package actor
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
type BasicActor struct {
dispatch *Dispatch
scheduler *bt.Scheduler
registry *registry
network network.Network
security SecurityContext
limiter RateLimiter
params BasicActorParams
self Handle
mx sync.Mutex
subscriptions map[string]uint64
}
type BasicActorParams struct{}
var _ Actor = (*BasicActor)(nil)
// New creates a new basic actor.
func New(scheduler *bt.Scheduler, net network.Network, security *BasicSecurityContext, limiter RateLimiter, params BasicActorParams, self Handle, opt ...DispatchOption) (*BasicActor, error) {
if scheduler == nil {
return nil, errors.New("scheduler is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if security == nil {
return nil, errors.New("security is nil")
}
dispatchOptions := []DispatchOption{WithRateLimiter(limiter)}
dispatchOptions = append(dispatchOptions, opt...)
dispatch := NewDispatch(security, dispatchOptions...)
actor := &BasicActor{
dispatch: dispatch,
scheduler: scheduler,
registry: newRegistry(),
network: net,
security: security,
limiter: limiter,
params: params,
self: self,
subscriptions: make(map[string]uint64),
}
return actor, nil
}
func (a *BasicActor) Start() error {
// Network messages
if err := a.network.HandleMessage(
fmt.Sprintf("actor/%s/messages/0.0.1", a.self.Address.InboxAddress),
a.handleMessage,
); err != nil {
return fmt.Errorf("starting actor: %s: %w", a.self.ID, err)
}
// and start the internal goroutines
a.dispatch.Start()
a.scheduler.Start()
return nil
}
func (a *BasicActor) handleMessage(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling message: %s", err)
return
}
if !a.self.ID.Equal(msg.To.ID) {
log.Warnf("message is not for ourselves: %s %s", a.self.ID, msg.To.ID)
return
}
if !a.limiter.Allow(msg) {
log.Warnf("incoming message invoking %s not allowed by limiter", msg.Behavior)
return
}
_ = a.Receive(msg)
}
func (a *BasicActor) Context() context.Context {
return a.dispatch.Context()
}
func (a *BasicActor) Handle() Handle {
return a.self
}
func (a *BasicActor) Security() SecurityContext {
return a.security
}
func (a *BasicActor) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
return a.dispatch.AddBehavior(behavior, continuation, opt...)
}
func (a *BasicActor) RemoveBehavior(behavior string) {
a.dispatch.RemoveBehavior(behavior)
}
func (a *BasicActor) Receive(msg Envelope) error {
if a.self.ID.Equal(msg.To.ID) {
return a.dispatch.Receive(msg)
}
if msg.IsBroadcast() {
return a.dispatch.Receive(msg)
}
return fmt.Errorf("bad receiver: %w", ErrInvalidMessage)
}
func (a *BasicActor) Send(msg Envelope) error {
if msg.To.ID.Equal(a.self.ID) {
return a.Receive(msg)
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
invoke := []Capability{Capability(msg.Behavior)}
var delegate []Capability
if msg.Options.ReplyTo != "" {
delegate = append(delegate, Capability(msg.Options.ReplyTo))
}
if err := a.security.Provide(&msg, invoke, delegate); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
err = a.network.SendMessage(
a.Context(),
msg.To.Address.HostID,
types.MessageEnvelope{
Type: types.MessageType(
fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress),
),
Data: data,
})
if err != nil {
return fmt.Errorf("sending message to %s: %w", msg.To.ID, err)
}
return nil
}
func (a *BasicActor) Invoke(msg Envelope) (<-chan Envelope, error) {
if msg.Options.ReplyTo == "" {
msg.Options.ReplyTo = fmt.Sprintf("/dms/actor/replyto/%d", a.security.Nonce())
}
result := make(chan Envelope, 1)
if err := a.dispatch.AddBehavior(
msg.Options.ReplyTo,
func(reply Envelope) {
result <- reply
close(result)
},
WithBehaviorExpiry(msg.Options.Expire),
WithBehaviorOneShot(true),
); err != nil {
return nil, fmt.Errorf("adding reply behavior: %w", err)
}
if err := a.Send(msg); err != nil {
a.dispatch.RemoveBehavior(msg.Options.ReplyTo)
return nil, fmt.Errorf("sending message: %w", err)
}
return result, nil
}
func (a *BasicActor) Publish(msg Envelope) error {
if !msg.IsBroadcast() {
return ErrInvalidMessage
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
broadcast := []Capability{Capability(msg.Behavior)}
if err := a.security.ProvideBroadcast(&msg, msg.Options.Topic, broadcast); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
if err := a.network.Publish(a.Context(), msg.Options.Topic, data); err != nil {
return fmt.Errorf("publishing message: %w", err)
}
return nil
}
func (a *BasicActor) Subscribe(topic string, setup ...BroadcastSetup) error {
a.mx.Lock()
defer a.mx.Unlock()
_, ok := a.subscriptions[topic]
if ok {
return nil
}
subID, err := a.network.Subscribe(
a.Context(),
topic,
a.handleBroadcast,
func(data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
return a.validateBroadcast(topic, data, validatorData)
},
)
if err != nil {
return fmt.Errorf("subscribe: %w", err)
}
for _, f := range setup {
if err := f(topic); err != nil {
_ = a.network.Unsubscribe(topic, subID)
return fmt.Errorf("setup broadcast topic: %w", err)
}
}
a.subscriptions[topic] = subID
return nil
}
func (a *BasicActor) validateBroadcast(topic string, data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
var msg Envelope
if validatorData != nil {
if _, ok := validatorData.(Envelope); !ok {
log.Warnf("bogus pubsub validation data: %v", validatorData)
return network.ValidationReject, nil
}
// we have already validated the message, just short-circuit
return network.ValidationAccept, validatorData
} else if err := json.Unmarshal(data, &msg); err != nil {
return network.ValidationReject, nil
}
if !msg.IsBroadcast() {
return network.ValidationReject, nil
}
if msg.Options.Topic != topic {
return network.ValidationReject, nil
}
if msg.Expired() {
return network.ValidationIgnore, nil
}
if err := a.security.Verify(msg); err != nil {
return network.ValidationReject, nil
}
if !a.limiter.Allow(msg) {
log.Warnf("incoming broadcast message in %s not allowed by limiter", topic)
return network.ValidationIgnore, nil
}
return network.ValidationAccept, msg
}
func (a *BasicActor) handleBroadcast(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling broadcast message: %s", err)
return
}
if err := a.Receive(msg); err != nil {
log.Warnf("error receiving broadcast message: %s", err)
}
}
func (a *BasicActor) Stop() error {
a.dispatch.close()
for topic, subID := range a.subscriptions {
err := a.network.Unsubscribe(topic, subID)
if err != nil {
log.Debugf("error unsubscribing from %s: %s", topic, err)
}
}
return nil
}
func (a *BasicActor) Limiter() RateLimiter {
return a.limiter
}
package actor
import (
"context"
"fmt"
"sync"
"time"
)
var (
DefaultDispatchGCInterval = 120 * time.Second
DefaultDispatchWorkers = 1
)
// Dispatch provides a reaction kernel with multithreaded dispatch and oneshot
// continuations.
type Dispatch struct {
ctx context.Context
close func()
sctx SecurityContext
mx sync.Mutex
q chan Envelope // incoming message queue
vq chan Envelope // verified message queue
behaviors map[string]*BehaviorState
started bool
options DispatchOptions
}
type DispatchOptions struct {
Limiter RateLimiter
GCInterval time.Duration
Workers int
}
type BehaviorState struct {
cont Behavior
opt BehaviorOptions
}
type DispatchOption func(o *DispatchOptions)
func WithDispatchWorkers(count int) DispatchOption {
return func(o *DispatchOptions) {
o.Workers = count
}
}
func WithDispatchGCInterval(dt time.Duration) DispatchOption {
return func(o *DispatchOptions) {
o.GCInterval = dt
}
}
func WithRateLimiter(limiter RateLimiter) DispatchOption {
return func(o *DispatchOptions) {
o.Limiter = limiter
}
}
func NewDispatch(sctx SecurityContext, opt ...DispatchOption) *Dispatch {
ctx, cancel := context.WithCancel(context.Background())
k := &Dispatch{
sctx: sctx,
ctx: ctx,
close: cancel,
q: make(chan Envelope),
vq: make(chan Envelope),
behaviors: make(map[string]*BehaviorState),
options: DispatchOptions{
GCInterval: DefaultDispatchGCInterval,
Workers: DefaultDispatchWorkers,
Limiter: NoRateLimiter{},
},
}
for _, f := range opt {
f(&k.options)
}
return k
}
func (k *Dispatch) Start() {
k.mx.Lock()
defer k.mx.Unlock()
if !k.started {
for i := 0; i < k.options.Workers; i++ {
go k.recv()
}
go k.dispatch()
go k.gc()
k.started = true
}
}
func (k *Dispatch) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
st := &BehaviorState{
cont: continuation,
opt: BehaviorOptions{
Capability: []Capability{Capability(behavior)},
},
}
for _, f := range opt {
if err := f(&st.opt); err != nil {
return fmt.Errorf("adding behavior: %w", err)
}
}
k.mx.Lock()
defer k.mx.Unlock()
k.behaviors[behavior] = st
return nil
}
func (k *Dispatch) RemoveBehavior(behavior string) {
k.mx.Lock()
defer k.mx.Unlock()
delete(k.behaviors, behavior)
}
func (k *Dispatch) Receive(msg Envelope) error {
select {
case k.q <- msg:
return nil
case <-k.ctx.Done():
return k.ctx.Err()
}
}
func (k *Dispatch) Context() context.Context {
return k.ctx
}
func (k *Dispatch) recv() {
for {
select {
case msg := <-k.q:
if err := k.sctx.Verify(msg); err != nil {
log.Debugf("failed to verify message from %s: %s", msg.From, err)
continue
}
k.vq <- msg
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) dispatch() {
for {
select {
case msg := <-k.vq:
k.mx.Lock()
b, ok := k.behaviors[msg.Behavior]
if !ok {
k.mx.Unlock()
log.Debugf("unknown behavior %s", msg.Behavior)
continue
}
if b.Expired(time.Now()) {
delete(k.behaviors, msg.Behavior)
k.mx.Unlock()
log.Debugf("expired behavior %s", msg.Behavior)
continue
}
if msg.IsBroadcast() {
if err := k.sctx.RequireBroadcast(msg, b.opt.Topic, b.opt.Capability); err != nil {
k.mx.Unlock()
log.Warnf("broadcast message from %s does not have the required capability %s: %s", msg.From, b.opt.Capability, err)
continue
}
} else if err := k.sctx.Require(msg, b.opt.Capability); err != nil {
k.mx.Unlock()
log.Warnf("message from %s does not have the required capability %s: %s", msg.From, b.opt.Capability, err)
continue
}
if b.opt.OneShot {
delete(k.behaviors, msg.Behavior)
}
k.mx.Unlock()
if err := k.options.Limiter.Acquire(msg); err != nil {
k.sctx.Discard(msg)
log.Debugf("limiter rejected message from %s: %s", msg.From, err)
continue
}
msg.Discard = func() {
k.sctx.Discard(msg)
}
go func() {
defer k.options.Limiter.Release(msg)
b.cont(msg)
}()
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) gc() {
ticker := time.NewTicker(k.options.GCInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
k.mx.Lock()
now := time.Now()
for x, b := range k.behaviors {
if b.Expired(now) {
delete(k.behaviors, x)
}
}
k.mx.Unlock()
case <-k.ctx.Done():
return
}
}
}
func (b *BehaviorState) Expired(now time.Time) bool {
if b.opt.Expire > 0 {
return uint64(now.UnixNano()) > b.opt.Expire
}
return false
}
func WithBehaviorExpiry(expire uint64) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Expire = expire
return nil
}
}
func WithBehaviorCapability(require ...Capability) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Capability = require
return nil
}
}
func WithBehaviorOneShot(oneShot bool) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.OneShot = oneShot
return nil
}
}
func WithBehaviorTopic(topic string) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Topic = topic
return nil
}
}
package actor
import (
"fmt"
"gitlab.com/nunet/device-management-service/lib/did"
)
func (h *Handle) Empty() bool {
return h.ID.Empty() &&
h.DID.Empty() &&
h.Address.Empty()
}
func (h *Handle) String() string {
var idStr string
idDID, err := did.FromID(h.ID)
if err == nil {
idStr = idDID.String()
}
return fmt.Sprintf("%s[%s]@%s", idStr, h.DID, h.Address)
}
func HandleFromString(_ string) (Handle, error) {
// TODO
return Handle{}, ErrTODO
}
func (a *Address) Empty() bool {
return a.HostID == "" && a.InboxAddress == ""
}
func (a *Address) String() string {
return a.HostID + ":" + a.InboxAddress
}
func AddressFromString(_ string) (Address, error) {
// TODO
return Address{}, ErrTODO
}
package actor
import (
"strings"
"sync"
)
// NoRateLimiter is the null limiter, that does not rate limit
type NoRateLimiter struct{}
var _ RateLimiter = NoRateLimiter{}
type BasicRateLimiter struct {
cfg RateLimiterConfig
mx sync.Mutex
activeBroadcast int
activeTopics map[string]int
activePublic int
}
var _ RateLimiter = (*BasicRateLimiter)(nil)
// implementation
func (l NoRateLimiter) Allow(_ Envelope) bool { return true }
func (l NoRateLimiter) Acquire(_ Envelope) error { return nil }
func (l NoRateLimiter) Release(_ Envelope) {}
func (l NoRateLimiter) Config() RateLimiterConfig { return RateLimiterConfig{} }
func (l NoRateLimiter) SetConfig(_ RateLimiterConfig) {}
func DefaultRateLimiterConfig() RateLimiterConfig {
return RateLimiterConfig{
PublicLimitAllow: 4096,
PublicLimitAcquire: 4112,
BroadcastLimitAllow: 1024,
BroadcastLimitAcquire: 1040,
TopicDefaultLimit: 128,
}
}
func (cfg *RateLimiterConfig) Valid() bool {
return cfg.PublicLimitAllow > 0 &&
cfg.PublicLimitAcquire >= cfg.PublicLimitAllow &&
cfg.BroadcastLimitAllow > 0 &&
cfg.BroadcastLimitAcquire >= cfg.BroadcastLimitAllow &&
cfg.TopicDefaultLimit > 0
}
func NewRateLimiter(cfg RateLimiterConfig) RateLimiter {
return &BasicRateLimiter{
cfg: cfg,
activeTopics: make(map[string]int),
}
}
func (l *BasicRateLimiter) Allow(msg Envelope) bool {
if msg.IsBroadcast() {
return l.allowBroadcast(msg)
}
if isPublicBehavior(msg) {
return l.allowPublic(msg)
}
return true
}
func (l *BasicRateLimiter) allowPublic(_ Envelope) bool {
l.mx.Lock()
defer l.mx.Unlock()
return l.activePublic < l.cfg.PublicLimitAllow
}
func (l *BasicRateLimiter) allowBroadcast(msg Envelope) bool {
l.mx.Lock()
defer l.mx.Unlock()
if l.activeBroadcast >= l.cfg.BroadcastLimitAllow {
return false
}
topic := msg.Options.Topic
active := l.activeTopics[topic]
topicLimit, ok := l.cfg.TopicLimit[topic]
if !ok {
return active < l.cfg.TopicDefaultLimit
}
return active < topicLimit
}
func (l *BasicRateLimiter) Acquire(msg Envelope) error {
if msg.IsBroadcast() {
return l.acquireBroadcast(msg)
}
if isPublicBehavior(msg) {
return l.acquirePublic(msg)
}
return nil
}
func (l *BasicRateLimiter) acquirePublic(_ Envelope) error {
l.mx.Lock()
defer l.mx.Unlock()
if l.activePublic >= l.cfg.PublicLimitAcquire {
return ErrRateLimitExceeded
}
l.activePublic++
return nil
}
func (l *BasicRateLimiter) acquireBroadcast(msg Envelope) error {
l.mx.Lock()
defer l.mx.Unlock()
if l.activeBroadcast >= l.cfg.BroadcastLimitAcquire {
return ErrRateLimitExceeded
}
topic := msg.Options.Topic
active := l.activeTopics[topic]
topicLimit, ok := l.cfg.TopicLimit[topic]
if ok {
if active >= topicLimit {
return ErrRateLimitExceeded
}
} else if active >= l.cfg.TopicDefaultLimit {
return ErrRateLimitExceeded
}
active++
l.activeTopics[topic] = active
l.activeBroadcast++
return nil
}
func (l *BasicRateLimiter) Release(msg Envelope) {
if msg.IsBroadcast() {
l.releaseBroadcast(msg)
} else if isPublicBehavior(msg) {
l.releasePublic(msg)
}
}
func (l *BasicRateLimiter) releasePublic(_ Envelope) {
l.mx.Lock()
defer l.mx.Unlock()
l.activePublic--
}
func (l *BasicRateLimiter) releaseBroadcast(msg Envelope) {
l.mx.Lock()
defer l.mx.Unlock()
topic := msg.Options.Topic
active, ok := l.activeTopics[topic]
if !ok {
return
}
active--
if active > 0 {
l.activeTopics[topic] = active
} else {
delete(l.activeTopics, topic)
}
l.activeBroadcast--
}
func (l *BasicRateLimiter) Config() RateLimiterConfig {
l.mx.Lock()
defer l.mx.Unlock()
return l.cfg
}
func (l *BasicRateLimiter) SetConfig(cfg RateLimiterConfig) {
l.mx.Lock()
defer l.mx.Unlock()
l.cfg = cfg
}
func isPublicBehavior(msg Envelope) bool {
return strings.HasPrefix(msg.Behavior, "/public/")
}
package actor
import (
"encoding/json"
"fmt"
"time"
)
const (
heartbeatBehavior = "/dms/actor/heartbeat"
defaultMessageTimeout = 30 * time.Second
)
var signaturePrefix = []byte("dms:msg:")
type HeartbeatMessage struct{}
// Message constructs a new message envelope and applies the options
func Message(src Handle, dest Handle, behavior string, payload interface{}, opt ...MessageOption) (Envelope, error) {
data, err := json.Marshal(payload)
if err != nil {
return Envelope{}, fmt.Errorf("marshaling payload: %w", err)
}
msg := Envelope{
To: dest,
Behavior: behavior,
From: src,
Message: data,
Options: EnvelopeOptions{
Expire: uint64(time.Now().Add(defaultMessageTimeout).UnixNano()),
},
Discard: func() {},
}
for _, f := range opt {
if err := f(&msg); err != nil {
return Envelope{}, fmt.Errorf("setting message option: %w", err)
}
}
return msg, nil
}
func ReplyTo(msg Envelope, payload interface{}, opt ...MessageOption) (Envelope, error) {
if msg.Options.ReplyTo == "" {
return Envelope{}, fmt.Errorf("no behavior to reply to: %w", ErrInvalidMessage)
}
msgOptions := []MessageOption{WithMessageExpiry(msg.Options.Expire)}
msgOptions = append(msgOptions, opt...)
return Message(msg.To, msg.From, msg.Options.ReplyTo, payload, msgOptions...)
}
// WithMessageContext provides the necessary envelope and signs it.
//
// NOTE: If this option must be passed last, otherwise the signature will be invalidated by further modifications.
//
// NOTE: Signing is implicit in Send.
func WithMessageSignature(sctx SecurityContext, invoke []Capability, delegate []Capability) MessageOption {
return func(msg *Envelope) error {
if !msg.From.ID.Equal(sctx.ID()) {
return ErrInvalidSecurityContext
}
msg.Nonce = sctx.Nonce()
if msg.IsBroadcast() {
return sctx.ProvideBroadcast(msg, msg.Options.Topic, invoke)
}
return sctx.Provide(msg, invoke, delegate)
}
}
// WithMessageTimeout sets the message expiration from a relative timeout
//
// NOTE: messages created with Message have an implicit timeout of DefaultMessageTimeout
func WithMessageTimeout(timeo time.Duration) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = uint64(time.Now().Add(timeo).UnixNano())
return nil
}
}
// WithMessageExpiry sets the message expiry
//
// NOTE: created with Message message have an implicit timeout of DefaultMessageTimeout
func WithMessageExpiry(expiry uint64) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = expiry
return nil
}
}
// WithMessageReplyTo sets the message replyto behavior
//
// NOTE: ReplyTo is set implicitly in Invoke and the appropriate capability
//
// tokens are delegated by Provide.
func WithMessageReplyTo(replyto string) MessageOption {
return func(msg *Envelope) error {
msg.Options.ReplyTo = replyto
return nil
}
}
// WithMessageTopic sets the broadcast topic
func WithMessageTopic(topic string) MessageOption {
return func(msg *Envelope) error {
msg.Options.Topic = topic
return nil
}
}
// WithMessageSource sets the message source
func WithMessageSource(source Handle) MessageOption {
return func(msg *Envelope) error {
msg.From = source
return nil
}
}
func (msg *Envelope) SignatureData() ([]byte, error) {
msgCopy := *msg
msgCopy.Signature = nil
data, err := json.Marshal(&msgCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (msg *Envelope) Expired() bool {
return uint64(time.Now().UnixNano()) > msg.Options.Expire
}
// convert the expiration to a time.Time object
func (msg *Envelope) Expiry() time.Time {
sec := msg.Options.Expire / uint64(time.Second)
nsec := msg.Options.Expire % uint64(time.Second)
return time.Unix(int64(sec), int64(nsec)) //nolint
}
func (msg *Envelope) IsBroadcast() bool {
return msg.To.Empty() && msg.Options.Topic != ""
}
package actor
import (
"errors"
"sync"
)
type Info struct {
Addr *Handle
Parent *Handle
Children []Handle
}
type Registry interface {
Actors() map[string]Info
Add(a Handle, parent Handle, children []Handle) error
Get(a Handle) (Info, bool)
SetParent(a Handle, parent Handle) error
GetParent(a Handle) (*Handle, bool)
}
type registry struct {
mx sync.Mutex
actors map[string]Info
}
func newRegistry() *registry {
return ®istry{
actors: make(map[string]Info),
}
}
func (r *registry) Actors() map[string]Info {
r.mx.Lock()
defer r.mx.Unlock()
actors := make(map[string]Info, len(r.actors))
for k, v := range r.actors {
actors[k] = v
}
return actors
}
func (r *registry) Add(a Handle, parent Handle, children []Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
if _, ok := r.actors[a.Address.InboxAddress]; ok {
return errors.New("actor already exists")
}
if children == nil {
children = []Handle{}
}
r.actors[a.Address.InboxAddress] = Info{
Addr: &a,
Parent: &parent,
Children: children,
}
return nil
}
func (r *registry) Get(a Handle) (Info, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
return info, ok
}
func (r *registry) SetParent(a Handle, parent Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return errors.New("actor not found")
}
info.Parent = &parent
r.actors[a.Address.InboxAddress] = info
return nil
}
func (r *registry) GetParent(a Handle) (*Handle, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return nil, false
}
return info.Parent, true
}
package actor
import (
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
type BasicSecurityContext struct {
id ID
privk crypto.PrivKey
cap ucan.CapabilityContext
mx sync.Mutex
nonce uint64
}
var _ SecurityContext = (*BasicSecurityContext)(nil)
func NewBasicSecurityContext(pubk crypto.PubKey, privk crypto.PrivKey, cap ucan.CapabilityContext) (*BasicSecurityContext, error) {
sctx := &BasicSecurityContext{
privk: privk,
cap: cap,
nonce: uint64(time.Now().UnixNano()),
}
var err error
sctx.id, err = crypto.IDFromPublicKey(pubk)
if err != nil {
return nil, fmt.Errorf("creating security context: %w", err)
}
return sctx, nil
}
func (s *BasicSecurityContext) ID() ID {
return s.id
}
func (s *BasicSecurityContext) DID() DID {
return s.cap.DID()
}
func (s *BasicSecurityContext) Nonce() uint64 {
s.mx.Lock()
defer s.mx.Unlock()
nonce := s.nonce
s.nonce++
return nonce
}
func (s *BasicSecurityContext) Require(msg Envelope, cap []Capability) error {
// if we are sending to self, nothing to do, signature is alredady verified
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return nil
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.Require(s.DID(), msg.From.ID, s.id, cap); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) Provide(msg *Envelope, invoke []Capability, delegate []Capability) error {
// if we are sending to self, nothing to do, just Sign
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return s.Sign(msg)
}
tokens, err := s.cap.Provide(msg.To.DID, s.id, msg.To.ID, msg.Options.Expire, invoke, delegate)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) RequireBroadcast(msg Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.RequireBroadcast(s.DID(), msg.From.ID, topic, broadcast); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) ProvideBroadcast(msg *Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
tokens, err := s.cap.ProvideBroadcast(msg.From.ID, topic, msg.Options.Expire, broadcast)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) Verify(msg Envelope) error {
if msg.Expired() {
return ErrMessageExpired
}
pubk, err := crypto.PublicKeyFromID(msg.From.ID)
if err != nil {
return fmt.Errorf("public key from id: %w", err)
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
ok, err := pubk.Verify(data, msg.Signature)
if err != nil {
return fmt.Errorf("verify message signature: %w", err)
}
if !ok {
return ErrSignatureVerification
}
return nil
}
func (s *BasicSecurityContext) Sign(msg *Envelope) error {
if !s.id.Equal(msg.From.ID) {
return ErrBadSender
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
sig, err := s.privk.Sign(data)
if err != nil {
return fmt.Errorf("signing message: %w", err)
}
msg.Signature = sig
return nil
}
func (s *BasicSecurityContext) Discard(msg Envelope) {
s.cap.Discard(msg.Capability)
}
package cmd
import (
"os"
"github.com/spf13/cobra"
)
// autocompleteCmd represents the command to generate shell autocompletion scripts
func newAutoCompleteCmd() *cobra.Command {
return &cobra.Command{
Use: "autocomplete [shell]",
Short: "Generate autocomplete script for your shell",
Long: `Generate an autocomplete script for the nunet CLI.
This command supports Bash and Zsh shells.`,
DisableFlagsInUseLine: true,
Hidden: true,
ValidArgs: []string{"bash", "zsh"},
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
Run: func(cmd *cobra.Command, args []string) {
switch args[0] {
case "bash":
_ = cmd.Root().GenBashCompletion(os.Stdout)
case "zsh":
_ = cmd.Root().GenZshCompletion(os.Stdout)
}
},
}
}
package cmd
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/internal/config"
)
func newConfigCmd(fs afero.Fs) *cobra.Command {
if fs == nil {
cobra.CheckErr("Fs is nil")
}
cmd := &cobra.Command{
Use: "config",
Short: "Manage configuration file",
Long: `Utility to manage user's configuration file via command-line
Search for the configuration file is done in the following locations and order:
1. "." (current directory)
2. "$HOME/.nunet"
3. "/etc/nunet"`,
}
cmd.AddCommand(newConfigGetCmd())
cmd.AddCommand(newConfigSetCmd(fs))
cmd.AddCommand(newConfigEditCmd())
return cmd
}
func newConfigGetCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "get <key>",
Short: "Display configuration",
Long: `Display the value for a configuration key
It reads the value from configuration file, otherwise it return default values
Example:
nunet config get rest.port`,
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
info, err := json.MarshalIndent(config.GetConfig(), "", " ")
if err != nil {
return fmt.Errorf("failed to indent config JSON: %w", err)
}
cmd.Println(string(info))
return nil
}
value, err := config.Get(args[0])
if err != nil {
return fmt.Errorf("could not get key's value: %w", err)
}
pretty, err := json.MarshalIndent(value, "", " ")
if err != nil {
return fmt.Errorf("failed to indent JSON: %w", err)
}
cmd.Println(string(pretty))
return nil
},
}
return cmd
}
func newConfigSetCmd(fs afero.Fs) *cobra.Command {
cmd := &cobra.Command{
Use: "set <key> <value>",
Short: "Update configuration",
Long: `Set value for a configuration key
It creates a configuration file if does not exists, otherwise it updates the existing file
Example:
nunet config set rest.port 4444`,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
exists, err := config.FileExists(fs)
if err != nil {
return fmt.Errorf("failed to check if config file exists: %w", err)
}
if !exists {
cmd.Println("Config file did not exist. Creating new file...")
} else {
cmd.Println("Updating existing config file...")
}
if err := config.Set(fs, args[0], args[1]); err != nil {
return fmt.Errorf("failed to set config: %w", err)
}
cmd.Println("Applied changes.")
return nil
},
}
return cmd
}
func newConfigEditCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "edit",
Short: "Edit configuration",
Long: `Open configuration file with default text editor
This command search the configuration file and open it with the default text editor
It reads the $EDITOR environment variable and it fails if it's not set`,
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
editor := os.Getenv("EDITOR")
if editor == "" {
return fmt.Errorf("$EDITOR not set")
}
cmd.Printf("Text editor: %s\n", editor)
cmd.Printf("Config path: %s\n", config.GetPath())
// do we need better sanitization?
// do we check if editor is valid?
proc := exec.Command(editor, config.GetPath())
proc.Stdout = cmd.OutOrStdout()
proc.Stdin = cmd.InOrStdin()
proc.Stderr = cmd.OutOrStderr()
return proc.Run()
},
}
return cmd
}
package cmd
import (
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
dmsUtils "gitlab.com/nunet/device-management-service/utils"
)
func newKeyCmd(fs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "key",
Short: "Manage cryptographic keys",
Long: `Manage cryptographic keys for the Device Management Service (DMS).
This command provides subcommands for creating new keys and retrieving Decentralized Identifiers (DIDs) associated with existing keys.`,
}
cmd.AddCommand(newKeyNewCmd(fs))
cmd.AddCommand(newKeyDIDCmd(fs))
return cmd
}
func newKeyNewCmd(fs afero.Afero) *cobra.Command {
return &cobra.Command{
Use: "new <name>",
Short: "Generate a key pair",
Long: `Generate a key pair and save the private key into the user's local keystore.
This command creates a new cryptographic key pair, stores the private key securely, and displays the associated Decentralized Identifier (DID). If a key with the specified name already exists, the user will be prompted to confirm before overwriting it.
Example:
nunet key new user`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
keyStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir)
if err != nil {
return fmt.Errorf("failed to create keystore: %w", err)
}
keyID := dms.UserContextName
if len(args) > 0 {
keyID = args[0]
}
keys, err := ks.ListKeys()
if err != nil {
return fmt.Errorf("failed to list keys: %w", err)
}
if dmsUtils.SliceContains(keys, keyID) {
confirmed, err := utils.PromptYesNo(
cmd.InOrStdin(),
cmd.OutOrStdout(),
fmt.Sprintf("A key with name '%s' already exists. Do you want to overwrite it with a new one?", keyID),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return fmt.Errorf("operation cancelled by user")
}
}
passphrase := os.Getenv("DMS_PASSPHRASE")
if passphrase == "" {
passphrase, err = utils.PromptForPassphrase(true)
if err != nil {
return fmt.Errorf("failed to get passphrase: %w", err)
}
}
priv, err := dms.GenerateAndStorePrivKey(ks, passphrase, keyID)
if err != nil {
return fmt.Errorf("failed to generate and store new private key")
}
did := did.FromPublicKey(priv.GetPublic())
fmt.Println(did)
return nil
},
}
}
func newKeyDIDCmd(fs afero.Afero) *cobra.Command {
return &cobra.Command{
Use: "did <name>",
Short: "Get a key's DID",
Long: `Get the DID (Decentralized Identifier) for a specified key.
This command retrieves the DID associated with either a key stored in the local keystore or a hardware ledger.
For keys in the local keystore, the user will be prompted for the passphrase to decrypt the key. To avoid passphrase prompting, it's possible to set a DMS_PASSPHRASE environment variable. For the ledger option, it uses the first account (index 0) on the connected hardware wallet.
Example:
nunet key did user
nunet key did ledger # if using ledger`,
Args: cobra.ExactArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
keyName := args[0]
if keyName == "ledger" {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
fmt.Println(provider.DID())
return nil
}
keyStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir)
if err != nil {
return fmt.Errorf("failed to open keystore: %w", err)
}
passphrase := os.Getenv("DMS_PASSPHRASE")
if passphrase == "" {
passphrase, err = utils.PromptForPassphrase(false)
if err != nil {
return fmt.Errorf("failed to get passphrase: %w", err)
}
}
key, err := ks.Get(keyName, passphrase)
if err != nil {
return fmt.Errorf("failed to get key: %w", err)
}
priv, err := key.PrivKey()
if err != nil {
return fmt.Errorf("unable to convert key from keystore to private key: %v", err)
}
did := did.FromPublicKey(priv.GetPublic())
fmt.Println(did)
return nil
},
}
}
package cmd
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/actor"
"gitlab.com/nunet/device-management-service/cmd/cap"
"gitlab.com/nunet/device-management-service/utils"
)
func newRootCmd(client *utils.HTTPClient, afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "nunet",
Short: "NuNet Device Management Service",
Long: `The Device Management Service (DMS) Command Line Interface (CLI)`,
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: false,
HiddenDefaultCmd: true,
},
SilenceErrors: true,
SilenceUsage: true,
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
}
cmd.AddCommand(newRunCmd())
cmd.AddCommand(newKeyCmd(afs))
cmd.AddCommand(cap.NewCapCmd(afs))
cmd.AddCommand(actor.NewActorCmd(client, afs))
cmd.AddCommand(newConfigCmd(afs.Fs))
cmd.AddCommand(newAutoCompleteCmd())
cmd.AddCommand(newVersionCmd())
cmd.AddCommand(newTapCommand())
return cmd
}
//go:build linux
package cmd
import (
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/utils"
)
// Execute is a wrapper for cobra.Command method with same name
// It makes use of cobra.CheckErr to facilitate error handling
func Execute() {
afs := afero.Afero{Fs: afero.NewOsFs()}
client := utils.NewHTTPClient(
fmt.Sprintf("http://%s:%d",
config.GetConfig().Addr,
config.GetConfig().Port),
"/api/v1",
)
cobra.CheckErr(newRootCmd(client, afs).Execute())
}
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
)
func newRunCmd() *cobra.Command {
var context string
cmd := &cobra.Command{
Use: "run",
Short: "Start the Device Management Service",
Long: `Start the Device Management Service
The Device Management Service (DMS) is a system application for running a node in the NuNet decentralized network of compute providers.
By default, DMS listens on port 9999. For more information on configuration, see:
nunet config --help
Or manually create a dms_config.json file and refer to the README for available settings.`,
RunE: func(_ *cobra.Command, _ []string) error {
passphrase := os.Getenv("DMS_PASSPHRASE")
var err error
if passphrase == "" {
fmt.Print("Please enter the DMS passphrase. This will be used to encrypt/decrypt the keystore containing necessary secrets for DMS:\n")
passphrase, err = utils.PromptForPassphrase(false)
if err != nil {
return fmt.Errorf("error reading passphrase from stdin: %w", err)
}
// TODO: validate passphrase (minimum x characters)
if passphrase == "" {
return fmt.Errorf("invalid passphrase")
}
}
return dms.Run(passphrase, context)
},
}
cmd.Flags().StringVarP(&context, "context", "c", dms.DefaultContextName, "specify a capability context")
return cmd
}
package cmd
import (
"fmt"
"io"
"os"
"os/exec"
"github.com/spf13/cobra"
)
// newTapCommand creates the Cobra command to set up a TAP interface
func newTapCommand() *cobra.Command {
return &cobra.Command{
Use: "tap [main_interface] [vm_interface] [CIDR]",
Short: "Create and configure a TAP interface",
Long: `Create a TAP interface using the provided interface name and configure IP forwarding and iptables rules.
Example:
nunet tap eth0 tap0 172.16.0.1/24
Note: The command requires root privileges.
`,
Args: cobra.ExactArgs(3),
RunE: func(cmd *cobra.Command, args []string) error {
// check if the user is root
if os.Getuid() != 0 {
return fmt.Errorf("this command requires root privileges to execute")
}
mainInterface := args[0]
vmInterface := args[1]
cidr := args[2]
// Step 1: Create the TAP interface
err := runCommand(cmd.OutOrStdout(), fmt.Sprintf("ip tuntap add %s mode tap", vmInterface))
if err != nil {
return err
}
// Step 2: Assign IP address to the TAP interface
err = runCommand(cmd.OutOrStdout(), fmt.Sprintf("ip addr add %s dev %s", cidr, vmInterface))
if err != nil {
return err
}
// Step 3: Bring the TAP interface up
err = runCommand(cmd.OutOrStdout(), fmt.Sprintf("ip link set %s up", vmInterface))
if err != nil {
return err
}
// Step 4: Enable IP forwarding
err = runCommand(cmd.OutOrStdout(), "echo 1 > /proc/sys/net/ipv4/ip_forward")
if err != nil {
return err
}
// Step 5: Add iptables rules for connection tracking
err = runCommand(cmd.OutOrStdout(), "iptables -C FORWARD -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT || iptables -A FORWARD -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT")
if err != nil {
return err
}
// Step 6: Add iptables rules to allow forwarding between interfaces
err = runCommand(cmd.OutOrStdout(), fmt.Sprintf("iptables -C FORWARD -i %s -o %s -j ACCEPT || iptables -A FORWARD -i %s -o %s -j ACCEPT", vmInterface, mainInterface, vmInterface, mainInterface))
if err != nil {
return err
}
fmt.Fprintf(cmd.OutOrStdout(), "TAP interface %s created and configured successfully\n", vmInterface)
return nil
},
}
}
// Helper function to execute shell commands
func runCommand(stdout io.Writer, command string) error {
cmd := exec.Command("sh", "-c", command)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to execute command: %s, Error: %w, Output: %s", command, err, output)
}
fmt.Fprintf(stdout, "Executed: %s\n", command)
return nil
}
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
// Version is the version of the Nunet Device Management Service
// TODO: use git describe after this release
var Version = "v0.5.0-boot"
var (
GoVersion string
BuildDate string
Commit string
)
func newVersionCmd() *cobra.Command {
return &cobra.Command{
Use: "version",
Short: "Information about current version",
Long: `Display information about the current Device Management Service (DMS) version`,
Run: func(_ *cobra.Command, _ []string) {
// TODO get the version from git; make a top level version.go file
fmt.Println("NuNet Device Management Service")
fmt.Printf("Version: %s\nCommit: %s\n\nGo Version: %s\nBuild Date: %s\n",
Version, Commit, GoVersion, BuildDate)
},
}
}
package clover
import (
"fmt"
clover "github.com/ostafen/clover/v2"
)
// NewDB initializes and sets up the clover database using bbolt under the hood.
// Additionally, it automatically creates collections for the necessary types.
func NewDB(path string, collections []string) (*clover.DB, error) {
db, err := clover.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
for _, collection := range collections {
if err := db.CreateCollection(collection); err != nil {
return nil, fmt.Errorf("failed to create collection %s: %w", collection, err)
}
}
return db, nil
}
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatClover is a Clover implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatClover struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatClover.
// It initializes and returns a Clover-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *clover.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatClover{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerClover is a Clover implementation of the RequestTracker interface.
type RequestTrackerClover struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerClover.
// It initializes and returns a Clover-based repository for RequestTracker entities.
func NewRequestTracker(db *clover.DB) repositories.RequestTracker {
return &RequestTrackerClover{
NewGenericRepository[types.RequestTracker](db),
}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineClover is a Clover implementation of the VirtualMachine interface.
type VirtualMachineClover struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineClover.
// It initializes and returns a Clover-based repository for VirtualMachine entities.
func NewVirtualMachine(db *clover.DB) repositories.VirtualMachine {
return &VirtualMachineClover{
NewGenericRepository[types.VirtualMachine](db),
}
}
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericEntityRepositoryClover is a generic single entity repository implementation using Clover.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryClover.
// It initializes and returns a repository with the provided Clover database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericEntityRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericEntityRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericEntityRepositoryClover[T]) query() *clover_q.Query {
return clover_q.NewQuery(repo.collection)
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryClover[T]) Save(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) Get(_ context.Context) (T, error) {
var model T
q := repo.query().Sort(clover_q.SortOption{
Field: "CreatedAt",
Direction: -1,
})
doc, err := repo.db.FindFirst(q)
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryClover[T]) Clear(_ context.Context) error {
return repo.db.Delete(repo.query())
}
// History retrieves previous versions of the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) History(_ context.Context, query repositories.Query[T]) ([]T, error) {
var models []T
q := repo.query()
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return false
}
models = append(models, model)
return true
})
return models, handleDBError(err)
}
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pkField = "_id"
deletedAtField = "DeletedAt"
)
// GenericRepositoryClover is a generic repository implementation using Clover.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericRepository creates a new instance of GenericRepositoryClover.
// It initializes and returns a repository with the provided Clover database.
func NewGenericRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericRepositoryClover[T]) query(includeDeleted bool) *clover_q.Query {
query := clover_q.NewQuery(repo.collection)
if !includeDeleted {
query = query.Where(clover_q.Field(deletedAtField).LtEq(time.Unix(0, 0)))
}
return query
}
func (repo *GenericRepositoryClover[T]) queryWithID(
id interface{},
includeDeleted bool,
) *clover_q.Query {
return repo.query(includeDeleted).Where(clover_q.Field(pkField).Eq(id.(string)))
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryClover[T]) Create(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return data, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryClover[T]) Get(_ context.Context, id interface{}) (T, error) {
var model T
doc, err := repo.db.FindFirst(repo.queryWithID(id, false))
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryClover[T]) Update(
ctx context.Context,
id interface{},
data T,
) (T, error) {
updates := toCloverDoc(data).AsMap()
updates["UpdatedAt"] = time.Now()
err := repo.db.Update(repo.queryWithID(id, false), updates)
if err != nil {
return data, handleDBError(err)
}
data, err = repo.Get(ctx, id)
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryClover[T]) Delete(_ context.Context, id interface{}) error {
err := repo.db.Update(
repo.queryWithID(id, false),
map[string]interface{}{"DeletedAt": time.Now()},
)
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryClover[T]) Find(
_ context.Context,
query repositories.Query[T],
) (T, error) {
var model T
q := repo.query(false)
q = applyConditions(q, query)
doc, err := repo.db.FindFirst(q)
if err != nil {
return model, handleDBError(err)
} else if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryClover[T]) FindAll(
_ context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
var modelParsingErr error
q := repo.query(false)
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
model, internalErr := toModel[T](doc, false)
if internalErr != nil {
modelParsingErr = handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, internalErr))
return false
}
models = append(models, model)
return true
})
if err != nil {
return models, handleDBError(err)
}
if modelParsingErr != nil {
return models, modelParsingErr
}
return models, nil
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a Clover database query.
// It takes a Clover database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified Clover database instance is returned.
func applyConditions[T repositories.ModelType](
q *clover_q.Query,
query repositories.Query[T],
) *clover_q.Query {
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
// change the field name to json tag name if specified in the struct
condition.Field = fieldJSONTag[T](condition.Field)
switch condition.Operator {
case "=":
q = q.Where(clover_q.Field(condition.Field).Eq(condition.Value))
case ">":
q = q.Where(clover_q.Field(condition.Field).Gt(condition.Value))
case ">=":
q = q.Where(clover_q.Field(condition.Field).GtEq(condition.Value))
case "<":
q = q.Where(clover_q.Field(condition.Field).Lt(condition.Value))
case "<=":
q = q.Where(clover_q.Field(condition.Field).LtEq(condition.Value))
case "!=":
q = q.Where(clover_q.Field(condition.Field).Neq(condition.Value))
case "IN":
if values, ok := condition.Value.([]interface{}); ok {
q = q.Where(clover_q.Field(condition.Field).In(values...))
}
case "LIKE":
if value, ok := condition.Value.(string); ok {
q = q.Where(clover_q.Field(condition.Field).Like(value))
}
}
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldName = fieldJSONTag[T](fieldName)
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
q = q.Where(clover_q.Field(fieldName).Eq(fieldValue))
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := 1
if query.SortBy[0] == '-' {
dir = -1
query.SortBy = fieldJSONTag[T](query.SortBy[1:])
}
q = q.Sort(clover_q.SortOption{Field: query.SortBy, Direction: dir})
}
// Apply limit if specified in the query.
if query.Limit > 0 {
q = q.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
q = q.Limit(query.Offset)
}
return q
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoClover is a Clover implementation of the PeerInfo interface.
type PeerInfoClover struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoClover.
// It initializes and returns a Clover-based repository for PeerInfo entities.
func NewPeerInfo(db *clover.DB) repositories.PeerInfo {
return &PeerInfoClover{NewGenericRepository[types.PeerInfo](db)}
}
// MachineClover is a Clover implementation of the Machine interface.
type MachineClover struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineClover.
// It initializes and returns a Clover-based repository for Machine entities.
func NewMachine(db *clover.DB) repositories.Machine {
return &MachineClover{NewGenericRepository[types.Machine](db)}
}
// AvailableResourcesClover is a Clover implementation of the AvailableResources interface.
type AvailableResourcesClover struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// AvailableResourcesRepositoryClover is a Clover implementation of the AvailableResourcesRepository interface.
type AvailableResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesClover.
// It initializes and returns a Clover-based repository for AvailableResources entity.
func NewAvailableResources(db *clover.DB) repositories.AvailableResources {
return &AvailableResourcesClover{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesClover is a Clover implementation of the Services interface.
type ServicesClover struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesClover.
// It initializes and returns a Clover-based repository for Services entities.
func NewServices(db *clover.DB) repositories.Services {
return &ServicesClover{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsClover is a Clover implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsClover struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsClover.
// It initializes and returns a Clover-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *clover.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsClover{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoClover is a Clover implementation of the Libp2pInfo interface.
type Libp2pInfoClover struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoClover.
// It initializes and returns a Clover-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *clover.DB) repositories.Libp2pInfo {
return &Libp2pInfoClover{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDClover is a Clover implementation of the MachineUUID interface.
type MachineUUIDClover struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDClover.
// It initializes and returns a Clover-based repository for MachineUUID entity.
func NewMachineUUID(db *clover.DB) repositories.MachineUUID {
return &MachineUUIDClover{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionClover is a Clover implementation of the Connection interface.
type ConnectionClover struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionClover.
// It initializes and returns a Clover-based repository for Connection entities.
func NewConnection(db *clover.DB) repositories.Connection {
return &ConnectionClover{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenClover is a Clover implementation of the ElasticToken interface.
type ElasticTokenClover struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenClover.
// It initializes and returns a Clover-based repository for ElasticToken entities.
func NewElasticToken(db *clover.DB) repositories.ElasticToken {
return &ElasticTokenClover{NewGenericRepository[types.ElasticToken](db)}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesRepositoryClover is a Clover implementation of the MachineResourcesRepository interface.
type MachineResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResourcesRepository creates a new instance of MachineResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for MachineResources entity.
func NewMachineResourcesRepository(db *clover.DB) repositories.MachineResources {
return &MachineResourcesRepositoryClover{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesClover is a Clover implementation of the FreeResources interface.
type FreeResourcesClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResources(db *clover.DB) repositories.FreeResources {
return &FreeResourcesClover{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesClover is a Clover implementation of the OnboardedResources interface.
type OnboardedResourcesClover struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResources creates a new instance of OnboardedResourcesClover.
// It initializes and returns a Clover-based repository for OnboardedResources entity.
func NewOnboardedResources(db *clover.DB) repositories.OnboardedResources {
return &OnboardedResourcesClover{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// ResourceAllocationClover is a Clover implementation of the ResourceAllocation interface.
type ResourceAllocationClover struct {
repositories.GenericRepository[types.ResourceAllocation]
}
// NewResourceAllocation creates a new instance of ResourceAllocationClover.
// It initializes and returns a Clover-based repository for ResourceAllocation entity.
func NewResourceAllocation(db *clover.DB) repositories.ResourceAllocation {
return &ResourceAllocationClover{
NewGenericRepository[types.ResourceAllocation](db),
}
}
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeClover is a Clover implementation of the StorageVolume interface.
type StorageVolumeClover struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeClover.
// It initializes and returns a Clover-based repository for StorageVolume entities.
func NewStorageVolume(db *clover.DB) repositories.StorageVolume {
return &StorageVolumeClover{
NewGenericRepository[types.StorageVolume](db),
}
}
package clover
import (
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
"gitlab.com/nunet/device-management-service/db/repositories"
)
func handleDBError(err error) error {
if err != nil {
switch err {
case clover.ErrDocumentNotExist:
return repositories.ErrNotFound
case clover.ErrDuplicateKey:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
func toCloverDoc[T repositories.ModelType](data T) *clover_d.Document {
jsonBytes, err := json.Marshal(data)
if err != nil {
return clover_d.NewDocument()
}
mappedData := make(map[string]interface{})
err = json.Unmarshal(jsonBytes, &mappedData)
if err != nil {
return clover_d.NewDocument()
}
doc := clover_d.NewDocumentOf(mappedData)
return doc
}
func toModel[T repositories.ModelType](doc *clover_d.Document, isEntityRepo bool) (T, error) {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return model, err
}
if !isEntityRepo {
// we shouldn't try to update IDs of entity repositories as they might not
// even have an ID at all
model, err = repositories.UpdateField(model, "ID", doc.ObjectId())
if err != nil {
return model, err
}
}
return model, nil
}
func fieldJSONTag[T repositories.ModelType](field string) string {
fieldName := field
if field, ok := reflect.TypeOf(*new(T)).FieldByName(field); ok {
if tag, ok := field.Tag.Lookup("json"); ok {
fieldName = strings.Split(tag, ",")[0]
}
}
return fieldName
}
package repositories
import (
"context"
)
// QueryCondition is a struct representing a query condition.
type QueryCondition struct {
Field string // Field specifies the database or struct field to which the condition applies.
Operator string // Operator defines the comparison operator (e.g., "=", ">", "<").
Value interface{} // Value is the expected value for the given field.
}
type ModelType interface{}
// Query is a struct that wraps both the instance of type T and additional query parameters.
// It is used to construct queries with conditions, sorting, limiting, and offsetting.
type Query[T any] struct {
Instance T // Instance is an optional object of type T used to build conditions from its fields.
Conditions []QueryCondition // Conditions represent the conditions applied to the query.
SortBy string // SortBy specifies the field by which the query results should be sorted.
Limit int // Limit specifies the maximum number of results to return.
Offset int // Offset specifies the number of results to skip before starting to return data.
}
// GenericRepository is an interface defining basic CRUD operations and standard querying methods.
type GenericRepository[T ModelType] interface {
// Create adds a new record to the repository.
Create(ctx context.Context, data T) (T, error)
// Get retrieves a record by its identifier.
Get(ctx context.Context, id interface{}) (T, error)
// Update modifies a record by its identifier.
Update(ctx context.Context, id interface{}, data T) (T, error)
// Delete removes a record by its identifier.
Delete(ctx context.Context, id interface{}) error
// Find retrieves a single record based on a query.
Find(ctx context.Context, query Query[T]) (T, error)
// FindAll retrieves multiple records based on a query.
FindAll(ctx context.Context, query Query[T]) ([]T, error)
// GetQuery returns an empty query instance for the repository's type.
GetQuery() Query[T]
}
// EQ creates a QueryCondition for equality comparison.
// It takes a field name and a value and returns a QueryCondition with the equality operator.
func EQ(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "=", Value: value}
}
// GT creates a QueryCondition for greater-than comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than operator.
func GT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">", Value: value}
}
// GTE creates a QueryCondition for greater-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than or equal operator.
func GTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">=", Value: value}
}
// LT creates a QueryCondition for less-than comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than operator.
func LT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<", Value: value}
}
// LTE creates a QueryCondition for less-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than or equal operator.
func LTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<=", Value: value}
}
// IN creates a QueryCondition for an "IN" comparison.
// It takes a field name and a slice of values and returns a QueryCondition with the "IN" operator.
func IN(field string, values []interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "IN", Value: values}
}
// LIKE creates a QueryCondition for a "LIKE" comparison.
// It takes a field name and a pattern and returns a QueryCondition with the "LIKE" operator.
func LIKE(field, pattern string) QueryCondition {
return QueryCondition{Field: field, Operator: "LIKE", Value: pattern}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatGORM is a GORM implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatGORM struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatGORM.
// It initializes and returns a GORM-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *gorm.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatGORM{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerGORM is a GORM implementation of the RequestTracker interface.
type RequestTrackerGORM struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerGORM.
// It initializes and returns a GORM-based repository for RequestTracker entities.
func NewRequestTracker(db *gorm.DB) repositories.RequestTracker {
return &RequestTrackerGORM{
NewGenericRepository[types.RequestTracker](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineGORM is a GORM implementation of the VirtualMachine interface.
type VirtualMachineGORM struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineGORM.
// It initializes and returns a GORM-based repository for VirtualMachine entities.
func NewVirtualMachine(db *gorm.DB) repositories.VirtualMachine {
return &VirtualMachineGORM{
NewGenericRepository[types.VirtualMachine](db),
}
}
package gorm
import (
"context"
"fmt"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
createdAtField = "CreatedAt"
)
// GenericEntityRepositoryGORM is a generic single entity repository implementation using GORM as an ORM.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB // db is the GORM database instance.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryGORM.
// It initializes and returns a repository with the provided GORM database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *gorm.DB,
) repositories.GenericEntityRepository[T] {
return &GenericEntityRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryGORM[T]) Save(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Get(ctx context.Context) (T, error) {
var result T
query := repo.GetQuery()
query.SortBy = fmt.Sprintf("-%s", createdAtField)
db := repo.db.WithContext(ctx)
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Clear(ctx context.Context) error {
return repo.db.WithContext(ctx).Delete(new(T), "id IS NOT NULL").Error
}
// History retrieves previous records from the repository constrained by the query.
func (repo *GenericEntityRepositoryGORM[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
package gorm
import (
"context"
"fmt"
"reflect"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericRepositoryGORM is a generic repository implementation using GORM as an ORM.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB
}
// NewGenericRepository creates a new instance of GenericRepositoryGORM.
// It initializes and returns a repository with the provided GORM database.
func NewGenericRepository[T repositories.ModelType](db *gorm.DB) repositories.GenericRepository[T] {
return &GenericRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryGORM[T]) Create(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Get(ctx context.Context, id interface{}) (T, error) {
var result T
err := repo.db.WithContext(ctx).First(&result, "id = ?", id).Error
if err != nil {
return result, handleDBError(err)
}
return result, handleDBError(err)
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Update(ctx context.Context, id interface{}, data T) (T, error) {
err := repo.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Updates(data).Error
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.WithContext(ctx).Delete(new(T), "id = ?", id).Error
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryGORM[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var result T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryGORM[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a GORM database query.
// It takes a GORM database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified GORM database instance is returned.
func applyConditions[T any](db *gorm.DB, query repositories.Query[T]) *gorm.DB {
// Retrieve the table name using the GORM naming strategy.
tableName := db.NamingStrategy.TableName(reflect.TypeOf(*new(T)).Name())
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
columnName := db.NamingStrategy.ColumnName(tableName, condition.Field)
db = db.Where(
fmt.Sprintf("%s %s ?", columnName, condition.Operator),
condition.Value,
)
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
columnName := db.NamingStrategy.ColumnName(tableName, fieldName)
db = db.Where(fmt.Sprintf("%s = ?", columnName), fieldValue)
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := "ASC"
if query.SortBy[0] == '-' {
query.SortBy = query.SortBy[1:]
dir = "DESC"
}
columnName := db.NamingStrategy.ColumnName(tableName, query.SortBy)
db = db.Order(fmt.Sprintf("%s.%s %s", tableName, columnName, dir))
}
// Apply limit if specified in the query.
if query.Limit > 0 {
db = db.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
db = db.Limit(query.Offset)
}
return db
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoGORM is a GORM implementation of the PeerInfo interface.
type PeerInfoGORM struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoGORM.
// It initializes and returns a GORM-based repository for PeerInfo entities.
func NewPeerInfo(db *gorm.DB) repositories.PeerInfo {
return &PeerInfoGORM{NewGenericRepository[types.PeerInfo](db)}
}
// MachineGORM is a GORM implementation of the Machine interface.
type MachineGORM struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineGORM.
// It initializes and returns a GORM-based repository for Machine entities.
func NewMachine(db *gorm.DB) repositories.Machine {
return &MachineGORM{NewGenericRepository[types.Machine](db)}
}
// AvailableResourcesGORM is a GORM implementation of the AvailableResources interface.
type AvailableResourcesGORM struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesGORM.
// It initializes and returns a GORM-based repository for AvailableResources entity.
func NewAvailableResources(db *gorm.DB) repositories.AvailableResources {
return &AvailableResourcesGORM{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesGORM is a GORM implementation of the Services interface.
type ServicesGORM struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesGORM.
// It initializes and returns a GORM-based repository for Services entities.
func NewServices(db *gorm.DB) repositories.Services {
return &ServicesGORM{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsGORM is a GORM implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsGORM struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsGORM.
// It initializes and returns a GORM-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *gorm.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsGORM{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoGORM is a GORM implementation of the Libp2pInfo interface.
type Libp2pInfoGORM struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoGORM.
// It initializes and returns a GORM-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *gorm.DB) repositories.Libp2pInfo {
return &Libp2pInfoGORM{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDGORM is a GORM implementation of the MachineUUID interface.
type MachineUUIDGORM struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDGORM.
// It initializes and returns a GORM-based repository for MachineUUID entity.
func NewMachineUUID(db *gorm.DB) repositories.MachineUUID {
return &MachineUUIDGORM{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionGORM is a GORM implementation of the Connection interface.
type ConnectionGORM struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionGORM.
// It initializes and returns a GORM-based repository for Connection entities.
func NewConnection(db *gorm.DB) repositories.Connection {
return &ConnectionGORM{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenGORM is a GORM implementation of the ElasticToken interface.
type ElasticTokenGORM struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenGORM.
// It initializes and returns a GORM-based repository for ElasticToken entities.
func NewElasticToken(db *gorm.DB) repositories.ElasticToken {
return &ElasticTokenGORM{NewGenericRepository[types.ElasticToken](db)}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
type OnboardingParamsGORM struct {
repositories.GenericEntityRepository[types.OnboardingConfig]
}
func NewOnboardingParams(db *gorm.DB) repositories.OnboardingParams {
return &OnboardingParamsGORM{
NewGenericEntityRepository[types.OnboardingConfig](db),
}
}
package gorm
import (
"gitlab.com/nunet/device-management-service/types"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesGORM is a GORM implementation of the MachineResources interface.
type MachineResourcesGORM struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResources creates a new instance of MachineResourcesGORM.
// It initializes and returns a GORM-based repository for MachineResources entity.
func NewMachineResources(db *gorm.DB) repositories.MachineResources {
return &MachineResourcesGORM{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesGORM is a GORM implementation of the FreeResources interface.
type FreeResourcesGORM struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesGORM.
// It initializes and returns a GORM-based repository for FreeResources entity.
func NewFreeResources(db *gorm.DB) repositories.FreeResources {
return &FreeResourcesGORM{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesGORM is a GORM implementation of the OnboardedResources interface.
type OnboardedResourcesGORM struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResources creates a new instance of OnboardedResourcesGORM.
// It initializes and returns a GORM-based repository for OnboardedResources entity.
func NewOnboardedResources(db *gorm.DB) repositories.OnboardedResources {
return &OnboardedResourcesGORM{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// ResourceAllocationGORM is a GORM implementation of the ResourceAllocation interface.
type ResourceAllocationGORM struct {
repositories.GenericRepository[types.ResourceAllocation]
}
// NewResourceAllocation creates a new instance of ResourceAllocationGORM.
// It initializes and returns a GORM-based repository for ResourceAllocation entities.
func NewResourceAllocation(db *gorm.DB) repositories.ResourceAllocation {
return &ResourceAllocationGORM{
NewGenericRepository[types.ResourceAllocation](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeGORM is a GORM implementation of the StorageVolume interface.
type StorageVolumeGORM struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeGORM.
// It initializes and returns a GORM-based repository for StorageVolume entities.
func NewStorageVolume(db *gorm.DB) repositories.StorageVolume {
return &StorageVolumeGORM{
NewGenericRepository[types.StorageVolume](db),
}
}
package gorm
import (
"errors"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// handleDBError is a utility function that translates GORM database errors into custom repository errors.
// It takes a GORM database error as input and returns a corresponding custom error from the repositories package.
func handleDBError(err error) error {
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return repositories.ErrNotFound
case gorm.ErrInvalidData, gorm.ErrInvalidField, gorm.ErrInvalidValue:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
package repositories
import (
"fmt"
"reflect"
)
// UpdateField is a generic function that updates a field of a struct or a pointer to a struct.
// The function uses reflection to dynamically update the specified field of the input struct.
func UpdateField[T interface{}](input T, fieldName string, newValue interface{}) (T, error) {
// Use reflection to get the struct's field
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
// If input is a pointer, get the underlying element
val = val.Elem()
} else {
// If input is not a pointer, ensure it's addressable
val = reflect.ValueOf(&input).Elem()
}
// Check if the input is a struct
if val.Kind() != reflect.Struct {
return input, fmt.Errorf("not a struct: %T", input)
}
// Get the field by name
field := val.FieldByName(fieldName)
if !field.IsValid() {
return input, fmt.Errorf("field not found: %v", fieldName)
}
// Check if the field is settable
if !field.CanSet() {
return input, fmt.Errorf("field not settable: %v", fieldName)
}
// Check if types are compatible
if !reflect.TypeOf(newValue).ConvertibleTo(field.Type()) {
return input, fmt.Errorf(
"incompatible conversion: %v -> %v; value: %v",
field.Type(), reflect.TypeOf(newValue), newValue,
)
}
// Convert the new value to the field type
convertedValue := reflect.ValueOf(newValue).Convert(field.Type())
// Set the new value to the field
field.Set(convertedValue)
return input, nil
}
// IsEmptyValue checks if value represents a zero-value struct (or pointer to a zero-value struct) using reflection.
// The function is useful for determining if a struct or its pointer is empty, i.e., all fields have their zero-values.
func IsEmptyValue(value interface{}) bool {
// Check if the value is nil
if value == nil {
return true
}
// Use reflection to get the value's type and kind
val := reflect.ValueOf(value)
// If the value is a pointer, dereference it to get the underlying element
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Check if the value is zero (empty) based on its kind
return val.IsZero()
}
package jobs
import (
"context"
"errors"
"fmt"
"sync"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/executor"
"gitlab.com/nunet/device-management-service/types"
)
const (
pending AllocationStatus = "pending"
running AllocationStatus = "running"
stopped AllocationStatus = "stopped"
)
// AllocationStatus is a representation of the execution status
type AllocationStatus string
// Status holds the status of an allocation.
type Status struct {
JobResources types.Resources
Status AllocationStatus
}
// AllocationDetails encapsulates the dependencies to the constructor.
type AllocationDetails struct {
Job Job
NodeID string
SourceID string
}
// Allocation represents an allocation
type Allocation struct {
ID string
mx sync.Mutex
status AllocationStatus
nodeID string
sourceID string
executionID string
Actor *actor.BasicActor
executor executor.Executor
resourceManager types.ResourceManager
actorRunning bool
Job Job
}
// NewAllocation creates a new allocation given the actor.
func NewAllocation(actor *actor.BasicActor, details AllocationDetails, resourceManager types.ResourceManager) (*Allocation, error) {
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
id, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate uuid for allocation: %w", err)
}
executorID, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to create executor id: %w", err)
}
return &Allocation{
ID: id.String(),
Job: details.Job,
nodeID: details.NodeID,
sourceID: details.SourceID,
Actor: actor,
executionID: executorID.String(),
resourceManager: resourceManager,
status: pending,
}, nil
}
// Run creates the executor based on the execution engine configuration.
func (a *Allocation) Run(ctx context.Context) error {
a.mx.Lock()
defer a.mx.Unlock()
if a.status == running {
return nil
}
resourceAllocation := types.ResourceAllocation{JobID: a.Job.ID, Resources: a.Job.Resources}
err := a.resourceManager.AllocateResources(ctx, resourceAllocation)
if err != nil {
return fmt.Errorf("failed to allocate resources: %w", err)
}
defer func() {
if a.status != running {
// If not running, ensure deallocation of resources
err = a.resourceManager.DeallocateResources(ctx, a.Job.ID)
}
}()
// if executor is nil create it
if a.executor == nil {
err = a.createExecutor(ctx, a.Job.Execution)
if err != nil {
return fmt.Errorf("failed to create executor: %w", err)
}
}
err = a.executor.Start(ctx, &types.ExecutionRequest{
JobID: a.Job.ID,
ExecutionID: a.executionID,
EngineSpec: &a.Job.Execution,
Resources: &a.Job.Resources,
// TODO add the following
Inputs: []*types.StorageVolumeExecutor{},
Outputs: []*types.StorageVolumeExecutor{},
ResultsDir: "",
})
if err != nil {
return fmt.Errorf("failed to start executor: %w", err)
}
a.status = running
return nil
}
// Stop stops the running executor
func (a *Allocation) Stop(ctx context.Context) error {
a.mx.Lock()
defer a.mx.Unlock()
defer func() {
if a.actorRunning {
if err := a.Actor.Stop(); err != nil {
log.Warnf("error stopping allocation actor: %s", err)
}
a.actorRunning = false
}
}()
if a.status != running {
return nil
}
if err := a.executor.Cancel(ctx, a.executionID); err != nil {
return fmt.Errorf("failed to stop execution: %w", err)
}
a.status = stopped
if err := a.resourceManager.DeallocateResources(ctx, a.Job.ID); err != nil {
return fmt.Errorf("failed to deallocate resources: %w", err)
}
return nil
}
// Status returns information about the allocated/usage of resources and execution status of workload.
func (a *Allocation) Status(_ context.Context) Status {
return Status{
JobResources: a.Job.Resources,
Status: a.status,
}
}
// Start the actor of the allocation.
func (a *Allocation) Start() error {
a.mx.Lock()
defer a.mx.Unlock()
if a.actorRunning {
return nil
}
err := a.Actor.Start()
if err != nil {
return fmt.Errorf("failed to start allocation actor: %w", err)
}
a.actorRunning = true
return nil
}
//go:build linux
// +build linux
package jobs
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/executor/firecracker"
"gitlab.com/nunet/device-management-service/types"
)
func (a *Allocation) createExecutor(ctx context.Context, conf types.SpecConfig) error {
if conf.Type == string(types.ExecutorTypeFirecracker) {
executor, err := firecracker.NewExecutor(ctx, a.executionID)
if err != nil {
return fmt.Errorf("firecracker executor: %w", err)
}
a.executor = executor
} else if conf.Type == string(types.ExecutorTypeDocker) {
executor, err := docker.NewExecutor(ctx, a.executionID)
if err != nil {
return fmt.Errorf("docker executor: %w", err)
}
a.executor = executor
}
return nil
}
package parser
import (
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/nunet"
)
var registry Registry[jobs.JobSpec]
func init() {
registry = &RegistryImpl[jobs.JobSpec]{
parsers: make(map[SpecType]Parser[jobs.JobSpec]),
}
// Register Nunet parser.
nunetParser := NewParser[jobs.JobSpec](
nunet.NewNuNetTransformer(),
nunet.NewNuNetValidator(),
)
registry.RegisterParser(specTypeNuNet, nunetParser)
}
package nunet
import (
"fmt"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetTransformer() transform.Transformer {
return transform.NewTransformer(
[]map[tree.Path]transform.TransformerFunc{
{
"jobs": TransformJobs,
"jobs.**.children": TransformJobs,
"jobs.**.volumes": TransformVolumes,
"jobs.**.networks": TransformNetworks,
},
{
"jobs.**.volumes.[]": TransformVolume,
"jobs.**.networks.[]": TransformNetwork,
"jobs.**.libraries.[]": TransformLibrary,
},
{
"jobs.**.execution": TransformExecution,
"jobs.**.volumes.[].remote": TransformVolumeRemote,
},
},
)
}
// TransformJobs transforms the jobs map to a slice and assigns the keys to the "name" field.
func TransformJobs(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
jobs, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid jobs configuration: %v", data)
}
return transform.MapToSlice(jobs)
}
// TransformVolumes transforms the volumes map to a slice and assigns the keys to the "name" field.
func TransformVolumes(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
volumes, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volumes configuration: %v", data)
}
return transform.MapToSlice(volumes)
}
// TransformNetworks transforms the networks map to a slice and assigns the keys to the "name" field.
func TransformNetworks(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
networks, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid networks configuration: %v", data)
}
return transform.MapToSlice(networks)
}
// TransformExecution transforms the engine configuration from flat map to SpecConfig format.
func TransformExecution(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
engine, ok := data.(map[string]any)
result := map[string]any{}
if !ok {
return nil, fmt.Errorf("invalid engine configuration: %v", data)
}
params := map[string]any{}
for k, v := range engine {
if k != "type" {
params[k] = v
}
}
result["type"] = engine["type"]
result["params"] = params
return result, nil
}
// TransformVloume transforms the volume configuration and handles inheritance.
// The volume configuration can be a string in the format "name:mountpoint" or a map.
// If the volume is defined in the parent volumes, the configurations are merged.
func TransformVolume(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, split it into name and mountpoint.
switch v := data.(type) {
case string:
mapping := strings.Split(v, ":")
if len(mapping) != 2 {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
config = map[string]any{
"name": mapping[0],
"mountpoint": mapping[1],
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
// Collect all potential parent paths where the volume could be defined.
parentPaths := []tree.Path{}
pathParts := path.Parts()
for i, part := range pathParts {
if part == "children" {
parentPaths = append(parentPaths, tree.NewPath(pathParts[:i]...))
}
}
// Merge the volume configuration with the parent configurations.
for _, parent := range parentPaths {
// Check if the volume exists in the parent
c, err := transform.GetConfigAtPath(*root, parent.Next("volumes"))
if err != nil {
fmt.Println("error: ", err)
continue
}
volumes, _ := transform.ToAnySlice(c)
for _, v := range volumes {
if volume, ok := v.(map[string]any); ok && volume["name"] == config["name"] {
// Merge the configurations
for k, v := range volume {
config[k] = v
}
}
}
}
return config, nil
}
func TransformVolumeRemote(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
remoteConfig := map[string]any{}
remoteConfig["type"] = config["type"]
if params, ok := config["params"]; ok {
remoteConfig["params"] = params.(map[string]any)
return remoteConfig, nil
}
params := map[string]any{}
for k, v := range config {
if k != "type" {
params[k] = v
}
}
remoteConfig["params"] = params
return remoteConfig, nil
}
// TransformNetwork transforms the network configuration
func TransformNetwork(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid network configuration: %v", data)
}
ports, _ := transform.ToAnySlice(config["ports"])
portMap := []map[string]any{}
for _, port := range ports {
protocol, host, container := "tcp", 0, 0
switch v := port.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) <= 2 {
host, _ = strconv.Atoi(parts[0])
container, _ = strconv.Atoi(parts[len(parts)-1])
} else if len(parts) == 3 {
protocol = parts[0]
host, _ = strconv.Atoi(parts[1])
container, _ = strconv.Atoi(parts[len(parts)-1])
}
case int:
host = v
container = v
case map[string]any:
switch h := v["host_port"].(type) {
case int:
host = h
case string:
host, _ = strconv.Atoi(h)
}
switch c := v["container_port"].(type) {
case int:
container = c
case string:
container, _ = strconv.Atoi(c)
}
if p, ok := v["protocol"].(string); ok {
protocol = p
}
}
portMap = append(portMap, map[string]any{
"protocol": protocol,
"host_port": host,
"container_port": container,
})
}
config["port_map"] = portMap
delete(config, "ports")
return config, nil
}
// TransformLibrary tansforms the library configuration to a map.
// The library configuration can be a string in the format "name:version" or a map.
func TransformLibrary(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
switch v := data.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) == 1 {
parts = append(parts, "")
}
if len(parts) != 2 {
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
return map[string]any{
"name": parts[0],
"version": parts[1],
}, nil
case map[string]any:
return v, nil
default:
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
}
package nunet
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetValidator() validate.Validator {
return validate.NewValidator(
map[tree.Path]validate.ValidatorFunc{
"": ValidateSpec,
"jobs.[]": ValidateJob,
"jobs.**.children.[]": ValidateJob,
},
)
}
// ValidateSpec checks the root configuration for consistency.
func ValidateSpec(_ *map[string]any, data any, _ tree.Path) error {
spec, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid spec configuration: %v", data)
}
// Check if the jobs list is present and not empty.
if spec["jobs"] == nil || len(spec["jobs"].([]any)) == 0 {
return fmt.Errorf("jobs list is required")
}
return nil
}
// ValidateJob checks the job configuration.
func ValidateJob(_ *map[string]any, data any, _ tree.Path) error {
job, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid job configuration: %v", data)
}
// Check if the job has either children or an execution.
if job["children"] == nil || len(job["children"].([]any)) == 0 {
if job["execution"] == nil {
return fmt.Errorf("job must have either children or an execution")
}
}
return nil
}
package parser
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs"
)
func Parse(specType SpecType, data []byte) (jobs.JobSpec, error) {
result := jobs.JobSpec{}
parser, exists := registry.GetParser(specType)
if !exists {
return result, fmt.Errorf("parser for spec type %s not found", specType)
}
result, err := parser.Parse(data)
if err != nil {
return result, err
}
return result, nil
}
package parser
import (
"encoding/json"
"fmt"
"github.com/mitchellh/mapstructure"
yaml "gopkg.in/yaml.v3"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
type SpecType string
const (
specTypeNuNet SpecType = "nunet"
specTypeNomad SpecType = "nomad"
specTypeK8s SpecType = "k8s"
)
const DefaultTagName = "json"
type Parser[T any] interface {
Parse(data []byte) (T, error)
}
type Impl[T any] struct {
validator validate.Validator
transformer transform.Transformer
}
func NewParser[T any](transformer transform.Transformer, validator validate.Validator) Parser[T] {
return Impl[T]{
transformer: transformer,
validator: validator,
}
}
func (p Impl[T]) Parse(data []byte) (T, error) {
var rawConfig map[string]any
var config T
// Try to unmarshal as YAML first
err := yaml.Unmarshal(data, &rawConfig)
if err != nil {
// If YAML fails, try JSON
err = json.Unmarshal(data, &rawConfig)
if err != nil {
return config, fmt.Errorf("failed to parse config: %v", err)
}
}
// Apply transformers
transformed, err := p.transformer.Transform(&rawConfig)
if err != nil {
return config, fmt.Errorf("failed to transform config: %v", err)
}
// Validate the transformed configuration
if err = p.validator.Validate(&rawConfig); err != nil {
return config, err
}
// Create a new mapstructure decoder
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &config,
TagName: DefaultTagName,
})
if err != nil {
return config, fmt.Errorf("failed to create decoder: %v", err)
}
// Decode the transformed configuration
err = decoder.Decode(transformed)
if err != nil {
return config, fmt.Errorf("failed to decode config: %v", err)
}
return config, err
}
package parser
import (
"sync"
)
type Registry[T any] interface {
GetParser(specType SpecType) (Parser[T], bool)
RegisterParser(specType SpecType, p Parser[T])
}
type RegistryImpl[T any] struct {
parsers map[SpecType]Parser[T]
mu sync.RWMutex
}
func (r *RegistryImpl[T]) RegisterParser(specType SpecType, p Parser[T]) {
r.mu.Lock()
defer r.mu.Unlock()
r.parsers[specType] = p
}
func (r *RegistryImpl[T]) GetParser(specType SpecType) (Parser[T], bool) {
r.mu.RLock()
defer r.mu.RUnlock()
p, exists := r.parsers[specType]
return p, exists
}
package transform
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// TransformerFunc is a function that transforms a part of the configuration.
// It modifies the data to conform to the expected structure and returns the transformed data.
// It takes the root configuration, the data to transform and the current path in the tree.
type TransformerFunc func(*map[string]interface{}, interface{}, tree.Path) (any, error)
// Transformer is a configuration transformer.
type Transformer interface {
Transform(*map[string]interface{}) (interface{}, error)
}
// TransformerImpl is the implementation of the Transformer interface.
type TransformerImpl struct {
transformers []map[tree.Path]TransformerFunc
}
// NewTransformer creates a new transformer with the given transformers.
func NewTransformer(transformers []map[tree.Path]TransformerFunc) Transformer {
return TransformerImpl{
transformers: transformers,
}
}
// Transform applies the transformers to the configuration.
func (t TransformerImpl) Transform(rawConfig *map[string]interface{}) (interface{}, error) {
data := any(*rawConfig)
var err error
for _, transformers := range t.transformers {
data, err = t.transform(rawConfig, data, tree.NewPath(), transformers)
if err != nil {
return nil, err
}
}
return Normalize(data), nil
}
// transform is a recursive function that applies the transformers to the configuration.
func (t TransformerImpl) transform(root *map[string]interface{}, data any, path tree.Path, transformers map[tree.Path]TransformerFunc) (interface{}, error) {
var err error
// Apply transformers that match the current path.
for pattern, transformer := range transformers {
if path.Matches(pattern) {
data, err = transformer(root, data, path)
if err != nil {
return nil, err
}
}
}
// Recursively apply transformers to children.
if result, ok := data.(map[string]interface{}); ok {
for key, value := range result {
next := path.Next(key)
result[key], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
} else if result, err := ToAnySlice(data); err == nil {
for i, value := range result {
next := path.Next(fmt.Sprintf("[%d]", i))
result[i], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
}
return data, nil
}
package transform
import (
"fmt"
"reflect"
"sort"
"strconv"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// mapToSlice converts a map of maps to a slice
// and assigns the key to the "name" field.
func MapToSlice(data map[string]any) ([]any, error) {
if data == nil {
return nil, nil
}
result := []any{}
for k, v := range data {
if v == nil {
v = map[string]any{}
}
if e, ok := v.(map[string]any); ok {
e["name"] = k
}
result = append(result, v)
}
return result, nil
}
// getConfigAtPath retrieves a part of the configuration at a given path
func GetConfigAtPath(config map[string]interface{}, path tree.Path) (any, error) {
current := any(config)
for _, key := range path.Parts() {
switch v := current.(type) {
case map[string]any:
current = v[key]
case []any, []map[string]any:
i, err := strconv.Atoi(key[1 : len(key)-1])
if err != nil {
return nil, fmt.Errorf("invalid index: %v", key)
}
switch v := v.(type) {
case []any:
current = v[i]
case []map[string]any:
current = v[i]
}
default:
return nil, fmt.Errorf("invalid data type: %v", current)
}
}
return current, nil
}
// Generic function to convert any slice to []any
func ToAnySlice(slice any) ([]any, error) {
value := reflect.ValueOf(slice)
// Check if the input is a slice
if value.Kind() != reflect.Slice {
return nil, fmt.Errorf("input is not a slice. type: %T", slice)
}
length := value.Len()
anySlice := make([]any, length)
for i := 0; i < length; i++ {
anySlice[i] = value.Index(i).Interface()
}
return anySlice, nil
}
func normalizeMap(m interface{}) interface{} {
v := reflect.ValueOf(m)
switch v.Kind() {
case reflect.Map:
// Create a new map to hold normalized values
newMap := reflect.MakeMap(reflect.MapOf(v.Type().Key(), reflect.TypeOf((*interface{})(nil)).Elem()))
for _, key := range v.MapKeys() {
newValue := normalizeMap(v.MapIndex(key).Interface())
newMap.SetMapIndex(key, reflect.ValueOf(newValue))
}
return newMap.Interface()
case reflect.Slice:
// Create a new []interface{} slice to hold normalized values
newSlice := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
newSlice[i] = normalizeMap(v.Index(i).Interface())
}
// Sort the slice if it's sortable
sort.Slice(newSlice, func(i, j int) bool {
return fmt.Sprint(newSlice[i]) < fmt.Sprint(newSlice[j])
})
return newSlice
default:
// For other types, return as is
return m
}
}
// NormalizeMap is the exported function that users will call
func Normalize(m any) interface{} {
return normalizeMap(m)
}
package tree
import (
"strings"
)
const (
configPathSeparator = "."
configPathMatchAny = "*"
configPathMatchAnyMultiple = "**"
configPathList = "[]"
)
// Path is a custom type for representing paths in the configuration
type Path string
func NewPath(path ...string) Path {
return Path(strings.Join(path, configPathSeparator))
}
// Parts returns the parts of the path
func (p Path) Parts() []string {
return strings.Split(string(p), configPathSeparator)
}
// Parent returns the parent path
func (p Path) Parent() Path {
parts := p.Parts()
if len(parts) > 1 {
return Path(strings.Join(parts[:len(parts)-1], configPathSeparator))
}
return ""
}
// Next returns the next part of the path
func (p Path) Next(path string) Path {
if path == "" {
return p
}
if p == "" {
return Path(path)
}
return Path(string(p) + configPathSeparator + path)
}
// Last returns the last part of the path
func (p Path) Last() string {
parts := p.Parts()
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}
// Matches checks if the path matches a given pattern
func (p Path) Matches(pattern Path) bool {
pathParts := p.Parts()
patternParts := pattern.Parts()
return matchParts(pathParts, patternParts)
}
// String returns the string representation of the path
func (p Path) String() string {
return string(p)
}
// matchParts checks if the path parts match the pattern parts
func matchParts(pathParts, patternParts []string) bool {
// If the pattern is longer than the path, it can't match
if len(pathParts) < len(patternParts) {
return false
}
for i, part := range patternParts {
switch part {
case configPathMatchAnyMultiple:
// if it is the last part of the pattern, it matches
if i == len(patternParts)-1 {
return true
}
// Otherwise, try to match the rest of the path
for j := i; j < len(pathParts); j++ {
if matchParts(pathParts[j:], patternParts[i+1:]) {
return true
}
}
case configPathList:
// check if pathParts[i] is inclosed by []
if pathParts[i][0] != '[' || pathParts[i][len(pathParts[i])-1] != ']' {
return false
}
default:
// If the part doesn't match, it doesn't match
if part != configPathMatchAny && part != pathParts[i] {
return false
}
}
// If it is the last part of the pattern and the path is longer, it doesn't match
if i == len(patternParts)-1 && i < len(pathParts)-1 {
return false
}
}
return true
}
package validate
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// ValidatorFunc is a function that validates a part of the configuration.
// It takes the root configuration, the data to validate and the current path in the tree.
type ValidatorFunc func(*map[string]any, any, tree.Path) error
// Validator is a configuration validator.
// It contains a map of patterns to paths to functions that validate the configuration.
type Validator interface {
Validate(*map[string]any) error
}
// ValidatorImpl is the implementation of the Validator interface.
type ValidatorImpl struct {
validators map[tree.Path]ValidatorFunc
}
// NewValidator creates a new validator with the given validators.
func NewValidator(validators map[tree.Path]ValidatorFunc) Validator {
return ValidatorImpl{
validators: validators,
}
}
// Validate applies the validators to the configuration.
func (v ValidatorImpl) Validate(rawConfig *map[string]any) error {
data := any(*rawConfig)
return v.validate(rawConfig, data, tree.NewPath(), v.validators)
}
// validate is a recursive function that applies the validators to the configuration.
func (v ValidatorImpl) validate(root *map[string]interface{}, data any, path tree.Path, validators map[tree.Path]ValidatorFunc) error {
// Apply validators that match the current path.
for pattern, validator := range validators {
if path.Matches(pattern) {
if err := validator(root, data, path); err != nil {
return err
}
}
}
// Recursively apply validators to children.
switch data := data.(type) {
case map[string]interface{}:
for key, value := range data {
next := path.Next(key)
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
case []interface{}:
for i, value := range data {
next := path.Next(fmt.Sprintf("[%d]", i))
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
}
return nil
}
package node
import (
"context"
"encoding/json"
"time"
kbucket "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
const (
PeersListBehavior = "/dms/node/peers/list"
PeerAddrInfoBehavior = "/dms/node/peers/self"
PeerPingBehavior = "/dms/node/peers/ping"
PeerDHTBehavior = "/dms/node/peers/dht"
PeerConnectBehavior = "/dms/node/peers/connect"
PeerScoreBehavior = "/dms/node/peers/score"
OnboardBehaviour = "/dms/node/onboarding/onboard"
OffboardBehaviour = "/dms/node/onboarding/offboard"
OnboardStatusBehaviour = "/dms/node/onboarding/status"
OnboardResourceBehaviour = "/dms/node/onboarding/resource"
CustomVMStart = "/dms/node/vm/start/custom"
VMStop = "/dms/node/vm/stop"
VMList = "/dms/node/vm/list"
pingTimeout = 1 * time.Second
)
type PingRequest struct {
Host string
}
type PingResponse struct {
Error string
RTT int64
}
func (n *Node) handlePeerPing(msg actor.Envelope) {
defer msg.Discard()
var request PingRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := PingResponse{}
res, err := n.network.Ping(context.Background(), request.Host, pingTimeout)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
if res.Error != nil {
resp.Error = res.Error.Error()
}
resp.RTT = res.RTT.Milliseconds()
n.sendReply(msg, resp)
}
type PeersListResponse struct {
Peers []peer.ID
}
func (n *Node) handlePeersList(msg actor.Envelope) {
defer msg.Discard()
// get the underlying libp2p instance and extract the DHT data
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
return
}
resp := PeersListResponse{
Peers: make([]peer.ID, 0),
}
for _, v := range libp2pNet.PS.Peers() {
resp.Peers = append(resp.Peers, v)
}
n.sendReply(msg, resp)
}
type PeerAddrInfoResponse struct {
ID string `json:"id"`
Address string `json:"listen_addr"`
}
func (n *Node) handlePeerAddrInfo(msg actor.Envelope) {
defer msg.Discard()
stats := n.network.Stat()
resp := PeerAddrInfoResponse{
ID: stats.ID,
Address: stats.ListenAddr,
}
n.sendReply(msg, resp)
}
type PeerDHTResponse struct {
Peers []kbucket.PeerInfo
}
func (n *Node) handlePeerDHT(msg actor.Envelope) {
defer msg.Discard()
// get the underlying libp2p instance and extract the DHT data
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
return
}
resp := PeerDHTResponse{
Peers: libp2pNet.DHT.RoutingTable().GetPeerInfos(),
}
n.sendReply(msg, resp)
}
type PeerConnectRequest struct {
Address string
}
type PeerConnectResponse struct {
Status string
Error string
}
func (n *Node) handlePeerConnect(msg actor.Envelope) {
defer msg.Discard()
var request PeerConnectRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
return
}
resp := PeerConnectResponse{}
peerAddr, err := multiaddr.NewMultiaddr(request.Address)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
if err := libp2pNet.Host.Connect(context.Background(), *addrInfo); err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Status = "CONNECTED"
n.sendReply(msg, resp)
}
type OnboardRequest struct {
Config types.CapacityForNunet
}
type OnboardResponse struct {
Error string
Config types.OnboardingConfig
}
func (n *Node) handleOnboard(msg actor.Envelope) {
defer msg.Discard()
resp := OnboardResponse{}
var request OnboardRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
onboardResult, err := n.onboarder.Onboard(context.Background(), request.Config)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Config = *onboardResult
n.sendReply(msg, resp)
}
type OffboardRequest struct {
Force bool
}
type OffboardResponse struct {
Success bool
}
func (n *Node) handleOffboard(msg actor.Envelope) {
defer msg.Discard()
var request OffboardRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := OffboardResponse{}
if err := n.onboarder.Offboard(context.Background(), request.Force); err != nil {
resp.Success = false
n.sendReply(msg, resp)
return
}
resp.Success = true
n.sendReply(msg, resp)
}
type OnboardStatusResponse struct {
Onboarded bool
Error string
}
func (n *Node) handleOnboardStatus(msg actor.Envelope) {
defer msg.Discard()
resp := OnboardStatusResponse{}
onboarded, err := n.onboarder.IsOnboarded(context.Background())
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Onboarded = onboarded
n.sendReply(msg, resp)
}
type OnboardResourceRequest struct {
Config types.CapacityForNunet
}
type OnboardResourceResponse struct {
Error string
Result types.OnboardingConfig
}
func (n *Node) handleOnboardResource(msg actor.Envelope) {
defer msg.Discard()
var request OnboardResourceRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := OnboardResourceResponse{}
result, err := n.onboarder.ResourceConfig(context.Background(), request.Config)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Result = *result
n.sendReply(msg, resp)
}
type CustomVMStartRequest struct {
Execution types.ExecutionRequest
}
type CustomVMStartResponse struct {
Error string
}
func (n *Node) handleCustomVMStart(msg actor.Envelope) {
defer msg.Discard()
var request CustomVMStartRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := CustomVMStartResponse{}
err := n.executor.Start(context.Background(), &request.Execution)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
type VMStopRequest struct {
ExecutionID string
}
type VMStopResponse struct {
Error string
}
func (n *Node) handleVMStop(msg actor.Envelope) {
defer msg.Discard()
var request VMStopRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := VMStopResponse{}
err := n.executor.Cancel(context.Background(), request.ExecutionID)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
type ListVMResponse struct {
Error string
VMS []types.ExecutionListItem
}
func (n *Node) handleListVM(msg actor.Envelope) {
defer msg.Discard()
resp := ListVMResponse{
VMS: n.executor.List(),
}
n.sendReply(msg, resp)
}
type PeerScoreResponse struct {
Score map[peer.ID]*network.PeerScoreSnapshot
}
func (n *Node) handlePeerScore(msg actor.Envelope) {
defer msg.Discard()
resp := PeerScoreResponse{Score: n.network.GetBroadcastScore()}
n.sendReply(msg, resp)
}
package node
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/executor"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
// Node is the structure that holds the node's dependencies.
type Node struct {
rootCap ucan.CapabilityContext
actor actor.Actor
scheduler *bt.Scheduler
network network.Network
resourceManager types.ResourceManager
hostID string
onboarder *onboarding.Onboarding
executor executor.Executor
mx sync.Mutex
peers map[peer.ID]*peerState
allocations map[string]*jobs.Allocation
running int32
}
type peerState struct {
conns int
helloIn, helloOut bool
}
// New creates a new node, attaches an actor to the node.
func New(ctx context.Context, onboarder *onboarding.Onboarding, rootCap ucan.CapabilityContext, hostID string, net network.Network, resourceManager types.ResourceManager, scheduler *bt.Scheduler) (*Node, error) {
if onboarder == nil {
return nil, errors.New("onboarder is nil")
}
if rootCap == nil {
return nil, errors.New("root capability context is nil")
}
if hostID == "" {
return nil, errors.New("host id is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
if scheduler == nil {
return nil, errors.New("scheduler is nil")
}
rootDID := rootCap.DID()
rootTrust := rootCap.Trust()
anchor, err := rootTrust.GetAnchor(rootDID)
if err != nil {
return nil, fmt.Errorf("failed to get root DID anchor: %w", err)
}
pubk := anchor.PublicKey()
provider, err := rootTrust.GetProvider(rootDID)
if err != nil {
return nil, fmt.Errorf("failed to get root DID provider: %w", err)
}
privk, err := provider.PrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to get root private key: %w", err)
}
rootSec, err := actor.NewBasicSecurityContext(pubk, privk, rootCap)
if err != nil {
return nil, fmt.Errorf("failed to create security context: %w", err)
}
nodeActor, err := createActor(rootSec, actor.NewRateLimiter(actor.DefaultRateLimiterConfig()), hostID, "root", net, scheduler)
if err != nil {
return nil, fmt.Errorf("failed to create node actor: %w", err)
}
executor, err := NewExecutor(ctx)
if err != nil {
return nil, fmt.Errorf("new executor: %w", err)
}
n := &Node{
hostID: hostID,
network: net,
allocations: make(map[string]*jobs.Allocation),
peers: make(map[peer.ID]*peerState),
resourceManager: resourceManager,
actor: nodeActor,
rootCap: rootCap,
scheduler: scheduler,
onboarder: onboarder,
executor: executor,
}
if err := nodeActor.AddBehavior(PublicHelloBehavior, n.publicHelloBehavior); err != nil {
return nil, fmt.Errorf("adding public hello behavior: %w", err)
}
if err := nodeActor.AddBehavior(PublicStatusBehavior, n.publicStatusBehavior); err != nil {
return nil, fmt.Errorf("adding public status behavior: %w", err)
}
if err := nodeActor.AddBehavior(BroadcastHelloBehavior, n.broadcastHelloBehavior, actor.WithBehaviorTopic(BroadcastHelloTopic)); err != nil {
return nil, fmt.Errorf("adding broadcast status behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeersListBehavior, n.handlePeersList); err != nil {
return nil, fmt.Errorf("adding peers list behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerAddrInfoBehavior, n.handlePeerAddrInfo); err != nil {
return nil, fmt.Errorf("adding peers addr info behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerPingBehavior, n.handlePeerPing); err != nil {
return nil, fmt.Errorf("adding peer ping behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerDHTBehavior, n.handlePeerDHT); err != nil {
return nil, fmt.Errorf("adding peer dht behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerConnectBehavior, n.handlePeerConnect); err != nil {
return nil, fmt.Errorf("adding peer connect behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerScoreBehavior, n.handlePeerScore); err != nil {
return nil, fmt.Errorf("adding peer score behavior: %w", err)
}
if err := nodeActor.AddBehavior(OnboardBehaviour, n.handleOnboard); err != nil {
return nil, fmt.Errorf("adding onboard behavior: %w", err)
}
if err := nodeActor.AddBehavior(OffboardBehaviour, n.handleOffboard); err != nil {
return nil, fmt.Errorf("adding offboard behavior: %w", err)
}
if err := nodeActor.AddBehavior(OnboardStatusBehaviour, n.handleOnboardStatus); err != nil {
return nil, fmt.Errorf("adding onboard status behavior: %w", err)
}
if err := nodeActor.AddBehavior(OnboardResourceBehaviour, n.handleOnboardResource); err != nil {
return nil, fmt.Errorf("adding onboard resource behavior: %w", err)
}
if err := nodeActor.AddBehavior(CustomVMStart, n.handleCustomVMStart); err != nil {
return nil, fmt.Errorf("adding custom vm start behavior: %w", err)
}
if err := nodeActor.AddBehavior(VMStop, n.handleVMStop); err != nil {
return nil, fmt.Errorf("adding vm stop behavior: %w", err)
}
if err := nodeActor.AddBehavior(VMList, n.handleListVM); err != nil {
return nil, fmt.Errorf("adding vm list behavior: %w", err)
}
return n, nil
}
// CreateAllocation creates an allocation
func (n *Node) CreateAllocation(job jobs.Job) (*jobs.Allocation, error) {
// generate random keypair
priv, pub, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return nil, fmt.Errorf("failed to generate random keypair for allocation job %s: %w", job.ID, err)
}
security, err := actor.NewBasicSecurityContext(pub, priv, n.rootCap)
if err != nil {
return nil, fmt.Errorf("failed to create security context: %w", err)
}
allocationInbox, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate uuid for allocation inbox: %w", err)
}
actor, err := createActor(security, n.actor.Limiter(), n.hostID, allocationInbox.String(), n.network, n.scheduler)
if err != nil {
return nil, fmt.Errorf("failed to create allocation actor: %w", err)
}
allocation, err := jobs.NewAllocation(actor, jobs.AllocationDetails{Job: job, NodeID: n.hostID}, n.resourceManager)
if err != nil {
return nil, fmt.Errorf("failed to create allocation actor: %w", err)
}
err = allocation.Start()
if err != nil {
return nil, fmt.Errorf("failed to start the allocation: %w", err)
}
n.mx.Lock()
n.allocations[allocation.ID] = allocation
n.mx.Unlock()
return allocation, nil
}
// GetAllocation gets an allocation by id.
func (n *Node) GetAllocation(id string) (*jobs.Allocation, error) {
n.mx.Lock()
defer n.mx.Unlock()
alloc, ok := n.allocations[id]
if !ok {
return nil, errors.New("allocation not found")
}
return alloc, nil
}
// Start node
func (n *Node) Start() error {
if !atomic.CompareAndSwapInt32(&n.running, 0, 1) {
return nil
}
if err := n.actor.Start(); err != nil {
return fmt.Errorf("failed to start node actor: %w", err)
}
if err := n.subscribe(BroadcastHelloTopic); err != nil {
_ = n.actor.Stop()
return err
}
return nil
}
func (n *Node) subscribe(topics ...string) error {
for _, topic := range topics {
if err := n.actor.Subscribe(topic, n.setupBroadcast); err != nil {
return fmt.Errorf("error subscribing to %s: %w", topic, err)
}
}
n.network.SetBroadcastAppScore(n.broadcastScore)
if err := n.network.Notify(n.actor.Context(), n.peerConnected, n.peerDisconnected); err != nil {
return fmt.Errorf("error setting up peer notifications: %w", err)
}
return nil
}
func (n *Node) setupBroadcast(topic string) error {
return n.network.SetupBroadcastTopic(topic, func(t *network.Topic) error {
return t.SetScoreParams(&pubsub.TopicScoreParams{
SkipAtomicValidation: true,
TopicWeight: 1.0,
TimeInMeshWeight: 0.00027, // ~1/3600
TimeInMeshQuantum: time.Second,
TimeInMeshCap: 1.0,
InvalidMessageDeliveriesWeight: -1000,
InvalidMessageDeliveriesDecay: pubsub.ScoreParameterDecay(time.Hour),
})
})
}
func (n *Node) broadcastScore(p peer.ID) float64 {
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if ok && st.helloIn && st.helloOut {
return 0.01
}
return -100
}
func (n *Node) peerConnected(p peer.ID) {
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
st = &peerState{}
n.peers[p] = st
}
if !st.helloOut {
go n.sayHello(p)
}
st.conns++
}
func (n *Node) peerDisconnected(p peer.ID) {
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
return
}
st.conns--
if st.conns <= 0 {
delete(n.peers, p)
}
}
func (n *Node) sayHello(p peer.ID) {
pubk, err := p.ExtractPublicKey()
if err != nil {
log.Debugf("failed to extract public key: %s", err)
return
}
if !crypto.AllowedKey(int(pubk.Type())) {
log.Debugf("unexpected key type: %d", pubk.Type())
return
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
log.Debugf("failed to extract actor ID: %s", err)
return
}
actorDID := did.FromPublicKey(pubk)
handle := actor.Handle{
ID: actorID,
DID: actorDID,
Address: actor.Address{
HostID: p.String(),
InboxAddress: "root",
},
}
msg, err := actor.Message(
n.actor.Handle(),
handle,
PublicHelloBehavior,
nil,
actor.WithMessageTimeout(time.Second),
)
if err != nil {
log.Debugf("failed to construct hello message: %s", err)
return
}
replyCh, err := n.actor.Invoke(msg)
if err != nil {
log.Debugf("error invoking hello: %s", err)
return
}
select {
case reply := <-replyCh:
reply.Discard()
n.mx.Lock()
if st, ok := n.peers[p]; ok {
st.helloOut = true
} else if n.network.PeerConnected(p) {
// rance with connected notification
st = &peerState{helloOut: true}
n.peers[p] = st
}
n.mx.Unlock()
log.Infof("got hello from %s", handle)
case <-time.After(time.Until(msg.Expiry())):
log.Debugf("hello timeout for %s", handle)
}
}
// Stop node
func (n *Node) Stop() error {
n.mx.Lock()
defer n.mx.Unlock()
if !atomic.CompareAndSwapInt32(&n.running, 1, 0) {
return nil
}
// stop all allocations
for k, alloc := range n.allocations {
if err := alloc.Stop(context.Background()); err != nil {
log.Warnf("error stopping allocation %s: %w", k, err)
}
}
// clear the broadcast app score
n.network.SetBroadcastAppScore(nil)
// stop the actor
if err := n.actor.Stop(); err != nil {
return fmt.Errorf("failed to stop node actor: %w", err)
}
return nil
}
func (n *Node) sendReply(msg actor.Envelope, payload interface{}) {
var opt []actor.MessageOption
if msg.IsBroadcast() {
opt = append(opt, actor.WithMessageSource(n.actor.Handle()))
}
reply, err := actor.ReplyTo(msg, payload, opt...)
if err != nil {
log.Debugf("error creating peers list reply: %s", err)
return
}
if err := n.actor.Send(reply); err != nil {
log.Debugf("error sending peers list reply: %s", err)
}
}
// createActor creates an actor.
func createActor(sctx *actor.BasicSecurityContext, limiter actor.RateLimiter, hostID, inboxAddress string, net network.Network, scheduler *bt.Scheduler) (*actor.BasicActor, error) {
self := actor.Handle{
ID: sctx.ID(),
DID: sctx.DID(),
Address: actor.Address{
HostID: hostID,
InboxAddress: inboxAddress,
},
}
actor, err := actor.New(scheduler, net, sctx, limiter, actor.BasicActorParams{}, self)
if err != nil {
return nil, fmt.Errorf("failed to create actor: %w", err)
}
return actor, nil
}
//go:build linux
// +build linux
package node
import (
"context"
"errors"
"fmt"
"gitlab.com/nunet/device-management-service/executor"
"gitlab.com/nunet/device-management-service/executor/firecracker"
"gitlab.com/nunet/device-management-service/executor/null"
)
func NewExecutor(ctx context.Context) (executor.Executor, error) {
executor, err := firecracker.NewExecutor(ctx, "root")
if err != nil {
if errors.Is(err, firecracker.ErrNotInstalled) {
executor, err := null.NewExecutor(ctx, "root")
if err != nil {
return nil, fmt.Errorf("failed to setup null executor: %w", err)
}
return executor, nil
}
return nil, fmt.Errorf("failed to create executor: %w", err)
}
return executor, nil
}
package node
import (
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/types"
)
const (
PublicHelloBehavior = "/public/hello"
PublicStatusBehavior = "/public/status"
BroadcastHelloBehavior = "/broadcast/hello"
BroadcastHelloTopic = "/nunet/hello"
)
type HelloResponse struct {
DID did.DID
}
type PublicStatusResponse struct {
Status string
Resources types.Resources
}
func (n *Node) publicHelloBehavior(msg actor.Envelope) {
pubk, err := did.PublicKeyFromDID(msg.From.DID)
if err != nil {
log.Debugf("failed to extract public key from DID: %s", err)
return
}
p, err := peer.IDFromPublicKey(pubk)
if err != nil {
log.Debugf("failed to extract peer ID from public key: %s", err)
return
}
n.mx.Lock()
if st, ok := n.peers[p]; ok {
st.helloIn = true
} else if n.network.PeerConnected(p) {
// rance with connected notification
st = &peerState{helloIn: true}
n.peers[p] = st
}
n.mx.Unlock()
n.handleHello(msg)
}
func (n *Node) broadcastHelloBehavior(msg actor.Envelope) {
n.handleHello(msg)
}
func (n *Node) handleHello(msg actor.Envelope) {
defer msg.Discard()
log.Debugf("hello from %s", msg.From)
resp := HelloResponse{DID: n.actor.Security().DID()}
n.sendReply(msg, resp)
}
func (n *Node) publicStatusBehavior(msg actor.Envelope) {
defer msg.Discard()
var resp PublicStatusResponse
machineResources, err := n.resourceManager.SystemSpecs().GetMachineResources()
if err != nil {
resp.Status = "ERROR"
} else {
resp.Status = "OK"
resp.Resources = machineResources.Resources
}
n.sendReply(msg, resp)
}
package onboarding
import (
"crypto/ecdsa"
"errors"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/fivebinaries/go-cardano-serialization/address"
"github.com/fivebinaries/go-cardano-serialization/bip32"
"github.com/fivebinaries/go-cardano-serialization/network"
"github.com/tyler-smith/go-bip39"
"gitlab.com/nunet/device-management-service/types"
)
func GetEthereumAddressAndPrivateKey() (*types.BlockchainAddressPrivKey, error) {
privateKey, err := crypto.GenerateKey()
if err != nil {
return nil, err
}
privateKeyBytes := crypto.FromECDSA(privateKey)
privateKeyString := hexutil.Encode(privateKeyBytes)
publicKey := privateKey.Public()
publicKeyECDSA, ok := publicKey.(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("publicKey is not of type *ecdsa.PublicKey")
}
address := crypto.PubkeyToAddress(*publicKeyECDSA).Hex()
var pair types.BlockchainAddressPrivKey
pair.Address = address
pair.PrivateKey = privateKeyString
return &pair, nil
}
func harden(num uint32) uint32 {
return 0x80000000 + num
}
func GetCardanoAddressAndMnemonic() (*types.BlockchainAddressPrivKey, error) {
var pair types.BlockchainAddressPrivKey
entropy, _ := bip39.NewEntropy(256)
mnemonic, _ := bip39.NewMnemonic(entropy)
pair.Mnemonic = mnemonic
rootKey := bip32.FromBip39Entropy(
entropy,
[]byte{},
)
accountKey := rootKey.Derive(harden(1852)).Derive(harden(1815)).Derive(harden(0))
utxoPubKey := accountKey.Derive(0).Derive(0).Public()
utxoPubKeyHash := utxoPubKey.PublicKey().Hash()
stakeKey := accountKey.Derive(2).Derive(0).Public()
stakeKeyHash := stakeKey.PublicKey().Hash()
addr := address.NewBaseAddress(
network.MainNet(),
&address.StakeCredential{
Kind: address.KeyStakeCredentialType,
Payload: utxoPubKeyHash[:],
},
&address.StakeCredential{
Kind: address.KeyStakeCredentialType,
Payload: stakeKeyHash[:],
})
pair.Address = addr.String()
return &pair, nil
}
package onboarding
import (
"github.com/shirou/gopsutil/cpu"
"github.com/shirou/gopsutil/mem"
"gitlab.com/nunet/device-management-service/types"
)
// totalRAMInMB fetches total memory installed on host machine
func totalRAMInMB() uint64 {
v, _ := mem.VirtualMemory()
ramInMB := v.Total / 1024 / 1024
return ramInMB
}
// totalCPUInMHz fetches compute capacity of the host machine
func totalCPUInMHz() float64 {
cores, _ := cpu.Info()
var totalCompute float64
for i := 0; i < len(cores); i++ {
totalCompute += cores[i].Mhz
}
return totalCompute
}
// GetTotalProvisioned returns Provisioned struct with provisioned memory and CPU.
func GetTotalProvisioned() *types.Provisioned {
cores, _ := cpu.Info()
provisioned := &types.Provisioned{
CPU: totalCPUInMHz(),
Memory: totalRAMInMB(),
NumCores: uint64(len(cores)),
}
return provisioned
}
package onboarding
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("onboarding")
}
package onboarding
import (
"context"
"errors"
"fmt"
"os"
"slices"
"time"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
var ErrMachineNotOnboarded = errors.New("machine is not onboarded")
type Config struct {
Fs afero.Afero
WorkDir string
DatabasePath string
ParamsRepo repositories.OnboardingParams
P2PRepo repositories.Libp2pInfo
ResourceManager types.ResourceManager
AvResourceRepo repositories.AvailableResources
UUIDRepo repositories.MachineUUID
Channels []string // supported channels such as nunet-test and nunet-team
}
// NewConfig is a constructor for Config
func NewConfig(
fs afero.Afero,
workDir, dbPath string,
onboardingRepo repositories.OnboardingParams,
p2pRepo repositories.Libp2pInfo,
avResourceRepo repositories.AvailableResources,
uuidRepo repositories.MachineUUID,
channels []string,
) *Config {
return &Config{
Fs: fs,
WorkDir: workDir,
DatabasePath: dbPath,
ParamsRepo: onboardingRepo,
P2PRepo: p2pRepo,
AvResourceRepo: avResourceRepo,
UUIDRepo: uuidRepo,
Channels: channels,
}
}
// Onboarding acts a receiver for methods related to onboarding
type Onboarding struct {
Config
}
// New is a constructor for Onboarding
func New(config *Config) *Onboarding {
return &Onboarding{Config: *config}
}
// IsOnboarded checks whether the machine is onboarded or not
func (o *Onboarding) IsOnboarded(ctx context.Context) (bool, error) {
_, err := o.ParamsRepo.Get(ctx)
if err != nil {
return false, nil
}
// TODO: validate onboarding params
return true, nil
}
// Info returns additional info from onboarding
func (o *Onboarding) Info(ctx context.Context) (*types.OnboardingConfig, error) {
info, err := o.ParamsRepo.Get(ctx)
if err != nil {
return nil, err
}
return &info, err
}
// Onboard validates the onboarding params and onboards the machine to the network
// It returns a *types.OnboardingConfig and any error if encountered
func (o *Onboarding) Onboard(ctx context.Context, capacity types.CapacityForNunet) (*types.OnboardingConfig, error) {
if err := o.validateOnboardingPrerequisites(capacity); err != nil {
return nil, err
}
hostname, err := os.Hostname()
if err != nil {
return nil, fmt.Errorf("unable to get hostname: %v", err)
}
machineResources, err := o.ResourceManager.SystemSpecs().GetMachineResources()
if err != nil {
return nil, fmt.Errorf("cannot get provisioned resources: %w", err)
}
var oConf types.OnboardingConfig
oConf.Name = hostname
oConf.UpdateTimestamp = time.Now().Unix()
oConf.TotalResources.RAM = machineResources.RAM
oConf.TotalResources.CPU = machineResources.CPU
oConf.GpuInfo = machineResources.GPUs
oConf.OnboardedResources.RAM = types.RAM{Size: capacity.Memory}
oConf.OnboardedResources.CPU = types.CPU{Cores: float32(capacity.CPU)}
oConf.Network = capacity.Channel
oConf.PublicKey = capacity.PaymentAddress
oConf.NTXPricePerMinute = capacity.NTXPricePerMinute
savedConfig, err := o.ParamsRepo.Save(context.Background(), oConf)
if err != nil {
return nil, fmt.Errorf("could not save onboarding params: %w", err)
}
// TODO: call the resource manager directly instead
if err := o.updateAvailableResources(ctx, capacity); err != nil {
return nil, fmt.Errorf("failed to update available resources: %w", err)
}
_, err = o.P2PRepo.Save(ctx, types.Libp2pInfo{
ServerMode: capacity.ServerMode,
Available: capacity.IsAvailable,
})
if err != nil {
return nil, fmt.Errorf("unable to save libp2pInfo: %w", err)
}
return &savedConfig, nil
}
// ResourceConfig allows changing onboarding parameters
func (o *Onboarding) ResourceConfig(ctx context.Context, capacity types.CapacityForNunet) (*types.OnboardingConfig, error) {
onboarded, err := o.IsOnboarded(ctx)
if err != nil {
return nil, fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return nil, ErrMachineNotOnboarded
}
if err := o.validateCapacityForNunet(capacity); err != nil {
return nil, fmt.Errorf("could not validate capacity data: %w", err)
}
params, err := o.ParamsRepo.Get(ctx)
if err != nil {
return nil, fmt.Errorf("could not read onboarding params from db: %w", err)
}
params.OnboardedResources.CPU = types.CPU{Cores: float32(capacity.CPU)}
params.OnboardedResources.RAM = types.RAM{Size: capacity.Memory}
params.NTXPricePerMinute = capacity.NTXPricePerMinute
available, err := o.AvResourceRepo.Get(ctx)
if err != nil {
return nil, fmt.Errorf("could not get available resources info: %w", err)
}
available.TotCPUHz = capacity.CPU
available.RAM = capacity.Memory
available.NTXPricePerMinute = capacity.NTXPricePerMinute
if _, err := o.AvResourceRepo.Save(ctx, available); err != nil {
return nil, fmt.Errorf("could not save available resources info: %w", err)
}
if _, err := o.ParamsRepo.Save(ctx, params); err != nil {
return nil, fmt.Errorf("could not save onboarding params in db: %w", err)
}
// TODO: change the way the resources are being onboarded
// _, err = o.ResourceManager.UpdateFreeResources(ctx)
// if err != nil {
// return nil, fmt.Errorf("could not calculate free resources and update database: %w", err)
// }
return ¶ms, nil
}
// Offboard deletes all onboarding information if already set
// It returns an error
func (o *Onboarding) Offboard(ctx context.Context, force bool) error {
onboarded, err := o.IsOnboarded(ctx)
if err != nil && !force {
return fmt.Errorf("could not retrieve onboard status: %w", err)
} else if err != nil && force {
zlog.Sugar().Errorf("problem with onboarding state: %w", err)
zlog.Info("continuing with offboarding because forced")
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
// TODO: shutdown routine to stop networking etc... here
err = o.ParamsRepo.Clear(ctx)
if err != nil && !force {
return fmt.Errorf("failed to remove onboarding params from db: %w", err)
} else if err != nil && force {
zlog.Sugar().Errorf("failed to delete onboarding params from db - problem with onboarding state: %w", err)
zlog.Info("continuing with offboarding because forced")
}
// delete the available resources from database
err = o.AvResourceRepo.Clear(ctx)
if err != nil && !force {
return fmt.Errorf("failed to remove reserved resource from db: %w", err)
} else if err != nil && force {
zlog.Sugar().Errorf("failed to delete reserved resource from db - problem with onboarding state: %w", err)
zlog.Info("continuing with offboarding because forced")
}
return nil
}
func (o *Onboarding) validateCapacityForNunet(capacity types.CapacityForNunet) error {
machineResources, err := o.ResourceManager.SystemSpecs().GetMachineResources()
if err != nil {
return fmt.Errorf("could not get provisioned resources: %w", err)
}
if capacity.CPU > int64(machineResources.CPU.Compute*9/10) || capacity.CPU < int64(machineResources.CPU.Compute/10) {
return fmt.Errorf("CPU should be between 10%% and 90%% of the available CPU (%d and %d)", int64(machineResources.CPU.Compute/10), int64(machineResources.CPU.Compute*9/10))
}
//nolint:gosec // to be fixed in TODO: 553
if capacity.Memory > machineResources.RAM.Size*9/10 || capacity.Memory < machineResources.RAM.Size/10 {
return fmt.Errorf("memory should be between 10%% and 90%% of the available memory (%d and %d)", int64(machineResources.RAM.Size/10), int64(machineResources.RAM.Size*9/10))
}
return nil
}
func (o *Onboarding) validateOnboardingPrerequisites(capacity types.CapacityForNunet) error {
ok, err := o.Fs.DirExists(o.WorkDir)
if err != nil {
return fmt.Errorf("could not check if config directory exists: %w", err)
}
if !ok {
return fmt.Errorf("working directory does not exist")
}
if err := utils.ValidateAddress(capacity.PaymentAddress); err != nil {
return fmt.Errorf("could not validate payment address: %w", err)
}
if err := o.validateCapacityForNunet(capacity); err != nil {
return fmt.Errorf("could not validate capacity data: %w", err)
}
if !slices.Contains(o.Channels, capacity.Channel) {
return fmt.Errorf("invalid channel data: '%s' channel does not exist", capacity.Channel)
}
return nil
}
func (o *Onboarding) updateAvailableResources(ctx context.Context, capacity types.CapacityForNunet) error {
machineResources, err := o.ResourceManager.SystemSpecs().GetMachineResources()
if err != nil {
return fmt.Errorf("could not get provisioned resources: %w", err)
}
avalRes := types.AvailableResources{
TotCPUHz: capacity.CPU,
CPUNo: int(machineResources.CPU.Cores),
CPUHz: machineResources.CPU.ClockSpeed,
PriceCPU: 0, // TODO: Get price of CPU
RAM: capacity.Memory,
PriceRAM: 0, // TODO: Get price of RAM
Vcpu: int(float64(capacity.CPU) / machineResources.CPU.ClockSpeed),
Disk: 0,
PriceDisk: 0,
NTXPricePerMinute: capacity.NTXPricePerMinute,
}
_, err = o.AvResourceRepo.Save(ctx, avalRes)
if err != nil {
return fmt.Errorf("failed to save available resources: %w", err)
}
// if _, err := o.ResourceManager.UpdateFreeResources(ctx); err != nil {
// zlog.Sugar().Errorf("could not calculate free resources and update database: %w", err)
// }
return nil
}
// CreatePaymentAddress generates a keypair based on the wallet type. Currently supported types: ethereum, cardano.
// TODO: This should be moved to utils-related package. It's a utility function independent of onboarding
func CreatePaymentAddress(wallet string) (*types.BlockchainAddressPrivKey, error) {
var (
pair *types.BlockchainAddressPrivKey
err error
)
switch wallet {
case "ethereum":
pair, err = GetEthereumAddressAndPrivateKey()
case "cardano":
pair, err = GetCardanoAddressAndMnemonic()
default:
return nil, fmt.Errorf("invalid wallet")
}
if err != nil {
return nil, fmt.Errorf("could not generate %s address: %w", wallet, err)
}
return pair, nil
}
package resources
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/db"
gormRepo "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/telemetry/logger"
"gitlab.com/nunet/device-management-service/types"
)
var (
// zlog is the logger for the resources package
zlog *otelzap.Logger
// ManagerInstance is the ResourceManager instance
ManagerInstance types.ResourceManager
)
// TODO: This needs to be initialized in `dms` package and removed from here
// https://gitlab.com/nunet/device-management-service/-/issues/536
// it is being initialized in `dms` package now but there is usage in executor
// in executor/docker/executor.go:262:25 in function newDockerExecutionContainer
// which heavily depends on this var and any attempt to fix it will involve
// too many changes. Once that code moves to allocations, this can be removed.
func init() {
zlog = logger.OtelZapLogger("resources")
repos := ManagerRepos{
FreeResources: gormRepo.NewFreeResources(db.DB),
OnboardedResources: gormRepo.NewOnboardedResources(db.DB),
ResourceAllocation: gormRepo.NewResourceAllocation(db.DB),
}
ManagerInstance = NewResourceManager(repos)
}
package resources
import (
"context"
"fmt"
"sync"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// gpuMetadata holds the metadata of the GPU
type gpuMetadata struct {
PCIAddress string
}
// ManagerRepos holds all the repositories needed for resource management
type ManagerRepos struct {
FreeResources repositories.FreeResources
OnboardedResources repositories.OnboardedResources
ResourceAllocation repositories.ResourceAllocation
}
// DefaultManager implements the ResourceManager interface
// TODO: Add telemetry for the methods https://gitlab.com/nunet/device-management-service/-/issues/535
type DefaultManager struct {
usageMonitor types.UsageMonitor
systemSpecs types.SystemSpecs
repos ManagerRepos
store *store
// allocationLock is used to synchronize access to the allocation pool during allocation and deallocation
// it ensures that resource allocation and deallocation are atomic operations
allocationLock sync.RWMutex
}
// NewResourceManager returns a new defaultResourceManager instance
func NewResourceManager(repos ManagerRepos) *DefaultManager {
rmStore := newStore()
sysSpecs := newSystemSpecs(rmStore)
return &DefaultManager{
usageMonitor: newUsageMonitor(),
systemSpecs: sysSpecs,
repos: repos,
store: rmStore,
}
}
var _ types.ResourceManager = (*DefaultManager)(nil)
// AllocateResources allocates resources for a job
func (d *DefaultManager) AllocateResources(ctx context.Context, allocation types.ResourceAllocation) error {
d.allocationLock.Lock()
defer d.allocationLock.Unlock()
// Check if resources are already allocated for the job
var ok bool
d.store.withAllocationsRLock(func() {
_, ok = d.store.allocations[allocation.JobID]
})
if ok {
return fmt.Errorf("resources already allocated for job %s", allocation.JobID)
}
// Check if the resources are available
freeResources, err := d.GetFreeResources(ctx)
if err != nil {
return fmt.Errorf("getting free resources: %w", err)
}
// Allocate the resources
if err := freeResources.Subtract(allocation.Resources); err != nil {
return fmt.Errorf("subtracting resources: %w", err)
}
// Potential issue: if the free resources are updated in the db, the allocations should be updated as well
// If the allocations update fails, the free resources should not be updated
// Since we have no concept of transactions in the current implementation of db, we cannot handle this scenario
// without writing a custom transaction manager
if err := d.updateFreeResources(ctx, freeResources); err != nil {
return fmt.Errorf("updating free resources in db: %w", err)
}
if err := d.storeAllocation(ctx, allocation); err != nil {
return fmt.Errorf("storing allocations in db: %w", err)
}
return nil
}
// DeallocateResources deallocates resources for a job
func (d *DefaultManager) DeallocateResources(ctx context.Context, jobID string) error {
d.allocationLock.Lock()
defer d.allocationLock.Unlock()
// Check if resources are already deallocated for the job
var (
allocation types.ResourceAllocation
ok bool
)
d.store.withAllocationsRLock(func() {
allocation, ok = d.store.allocations[jobID]
})
if !ok {
return fmt.Errorf("resources not allocated for job %s", jobID)
}
// Get the free resources in order to update them
freeResources, err := d.GetFreeResources(ctx)
if err != nil {
return fmt.Errorf("getting free resources: %w", err)
}
// Deallocate the resources
// Potential issue: if the free resources are updated in the db, the allocations should be updated as well
// If the allocations update fails, the free resources should not be updated
// Since we have no concept of transactions in the current implementation of db, we cannot handle this scenario
// without writing a custom transaction manager
if err := freeResources.Add(allocation.Resources); err != nil {
return fmt.Errorf("adding resources: %w", err)
}
if err := d.updateFreeResources(ctx, freeResources); err != nil {
return fmt.Errorf("updating free resources in db: %w", err)
}
if err := d.deleteAllocation(ctx, jobID); err != nil {
return fmt.Errorf("deleting allocations from db: %w", err)
}
return nil
}
// GetFreeResources returns the free resources in the allocation pool
func (d *DefaultManager) GetFreeResources(ctx context.Context) (types.FreeResources, error) {
var (
freeResources types.FreeResources
ok bool
)
d.store.withFreeRLock(func() {
if d.store.freeResources != nil {
freeResources = *d.store.freeResources
ok = true
}
})
if ok {
return freeResources, nil
}
freeResources, err := d.repos.FreeResources.Get(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("failed to get free resources: %w", err)
}
d.store.withFreeLock(func() {
d.store.freeResources = &freeResources
})
return freeResources, nil
}
// GetTotalAllocation returns the total allocations of the jobs requiring resources
func (d *DefaultManager) GetTotalAllocation() (types.Resources, error) {
if len(d.store.allocations) == 0 {
if err := d.getAllocationsFromDB(context.Background()); err != nil {
return types.Resources{}, fmt.Errorf("getting allocations from db: %w", err)
}
}
var (
totalAllocation types.Resources
err error
)
d.store.withAllocationsRLock(func() {
for _, allocation := range d.store.allocations {
err = totalAllocation.Add(allocation.Resources)
if err != nil {
break
}
}
})
return totalAllocation, err
}
// GetOnboardedResources returns the onboarded resources of the machine
func (d *DefaultManager) GetOnboardedResources(ctx context.Context) (types.OnboardedResources, error) {
var (
onboardedResources types.OnboardedResources
ok bool
)
d.store.withOnboardedRLock(func() {
if d.store.onboardedResources != nil {
onboardedResources = *d.store.onboardedResources
ok = true
}
})
if ok {
return onboardedResources, nil
}
onboardedResources, err := d.repos.OnboardedResources.Get(ctx)
if err != nil {
return types.OnboardedResources{}, fmt.Errorf("failed to get onboarded resources: %w", err)
}
_ = d.store.withOnboardedLock(func() error {
d.store.onboardedResources = &onboardedResources
return nil
})
return onboardedResources, nil
}
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
func (d *DefaultManager) UpdateOnboardedResources(ctx context.Context, resources types.OnboardedResources) error {
if err := d.store.withOnboardedLock(func() error {
// calculate the new free resources based on the allocations
totalAllocation, err := d.GetTotalAllocation()
if err != nil {
return fmt.Errorf("getting total allocations: %w", err)
}
if err := resources.Resources.Subtract(totalAllocation); err != nil {
return fmt.Errorf("couldn't subtract allocation: %w. Demand too high", err)
}
// Potential issue: if the onboarded resources are updated in the db, the free resources should be updated as well
// If the free resources update fails, the onboarded resources should not be updated
// Since we have no concept of transactions in the current implementation of db, we cannot handle this scenario
// without writing a custom transaction manager
_, err = d.repos.OnboardedResources.Save(ctx, resources)
if err != nil {
return fmt.Errorf("failed to update onboarded resources: %w", err)
}
d.store.onboardedResources = &resources
if err := d.updateFreeResources(ctx, types.FreeResources{
Resources: resources.Resources,
}); err != nil {
return fmt.Errorf("updating free resources in db: %w", err)
}
return nil
}); err != nil {
return err
}
return nil
}
// SystemSpecs returns the SystemSpecs instance
func (d *DefaultManager) SystemSpecs() types.SystemSpecs {
return d.systemSpecs
}
// UsageMonitor returns the UsageMonitor instance
func (d *DefaultManager) UsageMonitor() types.UsageMonitor {
return d.usageMonitor
}
// updateFreeResources updates the free resources in the database and the store
func (d *DefaultManager) updateFreeResources(ctx context.Context, freeResources types.FreeResources) error {
_, err := d.repos.FreeResources.Save(ctx, freeResources)
if err != nil {
return fmt.Errorf("updating free resources: %w", err)
}
// update the free resources in the store
d.store.withFreeLock(func() {
d.store.freeResources = &freeResources
})
return nil
}
// getAllocationsFromDB fetches the allocations from the database
func (d *DefaultManager) getAllocationsFromDB(ctx context.Context) error {
allocations, err := d.repos.ResourceAllocation.FindAll(ctx, d.repos.ResourceAllocation.GetQuery())
if err != nil {
return fmt.Errorf("getting allocations from db: %w", err)
}
d.store.withAllocationsLock(func() {
for _, allocation := range allocations {
d.store.allocations[allocation.JobID] = allocation
}
})
return nil
}
// storeAllocation stores the allocations in the database and the store
func (d *DefaultManager) storeAllocation(ctx context.Context, allocation types.ResourceAllocation) error {
_, err := d.repos.ResourceAllocation.Create(ctx, allocation)
if err != nil {
return fmt.Errorf("storing allocations in db: %w", err)
}
d.store.withAllocationsLock(func() {
d.store.allocations[allocation.JobID] = allocation
})
return nil
}
// deleteAllocation deletes the allocations from the database and the store
func (d *DefaultManager) deleteAllocation(ctx context.Context, jobID string) error {
query := d.repos.ResourceAllocation.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("JobID", jobID))
allocation, err := d.repos.ResourceAllocation.Find(context.Background(), query)
if err != nil {
return fmt.Errorf("finding allocations in db: %w", err)
}
if err := d.repos.ResourceAllocation.Delete(ctx, allocation.ID); err != nil {
return fmt.Errorf("deleting allocations from db: %w", err)
}
d.store.withAllocationsLock(func() {
delete(d.store.allocations, jobID)
})
return nil
}
package resources
import (
"sync"
"gitlab.com/nunet/device-management-service/types"
)
// locks holds the locks for the resource manager
// allocations: lock for the allocations map
// onboarded: lock for the onboarded resources
// free: lock for the free resources
type locks struct {
allocations sync.RWMutex
onboarded sync.RWMutex
free sync.RWMutex
}
// newLocks returns a new locks instance
func newLocks() *locks {
return &locks{}
}
// store holds the resources of the machine
// onboardedResources: resources that are onboarded to the machine
// freeResources: resources that are free to be allocated
// allocations: resources that are requested by the jobs
type store struct {
onboardedResources *types.OnboardedResources
freeResources *types.FreeResources
allocations map[string]types.ResourceAllocation
gpuMetadata map[types.GPUVendor][]gpuMetadata
machineResources *types.MachineResources
locks *locks
}
// newStore returns a new store instance
func newStore() *store {
return &store{
allocations: make(map[string]types.ResourceAllocation),
gpuMetadata: make(map[types.GPUVendor][]gpuMetadata),
locks: newLocks(),
}
}
// withAllocationsLock locks the allocations lock and executes the function
func (s *store) withAllocationsLock(fn func()) {
s.locks.allocations.Lock()
defer s.locks.allocations.Unlock()
fn()
}
// withOnboardedLock locks the onboarded lock and executes the function
func (s *store) withOnboardedLock(fn func() error) error {
s.locks.onboarded.Lock()
defer s.locks.onboarded.Unlock()
return fn()
}
// withFreeLock locks the free lock and executes the function
func (s *store) withFreeLock(fn func()) {
s.locks.free.Lock()
defer s.locks.free.Unlock()
fn()
}
// withGpuMetadataLock locks the gpu metadata lock and executes the function
func (s *store) withGpuMetadataLock(fn func()) {
s.locks.allocations.RLock()
defer s.locks.allocations.RUnlock()
fn()
}
// withMachineResourcesLock locks the machine resources lock and executes the function
func (s *store) withMachineResourcesLock(fn func()) {
s.locks.allocations.Lock()
defer s.locks.allocations.Unlock()
fn()
}
// withAllocationsRLock performs a read lock and returns the result and error
func (s *store) withAllocationsRLock(fn func()) {
s.locks.allocations.RLock()
defer s.locks.allocations.RUnlock()
fn()
}
// withOnboardedRLock performs a read lock and returns the result and error
func (s *store) withOnboardedRLock(fn func()) {
s.locks.onboarded.RLock()
defer s.locks.onboarded.RUnlock()
fn()
}
// withFreeRLock performs a read lock and returns the result and error
func (s *store) withFreeRLock(fn func()) {
s.locks.free.RLock()
defer s.locks.free.RUnlock()
fn()
}
// withGpuMetadataLock locks the gpu metadata lock and executes the function
// commenting out this function as it is not used but will be used in the future
// func (s *store) withGpuMetadataRLock(fn func() map[types.GPUVendor][]gpuMetadata) map[types.GPUVendor][]gpuMetadata {
// s.locks.allocations.RLock()
// defer s.locks.allocations.RUnlock()
// return fn()
// }
// withMachineResourcesRLock locks the machine resources lock and executes the function
func (s *store) withMachineResourcesRLock(fn func()) {
s.locks.allocations.RLock()
defer s.locks.allocations.RUnlock()
fn()
}
//go:build linux && amd64
package resources
import (
"context"
"errors"
"fmt"
"os/exec"
"regexp"
"strconv"
"strings"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/jaypipes/ghw"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/disk"
"github.com/shirou/gopsutil/v4/mem"
"gitlab.com/nunet/device-management-service/types"
)
// linuxSystemSpecs implements the SystemSpecs interface for Linux systems
type linuxSystemSpecs struct {
store *store
}
// newSystemSpecs returns a new instance of linuxSystemSpecs
func newSystemSpecs(store *store) *linuxSystemSpecs {
return &linuxSystemSpecs{
store: store,
}
}
var _ types.SystemSpecs = (*linuxSystemSpecs)(nil)
// getRAM returns the types.RAM information for the system
func getRAM() (types.RAM, error) {
v, err := mem.VirtualMemory()
if err != nil {
return types.RAM{}, fmt.Errorf("failed to get total memory: %s", err)
}
return types.RAM{
Size: v.Total,
}, nil
}
// getDisk returns the types.Disk for the system
func getDisk() (types.Disk, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get partitions: %w", err)
}
var totalStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get disk usage: %w", err)
}
totalStorage += usage.Total
}
return types.Disk{
Size: totalStorage,
}, nil
}
// GetCPU returns the CPU information for the system
func getCPU() (types.CPU, error) {
cores, err := cpu.Info()
if err != nil {
return types.CPU{}, fmt.Errorf("failed to get CPU info: %s", err)
}
var totalCompute float64
for i := 0; i < len(cores); i++ {
totalCompute += cores[i].Mhz
}
return types.CPU{
Compute: totalCompute,
Cores: float32(len(cores)),
ClockSpeed: cores[0].Mhz * 1000000,
}, nil
}
// TODO: move the following functions to the `gpu` sub-package
// https://gitlab.com/nunet/device-management-service/-/issues/546
// assignIndexToGPUs assigns an index to each GPU in the list starting from 0
func assignIndexToGPUs(gpus []types.GPU) []types.GPU {
for i := range gpus {
gpus[i].Index = i
}
return gpus
}
// getAMDGPUInfo returns the GPU information for AMD GPUs
func getAMDGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
cmd := exec.Command("rocm-smi", "--showid", "--showproductname", "--showmeminfo", "vram")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("AMD ROCm not installed, initialized, or configured (reboot recommended for newly installed AMD GPU Drivers): %s", err)
}
outputStr := string(output)
// fmt.Println("rocm-smi vram output:\n", outputStr) // Print the output for debugging
gpuNameRegex := regexp.MustCompile(`GPU\[\d+\]\s+: Card Series:\s+([^\n]+)`)
totalRegex := regexp.MustCompile(`GPU\[\d+\]\s+: VRAM Total Memory \(B\):\s+(\d+)`)
usedRegex := regexp.MustCompile(`GPU\[\d+\]\s+: VRAM Total Used Memory \(B\):\s+(\d+)`)
gpuNameMatches := gpuNameRegex.FindAllStringSubmatch(outputStr, -1)
totalMatches := totalRegex.FindAllStringSubmatch(outputStr, -1)
usedMatches := usedRegex.FindAllStringSubmatch(outputStr, -1)
if len(gpuNameMatches) == 0 || len(totalMatches) == 0 || len(usedMatches) == 0 {
return nil, fmt.Errorf("failed to find AMD GPU information or vram information in the output")
}
if len(gpuNameMatches) != len(totalMatches) || len(totalMatches) != len(usedMatches) {
return nil, fmt.Errorf("inconsistent AMD GPU information detected")
}
gpuInfos := make([]types.GPU, 0)
for i := range gpuNameMatches {
gpuName := gpuNameMatches[i][1]
totalMemoryBytes, err := strconv.ParseInt(totalMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse total amdgpu vram: %s", err)
}
usedMemoryBytes, err := strconv.ParseInt(usedMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse used amdgpu vram: %s", err)
}
totalMemoryMiB := totalMemoryBytes / 1024 / 1024
usedMemoryMiB := usedMemoryBytes / 1024 / 1024
freeMemoryMiB := totalMemoryMiB - usedMemoryMiB
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
TotalVRAM: uint64(totalMemoryMiB),
UsedVRAM: uint64(usedMemoryMiB),
FreeVRAM: uint64(freeMemoryMiB),
Vendor: types.GPUVendorAMDATI,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// getNVIDIAGPUInfo returns the GPU information for NVIDIA GPUs
func getNVIDIAGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
// Initialize NVML
ret := nvml.Init()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("NVIDIA Management Library not installed, initialized or configured (reboot recommended for newly installed NVIDIA GPU drivers): %s", nvml.ErrorString(ret))
}
defer func() {
_ = nvml.Shutdown()
}()
// Get the number of GPU devices
deviceCount, ret := nvml.DeviceGetCount()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device count: %s", nvml.ErrorString(ret))
}
if deviceCount != len(metadata) {
return nil, fmt.Errorf("failed to find NVIDIA GPU information for all GPUs")
}
var gpus []types.GPU
// Iterate over each device
for i := 0; i < deviceCount; i++ {
// Get the device handle
device, ret := nvml.DeviceGetHandleByIndex(i)
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device handle for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the device name
name, ret := device.GetName()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get name for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the memory info
memory, ret := device.GetMemoryInfo()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get nvidiagpu vram info for device %d: %s", i, nvml.ErrorString(ret))
}
gpu := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Name: name,
Model: name,
TotalVRAM: memory.Total / 1024 / 1024,
UsedVRAM: memory.Used / 1024 / 1024,
FreeVRAM: memory.Free / 1024 / 1024,
Vendor: types.GPUVendorNvidia,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// getIntelGPUInfo returns the GPU information for Intel GPUs
func getIntelGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
// Determine the number of discrete Intel GPUs
cmd := exec.Command("xpu-smi", "health", "-l")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
outputStr := string(output)
// fmt.Println("xpu-smi health -l output:\n", outputStr) // Print the output for debugging
// Use regex to find all instances of Device ID
deviceIDRegex := regexp.MustCompile(`(?i)\| Device ID\s+\|\s+(\d+)\s+\|`)
deviceIDMatches := deviceIDRegex.FindAllStringSubmatch(outputStr, -1)
// fmt.Printf("Found device ID matches: %v\n", deviceIDMatches) // Print matched device IDs for debugging
if len(deviceIDMatches) == 0 {
return nil, fmt.Errorf("failed to find any Intel GPUs")
}
if len(deviceIDMatches) != len(metadata) {
return nil, fmt.Errorf("failed to find Intel GPU information for all GPUs")
}
gpuInfos := make([]types.GPU, 0)
for i, match := range deviceIDMatches {
deviceID := match[1]
// Get GPU details using xpu-smi discovery
cmd = exec.Command("xpu-smi", "discovery", "-d", deviceID)
output, err = cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to get discovery info for Intel GPU %s: %s", deviceID, err)
}
outputStr = string(output)
// fmt.Printf("xpu-smi discovery -d %s output:\n%s", deviceID, outputStr) // Print the output for debugging
// Use regex to find GPU name and total memory
nameRegex := regexp.MustCompile(`(?i)Device Name:\s+([^\n|]+)`)
totalMemRegex := regexp.MustCompile(`(?i)Memory Physical Size:\s+([^\s]+)\s+MiB`)
nameMatch := nameRegex.FindStringSubmatch(outputStr)
totalMemMatch := totalMemRegex.FindStringSubmatch(outputStr)
if nameMatch == nil || totalMemMatch == nil {
return nil, fmt.Errorf("failed to parse discovery info for Intel GPU %s", deviceID)
}
gpuName := strings.TrimSpace(nameMatch[1])
totalMemoryMiB, err := strconv.ParseFloat(totalMemMatch[1], 64)
if err != nil {
return nil, fmt.Errorf("failed to parse total memory for Intel GPU %s: %s", deviceID, err)
}
// Get used memory using xpu-smi stats
cmd = exec.Command("xpu-smi", "stats", "-d", deviceID)
output, err = cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to get stats for Intel GPU %s: %s", deviceID, err)
}
outputStr = string(output)
// fmt.Printf("xpu-smi stats -d %s output:\n%s", deviceID, outputStr) // Print the output for debugging
// Use regex to find used memory
usedMemRegex := regexp.MustCompile(`(?i)GPU Memory Used \(MiB\)\s+\|\s+(\d+)\s+\|`)
usedMemMatch := usedMemRegex.FindStringSubmatch(outputStr)
if usedMemMatch == nil {
return nil, fmt.Errorf("failed to parse used memory for Intel GPU %s", deviceID)
}
usedMemoryMiB, err := strconv.ParseFloat(usedMemMatch[1], 64)
if err != nil {
return nil, fmt.Errorf("failed to parse used memory for Intel GPU %s: %s", deviceID, err)
}
freeMemoryMiB := totalMemoryMiB - usedMemoryMiB
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
TotalVRAM: uint64(totalMemoryMiB),
UsedVRAM: uint64(usedMemoryMiB),
FreeVRAM: uint64(freeMemoryMiB),
Vendor: types.GPUVendorIntel,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// getGPUs returns the GPUs based on the specified vendors. If no vendors are provided, it returns the information of all the GPUs
func getGPUs(metadata map[types.GPUVendor][]gpuMetadata, vendors ...types.GPUVendor) ([]types.GPU, error) {
var gpus []types.GPU
// Helper function to fetch and append GPU info
fetchAndAppendGPUs := func(fetchFunc func(metadata []gpuMetadata) ([]types.GPU, error), vendor types.GPUVendor) {
vendorMetadata, ok := metadata[vendor]
if !ok {
zlog.Sugar().Infof("No %s GPUs found", vendor)
return
}
gpuList, err := fetchFunc(vendorMetadata)
if err != nil {
zlog.Sugar().Warnf("Failed to retrieve %s GPU information: %v", vendor, err)
return
}
gpus = append(gpus, gpuList...)
}
if len(vendors) == 0 {
// No specific vendor requested, fetch all types of GPUs
fetchAndAppendGPUs(getIntelGPUInfo, types.GPUVendorIntel)
fetchAndAppendGPUs(getNVIDIAGPUInfo, types.GPUVendorNvidia)
fetchAndAppendGPUs(getAMDGPUInfo, types.GPUVendorAMDATI)
} else {
// Fetch GPUs for the specified vendor only
for _, vendor := range vendors {
switch vendor {
case types.GPUVendorIntel:
fetchAndAppendGPUs(getIntelGPUInfo, vendor)
case types.GPUVendorNvidia:
fetchAndAppendGPUs(getNVIDIAGPUInfo, vendor)
case types.GPUVendorAMDATI:
fetchAndAppendGPUs(getAMDGPUInfo, vendor)
default:
return nil, fmt.Errorf("unsupported GPU vendor: %v", vendor)
}
}
}
// Assign index to GPUs and return
// Note: The index is internal to dms and is not the same as the device index
return assignIndexToGPUs(gpus), nil
}
// fetchGPUMetadata fetches the GPU metadata for the system using `ghw.GPU()`
// TODO: Use one single library to fetch GPU information or improve the match criteria
// https://gitlab.com/nunet/device-management-service/-/issues/548
// TODO: write tests by mocking the gpu snapshot
// https://gitlab.com/nunet/device-management-service/-/issues/534
func (l *linuxSystemSpecs) fetchGPUMetadata() (map[types.GPUVendor][]gpuMetadata, error) {
metadata := make(map[types.GPUVendor][]gpuMetadata)
l.store.withGpuMetadataLock(func() {
if l.store.gpuMetadata != nil {
metadata = l.store.gpuMetadata
return
}
})
if len(metadata) > 0 {
return metadata, nil
}
gpuInfo, err := ghw.GPU()
if err != nil {
return nil, err
}
for _, card := range gpuInfo.GraphicsCards {
if card.DeviceInfo == nil {
continue
}
pciAddress := card.Address
vendor := types.ParseGPUVendor(card.DeviceInfo.Vendor.Name)
metadata[vendor] = append(metadata[vendor], gpuMetadata{PCIAddress: pciAddress})
}
l.store.withGpuMetadataLock(func() {
l.store.gpuMetadata = metadata
})
return metadata, nil
}
func (l *linuxSystemSpecs) GetMachineResources() (types.MachineResources, error) {
var (
ok bool
machineResources types.MachineResources
)
l.store.withMachineResourcesRLock(func() {
if l.store.machineResources != nil {
machineResources = *l.store.machineResources
ok = true
}
})
if ok {
return machineResources, nil
}
metadata, err := l.fetchGPUMetadata()
if err != nil {
return types.MachineResources{}, fmt.Errorf("failed to fetch GPU metadata: %s", err)
}
cpuDetails, err := getCPU()
if err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get CPU: %s", err)
}
ram, err := getRAM()
if err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get RAM: %s", err)
}
gpus, err := getGPUs(metadata)
if err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get GPUs: %s", err)
}
diskDetails, err := getDisk()
if err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get DISK: %s", err)
}
machineResources = types.MachineResources{
Resources: types.Resources{
CPU: cpuDetails,
RAM: ram,
Disk: diskDetails,
GPUs: gpus,
},
}
l.store.withMachineResourcesLock(func() {
l.store.machineResources = &machineResources
})
// TODO: do we wanna store it in the db?
return machineResources, nil
}
package resources
import (
"context"
"gitlab.com/nunet/device-management-service/types"
)
// defaultUsageMonitor implements the UsageMonitor interface
type defaultUsageMonitor struct{}
// newUsageMonitor creates a new defaultUsageMonitor
func newUsageMonitor() *defaultUsageMonitor {
return &defaultUsageMonitor{}
}
var _ types.UsageMonitor = (*defaultUsageMonitor)(nil)
// GetUsage returns the resources used by the machine
func (um *defaultUsageMonitor) GetUsage(_ context.Context) (types.Resources, error) {
panic("implement me")
}
package docker
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/stdcopy"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
// Client wraps the Docker client to provide high-level operations on Docker containers and networks.
type Client struct {
client *client.Client // Embed the Docker client.
}
// NewDockerClient initializes a new Docker client with environment variables and API version negotiation.
func NewDockerClient() (*Client, error) {
c, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}
return &Client{client: c}, nil
}
// IsInstalled checks if Docker is installed and reachable by pinging the Docker daemon.
func (c *Client) IsInstalled(ctx context.Context) bool {
_, err := c.client.Ping(ctx)
return err == nil
}
// CreateContainer creates a new Docker container with the specified configuration.
func (c *Client) CreateContainer(
ctx context.Context,
config *container.Config,
hostConfig *container.HostConfig,
networkingConfig *network.NetworkingConfig,
platform *v1.Platform,
name string,
) (string, error) {
_, err := c.PullImage(ctx, config.Image)
if err != nil {
return "", err
}
resp, err := c.client.ContainerCreate(
ctx,
config,
hostConfig,
networkingConfig,
platform,
name,
)
if err != nil {
return "", err
}
return resp.ID, nil
}
// InspectContainer returns detailed information about a Docker container.
func (c *Client) InspectContainer(ctx context.Context, id string) (types.ContainerJSON, error) {
return c.client.ContainerInspect(ctx, id)
}
// FollowLogs tails the logs of a specified container, returning separate readers for stdout and stderr.
func (c *Client) FollowLogs(ctx context.Context, id string) (stdout, stderr io.Reader, err error) {
cont, err := c.InspectContainer(ctx, id)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
}
logsReader, err := c.client.ContainerLogs(ctx, cont.ID, logOptions)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container logs")
}
stdoutReader, stdoutWriter := io.Pipe()
stderrReader, stderrWriter := io.Pipe()
go func() {
stdoutBuffer := bufio.NewWriter(stdoutWriter)
stderrBuffer := bufio.NewWriter(stderrWriter)
defer func() {
logsReader.Close()
stdoutBuffer.Flush()
stdoutWriter.Close()
stderrBuffer.Flush()
stderrWriter.Close()
}()
_, err = stdcopy.StdCopy(stdoutBuffer, stderrBuffer, logsReader)
if err != nil && !errors.Is(err, context.Canceled) {
zlog.Sugar().Warnf("context closed while getting logs: %v\n", err)
}
}()
return stdoutReader, stderrReader, nil
}
// StartContainer starts a specified Docker container.
func (c *Client) StartContainer(ctx context.Context, containerID string) error {
return c.client.ContainerStart(ctx, containerID, types.ContainerStartOptions{})
}
// WaitContainer waits for a container to stop, returning channels for the result and errors.
func (c *Client) WaitContainer(
ctx context.Context,
containerID string,
) (<-chan container.ContainerWaitOKBody, <-chan error) {
return c.client.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
}
// StopContainer stops a running Docker container with a specified timeout.
func (c *Client) StopContainer(
ctx context.Context,
containerID string,
timeout time.Duration,
) error {
return c.client.ContainerStop(ctx, containerID, &timeout)
}
// RemoveContainer removes a Docker container, optionally forcing removal and removing associated volumes.
func (c *Client) RemoveContainer(ctx context.Context, containerID string) error {
return c.client.ContainerRemove(
ctx,
containerID,
types.ContainerRemoveOptions{RemoveVolumes: true, Force: true},
)
}
// removeContainers removes all containers matching the specified filters.
func (c *Client) removeContainers(ctx context.Context, filterz filters.Args) error {
containers, err := c.client.ContainerList(
ctx,
types.ContainerListOptions{All: true, Filters: filterz},
)
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(containers))
for _, container := range containers {
wg.Add(1)
go func(container types.Container, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.RemoveContainer(ctx, container.ID)
}(container, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// removeNetworks removes all networks matching the specified filters.
func (c *Client) removeNetworks(ctx context.Context, filterz filters.Args) error {
networks, err := c.client.NetworkList(ctx, types.NetworkListOptions{Filters: filterz})
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(networks))
for _, network := range networks {
wg.Add(1)
go func(network types.NetworkResource, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.client.NetworkRemove(ctx, network.ID)
}(network, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// RemoveObjectsWithLabel removes all Docker containers and networks with a specific label.
func (c *Client) RemoveObjectsWithLabel(ctx context.Context, label string, value string) error {
filterz := filters.NewArgs(
filters.Arg("label", fmt.Sprintf("%s=%s", label, value)),
)
containerErr := c.removeContainers(ctx, filterz)
networkErr := c.removeNetworks(ctx, filterz)
return multierr.Combine(containerErr, networkErr)
}
// GetOutputStream streams the logs for a specified container.
// The 'since' parameter specifies the timestamp from which to start streaming logs.
// The 'follow' parameter indicates whether to continue streaming logs as they are produced.
// Returns an io.ReadCloser to read the output stream and an error if the operation fails.
func (c *Client) GetOutputStream(
ctx context.Context,
containerID string,
since string,
follow bool,
) (io.ReadCloser, error) {
cont, err := c.InspectContainer(ctx, containerID)
if err != nil {
return nil, errors.Wrap(err, "failed to get container")
}
if !cont.State.Running {
return nil, fmt.Errorf("cannot get logs for a container that is not running")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: follow,
Since: since,
}
logReader, err := c.client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
return nil, errors.Wrap(err, "failed to get container logs")
}
return logReader, nil
}
// FindContainer searches for a container by label and value, returning its ID if found.
func (c *Client) FindContainer(ctx context.Context, label string, value string) (string, error) {
containers, err := c.client.ContainerList(ctx, types.ContainerListOptions{All: true})
if err != nil {
return "", err
}
for _, container := range containers {
if container.Labels[label] == value {
return container.ID, nil
}
}
return "", fmt.Errorf("unable to find container for %s=%s", label, value)
}
// PullImage pulls a Docker image from a registry.
func (c *Client) PullImage(ctx context.Context, imageName string) (string, error) {
out, err := c.client.ImagePull(ctx, imageName, types.ImagePullOptions{})
if err != nil {
zlog.Sugar().Errorf("unable to pull image: %v", err)
return "", err
}
defer out.Close()
d := json.NewDecoder(io.TeeReader(out, os.Stdout))
var message jsonmessage.JSONMessage
var digest string
for {
if err := d.Decode(&message); err != nil {
if err == io.EOF {
break
}
zlog.Sugar().Errorf("unable pull image: %v", err)
return "", err
}
if message.Aux != nil {
continue
}
if message.Error != nil {
zlog.Sugar().Errorf("unable pull image: %v", message.Error.Message)
return "", errors.New(message.Error.Message)
}
if strings.HasPrefix(message.Status, "Digest") {
digest = strings.TrimPrefix(message.Status, "Digest: ")
}
}
return digest, nil
}
package docker
import (
"context"
"fmt"
"io"
"os"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
labelExecutorName = "nunet-executor"
labelJobID = "nunet-jobID"
labelExecutionID = "nunet-executionID"
outputStreamCheckTickTime = 100 * time.Millisecond
outputStreamCheckTimeout = 5 * time.Second
)
// Executor manages the lifecycle of Docker containers for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Docker client for container management.
}
// NewExecutor initializes a new Executor instance with a Docker client.
func NewExecutor(ctx context.Context, id string) (*Executor, error) {
dockerClient, err := NewDockerClient()
if err != nil {
return nil, err
}
if !dockerClient.IsInstalled(ctx) {
return nil, fmt.Errorf("docker is not installed")
}
return &Executor{
ID: id,
client: dockerClient,
}, nil
}
// Start begins the execution of a request by starting a Docker container.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// container is already running.
containerID, err := e.FindRunningContainer(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running container for this execution, we will instead check for a handler, and
// failing that will create a new container.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
}
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
containerID, err = e.newDockerExecutionContainer(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new container: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
containerID: containerID,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
TTYEnabled: true,
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the container.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s received results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// GetLogStream provides a stream of output logs for a specific execution.
// Parameters 'withHistory' and 'follow' control whether to include past logs
// and whether to keep the stream open for new logs, respectively.
// It returns an error if the execution is not found.
func (e *Executor) GetLogStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
// It's possible we've recorded the execution as running, but have not yet added the handler to
// the handler map because we're still waiting for the container to start. We will try and wait
// for a few seconds to see if the handler is added to the map.
chHandler := make(chan *executionHandler)
chExit := make(chan struct{})
go func(ch chan *executionHandler, exit chan struct{}) {
// Check the handlers every 100ms and send it down the
// channel if we find it. If we don't find it after 5 seconds
// then we'll be told on the exit channel
ticker := time.NewTicker(outputStreamCheckTickTime)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h, found := e.handlers.Get(request.ExecutionID)
if found {
ch <- h
return
}
case <-exit:
ticker.Stop()
return
}
}
}(chHandler, chExit)
// Either we'll find a handler for the execution (which might have finished starting)
// or we'll timeout and return an error.
select {
case handler := <-chHandler:
return handler.outputStream(ctx, request)
case <-time.After(outputStreamCheckTimeout):
chExit <- struct{}{}
}
return nil, fmt.Errorf("execution (%s) not found", request.ExecutionID)
}
// List returns a slice of ExecutionListItem containing information about current executions.
// This implementation currently returns an empty list and should be updated in the future.
func (e *Executor) List() []types.ExecutionListItem {
// TODO: list dms containers
return nil
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// Cleanup removes all Docker resources associated with the executor.
// This includes removing containers including networks and volumes with the executor's label.
func (e *Executor) Cleanup(ctx context.Context) error {
err := e.client.RemoveObjectsWithLabel(ctx, labelExecutorName, e.ID)
if err != nil {
return fmt.Errorf("failed to remove containers: %w", err)
}
zlog.Info("Cleaned up all Docker resources")
return nil
}
// newDockerExecutionContainer is an internal method called by Start to set up a new Docker container
// for the job execution. It configures the container based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up environment variables, mounts and resource
// constraints. It then creates the container but does not start it.
// The method returns a container.CreateResponse and an error if any part of the setup fails.
func (e *Executor) newDockerExecutionContainer(
ctx context.Context,
params *types.ExecutionRequest,
) (string, error) {
dockerArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return "", fmt.Errorf("failed to decode docker engine spec: %w", err)
}
// TODO: Move this code block ( L263-272) to the allocator in future
// Select the GPU with the highest available free VRAM and choose the GPU vendor for container's host config
machineResources, err := resources.ManagerInstance.SystemSpecs().GetMachineResources()
if err != nil {
return "", fmt.Errorf("failed to get machine resources: %w", err)
}
var chosenGPUVendor types.GPUVendor
if len(machineResources.GPUs) == 0 {
zlog.Info("no GPUs available on the machine")
chosenGPUVendor = types.GPUVendorUnknown
} else {
maxFreeVRAMGpu, err := machineResources.GPUs.GetGPUWithHighestFreeVRAM()
if err != nil {
return "", fmt.Errorf("failed to get GPU with highest free VRAM: %w", err)
}
// Essential for multi-vendor GPU nodes. For example,
// if a machine has an 8 GB NVIDIA and a 16 GB Intel GPU, the latter should be used first.
// Even for machines with a single GPU, this is important as integrated GPUs would also be commonly detected.
chosenGPUVendor = maxFreeVRAMGpu.Vendor
}
containerConfig := container.Config{
Image: dockerArgs.Image,
Tty: true, // Needs to be true for applications such as Jupyter or Gradio to work correctly. See issue #459 for details.
Env: dockerArgs.Environment,
Entrypoint: dockerArgs.Entrypoint,
Cmd: dockerArgs.Cmd,
Labels: e.containerLabels(params.JobID, params.ExecutionID),
WorkingDir: dockerArgs.WorkingDirectory,
}
mounts, err := makeContainerMounts(params.Inputs, params.Outputs, params.ResultsDir)
if err != nil {
return "", fmt.Errorf("failed to create container mounts: %w", err)
}
zlog.Sugar().Infof("Adding %d GPUs to request", len(params.Resources.GPUs))
hostConfig := configureHostConfig(chosenGPUVendor, params, mounts)
if _, err = e.client.PullImage(ctx, dockerArgs.Image); err != nil {
return "", fmt.Errorf("failed to pull docker image: %w", err)
}
executionContainer, err := e.client.CreateContainer(
ctx,
&containerConfig,
&hostConfig,
nil,
nil,
labelExecutionValue(e.ID, params.JobID, params.ExecutionID),
)
if err != nil {
return "", fmt.Errorf("failed to create container: %w", err)
}
return executionContainer, nil
}
// configureHostConfig sets up the host configuration for the container based on the
// GPU vendor and resources requested by the execution. It supports both GPU and CPU configurations.
func configureHostConfig(vendor types.GPUVendor, params *types.ExecutionRequest, mounts []mount.Mount) container.HostConfig {
var hostConfig container.HostConfig
switch vendor {
case types.GPUVendorNvidia:
deviceIDs := make([]string, len(params.Resources.GPUs))
for i, gpu := range params.Resources.GPUs {
deviceIDs[i] = fmt.Sprint(gpu.Index)
}
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores),
CPUCount: int64(params.Resources.CPU.Cores),
DeviceRequests: []container.DeviceRequest{
{
DeviceIDs: deviceIDs,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores),
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
// Updated the device handling for Intel GPUs.
// Previously, specific device paths were determined using PCI addresses and symlinks.
// Now, the approach has been simplified by directly binding the entire /dev/dri directory.
// This change exposes all Intel GPUs to the container, which may be preferable for
// environments with multiple Intel GPUs. It reduces complexity as granular control
// is not required if all GPUs need to be accessible.
case types.GPUVendorIntel:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores),
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
default:
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores),
CPUCount: int64(params.Resources.CPU.Cores),
},
}
}
return hostConfig
}
// makeContainerMounts creates the mounts for the container based on the input and output
// volumes provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeContainerMounts(
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]mount.Mount, error) {
// the actual mounts we will give to the container
// these are paths for both input and output data
mounts := make([]mount.Mount, 0)
for _, input := range inputs {
if input.Type != types.StorageVolumeTypeBind {
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: input.Source,
Target: input.Target,
ReadOnly: input.ReadOnly,
})
} else {
return nil, fmt.Errorf("unsupported storage volume type: %s", input.Type)
}
}
for _, output := range outputs {
if output.Source == "" {
return nil, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return nil, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return nil, fmt.Errorf("failed to create results directory: %w", err)
}
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: output.Source,
Target: output.Target,
// this is an output volume so can be written to
ReadOnly: false,
})
}
return mounts, nil
}
// containerLabels returns the labels to be applied to the container for the given job and execution.
func (e *Executor) containerLabels(jobID string, executionID string) map[string]string {
return map[string]string{
labelExecutorName: e.ID,
labelJobID: labelJobValue(e.ID, jobID),
labelExecutionID: labelExecutionValue(e.ID, jobID, executionID),
}
}
// labelJobValue returns the value for the job label.
func labelJobValue(executorID string, jobID string) string {
return fmt.Sprintf("%s_%s", executorID, jobID)
}
// labelExecutionValue returns the value for the execution label.
func labelExecutionValue(executorID string, jobID string, executionID string) string {
return fmt.Sprintf("%s_%s_%s", executorID, jobID, executionID)
}
// FindRunningContainer finds the container that is running the execution
// with the given ID. It returns the container ID if found, or an error if
// the container is not found.
func (e *Executor) FindRunningContainer(
ctx context.Context,
jobID string,
executionID string,
) (string, error) {
labelValue := labelExecutionValue(e.ID, jobID, executionID)
return e.client.FindContainer(ctx, labelExecutionID, labelValue)
}
package docker
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"strconv"
"sync/atomic"
"time"
"gitlab.com/nunet/device-management-service/types"
)
const DestroyTimeout = time.Second * 10
// executionHandler manages the lifecycle and execution of a Docker container for a specific job.
type executionHandler struct {
// provided by the executor
ID string
client *Client // Docker client for container management.
// meta data about the task
jobID string
executionID string
containerID string
resultsDir string // Directory to store execution results.
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // Blocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
// TTY setting
TTYEnabled bool // Indicates if TTY is enabled for the container.
}
// active checks if the execution handler's container is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the container and handles its execution lifecycle.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
if err := h.destroy(DestroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
if err := h.client.StartContainer(ctx, h.containerID); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start container: %v", err))
return
}
close(h.activeCh) // Indicate that the container has started.
var containerError error
var containerExitStatusCode int64
// Wait for the container to finish or for an execution error.
statusCh, errCh := h.client.WaitContainer(ctx, h.containerID)
select {
case status := <-ctx.Done():
h.result = types.NewFailedExecutionResult(fmt.Errorf("execution cancelled: %v", status))
return
case err := <-errCh:
zlog.Sugar().Errorf("error while waiting for container: %v\n", err)
h.result = types.NewFailedExecutionResult(
fmt.Errorf("failed to wait for container: %v", err),
)
return
case exitStatus := <-statusCh:
containerExitStatusCode = exitStatus.StatusCode
containerJSON, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: err.Error(),
}
return
}
if containerJSON.ContainerJSONBase.State.OOMKilled {
containerError = errors.New("container was killed due to OOM")
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
return
}
if exitStatus.Error != nil {
containerError = errors.New(exitStatus.Error.Message)
}
}
// Follow container logs to capture stdout and stderr.
stdoutPipe, stderrPipe, logsErr := h.client.FollowLogs(ctx, h.containerID)
if logsErr != nil {
followError := fmt.Errorf("failed to follow container logs: %w", logsErr)
if containerError != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: fmt.Sprintf(
"container error: '%s'. logs error: '%s'",
containerError,
followError,
),
}
} else {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: followError.Error(),
}
}
return
}
// Initialize the result with the exit status code.
h.result = types.NewExecutionResult(int(containerExitStatusCode))
// Capture the logs based on the TTY setting.
if h.TTYEnabled {
// TTY combines stdout and stderr, read from stdoutPipe only.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
} else {
// Read from stdout and stderr separately.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
h.result.STDERR, _ = bufio.NewReader(stderrPipe).ReadString('\x00')
}
}
// kill sends a stop signal to the container.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.StopContainer(ctx, h.containerID, DestroyTimeout)
}
// destroy cleans up the container and its associated resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// stop the container
if err := h.kill(ctx); err != nil {
return fmt.Errorf("failed to kill container (%s): %w", h.containerID, err)
}
if err := h.client.RemoveContainer(ctx, h.containerID); err != nil {
return err
}
// Remove related objects like networks or volumes created for this execution.
return h.client.RemoveObjectsWithLabel(
ctx,
labelExecutionID,
labelExecutionValue(h.ID, h.jobID, h.executionID),
)
}
func (h *executionHandler) outputStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
since := "1" // Default to the start of UNIX time to get all logs.
if request.Tail {
since = strconv.FormatInt(time.Now().Unix(), 10)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-h.activeCh: // Ensure the container is active before attempting to stream logs.
}
// Gets the underlying reader, and provides data since the value of the `since` timestamp.
return h.client.GetOutputStream(ctx, h.containerID, since, request.Follow)
}
package docker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("docker.executor")
}
package docker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyImage = "image"
EngineKeyEntrypoint = "entrypoint"
EngineKeyCmd = "cmd"
EngineKeyEnvironment = "environment"
EngineKeyWorkingDirectory = "working_directory"
)
// EngineSpec contains necessary parameters to execute a docker job.
type EngineSpec struct {
// Image this should be pullable by docker
Image string `json:"image,omitempty"`
// Entrypoint optionally override the default entrypoint
Entrypoint []string `json:"entrypoint,omitempty"`
// Cmd specifies the command to run in the container
Cmd []string `json:"cmd,omitempty"`
// EnvironmentVariables is a slice of env to run the container with
Environment []string `json:"environment,omitempty"`
// WorkingDirectory inside the container
WorkingDirectory string `json:"working_directory,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.Image) {
return fmt.Errorf("invalid docker engine params: image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a docker engine spec
// It converts the params into a docker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(string(types.ExecutorTypeDocker)) {
return EngineSpec{}, fmt.Errorf(
"invalid docker engine type. expected %s, but received: %s",
types.ExecutorTypeDocker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid docker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode docker engine params: %w", err)
}
var dockerSpec *EngineSpec
if err := json.Unmarshal(paramBytes, &dockerSpec); err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode docker engine params: %w", err)
}
return *dockerSpec, dockerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Docker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewDockerEngineBuilder function initializes a new DockerEngineBuilder instance.
// It sets the engine type to model.EngineDocker.String() and image as per the input argument.
func NewDockerEngineBuilder(image string) *EngineBuilder {
eb := types.NewSpecConfig(string(types.ExecutorTypeDocker))
eb.WithParam(EngineKeyImage, image)
return &EngineBuilder{eb: eb}
}
// WithEntrypoint is a builder method that sets the Docker engine entrypoint.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEntrypoint(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEntrypoint, e)
return b
}
// WithCmd is a builder method that sets the Docker engine's Command.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithCmd(c ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyCmd, c)
return b
}
// WithEnvironment is a builder method that sets the Docker engine's environment variables.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEnvironment(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEnvironment, e)
return b
}
// WithWorkingDirectory is a builder method that sets the Docker engine's working directory.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithWorkingDirectory(w string) *EngineBuilder {
b.eb.WithParam(EngineKeyWorkingDirectory, w)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
//go:build linux
// +build linux
package firecracker
import (
"context"
"fmt"
"os"
"syscall"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
)
const pidCheckTickTime = 100 * time.Millisecond
// Client wraps the Firecracker SDK to provide high-level operations on Firecracker VMs.
type Client struct{}
func NewFirecrackerClient() (*Client, error) {
return &Client{}, nil
}
// IsInstalled checks if Firecracker is installed on the host.
func (c *Client) IsInstalled(ctx context.Context) bool {
// Check if the Firecracker binary is installed.
// This implementation sends a version request to the Firecracker binary.
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
cmd := firecracker.VMCommandBuilder{}.WithArgs([]string{"--version"}).Build(ctx)
version, err := cmd.Output()
if err != nil || !cmd.ProcessState.Success() {
return false
}
return string(version) != ""
}
// CreateVM creates a new Firecracker VM with the specified configuration.
func (c *Client) CreateVM(
ctx context.Context,
cfg firecracker.Config,
) (*firecracker.Machine, error) {
cmd := firecracker.VMCommandBuilder{}.
WithSocketPath(cfg.SocketPath).
Build(ctx)
machineOpts := []firecracker.Opt{
firecracker.WithProcessRunner(cmd),
}
m, err := firecracker.NewMachine(ctx, cfg, machineOpts...)
return m, err
}
// StartVM starts the Firecracker VM.
func (c *Client) StartVM(ctx context.Context, m *firecracker.Machine) error {
return m.Start(ctx)
}
// ShutdownVM shuts down the Firecracker VM.
func (c *Client) ShutdownVM(ctx context.Context, m *firecracker.Machine) error {
return m.Shutdown(ctx)
}
// DestroyVM destroys the Firecracker VM.
func (c *Client) DestroyVM(
ctx context.Context,
m *firecracker.Machine,
timeout time.Duration,
) error {
// Get the PID of the Firecracker process and shut down the VM.
// If the process is still running after the timeout, kill it.
err := c.ShutdownVM(ctx, m)
if err != nil {
return fmt.Errorf("failed to shutdown vm: %w", err)
}
pid, _ := m.PID()
defer os.Remove(m.Cfg.SocketPath)
// If the process is not running, return early.
if pid <= 0 {
return nil
}
// This checks if the process is still running every pidCheckTickTime.
// If the process is still running after the timeout it will set done to false.
done := make(chan bool, 1)
go func() {
ticker := time.NewTicker(pidCheckTickTime)
defer ticker.Stop()
to := time.NewTimer(timeout)
defer to.Stop()
for {
select {
case <-to.C:
done <- false
return
case <-ticker.C:
if pid, _ := m.PID(); pid <= 0 {
done <- true
return
}
}
}
}()
// Wait for the check to finish.
killed := <-done
if !killed {
// The shutdown request timed out, kill the process with SIGKILL.
err := syscall.Kill(pid, syscall.SIGKILL)
if err != nil {
return fmt.Errorf("failed to kill process: %v", err)
}
}
return nil
}
// FindVM finds a Firecracker VM by its socket path.
// This implementation checks if the VM is running by sending a request to the Firecracker API.
func (c *Client) FindVM(ctx context.Context, socketPath string) (*firecracker.Machine, error) {
// Check if the socket file exists.
if _, err := os.Stat(socketPath); err != nil {
return nil, fmt.Errorf("VM with socket path %v not found", socketPath)
}
// Create a new Firecracker machine instance.
cmd := firecracker.VMCommandBuilder{}.WithSocketPath(socketPath).Build(ctx)
machine, err := firecracker.NewMachine(
ctx,
firecracker.Config{SocketPath: socketPath},
firecracker.WithProcessRunner(cmd),
)
if err != nil {
return nil, fmt.Errorf("failed to create machine with socket %s: %v", socketPath, err)
}
// Check if the VM is running by getting its instance info.
info, err := machine.DescribeInstanceInfo(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get instance info for socket %s: %v", socketPath, err)
}
if *info.State != "Running" {
return nil, fmt.Errorf(
"VM with socket %s is not running, current state: %s",
socketPath,
*info.State,
)
}
return machine, nil
}
//go:build linux
// +build linux
package firecracker
import (
"context"
"errors"
"fmt"
"io"
"os"
"sync"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
fcModels "github.com/firecracker-microvm/firecracker-go-sdk/client/models"
"go.uber.org/multierr"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
socketDir = "/tmp"
)
var ErrNotInstalled = errors.New("firecracker is not installed")
// Executor manages the lifecycle of Firecracker VMs for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Firecracker client for VM management.
}
// NewExecutor initializes a new executor for Firecracker VMs.
func NewExecutor(ctx context.Context, id string) (*Executor, error) {
firecrackerClient, err := NewFirecrackerClient()
if err != nil {
return nil, err
}
if !firecrackerClient.IsInstalled(ctx) {
return nil, ErrNotInstalled
}
fe := &Executor{
ID: id,
client: firecrackerClient,
}
return fe, nil
}
// List the current executions.
func (e *Executor) List() []types.ExecutionListItem {
executions := make([]types.ExecutionListItem, 0)
e.handlers.Range(func(key, value any) bool {
strKey := key.(string)
val := value.(*executionHandler)
executions = append(executions, types.ExecutionListItem{
ExecutionID: strKey,
Running: val.running.Load(),
})
return true
})
return executions
}
// start begins the execution of a request by starting a new Firecracker VM.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// VM is already running.
machine, err := e.FindRunningVM(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running VM for this execution, we will instead check for a handler, and
// failing that will create a new VM.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
}
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
machine, err = e.newFirecrackerExecutionVM(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new firecracker VM: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
machine: machine,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the VM.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s received results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// GetLogStream is not implemented for Firecracker.
// It is defined to satisfy the Executor interface.
// This method will return an error if called.
func (e *Executor) GetLogStream(_ context.Context, _ types.LogStreamRequest) (io.ReadCloser, error) {
return nil, fmt.Errorf("GetLogStream is not implemented for Firecracker")
}
// Cleanup removes all resources associated with the executor.
// This includes stopping and removing all running VMs and deleting their socket paths.
func (e *Executor) Cleanup(_ context.Context) error {
wg := sync.WaitGroup{}
errCh := make(chan error, len(e.handlers.Keys()))
e.handlers.Iter(func(_ string, handler *executionHandler) bool {
wg.Add(1)
go func(handler *executionHandler, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- handler.destroy(time.Second * 10)
}(handler, &wg, errCh)
return true
})
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
zlog.Info("Cleaned up all firecracker resources")
return errs
}
// newFirecrackerExecutionVM is an internal method called by Start to set up a new Firecracker VM
// for the job execution. It configures the VM based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up mounts and resource constraints.
// It then creates the VM but does not start it. The method returns a firecracker.Machine instance
// and an error if any part of the setup fails.
func (e *Executor) newFirecrackerExecutionVM(
ctx context.Context,
params *types.ExecutionRequest,
) (*firecracker.Machine, error) {
fcArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return nil, fmt.Errorf("failed to decode firecracker engine spec: %w", err)
}
fcConfig := firecracker.Config{
VMID: params.ExecutionID,
SocketPath: e.generateSocketPath(params.JobID, params.ExecutionID),
KernelImagePath: fcArgs.KernelImage,
InitrdPath: fcArgs.Initrd,
KernelArgs: fcArgs.KernelArgs,
MachineCfg: fcModels.MachineConfiguration{
VcpuCount: firecracker.Int64(int64(params.Resources.CPU.Cores)),
MemSizeMib: firecracker.Int64(int64(params.Resources.RAM.Size)), //nolint
},
}
mounts, err := makeVMMounts(
fcArgs.RootFileSystem,
params.Inputs,
params.Outputs,
params.ResultsDir,
)
if err != nil {
return nil, fmt.Errorf("failed to create VM mounts: %w", err)
}
fcConfig.Drives = mounts
machine, err := e.client.CreateVM(ctx, fcConfig)
if err != nil {
return nil, fmt.Errorf("failed to create VM: %w", err)
}
// e.client.VMPassMMDs(ctx, machine, fcArgs.MMDSMessage)
return machine, nil
}
// makeVMMounts creates the mounts for the VM based on the input and output volumes
// provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeVMMounts(
rootFileSystem string,
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]fcModels.Drive, error) {
var drives []fcModels.Drive
drivesBuilder := firecracker.NewDrivesBuilder(rootFileSystem)
for _, input := range inputs {
drivesBuilder.AddDrive(input.Source, input.ReadOnly)
}
for _, output := range outputs {
if output.Source == "" {
return drives, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return drives, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return drives, fmt.Errorf("failed to create results directory: %w", err)
}
drivesBuilder.AddDrive(output.Source, false)
}
drives = drivesBuilder.Build()
return drives, nil
}
// FindRunningVM finds the VM that is running the execution with the given ID.
// It returns the Mchine instance if found, or an error if the VM is not found.
func (e *Executor) FindRunningVM(
ctx context.Context,
jobID string,
executionID string,
) (*firecracker.Machine, error) {
return e.client.FindVM(ctx, e.generateSocketPath(jobID, executionID))
}
// generateSocketPath generates a socket path based on the job identifiers.
func (e *Executor) generateSocketPath(jobID string, executionID string) string {
return fmt.Sprintf("%s/%s_%s_%s.sock", socketDir, e.ID, jobID, executionID)
}
//go:build linux
// +build linux
package firecracker
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
"gitlab.com/nunet/device-management-service/types"
)
// executionHandler is a struct that holds the necessary information to manage the execution of a firecracker VM.
type executionHandler struct {
//
// provided by the executor
ID string
client *Client
// meta data about the task
JobID string
executionID string
machine *firecracker.Machine
resultsDir string
// synchronization
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // BLocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
}
// active returns true if the firecracker VM is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the firecracker VM and waits for it to finish.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
destroyTimeout := time.Second * 10
if err := h.destroy(destroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
// start the VM
zlog.Sugar().Info("starting firecracker execution")
if err := h.client.StartVM(ctx, h.machine); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start VM: %v", err))
return
}
close(h.activeCh) // Indicate that the VM has started.
err := h.machine.Wait(ctx)
if err != nil {
if ctx.Err() != nil {
h.result = types.NewFailedExecutionResult(
fmt.Errorf("context closed while waiting on VM: %v", err),
)
return
}
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to wait on VM: %v", err))
return
}
h.result = types.NewExecutionResult(types.ExecutionStatusCodeSuccess)
}
// kill stops the firecracker VM.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.ShutdownVM(ctx, h.machine)
}
// destroy stops the firecracker VM and removes its resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return h.client.DestroyVM(ctx, h.machine, timeout)
}
//go:build linux
// +build linux
package firecracker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("executor.firecracker")
}
//go:build linux
// +build linux
package firecracker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyKernelImage = "kernel_image"
EngineKeyKernelArgs = "kernel_args"
EngineKeyRootFileSystem = "root_file_system"
EngineKeyInitrd = "initrd"
EngineKeyMMDSMessage = "mmds_message"
)
// EngineSpec contains necessary parameters to execute a firecracker job.
type EngineSpec struct {
// KernelImage is the path to the kernel image file.
KernelImage string `json:"kernel_image,omitempty"`
// InitrdPath is the path to the initial ramdisk file.
Initrd string `json:"initrd_path,omitempty"`
// KernelArgs is the kernel command line arguments.
KernelArgs string `json:"kernel_args,omitempty"`
// RootFileSystem is the path to the root file system.
RootFileSystem string `json:"root_file_system,omitempty"`
// MMDSMessage is the MMDS message to be sent to the Firecracker VM.
MMDSMessage string `json:"mmds_message,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.RootFileSystem) {
return fmt.Errorf("invalid firecracker engine params: root_file_system cannot be empty")
}
if validate.IsBlank(c.KernelImage) {
return fmt.Errorf("invalid firecracker engine params: kernel_image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a firecracker engine spec
// It converts the params into a firecracker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(types.ExecutorTypeFirecracker.String()) {
return EngineSpec{}, fmt.Errorf(
"invalid firecracker engine type. expected %s, but received: %s",
types.ExecutorTypeFirecracker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid firecracker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode firecracker engine params: %w", err)
}
var firecrackerSpec *EngineSpec
err = json.Unmarshal(paramBytes, &firecrackerSpec)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode firecracker engine params: %w", err)
}
return *firecrackerSpec, firecrackerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Firecracker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewFirecrackerEngineBuilder function initializes a new FirecrackerEngineBuilder instance.
// It sets the engine type to EngineFirecracker.String() and kernel image path as per the input argument.
func NewFirecrackerEngineBuilder(rootFileSystem string) *EngineBuilder {
eb := types.NewSpecConfig(types.ExecutorTypeFirecracker.String())
eb.WithParam(EngineKeyRootFileSystem, rootFileSystem)
return &EngineBuilder{eb: eb}
}
// WithRootFileSystem is a builder method that sets the Firecracker engine root file system.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithRootFileSystem(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyRootFileSystem, e)
return b
}
// WithKernelImage is a builder method that sets the Firecracker engine kernel image.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelImage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelImage, e)
return b
}
// WithInitrd is a builder method that sets the Firecracker init ram disk.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithInitrd(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyInitrd, e)
return b
}
// WithKernelArgs is a builder method that sets the Firecracker engine kernel arguments.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelArgs(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelArgs, e)
return b
}
// WithMMDSMessage is a builder method that sets the Firecracker engine MMDS message.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithMMDSMessage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyMMDSMessage, e)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
package backgroundtasks
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("background_tasks")
}
package backgroundtasks
import (
"sort"
"sync"
"time"
)
// Scheduler orchestrates the execution of tasks based on their triggers and priority.
type Scheduler struct {
tasks map[int]*Task // Map of tasks by their ID.
runningTasks map[int]bool // Map to keep track of running tasks.
ticker *time.Ticker // Ticker for periodic checks of task triggers.
stopChan chan struct{} // Channel to signal stopping the scheduler.
maxRunningTasks int // Maximum number of tasks that can run concurrently.
lastTaskID int // Counter for assigning unique IDs to tasks.
mu sync.Mutex // Mutex to protect access to task maps.
}
// NewScheduler creates a new Scheduler with a specified limit on running tasks.
func NewScheduler(maxRunningTasks int) *Scheduler {
return &Scheduler{
tasks: make(map[int]*Task),
runningTasks: make(map[int]bool),
ticker: time.NewTicker(1 * time.Second),
stopChan: make(chan struct{}),
maxRunningTasks: maxRunningTasks,
lastTaskID: 0,
}
}
// AddTask adds a new task to the scheduler and initializes its state.
func (s *Scheduler) AddTask(task *Task) *Task {
s.mu.Lock()
defer s.mu.Unlock()
task.ID = s.lastTaskID
task.Enabled = true
for _, trigger := range task.Triggers {
trigger.Reset()
}
s.tasks[task.ID] = task
s.lastTaskID++
return task
}
// RemoveTask removes a task from the scheduler.
func (s *Scheduler) RemoveTask(taskID int) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Start begins the scheduler's task execution loop.
func (s *Scheduler) Start() {
go func() {
for {
select {
case <-s.stopChan:
return
case <-s.ticker.C:
s.runTasks()
}
}
}()
}
// runningTasksCount returns the count of running tasks.
func (s *Scheduler) runningTasksCount() int {
s.mu.Lock()
defer s.mu.Unlock()
count := 0
for _, isRunning := range s.runningTasks {
if isRunning {
count++
}
}
return count
}
// runTasks checks and runs tasks based on their triggers and priority.
func (s *Scheduler) runTasks() {
// Sort tasks by priority.
sortedTasks := make([]*Task, 0, len(s.tasks))
for _, task := range s.tasks {
sortedTasks = append(sortedTasks, task)
}
sort.Slice(sortedTasks, func(i, j int) bool {
return sortedTasks[i].Priority > sortedTasks[j].Priority
})
for _, task := range sortedTasks {
if !task.Enabled || s.runningTasks[task.ID] {
continue
}
if len(task.Triggers) == 0 {
s.RemoveTask(task.ID)
continue
}
for _, trigger := range task.Triggers {
if trigger.IsReady() && s.runningTasksCount() < s.maxRunningTasks {
s.runningTasks[task.ID] = true
go s.runTask(task.ID)
trigger.Reset()
break
}
}
}
}
// Stop signals the scheduler to stop running tasks.
func (s *Scheduler) Stop() {
close(s.stopChan)
}
// runTask executes a task and manages its lifecycle and retry policy.
func (s *Scheduler) runTask(taskID int) {
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
s.runningTasks[taskID] = false
}()
task := s.tasks[taskID]
execution := Execution{StartedAt: time.Now()}
defer func() {
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.tasks[taskID] = task
s.mu.Unlock()
}()
for i := 0; i < task.RetryPolicy.MaxRetries+1; i++ {
err := runTaskWithRetry(task.Function, task.Args, task.RetryPolicy.Delay)
if err == nil {
execution.Status = "SUCCESS"
execution.EndedAt = time.Now()
return
}
execution.Error = err.Error()
}
execution.Status = "FAILED"
execution.EndedAt = time.Now()
}
// runTaskWithRetry attempts to execute a task with a retry policy.
func runTaskWithRetry(
fn func(args interface{}) error,
args []interface{},
delay time.Duration,
) error {
err := fn(args)
if err != nil {
time.Sleep(delay)
return err
}
return nil
}
package backgroundtasks
import (
"time"
"github.com/robfig/cron/v3"
)
// Trigger interface defines a method to check if a trigger condition is met.
type Trigger interface {
IsReady() bool // Returns true if the trigger condition is met.
Reset() // Resets the trigger state.
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTrigger struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTrigger) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTrigger) Reset() {
t.lastTriggered = time.Now()
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTriggerWithJitter struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
Jitter func() time.Duration
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTriggerWithJitter) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval + t.Jitter()).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTriggerWithJitter) Reset() {
t.lastTriggered = time.Now()
}
// EventTrigger triggers based on an external event signaled through a channel.
type EventTrigger struct {
Trigger chan bool // Channel to signal an event.
}
// IsReady checks if there is a signal in the trigger channel.
func (t *EventTrigger) IsReady() bool {
select {
case <-t.Trigger:
return true
default:
return false
}
}
// Reset for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) Reset() {}
// OneTimeTrigger triggers once after a specified delay.
type OneTimeTrigger struct {
Delay time.Duration // The delay after which to trigger.
registeredAt time.Time // Time when the trigger was set.
}
// Reset sets the trigger registration time to the current time.
func (t *OneTimeTrigger) Reset() {
t.registeredAt = time.Now()
}
// IsReady checks if the current time has passed the delay period.
func (t *OneTimeTrigger) IsReady() bool {
return t.registeredAt.Add(t.Delay).Before(time.Now())
}
package keystore
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/hex"
"encoding/json"
"fmt"
"path/filepath"
"strings"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
"golang.org/x/crypto/scrypt"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
var (
// KDF parameters
nameKDF = "scrypt"
scryptKeyLen = 64
scryptN = 1 << 18
scryptR = 8
scryptP = 1
ksVersion = 3
ksCipher = "aes-256-ctr"
)
// Key represents a keypair to be stored in a keystore
type Key struct {
ID string
Data []byte
}
// NewKey creates new Key
func NewKey(id string, data []byte) (*Key, error) {
return &Key{
ID: id,
Data: data,
}, nil
}
// PrivKey acts upon a Key which its `Data` is a private key.
// The method unmarshals the raw pvkey bytes.
func (key *Key) PrivKey() (crypto.PrivKey, error) {
priv, err := libp2p_crypto.UnmarshalPrivateKey(key.Data)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal private key: %v", err)
}
return priv, nil
}
// MarshalToJSON encrypts and marshals a key to json byte array.
func (key *Key) MarshalToJSON(passphrase string) ([]byte, error) {
if passphrase == "" {
return nil, ErrEmptyPassphrase
}
salt, err := crypto.RandomEntropy(64)
if err != nil {
return nil, err
}
dk, err := scrypt.Key([]byte(passphrase), salt, scryptN, scryptR, scryptP, scryptKeyLen)
if err != nil {
return nil, err
}
iv, err := crypto.RandomEntropy(aes.BlockSize)
if err != nil {
return nil, err
}
enckey := dk[:32]
aesBlock, err := aes.NewCipher(enckey)
if err != nil {
return nil, err
}
stream := cipher.NewCTR(aesBlock, iv)
cipherText := make([]byte, len(key.Data))
stream.XORKeyStream(cipherText, key.Data)
mac, err := crypto.Sha3(dk[32:64], cipherText)
if err != nil {
return nil, err
}
cipherParamsJSON := cipherparamsJSON{
IV: hex.EncodeToString(iv),
}
sp := ScryptParams{
N: scryptN,
R: scryptR,
P: scryptP,
DKeyLength: scryptKeyLen,
Salt: hex.EncodeToString(salt),
}
keyjson := cryptoJSON{
Cipher: ksCipher,
CipherText: hex.EncodeToString(cipherText),
CipherParams: cipherParamsJSON,
KDF: nameKDF,
KDFParams: sp,
MAC: hex.EncodeToString(mac),
}
encjson := encryptedKeyJSON{
Crypto: keyjson,
ID: key.ID,
Version: ksVersion,
}
data, err := json.MarshalIndent(&encjson, "", " ")
if err != nil {
return nil, err
}
return data, nil
}
// UnmarshalKey decrypts and unmarhals the private key
func UnmarshalKey(data []byte, passphrase string) (*Key, error) {
if passphrase == "" {
return nil, ErrEmptyPassphrase
}
encjson := encryptedKeyJSON{}
if err := json.Unmarshal(data, &encjson); err != nil {
return nil, fmt.Errorf("failed to unmarshal key data: %w", err)
}
if encjson.Version != ksVersion {
return nil, ErrVersionMismatch
}
if encjson.Crypto.Cipher != ksCipher {
return nil, ErrCipherMismatch
}
mac, err := hex.DecodeString(encjson.Crypto.MAC)
if err != nil {
return nil, fmt.Errorf("failed to decode mac: %w", err)
}
iv, err := hex.DecodeString(encjson.Crypto.CipherParams.IV)
if err != nil {
return nil, fmt.Errorf("failed to decode cipher params iv: %w", err)
}
salt, err := hex.DecodeString(encjson.Crypto.KDFParams.Salt)
if err != nil {
return nil, fmt.Errorf("failed to decode salt: %w", err)
}
ciphertext, err := hex.DecodeString(encjson.Crypto.CipherText)
if err != nil {
return nil, fmt.Errorf("failed to decode cipher text: %w", err)
}
dk, err := scrypt.Key([]byte(passphrase), salt, encjson.Crypto.KDFParams.N, encjson.Crypto.KDFParams.R, encjson.Crypto.KDFParams.P, encjson.Crypto.KDFParams.DKeyLength)
if err != nil {
return nil, fmt.Errorf("failed to derive key: %w", err)
}
hash, err := crypto.Sha3(dk[32:64], ciphertext)
if err != nil {
return nil, fmt.Errorf("failed to hash key and ciphertext: %w", err)
}
if !bytes.Equal(hash, mac) {
return nil, ErrMACMismatch
}
aesBlock, err := aes.NewCipher(dk[:32])
if err != nil {
return nil, fmt.Errorf("failed to create cipher block: %w", err)
}
stream := cipher.NewCTR(aesBlock, iv)
outputkey := make([]byte, len(ciphertext))
stream.XORKeyStream(outputkey, ciphertext)
return &Key{
ID: encjson.ID,
Data: outputkey,
}, nil
}
func removeFileExtension(filename string) string {
ext := filepath.Ext(filename)
return strings.TrimSuffix(filename, ext)
}
package keystore
import (
"fmt"
"os"
"path/filepath"
"sync"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/utils"
)
// KeyStore manages a local keystore with lock and unlock functionalities.
type KeyStore interface {
Save(id string, data []byte, passphrase string) (string, error)
Get(keyID string, passphrase string) (*Key, error)
Delete(keyID string, passphrase string) error
ListKeys() ([]string, error)
}
// BasicKeyStore handles keypair storage.
// TODO: add cache?
type BasicKeyStore struct {
fs afero.Fs
keysDir string
mu sync.RWMutex
}
var _ KeyStore = (*BasicKeyStore)(nil)
// New creates a new BasicKeyStore.
func New(fs afero.Fs, keysDir string) (*BasicKeyStore, error) {
if keysDir == "" {
return nil, ErrEmptyKeysDir
}
if err := fs.MkdirAll(keysDir, 0o700); err != nil {
return nil, fmt.Errorf("failed to create keystore directory: %w", err)
}
return &BasicKeyStore{
fs: fs,
keysDir: keysDir,
}, nil
}
// Save encrypts a key and writes it to a file.
func (ks *BasicKeyStore) Save(id string, data []byte, passphrase string) (string, error) {
if passphrase == "" {
return "", ErrEmptyPassphrase
}
key := &Key{
ID: id,
Data: data,
}
keyDataJSON, err := key.MarshalToJSON(passphrase)
if err != nil {
return "", fmt.Errorf("failed to marshal key: %w", err)
}
filename, err := utils.WriteToFile(ks.fs, keyDataJSON, filepath.Join(ks.keysDir, key.ID+".json"))
if err != nil {
return "", fmt.Errorf("failed to write key to file: %v", err)
}
return filename, nil
}
// Get unlocks a key by keyID.
func (ks *BasicKeyStore) Get(keyID string, passphrase string) (*Key, error) {
bts, err := afero.ReadFile(ks.fs, filepath.Join(ks.keysDir, keyID+".json"))
if err != nil {
if os.IsNotExist(err) {
return nil, ErrKeyNotFound
}
return nil, fmt.Errorf("failed to read keystore file: %w", err)
}
key, err := UnmarshalKey(bts, passphrase)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal keystore file: %w", err)
}
return key, err
}
// Delete removes the file referencing the given key.
func (ks *BasicKeyStore) Delete(keyID string, passphrase string) error {
ks.mu.Lock()
defer ks.mu.Unlock()
filePath := filepath.Join(ks.keysDir, keyID+".json")
bts, err := afero.ReadFile(ks.fs, filePath)
if err != nil {
if os.IsNotExist(err) {
return ErrKeyNotFound
}
return fmt.Errorf("failed to read keystore file: %w", err)
}
_, err = UnmarshalKey(bts, passphrase)
if err != nil {
return fmt.Errorf("invalid passphrase or corrupted key file: %w", err)
}
err = ks.fs.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to delete key file: %w", err)
}
return nil
}
// ListKeys lists the keys in the keysDir.
func (ks *BasicKeyStore) ListKeys() ([]string, error) {
keys := make([]string, 0)
dirEntries, err := afero.ReadDir(ks.fs, ks.keysDir)
if err != nil {
return nil, fmt.Errorf("failed to read keystore directory: %w", err)
}
for _, entry := range dirEntries {
_, err := afero.ReadFile(ks.fs, filepath.Join(ks.keysDir, entry.Name()))
if err != nil {
continue
}
keys = append(keys, removeFileExtension(entry.Name()))
}
return keys, nil
}
package did
type GetAnchorFunc func(did DID) (Anchor, error)
var anchorMethods map[string]GetAnchorFunc
func init() {
anchorMethods = map[string]GetAnchorFunc{
"key": makeKeyAnchor,
}
}
func GetAnchorForDID(did DID) (Anchor, error) {
makeAnchor, ok := anchorMethods[did.Method()]
if !ok {
return nil, ErrNoAnchorMethod
}
return makeAnchor(did)
}
func makeKeyAnchor(did DID) (Anchor, error) {
pubk, err := PublicKeyFromDID(did)
if err != nil {
return nil, err
}
return NewAnchor(did, pubk), nil
}
package did
import (
"context"
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const anchorEntryTTL = time.Hour
// Anchor is a DID anchor that encapsulates a public key that can be used
// for verification of signatures.
type Anchor interface {
DID() DID
Verify(data []byte, sig []byte) error
PublicKey() crypto.PubKey
}
// Provider holds the private key material necessary to sign statements for
// a DID.
type Provider interface {
DID() DID
Sign(data []byte) ([]byte, error)
Anchor() Anchor
PrivateKey() (crypto.PrivKey, error)
}
type TrustContext interface {
Anchors() []DID
Providers() []DID
GetAnchor(did DID) (Anchor, error)
GetProvider(did DID) (Provider, error)
AddAnchor(anchor Anchor)
AddProvider(provider Provider)
Start(gcInterval time.Duration)
Stop()
}
type anchorEntry struct {
anchor Anchor
expire time.Time
}
type BasicTrustContext struct {
mx sync.Mutex
anchors map[DID]*anchorEntry
providers map[DID]Provider
stop func()
}
var _ TrustContext = (*BasicTrustContext)(nil)
func NewTrustContext() TrustContext {
return &BasicTrustContext{
anchors: make(map[DID]*anchorEntry),
providers: make(map[DID]Provider),
}
}
func NewTrustContextWithPrivateKey(privk crypto.PrivKey) (TrustContext, error) {
ctx := NewTrustContext()
provider, err := ProviderFromPrivateKey(privk)
if err != nil {
return nil, fmt.Errorf("provide from private key: %w", err)
}
ctx.AddProvider(provider)
return ctx, nil
}
func NewTrustContextWithProvider(p Provider) TrustContext {
ctx := NewTrustContext()
ctx.AddProvider(p)
return ctx
}
func (ctx *BasicTrustContext) Anchors() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.anchors))
for anchor := range ctx.anchors {
result = append(result, anchor)
}
return result
}
func (ctx *BasicTrustContext) Providers() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.providers))
for provider := range ctx.providers {
result = append(result, provider)
}
return result
}
func (ctx *BasicTrustContext) GetAnchor(did DID) (Anchor, error) {
anchor, ok := ctx.getAnchor(did)
if ok {
return anchor, nil
}
anchor, err := GetAnchorForDID(did)
if err != nil {
return nil, fmt.Errorf("get anchor for did: %w", err)
}
ctx.AddAnchor(anchor)
return anchor, nil
}
func (ctx *BasicTrustContext) getAnchor(did DID) (Anchor, bool) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
entry, ok := ctx.anchors[did]
if ok {
entry.expire = time.Now().Add(anchorEntryTTL)
return entry.anchor, true
}
return nil, false
}
func (ctx *BasicTrustContext) GetProvider(did DID) (Provider, error) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
provider, ok := ctx.providers[did]
if !ok {
return nil, ErrNoProvider
}
return provider, nil
}
func (ctx *BasicTrustContext) AddAnchor(anchor Anchor) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.anchors[anchor.DID()] = &anchorEntry{
anchor: anchor,
expire: time.Now().Add(anchorEntryTTL),
}
}
func (ctx *BasicTrustContext) AddProvider(provider Provider) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.providers[provider.DID()] = provider
}
func (ctx *BasicTrustContext) Start(gcInterval time.Duration) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
}
gcCtx, stop := context.WithCancel(context.Background())
ctx.stop = stop
go ctx.gc(gcCtx, gcInterval)
}
func (ctx *BasicTrustContext) Stop() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
ctx.stop = nil
}
}
func (ctx *BasicTrustContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcAnchorEntries()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicTrustContext) gcAnchorEntries() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := time.Now()
for k, e := range ctx.anchors {
if e.expire.Before(now) {
delete(ctx.anchors, k)
}
}
}
package did
import (
"strings"
)
type DID struct {
URI string `json:"uri,omitempty"`
}
func (did DID) Equal(other DID) bool {
return did.URI == other.URI
}
func (did DID) Empty() bool {
return did.URI == ""
}
func (did DID) String() string {
return did.URI
}
func (did DID) Method() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[1]
}
return ""
}
func (did DID) Identifier() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[2]
}
return ""
}
func FromString(s string) (DID, error) {
if s != "" {
parts := strings.Split(s, ":")
if len(parts) != 3 {
return DID{}, ErrInvalidDID
}
for _, part := range parts {
if part == "" {
return DID{}, ErrInvalidDID
}
}
// TODO validate parts according to spec: https://www.w3.org/TR/did-core/
}
return DID{URI: s}, nil
}
package did
import (
"fmt"
"strings"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
mb "github.com/multiformats/go-multibase"
varint "github.com/multiformats/go-varint"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
type PublicKeyAnchor struct {
did DID
pubk crypto.PubKey
}
var _ Anchor = (*PublicKeyAnchor)(nil)
type PrivateKeyProvider struct {
did DID
privk crypto.PrivKey
}
var _ Provider = (*PrivateKeyProvider)(nil)
func NewAnchor(did DID, pubk crypto.PubKey) Anchor {
return &PublicKeyAnchor{
did: did,
pubk: pubk,
}
}
func NewProvider(did DID, privk crypto.PrivKey) Provider {
return &PrivateKeyProvider{
did: did,
privk: privk,
}
}
func (a *PublicKeyAnchor) DID() DID {
return a.did
}
func (a *PublicKeyAnchor) Verify(data []byte, sig []byte) error {
ok, err := a.pubk.Verify(data, sig)
if err != nil {
return err
}
if !ok {
return ErrInvalidSignature
}
return nil
}
func (a *PublicKeyAnchor) PublicKey() crypto.PubKey {
return a.pubk
}
func (p *PrivateKeyProvider) DID() DID {
return p.did
}
func (p *PrivateKeyProvider) Sign(data []byte) ([]byte, error) {
return p.privk.Sign(data)
}
func (p *PrivateKeyProvider) PrivateKey() (crypto.PrivKey, error) {
return p.privk, nil
}
func (p *PrivateKeyProvider) Anchor() Anchor {
return NewAnchor(p.did, p.privk.GetPublic())
}
func FromID(id crypto.ID) (DID, error) {
pubk, err := crypto.PublicKeyFromID(id)
if err != nil {
return DID{}, fmt.Errorf("public key from id: %w", err)
}
return FromPublicKey(pubk), nil
}
func FromPublicKey(pubk crypto.PubKey) DID {
uri := FormatKeyURI(pubk)
return DID{URI: uri}
}
func PublicKeyFromDID(did DID) (crypto.PubKey, error) {
if did.Method() != "key" {
return nil, ErrInvalidDID
}
pubk, err := ParseKeyURI(did.URI)
if err != nil {
return nil, fmt.Errorf("parsing did key identifier: %w", err)
}
return pubk, nil
}
func AnchorFromPublicKey(pubk crypto.PubKey) (Anchor, error) {
did := FromPublicKey(pubk)
return NewAnchor(did, pubk), nil
}
func ProviderFromPrivateKey(privk crypto.PrivKey) (Provider, error) {
did := FromPublicKey(privk.GetPublic())
return NewProvider(did, privk), nil
}
// Note: this code originated in https://github.com/ucan-wg/go-ucan/blob/main/didkey/key.go
// Copyright applies; some superficial modifications by vyzo.
const (
multicodecKindEd25519PubKey uint64 = 0xed
multicodecKindSecp256k1PubKey uint64 = 0xe7
multicodecKindEthPubKey uint64 = 0xef01
keyPrefix = "did:key"
)
func FormatKeyURI(pubk crypto.PubKey) string {
raw, err := pubk.Raw()
if err != nil {
return ""
}
var t uint64
switch pubk.Type() {
case crypto.Ed25519:
t = multicodecKindEd25519PubKey
case crypto.Secp256k1:
t = multicodecKindSecp256k1PubKey
case crypto.Eth:
t = multicodecKindEthPubKey
default:
// we don't support those yet
log.Errorf("unsupported key type: %d", t)
return ""
}
size := varint.UvarintSize(t)
data := make([]byte, size+len(raw))
n := varint.PutUvarint(data, t)
copy(data[n:], raw)
b58BKeyStr, err := mb.Encode(mb.Base58BTC, data)
if err != nil {
return ""
}
return fmt.Sprintf("%s:%s", keyPrefix, b58BKeyStr)
}
func ParseKeyURI(uri string) (crypto.PubKey, error) {
if !strings.HasPrefix(uri, keyPrefix) {
return nil, fmt.Errorf("decentralized identifier is not a 'key' type")
}
uri = strings.TrimPrefix(uri, keyPrefix+":")
enc, data, err := mb.Decode(uri)
if err != nil {
return nil, fmt.Errorf("decoding multibase: %w", err)
}
if enc != mb.Base58BTC {
return nil, fmt.Errorf("unexpected multibase encoding: %s", mb.EncodingToStr[enc])
}
keyType, n, err := varint.FromUvarint(data)
if err != nil {
return nil, err
}
switch keyType {
case multicodecKindEd25519PubKey:
return libp2p_crypto.UnmarshalEd25519PublicKey(data[n:])
case multicodecKindSecp256k1PubKey:
return libp2p_crypto.UnmarshalSecp256k1PublicKey(data[n:])
case multicodecKindEthPubKey:
return crypto.UnmarshalEthPublicKey(data[n:])
default:
return nil, ErrInvalidKeyType
}
}
package did
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
"os/exec"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const ledgerCLI = "ledger-cli"
type LedgerWalletProvider struct {
did DID
pubk crypto.PubKey
acct int
}
var _ Provider = (*LedgerWalletProvider)(nil)
type LedgerKeyOutput struct {
Key string `json:"key"`
Address string `json:"address"`
}
type LedgerSignOutput struct {
ECDSA LedgerSignECDSAOutput `json:"ecdsa"`
}
type LedgerSignECDSAOutput struct {
V uint `json:"v"`
R string `json:"r"`
S string `json:"s"`
}
func NewLedgerWalletProvider(acct int) (Provider, error) {
tmp, err := getLedgerTmpFile()
if err != nil {
return nil, err
}
defer os.Remove(tmp)
var output LedgerKeyOutput
if err := ledgerExec(
tmp,
&output,
"key",
"-o", tmp,
"-a", fmt.Sprintf("%d", acct),
); err != nil {
return nil, fmt.Errorf("error executing ledger cli: %w", err)
}
// decode the hex key
raw, err := hex.DecodeString(output.Key)
if err != nil {
return nil, fmt.Errorf("decode ledger key: %w", err)
}
pubk, err := crypto.UnmarshalEthPublicKey(raw)
if err != nil {
return nil, fmt.Errorf("unmarshal ledger raw key: %w", err)
}
did := FromPublicKey(pubk)
return &LedgerWalletProvider{
did: did,
pubk: pubk,
acct: acct,
}, nil
}
func ledgerExec(tmp string, output interface{}, args ...string) error {
ledger, err := exec.LookPath(ledgerCLI)
if err != nil {
return fmt.Errorf("can't find %s in PATH: %w", ledgerCLI, err)
}
cmd := exec.Command(ledger, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("get ledger key: %w", err)
}
f, err := os.Open(tmp)
if err != nil {
return fmt.Errorf("open ledger output: %w", err)
}
defer f.Close()
decoder := json.NewDecoder(f)
if err := decoder.Decode(&output); err != nil {
return fmt.Errorf("parse ledger output: %w", err)
}
return nil
}
func getLedgerTmpFile() (string, error) {
tmp, err := os.CreateTemp("", "ledger.out")
if err != nil {
return "", fmt.Errorf("creating temporary file: %w", err)
}
tmpPath := tmp.Name()
tmp.Close()
return tmpPath, nil
}
func (p *LedgerWalletProvider) DID() DID {
return p.did
}
func (p *LedgerWalletProvider) Sign(data []byte) ([]byte, error) {
tmp, err := getLedgerTmpFile()
if err != nil {
return nil, err
}
defer os.Remove(tmp)
dataHex := hex.EncodeToString(data)
var output LedgerSignOutput
if err := ledgerExec(
tmp,
&output,
"sign",
"-o", tmp,
"-a", fmt.Sprintf("%d", p.acct),
dataHex,
); err != nil {
return nil, fmt.Errorf("error executing ledger cli: %w", err)
}
rBytes, err := hex.DecodeString(output.ECDSA.R)
if err != nil {
return nil, fmt.Errorf("error decoding signature r: %w", err)
}
sBytes, err := hex.DecodeString(output.ECDSA.S)
if err != nil {
return nil, fmt.Errorf("error decoding signature s: %w", err)
}
r := secp256k1.ModNScalar{}
s := secp256k1.ModNScalar{}
if overflow := r.SetByteSlice(rBytes); overflow {
return nil, fmt.Errorf("signature r overflowed")
}
if overflow := s.SetByteSlice(sBytes); overflow {
return nil, fmt.Errorf("signature s overflowed")
}
sig := ecdsa.NewSignature(&r, &s)
return sig.Serialize(), nil
}
func (p *LedgerWalletProvider) Anchor() Anchor {
return NewAnchor(p.did, p.pubk)
}
func (p *LedgerWalletProvider) PrivateKey() (crypto.PrivKey, error) {
return nil, fmt.Errorf("ledger private key cannot be exported: %w", ErrHardwareKey)
}
package ucan
import (
"strings"
)
type Capability string
const Root = Capability("/")
func (c Capability) Implies(other Capability) bool {
if c == other || c == Root {
return true
}
return strings.HasPrefix(string(other), string(c)+"/")
}
package ucan
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
)
const (
maxCapabilitySize = 16384
SelfSignNo SelfSignMode = iota
SelfSignAlso
SelfSignOnly
)
type SelfSignMode int
type CapabilityContext interface {
// DID returns the context's controlling DID
DID() did.DID
// Trust returns the context's did trust context
Trust() did.TrustContext
// Consume ingests the provided capability tokens
Consume(origin did.DID, cap []byte) error
// Discard discards previously consumed capability tokens
Discard(cap []byte)
// Require ensures that at least one of the capabilities is delegated from
// the subject to the audience, with an appropriate anchor
// An empty list will mean that no capabilities are required and is vacuously
// true.
Require(anchor did.DID, subject crypto.ID, audience crypto.ID, require []Capability) error
// RequireBroadcast ensures that at least one of the capabilities is delegated
// to thes subject for the specified broadcast topics
RequireBroadcast(origin did.DID, subject crypto.ID, topic string, require []Capability) error
// Provide prepares the appropriate capability tokens to prove and delegate authority
// to a subject for an audience.
// - It delegates invocations to the subject with an audience and invoke capabilities
// - It delegates the delegate capabilities to the target with audience the subject
Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, delegate []Capability) ([]byte, error)
// ProvideBroadcast prepares the appropriate capability tokens to prove authority
// to broadcast to a topic
ProvideBroadcast(subject crypto.ID, topic string, expire uint64, broadcast []Capability) ([]byte, error)
// AddRoots adds trust anchors
AddRoots(trust []did.DID, require, provide TokenList) error
// ListRoots list the current trust anchors
ListRoots() ([]did.DID, TokenList, TokenList)
// RemoveRoots removes the specified trust anchors
RemoveRoots(trust []did.DID, require, provide TokenList)
// Delegate creates the appropriate delegation tokens anchored in our roots
Delegate(subject, audience did.DID, topics []string, expire, depth uint64, cap []Capability, selfSign SelfSignMode) (TokenList, error)
// DelegateInvocation creates the appropriate invocation tokens anchored in anchor
DelegateInvocation(target, subject, audience did.DID, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error)
// DelegateBroadcast creates the appropriate broadcast token anchored in our roots
DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error)
// Grant creates the appropriate delegation tokens considering ourselves as the root
Grant(action Action, subject, audience did.DID, topic []string, expire, depth uint64, provide []Capability) (TokenList, error)
// Start starts a token garbage collector goroutine that clears expired tokens
Start(gcInterval time.Duration)
// Stop stops a previously started gc goroutine
Stop()
}
type BasicCapabilityContext struct {
mx sync.Mutex
provider did.Provider
trust did.TrustContext
roots map[did.DID]struct{} // our root anchors of trust
require map[did.DID][]*Token // our acceptance side-roots
provide map[did.DID][]*Token // root capabilities -> tokens
tokens map[did.DID][]*Token // our context dependent capabilities; subject -> tokens
stop func()
}
var _ CapabilityContext = (*BasicCapabilityContext)(nil)
func NewCapabilityContext(trust did.TrustContext, ctxDID did.DID, roots []did.DID, require, provide TokenList) (CapabilityContext, error) {
ctx := &BasicCapabilityContext{
trust: trust,
roots: make(map[did.DID]struct{}),
require: make(map[did.DID][]*Token),
provide: make(map[did.DID][]*Token),
tokens: make(map[did.DID][]*Token),
}
p, err := trust.GetProvider(ctxDID)
if err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
ctx.provider = p
if err := ctx.AddRoots(roots, require, provide); err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
return ctx, nil
}
func (ctx *BasicCapabilityContext) DID() did.DID {
return ctx.provider.DID()
}
func (ctx *BasicCapabilityContext) Trust() did.TrustContext {
return ctx.trust
}
func (ctx *BasicCapabilityContext) Start(gcInterval time.Duration) {
if ctx.stop != nil {
gcCtx, cancel := context.WithCancel(context.Background())
go ctx.gc(gcCtx, gcInterval)
ctx.stop = cancel
}
}
func (ctx *BasicCapabilityContext) Stop() {
if ctx.stop != nil {
ctx.stop()
}
}
func (ctx *BasicCapabilityContext) AddRoots(roots []did.DID, require, provide TokenList) error {
ctx.addRoots(roots)
now := uint64(time.Now().UnixNano())
for _, t := range require.Tokens {
if err := t.Verify(ctx.trust, now); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeRequireToken(t)
}
for _, t := range provide.Tokens {
if err := t.Verify(ctx.trust, now); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeProvideToken(t)
}
return nil
}
func (ctx *BasicCapabilityContext) ListRoots() ([]did.DID, TokenList, TokenList) {
var require, provide []*Token
roots := ctx.getRoots()
for _, anchor := range ctx.getRequireAnchors() {
tokenList := ctx.getRequireTokens(anchor)
require = append(require, tokenList...)
}
for _, anchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(anchor)
provide = append(provide, tokenList...)
}
return roots, TokenList{Tokens: require}, TokenList{Tokens: provide}
}
func (ctx *BasicCapabilityContext) RemoveRoots(trust []did.DID, require, provide TokenList) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, root := range trust {
delete(ctx.roots, root)
}
for _, t := range require.Tokens {
tokenList, ok := ctx.require[t.Issuer()]
if ok {
tokenList = slices.DeleteFunc(tokenList, func(ot *Token) bool {
return bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(tokenList) > 0 {
ctx.require[t.Issuer()] = tokenList
} else {
delete(ctx.require, t.Issuer())
}
}
}
for _, t := range provide.Tokens {
tokenList, ok := ctx.provide[t.Issuer()]
if ok {
tokenList = slices.DeleteFunc(tokenList, func(ot *Token) bool {
return bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(tokenList) > 0 {
ctx.provide[t.Issuer()] = tokenList
} else {
delete(ctx.provide, t.Issuer())
}
}
}
}
func (ctx *BasicCapabilityContext) Grant(action Action, subject, audience did.DID, topics []string, expire, depth uint64, provide []Capability) (TokenList, error) {
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return TokenList{}, fmt.Errorf("nonce: %w", err)
}
topicCap := make([]Capability, 0, len(topics))
for _, topic := range topics {
topicCap = append(topicCap, Capability(topic))
}
result := &DMSToken{
Issuer: ctx.DID(),
Subject: subject,
Audience: audience,
Action: action,
Topic: topicCap,
Capability: provide,
Nonce: nonce,
Expire: expire,
Depth: depth,
}
data, err := result.SignatureData()
if err != nil {
return TokenList{}, fmt.Errorf("grant: %w", err)
}
sig, err := ctx.provider.Sign(data)
if err != nil {
return TokenList{}, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return TokenList{Tokens: []*Token{{DMS: result}}}, nil
}
func (ctx *BasicCapabilityContext) Delegate(subject, audience did.DID, topics []string, expire, depth uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
topicCap := make([]Capability, 0, len(topics))
for _, topic := range topics {
topicCap = append(topicCap, Capability(topic))
}
var result []*Token
if selfSign == SelfSignOnly {
goto selfsign
}
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
if len(tokenList) == 0 {
continue
}
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(trustAnchor) && t.AllowDelegation(Delegate, ctx.DID(), audience, topicCap, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.Delegate(ctx.provider, subject, audience, topicCap, expire, depth, providing)
if err != nil {
log.Debugf("error delegating %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
tokens, err := ctx.Grant(Delegate, subject, audience, topics, expire, depth, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, tokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) DelegateInvocation(target, subject, audience did.DID, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
// first get tokens we have about ourselves and see if any allows delegation to
// the subject for the audience
tokenList := ctx.getSubjectTokens(ctx.DID())
tokens := ctx.delegateInvocation(tokenList, target, subject, audience, expire, provide)
result = append(result, tokens...)
if selfSign == SelfSignOnly {
goto selfsign
}
// then we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateInvocation(tokenList, trustAnchor, subject, audience, expire, provide)
result = append(result, tokens...)
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
selfTokens, err := ctx.Grant(Invoke, subject, audience, nil, expire, 0, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateInvocation(tokenList []*Token, anchor, subject, audience did.DID, expire uint64, provide []Capability) []*Token {
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(Invoke, ctx.DID(), audience, nil, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.DelegateInvocation(ctx.provider, subject, audience, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
if selfSign == SelfSignOnly {
goto selfsign
}
// first we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateBroadcast(tokenList, trustAnchor, subject, topic, expire, provide)
result = append(result, tokens...)
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
selfTokens, err := ctx.Grant(Broadcast, subject, did.DID{}, []string{topic}, expire, 0, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting broadcast: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateBroadcast(tokenList []*Token, anchor did.DID, subject did.DID, topic string, expire uint64, provide []Capability) []*Token {
topicCap := Capability(topic)
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(Broadcast, ctx.DID(), did.DID{}, []Capability{topicCap}, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.DelegateBroadcast(ctx.provider, subject, topicCap, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) Consume(origin did.DID, data []byte) error {
if len(data) > maxCapabilitySize {
return ErrTooBig
}
var tokens TokenList
if err := json.Unmarshal(data, &tokens); err != nil {
return fmt.Errorf("unmarshaling payload: %w", err)
}
rootAnchors := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
now := uint64(time.Now().UnixNano())
for _, t := range tokens.Tokens {
if t.Anchor(ctx.DID()) {
goto verify
}
if t.Anchor(origin) {
goto verify
}
for _, anchor := range rootAnchors {
if t.Anchor(anchor) {
goto verify
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) {
goto verify
}
}
}
continue
verify:
if err := t.Verify(ctx.trust, now); err != nil {
log.Warnf("failed to verify token issued by %s: %s", t.Issuer(), err)
continue
}
ctx.consumeSubjectToken(t)
}
return nil
}
func (ctx *BasicCapabilityContext) Discard(data []byte) {
var tokens TokenList
if err := json.Unmarshal(data, &tokens); err != nil {
return
}
ctx.discardTokens(tokens.Tokens)
}
func (ctx *BasicCapabilityContext) consumeAnchorToken(getf func() []*Token, setf func(result []*Token), t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList := getf()
result := make([]*Token, 0, len(tokenList)+1)
for _, ot := range tokenList {
if ot.Subsumes(t) {
return
}
if t.Subsumes(ot) {
continue
}
result = append(result, ot)
}
result = append(result, t)
setf(result)
}
func (ctx *BasicCapabilityContext) consumeRequireToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.require[t.Issuer()] },
func(result []*Token) {
ctx.require[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeProvideToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.provide[t.Issuer()] },
func(result []*Token) {
ctx.provide[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeSubjectToken(t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
subject := t.Subject()
tokenList := ctx.tokens[subject]
tokenList = append(tokenList, t)
ctx.tokens[subject] = tokenList
}
func (ctx *BasicCapabilityContext) Require(anchor did.DID, subject crypto.ID, audience crypto.ID, cap []Capability) error {
if len(cap) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return fmt.Errorf("DID for audience: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
for _, t := range tokenList {
for _, c := range cap {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) RequireBroadcast(anchor did.DID, subject crypto.ID, topic string, require []Capability) error {
if len(require) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
topicCap := Capability(topic)
for _, t := range tokenList {
for _, c := range require {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, provide []Capability) ([]byte, error) {
if len(invoke) == 0 && len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return nil, fmt.Errorf("DID for audience: %w", err)
}
var result []*Token
var invocation, delegation TokenList
if len(invoke) == 0 {
return nil, fmt.Errorf("no invocation capabilities: %w", ErrNotAuthorized)
}
invocation, err = ctx.DelegateInvocation(target, subjectDID, audienceDID, expire, invoke, SelfSignAlso)
if err != nil {
return nil, fmt.Errorf("cannot provide invocation tokens: %w", err)
}
if len(invocation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary invocation tokens: %w", ErrNotAuthorized)
}
result = append(result, invocation.Tokens...)
if len(provide) == 0 {
goto marshal
}
delegation, err = ctx.Delegate(target, subjectDID, nil, expire, 1, provide, SelfSignOnly)
if err != nil {
return nil, fmt.Errorf("cannot provide delegation tokens: %w", err)
}
if len(delegation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary delegation tokens: %w", ErrNotAuthorized)
}
result = append(result, delegation.Tokens...)
marshal:
payload := TokenList{Tokens: result}
data, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) ProvideBroadcast(subject crypto.ID, topic string, expire uint64, provide []Capability) ([]byte, error) {
if len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
broadcast, err := ctx.DelegateBroadcast(subjectDID, topic, expire, provide, SelfSignAlso)
if err != nil {
return nil, fmt.Errorf("cannot provide broadcast tokens: %w", err)
}
if len(broadcast.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary broadcast tokens: %w", ErrNotAuthorized)
}
data, err := json.Marshal(broadcast)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) getRoots() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.roots))
for anchor := range ctx.roots {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) addRoots(anchors []did.DID) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, anchor := range anchors {
ctx.roots[anchor] = struct{}{}
}
}
func (ctx *BasicCapabilityContext) getRequireAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.require))
for anchor := range ctx.require {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getProvideAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.provide))
for anchor := range ctx.provide {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getTokens(getf func() ([]*Token, bool), setf func([]*Token)) []*Token {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList, ok := getf()
if !ok {
return nil
}
// filter expired
now := uint64(time.Now().UnixNano())
result := slices.DeleteFunc(slices.Clone(tokenList), func(t *Token) bool {
return t.ExpireBefore(now)
})
setf(result)
return result
}
func (ctx *BasicCapabilityContext) getRequireTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.require[anchor]; return result, ok },
func(result []*Token) { ctx.require[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getProvideTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.provide[anchor]; return result, ok },
func(result []*Token) { ctx.provide[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getSubjectTokens(subject did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.tokens[subject]; return result, ok },
func(result []*Token) { ctx.tokens[subject] = result },
)
}
func (ctx *BasicCapabilityContext) discardTokens(tokens []*Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, t := range tokens {
subject := t.Subject()
subjectTokens := slices.DeleteFunc(slices.Clone(ctx.tokens[subject]), func(ot *Token) bool {
return t.Issuer() == ot.Issuer() && bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(subjectTokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = subjectTokens
}
}
}
func (ctx *BasicCapabilityContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcTokens()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicCapabilityContext) gcTokens() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := uint64(time.Now().UnixNano())
for anchor, tokens := range ctx.require {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.require, anchor)
} else {
ctx.require[anchor] = tokens
}
}
for anchor, tokens := range ctx.provide {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.provide, anchor)
} else {
ctx.provide[anchor] = tokens
}
}
for subject, tokens := range ctx.tokens {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = tokens
}
}
}
package ucan
import (
"encoding/json"
"fmt"
"io"
"gitlab.com/nunet/device-management-service/lib/did"
)
type Saver interface {
Save(wr io.Writer) error
}
type CapabilityContextView struct {
DID did.DID
Roots []did.DID
Require TokenList
Provide TokenList
}
func SaveCapabilityContext(ctx CapabilityContext, wr io.Writer) error {
if bctx, ok := ctx.(*BasicCapabilityContext); ok {
return saveCapabilityContext(bctx, wr)
} else if saver, ok := ctx.(Saver); ok {
return saver.Save(wr)
}
return fmt.Errorf("cannot save context: %w", ErrBadContext)
}
func saveCapabilityContext(ctx *BasicCapabilityContext, wr io.Writer) error {
roots, require, provide := ctx.ListRoots()
view := CapabilityContextView{
DID: ctx.provider.DID(),
Roots: roots,
Require: require,
Provide: provide,
}
encoder := json.NewEncoder(wr)
if err := encoder.Encode(&view); err != nil {
return fmt.Errorf("encoding capability context view: %w", err)
}
return nil
}
func LoadCapabilityContext(trust did.TrustContext, rd io.Reader) (CapabilityContext, error) {
var view CapabilityContextView
decoder := json.NewDecoder(rd)
if err := decoder.Decode(&view); err != nil {
return nil, fmt.Errorf("decoding capability context view: %w", err)
}
var require, provide TokenList
for _, t := range view.Require.Tokens {
if !t.Expired() {
require.Tokens = append(require.Tokens, t)
}
}
for _, t := range view.Provide.Tokens {
if !t.Expired() {
provide.Tokens = append(provide.Tokens, t)
}
}
return NewCapabilityContext(trust, view.DID, view.Roots, require, provide)
}
package ucan
import (
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"time"
"gitlab.com/nunet/device-management-service/lib/did"
)
type Action string
const (
Invoke Action = "invoke"
Delegate Action = "delegate"
Broadcast Action = "broadcast"
// Revoke Action = "revoke" // TODO
nonceLength = 12 // 96 bits
)
var signaturePrefix = []byte("dms:token:")
type Token struct {
// DMS tokens
DMS *DMSToken `json:"dms,omitempty"`
// UCAN standard (when it is done) envelope for BYO anhcors
UCAN *BYOToken `json:"ucan,omitempty"`
}
type DMSToken struct {
Action Action `json:"act"`
Issuer did.DID `json:"iss"`
Subject did.DID `json:"sub"`
Audience did.DID `json:"aud"`
Topic []Capability `json:"topic,omitempty"`
Capability []Capability `json:"cap"`
Nonce []byte `json:"nonce"`
Expire uint64 `json:"exp"`
Depth uint64 `json:"depth,omitempty"`
Chain *Token `json:"chain,omitempty"`
Signature []byte `json:"sig,omitempty"`
}
type BYOToken struct {
// TODO followup
}
type TokenList struct {
Tokens []*Token `json:"tok,omitempty"`
}
func (t *Token) SignatureData() ([]byte, error) {
switch {
case t.DMS != nil:
return t.DMS.SignatureData()
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) SignatureData() ([]byte, error) {
tCopy := *t
tCopy.Signature = nil
data, err := json.Marshal(&tCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (t *Token) Issuer() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Issuer
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Subject() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Subject
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Audience() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Audience
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Capability() []Capability {
switch {
case t.DMS != nil:
return t.DMS.Capability
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Topic() []Capability {
switch {
case t.DMS != nil:
return t.DMS.Topic
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Expire() uint64 {
switch {
case t.DMS != nil:
return t.DMS.Expire
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0
}
}
func (t *Token) Nonce() []byte {
switch {
case t.DMS != nil:
return t.DMS.Nonce
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil // expired right after the unix big bang
}
}
func (t *Token) Action() Action {
switch {
case t.DMS != nil:
return t.DMS.Action
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return Action("")
}
}
func (t *Token) Verify(trust did.TrustContext, now uint64) error {
return t.verify(trust, now, 0)
}
func (t *Token) verify(trust did.TrustContext, now, depth uint64) error {
switch {
case t.DMS != nil:
return t.DMS.verify(trust, now, depth)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return ErrBadToken
}
}
func (t *DMSToken) verify(trust did.TrustContext, now, depth uint64) error {
if t.ExpireBefore(now) {
return ErrCapabilityExpired
}
if t.Depth > 0 && depth > t.Depth {
return fmt.Errorf("max token depth exceeded: %w", ErrNotAuthorized)
}
if t.Chain != nil {
if t.Chain.Action() != Delegate {
return fmt.Errorf("verify: chain does not allow delegation: %w", ErrNotAuthorized)
}
if t.Chain.ExpireBefore(t.Expire) {
return ErrCapabilityExpired
}
if err := t.Chain.verify(trust, now, depth+1); err != nil {
return err
}
if !t.Issuer.Equal(t.Chain.Subject()) {
return fmt.Errorf("verify: issuer/chain subject misnmatch: %w", ErrNotAuthorized)
}
needCapability := slices.Clone(t.Capability)
for _, c := range t.Capability {
if t.Chain.allowDelegation(t.Issuer, t.Audience, t.Topic, t.Expire, c) {
needCapability = slices.DeleteFunc(needCapability, func(oc Capability) bool {
return c == oc
})
if len(needCapability) == 0 {
break
}
}
}
if len(needCapability) > 0 {
return fmt.Errorf("verify: capabilities are not allowed by the chain: %w", ErrNotAuthorized)
}
}
anchor, err := trust.GetAnchor(t.Issuer)
if err != nil {
return fmt.Errorf("verify: anchor: %w", err)
}
data, err := t.SignatureData()
if err != nil {
return fmt.Errorf("verify: signature data: %w", err)
}
if err := anchor.Verify(data, t.Signature); err != nil {
return fmt.Errorf("verify: signature: %w", err)
}
return nil
}
func (t *Token) AllowAction(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowAction(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowAction(ot *Token) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(ot.Expire()) {
return false
}
if !ot.Anchor(t.Subject) {
return false
}
if t.Depth > 0 {
depth, ok := ot.AnchorDepth(t.Subject)
if ok && depth > t.Depth {
return false
}
}
if !t.Audience.Empty() && !t.Audience.Equal(ot.Audience()) {
return false
}
for _, oc := range ot.Capability() {
allow := false
for _, c := range t.Capability {
if c.Implies(oc) {
allow = true
break
}
}
if !allow {
return false
}
}
for _, otherTopic := range ot.Topic() {
allow := false
for _, topic := range t.Topic {
if topic.Implies(otherTopic) {
allow = true
break
}
}
if !allow {
return false
}
}
return true
}
func (t *Token) Size() int {
data, _ := t.SignatureData()
return len(data)
}
func (t *Token) Subsumes(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.Subsumes(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Subsumes(ot *Token) bool {
if t.Issuer.Equal(ot.Issuer()) &&
t.Subject.Equal(ot.Subject()) &&
t.Audience.Equal(ot.Audience()) &&
t.Expire > ot.Expire() {
loop:
for _, oc := range ot.Capability() {
for _, c := range t.Capability {
if c.Implies(oc) {
continue loop
}
}
return false
}
return true
}
return false
}
func (t *Token) AllowInvocation(subject, audience did.DID, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowInvocation(subject, audience, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowInvocation(subject, audience did.DID, c Capability) bool {
if t.Action != Invoke {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, granted := range t.Capability {
if granted.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowBroadcast(subject did.DID, topic Capability, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowBroadcast(subject, topic, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowBroadcast(subject did.DID, topic Capability, c Capability) bool {
if t.Action != Broadcast {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() {
return false
}
allow := false
for _, allowTopic := range t.Topic {
if allowTopic.Implies(topic) {
allow = true
break
}
}
if !allow {
return false
}
for _, allowCap := range t.Capability {
if allowCap.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowDelegation(action Action, issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowDelegation(action, issuer, audience, topics, expire, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowDelegation(action Action, issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
if action == Delegate {
if !t.verifyDepth(2) {
// certificate would be dead end with 1
return false
}
} else {
if !t.verifyDepth(1) {
return false
}
}
return t.allowDelegation(issuer, audience, topics, expire, c)
}
func (t *Token) allowDelegation(issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.allowDelegation(issuer, audience, topics, expire, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) allowDelegation(issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(expire) {
return false
}
if !t.Subject.Equal(issuer) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, topic := range topics {
allow := false
for _, myTopic := range t.Topic {
if myTopic.Implies(topic) {
allow = true
break
}
}
if !allow {
return false
}
}
for _, myCap := range t.Capability {
if myCap.Implies(c) {
return true
}
}
return false
}
func (t *Token) verifyDepth(depth uint64) bool {
switch {
case t.DMS != nil:
return t.DMS.verifyDepth(depth)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) verifyDepth(depth uint64) bool {
if t.Depth > 0 && depth > t.Depth {
return false
}
if t.Chain != nil {
return t.Chain.verifyDepth(depth + 1)
}
return true
}
func (t *Token) Delegate(provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.Delegate(provider, subject, audience, topics, expire, depth, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) Delegate(provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Delegate, provider, subject, audience, topics, expire, depth, c)
}
func (t *DMSToken) delegate(action Action, provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*DMSToken, error) {
if t.Action != Delegate {
return nil, ErrNotAuthorized
}
if action == Delegate {
if !t.verifyDepth(2) {
// certificate would be dead end with 1
return nil, ErrNotAuthorized
}
} else {
if !t.verifyDepth(1) {
return nil, ErrNotAuthorized
}
}
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return nil, fmt.Errorf("nonce: %w", err)
}
result := &DMSToken{
Action: action,
Issuer: provider.DID(),
Subject: subject,
Audience: audience,
Topic: topics,
Capability: c,
Nonce: nonce,
Expire: expire,
Depth: depth,
Chain: &Token{DMS: t},
}
data, err := result.SignatureData()
if err != nil {
return nil, fmt.Errorf("delegate: %w", err)
}
sig, err := provider.Sign(data)
if err != nil {
return nil, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return result, nil
}
func (t *Token) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateInvocation(provider, subject, audience, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Invoke, provider, subject, audience, nil, expire, 0, c)
}
func (t *Token) DelegateBroadcast(provider did.Provider, subject did.DID, topic Capability, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateBroadcast(provider, subject, topic, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateBroadcast(provider did.Provider, subject did.DID, topic Capability, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Broadcast, provider, subject, did.DID{}, []Capability{topic}, expire, 0, c)
}
func (t *Token) Anchor(anchor did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.Anchor(anchor)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Anchor(anchor did.DID) bool {
if t.Issuer.Equal(anchor) {
return true
}
if t.Chain != nil {
return t.Chain.Anchor(anchor)
}
return false
}
func (t *Token) AnchorDepth(anchor did.DID) (uint64, bool) {
switch {
case t.DMS != nil:
return t.DMS.AnchorDepth(anchor)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0, false
}
}
func (t *DMSToken) AnchorDepth(anchor did.DID) (depth uint64, have bool) {
if t.Issuer.Equal(anchor) {
have = true
depth = 0
}
if t.Chain != nil {
if chainDepth, chainHave := t.Chain.AnchorDepth(anchor); chainHave {
have = true
depth = chainDepth + 1
}
}
return depth, have
}
func (t *Token) Expired() bool {
return t.ExpireBefore(uint64(time.Now().UnixNano()))
}
func (t *Token) ExpireBefore(deadline uint64) bool {
switch {
case t.DMS != nil:
return t.DMS.ExpireBefore(deadline)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return true
}
}
func (t *DMSToken) ExpireBefore(deadline uint64) bool {
if deadline > t.Expire {
return true
}
if t.Chain != nil {
return t.Chain.ExpireBefore(deadline)
}
return false
}
func (t *Token) SelfSigned(origin did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.SelfSigned(origin)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) SelfSigned(origin did.DID) bool {
if t.Chain != nil {
return t.Chain.SelfSigned(origin)
}
return t.Issuer.Equal(origin)
}
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/spf13/afero"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
type (
PeerID = libp2p.PeerID
Topic = libp2p.Topic
Validator = libp2p.Validator
ValidationResult = libp2p.ValidationResult
PeerScoreSnapshot = libp2p.PeerScoreSnapshot
)
const (
ValidationAccept = libp2p.ValidationAccept
ValidationReject = libp2p.ValidationReject
ValidationIgnore = libp2p.ValidationIgnore
)
// Messenger defines the interface for sending messages.
type Messenger interface {
// SendMessage sends a message to the given address.
SendMessage(ctx context.Context, hostID string, msg types.MessageEnvelope) error
}
type Network interface {
// Messenger embedded interface
Messenger
// Init initializes the network
Init(context.Context) error
// Start starts the network
Start(context context.Context) error
// Stat returns the network information
Stat() types.NetworkStats
// Ping pings the given address and returns the PingResult
Ping(ctx context.Context, address string, timeout time.Duration) (types.PingResult, error)
// HandleMessage is responsible for registering a message type and its handler.
HandleMessage(messageType string, handler func(data []byte)) error
// UnregisterMessageHandler unregisters a stream handler for a specific protocol.
UnregisterMessageHandler(messageType string)
// ResolveAddress given an id it retruns the address of the peer.
// In libp2p, id represents the peerID and the response is the addrinfo
ResolveAddress(ctx context.Context, id string) ([]string, error)
// Advertise advertises the given data with the given adId
// such as advertising device capabilities on the DHT
Advertise(ctx context.Context, key string, data []byte) error
// Unadvertise stops advertising data corresponding to the given adId
Unadvertise(ctx context.Context, key string) error
// Query returns the network advertisement
Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error)
// Publish publishes the given data to the given topic if the network
// type allows publish/subscribe functionality such as gossipsub or nats
Publish(ctx context.Context, topic string, data []byte) error
// Subscribe subscribes to the given topic and calls the handler function
// if the network type allows it similar to Publish()
Subscribe(ctx context.Context, topic string, handler func(data []byte), validator libp2p.Validator) (uint64, error)
// Unsubscribe from a topic
Unsubscribe(topic string, subID uint64) error
// SetupBroadcastTopic allows the application to configure pubsub topic directly
SetupBroadcastTopic(topic string, setup func(*Topic) error) error
// SetupBroadcastAppScore allows the application to configure application level
// scoring for pubsub
SetBroadcastAppScore(func(PeerID) float64)
// GetBroadcastScore returns the latest broadcast score snapshot
GetBroadcastScore() map[PeerID]*PeerScoreSnapshot
// Notify allows the application to receive notifications about peer connections
// and disconnecions
Notify(ctx context.Context, connected, disconnected func(PeerID)) error
// PeerConnected returs true if the peer is currently connected
PeerConnected(p PeerID) bool
// Stop stops the network including any existing advertisements and subscriptions
Stop() error
}
// NewNetwork returns a new network given the configuration.
func NewNetwork(netConfig *types.NetworkConfig, fs afero.Fs) (Network, error) {
// TODO: probable additional params to receive: DB, FileSystem
if netConfig == nil {
return nil, errors.New("network configuration is nil")
}
switch netConfig.Type {
case types.Libp2pNetwork:
ln, err := libp2p.New(&netConfig.Libp2pConfig, fs)
return ln, err
case types.NATSNetwork:
return nil, errors.New("not implemented")
default:
return nil, fmt.Errorf("unsupported network type: %s", netConfig.Type)
}
}
package basiccontroller
import (
"context"
"fmt"
"os"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// BasicVolumeController is the default implementation of the VolumeController.
// It persists storage volumes information using the StorageVolume.
type BasicVolumeController struct {
// repo is the repository for storage volume operations
repo repositories.StorageVolume
// basePath is the base path where volumes are stored under
basePath string
// file system to act upon
FS afero.Fs
}
// NewDefaultVolumeController returns a new instance of BasicVolumeController
//
// TODO-BugFix [path]: volBasePath might not end with `/`, causing errors when calling methods.
// We need to validate it using the `path` library or just verifying the string.
func NewDefaultVolumeController(repo repositories.StorageVolume, volBasePath string, fs afero.Fs) (*BasicVolumeController, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_controller_init_duration", "opentelemetry", "log")
defer cancel()
vc := &BasicVolumeController{
repo: repo,
basePath: volBasePath,
FS: fs,
}
st.Info(ctx, "volume_controller_init_success", nil)
return vc, nil
}
// CreateVolume creates a new storage volume given a source (S3, IPFS, job, etc). The
// creation of a storage volume effectively creates an empty directory in the local filesystem
// and writes a record in the database.
//
// The directory name follows the format: `<volSource> + "-" + <name>
// where `name` is random.
//
// TODO-maybe [withName]: allow callers to specify custom name for path
func (vc *BasicVolumeController) CreateVolume(volSource storage.VolumeSource, opts ...storage.CreateVolOpt) (types.StorageVolume, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_create_duration", "opentelemetry", "log")
defer cancel()
vol := types.StorageVolume{
Private: false,
ReadOnly: false,
EncryptionType: types.EncryptionTypeNull,
}
for _, opt := range opts {
opt(&vol)
}
randomStr, err := utils.RandomString(16)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create random string: %w", err)
}
vol.Path = vc.basePath + string(volSource) + "-" + randomStr
ctx = context.WithValue(ctx, pathKey, vol.Path)
if err := vc.FS.Mkdir(vol.Path, os.ModePerm); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_create_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %w", err)
}
createdVol, err := vc.repo.Create(ctx, vol)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_create_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume in repository: %w", err)
}
ctx = context.WithValue(ctx, volumeIDKey, createdVol.ID)
st.Info(ctx, "volume_create_success", nil)
return createdVol, nil
}
// LockVolume makes the volume read-only, not only changing the field value but also changing file permissions.
// It should be used after all necessary data has been written.
// It optionally can also set the CID and mark the volume as private.
//
// TODO-maybe [CID]: maybe calculate CID of every volume in case WithCID opt is not provided
func (vc *BasicVolumeController) LockVolume(pathToVol string, opts ...storage.LockVolOpt) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_lock_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, pathKey, pathToVol)
query := vc.repo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("Path", pathToVol))
vol, err := vc.repo.Find(ctx, query)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_lock_failure", nil)
return fmt.Errorf("failed to find storage volume with path %s - Error: %w", pathToVol, err)
}
for _, opt := range opts {
opt(&vol)
}
vol.ReadOnly = true
updatedVol, err := vc.repo.Update(ctx, vol.ID, vol)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_lock_failure", nil)
return fmt.Errorf("failed to update storage volume with path %s - Error: %w", pathToVol, err)
}
// change file permissions
if err := vc.FS.Chmod(updatedVol.Path, 0o400); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_lock_failure", nil)
return fmt.Errorf("failed to make storage volume read-only (path: %s): %w", updatedVol.Path, err)
}
st.Info(ctx, "volume_lock_success", nil)
return nil
}
// WithPrivate designates a given volume as private. It can be used both
// when creating or locking a volume.
func WithPrivate[T storage.CreateVolOpt | storage.LockVolOpt]() T {
return func(v *types.StorageVolume) {
v.Private = true
}
}
// WithCID sets the CID of a given volume if already calculated
//
// TODO [validate]: check if CID provided is valid
func WithCID(cid string) storage.LockVolOpt {
return func(v *types.StorageVolume) {
v.CID = cid
}
}
// DeleteVolume deletes a given storage volume record from the database and removes the corresponding directory.
// Identifier is either a CID or a path of a volume.
//
// Note [CID]: if we start to type CID as cid.CID, we may have to use generics here
// as in `[T string | cid.CID]`
func (vc *BasicVolumeController) DeleteVolume(identifier string, idType storage.IDType) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_delete_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, identifierKey, identifier)
ctx = context.WithValue(ctx, idTypeKey, idType)
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
ctx = context.WithValue(ctx, errorKey, "identifier type not supported")
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("identifier type not supported")
}
vol, err := vc.repo.Find(ctx, query)
if err != nil {
if err == repositories.ErrNotFound {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("volume not found: %w", err)
}
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("failed to find volume: %w", err)
}
// Remove the directory
if err := vc.FS.RemoveAll(vol.Path); err != nil {
return fmt.Errorf("failed to remove volume directory: %w", err)
}
// Delete the record from the database
if err := vc.repo.Delete(context.Background(), vol.ID); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("failed to delete volume: %w", err)
}
st.Info(ctx, "volume_delete_success", nil)
return nil
}
// ListVolumes returns a list of all storage volumes stored on the database
//
// TODO [filter]: maybe add opts to filter results by certain values
func (vc *BasicVolumeController) ListVolumes() ([]types.StorageVolume, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_list_duration", "opentelemetry", "log")
defer cancel()
volumes, err := vc.repo.FindAll(ctx, vc.repo.GetQuery())
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_list_failure", nil)
return nil, fmt.Errorf("failed to list volumes: %w", err)
}
ctx = context.WithValue(ctx, volumeCountKey, len(volumes))
st.Info(ctx, "volume_list_success", nil)
return volumes, nil
}
// GetSize returns the size of a volume
// TODO-minor: identify which measurement type will be used
func (vc *BasicVolumeController) GetSize(identifier string, idType storage.IDType) (int64, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_get_size_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, identifierKey, identifier)
ctx = context.WithValue(ctx, idTypeKey, idType)
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
ctx = context.WithValue(ctx, errorKey, fmt.Sprintf("unsupported ID type: %d", idType))
st.Error(ctx, "volume_get_size_failure", nil)
return 0, fmt.Errorf("unsupported ID type: %d", idType)
}
vol, err := vc.repo.Find(ctx, query)
if err != nil {
ctx = context.WithValue(ctx, errorKey, fmt.Sprintf("failed to find volume: %v", err))
st.Error(ctx, "volume_get_size_failure", nil)
return 0, fmt.Errorf("failed to find volume: %w", err)
}
size, err := utils.GetDirectorySize(vc.FS, vol.Path)
if err != nil {
ctx = context.WithValue(ctx, errorKey, fmt.Sprintf("failed to get directory size: %v", err))
st.Error(ctx, "volume_get_size_failure", nil)
return 0, fmt.Errorf("failed to get directory size: %w", err)
}
ctx = context.WithValue(ctx, sizeKey, size)
st.Info(ctx, "volume_get_size_success", nil)
ctx = context.WithValue(ctx, sizeKey, size)
st.Info(ctx, "volume_get_size_success", nil)
return size, nil
}
// EncryptVolume encrypts a given volume
func (vc *BasicVolumeController) EncryptVolume(path string, _ types.Encryptor, _ types.EncryptionType) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_encrypt_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, pathKey, path)
st.Error(ctx, "volume_encrypt_not_implemented", nil)
return fmt.Errorf("not implemented")
}
// DecryptVolume decrypts a given volume
func (vc *BasicVolumeController) DecryptVolume(path string, _ types.Decryptor, _ types.EncryptionType) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_decrypt_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, pathKey, path)
st.Error(ctx, "volume_decrypt_not_implemented", nil)
return fmt.Errorf("not implemented")
}
var _ storage.VolumeController = (*BasicVolumeController)(nil)
package basiccontroller
import (
"context"
"fmt"
"github.com/spf13/afero"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
rGorm "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/telemetry"
"gitlab.com/nunet/device-management-service/types"
)
type VolumeControllerTestKit struct {
BasicVolController *BasicVolumeController
Fs afero.Fs
Volumes map[string]*types.StorageVolume
}
// SetupVolumeControllerTestKit sets up a volume controller with 0-n volumes given a base path.
// If volumes are inputed, directories will be created and volumes will be stored in the database
func SetupVolumeControllerTestKit(basePath string, volumes map[string]*types.StorageVolume) (*VolumeControllerTestKit, error) {
// Initialize telemetry in test mode, replacing the global st
// It's initiated here too, besides on basic_controller_test.go, because
// s3 tests depend on basicController (which in turn depends on telemetry instantiation).
// S3 are calling this SetupVolControllerTestSuite, so it's one way to initialize telemetry
// for basic controller
st = telemetry.NewTelemetry(nil, nil, true)
db, err := gorm.Open(
sqlite.Open("file:?mode=memory&cache=shared"),
&gorm.Config{Logger: logger.Default.LogMode(logger.Silent)},
)
if err != nil {
return nil, fmt.Errorf("failed to create in-memory mock database: %w", err)
}
err = db.AutoMigrate(&types.StorageVolume{})
if err != nil {
return nil, fmt.Errorf("failed to automigrate: %w", err)
}
fs := afero.NewMemMapFs()
err = fs.MkdirAll(basePath, 0o755)
if err != nil {
return nil, fmt.Errorf("failed to create base path: %w", err)
}
repo := rGorm.NewStorageVolume(db)
vc, err := NewDefaultVolumeController(repo, basePath, fs)
if err != nil {
return nil, fmt.Errorf("failed to create volume controller: %w", err)
}
for _, vol := range volumes {
// create root volume dir
err = fs.MkdirAll(vol.Path, 0o755)
if err != nil {
return nil, fmt.Errorf("failed to create volume dir: %w", err)
}
// create volume record in db
_, err = repo.Create(context.Background(), *vol)
if err != nil {
return nil, fmt.Errorf("failed to create volume record: %w", err)
}
}
return &VolumeControllerTestKit{
BasicVolController: vc,
Fs: fs,
Volumes: volumes,
}, nil
}
package telemetry
import (
"os"
"sync"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/types"
)
var (
once sync.Once
instance *Telemetry
logLevel types.ObservabilityLevel
zapLogger *zap.Logger
)
// InitGlobalTelemetry initializes the global telemetry instance with configuration loaded from the configuration package.
func InitGlobalTelemetry() error {
var initError error
once.Do(func() {
// Initialize Zap logger
zapLogger, initError = initZapLogger()
if initError != nil {
panic(initError)
}
zap.ReplaceGlobals(zapLogger)
cfg := config.GetConfig()
telemetryConfig := cfg.Telemetry
logLevel = types.INFO // Default level
if level, err := types.ParseObservabilityLevel(telemetryConfig.ObservabilityLevel); err == nil {
logLevel = level
} else {
zap.L().Warn("Invalid observability level, defaulting to INFO", zap.Error(err))
}
instance = &Telemetry{
config: &types.TelemetryConfig{
ServiceName: telemetryConfig.ServiceName,
GlobalEndpoint: telemetryConfig.GlobalEndpoint,
ObservabilityLevel: telemetryConfig.ObservabilityLevel, // Assign the string value
TelemetryMode: telemetryConfig.TelemetryMode,
},
}
opentelemetryCollector := NewOpenTelemetryCollector(instance.config, zap.L())
logCollector := NewLogCollector(instance.config, zap.L())
instance.collectors = map[string]Collector{
logCollector.GetName(): logCollector,
opentelemetryCollector.GetName(): opentelemetryCollector,
}
for _, collector := range instance.collectors {
if err := collector.Initialize(); err != nil {
zap.L().Error("Failed to initialize collector", zap.Error(err))
}
}
// Start periodic flush after initializing collectors
StartPeriodicFlush(5 * time.Minute)
})
return initError
}
// initZapLogger initializes the zap logger based on configuration or environment variables.
func initZapLogger() (*zap.Logger, error) {
var err error
var logger *zap.Logger
if _, debug := os.LookupEnv("NUNET_DEBUG"); debug || config.GetConfig().General.Debug {
zapConfig := zap.NewDevelopmentConfig()
zapConfig.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
logger, err = zapConfig.Build()
} else {
logger, err = zap.NewProduction()
}
return logger, err
}
// StartPeriodicFlush starts a goroutine that periodically flushes telemetry data.
func StartPeriodicFlush(interval time.Duration) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
zap.L().Info("Periodic flush started for telemetry")
instance.Flush()
}
}()
}
package telemetry
import (
"context"
"gitlab.com/nunet/device-management-service/types"
"go.uber.org/zap"
)
type LogCollector struct {
config *types.TelemetryConfig
logger *zap.Logger
}
func NewLogCollector(config *types.TelemetryConfig, logger *zap.Logger) *LogCollector {
return &LogCollector{
config: config,
logger: logger,
}
}
func (c *LogCollector) Initialize() error {
c.logger.Info("LogCollector initialized.")
return nil
}
func (c *LogCollector) HandleEvent(event types.Event) error {
fields := []zap.Field{
zap.Any("context", event.Context),
zap.String("message", event.Message),
zap.String("level", event.Level.String()),
zap.Any("payload", event.Payload),
}
switch event.Level {
case types.TRACE:
c.logger.Debug(event.Message, fields...)
case types.DEBUG:
c.logger.Debug(event.Message, fields...)
case types.INFO:
c.logger.Info(event.Message, fields...)
case types.WARN:
c.logger.Warn(event.Message, fields...)
case types.ERROR:
c.logger.Error(event.Message, fields...)
case types.FATAL:
c.logger.Fatal(event.Message, fields...)
default:
c.logger.Info(event.Message, fields...)
}
return nil
}
func (c *LogCollector) Flush() error {
if err := c.logger.Sync(); err != nil { // Check for error in Sync
return err
}
return nil
}
func (c *LogCollector) Shutdown() error {
return c.Flush()
}
func (c *LogCollector) GetName() string {
return "log"
}
func (c *LogCollector) SpanContext(ctx context.Context, _ string) (context.Context, context.CancelFunc) {
// LogCollector does not support tracing, so just return the original context and a no-op cancel function
return ctx, func() {}
}
// Compile-time check to ensure LogCollector implements the Collector interface
var _ Collector = (*LogCollector)(nil)
package telemetry
import (
"context"
"sync"
"gitlab.com/nunet/device-management-service/types"
)
// MockCollector is a mock implementation of the Collector interface.
type MockCollector struct {
mu sync.Mutex
events []types.Event
traces []MockTrace
initialized bool
name string
}
type MockTrace struct {
SpanName string
Context context.Context
CancelFunc context.CancelFunc
}
// NewMockCollector creates a new instance of MockCollector.
func NewMockCollector(name string) *MockCollector {
return &MockCollector{
events: []types.Event{},
traces: []MockTrace{},
name: name,
}
}
// Initialize is a mock implementation of the Collector interface's Initialize method.
func (m *MockCollector) Initialize() error {
m.mu.Lock()
defer m.mu.Unlock()
m.initialized = true
return nil
}
// SpanContext is a mock implementation of the Collector interface's SpanContext method.
func (m *MockCollector) SpanContext(ctx context.Context, spanName string) (context.Context, context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
mockCtx, cancel := context.WithCancel(ctx)
m.traces = append(m.traces, MockTrace{
SpanName: spanName,
Context: mockCtx,
CancelFunc: cancel,
})
return mockCtx, cancel
}
// HandleEvent is a mock implementation of the Collector interface's HandleEvent method.
func (m *MockCollector) HandleEvent(event types.Event) error {
m.mu.Lock()
defer m.mu.Unlock()
m.events = append(m.events, event)
return nil
}
// Flush is a mock implementation of the Collector interface's Flush method.
func (m *MockCollector) Flush() error { // Added error return
m.mu.Lock()
defer m.mu.Unlock()
return nil
}
// Shutdown is a mock implementation of the Collector interface's Shutdown method.
func (m *MockCollector) Shutdown() error { // Added error return
m.mu.Lock()
defer m.mu.Unlock()
return nil
}
// GetName returns the name of the mock collector.
func (m *MockCollector) GetName() string {
return m.name
}
// GetTraces returns the recorded traces.
func (m *MockCollector) GetTraces() []MockTrace {
m.mu.Lock()
defer m.mu.Unlock()
return m.traces
}
// GetEvents returns the recorded events.
func (m *MockCollector) GetEvents() []types.Event {
m.mu.Lock()
defer m.mu.Unlock()
return m.events
}
// Reset clears all recorded events and traces.
func (m *MockCollector) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.events = []types.Event{}
m.traces = []MockTrace{}
}
// AssertInitialized checks if the mock collector was initialized.
func (m *MockCollector) AssertInitialized() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.initialized
}
// MockTelemetry is a mock implementation of the Telemetry system.
type MockTelemetry struct {
Telemetry
mu sync.Mutex
collectors map[string]*MockCollector
}
// NewMockTelemetry creates a new instance of MockTelemetry that mimics the Telemetry struct.
func NewMockTelemetry(config *types.TelemetryConfig) *MockTelemetry {
return &MockTelemetry{
Telemetry: Telemetry{
config: config,
collectors: make(map[string]Collector),
},
collectors: make(map[string]*MockCollector),
}
}
// AddCollector adds a mock collector to the telemetry system.
func (m *MockTelemetry) AddCollector(collector *MockCollector) {
m.mu.Lock()
defer m.mu.Unlock()
m.collectors[collector.GetName()] = collector
}
// SpanContext simulates starting a trace with the given collectors.
func (m *MockTelemetry) SpanContext(ctx context.Context, _ string, span string, collectors ...string) (context.Context, context.CancelFunc) { // Renamed unused parameter
var cancelFuncs []context.CancelFunc
for _, collectorName := range collectors {
if collector, ok := m.collectors[collectorName]; ok {
mockCtx, cancel := collector.SpanContext(ctx, span)
cancelFuncs = append(cancelFuncs, cancel)
ctx = mockCtx
}
}
cancel := func() {
for _, cancelFunc := range cancelFuncs {
cancelFunc()
}
}
return ctx, cancel
}
// Trace simulates logging a trace event in all added collectors.
func (m *MockTelemetry) Trace(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.TRACE, message, payload)
}
// Debug simulates logging a debug event in all added collectors.
func (m *MockTelemetry) Debug(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.DEBUG, message, payload)
}
// Info simulates logging an info event in all added collectors.
func (m *MockTelemetry) Info(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.INFO, message, payload)
}
// Warn simulates logging a warning event in all added collectors.
func (m *MockTelemetry) Warn(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.WARN, message, payload)
}
// Error simulates logging an error event in all added collectors.
func (m *MockTelemetry) Error(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.ERROR, message, payload)
}
// Fatal simulates logging a fatal event in all added collectors.
func (m *MockTelemetry) Fatal(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.FATAL, message, payload)
}
// logEvent logs an event in all collectors.
func (m *MockTelemetry) logEvent(ctx context.Context, level types.ObservabilityLevel, message string, payload map[string]interface{}) {
event := types.Event{
Context: ctx,
Level: level,
Message: message,
Payload: payload,
}
for _, collector := range m.collectors {
_ = collector.HandleEvent(event) // HandleEvent error is intentionally ignored
}
}
// GetCollector returns a mock collector by name.
func (m *MockTelemetry) GetCollector(name string) *MockCollector {
m.mu.Lock()
defer m.mu.Unlock()
return m.collectors[name]
}
// Reset clears all recorded data in all collectors.
func (m *MockTelemetry) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
for _, collector := range m.collectors {
collector.Reset()
}
}
// Flush is a mock implementation of the Telemetry system's Flush method.
func (m *MockTelemetry) Flush() {
for _, collector := range m.collectors {
_ = collector.Flush() // Flush error is intentionally ignored
}
}
// Shutdown is a mock implementation of the Telemetry system's Shutdown method.
func (m *MockTelemetry) Shutdown() {
m.Flush()
for _, collector := range m.collectors {
_ = collector.Shutdown() // Shutdown error is intentionally ignored
}
}
package telemetry
import (
"context"
"gitlab.com/nunet/device-management-service/types"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.uber.org/zap"
)
type OpenTelemetryCollector struct {
config *types.TelemetryConfig
logger *zap.Logger
tracerProvider *sdktrace.TracerProvider
}
func NewOpenTelemetryCollector(config *types.TelemetryConfig, logger *zap.Logger) *OpenTelemetryCollector {
return &OpenTelemetryCollector{
config: config,
logger: logger,
}
}
func (c *OpenTelemetryCollector) Initialize() error {
c.logger.Info("Initializing OpenTelemetry HTTP trace exporter",
zap.String("endpoint", c.config.GlobalEndpoint),
)
exp, err := otlptracehttp.New(context.Background(),
otlptracehttp.WithEndpoint(c.config.GlobalEndpoint),
otlptracehttp.WithInsecure(),
)
if err != nil {
c.logger.Error("Failed to create HTTP trace exporter", zap.Error(err))
return err
}
res, err := resource.New(context.Background(),
resource.WithAttributes(
semconv.ServiceNameKey.String(c.config.ServiceName),
),
)
if err != nil {
c.logger.Error("Failed to create resource", zap.Error(err))
return err
}
c.tracerProvider = sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exp),
sdktrace.WithResource(res),
)
otel.SetTracerProvider(c.tracerProvider)
c.logger.Info("OpenTelemetryCollector initialized.")
return nil
}
func (c *OpenTelemetryCollector) HandleEvent(event types.Event) error {
fields := []attribute.KeyValue{
attribute.String("message", event.Message),
attribute.String("level", event.Level.String()),
}
for key, value := range event.Payload {
fields = append(fields, attribute.String(key, value.(string)))
}
c.logger.Info("Handling event",
zap.String("message", event.Message),
zap.String("level", event.Level.String()),
zap.Any("context", event.Context),
zap.Any("payload", event.Payload),
)
// Fetch tracer name from context, or default to "otel-tracer"
tracerName, ok := event.Context.Value(tracerNameKey).(string)
if !ok {
tracerName = "otel-tracer"
}
tracer := c.tracerProvider.Tracer(tracerName)
ctx := context.Background()
_, span := tracer.Start(ctx, event.Message)
span.SetAttributes(fields...)
span.End()
c.logger.Info("Event sent to OpenTelemetry",
zap.String("message", event.Message),
zap.String("level", event.Level.String()),
zap.Any("context", event.Context),
zap.Any("payload", event.Payload),
)
return nil
}
func (c *OpenTelemetryCollector) Flush() error {
if c.tracerProvider == nil {
c.logger.Warn("TracerProvider is nil, skipping flush")
return nil
}
c.logger.Info("Flushing tracer provider")
if err := c.tracerProvider.ForceFlush(context.Background()); err != nil {
c.logger.Error("Error flushing tracer provider", zap.Error(err))
return err
}
c.logger.Info("Collector flushed successfully")
return nil
}
func (c *OpenTelemetryCollector) Shutdown() error {
if c.tracerProvider == nil {
c.logger.Warn("TracerProvider is nil, skipping shutdown")
return nil
}
c.logger.Info("Shutting down tracer provider")
if err := c.tracerProvider.Shutdown(context.Background()); err != nil {
c.logger.Error("Error shutting down tracer provider", zap.Error(err))
return err
}
c.logger.Info("Collector shutdown successfully")
return nil
}
func (c *OpenTelemetryCollector) GetName() string {
return "opentelemetry"
}
func (c *OpenTelemetryCollector) SpanContext(ctx context.Context, span string) (context.Context, context.CancelFunc) {
tracerName, ok := ctx.Value(tracerNameKey).(string)
if !ok {
tracerName = c.GetName()
}
tracer := c.tracerProvider.Tracer(tracerName)
ctx, s := tracer.Start(ctx, span)
cancel := func() {
s.End()
}
return ctx, cancel
}
// Compile-time check to ensure OpenTelemetryCollector implements the Collector interface
var _ Collector = (*OpenTelemetryCollector)(nil)
package telemetry
import (
"context"
"runtime"
"go.uber.org/zap"
"gitlab.com/nunet/device-management-service/types"
)
type Telemetry struct {
config *types.TelemetryConfig
collectors map[string]Collector
testMode bool
}
// Define a custom type for context keys to avoid conflicts
type contextKey string
const (
collectorsKey contextKey = "collectors"
tracerNameKey contextKey = "tracerName"
versionKey contextKey = "version"
)
func GetTelemetry() *Telemetry {
return instance
}
// NewTelemetry initializes a new Telemetry instance.
// If testMode is true, the telemetry operations will be no-ops.
func NewTelemetry(config *types.TelemetryConfig, collectors map[string]Collector, testMode bool) *Telemetry {
if testMode {
return &Telemetry{
testMode: true,
}
}
return &Telemetry{
config: config,
collectors: collectors,
testMode: false,
}
}
func (t *Telemetry) SpanContext(ctx context.Context, tracerName string, span string, collectors ...string) (context.Context, context.CancelFunc) {
if t.testMode {
return ctx, func() {}
}
var cancelFuncs []context.CancelFunc
// Fetch caller info
pc, _, _, ok := runtime.Caller(1)
functionName := "unknown_function"
if ok {
function := runtime.FuncForPC(pc)
functionName = function.Name()
}
// Use caller info as default tracer and span names if not provided
if tracerName == "" {
tracerName = functionName
}
if span == "" {
span = functionName
}
ctx = context.WithValue(ctx, collectorsKey, collectors)
ctx = context.WithValue(ctx, tracerNameKey, tracerName)
var cancelFunc context.CancelFunc
for _, collector := range collectors {
if c, ok := t.collectors[collector]; ok {
ctx, cancelFunc = c.SpanContext(ctx, span)
cancelFuncs = append(cancelFuncs, cancelFunc)
}
}
cancel := func() {
for _, cancelFunc := range cancelFuncs {
cancelFunc()
}
}
return ctx, cancel
}
func (t *Telemetry) Trace(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.TRACE, message, payload)
}
func (t *Telemetry) Debug(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.DEBUG, message, payload)
}
func (t *Telemetry) Info(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.INFO, message, payload)
}
func (t *Telemetry) Warn(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.WARN, message, payload)
}
func (t *Telemetry) Error(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.ERROR, message, payload)
}
func (t *Telemetry) Fatal(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.FATAL, message, payload)
}
func (t *Telemetry) logEvent(ctx context.Context, level types.ObservabilityLevel, message string, payload map[string]interface{}) {
// Check if telemetry is enabled
if t.config.TelemetryMode == "disabled" {
return
}
// Only log events that are at or above the configured log level
if level < logLevel {
return
}
// Add the version to the context
ctx = context.WithValue(ctx, versionKey, "v0.5")
event := types.Event{
Context: ctx,
Level: level,
Message: message,
Payload: payload,
}
// Check for specific collector in context
collectors, ok := ctx.Value(collectorsKey).([]string)
if ok {
for _, collector := range collectors {
if c, ok := t.collectors[collector]; ok {
if err := c.HandleEvent(event); err != nil {
zap.L().Error("Failed to handle event", zap.Error(err))
}
}
}
return
}
// Forward to all collectors by default
for _, collector := range t.collectors {
if err := collector.HandleEvent(event); err != nil {
zap.L().Error("Failed to handle event", zap.Error(err))
}
}
}
func (t *Telemetry) Flush() {
if t.testMode {
return
}
for _, collector := range t.collectors {
if err := collector.Flush(); err != nil {
zap.L().Error("Failed to flush collector", zap.Error(err))
}
}
}
func (t *Telemetry) Shutdown() {
if t.testMode {
return
}
t.Flush()
for _, collector := range t.collectors {
if err := collector.Shutdown(); err != nil {
zap.L().Error("Failed to shut down collector", zap.Error(err))
}
}
}
package types
import (
"fmt"
"reflect"
"slices"
"github.com/hashicorp/go-version"
)
// Connectivity represents the network configuration
type Connectivity struct {
Ports []int `json:"ports" description:"Ports that need to be open for the job to run"`
VPN bool `json:"vpn" description:"Whether VPN is required"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Connectivity] = (*Connectivity)(nil)
_ Calculable[Connectivity] = (*Connectivity)(nil)
)
func (c *Connectivity) Compare(other Connectivity) Comparison {
if reflect.DeepEqual(*c, other) {
return Equal
}
if IsStrictlyContainedInt(c.Ports, other.Ports) && (c.VPN && other.VPN || c.VPN && !other.VPN) {
return Better
}
return Worse
}
func (c *Connectivity) Add(other Connectivity) error {
portSet := make(map[int]struct{}, len(c.Ports))
for _, port := range c.Ports {
portSet[port] = struct{}{}
}
for _, port := range other.Ports {
if _, exists := portSet[port]; !exists {
c.Ports = append(c.Ports, port)
}
}
// Set VPN to true if other has VPN
c.VPN = c.VPN || other.VPN
return nil
}
func (c *Connectivity) Subtract(other Connectivity) error {
if other.VPN {
c.VPN = false
}
// Filter out the ports that exist in 'other' from 'c'
filteredPorts := c.Ports[:0]
for _, port := range c.Ports {
if !slices.Contains(other.Ports, port) {
filteredPorts = append(filteredPorts, port)
}
}
// Set the filtered ports back to c
if len(filteredPorts) == 0 {
c.Ports = nil
} else {
c.Ports = filteredPorts
}
return nil
}
// TimeInformation represents the time constraints
type TimeInformation struct {
Units string `json:"units" description:"Time units"`
MaxTime int `json:"max_time" description:"Maximum time that job should run"`
Preference int `json:"preference" description:"Time preference"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[TimeInformation] = (*TimeInformation)(nil)
_ Calculable[TimeInformation] = (*TimeInformation)(nil)
)
func (t *TimeInformation) Add(other TimeInformation) error {
t.MaxTime += other.MaxTime
// If there are other time fields, add them here
// Example:
// t.MinTime += other.MinTime
// t.AvgTime += other.AvgTime
return nil
}
func (t *TimeInformation) Subtract(other TimeInformation) error {
t.MaxTime -= other.MaxTime
// If there are other time fields, subtract them here
// Example:
// t.MinTime -= other.MinTime
// t.AvgTime -= other.AvgTime
return nil
}
func (t *TimeInformation) TotalTime() int {
switch t.Units {
case "seconds":
return t.MaxTime
case "minutes":
return t.MaxTime * 60
case "hours":
return t.MaxTime * 60 * 60
case "days":
return t.MaxTime * 60 * 60 * 24
default:
return t.MaxTime
}
}
func (t *TimeInformation) Compare(other TimeInformation) Comparison {
if reflect.DeepEqual(t, other) {
return Equal
}
ownTotalTime := t.TotalTime()
otherTotalTime := other.TotalTime()
if ownTotalTime == otherTotalTime {
return Equal
}
if ownTotalTime < otherTotalTime {
return Worse
}
return Better
}
// PriceInformation represents the pricing information
type PriceInformation struct {
Currency string `json:"currency" description:"Currency used for pricing"`
CurrencyPerHour int `json:"currency_per_hour" description:"Price charged per hour"`
TotalPerJob int `json:"total_per_job" description:"Maximum total price or budget of the job"`
Preference int `json:"preference" description:"Pricing preference"`
}
// implementing Comparable interface
var _ Comparable[PriceInformation] = (*PriceInformation)(nil)
func (p *PriceInformation) Compare(other PriceInformation) Comparison {
if reflect.DeepEqual(p, other) {
return Equal
}
if p.Currency == other.Currency {
if p.TotalPerJob == other.TotalPerJob {
if p.CurrencyPerHour == other.CurrencyPerHour {
return Equal
} else if p.CurrencyPerHour < other.CurrencyPerHour {
return Better
}
return Worse
}
if p.TotalPerJob < other.TotalPerJob {
if p.CurrencyPerHour <= other.CurrencyPerHour {
return Better
}
return Worse
}
return Worse
}
return Error
}
func (p *PriceInformation) Equal(price PriceInformation) bool {
return p.Currency == price.Currency &&
p.CurrencyPerHour == price.CurrencyPerHour &&
p.TotalPerJob == price.TotalPerJob &&
p.Preference == price.Preference
}
// Library represents the libraries
type Library struct {
Name string `json:"name" description:"Name of the library"`
Constraint string `json:"constraint" description:"Constraint of the library"`
Version string `json:"version" description:"Version of the library"`
}
// implementing Comparable interface
var _ Comparable[Library] = (*Library)(nil)
func (lib *Library) Compare(other Library) Comparison {
ownVersion, err := version.NewVersion(lib.Version)
if err != nil {
return Error
}
// return 'Error' if the version of the left library is not valid
constraints, err := version.NewConstraint(other.Constraint + " " + other.Version)
if err != nil {
return Error
}
// return 'Error' if the names of the libraries are different
if lib.Name != other.Name {
return Error
}
// else return 'Equal if versions of libraries are equal and the constraint is '='
if other.Constraint == "=" && constraints.Check(ownVersion) {
return Equal
}
// else return 'Better' if versions of libraries match the constraint
if constraints.Check(ownVersion) {
return Better
}
// else return 'Worse'
return Worse
}
func (lib *Library) Equal(library Library) bool {
if lib.Name == library.Name && lib.Constraint == library.Constraint && lib.Version == library.Version {
return true
}
return false
}
// Locality represents the locality
type Locality struct {
Kind string `json:"kind" description:"Kind of the region (geographic, nunet-defined, etc)"`
Name string `json:"name" description:"Name of the region"`
}
// implementing Comparable interface
var _ Comparable[Locality] = (*Locality)(nil)
func (loc *Locality) Compare(other Locality) Comparison {
if loc.Kind == other.Kind {
if loc.Name == other.Name {
return Equal
}
return Worse
}
return Error
}
func (loc *Locality) Equal(locality Locality) bool {
if loc.Kind == locality.Kind && loc.Name == locality.Name {
return true
}
return false
}
// KYC represents the KYC data
type KYC struct {
Type string `json:"type" description:"Type of KYC"`
Data string `json:"data" description:"Data required for KYC"`
}
// implementing Comparable interface
var _ Comparable[KYC] = (*KYC)(nil)
func (k *KYC) Compare(other KYC) Comparison {
if reflect.DeepEqual(*k, other) {
return Equal
}
return Error
}
func (k *KYC) Equal(kyc KYC) bool {
if k.Type == kyc.Type && k.Data == kyc.Data {
return true
}
return false
}
// JobType represents the type of the job
type JobType string
const (
Batch JobType = "batch"
SingleRun JobType = "single_run"
Recurring JobType = "recurring"
LongRunning JobType = "long_running"
)
// implementing Comparable and Calculable interface
var _ Comparable[JobType] = (*JobType)(nil)
func (j JobType) Compare(other JobType) Comparison {
if reflect.DeepEqual(j, other) {
return Equal
}
return Error
}
// JobTypes a slice of JobType
type JobTypes []JobType
// implementing Comparable and Calculable interfaces
var (
_ Comparable[JobTypes] = (*JobTypes)(nil)
_ Calculable[JobTypes] = (*JobTypes)(nil)
)
func (j *JobTypes) Add(other JobTypes) error {
existing := *j
result := existing[:0]
existingSet := make(map[JobType]struct{}, len(result))
for _, job := range *j {
result = append(result, job)
existingSet[job] = struct{}{}
}
for _, job := range other {
if _, exists := existingSet[job]; !exists {
result = append(result, job)
existingSet[job] = struct{}{}
}
}
*j = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (j *JobTypes) Subtract(other JobTypes) error {
existing := *j
result := existing[:0]
toRemove := make(map[JobType]struct{}, len(other))
for _, job := range other {
toRemove[job] = struct{}{}
}
for _, job := range existing {
if _, found := toRemove[job]; !found {
result = append(result, job)
}
}
*j = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (j *JobTypes) Compare(other JobTypes) Comparison {
// we know that interfaces here are slices, so need to assert first
l := ConvertTypedSliceToUntypedSlice(*j)
r := ConvertTypedSliceToUntypedSlice(other)
if !IsSameShallowType(l, r) {
return Error
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return Equal
case IsStrictlyContained(l, r):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
return Better
case IsStrictlyContained(r, l):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
return Worse
// TODO: this comparator does not take into account options when several job types are available and several job types are required
// in the same data structure; this is why the test fails;
}
return Error
}
func (j *JobTypes) Contains(jobType JobType) bool {
for _, j := range *j {
if j == jobType {
return true
}
}
return false
}
type Libraries []Library
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Libraries] = (*Libraries)(nil)
_ Calculable[Libraries] = (*Libraries)(nil)
)
func (l *Libraries) Add(other Libraries) error {
existing := make(map[Library]struct{}, len(*l))
for _, lib := range *l {
existing[lib] = struct{}{}
}
for _, otherLibrary := range other {
if _, found := existing[otherLibrary]; !found {
*l = append(*l, otherLibrary)
existing[otherLibrary] = struct{}{}
}
}
return nil
}
func (l *Libraries) Subtract(other Libraries) error {
// remove from array
existing := *l
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[Library]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex2 := range existing {
if _, found := toRemove[ex2]; !found {
result = append(result, ex2)
}
}
*l = result
return nil
}
func (l *Libraries) Compare(other Libraries) Comparison {
interimComparison1 := make([][]Comparison, 0)
for _, otherLibrary := range other {
var interimComparison2 []Comparison
for _, ownLibrary := range *l {
interimComparison2 = append(interimComparison2, ownLibrary.Compare(otherLibrary))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
finalComparison := make([]Comparison, 0)
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, Error) {
return Error
}
if slices.Contains(finalComparison, Worse) {
return Worse
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal
}
return Better
}
func (l *Libraries) Contains(library Library) bool {
for _, lib := range *l {
if lib.Equal(library) {
return true
}
}
return false
}
type Localities []Locality
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Localities] = (*Localities)(nil)
_ Calculable[Localities] = (*Localities)(nil)
)
func (l *Localities) Compare(other Localities) Comparison {
interimComparison := make([]map[string]Comparison, 0)
for _, otherLocality := range other {
field := make(map[string]Comparison)
field[otherLocality.Kind] = Error
for _, ownLocality := range *l {
if ownLocality.Kind == otherLocality.Kind {
field[otherLocality.Kind] = ownLocality.Compare(otherLocality)
// this is to make sure that we have a comparison even if slice dimentiones do not match
}
}
interimComparison = append(interimComparison, field)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []Comparison
for _, c := range interimComparison {
for _, v := range c { // we know that there is only one value in the map
finalComparison = append(finalComparison, v)
}
}
if slices.Contains(finalComparison, Error) {
return Error
}
if slices.Contains(finalComparison, Worse) {
return Worse
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal
}
return Better
}
func (l *Localities) Add(other Localities) error {
existing := make(map[Locality]struct{}, len(*l))
for _, loc := range *l {
existing[loc] = struct{}{}
}
for _, otherLocality := range other {
if _, found := existing[otherLocality]; !found {
*l = append(*l, otherLocality)
existing[otherLocality] = struct{}{}
}
}
return nil
}
func (l *Localities) Subtract(other Localities) error {
// remove from array
existing := *l
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[Locality]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex := range existing {
if _, found := toRemove[ex]; !found {
result = append(result, ex)
}
}
*l = result
return nil
}
func (l *Localities) Contains(locality Locality) bool {
for _, loc := range *l {
if loc.Equal(locality) {
return true
}
}
return false
}
type KYCs []KYC
// implementing Comparable and Calculable interfaces
var (
_ Comparable[KYCs] = (*KYCs)(nil)
_ Calculable[KYCs] = (*KYCs)(nil)
)
func (k *KYCs) Compare(other KYCs) Comparison {
if reflect.DeepEqual(*k, other) {
return Equal
} else if len(other) == 0 && len(*k) != 0 {
return Better
}
for _, ownKYC := range *k {
for _, otherKYC := range other {
if comp := ownKYC.Compare(otherKYC); comp == Equal {
return Equal
}
}
}
return Error
}
func (k *KYCs) Add(other KYCs) error {
existing := make(map[KYC]struct{}, len(*k))
for _, kyc := range *k {
existing[kyc] = struct{}{}
}
for _, otherKYC := range other {
if _, found := existing[otherKYC]; !found {
*k = append(*k, otherKYC)
existing[otherKYC] = struct{}{}
}
}
return nil
}
func (k *KYCs) Subtract(other KYCs) error {
// remove from array
existing := *k
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[KYC]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex := range existing {
if _, found := toRemove[ex]; !found {
result = append(result, ex)
}
}
*k = result
return nil
}
func (k *KYCs) Contains(kyc KYC) bool {
for _, k := range *k {
if k.Equal(kyc) {
return true
}
}
return false
}
type PricesInformation []PriceInformation
// implementing Comparable and Calculable interfaces
var (
_ Comparable[PricesInformation] = (*PricesInformation)(nil)
_ Calculable[PricesInformation] = (*PricesInformation)(nil)
)
func (ps *PricesInformation) Add(other PricesInformation) error {
existing := make(map[PriceInformation]struct{}, len(*ps))
for _, p := range *ps {
existing[p] = struct{}{}
}
for _, otherPrice := range other {
if _, found := existing[otherPrice]; !found {
*ps = append(*ps, otherPrice)
existing[otherPrice] = struct{}{}
}
}
return nil
}
func (ps *PricesInformation) Subtract(other PricesInformation) error {
if len(other) == 0 {
return nil // Nothing to subtract, no operation needed
}
toRemove := make(map[PriceInformation]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
result := (*ps)[:0] // Reuse the slice's underlying array
for _, ex := range *ps {
if _, found := toRemove[ex]; !found {
result = append(result, ex) // Keep entries not in 'toRemove'
}
}
*ps = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (ps *PricesInformation) Compare(other PricesInformation) Comparison {
if reflect.DeepEqual(*ps, other) {
return Equal
}
comparison := Error
for _, ownPrice := range *ps {
for _, otherPrice := range other {
if comparison = ownPrice.Compare(otherPrice); comparison != Error {
return comparison
}
}
}
return comparison
}
func (ps *PricesInformation) Contains(price PriceInformation) bool {
for _, p := range *ps {
if p.Equal(price) {
return true
}
}
return false
}
// HardwareCapability represents the hardware capability of the machine
type HardwareCapability struct {
Executors Executors `json:"executor" description:"Executor type required for the job (docker, vm, wasm, or others)"`
JobTypes JobTypes `json:"type" description:"Details about type of the job (One time, batch, recurring, long running). Refer to dms.jobs package for jobType data model"`
Resources Resources `json:"resources" description:"Resources required for the job"`
Libraries Libraries `json:"libraries" description:"Libraries required for the job"`
Localities Localities `json:"locality" description:"Preferred localities of the machine for execution"`
Connectivity Connectivity `json:"connectivity" description:"Network configuration required"`
Price PricesInformation `json:"price" description:"Pricing information"`
Time TimeInformation `json:"time" description:"Time constraints"`
KYCs KYCs
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[HardwareCapability] = (*HardwareCapability)(nil)
_ Calculable[HardwareCapability] = (*HardwareCapability)(nil)
)
// Compare compares two HardwareCapability objects
func (c *HardwareCapability) Compare(other HardwareCapability) Comparison {
comparisonMap := ComplexComparison{
"Executors": c.Executors.Compare(other.Executors),
"JobTypes": c.JobTypes.Compare(other.JobTypes),
"Resources": c.Resources.Compare(other.Resources),
"Libraries": c.Libraries.Compare(other.Libraries),
"Localities": c.Localities.Compare(other.Localities),
"Price": c.Price.Compare(other.Price),
"KYCs": c.KYCs.Compare(other.KYCs),
"Time": c.Time.Compare(other.Time),
"Connectivity": c.Connectivity.Compare(other.Connectivity),
}
return comparisonMap.Result()
}
// Add adds the resources of the given HardwareCapability to the current HardwareCapability
func (c *HardwareCapability) Add(other HardwareCapability) error {
// Executors
if err := c.Executors.Add(other.Executors); err != nil {
return fmt.Errorf("error adding Executors")
}
// JobTypes
if err := c.JobTypes.Add(other.JobTypes); err != nil {
return fmt.Errorf("error adding JobTypes: %v", err)
}
// Resources
if err := c.Resources.Add(other.Resources); err != nil {
return err
}
// Libraries
if err := c.Libraries.Add(other.Libraries); err != nil {
return fmt.Errorf("error adding Libraries: %v", err)
}
// Localities
if err := c.Localities.Add(other.Localities); err != nil {
return fmt.Errorf("error adding Localities: %v", err)
}
// Connectivity
if err := c.Connectivity.Add(other.Connectivity); err != nil {
return fmt.Errorf("error adding Connectivity: %v", err)
}
// Price
if err := c.Price.Add(other.Price); err != nil {
return fmt.Errorf("error adding Price: %v", err)
}
// Time
if err := c.Time.Add(other.Time); err != nil {
return fmt.Errorf("error adding Time: %v", err)
}
// KYCs
if err := c.KYCs.Add(other.KYCs); err != nil {
return fmt.Errorf("error adding KYCs: %v", err)
}
return nil
}
// Subtract subtracts the resources of the given HardwareCapability from the current HardwareCapability
func (c *HardwareCapability) Subtract(cap HardwareCapability) error {
// Executors
if err := c.Executors.Subtract(cap.Executors); err != nil {
return fmt.Errorf("error subtracting Executors: %v", err)
}
// JobTypes
if err := c.JobTypes.Subtract(cap.JobTypes); err != nil {
return fmt.Errorf("error comparing JobTypes: %v", err)
}
// Resources
if err := c.Resources.Subtract(cap.Resources); err != nil {
return fmt.Errorf("error subtracting Resources: %v", err)
}
// Libraries
if err := c.Libraries.Subtract(cap.Libraries); err != nil {
return fmt.Errorf("error subtracting Libraries: %v", err)
}
// Localities
if err := c.Localities.Subtract(cap.Localities); err != nil {
return fmt.Errorf("error subtracting Localities: %v", err)
}
// Connectivity
if err := c.Connectivity.Subtract(cap.Connectivity); err != nil {
return fmt.Errorf("error subtracting Connectivity: %v", err)
}
// Price
if err := c.Price.Subtract(cap.Price); err != nil {
return fmt.Errorf("error subtracting Price: %v", err)
}
// Time
if err := c.Time.Subtract(cap.Time); err != nil {
return fmt.Errorf("error subtracting Time: %v", err)
}
// KYCs
if err := c.KYCs.Subtract(cap.KYCs); err != nil {
return fmt.Errorf("error subtracting KYCs: %v", err)
}
return nil
}
package types
import "reflect"
// Comparable public Comparable interface to be enforced on types that can be compared
type Comparable[T any] interface {
Compare(other T) Comparison
}
type Calculable[T any] interface {
Add(other T) error
Subtract(other T) error
}
type PreferenceString string
const (
Hard PreferenceString = "Hard"
Soft PreferenceString = "Soft"
)
type Preference struct {
TypeName string
Strength PreferenceString
DefaultComparatorOverride Comparator
}
type Comparator func(l, r interface{}, preference ...Preference) Comparison
// ComplexCompare helper function to return a complex comparison of two complex types
// this uses reflection, could become a performance bottleneck
// with generics, we wouldn't need this function and could use the ComplexCompare method directly
func ComplexCompare(l, r interface{}) ComplexComparison {
// Complex comparison is a comparison of two complex types
// Which have nested fields that need to be considered together
// before a final comparison for the whole complex type can be made
// it is a helper function used in some type comparators
complexComparison := make(map[string]Comparison)
val1 := reflect.ValueOf(l)
val2 := reflect.ValueOf(r)
for i := 0; i < val1.NumField(); i++ {
field1 := val1.Field(i)
field2 := val2.Field(i)
// we need a pointer to field1 to call the ComplexCompare method
// if the field is not addressable we need to create a pointer to it
// and then call the ComplexCompare method on it
var field1Ptr reflect.Value
if field1.CanAddr() {
field1Ptr = field1.Addr()
} else {
ptr := reflect.New(field1.Type())
ptr.Elem().Set(field1)
field1Ptr = ptr
}
compareMethod := field1Ptr.MethodByName("Compare")
if compareMethod.IsValid() &&
compareMethod.Type().NumIn() == 1 &&
compareMethod.Type().In(0) == field1.Type() {
result := compareMethod.Call([]reflect.Value{field2})[0].Interface().(Comparison)
complexComparison[val1.Type().Field(i).Name] = result
} else {
complexComparison[val1.Type().Field(i).Name] = Error
}
}
return complexComparison
}
func NumericComparator[T float64 | float32 | int | int32 | int64 | uint64](l, r T, _ ...Preference) Comparison {
// comparator for numeric types:
// left represents machine capabilities;
// right represents required capabilities;
switch {
case l == r:
return Equal
case l < r:
return Worse
case l > r:
return Better
}
return Error
}
func LiteralComparator[T ~string](l, r T, _ ...Preference) Comparison {
// comparator for literal (string-like) types:
// left represents machine capabilities;
// right represents required capabilities;
// which can only be equal or not equal.
// ComplexCompare the string values
if l == r {
return Equal
}
return Error
}
package types
// Comparison is a type for comparison results
type Comparison string
const (
// Worse means left object is 'worse' than right object
Worse Comparison = "Worse"
// Better means left object is 'better' than right object
Better Comparison = "Better"
// Equal means objects on the left and right are 'equally good'
Equal Comparison = "Equal"
// Error means error in comparison or objects incomparable
Error Comparison = "Error"
)
// And returns the result of AND operation of two Comparison values
// it respects the following table of truth:
// | AND | Better | Worse | Equal | Error |
// | ------ | ------ |--------|--------|--------|
// | Better | Better | Worse | Better | Error |
// | Worse | Worse | Worse | Worse | Error |
// | Equal | Better | Worse | Equal | Error |
// | Error | Error | Error | Error | Error |
func (c Comparison) And(cmp Comparison) Comparison {
if c == Error || cmp == Error {
return Error
}
if c == cmp {
return c
}
switch c {
case Equal:
switch cmp {
case Better:
return Better
case Worse:
return Worse
case Equal:
return Equal
default:
return Error
}
case Better:
switch cmp {
case Worse:
return Worse
case Equal:
return Better
default:
return Error
}
case Worse:
return Worse
default:
return Error
}
}
// ComplexComparison is a map of string to Comparison
type ComplexComparison map[string]Comparison
// Result returns the result of AND operation of all Comparison values in the ComplexComparison
func (c *ComplexComparison) Result() Comparison {
result := Equal
for _, comparison := range *c {
result = result.And(comparison)
}
return result
}
package types
// ExecutionRequest is the request object for executing a job
type ExecutionRequest struct {
JobID string // ID of the job to execute
ExecutionID string // ID of the execution
EngineSpec *SpecConfig // Engine spec for the execution
Resources *Resources // Resources for the execution
Inputs []*StorageVolumeExecutor // Input volumes for the execution
Outputs []*StorageVolumeExecutor // Output volumes for the results
ResultsDir string // Directory to store the results
}
// ExecutionListItem is the result of the current executions.
type ExecutionListItem struct {
ExecutionID string // ID of the execution
Running bool
}
// ExecutionResult is the result of an execution
type ExecutionResult struct {
STDOUT string `json:"stdout"` // STDOUT of the execution
STDERR string `json:"stderr"` // STDERR of the execution
ExitCode int `json:"exit_code"` // Exit code of the execution
ErrorMsg string `json:"error_msg"` // Error message if the execution failed
}
// NewExecutionResult creates a new ExecutionResult object
func NewExecutionResult(code int) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: code,
}
}
// NewFailedExecutionResult creates a new ExecutionResult object for a failed execution
// It sets the error message from the provided error and sets the exit code to -1
func NewFailedExecutionResult(err error) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: -1,
ErrorMsg: err.Error(),
}
}
// LogStreamRequest is the request object for streaming logs from an execution
type LogStreamRequest struct {
JobID string // ID of the job
ExecutionID string // ID of the execution
Tail bool // Tail the logs
Follow bool // Follow the logs
}
package types
import (
"reflect"
)
// ExecutorType is the type of the executor
type ExecutorType string
const (
ExecutorTypeDocker ExecutorType = "docker"
ExecutorTypeFirecracker ExecutorType = "firecracker"
ExecutorTypeWasm ExecutorType = "wasm"
ExecutionStatusCodeSuccess = 0
)
// implementing Comparable interface
var _ Comparable[ExecutorType] = (*ExecutorType)(nil)
// Compare compares two ExecutorType objects
func (e ExecutorType) Compare(other ExecutorType) Comparison {
return LiteralComparator(string(e), string(other))
}
// String returns the string representation of the ExecutorType
func (e ExecutorType) String() string {
return string(e)
}
// Executor is the executor type
type Executor struct {
ExecutorType ExecutorType `json:"executor_type"`
}
// implementing Comparable interface
var _ Comparable[Executor] = (*Executor)(nil)
// Compare compares two Executor objects
func (e *Executor) Compare(other Executor) Comparison {
// comparator for Executor types
// it is needed because executor type is defined as enum of ExecutorType's in types.execution.go
// left represent machine capabilities
// right represent required capabilities
// it is not so complex as the type has only one field
// therefore this method just passes it through...
return e.ExecutorType.Compare(other.ExecutorType)
}
// Equal checks if two Executor objects are equal
func (e *Executor) Equal(executor Executor) bool {
return e.ExecutorType == executor.ExecutorType
}
// Executors is a list of Executor objects
type Executors []Executor
// implementing Comparable and Calculable interface
var (
_ Comparable[Executors] = (*Executors)(nil)
_ Calculable[Executors] = (*Executors)(nil)
)
// Add adds the Executor object to another Executor object
func (e *Executors) Add(other Executors) error {
// append to Executors slice
*e = append(*e, other...)
return nil
}
// Subtract subtracts the Executor object from another Executor object
func (e *Executors) Subtract(other Executors) error {
if len(other) == 0 {
return nil
}
toRemove := make(map[ExecutorType]struct{})
for _, ex := range other {
toRemove[ex.ExecutorType] = struct{}{}
}
result := (*e)[:0]
for _, ex := range *e {
if _, found := toRemove[ex.ExecutorType]; !found {
result = append(result, ex)
}
}
*e = result[:len(result):len(result)]
return nil
}
// Contains checks if an Executor object is in the list of Executors
func (e *Executors) Contains(executor Executor) bool {
executors := *e
for _, ex := range executors {
if ex.Equal(executor) {
return true
}
}
return false
}
// Compare compares two Executors objects
func (e *Executors) Compare(other Executors) Comparison {
if reflect.DeepEqual(*e, other) {
return Equal
}
// comparator for Executors types:
// left represent machine capabilities;
// right represent required capabilities;
lSlice := make([]interface{}, 0, len(*e))
rSlice := make([]interface{}, 0, len(other))
for _, ex := range *e {
lSlice = append(lSlice, ex)
}
for _, ex := range other {
rSlice = append(rSlice, ex)
}
if !IsSameShallowType(lSlice, rSlice) {
return Error
}
switch {
case reflect.DeepEqual(lSlice, rSlice):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return Equal
case IsStrictlyContained(lSlice, rSlice):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
return Better
case IsStrictlyContained(rSlice, lSlice):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
return Worse
default:
return Error
}
}
package types
import (
"fmt"
"slices"
"strings"
)
type GPUVendor string
const (
GPUVendorNvidia GPUVendor = "NVIDIA"
GPUVendorAMDATI GPUVendor = "AMD/ATI"
GPUVendorIntel GPUVendor = "Intel"
GPUVendorUnknown GPUVendor = "Unknown"
None GPUVendor = "None"
)
// implementing Comparable interface
var _ Comparable[GPUVendor] = (*GPUVendor)(nil)
func (g GPUVendor) Compare(other GPUVendor) Comparison {
if g == other {
return Equal
}
return Error
}
// ParseGPUVendor parses the GPU vendor string and returns the corresponding GPUVendor enum
func ParseGPUVendor(vendor string) GPUVendor {
switch {
case strings.Contains(strings.ToUpper(vendor), "NVIDIA"):
return GPUVendorNvidia
case strings.Contains(strings.ToUpper(vendor), "AMD") ||
strings.Contains(strings.ToUpper(vendor), "ATI"):
return GPUVendorAMDATI
case strings.Contains(strings.ToUpper(vendor), "INTEL"):
return GPUVendorIntel
default:
return GPUVendorUnknown
}
}
// GPU represents the GPU information
type GPU struct {
// Index is the self-reported index of the device in the system
Index int
// Name is the model name of the GPU e.g. Tesla T4
Name string
// Vendor is the maker of the GPU, e.g. NVidia, AMD, Intel
Vendor GPUVendor
// PCIAddress is the PCI address of the device, in the format AAAA:BB:CC.C
// Used to discover the correct device rendering cards
PCIAddress string
// Model of the GPU, e.g. A100
Model string `json:"model" description:"GPU model, ex A100"`
// TotalVRAM is the total amount of VRAM on the device
TotalVRAM uint64
// UsedVRAM is the amount of VRAM currently in use
UsedVRAM uint64
// FreeVRAM is the amount of VRAM currently free
FreeVRAM uint64
// Gorm fields
// Team, is this the right way to do this? What is the best practice we're following?
ResourceID uint `gorm:"foreignKey:ID"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[GPU] = (*GPU)(nil)
_ Calculable[GPU] = (*GPU)(nil)
)
func (g *GPU) Compare(other GPU) Comparison {
comparison := make(ComplexComparison)
// compare the TotalVRAM
switch {
case g.TotalVRAM > other.TotalVRAM:
comparison["TotalVRAM"] = Better
case g.TotalVRAM < other.TotalVRAM:
comparison["TotalVRAM"] = Worse
default:
comparison["TotalVRAM"] = Equal
}
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and VRAM is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of GPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
return comparison["TotalVRAM"]
}
func (g *GPU) Add(other GPU) error {
g.TotalVRAM += other.TotalVRAM
return nil
}
func (g *GPU) Subtract(other GPU) error {
if g.TotalVRAM < other.TotalVRAM {
return fmt.Errorf("total VRAM: underflow, cannot subtract %v from %v", g.TotalVRAM, other.TotalVRAM)
}
g.TotalVRAM -= other.TotalVRAM
return nil
}
func (g *GPU) Equal(other GPU) bool {
return g.Model == other.Model &&
g.TotalVRAM == other.TotalVRAM &&
g.UsedVRAM == other.UsedVRAM &&
g.FreeVRAM == other.FreeVRAM &&
g.Index == other.Index &&
g.Vendor == other.Vendor &&
g.PCIAddress == other.PCIAddress
}
type GPUs []GPU
// implementing Comparable and Calculable interfaces
var (
_ Calculable[GPUs] = (*GPUs)(nil)
_ Comparable[GPUs] = (*GPUs)(nil)
)
func (gpus GPUs) Compare(other GPUs) Comparison {
interimComparison1 := make([][]Comparison, 0)
for _, otherGPU := range other {
var interimComparison2 []Comparison
for _, ownGPU := range gpus {
interimComparison2 = append(interimComparison2, ownGPU.Compare(otherGPU))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []Comparison
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, Error) {
return Error
}
if slices.Contains(finalComparison, Worse) {
return Worse
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal
}
return Better
}
func (gpus GPUs) Add(other GPUs) error {
// TODO: I think this logic needs to change
// 1. if other gpu is in own gpus, add the total vram
// 2. if other gpu is not in own gpus, append it to own gpus
// assuming that the GPUs are ordered by index
// which may not be the case
otherGPUs := make(map[int]GPU)
for _, otherGPU := range other {
otherGPUs[otherGPU.Index] = otherGPU
}
for i, gpu := range gpus {
if otherGPU, ok := otherGPUs[gpu.Index]; ok {
if err := gpus[i].Add(otherGPU); err != nil {
return fmt.Errorf("failed to add GPU %s: %w", gpu.Model, err)
}
}
}
return nil
}
func (gpus GPUs) Subtract(other GPUs) error {
// assuming that the GPUs are ordered by index
// which may not be the case
otherGPUs := make(map[int]GPU)
for _, otherGPU := range other {
otherGPUs[otherGPU.Index] = otherGPU
}
for i, gpu := range gpus {
if otherGPU, ok := otherGPUs[gpu.Index]; ok {
if err := gpus[i].Subtract(otherGPU); err != nil {
return fmt.Errorf("failed to subtract GPU %s: %w", gpu.Model, err)
}
}
}
return nil
}
// GetGPUWithHighestFreeVRAM Determine the GPU vendor with the highest free VRAM: NVIDIA, AMD, or Intel.
// Useful for selecting the best GPU if multiple vendors are available,
// especially in multi-GPU systems or mining rigs.
func (gpus GPUs) GetGPUWithHighestFreeVRAM() (GPU, error) {
if len(gpus) == 0 {
// Return a GPU with Vendor set to None if no GPUs are detected - Useful for launching CPU-only containers
return GPU{Vendor: None}, nil
}
var maxFreeVRAMGpu GPU
maxFreeVRAM := uint64(0)
for _, gpu := range gpus {
if gpu.FreeVRAM > maxFreeVRAM {
maxFreeVRAM = gpu.FreeVRAM
maxFreeVRAMGpu = gpu
}
}
return maxFreeVRAMGpu, nil
}
// CPU represents the CPU information
type CPU struct {
// ClockSpeed represents the CPU clock speed in Hz
ClockSpeed float64
// Cores represents the number of physical CPU cores
Cores float32
// TotalCompute represents the total compute power of the CPU
Compute float64
// TODO: capture the below fields if required
// Model represents the CPU model, e.g., "Intel Core i7-9700K", "AMD Ryzen 9 5900X"
Model string
// Vendor represents the CPU manufacturer, e.g., "Intel", "AMD"
Vendor string
// Threads represents the number of logical CPU threads (including hyperthreading)
Threads int
// Architecture represents the CPU architecture, e.g., "x86", "x86_64", "arm64"
Architecture string
// Cache size in bytes
CacheSize uint64
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[CPU] = (*CPU)(nil)
_ Comparable[CPU] = (*CPU)(nil)
)
func (c *CPU) Compare(other CPU) Comparison {
perfComparison := NumericComparator(
float64(c.Cores)*c.ClockSpeed,
float64(other.Cores)*other.ClockSpeed,
)
archComparison := LiteralComparator(c.Architecture, other.Architecture)
if archComparison == Error {
return Error
}
if archComparison != Equal {
return Worse
}
return perfComparison
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and frequency is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of CPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
}
func (c *CPU) Add(other CPU) error {
c.Cores = round(c.Cores+other.Cores, 2)
c.Compute = round(c.Compute+other.Compute, 2)
return nil
}
func (c *CPU) Subtract(other CPU) error {
if c.Compute < other.Compute {
return fmt.Errorf("compute: underflow, cannot subtract %v from %v", c.Compute, other.Compute)
}
if c.Cores < other.Cores {
return fmt.Errorf("core: underflow, cannot subtract %v from %v", c.Cores, other.Cores)
}
c.Cores = round(c.Cores-other.Cores, 2)
c.Compute = round(c.Compute-other.Compute, 2)
return nil
}
// RAM represents the RAM information
type RAM struct {
// Size in bytes
Size uint64
// TODO: capture the below fields if required
// Clock speed in Hz
ClockSpeed uint64
// Type represents the RAM type, e.g., "DDR4", "DDR5", "LPDDR4"
Type string
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[RAM] = (*RAM)(nil)
_ Comparable[RAM] = (*RAM)(nil)
)
func (r *RAM) Compare(other RAM) Comparison {
comparison := make(ComplexComparison)
// compare the Size
comparison["Size"] = NumericComparator(r.Size, other.Size)
comparison["ClockSpeed"] = NumericComparator(r.ClockSpeed, other.ClockSpeed)
return comparison["Size"]
}
func (r *RAM) Add(other RAM) error {
r.Size += other.Size
return nil
}
func (r *RAM) Subtract(other RAM) error {
if r.Size < other.Size {
return fmt.Errorf("size: underflow, cannot subtract %v from %v", r.Size, other.Size)
}
r.Size -= other.Size
return nil
}
// Disk represents the disk information
type Disk struct {
// Size in bytes
Size uint64
// TODO: capture the below fields if required
// Model represents the disk model, e.g., "Samsung 970 EVO Plus", "Western Digital Blue SN550"
Model string
// Vendor represents the disk manufacturer, e.g., "Samsung", "Western Digital"
Vendor string
// Type represents the disk type, e.g., "SSD", "HDD", "NVMe"
Type string
// Interface represents the disk interface, e.g., "SATA", "PCIe", "M.2"
Interface string
// Read speed in bytes per second
ReadSpeed uint64
// Write speed in bytes per second
WriteSpeed uint64
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[Disk] = (*Disk)(nil)
_ Comparable[Disk] = (*Disk)(nil)
)
func (d *Disk) Compare(other Disk) Comparison {
comparison := make(ComplexComparison)
// compare the Size
comparison["Size"] = NumericComparator(d.Size, other.Size)
return comparison["Size"]
}
func (d *Disk) Add(other Disk) error {
d.Size += other.Size
return nil
}
func (d *Disk) Subtract(other Disk) error {
if d.Size < other.Size {
return fmt.Errorf("size: underflow, cannot subtract %v from %v", d.Size, other.Size)
}
d.Size -= other.Size
return nil
}
// NetworkInfo represents the network information
// TODO: not yet used, but can be used to capture the network information
type NetworkInfo struct {
// Bandwidth in bits per second (b/s)
Bandwidth uint64
// NetworkType represents the network type, e.g., "Ethernet", "Wi-Fi", "Cellular"
NetworkType string
}
package types
const (
NetP2P = "p2p"
)
// NetworkSpec is a stub. Please expand based on requirements.
type NetworkSpec struct{}
// NetConfig is a stub. Please expand it or completely change it based on requirements.
type NetConfig struct {
NetworkSpec SpecConfig `json:"network_spec"` // Network specification
}
func (nc *NetConfig) GetNetworkConfig() *SpecConfig {
return &nc.NetworkSpec
}
// NetworkStats should contain all network info the user is interested in.
// for now there's only peerID and listening address but reachability, local and remote addr etc...
// can be added when necessary.
type NetworkStats struct {
ID string `json:"id"`
ListenAddr string `json:"listen_addr"`
}
// MessageInfo is a stub. Please expand it or completely change it based on requirements.
type MessageInfo struct {
Info string `json:"info"` // Message information
}
package types
import (
"reflect"
)
func ConvertNumericToFloat64(n any) (float64, bool) {
switch n := n.(type) {
case int, int8, int16, int32, int64:
return float64(reflect.ValueOf(n).Int()), true
case uint, uint8, uint16, uint32, uint64:
return float64(reflect.ValueOf(n).Uint()), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}
package types
import (
"context"
"fmt"
)
// Resources represents the resources of the machine
type Resources struct {
CPU CPU `gorm:"embedded;embeddedPrefix:cpu_"`
GPUs GPUs `gorm:"foreignKey:ResourceID"`
RAM RAM `gorm:"embedded;embeddedPrefix:ram_"`
Disk Disk `gorm:"embedded;embeddedPrefix:disk_"`
}
// implements the Calculable and Comparable interfaces
var (
_ Calculable[Resources] = (*Resources)(nil)
_ Comparable[Resources] = (*Resources)(nil)
)
// Compare compares two Resources objects
func (r *Resources) Compare(other Resources) Comparison {
comparisonMap := ComplexComparison{
"CPU": r.CPU.Compare(other.CPU),
"RAM": r.RAM.Compare(other.RAM),
"Disk": r.Disk.Compare(other.Disk),
"GPUs": r.GPUs.Compare(other.GPUs),
}
return comparisonMap.Result()
}
// Add returns the sum of the resources
func (r *Resources) Add(other Resources) error {
if err := r.CPU.Add(other.CPU); err != nil {
return fmt.Errorf("error adding CPU: %v", err)
}
if err := r.RAM.Add(other.RAM); err != nil {
return fmt.Errorf("error adding RAM: %v", err)
}
if err := r.Disk.Add(other.Disk); err != nil {
return fmt.Errorf("error adding Disk: %v", err)
}
if err := r.GPUs.Add(other.GPUs); err != nil {
return fmt.Errorf("error adding GPUs: %v", err)
}
return nil
}
// Subtract returns the difference of the resources
func (r *Resources) Subtract(other Resources) error {
if err := r.CPU.Subtract(other.CPU); err != nil {
return fmt.Errorf("error subtracting CPU: %v", err)
}
if err := r.RAM.Subtract(other.RAM); err != nil {
return fmt.Errorf("error subtracting RAM: %v", err)
}
if err := r.Disk.Subtract(other.Disk); err != nil {
return fmt.Errorf("error subtracting Disk: %v", err)
}
if err := r.GPUs.Subtract(other.GPUs); err != nil {
return fmt.Errorf("error subtracting GPUs: %v", err)
}
return nil
}
// MachineResources represents the total resources of the machine
type MachineResources struct {
BaseDBModel
Resources
}
// FreeResources represents the free resources of the machine
type FreeResources struct {
BaseDBModel
Resources
}
// OnboardedResources represents the onboarded resources of the machine
type OnboardedResources struct {
BaseDBModel
Resources
}
// ResourceAllocation represents the allocation of resources for a job
type ResourceAllocation struct {
BaseDBModel
JobID string
Resources
}
// SystemSpecs is an interface that defines the methods to get the system specifications of the machine
type SystemSpecs interface {
// GetMachineResources returns the machine resources
GetMachineResources() (MachineResources, error)
}
// UsageMonitor defines the methods to monitor the system usage
type UsageMonitor interface {
// GetUsage returns the resources used by the machine
GetUsage(context.Context) (Resources, error)
}
// ResourceManager is an interface that defines the methods to manage the resources of the machine
type ResourceManager interface {
// AllocateResources allocates the resources required by a job
AllocateResources(context.Context, ResourceAllocation) error
// DeallocateResources deallocates the resources required by a job
DeallocateResources(context.Context, string) error
// GetTotalAllocation returns the total allocations for the jobs
GetTotalAllocation() (Resources, error)
// GetFreeResources returns the free resources in the allocation pool
GetFreeResources(ctx context.Context) (FreeResources, error)
// GetOnboardedResources returns the onboarded resources of the machine
GetOnboardedResources(context.Context) (OnboardedResources, error)
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
UpdateOnboardedResources(context.Context, OnboardedResources) error
// SystemSpecs returns the SystemSpecs instance
SystemSpecs() SystemSpecs
// UsageMonitor returns the UsageMonitor instance
UsageMonitor() UsageMonitor
}
package types
import (
"errors"
"strings"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// SpecConfig represents a configuration for a spec
// A SpecConfig can be used to define an engine spec, a storage volume, etc.
type SpecConfig struct {
// Type of the spec (e.g. docker, firecracker, storage, etc.)
Type string `json:"type"`
// Params of the spec
Params map[string]interface{} `json:"params,omitempty"`
}
type Config interface {
GetNetworkConfig() *SpecConfig
}
// NewSpecConfig creates a new SpecConfig with the given type
func NewSpecConfig(t string) *SpecConfig {
return &SpecConfig{
Type: t,
Params: make(map[string]interface{}),
}
}
// WithParam adds a new key-value pair to the spec params
func (s *SpecConfig) WithParam(key string, value interface{}) *SpecConfig {
if s.Params == nil {
s.Params = make(map[string]interface{})
}
s.Params[key] = value
return s
}
// Normalize ensures that the spec config is in a valid state
func (s *SpecConfig) Normalize() {
if s == nil {
return
}
s.Type = strings.TrimSpace(s.Type)
// Ensure that an empty and nil map are treated the same
if len(s.Params) == 0 {
s.Params = make(map[string]interface{})
}
}
// Validate checks if the spec config is valid
func (s *SpecConfig) Validate() error {
if s == nil {
return errors.New("nil spec config")
}
if validate.IsBlank(s.Type) {
return errors.New("missing spec type")
}
return nil
}
// IsType returns true if the current SpecConfig is of the given type
func (s *SpecConfig) IsType(t string) bool {
if s == nil {
return false
}
t = strings.TrimSpace(t)
return strings.EqualFold(s.Type, t)
}
// IsEmpty returns true if the spec config is empty
func (s *SpecConfig) IsEmpty() bool {
return s == nil || (validate.IsBlank(s.Type) && len(s.Params) == 0)
}
package types
import (
"context"
)
// TelemetryConfig holds the configuration for the telemetry system.
type TelemetryConfig struct {
ServiceName string
GlobalEndpoint string
ObservabilityLevel string
CollectorConfigs map[string]CollectorConfig
TelemetryMode string
}
// CollectorConfig holds the configuration for individual collectors.
type CollectorConfig struct {
CollectorType string
CollectorEndpoint string
}
// Event represents a telemetry event with its details.
type Event struct {
Context context.Context
Level ObservabilityLevel
Message string
Payload map[string]interface{}
}
// ObservabilityLevel defines the levels of observability.
type ObservabilityLevel int
const (
TRACE ObservabilityLevel = 1
DEBUG ObservabilityLevel = 2
INFO ObservabilityLevel = 3
WARN ObservabilityLevel = 4
ERROR ObservabilityLevel = 5
FATAL ObservabilityLevel = 6
)
// ParseObservabilityLevel converts a string representation of the observability level to an integer.
func ParseObservabilityLevel(levelStr string) (ObservabilityLevel, error) {
switch levelStr {
case "TRACE":
return TRACE, nil
case "DEBUG":
return DEBUG, nil
case "INFO":
return INFO, nil
case "WARN":
return WARN, nil
case "ERROR":
return ERROR, nil
case "FATAL":
return FATAL, nil
default:
return INFO, nil
}
}
func (level ObservabilityLevel) String() string {
switch level {
case TRACE:
return "TRACE"
case DEBUG:
return "DEBUG"
case INFO:
return "INFO"
case WARN:
return "WARN"
case ERROR:
return "ERROR"
case FATAL:
return "FATAL"
default:
return "UNKNOWN"
}
}
package types
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// BaseDBModel is a base model for all entities. It'll be mainly used for database
// records.
type BaseDBModel struct {
ID string `gorm:"type:uuid"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// GetID returns the ID of the entity.
func (m BaseDBModel) GetID() string {
return m.ID
}
// BeforeCreate sets the ID and CreatedAt fields before creating a new entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeCreate(_ *gorm.DB) error {
m.ID = uuid.NewString()
m.CreatedAt = time.Now()
return nil
}
// BeforeUpdate sets the UpdatedAt field before updating an entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeUpdate(_ *gorm.DB) error {
m.UpdatedAt = time.Now()
return nil
}
package types
import (
"math"
"reflect"
"slices"
)
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
// round rounds the value to the specified number of decimal places
func round[T float32 | float64](value T, places int) T {
factor := math.Pow(10, float64(places))
roundedValue := math.Round(float64(value)*factor) / factor
return T(roundedValue)
}
func SliceContainsOneValue(slice []Comparison, value Comparison) bool {
// returns true if all elements in the slice are equal to the given value
for _, v := range slice {
if v != value {
return false
}
}
return true
}
func returnBestMatch(dimension []Comparison) (Comparison, int) {
// while i feel that there could be some weird matrix sorting algorithm that could be used here
// i can't think of any right now, so i will just iterate over the matrix and return matches
// in somewhat manual way
for i, v := range dimension {
if v == Equal {
return v, i // selecting an equal match is the most efficient match
}
}
for i, v := range dimension {
if v == Better {
return v, i // selecting a better is also not bad
}
}
for i, v := range dimension {
if v == Worse {
return v, i // this is just for sport
}
}
for i, v := range dimension {
if v == Error {
return v, i // this is just for sport
}
}
return Error, -1
}
func removeIndex(slice [][]Comparison, index int) [][]Comparison {
// removeIndex removes the element at the specified index from each sub-slice in the given slice.
// If the index is out of bounds for a sub-slice, the function leaves that sub-slice unmodified.
for i, c := range slice {
if index < 0 || index >= len(c) {
// Index is out of bounds, leave the sub-slice unmodified
continue
}
slice[i] = append(c[:index], c[index+1:]...)
}
return slice
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
package utils
import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/cosmos/btcutil/bech32"
"github.com/ethereum/go-ethereum/common"
"github.com/fivebinaries/go-cardano-serialization/address"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
)
// KoiosEndpoint type for Koios rest api endpoints
type KoiosEndpoint string
const (
// KoiosMainnet - mainnet Koios rest api endpoint
KoiosMainnet KoiosEndpoint = "api.koios.rest"
// KoiosPreProd - testnet preprod Koios rest api endpoint
KoiosPreProd KoiosEndpoint = "preprod.koios.rest"
)
type UTXOs struct {
TxHash string `json:"tx_hash"`
IsSpent bool `json:"is_spent"`
}
type TxHashResp struct {
TxHash string `json:"tx_hash"`
TransactionType string `json:"transaction_type"`
DateTime string `json:"date_time"`
}
type ClaimCardanoTokenBody struct {
ComputeProviderAddress string `json:"compute_provider_address"`
TxHash string `json:"tx_hash"`
}
type RewardRespToCPD struct {
ServiceProviderAddr string `json:"service_provider_addr"`
ComputeProviderAddr string `json:"compute_provider_addr"`
RewardType string `json:"reward_type,omitempty"`
SignatureDatum string `json:"signature_datum,omitempty"`
MessageHashDatum string `json:"message_hash_datum,omitempty"`
Datum string `json:"datum,omitempty"`
SignatureAction string `json:"signature_action,omitempty"`
MessageHashAction string `json:"message_hash_action,omitempty"`
Action string `json:"action,omitempty"`
}
type UpdateTxStatusBody struct {
Address string `json:"address,omitempty"`
}
func GetJobTxHashes(size int, clean string) ([]TxHashResp, error) {
if clean != "done" && clean != "refund" && clean != "withdraw" && clean != "" {
return nil, fmt.Errorf("invalid clean_tx parameter")
}
err := db.DB.Where("transaction_type = ?", clean).Delete(&types.Services{}).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
}
resp := make([]TxHashResp, 0)
services := make([]types.Services, 0)
if size == 0 {
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type is NOT NULL").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("no job deployed to request reward for: %w", err)
}
} else {
services, err = getLimitedTransactions(size)
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("could not get limited transactions: %w", err)
}
}
for _, service := range services {
resp = append(resp, TxHashResp{
TxHash: service.TxHash,
TransactionType: service.TransactionType,
DateTime: service.CreatedAt.String(),
})
}
return resp, nil
}
func RequestReward(claim ClaimCardanoTokenBody) (*RewardRespToCPD, error) {
// At some point, management dashboard should send container ID to identify
// against which container we are requesting reward
service := types.Services{
TxHash: claim.TxHash,
}
// SELECTs the first record; first record which is not marked as delete
err := db.DB.Where("tx_hash = ?", claim.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
return nil, fmt.Errorf("unknown tx hash: %w", err)
}
zlog.Sugar().Infof("service found from txHash: %+v", service)
if service.JobStatus == "running" {
return nil, fmt.Errorf("job is still running")
// c.JSON(503, gin.H{"error": "the job is still running"})
}
reward := RewardRespToCPD{
ServiceProviderAddr: service.ServiceProviderAddr,
ComputeProviderAddr: service.ComputeProviderAddr,
RewardType: service.TransactionType,
SignatureDatum: service.SignatureDatum,
MessageHashDatum: service.MessageHashDatum,
Datum: service.Datum,
SignatureAction: service.SignatureAction,
MessageHashAction: service.MessageHashAction,
Action: service.Action,
}
return &reward, nil
}
func SendStatus(status types.BlockchainTxStatus) string {
if status.TransactionStatus == "success" {
zlog.Sugar().Infof("withdraw transaction successful - updating DB")
// Partial deletion of entry
var service types.Services
err := db.DB.Where("tx_hash = ?", status.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
}
service.TransactionType = "done"
db.DB.Save(&service)
}
return status.TransactionStatus
}
func UpdateStatus(body UpdateTxStatusBody) error {
utxoHashes, err := GetUTXOsOfSmartContract(body.Address, KoiosPreProd)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to fetch UTXOs from Blockchain: %w", err)
}
fiveMinAgo := time.Now().Add(-5 * time.Minute)
var services []types.Services
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Where("deleted_at IS NULL").
Where("created_at <= ?", fiveMinAgo).
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("no job deployed to request reward for: %w", err)
}
err = UpdateTransactionStatus(services, utxoHashes)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to update transaction status")
}
return nil
}
func getLimitedTransactions(sizeDone int) ([]types.Services, error) {
var doneServices []types.Services
var services []types.Services
err := db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type = ?", "done").
Order("created_at DESC").
Limit(sizeDone).
Find(&doneServices).Error
if err != nil {
return []types.Services{}, err
}
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
return []types.Services{}, err
}
services = append(services, doneServices...)
return services, nil
}
// isValidCardano checks if the cardano address is valid
func isValidCardano(addr string, valid *bool) {
defer func() {
if r := recover(); r != nil {
*valid = false
}
}()
if _, err := address.NewAddress(addr); err == nil {
*valid = true
}
}
// ValidateAddress checks if the wallet address is a valid ethereum/cardano address
func ValidateAddress(addr string) error {
if common.IsHexAddress(addr) {
return errors.New("ethereum wallet address not allowed")
}
validCardano := false
isValidCardano(addr, &validCardano)
if validCardano {
return nil
}
return errors.New("invalid cardano wallet address")
}
func GetAddressPaymentCredential(addr string) (string, error) {
_, data, err := bech32.Decode(addr, 1023)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
converted, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
return hex.EncodeToString(converted)[2:58], nil
}
// GetTxReceiver returns the list of receivers of a transaction from the transaction hash
func GetTxReceiver(txHash string, endpoint KoiosEndpoint) (string, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_info", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return "", err
}
defer resp.Body.Close()
res := []struct {
Outputs []struct {
InlineDatum struct {
Value struct {
Fields []struct {
Bytes string `json:"bytes"`
} `json:"fields"`
} `json:"value"`
} `json:"inline_datum"`
} `json:"outputs"`
}{}
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&res); err != nil && err != io.EOF {
return "", err
}
if len(res) == 0 || len(res[0].Outputs) == 0 || len(res[0].Outputs[1].InlineDatum.Value.Fields) == 0 {
return "", fmt.Errorf("unable to find receiver")
}
receiver := res[0].Outputs[1].InlineDatum.Value.Fields[1].Bytes
return receiver, nil
}
// GetTxConfirmations returns the number of confirmations of a transaction from the transaction hash
func GetTxConfirmations(txHash string, endpoint KoiosEndpoint) (int, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_status", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return 0, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}
var res []struct {
TxHash string `json:"tx_hash"`
Confirmations int `json:"num_confirmations"`
}
if err := json.Unmarshal(body, &res); err != nil {
return 0, err
}
return res[len(res)-1].Confirmations, nil
}
// WaitForTxConfirmation waits for a transaction to be confirmed
func WaitForTxConfirmation(confirmations int, timeout time.Duration, txHash string, endpoint KoiosEndpoint) error {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
conf, err := GetTxConfirmations(txHash, endpoint)
if err != nil {
return err
}
if conf >= confirmations {
return nil
}
case <-time.After(timeout):
return errors.New("timeout")
}
}
}
// GetUTXOsOfSmartContract fetch all utxos of smart contract and return list of tx_hash
func GetUTXOsOfSmartContract(address string, endpoint KoiosEndpoint) ([]string, error) {
type Request struct {
Address []string `json:"_addresses"`
Extended bool `json:"_extended"`
}
reqBody, _ := json.Marshal(Request{Address: []string{address}, Extended: true})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/address_utxos", endpoint),
"application/json",
bytes.NewBuffer(reqBody),
)
if err != nil {
return nil, fmt.Errorf("error making POST request: %v", err)
}
defer resp.Body.Close()
var utxos []UTXOs
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&utxos); err != nil && err != io.EOF {
return nil, err
}
utxoHashes := make([]string, 0)
for _, utxo := range utxos {
utxoHashes = append(utxoHashes, utxo.TxHash)
}
return utxoHashes, nil
}
// UpdateTransactionStatus updates the status of claimed transactions in local DB
func UpdateTransactionStatus(services []types.Services, utxoHashes []string) error {
for _, service := range services {
if !SliceContains(utxoHashes, service.TxHash) {
switch service.TransactionType {
case "withdraw":
{
service.TransactionType = transactionWithdrawnStatus
}
case "refund":
{
service.TransactionType = transactionRefundedStatus
}
case "distribute-50":
case "distribute-75":
{
service.TransactionType = transactionDistributedStatus
}
}
s := service
if err := db.DB.Save(&s).Error; err != nil {
return err
}
}
}
return nil
}
package utils
import (
"errors"
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
)
func GetDirectorySize(fs afero.Fs, path string) (int64, error) {
var size int64
err := afero.Walk(fs, path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to calculate volume size: %w", err)
}
return size, nil
}
// WriteToFile writes data to a file.
func WriteToFile(fs afero.Fs, data []byte, filePath string) (string, error) {
if err := fs.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return "", fmt.Errorf("failed to open path: %w", err)
}
file, err := fs.Create(filePath)
if err != nil {
return "", fmt.Errorf("failed to create path: %w", err)
}
defer file.Close()
n, err := file.Write(data)
if err != nil {
return "", fmt.Errorf("failed to write data to path: %w", err)
}
if n != len(data) {
return "", errors.New("failed to write the size of data to file")
}
return filePath, nil
}
// FileExists checks if destination file exists
func FileExists(fs afero.Fs, filename string) bool {
info, err := fs.Stat(filename)
if os.IsNotExist(err) {
return false
}
return !info.IsDir()
}
package utils
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"path"
)
type HTTPClient struct {
BaseURL string
APIVersion string
Client *http.Client
}
func NewHTTPClient(baseURL, version string) *HTTPClient {
return &HTTPClient{
BaseURL: baseURL,
APIVersion: version,
Client: http.DefaultClient,
}
}
// MakeRequest performs an HTTP request with the given method, path, and body
// It returns the response body, status code, and an error if any
func (c *HTTPClient) MakeRequest(method, relativePath string, body []byte) ([]byte, int, error) {
url, err := url.Parse(c.BaseURL)
if err != nil {
return nil, 0, fmt.Errorf("failed to parse base URL: %v", err)
}
url.Path = path.Join(c.APIVersion, relativePath)
req, err := http.NewRequest(method, url.String(), bytes.NewBuffer(body))
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
resp, err := c.Client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("request failed: %v", err)
}
defer resp.Body.Close()
// Read the response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("failed to read response body: %v", err)
}
return respBody, resp.StatusCode, nil
}
package utils
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
const transactionWithdrawnStatus = "withdrawn"
const transactionRefundedStatus = "refunded"
const transactionDistributedStatus = "distributed"
func init() {
zlog = logger.OtelZapLogger("utils")
}
package utils
import (
"io"
"sync"
"time"
)
type IOProgress struct {
n float64
size float64
started time.Time
estimated time.Time
err error
}
type Reader struct {
reader io.Reader
lock sync.RWMutex
Progress IOProgress
}
type Writer struct {
writer io.Writer
lock sync.RWMutex
Progress IOProgress
}
func ReaderWithProgress(r io.Reader, size int64) *Reader {
return &Reader{
reader: r,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func WriterWithProgress(w io.Writer, size int64) *Writer {
return &Writer{
writer: w,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func (r *Reader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.lock.Lock()
r.Progress.n += float64(n)
r.Progress.err = err
r.lock.Unlock()
return n, err
}
func (w *Writer) Write(p []byte) (n int, err error) {
n, err = w.writer.Write(p)
w.lock.Lock()
w.Progress.n += float64(n)
w.Progress.err = err
w.lock.Unlock()
return n, err
}
func (p IOProgress) Size() float64 {
return p.size
}
func (p IOProgress) N() float64 {
return p.n
}
func (p IOProgress) Complete() bool {
if p.err == io.EOF {
return true
}
if p.size == -1 {
return false
}
return p.n >= p.size
}
// Percent calculates the percentage complete.
func (p IOProgress) Percent() float64 {
if p.n == 0 {
return 0
}
if p.n >= p.size {
return 100
}
return 100.0 / (p.size / p.n)
}
func (p IOProgress) Remaining() time.Duration {
if p.estimated.IsZero() {
return time.Until(p.Estimated())
}
return time.Until(p.estimated)
}
func (p IOProgress) Estimated() time.Time {
ratio := p.n / p.size
past := float64(time.Since(p.started))
if p.n > 0.0 {
total := time.Duration(past / ratio)
p.estimated = p.started.Add(total)
}
return p.estimated
}
package utils
import (
"fmt"
"strings"
"sync"
)
// A SyncMap is a concurrency-safe sync.Map that uses strongly-typed
// method signatures to ensure the types of its stored data are known.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// SyncMapFromMap converts a standard Go map to a concurrency-safe SyncMap.
func SyncMapFromMap[K comparable, V any](m map[K]V) *SyncMap[K, V] {
ret := &SyncMap[K, V]{}
for k, v := range m {
ret.Put(k, v)
}
return ret
}
// Get retrieves the value associated with the given key from the map.
// It returns the value and a boolean indicating whether the key was found.
func (m *SyncMap[K, V]) Get(key K) (V, bool) {
value, ok := m.Load(key)
if !ok {
var empty V
return empty, false
}
return value.(V), true
}
// Put inserts or updates a key-value pair in the map.
func (m *SyncMap[K, V]) Put(key K, value V) {
m.Store(key, value)
}
// Iter iterates over each key-value pair in the map, executing the provided function on each pair.
// The iteration stops if the provided function returns false.
func (m *SyncMap[K, V]) Iter(ranger func(key K, value V) bool) {
m.Range(func(key, value any) bool {
k := key.(K)
v := value.(V)
return ranger(k, v)
})
}
// Keys returns a slice containing all the keys present in the map.
func (m *SyncMap[K, V]) Keys() []K {
var keys []K
m.Iter(func(key K, _ V) bool {
keys = append(keys, key)
return true
})
return keys
}
// String provides a string representation of the map, listing all key-value pairs.
func (m *SyncMap[K, V]) String() string {
// Use a strings.Builder for efficient string concatenation.
var sb strings.Builder
sb.Write([]byte(`{`))
m.Range(func(key, value any) bool {
// Append each key-value pair to the string builder.
sb.Write([]byte(fmt.Sprintf(`%s=%s`, key, value)))
return true
})
sb.Write([]byte(`}`))
return sb.String()
}
package utils
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"math/big"
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"time"
"github.com/google/uuid"
"github.com/spf13/afero"
"golang.org/x/exp/slices"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
)
const (
KernelFileURL = "https://d.nunet.io/fc/vmlinux"
KernelFilePath = "/etc/nunet/vmlinux"
FilesystemURL = "https://d.nunet.io/fc/nunet-fc-ubuntu-20.04-0.ext4"
FilesystemPath = "/etc/nunet/nunet-fc-ubuntu-20.04-0.ext4"
)
// DownloadFile downloads a file from a url and saves it to a filepath
func DownloadFile(url string, filepath string) (err error) {
zlog.Sugar().Infof("Downloading file '", filepath, "' from '", url, "'")
file, err := os.Create(filepath)
if err != nil {
return err
}
defer file.Close()
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
_, err = io.Copy(file, resp.Body)
if err != nil {
return err
}
log.Println("Finished downloading file '", filepath, "'")
return nil
}
// ReadHTTPString GET request to http endpoint and return response as string
func ReadHTTPString(url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(respBody), nil
}
// RandomString generates a random string of length n
func RandomString(n int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
sb.WriteByte(charset[n.Int64()])
}
return sb.String(), nil
}
// GenerateMachineUUID generates a machine uuid
func GenerateMachineUUID() (string, error) {
var machine types.MachineUUID
machineUUID, err := uuid.NewDCEGroup()
if err != nil {
return "", err
}
machine.UUID = machineUUID.String()
return machine.UUID, nil
}
// GetMachineUUID returns the machine uuid from the DB
func GetMachineUUID() string {
var machine types.MachineUUID
uuid, err := GenerateMachineUUID()
if err != nil {
zlog.Sugar().Errorf("could not generate machine uuid: %v", err)
}
machine.UUID = uuid
result := db.DB.FirstOrCreate(&machine)
if result.Error != nil {
zlog.Sugar().Errorf("could not find or create machine uuid record in DB: %v", result.Error)
}
return machine.UUID
}
// SliceContains checks if a string exists in a slice
func SliceContains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
// DeleteFile deletes a file, with or without a backup
func DeleteFile(path string, backup bool) (err error) {
if backup {
err = os.Rename(path, fmt.Sprintf("%s.bk.%d", path, time.Now().Unix()))
} else {
err = os.Remove(path)
}
return
}
// ReadyForElastic checks if the device is ready to send logs to elastic
func ReadyForElastic() bool {
elasticToken := types.ElasticToken{}
db.DB.Find(&elasticToken)
return elasticToken.NodeID != "" && elasticToken.ChannelName != ""
}
// CreateDirectoryIfNotExists creates a directory if it does not exist
func CreateDirectoryIfNotExists(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
err := os.MkdirAll(path, 0o755)
if err != nil {
return err
}
}
return nil
}
// CalculateSHA256Checksum calculates the SHA256 checksum of a file
func CalculateSHA256Checksum(filePath string) (string, error) {
// Open the file for reading
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
// Create a new SHA-256 hash
hash := sha256.New()
// Copy the file's contents into the hash object
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
// Calculate the checksum and return it as a hexadecimal string
checksum := hex.EncodeToString(hash.Sum(nil))
return checksum, nil
}
// put checksum in file
func CreateCheckSumFile(filePath string, checksum string) (string, error) {
sha256FilePath := fmt.Sprintf("%s.sha256.txt", filePath)
sha256File, err := os.Create(sha256FilePath)
if err != nil {
return "", fmt.Errorf("unable to create SHA-256 checksum file: %v", err)
}
defer sha256File.Close()
_, err = sha256File.WriteString(checksum)
if err != nil {
return "", fmt.Errorf("unable to write to SHA-256 checksum file: %v", err)
}
return sha256FilePath, nil
}
// SanitizeArchivePath Sanitize archive file pathing from "G305: Zip Slip vulnerability"
func SanitizeArchivePath(d, t string) (v string, err error) {
v = filepath.Join(d, t)
if strings.HasPrefix(v, filepath.Clean(d)) {
return v, nil
}
return "", fmt.Errorf("%s: %s", "content filepath is tainted", t)
}
// ExtractTarGzToPath extracts a tar.gz file to a specified path
func ExtractTarGzToPath(tarGzFilePath, extractedPath string) error {
// Ensure the target directory exists; create it if it doesn't.
if err := os.MkdirAll(extractedPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating target directory: %v", err)
}
tarGzFile, err := os.Open(tarGzFilePath)
if err != nil {
return fmt.Errorf("error opening tar.gz file: %v", err)
}
defer tarGzFile.Close()
gzipReader, err := gzip.NewReader(tarGzFile)
if err != nil {
return fmt.Errorf("error creating gzip reader: %v", err)
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading tar header: %v", err)
}
// Construct the full target path by joining the target directory with
// the name of the file or directory from the archive.
fullTargetPath, err := SanitizeArchivePath(extractedPath, header.Name)
if err != nil {
return fmt.Errorf("failed to santize path %w", err)
}
// Ensure that the directory path leading to the file exists.
if header.FileInfo().IsDir() {
// Create the directory and any parent directories as needed.
if err := os.MkdirAll(fullTargetPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
} else {
// Create the file and any parent directories as needed.
if err := os.MkdirAll(filepath.Dir(fullTargetPath), os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// Create a new file with the specified path.
newFile, err := os.Create(fullTargetPath)
if err != nil {
return fmt.Errorf("error creating file: %v", err)
}
defer newFile.Close()
// Copy the file contents from the tar archive to the new file.
for {
_, err := io.CopyN(newFile, tarReader, 1024)
if err != nil {
if err == io.EOF {
break
}
return err
}
}
}
}
return nil
}
// CheckWSL check if running in WSL
func CheckWSL(afs afero.Afero) (bool, error) {
file, err := afs.Open("/proc/version")
if err != nil {
return false, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "Microsoft") || strings.Contains(line, "WSL") {
return true, nil
}
}
if scanner.Err() != nil {
return false, scanner.Err()
}
return false, nil
}
// SaveServiceInfo updates service info into SP's DMS for claim Reward by SP user
func SaveServiceInfo(cpService types.Services) error {
var spService types.Services
err := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Find(&spService).Error
if err != nil {
return fmt.Errorf("unable to find service on SP side: %v", err)
}
cpService.ID = spService.ID
cpService.CreatedAt = spService.CreatedAt
result := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Updates(&cpService)
if result.Error != nil {
return fmt.Errorf("unable to update service info on SP side: %v", result.Error.Error())
}
return nil
}
func RandomBool() (bool, error) {
n, err := rand.Int(rand.Reader, big.NewInt(2))
if err != nil {
return false, err
}
// Return true if the number is 1, otherwise false
return n.Int64() == 1, nil
}
func IsExecutorType(v interface{}) bool {
_, ok := v.(types.ExecutorType)
return ok
}
func IsGPUVendor(v interface{}) bool {
_, ok := v.(types.GPUVendor)
return ok
}
func IsJobType(v interface{}) bool {
_, ok := v.(types.JobType)
return ok
}
func IsJobTypes(v interface{}) bool {
_, ok := v.(types.JobTypes)
return ok
}
func IsExecutor(v interface{}) bool {
_, ok := v.(types.Executor)
return ok
}
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
func NoIntersectionSlices(slice1, slice2 []interface{}) bool {
result := false // the default result is false
for _, subElement := range slice1 {
if slices.Contains(slice2, subElement) {
result = false
} else {
result = true
}
}
return result
}
// IntersectionStringSlices returns the intersection of two slices of strings.
func IntersectionSlices(slice1, slice2 []interface{}) []interface{} {
// Create a map to store strings from the first slice.
executorMap := make(map[interface{}]bool)
// Iterate through the first slice and add elements to the map.
for _, str := range slice1 {
executorMap[str] = true
}
// Create a slice to store the intersection of the strings.
intersectionSlice := []interface{}{}
// Iterate through the second slice and check for common elements.
for _, str := range slice2 {
if executorMap[str] {
// If the string is found in the map, add to the intersection slice.
intersectionSlice = append(intersectionSlice, str)
// Remove the string from the map to avoid duplicates in the result.
delete(executorMap, str)
}
}
return intersectionSlice
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
package validate
import (
"strings"
)
// IsBlank checks if a string is empty or contains only whitespace
func IsBlank(s string) bool {
return len(strings.TrimSpace(s)) == 0
}
// IsNotBlank checks if a string is not empty and does not contain only whitespace
func IsNotBlank(s string) bool {
return !IsBlank(s)
}
// Just checks if a variable is a string
func IsLiteral(s interface{}) bool {
switch s.(type) {
case string:
return true
default:
return false
}
}