package repositories_clover
import (
"fmt"
clover "github.com/ostafen/clover/v2"
)
// NewDB initializes and sets up the clover database using bbolt under the hood.
// Additionally, it automatically creates collections for the necessary models.
func NewDB(path string, collections []string) (*clover.DB, error) {
db, err := clover.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
for _, collection := range collections {
if err := db.CreateCollection(collection); err != nil {
return nil, fmt.Errorf("failed to create collection %s: %w", collection, err)
}
}
return db, nil
}
package repositories_clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
// DeploymentRequestFlatRepositoryClover is a Clover implementation of the DeploymentRequestFlatRepository interface.
type DeploymentRequestFlatRepositoryClover struct {
repositories.GenericRepository[models.DeploymentRequestFlat]
}
// NewDeploymentRequestFlatRepository creates a new instance of DeploymentRequestFlatRepositoryClover.
// It initializes and returns a Clover-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlatRepository(db *clover.DB) repositories.DeploymentRequestFlatRepository {
return &DeploymentRequestFlatRepositoryClover{
NewGenericRepository[models.DeploymentRequestFlat](db),
}
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
// RequestTrackerRepositoryClover is a Clover implementation of the RequestTrackerRepository interface.
type RequestTrackerRepositoryClover struct {
repositories.GenericRepository[models.RequestTracker]
}
// NewRequestTrackerRepository creates a new instance of RequestTrackerRepositoryClover.
// It initializes and returns a Clover-based repository for RequestTracker entities.
func NewRequestTrackerRepository(db *clover.DB) repositories.RequestTrackerRepository {
return &RequestTrackerRepositoryClover{
NewGenericRepository[models.RequestTracker](db),
}
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
// VirtualMachineRepositoryClover is a Clover implementation of the VirtualMachineRepository interface.
type VirtualMachineRepositoryClover struct {
repositories.GenericRepository[models.VirtualMachine]
}
// NewVirtualMachineRepository creates a new instance of VirtualMachineRepositoryClover.
// It initializes and returns a Clover-based repository for VirtualMachine entities.
func NewVirtualMachineRepository(db *clover.DB) repositories.VirtualMachineRepository {
return &VirtualMachineRepositoryClover{
NewGenericRepository[models.VirtualMachine](db),
}
}
package repositories_clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pKField = "_id"
)
// GenericEntityRepositoryClover is a generic single entity repository implementation using Clover.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryClover.
// It initializes and returns a repository with the provided Clover database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericEntityRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericEntityRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericEntityRepositoryClover[T]) query() *clover_q.Query {
return clover_q.NewQuery(repo.collection)
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryClover[T]) Save(ctx context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) Get(ctx context.Context) (T, error) {
var model T
q := repo.query().Sort(clover_q.SortOption{
Field: "CreatedAt",
Direction: -1,
})
doc, err := repo.db.FindFirst(q)
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, fmt.Errorf("Failed to convert document to model: %v", err)
}
return model, nil
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryClover[T]) Clear(ctx context.Context) error {
return repo.db.Delete(repo.query())
}
// History retrieves previous versions of the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
q := repo.query()
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return false
}
models = append(models, model)
return true
})
return models, handleDBError(err)
}
package repositories_clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pkField = "_id"
deletedAtField = "DeletedAt"
)
// GenericRepositoryClover is a generic repository implementation using Clover.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericRepository creates a new instance of GenericRepositoryClover.
// It initializes and returns a repository with the provided Clover database.
func NewGenericRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericRepositoryClover[T]) query(includeDeleted bool) *clover_q.Query {
query := clover_q.NewQuery(repo.collection)
if !includeDeleted {
query = query.Where(clover_q.Field(deletedAtField).LtEq(time.Unix(0, 0)))
}
return query
}
func (repo *GenericRepositoryClover[T]) queryWithID(
id interface{},
includeDeleted bool,
) *clover_q.Query {
return repo.query(includeDeleted).Where(clover_q.Field(pkField).Eq(id.(string)))
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryClover[T]) Create(ctx context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return data, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryClover[T]) Get(ctx context.Context, id interface{}) (T, error) {
var model T
doc, err := repo.db.FindFirst(repo.queryWithID(id, false))
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryClover[T]) Update(
ctx context.Context,
id interface{},
data T,
) (T, error) {
updates := toCloverDoc(data).AsMap()
updates["UpdatedAt"] = time.Now()
err := repo.db.Update(repo.queryWithID(id, false), updates)
if err != nil {
return data, handleDBError(err)
}
data, err = repo.Get(ctx, id)
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryClover[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.Update(
repo.queryWithID(id, false),
map[string]interface{}{"DeletedAt": time.Now()},
)
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryClover[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var model T
q := repo.query(false)
q = applyConditions(q, query)
doc, err := repo.db.FindFirst(q)
if err != nil {
return model, handleDBError(err)
} else if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, fmt.Errorf("Failed to convert document to model: %v", err)
}
return model, nil
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryClover[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
var modelParsingErr error
q := repo.query(false)
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
model, internalErr := toModel[T](doc, false)
if internalErr != nil {
modelParsingErr = handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, internalErr))
return false
}
models = append(models, model)
return true
})
if err != nil {
return models, handleDBError(err)
}
if modelParsingErr != nil {
return models, modelParsingErr
}
return models, nil
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a Clover database query.
// It takes a Clover database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified Clover database instance is returned.
func applyConditions[T repositories.ModelType](
q *clover_q.Query,
query repositories.Query[T],
) *clover_q.Query {
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
// change the field name to json tag name if specified in the struct
condition.Field = fieldJSONTag[T](condition.Field)
switch condition.Operator {
case "=":
q = q.Where(clover_q.Field(condition.Field).Eq(condition.Value))
case ">":
q = q.Where(clover_q.Field(condition.Field).Gt(condition.Value))
case ">=":
q = q.Where(clover_q.Field(condition.Field).GtEq(condition.Value))
case "<":
q = q.Where(clover_q.Field(condition.Field).Lt(condition.Value))
case "<=":
q = q.Where(clover_q.Field(condition.Field).LtEq(condition.Value))
case "!=":
q = q.Where(clover_q.Field(condition.Field).Neq(condition.Value))
case "IN":
if values, ok := condition.Value.([]interface{}); ok {
q = q.Where(clover_q.Field(condition.Field).In(values...))
}
case "LIKE":
if value, ok := condition.Value.(string); ok {
q = q.Where(clover_q.Field(condition.Field).Like(value))
}
}
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldName = fieldJSONTag[T](fieldName)
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
q = q.Where(clover_q.Field(fieldName).Eq(fieldValue))
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := 1
if query.SortBy[0] == '-' {
dir = -1
query.SortBy = fieldJSONTag[T](query.SortBy[1:])
}
q = q.Sort(clover_q.SortOption{Field: query.SortBy, Direction: dir})
}
// Apply limit if specified in the query.
if query.Limit > 0 {
q = q.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
q = q.Limit(query.Offset)
}
return q
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
// PeerInfoRepositoryClover is a Clover implementation of the PeerInfoRepository interface.
type PeerInfoRepositoryClover struct {
repositories.GenericRepository[models.PeerInfo]
}
// NewPeerInfoRepository creates a new instance of PeerInfoRepositoryClover.
// It initializes and returns a Clover-based repository for PeerInfo entities.
func NewPeerInfoRepository(db *clover.DB) repositories.PeerInfoRepository {
return &PeerInfoRepositoryClover{NewGenericRepository[models.PeerInfo](db)}
}
// MachineRepositoryClover is a Clover implementation of the MachineRepository interface.
type MachineRepositoryClover struct {
repositories.GenericRepository[models.Machine]
}
// NewMachineRepository creates a new instance of MachineRepositoryClover.
// It initializes and returns a Clover-based repository for Machine entities.
func NewMachineRepository(db *clover.DB) repositories.MachineRepository {
return &MachineRepositoryClover{NewGenericRepository[models.Machine](db)}
}
// FreeResourcesRepositoryClover is a Clover implementation of the FreeResourcesRepository interface.
type FreeResourcesRepositoryClover struct {
repositories.GenericEntityRepository[models.FreeResources]
}
// NewFreeResourcesRepository creates a new instance of FreeResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResourcesRepository(db *clover.DB) repositories.FreeResourcesRepository {
return &FreeResourcesRepositoryClover{NewGenericEntityRepository[models.FreeResources](db)}
}
// AvailableResourcesRepositoryClover is a Clover implementation of the AvailableResourcesRepository interface.
type AvailableResourcesRepositoryClover struct {
repositories.GenericEntityRepository[models.AvailableResources]
}
// NewAvailableResourcesRepository creates a new instance of AvailableResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for AvailableResources entity.
func NewAvailableResourcesRepository(db *clover.DB) repositories.AvailableResourcesRepository {
return &AvailableResourcesRepositoryClover{
NewGenericEntityRepository[models.AvailableResources](db),
}
}
// ServicesRepositoryClover is a Clover implementation of the ServicesRepository interface.
type ServicesRepositoryClover struct {
repositories.GenericRepository[models.Services]
}
// NewServicesRepository creates a new instance of ServicesRepositoryClover.
// It initializes and returns a Clover-based repository for Services entities.
func NewServicesRepository(db *clover.DB) repositories.ServicesRepository {
return &ServicesRepositoryClover{NewGenericRepository[models.Services](db)}
}
// ServiceResourceRequirementsRepositoryClover is a Clover implementation of the ServiceResourceRequirementsRepository interface.
type ServiceResourceRequirementsRepositoryClover struct {
repositories.GenericRepository[models.ServiceResourceRequirements]
}
// NewServiceResourceRequirementsRepository creates a new instance of ServiceResourceRequirementsRepositoryClover.
// It initializes and returns a Clover-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirementsRepository(
db *clover.DB,
) repositories.ServiceResourceRequirementsRepository {
return &ServiceResourceRequirementsRepositoryClover{
NewGenericRepository[models.ServiceResourceRequirements](db),
}
}
// Libp2pInfoRepositoryClover is a Clover implementation of the Libp2pInfoRepository interface.
type Libp2pInfoRepositoryClover struct {
repositories.GenericEntityRepository[models.Libp2pInfo]
}
// NewLibp2pInfoRepository creates a new instance of Libp2pInfoRepositoryClover.
// It initializes and returns a Clover-based repository for Libp2pInfo entity.
func NewLibp2pInfoRepository(db *clover.DB) repositories.Libp2pInfoRepository {
return &Libp2pInfoRepositoryClover{NewGenericEntityRepository[models.Libp2pInfo](db)}
}
// MachineUUIDRepositoryClover is a Clover implementation of the MachineUUIDRepository interface.
type MachineUUIDRepositoryClover struct {
repositories.GenericEntityRepository[models.MachineUUID]
}
// NewMachineUUIDRepository creates a new instance of MachineUUIDRepositoryClover.
// It initializes and returns a Clover-based repository for MachineUUID entity.
func NewMachineUUIDRepository(db *clover.DB) repositories.MachineUUIDRepository {
return &MachineUUIDRepositoryClover{NewGenericEntityRepository[models.MachineUUID](db)}
}
// ConnectionRepositoryClover is a Clover implementation of the ConnectionRepository interface.
type ConnectionRepositoryClover struct {
repositories.GenericRepository[models.Connection]
}
// NewConnectionRepository creates a new instance of ConnectionRepositoryClover.
// It initializes and returns a Clover-based repository for Connection entities.
func NewConnectionRepository(db *clover.DB) repositories.ConnectionRepository {
return &ConnectionRepositoryClover{NewGenericRepository[models.Connection](db)}
}
// ElasticTokenRepositoryClover is a Clover implementation of the ElasticTokenRepository interface.
type ElasticTokenRepositoryClover struct {
repositories.GenericRepository[models.ElasticToken]
}
// NewElasticTokenRepository creates a new instance of ElasticTokenRepositoryClover.
// It initializes and returns a Clover-based repository for ElasticToken entities.
func NewElasticTokenRepository(db *clover.DB) repositories.ElasticTokenRepository {
return &ElasticTokenRepositoryClover{NewGenericRepository[models.ElasticToken](db)}
}
package repositories_clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
// StorageVolumeRepositoryClover is a Clover implementation of the StorageVolumeRepository interface.
type StorageVolumeRepositoryClover struct {
repositories.GenericRepository[models.StorageVolume]
}
// NewStorageVolumeRepository creates a new instance of StorageVolumeRepositoryClover.
// It initializes and returns a Clover-based repository for StorageVolume entities.
func NewStorageVolumeRepository(db *clover.DB) repositories.StorageVolumeRepository {
return &StorageVolumeRepositoryClover{
NewGenericRepository[models.StorageVolume](db),
}
}
package repositories_clover
import (
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
"gitlab.com/nunet/device-management-service/db/repositories"
)
func handleDBError(err error) error {
if err != nil {
switch err {
case clover.ErrDocumentNotExist:
return repositories.NotFoundError
case clover.ErrDuplicateKey:
return repositories.InvalidDataError
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.DatabaseError, err)
}
}
return nil
}
func toCloverDoc[T repositories.ModelType](data T) *clover_d.Document {
jsonBytes, err := json.Marshal(data)
if err != nil {
return clover_d.NewDocument()
}
mappedData := make(map[string]interface{})
err = json.Unmarshal(jsonBytes, &mappedData)
if err != nil {
return clover_d.NewDocument()
}
doc := clover_d.NewDocumentOf(mappedData)
return doc
}
func toModel[T repositories.ModelType](doc *clover_d.Document, isEntityRepo bool) (T, error) {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return model, err
}
if !isEntityRepo {
// we shouldn't try to update IDs of entity repositories as they might not
// even have an ID at all
model, err = repositories.UpdateField(model, "ID", doc.ObjectId())
if err != nil {
return model, err
}
}
return model, nil
}
func fieldJSONTag[T repositories.ModelType](field string) string {
fieldName := field
if field, ok := reflect.TypeOf(*new(T)).FieldByName(field); ok {
if tag, ok := field.Tag.Lookup("json"); ok {
fieldName = strings.Split(tag, ",")[0]
}
}
return fieldName
}
package repositories
import (
"context"
)
// QueryCondition is a struct representing a query condition.
type QueryCondition struct {
Field string // Field specifies the database or struct field to which the condition applies.
Operator string // Operator defines the comparison operator (e.g., "=", ">", "<").
Value interface{} // Value is the expected value for the given field.
}
type ModelType interface{}
// Query is a struct that wraps both the instance of type T and additional query parameters.
// It is used to construct queries with conditions, sorting, limiting, and offsetting.
type Query[T any] struct {
Instance T // Instance is an optional object of type T used to build conditions from its fields.
Conditions []QueryCondition // Conditions represent the conditions applied to the query.
SortBy string // SortBy specifies the field by which the query results should be sorted.
Limit int // Limit specifies the maximum number of results to return.
Offset int // Offset specifies the number of results to skip before starting to return data.
}
// GenericRepository is an interface defining basic CRUD operations and standard querying methods.
type GenericRepository[T ModelType] interface {
// Create adds a new record to the repository.
Create(ctx context.Context, data T) (T, error)
// Get retrieves a record by its identifier.
Get(ctx context.Context, id interface{}) (T, error)
// Update modifies a record by its identifier.
Update(ctx context.Context, id interface{}, data T) (T, error)
// Delete removes a record by its identifier.
Delete(ctx context.Context, id interface{}) error
// Find retrieves a single record based on a query.
Find(ctx context.Context, query Query[T]) (T, error)
// FindAll retrieves multiple records based on a query.
FindAll(ctx context.Context, query Query[T]) ([]T, error)
// GetQuery returns an empty query instance for the repository's type.
GetQuery() Query[T]
}
// EQ creates a QueryCondition for equality comparison.
// It takes a field name and a value and returns a QueryCondition with the equality operator.
func EQ(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "=", Value: value}
}
// GT creates a QueryCondition for greater-than comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than operator.
func GT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">", Value: value}
}
// GTE creates a QueryCondition for greater-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than or equal operator.
func GTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">=", Value: value}
}
// LT creates a QueryCondition for less-than comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than operator.
func LT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<", Value: value}
}
// LTE creates a QueryCondition for less-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than or equal operator.
func LTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<=", Value: value}
}
// IN creates a QueryCondition for an "IN" comparison.
// It takes a field name and a slice of values and returns a QueryCondition with the "IN" operator.
func IN(field string, values []interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "IN", Value: values}
}
// LIKE creates a QueryCondition for a "LIKE" comparison.
// It takes a field name and a pattern and returns a QueryCondition with the "LIKE" operator.
func LIKE(field, pattern string) QueryCondition {
return QueryCondition{Field: field, Operator: "LIKE", Value: pattern}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
// DeploymentRequestFlatRepositoryGORM is a GORM implementation of the DeploymentRequestFlatRepository interface.
type DeploymentRequestFlatRepositoryGORM struct {
repositories.GenericRepository[models.DeploymentRequestFlat]
}
// NewDeploymentRequestFlatRepository creates a new instance of DeploymentRequestFlatRepositoryGORM.
// It initializes and returns a GORM-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlatRepository(db *gorm.DB) repositories.DeploymentRequestFlatRepository {
return &DeploymentRequestFlatRepositoryGORM{
NewGenericRepository[models.DeploymentRequestFlat](db),
}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
// RequestTrackerRepositoryGORM is a GORM implementation of the RequestTrackerRepository interface.
type RequestTrackerRepositoryGORM struct {
repositories.GenericRepository[models.RequestTracker]
}
// NewRequestTrackerRepository creates a new instance of RequestTrackerRepositoryGORM.
// It initializes and returns a GORM-based repository for RequestTracker entities.
func NewRequestTrackerRepository(db *gorm.DB) repositories.RequestTrackerRepository {
return &RequestTrackerRepositoryGORM{
NewGenericRepository[models.RequestTracker](db),
}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
// VirtualMachineRepositoryGORM is a GORM implementation of the VirtualMachineRepository interface.
type VirtualMachineRepositoryGORM struct {
repositories.GenericRepository[models.VirtualMachine]
}
// NewVirtualMachineRepository creates a new instance of VirtualMachineRepositoryGORM.
// It initializes and returns a GORM-based repository for VirtualMachine entities.
func NewVirtualMachineRepository(db *gorm.DB) repositories.VirtualMachineRepository {
return &VirtualMachineRepositoryGORM{
NewGenericRepository[models.VirtualMachine](db),
}
}
package repositories_gorm
import (
"context"
"fmt"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
createdAtField = "CreatedAt"
)
// GenericEntityRepositoryGORM is a generic single entity repository implementation using GORM as an ORM.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB // db is the GORM database instance.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryGORM.
// It initializes and returns a repository with the provided GORM database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *gorm.DB,
) repositories.GenericEntityRepository[T] {
return &GenericEntityRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryGORM[T]) Save(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Get(ctx context.Context) (T, error) {
var result T
query := repo.GetQuery()
query.SortBy = fmt.Sprintf("-%s", createdAtField)
db := repo.db.WithContext(ctx)
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Clear(ctx context.Context) error {
return repo.db.WithContext(ctx).Delete(new(T), "id IS NOT NULL").Error
}
// History retrieves previous records from the repository constrained by the query.
func (repo *GenericEntityRepositoryGORM[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
package repositories_gorm
import (
"context"
"fmt"
"reflect"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericRepositoryGORM is a generic repository implementation using GORM as an ORM.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB
}
// NewGenericRepository creates a new instance of GenericRepositoryGORM.
// It initializes and returns a repository with the provided GORM database.
func NewGenericRepository[T repositories.ModelType](db *gorm.DB) repositories.GenericRepository[T] {
return &GenericRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryGORM[T]) Create(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Get(ctx context.Context, id interface{}) (T, error) {
var result T
err := repo.db.WithContext(ctx).First(&result, "id = ?", id).Error
if err != nil {
return result, handleDBError(err)
}
return result, handleDBError(err)
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Update(ctx context.Context, id interface{}, data T) (T, error) {
err := repo.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Updates(data).Error
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.WithContext(ctx).Delete(new(T), "id = ?", id).Error
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryGORM[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var result T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryGORM[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a GORM database query.
// It takes a GORM database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified GORM database instance is returned.
func applyConditions[T any](db *gorm.DB, query repositories.Query[T]) *gorm.DB {
// Retrieve the table name using the GORM naming strategy.
tableName := db.NamingStrategy.TableName(reflect.TypeOf(*new(T)).Name())
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
columnName := db.NamingStrategy.ColumnName(tableName, condition.Field)
db = db.Where(
fmt.Sprintf("%s %s ?", columnName, condition.Operator),
condition.Value,
)
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
columnName := db.NamingStrategy.ColumnName(tableName, fieldName)
db = db.Where(fmt.Sprintf("%s = ?", columnName), fieldValue)
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := "ASC"
if query.SortBy[0] == '-' {
query.SortBy = query.SortBy[1:]
dir = "DESC"
}
columnName := db.NamingStrategy.ColumnName(tableName, query.SortBy)
db = db.Order(fmt.Sprintf("%s.%s %s", tableName, columnName, dir))
}
// Apply limit if specified in the query.
if query.Limit > 0 {
db = db.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
db = db.Limit(query.Offset)
}
return db
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
// PeerInfoRepositoryGORM is a GORM implementation of the PeerInfoRepository interface.
type PeerInfoRepositoryGORM struct {
repositories.GenericRepository[models.PeerInfo]
}
// NewPeerInfoRepository creates a new instance of PeerInfoRepositoryGORM.
// It initializes and returns a GORM-based repository for PeerInfo entities.
func NewPeerInfoRepository(db *gorm.DB) repositories.PeerInfoRepository {
return &PeerInfoRepositoryGORM{NewGenericRepository[models.PeerInfo](db)}
}
// MachineRepositoryGORM is a GORM implementation of the MachineRepository interface.
type MachineRepositoryGORM struct {
repositories.GenericRepository[models.Machine]
}
// NewMachineRepository creates a new instance of MachineRepositoryGORM.
// It initializes and returns a GORM-based repository for Machine entities.
func NewMachineRepository(db *gorm.DB) repositories.MachineRepository {
return &MachineRepositoryGORM{NewGenericRepository[models.Machine](db)}
}
// FreeResourcesRepositoryGORM is a GORM implementation of the FreeResourcesRepository interface.
type FreeResourcesRepositoryGORM struct {
repositories.GenericEntityRepository[models.FreeResources]
}
// NewFreeResourcesRepository creates a new instance of FreeResourcesRepositoryGORM.
// It initializes and returns a GORM-based repository for FreeResources entity.
func NewFreeResourcesRepository(db *gorm.DB) repositories.FreeResourcesRepository {
return &FreeResourcesRepositoryGORM{NewGenericEntityRepository[models.FreeResources](db)}
}
// AvailableResourcesRepositoryGORM is a GORM implementation of the AvailableResourcesRepository interface.
type AvailableResourcesRepositoryGORM struct {
repositories.GenericEntityRepository[models.AvailableResources]
}
// NewAvailableResourcesRepository creates a new instance of AvailableResourcesRepositoryGORM.
// It initializes and returns a GORM-based repository for AvailableResources entity.
func NewAvailableResourcesRepository(db *gorm.DB) repositories.AvailableResourcesRepository {
return &AvailableResourcesRepositoryGORM{
NewGenericEntityRepository[models.AvailableResources](db),
}
}
// ServicesRepositoryGORM is a GORM implementation of the ServicesRepository interface.
type ServicesRepositoryGORM struct {
repositories.GenericRepository[models.Services]
}
// NewServicesRepository creates a new instance of ServicesRepositoryGORM.
// It initializes and returns a GORM-based repository for Services entities.
func NewServicesRepository(db *gorm.DB) repositories.ServicesRepository {
return &ServicesRepositoryGORM{NewGenericRepository[models.Services](db)}
}
// ServiceResourceRequirementsRepositoryGORM is a GORM implementation of the ServiceResourceRequirementsRepository interface.
type ServiceResourceRequirementsRepositoryGORM struct {
repositories.GenericRepository[models.ServiceResourceRequirements]
}
// NewServiceResourceRequirementsRepository creates a new instance of ServiceResourceRequirementsRepositoryGORM.
// It initializes and returns a GORM-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirementsRepository(
db *gorm.DB,
) repositories.ServiceResourceRequirementsRepository {
return &ServiceResourceRequirementsRepositoryGORM{
NewGenericRepository[models.ServiceResourceRequirements](db),
}
}
// Libp2pInfoRepositoryGORM is a GORM implementation of the Libp2pInfoRepository interface.
type Libp2pInfoRepositoryGORM struct {
repositories.GenericEntityRepository[models.Libp2pInfo]
}
// NewLibp2pInfoRepository creates a new instance of Libp2pInfoRepositoryGORM.
// It initializes and returns a GORM-based repository for Libp2pInfo entity.
func NewLibp2pInfoRepository(db *gorm.DB) repositories.Libp2pInfoRepository {
return &Libp2pInfoRepositoryGORM{NewGenericEntityRepository[models.Libp2pInfo](db)}
}
// MachineUUIDRepositoryGORM is a GORM implementation of the MachineUUIDRepository interface.
type MachineUUIDRepositoryGORM struct {
repositories.GenericEntityRepository[models.MachineUUID]
}
// NewMachineUUIDRepository creates a new instance of MachineUUIDRepositoryGORM.
// It initializes and returns a GORM-based repository for MachineUUID entity.
func NewMachineUUIDRepository(db *gorm.DB) repositories.MachineUUIDRepository {
return &MachineUUIDRepositoryGORM{NewGenericEntityRepository[models.MachineUUID](db)}
}
// ConnectionRepositoryGORM is a GORM implementation of the ConnectionRepository interface.
type ConnectionRepositoryGORM struct {
repositories.GenericRepository[models.Connection]
}
// NewConnectionRepository creates a new instance of ConnectionRepositoryGORM.
// It initializes and returns a GORM-based repository for Connection entities.
func NewConnectionRepository(db *gorm.DB) repositories.ConnectionRepository {
return &ConnectionRepositoryGORM{NewGenericRepository[models.Connection](db)}
}
// ElasticTokenRepositoryGORM is a GORM implementation of the ElasticTokenRepository interface.
type ElasticTokenRepositoryGORM struct {
repositories.GenericRepository[models.ElasticToken]
}
// NewElasticTokenRepository creates a new instance of ElasticTokenRepositoryGORM.
// It initializes and returns a GORM-based repository for ElasticToken entities.
func NewElasticTokenRepository(db *gorm.DB) repositories.ElasticTokenRepository {
return &ElasticTokenRepositoryGORM{NewGenericRepository[models.ElasticToken](db)}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
)
type OnboardingParamsRepositoryGORM struct {
repositories.GenericEntityRepository[models.OnboardingConfig]
}
func NewOnboardingParamsRepository(db *gorm.DB) repositories.OnboardingParamsRepository {
return &OnboardingParamsRepositoryGORM{
NewGenericEntityRepository[models.OnboardingConfig](db),
}
}
package repositories_gorm
import (
"errors"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const structFieldNameDeletedAt = "DeletedAt"
// handleDBError is a utility function that translates GORM database errors into custom repository errors.
// It takes a GORM database error as input and returns a corresponding custom error from the repositories package.
func handleDBError(err error) error {
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return repositories.NotFoundError
case gorm.ErrInvalidData, gorm.ErrInvalidField, gorm.ErrInvalidValue:
return repositories.InvalidDataError
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.DatabaseError, err)
}
}
return nil
}
package repositories
import (
"fmt"
"reflect"
)
// UpdateField is a generic function that updates a field of a struct or a pointer to a struct.
// The function uses reflection to dynamically update the specified field of the input struct.
func UpdateField[T interface{}](input T, fieldName string, newValue interface{}) (T, error) {
// Use reflection to get the struct's field
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
// If input is a pointer, get the underlying element
val = val.Elem()
} else {
// If input is not a pointer, ensure it's addressable
val = reflect.ValueOf(&input).Elem()
}
// Check if the input is a struct
if val.Kind() != reflect.Struct {
return input, fmt.Errorf("Not a struct: %T", input)
}
// Get the field by name
field := val.FieldByName(fieldName)
if !field.IsValid() {
return input, fmt.Errorf("Field not found: %v", fieldName)
}
// Check if the field is settable
if !field.CanSet() {
return input, fmt.Errorf("Field not settable: %v", fieldName)
}
// Check if types are compatible
if !reflect.TypeOf(newValue).ConvertibleTo(field.Type()) {
return input, fmt.Errorf(
"Incompatible conversion: %v -> %v; value: %v",
field.Type(), reflect.TypeOf(newValue), newValue,
)
}
// Convert the new value to the field type
convertedValue := reflect.ValueOf(newValue).Convert(field.Type())
// Set the new value to the field
field.Set(convertedValue)
return input, nil
}
// IsEmptyValue checks if value represents a zero-value struct (or pointer to a zero-value struct) using reflection.
// The function is useful for determining if a struct or its pointer is empty, i.e., all fields have their zero-values.
func IsEmptyValue(value interface{}) bool {
// Check if the value is nil
if value == nil {
return true
}
// Use reflection to get the value's type and kind
val := reflect.ValueOf(value)
// If the value is a pointer, dereference it to get the underlying element
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Check if the value is zero (empty) based on its kind
return val.IsZero()
}
package nunet
import (
"fmt"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetTransformer() transform.Transformer {
return transform.NewTransformer(
[]map[tree.Path]transform.TransformerFunc{
{
"jobs": TransformJobs,
"jobs.**.children": TransformJobs,
"jobs.**.volumes": TransformVolumes,
"jobs.**.networks": TransformNetworks,
},
{
"jobs.**.volumes.[]": TransformVolume,
"jobs.**.networks.[]": TransformNetwork,
"jobs.**.libraries.[]": TransformLibrary,
},
{
"jobs.**.execution": TransformExecution,
"jobs.**.volumes.[].remote": TransformVolumeRemote,
},
},
)
}
// TransformJobs transforms the jobs map to a slice and assigns the keys to the "name" field.
func TransformJobs(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
jobs, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid jobs configuration: %v", data)
}
return transform.MapToSlice(jobs)
}
// TransformVolumes transforms the volumes map to a slice and assigns the keys to the "name" field.
func TransformVolumes(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
volumes, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volumes configuration: %v", data)
}
return transform.MapToSlice(volumes)
}
// TransformNetworks transforms the networks map to a slice and assigns the keys to the "name" field.
func TransformNetworks(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
networks, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid networks configuration: %v", data)
}
return transform.MapToSlice(networks)
}
// TransformExecution transforms the engine configuration from flat map to SpecConfig format.
func TransformExecution(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
engine, ok := data.(map[string]any)
result := map[string]any{}
if !ok {
return nil, fmt.Errorf("invalid engine configuration: %v", data)
}
params := map[string]any{}
for k, v := range engine {
if k != "type" {
params[k] = v
}
}
result["type"] = engine["type"]
result["params"] = params
return result, nil
}
// TransformVloume transforms the volume configuration and handles inheritance.
// The volume configuration can be a string in the format "name:mountpoint" or a map.
// If the volume is defined in the parent volumes, the configurations are merged.
func TransformVolume(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, split it into name and mountpoint.
switch v := data.(type) {
case string:
mapping := strings.Split(v, ":")
if len(mapping) != 2 {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
config = map[string]any{
"name": mapping[0],
"mountpoint": mapping[1],
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
// Collect all potential parent paths where the volume could be defined.
parentPaths := []tree.Path{}
pathParts := path.Parts()
for i, part := range pathParts {
if part == "children" {
parentPaths = append(parentPaths, tree.NewPath(pathParts[:i]...))
}
}
// Merge the volume configuration with the parent configurations.
for _, parent := range parentPaths {
// Check if the volume exists in the parent
c, err := transform.GetConfigAtPath(*root, parent.Next("volumes"))
if err != nil {
fmt.Println("error: ", err)
continue
}
volumes, _ := transform.ToAnySlice(c)
for _, v := range volumes {
if volume, ok := v.(map[string]any); ok && volume["name"] == config["name"] {
// Merge the configurations
for k, v := range volume {
config[k] = v
}
}
}
}
return config, nil
}
func TransformVolumeRemote(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
remoteConfig := map[string]any{}
remoteConfig["type"] = config["type"]
if params, ok := config["params"]; ok {
remoteConfig["params"] = params.(map[string]any)
return remoteConfig, nil
}
params := map[string]any{}
for k, v := range config {
if k != "type" {
params[k] = v
}
}
remoteConfig["params"] = params
return remoteConfig, nil
}
// TransformNetwork transforms the network configuration
func TransformNetwork(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid network configuration: %v", data)
}
ports, _ := transform.ToAnySlice(config["ports"])
portMap := []map[string]any{}
for _, port := range ports {
protocol, host, container := "tcp", 0, 0;
switch v := port.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) <= 2 {
host, _ = strconv.Atoi(parts[0])
container, _ = strconv.Atoi(parts[len(parts)-1])
} else if len(parts) == 3 {
protocol = parts[0]
host, _ = strconv.Atoi(parts[1])
container, _ = strconv.Atoi(parts[len(parts)-1])
}
case int:
host = v
container = v
case map[string]any:
switch h := v["host_port"].(type){
case int:
host = h
case string:
host, _ = strconv.Atoi(h)
}
switch c := v["container_port"].(type){
case int:
container = c
case string:
container, _ = strconv.Atoi(c)
}
if p, ok := v["protocol"].(string); ok {
protocol = p
}
}
portMap = append(portMap, map[string]any{
"protocol": protocol,
"host_port": host,
"container_port": container,
})
}
config["port_map"] = portMap
delete(config, "ports")
return config, nil
}
// TransformLibrary tansforms the library configuration to a map.
// The library configuration can be a string in the format "name:version" or a map.
func TransformLibrary(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
switch v := data.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) == 1 {
parts = append(parts, "")
}
if len(parts) != 2 {
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
return map[string]any{
"name": parts[0],
"version": parts[1],
}, nil
case map[string]any:
return v, nil
default:
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
}
package nunet
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetValidator() validate.Validator {
return validate.NewValidator(
map[tree.Path]validate.ValidatorFunc{
"": ValidateSpec,
"jobs.[]": ValidateJob,
"jobs.**.children.[]": ValidateJob,
},
)
}
// ValidateSpec checks the root configuration for consistency.
func ValidateSpec(root *map[string]any, data any, path tree.Path) error {
spec, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid spec configuration: %v", data)
}
// Check if the jobs list is present and not empty.
if spec["jobs"] == nil || len(spec["jobs"].([]any)) == 0 {
return fmt.Errorf("jobs list is required")
}
return nil
}
// ValidateJob checks the job configuration.
func ValidateJob(root *map[string]any, data any, path tree.Path) error {
job, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid job configuration: %v", data)
}
// Check if the job has either children or an execution.
if job["children"] == nil || len(job["children"].([]any)) == 0 {
if job["execution"] == nil {
return fmt.Errorf("job must have either children or an execution")
}
}
return nil
}
package transform
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// TransformerFunc is a function that transforms a part of the configuration.
// It modifies the data to conform to the expected structure and returns the transformed data.
// It takes the root configuration, the data to transform and the current path in the tree.
type TransformerFunc func(*map[string]interface{}, interface{}, tree.Path) (any, error)
// Transformer is a configuration transformer.
type Transformer interface {
Transform(*map[string]interface{}) (interface{}, error)
}
// TransformerImpl is the implementation of the Transformer interface.
type TransformerImpl struct {
transformers []map[tree.Path]TransformerFunc
}
// NewTransformer creates a new transformer with the given transformers.
func NewTransformer(transformers []map[tree.Path]TransformerFunc) Transformer {
return TransformerImpl{
transformers: transformers,
}
}
// Transform applies the transformers to the configuration.
func (t TransformerImpl) Transform(rawConfig *map[string]interface{}) (interface{}, error) {
data := any(*rawConfig)
var err error
for _, transformers := range t.transformers {
data, err = t.transform(rawConfig, data, tree.NewPath(), transformers)
if err != nil {
return nil, err
}
}
return Normalize(data), nil
}
// transform is a recursive function that applies the transformers to the configuration.
func (t TransformerImpl) transform(root *map[string]interface{}, data any, path tree.Path, transformers map[tree.Path]TransformerFunc) (interface{}, error) {
var err error
// Apply transformers that match the current path.
for pattern, transformer := range transformers {
if path.Matches(pattern) {
data, err = transformer(root, data, path)
if err != nil {
return nil, err
}
}
}
// Recursively apply transformers to children.
if result, ok := data.(map[string]interface{}); ok {
for key, value := range result {
next := path.Next(key)
result[key], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
} else if result, err := ToAnySlice(data); err == nil {
for i, value := range result {
next := path.Next(fmt.Sprintf("[%d]", i))
result[i], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
}
return data, nil
}
package transform
import (
"fmt"
"reflect"
"sort"
"strconv"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// mapToSlice converts a map of maps to a slice
// and assigns the key to the "name" field.
func MapToSlice(data map[string]any) ([]any, error) {
if data == nil {
return nil, nil
}
result := []any{}
for k, v := range data {
if v == nil {
v = map[string]any{}
}
if e, ok := v.(map[string]any); ok {
e["name"] = k
}
result = append(result, v)
}
return result, nil
}
// getConfigAtPath retrieves a part of the configuration at a given path
func GetConfigAtPath(config map[string]interface{}, path tree.Path) (any, error) {
current := any(config)
for _, key := range path.Parts() {
switch v := current.(type) {
case map[string]any:
current = v[key]
case []any, []map[string]any:
i, err := strconv.Atoi(key[1 : len(key)-1])
if err != nil {
return nil, fmt.Errorf("invalid index: %v", key)
}
switch v := v.(type) {
case []any:
current = v[i]
case []map[string]any:
current = v[i]
}
default:
return nil, fmt.Errorf("invalid data type: %v", current)
}
}
return current, nil
}
// Generic function to convert any slice to []any
func ToAnySlice(slice any) ([]any, error) {
value := reflect.ValueOf(slice)
// Check if the input is a slice
if value.Kind() != reflect.Slice {
return nil, fmt.Errorf("input is not a slice. type: %T", slice)
}
length := value.Len()
anySlice := make([]any, length)
for i := 0; i < length; i++ {
anySlice[i] = value.Index(i).Interface()
}
return anySlice, nil
}
func normalizeMap(m interface{}) interface{} {
v := reflect.ValueOf(m)
switch v.Kind() {
case reflect.Map:
// Create a new map to hold normalized values
newMap := reflect.MakeMap(reflect.MapOf(v.Type().Key(), reflect.TypeOf((*interface{})(nil)).Elem()))
for _, key := range v.MapKeys() {
newValue := normalizeMap(v.MapIndex(key).Interface())
newMap.SetMapIndex(key, reflect.ValueOf(newValue))
}
return newMap.Interface()
case reflect.Slice:
// Create a new []interface{} slice to hold normalized values
newSlice := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
newSlice[i] = normalizeMap(v.Index(i).Interface())
}
// Sort the slice if it's sortable
sort.Slice(newSlice, func(i, j int) bool {
return fmt.Sprint(newSlice[i]) < fmt.Sprint(newSlice[j])
})
return newSlice
default:
// For other types, return as is
return m
}
}
// NormalizeMap is the exported function that users will call
func Normalize(m any) interface{} {
return normalizeMap(m)
}
package tree
import (
"strings"
)
const (
configPathSeparator = "."
configPathMatchAny = "*"
configPathMatchAnyMultiple = "**"
configPathList = "[]"
)
// Path is a custom type for representing paths in the configuration
type Path string
func NewPath(path ...string) Path {
return Path(strings.Join(path, configPathSeparator))
}
// Parts returns the parts of the path
func (p Path) Parts() []string {
return strings.Split(string(p), configPathSeparator)
}
// Parent returns the parent path
func (p Path) Parent() Path {
parts := p.Parts()
if len(parts) > 1 {
return Path(strings.Join(parts[:len(parts)-1], configPathSeparator))
}
return ""
}
// Next returns the next part of the path
func (p Path) Next(path string) Path {
if path == "" {
return p
}
if p == "" {
return Path(path)
}
return Path(string(p) + configPathSeparator + path)
}
// Last returns the last part of the path
func (p Path) Last() string {
parts := p.Parts()
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}
// Matches checks if the path matches a given pattern
func (p Path) Matches(pattern Path) bool {
pathParts := p.Parts()
patternParts := pattern.Parts()
return matchParts(pathParts, patternParts)
}
// String returns the string representation of the path
func (p Path) String() string {
return string(p)
}
// matchParts checks if the path parts match the pattern parts
func matchParts(pathParts, patternParts []string) bool {
// If the pattern is longer than the path, it can't match
if len(pathParts) < len(patternParts) {
return false
}
for i, part := range patternParts {
switch part {
case configPathMatchAnyMultiple:
// if it is the last part of the pattern, it matches
if i == len(patternParts)-1 {
return true
}
// Otherwise, try to match the rest of the path
for j := i; j < len(pathParts); j++ {
if matchParts(pathParts[j:], patternParts[i+1:]) {
return true
}
}
case configPathList:
// check if pathParts[i] is inclosed by []
if pathParts[i][0] != '[' || pathParts[i][len(pathParts[i])-1] != ']' {
return false
}
default:
// If the part doesn't match, it doesn't match
if part != configPathMatchAny && part != pathParts[i] {
return false
}
}
// If it is the last part of the pattern and the path is longer, it doesn't match
if i == len(patternParts)-1 && i < len(pathParts)-1 {
return false
}
}
return true
}
package validate
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// ValidatorFunc is a function that validates a part of the configuration.
// It takes the root configuration, the data to validate and the current path in the tree.
type ValidatorFunc func(*map[string]any, any, tree.Path) error
// Validator is a configuration validator.
// It contains a map of patterns to paths to functions that validate the configuration.
type Validator interface {
Validate(*map[string]any) error
}
// ValidatorImpl is the implementation of the Validator interface.
type ValidatorImpl struct {
validators map[tree.Path]ValidatorFunc
}
// NewValidator creates a new validator with the given validators.
func NewValidator(validators map[tree.Path]ValidatorFunc) Validator {
return ValidatorImpl{
validators: validators,
}
}
// Validate applies the validators to the configuration.
func (v ValidatorImpl) Validate(rawConfig *map[string]any) error {
data := any(*rawConfig)
return v.validate(rawConfig, data, tree.NewPath(), v.validators)
}
// validate is a recursive function that applies the validators to the configuration.
func (v ValidatorImpl) validate(root *map[string]interface{}, data any, path tree.Path, validators map[tree.Path]ValidatorFunc) error {
// Apply validators that match the current path.
for pattern, validator := range validators {
if path.Matches(pattern) {
if err := validator(root, data, path); err != nil {
return err
}
}
}
// Recursively apply validators to children.
switch data := data.(type) {
case map[string]interface{}:
for key, value := range data {
next := path.Next(key)
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
case []interface{}:
for i, value := range data {
next := path.Next(fmt.Sprintf("[%d]", i))
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
}
return nil
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
)
func CapabilityComparator(l, r interface{}, preference ...Preference) models.Comparison {
var result models.Comparison
result = models.Error // error is the default value
// TODO: implement the comparison logic for the Capability type
// after all the fields are implemented
return result
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
)
func CpuComparator(l, r interface{}, preference ...Preference) models.Comparison {
// comparator for CPU type
// we want to reason about the inner fields of the CPU type and how they compare between left and right
// validate input type
lCpu, lok := l.(models.CPU)
rCpu, rok := r.(models.CPU)
if !lok || !rok {
return models.Error
}
perf_comparison := NumericComparator(
(lCpu.Cores * lCpu.Freq),
(rCpu.Cores * rCpu.Freq),
)
arch_comparision := LiteralComparator(lCpu.Arch, rCpu.Arch)
if arch_comparision == models.Error {
return models.Error
}
if arch_comparision != models.Equal {
return models.Worse
}
return perf_comparison
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and frequency is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare models of CPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
)
func DiskComparator(l, r interface{}, preference ...Preference) models.Comparison {
// comparator for Memory type
// we want to reason about the inner fields of the Memory type and how they compare between left and right
// validate input type
_, lok := l.(models.Disk)
_, rok := r.(models.Disk)
if !lok || !rok {
return models.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["Type"] == models.Error {
return models.Error
}
if comparison["Type"] != models.Equal {
return models.Worse
}
return comparison["Size"]
// currently this is a very simple comparison, based on the assumption
// that more Size / or equal amount of size and speed is acceptable, but nothing less;
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
)
func ExecutionResourcesComparator(l, r interface{}, preference ...Preference) models.Comparison {
// comparator for models.ExecutionResources type
// Current implementation of the type has four fields: CPU, Memory, Disk, GPUs
// we consider that all fields have to be 'Better' or 'Equal'
// for the comparison to be 'Better' or 'Equal'
// else we return 'Worse'
// validate input type
_, lok := l.(models.ExecutionResources)
_, rok := r.(models.ExecutionResources)
if !lok || !rok {
return models.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["CPU"] == models.Error ||
comparison["Memory"] == models.Error ||
comparison["Disk"] == models.Error ||
comparison["GPUs"] == models.Error {
return models.Error
}
if comparison["CPU"] == models.Worse ||
comparison["Memory"] == models.Worse ||
comparison["Disk"] == models.Worse ||
comparison["GPUs"] == models.Worse {
return models.Worse
}
if comparison["CPU"] == models.Equal &&
comparison["Memory"] == models.Equal &&
comparison["Disk"] == models.Equal &&
comparison["GPUs"] == models.Equal {
return models.Equal
}
return models.Better // if non above returns, then the result is better
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
)
func ExecutorComparator(lraw, rraw interface{}, preference ...Preference) models.Comparison {
// comparator for Executor types
// it is needed because executor type is defined as enum of ExecutorType's in models/execution.go
// left represent machine capabilities
// right represent required capabilities
// it is not so complex as the type has only one field
// therefore this method just passes it through...
// validate input type
_, lrawok := lraw.(models.Executor)
_, rrawok := rraw.(models.Executor)
if !lrawok || !rrawok {
return models.Error
}
l := lraw.(models.Executor)
r := rraw.(models.Executor)
leftExecutorType := l.ExecutorType
rightExecutorType := r.ExecutorType
comparison := Compare(leftExecutorType, rightExecutorType)
return comparison
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/models"
)
func ExecutorTypeComparator(l, r interface{}, preference ...Preference) models.Comparison {
_, lok := l.(models.ExecutorType)
_, rok := r.(models.ExecutorType)
if !lok || !rok {
return models.Error
}
result := models.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = models.Equal
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/utils"
)
func ExecutorsComparator(lraw, rraw interface{}, preference ...Preference) models.Comparison {
// comparator for Executors types:
// left represent machine capabilities;
// right represent required capabilities;
var result models.Comparison
result = models.Error // error is the default value
// validate input type
_, lrawok := lraw.(models.Executors)
_, rrawok := rraw.(models.Executors)
if !lrawok || !rrawok {
return models.Error
}
var l []interface{}
l = lraw.([]interface{})
var r []interface{}
r = rraw.([]interface{})
if !utils.IsSameShallowType(l, r) {
result = models.Error
}
if reflect.DeepEqual(l, r) {
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = models.Equal
} else if utils.IsStrictlyContained(l, r) {
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
result = models.Better
} else if utils.IsStrictlyContained(r, l) {
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
result = models.Worse
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/models"
)
func GPUVendorComparator(l, r interface{}, preference ...Preference) models.Comparison {
// validate input type
_, lok := l.(models.GPUVendor)
_, rok := r.(models.GPUVendor)
if !lok || !rok {
return models.Error
}
result := models.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = models.Equal
}
return result
// This comparison logic just tells if the vendor is the same or not;
// however, we do not have yet a mechanism for externally defined preferences from a user;
// in this case, we may need to implement that -- because some compute may prefer one vendor over the other;
// some compute may be strictly dependent on a specific vendor;
// technically, this will have to be solved on the resource matching level;
// but the mechanism will have to be generic...
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
"golang.org/x/exp/slices"
)
func GPUsComparator(lraw, rraw interface{}, preference ...Preference) models.Comparison {
// comparator for GPUs type which is just a slice of GPU types:
// left represent machine capabilities;
// right represent required capabilities;
// we need to check if for ech GPU on the right there exist a matching GPU on the left...
// (since given slices are not ordered...)
// validate input type
_, lrawok := lraw.([]models.GPU)
_, rrawok := rraw.([]models.GPU)
if !lrawok || !rrawok {
return models.Error
}
l := lraw.([]models.GPU)
r := rraw.([]models.GPU)
var interimComparison1 [][]models.Comparison
for _, rGPU := range r {
var interimComparison2 []models.Comparison
for _, lGPU := range l {
interimComparison2 = append(interimComparison2, Compare(lGPU, rGPU))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []models.Comparison
var consideredIndexes []int
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
consideredIndexes = append(consideredIndexes, index)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, models.Error) {
return models.Error
}
if slices.Contains(finalComparison, models.Worse) {
return models.Worse
}
if SliceContainsOneValue(finalComparison, models.Equal) {
return models.Equal
}
return models.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
)
func GpuComparator(l, r interface{}, preference ...Preference) models.Comparison {
// comparator for GPU type
// we want to reason about the inner fields of the GPU type and how they compare between left and right
// in the future we may want to pass custom preference parameters to the ComplexComparator
// for now it is probably best to hardcode them;
// validate input type
_, lok := l.(models.GPU)
_, rok := r.(models.GPU)
if !lok || !rok {
return models.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["VRAM"] == models.Error {
return models.Error
}
if comparison["VRAM"] == models.Worse {
return models.Worse
}
if comparison["VRAM"] == models.Better {
return models.Better
}
if comparison["VRAM"] == models.Equal {
return models.Equal
}
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and VRAM is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare models of GPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
return models.Error // error is the default value
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/models"
)
func JobTypeComparator(l, r interface{}, preference ...Preference) models.Comparison {
// validate input type
_, lok := l.(models.JobType)
_, rok := r.(models.JobType)
if !lok || !rok {
return models.Error
}
result := models.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = models.Equal
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/utils"
)
func JobTypesComparator(lraw, rraw interface{}, preference ...Preference) models.Comparison {
// comparator for JobTypes type:
// left represent machine capabilities;
// right represent required capabilities;
// if machine capabilities contain oll the required capabilities, then we are good to go
// validate input type
_, lrawok := lraw.(models.JobTypes)
_, rrawok := rraw.(models.JobTypes)
if !lrawok || !rrawok {
return models.Error
}
var result models.Comparison
result = models.Error // error is the default value
// we know that interfaces here are slices, so need to assert first
var l []interface{} = utils.ConvertTypedSliceToUntypedSlice(lraw)
var r []interface{} = utils.ConvertTypedSliceToUntypedSlice(rraw)
if !utils.IsSameShallowType(l, r) {
result = models.Error
}
if reflect.DeepEqual(l, r) {
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = models.Equal
} else if utils.IsStrictlyContained(l, r) {
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
result = models.Better
} else if utils.IsStrictlyContained(r, l) {
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
result = models.Worse
// TODO: this comparator does not take into account options when several job types are available and several job types are required
// in the same data structure; this is why the test fails;
}
return result
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
"golang.org/x/exp/slices"
)
func LibrariesComparator(lraw, rraw interface{}, preference ...Preference) models.Comparison {
// comparator for Libraries slices (of different lengths) of Library types:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]models.Library)
_, rrawok := rraw.([]models.Library)
if !lrawok || !rrawok {
return models.Error
}
l := lraw.([]models.Library)
r := rraw.([]models.Library)
var interimComparison1 [][]models.Comparison
for _, rLibrary := range r {
var interimComparison2 []models.Comparison
for _, lLibrary := range l {
interimComparison2 = append(interimComparison2, Compare(lLibrary, rLibrary))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []models.Comparison
var consideredIndexes []int
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
consideredIndexes = append(consideredIndexes, index)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, models.Error) {
return models.Error
}
if slices.Contains(finalComparison, models.Worse) {
return models.Worse
}
if SliceContainsOneValue(finalComparison, models.Equal) {
return models.Equal
}
return models.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
"github.com/hashicorp/go-version"
)
func LibraryComparator(lraw, rraw interface{}, preference ...Preference) models.Comparison {
// comparator for single Library type:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(models.Library)
_, rrawok := rraw.(models.Library)
if !lrawok || !rrawok {
return models.Error
}
l := lraw.(models.Library)
lVersion, err := version.NewVersion(l.Version)
if err != nil {
return models.Error
}
r := rraw.(models.Library)
// return 'Error' if the version of the left library is not valid
constraints, err := version.NewConstraint(r.Constraint + " " + r.Version)
if err != nil {
return models.Error
}
// return 'Error' if the names of the libraries are different
if l.Name != r.Name {
return models.Error
}
// else return 'Equal if versions of libraries are equal and the constraint is '='
if r.Constraint == "=" && constraints.Check(lVersion) {
return models.Equal
}
// else return 'Better' if versions of libraries match the constraint
if constraints.Check(lVersion) {
return models.Better
}
// else return 'Worse'
return models.Worse
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
)
func LiteralComparator(l, r interface{}, preference ...Preference) models.Comparison {
// comparator for literal (basically string) types:
// left represent machine capabilities;
// right represent required capabilities;
// which can be only equal or not equal...
// validate input type
_, lok := l.(string)
_, rok := r.(string)
if !lok || !rok {
return models.Error
}
var result models.Comparison
result = models.Error // error is the default value
switch l.(type) {
case string:
if l == r {
result = models.Equal
}
}
return result
}
package matching
import (
// "reflect"
"golang.org/x/exp/slices"
"gitlab.com/nunet/device-management-service/models"
)
func LocalitiesComparator(lraw interface{}, rraw interface{}, preference ...Preference) models.Comparison {
// simplified version of Localities comparator
// which is simply a slice of Locality type;
// we do not have separate type defined for Localities
// it takes preference variable where comparison Preference is defined
// this is the first method that is used to take Preference variable into account
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]models.Locality)
_, rrawok := rraw.([]models.Locality)
if !lrawok || !rrawok {
return models.Error
}
l := lraw.([]models.Locality)
r := rraw.([]models.Locality)
var interimComparison [](map[string]models.Comparison)
for _, rLocality := range r {
field := make(map[string]models.Comparison)
field[rLocality.Kind] = models.Error
for _, lLocality := range l {
if lLocality.Kind == rLocality.Kind {
field[rLocality.Kind] = Compare(lLocality, rLocality)
// this is to make sure that we have a comparison even if slice dimentiones do not match
}
}
interimComparison = append(interimComparison, field)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []models.Comparison
for _, c := range interimComparison {
for _, v := range c { // we know that there is only one value in the map
finalComparison = append(finalComparison, v)
}
}
if slices.Contains(finalComparison, models.Error) {
return models.Error
}
if slices.Contains(finalComparison, models.Worse) {
return models.Worse
}
if SliceContainsOneValue(finalComparison, models.Equal) {
return models.Equal
}
return models.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
)
func LocalityComparator(lraw interface{}, rraw interface{}, preference ...Preference) models.Comparison {
// simplified version of (placeholder)
// comparator for Locality:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(models.Locality)
_, rrawok := rraw.(models.Locality)
if !lrawok || !rrawok {
return models.Error
}
l := lraw.(models.Locality)
r := rraw.(models.Locality)
if l.Kind == r.Kind {
if l.Name == r.Name {
return models.Equal
} else {
return models.Worse
}
} else {
return models.Error
}
}
package matching
import (
"gitlab.com/nunet/device-management-service/models"
)
func MemoryComparator(l, r interface{}, preference ...Preference) models.Comparison {
// comparator for Memory type
// we want to reason about the inner fields of the Memory type and how they compare between left and right
// validate input type
_, lok := l.(models.Memory)
_, rok := r.(models.Memory)
if !lok || !rok {
return models.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["Size"] == models.Error {
return models.Error
}
if comparison["Size"] == models.Worse {
return models.Worse
}
return comparison["Speed"]
// currently this is a very simple comparison, based on the assumption
// that more Size / or equal amount of size and speed is acceptable, but nothing less;
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/utils/validate"
)
func NumericComparator(lraw, rraw interface{}, preference ...Preference) models.Comparison {
// comparator for numeric types:
// left represent machine capabilities;
// right represent required capabilities;
var result models.Comparison
result = models.Error // error is the default value
// validate input type
l, lnumeric := validate.ConvertNumericToFloat64(lraw)
r, rnumeric := validate.ConvertNumericToFloat64(rraw)
if !lnumeric || !rnumeric {
result = models.Error
}
if reflect.DeepEqual(l, r) {
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = models.Equal
} else if l < r {
// if declared machine numeric capability
// is less than jobs required capability
// then the result of comparison in 'Less'
// ("less is required than available")
result = models.Worse
} else if l > r {
// if declared machine numeric capability
// is more than jobs required numeric capability
// then the result of comparison is 'More'
// ("more is required than available")
result = models.Better
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// generic compare function for comparing any custom types given a custom comparator
// for simple types (i.e which are not nested in the map[string]interface{} structure)
type Comparator func(l, r interface{}, preference ...Preference) models.Comparison
func Compare(l, r interface{}, preference ...Preference) models.Comparison {
// TODO: it would be better to pass a pointer as this is a global structure
var comparatorMap = initComparatorMap()
// check if the type is numeric
if _, numeric := validate.ConvertNumericToFloat64(l); numeric {
comparator := comparatorMap["Numeric"]
if comparator == nil {
return models.Error
}
return comparator(l, r)
}
typeName := reflect.TypeOf(l).Name()
// this means that the type is probably a slice of custom types
// we have to get the element types and then map it to existing custom types that know
// so that we can call a correct comparator for that
switch reflect.TypeOf(l).Kind() {
// check if we have a slice of further types
// we need to mention each type explicitly
case reflect.Slice:
if _, ok := l.([]models.GPU); ok {
typeName = "GPUs"
}
if _, ok := l.([]models.Library); ok {
typeName = "Libraries"
}
if _, ok := l.([]models.Locality); ok {
typeName = "Localities"
}
if _, ok := l.([]models.Locality); ok {
typeName = "Localities"
}
}
// select the comparator based on type
comparator := comparatorMap[typeName]
if comparator == nil {
return models.Error
}
return comparator(l, r)
}
type ComparatorMap map[string]Comparator
func initComparatorMap() ComparatorMap {
// comparatorMap holds all defined comparators in a variable that can be passed
// around and searched / referenced
var comparatorMap = make(map[string]Comparator)
comparatorMap["Numeric"] = NumericComparator
comparatorMap["Capability"] = CapabilityComparator
comparatorMap["string"] = LiteralComparator
comparatorMap["Executors"] = ExecutorsComparator
comparatorMap["ExecutorType"] = ExecutorTypeComparator
comparatorMap["JobType"] = JobTypeComparator
comparatorMap["JobTypes"] = JobTypesComparator
comparatorMap["GPUVendor"] = GPUVendorComparator
comparatorMap["GPUs"] = GPUsComparator
comparatorMap["GPU"] = GpuComparator
comparatorMap["Executor"] = ExecutorComparator
comparatorMap["ExecutionResources"] = ExecutionResourcesComparator
comparatorMap["CPU"] = CpuComparator
comparatorMap["Memory"] = MemoryComparator
comparatorMap["Disk"] = DiskComparator
comparatorMap["Library"] = LibraryComparator
comparatorMap["Libraries"] = LibrariesComparator
comparatorMap["Locality"] = LocalityComparator
comparatorMap["Localities"] = LocalitiesComparator
return comparatorMap
}
type Preference struct {
TypeName string
Strength PreferenceString
DefaultComparatorOverride Comparator
}
type PreferenceString string
const (
Hard PreferenceString = "Hard"
Soft PreferenceString = "Soft"
)
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/models"
)
func ReturnComplexComparison(l, r interface{}) models.ComplexComparison {
// Complex comparison is a comparison of two complex types
// Which have nested fields that need to be considered together
// before a final comparison for the whole complex type can be made
// it is a helper function used in some type comparators
vl := reflect.ValueOf(l)
vr := reflect.ValueOf(r)
var complexComparison models.ComplexComparison = make(models.ComplexComparison)
for i := 0; i < vl.NumField(); i++ {
innerTypeName := vl.Type().Field(i).Name
valueL := vl.Field(i).Interface()
valueR := vr.Field(i).Interface()
complexComparison[innerTypeName] = Compare(valueL, valueR)
}
return complexComparison
}
func removeIndex(slice [][]models.Comparison, index int) [][]models.Comparison {
// removeIndex removes the element at the specified index from each sub-slice in the given slice.
// If the index is out of bounds for a sub-slice, the function leaves that sub-slice unmodified.
for i, c := range slice {
if index < 0 || index >= len(c) {
// Index is out of bounds, leave the sub-slice unmodified
continue
}
slice[i] = append(c[:index], c[index+1:]...)
}
return slice
}
func returnBestMatch(dimension []models.Comparison) (models.Comparison, int) {
// while i feel that there could be some weird matrix sorting algorithm that could be used here
// i can't think of any right now, so i will just iterate over the matrix and return matches
// in somewhat manual way
for i, v := range dimension {
if v == models.Equal {
return v, i // selecting an equal match is the most efficient match
}
}
for i, v := range dimension {
if v == models.Better {
return v, i // selecting a better is also not bad
}
}
for i, v := range dimension {
if v == models.Worse {
return v, i // this is just for sport
}
}
for i, v := range dimension {
if v == models.Error {
return v, i // this is just for sport
}
}
return models.Error, -1
}
func SliceContainsOneValue(slice []models.Comparison, value models.Comparison) bool {
// returns true if all elements in the slice are equal to the given value
for _, v := range slice {
if v != value {
return false
}
}
return true
}
package resources
import (
"fmt"
"github.com/shirou/gopsutil/cpu"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/models"
"gorm.io/gorm"
)
// CalcFreeResAndUpdateDB calculates the FreeResources based on the AvailableResources
// and all processes started by DMS, and updates the FreeResources DB's table accordingly
func CalcFreeResAndUpdateDB() error {
cpuInfo, err := cpu.Info()
if err != nil {
return err
}
freeRes, err := calcFreeResources(db.DB, cpuInfo)
if err != nil {
return err
}
err = updateDBFreeResources(freeRes)
if err != nil {
return err
}
return nil
}
// calcFreeResources returns the subtraction between onboarded resources (AvailableResources)
// by the user and the sum of resources usage for every services, virtual machines, plugins
// and any other process started by DMS on the user's machine.
func calcFreeResources(gormDB *gorm.DB, cpuInfo []cpu.InfoStat) (models.FreeResources, error) {
vms, err := queryRunningVMs(gormDB)
if err != nil {
return models.FreeResources{},
fmt.Errorf("Error querying running VMs: %w", err)
}
conts, err := queryRunningConts(gormDB)
if err != nil {
return models.FreeResources{},
fmt.Errorf("Error querying running containers: %w", err)
}
servicesReqs, err := getServiceResourcesRequirements(gormDB)
if err != nil {
return models.FreeResources{},
fmt.Errorf("Error querying ServiceResourceRequirements table: %w", err)
}
resUsageVMs := calcUsedResourcesVMs(vms, cpuInfo)
resUsageDcContainers := calcUsedResourcesConts(conts, servicesReqs)
totalResourcesUsage := addResourcesUsage(resUsageDcContainers, resUsageVMs)
freeResources, err := subtractFromAvailableRes(gormDB, totalResourcesUsage)
if err != nil {
return freeResources,
fmt.Errorf("Couldn't subtract from AvailableResources, Error: %w", err)
}
return freeResources, nil
}
// calcUsedResourcesVMs returns the sum of resource usage between all running Firecracker
// virtual machines started by DMS.
func calcUsedResourcesVMs(vms []models.VirtualMachine, cpuInfo []cpu.InfoStat) models.FreeResources {
var resourcesUsage models.FreeResources
if len(vms) == 0 {
return resourcesUsage
}
var totVCPU, totalMemSizeMib int
for i := 0; i < len(vms); i++ {
totVCPU += vms[i].VCPUCount
totalMemSizeMib += vms[i].MemSizeMib
}
resourcesUsage.Ram = totalMemSizeMib
resourcesUsage.TotCpuHz = totVCPU * int(cpuInfo[0].Mhz) // CPU in MHz
return resourcesUsage
}
// calcUsedResourcesConts returns the sum of resource usage between all running Docker
// containers started by DMS.
func calcUsedResourcesConts(
services []models.Services, requirements map[string]models.ServiceResourceRequirements,
) models.FreeResources {
var resourcesUsage models.FreeResources
if len(services) == 0 {
return resourcesUsage
}
for _, service := range services {
idx := fmt.Sprint(service.ResourceRequirements)
resourcesReq := requirements[idx]
resourcesUsage.TotCpuHz += resourcesReq.CPU
resourcesUsage.Ram += resourcesReq.RAM
}
resourcesUsage.TotCpuHz += 1
resourcesUsage.Ram += 1
return resourcesUsage
}
// queryRunningVMs returns a list of running Firecracker's virtual machines
func queryRunningVMs(gormDB *gorm.DB) ([]models.VirtualMachine, error) {
var vm []models.VirtualMachine
result := gormDB.Where("state = ?", "running").Find(&vm)
if result.Error != nil {
return nil, fmt.Errorf("unable to query running vms - %v", result.Error)
}
return vm, nil
}
// queryRunningConts returns a list of running Docker Containers
func queryRunningConts(gormDB *gorm.DB) ([]models.Services, error) {
var services []models.Services
result := gormDB.Where("job_status = ?", "running").Find(&services)
if result.Error != nil {
return nil, fmt.Errorf("unable to query running containers - %v", result.Error)
}
return services, nil
}
package resources
import (
"strings"
"github.com/jaypipes/ghw"
)
type GPUVendor int
const (
Unknown GPUVendor = iota
NVIDIA
AMD
)
func (g GPUVendor) String() string {
switch g {
case Unknown:
return "Unknown"
case NVIDIA:
return "NVIDIA"
case AMD:
return "AMD"
default:
return "Unknown"
}
}
func DetectGPUVendors() ([]GPUVendor, error) {
var vendors []GPUVendor
gpu, err := ghw.GPU()
if err != nil {
return nil, err
}
for _, card := range gpu.GraphicsCards {
deviceInfo := card.DeviceInfo
if deviceInfo != nil {
class := deviceInfo.Class
if class != nil {
className := strings.ToLower(class.Name)
if strings.Contains(className, "display controller") ||
strings.Contains(className, "vga compatible controller") ||
strings.Contains(className, "3d controller") ||
strings.Contains(className, "2d controller") {
vendor := card.DeviceInfo.Vendor
if vendor != nil {
if strings.Contains(strings.ToLower(vendor.Name), "nvidia") {
vendors = append(vendors, NVIDIA)
}
if strings.Contains(strings.ToLower(vendor.Name), "amd") {
vendors = append(vendors, AMD)
}
}
}
}
}
}
if len(vendors) == 0 {
return []GPUVendor{Unknown}, nil
}
return vendors, nil
}
package resources
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/models"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"gorm.io/gorm"
)
func GetFreeResource(ctx context.Context) (*models.FreeResources, error) {
span := trace.SpanFromContext(ctx)
span.SetAttributes(attribute.String("URL", "/telemetry/free"))
err := CalcFreeResAndUpdateDB()
if err != nil {
return nil, fmt.Errorf("could not calculate free resources and update database: %w", err)
}
var free models.FreeResources
res := db.DB.WithContext(ctx).Find(&free)
if res.Error != nil {
return nil, fmt.Errorf("could not find free resources table: %w", res.Error)
} else if res.RowsAffected == 0 {
return nil, fmt.Errorf("no rows were affected")
}
return &free, nil
}
func updateDBFreeResources(freeRes models.FreeResources) error {
freeRes.ID = "1" // Enforce unique record for a given machine
var freeResourcesModel models.FreeResources
if res := db.DB.Find(&freeResourcesModel); res.RowsAffected == 0 {
result := db.DB.Create(&freeRes)
if result.Error != nil {
return result.Error
}
} else {
result := db.DB.Model(&models.FreeResources{}).Where("id = ?", 1).Select("*").Updates(&freeRes)
if result.Error != nil {
return result.Error
}
}
return nil
}
func getServiceResourcesRequirements(gormDB *gorm.DB) (map[string]models.ServiceResourceRequirements, error) {
var serviceResRequirements []models.ServiceResourceRequirements
result := gormDB.Find(&serviceResRequirements)
if result.Error != nil {
return nil, fmt.Errorf("unable to query resource requirements - %v", result.Error)
}
mappedServicesResRequirements := make(map[string]models.ServiceResourceRequirements)
for _, rr := range serviceResRequirements {
mappedServicesResRequirements[rr.ID] = rr
}
return mappedServicesResRequirements, nil
}
func GetFreeResources() (models.FreeResources, error) {
var freeResource models.FreeResources
if res := db.DB.Find(&freeResource); res.RowsAffected == 0 {
return freeResource, res.Error
}
return freeResource, nil
}
func GetAvailableResources(gormDB *gorm.DB) (models.AvailableResources, error) {
var availableRes models.AvailableResources
if res := gormDB.Find(&availableRes); res.RowsAffected == 0 {
return availableRes, res.Error
}
return availableRes, nil
}
package resources
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("resources")
}
//go:build linux && amd64
package resources
import (
"fmt"
"gitlab.com/nunet/device-management-service/models"
// "gitlab.com/nunet/device-management-service/library/gpudetect"
// "gitlab.com/nunet/device-management-service/dms/onboarding/gpuinfo"
)
func CheckGPU() ([]models.Gpu, error) {
var gpuInfo []models.Gpu
vendors, err := DetectGPUVendors()
if err != nil {
return nil, fmt.Errorf("unable to detect GPU Vendor: %v", err)
}
foundNVIDIA, foundAMD := false, false
for _, vendor := range vendors {
switch vendor {
case NVIDIA:
if !foundNVIDIA {
var gpu models.Gpu
info, err := GetNVIDIAGPUInfo()
if err != nil {
return nil, fmt.Errorf("error getting NVIDIA GPU info: %v", err)
}
for _, i := range info {
gpu.Name = i.GPUName
gpu.FreeVram = i.FreeMemory
gpu.TotVram = i.TotalMemory
gpuInfo = append(gpuInfo, gpu)
}
foundNVIDIA = true
}
case AMD:
if !foundAMD {
var gpu models.Gpu
info, err := GetAMDGPUInfo()
if err != nil {
return nil, fmt.Errorf("error getting AMD GPU info: %v", err)
}
for _, i := range info {
gpu.Name = i.GPUName
gpu.FreeVram = i.FreeMemory
gpu.TotVram = i.TotalMemory
gpuInfo = append(gpuInfo, gpu)
}
foundAMD = true
}
case Unknown:
fmt.Println("Unknown GPU(s) detected")
}
}
return gpuInfo, nil
}
//go:build linux && amd64
package resources
import (
"fmt"
"os/exec"
"regexp"
"strconv"
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
type GPUInfo struct {
GPUName string
TotalMemory uint64
UsedMemory uint64
FreeMemory uint64
Vendor GPUVendor
}
func GetGPUInfo() ([][]GPUInfo, error) {
gpu_infos := make([][]GPUInfo, 2)
amd_gpus, err := GetAMDGPUInfo()
if err != nil {
zlog.Sugar().Warnf("AMD GPU not found: %v", err)
}
gpu_infos[0] = amd_gpus
nvidia_gpus, err := GetNVIDIAGPUInfo()
if err != nil && amd_gpus == nil {
zlog.Sugar().Errorf("NVIDIA GPU not found: %v", err)
return nil, err
}
gpu_infos[1] = nvidia_gpus
return gpu_infos, nil
}
func GetAMDGPUInfo() ([]GPUInfo, error) {
cmd := exec.Command("rocm-smi", "--showid", "--showproductname", "--showmeminfo", "vram")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("AMD ROCm not installed, initialized, or configured (reboot recommended for newly installed AMD GPU Drivers): %s", err)
}
outputStr := string(output)
gpuName := regexp.MustCompile(`Card series:\s+([^\n]+)`)
total := regexp.MustCompile(`Total Memory \(B\):\s+(\d+)`)
used := regexp.MustCompile(`Total Used Memory \(B\):\s+(\d+)`)
gpuNameMatches := gpuName.FindAllStringSubmatch(outputStr, -1)
totalMatches := total.FindAllStringSubmatch(outputStr, -1)
usedMatches := used.FindAllStringSubmatch(outputStr, -1)
if len(gpuNameMatches) == len(totalMatches) && len(totalMatches) == len(usedMatches) {
var gpuInfos []GPUInfo
for i := range gpuNameMatches {
gpuName := gpuNameMatches[i][1]
totalMemoryBytes, err := strconv.ParseInt(totalMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse total amdgpu vram: %s", err)
}
usedMemoryBytes, err := strconv.ParseInt(usedMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse used amdgpu vram: %s", err)
}
totalMemoryMiB := totalMemoryBytes / 1024 / 1024
usedMemoryMiB := usedMemoryBytes / 1024 / 1024
freeMemoryMiB := totalMemoryMiB - usedMemoryMiB
gpuInfo := GPUInfo{
GPUName: "AMD " + gpuName,
TotalMemory: uint64(totalMemoryMiB),
UsedMemory: uint64(usedMemoryMiB),
FreeMemory: uint64(freeMemoryMiB),
Vendor: AMD,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
return nil, fmt.Errorf("failed to find AMD GPU information or vram information in the output")
}
func GetNVIDIAGPUInfo() ([]GPUInfo, error) {
// Initialize NVML
ret := nvml.Init()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("NVIDIA Management Library not installed, initialized or configured (reboot recommended for newly installed NVIDIA GPU drivers): %s", nvml.ErrorString(ret))
}
defer nvml.Shutdown()
// Get the number of GPU devices
deviceCount, ret := nvml.DeviceGetCount()
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("failed to get device count: %s", nvml.ErrorString(ret))
}
var gpuInfos []GPUInfo
// Iterate over each device
for i := uint32(0); i < uint32(deviceCount); i++ {
// Get the device handle
device, ret := nvml.DeviceGetHandleByIndex(int(i))
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("failed to get device handle for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the device name
name, ret := nvml.DeviceGetName(device)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("failed to get name for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the memory info
memory, ret := nvml.DeviceGetMemoryInfo(device)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("failed to get nvidiagpu vram info for device %d: %s", i, nvml.ErrorString(ret))
}
gpuInfo := GPUInfo{
GPUName: name,
TotalMemory: memory.Total / 1024 / 1024,
UsedMemory: memory.Used / 1024 / 1024,
FreeMemory: memory.Free / 1024 / 1024,
Vendor: NVIDIA,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
//go:build linux && amd64
package resources
import (
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/mem"
"gitlab.com/nunet/device-management-service/models"
)
// totalRamInMB fetches total memory installed on host machine
func totalRamInMB() uint64 {
v, _ := mem.VirtualMemory()
ramInMB := v.Total / 1024 / 1024
return ramInMB
}
// totalCPUInMHz fetches compute capacity of the host machine
func totalCPUInMHz() float64 {
cores, _ := cpu.Info()
var totalCompute float64
for i := 0; i < len(cores); i++ {
totalCompute += cores[i].Mhz
}
return totalCompute
}
// fetches the max clock speed of a single core
// Assuming all cores have the same clock speed
func Hz_per_cpu() float64 {
cores, _ := cpu.Info()
return cores[0].Mhz
}
// GetTotalProvisioned returns Provisioned struct with provisioned memory and CPU.
func GetTotalProvisioned() *models.Provisioned {
cores, _ := cpu.Info()
provisioned := &models.Provisioned{
CPU: totalCPUInMHz(),
Memory: totalRamInMB(),
NumCores: uint64(len(cores)),
}
return provisioned
}
package resources
import (
"fmt"
"gitlab.com/nunet/device-management-service/models"
"gorm.io/gorm"
)
// negativeValueError is a type struct used to return a custom error for the
// resultsInNegativeValues functions
type negativeValueError struct {
fieldName string
r1 int
r2 int
}
func (e *negativeValueError) Error() string {
return fmt.Sprintf("Error: %s subtraction results in negative values. (%d - %d)", e.fieldName, e.r1, e.r2)
}
// Note: Despite FreeResources model's main goal being represent the machine's free resources,
// we're using its model to represent general resources usage for all operations between
// resources-related structs (a REFACTORING is needed, we need to simplify the resource-relate models)
// addResourcesUsage returns the sum between each field of two FreeResources
// structs. Use it when increasing resources usage based on two resources-usage
// structs.
func addResourcesUsage(r1, r2 models.FreeResources) models.FreeResources {
return models.FreeResources{
TotCpuHz: r1.TotCpuHz + r2.TotCpuHz,
Ram: r1.Ram + r2.Ram,
Disk: r1.Disk + r2.Disk,
}
}
// subResourcesUsage returns the result of subtracting each field of the second FreeResources
// struct from the corresponding field of the first. It is used when decreasing resources usage
// based on two resources-usage structs.
func subResourcesUsage(r1, r2 models.FreeResources) (models.FreeResources, error) {
// We don't need this function if when refactoring the resources models,
// we use unint instead of int
if err := resultsInNegativeValuesFreeRes(&r1, &r2); err != nil {
return models.FreeResources{
TotCpuHz: r1.TotCpuHz - r2.TotCpuHz,
Ram: r1.Ram - r2.Ram,
Disk: r1.Disk - r2.Disk,
}, fmt.Errorf("Subtraction of resources would result in negative values, Error: %w", err)
}
return models.FreeResources{
TotCpuHz: r1.TotCpuHz - r2.TotCpuHz,
Ram: r1.Ram - r2.Ram,
Disk: r1.Disk - r2.Disk,
}, nil
}
// resultsInNegativeValuesFreeRes checks if any subtraction operation between the
// fields of two FreeResources structs results in a negative value. It returns an error if so,
// since resource values cannot be negative.
func resultsInNegativeValuesFreeRes(r1, r2 *models.FreeResources) error {
// We don't need this function if when refactoring the resources models,
// we use unint instead of int
if r1.TotCpuHz-r2.TotCpuHz < 0 {
return &negativeValueError{fieldName: "TotCpuHz", r1: r1.TotCpuHz, r2: r2.TotCpuHz}
}
if r1.Ram-r2.Ram < 0 {
return &negativeValueError{fieldName: "Ram", r1: r1.Ram, r2: r2.Ram}
}
if r1.Disk-r2.Disk < 0 {
return &negativeValueError{fieldName: "Disk", r1: int(r1.Disk), r2: int(r2.Disk)}
}
return nil
}
// resultsInNegativeValuesAvailableRes checks if any subtraction operation between the
// fields of two different resource-related structs (AvailableResources and FreeResources) results in a negative value.
// It returns an error if so, since resource values cannot be negative.
func resultsInNegativeValuesAvailableRes(r1 models.AvailableResources, r2 models.FreeResources) error {
// Basically duplicate function just because we have different models
// We will remove that when simplifying resources structs
if r1.TotCpuHz-r2.TotCpuHz < 0 {
return &negativeValueError{fieldName: "TotCpuHz", r1: r1.TotCpuHz, r2: r2.TotCpuHz}
}
if r1.Ram-r2.Ram < 0 {
return &negativeValueError{fieldName: "Ram", r1: r1.Ram, r2: r2.Ram}
}
if r1.Disk-r2.Disk < 0 {
return &negativeValueError{fieldName: "Disk", r1: int(r1.Disk), r2: int(r2.Disk)}
}
return nil
}
// subtractFromAvailableRes returns the difference between AvailableResources usage and
// FreeResources struct.
func subtractFromAvailableRes(gormDB *gorm.DB, resourcesUsage models.FreeResources,
) (models.FreeResources, error) {
var freeRes models.FreeResources
availableRes, err := GetAvailableResources(gormDB)
if err != nil {
return freeRes,
fmt.Errorf("Couldn't query AvailableResources for subtraction, Error: %w", err)
}
if err := resultsInNegativeValuesAvailableRes(availableRes, resourcesUsage); err != nil {
return freeRes,
fmt.Errorf("Subtraction of resources results in negative values, Error: %w", err)
}
freeRes.TotCpuHz = availableRes.TotCpuHz - resourcesUsage.TotCpuHz
freeRes.Vcpu = freeRes.TotCpuHz / int(availableRes.CpuHz)
freeRes.Ram = availableRes.Ram - resourcesUsage.Ram
freeRes.Disk = float64(availableRes.Disk) - resourcesUsage.Disk
freeRes.NTXPricePerMinute = availableRes.NTXPricePerMinute
return freeRes, nil
}
package docker
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/stdcopy"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
// Client wraps the Docker client to provide high-level operations on Docker containers and networks.
type Client struct {
client *client.Client // Embed the Docker client.
}
// NewDockerClient initializes a new Docker client with environment variables and API version negotiation.
func NewDockerClient() (*Client, error) {
c, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}
return &Client{client: c}, nil
}
// IsInstalled checks if Docker is installed and reachable by pinging the Docker daemon.
func (c *Client) IsInstalled(ctx context.Context) bool {
_, err := c.client.Ping(ctx)
return err == nil
}
// CreateContainer creates a new Docker container with the specified configuration.
func (c *Client) CreateContainer(
ctx context.Context,
config *container.Config,
hostConfig *container.HostConfig,
networkingConfig *network.NetworkingConfig,
platform *v1.Platform,
name string,
) (string, error) {
_, err := c.PullImage(ctx, config.Image)
if err != nil {
return "", err
}
resp, err := c.client.ContainerCreate(
ctx,
config,
hostConfig,
networkingConfig,
platform,
name,
)
if err != nil {
return "", err
}
return resp.ID, nil
}
// InspectContainer returns detailed information about a Docker container.
func (c *Client) InspectContainer(ctx context.Context, id string) (types.ContainerJSON, error) {
return c.client.ContainerInspect(ctx, id)
}
// FollowLogs tails the logs of a specified container, returning separate readers for stdout and stderr.
func (c *Client) FollowLogs(ctx context.Context, id string) (stdout, stderr io.Reader, err error) {
cont, err := c.InspectContainer(ctx, id)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
}
logsReader, err := c.client.ContainerLogs(ctx, cont.ID, logOptions)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container logs")
}
stdoutReader, stdoutWriter := io.Pipe()
stderrReader, stderrWriter := io.Pipe()
go func() {
stdoutBuffer := bufio.NewWriter(stdoutWriter)
stderrBuffer := bufio.NewWriter(stderrWriter)
defer func() {
logsReader.Close()
stdoutBuffer.Flush()
stdoutWriter.Close()
stderrBuffer.Flush()
stderrWriter.Close()
}()
_, err = stdcopy.StdCopy(stdoutBuffer, stderrBuffer, logsReader)
if err != nil && !errors.Is(err, context.Canceled) {
zlog.Sugar().Warnf("context closed while getting logs: %v\n", err)
}
}()
return stdoutReader, stderrReader, nil
}
// StartContainer starts a specified Docker container.
func (c *Client) StartContainer(ctx context.Context, containerID string) error {
return c.client.ContainerStart(ctx, containerID, types.ContainerStartOptions{})
}
// WaitContainer waits for a container to stop, returning channels for the result and errors.
func (c *Client) WaitContainer(
ctx context.Context,
containerID string,
) (<-chan container.ContainerWaitOKBody, <-chan error) {
return c.client.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
}
// StopContainer stops a running Docker container with a specified timeout.
func (c *Client) StopContainer(
ctx context.Context,
containerID string,
timeout time.Duration,
) error {
return c.client.ContainerStop(ctx, containerID, &timeout)
}
// RemoveContainer removes a Docker container, optionally forcing removal and removing associated volumes.
func (c *Client) RemoveContainer(ctx context.Context, containerID string) error {
return c.client.ContainerRemove(
ctx,
containerID,
types.ContainerRemoveOptions{RemoveVolumes: true, Force: true},
)
}
// removeContainers removes all containers matching the specified filters.
func (c *Client) removeContainers(ctx context.Context, filterz filters.Args) error {
containers, err := c.client.ContainerList(
ctx,
types.ContainerListOptions{All: true, Filters: filterz},
)
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(containers))
for _, container := range containers {
wg.Add(1)
go func(container types.Container, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.RemoveContainer(ctx, container.ID)
}(container, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// removeNetworks removes all networks matching the specified filters.
func (c *Client) removeNetworks(ctx context.Context, filterz filters.Args) error {
networks, err := c.client.NetworkList(ctx, types.NetworkListOptions{Filters: filterz})
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(networks))
for _, network := range networks {
wg.Add(1)
go func(network types.NetworkResource, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.client.NetworkRemove(ctx, network.ID)
}(network, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// RemoveObjectsWithLabel removes all Docker containers and networks with a specific label.
func (c *Client) RemoveObjectsWithLabel(ctx context.Context, label string, value string) error {
filterz := filters.NewArgs(
filters.Arg("label", fmt.Sprintf("%s=%s", label, value)),
)
containerErr := c.removeContainers(ctx, filterz)
networkErr := c.removeNetworks(ctx, filterz)
return multierr.Combine(containerErr, networkErr)
}
// GetOutputStream streams the logs for a specified container.
// The 'since' parameter specifies the timestamp from which to start streaming logs.
// The 'follow' parameter indicates whether to continue streaming logs as they are produced.
// Returns an io.ReadCloser to read the output stream and an error if the operation fails.
func (c *Client) GetOutputStream(
ctx context.Context,
containerID string,
since string,
follow bool,
) (io.ReadCloser, error) {
cont, err := c.InspectContainer(ctx, containerID)
if err != nil {
return nil, errors.Wrap(err, "failed to get container")
}
if !cont.State.Running {
return nil, fmt.Errorf("cannot get logs for a container that is not running")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: follow,
Since: since,
}
logReader, err := c.client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
return nil, errors.Wrap(err, "failed to get container logs")
}
return logReader, nil
}
// FindContainer searches for a container by label and value, returning its ID if found.
func (c *Client) FindContainer(ctx context.Context, label string, value string) (string, error) {
containers, err := c.client.ContainerList(ctx, types.ContainerListOptions{All: true})
if err != nil {
return "", err
}
for _, container := range containers {
if container.Labels[label] == value {
return container.ID, nil
}
}
return "", fmt.Errorf("unable to find container for %s=%s", label, value)
}
// PullImage pulls a Docker image from a registry.
func (c *Client) PullImage(ctx context.Context, imageName string) (string, error) {
out, err := c.client.ImagePull(ctx, imageName, types.ImagePullOptions{})
if err != nil {
zlog.Sugar().Errorf("unable to pull image: %v", err)
return "", err
}
defer out.Close()
d := json.NewDecoder(io.TeeReader(out, os.Stdout))
var message jsonmessage.JSONMessage
var digest string
for {
if err := d.Decode(&message); err != nil {
if err == io.EOF {
break
}
zlog.Sugar().Errorf("unable pull image: %v", err)
return "", err
}
if message.Aux != nil {
continue
}
if message.Error != nil {
zlog.Sugar().Errorf("unable pull image: %v", message.Error.Message)
return "", errors.New(message.Error.Message)
}
if strings.HasPrefix(message.Status, "Digest") {
digest = strings.TrimPrefix(message.Status, "Digest: ")
}
}
return digest, nil
}
package docker
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"github.com/pkg/errors"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/utils"
)
const (
labelExecutorName = "nunet-executor"
labelJobID = "nunet-jobID"
labelExecutionID = "nunet-executionID"
outputStreamCheckTickTime = 100 * time.Millisecond
outputStreamCheckTimeout = 5 * time.Second
)
// Executor manages the lifecycle of Docker containers for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Docker client for container management.
}
// NewExecutor initializes a new Executor instance with a Docker client.
func NewExecutor(_ context.Context, id string) (*Executor, error) {
dockerClient, err := NewDockerClient()
if err != nil {
return nil, err
}
return &Executor{
ID: id,
client: dockerClient,
}, nil
}
// IsInstalled checks if Docker is installed and the Docker daemon is accessible.
func (e *Executor) IsInstalled(ctx context.Context) bool {
return e.client.IsInstalled(ctx)
}
// Start begins the execution of a request by starting a Docker container.
func (e *Executor) Start(ctx context.Context, request *models.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// container is already running.
containerID, err := e.FindRunningContainer(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running container for this execution, we will instead check for a handler, and
// failing that will create a new container.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
} else {
return fmt.Errorf("execution is already completed")
}
}
// Create a new handler for the execution.
containerID, err = e.newDockerExecutionContainer(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new container: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
containerID: containerID,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the container.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *models.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *models.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *models.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s recieved results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// GetLogStream provides a stream of output logs for a specific execution.
// Parameters 'withHistory' and 'follow' control whether to include past logs
// and whether to keep the stream open for new logs, respectively.
// It returns an error if the execution is not found.
func (e *Executor) GetLogStream(
ctx context.Context,
request models.LogStreamRequest,
) (io.ReadCloser, error) {
// It's possible we've recorded the execution as running, but have not yet added the handler to
// the handler map because we're still waiting for the container to start. We will try and wait
// for a few seconds to see if the handler is added to the map.
chHandler := make(chan *executionHandler)
chExit := make(chan struct{})
go func(ch chan *executionHandler, exit chan struct{}) {
// Check the handlers every 100ms and send it down the
// channel if we find it. If we don't find it after 5 seconds
// then we'll be told on the exit channel
ticker := time.NewTicker(outputStreamCheckTickTime)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h, found := e.handlers.Get(request.ExecutionID)
if found {
ch <- h
return
}
case <-exit:
ticker.Stop()
return
}
}
}(chHandler, chExit)
// Either we'll find a handler for the execution (which might have finished starting)
// or we'll timeout and return an error.
select {
case handler := <-chHandler:
return handler.outputStream(ctx, request)
case <-time.After(outputStreamCheckTimeout):
chExit <- struct{}{}
}
return nil, fmt.Errorf("execution (%s) not found", request.ExecutionID)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *models.ExecutionRequest,
) (*models.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// Cleanup removes all Docker resources associated with the executor.
// This includes removing containers including networks and volumes with the executor's label.
func (e *Executor) Cleanup(ctx context.Context) error {
err := e.client.RemoveObjectsWithLabel(ctx, labelExecutorName, e.ID)
if err != nil {
return fmt.Errorf("failed to remove containers: %w", err)
}
zlog.Info("Cleaned up all Docker resources")
return nil
}
// newDockerExecutionContainer is an internal method called by Start to set up a new Docker container
// for the job execution. It configures the container based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up environment variables, mounts and resource
// constraints. It then creates the container but does not start it.
// The method returns a container.CreateResponse and an error if any part of the setup fails.
func (e *Executor) newDockerExecutionContainer(
ctx context.Context,
params *models.ExecutionRequest,
) (string, error) {
dockerArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return "", fmt.Errorf("failed to decode docker engine spec: %w", err)
}
containerConfig := container.Config{
Image: dockerArgs.Image,
Tty: false,
Env: dockerArgs.Environment,
Entrypoint: dockerArgs.Entrypoint,
Cmd: dockerArgs.Cmd,
Labels: e.containerLabels(params.JobID, params.ExecutionID),
WorkingDir: dockerArgs.WorkingDirectory,
}
mounts, err := makeContainerMounts(params.Inputs, params.Outputs, params.ResultsDir)
if err != nil {
return "", fmt.Errorf("failed to create container mounts: %w", err)
}
zlog.Sugar().Infof("Adding %d GPUs to request", len(params.Resources.GPUs))
deviceRequests, deviceMappings, err := configureDevices(params.Resources)
if err != nil {
return "", fmt.Errorf("creating container devices: %w", err)
}
hostConfig := container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Freq),
CPUCount: int64(params.Resources.CPU.Cores),
DeviceRequests: deviceRequests,
Devices: deviceMappings,
},
}
if _, err = e.client.PullImage(ctx, dockerArgs.Image); err != nil {
return "", fmt.Errorf("failed to pull docker image: %w", err)
}
executionContainer, err := e.client.CreateContainer(
ctx,
&containerConfig,
&hostConfig,
nil,
nil,
labelExecutionValue(e.ID, params.JobID, params.ExecutionID),
)
if err != nil {
return "", fmt.Errorf("failed to create container: %w", err)
}
return executionContainer, nil
}
// configureDevices sets up the device requests and mappings for the container based on the
// resources requested by the execution. Currently, only GPUs are supported.
func configureDevices(
resources *models.ExecutionResources,
) ([]container.DeviceRequest, []container.DeviceMapping, error) {
requests := []container.DeviceRequest{}
mappings := []container.DeviceMapping{}
vendorGroups := make(map[models.GPUVendor][]models.GPU)
for _, gpu := range resources.GPUs {
vendorGroups[gpu.Vendor] = append(vendorGroups[gpu.Vendor], gpu)
}
for vendor, gpus := range vendorGroups {
switch vendor {
case models.GPUVendorNvidia:
deviceIDs := make([]string, len(gpus))
for i, gpu := range gpus {
deviceIDs[i] = fmt.Sprint(gpu.Index)
}
requests = append(requests, container.DeviceRequest{
DeviceIDs: deviceIDs,
Capabilities: [][]string{{"gpu"}},
})
case models.GPUVendorAMDATI:
// https://docs.amd.com/en/latest/deploy/docker.html
mappings = append(mappings, container.DeviceMapping{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
})
fallthrough
case models.GPUVendorIntel:
// https://github.com/openvinotoolkit/docker_ci/blob/master/docs/accelerators.md
var paths []string
for _, gpu := range gpus {
paths = append(
paths,
filepath.Join("/dev/dri/by-path/", fmt.Sprintf("pci-%s-card", gpu.PCIAddress)),
)
paths = append(
paths,
filepath.Join(
"/dev/dri/by-path/",
fmt.Sprintf("pci-%s-render", gpu.PCIAddress),
),
)
}
for _, path := range paths {
// We need to use the PCI address of the GPU to look up the correct devices to expose
absPath, err := filepath.EvalSymlinks(path)
if err != nil {
return nil, nil, errors.Wrapf(
err,
"could not find attached device for GPU at %q",
path,
)
}
mappings = append(mappings, container.DeviceMapping{
PathOnHost: absPath,
PathInContainer: absPath,
CgroupPermissions: "rwm",
})
}
default:
return nil, nil, fmt.Errorf("job requires GPU from unsupported vendor %q", vendor)
}
}
return requests, mappings, nil
}
// makeContainerMounts creates the mounts for the container based on the input and output
// volumes provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeContainerMounts(
inputs []*models.StorageVolumeExecutor,
outputs []*models.StorageVolumeExecutor,
resultsDir string,
) ([]mount.Mount, error) {
// the actual mounts we will give to the container
// these are paths for both input and output data
var mounts []mount.Mount
for _, input := range inputs {
if input.Type != models.StorageVolumeTypeBind {
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: input.Source,
Target: input.Target,
ReadOnly: input.ReadOnly,
})
} else {
return nil, fmt.Errorf("unsupported storage volume type: %s", input.Type)
}
}
for _, output := range outputs {
if output.Source == "" {
return nil, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return nil, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return nil, fmt.Errorf("failed to create results directory: %w", err)
}
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: output.Source,
Target: output.Target,
// this is an output volume so can be written to
ReadOnly: false,
})
}
return mounts, nil
}
// containerLabels returns the labels to be applied to the container for the given job and execution.
func (e *Executor) containerLabels(jobID string, executionID string) map[string]string {
return map[string]string{
labelExecutorName: e.ID,
labelJobID: labelJobValue(e.ID, jobID),
labelExecutionID: labelExecutionValue(e.ID, jobID, executionID),
}
}
// labelJobValue returns the value for the job label.
func labelJobValue(executorID string, jobID string) string {
return fmt.Sprintf("%s_%s", executorID, jobID)
}
// labelExecutionValue returns the value for the execution label.
func labelExecutionValue(executorID string, jobID string, executionID string) string {
return fmt.Sprintf("%s_%s_%s", executorID, jobID, executionID)
}
// FindRunningContainer finds the container that is running the execution
// with the given ID. It returns the container ID if found, or an error if
// the container is not found.
func (e *Executor) FindRunningContainer(
ctx context.Context,
jobID string,
executionID string,
) (string, error) {
labelValue := labelExecutionValue(e.ID, jobID, executionID)
return e.client.FindContainer(ctx, labelExecutionID, labelValue)
}
package docker
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"strconv"
"sync/atomic"
"time"
"gitlab.com/nunet/device-management-service/models"
)
const DestroyTimeout = time.Second * 10
// executionHandler manages the lifecycle and execution of a Docker container for a specific job.
type executionHandler struct {
// provided by the executor
ID string
client *Client // Docker client for container management.
// meta data about the task
jobID string
executionID string
containerID string
resultsDir string // Directory to store execution results.
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // BLocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *models.ExecutionResult
}
// active checks if the execution handler's container is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the container and handles its execution lifecycle.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
if err := h.destroy(DestroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
if err := h.client.StartContainer(ctx, h.containerID); err != nil {
h.result = models.NewFailedExecutionResult(fmt.Errorf("failed to start container: %v", err))
return
}
close(h.activeCh) // Indicate that the container has started.
var containerError error
var containerExitStatusCode int64
// Wait for the container to finish or for an execution error.
statusCh, errCh := h.client.WaitContainer(ctx, h.containerID)
select {
case status := <-ctx.Done():
h.result = models.NewFailedExecutionResult(fmt.Errorf("execution cancelled: %v", status))
return
case err := <-errCh:
zlog.Sugar().Errorf("error while waiting for container: %v\n", err)
h.result = models.NewFailedExecutionResult(
fmt.Errorf("failed to wait for container: %v", err),
)
return
case exitStatus := <-statusCh:
containerExitStatusCode = exitStatus.StatusCode
containerJSON, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
h.result = &models.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: err.Error(),
}
return
}
if containerJSON.ContainerJSONBase.State.OOMKilled {
containerError = errors.New("container was killed due to OOM")
h.result = &models.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
return
}
if exitStatus.Error != nil {
containerError = errors.New(exitStatus.Error.Message)
}
}
// Follow container logs to capture stdout and stderr.
stdoutPipe, stderrPipe, err := h.client.FollowLogs(ctx, h.containerID)
if err != nil {
followError := fmt.Errorf("failed to follow container logs: %w", err)
if containerError != nil {
h.result = &models.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: fmt.Sprintf(
"container error: '%s'. logs error: '%s'",
containerError,
followError,
),
}
} else {
h.result = &models.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: followError.Error(),
}
}
return
}
// Capture the logs from the stdout and stderr pipes.
h.result = models.NewExecutionResult(int(containerExitStatusCode))
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
h.result.STDERR, _ = bufio.NewReader(stderrPipe).ReadString('\x00')
}
// kill sends a stop signal to the container.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.StopContainer(ctx, h.containerID, DestroyTimeout)
}
// destroy cleans up the container and its associated resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// stop the container
if err := h.kill(ctx); err != nil {
return fmt.Errorf("failed to kill container (%s): %w", h.containerID, err)
}
if err := h.client.RemoveContainer(ctx, h.containerID); err != nil {
return err
}
// Remove related objects like networks or volumes created for this execution.
return h.client.RemoveObjectsWithLabel(
ctx,
labelExecutionID,
labelExecutionValue(h.ID, h.jobID, h.executionID),
)
}
func (h *executionHandler) outputStream(
ctx context.Context,
request models.LogStreamRequest,
) (io.ReadCloser, error) {
since := "1" // Default to the start of UNIX time to get all logs.
if request.Tail {
since = strconv.FormatInt(time.Now().Unix(), 10)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-h.activeCh: // Ensure the container is active before attempting to stream logs.
}
// Gets the underlying reader, and provides data since the value of the `since` timestamp.
return h.client.GetOutputStream(ctx, h.containerID, since, request.Follow)
}
package docker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("docker.executor")
}
package docker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyImage = "image"
EngineKeyEntrypoint = "entrypoint"
EngineKeyCmd = "cmd"
EngineKeyEnvironment = "environment"
EngineKeyWorkingDirectory = "working_directory"
)
// EngineSpec contains necessary parameters to execute a docker job.
type EngineSpec struct {
// Image this should be pullable by docker
Image string `json:"image,omitempty"`
// Entrypoint optionally override the default entrypoint
Entrypoint []string `json:"entrypoint,omitempty"`
// Cmd specifies the command to run in the container
Cmd []string `json:"cmd,omitempty"`
// EnvironmentVariables is a slice of env to run the container with
Environment []string `json:"environment,omitempty"`
// WorkingDirectory inside the container
WorkingDirectory string `json:"working_directory,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.Image) {
return fmt.Errorf("invalid docker engine params: image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a docker engine spec
// It converts the params into a docker EngineSpec struct and validates it
func DecodeSpec(spec *models.SpecConfig) (EngineSpec, error) {
if !spec.IsType(models.ExecutorTypeDocker) {
return EngineSpec{}, fmt.Errorf(
"invalid docker engine type. expected %s, but recieved: %s",
models.ExecutorTypeDocker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid docker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode docker engine params: %w", err)
}
var dockerSpec *EngineSpec
if err := json.Unmarshal(paramBytes, &dockerSpec); err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode docker engine params: %w", err)
}
return *dockerSpec, dockerSpec.Validate()
}
// DockerEngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Docker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type DockerEngineBuilder struct {
eb *models.SpecConfig
}
// NewDockerEngineBuilder function initializes a new DockerEngineBuilder instance.
// It sets the engine type to model.EngineDocker.String() and image as per the input argument.
func NewDockerEngineBuilder(image string) *DockerEngineBuilder {
eb := models.NewSpecConfig(models.ExecutorTypeDocker)
eb.WithParam(EngineKeyImage, image)
return &DockerEngineBuilder{eb: eb}
}
// WithEntrypoint is a builder method that sets the Docker engine entrypoint.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *DockerEngineBuilder) WithEntrypoint(e ...string) *DockerEngineBuilder {
b.eb.WithParam(EngineKeyEntrypoint, e)
return b
}
// WithCmd is a builder method that sets the Docker engine's Command.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *DockerEngineBuilder) WithCmd(c ...string) *DockerEngineBuilder {
b.eb.WithParam(EngineKeyCmd, c)
return b
}
// WithEnvironment is a builder method that sets the Docker engine's environment variables.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *DockerEngineBuilder) WithEnvironment(e ...string) *DockerEngineBuilder {
b.eb.WithParam(EngineKeyEnvironment, e)
return b
}
// WithWorkingDirectory is a builder method that sets the Docker engine's working directory.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *DockerEngineBuilder) WithWorkingDirectory(w string) *DockerEngineBuilder {
b.eb.WithParam(EngineKeyWorkingDirectory, w)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *DockerEngineBuilder) Build() *models.SpecConfig {
return b.eb
}
package firecracker
import (
"context"
"fmt"
"os"
"os/exec"
"syscall"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
)
const pidCheckTickTime = 100 * time.Millisecond
// Client wraps the Firecracker SDK to provide high-level operations on Firecracker VMs.
type Client struct{}
func NewFirecrackerClient() (*Client, error) {
return &Client{}, nil
}
// IsInstalled checks if Firecracker is installed on the host.
func (c *Client) IsInstalled() bool {
// LookPath searches for an executable named file in the directories named by the PATH environment variable.
// There might be a better way to check if Firecracker is installed.
_, err := exec.LookPath("firecracker")
return err == nil
}
// CreateVM creates a new Firecracker VM with the specified configuration.
func (c *Client) CreateVM(
ctx context.Context,
cfg firecracker.Config,
) (*firecracker.Machine, error) {
cmd := firecracker.VMCommandBuilder{}.
WithSocketPath(cfg.SocketPath).
Build(ctx)
machineOpts := []firecracker.Opt{
firecracker.WithProcessRunner(cmd),
}
m, err := firecracker.NewMachine(ctx, cfg, machineOpts...)
return m, err
}
// StartVM starts the Firecracker VM.
func (c *Client) StartVM(ctx context.Context, m *firecracker.Machine) error {
return m.Start(ctx)
}
// ShutdownVM shuts down the Firecracker VM.
func (c *Client) ShutdownVM(ctx context.Context, m *firecracker.Machine) error {
return m.Shutdown(ctx)
}
// DestroyVM destroys the Firecracker VM.
func (c *Client) DestroyVM(
ctx context.Context,
m *firecracker.Machine,
timeout time.Duration,
) error {
// Remove the socket file.
defer os.Remove(m.Cfg.SocketPath)
// Get the PID of the Firecracker process and shut down the VM.
// If the process is still running after the timeout, kill it.
pid, _ := m.PID()
c.ShutdownVM(ctx, m)
// If the process is not running, return early.
if pid <= 0 {
return nil
}
// This checks if the process is still running every pidCheckTickTime.
// If the process is still running after the timeout it will set done to false.
done := make(chan bool, 1)
go func() {
ticker := time.NewTicker(pidCheckTickTime)
defer ticker.Stop()
to := time.NewTimer(timeout)
defer to.Stop()
for {
select {
case <-to.C:
done <- false
return
case <-ticker.C:
if pid, _ := m.PID(); pid <= 0 {
done <- true
return
}
}
}
}()
// Wait for the check to finish.
select {
case killed := <-done:
if !killed {
// The shutdown request timed out, kill the process with SIGKILL.
err := syscall.Kill(int(pid), syscall.SIGKILL)
if err != nil {
return fmt.Errorf("failed to kill process: %v", err)
}
}
}
return nil
}
// FindVM finds a Firecracker VM by its socket path.
// This implementation checks if the VM is running by sending a request to the Firecracker API.
func (c *Client) FindVM(ctx context.Context, socketPath string) (*firecracker.Machine, error) {
// Check if the socket file exists.
if _, err := os.Stat(socketPath); err != nil {
return nil, fmt.Errorf("VM with socket path %v not found", socketPath)
}
// Create a new Firecracker machine instance.
cmd := firecracker.VMCommandBuilder{}.WithSocketPath(socketPath).Build(ctx)
machine, err := firecracker.NewMachine(
ctx,
firecracker.Config{SocketPath: socketPath},
firecracker.WithProcessRunner(cmd),
)
if err != nil {
return nil, fmt.Errorf("failed to create machine with socket %s: %v", socketPath, err)
}
// Check if the VM is running by getting its instance info.
info, err := machine.DescribeInstanceInfo(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get instance info for socket %s: %v", socketPath, err)
}
if *info.State != "Running" {
return nil, fmt.Errorf(
"VM with socket %s is not running, current state: %s",
socketPath,
*info.State,
)
}
return machine, nil
}
package firecracker
import (
"context"
"fmt"
"io"
"os"
"sync"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
fcModels "github.com/firecracker-microvm/firecracker-go-sdk/client/models"
"go.uber.org/multierr"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/utils"
)
const (
socketDir = "/tmp"
DefaultCpuCount int64 = 1
DefaultMemSize int64 = 50
)
// Executor manages the lifecycle of Firecracker VMs for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Firecracker client for VM management.
}
// NewExecutor initializes a new executor for Firecracker VMs.
func NewExecutor(
_ context.Context,
id string,
) (*Executor, error) {
firecrackerClient, err := NewFirecrackerClient()
if err != nil {
return nil, err
}
fe := &Executor{
ID: id,
client: firecrackerClient,
}
return fe, nil
}
// IsInstalled checks if Firecracker is installed on the host.
func (e *Executor) IsInstalled(ctx context.Context) bool {
return e.client.IsInstalled()
}
// start begins the execution of a request by starting a new Firecracker VM.
func (e *Executor) Start(ctx context.Context, request *models.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// VM is already running.
machine, err := e.FindRunningVM(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running VM for this execution, we will instead check for a handler, and
// failing that will create a new VM.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
} else {
return fmt.Errorf("execution is already completed")
}
}
// Create a new handler for the execution.
machine, err = e.newFirecrackerExecutionVM(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new firecracker VM: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
machine: machine,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the VM.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *models.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *models.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *models.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s recieved results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *models.ExecutionRequest,
) (*models.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// GetLogStream is not implemented for Firecracker.
// It is defined to satisfy the Executor interface.
// This method will return an error if called.
func (e *Executor) GetLogStream(
ctx context.Context,
request models.LogStreamRequest,
) (io.ReadCloser, error) {
return nil, fmt.Errorf("GetLogStream is not implemented for Firecracker")
}
// Cleanup removes all resources associated with the executor.
// This includes stopping and removing all running VMs and deleting their socket paths.
func (e *Executor) Cleanup(ctx context.Context) error {
wg := sync.WaitGroup{}
errCh := make(chan error, len(e.handlers.Keys()))
e.handlers.Iter(func(_ string, handler *executionHandler) bool {
wg.Add(1)
go func(handler *executionHandler, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- handler.destroy(time.Second * 10)
}(handler, &wg, errCh)
return true
})
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
zlog.Info("Cleaned up all firecracker resources")
return errs
}
// newFirecrackerExecutionVM is an internal method called by Start to set up a new Firecracker VM
// for the job execution. It configures the VM based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up mounts and resource constraints.
// It then creates the VM but does not start it. The method returns a firecracker.Machine instance
// and an error if any part of the setup fails.
func (e *Executor) newFirecrackerExecutionVM(
ctx context.Context,
params *models.ExecutionRequest,
) (*firecracker.Machine, error) {
fcArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return nil, fmt.Errorf("failed to decode firecracker engine spec: %w", err)
}
fcConfig := firecracker.Config{
VMID: params.ExecutionID,
SocketPath: e.generateSocketPath(params.JobID, params.ExecutionID),
KernelImagePath: fcArgs.KernelImage,
InitrdPath: fcArgs.Initrd,
KernelArgs: fcArgs.KernelArgs,
MachineCfg: fcModels.MachineConfiguration{
VcpuCount: firecracker.Int64(int64(params.Resources.CPU.Cores)),
MemSizeMib: firecracker.Int64(int64(params.Resources.Memory.Size)),
},
}
mounts, err := makeVMMounts(
fcArgs.RootFileSystem,
params.Inputs,
params.Outputs,
params.ResultsDir,
)
if err != nil {
return nil, fmt.Errorf("failed to create VM mounts: %w", err)
}
fcConfig.Drives = mounts
machine, err := e.client.CreateVM(ctx, fcConfig)
if err != nil {
return nil, fmt.Errorf("failed to create VM: %w", err)
}
// e.client.VMPassMMDs(ctx, machine, fcArgs.MMDSMessage)
return machine, nil
}
// makeVMMounts creates the mounts for the VM based on the input and output volumes
// provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeVMMounts(
rootFileSystem string,
inputs []*models.StorageVolumeExecutor,
outputs []*models.StorageVolumeExecutor,
resultsDir string,
) ([]fcModels.Drive, error) {
var drives []fcModels.Drive
drivesBuilder := firecracker.NewDrivesBuilder(rootFileSystem)
for _, input := range inputs {
drivesBuilder.AddDrive(input.Source, input.ReadOnly)
}
for _, output := range outputs {
if output.Source == "" {
return drives, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return drives, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return drives, fmt.Errorf("failed to create results directory: %w", err)
}
drivesBuilder.AddDrive(output.Source, false)
}
drives = drivesBuilder.Build()
return drives, nil
}
// FindRunningVM finds the VM that is running the execution with the given ID.
// It returns the Mchine instance if found, or an error if the VM is not found.
func (e *Executor) FindRunningVM(
ctx context.Context,
jobID string,
executionID string,
) (*firecracker.Machine, error) {
return e.client.FindVM(ctx, e.generateSocketPath(jobID, executionID))
}
// generateSocketPath generates a socket path based on the job identifiers.
func (e *Executor) generateSocketPath(jobID string, executionID string) string {
return fmt.Sprintf("%s/%s_%s_%s.sock", socketDir, e.ID, jobID, executionID)
}
package firecracker
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
"gitlab.com/nunet/device-management-service/models"
)
// executionHandler is a struct that holds the necessary information to manage the execution of a firecracker VM.
type executionHandler struct {
//
// provided by the executor
ID string
client *Client
// meta data about the task
JobID string
executionID string
machine *firecracker.Machine
resultsDir string
// synchronization
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // BLocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *models.ExecutionResult
}
// active returns true if the firecracker VM is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the firecracker VM and waits for it to finish.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
destroyTimeout := time.Second * 10
if err := h.destroy(destroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
// start the VM
zlog.Sugar().Info("starting firecracker execution")
if err := h.client.StartVM(ctx, h.machine); err != nil {
h.result = models.NewFailedExecutionResult(fmt.Errorf("failed to start VM: %v", err))
return
}
close(h.activeCh) // Indicate that the VM has started.
err := h.machine.Wait(ctx)
if err != nil {
if ctx.Err() != nil {
h.result = models.NewFailedExecutionResult(
fmt.Errorf("context closed while waiting on VM: %v", err),
)
return
}
h.result = models.NewFailedExecutionResult(fmt.Errorf("failed to wait on VM: %v", err))
return
}
h.result = models.NewExecutionResult(models.ExecutionStatusCodeSuccess)
}
// kill stops the firecracker VM.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.ShutdownVM(ctx, h.machine)
}
// destroy stops the firecracker VM and removes its resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return h.client.DestroyVM(ctx, h.machine, timeout)
}
package firecracker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("executor.firecracker")
}
package firecracker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyKernelImage = "kernel_image"
EngineKeyKernelArgs = "kernel_args"
EngineKeyRootFileSystem = "root_file_system"
EngineKeyMMDSMessage = "mmds_message"
)
// EngineSpec contains necessary parameters to execute a firecracker job.
type EngineSpec struct {
// KernelImage is the path to the kernel image file.
KernelImage string `json:"kernel_image,omitempty"`
// InitrdPath is the path to the initial ramdisk file.
Initrd string `json:"initrd_path,omitempty"`
// KernelArgs is the kernel command line arguments.
KernelArgs string `json:"kernel_args,omitempty"`
// RootFileSystem is the path to the root file system.
RootFileSystem string `json:"root_file_system,omitempty"`
// MMDSMessage is the MMDS message to be sent to the Firecracker VM.
MMDSMessage string `json:"mmds_message,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.RootFileSystem) {
return fmt.Errorf("invalid firecracker engine params: root_file_system cannot be empty")
}
if validate.IsBlank(c.KernelImage) {
return fmt.Errorf("invalid firecracker engine params: kernel_image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a firecracker engine spec
// It converts the params into a firecracker EngineSpec struct and validates it
func DecodeSpec(spec *models.SpecConfig) (EngineSpec, error) {
if !spec.IsType(models.ExecutorTypeFirecracker) {
return EngineSpec{}, fmt.Errorf(
"invalid firecracker engine type. expected %s, but recieved: %s",
models.ExecutorTypeFirecracker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid firecracker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode firecracker engine params: %w", err)
}
var firecrackerSpec *EngineSpec
err = json.Unmarshal(paramBytes, &firecrackerSpec)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode firecracker engine params: %w", err)
}
return *firecrackerSpec, firecrackerSpec.Validate()
}
// FirecrackerEngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Firecracker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type FirecrackerEngineBuilder struct {
eb *models.SpecConfig
}
// NewFirecrackerEngineBuilder function initializes a new FirecrackerEngineBuilder instance.
// It sets the engine type to EngineFirecracker.String() and kernel image path as per the input argument.
func NewFirecrackerEngineBuilder(rootFileSystem string) *FirecrackerEngineBuilder {
eb := models.NewSpecConfig(models.ExecutorTypeFirecracker)
eb.WithParam(EngineKeyRootFileSystem, rootFileSystem)
return &FirecrackerEngineBuilder{eb: eb}
}
// WithRootFileSystem is a builder method that sets the Firecracker engine root file system.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *FirecrackerEngineBuilder) WithRootFileSystem(e string) *FirecrackerEngineBuilder {
b.eb.WithParam(EngineKeyRootFileSystem, e)
return b
}
// WithKernelImage is a builder method that sets the Firecracker engine kernel image.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *FirecrackerEngineBuilder) WithKernelImage(e string) *FirecrackerEngineBuilder {
b.eb.WithParam(EngineKeyKernelImage, e)
return b
}
// WithKernelArgs is a builder method that sets the Firecracker engine kernel arguments.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *FirecrackerEngineBuilder) WithKernelArgs(e string) *FirecrackerEngineBuilder {
b.eb.WithParam(EngineKeyKernelArgs, e)
return b
}
// WithMMDSMessage is a builder method that sets the Firecracker engine MMDS message.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *FirecrackerEngineBuilder) WithMMDSMessage(e string) *FirecrackerEngineBuilder {
b.eb.WithParam(EngineKeyMMDSMessage, e)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *FirecrackerEngineBuilder) Build() *models.SpecConfig {
return b.eb
}
package background_tasks
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("background_tasks")
}
package background_tasks
import (
"sort"
"sync"
"time"
)
// Scheduler orchestrates the execution of tasks based on their triggers and priority.
type Scheduler struct {
tasks map[int]*Task // Map of tasks by their ID.
runningTasks map[int]bool // Map to keep track of running tasks.
ticker *time.Ticker // Ticker for periodic checks of task triggers.
stopChan chan struct{} // Channel to signal stopping the scheduler.
maxRunningTasks int // Maximum number of tasks that can run concurrently.
lastTaskID int // Counter for assigning unique IDs to tasks.
mu sync.Mutex // Mutex to protect access to task maps.
}
// NewScheduler creates a new Scheduler with a specified limit on running tasks.
func NewScheduler(maxRunningTasks int) *Scheduler {
return &Scheduler{
tasks: make(map[int]*Task),
runningTasks: make(map[int]bool),
ticker: time.NewTicker(1 * time.Second),
stopChan: make(chan struct{}),
maxRunningTasks: maxRunningTasks,
lastTaskID: 0,
}
}
// AddTask adds a new task to the scheduler and initializes its state.
func (s *Scheduler) AddTask(task *Task) *Task {
s.mu.Lock()
defer s.mu.Unlock()
task.ID = s.lastTaskID
task.Enabled = true
for _, trigger := range task.Triggers {
trigger.Reset()
}
s.tasks[task.ID] = task
s.lastTaskID++
return task
}
// RemoveTask removes a task from the scheduler.
func (s *Scheduler) RemoveTask(taskID int) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Start begins the scheduler's task execution loop.
func (s *Scheduler) Start() {
go func() {
for {
select {
case <-s.stopChan:
return
case <-s.ticker.C:
s.runTasks()
}
}
}()
}
// runningTasksCount returns the count of running tasks.
func (s *Scheduler) runningTasksCount() int {
s.mu.Lock()
defer s.mu.Unlock()
count := 0
for _, isRunning := range s.runningTasks {
if isRunning {
count++
}
}
return count
}
// runTasks checks and runs tasks based on their triggers and priority.
func (s *Scheduler) runTasks() {
// Sort tasks by priority.
sortedTasks := make([]*Task, 0, len(s.tasks))
for _, task := range s.tasks {
sortedTasks = append(sortedTasks, task)
}
sort.Slice(sortedTasks, func(i, j int) bool {
return sortedTasks[i].Priority > sortedTasks[j].Priority
})
for _, task := range sortedTasks {
if !task.Enabled || s.runningTasks[task.ID] {
continue
}
if len(task.Triggers) == 0 {
s.RemoveTask(task.ID)
continue
}
for _, trigger := range task.Triggers {
if trigger.IsReady() && s.runningTasksCount() < s.maxRunningTasks {
s.runningTasks[task.ID] = true
go s.runTask(task.ID)
trigger.Reset()
break
}
}
}
}
// Stop signals the scheduler to stop running tasks.
func (s *Scheduler) Stop() {
close(s.stopChan)
}
// runTask executes a task and manages its lifecycle and retry policy.
func (s *Scheduler) runTask(taskID int) {
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
s.runningTasks[taskID] = false
}()
task := s.tasks[taskID]
execution := Execution{StartedAt: time.Now()}
defer func() {
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.tasks[taskID] = task
s.mu.Unlock()
}()
for i := 0; i < task.RetryPolicy.MaxRetries+1; i++ {
err := runTaskWithRetry(task.Function, task.Args, task.RetryPolicy.Delay)
if err == nil {
execution.Status = "SUCCESS"
execution.EndedAt = time.Now()
return
}
execution.Error = err.Error()
}
execution.Status = "FAILED"
execution.EndedAt = time.Now()
}
// runTaskWithRetry attempts to execute a task with a retry policy.
func runTaskWithRetry(
fn func(args interface{}) error,
args []interface{},
delay time.Duration,
) error {
err := fn(args)
if err != nil {
time.Sleep(delay)
return err
}
return nil
}
package background_tasks
import (
"time"
"github.com/robfig/cron/v3"
)
// Trigger interface defines a method to check if a trigger condition is met.
type Trigger interface {
IsReady() bool // Returns true if the trigger condition is met.
Reset() // Resets the trigger state.
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTrigger struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTrigger) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTrigger) Reset() {
t.lastTriggered = time.Now()
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTriggerWithJitter struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
Jitter func() time.Duration
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTriggerWithJitter) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval + t.Jitter()).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTriggerWithJitter) Reset() {
t.lastTriggered = time.Now()
}
// EventTrigger triggers based on an external event signaled through a channel.
type EventTrigger struct {
Trigger chan bool // Channel to signal an event.
}
// IsReady checks if there is a signal in the trigger channel.
func (t *EventTrigger) IsReady() bool {
select {
case <-t.Trigger:
return true
default:
return false
}
}
// Reset for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) Reset() {}
// OneTimeTrigger triggers once after a specified delay.
type OneTimeTrigger struct {
Delay time.Duration // The delay after which to trigger.
registeredAt time.Time // Time when the trigger was set.
}
// Reset sets the trigger registration time to the current time.
func (t *OneTimeTrigger) Reset() {
t.registeredAt = time.Now()
}
// IsReady checks if the current time has passed the delay period.
func (t *OneTimeTrigger) IsReady() bool {
return t.registeredAt.Add(t.Delay).Before(time.Now())
}
package libp2p
import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/multiformats/go-multiaddr"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"google.golang.org/protobuf/proto"
)
// Bootstrap using a list.
func (l *Libp2p) Bootstrap(ctx context.Context, bootstrapPeers []multiaddr.Multiaddr) error {
if err := l.DHT.Bootstrap(ctx); err != nil {
return fmt.Errorf("failed to prepare this node for bootstraping: %w", err)
}
// bootstrap all nodes at the same time.
if len(bootstrapPeers) > 0 {
var wg sync.WaitGroup
for _, addr := range bootstrapPeers {
wg.Add(1)
go func(peerAddr multiaddr.Multiaddr) {
defer wg.Done()
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
zlog.Sugar().Errorf("failed to convert multi addr to addr info %v - %v", peerAddr, err)
return
}
if err := l.Host.Connect(ctx, *addrInfo); err != nil {
zlog.Sugar().Errorf("failed to connect to bootstrap node %s - %v", addrInfo.ID.String(), err)
} else {
zlog.Sugar().Infof("connected to Bootstrap Node %s", addrInfo.ID.String())
}
}(addr)
}
wg.Wait()
}
return nil
}
type dhtValidator struct {
PS peerstore.Peerstore
customNamespace string
}
// Validate validates an item placed into the dht.
func (d dhtValidator) Validate(key string, value []byte) error {
// empty value is considered deleting an item from the dht
if len(value) == 0 {
return nil
}
if !strings.HasPrefix(key, d.customNamespace) {
return errors.New("invalid key namespace")
}
// verify signature
var envelope commonproto.Advertisement
err := proto.Unmarshal(value, &envelope)
if err != nil {
return fmt.Errorf("failed to unmarshal envelope: %w", err)
}
pubKey, err := crypto.UnmarshalSecp256k1PublicKey(envelope.PublicKey)
if err != nil {
return fmt.Errorf("failed to unmarshal public key: %w", err)
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
envelope.PublicKey,
}, nil)
ok, err := pubKey.Verify(concatenatedBytes, envelope.Signature)
if err != nil {
return fmt.Errorf("failed to verify envelope: %w", err)
}
if !ok {
return errors.New("failed to verify envelope, public key didn't sign payload")
}
return nil
}
func (dhtValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
// TODO remove the below when network package is fully implemented
// UpdateKadDHT is a stub
func (l *Libp2p) UpdateKadDHT() {
zlog.Warn("UpdateKadDHT: Stub")
}
// ListKadDHTPeers is a stub
func (l *Libp2p) ListKadDHTPeers(c *gin.Context, ctx context.Context) ([]string, error) {
zlog.Warn("ListKadDHTPeers: Stub")
return nil, nil
}
package libp2p
import (
"context"
"fmt"
"os"
"github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
dutil "github.com/libp2p/go-libp2p/p2p/discovery/util"
)
// DiscoverDialPeers discovers peers using randevouz point
func (l *Libp2p) DiscoverDialPeers(ctx context.Context) error {
foundPeers, err := l.findPeersFromRendezvousDiscovery(ctx)
if err != nil {
return err
}
if len(foundPeers) > 0 {
l.discoveredPeers = foundPeers
}
// filter out peers with no listening addresses and self host
filterSpec := NoAddrIDFilter{ID: l.Host.ID()}
l.discoveredPeers = PeerPassFilter(l.discoveredPeers, filterSpec)
l.dialPeers(ctx)
return nil
}
// advertiseForRendezvousDiscovery is used to advertise node using the dht by giving it the randevouz point.
func (l *Libp2p) advertiseForRendezvousDiscovery(context context.Context) error {
_, err := l.discovery.Advertise(context, l.config.Rendezvous)
return err
}
// findPeersFromRendezvousDiscovery uses the randevouz point to discover other peers.
func (l *Libp2p) findPeersFromRendezvousDiscovery(ctx context.Context) ([]peer.AddrInfo, error) {
peers, err := dutil.FindPeers(
ctx,
l.discovery,
l.config.Rendezvous,
discovery.Limit(l.config.PeerCountDiscoveryLimit),
)
if err != nil {
return nil, fmt.Errorf("failed to discover peers: %w", err)
}
return peers, nil
}
func (l *Libp2p) dialPeers(ctx context.Context) {
for _, p := range l.discoveredPeers {
if p.ID == l.Host.ID() {
continue
}
if l.Host.Network().Connectedness(p.ID) != network.Connected {
_, err := l.Host.Network().DialPeer(ctx, p.ID)
if err != nil {
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("couldn't establish connection with: %s - error: %v", p.ID.String(), err)
}
continue
}
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("connected with: %s", p.ID.String())
}
}
}
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/control"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
mafilt "github.com/whyrusleeping/multiaddr-filter"
)
var defaultServerFilters = []string{
"/ip4/10.0.0.0/ipcidr/8",
"/ip4/100.64.0.0/ipcidr/10",
"/ip4/169.254.0.0/ipcidr/16",
"/ip4/172.16.0.0/ipcidr/12",
"/ip4/192.0.0.0/ipcidr/24",
"/ip4/192.0.2.0/ipcidr/24",
"/ip4/192.168.0.0/ipcidr/16",
"/ip4/198.18.0.0/ipcidr/15",
"/ip4/198.51.100.0/ipcidr/24",
"/ip4/203.0.113.0/ipcidr/24",
"/ip4/240.0.0.0/ipcidr/4",
"/ip6/100::/ipcidr/64",
"/ip6/2001:2::/ipcidr/48",
"/ip6/2001:db8::/ipcidr/32",
"/ip6/fc00::/ipcidr/7",
"/ip6/fe80::/ipcidr/10",
}
// PeerFilter is an interface for filtering peers
// satisfaction of filter criteria allows the peer to pass
type PeerFilter interface {
satisfies(p peer.AddrInfo) bool
}
// NoAddrIDFilter filters out peers with no listening addresses
// and a peer with a specific ID
type NoAddrIDFilter struct {
ID peer.ID
}
func (f NoAddrIDFilter) satisfies(p peer.AddrInfo) bool {
return len(p.Addrs) > 0 && p.ID != f.ID
}
func PeerPassFilter(peers []peer.AddrInfo, pf PeerFilter) []peer.AddrInfo {
var filtered []peer.AddrInfo
for _, p := range peers {
if pf.satisfies(p) {
filtered = append(filtered, p)
}
}
return filtered
}
type filtersConnectionGater multiaddr.Filters
var _ connmgr.ConnectionGater = (*filtersConnectionGater)(nil)
func (f *filtersConnectionGater) InterceptAddrDial(_ peer.ID, addr multiaddr.Multiaddr) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(addr)
}
func (f *filtersConnectionGater) InterceptPeerDial(p peer.ID) (allow bool) {
return true
}
func (f *filtersConnectionGater) InterceptAccept(connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptUpgraded(_ network.Conn) (allow bool, reason control.DisconnectReason) {
return true, 0
}
func makeAddrsFactory(announce []string, appendAnnouce []string, noAnnounce []string) func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var err error // To assign to the slice in the for loop
existing := make(map[string]bool) // To avoid duplicates
annAddrs := make([]multiaddr.Multiaddr, len(announce))
for i, addr := range announce {
annAddrs[i], err = multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
existing[addr] = true
}
var appendAnnAddrs []multiaddr.Multiaddr
for _, addr := range appendAnnouce {
if existing[addr] {
// skip AppendAnnounce that is on the Announce list already
continue
}
appendAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
appendAnnAddrs = append(appendAnnAddrs, appendAddr)
}
filters := multiaddr.NewFilters()
noAnnAddrs := map[string]bool{}
for _, addr := range noAnnounce {
f, err := mafilt.NewMask(addr)
if err == nil {
filters.AddFilter(*f, multiaddr.ActionDeny)
continue
}
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
noAnnAddrs[string(maddr.Bytes())] = true
}
return func(allAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
var addrs []multiaddr.Multiaddr
if len(annAddrs) > 0 {
addrs = annAddrs
} else {
addrs = allAddrs
}
addrs = append(addrs, appendAnnAddrs...)
var out []multiaddr.Multiaddr
for _, maddr := range addrs {
// check for exact matches
ok := noAnnAddrs[string(maddr.Bytes())]
// check for /ipcidr matches
if !ok && !filters.AddrBlocked(maddr) {
out = append(out, maddr)
}
}
return out
}
}
package libp2p
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/models"
)
// StreamHandler is a function type that processes data from a stream.
type StreamHandler func(stream network.Stream)
// HandlerRegistry manages the registration of stream handlers for different protocols.
type HandlerRegistry struct {
host host.Host
handlers map[protocol.ID]StreamHandler
bytesHandlers map[protocol.ID]func(data []byte)
mu sync.RWMutex
}
// NewHandlerRegistry creates a new handler registry instance.
func NewHandlerRegistry(host host.Host) *HandlerRegistry {
return &HandlerRegistry{
host: host,
handlers: make(map[protocol.ID]StreamHandler),
bytesHandlers: make(map[protocol.ID]func(data []byte)),
}
}
// RegisterHandlerWithStreamCallback registers a stream handler for a specific protocol.
func (r *HandlerRegistry) RegisterHandlerWithStreamCallback(messageType models.MessageType, handler StreamHandler) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.handlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.handlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(handler))
return nil
}
// RegisterHandlerWithBytesCallback registers a stream handler for a specific protocol and sends the bytes back to callback.
func (r *HandlerRegistry) RegisterHandlerWithBytesCallback(messageType models.MessageType, s StreamHandler, handler func(data []byte)) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.bytesHandlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.bytesHandlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(s))
return nil
}
// SendMessageToLocalHandler given the message type it sends data to the local handler found.
func (r *HandlerRegistry) SendMessageToLocalHandler(messageType models.MessageType, data []byte) {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
h, ok := r.bytesHandlers[protoID]
if !ok {
return
}
// we need this goroutine to avoid blocking the caller goroutine
go h(data)
}
package libp2p
import (
"context"
"fmt"
"strings"
"time"
"github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/pnet"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/routing"
"github.com/libp2p/go-libp2p/p2p/host/autorelay"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/security/noise"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"github.com/multiformats/go-multiaddr"
"github.com/spf13/afero"
mafilt "github.com/whyrusleeping/multiaddr-filter"
"gitlab.com/nunet/device-management-service/models"
)
// NewHost returns a new libp2p host with dht and other related settings.
func NewHost(ctx context.Context, config *models.Libp2pConfig, fs afero.Fs) (host.Host, *dht.IpfsDHT, *pubsub.PubSub, error) {
var idht *dht.IpfsDHT
connmgr, err := connmgr.NewConnManager(
100,
400,
connmgr.WithGracePeriod(time.Duration(config.GracePeriodMs)*time.Millisecond),
)
if err != nil {
return nil, nil, nil, err
}
filter := multiaddr.NewFilters()
for _, s := range defaultServerFilters {
f, err := mafilt.NewMask(s)
if err != nil {
zlog.Sugar().Errorf("incorrectly formatted address filter in config: %s - %v", s, err)
}
filter.AddFilter(*f, multiaddr.ActionDeny)
}
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, nil, nil, err
}
var libp2pOpts []libp2p.Option
baseOpts := []dht.Option{
dht.ProtocolPrefix(protocol.ID(config.DHTPrefix)),
dht.NamespacedValidator(strings.ReplaceAll(config.CustomNamespace, "/", ""), dhtValidator{PS: ps}),
dht.Mode(dht.ModeServer),
}
if config.PrivateNetwork.WithSwarmKey {
psk, err := configureSwarmKey(fs)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to configure swarm key: %v", err)
}
libp2pOpts = append(libp2pOpts, libp2p.PrivateNetwork(psk))
// guarantee that outer connection will be refused
pnet.ForcePrivateNetwork = true
} else {
// enable quic (it does not work with pnet enabled)
libp2pOpts = append(libp2pOpts, libp2p.Transport(quic.NewTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(webtransport.New))
// for some reason, ForcePrivateNetwork was equal to true even without being set to true
pnet.ForcePrivateNetwork = false
}
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(config.ListenAddress...),
libp2p.Identity(config.PrivateKey),
libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) {
idht, err = dht.New(ctx, h, baseOpts...)
return idht, err
}),
libp2p.Peerstore(ps),
libp2p.Security(libp2ptls.ID, libp2ptls.New),
libp2p.Security(noise.ID, noise.New),
// Do not use DefaulTransports as we can not enable Quic when pnet
libp2p.Transport(tcp.NewTCPTransport),
libp2p.Transport(ws.New),
libp2p.EnableNATService(),
libp2p.ConnectionManager(connmgr),
libp2p.EnableRelay(),
libp2p.EnableHolePunching(),
libp2p.EnableRelayService(
relay.WithResources(
relay.Resources{
MaxReservations: 256,
MaxCircuits: 32,
BufferSize: 4096,
MaxReservationsPerPeer: 8,
MaxReservationsPerIP: 16,
},
),
relay.WithLimit(&relay.RelayLimit{
Duration: 5 * time.Minute,
Data: 1 << 21, // 2 MiB
}),
),
libp2p.EnableAutoRelayWithPeerSource(
func(ctx context.Context, num int) <-chan peer.AddrInfo {
r := make(chan peer.AddrInfo)
go func() {
defer close(r)
for i := 0; i < num; i++ {
select {
case p := <-newPeer:
select {
case r <- p:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return r
},
autorelay.WithBootDelay(time.Minute),
autorelay.WithBackoff(30*time.Second),
autorelay.WithMinCandidates(2),
autorelay.WithMaxCandidates(3),
autorelay.WithNumRelays(2),
),
)
if config.Server {
libp2pOpts = append(libp2pOpts, libp2p.AddrsFactory(makeAddrsFactory([]string{}, []string{}, defaultServerFilters)))
libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater((*filtersConnectionGater)(filter)))
} else {
libp2pOpts = append(libp2pOpts, libp2p.NATPortMap())
}
host, err := libp2p.New(libp2pOpts...)
if err != nil {
return nil, nil, nil, err
}
optsPS := []pubsub.Option{pubsub.WithMessageSigning(true), pubsub.WithMaxMessageSize(config.GossipMaxMessageSize)}
gossip, err := pubsub.NewGossipSub(ctx, host, optsPS...)
// gossip, err := pubsub.NewGossipSubWithRouter(ctx, host, pubsub.DefaultGossipSubRouter(host), optsPS...)
if err != nil {
return nil, nil, nil, err
}
return host, idht, gossip, nil
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
const (
// Custom namespace for DHT protocol with version number
customNamespace = "/nunet-dht-1/"
)
// TODO: pass the logger to the constructor and remove from here
var (
zlog *otelzap.Logger
newPeer = make(chan peer.AddrInfo)
)
func init() {
zlog = logger.OtelZapLogger("network.libp2p")
}
package libp2p
import (
"context"
"bufio"
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/ipfs/go-cid"
kbucket "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"gitlab.com/nunet/device-management-service/models"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
"github.com/spf13/afero"
"google.golang.org/protobuf/proto"
dht "github.com/libp2p/go-libp2p-kad-dht"
libp2pdiscovery "github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
drouting "github.com/libp2p/go-libp2p/p2p/discovery/routing"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
)
const (
MB = 1024 * 1024
MaxMessageLengthMB = 10
)
// Libp2p contains the configuration for a Libp2p instance.
//
// TODO-suggestion: maybe we should call it something else like Libp2pPeer,
// Libp2pHost or just Peer (callers would use libp2p.Peer...)
type Libp2p struct {
Host host.Host
DHT *dht.IpfsDHT
PS peerstore.Peerstore
pubsub *pubsub.PubSub
pubsubTopics map[string]*pubsub.Topic
topicSubscription map[string]*pubsub.Subscription
topicMux sync.RWMutex
// a list of peers discovered by discovery
discoveredPeers []peer.AddrInfo
discovery libp2pdiscovery.Discovery
// services
pingService *ping.PingService
// tasks
discoveryTask *bt.Task
handlerRegistry *HandlerRegistry
config *models.Libp2pConfig
// dependencies (db, filesystem...)
fs afero.Fs
}
// New creates a libp2p instance.
//
// TODO-Suggestion: move models.Libp2pConfig to here for better readability.
// Unless there is a reason to keep within models.
func New(config *models.Libp2pConfig, fs afero.Fs) (*Libp2p, error) {
if config == nil {
return nil, errors.New("config is nil")
}
if config.Scheduler == nil {
return nil, errors.New("scheduler is nil")
}
return &Libp2p{
config: config,
discoveredPeers: make([]peer.AddrInfo, 0),
pubsubTopics: make(map[string]*pubsub.Topic),
topicSubscription: make(map[string]*pubsub.Subscription),
fs: fs,
}, nil
}
// Init initializes a libp2p host with its dependencies.
func (l *Libp2p) Init(context context.Context) error {
host, dht, pubsub, err := NewHost(context, l.config, l.fs)
if err != nil {
zlog.Sugar().Error(err)
return err
}
l.Host = host
l.DHT = dht
l.PS = host.Peerstore()
l.discovery = drouting.NewRoutingDiscovery(dht)
l.pubsub = pubsub
l.handlerRegistry = NewHandlerRegistry(host)
return nil
}
// Start performs network bootstrapping, peer discovery and protocols handling.
func (l *Libp2p) Start(context context.Context) error {
// set stream handlers
l.registerStreamHandlers()
// bootstrap should return error if it had an error
err := l.Bootstrap(context, l.config.BootstrapPeers)
if err != nil {
zlog.Sugar().Errorf("failed to start network: %v", err)
return err
}
// advertise randevouz discovery
err = l.advertiseForRendezvousDiscovery(context)
if err != nil {
// TODO: the error might be misleading as a peer can normally work well if an error
// is returned here (e.g.: the error is yielded in tests even though all tests pass).
zlog.Sugar().Errorf("failed to start network with randevouz discovery: %v", err)
}
// discover
err = l.DiscoverDialPeers(context)
if err != nil {
zlog.Sugar().Errorf("failed to discover peers: %v", err)
}
// register period peer discoveryTask task
discoveryTask := &bt.Task{
Name: "Peer Discovery",
Description: "Periodic task to discover new peers every 15 minutes",
Function: func(args interface{}) error {
return l.DiscoverDialPeers(context)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 15 * time.Minute}},
}
l.discoveryTask = l.config.Scheduler.AddTask(discoveryTask)
l.config.Scheduler.Start()
return nil
}
// RegisterStreamMessageHandler registers a stream handler for a specific protocol.
func (l *Libp2p) RegisterStreamMessageHandler(messageType models.MessageType, handler StreamHandler) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithStreamCallback(messageType, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// RegisterBytesMessageHandler registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) RegisterBytesMessageHandler(messageType models.MessageType, handler func(data []byte)) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithBytesCallback(messageType, l.handleReadBytesFromStream, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// HandleMessage registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) HandleMessage(messageType string, handler func(data []byte)) error {
return l.RegisterBytesMessageHandler(models.MessageType(messageType), handler)
}
func (l *Libp2p) handleReadBytesFromStream(s network.Stream) {
callback, ok := l.handlerRegistry.bytesHandlers[s.Protocol()]
if !ok {
s.Close()
return
}
c := bufio.NewReader(s)
defer s.Close()
// read the first 8 bytes to determine the size of the message
msgLengthBuffer := make([]byte, 8)
_, err := c.Read(msgLengthBuffer)
if err != nil {
return
}
// create a buffer with the size of the message and then read until its full
lengthPrefix := int64(binary.LittleEndian.Uint64(msgLengthBuffer))
buf := make([]byte, lengthPrefix)
// read the full message
_, err = io.ReadFull(c, buf)
if err != nil {
return
}
callback(buf)
}
// SendMessage sends a message to a list of peers.
func (l *Libp2p) SendMessage(ctx context.Context, addrs []string, msg models.MessageEnvelope) error {
var wg sync.WaitGroup
errCh := make(chan error, len(addrs))
for _, addr := range addrs {
wg.Add(1)
go func(addr string) {
defer wg.Done()
err := l.sendMessage(ctx, addr, msg)
if err != nil {
errCh <- err
}
}(addr)
}
wg.Wait()
close(errCh)
var result error
for err := range errCh {
if result == nil {
result = err
} else {
result = fmt.Errorf("%v; %v", result, err)
}
}
return result
}
// OpenStream opens a stream to a remote address and returns the stream for the caller to handle.
func (l *Libp2p) OpenStream(ctx context.Context, addr string, messageType models.MessageType) (network.Stream, error) {
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil, fmt.Errorf("invalid multiaddress: %w", err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(maddr)
if err != nil {
return nil, fmt.Errorf("could not resolve peer info: %w", err)
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(messageType))
if err != nil {
return nil, fmt.Errorf("failed to open stream: %w", err)
}
return stream, nil
}
// GetMultiaddr returns the peer's multiaddr.
func (l *Libp2p) GetMultiaddr() ([]multiaddr.Multiaddr, error) {
peerInfo := peer.AddrInfo{
ID: l.Host.ID(),
Addrs: l.Host.Addrs(),
}
return peer.AddrInfoToP2pAddrs(&peerInfo)
}
// Stop performs a cleanup of any resources used in this package.
func (l *Libp2p) Stop() error {
var errorMessages []string
l.config.Scheduler.RemoveTask(l.discoveryTask.ID)
if err := l.DHT.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if err := l.Host.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if len(errorMessages) > 0 {
return errors.New(strings.Join(errorMessages, "; "))
}
return nil
}
// Stat returns the status about the libp2p network.
func (l *Libp2p) Stat() models.NetworkStats {
var lAddrs []string
for _, addr := range l.Host.Addrs() {
lAddrs = append(lAddrs, addr.String())
}
return models.NetworkStats{
ID: l.Host.ID().String(),
ListenAddr: strings.Join(lAddrs, ", "),
}
}
// Ping the remote address. The remote address is the encoded peer id which will be decoded and used here.
//
// TODO (Return error once): something that was confusing me when using this method is that the error is
// returned twice if any. Once as a field of PingResult and one as a return value.
func (l *Libp2p) Ping(ctx context.Context, peerIDAddress string, timeout time.Duration) (models.PingResult, error) {
// avoid dial to self attempt
if peerIDAddress == l.Host.ID().String() {
err := errors.New("can't ping self")
return models.PingResult{Success: false, Error: err}, err
}
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
remotePeer, err := peer.Decode(peerIDAddress)
if err != nil {
return models.PingResult{}, err
}
pingChan := ping.Ping(pingCtx, l.Host, remotePeer)
select {
case res := <-pingChan:
if res.Error != nil {
zlog.Sugar().Errorf("failed to ping peer %s: %v", peerIDAddress, res.Error)
return models.PingResult{
Success: false,
RTT: res.RTT,
Error: res.Error,
}, res.Error
}
return models.PingResult{
RTT: res.RTT,
Success: true,
}, nil
case <-pingCtx.Done():
return models.PingResult{
Error: pingCtx.Err(),
}, pingCtx.Err()
}
}
// ResolveAddress resolves the address by given a peer id.
func (l *Libp2p) ResolveAddress(ctx context.Context, id string) ([]string, error) {
pid, err := peer.Decode(id)
if err != nil {
return nil, fmt.Errorf("failed to resolve invalid peer: %w", err)
}
// resolve ourself
if l.Host.ID().String() == id {
multiAddrs, err := l.GetMultiaddr()
if err != nil {
return nil, fmt.Errorf("failed to resolve self: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
pi, err := l.DHT.FindPeer(ctx, pid)
if err != nil {
return nil, fmt.Errorf("failed to resolve address %s: %w", id, err)
}
peerInfo := peer.AddrInfo{
ID: pi.ID,
Addrs: pi.Addrs,
}
multiAddrs, err := peer.AddrInfoToP2pAddrs(&peerInfo)
if err != nil {
return nil, fmt.Errorf("failed to convert to p2p address: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
// Query return all the advertisements in the network related to a key.
// The network is queried to find providers for the given key, and peers which we aren't connected to can be retrieved.
func (l *Libp2p) Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error) {
if key == "" {
return nil, errors.New("advertisement key is empty")
}
customCID, err := createCIDFromKey(key)
if err != nil {
return nil, fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
addrInfo, err := l.DHT.FindProviders(ctx, customCID)
if err != nil {
return nil, fmt.Errorf("failed to find providers for key %s: %w", key, err)
}
var advertisements []*commonproto.Advertisement
for _, v := range addrInfo {
// TODO: use go routines to get the values in parallel.
bytesAdvertisement, err := l.DHT.GetValue(ctx, l.getCustomNamespace(key, v.ID.String()))
if err != nil {
continue
}
var ad commonproto.Advertisement
if err := proto.Unmarshal(bytesAdvertisement, &ad); err != nil {
return nil, fmt.Errorf("failed to unmarshal advertisement payload: %w", err)
}
advertisements = append(advertisements, &ad)
}
return advertisements, nil
}
// Advertise given data and a key pushes the data to the dht.
func (l *Libp2p) Advertise(ctx context.Context, key string, data []byte) error {
if key == "" {
return errors.New("advertisement key is empty")
}
pubKeyBytes, err := l.getPublicKey()
if err != nil {
return fmt.Errorf("failed to get public key: %w", err)
}
envelope := &commonproto.Advertisement{
PeerId: l.Host.ID().String(),
Timestamp: time.Now().Unix(),
Data: data,
PublicKey: pubKeyBytes,
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
pubKeyBytes,
}, nil)
sig, err := l.sign(concatenatedBytes)
if err != nil {
return fmt.Errorf("failed to sign advertisement envelope content: %w", err)
}
envelope.Signature = sig
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
return fmt.Errorf("failed to marshal advertise envelope: %w", err)
}
customCID, err := createCIDFromKey(key)
if err != nil {
return fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
err = l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), envelopeBytes)
if err != nil {
return fmt.Errorf("failed to put key %s into the dht: %w", key, err)
}
err = l.DHT.Provide(ctx, customCID, true)
if err != nil {
return fmt.Errorf("failed to provide key %s into the dht: %w", key, err)
}
return nil
}
// Unadvertise removes the data from the dht.
func (l *Libp2p) Unadvertise(ctx context.Context, key string) error {
err := l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), nil)
if err != nil {
return fmt.Errorf("failed to remove key %s from the DHT: %w", key, err)
}
return nil
}
// Publish publishes data to a topic.
// The requirements are that only one topic handler should exist per topic.
func (l *Libp2p) Publish(ctx context.Context, topic string, data []byte) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to publish: %w", err)
}
err = topicHandler.Publish(ctx, data)
if err != nil {
return fmt.Errorf("failed to publish to topic %s: %w", topic, err)
}
return nil
}
// Subscribe subscribes to a topic and sends the messages to the handler.
func (l *Libp2p) Subscribe(ctx context.Context, topic string, handler func(data []byte)) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to subscribe to topic: %w", err)
}
sub, err := topicHandler.Subscribe()
if err != nil {
return fmt.Errorf("failed to subscribe to topic %s: %w", topic, err)
}
l.topicMux.Lock()
l.topicSubscription[topic] = sub
l.topicMux.Unlock()
go func() {
for {
msg, err := sub.Next(ctx)
if err != nil {
continue
}
handler(msg.Data)
}
}()
return nil
}
func (l *Libp2p) sendMessage(ctx context.Context, addr string, msg models.MessageEnvelope) error {
peerAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return fmt.Errorf("invalid multiaddr %s: %v", addr, err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
return fmt.Errorf("failed to get peer info %s: %v", addr, err)
}
// we are delivering a message to ourself
// we should use the handler to send the message to the handler directly which has been previously registered.
if peerInfo.ID.String() == l.Host.ID().String() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data)
return nil
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return fmt.Errorf("failed to connect to peer %v: %v", peerInfo.ID, err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(msg.Type))
if err != nil {
return fmt.Errorf("failed to open stream to peer %v: %v", peerInfo.ID, err)
}
defer stream.Close()
requestBufferSize := 8 + len(msg.Data)
if requestBufferSize > MaxMessageLengthMB*MB {
return fmt.Errorf("message size %d is greater than limit %d bytes", requestBufferSize, MaxMessageLengthMB*MB)
}
requestPayloadWithLength := make([]byte, requestBufferSize)
binary.LittleEndian.PutUint64(requestPayloadWithLength, uint64(len(msg.Data)))
copy(requestPayloadWithLength[8:], msg.Data)
_, err = stream.Write(requestPayloadWithLength)
if err != nil {
return fmt.Errorf("failed to send message to peer %v: %v", peerInfo.ID, err)
}
return nil
}
// getOrJoinTopicHandler gets the topic handler, it will be created if it doesn't exist.
// for publishing and subscribing its needed therefore its implemented in this function.
func (l *Libp2p) getOrJoinTopicHandler(topic string) (*pubsub.Topic, error) {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
t, err := l.pubsub.Join(topic)
if err != nil {
return nil, fmt.Errorf("failed to join topic %s: %w", topic, err)
}
topicHandler = t
l.pubsubTopics[topic] = t
}
return topicHandler, nil
}
// Unsubscribe cancels the subscription to a topic
func (l *Libp2p) Unsubscribe(topic string) error {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to topic: %s", topic)
}
// delete subscription handler and subscription
sub, ok := l.topicSubscription[topic]
if ok {
sub.Cancel()
delete(l.topicSubscription, topic)
}
if err := topicHandler.Close(); err != nil {
return fmt.Errorf("failed to close topic handler: %w", err)
}
delete(l.pubsubTopics, topic)
return nil
}
func (l *Libp2p) VisiblePeers() []peer.AddrInfo {
return l.discoveredPeers
}
func (l *Libp2p) KnownPeers() ([]peer.AddrInfo, error) {
knownPeers := l.Host.Peerstore().Peers()
peers := make([]peer.AddrInfo, 0, len(knownPeers))
for _, p := range knownPeers {
peers = append(peers, peer.AddrInfo{ID: p})
}
return peers, nil
}
func (l *Libp2p) DumpDHTRoutingTable() ([]kbucket.PeerInfo, error) {
rt := l.DHT.RoutingTable()
return rt.GetPeerInfos(), nil
}
func (l *Libp2p) registerStreamHandlers() {
l.pingService = ping.NewPingService(l.Host)
l.Host.SetStreamHandler(protocol.ID("/ipfs/ping/1.0.0"), l.pingService.PingHandler)
}
func (l *Libp2p) sign(data []byte) ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
signature, err := privKey.Sign(data)
if err != nil {
return nil, fmt.Errorf("failed to sign data: %w", err)
}
return signature, nil
}
func (l *Libp2p) getPublicKey() ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
pubKey := privKey.GetPublic()
return pubKey.Raw()
}
func (l *Libp2p) getCustomNamespace(key, peerID string) string {
return fmt.Sprintf("%s-%s-%s", l.config.CustomNamespace, key, peerID)
}
func createCIDFromKey(key string) (cid.Cid, error) {
hash := sha256.Sum256([]byte(key))
mh, err := multihash.Encode(hash[:], multihash.SHA2_256)
if err != nil {
return cid.Cid{}, err
}
return cid.NewCidV1(cid.Raw, mh), nil
}
func CleanupPeer(id peer.ID) error {
zlog.Warn("CleanupPeer: Stub")
return nil
}
func PingPeer(ctx context.Context, target peer.ID) (bool, *ping.Result) {
zlog.Warn("PingPeer: Stub")
return false, nil
}
func DumpKademliaDHT(ctx context.Context) ([]models.PeerData, error) {
zlog.Warn("DumpKademliaDHT: Stub")
return nil, nil
}
func OldPingPeer(ctx context.Context, target peer.ID) (bool, *models.PingResult) {
zlog.Warn("OldPingPeer: Stub")
return false, nil
}
package libp2p
import (
"bytes"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/pnet"
"github.com/spf13/afero"
)
/*
** Swarm key **
By default, the swarm key shall be stored in a file named `swarm.key`
using the following pathbased codec:
`/key/swarm/psk/1.0.0/<base_encoding>/<256_bits_key>`
`<base_encoding>` is either bin, base16 or base64.
*/
// TODO-pnet-1: we shouldn't handle configuration paths here, a general configuration path
// should be provided by /internal/config.go
func getBasePath(fs afero.Fs) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("error getting home directory: %w", err)
}
nunetDir := filepath.Join(homeDir, ".nunet")
return nunetDir, nil
}
// configureSwarmKey try to read the swarm key from `<config_path>/swarm.key` file.
// If a swarm key is not found, generate a new one.
//
// TODO-ask: should we continue to generate a new swarm key if one is not found?
// Or we should enforce the user to use some cmd/API rpc to generate a new one?
func configureSwarmKey(fs afero.Fs) (pnet.PSK, error) {
var psk pnet.PSK
var err error
psk, err = getSwarmKey(fs)
if err != nil {
psk, err = generateSwarmKey(fs)
if err != nil {
return nil, fmt.Errorf("failed to generate new swarm key: %w", err)
}
}
return psk, nil
}
// getSwarmKey reads the swarm key from a file
func getSwarmKey(fs afero.Fs) (pnet.PSK, error) {
homeDir, err := getBasePath(fs)
swarmkey, err := afero.ReadFile(fs, filepath.Join(homeDir, "swarm.key"))
if err != nil {
return nil, fmt.Errorf("failed to read swarm key file: %w", err)
}
psk, err := pnet.DecodeV1PSK(bytes.NewReader(swarmkey))
if err != nil {
return nil, fmt.Errorf("failed to configure private network: %s", err)
}
// TODO-ask: should we return psk fingerprint?
return psk, nil
}
// generateSwarmKey generates a new swarm key, storing it within
// `<nunet_config_dir>/swarm.key`.
func generateSwarmKey(fs afero.Fs) (pnet.PSK, error) {
priv, _, err := crypto.GenerateKeyPair(crypto.Secp256k1, 256)
if err != nil {
return nil, err
}
privBytes, err := crypto.MarshalPrivateKey(priv)
if err != nil {
return nil, err
}
encodedKey := base64.StdEncoding.EncodeToString(privBytes)
swarmKeyWithCodec := fmt.Sprintf("/key/swarm/psk/1.0.0/\n/base64/\n%s\n", encodedKey)
// TODO-pnet-1
nunetDir, err := getBasePath(fs)
if err != nil {
return nil, err
}
swarmKeyPath := filepath.Join(nunetDir, "swarm.key")
if err := afero.WriteFile(fs, swarmKeyPath, []byte(swarmKeyWithCodec), 0600); err != nil {
return nil, fmt.Errorf("error writing swarm key to file: %w", err)
}
psk, err := pnet.DecodeV1PSK(bytes.NewReader([]byte(swarmKeyWithCodec)))
if err != nil {
return nil, fmt.Errorf("failed to decode generated swarm key: %s", err)
}
zlog.Sugar().Infof("A new Swarm key was generated and written to %s\n"+
"IMPORTANT: If you'd like to create the swarm key using a cryptography algorithm "+
"of your choice, just modify the swarm.key file with your own key.\n"+
"The content of `swarm.key` should look like: `/key/swarm/psk/1.0.0/<base_encoding>/<your_key>`\n"+
"where `<base_encoding>` is either `bin`, `base16`, or `base64`.\n",
swarmKeyPath,
)
return psk, nil
}
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/spf13/afero"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/network/libp2p"
)
// Messenger defines the interface for sending messages.
type Messenger interface {
// SendMessage sends a message to the given address.
SendMessage(ctx context.Context, addrs []string, msg models.MessageEnvelope) error
}
type Network interface {
// Messenger embedded interface
Messenger
// Init initializes the network
Init(context.Context) error
// Start starts the network
Start(context context.Context) error
// Stat returns the network information
Stat() models.NetworkStats
// Ping pings the given address and returns the PingResult
Ping(ctx context.Context, address string, timeout time.Duration) (models.PingResult, error)
// HandleMessage is responsible for registering a message type and its handler.
HandleMessage(messageType string, handler func(data []byte)) error
// ResolveAddress given an id it retruns the address of the peer.
// In libp2p, id represents the peerID and the response is the addrinfo
ResolveAddress(ctx context.Context, id string) ([]string, error)
// Advertise advertises the given data with the given adId
// such as advertising device capabilities on the DHT
Advertise(ctx context.Context, key string, data []byte) error
// Unadvertise stops advertising data corresponding to the given adId
Unadvertise(ctx context.Context, key string) error
// Query returns the network advertisement
Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error)
// Publish publishes the given data to the given topic if the network
// type allows publish/subscribe functionality such as gossipsub or nats
Publish(ctx context.Context, topic string, data []byte) error
// Subscribe subscribes to the given topic and calls the handler function
// if the network type allows it simmilar to Publish()
Subscribe(ctx context.Context, topic string, handler func(data []byte)) error
// Unsubscribe from a topic
Unsubscribe(topic string) error
// Stop stops the network including any existing advertisments and subscriptions
Stop() error
}
// NewNetwork returns a new network given the configuration.
func NewNetwork(netConfig *models.NetworkConfig, fs afero.Fs) (Network, error) {
// TODO: probable additional params to receive: DB, FileSystem
if netConfig == nil {
return nil, errors.New("network configuration is nil")
}
switch netConfig.Type {
case models.Libp2pNetwork:
ln, err := libp2p.New(&netConfig.Libp2pConfig, fs)
return ln, err
case models.NATSNetwork:
return nil, errors.New("not implemented")
default:
return nil, fmt.Errorf("unsupported network type: %s", netConfig.Type)
}
}
package basic_controller
import (
"context"
"fmt"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/utils"
)
// BasicVolumeController is the default implementation of the VolumeController.
// It persists storage volumes information using the StorageVolumeRepository.
type BasicVolumeController struct {
// repo is the repository for storage volume operations
repo repositories.StorageVolumeRepository
// basePath is the base path where volumes are stored under
basePath string
// file system to act upon
FS afero.Fs
}
// NewDefaultVolumeController returns a new instance of BasicVolumeController
//
// TODO-BugFix [path]: volBasePath might not end with `/`, causing errors when calling methods.
// We need to validate it using the `path` library or just verifying the string.
func NewDefaultVolumeController(repo repositories.StorageVolumeRepository, volBasePath string, fs afero.Fs) (*BasicVolumeController, error) {
return &BasicVolumeController{
repo: repo,
basePath: volBasePath,
FS: fs,
}, nil
}
// CreateVolume creates a new storage volume given a source (S3, IPFS, job, etc). The
// creation of a storage volume effectively creates an empty directory in the local filesystem
// and writes a record in the database.
//
// The directory name follows the format: `<volSource> + "-" + <name>
// where `name` is random.
//
// TODO-maybe [withName]: allow callers to specify custom name for path
func (vc *BasicVolumeController) CreateVolume(volSource storage.VolumeSource, opts ...storage.CreateVolOpt) (models.StorageVolume, error) {
vol := models.StorageVolume{
Private: false,
ReadOnly: false,
EncryptionType: models.EncryptionTypeNull,
}
for _, opt := range opts {
opt(&vol)
}
vol.Path = vc.basePath + string(volSource) + "-" + utils.RandomString(16)
if err := vc.FS.Mkdir(vol.Path, 0770); err != nil {
return models.StorageVolume{}, fmt.Errorf("failed to create storage volume: %w", err)
}
createdVol, err := vc.repo.Create(context.Background(), vol)
if err != nil {
return models.StorageVolume{}, fmt.Errorf("failed to create storage volume in repository: %w", err)
}
return createdVol, nil
}
// LockVolume makes the volume read-only, not only changing the field value but also changing file permissions.
// It should be used after all necessary data has been written.
// It optionally can also set the CID and mark the volume as private.
//
// TODO-maybe [CID]: maybe calculate CID of every volume in case WithCID opt is not provided
func (vc *BasicVolumeController) LockVolume(pathToVol string, opts ...storage.LockVolOpt) error {
query := vc.repo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("Path", pathToVol))
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
return fmt.Errorf("failed to find storage volume with path %s - Error: %w", pathToVol, err)
}
for _, opt := range opts {
opt(&vol)
}
// update records
vol.ReadOnly = true
updatedVol, err := vc.repo.Update(context.Background(), vol.ID, vol)
if err != nil {
return fmt.Errorf("failed to update storage volume with path %s - Error: %w", pathToVol, err)
}
// change file permissions
if err := vc.FS.Chmod(updatedVol.Path, 0400); err != nil {
return fmt.Errorf("failed to make storage volume read-only (path: %s): %w", updatedVol.Path, err)
}
return nil
}
// WithPrivate designates a given volume as private. It can be used both
// when creating or locking a volume.
func WithPrivate[T storage.CreateVolOpt | storage.LockVolOpt]() T {
return func(v *models.StorageVolume) {
v.Private = true
}
}
// WithCID sets the CID of a given volume if already calculated
//
// TODO [validate]: check if CID provided is valid
func WithCID(cid string) storage.LockVolOpt {
return func(v *models.StorageVolume) {
v.CID = cid
}
}
// DeleteVolume deletes a given storage volume record from the database.
// Identifier is either a CID or a path of a volume. Therefore, records for both
// will be deleted.
//
// Note [CID]: if we start to type CID as cid.CID, we may have to use generics here
// as in `[T string | cid.CID]`
func (vc *BasicVolumeController) DeleteVolume(identifier string, idType storage.IDType) error {
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
return fmt.Errorf("identifier type not supported")
}
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
if err == repositories.NotFoundError {
return fmt.Errorf("volume not found: %w", err)
}
return fmt.Errorf("failed to find volume: %w", err)
}
err = vc.repo.Delete(context.Background(), vol.ID)
if err != nil {
return fmt.Errorf("failed to delete volume: %w", err)
}
return nil
}
// ListVolumes returns a list of all storage volumes stored on the database
//
// TODO [filter]: maybe add opts to filter results by certain values
func (vc *BasicVolumeController) ListVolumes() ([]models.StorageVolume, error) {
volumes, err := vc.repo.FindAll(context.Background(), vc.repo.GetQuery())
if err != nil {
return nil, fmt.Errorf("failed to list volumes: %w", err)
}
return volumes, nil
}
// GetSize returns the size of a volume
// TODO-minor: identify which measurement type will be used
func (vc *BasicVolumeController) GetSize(identifier string, idType storage.IDType) (int64, error) {
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
return 0, fmt.Errorf("unsupported ID type: %d", idType)
}
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
return 0, fmt.Errorf("failed to find volume: %w", err)
}
size, err := utils.GetDirectorySize(vc.FS, vol.Path)
if err != nil {
return 0, fmt.Errorf("failed to get directory size: %w", err)
}
return size, nil
}
// EncryptVolume encrypts a given volume
func (vc *BasicVolumeController) EncryptVolume(path string, encryptor models.Encryptor, encryptionType models.EncryptionType) error {
return fmt.Errorf("not implemented")
}
// DecryptVolume decrypts a given volume
func (vc *BasicVolumeController) DecryptVolume(path string, decryptor models.Decryptor, decryptionType models.EncryptionType) error {
return fmt.Errorf("not implemented")
}
// TODO-minor: compiler-time check for interface implementation
var _ storage.VolumeController = (*BasicVolumeController)(nil)
package basic_controller
import (
"context"
"fmt"
"os"
"testing"
clover "github.com/ostafen/clover/v2"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories/clover"
"gitlab.com/nunet/device-management-service/models"
)
type VolControllerTestSuiteHelper struct {
BasicVolController *BasicVolumeController
Fs afero.Fs
DB *clover.DB
Volumes map[string]*models.StorageVolume
TempDBDir string
}
// SetupVolControllerTestSuite sets up a volume controller with 0-n volumes given a base path.
// If volumes are inputed, directories will be created and volumes will be stored in the database
func SetupVolControllerTestSuite(t *testing.T, basePath string,
volumes map[string]*models.StorageVolume) (*VolControllerTestSuiteHelper, error) {
tempDir, err := os.MkdirTemp("", "clover-test-*")
if err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
db, err := repositories_clover.NewDB(tempDir, []string{"storage_volume"})
if err != nil {
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to open clover db: %w", err)
}
fs := afero.NewMemMapFs()
err = fs.MkdirAll(basePath, 0755)
if err != nil {
db.Close()
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to create base path: %w", err)
}
repo := repositories_clover.NewStorageVolumeRepository(db)
vc, err := NewDefaultVolumeController(repo, basePath, fs)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume controller: %w", err)
}
for _, vol := range volumes {
// create root volume dir
err = fs.MkdirAll(vol.Path, 0755)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume dir: %w", err)
}
// create volume record in db
_, err = repo.Create(context.Background(), *vol)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume record: %w", err)
}
}
helper := &VolControllerTestSuiteHelper{vc, fs, db, volumes, tempDir}
t.Cleanup(func() {
TearDownVolControllerTestSuite(helper)
})
return helper, nil
}
// TearDownVolControllerTestSuite cleans up resources created during setup
func TearDownVolControllerTestSuite(helper *VolControllerTestSuiteHelper) {
if helper.DB != nil {
helper.DB.Close()
}
if helper.TempDBDir != "" {
os.RemoveAll(helper.TempDBDir)
}
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/storage/basic_controller"
)
// Download fetch files from a given S3 bucket. The key may be a directory ending
// with `/` or have a wildcard (`*`) so it handles normal S3 folders but it does
// not handle x-directory.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *S3Storage) Download(ctx context.Context, sourceSpecs *models.SpecConfig) (
models.StorageVolume, error) {
var storageVol models.StorageVolume
source, err := DecodeInputSpec(sourceSpecs)
if err != nil {
return models.StorageVolume{}, err
}
storageVol, err = s.volController.CreateVolume(storage.VolumeSourceS3)
if err != nil {
return models.StorageVolume{}, fmt.Errorf("failed to create storage volume: %v", err)
}
resolvedObjects, err := resolveStorageKey(ctx, s.Client, &source)
if err != nil {
return models.StorageVolume{}, fmt.Errorf("failed to resolve storage key: %v", err)
}
for _, resolvedObject := range resolvedObjects {
err = s.downloadObject(ctx, &source, resolvedObject, storageVol.Path)
if err != nil {
return models.StorageVolume{}, fmt.Errorf("failed to download s3 object: %v", err)
}
}
// after data is filled within the volume, we have to lock it
err = s.volController.LockVolume(storageVol.Path)
if err != nil {
return models.StorageVolume{}, fmt.Errorf("failed to lock storage volume: %v", err)
}
return storageVol, nil
}
func (s *S3Storage) downloadObject(ctx context.Context, source *S3InputSource,
object s3Object, volPath string) error {
outputPath := filepath.Join(volPath, *object.key)
// use the same file system instance used by the Volume Controller
var fs afero.Fs
if basicVolController, ok := s.volController.(*basic_controller.BasicVolumeController); ok {
fs = basicVolController.FS
}
err := fs.MkdirAll(outputPath, 0755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
if object.isDir {
// if object is a directory, we don't need to download it (just create the dir)
return nil
}
outputFile, err := fs.OpenFile(outputPath, os.O_RDWR|os.O_CREATE, 0755)
if err != nil {
return err
}
defer outputFile.Close()
zlog.Sugar().Debugf("Downloading s3 object %s to %s", *object.key, outputPath)
_, err = s.downloader.Download(ctx, outputFile, &s3.GetObjectInput{
Bucket: aws.String(source.Bucket),
Key: object.key,
IfMatch: object.eTag,
})
return nil
}
// resolveStorageKey returns a list of s3 objects within a bucket accordingly to the key provided.
func resolveStorageKey(ctx context.Context, client *s3.Client, source *S3InputSource) ([]s3Object, error) {
key := source.Key
if key == "" {
return nil, fmt.Errorf("key is required")
}
// Check if the key represents a single object
if !strings.HasSuffix(key, "/") && !strings.Contains(key, "*") {
return resolveSingleObject(ctx, client, source)
}
// key represents multiple objects
return resolveObjectsWithPrefix(ctx, client, source)
}
func resolveSingleObject(ctx context.Context, client *s3.Client, source *S3InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
headObjectInput := &s3.HeadObjectInput{
Bucket: aws.String(source.Bucket),
Key: aws.String(key),
}
headObjectOut, err := client.HeadObject(ctx, headObjectInput)
if err != nil {
return []s3Object{}, fmt.Errorf("failed to retrieve object metadata: %v", err)
}
// TODO-minor: validate checksum if provided
if strings.HasPrefix(*headObjectOut.ContentType, "application/x-directory") {
return []s3Object{}, fmt.Errorf("x-directory is not yet handled!")
}
return []s3Object{
{
key: aws.String(source.Key),
eTag: headObjectOut.ETag,
size: *headObjectOut.ContentLength,
},
}, nil
}
func resolveObjectsWithPrefix(ctx context.Context, client *s3.Client, source *S3InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
// List objects with the given prefix
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(source.Bucket),
Prefix: aws.String(key),
}
var objects []s3Object
paginator := s3.NewListObjectsV2Paginator(client, listObjectsInput)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list objects: %v", err)
}
for _, obj := range page.Contents {
objects = append(objects, s3Object{
key: aws.String(*obj.Key),
size: *obj.Size,
isDir: strings.HasSuffix(*obj.Key, "/"),
})
}
}
return objects, nil
}
package s3
import (
"context"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
)
// GetAWSDefaultConfig returns the default AWS config based on environment variables,
// shared configuration and shared credentials files.
func GetAWSDefaultConfig() (aws.Config, error) {
var optFns []func(*config.LoadOptions) error
return config.LoadDefaultConfig(context.Background(), optFns...)
}
func hasValidCredentials(config aws.Config) bool {
credentials, err := config.Credentials.Retrieve(context.Background())
if err != nil {
return false
}
return credentials.HasKeys()
}
// sanitizeKey removes trailing spaces and wildcards
func sanitizeKey(key string) string {
return strings.TrimSuffix(strings.TrimSpace(key), "*")
}
package s3
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("s3")
}
package s3
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/storage"
)
type S3Storage struct {
*s3.Client
volController storage.VolumeController
downloader *s3Manager.Downloader
uploader *s3Manager.Uploader
}
type s3Object struct {
key *string
eTag *string
versionID *string
size int64
isDir bool
}
// NewClient creates a new S3Storage which includes a S3-SDK client.
// It depends on a VolumeController to manage the volumes being acted upon.
func NewClient(config aws.Config, volController storage.VolumeController) (*S3Storage, error) {
if !hasValidCredentials(config) {
return nil, fmt.Errorf("invalid credentials")
}
s3Client := s3.NewFromConfig(config)
return &S3Storage{
s3Client,
volController,
s3Manager.NewDownloader(s3Client),
s3Manager.NewUploader(s3Client),
}, nil
}
func (s *S3Storage) Size(ctx context.Context, source *models.SpecConfig) (uint64, error) {
inputSource, err := DecodeInputSpec(source)
if err != nil {
return 0, fmt.Errorf("failed to decode input spec: %v", err)
}
input := &s3.HeadObjectInput{
Bucket: aws.String(inputSource.Bucket),
Key: aws.String(inputSource.Key),
}
output, err := s.HeadObject(ctx, input)
if err != nil {
return 0, fmt.Errorf("failed to get object size: %v", err)
}
return uint64(*output.ContentLength), nil
}
// Compile time interface check
// var _ storage.StorageProvider = (*S3Storage)(nil)
package s3
import (
"fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure"
"gitlab.com/nunet/device-management-service/models"
)
type S3InputSource struct {
Bucket string
Key string
Filter string
Region string
Endpoint string
}
func (s S3InputSource) Validate() error {
if s.Bucket == "" {
return fmt.Errorf("invalid s3 storage params: bucket cannot be empty")
}
return nil
}
func (s S3InputSource) ToMap() map[string]interface{} {
return structs.Map(s)
}
func DecodeInputSpec(spec *models.SpecConfig) (S3InputSource, error) {
if !spec.IsType(models.StorageProviderS3) {
return S3InputSource{}, fmt.Errorf("invalid storage source type. Expected %s but received %s", models.StorageProviderS3, spec.Type)
}
inputParams := spec.Params
if inputParams == nil {
return S3InputSource{}, fmt.Errorf("invalid storage input source params. cannot be nil")
}
var c S3InputSource
if err := mapstructure.Decode(spec.Params, &c); err != nil {
return c, err
}
return c, c.Validate()
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/models"
"gitlab.com/nunet/device-management-service/storage/basic_controller"
)
// Upload uploads all files (recursively) from a local volume to an S3 bucket.
// It handles directories.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *S3Storage) Upload(ctx context.Context, vol models.StorageVolume,
destinationSpecs *models.SpecConfig) error {
target, err := DecodeInputSpec(destinationSpecs)
if err != nil {
return fmt.Errorf("failed to decode input spec: %v", err)
}
sanitizedKey := sanitizeKey(target.Key)
// set file system to act upon based on the volume controller implementation
var fs afero.Fs
if basicVolController, ok := s.volController.(*basic_controller.BasicVolumeController); ok {
fs = basicVolController.FS
}
zlog.Sugar().Debugf("Uploading files from %s to s3://%s/%s", vol.Path, target.Bucket, sanitizedKey)
err = afero.Walk(fs, vol.Path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Skip directories
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(vol.Path, filePath)
if err != nil {
return fmt.Errorf("failed to get relative path: %v", err)
}
// Construct the S3 key by joining the sanitized key and the relative path
s3Key := filepath.Join(sanitizedKey, relPath)
file, err := fs.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
zlog.Sugar().Debugf("Uploading %s to s3://%s/%s", filePath, target.Bucket, s3Key)
_, err = s.uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(target.Bucket),
Key: aws.String(s3Key),
Body: file,
})
if err != nil {
return fmt.Errorf("failed to upload file to S3: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("upload failed. It's possible that some files were uploaded; Error: %v", err)
}
return nil
}
package utils
import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/cosmos/btcutil/bech32"
"github.com/ethereum/go-ethereum/common"
"github.com/fivebinaries/go-cardano-serialization/address"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/models"
)
// KoiosEndpoint type for Koios rest api endpoints
type KoiosEndpoint string
const (
// KoiosMainnet - mainnet Koios rest api endpoint
KoiosMainnet KoiosEndpoint = "api.koios.rest"
// KoiosPreProd - testnet preprod Koios rest api endpoint
KoiosPreProd KoiosEndpoint = "preprod.koios.rest"
)
type UTXOs struct {
TxHash string `json:"tx_hash"`
IsSpent bool `json:"is_spent"`
}
type TxHashResp struct {
TxHash string `json:"tx_hash"`
TransactionType string `json:"transaction_type"`
DateTime string `json:"date_time"`
}
type ClaimCardanoTokenBody struct {
ComputeProviderAddress string `json:"compute_provider_address"`
TxHash string `json:"tx_hash"`
}
type rewardRespToCPD struct {
ServiceProviderAddr string `json:"service_provider_addr"`
ComputeProviderAddr string `json:"compute_provider_addr"`
RewardType string `json:"reward_type,omitempty"`
SignatureDatum string `json:"signature_datum,omitempty"`
MessageHashDatum string `json:"message_hash_datum,omitempty"`
Datum string `json:"datum,omitempty"`
SignatureAction string `json:"signature_action,omitempty"`
MessageHashAction string `json:"message_hash_action,omitempty"`
Action string `json:"action,omitempty"`
}
type UpdateTxStatusBody struct {
Address string `json:"address,omitempty"`
}
func GetJobTxHashes(size int, clean string) ([]TxHashResp, error) {
if clean != "done" && clean != "refund" && clean != "withdraw" && clean != "" {
return nil, fmt.Errorf("invalid clean_tx parameter")
}
err := db.DB.Where("transaction_type = ?", clean).Delete(&models.Services{}).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
}
var resp []TxHashResp
var services []models.Services
if size == 0 {
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type is NOT NULL").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("no job deployed to request reward for: %w", err)
}
} else {
services, err = getLimitedTransactions(size)
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("could not get limited transactions: %w", err)
}
}
for _, service := range services {
resp = append(resp, TxHashResp{
TxHash: service.TxHash,
TransactionType: service.TransactionType,
DateTime: service.CreatedAt.String(),
})
}
return resp, nil
}
func RequestReward(claim ClaimCardanoTokenBody) (*rewardRespToCPD, error) {
// At some point, management dashboard should send container ID to identify
// against which container we are requesting reward
service := models.Services{
TxHash: claim.TxHash,
}
// SELECTs the first record; first record which is not marked as delete
err := db.DB.Where("tx_hash = ?", claim.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
return nil, fmt.Errorf("unknown tx hash: %w", err)
}
zlog.Sugar().Infof("service found from txHash: %+v", service)
if service.JobStatus == "running" {
return nil, fmt.Errorf("job is still running")
// c.JSON(503, gin.H{"error": "the job is still running"})
}
reward := rewardRespToCPD{
ServiceProviderAddr: service.ServiceProviderAddr,
ComputeProviderAddr: service.ComputeProviderAddr,
RewardType: service.TransactionType,
SignatureDatum: service.SignatureDatum,
MessageHashDatum: service.MessageHashDatum,
Datum: service.Datum,
SignatureAction: service.SignatureAction,
MessageHashAction: service.MessageHashAction,
Action: service.Action,
}
return &reward, nil
}
func SendStatus(status models.BlockchainTxStatus) string {
if status.TransactionStatus == "success" {
zlog.Sugar().Infof("withdraw transaction successful - updating DB")
// Partial deletion of entry
var service models.Services
err := db.DB.Where("tx_hash = ?", status.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
}
service.TransactionType = "done"
db.DB.Save(&service)
}
return status.TransactionStatus
}
func UpdateStatus(body UpdateTxStatusBody) error {
utxoHashes, err := GetUTXOsOfSmartContract(body.Address, KoiosPreProd)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to fetch UTXOs from Blockchain: %w", err)
}
fiveMinAgo := time.Now().Add(-5 * time.Minute)
var services []models.Services
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Where("deleted_at IS NULL").
Where("created_at <= ?", fiveMinAgo).
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("no job deployed to request reward for: %w", err)
}
err = UpdateTransactionStatus(services, utxoHashes)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to update transaction status")
}
return nil
}
func getLimitedTransactions(sizeDone int) ([]models.Services, error) {
var doneServices []models.Services
var services []models.Services
err := db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type = ?", "done").
Order("created_at DESC").
Limit(sizeDone).
Find(&doneServices).Error
if err != nil {
return []models.Services{}, err
}
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
return []models.Services{}, err
}
services = append(services, doneServices...)
return services, nil
}
// isValidCardano checks if the cardano address is valid
func isValidCardano(addr string, valid *bool) {
defer func() {
if r := recover(); r != nil {
*valid = false
}
}()
if _, err := address.NewAddress(addr); err == nil {
*valid = true
}
}
// ValidateAddress checks if the wallet address is a valid ethereum/cardano address
func ValidateAddress(addr string) error {
if common.IsHexAddress(addr) {
return errors.New("ethereum wallet address not allowed")
}
var validCardano = false
isValidCardano(addr, &validCardano)
if validCardano {
return nil
}
return errors.New("invalid cardano wallet address")
}
func GetAddressPaymentCredential(addr string) (string, error) {
_, data, err := bech32.Decode(addr, 1023)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
converted, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
return hex.EncodeToString(converted)[2:58], nil
}
// GetTxReceiver returns the list of receivers of a transaction from the transaction hash
func GetTxReceiver(txHash string, endpoint KoiosEndpoint) (string, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_info", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return "", err
}
defer resp.Body.Close()
res := []struct {
Outputs []struct {
InlineDatum struct {
Value struct {
Fields []struct {
Bytes string `json:"bytes"`
} `json:"fields"`
} `json:"value"`
} `json:"inline_datum"`
} `json:"outputs"`
}{}
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&res); err != nil && err != io.EOF {
return "", err
}
if len(res) == 0 || len(res[0].Outputs) == 0 || len(res[0].Outputs[1].InlineDatum.Value.Fields) == 0 {
return "", fmt.Errorf("unable to find receiver")
}
receiver := res[0].Outputs[1].InlineDatum.Value.Fields[1].Bytes
return receiver, nil
}
// GetTxConfirmations returns the number of confirmations of a transaction from the transaction hash
func GetTxConfirmations(txHash string, endpoint KoiosEndpoint) (int, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_status", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return 0, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}
var res []struct {
TxHash string `json:"tx_hash"`
Confirmations int `json:"num_confirmations"`
}
if err := json.Unmarshal(body, &res); err != nil {
return 0, err
}
return res[len(res)-1].Confirmations, nil
}
// WaitForTxConfirmation waits for a transaction to be confirmed
func WaitForTxConfirmation(confirmations int, timeout time.Duration, txHash string, endpoint KoiosEndpoint) error {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
conf, err := GetTxConfirmations(txHash, endpoint)
if err != nil {
return err
}
if conf >= confirmations {
return nil
}
case <-time.After(timeout):
return errors.New("timeout")
}
}
}
// GetUTXOsOfSmartContract fetch all utxos of smart contract and return list of tx_hash
func GetUTXOsOfSmartContract(address string, endpoint KoiosEndpoint) ([]string, error) {
type Request struct {
Address []string `json:"_addresses"`
Extended bool `json:"_extended"`
}
reqBody, _ := json.Marshal(Request{Address: []string{address}, Extended: true})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/address_utxos", endpoint),
"application/json",
bytes.NewBuffer(reqBody),
)
if err != nil {
return nil, fmt.Errorf("error making POST request: %v", err)
}
defer resp.Body.Close()
var utxos []UTXOs
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&utxos); err != nil && err != io.EOF {
return nil, err
}
var utxoHashes []string
for _, utxo := range utxos {
utxoHashes = append(utxoHashes, utxo.TxHash)
}
return utxoHashes, nil
}
// UpdateTransactionStatus updates the status of claimed transactions in local DB
func UpdateTransactionStatus(services []models.Services, utxoHashes []string) error {
for _, service := range services {
if !SliceContains(utxoHashes, service.TxHash) {
if service.TransactionType == "withdraw" {
service.TransactionType = transactionWithdrawnStatus
} else if service.TransactionType == "refund" {
service.TransactionType = transactionRefundedStatus
} else if service.TransactionType == "distribute-50" || service.TransactionType == "distribute-75" {
service.TransactionType = transactionDistributedStatus
}
if err := db.DB.Save(&service).Error; err != nil {
return err
}
}
}
return nil
}
package utils
import (
"fmt"
"os"
"github.com/spf13/afero"
)
func GetDirectorySize(fs afero.Fs, path string) (int64, error) {
var size int64
err := afero.Walk(fs, path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to calculate volume size: %w", err)
}
return size, nil
}
package utils
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
const transactionWithdrawnStatus = "withdrawn"
const transactionRefundedStatus = "refunded"
const transactionDistributedStatus = "distributed"
func init() {
zlog = logger.OtelZapLogger("utils")
}
package utils
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"github.com/gin-gonic/gin"
"gitlab.com/nunet/device-management-service/internal/config"
)
// InternalAPIURL is a helper method to compose API URLs
func InternalAPIURL(protocol, endpoint, query string) (string, error) {
if protocol == "" || endpoint == "" {
return "", fmt.Errorf("protocol and endpoint values must be specified")
}
if protocol != "http" && protocol != "ws" {
return "", fmt.Errorf("invalid protocol: %s", protocol)
}
port := config.GetConfig().Rest.Port
if port == 0 {
return "", fmt.Errorf("port is not configured")
}
serverURL := url.URL{
Scheme: protocol,
Host: fmt.Sprintf("localhost:%d", port),
Path: endpoint,
RawQuery: query,
}
return serverURL.String(), nil
}
// MakeInternalRequest is a helper method to make call to DMS's own API
func MakeInternalRequest(c *gin.Context, methodType, internalEndpoint, query string, body []byte) (*http.Response, error) {
endpoint, err := InternalAPIURL("http", internalEndpoint, query)
if err != nil {
return nil, err
}
req, err := http.NewRequest(
methodType,
endpoint,
bytes.NewBuffer(body),
)
if err != nil {
return nil, err
}
client := http.Client{}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
if c != nil {
req.Header = c.Request.Header.Clone()
}
resp, err := client.Do(req)
if err != nil {
return nil, err
// c.JSON(400, gin.H{
// "message": fmt.Sprintf("Error making %s request to %s", methodType, internalEndpoint),
// "timestamp": time.Now(),
// })
// return
}
return resp, nil
}
func MakeRequest(c *gin.Context, client *http.Client, uri string, body []byte, errMsg string) error {
// set the HTTP method, url, and request body
req, err := http.NewRequest(http.MethodPut, uri, bytes.NewBuffer(body))
if err != nil {
return fmt.Errorf("unable to make new request: %v", err)
}
// set the request header Content-Type for json
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
_, err = client.Do(req)
if err != nil {
// c.JSON(400, gin.H{
// "message": errMsg,
// "timestamp": time.Now(),
// })
// return
return err
}
return nil
}
func ResponseBody(c *gin.Context, methodType, internalEndpoint, query string, body []byte) (responseBody []byte, errMsg error) {
resp, err := MakeInternalRequest(c, methodType, internalEndpoint, query, body)
if err != nil {
return nil, fmt.Errorf("unable to make internal request: %v", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("cannot read response body: %v", err)
}
return respBody, nil
}
package utils
import (
"io"
"sync"
"time"
)
type IOProgress struct {
n float64
size float64
started time.Time
estimated time.Time
err error
}
type Reader struct {
reader io.Reader
lock sync.RWMutex
Progress IOProgress
}
type Writer struct {
writer io.Writer
lock sync.RWMutex
Progress IOProgress
}
func ReaderWithProgress(r io.Reader, size int64) *Reader {
return &Reader{
reader: r,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func WriterWithProgress(w io.Writer, size int64) *Writer {
return &Writer{
writer: w,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func (r *Reader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.lock.Lock()
r.Progress.n += float64(n)
r.Progress.err = err
r.lock.Unlock()
return n, err
}
func (w *Writer) Write(p []byte) (n int, err error) {
n, err = w.writer.Write(p)
w.lock.Lock()
w.Progress.n += float64(n)
w.Progress.err = err
w.lock.Unlock()
return n, err
}
func (p IOProgress) Size() float64 {
return p.size
}
func (p IOProgress) N() float64 {
return p.n
}
func (p IOProgress) Complete() bool {
if p.err == io.EOF {
return true
}
if p.size == -1 {
return false
}
return p.n >= p.size
}
// Percent calculates the percentage complete.
func (p IOProgress) Percent() float64 {
if p.n == 0 {
return 0
}
if p.n >= p.size {
return 100
}
return 100.0 / (p.size / p.n)
}
func (p IOProgress) Remaining() time.Duration {
if p.estimated.IsZero() {
return time.Until(p.Estimated())
}
return time.Until(p.estimated)
}
func (p IOProgress) Estimated() time.Time {
ratio := p.n / p.size
past := float64(time.Since(p.started))
if p.n > 0.0 {
total := time.Duration(past / ratio)
p.estimated = p.started.Add(total)
}
return p.estimated
}
package utils
import (
"fmt"
"strings"
"sync"
)
// A SyncMap is a concurrency-safe sync.Map that uses strongly-typed
// method signatures to ensure the types of its stored data are known.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// SyncMapFromMap converts a standard Go map to a concurrency-safe SyncMap.
func SyncMapFromMap[K comparable, V any](m map[K]V) *SyncMap[K, V] {
ret := &SyncMap[K, V]{}
for k, v := range m {
ret.Put(k, v)
}
return ret
}
// Get retrieves the value associated with the given key from the map.
// It returns the value and a boolean indicating whether the key was found.
func (m *SyncMap[K, V]) Get(key K) (V, bool) {
value, ok := m.Load(key)
if !ok {
var empty V
return empty, false
}
return value.(V), true
}
// Put inserts or updates a key-value pair in the map.
func (m *SyncMap[K, V]) Put(key K, value V) {
m.Store(key, value)
}
// Iter iterates over each key-value pair in the map, executing the provided function on each pair.
// The iteration stops if the provided function returns false.
func (m *SyncMap[K, V]) Iter(ranger func(key K, value V) bool) {
m.Range(func(key, value any) bool {
k := key.(K)
v := value.(V)
return ranger(k, v)
})
}
// Keys returns a slice containing all the keys present in the map.
func (m *SyncMap[K, V]) Keys() []K {
var keys []K
m.Iter(func(key K, value V) bool {
keys = append(keys, key)
return true
})
return keys
}
// String provides a string representation of the map, listing all key-value pairs.
func (m *SyncMap[K, V]) String() string {
// Use a strings.Builder for efficient string concatenation.
var sb strings.Builder
sb.Write([]byte(`{`))
m.Range(func(key, value any) bool {
// Append each key-value pair to the string builder.
sb.Write([]byte(fmt.Sprintf(`%s=%s`, key, value)))
return true
})
sb.Write([]byte(`}`))
return sb.String()
}
package utils
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/models"
"golang.org/x/exp/slices"
"reflect"
)
var KernelFileURL = "https://d.nunet.io/fc/vmlinux"
var KernelFilePath = "/etc/nunet/vmlinux"
var FilesystemURL = "https://d.nunet.io/fc/nunet-fc-ubuntu-20.04-0.ext4"
var FilesystemPath = "/etc/nunet/nunet-fc-ubuntu-20.04-0.ext4"
// DownloadFile downloads a file from a url and saves it to a filepath
func DownloadFile(url string, filepath string) (err error) {
zlog.Sugar().Infof("Downloading file '", filepath, "' from '", url, "'")
file, err := os.Create(filepath)
if err != nil {
return err
}
defer file.Close()
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
_, err = io.Copy(file, resp.Body)
if err != nil {
return err
}
log.Println("Finished downloading file '", filepath, "'")
return nil
}
// ReadHttpString GET request to http endpoint and return response as string
func ReadHttpString(url string) (string, error) {
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(respBody), nil
}
// RandomString generates a random string of length n
func RandomString(n int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
sb.WriteByte(charset[rand.Intn(len(charset))])
}
return sb.String()
}
// GenerateMachineUUID generates a machine uuid
func GenerateMachineUUID() (string, error) {
var machine models.MachineUUID
machineUUID, err := uuid.NewDCEGroup()
if err != nil {
return "", err
}
machine.UUID = machineUUID.String()
return machine.UUID, nil
}
// GetMachineUUID returns the machine uuid from the DB
func GetMachineUUID() string {
var machine models.MachineUUID
uuid, err := GenerateMachineUUID()
if err != nil {
zlog.Sugar().Errorf("could not generate machine uuid: %v", err)
}
machine.UUID = uuid
result := db.DB.FirstOrCreate(&machine)
if result.Error != nil {
zlog.Sugar().Errorf("could not find or create machine uuid record in DB: %v", result.Error)
}
return machine.UUID
}
// SliceContains checks if a string exists in a slice
func SliceContains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
// DeleteFile deletes a file, with or without a backup
func DeleteFile(path string, backup bool) (err error) {
if backup {
err = os.Rename(path, fmt.Sprintf("%s.bk.%d", path, time.Now().Unix()))
} else {
err = os.Remove(path)
}
return
}
// ReadyForElastic checks if the device is ready to send logs to elastic
func ReadyForElastic() bool {
elasticToken := models.ElasticToken{}
db.DB.Find(&elasticToken)
return elasticToken.NodeId != "" && elasticToken.ChannelName != ""
}
// PromptYesNo loops on confirmation from user until valid answer
func PromptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) {
reader := bufio.NewReader(in)
for {
fmt.Fprintf(out, "%s (y/N): ", prompt)
response, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("read response string failed: %w", err)
}
response = strings.ToLower(strings.TrimSpace(response))
if response == "y" || response == "yes" {
return true, nil
} else if response == "n" || response == "no" {
return false, nil
}
}
}
// CreateDirectoryIfNotExists creates a directory if it does not exist
func CreateDirectoryIfNotExists(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
err := os.MkdirAll(path, 0755)
if err != nil {
return err
}
}
return nil
}
// CalculateSHA256Checksum calculates the SHA256 checksum of a file
func CalculateSHA256Checksum(filePath string) (string, error) {
// Open the file for reading
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
// Create a new SHA-256 hash
hash := sha256.New()
// Copy the file's contents into the hash object
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
// Calculate the checksum and return it as a hexadecimal string
checksum := hex.EncodeToString(hash.Sum(nil))
return checksum, nil
}
// put checksum in file
func CreateCheckSumFile(filePath string, checksum string) (string, error) {
sha256FilePath := fmt.Sprintf("%s.sha256.txt", filePath)
sha256File, err := os.Create(sha256FilePath)
if err != nil {
return "", fmt.Errorf("unable to create SHA-256 checksum file: %v", err)
}
defer sha256File.Close()
_, err = sha256File.WriteString(checksum)
if err != nil {
return "", fmt.Errorf("unable to write to SHA-256 checksum file: %v", err)
}
return sha256FilePath, nil
}
// ExtractTarGzToPath extracts a tar.gz file to a specified path
func ExtractTarGzToPath(tarGzFilePath, extractedPath string) error {
// Ensure the target directory exists; create it if it doesn't.
if err := os.MkdirAll(extractedPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating target directory: %v", err)
}
tarGzFile, err := os.Open(tarGzFilePath)
if err != nil {
return fmt.Errorf("error opening tar.gz file: %v", err)
}
defer tarGzFile.Close()
gzipReader, err := gzip.NewReader(tarGzFile)
if err != nil {
return fmt.Errorf("error creating gzip reader: %v", err)
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading tar header: %v", err)
}
// Construct the full target path by joining the target directory with
// the name of the file or directory from the archive.
fullTargetPath := filepath.Join(extractedPath, header.Name)
// Ensure that the directory path leading to the file exists.
if header.FileInfo().IsDir() {
// Create the directory and any parent directories as needed.
if err := os.MkdirAll(fullTargetPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
} else {
// Create the file and any parent directories as needed.
if err := os.MkdirAll(filepath.Dir(fullTargetPath), os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// Create a new file with the specified path.
newFile, err := os.Create(fullTargetPath)
if err != nil {
return fmt.Errorf("error creating file: %v", err)
}
defer newFile.Close()
// Copy the file contents from the tar archive to the new file.
if _, err := io.Copy(newFile, tarReader); err != nil {
return fmt.Errorf("error copying file contents: %v", err)
}
}
}
return nil
}
// CheckWSL check if running in WSL
func CheckWSL() (bool, error) {
file, err := os.Open("/proc/version")
if err != nil {
return false, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "Microsoft") || strings.Contains(line, "WSL") {
return true, nil
}
}
if scanner.Err() != nil {
return false, scanner.Err()
}
return false, nil
}
// SaveServiceInfo updates service info into SP's DMS for claim Reward by SP user
func SaveServiceInfo(cpService models.Services) error {
var spService models.Services
err := db.DB.Model(&models.Services{}).Where("tx_hash = ?", cpService.TxHash).Find(&spService).Error
if err != nil {
return fmt.Errorf("Unable to find service on SP side: %v", err)
}
cpService.ID = spService.ID
cpService.CreatedAt = spService.CreatedAt
result := db.DB.Model(&models.Services{}).Where("tx_hash = ?", cpService.TxHash).Updates(&cpService)
if result.Error != nil {
return fmt.Errorf("Unable to update service info on SP side: %v", result.Error.Error())
}
return nil
}
func RandomBool() bool {
rand.Seed(time.Now().UnixNano())
return rand.Intn(2) == 1
}
func IsExecutorType(v interface{}) bool {
_, ok := v.(models.ExecutorType)
return ok
}
func IsGPUVendor(v interface{}) bool {
_, ok := v.(models.GPUVendor)
return ok
}
func IsJobType(v interface{}) bool {
_, ok := v.(models.JobType)
return ok
}
func IsJobTypes(v interface{}) bool {
_, ok := v.(models.JobTypes)
return ok
}
func IsExecutor(v interface{}) bool {
_, ok := v.(models.Executor)
return ok
}
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
} else {
result = true
}
}
return result
}
func NoIntersectionSlices(slice1, slice2 []interface{}) bool {
result := false // the default result is false
for _, subElement := range slice1 {
if slices.Contains(slice2, subElement) {
result = false
} else {
result = true
}
}
return result
}
// IntersectionStringSlices returns the intersection of two slices of strings.
func IntersectionSlices(slice1, slice2 []interface{}) []interface{} {
// Create a map to store strings from the first slice.
executorMap := make(map[interface{}]bool)
// Iterate through the first slice and add elements to the map.
for _, str := range slice1 {
executorMap[str] = true
}
// Create a slice to store the intersection of the strings.
intersectionSlice := []interface{}{}
// Iterate through the second slice and check for common elements.
for _, str := range slice2 {
if executorMap[str] {
// If the string is found in the map, add to the intersection slice.
intersectionSlice = append(intersectionSlice, str)
// Remove the string from the map to avoid duplicates in the result.
delete(executorMap, str)
}
}
return intersectionSlice
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
package validate
import (
"reflect"
)
func ConvertNumericToFloat64(n any) (float64, bool) {
switch n := n.(type) {
case int, int8, int16, int32, int64:
return float64(reflect.ValueOf(n).Int()), true
case uint, uint8, uint16, uint32, uint64:
return float64(reflect.ValueOf(n).Uint()), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}
package validate
import (
"strings"
)
// IsBlank checks if a string is empty or contains only whitespace
func IsBlank(s string) bool {
if len(strings.TrimSpace(s)) == 0 {
return true
}
return false
}
// IsNotBlank checks if a string is not empty and does not contain only whitespace
func IsNotBlank(s string) bool {
return !IsBlank(s)
}
// Just checks if a variable is a string
func IsLiteral(s interface{}) bool {
switch s.(type) {
case string:
return true
default:
return false
}
}