package repositories_clover
import (
"fmt"
clover "github.com/ostafen/clover/v2"
)
// NewDB initializes and sets up the clover database using bbolt under the hood.
// Additionally, it automatically creates collections for the necessary types.
func NewDB(path string, collections []string) (*clover.DB, error) {
db, err := clover.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
for _, collection := range collections {
if err := db.CreateCollection(collection); err != nil {
return nil, fmt.Errorf("failed to create collection %s: %w", collection, err)
}
}
return db, nil
}
package repositories_clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatClover is a Clover implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatClover struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatClover.
// It initializes and returns a Clover-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *clover.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatClover{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerClover is a Clover implementation of the RequestTracker interface.
type RequestTrackerClover struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerClover.
// It initializes and returns a Clover-based repository for RequestTracker entities.
func NewRequestTracker(db *clover.DB) repositories.RequestTracker {
return &RequestTrackerClover{
NewGenericRepository[types.RequestTracker](db),
}
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineClover is a Clover implementation of the VirtualMachine interface.
type VirtualMachineClover struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineClover.
// It initializes and returns a Clover-based repository for VirtualMachine entities.
func NewVirtualMachine(db *clover.DB) repositories.VirtualMachine {
return &VirtualMachineClover{
NewGenericRepository[types.VirtualMachine](db),
}
}
package repositories_clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pKField = "_id"
)
// GenericEntityRepositoryClover is a generic single entity repository implementation using Clover.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryClover.
// It initializes and returns a repository with the provided Clover database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericEntityRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericEntityRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericEntityRepositoryClover[T]) query() *clover_q.Query {
return clover_q.NewQuery(repo.collection)
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryClover[T]) Save(ctx context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) Get(ctx context.Context) (T, error) {
var model T
q := repo.query().Sort(clover_q.SortOption{
Field: "CreatedAt",
Direction: -1,
})
doc, err := repo.db.FindFirst(q)
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, fmt.Errorf("Failed to convert document to model: %v", err)
}
return model, nil
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryClover[T]) Clear(ctx context.Context) error {
return repo.db.Delete(repo.query())
}
// History retrieves previous versions of the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
q := repo.query()
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return false
}
models = append(models, model)
return true
})
return models, handleDBError(err)
}
package repositories_clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pkField = "_id"
deletedAtField = "DeletedAt"
)
// GenericRepositoryClover is a generic repository implementation using Clover.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericRepository creates a new instance of GenericRepositoryClover.
// It initializes and returns a repository with the provided Clover database.
func NewGenericRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericRepositoryClover[T]) query(includeDeleted bool) *clover_q.Query {
query := clover_q.NewQuery(repo.collection)
if !includeDeleted {
query = query.Where(clover_q.Field(deletedAtField).LtEq(time.Unix(0, 0)))
}
return query
}
func (repo *GenericRepositoryClover[T]) queryWithID(
id interface{},
includeDeleted bool,
) *clover_q.Query {
return repo.query(includeDeleted).Where(clover_q.Field(pkField).Eq(id.(string)))
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryClover[T]) Create(ctx context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return data, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryClover[T]) Get(ctx context.Context, id interface{}) (T, error) {
var model T
doc, err := repo.db.FindFirst(repo.queryWithID(id, false))
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryClover[T]) Update(
ctx context.Context,
id interface{},
data T,
) (T, error) {
updates := toCloverDoc(data).AsMap()
updates["UpdatedAt"] = time.Now()
err := repo.db.Update(repo.queryWithID(id, false), updates)
if err != nil {
return data, handleDBError(err)
}
data, err = repo.Get(ctx, id)
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryClover[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.Update(
repo.queryWithID(id, false),
map[string]interface{}{"DeletedAt": time.Now()},
)
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryClover[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var model T
q := repo.query(false)
q = applyConditions(q, query)
doc, err := repo.db.FindFirst(q)
if err != nil {
return model, handleDBError(err)
} else if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, fmt.Errorf("Failed to convert document to model: %v", err)
}
return model, nil
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryClover[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
var modelParsingErr error
q := repo.query(false)
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
model, internalErr := toModel[T](doc, false)
if internalErr != nil {
modelParsingErr = handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, internalErr))
return false
}
models = append(models, model)
return true
})
if err != nil {
return models, handleDBError(err)
}
if modelParsingErr != nil {
return models, modelParsingErr
}
return models, nil
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a Clover database query.
// It takes a Clover database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified Clover database instance is returned.
func applyConditions[T repositories.ModelType](
q *clover_q.Query,
query repositories.Query[T],
) *clover_q.Query {
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
// change the field name to json tag name if specified in the struct
condition.Field = fieldJSONTag[T](condition.Field)
switch condition.Operator {
case "=":
q = q.Where(clover_q.Field(condition.Field).Eq(condition.Value))
case ">":
q = q.Where(clover_q.Field(condition.Field).Gt(condition.Value))
case ">=":
q = q.Where(clover_q.Field(condition.Field).GtEq(condition.Value))
case "<":
q = q.Where(clover_q.Field(condition.Field).Lt(condition.Value))
case "<=":
q = q.Where(clover_q.Field(condition.Field).LtEq(condition.Value))
case "!=":
q = q.Where(clover_q.Field(condition.Field).Neq(condition.Value))
case "IN":
if values, ok := condition.Value.([]interface{}); ok {
q = q.Where(clover_q.Field(condition.Field).In(values...))
}
case "LIKE":
if value, ok := condition.Value.(string); ok {
q = q.Where(clover_q.Field(condition.Field).Like(value))
}
}
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldName = fieldJSONTag[T](fieldName)
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
q = q.Where(clover_q.Field(fieldName).Eq(fieldValue))
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := 1
if query.SortBy[0] == '-' {
dir = -1
query.SortBy = fieldJSONTag[T](query.SortBy[1:])
}
q = q.Sort(clover_q.SortOption{Field: query.SortBy, Direction: dir})
}
// Apply limit if specified in the query.
if query.Limit > 0 {
q = q.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
q = q.Limit(query.Offset)
}
return q
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoClover is a Clover implementation of the PeerInfo interface.
type PeerInfoClover struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoClover.
// It initializes and returns a Clover-based repository for PeerInfo entities.
func NewPeerInfo(db *clover.DB) repositories.PeerInfo {
return &PeerInfoClover{NewGenericRepository[types.PeerInfo](db)}
}
// MachineClover is a Clover implementation of the Machine interface.
type MachineClover struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineClover.
// It initializes and returns a Clover-based repository for Machine entities.
func NewMachine(db *clover.DB) repositories.Machine {
return &MachineClover{NewGenericRepository[types.Machine](db)}
}
// FreeResourcesClover is a Clover implementation of the FreeResources interface.
type FreeResourcesClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResources(db *clover.DB) repositories.FreeResources {
return &FreeResourcesClover{NewGenericEntityRepository[types.FreeResources](db)}
}
// AvailableResourcesClover is a Clover implementation of the AvailableResources interface.
type AvailableResourcesClover struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// AvailableResourcesRepositoryClover is a Clover implementation of the AvailableResourcesRepository interface.
type AvailableResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesClover.
// It initializes and returns a Clover-based repository for AvailableResources entity.
func NewAvailableResources(db *clover.DB) repositories.AvailableResources {
return &AvailableResourcesClover{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesClover is a Clover implementation of the Services interface.
type ServicesClover struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesClover.
// It initializes and returns a Clover-based repository for Services entities.
func NewServices(db *clover.DB) repositories.Services {
return &ServicesClover{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsClover is a Clover implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsClover struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsClover.
// It initializes and returns a Clover-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *clover.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsClover{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoClover is a Clover implementation of the Libp2pInfo interface.
type Libp2pInfoClover struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoClover.
// It initializes and returns a Clover-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *clover.DB) repositories.Libp2pInfo {
return &Libp2pInfoClover{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDClover is a Clover implementation of the MachineUUID interface.
type MachineUUIDClover struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDClover.
// It initializes and returns a Clover-based repository for MachineUUID entity.
func NewMachineUUID(db *clover.DB) repositories.MachineUUID {
return &MachineUUIDClover{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionClover is a Clover implementation of the Connection interface.
type ConnectionClover struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionClover.
// It initializes and returns a Clover-based repository for Connection entities.
func NewConnection(db *clover.DB) repositories.Connection {
return &ConnectionClover{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenClover is a Clover implementation of the ElasticToken interface.
type ElasticTokenClover struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenClover.
// It initializes and returns a Clover-based repository for ElasticToken entities.
func NewElasticToken(db *clover.DB) repositories.ElasticToken {
return &ElasticTokenClover{NewGenericRepository[types.ElasticToken](db)}
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// FreeResourcesRepositoryClover is a Clover implementation of the FreeResourcesRepository interface.
type FreeResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResourcesRepository creates a new instance of FreeResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResourcesRepository(db *clover.DB) repositories.FreeResources {
return &FreeResourcesRepositoryClover{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesRepositoryClover is a Clover implementation of the OnboardedResourcesRepository interface.
type OnboardedResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResourcesRepository creates a new instance of OnboardedResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for OnboardedResources entity.
func NewOnboardedResourcesRepository(db *clover.DB) repositories.OnboardedResources {
return &OnboardedResourcesRepositoryClover{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// RequiredResourcesRepositoryClover is a Clover implementation of the RequiredResourcesRepository interface.
type RequiredResourcesRepositoryClover struct {
repositories.GenericRepository[types.RequiredResources]
}
// NewRequiredResourcesRepository creates a new instance of RequiredResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for RequiredResources entities.
func NewRequiredResourcesRepository(db *clover.DB) repositories.RequiredResources {
return &RequiredResourcesRepositoryClover{
NewGenericRepository[types.RequiredResources](db),
}
}
package repositories_clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeClover is a Clover implementation of the StorageVolume interface.
type StorageVolumeClover struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeClover.
// It initializes and returns a Clover-based repository for StorageVolume entities.
func NewStorageVolume(db *clover.DB) repositories.StorageVolume {
return &StorageVolumeClover{
NewGenericRepository[types.StorageVolume](db),
}
}
package repositories_clover
import (
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
"gitlab.com/nunet/device-management-service/db/repositories"
)
func handleDBError(err error) error {
if err != nil {
switch err {
case clover.ErrDocumentNotExist:
return repositories.NotFoundError
case clover.ErrDuplicateKey:
return repositories.InvalidDataError
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.DatabaseError, err)
}
}
return nil
}
func toCloverDoc[T repositories.ModelType](data T) *clover_d.Document {
jsonBytes, err := json.Marshal(data)
if err != nil {
return clover_d.NewDocument()
}
mappedData := make(map[string]interface{})
err = json.Unmarshal(jsonBytes, &mappedData)
if err != nil {
return clover_d.NewDocument()
}
doc := clover_d.NewDocumentOf(mappedData)
return doc
}
func toModel[T repositories.ModelType](doc *clover_d.Document, isEntityRepo bool) (T, error) {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return model, err
}
if !isEntityRepo {
// we shouldn't try to update IDs of entity repositories as they might not
// even have an ID at all
model, err = repositories.UpdateField(model, "ID", doc.ObjectId())
if err != nil {
return model, err
}
}
return model, nil
}
func fieldJSONTag[T repositories.ModelType](field string) string {
fieldName := field
if field, ok := reflect.TypeOf(*new(T)).FieldByName(field); ok {
if tag, ok := field.Tag.Lookup("json"); ok {
fieldName = strings.Split(tag, ",")[0]
}
}
return fieldName
}
package repositories
import (
"context"
)
// QueryCondition is a struct representing a query condition.
type QueryCondition struct {
Field string // Field specifies the database or struct field to which the condition applies.
Operator string // Operator defines the comparison operator (e.g., "=", ">", "<").
Value interface{} // Value is the expected value for the given field.
}
type ModelType interface{}
// Query is a struct that wraps both the instance of type T and additional query parameters.
// It is used to construct queries with conditions, sorting, limiting, and offsetting.
type Query[T any] struct {
Instance T // Instance is an optional object of type T used to build conditions from its fields.
Conditions []QueryCondition // Conditions represent the conditions applied to the query.
SortBy string // SortBy specifies the field by which the query results should be sorted.
Limit int // Limit specifies the maximum number of results to return.
Offset int // Offset specifies the number of results to skip before starting to return data.
}
// GenericRepository is an interface defining basic CRUD operations and standard querying methods.
type GenericRepository[T ModelType] interface {
// Create adds a new record to the repository.
Create(ctx context.Context, data T) (T, error)
// Get retrieves a record by its identifier.
Get(ctx context.Context, id interface{}) (T, error)
// Update modifies a record by its identifier.
Update(ctx context.Context, id interface{}, data T) (T, error)
// Delete removes a record by its identifier.
Delete(ctx context.Context, id interface{}) error
// Find retrieves a single record based on a query.
Find(ctx context.Context, query Query[T]) (T, error)
// FindAll retrieves multiple records based on a query.
FindAll(ctx context.Context, query Query[T]) ([]T, error)
// GetQuery returns an empty query instance for the repository's type.
GetQuery() Query[T]
}
// EQ creates a QueryCondition for equality comparison.
// It takes a field name and a value and returns a QueryCondition with the equality operator.
func EQ(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "=", Value: value}
}
// GT creates a QueryCondition for greater-than comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than operator.
func GT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">", Value: value}
}
// GTE creates a QueryCondition for greater-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than or equal operator.
func GTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">=", Value: value}
}
// LT creates a QueryCondition for less-than comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than operator.
func LT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<", Value: value}
}
// LTE creates a QueryCondition for less-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than or equal operator.
func LTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<=", Value: value}
}
// IN creates a QueryCondition for an "IN" comparison.
// It takes a field name and a slice of values and returns a QueryCondition with the "IN" operator.
func IN(field string, values []interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "IN", Value: values}
}
// LIKE creates a QueryCondition for a "LIKE" comparison.
// It takes a field name and a pattern and returns a QueryCondition with the "LIKE" operator.
func LIKE(field, pattern string) QueryCondition {
return QueryCondition{Field: field, Operator: "LIKE", Value: pattern}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatGORM is a GORM implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatGORM struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatGORM.
// It initializes and returns a GORM-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *gorm.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatGORM{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerGORM is a GORM implementation of the RequestTracker interface.
type RequestTrackerGORM struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerGORM.
// It initializes and returns a GORM-based repository for RequestTracker entities.
func NewRequestTracker(db *gorm.DB) repositories.RequestTracker {
return &RequestTrackerGORM{
NewGenericRepository[types.RequestTracker](db),
}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineGORM is a GORM implementation of the VirtualMachine interface.
type VirtualMachineGORM struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineGORM.
// It initializes and returns a GORM-based repository for VirtualMachine entities.
func NewVirtualMachine(db *gorm.DB) repositories.VirtualMachine {
return &VirtualMachineGORM{
NewGenericRepository[types.VirtualMachine](db),
}
}
package repositories_gorm
import (
"context"
"fmt"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
createdAtField = "CreatedAt"
)
// GenericEntityRepositoryGORM is a generic single entity repository implementation using GORM as an ORM.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB // db is the GORM database instance.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryGORM.
// It initializes and returns a repository with the provided GORM database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *gorm.DB,
) repositories.GenericEntityRepository[T] {
return &GenericEntityRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryGORM[T]) Save(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Get(ctx context.Context) (T, error) {
var result T
query := repo.GetQuery()
query.SortBy = fmt.Sprintf("-%s", createdAtField)
db := repo.db.WithContext(ctx)
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Clear(ctx context.Context) error {
return repo.db.WithContext(ctx).Delete(new(T), "id IS NOT NULL").Error
}
// History retrieves previous records from the repository constrained by the query.
func (repo *GenericEntityRepositoryGORM[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
package repositories_gorm
import (
"context"
"fmt"
"reflect"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericRepositoryGORM is a generic repository implementation using GORM as an ORM.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB
}
// NewGenericRepository creates a new instance of GenericRepositoryGORM.
// It initializes and returns a repository with the provided GORM database.
func NewGenericRepository[T repositories.ModelType](db *gorm.DB) repositories.GenericRepository[T] {
return &GenericRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryGORM[T]) Create(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Get(ctx context.Context, id interface{}) (T, error) {
var result T
err := repo.db.WithContext(ctx).First(&result, "id = ?", id).Error
if err != nil {
return result, handleDBError(err)
}
return result, handleDBError(err)
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Update(ctx context.Context, id interface{}, data T) (T, error) {
err := repo.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Updates(data).Error
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.WithContext(ctx).Delete(new(T), "id = ?", id).Error
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryGORM[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var result T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryGORM[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a GORM database query.
// It takes a GORM database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified GORM database instance is returned.
func applyConditions[T any](db *gorm.DB, query repositories.Query[T]) *gorm.DB {
// Retrieve the table name using the GORM naming strategy.
tableName := db.NamingStrategy.TableName(reflect.TypeOf(*new(T)).Name())
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
columnName := db.NamingStrategy.ColumnName(tableName, condition.Field)
db = db.Where(
fmt.Sprintf("%s %s ?", columnName, condition.Operator),
condition.Value,
)
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
columnName := db.NamingStrategy.ColumnName(tableName, fieldName)
db = db.Where(fmt.Sprintf("%s = ?", columnName), fieldValue)
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := "ASC"
if query.SortBy[0] == '-' {
query.SortBy = query.SortBy[1:]
dir = "DESC"
}
columnName := db.NamingStrategy.ColumnName(tableName, query.SortBy)
db = db.Order(fmt.Sprintf("%s.%s %s", tableName, columnName, dir))
}
// Apply limit if specified in the query.
if query.Limit > 0 {
db = db.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
db = db.Limit(query.Offset)
}
return db
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoGORM is a GORM implementation of the PeerInfo interface.
type PeerInfoGORM struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoGORM.
// It initializes and returns a GORM-based repository for PeerInfo entities.
func NewPeerInfo(db *gorm.DB) repositories.PeerInfo {
return &PeerInfoGORM{NewGenericRepository[types.PeerInfo](db)}
}
// MachineGORM is a GORM implementation of the Machine interface.
type MachineGORM struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineGORM.
// It initializes and returns a GORM-based repository for Machine entities.
func NewMachine(db *gorm.DB) repositories.Machine {
return &MachineGORM{NewGenericRepository[types.Machine](db)}
}
// AvailableResourcesGORM is a GORM implementation of the AvailableResources interface.
type AvailableResourcesGORM struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesGORM.
// It initializes and returns a GORM-based repository for AvailableResources entity.
func NewAvailableResources(db *gorm.DB) repositories.AvailableResources {
return &AvailableResourcesGORM{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesGORM is a GORM implementation of the Services interface.
type ServicesGORM struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesGORM.
// It initializes and returns a GORM-based repository for Services entities.
func NewServices(db *gorm.DB) repositories.Services {
return &ServicesGORM{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsGORM is a GORM implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsGORM struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsGORM.
// It initializes and returns a GORM-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *gorm.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsGORM{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoGORM is a GORM implementation of the Libp2pInfo interface.
type Libp2pInfoGORM struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoGORM.
// It initializes and returns a GORM-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *gorm.DB) repositories.Libp2pInfo {
return &Libp2pInfoGORM{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDGORM is a GORM implementation of the MachineUUID interface.
type MachineUUIDGORM struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDGORM.
// It initializes and returns a GORM-based repository for MachineUUID entity.
func NewMachineUUID(db *gorm.DB) repositories.MachineUUID {
return &MachineUUIDGORM{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionGORM is a GORM implementation of the Connection interface.
type ConnectionGORM struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionGORM.
// It initializes and returns a GORM-based repository for Connection entities.
func NewConnection(db *gorm.DB) repositories.Connection {
return &ConnectionGORM{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenGORM is a GORM implementation of the ElasticToken interface.
type ElasticTokenGORM struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenGORM.
// It initializes and returns a GORM-based repository for ElasticToken entities.
func NewElasticToken(db *gorm.DB) repositories.ElasticToken {
return &ElasticTokenGORM{NewGenericRepository[types.ElasticToken](db)}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
type OnboardingParamsGORM struct {
repositories.GenericEntityRepository[types.OnboardingConfig]
}
func NewOnboardingParams(db *gorm.DB) repositories.OnboardingParams {
return &OnboardingParamsGORM{
NewGenericEntityRepository[types.OnboardingConfig](db),
}
}
package repositories_gorm
import (
"gitlab.com/nunet/device-management-service/types"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// FreeResourcesGORM is a GORM implementation of the FreeResourcesRepository interface.
type FreeResourcesGORM struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesRepositoryGORM.
// It initializes and returns a GORM-based repository for FreeResources entity.
func NewFreeResources(db *gorm.DB) repositories.FreeResources {
return &FreeResourcesGORM{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesRepositoryGORM is a GORM implementation of the OnboardedResourcesRepository interface.
type OnboardedResourcesRepositoryGORM struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResources creates a new instance of OnboardedResourcesRepositoryGORM.
// It initializes and returns a GORM-based repository for OnboardedResources entity.
func NewOnboardedResources(db *gorm.DB) repositories.OnboardedResources {
return &OnboardedResourcesRepositoryGORM{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// RequiredResourcesRepositoryGORM is a GORM implementation of the RequiredResourcesRepository interface.
type RequiredResourcesRepositoryGORM struct {
repositories.GenericRepository[types.RequiredResources]
}
// NewRequiredResources creates a new instance of RequiredResourcesRepositoryGORM.
// It initializes and returns a GORM-based repository for RequiredResources entities.
func NewRequiredResources(db *gorm.DB) repositories.RequiredResources {
return &RequiredResourcesRepositoryGORM{
NewGenericRepository[types.RequiredResources](db),
}
}
package repositories_gorm
import (
"errors"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const structFieldNameDeletedAt = "DeletedAt"
// handleDBError is a utility function that translates GORM database errors into custom repository errors.
// It takes a GORM database error as input and returns a corresponding custom error from the repositories package.
func handleDBError(err error) error {
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return repositories.NotFoundError
case gorm.ErrInvalidData, gorm.ErrInvalidField, gorm.ErrInvalidValue:
return repositories.InvalidDataError
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.DatabaseError, err)
}
}
return nil
}
package repositories
import (
"fmt"
"reflect"
)
// UpdateField is a generic function that updates a field of a struct or a pointer to a struct.
// The function uses reflection to dynamically update the specified field of the input struct.
func UpdateField[T interface{}](input T, fieldName string, newValue interface{}) (T, error) {
// Use reflection to get the struct's field
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
// If input is a pointer, get the underlying element
val = val.Elem()
} else {
// If input is not a pointer, ensure it's addressable
val = reflect.ValueOf(&input).Elem()
}
// Check if the input is a struct
if val.Kind() != reflect.Struct {
return input, fmt.Errorf("Not a struct: %T", input)
}
// Get the field by name
field := val.FieldByName(fieldName)
if !field.IsValid() {
return input, fmt.Errorf("Field not found: %v", fieldName)
}
// Check if the field is settable
if !field.CanSet() {
return input, fmt.Errorf("Field not settable: %v", fieldName)
}
// Check if types are compatible
if !reflect.TypeOf(newValue).ConvertibleTo(field.Type()) {
return input, fmt.Errorf(
"Incompatible conversion: %v -> %v; value: %v",
field.Type(), reflect.TypeOf(newValue), newValue,
)
}
// Convert the new value to the field type
convertedValue := reflect.ValueOf(newValue).Convert(field.Type())
// Set the new value to the field
field.Set(convertedValue)
return input, nil
}
// IsEmptyValue checks if value represents a zero-value struct (or pointer to a zero-value struct) using reflection.
// The function is useful for determining if a struct or its pointer is empty, i.e., all fields have their zero-values.
func IsEmptyValue(value interface{}) bool {
// Check if the value is nil
if value == nil {
return true
}
// Use reflection to get the value's type and kind
val := reflect.ValueOf(value)
// If the value is a pointer, dereference it to get the underlying element
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Check if the value is zero (empty) based on its kind
return val.IsZero()
}
package dms
import (
"context"
"errors"
"fmt"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/network"
)
// ActorInterface defines the functionalities of actor.
type ActorInterface interface {
// Address of actor.
Address() *ActorAddrInfo
// SendMessage to another actor.
SendMessage(destination *ActorAddrInfo, m *Message)
// CreateActor creates a new actor.
CreateActor() (*ActorAddrInfo, error)
}
// ActorAddrInfo encapsulates the data required to address an actor.
type ActorAddrInfo struct {
HostID string
InboxAddress string
}
// Valid checks if an actor is valid.
func (a *ActorAddrInfo) Valid() bool {
return a.HostID != "" && a.InboxAddress != ""
}
// Message is passed between actors.
type Message struct {
msgType string
sender string
data []byte
}
// ActorFactory is an actor factory.
type ActorFactory struct {
hostID string
network network.Network
actorRegistry *ActorRegistry
}
// Actor represents an actor.
type Actor struct {
hostID string
address string
network network.Network
messages chan Message
actorRegistry *ActorRegistry
factory *ActorFactory
}
// NewActorFactory holds the dependencies to create and manage actors.
func NewActorFactory(hostID string, network network.Network, actorRegistry *ActorRegistry) *ActorFactory {
return &ActorFactory{
hostID: hostID,
network: network,
actorRegistry: actorRegistry,
}
}
// NewActor allows the factory to create a new actor.
func (f *ActorFactory) NewActor() (*Actor, error) {
return f.newActor(nil)
}
func (f *ActorFactory) newActor(parentActorAddress *ActorAddrInfo) (*Actor, error) {
return newActor(parentActorAddress, f.hostID, f.network, f.actorRegistry, f)
}
// newActor returns a new actor based on the given arguments.
func newActor(parentActorAddress *ActorAddrInfo, hostID string, net network.Network, actorRegistry *ActorRegistry, factory *ActorFactory) (*Actor, error) {
if hostID == "" {
return nil, errors.New("host id is empty")
}
if net == nil {
return nil, fmt.Errorf("network is nil")
}
id, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate uuid: %w", err)
}
createdActor := &Actor{
hostID: hostID,
network: net,
address: id.String(),
actorRegistry: actorRegistry,
factory: factory,
messages: make(chan Message, 100),
}
actorRegistry.AddActorAddress(createdActor.Address())
if parentActorAddress != nil {
actorRegistry.SetParentAddress(createdActor.address, parentActorAddress)
actorRegistry.AddChild(parentActorAddress.InboxAddress, createdActor.Address())
}
return createdActor, nil
}
// Address returns the address of an actor.
func (a *Actor) Address() *ActorAddrInfo {
return &ActorAddrInfo{
HostID: a.hostID,
InboxAddress: a.address,
}
}
// SendMessage sends a message to another actor.
func (a *Actor) SendMessage(ctx context.Context, destination *ActorAddrInfo, m *Message) error {
if !destination.Valid() {
return errors.New("destination actor addr info is invalid")
}
if m == nil {
return errors.New("message is invalid")
}
// get the multiaddress of a host by resolving the hostid
addresses, err := a.network.ResolveAddress(ctx, destination.HostID)
if err != nil {
return fmt.Errorf("failed to send message to actor %s: %v", destination.HostID, err)
}
err = a.network.SendMessage(ctx, []string{addresses[0]}, types.MessageEnvelope{
Type: types.MessageType(fmt.Sprintf("actor/%s/messages/0.0.1", destination.InboxAddress)),
Data: m.data,
})
if err != nil {
return fmt.Errorf("failed to send message to remote actor %s: %v", destination.InboxAddress, err)
}
return nil
}
// Start registers the message handlers and starts an actor.
func (a *Actor) Start() error {
err := a.network.HandleMessage(fmt.Sprintf("actor/%s/messages/0.0.1", a.address), func(data []byte) {
a.messages <- Message{
sender: "sender",
data: data,
}
})
if err != nil {
return fmt.Errorf("failed to start actor %s: %w", a.address, err)
}
return nil
}
// CreateActor creates another actor.
func (a *Actor) CreateActor() (*ActorAddrInfo, error) {
newActor, err := a.factory.newActor(a.Address())
if err != nil {
return nil, fmt.Errorf("failed to create new actor: %w", err)
}
if err := newActor.Start(); err != nil {
return nil, fmt.Errorf("failed to start new actor: %w", err)
}
return newActor.Address(), nil
}
// ProcessMessages reads messages from the incoming messages channel.
func (a *Actor) ProcessMessages() {
for msg := range a.messages {
fmt.Printf("received message from %s", msg.sender)
// switch msg.msgType {
// case "hello":
// {
// a.handleHello(msg)
// }
// default:
// fmt.Printf("unhandled message type: %s\n", msg.msgType)
// }
}
}
// Hello behaviour
func (a *Actor) Hello(ctx context.Context, destination *ActorAddrInfo, m *Message) {
m.msgType = "hello"
a.SendMessage(ctx, destination, m)
}
func (a *Actor) handleHello(m Message) {
fmt.Println("handled hello message", m)
}
package dms
// ActorRegistry represents an actor registry.
type ActorRegistry struct {
actors map[string]*actorInfo
}
type actorInfo struct {
addrInfo *ActorAddrInfo
parent *ActorAddrInfo
childs []*ActorAddrInfo
}
// NewActorRegistry creates an actor registry.
func NewActorRegistry() *ActorRegistry {
return &ActorRegistry{
actors: make(map[string]*actorInfo),
}
}
// AddActorAddress adds an actor address to the registry.
func (r *ActorRegistry) AddActorAddress(a *ActorAddrInfo) {
_, ok := r.actors[a.InboxAddress]
if ok {
return
}
r.actors[a.InboxAddress] = &actorInfo{
addrInfo: a,
parent: nil,
childs: make([]*ActorAddrInfo, 0),
}
}
// SetParentAddress sets parent address of an actor.
func (r *ActorRegistry) SetParentAddress(actorID string, parent *ActorAddrInfo) {
actor, ok := r.actors[actorID]
if !ok {
return
}
actor.parent = parent
r.actors[actorID] = actor
}
// AddChild adds a child to an actor.
func (r *ActorRegistry) AddChild(actorID string, child *ActorAddrInfo) {
actor, ok := r.actors[actorID]
if !ok {
return
}
actor.childs = append(actor.childs, child)
r.actors[actorID] = actor
}
func (r *ActorRegistry) GetActorAddress(address string) (*ActorAddrInfo, bool) {
a, ok := r.actors[address]
return a.addrInfo, ok
}
package dms
import (
// "context"
"context"
"fmt"
"log"
"time"
"gitlab.com/nunet/device-management-service/api"
gdb "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/internal"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/telemetry/logger"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
// "gitlab.com/nunet/device-management-service/internal/messaging"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/utils"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/spf13/afero"
)
// NewP2P is stub, real implementation is needed in order to pass it to
// routers which access them in some handlers.
func NewP2P() libp2p.Libp2p {
return libp2p.Libp2p{}
}
// QUESTION(dms-initialization): should the db instance be constructed here?
func Run() {
ctx := context.Background()
log.Println("WARNING: Most parts commented out in dms.Run()")
config.LoadConfig()
// XXX: wait for server to start properly before sending requests below
// TODO: should be removed
time.Sleep(time.Second * 5)
// check if onboarded
db, err := gorm.Open(sqlite.Open(fmt.Sprintf("%s/nunet.db", config.GetConfig().General.WorkDir)), &gorm.Config{})
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
onboardR := gdb.NewOnboardingParams(db)
p2pR := gdb.NewLibp2pInfo(db)
uuidR := gdb.NewMachineUUID(db)
avResR := gdb.NewAvailableResources(db)
oConf, err := onboardR.Get(ctx)
if err != nil {
log.Fatalf("Failed to get onboarding config: %v", err)
}
onboard := onboarding.New(onboarding.OnboardingConfig{
Fs: afero.Afero{Fs: afero.NewOsFs()},
P2PRepo: p2pR,
UUIDRepo: uuidR,
AvResourceRepo: avResR,
WorkDir: config.GetConfig().WorkDir,
DatabasePath: fmt.Sprintf("%s/nunet.db", config.GetConfig().General.WorkDir),
Channels: []string{"nunet", "nunet-test", "nunet-team", "nunet-edge"},
})
if onboarded, _ := onboard.IsOnboarded(ctx); onboarded {
ValidateOnboarding(&oConf)
p2pParams, err := p2pR.Get(ctx)
if err != nil {
log.Fatalf("Failed to get libp2p info: %v", err)
}
_, err = crypto.UnmarshalPrivateKey(p2pParams.PrivateKey)
if err != nil {
zlog.Sugar().Fatalf("unable to unmarshal private key: %v", err)
}
}
// initialize rest api server
restConfig := api.RESTServerConfig{
P2p: nil,
Onboarding: nil,
Logger: logger.New("rest-server"),
MidW: nil,
Port: config.GetConfig().Rest.Port,
Addr: config.GetConfig().Rest.Addr,
}
rServer := api.NewRESTServer(restConfig)
rServer.InitializeRoutes()
go rServer.Run()
// wait for SIGINT or SIGTERM
sig := <-internal.ShutdownChan
fmt.Printf("Shutting down after receiving %v...\n", sig)
// add cleanup code here
fmt.Println("Cleaning up before shutting down")
return
}
func ValidateOnboarding(oConf *types.OnboardingConfig) {
// Check 1: Check if payment address is valid
err := utils.ValidateAddress(oConf.PublicKey)
if err != nil {
zlog.Sugar().Errorf("the payment address %s is not valid", oConf.PublicKey)
zlog.Sugar().Error("exiting DMS")
return
}
}
package dms
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("dms")
}
package parser
import (
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/nunet"
)
var registry Registry[jobs.JobSpec]
func init() {
registry = &RegistryImpl[jobs.JobSpec]{
parsers: make(map[SpecType]Parser[jobs.JobSpec]),
}
// Register Nunet parser.
nunetParser := NewParser[jobs.JobSpec](
nunet.NewNuNetTransformer(),
nunet.NewNuNetValidator(),
)
registry.RegisterParser(
specTypeNuNet,
nunetParser,
)
// Register other parsers here.
}
package nunet
import (
"fmt"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetTransformer() transform.Transformer {
return transform.NewTransformer(
[]map[tree.Path]transform.TransformerFunc{
{
"jobs": TransformJobs,
"jobs.**.children": TransformJobs,
"jobs.**.volumes": TransformVolumes,
"jobs.**.networks": TransformNetworks,
},
{
"jobs.**.volumes.[]": TransformVolume,
"jobs.**.networks.[]": TransformNetwork,
"jobs.**.libraries.[]": TransformLibrary,
},
{
"jobs.**.execution": TransformExecution,
"jobs.**.volumes.[].remote": TransformVolumeRemote,
},
},
)
}
// TransformJobs transforms the jobs map to a slice and assigns the keys to the "name" field.
func TransformJobs(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
jobs, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid jobs configuration: %v", data)
}
return transform.MapToSlice(jobs)
}
// TransformVolumes transforms the volumes map to a slice and assigns the keys to the "name" field.
func TransformVolumes(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
volumes, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volumes configuration: %v", data)
}
return transform.MapToSlice(volumes)
}
// TransformNetworks transforms the networks map to a slice and assigns the keys to the "name" field.
func TransformNetworks(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
networks, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid networks configuration: %v", data)
}
return transform.MapToSlice(networks)
}
// TransformExecution transforms the engine configuration from flat map to SpecConfig format.
func TransformExecution(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
engine, ok := data.(map[string]any)
result := map[string]any{}
if !ok {
return nil, fmt.Errorf("invalid engine configuration: %v", data)
}
params := map[string]any{}
for k, v := range engine {
if k != "type" {
params[k] = v
}
}
result["type"] = engine["type"]
result["params"] = params
return result, nil
}
// TransformVloume transforms the volume configuration and handles inheritance.
// The volume configuration can be a string in the format "name:mountpoint" or a map.
// If the volume is defined in the parent volumes, the configurations are merged.
func TransformVolume(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, split it into name and mountpoint.
switch v := data.(type) {
case string:
mapping := strings.Split(v, ":")
if len(mapping) != 2 {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
config = map[string]any{
"name": mapping[0],
"mountpoint": mapping[1],
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
// Collect all potential parent paths where the volume could be defined.
parentPaths := []tree.Path{}
pathParts := path.Parts()
for i, part := range pathParts {
if part == "children" {
parentPaths = append(parentPaths, tree.NewPath(pathParts[:i]...))
}
}
// Merge the volume configuration with the parent configurations.
for _, parent := range parentPaths {
// Check if the volume exists in the parent
c, err := transform.GetConfigAtPath(*root, parent.Next("volumes"))
if err != nil {
fmt.Println("error: ", err)
continue
}
volumes, _ := transform.ToAnySlice(c)
for _, v := range volumes {
if volume, ok := v.(map[string]any); ok && volume["name"] == config["name"] {
// Merge the configurations
for k, v := range volume {
config[k] = v
}
}
}
}
return config, nil
}
func TransformVolumeRemote(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
remoteConfig := map[string]any{}
remoteConfig["type"] = config["type"]
if params, ok := config["params"]; ok {
remoteConfig["params"] = params.(map[string]any)
return remoteConfig, nil
}
params := map[string]any{}
for k, v := range config {
if k != "type" {
params[k] = v
}
}
remoteConfig["params"] = params
return remoteConfig, nil
}
// TransformNetwork transforms the network configuration
func TransformNetwork(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid network configuration: %v", data)
}
ports, _ := transform.ToAnySlice(config["ports"])
portMap := []map[string]any{}
for _, port := range ports {
protocol, host, container := "tcp", 0, 0;
switch v := port.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) <= 2 {
host, _ = strconv.Atoi(parts[0])
container, _ = strconv.Atoi(parts[len(parts)-1])
} else if len(parts) == 3 {
protocol = parts[0]
host, _ = strconv.Atoi(parts[1])
container, _ = strconv.Atoi(parts[len(parts)-1])
}
case int:
host = v
container = v
case map[string]any:
switch h := v["host_port"].(type){
case int:
host = h
case string:
host, _ = strconv.Atoi(h)
}
switch c := v["container_port"].(type){
case int:
container = c
case string:
container, _ = strconv.Atoi(c)
}
if p, ok := v["protocol"].(string); ok {
protocol = p
}
}
portMap = append(portMap, map[string]any{
"protocol": protocol,
"host_port": host,
"container_port": container,
})
}
config["port_map"] = portMap
delete(config, "ports")
return config, nil
}
// TransformLibrary tansforms the library configuration to a map.
// The library configuration can be a string in the format "name:version" or a map.
func TransformLibrary(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
switch v := data.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) == 1 {
parts = append(parts, "")
}
if len(parts) != 2 {
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
return map[string]any{
"name": parts[0],
"version": parts[1],
}, nil
case map[string]any:
return v, nil
default:
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
}
package nunet
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetValidator() validate.Validator {
return validate.NewValidator(
map[tree.Path]validate.ValidatorFunc{
"": ValidateSpec,
"jobs.[]": ValidateJob,
"jobs.**.children.[]": ValidateJob,
},
)
}
// ValidateSpec checks the root configuration for consistency.
func ValidateSpec(root *map[string]any, data any, path tree.Path) error {
spec, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid spec configuration: %v", data)
}
// Check if the jobs list is present and not empty.
if spec["jobs"] == nil || len(spec["jobs"].([]any)) == 0 {
return fmt.Errorf("jobs list is required")
}
return nil
}
// ValidateJob checks the job configuration.
func ValidateJob(root *map[string]any, data any, path tree.Path) error {
job, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid job configuration: %v", data)
}
// Check if the job has either children or an execution.
if job["children"] == nil || len(job["children"].([]any)) == 0 {
if job["execution"] == nil {
return fmt.Errorf("job must have either children or an execution")
}
}
return nil
}
package parser
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs"
)
func Parse(specType SpecType, data []byte) (jobs.JobSpec, error) {
result := jobs.JobSpec{}
parser, exists := registry.GetParser(specType)
if !exists {
return result, fmt.Errorf("parser for spec type %s not found", specType)
}
result, err := parser.Parse(data)
if err != nil {
return result, err
}
return result, nil
}
package parser
import (
"encoding/json"
"fmt"
"github.com/mitchellh/mapstructure"
yaml "gopkg.in/yaml.v3"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
type SpecType string
const (
specTypeNuNet SpecType = "nunet"
specTypeNomad SpecType = "nomad"
specTypeK8s SpecType = "k8s"
)
const DefaultTagName = "json"
type Parser[T any] interface {
Parse(data []byte) (T, error)
}
type ParserImpl[T any] struct {
validator validate.Validator
transformer transform.Transformer
}
func NewParser[T any](transformer transform.Transformer, validator validate.Validator) Parser[T] {
return ParserImpl[T]{
transformer: transformer,
validator: validator,
}
}
func (p ParserImpl[T]) Parse(data []byte) (T, error) {
var rawConfig map[string]any
var config T
// Try to unmarshal as YAML first
err := yaml.Unmarshal(data, &rawConfig)
if err != nil {
// If YAML fails, try JSON
err = json.Unmarshal(data, &rawConfig)
if err != nil {
return config, fmt.Errorf("failed to parse config: %v", err)
}
}
// Apply transformers
transformed, err := p.transformer.Transform(&rawConfig)
if err != nil {
return config, fmt.Errorf("failed to transform config: %v", err)
}
// Validate the transformed configuration
if err = p.validator.Validate(&rawConfig); err != nil {
return config, err
}
// Create a new mapstructure decoder
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &config,
TagName: DefaultTagName,
})
if err != nil {
return config, fmt.Errorf("failed to create decoder: %v", err)
}
// Decode the transformed configuration
err = decoder.Decode(transformed)
if err != nil {
return config, fmt.Errorf("failed to decode config: %v", err)
}
return config, err
}
package parser
import (
"sync"
)
type Registry[T any] interface {
GetParser(specType SpecType) (Parser[T], bool)
RegisterParser(specType SpecType, p Parser[T])
}
type RegistryImpl[T any] struct {
parsers map[SpecType]Parser[T]
mu sync.RWMutex
}
func (r *RegistryImpl[T]) RegisterParser(specType SpecType, p Parser[T]) {
r.mu.Lock()
defer r.mu.Unlock()
r.parsers[specType] = p
}
func (r *RegistryImpl[T]) GetParser(specType SpecType) (Parser[T], bool) {
r.mu.RLock()
defer r.mu.RUnlock()
p, exists := r.parsers[specType]
return p, exists
}
package transform
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// TransformerFunc is a function that transforms a part of the configuration.
// It modifies the data to conform to the expected structure and returns the transformed data.
// It takes the root configuration, the data to transform and the current path in the tree.
type TransformerFunc func(*map[string]interface{}, interface{}, tree.Path) (any, error)
// Transformer is a configuration transformer.
type Transformer interface {
Transform(*map[string]interface{}) (interface{}, error)
}
// TransformerImpl is the implementation of the Transformer interface.
type TransformerImpl struct {
transformers []map[tree.Path]TransformerFunc
}
// NewTransformer creates a new transformer with the given transformers.
func NewTransformer(transformers []map[tree.Path]TransformerFunc) Transformer {
return TransformerImpl{
transformers: transformers,
}
}
// Transform applies the transformers to the configuration.
func (t TransformerImpl) Transform(rawConfig *map[string]interface{}) (interface{}, error) {
data := any(*rawConfig)
var err error
for _, transformers := range t.transformers {
data, err = t.transform(rawConfig, data, tree.NewPath(), transformers)
if err != nil {
return nil, err
}
}
return Normalize(data), nil
}
// transform is a recursive function that applies the transformers to the configuration.
func (t TransformerImpl) transform(root *map[string]interface{}, data any, path tree.Path, transformers map[tree.Path]TransformerFunc) (interface{}, error) {
var err error
// Apply transformers that match the current path.
for pattern, transformer := range transformers {
if path.Matches(pattern) {
data, err = transformer(root, data, path)
if err != nil {
return nil, err
}
}
}
// Recursively apply transformers to children.
if result, ok := data.(map[string]interface{}); ok {
for key, value := range result {
next := path.Next(key)
result[key], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
} else if result, err := ToAnySlice(data); err == nil {
for i, value := range result {
next := path.Next(fmt.Sprintf("[%d]", i))
result[i], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
}
return data, nil
}
package transform
import (
"fmt"
"reflect"
"sort"
"strconv"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// mapToSlice converts a map of maps to a slice
// and assigns the key to the "name" field.
func MapToSlice(data map[string]any) ([]any, error) {
if data == nil {
return nil, nil
}
result := []any{}
for k, v := range data {
if v == nil {
v = map[string]any{}
}
if e, ok := v.(map[string]any); ok {
e["name"] = k
}
result = append(result, v)
}
return result, nil
}
// getConfigAtPath retrieves a part of the configuration at a given path
func GetConfigAtPath(config map[string]interface{}, path tree.Path) (any, error) {
current := any(config)
for _, key := range path.Parts() {
switch v := current.(type) {
case map[string]any:
current = v[key]
case []any, []map[string]any:
i, err := strconv.Atoi(key[1 : len(key)-1])
if err != nil {
return nil, fmt.Errorf("invalid index: %v", key)
}
switch v := v.(type) {
case []any:
current = v[i]
case []map[string]any:
current = v[i]
}
default:
return nil, fmt.Errorf("invalid data type: %v", current)
}
}
return current, nil
}
// Generic function to convert any slice to []any
func ToAnySlice(slice any) ([]any, error) {
value := reflect.ValueOf(slice)
// Check if the input is a slice
if value.Kind() != reflect.Slice {
return nil, fmt.Errorf("input is not a slice. type: %T", slice)
}
length := value.Len()
anySlice := make([]any, length)
for i := 0; i < length; i++ {
anySlice[i] = value.Index(i).Interface()
}
return anySlice, nil
}
func normalizeMap(m interface{}) interface{} {
v := reflect.ValueOf(m)
switch v.Kind() {
case reflect.Map:
// Create a new map to hold normalized values
newMap := reflect.MakeMap(reflect.MapOf(v.Type().Key(), reflect.TypeOf((*interface{})(nil)).Elem()))
for _, key := range v.MapKeys() {
newValue := normalizeMap(v.MapIndex(key).Interface())
newMap.SetMapIndex(key, reflect.ValueOf(newValue))
}
return newMap.Interface()
case reflect.Slice:
// Create a new []interface{} slice to hold normalized values
newSlice := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
newSlice[i] = normalizeMap(v.Index(i).Interface())
}
// Sort the slice if it's sortable
sort.Slice(newSlice, func(i, j int) bool {
return fmt.Sprint(newSlice[i]) < fmt.Sprint(newSlice[j])
})
return newSlice
default:
// For other types, return as is
return m
}
}
// NormalizeMap is the exported function that users will call
func Normalize(m any) interface{} {
return normalizeMap(m)
}
package tree
import (
"strings"
)
const (
configPathSeparator = "."
configPathMatchAny = "*"
configPathMatchAnyMultiple = "**"
configPathList = "[]"
)
// Path is a custom type for representing paths in the configuration
type Path string
func NewPath(path ...string) Path {
return Path(strings.Join(path, configPathSeparator))
}
// Parts returns the parts of the path
func (p Path) Parts() []string {
return strings.Split(string(p), configPathSeparator)
}
// Parent returns the parent path
func (p Path) Parent() Path {
parts := p.Parts()
if len(parts) > 1 {
return Path(strings.Join(parts[:len(parts)-1], configPathSeparator))
}
return ""
}
// Next returns the next part of the path
func (p Path) Next(path string) Path {
if path == "" {
return p
}
if p == "" {
return Path(path)
}
return Path(string(p) + configPathSeparator + path)
}
// Last returns the last part of the path
func (p Path) Last() string {
parts := p.Parts()
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}
// Matches checks if the path matches a given pattern
func (p Path) Matches(pattern Path) bool {
pathParts := p.Parts()
patternParts := pattern.Parts()
return matchParts(pathParts, patternParts)
}
// String returns the string representation of the path
func (p Path) String() string {
return string(p)
}
// matchParts checks if the path parts match the pattern parts
func matchParts(pathParts, patternParts []string) bool {
// If the pattern is longer than the path, it can't match
if len(pathParts) < len(patternParts) {
return false
}
for i, part := range patternParts {
switch part {
case configPathMatchAnyMultiple:
// if it is the last part of the pattern, it matches
if i == len(patternParts)-1 {
return true
}
// Otherwise, try to match the rest of the path
for j := i; j < len(pathParts); j++ {
if matchParts(pathParts[j:], patternParts[i+1:]) {
return true
}
}
case configPathList:
// check if pathParts[i] is inclosed by []
if pathParts[i][0] != '[' || pathParts[i][len(pathParts[i])-1] != ']' {
return false
}
default:
// If the part doesn't match, it doesn't match
if part != configPathMatchAny && part != pathParts[i] {
return false
}
}
// If it is the last part of the pattern and the path is longer, it doesn't match
if i == len(patternParts)-1 && i < len(pathParts)-1 {
return false
}
}
return true
}
package validate
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// ValidatorFunc is a function that validates a part of the configuration.
// It takes the root configuration, the data to validate and the current path in the tree.
type ValidatorFunc func(*map[string]any, any, tree.Path) error
// Validator is a configuration validator.
// It contains a map of patterns to paths to functions that validate the configuration.
type Validator interface {
Validate(*map[string]any) error
}
// ValidatorImpl is the implementation of the Validator interface.
type ValidatorImpl struct {
validators map[tree.Path]ValidatorFunc
}
// NewValidator creates a new validator with the given validators.
func NewValidator(validators map[tree.Path]ValidatorFunc) Validator {
return ValidatorImpl{
validators: validators,
}
}
// Validate applies the validators to the configuration.
func (v ValidatorImpl) Validate(rawConfig *map[string]any) error {
data := any(*rawConfig)
return v.validate(rawConfig, data, tree.NewPath(), v.validators)
}
// validate is a recursive function that applies the validators to the configuration.
func (v ValidatorImpl) validate(root *map[string]interface{}, data any, path tree.Path, validators map[tree.Path]ValidatorFunc) error {
// Apply validators that match the current path.
for pattern, validator := range validators {
if path.Matches(pattern) {
if err := validator(root, data, path); err != nil {
return err
}
}
}
// Recursively apply validators to children.
switch data := data.(type) {
case map[string]interface{}:
for key, value := range data {
next := path.Next(key)
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
case []interface{}:
for i, value := range data {
next := path.Next(fmt.Sprintf("[%d]", i))
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
}
return nil
}
package resources
import (
"context"
"fmt"
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/db/repositories"
gormRepo "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/telemetry/logger"
"gitlab.com/nunet/device-management-service/types"
)
var (
// zlog is the logger for the resources package
zlog *otelzap.Logger
// ManagerInstance is the ResourceManager instance
ManagerInstance Manager
)
// TODO: This needs to be initialized in `dms` package and removed from here
// https://gitlab.com/nunet/device-management-service/-/issues/536
func init() {
zlog = logger.OtelZapLogger("resources")
repos := ManagerRepos{
FreeResources: gormRepo.NewFreeResources(db.DB),
OnboardedResources: gormRepo.NewOnboardedResources(db.DB),
RequiredResources: gormRepo.NewRequiredResources(db.DB),
VirtualMachine: gormRepo.NewVirtualMachine(db.DB),
Services: gormRepo.NewServices(db.DB),
}
ManagerInstance = newResourceManager(repos)
}
// SystemSpecs is an interface that defines the methods to get the system specifications of the machine
type SystemSpecs interface {
// GetSpecInfo returns the detailed specifications of the machine
GetSpecInfo() (types.SpecInfo, error)
// GetGPUVendors returns the GPU vendors of the machine
GetGPUVendors() ([]types.GPUVendor, error)
// GetGPUInfo returns the GPU information of the machine for the given vendors
// If no vendors are provided, it returns the information of all the GPUs
GetGPUInfo(vendors ...types.GPUVendor) ([]types.GPU, error)
// GetTotalMemory returns the total memory of the machine in MB
GetTotalMemory() (uint64, error)
// GetTotalStorage returns the total storage of the machine in MB
GetTotalStorage() (uint64, error)
// GetCPUInfo returns the CPU information of the machine
GetCPUInfo() (types.CPUInfo, error)
// GetProvisionedResources returns the total resources of the machine
GetProvisionedResources() (types.Resources, error)
}
// Manager is an interface that defines the methods to manage the resources of the machine
type Manager interface {
// UpdateFreeResources calculates, updates db and returns the free resources of the machine in the database
UpdateFreeResources(context.Context) (types.FreeResources, error)
// GetOnboardedResources returns the onboarded resources of the machine
GetOnboardedResources(context.Context) (types.OnboardedResources, error)
// GetRequiredResources returns the resources required by the jobs running on the machine
GetRequiredResources(context.Context) (types.Resources, error)
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
UpdateOnboardedResources(context.Context, types.OnboardedResources) error
// SystemSpecs returns the SystemSpecs instance
SystemSpecs() SystemSpecs
// UsageMonitor returns the UsageMonitor instance
UsageMonitor() UsageMonitor
// ... other methods
}
// defaultManager implements the Manager interface
// TODO: do we want to have an in-memory cache for the resources instead of querying the DB every time?
// TODO: Add telemetry for the methods https://gitlab.com/nunet/device-management-service/-/issues/535
type defaultManager struct {
usageMonitor UsageMonitor
systemSpecs SystemSpecs
repos ManagerRepos
}
// ManagerRepos holds all the repositories needed for resource management
type ManagerRepos struct {
FreeResources repositories.FreeResources
OnboardedResources repositories.OnboardedResources
RequiredResources repositories.RequiredResources
VirtualMachine repositories.VirtualMachine
Services repositories.Services
}
// newResourceManager returns a new defaultResourceManager instance
func newResourceManager(repos ManagerRepos) *defaultManager {
sysSpecs := newSystemSpecs()
return &defaultManager{
usageMonitor: newUsageMonitor(
sysSpecs,
repos.VirtualMachine,
repos.Services,
repos.RequiredResources,
),
systemSpecs: sysSpecs,
repos: repos,
}
}
// UpdateFreeResources calculates, updates db and returns the free resources of the machine in the database
func (d defaultManager) UpdateFreeResources(ctx context.Context) (types.FreeResources, error) {
usage, err := d.usageMonitor.GetUsage(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting usage: %w", err)
}
onboardedResources, err := d.GetOnboardedResources(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting total resources: %w", err)
}
freeResources, err := onboardedResources.Subtract(usage)
if err != nil {
return types.FreeResources{}, fmt.Errorf("calculating free resources: %w", err)
}
if err := d.updateDBFreeResources(ctx, types.FreeResources{Resources: freeResources}); err != nil {
return types.FreeResources{}, fmt.Errorf("updating free resources in db: %w", err)
}
return types.FreeResources{Resources: freeResources}, nil
}
// GetOnboardedResources returns the onboarded resources of the machine
func (d defaultManager) GetOnboardedResources(ctx context.Context) (types.OnboardedResources, error) {
onboardedResources, err := d.repos.OnboardedResources.Get(ctx)
if err != nil {
return types.OnboardedResources{}, fmt.Errorf("failed to get onboarded resources: %w", err)
}
return onboardedResources, nil
}
// GetRequiredResources returns the resources required by the jobs running on the machine
func (d defaultManager) GetRequiredResources(ctx context.Context) (types.Resources, error) {
jobRequirements, err := d.repos.RequiredResources.FindAll(ctx, d.repos.RequiredResources.GetQuery())
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get resource requirements from db - %v", err)
}
var totalRequiredResources types.Resources
for _, req := range jobRequirements {
totalRequiredResources = totalRequiredResources.Add(req.Resources)
}
return totalRequiredResources, nil
}
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
func (d defaultManager) UpdateOnboardedResources(ctx context.Context, resources types.OnboardedResources) error {
_, err := d.repos.OnboardedResources.Save(ctx, resources)
if err != nil {
return fmt.Errorf("failed to update onboarded resources: %w", err)
}
return nil
}
// SystemSpecs returns the SystemSpecs instance
func (d defaultManager) SystemSpecs() SystemSpecs {
return d.systemSpecs
}
// UsageMonitor returns the UsageMonitor instance
func (d defaultManager) UsageMonitor() UsageMonitor {
return d.usageMonitor
}
// updateDBFreeResources updates the free resources in the database
func (d defaultManager) updateDBFreeResources(ctx context.Context, freeResources types.FreeResources) error {
_, err := d.repos.FreeResources.Save(ctx, freeResources)
if err != nil {
return fmt.Errorf("updating free resources: %w", err)
}
return nil
}
var _ Manager = (*defaultManager)(nil)
package resources
import (
"context"
"errors"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"os/exec"
"regexp"
"strconv"
"strings"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/jaypipes/ghw"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/disk"
"github.com/shirou/gopsutil/v4/mem"
)
type gpuMetadata struct {
PCIAddress string
}
// linuxSystemSpecs implements the SystemSpecs interface for Linux systems
type linuxSystemSpecs struct{}
// newSystemSpecs returns a new instance of linuxSystemSpecs
func newSystemSpecs() *linuxSystemSpecs {
return &linuxSystemSpecs{}
}
var _ SystemSpecs = (*linuxSystemSpecs)(nil)
// GetSpecInfo returns the detailed specifications of the system
// TODO: implement the method
// https://gitlab.com/nunet/device-management-service/-/issues/537
func (l linuxSystemSpecs) GetSpecInfo() (types.SpecInfo, error) {
// TODO implement me
panic("implement me")
}
// GetGPUVendors returns the GPU vendors for the system
func (l linuxSystemSpecs) GetGPUVendors() ([]types.GPUVendor, error) {
var vendors []types.GPUVendor
gpu, err := ghw.GPU()
if err != nil {
return nil, err
}
for _, card := range gpu.GraphicsCards {
deviceInfo := card.DeviceInfo
if deviceInfo != nil {
class := deviceInfo.Class
if class != nil {
className := strings.ToLower(class.Name)
if strings.Contains(className, "display controller") ||
strings.Contains(className, "vga compatible controller") ||
strings.Contains(className, "3d controller") ||
strings.Contains(className, "2d controller") {
vendor := card.DeviceInfo.Vendor
if vendor != nil {
switch {
case strings.Contains(strings.ToLower(vendor.Name), "nvidia"):
vendors = append(vendors, types.GPUVendorNvidia)
case strings.Contains(strings.ToLower(vendor.Name), "amd"):
vendors = append(vendors, types.GPUVendorAMDATI)
case strings.Contains(strings.ToLower(vendor.Name), "intel"):
vendors = append(vendors, types.GPUVendorIntel)
default:
vendors = append(vendors, types.GPUVendorUnknown)
}
}
}
}
}
}
return vendors, nil
}
// GetGPUInfo returns the GPU information for the system based on the specified vendor.
func (l linuxSystemSpecs) GetGPUInfo(vendors ...types.GPUVendor) ([]types.GPU, error) {
var gpus []types.GPU
gpuMetadataMap, err := l.fetchGPUMetadata()
if err != nil {
return nil, fmt.Errorf("failed to fetch GPU metadata: %w", err)
}
// Helper function to fetch and append GPU info
fetchAndAppendGPUs := func(fetchFunc func(metadata []gpuMetadata) ([]types.GPU, error), vendor types.GPUVendor) {
vendorMetadata, ok := gpuMetadataMap[vendor]
if !ok {
zlog.Sugar().Infof("No %s GPUs found", vendor)
return
}
gpuList, err := fetchFunc(vendorMetadata)
if err != nil {
zlog.Sugar().Warnf("Failed to retrieve %s GPU information: %v", vendor, err)
return
}
gpus = append(gpus, gpuList...)
}
if len(vendors) == 0 {
// No specific vendor requested, fetch all types of GPUs
fetchAndAppendGPUs(l.getIntelGPUInfo, types.GPUVendorIntel)
fetchAndAppendGPUs(l.getNVIDIAGPUInfo, types.GPUVendorNvidia)
fetchAndAppendGPUs(l.getAMDGPUInfo, types.GPUVendorAMDATI)
} else {
// Fetch GPUs for the specified vendor only
for _, vendor := range vendors {
switch vendor {
case types.GPUVendorIntel:
fetchAndAppendGPUs(l.getIntelGPUInfo, vendor)
case types.GPUVendorNvidia:
fetchAndAppendGPUs(l.getNVIDIAGPUInfo, vendor)
case types.GPUVendorAMDATI:
fetchAndAppendGPUs(l.getAMDGPUInfo, vendor)
default:
return nil, fmt.Errorf("unsupported GPU vendor: %v", vendor)
}
}
}
// Assign index to GPUs and return
// Note: The index is internal to dms and is not the same as the device index
return assignIndexToGPUs(gpus), nil
}
// GetTotalMemory returns the total memory available on the system
func (l linuxSystemSpecs) GetTotalMemory() (uint64, error) {
v, err := mem.VirtualMemory()
if err != nil {
return 0, fmt.Errorf("failed to get total memory: %s", err)
}
ramInMB := v.Total / 1024 / 1024
return ramInMB, nil
}
// GetTotalStorage returns the total storage available on the system
func (l linuxSystemSpecs) GetTotalStorage() (uint64, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return 0, fmt.Errorf("failed to get partitions: %w", err)
}
var totalStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return 0, fmt.Errorf("failed to get disk usage: %w", err)
}
totalStorage += usage.Total
}
return totalStorage, nil
}
// GetCPUInfo returns the CPU information for the system
func (l linuxSystemSpecs) GetCPUInfo() (types.CPUInfo, error) {
cores, err := cpu.Info()
if err != nil {
return types.CPUInfo{}, fmt.Errorf("failed to get CPU info: %s", err)
}
var totalCompute float64
for i := 0; i < len(cores); i++ {
totalCompute += cores[i].Mhz
}
return types.CPUInfo{
Compute: totalCompute,
NumCores: uint64(len(cores)),
MHzPerCore: cores[0].Mhz,
}, nil
}
// GetProvisionedResources returns the total resources available on the system
func (l linuxSystemSpecs) GetProvisionedResources() (types.Resources, error) {
cpuInfo, err := l.GetCPUInfo()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get CPU info: %s", err)
}
totalMemory, err := l.GetTotalMemory()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get total memory: %s", err)
}
gpus, err := l.GetGPUInfo()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get GPU info: %s", err)
}
totalDisk, err := l.GetTotalStorage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get total storage: %s", err)
}
return types.Resources{
CPU: cpuInfo.Compute,
RAM: totalMemory,
Disk: totalDisk,
GPU: gpus,
}, nil
}
// getAMDGPUInfo returns the GPU information for AMD GPUs
func (l linuxSystemSpecs) getAMDGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
cmd := exec.Command("rocm-smi", "--showid", "--showproductname", "--showmeminfo", "vram")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("AMD ROCm not installed, initialized, or configured (reboot recommended for newly installed AMD GPU Drivers): %s", err)
}
outputStr := string(output)
// fmt.Println("rocm-smi vram output:\n", outputStr) // Print the output for debugging
gpuNameRegex := regexp.MustCompile(`GPU\[\d+\]\s+: Card Series:\s+([^\n]+)`)
totalRegex := regexp.MustCompile(`GPU\[\d+\]\s+: VRAM Total Memory \(B\):\s+(\d+)`)
usedRegex := regexp.MustCompile(`GPU\[\d+\]\s+: VRAM Total Used Memory \(B\):\s+(\d+)`)
gpuNameMatches := gpuNameRegex.FindAllStringSubmatch(outputStr, -1)
totalMatches := totalRegex.FindAllStringSubmatch(outputStr, -1)
usedMatches := usedRegex.FindAllStringSubmatch(outputStr, -1)
if len(gpuNameMatches) == 0 || len(totalMatches) == 0 || len(usedMatches) == 0 {
return nil, fmt.Errorf("failed to find AMD GPU information or vram information in the output")
}
if len(gpuNameMatches) != len(totalMatches) || len(totalMatches) != len(usedMatches) {
return nil, fmt.Errorf("inconsistent AMD GPU information detected")
}
var gpuInfos []types.GPU
for i := range gpuNameMatches {
gpuName := gpuNameMatches[i][1]
totalMemoryBytes, err := strconv.ParseInt(totalMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse total amdgpu vram: %s", err)
}
usedMemoryBytes, err := strconv.ParseInt(usedMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse used amdgpu vram: %s", err)
}
totalMemoryMiB := totalMemoryBytes / 1024 / 1024
usedMemoryMiB := usedMemoryBytes / 1024 / 1024
freeMemoryMiB := totalMemoryMiB - usedMemoryMiB
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
TotalVRAM: uint64(totalMemoryMiB),
UsedVRAM: uint64(usedMemoryMiB),
FreeVRAM: uint64(freeMemoryMiB),
Vendor: types.GPUVendorAMDATI,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// getNVIDIAGPUInfo returns the GPU information for NVIDIA GPUs
func (l linuxSystemSpecs) getNVIDIAGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
// Initialize NVML
ret := nvml.Init()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("NVIDIA Management Library not installed, initialized or configured (reboot recommended for newly installed NVIDIA GPU drivers): %s", nvml.ErrorString(ret))
}
defer nvml.Shutdown()
// Get the number of GPU devices
deviceCount, ret := nvml.DeviceGetCount()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device count: %s", nvml.ErrorString(ret))
}
if deviceCount != len(metadata) {
return nil, fmt.Errorf("failed to find NVIDIA GPU information for all GPUs")
}
var gpus []types.GPU
// Iterate over each device
for i := 0; i < deviceCount; i++ {
// Get the device handle
device, ret := nvml.DeviceGetHandleByIndex(i)
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device handle for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the device name
name, ret := device.GetName()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get name for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the memory info
memory, ret := device.GetMemoryInfo()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get nvidiagpu vram info for device %d: %s", i, nvml.ErrorString(ret))
}
gpu := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Name: name,
Model: name,
TotalVRAM: memory.Total / 1024 / 1024,
UsedVRAM: memory.Used / 1024 / 1024,
FreeVRAM: memory.Free / 1024 / 1024,
Vendor: types.GPUVendorNvidia,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// getIntelGPUInfo returns the GPU information for Intel GPUs
func (l linuxSystemSpecs) getIntelGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
// Determine the number of discrete Intel GPUs
cmd := exec.Command("xpu-smi", "health", "-l")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
outputStr := string(output)
// fmt.Println("xpu-smi health -l output:\n", outputStr) // Print the output for debugging
// Use regex to find all instances of Device ID
deviceIDRegex := regexp.MustCompile(`(?i)\| Device ID\s+\|\s+(\d+)\s+\|`)
deviceIDMatches := deviceIDRegex.FindAllStringSubmatch(outputStr, -1)
// fmt.Printf("Found device ID matches: %v\n", deviceIDMatches) // Print matched device IDs for debugging
if len(deviceIDMatches) == 0 {
return nil, fmt.Errorf("failed to find any Intel GPUs")
}
if len(deviceIDMatches) != len(metadata) {
return nil, fmt.Errorf("failed to find Intel GPU information for all GPUs")
}
var gpuInfos []types.GPU
for i, match := range deviceIDMatches {
deviceID := match[1]
// Get GPU details using xpu-smi discovery
cmd = exec.Command("xpu-smi", "discovery", "-d", deviceID)
output, err = cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to get discovery info for Intel GPU %s: %s", deviceID, err)
}
outputStr = string(output)
// fmt.Printf("xpu-smi discovery -d %s output:\n%s", deviceID, outputStr) // Print the output for debugging
// Use regex to find GPU name and total memory
nameRegex := regexp.MustCompile(`(?i)Device Name:\s+([^\n|]+)`)
totalMemRegex := regexp.MustCompile(`(?i)Memory Physical Size:\s+([^\s]+)\s+MiB`)
nameMatch := nameRegex.FindStringSubmatch(outputStr)
totalMemMatch := totalMemRegex.FindStringSubmatch(outputStr)
if nameMatch == nil || totalMemMatch == nil {
return nil, fmt.Errorf("failed to parse discovery info for Intel GPU %s", deviceID)
}
gpuName := strings.TrimSpace(nameMatch[1])
totalMemoryMiB, err := strconv.ParseFloat(totalMemMatch[1], 64)
if err != nil {
return nil, fmt.Errorf("failed to parse total memory for Intel GPU %s: %s", deviceID, err)
}
// Get used memory using xpu-smi stats
cmd = exec.Command("xpu-smi", "stats", "-d", deviceID)
output, err = cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to get stats for Intel GPU %s: %s", deviceID, err)
}
outputStr = string(output)
// fmt.Printf("xpu-smi stats -d %s output:\n%s", deviceID, outputStr) // Print the output for debugging
// Use regex to find used memory
usedMemRegex := regexp.MustCompile(`(?i)GPU Memory Used \(MiB\)\s+\|\s+(\d+)\s+\|`)
usedMemMatch := usedMemRegex.FindStringSubmatch(outputStr)
if usedMemMatch == nil {
return nil, fmt.Errorf("failed to parse used memory for Intel GPU %s", deviceID)
}
usedMemoryMiB, err := strconv.ParseFloat(usedMemMatch[1], 64)
if err != nil {
return nil, fmt.Errorf("failed to parse used memory for Intel GPU %s: %s", deviceID, err)
}
freeMemoryMiB := totalMemoryMiB - usedMemoryMiB
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
TotalVRAM: uint64(totalMemoryMiB),
UsedVRAM: uint64(usedMemoryMiB),
FreeVRAM: uint64(freeMemoryMiB),
Vendor: types.GPUVendorIntel,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// fetchGPUMetadata fetches the GPU metadata for the system using `ghw.GPU()`
// TODO: Use one single library to fetch GPU information or improve the match criteria
// https://gitlab.com/nunet/device-management-service/-/issues/548
// TODO: write tests by mocking the gpu snapshot
// https://gitlab.com/nunet/device-management-service/-/issues/534
func (l linuxSystemSpecs) fetchGPUMetadata() (map[types.GPUVendor][]gpuMetadata, error) {
gpuInfo, err := ghw.GPU()
if err != nil {
return nil, err
}
gpuDetails := make(map[types.GPUVendor][]gpuMetadata)
for _, card := range gpuInfo.GraphicsCards {
if card.DeviceInfo == nil {
continue
}
pciAddress := card.Address
vendor := types.ParseGPUVendor(card.DeviceInfo.Vendor.Name)
gpuDetails[vendor] = append(gpuDetails[vendor], gpuMetadata{PCIAddress: pciAddress})
}
return gpuDetails, nil
}
// assignIndexToGPUs assigns an index to each GPU in the list starting from 0
func assignIndexToGPUs(gpus []types.GPU) []types.GPU {
for i := range gpus {
gpus[i].Index = i
}
return gpus
}
package resources
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// UsageMonitor defines the methods to monitor the system usage
type UsageMonitor interface {
// GetUsage returns the resources used by the machine
GetUsage(context.Context) (types.Resources, error)
// getVMUsage returns the resources used by the firecracker VMs
getVMUsage(context.Context, types.CPUInfo) (types.Resources, error)
// getContainerUsage returns the resources used by the docker containers
getContainerUsage(context.Context) (types.Resources, error)
}
// defaultUsageMonitor implements the UsageMonitor interface
type defaultUsageMonitor struct {
systemSpecs SystemSpecs
vmRepo repositories.VirtualMachine
serviceRepo repositories.Services
requiredResourcesRepo repositories.RequiredResources
}
// newUsageMonitor creates a new defaultUsageMonitor
func newUsageMonitor(
systemSpecs SystemSpecs,
vmRepo repositories.VirtualMachine,
serviceRepo repositories.Services,
requiredResourcesRepo repositories.RequiredResources,
) *defaultUsageMonitor {
return &defaultUsageMonitor{
systemSpecs: systemSpecs,
vmRepo: vmRepo,
serviceRepo: serviceRepo,
requiredResourcesRepo: requiredResourcesRepo,
}
}
// GetUsage returns the resources used by the machine
func (um *defaultUsageMonitor) GetUsage(ctx context.Context) (types.Resources, error) {
cpuInfo, err := um.systemSpecs.GetCPUInfo()
if err != nil {
return types.Resources{}, fmt.Errorf("getting CPU info: %w", err)
}
vmUsage, err := um.getVMUsage(ctx, cpuInfo)
if err != nil {
return types.Resources{}, fmt.Errorf("getting VM usage: %w", err)
}
contUsage, err := um.getContainerUsage(ctx)
if err != nil {
return types.Resources{}, fmt.Errorf("getting container usage: %w", err)
}
return vmUsage.Add(contUsage), nil
}
// getVMUsage returns the total usage of all running VMs
func (um *defaultUsageMonitor) getVMUsage(ctx context.Context, cpuInfo types.CPUInfo) (types.Resources, error) {
query := um.vmRepo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("State", "running"))
vms, err := um.vmRepo.FindAll(ctx, query)
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get running VMs: %w", err)
}
var resourcesUsage types.Resources
if len(vms) == 0 {
return resourcesUsage, nil
}
// TODO: disk usage
var totalVCPU, totalMemSizeMib uint
for _, vm := range vms {
totalVCPU += vm.VCPUCount
totalMemSizeMib += vm.MemSizeMib
}
resourcesUsage.RAM = uint64(totalMemSizeMib)
resourcesUsage.CPU = float64(totalVCPU) * cpuInfo.MHzPerCore // CPU in MHz
return resourcesUsage, nil
}
// getContainerUsage returns the total usage of all running containers
func (um *defaultUsageMonitor) getContainerUsage(ctx context.Context) (types.Resources, error) {
query := um.serviceRepo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("JobStatus", "running"))
services, err := um.serviceRepo.FindAll(ctx, query)
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get running containers: %w", err)
}
var resourcesUsage types.Resources
if len(services) == 0 {
return resourcesUsage, nil
}
requirements, err := um.getResourceRequirements(ctx)
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get resource requirements: %w", err)
}
// TODO: disk usage
for _, service := range services {
resourcesReq := requirements[service.ResourceRequirements]
resourcesUsage.CPU += resourcesReq.CPU
resourcesUsage.RAM += resourcesReq.RAM
}
return resourcesUsage, nil
}
// getResourceRequirements returns the resource requirements of all jobs
func (um *defaultUsageMonitor) getResourceRequirements(ctx context.Context) (map[int]types.RequiredResources, error) {
requiredResources, err := um.requiredResourcesRepo.FindAll(ctx, um.requiredResourcesRepo.GetQuery())
if err != nil {
return nil, fmt.Errorf("unable to query resource requirements: %w", err)
}
mappedRequiredResources := make(map[int]types.RequiredResources)
for _, rr := range requiredResources {
mappedRequiredResources[rr.JobID] = rr
}
return mappedRequiredResources, nil
}
var _ UsageMonitor = (*defaultUsageMonitor)(nil)
package dms
import (
"gorm.io/gorm"
)
// SanityCheck before being deleted performed basic consistency checks before starting the DMS
// in the following sequence:
// It checks for services that are marked running from the database and stops then removes them.
// Update their status to 'finshed with errors'.
// Recalculates free resources and update the database.
//
// Deleted now because dependencies such as the docker package have been replaced with executor/docker
func SanityCheck(gormDB *gorm.DB) {
// TODO: sanity check of DMS last exit and correction of invalid states
//resources.CalcFreeResAndUpdateDB()
}
package background_tasks
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("background_tasks")
}
package background_tasks
import (
"sort"
"sync"
"time"
)
// Scheduler orchestrates the execution of tasks based on their triggers and priority.
type Scheduler struct {
tasks map[int]*Task // Map of tasks by their ID.
runningTasks map[int]bool // Map to keep track of running tasks.
ticker *time.Ticker // Ticker for periodic checks of task triggers.
stopChan chan struct{} // Channel to signal stopping the scheduler.
maxRunningTasks int // Maximum number of tasks that can run concurrently.
lastTaskID int // Counter for assigning unique IDs to tasks.
mu sync.Mutex // Mutex to protect access to task maps.
}
// NewScheduler creates a new Scheduler with a specified limit on running tasks.
func NewScheduler(maxRunningTasks int) *Scheduler {
return &Scheduler{
tasks: make(map[int]*Task),
runningTasks: make(map[int]bool),
ticker: time.NewTicker(1 * time.Second),
stopChan: make(chan struct{}),
maxRunningTasks: maxRunningTasks,
lastTaskID: 0,
}
}
// AddTask adds a new task to the scheduler and initializes its state.
func (s *Scheduler) AddTask(task *Task) *Task {
s.mu.Lock()
defer s.mu.Unlock()
task.ID = s.lastTaskID
task.Enabled = true
for _, trigger := range task.Triggers {
trigger.Reset()
}
s.tasks[task.ID] = task
s.lastTaskID++
return task
}
// RemoveTask removes a task from the scheduler.
func (s *Scheduler) RemoveTask(taskID int) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Start begins the scheduler's task execution loop.
func (s *Scheduler) Start() {
go func() {
for {
select {
case <-s.stopChan:
return
case <-s.ticker.C:
s.runTasks()
}
}
}()
}
// runningTasksCount returns the count of running tasks.
func (s *Scheduler) runningTasksCount() int {
s.mu.Lock()
defer s.mu.Unlock()
count := 0
for _, isRunning := range s.runningTasks {
if isRunning {
count++
}
}
return count
}
// runTasks checks and runs tasks based on their triggers and priority.
func (s *Scheduler) runTasks() {
// Sort tasks by priority.
sortedTasks := make([]*Task, 0, len(s.tasks))
for _, task := range s.tasks {
sortedTasks = append(sortedTasks, task)
}
sort.Slice(sortedTasks, func(i, j int) bool {
return sortedTasks[i].Priority > sortedTasks[j].Priority
})
for _, task := range sortedTasks {
if !task.Enabled || s.runningTasks[task.ID] {
continue
}
if len(task.Triggers) == 0 {
s.RemoveTask(task.ID)
continue
}
for _, trigger := range task.Triggers {
if trigger.IsReady() && s.runningTasksCount() < s.maxRunningTasks {
s.runningTasks[task.ID] = true
go s.runTask(task.ID)
trigger.Reset()
break
}
}
}
}
// Stop signals the scheduler to stop running tasks.
func (s *Scheduler) Stop() {
close(s.stopChan)
}
// runTask executes a task and manages its lifecycle and retry policy.
func (s *Scheduler) runTask(taskID int) {
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
s.runningTasks[taskID] = false
}()
task := s.tasks[taskID]
execution := Execution{StartedAt: time.Now()}
defer func() {
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.tasks[taskID] = task
s.mu.Unlock()
}()
for i := 0; i < task.RetryPolicy.MaxRetries+1; i++ {
err := runTaskWithRetry(task.Function, task.Args, task.RetryPolicy.Delay)
if err == nil {
execution.Status = "SUCCESS"
execution.EndedAt = time.Now()
return
}
execution.Error = err.Error()
}
execution.Status = "FAILED"
execution.EndedAt = time.Now()
}
// runTaskWithRetry attempts to execute a task with a retry policy.
func runTaskWithRetry(
fn func(args interface{}) error,
args []interface{},
delay time.Duration,
) error {
err := fn(args)
if err != nil {
time.Sleep(delay)
return err
}
return nil
}
package background_tasks
import (
"time"
"github.com/robfig/cron/v3"
)
// Trigger interface defines a method to check if a trigger condition is met.
type Trigger interface {
IsReady() bool // Returns true if the trigger condition is met.
Reset() // Resets the trigger state.
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTrigger struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTrigger) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTrigger) Reset() {
t.lastTriggered = time.Now()
}
// EventTrigger triggers based on an external event signaled through a channel.
type EventTrigger struct {
Trigger chan bool // Channel to signal an event.
}
// IsReady checks if there is a signal in the trigger channel.
func (t *EventTrigger) IsReady() bool {
select {
case <-t.Trigger:
return true
default:
return false
}
}
// Reset for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) Reset() {}
// OneTimeTrigger triggers once after a specified delay.
type OneTimeTrigger struct {
Delay time.Duration // The delay after which to trigger.
registeredAt time.Time // Time when the trigger was set.
}
// Reset sets the trigger registration time to the current time.
func (t *OneTimeTrigger) Reset() {
t.registeredAt = time.Now()
}
// IsReady checks if the current time has passed the delay period.
func (t *OneTimeTrigger) IsReady() bool {
return t.registeredAt.Add(t.Delay).Before(time.Now())
}
package libp2p
import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/multiformats/go-multiaddr"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"google.golang.org/protobuf/proto"
)
// Bootstrap using a list.
func (l *Libp2p) Bootstrap(ctx context.Context, bootstrapPeers []multiaddr.Multiaddr) error {
if err := l.DHT.Bootstrap(ctx); err != nil {
return fmt.Errorf("failed to prepare this node for bootstraping: %w", err)
}
// bootstrap all nodes at the same time.
if len(bootstrapPeers) > 0 {
var wg sync.WaitGroup
for _, addr := range bootstrapPeers {
wg.Add(1)
go func(peerAddr multiaddr.Multiaddr) {
defer wg.Done()
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
zlog.Sugar().Errorf("failed to convert multi addr to addr info %v - %v", peerAddr, err)
return
}
if err := l.Host.Connect(ctx, *addrInfo); err != nil {
zlog.Sugar().Errorf("failed to connect to bootstrap node %s - %v", addrInfo.ID.String(), err)
} else {
zlog.Sugar().Infof("connected to Bootstrap Node %s", addrInfo.ID.String())
}
}(addr)
}
wg.Wait()
}
return nil
}
type dhtValidator struct {
PS peerstore.Peerstore
customNamespace string
}
// Validate validates an item placed into the dht.
func (d dhtValidator) Validate(key string, value []byte) error {
// empty value is considered deleting an item from the dht
if len(value) == 0 {
return nil
}
if !strings.HasPrefix(key, d.customNamespace) {
return errors.New("invalid key namespace")
}
// verify signature
var envelope commonproto.Advertisement
err := proto.Unmarshal(value, &envelope)
if err != nil {
return fmt.Errorf("failed to unmarshal envelope: %w", err)
}
pubKey, err := crypto.UnmarshalSecp256k1PublicKey(envelope.PublicKey)
if err != nil {
return fmt.Errorf("failed to unmarshal public key: %w", err)
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
envelope.PublicKey,
}, nil)
ok, err := pubKey.Verify(concatenatedBytes, envelope.Signature)
if err != nil {
return fmt.Errorf("failed to verify envelope: %w", err)
}
if !ok {
return errors.New("failed to verify envelope, public key didn't sign payload")
}
return nil
}
func (dhtValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
// TODO remove the below when network package is fully implemented
// UpdateKadDHT is a stub
func (l *Libp2p) UpdateKadDHT() {
zlog.Warn("UpdateKadDHT: Stub")
}
// ListKadDHTPeers is a stub
func (l *Libp2p) ListKadDHTPeers(c *gin.Context, ctx context.Context) ([]string, error) {
zlog.Warn("ListKadDHTPeers: Stub")
return nil, nil
}
package libp2p
import (
"context"
"fmt"
"os"
"github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
dutil "github.com/libp2p/go-libp2p/p2p/discovery/util"
)
// DiscoverDialPeers discovers peers using randevouz point
func (l *Libp2p) DiscoverDialPeers(ctx context.Context) error {
foundPeers, err := l.findPeersFromRendezvousDiscovery(ctx)
if err != nil {
return err
}
if len(foundPeers) > 0 {
l.discoveredPeers = foundPeers
}
// filter out peers with no listening addresses and self host
filterSpec := NoAddrIDFilter{ID: l.Host.ID()}
l.discoveredPeers = PeerPassFilter(l.discoveredPeers, filterSpec)
l.dialPeers(ctx)
return nil
}
// advertiseForRendezvousDiscovery is used to advertise node using the dht by giving it the randevouz point.
func (l *Libp2p) advertiseForRendezvousDiscovery(context context.Context) error {
_, err := l.discovery.Advertise(context, l.config.Rendezvous)
return err
}
// findPeersFromRendezvousDiscovery uses the randevouz point to discover other peers.
func (l *Libp2p) findPeersFromRendezvousDiscovery(ctx context.Context) ([]peer.AddrInfo, error) {
peers, err := dutil.FindPeers(
ctx,
l.discovery,
l.config.Rendezvous,
discovery.Limit(l.config.PeerCountDiscoveryLimit),
)
if err != nil {
return nil, fmt.Errorf("failed to discover peers: %w", err)
}
return peers, nil
}
func (l *Libp2p) dialPeers(ctx context.Context) {
for _, p := range l.discoveredPeers {
if p.ID == l.Host.ID() {
continue
}
if l.Host.Network().Connectedness(p.ID) != network.Connected {
_, err := l.Host.Network().DialPeer(ctx, p.ID)
if err != nil {
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("couldn't establish connection with: %s - error: %v", p.ID.String(), err)
}
continue
}
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("connected with: %s", p.ID.String())
}
}
}
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/control"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
mafilt "github.com/whyrusleeping/multiaddr-filter"
)
var defaultServerFilters = []string{
"/ip4/10.0.0.0/ipcidr/8",
"/ip4/100.64.0.0/ipcidr/10",
"/ip4/169.254.0.0/ipcidr/16",
"/ip4/172.16.0.0/ipcidr/12",
"/ip4/192.0.0.0/ipcidr/24",
"/ip4/192.0.2.0/ipcidr/24",
"/ip4/192.168.0.0/ipcidr/16",
"/ip4/198.18.0.0/ipcidr/15",
"/ip4/198.51.100.0/ipcidr/24",
"/ip4/203.0.113.0/ipcidr/24",
"/ip4/240.0.0.0/ipcidr/4",
"/ip6/100::/ipcidr/64",
"/ip6/2001:2::/ipcidr/48",
"/ip6/2001:db8::/ipcidr/32",
"/ip6/fc00::/ipcidr/7",
"/ip6/fe80::/ipcidr/10",
}
// PeerFilter is an interface for filtering peers
// satisfaction of filter criteria allows the peer to pass
type PeerFilter interface {
satisfies(p peer.AddrInfo) bool
}
// NoAddrIDFilter filters out peers with no listening addresses
// and a peer with a specific ID
type NoAddrIDFilter struct {
ID peer.ID
}
func (f NoAddrIDFilter) satisfies(p peer.AddrInfo) bool {
return len(p.Addrs) > 0 && p.ID != f.ID
}
func PeerPassFilter(peers []peer.AddrInfo, pf PeerFilter) []peer.AddrInfo {
var filtered []peer.AddrInfo
for _, p := range peers {
if pf.satisfies(p) {
filtered = append(filtered, p)
}
}
return filtered
}
type filtersConnectionGater multiaddr.Filters
var _ connmgr.ConnectionGater = (*filtersConnectionGater)(nil)
func (f *filtersConnectionGater) InterceptAddrDial(_ peer.ID, addr multiaddr.Multiaddr) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(addr)
}
func (f *filtersConnectionGater) InterceptPeerDial(p peer.ID) (allow bool) {
return true
}
func (f *filtersConnectionGater) InterceptAccept(connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptUpgraded(_ network.Conn) (allow bool, reason control.DisconnectReason) {
return true, 0
}
func makeAddrsFactory(announce []string, appendAnnouce []string, noAnnounce []string) func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var err error // To assign to the slice in the for loop
existing := make(map[string]bool) // To avoid duplicates
annAddrs := make([]multiaddr.Multiaddr, len(announce))
for i, addr := range announce {
annAddrs[i], err = multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
existing[addr] = true
}
var appendAnnAddrs []multiaddr.Multiaddr
for _, addr := range appendAnnouce {
if existing[addr] {
// skip AppendAnnounce that is on the Announce list already
continue
}
appendAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
appendAnnAddrs = append(appendAnnAddrs, appendAddr)
}
filters := multiaddr.NewFilters()
noAnnAddrs := map[string]bool{}
for _, addr := range noAnnounce {
f, err := mafilt.NewMask(addr)
if err == nil {
filters.AddFilter(*f, multiaddr.ActionDeny)
continue
}
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
noAnnAddrs[string(maddr.Bytes())] = true
}
return func(allAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
var addrs []multiaddr.Multiaddr
if len(annAddrs) > 0 {
addrs = annAddrs
} else {
addrs = allAddrs
}
addrs = append(addrs, appendAnnAddrs...)
var out []multiaddr.Multiaddr
for _, maddr := range addrs {
// check for exact matches
ok := noAnnAddrs[string(maddr.Bytes())]
// check for /ipcidr matches
if !ok && !filters.AddrBlocked(maddr) {
out = append(out, maddr)
}
}
return out
}
}
package libp2p
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/types"
)
// StreamHandler is a function type that processes data from a stream.
type StreamHandler func(stream network.Stream)
// HandlerRegistry manages the registration of stream handlers for different protocols.
type HandlerRegistry struct {
host host.Host
handlers map[protocol.ID]StreamHandler
bytesHandlers map[protocol.ID]func(data []byte)
mu sync.RWMutex
}
// NewHandlerRegistry creates a new handler registry instance.
func NewHandlerRegistry(host host.Host) *HandlerRegistry {
return &HandlerRegistry{
host: host,
handlers: make(map[protocol.ID]StreamHandler),
bytesHandlers: make(map[protocol.ID]func(data []byte)),
}
}
// RegisterHandlerWithStreamCallback registers a stream handler for a specific protocol.
func (r *HandlerRegistry) RegisterHandlerWithStreamCallback(messageType types.MessageType, handler StreamHandler) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.handlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.handlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(handler))
return nil
}
// RegisterHandlerWithBytesCallback registers a stream handler for a specific protocol and sends the bytes back to callback.
func (r *HandlerRegistry) RegisterHandlerWithBytesCallback(messageType types.MessageType, s StreamHandler, handler func(data []byte)) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.bytesHandlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.bytesHandlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(s))
return nil
}
// SendMessageToLocalHandler given the message type it sends data to the local handler found.
func (r *HandlerRegistry) SendMessageToLocalHandler(messageType types.MessageType, data []byte) {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
h, ok := r.bytesHandlers[protoID]
if !ok {
return
}
// we need this goroutine to avoid blocking the caller goroutine
go h(data)
}
package libp2p
import (
"context"
"fmt"
"strings"
"time"
"github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/pnet"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/routing"
"github.com/libp2p/go-libp2p/p2p/host/autorelay"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/security/noise"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"github.com/multiformats/go-multiaddr"
"github.com/spf13/afero"
mafilt "github.com/whyrusleeping/multiaddr-filter"
"gitlab.com/nunet/device-management-service/types"
)
// NewHost returns a new libp2p host with dht and other related settings.
func NewHost(ctx context.Context, config *types.Libp2pConfig, fs afero.Fs) (host.Host, *dht.IpfsDHT, *pubsub.PubSub, error) {
var idht *dht.IpfsDHT
connmgr, err := connmgr.NewConnManager(
100,
400,
connmgr.WithGracePeriod(time.Duration(config.GracePeriodMs)*time.Millisecond),
)
if err != nil {
return nil, nil, nil, err
}
filter := multiaddr.NewFilters()
for _, s := range defaultServerFilters {
f, err := mafilt.NewMask(s)
if err != nil {
zlog.Sugar().Errorf("incorrectly formatted address filter in config: %s - %v", s, err)
}
filter.AddFilter(*f, multiaddr.ActionDeny)
}
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, nil, nil, err
}
var libp2pOpts []libp2p.Option
baseOpts := []dht.Option{
dht.ProtocolPrefix(protocol.ID(config.DHTPrefix)),
dht.NamespacedValidator(strings.ReplaceAll(config.CustomNamespace, "/", ""), dhtValidator{PS: ps}),
dht.Mode(dht.ModeServer),
}
if config.PrivateNetwork.WithSwarmKey {
psk, err := configureSwarmKey(fs)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to configure swarm key: %v", err)
}
libp2pOpts = append(libp2pOpts, libp2p.PrivateNetwork(psk))
// guarantee that outer connection will be refused
pnet.ForcePrivateNetwork = true
} else {
// enable quic (it does not work with pnet enabled)
libp2pOpts = append(libp2pOpts, libp2p.Transport(quic.NewTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(webtransport.New))
// for some reason, ForcePrivateNetwork was equal to true even without being set to true
pnet.ForcePrivateNetwork = false
}
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(config.ListenAddress...),
libp2p.Identity(config.PrivateKey),
libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) {
idht, err = dht.New(ctx, h, baseOpts...)
return idht, err
}),
libp2p.Peerstore(ps),
libp2p.Security(libp2ptls.ID, libp2ptls.New),
libp2p.Security(noise.ID, noise.New),
// Do not use DefaulTransports as we can not enable Quic when pnet
libp2p.Transport(tcp.NewTCPTransport),
libp2p.Transport(ws.New),
libp2p.EnableNATService(),
libp2p.ConnectionManager(connmgr),
libp2p.EnableRelay(),
libp2p.EnableHolePunching(),
libp2p.EnableRelayService(
relay.WithResources(
relay.Resources{
MaxReservations: 256,
MaxCircuits: 32,
BufferSize: 4096,
MaxReservationsPerPeer: 8,
MaxReservationsPerIP: 16,
},
),
relay.WithLimit(&relay.RelayLimit{
Duration: 5 * time.Minute,
Data: 1 << 21, // 2 MiB
}),
),
libp2p.EnableAutoRelayWithPeerSource(
func(ctx context.Context, num int) <-chan peer.AddrInfo {
r := make(chan peer.AddrInfo)
go func() {
defer close(r)
for i := 0; i < num; i++ {
select {
case p := <-newPeer:
select {
case r <- p:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return r
},
autorelay.WithBootDelay(time.Minute),
autorelay.WithBackoff(30*time.Second),
autorelay.WithMinCandidates(2),
autorelay.WithMaxCandidates(3),
autorelay.WithNumRelays(2),
),
)
if config.Server {
libp2pOpts = append(libp2pOpts, libp2p.AddrsFactory(makeAddrsFactory([]string{}, []string{}, defaultServerFilters)))
libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater((*filtersConnectionGater)(filter)))
} else {
libp2pOpts = append(libp2pOpts, libp2p.NATPortMap())
}
host, err := libp2p.New(libp2pOpts...)
if err != nil {
return nil, nil, nil, err
}
optsPS := []pubsub.Option{pubsub.WithMessageSigning(true), pubsub.WithMaxMessageSize(config.GossipMaxMessageSize)}
gossip, err := pubsub.NewGossipSub(ctx, host, optsPS...)
// gossip, err := pubsub.NewGossipSubWithRouter(ctx, host, pubsub.DefaultGossipSubRouter(host), optsPS...)
if err != nil {
return nil, nil, nil, err
}
return host, idht, gossip, nil
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
const (
// Custom namespace for DHT protocol with version number
customNamespace = "/nunet-dht-1/"
)
// TODO: pass the logger to the constructor and remove from here
var (
zlog *otelzap.Logger
newPeer = make(chan peer.AddrInfo)
)
func init() {
zlog = logger.OtelZapLogger("network.libp2p")
}
package libp2p
import (
"context"
"bufio"
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/ipfs/go-cid"
kbucket "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"gitlab.com/nunet/device-management-service/types"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
"github.com/spf13/afero"
"google.golang.org/protobuf/proto"
dht "github.com/libp2p/go-libp2p-kad-dht"
libp2pdiscovery "github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
drouting "github.com/libp2p/go-libp2p/p2p/discovery/routing"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
)
const (
MB = 1024 * 1024
MaxMessageLengthMB = 10
)
// Libp2p contains the configuration for a Libp2p instance.
//
// TODO-suggestion: maybe we should call it something else like Libp2pPeer,
// Libp2pHost or just Peer (callers would use libp2p.Peer...)
type Libp2p struct {
Host host.Host
DHT *dht.IpfsDHT
PS peerstore.Peerstore
pubsub *pubsub.PubSub
pubsubTopics map[string]*pubsub.Topic
topicSubscription map[string]*pubsub.Subscription
topicMux sync.RWMutex
// a list of peers discovered by discovery
discoveredPeers []peer.AddrInfo
discovery libp2pdiscovery.Discovery
// services
pingService *ping.PingService
// tasks
discoveryTask *bt.Task
handlerRegistry *HandlerRegistry
config *types.Libp2pConfig
// dependencies (db, filesystem...)
fs afero.Fs
}
// New creates a libp2p instance.
//
// TODO-Suggestion: move types.Libp2pConfig to here for better readability.
// Unless there is a reason to keep within types.
func New(config *types.Libp2pConfig, fs afero.Fs) (*Libp2p, error) {
if config == nil {
return nil, errors.New("config is nil")
}
if config.Scheduler == nil {
return nil, errors.New("scheduler is nil")
}
return &Libp2p{
config: config,
discoveredPeers: make([]peer.AddrInfo, 0),
pubsubTopics: make(map[string]*pubsub.Topic),
topicSubscription: make(map[string]*pubsub.Subscription),
fs: fs,
}, nil
}
// Init initializes a libp2p host with its dependencies.
func (l *Libp2p) Init(context context.Context) error {
host, dht, pubsub, err := NewHost(context, l.config, l.fs)
if err != nil {
zlog.Sugar().Error(err)
return err
}
l.Host = host
l.DHT = dht
l.PS = host.Peerstore()
l.discovery = drouting.NewRoutingDiscovery(dht)
l.pubsub = pubsub
l.handlerRegistry = NewHandlerRegistry(host)
return nil
}
// Start performs network bootstrapping, peer discovery and protocols handling.
func (l *Libp2p) Start(context context.Context) error {
// set stream handlers
l.registerStreamHandlers()
// bootstrap should return error if it had an error
err := l.Bootstrap(context, l.config.BootstrapPeers)
if err != nil {
zlog.Sugar().Errorf("failed to start network: %v", err)
return err
}
// advertise randevouz discovery
err = l.advertiseForRendezvousDiscovery(context)
if err != nil {
// TODO: the error might be misleading as a peer can normally work well if an error
// is returned here (e.g.: the error is yielded in tests even though all tests pass).
zlog.Sugar().Errorf("failed to start network with randevouz discovery: %v", err)
}
// discover
err = l.DiscoverDialPeers(context)
if err != nil {
zlog.Sugar().Errorf("failed to discover peers: %v", err)
}
// register period peer discoveryTask task
discoveryTask := &bt.Task{
Name: "Peer Discovery",
Description: "Periodic task to discover new peers every 15 minutes",
Function: func(args interface{}) error {
return l.DiscoverDialPeers(context)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 15 * time.Minute}},
}
l.discoveryTask = l.config.Scheduler.AddTask(discoveryTask)
l.config.Scheduler.Start()
return nil
}
// RegisterStreamMessageHandler registers a stream handler for a specific protocol.
func (l *Libp2p) RegisterStreamMessageHandler(messageType types.MessageType, handler StreamHandler) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithStreamCallback(messageType, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// RegisterBytesMessageHandler registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) RegisterBytesMessageHandler(messageType types.MessageType, handler func(data []byte)) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithBytesCallback(messageType, l.handleReadBytesFromStream, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// HandleMessage registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) HandleMessage(messageType string, handler func(data []byte)) error {
return l.RegisterBytesMessageHandler(types.MessageType(messageType), handler)
}
func (l *Libp2p) handleReadBytesFromStream(s network.Stream) {
callback, ok := l.handlerRegistry.bytesHandlers[s.Protocol()]
if !ok {
s.Close()
return
}
c := bufio.NewReader(s)
defer s.Close()
// read the first 8 bytes to determine the size of the message
msgLengthBuffer := make([]byte, 8)
_, err := c.Read(msgLengthBuffer)
if err != nil {
return
}
// create a buffer with the size of the message and then read until its full
lengthPrefix := int64(binary.LittleEndian.Uint64(msgLengthBuffer))
buf := make([]byte, lengthPrefix)
// read the full message
_, err = io.ReadFull(c, buf)
if err != nil {
return
}
callback(buf)
}
// SendMessage sends a message to a list of peers.
func (l *Libp2p) SendMessage(ctx context.Context, addrs []string, msg types.MessageEnvelope) error {
var wg sync.WaitGroup
errCh := make(chan error, len(addrs))
for _, addr := range addrs {
wg.Add(1)
go func(addr string) {
defer wg.Done()
err := l.sendMessage(ctx, addr, msg)
if err != nil {
errCh <- err
}
}(addr)
}
wg.Wait()
close(errCh)
var result error
for err := range errCh {
if result == nil {
result = err
} else {
result = fmt.Errorf("%v; %v", result, err)
}
}
return result
}
// OpenStream opens a stream to a remote address and returns the stream for the caller to handle.
func (l *Libp2p) OpenStream(ctx context.Context, addr string, messageType types.MessageType) (network.Stream, error) {
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil, fmt.Errorf("invalid multiaddress: %w", err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(maddr)
if err != nil {
return nil, fmt.Errorf("could not resolve peer info: %w", err)
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(messageType))
if err != nil {
return nil, fmt.Errorf("failed to open stream: %w", err)
}
return stream, nil
}
// GetMultiaddr returns the peer's multiaddr.
func (l *Libp2p) GetMultiaddr() ([]multiaddr.Multiaddr, error) {
peerInfo := peer.AddrInfo{
ID: l.Host.ID(),
Addrs: l.Host.Addrs(),
}
return peer.AddrInfoToP2pAddrs(&peerInfo)
}
// Stop performs a cleanup of any resources used in this package.
func (l *Libp2p) Stop() error {
var errorMessages []string
l.config.Scheduler.RemoveTask(l.discoveryTask.ID)
if err := l.DHT.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if err := l.Host.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if len(errorMessages) > 0 {
return errors.New(strings.Join(errorMessages, "; "))
}
return nil
}
// Stat returns the status about the libp2p network.
func (l *Libp2p) Stat() types.NetworkStats {
var lAddrs []string
for _, addr := range l.Host.Addrs() {
lAddrs = append(lAddrs, addr.String())
}
return types.NetworkStats{
ID: l.Host.ID().String(),
ListenAddr: strings.Join(lAddrs, ", "),
}
}
// Ping the remote address. The remote address is the encoded peer id which will be decoded and used here.
//
// TODO (Return error once): something that was confusing me when using this method is that the error is
// returned twice if any. Once as a field of PingResult and one as a return value.
func (l *Libp2p) Ping(ctx context.Context, peerIDAddress string, timeout time.Duration) (types.PingResult, error) {
// avoid dial to self attempt
if peerIDAddress == l.Host.ID().String() {
err := errors.New("can't ping self")
return types.PingResult{Success: false, Error: err}, err
}
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
remotePeer, err := peer.Decode(peerIDAddress)
if err != nil {
return types.PingResult{}, err
}
pingChan := ping.Ping(pingCtx, l.Host, remotePeer)
select {
case res := <-pingChan:
if res.Error != nil {
zlog.Sugar().Errorf("failed to ping peer %s: %v", peerIDAddress, res.Error)
return types.PingResult{
Success: false,
RTT: res.RTT,
Error: res.Error,
}, res.Error
}
return types.PingResult{
RTT: res.RTT,
Success: true,
}, nil
case <-pingCtx.Done():
return types.PingResult{
Error: pingCtx.Err(),
}, pingCtx.Err()
}
}
// ResolveAddress resolves the address by given a peer id.
func (l *Libp2p) ResolveAddress(ctx context.Context, id string) ([]string, error) {
pid, err := peer.Decode(id)
if err != nil {
return nil, fmt.Errorf("failed to resolve invalid peer: %w", err)
}
// resolve ourself
if l.Host.ID().String() == id {
multiAddrs, err := l.GetMultiaddr()
if err != nil {
return nil, fmt.Errorf("failed to resolve self: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
pi, err := l.DHT.FindPeer(ctx, pid)
if err != nil {
return nil, fmt.Errorf("failed to resolve address %s: %w", id, err)
}
peerInfo := peer.AddrInfo{
ID: pi.ID,
Addrs: pi.Addrs,
}
multiAddrs, err := peer.AddrInfoToP2pAddrs(&peerInfo)
if err != nil {
return nil, fmt.Errorf("failed to convert to p2p address: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
// Query return all the advertisements in the network related to a key.
// The network is queried to find providers for the given key, and peers which we aren't connected to can be retrieved.
func (l *Libp2p) Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error) {
if key == "" {
return nil, errors.New("advertisement key is empty")
}
customCID, err := createCIDFromKey(key)
if err != nil {
return nil, fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
addrInfo, err := l.DHT.FindProviders(ctx, customCID)
if err != nil {
return nil, fmt.Errorf("failed to find providers for key %s: %w", key, err)
}
var advertisements []*commonproto.Advertisement
for _, v := range addrInfo {
// TODO: use go routines to get the values in parallel.
bytesAdvertisement, err := l.DHT.GetValue(ctx, l.getCustomNamespace(key, v.ID.String()))
if err != nil {
continue
}
var ad commonproto.Advertisement
if err := proto.Unmarshal(bytesAdvertisement, &ad); err != nil {
return nil, fmt.Errorf("failed to unmarshal advertisement payload: %w", err)
}
advertisements = append(advertisements, &ad)
}
return advertisements, nil
}
// Advertise given data and a key pushes the data to the dht.
func (l *Libp2p) Advertise(ctx context.Context, key string, data []byte) error {
if key == "" {
return errors.New("advertisement key is empty")
}
pubKeyBytes, err := l.getPublicKey()
if err != nil {
return fmt.Errorf("failed to get public key: %w", err)
}
envelope := &commonproto.Advertisement{
PeerId: l.Host.ID().String(),
Timestamp: time.Now().Unix(),
Data: data,
PublicKey: pubKeyBytes,
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
pubKeyBytes,
}, nil)
sig, err := l.sign(concatenatedBytes)
if err != nil {
return fmt.Errorf("failed to sign advertisement envelope content: %w", err)
}
envelope.Signature = sig
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
return fmt.Errorf("failed to marshal advertise envelope: %w", err)
}
customCID, err := createCIDFromKey(key)
if err != nil {
return fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
err = l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), envelopeBytes)
if err != nil {
return fmt.Errorf("failed to put key %s into the dht: %w", key, err)
}
err = l.DHT.Provide(ctx, customCID, true)
if err != nil {
return fmt.Errorf("failed to provide key %s into the dht: %w", key, err)
}
return nil
}
// Unadvertise removes the data from the dht.
func (l *Libp2p) Unadvertise(ctx context.Context, key string) error {
err := l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), nil)
if err != nil {
return fmt.Errorf("failed to remove key %s from the DHT: %w", key, err)
}
return nil
}
// Publish publishes data to a topic.
// The requirements are that only one topic handler should exist per topic.
func (l *Libp2p) Publish(ctx context.Context, topic string, data []byte) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to publish: %w", err)
}
err = topicHandler.Publish(ctx, data)
if err != nil {
return fmt.Errorf("failed to publish to topic %s: %w", topic, err)
}
return nil
}
// Subscribe subscribes to a topic and sends the messages to the handler.
func (l *Libp2p) Subscribe(ctx context.Context, topic string, handler func(data []byte)) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to subscribe to topic: %w", err)
}
sub, err := topicHandler.Subscribe()
if err != nil {
return fmt.Errorf("failed to subscribe to topic %s: %w", topic, err)
}
l.topicMux.Lock()
l.topicSubscription[topic] = sub
l.topicMux.Unlock()
go func() {
for {
msg, err := sub.Next(ctx)
if err != nil {
continue
}
handler(msg.Data)
}
}()
return nil
}
func (l *Libp2p) sendMessage(ctx context.Context, addr string, msg types.MessageEnvelope) error {
peerAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return fmt.Errorf("invalid multiaddr %s: %v", addr, err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
return fmt.Errorf("failed to get peer info %s: %v", addr, err)
}
// we are delivering a message to ourself
// we should use the handler to send the message to the handler directly which has been previously registered.
if peerInfo.ID.String() == l.Host.ID().String() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data)
return nil
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return fmt.Errorf("failed to connect to peer %v: %v", peerInfo.ID, err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(msg.Type))
if err != nil {
return fmt.Errorf("failed to open stream to peer %v: %v", peerInfo.ID, err)
}
defer stream.Close()
requestBufferSize := 8 + len(msg.Data)
if requestBufferSize > MaxMessageLengthMB*MB {
return fmt.Errorf("message size %d is greater than limit %d bytes", requestBufferSize, MaxMessageLengthMB*MB)
}
requestPayloadWithLength := make([]byte, requestBufferSize)
binary.LittleEndian.PutUint64(requestPayloadWithLength, uint64(len(msg.Data)))
copy(requestPayloadWithLength[8:], msg.Data)
_, err = stream.Write(requestPayloadWithLength)
if err != nil {
return fmt.Errorf("failed to send message to peer %v: %v", peerInfo.ID, err)
}
return nil
}
// getOrJoinTopicHandler gets the topic handler, it will be created if it doesn't exist.
// for publishing and subscribing its needed therefore its implemented in this function.
func (l *Libp2p) getOrJoinTopicHandler(topic string) (*pubsub.Topic, error) {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
t, err := l.pubsub.Join(topic)
if err != nil {
return nil, fmt.Errorf("failed to join topic %s: %w", topic, err)
}
topicHandler = t
l.pubsubTopics[topic] = t
}
return topicHandler, nil
}
// Unsubscribe cancels the subscription to a topic
func (l *Libp2p) Unsubscribe(topic string) error {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to topic: %s", topic)
}
// delete subscription handler and subscription
sub, ok := l.topicSubscription[topic]
if ok {
sub.Cancel()
delete(l.topicSubscription, topic)
}
if err := topicHandler.Close(); err != nil {
return fmt.Errorf("failed to close topic handler: %w", err)
}
delete(l.pubsubTopics, topic)
return nil
}
func (l *Libp2p) VisiblePeers() []peer.AddrInfo {
return l.discoveredPeers
}
func (l *Libp2p) KnownPeers() ([]peer.AddrInfo, error) {
knownPeers := l.Host.Peerstore().Peers()
peers := make([]peer.AddrInfo, 0, len(knownPeers))
for _, p := range knownPeers {
peers = append(peers, peer.AddrInfo{ID: p})
}
return peers, nil
}
func (l *Libp2p) DumpDHTRoutingTable() ([]kbucket.PeerInfo, error) {
rt := l.DHT.RoutingTable()
return rt.GetPeerInfos(), nil
}
func (l *Libp2p) registerStreamHandlers() {
l.pingService = ping.NewPingService(l.Host)
l.Host.SetStreamHandler(protocol.ID("/ipfs/ping/1.0.0"), l.pingService.PingHandler)
}
func (l *Libp2p) sign(data []byte) ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
signature, err := privKey.Sign(data)
if err != nil {
return nil, fmt.Errorf("failed to sign data: %w", err)
}
return signature, nil
}
func (l *Libp2p) getPublicKey() ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
pubKey := privKey.GetPublic()
return pubKey.Raw()
}
func (l *Libp2p) getCustomNamespace(key, peerID string) string {
return fmt.Sprintf("%s-%s-%s", l.config.CustomNamespace, key, peerID)
}
func createCIDFromKey(key string) (cid.Cid, error) {
hash := sha256.Sum256([]byte(key))
mh, err := multihash.Encode(hash[:], multihash.SHA2_256)
if err != nil {
return cid.Cid{}, err
}
return cid.NewCidV1(cid.Raw, mh), nil
}
func CleanupPeer(id peer.ID) error {
zlog.Warn("CleanupPeer: Stub")
return nil
}
func PingPeer(ctx context.Context, target peer.ID) (bool, *ping.Result) {
zlog.Warn("PingPeer: Stub")
return false, nil
}
func DumpKademliaDHT(ctx context.Context) ([]types.PeerData, error) {
zlog.Warn("DumpKademliaDHT: Stub")
return nil, nil
}
func OldPingPeer(ctx context.Context, target peer.ID) (bool, *types.PingResult) {
zlog.Warn("OldPingPeer: Stub")
return false, nil
}
package libp2p
import (
"bytes"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/pnet"
"github.com/spf13/afero"
)
/*
** Swarm key **
By default, the swarm key shall be stored in a file named `swarm.key`
using the following pathbased codec:
`/key/swarm/psk/1.0.0/<base_encoding>/<256_bits_key>`
`<base_encoding>` is either bin, base16 or base64.
*/
// TODO-pnet-1: we shouldn't handle configuration paths here, a general configuration path
// should be provided by /internal/config.go
func getBasePath(fs afero.Fs) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("error getting home directory: %w", err)
}
nunetDir := filepath.Join(homeDir, ".nunet")
return nunetDir, nil
}
// configureSwarmKey try to read the swarm key from `<config_path>/swarm.key` file.
// If a swarm key is not found, generate a new one.
//
// TODO-ask: should we continue to generate a new swarm key if one is not found?
// Or we should enforce the user to use some cmd/API rpc to generate a new one?
func configureSwarmKey(fs afero.Fs) (pnet.PSK, error) {
var psk pnet.PSK
var err error
psk, err = getSwarmKey(fs)
if err != nil {
psk, err = generateSwarmKey(fs)
if err != nil {
return nil, fmt.Errorf("failed to generate new swarm key: %w", err)
}
}
return psk, nil
}
// getSwarmKey reads the swarm key from a file
func getSwarmKey(fs afero.Fs) (pnet.PSK, error) {
homeDir, err := getBasePath(fs)
swarmkey, err := afero.ReadFile(fs, filepath.Join(homeDir, "swarm.key"))
if err != nil {
return nil, fmt.Errorf("failed to read swarm key file: %w", err)
}
psk, err := pnet.DecodeV1PSK(bytes.NewReader(swarmkey))
if err != nil {
return nil, fmt.Errorf("failed to configure private network: %s", err)
}
// TODO-ask: should we return psk fingerprint?
return psk, nil
}
// generateSwarmKey generates a new swarm key, storing it within
// `<nunet_config_dir>/swarm.key`.
func generateSwarmKey(fs afero.Fs) (pnet.PSK, error) {
priv, _, err := crypto.GenerateKeyPair(crypto.Secp256k1, 256)
if err != nil {
return nil, err
}
privBytes, err := crypto.MarshalPrivateKey(priv)
if err != nil {
return nil, err
}
encodedKey := base64.StdEncoding.EncodeToString(privBytes)
swarmKeyWithCodec := fmt.Sprintf("/key/swarm/psk/1.0.0/\n/base64/\n%s\n", encodedKey)
// TODO-pnet-1
nunetDir, err := getBasePath(fs)
if err != nil {
return nil, err
}
swarmKeyPath := filepath.Join(nunetDir, "swarm.key")
if err := afero.WriteFile(fs, swarmKeyPath, []byte(swarmKeyWithCodec), 0600); err != nil {
return nil, fmt.Errorf("error writing swarm key to file: %w", err)
}
psk, err := pnet.DecodeV1PSK(bytes.NewReader([]byte(swarmKeyWithCodec)))
if err != nil {
return nil, fmt.Errorf("failed to decode generated swarm key: %s", err)
}
zlog.Sugar().Infof("A new Swarm key was generated and written to %s\n"+
"IMPORTANT: If you'd like to create the swarm key using a cryptography algorithm "+
"of your choice, just modify the swarm.key file with your own key.\n"+
"The content of `swarm.key` should look like: `/key/swarm/psk/1.0.0/<base_encoding>/<your_key>`\n"+
"where `<base_encoding>` is either `bin`, `base16`, or `base64`.\n",
swarmKeyPath,
)
return psk, nil
}
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/spf13/afero"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/network/libp2p"
)
// Messenger defines the interface for sending messages.
type Messenger interface {
// SendMessage sends a message to the given address.
SendMessage(ctx context.Context, addrs []string, msg types.MessageEnvelope) error
}
type Network interface {
// Messenger embedded interface
Messenger
// Init initializes the network
Init(context.Context) error
// Start starts the network
Start(context context.Context) error
// Stat returns the network information
Stat() types.NetworkStats
// Ping pings the given address and returns the PingResult
Ping(ctx context.Context, address string, timeout time.Duration) (types.PingResult, error)
// HandleMessage is responsible for registering a message type and its handler.
HandleMessage(messageType string, handler func(data []byte)) error
// ResolveAddress given an id it retruns the address of the peer.
// In libp2p, id represents the peerID and the response is the addrinfo
ResolveAddress(ctx context.Context, id string) ([]string, error)
// Advertise advertises the given data with the given adId
// such as advertising device capabilities on the DHT
Advertise(ctx context.Context, key string, data []byte) error
// Unadvertise stops advertising data corresponding to the given adId
Unadvertise(ctx context.Context, key string) error
// Query returns the network advertisement
Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error)
// Publish publishes the given data to the given topic if the network
// type allows publish/subscribe functionality such as gossipsub or nats
Publish(ctx context.Context, topic string, data []byte) error
// Subscribe subscribes to the given topic and calls the handler function
// if the network type allows it simmilar to Publish()
Subscribe(ctx context.Context, topic string, handler func(data []byte)) error
// Unsubscribe from a topic
Unsubscribe(topic string) error
// Stop stops the network including any existing advertisments and subscriptions
Stop() error
}
// NewNetwork returns a new network given the configuration.
func NewNetwork(netConfig *types.NetworkConfig, fs afero.Fs) (Network, error) {
// TODO: probable additional params to receive: DB, FileSystem
if netConfig == nil {
return nil, errors.New("network configuration is nil")
}
switch netConfig.Type {
case types.Libp2pNetwork:
ln, err := libp2p.New(&netConfig.Libp2pConfig, fs)
return ln, err
case types.NATSNetwork:
return nil, errors.New("not implemented")
default:
return nil, fmt.Errorf("unsupported network type: %s", netConfig.Type)
}
}
package basic_controller
import (
"context"
"fmt"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/utils"
)
// BasicVolumeController is the default implementation of the VolumeController.
// It persists storage volumes information using the StorageVolume.
type BasicVolumeController struct {
// repo is the repository for storage volume operations
repo repositories.StorageVolume
// basePath is the base path where volumes are stored under
basePath string
// file system to act upon
FS afero.Fs
}
// NewDefaultVolumeController returns a new instance of BasicVolumeController
//
// TODO-BugFix [path]: volBasePath might not end with `/`, causing errors when calling methods.
// We need to validate it using the `path` library or just verifying the string.
func NewDefaultVolumeController(repo repositories.StorageVolume, volBasePath string, fs afero.Fs) (*BasicVolumeController, error) {
return &BasicVolumeController{
repo: repo,
basePath: volBasePath,
FS: fs,
}, nil
}
// CreateVolume creates a new storage volume given a source (S3, IPFS, job, etc). The
// creation of a storage volume effectively creates an empty directory in the local filesystem
// and writes a record in the database.
//
// The directory name follows the format: `<volSource> + "-" + <name>
// where `name` is random.
//
// TODO-maybe [withName]: allow callers to specify custom name for path
func (vc *BasicVolumeController) CreateVolume(volSource storage.VolumeSource, opts ...storage.CreateVolOpt) (types.StorageVolume, error) {
vol := types.StorageVolume{
Private: false,
ReadOnly: false,
EncryptionType: types.EncryptionTypeNull,
}
for _, opt := range opts {
opt(&vol)
}
vol.Path = vc.basePath + string(volSource) + "-" + utils.RandomString(16)
if err := vc.FS.Mkdir(vol.Path, 0770); err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %w", err)
}
createdVol, err := vc.repo.Create(context.Background(), vol)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume in repository: %w", err)
}
return createdVol, nil
}
// LockVolume makes the volume read-only, not only changing the field value but also changing file permissions.
// It should be used after all necessary data has been written.
// It optionally can also set the CID and mark the volume as private.
//
// TODO-maybe [CID]: maybe calculate CID of every volume in case WithCID opt is not provided
func (vc *BasicVolumeController) LockVolume(pathToVol string, opts ...storage.LockVolOpt) error {
query := vc.repo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("Path", pathToVol))
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
return fmt.Errorf("failed to find storage volume with path %s - Error: %w", pathToVol, err)
}
for _, opt := range opts {
opt(&vol)
}
// update records
vol.ReadOnly = true
updatedVol, err := vc.repo.Update(context.Background(), vol.ID, vol)
if err != nil {
return fmt.Errorf("failed to update storage volume with path %s - Error: %w", pathToVol, err)
}
// change file permissions
if err := vc.FS.Chmod(updatedVol.Path, 0400); err != nil {
return fmt.Errorf("failed to make storage volume read-only (path: %s): %w", updatedVol.Path, err)
}
return nil
}
// WithPrivate designates a given volume as private. It can be used both
// when creating or locking a volume.
func WithPrivate[T storage.CreateVolOpt | storage.LockVolOpt]() T {
return func(v *types.StorageVolume) {
v.Private = true
}
}
// WithCID sets the CID of a given volume if already calculated
//
// TODO [validate]: check if CID provided is valid
func WithCID(cid string) storage.LockVolOpt {
return func(v *types.StorageVolume) {
v.CID = cid
}
}
// DeleteVolume deletes a given storage volume record from the database.
// Identifier is either a CID or a path of a volume. Therefore, records for both
// will be deleted.
//
// Note [CID]: if we start to type CID as cid.CID, we may have to use generics here
// as in `[T string | cid.CID]`
func (vc *BasicVolumeController) DeleteVolume(identifier string, idType storage.IDType) error {
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
return fmt.Errorf("identifier type not supported")
}
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
if err == repositories.NotFoundError {
return fmt.Errorf("volume not found: %w", err)
}
return fmt.Errorf("failed to find volume: %w", err)
}
err = vc.repo.Delete(context.Background(), vol.ID)
if err != nil {
return fmt.Errorf("failed to delete volume: %w", err)
}
return nil
}
// ListVolumes returns a list of all storage volumes stored on the database
//
// TODO [filter]: maybe add opts to filter results by certain values
func (vc *BasicVolumeController) ListVolumes() ([]types.StorageVolume, error) {
volumes, err := vc.repo.FindAll(context.Background(), vc.repo.GetQuery())
if err != nil {
return nil, fmt.Errorf("failed to list volumes: %w", err)
}
return volumes, nil
}
// GetSize returns the size of a volume
// TODO-minor: identify which measurement type will be used
func (vc *BasicVolumeController) GetSize(identifier string, idType storage.IDType) (int64, error) {
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
return 0, fmt.Errorf("unsupported ID type: %d", idType)
}
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
return 0, fmt.Errorf("failed to find volume: %w", err)
}
size, err := utils.GetDirectorySize(vc.FS, vol.Path)
if err != nil {
return 0, fmt.Errorf("failed to get directory size: %w", err)
}
return size, nil
}
// EncryptVolume encrypts a given volume
func (vc *BasicVolumeController) EncryptVolume(path string, encryptor types.Encryptor, encryptionType types.EncryptionType) error {
return fmt.Errorf("not implemented")
}
// DecryptVolume decrypts a given volume
func (vc *BasicVolumeController) DecryptVolume(path string, decryptor types.Decryptor, decryptionType types.EncryptionType) error {
return fmt.Errorf("not implemented")
}
// TODO-minor: compiler-time check for interface implementation
var _ storage.VolumeController = (*BasicVolumeController)(nil)
package basic_controller
import (
"context"
"fmt"
"os"
"testing"
clover "github.com/ostafen/clover/v2"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories/clover"
"gitlab.com/nunet/device-management-service/types"
)
type VolControllerTestSuiteHelper struct {
BasicVolController *BasicVolumeController
Fs afero.Fs
DB *clover.DB
Volumes map[string]*types.StorageVolume
TempDBDir string
}
// SetupVolControllerTestSuite sets up a volume controller with 0-n volumes given a base path.
// If volumes are inputed, directories will be created and volumes will be stored in the database
func SetupVolControllerTestSuite(t *testing.T, basePath string,
volumes map[string]*types.StorageVolume) (*VolControllerTestSuiteHelper, error) {
tempDir, err := os.MkdirTemp("", "clover-test-*")
if err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
db, err := repositories_clover.NewDB(tempDir, []string{"storage_volume"})
if err != nil {
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to open clover db: %w", err)
}
fs := afero.NewMemMapFs()
err = fs.MkdirAll(basePath, 0755)
if err != nil {
db.Close()
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to create base path: %w", err)
}
repo := repositories_clover.NewStorageVolume(db)
vc, err := NewDefaultVolumeController(repo, basePath, fs)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume controller: %w", err)
}
for _, vol := range volumes {
// create root volume dir
err = fs.MkdirAll(vol.Path, 0755)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume dir: %w", err)
}
// create volume record in db
_, err = repo.Create(context.Background(), *vol)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume record: %w", err)
}
}
helper := &VolControllerTestSuiteHelper{vc, fs, db, volumes, tempDir}
t.Cleanup(func() {
TearDownVolControllerTestSuite(helper)
})
return helper, nil
}
// TearDownVolControllerTestSuite cleans up resources created during setup
func TearDownVolControllerTestSuite(helper *VolControllerTestSuiteHelper) {
if helper.DB != nil {
helper.DB.Close()
}
if helper.TempDBDir != "" {
os.RemoveAll(helper.TempDBDir)
}
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/storage/basic_controller"
)
// Download fetch files from a given S3 bucket. The key may be a directory ending
// with `/` or have a wildcard (`*`) so it handles normal S3 folders but it does
// not handle x-directory.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *S3Storage) Download(ctx context.Context, sourceSpecs *types.SpecConfig) (
types.StorageVolume, error) {
var storageVol types.StorageVolume
source, err := DecodeInputSpec(sourceSpecs)
if err != nil {
return types.StorageVolume{}, err
}
storageVol, err = s.volController.CreateVolume(storage.VolumeSourceS3)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %v", err)
}
resolvedObjects, err := resolveStorageKey(ctx, s.Client, &source)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to resolve storage key: %v", err)
}
for _, resolvedObject := range resolvedObjects {
err = s.downloadObject(ctx, &source, resolvedObject, storageVol.Path)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to download s3 object: %v", err)
}
}
// after data is filled within the volume, we have to lock it
err = s.volController.LockVolume(storageVol.Path)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to lock storage volume: %v", err)
}
return storageVol, nil
}
func (s *S3Storage) downloadObject(ctx context.Context, source *S3InputSource,
object s3Object, volPath string) error {
outputPath := filepath.Join(volPath, *object.key)
// use the same file system instance used by the Volume Controller
var fs afero.Fs
if basicVolController, ok := s.volController.(*basic_controller.BasicVolumeController); ok {
fs = basicVolController.FS
}
err := fs.MkdirAll(outputPath, 0755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
if object.isDir {
// if object is a directory, we don't need to download it (just create the dir)
return nil
}
outputFile, err := fs.OpenFile(outputPath, os.O_RDWR|os.O_CREATE, 0755)
if err != nil {
return err
}
defer outputFile.Close()
zlog.Sugar().Debugf("Downloading s3 object %s to %s", *object.key, outputPath)
_, err = s.downloader.Download(ctx, outputFile, &s3.GetObjectInput{
Bucket: aws.String(source.Bucket),
Key: object.key,
IfMatch: object.eTag,
})
return nil
}
// resolveStorageKey returns a list of s3 objects within a bucket accordingly to the key provided.
func resolveStorageKey(ctx context.Context, client *s3.Client, source *S3InputSource) ([]s3Object, error) {
key := source.Key
if key == "" {
return nil, fmt.Errorf("key is required")
}
// Check if the key represents a single object
if !strings.HasSuffix(key, "/") && !strings.Contains(key, "*") {
return resolveSingleObject(ctx, client, source)
}
// key represents multiple objects
return resolveObjectsWithPrefix(ctx, client, source)
}
func resolveSingleObject(ctx context.Context, client *s3.Client, source *S3InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
headObjectInput := &s3.HeadObjectInput{
Bucket: aws.String(source.Bucket),
Key: aws.String(key),
}
headObjectOut, err := client.HeadObject(ctx, headObjectInput)
if err != nil {
return []s3Object{}, fmt.Errorf("failed to retrieve object metadata: %v", err)
}
// TODO-minor: validate checksum if provided
if strings.HasPrefix(*headObjectOut.ContentType, "application/x-directory") {
return []s3Object{}, fmt.Errorf("x-directory is not yet handled!")
}
return []s3Object{
{
key: aws.String(source.Key),
eTag: headObjectOut.ETag,
size: *headObjectOut.ContentLength,
},
}, nil
}
func resolveObjectsWithPrefix(ctx context.Context, client *s3.Client, source *S3InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
// List objects with the given prefix
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(source.Bucket),
Prefix: aws.String(key),
}
var objects []s3Object
paginator := s3.NewListObjectsV2Paginator(client, listObjectsInput)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list objects: %v", err)
}
for _, obj := range page.Contents {
objects = append(objects, s3Object{
key: aws.String(*obj.Key),
size: *obj.Size,
isDir: strings.HasSuffix(*obj.Key, "/"),
})
}
}
return objects, nil
}
package s3
import (
"context"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
)
// GetAWSDefaultConfig returns the default AWS config based on environment variables,
// shared configuration and shared credentials files.
func GetAWSDefaultConfig() (aws.Config, error) {
var optFns []func(*config.LoadOptions) error
return config.LoadDefaultConfig(context.Background(), optFns...)
}
func hasValidCredentials(config aws.Config) bool {
credentials, err := config.Credentials.Retrieve(context.Background())
if err != nil {
return false
}
return credentials.HasKeys()
}
// sanitizeKey removes trailing spaces and wildcards
func sanitizeKey(key string) string {
return strings.TrimSuffix(strings.TrimSpace(key), "*")
}
package s3
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("s3")
}
package s3
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/storage"
)
type S3Storage struct {
*s3.Client
volController storage.VolumeController
downloader *s3Manager.Downloader
uploader *s3Manager.Uploader
}
type s3Object struct {
key *string
eTag *string
versionID *string
size int64
isDir bool
}
// NewClient creates a new S3Storage which includes a S3-SDK client.
// It depends on a VolumeController to manage the volumes being acted upon.
func NewClient(config aws.Config, volController storage.VolumeController) (*S3Storage, error) {
if !hasValidCredentials(config) {
return nil, fmt.Errorf("invalid credentials")
}
s3Client := s3.NewFromConfig(config)
return &S3Storage{
s3Client,
volController,
s3Manager.NewDownloader(s3Client),
s3Manager.NewUploader(s3Client),
}, nil
}
func (s *S3Storage) Size(ctx context.Context, source *types.SpecConfig) (uint64, error) {
inputSource, err := DecodeInputSpec(source)
if err != nil {
return 0, fmt.Errorf("failed to decode input spec: %v", err)
}
input := &s3.HeadObjectInput{
Bucket: aws.String(inputSource.Bucket),
Key: aws.String(inputSource.Key),
}
output, err := s.HeadObject(ctx, input)
if err != nil {
return 0, fmt.Errorf("failed to get object size: %v", err)
}
return uint64(*output.ContentLength), nil
}
// Compile time interface check
// var _ storage.StorageProvider = (*S3Storage)(nil)
package s3
import (
"fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure"
"gitlab.com/nunet/device-management-service/types"
)
type S3InputSource struct {
Bucket string
Key string
Filter string
Region string
Endpoint string
}
func (s S3InputSource) Validate() error {
if s.Bucket == "" {
return fmt.Errorf("invalid s3 storage params: bucket cannot be empty")
}
return nil
}
func (s S3InputSource) ToMap() map[string]interface{} {
return structs.Map(s)
}
func DecodeInputSpec(spec *types.SpecConfig) (S3InputSource, error) {
if !spec.IsType(types.StorageProviderS3) {
return S3InputSource{}, fmt.Errorf("invalid storage source type. Expected %s but received %s", types.StorageProviderS3, spec.Type)
}
inputParams := spec.Params
if inputParams == nil {
return S3InputSource{}, fmt.Errorf("invalid storage input source params. cannot be nil")
}
var c S3InputSource
if err := mapstructure.Decode(spec.Params, &c); err != nil {
return c, err
}
return c, c.Validate()
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/storage/basic_controller"
)
// Upload uploads all files (recursively) from a local volume to an S3 bucket.
// It handles directories.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *S3Storage) Upload(ctx context.Context, vol types.StorageVolume,
destinationSpecs *types.SpecConfig) error {
target, err := DecodeInputSpec(destinationSpecs)
if err != nil {
return fmt.Errorf("failed to decode input spec: %v", err)
}
sanitizedKey := sanitizeKey(target.Key)
// set file system to act upon based on the volume controller implementation
var fs afero.Fs
if basicVolController, ok := s.volController.(*basic_controller.BasicVolumeController); ok {
fs = basicVolController.FS
}
zlog.Sugar().Debugf("Uploading files from %s to s3://%s/%s", vol.Path, target.Bucket, sanitizedKey)
err = afero.Walk(fs, vol.Path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Skip directories
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(vol.Path, filePath)
if err != nil {
return fmt.Errorf("failed to get relative path: %v", err)
}
// Construct the S3 key by joining the sanitized key and the relative path
s3Key := filepath.Join(sanitizedKey, relPath)
file, err := fs.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
zlog.Sugar().Debugf("Uploading %s to s3://%s/%s", filePath, target.Bucket, s3Key)
_, err = s.uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(target.Bucket),
Key: aws.String(s3Key),
Body: file,
})
if err != nil {
return fmt.Errorf("failed to upload file to S3: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("upload failed. It's possible that some files were uploaded; Error: %v", err)
}
return nil
}
package types
type Executor struct {
ExecutorType ExecutorType `json:"executor_type"`
}
type ExecutorType string
const (
ExecutorTypeDocker = "docker"
ExecutorTypeFirecracker = "firecracker"
ExecutorTypeWasm = "wasm"
ExecutionStatusCodeSuccess = 0
)
// ExecutionRequest is the request object for executing a job
type ExecutionRequest struct {
JobID string // ID of the job to execute
ExecutionID string // ID of the execution
EngineSpec *SpecConfig // Engine spec for the execution
Resources *ExecutionResources // Resources for the execution
Inputs []*StorageVolumeExecutor // Input volumes for the execution
Outputs []*StorageVolumeExecutor // Output volumes for the results
ResultsDir string // Directory to store the results
}
// ExecutionResult is the result of an execution
type ExecutionResult struct {
STDOUT string `json:"stdout"` // STDOUT of the execution
STDERR string `json:"stderr"` // STDERR of the execution
ExitCode int `json:"exit_code"` // Exit code of the execution
ErrorMsg string `json:"error_msg"` // Error message if the execution failed
}
// NewExecutionResult creates a new ExecutionResult object
func NewExecutionResult(code int) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: code,
}
}
// NewFailedExecutionResult creates a new ExecutionResult object for a failed execution
// It sets the error message from the provided error and sets the exit code to -1
func NewFailedExecutionResult(err error) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: -1,
ErrorMsg: err.Error(),
}
}
// LogStreamRequest is the request object for streaming logs from an execution
type LogStreamRequest struct {
JobID string // ID of the job
ExecutionID string // ID of the execution
Tail bool // Tail the logs
Follow bool // Follow the logs
}
package types
const (
NetP2P = "p2p"
)
// NetworkSpec is a stub. Please expand based on requirements.
type NetworkSpec struct {
}
// NetConfig is a stub. Please expand it or completely change it based on requirements.
type NetConfig struct {
NetworkSpec SpecConfig `json:"network_spec"` // Network specification
}
func (nc *NetConfig) GetNetworkConfig() *SpecConfig {
return &nc.NetworkSpec
}
// NetworkStats should contain all network info the user is interested in.
// for now there's only peerID and listening address but reachability, local and remote addr etc...
// can be added when necessary.
type NetworkStats struct {
ID string `json:"id"`
ListenAddr string `json:"listen_addr"`
}
// MessageInfo is a stub. Please expand it or completely change it based on requirements.
type MessageInfo struct {
Info string `json:"info"` // Message information
}
package types
import (
"fmt"
"strings"
)
type GPUVendor string
const (
GPUVendorNvidia GPUVendor = "NVIDIA"
GPUVendorAMDATI GPUVendor = "AMD/ATI"
GPUVendorIntel GPUVendor = "Intel"
GPUVendorUnknown GPUVendor = "Unknown"
None GPUVendor = "None"
)
func ParseGPUVendor(vendor string) GPUVendor {
switch {
case strings.Contains(vendor, "NVIDIA"):
return GPUVendorNvidia
case strings.Contains(vendor, "AMD") || strings.Contains(vendor, "ATI"):
return GPUVendorAMDATI
case strings.Contains(vendor, "Intel"):
return GPUVendorIntel
default:
return GPUVendorUnknown
}
}
type GPU struct {
// Index is the self-reported index of the device in the system
Index int
// Name is the model name of the GPU e.g. Tesla T4
Name string
// Vendor is the maker of the GPU, e.g. NVidia, AMD, Intel
Vendor GPUVendor
// PCIAddress is the PCI address of the device, in the format AAAA:BB:CC.C
// Used to discover the correct device rendering cards
PCIAddress string
// Model of the GPU, e.g. A100
Model string `json:"model" description:"GPU model, ex A100"`
// TotalVRAM is the total amount of VRAM on the device
TotalVRAM uint64
// UsedVRAM is the amount of VRAM currently in use
UsedVRAM uint64
// FreeVRAM is the amount of VRAM currently free
FreeVRAM uint64
// Gorm fields
// Team, is this the right way to do this? What is the best practice we're following?
ResourceID uint `gorm:"foreignKey:ID"`
}
type GPUList []GPU
// GetGPUWithHighestFreeVRAM Determine the GPU vendor with the highest free VRAM: NVIDIA, AMD, or Intel.
// Useful for selecting the best GPU if multiple vendors are available,
// especially in multi-GPU systems or mining rigs.
func (gpus GPUList) GetGPUWithHighestFreeVRAM() (GPU, error) {
if len(gpus) == 0 {
// Return a GPU with Vendor set to None if no GPUs are detected - Useful for launching CPU-only containers
return GPU{Vendor: None}, nil
}
var maxFreeVRAMGpu GPU
maxFreeVRAM := uint64(0)
for _, gpu := range gpus {
if gpu.FreeVRAM > maxFreeVRAM {
maxFreeVRAM = gpu.FreeVRAM
maxFreeVRAMGpu = gpu
}
}
return maxFreeVRAMGpu, nil
}
// negativeValueError is a type struct used to return a custom error for negative values in resources subtraction
type negativeValueError struct {
resource string
r1 any
r2 any
}
// Error returns the error message
func (e *negativeValueError) Error() string {
return fmt.Sprintf("Error: %s subtraction results in negative values. (%d - %d)", e.resource, e.r1, e.r2)
}
// ResourceOps defines the operations on resources
// TODO: Check how to handle GPU resources
type ResourceOps interface {
// Add returns the sum of the resources
Add(r Resources) Resources
// Subtract returns the difference of the resources
Subtract(r Resources) (Resources, error)
}
// Resources represents the resources of the machine
type Resources struct {
CPU float64
NumCores uint64
GPU []GPU `gorm:"foreignKey:ResourceID"`
RAM uint64
Disk uint64
}
// Add returns the sum of the resources
func (r Resources) Add(r2 Resources) Resources {
//TODO: GPU addition
return Resources{
CPU: r.CPU + r2.CPU,
RAM: r.RAM + r2.RAM,
Disk: r.Disk + r2.Disk,
}
}
// Subtract returns the difference of the resources
func (r Resources) Subtract(r2 Resources) (Resources, error) {
// Check if the subtraction results in negative values
// Team, why do we need to return a negative value error?
// Can't we just return the negative value indicating that the machine is overused?
// We can then handle the scenario in the calling function by checking if the result is negative if required.
if r.CPU < r2.CPU {
return Resources{}, &negativeValueError{resource: "CPU", r1: r.CPU, r2: r2.CPU}
}
// TODO: GPU subtraction
if r.RAM < r2.RAM {
return Resources{}, &negativeValueError{resource: "RAM", r1: r.RAM, r2: r2.RAM}
}
if r.Disk < r2.Disk {
return Resources{}, &negativeValueError{resource: "Disk", r1: r.Disk, r2: r2.Disk}
}
return Resources{
CPU: r.CPU - r2.CPU,
RAM: r.RAM - r2.RAM,
Disk: r.Disk - r2.Disk,
}, nil
}
var _ ResourceOps = (*Resources)(nil)
// FreeResources represents the free resources of the machine
type FreeResources struct {
BaseDBModel
Resources
}
// OnboardedResources represents the onboarded resources of the machine
type OnboardedResources struct {
BaseDBModel
Resources
}
// RequiredResources represents the resources required by the jobs running on the machine
// TODO: this is a replacement for ServiceResourceRequirements. Check with the team on this.
type RequiredResources struct {
BaseDBModel
JobID int
Resources
}
// CPUInfo represents the CPU information of the machine
type CPUInfo struct {
NumCores uint64
MHzPerCore float64
Compute float64
}
// SpecInfo represents the machine specifications
// TODO: Finalise the fields required in this struct
// https://gitlab.com/nunet/device-management-service/-/issues/533
type SpecInfo struct {
CPUs []CPU
GPUs []GPU
RAMs []RAM
Disks []Disk
Network NetworkInfo
}
// CPU represents the CPU information
type CPU struct {
// Model represents the CPU model, e.g., "Intel Core i7-9700K", "AMD Ryzen 9 5900X"
Model string
// Vendor represents the CPU manufacturer, e.g., "Intel", "AMD"
Vendor string
// ClockSpeedHz represents the CPU clock speed in Hz
ClockSpeedHz uint64
// Cores represents the number of physical CPU cores
Cores int
// Threads represents the number of logical CPU threads (including hyperthreading)
Threads int
// Architecture represents the CPU architecture, e.g., "x86", "x86_64", "arm64"
Architecture string
// Cache size in bytes
CacheSize uint64
}
// RAM represents the RAM information
type RAM struct {
// Size in bytes
Size uint64
// Clock speed in Hz
ClockSpeedHz uint64
// Type represents the RAM type, e.g., "DDR4", "DDR5", "LPDDR4"
Type string
}
// Disk represents the disk information
type Disk struct {
// Model represents the disk model, e.g., "Samsung 970 EVO Plus", "Western Digital Blue SN550"
// TODO: may be removed as Disk models will be usually irrelevant, right?
Model string
// Vendor represents the disk manufacturer, e.g., "Samsung", "Western Digital"
// TODO: may be removed as Disk vendors will be usually irrelevant, right?
Vendor string
// Size in bytes
Size uint64
// Type represents the disk type, e.g., "SSD", "HDD", "NVMe"
Type string
// Interface represents the disk interface, e.g., "SATA", "PCIe", "M.2"
Interface string
// Read speed in bytes per second
// TODO: may be removed as it may be too specific for our case
ReadSpeed uint64
// Write speed in bytes per second
// TODO: may be removed as it may be too specific for our case
WriteSpeed uint64
}
// NetworkInfo represents the network information
type NetworkInfo struct {
// Bandwidth in bits per second (b/s)
Bandwidth uint64
// NetworkType represents the network type, e.g., "Ethernet", "Wi-Fi", "Cellular"
NetworkType string
}
// ExecutionResources represents the resources required to execute a task
type ExecutionResources struct {
// CPU configuration
CPU CPU `json:"cpu,omitempty" description:"CPU configuration"`
// Memory configuration
Memory RAM `json:"memory,omitempty" description:"Memory configuration"`
// Disk configuration
Disk Disk `json:"disk,omitempty" description:"Disk configuration"`
// GPU configuration
GPUs []GPU `json:"gpus,omitempty" description:"GPU configuration"`
}
package types
import (
"errors"
"strings"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// SpecConfig represents a configuration for a spec
// A SpecConfig can be used to define an engine spec, a storage volume, etc.
type SpecConfig struct {
// Type of the spec (e.g. docker, firecracker, storage, etc.)
Type string `json:"type"`
// Params of the spec
Params map[string]interface{} `json:"params,omitempty"`
}
type Config interface {
GetNetworkConfig() *SpecConfig
}
// NewSpecConfig creates a new SpecConfig with the given type
func NewSpecConfig(t string) *SpecConfig {
return &SpecConfig{
Type: t,
Params: make(map[string]interface{}),
}
}
// WithParam adds a new key-value pair to the spec params
func (s *SpecConfig) WithParam(key string, value interface{}) *SpecConfig {
if s.Params == nil {
s.Params = make(map[string]interface{})
}
s.Params[key] = value
return s
}
// Normalize ensures that the spec config is in a valid state
func (s *SpecConfig) Normalize() {
if s == nil {
return
}
s.Type = strings.TrimSpace(s.Type)
// Ensure that an empty and nil map are treated the same
if len(s.Params) == 0 {
s.Params = make(map[string]interface{})
}
}
// Validate checks if the spec config is valid
func (s *SpecConfig) Validate() error {
if s == nil {
return errors.New("nil spec config")
}
if validate.IsBlank(s.Type) {
return errors.New("missing spec type")
}
return nil
}
// IsType returns true if the current SpecConfig is of the given type
func (s *SpecConfig) IsType(t string) bool {
if s == nil {
return false
}
t = strings.TrimSpace(t)
return strings.EqualFold(s.Type, t)
}
// IsEmpty returns true if the spec config is empty
func (s *SpecConfig) IsEmpty() bool {
return s == nil || (validate.IsBlank(s.Type) && len(s.Params) == 0)
}
package types
import (
"log"
"os"
)
type CollectorConfig struct {
CollectorType string
CollectorEndpoint string
}
type TelemetryConfig struct {
ServiceName string
GlobalEndpoint string
ObservabilityLevel int
CollectorConfigs map[string]CollectorConfig
}
func LoadConfigFromEnv() (*TelemetryConfig, error) {
levelStr := os.Getenv("OBSERVABILITY_LEVEL")
level := parseObservabilityLevel(levelStr)
// Assume the format for collector-specific configs is like COLLECTOR_<TYPE>_ENDPOINT
collectorConfigs := make(map[string]CollectorConfig)
for _, collectorType := range []string{"OPENTELEMETRY", "LOG"} {
endpoint := os.Getenv("COLLECTOR_" + collectorType + "_ENDPOINT")
if endpoint != "" {
collectorConfigs[collectorType] = CollectorConfig{
CollectorType: collectorType,
CollectorEndpoint: endpoint,
}
}
}
config := &TelemetryConfig{
ServiceName: os.Getenv("SERVICE_NAME"),
GlobalEndpoint: os.Getenv("COLLECTOR_ENDPOINT"),
ObservabilityLevel: level,
CollectorConfigs: collectorConfigs,
}
// Debug: Log loaded environment variables
log.Printf("Loaded environment variables: SERVICE_NAME=%s, COLLECTOR_ENDPOINT=%s, OBSERVABILITY_LEVEL=%s", config.ServiceName, config.GlobalEndpoint, levelStr)
return config, nil
}
func parseObservabilityLevel(levelStr string) int {
switch levelStr {
case "TRACE":
return int(TRACE)
case "DEBUG":
return int(DEBUG)
case "INFO":
return int(INFO)
case "WARN":
return int(WARN)
case "ERROR":
return int(ERROR)
case "FATAL":
return int(FATAL)
default:
log.Printf("Invalid OBSERVABILITY_LEVEL: %s, defaulting to INFO", levelStr)
return int(INFO)
}
}
// ObservabilityLevel defines levels of observability.
type ObservabilityLevel int
// Constants representing levels of observability.
const (
TRACE ObservabilityLevel = 1
DEBUG ObservabilityLevel = 2
INFO ObservabilityLevel = 3
WARN ObservabilityLevel = 4
ERROR ObservabilityLevel = 5
FATAL ObservabilityLevel = 6
)
package types
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// BaseDBModel is a base model for all entities. It'll be mainly used for database
// records.
type BaseDBModel struct {
ID string `gorm:"type:uuid"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// GetID returns the ID of the entity.
func (m BaseDBModel) GetID() string {
return m.ID
}
// BeforeCreate sets the ID and CreatedAt fields before creating a new entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeCreate(tx *gorm.DB) error {
m.ID = uuid.NewString()
m.CreatedAt = time.Now()
return nil
}
// BeforeUpdate sets the UpdatedAt field before updating an entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeUpdate(tx *gorm.DB) error {
m.UpdatedAt = time.Now()
return nil
}
package validate
import (
"reflect"
)
func ConvertNumericToFloat64(n any) (float64, bool) {
switch n := n.(type) {
case int, int8, int16, int32, int64:
return float64(reflect.ValueOf(n).Int()), true
case uint, uint8, uint16, uint32, uint64:
return float64(reflect.ValueOf(n).Uint()), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}
package validate
import (
"strings"
)
// IsBlank checks if a string is empty or contains only whitespace
func IsBlank(s string) bool {
if len(strings.TrimSpace(s)) == 0 {
return true
}
return false
}
// IsNotBlank checks if a string is not empty and does not contain only whitespace
func IsNotBlank(s string) bool {
return !IsBlank(s)
}
// Just checks if a variable is a string
func IsLiteral(s interface{}) bool {
switch s.(type) {
case string:
return true
default:
return false
}
}