package repositories_clover
import (
"fmt"
clover "github.com/ostafen/clover/v2"
)
// NewDB initializes and sets up the clover database using bbolt under the hood.
// Additionally, it automatically creates collections for the necessary types.
func NewDB(path string, collections []string) (*clover.DB, error) {
db, err := clover.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
for _, collection := range collections {
if err := db.CreateCollection(collection); err != nil {
return nil, fmt.Errorf("failed to create collection %s: %w", collection, err)
}
}
return db, nil
}
package repositories_clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatClover is a Clover implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatClover struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatClover.
// It initializes and returns a Clover-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *clover.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatClover{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerClover is a Clover implementation of the RequestTracker interface.
type RequestTrackerClover struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerClover.
// It initializes and returns a Clover-based repository for RequestTracker entities.
func NewRequestTracker(db *clover.DB) repositories.RequestTracker {
return &RequestTrackerClover{
NewGenericRepository[types.RequestTracker](db),
}
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineClover is a Clover implementation of the VirtualMachine interface.
type VirtualMachineClover struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineClover.
// It initializes and returns a Clover-based repository for VirtualMachine entities.
func NewVirtualMachine(db *clover.DB) repositories.VirtualMachine {
return &VirtualMachineClover{
NewGenericRepository[types.VirtualMachine](db),
}
}
package repositories_clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pKField = "_id"
)
// GenericEntityRepositoryClover is a generic single entity repository implementation using Clover.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryClover.
// It initializes and returns a repository with the provided Clover database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericEntityRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericEntityRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericEntityRepositoryClover[T]) query() *clover_q.Query {
return clover_q.NewQuery(repo.collection)
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryClover[T]) Save(ctx context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) Get(ctx context.Context) (T, error) {
var model T
q := repo.query().Sort(clover_q.SortOption{
Field: "CreatedAt",
Direction: -1,
})
doc, err := repo.db.FindFirst(q)
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, fmt.Errorf("Failed to convert document to model: %v", err)
}
return model, nil
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryClover[T]) Clear(ctx context.Context) error {
return repo.db.Delete(repo.query())
}
// History retrieves previous versions of the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
q := repo.query()
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return false
}
models = append(models, model)
return true
})
return models, handleDBError(err)
}
package repositories_clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pkField = "_id"
deletedAtField = "DeletedAt"
)
// GenericRepositoryClover is a generic repository implementation using Clover.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericRepository creates a new instance of GenericRepositoryClover.
// It initializes and returns a repository with the provided Clover database.
func NewGenericRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericRepositoryClover[T]) query(includeDeleted bool) *clover_q.Query {
query := clover_q.NewQuery(repo.collection)
if !includeDeleted {
query = query.Where(clover_q.Field(deletedAtField).LtEq(time.Unix(0, 0)))
}
return query
}
func (repo *GenericRepositoryClover[T]) queryWithID(
id interface{},
includeDeleted bool,
) *clover_q.Query {
return repo.query(includeDeleted).Where(clover_q.Field(pkField).Eq(id.(string)))
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryClover[T]) Create(ctx context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return data, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryClover[T]) Get(ctx context.Context, id interface{}) (T, error) {
var model T
doc, err := repo.db.FindFirst(repo.queryWithID(id, false))
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryClover[T]) Update(
ctx context.Context,
id interface{},
data T,
) (T, error) {
updates := toCloverDoc(data).AsMap()
updates["UpdatedAt"] = time.Now()
err := repo.db.Update(repo.queryWithID(id, false), updates)
if err != nil {
return data, handleDBError(err)
}
data, err = repo.Get(ctx, id)
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryClover[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.Update(
repo.queryWithID(id, false),
map[string]interface{}{"DeletedAt": time.Now()},
)
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryClover[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var model T
q := repo.query(false)
q = applyConditions(q, query)
doc, err := repo.db.FindFirst(q)
if err != nil {
return model, handleDBError(err)
} else if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, fmt.Errorf("Failed to convert document to model: %v", err)
}
return model, nil
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryClover[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
var modelParsingErr error
q := repo.query(false)
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
model, internalErr := toModel[T](doc, false)
if internalErr != nil {
modelParsingErr = handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, internalErr))
return false
}
models = append(models, model)
return true
})
if err != nil {
return models, handleDBError(err)
}
if modelParsingErr != nil {
return models, modelParsingErr
}
return models, nil
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a Clover database query.
// It takes a Clover database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified Clover database instance is returned.
func applyConditions[T repositories.ModelType](
q *clover_q.Query,
query repositories.Query[T],
) *clover_q.Query {
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
// change the field name to json tag name if specified in the struct
condition.Field = fieldJSONTag[T](condition.Field)
switch condition.Operator {
case "=":
q = q.Where(clover_q.Field(condition.Field).Eq(condition.Value))
case ">":
q = q.Where(clover_q.Field(condition.Field).Gt(condition.Value))
case ">=":
q = q.Where(clover_q.Field(condition.Field).GtEq(condition.Value))
case "<":
q = q.Where(clover_q.Field(condition.Field).Lt(condition.Value))
case "<=":
q = q.Where(clover_q.Field(condition.Field).LtEq(condition.Value))
case "!=":
q = q.Where(clover_q.Field(condition.Field).Neq(condition.Value))
case "IN":
if values, ok := condition.Value.([]interface{}); ok {
q = q.Where(clover_q.Field(condition.Field).In(values...))
}
case "LIKE":
if value, ok := condition.Value.(string); ok {
q = q.Where(clover_q.Field(condition.Field).Like(value))
}
}
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldName = fieldJSONTag[T](fieldName)
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
q = q.Where(clover_q.Field(fieldName).Eq(fieldValue))
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := 1
if query.SortBy[0] == '-' {
dir = -1
query.SortBy = fieldJSONTag[T](query.SortBy[1:])
}
q = q.Sort(clover_q.SortOption{Field: query.SortBy, Direction: dir})
}
// Apply limit if specified in the query.
if query.Limit > 0 {
q = q.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
q = q.Limit(query.Offset)
}
return q
}
package repositories_clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoClover is a Clover implementation of the PeerInfo interface.
type PeerInfoClover struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoClover.
// It initializes and returns a Clover-based repository for PeerInfo entities.
func NewPeerInfo(db *clover.DB) repositories.PeerInfo {
return &PeerInfoClover{NewGenericRepository[types.PeerInfo](db)}
}
// MachineClover is a Clover implementation of the Machine interface.
type MachineClover struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineClover.
// It initializes and returns a Clover-based repository for Machine entities.
func NewMachine(db *clover.DB) repositories.Machine {
return &MachineClover{NewGenericRepository[types.Machine](db)}
}
// FreeResourcesClover is a Clover implementation of the FreeResources interface.
type FreeResourcesClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResources(db *clover.DB) repositories.FreeResources {
return &FreeResourcesClover{NewGenericEntityRepository[types.FreeResources](db)}
}
// AvailableResourcesClover is a Clover implementation of the AvailableResources interface.
type AvailableResourcesClover struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesClover.
// It initializes and returns a Clover-based repository for AvailableResources entity.
func NewAvailableResources(db *clover.DB) repositories.AvailableResources {
return &AvailableResourcesClover{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesClover is a Clover implementation of the Services interface.
type ServicesClover struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesClover.
// It initializes and returns a Clover-based repository for Services entities.
func NewServices(db *clover.DB) repositories.Services {
return &ServicesClover{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsClover is a Clover implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsClover struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsClover.
// It initializes and returns a Clover-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *clover.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsClover{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoClover is a Clover implementation of the Libp2pInfo interface.
type Libp2pInfoClover struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoClover.
// It initializes and returns a Clover-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *clover.DB) repositories.Libp2pInfo {
return &Libp2pInfoClover{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDClover is a Clover implementation of the MachineUUID interface.
type MachineUUIDClover struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDClover.
// It initializes and returns a Clover-based repository for MachineUUID entity.
func NewMachineUUID(db *clover.DB) repositories.MachineUUID {
return &MachineUUIDClover{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionClover is a Clover implementation of the Connection interface.
type ConnectionClover struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionClover.
// It initializes and returns a Clover-based repository for Connection entities.
func NewConnection(db *clover.DB) repositories.Connection {
return &ConnectionClover{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenClover is a Clover implementation of the ElasticToken interface.
type ElasticTokenClover struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenClover.
// It initializes and returns a Clover-based repository for ElasticToken entities.
func NewElasticToken(db *clover.DB) repositories.ElasticToken {
return &ElasticTokenClover{NewGenericRepository[types.ElasticToken](db)}
}
package repositories_clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeClover is a Clover implementation of the StorageVolume interface.
type StorageVolumeClover struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeClover.
// It initializes and returns a Clover-based repository for StorageVolume entities.
func NewStorageVolume(db *clover.DB) repositories.StorageVolume {
return &StorageVolumeClover{
NewGenericRepository[types.StorageVolume](db),
}
}
package repositories_clover
import (
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
"gitlab.com/nunet/device-management-service/db/repositories"
)
func handleDBError(err error) error {
if err != nil {
switch err {
case clover.ErrDocumentNotExist:
return repositories.NotFoundError
case clover.ErrDuplicateKey:
return repositories.InvalidDataError
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.DatabaseError, err)
}
}
return nil
}
func toCloverDoc[T repositories.ModelType](data T) *clover_d.Document {
jsonBytes, err := json.Marshal(data)
if err != nil {
return clover_d.NewDocument()
}
mappedData := make(map[string]interface{})
err = json.Unmarshal(jsonBytes, &mappedData)
if err != nil {
return clover_d.NewDocument()
}
doc := clover_d.NewDocumentOf(mappedData)
return doc
}
func toModel[T repositories.ModelType](doc *clover_d.Document, isEntityRepo bool) (T, error) {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return model, err
}
if !isEntityRepo {
// we shouldn't try to update IDs of entity repositories as they might not
// even have an ID at all
model, err = repositories.UpdateField(model, "ID", doc.ObjectId())
if err != nil {
return model, err
}
}
return model, nil
}
func fieldJSONTag[T repositories.ModelType](field string) string {
fieldName := field
if field, ok := reflect.TypeOf(*new(T)).FieldByName(field); ok {
if tag, ok := field.Tag.Lookup("json"); ok {
fieldName = strings.Split(tag, ",")[0]
}
}
return fieldName
}
package repositories
import (
"context"
)
// QueryCondition is a struct representing a query condition.
type QueryCondition struct {
Field string // Field specifies the database or struct field to which the condition applies.
Operator string // Operator defines the comparison operator (e.g., "=", ">", "<").
Value interface{} // Value is the expected value for the given field.
}
type ModelType interface{}
// Query is a struct that wraps both the instance of type T and additional query parameters.
// It is used to construct queries with conditions, sorting, limiting, and offsetting.
type Query[T any] struct {
Instance T // Instance is an optional object of type T used to build conditions from its fields.
Conditions []QueryCondition // Conditions represent the conditions applied to the query.
SortBy string // SortBy specifies the field by which the query results should be sorted.
Limit int // Limit specifies the maximum number of results to return.
Offset int // Offset specifies the number of results to skip before starting to return data.
}
// GenericRepository is an interface defining basic CRUD operations and standard querying methods.
type GenericRepository[T ModelType] interface {
// Create adds a new record to the repository.
Create(ctx context.Context, data T) (T, error)
// Get retrieves a record by its identifier.
Get(ctx context.Context, id interface{}) (T, error)
// Update modifies a record by its identifier.
Update(ctx context.Context, id interface{}, data T) (T, error)
// Delete removes a record by its identifier.
Delete(ctx context.Context, id interface{}) error
// Find retrieves a single record based on a query.
Find(ctx context.Context, query Query[T]) (T, error)
// FindAll retrieves multiple records based on a query.
FindAll(ctx context.Context, query Query[T]) ([]T, error)
// GetQuery returns an empty query instance for the repository's type.
GetQuery() Query[T]
}
// EQ creates a QueryCondition for equality comparison.
// It takes a field name and a value and returns a QueryCondition with the equality operator.
func EQ(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "=", Value: value}
}
// GT creates a QueryCondition for greater-than comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than operator.
func GT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">", Value: value}
}
// GTE creates a QueryCondition for greater-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than or equal operator.
func GTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">=", Value: value}
}
// LT creates a QueryCondition for less-than comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than operator.
func LT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<", Value: value}
}
// LTE creates a QueryCondition for less-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than or equal operator.
func LTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<=", Value: value}
}
// IN creates a QueryCondition for an "IN" comparison.
// It takes a field name and a slice of values and returns a QueryCondition with the "IN" operator.
func IN(field string, values []interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "IN", Value: values}
}
// LIKE creates a QueryCondition for a "LIKE" comparison.
// It takes a field name and a pattern and returns a QueryCondition with the "LIKE" operator.
func LIKE(field, pattern string) QueryCondition {
return QueryCondition{Field: field, Operator: "LIKE", Value: pattern}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatGORM is a GORM implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatGORM struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatGORM.
// It initializes and returns a GORM-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *gorm.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatGORM{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerGORM is a GORM implementation of the RequestTracker interface.
type RequestTrackerGORM struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerGORM.
// It initializes and returns a GORM-based repository for RequestTracker entities.
func NewRequestTracker(db *gorm.DB) repositories.RequestTracker {
return &RequestTrackerGORM{
NewGenericRepository[types.RequestTracker](db),
}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineGORM is a GORM implementation of the VirtualMachine interface.
type VirtualMachineGORM struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineGORM.
// It initializes and returns a GORM-based repository for VirtualMachine entities.
func NewVirtualMachine(db *gorm.DB) repositories.VirtualMachine {
return &VirtualMachineGORM{
NewGenericRepository[types.VirtualMachine](db),
}
}
package repositories_gorm
import (
"context"
"fmt"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
createdAtField = "CreatedAt"
)
// GenericEntityRepositoryGORM is a generic single entity repository implementation using GORM as an ORM.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB // db is the GORM database instance.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryGORM.
// It initializes and returns a repository with the provided GORM database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *gorm.DB,
) repositories.GenericEntityRepository[T] {
return &GenericEntityRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryGORM[T]) Save(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Get(ctx context.Context) (T, error) {
var result T
query := repo.GetQuery()
query.SortBy = fmt.Sprintf("-%s", createdAtField)
db := repo.db.WithContext(ctx)
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Clear(ctx context.Context) error {
return repo.db.WithContext(ctx).Delete(new(T), "id IS NOT NULL").Error
}
// History retrieves previous records from the repository constrained by the query.
func (repo *GenericEntityRepositoryGORM[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
package repositories_gorm
import (
"context"
"fmt"
"reflect"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericRepositoryGORM is a generic repository implementation using GORM as an ORM.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB
}
// NewGenericRepository creates a new instance of GenericRepositoryGORM.
// It initializes and returns a repository with the provided GORM database.
func NewGenericRepository[T repositories.ModelType](db *gorm.DB) repositories.GenericRepository[T] {
return &GenericRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryGORM[T]) Create(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Get(ctx context.Context, id interface{}) (T, error) {
var result T
err := repo.db.WithContext(ctx).First(&result, "id = ?", id).Error
if err != nil {
return result, handleDBError(err)
}
return result, handleDBError(err)
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Update(ctx context.Context, id interface{}, data T) (T, error) {
err := repo.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Updates(data).Error
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.WithContext(ctx).Delete(new(T), "id = ?", id).Error
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryGORM[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var result T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryGORM[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a GORM database query.
// It takes a GORM database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified GORM database instance is returned.
func applyConditions[T any](db *gorm.DB, query repositories.Query[T]) *gorm.DB {
// Retrieve the table name using the GORM naming strategy.
tableName := db.NamingStrategy.TableName(reflect.TypeOf(*new(T)).Name())
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
columnName := db.NamingStrategy.ColumnName(tableName, condition.Field)
db = db.Where(
fmt.Sprintf("%s %s ?", columnName, condition.Operator),
condition.Value,
)
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
columnName := db.NamingStrategy.ColumnName(tableName, fieldName)
db = db.Where(fmt.Sprintf("%s = ?", columnName), fieldValue)
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := "ASC"
if query.SortBy[0] == '-' {
query.SortBy = query.SortBy[1:]
dir = "DESC"
}
columnName := db.NamingStrategy.ColumnName(tableName, query.SortBy)
db = db.Order(fmt.Sprintf("%s.%s %s", tableName, columnName, dir))
}
// Apply limit if specified in the query.
if query.Limit > 0 {
db = db.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
db = db.Limit(query.Offset)
}
return db
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoGORM is a GORM implementation of the PeerInfo interface.
type PeerInfoGORM struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoGORM.
// It initializes and returns a GORM-based repository for PeerInfo entities.
func NewPeerInfo(db *gorm.DB) repositories.PeerInfo {
return &PeerInfoGORM{NewGenericRepository[types.PeerInfo](db)}
}
// MachineGORM is a GORM implementation of the Machine interface.
type MachineGORM struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineGORM.
// It initializes and returns a GORM-based repository for Machine entities.
func NewMachine(db *gorm.DB) repositories.Machine {
return &MachineGORM{NewGenericRepository[types.Machine](db)}
}
// FreeResourcesGORM is a GORM implementation of the FreeResources interface.
type FreeResourcesGORM struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesGORM.
// It initializes and returns a GORM-based repository for FreeResources entity.
func NewFreeResources(db *gorm.DB) repositories.FreeResources {
return &FreeResourcesGORM{NewGenericEntityRepository[types.FreeResources](db)}
}
// AvailableResourcesGORM is a GORM implementation of the AvailableResources interface.
type AvailableResourcesGORM struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesGORM.
// It initializes and returns a GORM-based repository for AvailableResources entity.
func NewAvailableResources(db *gorm.DB) repositories.AvailableResources {
return &AvailableResourcesGORM{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesGORM is a GORM implementation of the Services interface.
type ServicesGORM struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesGORM.
// It initializes and returns a GORM-based repository for Services entities.
func NewServices(db *gorm.DB) repositories.Services {
return &ServicesGORM{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsGORM is a GORM implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsGORM struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsGORM.
// It initializes and returns a GORM-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *gorm.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsGORM{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoGORM is a GORM implementation of the Libp2pInfo interface.
type Libp2pInfoGORM struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoGORM.
// It initializes and returns a GORM-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *gorm.DB) repositories.Libp2pInfo {
return &Libp2pInfoGORM{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDGORM is a GORM implementation of the MachineUUID interface.
type MachineUUIDGORM struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDGORM.
// It initializes and returns a GORM-based repository for MachineUUID entity.
func NewMachineUUID(db *gorm.DB) repositories.MachineUUID {
return &MachineUUIDGORM{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionGORM is a GORM implementation of the Connection interface.
type ConnectionGORM struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionGORM.
// It initializes and returns a GORM-based repository for Connection entities.
func NewConnection(db *gorm.DB) repositories.Connection {
return &ConnectionGORM{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenGORM is a GORM implementation of the ElasticToken interface.
type ElasticTokenGORM struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenGORM.
// It initializes and returns a GORM-based repository for ElasticToken entities.
func NewElasticToken(db *gorm.DB) repositories.ElasticToken {
return &ElasticTokenGORM{NewGenericRepository[types.ElasticToken](db)}
}
package repositories_gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
type OnboardingParamsGORM struct {
repositories.GenericEntityRepository[types.OnboardingConfig]
}
func NewOnboardingParams(db *gorm.DB) repositories.OnboardingParams {
return &OnboardingParamsGORM{
NewGenericEntityRepository[types.OnboardingConfig](db),
}
}
package repositories_gorm
import (
"errors"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const structFieldNameDeletedAt = "DeletedAt"
// handleDBError is a utility function that translates GORM database errors into custom repository errors.
// It takes a GORM database error as input and returns a corresponding custom error from the repositories package.
func handleDBError(err error) error {
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return repositories.NotFoundError
case gorm.ErrInvalidData, gorm.ErrInvalidField, gorm.ErrInvalidValue:
return repositories.InvalidDataError
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.DatabaseError, err)
}
}
return nil
}
package repositories
import (
"fmt"
"reflect"
)
// UpdateField is a generic function that updates a field of a struct or a pointer to a struct.
// The function uses reflection to dynamically update the specified field of the input struct.
func UpdateField[T interface{}](input T, fieldName string, newValue interface{}) (T, error) {
// Use reflection to get the struct's field
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
// If input is a pointer, get the underlying element
val = val.Elem()
} else {
// If input is not a pointer, ensure it's addressable
val = reflect.ValueOf(&input).Elem()
}
// Check if the input is a struct
if val.Kind() != reflect.Struct {
return input, fmt.Errorf("Not a struct: %T", input)
}
// Get the field by name
field := val.FieldByName(fieldName)
if !field.IsValid() {
return input, fmt.Errorf("Field not found: %v", fieldName)
}
// Check if the field is settable
if !field.CanSet() {
return input, fmt.Errorf("Field not settable: %v", fieldName)
}
// Check if types are compatible
if !reflect.TypeOf(newValue).ConvertibleTo(field.Type()) {
return input, fmt.Errorf(
"Incompatible conversion: %v -> %v; value: %v",
field.Type(), reflect.TypeOf(newValue), newValue,
)
}
// Convert the new value to the field type
convertedValue := reflect.ValueOf(newValue).Convert(field.Type())
// Set the new value to the field
field.Set(convertedValue)
return input, nil
}
// IsEmptyValue checks if value represents a zero-value struct (or pointer to a zero-value struct) using reflection.
// The function is useful for determining if a struct or its pointer is empty, i.e., all fields have their zero-values.
func IsEmptyValue(value interface{}) bool {
// Check if the value is nil
if value == nil {
return true
}
// Use reflection to get the value's type and kind
val := reflect.ValueOf(value)
// If the value is a pointer, dereference it to get the underlying element
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Check if the value is zero (empty) based on its kind
return val.IsZero()
}
package dms
import (
"context"
"errors"
"fmt"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/network"
)
// ActorInterface defines the functionalities of actor.
type ActorInterface interface {
// Address of actor.
Address() *ActorAddrInfo
// SendMessage to another actor.
SendMessage(destination *ActorAddrInfo, m *Message)
// CreateActor creates a new actor.
CreateActor() (*ActorAddrInfo, error)
}
// ActorAddrInfo encapsulates the data required to address an actor.
type ActorAddrInfo struct {
HostID string
InboxAddress string
}
// Valid checks if an actor is valid.
func (a *ActorAddrInfo) Valid() bool {
return a.HostID != "" && a.InboxAddress != ""
}
// Message is passed between actors.
type Message struct {
msgType string
sender string
data []byte
}
// ActorFactory is an actor factory.
type ActorFactory struct {
hostID string
network network.Network
actorRegistry *ActorRegistry
}
// Actor represents an actor.
type Actor struct {
hostID string
address string
network network.Network
messages chan Message
actorRegistry *ActorRegistry
factory *ActorFactory
}
// NewActorFactory holds the dependencies to create and manage actors.
func NewActorFactory(hostID string, network network.Network, actorRegistry *ActorRegistry) *ActorFactory {
return &ActorFactory{
hostID: hostID,
network: network,
actorRegistry: actorRegistry,
}
}
// NewActor allows the factory to create a new actor.
func (f *ActorFactory) NewActor() (*Actor, error) {
return f.newActor(nil)
}
func (f *ActorFactory) newActor(parentActorAddress *ActorAddrInfo) (*Actor, error) {
return newActor(parentActorAddress, f.hostID, f.network, f.actorRegistry, f)
}
// newActor returns a new actor based on the given arguments.
func newActor(parentActorAddress *ActorAddrInfo, hostID string, net network.Network, actorRegistry *ActorRegistry, factory *ActorFactory) (*Actor, error) {
if hostID == "" {
return nil, errors.New("host id is empty")
}
if net == nil {
return nil, fmt.Errorf("network is nil")
}
id, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate uuid: %w", err)
}
createdActor := &Actor{
hostID: hostID,
network: net,
address: id.String(),
actorRegistry: actorRegistry,
factory: factory,
messages: make(chan Message, 100),
}
actorRegistry.AddActorAddress(createdActor.Address())
if parentActorAddress != nil {
actorRegistry.SetParentAddress(createdActor.address, parentActorAddress)
actorRegistry.AddChild(parentActorAddress.InboxAddress, createdActor.Address())
}
return createdActor, nil
}
// Address returns the address of an actor.
func (a *Actor) Address() *ActorAddrInfo {
return &ActorAddrInfo{
HostID: a.hostID,
InboxAddress: a.address,
}
}
// SendMessage sends a message to another actor.
func (a *Actor) SendMessage(ctx context.Context, destination *ActorAddrInfo, m *Message) error {
if !destination.Valid() {
return errors.New("destination actor addr info is invalid")
}
if m == nil {
return errors.New("message is invalid")
}
// get the multiaddress of a host by resolving the hostid
addresses, err := a.network.ResolveAddress(ctx, destination.HostID)
if err != nil {
return fmt.Errorf("failed to send message to actor %s: %v", destination.HostID, err)
}
err = a.network.SendMessage(ctx, []string{addresses[0]}, types.MessageEnvelope{
Type: types.MessageType(fmt.Sprintf("actor/%s/messages/0.0.1", destination.InboxAddress)),
Data: m.data,
})
if err != nil {
return fmt.Errorf("failed to send message to remote actor %s: %v", destination.InboxAddress, err)
}
return nil
}
// Start registers the message handlers and starts an actor.
func (a *Actor) Start() error {
err := a.network.HandleMessage(fmt.Sprintf("actor/%s/messages/0.0.1", a.address), func(data []byte) {
a.messages <- Message{
sender: "sender",
data: data,
}
})
if err != nil {
return fmt.Errorf("failed to start actor %s: %w", a.address, err)
}
return nil
}
// CreateActor creates another actor.
func (a *Actor) CreateActor() (*ActorAddrInfo, error) {
newActor, err := a.factory.newActor(a.Address())
if err != nil {
return nil, fmt.Errorf("failed to create new actor: %w", err)
}
if err := newActor.Start(); err != nil {
return nil, fmt.Errorf("failed to start new actor: %w", err)
}
return newActor.Address(), nil
}
// ProcessMessages reads messages from the incoming messages channel.
func (a *Actor) ProcessMessages() {
for msg := range a.messages {
fmt.Printf("received message from %s", msg.sender)
// switch msg.msgType {
// case "hello":
// {
// a.handleHello(msg)
// }
// default:
// fmt.Printf("unhandled message type: %s\n", msg.msgType)
// }
}
}
// Hello behaviour
func (a *Actor) Hello(ctx context.Context, destination *ActorAddrInfo, m *Message) {
m.msgType = "hello"
a.SendMessage(ctx, destination, m)
}
func (a *Actor) handleHello(m Message) {
fmt.Println("handled hello message", m)
}
package dms
// ActorRegistry represents an actor registry.
type ActorRegistry struct {
actors map[string]*actorInfo
}
type actorInfo struct {
addrInfo *ActorAddrInfo
parent *ActorAddrInfo
childs []*ActorAddrInfo
}
// NewActorRegistry creates an actor registry.
func NewActorRegistry() *ActorRegistry {
return &ActorRegistry{
actors: make(map[string]*actorInfo),
}
}
// AddActorAddress adds an actor address to the registry.
func (r *ActorRegistry) AddActorAddress(a *ActorAddrInfo) {
_, ok := r.actors[a.InboxAddress]
if ok {
return
}
r.actors[a.InboxAddress] = &actorInfo{
addrInfo: a,
parent: nil,
childs: make([]*ActorAddrInfo, 0),
}
}
// SetParentAddress sets parent address of an actor.
func (r *ActorRegistry) SetParentAddress(actorID string, parent *ActorAddrInfo) {
actor, ok := r.actors[actorID]
if !ok {
return
}
actor.parent = parent
r.actors[actorID] = actor
}
// AddChild adds a child to an actor.
func (r *ActorRegistry) AddChild(actorID string, child *ActorAddrInfo) {
actor, ok := r.actors[actorID]
if !ok {
return
}
actor.childs = append(actor.childs, child)
r.actors[actorID] = actor
}
func (r *ActorRegistry) GetActorAddress(address string) (*ActorAddrInfo, bool) {
a, ok := r.actors[address]
return a.addrInfo, ok
}
package dms
import (
// "context"
"context"
"fmt"
"log"
"time"
"gitlab.com/nunet/device-management-service/api"
gdb "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/internal"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/telemetry/logger"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
// "gitlab.com/nunet/device-management-service/internal/messaging"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/utils"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/spf13/afero"
)
// NewP2P is stub, real implementation is needed in order to pass it to
// routers which access them in some handlers.
func NewP2P() libp2p.Libp2p {
return libp2p.Libp2p{}
}
// QUESTION(dms-initialization): should the db instance be constructed here?
func Run() {
ctx := context.Background()
log.Println("WARNING: Most parts commented out in dms.Run()")
config.LoadConfig()
// XXX: wait for server to start properly before sending requests below
// TODO: should be removed
time.Sleep(time.Second * 5)
// check if onboarded
db, err := gorm.Open(sqlite.Open(fmt.Sprintf("%s/nunet.db", config.GetConfig().General.WorkDir)), &gorm.Config{})
if err != nil {
log.Fatalf("Failed to connect to database: %v", err)
}
onboardR := gdb.NewOnboardingParams(db)
p2pR := gdb.NewLibp2pInfo(db)
uuidR := gdb.NewMachineUUID(db)
avResR := gdb.NewAvailableResources(db)
oConf, err := onboardR.Get(ctx)
if err != nil {
log.Fatalf("Failed to get onboarding config: %v", err)
}
onboard := onboarding.New(onboarding.OnboardingConfig{
Fs: afero.Afero{Fs: afero.NewOsFs()},
P2PRepo: p2pR,
UUIDRepo: uuidR,
AvResourceRepo: avResR,
WorkDir: config.GetConfig().WorkDir,
DatabasePath: fmt.Sprintf("%s/nunet.db", config.GetConfig().General.WorkDir),
Channels: []string{"nunet", "nunet-test", "nunet-team", "nunet-edge"},
})
if onboarded, _ := onboard.IsOnboarded(ctx); onboarded {
ValidateOnboarding(&oConf)
p2pParams, err := p2pR.Get(ctx)
if err != nil {
log.Fatalf("Failed to get libp2p info: %v", err)
}
_, err = crypto.UnmarshalPrivateKey(p2pParams.PrivateKey)
if err != nil {
zlog.Sugar().Fatalf("unable to unmarshal private key: %v", err)
}
}
// initialize rest api server
restConfig := api.RESTServerConfig{
P2p: nil,
Onboarding: nil,
Logger: logger.New("rest-server"),
MidW: nil,
Port: config.GetConfig().Rest.Port,
Addr: config.GetConfig().Rest.Addr,
}
rServer := api.NewRESTServer(restConfig)
rServer.InitializeRoutes()
go rServer.Run()
// wait for SIGINT or SIGTERM
sig := <-internal.ShutdownChan
fmt.Printf("Shutting down after receiving %v...\n", sig)
// add cleanup code here
fmt.Println("Cleaning up before shutting down")
return
}
func ValidateOnboarding(oConf *types.OnboardingConfig) {
// Check 1: Check if payment address is valid
err := utils.ValidateAddress(oConf.PublicKey)
if err != nil {
zlog.Sugar().Errorf("the payment address %s is not valid", oConf.PublicKey)
zlog.Sugar().Error("exiting DMS")
return
}
}
package dms
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("dms")
}
package parser
import (
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/nunet"
)
var registry Registry[jobs.JobSpec]
func init() {
registry = &RegistryImpl[jobs.JobSpec]{
parsers: make(map[SpecType]Parser[jobs.JobSpec]),
}
// Register Nunet parser.
nunetParser := NewParser[jobs.JobSpec](
nunet.NewNuNetTransformer(),
nunet.NewNuNetValidator(),
)
registry.RegisterParser(
specTypeNuNet,
nunetParser,
)
// Register other parsers here.
}
package nunet
import (
"fmt"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetTransformer() transform.Transformer {
return transform.NewTransformer(
[]map[tree.Path]transform.TransformerFunc{
{
"jobs": TransformJobs,
"jobs.**.children": TransformJobs,
"jobs.**.volumes": TransformVolumes,
"jobs.**.networks": TransformNetworks,
},
{
"jobs.**.volumes.[]": TransformVolume,
"jobs.**.networks.[]": TransformNetwork,
"jobs.**.libraries.[]": TransformLibrary,
},
{
"jobs.**.execution": TransformExecution,
"jobs.**.volumes.[].remote": TransformVolumeRemote,
},
},
)
}
// TransformJobs transforms the jobs map to a slice and assigns the keys to the "name" field.
func TransformJobs(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
jobs, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid jobs configuration: %v", data)
}
return transform.MapToSlice(jobs)
}
// TransformVolumes transforms the volumes map to a slice and assigns the keys to the "name" field.
func TransformVolumes(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
volumes, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volumes configuration: %v", data)
}
return transform.MapToSlice(volumes)
}
// TransformNetworks transforms the networks map to a slice and assigns the keys to the "name" field.
func TransformNetworks(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
networks, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid networks configuration: %v", data)
}
return transform.MapToSlice(networks)
}
// TransformExecution transforms the engine configuration from flat map to SpecConfig format.
func TransformExecution(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
engine, ok := data.(map[string]any)
result := map[string]any{}
if !ok {
return nil, fmt.Errorf("invalid engine configuration: %v", data)
}
params := map[string]any{}
for k, v := range engine {
if k != "type" {
params[k] = v
}
}
result["type"] = engine["type"]
result["params"] = params
return result, nil
}
// TransformVloume transforms the volume configuration and handles inheritance.
// The volume configuration can be a string in the format "name:mountpoint" or a map.
// If the volume is defined in the parent volumes, the configurations are merged.
func TransformVolume(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, split it into name and mountpoint.
switch v := data.(type) {
case string:
mapping := strings.Split(v, ":")
if len(mapping) != 2 {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
config = map[string]any{
"name": mapping[0],
"mountpoint": mapping[1],
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
// Collect all potential parent paths where the volume could be defined.
parentPaths := []tree.Path{}
pathParts := path.Parts()
for i, part := range pathParts {
if part == "children" {
parentPaths = append(parentPaths, tree.NewPath(pathParts[:i]...))
}
}
// Merge the volume configuration with the parent configurations.
for _, parent := range parentPaths {
// Check if the volume exists in the parent
c, err := transform.GetConfigAtPath(*root, parent.Next("volumes"))
if err != nil {
fmt.Println("error: ", err)
continue
}
volumes, _ := transform.ToAnySlice(c)
for _, v := range volumes {
if volume, ok := v.(map[string]any); ok && volume["name"] == config["name"] {
// Merge the configurations
for k, v := range volume {
config[k] = v
}
}
}
}
return config, nil
}
func TransformVolumeRemote(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
remoteConfig := map[string]any{}
remoteConfig["type"] = config["type"]
if params, ok := config["params"]; ok {
remoteConfig["params"] = params.(map[string]any)
return remoteConfig, nil
}
params := map[string]any{}
for k, v := range config {
if k != "type" {
params[k] = v
}
}
remoteConfig["params"] = params
return remoteConfig, nil
}
// TransformNetwork transforms the network configuration
func TransformNetwork(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid network configuration: %v", data)
}
ports, _ := transform.ToAnySlice(config["ports"])
portMap := []map[string]any{}
for _, port := range ports {
protocol, host, container := "tcp", 0, 0;
switch v := port.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) <= 2 {
host, _ = strconv.Atoi(parts[0])
container, _ = strconv.Atoi(parts[len(parts)-1])
} else if len(parts) == 3 {
protocol = parts[0]
host, _ = strconv.Atoi(parts[1])
container, _ = strconv.Atoi(parts[len(parts)-1])
}
case int:
host = v
container = v
case map[string]any:
switch h := v["host_port"].(type){
case int:
host = h
case string:
host, _ = strconv.Atoi(h)
}
switch c := v["container_port"].(type){
case int:
container = c
case string:
container, _ = strconv.Atoi(c)
}
if p, ok := v["protocol"].(string); ok {
protocol = p
}
}
portMap = append(portMap, map[string]any{
"protocol": protocol,
"host_port": host,
"container_port": container,
})
}
config["port_map"] = portMap
delete(config, "ports")
return config, nil
}
// TransformLibrary tansforms the library configuration to a map.
// The library configuration can be a string in the format "name:version" or a map.
func TransformLibrary(root *map[string]interface{}, data any, path tree.Path) (any, error) {
if data == nil {
return nil, nil
}
switch v := data.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) == 1 {
parts = append(parts, "")
}
if len(parts) != 2 {
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
return map[string]any{
"name": parts[0],
"version": parts[1],
}, nil
case map[string]any:
return v, nil
default:
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
}
package nunet
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetValidator() validate.Validator {
return validate.NewValidator(
map[tree.Path]validate.ValidatorFunc{
"": ValidateSpec,
"jobs.[]": ValidateJob,
"jobs.**.children.[]": ValidateJob,
},
)
}
// ValidateSpec checks the root configuration for consistency.
func ValidateSpec(root *map[string]any, data any, path tree.Path) error {
spec, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid spec configuration: %v", data)
}
// Check if the jobs list is present and not empty.
if spec["jobs"] == nil || len(spec["jobs"].([]any)) == 0 {
return fmt.Errorf("jobs list is required")
}
return nil
}
// ValidateJob checks the job configuration.
func ValidateJob(root *map[string]any, data any, path tree.Path) error {
job, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid job configuration: %v", data)
}
// Check if the job has either children or an execution.
if job["children"] == nil || len(job["children"].([]any)) == 0 {
if job["execution"] == nil {
return fmt.Errorf("job must have either children or an execution")
}
}
return nil
}
package parser
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs"
)
func Parse(specType SpecType, data []byte) (jobs.JobSpec, error) {
result := jobs.JobSpec{}
parser, exists := registry.GetParser(specType)
if !exists {
return result, fmt.Errorf("parser for spec type %s not found", specType)
}
result, err := parser.Parse(data)
if err != nil {
return result, err
}
return result, nil
}
package parser
import (
"encoding/json"
"fmt"
"github.com/mitchellh/mapstructure"
yaml "gopkg.in/yaml.v3"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
type SpecType string
const (
specTypeNuNet SpecType = "nunet"
specTypeNomad SpecType = "nomad"
specTypeK8s SpecType = "k8s"
)
const DefaultTagName = "json"
type Parser[T any] interface {
Parse(data []byte) (T, error)
}
type ParserImpl[T any] struct {
validator validate.Validator
transformer transform.Transformer
}
func NewParser[T any](transformer transform.Transformer, validator validate.Validator) Parser[T] {
return ParserImpl[T]{
transformer: transformer,
validator: validator,
}
}
func (p ParserImpl[T]) Parse(data []byte) (T, error) {
var rawConfig map[string]any
var config T
// Try to unmarshal as YAML first
err := yaml.Unmarshal(data, &rawConfig)
if err != nil {
// If YAML fails, try JSON
err = json.Unmarshal(data, &rawConfig)
if err != nil {
return config, fmt.Errorf("failed to parse config: %v", err)
}
}
// Apply transformers
transformed, err := p.transformer.Transform(&rawConfig)
if err != nil {
return config, fmt.Errorf("failed to transform config: %v", err)
}
// Validate the transformed configuration
if err = p.validator.Validate(&rawConfig); err != nil {
return config, err
}
// Create a new mapstructure decoder
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &config,
TagName: DefaultTagName,
})
if err != nil {
return config, fmt.Errorf("failed to create decoder: %v", err)
}
// Decode the transformed configuration
err = decoder.Decode(transformed)
if err != nil {
return config, fmt.Errorf("failed to decode config: %v", err)
}
return config, err
}
package parser
import (
"sync"
)
type Registry[T any] interface {
GetParser(specType SpecType) (Parser[T], bool)
RegisterParser(specType SpecType, p Parser[T])
}
type RegistryImpl[T any] struct {
parsers map[SpecType]Parser[T]
mu sync.RWMutex
}
func (r *RegistryImpl[T]) RegisterParser(specType SpecType, p Parser[T]) {
r.mu.Lock()
defer r.mu.Unlock()
r.parsers[specType] = p
}
func (r *RegistryImpl[T]) GetParser(specType SpecType) (Parser[T], bool) {
r.mu.RLock()
defer r.mu.RUnlock()
p, exists := r.parsers[specType]
return p, exists
}
package transform
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// TransformerFunc is a function that transforms a part of the configuration.
// It modifies the data to conform to the expected structure and returns the transformed data.
// It takes the root configuration, the data to transform and the current path in the tree.
type TransformerFunc func(*map[string]interface{}, interface{}, tree.Path) (any, error)
// Transformer is a configuration transformer.
type Transformer interface {
Transform(*map[string]interface{}) (interface{}, error)
}
// TransformerImpl is the implementation of the Transformer interface.
type TransformerImpl struct {
transformers []map[tree.Path]TransformerFunc
}
// NewTransformer creates a new transformer with the given transformers.
func NewTransformer(transformers []map[tree.Path]TransformerFunc) Transformer {
return TransformerImpl{
transformers: transformers,
}
}
// Transform applies the transformers to the configuration.
func (t TransformerImpl) Transform(rawConfig *map[string]interface{}) (interface{}, error) {
data := any(*rawConfig)
var err error
for _, transformers := range t.transformers {
data, err = t.transform(rawConfig, data, tree.NewPath(), transformers)
if err != nil {
return nil, err
}
}
return Normalize(data), nil
}
// transform is a recursive function that applies the transformers to the configuration.
func (t TransformerImpl) transform(root *map[string]interface{}, data any, path tree.Path, transformers map[tree.Path]TransformerFunc) (interface{}, error) {
var err error
// Apply transformers that match the current path.
for pattern, transformer := range transformers {
if path.Matches(pattern) {
data, err = transformer(root, data, path)
if err != nil {
return nil, err
}
}
}
// Recursively apply transformers to children.
if result, ok := data.(map[string]interface{}); ok {
for key, value := range result {
next := path.Next(key)
result[key], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
} else if result, err := ToAnySlice(data); err == nil {
for i, value := range result {
next := path.Next(fmt.Sprintf("[%d]", i))
result[i], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
}
return data, nil
}
package transform
import (
"fmt"
"reflect"
"sort"
"strconv"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// mapToSlice converts a map of maps to a slice
// and assigns the key to the "name" field.
func MapToSlice(data map[string]any) ([]any, error) {
if data == nil {
return nil, nil
}
result := []any{}
for k, v := range data {
if v == nil {
v = map[string]any{}
}
if e, ok := v.(map[string]any); ok {
e["name"] = k
}
result = append(result, v)
}
return result, nil
}
// getConfigAtPath retrieves a part of the configuration at a given path
func GetConfigAtPath(config map[string]interface{}, path tree.Path) (any, error) {
current := any(config)
for _, key := range path.Parts() {
switch v := current.(type) {
case map[string]any:
current = v[key]
case []any, []map[string]any:
i, err := strconv.Atoi(key[1 : len(key)-1])
if err != nil {
return nil, fmt.Errorf("invalid index: %v", key)
}
switch v := v.(type) {
case []any:
current = v[i]
case []map[string]any:
current = v[i]
}
default:
return nil, fmt.Errorf("invalid data type: %v", current)
}
}
return current, nil
}
// Generic function to convert any slice to []any
func ToAnySlice(slice any) ([]any, error) {
value := reflect.ValueOf(slice)
// Check if the input is a slice
if value.Kind() != reflect.Slice {
return nil, fmt.Errorf("input is not a slice. type: %T", slice)
}
length := value.Len()
anySlice := make([]any, length)
for i := 0; i < length; i++ {
anySlice[i] = value.Index(i).Interface()
}
return anySlice, nil
}
func normalizeMap(m interface{}) interface{} {
v := reflect.ValueOf(m)
switch v.Kind() {
case reflect.Map:
// Create a new map to hold normalized values
newMap := reflect.MakeMap(reflect.MapOf(v.Type().Key(), reflect.TypeOf((*interface{})(nil)).Elem()))
for _, key := range v.MapKeys() {
newValue := normalizeMap(v.MapIndex(key).Interface())
newMap.SetMapIndex(key, reflect.ValueOf(newValue))
}
return newMap.Interface()
case reflect.Slice:
// Create a new []interface{} slice to hold normalized values
newSlice := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
newSlice[i] = normalizeMap(v.Index(i).Interface())
}
// Sort the slice if it's sortable
sort.Slice(newSlice, func(i, j int) bool {
return fmt.Sprint(newSlice[i]) < fmt.Sprint(newSlice[j])
})
return newSlice
default:
// For other types, return as is
return m
}
}
// NormalizeMap is the exported function that users will call
func Normalize(m any) interface{} {
return normalizeMap(m)
}
package tree
import (
"strings"
)
const (
configPathSeparator = "."
configPathMatchAny = "*"
configPathMatchAnyMultiple = "**"
configPathList = "[]"
)
// Path is a custom type for representing paths in the configuration
type Path string
func NewPath(path ...string) Path {
return Path(strings.Join(path, configPathSeparator))
}
// Parts returns the parts of the path
func (p Path) Parts() []string {
return strings.Split(string(p), configPathSeparator)
}
// Parent returns the parent path
func (p Path) Parent() Path {
parts := p.Parts()
if len(parts) > 1 {
return Path(strings.Join(parts[:len(parts)-1], configPathSeparator))
}
return ""
}
// Next returns the next part of the path
func (p Path) Next(path string) Path {
if path == "" {
return p
}
if p == "" {
return Path(path)
}
return Path(string(p) + configPathSeparator + path)
}
// Last returns the last part of the path
func (p Path) Last() string {
parts := p.Parts()
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}
// Matches checks if the path matches a given pattern
func (p Path) Matches(pattern Path) bool {
pathParts := p.Parts()
patternParts := pattern.Parts()
return matchParts(pathParts, patternParts)
}
// String returns the string representation of the path
func (p Path) String() string {
return string(p)
}
// matchParts checks if the path parts match the pattern parts
func matchParts(pathParts, patternParts []string) bool {
// If the pattern is longer than the path, it can't match
if len(pathParts) < len(patternParts) {
return false
}
for i, part := range patternParts {
switch part {
case configPathMatchAnyMultiple:
// if it is the last part of the pattern, it matches
if i == len(patternParts)-1 {
return true
}
// Otherwise, try to match the rest of the path
for j := i; j < len(pathParts); j++ {
if matchParts(pathParts[j:], patternParts[i+1:]) {
return true
}
}
case configPathList:
// check if pathParts[i] is inclosed by []
if pathParts[i][0] != '[' || pathParts[i][len(pathParts[i])-1] != ']' {
return false
}
default:
// If the part doesn't match, it doesn't match
if part != configPathMatchAny && part != pathParts[i] {
return false
}
}
// If it is the last part of the pattern and the path is longer, it doesn't match
if i == len(patternParts)-1 && i < len(pathParts)-1 {
return false
}
}
return true
}
package validate
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// ValidatorFunc is a function that validates a part of the configuration.
// It takes the root configuration, the data to validate and the current path in the tree.
type ValidatorFunc func(*map[string]any, any, tree.Path) error
// Validator is a configuration validator.
// It contains a map of patterns to paths to functions that validate the configuration.
type Validator interface {
Validate(*map[string]any) error
}
// ValidatorImpl is the implementation of the Validator interface.
type ValidatorImpl struct {
validators map[tree.Path]ValidatorFunc
}
// NewValidator creates a new validator with the given validators.
func NewValidator(validators map[tree.Path]ValidatorFunc) Validator {
return ValidatorImpl{
validators: validators,
}
}
// Validate applies the validators to the configuration.
func (v ValidatorImpl) Validate(rawConfig *map[string]any) error {
data := any(*rawConfig)
return v.validate(rawConfig, data, tree.NewPath(), v.validators)
}
// validate is a recursive function that applies the validators to the configuration.
func (v ValidatorImpl) validate(root *map[string]interface{}, data any, path tree.Path, validators map[tree.Path]ValidatorFunc) error {
// Apply validators that match the current path.
for pattern, validator := range validators {
if path.Matches(pattern) {
if err := validator(root, data, path); err != nil {
return err
}
}
}
// Recursively apply validators to children.
switch data := data.(type) {
case map[string]interface{}:
for key, value := range data {
next := path.Next(key)
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
case []interface{}:
for i, value := range data {
next := path.Next(fmt.Sprintf("[%d]", i))
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
}
return nil
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func CapabilityComparator(l, r interface{}, preference ...Preference) types.Comparison {
var result types.Comparison
result = types.Error // error is the default value
// TODO: implement the comparison logic for the Capability type
// after all the fields are implemented
return result
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func CpuComparator(l, r interface{}, preference ...Preference) types.Comparison {
// comparator for CPU type
// we want to reason about the inner fields of the CPU type and how they compare between left and right
// validate input type
lCpu, lok := l.(types.CPU)
rCpu, rok := r.(types.CPU)
if !lok || !rok {
return types.Error
}
perf_comparison := NumericComparator(
(lCpu.Cores * lCpu.Freq),
(rCpu.Cores * rCpu.Freq),
)
arch_comparision := LiteralComparator(lCpu.Arch, rCpu.Arch)
if arch_comparision == types.Error {
return types.Error
}
if arch_comparision != types.Equal {
return types.Worse
}
return perf_comparison
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and frequency is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of CPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func DiskComparator(l, r interface{}, preference ...Preference) types.Comparison {
// comparator for Memory type
// we want to reason about the inner fields of the Memory type and how they compare between left and right
// validate input type
_, lok := l.(types.Disk)
_, rok := r.(types.Disk)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["Type"] == types.Error {
return types.Error
}
if comparison["Type"] != types.Equal {
return types.Worse
}
return comparison["Size"]
// currently this is a very simple comparison, based on the assumption
// that more Size / or equal amount of size and speed is acceptable, but nothing less;
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func ExecutionResourcesComparator(l, r interface{}, preference ...Preference) types.Comparison {
// comparator for types.ExecutionResources type
// Current implementation of the type has four fields: CPU, Memory, Disk, GPUs
// we consider that all fields have to be 'Better' or 'Equal'
// for the comparison to be 'Better' or 'Equal'
// else we return 'Worse'
// validate input type
_, lok := l.(types.ExecutionResources)
_, rok := r.(types.ExecutionResources)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["CPU"] == types.Error ||
comparison["Memory"] == types.Error ||
comparison["Disk"] == types.Error ||
comparison["GPUs"] == types.Error {
return types.Error
}
if comparison["CPU"] == types.Worse ||
comparison["Memory"] == types.Worse ||
comparison["Disk"] == types.Worse ||
comparison["GPUs"] == types.Worse {
return types.Worse
}
if comparison["CPU"] == types.Equal &&
comparison["Memory"] == types.Equal &&
comparison["Disk"] == types.Equal &&
comparison["GPUs"] == types.Equal {
return types.Equal
}
return types.Better // if non above returns, then the result is better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func ExecutorComparator(lraw, rraw interface{}, preference ...Preference) types.Comparison {
// comparator for Executor types
// it is needed because executor type is defined as enum of ExecutorType's in types.execution.go
// left represent machine capabilities
// right represent required capabilities
// it is not so complex as the type has only one field
// therefore this method just passes it through...
// validate input type
_, lrawok := lraw.(types.Executor)
_, rrawok := rraw.(types.Executor)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Executor)
r := rraw.(types.Executor)
leftExecutorType := l.ExecutorType
rightExecutorType := r.ExecutorType
comparison := Compare(leftExecutorType, rightExecutorType)
return comparison
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func ExecutorTypeComparator(l, r interface{}, preference ...Preference) types.Comparison {
_, lok := l.(types.ExecutorType)
_, rok := r.(types.ExecutorType)
if !lok || !rok {
return types.Error
}
result := types.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = types.Equal
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func ExecutorsComparator(lraw, rraw interface{}, preference ...Preference) types.Comparison {
// comparator for Executors types:
// left represent machine capabilities;
// right represent required capabilities;
var result types.Comparison
result = types.Error // error is the default value
// validate input type
_, lrawok := lraw.(types.Executors)
_, rrawok := rraw.(types.Executors)
if !lrawok || !rrawok {
return types.Error
}
var l []interface{}
l = lraw.([]interface{})
var r []interface{}
r = rraw.([]interface{})
if !utils.IsSameShallowType(l, r) {
result = types.Error
}
if reflect.DeepEqual(l, r) {
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = types.Equal
} else if utils.IsStrictlyContained(l, r) {
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
result = types.Better
} else if utils.IsStrictlyContained(r, l) {
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
result = types.Worse
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func GPUVendorComparator(l, r interface{}, preference ...Preference) types.Comparison {
// validate input type
_, lok := l.(types.GPUVendor)
_, rok := r.(types.GPUVendor)
if !lok || !rok {
return types.Error
}
result := types.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = types.Equal
}
return result
// This comparison logic just tells if the vendor is the same or not;
// however, we do not have yet a mechanism for externally defined preferences from a user;
// in this case, we may need to implement that -- because some compute may prefer one vendor over the other;
// some compute may be strictly dependent on a specific vendor;
// technically, this will have to be solved on the resource matching level;
// but the mechanism will have to be generic...
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
func GPUsComparator(lraw, rraw interface{}, preference ...Preference) types.Comparison {
// comparator for GPUs type which is just a slice of GPU types:
// left represent machine capabilities;
// right represent required capabilities;
// we need to check if for ech GPU on the right there exist a matching GPU on the left...
// (since given slices are not ordered...)
// validate input type
_, lrawok := lraw.([]types.GPU)
_, rrawok := rraw.([]types.GPU)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.GPU)
r := rraw.([]types.GPU)
var interimComparison1 [][]types.Comparison
for _, rGPU := range r {
var interimComparison2 []types.Comparison
for _, lGPU := range l {
interimComparison2 = append(interimComparison2, Compare(lGPU, rGPU))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []types.Comparison
var consideredIndexes []int
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
consideredIndexes = append(consideredIndexes, index)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, types.Error) {
return types.Error
}
if slices.Contains(finalComparison, types.Worse) {
return types.Worse
}
if SliceContainsOneValue(finalComparison, types.Equal) {
return types.Equal
}
return types.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func GpuComparator(l, r interface{}, preference ...Preference) types.Comparison {
// comparator for GPU type
// we want to reason about the inner fields of the GPU type and how they compare between left and right
// in the future we may want to pass custom preference parameters to the ComplexComparator
// for now it is probably best to hardcode them;
// validate input type
_, lok := l.(types.GPU)
_, rok := r.(types.GPU)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["VRAM"] == types.Error {
return types.Error
}
if comparison["VRAM"] == types.Worse {
return types.Worse
}
if comparison["VRAM"] == types.Better {
return types.Better
}
if comparison["VRAM"] == types.Equal {
return types.Equal
}
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and VRAM is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of GPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
return types.Error // error is the default value
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func JobTypeComparator(l, r interface{}, preference ...Preference) types.Comparison {
// validate input type
_, lok := l.(types.JobType)
_, rok := r.(types.JobType)
if !lok || !rok {
return types.Error
}
result := types.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = types.Equal
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func JobTypesComparator(lraw, rraw interface{}, preference ...Preference) types.Comparison {
// comparator for JobTypes type:
// left represent machine capabilities;
// right represent required capabilities;
// if machine capabilities contain oll the required capabilities, then we are good to go
// validate input type
_, lrawok := lraw.(types.JobTypes)
_, rrawok := rraw.(types.JobTypes)
if !lrawok || !rrawok {
return types.Error
}
var result types.Comparison
result = types.Error // error is the default value
// we know that interfaces here are slices, so need to assert first
var l []interface{} = utils.ConvertTypedSliceToUntypedSlice(lraw)
var r []interface{} = utils.ConvertTypedSliceToUntypedSlice(rraw)
if !utils.IsSameShallowType(l, r) {
result = types.Error
}
if reflect.DeepEqual(l, r) {
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = types.Equal
} else if utils.IsStrictlyContained(l, r) {
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
result = types.Better
} else if utils.IsStrictlyContained(r, l) {
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
result = types.Worse
// TODO: this comparator does not take into account options when several job types are available and several job types are required
// in the same data structure; this is why the test fails;
}
return result
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
func LibrariesComparator(lraw, rraw interface{}, preference ...Preference) types.Comparison {
// comparator for Libraries slices (of different lengths) of Library types:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.Library)
_, rrawok := rraw.([]types.Library)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.Library)
r := rraw.([]types.Library)
var interimComparison1 [][]types.Comparison
for _, rLibrary := range r {
var interimComparison2 []types.Comparison
for _, lLibrary := range l {
interimComparison2 = append(interimComparison2, Compare(lLibrary, rLibrary))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []types.Comparison
var consideredIndexes []int
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
consideredIndexes = append(consideredIndexes, index)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, types.Error) {
return types.Error
}
if slices.Contains(finalComparison, types.Worse) {
return types.Worse
}
if SliceContainsOneValue(finalComparison, types.Equal) {
return types.Equal
}
return types.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
"github.com/hashicorp/go-version"
)
func LibraryComparator(lraw, rraw interface{}, preference ...Preference) types.Comparison {
// comparator for single Library type:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Library)
_, rrawok := rraw.(types.Library)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Library)
lVersion, err := version.NewVersion(l.Version)
if err != nil {
return types.Error
}
r := rraw.(types.Library)
// return 'Error' if the version of the left library is not valid
constraints, err := version.NewConstraint(r.Constraint + " " + r.Version)
if err != nil {
return types.Error
}
// return 'Error' if the names of the libraries are different
if l.Name != r.Name {
return types.Error
}
// else return 'Equal if versions of libraries are equal and the constraint is '='
if r.Constraint == "=" && constraints.Check(lVersion) {
return types.Equal
}
// else return 'Better' if versions of libraries match the constraint
if constraints.Check(lVersion) {
return types.Better
}
// else return 'Worse'
return types.Worse
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func LiteralComparator(l, r interface{}, preference ...Preference) types.Comparison {
// comparator for literal (basically string) types:
// left represent machine capabilities;
// right represent required capabilities;
// which can be only equal or not equal...
// validate input type
_, lok := l.(string)
_, rok := r.(string)
if !lok || !rok {
return types.Error
}
var result types.Comparison
result = types.Error // error is the default value
switch l.(type) {
case string:
if l == r {
result = types.Equal
}
}
return result
}
package matching
import (
// "reflect"
"golang.org/x/exp/slices"
"gitlab.com/nunet/device-management-service/types"
)
func LocalitiesComparator(lraw interface{}, rraw interface{}, preference ...Preference) types.Comparison {
// simplified version of Localities comparator
// which is simply a slice of Locality type;
// we do not have separate type defined for Localities
// it takes preference variable where comparison Preference is defined
// this is the first method that is used to take Preference variable into account
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.Locality)
_, rrawok := rraw.([]types.Locality)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.Locality)
r := rraw.([]types.Locality)
var interimComparison [](map[string]types.Comparison)
for _, rLocality := range r {
field := make(map[string]types.Comparison)
field[rLocality.Kind] = types.Error
for _, lLocality := range l {
if lLocality.Kind == rLocality.Kind {
field[rLocality.Kind] = Compare(lLocality, rLocality)
// this is to make sure that we have a comparison even if slice dimentiones do not match
}
}
interimComparison = append(interimComparison, field)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []types.Comparison
for _, c := range interimComparison {
for _, v := range c { // we know that there is only one value in the map
finalComparison = append(finalComparison, v)
}
}
if slices.Contains(finalComparison, types.Error) {
return types.Error
}
if slices.Contains(finalComparison, types.Worse) {
return types.Worse
}
if SliceContainsOneValue(finalComparison, types.Equal) {
return types.Equal
}
return types.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func LocalityComparator(lraw interface{}, rraw interface{}, preference ...Preference) types.Comparison {
// simplified version of (placeholder)
// comparator for Locality:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Locality)
_, rrawok := rraw.(types.Locality)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Locality)
r := rraw.(types.Locality)
if l.Kind == r.Kind {
if l.Name == r.Name {
return types.Equal
} else {
return types.Worse
}
} else {
return types.Error
}
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func MemoryComparator(l, r interface{}, preference ...Preference) types.Comparison {
// comparator for Memory type
// we want to reason about the inner fields of the Memory type and how they compare between left and right
// validate input type
_, lok := l.(types.Memory)
_, rok := r.(types.Memory)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["Size"] == types.Error {
return types.Error
}
if comparison["Size"] == types.Worse {
return types.Worse
}
return comparison["Speed"]
// currently this is a very simple comparison, based on the assumption
// that more Size / or equal amount of size and speed is acceptable, but nothing less;
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
func NumericComparator(lraw, rraw interface{}, preference ...Preference) types.Comparison {
// comparator for numeric types:
// left represent machine capabilities;
// right represent required capabilities;
var result types.Comparison
result = types.Error // error is the default value
// validate input type
l, lnumeric := validate.ConvertNumericToFloat64(lraw)
r, rnumeric := validate.ConvertNumericToFloat64(rraw)
if !lnumeric || !rnumeric {
result = types.Error
}
if reflect.DeepEqual(l, r) {
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = types.Equal
} else if l < r {
// if declared machine numeric capability
// is less than jobs required capability
// then the result of comparison in 'Less'
// ("less is required than available")
result = types.Worse
} else if l > r {
// if declared machine numeric capability
// is more than jobs required numeric capability
// then the result of comparison is 'More'
// ("more is required than available")
result = types.Better
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// generic compare function for comparing any custom types given a custom comparator
// for simple types (i.e which are not nested in the map[string]interface{} structure)
type Comparator func(l, r interface{}, preference ...Preference) types.Comparison
func Compare(l, r interface{}, preference ...Preference) types.Comparison {
// TODO: it would be better to pass a pointer as this is a global structure
var comparatorMap = initComparatorMap()
// check if the type is numeric
if _, numeric := validate.ConvertNumericToFloat64(l); numeric {
comparator := comparatorMap["Numeric"]
if comparator == nil {
return types.Error
}
return comparator(l, r)
}
typeName := reflect.TypeOf(l).Name()
// this means that the type is probably a slice of custom types
// we have to get the element types and then map it to existing custom types that know
// so that we can call a correct comparator for that
switch reflect.TypeOf(l).Kind() {
// check if we have a slice of further types
// we need to mention each type explicitly
case reflect.Slice:
if _, ok := l.([]types.GPU); ok {
typeName = "GPUs"
}
if _, ok := l.([]types.Library); ok {
typeName = "Libraries"
}
if _, ok := l.([]types.Locality); ok {
typeName = "Localities"
}
if _, ok := l.([]types.Locality); ok {
typeName = "Localities"
}
}
// select the comparator based on type
comparator := comparatorMap[typeName]
if comparator == nil {
return types.Error
}
return comparator(l, r)
}
type ComparatorMap map[string]Comparator
func initComparatorMap() ComparatorMap {
// comparatorMap holds all defined comparators in a variable that can be passed
// around and searched / referenced
var comparatorMap = make(map[string]Comparator)
comparatorMap["Numeric"] = NumericComparator
comparatorMap["Capability"] = CapabilityComparator
comparatorMap["string"] = LiteralComparator
comparatorMap["Executors"] = ExecutorsComparator
comparatorMap["ExecutorType"] = ExecutorTypeComparator
comparatorMap["JobType"] = JobTypeComparator
comparatorMap["JobTypes"] = JobTypesComparator
comparatorMap["GPUVendor"] = GPUVendorComparator
comparatorMap["GPUs"] = GPUsComparator
comparatorMap["GPU"] = GpuComparator
comparatorMap["Executor"] = ExecutorComparator
comparatorMap["ExecutionResources"] = ExecutionResourcesComparator
comparatorMap["CPU"] = CpuComparator
comparatorMap["Memory"] = MemoryComparator
comparatorMap["Disk"] = DiskComparator
comparatorMap["Library"] = LibraryComparator
comparatorMap["Libraries"] = LibrariesComparator
comparatorMap["Locality"] = LocalityComparator
comparatorMap["Localities"] = LocalitiesComparator
return comparatorMap
}
type Preference struct {
TypeName string
Strength PreferenceString
DefaultComparatorOverride Comparator
}
type PreferenceString string
const (
Hard PreferenceString = "Hard"
Soft PreferenceString = "Soft"
)
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func ReturnComplexComparison(l, r interface{}) types.ComplexComparison {
// Complex comparison is a comparison of two complex types
// Which have nested fields that need to be considered together
// before a final comparison for the whole complex type can be made
// it is a helper function used in some type comparators
vl := reflect.ValueOf(l)
vr := reflect.ValueOf(r)
var complexComparison types.ComplexComparison = make(types.ComplexComparison)
for i := 0; i < vl.NumField(); i++ {
innerTypeName := vl.Type().Field(i).Name
valueL := vl.Field(i).Interface()
valueR := vr.Field(i).Interface()
complexComparison[innerTypeName] = Compare(valueL, valueR)
}
return complexComparison
}
func removeIndex(slice [][]types.Comparison, index int) [][]types.Comparison {
// removeIndex removes the element at the specified index from each sub-slice in the given slice.
// If the index is out of bounds for a sub-slice, the function leaves that sub-slice unmodified.
for i, c := range slice {
if index < 0 || index >= len(c) {
// Index is out of bounds, leave the sub-slice unmodified
continue
}
slice[i] = append(c[:index], c[index+1:]...)
}
return slice
}
func returnBestMatch(dimension []types.Comparison) (types.Comparison, int) {
// while i feel that there could be some weird matrix sorting algorithm that could be used here
// i can't think of any right now, so i will just iterate over the matrix and return matches
// in somewhat manual way
for i, v := range dimension {
if v == types.Equal {
return v, i // selecting an equal match is the most efficient match
}
}
for i, v := range dimension {
if v == types.Better {
return v, i // selecting a better is also not bad
}
}
for i, v := range dimension {
if v == types.Worse {
return v, i // this is just for sport
}
}
for i, v := range dimension {
if v == types.Error {
return v, i // this is just for sport
}
}
return types.Error, -1
}
func SliceContainsOneValue(slice []types.Comparison, value types.Comparison) bool {
// returns true if all elements in the slice are equal to the given value
for _, v := range slice {
if v != value {
return false
}
}
return true
}
package dms
import (
"gitlab.com/nunet/device-management-service/dms/resources"
"gorm.io/gorm"
)
// SanityCheck before being deleted performed basic consistency checks before starting the DMS
// in the following sequence:
// It checks for services that are marked running from the database and stops then removes them.
// Update their status to 'finshed with errors'.
// Recalculates free resources and update the database.
//
// Deleted now because dependencies such as the docker package have been replaced with executor/docker
func SanityCheck(gormDB *gorm.DB) {
// TODO: sanity check of DMS last exit and correction of invalid states
resources.CalcFreeResAndUpdateDB()
}
package docker
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/stdcopy"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
// Client wraps the Docker client to provide high-level operations on Docker containers and networks.
type Client struct {
client *client.Client // Embed the Docker client.
}
// NewDockerClient initializes a new Docker client with environment variables and API version negotiation.
func NewDockerClient() (*Client, error) {
c, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}
return &Client{client: c}, nil
}
// IsInstalled checks if Docker is installed and reachable by pinging the Docker daemon.
func (c *Client) IsInstalled(ctx context.Context) bool {
_, err := c.client.Ping(ctx)
return err == nil
}
// CreateContainer creates a new Docker container with the specified configuration.
func (c *Client) CreateContainer(
ctx context.Context,
config *container.Config,
hostConfig *container.HostConfig,
networkingConfig *network.NetworkingConfig,
platform *v1.Platform,
name string,
) (string, error) {
_, err := c.PullImage(ctx, config.Image)
if err != nil {
return "", err
}
resp, err := c.client.ContainerCreate(
ctx,
config,
hostConfig,
networkingConfig,
platform,
name,
)
if err != nil {
return "", err
}
return resp.ID, nil
}
// InspectContainer returns detailed information about a Docker container.
func (c *Client) InspectContainer(ctx context.Context, id string) (types.ContainerJSON, error) {
return c.client.ContainerInspect(ctx, id)
}
// FollowLogs tails the logs of a specified container, returning separate readers for stdout and stderr.
func (c *Client) FollowLogs(ctx context.Context, id string) (stdout, stderr io.Reader, err error) {
cont, err := c.InspectContainer(ctx, id)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
}
logsReader, err := c.client.ContainerLogs(ctx, cont.ID, logOptions)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container logs")
}
stdoutReader, stdoutWriter := io.Pipe()
stderrReader, stderrWriter := io.Pipe()
go func() {
stdoutBuffer := bufio.NewWriter(stdoutWriter)
stderrBuffer := bufio.NewWriter(stderrWriter)
defer func() {
logsReader.Close()
stdoutBuffer.Flush()
stdoutWriter.Close()
stderrBuffer.Flush()
stderrWriter.Close()
}()
_, err = stdcopy.StdCopy(stdoutBuffer, stderrBuffer, logsReader)
if err != nil && !errors.Is(err, context.Canceled) {
zlog.Sugar().Warnf("context closed while getting logs: %v\n", err)
}
}()
return stdoutReader, stderrReader, nil
}
// StartContainer starts a specified Docker container.
func (c *Client) StartContainer(ctx context.Context, containerID string) error {
return c.client.ContainerStart(ctx, containerID, types.ContainerStartOptions{})
}
// WaitContainer waits for a container to stop, returning channels for the result and errors.
func (c *Client) WaitContainer(
ctx context.Context,
containerID string,
) (<-chan container.ContainerWaitOKBody, <-chan error) {
return c.client.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
}
// StopContainer stops a running Docker container with a specified timeout.
func (c *Client) StopContainer(
ctx context.Context,
containerID string,
timeout time.Duration,
) error {
return c.client.ContainerStop(ctx, containerID, &timeout)
}
// RemoveContainer removes a Docker container, optionally forcing removal and removing associated volumes.
func (c *Client) RemoveContainer(ctx context.Context, containerID string) error {
return c.client.ContainerRemove(
ctx,
containerID,
types.ContainerRemoveOptions{RemoveVolumes: true, Force: true},
)
}
// removeContainers removes all containers matching the specified filters.
func (c *Client) removeContainers(ctx context.Context, filterz filters.Args) error {
containers, err := c.client.ContainerList(
ctx,
types.ContainerListOptions{All: true, Filters: filterz},
)
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(containers))
for _, container := range containers {
wg.Add(1)
go func(container types.Container, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.RemoveContainer(ctx, container.ID)
}(container, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// removeNetworks removes all networks matching the specified filters.
func (c *Client) removeNetworks(ctx context.Context, filterz filters.Args) error {
networks, err := c.client.NetworkList(ctx, types.NetworkListOptions{Filters: filterz})
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(networks))
for _, network := range networks {
wg.Add(1)
go func(network types.NetworkResource, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.client.NetworkRemove(ctx, network.ID)
}(network, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// RemoveObjectsWithLabel removes all Docker containers and networks with a specific label.
func (c *Client) RemoveObjectsWithLabel(ctx context.Context, label string, value string) error {
filterz := filters.NewArgs(
filters.Arg("label", fmt.Sprintf("%s=%s", label, value)),
)
containerErr := c.removeContainers(ctx, filterz)
networkErr := c.removeNetworks(ctx, filterz)
return multierr.Combine(containerErr, networkErr)
}
// GetOutputStream streams the logs for a specified container.
// The 'since' parameter specifies the timestamp from which to start streaming logs.
// The 'follow' parameter indicates whether to continue streaming logs as they are produced.
// Returns an io.ReadCloser to read the output stream and an error if the operation fails.
func (c *Client) GetOutputStream(
ctx context.Context,
containerID string,
since string,
follow bool,
) (io.ReadCloser, error) {
cont, err := c.InspectContainer(ctx, containerID)
if err != nil {
return nil, errors.Wrap(err, "failed to get container")
}
if !cont.State.Running {
return nil, fmt.Errorf("cannot get logs for a container that is not running")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: follow,
Since: since,
}
logReader, err := c.client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
return nil, errors.Wrap(err, "failed to get container logs")
}
return logReader, nil
}
// FindContainer searches for a container by label and value, returning its ID if found.
func (c *Client) FindContainer(ctx context.Context, label string, value string) (string, error) {
containers, err := c.client.ContainerList(ctx, types.ContainerListOptions{All: true})
if err != nil {
return "", err
}
for _, container := range containers {
if container.Labels[label] == value {
return container.ID, nil
}
}
return "", fmt.Errorf("unable to find container for %s=%s", label, value)
}
// PullImage pulls a Docker image from a registry.
func (c *Client) PullImage(ctx context.Context, imageName string) (string, error) {
out, err := c.client.ImagePull(ctx, imageName, types.ImagePullOptions{})
if err != nil {
zlog.Sugar().Errorf("unable to pull image: %v", err)
return "", err
}
defer out.Close()
d := json.NewDecoder(io.TeeReader(out, os.Stdout))
var message jsonmessage.JSONMessage
var digest string
for {
if err := d.Decode(&message); err != nil {
if err == io.EOF {
break
}
zlog.Sugar().Errorf("unable pull image: %v", err)
return "", err
}
if message.Aux != nil {
continue
}
if message.Error != nil {
zlog.Sugar().Errorf("unable pull image: %v", message.Error.Message)
return "", errors.New(message.Error.Message)
}
if strings.HasPrefix(message.Status, "Digest") {
digest = strings.TrimPrefix(message.Status, "Digest: ")
}
}
return digest, nil
}
package docker
import (
"context"
"fmt"
"io"
"os"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
labelExecutorName = "nunet-executor"
labelJobID = "nunet-jobID"
labelExecutionID = "nunet-executionID"
outputStreamCheckTickTime = 100 * time.Millisecond
outputStreamCheckTimeout = 5 * time.Second
)
// Executor manages the lifecycle of Docker containers for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Docker client for container management.
}
// NewExecutor initializes a new Executor instance with a Docker client.
func NewExecutor(_ context.Context, id string) (*Executor, error) {
dockerClient, err := NewDockerClient()
if err != nil {
return nil, err
}
return &Executor{
ID: id,
client: dockerClient,
}, nil
}
// IsInstalled checks if Docker is installed and the Docker daemon is accessible.
func (e *Executor) IsInstalled(ctx context.Context) bool {
return e.client.IsInstalled(ctx)
}
// Start begins the execution of a request by starting a Docker container.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// container is already running.
containerID, err := e.FindRunningContainer(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running container for this execution, we will instead check for a handler, and
// failing that will create a new container.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
} else {
return fmt.Errorf("execution is already completed")
}
}
// Create a new handler for the execution.
containerID, err = e.newDockerExecutionContainer(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new container: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
containerID: containerID,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
TTYEnabled: true,
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the container.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s received results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// GetLogStream provides a stream of output logs for a specific execution.
// Parameters 'withHistory' and 'follow' control whether to include past logs
// and whether to keep the stream open for new logs, respectively.
// It returns an error if the execution is not found.
func (e *Executor) GetLogStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
// It's possible we've recorded the execution as running, but have not yet added the handler to
// the handler map because we're still waiting for the container to start. We will try and wait
// for a few seconds to see if the handler is added to the map.
chHandler := make(chan *executionHandler)
chExit := make(chan struct{})
go func(ch chan *executionHandler, exit chan struct{}) {
// Check the handlers every 100ms and send it down the
// channel if we find it. If we don't find it after 5 seconds
// then we'll be told on the exit channel
ticker := time.NewTicker(outputStreamCheckTickTime)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h, found := e.handlers.Get(request.ExecutionID)
if found {
ch <- h
return
}
case <-exit:
ticker.Stop()
return
}
}
}(chHandler, chExit)
// Either we'll find a handler for the execution (which might have finished starting)
// or we'll timeout and return an error.
select {
case handler := <-chHandler:
return handler.outputStream(ctx, request)
case <-time.After(outputStreamCheckTimeout):
chExit <- struct{}{}
}
return nil, fmt.Errorf("execution (%s) not found", request.ExecutionID)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// Cleanup removes all Docker resources associated with the executor.
// This includes removing containers including networks and volumes with the executor's label.
func (e *Executor) Cleanup(ctx context.Context) error {
err := e.client.RemoveObjectsWithLabel(ctx, labelExecutorName, e.ID)
if err != nil {
return fmt.Errorf("failed to remove containers: %w", err)
}
zlog.Info("Cleaned up all Docker resources")
return nil
}
// newDockerExecutionContainer is an internal method called by Start to set up a new Docker container
// for the job execution. It configures the container based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up environment variables, mounts and resource
// constraints. It then creates the container but does not start it.
// The method returns a container.CreateResponse and an error if any part of the setup fails.
func (e *Executor) newDockerExecutionContainer(
ctx context.Context,
params *types.ExecutionRequest,
) (string, error) {
dockerArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return "", fmt.Errorf("failed to decode docker engine spec: %w", err)
}
// TODO: Move this code block ( L263-272) to the allocator in future
// Select the GPU with the highest available free VRAM and choose the GPU vendor for container's host config
gpus, err := resources.GetGPUInfo()
maxFreeVRAMGpu, err := resources.GPUList(gpus).GetGPUWithHighestFreeVRAM()
if err != nil {
return "", fmt.Errorf("failed to get GPU with highest free VRAM: %w", err)
}
// Essential for multi-vendor GPU nodes. For example,
// if a machine has an 8 GB NVIDIA and a 16 GB Intel GPU, the latter should be used first.
// Even for machines with a single GPU, this is important as integrated GPUs would also be commonly detected.
chosenGPUVendor := maxFreeVRAMGpu.Vendor
containerConfig := container.Config{
Image: dockerArgs.Image,
Tty: true, // Needs to be true for applications such as Jupyter or Gradio to work correctly. See issue #459 for details.
Env: dockerArgs.Environment,
Entrypoint: dockerArgs.Entrypoint,
Cmd: dockerArgs.Cmd,
Labels: e.containerLabels(params.JobID, params.ExecutionID),
WorkingDir: dockerArgs.WorkingDirectory,
}
mounts, err := makeContainerMounts(params.Inputs, params.Outputs, params.ResultsDir)
if err != nil {
return "", fmt.Errorf("failed to create container mounts: %w", err)
}
zlog.Sugar().Infof("Adding %d GPUs to request", len(params.Resources.GPUs))
hostConfig := configureHostConfig(chosenGPUVendor, params, mounts)
if _, err = e.client.PullImage(ctx, dockerArgs.Image); err != nil {
return "", fmt.Errorf("failed to pull docker image: %w", err)
}
executionContainer, err := e.client.CreateContainer(
ctx,
&containerConfig,
&hostConfig,
nil,
nil,
labelExecutionValue(e.ID, params.JobID, params.ExecutionID),
)
if err != nil {
return "", fmt.Errorf("failed to create container: %w", err)
}
return executionContainer, nil
}
// configureHostConfig sets up the host configuration for the container based on the
// GPU vendor and resources requested by the execution. It supports both GPU and CPU configurations.
func configureHostConfig(vendor types.GPUVendor, params *types.ExecutionRequest, mounts []mount.Mount) container.HostConfig {
var hostConfig container.HostConfig
switch vendor {
case types.GPUVendorNvidia:
deviceIDs := make([]string, len(params.Resources.GPUs))
for i, gpu := range params.Resources.GPUs {
deviceIDs[i] = fmt.Sprint(gpu.Index)
}
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Freq),
CPUCount: int64(params.Resources.CPU.Cores),
DeviceRequests: []container.DeviceRequest{
{
DeviceIDs: deviceIDs,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Freq),
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
// Updated the device handling for Intel GPUs.
// Previously, specific device paths were determined using PCI addresses and symlinks.
// Now, the approach has been simplified by directly binding the entire /dev/dri directory.
// This change exposes all Intel GPUs to the container, which may be preferable for
// environments with multiple Intel GPUs. It reduces complexity as granular control
// is not required if all GPUs need to be accessible.
case types.GPUVendorIntel:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Freq),
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
default:
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Freq),
CPUCount: int64(params.Resources.CPU.Cores),
},
}
}
return hostConfig
}
// makeContainerMounts creates the mounts for the container based on the input and output
// volumes provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeContainerMounts(
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]mount.Mount, error) {
// the actual mounts we will give to the container
// these are paths for both input and output data
var mounts []mount.Mount
for _, input := range inputs {
if input.Type != types.StorageVolumeTypeBind {
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: input.Source,
Target: input.Target,
ReadOnly: input.ReadOnly,
})
} else {
return nil, fmt.Errorf("unsupported storage volume type: %s", input.Type)
}
}
for _, output := range outputs {
if output.Source == "" {
return nil, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return nil, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return nil, fmt.Errorf("failed to create results directory: %w", err)
}
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: output.Source,
Target: output.Target,
// this is an output volume so can be written to
ReadOnly: false,
})
}
return mounts, nil
}
// containerLabels returns the labels to be applied to the container for the given job and execution.
func (e *Executor) containerLabels(jobID string, executionID string) map[string]string {
return map[string]string{
labelExecutorName: e.ID,
labelJobID: labelJobValue(e.ID, jobID),
labelExecutionID: labelExecutionValue(e.ID, jobID, executionID),
}
}
// labelJobValue returns the value for the job label.
func labelJobValue(executorID string, jobID string) string {
return fmt.Sprintf("%s_%s", executorID, jobID)
}
// labelExecutionValue returns the value for the execution label.
func labelExecutionValue(executorID string, jobID string, executionID string) string {
return fmt.Sprintf("%s_%s_%s", executorID, jobID, executionID)
}
// FindRunningContainer finds the container that is running the execution
// with the given ID. It returns the container ID if found, or an error if
// the container is not found.
func (e *Executor) FindRunningContainer(
ctx context.Context,
jobID string,
executionID string,
) (string, error) {
labelValue := labelExecutionValue(e.ID, jobID, executionID)
return e.client.FindContainer(ctx, labelExecutionID, labelValue)
}
package docker
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"strconv"
"sync/atomic"
"time"
"gitlab.com/nunet/device-management-service/types"
)
const DestroyTimeout = time.Second * 10
// executionHandler manages the lifecycle and execution of a Docker container for a specific job.
type executionHandler struct {
// provided by the executor
ID string
client *Client // Docker client for container management.
// meta data about the task
jobID string
executionID string
containerID string
resultsDir string // Directory to store execution results.
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // Blocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
// TTY setting
TTYEnabled bool // Indicates if TTY is enabled for the container.
}
// active checks if the execution handler's container is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the container and handles its execution lifecycle.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
if err := h.destroy(DestroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
if err := h.client.StartContainer(ctx, h.containerID); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start container: %v", err))
return
}
close(h.activeCh) // Indicate that the container has started.
var containerError error
var containerExitStatusCode int64
// Wait for the container to finish or for an execution error.
statusCh, errCh := h.client.WaitContainer(ctx, h.containerID)
select {
case status := <-ctx.Done():
h.result = types.NewFailedExecutionResult(fmt.Errorf("execution cancelled: %v", status))
return
case err := <-errCh:
zlog.Sugar().Errorf("error while waiting for container: %v\n", err)
h.result = types.NewFailedExecutionResult(
fmt.Errorf("failed to wait for container: %v", err),
)
return
case exitStatus := <-statusCh:
containerExitStatusCode = exitStatus.StatusCode
containerJSON, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: err.Error(),
}
return
}
if containerJSON.ContainerJSONBase.State.OOMKilled {
containerError = errors.New("container was killed due to OOM")
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
return
}
if exitStatus.Error != nil {
containerError = errors.New(exitStatus.Error.Message)
}
}
// Follow container logs to capture stdout and stderr.
stdoutPipe, stderrPipe, logsErr := h.client.FollowLogs(ctx, h.containerID)
if logsErr != nil {
followError := fmt.Errorf("failed to follow container logs: %w", logsErr)
if containerError != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: fmt.Sprintf(
"container error: '%s'. logs error: '%s'",
containerError,
followError,
),
}
} else {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: followError.Error(),
}
}
return
}
// Initialize the result with the exit status code.
h.result = types.NewExecutionResult(int(containerExitStatusCode))
// Capture the logs based on the TTY setting.
if h.TTYEnabled {
// TTY combines stdout and stderr, read from stdoutPipe only.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
} else {
// Read from stdout and stderr separately.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
h.result.STDERR, _ = bufio.NewReader(stderrPipe).ReadString('\x00')
}
}
// kill sends a stop signal to the container.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.StopContainer(ctx, h.containerID, DestroyTimeout)
}
// destroy cleans up the container and its associated resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// stop the container
if err := h.kill(ctx); err != nil {
return fmt.Errorf("failed to kill container (%s): %w", h.containerID, err)
}
if err := h.client.RemoveContainer(ctx, h.containerID); err != nil {
return err
}
// Remove related objects like networks or volumes created for this execution.
return h.client.RemoveObjectsWithLabel(
ctx,
labelExecutionID,
labelExecutionValue(h.ID, h.jobID, h.executionID),
)
}
func (h *executionHandler) outputStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
since := "1" // Default to the start of UNIX time to get all logs.
if request.Tail {
since = strconv.FormatInt(time.Now().Unix(), 10)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-h.activeCh: // Ensure the container is active before attempting to stream logs.
}
// Gets the underlying reader, and provides data since the value of the `since` timestamp.
return h.client.GetOutputStream(ctx, h.containerID, since, request.Follow)
}
package docker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("docker.executor")
}
package docker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyImage = "image"
EngineKeyEntrypoint = "entrypoint"
EngineKeyCmd = "cmd"
EngineKeyEnvironment = "environment"
EngineKeyWorkingDirectory = "working_directory"
)
// EngineSpec contains necessary parameters to execute a docker job.
type EngineSpec struct {
// Image this should be pullable by docker
Image string `json:"image,omitempty"`
// Entrypoint optionally override the default entrypoint
Entrypoint []string `json:"entrypoint,omitempty"`
// Cmd specifies the command to run in the container
Cmd []string `json:"cmd,omitempty"`
// EnvironmentVariables is a slice of env to run the container with
Environment []string `json:"environment,omitempty"`
// WorkingDirectory inside the container
WorkingDirectory string `json:"working_directory,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.Image) {
return fmt.Errorf("invalid docker engine params: image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a docker engine spec
// It converts the params into a docker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(types.ExecutorTypeDocker) {
return EngineSpec{}, fmt.Errorf(
"invalid docker engine type. expected %s, but recieved: %s",
types.ExecutorTypeDocker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid docker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode docker engine params: %w", err)
}
var dockerSpec *EngineSpec
if err := json.Unmarshal(paramBytes, &dockerSpec); err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode docker engine params: %w", err)
}
return *dockerSpec, dockerSpec.Validate()
}
// DockerEngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Docker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type DockerEngineBuilder struct {
eb *types.SpecConfig
}
// NewDockerEngineBuilder function initializes a new DockerEngineBuilder instance.
// It sets the engine type to model.EngineDocker.String() and image as per the input argument.
func NewDockerEngineBuilder(image string) *DockerEngineBuilder {
eb := types.NewSpecConfig(types.ExecutorTypeDocker)
eb.WithParam(EngineKeyImage, image)
return &DockerEngineBuilder{eb: eb}
}
// WithEntrypoint is a builder method that sets the Docker engine entrypoint.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *DockerEngineBuilder) WithEntrypoint(e ...string) *DockerEngineBuilder {
b.eb.WithParam(EngineKeyEntrypoint, e)
return b
}
// WithCmd is a builder method that sets the Docker engine's Command.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *DockerEngineBuilder) WithCmd(c ...string) *DockerEngineBuilder {
b.eb.WithParam(EngineKeyCmd, c)
return b
}
// WithEnvironment is a builder method that sets the Docker engine's environment variables.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *DockerEngineBuilder) WithEnvironment(e ...string) *DockerEngineBuilder {
b.eb.WithParam(EngineKeyEnvironment, e)
return b
}
// WithWorkingDirectory is a builder method that sets the Docker engine's working directory.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *DockerEngineBuilder) WithWorkingDirectory(w string) *DockerEngineBuilder {
b.eb.WithParam(EngineKeyWorkingDirectory, w)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *DockerEngineBuilder) Build() *types.SpecConfig {
return b.eb
}
package firecracker
import (
"context"
"fmt"
"os"
"os/exec"
"syscall"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
)
const pidCheckTickTime = 100 * time.Millisecond
// Client wraps the Firecracker SDK to provide high-level operations on Firecracker VMs.
type Client struct{}
func NewFirecrackerClient() (*Client, error) {
return &Client{}, nil
}
// IsInstalled checks if Firecracker is installed on the host.
func (c *Client) IsInstalled() bool {
// LookPath searches for an executable named file in the directories named by the PATH environment variable.
// There might be a better way to check if Firecracker is installed.
_, err := exec.LookPath("firecracker")
return err == nil
}
// CreateVM creates a new Firecracker VM with the specified configuration.
func (c *Client) CreateVM(
ctx context.Context,
cfg firecracker.Config,
) (*firecracker.Machine, error) {
cmd := firecracker.VMCommandBuilder{}.
WithSocketPath(cfg.SocketPath).
Build(ctx)
machineOpts := []firecracker.Opt{
firecracker.WithProcessRunner(cmd),
}
m, err := firecracker.NewMachine(ctx, cfg, machineOpts...)
return m, err
}
// StartVM starts the Firecracker VM.
func (c *Client) StartVM(ctx context.Context, m *firecracker.Machine) error {
return m.Start(ctx)
}
// ShutdownVM shuts down the Firecracker VM.
func (c *Client) ShutdownVM(ctx context.Context, m *firecracker.Machine) error {
return m.Shutdown(ctx)
}
// DestroyVM destroys the Firecracker VM.
func (c *Client) DestroyVM(
ctx context.Context,
m *firecracker.Machine,
timeout time.Duration,
) error {
// Remove the socket file.
defer os.Remove(m.Cfg.SocketPath)
// Get the PID of the Firecracker process and shut down the VM.
// If the process is still running after the timeout, kill it.
pid, _ := m.PID()
c.ShutdownVM(ctx, m)
// If the process is not running, return early.
if pid <= 0 {
return nil
}
// This checks if the process is still running every pidCheckTickTime.
// If the process is still running after the timeout it will set done to false.
done := make(chan bool, 1)
go func() {
ticker := time.NewTicker(pidCheckTickTime)
defer ticker.Stop()
to := time.NewTimer(timeout)
defer to.Stop()
for {
select {
case <-to.C:
done <- false
return
case <-ticker.C:
if pid, _ := m.PID(); pid <= 0 {
done <- true
return
}
}
}
}()
// Wait for the check to finish.
select {
case killed := <-done:
if !killed {
// The shutdown request timed out, kill the process with SIGKILL.
err := syscall.Kill(int(pid), syscall.SIGKILL)
if err != nil {
return fmt.Errorf("failed to kill process: %v", err)
}
}
}
return nil
}
// FindVM finds a Firecracker VM by its socket path.
// This implementation checks if the VM is running by sending a request to the Firecracker API.
func (c *Client) FindVM(ctx context.Context, socketPath string) (*firecracker.Machine, error) {
// Check if the socket file exists.
if _, err := os.Stat(socketPath); err != nil {
return nil, fmt.Errorf("VM with socket path %v not found", socketPath)
}
// Create a new Firecracker machine instance.
cmd := firecracker.VMCommandBuilder{}.WithSocketPath(socketPath).Build(ctx)
machine, err := firecracker.NewMachine(
ctx,
firecracker.Config{SocketPath: socketPath},
firecracker.WithProcessRunner(cmd),
)
if err != nil {
return nil, fmt.Errorf("failed to create machine with socket %s: %v", socketPath, err)
}
// Check if the VM is running by getting its instance info.
info, err := machine.DescribeInstanceInfo(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get instance info for socket %s: %v", socketPath, err)
}
if *info.State != "Running" {
return nil, fmt.Errorf(
"VM with socket %s is not running, current state: %s",
socketPath,
*info.State,
)
}
return machine, nil
}
package firecracker
import (
"context"
"fmt"
"io"
"os"
"sync"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
fcModels "github.com/firecracker-microvm/firecracker-go-sdk/client/models"
"go.uber.org/multierr"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
socketDir = "/tmp"
DefaultCpuCount int64 = 1
DefaultMemSize int64 = 50
)
// Executor manages the lifecycle of Firecracker VMs for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Firecracker client for VM management.
}
// NewExecutor initializes a new executor for Firecracker VMs.
func NewExecutor(
_ context.Context,
id string,
) (*Executor, error) {
firecrackerClient, err := NewFirecrackerClient()
if err != nil {
return nil, err
}
fe := &Executor{
ID: id,
client: firecrackerClient,
}
return fe, nil
}
// IsInstalled checks if Firecracker is installed on the host.
func (e *Executor) IsInstalled(ctx context.Context) bool {
return e.client.IsInstalled()
}
// start begins the execution of a request by starting a new Firecracker VM.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// VM is already running.
machine, err := e.FindRunningVM(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running VM for this execution, we will instead check for a handler, and
// failing that will create a new VM.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
} else {
return fmt.Errorf("execution is already completed")
}
}
// Create a new handler for the execution.
machine, err = e.newFirecrackerExecutionVM(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new firecracker VM: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
machine: machine,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the VM.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s recieved results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// GetLogStream is not implemented for Firecracker.
// It is defined to satisfy the Executor interface.
// This method will return an error if called.
func (e *Executor) GetLogStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
return nil, fmt.Errorf("GetLogStream is not implemented for Firecracker")
}
// Cleanup removes all resources associated with the executor.
// This includes stopping and removing all running VMs and deleting their socket paths.
func (e *Executor) Cleanup(ctx context.Context) error {
wg := sync.WaitGroup{}
errCh := make(chan error, len(e.handlers.Keys()))
e.handlers.Iter(func(_ string, handler *executionHandler) bool {
wg.Add(1)
go func(handler *executionHandler, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- handler.destroy(time.Second * 10)
}(handler, &wg, errCh)
return true
})
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
zlog.Info("Cleaned up all firecracker resources")
return errs
}
// newFirecrackerExecutionVM is an internal method called by Start to set up a new Firecracker VM
// for the job execution. It configures the VM based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up mounts and resource constraints.
// It then creates the VM but does not start it. The method returns a firecracker.Machine instance
// and an error if any part of the setup fails.
func (e *Executor) newFirecrackerExecutionVM(
ctx context.Context,
params *types.ExecutionRequest,
) (*firecracker.Machine, error) {
fcArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return nil, fmt.Errorf("failed to decode firecracker engine spec: %w", err)
}
fcConfig := firecracker.Config{
VMID: params.ExecutionID,
SocketPath: e.generateSocketPath(params.JobID, params.ExecutionID),
KernelImagePath: fcArgs.KernelImage,
InitrdPath: fcArgs.Initrd,
KernelArgs: fcArgs.KernelArgs,
MachineCfg: fcModels.MachineConfiguration{
VcpuCount: firecracker.Int64(int64(params.Resources.CPU.Cores)),
MemSizeMib: firecracker.Int64(int64(params.Resources.Memory.Size)),
},
}
mounts, err := makeVMMounts(
fcArgs.RootFileSystem,
params.Inputs,
params.Outputs,
params.ResultsDir,
)
if err != nil {
return nil, fmt.Errorf("failed to create VM mounts: %w", err)
}
fcConfig.Drives = mounts
machine, err := e.client.CreateVM(ctx, fcConfig)
if err != nil {
return nil, fmt.Errorf("failed to create VM: %w", err)
}
// e.client.VMPassMMDs(ctx, machine, fcArgs.MMDSMessage)
return machine, nil
}
// makeVMMounts creates the mounts for the VM based on the input and output volumes
// provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeVMMounts(
rootFileSystem string,
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]fcModels.Drive, error) {
var drives []fcModels.Drive
drivesBuilder := firecracker.NewDrivesBuilder(rootFileSystem)
for _, input := range inputs {
drivesBuilder.AddDrive(input.Source, input.ReadOnly)
}
for _, output := range outputs {
if output.Source == "" {
return drives, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return drives, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return drives, fmt.Errorf("failed to create results directory: %w", err)
}
drivesBuilder.AddDrive(output.Source, false)
}
drives = drivesBuilder.Build()
return drives, nil
}
// FindRunningVM finds the VM that is running the execution with the given ID.
// It returns the Mchine instance if found, or an error if the VM is not found.
func (e *Executor) FindRunningVM(
ctx context.Context,
jobID string,
executionID string,
) (*firecracker.Machine, error) {
return e.client.FindVM(ctx, e.generateSocketPath(jobID, executionID))
}
// generateSocketPath generates a socket path based on the job identifiers.
func (e *Executor) generateSocketPath(jobID string, executionID string) string {
return fmt.Sprintf("%s/%s_%s_%s.sock", socketDir, e.ID, jobID, executionID)
}
package firecracker
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
"gitlab.com/nunet/device-management-service/types"
)
// executionHandler is a struct that holds the necessary information to manage the execution of a firecracker VM.
type executionHandler struct {
//
// provided by the executor
ID string
client *Client
// meta data about the task
JobID string
executionID string
machine *firecracker.Machine
resultsDir string
// synchronization
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // BLocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
}
// active returns true if the firecracker VM is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the firecracker VM and waits for it to finish.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
destroyTimeout := time.Second * 10
if err := h.destroy(destroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
// start the VM
zlog.Sugar().Info("starting firecracker execution")
if err := h.client.StartVM(ctx, h.machine); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start VM: %v", err))
return
}
close(h.activeCh) // Indicate that the VM has started.
err := h.machine.Wait(ctx)
if err != nil {
if ctx.Err() != nil {
h.result = types.NewFailedExecutionResult(
fmt.Errorf("context closed while waiting on VM: %v", err),
)
return
}
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to wait on VM: %v", err))
return
}
h.result = types.NewExecutionResult(types.ExecutionStatusCodeSuccess)
}
// kill stops the firecracker VM.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.ShutdownVM(ctx, h.machine)
}
// destroy stops the firecracker VM and removes its resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return h.client.DestroyVM(ctx, h.machine, timeout)
}
package firecracker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("executor.firecracker")
}
package firecracker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyKernelImage = "kernel_image"
EngineKeyKernelArgs = "kernel_args"
EngineKeyRootFileSystem = "root_file_system"
EngineKeyMMDSMessage = "mmds_message"
)
// EngineSpec contains necessary parameters to execute a firecracker job.
type EngineSpec struct {
// KernelImage is the path to the kernel image file.
KernelImage string `json:"kernel_image,omitempty"`
// InitrdPath is the path to the initial ramdisk file.
Initrd string `json:"initrd_path,omitempty"`
// KernelArgs is the kernel command line arguments.
KernelArgs string `json:"kernel_args,omitempty"`
// RootFileSystem is the path to the root file system.
RootFileSystem string `json:"root_file_system,omitempty"`
// MMDSMessage is the MMDS message to be sent to the Firecracker VM.
MMDSMessage string `json:"mmds_message,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.RootFileSystem) {
return fmt.Errorf("invalid firecracker engine params: root_file_system cannot be empty")
}
if validate.IsBlank(c.KernelImage) {
return fmt.Errorf("invalid firecracker engine params: kernel_image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a firecracker engine spec
// It converts the params into a firecracker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(types.ExecutorTypeFirecracker) {
return EngineSpec{}, fmt.Errorf(
"invalid firecracker engine type. expected %s, but recieved: %s",
types.ExecutorTypeFirecracker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid firecracker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode firecracker engine params: %w", err)
}
var firecrackerSpec *EngineSpec
err = json.Unmarshal(paramBytes, &firecrackerSpec)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode firecracker engine params: %w", err)
}
return *firecrackerSpec, firecrackerSpec.Validate()
}
// FirecrackerEngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Firecracker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type FirecrackerEngineBuilder struct {
eb *types.SpecConfig
}
// NewFirecrackerEngineBuilder function initializes a new FirecrackerEngineBuilder instance.
// It sets the engine type to EngineFirecracker.String() and kernel image path as per the input argument.
func NewFirecrackerEngineBuilder(rootFileSystem string) *FirecrackerEngineBuilder {
eb := types.NewSpecConfig(types.ExecutorTypeFirecracker)
eb.WithParam(EngineKeyRootFileSystem, rootFileSystem)
return &FirecrackerEngineBuilder{eb: eb}
}
// WithRootFileSystem is a builder method that sets the Firecracker engine root file system.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *FirecrackerEngineBuilder) WithRootFileSystem(e string) *FirecrackerEngineBuilder {
b.eb.WithParam(EngineKeyRootFileSystem, e)
return b
}
// WithKernelImage is a builder method that sets the Firecracker engine kernel image.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *FirecrackerEngineBuilder) WithKernelImage(e string) *FirecrackerEngineBuilder {
b.eb.WithParam(EngineKeyKernelImage, e)
return b
}
// WithKernelArgs is a builder method that sets the Firecracker engine kernel arguments.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *FirecrackerEngineBuilder) WithKernelArgs(e string) *FirecrackerEngineBuilder {
b.eb.WithParam(EngineKeyKernelArgs, e)
return b
}
// WithMMDSMessage is a builder method that sets the Firecracker engine MMDS message.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *FirecrackerEngineBuilder) WithMMDSMessage(e string) *FirecrackerEngineBuilder {
b.eb.WithParam(EngineKeyMMDSMessage, e)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *FirecrackerEngineBuilder) Build() *types.SpecConfig {
return b.eb
}
package background_tasks
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("background_tasks")
}
package background_tasks
import (
"sort"
"sync"
"time"
)
// Scheduler orchestrates the execution of tasks based on their triggers and priority.
type Scheduler struct {
tasks map[int]*Task // Map of tasks by their ID.
runningTasks map[int]bool // Map to keep track of running tasks.
ticker *time.Ticker // Ticker for periodic checks of task triggers.
stopChan chan struct{} // Channel to signal stopping the scheduler.
maxRunningTasks int // Maximum number of tasks that can run concurrently.
lastTaskID int // Counter for assigning unique IDs to tasks.
mu sync.Mutex // Mutex to protect access to task maps.
}
// NewScheduler creates a new Scheduler with a specified limit on running tasks.
func NewScheduler(maxRunningTasks int) *Scheduler {
return &Scheduler{
tasks: make(map[int]*Task),
runningTasks: make(map[int]bool),
ticker: time.NewTicker(1 * time.Second),
stopChan: make(chan struct{}),
maxRunningTasks: maxRunningTasks,
lastTaskID: 0,
}
}
// AddTask adds a new task to the scheduler and initializes its state.
func (s *Scheduler) AddTask(task *Task) *Task {
s.mu.Lock()
defer s.mu.Unlock()
task.ID = s.lastTaskID
task.Enabled = true
for _, trigger := range task.Triggers {
trigger.Reset()
}
s.tasks[task.ID] = task
s.lastTaskID++
return task
}
// RemoveTask removes a task from the scheduler.
func (s *Scheduler) RemoveTask(taskID int) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Start begins the scheduler's task execution loop.
func (s *Scheduler) Start() {
go func() {
for {
select {
case <-s.stopChan:
return
case <-s.ticker.C:
s.runTasks()
}
}
}()
}
// runningTasksCount returns the count of running tasks.
func (s *Scheduler) runningTasksCount() int {
s.mu.Lock()
defer s.mu.Unlock()
count := 0
for _, isRunning := range s.runningTasks {
if isRunning {
count++
}
}
return count
}
// runTasks checks and runs tasks based on their triggers and priority.
func (s *Scheduler) runTasks() {
// Sort tasks by priority.
sortedTasks := make([]*Task, 0, len(s.tasks))
for _, task := range s.tasks {
sortedTasks = append(sortedTasks, task)
}
sort.Slice(sortedTasks, func(i, j int) bool {
return sortedTasks[i].Priority > sortedTasks[j].Priority
})
for _, task := range sortedTasks {
if !task.Enabled || s.runningTasks[task.ID] {
continue
}
if len(task.Triggers) == 0 {
s.RemoveTask(task.ID)
continue
}
for _, trigger := range task.Triggers {
if trigger.IsReady() && s.runningTasksCount() < s.maxRunningTasks {
s.runningTasks[task.ID] = true
go s.runTask(task.ID)
trigger.Reset()
break
}
}
}
}
// Stop signals the scheduler to stop running tasks.
func (s *Scheduler) Stop() {
close(s.stopChan)
}
// runTask executes a task and manages its lifecycle and retry policy.
func (s *Scheduler) runTask(taskID int) {
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
s.runningTasks[taskID] = false
}()
task := s.tasks[taskID]
execution := Execution{StartedAt: time.Now()}
defer func() {
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.tasks[taskID] = task
s.mu.Unlock()
}()
for i := 0; i < task.RetryPolicy.MaxRetries+1; i++ {
err := runTaskWithRetry(task.Function, task.Args, task.RetryPolicy.Delay)
if err == nil {
execution.Status = "SUCCESS"
execution.EndedAt = time.Now()
return
}
execution.Error = err.Error()
}
execution.Status = "FAILED"
execution.EndedAt = time.Now()
}
// runTaskWithRetry attempts to execute a task with a retry policy.
func runTaskWithRetry(
fn func(args interface{}) error,
args []interface{},
delay time.Duration,
) error {
err := fn(args)
if err != nil {
time.Sleep(delay)
return err
}
return nil
}
package background_tasks
import (
"time"
"github.com/robfig/cron/v3"
)
// Trigger interface defines a method to check if a trigger condition is met.
type Trigger interface {
IsReady() bool // Returns true if the trigger condition is met.
Reset() // Resets the trigger state.
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTrigger struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTrigger) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTrigger) Reset() {
t.lastTriggered = time.Now()
}
// EventTrigger triggers based on an external event signaled through a channel.
type EventTrigger struct {
Trigger chan bool // Channel to signal an event.
}
// IsReady checks if there is a signal in the trigger channel.
func (t *EventTrigger) IsReady() bool {
select {
case <-t.Trigger:
return true
default:
return false
}
}
// Reset for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) Reset() {}
// OneTimeTrigger triggers once after a specified delay.
type OneTimeTrigger struct {
Delay time.Duration // The delay after which to trigger.
registeredAt time.Time // Time when the trigger was set.
}
// Reset sets the trigger registration time to the current time.
func (t *OneTimeTrigger) Reset() {
t.registeredAt = time.Now()
}
// IsReady checks if the current time has passed the delay period.
func (t *OneTimeTrigger) IsReady() bool {
return t.registeredAt.Add(t.Delay).Before(time.Now())
}
package libp2p
import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/multiformats/go-multiaddr"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"google.golang.org/protobuf/proto"
)
// Bootstrap using a list.
func (l *Libp2p) Bootstrap(ctx context.Context, bootstrapPeers []multiaddr.Multiaddr) error {
if err := l.DHT.Bootstrap(ctx); err != nil {
return fmt.Errorf("failed to prepare this node for bootstraping: %w", err)
}
// bootstrap all nodes at the same time.
if len(bootstrapPeers) > 0 {
var wg sync.WaitGroup
for _, addr := range bootstrapPeers {
wg.Add(1)
go func(peerAddr multiaddr.Multiaddr) {
defer wg.Done()
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
zlog.Sugar().Errorf("failed to convert multi addr to addr info %v - %v", peerAddr, err)
return
}
if err := l.Host.Connect(ctx, *addrInfo); err != nil {
zlog.Sugar().Errorf("failed to connect to bootstrap node %s - %v", addrInfo.ID.String(), err)
} else {
zlog.Sugar().Infof("connected to Bootstrap Node %s", addrInfo.ID.String())
}
}(addr)
}
wg.Wait()
}
return nil
}
type dhtValidator struct {
PS peerstore.Peerstore
customNamespace string
}
// Validate validates an item placed into the dht.
func (d dhtValidator) Validate(key string, value []byte) error {
// empty value is considered deleting an item from the dht
if len(value) == 0 {
return nil
}
if !strings.HasPrefix(key, d.customNamespace) {
return errors.New("invalid key namespace")
}
// verify signature
var envelope commonproto.Advertisement
err := proto.Unmarshal(value, &envelope)
if err != nil {
return fmt.Errorf("failed to unmarshal envelope: %w", err)
}
pubKey, err := crypto.UnmarshalSecp256k1PublicKey(envelope.PublicKey)
if err != nil {
return fmt.Errorf("failed to unmarshal public key: %w", err)
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
envelope.PublicKey,
}, nil)
ok, err := pubKey.Verify(concatenatedBytes, envelope.Signature)
if err != nil {
return fmt.Errorf("failed to verify envelope: %w", err)
}
if !ok {
return errors.New("failed to verify envelope, public key didn't sign payload")
}
return nil
}
func (dhtValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
// TODO remove the below when network package is fully implemented
// UpdateKadDHT is a stub
func (l *Libp2p) UpdateKadDHT() {
zlog.Warn("UpdateKadDHT: Stub")
}
// ListKadDHTPeers is a stub
func (l *Libp2p) ListKadDHTPeers(c *gin.Context, ctx context.Context) ([]string, error) {
zlog.Warn("ListKadDHTPeers: Stub")
return nil, nil
}
package libp2p
import (
"context"
"fmt"
"os"
"github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
dutil "github.com/libp2p/go-libp2p/p2p/discovery/util"
)
// DiscoverDialPeers discovers peers using randevouz point
func (l *Libp2p) DiscoverDialPeers(ctx context.Context) error {
foundPeers, err := l.findPeersFromRendezvousDiscovery(ctx)
if err != nil {
return err
}
if len(foundPeers) > 0 {
l.discoveredPeers = foundPeers
}
// filter out peers with no listening addresses and self host
filterSpec := NoAddrIDFilter{ID: l.Host.ID()}
l.discoveredPeers = PeerPassFilter(l.discoveredPeers, filterSpec)
l.dialPeers(ctx)
return nil
}
// advertiseForRendezvousDiscovery is used to advertise node using the dht by giving it the randevouz point.
func (l *Libp2p) advertiseForRendezvousDiscovery(context context.Context) error {
_, err := l.discovery.Advertise(context, l.config.Rendezvous)
return err
}
// findPeersFromRendezvousDiscovery uses the randevouz point to discover other peers.
func (l *Libp2p) findPeersFromRendezvousDiscovery(ctx context.Context) ([]peer.AddrInfo, error) {
peers, err := dutil.FindPeers(
ctx,
l.discovery,
l.config.Rendezvous,
discovery.Limit(l.config.PeerCountDiscoveryLimit),
)
if err != nil {
return nil, fmt.Errorf("failed to discover peers: %w", err)
}
return peers, nil
}
func (l *Libp2p) dialPeers(ctx context.Context) {
for _, p := range l.discoveredPeers {
if p.ID == l.Host.ID() {
continue
}
if l.Host.Network().Connectedness(p.ID) != network.Connected {
_, err := l.Host.Network().DialPeer(ctx, p.ID)
if err != nil {
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("couldn't establish connection with: %s - error: %v", p.ID.String(), err)
}
continue
}
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("connected with: %s", p.ID.String())
}
}
}
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/control"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
mafilt "github.com/whyrusleeping/multiaddr-filter"
)
var defaultServerFilters = []string{
"/ip4/10.0.0.0/ipcidr/8",
"/ip4/100.64.0.0/ipcidr/10",
"/ip4/169.254.0.0/ipcidr/16",
"/ip4/172.16.0.0/ipcidr/12",
"/ip4/192.0.0.0/ipcidr/24",
"/ip4/192.0.2.0/ipcidr/24",
"/ip4/192.168.0.0/ipcidr/16",
"/ip4/198.18.0.0/ipcidr/15",
"/ip4/198.51.100.0/ipcidr/24",
"/ip4/203.0.113.0/ipcidr/24",
"/ip4/240.0.0.0/ipcidr/4",
"/ip6/100::/ipcidr/64",
"/ip6/2001:2::/ipcidr/48",
"/ip6/2001:db8::/ipcidr/32",
"/ip6/fc00::/ipcidr/7",
"/ip6/fe80::/ipcidr/10",
}
// PeerFilter is an interface for filtering peers
// satisfaction of filter criteria allows the peer to pass
type PeerFilter interface {
satisfies(p peer.AddrInfo) bool
}
// NoAddrIDFilter filters out peers with no listening addresses
// and a peer with a specific ID
type NoAddrIDFilter struct {
ID peer.ID
}
func (f NoAddrIDFilter) satisfies(p peer.AddrInfo) bool {
return len(p.Addrs) > 0 && p.ID != f.ID
}
func PeerPassFilter(peers []peer.AddrInfo, pf PeerFilter) []peer.AddrInfo {
var filtered []peer.AddrInfo
for _, p := range peers {
if pf.satisfies(p) {
filtered = append(filtered, p)
}
}
return filtered
}
type filtersConnectionGater multiaddr.Filters
var _ connmgr.ConnectionGater = (*filtersConnectionGater)(nil)
func (f *filtersConnectionGater) InterceptAddrDial(_ peer.ID, addr multiaddr.Multiaddr) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(addr)
}
func (f *filtersConnectionGater) InterceptPeerDial(p peer.ID) (allow bool) {
return true
}
func (f *filtersConnectionGater) InterceptAccept(connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptUpgraded(_ network.Conn) (allow bool, reason control.DisconnectReason) {
return true, 0
}
func makeAddrsFactory(announce []string, appendAnnouce []string, noAnnounce []string) func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var err error // To assign to the slice in the for loop
existing := make(map[string]bool) // To avoid duplicates
annAddrs := make([]multiaddr.Multiaddr, len(announce))
for i, addr := range announce {
annAddrs[i], err = multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
existing[addr] = true
}
var appendAnnAddrs []multiaddr.Multiaddr
for _, addr := range appendAnnouce {
if existing[addr] {
// skip AppendAnnounce that is on the Announce list already
continue
}
appendAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
appendAnnAddrs = append(appendAnnAddrs, appendAddr)
}
filters := multiaddr.NewFilters()
noAnnAddrs := map[string]bool{}
for _, addr := range noAnnounce {
f, err := mafilt.NewMask(addr)
if err == nil {
filters.AddFilter(*f, multiaddr.ActionDeny)
continue
}
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
noAnnAddrs[string(maddr.Bytes())] = true
}
return func(allAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
var addrs []multiaddr.Multiaddr
if len(annAddrs) > 0 {
addrs = annAddrs
} else {
addrs = allAddrs
}
addrs = append(addrs, appendAnnAddrs...)
var out []multiaddr.Multiaddr
for _, maddr := range addrs {
// check for exact matches
ok := noAnnAddrs[string(maddr.Bytes())]
// check for /ipcidr matches
if !ok && !filters.AddrBlocked(maddr) {
out = append(out, maddr)
}
}
return out
}
}
package libp2p
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/types"
)
// StreamHandler is a function type that processes data from a stream.
type StreamHandler func(stream network.Stream)
// HandlerRegistry manages the registration of stream handlers for different protocols.
type HandlerRegistry struct {
host host.Host
handlers map[protocol.ID]StreamHandler
bytesHandlers map[protocol.ID]func(data []byte)
mu sync.RWMutex
}
// NewHandlerRegistry creates a new handler registry instance.
func NewHandlerRegistry(host host.Host) *HandlerRegistry {
return &HandlerRegistry{
host: host,
handlers: make(map[protocol.ID]StreamHandler),
bytesHandlers: make(map[protocol.ID]func(data []byte)),
}
}
// RegisterHandlerWithStreamCallback registers a stream handler for a specific protocol.
func (r *HandlerRegistry) RegisterHandlerWithStreamCallback(messageType types.MessageType, handler StreamHandler) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.handlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.handlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(handler))
return nil
}
// RegisterHandlerWithBytesCallback registers a stream handler for a specific protocol and sends the bytes back to callback.
func (r *HandlerRegistry) RegisterHandlerWithBytesCallback(messageType types.MessageType, s StreamHandler, handler func(data []byte)) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.bytesHandlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.bytesHandlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(s))
return nil
}
// SendMessageToLocalHandler given the message type it sends data to the local handler found.
func (r *HandlerRegistry) SendMessageToLocalHandler(messageType types.MessageType, data []byte) {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
h, ok := r.bytesHandlers[protoID]
if !ok {
return
}
// we need this goroutine to avoid blocking the caller goroutine
go h(data)
}
package libp2p
import (
"context"
"fmt"
"strings"
"time"
"github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/pnet"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/routing"
"github.com/libp2p/go-libp2p/p2p/host/autorelay"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/security/noise"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"github.com/multiformats/go-multiaddr"
"github.com/spf13/afero"
mafilt "github.com/whyrusleeping/multiaddr-filter"
"gitlab.com/nunet/device-management-service/types"
)
// NewHost returns a new libp2p host with dht and other related settings.
func NewHost(ctx context.Context, config *types.Libp2pConfig, fs afero.Fs) (host.Host, *dht.IpfsDHT, *pubsub.PubSub, error) {
var idht *dht.IpfsDHT
connmgr, err := connmgr.NewConnManager(
100,
400,
connmgr.WithGracePeriod(time.Duration(config.GracePeriodMs)*time.Millisecond),
)
if err != nil {
return nil, nil, nil, err
}
filter := multiaddr.NewFilters()
for _, s := range defaultServerFilters {
f, err := mafilt.NewMask(s)
if err != nil {
zlog.Sugar().Errorf("incorrectly formatted address filter in config: %s - %v", s, err)
}
filter.AddFilter(*f, multiaddr.ActionDeny)
}
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, nil, nil, err
}
var libp2pOpts []libp2p.Option
baseOpts := []dht.Option{
dht.ProtocolPrefix(protocol.ID(config.DHTPrefix)),
dht.NamespacedValidator(strings.ReplaceAll(config.CustomNamespace, "/", ""), dhtValidator{PS: ps}),
dht.Mode(dht.ModeServer),
}
if config.PrivateNetwork.WithSwarmKey {
psk, err := configureSwarmKey(fs)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to configure swarm key: %v", err)
}
libp2pOpts = append(libp2pOpts, libp2p.PrivateNetwork(psk))
// guarantee that outer connection will be refused
pnet.ForcePrivateNetwork = true
} else {
// enable quic (it does not work with pnet enabled)
libp2pOpts = append(libp2pOpts, libp2p.Transport(quic.NewTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(webtransport.New))
// for some reason, ForcePrivateNetwork was equal to true even without being set to true
pnet.ForcePrivateNetwork = false
}
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(config.ListenAddress...),
libp2p.Identity(config.PrivateKey),
libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) {
idht, err = dht.New(ctx, h, baseOpts...)
return idht, err
}),
libp2p.Peerstore(ps),
libp2p.Security(libp2ptls.ID, libp2ptls.New),
libp2p.Security(noise.ID, noise.New),
// Do not use DefaulTransports as we can not enable Quic when pnet
libp2p.Transport(tcp.NewTCPTransport),
libp2p.Transport(ws.New),
libp2p.EnableNATService(),
libp2p.ConnectionManager(connmgr),
libp2p.EnableRelay(),
libp2p.EnableHolePunching(),
libp2p.EnableRelayService(
relay.WithResources(
relay.Resources{
MaxReservations: 256,
MaxCircuits: 32,
BufferSize: 4096,
MaxReservationsPerPeer: 8,
MaxReservationsPerIP: 16,
},
),
relay.WithLimit(&relay.RelayLimit{
Duration: 5 * time.Minute,
Data: 1 << 21, // 2 MiB
}),
),
libp2p.EnableAutoRelayWithPeerSource(
func(ctx context.Context, num int) <-chan peer.AddrInfo {
r := make(chan peer.AddrInfo)
go func() {
defer close(r)
for i := 0; i < num; i++ {
select {
case p := <-newPeer:
select {
case r <- p:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return r
},
autorelay.WithBootDelay(time.Minute),
autorelay.WithBackoff(30*time.Second),
autorelay.WithMinCandidates(2),
autorelay.WithMaxCandidates(3),
autorelay.WithNumRelays(2),
),
)
if config.Server {
libp2pOpts = append(libp2pOpts, libp2p.AddrsFactory(makeAddrsFactory([]string{}, []string{}, defaultServerFilters)))
libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater((*filtersConnectionGater)(filter)))
} else {
libp2pOpts = append(libp2pOpts, libp2p.NATPortMap())
}
host, err := libp2p.New(libp2pOpts...)
if err != nil {
return nil, nil, nil, err
}
optsPS := []pubsub.Option{pubsub.WithMessageSigning(true), pubsub.WithMaxMessageSize(config.GossipMaxMessageSize)}
gossip, err := pubsub.NewGossipSub(ctx, host, optsPS...)
// gossip, err := pubsub.NewGossipSubWithRouter(ctx, host, pubsub.DefaultGossipSubRouter(host), optsPS...)
if err != nil {
return nil, nil, nil, err
}
return host, idht, gossip, nil
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
const (
// Custom namespace for DHT protocol with version number
customNamespace = "/nunet-dht-1/"
)
// TODO: pass the logger to the constructor and remove from here
var (
zlog *otelzap.Logger
newPeer = make(chan peer.AddrInfo)
)
func init() {
zlog = logger.OtelZapLogger("network.libp2p")
}
package libp2p
import (
"context"
"bufio"
"bytes"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/ipfs/go-cid"
kbucket "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"gitlab.com/nunet/device-management-service/types"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
"github.com/spf13/afero"
"google.golang.org/protobuf/proto"
dht "github.com/libp2p/go-libp2p-kad-dht"
libp2pdiscovery "github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
drouting "github.com/libp2p/go-libp2p/p2p/discovery/routing"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
)
const (
MB = 1024 * 1024
MaxMessageLengthMB = 10
)
// Libp2p contains the configuration for a Libp2p instance.
//
// TODO-suggestion: maybe we should call it something else like Libp2pPeer,
// Libp2pHost or just Peer (callers would use libp2p.Peer...)
type Libp2p struct {
Host host.Host
DHT *dht.IpfsDHT
PS peerstore.Peerstore
pubsub *pubsub.PubSub
pubsubTopics map[string]*pubsub.Topic
topicSubscription map[string]*pubsub.Subscription
topicMux sync.RWMutex
// a list of peers discovered by discovery
discoveredPeers []peer.AddrInfo
discovery libp2pdiscovery.Discovery
// services
pingService *ping.PingService
// tasks
discoveryTask *bt.Task
handlerRegistry *HandlerRegistry
config *types.Libp2pConfig
// dependencies (db, filesystem...)
fs afero.Fs
}
// New creates a libp2p instance.
//
// TODO-Suggestion: move types.Libp2pConfig to here for better readability.
// Unless there is a reason to keep within types.
func New(config *types.Libp2pConfig, fs afero.Fs) (*Libp2p, error) {
if config == nil {
return nil, errors.New("config is nil")
}
if config.Scheduler == nil {
return nil, errors.New("scheduler is nil")
}
return &Libp2p{
config: config,
discoveredPeers: make([]peer.AddrInfo, 0),
pubsubTopics: make(map[string]*pubsub.Topic),
topicSubscription: make(map[string]*pubsub.Subscription),
fs: fs,
}, nil
}
// Init initializes a libp2p host with its dependencies.
func (l *Libp2p) Init(context context.Context) error {
host, dht, pubsub, err := NewHost(context, l.config, l.fs)
if err != nil {
zlog.Sugar().Error(err)
return err
}
l.Host = host
l.DHT = dht
l.PS = host.Peerstore()
l.discovery = drouting.NewRoutingDiscovery(dht)
l.pubsub = pubsub
l.handlerRegistry = NewHandlerRegistry(host)
return nil
}
// Start performs network bootstrapping, peer discovery and protocols handling.
func (l *Libp2p) Start(context context.Context) error {
// set stream handlers
l.registerStreamHandlers()
// bootstrap should return error if it had an error
err := l.Bootstrap(context, l.config.BootstrapPeers)
if err != nil {
zlog.Sugar().Errorf("failed to start network: %v", err)
return err
}
// advertise randevouz discovery
err = l.advertiseForRendezvousDiscovery(context)
if err != nil {
// TODO: the error might be misleading as a peer can normally work well if an error
// is returned here (e.g.: the error is yielded in tests even though all tests pass).
zlog.Sugar().Errorf("failed to start network with randevouz discovery: %v", err)
}
// discover
err = l.DiscoverDialPeers(context)
if err != nil {
zlog.Sugar().Errorf("failed to discover peers: %v", err)
}
// register period peer discoveryTask task
discoveryTask := &bt.Task{
Name: "Peer Discovery",
Description: "Periodic task to discover new peers every 15 minutes",
Function: func(args interface{}) error {
return l.DiscoverDialPeers(context)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 15 * time.Minute}},
}
l.discoveryTask = l.config.Scheduler.AddTask(discoveryTask)
l.config.Scheduler.Start()
return nil
}
// RegisterStreamMessageHandler registers a stream handler for a specific protocol.
func (l *Libp2p) RegisterStreamMessageHandler(messageType types.MessageType, handler StreamHandler) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithStreamCallback(messageType, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// RegisterBytesMessageHandler registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) RegisterBytesMessageHandler(messageType types.MessageType, handler func(data []byte)) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithBytesCallback(messageType, l.handleReadBytesFromStream, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// HandleMessage registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) HandleMessage(messageType string, handler func(data []byte)) error {
return l.RegisterBytesMessageHandler(types.MessageType(messageType), handler)
}
func (l *Libp2p) handleReadBytesFromStream(s network.Stream) {
callback, ok := l.handlerRegistry.bytesHandlers[s.Protocol()]
if !ok {
s.Close()
return
}
c := bufio.NewReader(s)
defer s.Close()
// read the first 8 bytes to determine the size of the message
msgLengthBuffer := make([]byte, 8)
_, err := c.Read(msgLengthBuffer)
if err != nil {
return
}
// create a buffer with the size of the message and then read until its full
lengthPrefix := int64(binary.LittleEndian.Uint64(msgLengthBuffer))
buf := make([]byte, lengthPrefix)
// read the full message
_, err = io.ReadFull(c, buf)
if err != nil {
return
}
callback(buf)
}
// SendMessage sends a message to a list of peers.
func (l *Libp2p) SendMessage(ctx context.Context, addrs []string, msg types.MessageEnvelope) error {
var wg sync.WaitGroup
errCh := make(chan error, len(addrs))
for _, addr := range addrs {
wg.Add(1)
go func(addr string) {
defer wg.Done()
err := l.sendMessage(ctx, addr, msg)
if err != nil {
errCh <- err
}
}(addr)
}
wg.Wait()
close(errCh)
var result error
for err := range errCh {
if result == nil {
result = err
} else {
result = fmt.Errorf("%v; %v", result, err)
}
}
return result
}
// OpenStream opens a stream to a remote address and returns the stream for the caller to handle.
func (l *Libp2p) OpenStream(ctx context.Context, addr string, messageType types.MessageType) (network.Stream, error) {
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil, fmt.Errorf("invalid multiaddress: %w", err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(maddr)
if err != nil {
return nil, fmt.Errorf("could not resolve peer info: %w", err)
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(messageType))
if err != nil {
return nil, fmt.Errorf("failed to open stream: %w", err)
}
return stream, nil
}
// GetMultiaddr returns the peer's multiaddr.
func (l *Libp2p) GetMultiaddr() ([]multiaddr.Multiaddr, error) {
peerInfo := peer.AddrInfo{
ID: l.Host.ID(),
Addrs: l.Host.Addrs(),
}
return peer.AddrInfoToP2pAddrs(&peerInfo)
}
// Stop performs a cleanup of any resources used in this package.
func (l *Libp2p) Stop() error {
var errorMessages []string
l.config.Scheduler.RemoveTask(l.discoveryTask.ID)
if err := l.DHT.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if err := l.Host.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if len(errorMessages) > 0 {
return errors.New(strings.Join(errorMessages, "; "))
}
return nil
}
// Stat returns the status about the libp2p network.
func (l *Libp2p) Stat() types.NetworkStats {
var lAddrs []string
for _, addr := range l.Host.Addrs() {
lAddrs = append(lAddrs, addr.String())
}
return types.NetworkStats{
ID: l.Host.ID().String(),
ListenAddr: strings.Join(lAddrs, ", "),
}
}
// Ping the remote address. The remote address is the encoded peer id which will be decoded and used here.
//
// TODO (Return error once): something that was confusing me when using this method is that the error is
// returned twice if any. Once as a field of PingResult and one as a return value.
func (l *Libp2p) Ping(ctx context.Context, peerIDAddress string, timeout time.Duration) (types.PingResult, error) {
// avoid dial to self attempt
if peerIDAddress == l.Host.ID().String() {
err := errors.New("can't ping self")
return types.PingResult{Success: false, Error: err}, err
}
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
remotePeer, err := peer.Decode(peerIDAddress)
if err != nil {
return types.PingResult{}, err
}
pingChan := ping.Ping(pingCtx, l.Host, remotePeer)
select {
case res := <-pingChan:
if res.Error != nil {
zlog.Sugar().Errorf("failed to ping peer %s: %v", peerIDAddress, res.Error)
return types.PingResult{
Success: false,
RTT: res.RTT,
Error: res.Error,
}, res.Error
}
return types.PingResult{
RTT: res.RTT,
Success: true,
}, nil
case <-pingCtx.Done():
return types.PingResult{
Error: pingCtx.Err(),
}, pingCtx.Err()
}
}
// ResolveAddress resolves the address by given a peer id.
func (l *Libp2p) ResolveAddress(ctx context.Context, id string) ([]string, error) {
pid, err := peer.Decode(id)
if err != nil {
return nil, fmt.Errorf("failed to resolve invalid peer: %w", err)
}
// resolve ourself
if l.Host.ID().String() == id {
multiAddrs, err := l.GetMultiaddr()
if err != nil {
return nil, fmt.Errorf("failed to resolve self: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
pi, err := l.DHT.FindPeer(ctx, pid)
if err != nil {
return nil, fmt.Errorf("failed to resolve address %s: %w", id, err)
}
peerInfo := peer.AddrInfo{
ID: pi.ID,
Addrs: pi.Addrs,
}
multiAddrs, err := peer.AddrInfoToP2pAddrs(&peerInfo)
if err != nil {
return nil, fmt.Errorf("failed to convert to p2p address: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
// Query return all the advertisements in the network related to a key.
// The network is queried to find providers for the given key, and peers which we aren't connected to can be retrieved.
func (l *Libp2p) Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error) {
if key == "" {
return nil, errors.New("advertisement key is empty")
}
customCID, err := createCIDFromKey(key)
if err != nil {
return nil, fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
addrInfo, err := l.DHT.FindProviders(ctx, customCID)
if err != nil {
return nil, fmt.Errorf("failed to find providers for key %s: %w", key, err)
}
var advertisements []*commonproto.Advertisement
for _, v := range addrInfo {
// TODO: use go routines to get the values in parallel.
bytesAdvertisement, err := l.DHT.GetValue(ctx, l.getCustomNamespace(key, v.ID.String()))
if err != nil {
continue
}
var ad commonproto.Advertisement
if err := proto.Unmarshal(bytesAdvertisement, &ad); err != nil {
return nil, fmt.Errorf("failed to unmarshal advertisement payload: %w", err)
}
advertisements = append(advertisements, &ad)
}
return advertisements, nil
}
// Advertise given data and a key pushes the data to the dht.
func (l *Libp2p) Advertise(ctx context.Context, key string, data []byte) error {
if key == "" {
return errors.New("advertisement key is empty")
}
pubKeyBytes, err := l.getPublicKey()
if err != nil {
return fmt.Errorf("failed to get public key: %w", err)
}
envelope := &commonproto.Advertisement{
PeerId: l.Host.ID().String(),
Timestamp: time.Now().Unix(),
Data: data,
PublicKey: pubKeyBytes,
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
pubKeyBytes,
}, nil)
sig, err := l.sign(concatenatedBytes)
if err != nil {
return fmt.Errorf("failed to sign advertisement envelope content: %w", err)
}
envelope.Signature = sig
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
return fmt.Errorf("failed to marshal advertise envelope: %w", err)
}
customCID, err := createCIDFromKey(key)
if err != nil {
return fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
err = l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), envelopeBytes)
if err != nil {
return fmt.Errorf("failed to put key %s into the dht: %w", key, err)
}
err = l.DHT.Provide(ctx, customCID, true)
if err != nil {
return fmt.Errorf("failed to provide key %s into the dht: %w", key, err)
}
return nil
}
// Unadvertise removes the data from the dht.
func (l *Libp2p) Unadvertise(ctx context.Context, key string) error {
err := l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), nil)
if err != nil {
return fmt.Errorf("failed to remove key %s from the DHT: %w", key, err)
}
return nil
}
// Publish publishes data to a topic.
// The requirements are that only one topic handler should exist per topic.
func (l *Libp2p) Publish(ctx context.Context, topic string, data []byte) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to publish: %w", err)
}
err = topicHandler.Publish(ctx, data)
if err != nil {
return fmt.Errorf("failed to publish to topic %s: %w", topic, err)
}
return nil
}
// Subscribe subscribes to a topic and sends the messages to the handler.
func (l *Libp2p) Subscribe(ctx context.Context, topic string, handler func(data []byte)) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to subscribe to topic: %w", err)
}
sub, err := topicHandler.Subscribe()
if err != nil {
return fmt.Errorf("failed to subscribe to topic %s: %w", topic, err)
}
l.topicMux.Lock()
l.topicSubscription[topic] = sub
l.topicMux.Unlock()
go func() {
for {
msg, err := sub.Next(ctx)
if err != nil {
continue
}
handler(msg.Data)
}
}()
return nil
}
func (l *Libp2p) sendMessage(ctx context.Context, addr string, msg types.MessageEnvelope) error {
peerAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return fmt.Errorf("invalid multiaddr %s: %v", addr, err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
return fmt.Errorf("failed to get peer info %s: %v", addr, err)
}
// we are delivering a message to ourself
// we should use the handler to send the message to the handler directly which has been previously registered.
if peerInfo.ID.String() == l.Host.ID().String() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data)
return nil
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return fmt.Errorf("failed to connect to peer %v: %v", peerInfo.ID, err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(msg.Type))
if err != nil {
return fmt.Errorf("failed to open stream to peer %v: %v", peerInfo.ID, err)
}
defer stream.Close()
requestBufferSize := 8 + len(msg.Data)
if requestBufferSize > MaxMessageLengthMB*MB {
return fmt.Errorf("message size %d is greater than limit %d bytes", requestBufferSize, MaxMessageLengthMB*MB)
}
requestPayloadWithLength := make([]byte, requestBufferSize)
binary.LittleEndian.PutUint64(requestPayloadWithLength, uint64(len(msg.Data)))
copy(requestPayloadWithLength[8:], msg.Data)
_, err = stream.Write(requestPayloadWithLength)
if err != nil {
return fmt.Errorf("failed to send message to peer %v: %v", peerInfo.ID, err)
}
return nil
}
// getOrJoinTopicHandler gets the topic handler, it will be created if it doesn't exist.
// for publishing and subscribing its needed therefore its implemented in this function.
func (l *Libp2p) getOrJoinTopicHandler(topic string) (*pubsub.Topic, error) {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
t, err := l.pubsub.Join(topic)
if err != nil {
return nil, fmt.Errorf("failed to join topic %s: %w", topic, err)
}
topicHandler = t
l.pubsubTopics[topic] = t
}
return topicHandler, nil
}
// Unsubscribe cancels the subscription to a topic
func (l *Libp2p) Unsubscribe(topic string) error {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to topic: %s", topic)
}
// delete subscription handler and subscription
sub, ok := l.topicSubscription[topic]
if ok {
sub.Cancel()
delete(l.topicSubscription, topic)
}
if err := topicHandler.Close(); err != nil {
return fmt.Errorf("failed to close topic handler: %w", err)
}
delete(l.pubsubTopics, topic)
return nil
}
func (l *Libp2p) VisiblePeers() []peer.AddrInfo {
return l.discoveredPeers
}
func (l *Libp2p) KnownPeers() ([]peer.AddrInfo, error) {
knownPeers := l.Host.Peerstore().Peers()
peers := make([]peer.AddrInfo, 0, len(knownPeers))
for _, p := range knownPeers {
peers = append(peers, peer.AddrInfo{ID: p})
}
return peers, nil
}
func (l *Libp2p) DumpDHTRoutingTable() ([]kbucket.PeerInfo, error) {
rt := l.DHT.RoutingTable()
return rt.GetPeerInfos(), nil
}
func (l *Libp2p) registerStreamHandlers() {
l.pingService = ping.NewPingService(l.Host)
l.Host.SetStreamHandler(protocol.ID("/ipfs/ping/1.0.0"), l.pingService.PingHandler)
}
func (l *Libp2p) sign(data []byte) ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
signature, err := privKey.Sign(data)
if err != nil {
return nil, fmt.Errorf("failed to sign data: %w", err)
}
return signature, nil
}
func (l *Libp2p) getPublicKey() ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
pubKey := privKey.GetPublic()
return pubKey.Raw()
}
func (l *Libp2p) getCustomNamespace(key, peerID string) string {
return fmt.Sprintf("%s-%s-%s", l.config.CustomNamespace, key, peerID)
}
func createCIDFromKey(key string) (cid.Cid, error) {
hash := sha256.Sum256([]byte(key))
mh, err := multihash.Encode(hash[:], multihash.SHA2_256)
if err != nil {
return cid.Cid{}, err
}
return cid.NewCidV1(cid.Raw, mh), nil
}
func CleanupPeer(id peer.ID) error {
zlog.Warn("CleanupPeer: Stub")
return nil
}
func PingPeer(ctx context.Context, target peer.ID) (bool, *ping.Result) {
zlog.Warn("PingPeer: Stub")
return false, nil
}
func DumpKademliaDHT(ctx context.Context) ([]types.PeerData, error) {
zlog.Warn("DumpKademliaDHT: Stub")
return nil, nil
}
func OldPingPeer(ctx context.Context, target peer.ID) (bool, *types.PingResult) {
zlog.Warn("OldPingPeer: Stub")
return false, nil
}
package libp2p
import (
"bytes"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/pnet"
"github.com/spf13/afero"
)
/*
** Swarm key **
By default, the swarm key shall be stored in a file named `swarm.key`
using the following pathbased codec:
`/key/swarm/psk/1.0.0/<base_encoding>/<256_bits_key>`
`<base_encoding>` is either bin, base16 or base64.
*/
// TODO-pnet-1: we shouldn't handle configuration paths here, a general configuration path
// should be provided by /internal/config.go
func getBasePath(fs afero.Fs) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("error getting home directory: %w", err)
}
nunetDir := filepath.Join(homeDir, ".nunet")
return nunetDir, nil
}
// configureSwarmKey try to read the swarm key from `<config_path>/swarm.key` file.
// If a swarm key is not found, generate a new one.
//
// TODO-ask: should we continue to generate a new swarm key if one is not found?
// Or we should enforce the user to use some cmd/API rpc to generate a new one?
func configureSwarmKey(fs afero.Fs) (pnet.PSK, error) {
var psk pnet.PSK
var err error
psk, err = getSwarmKey(fs)
if err != nil {
psk, err = generateSwarmKey(fs)
if err != nil {
return nil, fmt.Errorf("failed to generate new swarm key: %w", err)
}
}
return psk, nil
}
// getSwarmKey reads the swarm key from a file
func getSwarmKey(fs afero.Fs) (pnet.PSK, error) {
homeDir, err := getBasePath(fs)
swarmkey, err := afero.ReadFile(fs, filepath.Join(homeDir, "swarm.key"))
if err != nil {
return nil, fmt.Errorf("failed to read swarm key file: %w", err)
}
psk, err := pnet.DecodeV1PSK(bytes.NewReader(swarmkey))
if err != nil {
return nil, fmt.Errorf("failed to configure private network: %s", err)
}
// TODO-ask: should we return psk fingerprint?
return psk, nil
}
// generateSwarmKey generates a new swarm key, storing it within
// `<nunet_config_dir>/swarm.key`.
func generateSwarmKey(fs afero.Fs) (pnet.PSK, error) {
priv, _, err := crypto.GenerateKeyPair(crypto.Secp256k1, 256)
if err != nil {
return nil, err
}
privBytes, err := crypto.MarshalPrivateKey(priv)
if err != nil {
return nil, err
}
encodedKey := base64.StdEncoding.EncodeToString(privBytes)
swarmKeyWithCodec := fmt.Sprintf("/key/swarm/psk/1.0.0/\n/base64/\n%s\n", encodedKey)
// TODO-pnet-1
nunetDir, err := getBasePath(fs)
if err != nil {
return nil, err
}
swarmKeyPath := filepath.Join(nunetDir, "swarm.key")
if err := afero.WriteFile(fs, swarmKeyPath, []byte(swarmKeyWithCodec), 0600); err != nil {
return nil, fmt.Errorf("error writing swarm key to file: %w", err)
}
psk, err := pnet.DecodeV1PSK(bytes.NewReader([]byte(swarmKeyWithCodec)))
if err != nil {
return nil, fmt.Errorf("failed to decode generated swarm key: %s", err)
}
zlog.Sugar().Infof("A new Swarm key was generated and written to %s\n"+
"IMPORTANT: If you'd like to create the swarm key using a cryptography algorithm "+
"of your choice, just modify the swarm.key file with your own key.\n"+
"The content of `swarm.key` should look like: `/key/swarm/psk/1.0.0/<base_encoding>/<your_key>`\n"+
"where `<base_encoding>` is either `bin`, `base16`, or `base64`.\n",
swarmKeyPath,
)
return psk, nil
}
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/spf13/afero"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/network/libp2p"
)
// Messenger defines the interface for sending messages.
type Messenger interface {
// SendMessage sends a message to the given address.
SendMessage(ctx context.Context, addrs []string, msg types.MessageEnvelope) error
}
type Network interface {
// Messenger embedded interface
Messenger
// Init initializes the network
Init(context.Context) error
// Start starts the network
Start(context context.Context) error
// Stat returns the network information
Stat() types.NetworkStats
// Ping pings the given address and returns the PingResult
Ping(ctx context.Context, address string, timeout time.Duration) (types.PingResult, error)
// HandleMessage is responsible for registering a message type and its handler.
HandleMessage(messageType string, handler func(data []byte)) error
// ResolveAddress given an id it retruns the address of the peer.
// In libp2p, id represents the peerID and the response is the addrinfo
ResolveAddress(ctx context.Context, id string) ([]string, error)
// Advertise advertises the given data with the given adId
// such as advertising device capabilities on the DHT
Advertise(ctx context.Context, key string, data []byte) error
// Unadvertise stops advertising data corresponding to the given adId
Unadvertise(ctx context.Context, key string) error
// Query returns the network advertisement
Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error)
// Publish publishes the given data to the given topic if the network
// type allows publish/subscribe functionality such as gossipsub or nats
Publish(ctx context.Context, topic string, data []byte) error
// Subscribe subscribes to the given topic and calls the handler function
// if the network type allows it simmilar to Publish()
Subscribe(ctx context.Context, topic string, handler func(data []byte)) error
// Unsubscribe from a topic
Unsubscribe(topic string) error
// Stop stops the network including any existing advertisments and subscriptions
Stop() error
}
// NewNetwork returns a new network given the configuration.
func NewNetwork(netConfig *types.NetworkConfig, fs afero.Fs) (Network, error) {
// TODO: probable additional params to receive: DB, FileSystem
if netConfig == nil {
return nil, errors.New("network configuration is nil")
}
switch netConfig.Type {
case types.Libp2pNetwork:
ln, err := libp2p.New(&netConfig.Libp2pConfig, fs)
return ln, err
case types.NATSNetwork:
return nil, errors.New("not implemented")
default:
return nil, fmt.Errorf("unsupported network type: %s", netConfig.Type)
}
}
package basic_controller
import (
"context"
"fmt"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/utils"
)
// BasicVolumeController is the default implementation of the VolumeController.
// It persists storage volumes information using the StorageVolume.
type BasicVolumeController struct {
// repo is the repository for storage volume operations
repo repositories.StorageVolume
// basePath is the base path where volumes are stored under
basePath string
// file system to act upon
FS afero.Fs
}
// NewDefaultVolumeController returns a new instance of BasicVolumeController
//
// TODO-BugFix [path]: volBasePath might not end with `/`, causing errors when calling methods.
// We need to validate it using the `path` library or just verifying the string.
func NewDefaultVolumeController(repo repositories.StorageVolume, volBasePath string, fs afero.Fs) (*BasicVolumeController, error) {
return &BasicVolumeController{
repo: repo,
basePath: volBasePath,
FS: fs,
}, nil
}
// CreateVolume creates a new storage volume given a source (S3, IPFS, job, etc). The
// creation of a storage volume effectively creates an empty directory in the local filesystem
// and writes a record in the database.
//
// The directory name follows the format: `<volSource> + "-" + <name>
// where `name` is random.
//
// TODO-maybe [withName]: allow callers to specify custom name for path
func (vc *BasicVolumeController) CreateVolume(volSource storage.VolumeSource, opts ...storage.CreateVolOpt) (types.StorageVolume, error) {
vol := types.StorageVolume{
Private: false,
ReadOnly: false,
EncryptionType: types.EncryptionTypeNull,
}
for _, opt := range opts {
opt(&vol)
}
vol.Path = vc.basePath + string(volSource) + "-" + utils.RandomString(16)
if err := vc.FS.Mkdir(vol.Path, 0770); err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %w", err)
}
createdVol, err := vc.repo.Create(context.Background(), vol)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume in repository: %w", err)
}
return createdVol, nil
}
// LockVolume makes the volume read-only, not only changing the field value but also changing file permissions.
// It should be used after all necessary data has been written.
// It optionally can also set the CID and mark the volume as private.
//
// TODO-maybe [CID]: maybe calculate CID of every volume in case WithCID opt is not provided
func (vc *BasicVolumeController) LockVolume(pathToVol string, opts ...storage.LockVolOpt) error {
query := vc.repo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("Path", pathToVol))
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
return fmt.Errorf("failed to find storage volume with path %s - Error: %w", pathToVol, err)
}
for _, opt := range opts {
opt(&vol)
}
// update records
vol.ReadOnly = true
updatedVol, err := vc.repo.Update(context.Background(), vol.ID, vol)
if err != nil {
return fmt.Errorf("failed to update storage volume with path %s - Error: %w", pathToVol, err)
}
// change file permissions
if err := vc.FS.Chmod(updatedVol.Path, 0400); err != nil {
return fmt.Errorf("failed to make storage volume read-only (path: %s): %w", updatedVol.Path, err)
}
return nil
}
// WithPrivate designates a given volume as private. It can be used both
// when creating or locking a volume.
func WithPrivate[T storage.CreateVolOpt | storage.LockVolOpt]() T {
return func(v *types.StorageVolume) {
v.Private = true
}
}
// WithCID sets the CID of a given volume if already calculated
//
// TODO [validate]: check if CID provided is valid
func WithCID(cid string) storage.LockVolOpt {
return func(v *types.StorageVolume) {
v.CID = cid
}
}
// DeleteVolume deletes a given storage volume record from the database.
// Identifier is either a CID or a path of a volume. Therefore, records for both
// will be deleted.
//
// Note [CID]: if we start to type CID as cid.CID, we may have to use generics here
// as in `[T string | cid.CID]`
func (vc *BasicVolumeController) DeleteVolume(identifier string, idType storage.IDType) error {
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
return fmt.Errorf("identifier type not supported")
}
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
if err == repositories.NotFoundError {
return fmt.Errorf("volume not found: %w", err)
}
return fmt.Errorf("failed to find volume: %w", err)
}
err = vc.repo.Delete(context.Background(), vol.ID)
if err != nil {
return fmt.Errorf("failed to delete volume: %w", err)
}
return nil
}
// ListVolumes returns a list of all storage volumes stored on the database
//
// TODO [filter]: maybe add opts to filter results by certain values
func (vc *BasicVolumeController) ListVolumes() ([]types.StorageVolume, error) {
volumes, err := vc.repo.FindAll(context.Background(), vc.repo.GetQuery())
if err != nil {
return nil, fmt.Errorf("failed to list volumes: %w", err)
}
return volumes, nil
}
// GetSize returns the size of a volume
// TODO-minor: identify which measurement type will be used
func (vc *BasicVolumeController) GetSize(identifier string, idType storage.IDType) (int64, error) {
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
return 0, fmt.Errorf("unsupported ID type: %d", idType)
}
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
return 0, fmt.Errorf("failed to find volume: %w", err)
}
size, err := utils.GetDirectorySize(vc.FS, vol.Path)
if err != nil {
return 0, fmt.Errorf("failed to get directory size: %w", err)
}
return size, nil
}
// EncryptVolume encrypts a given volume
func (vc *BasicVolumeController) EncryptVolume(path string, encryptor types.Encryptor, encryptionType types.EncryptionType) error {
return fmt.Errorf("not implemented")
}
// DecryptVolume decrypts a given volume
func (vc *BasicVolumeController) DecryptVolume(path string, decryptor types.Decryptor, decryptionType types.EncryptionType) error {
return fmt.Errorf("not implemented")
}
// TODO-minor: compiler-time check for interface implementation
var _ storage.VolumeController = (*BasicVolumeController)(nil)
package basic_controller
import (
"context"
"fmt"
"os"
"testing"
clover "github.com/ostafen/clover/v2"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories/clover"
"gitlab.com/nunet/device-management-service/types"
)
type VolControllerTestSuiteHelper struct {
BasicVolController *BasicVolumeController
Fs afero.Fs
DB *clover.DB
Volumes map[string]*types.StorageVolume
TempDBDir string
}
// SetupVolControllerTestSuite sets up a volume controller with 0-n volumes given a base path.
// If volumes are inputed, directories will be created and volumes will be stored in the database
func SetupVolControllerTestSuite(t *testing.T, basePath string,
volumes map[string]*types.StorageVolume) (*VolControllerTestSuiteHelper, error) {
tempDir, err := os.MkdirTemp("", "clover-test-*")
if err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
db, err := repositories_clover.NewDB(tempDir, []string{"storage_volume"})
if err != nil {
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to open clover db: %w", err)
}
fs := afero.NewMemMapFs()
err = fs.MkdirAll(basePath, 0755)
if err != nil {
db.Close()
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to create base path: %w", err)
}
repo := repositories_clover.NewStorageVolume(db)
vc, err := NewDefaultVolumeController(repo, basePath, fs)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume controller: %w", err)
}
for _, vol := range volumes {
// create root volume dir
err = fs.MkdirAll(vol.Path, 0755)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume dir: %w", err)
}
// create volume record in db
_, err = repo.Create(context.Background(), *vol)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume record: %w", err)
}
}
helper := &VolControllerTestSuiteHelper{vc, fs, db, volumes, tempDir}
t.Cleanup(func() {
TearDownVolControllerTestSuite(helper)
})
return helper, nil
}
// TearDownVolControllerTestSuite cleans up resources created during setup
func TearDownVolControllerTestSuite(helper *VolControllerTestSuiteHelper) {
if helper.DB != nil {
helper.DB.Close()
}
if helper.TempDBDir != "" {
os.RemoveAll(helper.TempDBDir)
}
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/storage/basic_controller"
)
// Download fetch files from a given S3 bucket. The key may be a directory ending
// with `/` or have a wildcard (`*`) so it handles normal S3 folders but it does
// not handle x-directory.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *S3Storage) Download(ctx context.Context, sourceSpecs *types.SpecConfig) (
types.StorageVolume, error) {
var storageVol types.StorageVolume
source, err := DecodeInputSpec(sourceSpecs)
if err != nil {
return types.StorageVolume{}, err
}
storageVol, err = s.volController.CreateVolume(storage.VolumeSourceS3)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %v", err)
}
resolvedObjects, err := resolveStorageKey(ctx, s.Client, &source)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to resolve storage key: %v", err)
}
for _, resolvedObject := range resolvedObjects {
err = s.downloadObject(ctx, &source, resolvedObject, storageVol.Path)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to download s3 object: %v", err)
}
}
// after data is filled within the volume, we have to lock it
err = s.volController.LockVolume(storageVol.Path)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to lock storage volume: %v", err)
}
return storageVol, nil
}
func (s *S3Storage) downloadObject(ctx context.Context, source *S3InputSource,
object s3Object, volPath string) error {
outputPath := filepath.Join(volPath, *object.key)
// use the same file system instance used by the Volume Controller
var fs afero.Fs
if basicVolController, ok := s.volController.(*basic_controller.BasicVolumeController); ok {
fs = basicVolController.FS
}
err := fs.MkdirAll(outputPath, 0755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
if object.isDir {
// if object is a directory, we don't need to download it (just create the dir)
return nil
}
outputFile, err := fs.OpenFile(outputPath, os.O_RDWR|os.O_CREATE, 0755)
if err != nil {
return err
}
defer outputFile.Close()
zlog.Sugar().Debugf("Downloading s3 object %s to %s", *object.key, outputPath)
_, err = s.downloader.Download(ctx, outputFile, &s3.GetObjectInput{
Bucket: aws.String(source.Bucket),
Key: object.key,
IfMatch: object.eTag,
})
return nil
}
// resolveStorageKey returns a list of s3 objects within a bucket accordingly to the key provided.
func resolveStorageKey(ctx context.Context, client *s3.Client, source *S3InputSource) ([]s3Object, error) {
key := source.Key
if key == "" {
return nil, fmt.Errorf("key is required")
}
// Check if the key represents a single object
if !strings.HasSuffix(key, "/") && !strings.Contains(key, "*") {
return resolveSingleObject(ctx, client, source)
}
// key represents multiple objects
return resolveObjectsWithPrefix(ctx, client, source)
}
func resolveSingleObject(ctx context.Context, client *s3.Client, source *S3InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
headObjectInput := &s3.HeadObjectInput{
Bucket: aws.String(source.Bucket),
Key: aws.String(key),
}
headObjectOut, err := client.HeadObject(ctx, headObjectInput)
if err != nil {
return []s3Object{}, fmt.Errorf("failed to retrieve object metadata: %v", err)
}
// TODO-minor: validate checksum if provided
if strings.HasPrefix(*headObjectOut.ContentType, "application/x-directory") {
return []s3Object{}, fmt.Errorf("x-directory is not yet handled!")
}
return []s3Object{
{
key: aws.String(source.Key),
eTag: headObjectOut.ETag,
size: *headObjectOut.ContentLength,
},
}, nil
}
func resolveObjectsWithPrefix(ctx context.Context, client *s3.Client, source *S3InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
// List objects with the given prefix
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(source.Bucket),
Prefix: aws.String(key),
}
var objects []s3Object
paginator := s3.NewListObjectsV2Paginator(client, listObjectsInput)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list objects: %v", err)
}
for _, obj := range page.Contents {
objects = append(objects, s3Object{
key: aws.String(*obj.Key),
size: *obj.Size,
isDir: strings.HasSuffix(*obj.Key, "/"),
})
}
}
return objects, nil
}
package s3
import (
"context"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
)
// GetAWSDefaultConfig returns the default AWS config based on environment variables,
// shared configuration and shared credentials files.
func GetAWSDefaultConfig() (aws.Config, error) {
var optFns []func(*config.LoadOptions) error
return config.LoadDefaultConfig(context.Background(), optFns...)
}
func hasValidCredentials(config aws.Config) bool {
credentials, err := config.Credentials.Retrieve(context.Background())
if err != nil {
return false
}
return credentials.HasKeys()
}
// sanitizeKey removes trailing spaces and wildcards
func sanitizeKey(key string) string {
return strings.TrimSuffix(strings.TrimSpace(key), "*")
}
package s3
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("s3")
}
package s3
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/storage"
)
type S3Storage struct {
*s3.Client
volController storage.VolumeController
downloader *s3Manager.Downloader
uploader *s3Manager.Uploader
}
type s3Object struct {
key *string
eTag *string
versionID *string
size int64
isDir bool
}
// NewClient creates a new S3Storage which includes a S3-SDK client.
// It depends on a VolumeController to manage the volumes being acted upon.
func NewClient(config aws.Config, volController storage.VolumeController) (*S3Storage, error) {
if !hasValidCredentials(config) {
return nil, fmt.Errorf("invalid credentials")
}
s3Client := s3.NewFromConfig(config)
return &S3Storage{
s3Client,
volController,
s3Manager.NewDownloader(s3Client),
s3Manager.NewUploader(s3Client),
}, nil
}
func (s *S3Storage) Size(ctx context.Context, source *types.SpecConfig) (uint64, error) {
inputSource, err := DecodeInputSpec(source)
if err != nil {
return 0, fmt.Errorf("failed to decode input spec: %v", err)
}
input := &s3.HeadObjectInput{
Bucket: aws.String(inputSource.Bucket),
Key: aws.String(inputSource.Key),
}
output, err := s.HeadObject(ctx, input)
if err != nil {
return 0, fmt.Errorf("failed to get object size: %v", err)
}
return uint64(*output.ContentLength), nil
}
// Compile time interface check
// var _ storage.StorageProvider = (*S3Storage)(nil)
package s3
import (
"fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure"
"gitlab.com/nunet/device-management-service/types"
)
type S3InputSource struct {
Bucket string
Key string
Filter string
Region string
Endpoint string
}
func (s S3InputSource) Validate() error {
if s.Bucket == "" {
return fmt.Errorf("invalid s3 storage params: bucket cannot be empty")
}
return nil
}
func (s S3InputSource) ToMap() map[string]interface{} {
return structs.Map(s)
}
func DecodeInputSpec(spec *types.SpecConfig) (S3InputSource, error) {
if !spec.IsType(types.StorageProviderS3) {
return S3InputSource{}, fmt.Errorf("invalid storage source type. Expected %s but received %s", types.StorageProviderS3, spec.Type)
}
inputParams := spec.Params
if inputParams == nil {
return S3InputSource{}, fmt.Errorf("invalid storage input source params. cannot be nil")
}
var c S3InputSource
if err := mapstructure.Decode(spec.Params, &c); err != nil {
return c, err
}
return c, c.Validate()
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/storage/basic_controller"
)
// Upload uploads all files (recursively) from a local volume to an S3 bucket.
// It handles directories.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *S3Storage) Upload(ctx context.Context, vol types.StorageVolume,
destinationSpecs *types.SpecConfig) error {
target, err := DecodeInputSpec(destinationSpecs)
if err != nil {
return fmt.Errorf("failed to decode input spec: %v", err)
}
sanitizedKey := sanitizeKey(target.Key)
// set file system to act upon based on the volume controller implementation
var fs afero.Fs
if basicVolController, ok := s.volController.(*basic_controller.BasicVolumeController); ok {
fs = basicVolController.FS
}
zlog.Sugar().Debugf("Uploading files from %s to s3://%s/%s", vol.Path, target.Bucket, sanitizedKey)
err = afero.Walk(fs, vol.Path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Skip directories
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(vol.Path, filePath)
if err != nil {
return fmt.Errorf("failed to get relative path: %v", err)
}
// Construct the S3 key by joining the sanitized key and the relative path
s3Key := filepath.Join(sanitizedKey, relPath)
file, err := fs.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
zlog.Sugar().Debugf("Uploading %s to s3://%s/%s", filePath, target.Bucket, s3Key)
_, err = s.uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(target.Bucket),
Key: aws.String(s3Key),
Body: file,
})
if err != nil {
return fmt.Errorf("failed to upload file to S3: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("upload failed. It's possible that some files were uploaded; Error: %v", err)
}
return nil
}
package utils
import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/cosmos/btcutil/bech32"
"github.com/ethereum/go-ethereum/common"
"github.com/fivebinaries/go-cardano-serialization/address"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
)
// KoiosEndpoint type for Koios rest api endpoints
type KoiosEndpoint string
const (
// KoiosMainnet - mainnet Koios rest api endpoint
KoiosMainnet KoiosEndpoint = "api.koios.rest"
// KoiosPreProd - testnet preprod Koios rest api endpoint
KoiosPreProd KoiosEndpoint = "preprod.koios.rest"
)
type UTXOs struct {
TxHash string `json:"tx_hash"`
IsSpent bool `json:"is_spent"`
}
type TxHashResp struct {
TxHash string `json:"tx_hash"`
TransactionType string `json:"transaction_type"`
DateTime string `json:"date_time"`
}
type ClaimCardanoTokenBody struct {
ComputeProviderAddress string `json:"compute_provider_address"`
TxHash string `json:"tx_hash"`
}
type rewardRespToCPD struct {
ServiceProviderAddr string `json:"service_provider_addr"`
ComputeProviderAddr string `json:"compute_provider_addr"`
RewardType string `json:"reward_type,omitempty"`
SignatureDatum string `json:"signature_datum,omitempty"`
MessageHashDatum string `json:"message_hash_datum,omitempty"`
Datum string `json:"datum,omitempty"`
SignatureAction string `json:"signature_action,omitempty"`
MessageHashAction string `json:"message_hash_action,omitempty"`
Action string `json:"action,omitempty"`
}
type UpdateTxStatusBody struct {
Address string `json:"address,omitempty"`
}
func GetJobTxHashes(size int, clean string) ([]TxHashResp, error) {
if clean != "done" && clean != "refund" && clean != "withdraw" && clean != "" {
return nil, fmt.Errorf("invalid clean_tx parameter")
}
err := db.DB.Where("transaction_type = ?", clean).Delete(&types.Services{}).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
}
var resp []TxHashResp
var services []types.Services
if size == 0 {
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type is NOT NULL").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("no job deployed to request reward for: %w", err)
}
} else {
services, err = getLimitedTransactions(size)
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("could not get limited transactions: %w", err)
}
}
for _, service := range services {
resp = append(resp, TxHashResp{
TxHash: service.TxHash,
TransactionType: service.TransactionType,
DateTime: service.CreatedAt.String(),
})
}
return resp, nil
}
func RequestReward(claim ClaimCardanoTokenBody) (*rewardRespToCPD, error) {
// At some point, management dashboard should send container ID to identify
// against which container we are requesting reward
service := types.Services{
TxHash: claim.TxHash,
}
// SELECTs the first record; first record which is not marked as delete
err := db.DB.Where("tx_hash = ?", claim.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
return nil, fmt.Errorf("unknown tx hash: %w", err)
}
zlog.Sugar().Infof("service found from txHash: %+v", service)
if service.JobStatus == "running" {
return nil, fmt.Errorf("job is still running")
// c.JSON(503, gin.H{"error": "the job is still running"})
}
reward := rewardRespToCPD{
ServiceProviderAddr: service.ServiceProviderAddr,
ComputeProviderAddr: service.ComputeProviderAddr,
RewardType: service.TransactionType,
SignatureDatum: service.SignatureDatum,
MessageHashDatum: service.MessageHashDatum,
Datum: service.Datum,
SignatureAction: service.SignatureAction,
MessageHashAction: service.MessageHashAction,
Action: service.Action,
}
return &reward, nil
}
func SendStatus(status types.BlockchainTxStatus) string {
if status.TransactionStatus == "success" {
zlog.Sugar().Infof("withdraw transaction successful - updating DB")
// Partial deletion of entry
var service types.Services
err := db.DB.Where("tx_hash = ?", status.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
}
service.TransactionType = "done"
db.DB.Save(&service)
}
return status.TransactionStatus
}
func UpdateStatus(body UpdateTxStatusBody) error {
utxoHashes, err := GetUTXOsOfSmartContract(body.Address, KoiosPreProd)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to fetch UTXOs from Blockchain: %w", err)
}
fiveMinAgo := time.Now().Add(-5 * time.Minute)
var services []types.Services
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Where("deleted_at IS NULL").
Where("created_at <= ?", fiveMinAgo).
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("no job deployed to request reward for: %w", err)
}
err = UpdateTransactionStatus(services, utxoHashes)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to update transaction status")
}
return nil
}
func getLimitedTransactions(sizeDone int) ([]types.Services, error) {
var doneServices []types.Services
var services []types.Services
err := db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type = ?", "done").
Order("created_at DESC").
Limit(sizeDone).
Find(&doneServices).Error
if err != nil {
return []types.Services{}, err
}
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
return []types.Services{}, err
}
services = append(services, doneServices...)
return services, nil
}
// isValidCardano checks if the cardano address is valid
func isValidCardano(addr string, valid *bool) {
defer func() {
if r := recover(); r != nil {
*valid = false
}
}()
if _, err := address.NewAddress(addr); err == nil {
*valid = true
}
}
// ValidateAddress checks if the wallet address is a valid ethereum/cardano address
func ValidateAddress(addr string) error {
if common.IsHexAddress(addr) {
return errors.New("ethereum wallet address not allowed")
}
var validCardano = false
isValidCardano(addr, &validCardano)
if validCardano {
return nil
}
return errors.New("invalid cardano wallet address")
}
func GetAddressPaymentCredential(addr string) (string, error) {
_, data, err := bech32.Decode(addr, 1023)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
converted, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
return hex.EncodeToString(converted)[2:58], nil
}
// GetTxReceiver returns the list of receivers of a transaction from the transaction hash
func GetTxReceiver(txHash string, endpoint KoiosEndpoint) (string, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_info", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return "", err
}
defer resp.Body.Close()
res := []struct {
Outputs []struct {
InlineDatum struct {
Value struct {
Fields []struct {
Bytes string `json:"bytes"`
} `json:"fields"`
} `json:"value"`
} `json:"inline_datum"`
} `json:"outputs"`
}{}
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&res); err != nil && err != io.EOF {
return "", err
}
if len(res) == 0 || len(res[0].Outputs) == 0 || len(res[0].Outputs[1].InlineDatum.Value.Fields) == 0 {
return "", fmt.Errorf("unable to find receiver")
}
receiver := res[0].Outputs[1].InlineDatum.Value.Fields[1].Bytes
return receiver, nil
}
// GetTxConfirmations returns the number of confirmations of a transaction from the transaction hash
func GetTxConfirmations(txHash string, endpoint KoiosEndpoint) (int, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_status", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return 0, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}
var res []struct {
TxHash string `json:"tx_hash"`
Confirmations int `json:"num_confirmations"`
}
if err := json.Unmarshal(body, &res); err != nil {
return 0, err
}
return res[len(res)-1].Confirmations, nil
}
// WaitForTxConfirmation waits for a transaction to be confirmed
func WaitForTxConfirmation(confirmations int, timeout time.Duration, txHash string, endpoint KoiosEndpoint) error {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
conf, err := GetTxConfirmations(txHash, endpoint)
if err != nil {
return err
}
if conf >= confirmations {
return nil
}
case <-time.After(timeout):
return errors.New("timeout")
}
}
}
// GetUTXOsOfSmartContract fetch all utxos of smart contract and return list of tx_hash
func GetUTXOsOfSmartContract(address string, endpoint KoiosEndpoint) ([]string, error) {
type Request struct {
Address []string `json:"_addresses"`
Extended bool `json:"_extended"`
}
reqBody, _ := json.Marshal(Request{Address: []string{address}, Extended: true})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/address_utxos", endpoint),
"application/json",
bytes.NewBuffer(reqBody),
)
if err != nil {
return nil, fmt.Errorf("error making POST request: %v", err)
}
defer resp.Body.Close()
var utxos []UTXOs
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&utxos); err != nil && err != io.EOF {
return nil, err
}
var utxoHashes []string
for _, utxo := range utxos {
utxoHashes = append(utxoHashes, utxo.TxHash)
}
return utxoHashes, nil
}
// UpdateTransactionStatus updates the status of claimed transactions in local DB
func UpdateTransactionStatus(services []types.Services, utxoHashes []string) error {
for _, service := range services {
if !SliceContains(utxoHashes, service.TxHash) {
if service.TransactionType == "withdraw" {
service.TransactionType = transactionWithdrawnStatus
} else if service.TransactionType == "refund" {
service.TransactionType = transactionRefundedStatus
} else if service.TransactionType == "distribute-50" || service.TransactionType == "distribute-75" {
service.TransactionType = transactionDistributedStatus
}
if err := db.DB.Save(&service).Error; err != nil {
return err
}
}
}
return nil
}
package utils
import (
"fmt"
"os"
"github.com/spf13/afero"
)
func GetDirectorySize(fs afero.Fs, path string) (int64, error) {
var size int64
err := afero.Walk(fs, path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to calculate volume size: %w", err)
}
return size, nil
}
package utils
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
const transactionWithdrawnStatus = "withdrawn"
const transactionRefundedStatus = "refunded"
const transactionDistributedStatus = "distributed"
func init() {
zlog = logger.OtelZapLogger("utils")
}
package utils
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"github.com/gin-gonic/gin"
"gitlab.com/nunet/device-management-service/internal/config"
)
// InternalAPIURL is a helper method to compose API URLs
func InternalAPIURL(protocol, endpoint, query string) (string, error) {
if protocol == "" || endpoint == "" {
return "", fmt.Errorf("protocol and endpoint values must be specified")
}
if protocol != "http" && protocol != "ws" {
return "", fmt.Errorf("invalid protocol: %s", protocol)
}
port := config.GetConfig().Rest.Port
if port == 0 {
return "", fmt.Errorf("port is not configured")
}
serverURL := url.URL{
Scheme: protocol,
Host: fmt.Sprintf("localhost:%d", port),
Path: endpoint,
RawQuery: query,
}
return serverURL.String(), nil
}
// MakeInternalRequest is a helper method to make call to DMS's own API
func MakeInternalRequest(c *gin.Context, methodType, internalEndpoint, query string, body []byte) (*http.Response, error) {
endpoint, err := InternalAPIURL("http", internalEndpoint, query)
if err != nil {
return nil, err
}
req, err := http.NewRequest(
methodType,
endpoint,
bytes.NewBuffer(body),
)
if err != nil {
return nil, err
}
client := http.Client{}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
if c != nil {
req.Header = c.Request.Header.Clone()
}
resp, err := client.Do(req)
if err != nil {
return nil, err
// c.JSON(400, gin.H{
// "message": fmt.Sprintf("Error making %s request to %s", methodType, internalEndpoint),
// "timestamp": time.Now(),
// })
// return
}
return resp, nil
}
func MakeRequest(c *gin.Context, client *http.Client, uri string, body []byte, errMsg string) error {
// set the HTTP method, url, and request body
req, err := http.NewRequest(http.MethodPut, uri, bytes.NewBuffer(body))
if err != nil {
return fmt.Errorf("unable to make new request: %v", err)
}
// set the request header Content-Type for json
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
_, err = client.Do(req)
if err != nil {
// c.JSON(400, gin.H{
// "message": errMsg,
// "timestamp": time.Now(),
// })
// return
return err
}
return nil
}
func ResponseBody(c *gin.Context, methodType, internalEndpoint, query string, body []byte) (responseBody []byte, errMsg error) {
resp, err := MakeInternalRequest(c, methodType, internalEndpoint, query, body)
if err != nil {
return nil, fmt.Errorf("unable to make internal request: %v", err)
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("cannot read response body: %v", err)
}
return respBody, nil
}
package utils
import (
"io"
"sync"
"time"
)
type IOProgress struct {
n float64
size float64
started time.Time
estimated time.Time
err error
}
type Reader struct {
reader io.Reader
lock sync.RWMutex
Progress IOProgress
}
type Writer struct {
writer io.Writer
lock sync.RWMutex
Progress IOProgress
}
func ReaderWithProgress(r io.Reader, size int64) *Reader {
return &Reader{
reader: r,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func WriterWithProgress(w io.Writer, size int64) *Writer {
return &Writer{
writer: w,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func (r *Reader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.lock.Lock()
r.Progress.n += float64(n)
r.Progress.err = err
r.lock.Unlock()
return n, err
}
func (w *Writer) Write(p []byte) (n int, err error) {
n, err = w.writer.Write(p)
w.lock.Lock()
w.Progress.n += float64(n)
w.Progress.err = err
w.lock.Unlock()
return n, err
}
func (p IOProgress) Size() float64 {
return p.size
}
func (p IOProgress) N() float64 {
return p.n
}
func (p IOProgress) Complete() bool {
if p.err == io.EOF {
return true
}
if p.size == -1 {
return false
}
return p.n >= p.size
}
// Percent calculates the percentage complete.
func (p IOProgress) Percent() float64 {
if p.n == 0 {
return 0
}
if p.n >= p.size {
return 100
}
return 100.0 / (p.size / p.n)
}
func (p IOProgress) Remaining() time.Duration {
if p.estimated.IsZero() {
return time.Until(p.Estimated())
}
return time.Until(p.estimated)
}
func (p IOProgress) Estimated() time.Time {
ratio := p.n / p.size
past := float64(time.Since(p.started))
if p.n > 0.0 {
total := time.Duration(past / ratio)
p.estimated = p.started.Add(total)
}
return p.estimated
}
package utils
import (
"fmt"
"strings"
"sync"
)
// A SyncMap is a concurrency-safe sync.Map that uses strongly-typed
// method signatures to ensure the types of its stored data are known.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// SyncMapFromMap converts a standard Go map to a concurrency-safe SyncMap.
func SyncMapFromMap[K comparable, V any](m map[K]V) *SyncMap[K, V] {
ret := &SyncMap[K, V]{}
for k, v := range m {
ret.Put(k, v)
}
return ret
}
// Get retrieves the value associated with the given key from the map.
// It returns the value and a boolean indicating whether the key was found.
func (m *SyncMap[K, V]) Get(key K) (V, bool) {
value, ok := m.Load(key)
if !ok {
var empty V
return empty, false
}
return value.(V), true
}
// Put inserts or updates a key-value pair in the map.
func (m *SyncMap[K, V]) Put(key K, value V) {
m.Store(key, value)
}
// Iter iterates over each key-value pair in the map, executing the provided function on each pair.
// The iteration stops if the provided function returns false.
func (m *SyncMap[K, V]) Iter(ranger func(key K, value V) bool) {
m.Range(func(key, value any) bool {
k := key.(K)
v := value.(V)
return ranger(k, v)
})
}
// Keys returns a slice containing all the keys present in the map.
func (m *SyncMap[K, V]) Keys() []K {
var keys []K
m.Iter(func(key K, value V) bool {
keys = append(keys, key)
return true
})
return keys
}
// String provides a string representation of the map, listing all key-value pairs.
func (m *SyncMap[K, V]) String() string {
// Use a strings.Builder for efficient string concatenation.
var sb strings.Builder
sb.Write([]byte(`{`))
m.Range(func(key, value any) bool {
// Append each key-value pair to the string builder.
sb.Write([]byte(fmt.Sprintf(`%s=%s`, key, value)))
return true
})
sb.Write([]byte(`}`))
return sb.String()
}
package utils
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"math/rand"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
"reflect"
)
var KernelFileURL = "https://d.nunet.io/fc/vmlinux"
var KernelFilePath = "/etc/nunet/vmlinux"
var FilesystemURL = "https://d.nunet.io/fc/nunet-fc-ubuntu-20.04-0.ext4"
var FilesystemPath = "/etc/nunet/nunet-fc-ubuntu-20.04-0.ext4"
// DownloadFile downloads a file from a url and saves it to a filepath
func DownloadFile(url string, filepath string) (err error) {
zlog.Sugar().Infof("Downloading file '", filepath, "' from '", url, "'")
file, err := os.Create(filepath)
if err != nil {
return err
}
defer file.Close()
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
_, err = io.Copy(file, resp.Body)
if err != nil {
return err
}
log.Println("Finished downloading file '", filepath, "'")
return nil
}
// ReadHttpString GET request to http endpoint and return response as string
func ReadHttpString(url string) (string, error) {
resp, err := http.Get(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(respBody), nil
}
// RandomString generates a random string of length n
func RandomString(n int) string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
sb.WriteByte(charset[rand.Intn(len(charset))])
}
return sb.String()
}
// GenerateMachineUUID generates a machine uuid
func GenerateMachineUUID() (string, error) {
var machine types.MachineUUID
machineUUID, err := uuid.NewDCEGroup()
if err != nil {
return "", err
}
machine.UUID = machineUUID.String()
return machine.UUID, nil
}
// GetMachineUUID returns the machine uuid from the DB
func GetMachineUUID() string {
var machine types.MachineUUID
uuid, err := GenerateMachineUUID()
if err != nil {
zlog.Sugar().Errorf("could not generate machine uuid: %v", err)
}
machine.UUID = uuid
result := db.DB.FirstOrCreate(&machine)
if result.Error != nil {
zlog.Sugar().Errorf("could not find or create machine uuid record in DB: %v", result.Error)
}
return machine.UUID
}
// SliceContains checks if a string exists in a slice
func SliceContains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
// DeleteFile deletes a file, with or without a backup
func DeleteFile(path string, backup bool) (err error) {
if backup {
err = os.Rename(path, fmt.Sprintf("%s.bk.%d", path, time.Now().Unix()))
} else {
err = os.Remove(path)
}
return
}
// ReadyForElastic checks if the device is ready to send logs to elastic
func ReadyForElastic() bool {
elasticToken := types.ElasticToken{}
db.DB.Find(&elasticToken)
return elasticToken.NodeId != "" && elasticToken.ChannelName != ""
}
// PromptYesNo loops on confirmation from user until valid answer
func PromptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) {
reader := bufio.NewReader(in)
for {
fmt.Fprintf(out, "%s (y/N): ", prompt)
response, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("read response string failed: %w", err)
}
response = strings.ToLower(strings.TrimSpace(response))
if response == "y" || response == "yes" {
return true, nil
} else if response == "n" || response == "no" {
return false, nil
}
}
}
// CreateDirectoryIfNotExists creates a directory if it does not exist
func CreateDirectoryIfNotExists(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
err := os.MkdirAll(path, 0755)
if err != nil {
return err
}
}
return nil
}
// CalculateSHA256Checksum calculates the SHA256 checksum of a file
func CalculateSHA256Checksum(filePath string) (string, error) {
// Open the file for reading
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
// Create a new SHA-256 hash
hash := sha256.New()
// Copy the file's contents into the hash object
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
// Calculate the checksum and return it as a hexadecimal string
checksum := hex.EncodeToString(hash.Sum(nil))
return checksum, nil
}
// put checksum in file
func CreateCheckSumFile(filePath string, checksum string) (string, error) {
sha256FilePath := fmt.Sprintf("%s.sha256.txt", filePath)
sha256File, err := os.Create(sha256FilePath)
if err != nil {
return "", fmt.Errorf("unable to create SHA-256 checksum file: %v", err)
}
defer sha256File.Close()
_, err = sha256File.WriteString(checksum)
if err != nil {
return "", fmt.Errorf("unable to write to SHA-256 checksum file: %v", err)
}
return sha256FilePath, nil
}
// ExtractTarGzToPath extracts a tar.gz file to a specified path
func ExtractTarGzToPath(tarGzFilePath, extractedPath string) error {
// Ensure the target directory exists; create it if it doesn't.
if err := os.MkdirAll(extractedPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating target directory: %v", err)
}
tarGzFile, err := os.Open(tarGzFilePath)
if err != nil {
return fmt.Errorf("error opening tar.gz file: %v", err)
}
defer tarGzFile.Close()
gzipReader, err := gzip.NewReader(tarGzFile)
if err != nil {
return fmt.Errorf("error creating gzip reader: %v", err)
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading tar header: %v", err)
}
// Construct the full target path by joining the target directory with
// the name of the file or directory from the archive.
fullTargetPath := filepath.Join(extractedPath, header.Name)
// Ensure that the directory path leading to the file exists.
if header.FileInfo().IsDir() {
// Create the directory and any parent directories as needed.
if err := os.MkdirAll(fullTargetPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
} else {
// Create the file and any parent directories as needed.
if err := os.MkdirAll(filepath.Dir(fullTargetPath), os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// Create a new file with the specified path.
newFile, err := os.Create(fullTargetPath)
if err != nil {
return fmt.Errorf("error creating file: %v", err)
}
defer newFile.Close()
// Copy the file contents from the tar archive to the new file.
if _, err := io.Copy(newFile, tarReader); err != nil {
return fmt.Errorf("error copying file contents: %v", err)
}
}
}
return nil
}
// CheckWSL check if running in WSL
func CheckWSL() (bool, error) {
file, err := os.Open("/proc/version")
if err != nil {
return false, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "Microsoft") || strings.Contains(line, "WSL") {
return true, nil
}
}
if scanner.Err() != nil {
return false, scanner.Err()
}
return false, nil
}
// SaveServiceInfo updates service info into SP's DMS for claim Reward by SP user
func SaveServiceInfo(cpService types.Services) error {
var spService types.Services
err := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Find(&spService).Error
if err != nil {
return fmt.Errorf("Unable to find service on SP side: %v", err)
}
cpService.ID = spService.ID
cpService.CreatedAt = spService.CreatedAt
result := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Updates(&cpService)
if result.Error != nil {
return fmt.Errorf("Unable to update service info on SP side: %v", result.Error.Error())
}
return nil
}
func RandomBool() bool {
rand.Seed(time.Now().UnixNano())
return rand.Intn(2) == 1
}
func IsExecutorType(v interface{}) bool {
_, ok := v.(types.ExecutorType)
return ok
}
func IsGPUVendor(v interface{}) bool {
_, ok := v.(types.GPUVendor)
return ok
}
func IsJobType(v interface{}) bool {
_, ok := v.(types.JobType)
return ok
}
func IsJobTypes(v interface{}) bool {
_, ok := v.(types.JobTypes)
return ok
}
func IsExecutor(v interface{}) bool {
_, ok := v.(types.Executor)
return ok
}
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
} else {
result = true
}
}
return result
}
func NoIntersectionSlices(slice1, slice2 []interface{}) bool {
result := false // the default result is false
for _, subElement := range slice1 {
if slices.Contains(slice2, subElement) {
result = false
} else {
result = true
}
}
return result
}
// IntersectionStringSlices returns the intersection of two slices of strings.
func IntersectionSlices(slice1, slice2 []interface{}) []interface{} {
// Create a map to store strings from the first slice.
executorMap := make(map[interface{}]bool)
// Iterate through the first slice and add elements to the map.
for _, str := range slice1 {
executorMap[str] = true
}
// Create a slice to store the intersection of the strings.
intersectionSlice := []interface{}{}
// Iterate through the second slice and check for common elements.
for _, str := range slice2 {
if executorMap[str] {
// If the string is found in the map, add to the intersection slice.
intersectionSlice = append(intersectionSlice, str)
// Remove the string from the map to avoid duplicates in the result.
delete(executorMap, str)
}
}
return intersectionSlice
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
package validate
import (
"reflect"
)
func ConvertNumericToFloat64(n any) (float64, bool) {
switch n := n.(type) {
case int, int8, int16, int32, int64:
return float64(reflect.ValueOf(n).Int()), true
case uint, uint8, uint16, uint32, uint64:
return float64(reflect.ValueOf(n).Uint()), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}
package validate
import (
"strings"
)
// IsBlank checks if a string is empty or contains only whitespace
func IsBlank(s string) bool {
if len(strings.TrimSpace(s)) == 0 {
return true
}
return false
}
// IsNotBlank checks if a string is not empty and does not contain only whitespace
func IsNotBlank(s string) bool {
return !IsBlank(s)
}
// Just checks if a variable is a string
func IsLiteral(s interface{}) bool {
switch s.(type) {
case string:
return true
default:
return false
}
}