package actor
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
type BasicActor struct {
dispatch *Dispatch
scheduler *bt.Scheduler
registry *registry
network network.Network
security SecurityContext
limiter RateLimiter
params BasicActorParams
self Handle
mx sync.Mutex
subscriptions map[string]uint64
}
type BasicActorParams struct{}
var _ Actor = (*BasicActor)(nil)
// New creates a new basic actor.
func New(scheduler *bt.Scheduler, net network.Network, security *BasicSecurityContext, limiter RateLimiter, params BasicActorParams, self Handle, opt ...DispatchOption) (*BasicActor, error) {
if scheduler == nil {
return nil, errors.New("scheduler is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if security == nil {
return nil, errors.New("security is nil")
}
dispatchOptions := []DispatchOption{WithRateLimiter(limiter)}
dispatchOptions = append(dispatchOptions, opt...)
dispatch := NewDispatch(security, dispatchOptions...)
actor := &BasicActor{
dispatch: dispatch,
scheduler: scheduler,
registry: newRegistry(),
network: net,
security: security,
limiter: limiter,
params: params,
self: self,
subscriptions: make(map[string]uint64),
}
return actor, nil
}
func (a *BasicActor) Start() error {
// Network messages
if err := a.network.HandleMessage(
fmt.Sprintf("actor/%s/messages/0.0.1", a.self.Address.InboxAddress),
a.handleMessage,
); err != nil {
return fmt.Errorf("starting actor: %s: %w", a.self.ID, err)
}
// and start the internal goroutines
a.dispatch.Start()
a.scheduler.Start()
return nil
}
func (a *BasicActor) handleMessage(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling message: %s", err)
return
}
if !a.self.ID.Equal(msg.To.ID) {
log.Warnf("message is not for ourselves: %s %s", a.self.ID, msg.To.ID)
return
}
if !a.limiter.Allow(msg) {
log.Warnf("incoming message invoking %s not allowed by limiter", msg.Behavior)
return
}
_ = a.Receive(msg)
}
func (a *BasicActor) Context() context.Context {
return a.dispatch.Context()
}
func (a *BasicActor) Handle() Handle {
return a.self
}
func (a *BasicActor) Security() SecurityContext {
return a.security
}
func (a *BasicActor) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
return a.dispatch.AddBehavior(behavior, continuation, opt...)
}
func (a *BasicActor) RemoveBehavior(behavior string) {
a.dispatch.RemoveBehavior(behavior)
}
func (a *BasicActor) Receive(msg Envelope) error {
if a.self.ID.Equal(msg.To.ID) {
return a.dispatch.Receive(msg)
}
if msg.IsBroadcast() {
return a.dispatch.Receive(msg)
}
return fmt.Errorf("bad receiver: %w", ErrInvalidMessage)
}
func (a *BasicActor) Send(msg Envelope) error {
if msg.To.ID.Equal(a.self.ID) {
return a.Receive(msg)
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
invoke := []Capability{Capability(msg.Behavior)}
var delegate []Capability
if msg.Options.ReplyTo != "" {
delegate = append(delegate, Capability(msg.Options.ReplyTo))
}
if err := a.security.Provide(&msg, invoke, delegate); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
err = a.network.SendMessage(
a.Context(),
msg.To.Address.HostID,
types.MessageEnvelope{
Type: types.MessageType(
fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress),
),
Data: data,
},
msg.Expiry(),
)
if err != nil {
return fmt.Errorf("sending message to %s: %w", msg.To.ID, err)
}
return nil
}
func (a *BasicActor) Invoke(msg Envelope) (<-chan Envelope, error) {
if msg.Options.ReplyTo == "" {
msg.Options.ReplyTo = fmt.Sprintf("/dms/actor/replyto/%d", a.security.Nonce())
}
result := make(chan Envelope, 1)
if err := a.dispatch.AddBehavior(
msg.Options.ReplyTo,
func(reply Envelope) {
result <- reply
close(result)
},
WithBehaviorExpiry(msg.Options.Expire),
WithBehaviorOneShot(true),
); err != nil {
return nil, fmt.Errorf("adding reply behavior: %w", err)
}
if err := a.Send(msg); err != nil {
a.dispatch.RemoveBehavior(msg.Options.ReplyTo)
return nil, fmt.Errorf("sending message: %w", err)
}
return result, nil
}
func (a *BasicActor) Publish(msg Envelope) error {
if !msg.IsBroadcast() {
return ErrInvalidMessage
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
broadcast := []Capability{Capability(msg.Behavior)}
if err := a.security.ProvideBroadcast(&msg, msg.Options.Topic, broadcast); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
if err := a.network.Publish(a.Context(), msg.Options.Topic, data); err != nil {
return fmt.Errorf("publishing message: %w", err)
}
return nil
}
func (a *BasicActor) Subscribe(topic string, setup ...BroadcastSetup) error {
a.mx.Lock()
defer a.mx.Unlock()
_, ok := a.subscriptions[topic]
if ok {
return nil
}
subID, err := a.network.Subscribe(
a.Context(),
topic,
a.handleBroadcast,
func(data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
return a.validateBroadcast(topic, data, validatorData)
},
)
if err != nil {
return fmt.Errorf("subscribe: %w", err)
}
for _, f := range setup {
if err := f(topic); err != nil {
_ = a.network.Unsubscribe(topic, subID)
return fmt.Errorf("setup broadcast topic: %w", err)
}
}
a.subscriptions[topic] = subID
return nil
}
func (a *BasicActor) validateBroadcast(topic string, data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
var msg Envelope
if validatorData != nil {
if _, ok := validatorData.(Envelope); !ok {
log.Warnf("bogus pubsub validation data: %v", validatorData)
return network.ValidationReject, nil
}
// we have already validated the message, just short-circuit
return network.ValidationAccept, validatorData
} else if err := json.Unmarshal(data, &msg); err != nil {
return network.ValidationReject, nil
}
if !msg.IsBroadcast() {
return network.ValidationReject, nil
}
if msg.Options.Topic != topic {
return network.ValidationReject, nil
}
if msg.Expired() {
return network.ValidationIgnore, nil
}
if err := a.security.Verify(msg); err != nil {
return network.ValidationReject, nil
}
if !a.limiter.Allow(msg) {
log.Warnf("incoming broadcast message in %s not allowed by limiter", topic)
return network.ValidationIgnore, nil
}
return network.ValidationAccept, msg
}
func (a *BasicActor) handleBroadcast(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling broadcast message: %s", err)
return
}
if err := a.Receive(msg); err != nil {
log.Warnf("error receiving broadcast message: %s", err)
}
}
func (a *BasicActor) Stop() error {
a.dispatch.close()
for topic, subID := range a.subscriptions {
err := a.network.Unsubscribe(topic, subID)
if err != nil {
log.Debugf("error unsubscribing from %s: %s", topic, err)
}
}
return nil
}
func (a *BasicActor) Limiter() RateLimiter {
return a.limiter
}
package actor
import (
"context"
"fmt"
"sync"
"time"
)
var (
DefaultDispatchGCInterval = 120 * time.Second
DefaultDispatchWorkers = 1
)
// Dispatch provides a reaction kernel with multithreaded dispatch and oneshot
// continuations.
type Dispatch struct {
ctx context.Context
close func()
sctx SecurityContext
mx sync.Mutex
q chan Envelope // incoming message queue
vq chan Envelope // verified message queue
behaviors map[string]*BehaviorState
started bool
options DispatchOptions
}
type DispatchOptions struct {
Limiter RateLimiter
GCInterval time.Duration
Workers int
}
type BehaviorState struct {
cont Behavior
opt BehaviorOptions
}
type DispatchOption func(o *DispatchOptions)
func WithDispatchWorkers(count int) DispatchOption {
return func(o *DispatchOptions) {
o.Workers = count
}
}
func WithDispatchGCInterval(dt time.Duration) DispatchOption {
return func(o *DispatchOptions) {
o.GCInterval = dt
}
}
func WithRateLimiter(limiter RateLimiter) DispatchOption {
return func(o *DispatchOptions) {
o.Limiter = limiter
}
}
func NewDispatch(sctx SecurityContext, opt ...DispatchOption) *Dispatch {
ctx, cancel := context.WithCancel(context.Background())
k := &Dispatch{
sctx: sctx,
ctx: ctx,
close: cancel,
q: make(chan Envelope),
vq: make(chan Envelope),
behaviors: make(map[string]*BehaviorState),
options: DispatchOptions{
GCInterval: DefaultDispatchGCInterval,
Workers: DefaultDispatchWorkers,
Limiter: NoRateLimiter{},
},
}
for _, f := range opt {
f(&k.options)
}
return k
}
func (k *Dispatch) Start() {
k.mx.Lock()
defer k.mx.Unlock()
if !k.started {
for i := 0; i < k.options.Workers; i++ {
go k.recv()
}
go k.dispatch()
go k.gc()
k.started = true
}
}
func (k *Dispatch) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
st := &BehaviorState{
cont: continuation,
opt: BehaviorOptions{
Capability: []Capability{Capability(behavior)},
},
}
for _, f := range opt {
if err := f(&st.opt); err != nil {
return fmt.Errorf("adding behavior: %w", err)
}
}
k.mx.Lock()
defer k.mx.Unlock()
k.behaviors[behavior] = st
return nil
}
func (k *Dispatch) RemoveBehavior(behavior string) {
k.mx.Lock()
defer k.mx.Unlock()
delete(k.behaviors, behavior)
}
func (k *Dispatch) Receive(msg Envelope) error {
select {
case k.q <- msg:
return nil
case <-k.ctx.Done():
return k.ctx.Err()
}
}
func (k *Dispatch) Context() context.Context {
return k.ctx
}
func (k *Dispatch) recv() {
for {
select {
case msg := <-k.q:
if err := k.sctx.Verify(msg); err != nil {
log.Debugf("failed to verify message from %s: %s", msg.From, err)
continue
}
k.vq <- msg
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) dispatch() {
for {
select {
case msg := <-k.vq:
k.mx.Lock()
b, ok := k.behaviors[msg.Behavior]
if !ok {
k.mx.Unlock()
log.Debugf("unknown behavior %s", msg.Behavior)
continue
}
if b.Expired(time.Now()) {
delete(k.behaviors, msg.Behavior)
k.mx.Unlock()
log.Debugf("expired behavior %s", msg.Behavior)
continue
}
if msg.IsBroadcast() {
if err := k.sctx.RequireBroadcast(msg, b.opt.Topic, b.opt.Capability); err != nil {
k.mx.Unlock()
log.Warnf("broadcast message from %s does not have the required capability %s %s: %s", msg.From, b.opt.Capability, string(msg.Capability), err)
continue
}
} else if err := k.sctx.Require(msg, b.opt.Capability); err != nil {
k.mx.Unlock()
log.Warnf("message from %s does not have the required capability %s %s: %s", msg.From, b.opt.Capability, string(msg.Capability), err)
continue
}
if b.opt.OneShot {
delete(k.behaviors, msg.Behavior)
}
k.mx.Unlock()
if err := k.options.Limiter.Acquire(msg); err != nil {
k.sctx.Discard(msg)
log.Debugf("limiter rejected message from %s: %s", msg.From, err)
continue
}
msg.Discard = func() {
k.sctx.Discard(msg)
}
log.Debugf("dispatching message from %s to %s", msg.From, msg.Behavior)
go func() {
defer k.options.Limiter.Release(msg)
b.cont(msg)
}()
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) gc() {
ticker := time.NewTicker(k.options.GCInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
k.mx.Lock()
now := time.Now()
for x, b := range k.behaviors {
if b.Expired(now) {
delete(k.behaviors, x)
}
}
k.mx.Unlock()
case <-k.ctx.Done():
return
}
}
}
func (b *BehaviorState) Expired(now time.Time) bool {
if b.opt.Expire > 0 {
return uint64(now.UnixNano()) > b.opt.Expire
}
return false
}
func WithBehaviorExpiry(expire uint64) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Expire = expire
return nil
}
}
func WithBehaviorCapability(require ...Capability) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Capability = require
return nil
}
}
func WithBehaviorOneShot(oneShot bool) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.OneShot = oneShot
return nil
}
}
func WithBehaviorTopic(topic string) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Topic = topic
return nil
}
}
package actor
import (
"fmt"
"gitlab.com/nunet/device-management-service/lib/did"
)
func (h *Handle) Empty() bool {
return h.ID.Empty() &&
h.DID.Empty() &&
h.Address.Empty()
}
func (h *Handle) String() string {
var idStr string
idDID, err := did.FromID(h.ID)
if err == nil {
idStr = idDID.String()
}
return fmt.Sprintf("%s[%s]@%s", idStr, h.DID, h.Address)
}
func HandleFromString(_ string) (Handle, error) {
// TODO
return Handle{}, ErrTODO
}
func (a *Address) Empty() bool {
return a.HostID == "" && a.InboxAddress == ""
}
func (a *Address) String() string {
return a.HostID + ":" + a.InboxAddress
}
func AddressFromString(_ string) (Address, error) {
// TODO
return Address{}, ErrTODO
}
package actor
import (
"strings"
"sync"
)
// NoRateLimiter is the null limiter, that does not rate limit
type NoRateLimiter struct{}
var _ RateLimiter = NoRateLimiter{}
type BasicRateLimiter struct {
cfg RateLimiterConfig
mx sync.Mutex
activeBroadcast int
activeTopics map[string]int
activePublic int
}
var _ RateLimiter = (*BasicRateLimiter)(nil)
// implementation
func (l NoRateLimiter) Allow(_ Envelope) bool { return true }
func (l NoRateLimiter) Acquire(_ Envelope) error { return nil }
func (l NoRateLimiter) Release(_ Envelope) {}
func (l NoRateLimiter) Config() RateLimiterConfig { return RateLimiterConfig{} }
func (l NoRateLimiter) SetConfig(_ RateLimiterConfig) {}
func DefaultRateLimiterConfig() RateLimiterConfig {
return RateLimiterConfig{
PublicLimitAllow: 4096,
PublicLimitAcquire: 4112,
BroadcastLimitAllow: 1024,
BroadcastLimitAcquire: 1040,
TopicDefaultLimit: 128,
}
}
func (cfg *RateLimiterConfig) Valid() bool {
return cfg.PublicLimitAllow > 0 &&
cfg.PublicLimitAcquire >= cfg.PublicLimitAllow &&
cfg.BroadcastLimitAllow > 0 &&
cfg.BroadcastLimitAcquire >= cfg.BroadcastLimitAllow &&
cfg.TopicDefaultLimit > 0
}
func NewRateLimiter(cfg RateLimiterConfig) RateLimiter {
return &BasicRateLimiter{
cfg: cfg,
activeTopics: make(map[string]int),
}
}
func (l *BasicRateLimiter) Allow(msg Envelope) bool {
if msg.IsBroadcast() {
return l.allowBroadcast(msg)
}
if isPublicBehavior(msg) {
return l.allowPublic()
}
return true
}
func (l *BasicRateLimiter) allowPublic() bool {
l.mx.Lock()
defer l.mx.Unlock()
return l.activePublic < l.cfg.PublicLimitAllow
}
func (l *BasicRateLimiter) allowBroadcast(msg Envelope) bool {
l.mx.Lock()
defer l.mx.Unlock()
if l.activeBroadcast >= l.cfg.BroadcastLimitAllow {
return false
}
topic := msg.Options.Topic
active := l.activeTopics[topic]
topicLimit, ok := l.cfg.TopicLimit[topic]
if !ok {
return active < l.cfg.TopicDefaultLimit
}
return active < topicLimit
}
func (l *BasicRateLimiter) Acquire(msg Envelope) error {
if msg.IsBroadcast() {
return l.acquireBroadcast(msg)
}
if isPublicBehavior(msg) {
return l.acquirePublic()
}
return nil
}
func (l *BasicRateLimiter) acquirePublic() error {
l.mx.Lock()
defer l.mx.Unlock()
if l.activePublic >= l.cfg.PublicLimitAcquire {
return ErrRateLimitExceeded
}
l.activePublic++
return nil
}
func (l *BasicRateLimiter) acquireBroadcast(msg Envelope) error {
l.mx.Lock()
defer l.mx.Unlock()
if l.activeBroadcast >= l.cfg.BroadcastLimitAcquire {
return ErrRateLimitExceeded
}
topic := msg.Options.Topic
active := l.activeTopics[topic]
topicLimit, ok := l.cfg.TopicLimit[topic]
if ok {
if active >= topicLimit {
return ErrRateLimitExceeded
}
} else if active >= l.cfg.TopicDefaultLimit {
return ErrRateLimitExceeded
}
active++
l.activeTopics[topic] = active
l.activeBroadcast++
return nil
}
func (l *BasicRateLimiter) Release(msg Envelope) {
if msg.IsBroadcast() {
l.releaseBroadcast(msg)
} else if isPublicBehavior(msg) {
l.releasePublic()
}
}
func (l *BasicRateLimiter) releasePublic() {
l.mx.Lock()
defer l.mx.Unlock()
l.activePublic--
}
func (l *BasicRateLimiter) releaseBroadcast(msg Envelope) {
l.mx.Lock()
defer l.mx.Unlock()
topic := msg.Options.Topic
active, ok := l.activeTopics[topic]
if !ok {
return
}
active--
if active > 0 {
l.activeTopics[topic] = active
} else {
delete(l.activeTopics, topic)
}
l.activeBroadcast--
}
func (l *BasicRateLimiter) Config() RateLimiterConfig {
l.mx.Lock()
defer l.mx.Unlock()
return l.cfg
}
func (l *BasicRateLimiter) SetConfig(cfg RateLimiterConfig) {
l.mx.Lock()
defer l.mx.Unlock()
l.cfg = cfg
}
func isPublicBehavior(msg Envelope) bool {
return strings.HasPrefix(msg.Behavior, "/public/")
}
package actor
import (
"encoding/json"
"fmt"
"time"
)
const (
heartbeatBehavior = "/dms/actor/heartbeat"
defaultMessageTimeout = 30 * time.Second
)
var signaturePrefix = []byte("dms:msg:")
type HeartbeatMessage struct{}
// Message constructs a new message envelope and applies the options
func Message(src Handle, dest Handle, behavior string, payload interface{}, opt ...MessageOption) (Envelope, error) {
data, err := json.Marshal(payload)
if err != nil {
return Envelope{}, fmt.Errorf("marshaling payload: %w", err)
}
msg := Envelope{
To: dest,
Behavior: behavior,
From: src,
Message: data,
Options: EnvelopeOptions{
Expire: uint64(time.Now().Add(defaultMessageTimeout).UnixNano()),
},
Discard: func() {},
}
for _, f := range opt {
if err := f(&msg); err != nil {
return Envelope{}, fmt.Errorf("setting message option: %w", err)
}
}
return msg, nil
}
func ReplyTo(msg Envelope, payload interface{}, opt ...MessageOption) (Envelope, error) {
if msg.Options.ReplyTo == "" {
return Envelope{}, fmt.Errorf("no behavior to reply to: %w", ErrInvalidMessage)
}
msgOptions := []MessageOption{WithMessageExpiry(msg.Options.Expire)}
msgOptions = append(msgOptions, opt...)
return Message(msg.To, msg.From, msg.Options.ReplyTo, payload, msgOptions...)
}
// WithMessageContext provides the necessary envelope and signs it.
//
// NOTE: If this option must be passed last, otherwise the signature will be invalidated by further modifications.
//
// NOTE: Signing is implicit in Send.
func WithMessageSignature(sctx SecurityContext, invoke []Capability, delegate []Capability) MessageOption {
return func(msg *Envelope) error {
if !msg.From.ID.Equal(sctx.ID()) {
return ErrInvalidSecurityContext
}
msg.Nonce = sctx.Nonce()
if msg.IsBroadcast() {
return sctx.ProvideBroadcast(msg, msg.Options.Topic, invoke)
}
return sctx.Provide(msg, invoke, delegate)
}
}
// WithMessageTimeout sets the message expiration from a relative timeout
//
// NOTE: messages created with Message have an implicit timeout of DefaultMessageTimeout
func WithMessageTimeout(timeo time.Duration) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = uint64(time.Now().Add(timeo).UnixNano())
return nil
}
}
// WithMessageExpiry sets the message expiry
//
// NOTE: created with Message message have an implicit timeout of DefaultMessageTimeout
func WithMessageExpiry(expiry uint64) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = expiry
return nil
}
}
// WithMessageReplyTo sets the message replyto behavior
//
// NOTE: ReplyTo is set implicitly in Invoke and the appropriate capability
//
// tokens are delegated by Provide.
func WithMessageReplyTo(replyto string) MessageOption {
return func(msg *Envelope) error {
msg.Options.ReplyTo = replyto
return nil
}
}
// WithMessageTopic sets the broadcast topic
func WithMessageTopic(topic string) MessageOption {
return func(msg *Envelope) error {
msg.Options.Topic = topic
return nil
}
}
// WithMessageSource sets the message source
func WithMessageSource(source Handle) MessageOption {
return func(msg *Envelope) error {
msg.From = source
return nil
}
}
func (msg *Envelope) SignatureData() ([]byte, error) {
msgCopy := *msg
msgCopy.Signature = nil
data, err := json.Marshal(&msgCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (msg *Envelope) Expired() bool {
return uint64(time.Now().UnixNano()) > msg.Options.Expire
}
// convert the expiration to a time.Time object
func (msg *Envelope) Expiry() time.Time {
sec := msg.Options.Expire / uint64(time.Second)
nsec := msg.Options.Expire % uint64(time.Second)
return time.Unix(int64(sec), int64(nsec)) //nolint
}
func (msg *Envelope) IsBroadcast() bool {
return msg.To.Empty() && msg.Options.Topic != ""
}
package actor
import (
"errors"
"sync"
)
type Info struct {
Addr *Handle
Parent *Handle
Children []Handle
}
type Registry interface {
Actors() map[string]Info
Add(a Handle, parent Handle, children []Handle) error
Get(a Handle) (Info, bool)
SetParent(a Handle, parent Handle) error
GetParent(a Handle) (*Handle, bool)
}
type registry struct {
mx sync.Mutex
actors map[string]Info
}
func newRegistry() *registry {
return ®istry{
actors: make(map[string]Info),
}
}
func (r *registry) Actors() map[string]Info {
r.mx.Lock()
defer r.mx.Unlock()
actors := make(map[string]Info, len(r.actors))
for k, v := range r.actors {
actors[k] = v
}
return actors
}
func (r *registry) Add(a Handle, parent Handle, children []Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
if _, ok := r.actors[a.Address.InboxAddress]; ok {
return errors.New("actor already exists")
}
if children == nil {
children = []Handle{}
}
r.actors[a.Address.InboxAddress] = Info{
Addr: &a,
Parent: &parent,
Children: children,
}
return nil
}
func (r *registry) Get(a Handle) (Info, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
return info, ok
}
func (r *registry) SetParent(a Handle, parent Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return errors.New("actor not found")
}
info.Parent = &parent
r.actors[a.Address.InboxAddress] = info
return nil
}
func (r *registry) GetParent(a Handle) (*Handle, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return nil, false
}
return info.Parent, true
}
package actor
import (
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
type BasicSecurityContext struct {
id ID
privk crypto.PrivKey
cap ucan.CapabilityContext
mx sync.Mutex
nonce uint64
}
var _ SecurityContext = (*BasicSecurityContext)(nil)
func NewBasicSecurityContext(pubk crypto.PubKey, privk crypto.PrivKey, cap ucan.CapabilityContext) (*BasicSecurityContext, error) {
sctx := &BasicSecurityContext{
privk: privk,
cap: cap,
nonce: uint64(time.Now().UnixNano()),
}
var err error
sctx.id, err = crypto.IDFromPublicKey(pubk)
if err != nil {
return nil, fmt.Errorf("creating security context: %w", err)
}
return sctx, nil
}
func (s *BasicSecurityContext) ID() ID {
return s.id
}
func (s *BasicSecurityContext) DID() DID {
return s.cap.DID()
}
func (s *BasicSecurityContext) Nonce() uint64 {
s.mx.Lock()
defer s.mx.Unlock()
nonce := s.nonce
s.nonce++
return nonce
}
func (s *BasicSecurityContext) Require(msg Envelope, cap []Capability) error {
// if we are sending to self, nothing to do, signature is alredady verified
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return nil
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.Require(s.DID(), msg.From.ID, s.id, cap); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) Provide(msg *Envelope, invoke []Capability, delegate []Capability) error {
// if we are sending to self, nothing to do, just Sign
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return s.Sign(msg)
}
tokens, err := s.cap.Provide(msg.To.DID, s.id, msg.To.ID, msg.Options.Expire, invoke, delegate)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) RequireBroadcast(msg Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.RequireBroadcast(s.DID(), msg.From.ID, topic, broadcast); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) ProvideBroadcast(msg *Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
tokens, err := s.cap.ProvideBroadcast(msg.From.ID, topic, msg.Options.Expire, broadcast)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) Verify(msg Envelope) error {
if msg.Expired() {
return ErrMessageExpired
}
pubk, err := crypto.PublicKeyFromID(msg.From.ID)
if err != nil {
return fmt.Errorf("public key from id: %w", err)
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
ok, err := pubk.Verify(data, msg.Signature)
if err != nil {
return fmt.Errorf("verify message signature: %w", err)
}
if !ok {
return ErrSignatureVerification
}
return nil
}
func (s *BasicSecurityContext) Sign(msg *Envelope) error {
if !s.id.Equal(msg.From.ID) {
return ErrBadSender
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
sig, err := s.privk.Sign(data)
if err != nil {
return fmt.Errorf("signing message: %w", err)
}
msg.Signature = sig
return nil
}
func (s *BasicSecurityContext) Discard(msg Envelope) {
s.cap.Discard(msg.Capability)
}
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
// ActorHandle godoc
//
// @Summary Retrieve actor handle
// @Description Retrieve actor handle with ID, DID, and inbox address
// @Tags actor
// @Produce json
// @Success 200 {object} actor.Handle
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "handle id is invalid"
// @Router /actor/handle [get]
func (rs RESTServer) ActorHandle(c *gin.Context) {
p2p := rs.config.P2P
if p2p == nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
// get handle here
pubk := p2p.Host.Peerstore().PubKey(p2p.Host.ID())
id, err := crypto.IDFromPublicKey(pubk)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "handle id is invalid"})
return
}
did := did.FromPublicKey(pubk)
handle := actor.Handle{
ID: id,
DID: did,
Address: actor.Address{
HostID: p2p.Host.ID().String(),
InboxAddress: "root",
},
}
c.JSON(http.StatusOK, handle)
}
// ActorSendMessage godoc
//
// @Summary Send message to actor
// @Description Send message to actor
// @Tags actor
// @Accept json
// @Produce json
// @Param message body actor.Envelope true "Message to send"
// @Success 200 {object} object "message sent"
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "failed to marshal message"
// @Failure 500 {object} object "destination address can't be resolved"
// @Failure 500 {object} object "failed to send message to destination"
// @Router /actor/send [post]
func (rs RESTServer) ActorSendMessage(c *gin.Context) {
var msg actor.Envelope
if err := c.ShouldBindJSON(&msg); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p2p := rs.config.P2P
if p2p == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
err := SendMessage(c.Request.Context(), p2p, msg)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "message sent"})
}
// ActorInvoke godoc
//
// @Summary Invoke actor
// @Description Invoke actor with message
// @Tags actor
// @Accept json
// @Produce json
// @Param message body actor.Envelope true "Message to send"
// @Success 200 {object} object "response message"
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "failed to marshal message"
// @Failure 500 {object} object "destination address can't be resolved"
// @Failure 500 {object} object "failed to send message to destination"
// @Router /actor/invoke [post]
func (rs RESTServer) ActorInvoke(c *gin.Context) {
var msg actor.Envelope
if err := c.ShouldBindJSON(&msg); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p2p := rs.config.P2P
if p2p == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
// Register a message handler for the responseCh
protocol := fmt.Sprintf("actor/%s/messages/0.0.1", msg.From.Address.InboxAddress)
responseCh := make(chan actor.Envelope, 1)
err := p2p.HandleMessage(protocol, func(data []byte) {
var envelope actor.Envelope
if err := json.Unmarshal(data, &envelope); err != nil {
// TODO log this
return
}
responseCh <- envelope
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Unregister the message handler before returning
defer p2p.UnregisterMessageHandler(protocol)
err = SendMessage(c.Request.Context(), p2p, msg)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
select {
case responseMsg := <-responseCh:
c.JSON(http.StatusOK, responseMsg)
return
case <-time.After(time.Until(msg.Expiry())):
c.JSON(http.StatusRequestTimeout, gin.H{"error": "request timeout"})
return
case <-c.Request.Context().Done():
c.JSON(http.StatusRequestTimeout, gin.H{"error": "request timeout"})
return
}
}
// ActorBroadcast godoc
//
// @Summary Broadcast message to actors
// @Description Broadcast message to actors
// @Tags actor
// @Accept json
// @Produce json
// @Param message body actor.Envelope true "Message to send"
// @Success 200 {object} object "received responses"
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "failed to marshal message"
// @Failure 500 {object} object "failed to publish message"
// @Router /actor/broadcast [post]
func (rs RESTServer) ActorBroadcast(c *gin.Context) {
var msg actor.Envelope
if err := c.ShouldBindJSON(&msg); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p2p := rs.config.P2P
if p2p == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
}
if !msg.IsBroadcast() {
c.JSON(http.StatusBadRequest, gin.H{"error": "message is not a broadcast message"})
return
}
data, err := json.Marshal(msg)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal message"})
return
}
// register message handler to collect responses
protocol := fmt.Sprintf("actor/%s/messages/0.0.1", msg.From.Address.InboxAddress)
var messages []actor.Envelope
var mu sync.Mutex
err = p2p.HandleMessage(protocol, func(data []byte) {
var envelope actor.Envelope
if err = json.Unmarshal(data, &envelope); err != nil {
return
}
mu.Lock()
messages = append(messages, envelope)
mu.Unlock()
})
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Unregister the message handler before returning
defer p2p.UnregisterMessageHandler(protocol)
// Publish the message
if err := p2p.Publish(c.Request.Context(), msg.Options.Topic, data); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to publish message"})
return
}
// Wait for either context done or timeout
select {
case <-time.After(time.Until(msg.Expiry())):
// message expiry time reached
case <-c.Request.Context().Done():
// request context done
}
c.JSON(http.StatusOK, messages)
}
func SendMessage(ctx context.Context, net *libp2p.Libp2p, msg actor.Envelope) (err error) {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
err = net.SendMessageSync(
ctx,
msg.To.Address.HostID,
types.MessageEnvelope{
Type: types.MessageType(
fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress),
),
Data: data,
},
msg.Expiry(),
)
if err != nil {
return fmt.Errorf("failed to send message to %s: %w", msg.To.ID, err)
}
return nil
}
package api
import (
"fmt"
"time"
"gitlab.com/nunet/device-management-service/types"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
type RESTServerConfig struct {
P2P *libp2p.Libp2p
Onboarding *onboarding.Onboarding
Logger *logger.Logger
Resource types.ResourceManager
MidW []gin.HandlerFunc
Port uint32
Addr string
}
// RESTServer represents a HTTP server
type RESTServer struct {
router *gin.Engine
config *RESTServerConfig
}
// NewRESTServer is a constructor function for RESTServer
// It returns a pointer to RESTServer
func NewRESTServer(config *RESTServerConfig) *RESTServer {
return &RESTServer{
router: setupRouter(config.MidW),
config: config,
}
}
func setupRouter(mid []gin.HandlerFunc) *gin.Engine {
mid = append(mid, cors.New(getCustomCorsConfig()))
router := gin.Default()
router.Use(mid...)
return router
}
// InitializeRoutes sets up all the endpoint routes
func (rs *RESTServer) InitializeRoutes() {
v1 := rs.router.Group("/api/v1")
// /actor routes
actor := v1.Group("/actor")
{
actor.GET("/handle", rs.ActorHandle)
actor.POST("/send", rs.ActorSendMessage)
actor.POST("/invoke", rs.ActorInvoke)
actor.POST("/broadcast", rs.ActorBroadcast)
}
}
// Run starts the server on the specified port
func (rs *RESTServer) Run() error {
return rs.router.Run(fmt.Sprintf("%s:%d", rs.config.Addr, rs.config.Port))
}
func getCustomCorsConfig() cors.Config {
config := defaultConfig()
// FIXME: This is a security concern.
config.AllowOrigins = []string{"http://localhost:9991", "http://localhost:9992"}
return config
}
// defaultConfig returns a generic default configuration mapped to localhost.
func defaultConfig() cors.Config {
return cors.Config{
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
AllowHeaders: []string{"Access-Control-Allow-Origin", "Origin", "Content-Length", "Content-Type"},
AllowCredentials: false,
MaxAge: 12 * time.Hour,
}
}
package actor
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
const (
CapstoreDir = "cap/"
DefaultUserContextName = "user"
KeystoreDir = "key/"
)
// NewActorCmd is a constructor for `actor` parent command
func NewActorCmd(client *utils.HTTPClient, afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "actor",
Short: "Interact with the actor system",
Long: `Interact with the actor system
Actors are the entities which compose the NuActor system, a secure decentralized programming framework based on the Actor Model.
Actors are connected through the libp2p network substrate and communication is achieved via immutable messages.
For more information on the actor system, please refer to actor/README.md`,
}
cmd.AddCommand(newActorMsgCmd(client, afs))
cmd.AddCommand(newActorSendCmd(client))
cmd.AddCommand(newActorInvokeCmd(client))
cmd.AddCommand(newActorBroadcastCmd(client))
cmd.AddCommand(newActorCmdGroup(client, afs))
return cmd
}
package actor
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/utils"
)
func newActorBroadcastCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "broadcast <msg>",
Short: "Broadcast a message",
Long: `Broadcast a message to a topic
If a topic is specified in the message's payload, the message will be published to all subscribers of that topic.`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
var msg actor.Envelope
if err := json.Unmarshal([]byte(args[0]), &msg); err != nil {
return fmt.Errorf("could not unmarshal message: %w", err)
}
resBody, resCode, err := client.MakeRequest("POST", "/actor/broadcast", []byte(args[0]))
fmt.Fprintln(cmd.OutOrStdout(), string(resBody))
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
return nil
},
}
return cmd
}
package actor
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
dmsUtil "gitlab.com/nunet/device-management-service/utils"
)
const (
fnTimeout = "timeout"
fnExpiry = "expiry"
fnContextName = "context"
fnDest = "dest"
bBroadcast = "broadcast"
bInvoke = "invoke"
bSend = "send"
)
func newActorCmdGroup(client *dmsUtil.HTTPClient, afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "cmd",
Short: "Invoke a predefined behavior on an actor",
Long: `Invoke a predefined behavior on an actor
Example:
nunet actor cmd --context user /broadcast/hello
Adding the --dest flag will cause the behavior to be invoked on the specified actor.
For more information on behaviors, refer to cmd/actor/README.md`,
ValidArgsFunction: func(_ *cobra.Command, args []string, _ string) ([]string, cobra.ShellCompDirective) {
if len(args) > 0 {
return nil, cobra.ShellCompDirectiveDefault
}
var completions []string
for k := range behaviors {
completions = append(completions, strings.Split(k, "/")[2])
}
return completions, cobra.ShellCompDirectiveNoFileComp
},
Run: func(cmd *cobra.Command, _ []string) {
err := cmd.Help()
if err != nil {
cmd.Println(err)
}
},
}
for behavior := range behaviors {
if behaviorCfg, ok := behaviors[behavior]; ok {
cmd.AddCommand(newActorCmdCmd(client, afs, behavior, behaviorCfg))
}
}
cmd.PersistentFlags().StringP(fnContextName, "c", "", "capability context name")
cmd.PersistentFlags().DurationP(fnTimeout, "t", 0, "timeout duration")
cmd.PersistentFlags().VarP(utils.NewTimeValue(&time.Time{}), fnExpiry, "e", "expiration time")
cmd.PersistentFlags().StringP(fnDest, "d", "", "destination DMS DID, peer ID or handle")
cmd.MarkFlagsMutuallyExclusive(fnTimeout, fnExpiry)
return cmd
}
func newActorCmdCmd(client *dmsUtil.HTTPClient, afs afero.Afero, behavior string, behaviorCfg behaviorConfig) *cobra.Command {
payload := &Payload{val: nil}
if behaviorCfg.Payload != nil {
payload.val = behaviorCfg.Payload()
}
cmd := &cobra.Command{
Use: fmt.Sprintf("%s [<param> ...]", behavior),
Short: behaviorCfg.Short,
Long: behaviorCfg.Long,
ValidArgsFunction: behaviorCfg.ValidArgsFn,
Args: behaviorCfg.Args,
PreRunE: func(cmd *cobra.Command, _ []string) error {
if behaviorCfg.PreRunE != nil {
return behaviorCfg.PreRunE(cmd, payload.val)
}
return nil
},
RunE: func(cmd *cobra.Command, _ []string) error {
timeout, _ := cmd.Flags().GetDuration(fnTimeout)
expiry, _ := utils.GetTime(cmd.Flags(), fnExpiry)
contextName, _ := cmd.Flags().GetString(fnContextName)
dest, _ := cmd.Flags().GetString(fnDest)
dmsHandle, err := getDMSHandle(client)
if err != nil {
return fmt.Errorf("could not get source DMS handle: %w", err)
}
topic := ""
if behaviorCfg.Type == bBroadcast {
topic = behaviorCfg.Topic
}
if behaviorCfg.PayloadEnc != nil {
payload.val, err = behaviorCfg.PayloadEnc(payload.val)
if err != nil {
return fmt.Errorf("could not marshal payload: %w", err)
}
}
invocation := behaviorCfg.Type == bInvoke
msg, err := newActorMessage(afs, dmsHandle, dest, topic, behavior, payload.val, timeout, expiry, invocation, contextName)
if err != nil {
return fmt.Errorf("could not create message: %w", err)
}
msgData, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("could not marshal message: %w", err)
}
endpoint := fmt.Sprintf("/actor/%s", behaviorCfg.Type)
resBody, resCode, err := client.MakeRequest("POST", endpoint, msgData)
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
enc := json.NewEncoder(cmd.OutOrStdout())
enc.SetIndent("", " ")
if behaviorCfg.Type == bBroadcast {
var resMsgs []cmdResponse
if err := json.Unmarshal(resBody, &resMsgs); err != nil {
return fmt.Errorf("could not unmarshal response: %w", err)
}
return enc.Encode(resMsgs)
}
var resMsg cmdResponse
if err := json.Unmarshal(resBody, &resMsg); err != nil {
return nil
}
return enc.Encode(resMsg)
},
}
if behaviorCfg.SetFlags != nil {
behaviorCfg.SetFlags(cmd, payload.val)
}
return cmd
}
package actor
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/utils"
)
// NewActorInvokeCmd is a constructor for `actor invoke` subcommand
func newActorInvokeCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "invoke <msg>",
Short: "Invoke a behaviour in an actor and return the result",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
var msg actor.Envelope
if err := json.Unmarshal([]byte(args[0]), &msg); err != nil {
return fmt.Errorf("could not unmarshal message: %w", err)
}
if msg.Options.ReplyTo == "" {
return fmt.Errorf("missing replyTo field in message")
}
resBody, resCode, err := client.MakeRequest("POST", "/actor/invoke", []byte(args[0]))
fmt.Fprintln(cmd.OutOrStdout(), string(resBody))
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
return nil
},
}
return cmd
}
package actor
import (
"encoding/json"
"fmt"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
dmsUtils "gitlab.com/nunet/device-management-service/utils"
)
func newActorMsgCmd(client *dmsUtils.HTTPClient, afs afero.Afero) *cobra.Command {
fnDest := "dest"
fnBroadcast := "broadcast"
fnTimeout := "timeout"
fnExpiry := "expiry"
fnInvoke := "invoke"
fnContextName := "context"
cmd := &cobra.Command{
Use: "msg <behavior> <payload>",
Short: "Construct a message",
Long: `Construct and sign a message that can be communicated to an actor.
The constructed message is returned as a JSON object that can be used stored or piped into another command, for instance the the send, invoke, or broadcast command.
Example:
nunet actor msg --broadcast /nunet/hello /broadcast/hello 'Hello, World!'`,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
destStr, _ := cmd.Flags().GetString(fnDest)
topic, _ := cmd.Flags().GetString(fnBroadcast)
timeout, _ := cmd.Flags().GetDuration(fnTimeout)
expiry, _ := utils.GetTime(cmd.Flags(), fnExpiry)
invocation, _ := cmd.Flags().GetBool(fnInvoke)
contextName, _ := cmd.Flags().GetString(fnContextName)
behavior := args[0]
payload := args[1]
dmsHandle, err := getDMSHandle(client)
if err != nil {
return fmt.Errorf("could not get source handle: %w", err)
}
msg, err := newActorMessage(afs, dmsHandle, destStr, topic, behavior, payload, timeout, expiry, invocation, contextName)
if err != nil {
return fmt.Errorf("could not create message: %w", err)
}
msgData, err := json.Marshal(msg)
if err != nil {
return err
}
fmt.Fprintln(cmd.OutOrStdout(), string(msgData))
return nil
},
}
cmd.Flags().StringP(fnDest, "d", "", "destination handle")
cmd.Flags().StringP(fnBroadcast, "b", "", "broadcast topic")
cmd.Flags().BoolP(fnInvoke, "i", false, "construct an invocation")
cmd.Flags().StringP(fnContextName, "c", "", "capability context name")
cmd.Flags().DurationP(fnTimeout, "t", 0, "timeout duration")
cmd.Flags().VarP(utils.NewTimeValue(&time.Time{}), fnExpiry, "e", "expiration time")
cmd.MarkFlagsMutuallyExclusive(fnDest, fnBroadcast)
cmd.MarkFlagsMutuallyExclusive(fnInvoke, fnBroadcast)
cmd.MarkFlagsMutuallyExclusive(fnTimeout, fnExpiry)
return cmd
}
package actor
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/utils"
)
// NewActorSendCmd is a constructor for `actor send` subcommand
func newActorSendCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "send <msg>",
Short: "Send a message",
Long: `Send a message to an actor
Actors only communicate via messages. For more information on constructing a message, see:
nunet actor msg --help
The message is encoded into an actor envelope, which then is sent across the network through the API.`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
var msg actor.Envelope
if err := json.Unmarshal([]byte(args[0]), &msg); err != nil {
return fmt.Errorf("could not unmarshal message: %w", err)
}
resBody, resCode, err := client.MakeRequest("POST", "/actor/send", []byte(args[0]))
fmt.Fprintln(cmd.OutOrStdout(), string(resBody))
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
return nil
},
}
return cmd
}
package actor
import (
"errors"
"fmt"
"gitlab.com/nunet/device-management-service/dms/hardware"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/node"
)
var ErrInvalidArgument = errors.New("invalid argument")
type Command = cobra.Command
type Payload struct {
val any
}
type behaviorConfig struct {
Behavior string
Type string
Topic string
Payload func() any
PayloadEnc func(payload any) (any, error)
SetFlags func(cmd *Command, payload any)
PreRunE func(cmd *Command, payload any) error
ValidArgsFn func(cmd *Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective)
Args cobra.PositionalArgs
Long string
Short string
}
var behaviors = map[string]behaviorConfig{
// /public/hello
node.PublicHelloBehavior: {
Type: bInvoke,
Short: "Broadcast a 'hello' message",
Long: `Invoke the /public/hello behavior on an actor
This behavior broadcasts a "hello" for a polite introduction.
Examples:
nunet actor cmd --context user /public/hello
nunet actor cmd --context user /public/hello --dest <did/peer_id/actor_handle>`,
},
// /broadcast/hello
node.BroadcastHelloBehavior: {
Type: bBroadcast,
Topic: node.BroadcastHelloTopic,
Short: "Broadcast a 'hello' message to a topic",
Long: `Invokes the /broadcast/hello behavior on an actor
This behavior sends a "hello" message to a broadcast topic for polite introduction.
Examples:
nunet actor cmd --context user /broadcast/hello`,
},
// /public/status
node.PublicStatusBehavior: {
Type: bInvoke,
Short: "Retrieve actor status",
Long: `Invokes the /public/status behavior on an actor
This behavior retrieves the status and resources information.
Examples:
nunet actor cmd --context user /public/status # own actor status
nunet actor cmd --context user /public/status --dest <did/peer_id/actor_handle> # status of specified destination`,
},
// /dms/node/peers/list
node.PeersListBehavior: {
Type: bInvoke,
Short: "List connected peers",
Long: `Invokes the /dms/node/peers/list behavior on an actor
This behavior retrieves a list of connected peers.
Examples:
nunet actor cmd --context user /dms/node/peers/list # own node actor peer list
nunet actor cmd --context user /dms/node/peers/list --dest <did/peer_id/actor_handle> # specified node actor peer list`,
},
// /dms/node/peers/self
node.PeerAddrInfoBehavior: {
Type: bInvoke,
Short: "Get peer's ID and addresses",
Long: `Invokes the /dms/node/peers/self behavior on an actor
This behavior retrieves information about the node itself, such as its ID or addresses.
Examples:
nunet actor cmd --context user /dms/node/peers/self # own node actor peer ID
nunet actor cmd --context user /dms/node/peers/self --dest <did/peer_id/actor_handle> # specified node actor peer ID`,
},
// /dms/node/peers/ping
node.PeerPingBehavior: {
Type: bInvoke,
Payload: func() any { return &node.PingRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.PingRequest)
cmd.Flags().StringVarP(&p.Host, "host", "H", "", "host address to ping (required)")
_ = cmd.MarkFlagRequired("host")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.PingRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Ping a peer",
Long: `Invokes the /dms/node/peers/ping behavior on an actor
This behavior establishes a ping connection with a peer.
Examples:
nunet actor cmd --context user /dms/node/peers/ping --host <peer_id>`,
},
// /dms/node/peers/dht
node.PeerDHTBehavior: {
Type: bInvoke,
Short: "List peers connected to DHT",
Long: `Invokes the /dms/node/peers/dht behavior on an actor
This behavior returns a list of peers from the Distributed Hash Table (DHT) used for peer discovery and content routing.
Examples:
nunet actor cmd --context user /dms/node/peers/dht`,
},
// /dms/node/peers/connect
node.PeerConnectBehavior: {
Type: bInvoke,
Payload: func() any { return &node.PeerConnectRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.PeerConnectRequest)
cmd.Flags().StringVarP(&p.Address, "address", "a", "", "peer address to connect to (required)")
_ = cmd.MarkFlagRequired("address")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.PeerConnectRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Connect to a peer",
Long: `Invokes the /dms/node/peers/connect behavior on an actor
This behavior initiates a connection to a specified peer.
Examples:
nunet actor cmd --context user /dms/node/peers/connect --address /p2p/<peer_id>`,
},
// /dms/node/peers/score
node.PeerScoreBehavior: {
Type: bInvoke,
Short: "Retrieves gossipsub broadcast score",
Long: `Invokes the /dms/node/peers/score behavior on an actor
This behavior retrieves a snapshot of the peer's gossipsub broadcast score.
Examples:
nunet actor cmd --context user /dms/node/peers/score`,
},
// /dms/node/onboarding/onboard
node.OnboardBehavior: {
Type: bInvoke,
Payload: func() any { return &node.OnboardRequest{} },
SetFlags: func(cmd *Command, payload any) {
// infer the type of the payload
p := payload.(*node.OnboardRequest)
cmd.Flags().Float64VarP(&p.Config.OnboardedResources.RAM.Size, "ram", "m", 0, "set the amount of memory in GB to reserve for NuNet")
cmd.Flags().Float32Var(&p.Config.OnboardedResources.CPU.Cores, "cpu", 0, "set the number of CPU cores to reserve for NuNet")
cmd.Flags().Float64Var(&p.Config.OnboardedResources.Disk.Size, "disk", 0, "set the amount of disk size in GB to reserve for NuNet")
cmd.MarkFlagsOneRequired("ram", "cpu", "disk")
cmd.MarkFlagsRequiredTogether("ram", "cpu", "disk")
},
PreRunE: onboardBehaviorPreRun,
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.OnboardRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Onboard a node to the network",
Long: `Invokes the /dms/node/onboarding/onboard behavior on an actor
This behavior is used to onboard a node to the DMS, making its resources available for use.
Examples:
nunet actor cmd --context user /dms/node/onboarding/onboard --memory 1 --cpu 2`,
},
// /dms/node/onboarding/offboard
node.OffboardBehavior: {
Type: bInvoke,
Payload: func() any { return &node.OffboardRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.OffboardRequest)
cmd.Flags().BoolVarP(&p.Force, "force", "f", false, "force offboard")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.OffboardRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Offboard a node from the network",
Long: `Invokes the /dms/node/onboarding/offboard behavior on an actor
This behavior is used to offboard a node from the DMS (Device Management Service).
Examples:
nunet actor cmd --context user /dms/node/onboarding/offboard
nunet actor cmd --context user /dms/node/onboarding/offboard --force`,
},
// /dms/node/onboarding/status
node.OnboardStatusBehavior: {
Type: bInvoke,
Short: "Retrieve onboarding status of a node",
Long: `Invokes the /dms/node/onboarding/status behavior on an actor
This behavior is used to check the onboarding status of a node.
Examples:
nunet actor cmd --context user /dms/node/onboarding/status`,
},
// /dms/node/onboarding/resource
node.OnboardResourceBehavior: {
Type: bInvoke,
Short: "Retrieve or manage resources of a node",
Long: `Invokes the /dms/node/onboarding/resource behavior on an actor
This behavior retrieves or manages resource information related to the onboarding process.
Examples:
nunet actor cmd --context user /dms/node/onboarding/resource`,
Payload: func() any { return &node.OnboardRequest{} },
SetFlags: func(cmd *Command, payload any) {
// infer the type of the payload
p := payload.(*node.OnboardRequest)
cmd.Flags().Float64VarP(&p.Config.OnboardedResources.RAM.Size, "ram", "r", 0, "set the RAM size in GB to reserve for NuNet")
cmd.Flags().Float32Var(&p.Config.OnboardedResources.CPU.Cores, "cpu", 0, "set the number of CPU cores to reserve for NuNet")
cmd.Flags().Float64Var(&p.Config.OnboardedResources.Disk.Size, "disk", 0, "set the disk size in GB to reserve for NuNet")
cmd.MarkFlagsRequiredTogether("ram", "disk", "cpu")
},
PreRunE: onboardBehaviorPreRun,
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.OnboardRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
},
// /dms/node/vm/start/custom
node.CustomVMStartBehavior: {
Type: bInvoke,
Payload: func() any { return &vmStartOpts{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*vmStartOpts)
cmd.Flags().StringVarP(&p.Engine.KernelImage, "kernel", "k", "", "path to kernel image file (required)")
cmd.Flags().StringVarP(&p.Engine.RootFileSystem, "rootfs", "r", "", "path to root fs image file (required)")
cmd.Flags().StringVarP(&p.Engine.Initrd, "initrd", "i", "", "path to initial ram disk")
cmd.Flags().StringVarP(&p.Engine.KernelArgs, "args", "a", "", "arguments to pas to the kernel")
cmd.Flags().Float32Var(&p.Resources.CPU.Cores, "cpu", 1, "CPU cores to allocate")
cmd.Flags().Float64VarP(&p.Resources.RAM.Size, "ram", "m", 1, "Memory to allocate in GB")
cmd.Flags().Float64Var(&p.Resources.Disk.Size, "disk", 0.5, "path to disk image file")
_ = cmd.MarkFlagRequired("kernel")
_ = cmd.MarkFlagFilename("kernel")
_ = cmd.MarkFlagRequired("rootfs")
_ = cmd.MarkFlagFilename("rootfs")
},
PayloadEnc: func(payload any) (any, error) {
opts, ok := payload.(*vmStartOpts)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return newCustomVMStartRequest(opts)
},
Short: "Starts a custom VM",
Long: `Invokes the /dms/node/vm/start/custom behavior on an actor
This behavior starts a new VM with custom configurations.
Examples:
nunet actor cmd --context user /dms/node/vm/start/custom --kernel /path/to/kernel --rootfs /path/to/rootfs --cpu 2 --memory 2048`,
},
// /dms/node/vm/stop
node.VMStopBehavior: {
Payload: func() any { return &node.VMStopRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.VMStopRequest)
cmd.Flags().StringVarP(&p.ExecutionID, "id", "i", "", "execution ID of the VM (required)")
_ = cmd.MarkFlagRequired("id")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.VMStopRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Stops a running VM",
Long: `Invokes the /dms/node/vm/stop behavior on an actor
This behavior stops a running VM.
Examples:
nunet actor cmd --context user /dms/node/vm/stop --id <execution_id>`,
},
// /dms/node/vm/list
node.VMListBehavior: {
Type: bInvoke,
Short: "List running VMs",
Long: `Invokes the /dms/node/vm/list behavior on an actor
This behavior retrieves a list of virtual machines (VMs) running on the node.
Examples:
nunet actor cmd --context user /dms/node/vm/list`,
},
}
func onboardBehaviorPreRun(_ *Command, payload any) error {
p, ok := payload.(*node.OnboardRequest)
if !ok {
return ErrInvalidArgument
}
// TODO: we need to have one instance of the hardware manager
// could we do an api call here instead?
hardwareManager := hardware.NewHardwareManager()
machineResources, err := hardwareManager.GetMachineResources()
if err != nil {
return fmt.Errorf("could not get machine resources: %w", err)
}
// TODO: create helper functions for these conversions
p.Config.OnboardedResources.CPU.ClockSpeed = machineResources.CPU.ClockSpeed
// convert RAM and Disk from GB to bytes
p.Config.OnboardedResources.RAM.Size *= 1024 * 1024 * 1024
p.Config.OnboardedResources.Disk.Size *= 1024 * 1024 * 1024
return nil
}
//go:build linux
// +build linux
package actor
import (
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/executor/firecracker"
"gitlab.com/nunet/device-management-service/types"
)
func newCustomVMStartRequest(opts *vmStartOpts) (node.CustomVMStartRequest, error) {
engine := firecracker.NewFirecrackerEngineBuilder(opts.Engine.RootFileSystem)
engine = engine.WithKernelImage(opts.Engine.KernelImage)
engine = engine.WithKernelArgs(opts.Engine.KernelArgs)
engine = engine.WithInitrd(opts.Engine.Initrd)
es := engine.Build()
req := node.CustomVMStartRequest{
Execution: types.ExecutionRequest{
ExecutionID: uuid.New().String(),
EngineSpec: es,
Resources: &opts.Resources,
},
}
return req, nil
}
type vmStartOpts struct {
Engine firecracker.EngineSpec
Resources types.Resources
}
package actor
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/cmd/cap"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/utils"
)
type cmdResponse struct {
val interface{}
}
func (r *cmdResponse) UnmarshalJSON(data []byte) error {
var res struct {
Message []byte `json:"msg"`
}
if err := json.Unmarshal(data, &res); err != nil {
return err
}
val := interface{}(nil)
if err := json.Unmarshal(res.Message, &val); err != nil {
return err
}
*r = cmdResponse{val: val}
return nil
}
func (r cmdResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(r.val)
}
func getDMSHandle(client *utils.HTTPClient) (actor.Handle, error) {
var handle actor.Handle
body, code, err := client.MakeRequest("GET", "/actor/handle", nil)
if err != nil {
return handle, fmt.Errorf("unable to get source handle: %w", err)
}
if code != 200 {
return handle, fmt.Errorf("request failed with status code: %d", code)
}
if err = json.Unmarshal(body, &handle); err != nil {
return handle, fmt.Errorf("could not unmarshal response body: %w", err)
}
return handle, err
}
func newUserHandle(id crypto.ID, did did.DID, dmsHandle actor.Handle, inbox string) actor.Handle {
return actor.Handle{
ID: id,
DID: did,
Address: actor.Address{
HostID: dmsHandle.Address.HostID,
InboxAddress: inbox,
},
}
}
func newSecurityContext(fs afero.Afero, context string) (actor.SecurityContext, error) {
if context == "" {
context = DefaultUserContextName
}
// Generate ephemeral key pair
privk, pubk, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral key pair: %w", err)
}
// Create trust context
var trustCtx did.TrustContext
if cap.IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return nil, err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = cap.LedgerContext(context)
} else {
var err error
trustCtx, _, err = cap.CreateTrustContextFromKeyStore(fs, context)
if err != nil {
return nil, fmt.Errorf("failed to create trust context: %w", err)
}
}
// Load capability context
capCtx, err := cap.LoadCapabilityContext(trustCtx, context)
if err != nil {
return nil, fmt.Errorf("failed to load capability context: %w", err)
}
return actor.NewBasicSecurityContext(pubk, privk, capCtx)
}
func newActorMessage(fs afero.Afero, dmsHandle actor.Handle, destStr string, topic, behavior string, payload interface{}, timeout time.Duration, expiry time.Time, invocation bool, context string) (actor.Envelope, error) {
var msg actor.Envelope
var src actor.Handle
var dest actor.Handle
sctx, err := newSecurityContext(fs, context)
if err != nil {
return msg, fmt.Errorf("failed to create security context: %w", err)
}
nonce := sctx.Nonce()
inbox := fmt.Sprintf("user-%d", nonce)
src = newUserHandle(sctx.ID(), sctx.DID(), dmsHandle, inbox)
opts := []actor.MessageOption{}
replyTo := ""
switch {
case topic != "":
opts = append(opts, actor.WithMessageTopic(topic))
replyTo = fmt.Sprintf("/public/user/%d", nonce)
case destStr != "":
switch {
case strings.HasPrefix(destStr, "did:"):
dest, err = handleFromDID(destStr)
case strings.HasPrefix(destStr, "{"):
err = json.Unmarshal([]byte(destStr), &dest)
default:
dest, err = handleFromPeerID(destStr)
}
if err != nil {
return msg, fmt.Errorf("could not create destination handle: %w", err)
}
default:
dest = dmsHandle
}
if invocation {
replyTo = fmt.Sprintf("/private/user/%d", nonce)
}
if !expiry.IsZero() {
opts = append(opts, actor.WithMessageExpiry(uint64(expiry.UnixNano())))
}
if timeout > 0 {
opts = append(opts, actor.WithMessageTimeout(timeout))
}
delegate := []ucan.Capability{}
if replyTo != "" {
opts = append(opts, actor.WithMessageReplyTo(replyTo))
if topic == "" {
delegate = append(delegate, ucan.Capability(replyTo))
}
}
opts = append(opts, actor.WithMessageSignature(sctx, []ucan.Capability{ucan.Capability(behavior)}, delegate))
msg, err = actor.Message(src, dest, behavior, payload, opts...)
if err != nil {
return msg, fmt.Errorf("could not construct message: %w", err)
}
return msg, nil
}
func handleFromPeerID(dest string) (actor.Handle, error) {
peerID, err := peer.Decode(dest)
if err != nil {
return actor.Handle{}, err
}
pubk, err := peerID.ExtractPublicKey()
if err != nil {
return actor.Handle{}, err
}
if !crypto.AllowedKey(int(pubk.Type())) {
return actor.Handle{}, fmt.Errorf("unexpected key type: %d", pubk.Type())
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
return actor.Handle{}, err
}
actorDID := did.FromPublicKey(pubk)
handle := actor.Handle{
ID: actorID,
DID: actorDID,
Address: actor.Address{
HostID: peerID.String(),
InboxAddress: "root",
},
}
return handle, nil
}
func handleFromDID(dest string) (actor.Handle, error) {
actorDID, err := did.FromString(dest)
if err != nil {
return actor.Handle{}, err
}
pubk, err := did.PublicKeyFromDID(actorDID)
if err != nil {
return actor.Handle{}, err
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
return actor.Handle{}, err
}
peerID, err := peer.IDFromPublicKey(pubk)
if err != nil {
return actor.Handle{}, err
}
handle := actor.Handle{
ID: actorID,
DID: actorDID,
Address: actor.Address{
HostID: peerID.String(),
InboxAddress: "root",
},
}
return handle, nil
}
package cmd
import (
"os"
"github.com/spf13/cobra"
)
// autocompleteCmd represents the command to generate shell autocompletion scripts
func newAutoCompleteCmd() *cobra.Command {
return &cobra.Command{
Use: "autocomplete [shell]",
Short: "Generate autocomplete script for your shell",
Long: `Generate an autocomplete script for the nunet CLI.
This command supports Bash and Zsh shells.`,
DisableFlagsInUseLine: true,
Hidden: true,
ValidArgs: []string{"bash", "zsh"},
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
Run: func(cmd *cobra.Command, args []string) {
switch args[0] {
case "bash":
_ = cmd.Root().GenBashCompletion(os.Stdout)
case "zsh":
_ = cmd.Root().GenZshCompletion(os.Stdout)
}
},
}
}
package backend
import gonet "github.com/shirou/gopsutil/net"
type Network struct{}
func (n *Network) GetConnections(kind string) ([]gonet.ConnectionStat, error) {
return gonet.Connections(kind)
}
package cap
import (
"encoding/json"
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newAnchorCmd(afs afero.Afero) *cobra.Command {
var (
context string
root string
provide string
require string
)
const (
fnProvide = "provide"
fnRoot = "root"
fnRequire = "require"
)
cmd := &cobra.Command{
Use: "anchor",
Short: "Manage capability anchors",
Long: `Add or modify capability anchors in a capability context.
An anchor is a basis of trust in the capability system. There are three types of anchors:
1. Root anchor: Represents absolute trust or effective root capability.
Use the --root flag with a DID value to add a root anchor.
2. Require anchor: Represents input trust. We verify incoming messages based on the require anchor.
Use the --require flag with a token to add a require anchor.
3. Provide anchor: Represents output trust. We emit invocation tokens based on our provide anchors and sign output.
Use the --provide flag with a token to add a provide anchor.
Only one type of anchor can be added or modified per command execution.
Usage examples:
nunet cap anchor --context user --root did:example:123456789abcdefghi
nunet cap anchor --context dms --require '{"some": "json", "token": "here"}'
nunet cap anchor --context user --provide '{"another": "json", "token": "example"}'
Note: The --context flag is required to specify the capability context.`,
RunE: func(_ *cobra.Command, _ []string) error {
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
var err error
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
switch {
case root != "":
rootDID, err := did.FromString(root)
if err != nil {
return fmt.Errorf("invalid root DID: %w", err)
}
if err := capCtx.AddRoots([]did.DID{rootDID}, ucan.TokenList{}, ucan.TokenList{}); err != nil {
return fmt.Errorf("failed to add root anchors: %w", err)
}
case require != "":
var tokens ucan.TokenList
if err := json.Unmarshal([]byte(require), &tokens); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
if err := capCtx.AddRoots(nil, tokens, ucan.TokenList{}); err != nil {
return fmt.Errorf("failed to add require anchors: %w", err)
}
case provide != "":
var tokens ucan.TokenList
if err := json.Unmarshal([]byte(provide), &tokens); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
if err := capCtx.AddRoots(nil, ucan.TokenList{}, tokens); err != nil {
return fmt.Errorf("failed to add provide anchors: %w", err)
}
default:
return fmt.Errorf("one of --provide, --root, or --require must be specified")
}
if err := SaveCapabilityContext(capCtx, context); err != nil {
return fmt.Errorf("save capability context: %w", err)
}
return nil
},
}
useFlagContext(cmd, &context)
useFlagRoot(cmd, &root)
useFlagRequire(cmd, &require)
useFlagProvide(cmd, &provide)
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnProvide, fnRoot, fnRequire)
cmd.MarkFlagsMutuallyExclusive(fnProvide, fnRoot, fnRequire)
return cmd
}
package cap
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
const (
fnContext = "context"
fnAudience = "audience"
fnAction = "action"
fnCap = "cap"
fnTopic = "topic"
fnExpiry = "expiry"
fnDuration = "duration"
fnAutoExpire = "auto-expire"
fnSelfSign = "self-sign"
fnDepth = "depth"
fnProvide = "provide"
fnRoot = "root"
fnRequire = "require"
)
// NewCapCmd returns the cap command that adds other commands
func NewCapCmd(afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "cap",
Short: "Manage capabilities",
Long: `Manage capabilities for the Device Management Service`,
}
cmd.AddCommand(newGrantCmd(afs))
cmd.AddCommand(newAnchorCmd(afs))
cmd.AddCommand(newNewCmd(afs))
cmd.AddCommand(newDelegateCmd(afs))
cmd.AddCommand(newListCmd(afs))
cmd.AddCommand(newRemoveCmd(afs))
return cmd
}
package cap
import (
"encoding/json"
"fmt"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newDelegateCmd(afs afero.Afero) *cobra.Command {
var (
context string
caps []string
topics []string
audience string
expiry time.Time
duration time.Duration
autoExpire bool
depth uint64
selfSign string
)
cmd := &cobra.Command{
Use: "delegate <did>",
Short: "Delegate capabilities",
Long: `Delegate capabilities to a subject
Capabilities are delegated based on provide anchors. No capabilities are delegated by default, you need to use --cap flag to explicitly specify the capabilities to delegate.
Example:
nunet cap anchor --context user --provide '<token>'
nunet cap delegate --context user --cap /public --duration 1h did:key:<some-key>`,
Args: cobra.ExactArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
subject := args[0]
var expirationTime uint64
switch {
case !expiry.IsZero():
expirationTime = uint64(expiry.UnixNano())
case duration != 0:
expirationTime = uint64(time.Now().Add(duration).UnixNano())
case autoExpire:
expirationTime = 0
default:
return fmt.Errorf("either expiration or duration must be specified")
}
subjectDID, err := did.FromString(subject)
if err != nil {
return fmt.Errorf("invalid subject DID: %w", err)
}
var audienceDID did.DID
if audience != "" {
audienceDID, err = did.FromString(audience)
if err != nil {
return fmt.Errorf("invalid audience DID: %w", err)
}
}
capabilities := make([]ucan.Capability, len(caps))
for i, cap := range caps {
capabilities[i] = ucan.Capability(cap)
}
var selfSignMode ucan.SelfSignMode
switch selfSign {
case "no":
selfSignMode = ucan.SelfSignNo
case "also":
selfSignMode = ucan.SelfSignAlso
case "only":
selfSignMode = ucan.SelfSignOnly
default:
return fmt.Errorf("invalid self-sign option: %s", selfSign)
}
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
tokens, err := capCtx.Delegate(subjectDID, audienceDID, topics, expirationTime, depth, capabilities, selfSignMode)
if err != nil {
return fmt.Errorf("failed to delegate capabilities: %w", err)
}
tokensJSON, err := json.Marshal(tokens)
if err != nil {
return fmt.Errorf("unable to marshal tokens to json: %w", err)
}
fmt.Println(string(tokensJSON))
return nil
},
}
useFlagContext(cmd, &context)
useFlagAudience(cmd, &audience)
useFlagCap(cmd, &caps)
useFlagTopic(cmd, &topics)
useFlagExpiry(cmd, &expiry)
useFlagDuration(cmd, &duration)
useFlagAutoExpire(cmd, &autoExpire)
useFlagDepth(cmd, &depth)
cmd.Flags().StringVar(&selfSign, fnSelfSign, "no", "Self-sign option: 'no', 'also', or 'only'")
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnExpiry, fnDuration, fnAutoExpire)
cmd.MarkFlagsMutuallyExclusive(fnExpiry, fnDuration, fnAutoExpire)
cmd.MarkFlagsMutuallyExclusive(fnSelfSign, fnAutoExpire)
return cmd
}
package cap
import (
"time"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
)
func useFlagContext(cmd *cobra.Command, context *string) {
cmd.Flags().StringVarP(context, fnContext, "c", dms.UserContextName, "specifies capability context")
}
func useFlagAudience(cmd *cobra.Command, audience *string) {
cmd.Flags().StringVarP(audience, fnAudience, "a", "", "audience DID (optional)")
}
func useFlagCap(cmd *cobra.Command, caps *[]string) {
cmd.Flags().StringSliceVar(caps, fnCap, []string{}, "capabilities to grant/delegate (can be specified multiple times)")
}
func useFlagTopic(cmd *cobra.Command, topics *[]string) {
cmd.Flags().StringSliceVarP(topics, fnTopic, "t", []string{}, "topics to grant/delegate (can be specified multiple times)")
}
func useFlagExpiry(cmd *cobra.Command, expiry *time.Time) {
cmd.Flags().VarP(utils.NewTimeValue(expiry), fnExpiry, "e", "set expiration date (ISO 8601 format)")
}
func useFlagDuration(cmd *cobra.Command, duration *time.Duration) {
cmd.Flags().DurationVar(duration, fnDuration, 0, "set duration time (specify unit)")
}
func useFlagAutoExpire(cmd *cobra.Command, autoExpire *bool) {
cmd.Flags().BoolVar(autoExpire, fnAutoExpire, false, "set auto expiration")
}
func useFlagDepth(cmd *cobra.Command, depth *uint64) {
cmd.Flags().Uint64VarP(depth, fnDepth, "d", 0, "delegation depth (optional, default=0)")
}
func useFlagRoot(cmd *cobra.Command, root *string) {
cmd.Flags().StringVar(root, fnRoot, "", "DID to add as root anchor (represents absolute trust)")
}
func useFlagRequire(cmd *cobra.Command, require *string) {
cmd.Flags().StringVar(require, fnRequire, "", "JWT to add as require anchor (for input trust)")
}
func useFlagProvide(cmd *cobra.Command, provide *string) {
cmd.Flags().StringVar(provide, fnProvide, "", "JWT to add as provide anchor (for output trust)")
}
package cap
import (
"encoding/json"
"fmt"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newGrantCmd(afs afero.Afero) *cobra.Command {
var (
context string
caps []string
topics []string
audience string
expiry time.Time
duration time.Duration
depth uint64
)
cmd := &cobra.Command{
Use: "grant <did>",
Short: "Grant capabilities",
Long: `Grant a self-sign token delegating capabilities
It is not necessary to set up a anchor before granting a capability because this operation is self-signed.
Example:
nunet cap grant --context user --cap /public --duration 1h did:key:<some-key>
The above command emits a self-signed token with the specified capabilities delegated from 'user' to the sbjects's DID. `,
Args: cobra.ExactArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
subject := args[0]
var expirationTime uint64
switch {
case !expiry.IsZero():
expirationTime = uint64(expiry.UnixNano())
case duration != 0:
expirationTime = uint64(time.Now().Add(duration).UnixNano())
default:
return fmt.Errorf("either expiration or duration must be specified")
}
subjectDID, err := did.FromString(subject)
if err != nil {
return fmt.Errorf("invalid subject DID: %w", err)
}
var audienceDID did.DID
if audience != "" {
audienceDID, err = did.FromString(audience)
if err != nil {
return fmt.Errorf("invalid audience DID: %w", err)
}
}
capabilities := make([]ucan.Capability, len(caps))
for i, cap := range caps {
capabilities[i] = ucan.Capability(cap)
}
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
tokens, err := capCtx.Grant(ucan.Delegate, subjectDID, audienceDID, topics, expirationTime, depth, capabilities)
if err != nil {
return fmt.Errorf("failed to grant capabilities: %w", err)
}
tokensJSON, err := json.Marshal(tokens)
if err != nil {
return fmt.Errorf("unable to marshal tokens to json: %w", err)
}
fmt.Println(string(tokensJSON))
return nil
},
}
useFlagContext(cmd, &context)
useFlagAudience(cmd, &audience)
useFlagCap(cmd, &caps)
useFlagTopic(cmd, &topics)
useFlagExpiry(cmd, &expiry)
useFlagDuration(cmd, &duration)
useFlagDepth(cmd, &depth)
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnExpiry, fnDuration)
return cmd
}
package cap
import (
"encoding/json"
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
)
func newListCmd(afs afero.Afero) *cobra.Command {
var context string
cmd := &cobra.Command{
Use: "list",
Short: "List capability anchors",
Long: `List all capability anchors in a capability context
It outputs DIDs and capability tokens set for root, provide and require anchors.`,
RunE: func(_ *cobra.Command, _ []string) error {
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
var err error
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
roots, require, provide := capCtx.ListRoots()
fmt.Println("roots:")
for _, root := range roots {
fmt.Printf("\t%s\n", root)
}
fmt.Println("require:")
for _, t := range require.Tokens {
data, err := json.Marshal(t)
if err != nil {
return fmt.Errorf("failed to marshal capability token: %w", err)
}
fmt.Printf("\t%s\n", string(data))
}
fmt.Println("provide:")
for _, t := range provide.Tokens {
data, err := json.Marshal(t)
if err != nil {
return fmt.Errorf("failed to marshal capability token: %w", err)
}
fmt.Printf("\t%s\n", string(data))
}
return nil
},
}
useFlagContext(cmd, &context)
return cmd
}
package cap
import (
"fmt"
"path/filepath"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newNewCmd(afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "new <name>",
Short: "Create a new capability context",
Long: `Create a new persistent capability context
Example:
nunet cap new user
nunet cap new ledger:user # if using ledger`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
context := dms.UserContextName
if len(args) > 0 {
context = args[0]
}
var trustCtx did.TrustContext
var rootDID did.DID
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
rootDID = provider.DID()
context = LedgerContext(context)
} else {
var priv crypto.PrivKey
var err error
trustCtx, priv, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
rootDID = did.FromPublicKey(priv.GetPublic())
}
capStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.CapstoreDir)
capStoreFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", context))
fileExists, err := afs.Exists(capStoreFile)
if err != nil {
return fmt.Errorf("unable to check if capability context file exists: %w", err)
}
if fileExists {
confirmed, err := utils.PromptYesNo(
cmd.InOrStdin(),
cmd.OutOrStdout(),
fmt.Sprintf(
"WARNING: A capability context file already exists at %s. Creating a new one will overwrite the existing context. Do you want to proceed?",
capStoreFile,
),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return fmt.Errorf("operation cancelled by user")
}
} else {
if err := afs.MkdirAll(capStoreDir, 0o700); err != nil {
return fmt.Errorf("unable to create capability store directory: %w", err)
}
}
capCtx, err := ucan.NewCapabilityContext(trustCtx, rootDID, nil, ucan.TokenList{}, ucan.TokenList{})
if err != nil {
return fmt.Errorf("unable to create capability context: %w", err)
}
if err := SaveCapabilityContext(capCtx, context); err != nil {
return fmt.Errorf("save capability context: %w", err)
}
return nil
},
}
return cmd
}
package cap
import (
"encoding/json"
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newRemoveCmd(afs afero.Afero) *cobra.Command {
var (
context string
root string
provide string
require string
)
const (
fnProvide = "provide"
fnRoot = "root"
fnRequire = "require"
)
cmd := &cobra.Command{
Use: "remove",
Short: "Remove capability anchors",
Long: `Remove capability anchors in a capability context
One capability anchor must be specified at a time.
Example:
nunet cap remove --context user --root did:key:abcd1234
nunet cap remove --context user --require '<the-token>'`,
RunE: func(_ *cobra.Command, _ []string) error {
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
var err error
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
switch {
case root != "":
rootDID, err := did.FromString(root)
if err != nil {
return fmt.Errorf("invalid root DID: %w", err)
}
capCtx.RemoveRoots([]did.DID{rootDID}, ucan.TokenList{}, ucan.TokenList{})
case require != "":
var token ucan.Token
if err := json.Unmarshal([]byte(require), &token); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
capCtx.RemoveRoots(nil, ucan.TokenList{Tokens: []*ucan.Token{&token}}, ucan.TokenList{})
case provide != "":
var token ucan.Token
if err := json.Unmarshal([]byte(provide), &token); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
capCtx.RemoveRoots(nil, ucan.TokenList{}, ucan.TokenList{Tokens: []*ucan.Token{&token}})
default:
return fmt.Errorf("one of --provide, --root, or --require must be specified")
}
if err := SaveCapabilityContext(capCtx, context); err != nil {
return fmt.Errorf("save capability context: %w", err)
}
return nil
},
}
useFlagContext(cmd, &context)
useFlagRoot(cmd, &root)
useFlagRequire(cmd, &require)
useFlagProvide(cmd, &provide)
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnProvide, fnRoot, fnRequire)
cmd.MarkFlagsMutuallyExclusive(fnProvide, fnRoot, fnRequire)
return cmd
}
package cap
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
const ledger = "ledger"
func IsLedgerContext(context string) bool {
return strings.HasPrefix(context, ledger)
}
func LedgerContext(context string) string {
parts := strings.Split(context, ":")
if len(parts) == 2 {
return parts[1]
}
return ledger
}
func CreateTrustContextFromKeyStore(afs afero.Afero, contextKey string) (did.TrustContext, crypto.PrivKey, error) {
keyStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.KeystoreDir)
ks, err := keystore.New(afs.Fs, keyStoreDir)
if err != nil {
return nil, nil, fmt.Errorf("failed to open keystore: %w", err)
}
passphrase := os.Getenv("DMS_PASSPHRASE")
if passphrase == "" {
passphrase, err = utils.PromptForPassphrase(false)
if err != nil {
return nil, nil, fmt.Errorf("failed to get passphrase: %w", err)
}
}
ksPrivKey, err := ks.Get(contextKey, passphrase)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key from keystore: %w", err)
}
priv, err := ksPrivKey.PrivKey()
if err != nil {
return nil, nil, fmt.Errorf("unable to convert key from keystore to private key: %w", err)
}
trustCtx, err := did.NewTrustContextWithPrivateKey(priv)
if err != nil {
return nil, nil, fmt.Errorf("unable to create trust context: %w", err)
}
return trustCtx, priv, nil
}
func LoadCapabilityContext(trustCtx did.TrustContext, name string) (ucan.CapabilityContext, error) {
capStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.CapstoreDir)
capStoreFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", name))
f, err := os.Open(capStoreFile)
if err != nil {
return nil, fmt.Errorf("unable to open capability context file: %w", err)
}
defer f.Close()
capCtx, err := ucan.LoadCapabilityContext(trustCtx, f)
if err != nil {
return nil, fmt.Errorf("unable to load capability context: %w", err)
}
return capCtx, nil
}
func SaveCapabilityContext(capCtx ucan.CapabilityContext, name string) error {
capStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.CapstoreDir)
capCtxFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", name))
capCtxBackup := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap.%d", name, time.Now().Unix()))
// first take a backup -- move the current context
if _, err := os.Stat(capCtxFile); err == nil {
if err := os.Rename(capCtxFile, capCtxBackup); err != nil {
return fmt.Errorf("error backing up current capability context: %w", err)
}
}
// now open for writing
f, err := os.Create(capCtxFile)
if err != nil {
return fmt.Errorf("error creating new capability context file: %w", err)
}
defer f.Close()
if err := ucan.SaveCapabilityContext(capCtx, f); err != nil {
return fmt.Errorf("error saving capability context: %w", err)
}
return nil
}
package cmd
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/internal/config"
)
func newConfigCmd(fs afero.Fs) *cobra.Command {
if fs == nil {
cobra.CheckErr("Fs is nil")
}
cmd := &cobra.Command{
Use: "config",
Short: "Manage configuration file",
Long: `Utility to manage user's configuration file via command-line
Search for the configuration file is done in the following locations and order:
1. "." (current directory)
2. "$HOME/.nunet"
3. "/etc/nunet"`,
}
cmd.AddCommand(newConfigGetCmd())
cmd.AddCommand(newConfigSetCmd(fs))
cmd.AddCommand(newConfigEditCmd())
return cmd
}
func newConfigGetCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "get <key>",
Short: "Display configuration",
Long: `Display the value for a configuration key
It reads the value from configuration file, otherwise it return default values
Example:
nunet config get rest.port`,
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
info, err := json.MarshalIndent(config.GetConfig(), "", " ")
if err != nil {
return fmt.Errorf("failed to indent config JSON: %w", err)
}
cmd.Println(string(info))
return nil
}
value, err := config.Get(args[0])
if err != nil {
return fmt.Errorf("could not get key's value: %w", err)
}
pretty, err := json.MarshalIndent(value, "", " ")
if err != nil {
return fmt.Errorf("failed to indent JSON: %w", err)
}
cmd.Println(string(pretty))
return nil
},
}
return cmd
}
func newConfigSetCmd(fs afero.Fs) *cobra.Command {
cmd := &cobra.Command{
Use: "set <key> <value>",
Short: "Update configuration",
Long: `Set value for a configuration key
It creates a configuration file if does not exists, otherwise it updates the existing file
Example:
nunet config set rest.port 4444`,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
exists, err := config.FileExists(fs)
if err != nil {
return fmt.Errorf("failed to check if config file exists: %w", err)
}
if !exists {
cmd.Println("Config file did not exist. Creating new file...")
} else {
cmd.Println("Updating existing config file...")
}
if err := config.Set(fs, args[0], args[1]); err != nil {
return fmt.Errorf("failed to set config: %w", err)
}
cmd.Println("Applied changes.")
return nil
},
}
return cmd
}
func newConfigEditCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "edit",
Short: "Edit configuration",
Long: `Open configuration file with default text editor
This command search the configuration file and open it with the default text editor
It reads the $EDITOR environment variable and it fails if it's not set`,
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
editor := os.Getenv("EDITOR")
if editor == "" {
return fmt.Errorf("$EDITOR not set")
}
cmd.Printf("Text editor: %s\n", editor)
cmd.Printf("Config path: %s\n", config.GetPath())
// do we need better sanitization?
// do we check if editor is valid?
proc := exec.Command(editor, config.GetPath())
proc.Stdout = cmd.OutOrStdout()
proc.Stdin = cmd.InOrStdin()
proc.Stderr = cmd.OutOrStderr()
return proc.Run()
},
}
return cmd
}
package cmd
import (
"context"
"fmt"
"os"
"github.com/docker/docker/api/types/container"
"gitlab.com/nunet/device-management-service/dms/hardware"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/dms/hardware/gpu"
"gitlab.com/nunet/device-management-service/types"
"github.com/spf13/cobra"
)
func newGPUCommand() *cobra.Command {
gpuCmd := &cobra.Command{
Use: "gpu <operation>",
Short: "Manage GPU resources",
Long: `Available operations:
- list: List all available GPUs
- test: Test GPU deployment by running a docker container with GPU resources
`,
}
// Add subcommands
gpuCmd.AddCommand(newGPUListCommand())
gpuCmd.AddCommand(newGPUTestCommand())
return gpuCmd
}
func newGPUListCommand() *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List all available GPUs",
RunE: func(_ *cobra.Command, _ []string) error {
gpus, err := gpu.GetGPUs()
if err != nil {
return fmt.Errorf("error getting GPUs: %w", err)
}
usage, err := gpu.GetGPUUsage()
if err != nil {
return fmt.Errorf("error getting GPU usage: %w", err)
}
if len(gpus) == 0 {
return fmt.Errorf("no gpus found")
}
if len(gpus) != len(usage) {
return fmt.Errorf("GPU and GPU usage counts do not match. This is a bug")
}
fmt.Println("GPU Details:")
for i, g := range gpus {
fmt.Printf("Model: %s, Total VRAM: %.2f GB, Used VRAM: %.2f GB, Vendor: %s, PCI Address: %s, Index: %d\n",
g.Model, types.ConvertBytesToGB(g.VRAM), types.ConvertBytesToGB(usage[i].VRAM), g.Vendor, g.PCIAddress, g.Index)
}
return nil
},
}
}
func newGPUTestCommand() *cobra.Command {
return &cobra.Command{
Use: "test",
Short: "Test GPU deployment by running a Docker container with GPU resources",
RunE: func(_ *cobra.Command, _ []string) error {
hardwareManager := hardware.NewHardwareManager()
machineResources, err := hardwareManager.GetMachineResources()
if err != nil {
return fmt.Errorf("getting machine resources: %v", err)
}
if len(machineResources.GPUs) == 0 {
return fmt.Errorf("no GPUs detected on the host")
}
maxFreeVRAMGpu, err := machineResources.GPUs.MaxFreeVRAMGPU()
if err != nil {
return fmt.Errorf("getting GPU with highest free VRAM: %v", err)
}
fmt.Printf("Selected Vendor: %s, Device: %+v\n", maxFreeVRAMGpu.Vendor, maxFreeVRAMGpu)
if maxFreeVRAMGpu.Vendor == types.GPUVendorNvidia {
// Check if NVIDIA container toolkit is installed
// We specifically look for the nvidia-container-toolkit executable because:
// 1. It's the name of the main package installed via apt (nvidia-container-toolkit)
// 2. It's the most reliable indicator of a proper toolkit installation
// 3. Checking for this single file reduces the risk of false positives
_, err = os.Stat("/usr/bin/nvidia-container-toolkit")
if os.IsNotExist(err) {
return fmt.Errorf("nvidia container toolkit is not installed. Please install it before running this command")
}
}
imageName := "ubuntu:20.04"
client, err := docker.NewDockerClient()
if err != nil {
return fmt.Errorf("creating Docker executor: %v", err)
}
if !client.IsInstalled(context.Background()) {
return fmt.Errorf("docker is not installed or running. Cannot run GPU deployment test")
}
fmt.Printf("Creating the docker conainer for the image: %s\n", imageName)
containerConfig := &container.Config{
Image: imageName,
User: "root",
Tty: true, // Enable TTY
AttachStdout: true, // Attach stdout
AttachStderr: true, // Attach stderr
Entrypoint: []string{""}, // Set entrypoint to run shell commands
Cmd: []string{
// This will show both the integrated and discrete GPUs
"sh", "-c",
"apt-get update && apt-get install -y pciutils && lspci | grep 'VGA compatible controller'",
},
}
var hostConfig *container.HostConfig
switch maxFreeVRAMGpu.Vendor {
case types.GPUVendorNvidia:
hostConfig = &container.HostConfig{
AutoRemove: true,
Resources: container.Resources{
DeviceRequests: []container.DeviceRequest{
{
Driver: "nvidia",
Count: -1,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = &container.HostConfig{
AutoRemove: true,
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
case types.GPUVendorIntel:
hostConfig = &container.HostConfig{
AutoRemove: true,
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
default:
return fmt.Errorf("unknown GPU vendor: %s", maxFreeVRAMGpu.Vendor)
}
containerID, err := client.CreateContainer(context.Background(),
containerConfig,
hostConfig,
nil,
nil,
"nunet-gpu-test",
)
if err != nil {
return fmt.Errorf("pulling Docker image: %v", err)
}
fmt.Println("Container created with ID: ", containerID)
if err := client.StartContainer(context.Background(), "nunet-gpu-test"); err != nil {
return fmt.Errorf("starting docker container: %v", err)
}
ctx := context.Background()
// Wait for the container to finish execution
statusCh, errCh := client.WaitContainer(ctx, containerID)
select {
case err := <-errCh:
if err != nil {
fmt.Printf("Container exited with error: %v\n", err)
}
case <-statusCh:
fmt.Println("Container execution completed.")
}
reader, err := client.GetOutputStream(ctx, containerID, "", true)
if err != nil {
return fmt.Errorf("getting output stream: %v", err)
}
// Print the output stream
if _, err := os.Stdout.ReadFrom(reader); err != nil {
return fmt.Errorf("reading output stream: %v", err)
}
return nil
},
}
}
package cmd
import (
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
dmsUtils "gitlab.com/nunet/device-management-service/utils"
)
func newKeyCmd(fs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "key",
Short: "Manage cryptographic keys",
Long: `Manage cryptographic keys for the Device Management Service (DMS).
This command provides subcommands for creating new keys and retrieving Decentralized Identifiers (DIDs) associated with existing keys.`,
}
cmd.AddCommand(newKeyNewCmd(fs))
cmd.AddCommand(newKeyDIDCmd(fs))
return cmd
}
func newKeyNewCmd(fs afero.Afero) *cobra.Command {
return &cobra.Command{
Use: "new <name>",
Short: "Generate a key pair",
Long: `Generate a key pair and save the private key into the user's local keystore.
This command creates a new cryptographic key pair, stores the private key securely, and displays the associated Decentralized Identifier (DID). If a key with the specified name already exists, the user will be prompted to confirm before overwriting it.
Example:
nunet key new user`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
keyStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir)
if err != nil {
return fmt.Errorf("failed to create keystore: %w", err)
}
keyID := dms.UserContextName
if len(args) > 0 {
keyID = args[0]
}
keys, err := ks.ListKeys()
if err != nil {
return fmt.Errorf("failed to list keys: %w", err)
}
if dmsUtils.SliceContains(keys, keyID) {
confirmed, err := utils.PromptYesNo(
cmd.InOrStdin(),
cmd.OutOrStdout(),
fmt.Sprintf("A key with name '%s' already exists. Do you want to overwrite it with a new one?", keyID),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return fmt.Errorf("operation cancelled by user")
}
}
passphrase := os.Getenv("DMS_PASSPHRASE")
if passphrase == "" {
passphrase, err = utils.PromptForPassphrase(true)
if err != nil {
return fmt.Errorf("failed to get passphrase: %w", err)
}
}
priv, err := dms.GenerateAndStorePrivKey(ks, passphrase, keyID)
if err != nil {
return fmt.Errorf("failed to generate and store new private key")
}
did := did.FromPublicKey(priv.GetPublic())
fmt.Println(did)
return nil
},
}
}
func newKeyDIDCmd(fs afero.Afero) *cobra.Command {
return &cobra.Command{
Use: "did <name>",
Short: "Get a key's DID",
Long: `Get the DID (Decentralized Identifier) for a specified key.
This command retrieves the DID associated with either a key stored in the local keystore or a hardware ledger.
For keys in the local keystore, the user will be prompted for the passphrase to decrypt the key. To avoid passphrase prompting, it's possible to set a DMS_PASSPHRASE environment variable. For the ledger option, it uses the first account (index 0) on the connected hardware wallet.
Example:
nunet key did user
nunet key did ledger # if using ledger`,
Args: cobra.ExactArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
keyName := args[0]
if keyName == "ledger" {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
fmt.Println(provider.DID())
return nil
}
keyStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir)
if err != nil {
return fmt.Errorf("failed to open keystore: %w", err)
}
passphrase := os.Getenv("DMS_PASSPHRASE")
if passphrase == "" {
passphrase, err = utils.PromptForPassphrase(false)
if err != nil {
return fmt.Errorf("failed to get passphrase: %w", err)
}
}
key, err := ks.Get(keyName, passphrase)
if err != nil {
return fmt.Errorf("failed to get key: %w", err)
}
priv, err := key.PrivKey()
if err != nil {
return fmt.Errorf("unable to convert key from keystore to private key: %v", err)
}
did := did.FromPublicKey(priv.GetPublic())
fmt.Println(did)
return nil
},
}
}
package cmd
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/actor"
"gitlab.com/nunet/device-management-service/cmd/cap"
"gitlab.com/nunet/device-management-service/utils"
)
func newRootCmd(client *utils.HTTPClient, afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "nunet",
Short: "NuNet Device Management Service",
Long: `The Device Management Service (DMS) Command Line Interface (CLI)`,
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: false,
HiddenDefaultCmd: true,
},
SilenceErrors: true,
SilenceUsage: true,
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
}
cmd.AddCommand(newRunCmd())
cmd.AddCommand(newKeyCmd(afs))
cmd.AddCommand(cap.NewCapCmd(afs))
cmd.AddCommand(actor.NewActorCmd(client, afs))
cmd.AddCommand(newConfigCmd(afs.Fs))
cmd.AddCommand(newAutoCompleteCmd())
cmd.AddCommand(newVersionCmd())
cmd.AddCommand(newTapCommand())
cmd.AddCommand(newGPUCommand())
return cmd
}
//go:build linux
package cmd
import (
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/utils"
)
// Execute is a wrapper for cobra.Command method with same name
// It makes use of cobra.CheckErr to facilitate error handling
func Execute() {
afs := afero.Afero{Fs: afero.NewOsFs()}
client := utils.NewHTTPClient(
fmt.Sprintf("http://%s:%d",
config.GetConfig().Rest.Addr,
config.GetConfig().Rest.Port),
"/api/v1",
)
cobra.CheckErr(newRootCmd(client, afs).Execute())
}
package cmd
import (
"fmt"
"net/http"
_ "net/http/pprof" //#nosec
"os"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/internal/config"
)
func newRunCmd() *cobra.Command {
var context string
pprof := config.GetConfig().Profiler.Enabled
pprofAddr := config.GetConfig().Profiler.Addr
pprofPort := config.GetConfig().Profiler.Port
cmd := &cobra.Command{
Use: "run",
Short: "Start the Device Management Service",
Long: `Start the Device Management Service
The Device Management Service (DMS) is a system application for running a node in the NuNet decentralized network of compute providers.
By default, DMS listens on port 9999. For more information on configuration, see:
nunet config --help
Or manually create a dms_config.json file and refer to the README for available settings.`,
RunE: func(_ *cobra.Command, _ []string) error {
passphrase := os.Getenv("DMS_PASSPHRASE")
var err error
if passphrase == "" {
fmt.Print("Please enter the DMS passphrase. This will be used to encrypt/decrypt the keystore containing necessary secrets for DMS:\n")
passphrase, err = utils.PromptForPassphrase(false)
if err != nil {
return fmt.Errorf("error reading passphrase from stdin: %w", err)
}
// TODO: validate passphrase (minimum x characters)
if passphrase == "" {
return fmt.Errorf("invalid passphrase")
}
}
if pprof {
go func() {
pprofMux := http.DefaultServeMux
http.DefaultServeMux = http.NewServeMux()
profilerAddr := fmt.Sprintf("%s:%d", pprofAddr, pprofPort)
log.Infof("Starting profiler on %s\n", profilerAddr)
// #nosec
if err := http.ListenAndServe(profilerAddr, pprofMux); err != nil {
log.Errorf("Error starting profiler: %v\n", err)
}
}()
}
return dms.Run(passphrase, context)
},
}
cmd.Flags().BoolVar(&pprof, "pprof", pprof, "enable profiling")
cmd.Flags().StringVar(&pprofAddr, "pprof-addr", pprofAddr, "enable profiling")
cmd.Flags().Uint32Var(&pprofPort, "pprof-port", pprofPort, "enable profiling")
cmd.Flags().StringVarP(&context, "context", "c", dms.DefaultContextName, "specify a capability context")
return cmd
}
package cmd
import (
"fmt"
"io"
"os"
"os/exec"
"github.com/spf13/cobra"
)
// newTapCommand creates the Cobra command to set up a TAP interface
func newTapCommand() *cobra.Command {
return &cobra.Command{
Use: "tap [main_interface] [vm_interface] [CIDR]",
Short: "Create and configure a TAP interface",
Long: `Create a TAP interface using the provided interface name and configure IP forwarding and iptables rules.
Example:
nunet tap eth0 tap0 172.16.0.1/24
Note: The command requires root privileges.
`,
Args: cobra.ExactArgs(3),
RunE: func(cmd *cobra.Command, args []string) error {
// check if the user is root
if os.Getuid() != 0 {
return fmt.Errorf("this command requires root privileges to execute")
}
mainInterface := args[0]
vmInterface := args[1]
cidr := args[2]
// Step 1: Create the TAP interface
err := runCommand(cmd.OutOrStdout(), fmt.Sprintf("ip tuntap add %s mode tap", vmInterface))
if err != nil {
return err
}
// Step 2: Assign IP address to the TAP interface
err = runCommand(cmd.OutOrStdout(), fmt.Sprintf("ip addr add %s dev %s", cidr, vmInterface))
if err != nil {
return err
}
// Step 3: Bring the TAP interface up
err = runCommand(cmd.OutOrStdout(), fmt.Sprintf("ip link set %s up", vmInterface))
if err != nil {
return err
}
// Step 4: Enable IP forwarding
err = runCommand(cmd.OutOrStdout(), "echo 1 > /proc/sys/net/ipv4/ip_forward")
if err != nil {
return err
}
// Step 5: Add iptables rules for connection tracking
err = runCommand(cmd.OutOrStdout(), "iptables -C FORWARD -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT || iptables -A FORWARD -m conntrack --ctstate RELATED,ESTABLISHED -j ACCEPT")
if err != nil {
return err
}
// Step 6: Add iptables rules to allow forwarding between interfaces
err = runCommand(cmd.OutOrStdout(), fmt.Sprintf("iptables -C FORWARD -i %s -o %s -j ACCEPT || iptables -A FORWARD -i %s -o %s -j ACCEPT", vmInterface, mainInterface, vmInterface, mainInterface))
if err != nil {
return err
}
fmt.Fprintf(cmd.OutOrStdout(), "TAP interface %s created and configured successfully\n", vmInterface)
return nil
},
}
}
// Helper function to execute shell commands
func runCommand(stdout io.Writer, command string) error {
cmd := exec.Command("sh", "-c", command)
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("failed to execute command: %s, Error: %w, Output: %s", command, err, output)
}
fmt.Fprintf(stdout, "Executed: %s\n", command)
return nil
}
package utils
import (
"fmt"
"strings"
"time"
"github.com/spf13/pflag"
)
// TimeValue adapts time.Time for use as a flag.
type TimeValue struct {
Time *time.Time
Formats []string
}
// NewTimeValue creates a new TimeValue.
func NewTimeValue(t *time.Time, formats ...string) *TimeValue {
if formats == nil {
formats = []string{
time.RFC822,
time.RFC822Z,
time.RFC3339,
time.RFC3339Nano,
time.DateTime,
time.DateOnly,
}
}
return &TimeValue{
Time: t,
Formats: formats,
}
}
// Set time.Time value from string based on accepted formats.
func (t TimeValue) Set(s string) error {
s = strings.TrimSpace(s)
for _, format := range t.Formats {
v, err := time.Parse(format, s)
if err == nil {
*t.Time = v
return nil
}
}
return fmt.Errorf("format must be one of: %v", strings.Join(t.Formats, ", "))
}
// Type name for time.Time flags.
func (t TimeValue) Type() string {
return "time"
}
// String returns the string representation of the time.Time value.
func (t TimeValue) String() string {
if t.Time == nil || t.Time.IsZero() {
return ""
}
return t.Time.String()
}
func GetTime(f *pflag.FlagSet, name string) (time.Time, error) {
t := time.Time{}
flag := f.Lookup(name)
if flag == nil {
return t, fmt.Errorf("flag %s not found", name)
}
if flag.Value == nil || flag.Value.Type() != new(TimeValue).Type() {
return t, fmt.Errorf("flag %s has wrong type or no value", name)
}
val := flag.Value.(*TimeValue)
return *val.Time, nil
}
package utils
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"github.com/howeyc/gopass"
)
// PromptReonboard is a wrapper of utils.PromptYesNo with custom prompt that return error if user declines reonboard
func PromptReonboard(r io.Reader, w io.Writer) error {
prompt := "Looks like your machine is already onboarded. Proceed with reonboarding?"
confirmed, err := PromptYesNo(r, w, prompt)
if err != nil {
return fmt.Errorf("could not confirm reonboarding: %w", err)
}
if !confirmed {
return fmt.Errorf("reonboarding aborted by user")
}
return nil
}
func PromptForPassphrase(confirm bool) (string, error) {
maxTries := 3
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
done := make(chan bool)
var passphrase string
var err error
// Start a goroutine to handle passphrase input
go func() {
defer close(done)
var bytePassphrase, byteConfirmation []byte
for i := 0; i < maxTries; i++ {
fmt.Print("Passphrase: ")
bytePassphrase, err = gopass.GetPasswdMasked()
if err != nil {
err = fmt.Errorf("failed to read passphrase: %w", err)
return
}
if confirm {
fmt.Print("Please confirm your passphrase: ")
byteConfirmation, err = gopass.GetPasswdMasked()
if err != nil {
err = fmt.Errorf("failed to read passphrase confirmation: %w", err)
return
}
if string(bytePassphrase) != string(byteConfirmation) {
err = fmt.Errorf("passphrases do not match")
}
}
if err == nil {
passphrase = string(bytePassphrase)
return
}
fmt.Println(err)
fmt.Println("")
}
err = fmt.Errorf("user failed to input passphrase")
}()
// Wait for either the passphrase input to complete or an interrupt signal
select {
case <-done:
return passphrase, err
case <-sigChan:
return "", errors.New("sigterm received")
}
}
// PromptYesNo loops on confirmation from user until valid answer
func PromptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) {
reader := bufio.NewReader(in)
for {
fmt.Fprintf(out, "%s (y/N): ", prompt)
response, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("read response string failed: %w", err)
}
response = strings.ToLower(strings.TrimSpace(response))
if response == "y" || response == "yes" {
return true, nil
} else if response == "n" || response == "no" {
return false, nil
}
}
}
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
// Version information set by the build system (see Makefile)
var (
Version string
GoVersion string
BuildDate string
Commit string
)
func newVersionCmd() *cobra.Command {
return &cobra.Command{
Use: "version",
Short: "Information about current version",
Long: `Display information about the current Device Management Service (DMS) version`,
Run: func(_ *cobra.Command, _ []string) {
fmt.Println("NuNet Device Management Service")
fmt.Printf("Version: %s\nCommit: %s\n\nGo Version: %s\nBuild Date: %s\n",
Version, Commit, GoVersion, BuildDate)
},
}
}
package db
import (
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
)
var DB *gorm.DB
func ConnectDatabase(dbPath string) (*gorm.DB, error) {
database, err := gorm.Open(sqlite.Open(fmt.Sprintf("%s/nunet.db", dbPath)), &gorm.Config{})
if err != nil {
panic("Failed to connect to database!")
}
_ = database.AutoMigrate(&types.ElasticToken{})
_ = database.AutoMigrate(&types.VirtualMachine{})
_ = database.AutoMigrate(&types.Machine{})
_ = database.AutoMigrate(&types.FreeResources{})
_ = database.AutoMigrate(&types.PeerInfo{})
_ = database.AutoMigrate(&types.Services{})
_ = database.AutoMigrate(&types.ServiceResourceRequirements{})
_ = database.AutoMigrate(&types.ContainerImages{})
_ = database.AutoMigrate(&types.RequestTracker{})
_ = database.AutoMigrate(&types.Libp2pInfo{})
_ = database.AutoMigrate(&types.DeploymentRequestFlat{})
_ = database.AutoMigrate(&types.MachineUUID{})
_ = database.AutoMigrate(&types.Connection{})
_ = database.AutoMigrate(&types.OnboardedResources{})
_ = database.AutoMigrate(&types.MachineResources{})
_ = database.AutoMigrate(&types.OnboardingConfig{})
_ = database.AutoMigrate(&types.ResourceAllocation{})
// TODO remove once all DB usage is transitioned to the repos
DB = database
return database, nil
}
package clover
import (
"fmt"
clover "github.com/ostafen/clover/v2"
)
// NewDB initializes and sets up the clover database using bbolt under the hood.
// Additionally, it automatically creates collections for the necessary types.
func NewDB(path string, collections []string) (*clover.DB, error) {
db, err := clover.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
for _, collection := range collections {
if err := db.CreateCollection(collection); err != nil {
return nil, fmt.Errorf("failed to create collection %s: %w", collection, err)
}
}
return db, nil
}
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatClover is a Clover implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatClover struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatClover.
// It initializes and returns a Clover-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *clover.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatClover{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerClover is a Clover implementation of the RequestTracker interface.
type RequestTrackerClover struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerClover.
// It initializes and returns a Clover-based repository for RequestTracker entities.
func NewRequestTracker(db *clover.DB) repositories.RequestTracker {
return &RequestTrackerClover{
NewGenericRepository[types.RequestTracker](db),
}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineClover is a Clover implementation of the VirtualMachine interface.
type VirtualMachineClover struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineClover.
// It initializes and returns a Clover-based repository for VirtualMachine entities.
func NewVirtualMachine(db *clover.DB) repositories.VirtualMachine {
return &VirtualMachineClover{
NewGenericRepository[types.VirtualMachine](db),
}
}
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericEntityRepositoryClover is a generic single entity repository implementation using Clover.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryClover.
// It initializes and returns a repository with the provided Clover database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericEntityRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericEntityRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericEntityRepositoryClover[T]) query() *clover_q.Query {
return clover_q.NewQuery(repo.collection)
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryClover[T]) Save(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) Get(_ context.Context) (T, error) {
var model T
q := repo.query().Sort(clover_q.SortOption{
Field: "CreatedAt",
Direction: -1,
})
doc, err := repo.db.FindFirst(q)
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryClover[T]) Clear(_ context.Context) error {
return repo.db.Delete(repo.query())
}
// History retrieves previous versions of the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) History(_ context.Context, query repositories.Query[T]) ([]T, error) {
var models []T
q := repo.query()
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return false
}
models = append(models, model)
return true
})
return models, handleDBError(err)
}
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pkField = "_id"
deletedAtField = "DeletedAt"
)
// GenericRepositoryClover is a generic repository implementation using Clover.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericRepository creates a new instance of GenericRepositoryClover.
// It initializes and returns a repository with the provided Clover database.
func NewGenericRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericRepositoryClover[T]) query(includeDeleted bool) *clover_q.Query {
query := clover_q.NewQuery(repo.collection)
if !includeDeleted {
query = query.Where(clover_q.Field(deletedAtField).LtEq(time.Unix(0, 0)))
}
return query
}
func (repo *GenericRepositoryClover[T]) queryWithID(
id interface{},
includeDeleted bool,
) *clover_q.Query {
return repo.query(includeDeleted).Where(clover_q.Field(pkField).Eq(id.(string)))
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryClover[T]) Create(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return data, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryClover[T]) Get(_ context.Context, id interface{}) (T, error) {
var model T
doc, err := repo.db.FindFirst(repo.queryWithID(id, false))
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryClover[T]) Update(
ctx context.Context,
id interface{},
data T,
) (T, error) {
updates := toCloverDoc(data).AsMap()
updates["UpdatedAt"] = time.Now()
err := repo.db.Update(repo.queryWithID(id, false), updates)
if err != nil {
return data, handleDBError(err)
}
data, err = repo.Get(ctx, id)
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryClover[T]) Delete(_ context.Context, id interface{}) error {
err := repo.db.Update(
repo.queryWithID(id, false),
map[string]interface{}{"DeletedAt": time.Now()},
)
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryClover[T]) Find(
_ context.Context,
query repositories.Query[T],
) (T, error) {
var model T
q := repo.query(false)
q = applyConditions(q, query)
doc, err := repo.db.FindFirst(q)
if err != nil {
return model, handleDBError(err)
} else if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryClover[T]) FindAll(
_ context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
var modelParsingErr error
q := repo.query(false)
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
model, internalErr := toModel[T](doc, false)
if internalErr != nil {
modelParsingErr = handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, internalErr))
return false
}
models = append(models, model)
return true
})
if err != nil {
return models, handleDBError(err)
}
if modelParsingErr != nil {
return models, modelParsingErr
}
return models, nil
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a Clover database query.
// It takes a Clover database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified Clover database instance is returned.
func applyConditions[T repositories.ModelType](
q *clover_q.Query,
query repositories.Query[T],
) *clover_q.Query {
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
// change the field name to json tag name if specified in the struct
condition.Field = fieldJSONTag[T](condition.Field)
switch condition.Operator {
case "=":
q = q.Where(clover_q.Field(condition.Field).Eq(condition.Value))
case ">":
q = q.Where(clover_q.Field(condition.Field).Gt(condition.Value))
case ">=":
q = q.Where(clover_q.Field(condition.Field).GtEq(condition.Value))
case "<":
q = q.Where(clover_q.Field(condition.Field).Lt(condition.Value))
case "<=":
q = q.Where(clover_q.Field(condition.Field).LtEq(condition.Value))
case "!=":
q = q.Where(clover_q.Field(condition.Field).Neq(condition.Value))
case "IN":
if values, ok := condition.Value.([]interface{}); ok {
q = q.Where(clover_q.Field(condition.Field).In(values...))
}
case "LIKE":
if value, ok := condition.Value.(string); ok {
q = q.Where(clover_q.Field(condition.Field).Like(value))
}
}
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldName = fieldJSONTag[T](fieldName)
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
q = q.Where(clover_q.Field(fieldName).Eq(fieldValue))
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := 1
if query.SortBy[0] == '-' {
dir = -1
query.SortBy = fieldJSONTag[T](query.SortBy[1:])
}
q = q.Sort(clover_q.SortOption{Field: query.SortBy, Direction: dir})
}
// Apply limit if specified in the query.
if query.Limit > 0 {
q = q.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
q = q.Limit(query.Offset)
}
return q
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoClover is a Clover implementation of the PeerInfo interface.
type PeerInfoClover struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoClover.
// It initializes and returns a Clover-based repository for PeerInfo entities.
func NewPeerInfo(db *clover.DB) repositories.PeerInfo {
return &PeerInfoClover{NewGenericRepository[types.PeerInfo](db)}
}
// MachineClover is a Clover implementation of the Machine interface.
type MachineClover struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineClover.
// It initializes and returns a Clover-based repository for Machine entities.
func NewMachine(db *clover.DB) repositories.Machine {
return &MachineClover{NewGenericRepository[types.Machine](db)}
}
// ServicesClover is a Clover implementation of the Services interface.
type ServicesClover struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesClover.
// It initializes and returns a Clover-based repository for Services entities.
func NewServices(db *clover.DB) repositories.Services {
return &ServicesClover{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsClover is a Clover implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsClover struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsClover.
// It initializes and returns a Clover-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *clover.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsClover{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoClover is a Clover implementation of the Libp2pInfo interface.
type Libp2pInfoClover struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoClover.
// It initializes and returns a Clover-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *clover.DB) repositories.Libp2pInfo {
return &Libp2pInfoClover{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDClover is a Clover implementation of the MachineUUID interface.
type MachineUUIDClover struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDClover.
// It initializes and returns a Clover-based repository for MachineUUID entity.
func NewMachineUUID(db *clover.DB) repositories.MachineUUID {
return &MachineUUIDClover{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionClover is a Clover implementation of the Connection interface.
type ConnectionClover struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionClover.
// It initializes and returns a Clover-based repository for Connection entities.
func NewConnection(db *clover.DB) repositories.Connection {
return &ConnectionClover{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenClover is a Clover implementation of the ElasticToken interface.
type ElasticTokenClover struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenClover.
// It initializes and returns a Clover-based repository for ElasticToken entities.
func NewElasticToken(db *clover.DB) repositories.ElasticToken {
return &ElasticTokenClover{NewGenericRepository[types.ElasticToken](db)}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesRepositoryClover is a Clover implementation of the MachineResourcesRepository interface.
type MachineResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResourcesRepository creates a new instance of MachineResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for MachineResources entity.
func NewMachineResourcesRepository(db *clover.DB) repositories.MachineResources {
return &MachineResourcesRepositoryClover{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesClover is a Clover implementation of the FreeResources interface.
type FreeResourcesClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResources(db *clover.DB) repositories.FreeResources {
return &FreeResourcesClover{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesClover is a Clover implementation of the OnboardedResources interface.
type OnboardedResourcesClover struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResources creates a new instance of OnboardedResourcesClover.
// It initializes and returns a Clover-based repository for OnboardedResources entity.
func NewOnboardedResources(db *clover.DB) repositories.OnboardedResources {
return &OnboardedResourcesClover{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// ResourceAllocationClover is a Clover implementation of the ResourceAllocation interface.
type ResourceAllocationClover struct {
repositories.GenericRepository[types.ResourceAllocation]
}
// NewResourceAllocation creates a new instance of ResourceAllocationClover.
// It initializes and returns a Clover-based repository for ResourceAllocation entity.
func NewResourceAllocation(db *clover.DB) repositories.ResourceAllocation {
return &ResourceAllocationClover{
NewGenericRepository[types.ResourceAllocation](db),
}
}
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeClover is a Clover implementation of the StorageVolume interface.
type StorageVolumeClover struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeClover.
// It initializes and returns a Clover-based repository for StorageVolume entities.
func NewStorageVolume(db *clover.DB) repositories.StorageVolume {
return &StorageVolumeClover{
NewGenericRepository[types.StorageVolume](db),
}
}
package clover
import (
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
"gitlab.com/nunet/device-management-service/db/repositories"
)
func handleDBError(err error) error {
if err != nil {
switch err {
case clover.ErrDocumentNotExist:
return repositories.ErrNotFound
case clover.ErrDuplicateKey:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
func toCloverDoc[T repositories.ModelType](data T) *clover_d.Document {
jsonBytes, err := json.Marshal(data)
if err != nil {
return clover_d.NewDocument()
}
mappedData := make(map[string]interface{})
err = json.Unmarshal(jsonBytes, &mappedData)
if err != nil {
return clover_d.NewDocument()
}
doc := clover_d.NewDocumentOf(mappedData)
return doc
}
func toModel[T repositories.ModelType](doc *clover_d.Document, isEntityRepo bool) (T, error) {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return model, err
}
if !isEntityRepo {
// we shouldn't try to update IDs of entity repositories as they might not
// even have an ID at all
model, err = repositories.UpdateField(model, "ID", doc.ObjectId())
if err != nil {
return model, err
}
}
return model, nil
}
func fieldJSONTag[T repositories.ModelType](field string) string {
fieldName := field
if field, ok := reflect.TypeOf(*new(T)).FieldByName(field); ok {
if tag, ok := field.Tag.Lookup("json"); ok {
fieldName = strings.Split(tag, ",")[0]
}
}
return fieldName
}
package repositories
import (
"context"
)
// QueryCondition is a struct representing a query condition.
type QueryCondition struct {
Field string // Field specifies the database or struct field to which the condition applies.
Operator string // Operator defines the comparison operator (e.g., "=", ">", "<").
Value interface{} // Value is the expected value for the given field.
}
type ModelType interface{}
// Query is a struct that wraps both the instance of type T and additional query parameters.
// It is used to construct queries with conditions, sorting, limiting, and offsetting.
type Query[T any] struct {
Instance T // Instance is an optional object of type T used to build conditions from its fields.
Conditions []QueryCondition // Conditions represent the conditions applied to the query.
SortBy string // SortBy specifies the field by which the query results should be sorted.
Limit int // Limit specifies the maximum number of results to return.
Offset int // Offset specifies the number of results to skip before starting to return data.
}
// GenericRepository is an interface defining basic CRUD operations and standard querying methods.
type GenericRepository[T ModelType] interface {
// Create adds a new record to the repository.
Create(ctx context.Context, data T) (T, error)
// Get retrieves a record by its identifier.
Get(ctx context.Context, id interface{}) (T, error)
// Update modifies a record by its identifier.
Update(ctx context.Context, id interface{}, data T) (T, error)
// Delete removes a record by its identifier.
Delete(ctx context.Context, id interface{}) error
// Find retrieves a single record based on a query.
Find(ctx context.Context, query Query[T]) (T, error)
// FindAll retrieves multiple records based on a query.
FindAll(ctx context.Context, query Query[T]) ([]T, error)
// GetQuery returns an empty query instance for the repository's type.
GetQuery() Query[T]
}
// EQ creates a QueryCondition for equality comparison.
// It takes a field name and a value and returns a QueryCondition with the equality operator.
func EQ(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "=", Value: value}
}
// GT creates a QueryCondition for greater-than comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than operator.
func GT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">", Value: value}
}
// GTE creates a QueryCondition for greater-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than or equal operator.
func GTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">=", Value: value}
}
// LT creates a QueryCondition for less-than comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than operator.
func LT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<", Value: value}
}
// LTE creates a QueryCondition for less-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than or equal operator.
func LTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<=", Value: value}
}
// IN creates a QueryCondition for an "IN" comparison.
// It takes a field name and a slice of values and returns a QueryCondition with the "IN" operator.
func IN(field string, values []interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "IN", Value: values}
}
// LIKE creates a QueryCondition for a "LIKE" comparison.
// It takes a field name and a pattern and returns a QueryCondition with the "LIKE" operator.
func LIKE(field, pattern string) QueryCondition {
return QueryCondition{Field: field, Operator: "LIKE", Value: pattern}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatGORM is a GORM implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatGORM struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatGORM.
// It initializes and returns a GORM-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *gorm.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatGORM{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerGORM is a GORM implementation of the RequestTracker interface.
type RequestTrackerGORM struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerGORM.
// It initializes and returns a GORM-based repository for RequestTracker entities.
func NewRequestTracker(db *gorm.DB) repositories.RequestTracker {
return &RequestTrackerGORM{
NewGenericRepository[types.RequestTracker](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineGORM is a GORM implementation of the VirtualMachine interface.
type VirtualMachineGORM struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineGORM.
// It initializes and returns a GORM-based repository for VirtualMachine entities.
func NewVirtualMachine(db *gorm.DB) repositories.VirtualMachine {
return &VirtualMachineGORM{
NewGenericRepository[types.VirtualMachine](db),
}
}
package gorm
import (
"context"
"fmt"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
createdAtField = "CreatedAt"
)
// GenericEntityRepositoryGORM is a generic single entity repository implementation using GORM as an ORM.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB // db is the GORM database instance.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryGORM.
// It initializes and returns a repository with the provided GORM database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *gorm.DB,
) repositories.GenericEntityRepository[T] {
return &GenericEntityRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryGORM[T]) Save(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Get(ctx context.Context) (T, error) {
var result T
query := repo.GetQuery()
query.SortBy = fmt.Sprintf("-%s", createdAtField)
db := repo.db.WithContext(ctx)
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Clear(ctx context.Context) error {
return repo.db.WithContext(ctx).Delete(new(T), "id IS NOT NULL").Error
}
// History retrieves previous records from the repository constrained by the query.
func (repo *GenericEntityRepositoryGORM[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
package gorm
import (
"context"
"fmt"
"reflect"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericRepositoryGORM is a generic repository implementation using GORM as an ORM.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB
}
// NewGenericRepository creates a new instance of GenericRepositoryGORM.
// It initializes and returns a repository with the provided GORM database.
func NewGenericRepository[T repositories.ModelType](db *gorm.DB) repositories.GenericRepository[T] {
return &GenericRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryGORM[T]) Create(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Get(ctx context.Context, id interface{}) (T, error) {
var result T
err := repo.db.WithContext(ctx).First(&result, "id = ?", id).Error
if err != nil {
return result, handleDBError(err)
}
return result, handleDBError(err)
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Update(ctx context.Context, id interface{}, data T) (T, error) {
err := repo.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Updates(data).Error
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.WithContext(ctx).Delete(new(T), "id = ?", id).Error
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryGORM[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var result T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryGORM[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a GORM database query.
// It takes a GORM database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified GORM database instance is returned.
func applyConditions[T any](db *gorm.DB, query repositories.Query[T]) *gorm.DB {
// Retrieve the table name using the GORM naming strategy.
tableName := db.NamingStrategy.TableName(reflect.TypeOf(*new(T)).Name())
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
columnName := db.NamingStrategy.ColumnName(tableName, condition.Field)
db = db.Where(
fmt.Sprintf("%s %s ?", columnName, condition.Operator),
condition.Value,
)
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
columnName := db.NamingStrategy.ColumnName(tableName, fieldName)
db = db.Where(fmt.Sprintf("%s = ?", columnName), fieldValue)
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := "ASC"
if query.SortBy[0] == '-' {
query.SortBy = query.SortBy[1:]
dir = "DESC"
}
columnName := db.NamingStrategy.ColumnName(tableName, query.SortBy)
db = db.Order(fmt.Sprintf("%s.%s %s", tableName, columnName, dir))
}
// Apply limit if specified in the query.
if query.Limit > 0 {
db = db.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
db = db.Limit(query.Offset)
}
return db
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoGORM is a GORM implementation of the PeerInfo interface.
type PeerInfoGORM struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoGORM.
// It initializes and returns a GORM-based repository for PeerInfo entities.
func NewPeerInfo(db *gorm.DB) repositories.PeerInfo {
return &PeerInfoGORM{NewGenericRepository[types.PeerInfo](db)}
}
// MachineGORM is a GORM implementation of the Machine interface.
type MachineGORM struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineGORM.
// It initializes and returns a GORM-based repository for Machine entities.
func NewMachine(db *gorm.DB) repositories.Machine {
return &MachineGORM{NewGenericRepository[types.Machine](db)}
}
// ServicesGORM is a GORM implementation of the Services interface.
type ServicesGORM struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesGORM.
// It initializes and returns a GORM-based repository for Services entities.
func NewServices(db *gorm.DB) repositories.Services {
return &ServicesGORM{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsGORM is a GORM implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsGORM struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsGORM.
// It initializes and returns a GORM-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *gorm.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsGORM{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoGORM is a GORM implementation of the Libp2pInfo interface.
type Libp2pInfoGORM struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoGORM.
// It initializes and returns a GORM-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *gorm.DB) repositories.Libp2pInfo {
return &Libp2pInfoGORM{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDGORM is a GORM implementation of the MachineUUID interface.
type MachineUUIDGORM struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDGORM.
// It initializes and returns a GORM-based repository for MachineUUID entity.
func NewMachineUUID(db *gorm.DB) repositories.MachineUUID {
return &MachineUUIDGORM{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionGORM is a GORM implementation of the Connection interface.
type ConnectionGORM struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionGORM.
// It initializes and returns a GORM-based repository for Connection entities.
func NewConnection(db *gorm.DB) repositories.Connection {
return &ConnectionGORM{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenGORM is a GORM implementation of the ElasticToken interface.
type ElasticTokenGORM struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenGORM.
// It initializes and returns a GORM-based repository for ElasticToken entities.
func NewElasticToken(db *gorm.DB) repositories.ElasticToken {
return &ElasticTokenGORM{NewGenericRepository[types.ElasticToken](db)}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
type OnboardingConfigGORM struct {
repositories.GenericEntityRepository[types.OnboardingConfig]
}
func NewOnboardingConfig(db *gorm.DB) repositories.OnboardingConfig {
return &OnboardingConfigGORM{
NewGenericEntityRepository[types.OnboardingConfig](db),
}
}
package gorm
import (
"gitlab.com/nunet/device-management-service/types"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesGORM is a GORM implementation of the MachineResources interface.
type MachineResourcesGORM struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResources creates a new instance of MachineResourcesGORM.
// It initializes and returns a GORM-based repository for MachineResources entity.
func NewMachineResources(db *gorm.DB) repositories.MachineResources {
return &MachineResourcesGORM{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesGORM is a GORM implementation of the FreeResources interface.
type FreeResourcesGORM struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesGORM.
// It initializes and returns a GORM-based repository for FreeResources entity.
func NewFreeResources(db *gorm.DB) repositories.FreeResources {
return &FreeResourcesGORM{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesGORM is a GORM implementation of the OnboardedResources interface.
type OnboardedResourcesGORM struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResources creates a new instance of OnboardedResourcesGORM.
// It initializes and returns a GORM-based repository for OnboardedResources entity.
func NewOnboardedResources(db *gorm.DB) repositories.OnboardedResources {
return &OnboardedResourcesGORM{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// ResourceAllocationGORM is a GORM implementation of the ResourceAllocation interface.
type ResourceAllocationGORM struct {
repositories.GenericRepository[types.ResourceAllocation]
}
// NewResourceAllocation creates a new instance of ResourceAllocationGORM.
// It initializes and returns a GORM-based repository for ResourceAllocation entities.
func NewResourceAllocation(db *gorm.DB) repositories.ResourceAllocation {
return &ResourceAllocationGORM{
NewGenericRepository[types.ResourceAllocation](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeGORM is a GORM implementation of the StorageVolume interface.
type StorageVolumeGORM struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeGORM.
// It initializes and returns a GORM-based repository for StorageVolume entities.
func NewStorageVolume(db *gorm.DB) repositories.StorageVolume {
return &StorageVolumeGORM{
NewGenericRepository[types.StorageVolume](db),
}
}
package gorm
import (
"errors"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// handleDBError is a utility function that translates GORM database errors into custom repository errors.
// It takes a GORM database error as input and returns a corresponding custom error from the repositories package.
func handleDBError(err error) error {
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return repositories.ErrNotFound
case gorm.ErrInvalidData, gorm.ErrInvalidField, gorm.ErrInvalidValue:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
package repositories
import (
"fmt"
"reflect"
)
// UpdateField is a generic function that updates a field of a struct or a pointer to a struct.
// The function uses reflection to dynamically update the specified field of the input struct.
func UpdateField[T interface{}](input T, fieldName string, newValue interface{}) (T, error) {
// Use reflection to get the struct's field
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
// If input is a pointer, get the underlying element
val = val.Elem()
} else {
// If input is not a pointer, ensure it's addressable
val = reflect.ValueOf(&input).Elem()
}
// Check if the input is a struct
if val.Kind() != reflect.Struct {
return input, fmt.Errorf("not a struct: %T", input)
}
// Get the field by name
field := val.FieldByName(fieldName)
if !field.IsValid() {
return input, fmt.Errorf("field not found: %v", fieldName)
}
// Check if the field is settable
if !field.CanSet() {
return input, fmt.Errorf("field not settable: %v", fieldName)
}
// Check if types are compatible
if !reflect.TypeOf(newValue).ConvertibleTo(field.Type()) {
return input, fmt.Errorf(
"incompatible conversion: %v -> %v; value: %v",
field.Type(), reflect.TypeOf(newValue), newValue,
)
}
// Convert the new value to the field type
convertedValue := reflect.ValueOf(newValue).Convert(field.Type())
// Set the new value to the field
field.Set(convertedValue)
return input, nil
}
// IsEmptyValue checks if value represents a zero-value struct (or pointer to a zero-value struct) using reflection.
// The function is useful for determining if a struct or its pointer is empty, i.e., all fields have their zero-values.
func IsEmptyValue(value interface{}) bool {
// Check if the value is nil
if value == nil {
return true
}
// Use reflection to get the value's type and kind
val := reflect.ValueOf(value)
// If the value is a pointer, dereference it to get the underlying element
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Check if the value is zero (empty) based on its kind
return val.IsZero()
}
package crypto
import (
"crypto/ecdsa"
"errors"
"github.com/ethereum/go-ethereum/common/hexutil"
"github.com/ethereum/go-ethereum/crypto"
"github.com/fivebinaries/go-cardano-serialization/address"
"github.com/fivebinaries/go-cardano-serialization/bip32"
"github.com/fivebinaries/go-cardano-serialization/network"
"github.com/tyler-smith/go-bip39"
"gitlab.com/nunet/device-management-service/types"
)
func GetEthereumAddressAndPrivateKey() (*types.Account, error) {
privateKey, err := crypto.GenerateKey()
if err != nil {
return nil, err
}
privateKeyBytes := crypto.FromECDSA(privateKey)
privateKeyString := hexutil.Encode(privateKeyBytes)
publicKey := privateKey.Public()
publicKeyECDSA, ok := publicKey.(*ecdsa.PublicKey)
if !ok {
return nil, errors.New("publicKey is not of type *ecdsa.PublicKey")
}
address := crypto.PubkeyToAddress(*publicKeyECDSA).Hex()
var pair types.Account
pair.Address = address
pair.PrivateKey = privateKeyString
return &pair, nil
}
func harden(num uint32) uint32 {
return 0x80000000 + num
}
func GetCardanoAddressAndMnemonic() (*types.Account, error) {
var pair types.Account
entropy, _ := bip39.NewEntropy(256)
mnemonic, _ := bip39.NewMnemonic(entropy)
pair.Mnemonic = mnemonic
rootKey := bip32.FromBip39Entropy(
entropy,
[]byte{},
)
accountKey := rootKey.Derive(harden(1852)).Derive(harden(1815)).Derive(harden(0))
utxoPubKey := accountKey.Derive(0).Derive(0).Public()
utxoPubKeyHash := utxoPubKey.PublicKey().Hash()
stakeKey := accountKey.Derive(2).Derive(0).Public()
stakeKeyHash := stakeKey.PublicKey().Hash()
addr := address.NewBaseAddress(
network.MainNet(),
&address.StakeCredential{
Kind: address.KeyStakeCredentialType,
Payload: utxoPubKeyHash[:],
},
&address.StakeCredential{
Kind: address.KeyStakeCredentialType,
Payload: stakeKeyHash[:],
})
pair.Address = addr.String()
return &pair, nil
}
package crypto
import (
"fmt"
"gitlab.com/nunet/device-management-service/types"
)
// CreatePaymentAddress generates a keypair based on the wallet type. Currently supported types: ethereum, cardano.
func CreatePaymentAddress(wallet string) (*types.Account, error) {
var (
pair *types.Account
err error
)
switch wallet {
case "ethereum":
pair, err = GetEthereumAddressAndPrivateKey()
case "cardano":
pair, err = GetCardanoAddressAndMnemonic()
default:
return nil, fmt.Errorf("invalid wallet")
}
if err != nil {
return nil, fmt.Errorf("could not generate %s address: %w", wallet, err)
}
return pair, nil
}
package dms
import (
"errors"
"fmt"
"os"
"path/filepath"
"time"
"gitlab.com/nunet/device-management-service/dms/hardware"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/multiformats/go-multiaddr"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/api"
"gitlab.com/nunet/device-management-service/db"
gdb "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/internal"
backgroundtasks "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/telemetry/logger"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
DefaultContextName = "dms"
UserContextName = "user"
KeystoreDir = "key/"
CapstoreDir = "cap/"
)
// NewP2P is stub, real implementation is needed in order to pass it to
// routers which access them in some handlers.
func NewP2P() libp2p.Libp2p {
return libp2p.Libp2p{}
}
// QUESTION(dms-initialization): should the db instance be constructed here?
func Run(ksPassphrase string, contextName string) error {
if contextName == "" {
contextName = DefaultContextName
}
fs := afero.NewOsFs()
keyStoreDir := filepath.Join(config.GetConfig().General.UserDir, KeystoreDir)
keyStore, err := keystore.New(fs, keyStoreDir)
if err != nil {
return fmt.Errorf("unable to create keystore: %w", err)
}
var priv crypto.PrivKey
ksPrivKey, err := keyStore.Get(contextName, ksPassphrase)
if err != nil {
if errors.Is(err, keystore.ErrKeyNotFound) {
priv, err = GenerateAndStorePrivKey(keyStore, ksPassphrase, contextName)
if err != nil {
return fmt.Errorf("couldn't generate and store priv key into keystore: %w", err)
}
} else {
return fmt.Errorf("failed to get private key from keystore; Error: %v", err)
}
} else {
priv, err = ksPrivKey.PrivKey()
if err != nil {
return fmt.Errorf("unable to convert key from keystore to private key: %v", err)
}
}
pubKey := priv.GetPublic()
db, err := db.ConnectDatabase(config.GetConfig().General.WorkDir)
if err != nil {
return fmt.Errorf("unable to connect to database: %w", err)
}
hardwareManager := hardware.NewHardwareManager()
repos := resources.ManagerRepos{
FreeResources: gdb.NewFreeResources(db),
OnboardedResources: gdb.NewOnboardedResources(db),
ResourceAllocation: gdb.NewResourceAllocation(db),
}
resourceManager, err := resources.NewResourceManager(repos, hardwareManager)
if err != nil {
return fmt.Errorf("unable to create resource manager: %w", err)
}
onboardR := gdb.NewOnboardingConfig(db)
p2pR := gdb.NewLibp2pInfo(db)
uuidR := gdb.NewMachineUUID(db)
onboard := onboarding.New(&onboarding.Config{
Fs: afero.Afero{Fs: fs},
ConfigRepo: onboardR,
P2PRepo: p2pR,
UUIDRepo: uuidR,
Hardware: hardwareManager,
ResourceManager: resourceManager,
WorkDir: config.GetConfig().WorkDir,
DatabasePath: fmt.Sprintf("%s/nunet.db", config.GetConfig().General.WorkDir),
})
var p2pNet *libp2p.Libp2p
bootstrapPeers := make([]multiaddr.Multiaddr, len(config.GetConfig().P2P.BootstrapPeers))
for i, addr := range config.GetConfig().P2P.BootstrapPeers {
bootstrapPeers[i], _ = multiaddr.NewMultiaddr(addr)
}
gcfg := config.GetConfig()
cfg := &types.Libp2pConfig{
PrivateKey: priv,
BootstrapPeers: bootstrapPeers,
Rendezvous: "nunet-test",
Server: false,
Scheduler: backgroundtasks.NewScheduler(10),
CustomNamespace: "/nunet-dht-1/",
ListenAddress: gcfg.P2P.ListenAddress,
PeerCountDiscoveryLimit: 40,
Memory: gcfg.P2P.Memory,
FileDescriptors: gcfg.P2P.FileDescriptors,
}
p2p, err := libp2p.New(cfg, fs)
if err != nil {
return fmt.Errorf("unable to create libp2p instance: %v", err)
}
if err = p2p.Init(); err != nil {
return fmt.Errorf("unable to initialize libp2p: %v", err)
}
if err = p2p.Start(); err != nil {
return fmt.Errorf("unable to start libp2p: %v", err)
}
p2pNet = p2p
trustCtx, err := did.NewTrustContextWithPrivateKey(priv)
if err != nil {
return fmt.Errorf("unable to create trust context: %w", err)
}
capStoreDir := filepath.Join(config.GetConfig().General.UserDir, CapstoreDir)
capStoreFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", contextName))
var capCtx ucan.CapabilityContext
if _, err := os.Stat(capStoreFile); err != nil {
if err := fs.MkdirAll(capStoreDir, os.FileMode(0o700)); err != nil {
return fmt.Errorf("unable to create capability context directory: %w", err)
}
// does not exist; create it
rootDID := did.FromPublicKey(pubKey)
capCtx, err = ucan.NewCapabilityContext(trustCtx, rootDID, nil, ucan.TokenList{}, ucan.TokenList{})
if err != nil {
return fmt.Errorf("unable to create capability context: %w", err)
}
// Save it!
f, err := os.Create(capStoreFile)
if err != nil {
return fmt.Errorf("unable to create capability context file: %w", err)
}
err = ucan.SaveCapabilityContext(capCtx, f)
_ = f.Close()
if err != nil {
return fmt.Errorf("unable to save capability context: %w", err)
}
} else {
f, err := os.Open(capStoreFile)
if err != nil {
return fmt.Errorf("unable to open capability context: %w", err)
}
capCtx, err = ucan.LoadCapabilityContext(trustCtx, f)
_ = f.Close()
if err != nil {
return fmt.Errorf("unable to load capability context: %w", err)
}
}
trustCtx.Start(time.Hour)
capCtx.Start(5 * time.Minute)
hostID := p2p.Host.ID().String()
node, err := node.New(onboard, capCtx, hostID, p2p, resourceManager, cfg.Scheduler, hardwareManager)
if err != nil {
return fmt.Errorf("failed to create node: %s", err)
}
err = node.Start()
if err != nil {
return fmt.Errorf("failed to start node: %s", err)
}
// initialize rest api server
restConfig := api.RESTServerConfig{
P2P: p2pNet,
Onboarding: onboard,
Logger: logger.New("rest-server"),
Resource: resourceManager,
MidW: nil,
Port: config.GetConfig().Rest.Port,
Addr: config.GetConfig().Rest.Addr,
}
rServer := api.NewRESTServer(&restConfig)
rServer.InitializeRoutes()
go func() {
err := rServer.Run()
if err != nil {
log.Fatal(err)
}
}()
// wait for SIGINT or SIGTERM
sig := <-internal.ShutdownChan
// clean up
go func() {
err = node.Stop()
if err != nil {
log.Errorf("failed to stop node: %s", err)
}
err = p2p.Stop()
if err != nil {
log.Errorf("failed to stop libp2p network: %s", err)
}
log.Infof("Shutting down after receiving %v...\n", sig)
os.Exit(0)
}()
sig = <-internal.ShutdownChan
log.Infof("Shutting down after receiving %v...\n", sig)
os.Exit(1)
return nil
}
// GenerateAndStorePrivKey generates a new key pair using Secp256k1,
// storing the private key into user's keystore.
func GenerateAndStorePrivKey(ks keystore.KeyStore, passphrase string, keyID string) (crypto.PrivKey, error) {
priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, 256)
if err != nil {
return nil, fmt.Errorf("unable to generate key pair: %w", err)
}
rawPriv, err := crypto.MarshalPrivateKey(priv)
if err != nil {
return nil, fmt.Errorf("unable to marshal private key: %w", err)
}
_, err = ks.Save(
keyID,
rawPriv,
passphrase,
)
if err != nil {
return nil, fmt.Errorf("unable to save private key into the keystore: %w", err)
}
return priv, nil
}
func ValidateOnboarding(oConf *types.OnboardingConfig) {
// Check 1: Check if payment address is valid
err := utils.ValidateAddress(oConf.PublicKey)
if err != nil {
log.Errorf("the payment address %s is not valid", oConf.PublicKey)
log.Error("exiting DMS")
return
}
}
package cpu
import (
"fmt"
"github.com/shirou/gopsutil/v4/cpu"
"gitlab.com/nunet/device-management-service/types"
)
// GetUsage returns the CPU usage for the system
func GetUsage() (types.CPU, error) {
cpuUsage, err := cpu.Percent(0, false)
if err != nil {
return types.CPU{}, fmt.Errorf("failed to get CPU usage: %s", err)
}
cpuInfo, err := GetCPU()
if err != nil {
return types.CPU{}, fmt.Errorf("failed to get CPU info: %s", err)
}
usedCores := float64(cpuInfo.Cores) * cpuUsage[0] / 100
cpuInfo.Cores = float32(usedCores)
return cpuInfo, nil
}
package cpu
import (
"fmt"
"github.com/shirou/gopsutil/v4/cpu"
"gitlab.com/nunet/device-management-service/types"
)
// GetCPU returns the CPU information for the system
func GetCPU() (types.CPU, error) {
cores, err := cpu.Info()
if err != nil {
return types.CPU{}, fmt.Errorf("failed to get CPU info: %s", err)
}
var totalCompute float64
for i := 0; i < len(cores); i++ {
totalCompute += cores[i].Mhz
}
return types.CPU{
Cores: float32(len(cores)),
ClockSpeed: cores[0].Mhz * 1000000,
}, nil
}
package hardware
import (
"context"
"fmt"
"github.com/shirou/gopsutil/v4/disk"
"gitlab.com/nunet/device-management-service/types"
)
// GetDisk returns the types.Disk for the system
func GetDisk() (types.Disk, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get partitions: %w", err)
}
var totalStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get disk usage: %w", err)
}
totalStorage += usage.Total
}
return types.Disk{
Size: float64(totalStorage),
}, nil
}
// GetDiskUsage returns the types.Disk usage
func GetDiskUsage() (types.Disk, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get partitions: %w", err)
}
var usedStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get disk usage: %w", err)
}
usedStorage += usage.Used
}
return types.Disk{
Size: float64(usedStorage),
}, nil
}
package gpu
import (
"fmt"
"os/exec"
"regexp"
"strconv"
"gitlab.com/nunet/device-management-service/types"
)
// runROCmSmiCommand executes the rocm-smi command and returns the output as a string.
func runROCmSmiCommand() (string, error) {
cmd := exec.Command("rocm-smi", "--showid", "--showproductname", "--showmeminfo", "vram")
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("AMD ROCm not installed, initialized, or configured (reboot recommended for newly installed AMD GPU Drivers): %s", err)
}
return string(output), nil
}
// convertToGB converts memory from bytes (as a string) to GB.
func convertToGB(memoryBytesStr string) (int64, error) {
memoryBytes, err := strconv.ParseInt(memoryBytesStr, 10, 64)
if err != nil {
return 0, fmt.Errorf("parse memory value: %s", err)
}
return memoryBytes / (1024 * 1024 * 1024), nil
}
// parseRegex extracts all matches from the given regex pattern and returns the matches.
func parseRegex(pattern, output string) [][]string {
regex := regexp.MustCompile(pattern)
return regex.FindAllStringSubmatch(output, -1)
}
// getAMDGPUTotalVRAM extracts the total VRAM from the command output and converts it to MiB.
func getAMDGPUTotalVRAM(output string) ([]int64, error) {
totalMatches := parseRegex(`GPU\[\d+\]\s+: VRAM Total Memory \(B\):\s+(\d+)`, output)
if len(totalMatches) == 0 {
return nil, fmt.Errorf("find total VRAM in the output")
}
totalVRAMs := make([]int64, len(totalMatches))
for i, match := range totalMatches {
totalMemoryMiB, err := convertToGB(match[1])
if err != nil {
return nil, fmt.Errorf("parse total VRAM for GPU %d: %s", i, err)
}
totalVRAMs[i] = totalMemoryMiB
}
return totalVRAMs, nil
}
// getAMDGPUUsedVRAM extracts the used VRAM from the command output and converts it to MiB.
func getAMDGPUUsedVRAM(output string) ([]int64, error) {
usedMatches := parseRegex(`GPU\[\d+\]\s+: VRAM Total Used Memory \(B\):\s+(\d+)`, output)
if len(usedMatches) == 0 {
return nil, fmt.Errorf("find used VRAM in the output")
}
usedVRAMs := make([]int64, len(usedMatches))
for i, match := range usedMatches {
usedMemoryMiB, err := convertToGB(match[1])
if err != nil {
return nil, fmt.Errorf("parse used VRAM for GPU %d: %s", i, err)
}
usedVRAMs[i] = usedMemoryMiB
}
return usedVRAMs, nil
}
// getAMDGPUName extracts the GPU name from the command output.
func getAMDGPUName(output string) ([]string, error) {
nameMatches := parseRegex(`GPU\[\d+\]\s+: Card Series:\s+([^\n]+)`, output)
if len(nameMatches) == 0 {
return nil, fmt.Errorf("find GPU names in the output")
}
names := make([]string, len(nameMatches))
for i, match := range nameMatches {
names[i] = match[1]
}
return names, nil
}
// getAMDGPUs returns the GPU information for AMD GPUs.
func getAMDGPUs(metadata []types.GPUMetadata) ([]types.GPU, error) {
output, err := runROCmSmiCommand()
if err != nil {
return nil, err
}
gpuNameMatches, err := getAMDGPUName(output)
if err != nil {
return nil, err
}
totalVRAMs, err := getAMDGPUTotalVRAM(output)
if err != nil {
return nil, err
}
gpuInfos := make([]types.GPU, 0, len(gpuNameMatches))
for i := range gpuNameMatches {
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuNameMatches[i],
VRAM: uint64(totalVRAMs[i]),
Vendor: types.GPUVendorAMDATI,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// getAMDGPUUsage returns the GPU usage for AMD GPUs.
func getAMDGPUUsage(_ []types.GPUMetadata) ([]types.GPU, error) {
output, err := runROCmSmiCommand()
if err != nil {
return nil, err
}
gpuNameMatches, err := getAMDGPUName(output)
if err != nil {
return nil, err
}
usedVRAMs, err := getAMDGPUUsedVRAM(output)
if err != nil {
return nil, err
}
gpus := make([]types.GPU, 0, len(usedVRAMs))
for i, usedVRAM := range usedVRAMs {
gpuInfo := types.GPU{
Model: gpuNameMatches[i],
VRAM: uint64(usedVRAM),
}
gpus = append(gpus, gpuInfo)
}
return gpus, nil
}
package gpu
import (
"fmt"
"sync"
"github.com/jaypipes/ghw"
"gitlab.com/nunet/device-management-service/types"
)
var (
metadata map[types.GPUVendor][]types.GPUMetadata
mu sync.Mutex
)
// fetchGPUMetadata returns the GPU metadata for all GPUs
// TODO: Use one single library to fetch GPU information or improve the match criteria
// https://gitlab.com/nunet/device-management-service/-/issues/548
// TODO: write tests by mocking the gpu snapshot
// https://gitlab.com/nunet/device-management-service/-/issues/534
func fetchGPUMetadata() (map[types.GPUVendor][]types.GPUMetadata, error) {
if metadata != nil {
return metadata, nil
}
mu.Lock()
defer mu.Unlock()
metadata = make(map[types.GPUVendor][]types.GPUMetadata)
gpuInfo, err := ghw.GPU()
if err != nil {
return nil, err
}
for _, card := range gpuInfo.GraphicsCards {
if card.DeviceInfo == nil {
continue
}
pciAddress := card.Address
vendor := types.ParseGPUVendor(card.DeviceInfo.Vendor.Name)
metadata[vendor] = append(metadata[vendor], types.GPUMetadata{PCIAddress: pciAddress})
}
return metadata, nil
}
// GetGPUs returns the GPUs based on the specified vendors. If no vendors are provided, it returns the information of all the GPUs
func GetGPUs(vendors ...types.GPUVendor) ([]types.GPU, error) {
return getGPUsHelper(fetchGPUMetadata, assignIndexToGPUs, map[types.GPUVendor]func(metadata []types.GPUMetadata) ([]types.GPU, error){
types.GPUVendorIntel: getIntelGPUs,
types.GPUVendorNvidia: getNVIDIAGPUs,
types.GPUVendorAMDATI: getAMDGPUs,
}, vendors...)
}
// GetGPUUsage returns the GPU usage based on the specified vendors. If no vendors are provided, it returns the information of all the GPUs
func GetGPUUsage(vendors ...types.GPUVendor) ([]types.GPU, error) {
return getGPUsHelper(fetchGPUMetadata, assignIndexToGPUs, map[types.GPUVendor]func(metadata []types.GPUMetadata) ([]types.GPU, error){
types.GPUVendorIntel: getIntelGPUUsage,
types.GPUVendorNvidia: getNVIDIAGPUUsage,
types.GPUVendorAMDATI: getAMDGPUUsage,
}, vendors...)
}
// getGPUsHelper is a helper function to avoid code duplication in GetGPUs and GetGPUUsage
func getGPUsHelper(fetchMetadata func() (map[types.GPUVendor][]types.GPUMetadata, error), assignFunc func([]types.GPU) []types.GPU, fetchFuncs map[types.GPUVendor]func(metadata []types.GPUMetadata) ([]types.GPU, error), vendors ...types.GPUVendor) ([]types.GPU, error) {
var gpus []types.GPU
gpuMetadata, err := fetchMetadata()
if err != nil {
return nil, fmt.Errorf("failed to fetch GPU metadata: %v", err)
}
// Helper function to fetch and append GPU info
fetchAndAppendGPUs := func(fetchFunc func(metadata []types.GPUMetadata) ([]types.GPU, error), vendor types.GPUVendor) {
vendorMetadata, ok := gpuMetadata[vendor]
if !ok {
// TODO: log a warning here
return
}
gpuList, err := fetchFunc(vendorMetadata)
if err != nil {
// TODO: log a warning here
return
}
gpus = append(gpus, gpuList...)
}
if len(vendors) == 0 {
// No specific vendor requested, fetch all types of GPUs
for vendor, fetchFunc := range fetchFuncs {
fetchAndAppendGPUs(fetchFunc, vendor)
}
} else {
// Fetch GPUs for the specified vendor only
for _, vendor := range vendors {
fetchFunc, ok := fetchFuncs[vendor]
if !ok {
return nil, fmt.Errorf("unsupported GPU vendor: %v", vendor)
}
fetchAndAppendGPUs(fetchFunc, vendor)
}
}
// Assign index to GPUs and return
// Note: The index is internal to dms and is not the same as the device index
return assignFunc(gpus), nil
}
package gpu
import (
"fmt"
"os/exec"
"regexp"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/types"
)
// runXpuSmiCommand runs the xpu-smi command with the provided arguments and returns the output as a string.
func runXpuSmiCommand(args ...string) (string, error) {
cmd := exec.Command("xpu-smi", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("xpu-smi command failed: %s", err)
}
return string(output), nil
}
// getIntelGPUDeviceIDs extracts the device IDs of Intel GPUs from the xpu-smi output.
func getIntelGPUDeviceIDs(output string) ([]string, error) {
deviceIDRegex := regexp.MustCompile(`(?i)\| Device ID\s+\|\s+(\d+)\s+\|`)
deviceIDMatches := deviceIDRegex.FindAllStringSubmatch(output, -1)
if len(deviceIDMatches) == 0 {
return nil, fmt.Errorf("failed to find any Intel GPUs")
}
deviceIDs := make([]string, len(deviceIDMatches))
for i, match := range deviceIDMatches {
deviceIDs[i] = match[1]
}
return deviceIDs, nil
}
// getIntelGPUDiscoveryInfo retrieves the GPU name and total memory for a specific Intel GPU.
func getIntelGPUDiscoveryInfo(deviceID string) (string, float64, error) {
output, err := runXpuSmiCommand("discovery", "-d", deviceID)
if err != nil {
return "", 0, fmt.Errorf("failed to get discovery info for Intel GPU %s: %s", deviceID, err)
}
// Extract the GPU name and total memory
nameRegex := regexp.MustCompile(`(?i)Device Name:\s+([^\n|]+)`)
totalMemRegex := regexp.MustCompile(`(?i)Memory Physical Size:\s+([^\s]+)\s+MiB`)
nameMatch := nameRegex.FindStringSubmatch(output)
totalMemMatch := totalMemRegex.FindStringSubmatch(output)
if nameMatch == nil || totalMemMatch == nil {
return "", 0, fmt.Errorf("failed to parse discovery info for Intel GPU %s", deviceID)
}
gpuName := strings.TrimSpace(nameMatch[1])
totalMemoryMiB, err := strconv.ParseFloat(totalMemMatch[1], 64)
if err != nil {
return "", 0, fmt.Errorf("failed to parse total memory for Intel GPU %s: %s", deviceID, err)
}
return gpuName, totalMemoryMiB, nil
}
// getIntelGPUUsedMemory retrieves the used memory for a specific Intel GPU.
func getIntelGPUUsedMemory(deviceID string) (float64, error) {
output, err := runXpuSmiCommand("stats", "-d", deviceID)
if err != nil {
return 0, fmt.Errorf("failed to get stats for Intel GPU %s: %s", deviceID, err)
}
// Extract the used memory
usedMemRegex := regexp.MustCompile(`(?i)GPU Memory Used \(MiB\)\s+\|\s+(\d+)\s+\|`)
usedMemMatch := usedMemRegex.FindStringSubmatch(output)
if usedMemMatch == nil {
return 0, fmt.Errorf("failed to parse used memory for Intel GPU %s", deviceID)
}
usedMemoryMiB, err := strconv.ParseFloat(usedMemMatch[1], 64)
if err != nil {
return 0, fmt.Errorf("failed to parse used memory for Intel GPU %s: %s", deviceID, err)
}
return usedMemoryMiB, nil
}
// getIntelGPUs returns the GPU information for Intel GPUs.
func getIntelGPUs(metadata []types.GPUMetadata) ([]types.GPU, error) {
// Get the list of Intel GPU devices
output, err := runXpuSmiCommand("health", "-l")
if err != nil {
return nil, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
// Get Intel GPU device IDs
deviceIDs, err := getIntelGPUDeviceIDs(output)
if err != nil {
return nil, err
}
if len(deviceIDs) != len(metadata) {
return nil, fmt.Errorf("failed to find Intel GPU information for all GPUs")
}
gpuInfos := make([]types.GPU, 0, len(deviceIDs))
for i, deviceID := range deviceIDs {
// Get GPU discovery info
gpuName, totalMemoryMiB, err := getIntelGPUDiscoveryInfo(deviceID)
if err != nil {
return nil, err
}
// Populate GPU info
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
VRAM: uint64(totalMemoryMiB),
Vendor: types.GPUVendorIntel,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// getIntelGPUUsage returns the GPU usage for Intel GPUs.
func getIntelGPUUsage(metadata []types.GPUMetadata) ([]types.GPU, error) {
// Get the list of Intel GPU devices
output, err := runXpuSmiCommand("health", "-l")
if err != nil {
return nil, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
// Get Intel GPU device IDs
deviceIDs, err := getIntelGPUDeviceIDs(output)
if err != nil {
return nil, err
}
if len(deviceIDs) != len(metadata) {
return nil, fmt.Errorf("failed to find Intel GPU information for all GPUs")
}
gpuInfos := make([]types.GPU, 0, len(deviceIDs))
for i, deviceID := range deviceIDs {
gpuName, _, err := getIntelGPUDiscoveryInfo(deviceID)
if err != nil {
return nil, err
}
usedMemoryMiB, err := getIntelGPUUsedMemory(deviceID)
if err != nil {
return nil, err
}
// Populate GPU info
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
VRAM: uint64(usedMemoryMiB),
Vendor: types.GPUVendorIntel,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
package gpu
import (
"errors"
"fmt"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"gitlab.com/nunet/device-management-service/types"
)
// initNVML initializes the NVIDIA Management Library.
func initNVML() error {
ret := nvml.Init()
if !errors.Is(ret, nvml.SUCCESS) {
return fmt.Errorf("NVIDIA Management Library not installed, initialized, or configured (reboot recommended for newly installed NVIDIA GPU drivers): %s", nvml.ErrorString(ret))
}
return nil
}
// shutdownNVML shuts down the NVIDIA Management Library.
func shutdownNVML() {
_ = nvml.Shutdown()
}
// getNVIDIADeviceCount returns the number of NVIDIA devices (GPUs).
func getNVIDIADeviceCount() (int, error) {
deviceCount, ret := nvml.DeviceGetCount()
if !errors.Is(ret, nvml.SUCCESS) {
return 0, fmt.Errorf("failed to get device count: %s", nvml.ErrorString(ret))
}
return deviceCount, nil
}
// getNVIDIADeviceHandle returns the handle for the NVIDIA device by its index.
func getNVIDIADeviceHandle(index int) (nvml.Device, error) {
device, ret := nvml.DeviceGetHandleByIndex(index)
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device handle for device %d: %s", index, nvml.ErrorString(ret))
}
return device, nil
}
// getNVIDIADeviceName returns the name of the NVIDIA device.
func getNVIDIADeviceName(device nvml.Device) (string, error) {
name, ret := device.GetName()
if !errors.Is(ret, nvml.SUCCESS) {
return "", fmt.Errorf("failed to get name for device: %s", nvml.ErrorString(ret))
}
return name, nil
}
// getNVIDIADeviceMemory returns the memory information for the NVIDIA device.
func getNVIDIADeviceMemory(device nvml.Device) (nvml.Memory, error) {
memory, ret := device.GetMemoryInfo()
if !errors.Is(ret, nvml.SUCCESS) {
return nvml.Memory{}, fmt.Errorf("failed to get NVIDIA GPU memory info: %s", nvml.ErrorString(ret))
}
return memory, nil
}
// getNVIDIAGPUs returns the GPU information for NVIDIA GPUs.
func getNVIDIAGPUs(metadata []types.GPUMetadata) ([]types.GPU, error) {
if err := initNVML(); err != nil {
return nil, err
}
defer shutdownNVML()
deviceCount, err := getNVIDIADeviceCount()
if err != nil {
return nil, err
}
if deviceCount != len(metadata) {
return nil, fmt.Errorf("failed to find NVIDIA GPU information for all GPUs")
}
var gpus []types.GPU
// Iterate over each device
for i := 0; i < deviceCount; i++ {
device, err := getNVIDIADeviceHandle(i)
if err != nil {
return nil, err
}
name, err := getNVIDIADeviceName(device)
if err != nil {
return nil, err
}
memory, err := getNVIDIADeviceMemory(device)
if err != nil {
return nil, err
}
gpu := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Name: name,
Model: name,
VRAM: memory.Total,
Vendor: types.GPUVendorNvidia,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// getNVIDIAGPUUsage returns the GPU usage for NVIDIA GPUs.
func getNVIDIAGPUUsage(_ []types.GPUMetadata) ([]types.GPU, error) {
if err := initNVML(); err != nil {
return nil, err
}
defer shutdownNVML()
deviceCount, err := getNVIDIADeviceCount()
if err != nil {
return nil, err
}
var gpus []types.GPU
for i := 0; i < deviceCount; i++ {
device, err := getNVIDIADeviceHandle(i)
if err != nil {
return nil, err
}
name, err := getNVIDIADeviceName(device)
if err != nil {
return nil, err
}
memory, err := getNVIDIADeviceMemory(device)
if err != nil {
return nil, err
}
gpu := types.GPU{
Name: name,
Model: name,
VRAM: memory.Used,
Vendor: types.GPUVendorNvidia,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
package gpu
import (
"gitlab.com/nunet/device-management-service/types"
)
// assignIndexToGPUs assigns an index to each GPU in the list starting from 0
func assignIndexToGPUs(gpus []types.GPU) []types.GPU {
for i := range gpus {
gpus[i].Index = i
}
return gpus
}
package hardware
import (
"fmt"
"sync"
"gitlab.com/nunet/device-management-service/dms/hardware/cpu"
"gitlab.com/nunet/device-management-service/dms/hardware/gpu"
"gitlab.com/nunet/device-management-service/types"
)
// defaultHardwareManager manages the machine's hardware resources.
type defaultHardwareManager struct {
machineResources *types.MachineResources
mu sync.Mutex
}
// NewHardwareManager creates a new instance of defaultHardwareManager.
func NewHardwareManager() types.HardwareManager {
return &defaultHardwareManager{}
}
var _ types.HardwareManager = (*defaultHardwareManager)(nil)
// GetMachineResources returns the resources of the machine in a thread-safe manner.
func (m *defaultHardwareManager) GetMachineResources() (types.MachineResources, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.machineResources != nil {
return *m.machineResources, nil
}
var err error
var cpuDetails types.CPU
var ram types.RAM
var gpus []types.GPU
var diskDetails types.Disk
if cpuDetails, err = cpu.GetCPU(); err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get CPU: %w", err)
}
if ram, err = GetRAM(); err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get RAM: %w", err)
}
if gpus, err = gpu.GetGPUs(); err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get GPUs: %w", err)
}
if diskDetails, err = GetDisk(); err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get Disk: %w", err)
}
m.machineResources = &types.MachineResources{
Resources: types.Resources{
CPU: cpuDetails,
RAM: ram,
Disk: diskDetails,
GPUs: gpus,
},
}
return *m.machineResources, nil
}
// GetUsage returns the usage of the machine.
func (m *defaultHardwareManager) GetUsage() (types.Resources, error) {
cpuDetails, err := cpu.GetUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get CPU usage: %w", err)
}
ram, err := GetRAMUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get RAM usage: %w", err)
}
diskDetails, err := GetDiskUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get Disk usage: %w", err)
}
gpus, err := gpu.GetGPUs()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get GPU usage: %w", err)
}
return types.Resources{
CPU: cpuDetails,
RAM: ram,
Disk: diskDetails,
GPUs: gpus,
}, nil
}
// GetFreeResources returns the free resources of the machine.
func (m *defaultHardwareManager) GetFreeResources() (types.Resources, error) {
usage, err := m.GetUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get usage: %w", err)
}
availableResources, err := m.GetMachineResources()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get machine resources: %w", err)
}
if err := availableResources.Subtract(usage); err != nil {
return types.Resources{}, fmt.Errorf("no free resources: %w", err)
}
return availableResources.Resources, nil
}
package hardware
import (
"fmt"
"github.com/shirou/gopsutil/v4/mem"
"gitlab.com/nunet/device-management-service/types"
)
// GetRAM returns the types.RAM information for the system
func GetRAM() (types.RAM, error) {
v, err := mem.VirtualMemory()
if err != nil {
return types.RAM{}, fmt.Errorf("failed to get total memory: %s", err)
}
return types.RAM{
Size: float64(v.Total),
}, nil
}
// GetRAMUsage returns the RAM usage
func GetRAMUsage() (types.RAM, error) {
v, err := mem.VirtualMemory()
if err != nil {
return types.RAM{}, fmt.Errorf("failed to get total memory: %s", err)
}
return types.RAM{
Size: float64(v.Used),
}, nil
}
package dms
import (
"os"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/internal/config"
)
func init() {
fs := afero.NewOsFs()
workDir := config.GetConfig().WorkDir
if workDir != "" {
err := fs.MkdirAll(workDir, os.FileMode(0o700))
if err != nil {
log.Warnf("unable to create work directory: %v", err)
}
}
dataDir := config.GetConfig().DataDir
if dataDir != "" {
err := fs.MkdirAll(dataDir, os.FileMode(0o700))
if err != nil {
log.Warnf("unable to create data directory: %v", err)
}
}
userDir := config.GetConfig().UserDir
if userDir != "" {
err := fs.MkdirAll(userDir, os.FileMode(0o700))
if err != nil {
log.Warnf("unable to create user directory: %v", err)
}
}
}
package jobs
import (
"context"
"errors"
"fmt"
"sync"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/actor"
// "gitlab.com/nunet/device-management-service/executor"
"gitlab.com/nunet/device-management-service/types"
)
const (
pending AllocationStatus = "pending"
running AllocationStatus = "running"
stopped AllocationStatus = "stopped"
)
// AllocationStatus is a representation of the execution status
type AllocationStatus string
// Status holds the status of an allocation.
type Status struct {
JobResources types.Resources
Status AllocationStatus
}
// AllocationDetails encapsulates the dependencies to the constructor.
type AllocationDetails struct {
// Job Job
NodeID string
SourceID string
}
// Allocation represents an allocation
type Allocation struct {
ID string
mx sync.Mutex
status AllocationStatus
nodeID string
sourceID string
executionID string
Actor *actor.BasicActor
// executor executor.Executor
resourceManager types.ResourceManager
actorRunning bool
// Job Job
}
// NewAllocation creates a new allocation given the actor.
func NewAllocation(actor *actor.BasicActor, details AllocationDetails, resourceManager types.ResourceManager) (*Allocation, error) {
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
id, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate uuid for allocation: %w", err)
}
executorID, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to create executor id: %w", err)
}
return &Allocation{
ID: id.String(),
nodeID: details.NodeID,
sourceID: details.SourceID,
Actor: actor,
executionID: executorID.String(),
resourceManager: resourceManager,
status: pending,
}, nil
}
// Run creates the executor based on the execution engine configuration.
func (a *Allocation) Run(_ context.Context) error {
// a.mx.Lock()
// defer a.mx.Unlock()
// if a.status == running {
// return nil
// }
// resourceAllocation := types.ResourceAllocation{JobID: a.Job.ID, Resources: a.Job.Resources}
// err := a.resourceManager.AllocateResources(ctx, resourceAllocation)
// if err != nil {
// return fmt.Errorf("failed to allocate resources: %w", err)
// }
// defer func() {
// if a.status != running {
// // If not running, ensure deallocation of resources
// err = a.resourceManager.DeallocateResources(ctx, a.Job.ID)
// }
// }()
// // if executor is nil create it
// if a.executor == nil {
// err = a.createExecutor(ctx, a.Job.Execution)
// if err != nil {
// return fmt.Errorf("failed to create executor: %w", err)
// }
// }
// err = a.executor.Start(ctx, &types.ExecutionRequest{
// JobID: a.Job.ID,
// ExecutionID: a.executionID,
// EngineSpec: &a.Job.Execution,
// Resources: &a.Job.Resources,
// // TODO add the following
// Inputs: []*types.StorageVolumeExecutor{},
// Outputs: []*types.StorageVolumeExecutor{},
// ResultsDir: "",
// })
// if err != nil {
// return fmt.Errorf("failed to start executor: %w", err)
// }
// a.status = running
// return nil
return ErrTODO
}
// Stop stops the running executor
func (a *Allocation) Stop() error {
// a.mx.Lock()
// defer a.mx.Unlock()
// defer func() {
// if a.actorRunning {
// if err := a.Actor.Stop(); err != nil {
// log.Warnf("error stopping allocation actor: %s", err)
// }
// a.actorRunning = false
// }
// }()
// if a.status != running {
// return nil
// }
// if err := a.executor.Cancel(ctx, a.executionID); err != nil {
// return fmt.Errorf("failed to stop execution: %w", err)
// }
// a.status = stopped
// if err := a.resourceManager.DeallocateResources(ctx, a.Job.ID); err != nil {
// return fmt.Errorf("failed to deallocate resources: %w", err)
// }
// return nil
return ErrTODO
}
// Status returns information about the allocated/usage of resources and execution status of workload.
func (a *Allocation) Status(_ context.Context) Status {
// return Status{
// JobResources: a.Job.Resources,
// Status: a.status,
// }
// TODO
return Status{}
}
// Start the actor of the allocation.
func (a *Allocation) Start() error {
a.mx.Lock()
defer a.mx.Unlock()
if a.actorRunning {
return nil
}
err := a.Actor.Start()
if err != nil {
return fmt.Errorf("failed to start allocation actor: %w", err)
}
a.actorRunning = true
return nil
}
package jobs
import (
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/types"
)
const (
BidRequestTopic = "/nunet/deployment"
BidRequestBehavior = "/dms/deployment/request"
BidRequestTimeout = 5 * time.Second
BidReplyBehavior = "/dms/deployment/bid"
MinEnsembleDeploymentTime = 15 * time.Second
)
// EnsembleBidRequest is a request for a bids pertaining to an ensemble
//
// Note: At the moment, we embed a bid request for each node
// This is fine for small deployments, and a small network, which is what we have.
// For large deployments however, this won't scale and we will have to create aggregate
// bid requests for related group of nodes and also handle them with bid request
// aggregators who control multiple nodes.
type EnsembleBidRequest struct {
ID string // unique identifier of an ensemble (in the context of the orchestrator)
Request []BidRequest // list of node bid requests
PeerExclusion []string // list of peers to exclude from bidding
}
// BidRequest is a versioned bid request
type BidRequest struct {
V1 *BidRequestV1
}
// BidRequestV1 is v1 of bid requests for a node to use for deployment
type BidRequestV1 struct {
NodeID string // unique identifier for a node, within the context of an ensemble
Executors []AllocationExecutor // list of required executors to support the allocation(s)
Resources types.Resources // (aggregate) required hardware resources
Location LocationConstraints // node location constraints
PublicPorts struct {
Static []int // statically configured public ports
Dynamic int // number of dynamic ports
}
}
// Bid is the version struct for Bids in response to a bid request
type Bid struct {
V1 *BidV1
}
// BidV1 is v1 of the bid structure
type BidV1 struct {
EnsembleID string // unique identifier for the ensemble
NodeID string // unique identifier for a node; matches the id of the BidRequest to which this bid pertains
Peer string // the peer ID of the node
Location Location // the location of the node
Handle actor.Handle // the handle of the actor submitting the bid
// TODO signature from Peer
}
func (b *EnsembleBidRequest) Validate() error {
// TODO
return nil
}
// TODO pass the envelope for verification
func (b *Bid) Validate() error {
// TODO
return nil
}
func (b *Bid) EnsembleID() string {
return b.V1.EnsembleID
}
func (b *Bid) NodeID() string {
return b.V1.NodeID
}
func (b *Bid) Peer() string {
return b.V1.Peer
}
func (b *Bid) Location() Location {
return b.V1.Location
}
package jobs
import (
"gitlab.com/nunet/device-management-service/types"
)
// EnsembleConfig is the versioned structure that contains the ensemble configuration
type EnsembleConfig struct {
V1 *EnsembleConfigV1
}
// EnsembleConfigV1 is version 1 of the configuration for an ensemble
type EnsembleConfigV1 struct {
Allocations map[string]AllocationConfig // (named) allocations in the ensemble
Nodes map[string]NodeConfig // (named) nodes in the ensemble
Edges []EdgeConstraint // network edge constraints
Supervisor SupervisorConfig // supervision structure
Keys map[string]string // (named) ssh public keys relevant to the allocation
Scripts map[string]string // (named) provisioning scripts
}
// AllocationConfig is the configuration of an allocation
type AllocationConfig struct {
Executor AllocationExecutor // the executor of the allocation
Resources types.Resources // the HW resources required by the allocation
Execution types.SpecConfig // the allocation execution configuration
Volumes []VolumeConfig // premounted external volumes
DNSName string // the internal DNS name of the allocation
Keys []string // names of the authorized ssh keys for the allocation
Provision []string // names of provisioning scripts to run (in order)
HealthCheck string // name of the script to run for health checks
}
// AllocationExecutor is the executor reoquired for the allocation
type AllocationExecutor string
const (
ExecutorFirecracker AllocationExecutor = "firecracker"
ExecutorDocker AllocationExecutor = "docker"
)
// VolumeConfig is an externally mounted volume
type VolumeConfig struct {
Name string // the name of the volume
Type string // the type of the volume
Remote types.SpecConfig // the remout mount config
Mountpoint string // the mountpoint
}
// NodeConfig is the configuration of a distinct DMS node
type NodeConfig struct {
Allocations []string // the list of (named) allocations in the node
Ports []PortConfig // the port mapping configuration for the node
Location LocationConstraints // the geographical location constraints for the node
Peer string // (optional) a fixed peer for the node
// TODO contract information
}
// LocationConstraints provides the node location placement constraints
type LocationConstraints struct {
Accept []Location // acceptable location constraints (disjunction)
Reject []Location // negative location constraints (conjunction); eg !USA for GPDR purposes
}
// Location is a geographical location on Planet Earth
type Location struct {
Region string // geographic region of the location (required)
Country string // country (code or name) of the location (optional)
City string // city of the location; optional but country must be specified if not empty
ASN uint // Autonomous System Number for the location (optional)
ISP string // Internet Service Provider name for the location (optional)
}
// PortConfig is the configuration for a port mapping a public port to a private port
// in an allocation
type PortConfig struct {
Public int // the public port 0 for any
Private int // the private mapping
Allocation string // the allocation where the port is mapped
}
// EdgeConstraint is a constraint for a network edge between two nodes
type EdgeConstraint struct {
S, T string // (named) nodes connected by the edge
RTT uint // maximum edge RTT in milliseconds
BW uint // minimum edge bandwidth in Kbps
Symmetric bool // whether the constraint is symmetric (bidirectional)
}
// SupervisorConfig is the supervisory structure configuration for the ensemble
type SupervisorConfig struct {
Strategy SupervisorStrategy // the strategy for the supervision group
Allocations []string // allocations in this supervision group
Children []SupervisorConfig // allocation children for recursive groups
}
// SupervisoryStrategy is the name of a supervision strategy
type SupervisorStrategy string
const (
StrategyOneForOne SupervisorStrategy = "OneForOne"
StrategyAllForOne SupervisorStrategy = "AllForOne"
StrategyRestForOne SupervisorStrategy = "RestForOne"
)
// config validation
func (e *EnsembleConfig) Validate() error {
// TODO
return nil
}
package jobs
import (
"context"
"encoding/json"
"fmt"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/network"
)
type Orchestrator struct {
actor actor.Actor
network network.Network //nolint
id string
cfg EnsembleConfig //nolint
manifest EnsembleManifest
running bool
ctx context.Context
cancel func()
}
func (o *Orchestrator) Deploy(expiry time.Time) (EnsembleManifest, error) {
deploy:
for time.Now().Before(expiry) {
// 1. Create bid requests for nodes
bidrq, err := o.makeInitialBidRequest()
if err != nil {
return EnsembleManifest{}, fmt.Errorf("creating bid request: %w", err)
}
// 2. Collect bids
bidMap := make(map[string][]Bid)
peerExclusion := make(map[string]struct{})
addBid := func(bid Bid) {
// check that the peer has not already submitted a bid
peerID := bid.Peer()
if _, exclude := peerExclusion[peerID]; exclude {
log.Debugf("ignoring duplicate bid from peer %s", peerID)
return
}
// verify the location constraints of the node
nodeID := bid.NodeID()
loc := bid.Location()
if !o.acceptPeerLocation(nodeID, loc) {
log.Debugf("ignoring out of location bid from peer %s for node %s", peerID, nodeID)
return
}
bidMap[nodeID] = append(bidMap[nodeID], bid)
peerExclusion[peerID] = struct{}{}
}
bidCh, bidExpiryTime, err := o.requestBids(bidrq, expiry)
if err != nil {
return EnsembleManifest{}, fmt.Errorf("collecting bids: %w", err)
}
o.collectBids(bidCh, bidExpiryTime, addBid)
// 3. Create a candidate deployment
var candidate map[string]Bid
var ok bool
for time.Now().Before(expiry) {
candidate, ok = o.makeCandidateDeployment(bidMap)
if ok {
break
}
// we don't have bids for some of our nodes so we don't have a candidate
// we need to make a residual bid request for the remaining nodes
// Note: in order to facilitate random selection, the residual bid requests
// can drop some of the original bids
bidrq, err := o.makeResidualBidRequest(bidMap, peerExclusion)
if err != nil {
return EnsembleManifest{}, fmt.Errorf("creating residual bid request: %w", err)
}
bidCh, bidExpiryTime, err := o.requestBids(bidrq, expiry)
if err != nil {
return EnsembleManifest{}, fmt.Errorf("collecting residual bids: %w", err)
}
o.collectBids(bidCh, bidExpiryTime, addBid)
}
if !ok {
log.Debugf("failed to create candidate deployment")
continue deploy
}
// 5. Check the edge constraints
if err := o.verifyEdgeConstraints(candidate); err != nil {
log.Debugf("failed to verify edge constraints: %s", err)
continue deploy
}
// 6. Commit the deployment
manifest, err := o.commitDeployment(candidate)
if err != nil {
log.Warnf("failed to commit deployment: %s", err)
continue deploy
}
// 7. provision the network
if err := o.provision(manifest); err != nil {
log.Errorf("failed to privision network: %s", err)
o.revertDeployment(manifest)
continue deploy
}
// 8. start the deployment
if err := o.start(manifest); err != nil {
log.Errorf("failed to start the deployment: %s", err)
o.revertDeployment(manifest)
continue deploy
}
// We are done! start the supervisor return the manifest.
o.manifest = manifest
o.running = true
o.ctx, o.cancel = context.WithCancel(context.Background())
go o.supervise()
return manifest, nil
}
// we failed to create the deployment in time
return EnsembleManifest{}, ErrDeploymentFailed
}
func (o *Orchestrator) requestBids(bidrq EnsembleBidRequest, expiry time.Time) (chan Bid, time.Time, error) {
log.Debugf("requesting bids: %+v", bidrq)
bidExpiryTime := time.Now().Add(BidRequestTimeout)
if expiry.Before(bidExpiryTime) {
return nil, time.Time{}, fmt.Errorf("not enough time for deployment: %w", ErrDeploymentFailed)
}
bidExpiry := uint64(bidExpiryTime.UnixNano())
msg, err := actor.Message(
o.actor.Handle(),
actor.Handle{},
BidRequestBehavior,
bidrq,
actor.WithMessageTopic(BidRequestTopic),
actor.WithMessageReplyTo(BidReplyBehavior),
actor.WithMessageExpiry(bidExpiry),
)
if err != nil {
return nil, time.Time{}, fmt.Errorf("creating bid request message: %w", err)
}
bidCh := make(chan Bid)
if err := o.actor.AddBehavior(
BidReplyBehavior,
func(msg actor.Envelope) {
defer msg.Discard()
var bid Bid
if err := json.Unmarshal(msg.Message, &bid); err != nil {
log.Debugf("failed to unmarshal bid from %s: %s", msg.From, err)
return
}
timer := time.NewTimer(time.Until(bidExpiryTime))
defer timer.Stop()
select {
case bidCh <- bid:
case <-timer.C:
}
},
actor.WithBehaviorExpiry(bidExpiry),
); err != nil {
return nil, time.Time{}, fmt.Errorf("adding bid behavior: %w", err)
}
if err := o.actor.Publish(msg); err != nil {
return nil, time.Time{}, fmt.Errorf("publishing bid request: %w", err)
}
return bidCh, bidExpiryTime, nil
}
func (o *Orchestrator) collectBids(bidCh chan Bid, bidExpiryTime time.Time, addBid func(Bid)) {
timer := time.NewTimer(time.Until(bidExpiryTime))
defer timer.Stop()
for {
select {
case bid := <-bidCh:
if err := bid.Validate(); err != nil {
log.Debugf("got invalid bid: %s", err)
continue
}
if bid.EnsembleID() != o.id {
log.Debugf("got bid for unexpected ensemble ID: %s", bid.EnsembleID())
continue
}
addBid(bid)
case <-timer.C:
return
}
}
}
func (o *Orchestrator) makeCandidateDeployment(_ map[string][]Bid) (map[string]Bid, bool) {
// TODO
return nil, false
}
func (o *Orchestrator) verifyEdgeConstraints(_ map[string]Bid) error {
// TODO
return ErrTODO
}
func (o *Orchestrator) commitDeployment(_ map[string]Bid) (EnsembleManifest, error) {
// TODO
return EnsembleManifest{}, ErrTODO
}
func (o *Orchestrator) provision(_ EnsembleManifest) error {
// TODO
return ErrTODO
}
func (o *Orchestrator) start(_ EnsembleManifest) error {
// TODO
return ErrTODO
}
func (o *Orchestrator) revertDeployment(_ EnsembleManifest) {
// TODO
}
func (o *Orchestrator) acceptPeerLocation(_ string, _ Location) bool {
// TODO
return true
}
func (o *Orchestrator) makeInitialBidRequest() (EnsembleBidRequest, error) {
// TODO
return EnsembleBidRequest{}, ErrTODO
}
func (o *Orchestrator) makeResidualBidRequest(_ map[string][]Bid, _ map[string]struct{}) (EnsembleBidRequest, error) {
// TODO
return EnsembleBidRequest{}, ErrTODO
}
func (o *Orchestrator) supervise() {
// TODO
}
func (o *Orchestrator) Stop() {
// TODO
}
package node
import (
"encoding/json"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/jobs"
)
const (
NewDeploymentBehavior = "/dms/deployment/new"
// Minimum time for deployment
MinDeploymentTime = time.Minute - time.Second
)
type NewDeploymentRequest struct {
Ensemble jobs.EnsembleConfig
}
type NewDeploymentResponse struct {
Status string
Ensemble *jobs.EnsembleManifest `json:",omitempty"`
Error string `json:",omitempty"`
}
func (n *Node) newDeployment(msg actor.Envelope) {
defer msg.Discard()
if time.Until(msg.Expiry()) < MinDeploymentTime {
log.Debugf("deployment time too short")
n.sendReply(msg, NewDeploymentResponse{
Status: "ERROR",
Error: "requested deployment time too short",
})
return
}
var request NewDeploymentRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
log.Debugf("unmarshalling deployment request: %s", err)
n.sendReply(msg, NewDeploymentResponse{
Status: "ERROR",
Error: err.Error(),
})
return
}
orchestrator, err := n.createOrchestrator(request.Ensemble)
if err != nil {
log.Warnf("creating orchestrator: %s", err)
n.sendReply(msg, NewDeploymentResponse{
Status: "ERROR",
Error: err.Error(),
})
return
}
manifest, err := orchestrator.Deploy(msg.Expiry())
if err != nil {
orchestrator.Stop()
log.Warnf("creating ensemble: %s", err)
n.sendReply(msg, NewDeploymentResponse{
Status: "ERROR",
Error: err.Error(),
})
return
}
n.mx.Lock()
n.deployments[manifest.ID] = orchestrator
n.mx.Unlock()
log.Infof("created ensemble: %s", manifest.ID)
n.sendReply(msg, NewDeploymentResponse{
Status: "OK",
Ensemble: &manifest,
})
}
func (n *Node) createOrchestrator(_ jobs.EnsembleConfig) (*jobs.Orchestrator, error) {
// TODO
return nil, ErrTODO
}
func (n *Node) saveDeployments() error {
// TODO
return nil
}
func (n *Node) restoreDeployments() error {
// TODO
return nil
}
package node
import (
"context"
"encoding/json"
"time"
kbucket "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
const (
PeersListBehavior = "/dms/node/peers/list"
PeerAddrInfoBehavior = "/dms/node/peers/self"
PeerPingBehavior = "/dms/node/peers/ping"
PeerDHTBehavior = "/dms/node/peers/dht"
PeerConnectBehavior = "/dms/node/peers/connect"
PeerScoreBehavior = "/dms/node/peers/score"
OnboardBehavior = "/dms/node/onboarding/onboard"
OffboardBehavior = "/dms/node/onboarding/offboard"
OnboardStatusBehavior = "/dms/node/onboarding/status"
OnboardResourceBehavior = "/dms/node/onboarding/resource"
CustomVMStartBehavior = "/dms/node/vm/start/custom"
VMStopBehavior = "/dms/node/vm/stop"
VMListBehavior = "/dms/node/vm/list"
pingTimeout = 1 * time.Second
)
type PingRequest struct {
Host string
}
type PingResponse struct {
Error string
RTT int64
}
func (n *Node) handlePeerPing(msg actor.Envelope) {
defer msg.Discard()
var request PingRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
resp := PingResponse{}
res, err := n.network.Ping(context.Background(), request.Host, pingTimeout)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
if res.Error != nil {
resp.Error = res.Error.Error()
}
resp.RTT = res.RTT.Milliseconds()
n.sendReply(msg, resp)
}
type PeersListResponse struct {
Peers []peer.ID
}
func (n *Node) handlePeersList(msg actor.Envelope) {
defer msg.Discard()
// get the underlying libp2p instance and extract the DHT data
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
// TODO log
return
}
resp := PeersListResponse{
Peers: make([]peer.ID, 0),
}
for _, v := range libp2pNet.PS.Peers() {
resp.Peers = append(resp.Peers, v)
}
n.sendReply(msg, resp)
}
type PeerAddrInfoResponse struct {
ID string `json:"id"`
Address string `json:"listen_addr"`
}
func (n *Node) handlePeerAddrInfo(msg actor.Envelope) {
defer msg.Discard()
stats := n.network.Stat()
resp := PeerAddrInfoResponse{
ID: stats.ID,
Address: stats.ListenAddr,
}
n.sendReply(msg, resp)
}
type PeerDHTResponse struct {
Peers []kbucket.PeerInfo
}
func (n *Node) handlePeerDHT(msg actor.Envelope) {
defer msg.Discard()
// get the underlying libp2p instance and extract the DHT data
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
// TODO log
return
}
resp := PeerDHTResponse{
Peers: libp2pNet.DHT.RoutingTable().GetPeerInfos(),
}
n.sendReply(msg, resp)
}
type PeerConnectRequest struct {
Address string
}
type PeerConnectResponse struct {
Status string
Error string
}
func (n *Node) handlePeerConnect(msg actor.Envelope) {
defer msg.Discard()
var request PeerConnectRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
// TODO log
return
}
resp := PeerConnectResponse{}
peerAddr, err := multiaddr.NewMultiaddr(request.Address)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
if err := libp2pNet.Host.Connect(context.Background(), *addrInfo); err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Status = "CONNECTED"
n.sendReply(msg, resp)
}
type OnboardRequest struct {
Config types.OnboardingConfig
}
type OnboardResponse struct {
Error string
Config types.OnboardingConfig
}
func (n *Node) handleOnboard(msg actor.Envelope) {
defer msg.Discard()
resp := OnboardResponse{}
var request OnboardRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
err := n.onboarder.Onboard(context.Background(), request.Config)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Config = request.Config
n.sendReply(msg, resp)
}
type OffboardRequest struct {
Force bool
}
type OffboardResponse struct {
Success bool
}
func (n *Node) handleOffboard(msg actor.Envelope) {
defer msg.Discard()
var request OffboardRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
resp := OffboardResponse{}
if err := n.onboarder.Offboard(context.Background(), request.Force); err != nil {
resp.Success = false
n.sendReply(msg, resp)
return
}
resp.Success = true
n.sendReply(msg, resp)
}
type OnboardStatusResponse struct {
Onboarded bool
Error string
}
func (n *Node) handleOnboardStatus(msg actor.Envelope) {
defer msg.Discard()
resp := OnboardStatusResponse{}
onboarded, err := n.onboarder.IsOnboarded(context.Background())
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Onboarded = onboarded
n.sendReply(msg, resp)
}
type OnboardResourceRequest struct {
Config types.OnboardingConfig
}
type OnboardResourceResponse struct {
Error string
Result types.OnboardingConfig
}
func (n *Node) handleOnboardResource(msg actor.Envelope) {
defer msg.Discard()
var request OnboardResourceRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
resp := OnboardResourceResponse{}
err := n.onboarder.Update(context.Background(), request.Config)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Result = request.Config
n.sendReply(msg, resp)
}
type CustomVMStartRequest struct {
Execution types.ExecutionRequest
}
type CustomVMStartResponse struct {
Error string
}
func (n *Node) handleCustomVMStart(msg actor.Envelope) {
defer msg.Discard()
var request CustomVMStartRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
resp := CustomVMStartResponse{}
err := n.executor.Start(context.Background(), &request.Execution)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
type VMStopRequest struct {
ExecutionID string
}
type VMStopResponse struct {
Error string
}
func (n *Node) handleVMStop(msg actor.Envelope) {
defer msg.Discard()
var request VMStopRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
resp := VMStopResponse{}
err := n.executor.Cancel(context.Background(), request.ExecutionID)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
type ListVMResponse struct {
Error string
VMS []types.ExecutionListItem
}
func (n *Node) handleListVM(msg actor.Envelope) {
defer msg.Discard()
resp := ListVMResponse{
VMS: n.executor.List(),
}
n.sendReply(msg, resp)
}
type PeerScoreResponse struct {
Score map[string]*network.PeerScoreSnapshot
}
func (n *Node) handlePeerScore(msg actor.Envelope) {
defer msg.Discard()
resp := PeerScoreResponse{Score: make(map[string]*network.PeerScoreSnapshot)}
snapshot := n.network.GetBroadcastScore()
for p, score := range snapshot {
resp.Score[p.String()] = score
}
n.sendReply(msg, resp)
}
package node
import (
"context"
"errors"
"fmt"
"math/rand"
"sync"
"sync/atomic"
"time"
// "github.com/google/uuid"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/executor"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
const (
helloMinDelay = 10 * time.Second
helloMaxDelay = 20 * time.Second
helloTimeout = 3 * time.Second
helloAttempts = 3
rootProto = "actor/root/messages/0.0.1"
)
// Node is the structure that holds the node's dependencies.
type Node struct {
rootCap ucan.CapabilityContext
actor actor.Actor
scheduler *bt.Scheduler
network network.Network
resourceManager types.ResourceManager
hardware types.HardwareManager
hostID string
onboarder *onboarding.Onboarding
executor executor.Executor
ctx context.Context
cancel func()
mx sync.Mutex
peers map[peer.ID]*peerState
deployments map[string]*jobs.Orchestrator
allocations map[string]*jobs.Allocation
running int32
}
type peerState struct {
conns int
hasRoot bool
helloIn, helloOut, helloPending bool
helloAttempts int
}
// New creates a new node, attaches an actor to the node.
func New(onboarder *onboarding.Onboarding,
rootCap ucan.CapabilityContext,
hostID string, net network.Network,
resourceManager types.ResourceManager,
scheduler *bt.Scheduler,
hardware types.HardwareManager,
) (*Node, error) {
if onboarder == nil {
return nil, errors.New("onboarder is nil")
}
if rootCap == nil {
return nil, errors.New("root capability context is nil")
}
if hostID == "" {
return nil, errors.New("host id is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
if scheduler == nil {
return nil, errors.New("scheduler is nil")
}
rootDID := rootCap.DID()
rootTrust := rootCap.Trust()
anchor, err := rootTrust.GetAnchor(rootDID)
if err != nil {
return nil, fmt.Errorf("failed to get root DID anchor: %w", err)
}
pubk := anchor.PublicKey()
provider, err := rootTrust.GetProvider(rootDID)
if err != nil {
return nil, fmt.Errorf("failed to get root DID provider: %w", err)
}
privk, err := provider.PrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to get root private key: %w", err)
}
rootSec, err := actor.NewBasicSecurityContext(pubk, privk, rootCap)
if err != nil {
return nil, fmt.Errorf("failed to create security context: %w", err)
}
nodeActor, err := createActor(rootSec, actor.NewRateLimiter(actor.DefaultRateLimiterConfig()), hostID, "root", net, scheduler)
if err != nil {
return nil, fmt.Errorf("failed to create node actor: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
executor, err := NewExecutor(ctx)
if err != nil {
cancel()
return nil, fmt.Errorf("new executor: %w", err)
}
n := &Node{
hostID: hostID,
network: net,
deployments: make(map[string]*jobs.Orchestrator),
allocations: make(map[string]*jobs.Allocation),
peers: make(map[peer.ID]*peerState),
resourceManager: resourceManager,
hardware: hardware,
actor: nodeActor,
rootCap: rootCap,
scheduler: scheduler,
onboarder: onboarder,
executor: executor,
ctx: ctx,
cancel: cancel,
}
if err := nodeActor.AddBehavior(PublicHelloBehavior, n.publicHelloBehavior); err != nil {
return nil, fmt.Errorf("adding public hello behavior: %w", err)
}
if err := nodeActor.AddBehavior(PublicStatusBehavior, n.publicStatusBehavior); err != nil {
return nil, fmt.Errorf("adding public status behavior: %w", err)
}
if err := nodeActor.AddBehavior(BroadcastHelloBehavior, n.broadcastHelloBehavior, actor.WithBehaviorTopic(BroadcastHelloTopic)); err != nil {
return nil, fmt.Errorf("adding broadcast status behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeersListBehavior, n.handlePeersList); err != nil {
return nil, fmt.Errorf("adding peers list behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerAddrInfoBehavior, n.handlePeerAddrInfo); err != nil {
return nil, fmt.Errorf("adding peers addr info behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerPingBehavior, n.handlePeerPing); err != nil {
return nil, fmt.Errorf("adding peer ping behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerDHTBehavior, n.handlePeerDHT); err != nil {
return nil, fmt.Errorf("adding peer dht behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerConnectBehavior, n.handlePeerConnect); err != nil {
return nil, fmt.Errorf("adding peer connect behavior: %w", err)
}
if err := nodeActor.AddBehavior(PeerScoreBehavior, n.handlePeerScore); err != nil {
return nil, fmt.Errorf("adding peer score behavior: %w", err)
}
if err := nodeActor.AddBehavior(OnboardBehavior, n.handleOnboard); err != nil {
return nil, fmt.Errorf("adding onboard behavior: %w", err)
}
if err := nodeActor.AddBehavior(OffboardBehavior, n.handleOffboard); err != nil {
return nil, fmt.Errorf("adding offboard behavior: %w", err)
}
if err := nodeActor.AddBehavior(OnboardStatusBehavior, n.handleOnboardStatus); err != nil {
return nil, fmt.Errorf("adding onboard status behavior: %w", err)
}
if err := nodeActor.AddBehavior(OnboardResourceBehavior, n.handleOnboardResource); err != nil {
return nil, fmt.Errorf("adding onboard resource behavior: %w", err)
}
if err := nodeActor.AddBehavior(CustomVMStartBehavior, n.handleCustomVMStart); err != nil {
return nil, fmt.Errorf("adding custom vm start behavior: %w", err)
}
if err := nodeActor.AddBehavior(VMStopBehavior, n.handleVMStop); err != nil {
return nil, fmt.Errorf("adding vm stop behavior: %w", err)
}
if err := nodeActor.AddBehavior(VMListBehavior, n.handleListVM); err != nil {
return nil, fmt.Errorf("adding vm list behavior: %w", err)
}
if err := nodeActor.AddBehavior(NewDeploymentBehavior, n.newDeployment); err != nil {
return nil, fmt.Errorf("adding new deployment behavior: %w", err)
}
if err := n.restoreDeployments(); err != nil {
log.Errorf("restoring deployments: %s", err)
}
return n, nil
}
// CreateAllocation creates an allocation
// func (n *Node) CreateAllocation(job jobs.Job) (*jobs.Allocation, error) {
// // generate random keypair
// priv, pub, err := crypto.GenerateKeyPair(crypto.Ed25519)
// if err != nil {
// return nil, fmt.Errorf("failed to generate random keypair for allocation job %s: %w", job.ID, err)
// }
// security, err := actor.NewBasicSecurityContext(pub, priv, n.rootCap)
// if err != nil {
// return nil, fmt.Errorf("failed to create security context: %w", err)
// }
// allocationInbox, err := uuid.NewUUID()
// if err != nil {
// return nil, fmt.Errorf("failed to generate uuid for allocation inbox: %w", err)
// }
// actor, err := createActor(security, n.actor.Limiter(), n.hostID, allocationInbox.String(), n.network, n.scheduler)
// if err != nil {
// return nil, fmt.Errorf("failed to create allocation actor: %w", err)
// }
// allocation, err := jobs.NewAllocation(actor, jobs.AllocationDetails{Job: job, NodeID: n.hostID}, n.resourceManager)
// if err != nil {
// return nil, fmt.Errorf("failed to create allocation actor: %w", err)
// }
// err = allocation.Start()
// if err != nil {
// return nil, fmt.Errorf("failed to start the allocation: %w", err)
// }
// n.mx.Lock()
// n.allocations[allocation.ID] = allocation
// n.mx.Unlock()
// return allocation, nil
// }
// GetAllocation gets an allocation by id.
func (n *Node) GetAllocation(id string) (*jobs.Allocation, error) {
n.mx.Lock()
defer n.mx.Unlock()
alloc, ok := n.allocations[id]
if !ok {
return nil, errors.New("allocation not found")
}
return alloc, nil
}
// Start node
func (n *Node) Start() error {
if !atomic.CompareAndSwapInt32(&n.running, 0, 1) {
return nil
}
if err := n.actor.Start(); err != nil {
return fmt.Errorf("failed to start node actor: %w", err)
}
if err := n.subscribe(BroadcastHelloTopic); err != nil {
_ = n.actor.Stop()
return err
}
return nil
}
func (n *Node) subscribe(topics ...string) error {
for _, topic := range topics {
if err := n.actor.Subscribe(topic, n.setupBroadcast); err != nil {
return fmt.Errorf("error subscribing to %s: %w", topic, err)
}
}
n.network.SetBroadcastAppScore(n.broadcastScore)
if err := n.network.Notify(n.actor.Context(), n.peerPreConnected, n.peerConnected, n.peerDisconnected, n.peerIdentified, n.peerIdentified); err != nil {
return fmt.Errorf("error setting up peer notifications: %w", err)
}
return nil
}
func (n *Node) setupBroadcast(topic string) error {
return n.network.SetupBroadcastTopic(topic, func(t *network.Topic) error {
return t.SetScoreParams(&pubsub.TopicScoreParams{
SkipAtomicValidation: true,
TopicWeight: 1.0,
TimeInMeshWeight: 0.00027, // ~1/3600
TimeInMeshQuantum: time.Second,
TimeInMeshCap: 1.0,
InvalidMessageDeliveriesWeight: -1000,
InvalidMessageDeliveriesDecay: pubsub.ScoreParameterDecay(time.Hour),
})
})
}
func (n *Node) broadcastScore(p peer.ID) float64 {
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
return 0
}
if st.helloIn && st.helloOut {
return 5
}
if st.hasRoot {
return 1
}
return 0
}
func (n *Node) peerConnected(p peer.ID) {
log.Debugf("peer connected: %s", p)
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
st = &peerState{}
n.peers[p] = st
}
st.conns++
}
func (n *Node) peerPreConnected(p peer.ID, protos []protocol.ID, conns int) {
log.Debugf("peer preconnected: %s %s (%d)", p, protos, conns)
n.mx.Lock()
defer n.mx.Unlock()
st := &peerState{conns: conns}
n.peers[p] = st
if includesRootProtocol(protos) {
st.hasRoot = true
st.helloPending = true
st.helloAttempts = 1
go n.sayHello(p)
}
}
func (n *Node) peerIdentified(p peer.ID, protos []protocol.ID) {
log.Debugf("peer identified: %s %s", p, protos)
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
st = &peerState{}
n.peers[p] = st
}
if includesRootProtocol(protos) {
st.hasRoot = true
if !st.helloOut && !st.helloPending {
st.helloPending = true
st.helloAttempts++
go n.sayHello(p)
}
}
}
func (n *Node) peerDisconnected(p peer.ID) {
log.Debugf("peer disconnected: %s", p)
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
return
}
st.conns--
if st.conns <= 0 {
delete(n.peers, p)
}
}
func (n *Node) sayHello(p peer.ID) {
pubk, err := p.ExtractPublicKey()
if err != nil {
log.Debugf("failed to extract public key: %s", err)
return
}
if !crypto.AllowedKey(int(pubk.Type())) {
log.Debugf("unexpected key type: %d", pubk.Type())
return
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
log.Debugf("failed to extract actor ID: %s", err)
return
}
actorDID := did.FromPublicKey(pubk)
handle := actor.Handle{
ID: actorID,
DID: actorDID,
Address: actor.Address{
HostID: p.String(),
InboxAddress: "root",
},
}
wait := helloMinDelay + time.Duration(rand.Int63n(int64(helloMaxDelay-helloMinDelay))) //nolint
time.Sleep(wait)
n.mx.Lock()
st, ok := n.peers[p]
if !ok {
n.mx.Unlock()
return
}
if !n.network.PeerConnected(p) {
st.helloPending = false
n.mx.Unlock()
return
}
n.mx.Unlock()
msg, err := actor.Message(
n.actor.Handle(),
handle,
PublicHelloBehavior,
nil,
actor.WithMessageTimeout(helloTimeout),
)
if err != nil {
log.Debugf("failed to construct hello message: %s", err)
return
}
log.Debugf("saying hello to %s", handle.Address.HostID)
replyCh, err := n.actor.Invoke(msg)
if err != nil {
n.mx.Lock()
if st, ok = n.peers[p]; ok {
if st.helloAttempts < helloAttempts {
st.helloAttempts++
go n.sayHello(p)
} else {
st.helloPending = false
}
}
n.mx.Unlock()
log.Debugf("error invoking hello: %s", err)
return
}
select {
case reply := <-replyCh:
reply.Discard()
n.mx.Lock()
if st, ok = n.peers[p]; ok {
st.helloOut = true
st.helloPending = false
} else if n.network.PeerConnected(p) {
// race with connected notification
st = &peerState{helloOut: true}
n.peers[p] = st
}
n.mx.Unlock()
log.Infof("got hello response from %s", handle.Address.HostID)
case <-time.After(time.Until(msg.Expiry())):
n.mx.Lock()
if st, ok = n.peers[p]; ok {
if st.helloAttempts < helloAttempts {
st.helloAttempts++
go n.sayHello(p)
} else {
st.helloPending = false
}
}
n.mx.Unlock()
log.Debugf("hello timeout for %s", handle.Address.HostID)
}
}
// Stop node
func (n *Node) Stop() error {
n.mx.Lock()
defer n.mx.Unlock()
if !atomic.CompareAndSwapInt32(&n.running, 1, 0) {
return nil
}
// stop all allocations
for k, alloc := range n.allocations {
if err := alloc.Stop(); err != nil {
log.Warnf("error stopping allocation %s: %err", k, err)
}
}
if err := n.saveDeployments(); err != nil {
log.Errorf("error saving active deployments: %s", err)
}
n.cancel()
// clear the broadcast app score
n.network.SetBroadcastAppScore(nil)
// stop the actor
if err := n.actor.Stop(); err != nil {
return fmt.Errorf("failed to stop node actor: %w", err)
}
return nil
}
func (n *Node) sendReply(msg actor.Envelope, payload interface{}) {
var opt []actor.MessageOption
if msg.IsBroadcast() {
opt = append(opt, actor.WithMessageSource(n.actor.Handle()))
}
reply, err := actor.ReplyTo(msg, payload, opt...)
if err != nil {
log.Debugf("error creating reply: %s", err)
return
}
if err := n.actor.Send(reply); err != nil {
log.Debugf("error sending reply: %s", err)
}
}
// createActor creates an actor.
func createActor(sctx *actor.BasicSecurityContext, limiter actor.RateLimiter, hostID, inboxAddress string, net network.Network, scheduler *bt.Scheduler) (*actor.BasicActor, error) {
self := actor.Handle{
ID: sctx.ID(),
DID: sctx.DID(),
Address: actor.Address{
HostID: hostID,
InboxAddress: inboxAddress,
},
}
actor, err := actor.New(scheduler, net, sctx, limiter, actor.BasicActorParams{}, self)
if err != nil {
return nil, fmt.Errorf("failed to create actor: %w", err)
}
return actor, nil
}
func includesRootProtocol(protos []protocol.ID) bool {
for _, proto := range protos {
if proto == rootProto {
return true
}
}
return false
}
//go:build linux
// +build linux
package node
import (
"context"
"errors"
"fmt"
"gitlab.com/nunet/device-management-service/executor"
"gitlab.com/nunet/device-management-service/executor/firecracker"
"gitlab.com/nunet/device-management-service/executor/null"
)
func NewExecutor(ctx context.Context) (executor.Executor, error) {
executor, err := firecracker.NewExecutor(ctx, "root")
if err != nil {
if errors.Is(err, firecracker.ErrNotInstalled) {
executor, err := null.NewExecutor(ctx, "root")
if err != nil {
return nil, fmt.Errorf("failed to setup null executor: %w", err)
}
return executor, nil
}
return nil, fmt.Errorf("failed to create executor: %w", err)
}
return executor, nil
}
package node
import (
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/types"
)
const (
PublicHelloBehavior = "/public/hello"
PublicStatusBehavior = "/public/status"
BroadcastHelloBehavior = "/broadcast/hello"
BroadcastHelloTopic = "/nunet/hello"
)
type HelloResponse struct {
DID did.DID
}
type PublicStatusResponse struct {
Status string
Resources types.Resources
}
func (n *Node) publicHelloBehavior(msg actor.Envelope) {
pubk, err := did.PublicKeyFromDID(msg.From.DID)
if err != nil {
log.Debugf("failed to extract public key from DID: %s", err)
return
}
p, err := peer.IDFromPublicKey(pubk)
if err != nil {
log.Debugf("failed to extract peer ID from public key: %s", err)
return
}
n.mx.Lock()
if st, ok := n.peers[p]; ok {
st.helloIn = true
} else if n.network.PeerConnected(p) {
// rance with connected notification
st = &peerState{helloIn: true}
n.peers[p] = st
}
n.mx.Unlock()
n.handleHello(msg)
}
func (n *Node) broadcastHelloBehavior(msg actor.Envelope) {
n.handleHello(msg)
}
func (n *Node) handleHello(msg actor.Envelope) {
defer msg.Discard()
log.Debugf("hello from %s", msg.From.Address.HostID)
resp := HelloResponse{DID: n.actor.Security().DID()}
n.sendReply(msg, resp)
}
func (n *Node) publicStatusBehavior(msg actor.Envelope) {
defer msg.Discard()
var resp PublicStatusResponse
machineResources, err := n.hardware.GetMachineResources()
if err != nil {
resp.Status = "ERROR"
} else {
resp.Status = "OK"
resp.Resources = machineResources.Resources
}
n.sendReply(msg, resp)
}
package onboarding
import (
"context"
"errors"
"fmt"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
var ErrMachineNotOnboarded = errors.New("machine is not onboarded")
type Config struct {
Fs afero.Afero
WorkDir string
DatabasePath string
ConfigRepo repositories.OnboardingConfig
P2PRepo repositories.Libp2pInfo
ResourceManager types.ResourceManager
Hardware types.HardwareManager
UUIDRepo repositories.MachineUUID
}
// NewConfig is a constructor for Config
func NewConfig(
fs afero.Afero,
workDir, dbPath string,
configRepo repositories.OnboardingConfig,
p2pRepo repositories.Libp2pInfo,
uuidRepo repositories.MachineUUID,
) *Config {
return &Config{
Fs: fs,
WorkDir: workDir,
DatabasePath: dbPath,
ConfigRepo: configRepo,
P2PRepo: p2pRepo,
UUIDRepo: uuidRepo,
}
}
// Onboarding acts a receiver for methods related to onboarding
type Onboarding struct {
Config
}
// New is a constructor for Onboarding
func New(config *Config) *Onboarding {
return &Onboarding{Config: *config}
}
// IsOnboarded checks whether the machine is onboarded or not
func (o *Onboarding) IsOnboarded(ctx context.Context) (bool, error) {
_, err := o.ConfigRepo.Get(ctx)
if err != nil {
return false, ErrMachineNotOnboarded
}
// TODO: validate onboarding params
return true, nil
}
// Info returns the onboarding configuration
// It fetches the onboarding config from the database and the onboarded resources from the resource manager
// It also fetches the machine resources from the hardware package
func (o *Onboarding) Info(ctx context.Context) (types.OnboardingConfig, error) {
info, err := o.ConfigRepo.Get(ctx)
if err != nil {
return types.OnboardingConfig{}, fmt.Errorf("could not get onboarding config: %w", err)
}
// get onboarded resources from the resource manager
resources, err := o.ResourceManager.GetOnboardedResources(ctx)
if err != nil {
return types.OnboardingConfig{}, fmt.Errorf("could not get onboarded resources: %w", err)
}
info.OnboardedResources = resources.Resources
// get machine resources
machineResources, err := o.Hardware.GetMachineResources()
if err != nil {
return types.OnboardingConfig{}, fmt.Errorf("could not get machine resources: %w", err)
}
info.MachineResources = machineResources.Resources
return info, nil
}
// Onboard validates the onboarding params and onboards the machine to the network
// It saves the onboarding config to the database and updates the onboarded resources in the resource manager
func (o *Onboarding) Onboard(ctx context.Context, config types.OnboardingConfig) error {
log.Debugf("onboarding the machine with the config: %+v", config)
if err := o.validatePrerequisites(config); err != nil {
return fmt.Errorf("could not validate onboarding prerequisites: %w", err)
}
if err := o.ResourceManager.UpdateOnboardedResources(ctx, config.OnboardedResources); err != nil {
return fmt.Errorf("could not update onboarded resources: %w", err)
}
if _, err := o.ConfigRepo.Save(ctx, config); err != nil {
return fmt.Errorf("could not save onboarding config: %w", err)
}
log.Debugf("machine onboarded successfully")
return nil
}
// Update updates the onboarding configuration
// Currently, it only updates the onboarded resources
func (o *Onboarding) Update(ctx context.Context, config types.OnboardingConfig) error {
log.Debugf("updating the onboarding config with the new config: %+v", config)
// update onboarded resources if there is a change with the existing config
onboardedResources, err := o.ResourceManager.GetOnboardedResources(ctx)
if err != nil {
return fmt.Errorf("could not get onboarded resources: %w", err)
}
if onboardedResources.Equal(config.OnboardedResources) {
return nil
}
if err := o.ResourceManager.UpdateOnboardedResources(ctx, config.OnboardedResources); err != nil {
return fmt.Errorf("could not update onboarded resources: %w", err)
}
log.Debugf("onboarding config updated successfully")
return nil
}
// Offboard offboards the machine from the network by clearing the onboarding config from the database
func (o *Onboarding) Offboard(ctx context.Context, force bool) error {
onboarded, err := o.IsOnboarded(ctx)
if err != nil && !force {
return fmt.Errorf("could not retrieve onboard status: %w", err)
}
if err != nil {
log.Errorf("problem with onboarding state: %v", err)
log.Info("continuing with offboarding because forced")
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
// TODO: shutdown routine to stop networking etc... here
err = o.ConfigRepo.Clear(ctx)
if err != nil && !force {
return fmt.Errorf("failed to clear onboarding config from db: %w", err)
}
if err != nil {
log.Errorf("failed to clear onboarding config from db: %v", err)
log.Info("continuing with offboarding because forced")
}
// clear the onboarded resources
if err := o.ResourceManager.UpdateOnboardedResources(ctx, types.Resources{}); err != nil {
return fmt.Errorf("could not clear onboarded resources: %w", err)
}
return nil
}
// validateCapacity validates the resource capacity data
// It checks if the CPU and memory are within 10% and 90% of the available resources
func (o *Onboarding) validateCapacity(resources types.Resources) error {
// TODO: https://gitlab.com/nunet/device-management-service/-/merge_requests/563#note_2139212199
machineResources, err := o.Hardware.GetMachineResources()
if err != nil {
return fmt.Errorf("could not get provisioned resources: %w", err)
}
if resources.CPU.Cores < 1 {
return fmt.Errorf("cores must be between %d and %.0f", 1, machineResources.CPU.Cores)
}
if resources.CPU.Compute() > machineResources.CPU.Compute()*9/10 ||
resources.CPU.Compute() < machineResources.CPU.Compute()/10 {
return fmt.Errorf("CPU should be between 10%% and 90%% of the available CPU (%.2f and %.2f): %.2f",
machineResources.CPU.Compute()/10, machineResources.CPU.Compute()*9/10, resources.CPU.Compute())
}
if resources.RAM.Size > machineResources.RAM.Size*9/10 || resources.RAM.Size < machineResources.RAM.Size/10 {
return fmt.Errorf("memory should be between 10%% and 90%% of the available memory (%.2f and %.2f): %.2f",
machineResources.RAM.Size/10, machineResources.RAM.Size*9/10, resources.RAM.Size)
}
return nil
}
// validatePrerequisites validates the onboarding prerequisites
func (o *Onboarding) validatePrerequisites(config types.OnboardingConfig) error {
ok, err := o.Fs.DirExists(o.WorkDir)
if err != nil {
return fmt.Errorf("could not check if config directory exists: %w", err)
}
if !ok {
return fmt.Errorf("working directory does not exist")
}
if err := o.validateCapacity(config.OnboardedResources); err != nil {
return fmt.Errorf("could not validate capacity data: %w", err)
}
return nil
}
package resources
import (
"context"
"fmt"
"sync"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// ManagerRepos holds all the repositories needed for resource management
type ManagerRepos struct {
FreeResources repositories.FreeResources
OnboardedResources repositories.OnboardedResources
ResourceAllocation repositories.ResourceAllocation
}
// DefaultManager implements the ResourceManager interface
// TODO: Add telemetry for the methods https://gitlab.com/nunet/device-management-service/-/issues/535
type DefaultManager struct {
repos ManagerRepos
store *store
hardware types.HardwareManager
// allocationLock is used to synchronize access to the allocation pool during allocation and deallocation
// it ensures that resource allocation and deallocation are atomic operations
allocationLock sync.RWMutex
}
// NewResourceManager returns a new defaultResourceManager instance
func NewResourceManager(repos ManagerRepos, hardware types.HardwareManager) (*DefaultManager, error) {
if hardware == nil {
return nil, fmt.Errorf("hardware manager cannot be nil")
}
rmStore := newStore()
return &DefaultManager{
repos: repos,
store: rmStore,
hardware: hardware,
}, nil
}
var _ types.ResourceManager = (*DefaultManager)(nil)
// AllocateResources allocates resources for a job
func (d *DefaultManager) AllocateResources(ctx context.Context, allocation types.ResourceAllocation) error {
d.allocationLock.Lock()
defer d.allocationLock.Unlock()
// Check if resources are already allocated for the job
var ok bool
d.store.withAllocationsRLock(func() {
_, ok = d.store.allocations[allocation.JobID]
})
if ok {
return fmt.Errorf("resources already allocated for job %s", allocation.JobID)
}
freeResources, err := d.GetFreeResources(ctx)
if err != nil {
return fmt.Errorf("getting free resources: %w", err)
}
// Check if there are enough free resources in dms pool to allocate
if err := freeResources.Subtract(allocation.Resources); err != nil {
return fmt.Errorf("no free resources: %w", err)
}
// Check if there are enough free resources on the machine to allocate
systemFreeResources, err := d.hardware.GetFreeResources()
if err != nil {
return fmt.Errorf("get system free resources: %w", err)
}
if err := systemFreeResources.Subtract(allocation.Resources); err != nil {
return fmt.Errorf("no free resources on the machine: %w", err)
}
// Potential issue: if the free resources are updated in the db, the allocations should be updated as well
// If the allocations update fails, the free resources should not be updated
// Since we have no concept of transactions in the current implementation of db, we cannot handle this scenario
// without writing a custom transaction manager
if err := d.updateFreeResources(ctx, freeResources); err != nil {
return fmt.Errorf("updating free resources in db: %w", err)
}
if err := d.storeAllocation(ctx, allocation); err != nil {
return fmt.Errorf("storing allocations in db: %w", err)
}
return nil
}
// DeallocateResources deallocates resources for a job
func (d *DefaultManager) DeallocateResources(ctx context.Context, jobID string) error {
d.allocationLock.Lock()
defer d.allocationLock.Unlock()
// Check if resources are already deallocated for the job
var (
allocation types.ResourceAllocation
ok bool
)
d.store.withAllocationsRLock(func() {
allocation, ok = d.store.allocations[jobID]
})
if !ok {
return fmt.Errorf("resources not allocated for job %s", jobID)
}
// Get the free resources in order to update them
freeResources, err := d.GetFreeResources(ctx)
if err != nil {
return fmt.Errorf("getting free resources: %w", err)
}
// Deallocate the resources
// Potential issue: if the free resources are updated in the db, the allocations should be updated as well
// If the allocations update fails, the free resources should not be updated
// Since we have no concept of transactions in the current implementation of db, we cannot handle this scenario
// without writing a custom transaction manager
if err := freeResources.Add(allocation.Resources); err != nil {
return fmt.Errorf("adding resources: %w", err)
}
if err := d.updateFreeResources(ctx, freeResources); err != nil {
return fmt.Errorf("updating free resources in db: %w", err)
}
if err := d.deleteAllocation(ctx, jobID); err != nil {
return fmt.Errorf("deleting allocations from db: %w", err)
}
return nil
}
// GetFreeResources returns the free resources in the allocation pool
func (d *DefaultManager) GetFreeResources(ctx context.Context) (types.FreeResources, error) {
var (
freeResources types.FreeResources
ok bool
)
d.store.withFreeRLock(func() {
if d.store.freeResources != nil {
freeResources = *d.store.freeResources
ok = true
}
})
if ok {
return freeResources, nil
}
freeResources, err := d.repos.FreeResources.Get(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("failed to get free resources: %w", err)
}
d.store.withFreeLock(func() {
d.store.freeResources = &freeResources
})
return freeResources, nil
}
// GetTotalAllocation returns the total allocations of the jobs requiring resources
func (d *DefaultManager) GetTotalAllocation() (types.Resources, error) {
if len(d.store.allocations) == 0 {
if err := d.getAllocationsFromDB(context.Background()); err != nil {
return types.Resources{}, fmt.Errorf("getting allocations from db: %w", err)
}
}
var (
totalAllocation types.Resources
err error
)
d.store.withAllocationsRLock(func() {
for _, allocation := range d.store.allocations {
err = totalAllocation.Add(allocation.Resources)
if err != nil {
break
}
}
})
return totalAllocation, err
}
// GetOnboardedResources returns the onboarded resources of the machine
func (d *DefaultManager) GetOnboardedResources(ctx context.Context) (types.OnboardedResources, error) {
var (
onboardedResources types.OnboardedResources
ok bool
)
d.store.withOnboardedRLock(func() {
if d.store.onboardedResources != nil {
onboardedResources = *d.store.onboardedResources
ok = true
}
})
if ok {
return onboardedResources, nil
}
onboardedResources, err := d.repos.OnboardedResources.Get(ctx)
if err != nil {
return types.OnboardedResources{}, fmt.Errorf("failed to get onboarded resources: %w", err)
}
_ = d.store.withOnboardedLock(func() error {
d.store.onboardedResources = &onboardedResources
return nil
})
return onboardedResources, nil
}
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
func (d *DefaultManager) UpdateOnboardedResources(ctx context.Context, resources types.Resources) error {
if err := d.store.withOnboardedLock(func() error {
// calculate the new free resources based on the allocations
totalAllocation, err := d.GetTotalAllocation()
if err != nil {
return fmt.Errorf("getting total allocations: %w", err)
}
onboardedResources := types.OnboardedResources{Resources: resources}
// Check if the demand is too high
if err := resources.Subtract(totalAllocation); err != nil {
return fmt.Errorf("couldn't subtract allocation: %w. Demand too high", err)
}
// Potential issue: if the onboarded resources are updated in the db, the free resources should be updated as well
// If the free resources update fails, the onboarded resources should not be updated
// Since we have no concept of transactions in the current implementation of db, we cannot handle this scenario
// without writing a custom transaction manager
_, err = d.repos.OnboardedResources.Save(ctx, onboardedResources)
if err != nil {
return fmt.Errorf("failed to update onboarded resources: %w", err)
}
d.store.onboardedResources = &onboardedResources
if err := d.updateFreeResources(ctx, types.FreeResources{
Resources: resources,
}); err != nil {
return fmt.Errorf("updating free resources in db: %w", err)
}
return nil
}); err != nil {
return err
}
return nil
}
// updateFreeResources updates the free resources in the database and the store
func (d *DefaultManager) updateFreeResources(ctx context.Context, freeResources types.FreeResources) error {
_, err := d.repos.FreeResources.Save(ctx, freeResources)
if err != nil {
return fmt.Errorf("updating free resources: %w", err)
}
// update the free resources in the store
d.store.withFreeLock(func() {
d.store.freeResources = &freeResources
})
return nil
}
// getAllocationsFromDB fetches the allocations from the database
func (d *DefaultManager) getAllocationsFromDB(ctx context.Context) error {
allocations, err := d.repos.ResourceAllocation.FindAll(ctx, d.repos.ResourceAllocation.GetQuery())
if err != nil {
return fmt.Errorf("getting allocations from db: %w", err)
}
d.store.withAllocationsLock(func() {
for _, allocation := range allocations {
d.store.allocations[allocation.JobID] = allocation
}
})
return nil
}
// storeAllocation stores the allocations in the database and the store
func (d *DefaultManager) storeAllocation(ctx context.Context, allocation types.ResourceAllocation) error {
_, err := d.repos.ResourceAllocation.Create(ctx, allocation)
if err != nil {
return fmt.Errorf("storing allocations in db: %w", err)
}
d.store.withAllocationsLock(func() {
d.store.allocations[allocation.JobID] = allocation
})
return nil
}
// deleteAllocation deletes the allocations from the database and the store
func (d *DefaultManager) deleteAllocation(ctx context.Context, jobID string) error {
query := d.repos.ResourceAllocation.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("JobID", jobID))
allocation, err := d.repos.ResourceAllocation.Find(context.Background(), query)
if err != nil {
return fmt.Errorf("finding allocations in db: %w", err)
}
if err := d.repos.ResourceAllocation.Delete(ctx, allocation.ID); err != nil {
return fmt.Errorf("deleting allocations from db: %w", err)
}
d.store.withAllocationsLock(func() {
delete(d.store.allocations, jobID)
})
return nil
}
package resources
import (
"sync"
"gitlab.com/nunet/device-management-service/types"
)
// locks holds the locks for the resource manager
// allocations: lock for the allocations map
// onboarded: lock for the onboarded resources
// free: lock for the free resources
type locks struct {
allocations sync.RWMutex
onboarded sync.RWMutex
free sync.RWMutex
}
// newLocks returns a new locks instance
func newLocks() *locks {
return &locks{}
}
// store holds the resources of the machine
// onboardedResources: resources that are onboarded to the machine
// freeResources: resources that are free to be allocated
// allocations: resources that are requested by the jobs
type store struct {
onboardedResources *types.OnboardedResources
freeResources *types.FreeResources
allocations map[string]types.ResourceAllocation
locks *locks
}
// newStore returns a new store instance
func newStore() *store {
return &store{
allocations: make(map[string]types.ResourceAllocation),
locks: newLocks(),
}
}
// withAllocationsLock locks the allocations lock and executes the function
func (s *store) withAllocationsLock(fn func()) {
s.locks.allocations.Lock()
defer s.locks.allocations.Unlock()
fn()
}
// withOnboardedLock locks the onboarded lock and executes the function
func (s *store) withOnboardedLock(fn func() error) error {
s.locks.onboarded.Lock()
defer s.locks.onboarded.Unlock()
return fn()
}
// withFreeLock locks the free lock and executes the function
func (s *store) withFreeLock(fn func()) {
s.locks.free.Lock()
defer s.locks.free.Unlock()
fn()
}
// withAllocationsRLock performs a read lock and returns the result and error
func (s *store) withAllocationsRLock(fn func()) {
s.locks.allocations.RLock()
defer s.locks.allocations.RUnlock()
fn()
}
// withOnboardedRLock performs a read lock and returns the result and error
func (s *store) withOnboardedRLock(fn func()) {
s.locks.onboarded.RLock()
defer s.locks.onboarded.RUnlock()
fn()
}
// withFreeRLock performs a read lock and returns the result and error
func (s *store) withFreeRLock(fn func()) {
s.locks.free.RLock()
defer s.locks.free.RUnlock()
fn()
}
package dms
import (
"gorm.io/gorm"
)
// SanityCheck before being deleted performed basic consistency checks before starting the DMS
// in the following sequence:
// It checks for services that are marked running from the database and stops then removes them.
// Update their status to 'finshed with errors'.
// Recalculates free resources and update the database.
//
// Deleted now because dependencies such as the docker package have been replaced with executor/docker
func SanityCheck(_ *gorm.DB) {
// TODO: sanity check of DMS last exit and correction of invalid states
// resources.CalcFreeResAndUpdateDB()
}
package docker
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/stdcopy"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
// Client wraps the Docker client to provide high-level operations on Docker containers and networks.
type Client struct {
client *client.Client // Embed the Docker client.
}
// NewDockerClient initializes a new Docker client with environment variables and API version negotiation.
func NewDockerClient() (*Client, error) {
c, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation(), client.WithHostFromEnv())
if err != nil {
return nil, err
}
return &Client{client: c}, nil
}
// IsInstalled checks if Docker is installed and reachable by pinging the Docker daemon.
func (c *Client) IsInstalled(ctx context.Context) bool {
_, err := c.client.Ping(ctx)
return err == nil
}
// CreateContainer creates a new Docker container with the specified configuration.
func (c *Client) CreateContainer(
ctx context.Context,
config *container.Config,
hostConfig *container.HostConfig,
networkingConfig *network.NetworkingConfig,
platform *v1.Platform,
name string,
) (string, error) {
_, err := c.PullImage(ctx, config.Image)
if err != nil {
return "", err
}
resp, err := c.client.ContainerCreate(
ctx,
config,
hostConfig,
networkingConfig,
platform,
name,
)
if err != nil {
return "", err
}
return resp.ID, nil
}
// InspectContainer returns detailed information about a Docker container.
func (c *Client) InspectContainer(ctx context.Context, id string) (types.ContainerJSON, error) {
return c.client.ContainerInspect(ctx, id)
}
// FollowLogs tails the logs of a specified container, returning separate readers for stdout and stderr.
func (c *Client) FollowLogs(ctx context.Context, id string) (stdout, stderr io.Reader, err error) {
cont, err := c.InspectContainer(ctx, id)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container")
}
logOptions := container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
}
logsReader, err := c.client.ContainerLogs(ctx, cont.ID, logOptions)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container logs")
}
stdoutReader, stdoutWriter := io.Pipe()
stderrReader, stderrWriter := io.Pipe()
go func() {
stdoutBuffer := bufio.NewWriter(stdoutWriter)
stderrBuffer := bufio.NewWriter(stderrWriter)
defer func() {
logsReader.Close()
stdoutBuffer.Flush()
stdoutWriter.Close()
stderrBuffer.Flush()
stderrWriter.Close()
}()
_, err = stdcopy.StdCopy(stdoutBuffer, stderrBuffer, logsReader)
if err != nil && !errors.Is(err, context.Canceled) {
zlog.Sugar().Warnf("context closed while getting logs: %v\n", err)
}
}()
return stdoutReader, stderrReader, nil
}
// StartContainer starts a specified Docker container.
func (c *Client) StartContainer(ctx context.Context, containerID string) error {
return c.client.ContainerStart(ctx, containerID, container.StartOptions{})
}
// WaitContainer waits for a container to stop, returning channels for the result and errors.
func (c *Client) WaitContainer(
ctx context.Context,
containerID string,
) (<-chan container.WaitResponse, <-chan error) {
return c.client.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
}
// StopContainer stops a running Docker container with a specified timeout.
func (c *Client) StopContainer(
ctx context.Context,
containerID string,
options container.StopOptions,
) error {
return c.client.ContainerStop(ctx, containerID, options)
}
// RemoveContainer removes a Docker container, optionally forcing removal and removing associated volumes.
func (c *Client) RemoveContainer(ctx context.Context, containerID string) error {
return c.client.ContainerRemove(
ctx,
containerID,
container.RemoveOptions{RemoveVolumes: true, Force: true},
)
}
// removeContainers removes all containers matching the specified filters.
func (c *Client) removeContainers(ctx context.Context, filterz filters.Args) error {
containers, err := c.client.ContainerList(
ctx,
container.ListOptions{All: true, Filters: filterz},
)
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(containers))
for _, container := range containers {
wg.Add(1)
go func(container types.Container, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.RemoveContainer(ctx, container.ID)
}(container, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// removeNetworks removes all networks matching the specified filters.
func (c *Client) removeNetworks(ctx context.Context, filterz filters.Args) error {
networks, err := c.client.NetworkList(ctx, network.ListOptions{Filters: filterz})
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(networks))
for _, n := range networks {
wg.Add(1)
go func(network network.Inspect, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.client.NetworkRemove(ctx, network.ID)
}(n, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// RemoveObjectsWithLabel removes all Docker containers and networks with a specific label.
func (c *Client) RemoveObjectsWithLabel(ctx context.Context, label string, value string) error {
filterz := filters.NewArgs(
filters.Arg("label", fmt.Sprintf("%s=%s", label, value)),
)
containerErr := c.removeContainers(ctx, filterz)
networkErr := c.removeNetworks(ctx, filterz)
return multierr.Combine(containerErr, networkErr)
}
// GetOutputStream streams the logs for a specified container.
// The 'since' parameter specifies the timestamp from which to start streaming logs.
// The 'follow' parameter indicates whether to continue streaming logs as they are produced.
// Returns an io.ReadCloser to read the output stream and an error if the operation fails.
func (c *Client) GetOutputStream(
ctx context.Context,
containerID string,
since string,
follow bool,
) (io.ReadCloser, error) {
logOptions := container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: follow,
Since: since,
}
logReader, err := c.client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
return nil, errors.Wrap(err, "failed to get container logs")
}
return logReader, nil
}
// FindContainer searches for a container by label and value, returning its ID if found.
func (c *Client) FindContainer(ctx context.Context, label string, value string) (string, error) {
containers, err := c.client.ContainerList(ctx, container.ListOptions{All: true})
if err != nil {
return "", err
}
for _, cont := range containers {
if cont.Labels[label] == value {
return cont.ID, nil
}
}
return "", fmt.Errorf("unable to find container for %s=%s", label, value)
}
// PullImage pulls a Docker image from a registry.
func (c *Client) PullImage(ctx context.Context, imageName string) (string, error) {
out, err := c.client.ImagePull(ctx, imageName, image.PullOptions{})
if err != nil {
zlog.Sugar().Errorf("unable to pull image: %v", err)
return "", err
}
defer out.Close()
d := json.NewDecoder(io.TeeReader(out, os.Stdout))
var message jsonmessage.JSONMessage
var digest string
for {
if err := d.Decode(&message); err != nil {
if err == io.EOF {
break
}
zlog.Sugar().Errorf("unable pull image: %v", err)
return "", err
}
if message.Aux != nil {
continue
}
if message.Error != nil {
zlog.Sugar().Errorf("unable pull image: %v", message.Error.Message)
return "", errors.New(message.Error.Message)
}
if strings.HasPrefix(message.Status, "Digest") {
digest = strings.TrimPrefix(message.Status, "Digest: ")
}
}
return digest, nil
}
package docker
import (
"context"
"fmt"
"io"
"os"
"sync/atomic"
"time"
"gitlab.com/nunet/device-management-service/dms/hardware"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
labelExecutorName = "nunet-executor"
labelJobID = "nunet-jobID"
labelExecutionID = "nunet-executionID"
outputStreamCheckTickTime = 100 * time.Millisecond
outputStreamCheckTimeout = 5 * time.Second
)
// Executor manages the lifecycle of Docker containers for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Docker client for container management.
}
// NewExecutor initializes a new Executor instance with a Docker client.
func NewExecutor(ctx context.Context, id string) (*Executor, error) {
dockerClient, err := NewDockerClient()
if err != nil {
return nil, err
}
if !dockerClient.IsInstalled(ctx) {
return nil, fmt.Errorf("docker is not installed")
}
return &Executor{
ID: id,
client: dockerClient,
}, nil
}
// Start begins the execution of a request by starting a Docker container.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// container is already running.
containerID, err := e.FindRunningContainer(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running container for this execution, we will instead check for a handler, and
// failing that will create a new container.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
}
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
containerID, err = e.newDockerExecutionContainer(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new container: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
containerID: containerID,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
TTYEnabled: true,
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the container.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s received results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// GetLogStream provides a stream of output logs for a specific execution.
// Parameters 'withHistory' and 'follow' control whether to include past logs
// and whether to keep the stream open for new logs, respectively.
// It returns an error if the execution is not found.
func (e *Executor) GetLogStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
// It's possible we've recorded the execution as running, but have not yet added the handler to
// the handler map because we're still waiting for the container to start. We will try and wait
// for a few seconds to see if the handler is added to the map.
chHandler := make(chan *executionHandler)
chExit := make(chan struct{})
go func(ch chan *executionHandler, exit chan struct{}) {
// Check the handlers every 100ms and send it down the
// channel if we find it. If we don't find it after 5 seconds
// then we'll be told on the exit channel
ticker := time.NewTicker(outputStreamCheckTickTime)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h, found := e.handlers.Get(request.ExecutionID)
if found {
ch <- h
return
}
case <-exit:
ticker.Stop()
return
}
}
}(chHandler, chExit)
// Either we'll find a handler for the execution (which might have finished starting)
// or we'll timeout and return an error.
select {
case handler := <-chHandler:
return handler.outputStream(ctx, request)
case <-time.After(outputStreamCheckTimeout):
chExit <- struct{}{}
}
return nil, fmt.Errorf("execution (%s) not found", request.ExecutionID)
}
// List returns a slice of ExecutionListItem containing information about current executions.
// This implementation currently returns an empty list and should be updated in the future.
func (e *Executor) List() []types.ExecutionListItem {
// TODO: list dms containers
return nil
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// Cleanup removes all Docker resources associated with the executor.
// This includes removing containers including networks and volumes with the executor's label.
func (e *Executor) Cleanup(ctx context.Context) error {
err := e.client.RemoveObjectsWithLabel(ctx, labelExecutorName, e.ID)
if err != nil {
return fmt.Errorf("failed to remove containers: %w", err)
}
zlog.Info("Cleaned up all Docker resources")
return nil
}
// newDockerExecutionContainer is an internal method called by Start to set up a new Docker container
// for the job execution. It configures the container based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up environment variables, mounts and resource
// constraints. It then creates the container but does not start it.
// The method returns a container.CreateResponse and an error if any part of the setup fails.
func (e *Executor) newDockerExecutionContainer(
ctx context.Context,
params *types.ExecutionRequest,
) (string, error) {
dockerArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return "", fmt.Errorf("failed to decode docker engine spec: %w", err)
}
// TODO: Move this code block ( L263-272) to the allocator in future
// Select the GPU with the highest available free VRAM and choose the GPU vendor for container's host config
// TODO: use the hardware manager instantiated in the dms package
hardwareManager := hardware.NewHardwareManager()
machineResources, err := hardwareManager.GetMachineResources()
if err != nil {
return "", fmt.Errorf("failed to get machine resources: %w", err)
}
var chosenGPUVendor types.GPUVendor
if len(machineResources.GPUs) == 0 {
zlog.Info("no GPUs available on the machine")
chosenGPUVendor = types.GPUVendorNone
} else {
// Essential for multi-vendor GPU nodes. For example,
// if a machine has an 8 GB NVIDIA and a 16 GB Intel GPU, the latter should be used first.
// Even for machines with a single GPU, this is important as integrated GPUs would also be commonly detected.
maxFreeVRAMGpu, err := machineResources.GPUs.MaxFreeVRAMGPU()
if err != nil {
// TODO: log a warning here
chosenGPUVendor = types.GPUVendorNone
} else {
chosenGPUVendor = maxFreeVRAMGpu.Vendor
}
}
containerConfig := container.Config{
Image: dockerArgs.Image,
Tty: true, // Needs to be true for applications such as Jupyter or Gradio to work correctly. See issue #459 for details.
Env: dockerArgs.Environment,
Entrypoint: dockerArgs.Entrypoint,
Cmd: dockerArgs.Cmd,
Labels: e.containerLabels(params.JobID, params.ExecutionID),
WorkingDir: dockerArgs.WorkingDirectory,
}
mounts, err := makeContainerMounts(params.Inputs, params.Outputs, params.ResultsDir)
if err != nil {
return "", fmt.Errorf("failed to create container mounts: %w", err)
}
zlog.Sugar().Infof("Adding %d GPUs to request", len(params.Resources.GPUs))
hostConfig := configureHostConfig(chosenGPUVendor, params, mounts)
if _, err = e.client.PullImage(ctx, dockerArgs.Image); err != nil {
return "", fmt.Errorf("failed to pull docker image: %w", err)
}
executionContainer, err := e.client.CreateContainer(
ctx,
&containerConfig,
&hostConfig,
nil,
nil,
labelExecutionValue(e.ID, params.JobID, params.ExecutionID),
)
if err != nil {
return "", fmt.Errorf("failed to create container: %w", err)
}
return executionContainer, nil
}
// configureHostConfig sets up the host configuration for the container based on the
// GPU vendor and resources requested by the execution. It supports both GPU and CPU configurations.
func configureHostConfig(vendor types.GPUVendor, params *types.ExecutionRequest, mounts []mount.Mount) container.HostConfig {
var hostConfig container.HostConfig
switch vendor {
case types.GPUVendorNvidia:
deviceIDs := make([]string, len(params.Resources.GPUs))
for i, gpu := range params.Resources.GPUs {
deviceIDs[i] = fmt.Sprint(gpu.Index)
}
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores),
CPUCount: int64(params.Resources.CPU.Cores),
DeviceRequests: []container.DeviceRequest{
{
DeviceIDs: deviceIDs,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores),
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
// Updated the device handling for Intel GPUs.
// Previously, specific device paths were determined using PCI addresses and symlinks.
// Now, the approach has been simplified by directly binding the entire /dev/dri directory.
// This change exposes all Intel GPUs to the container, which may be preferable for
// environments with multiple Intel GPUs. It reduces complexity as granular control
// is not required if all GPUs need to be accessible.
case types.GPUVendorIntel:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores),
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
default:
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores),
CPUCount: int64(params.Resources.CPU.Cores),
},
}
}
return hostConfig
}
// makeContainerMounts creates the mounts for the container based on the input and output
// volumes provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeContainerMounts(
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]mount.Mount, error) {
// the actual mounts we will give to the container
// these are paths for both input and output data
mounts := make([]mount.Mount, 0)
for _, input := range inputs {
if input.Type != types.StorageVolumeTypeBind {
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: input.Source,
Target: input.Target,
ReadOnly: input.ReadOnly,
})
} else {
return nil, fmt.Errorf("unsupported storage volume type: %s", input.Type)
}
}
for _, output := range outputs {
if output.Source == "" {
return nil, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return nil, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return nil, fmt.Errorf("failed to create results directory: %w", err)
}
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: output.Source,
Target: output.Target,
// this is an output volume so can be written to
ReadOnly: false,
})
}
return mounts, nil
}
// containerLabels returns the labels to be applied to the container for the given job and execution.
func (e *Executor) containerLabels(jobID string, executionID string) map[string]string {
return map[string]string{
labelExecutorName: e.ID,
labelJobID: labelJobValue(e.ID, jobID),
labelExecutionID: labelExecutionValue(e.ID, jobID, executionID),
}
}
// labelJobValue returns the value for the job label.
func labelJobValue(executorID string, jobID string) string {
return fmt.Sprintf("%s_%s", executorID, jobID)
}
// labelExecutionValue returns the value for the execution label.
func labelExecutionValue(executorID string, jobID string, executionID string) string {
return fmt.Sprintf("%s_%s_%s", executorID, jobID, executionID)
}
// FindRunningContainer finds the container that is running the execution
// with the given ID. It returns the container ID if found, or an error if
// the container is not found.
func (e *Executor) FindRunningContainer(
ctx context.Context,
jobID string,
executionID string,
) (string, error) {
labelValue := labelExecutionValue(e.ID, jobID, executionID)
return e.client.FindContainer(ctx, labelExecutionID, labelValue)
}
package docker
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"strconv"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"gitlab.com/nunet/device-management-service/types"
)
var DestroyTimeout = time.Second * 10
// executionHandler manages the lifecycle and execution of a Docker container for a specific job.
type executionHandler struct {
// provided by the executor
ID string
client *Client // Docker client for container management.
// meta data about the task
jobID string
executionID string
containerID string
resultsDir string // Directory to store execution results.
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // Blocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
// TTY setting
TTYEnabled bool // Indicates if TTY is enabled for the container.
}
// active checks if the execution handler's container is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the container and handles its execution lifecycle.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
if err := h.destroy(DestroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
if err := h.client.StartContainer(ctx, h.containerID); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start container: %v", err))
return
}
close(h.activeCh) // Indicate that the container has started.
var containerError error
var containerExitStatusCode int64
// Wait for the container to finish or for an execution error.
statusCh, errCh := h.client.WaitContainer(ctx, h.containerID)
select {
case status := <-ctx.Done():
h.result = types.NewFailedExecutionResult(fmt.Errorf("execution cancelled: %v", status))
return
case err := <-errCh:
zlog.Sugar().Errorf("error while waiting for container: %v\n", err)
h.result = types.NewFailedExecutionResult(
fmt.Errorf("failed to wait for container: %v", err),
)
return
case exitStatus := <-statusCh:
containerExitStatusCode = exitStatus.StatusCode
containerJSON, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: err.Error(),
}
return
}
if containerJSON.ContainerJSONBase.State.OOMKilled {
containerError = errors.New("container was killed due to OOM")
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
return
}
if exitStatus.Error != nil {
containerError = errors.New(exitStatus.Error.Message)
}
}
// Follow container logs to capture stdout and stderr.
stdoutPipe, stderrPipe, logsErr := h.client.FollowLogs(ctx, h.containerID)
if logsErr != nil {
followError := fmt.Errorf("failed to follow container logs: %w", logsErr)
if containerError != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: fmt.Sprintf(
"container error: '%s'. logs error: '%s'",
containerError,
followError,
),
}
} else {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: followError.Error(),
}
}
return
}
// Initialize the result with the exit status code.
h.result = types.NewExecutionResult(int(containerExitStatusCode))
// Capture the logs based on the TTY setting.
if h.TTYEnabled {
// TTY combines stdout and stderr, read from stdoutPipe only.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
} else {
// Read from stdout and stderr separately.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
h.result.STDERR, _ = bufio.NewReader(stderrPipe).ReadString('\x00')
}
}
// kill sends a stop signal to the container.
func (h *executionHandler) kill(ctx context.Context) error {
timeout := int(DestroyTimeout)
stopOptions := container.StopOptions{
Timeout: &timeout,
}
return h.client.StopContainer(ctx, h.containerID, stopOptions)
}
// destroy cleans up the container and its associated resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// stop the container
if err := h.kill(ctx); err != nil {
return fmt.Errorf("failed to kill container (%s): %w", h.containerID, err)
}
if err := h.client.RemoveContainer(ctx, h.containerID); err != nil {
return err
}
// Remove related objects like networks or volumes created for this execution.
return h.client.RemoveObjectsWithLabel(
ctx,
labelExecutionID,
labelExecutionValue(h.ID, h.jobID, h.executionID),
)
}
func (h *executionHandler) outputStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
since := "1" // Default to the start of UNIX time to get all logs.
if request.Tail {
since = strconv.FormatInt(time.Now().Unix(), 10)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-h.activeCh: // Ensure the container is active before attempting to stream logs.
}
// Gets the underlying reader, and provides data since the value of the `since` timestamp.
return h.client.GetOutputStream(ctx, h.containerID, since, request.Follow)
}
package docker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("docker.executor")
}
package docker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyImage = "image"
EngineKeyEntrypoint = "entrypoint"
EngineKeyCmd = "cmd"
EngineKeyEnvironment = "environment"
EngineKeyWorkingDirectory = "working_directory"
)
// EngineSpec contains necessary parameters to execute a docker job.
type EngineSpec struct {
// Image this should be pullable by docker
Image string `json:"image,omitempty"`
// Entrypoint optionally override the default entrypoint
Entrypoint []string `json:"entrypoint,omitempty"`
// Cmd specifies the command to run in the container
Cmd []string `json:"cmd,omitempty"`
// EnvironmentVariables is a slice of env to run the container with
Environment []string `json:"environment,omitempty"`
// WorkingDirectory inside the container
WorkingDirectory string `json:"working_directory,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.Image) {
return fmt.Errorf("invalid docker engine params: image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a docker engine spec
// It converts the params into a docker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(string(types.ExecutorTypeDocker)) {
return EngineSpec{}, fmt.Errorf(
"invalid docker engine type. expected %s, but received: %s",
types.ExecutorTypeDocker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid docker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode docker engine params: %w", err)
}
var dockerSpec *EngineSpec
if err := json.Unmarshal(paramBytes, &dockerSpec); err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode docker engine params: %w", err)
}
return *dockerSpec, dockerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Docker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewDockerEngineBuilder function initializes a new DockerEngineBuilder instance.
// It sets the engine type to model.EngineDocker.String() and image as per the input argument.
func NewDockerEngineBuilder(image string) *EngineBuilder {
eb := types.NewSpecConfig(string(types.ExecutorTypeDocker))
eb.WithParam(EngineKeyImage, image)
return &EngineBuilder{eb: eb}
}
// WithEntrypoint is a builder method that sets the Docker engine entrypoint.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEntrypoint(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEntrypoint, e)
return b
}
// WithCmd is a builder method that sets the Docker engine's Command.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithCmd(c ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyCmd, c)
return b
}
// WithEnvironment is a builder method that sets the Docker engine's environment variables.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEnvironment(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEnvironment, e)
return b
}
// WithWorkingDirectory is a builder method that sets the Docker engine's working directory.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithWorkingDirectory(w string) *EngineBuilder {
b.eb.WithParam(EngineKeyWorkingDirectory, w)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
//go:build linux
// +build linux
package firecracker
import (
"context"
"fmt"
"os"
"syscall"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
)
const pidCheckTickTime = 100 * time.Millisecond
// Client wraps the Firecracker SDK to provide high-level operations on Firecracker VMs.
type Client struct{}
func NewFirecrackerClient() (*Client, error) {
return &Client{}, nil
}
// IsInstalled checks if Firecracker is installed on the host.
func (c *Client) IsInstalled(ctx context.Context) bool {
// Check if the Firecracker binary is installed.
// This implementation sends a version request to the Firecracker binary.
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
cmd := firecracker.VMCommandBuilder{}.WithArgs([]string{"--version"}).Build(ctx)
version, err := cmd.Output()
if err != nil || !cmd.ProcessState.Success() {
return false
}
return string(version) != ""
}
// CreateVM creates a new Firecracker VM with the specified configuration.
func (c *Client) CreateVM(
ctx context.Context,
cfg firecracker.Config,
) (*firecracker.Machine, error) {
cmd := firecracker.VMCommandBuilder{}.
WithSocketPath(cfg.SocketPath).
Build(ctx)
machineOpts := []firecracker.Opt{
firecracker.WithProcessRunner(cmd),
}
m, err := firecracker.NewMachine(ctx, cfg, machineOpts...)
return m, err
}
// StartVM starts the Firecracker VM.
func (c *Client) StartVM(ctx context.Context, m *firecracker.Machine) error {
return m.Start(ctx)
}
// ShutdownVM shuts down the Firecracker VM.
func (c *Client) ShutdownVM(ctx context.Context, m *firecracker.Machine) error {
return m.Shutdown(ctx)
}
// DestroyVM destroys the Firecracker VM.
func (c *Client) DestroyVM(
ctx context.Context,
m *firecracker.Machine,
timeout time.Duration,
) error {
// Get the PID of the Firecracker process and shut down the VM.
// If the process is still running after the timeout, kill it.
err := c.ShutdownVM(ctx, m)
if err != nil {
return fmt.Errorf("failed to shutdown vm: %w", err)
}
pid, _ := m.PID()
defer os.Remove(m.Cfg.SocketPath)
// If the process is not running, return early.
if pid <= 0 {
return nil
}
// This checks if the process is still running every pidCheckTickTime.
// If the process is still running after the timeout it will set done to false.
done := make(chan bool, 1)
go func() {
ticker := time.NewTicker(pidCheckTickTime)
defer ticker.Stop()
to := time.NewTimer(timeout)
defer to.Stop()
for {
select {
case <-to.C:
done <- false
return
case <-ticker.C:
if pid, _ := m.PID(); pid <= 0 {
done <- true
return
}
}
}
}()
// Wait for the check to finish.
killed := <-done
if !killed {
// The shutdown request timed out, kill the process with SIGKILL.
err := syscall.Kill(pid, syscall.SIGKILL)
if err != nil {
return fmt.Errorf("failed to kill process: %v", err)
}
}
return nil
}
// FindVM finds a Firecracker VM by its socket path.
// This implementation checks if the VM is running by sending a request to the Firecracker API.
func (c *Client) FindVM(ctx context.Context, socketPath string) (*firecracker.Machine, error) {
// Check if the socket file exists.
if _, err := os.Stat(socketPath); err != nil {
return nil, fmt.Errorf("VM with socket path %v not found", socketPath)
}
// Create a new Firecracker machine instance.
cmd := firecracker.VMCommandBuilder{}.WithSocketPath(socketPath).Build(ctx)
machine, err := firecracker.NewMachine(
ctx,
firecracker.Config{SocketPath: socketPath},
firecracker.WithProcessRunner(cmd),
)
if err != nil {
return nil, fmt.Errorf("failed to create machine with socket %s: %v", socketPath, err)
}
// Check if the VM is running by getting its instance info.
info, err := machine.DescribeInstanceInfo(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get instance info for socket %s: %v", socketPath, err)
}
if *info.State != "Running" {
return nil, fmt.Errorf(
"VM with socket %s is not running, current state: %s",
socketPath,
*info.State,
)
}
return machine, nil
}
//go:build linux
// +build linux
package firecracker
import (
"context"
"errors"
"fmt"
"io"
"os"
"sync"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
fcModels "github.com/firecracker-microvm/firecracker-go-sdk/client/models"
"go.uber.org/multierr"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
socketDir = "/tmp"
)
var ErrNotInstalled = errors.New("firecracker is not installed")
// Executor manages the lifecycle of Firecracker VMs for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Firecracker client for VM management.
}
// NewExecutor initializes a new executor for Firecracker VMs.
func NewExecutor(ctx context.Context, id string) (*Executor, error) {
firecrackerClient, err := NewFirecrackerClient()
if err != nil {
return nil, err
}
if !firecrackerClient.IsInstalled(ctx) {
return nil, ErrNotInstalled
}
fe := &Executor{
ID: id,
client: firecrackerClient,
}
return fe, nil
}
// List the current executions.
func (e *Executor) List() []types.ExecutionListItem {
executions := make([]types.ExecutionListItem, 0)
e.handlers.Range(func(key, value any) bool {
strKey := key.(string)
val := value.(*executionHandler)
executions = append(executions, types.ExecutionListItem{
ExecutionID: strKey,
Running: val.running.Load(),
})
return true
})
return executions
}
// start begins the execution of a request by starting a new Firecracker VM.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// VM is already running.
machine, err := e.FindRunningVM(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running VM for this execution, we will instead check for a handler, and
// failing that will create a new VM.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
}
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
machine, err = e.newFirecrackerExecutionVM(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new firecracker VM: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
machine: machine,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the VM.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s received results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// GetLogStream is not implemented for Firecracker.
// It is defined to satisfy the Executor interface.
// This method will return an error if called.
func (e *Executor) GetLogStream(_ context.Context, _ types.LogStreamRequest) (io.ReadCloser, error) {
return nil, fmt.Errorf("GetLogStream is not implemented for Firecracker")
}
// Cleanup removes all resources associated with the executor.
// This includes stopping and removing all running VMs and deleting their socket paths.
func (e *Executor) Cleanup() error {
wg := sync.WaitGroup{}
errCh := make(chan error, len(e.handlers.Keys()))
e.handlers.Iter(func(_ string, handler *executionHandler) bool {
wg.Add(1)
go func(handler *executionHandler, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- handler.destroy(time.Second * 10)
}(handler, &wg, errCh)
return true
})
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
zlog.Info("Cleaned up all firecracker resources")
return errs
}
// newFirecrackerExecutionVM is an internal method called by Start to set up a new Firecracker VM
// for the job execution. It configures the VM based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up mounts and resource constraints.
// It then creates the VM but does not start it. The method returns a firecracker.Machine instance
// and an error if any part of the setup fails.
func (e *Executor) newFirecrackerExecutionVM(
ctx context.Context,
params *types.ExecutionRequest,
) (*firecracker.Machine, error) {
fcArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return nil, fmt.Errorf("failed to decode firecracker engine spec: %w", err)
}
fcConfig := firecracker.Config{
VMID: params.ExecutionID,
SocketPath: e.generateSocketPath(params.JobID, params.ExecutionID),
KernelImagePath: fcArgs.KernelImage,
InitrdPath: fcArgs.Initrd,
KernelArgs: fcArgs.KernelArgs,
MachineCfg: fcModels.MachineConfiguration{
VcpuCount: firecracker.Int64(int64(params.Resources.CPU.Cores)),
MemSizeMib: firecracker.Int64(int64(params.Resources.RAM.Size)),
},
}
mounts, err := makeVMMounts(
fcArgs.RootFileSystem,
params.Inputs,
params.Outputs,
params.ResultsDir,
)
if err != nil {
return nil, fmt.Errorf("failed to create VM mounts: %w", err)
}
fcConfig.Drives = mounts
machine, err := e.client.CreateVM(ctx, fcConfig)
if err != nil {
return nil, fmt.Errorf("failed to create VM: %w", err)
}
// e.client.VMPassMMDs(ctx, machine, fcArgs.MMDSMessage)
return machine, nil
}
// makeVMMounts creates the mounts for the VM based on the input and output volumes
// provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeVMMounts(
rootFileSystem string,
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]fcModels.Drive, error) {
var drives []fcModels.Drive
drivesBuilder := firecracker.NewDrivesBuilder(rootFileSystem)
for _, input := range inputs {
drivesBuilder.AddDrive(input.Source, input.ReadOnly)
}
for _, output := range outputs {
if output.Source == "" {
return drives, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return drives, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return drives, fmt.Errorf("failed to create results directory: %w", err)
}
drivesBuilder.AddDrive(output.Source, false)
}
drives = drivesBuilder.Build()
return drives, nil
}
// FindRunningVM finds the VM that is running the execution with the given ID.
// It returns the Mchine instance if found, or an error if the VM is not found.
func (e *Executor) FindRunningVM(
ctx context.Context,
jobID string,
executionID string,
) (*firecracker.Machine, error) {
return e.client.FindVM(ctx, e.generateSocketPath(jobID, executionID))
}
// generateSocketPath generates a socket path based on the job identifiers.
func (e *Executor) generateSocketPath(jobID string, executionID string) string {
return fmt.Sprintf("%s/%s_%s_%s.sock", socketDir, e.ID, jobID, executionID)
}
//go:build linux
// +build linux
package firecracker
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
"gitlab.com/nunet/device-management-service/types"
)
// executionHandler is a struct that holds the necessary information to manage the execution of a firecracker VM.
type executionHandler struct {
//
// provided by the executor
ID string
client *Client
// meta data about the task
JobID string
executionID string
machine *firecracker.Machine
resultsDir string
// synchronization
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // BLocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
}
// active returns true if the firecracker VM is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the firecracker VM and waits for it to finish.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
destroyTimeout := time.Second * 10
if err := h.destroy(destroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
// start the VM
zlog.Sugar().Info("starting firecracker execution")
if err := h.client.StartVM(ctx, h.machine); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start VM: %v", err))
return
}
close(h.activeCh) // Indicate that the VM has started.
err := h.machine.Wait(ctx)
if err != nil {
if ctx.Err() != nil {
h.result = types.NewFailedExecutionResult(
fmt.Errorf("context closed while waiting on VM: %v", err),
)
return
}
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to wait on VM: %v", err))
return
}
h.result = types.NewExecutionResult(types.ExecutionStatusCodeSuccess)
}
// kill stops the firecracker VM.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.ShutdownVM(ctx, h.machine)
}
// destroy stops the firecracker VM and removes its resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return h.client.DestroyVM(ctx, h.machine, timeout)
}
//go:build linux
// +build linux
package firecracker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("executor.firecracker")
}
//go:build linux
// +build linux
package firecracker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyKernelImage = "kernel_image"
EngineKeyKernelArgs = "kernel_args"
EngineKeyRootFileSystem = "root_file_system"
EngineKeyInitrd = "initrd"
EngineKeyMMDSMessage = "mmds_message"
)
// EngineSpec contains necessary parameters to execute a firecracker job.
type EngineSpec struct {
// KernelImage is the path to the kernel image file.
KernelImage string `json:"kernel_image,omitempty"`
// InitrdPath is the path to the initial ramdisk file.
Initrd string `json:"initrd_path,omitempty"`
// KernelArgs is the kernel command line arguments.
KernelArgs string `json:"kernel_args,omitempty"`
// RootFileSystem is the path to the root file system.
RootFileSystem string `json:"root_file_system,omitempty"`
// MMDSMessage is the MMDS message to be sent to the Firecracker VM.
MMDSMessage string `json:"mmds_message,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.RootFileSystem) {
return fmt.Errorf("invalid firecracker engine params: root_file_system cannot be empty")
}
if validate.IsBlank(c.KernelImage) {
return fmt.Errorf("invalid firecracker engine params: kernel_image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a firecracker engine spec
// It converts the params into a firecracker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(types.ExecutorTypeFirecracker.String()) {
return EngineSpec{}, fmt.Errorf(
"invalid firecracker engine type. expected %s, but received: %s",
types.ExecutorTypeFirecracker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid firecracker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode firecracker engine params: %w", err)
}
var firecrackerSpec *EngineSpec
err = json.Unmarshal(paramBytes, &firecrackerSpec)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode firecracker engine params: %w", err)
}
return *firecrackerSpec, firecrackerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Firecracker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewFirecrackerEngineBuilder function initializes a new FirecrackerEngineBuilder instance.
// It sets the engine type to EngineFirecracker.String() and kernel image path as per the input argument.
func NewFirecrackerEngineBuilder(rootFileSystem string) *EngineBuilder {
eb := types.NewSpecConfig(types.ExecutorTypeFirecracker.String())
eb.WithParam(EngineKeyRootFileSystem, rootFileSystem)
return &EngineBuilder{eb: eb}
}
// WithRootFileSystem is a builder method that sets the Firecracker engine root file system.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithRootFileSystem(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyRootFileSystem, e)
return b
}
// WithKernelImage is a builder method that sets the Firecracker engine kernel image.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelImage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelImage, e)
return b
}
// WithInitrd is a builder method that sets the Firecracker init ram disk.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithInitrd(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyInitrd, e)
return b
}
// WithKernelArgs is a builder method that sets the Firecracker engine kernel arguments.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelArgs(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelArgs, e)
return b
}
// WithMMDSMessage is a builder method that sets the Firecracker engine MMDS message.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithMMDSMessage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyMMDSMessage, e)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
package null
import (
"context"
"io"
"gitlab.com/nunet/device-management-service/executor"
"gitlab.com/nunet/device-management-service/types"
)
// Executor is a no-op implementation of the Executor interface.
type Executor struct{}
// NewExecutor creates a new Executor.
func NewExecutor(_ context.Context, _ string) (executor.Executor, error) {
return &Executor{}, nil
}
var _ executor.Executor = (*Executor)(nil)
// Start does nothing and returns nil.
func (e *Executor) Start(_ context.Context, _ *types.ExecutionRequest) error {
return nil
}
// Run returns a nil result and nil error.
func (e *Executor) Run(_ context.Context, _ *types.ExecutionRequest) (*types.ExecutionResult, error) {
return nil, nil
}
// Wait returns channels that immediately close.
func (e *Executor) Wait(_ context.Context, _ string) (<-chan *types.ExecutionResult, <-chan error) {
resultCh := make(chan *types.ExecutionResult)
errCh := make(chan error)
close(resultCh)
close(errCh)
return resultCh, errCh
}
// Cancel does nothing and returns nil.
func (e *Executor) Cancel(_ context.Context, _ string) error {
return nil
}
// GetLogStream returns a closed io.ReadCloser and nil error.
func (e *Executor) GetLogStream(_ context.Context, _ types.LogStreamRequest) (io.ReadCloser, error) {
return io.NopCloser(nil), nil
}
// List returns an empty slice of ExecutionListItem.
func (e *Executor) List() []types.ExecutionListItem {
return []types.ExecutionListItem{}
}
// Cleanup does nothing and returns nil.
func (e *Executor) Cleanup(_ context.Context) error {
return nil
}
package backgroundtasks
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("background_tasks")
}
package backgroundtasks
import (
"sort"
"sync"
"time"
)
// Scheduler orchestrates the execution of tasks based on their triggers and priority.
type Scheduler struct {
tasks map[int]*Task // Map of tasks by their ID.
runningTasks map[int]bool // Map to keep track of running tasks.
ticker *time.Ticker // Ticker for periodic checks of task triggers.
stopChan chan struct{} // Channel to signal stopping the scheduler.
maxRunningTasks int // Maximum number of tasks that can run concurrently.
lastTaskID int // Counter for assigning unique IDs to tasks.
mu sync.Mutex // Mutex to protect access to task maps.
}
// NewScheduler creates a new Scheduler with a specified limit on running tasks.
func NewScheduler(maxRunningTasks int) *Scheduler {
return &Scheduler{
tasks: make(map[int]*Task),
runningTasks: make(map[int]bool),
ticker: time.NewTicker(1 * time.Second),
stopChan: make(chan struct{}),
maxRunningTasks: maxRunningTasks,
lastTaskID: 0,
}
}
// AddTask adds a new task to the scheduler and initializes its state.
func (s *Scheduler) AddTask(task *Task) *Task {
s.mu.Lock()
defer s.mu.Unlock()
task.ID = s.lastTaskID
task.Enabled = true
for _, trigger := range task.Triggers {
trigger.Reset()
}
s.tasks[task.ID] = task
s.lastTaskID++
return task
}
// RemoveTask removes a task from the scheduler.
func (s *Scheduler) RemoveTask(taskID int) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Start begins the scheduler's task execution loop.
func (s *Scheduler) Start() {
go func() {
for {
select {
case <-s.stopChan:
return
case <-s.ticker.C:
s.runTasks()
}
}
}()
}
// runningTasksCount returns the count of running tasks.
func (s *Scheduler) runningTasksCount() int {
s.mu.Lock()
defer s.mu.Unlock()
count := 0
for _, isRunning := range s.runningTasks {
if isRunning {
count++
}
}
return count
}
// runTasks checks and runs tasks based on their triggers and priority.
func (s *Scheduler) runTasks() {
// Sort tasks by priority.
sortedTasks := make([]*Task, 0, len(s.tasks))
for _, task := range s.tasks {
sortedTasks = append(sortedTasks, task)
}
sort.Slice(sortedTasks, func(i, j int) bool {
return sortedTasks[i].Priority > sortedTasks[j].Priority
})
for _, task := range sortedTasks {
if !task.Enabled || s.runningTasks[task.ID] {
continue
}
if len(task.Triggers) == 0 {
s.RemoveTask(task.ID)
continue
}
for _, trigger := range task.Triggers {
if trigger.IsReady() && s.runningTasksCount() < s.maxRunningTasks {
s.runningTasks[task.ID] = true
go s.runTask(task.ID)
trigger.Reset()
break
}
}
}
}
// Stop signals the scheduler to stop running tasks.
func (s *Scheduler) Stop() {
close(s.stopChan)
}
// runTask executes a task and manages its lifecycle and retry policy.
func (s *Scheduler) runTask(taskID int) {
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
s.runningTasks[taskID] = false
}()
task := s.tasks[taskID]
execution := Execution{StartedAt: time.Now()}
defer func() {
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.tasks[taskID] = task
s.mu.Unlock()
}()
for i := 0; i < task.RetryPolicy.MaxRetries+1; i++ {
err := runTaskWithRetry(task.Function, task.Args, task.RetryPolicy.Delay)
if err == nil {
execution.Status = "SUCCESS"
execution.EndedAt = time.Now()
return
}
execution.Error = err.Error()
}
execution.Status = "FAILED"
execution.EndedAt = time.Now()
}
// runTaskWithRetry attempts to execute a task with a retry policy.
func runTaskWithRetry(
fn func(args interface{}) error,
args []interface{},
delay time.Duration,
) error {
err := fn(args)
if err != nil {
time.Sleep(delay)
return err
}
return nil
}
package backgroundtasks
import (
"time"
"github.com/robfig/cron/v3"
)
// Trigger interface defines a method to check if a trigger condition is met.
type Trigger interface {
IsReady() bool // Returns true if the trigger condition is met.
Reset() // Resets the trigger state.
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTrigger struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTrigger) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTrigger) Reset() {
t.lastTriggered = time.Now()
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTriggerWithJitter struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
Jitter func() time.Duration
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTriggerWithJitter) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval + t.Jitter()).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTriggerWithJitter) Reset() {
t.lastTriggered = time.Now()
}
// EventTrigger triggers based on an external event signaled through a channel.
type EventTrigger struct {
Trigger chan bool // Channel to signal an event.
}
// IsReady checks if there is a signal in the trigger channel.
func (t *EventTrigger) IsReady() bool {
select {
case <-t.Trigger:
return true
default:
return false
}
}
// Reset for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) Reset() {}
// OneTimeTrigger triggers once after a specified delay.
type OneTimeTrigger struct {
Delay time.Duration // The delay after which to trigger.
registeredAt time.Time // Time when the trigger was set.
}
// Reset sets the trigger registration time to the current time.
func (t *OneTimeTrigger) Reset() {
t.registeredAt = time.Now()
}
// IsReady checks if the current time has passed the delay period.
func (t *OneTimeTrigger) IsReady() bool {
return t.registeredAt.Add(t.Delay).Before(time.Now())
}
package config
import (
"bytes"
"encoding/json"
"fmt"
"os"
"reflect"
"github.com/spf13/afero"
"github.com/spf13/viper"
)
var (
cfg Config
homeDir, _ = os.UserHomeDir()
)
func getViper() *viper.Viper {
v := viper.New()
v.SetConfigName("dms_config")
v.SetConfigType("json")
v.AddConfigPath(".") // config file reading order starts with current working directory
v.AddConfigPath(fmt.Sprintf("%s/.nunet", homeDir)) // then home directory
v.AddConfigPath("/etc/nunet/") // finally /etc/nunet
return v
}
func setDefaultConfig() *viper.Viper {
v := getViper()
v.SetDefault("general.user_dir", fmt.Sprintf("%s/.nunet", homeDir))
v.SetDefault("general.work_dir", fmt.Sprintf("%s/nunet", homeDir))
v.SetDefault("general.data_dir", "/var/nunet")
v.SetDefault("general.debug", false)
v.SetDefault("rest.addr", "127.0.0.1")
v.SetDefault("rest.port", 9999)
v.SetDefault("profiler.enabled", true)
v.SetDefault("profiler.addr", "127.0.0.1")
v.SetDefault("profiler.port", 6060)
v.SetDefault("p2p.listen_address", []string{
"/ip4/0.0.0.0/tcp/9000",
"/ip4/0.0.0.0/udp/9000/quic-v1",
})
v.SetDefault("p2p.bootstrap_peers", []string{
"/dnsaddr/bootstrap.p2p.nunet.io/p2p/QmQ2irHa8aFTLRhkbkQCRrounE4MbttNp8ki7Nmys4F9NP",
"/dnsaddr/bootstrap.p2p.nunet.io/p2p/Qmf16N2ecJVWufa29XKLNyiBxKWqVPNZXjbL3JisPcGqTw",
"/dnsaddr/bootstrap.p2p.nunet.io/p2p/QmTkWP72uECwCsiiYDpCFeTrVeUM9huGTPsg3m6bHxYQFZ",
})
v.SetDefault("p2p.memory", 1024)
v.SetDefault("p2p.fd", 512)
v.SetDefault("job.log_update_interval", 2)
v.SetDefault("job.target_peer", "")
v.SetDefault("job.cleanup_interval", 3)
v.SetDefault("telemetry.service_name", "NunetDMS")
v.SetDefault("telemetry.global_endpoint", "otel-collector.telemetry.nunet.io:4318")
v.SetDefault("telemetry.observability_level", "INFO")
v.SetDefault("telemetry.telemetry_mode", "production")
return v
}
func LoadConfig() error {
v := setDefaultConfig()
if err := v.ReadInConfig(); err != nil {
if err := setDefaultConfig().UnmarshalExact(&cfg); err != nil {
return fmt.Errorf("failed to unmarshal default config: %w", err)
}
return nil
}
if err := v.UnmarshalExact(&cfg); err != nil {
return fmt.Errorf("failed to unmarshal config: %w", err)
}
return nil
}
func GetConfig() *Config {
if reflect.DeepEqual(cfg, Config{}) {
if err := LoadConfig(); err != nil {
return &cfg
}
}
return &cfg
}
func Get(key string) (interface{}, error) {
v := getViper()
loadedConfig, err := json.Marshal(GetConfig())
if err != nil {
return nil, fmt.Errorf("could not marshal config: %w", err)
}
if err := v.ReadConfig(bytes.NewReader(loadedConfig)); err != nil {
return nil, fmt.Errorf("could not read config: %w", err)
}
if !v.IsSet(key) {
return nil, fmt.Errorf("key '%s' not found in configuration", key)
}
return v.Get(key), nil
}
func Set(fs afero.Fs, key string, value interface{}) error {
v := getViper()
v.SetFs(fs)
v.Set(key, value)
if err := v.UnmarshalExact(&cfg); err != nil {
return fmt.Errorf("failed to unmarshal config: %w", err)
}
loadedConfig, err := json.Marshal(GetConfig())
if err != nil {
return fmt.Errorf("could not marshal config: %w", err)
}
if err := v.MergeConfig(bytes.NewReader(loadedConfig)); err != nil {
return fmt.Errorf("failed to merge config: %w", err)
}
if err := v.WriteConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
// Config file does not exist, create it.
return v.SafeWriteConfig()
}
return fmt.Errorf("failed to write config: %w", err)
}
return nil
}
func FileExists(fs afero.Fs) (bool, error) {
v := getViper()
v.SetFs(fs)
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
return false, nil
}
return false, fmt.Errorf("could not read config file: %w", err)
}
return true, nil
}
func GetPath() string {
v := getViper()
if err := v.ReadInConfig(); err != nil {
return setDefaultConfig().ConfigFileUsed()
}
return v.ConfigFileUsed()
}
package internal
import (
"os"
"os/signal"
"syscall"
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var (
zlog *otelzap.Logger
ShutdownChan chan os.Signal
)
func init() {
zlog = logger.OtelZapLogger("internal")
ShutdownChan = make(chan os.Signal, 1)
signal.Notify(ShutdownChan, syscall.SIGINT, syscall.SIGTERM)
}
// Package internal is a work in progress. It is planned to accommodate
// modules such as db and types.
package internal
import (
"fmt"
"net/http"
"github.com/gorilla/websocket"
)
// UpgradeConnection is generic protocol upgrader for entire DMS.
var UpgradeConnection = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(_ *http.Request) bool { return true },
}
// WebSocketConnection is pointer to gorilla/websocket.Conn
type WebSocketConnection struct {
*websocket.Conn
}
// Command represents a command to be executed
type Command struct {
Command string
NodeID string // ID of the node where command will be executed
Result string
Conn *WebSocketConnection
}
var commandChan = make(chan Command)
var clients = make(map[WebSocketConnection]string)
// ListenForWs listens to the connected client for any message. It is assumed that
// every message that is coming is a command to be executed.
func ListenForWs(conn *WebSocketConnection) {
defer func() {
if r := recover(); r != nil {
zlog.Sugar().Errorf("Error:", fmt.Sprintf("%v", r))
}
}()
cmd := Command{NodeID: clients[*conn], Conn: conn}
for {
_, msg, err := conn.ReadMessage()
if err == nil { // if NO error
// logic to send command and fetch the output
cmd.Command = string(msg)
commandChan <- cmd
}
}
}
// SendCommandForExecution work is to send command for execution and fetch the result
// This function listens for new commands from commandChan
func SendCommandForExecution() {
for {
command := <-commandChan
zlog.Sugar().Infof("%v", command)
// TO BE IMPLEMENTED
// send command
// fetch result
// send back result
err := command.Conn.WriteMessage(websocket.TextMessage, []byte(command.Command))
if err != nil {
zlog.Sugar().Warnf("failed to write message: %w", err)
}
}
}
package crypto
import (
"crypto/rand"
"errors"
"io"
"golang.org/x/crypto/sha3"
)
// RandomEntropy bytes from rand.Reader
func RandomEntropy(length int) ([]byte, error) {
buf := make([]byte, length)
n, err := io.ReadFull(rand.Reader, buf)
if err != nil || n != length {
return nil, errors.New("failed to read random bytes")
}
return buf, nil
}
// Sha3 return sha3 of a given byte array
func Sha3(data ...[]byte) ([]byte, error) {
d := sha3.New256()
for _, b := range data {
_, err := d.Write(b)
if err != nil {
return nil, err
}
}
return d.Sum(nil), nil
}
package crypto
import (
"crypto/subtle"
"fmt"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
"github.com/libp2p/go-libp2p/core/crypto/pb"
"golang.org/x/crypto/sha3"
)
var ethSignMagic = []byte(
"\x19Ethereum Signed Message:\n",
)
type EthPublicKey struct {
key *secp256k1.PublicKey
}
var _ PubKey = (*EthPublicKey)(nil)
func UnmarshalEthPublicKey(data []byte) (_k PubKey, err error) {
k, err := secp256k1.ParsePubKey(data)
if err != nil {
return nil, err
}
return &EthPublicKey{key: k}, nil
}
func (k *EthPublicKey) Verify(data []byte, sigStr []byte) (success bool, err error) {
sig, err := ecdsa.ParseDERSignature(sigStr)
if err != nil {
return false, err
}
hasher := sha3.NewLegacyKeccak256()
hasher.Write(ethSignMagic)
hasher.Write([]byte(fmt.Sprintf("%d", len(data))))
hasher.Write(data)
hash := hasher.Sum(nil)
return sig.Verify(hash, k.key), nil
}
func (k *EthPublicKey) Raw() (res []byte, err error) {
return k.key.SerializeCompressed(), nil
}
func (k *EthPublicKey) Type() pb.KeyType {
return Eth
}
func (k *EthPublicKey) Equals(o Key) bool {
sk, ok := o.(*EthPublicKey)
if !ok {
return basicEquals(k, o)
}
return k.key.IsEqual(sk.key)
}
func basicEquals(k1, k2 Key) bool {
if k1.Type() != k2.Type() {
return false
}
a, err := k1.Raw()
if err != nil {
return false
}
b, err := k2.Raw()
if err != nil {
return false
}
return subtle.ConstantTimeCompare(a, b) == 1
}
package crypto
import (
"bytes"
"encoding/base32"
"encoding/json"
"fmt"
)
// ID is the encoding of the actor's public key
type ID struct{ PublicKey []byte }
func (id ID) Equal(other ID) bool {
return bytes.Equal(id.PublicKey, other.PublicKey)
}
func (id ID) Empty() bool {
return len(id.PublicKey) == 0
}
// IDJsonView is the on the wire json reprsentation of an ID
type IDJSONView struct {
Pub string `json:"pub"`
}
func (id ID) String() string {
return base32.StdEncoding.EncodeToString(id.PublicKey)
}
func IDFromString(s string) (ID, error) {
data, err := base32.StdEncoding.DecodeString(s)
if err != nil {
return ID{}, fmt.Errorf("decode ID: %w", err)
}
return ID{PublicKey: data}, nil
}
func (id ID) MarshalJSON() ([]byte, error) {
return json.Marshal(IDJSONView{Pub: id.String()})
}
var _ json.Marshaler = ID{}
func (id *ID) UnmarshalJSON(data []byte) error {
var input IDJSONView
err := json.Unmarshal(data, &input)
if err != nil {
return fmt.Errorf("unmarshaling ID: %w", err)
}
val, err := IDFromString(input.Pub)
if err != nil {
return fmt.Errorf("unmarshaling ID: %w", err)
}
*id = val
return nil
}
var _ json.Unmarshaler = (*ID)(nil)
package crypto
import (
"crypto/rand"
"fmt"
"github.com/libp2p/go-libp2p/core/crypto"
)
const (
Ed25519 = crypto.Ed25519
Secp256k1 = crypto.Secp256k1
Eth = 127
)
type (
Key = crypto.Key
PrivKey = crypto.PrivKey
PubKey = crypto.PubKey
)
func AllowedKey(t int) bool {
switch t {
case Ed25519:
return true
case Secp256k1:
return true
default:
return false
}
}
func GenerateKeyPair(t int) (PrivKey, PubKey, error) {
switch t {
case Ed25519:
return crypto.GenerateEd25519Key(rand.Reader)
case Secp256k1:
return crypto.GenerateSecp256k1Key(rand.Reader)
default:
return nil, nil, fmt.Errorf("unsupported key type %d: %w", t, ErrUnsupportedKeyType)
}
}
func PublicKeyToBytes(k PubKey) ([]byte, error) {
return crypto.MarshalPublicKey(k)
}
func BytesToPublicKey(data []byte) (PubKey, error) {
return crypto.UnmarshalPublicKey(data)
}
func PrivateKeyToBytes(k PrivKey) ([]byte, error) {
return crypto.MarshalPrivateKey(k)
}
func BytesToPrivateKey(data []byte) (PrivKey, error) {
return crypto.UnmarshalPrivateKey(data)
}
func IDFromPublicKey(k PubKey) (ID, error) {
data, err := PublicKeyToBytes(k)
if err != nil {
return ID{}, fmt.Errorf("id from public key: %w", err)
}
return ID{PublicKey: data}, nil
}
func PublicKeyFromID(id ID) (PubKey, error) {
return BytesToPublicKey(id.PublicKey)
}
package keystore
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/hex"
"encoding/json"
"fmt"
"path/filepath"
"strings"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
"golang.org/x/crypto/scrypt"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
var (
// KDF parameters
nameKDF = "scrypt"
scryptKeyLen = 64
scryptN = 1 << 18
scryptR = 8
scryptP = 1
ksVersion = 3
ksCipher = "aes-256-ctr"
)
// Key represents a keypair to be stored in a keystore
type Key struct {
ID string
Data []byte
}
// NewKey creates new Key
func NewKey(id string, data []byte) (*Key, error) {
return &Key{
ID: id,
Data: data,
}, nil
}
// PrivKey acts upon a Key which its `Data` is a private key.
// The method unmarshals the raw pvkey bytes.
func (key *Key) PrivKey() (crypto.PrivKey, error) {
priv, err := libp2p_crypto.UnmarshalPrivateKey(key.Data)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal private key: %v", err)
}
return priv, nil
}
// MarshalToJSON encrypts and marshals a key to json byte array.
func (key *Key) MarshalToJSON(passphrase string) ([]byte, error) {
if passphrase == "" {
return nil, ErrEmptyPassphrase
}
salt, err := crypto.RandomEntropy(64)
if err != nil {
return nil, err
}
dk, err := scrypt.Key([]byte(passphrase), salt, scryptN, scryptR, scryptP, scryptKeyLen)
if err != nil {
return nil, err
}
iv, err := crypto.RandomEntropy(aes.BlockSize)
if err != nil {
return nil, err
}
enckey := dk[:32]
aesBlock, err := aes.NewCipher(enckey)
if err != nil {
return nil, err
}
stream := cipher.NewCTR(aesBlock, iv)
cipherText := make([]byte, len(key.Data))
stream.XORKeyStream(cipherText, key.Data)
mac, err := crypto.Sha3(dk[32:64], cipherText)
if err != nil {
return nil, err
}
cipherParamsJSON := cipherparamsJSON{
IV: hex.EncodeToString(iv),
}
sp := ScryptParams{
N: scryptN,
R: scryptR,
P: scryptP,
DKeyLength: scryptKeyLen,
Salt: hex.EncodeToString(salt),
}
keyjson := cryptoJSON{
Cipher: ksCipher,
CipherText: hex.EncodeToString(cipherText),
CipherParams: cipherParamsJSON,
KDF: nameKDF,
KDFParams: sp,
MAC: hex.EncodeToString(mac),
}
encjson := encryptedKeyJSON{
Crypto: keyjson,
ID: key.ID,
Version: ksVersion,
}
data, err := json.MarshalIndent(&encjson, "", " ")
if err != nil {
return nil, err
}
return data, nil
}
// UnmarshalKey decrypts and unmarhals the private key
func UnmarshalKey(data []byte, passphrase string) (*Key, error) {
if passphrase == "" {
return nil, ErrEmptyPassphrase
}
encjson := encryptedKeyJSON{}
if err := json.Unmarshal(data, &encjson); err != nil {
return nil, fmt.Errorf("failed to unmarshal key data: %w", err)
}
if encjson.Version != ksVersion {
return nil, ErrVersionMismatch
}
if encjson.Crypto.Cipher != ksCipher {
return nil, ErrCipherMismatch
}
mac, err := hex.DecodeString(encjson.Crypto.MAC)
if err != nil {
return nil, fmt.Errorf("failed to decode mac: %w", err)
}
iv, err := hex.DecodeString(encjson.Crypto.CipherParams.IV)
if err != nil {
return nil, fmt.Errorf("failed to decode cipher params iv: %w", err)
}
salt, err := hex.DecodeString(encjson.Crypto.KDFParams.Salt)
if err != nil {
return nil, fmt.Errorf("failed to decode salt: %w", err)
}
ciphertext, err := hex.DecodeString(encjson.Crypto.CipherText)
if err != nil {
return nil, fmt.Errorf("failed to decode cipher text: %w", err)
}
dk, err := scrypt.Key([]byte(passphrase), salt, encjson.Crypto.KDFParams.N, encjson.Crypto.KDFParams.R, encjson.Crypto.KDFParams.P, encjson.Crypto.KDFParams.DKeyLength)
if err != nil {
return nil, fmt.Errorf("failed to derive key: %w", err)
}
hash, err := crypto.Sha3(dk[32:64], ciphertext)
if err != nil {
return nil, fmt.Errorf("failed to hash key and ciphertext: %w", err)
}
if !bytes.Equal(hash, mac) {
return nil, ErrMACMismatch
}
aesBlock, err := aes.NewCipher(dk[:32])
if err != nil {
return nil, fmt.Errorf("failed to create cipher block: %w", err)
}
stream := cipher.NewCTR(aesBlock, iv)
outputkey := make([]byte, len(ciphertext))
stream.XORKeyStream(outputkey, ciphertext)
return &Key{
ID: encjson.ID,
Data: outputkey,
}, nil
}
func removeFileExtension(filename string) string {
ext := filepath.Ext(filename)
return strings.TrimSuffix(filename, ext)
}
package keystore
import (
"fmt"
"os"
"path/filepath"
"sync"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/utils"
)
// KeyStore manages a local keystore with lock and unlock functionalities.
type KeyStore interface {
Save(id string, data []byte, passphrase string) (string, error)
Get(keyID string, passphrase string) (*Key, error)
Delete(keyID string, passphrase string) error
ListKeys() ([]string, error)
}
// BasicKeyStore handles keypair storage.
// TODO: add cache?
type BasicKeyStore struct {
fs afero.Fs
keysDir string
mu sync.RWMutex
}
var _ KeyStore = (*BasicKeyStore)(nil)
// New creates a new BasicKeyStore.
func New(fs afero.Fs, keysDir string) (*BasicKeyStore, error) {
if keysDir == "" {
return nil, ErrEmptyKeysDir
}
if err := fs.MkdirAll(keysDir, 0o700); err != nil {
return nil, fmt.Errorf("failed to create keystore directory: %w", err)
}
return &BasicKeyStore{
fs: fs,
keysDir: keysDir,
}, nil
}
// Save encrypts a key and writes it to a file.
func (ks *BasicKeyStore) Save(id string, data []byte, passphrase string) (string, error) {
if passphrase == "" {
return "", ErrEmptyPassphrase
}
key := &Key{
ID: id,
Data: data,
}
keyDataJSON, err := key.MarshalToJSON(passphrase)
if err != nil {
return "", fmt.Errorf("failed to marshal key: %w", err)
}
filename, err := utils.WriteToFile(ks.fs, keyDataJSON, filepath.Join(ks.keysDir, key.ID+".json"))
if err != nil {
return "", fmt.Errorf("failed to write key to file: %v", err)
}
return filename, nil
}
// Get unlocks a key by keyID.
func (ks *BasicKeyStore) Get(keyID string, passphrase string) (*Key, error) {
bts, err := afero.ReadFile(ks.fs, filepath.Join(ks.keysDir, keyID+".json"))
if err != nil {
if os.IsNotExist(err) {
return nil, ErrKeyNotFound
}
return nil, fmt.Errorf("failed to read keystore file: %w", err)
}
key, err := UnmarshalKey(bts, passphrase)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal keystore file: %w", err)
}
return key, err
}
// Delete removes the file referencing the given key.
func (ks *BasicKeyStore) Delete(keyID string, passphrase string) error {
ks.mu.Lock()
defer ks.mu.Unlock()
filePath := filepath.Join(ks.keysDir, keyID+".json")
bts, err := afero.ReadFile(ks.fs, filePath)
if err != nil {
if os.IsNotExist(err) {
return ErrKeyNotFound
}
return fmt.Errorf("failed to read keystore file: %w", err)
}
_, err = UnmarshalKey(bts, passphrase)
if err != nil {
return fmt.Errorf("invalid passphrase or corrupted key file: %w", err)
}
err = ks.fs.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to delete key file: %w", err)
}
return nil
}
// ListKeys lists the keys in the keysDir.
func (ks *BasicKeyStore) ListKeys() ([]string, error) {
keys := make([]string, 0)
dirEntries, err := afero.ReadDir(ks.fs, ks.keysDir)
if err != nil {
return nil, fmt.Errorf("failed to read keystore directory: %w", err)
}
for _, entry := range dirEntries {
_, err := afero.ReadFile(ks.fs, filepath.Join(ks.keysDir, entry.Name()))
if err != nil {
continue
}
keys = append(keys, removeFileExtension(entry.Name()))
}
return keys, nil
}
package crypto
// ReadVault
func ReadVault(path string, passphrase string) ([]byte, error) { //nolint:revive // its a todo
// TODO
return nil, ErrTODO
}
// WriteVault
func WriteVault(path string, passphrase string, data []byte) error { //nolint:revive // its a todo
// TODO
return ErrTODO
}
package did
type GetAnchorFunc func(did DID) (Anchor, error)
var anchorMethods map[string]GetAnchorFunc
func init() {
anchorMethods = map[string]GetAnchorFunc{
"key": makeKeyAnchor,
}
}
func GetAnchorForDID(did DID) (Anchor, error) {
makeAnchor, ok := anchorMethods[did.Method()]
if !ok {
return nil, ErrNoAnchorMethod
}
return makeAnchor(did)
}
func makeKeyAnchor(did DID) (Anchor, error) {
pubk, err := PublicKeyFromDID(did)
if err != nil {
return nil, err
}
return NewAnchor(did, pubk), nil
}
package did
import (
"context"
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const anchorEntryTTL = time.Hour
// Anchor is a DID anchor that encapsulates a public key that can be used
// for verification of signatures.
type Anchor interface {
DID() DID
Verify(data []byte, sig []byte) error
PublicKey() crypto.PubKey
}
// Provider holds the private key material necessary to sign statements for
// a DID.
type Provider interface {
DID() DID
Sign(data []byte) ([]byte, error)
Anchor() Anchor
PrivateKey() (crypto.PrivKey, error)
}
type TrustContext interface {
Anchors() []DID
Providers() []DID
GetAnchor(did DID) (Anchor, error)
GetProvider(did DID) (Provider, error)
AddAnchor(anchor Anchor)
AddProvider(provider Provider)
Start(gcInterval time.Duration)
Stop()
}
type anchorEntry struct {
anchor Anchor
expire time.Time
}
type BasicTrustContext struct {
mx sync.Mutex
anchors map[DID]*anchorEntry
providers map[DID]Provider
stop func()
}
var _ TrustContext = (*BasicTrustContext)(nil)
func NewTrustContext() TrustContext {
return &BasicTrustContext{
anchors: make(map[DID]*anchorEntry),
providers: make(map[DID]Provider),
}
}
func NewTrustContextWithPrivateKey(privk crypto.PrivKey) (TrustContext, error) {
ctx := NewTrustContext()
provider, err := ProviderFromPrivateKey(privk)
if err != nil {
return nil, fmt.Errorf("provide from private key: %w", err)
}
ctx.AddProvider(provider)
return ctx, nil
}
func NewTrustContextWithProvider(p Provider) TrustContext {
ctx := NewTrustContext()
ctx.AddProvider(p)
return ctx
}
func (ctx *BasicTrustContext) Anchors() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.anchors))
for anchor := range ctx.anchors {
result = append(result, anchor)
}
return result
}
func (ctx *BasicTrustContext) Providers() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.providers))
for provider := range ctx.providers {
result = append(result, provider)
}
return result
}
func (ctx *BasicTrustContext) GetAnchor(did DID) (Anchor, error) {
anchor, ok := ctx.getAnchor(did)
if ok {
return anchor, nil
}
anchor, err := GetAnchorForDID(did)
if err != nil {
return nil, fmt.Errorf("get anchor for did: %w", err)
}
ctx.AddAnchor(anchor)
return anchor, nil
}
func (ctx *BasicTrustContext) getAnchor(did DID) (Anchor, bool) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
entry, ok := ctx.anchors[did]
if ok {
entry.expire = time.Now().Add(anchorEntryTTL)
return entry.anchor, true
}
return nil, false
}
func (ctx *BasicTrustContext) GetProvider(did DID) (Provider, error) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
provider, ok := ctx.providers[did]
if !ok {
return nil, ErrNoProvider
}
return provider, nil
}
func (ctx *BasicTrustContext) AddAnchor(anchor Anchor) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.anchors[anchor.DID()] = &anchorEntry{
anchor: anchor,
expire: time.Now().Add(anchorEntryTTL),
}
}
func (ctx *BasicTrustContext) AddProvider(provider Provider) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.providers[provider.DID()] = provider
}
func (ctx *BasicTrustContext) Start(gcInterval time.Duration) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
}
gcCtx, stop := context.WithCancel(context.Background())
ctx.stop = stop
go ctx.gc(gcCtx, gcInterval)
}
func (ctx *BasicTrustContext) Stop() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
ctx.stop = nil
}
}
func (ctx *BasicTrustContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcAnchorEntries()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicTrustContext) gcAnchorEntries() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := time.Now()
for k, e := range ctx.anchors {
if e.expire.Before(now) {
delete(ctx.anchors, k)
}
}
}
package did
import (
"strings"
)
type DID struct {
URI string `json:"uri,omitempty"`
}
func (did DID) Equal(other DID) bool {
return did.URI == other.URI
}
func (did DID) Empty() bool {
return did.URI == ""
}
func (did DID) String() string {
return did.URI
}
func (did DID) Method() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[1]
}
return ""
}
func (did DID) Identifier() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[2]
}
return ""
}
func FromString(s string) (DID, error) {
if s != "" {
parts := strings.Split(s, ":")
if len(parts) != 3 {
return DID{}, ErrInvalidDID
}
for _, part := range parts {
if part == "" {
return DID{}, ErrInvalidDID
}
}
// TODO validate parts according to spec: https://www.w3.org/TR/did-core/
}
return DID{URI: s}, nil
}
package did
import (
"fmt"
"strings"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
mb "github.com/multiformats/go-multibase"
varint "github.com/multiformats/go-varint"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
type PublicKeyAnchor struct {
did DID
pubk crypto.PubKey
}
var _ Anchor = (*PublicKeyAnchor)(nil)
type PrivateKeyProvider struct {
did DID
privk crypto.PrivKey
}
var _ Provider = (*PrivateKeyProvider)(nil)
func NewAnchor(did DID, pubk crypto.PubKey) Anchor {
return &PublicKeyAnchor{
did: did,
pubk: pubk,
}
}
func NewProvider(did DID, privk crypto.PrivKey) Provider {
return &PrivateKeyProvider{
did: did,
privk: privk,
}
}
func (a *PublicKeyAnchor) DID() DID {
return a.did
}
func (a *PublicKeyAnchor) Verify(data []byte, sig []byte) error {
ok, err := a.pubk.Verify(data, sig)
if err != nil {
return err
}
if !ok {
return ErrInvalidSignature
}
return nil
}
func (a *PublicKeyAnchor) PublicKey() crypto.PubKey {
return a.pubk
}
func (p *PrivateKeyProvider) DID() DID {
return p.did
}
func (p *PrivateKeyProvider) Sign(data []byte) ([]byte, error) {
return p.privk.Sign(data)
}
func (p *PrivateKeyProvider) PrivateKey() (crypto.PrivKey, error) {
return p.privk, nil
}
func (p *PrivateKeyProvider) Anchor() Anchor {
return NewAnchor(p.did, p.privk.GetPublic())
}
func FromID(id crypto.ID) (DID, error) {
pubk, err := crypto.PublicKeyFromID(id)
if err != nil {
return DID{}, fmt.Errorf("public key from id: %w", err)
}
return FromPublicKey(pubk), nil
}
func FromPublicKey(pubk crypto.PubKey) DID {
uri := FormatKeyURI(pubk)
return DID{URI: uri}
}
func PublicKeyFromDID(did DID) (crypto.PubKey, error) {
if did.Method() != "key" {
return nil, ErrInvalidDID
}
pubk, err := ParseKeyURI(did.URI)
if err != nil {
return nil, fmt.Errorf("parsing did key identifier: %w", err)
}
return pubk, nil
}
func AnchorFromPublicKey(pubk crypto.PubKey) (Anchor, error) {
did := FromPublicKey(pubk)
return NewAnchor(did, pubk), nil
}
func ProviderFromPrivateKey(privk crypto.PrivKey) (Provider, error) {
did := FromPublicKey(privk.GetPublic())
return NewProvider(did, privk), nil
}
// Note: this code originated in https://github.com/ucan-wg/go-ucan/blob/main/didkey/key.go
// Copyright applies; some superficial modifications by vyzo.
const (
multicodecKindEd25519PubKey uint64 = 0xed
multicodecKindSecp256k1PubKey uint64 = 0xe7
multicodecKindEthPubKey uint64 = 0xef01
keyPrefix = "did:key"
)
func FormatKeyURI(pubk crypto.PubKey) string {
raw, err := pubk.Raw()
if err != nil {
return ""
}
var t uint64
switch pubk.Type() {
case crypto.Ed25519:
t = multicodecKindEd25519PubKey
case crypto.Secp256k1:
t = multicodecKindSecp256k1PubKey
case crypto.Eth:
t = multicodecKindEthPubKey
default:
// we don't support those yet
log.Errorf("unsupported key type: %d", t)
return ""
}
size := varint.UvarintSize(t)
data := make([]byte, size+len(raw))
n := varint.PutUvarint(data, t)
copy(data[n:], raw)
b58BKeyStr, err := mb.Encode(mb.Base58BTC, data)
if err != nil {
return ""
}
return fmt.Sprintf("%s:%s", keyPrefix, b58BKeyStr)
}
func ParseKeyURI(uri string) (crypto.PubKey, error) {
if !strings.HasPrefix(uri, keyPrefix) {
return nil, fmt.Errorf("decentralized identifier is not a 'key' type")
}
uri = strings.TrimPrefix(uri, keyPrefix+":")
enc, data, err := mb.Decode(uri)
if err != nil {
return nil, fmt.Errorf("decoding multibase: %w", err)
}
if enc != mb.Base58BTC {
return nil, fmt.Errorf("unexpected multibase encoding: %s", mb.EncodingToStr[enc])
}
keyType, n, err := varint.FromUvarint(data)
if err != nil {
return nil, err
}
switch keyType {
case multicodecKindEd25519PubKey:
return libp2p_crypto.UnmarshalEd25519PublicKey(data[n:])
case multicodecKindSecp256k1PubKey:
return libp2p_crypto.UnmarshalSecp256k1PublicKey(data[n:])
case multicodecKindEthPubKey:
return crypto.UnmarshalEthPublicKey(data[n:])
default:
return nil, ErrInvalidKeyType
}
}
package did
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
"os/exec"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const ledgerCLI = "ledger-cli"
type LedgerWalletProvider struct {
did DID
pubk crypto.PubKey
acct int
}
var _ Provider = (*LedgerWalletProvider)(nil)
type LedgerKeyOutput struct {
Key string `json:"key"`
Address string `json:"address"`
}
type LedgerSignOutput struct {
ECDSA LedgerSignECDSAOutput `json:"ecdsa"`
}
type LedgerSignECDSAOutput struct {
V uint `json:"v"`
R string `json:"r"`
S string `json:"s"`
}
func NewLedgerWalletProvider(acct int) (Provider, error) {
tmp, err := getLedgerTmpFile()
if err != nil {
return nil, err
}
defer os.Remove(tmp)
var output LedgerKeyOutput
if err := ledgerExec(
tmp,
&output,
"key",
"-o", tmp,
"-a", fmt.Sprintf("%d", acct),
); err != nil {
return nil, fmt.Errorf("error executing ledger cli: %w", err)
}
// decode the hex key
raw, err := hex.DecodeString(output.Key)
if err != nil {
return nil, fmt.Errorf("decode ledger key: %w", err)
}
pubk, err := crypto.UnmarshalEthPublicKey(raw)
if err != nil {
return nil, fmt.Errorf("unmarshal ledger raw key: %w", err)
}
did := FromPublicKey(pubk)
return &LedgerWalletProvider{
did: did,
pubk: pubk,
acct: acct,
}, nil
}
func ledgerExec(tmp string, output interface{}, args ...string) error {
ledger, err := exec.LookPath(ledgerCLI)
if err != nil {
return fmt.Errorf("can't find %s in PATH: %w", ledgerCLI, err)
}
cmd := exec.Command(ledger, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("get ledger key: %w", err)
}
f, err := os.Open(tmp)
if err != nil {
return fmt.Errorf("open ledger output: %w", err)
}
defer f.Close()
decoder := json.NewDecoder(f)
if err := decoder.Decode(&output); err != nil {
return fmt.Errorf("parse ledger output: %w", err)
}
return nil
}
func getLedgerTmpFile() (string, error) {
tmp, err := os.CreateTemp("", "ledger.out")
if err != nil {
return "", fmt.Errorf("creating temporary file: %w", err)
}
tmpPath := tmp.Name()
tmp.Close()
return tmpPath, nil
}
func (p *LedgerWalletProvider) DID() DID {
return p.did
}
func (p *LedgerWalletProvider) Sign(data []byte) ([]byte, error) {
tmp, err := getLedgerTmpFile()
if err != nil {
return nil, err
}
defer os.Remove(tmp)
dataHex := hex.EncodeToString(data)
var output LedgerSignOutput
if err := ledgerExec(
tmp,
&output,
"sign",
"-o", tmp,
"-a", fmt.Sprintf("%d", p.acct),
dataHex,
); err != nil {
return nil, fmt.Errorf("error executing ledger cli: %w", err)
}
rBytes, err := hex.DecodeString(output.ECDSA.R)
if err != nil {
return nil, fmt.Errorf("error decoding signature r: %w", err)
}
sBytes, err := hex.DecodeString(output.ECDSA.S)
if err != nil {
return nil, fmt.Errorf("error decoding signature s: %w", err)
}
r := secp256k1.ModNScalar{}
s := secp256k1.ModNScalar{}
if overflow := r.SetByteSlice(rBytes); overflow {
return nil, fmt.Errorf("signature r overflowed")
}
if overflow := s.SetByteSlice(sBytes); overflow {
return nil, fmt.Errorf("signature s overflowed")
}
sig := ecdsa.NewSignature(&r, &s)
return sig.Serialize(), nil
}
func (p *LedgerWalletProvider) Anchor() Anchor {
return NewAnchor(p.did, p.pubk)
}
func (p *LedgerWalletProvider) PrivateKey() (crypto.PrivKey, error) {
return nil, fmt.Errorf("ledger private key cannot be exported: %w", ErrHardwareKey)
}
package ucan
import (
"strings"
)
type Capability string
const Root = Capability("/")
func (c Capability) Implies(other Capability) bool {
if c == other || c == Root {
return true
}
return strings.HasPrefix(string(other), string(c)+"/")
}
package ucan
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
)
const (
maxCapabilitySize = 16384
SelfSignNo SelfSignMode = iota
SelfSignAlso
SelfSignOnly
)
type SelfSignMode int
type CapabilityContext interface {
// DID returns the context's controlling DID
DID() did.DID
// Trust returns the context's did trust context
Trust() did.TrustContext
// Consume ingests the provided capability tokens
Consume(origin did.DID, cap []byte) error
// Discard discards previously consumed capability tokens
Discard(cap []byte)
// Require ensures that at least one of the capabilities is delegated from
// the subject to the audience, with an appropriate anchor
// An empty list will mean that no capabilities are required and is vacuously
// true.
Require(anchor did.DID, subject crypto.ID, audience crypto.ID, require []Capability) error
// RequireBroadcast ensures that at least one of the capabilities is delegated
// to thes subject for the specified broadcast topics
RequireBroadcast(origin did.DID, subject crypto.ID, topic string, require []Capability) error
// Provide prepares the appropriate capability tokens to prove and delegate authority
// to a subject for an audience.
// - It delegates invocations to the subject with an audience and invoke capabilities
// - It delegates the delegate capabilities to the target with audience the subject
Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, delegate []Capability) ([]byte, error)
// ProvideBroadcast prepares the appropriate capability tokens to prove authority
// to broadcast to a topic
ProvideBroadcast(subject crypto.ID, topic string, expire uint64, broadcast []Capability) ([]byte, error)
// AddRoots adds trust anchors
AddRoots(trust []did.DID, require, provide TokenList) error
// ListRoots list the current trust anchors
ListRoots() ([]did.DID, TokenList, TokenList)
// RemoveRoots removes the specified trust anchors
RemoveRoots(trust []did.DID, require, provide TokenList)
// Delegate creates the appropriate delegation tokens anchored in our roots
Delegate(subject, audience did.DID, topics []string, expire, depth uint64, cap []Capability, selfSign SelfSignMode) (TokenList, error)
// DelegateInvocation creates the appropriate invocation tokens anchored in anchor
DelegateInvocation(target, subject, audience did.DID, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error)
// DelegateBroadcast creates the appropriate broadcast token anchored in our roots
DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error)
// Grant creates the appropriate delegation tokens considering ourselves as the root
Grant(action Action, subject, audience did.DID, topic []string, expire, depth uint64, provide []Capability) (TokenList, error)
// Start starts a token garbage collector goroutine that clears expired tokens
Start(gcInterval time.Duration)
// Stop stops a previously started gc goroutine
Stop()
}
type BasicCapabilityContext struct {
mx sync.Mutex
provider did.Provider
trust did.TrustContext
roots map[did.DID]struct{} // our root anchors of trust
require map[did.DID][]*Token // our acceptance side-roots
provide map[did.DID][]*Token // root capabilities -> tokens
tokens map[did.DID][]*Token // our context dependent capabilities; subject -> tokens
stop func()
}
var _ CapabilityContext = (*BasicCapabilityContext)(nil)
func NewCapabilityContext(trust did.TrustContext, ctxDID did.DID, roots []did.DID, require, provide TokenList) (CapabilityContext, error) {
ctx := &BasicCapabilityContext{
trust: trust,
roots: make(map[did.DID]struct{}),
require: make(map[did.DID][]*Token),
provide: make(map[did.DID][]*Token),
tokens: make(map[did.DID][]*Token),
}
p, err := trust.GetProvider(ctxDID)
if err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
ctx.provider = p
if err := ctx.AddRoots(roots, require, provide); err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
return ctx, nil
}
func (ctx *BasicCapabilityContext) DID() did.DID {
return ctx.provider.DID()
}
func (ctx *BasicCapabilityContext) Trust() did.TrustContext {
return ctx.trust
}
func (ctx *BasicCapabilityContext) Start(gcInterval time.Duration) {
if ctx.stop != nil {
gcCtx, cancel := context.WithCancel(context.Background())
go ctx.gc(gcCtx, gcInterval)
ctx.stop = cancel
}
}
func (ctx *BasicCapabilityContext) Stop() {
if ctx.stop != nil {
ctx.stop()
}
}
func (ctx *BasicCapabilityContext) AddRoots(roots []did.DID, require, provide TokenList) error {
ctx.addRoots(roots)
now := uint64(time.Now().UnixNano())
for _, t := range require.Tokens {
if err := t.Verify(ctx.trust, now); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeRequireToken(t)
}
for _, t := range provide.Tokens {
if err := t.Verify(ctx.trust, now); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeProvideToken(t)
}
return nil
}
func (ctx *BasicCapabilityContext) ListRoots() ([]did.DID, TokenList, TokenList) {
var require, provide []*Token
roots := ctx.getRoots()
for _, anchor := range ctx.getRequireAnchors() {
tokenList := ctx.getRequireTokens(anchor)
require = append(require, tokenList...)
}
for _, anchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(anchor)
provide = append(provide, tokenList...)
}
return roots, TokenList{Tokens: require}, TokenList{Tokens: provide}
}
func (ctx *BasicCapabilityContext) RemoveRoots(trust []did.DID, require, provide TokenList) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, root := range trust {
delete(ctx.roots, root)
}
for _, t := range require.Tokens {
tokenList, ok := ctx.require[t.Issuer()]
if ok {
tokenList = slices.DeleteFunc(tokenList, func(ot *Token) bool {
return bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(tokenList) > 0 {
ctx.require[t.Issuer()] = tokenList
} else {
delete(ctx.require, t.Issuer())
}
}
}
for _, t := range provide.Tokens {
tokenList, ok := ctx.provide[t.Issuer()]
if ok {
tokenList = slices.DeleteFunc(tokenList, func(ot *Token) bool {
return bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(tokenList) > 0 {
ctx.provide[t.Issuer()] = tokenList
} else {
delete(ctx.provide, t.Issuer())
}
}
}
}
func (ctx *BasicCapabilityContext) Grant(action Action, subject, audience did.DID, topics []string, expire, depth uint64, provide []Capability) (TokenList, error) {
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return TokenList{}, fmt.Errorf("nonce: %w", err)
}
topicCap := make([]Capability, 0, len(topics))
for _, topic := range topics {
topicCap = append(topicCap, Capability(topic))
}
result := &DMSToken{
Issuer: ctx.DID(),
Subject: subject,
Audience: audience,
Action: action,
Topic: topicCap,
Capability: provide,
Nonce: nonce,
Expire: expire,
Depth: depth,
}
data, err := result.SignatureData()
if err != nil {
return TokenList{}, fmt.Errorf("grant: %w", err)
}
sig, err := ctx.provider.Sign(data)
if err != nil {
return TokenList{}, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return TokenList{Tokens: []*Token{{DMS: result}}}, nil
}
func (ctx *BasicCapabilityContext) Delegate(subject, audience did.DID, topics []string, expire, depth uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
topicCap := make([]Capability, 0, len(topics))
for _, topic := range topics {
topicCap = append(topicCap, Capability(topic))
}
var result []*Token
if selfSign == SelfSignOnly {
goto selfsign
}
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
if len(tokenList) == 0 {
continue
}
for _, t := range tokenList {
var providing []Capability
definitiveExpire := expire
if definitiveExpire == 0 {
definitiveExpire = t.Expire()
}
for _, c := range provide {
if t.Anchor(trustAnchor) && t.AllowDelegation(Delegate, ctx.DID(), audience, topicCap, definitiveExpire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.Delegate(ctx.provider, subject, audience, topicCap, definitiveExpire, depth, providing)
if err != nil {
log.Debugf("error delegating %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
tokens, err := ctx.Grant(Delegate, subject, audience, topics, expire, depth, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, tokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) DelegateInvocation(target, subject, audience did.DID, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
// first get tokens we have about ourselves and see if any allows delegation to
// the subject for the audience
tokenList := ctx.getSubjectTokens(ctx.DID())
tokens := ctx.delegateInvocation(tokenList, target, subject, audience, expire, provide)
result = append(result, tokens...)
if selfSign == SelfSignOnly {
goto selfsign
}
// then we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateInvocation(tokenList, trustAnchor, subject, audience, expire, provide)
result = append(result, tokens...)
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
selfTokens, err := ctx.Grant(Invoke, subject, audience, nil, expire, 0, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateInvocation(tokenList []*Token, anchor, subject, audience did.DID, expire uint64, provide []Capability) []*Token {
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(Invoke, ctx.DID(), audience, nil, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.DelegateInvocation(ctx.provider, subject, audience, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
if selfSign == SelfSignOnly {
goto selfsign
}
// first we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateBroadcast(tokenList, trustAnchor, subject, topic, expire, provide)
result = append(result, tokens...)
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
selfTokens, err := ctx.Grant(Broadcast, subject, did.DID{}, []string{topic}, expire, 0, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting broadcast: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateBroadcast(tokenList []*Token, anchor did.DID, subject did.DID, topic string, expire uint64, provide []Capability) []*Token {
topicCap := Capability(topic)
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(Broadcast, ctx.DID(), did.DID{}, []Capability{topicCap}, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.DelegateBroadcast(ctx.provider, subject, topicCap, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) Consume(origin did.DID, data []byte) error {
if len(data) > maxCapabilitySize {
return ErrTooBig
}
var tokens TokenList
if err := json.Unmarshal(data, &tokens); err != nil {
return fmt.Errorf("unmarshaling payload: %w", err)
}
rootAnchors := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
now := uint64(time.Now().UnixNano())
for _, t := range tokens.Tokens {
if t.Anchor(ctx.DID()) {
goto verify
}
if t.Anchor(origin) {
goto verify
}
for _, anchor := range rootAnchors {
if t.Anchor(anchor) {
goto verify
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) {
goto verify
}
}
}
log.Debugf("ignoring token %+v", *t)
continue
verify:
if err := t.Verify(ctx.trust, now); err != nil {
log.Warnf("failed to verify token issued by %s: %s", t.Issuer(), err)
continue
}
ctx.consumeSubjectToken(t)
}
return nil
}
func (ctx *BasicCapabilityContext) Discard(data []byte) {
var tokens TokenList
if err := json.Unmarshal(data, &tokens); err != nil {
return
}
ctx.discardTokens(tokens.Tokens)
}
func (ctx *BasicCapabilityContext) consumeAnchorToken(getf func() []*Token, setf func(result []*Token), t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList := getf()
result := make([]*Token, 0, len(tokenList)+1)
for _, ot := range tokenList {
if ot.Subsumes(t) {
return
}
if t.Subsumes(ot) {
continue
}
result = append(result, ot)
}
result = append(result, t)
setf(result)
}
func (ctx *BasicCapabilityContext) consumeRequireToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.require[t.Issuer()] },
func(result []*Token) {
ctx.require[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeProvideToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.provide[t.Issuer()] },
func(result []*Token) {
ctx.provide[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeSubjectToken(t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
subject := t.Subject()
tokenList := ctx.tokens[subject]
tokenList = append(tokenList, t)
ctx.tokens[subject] = tokenList
}
func (ctx *BasicCapabilityContext) Require(anchor did.DID, subject crypto.ID, audience crypto.ID, cap []Capability) error {
if len(cap) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return fmt.Errorf("DID for audience: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
for _, t := range tokenList {
for _, c := range cap {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) RequireBroadcast(anchor did.DID, subject crypto.ID, topic string, require []Capability) error {
if len(require) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
topicCap := Capability(topic)
for _, t := range tokenList {
for _, c := range require {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, provide []Capability) ([]byte, error) {
if len(invoke) == 0 && len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return nil, fmt.Errorf("DID for audience: %w", err)
}
var result []*Token
var invocation, delegation TokenList
if len(invoke) == 0 {
return nil, fmt.Errorf("no invocation capabilities: %w", ErrNotAuthorized)
}
invocation, err = ctx.DelegateInvocation(target, subjectDID, audienceDID, expire, invoke, SelfSignAlso)
if err != nil {
return nil, fmt.Errorf("cannot provide invocation tokens: %w", err)
}
if len(invocation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary invocation tokens: %w", ErrNotAuthorized)
}
result = append(result, invocation.Tokens...)
if len(provide) == 0 {
goto marshal
}
delegation, err = ctx.Delegate(target, subjectDID, nil, expire, 1, provide, SelfSignOnly)
if err != nil {
return nil, fmt.Errorf("cannot provide delegation tokens: %w", err)
}
if len(delegation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary delegation tokens: %w", ErrNotAuthorized)
}
result = append(result, delegation.Tokens...)
marshal:
payload := TokenList{Tokens: result}
data, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) ProvideBroadcast(subject crypto.ID, topic string, expire uint64, provide []Capability) ([]byte, error) {
if len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
broadcast, err := ctx.DelegateBroadcast(subjectDID, topic, expire, provide, SelfSignAlso)
if err != nil {
return nil, fmt.Errorf("cannot provide broadcast tokens: %w", err)
}
if len(broadcast.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary broadcast tokens: %w", ErrNotAuthorized)
}
data, err := json.Marshal(broadcast)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) getRoots() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.roots))
for anchor := range ctx.roots {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) addRoots(anchors []did.DID) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, anchor := range anchors {
ctx.roots[anchor] = struct{}{}
}
}
func (ctx *BasicCapabilityContext) getRequireAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.require))
for anchor := range ctx.require {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getProvideAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.provide))
for anchor := range ctx.provide {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getTokens(getf func() ([]*Token, bool), setf func([]*Token)) []*Token {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList, ok := getf()
if !ok {
return nil
}
// filter expired
now := uint64(time.Now().UnixNano())
result := slices.DeleteFunc(slices.Clone(tokenList), func(t *Token) bool {
return t.ExpireBefore(now)
})
setf(result)
return result
}
func (ctx *BasicCapabilityContext) getRequireTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.require[anchor]; return result, ok },
func(result []*Token) { ctx.require[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getProvideTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.provide[anchor]; return result, ok },
func(result []*Token) { ctx.provide[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getSubjectTokens(subject did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.tokens[subject]; return result, ok },
func(result []*Token) { ctx.tokens[subject] = result },
)
}
func (ctx *BasicCapabilityContext) discardTokens(tokens []*Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, t := range tokens {
subject := t.Subject()
subjectTokens := slices.DeleteFunc(slices.Clone(ctx.tokens[subject]), func(ot *Token) bool {
return t.Issuer() == ot.Issuer() && bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(subjectTokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = subjectTokens
}
}
}
func (ctx *BasicCapabilityContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcTokens()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicCapabilityContext) gcTokens() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := uint64(time.Now().UnixNano())
for anchor, tokens := range ctx.require {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.require, anchor)
} else {
ctx.require[anchor] = tokens
}
}
for anchor, tokens := range ctx.provide {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.provide, anchor)
} else {
ctx.provide[anchor] = tokens
}
}
for subject, tokens := range ctx.tokens {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = tokens
}
}
}
package ucan
import (
"encoding/json"
"fmt"
"io"
"gitlab.com/nunet/device-management-service/lib/did"
)
type Saver interface {
Save(wr io.Writer) error
}
type CapabilityContextView struct {
DID did.DID
Roots []did.DID
Require TokenList
Provide TokenList
}
func SaveCapabilityContext(ctx CapabilityContext, wr io.Writer) error {
if bctx, ok := ctx.(*BasicCapabilityContext); ok {
return saveCapabilityContext(bctx, wr)
} else if saver, ok := ctx.(Saver); ok {
return saver.Save(wr)
}
return fmt.Errorf("cannot save context: %w", ErrBadContext)
}
func saveCapabilityContext(ctx *BasicCapabilityContext, wr io.Writer) error {
roots, require, provide := ctx.ListRoots()
view := CapabilityContextView{
DID: ctx.provider.DID(),
Roots: roots,
Require: require,
Provide: provide,
}
encoder := json.NewEncoder(wr)
if err := encoder.Encode(&view); err != nil {
return fmt.Errorf("encoding capability context view: %w", err)
}
return nil
}
func LoadCapabilityContext(trust did.TrustContext, rd io.Reader) (CapabilityContext, error) {
var view CapabilityContextView
decoder := json.NewDecoder(rd)
if err := decoder.Decode(&view); err != nil {
return nil, fmt.Errorf("decoding capability context view: %w", err)
}
var require, provide TokenList
for _, t := range view.Require.Tokens {
if !t.Expired() {
require.Tokens = append(require.Tokens, t)
}
}
for _, t := range view.Provide.Tokens {
if !t.Expired() {
provide.Tokens = append(provide.Tokens, t)
}
}
return NewCapabilityContext(trust, view.DID, view.Roots, require, provide)
}
package ucan
import (
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"time"
"gitlab.com/nunet/device-management-service/lib/did"
)
type Action string
const (
Invoke Action = "invoke"
Delegate Action = "delegate"
Broadcast Action = "broadcast"
// Revoke Action = "revoke" // TODO
nonceLength = 12 // 96 bits
)
var signaturePrefix = []byte("dms:token:")
type Token struct {
// DMS tokens
DMS *DMSToken `json:"dms,omitempty"`
// UCAN standard (when it is done) envelope for BYO anhcors
UCAN *BYOToken `json:"ucan,omitempty"`
}
type DMSToken struct {
Action Action `json:"act"`
Issuer did.DID `json:"iss"`
Subject did.DID `json:"sub"`
Audience did.DID `json:"aud"`
Topic []Capability `json:"topic,omitempty"`
Capability []Capability `json:"cap"`
Nonce []byte `json:"nonce"`
Expire uint64 `json:"exp"`
Depth uint64 `json:"depth,omitempty"`
Chain *Token `json:"chain,omitempty"`
Signature []byte `json:"sig,omitempty"`
}
type BYOToken struct {
// TODO followup
}
type TokenList struct {
Tokens []*Token `json:"tok,omitempty"`
}
func (t *Token) SignatureData() ([]byte, error) {
switch {
case t.DMS != nil:
return t.DMS.SignatureData()
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) SignatureData() ([]byte, error) {
tCopy := *t
tCopy.Signature = nil
data, err := json.Marshal(&tCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (t *Token) Issuer() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Issuer
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Subject() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Subject
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Audience() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Audience
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Capability() []Capability {
switch {
case t.DMS != nil:
return t.DMS.Capability
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Topic() []Capability {
switch {
case t.DMS != nil:
return t.DMS.Topic
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Expire() uint64 {
switch {
case t.DMS != nil:
return t.DMS.Expire
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0
}
}
func (t *Token) Nonce() []byte {
switch {
case t.DMS != nil:
return t.DMS.Nonce
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil // expired right after the unix big bang
}
}
func (t *Token) Action() Action {
switch {
case t.DMS != nil:
return t.DMS.Action
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return Action("")
}
}
func (t *Token) Verify(trust did.TrustContext, now uint64) error {
return t.verify(trust, now, 0)
}
func (t *Token) verify(trust did.TrustContext, now, depth uint64) error {
switch {
case t.DMS != nil:
return t.DMS.verify(trust, now, depth)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return ErrBadToken
}
}
func (t *DMSToken) verify(trust did.TrustContext, now, depth uint64) error {
if t.ExpireBefore(now) {
return ErrCapabilityExpired
}
if t.Depth > 0 && depth > t.Depth {
return fmt.Errorf("max token depth exceeded: %w", ErrNotAuthorized)
}
if t.Chain != nil {
if t.Chain.Action() != Delegate {
return fmt.Errorf("verify: chain does not allow delegation: %w", ErrNotAuthorized)
}
if t.Chain.ExpireBefore(t.Expire) {
return ErrCapabilityExpired
}
if err := t.Chain.verify(trust, now, depth+1); err != nil {
return err
}
if !t.Issuer.Equal(t.Chain.Subject()) {
return fmt.Errorf("verify: issuer/chain subject misnmatch: %w", ErrNotAuthorized)
}
needCapability := slices.Clone(t.Capability)
for _, c := range t.Capability {
if t.Chain.allowDelegation(t.Issuer, t.Audience, t.Topic, t.Expire, c) {
needCapability = slices.DeleteFunc(needCapability, func(oc Capability) bool {
return c == oc
})
if len(needCapability) == 0 {
break
}
}
}
if len(needCapability) > 0 {
return fmt.Errorf("verify: capabilities are not allowed by the chain: %w", ErrNotAuthorized)
}
}
anchor, err := trust.GetAnchor(t.Issuer)
if err != nil {
return fmt.Errorf("verify: anchor: %w", err)
}
data, err := t.SignatureData()
if err != nil {
return fmt.Errorf("verify: signature data: %w", err)
}
if err := anchor.Verify(data, t.Signature); err != nil {
return fmt.Errorf("verify: signature: %w", err)
}
return nil
}
func (t *Token) AllowAction(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowAction(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowAction(ot *Token) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(ot.Expire()) {
return false
}
if !ot.Anchor(t.Subject) {
return false
}
if t.Depth > 0 {
depth, ok := ot.AnchorDepth(t.Subject)
if ok && depth > t.Depth {
return false
}
}
if !t.Audience.Empty() && !t.Audience.Equal(ot.Audience()) {
return false
}
for _, oc := range ot.Capability() {
allow := false
for _, c := range t.Capability {
if c.Implies(oc) {
allow = true
break
}
}
if !allow {
return false
}
}
for _, otherTopic := range ot.Topic() {
allow := false
for _, topic := range t.Topic {
if topic.Implies(otherTopic) {
allow = true
break
}
}
if !allow {
return false
}
}
return true
}
func (t *Token) Size() int {
data, _ := t.SignatureData()
return len(data)
}
func (t *Token) Subsumes(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.Subsumes(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Subsumes(ot *Token) bool {
if t.Issuer.Equal(ot.Issuer()) &&
t.Subject.Equal(ot.Subject()) &&
t.Audience.Equal(ot.Audience()) &&
t.Expire >= ot.Expire() {
loop:
for _, oc := range ot.Capability() {
for _, c := range t.Capability {
if c.Implies(oc) {
continue loop
}
}
return false
}
return true
}
return false
}
func (t *Token) AllowInvocation(subject, audience did.DID, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowInvocation(subject, audience, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowInvocation(subject, audience did.DID, c Capability) bool {
if t.Action != Invoke {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, granted := range t.Capability {
if granted.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowBroadcast(subject did.DID, topic Capability, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowBroadcast(subject, topic, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowBroadcast(subject did.DID, topic Capability, c Capability) bool {
if t.Action != Broadcast {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() {
return false
}
allow := false
for _, allowTopic := range t.Topic {
if allowTopic.Implies(topic) {
allow = true
break
}
}
if !allow {
return false
}
for _, allowCap := range t.Capability {
if allowCap.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowDelegation(action Action, issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowDelegation(action, issuer, audience, topics, expire, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowDelegation(action Action, issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
if action == Delegate {
if !t.verifyDepth(2) {
// certificate would be dead end with 1
return false
}
} else {
if !t.verifyDepth(1) {
return false
}
}
return t.allowDelegation(issuer, audience, topics, expire, c)
}
func (t *Token) allowDelegation(issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.allowDelegation(issuer, audience, topics, expire, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) allowDelegation(issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(expire) {
return false
}
if !t.Subject.Equal(issuer) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, topic := range topics {
allow := false
for _, myTopic := range t.Topic {
if myTopic.Implies(topic) {
allow = true
break
}
}
if !allow {
return false
}
}
for _, myCap := range t.Capability {
if myCap.Implies(c) {
return true
}
}
return false
}
func (t *Token) verifyDepth(depth uint64) bool {
switch {
case t.DMS != nil:
return t.DMS.verifyDepth(depth)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) verifyDepth(depth uint64) bool {
if t.Depth > 0 && depth > t.Depth {
return false
}
if t.Chain != nil {
return t.Chain.verifyDepth(depth + 1)
}
return true
}
func (t *Token) Delegate(provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.Delegate(provider, subject, audience, topics, expire, depth, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) Delegate(provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Delegate, provider, subject, audience, topics, expire, depth, c)
}
func (t *DMSToken) delegate(action Action, provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*DMSToken, error) {
if t.Action != Delegate {
return nil, ErrNotAuthorized
}
if action == Delegate {
if !t.verifyDepth(2) {
// certificate would be dead end with 1
return nil, ErrNotAuthorized
}
} else {
if !t.verifyDepth(1) {
return nil, ErrNotAuthorized
}
}
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return nil, fmt.Errorf("nonce: %w", err)
}
result := &DMSToken{
Action: action,
Issuer: provider.DID(),
Subject: subject,
Audience: audience,
Topic: topics,
Capability: c,
Nonce: nonce,
Expire: expire,
Depth: depth,
Chain: &Token{DMS: t},
}
data, err := result.SignatureData()
if err != nil {
return nil, fmt.Errorf("delegate: %w", err)
}
sig, err := provider.Sign(data)
if err != nil {
return nil, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return result, nil
}
func (t *Token) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateInvocation(provider, subject, audience, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Invoke, provider, subject, audience, nil, expire, 0, c)
}
func (t *Token) DelegateBroadcast(provider did.Provider, subject did.DID, topic Capability, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateBroadcast(provider, subject, topic, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateBroadcast(provider did.Provider, subject did.DID, topic Capability, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Broadcast, provider, subject, did.DID{}, []Capability{topic}, expire, 0, c)
}
func (t *Token) Anchor(anchor did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.Anchor(anchor)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Anchor(anchor did.DID) bool {
if t.Issuer.Equal(anchor) {
return true
}
if t.Chain != nil {
return t.Chain.Anchor(anchor)
}
return false
}
func (t *Token) AnchorDepth(anchor did.DID) (uint64, bool) {
switch {
case t.DMS != nil:
return t.DMS.AnchorDepth(anchor)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0, false
}
}
func (t *DMSToken) AnchorDepth(anchor did.DID) (depth uint64, have bool) {
if t.Issuer.Equal(anchor) {
have = true
depth = 0
}
if t.Chain != nil {
if chainDepth, chainHave := t.Chain.AnchorDepth(anchor); chainHave {
have = true
depth = chainDepth + 1
}
}
return depth, have
}
func (t *Token) Expired() bool {
return t.ExpireBefore(uint64(time.Now().UnixNano()))
}
func (t *Token) ExpireBefore(deadline uint64) bool {
switch {
case t.DMS != nil:
return t.DMS.ExpireBefore(deadline)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return true
}
}
func (t *DMSToken) ExpireBefore(deadline uint64) bool {
if deadline > t.Expire {
return true
}
if t.Chain != nil {
return t.Chain.ExpireBefore(deadline)
}
return false
}
func (t *Token) SelfSigned(origin did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.SelfSigned(origin)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) SelfSigned(origin did.DID) bool {
if t.Chain != nil {
return t.Chain.SelfSigned(origin)
}
return t.Issuer.Equal(origin)
}
package main
import "gitlab.com/nunet/device-management-service/cmd"
// @title Device Management Service
// @version 0.4.185
// @description A dashboard application for computing providers.
// @termsOfService https://nunet.io/tos
// @contact.name Support
// @contact.url https://devexchange.nunet.io/
// @contact.email support@nunet.io
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @host localhost:9999
//
// @Schemes http
//
// @BasePath /api/v1
func main() {
cmd.Execute()
}
package libp2p
import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
"strings"
"sync"
"time"
dht_pb "github.com/libp2p/go-libp2p-kad-dht/pb"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
msgio "github.com/libp2p/go-msgio"
"github.com/libp2p/go-msgio/protoio" //nolint:staticcheck
multiaddr "github.com/multiformats/go-multiaddr"
"google.golang.org/protobuf/proto"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
)
const kadv1 = "/kad/1.0.0"
// Connect to Bootstrap nodes
func (l *Libp2p) ConnectToBootstrapNodes(ctx context.Context) error {
// bootstrap all nodes at the same time.
if len(l.config.BootstrapPeers) > 0 {
var wg sync.WaitGroup
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
for _, addr := range l.config.BootstrapPeers {
wg.Add(1)
go func(peerAddr multiaddr.Multiaddr) {
defer wg.Done()
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
log.Errorf("failed to convert multi addr to addr info %v - %v", peerAddr, err)
return
}
if err := l.Host.Connect(connectCtx, *addrInfo); err != nil {
log.Errorf("failed to connect to bootstrap node %s - %v", addrInfo.ID.String(), err)
} else {
log.Infof("connected to Bootstrap Node %s", addrInfo.ID.String())
}
}(addr)
}
wg.Wait()
}
return nil
}
// Start dht bootstrapper
func (l *Libp2p) BootstrapDHT(ctx context.Context) error {
if err := l.DHT.Bootstrap(ctx); err != nil {
log.Errorf("failed to prepare this node for bootstraping: %s", err)
return err
}
return nil
}
// startRandomWalk starts a background process that crawls the dht by resolving random keys.
func (l *Libp2p) startRandomWalk(ctx context.Context) {
go func() {
log.Debug("starting bootstrap process")
// A simple mechanism to improve our botostrap and peer discovery:
// 1. initiate a background, never ending, random walk which tries to resolve
// random keys in the dht and by extension discovers other peers.
interval := 5 * time.Minute
delayOnError := 10 * time.Second
time.Sleep(interval) // wait for dht ready
dhtProto := protocol.ID(l.config.DHTPrefix + kadv1)
sender := newDHTMessageSender(l.Host, dhtProto)
messenger, err := dht_pb.NewProtocolMessenger(sender)
if err != nil {
log.Errorf("bootstrap: creating protocol messenger: %s", err)
return
}
var depth int
var key string
for {
select {
case <-ctx.Done():
log.Debugf("bootstrap: context done, stopping bootstrap")
return
default:
randomPeerID, err := l.DHT.RoutingTable().GenRandPeerID(0)
if err != nil {
log.Debugf("bootstrap: failed to generate random peer ID: %s", err)
continue
}
key = randomPeerID.String()
log.Debugf("bootstrap: crawling from %s", key)
peers, err := l.DHT.GetClosestPeers(ctx, key)
if err != nil {
log.Debugf("bootstrap: failed to get closest peers with key=%s - error: %s", randomPeerID.String(), err)
time.Sleep(delayOnError)
delayOnError = time.Duration(float64(delayOnError) * 1.25)
if delayOnError > 5*time.Minute {
delayOnError = 5 * time.Minute
}
continue
}
delayOnError = 10 * time.Second
if len(peers) == 0 {
continue
}
peerID := peers[rand.Intn(len(peers))] //nolint:gosec
if peerID == l.Host.ID() {
log.Debugf("bootstrap: skipping self")
continue
}
log.Debugf("bootstrap: starting random walk from %s", peerID)
peerAddrInfo, err := l.resolvePeerAddress(ctx, peerID)
if err != nil {
log.Debugf("bootstrap: failed to resolve address for peer %s - %v", peerID, err)
continue
}
var peerInfos []*peer.AddrInfo
selected := &peerAddrInfo
crawl:
log.Debugf("bootstrap: crawling %s", selected.ID)
if err := l.Host.Connect(ctx, *selected); err != nil {
log.Debugf("bootstrap: failed to connect to peer %s: %s", peerID, err)
depth++
continue
}
peerInfos, err = messenger.GetClosestPeers(ctx, selected.ID, randomPeerID)
if err != nil {
log.Debugf("bootstrap: failed to get closest peers from %s: %s", selected.ID, err)
depth++
continue
}
if len(peerInfos) == 0 {
depth++
continue
}
selected = peerInfos[rand.Intn(len(peerInfos))] //nolint:gosec
if selected.ID == l.Host.ID() {
log.Debugf("bootstrap: skipping self")
depth++
continue
}
if depth < 20 {
randomPeerID, err = l.DHT.RoutingTable().GenRandPeerID(0)
if err != nil {
log.Debugf("bootstrap: failed to generate random peer ID: %s", err)
goto cooldown
}
depth++
goto crawl
}
// cooldown
cooldown:
depth = 0
minDelay := interval / 2
maxDelay := (3 * interval) / 2
delay := minDelay + time.Duration(rand.Int63n(int64(maxDelay-minDelay))) //nolint:gosec
log.Debugf("bootstrap: cooling down for %s", delay)
select {
case <-time.After(delay):
case <-ctx.Done():
return
}
interval = interval * 3 / 2
if interval > 4*time.Hour {
interval = 4 * time.Hour
}
}
}
}()
}
type dhtValidator struct {
PS peerstore.Peerstore
customNamespace string
}
// Validate validates an item placed into the dht.
func (d dhtValidator) Validate(key string, value []byte) error {
// empty value is considered deleting an item from the dht
if len(value) == 0 {
return nil
}
if !strings.HasPrefix(key, d.customNamespace) {
return errors.New("invalid key namespace")
}
// verify signature
var envelope commonproto.Advertisement
err := proto.Unmarshal(value, &envelope)
if err != nil {
return fmt.Errorf("failed to unmarshal envelope: %w", err)
}
pubKey, err := crypto.UnmarshalSecp256k1PublicKey(envelope.PublicKey)
if err != nil {
return fmt.Errorf("failed to unmarshal public key: %w", err)
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
envelope.PublicKey,
}, nil)
ok, err := pubKey.Verify(concatenatedBytes, envelope.Signature)
if err != nil {
return fmt.Errorf("failed to verify envelope: %w", err)
}
if !ok {
return errors.New("failed to verify envelope, public key didn't sign payload")
}
return nil
}
func (dhtValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
type dhtMessenger struct {
host host.Host
proto protocol.ID
}
func newDHTMessageSender(h host.Host, proto protocol.ID) dht_pb.MessageSender {
return &dhtMessenger{host: h, proto: proto}
}
func (m *dhtMessenger) SendRequest(ctx context.Context, p peer.ID, msg *dht_pb.Message) (*dht_pb.Message, error) {
s, err := m.host.NewStream(ctx, p, m.proto)
if err != nil {
return nil, fmt.Errorf("open stream: %w", err)
}
defer s.Close()
wr := protoio.NewDelimitedWriter(s)
if err := wr.WriteMsg(msg); err != nil {
_ = s.Reset()
return nil, fmt.Errorf("write message: %w", err)
}
r := msgio.NewVarintReaderSize(s, network.MessageSizeMax)
bytes, err := r.ReadMsg()
if err != nil {
_ = s.Reset()
return nil, fmt.Errorf("read message: %w", err)
}
defer r.ReleaseMsg(bytes)
reply := new(dht_pb.Message)
if err := reply.Unmarshal(bytes); err != nil {
_ = s.Reset()
return nil, fmt.Errorf("unmarshal message: %w", err)
}
return reply, nil
}
func (m *dhtMessenger) SendMessage(ctx context.Context, p peer.ID, msg *dht_pb.Message) error {
s, err := m.host.NewStream(ctx, p, m.proto)
if err != nil {
return fmt.Errorf("open stream: %w", err)
}
defer s.Close()
wr := protoio.NewDelimitedWriter(s)
if err := wr.WriteMsg(msg); err != nil {
_ = s.Reset()
return fmt.Errorf("write message: %w", err)
}
return s.CloseWrite()
}
package libp2p
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/peer"
dutil "github.com/libp2p/go-libp2p/p2p/discovery/util"
)
// DiscoverDialPeers discovers peers using randevouz point
func (l *Libp2p) DiscoverDialPeers(ctx context.Context) error {
foundPeers, err := l.findPeersFromRendezvousDiscovery(ctx)
if err != nil {
return err
}
if len(foundPeers) > 0 {
l.discoveredPeers = foundPeers
}
// filter out peers with no listening addresses and self host
filterSpec := NoAddrIDFilter{ID: l.Host.ID()}
l.discoveredPeers = PeerPassFilter(l.discoveredPeers, filterSpec)
l.dialPeers(ctx)
return nil
}
// advertiseForRendezvousDiscovery is used to advertise node using the dht by giving it the randevouz point.
func (l *Libp2p) advertiseForRendezvousDiscovery(context context.Context) error {
_, err := l.discovery.Advertise(context, l.config.Rendezvous)
return err
}
// findPeersFromRendezvousDiscovery uses the randevouz point to discover other peers.
func (l *Libp2p) findPeersFromRendezvousDiscovery(ctx context.Context) ([]peer.AddrInfo, error) {
peers, err := dutil.FindPeers(
ctx,
l.discovery,
l.config.Rendezvous,
discovery.Limit(l.config.PeerCountDiscoveryLimit),
)
if err != nil {
return nil, fmt.Errorf("failed to discover peers: %w", err)
}
return peers, nil
}
func (l *Libp2p) dialPeers(ctx context.Context) {
maxPeers := 16
peersToConnect := l.discoveredPeers
if len(peersToConnect) > maxPeers {
//nolint:gosec
r := rand.New(rand.NewSource(time.Now().UnixNano()))
r.Shuffle(len(peersToConnect), func(i, j int) {
peersToConnect[i], peersToConnect[j] = peersToConnect[j],
peersToConnect[i]
})
// Take only the first maxPeers
peersToConnect = peersToConnect[:maxPeers]
}
for _, p := range peersToConnect {
if p.ID == l.Host.ID() {
continue
}
if !l.PeerConnected(p.ID) {
go func(p peer.AddrInfo) {
dialCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
if err := l.Host.Connect(dialCtx, p); err != nil {
log.Debugf("couldn't establish connection with: %s - error: %v", p.ID, err)
return
}
log.Debugf("connected with: %s", p.ID)
}(p)
}
}
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/control"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
mafilt "github.com/whyrusleeping/multiaddr-filter"
)
var defaultServerFilters = []string{
"/ip4/10.0.0.0/ipcidr/8",
"/ip4/100.64.0.0/ipcidr/10",
"/ip4/169.254.0.0/ipcidr/16",
"/ip4/172.16.0.0/ipcidr/12",
"/ip4/192.0.0.0/ipcidr/24",
"/ip4/192.0.2.0/ipcidr/24",
"/ip4/192.168.0.0/ipcidr/16",
"/ip4/198.18.0.0/ipcidr/15",
"/ip4/198.51.100.0/ipcidr/24",
"/ip4/203.0.113.0/ipcidr/24",
"/ip4/240.0.0.0/ipcidr/4",
"/ip6/100::/ipcidr/64",
"/ip6/2001:2::/ipcidr/48",
"/ip6/2001:db8::/ipcidr/32",
"/ip6/fc00::/ipcidr/7",
"/ip6/fe80::/ipcidr/10",
}
// PeerFilter is an interface for filtering peers
// satisfaction of filter criteria allows the peer to pass
type PeerFilter interface {
satisfies(p peer.AddrInfo) bool
}
// NoAddrIDFilter filters out peers with no listening addresses
// and a peer with a specific ID
type NoAddrIDFilter struct {
ID peer.ID
}
func (f NoAddrIDFilter) satisfies(p peer.AddrInfo) bool {
return len(p.Addrs) > 0 && p.ID != f.ID
}
func PeerPassFilter(peers []peer.AddrInfo, pf PeerFilter) []peer.AddrInfo {
var filtered []peer.AddrInfo
for _, p := range peers {
if pf.satisfies(p) {
filtered = append(filtered, p)
}
}
return filtered
}
type filtersConnectionGater multiaddr.Filters
var _ connmgr.ConnectionGater = (*filtersConnectionGater)(nil)
func (f *filtersConnectionGater) InterceptAddrDial(_ peer.ID, addr multiaddr.Multiaddr) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(addr)
}
func (f *filtersConnectionGater) InterceptPeerDial(_ peer.ID) (allow bool) {
return true
}
func (f *filtersConnectionGater) InterceptAccept(connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptUpgraded(_ network.Conn) (allow bool, reason control.DisconnectReason) {
return true, 0
}
func makeAddrsFactory(announce []string, appendAnnouce []string, noAnnounce []string) func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var err error // To assign to the slice in the for loop
existing := make(map[string]bool) // To avoid duplicates
annAddrs := make([]multiaddr.Multiaddr, len(announce))
for i, addr := range announce {
annAddrs[i], err = multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
existing[addr] = true
}
appendAnnAddrs := make([]multiaddr.Multiaddr, 0)
for _, addr := range appendAnnouce {
if existing[addr] {
// skip AppendAnnounce that is on the Announce list already
continue
}
appendAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
appendAnnAddrs = append(appendAnnAddrs, appendAddr)
}
filters := multiaddr.NewFilters()
noAnnAddrs := map[string]bool{}
for _, addr := range noAnnounce {
f, err := mafilt.NewMask(addr)
if err == nil {
filters.AddFilter(*f, multiaddr.ActionDeny)
continue
}
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
noAnnAddrs[string(maddr.Bytes())] = true
}
return func(allAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
var addrs []multiaddr.Multiaddr
if len(annAddrs) > 0 {
addrs = annAddrs
} else {
addrs = allAddrs
}
addrs = append(addrs, appendAnnAddrs...)
var out []multiaddr.Multiaddr
for _, maddr := range addrs {
// check for exact matches
ok := noAnnAddrs[string(maddr.Bytes())]
// check for /ipcidr matches
if !ok && !filters.AddrBlocked(maddr) {
out = append(out, maddr)
}
}
return out
}
}
package libp2p
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/types"
)
// StreamHandler is a function type that processes data from a stream.
type StreamHandler func(stream network.Stream)
// HandlerRegistry manages the registration of stream handlers for different protocols.
type HandlerRegistry struct {
host host.Host
handlers map[protocol.ID]StreamHandler
bytesHandlers map[protocol.ID]func(data []byte)
mu sync.RWMutex
}
// NewHandlerRegistry creates a new handler registry instance.
func NewHandlerRegistry(host host.Host) *HandlerRegistry {
return &HandlerRegistry{
host: host,
handlers: make(map[protocol.ID]StreamHandler),
bytesHandlers: make(map[protocol.ID]func(data []byte)),
}
}
// RegisterHandlerWithStreamCallback registers a stream handler for a specific protocol.
func (r *HandlerRegistry) RegisterHandlerWithStreamCallback(messageType types.MessageType, handler StreamHandler) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.handlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.handlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(handler))
return nil
}
// RegisterHandlerWithBytesCallback registers a stream handler for a specific protocol and sends the bytes back to callback.
func (r *HandlerRegistry) RegisterHandlerWithBytesCallback(messageType types.MessageType, s StreamHandler, handler func(data []byte)) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.bytesHandlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.bytesHandlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(s))
return nil
}
// SendMessageToLocalHandler given the message type it sends data to the local handler found.
func (r *HandlerRegistry) SendMessageToLocalHandler(messageType types.MessageType, data []byte) {
r.mu.RLock()
defer r.mu.RUnlock()
protoID := protocol.ID(messageType)
h, ok := r.bytesHandlers[protoID]
if !ok {
return
}
// we need this goroutine to avoid blocking the caller goroutine
go h(data)
}
// UnregisterHandler unregisters a stream handler for a specific protocol.
func (r *HandlerRegistry) UnregisterHandler(messageType types.MessageType) {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
delete(r.handlers, protoID)
delete(r.bytesHandlers, protoID)
r.host.RemoveStreamHandler(protoID)
}
package libp2p
import (
"context"
"strings"
"time"
"github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/routing"
"github.com/libp2p/go-libp2p/p2p/host/autorelay"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/libp2p/go-libp2p/p2p/host/resource-manager"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/security/noise"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
mafilt "github.com/whyrusleeping/multiaddr-filter"
"gitlab.com/nunet/device-management-service/types"
)
// NewHost returns a new libp2p host with dht and other related settings.
func NewHost(ctx context.Context, config *types.Libp2pConfig, appScore func(p peer.ID) float64, scoreInspect pubsub.ExtendedPeerScoreInspectFn) (host.Host, *dht.IpfsDHT, *pubsub.PubSub, error) {
newPeer := make(chan peer.AddrInfo)
var idht *dht.IpfsDHT
connmgr, err := connmgr.NewConnManager(
100,
400,
connmgr.WithGracePeriod(time.Duration(config.GracePeriodMs)*time.Millisecond),
)
if err != nil {
return nil, nil, nil, err
}
filter := ma.NewFilters()
for _, s := range defaultServerFilters {
f, err := mafilt.NewMask(s)
if err != nil {
log.Errorf("incorrectly formatted address filter in config: %s - %v", s, err)
}
filter.AddFilter(*f, ma.ActionDeny)
}
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, nil, nil, err
}
var libp2pOpts []libp2p.Option
dhtOpts := []dht.Option{
dht.ProtocolPrefix(protocol.ID(config.DHTPrefix)),
dht.NamespacedValidator(strings.ReplaceAll(config.CustomNamespace, "/", ""), dhtValidator{PS: ps}),
dht.Mode(dht.ModeAutoServer),
}
// set up the resource manager
mem := int64(config.Memory)
if mem > 0 {
mem = 1024 * 1024 * mem
} else {
mem = 1024 * 1024 * 1024 // 1GB
}
fds := config.FileDescriptors
if fds == 0 {
fds = 512
}
limits := rcmgr.DefaultLimits
limits.SystemBaseLimit.ConnsInbound = 512
limits.SystemBaseLimit.ConnsOutbound = 512
limits.SystemBaseLimit.Conns = 1024
limits.SystemBaseLimit.StreamsInbound = 8192
limits.SystemBaseLimit.StreamsOutbound = 8192
limits.SystemBaseLimit.Streams = 16384
scaled := limits.Scale(mem, fds)
log.Infof("libp2p limits: %+v", scaled)
mgr, err := rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(scaled))
if err != nil {
return nil, nil, nil, err
}
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(config.ListenAddress...),
libp2p.ResourceManager(mgr),
libp2p.Identity(config.PrivateKey),
libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) {
idht, err = dht.New(ctx, h, dhtOpts...)
return idht, err
}),
libp2p.Peerstore(ps),
libp2p.Security(libp2ptls.ID, libp2ptls.New),
libp2p.Security(noise.ID, noise.New),
// libp2p.NoListenAddrs,
libp2p.ChainOptions(
libp2p.Transport(tcp.NewTCPTransport),
libp2p.Transport(quic.NewTransport),
libp2p.Transport(webtransport.New),
libp2p.Transport(ws.New),
),
libp2p.EnableNATService(),
libp2p.ConnectionManager(connmgr),
libp2p.EnableRelay(),
libp2p.EnableHolePunching(),
libp2p.EnableRelayService(
relay.WithLimit(&relay.RelayLimit{
Duration: 5 * time.Minute,
Data: 1 << 21, // 2 MiB
}),
),
libp2p.EnableAutoRelayWithPeerSource(
func(ctx context.Context, num int) <-chan peer.AddrInfo {
r := make(chan peer.AddrInfo)
go func() {
defer close(r)
for i := 0; i < num; i++ {
select {
case p := <-newPeer:
select {
case r <- p:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return r
},
autorelay.WithBootDelay(time.Minute),
autorelay.WithBackoff(30*time.Second),
autorelay.WithMinCandidates(2),
autorelay.WithMaxCandidates(3),
autorelay.WithNumRelays(2),
),
)
if config.Server {
libp2pOpts = append(libp2pOpts, libp2p.AddrsFactory(makeAddrsFactory([]string{}, []string{}, defaultServerFilters)))
libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater((*filtersConnectionGater)(filter)))
}
host, err := libp2p.New(libp2pOpts...)
if err != nil {
return nil, nil, nil, err
}
go watchForNewPeers(ctx, host, newPeer)
optsPS := []pubsub.Option{
pubsub.WithFloodPublish(true),
pubsub.WithMessageSigning(true),
pubsub.WithPeerScore(
&pubsub.PeerScoreParams{
SkipAtomicValidation: true,
Topics: make(map[string]*pubsub.TopicScoreParams),
TopicScoreCap: 10,
AppSpecificScore: appScore,
AppSpecificWeight: 1,
DecayInterval: time.Hour,
DecayToZero: 0.001,
RetainScore: 6 * time.Hour,
},
&pubsub.PeerScoreThresholds{
GossipThreshold: -500,
PublishThreshold: -1000,
GraylistThreshold: -2500,
AcceptPXThreshold: 0, // TODO for public mainnet we should limit to botostrappers and set them up without a mesh
OpportunisticGraftThreshold: 2.5,
},
),
pubsub.WithPeerExchange(true),
pubsub.WithPeerScoreInspect(scoreInspect, time.Second),
}
if config.GossipMaxMessageSize > 0 {
optsPS = append(optsPS, pubsub.WithMaxMessageSize(config.GossipMaxMessageSize))
}
gossip, err := pubsub.NewGossipSub(ctx, host, optsPS...)
// gossip, err := pubsub.NewGossipSubWithRouter(ctx, host, pubsub.DefaultGossipSubRouter(host), optsPS...)
if err != nil {
return nil, nil, nil, err
}
return host, idht, gossip, nil
}
func watchForNewPeers(ctx context.Context, host host.Host, newPeer chan peer.AddrInfo) {
sub, err := host.EventBus().Subscribe([]interface{}{
&event.EvtPeerIdentificationCompleted{},
&event.EvtPeerProtocolsUpdated{},
})
if err != nil {
log.Errorf("failed to subscribe to peer identification events: %v", err)
return
}
defer sub.Close()
for ctx.Err() == nil {
var ev any
select {
case <-ctx.Done():
return
case ev = <-sub.Out():
}
if ev, ok := ev.(event.EvtPeerIdentificationCompleted); ok {
var publicAddrs []ma.Multiaddr
for _, addr := range ev.ListenAddrs {
if manet.IsPublicAddr(addr) {
publicAddrs = append(publicAddrs, addr)
}
}
if len(publicAddrs) > 0 {
newPeer <- peer.AddrInfo{ID: ev.Peer, Addrs: publicAddrs}
}
}
}
}
package libp2p
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
cid "github.com/ipfs/go-cid"
dht "github.com/libp2p/go-libp2p-kad-dht"
kbucket "github.com/libp2p/go-libp2p-kbucket"
pubsub "github.com/libp2p/go-libp2p-pubsub"
libp2pdiscovery "github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
drouting "github.com/libp2p/go-libp2p/p2p/discovery/routing"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
multiaddr "github.com/multiformats/go-multiaddr"
multihash "github.com/multiformats/go-multihash"
msmux "github.com/multiformats/go-multistream"
"github.com/spf13/afero"
"google.golang.org/protobuf/proto"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/types"
)
const (
MB = 1024 * 1024
maxMessageLengthMB = 1
ValidationAccept = pubsub.ValidationAccept
ValidationReject = pubsub.ValidationReject
ValidationIgnore = pubsub.ValidationIgnore
readTimeout = 30 * time.Second
sendSemaphoreLimit = 4096
)
type (
PeerID = peer.ID
ProtocolID = protocol.ID
Topic = pubsub.Topic
PubSub = pubsub.PubSub
ValidationResult = pubsub.ValidationResult
Validator func([]byte, interface{}) (ValidationResult, interface{})
PeerScoreSnapshot = pubsub.PeerScoreSnapshot
)
// Libp2p contains the configuration for a Libp2p instance.
//
// TODO-suggestion: maybe we should call it something else like Libp2pPeer,
// Libp2pHost or just Peer (callers would use libp2p.Peer...)
type Libp2p struct {
Host host.Host
DHT *dht.IpfsDHT
PS peerstore.Peerstore
pubsub *PubSub
ctx context.Context
cancel func()
mx sync.Mutex
pubsubAppScore func(peer.ID) float64
pubsubScore map[peer.ID]*PeerScoreSnapshot
topicMux sync.RWMutex
pubsubTopics map[string]*Topic
topicValidators map[string]map[uint64]Validator
topicSubscription map[string]map[uint64]*pubsub.Subscription
nextTopicSubID uint64
// send backpressure semaphore
sendSemaphore chan struct{}
// a list of peers discovered by discovery
discoveredPeers []peer.AddrInfo
discovery libp2pdiscovery.Discovery
// services
pingService *ping.PingService
// tasks
discoveryTask *bt.Task
advertiseRendezvousTask *bt.Task
handlerRegistry *HandlerRegistry
config *types.Libp2pConfig
// dependencies (db, filesystem...)
fs afero.Fs
}
// New creates a libp2p instance.
//
// TODO-Suggestion: move types.Libp2pConfig to here for better readability.
// Unless there is a reason to keep within types.
func New(config *types.Libp2pConfig, fs afero.Fs) (*Libp2p, error) {
if config == nil {
return nil, errors.New("config is nil")
}
if config.Scheduler == nil {
return nil, errors.New("scheduler is nil")
}
return &Libp2p{
config: config,
discoveredPeers: make([]peer.AddrInfo, 0),
pubsubTopics: make(map[string]*pubsub.Topic),
topicSubscription: make(map[string]map[uint64]*pubsub.Subscription),
topicValidators: make(map[string]map[uint64]Validator),
sendSemaphore: make(chan struct{}, sendSemaphoreLimit),
fs: fs,
}, nil
}
// Init initializes a libp2p host with its dependencies.
func (l *Libp2p) Init() error {
ctx, cancel := context.WithCancel(context.Background())
host, dht, pubsub, err := NewHost(ctx, l.config, l.broadcastAppScore, l.broadcastScoreInspect)
if err != nil {
cancel()
log.Error(err)
return err
}
l.ctx = ctx
l.cancel = cancel
l.Host = host
l.DHT = dht
l.PS = host.Peerstore()
l.discovery = drouting.NewRoutingDiscovery(dht)
l.pubsub = pubsub
l.handlerRegistry = NewHandlerRegistry(host)
return nil
}
// Start performs network bootstrapping, peer discovery and protocols handling.
func (l *Libp2p) Start() error {
// set stream handlers
l.registerStreamHandlers()
// connect to bootstrap nodes
err := l.ConnectToBootstrapNodes(l.ctx)
if err != nil {
log.Errorf("failed to connect to bootstrap nodes: %v", err)
return err
}
err = l.BootstrapDHT(l.ctx)
if err != nil {
log.Errorf("failed to bootstrap dht: %v", err)
return err
}
// Start random walk
l.startRandomWalk(l.ctx)
// watch for local address change
go l.watchForAddrsChange(l.ctx)
// discover
go func() {
// wait for dht bootstrap
time.Sleep(1 * time.Minute)
// advertise randevouz discovery
err = l.advertiseForRendezvousDiscovery(l.ctx)
if err != nil {
log.Warnf("failed to advertise rendezvous point: %v", err)
}
err = l.DiscoverDialPeers(l.ctx)
if err != nil {
log.Warnf("failed to discover peers: %v", err)
}
}()
// register period peer discoveryTask task
discoveryTask := &bt.Task{
Name: "Peer Discovery",
Description: "Periodic task to discover new peers every 15 minutes",
Function: func(_ interface{}) error {
return l.DiscoverDialPeers(l.ctx)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 15 * time.Minute}},
}
l.discoveryTask = l.config.Scheduler.AddTask(discoveryTask)
// register rendezvous advertisement task
advertiseRendezvousTask := &bt.Task{
Name: "Rendezvous advertisement",
Description: "Periodic task to advertise a rendezvous point every 6 hours",
Function: func(_ interface{}) error {
return l.advertiseForRendezvousDiscovery(l.ctx)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 6 * time.Hour}},
}
l.advertiseRendezvousTask = l.config.Scheduler.AddTask(advertiseRendezvousTask)
l.config.Scheduler.Start()
return nil
}
// RegisterStreamMessageHandler registers a stream handler for a specific protocol.
func (l *Libp2p) RegisterStreamMessageHandler(messageType types.MessageType, handler StreamHandler) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithStreamCallback(messageType, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// RegisterBytesMessageHandler registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) RegisterBytesMessageHandler(messageType types.MessageType, handler func(data []byte)) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithBytesCallback(messageType, l.handleReadBytesFromStream, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// HandleMessage registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) HandleMessage(messageType string, handler func(data []byte)) error {
return l.RegisterBytesMessageHandler(types.MessageType(messageType), handler)
}
func (l *Libp2p) handleReadBytesFromStream(s network.Stream) {
l.handlerRegistry.mu.RLock()
callback, ok := l.handlerRegistry.bytesHandlers[s.Protocol()]
l.handlerRegistry.mu.RUnlock()
if !ok {
_ = s.Reset()
return
}
if err := s.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
_ = s.Reset()
log.Warnf("error setting read deadline: %s", err)
return
}
c := bufio.NewReader(s)
defer s.Close()
// read the first 8 bytes to determine the size of the message
msgLengthBuffer := make([]byte, 8)
_, err := c.Read(msgLengthBuffer)
if err != nil {
log.Debugf("error reading message length: %s", err)
_ = s.Reset()
return
}
// create a buffer with the size of the message and then read until its full
lengthPrefix := binary.LittleEndian.Uint64(msgLengthBuffer)
// check if the message length is greater than max allowed
if lengthPrefix > maxMessageLengthMB*MB {
_ = s.Reset()
log.Warnf("message length exceeds maximum: %d", lengthPrefix)
return
}
buf := make([]byte, lengthPrefix)
// read the full message
_, err = io.ReadFull(c, buf)
if err != nil {
log.Debugf("error reading message: %s", err)
_ = s.Reset()
return
}
_ = s.Close()
callback(buf)
}
// UnregisterMessageHandler unregisters a stream handler for a specific protocol.
func (l *Libp2p) UnregisterMessageHandler(messageType string) {
l.handlerRegistry.UnregisterHandler(types.MessageType(messageType))
}
// SendMessage asynchronously sends a message to a peer
func (l *Libp2p) SendMessage(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error {
pid, err := peer.Decode(hostID)
if err != nil {
return fmt.Errorf("send: invalid peer ID: %w", err)
}
// we are delivering a message to ourself
// we should use the handler to send the message to the handler directly which has been previously registered.
if pid == l.Host.ID() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data)
return nil
}
ctx, cancel := context.WithTimeout(ctx, time.Until(expiry))
select {
case l.sendSemaphore <- struct{}{}:
go func() {
defer cancel()
defer func() { <-l.sendSemaphore }()
l.sendMessage(ctx, pid, msg, expiry, nil)
}()
return nil
case <-ctx.Done():
cancel()
return ctx.Err()
}
}
// SendMessageSync synchronously sends a message to a peer
func (l *Libp2p) SendMessageSync(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error {
pid, err := peer.Decode(hostID)
if err != nil {
return fmt.Errorf("send: invalid peer ID: %w", err)
}
if pid == l.Host.ID() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data)
return nil
}
ctx, cancel := context.WithTimeout(ctx, time.Until(expiry))
defer cancel()
result := make(chan error, 1)
l.sendMessage(ctx, pid, msg, expiry, result)
return <-result
}
// workaround for https://github.com/libp2p/go-libp2p/issues/2983
func (l *Libp2p) newStream(ctx context.Context, pid peer.ID, proto protocol.ID) (network.Stream, error) {
s, err := l.Host.Network().NewStream(network.WithNoDial(ctx, "already dialed"), pid)
if err != nil {
return nil, err
}
selected, err := msmux.SelectOneOf([]protocol.ID{proto}, s)
if err != nil {
_ = s.Reset()
return nil, err
}
if err := s.SetProtocol(selected); err != nil {
_ = s.Reset()
return nil, err
}
return s, nil
}
func (l *Libp2p) sendMessage(ctx context.Context, pid peer.ID, msg types.MessageEnvelope, expiry time.Time, result chan error) {
var err error
defer func() {
if result != nil {
result <- err
}
}()
if !l.PeerConnected(pid) {
var ai peer.AddrInfo
ai, err = l.resolvePeerAddress(ctx, pid)
if err != nil {
log.Warnf("send: error resolving addresses for peer %s: %s", pid, err)
return
}
if err = l.Host.Connect(ctx, ai); err != nil {
log.Warnf("send: failed to connect to peer %s: %s", pid, err)
return
}
}
requestBufferSize := 8 + len(msg.Data)
if requestBufferSize > maxMessageLengthMB*MB {
log.Warnf("send: message size %d is greater than limit %d bytes", requestBufferSize, maxMessageLengthMB*MB)
err = fmt.Errorf("message too large")
return
}
ctx = network.WithAllowLimitedConn(ctx, "send message")
stream, err := l.newStream(ctx, pid, protocol.ID(msg.Type))
if err != nil {
log.Warnf("send: failed to open stream to peer %s: %s", pid, err)
return
}
defer stream.Close()
if err = stream.SetWriteDeadline(expiry); err != nil {
_ = stream.Reset()
log.Warnf("send: failed to set write deadline to peer %s: %s", pid, err)
return
}
requestPayloadWithLength := make([]byte, requestBufferSize)
binary.LittleEndian.PutUint64(requestPayloadWithLength, uint64(len(msg.Data)))
copy(requestPayloadWithLength[8:], msg.Data)
if _, err = stream.Write(requestPayloadWithLength); err != nil {
_ = stream.Reset()
log.Warnf("send: failed to send message to peer %s: %s", pid, err)
}
if err = stream.CloseWrite(); err != nil {
_ = stream.Reset()
log.Warnf("send: failed to flush output to peer %s: %s", pid, err)
}
log.Debugf("send %d bytes to peer %s", len(requestPayloadWithLength), pid)
}
// OpenStream opens a stream to a remote address and returns the stream for the caller to handle.
func (l *Libp2p) OpenStream(ctx context.Context, addr string, messageType types.MessageType) (network.Stream, error) {
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil, fmt.Errorf("invalid multiaddress: %w", err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(maddr)
if err != nil {
return nil, fmt.Errorf("could not resolve peer info: %w", err)
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(messageType))
if err != nil {
return nil, fmt.Errorf("failed to open stream: %w", err)
}
return stream, nil
}
// GetMultiaddr returns the peer's multiaddr.
func (l *Libp2p) GetMultiaddr() ([]multiaddr.Multiaddr, error) {
peerInfo := peer.AddrInfo{
ID: l.Host.ID(),
Addrs: l.Host.Addrs(),
}
return peer.AddrInfoToP2pAddrs(&peerInfo)
}
// Stop performs a cleanup of any resources used in this package.
func (l *Libp2p) Stop() error {
var errorMessages []string
l.cancel()
l.config.Scheduler.RemoveTask(l.discoveryTask.ID)
l.config.Scheduler.RemoveTask(l.advertiseRendezvousTask.ID)
if err := l.DHT.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if err := l.Host.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if len(errorMessages) > 0 {
return errors.New(strings.Join(errorMessages, "; "))
}
return nil
}
// Stat returns the status about the libp2p network.
func (l *Libp2p) Stat() types.NetworkStats {
lAddrs := make([]string, 0, len(l.Host.Addrs()))
for _, addr := range l.Host.Addrs() {
lAddrs = append(lAddrs, addr.String())
}
return types.NetworkStats{
ID: l.Host.ID().String(),
ListenAddr: strings.Join(lAddrs, ", "),
}
}
// Ping the remote address. The remote address is the encoded peer id which will be decoded and used here.
//
// TODO (Return error once): something that was confusing me when using this method is that the error is
// returned twice if any. Once as a field of PingResult and one as a return value.
func (l *Libp2p) Ping(ctx context.Context, peerIDAddress string, timeout time.Duration) (types.PingResult, error) {
// avoid dial to self attempt
if peerIDAddress == l.Host.ID().String() {
err := errors.New("can't ping self")
return types.PingResult{Success: false, Error: err}, err
}
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
remotePeer, err := peer.Decode(peerIDAddress)
if err != nil {
return types.PingResult{}, err
}
pingChan := ping.Ping(pingCtx, l.Host, remotePeer)
select {
case res := <-pingChan:
if res.Error != nil {
log.Errorf("failed to ping peer %s: %v", peerIDAddress, res.Error)
return types.PingResult{
Success: false,
RTT: res.RTT,
Error: res.Error,
}, res.Error
}
return types.PingResult{
RTT: res.RTT,
Success: true,
}, nil
case <-pingCtx.Done():
return types.PingResult{
Error: pingCtx.Err(),
}, pingCtx.Err()
}
}
// ResolveAddress resolves the address by given a peer id.
func (l *Libp2p) ResolveAddress(ctx context.Context, id string) ([]string, error) {
ai, err := l.resolveAddress(ctx, id)
if err != nil {
return nil, err
}
result := make([]string, 0, len(ai.Addrs))
for _, addr := range ai.Addrs {
result = append(result, fmt.Sprintf("%s/p2p/%s", addr, id))
}
return result, nil
}
func (l *Libp2p) resolveAddress(ctx context.Context, id string) (peer.AddrInfo, error) {
pid, err := peer.Decode(id)
if err != nil {
return peer.AddrInfo{}, fmt.Errorf("failed to resolve invalid peer: %w", err)
}
return l.resolvePeerAddress(ctx, pid)
}
func (l *Libp2p) resolvePeerAddress(ctx context.Context, pid peer.ID) (peer.AddrInfo, error) {
// resolve ourself
if l.Host.ID() == pid {
addrs, err := l.GetMultiaddr()
if err != nil {
return peer.AddrInfo{}, fmt.Errorf("failed to resolve self: %w", err)
}
return peer.AddrInfo{ID: pid, Addrs: addrs}, nil
}
if l.PeerConnected(pid) {
addrs := l.Host.Peerstore().Addrs(pid)
return peer.AddrInfo{
ID: pid,
Addrs: addrs,
}, nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
pi, err := l.DHT.FindPeer(ctx, pid)
if err != nil {
return peer.AddrInfo{}, fmt.Errorf("failed to resolve address for peer %s: %w", pid, err)
}
return pi, nil
}
// Query return all the advertisements in the network related to a key.
// The network is queried to find providers for the given key, and peers which we aren't connected to can be retrieved.
func (l *Libp2p) Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error) {
if key == "" {
return nil, errors.New("advertisement key is empty")
}
customCID, err := createCIDFromKey(key)
if err != nil {
return nil, fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
addrInfo, err := l.DHT.FindProviders(ctx, customCID)
if err != nil {
return nil, fmt.Errorf("failed to find providers for key %s: %w", key, err)
}
advertisements := make([]*commonproto.Advertisement, 0)
for _, v := range addrInfo {
// TODO: use go routines to get the values in parallel.
bytesAdvertisement, err := l.DHT.GetValue(ctx, l.getCustomNamespace(key, v.ID.String()))
if err != nil {
continue
}
var ad commonproto.Advertisement
if err := proto.Unmarshal(bytesAdvertisement, &ad); err != nil {
return nil, fmt.Errorf("failed to unmarshal advertisement payload: %w", err)
}
advertisements = append(advertisements, &ad)
}
return advertisements, nil
}
// Advertise given data and a key pushes the data to the dht.
func (l *Libp2p) Advertise(ctx context.Context, key string, data []byte) error {
if key == "" {
return errors.New("advertisement key is empty")
}
pubKeyBytes, err := l.getPublicKey()
if err != nil {
return fmt.Errorf("failed to get public key: %w", err)
}
envelope := &commonproto.Advertisement{
PeerId: l.Host.ID().String(),
Timestamp: time.Now().Unix(),
Data: data,
PublicKey: pubKeyBytes,
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
pubKeyBytes,
}, nil)
sig, err := l.sign(concatenatedBytes)
if err != nil {
return fmt.Errorf("failed to sign advertisement envelope content: %w", err)
}
envelope.Signature = sig
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
return fmt.Errorf("failed to marshal advertise envelope: %w", err)
}
customCID, err := createCIDFromKey(key)
if err != nil {
return fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
err = l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), envelopeBytes)
if err != nil {
return fmt.Errorf("failed to put key %s into the dht: %w", key, err)
}
err = l.DHT.Provide(ctx, customCID, true)
if err != nil {
return fmt.Errorf("failed to provide key %s into the dht: %w", key, err)
}
return nil
}
// Unadvertise removes the data from the dht.
func (l *Libp2p) Unadvertise(ctx context.Context, key string) error {
err := l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), nil)
if err != nil {
return fmt.Errorf("failed to remove key %s from the DHT: %w", key, err)
}
return nil
}
// Publish publishes data to a topic.
// The requirements are that only one topic handler should exist per topic.
func (l *Libp2p) Publish(ctx context.Context, topic string, data []byte) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to publish: %w", err)
}
err = topicHandler.Publish(ctx, data)
if err != nil {
return fmt.Errorf("failed to publish to topic %s: %w", topic, err)
}
return nil
}
// Subscribe subscribes to a topic and sends the messages to the handler.
func (l *Libp2p) Subscribe(ctx context.Context, topic string, handler func(data []byte), validator Validator) (uint64, error) {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return 0, fmt.Errorf("failed to subscribe to topic: %w", err)
}
sub, err := topicHandler.Subscribe()
if err != nil {
return 0, fmt.Errorf("failed to subscribe to topic %s: %w", topic, err)
}
l.topicMux.Lock()
subID := l.nextTopicSubID
l.nextTopicSubID++
topicMap, ok := l.topicSubscription[topic]
if !ok {
topicMap = make(map[uint64]*pubsub.Subscription)
l.topicSubscription[topic] = topicMap
}
if validator != nil {
validatorMap, ok := l.topicValidators[topic]
if !ok {
if err := l.pubsub.RegisterTopicValidator(topic, l.validate); err != nil {
sub.Cancel()
return 0, fmt.Errorf("failed to register topic validator: %w", err)
}
validatorMap = make(map[uint64]Validator)
l.topicValidators[topic] = validatorMap
}
validatorMap[subID] = validator
}
topicMap[subID] = sub
l.topicMux.Unlock()
go func() {
for {
msg, err := sub.Next(ctx)
if err != nil {
continue
}
handler(msg.Data)
}
}()
return subID, nil
}
func (l *Libp2p) validate(_ context.Context, _ peer.ID, msg *pubsub.Message) ValidationResult {
l.topicMux.RLock()
validators, ok := l.topicValidators[msg.GetTopic()]
l.topicMux.RUnlock()
if !ok {
return ValidationAccept
}
for _, validator := range validators {
result, validatorData := validator(msg.Data, msg.ValidatorData)
if result != ValidationAccept {
return result
}
msg.ValidatorData = validatorData
}
return ValidationAccept
}
func (l *Libp2p) SetupBroadcastTopic(topic string, setup func(*Topic) error) error {
t, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to %s", topic)
}
return setup(t)
}
func (l *Libp2p) SetBroadcastAppScore(f func(peer.ID) float64) {
l.mx.Lock()
defer l.mx.Unlock()
l.pubsubAppScore = f
}
func (l *Libp2p) broadcastAppScore(p peer.ID) float64 {
f := func(peer.ID) float64 { return 0 }
l.mx.Lock()
if l.pubsubAppScore != nil {
f = l.pubsubAppScore
}
l.mx.Unlock()
return f(p)
}
func (l *Libp2p) GetBroadcastScore() map[peer.ID]*PeerScoreSnapshot {
l.mx.Lock()
defer l.mx.Unlock()
return l.pubsubScore
}
func (l *Libp2p) broadcastScoreInspect(score map[peer.ID]*PeerScoreSnapshot) {
l.mx.Lock()
defer l.mx.Unlock()
l.pubsubScore = score
}
func (l *Libp2p) watchForAddrsChange(ctx context.Context) {
sub, err := l.Host.EventBus().Subscribe(&event.EvtLocalAddressesUpdated{})
if err != nil {
log.Errorf("failed to subscribe to event bus: %v", err)
return
}
for {
select {
case <-ctx.Done():
return
case <-sub.Out():
log.Debug("network address changed. trying to be bootstrap again.")
if err = l.ConnectToBootstrapNodes(l.ctx); err != nil {
log.Errorf("failed to start network: %v", err)
}
}
}
}
func (l *Libp2p) Notify(ctx context.Context, preconnected func(peer.ID, []protocol.ID, int), connected, disconnected func(peer.ID), identified, updated func(peer.ID, []protocol.ID)) error {
sub, err := l.Host.EventBus().Subscribe([]interface{}{
&event.EvtPeerConnectednessChanged{},
&event.EvtPeerIdentificationCompleted{},
&event.EvtPeerProtocolsUpdated{},
})
if err != nil {
return fmt.Errorf("failed to subscribe to event bus: %w", err)
}
for _, p := range l.Host.Network().Peers() {
switch l.Host.Network().Connectedness(p) {
case network.Limited:
fallthrough
case network.Connected:
protos, _ := l.Host.Peerstore().GetProtocols(p)
preconnected(p, protos, len(l.Host.Network().ConnsToPeer(p)))
}
}
go func() {
defer sub.Close()
for ctx.Err() == nil {
var ev any
select {
case <-ctx.Done():
return
case ev = <-sub.Out():
switch evt := ev.(type) {
case event.EvtPeerConnectednessChanged:
switch evt.Connectedness {
case network.Limited:
fallthrough
case network.Connected:
connected(evt.Peer)
case network.NotConnected:
disconnected(evt.Peer)
}
case event.EvtPeerIdentificationCompleted:
identified(evt.Peer, evt.Protocols)
case event.EvtPeerProtocolsUpdated:
updated(evt.Peer, evt.Added)
}
}
}
}()
return nil
}
func (l *Libp2p) PeerConnected(p PeerID) bool {
switch l.Host.Network().Connectedness(p) {
case network.Limited:
return true
case network.Connected:
return true
default:
return false
}
}
// getOrJoinTopicHandler gets the topic handler, it will be created if it doesn't exist.
// for publishing and subscribing its needed therefore its implemented in this function.
func (l *Libp2p) getOrJoinTopicHandler(topic string) (*pubsub.Topic, error) {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
t, err := l.pubsub.Join(topic)
if err != nil {
return nil, fmt.Errorf("failed to join topic %s: %w", topic, err)
}
topicHandler = t
l.pubsubTopics[topic] = t
}
return topicHandler, nil
}
// Unsubscribe cancels the subscription to a topic
func (l *Libp2p) Unsubscribe(topic string, subID uint64) error {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to topic: %s", topic)
}
topicValidators, ok := l.topicValidators[topic]
if ok {
delete(topicValidators, subID)
}
// delete subscription handler and subscription
topicSubscriptions, ok := l.topicSubscription[topic]
if ok {
sub, ok := topicSubscriptions[subID]
if ok {
sub.Cancel()
delete(topicSubscriptions, subID)
}
}
if len(topicSubscriptions) == 0 {
delete(l.pubsubTopics, topic)
if err := topicHandler.Close(); err != nil {
return fmt.Errorf("failed to close topic handler: %w", err)
}
}
return nil
}
func (l *Libp2p) VisiblePeers() []peer.AddrInfo {
return l.discoveredPeers
}
func (l *Libp2p) KnownPeers() ([]peer.AddrInfo, error) {
knownPeers := l.Host.Peerstore().Peers()
peers := make([]peer.AddrInfo, 0, len(knownPeers))
for _, p := range knownPeers {
peers = append(peers, peer.AddrInfo{ID: p})
}
return peers, nil
}
func (l *Libp2p) DumpDHTRoutingTable() ([]kbucket.PeerInfo, error) {
rt := l.DHT.RoutingTable()
return rt.GetPeerInfos(), nil
}
func (l *Libp2p) registerStreamHandlers() {
l.pingService = ping.NewPingService(l.Host)
l.Host.SetStreamHandler(protocol.ID("/ipfs/ping/1.0.0"), l.pingService.PingHandler)
}
func (l *Libp2p) sign(data []byte) ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
signature, err := privKey.Sign(data)
if err != nil {
return nil, fmt.Errorf("failed to sign data: %w", err)
}
return signature, nil
}
func (l *Libp2p) getPublicKey() ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
pubKey := privKey.GetPublic()
return pubKey.Raw()
}
func (l *Libp2p) getCustomNamespace(key, peerID string) string {
return fmt.Sprintf("%s-%s-%s", l.config.CustomNamespace, key, peerID)
}
func createCIDFromKey(key string) (cid.Cid, error) {
hash := sha256.Sum256([]byte(key))
mh, err := multihash.Encode(hash[:], multihash.SHA2_256)
if err != nil {
return cid.Cid{}, err
}
return cid.NewCidV1(cid.Raw, mh), nil
}
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/spf13/afero"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
type (
PeerID = libp2p.PeerID
ProtocolID = libp2p.ProtocolID
Topic = libp2p.Topic
Validator = libp2p.Validator
ValidationResult = libp2p.ValidationResult
PeerScoreSnapshot = libp2p.PeerScoreSnapshot
)
const (
ValidationAccept = libp2p.ValidationAccept
ValidationReject = libp2p.ValidationReject
ValidationIgnore = libp2p.ValidationIgnore
)
// Messenger defines the interface for sending messages.
type Messenger interface {
// SendMessage asynchronously sends a message to the given peer.
SendMessage(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error
// SendMessageSync synchronously sends a message to the given peer.
// This method blocks until the message has been sent.
SendMessageSync(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error
}
type Network interface {
// Messenger embedded interface
Messenger
// Init initializes the network
Init() error
// Start starts the network
Start() error
// Stat returns the network information
Stat() types.NetworkStats
// Ping pings the given address and returns the PingResult
Ping(ctx context.Context, address string, timeout time.Duration) (types.PingResult, error)
// HandleMessage is responsible for registering a message type and its handler.
HandleMessage(messageType string, handler func(data []byte)) error
// UnregisterMessageHandler unregisters a stream handler for a specific protocol.
UnregisterMessageHandler(messageType string)
// ResolveAddress given an id it retruns the address of the peer.
// In libp2p, id represents the peerID and the response is the addrinfo
ResolveAddress(ctx context.Context, id string) ([]string, error)
// Advertise advertises the given data with the given adId
// such as advertising device capabilities on the DHT
Advertise(ctx context.Context, key string, data []byte) error
// Unadvertise stops advertising data corresponding to the given adId
Unadvertise(ctx context.Context, key string) error
// Query returns the network advertisement
Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error)
// Publish publishes the given data to the given topic if the network
// type allows publish/subscribe functionality such as gossipsub or nats
Publish(ctx context.Context, topic string, data []byte) error
// Subscribe subscribes to the given topic and calls the handler function
// if the network type allows it similar to Publish()
Subscribe(ctx context.Context, topic string, handler func(data []byte), validator libp2p.Validator) (uint64, error)
// Unsubscribe from a topic
Unsubscribe(topic string, subID uint64) error
// SetupBroadcastTopic allows the application to configure pubsub topic directly
SetupBroadcastTopic(topic string, setup func(*Topic) error) error
// SetupBroadcastAppScore allows the application to configure application level
// scoring for pubsub
SetBroadcastAppScore(func(PeerID) float64)
// GetBroadcastScore returns the latest broadcast score snapshot
GetBroadcastScore() map[PeerID]*PeerScoreSnapshot
// Notify allows the application to receive notifications about peer connections
// and disconnecions
Notify(ctx context.Context, preconnected func(PeerID, []ProtocolID, int), connected, disconnected func(PeerID), identified, updated func(PeerID, []ProtocolID)) error
// PeerConnected returs true if the peer is currently connected
PeerConnected(p PeerID) bool
// Stop stops the network including any existing advertisements and subscriptions
Stop() error
}
// NewNetwork returns a new network given the configuration.
func NewNetwork(netConfig *types.NetworkConfig, fs afero.Fs) (Network, error) {
// TODO: probable additional params to receive: DB, FileSystem
if netConfig == nil {
return nil, errors.New("network configuration is nil")
}
switch netConfig.Type {
case types.Libp2pNetwork:
ln, err := libp2p.New(&netConfig.Libp2pConfig, fs)
return ln, err
case types.NATSNetwork:
return nil, errors.New("not implemented")
default:
return nil, fmt.Errorf("unsupported network type: %s", netConfig.Type)
}
}
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.1
// protoc v5.26.1
// source: common.proto
package common
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Advertisement is the envelope to advertise peers payload.
type Advertisement struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PeerId string `protobuf:"bytes,1,opt,name=peer_id,json=peerId,proto3" json:"peer_id,omitempty"`
Timestamp int64 `protobuf:"varint,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
Data []byte `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"`
Signature []byte `protobuf:"bytes,4,opt,name=signature,proto3" json:"signature,omitempty"`
PublicKey []byte `protobuf:"bytes,5,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
}
func (x *Advertisement) Reset() {
*x = Advertisement{}
if protoimpl.UnsafeEnabled {
mi := &file_common_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Advertisement) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Advertisement) ProtoMessage() {}
func (x *Advertisement) ProtoReflect() protoreflect.Message {
mi := &file_common_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Advertisement.ProtoReflect.Descriptor instead.
func (*Advertisement) Descriptor() ([]byte, []int) {
return file_common_proto_rawDescGZIP(), []int{0}
}
func (x *Advertisement) GetPeerId() string {
if x != nil {
return x.PeerId
}
return ""
}
func (x *Advertisement) GetTimestamp() int64 {
if x != nil {
return x.Timestamp
}
return 0
}
func (x *Advertisement) GetData() []byte {
if x != nil {
return x.Data
}
return nil
}
func (x *Advertisement) GetSignature() []byte {
if x != nil {
return x.Signature
}
return nil
}
func (x *Advertisement) GetPublicKey() []byte {
if x != nil {
return x.PublicKey
}
return nil
}
var File_common_proto protoreflect.FileDescriptor
var file_common_proto_rawDesc = []byte{
0x0a, 0x0c, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06,
0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x22, 0x97, 0x01, 0x0a, 0x0d, 0x41, 0x64, 0x76, 0x65, 0x72,
0x74, 0x69, 0x73, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x70, 0x65, 0x65, 0x72,
0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x65, 0x65, 0x72, 0x49,
0x64, 0x12, 0x1c, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02,
0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12,
0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64,
0x61, 0x74, 0x61, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65,
0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72,
0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18,
0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_common_proto_rawDescOnce sync.Once
file_common_proto_rawDescData = file_common_proto_rawDesc
)
func file_common_proto_rawDescGZIP() []byte {
file_common_proto_rawDescOnce.Do(func() {
file_common_proto_rawDescData = protoimpl.X.CompressGZIP(file_common_proto_rawDescData)
})
return file_common_proto_rawDescData
}
var file_common_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
var file_common_proto_goTypes = []interface{}{
(*Advertisement)(nil), // 0: common.Advertisement
}
var file_common_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_common_proto_init() }
func file_common_proto_init() {
if File_common_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_common_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*Advertisement); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_common_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_common_proto_goTypes,
DependencyIndexes: file_common_proto_depIdxs,
MessageInfos: file_common_proto_msgTypes,
}.Build()
File_common_proto = out.File
file_common_proto_rawDesc = nil
file_common_proto_goTypes = nil
file_common_proto_depIdxs = nil
}
package basiccontroller
import (
"context"
"fmt"
"os"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// BasicVolumeController is the default implementation of the VolumeController.
// It persists storage volumes information using the StorageVolume.
type BasicVolumeController struct {
// repo is the repository for storage volume operations
repo repositories.StorageVolume
// basePath is the base path where volumes are stored under
basePath string
// file system to act upon
FS afero.Fs
}
// NewDefaultVolumeController returns a new instance of BasicVolumeController
//
// TODO-BugFix [path]: volBasePath might not end with `/`, causing errors when calling methods.
// We need to validate it using the `path` library or just verifying the string.
func NewDefaultVolumeController(repo repositories.StorageVolume, volBasePath string, fs afero.Fs) (*BasicVolumeController, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_controller_init_duration", "opentelemetry", "log")
defer cancel()
vc := &BasicVolumeController{
repo: repo,
basePath: volBasePath,
FS: fs,
}
st.Info(ctx, "volume_controller_init_success", nil)
return vc, nil
}
// CreateVolume creates a new storage volume given a source (S3, IPFS, job, etc). The
// creation of a storage volume effectively creates an empty directory in the local filesystem
// and writes a record in the database.
//
// The directory name follows the format: `<volSource> + "-" + <name>
// where `name` is random.
//
// TODO-maybe [withName]: allow callers to specify custom name for path
func (vc *BasicVolumeController) CreateVolume(volSource storage.VolumeSource, opts ...storage.CreateVolOpt) (types.StorageVolume, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_create_duration", "opentelemetry", "log")
defer cancel()
vol := types.StorageVolume{
Private: false,
ReadOnly: false,
EncryptionType: types.EncryptionTypeNull,
}
for _, opt := range opts {
opt(&vol)
}
randomStr, err := utils.RandomString(16)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create random string: %w", err)
}
vol.Path = vc.basePath + string(volSource) + "-" + randomStr
ctx = context.WithValue(ctx, pathKey, vol.Path)
if err := vc.FS.Mkdir(vol.Path, os.ModePerm); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_create_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %w", err)
}
createdVol, err := vc.repo.Create(ctx, vol)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_create_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume in repository: %w", err)
}
ctx = context.WithValue(ctx, volumeIDKey, createdVol.ID)
st.Info(ctx, "volume_create_success", nil)
return createdVol, nil
}
// LockVolume makes the volume read-only, not only changing the field value but also changing file permissions.
// It should be used after all necessary data has been written.
// It optionally can also set the CID and mark the volume as private.
//
// TODO-maybe [CID]: maybe calculate CID of every volume in case WithCID opt is not provided
func (vc *BasicVolumeController) LockVolume(pathToVol string, opts ...storage.LockVolOpt) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_lock_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, pathKey, pathToVol)
query := vc.repo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("Path", pathToVol))
vol, err := vc.repo.Find(ctx, query)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_lock_failure", nil)
return fmt.Errorf("failed to find storage volume with path %s - Error: %w", pathToVol, err)
}
for _, opt := range opts {
opt(&vol)
}
vol.ReadOnly = true
updatedVol, err := vc.repo.Update(ctx, vol.ID, vol)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_lock_failure", nil)
return fmt.Errorf("failed to update storage volume with path %s - Error: %w", pathToVol, err)
}
// change file permissions
if err := vc.FS.Chmod(updatedVol.Path, 0o400); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_lock_failure", nil)
return fmt.Errorf("failed to make storage volume read-only (path: %s): %w", updatedVol.Path, err)
}
st.Info(ctx, "volume_lock_success", nil)
return nil
}
// WithPrivate designates a given volume as private. It can be used both
// when creating or locking a volume.
func WithPrivate[T storage.CreateVolOpt | storage.LockVolOpt]() T {
return func(v *types.StorageVolume) {
v.Private = true
}
}
// WithCID sets the CID of a given volume if already calculated
//
// TODO [validate]: check if CID provided is valid
func WithCID(cid string) storage.LockVolOpt {
return func(v *types.StorageVolume) {
v.CID = cid
}
}
// DeleteVolume deletes a given storage volume record from the database and removes the corresponding directory.
// Identifier is either a CID or a path of a volume.
//
// Note [CID]: if we start to type CID as cid.CID, we may have to use generics here
// as in `[T string | cid.CID]`
func (vc *BasicVolumeController) DeleteVolume(identifier string, idType storage.IDType) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_delete_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, identifierKey, identifier)
ctx = context.WithValue(ctx, idTypeKey, idType)
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
ctx = context.WithValue(ctx, errorKey, "identifier type not supported")
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("identifier type not supported")
}
vol, err := vc.repo.Find(ctx, query)
if err != nil {
if err == repositories.ErrNotFound {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("volume not found: %w", err)
}
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("failed to find volume: %w", err)
}
// Remove the directory
if err := vc.FS.RemoveAll(vol.Path); err != nil {
return fmt.Errorf("failed to remove volume directory: %w", err)
}
// Delete the record from the database
if err := vc.repo.Delete(context.Background(), vol.ID); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("failed to delete volume: %w", err)
}
st.Info(ctx, "volume_delete_success", nil)
return nil
}
// ListVolumes returns a list of all storage volumes stored on the database
//
// TODO [filter]: maybe add opts to filter results by certain values
func (vc *BasicVolumeController) ListVolumes() ([]types.StorageVolume, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_list_duration", "opentelemetry", "log")
defer cancel()
volumes, err := vc.repo.FindAll(ctx, vc.repo.GetQuery())
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_list_failure", nil)
return nil, fmt.Errorf("failed to list volumes: %w", err)
}
ctx = context.WithValue(ctx, volumeCountKey, len(volumes))
st.Info(ctx, "volume_list_success", nil)
return volumes, nil
}
// GetSize returns the size of a volume
// TODO-minor: identify which measurement type will be used
func (vc *BasicVolumeController) GetSize(identifier string, idType storage.IDType) (int64, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_get_size_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, identifierKey, identifier)
ctx = context.WithValue(ctx, idTypeKey, idType)
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
ctx = context.WithValue(ctx, errorKey, fmt.Sprintf("unsupported ID type: %d", idType))
st.Error(ctx, "volume_get_size_failure", nil)
return 0, fmt.Errorf("unsupported ID type: %d", idType)
}
vol, err := vc.repo.Find(ctx, query)
if err != nil {
ctx = context.WithValue(ctx, errorKey, fmt.Sprintf("failed to find volume: %v", err))
st.Error(ctx, "volume_get_size_failure", nil)
return 0, fmt.Errorf("failed to find volume: %w", err)
}
size, err := utils.GetDirectorySize(vc.FS, vol.Path)
if err != nil {
ctx = context.WithValue(ctx, errorKey, fmt.Sprintf("failed to get directory size: %v", err))
st.Error(ctx, "volume_get_size_failure", nil)
return 0, fmt.Errorf("failed to get directory size: %w", err)
}
ctx = context.WithValue(ctx, sizeKey, size)
st.Info(ctx, "volume_get_size_success", nil)
ctx = context.WithValue(ctx, sizeKey, size)
st.Info(ctx, "volume_get_size_success", nil)
return size, nil
}
// EncryptVolume encrypts a given volume
func (vc *BasicVolumeController) EncryptVolume(path string, _ types.Encryptor, _ types.EncryptionType) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_encrypt_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, pathKey, path)
st.Error(ctx, "volume_encrypt_not_implemented", nil)
return fmt.Errorf("not implemented")
}
// DecryptVolume decrypts a given volume
func (vc *BasicVolumeController) DecryptVolume(path string, _ types.Decryptor, _ types.EncryptionType) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_decrypt_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, pathKey, path)
st.Error(ctx, "volume_decrypt_not_implemented", nil)
return fmt.Errorf("not implemented")
}
var _ storage.VolumeController = (*BasicVolumeController)(nil)
package basiccontroller
import (
"context"
"fmt"
"github.com/spf13/afero"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
rGorm "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/telemetry"
"gitlab.com/nunet/device-management-service/types"
)
type VolumeControllerTestKit struct {
BasicVolController *BasicVolumeController
Fs afero.Fs
Volumes map[string]*types.StorageVolume
}
// SetupVolumeControllerTestKit sets up a volume controller with 0-n volumes given a base path.
// If volumes are inputed, directories will be created and volumes will be stored in the database
func SetupVolumeControllerTestKit(basePath string, volumes map[string]*types.StorageVolume) (*VolumeControllerTestKit, error) {
// Initialize telemetry in test mode, replacing the global st
// It's initiated here too, besides on basic_controller_test.go, because
// s3 tests depend on basicController (which in turn depends on telemetry instantiation).
// S3 are calling this SetupVolControllerTestSuite, so it's one way to initialize telemetry
// for basic controller
st = telemetry.NewTelemetry(nil, nil, true)
db, err := gorm.Open(
sqlite.Open("file:?mode=memory&cache=shared"),
&gorm.Config{Logger: logger.Default.LogMode(logger.Silent)},
)
if err != nil {
return nil, fmt.Errorf("failed to create in-memory mock database: %w", err)
}
err = db.AutoMigrate(&types.StorageVolume{})
if err != nil {
return nil, fmt.Errorf("failed to automigrate: %w", err)
}
fs := afero.NewMemMapFs()
err = fs.MkdirAll(basePath, 0o755)
if err != nil {
return nil, fmt.Errorf("failed to create base path: %w", err)
}
repo := rGorm.NewStorageVolume(db)
vc, err := NewDefaultVolumeController(repo, basePath, fs)
if err != nil {
return nil, fmt.Errorf("failed to create volume controller: %w", err)
}
for _, vol := range volumes {
// create root volume dir
err = fs.MkdirAll(vol.Path, 0o755)
if err != nil {
return nil, fmt.Errorf("failed to create volume dir: %w", err)
}
// create volume record in db
_, err = repo.Create(context.Background(), *vol)
if err != nil {
return nil, fmt.Errorf("failed to create volume record: %w", err)
}
}
return &VolumeControllerTestKit{
BasicVolController: vc,
Fs: fs,
Volumes: volumes,
}, nil
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/storage"
basicController "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Download fetches files from a given S3 bucket. The key may be a directory ending
// with `/` or have a wildcard (`*`) so it handles normal S3 folders but it does
// not handle x-directory.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Download(ctx context.Context, sourceSpecs *types.SpecConfig) (types.StorageVolume, error) {
ctx, cancel := st.SpanContext(ctx, "s3", "s3_download_duration", "opentelemetry", "log")
defer cancel()
var storageVol types.StorageVolume
source, err := DecodeInputSpec(sourceSpecs)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_download_failure", nil)
return types.StorageVolume{}, err
}
storageVol, err = s.volController.CreateVolume(storage.VolumeSourceS3)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_volume_create_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %v", err)
}
resolvedObjects, err := resolveStorageKey(ctx, s.Client, &source)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_resolve_key_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to resolve storage key: %v", err)
}
for _, resolvedObject := range resolvedObjects {
err = s.downloadObject(ctx, &source, resolvedObject, storageVol.Path)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_download_object_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to download s3 object: %v", err)
}
}
// after data is filled within the volume, we have to lock it
err = s.volController.LockVolume(storageVol.Path)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_volume_lock_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to lock storage volume: %v", err)
}
st.Info(ctx, "s3_download_success", nil)
return storageVol, nil
}
func (s *Storage) downloadObject(ctx context.Context, source *InputSource, object s3Object, volPath string) error {
outputPath := filepath.Join(volPath, *object.key)
// use the same file system instance used by the Volume Controller
var fs afero.Fs
if basicVolController, ok := s.volController.(*basicController.BasicVolumeController); ok {
fs = basicVolController.FS
}
err := fs.MkdirAll(outputPath, 0o755)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_create_directory_failure", nil)
return fmt.Errorf("failed to create directory: %v", err)
}
if object.isDir {
// if object is a directory, we don't need to download it (just create the dir)
return nil
}
outputFile, err := fs.OpenFile(outputPath, os.O_RDWR|os.O_CREATE, 0o755)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_open_file_failure", nil)
return err
}
defer outputFile.Close()
zlog.Sugar().Debugf("Downloading s3 object %s to %s", *object.key, outputPath)
_, err = s.downloader.Download(ctx, outputFile, &s3.GetObjectInput{
Bucket: aws.String(source.Bucket),
Key: object.key,
IfMatch: object.eTag,
})
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_download_failure", nil)
return fmt.Errorf("failed to download file: %w", err)
}
st.Info(ctx, "s3_download_object_success", nil)
return nil
}
// resolveStorageKey returns a list of s3 objects within a bucket according to the key provided.
func resolveStorageKey(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := source.Key
if key == "" {
err := fmt.Errorf("key is required")
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_resolve_key_failure", nil)
return nil, err
}
// Check if the key represents a single object
if !strings.HasSuffix(key, "/") && !strings.Contains(key, "*") {
return resolveSingleObject(ctx, client, source)
}
// key represents multiple objects
return resolveObjectsWithPrefix(ctx, client, source)
}
func resolveSingleObject(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
headObjectInput := &s3.HeadObjectInput{
Bucket: aws.String(source.Bucket),
Key: aws.String(key),
}
headObjectOut, err := client.HeadObject(ctx, headObjectInput)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_head_object_failure", nil)
return []s3Object{}, fmt.Errorf("failed to retrieve object metadata: %v", err)
}
if strings.HasPrefix(*headObjectOut.ContentType, "application/x-directory") {
err := fmt.Errorf("x-directory is not yet handled")
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_directory_handling_failure", nil)
return []s3Object{}, err
}
return []s3Object{
{
key: aws.String(source.Key),
eTag: headObjectOut.ETag,
size: *headObjectOut.ContentLength,
},
}, nil
}
func resolveObjectsWithPrefix(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
// List objects with the given prefix
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(source.Bucket),
Prefix: aws.String(key),
}
var objects []s3Object
paginator := s3.NewListObjectsV2Paginator(client, listObjectsInput)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_list_objects_failure", nil)
return nil, fmt.Errorf("failed to list objects: %v", err)
}
for _, obj := range page.Contents {
objects = append(objects, s3Object{
key: aws.String(*obj.Key),
size: *obj.Size,
isDir: strings.HasSuffix(*obj.Key, "/"),
})
}
}
st.Info(ctx, "s3_resolve_objects_with_prefix_success", nil)
return objects, nil
}
package s3
import (
"context"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
)
// GetAWSDefaultConfig returns the default AWS config based on environment variables,
// shared configuration and shared credentials files.
func GetAWSDefaultConfig() (aws.Config, error) {
ctx, cancel := st.SpanContext(context.Background(), "s3", "get_aws_default_config_duration", "opentelemetry", "log")
defer cancel()
var optFns []func(*config.LoadOptions) error
cfg, err := config.LoadDefaultConfig(ctx, optFns...)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "get_aws_default_config_failure", nil)
return aws.Config{}, err
}
st.Info(ctx, "get_aws_default_config_success", nil)
return cfg, nil
}
// hasValidCredentials checks if the provided AWS config has valid credentials.
func hasValidCredentials(config aws.Config) bool {
ctx, cancel := st.SpanContext(context.Background(), "s3", "has_valid_credentials_duration", "opentelemetry", "log")
defer cancel()
credentials, err := config.Credentials.Retrieve(ctx)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "has_valid_credentials_failure", nil)
return false
}
if !credentials.HasKeys() {
st.Error(ctx, "has_valid_credentials_failure_no_keys", nil)
return false
}
st.Info(ctx, "has_valid_credentials_success", nil)
return true
}
// sanitizeKey removes trailing spaces and wildcards
func sanitizeKey(key string) string {
ctx, cancel := st.SpanContext(context.Background(), "s3", "sanitize_key_duration", "opentelemetry", "log")
defer cancel()
sanitizedKey := strings.TrimSuffix(strings.TrimSpace(key), "*")
ctx = context.WithValue(ctx, sanitizedKeyContext, sanitizedKey)
st.Info(ctx, "sanitize_key_success", nil)
return sanitizedKey
}
package s3
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var (
zlog *otelzap.Logger
st = telemetry.NewTelemetry(nil, nil, true)
)
// Context keys used for tracing
type contextKey string
const (
pathKey contextKey = "path"
SourceSpecsKey contextKey = "sourceSpecs"
errorKey contextKey = "error"
OutputPathKey contextKey = "outputPath"
bucketKey contextKey = "bucket"
S3KeyKey contextKey = "key"
ContentLength contextKey = "content_length"
FilePathKey contextKey = "file_path"
VolumePathKey contextKey = "volume_path"
sanitizedKeyContext contextKey = "sanitized_key"
)
func init() {
zlog = logger.OtelZapLogger("s3")
}
package s3
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
)
type Storage struct {
*s3.Client
volController storage.VolumeController
downloader *s3Manager.Downloader
uploader *s3Manager.Uploader
}
type s3Object struct {
key *string
eTag *string
size int64
isDir bool
}
// NewClient creates a new S3Storage which includes a S3-SDK client.
// It depends on a VolumeController to manage the volumes being acted upon.
func NewClient(config aws.Config, volController storage.VolumeController) (*Storage, error) {
ctx, cancel := st.SpanContext(context.Background(), "s3", "new_client_duration", "opentelemetry", "log")
defer cancel()
if !hasValidCredentials(config) {
err := fmt.Errorf("invalid credentials")
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "new_client_invalid_credentials", nil)
return nil, err
}
s3Client := s3.NewFromConfig(config)
storage := &Storage{
s3Client,
volController,
s3Manager.NewDownloader(s3Client),
s3Manager.NewUploader(s3Client),
}
st.Info(ctx, "new_client_success", nil)
return storage, nil
}
func (s *Storage) Size(ctx context.Context, source *types.SpecConfig) (uint64, error) {
ctx, cancel := st.SpanContext(ctx, "s3", "s3_size_duration", "opentelemetry", "log")
defer cancel()
inputSource, err := DecodeInputSpec(source)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_size_decode_input_spec_failure", nil)
return 0, fmt.Errorf("failed to decode input spec: %v", err)
}
input := &s3.HeadObjectInput{
Bucket: aws.String(inputSource.Bucket),
Key: aws.String(inputSource.Key),
}
output, err := s.HeadObject(ctx, input)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_size_head_object_failure", nil)
return 0, fmt.Errorf("failed to get object size: %v", err)
}
st.Info(ctx, "s3_size_success", nil)
return uint64(*output.ContentLength), nil
}
// Compile time interface check
// var _ storage.StorageProvider = (*S3Storage)(nil)
package s3
import (
"context"
"fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure"
"gitlab.com/nunet/device-management-service/types"
)
type InputSource struct {
Bucket string
Key string
Filter string
Region string
Endpoint string
}
func (s InputSource) Validate() error {
if s.Bucket == "" {
err := fmt.Errorf("invalid s3 storage params: bucket cannot be empty")
st.Error(context.Background(), "s3_input_source_validation_failure", nil)
return err
}
return nil
}
func (s InputSource) ToMap() map[string]interface{} {
return structs.Map(s)
}
func DecodeInputSpec(spec *types.SpecConfig) (InputSource, error) {
ctx, cancel := st.SpanContext(context.Background(), "s3", "decode_input_spec_duration", "opentelemetry", "log")
defer cancel()
if !spec.IsType(types.StorageProviderS3) {
err := fmt.Errorf("invalid storage source type. Expected %s but received %s", types.StorageProviderS3, spec.Type)
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "decode_input_spec_invalid_type_failure", nil)
return InputSource{}, err
}
inputParams := spec.Params
if inputParams == nil {
err := fmt.Errorf("invalid storage input source params. cannot be nil")
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "decode_input_spec_nil_params_failure", nil)
return InputSource{}, err
}
var c InputSource
if err := mapstructure.Decode(spec.Params, &c); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "decode_input_spec_decode_failure", nil)
return c, err
}
if err := c.Validate(); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "decode_input_spec_validation_failure", nil)
return c, err
}
st.Info(ctx, "decode_input_spec_success", nil)
return c, nil
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
basicController "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Upload uploads all files (recursively) from a local volume to an S3 bucket.
// It handles directories.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Upload(ctx context.Context, vol types.StorageVolume, destinationSpecs *types.SpecConfig) error {
ctx, cancel := st.SpanContext(ctx, "s3", "s3_upload_duration", "opentelemetry", "log")
defer cancel()
target, err := DecodeInputSpec(destinationSpecs)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_decode_spec_failure", nil)
return fmt.Errorf("failed to decode input spec: %v", err)
}
sanitizedKey := sanitizeKey(target.Key)
// set file system to act upon based on the volume controller implementation
var fs afero.Fs
if basicVolController, ok := s.volController.(*basicController.BasicVolumeController); ok {
fs = basicVolController.FS
}
zlog.Sugar().Debugf("Uploading files from %s to s3://%s/%s", vol.Path, target.Bucket, sanitizedKey)
err = afero.Walk(fs, vol.Path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_walk_failure", nil)
return err
}
// Skip directories
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(vol.Path, filePath)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_relative_path_failure", nil)
return fmt.Errorf("failed to get relative path: %v", err)
}
// Construct the S3 key by joining the sanitized key and the relative path
s3Key := filepath.Join(sanitizedKey, relPath)
file, err := fs.Open(filePath)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_open_file_failure", nil)
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
// Add file path and S3 key to context
ctx = context.WithValue(ctx, FilePathKey, filePath)
ctx = context.WithValue(ctx, S3KeyKey, s3Key)
zlog.Sugar().Debugf("Uploading %s to s3://%s/%s", filePath, target.Bucket, s3Key)
_, err = s.uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(target.Bucket),
Key: aws.String(s3Key),
Body: file,
})
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_put_object_failure", nil)
return fmt.Errorf("failed to upload file to S3: %v", err)
}
st.Info(ctx, "s3_upload_file_success", nil)
return nil
})
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_failure", nil)
return fmt.Errorf("upload failed. It's possible that some files were uploaded; Error: %v", err)
}
st.Info(ctx, "s3_upload_success", nil)
return nil
}
package telemetry
import (
"os"
"sync"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/types"
)
var (
once sync.Once
instance *Telemetry
logLevel types.ObservabilityLevel
zapLogger *zap.Logger
)
// InitGlobalTelemetry initializes the global telemetry instance with configuration loaded from the configuration package.
func InitGlobalTelemetry() error {
var initError error
once.Do(func() {
// Initialize Zap logger
zapLogger, initError = initZapLogger()
if initError != nil {
panic(initError)
}
zap.ReplaceGlobals(zapLogger)
cfg := config.GetConfig()
telemetryConfig := cfg.Telemetry
logLevel = types.INFO // Default level
if level, err := types.ParseObservabilityLevel(telemetryConfig.ObservabilityLevel); err == nil {
logLevel = level
} else {
zap.L().Warn("Invalid observability level, defaulting to INFO", zap.Error(err))
}
instance = &Telemetry{
config: &types.TelemetryConfig{
ServiceName: telemetryConfig.ServiceName,
GlobalEndpoint: telemetryConfig.GlobalEndpoint,
ObservabilityLevel: telemetryConfig.ObservabilityLevel, // Assign the string value
TelemetryMode: telemetryConfig.TelemetryMode,
},
}
opentelemetryCollector := NewOpenTelemetryCollector(instance.config, zap.L())
logCollector := NewLogCollector(instance.config, zap.L())
instance.collectors = map[string]Collector{
logCollector.GetName(): logCollector,
opentelemetryCollector.GetName(): opentelemetryCollector,
}
for _, collector := range instance.collectors {
if err := collector.Initialize(); err != nil {
zap.L().Error("Failed to initialize collector", zap.Error(err))
}
}
// Start periodic flush after initializing collectors
StartPeriodicFlush(5 * time.Minute)
})
return initError
}
// initZapLogger initializes the zap logger based on configuration or environment variables.
func initZapLogger() (*zap.Logger, error) {
var err error
var logger *zap.Logger
if _, debug := os.LookupEnv("NUNET_DEBUG"); debug || config.GetConfig().General.Debug {
zapConfig := zap.NewDevelopmentConfig()
zapConfig.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
logger, err = zapConfig.Build()
} else {
logger, err = zap.NewProduction()
}
return logger, err
}
// StartPeriodicFlush starts a goroutine that periodically flushes telemetry data.
func StartPeriodicFlush(interval time.Duration) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
zap.L().Info("Periodic flush started for telemetry")
instance.Flush()
}
}()
}
package telemetry
import (
"context"
"gitlab.com/nunet/device-management-service/types"
"go.uber.org/zap"
)
type LogCollector struct {
config *types.TelemetryConfig
logger *zap.Logger
}
func NewLogCollector(config *types.TelemetryConfig, logger *zap.Logger) *LogCollector {
return &LogCollector{
config: config,
logger: logger,
}
}
func (c *LogCollector) Initialize() error {
c.logger.Info("LogCollector initialized.")
return nil
}
func (c *LogCollector) HandleEvent(event types.Event) error {
fields := []zap.Field{
zap.Any("context", event.Context),
zap.String("message", event.Message),
zap.String("level", event.Level.String()),
zap.Any("payload", event.Payload),
}
switch event.Level {
case types.TRACE:
c.logger.Debug(event.Message, fields...)
case types.DEBUG:
c.logger.Debug(event.Message, fields...)
case types.INFO:
c.logger.Info(event.Message, fields...)
case types.WARN:
c.logger.Warn(event.Message, fields...)
case types.ERROR:
c.logger.Error(event.Message, fields...)
case types.FATAL:
c.logger.Fatal(event.Message, fields...)
default:
c.logger.Info(event.Message, fields...)
}
return nil
}
func (c *LogCollector) Flush() error {
if err := c.logger.Sync(); err != nil { // Check for error in Sync
return err
}
return nil
}
func (c *LogCollector) Shutdown() error {
return c.Flush()
}
func (c *LogCollector) GetName() string {
return "log"
}
func (c *LogCollector) SpanContext(ctx context.Context, _ string) (context.Context, context.CancelFunc) {
// LogCollector does not support tracing, so just return the original context and a no-op cancel function
return ctx, func() {}
}
// Compile-time check to ensure LogCollector implements the Collector interface
var _ Collector = (*LogCollector)(nil)
package logger
import (
"sync"
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"go.uber.org/zap"
)
var (
once sync.Once
logger *otelzap.Logger
)
type Logger struct {
*zap.Logger
}
// New creates a new Logger with the specified package name.
// It assumes that the logger has already been initialized elsewhere.
func New(pkg string) *Logger {
Log := &Logger{
Logger: zap.L(), // Use the globally initialized zap logger
}
Log.Logger = Log.Logger.With(
zap.String("package", pkg),
)
return Log
}
func OtelZapLogger(pkg string) *otelzap.Logger {
once.Do(func() {
l := New(pkg)
logger = otelzap.New(l.Logger)
})
return logger
}
package telemetry
import (
"context"
"sync"
"gitlab.com/nunet/device-management-service/types"
)
// MockCollector is a mock implementation of the Collector interface.
type MockCollector struct {
mu sync.Mutex
events []types.Event
traces []MockTrace
initialized bool
name string
}
type MockTrace struct {
SpanName string
Context context.Context
CancelFunc context.CancelFunc
}
// NewMockCollector creates a new instance of MockCollector.
func NewMockCollector(name string) *MockCollector {
return &MockCollector{
events: []types.Event{},
traces: []MockTrace{},
name: name,
}
}
// Initialize is a mock implementation of the Collector interface's Initialize method.
func (m *MockCollector) Initialize() error {
m.mu.Lock()
defer m.mu.Unlock()
m.initialized = true
return nil
}
// SpanContext is a mock implementation of the Collector interface's SpanContext method.
func (m *MockCollector) SpanContext(ctx context.Context, spanName string) (context.Context, context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
mockCtx, cancel := context.WithCancel(ctx)
m.traces = append(m.traces, MockTrace{
SpanName: spanName,
Context: mockCtx,
CancelFunc: cancel,
})
return mockCtx, cancel
}
// HandleEvent is a mock implementation of the Collector interface's HandleEvent method.
func (m *MockCollector) HandleEvent(event types.Event) error {
m.mu.Lock()
defer m.mu.Unlock()
m.events = append(m.events, event)
return nil
}
// Flush is a mock implementation of the Collector interface's Flush method.
func (m *MockCollector) Flush() error { // Added error return
m.mu.Lock()
defer m.mu.Unlock()
return nil
}
// Shutdown is a mock implementation of the Collector interface's Shutdown method.
func (m *MockCollector) Shutdown() error { // Added error return
m.mu.Lock()
defer m.mu.Unlock()
return nil
}
// GetName returns the name of the mock collector.
func (m *MockCollector) GetName() string {
return m.name
}
// GetTraces returns the recorded traces.
func (m *MockCollector) GetTraces() []MockTrace {
m.mu.Lock()
defer m.mu.Unlock()
return m.traces
}
// GetEvents returns the recorded events.
func (m *MockCollector) GetEvents() []types.Event {
m.mu.Lock()
defer m.mu.Unlock()
return m.events
}
// Reset clears all recorded events and traces.
func (m *MockCollector) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.events = []types.Event{}
m.traces = []MockTrace{}
}
// AssertInitialized checks if the mock collector was initialized.
func (m *MockCollector) AssertInitialized() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.initialized
}
// MockTelemetry is a mock implementation of the Telemetry system.
type MockTelemetry struct {
Telemetry
mu sync.Mutex
collectors map[string]*MockCollector
}
// NewMockTelemetry creates a new instance of MockTelemetry that mimics the Telemetry struct.
func NewMockTelemetry(config *types.TelemetryConfig) *MockTelemetry {
return &MockTelemetry{
Telemetry: Telemetry{
config: config,
collectors: make(map[string]Collector),
},
collectors: make(map[string]*MockCollector),
}
}
// AddCollector adds a mock collector to the telemetry system.
func (m *MockTelemetry) AddCollector(collector *MockCollector) {
m.mu.Lock()
defer m.mu.Unlock()
m.collectors[collector.GetName()] = collector
}
// SpanContext simulates starting a trace with the given collectors.
func (m *MockTelemetry) SpanContext(ctx context.Context, _ string, span string, collectors ...string) (context.Context, context.CancelFunc) { // Renamed unused parameter
var cancelFuncs []context.CancelFunc
for _, collectorName := range collectors {
if collector, ok := m.collectors[collectorName]; ok {
mockCtx, cancel := collector.SpanContext(ctx, span)
cancelFuncs = append(cancelFuncs, cancel)
ctx = mockCtx
}
}
cancel := func() {
for _, cancelFunc := range cancelFuncs {
cancelFunc()
}
}
return ctx, cancel
}
// Trace simulates logging a trace event in all added collectors.
func (m *MockTelemetry) Trace(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.TRACE, message, payload)
}
// Debug simulates logging a debug event in all added collectors.
func (m *MockTelemetry) Debug(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.DEBUG, message, payload)
}
// Info simulates logging an info event in all added collectors.
func (m *MockTelemetry) Info(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.INFO, message, payload)
}
// Warn simulates logging a warning event in all added collectors.
func (m *MockTelemetry) Warn(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.WARN, message, payload)
}
// Error simulates logging an error event in all added collectors.
func (m *MockTelemetry) Error(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.ERROR, message, payload)
}
// Fatal simulates logging a fatal event in all added collectors.
func (m *MockTelemetry) Fatal(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.FATAL, message, payload)
}
// logEvent logs an event in all collectors.
func (m *MockTelemetry) logEvent(ctx context.Context, level types.ObservabilityLevel, message string, payload map[string]interface{}) {
event := types.Event{
Context: ctx,
Level: level,
Message: message,
Payload: payload,
}
for _, collector := range m.collectors {
_ = collector.HandleEvent(event) // HandleEvent error is intentionally ignored
}
}
// GetCollector returns a mock collector by name.
func (m *MockTelemetry) GetCollector(name string) *MockCollector {
m.mu.Lock()
defer m.mu.Unlock()
return m.collectors[name]
}
// Reset clears all recorded data in all collectors.
func (m *MockTelemetry) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
for _, collector := range m.collectors {
collector.Reset()
}
}
// Flush is a mock implementation of the Telemetry system's Flush method.
func (m *MockTelemetry) Flush() {
for _, collector := range m.collectors {
_ = collector.Flush() // Flush error is intentionally ignored
}
}
// Shutdown is a mock implementation of the Telemetry system's Shutdown method.
func (m *MockTelemetry) Shutdown() {
m.Flush()
for _, collector := range m.collectors {
_ = collector.Shutdown() // Shutdown error is intentionally ignored
}
}
package telemetry
import (
"context"
"gitlab.com/nunet/device-management-service/types"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.uber.org/zap"
)
type OpenTelemetryCollector struct {
config *types.TelemetryConfig
logger *zap.Logger
tracerProvider *sdktrace.TracerProvider
}
func NewOpenTelemetryCollector(config *types.TelemetryConfig, logger *zap.Logger) *OpenTelemetryCollector {
return &OpenTelemetryCollector{
config: config,
logger: logger,
}
}
func (c *OpenTelemetryCollector) Initialize() error {
c.logger.Info("Initializing OpenTelemetry HTTP trace exporter",
zap.String("endpoint", c.config.GlobalEndpoint),
)
exp, err := otlptracehttp.New(context.Background(),
otlptracehttp.WithEndpoint(c.config.GlobalEndpoint),
otlptracehttp.WithInsecure(),
)
if err != nil {
c.logger.Error("Failed to create HTTP trace exporter", zap.Error(err))
return err
}
res, err := resource.New(context.Background(),
resource.WithAttributes(
semconv.ServiceNameKey.String(c.config.ServiceName),
),
)
if err != nil {
c.logger.Error("Failed to create resource", zap.Error(err))
return err
}
c.tracerProvider = sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exp),
sdktrace.WithResource(res),
)
otel.SetTracerProvider(c.tracerProvider)
c.logger.Info("OpenTelemetryCollector initialized.")
return nil
}
func (c *OpenTelemetryCollector) HandleEvent(event types.Event) error {
fields := []attribute.KeyValue{
attribute.String("message", event.Message),
attribute.String("level", event.Level.String()),
}
for key, value := range event.Payload {
fields = append(fields, attribute.String(key, value.(string)))
}
c.logger.Info("Handling event",
zap.String("message", event.Message),
zap.String("level", event.Level.String()),
zap.Any("context", event.Context),
zap.Any("payload", event.Payload),
)
// Fetch tracer name from context, or default to "otel-tracer"
tracerName, ok := event.Context.Value(tracerNameKey).(string)
if !ok {
tracerName = "otel-tracer"
}
tracer := c.tracerProvider.Tracer(tracerName)
ctx := context.Background()
_, span := tracer.Start(ctx, event.Message)
span.SetAttributes(fields...)
span.End()
c.logger.Info("Event sent to OpenTelemetry",
zap.String("message", event.Message),
zap.String("level", event.Level.String()),
zap.Any("context", event.Context),
zap.Any("payload", event.Payload),
)
return nil
}
func (c *OpenTelemetryCollector) Flush() error {
if c.tracerProvider == nil {
c.logger.Warn("TracerProvider is nil, skipping flush")
return nil
}
c.logger.Info("Flushing tracer provider")
if err := c.tracerProvider.ForceFlush(context.Background()); err != nil {
c.logger.Error("Error flushing tracer provider", zap.Error(err))
return err
}
c.logger.Info("Collector flushed successfully")
return nil
}
func (c *OpenTelemetryCollector) Shutdown() error {
if c.tracerProvider == nil {
c.logger.Warn("TracerProvider is nil, skipping shutdown")
return nil
}
c.logger.Info("Shutting down tracer provider")
if err := c.tracerProvider.Shutdown(context.Background()); err != nil {
c.logger.Error("Error shutting down tracer provider", zap.Error(err))
return err
}
c.logger.Info("Collector shutdown successfully")
return nil
}
func (c *OpenTelemetryCollector) GetName() string {
return "opentelemetry"
}
func (c *OpenTelemetryCollector) SpanContext(ctx context.Context, span string) (context.Context, context.CancelFunc) {
tracerName, ok := ctx.Value(tracerNameKey).(string)
if !ok {
tracerName = c.GetName()
}
tracer := c.tracerProvider.Tracer(tracerName)
ctx, s := tracer.Start(ctx, span)
cancel := func() {
s.End()
}
return ctx, cancel
}
// Compile-time check to ensure OpenTelemetryCollector implements the Collector interface
var _ Collector = (*OpenTelemetryCollector)(nil)
package telemetry
import (
"context"
"runtime"
"go.uber.org/zap"
"gitlab.com/nunet/device-management-service/types"
)
type Telemetry struct {
config *types.TelemetryConfig
collectors map[string]Collector
testMode bool
}
// Define a custom type for context keys to avoid conflicts
type contextKey string
const (
collectorsKey contextKey = "collectors"
tracerNameKey contextKey = "tracerName"
versionKey contextKey = "version"
)
func GetTelemetry() *Telemetry {
return instance
}
// NewTelemetry initializes a new Telemetry instance.
// If testMode is true, the telemetry operations will be no-ops.
func NewTelemetry(config *types.TelemetryConfig, collectors map[string]Collector, testMode bool) *Telemetry {
if testMode {
return &Telemetry{
testMode: true,
}
}
return &Telemetry{
config: config,
collectors: collectors,
testMode: false,
}
}
func (t *Telemetry) SpanContext(ctx context.Context, tracerName string, span string, collectors ...string) (context.Context, context.CancelFunc) {
if t.testMode {
return ctx, func() {}
}
var cancelFuncs []context.CancelFunc
// Fetch caller info
pc, _, _, ok := runtime.Caller(1)
functionName := "unknown_function"
if ok {
function := runtime.FuncForPC(pc)
functionName = function.Name()
}
// Use caller info as default tracer and span names if not provided
if tracerName == "" {
tracerName = functionName
}
if span == "" {
span = functionName
}
ctx = context.WithValue(ctx, collectorsKey, collectors)
ctx = context.WithValue(ctx, tracerNameKey, tracerName)
var cancelFunc context.CancelFunc
for _, collector := range collectors {
if c, ok := t.collectors[collector]; ok {
ctx, cancelFunc = c.SpanContext(ctx, span)
cancelFuncs = append(cancelFuncs, cancelFunc)
}
}
cancel := func() {
for _, cancelFunc := range cancelFuncs {
cancelFunc()
}
}
return ctx, cancel
}
func (t *Telemetry) Trace(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.TRACE, message, payload)
}
func (t *Telemetry) Debug(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.DEBUG, message, payload)
}
func (t *Telemetry) Info(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.INFO, message, payload)
}
func (t *Telemetry) Warn(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.WARN, message, payload)
}
func (t *Telemetry) Error(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.ERROR, message, payload)
}
func (t *Telemetry) Fatal(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.FATAL, message, payload)
}
func (t *Telemetry) logEvent(ctx context.Context, level types.ObservabilityLevel, message string, payload map[string]interface{}) {
// Check if telemetry is enabled
if t.config.TelemetryMode == "disabled" {
return
}
// Only log events that are at or above the configured log level
if level < logLevel {
return
}
// Add the version to the context
ctx = context.WithValue(ctx, versionKey, "v0.5")
event := types.Event{
Context: ctx,
Level: level,
Message: message,
Payload: payload,
}
// Check for specific collector in context
collectors, ok := ctx.Value(collectorsKey).([]string)
if ok {
for _, collector := range collectors {
if c, ok := t.collectors[collector]; ok {
if err := c.HandleEvent(event); err != nil {
zap.L().Error("Failed to handle event", zap.Error(err))
}
}
}
return
}
// Forward to all collectors by default
for _, collector := range t.collectors {
if err := collector.HandleEvent(event); err != nil {
zap.L().Error("Failed to handle event", zap.Error(err))
}
}
}
func (t *Telemetry) Flush() {
if t.testMode {
return
}
for _, collector := range t.collectors {
if err := collector.Flush(); err != nil {
zap.L().Error("Failed to flush collector", zap.Error(err))
}
}
}
func (t *Telemetry) Shutdown() {
if t.testMode {
return
}
t.Flush()
for _, collector := range t.collectors {
if err := collector.Shutdown(); err != nil {
zap.L().Error("Failed to shut down collector", zap.Error(err))
}
}
}
package types
import (
"fmt"
"reflect"
"slices"
"github.com/hashicorp/go-version"
)
// Connectivity represents the network configuration
type Connectivity struct {
Ports []int `json:"ports" description:"Ports that need to be open for the job to run"`
VPN bool `json:"vpn" description:"Whether VPN is required"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Connectivity] = (*Connectivity)(nil)
_ Calculable[Connectivity] = (*Connectivity)(nil)
)
func (c *Connectivity) Compare(other Connectivity) Comparison {
if reflect.DeepEqual(*c, other) {
return Equal
}
if IsStrictlyContainedInt(c.Ports, other.Ports) && (c.VPN && other.VPN || c.VPN && !other.VPN) {
return Better
}
return Worse
}
func (c *Connectivity) Add(other Connectivity) error {
portSet := make(map[int]struct{}, len(c.Ports))
for _, port := range c.Ports {
portSet[port] = struct{}{}
}
for _, port := range other.Ports {
if _, exists := portSet[port]; !exists {
c.Ports = append(c.Ports, port)
}
}
// Set VPN to true if other has VPN
c.VPN = c.VPN || other.VPN
return nil
}
func (c *Connectivity) Subtract(other Connectivity) error {
if other.VPN {
c.VPN = false
}
// Filter out the ports that exist in 'other' from 'c'
filteredPorts := c.Ports[:0]
for _, port := range c.Ports {
if !slices.Contains(other.Ports, port) {
filteredPorts = append(filteredPorts, port)
}
}
// Set the filtered ports back to c
if len(filteredPorts) == 0 {
c.Ports = nil
} else {
c.Ports = filteredPorts
}
return nil
}
// TimeInformation represents the time constraints
type TimeInformation struct {
Units string `json:"units" description:"Time units"`
MaxTime int `json:"max_time" description:"Maximum time that job should run"`
Preference int `json:"preference" description:"Time preference"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[TimeInformation] = (*TimeInformation)(nil)
_ Calculable[TimeInformation] = (*TimeInformation)(nil)
)
func (t *TimeInformation) Add(other TimeInformation) error {
t.MaxTime += other.MaxTime
// If there are other time fields, add them here
// Example:
// t.MinTime += other.MinTime
// t.AvgTime += other.AvgTime
return nil
}
func (t *TimeInformation) Subtract(other TimeInformation) error {
t.MaxTime -= other.MaxTime
// If there are other time fields, subtract them here
// Example:
// t.MinTime -= other.MinTime
// t.AvgTime -= other.AvgTime
return nil
}
func (t *TimeInformation) TotalTime() int {
switch t.Units {
case "seconds":
return t.MaxTime
case "minutes":
return t.MaxTime * 60
case "hours":
return t.MaxTime * 60 * 60
case "days":
return t.MaxTime * 60 * 60 * 24
default:
return t.MaxTime
}
}
func (t *TimeInformation) Compare(other TimeInformation) Comparison {
if reflect.DeepEqual(t, other) {
return Equal
}
ownTotalTime := t.TotalTime()
otherTotalTime := other.TotalTime()
if ownTotalTime == otherTotalTime {
return Equal
}
if ownTotalTime < otherTotalTime {
return Worse
}
return Better
}
// PriceInformation represents the pricing information
type PriceInformation struct {
Currency string `json:"currency" description:"Currency used for pricing"`
CurrencyPerHour int `json:"currency_per_hour" description:"Price charged per hour"`
TotalPerJob int `json:"total_per_job" description:"Maximum total price or budget of the job"`
Preference int `json:"preference" description:"Pricing preference"`
}
// implementing Comparable interface
var _ Comparable[PriceInformation] = (*PriceInformation)(nil)
func (p *PriceInformation) Compare(other PriceInformation) Comparison {
if reflect.DeepEqual(p, other) {
return Equal
}
if p.Currency == other.Currency {
if p.TotalPerJob == other.TotalPerJob {
if p.CurrencyPerHour == other.CurrencyPerHour {
return Equal
} else if p.CurrencyPerHour < other.CurrencyPerHour {
return Better
}
return Worse
}
if p.TotalPerJob < other.TotalPerJob {
if p.CurrencyPerHour <= other.CurrencyPerHour {
return Better
}
return Worse
}
return Worse
}
return Error
}
func (p *PriceInformation) Equal(price PriceInformation) bool {
return p.Currency == price.Currency &&
p.CurrencyPerHour == price.CurrencyPerHour &&
p.TotalPerJob == price.TotalPerJob &&
p.Preference == price.Preference
}
// Library represents the libraries
type Library struct {
Name string `json:"name" description:"Name of the library"`
Constraint string `json:"constraint" description:"Constraint of the library"`
Version string `json:"version" description:"Version of the library"`
}
// implementing Comparable interface
var _ Comparable[Library] = (*Library)(nil)
func (lib *Library) Compare(other Library) Comparison {
ownVersion, err := version.NewVersion(lib.Version)
if err != nil {
return Error
}
// return 'Error' if the version of the left library is not valid
constraints, err := version.NewConstraint(other.Constraint + " " + other.Version)
if err != nil {
return Error
}
// return 'Error' if the names of the libraries are different
if lib.Name != other.Name {
return Error
}
// else return 'Equal if versions of libraries are equal and the constraint is '='
if other.Constraint == "=" && constraints.Check(ownVersion) {
return Equal
}
// else return 'Better' if versions of libraries match the constraint
if constraints.Check(ownVersion) {
return Better
}
// else return 'Worse'
return Worse
}
func (lib *Library) Equal(library Library) bool {
if lib.Name == library.Name && lib.Constraint == library.Constraint && lib.Version == library.Version {
return true
}
return false
}
// Locality represents the locality
type Locality struct {
Kind string `json:"kind" description:"Kind of the region (geographic, nunet-defined, etc)"`
Name string `json:"name" description:"Name of the region"`
}
// implementing Comparable interface
var _ Comparable[Locality] = (*Locality)(nil)
func (loc *Locality) Compare(other Locality) Comparison {
if loc.Kind == other.Kind {
if loc.Name == other.Name {
return Equal
}
return Worse
}
return Error
}
func (loc *Locality) Equal(locality Locality) bool {
if loc.Kind == locality.Kind && loc.Name == locality.Name {
return true
}
return false
}
// KYC represents the KYC data
type KYC struct {
Type string `json:"type" description:"Type of KYC"`
Data string `json:"data" description:"Data required for KYC"`
}
// implementing Comparable interface
var _ Comparable[KYC] = (*KYC)(nil)
func (k *KYC) Compare(other KYC) Comparison {
if reflect.DeepEqual(*k, other) {
return Equal
}
return Error
}
func (k *KYC) Equal(kyc KYC) bool {
if k.Type == kyc.Type && k.Data == kyc.Data {
return true
}
return false
}
// JobType represents the type of the job
type JobType string
const (
Batch JobType = "batch"
SingleRun JobType = "single_run"
Recurring JobType = "recurring"
LongRunning JobType = "long_running"
)
// implementing Comparable and Calculable interface
var _ Comparable[JobType] = (*JobType)(nil)
func (j JobType) Compare(other JobType) Comparison {
if reflect.DeepEqual(j, other) {
return Equal
}
return Error
}
// JobTypes a slice of JobType
type JobTypes []JobType
// implementing Comparable and Calculable interfaces
var (
_ Comparable[JobTypes] = (*JobTypes)(nil)
_ Calculable[JobTypes] = (*JobTypes)(nil)
)
func (j *JobTypes) Add(other JobTypes) error {
existing := *j
result := existing[:0]
existingSet := make(map[JobType]struct{}, len(result))
for _, job := range *j {
result = append(result, job)
existingSet[job] = struct{}{}
}
for _, job := range other {
if _, exists := existingSet[job]; !exists {
result = append(result, job)
existingSet[job] = struct{}{}
}
}
*j = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (j *JobTypes) Subtract(other JobTypes) error {
existing := *j
result := existing[:0]
toRemove := make(map[JobType]struct{}, len(other))
for _, job := range other {
toRemove[job] = struct{}{}
}
for _, job := range existing {
if _, found := toRemove[job]; !found {
result = append(result, job)
}
}
*j = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (j *JobTypes) Compare(other JobTypes) Comparison {
// we know that interfaces here are slices, so need to assert first
l := ConvertTypedSliceToUntypedSlice(*j)
r := ConvertTypedSliceToUntypedSlice(other)
if !IsSameShallowType(l, r) {
return Error
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return Equal
case IsStrictlyContained(l, r):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
return Better
case IsStrictlyContained(r, l):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
return Worse
// TODO: this comparator does not take into account options when several job types are available and several job types are required
// in the same data structure; this is why the test fails;
}
return Error
}
func (j *JobTypes) Contains(jobType JobType) bool {
for _, j := range *j {
if j == jobType {
return true
}
}
return false
}
type Libraries []Library
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Libraries] = (*Libraries)(nil)
_ Calculable[Libraries] = (*Libraries)(nil)
)
func (l *Libraries) Add(other Libraries) error {
existing := make(map[Library]struct{}, len(*l))
for _, lib := range *l {
existing[lib] = struct{}{}
}
for _, otherLibrary := range other {
if _, found := existing[otherLibrary]; !found {
*l = append(*l, otherLibrary)
existing[otherLibrary] = struct{}{}
}
}
return nil
}
func (l *Libraries) Subtract(other Libraries) error {
// remove from array
existing := *l
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[Library]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex2 := range existing {
if _, found := toRemove[ex2]; !found {
result = append(result, ex2)
}
}
*l = result
return nil
}
func (l *Libraries) Compare(other Libraries) Comparison {
interimComparison1 := make([][]Comparison, 0)
for _, otherLibrary := range other {
var interimComparison2 []Comparison
for _, ownLibrary := range *l {
interimComparison2 = append(interimComparison2, ownLibrary.Compare(otherLibrary))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
finalComparison := make([]Comparison, 0)
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, Error) {
return Error
}
if slices.Contains(finalComparison, Worse) {
return Worse
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal
}
return Better
}
func (l *Libraries) Contains(library Library) bool {
for _, lib := range *l {
if lib.Equal(library) {
return true
}
}
return false
}
type Localities []Locality
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Localities] = (*Localities)(nil)
_ Calculable[Localities] = (*Localities)(nil)
)
func (l *Localities) Compare(other Localities) Comparison {
interimComparison := make([]map[string]Comparison, 0)
for _, otherLocality := range other {
field := make(map[string]Comparison)
field[otherLocality.Kind] = Error
for _, ownLocality := range *l {
if ownLocality.Kind == otherLocality.Kind {
field[otherLocality.Kind] = ownLocality.Compare(otherLocality)
// this is to make sure that we have a comparison even if slice dimentiones do not match
}
}
interimComparison = append(interimComparison, field)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []Comparison
for _, c := range interimComparison {
for _, v := range c { // we know that there is only one value in the map
finalComparison = append(finalComparison, v)
}
}
if slices.Contains(finalComparison, Error) {
return Error
}
if slices.Contains(finalComparison, Worse) {
return Worse
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal
}
return Better
}
func (l *Localities) Add(other Localities) error {
existing := make(map[Locality]struct{}, len(*l))
for _, loc := range *l {
existing[loc] = struct{}{}
}
for _, otherLocality := range other {
if _, found := existing[otherLocality]; !found {
*l = append(*l, otherLocality)
existing[otherLocality] = struct{}{}
}
}
return nil
}
func (l *Localities) Subtract(other Localities) error {
// remove from array
existing := *l
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[Locality]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex := range existing {
if _, found := toRemove[ex]; !found {
result = append(result, ex)
}
}
*l = result
return nil
}
func (l *Localities) Contains(locality Locality) bool {
for _, loc := range *l {
if loc.Equal(locality) {
return true
}
}
return false
}
type KYCs []KYC
// implementing Comparable and Calculable interfaces
var (
_ Comparable[KYCs] = (*KYCs)(nil)
_ Calculable[KYCs] = (*KYCs)(nil)
)
func (k *KYCs) Compare(other KYCs) Comparison {
if reflect.DeepEqual(*k, other) {
return Equal
} else if len(other) == 0 && len(*k) != 0 {
return Better
}
for _, ownKYC := range *k {
for _, otherKYC := range other {
if comp := ownKYC.Compare(otherKYC); comp == Equal {
return Equal
}
}
}
return Error
}
func (k *KYCs) Add(other KYCs) error {
existing := make(map[KYC]struct{}, len(*k))
for _, kyc := range *k {
existing[kyc] = struct{}{}
}
for _, otherKYC := range other {
if _, found := existing[otherKYC]; !found {
*k = append(*k, otherKYC)
existing[otherKYC] = struct{}{}
}
}
return nil
}
func (k *KYCs) Subtract(other KYCs) error {
// remove from array
existing := *k
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[KYC]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex := range existing {
if _, found := toRemove[ex]; !found {
result = append(result, ex)
}
}
*k = result
return nil
}
func (k *KYCs) Contains(kyc KYC) bool {
for _, k := range *k {
if k.Equal(kyc) {
return true
}
}
return false
}
type PricesInformation []PriceInformation
// implementing Comparable and Calculable interfaces
var (
_ Comparable[PricesInformation] = (*PricesInformation)(nil)
_ Calculable[PricesInformation] = (*PricesInformation)(nil)
)
func (ps *PricesInformation) Add(other PricesInformation) error {
existing := make(map[PriceInformation]struct{}, len(*ps))
for _, p := range *ps {
existing[p] = struct{}{}
}
for _, otherPrice := range other {
if _, found := existing[otherPrice]; !found {
*ps = append(*ps, otherPrice)
existing[otherPrice] = struct{}{}
}
}
return nil
}
func (ps *PricesInformation) Subtract(other PricesInformation) error {
if len(other) == 0 {
return nil // Nothing to subtract, no operation needed
}
toRemove := make(map[PriceInformation]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
result := (*ps)[:0] // Reuse the slice's underlying array
for _, ex := range *ps {
if _, found := toRemove[ex]; !found {
result = append(result, ex) // Keep entries not in 'toRemove'
}
}
*ps = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (ps *PricesInformation) Compare(other PricesInformation) Comparison {
if reflect.DeepEqual(*ps, other) {
return Equal
}
comparison := Error
for _, ownPrice := range *ps {
for _, otherPrice := range other {
if comparison = ownPrice.Compare(otherPrice); comparison != Error {
return comparison
}
}
}
return comparison
}
func (ps *PricesInformation) Contains(price PriceInformation) bool {
for _, p := range *ps {
if p.Equal(price) {
return true
}
}
return false
}
// HardwareCapability represents the hardware capability of the machine
type HardwareCapability struct {
Executors Executors `json:"executor" description:"Executor type required for the job (docker, vm, wasm, or others)"`
JobTypes JobTypes `json:"type" description:"Details about type of the job (One time, batch, recurring, long running). Refer to dms.jobs package for jobType data model"`
Resources Resources `json:"resources" description:"Resources required for the job"`
Libraries Libraries `json:"libraries" description:"Libraries required for the job"`
Localities Localities `json:"locality" description:"Preferred localities of the machine for execution"`
Connectivity Connectivity `json:"connectivity" description:"Network configuration required"`
Price PricesInformation `json:"price" description:"Pricing information"`
Time TimeInformation `json:"time" description:"Time constraints"`
KYCs KYCs
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[HardwareCapability] = (*HardwareCapability)(nil)
_ Calculable[HardwareCapability] = (*HardwareCapability)(nil)
)
// Compare compares two HardwareCapability objects
func (c *HardwareCapability) Compare(other HardwareCapability) Comparison {
comparisonMap := ComplexComparison{
"Executors": c.Executors.Compare(other.Executors),
"JobTypes": c.JobTypes.Compare(other.JobTypes),
"Resources": c.Resources.Compare(other.Resources),
"Libraries": c.Libraries.Compare(other.Libraries),
"Localities": c.Localities.Compare(other.Localities),
"Price": c.Price.Compare(other.Price),
"KYCs": c.KYCs.Compare(other.KYCs),
"Time": c.Time.Compare(other.Time),
"Connectivity": c.Connectivity.Compare(other.Connectivity),
}
return comparisonMap.Result()
}
// Add adds the resources of the given HardwareCapability to the current HardwareCapability
func (c *HardwareCapability) Add(other HardwareCapability) error {
// Executors
if err := c.Executors.Add(other.Executors); err != nil {
return fmt.Errorf("error adding Executors")
}
// JobTypes
if err := c.JobTypes.Add(other.JobTypes); err != nil {
return fmt.Errorf("error adding JobTypes: %v", err)
}
// Resources
if err := c.Resources.Add(other.Resources); err != nil {
return err
}
// Libraries
if err := c.Libraries.Add(other.Libraries); err != nil {
return fmt.Errorf("error adding Libraries: %v", err)
}
// Localities
if err := c.Localities.Add(other.Localities); err != nil {
return fmt.Errorf("error adding Localities: %v", err)
}
// Connectivity
if err := c.Connectivity.Add(other.Connectivity); err != nil {
return fmt.Errorf("error adding Connectivity: %v", err)
}
// Price
if err := c.Price.Add(other.Price); err != nil {
return fmt.Errorf("error adding Price: %v", err)
}
// Time
if err := c.Time.Add(other.Time); err != nil {
return fmt.Errorf("error adding Time: %v", err)
}
// KYCs
if err := c.KYCs.Add(other.KYCs); err != nil {
return fmt.Errorf("error adding KYCs: %v", err)
}
return nil
}
// Subtract subtracts the resources of the given HardwareCapability from the current HardwareCapability
func (c *HardwareCapability) Subtract(cap HardwareCapability) error {
// Executors
if err := c.Executors.Subtract(cap.Executors); err != nil {
return fmt.Errorf("error subtracting Executors: %v", err)
}
// JobTypes
if err := c.JobTypes.Subtract(cap.JobTypes); err != nil {
return fmt.Errorf("error comparing JobTypes: %v", err)
}
// Resources
if err := c.Resources.Subtract(cap.Resources); err != nil {
return fmt.Errorf("error subtracting Resources: %v", err)
}
// Libraries
if err := c.Libraries.Subtract(cap.Libraries); err != nil {
return fmt.Errorf("error subtracting Libraries: %v", err)
}
// Localities
if err := c.Localities.Subtract(cap.Localities); err != nil {
return fmt.Errorf("error subtracting Localities: %v", err)
}
// Connectivity
if err := c.Connectivity.Subtract(cap.Connectivity); err != nil {
return fmt.Errorf("error subtracting Connectivity: %v", err)
}
// Price
if err := c.Price.Subtract(cap.Price); err != nil {
return fmt.Errorf("error subtracting Price: %v", err)
}
// Time
if err := c.Time.Subtract(cap.Time); err != nil {
return fmt.Errorf("error subtracting Time: %v", err)
}
// KYCs
if err := c.KYCs.Subtract(cap.KYCs); err != nil {
return fmt.Errorf("error subtracting KYCs: %v", err)
}
return nil
}
package types
import "reflect"
// Comparable public Comparable interface to be enforced on types that can be compared
type Comparable[T any] interface {
Compare(other T) Comparison
}
type Calculable[T any] interface {
Add(other T) error
Subtract(other T) error
}
type PreferenceString string
const (
Hard PreferenceString = "Hard"
Soft PreferenceString = "Soft"
)
type Preference struct {
TypeName string
Strength PreferenceString
DefaultComparatorOverride Comparator
}
type Comparator func(l, r interface{}, preference ...Preference) Comparison
// ComplexCompare helper function to return a complex comparison of two complex types
// this uses reflection, could become a performance bottleneck
// with generics, we wouldn't need this function and could use the ComplexCompare method directly
func ComplexCompare(l, r interface{}) ComplexComparison {
// Complex comparison is a comparison of two complex types
// Which have nested fields that need to be considered together
// before a final comparison for the whole complex type can be made
// it is a helper function used in some type comparators
complexComparison := make(map[string]Comparison)
val1 := reflect.ValueOf(l)
val2 := reflect.ValueOf(r)
for i := 0; i < val1.NumField(); i++ {
field1 := val1.Field(i)
field2 := val2.Field(i)
// we need a pointer to field1 to call the ComplexCompare method
// if the field is not addressable we need to create a pointer to it
// and then call the ComplexCompare method on it
var field1Ptr reflect.Value
if field1.CanAddr() {
field1Ptr = field1.Addr()
} else {
ptr := reflect.New(field1.Type())
ptr.Elem().Set(field1)
field1Ptr = ptr
}
compareMethod := field1Ptr.MethodByName("Compare")
if compareMethod.IsValid() &&
compareMethod.Type().NumIn() == 1 &&
compareMethod.Type().In(0) == field1.Type() {
result := compareMethod.Call([]reflect.Value{field2})[0].Interface().(Comparison)
complexComparison[val1.Type().Field(i).Name] = result
} else {
complexComparison[val1.Type().Field(i).Name] = Error
}
}
return complexComparison
}
func NumericComparator[T float64 | float32 | int | int32 | int64 | uint64](l, r T, _ ...Preference) Comparison {
// comparator for numeric types:
// left represents machine capabilities;
// right represents required capabilities;
switch {
case l == r:
return Equal
case l < r:
return Worse
case l > r:
return Better
}
return Error
}
func LiteralComparator[T ~string](l, r T, _ ...Preference) Comparison {
// comparator for literal (string-like) types:
// left represents machine capabilities;
// right represents required capabilities;
// which can only be equal or not equal.
// ComplexCompare the string values
if l == r {
return Equal
}
return Error
}
package types
// Comparison is a type for comparison results
type Comparison string
const (
// Worse means left object is 'worse' than right object
Worse Comparison = "Worse"
// Better means left object is 'better' than right object
Better Comparison = "Better"
// Equal means objects on the left and right are 'equally good'
Equal Comparison = "Equal"
// Error means error in comparison or objects incomparable
Error Comparison = "Error"
)
// And returns the result of AND operation of two Comparison values
// it respects the following table of truth:
// | AND | Better | Worse | Equal | Error |
// | ------ | ------ |--------|--------|--------|
// | Better | Better | Worse | Better | Error |
// | Worse | Worse | Worse | Worse | Error |
// | Equal | Better | Worse | Equal | Error |
// | Error | Error | Error | Error | Error |
func (c Comparison) And(cmp Comparison) Comparison {
if c == Error || cmp == Error {
return Error
}
if c == cmp {
return c
}
switch c {
case Equal:
switch cmp {
case Better:
return Better
case Worse:
return Worse
case Equal:
return Equal
default:
return Error
}
case Better:
switch cmp {
case Worse:
return Worse
case Equal:
return Better
default:
return Error
}
case Worse:
return Worse
default:
return Error
}
}
// ComplexComparison is a map of string to Comparison
type ComplexComparison map[string]Comparison
// Result returns the result of AND operation of all Comparison values in the ComplexComparison
func (c *ComplexComparison) Result() Comparison {
result := Equal
for _, comparison := range *c {
result = result.And(comparison)
}
return result
}
package types
// ExecutionRequest is the request object for executing a job
type ExecutionRequest struct {
JobID string // ID of the job to execute
ExecutionID string // ID of the execution
EngineSpec *SpecConfig // Engine spec for the execution
Resources *Resources // Resources for the execution
Inputs []*StorageVolumeExecutor // Input volumes for the execution
Outputs []*StorageVolumeExecutor // Output volumes for the results
ResultsDir string // Directory to store the results
}
// ExecutionListItem is the result of the current executions.
type ExecutionListItem struct {
ExecutionID string // ID of the execution
Running bool
}
// ExecutionResult is the result of an execution
type ExecutionResult struct {
STDOUT string `json:"stdout"` // STDOUT of the execution
STDERR string `json:"stderr"` // STDERR of the execution
ExitCode int `json:"exit_code"` // Exit code of the execution
ErrorMsg string `json:"error_msg"` // Error message if the execution failed
}
// NewExecutionResult creates a new ExecutionResult object
func NewExecutionResult(code int) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: code,
}
}
// NewFailedExecutionResult creates a new ExecutionResult object for a failed execution
// It sets the error message from the provided error and sets the exit code to -1
func NewFailedExecutionResult(err error) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: -1,
ErrorMsg: err.Error(),
}
}
// LogStreamRequest is the request object for streaming logs from an execution
type LogStreamRequest struct {
JobID string // ID of the job
ExecutionID string // ID of the execution
Tail bool // Tail the logs
Follow bool // Follow the logs
}
package types
import (
"reflect"
)
// ExecutorType is the type of the executor
type ExecutorType string
const (
ExecutorTypeDocker ExecutorType = "docker"
ExecutorTypeFirecracker ExecutorType = "firecracker"
ExecutorTypeWasm ExecutorType = "wasm"
ExecutionStatusCodeSuccess = 0
)
// implementing Comparable interface
var _ Comparable[ExecutorType] = (*ExecutorType)(nil)
// Compare compares two ExecutorType objects
func (e ExecutorType) Compare(other ExecutorType) Comparison {
return LiteralComparator(string(e), string(other))
}
// String returns the string representation of the ExecutorType
func (e ExecutorType) String() string {
return string(e)
}
// Executor is the executor type
type Executor struct {
ExecutorType ExecutorType `json:"executor_type"`
}
// implementing Comparable interface
var _ Comparable[Executor] = (*Executor)(nil)
// Compare compares two Executor objects
func (e *Executor) Compare(other Executor) Comparison {
// comparator for Executor types
// it is needed because executor type is defined as enum of ExecutorType's in types.execution.go
// left represent machine capabilities
// right represent required capabilities
// it is not so complex as the type has only one field
// therefore this method just passes it through...
return e.ExecutorType.Compare(other.ExecutorType)
}
// Equal checks if two Executor objects are equal
func (e *Executor) Equal(executor Executor) bool {
return e.ExecutorType == executor.ExecutorType
}
// Executors is a list of Executor objects
type Executors []Executor
// implementing Comparable and Calculable interface
var (
_ Comparable[Executors] = (*Executors)(nil)
_ Calculable[Executors] = (*Executors)(nil)
)
// Add adds the Executor object to another Executor object
func (e *Executors) Add(other Executors) error {
// append to Executors slice
*e = append(*e, other...)
return nil
}
// Subtract subtracts the Executor object from another Executor object
func (e *Executors) Subtract(other Executors) error {
if len(other) == 0 {
return nil
}
toRemove := make(map[ExecutorType]struct{})
for _, ex := range other {
toRemove[ex.ExecutorType] = struct{}{}
}
result := (*e)[:0]
for _, ex := range *e {
if _, found := toRemove[ex.ExecutorType]; !found {
result = append(result, ex)
}
}
*e = result[:len(result):len(result)]
return nil
}
// Contains checks if an Executor object is in the list of Executors
func (e *Executors) Contains(executor Executor) bool {
executors := *e
for _, ex := range executors {
if ex.Equal(executor) {
return true
}
}
return false
}
// Compare compares two Executors objects
func (e *Executors) Compare(other Executors) Comparison {
if reflect.DeepEqual(*e, other) {
return Equal
}
// comparator for Executors types:
// left represent machine capabilities;
// right represent required capabilities;
lSlice := make([]interface{}, 0, len(*e))
rSlice := make([]interface{}, 0, len(other))
for _, ex := range *e {
lSlice = append(lSlice, ex)
}
for _, ex := range other {
rSlice = append(rSlice, ex)
}
if !IsSameShallowType(lSlice, rSlice) {
return Error
}
switch {
case reflect.DeepEqual(lSlice, rSlice):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return Equal
case IsStrictlyContained(lSlice, rSlice):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
return Better
case IsStrictlyContained(rSlice, lSlice):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
return Worse
default:
return Error
}
}
package types
import (
"fmt"
"slices"
"strings"
)
// HardwareManager defines the interface for managing machine resources.
type HardwareManager interface {
GetMachineResources() (MachineResources, error)
GetUsage() (Resources, error)
GetFreeResources() (Resources, error)
}
type GPUVendor string
const (
GPUVendorNvidia GPUVendor = "NVIDIA"
GPUVendorAMDATI GPUVendor = "AMD/ATI"
GPUVendorIntel GPUVendor = "Intel"
GPUVendorUnknown GPUVendor = "Unknown"
GPUVendorNone GPUVendor = "None"
)
// implementing Comparable interface
var _ Comparable[GPUVendor] = (*GPUVendor)(nil)
func (g GPUVendor) Compare(other GPUVendor) Comparison {
if g == other {
return Equal
}
return Error
}
// ParseGPUVendor parses the GPU vendor string and returns the corresponding GPUVendor enum
func ParseGPUVendor(vendor string) GPUVendor {
switch {
case strings.Contains(strings.ToUpper(vendor), "NVIDIA"):
return GPUVendorNvidia
case strings.Contains(strings.ToUpper(vendor), "AMD") ||
strings.Contains(strings.ToUpper(vendor), "ATI"):
return GPUVendorAMDATI
case strings.Contains(strings.ToUpper(vendor), "INTEL"):
return GPUVendorIntel
default:
return GPUVendorUnknown
}
}
// GPU represents the GPU information
type GPU struct {
// Index is the self-reported index of the device in the system
Index int
// Name is the model name of the GPU e.g. Tesla T4
Name string
// Vendor is the maker of the GPU, e.g. NVidia, AMD, Intel
Vendor GPUVendor
// PCIAddress is the PCI address of the device, in the format AAAA:BB:CC.C
// Used to discover the correct device rendering cards
PCIAddress string
// Model of the GPU, e.g. A100
Model string `json:"model" description:"GPU model, ex A100"`
// VRAM is the total amount of VRAM on the device
VRAM uint64
// Gorm fields
// Team, is this the right way to do this? What is the best practice we're following?
ResourceID uint `gorm:"foreignKey:ID"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[GPU] = (*GPU)(nil)
_ Calculable[GPU] = (*GPU)(nil)
)
func (g *GPU) Compare(other GPU) Comparison {
comparison := make(ComplexComparison)
// compare the VRAM
switch {
case g.VRAM > other.VRAM:
comparison["VRAM"] = Better
case g.VRAM < other.VRAM:
comparison["VRAM"] = Worse
default:
comparison["VRAM"] = Equal
}
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and VRAM is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of GPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
return comparison["VRAM"]
}
func (g *GPU) Add(other GPU) error {
g.VRAM += other.VRAM
return nil
}
func (g *GPU) Subtract(other GPU) error {
if g.VRAM < other.VRAM {
return fmt.Errorf("total VRAM: underflow, cannot subtract %v from %v", g.VRAM, other.VRAM)
}
g.VRAM -= other.VRAM
return nil
}
func (g *GPU) Equal(other GPU) bool {
return g.Model == other.Model &&
g.VRAM == other.VRAM &&
g.Index == other.Index &&
g.Vendor == other.Vendor &&
g.PCIAddress == other.PCIAddress
}
type GPUs []GPU
// implementing Comparable and Calculable interfaces
var (
_ Calculable[GPUs] = (*GPUs)(nil)
_ Comparable[GPUs] = (*GPUs)(nil)
)
func (gpus GPUs) Compare(other GPUs) Comparison {
interimComparison1 := make([][]Comparison, 0)
for _, otherGPU := range other {
var interimComparison2 []Comparison
for _, ownGPU := range gpus {
interimComparison2 = append(interimComparison2, ownGPU.Compare(otherGPU))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []Comparison
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, Error) {
return Error
}
if slices.Contains(finalComparison, Worse) {
return Worse
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal
}
return Better
}
func (gpus GPUs) Add(other GPUs) error {
// TODO: I think this logic needs to change
// 1. if other gpu is in own gpus, add the total vram
// 2. if other gpu is not in own gpus, append it to own gpus
// assuming that the GPUs are ordered by index
// which may not be the case
otherGPUs := make(map[int]GPU)
for _, otherGPU := range other {
otherGPUs[otherGPU.Index] = otherGPU
}
for i, gpu := range gpus {
if otherGPU, ok := otherGPUs[gpu.Index]; ok {
if err := gpus[i].Add(otherGPU); err != nil {
return fmt.Errorf("failed to add GPU %s: %w", gpu.Model, err)
}
}
}
return nil
}
func (gpus GPUs) Subtract(other GPUs) error {
// assuming that the GPUs are ordered by index
// which may not be the case
otherGPUs := make(map[int]GPU)
for _, otherGPU := range other {
otherGPUs[otherGPU.Index] = otherGPU
}
for i, gpu := range gpus {
if otherGPU, ok := otherGPUs[gpu.Index]; ok {
if err := gpus[i].Subtract(otherGPU); err != nil {
return fmt.Errorf("failed to subtract GPU %s: %w", gpu.Model, err)
}
}
}
return nil
}
// MaxFreeVRAMGPU returns the GPU with the maximum free VRAM from the list of GPUs
func (gpus GPUs) MaxFreeVRAMGPU() (GPU, error) {
if len(gpus) == 0 {
return GPU{}, fmt.Errorf("no GPUs found")
}
maxFreeVRAM := uint64(0)
var maxFreeVRAMGPU GPU
for _, gpu := range gpus {
if gpu.VRAM > maxFreeVRAM {
maxFreeVRAM = gpu.VRAM
maxFreeVRAMGPU = gpu
}
}
return maxFreeVRAMGPU, nil
}
// CPU represents the CPU information
type CPU struct {
// ClockSpeed represents the CPU clock speed in Hz
ClockSpeed float64
// Cores represents the number of physical CPU cores
Cores float32
// TODO: capture the below fields if required
// Model represents the CPU model, e.g., "Intel Core i7-9700K", "AMD Ryzen 9 5900X"
Model string
// Vendor represents the CPU manufacturer, e.g., "Intel", "AMD"
Vendor string
// Threads represents the number of logical CPU threads (including hyperthreading)
Threads int
// Architecture represents the CPU architecture, e.g., "x86", "x86_64", "arm64"
Architecture string
// Cache size in bytes
CacheSize uint64
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[CPU] = (*CPU)(nil)
_ Comparable[CPU] = (*CPU)(nil)
)
func (c *CPU) Compare(other CPU) Comparison {
perfComparison := NumericComparator(
float64(c.Cores)*c.ClockSpeed,
float64(other.Cores)*other.ClockSpeed,
)
archComparison := LiteralComparator(c.Architecture, other.Architecture)
if archComparison == Error {
return Error
}
if archComparison != Equal {
return Worse
}
return perfComparison
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and frequency is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of CPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
}
func (c *CPU) Add(other CPU) error {
c.Cores = round(c.Cores+other.Cores, 2)
return nil
}
func (c *CPU) Subtract(other CPU) error {
if c.Cores < other.Cores {
return fmt.Errorf("core: underflow, cannot subtract %v from %v", c.Cores, other.Cores)
}
c.Cores = round(c.Cores-other.Cores, 2)
return nil
}
func (c *CPU) Compute() float64 {
return float64(c.Cores) * c.ClockSpeed
}
// RAM represents the RAM information
type RAM struct {
// Size in bytes
Size float64
// TODO: capture the below fields if required
// Clock speed in Hz
ClockSpeed uint64
// Type represents the RAM type, e.g., "DDR4", "DDR5", "LPDDR4"
Type string
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[RAM] = (*RAM)(nil)
_ Comparable[RAM] = (*RAM)(nil)
)
func (r *RAM) Compare(other RAM) Comparison {
comparison := make(ComplexComparison)
// compare the Size
comparison["Size"] = NumericComparator(r.Size, other.Size)
comparison["ClockSpeed"] = NumericComparator(r.ClockSpeed, other.ClockSpeed)
return comparison["Size"]
}
func (r *RAM) Add(other RAM) error {
r.Size += other.Size
return nil
}
func (r *RAM) Subtract(other RAM) error {
if r.Size < other.Size {
return fmt.Errorf("size: underflow, cannot subtract %v from %v", r.Size, other.Size)
}
r.Size -= other.Size
return nil
}
// Disk represents the disk information
type Disk struct {
// Size in bytes
Size float64
// TODO: capture the below fields if required
// Model represents the disk model, e.g., "Samsung 970 EVO Plus", "Western Digital Blue SN550"
Model string
// Vendor represents the disk manufacturer, e.g., "Samsung", "Western Digital"
Vendor string
// Type represents the disk type, e.g., "SSD", "HDD", "NVMe"
Type string
// Interface represents the disk interface, e.g., "SATA", "PCIe", "M.2"
Interface string
// Read speed in bytes per second
ReadSpeed uint64
// Write speed in bytes per second
WriteSpeed uint64
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[Disk] = (*Disk)(nil)
_ Comparable[Disk] = (*Disk)(nil)
)
func (d *Disk) Compare(other Disk) Comparison {
comparison := make(ComplexComparison)
// compare the Size
comparison["Size"] = NumericComparator(d.Size, other.Size)
return comparison["Size"]
}
func (d *Disk) Add(other Disk) error {
d.Size += other.Size
return nil
}
func (d *Disk) Subtract(other Disk) error {
if d.Size < other.Size {
return fmt.Errorf("size: underflow, cannot subtract %v from %v", d.Size, other.Size)
}
d.Size -= other.Size
return nil
}
// NetworkInfo represents the network information
// TODO: not yet used, but can be used to capture the network information
type NetworkInfo struct {
// Bandwidth in bits per second (b/s)
Bandwidth uint64
// NetworkType represents the network type, e.g., "Ethernet", "Wi-Fi", "Cellular"
NetworkType string
}
// GPUMetadata holds the metadata of the GPU
type GPUMetadata struct {
PCIAddress string
}
// ConvertBytesToGB converts bytes to gigabytes
func ConvertBytesToGB(bytes uint64) float64 {
return float64(bytes) / 1e9
}
package types
const (
NetP2P = "p2p"
)
// NetworkSpec is a stub. Please expand based on requirements.
type NetworkSpec struct{}
// NetConfig is a stub. Please expand it or completely change it based on requirements.
type NetConfig struct {
NetworkSpec SpecConfig `json:"network_spec"` // Network specification
}
func (nc *NetConfig) GetNetworkConfig() *SpecConfig {
return &nc.NetworkSpec
}
// NetworkStats should contain all network info the user is interested in.
// for now there's only peerID and listening address but reachability, local and remote addr etc...
// can be added when necessary.
type NetworkStats struct {
ID string `json:"id"`
ListenAddr string `json:"listen_addr"`
}
// MessageInfo is a stub. Please expand it or completely change it based on requirements.
type MessageInfo struct {
Info string `json:"info"` // Message information
}
package types
import (
"reflect"
)
func ConvertNumericToFloat64(n any) (float64, bool) {
switch n := n.(type) {
case int, int8, int16, int32, int64:
return float64(reflect.ValueOf(n).Int()), true
case uint, uint8, uint16, uint32, uint64:
return float64(reflect.ValueOf(n).Uint()), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}
package types
import (
"context"
"fmt"
)
// Resources represents the resources of the machine
type Resources struct {
CPU CPU `gorm:"embedded;embeddedPrefix:cpu_"`
GPUs GPUs `gorm:"foreignKey:ResourceID"`
RAM RAM `gorm:"embedded;embeddedPrefix:ram_"`
Disk Disk `gorm:"embedded;embeddedPrefix:disk_"`
}
// implements the Calculable and Comparable interfaces
var (
_ Calculable[Resources] = (*Resources)(nil)
_ Comparable[Resources] = (*Resources)(nil)
)
// Compare compares two Resources objects
func (r *Resources) Compare(other Resources) Comparison {
comparisonMap := ComplexComparison{
"CPU": r.CPU.Compare(other.CPU),
"RAM": r.RAM.Compare(other.RAM),
"Disk": r.Disk.Compare(other.Disk),
"GPUs": r.GPUs.Compare(other.GPUs),
}
return comparisonMap.Result()
}
// Equal returns true if the resources are equal
func (r *Resources) Equal(other Resources) bool {
if r.RAM.Size != other.RAM.Size {
return false
}
if r.CPU.Cores != other.CPU.Cores {
return false
}
if r.Disk.Size != other.Disk.Size {
return false
}
return true
}
// Add returns the sum of the resources
func (r *Resources) Add(other Resources) error {
if err := r.CPU.Add(other.CPU); err != nil {
return fmt.Errorf("error adding CPU: %v", err)
}
if err := r.RAM.Add(other.RAM); err != nil {
return fmt.Errorf("error adding RAM: %v", err)
}
if err := r.Disk.Add(other.Disk); err != nil {
return fmt.Errorf("error adding Disk: %v", err)
}
if err := r.GPUs.Add(other.GPUs); err != nil {
return fmt.Errorf("error adding GPUs: %v", err)
}
return nil
}
// Subtract returns the difference of the resources
func (r *Resources) Subtract(other Resources) error {
if err := r.CPU.Subtract(other.CPU); err != nil {
return fmt.Errorf("error subtracting CPU: %v", err)
}
if err := r.RAM.Subtract(other.RAM); err != nil {
return fmt.Errorf("error subtracting RAM: %v", err)
}
if err := r.Disk.Subtract(other.Disk); err != nil {
return fmt.Errorf("error subtracting Disk: %v", err)
}
if err := r.GPUs.Subtract(other.GPUs); err != nil {
return fmt.Errorf("error subtracting GPUs: %v", err)
}
return nil
}
// MachineResources represents the total resources of the machine
type MachineResources struct {
BaseDBModel
Resources
}
// FreeResources represents the free resources of the machine
type FreeResources struct {
BaseDBModel
Resources
}
// OnboardedResources represents the onboarded resources of the machine
type OnboardedResources struct {
BaseDBModel
Resources
}
// ResourceAllocation represents the allocation of resources for a job
type ResourceAllocation struct {
BaseDBModel
JobID string
Resources
}
// ResourceManager is an interface that defines the methods to manage the resources of the machine
type ResourceManager interface {
// AllocateResources allocates the resources required by a job
AllocateResources(context.Context, ResourceAllocation) error
// DeallocateResources deallocates the resources required by a job
DeallocateResources(context.Context, string) error
// GetTotalAllocation returns the total allocations for the jobs
GetTotalAllocation() (Resources, error)
// GetFreeResources returns the free resources in the allocation pool
GetFreeResources(ctx context.Context) (FreeResources, error)
// GetOnboardedResources returns the onboarded resources of the machine
GetOnboardedResources(context.Context) (OnboardedResources, error)
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
UpdateOnboardedResources(context.Context, Resources) error
}
package types
import (
"errors"
"strings"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// SpecConfig represents a configuration for a spec
// A SpecConfig can be used to define an engine spec, a storage volume, etc.
type SpecConfig struct {
// Type of the spec (e.g. docker, firecracker, storage, etc.)
Type string `json:"type"`
// Params of the spec
Params map[string]interface{} `json:"params,omitempty"`
}
type Config interface {
GetNetworkConfig() *SpecConfig
}
// NewSpecConfig creates a new SpecConfig with the given type
func NewSpecConfig(t string) *SpecConfig {
return &SpecConfig{
Type: t,
Params: make(map[string]interface{}),
}
}
// WithParam adds a new key-value pair to the spec params
func (s *SpecConfig) WithParam(key string, value interface{}) *SpecConfig {
if s.Params == nil {
s.Params = make(map[string]interface{})
}
s.Params[key] = value
return s
}
// Normalize ensures that the spec config is in a valid state
func (s *SpecConfig) Normalize() {
if s == nil {
return
}
s.Type = strings.TrimSpace(s.Type)
// Ensure that an empty and nil map are treated the same
if len(s.Params) == 0 {
s.Params = make(map[string]interface{})
}
}
// Validate checks if the spec config is valid
func (s *SpecConfig) Validate() error {
if s == nil {
return errors.New("nil spec config")
}
if validate.IsBlank(s.Type) {
return errors.New("missing spec type")
}
return nil
}
// IsType returns true if the current SpecConfig is of the given type
func (s *SpecConfig) IsType(t string) bool {
if s == nil {
return false
}
t = strings.TrimSpace(t)
return strings.EqualFold(s.Type, t)
}
// IsEmpty returns true if the spec config is empty
func (s *SpecConfig) IsEmpty() bool {
return s == nil || (validate.IsBlank(s.Type) && len(s.Params) == 0)
}
package types
import (
"context"
)
// TelemetryConfig holds the configuration for the telemetry system.
type TelemetryConfig struct {
ServiceName string
GlobalEndpoint string
ObservabilityLevel string
CollectorConfigs map[string]CollectorConfig
TelemetryMode string
}
// CollectorConfig holds the configuration for individual collectors.
type CollectorConfig struct {
CollectorType string
CollectorEndpoint string
}
// Event represents a telemetry event with its details.
type Event struct {
Context context.Context
Level ObservabilityLevel
Message string
Payload map[string]interface{}
}
// ObservabilityLevel defines the levels of observability.
type ObservabilityLevel int
const (
TRACE ObservabilityLevel = 1
DEBUG ObservabilityLevel = 2
INFO ObservabilityLevel = 3
WARN ObservabilityLevel = 4
ERROR ObservabilityLevel = 5
FATAL ObservabilityLevel = 6
)
// ParseObservabilityLevel converts a string representation of the observability level to an integer.
func ParseObservabilityLevel(levelStr string) (ObservabilityLevel, error) {
switch levelStr {
case "TRACE":
return TRACE, nil
case "DEBUG":
return DEBUG, nil
case "INFO":
return INFO, nil
case "WARN":
return WARN, nil
case "ERROR":
return ERROR, nil
case "FATAL":
return FATAL, nil
default:
return INFO, nil
}
}
func (level ObservabilityLevel) String() string {
switch level {
case TRACE:
return "TRACE"
case DEBUG:
return "DEBUG"
case INFO:
return "INFO"
case WARN:
return "WARN"
case ERROR:
return "ERROR"
case FATAL:
return "FATAL"
default:
return "UNKNOWN"
}
}
package types
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// BaseDBModel is a base model for all entities. It'll be mainly used for database
// records.
type BaseDBModel struct {
ID string `gorm:"type:uuid"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// GetID returns the ID of the entity.
func (m BaseDBModel) GetID() string {
return m.ID
}
// BeforeCreate sets the ID and CreatedAt fields before creating a new entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeCreate(_ *gorm.DB) error {
m.ID = uuid.NewString()
m.CreatedAt = time.Now()
return nil
}
// BeforeUpdate sets the UpdatedAt field before updating an entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeUpdate(_ *gorm.DB) error {
m.UpdatedAt = time.Now()
return nil
}
package types
import (
"math"
"reflect"
"slices"
)
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
// round rounds the value to the specified number of decimal places
func round[T float32 | float64](value T, places int) T {
factor := math.Pow(10, float64(places))
roundedValue := math.Round(float64(value)*factor) / factor
return T(roundedValue)
}
func SliceContainsOneValue(slice []Comparison, value Comparison) bool {
// returns true if all elements in the slice are equal to the given value
for _, v := range slice {
if v != value {
return false
}
}
return true
}
func returnBestMatch(dimension []Comparison) (Comparison, int) {
// while i feel that there could be some weird matrix sorting algorithm that could be used here
// i can't think of any right now, so i will just iterate over the matrix and return matches
// in somewhat manual way
for i, v := range dimension {
if v == Equal {
return v, i // selecting an equal match is the most efficient match
}
}
for i, v := range dimension {
if v == Better {
return v, i // selecting a better is also not bad
}
}
for i, v := range dimension {
if v == Worse {
return v, i // this is just for sport
}
}
for i, v := range dimension {
if v == Error {
return v, i // this is just for sport
}
}
return Error, -1
}
func removeIndex(slice [][]Comparison, index int) [][]Comparison {
// removeIndex removes the element at the specified index from each sub-slice in the given slice.
// If the index is out of bounds for a sub-slice, the function leaves that sub-slice unmodified.
for i, c := range slice {
if index < 0 || index >= len(c) {
// Index is out of bounds, leave the sub-slice unmodified
continue
}
slice[i] = append(c[:index], c[index+1:]...)
}
return slice
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
package utils
import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/cosmos/btcutil/bech32"
"github.com/ethereum/go-ethereum/common"
"github.com/fivebinaries/go-cardano-serialization/address"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
)
// KoiosEndpoint type for Koios rest api endpoints
type KoiosEndpoint string
const (
// KoiosMainnet - mainnet Koios rest api endpoint
KoiosMainnet KoiosEndpoint = "api.koios.rest"
// KoiosPreProd - testnet preprod Koios rest api endpoint
KoiosPreProd KoiosEndpoint = "preprod.koios.rest"
)
type UTXOs struct {
TxHash string `json:"tx_hash"`
IsSpent bool `json:"is_spent"`
}
type TxHashResp struct {
TxHash string `json:"tx_hash"`
TransactionType string `json:"transaction_type"`
DateTime string `json:"date_time"`
}
type ClaimCardanoTokenBody struct {
ComputeProviderAddress string `json:"compute_provider_address"`
TxHash string `json:"tx_hash"`
}
type RewardRespToCPD struct {
ServiceProviderAddr string `json:"service_provider_addr"`
ComputeProviderAddr string `json:"compute_provider_addr"`
RewardType string `json:"reward_type,omitempty"`
SignatureDatum string `json:"signature_datum,omitempty"`
MessageHashDatum string `json:"message_hash_datum,omitempty"`
Datum string `json:"datum,omitempty"`
SignatureAction string `json:"signature_action,omitempty"`
MessageHashAction string `json:"message_hash_action,omitempty"`
Action string `json:"action,omitempty"`
}
type UpdateTxStatusBody struct {
Address string `json:"address,omitempty"`
}
func GetJobTxHashes(size int, clean string) ([]TxHashResp, error) {
if clean != "done" && clean != "refund" && clean != "withdraw" && clean != "" {
return nil, fmt.Errorf("invalid clean_tx parameter")
}
err := db.DB.Where("transaction_type = ?", clean).Delete(&types.Services{}).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
}
resp := make([]TxHashResp, 0)
services := make([]types.Services, 0)
if size == 0 {
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type is NOT NULL").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("no job deployed to request reward for: %w", err)
}
} else {
services, err = getLimitedTransactions(size)
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("could not get limited transactions: %w", err)
}
}
for _, service := range services {
resp = append(resp, TxHashResp{
TxHash: service.TxHash,
TransactionType: service.TransactionType,
DateTime: service.CreatedAt.String(),
})
}
return resp, nil
}
func RequestReward(claim ClaimCardanoTokenBody) (*RewardRespToCPD, error) {
// At some point, management dashboard should send container ID to identify
// against which container we are requesting reward
service := types.Services{
TxHash: claim.TxHash,
}
// SELECTs the first record; first record which is not marked as delete
err := db.DB.Where("tx_hash = ?", claim.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
return nil, fmt.Errorf("unknown tx hash: %w", err)
}
zlog.Sugar().Infof("service found from txHash: %+v", service)
if service.JobStatus == "running" {
return nil, fmt.Errorf("job is still running")
// c.JSON(503, gin.H{"error": "the job is still running"})
}
reward := RewardRespToCPD{
ServiceProviderAddr: service.ServiceProviderAddr,
ComputeProviderAddr: service.ComputeProviderAddr,
RewardType: service.TransactionType,
SignatureDatum: service.SignatureDatum,
MessageHashDatum: service.MessageHashDatum,
Datum: service.Datum,
SignatureAction: service.SignatureAction,
MessageHashAction: service.MessageHashAction,
Action: service.Action,
}
return &reward, nil
}
func SendStatus(status types.BlockchainTxStatus) string {
if status.TransactionStatus == "success" {
zlog.Sugar().Infof("withdraw transaction successful - updating DB")
// Partial deletion of entry
var service types.Services
err := db.DB.Where("tx_hash = ?", status.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
}
service.TransactionType = "done"
db.DB.Save(&service)
}
return status.TransactionStatus
}
func UpdateStatus(body UpdateTxStatusBody) error {
utxoHashes, err := GetUTXOsOfSmartContract(body.Address, KoiosPreProd)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to fetch UTXOs from Blockchain: %w", err)
}
fiveMinAgo := time.Now().Add(-5 * time.Minute)
var services []types.Services
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Where("deleted_at IS NULL").
Where("created_at <= ?", fiveMinAgo).
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("no job deployed to request reward for: %w", err)
}
err = UpdateTransactionStatus(services, utxoHashes)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to update transaction status")
}
return nil
}
func getLimitedTransactions(sizeDone int) ([]types.Services, error) {
var doneServices []types.Services
var services []types.Services
err := db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type = ?", "done").
Order("created_at DESC").
Limit(sizeDone).
Find(&doneServices).Error
if err != nil {
return []types.Services{}, err
}
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
return []types.Services{}, err
}
services = append(services, doneServices...)
return services, nil
}
// isValidCardano checks if the cardano address is valid
func isValidCardano(addr string, valid *bool) {
defer func() {
if r := recover(); r != nil {
*valid = false
}
}()
if _, err := address.NewAddress(addr); err == nil {
*valid = true
}
}
// ValidateAddress checks if the wallet address is a valid ethereum/cardano address
func ValidateAddress(addr string) error {
if common.IsHexAddress(addr) {
return errors.New("ethereum wallet address not allowed")
}
validCardano := false
isValidCardano(addr, &validCardano)
if validCardano {
return nil
}
return errors.New("invalid cardano wallet address")
}
func GetAddressPaymentCredential(addr string) (string, error) {
_, data, err := bech32.Decode(addr, 1023)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
converted, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
return hex.EncodeToString(converted)[2:58], nil
}
// GetTxReceiver returns the list of receivers of a transaction from the transaction hash
func GetTxReceiver(txHash string, endpoint KoiosEndpoint) (string, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_info", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return "", err
}
defer resp.Body.Close()
res := []struct {
Outputs []struct {
InlineDatum struct {
Value struct {
Fields []struct {
Bytes string `json:"bytes"`
} `json:"fields"`
} `json:"value"`
} `json:"inline_datum"`
} `json:"outputs"`
}{}
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&res); err != nil && err != io.EOF {
return "", err
}
if len(res) == 0 || len(res[0].Outputs) == 0 || len(res[0].Outputs[1].InlineDatum.Value.Fields) == 0 {
return "", fmt.Errorf("unable to find receiver")
}
receiver := res[0].Outputs[1].InlineDatum.Value.Fields[1].Bytes
return receiver, nil
}
// GetTxConfirmations returns the number of confirmations of a transaction from the transaction hash
func GetTxConfirmations(txHash string, endpoint KoiosEndpoint) (int, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_status", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return 0, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}
var res []struct {
TxHash string `json:"tx_hash"`
Confirmations int `json:"num_confirmations"`
}
if err := json.Unmarshal(body, &res); err != nil {
return 0, err
}
return res[len(res)-1].Confirmations, nil
}
// WaitForTxConfirmation waits for a transaction to be confirmed
func WaitForTxConfirmation(confirmations int, timeout time.Duration, txHash string, endpoint KoiosEndpoint) error {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
conf, err := GetTxConfirmations(txHash, endpoint)
if err != nil {
return err
}
if conf >= confirmations {
return nil
}
case <-time.After(timeout):
return errors.New("timeout")
}
}
}
// GetUTXOsOfSmartContract fetch all utxos of smart contract and return list of tx_hash
func GetUTXOsOfSmartContract(address string, endpoint KoiosEndpoint) ([]string, error) {
type Request struct {
Address []string `json:"_addresses"`
Extended bool `json:"_extended"`
}
reqBody, _ := json.Marshal(Request{Address: []string{address}, Extended: true})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/address_utxos", endpoint),
"application/json",
bytes.NewBuffer(reqBody),
)
if err != nil {
return nil, fmt.Errorf("error making POST request: %v", err)
}
defer resp.Body.Close()
var utxos []UTXOs
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&utxos); err != nil && err != io.EOF {
return nil, err
}
utxoHashes := make([]string, 0)
for _, utxo := range utxos {
utxoHashes = append(utxoHashes, utxo.TxHash)
}
return utxoHashes, nil
}
// UpdateTransactionStatus updates the status of claimed transactions in local DB
func UpdateTransactionStatus(services []types.Services, utxoHashes []string) error {
for _, service := range services {
if !SliceContains(utxoHashes, service.TxHash) {
switch service.TransactionType {
case "withdraw":
{
service.TransactionType = transactionWithdrawnStatus
}
case "refund":
{
service.TransactionType = transactionRefundedStatus
}
case "distribute-50":
case "distribute-75":
{
service.TransactionType = transactionDistributedStatus
}
}
s := service
if err := db.DB.Save(&s).Error; err != nil {
return err
}
}
}
return nil
}
package utils
import (
"errors"
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
)
func GetDirectorySize(fs afero.Fs, path string) (int64, error) {
var size int64
err := afero.Walk(fs, path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to calculate volume size: %w", err)
}
return size, nil
}
// WriteToFile writes data to a file.
func WriteToFile(fs afero.Fs, data []byte, filePath string) (string, error) {
if err := fs.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return "", fmt.Errorf("failed to open path: %w", err)
}
file, err := fs.Create(filePath)
if err != nil {
return "", fmt.Errorf("failed to create path: %w", err)
}
defer file.Close()
n, err := file.Write(data)
if err != nil {
return "", fmt.Errorf("failed to write data to path: %w", err)
}
if n != len(data) {
return "", errors.New("failed to write the size of data to file")
}
return filePath, nil
}
// FileExists checks if destination file exists
func FileExists(fs afero.Fs, filename string) bool {
info, err := fs.Stat(filename)
if os.IsNotExist(err) {
return false
}
return !info.IsDir()
}
package utils
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"path"
)
type HTTPClient struct {
BaseURL string
APIVersion string
Client *http.Client
}
func NewHTTPClient(baseURL, version string) *HTTPClient {
return &HTTPClient{
BaseURL: baseURL,
APIVersion: version,
Client: http.DefaultClient,
}
}
// MakeRequest performs an HTTP request with the given method, path, and body
// It returns the response body, status code, and an error if any
func (c *HTTPClient) MakeRequest(method, relativePath string, body []byte) ([]byte, int, error) {
url, err := url.Parse(c.BaseURL)
if err != nil {
return nil, 0, fmt.Errorf("failed to parse base URL: %v", err)
}
url.Path = path.Join(c.APIVersion, relativePath)
req, err := http.NewRequest(method, url.String(), bytes.NewBuffer(body))
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
resp, err := c.Client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("request failed: %v", err)
}
defer resp.Body.Close()
// Read the response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("failed to read response body: %v", err)
}
return respBody, resp.StatusCode, nil
}
package utils
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
const transactionWithdrawnStatus = "withdrawn"
const transactionRefundedStatus = "refunded"
const transactionDistributedStatus = "distributed"
func init() {
zlog = logger.OtelZapLogger("utils")
}
package utils
import (
"io"
"sync"
"time"
)
type IOProgress struct {
n float64
size float64
started time.Time
estimated time.Time
err error
}
type Reader struct {
reader io.Reader
lock sync.RWMutex
Progress IOProgress
}
type Writer struct {
writer io.Writer
lock sync.RWMutex
Progress IOProgress
}
func ReaderWithProgress(r io.Reader, size int64) *Reader {
return &Reader{
reader: r,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func WriterWithProgress(w io.Writer, size int64) *Writer {
return &Writer{
writer: w,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func (r *Reader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.lock.Lock()
r.Progress.n += float64(n)
r.Progress.err = err
r.lock.Unlock()
return n, err
}
func (w *Writer) Write(p []byte) (n int, err error) {
n, err = w.writer.Write(p)
w.lock.Lock()
w.Progress.n += float64(n)
w.Progress.err = err
w.lock.Unlock()
return n, err
}
func (p IOProgress) Size() float64 {
return p.size
}
func (p IOProgress) N() float64 {
return p.n
}
func (p IOProgress) Complete() bool {
if p.err == io.EOF {
return true
}
if p.size == -1 {
return false
}
return p.n >= p.size
}
// Percent calculates the percentage complete.
func (p IOProgress) Percent() float64 {
if p.n == 0 {
return 0
}
if p.n >= p.size {
return 100
}
return 100.0 / (p.size / p.n)
}
func (p IOProgress) Remaining() time.Duration {
if p.estimated.IsZero() {
return time.Until(p.Estimated())
}
return time.Until(p.estimated)
}
func (p IOProgress) Estimated() time.Time {
ratio := p.n / p.size
past := float64(time.Since(p.started))
if p.n > 0.0 {
total := time.Duration(past / ratio)
p.estimated = p.started.Add(total)
}
return p.estimated
}
package utils
import (
"fmt"
"strings"
"sync"
)
// A SyncMap is a concurrency-safe sync.Map that uses strongly-typed
// method signatures to ensure the types of its stored data are known.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// SyncMapFromMap converts a standard Go map to a concurrency-safe SyncMap.
func SyncMapFromMap[K comparable, V any](m map[K]V) *SyncMap[K, V] {
ret := &SyncMap[K, V]{}
for k, v := range m {
ret.Put(k, v)
}
return ret
}
// Get retrieves the value associated with the given key from the map.
// It returns the value and a boolean indicating whether the key was found.
func (m *SyncMap[K, V]) Get(key K) (V, bool) {
value, ok := m.Load(key)
if !ok {
var empty V
return empty, false
}
return value.(V), true
}
// Put inserts or updates a key-value pair in the map.
func (m *SyncMap[K, V]) Put(key K, value V) {
m.Store(key, value)
}
// Iter iterates over each key-value pair in the map, executing the provided function on each pair.
// The iteration stops if the provided function returns false.
func (m *SyncMap[K, V]) Iter(ranger func(key K, value V) bool) {
m.Range(func(key, value any) bool {
k := key.(K)
v := value.(V)
return ranger(k, v)
})
}
// Keys returns a slice containing all the keys present in the map.
func (m *SyncMap[K, V]) Keys() []K {
var keys []K
m.Iter(func(key K, _ V) bool {
keys = append(keys, key)
return true
})
return keys
}
// String provides a string representation of the map, listing all key-value pairs.
func (m *SyncMap[K, V]) String() string {
// Use a strings.Builder for efficient string concatenation.
var sb strings.Builder
sb.Write([]byte(`{`))
m.Range(func(key, value any) bool {
// Append each key-value pair to the string builder.
sb.Write([]byte(fmt.Sprintf(`%s=%s`, key, value)))
return true
})
sb.Write([]byte(`}`))
return sb.String()
}
package utils
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"math/big"
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"time"
"github.com/google/uuid"
"github.com/spf13/afero"
"golang.org/x/exp/slices"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
)
const (
KernelFileURL = "https://d.nunet.io/fc/vmlinux"
KernelFilePath = "/etc/nunet/vmlinux"
FilesystemURL = "https://d.nunet.io/fc/nunet-fc-ubuntu-20.04-0.ext4"
FilesystemPath = "/etc/nunet/nunet-fc-ubuntu-20.04-0.ext4"
)
// DownloadFile downloads a file from a url and saves it to a filepath
func DownloadFile(url string, filepath string, maxBytes int64) (err error) {
zlog.Sugar().Infof("Downloading file '", filepath, "' from '", url, "'")
file, err := os.Create(filepath)
if err != nil {
return err
}
defer file.Close()
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
client := &http.Client{
Timeout: 10 * time.Second,
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download file: server returned %s", resp.Status)
}
reader := io.LimitReader(resp.Body, maxBytes)
_, err = io.Copy(file, reader)
if err != nil {
return err
}
log.Println("Finished downloading file '", filepath, "'")
return nil
}
// ReadHTTPString GET request to http endpoint and return response as string
func ReadHTTPString(url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(respBody), nil
}
// RandomString generates a random string of length n
func RandomString(n int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
sb.WriteByte(charset[n.Int64()])
}
return sb.String(), nil
}
// GenerateMachineUUID generates a machine uuid
func GenerateMachineUUID() (string, error) {
var machine types.MachineUUID
machineUUID, err := uuid.NewDCEGroup()
if err != nil {
return "", err
}
machine.UUID = machineUUID.String()
return machine.UUID, nil
}
// GetMachineUUID returns the machine uuid from the DB
func GetMachineUUID() string {
var machine types.MachineUUID
uuid, err := GenerateMachineUUID()
if err != nil {
zlog.Sugar().Errorf("could not generate machine uuid: %v", err)
}
machine.UUID = uuid
result := db.DB.FirstOrCreate(&machine)
if result.Error != nil {
zlog.Sugar().Errorf("could not find or create machine uuid record in DB: %v", result.Error)
}
return machine.UUID
}
// SliceContains checks if a string exists in a slice
func SliceContains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
// DeleteFile deletes a file, with or without a backup
func DeleteFile(path string, backup bool) (err error) {
if backup {
err = os.Rename(path, fmt.Sprintf("%s.bk.%d", path, time.Now().Unix()))
} else {
err = os.Remove(path)
}
return
}
// ReadyForElastic checks if the device is ready to send logs to elastic
func ReadyForElastic() bool {
elasticToken := types.ElasticToken{}
db.DB.Find(&elasticToken)
return elasticToken.NodeID != "" && elasticToken.ChannelName != ""
}
// CreateDirectoryIfNotExists creates a directory if it does not exist
func CreateDirectoryIfNotExists(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
err := os.MkdirAll(path, 0o755)
if err != nil {
return err
}
}
return nil
}
// CalculateSHA256Checksum calculates the SHA256 checksum of a file
func CalculateSHA256Checksum(filePath string) (string, error) {
// Open the file for reading
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
// Create a new SHA-256 hash
hash := sha256.New()
// Copy the file's contents into the hash object
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
// Calculate the checksum and return it as a hexadecimal string
checksum := hex.EncodeToString(hash.Sum(nil))
return checksum, nil
}
// put checksum in file
func CreateCheckSumFile(filePath string, checksum string) (string, error) {
sha256FilePath := fmt.Sprintf("%s.sha256.txt", filePath)
sha256File, err := os.Create(sha256FilePath)
if err != nil {
return "", fmt.Errorf("unable to create SHA-256 checksum file: %v", err)
}
defer sha256File.Close()
_, err = sha256File.WriteString(checksum)
if err != nil {
return "", fmt.Errorf("unable to write to SHA-256 checksum file: %v", err)
}
return sha256FilePath, nil
}
// SanitizeArchivePath Sanitize archive file pathing from "G305: Zip Slip vulnerability"
func SanitizeArchivePath(d, t string) (v string, err error) {
v = filepath.Join(d, t)
if strings.HasPrefix(v, filepath.Clean(d)) {
return v, nil
}
return "", fmt.Errorf("%s: %s", "content filepath is tainted", t)
}
// ExtractTarGzToPath extracts a tar.gz file to a specified path
func ExtractTarGzToPath(tarGzFilePath, extractedPath string, maxBytes int64) error {
// Ensure the target directory exists; create it if it doesn't.
if err := os.MkdirAll(extractedPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating target directory: %v", err)
}
tarGzFile, err := os.Open(tarGzFilePath)
if err != nil {
return fmt.Errorf("error opening tar.gz file: %v", err)
}
defer tarGzFile.Close()
gzipReader, err := gzip.NewReader(tarGzFile)
if err != nil {
return fmt.Errorf("error creating gzip reader: %v", err)
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
var totalSize int64
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading tar header: %v", err)
}
if header.Size > maxBytes {
return fmt.Errorf("file %s exceeds the maximum allowed size of %d bytes", header.Name, maxBytes)
}
// Construct the full target path by joining the target directory with
// the name of the file or directory from the archive.
fullTargetPath, err := SanitizeArchivePath(extractedPath, header.Name)
if err != nil {
return fmt.Errorf("failed to santize path %w", err)
}
// Ensure that the directory path leading to the file exists.
if header.FileInfo().IsDir() {
// Create the directory and any parent directories as needed.
if err := os.MkdirAll(fullTargetPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
} else {
// Create the file and any parent directories as needed.
if err := os.MkdirAll(filepath.Dir(fullTargetPath), os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// Create a new file with the specified path.
newFile, err := os.Create(fullTargetPath)
if err != nil {
return fmt.Errorf("error creating file: %v", err)
}
defer newFile.Close()
// Copy the file contents from the tar archive to the new file.
for {
n, err := io.CopyN(newFile, tarReader, 1024)
totalSize += n
if totalSize > maxBytes {
return fmt.Errorf("extracted data exceeds allowed limit of %d bytes", maxBytes)
}
if err != nil {
if err == io.EOF {
break
}
return err
}
}
}
}
return nil
}
// CheckWSL check if running in WSL
func CheckWSL(afs afero.Afero) (bool, error) {
file, err := afs.Open("/proc/version")
if err != nil {
return false, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "Microsoft") || strings.Contains(line, "WSL") {
return true, nil
}
}
if scanner.Err() != nil {
return false, scanner.Err()
}
return false, nil
}
// SaveServiceInfo updates service info into SP's DMS for claim Reward by SP user
func SaveServiceInfo(cpService types.Services) error {
var spService types.Services
err := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Find(&spService).Error
if err != nil {
return fmt.Errorf("unable to find service on SP side: %v", err)
}
cpService.ID = spService.ID
cpService.CreatedAt = spService.CreatedAt
result := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Updates(&cpService)
if result.Error != nil {
return fmt.Errorf("unable to update service info on SP side: %v", result.Error.Error())
}
return nil
}
func RandomBool() (bool, error) {
n, err := rand.Int(rand.Reader, big.NewInt(2))
if err != nil {
return false, err
}
// Return true if the number is 1, otherwise false
return n.Int64() == 1, nil
}
func IsExecutorType(v interface{}) bool {
_, ok := v.(types.ExecutorType)
return ok
}
func IsGPUVendor(v interface{}) bool {
_, ok := v.(types.GPUVendor)
return ok
}
func IsJobType(v interface{}) bool {
_, ok := v.(types.JobType)
return ok
}
func IsJobTypes(v interface{}) bool {
_, ok := v.(types.JobTypes)
return ok
}
func IsExecutor(v interface{}) bool {
_, ok := v.(types.Executor)
return ok
}
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
func NoIntersectionSlices(slice1, slice2 []interface{}) bool {
result := false // the default result is false
for _, subElement := range slice1 {
if slices.Contains(slice2, subElement) {
result = false
} else {
result = true
}
}
return result
}
// IntersectionStringSlices returns the intersection of two slices of strings.
func IntersectionSlices(slice1, slice2 []interface{}) []interface{} {
// Create a map to store strings from the first slice.
executorMap := make(map[interface{}]bool)
// Iterate through the first slice and add elements to the map.
for _, str := range slice1 {
executorMap[str] = true
}
// Create a slice to store the intersection of the strings.
intersectionSlice := []interface{}{}
// Iterate through the second slice and check for common elements.
for _, str := range slice2 {
if executorMap[str] {
// If the string is found in the map, add to the intersection slice.
intersectionSlice = append(intersectionSlice, str)
// Remove the string from the map to avoid duplicates in the result.
delete(executorMap, str)
}
}
return intersectionSlice
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
package validate
import (
"strings"
)
// IsBlank checks if a string is empty or contains only whitespace
func IsBlank(s string) bool {
return len(strings.TrimSpace(s)) == 0
}
// IsNotBlank checks if a string is not empty and does not contain only whitespace
func IsNotBlank(s string) bool {
return !IsBlank(s)
}
// Just checks if a variable is a string
func IsLiteral(s interface{}) bool {
switch s.(type) {
case string:
return true
default:
return false
}
}