// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
type BasicActor struct {
dispatch *Dispatch
scheduler *bt.Scheduler
registry *registry
network network.Network
security SecurityContext
limiter RateLimiter
params BasicActorParams
self Handle
mx sync.Mutex
subscriptions map[string]uint64
}
type BasicActorParams struct{}
var _ Actor = (*BasicActor)(nil)
// New creates a new basic actor.
func New(scheduler *bt.Scheduler, net network.Network, security *BasicSecurityContext, limiter RateLimiter, params BasicActorParams, self Handle, opt ...DispatchOption) (*BasicActor, error) {
if scheduler == nil {
return nil, errors.New("scheduler is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if security == nil {
return nil, errors.New("security is nil")
}
dispatchOptions := []DispatchOption{WithRateLimiter(limiter)}
dispatchOptions = append(dispatchOptions, opt...)
dispatch := NewDispatch(security, dispatchOptions...)
actor := &BasicActor{
dispatch: dispatch,
scheduler: scheduler,
registry: newRegistry(),
network: net,
security: security,
limiter: limiter,
params: params,
self: self,
subscriptions: make(map[string]uint64),
}
return actor, nil
}
func (a *BasicActor) Start() error {
// Network messages
if err := a.network.HandleMessage(
fmt.Sprintf("actor/%s/messages/0.0.1", a.self.Address.InboxAddress),
a.handleMessage,
); err != nil {
return fmt.Errorf("starting actor: %s: %w", a.self.ID, err)
}
// and start the internal goroutines
a.dispatch.Start()
a.scheduler.Start()
return nil
}
func (a *BasicActor) handleMessage(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling message: %s", err)
return
}
if !a.self.ID.Equal(msg.To.ID) {
log.Warnf("message is not for ourselves: %s %s", a.self.ID, msg.To.ID)
return
}
if !a.limiter.Allow(msg) {
log.Warnf("incoming message invoking %s not allowed by limiter", msg.Behavior)
return
}
_ = a.Receive(msg)
}
func (a *BasicActor) Context() context.Context {
return a.dispatch.Context()
}
func (a *BasicActor) Handle() Handle {
return a.self
}
func (a *BasicActor) Security() SecurityContext {
return a.security
}
func (a *BasicActor) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
return a.dispatch.AddBehavior(behavior, continuation, opt...)
}
func (a *BasicActor) RemoveBehavior(behavior string) {
a.dispatch.RemoveBehavior(behavior)
}
func (a *BasicActor) Receive(msg Envelope) error {
if a.self.ID.Equal(msg.To.ID) {
return a.dispatch.Receive(msg)
}
if msg.IsBroadcast() {
return a.dispatch.Receive(msg)
}
return fmt.Errorf("bad receiver: %w", ErrInvalidMessage)
}
func (a *BasicActor) Send(msg Envelope) error {
if msg.To.ID.Equal(a.self.ID) {
return a.Receive(msg)
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
invoke := []Capability{Capability(msg.Behavior)}
var delegate []Capability
if msg.Options.ReplyTo != "" {
delegate = append(delegate, Capability(msg.Options.ReplyTo))
}
if err := a.security.Provide(&msg, invoke, delegate); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
err = a.network.SendMessage(
a.Context(),
msg.To.Address.HostID,
types.MessageEnvelope{
Type: types.MessageType(
fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress),
),
Data: data,
},
msg.Expiry(),
)
if err != nil {
return fmt.Errorf("sending message to %s: %w", msg.To.ID, err)
}
return nil
}
func (a *BasicActor) Invoke(msg Envelope) (<-chan Envelope, error) {
if msg.Options.ReplyTo == "" {
msg.Options.ReplyTo = fmt.Sprintf("/dms/actor/replyto/%d", a.security.Nonce())
}
result := make(chan Envelope, 1)
if err := a.dispatch.AddBehavior(
msg.Options.ReplyTo,
func(reply Envelope) {
result <- reply
close(result)
},
WithBehaviorExpiry(msg.Options.Expire),
WithBehaviorOneShot(true),
); err != nil {
return nil, fmt.Errorf("adding reply behavior: %w", err)
}
if err := a.Send(msg); err != nil {
a.dispatch.RemoveBehavior(msg.Options.ReplyTo)
return nil, fmt.Errorf("sending message: %w", err)
}
return result, nil
}
func (a *BasicActor) Publish(msg Envelope) error {
if !msg.IsBroadcast() {
return ErrInvalidMessage
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
broadcast := []Capability{Capability(msg.Behavior)}
if err := a.security.ProvideBroadcast(&msg, msg.Options.Topic, broadcast); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
if err := a.network.Publish(a.Context(), msg.Options.Topic, data); err != nil {
return fmt.Errorf("publishing message: %w", err)
}
return nil
}
func (a *BasicActor) Subscribe(topic string, setup ...BroadcastSetup) error {
a.mx.Lock()
defer a.mx.Unlock()
_, ok := a.subscriptions[topic]
if ok {
return nil
}
subID, err := a.network.Subscribe(
a.Context(),
topic,
a.handleBroadcast,
func(data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
return a.validateBroadcast(topic, data, validatorData)
},
)
if err != nil {
return fmt.Errorf("subscribe: %w", err)
}
for _, f := range setup {
if err := f(topic); err != nil {
_ = a.network.Unsubscribe(topic, subID)
return fmt.Errorf("setup broadcast topic: %w", err)
}
}
a.subscriptions[topic] = subID
return nil
}
func (a *BasicActor) validateBroadcast(topic string, data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
var msg Envelope
if validatorData != nil {
if _, ok := validatorData.(Envelope); !ok {
log.Warnf("bogus pubsub validation data: %v", validatorData)
return network.ValidationReject, nil
}
// we have already validated the message, just short-circuit
return network.ValidationAccept, validatorData
} else if err := json.Unmarshal(data, &msg); err != nil {
return network.ValidationReject, nil
}
if !msg.IsBroadcast() {
return network.ValidationReject, nil
}
if msg.Options.Topic != topic {
return network.ValidationReject, nil
}
if msg.Expired() {
return network.ValidationIgnore, nil
}
if err := a.security.Verify(msg); err != nil {
return network.ValidationReject, nil
}
if !a.limiter.Allow(msg) {
log.Warnf("incoming broadcast message in %s not allowed by limiter", topic)
return network.ValidationIgnore, nil
}
return network.ValidationAccept, msg
}
func (a *BasicActor) handleBroadcast(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling broadcast message: %s", err)
return
}
// don't receive message from self
if msg.From.Equal(a.Handle()) {
return
}
if err := a.Receive(msg); err != nil {
log.Warnf("error receiving broadcast message: %s", err)
}
}
func (a *BasicActor) Stop() error {
a.dispatch.close()
for topic, subID := range a.subscriptions {
err := a.network.Unsubscribe(topic, subID)
if err != nil {
log.Debugf("error unsubscribing from %s: %s", topic, err)
}
}
return nil
}
func (a *BasicActor) Limiter() RateLimiter {
return a.limiter
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"fmt"
"sync"
"time"
)
var (
DefaultDispatchGCInterval = 120 * time.Second
DefaultDispatchWorkers = 1
)
// Dispatch provides a reaction kernel with multithreaded dispatch and oneshot
// continuations.
type Dispatch struct {
ctx context.Context
close func()
sctx SecurityContext
mx sync.Mutex
q chan Envelope // incoming message queue
vq chan Envelope // verified message queue
behaviors map[string]*BehaviorState
started bool
options DispatchOptions
}
type DispatchOptions struct {
Limiter RateLimiter
GCInterval time.Duration
Workers int
}
type BehaviorState struct {
cont Behavior
opt BehaviorOptions
}
type DispatchOption func(o *DispatchOptions)
func WithDispatchWorkers(count int) DispatchOption {
return func(o *DispatchOptions) {
o.Workers = count
}
}
func WithDispatchGCInterval(dt time.Duration) DispatchOption {
return func(o *DispatchOptions) {
o.GCInterval = dt
}
}
func WithRateLimiter(limiter RateLimiter) DispatchOption {
return func(o *DispatchOptions) {
o.Limiter = limiter
}
}
func NewDispatch(sctx SecurityContext, opt ...DispatchOption) *Dispatch {
ctx, cancel := context.WithCancel(context.Background())
k := &Dispatch{
sctx: sctx,
ctx: ctx,
close: cancel,
q: make(chan Envelope),
vq: make(chan Envelope),
behaviors: make(map[string]*BehaviorState),
options: DispatchOptions{
GCInterval: DefaultDispatchGCInterval,
Workers: DefaultDispatchWorkers,
Limiter: NoRateLimiter{},
},
}
for _, f := range opt {
f(&k.options)
}
return k
}
func (k *Dispatch) Start() {
k.mx.Lock()
defer k.mx.Unlock()
if !k.started {
for i := 0; i < k.options.Workers; i++ {
go k.recv()
}
go k.dispatch()
go k.gc()
k.started = true
}
}
func (k *Dispatch) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
st := &BehaviorState{
cont: continuation,
opt: BehaviorOptions{
Capability: []Capability{Capability(behavior)},
},
}
for _, f := range opt {
if err := f(&st.opt); err != nil {
return fmt.Errorf("adding behavior: %w", err)
}
}
k.mx.Lock()
defer k.mx.Unlock()
k.behaviors[behavior] = st
return nil
}
func (k *Dispatch) RemoveBehavior(behavior string) {
k.mx.Lock()
defer k.mx.Unlock()
delete(k.behaviors, behavior)
}
func (k *Dispatch) Receive(msg Envelope) error {
select {
case k.q <- msg:
return nil
case <-k.ctx.Done():
return k.ctx.Err()
}
}
func (k *Dispatch) Context() context.Context {
return k.ctx
}
func (k *Dispatch) recv() {
for {
select {
case msg := <-k.q:
if err := k.sctx.Verify(msg); err != nil {
log.Debugf("failed to verify message from %s: %s", msg.From, err)
continue
}
k.vq <- msg
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) dispatch() {
for {
select {
case msg := <-k.vq:
k.mx.Lock()
b, ok := k.behaviors[msg.Behavior]
if !ok {
k.mx.Unlock()
log.Debugf("unknown behavior %s", msg.Behavior)
continue
}
if b.Expired(time.Now()) {
delete(k.behaviors, msg.Behavior)
k.mx.Unlock()
log.Debugf("expired behavior %s", msg.Behavior)
continue
}
if msg.IsBroadcast() {
if err := k.sctx.RequireBroadcast(msg, b.opt.Topic, b.opt.Capability); err != nil {
k.mx.Unlock()
log.Warnf("broadcast message from %s does not have the required capability %s %s: %s", msg.From, b.opt.Capability, string(msg.Capability), err)
continue
}
} else if err := k.sctx.Require(msg, b.opt.Capability); err != nil {
k.mx.Unlock()
log.Warnf("message from %s does not have the required capability %s %s: %s", msg.From, b.opt.Capability, string(msg.Capability), err)
continue
}
if b.opt.OneShot {
delete(k.behaviors, msg.Behavior)
}
k.mx.Unlock()
if err := k.options.Limiter.Acquire(msg); err != nil {
k.sctx.Discard(msg)
log.Debugf("limiter rejected message from %s: %s", msg.From, err)
continue
}
msg.Discard = func() {
k.sctx.Discard(msg)
}
log.Debugf("dispatching message from %s to %s", msg.From, msg.Behavior)
go func() {
defer k.options.Limiter.Release(msg)
b.cont(msg)
}()
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) gc() {
ticker := time.NewTicker(k.options.GCInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
k.mx.Lock()
now := time.Now()
for x, b := range k.behaviors {
if b.Expired(now) {
delete(k.behaviors, x)
}
}
k.mx.Unlock()
case <-k.ctx.Done():
return
}
}
}
func (b *BehaviorState) Expired(now time.Time) bool {
if b.opt.Expire > 0 {
return uint64(now.UnixNano()) > b.opt.Expire
}
return false
}
func WithBehaviorExpiry(expire uint64) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Expire = expire
return nil
}
}
func WithBehaviorCapability(require ...Capability) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Capability = require
return nil
}
}
func WithBehaviorOneShot(oneShot bool) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.OneShot = oneShot
return nil
}
}
func WithBehaviorTopic(topic string) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Topic = topic
return nil
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"fmt"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
)
func (h *Handle) Empty() bool {
return h.ID.Empty() &&
h.DID.Empty() &&
h.Address.Empty()
}
func (h *Handle) String() string {
var idStr string
idDID, err := did.FromID(h.ID)
if err == nil {
idStr = idDID.String()
}
return fmt.Sprintf("%s[%s]@%s", idStr, h.DID, h.Address)
}
func (h Handle) Equal(other Handle) bool {
if !h.ID.Equal(other.ID) {
return false
}
if !h.DID.Equal(other.DID) {
return false
}
if h.Address.HostID != other.Address.HostID {
return false
}
if h.Address.InboxAddress != other.Address.InboxAddress {
return false
}
return true
}
func HandleFromString(_ string) (Handle, error) {
// TODO
return Handle{}, ErrTODO
}
func (a *Address) Empty() bool {
return a.HostID == "" && a.InboxAddress == ""
}
func (a *Address) String() string {
return a.HostID + ":" + a.InboxAddress
}
func AddressFromString(_ string) (Address, error) {
// TODO
return Address{}, ErrTODO
}
func HandleFromPeerID(dest string) (Handle, error) {
peerID, err := peer.Decode(dest)
if err != nil {
return Handle{}, err
}
pubk, err := peerID.ExtractPublicKey()
if err != nil {
return Handle{}, err
}
if !crypto.AllowedKey(int(pubk.Type())) {
return Handle{}, fmt.Errorf("unexpected key type: %d", pubk.Type())
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
return Handle{}, err
}
actorDID := did.FromPublicKey(pubk)
handle := Handle{
ID: actorID,
DID: actorDID,
Address: Address{
HostID: peerID.String(),
InboxAddress: "root",
},
}
return handle, nil
}
func HandleFromDID(dest string) (Handle, error) {
actorDID, err := did.FromString(dest)
if err != nil {
return Handle{}, err
}
pubk, err := did.PublicKeyFromDID(actorDID)
if err != nil {
return Handle{}, err
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
return Handle{}, err
}
peerID, err := peer.IDFromPublicKey(pubk)
if err != nil {
return Handle{}, err
}
handle := Handle{
ID: actorID,
DID: actorDID,
Address: Address{
HostID: peerID.String(),
InboxAddress: "root",
},
}
return handle, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"strings"
"sync"
)
// NoRateLimiter is the null limiter, that does not rate limit
type NoRateLimiter struct{}
var _ RateLimiter = NoRateLimiter{}
type BasicRateLimiter struct {
cfg RateLimiterConfig
mx sync.Mutex
activeBroadcast int
activeTopics map[string]int
activePublic int
}
var _ RateLimiter = (*BasicRateLimiter)(nil)
// implementation
func (l NoRateLimiter) Allow(_ Envelope) bool { return true }
func (l NoRateLimiter) Acquire(_ Envelope) error { return nil }
func (l NoRateLimiter) Release(_ Envelope) {}
func (l NoRateLimiter) Config() RateLimiterConfig { return RateLimiterConfig{} }
func (l NoRateLimiter) SetConfig(_ RateLimiterConfig) {}
func DefaultRateLimiterConfig() RateLimiterConfig {
return RateLimiterConfig{
PublicLimitAllow: 4096,
PublicLimitAcquire: 4112,
BroadcastLimitAllow: 1024,
BroadcastLimitAcquire: 1040,
TopicDefaultLimit: 128,
}
}
func (cfg *RateLimiterConfig) Valid() bool {
return cfg.PublicLimitAllow > 0 &&
cfg.PublicLimitAcquire >= cfg.PublicLimitAllow &&
cfg.BroadcastLimitAllow > 0 &&
cfg.BroadcastLimitAcquire >= cfg.BroadcastLimitAllow &&
cfg.TopicDefaultLimit > 0
}
func NewRateLimiter(cfg RateLimiterConfig) RateLimiter {
return &BasicRateLimiter{
cfg: cfg,
activeTopics: make(map[string]int),
}
}
func (l *BasicRateLimiter) Allow(msg Envelope) bool {
if msg.IsBroadcast() {
return l.allowBroadcast(msg)
}
if isPublicBehavior(msg) {
return l.allowPublic()
}
return true
}
func (l *BasicRateLimiter) allowPublic() bool {
l.mx.Lock()
defer l.mx.Unlock()
return l.activePublic < l.cfg.PublicLimitAllow
}
func (l *BasicRateLimiter) allowBroadcast(msg Envelope) bool {
l.mx.Lock()
defer l.mx.Unlock()
if l.activeBroadcast >= l.cfg.BroadcastLimitAllow {
return false
}
topic := msg.Options.Topic
active := l.activeTopics[topic]
topicLimit, ok := l.cfg.TopicLimit[topic]
if !ok {
return active < l.cfg.TopicDefaultLimit
}
return active < topicLimit
}
func (l *BasicRateLimiter) Acquire(msg Envelope) error {
if msg.IsBroadcast() {
return l.acquireBroadcast(msg)
}
if isPublicBehavior(msg) {
return l.acquirePublic()
}
return nil
}
func (l *BasicRateLimiter) acquirePublic() error {
l.mx.Lock()
defer l.mx.Unlock()
if l.activePublic >= l.cfg.PublicLimitAcquire {
return ErrRateLimitExceeded
}
l.activePublic++
return nil
}
func (l *BasicRateLimiter) acquireBroadcast(msg Envelope) error {
l.mx.Lock()
defer l.mx.Unlock()
if l.activeBroadcast >= l.cfg.BroadcastLimitAcquire {
return ErrRateLimitExceeded
}
topic := msg.Options.Topic
active := l.activeTopics[topic]
topicLimit, ok := l.cfg.TopicLimit[topic]
if ok {
if active >= topicLimit {
return ErrRateLimitExceeded
}
} else if active >= l.cfg.TopicDefaultLimit {
return ErrRateLimitExceeded
}
active++
l.activeTopics[topic] = active
l.activeBroadcast++
return nil
}
func (l *BasicRateLimiter) Release(msg Envelope) {
if msg.IsBroadcast() {
l.releaseBroadcast(msg)
} else if isPublicBehavior(msg) {
l.releasePublic()
}
}
func (l *BasicRateLimiter) releasePublic() {
l.mx.Lock()
defer l.mx.Unlock()
l.activePublic--
}
func (l *BasicRateLimiter) releaseBroadcast(msg Envelope) {
l.mx.Lock()
defer l.mx.Unlock()
topic := msg.Options.Topic
active, ok := l.activeTopics[topic]
if !ok {
return
}
active--
if active > 0 {
l.activeTopics[topic] = active
} else {
delete(l.activeTopics, topic)
}
l.activeBroadcast--
}
func (l *BasicRateLimiter) Config() RateLimiterConfig {
l.mx.Lock()
defer l.mx.Unlock()
return l.cfg
}
func (l *BasicRateLimiter) SetConfig(cfg RateLimiterConfig) {
l.mx.Lock()
defer l.mx.Unlock()
l.cfg = cfg
}
func isPublicBehavior(msg Envelope) bool {
return strings.HasPrefix(msg.Behavior, "/public/")
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"encoding/json"
"fmt"
"time"
)
const (
heartbeatBehavior = "/dms/actor/heartbeat"
defaultMessageTimeout = 30 * time.Second
)
var signaturePrefix = []byte("dms:msg:")
type HeartbeatMessage struct{}
// Message constructs a new message envelope and applies the options
func Message(src Handle, dest Handle, behavior string, payload interface{}, opt ...MessageOption) (Envelope, error) {
data, err := json.Marshal(payload)
if err != nil {
return Envelope{}, fmt.Errorf("marshaling payload: %w", err)
}
msg := Envelope{
To: dest,
Behavior: behavior,
From: src,
Message: data,
Options: EnvelopeOptions{
Expire: uint64(time.Now().Add(defaultMessageTimeout).UnixNano()),
},
Discard: func() {},
}
for _, f := range opt {
if err := f(&msg); err != nil {
return Envelope{}, fmt.Errorf("setting message option: %w", err)
}
}
return msg, nil
}
func ReplyTo(msg Envelope, payload interface{}, opt ...MessageOption) (Envelope, error) {
if msg.Options.ReplyTo == "" {
return Envelope{}, fmt.Errorf("no behavior to reply to: %w", ErrInvalidMessage)
}
msgOptions := []MessageOption{WithMessageExpiry(msg.Options.Expire)}
msgOptions = append(msgOptions, opt...)
return Message(msg.To, msg.From, msg.Options.ReplyTo, payload, msgOptions...)
}
// WithMessageContext provides the necessary envelope and signs it.
//
// NOTE: If this option must be passed last, otherwise the signature will be invalidated by further modifications.
//
// NOTE: Signing is implicit in Send.
func WithMessageSignature(sctx SecurityContext, invoke []Capability, delegate []Capability) MessageOption {
return func(msg *Envelope) error {
if !msg.From.ID.Equal(sctx.ID()) {
return ErrInvalidSecurityContext
}
msg.Nonce = sctx.Nonce()
if msg.IsBroadcast() {
return sctx.ProvideBroadcast(msg, msg.Options.Topic, invoke)
}
return sctx.Provide(msg, invoke, delegate)
}
}
// WithMessageTimeout sets the message expiration from a relative timeout
//
// NOTE: messages created with Message have an implicit timeout of DefaultMessageTimeout
func WithMessageTimeout(timeo time.Duration) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = uint64(time.Now().Add(timeo).UnixNano())
return nil
}
}
// WithMessageExpiry sets the message expiry
//
// NOTE: created with Message message have an implicit timeout of DefaultMessageTimeout
func WithMessageExpiry(expiry uint64) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = expiry
return nil
}
}
// WithMessageReplyTo sets the message replyto behavior
//
// NOTE: ReplyTo is set implicitly in Invoke and the appropriate capability
//
// tokens are delegated by Provide.
func WithMessageReplyTo(replyto string) MessageOption {
return func(msg *Envelope) error {
msg.Options.ReplyTo = replyto
return nil
}
}
// WithMessageTopic sets the broadcast topic
func WithMessageTopic(topic string) MessageOption {
return func(msg *Envelope) error {
msg.Options.Topic = topic
return nil
}
}
// WithMessageSource sets the message source
func WithMessageSource(source Handle) MessageOption {
return func(msg *Envelope) error {
msg.From = source
return nil
}
}
func (msg *Envelope) SignatureData() ([]byte, error) {
msgCopy := *msg
msgCopy.Signature = nil
data, err := json.Marshal(&msgCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (msg *Envelope) Expired() bool {
return uint64(time.Now().UnixNano()) > msg.Options.Expire
}
// convert the expiration to a time.Time object
func (msg *Envelope) Expiry() time.Time {
sec := msg.Options.Expire / uint64(time.Second)
nsec := msg.Options.Expire % uint64(time.Second)
return time.Unix(int64(sec), int64(nsec))
}
func (msg *Envelope) IsBroadcast() bool {
return msg.To.Empty() && msg.Options.Topic != ""
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"errors"
"sync"
)
type Info struct {
Addr *Handle
Parent *Handle
Children []Handle
}
type Registry interface {
Actors() map[string]Info
Add(a Handle, parent Handle, children []Handle) error
Get(a Handle) (Info, bool)
SetParent(a Handle, parent Handle) error
GetParent(a Handle) (*Handle, bool)
}
type registry struct {
mx sync.Mutex
actors map[string]Info
}
func newRegistry() *registry {
return ®istry{
actors: make(map[string]Info),
}
}
func (r *registry) Actors() map[string]Info {
r.mx.Lock()
defer r.mx.Unlock()
actors := make(map[string]Info, len(r.actors))
for k, v := range r.actors {
actors[k] = v
}
return actors
}
func (r *registry) Add(a Handle, parent Handle, children []Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
if _, ok := r.actors[a.Address.InboxAddress]; ok {
return errors.New("actor already exists")
}
if children == nil {
children = []Handle{}
}
r.actors[a.Address.InboxAddress] = Info{
Addr: &a,
Parent: &parent,
Children: children,
}
return nil
}
func (r *registry) Get(a Handle) (Info, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
return info, ok
}
func (r *registry) SetParent(a Handle, parent Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return errors.New("actor not found")
}
info.Parent = &parent
r.actors[a.Address.InboxAddress] = info
return nil
}
func (r *registry) GetParent(a Handle) (*Handle, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return nil, false
}
return info.Parent, true
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
type BasicSecurityContext struct {
id ID
privk crypto.PrivKey
cap ucan.CapabilityContext
mx sync.Mutex
nonce uint64
}
var _ SecurityContext = (*BasicSecurityContext)(nil)
func NewBasicSecurityContext(pubk crypto.PubKey, privk crypto.PrivKey, cap ucan.CapabilityContext) (*BasicSecurityContext, error) {
sctx := &BasicSecurityContext{
privk: privk,
cap: cap,
nonce: uint64(time.Now().UnixNano()),
}
var err error
sctx.id, err = crypto.IDFromPublicKey(pubk)
if err != nil {
return nil, fmt.Errorf("creating security context: %w", err)
}
return sctx, nil
}
func (s *BasicSecurityContext) ID() ID {
return s.id
}
func (s *BasicSecurityContext) DID() DID {
return s.cap.DID()
}
func (s *BasicSecurityContext) Nonce() uint64 {
s.mx.Lock()
defer s.mx.Unlock()
nonce := s.nonce
s.nonce++
return nonce
}
func (s *BasicSecurityContext) Require(msg Envelope, cap []Capability) error {
// if we are sending to self, nothing to do, signature is alredady verified
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return nil
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.Require(s.DID(), msg.From.ID, s.id, cap); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) Provide(msg *Envelope, invoke []Capability, delegate []Capability) error {
// if we are sending to self, nothing to do, just Sign
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return s.Sign(msg)
}
tokens, err := s.cap.Provide(msg.To.DID, s.id, msg.To.ID, msg.Options.Expire, invoke, delegate)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) RequireBroadcast(msg Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.RequireBroadcast(s.DID(), msg.From.ID, topic, broadcast); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) ProvideBroadcast(msg *Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
tokens, err := s.cap.ProvideBroadcast(msg.From.ID, topic, msg.Options.Expire, broadcast)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) Verify(msg Envelope) error {
if msg.Expired() {
return ErrMessageExpired
}
pubk, err := crypto.PublicKeyFromID(msg.From.ID)
if err != nil {
return fmt.Errorf("public key from id: %w", err)
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
ok, err := pubk.Verify(data, msg.Signature)
if err != nil {
return fmt.Errorf("verify message signature: %w", err)
}
if !ok {
return ErrSignatureVerification
}
return nil
}
func (s *BasicSecurityContext) Sign(msg *Envelope) error {
if !s.id.Equal(msg.From.ID) {
return ErrBadSender
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
sig, err := s.privk.Sign(data)
if err != nil {
return fmt.Errorf("signing message: %w", err)
}
msg.Signature = sig
return nil
}
func (s *BasicSecurityContext) Discard(msg Envelope) {
s.cap.Discard(msg.Capability)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"testing"
"time"
"github.com/google/uuid"
"github.com/multiformats/go-multiaddr"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
backgroundtasks "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
func MakeRootTrustContext(t *testing.T) (did.DID, did.TrustContext) {
privk, _, err := crypto.GenerateKeyPair(crypto.Ed25519)
require.NoError(t, err)
return MakeTrustContext(t, privk)
}
func MakeTrustContext(t *testing.T, privk crypto.PrivKey) (did.DID, did.TrustContext) {
provider, err := did.ProviderFromPrivateKey(privk)
require.NoError(t, err, "provider from public key")
ctx := did.NewTrustContext()
ctx.AddProvider(provider)
return provider.DID(), ctx
}
func MakeCapabilityContext(t *testing.T, actorDID, rootDID did.DID, trust, root did.TrustContext) ucan.CapabilityContext {
actorCap, err := ucan.NewCapabilityContext(trust, actorDID, nil, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
rootCap, err := ucan.NewCapabilityContext(root, rootDID, nil, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
tokens, err := rootCap.Grant(
ucan.Delegate,
actorDID,
did.DID{},
nil,
MakeExpiry(time.Hour),
0,
[]ucan.Capability{ucan.Root},
)
require.NoError(t, err)
err = actorCap.AddRoots([]did.DID{rootDID}, ucan.TokenList{}, tokens)
require.NoError(t, err)
return actorCap
}
func MakeExpiry(d time.Duration) uint64 {
return uint64(time.Now().Add(d).UnixNano())
}
func AllowReciprocal(t *testing.T, actorCap ucan.CapabilityContext, rootTrust did.TrustContext, rootDID, otherRootDID did.DID, cap string) {
rootCap, err := ucan.NewCapabilityContext(rootTrust, rootDID, nil, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
tokens, err := rootCap.Grant(
ucan.Delegate,
otherRootDID,
did.DID{},
nil,
MakeExpiry(time.Hour),
0,
[]ucan.Capability{ucan.Capability(cap)},
)
require.NoError(t, err)
err = actorCap.AddRoots(nil, tokens, ucan.TokenList{})
require.NoError(t, err)
}
func AllowBroadcast(t *testing.T, actor1, actor2 ucan.CapabilityContext, root1, root2 did.TrustContext, root1DID, root2DID did.DID, topic string, actorCap ...Capability) {
root1Cap, err := ucan.NewCapabilityContext(root1, root1DID, nil, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
root2Cap, err := ucan.NewCapabilityContext(root2, root2DID, nil, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
tokens, err := root1Cap.Grant(
ucan.Delegate,
actor1.DID(),
did.DID{},
[]string{topic},
MakeExpiry(120*time.Second),
0,
actorCap,
)
require.NoError(t, err, "granting broadcast capability")
err = actor1.AddRoots(nil, ucan.TokenList{}, tokens)
require.NoError(t, err, "add roots")
tokens, err = root2Cap.Grant(
ucan.Delegate,
actor1.DID(),
did.DID{},
[]string{topic},
MakeExpiry(120*time.Second),
0,
actorCap,
)
require.NoError(t, err, "grant broadcast capability")
err = actor2.AddRoots(nil, tokens, ucan.TokenList{})
require.NoError(t, err, "add roots")
}
func CreateActor(t *testing.T, peer *libp2p.Libp2p, cap ucan.CapabilityContext) *BasicActor {
privk, pubk, err := crypto.GenerateKeyPair(crypto.Ed25519)
require.NoError(t, err)
sctx, err := NewBasicSecurityContext(pubk, privk, cap)
assert.NoError(t, err)
params := BasicActorParams{}
uuid, err := uuid.NewUUID()
assert.NoError(t, err)
handle := Handle{
ID: sctx.id,
DID: cap.DID(),
Address: Address{
HostID: peer.Host.ID().String(),
InboxAddress: uuid.String(),
},
}
actor, err := New(backgroundtasks.NewScheduler(1), peer, sctx, NewRateLimiter(DefaultRateLimiterConfig()), params, handle)
assert.NoError(t, err)
assert.NotNil(t, actor)
return actor
}
func NewLibp2pNetwork(t *testing.T, bootstrap []multiaddr.Multiaddr) ([]multiaddr.Multiaddr, crypto.PrivKey, *libp2p.Libp2p) {
priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519)
assert.NoError(t, err)
net, err := network.NewNetwork(&types.NetworkConfig{
Type: types.Libp2pNetwork,
Libp2pConfig: types.Libp2pConfig{
PrivateKey: priv,
BootstrapPeers: bootstrap,
Rendezvous: "nunet-randevouz",
Server: false,
Scheduler: backgroundtasks.NewScheduler(1),
CustomNamespace: "/nunet-dht-1/",
ListenAddress: []string{"/ip4/127.0.0.1/tcp/0"},
PeerCountDiscoveryLimit: 40,
GossipMaxMessageSize: 2 << 16,
},
}, afero.NewMemMapFs())
assert.NoError(t, err)
err = net.Init()
assert.NoError(t, err)
err = net.Start()
assert.NoError(t, err)
libp2pInstance, _ := net.(*libp2p.Libp2p)
multi, err := libp2pInstance.GetMultiaddr()
assert.NoError(t, err)
return multi, priv, libp2pInstance
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
logging "github.com/ipfs/go-log/v2"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// log is the logger for the actor API package
var log = logging.Logger("actor-api")
// ActorHandle godoc
//
// @Summary Retrieve actor handle
// @Description Retrieve actor handle with ID, DID, and inbox address
// @Tags actor
// @Produce json
// @Success 200 {object} actor.Handle
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "handle id is invalid"
// @Router /actor/handle [get]
func (rs RESTServer) ActorHandle(c *gin.Context) {
endTrace := observability.StartTrace("actor_handle_retrieve_duration")
defer endTrace()
p2p := rs.config.P2P
if p2p == nil {
log.Errorw("actor_handle_retrieve_failure", "error", "host node hasn't yet been initialized")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
// get handle here
pubk := p2p.Host.Peerstore().PubKey(p2p.Host.ID())
id, err := crypto.IDFromPublicKey(pubk)
if err != nil {
log.Errorw("actor_handle_retrieve_failure", "error", "handle id is invalid")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "handle id is invalid"})
return
}
did := did.FromPublicKey(pubk)
handle := actor.Handle{
ID: id,
DID: did,
Address: actor.Address{
HostID: p2p.Host.ID().String(),
InboxAddress: "root",
},
}
log.Infow("actor_handle_retrieve_success", "id", id, "DID", did)
c.JSON(http.StatusOK, handle)
}
// ActorSendMessage godoc
//
// @Summary Send message to actor
// @Description Send message to actor
// @Tags actor
// @Accept json
// @Produce json
// @Param message body actor.Envelope true "Message to send"
// @Success 200 {object} object "message sent"
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "failed to marshal message"
// @Failure 500 {object} object "destination address can't be resolved"
// @Failure 500 {object} object "failed to send message to destination"
// @Router /actor/send [post]
func (rs RESTServer) ActorSendMessage(c *gin.Context) {
endTrace := observability.StartTrace("actor_send_message_duration")
defer endTrace()
var msg actor.Envelope
if err := c.ShouldBindJSON(&msg); err != nil {
log.Errorw("actor_send_message_failure", "error", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p2p := rs.config.P2P
if p2p == nil {
log.Errorw("actor_send_message_failure", "error", "host node hasn't yet been initialized")
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
err := SendMessage(c.Request.Context(), p2p, msg)
if err != nil {
log.Errorw("actor_send_message_failure", "error", err.Error(), "destination", msg.To.Address.HostID)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
log.Infow("actor_send_message_success", "destination", msg.To.Address.HostID)
c.JSON(http.StatusOK, gin.H{"message": "message sent"})
}
// ActorInvoke godoc
//
// @Summary Invoke actor
// @Description Invoke actor with message
// @Tags actor
// @Accept json
// @Produce json
// @Param message body actor.Envelope true "Message to send"
// @Success 200 {object} object "response message"
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "failed to marshal message"
// @Failure 500 {object} object "destination address can't be resolved"
// @Failure 500 {object} object "failed to send message to destination"
// @Router /actor/invoke [post]
func (rs RESTServer) ActorInvoke(c *gin.Context) {
endTrace := observability.StartTrace("actor_invoke_duration")
defer endTrace()
var msg actor.Envelope
if err := c.ShouldBindJSON(&msg); err != nil {
log.Errorw("actor_invoke_failure", "error", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p2p := rs.config.P2P
if p2p == nil {
log.Errorw("actor_invoke_failure", "error", "host node hasn't yet been initialized")
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
// Register a message handler for the responseCh
protocol := fmt.Sprintf("actor/%s/messages/0.0.1", msg.From.Address.InboxAddress)
responseCh := make(chan actor.Envelope, 1)
err := p2p.HandleMessage(protocol, func(data []byte) {
var envelope actor.Envelope
if err := json.Unmarshal(data, &envelope); err != nil {
log.Errorw("actor_invoke_response_failure", "error", err.Error())
return
}
responseCh <- envelope
})
if err != nil {
log.Errorw("actor_invoke_failure", "error", err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Unregister the message handler before returning
defer p2p.UnregisterMessageHandler(protocol)
err = SendMessage(c.Request.Context(), p2p, msg)
if err != nil {
log.Errorw("actor_invoke_failure", "error", err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
select {
case responseMsg := <-responseCh:
log.Infow("actor_invoke_success", "destination", msg.To.Address.HostID)
c.JSON(http.StatusOK, responseMsg)
return
case <-time.After(time.Until(msg.Expiry())):
log.Errorw("actor_invoke_failure", "error", "request timeout")
c.JSON(http.StatusRequestTimeout, gin.H{"error": "request timeout"})
return
case <-c.Request.Context().Done():
log.Errorw("actor_invoke_failure", "error", "request timeout")
c.JSON(http.StatusRequestTimeout, gin.H{"error": "request timeout"})
return
}
}
// ActorBroadcast godoc
//
// @Summary Broadcast message to actors
// @Description Broadcast message to actors
// @Tags actor
// @Accept json
// @Produce json
// @Param message body actor.Envelope true "Message to send"
// @Success 200 {object} object "received responses"
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "failed to marshal message"
// @Failure 500 {object} object "failed to publish message"
// @Router /actor/broadcast [post]
func (rs RESTServer) ActorBroadcast(c *gin.Context) {
endTrace := observability.StartTrace("actor_broadcast_duration")
defer endTrace()
var msg actor.Envelope
if err := c.ShouldBindJSON(&msg); err != nil {
log.Errorw("actor_broadcast_failure", "error", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p2p := rs.config.P2P
if p2p == nil {
log.Errorw("actor_broadcast_failure", "error", "host node hasn't yet been initialized")
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
if !msg.IsBroadcast() {
log.Errorw("actor_broadcast_failure", "error", "message is not a broadcast message")
c.JSON(http.StatusBadRequest, gin.H{"error": "message is not a broadcast message"})
return
}
data, err := json.Marshal(msg)
if err != nil {
log.Errorw("actor_broadcast_failure", "error", "failed to marshal message")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal message"})
return
}
// register message handler to collect responses
protocol := fmt.Sprintf("actor/%s/messages/0.0.1", msg.From.Address.InboxAddress)
var messages []actor.Envelope
var mu sync.Mutex
err = p2p.HandleMessage(protocol, func(data []byte) {
var envelope actor.Envelope
if err = json.Unmarshal(data, &envelope); err != nil {
log.Errorw("actor_broadcast_failure", "error", "failed to unmarshal response message")
return
}
mu.Lock()
messages = append(messages, envelope)
mu.Unlock()
})
if err != nil {
log.Errorw("actor_broadcast_failure", "error", err.Error())
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Unregister the message handler before returning
defer p2p.UnregisterMessageHandler(protocol)
// Publish the message
if err := p2p.Publish(c.Request.Context(), msg.Options.Topic, data); err != nil {
log.Errorw("actor_broadcast_failure", "error", "failed to publish message")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to publish message"})
return
}
// Wait for either context done or timeout
select {
case <-time.After(time.Until(msg.Expiry())):
// message expiry time reached
case <-c.Request.Context().Done():
// request context done
}
log.Infow("actor_broadcast_success", "fromAddress", msg.From.Address.HostID, "responsesCount", len(messages))
c.JSON(http.StatusOK, messages)
}
func SendMessage(ctx context.Context, net *libp2p.Libp2p, msg actor.Envelope) (err error) {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
err = net.SendMessageSync(
ctx,
msg.To.Address.HostID,
types.MessageEnvelope{
Type: types.MessageType(
fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress),
),
Data: data,
},
msg.Expiry(),
)
if err != nil {
return fmt.Errorf("failed to send message to %s: %w", msg.To.ID, err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package api
import (
"fmt"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
type RESTServerConfig struct {
P2P *libp2p.Libp2p
Onboarding *onboarding.Onboarding
Resource types.ResourceManager
MidW []gin.HandlerFunc
Port uint32
Addr string
}
// RESTServer represents a HTTP server
type RESTServer struct {
router *gin.Engine
config *RESTServerConfig
}
// NewRESTServer is a constructor function for RESTServer
// It returns a pointer to RESTServer
func NewRESTServer(config *RESTServerConfig) *RESTServer {
endTrace := observability.StartTrace("rest_server_init_duration")
defer endTrace()
rs := &RESTServer{
router: setupRouter(config.MidW),
config: config,
}
log.Infow("rest_server_init_success", "addr", config.Addr, "port", config.Port)
return rs
}
func setupRouter(mid []gin.HandlerFunc) *gin.Engine {
mid = append(mid, cors.New(getCustomCorsConfig()))
router := gin.Default()
router.Use(mid...)
return router
}
// InitializeRoutes sets up all the endpoint routes
func (rs *RESTServer) InitializeRoutes() {
endTrace := observability.StartTrace("rest_server_route_init_duration")
defer endTrace()
v1 := rs.router.Group("/api/v1")
// /actor routes
actor := v1.Group("/actor")
{
actor.GET("/handle", rs.ActorHandle)
actor.POST("/send", rs.ActorSendMessage)
actor.POST("/invoke", rs.ActorInvoke)
actor.POST("/broadcast", rs.ActorBroadcast)
}
log.Infow("rest_server_route_init_success", "endpoint", "/api/v1/actor")
}
// Run starts the server on the specified port
func (rs *RESTServer) Run() error {
endTrace := observability.StartTrace("rest_server_run_duration")
defer endTrace()
addr := fmt.Sprintf("%s:%d", rs.config.Addr, rs.config.Port)
if err := rs.router.Run(addr); err != nil {
log.Errorw("rest_server_run_failure", "addr", addr, "error", err)
return err
}
log.Infow("rest_server_run_success", "addr", addr)
return nil
}
func getCustomCorsConfig() cors.Config {
config := defaultConfig()
// FIXME: This is a security concern.
config.AllowOrigins = []string{"http://localhost:9991", "http://localhost:9992"}
return config
}
// defaultConfig returns a generic default configuration mapped to localhost.
func defaultConfig() cors.Config {
return cors.Config{
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
AllowHeaders: []string{"Access-Control-Allow-Origin", "Origin", "Content-Length", "Content-Type"},
AllowCredentials: false,
MaxAge: 12 * time.Hour,
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
const (
CapstoreDir = "cap/"
DefaultUserContextName = "user"
KeystoreDir = "key/"
)
// NewActorCmd is a constructor for `actor` parent command
func NewActorCmd(client *utils.HTTPClient, afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "actor",
Short: "Interact with the actor system",
Long: `Interact with the actor system
Actors are the entities which compose the NuActor system, a secure decentralized programming framework based on the Actor Model.
Actors are connected through the libp2p network substrate and communication is achieved via immutable messages.
For more information on the actor system, please refer to actor/README.md`,
}
cmd.AddCommand(newActorMsgCmd(client, afs))
cmd.AddCommand(newActorSendCmd(client))
cmd.AddCommand(newActorInvokeCmd(client))
cmd.AddCommand(newActorBroadcastCmd(client))
cmd.AddCommand(newActorCmdGroup(client, afs))
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/utils"
)
func newActorBroadcastCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "broadcast <msg>",
Short: "Broadcast a message",
Long: `Broadcast a message to a topic
If a topic is specified in the message's payload, the message will be published to all subscribers of that topic.`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
var msg actor.Envelope
if err := json.Unmarshal([]byte(args[0]), &msg); err != nil {
return fmt.Errorf("could not unmarshal message: %w", err)
}
resBody, resCode, err := client.MakeRequest("POST", "/actor/broadcast", []byte(args[0]))
fmt.Fprintln(cmd.OutOrStdout(), string(resBody))
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
return nil
},
}
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
dmsUtil "gitlab.com/nunet/device-management-service/utils"
)
const (
fnTimeout = "timeout"
fnExpiry = "expiry"
fnContextName = "context"
fnDest = "dest"
bBroadcast = "broadcast"
bInvoke = "invoke"
bSend = "send"
)
func newActorCmdGroup(client *dmsUtil.HTTPClient, afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "cmd",
Short: "Invoke a predefined behavior on an actor",
Long: `Invoke a predefined behavior on an actor
Example:
nunet actor cmd --context user /broadcast/hello
Adding the --dest flag will cause the behavior to be invoked on the specified actor.
For more information on behaviors, refer to cmd/actor/README.md`,
ValidArgsFunction: func(_ *cobra.Command, args []string, _ string) ([]string, cobra.ShellCompDirective) {
if len(args) > 0 {
return nil, cobra.ShellCompDirectiveDefault
}
var completions []string
for k := range behaviors {
completions = append(completions, strings.Split(k, "/")[2])
}
return completions, cobra.ShellCompDirectiveNoFileComp
},
Run: func(cmd *cobra.Command, _ []string) {
err := cmd.Help()
if err != nil {
cmd.Println(err)
}
},
}
for behavior := range behaviors {
if behaviorCfg, ok := behaviors[behavior]; ok {
cmd.AddCommand(newActorCmdCmd(client, afs, behavior, behaviorCfg))
}
}
cmd.PersistentFlags().StringP(fnContextName, "c", "", "capability context name")
cmd.PersistentFlags().DurationP(fnTimeout, "t", 0, "timeout duration")
cmd.PersistentFlags().VarP(utils.NewTimeValue(&time.Time{}), fnExpiry, "e", "expiration time")
cmd.PersistentFlags().StringP(fnDest, "d", "", "destination DMS DID, peer ID or handle")
cmd.MarkFlagsMutuallyExclusive(fnTimeout, fnExpiry)
return cmd
}
func newActorCmdCmd(client *dmsUtil.HTTPClient, afs afero.Afero, behavior string, behaviorCfg behaviorConfig) *cobra.Command {
payload := &Payload{val: nil}
if behaviorCfg.Payload != nil {
payload.val = behaviorCfg.Payload()
}
cmd := &cobra.Command{
Use: fmt.Sprintf("%s [<param> ...]", behavior),
Short: behaviorCfg.Short,
Long: behaviorCfg.Long,
ValidArgsFunction: behaviorCfg.ValidArgsFn,
Args: behaviorCfg.Args,
PreRunE: func(cmd *cobra.Command, _ []string) error {
if behaviorCfg.PreRunE != nil {
return behaviorCfg.PreRunE(cmd, payload.val)
}
return nil
},
RunE: func(cmd *cobra.Command, _ []string) error {
timeout, _ := cmd.Flags().GetDuration(fnTimeout)
expiry, _ := utils.GetTime(cmd.Flags(), fnExpiry)
contextName, _ := cmd.Flags().GetString(fnContextName)
dest, _ := cmd.Flags().GetString(fnDest)
dmsHandle, err := getDMSHandle(client)
if err != nil {
return fmt.Errorf("could not get source DMS handle: %w", err)
}
topic := ""
if behaviorCfg.Type == bBroadcast {
topic = behaviorCfg.Topic
}
if behaviorCfg.PayloadEnc != nil {
payload.val, err = behaviorCfg.PayloadEnc(payload.val)
if err != nil {
return fmt.Errorf("could not marshal payload: %w", err)
}
}
invocation := behaviorCfg.Type == bInvoke
msg, err := newActorMessage(afs, dmsHandle, dest, topic, behavior, payload.val, timeout, expiry, invocation, contextName)
if err != nil {
return fmt.Errorf("could not create message: %w", err)
}
msgData, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("could not marshal message: %w", err)
}
endpoint := fmt.Sprintf("/actor/%s", behaviorCfg.Type)
resBody, resCode, err := client.MakeRequest("POST", endpoint, msgData)
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
enc := json.NewEncoder(cmd.OutOrStdout())
enc.SetIndent("", " ")
if behaviorCfg.Type == bBroadcast {
var resMsgs []cmdResponse
if err := json.Unmarshal(resBody, &resMsgs); err != nil {
return fmt.Errorf("could not unmarshal response: %w", err)
}
return enc.Encode(resMsgs)
}
var resMsg cmdResponse
if err := json.Unmarshal(resBody, &resMsg); err != nil {
return nil
}
return enc.Encode(resMsg)
},
}
if behaviorCfg.SetFlags != nil {
behaviorCfg.SetFlags(cmd, payload.val)
}
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/utils"
)
// NewActorInvokeCmd is a constructor for `actor invoke` subcommand
func newActorInvokeCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "invoke <msg>",
Short: "Invoke a behaviour in an actor and return the result",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
var msg actor.Envelope
if err := json.Unmarshal([]byte(args[0]), &msg); err != nil {
return fmt.Errorf("could not unmarshal message: %w", err)
}
if msg.Options.ReplyTo == "" {
return fmt.Errorf("missing replyTo field in message")
}
resBody, resCode, err := client.MakeRequest("POST", "/actor/invoke", []byte(args[0]))
fmt.Fprintln(cmd.OutOrStdout(), string(resBody))
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
return nil
},
}
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"encoding/json"
"fmt"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
dmsUtils "gitlab.com/nunet/device-management-service/utils"
)
func newActorMsgCmd(client *dmsUtils.HTTPClient, afs afero.Afero) *cobra.Command {
fnDest := "dest"
fnBroadcast := "broadcast"
fnTimeout := "timeout"
fnExpiry := "expiry"
fnInvoke := "invoke"
fnContextName := "context"
cmd := &cobra.Command{
Use: "msg <behavior> <payload>",
Short: "Construct a message",
Long: `Construct and sign a message that can be communicated to an actor.
The constructed message is returned as a JSON object that can be used stored or piped into another command, for instance the the send, invoke, or broadcast command.
Example:
nunet actor msg --broadcast /nunet/hello /broadcast/hello 'Hello, World!'`,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
destStr, _ := cmd.Flags().GetString(fnDest)
topic, _ := cmd.Flags().GetString(fnBroadcast)
timeout, _ := cmd.Flags().GetDuration(fnTimeout)
expiry, _ := utils.GetTime(cmd.Flags(), fnExpiry)
invocation, _ := cmd.Flags().GetBool(fnInvoke)
contextName, _ := cmd.Flags().GetString(fnContextName)
behavior := args[0]
payload := args[1]
dmsHandle, err := getDMSHandle(client)
if err != nil {
return fmt.Errorf("could not get source handle: %w", err)
}
msg, err := newActorMessage(afs, dmsHandle, destStr, topic, behavior, payload, timeout, expiry, invocation, contextName)
if err != nil {
return fmt.Errorf("could not create message: %w", err)
}
msgData, err := json.Marshal(msg)
if err != nil {
return err
}
fmt.Fprintln(cmd.OutOrStdout(), string(msgData))
return nil
},
}
cmd.Flags().StringP(fnDest, "d", "", "destination handle")
cmd.Flags().StringP(fnBroadcast, "b", "", "broadcast topic")
cmd.Flags().BoolP(fnInvoke, "i", false, "construct an invocation")
cmd.Flags().StringP(fnContextName, "c", "", "capability context name")
cmd.Flags().DurationP(fnTimeout, "t", 0, "timeout duration")
cmd.Flags().VarP(utils.NewTimeValue(&time.Time{}), fnExpiry, "e", "expiration time")
cmd.MarkFlagsMutuallyExclusive(fnDest, fnBroadcast)
cmd.MarkFlagsMutuallyExclusive(fnInvoke, fnBroadcast)
cmd.MarkFlagsMutuallyExclusive(fnTimeout, fnExpiry)
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/utils"
)
// NewActorSendCmd is a constructor for `actor send` subcommand
func newActorSendCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "send <msg>",
Short: "Send a message",
Long: `Send a message to an actor
Actors only communicate via messages. For more information on constructing a message, see:
nunet actor msg --help
The message is encoded into an actor envelope, which then is sent across the network through the API.`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
var msg actor.Envelope
if err := json.Unmarshal([]byte(args[0]), &msg); err != nil {
return fmt.Errorf("could not unmarshal message: %w", err)
}
resBody, resCode, err := client.MakeRequest("POST", "/actor/send", []byte(args[0]))
fmt.Fprintln(cmd.OutOrStdout(), string(resBody))
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
return nil
},
}
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"errors"
"fmt"
"os"
"strconv"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/hardware"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/jobs/parser"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/types"
)
var ErrInvalidArgument = errors.New("invalid argument")
type Command = cobra.Command
type Payload struct {
val any
}
type behaviorConfig struct {
Behavior string
Type string
Topic string
Payload func() any
PayloadEnc func(payload any) (any, error)
SetFlags func(cmd *Command, payload any)
PreRunE func(cmd *Command, payload any) error
ValidArgsFn func(cmd *Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective)
Args cobra.PositionalArgs
Long string
Short string
}
type NewDeploymentRequestCmd struct {
Config string
}
var behaviors = map[string]behaviorConfig{
// /public/hello
node.PublicHelloBehavior: {
Type: bInvoke,
Short: "Broadcast a 'hello' message",
Long: `Invoke the /public/hello behavior on an actor
This behavior broadcasts a "hello" for a polite introduction.
Examples:
nunet actor cmd --context user /public/hello
nunet actor cmd --context user /public/hello --dest <did/peer_id/actor_handle>`,
},
// /broadcast/hello
node.BroadcastHelloBehavior: {
Type: bBroadcast,
Topic: node.BroadcastHelloTopic,
Short: "Broadcast a 'hello' message to a topic",
Long: `Invokes the /broadcast/hello behavior on an actor
This behavior sends a "hello" message to a broadcast topic for polite introduction.
Examples:
nunet actor cmd --context user /broadcast/hello`,
},
// /public/status
node.PublicStatusBehavior: {
Type: bInvoke,
Short: "Retrieve actor status",
Long: `Invokes the /public/status behavior on an actor
This behavior retrieves the status and resources information.
Examples:
nunet actor cmd --context user /public/status # own actor status
nunet actor cmd --context user /public/status --dest <did/peer_id/actor_handle> # status of specified destination`,
},
// /dms/node/peers/list
node.PeersListBehavior: {
Type: bInvoke,
Short: "List connected peers",
Long: `Invokes the /dms/node/peers/list behavior on an actor
This behavior retrieves a list of connected peers.
Examples:
nunet actor cmd --context user /dms/node/peers/list # own node actor peer list
nunet actor cmd --context user /dms/node/peers/list --dest <did/peer_id/actor_handle> # specified node actor peer list`,
},
// /dms/node/peers/self
node.PeerAddrInfoBehavior: {
Type: bInvoke,
Short: "Get peer's ID and addresses",
Long: `Invokes the /dms/node/peers/self behavior on an actor
This behavior retrieves information about the node itself, such as its ID or addresses.
Examples:
nunet actor cmd --context user /dms/node/peers/self # own node actor peer ID
nunet actor cmd --context user /dms/node/peers/self --dest <did/peer_id/actor_handle> # specified node actor peer ID`,
},
// /dms/node/peers/ping
node.PeerPingBehavior: {
Type: bInvoke,
Payload: func() any { return &node.PingRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.PingRequest)
cmd.Flags().StringVarP(&p.Host, "host", "H", "", "host address to ping (required)")
_ = cmd.MarkFlagRequired("host")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.PingRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Ping a peer",
Long: `Invokes the /dms/node/peers/ping behavior on an actor
This behavior establishes a ping connection with a peer.
Examples:
nunet actor cmd --context user /dms/node/peers/ping --host <peer_id>`,
},
// /dms/node/peers/dht
node.PeerDHTBehavior: {
Type: bInvoke,
Short: "List peers connected to DHT",
Long: `Invokes the /dms/node/peers/dht behavior on an actor
This behavior returns a list of peers from the Distributed Hash Table (DHT) used for peer discovery and content routing.
Examples:
nunet actor cmd --context user /dms/node/peers/dht`,
},
// /dms/node/peers/connect
node.PeerConnectBehavior: {
Type: bInvoke,
Payload: func() any { return &node.PeerConnectRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.PeerConnectRequest)
cmd.Flags().StringVarP(&p.Address, "address", "a", "", "peer address to connect to (required)")
_ = cmd.MarkFlagRequired("address")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.PeerConnectRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Connect to a peer",
Long: `Invokes the /dms/node/peers/connect behavior on an actor
This behavior initiates a connection to a specified peer.
Examples:
nunet actor cmd --context user /dms/node/peers/connect --address /p2p/<peer_id>`,
},
// /dms/node/peers/score
node.PeerScoreBehavior: {
Type: bInvoke,
Short: "Retrieves gossipsub broadcast score",
Long: `Invokes the /dms/node/peers/score behavior on an actor
This behavior retrieves a snapshot of the peer's gossipsub broadcast score.
Examples:
nunet actor cmd --context user /dms/node/peers/score`,
},
// /dms/node/onboarding/onboard
node.OnboardBehavior: {
Type: bInvoke,
Payload: func() any { return &node.OnboardRequest{} },
SetFlags: func(cmd *Command, payload any) {
// infer the type of the payload
p := payload.(*node.OnboardRequest)
cmd.Flags().Float64VarP(&p.Config.OnboardedResources.RAM.Size, "ram", "R", 0, "set the amount of memory in GB to reserve for NuNet")
cmd.Flags().Float32VarP(&p.Config.OnboardedResources.CPU.Cores, "cpu", "C", 0, "set the number of CPU cores to reserve for NuNet")
cmd.Flags().Float64VarP(&p.Config.OnboardedResources.Disk.Size, "disk", "D", 0, "set the amount of disk size in GB to reserve for NuNet")
cmd.MarkFlagsOneRequired("ram", "cpu", "disk")
cmd.MarkFlagsRequiredTogether("ram", "cpu", "disk")
},
PreRunE: onboardBehaviorPreRun,
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.OnboardRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
// convert RAM and Disk from GB to bytes
req.Config.OnboardedResources.RAM.Size = types.ConvertGBToBytes(req.Config.OnboardedResources.RAM.Size)
req.Config.OnboardedResources.Disk.Size = types.ConvertGBToBytes(req.Config.OnboardedResources.Disk.Size)
return req, nil
},
Short: "Onboard a node to the network",
Long: `Invokes the /dms/node/onboarding/onboard behavior on an actor
This behavior is used to onboard a node to the DMS, making its resources available for use.
Examples:
nunet actor cmd --context user /dms/node/onboarding/onboard --memory 1 --cpu 2`,
},
// /dms/node/onboarding/offboard
node.OffboardBehavior: {
Type: bInvoke,
Payload: func() any { return &node.OffboardRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.OffboardRequest)
cmd.Flags().BoolVarP(&p.Force, "force", "f", false, "force offboard")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.OffboardRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Offboard a node from the network",
Long: `Invokes the /dms/node/onboarding/offboard behavior on an actor
This behavior is used to offboard a node from the DMS (Device Management Service).
Examples:
nunet actor cmd --context user /dms/node/onboarding/offboard
nunet actor cmd --context user /dms/node/onboarding/offboard --force`,
},
// /dms/node/onboarding/status
node.OnboardStatusBehavior: {
Type: bInvoke,
Short: "Retrieve onboarding status of a node",
Long: `Invokes the /dms/node/onboarding/status behavior on an actor
This behavior is used to check the onboarding status of a node.
Examples:
nunet actor cmd --context user /dms/node/onboarding/status`,
},
// /dms/node/deployment/list
node.DeploymentListBehavior: {
Type: bInvoke,
Short: "List deployments",
Long: `Invokes the /dms/node/deployment/list behavior on an actor
This behavior retrieves a list of all deployments on the node.
Examples:
nunet actor cmd --context user /dms/node/deployment/list`,
},
// /dms/node/deployment/status
node.DeploymentStatusBehavior: {
Type: bInvoke,
Payload: func() any { return &node.DeploymentStatusRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentStatusRequest)
cmd.Flags().StringVarP(&p.ID, "id", "i", "", "deployment ID (required)")
_ = cmd.MarkFlagRequired("id")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.DeploymentStatusRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Get deployment status",
Long: `Invokes the /dms/node/deployment/status behavior on an actor
This behavior retrieves the status of a specific deployment.
Examples:
nunet actor cmd --context user /dms/node/deployment/status --id <deployment_id>`,
},
// /dms/node/deployment/manifest
node.DeploymentManifestBehavior: {
Type: bInvoke,
Payload: func() any { return &node.DeploymentManifestRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentManifestRequest)
cmd.Flags().StringVarP(&p.ID, "id", "i", "", "deployment ID (required)")
_ = cmd.MarkFlagRequired("id")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.DeploymentManifestRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Get deployment manifest",
Long: `Invokes the /dms/node/deployment/manifest behavior on an actor
This behavior retrieves the manifest of a specific deployment.
Examples:
nunet actor cmd --context user /dms/node/deployment/manifest --id <deployment_id>`,
},
// /dms/node/deployment/shutdown
node.DeploymentShutdownBehavior: {
Type: bInvoke,
Payload: func() any { return &node.DeploymentShutdownRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentShutdownRequest)
cmd.Flags().StringVarP(&p.ID, "id", "i", "", "deployment ID (required)")
_ = cmd.MarkFlagRequired("id")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.DeploymentShutdownRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Short: "Shutdown a deployment",
Long: `Invokes the /dms/node/deployment/shutdown behavior on an actor
This behavior shuts down a specific deployment.
Examples:
nunet actor cmd --context user /dms/node/deployment/shutdown --id <deployment_id>`,
},
// /dms/node/vm/start/custom
node.VMStartBehavior: {
Type: bInvoke,
Payload: func() any { return &vmStartOpts{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*vmStartOpts)
cmd.Flags().StringVarP(&p.Engine.KernelImage, "kernel", "k", "", "path to kernel image file (required)")
cmd.Flags().StringVarP(&p.Engine.RootFileSystem, "rootfs", "r", "", "path to root fs image file (required)")
cmd.Flags().StringVarP(&p.Engine.Initrd, "initrd", "i", "", "path to initial ram disk")
cmd.Flags().StringVarP(&p.Engine.KernelArgs, "args", "a", "", "arguments to pas to the kernel")
cmd.Flags().Float32Var(&p.Resources.CPU.Cores, "cpu", 1, "CPU cores to allocate")
cmd.Flags().Float64VarP(&p.Resources.RAM.Size, "ram", "m", 1, "Memory to allocate in GB")
cmd.Flags().Float64Var(&p.Resources.Disk.Size, "disk", 0.5, "path to disk image file")
_ = cmd.MarkFlagRequired("kernel")
_ = cmd.MarkFlagFilename("kernel")
_ = cmd.MarkFlagRequired("rootfs")
_ = cmd.MarkFlagFilename("rootfs")
},
PayloadEnc: func(payload any) (any, error) {
opts, ok := payload.(*vmStartOpts)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return newCustomVMStartRequest(opts)
},
Short: "Starts a custom VM",
Long: `Invokes the /dms/node/vm/start/custom behavior on an actor
This behavior starts a new VM with custom configurations.
Examples:
nunet actor cmd --context user /dms/node/vm/start/custom --kernel /path/to/kernel --rootfs /path/to/rootfs --cpu 2 --memory 2048`,
},
// /dms/node/vm/stop
node.VMStopBehavior: {
Payload: func() any { return &node.VMStopRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.VMStopRequest)
p.ExecutionType = jobs.ExecutorFirecracker
cmd.Flags().StringVarP(&p.ExecutionID, "id", "i", "", "execution ID of the VM (required)")
_ = cmd.MarkFlagRequired("id")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*node.VMStopRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Stops a running VM",
Long: `Invokes the /dms/node/vm/stop behavior on an actor
This behavior stops a running VM.
Examples:
nunet actor cmd --context user /dms/node/vm/stop --id <execution_id>`,
},
// /dms/node/vm/list
node.VMListBehavior: {
Payload: func() any {
return &node.ListVMResponse{
ExecutionType: jobs.ExecutorFirecracker,
}
},
Type: bInvoke,
Short: "List running VMs",
Long: `Invokes the /dms/node/vm/list behavior on an actor
This behavior retrieves a list of virtual machines (VMs) running on the node.
Examples:
nunet actor cmd --context user /dms/node/vm/list`,
},
node.NewDeploymentBehavior: {
Type: bInvoke,
Short: "Create a new deployment",
Long: `Invokes the /dms/node/deployment/new behavior on an actor
This behavior creates a new deployment.
Examples:
nunet actor cmd --context user /dms/node/deployment/new --spec-file <path to ensemble specification file>`,
Payload: func() any { return &NewDeploymentRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*NewDeploymentRequestCmd)
cmd.Flags().StringVarP(&p.Config, "spec-file", "f", "ensemble.yaml", "path of the ensemble specification file (required)")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*NewDeploymentRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
data, err := os.ReadFile(req.Config)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
cfg := &node.NewDeploymentRequest{}
err = parser.Parse(parser.SpecTypeEnsembleV1, data, &cfg.Ensemble)
if err != nil {
return nil, err
}
for name, script := range cfg.Ensemble.V1.Scripts {
fmt.Println(name, string(script))
scriptData, err := os.ReadFile(string(script))
if err != nil {
return nil, fmt.Errorf("failed to read script file: %w", err)
}
cfg.Ensemble.V1.Scripts[name] = scriptData
}
for name, key := range cfg.Ensemble.V1.Keys {
key, err := os.ReadFile(key)
if err != nil {
return nil, fmt.Errorf("failed to read script file: %w", err)
}
cfg.Ensemble.V1.Keys[name] = string(key)
}
return cfg, nil
},
},
jobs.SubnetCreateBehavior: {
Payload: func() any { return &jobs.SubnetCreateRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*jobs.SubnetCreateRequest)
cmd.Flags().StringVarP(&p.SubnetID, "subnet-id", "s", "", "subnet ID (required)")
cmd.Flags().StringToStringVarP(&p.RoutingTable, "routing-table", "r", nil, "subnet routing table (required)")
_ = cmd.MarkFlagRequired("subnet-id")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*jobs.SubnetCreateRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Create a subnet",
Long: `Invokes the /dms/node/subnet/create behavior on an actor
This behavior creates a new subnet with the specified subnet ID, IP address, and routing table.
Examples:
nunet actor cmd --context user /dms/node/subnet/create --subnet-id <subnet_id> --ip <ip> --routing-table <routing_table>`,
},
jobs.SubnetDestroyBehavior: {
Payload: func() any { return &jobs.SubnetDestroyRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*jobs.SubnetDestroyRequest)
cmd.Flags().StringVarP(&p.SubnetID, "subnet-id", "s", "", "subnet ID (required)")
_ = cmd.MarkFlagRequired("subnet-id")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*jobs.SubnetDestroyRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Destroy a subnet",
Long: `Invokes the /dms/node/subnet/destroy behavior on an actor
This behavior destroys the specified subnet.
Examples:
nunet actor cmd --context user /dms/node/subnet/destroy --subnet-id <subnet_id>`,
},
jobs.SubnetAddPeerBehavior: {
Payload: func() any { return &jobs.SubnetAddPeerRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*jobs.SubnetAddPeerRequest)
cmd.Flags().StringVarP(&p.SubnetID, "subnet-id", "s", "", "subnet ID (required)")
cmd.Flags().StringVarP(&p.PeerID, "peer-id", "p", "", "peer ID (required)")
cmd.Flags().StringVarP(&p.IP, "ip", "i", "", "peer IP address (required)")
_ = cmd.MarkFlagRequired("subnet-id")
_ = cmd.MarkFlagRequired("peer-id")
_ = cmd.MarkFlagRequired("ip")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*jobs.SubnetAddPeerRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Add a peer to a subnet",
Long: `Invokes the /dms/node/subnet/add-peer behavior on an actor
This behavior adds a peer to the specified subnet.
Examples:
nunet actor cmd --context user /dms/node/subnet/add-peer --subnet-id <subnet_id> --peer-id <peer_id> --ip <ip>`,
},
jobs.SubnetRemovePeerBehavior: {
Payload: func() any { return &jobs.SubnetRemovePeerRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*jobs.SubnetRemovePeerRequest)
cmd.Flags().StringVarP(&p.SubnetID, "subnet-id", "s", "", "subnet ID (required)")
cmd.Flags().StringVarP(&p.PeerID, "peer-id", "p", "", "peer ID (required)")
_ = cmd.MarkFlagRequired("subnet-id")
_ = cmd.MarkFlagRequired("peer-id")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*jobs.SubnetRemovePeerRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Remove a peer from a subnet",
Long: `Invokes the /dms/node/subnet/remove-peer behavior on an actor
This behavior removes a peer from the specified subnet.
Examples:
nunet actor cmd --context user /dms/node/subnet/remove-peer --subnet-id <subnet_id> --peer-id <peer_id>`,
},
jobs.SubnetAcceptPeerBehavior: {
Payload: func() any { return &jobs.SubnetAcceptPeerRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*jobs.SubnetAcceptPeerRequest)
cmd.Flags().StringVarP(&p.SubnetID, "subnet-id", "s", "", "subnet ID (required)")
cmd.Flags().StringVarP(&p.PeerID, "peer-id", "p", "", "peer ID (required)")
cmd.Flags().StringVarP(&p.IP, "ip", "i", "", "peer IP address (required)")
_ = cmd.MarkFlagRequired("subnet-id")
_ = cmd.MarkFlagRequired("peer-id")
_ = cmd.MarkFlagRequired("ip")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*jobs.SubnetAcceptPeerRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Accept a peer to a subnet",
Long: `Invokes the /dms/node/subnet/accept-peer behavior on an actor
This behavior accepts a peer to the specified subnet.
Examples:
nunet actor cmd --context user /dms/node/subnet/accept-peer --subnet-id <subnet_id> --peer-id <peer_id> --ip <ip>`,
},
jobs.SubnetMapPortBehavior: {
Payload: func() any { return &jobs.SubnetMapPortRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*jobs.SubnetMapPortRequest)
cmd.Flags().StringVarP(&p.SubnetID, "subnet-id", "i", "", "subnet-id (required)")
cmd.Flags().StringVarP(&p.Protocol, "protocol", "p", "", "protocol (required)")
cmd.Flags().StringVarP(&p.SourceIP, "source-ip", "s", "", "source IP address (required)")
cmd.Flags().StringVarP(&p.SourcePort, "source-port", "o", "", "source port (required)")
cmd.Flags().StringVarP(&p.DestIP, "dest-ip", "d", "", "destination IP address (required)")
cmd.Flags().StringVarP(&p.DestPort, "dest-port", "n", "", "destination port (required)")
_ = cmd.MarkFlagRequired("protocol")
_ = cmd.MarkFlagRequired("source-ip")
_ = cmd.MarkFlagRequired("source-port")
_ = cmd.MarkFlagRequired("dest-ip")
_ = cmd.MarkFlagRequired("dest-port")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*jobs.SubnetMapPortRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Map a port",
Long: `Invokes the /dms/node/subnet/map-port behavior on an actor
This behavior maps a port from the source to the destination.
Examples:
nunet actor cmd --context user /dms/node/subnet/map-port --protocol <protocol> --source-ip <source_ip> --source-port <source_port> --dest-ip <dest_ip> --dest-port <dest_port>`,
},
jobs.SubnetDNSAddRecordBehavior: {
Payload: func() any { return &jobs.SubnetDNSAddRecordRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*jobs.SubnetDNSAddRecordRequest)
cmd.Flags().StringVarP(&p.SubnetID, "subnet-id", "s", "", "subnet ID (required)")
cmd.Flags().StringVarP(&p.DomainName, "domain-name", "n", "", "A record name (required)")
cmd.Flags().StringVarP(&p.IP, "ip", "i", "", "IP address (required)")
_ = cmd.MarkFlagRequired("subnet-id")
_ = cmd.MarkFlagRequired("name")
_ = cmd.MarkFlagRequired("ip")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*jobs.SubnetDNSAddRecordRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Add a DNS record",
Long: `Invokes the /dms/node/subnet/dns/add-record behavior on an actor
This behavior adds a DNS record to the local resolver.
Examples:
nunet actor cmd --context user /dms/node/subnet/dns/add-record --subnet-id <subnet_id> --name <record_name> --ip <ip>`,
},
jobs.SubnetUnmapPortBehavior: {
Payload: func() any { return &jobs.SubnetUnmapPortRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*jobs.SubnetUnmapPortRequest)
cmd.Flags().StringVarP(&p.SubnetID, "subnet-id", "i", "", "subnet-id (required)")
cmd.Flags().StringVarP(&p.Protocol, "protocol", "p", "", "protocol (required)")
cmd.Flags().StringVarP(&p.SourceIP, "source-ip", "s", "", "source IP address (required)")
cmd.Flags().StringVarP(&p.SourcePort, "source-port", "o", "", "source port (required)")
cmd.Flags().StringVarP(&p.DestIP, "dest-ip", "d", "", "destination IP address (required)")
cmd.Flags().StringVarP(&p.DestPort, "dest-port", "n", "", "destination port (required)")
_ = cmd.MarkFlagRequired("subnet-id")
_ = cmd.MarkFlagRequired("protocol")
_ = cmd.MarkFlagRequired("source-ip")
_ = cmd.MarkFlagRequired("source-port")
_ = cmd.MarkFlagRequired("dest-ip")
_ = cmd.MarkFlagRequired("dest-port")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*jobs.SubnetUnmapPortRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Unmap a port",
Long: `Invokes the /dms/node/subnet/unmap-port behavior on an actor
This behavior removes a port mapping.
Examples:
nunet actor cmd --context user /dms/node/subnet/unmap-port --subnet-id <subnet_id> --protocol <protocol> --source-ip <source_ip> --source-port <source_port> --dest-ip <dest_ip> --dest-port <dest_port>`,
},
jobs.SubnetDNSRemoveRecordBehavior: {
Payload: func() any { return &jobs.SubnetDNSRemoveRecordRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*jobs.SubnetDNSRemoveRecordRequest)
cmd.Flags().StringVarP(&p.SubnetID, "subnet-id", "s", "", "subnet ID (required)")
cmd.Flags().StringVarP(&p.DomainName, "domain-name", "n", "", "A record name (required)")
_ = cmd.MarkFlagRequired("subnet-id")
_ = cmd.MarkFlagRequired("name")
},
PayloadEnc: func(payload any) (any, error) {
req, ok := payload.(*jobs.SubnetDNSRemoveRecordRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return req, nil
},
Type: bInvoke,
Short: "Remove a DNS record",
Long: `Invokes the /dms/node/subnet/dns/remove-record behavior on an actor
This behavior removes a DNS record from the local resolver.
Examples:
nunet actor cmd --context user /dms/node/subnet/dns/remove-record --subnet-id <subnet_id> --name <record_name>`,
},
}
func onboardBehaviorPreRun(_ *Command, payload any) error {
p, ok := payload.(*node.OnboardRequest)
if !ok {
return ErrInvalidArgument
}
// TODO: we need to have single instance of hardware manager
// Should we do an api call here?
hardwareManager := hardware.NewHardwareManager()
machineResources, err := hardwareManager.GetMachineResources()
if err != nil {
return fmt.Errorf("could not get machine resources: %w", err)
}
p.Config.OnboardedResources.CPU.ClockSpeed = machineResources.CPU.ClockSpeed
if len(machineResources.GPUs) != 0 {
var (
gpuMap = make(map[string]types.GPU)
gpuPromptItems []*selectPromptItem
)
for _, gpu := range machineResources.GPUs {
gpuMap[gpu.Model] = gpu
gpuPromptItems = append(gpuPromptItems, &selectPromptItem{
Label: gpu.Model,
})
}
res, err := selectPrompt("Select GPU", gpuPromptItems)
if err != nil {
return fmt.Errorf("could not select GPU: %w", err)
}
vramValidator := func(input string) error {
if _, err := strconv.ParseFloat(input, 64); err != nil {
return fmt.Errorf("invalid input: %w", err)
}
return nil
}
for _, gpuName := range res {
input, err := prompt("Enter VRAM in GB", vramValidator)
if err != nil {
return fmt.Errorf("could not prompt for VRAM: %w", err)
}
vram, err := strconv.ParseFloat(input, 64)
if err != nil {
return fmt.Errorf("could not parse VRAM: %w", err)
}
gpu := gpuMap[gpuName]
gpu.VRAM = types.ConvertGBToBytes(vram)
p.Config.OnboardedResources.GPUs = append(p.Config.OnboardedResources.GPUs, gpu)
}
} else {
fmt.Println("No GPUs found. Skipping GPU selection.")
}
p.Config.OnboardedResources.CPU.ClockSpeed = machineResources.CPU.ClockSpeed
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
// +build linux
package actor
import (
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/executor/firecracker"
"gitlab.com/nunet/device-management-service/types"
)
func newCustomVMStartRequest(opts *vmStartOpts) (node.CustomVMStartRequest, error) {
engine := firecracker.NewFirecrackerEngineBuilder(opts.Engine.RootFileSystem)
engine = engine.WithKernelImage(opts.Engine.KernelImage)
engine = engine.WithKernelArgs(opts.Engine.KernelArgs)
engine = engine.WithInitrd(opts.Engine.Initrd)
es := engine.Build()
req := node.CustomVMStartRequest{
Execution: types.ExecutionRequest{
ExecutionID: uuid.New().String(),
EngineSpec: es,
Resources: &opts.Resources,
},
}
return req, nil
}
type vmStartOpts struct {
Engine firecracker.EngineSpec
Resources types.Resources
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"fmt"
"github.com/manifoldco/promptui"
)
type selectPromptItem struct {
Label string
Selected bool
}
func selectPrompt(label string, items []*selectPromptItem) ([]string, error) {
// Always prepend a "Done" item to the slice if it doesn't
// already exist.
const doneLabel = "Done"
if len(items) > 0 && items[0].Label != doneLabel {
items = append([]*selectPromptItem{{Label: doneLabel}}, items...)
}
template := &promptui.SelectTemplates{
Label: "{{ .Label }}",
Active: "{{ if .Selected }}{{ \"✔\" | green }} {{ end }}→ {{ .Label | cyan | bold }}",
Inactive: "{{ if .Selected }}{{ \"✔\" | green }} {{ end }} {{ .Label | faint }}",
Selected: "{{ .Label | green | bold }}",
}
p := promptui.Select{
Label: label,
Items: items,
Templates: template,
Size: len(items),
}
index, _, err := p.Run()
if err != nil {
return nil, fmt.Errorf("prompt failed %w", err)
}
selectedItem := items[index]
if selectedItem.Label != doneLabel {
selectedItem.Selected = !selectedItem.Selected
return selectPrompt(label, items)
}
var selected []string
for _, item := range items {
if item.Selected {
selected = append(selected, item.Label)
}
}
return selected, nil
}
func prompt(label string, validate func(string) error) (string, error) {
p := promptui.Prompt{
Label: label,
Validate: validate,
}
result, err := p.Run()
return result, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"encoding/json"
"fmt"
"strings"
"time"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/cmd/cap"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/utils"
)
type cmdResponse struct {
val interface{}
}
func (r *cmdResponse) UnmarshalJSON(data []byte) error {
var res struct {
Message []byte `json:"msg"`
}
if err := json.Unmarshal(data, &res); err != nil {
return err
}
val := interface{}(nil)
if err := json.Unmarshal(res.Message, &val); err != nil {
return err
}
*r = cmdResponse{val: val}
return nil
}
func (r cmdResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(r.val)
}
func getDMSHandle(client *utils.HTTPClient) (actor.Handle, error) {
var handle actor.Handle
body, code, err := client.MakeRequest("GET", "/actor/handle", nil)
if err != nil {
return handle, fmt.Errorf("unable to get source handle: %w", err)
}
if code != 200 {
return handle, fmt.Errorf("request failed with status code: %d", code)
}
if err = json.Unmarshal(body, &handle); err != nil {
return handle, fmt.Errorf("could not unmarshal response body: %w", err)
}
return handle, err
}
func newUserHandle(id crypto.ID, did did.DID, dmsHandle actor.Handle, inbox string) actor.Handle {
return actor.Handle{
ID: id,
DID: did,
Address: actor.Address{
HostID: dmsHandle.Address.HostID,
InboxAddress: inbox,
},
}
}
func newSecurityContext(fs afero.Afero, context string) (actor.SecurityContext, error) {
if context == "" {
context = DefaultUserContextName
}
// Generate ephemeral key pair
privk, pubk, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return nil, fmt.Errorf("failed to generate ephemeral key pair: %w", err)
}
// Create trust context
var trustCtx did.TrustContext
if cap.IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return nil, err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = cap.LedgerContext(context)
} else {
var err error
trustCtx, _, err = cap.CreateTrustContextFromKeyStore(fs, context)
if err != nil {
return nil, fmt.Errorf("failed to create trust context: %w", err)
}
}
// Load capability context
capCtx, err := cap.LoadCapabilityContext(trustCtx, context)
if err != nil {
return nil, fmt.Errorf("failed to load capability context: %w", err)
}
return actor.NewBasicSecurityContext(pubk, privk, capCtx)
}
func newActorMessage(fs afero.Afero, dmsHandle actor.Handle, destStr string, topic, behavior string, payload interface{}, timeout time.Duration, expiry time.Time, invocation bool, context string) (actor.Envelope, error) {
var msg actor.Envelope
var src actor.Handle
var dest actor.Handle
sctx, err := newSecurityContext(fs, context)
if err != nil {
return msg, fmt.Errorf("failed to create security context: %w", err)
}
nonce := sctx.Nonce()
inbox := fmt.Sprintf("user-%d", nonce)
src = newUserHandle(sctx.ID(), sctx.DID(), dmsHandle, inbox)
opts := []actor.MessageOption{}
replyTo := ""
switch {
case topic != "":
opts = append(opts, actor.WithMessageTopic(topic))
replyTo = fmt.Sprintf("/public/user/%d", nonce)
case destStr != "":
switch {
case strings.HasPrefix(destStr, "did:"):
dest, err = actor.HandleFromDID(destStr)
case strings.HasPrefix(destStr, "{"):
err = json.Unmarshal([]byte(destStr), &dest)
default:
dest, err = actor.HandleFromPeerID(destStr)
}
if err != nil {
return msg, fmt.Errorf("could not create destination handle: %w", err)
}
default:
dest = dmsHandle
}
if invocation {
replyTo = fmt.Sprintf("/private/user/%d", nonce)
}
if !expiry.IsZero() {
opts = append(opts, actor.WithMessageExpiry(uint64(expiry.UnixNano())))
}
if timeout > 0 {
opts = append(opts, actor.WithMessageTimeout(timeout))
}
delegate := []ucan.Capability{}
if replyTo != "" {
opts = append(opts, actor.WithMessageReplyTo(replyTo))
if topic == "" {
delegate = append(delegate, ucan.Capability(replyTo))
}
}
opts = append(opts, actor.WithMessageSignature(sctx, []ucan.Capability{ucan.Capability(behavior)}, delegate))
msg, err = actor.Message(src, dest, behavior, payload, opts...)
if err != nil {
return msg, fmt.Errorf("could not construct message: %w", err)
}
return msg, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"os"
"github.com/spf13/cobra"
)
// autocompleteCmd represents the command to generate shell autocompletion scripts
func newAutoCompleteCmd() *cobra.Command {
return &cobra.Command{
Use: "autocomplete [shell]",
Short: "Generate autocomplete script for your shell",
Long: `Generate an autocomplete script for the nunet CLI.
This command supports Bash and Zsh shells.`,
DisableFlagsInUseLine: true,
Hidden: true,
ValidArgs: []string{"bash", "zsh"},
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
Run: func(cmd *cobra.Command, args []string) {
switch args[0] {
case "bash":
_ = cmd.Root().GenBashCompletion(os.Stdout)
case "zsh":
_ = cmd.Root().GenZshCompletion(os.Stdout)
}
},
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package backend
import gonet "github.com/shirou/gopsutil/net"
type Network struct{}
func (n *Network) GetConnections(kind string) ([]gonet.ConnectionStat, error) {
return gonet.Connections(kind)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"encoding/json"
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newAnchorCmd(afs afero.Afero) *cobra.Command {
var (
context string
root string
provide string
require string
)
const (
fnProvide = "provide"
fnRoot = "root"
fnRequire = "require"
)
cmd := &cobra.Command{
Use: "anchor",
Short: "Manage capability anchors",
Long: `Add or modify capability anchors in a capability context.
An anchor is a basis of trust in the capability system. There are three types of anchors:
1. Root anchor: Represents absolute trust or effective root capability.
Use the --root flag with a DID value to add a root anchor.
2. Require anchor: Represents input trust. We verify incoming messages based on the require anchor.
Use the --require flag with a token to add a require anchor.
3. Provide anchor: Represents output trust. We emit invocation tokens based on our provide anchors and sign output.
Use the --provide flag with a token to add a provide anchor.
Only one type of anchor can be added or modified per command execution.
Usage examples:
nunet cap anchor --context user --root did:example:123456789abcdefghi
nunet cap anchor --context dms --require '{"some": "json", "token": "here"}'
nunet cap anchor --context user --provide '{"another": "json", "token": "example"}'
Note: The --context flag is required to specify the capability context.`,
RunE: func(_ *cobra.Command, _ []string) error {
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
var err error
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
switch {
case root != "":
rootDID, err := did.FromString(root)
if err != nil {
return fmt.Errorf("invalid root DID: %w", err)
}
if err := capCtx.AddRoots([]did.DID{rootDID}, ucan.TokenList{}, ucan.TokenList{}); err != nil {
return fmt.Errorf("failed to add root anchors: %w", err)
}
case require != "":
var tokens ucan.TokenList
if err := json.Unmarshal([]byte(require), &tokens); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
if err := capCtx.AddRoots(nil, tokens, ucan.TokenList{}); err != nil {
return fmt.Errorf("failed to add require anchors: %w", err)
}
case provide != "":
var tokens ucan.TokenList
if err := json.Unmarshal([]byte(provide), &tokens); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
if err := capCtx.AddRoots(nil, ucan.TokenList{}, tokens); err != nil {
return fmt.Errorf("failed to add provide anchors: %w", err)
}
default:
return fmt.Errorf("one of --provide, --root, or --require must be specified")
}
if err := SaveCapabilityContext(capCtx, context); err != nil {
return fmt.Errorf("save capability context: %w", err)
}
return nil
},
}
useFlagContext(cmd, &context)
useFlagRoot(cmd, &root)
useFlagRequire(cmd, &require)
useFlagProvide(cmd, &provide)
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnProvide, fnRoot, fnRequire)
cmd.MarkFlagsMutuallyExclusive(fnProvide, fnRoot, fnRequire)
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
const (
fnContext = "context"
fnAudience = "audience"
fnAction = "action"
fnCap = "cap"
fnTopic = "topic"
fnExpiry = "expiry"
fnDuration = "duration"
fnAutoExpire = "auto-expire"
fnSelfSign = "self-sign"
fnDepth = "depth"
fnProvide = "provide"
fnRoot = "root"
fnRequire = "require"
)
// NewCapCmd returns the cap command that adds other commands
func NewCapCmd(afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "cap",
Short: "Manage capabilities",
Long: `Manage capabilities for the Device Management Service`,
}
cmd.AddCommand(newGrantCmd(afs))
cmd.AddCommand(newAnchorCmd(afs))
cmd.AddCommand(newNewCmd(afs))
cmd.AddCommand(newDelegateCmd(afs))
cmd.AddCommand(newListCmd(afs))
cmd.AddCommand(newRemoveCmd(afs))
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"encoding/json"
"fmt"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newDelegateCmd(afs afero.Afero) *cobra.Command {
var (
context string
caps []string
topics []string
audience string
expiry time.Time
duration time.Duration
autoExpire bool
depth uint64
selfSign string
)
cmd := &cobra.Command{
Use: "delegate <did>",
Short: "Delegate capabilities",
Long: `Delegate capabilities to a subject
Capabilities are delegated based on provide anchors. No capabilities are delegated by default, you need to use --cap flag to explicitly specify the capabilities to delegate.
Example:
nunet cap anchor --context user --provide '<token>'
nunet cap delegate --context user --cap /public --duration 1h did:key:<some-key>`,
Args: cobra.ExactArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
subject := args[0]
var expirationTime uint64
switch {
case !expiry.IsZero():
expirationTime = uint64(expiry.UnixNano())
case duration != 0:
expirationTime = uint64(time.Now().Add(duration).UnixNano())
case autoExpire:
expirationTime = 0
default:
return fmt.Errorf("either expiration or duration must be specified")
}
subjectDID, err := did.FromString(subject)
if err != nil {
return fmt.Errorf("invalid subject DID: %w", err)
}
var audienceDID did.DID
if audience != "" {
audienceDID, err = did.FromString(audience)
if err != nil {
return fmt.Errorf("invalid audience DID: %w", err)
}
}
capabilities := make([]ucan.Capability, len(caps))
for i, cap := range caps {
capabilities[i] = ucan.Capability(cap)
}
var selfSignMode ucan.SelfSignMode
switch selfSign {
case "no":
selfSignMode = ucan.SelfSignNo
case "also":
selfSignMode = ucan.SelfSignAlso
case "only":
selfSignMode = ucan.SelfSignOnly
default:
return fmt.Errorf("invalid self-sign option: %s", selfSign)
}
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
tokens, err := capCtx.Delegate(subjectDID, audienceDID, topics, expirationTime, depth, capabilities, selfSignMode)
if err != nil {
return fmt.Errorf("failed to delegate capabilities: %w", err)
}
tokensJSON, err := json.Marshal(tokens)
if err != nil {
return fmt.Errorf("unable to marshal tokens to json: %w", err)
}
fmt.Println(string(tokensJSON))
return nil
},
}
useFlagContext(cmd, &context)
useFlagAudience(cmd, &audience)
useFlagCap(cmd, &caps)
useFlagTopic(cmd, &topics)
useFlagExpiry(cmd, &expiry)
useFlagDuration(cmd, &duration)
useFlagAutoExpire(cmd, &autoExpire)
useFlagDepth(cmd, &depth)
cmd.Flags().StringVar(&selfSign, fnSelfSign, "no", "Self-sign option: 'no', 'also', or 'only'")
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnExpiry, fnDuration, fnAutoExpire)
cmd.MarkFlagsMutuallyExclusive(fnExpiry, fnDuration, fnAutoExpire)
cmd.MarkFlagsMutuallyExclusive(fnSelfSign, fnAutoExpire)
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"time"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
)
func useFlagContext(cmd *cobra.Command, context *string) {
cmd.Flags().StringVarP(context, fnContext, "c", dms.UserContextName, "specifies capability context")
}
func useFlagAudience(cmd *cobra.Command, audience *string) {
cmd.Flags().StringVarP(audience, fnAudience, "a", "", "audience DID (optional)")
}
func useFlagCap(cmd *cobra.Command, caps *[]string) {
cmd.Flags().StringSliceVar(caps, fnCap, []string{}, "capabilities to grant/delegate (can be specified multiple times)")
}
func useFlagTopic(cmd *cobra.Command, topics *[]string) {
cmd.Flags().StringSliceVarP(topics, fnTopic, "t", []string{}, "topics to grant/delegate (can be specified multiple times)")
}
func useFlagExpiry(cmd *cobra.Command, expiry *time.Time) {
cmd.Flags().VarP(utils.NewTimeValue(expiry), fnExpiry, "e", "set expiration date (ISO 8601 format)")
}
func useFlagDuration(cmd *cobra.Command, duration *time.Duration) {
cmd.Flags().DurationVar(duration, fnDuration, 0, "set duration time (specify unit)")
}
func useFlagAutoExpire(cmd *cobra.Command, autoExpire *bool) {
cmd.Flags().BoolVar(autoExpire, fnAutoExpire, false, "set auto expiration")
}
func useFlagDepth(cmd *cobra.Command, depth *uint64) {
cmd.Flags().Uint64VarP(depth, fnDepth, "d", 0, "delegation depth (optional, default=0)")
}
func useFlagRoot(cmd *cobra.Command, root *string) {
cmd.Flags().StringVar(root, fnRoot, "", "DID to add as root anchor (represents absolute trust)")
}
func useFlagRequire(cmd *cobra.Command, require *string) {
cmd.Flags().StringVar(require, fnRequire, "", "JWT to add as require anchor (for input trust)")
}
func useFlagProvide(cmd *cobra.Command, provide *string) {
cmd.Flags().StringVar(provide, fnProvide, "", "JWT to add as provide anchor (for output trust)")
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"encoding/json"
"fmt"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newGrantCmd(afs afero.Afero) *cobra.Command {
var (
context string
caps []string
topics []string
audience string
expiry time.Time
duration time.Duration
depth uint64
)
cmd := &cobra.Command{
Use: "grant <did>",
Short: "Grant capabilities",
Long: `Grant a self-sign token delegating capabilities
It is not necessary to set up a anchor before granting a capability because this operation is self-signed.
Example:
nunet cap grant --context user --cap /public --duration 1h did:key:<some-key>
The above command emits a self-signed token with the specified capabilities delegated from 'user' to the sbjects's DID. `,
Args: cobra.ExactArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
subject := args[0]
var expirationTime uint64
switch {
case !expiry.IsZero():
expirationTime = uint64(expiry.UnixNano())
case duration != 0:
expirationTime = uint64(time.Now().Add(duration).UnixNano())
default:
return fmt.Errorf("either expiration or duration must be specified")
}
subjectDID, err := did.FromString(subject)
if err != nil {
return fmt.Errorf("invalid subject DID: %w", err)
}
var audienceDID did.DID
if audience != "" {
audienceDID, err = did.FromString(audience)
if err != nil {
return fmt.Errorf("invalid audience DID: %w", err)
}
}
capabilities := make([]ucan.Capability, len(caps))
for i, cap := range caps {
capabilities[i] = ucan.Capability(cap)
}
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
tokens, err := capCtx.Grant(ucan.Delegate, subjectDID, audienceDID, topics, expirationTime, depth, capabilities)
if err != nil {
return fmt.Errorf("failed to grant capabilities: %w", err)
}
tokensJSON, err := json.Marshal(tokens)
if err != nil {
return fmt.Errorf("unable to marshal tokens to json: %w", err)
}
fmt.Println(string(tokensJSON))
return nil
},
}
useFlagContext(cmd, &context)
useFlagAudience(cmd, &audience)
useFlagCap(cmd, &caps)
useFlagTopic(cmd, &topics)
useFlagExpiry(cmd, &expiry)
useFlagDuration(cmd, &duration)
useFlagDepth(cmd, &depth)
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnExpiry, fnDuration)
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"encoding/json"
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
)
func newListCmd(afs afero.Afero) *cobra.Command {
var context string
cmd := &cobra.Command{
Use: "list",
Short: "List capability anchors",
Long: `List all capability anchors in a capability context
It outputs DIDs and capability tokens set for root, provide and require anchors.`,
RunE: func(_ *cobra.Command, _ []string) error {
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
var err error
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
roots, require, provide := capCtx.ListRoots()
fmt.Println("roots:")
for _, root := range roots {
fmt.Printf("\t%s\n", root)
}
fmt.Println("require:")
for _, t := range require.Tokens {
data, err := json.Marshal(t)
if err != nil {
return fmt.Errorf("failed to marshal capability token: %w", err)
}
fmt.Printf("\t%s\n", string(data))
}
fmt.Println("provide:")
for _, t := range provide.Tokens {
data, err := json.Marshal(t)
if err != nil {
return fmt.Errorf("failed to marshal capability token: %w", err)
}
fmt.Printf("\t%s\n", string(data))
}
return nil
},
}
useFlagContext(cmd, &context)
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"fmt"
"path/filepath"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newNewCmd(afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "new <name>",
Short: "Create a new capability context",
Long: `Create a new persistent capability context
Example:
nunet cap new user
nunet cap new ledger:user # if using ledger`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
context := dms.UserContextName
if len(args) > 0 {
context = args[0]
}
var trustCtx did.TrustContext
var rootDID did.DID
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
rootDID = provider.DID()
context = LedgerContext(context)
} else {
var priv crypto.PrivKey
var err error
trustCtx, priv, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
rootDID = did.FromPublicKey(priv.GetPublic())
}
capStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.CapstoreDir)
capStoreFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", context))
fileExists, err := afs.Exists(capStoreFile)
if err != nil {
return fmt.Errorf("unable to check if capability context file exists: %w", err)
}
if fileExists {
confirmed, err := utils.PromptYesNo(
cmd.InOrStdin(),
cmd.OutOrStdout(),
fmt.Sprintf(
"WARNING: A capability context file already exists at %s. Creating a new one will overwrite the existing context. Do you want to proceed?",
capStoreFile,
),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return fmt.Errorf("operation cancelled by user")
}
} else {
if err := afs.MkdirAll(capStoreDir, 0o700); err != nil {
return fmt.Errorf("unable to create capability store directory: %w", err)
}
}
capCtx, err := ucan.NewCapabilityContext(trustCtx, rootDID, nil, ucan.TokenList{}, ucan.TokenList{})
if err != nil {
return fmt.Errorf("unable to create capability context: %w", err)
}
if err := SaveCapabilityContext(capCtx, context); err != nil {
return fmt.Errorf("save capability context: %w", err)
}
return nil
},
}
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"encoding/json"
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
func newRemoveCmd(afs afero.Afero) *cobra.Command {
var (
context string
root string
provide string
require string
)
const (
fnProvide = "provide"
fnRoot = "root"
fnRequire = "require"
)
cmd := &cobra.Command{
Use: "remove",
Short: "Remove capability anchors",
Long: `Remove capability anchors in a capability context
One capability anchor must be specified at a time.
Example:
nunet cap remove --context user --root did:key:abcd1234
nunet cap remove --context user --require '<the-token>'`,
RunE: func(_ *cobra.Command, _ []string) error {
var trustCtx did.TrustContext
if IsLedgerContext(context) {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
context = LedgerContext(context)
} else {
var err error
trustCtx, _, err = CreateTrustContextFromKeyStore(afs, context)
if err != nil {
return fmt.Errorf("failed to create trust context: %w", err)
}
}
capCtx, err := LoadCapabilityContext(trustCtx, context)
if err != nil {
return fmt.Errorf("failed to load capability context: %w", err)
}
switch {
case root != "":
rootDID, err := did.FromString(root)
if err != nil {
return fmt.Errorf("invalid root DID: %w", err)
}
capCtx.RemoveRoots([]did.DID{rootDID}, ucan.TokenList{}, ucan.TokenList{})
case require != "":
var token ucan.Token
if err := json.Unmarshal([]byte(require), &token); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
capCtx.RemoveRoots(nil, ucan.TokenList{Tokens: []*ucan.Token{&token}}, ucan.TokenList{})
case provide != "":
var token ucan.Token
if err := json.Unmarshal([]byte(provide), &token); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
capCtx.RemoveRoots(nil, ucan.TokenList{}, ucan.TokenList{Tokens: []*ucan.Token{&token}})
default:
return fmt.Errorf("one of --provide, --root, or --require must be specified")
}
if err := SaveCapabilityContext(capCtx, context); err != nil {
return fmt.Errorf("save capability context: %w", err)
}
return nil
},
}
useFlagContext(cmd, &context)
useFlagRoot(cmd, &root)
useFlagRequire(cmd, &require)
useFlagProvide(cmd, &provide)
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnProvide, fnRoot, fnRequire)
cmd.MarkFlagsMutuallyExclusive(fnProvide, fnRoot, fnRequire)
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
const ledger = "ledger"
func IsLedgerContext(context string) bool {
return strings.HasPrefix(context, ledger)
}
func LedgerContext(context string) string {
parts := strings.Split(context, ":")
if len(parts) == 2 {
return parts[1]
}
return ledger
}
func CreateTrustContextFromKeyStore(afs afero.Afero, contextKey string) (did.TrustContext, crypto.PrivKey, error) {
keyStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.KeystoreDir)
ks, err := keystore.New(afs.Fs, keyStoreDir)
if err != nil {
return nil, nil, fmt.Errorf("failed to open keystore: %w", err)
}
passphrase := os.Getenv("DMS_PASSPHRASE")
if passphrase == "" {
passphrase, err = utils.PromptForPassphrase(false)
if err != nil {
return nil, nil, fmt.Errorf("failed to get passphrase: %w", err)
}
}
ksPrivKey, err := ks.Get(contextKey, passphrase)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key from keystore: %w", err)
}
priv, err := ksPrivKey.PrivKey()
if err != nil {
return nil, nil, fmt.Errorf("unable to convert key from keystore to private key: %w", err)
}
trustCtx, err := did.NewTrustContextWithPrivateKey(priv)
if err != nil {
return nil, nil, fmt.Errorf("unable to create trust context: %w", err)
}
return trustCtx, priv, nil
}
func LoadCapabilityContext(trustCtx did.TrustContext, name string) (ucan.CapabilityContext, error) {
capStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.CapstoreDir)
capStoreFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", name))
f, err := os.Open(capStoreFile)
if err != nil {
return nil, fmt.Errorf("unable to open capability context file: %w", err)
}
defer f.Close()
capCtx, err := ucan.LoadCapabilityContext(trustCtx, f)
if err != nil {
return nil, fmt.Errorf("unable to load capability context: %w", err)
}
return capCtx, nil
}
func SaveCapabilityContext(capCtx ucan.CapabilityContext, name string) error {
capStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.CapstoreDir)
capCtxFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", name))
capCtxBackup := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap.%d", name, time.Now().Unix()))
// first take a backup -- move the current context
if _, err := os.Stat(capCtxFile); err == nil {
if err := os.Rename(capCtxFile, capCtxBackup); err != nil {
return fmt.Errorf("error backing up current capability context: %w", err)
}
}
// now open for writing
f, err := os.Create(capCtxFile)
if err != nil {
return fmt.Errorf("error creating new capability context file: %w", err)
}
defer f.Close()
if err := ucan.SaveCapabilityContext(capCtx, f); err != nil {
return fmt.Errorf("error saving capability context: %w", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"encoding/json"
"fmt"
"os"
"os/exec"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/internal/config"
)
func newConfigCmd(fs afero.Fs) *cobra.Command {
if fs == nil {
cobra.CheckErr("Fs is nil")
}
cmd := &cobra.Command{
Use: "config",
Short: "Manage configuration file",
Long: `Utility to manage user's configuration file via command-line
Search for the configuration file is done in the following locations and order:
1. "." (current directory)
2. "$HOME/.nunet"
3. "/etc/nunet"`,
}
cmd.AddCommand(newConfigGetCmd())
cmd.AddCommand(newConfigSetCmd(fs))
cmd.AddCommand(newConfigEditCmd())
return cmd
}
func newConfigGetCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "get <key>",
Short: "Display configuration",
Long: `Display the value for a configuration key
It reads the value from configuration file, otherwise it return default values
Example:
nunet config get rest.port`,
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
if len(args) == 0 {
info, err := json.MarshalIndent(config.GetConfig(), "", " ")
if err != nil {
return fmt.Errorf("failed to indent config JSON: %w", err)
}
cmd.Println(string(info))
return nil
}
value, err := config.Get(args[0])
if err != nil {
return fmt.Errorf("could not get key's value: %w", err)
}
pretty, err := json.MarshalIndent(value, "", " ")
if err != nil {
return fmt.Errorf("failed to indent JSON: %w", err)
}
cmd.Println(string(pretty))
return nil
},
}
return cmd
}
func newConfigSetCmd(fs afero.Fs) *cobra.Command {
cmd := &cobra.Command{
Use: "set <key> <value>",
Short: "Update configuration",
Long: `Set value for a configuration key
It creates a configuration file if does not exists, otherwise it updates the existing file
Example:
nunet config set rest.port 4444`,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
exists, err := config.FileExists(fs)
if err != nil {
return fmt.Errorf("failed to check if config file exists: %w", err)
}
if !exists {
cmd.Println("Config file did not exist. Creating new file...")
} else {
cmd.Println("Updating existing config file...")
}
if err := config.Set(fs, args[0], args[1]); err != nil {
return fmt.Errorf("failed to set config: %w", err)
}
cmd.Println("Applied changes.")
return nil
},
}
return cmd
}
func newConfigEditCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "edit",
Short: "Edit configuration",
Long: `Open configuration file with default text editor
This command search the configuration file and open it with the default text editor
It reads the $EDITOR environment variable and it fails if it's not set`,
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
editor := os.Getenv("EDITOR")
if editor == "" {
return fmt.Errorf("$EDITOR not set")
}
cmd.Printf("Text editor: %s\n", editor)
cmd.Printf("Config path: %s\n", config.GetPath())
// do we need better sanitization?
// do we check if editor is valid?
proc := exec.Command(editor, config.GetPath())
proc.Stdout = cmd.OutOrStdout()
proc.Stdin = cmd.InOrStdin()
proc.Stderr = cmd.OutOrStderr()
return proc.Run()
},
}
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"context"
"fmt"
"os"
"github.com/docker/docker/api/types/container"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/hardware"
"gitlab.com/nunet/device-management-service/dms/hardware/gpu"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/types"
)
func newGPUCommand() *cobra.Command {
gpuCmd := &cobra.Command{
Use: "gpu <operation>",
Short: "Manage GPU resources",
Long: `Available operations:
- list: List all available GPUs
- test: Test GPU deployment by running a docker container with GPU resources
`,
}
// Add subcommands
gpuCmd.AddCommand(newGPUListCommand())
gpuCmd.AddCommand(newGPUTestCommand())
return gpuCmd
}
func newGPUListCommand() *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List all available GPUs",
RunE: func(_ *cobra.Command, _ []string) error {
gpus, err := gpu.GetGPUs()
if err != nil {
return fmt.Errorf("error getting GPUs: %w", err)
}
usage, err := gpu.GetGPUUsage()
if err != nil {
return fmt.Errorf("error getting GPU usage: %w", err)
}
if len(gpus) == 0 {
return fmt.Errorf("no gpus found")
}
if len(gpus) != len(usage) {
return fmt.Errorf("GPU and GPU usage counts do not match. This is a bug")
}
fmt.Println("GPU Details:")
for i, g := range gpus {
fmt.Printf("Model: %s, Total VRAM: %.2f GB, Used VRAM: %.2f GB, Vendor: %s, PCI Address: %s, Index: %d\n",
g.Model, types.ConvertBytesToGB(g.VRAM), types.ConvertBytesToGB(usage[i].VRAM), g.Vendor, g.PCIAddress, g.Index)
}
return nil
},
}
}
func newGPUTestCommand() *cobra.Command {
return &cobra.Command{
Use: "test",
Short: "Test GPU deployment by running a Docker container with GPU resources",
RunE: func(_ *cobra.Command, _ []string) error {
hardwareManager := hardware.NewHardwareManager()
machineResources, err := hardwareManager.GetMachineResources()
if err != nil {
return fmt.Errorf("getting machine resources: %v", err)
}
if len(machineResources.GPUs) == 0 {
return fmt.Errorf("no GPUs detected on the host")
}
maxFreeVRAMGpu, err := machineResources.GPUs.MaxFreeVRAMGPU()
if err != nil {
return fmt.Errorf("getting GPU with highest free VRAM: %v", err)
}
fmt.Printf("Selected Vendor: %s, Device: %+v\n", maxFreeVRAMGpu.Vendor, maxFreeVRAMGpu)
if maxFreeVRAMGpu.Vendor == types.GPUVendorNvidia {
// Check if NVIDIA container toolkit is installed
// We specifically look for the nvidia-container-toolkit executable because:
// 1. It's the name of the main package installed via apt (nvidia-container-toolkit)
// 2. It's the most reliable indicator of a proper toolkit installation
// 3. Checking for this single file reduces the risk of false positives
_, err = os.Stat("/usr/bin/nvidia-container-toolkit")
if os.IsNotExist(err) {
return fmt.Errorf("nvidia container toolkit is not installed. Please install it before running this command")
}
}
imageName := "ubuntu:20.04"
client, err := docker.NewDockerClient()
if err != nil {
return fmt.Errorf("creating Docker executor: %v", err)
}
if !client.IsInstalled(context.Background()) {
return fmt.Errorf("docker is not installed or running. Cannot run GPU deployment test")
}
fmt.Printf("Creating the docker conainer for the image: %s\n", imageName)
containerConfig := &container.Config{
Image: imageName,
User: "root",
Tty: true, // Enable TTY
AttachStdout: true, // Attach stdout
AttachStderr: true, // Attach stderr
Entrypoint: []string{""}, // Set entrypoint to run shell commands
Cmd: []string{
// This will show both the integrated and discrete GPUs
"sh", "-c",
"apt-get update && apt-get install -y pciutils && lspci | grep 'VGA compatible controller'",
},
}
var hostConfig *container.HostConfig
switch maxFreeVRAMGpu.Vendor {
case types.GPUVendorNvidia:
hostConfig = &container.HostConfig{
AutoRemove: true,
Resources: container.Resources{
DeviceRequests: []container.DeviceRequest{
{
Driver: "nvidia",
Count: -1,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = &container.HostConfig{
AutoRemove: true,
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
case types.GPUVendorIntel:
hostConfig = &container.HostConfig{
AutoRemove: true,
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
default:
return fmt.Errorf("unknown GPU vendor: %s", maxFreeVRAMGpu.Vendor)
}
containerID, err := client.CreateContainer(context.Background(),
containerConfig,
hostConfig,
nil,
nil,
"nunet-gpu-test",
true,
)
if err != nil {
return fmt.Errorf("pulling Docker image: %v", err)
}
fmt.Println("Container created with ID: ", containerID)
if err := client.StartContainer(context.Background(), "nunet-gpu-test"); err != nil {
return fmt.Errorf("starting docker container: %v", err)
}
ctx := context.Background()
// Wait for the container to finish execution
statusCh, errCh := client.WaitContainer(ctx, containerID)
select {
case err := <-errCh:
if err != nil {
fmt.Printf("Container exited with error: %v\n", err)
}
case <-statusCh:
fmt.Println("Container execution completed.")
}
reader, err := client.GetOutputStream(ctx, containerID, "", true)
if err != nil {
return fmt.Errorf("getting output stream: %v", err)
}
// Print the output stream
if _, err := os.Stdout.ReadFrom(reader); err != nil {
return fmt.Errorf("reading output stream: %v", err)
}
return nil
},
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
dmsUtils "gitlab.com/nunet/device-management-service/utils"
)
func newKeyCmd(fs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "key",
Short: "Manage cryptographic keys",
Long: `Manage cryptographic keys for the Device Management Service (DMS).
This command provides subcommands for creating new keys and retrieving Decentralized Identifiers (DIDs) associated with existing keys.`,
}
cmd.AddCommand(newKeyNewCmd(fs))
cmd.AddCommand(newKeyDIDCmd(fs))
return cmd
}
func newKeyNewCmd(fs afero.Afero) *cobra.Command {
return &cobra.Command{
Use: "new <name>",
Short: "Generate a key pair",
Long: `Generate a key pair and save the private key into the user's local keystore.
This command creates a new cryptographic key pair, stores the private key securely, and displays the associated Decentralized Identifier (DID). If a key with the specified name already exists, the user will be prompted to confirm before overwriting it.
Example:
nunet key new user`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
keyStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir)
if err != nil {
return fmt.Errorf("failed to create keystore: %w", err)
}
keyID := dms.UserContextName
if len(args) > 0 {
keyID = args[0]
}
keys, err := ks.ListKeys()
if err != nil {
return fmt.Errorf("failed to list keys: %w", err)
}
if dmsUtils.SliceContains(keys, keyID) {
confirmed, err := utils.PromptYesNo(
cmd.InOrStdin(),
cmd.OutOrStdout(),
fmt.Sprintf("A key with name '%s' already exists. Do you want to overwrite it with a new one?", keyID),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return fmt.Errorf("operation cancelled by user")
}
}
passphrase := os.Getenv("DMS_PASSPHRASE")
if passphrase == "" {
passphrase, err = utils.PromptForPassphrase(true)
if err != nil {
return fmt.Errorf("failed to get passphrase: %w", err)
}
}
priv, err := dms.GenerateAndStorePrivKey(ks, passphrase, keyID)
if err != nil {
return fmt.Errorf("failed to generate and store new private key")
}
did := did.FromPublicKey(priv.GetPublic())
fmt.Println(did)
return nil
},
}
}
func newKeyDIDCmd(fs afero.Afero) *cobra.Command {
return &cobra.Command{
Use: "did <name>",
Short: "Get a key's DID",
Long: `Get the DID (Decentralized Identifier) for a specified key.
This command retrieves the DID associated with either a key stored in the local keystore or a hardware ledger.
For keys in the local keystore, the user will be prompted for the passphrase to decrypt the key. To avoid passphrase prompting, it's possible to set a DMS_PASSPHRASE environment variable. For the ledger option, it uses the first account (index 0) on the connected hardware wallet.
Example:
nunet key did user
nunet key did ledger # if using ledger`,
Args: cobra.ExactArgs(1),
RunE: func(_ *cobra.Command, args []string) error {
keyName := args[0]
if keyName == "ledger" {
provider, err := did.NewLedgerWalletProvider(0)
if err != nil {
return err
}
fmt.Println(provider.DID())
return nil
}
keyStoreDir := filepath.Join(config.GetConfig().General.UserDir, dms.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir)
if err != nil {
return fmt.Errorf("failed to open keystore: %w", err)
}
passphrase := os.Getenv("DMS_PASSPHRASE")
if passphrase == "" {
passphrase, err = utils.PromptForPassphrase(false)
if err != nil {
return fmt.Errorf("failed to get passphrase: %w", err)
}
}
key, err := ks.Get(keyName, passphrase)
if err != nil {
return fmt.Errorf("failed to get key: %w", err)
}
priv, err := key.PrivKey()
if err != nil {
return fmt.Errorf("unable to convert key from keystore to private key: %v", err)
}
did := did.FromPublicKey(priv.GetPublic())
fmt.Println(did)
return nil
},
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/actor"
"gitlab.com/nunet/device-management-service/cmd/cap"
"gitlab.com/nunet/device-management-service/utils"
)
func newRootCmd(client *utils.HTTPClient, afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "nunet",
Short: "NuNet Device Management Service",
Long: `The Device Management Service (DMS) Command Line Interface (CLI)`,
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: false,
HiddenDefaultCmd: true,
},
SilenceErrors: true,
SilenceUsage: true,
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
}
cmd.AddCommand(newRunCmd())
cmd.AddCommand(newKeyCmd(afs))
cmd.AddCommand(cap.NewCapCmd(afs))
cmd.AddCommand(actor.NewActorCmd(client, afs))
cmd.AddCommand(newConfigCmd(afs.Fs))
cmd.AddCommand(newAutoCompleteCmd())
cmd.AddCommand(newVersionCmd())
cmd.AddCommand(newTapCommand())
cmd.AddCommand(newGPUCommand())
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
package cmd
import (
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/utils"
)
// Execute is a wrapper for cobra.Command method with same name
// It makes use of cobra.CheckErr to facilitate error handling
func Execute() {
afs := afero.Afero{Fs: afero.NewOsFs()}
client := utils.NewHTTPClient(
fmt.Sprintf("http://%s:%d",
config.GetConfig().Rest.Addr,
config.GetConfig().Rest.Port),
"/api/v1",
)
cobra.CheckErr(newRootCmd(client, afs).Execute())
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"fmt"
"net/http"
_ "net/http/pprof" //#nosec
"os"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/internal/config"
)
func newRunCmd() *cobra.Command {
var context string
pprof := config.GetConfig().Profiler.Enabled
pprofAddr := config.GetConfig().Profiler.Addr
pprofPort := config.GetConfig().Profiler.Port
cmd := &cobra.Command{
Use: "run",
Short: "Start the Device Management Service",
Long: `Start the Device Management Service
The Device Management Service (DMS) is a system application for running a node in the NuNet decentralized network of compute providers.
By default, DMS listens on port 9999. For more information on configuration, see:
nunet config --help
Or manually create a dms_config.json file and refer to the README for available settings.`,
RunE: func(_ *cobra.Command, _ []string) error {
passphrase := os.Getenv("DMS_PASSPHRASE")
var err error
if passphrase == "" {
fmt.Print("Please enter the DMS passphrase. This will be used to encrypt/decrypt the keystore containing necessary secrets for DMS:\n")
passphrase, err = utils.PromptForPassphrase(false)
if err != nil {
return fmt.Errorf("error reading passphrase from stdin: %w", err)
}
// TODO: validate passphrase (minimum x characters)
if passphrase == "" {
return fmt.Errorf("invalid passphrase")
}
}
if pprof {
go func() {
pprofMux := http.DefaultServeMux
http.DefaultServeMux = http.NewServeMux()
profilerAddr := fmt.Sprintf("%s:%d", pprofAddr, pprofPort)
log.Infof("Starting profiler on %s\n", profilerAddr)
// #nosec
if err := http.ListenAndServe(profilerAddr, pprofMux); err != nil {
log.Errorf("Error starting profiler: %v\n", err)
}
}()
}
return dms.Run(passphrase, context)
},
}
cmd.Flags().BoolVar(&pprof, "pprof", pprof, "enable profiling")
cmd.Flags().StringVar(&pprofAddr, "pprof-addr", pprofAddr, "enable profiling")
cmd.Flags().Uint32Var(&pprofPort, "pprof-port", pprofPort, "enable profiling")
cmd.Flags().StringVarP(&context, "context", "c", dms.DefaultContextName, "specify a capability context")
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/sys"
)
// newTapCommand creates the Cobra command to set up a TAP interface
func newTapCommand() *cobra.Command {
return &cobra.Command{
Use: "tap [main_interface] [vm_interface] [CIDR]",
Short: "Create and configure a TAP interface",
Long: `Create a TAP interface using the provided interface name and configure IP forwarding and iptables rules.
Example:
nunet tap eth0 tap0 172.16.0.1/24
Note: The command requires root privileges or CAP_NET_ADMIN=ep capability.
`,
Args: cobra.ExactArgs(3),
RunE: func(cmd *cobra.Command, args []string) error {
// check if running with privilege
if err := sys.RequiredCaps(); err != nil && os.Getuid() != 0 {
return fmt.Errorf("this command requires the CAP_NET_ADMIN=ep capability or run as root")
}
mainInterface := args[0]
vmInterface := args[1]
cidr := args[2]
// Check if the interfaces
_, err := sys.GetNetInterfaceByName(mainInterface)
if err != nil {
return fmt.Errorf("couldn't read main interface %q", mainInterface)
}
_, err = sys.GetNetInterfaceByName(vmInterface)
if err == nil {
return fmt.Errorf("interface %q already exists", vmInterface)
}
// Create the TAP interface
iface, err := sys.NewTunTapInterface(vmInterface, sys.NetTapMode, true)
if err != nil {
return err
}
fmt.Fprintf(cmd.OutOrStdout(), "TAP interface %s created\n", iface.Iface.Name())
// Assign IP address to the TAP interface
err = iface.SetAddress(cidr)
if err != nil {
return err
}
fmt.Fprintf(cmd.OutOrStdout(), "IP address %s assigned to TAP interface %s\n", cidr, iface.Iface.Name())
// Bring the TAP interface up
err = iface.Up()
if err != nil {
return err
}
// Add iptables rules for connection tracking
err = sys.AddRelEstRule("FORWARD")
if err != nil {
return err
}
// Add iptables rules to allow forwarding between interfaces
err = sys.AddForwardIntRule(vmInterface, mainInterface)
if err != nil {
return err
}
// Check IP Forwarding kernel parameter
if enabled, err := sys.ForwardingEnabled(); !enabled || err != nil {
return fmt.Errorf("IP forwarding looks to be disabled. Please enable it using 'sysctl -w sys.ipv4.ip_forward=1'")
}
fmt.Fprintf(cmd.OutOrStdout(), "TAP interface %s created and configured successfully\n", vmInterface)
return nil
},
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"fmt"
"strings"
"time"
"github.com/spf13/pflag"
)
// TimeValue adapts time.Time for use as a flag.
type TimeValue struct {
Time *time.Time
Formats []string
}
// NewTimeValue creates a new TimeValue.
func NewTimeValue(t *time.Time, formats ...string) *TimeValue {
if formats == nil {
formats = []string{
time.RFC822,
time.RFC822Z,
time.RFC3339,
time.RFC3339Nano,
time.DateTime,
time.DateOnly,
}
}
return &TimeValue{
Time: t,
Formats: formats,
}
}
// Set time.Time value from string based on accepted formats.
func (t TimeValue) Set(s string) error {
s = strings.TrimSpace(s)
for _, format := range t.Formats {
v, err := time.Parse(format, s)
if err == nil {
*t.Time = v
return nil
}
}
return fmt.Errorf("format must be one of: %v", strings.Join(t.Formats, ", "))
}
// Type name for time.Time flags.
func (t TimeValue) Type() string {
return "time"
}
// String returns the string representation of the time.Time value.
func (t TimeValue) String() string {
if t.Time == nil || t.Time.IsZero() {
return ""
}
return t.Time.String()
}
func GetTime(f *pflag.FlagSet, name string) (time.Time, error) {
t := time.Time{}
flag := f.Lookup(name)
if flag == nil {
return t, fmt.Errorf("flag %s not found", name)
}
if flag.Value == nil || flag.Value.Type() != new(TimeValue).Type() {
return t, fmt.Errorf("flag %s has wrong type or no value", name)
}
val := flag.Value.(*TimeValue)
return *val.Time, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"github.com/howeyc/gopass"
)
// PromptReonboard is a wrapper of utils.PromptYesNo with custom prompt that return error if user declines reonboard
func PromptReonboard(r io.Reader, w io.Writer) error {
prompt := "Looks like your machine is already onboarded. Proceed with reonboarding?"
confirmed, err := PromptYesNo(r, w, prompt)
if err != nil {
return fmt.Errorf("could not confirm reonboarding: %w", err)
}
if !confirmed {
return fmt.Errorf("reonboarding aborted by user")
}
return nil
}
func PromptForPassphrase(confirm bool) (string, error) {
maxTries := 3
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
done := make(chan bool)
var passphrase string
var err error
// Start a goroutine to handle passphrase input
go func() {
defer close(done)
var bytePassphrase, byteConfirmation []byte
for i := 0; i < maxTries; i++ {
fmt.Print("Passphrase: ")
bytePassphrase, err = gopass.GetPasswdMasked()
if err != nil {
err = fmt.Errorf("failed to read passphrase: %w", err)
return
}
if confirm {
fmt.Print("Please confirm your passphrase: ")
byteConfirmation, err = gopass.GetPasswdMasked()
if err != nil {
err = fmt.Errorf("failed to read passphrase confirmation: %w", err)
return
}
if string(bytePassphrase) != string(byteConfirmation) {
err = fmt.Errorf("passphrases do not match")
}
}
if err == nil {
passphrase = string(bytePassphrase)
return
}
fmt.Println(err)
fmt.Println("")
}
err = fmt.Errorf("user failed to input passphrase")
}()
// Wait for either the passphrase input to complete or an interrupt signal
select {
case <-done:
return passphrase, err
case <-sigChan:
return "", errors.New("sigterm received")
}
}
// PromptYesNo loops on confirmation from user until valid answer
func PromptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) {
reader := bufio.NewReader(in)
for {
fmt.Fprintf(out, "%s (y/N): ", prompt)
response, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("read response string failed: %w", err)
}
response = strings.ToLower(strings.TrimSpace(response))
if response == "y" || response == "yes" {
return true, nil
} else if response == "n" || response == "no" {
return false, nil
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
// Version information set by the build system (see Makefile)
var (
Version string
GoVersion string
BuildDate string
Commit string
)
func newVersionCmd() *cobra.Command {
return &cobra.Command{
Use: "version",
Short: "Information about current version",
Long: `Display information about the current Device Management Service (DMS) version`,
Run: func(_ *cobra.Command, _ []string) {
fmt.Println("NuNet Device Management Service")
fmt.Printf("Version: %s\nCommit: %s\n\nGo Version: %s\nBuild Date: %s\n",
Version, Commit, GoVersion, BuildDate)
},
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package db
import (
"fmt"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/types"
)
func ConnectDatabase(dbPath string) (*gorm.DB, error) {
database, err := gorm.Open(sqlite.Open(fmt.Sprintf("%s/nunet.db", dbPath)), &gorm.Config{})
if err != nil {
panic("Failed to connect to database!")
}
_ = database.AutoMigrate(&types.FreeResources{})
_ = database.AutoMigrate(&types.RequestTracker{})
_ = database.AutoMigrate(&types.OnboardedResources{})
_ = database.AutoMigrate(&types.MachineResources{})
_ = database.AutoMigrate(&types.OnboardingConfig{})
_ = database.AutoMigrate(&types.ResourceAllocation{})
_ = database.AutoMigrate(&types.GPU{})
return database, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
"fmt"
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/observability"
)
// NewDB initializes and sets up the clover database using bbolt under the hood.
// Additionally, it automatically creates collections for the necessary types.
func NewDB(path string, collections []string) (*clover.DB, error) {
endTrace := observability.StartTrace("clover_db_init_duration")
defer endTrace()
db, err := clover.Open(path)
if err != nil {
logger.Errorw("clover_db_init_failure", "error", fmt.Errorf("failed to connect to database: %w", err))
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
for _, collection := range collections {
if err := db.CreateCollection(collection); err != nil {
logger.Errorw("clover_db_init_failure", "collection", collection, "error", fmt.Errorf("failed to create collection %s: %w", collection, err))
return nil, fmt.Errorf("failed to create collection %s: %w", collection, err)
}
}
logger.Infow("clover_db_init_success", "path", path, "collections", collections)
return db, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerClover is a Clover implementation of the RequestTracker interface.
type RequestTrackerClover struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerClover.
// It initializes and returns a Clover-based repository for RequestTracker entities.
func NewRequestTracker(db *clover.DB) repositories.RequestTracker {
endTrace := observability.StartTrace("clover_db_repo_init_duration")
defer endTrace()
repo := &RequestTrackerClover{
NewGenericRepository[types.RequestTracker](db),
}
logger.Infow("clover_db_repo_init_success", "repository", "RequestTracker")
return repo
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericEntityRepositoryClover is a generic single entity repository implementation using Clover.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryClover.
// It initializes and returns a repository with the provided Clover database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericEntityRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericEntityRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericEntityRepositoryClover[T]) query() *clover_q.Query {
return clover_q.NewQuery(repo.collection)
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryClover[T]) Save(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) Get(_ context.Context) (T, error) {
var model T
q := repo.query().Sort(clover_q.SortOption{
Field: "CreatedAt",
Direction: -1,
})
doc, err := repo.db.FindFirst(q)
if err != nil {
return model, handleDBError(err)
}
if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryClover[T]) Clear(_ context.Context) error {
return repo.db.Delete(repo.query())
}
// History retrieves previous versions of the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) History(_ context.Context, query repositories.Query[T]) ([]T, error) {
var models []T
q := repo.query()
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return false
}
models = append(models, model)
return true
})
return models, handleDBError(err)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
)
const (
pkField = "_id"
deletedAtField = "DeletedAt"
)
// GenericRepositoryClover is a generic repository implementation using Clover.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericRepository creates a new instance of GenericRepositoryClover.
// It initializes and returns a repository with the provided Clover database.
func NewGenericRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericRepository[T] {
endTrace := observability.StartTrace("clover_db_repo_init_duration")
defer endTrace()
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
logger.Infow("clover_db_repo_init_success", "collection", collection)
return &GenericRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericRepositoryClover[T]) query(includeDeleted bool) *clover_q.Query {
query := clover_q.NewQuery(repo.collection)
if !includeDeleted {
query = query.Where(clover_q.Field(deletedAtField).LtEq(time.Unix(0, 0)))
}
return query
}
func (repo *GenericRepositoryClover[T]) queryWithID(
id interface{},
includeDeleted bool,
) *clover_q.Query {
return repo.query(includeDeleted).Where(clover_q.Field(pkField).Eq(id.(string)))
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryClover[T]) Create(_ context.Context, data T) (T, error) {
endTrace := observability.StartTrace("clover_db_create_duration")
defer endTrace()
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
logger.Errorw("clover_db_create_failure", "error", err)
return data, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
logger.Errorw("clover_db_create_failure", "error", err)
return data, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
logger.Infow("clover_db_create_success", "collection", repo.collection)
return model, nil
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryClover[T]) Get(_ context.Context, id interface{}) (T, error) {
endTrace := observability.StartTrace("clover_db_get_duration")
defer endTrace()
var model T
doc, err := repo.db.FindFirst(repo.queryWithID(id, false))
if err != nil || doc == nil {
logger.Errorw("clover_db_get_failure", "id", id, "error", err)
return model, handleDBError(err)
}
if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
logger.Errorw("clover_db_get_failure", "id", id, "error", err)
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
logger.Infow("clover_db_get_success", "id", id)
return model, nil
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryClover[T]) Update(
ctx context.Context,
id interface{},
data T,
) (T, error) {
endTrace := observability.StartTrace("clover_db_update_duration")
defer endTrace()
updates := toCloverDoc(data).AsMap()
updates["UpdatedAt"] = time.Now()
err := repo.db.Update(repo.queryWithID(id, false), updates)
if err != nil {
logger.Errorw("clover_db_update_failure", "id", id, "error", err)
return data, handleDBError(err)
}
data, err = repo.Get(ctx, id)
logger.Infow("clover_db_update_success", "id", id)
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryClover[T]) Delete(_ context.Context, id interface{}) error {
endTrace := observability.StartTrace("clover_db_delete_duration")
defer endTrace()
err := repo.db.Delete(
repo.queryWithID(id, false),
)
if err != nil {
logger.Errorw("clover_db_delete_failure", "id", id, "error", err)
return err
}
logger.Infow("clover_db_delete_success", "id", id)
return nil
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryClover[T]) Find(
_ context.Context,
query repositories.Query[T],
) (T, error) {
endTrace := observability.StartTrace("clover_db_find_duration")
defer endTrace()
var model T
q := repo.query(false)
q = applyConditions(q, query)
doc, err := repo.db.FindFirst(q)
if err != nil || doc == nil {
logger.Errorw("clover_db_find_failure", "error", err)
return model, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
logger.Errorw("clover_db_find_failure", "error", err)
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
logger.Infow("clover_db_find_success", "collection", repo.collection)
return model, nil
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryClover[T]) FindAll(
_ context.Context,
query repositories.Query[T],
) ([]T, error) {
endTrace := observability.StartTrace("clover_db_find_all_duration")
defer endTrace()
var models []T
var modelParsingErr error
q := repo.query(false)
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
model, internalErr := toModel[T](doc, false)
if internalErr != nil {
modelParsingErr = handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, internalErr))
logger.Errorw("clover_db_find_all_failure", "error", internalErr)
return false
}
models = append(models, model)
return true
})
if err != nil {
logger.Errorw("clover_db_find_all_failure", "error", err)
return models, handleDBError(err)
}
if modelParsingErr != nil {
return models, modelParsingErr
}
logger.Infow("clover_db_find_all_success", "collection", repo.collection)
return models, nil
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a Clover database query.
// It takes a Clover database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified Clover database instance is returned.
func applyConditions[T repositories.ModelType](
q *clover_q.Query,
query repositories.Query[T],
) *clover_q.Query {
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
// change the field name to json tag name if specified in the struct
condition.Field = fieldJSONTag[T](condition.Field)
switch condition.Operator {
case "=":
q = q.Where(clover_q.Field(condition.Field).Eq(condition.Value))
case ">":
q = q.Where(clover_q.Field(condition.Field).Gt(condition.Value))
case ">=":
q = q.Where(clover_q.Field(condition.Field).GtEq(condition.Value))
case "<":
q = q.Where(clover_q.Field(condition.Field).Lt(condition.Value))
case "<=":
q = q.Where(clover_q.Field(condition.Field).LtEq(condition.Value))
case "!=":
q = q.Where(clover_q.Field(condition.Field).Neq(condition.Value))
case "IN":
if values, ok := condition.Value.([]interface{}); ok {
q = q.Where(clover_q.Field(condition.Field).In(values...))
}
case "LIKE":
if value, ok := condition.Value.(string); ok {
q = q.Where(clover_q.Field(condition.Field).Like(value))
}
}
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldName = fieldJSONTag[T](fieldName)
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
q = q.Where(clover_q.Field(fieldName).Eq(fieldValue))
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := 1
if query.SortBy[0] == '-' {
dir = -1
query.SortBy = fieldJSONTag[T](query.SortBy[1:])
}
q = q.Sort(clover_q.SortOption{Field: query.SortBy, Direction: dir})
}
// Apply limit if specified in the query.
if query.Limit > 0 {
q = q.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
q = q.Limit(query.Offset)
}
return q
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesRepositoryClover is a Clover implementation of the MachineResourcesRepository interface.
type MachineResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResourcesRepository creates a new instance of MachineResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for MachineResources entity.
func NewMachineResourcesRepository(db *clover.DB) repositories.MachineResources {
return &MachineResourcesRepositoryClover{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesClover is a Clover implementation of the FreeResources interface.
type FreeResourcesClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResources(db *clover.DB) repositories.FreeResources {
return &FreeResourcesClover{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesClover is a Clover implementation of the OnboardedResources interface.
type OnboardedResourcesClover struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResources creates a new instance of OnboardedResourcesClover.
// It initializes and returns a Clover-based repository for OnboardedResources entity.
func NewOnboardedResources(db *clover.DB) repositories.OnboardedResources {
return &OnboardedResourcesClover{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// ResourceAllocationClover is a Clover implementation of the ResourceAllocation interface.
type ResourceAllocationClover struct {
repositories.GenericRepository[types.ResourceAllocation]
}
// NewResourceAllocation creates a new instance of ResourceAllocationClover.
// It initializes and returns a Clover-based repository for ResourceAllocation entity.
func NewResourceAllocation(db *clover.DB) repositories.ResourceAllocation {
return &ResourceAllocationClover{
NewGenericRepository[types.ResourceAllocation](db),
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeClover is a Clover implementation of the StorageVolume interface.
type StorageVolumeClover struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeClover.
// It initializes and returns a Clover-based repository for StorageVolume entities.
func NewStorageVolume(db *clover.DB) repositories.StorageVolume {
endTrace := observability.StartTrace("clover_storage_volume_init_duration")
defer endTrace()
logger.Infow("clover_storage_volume_init_success", "collection", "storage_volume")
return &StorageVolumeClover{
NewGenericRepository[types.StorageVolume](db),
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
"gitlab.com/nunet/device-management-service/db/repositories"
)
func handleDBError(err error) error {
if err != nil {
switch err {
case clover.ErrDocumentNotExist:
return repositories.ErrNotFound
case clover.ErrDuplicateKey:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
func toCloverDoc[T repositories.ModelType](data T) *clover_d.Document {
jsonBytes, err := json.Marshal(data)
if err != nil {
return clover_d.NewDocument()
}
mappedData := make(map[string]interface{})
err = json.Unmarshal(jsonBytes, &mappedData)
if err != nil {
return clover_d.NewDocument()
}
doc := clover_d.NewDocumentOf(mappedData)
return doc
}
func toModel[T repositories.ModelType](doc *clover_d.Document, isEntityRepo bool) (T, error) {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return model, err
}
if !isEntityRepo {
// we shouldn't try to update IDs of entity repositories as they might not
// even have an ID at all
model, err = repositories.UpdateField(model, "ID", doc.ObjectId())
if err != nil {
return model, err
}
}
return model, nil
}
func fieldJSONTag[T repositories.ModelType](field string) string {
fieldName := field
if field, ok := reflect.TypeOf(*new(T)).FieldByName(field); ok {
if tag, ok := field.Tag.Lookup("json"); ok {
fieldName = strings.Split(tag, ",")[0]
}
}
return fieldName
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package repositories
import (
"context"
)
// QueryCondition is a struct representing a query condition.
type QueryCondition struct {
Field string // Field specifies the database or struct field to which the condition applies.
Operator string // Operator defines the comparison operator (e.g., "=", ">", "<").
Value interface{} // Value is the expected value for the given field.
}
type ModelType interface{}
// Query is a struct that wraps both the instance of type T and additional query parameters.
// It is used to construct queries with conditions, sorting, limiting, and offsetting.
type Query[T any] struct {
Instance T // Instance is an optional object of type T used to build conditions from its fields.
Conditions []QueryCondition // Conditions represent the conditions applied to the query.
SortBy string // SortBy specifies the field by which the query results should be sorted.
Limit int // Limit specifies the maximum number of results to return.
Offset int // Offset specifies the number of results to skip before starting to return data.
}
// GenericRepository is an interface defining basic CRUD operations and standard querying methods.
type GenericRepository[T ModelType] interface {
// Create adds a new record to the repository.
Create(ctx context.Context, data T) (T, error)
// Get retrieves a record by its identifier.
Get(ctx context.Context, id interface{}) (T, error)
// Update modifies a record by its identifier.
Update(ctx context.Context, id interface{}, data T) (T, error)
// Delete removes a record by its identifier.
Delete(ctx context.Context, id interface{}) error
// Find retrieves a single record based on a query.
Find(ctx context.Context, query Query[T]) (T, error)
// FindAll retrieves multiple records based on a query.
FindAll(ctx context.Context, query Query[T]) ([]T, error)
// GetQuery returns an empty query instance for the repository's type.
GetQuery() Query[T]
}
// EQ creates a QueryCondition for equality comparison.
// It takes a field name and a value and returns a QueryCondition with the equality operator.
func EQ(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "=", Value: value}
}
// GT creates a QueryCondition for greater-than comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than operator.
func GT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">", Value: value}
}
// GTE creates a QueryCondition for greater-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than or equal operator.
func GTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">=", Value: value}
}
// LT creates a QueryCondition for less-than comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than operator.
func LT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<", Value: value}
}
// LTE creates a QueryCondition for less-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than or equal operator.
func LTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<=", Value: value}
}
// IN creates a QueryCondition for an "IN" comparison.
// It takes a field name and a slice of values and returns a QueryCondition with the "IN" operator.
func IN(field string, values []interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "IN", Value: values}
}
// LIKE creates a QueryCondition for a "LIKE" comparison.
// It takes a field name and a pattern and returns a QueryCondition with the "LIKE" operator.
func LIKE(field, pattern string) QueryCondition {
return QueryCondition{Field: field, Operator: "LIKE", Value: pattern}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerGORM is a GORM implementation of the RequestTracker interface.
type RequestTrackerGORM struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerGORM.
// It initializes and returns a GORM-based repository for RequestTracker entities.
func NewRequestTracker(db *gorm.DB) repositories.RequestTracker {
endTrace := observability.StartTrace("gorm_db_request_tracker_init_duration")
defer endTrace()
logger.Infow("gorm_db_request_tracker_init_success", "repository", "RequestTracker")
return &RequestTrackerGORM{
NewGenericRepository[types.RequestTracker](db),
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package gorm
import (
"context"
"fmt"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// Note: our GORM implementation does not support:
//
// - Structs with maps
//
// - Structs with nested NAMED structs, e.g.:
// type ComputerSpecs struct {
// types.BaseDBModel
// CPU int
// Another AnotherStruct
// }
const (
createdAtField = "CreatedAt"
)
// GenericEntityRepositoryGORM is a generic single entity repository implementation using GORM as an ORM.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB // db is the GORM database instance.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryGORM.
// It initializes and returns a repository with the provided GORM database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *gorm.DB,
) repositories.GenericEntityRepository[T] {
return &GenericEntityRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryGORM[T]) Save(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Get(ctx context.Context) (T, error) {
var result T
query := repo.GetQuery()
query.SortBy = fmt.Sprintf("-%s", createdAtField)
db := repo.db.WithContext(ctx)
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Clear(ctx context.Context) error {
return repo.db.WithContext(ctx).Delete(new(T), "id IS NOT NULL").Error
}
// History retrieves previous records from the repository constrained by the query.
func (repo *GenericEntityRepositoryGORM[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package gorm
import (
"context"
"fmt"
"reflect"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
)
// Note: our GORM implementation does not support:
//
// - Structs with maps
//
// - Structs with nested NAMED structs, e.g.:
// type ComputerSpecs struct {
// types.BaseDBModel
// CPU int
// Another AnotherStruct
// }
// GenericRepositoryGORM is a generic repository implementation using GORM as an ORM.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB
}
// NewGenericRepository creates a new instance of GenericRepositoryGORM.
// It initializes and returns a repository with the provided GORM database.
func NewGenericRepository[T repositories.ModelType](db *gorm.DB) repositories.GenericRepository[T] {
endTrace := observability.StartTrace("gorm_db_repo_init_duration")
defer endTrace()
logger.Infow("gorm_db_repo_init_success", "repository", fmt.Sprintf("%T", *new(T)))
return &GenericRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryGORM[T]) Create(ctx context.Context, data T) (T, error) {
endTrace := observability.StartTrace("gorm_db_create_duration")
defer endTrace()
err := repo.db.WithContext(ctx).Create(&data).Error
if err != nil {
logger.Errorw("gorm_db_create_failure", "error", err)
return data, handleDBError(err)
}
logger.Infow("gorm_db_create_success", "record", fmt.Sprintf("%+v", data))
return data, handleDBError(err)
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Get(ctx context.Context, id interface{}) (T, error) {
endTrace := observability.StartTrace("gorm_db_get_duration")
defer endTrace()
var result T
err := repo.db.WithContext(ctx).First(&result, "id = ?", id).Error
if err != nil {
logger.Errorw("gorm_db_get_failure", "id", id, "error", err)
return result, handleDBError(err)
}
logger.Infow("gorm_db_get_success", "record", fmt.Sprintf("%+v", result))
return result, handleDBError(err)
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Update(ctx context.Context, id interface{}, data T) (T, error) {
endTrace := observability.StartTrace("gorm_db_update_duration")
defer endTrace()
err := repo.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Updates(data).Error
if err != nil {
logger.Errorw("gorm_db_update_failure", "id", id, "error", err)
return data, handleDBError(err)
}
logger.Infow("gorm_db_update_success", "id", id, "data", fmt.Sprintf("%+v", data))
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Delete(ctx context.Context, id interface{}) error {
endTrace := observability.StartTrace("gorm_db_delete_duration")
defer endTrace()
err := repo.db.WithContext(ctx).Delete(new(T), "id = ?", id).Error
if err != nil {
logger.Errorw("gorm_db_delete_failure", "id", id, "error", err)
return err
}
logger.Infow("gorm_db_delete_success", "id", id)
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryGORM[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
endTrace := observability.StartTrace("gorm_db_find_duration")
defer endTrace()
var result T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.First(&result).Error
if err != nil {
logger.Errorw("gorm_db_find_failure", "error", err)
return result, handleDBError(err)
}
logger.Infow("gorm_db_find_success", "record", fmt.Sprintf("%+v", result))
return result, handleDBError(err)
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryGORM[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
endTrace := observability.StartTrace("gorm_db_find_all_duration")
defer endTrace()
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
if err != nil {
logger.Errorw("gorm_db_find_all_failure", "error", err)
return results, handleDBError(err)
}
logger.Infow("gorm_db_find_all_success", "recordCount", len(results))
return results, handleDBError(err)
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a GORM database query.
// It takes a GORM database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified GORM database instance is returned.
func applyConditions[T any](db *gorm.DB, query repositories.Query[T]) *gorm.DB {
// Retrieve the table name using the GORM naming strategy.
tableName := db.NamingStrategy.TableName(reflect.TypeOf(*new(T)).Name())
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
columnName := db.NamingStrategy.ColumnName(tableName, condition.Field)
db = db.Where(
fmt.Sprintf("%s %s ?", columnName, condition.Operator),
condition.Value,
)
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
columnName := db.NamingStrategy.ColumnName(tableName, fieldName)
db = db.Where(fmt.Sprintf("%s = ?", columnName), fieldValue)
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := "ASC"
if query.SortBy[0] == '-' {
query.SortBy = query.SortBy[1:]
dir = "DESC"
}
columnName := db.NamingStrategy.ColumnName(tableName, query.SortBy)
db = db.Order(fmt.Sprintf("%s.%s %s", tableName, columnName, dir))
}
// Apply limit if specified in the query.
if query.Limit > 0 {
db = db.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
db = db.Limit(query.Offset)
}
return db
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
type OnboardingConfigGORM struct {
repositories.GenericEntityRepository[types.OnboardingConfig]
}
func NewOnboardingConfig(db *gorm.DB) repositories.OnboardingConfig {
return &OnboardingConfigGORM{
NewGenericEntityRepository[types.OnboardingConfig](db),
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package gorm
import (
"gitlab.com/nunet/device-management-service/types"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesGORM is a GORM implementation of the MachineResources interface.
type MachineResourcesGORM struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResources creates a new instance of MachineResourcesGORM.
// It initializes and returns a GORM-based repository for MachineResources entity.
func NewMachineResources(db *gorm.DB) repositories.MachineResources {
return &MachineResourcesGORM{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesGORM is a GORM implementation of the FreeResources interface.
type FreeResourcesGORM struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesGORM.
// It initializes and returns a GORM-based repository for FreeResources entity.
func NewFreeResources(db *gorm.DB) repositories.FreeResources {
return &FreeResourcesGORM{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesGORM is a GORM implementation of the OnboardedResources interface.
type OnboardedResourcesGORM struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResources creates a new instance of OnboardedResourcesGORM.
// It initializes and returns a GORM-based repository for OnboardedResources entity.
func NewOnboardedResources(db *gorm.DB) repositories.OnboardedResources {
return &OnboardedResourcesGORM{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// ResourceAllocationGORM is a GORM implementation of the ResourceAllocation interface.
type ResourceAllocationGORM struct {
repositories.GenericRepository[types.ResourceAllocation]
}
// NewResourceAllocation creates a new instance of ResourceAllocationGORM.
// It initializes and returns a GORM-based repository for ResourceAllocation entities.
func NewResourceAllocation(db *gorm.DB) repositories.ResourceAllocation {
return &ResourceAllocationGORM{
NewGenericRepository[types.ResourceAllocation](db),
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeGORM is a GORM implementation of the StorageVolume interface.
type StorageVolumeGORM struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeGORM.
// It initializes and returns a GORM-based repository for StorageVolume entities.
func NewStorageVolume(db *gorm.DB) repositories.StorageVolume {
endTrace := observability.StartTrace("gorm_db_storage_volume_init_duration")
defer endTrace()
logger.Infow("gorm_db_storage_volume_init_success", "repository", "StorageVolume")
return &StorageVolumeGORM{
NewGenericRepository[types.StorageVolume](db),
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package gorm
import (
"errors"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// handleDBError is a utility function that translates GORM database errors into custom repository errors.
// It takes a GORM database error as input and returns a corresponding custom error from the repositories package.
func handleDBError(err error) error {
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return repositories.ErrNotFound
case gorm.ErrInvalidData, gorm.ErrInvalidField, gorm.ErrInvalidValue:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package repositories
import (
"fmt"
"reflect"
)
// UpdateField is a generic function that updates a field of a struct or a pointer to a struct.
// The function uses reflection to dynamically update the specified field of the input struct.
func UpdateField[T interface{}](input T, fieldName string, newValue interface{}) (T, error) {
// Use reflection to get the struct's field
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
// If input is a pointer, get the underlying element
val = val.Elem()
} else {
// If input is not a pointer, ensure it's addressable
val = reflect.ValueOf(&input).Elem()
}
// Check if the input is a struct
if val.Kind() != reflect.Struct {
return input, fmt.Errorf("not a struct: %T", input)
}
// Get the field by name
field := val.FieldByName(fieldName)
if !field.IsValid() {
return input, fmt.Errorf("field not found: %v", fieldName)
}
// Check if the field is settable
if !field.CanSet() {
return input, fmt.Errorf("field not settable: %v", fieldName)
}
// Check if types are compatible
if !reflect.TypeOf(newValue).ConvertibleTo(field.Type()) {
return input, fmt.Errorf(
"incompatible conversion: %v -> %v; value: %v",
field.Type(), reflect.TypeOf(newValue), newValue,
)
}
// Convert the new value to the field type
convertedValue := reflect.ValueOf(newValue).Convert(field.Type())
// Set the new value to the field
field.Set(convertedValue)
return input, nil
}
// IsEmptyValue checks if value represents a zero-value struct (or pointer to a zero-value struct) using reflection.
// The function is useful for determining if a struct or its pointer is empty, i.e., all fields have their zero-values.
func IsEmptyValue(value interface{}) bool {
// Check if the value is nil
if value == nil {
return true
}
// Use reflection to get the value's type and kind
val := reflect.ValueOf(value)
// If the value is a pointer, dereference it to get the underlying element
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Check if the value is zero (empty) based on its kind
return val.IsZero()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"github.com/fivebinaries/go-cardano-serialization/address"
"github.com/fivebinaries/go-cardano-serialization/bip32"
"github.com/fivebinaries/go-cardano-serialization/network"
"github.com/tyler-smith/go-bip39"
"gitlab.com/nunet/device-management-service/types"
)
func harden(num uint32) uint32 {
return 0x80000000 + num
}
func GetCardanoAddressAndMnemonic() (*types.Account, error) {
var pair types.Account
entropy, _ := bip39.NewEntropy(256)
mnemonic, _ := bip39.NewMnemonic(entropy)
pair.Mnemonic = mnemonic
rootKey := bip32.FromBip39Entropy(
entropy,
[]byte{},
)
accountKey := rootKey.Derive(harden(1852)).Derive(harden(1815)).Derive(harden(0))
utxoPubKey := accountKey.Derive(0).Derive(0).Public()
utxoPubKeyHash := utxoPubKey.PublicKey().Hash()
stakeKey := accountKey.Derive(2).Derive(0).Public()
stakeKeyHash := stakeKey.PublicKey().Hash()
addr := address.NewBaseAddress(
network.MainNet(),
&address.StakeCredential{
Kind: address.KeyStakeCredentialType,
Payload: utxoPubKeyHash[:],
},
&address.StakeCredential{
Kind: address.KeyStakeCredentialType,
Payload: stakeKeyHash[:],
})
pair.Address = addr.String()
return &pair, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"fmt"
"gitlab.com/nunet/device-management-service/types"
)
// CreatePaymentAddress generates a keypair based on the wallet type. Currently supported types: cardano.
func CreatePaymentAddress(wallet string) (*types.Account, error) {
if wallet != "cardano" {
return nil, fmt.Errorf("invalid wallet")
}
pair, err := GetCardanoAddressAndMnemonic()
if err != nil {
return nil, fmt.Errorf("could not generate %s address: %w", wallet, err)
}
return pair, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package dms
import (
_ "embed"
"errors"
"fmt"
"os"
"path/filepath"
"time"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/multiformats/go-multiaddr"
"github.com/oschwald/geoip2-golang"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/api"
"gitlab.com/nunet/device-management-service/db"
gdb "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/dms/hardware"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/internal"
backgroundtasks "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
DefaultContextName = "dms"
UserContextName = "user"
KeystoreDir = "key/"
CapstoreDir = "cap/"
)
//go:embed data/GeoLite2-Country.mmdb
var geoLite2Country []byte
// NewP2P is stub, real implementation is needed in order to pass it to
// routers which access them in some handlers.
func NewP2P() libp2p.Libp2p {
return libp2p.Libp2p{}
}
// QUESTION(dms-initialization): should the db instance be constructed here?
func Run(ksPassphrase string, contextName string) error {
if contextName == "" {
contextName = DefaultContextName
}
gcfg := config.GetConfig()
// load geoip2 database
geoip2db, err := geoip2.FromBytes(geoLite2Country)
if err != nil {
return fmt.Errorf("unable to load geoip2 database: %w", err)
}
fs := afero.NewOsFs()
keyStoreDir := filepath.Join(gcfg.UserDir, KeystoreDir)
keyStore, err := keystore.New(fs, keyStoreDir)
if err != nil {
return fmt.Errorf("unable to create keystore: %w", err)
}
var priv crypto.PrivKey
ksPrivKey, err := keyStore.Get(contextName, ksPassphrase)
if err != nil {
if errors.Is(err, keystore.ErrKeyNotFound) {
priv, err = GenerateAndStorePrivKey(keyStore, ksPassphrase, contextName)
if err != nil {
return fmt.Errorf("couldn't generate and store priv key into keystore: %w", err)
}
} else {
return fmt.Errorf("failed to get private key from keystore; Error: %v", err)
}
} else {
priv, err = ksPrivKey.PrivKey()
if err != nil {
return fmt.Errorf("unable to convert key from keystore to private key: %v", err)
}
}
pubKey := priv.GetPublic()
db, err := db.ConnectDatabase(gcfg.WorkDir)
if err != nil {
return fmt.Errorf("unable to connect to database: %w", err)
}
hardwareManager := hardware.NewHardwareManager()
repos := resources.ManagerRepos{
OnboardedResources: gdb.NewOnboardedResources(db),
ResourceAllocation: gdb.NewResourceAllocation(db),
}
resourceManager, err := resources.NewResourceManager(repos, hardwareManager)
if err != nil {
return fmt.Errorf("unable to create resource manager: %w", err)
}
onboardR := gdb.NewOnboardingConfig(db)
onboard := onboarding.New(&onboarding.Config{
Fs: afero.Afero{Fs: fs},
ConfigRepo: onboardR,
Hardware: hardwareManager,
ResourceManager: resourceManager,
WorkDir: gcfg.WorkDir,
DatabasePath: fmt.Sprintf("%s/nunet.db", gcfg.WorkDir),
})
var p2pNet *libp2p.Libp2p
bootstrapPeers := make([]multiaddr.Multiaddr, len(gcfg.P2P.BootstrapPeers))
for i, addr := range gcfg.P2P.BootstrapPeers {
bootstrapPeers[i], _ = multiaddr.NewMultiaddr(addr)
}
cfg := &types.Libp2pConfig{
PrivateKey: priv,
BootstrapPeers: bootstrapPeers,
Rendezvous: "nunet-test",
Server: false,
Scheduler: backgroundtasks.NewScheduler(10),
CustomNamespace: "/nunet-dht-1/",
ListenAddress: gcfg.P2P.ListenAddress,
PeerCountDiscoveryLimit: 40,
Memory: gcfg.P2P.Memory,
FileDescriptors: gcfg.P2P.FileDescriptors,
}
p2p, err := libp2p.New(cfg, fs)
if err != nil {
return fmt.Errorf("unable to create libp2p instance: %v", err)
}
if err = p2p.Init(); err != nil {
return fmt.Errorf("unable to initialize libp2p: %v", err)
}
if err = p2p.Start(); err != nil {
return fmt.Errorf("unable to start libp2p: %v", err)
}
p2pNet = p2p
trustCtx, err := did.NewTrustContextWithPrivateKey(priv)
if err != nil {
return fmt.Errorf("unable to create trust context: %w", err)
}
capStoreDir := filepath.Join(gcfg.UserDir, CapstoreDir)
capStoreFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", contextName))
var capCtx ucan.CapabilityContext
if _, err := os.Stat(capStoreFile); err != nil {
if err := fs.MkdirAll(capStoreDir, os.FileMode(0o700)); err != nil {
return fmt.Errorf("unable to create capability context directory: %w", err)
}
// does not exist; create it
rootDID := did.FromPublicKey(pubKey)
capCtx, err = ucan.NewCapabilityContext(trustCtx, rootDID, nil, ucan.TokenList{}, ucan.TokenList{})
if err != nil {
return fmt.Errorf("unable to create capability context: %w", err)
}
// Save it!
f, err := os.Create(capStoreFile)
if err != nil {
return fmt.Errorf("unable to create capability context file: %w", err)
}
err = ucan.SaveCapabilityContext(capCtx, f)
_ = f.Close()
if err != nil {
return fmt.Errorf("unable to save capability context: %w", err)
}
} else {
f, err := os.Open(capStoreFile)
if err != nil {
return fmt.Errorf("unable to open capability context: %w", err)
}
capCtx, err = ucan.LoadCapabilityContext(trustCtx, f)
_ = f.Close()
if err != nil {
return fmt.Errorf("unable to load capability context: %w", err)
}
}
trustCtx.Start(time.Hour)
capCtx.Start(5 * time.Minute)
hostLocation := node.HostGeolocation{
HostCountry: gcfg.HostCountry,
HostCity: gcfg.HostCity,
HostContinent: gcfg.HostContinent,
}
portConfig := node.PortConfig{
AvailableRangeFrom: gcfg.PortAvailableRangeFrom,
AvailableRangeTo: gcfg.PortAvailableRangeTo,
}
hostID := p2p.Host.ID().String()
node, err := node.New(onboard, capCtx, hostID, p2p, resourceManager, cfg.Scheduler, hardwareManager, geoip2db, hostLocation, portConfig)
if err != nil {
return fmt.Errorf("failed to create node: %s", err)
}
err = node.Start()
if err != nil {
return fmt.Errorf("failed to start node: %s", err)
}
// initialize rest api server
restConfig := api.RESTServerConfig{
P2P: p2pNet,
Onboarding: onboard,
Resource: resourceManager,
MidW: nil,
Port: gcfg.Rest.Port,
Addr: gcfg.Rest.Addr,
}
rServer := api.NewRESTServer(&restConfig)
rServer.InitializeRoutes()
go func() {
err := rServer.Run()
if err != nil {
log.Fatal(err)
}
}()
// wait for SIGINT or SIGTERM
sig := <-internal.ShutdownChan
// clean up
go func() {
err = node.Stop()
if err != nil {
log.Errorf("failed to stop node: %s", err)
}
err = p2p.Stop()
if err != nil {
log.Errorf("failed to stop libp2p network: %s", err)
}
log.Infof("Shutting down after receiving %v...\n", sig)
os.Exit(0)
}()
sig = <-internal.ShutdownChan
log.Infof("Shutting down after receiving %v...\n", sig)
os.Exit(1)
return nil
}
// GenerateAndStorePrivKey generates a new key pair using Secp256k1,
// storing the private key into user's keystore.
func GenerateAndStorePrivKey(ks keystore.KeyStore, passphrase string, keyID string) (crypto.PrivKey, error) {
priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, 256)
if err != nil {
return nil, fmt.Errorf("unable to generate key pair: %w", err)
}
rawPriv, err := crypto.MarshalPrivateKey(priv)
if err != nil {
return nil, fmt.Errorf("unable to marshal private key: %w", err)
}
_, err = ks.Save(
keyID,
rawPriv,
passphrase,
)
if err != nil {
return nil, fmt.Errorf("unable to save private key into the keystore: %w", err)
}
return priv, nil
}
func ValidateOnboarding(oConf *types.OnboardingConfig) {
// Check 1: Check if payment address is valid
err := utils.ValidateAddress(oConf.PublicKey)
if err != nil {
log.Errorf("the payment address %s is not valid", oConf.PublicKey)
log.Error("exiting DMS")
return
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cpu
import (
"fmt"
"time"
"github.com/shirou/gopsutil/v4/cpu"
"gitlab.com/nunet/device-management-service/types"
)
// GetUsage returns the CPU usage for the system
func GetUsage() (types.CPU, error) {
cpuUsage, err := cpu.Percent(time.Second, false)
if err != nil {
return types.CPU{}, fmt.Errorf("failed to get CPU usage: %s", err)
}
cpuInfo, err := GetCPU()
if err != nil {
return types.CPU{}, fmt.Errorf("failed to get CPU info: %s", err)
}
usedCores := float64(cpuInfo.Cores) * cpuUsage[0] / 100
cpuInfo.Cores = float32(usedCores)
return cpuInfo, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cpu
import (
"fmt"
"github.com/shirou/gopsutil/v4/cpu"
"gitlab.com/nunet/device-management-service/types"
)
// GetCPU returns the CPU information for the system
func GetCPU() (types.CPU, error) {
cores, err := cpu.Info()
if err != nil {
return types.CPU{}, fmt.Errorf("failed to get CPU info: %s", err)
}
var totalCompute float64
for i := 0; i < len(cores); i++ {
totalCompute += cores[i].Mhz
}
return types.CPU{
Cores: float32(len(cores)),
ClockSpeed: cores[0].Mhz * 1000000,
}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package hardware
import (
"context"
"fmt"
"github.com/shirou/gopsutil/v4/disk"
"gitlab.com/nunet/device-management-service/types"
)
// GetDisk returns the types.Disk for the system
func GetDisk() (types.Disk, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get partitions: %w", err)
}
var totalStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get disk usage: %w", err)
}
totalStorage += usage.Total
}
return types.Disk{
Size: float64(totalStorage),
}, nil
}
// GetDiskUsage returns the types.Disk usage
func GetDiskUsage() (types.Disk, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get partitions: %w", err)
}
var usedStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get disk usage: %w", err)
}
usedStorage += usage.Used
}
return types.Disk{
Size: float64(usedStorage),
}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && amd64
package gpu
import (
"fmt"
"os/exec"
"regexp"
"strconv"
"gitlab.com/nunet/device-management-service/types"
)
// runROCmSmiCommand executes the rocm-smi command and returns the output as a string.
func runROCmSmiCommand() (string, error) {
cmd := exec.Command("rocm-smi", "--showid", "--showproductname", "--showmeminfo", "vram", "--showbus")
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("AMD ROCm not installed, initialized, or configured (reboot recommended for newly installed AMD GPU Drivers): %s", err)
}
return string(output), nil
}
// parseRegex extracts all matches from the given regex pattern and returns the matches.
func parseRegex(pattern, output string) [][]string {
regex := regexp.MustCompile(pattern)
return regex.FindAllStringSubmatch(output, -1)
}
// getAMDGPUPCIAddress extracts the PCI bus ID from the command output.
func getAMDGPUPCIAddress(output string) ([]string, error) {
pciMatches := parseRegex(`GPU\[\d+\]\s+: PCI Bus:\s+([^\n]+)`, output)
if len(pciMatches) == 0 {
return nil, fmt.Errorf("find PCI bus IDs in the output")
}
pciAddresses := make([]string, len(pciMatches))
for i, match := range pciMatches {
pciAddresses[i] = match[1]
}
return pciAddresses, nil
}
// getAMDGPUTotalVRAM extracts the total VRAM from the command output and converts it to MiB.
func getAMDGPUTotalVRAM(output string) ([]float64, error) {
totalMatches := parseRegex(`GPU\[\d+\]\s+: VRAM Total Memory \(B\):\s+(\d+)`, output)
if len(totalMatches) == 0 {
return nil, fmt.Errorf("find total VRAM in the output")
}
totalVRAMs := make([]float64, len(totalMatches))
for i, match := range totalMatches {
memoryBytes, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return nil, fmt.Errorf("parse total VRAM for GPU %d: %s", i, err)
}
totalVRAMs[i] = memoryBytes
}
return totalVRAMs, nil
}
// getAMDGPUUsedVRAM extracts the used VRAM from the command output and converts it to MiB.
func getAMDGPUUsedVRAM(output string) ([]float64, error) {
usedMatches := parseRegex(`GPU\[\d+\]\s+: VRAM Total Used Memory \(B\):\s+(\d+)`, output)
if len(usedMatches) == 0 {
return nil, fmt.Errorf("find used VRAM in the output")
}
usedVRAMs := make([]float64, len(usedMatches))
for i, match := range usedMatches {
memoryBytes, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return nil, fmt.Errorf("parse used VRAM for GPU %d: %s", i, err)
}
usedVRAMs[i] = memoryBytes
}
return usedVRAMs, nil
}
// getAMDGPUName extracts the GPU name from the command output.
func getAMDGPUName(output string) ([]string, error) {
nameMatches := parseRegex(`GPU\[\d+\]\s+: Card Series:\s+([^\n]+)`, output)
if len(nameMatches) == 0 {
return nil, fmt.Errorf("find GPU names in the output")
}
names := make([]string, len(nameMatches))
for i, match := range nameMatches {
names[i] = match[1]
}
return names, nil
}
// getAMDGPUs returns the GPU information for AMD GPUs.
func getAMDGPUs() ([]types.GPU, error) {
output, err := runROCmSmiCommand()
if err != nil {
return nil, fmt.Errorf("failed to get AMD GPU information: %s", err)
}
gpuNameMatches, err := getAMDGPUName(output)
if err != nil {
return nil, fmt.Errorf("failed to get AMD GPU names: %s", err)
}
totalVRAMs, err := getAMDGPUTotalVRAM(output)
if err != nil {
return nil, fmt.Errorf("failed to get AMD GPU total VRAM: %s", err)
}
pciAddresses, err := getAMDGPUPCIAddress(output)
if err != nil {
return nil, fmt.Errorf("failed to get AMD GPU PCI addresses: %s", err)
}
if len(gpuNameMatches) != len(totalVRAMs) || len(gpuNameMatches) != len(pciAddresses) {
return nil, fmt.Errorf("failed to get AMD GPU information: mismatched GPU details")
}
gpuInfos := make([]types.GPU, 0, len(gpuNameMatches))
for i := range gpuNameMatches {
gpuInfo := types.GPU{
Model: gpuNameMatches[i],
VRAM: totalVRAMs[i],
Vendor: types.GPUVendorAMDATI,
PCIAddress: pciAddresses[i],
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// getAMDGPUUsage returns the GPU usage for AMD GPUs.
func getAMDGPUUsage() ([]types.GPU, error) {
output, err := runROCmSmiCommand()
if err != nil {
return nil, fmt.Errorf("failed to get AMD GPU information: %s", err)
}
gpuNameMatches, err := getAMDGPUName(output)
if err != nil {
return nil, fmt.Errorf("failed to get AMD GPU names: %s", err)
}
pciAddresses, err := getAMDGPUPCIAddress(output)
if err != nil {
return nil, fmt.Errorf("failed to get AMD GPU PCI addresses: %s", err)
}
usedVRAMs, err := getAMDGPUUsedVRAM(output)
if err != nil {
return nil, fmt.Errorf("failed to get AMD GPU used VRAM: %s", err)
}
if len(gpuNameMatches) != len(usedVRAMs) || len(gpuNameMatches) != len(pciAddresses) {
return nil, fmt.Errorf("failed to get AMD GPU information: mismatched GPU details")
}
gpus := make([]types.GPU, 0, len(usedVRAMs))
for i, usedVRAM := range usedVRAMs {
gpuInfo := types.GPU{
Model: gpuNameMatches[i],
VRAM: usedVRAM,
Vendor: types.GPUVendorAMDATI,
PCIAddress: pciAddresses[i],
}
gpus = append(gpus, gpuInfo)
}
return gpus, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && amd64
package gpu
import (
"fmt"
"sync"
"gitlab.com/nunet/device-management-service/types"
)
var (
gpuInfoCache map[types.GPUVendor][]types.GPU
mu sync.Mutex
)
// GetGPUs returns the GPU information based on the specified vendors.
// If no vendors are provided, it returns the information of all the GPUs.
func GetGPUs(vendors ...types.GPUVendor) ([]types.GPU, error) {
return getGPUsHelper(assignIndexToGPUs, map[types.GPUVendor]func() ([]types.GPU, error){
types.GPUVendorIntel: getIntelGPUs,
types.GPUVendorNvidia: getNVIDIAGPUs,
types.GPUVendorAMDATI: getAMDGPUs,
}, vendors...)
}
// GetGPUUsage returns the GPU usage based on the specified vendors.
// If no vendors are provided, it returns the information of all the GPUs.
func GetGPUUsage(vendors ...types.GPUVendor) ([]types.GPU, error) {
return getGPUsHelper(assignIndexToGPUs, map[types.GPUVendor]func() ([]types.GPU, error){
types.GPUVendorIntel: getIntelGPUUsage,
types.GPUVendorNvidia: getNVIDIAGPUUsage,
types.GPUVendorAMDATI: getAMDGPUUsage,
}, vendors...)
}
// getGPUsHelper is a helper function to avoid code duplication in GetGPUs and GetGPUUsage.
// It fetches GPU information from different vendors and aggregates the results.
func getGPUsHelper(assignFunc func([]types.GPU) []types.GPU, fetchFuncs map[types.GPUVendor]func() ([]types.GPU, error), vendors ...types.GPUVendor) ([]types.GPU, error) {
var gpus []types.GPU
mu.Lock()
defer mu.Unlock()
if gpuInfoCache == nil {
gpuInfoCache = make(map[types.GPUVendor][]types.GPU)
}
// Helper function to fetch and append GPU info for a vendor
fetchAndAppendGPUs := func(fetchFunc func() ([]types.GPU, error), vendor types.GPUVendor) {
gpuList, err := fetchFunc()
if err != nil {
// TODO: log a warning here
return
}
gpuInfoCache[vendor] = gpuList
gpus = append(gpus, gpuList...)
}
if len(vendors) == 0 {
// No specific vendor requested, fetch all types of GPUs
for vendor, fetchFunc := range fetchFuncs {
if cachedGpus, ok := gpuInfoCache[vendor]; ok {
gpus = append(gpus, cachedGpus...)
} else {
fetchAndAppendGPUs(fetchFunc, vendor)
}
}
} else {
// Fetch GPUs for the specified vendor only
for _, vendor := range vendors {
fetchFunc, ok := fetchFuncs[vendor]
if !ok {
return nil, fmt.Errorf("unsupported GPU vendor: %v", vendor)
}
if cachedGpus, ok := gpuInfoCache[vendor]; ok {
gpus = append(gpus, cachedGpus...)
} else {
fetchAndAppendGPUs(fetchFunc, vendor)
}
}
}
// Assign index to GPUs and return
// Note: The index is internal to dms and is not the same as the device index
return assignFunc(gpus), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && amd64
package gpu
import (
"fmt"
"os/exec"
"regexp"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/types"
)
// runXpuSmiCommand runs the xpu-smi command with the provided arguments and returns the output as a string.
func runXpuSmiCommand(args ...string) (string, error) {
cmd := exec.Command("xpu-smi", args...)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("xpu-smi command failed: %s", err)
}
return string(output), nil
}
// getIntelGPUDeviceIDs extracts the device IDs of Intel GPUs from the xpu-smi output.
func getIntelGPUDeviceIDs(output string) ([]string, error) {
deviceIDRegex := regexp.MustCompile(`(?i)\| Device ID\s+\|\s+(\d+)\s+\|`)
deviceIDMatches := deviceIDRegex.FindAllStringSubmatch(output, -1)
if len(deviceIDMatches) == 0 {
return nil, fmt.Errorf("failed to find any Intel GPUs")
}
deviceIDs := make([]string, len(deviceIDMatches))
for i, match := range deviceIDMatches {
deviceIDs[i] = match[1]
}
return deviceIDs, nil
}
// getIntelGPUDiscoveryInfo retrieves the GPU name and total memory for a specific Intel GPU.
func getIntelGPUDiscoveryInfo(deviceID string) (string, float64, error) {
output, err := runXpuSmiCommand("discovery", "-d", deviceID)
if err != nil {
return "", 0, fmt.Errorf("failed to get discovery info for Intel GPU %s: %s", deviceID, err)
}
// Extract the GPU name and total memory
nameRegex := regexp.MustCompile(`(?i)Device Name:\s+([^\n|]+)`)
totalMemRegex := regexp.MustCompile(`(?i)Memory Physical Size:\s+([^\s]+)\s+MiB`)
nameMatch := nameRegex.FindStringSubmatch(output)
totalMemMatch := totalMemRegex.FindStringSubmatch(output)
if nameMatch == nil || totalMemMatch == nil {
return "", 0, fmt.Errorf("failed to parse discovery info for Intel GPU %s", deviceID)
}
gpuName := strings.TrimSpace(nameMatch[1])
totalMemoryMiB, err := strconv.ParseFloat(totalMemMatch[1], 64)
if err != nil {
return "", 0, fmt.Errorf("failed to parse total memory for Intel GPU %s: %s", deviceID, err)
}
return gpuName, types.ConvertMibToBytes(totalMemoryMiB), nil
}
// getIntelGPUPCIAddress extracts the PCI bus ID from the xpu-smi output.
func getIntelGPUPCIAddress(deviceID string) (string, error) {
output, err := runXpuSmiCommand("discovery", "-d", deviceID)
if err != nil {
return "", fmt.Errorf("failed to get PCI address for Intel GPU %s: %s", deviceID, err)
}
// Extract the PCI bus address
pciRegex := regexp.MustCompile(`(?i)PCI\s+BDF\s+Address:\s+([^\n|]+)`)
pciMatch := pciRegex.FindStringSubmatch(output)
if pciMatch == nil {
return "", fmt.Errorf("failed to parse PCI bus address for Intel GPU %s", deviceID)
}
pciAddress := strings.TrimSpace(pciMatch[1])
return pciAddress, nil
}
// getIntelGPUUsedMemory retrieves the used memory for a specific Intel GPU.
func getIntelGPUUsedMemory(deviceID string) (float64, error) {
output, err := runXpuSmiCommand("stats", "-d", deviceID)
if err != nil {
return 0, fmt.Errorf("failed to get stats for Intel GPU %s: %s", deviceID, err)
}
// Extract the used memory
usedMemRegex := regexp.MustCompile(`(?i)GPU Memory Used \(MiB\)\s+\|\s+(\d+)\s+\|`)
usedMemMatch := usedMemRegex.FindStringSubmatch(output)
if usedMemMatch == nil {
return 0, fmt.Errorf("failed to parse used memory for Intel GPU %s", deviceID)
}
usedMemory, err := strconv.ParseFloat(usedMemMatch[1], 64)
if err != nil {
return 0, fmt.Errorf("failed to parse used memory for Intel GPU %s: %s", deviceID, err)
}
return types.ConvertMibToBytes(usedMemory), nil
}
// getIntelGPUs returns the GPU information for Intel GPUs.
func getIntelGPUs() ([]types.GPU, error) {
// Get the list of Intel GPU devices
output, err := runXpuSmiCommand("health", "-l")
if err != nil {
return nil, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
// Get Intel GPU device IDs
deviceIDs, err := getIntelGPUDeviceIDs(output)
if err != nil {
return nil, err
}
gpuInfos := make([]types.GPU, 0, len(deviceIDs))
for _, deviceID := range deviceIDs {
// Get GPU discovery info
gpuName, totalMemoryMiB, err := getIntelGPUDiscoveryInfo(deviceID)
if err != nil {
return nil, err
}
// Get PCI address
pciAddress, err := getIntelGPUPCIAddress(deviceID)
if err != nil {
return nil, fmt.Errorf("get PCI address for Intel GPU %s: %s", deviceID, err)
}
// Populate GPU info
gpuInfo := types.GPU{
Model: gpuName,
VRAM: totalMemoryMiB,
Vendor: types.GPUVendorIntel,
PCIAddress: pciAddress,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// getIntelGPUUsage returns the GPU usage for Intel GPUs.
func getIntelGPUUsage() ([]types.GPU, error) {
// Reuse xpu-smi output and Intel GPU device IDs to avoid multiple executions
output, err := runXpuSmiCommand("health", "-l")
if err != nil {
return nil, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
// Get Intel GPU device IDs
deviceIDs, err := getIntelGPUDeviceIDs(output)
if err != nil {
return nil, err
}
gpuInfos := make([]types.GPU, 0, len(deviceIDs))
for _, deviceID := range deviceIDs {
gpuName, _, err := getIntelGPUDiscoveryInfo(deviceID)
if err != nil {
return nil, err
}
usedMemory, err := getIntelGPUUsedMemory(deviceID)
if err != nil {
return nil, err
}
pciAddress, err := getIntelGPUPCIAddress(deviceID)
if err != nil {
return nil, fmt.Errorf("get PCI address for Intel GPU %s: %s", deviceID, err)
}
// Populate GPU info
gpuInfo := types.GPU{
PCIAddress: pciAddress,
Model: gpuName,
VRAM: usedMemory,
Vendor: types.GPUVendorIntel,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && amd64
package gpu
import (
"errors"
"fmt"
"strings"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"gitlab.com/nunet/device-management-service/types"
)
// initNVML initializes the NVIDIA Management Library.
func initNVML() error {
ret := nvml.Init()
if !errors.Is(ret, nvml.SUCCESS) {
return fmt.Errorf("NVIDIA Management Library not installed, initialized, or configured (reboot recommended for newly installed NVIDIA GPU drivers): %s", nvml.ErrorString(ret))
}
return nil
}
// shutdownNVML shuts down the NVIDIA Management Library.
func shutdownNVML() {
_ = nvml.Shutdown()
}
// getNVIDIADeviceCount returns the number of NVIDIA devices (GPUs).
func getNVIDIADeviceCount() (int, error) {
deviceCount, ret := nvml.DeviceGetCount()
if !errors.Is(ret, nvml.SUCCESS) {
return 0, fmt.Errorf("failed to get device count: %s", nvml.ErrorString(ret))
}
return deviceCount, nil
}
// getNVIDIADeviceHandle returns the handle for the NVIDIA device by its index.
func getNVIDIADeviceHandle(index int) (nvml.Device, error) {
device, ret := nvml.DeviceGetHandleByIndex(index)
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device handle for device %d: %s", index, nvml.ErrorString(ret))
}
return device, nil
}
// getNVIDIADeviceName returns the name of the NVIDIA device.
func getNVIDIADeviceName(device nvml.Device) (string, error) {
name, ret := device.GetName()
if !errors.Is(ret, nvml.SUCCESS) {
return "", fmt.Errorf("failed to get name for device: %s", nvml.ErrorString(ret))
}
return name, nil
}
// getNVIDIADeviceMemory returns the memory information for the NVIDIA device.
func getNVIDIADeviceMemory(device nvml.Device) (nvml.Memory, error) {
memory, ret := device.GetMemoryInfo()
if !errors.Is(ret, nvml.SUCCESS) {
return nvml.Memory{}, fmt.Errorf("failed to get NVIDIA GPU memory info: %s", nvml.ErrorString(ret))
}
return memory, nil
}
// convertBusID converts the BusId array to a correctly formatted PCI address string.
func convertBusID(busID [32]int8) string {
busIDBytes := make([]byte, len(busID))
for i, b := range busID {
busIDBytes[i] = byte(b)
}
busIDStr := strings.TrimRight(string(busIDBytes), "\x00")
// Check if the string starts with extra zero groups and correct the format
if strings.HasPrefix(busIDStr, "00000000:") {
// Trim it to the correct format: "0000:XX:YY.Z"
return "0000" + busIDStr[8:]
}
return busIDStr
}
// getNVIDIAPCIAddress returns the PCI address for the NVIDIA device.
func getNVIDIAPCIAddress(device nvml.Device) (string, error) {
pciInfo, ret := device.GetPciInfo()
if !errors.Is(ret, nvml.SUCCESS) {
return "", fmt.Errorf("failed to get PCI info for device: %s", nvml.ErrorString(ret))
}
return convertBusID(pciInfo.BusId), nil
}
// getNVIDIAGPUs returns the GPU information for NVIDIA GPUs.
func getNVIDIAGPUs() ([]types.GPU, error) {
if err := initNVML(); err != nil {
return nil, err
}
defer shutdownNVML()
deviceCount, err := getNVIDIADeviceCount()
if err != nil {
return nil, err
}
var gpus []types.GPU
// Iterate over each device
for i := 0; i < deviceCount; i++ {
device, err := getNVIDIADeviceHandle(i)
if err != nil {
return nil, err
}
name, err := getNVIDIADeviceName(device)
if err != nil {
return nil, err
}
memory, err := getNVIDIADeviceMemory(device)
if err != nil {
return nil, err
}
pciAddress, err := getNVIDIAPCIAddress(device)
if err != nil {
return nil, err
}
gpu := types.GPU{
PCIAddress: pciAddress,
Model: name,
VRAM: float64(memory.Total),
Vendor: types.GPUVendorNvidia,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// getNVIDIAGPUUsage returns the GPU usage for NVIDIA GPUs.
func getNVIDIAGPUUsage() ([]types.GPU, error) {
if err := initNVML(); err != nil {
return nil, err
}
defer shutdownNVML()
deviceCount, err := getNVIDIADeviceCount()
if err != nil {
return nil, err
}
var gpus []types.GPU
for i := 0; i < deviceCount; i++ {
device, err := getNVIDIADeviceHandle(i)
if err != nil {
return nil, err
}
name, err := getNVIDIADeviceName(device)
if err != nil {
return nil, err
}
memory, err := getNVIDIADeviceMemory(device)
if err != nil {
return nil, err
}
pciAddress, err := getNVIDIAPCIAddress(device)
if err != nil {
return nil, err
}
gpu := types.GPU{
PCIAddress: pciAddress,
Model: name,
VRAM: float64(memory.Used),
Vendor: types.GPUVendorNvidia,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package gpu
import (
"gitlab.com/nunet/device-management-service/types"
)
// assignIndexToGPUs assigns an index to each GPU in the list starting from 0
func assignIndexToGPUs(gpus []types.GPU) []types.GPU {
for i := range gpus {
gpus[i].Index = i
}
return gpus
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package hardware
import (
"fmt"
"sync"
"gitlab.com/nunet/device-management-service/dms/hardware/cpu"
"gitlab.com/nunet/device-management-service/dms/hardware/gpu"
"gitlab.com/nunet/device-management-service/types"
)
// defaultHardwareManager manages the machine's hardware resources.
type defaultHardwareManager struct {
machineResources *types.MachineResources
mu sync.Mutex
}
// NewHardwareManager creates a new instance of defaultHardwareManager.
func NewHardwareManager() types.HardwareManager {
return &defaultHardwareManager{}
}
var _ types.HardwareManager = (*defaultHardwareManager)(nil)
// GetMachineResources returns the resources of the machine in a thread-safe manner.
func (m *defaultHardwareManager) GetMachineResources() (types.MachineResources, error) {
m.mu.Lock()
defer m.mu.Unlock()
if m.machineResources != nil {
return *m.machineResources, nil
}
var err error
var cpuDetails types.CPU
var ram types.RAM
var gpus []types.GPU
var diskDetails types.Disk
if cpuDetails, err = cpu.GetCPU(); err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get CPU: %w", err)
}
if ram, err = GetRAM(); err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get RAM: %w", err)
}
if gpus, err = gpu.GetGPUs(); err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get GPUs: %w", err)
}
if diskDetails, err = GetDisk(); err != nil {
return types.MachineResources{}, fmt.Errorf("failed to get Disk: %w", err)
}
m.machineResources = &types.MachineResources{
Resources: types.Resources{
CPU: cpuDetails,
RAM: ram,
Disk: diskDetails,
GPUs: gpus,
},
}
return *m.machineResources, nil
}
// GetUsage returns the usage of the machine.
func (m *defaultHardwareManager) GetUsage() (types.Resources, error) {
cpuDetails, err := cpu.GetUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get CPU usage: %w", err)
}
ram, err := GetRAMUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get RAM usage: %w", err)
}
diskDetails, err := GetDiskUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get Disk usage: %w", err)
}
gpus, err := gpu.GetGPUUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get GPU usage: %w", err)
}
return types.Resources{
CPU: cpuDetails,
RAM: ram,
Disk: diskDetails,
GPUs: gpus,
}, nil
}
// GetFreeResources returns the free resources of the machine.
func (m *defaultHardwareManager) GetFreeResources() (types.Resources, error) {
usage, err := m.GetUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get usage: %w", err)
}
availableResources, err := m.GetMachineResources()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get machine resources: %w", err)
}
if err := availableResources.Subtract(usage); err != nil {
return types.Resources{}, fmt.Errorf("no free resources: %w", err)
}
return availableResources.Resources, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package hardware
import (
"fmt"
"github.com/shirou/gopsutil/v4/mem"
"gitlab.com/nunet/device-management-service/types"
)
// GetRAM returns the types.RAM information for the system
func GetRAM() (types.RAM, error) {
v, err := mem.VirtualMemory()
if err != nil {
return types.RAM{}, fmt.Errorf("failed to get total memory: %s", err)
}
return types.RAM{
Size: float64(v.Total),
}, nil
}
// GetRAMUsage returns the RAM usage
func GetRAMUsage() (types.RAM, error) {
v, err := mem.VirtualMemory()
if err != nil {
return types.RAM{}, fmt.Errorf("failed to get total memory: %s", err)
}
return types.RAM{
Size: float64(v.Used),
}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package dms
import (
"os"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/internal/config"
)
func init() {
fs := afero.NewOsFs()
workDir := config.GetConfig().WorkDir
if workDir != "" {
err := fs.MkdirAll(workDir, os.FileMode(0o700))
if err != nil {
log.Warnf("unable to create work directory: %v", err)
}
}
dataDir := config.GetConfig().DataDir
if dataDir != "" {
err := fs.MkdirAll(dataDir, os.FileMode(0o700))
if err != nil {
log.Warnf("unable to create data directory: %v", err)
}
}
userDir := config.GetConfig().UserDir
if userDir != "" {
err := fs.MkdirAll(userDir, os.FileMode(0o700))
if err != nil {
log.Warnf("unable to create user directory: %v", err)
}
}
libp2pLogging := os.Getenv("LIBP2P_LOGGING")
if libp2pLogging == "false" || libp2pLogging == "" {
err := silenceLibp2pLogging()
if err != nil {
log.Warnf("unable to set libp2p logging: %v", err)
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobs
import (
"context"
"errors"
"fmt"
"sync"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/executor"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/types"
)
const (
pending AllocationStatus = "pending"
running AllocationStatus = "running"
stopped AllocationStatus = "stopped"
)
// AllocationStatus is a representation of the execution status
type AllocationStatus string
// Status holds the status of an allocation.
type Status struct {
JobResources types.Resources
Status AllocationStatus
}
// AllocationDetails encapsulates the dependencies to the constructor.
type AllocationDetails struct {
Job Job
NodeID string
SourceID string
}
type Job struct {
ID string
Resources types.Resources
Execution types.SpecConfig
ProvisionScripts map[string][]byte
}
// Allocation represents an allocation
type Allocation struct {
ID string
mx sync.Mutex
status AllocationStatus
nodeID string
sourceID string
executionID string
Actor actor.Actor
executor executor.Executor
resourceManager types.ResourceManager
actorRunning bool
Job Job
}
// NewAllocation creates a new allocation given the actor.
func NewAllocation(actor actor.Actor, details AllocationDetails, resourceManager types.ResourceManager) (*Allocation, error) {
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
id, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate uuid for allocation: %w", err)
}
executorID, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to create executor id: %w", err)
}
return &Allocation{
ID: id.String(),
nodeID: details.NodeID,
sourceID: details.SourceID,
Job: details.Job,
Actor: actor,
executionID: executorID.String(),
resourceManager: resourceManager,
status: pending,
}, nil
}
// Run creates the executor based on the execution engine configuration.
func (a *Allocation) Run(ctx context.Context) error {
a.mx.Lock()
defer a.mx.Unlock()
if a.status == running {
return nil
}
resourceAllocation := types.ResourceAllocation{JobID: a.Job.ID, Resources: a.Job.Resources}
err := a.resourceManager.AllocateResources(ctx, resourceAllocation)
if err != nil {
return fmt.Errorf("failed to allocate resources: %w", err)
}
defer func() {
if a.status != running {
// If not running, ensure deallocation of resources
err = a.resourceManager.DeallocateResources(ctx, a.Job.ID)
}
}()
// if executor is nil create it
if a.executor == nil {
err = a.createExecutor(ctx, a.Job.Execution)
if err != nil {
return fmt.Errorf("failed to create executor: %w", err)
}
}
err = a.executor.Start(ctx, &types.ExecutionRequest{
JobID: a.Job.ID,
ExecutionID: a.executionID,
EngineSpec: &a.Job.Execution,
Resources: &a.Job.Resources,
ProvisionScripts: a.Job.ProvisionScripts,
// TODO add the following
Inputs: []*types.StorageVolumeExecutor{}, // Question: what are those?
Outputs: []*types.StorageVolumeExecutor{},
ResultsDir: "",
})
if err != nil {
return fmt.Errorf("failed to start executor: %w", err)
}
a.status = running
return nil
}
// Stop stops the running executor
func (a *Allocation) Stop(ctx context.Context) error {
a.mx.Lock()
defer a.mx.Unlock()
defer func() {
if a.actorRunning {
if err := a.Actor.Stop(); err != nil {
log.Warnf("error stopping allocation actor: %s", err)
}
a.actorRunning = false
}
}()
if a.status != running {
return nil
}
if err := a.executor.Cancel(ctx, a.executionID); err != nil {
return fmt.Errorf("failed to stop execution: %w", err)
}
a.status = stopped
if err := a.resourceManager.DeallocateResources(ctx, a.Job.ID); err != nil {
return fmt.Errorf("failed to deallocate resources: %w", err)
}
return nil
}
// Status returns information about the allocated/usage of resources and execution status of workload.
func (a *Allocation) Status(_ context.Context) Status {
return Status{
JobResources: a.Job.Resources,
Status: a.status,
}
}
// Start the actor of the allocation.
func (a *Allocation) Start() error {
a.mx.Lock()
defer a.mx.Unlock()
if a.actorRunning {
return nil
}
err := a.Actor.Start()
if err != nil {
return fmt.Errorf("failed to start allocation actor: %w", err)
}
a.actorRunning = true
return nil
}
func (a *Allocation) createExecutor(ctx context.Context, execution types.SpecConfig) error {
switch execution.Type {
case types.ExecutorTypeDocker.String():
id := uuid.New().String()
exec, err := docker.NewExecutor(ctx, id)
if err != nil {
return fmt.Errorf("failed to create executor: %w", err)
}
a.executor = exec
default:
return fmt.Errorf("unsupported executor type: %s", execution.Type)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobs
import (
"encoding/json"
"errors"
"fmt"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/types"
)
// EnsembleBidRequest is a request for a bids pertaining to an ensemble
//
// Note: At the moment, we embed a bid request for each node
// This is fine for small deployments, and a small network, which is what we have.
// For large deployments however, this won't scale and we will have to create aggregate
// bid requests for related group of nodes and also handle them with bid request
// aggregators who control multiple nodes.
type EnsembleBidRequest struct {
ID string // unique identifier of an ensemble (in the context of the orchestrator)
Request []BidRequest // list of node bid requests
PeerExclusion []string // list of peers to exclude from bidding
}
// BidRequest is a versioned bid request
type BidRequest struct {
V1 *BidRequestV1
}
// BidRequestV1 is v1 of bid requests for a node to use for deployment
type BidRequestV1 struct {
NodeID string // unique identifier for a node, within the context of an ensemble
Executors []AllocationExecutor // list of required executors to support the allocation(s)
Resources types.Resources // (aggregate) required hardware resources
Location LocationConstraints // node location constraints
PublicPorts struct {
Static []int // statically configured public ports
Dynamic int // number of dynamic ports
}
}
// Bid is the version struct for Bids in response to a bid request
type Bid struct {
V1 *BidV1
}
// BidV1 is v1 of the bid structure
type BidV1 struct {
EnsembleID string // unique identifier for the ensemble
NodeID string // unique identifier for a node; matches the id of the BidRequest to which this bid pertains
Peer string // the peer ID of the node
Location Location // the location of the node
Handle actor.Handle // the handle of the actor submitting the bid
Signature []byte
}
const bidPrefix = "dms-bid-"
func (b *EnsembleBidRequest) Validate() error {
if b == nil {
return errors.New("ensemble bid request is nil")
}
if b.ID == "" {
return errors.New("ensemble id is empty")
}
if len(b.Request) == 0 {
return errors.New("ensemble with empty requests")
}
return nil
}
func (b *Bid) SignatureData() ([]byte, error) {
bidV1Copy := *b.V1
bidV1Copy.Signature = nil
data, err := json.Marshal(&bidV1Copy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(bidPrefix)+len(data))
copy(result, []byte(bidPrefix))
copy(result[len(bidPrefix):], data)
return result, nil
}
func (b *Bid) Sign(key did.Provider) error {
data, err := b.SignatureData()
if err != nil {
return fmt.Errorf("unable to create bid signature data")
}
sig, err := key.Sign(data)
if err != nil {
return fmt.Errorf("unable to sign the bid")
}
b.V1.Signature = sig
return nil
}
func (b *Bid) Validate() error {
if b.V1 == nil {
return fmt.Errorf("bid V1 is nil")
}
p, err := peer.Decode(b.V1.Peer)
if err != nil {
return fmt.Errorf("failed to decode bid's peer id: %w", err)
}
pubKey, err := p.ExtractPublicKey()
if err != nil {
return fmt.Errorf("failed to extract public key: %w", err)
}
signData, err := b.SignatureData()
if err != nil {
return fmt.Errorf("unable to get bid signature data")
}
ok, err := pubKey.Verify(signData, b.V1.Signature)
if err != nil {
return fmt.Errorf("failed to verify signature: %w", err)
}
if !ok {
return errors.New("signature verification failed")
}
return nil
}
func (b *Bid) EnsembleID() string {
return b.V1.EnsembleID
}
func (b *Bid) NodeID() string {
return b.V1.NodeID
}
func (b *Bid) Peer() string {
return b.V1.Peer
}
func (b *Bid) Handle() actor.Handle {
return b.V1.Handle
}
func (b *Bid) Location() Location {
return b.V1.Location
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobs
import (
"encoding/json"
"errors"
"gitlab.com/nunet/device-management-service/types"
)
// EnsembleConfig is the versioned structure that contains the ensemble configuration
type EnsembleConfig struct {
V1 *EnsembleConfigV1
}
// EnsembleConfigV1 is version 1 of the configuration for an ensemble
type EnsembleConfigV1 struct {
Allocations map[string]AllocationConfig // (named) allocations in the ensemble
Nodes map[string]NodeConfig // (named) nodes in the ensemble
Edges []EdgeConstraint // network edge constraints
Supervisor SupervisorConfig // supervision structure
Keys map[string]string // (named) ssh public keys relevant to the allocation
Scripts map[string][]byte // (named) provisioning scripts
}
// AllocationConfig is the configuration of an allocation
type AllocationConfig struct {
Executor AllocationExecutor // the executor of the allocation
Resources types.Resources // the HW resources required by the allocation
Execution types.SpecConfig // the allocation execution configuration
DNSName string // the internal DNS name of the allocation
Keys []string // names of the authorized ssh keys for the allocation
Provision []string // names of provisioning scripts to run (in order)
HealthCheck string // name of the script to run for health checks
}
// AllocationExecutor is the executor reoquired for the allocation
type AllocationExecutor string
const (
ExecutorFirecracker AllocationExecutor = "firecracker"
ExecutorDocker AllocationExecutor = "docker"
ExecutorNull AllocationExecutor = "null"
)
// NodeConfig is the configuration of a distinct DMS node
type NodeConfig struct {
Allocations []string // the list of (named) allocations in the node
Ports []PortConfig // the port mapping configuration for the node
Location LocationConstraints // the geographical location constraints for the node
Peer string // (optional) a fixed peer for the node
// TODO contract information
}
// LocationConstraints provides the node location placement constraints
type LocationConstraints struct {
Accept []Location // acceptable location constraints (disjunction)
Reject []Location // negative location constraints (conjunction); eg !USA for GPDR purposes
}
// Location is a geographical location on Planet Earth
type Location struct {
Region string // geographic region of the location (required)
Country string // country (code or name) of the location (optional)
City string // city of the location; optional but country must be specified if not empty
ASN uint // Autonomous System Number for the location (optional)
ISP string // Internet Service Provider name for the location (optional)
}
// PortConfig is the configuration for a port mapping a public port to a private port
// in an allocation
type PortConfig struct {
Public int // the public port 0 for any
Private int // the private mapping
Allocation string // the allocation where the port is mapped
}
// EdgeConstraint is a constraint for a network edge between two nodes
type EdgeConstraint struct {
S, T string // (named) nodes connected by the edge
RTT uint // maximum edge RTT in milliseconds
BW uint // minimum edge bandwidth in Kbps
}
// SupervisorConfig is the supervisory structure configuration for the ensemble
type SupervisorConfig struct {
Strategy SupervisorStrategy // the strategy for the supervision group
Allocations []string // allocations in this supervision group
Children []SupervisorConfig // allocation children for recursive groups
}
// SupervisoryStrategy is the name of a supervision strategy
type SupervisorStrategy string
const (
StrategyOneForOne SupervisorStrategy = "OneForOne"
StrategyAllForOne SupervisorStrategy = "AllForOne"
StrategyRestForOne SupervisorStrategy = "RestForOne"
)
// config validation
func (e *EnsembleConfig) Validate() error {
if e == nil || e.V1 == nil {
return errors.New("invalid ensemble config")
}
return nil
}
func (e *EnsembleConfig) Allocations() map[string]AllocationConfig {
return e.V1.Allocations
}
func (e *EnsembleConfig) Allocation(allocID string) (AllocationConfig, bool) {
a, ok := e.V1.Allocations[allocID]
return a, ok
}
func (e *EnsembleConfig) Nodes() map[string]NodeConfig {
return e.V1.Nodes
}
func (e *EnsembleConfig) Node(nodeID string) (NodeConfig, bool) {
n, ok := e.V1.Nodes[nodeID]
return n, ok
}
func (e *EnsembleConfig) EdgeConstraints() []EdgeConstraint {
return e.V1.Edges
}
func (l *Location) Includes(other Location) bool {
if l.Region != other.Region {
return false
}
if l.Country != "" && l.Country != other.Country {
return false
}
if l.City != "" && l.City != other.City {
return false
}
if l.ASN > 0 && l.ASN != other.ASN {
return false
}
if l.ISP != "" && l.ISP != other.ISP {
return false
}
return true
}
func (e *EnsembleConfig) Clone() EnsembleConfig {
var clone EnsembleConfig
bytes, err := json.Marshal(e)
if err != nil {
log.Errorf("error marshaling ensemble config: %s", err)
return clone
}
if err := json.Unmarshal(bytes, &clone); err != nil {
log.Errorf("error unmarshaling ensemble config: %s", err)
}
return clone
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobs
import (
"bufio"
"bytes"
_ "embed"
"fmt"
"math"
"strconv"
"strings"
)
// currently only using a GeoNames file for cities with population > 5000
// download it here: https://download.geonames.org/export/dump/cities5000.zip
//
//go:embed cities5000.txt
var cities5000 string
const lightSpeed = 299792.458 // in km/s
type Coordinate struct {
lat float64
long float64
}
func (c *Coordinate) Empty() bool {
return c.lat == 0 && c.long == 0
}
type GeoLocator struct {
coord map[string]map[string]Coordinate // country -> city -> coordinate
}
func NewGeoLocator() (*GeoLocator, error) {
buf := bytes.NewBufferString(cities5000)
geo := &GeoLocator{
coord: make(map[string]map[string]Coordinate),
}
scanner := bufio.NewScanner(buf)
scanner.Buffer(make([]byte, 64*1024), 1024*1024) // increase buffer size for large lines
for scanner.Scan() {
fields := strings.SplitN(scanner.Text(), "\t", 20) // limit to 20 fields in each entry
if len(fields) < 19 {
continue
}
cityName := fields[1]
countryCode := fields[8]
coord, err := parseCoordinate(fields)
if err != nil {
log.Warnf("error parsing coordiates for %s in %s: %s", cityName, countryCode, err)
continue
}
countryMap, ok := geo.coord[countryCode]
if !ok {
countryMap = make(map[string]Coordinate)
geo.coord[countryCode] = countryMap
}
countryMap[cityName] = coord
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("error reading cities file: %w", err)
}
return geo, nil
}
func parseCoordinate(fields []string) (Coordinate, error) {
lat, err := strconv.ParseFloat(fields[4], 64)
if err != nil {
return Coordinate{}, fmt.Errorf("failed to parse latitude: %w", err)
}
long, err := strconv.ParseFloat(fields[5], 64)
if err != nil {
return Coordinate{}, fmt.Errorf("failed to parse longitude: %w", err)
}
return Coordinate{lat: lat, long: long}, nil
}
func (geo *GeoLocator) Coordinate(loc Location) (Coordinate, error) {
if loc.Country == "" || loc.City == "" {
return Coordinate{}, fmt.Errorf("no city in location")
}
coord, ok := geo.coord[loc.Country][loc.City]
if !ok {
return Coordinate{}, fmt.Errorf("unknown city")
}
return coord, nil
}
// using a haversine formula to calculate the shortest path
func computeGeodesic(p1, p2 Coordinate) float64 {
const earthRadius = 6371 // km
if p1.Empty() || p2.Empty() {
return 0
}
lat1 := p1.lat * math.Pi / 180
lat2 := p2.lat * math.Pi / 180
dLat := (p2.lat - p1.lat) * math.Pi / 180
dLon := (p2.long - p1.long) * math.Pi / 180
a := math.Sin(dLat/2)*math.Sin(dLat/2) +
math.Cos(lat1)*math.Cos(lat2)*
math.Sin(dLon/2)*math.Sin(dLon/2)
c := 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a))
return earthRadius * c
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobs
import (
"encoding/json"
"gitlab.com/nunet/device-management-service/actor"
)
type EnsembleManifest struct {
ID string // ensemble globally unique id
Orchestrator actor.Handle // orchestrator actor
Allocations map[string]AllocationManifest // allocation name -> manifest
Nodes map[string]NodeManifest // node name -> manifest
}
type AllocationManifest struct {
ID string // allocation unique id
NodeID string // allocation node
Handle actor.Handle // handle of the allocation control actor
DNSName string // (internal) DNS name of the allocation
PrivAddr string // (VPN) private IP address of the allocation peer
Ports map[int]int // port mapping, public -> private
}
type NodeManifest struct {
ID string // node unique id
Peer string // peer where the node is running
Handle actor.Handle // handle of the control actor for the node
PubAddrss []string // public IP4/6 address of the node peer
Location Location // location of the peer
Allocations []string // allocations in the nod
}
func (mf *EnsembleManifest) Clone() EnsembleManifest {
var clone EnsembleManifest
bytes, err := json.Marshal(mf)
if err != nil {
log.Errorf("error marshaling ensemble manifest: %s", err)
return clone
}
if err := json.Unmarshal(bytes, &clone); err != nil {
log.Errorf("error unmarshaling ensemble manifest: %s", err)
}
return clone
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobs
import (
"context"
crand "crypto/rand"
"encoding/json"
"fmt"
"math"
"math/big"
"math/rand"
"net"
"strconv"
"sync"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/network/utils"
"gitlab.com/nunet/device-management-service/types"
)
const MaxPermutations = 1_000_000
type DeploymentStatus int
const (
DeploymentStatusPreparing DeploymentStatus = iota
DeploymentStatusGenerating
DeploymentStatusCommitting
DeploymentStatusProvisioning
DeploymentStatusRunning
DeploymentStatusFailed
)
func DeploymentStatusString(d DeploymentStatus) string {
switch d {
case DeploymentStatusPreparing:
return "Preparing"
case DeploymentStatusGenerating:
return "Generating"
case DeploymentStatusCommitting:
return "Committing"
case DeploymentStatusProvisioning:
return "Provisioning"
case DeploymentStatusRunning:
return "Running"
case DeploymentStatusFailed:
return "Failed"
default:
return "Unknown"
}
}
type Orchestrator struct {
actor actor.Actor
network network.Network
geo *GeoLocator
mx sync.Mutex
id string
cfg EnsembleConfig
manifest EnsembleManifest
status DeploymentStatus
ctx context.Context
cancel func()
}
func NewOrchestrator(actor actor.Actor, network network.Network, cfg EnsembleConfig) (*Orchestrator, error) {
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("failed to validate ensemble configuration: %w", err)
}
uuid, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to create orchestrator id: %w", err)
}
geo, err := NewGeoLocator()
if err != nil {
return nil, fmt.Errorf("failed to create geolocator: %w", err)
}
o := &Orchestrator{
actor: actor,
network: network,
geo: geo,
id: uuid.String(),
cfg: cfg,
}
return o, nil
}
func (o *Orchestrator) setStatus(status DeploymentStatus) {
o.mx.Lock()
defer o.mx.Unlock()
o.status = status
}
func (o *Orchestrator) Status() DeploymentStatus {
o.mx.Lock()
defer o.mx.Unlock()
return o.status
}
func (o *Orchestrator) Manifest() EnsembleManifest {
o.mx.Lock()
defer o.mx.Unlock()
return o.manifest.Clone()
}
func (o *Orchestrator) Config() EnsembleConfig {
o.mx.Lock()
defer o.mx.Unlock()
return o.cfg.Clone()
}
func (o *Orchestrator) ID() string {
return o.id
}
func (o *Orchestrator) Deploy(expiry time.Time) error {
defer func() {
if o.Status() != DeploymentStatusRunning {
o.setStatus(DeploymentStatusFailed)
}
}()
o.setStatus(DeploymentStatusPreparing)
edgeConstraintCache := make(map[string]bool)
deploy:
for time.Now().Before(expiry) {
o.setStatus(DeploymentStatusPreparing)
// 0. check if one of the ensemble nodes have peer specified
// If bid request to peer specified fails, the entire deployment must fail
nodeForTargetPeer := make(map[string]string)
for nodeID, node := range o.cfg.Nodes() {
if node.Peer != "" {
nodeForTargetPeer[node.Peer] = nodeID
}
}
// 1. Create bid requests for nodes
bidrq, err := o.makeInitialBidRequest()
if err != nil {
return fmt.Errorf("creating bid request: %w", err)
}
// 2. Collect bids
bidMap := make(map[string][]Bid)
peerExclusion := make(map[string]struct{})
addBid := func(bid Bid) bool {
// if peer is already specified on another node, ignore the bid
if _, ok := nodeForTargetPeer[bid.Peer()]; ok {
if nodeForTargetPeer[bid.Peer()] != bid.NodeID() {
return false
}
}
// check that the peer has not already submitted a bid
peerID := bid.Peer()
if _, exclude := peerExclusion[peerID]; exclude {
log.Debugf("ignoring duplicate bid from peer %s", peerID)
return false
}
err := bid.Validate()
if err != nil {
log.Debugf("failed to validate bid from peer %s: %s", peerID, err)
return false
}
// verify that this is a node in the ensemble
nodeID := bid.NodeID()
if _, ok := o.cfg.Node(nodeID); !ok {
log.Debugf("ignoring bid from peer %s for unknown node %s", peerID, nodeID)
return false
}
// verify the location constraints of the node
loc := bid.Location()
if !o.acceptPeerLocation(nodeID, peerID, loc) {
log.Debugf("ignoring out of location bid from peer %s for node %s", peerID, nodeID)
return false
}
// don't bloat the permutation space
if len(bidMap[nodeID]) >= MaxBidMultiplier {
log.Debugf("ignore bid from peer %s for saturated node %s", peerID, nodeID)
return false
}
log.Debugf("added bid to bitMap from peer %s for %s", peerID, nodeID)
bidMap[nodeID] = append(bidMap[nodeID], bid)
peerExclusion[peerID] = struct{}{}
return true
}
bidCh, bidDoneCh, bidExpiryTime, err := o.requestBids(bidrq, expiry)
if err != nil {
return fmt.Errorf("collecting bids: %w", err)
}
maxBids := MaxBidMultiplier * len(o.cfg.Nodes())
o.collectBids(bidCh, bidDoneCh, bidExpiryTime, addBid, maxBids)
// 3. Create a candidate deployment
var nextCandidate func() (map[string]Bid, bool)
var ok bool
for time.Now().Before(expiry) {
nextCandidate, ok = o.makeCandidateDeployments(bidMap)
if ok {
break
}
// we don't have bids for some of our nodes so we don't have a candidate
// we need to make a residual bid request for the remaining nodes
// Note: in order to facilitate random selection, the residual bid requests
// can drop some of the original bids
bidrq, err := o.makeResidualBidRequest(bidMap, peerExclusion)
if err != nil {
return fmt.Errorf("creating residual bid request: %w", err)
}
bidCh, bidDoneCh, bidExpiryTime, err := o.requestBids(bidrq, expiry)
if err != nil {
return fmt.Errorf("collecting residual bids: %w", err)
}
maxBids := MaxBidMultiplier * (len(o.cfg.Nodes()) - len(bidMap))
o.collectBids(bidCh, bidDoneCh, bidExpiryTime, addBid, maxBids)
}
if !ok {
log.Debugf("failed to create candidate deployments")
continue deploy
}
// 4. Iterate through the candidates trying to find one that satisfies the
// edge constraints
o.setStatus(DeploymentStatusGenerating)
var candidate map[string]Bid
for time.Now().Before(expiry) {
candidate, ok = nextCandidate()
if !ok {
log.Debugf("failed to find candidate that satisfies edge constraints")
continue deploy
}
log.Debugf("candidate deployment: %+v", candidate)
if ok := o.verifyEdgeConstraints(candidate, edgeConstraintCache); !ok {
log.Debugf("candidate does not satisfy edge constraints")
continue
}
break
}
// 5. Commit the deployment
o.setStatus(DeploymentStatusCommitting)
manifest, err := o.commit(candidate)
if err != nil {
log.Warnf("failed to commit deployment: %s", err)
continue deploy
}
o.mx.Lock()
o.manifest = manifest
o.mx.Unlock()
// 6. provision the network and start the allocations
o.setStatus(DeploymentStatusProvisioning)
if err := o.provision(manifest); err != nil {
log.Errorf("failed to privision network: %s", err)
o.revert(manifest)
continue deploy
}
// We are done! start the supervisor return the manifest.
o.mx.Lock()
o.manifest = manifest
o.ctx, o.cancel = context.WithCancel(context.Background())
o.mx.Unlock()
o.setStatus(DeploymentStatusRunning)
go o.supervise()
return nil
}
// we failed to create the deployment in time
return ErrDeploymentFailed
}
func (o *Orchestrator) Shutdown() error {
// TODO shutdown the deployment
return nil
}
func (o *Orchestrator) requestBids(bidrq EnsembleBidRequest, expiry time.Time) (chan Bid, chan struct{}, time.Time, error) {
log.Debugf("requesting bids: %+v", bidrq)
bidExpiryTime := time.Now().Add(BidRequestTimeout)
if expiry.Before(bidExpiryTime) {
return nil, nil, time.Time{}, fmt.Errorf("not enough time for deployment: %w", ErrDeploymentFailed)
}
bidExpiry := uint64(bidExpiryTime.UnixNano())
// Split requests into direct peer requests and broadcast requests
var directRequests []BidRequest
var broadcastRequests []BidRequest
for _, req := range bidrq.Request {
if req.V1 == nil {
continue
}
nodeConfig, ok := o.cfg.Node(req.V1.NodeID)
if !ok {
continue
}
if nodeConfig.Peer != "" {
// This node has a specific peer target
directRequests = append(directRequests, req)
} else {
// This node needs broadcast
broadcastRequests = append(broadcastRequests, req)
}
}
// Send direct peer requests
for _, req := range directRequests {
nodeConfig, _ := o.cfg.Node(req.V1.NodeID)
targetedReq := EnsembleBidRequest{
ID: bidrq.ID,
Request: []BidRequest{req},
PeerExclusion: bidrq.PeerExclusion,
}
err := o.requestBidPeer(targetedReq, nodeConfig, bidExpiry)
if err != nil {
return nil, nil, time.Time{}, fmt.Errorf("requesting bid to targeted peer: %w", err)
}
}
// create reply behavior for this specific ensemble bid request
bidCh := make(chan Bid)
bidDoneCh := make(chan struct{})
if err := o.actor.AddBehavior(
BidReplyBehavior,
func(msg actor.Envelope) {
defer msg.Discard()
var bid Bid
if err := json.Unmarshal(msg.Message, &bid); err != nil {
log.Debugf("failed to unmarshal bid from %s: %s", msg.From, err)
return
}
timer := time.NewTimer(time.Until(bidExpiryTime))
defer timer.Stop()
select {
case bidCh <- bid:
case <-timer.C:
case <-bidDoneCh:
}
},
actor.WithBehaviorExpiry(bidExpiry),
); err != nil {
return nil, nil, time.Time{}, fmt.Errorf("adding bid behavior: %w", err)
}
// Send broadcast
if len(broadcastRequests) > 0 {
broadcastReq := EnsembleBidRequest{
ID: bidrq.ID,
Request: broadcastRequests,
PeerExclusion: bidrq.PeerExclusion,
}
err := o.broadcastBid(broadcastReq, bidExpiry)
if err != nil {
return nil, nil, time.Time{}, fmt.Errorf("broadcasting bid request: %w", err)
}
}
return bidCh, bidDoneCh, bidExpiryTime, nil
}
func (o *Orchestrator) broadcastBid(bidrq EnsembleBidRequest, bidExpiry uint64) error {
msg, err := actor.Message(
o.actor.Handle(),
actor.Handle{},
BidRequestBehavior,
bidrq,
actor.WithMessageTopic(BidRequestTopic),
actor.WithMessageReplyTo(BidReplyBehavior),
actor.WithMessageExpiry(bidExpiry),
)
if err != nil {
return fmt.Errorf("creating broadcast bid message: %w", err)
}
if err := o.actor.Publish(msg); err != nil {
return fmt.Errorf("publishing bid request: %w", err)
}
return nil
}
func (o *Orchestrator) requestBidPeer(targetedReq EnsembleBidRequest, nodeConfig NodeConfig, bidExpiry uint64) error {
destHandle, err := actor.HandleFromPeerID(nodeConfig.Peer)
if err != nil {
return fmt.Errorf("getting handle of selected peer %s: %w", nodeConfig.Peer, err)
}
log.Infof("sending direct peer request to %s: %+v", nodeConfig.Peer, targetedReq)
msg, err := actor.Message(
o.actor.Handle(),
destHandle,
BidRequestBehavior,
targetedReq,
actor.WithMessageReplyTo(BidReplyBehavior),
actor.WithMessageExpiry(bidExpiry),
)
if err != nil {
return fmt.Errorf("creating targeted bid message: %w", err)
}
if err := o.actor.Send(msg); err != nil {
return fmt.Errorf("sending targeted bid request: %w", err)
}
return nil
}
func (o *Orchestrator) collectBids(bidCh chan Bid, bidDoneCh chan struct{}, bidExpiryTime time.Time, addBid func(Bid) bool, maxBids int) {
defer close(bidDoneCh)
timer := time.NewTimer(time.Until(bidExpiryTime))
defer timer.Stop()
bidCount := 0
for {
select {
case bid := <-bidCh:
if err := bid.Validate(); err != nil {
log.Debugf("got invalid bid: %s", err)
continue
}
if bid.EnsembleID() != o.id {
log.Debugf("got bid for unexpected ensemble ID: %s", bid.EnsembleID())
continue
}
if addBid(bid) {
bidCount++
if bidCount >= maxBids {
return
}
}
case <-timer.C:
return
}
}
}
func (o *Orchestrator) makeCandidateDeployments(bids map[string][]Bid) (func() (map[string]Bid, bool), bool) {
// immediate satisfaction check: we need a bid for every node
if len(o.cfg.Nodes()) != len(bids) {
return nil, false
}
// first shuffle all the bids to seed the permutation generator
for _, blst := range bids {
rand.Shuffle(len(blst), func(i, j int) {
blst[i], blst[j] = blst[j], blst[i]
})
}
// count the bits in the permutation space; if it is more than 63, we need to use
// a bignum bassed permutation generator or it will overflow.
bits := 0
for _, blst := range bids {
bits += int(math.Ceil(math.Log2(float64(len(blst)))))
}
if bits > 63 {
return o.makeCandidateDeploymentBig(bids)
}
return o.makeCandidateDeploymentSmall(bids)
}
func (o *Orchestrator) makeCandidateDeploymentSmall(bids map[string][]Bid) (func() (map[string]Bid, bool), bool) {
// fix the order of permutation
type permutator struct {
mod int64
node string
bids []Bid
}
permutators := make([]permutator, 0, len(bids))
modulus := int64(1)
for n, blst := range bids {
permutators = append(permutators, permutator{mod: modulus, node: n, bids: blst})
modulus *= int64(len(blst))
}
// function to get a permutation by index
getPermutation := func(index int64) map[string]Bid {
result := make(map[string]Bid)
for _, permutator := range permutators {
selection := (index / permutator.mod) % int64(len(permutator.bids))
result[permutator.node] = permutator.bids[selection]
}
return result
}
// and return a function that gets a random next permutation
// note that we cache the constraint results, so potential duplication is ok.
// also note that the permutation space is large enough so that it's ok to skip
// some permutations.
// final note: Obviously we can deterministically generate all permutations in order
// (and we were doing that initially) but this has the problem that we are not
// modifying the network structure enough to get meaningful variance in a reasonable
// time.
nperm := modulus
if nperm > MaxPermutations {
nperm = MaxPermutations
}
count := int64(0)
return func() (map[string]Bid, bool) {
for count < nperm {
count++
nextPerm := rand.Int63n(nperm)
perm := getPermutation(nextPerm)
if !o.checkPermutationEdgeConstraints(perm) {
continue
}
return perm, true
}
return nil, false
}, true
}
func (o *Orchestrator) makeCandidateDeploymentBig(bids map[string][]Bid) (func() (map[string]Bid, bool), bool) {
// Note: this is the same as above with bignums
// fix the order of permutation
type permutator struct {
mod *big.Int
node string
bids []Bid
}
permutators := make([]permutator, 0, len(bids))
modulus := big.NewInt(1)
for n, blst := range bids {
permutators = append(permutators, permutator{mod: modulus, node: n, bids: blst})
modulus = new(big.Int).Mul(modulus, big.NewInt(int64(len(blst))))
}
// function to get a permutation by index
getPermutation := func(index *big.Int) map[string]Bid {
result := make(map[string]Bid)
for _, permutator := range permutators {
selection := int(
new(big.Int).Mod(
new(big.Int).Div(index, permutator.mod),
big.NewInt(int64(len(permutator.bids))),
).Int64(),
)
result[permutator.node] = permutator.bids[selection]
}
return result
}
// and return a function that gets a random next permutation
// note that we cache the constraint results, so potential duplication is ok.
// also note that the permutation space is large enough so that it's ok to skip
// some permutations.
// final note: Obviously we can deterministically generate all permutations in order
// (and we were doing that initially) but this has the problem that we are not
// modifying the network structure enough to get meaningful variance in a reasonable
// time.
nperm := MaxPermutations
count := 0
bytes := make([]byte, (modulus.BitLen()+7)/8)
return func() (map[string]Bid, bool) {
for count < nperm {
count++
if _, err := crand.Read(bytes); err != nil {
log.Errorf("error reading random bytes: %s", err)
return nil, false
}
nextPerm := new(big.Int).SetBytes(bytes)
perm := getPermutation(nextPerm)
if !o.checkPermutationEdgeConstraints(perm) {
continue
}
return perm, true
}
return nil, false
}, true
}
func (o *Orchestrator) checkPermutationEdgeConstraints(candidate map[string]Bid) bool {
for _, cst := range o.cfg.EdgeConstraints() {
if cst.RTT == 0 {
continue
}
bidS := candidate[cst.S]
bidT := candidate[cst.T]
locS, err := o.geo.Coordinate(bidS.Location())
if err != nil {
log.Errorf("Failed to get location for bid %s: %v", bidS.NodeID(), err)
continue
}
locT, err := o.geo.Coordinate(bidT.Location())
if err != nil {
log.Errorf("Failed to get location for bid %s: %v", bidT.NodeID(), err)
continue
}
distance := computeGeodesic(locS, locT)
// in milliseconds
minRTT := (distance / lightSpeed) * 2 * 1000
if minRTT > float64(cst.RTT) {
log.Debugf("Edge constraint not satisfied: min RTT %.2f ms > constraint %d ms for %s -> %s", minRTT, cst.RTT, cst.S, cst.T)
return false
}
// TODO: add bandwidth check when that information becomes available
}
return true
}
func (o *Orchestrator) verifyEdgeConstraints(candidate map[string]Bid, cache map[string]bool) bool {
var mx sync.Mutex
var wg sync.WaitGroup
var toVerify []EdgeConstraint
for _, cst := range o.cfg.EdgeConstraints() {
bidS := candidate[cst.S]
bidT := candidate[cst.T]
key := bidS.Peer() + ":" + bidT.Peer()
accept, ok := cache[key]
if !ok {
toVerify = append(toVerify, cst)
continue
}
if !accept {
return false
}
}
if len(toVerify) == 0 {
return true
}
accept := true
wg.Add(len(toVerify))
for _, cst := range toVerify {
go func(cst EdgeConstraint) {
result := o.verifyEdgeConstraint(candidate, cst)
bidS := candidate[cst.S]
bidT := candidate[cst.T]
key := bidS.Peer() + ":" + bidT.Peer()
mx.Lock()
cache[key] = result
accept = accept && result
mx.Unlock()
}(cst)
}
wg.Wait()
return accept
}
func (o *Orchestrator) verifyEdgeConstraint(candidate map[string]Bid, cst EdgeConstraint) bool {
bidS := candidate[cst.S]
bidT := candidate[cst.T]
key := bidS.Peer() + ":" + bidT.Peer()
log.Debugf("verify edge constraint %s %v", key, cst)
handle := bidS.Handle()
msg, err := actor.Message(
o.actor.Handle(),
handle,
VerifyEdgeConstraintBehavior,
VerifyEdgeConstraintRequest{
EnsembleID: o.id,
S: bidS.Peer(),
T: bidT.Peer(),
RTT: cst.RTT,
BW: cst.BW,
},
actor.WithMessageTimeout(VerifyEdgeConstraintTimeout),
)
if err != nil {
log.Warnf("error creating constraint check message for %s: %s", key, err)
return false
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
log.Warnf("error invoking constraint check for %s: %s", key, err)
return false
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(VerifyEdgeConstraintTimeout):
return false
}
defer reply.Discard()
var response VerifyEdgeConstraintResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
log.Warnf("error unmarshalling bid constraint response for %s: %s", key, err)
return false
}
if response.Error != "" {
log.Debugf("error verifying bid constraint for %s: %s", key, err)
}
return response.OK
}
func (o *Orchestrator) commit(candidate map[string]Bid) (EnsembleManifest, error) {
// This is a two phase commit:
// - first commit the resources in all the nodes to ensure the deployment is (still)
// feasible.
// - then create all the allocations for provisioning
// - if there are any failures, we need to revert this deployment and start anew
var mx sync.Mutex
// Phase 1: commit
var wg1 sync.WaitGroup
ok := true
committed := make([]string, 0, len(candidate))
wg1.Add(len(candidate))
for n, bid := range candidate {
go func(n string, bid Bid) {
defer wg1.Done()
err := o.commitDeployment(n, bid.Handle())
mx.Lock()
if err != nil {
log.Errorf("error committing bid for %s: %s", n, err)
ok = false
} else {
log.Debugf("committed resources for %s", n)
committed = append(committed, n)
}
mx.Unlock()
}(n, bid)
}
wg1.Wait()
if !ok {
for _, n := range committed {
bid := candidate[n]
o.revertDeployment(n, bid.Handle())
}
return EnsembleManifest{}, fmt.Errorf("failed to commit resources: %w", ErrDeploymentFailed)
}
// Phase 2: allocate
var wg2 sync.WaitGroup
allocations := make(map[string]actor.Handle)
wg2.Add(len(candidate))
for n, bid := range candidate {
go func(n string, bid Bid) {
defer wg2.Done()
allocated, err := o.allocate(n, bid.Handle())
mx.Lock()
if err != nil {
log.Errorf("error allocating deployment for %s: %s", n, err)
ok = false
} else {
log.Debugf("allocating deployment for %s", n)
for a, h := range allocated {
allocations[a] = h
}
}
mx.Unlock()
}(n, bid)
}
wg2.Wait()
if !ok {
for n, bid := range candidate {
o.revertDeployment(n, bid.Handle())
}
return EnsembleManifest{}, fmt.Errorf("failed to allocate resources: %w", ErrDeploymentFailed)
}
// We are done, create the (partial) manifest
// There are certain details that are filled during provisioning, e.g. allocation
// VPN addresses and public port mappings
mf := EnsembleManifest{
ID: o.id,
Orchestrator: o.actor.Handle(),
Allocations: make(map[string]AllocationManifest),
Nodes: make(map[string]NodeManifest),
}
allocationNodes := make(map[string]string)
for n, bid := range candidate {
ncfg, _ := o.cfg.Node(n)
nmf := NodeManifest{
ID: n,
Peer: bid.Peer(),
Handle: bid.Handle(),
Location: bid.Location(),
Allocations: ncfg.Allocations,
}
for _, a := range nmf.Allocations {
allocationNodes[a] = n
}
mf.Nodes[n] = nmf
}
for a := range o.cfg.Allocations() {
amf := AllocationManifest{
ID: a,
NodeID: allocationNodes[a],
Handle: allocations[a],
}
mf.Allocations[a] = amf
}
return mf, nil
}
func (o *Orchestrator) commitDeployment(n string, h actor.Handle) error {
msg, err := actor.Message(
o.actor.Handle(),
h,
CommitDeploymentBehavior,
CommitDeploymentRequest{
EnsembleID: o.id,
NodeID: n,
},
actor.WithMessageTimeout(CommitDeploymentTimeout),
)
if err != nil {
return fmt.Errorf("failed to create commit message for %s: %w", n, err)
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
return fmt.Errorf("failed to invoke commit for %s: %w", n, err)
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(CommitDeploymentTimeout):
return fmt.Errorf("timeout committing for %s: %w", n, ErrDeploymentFailed)
}
defer reply.Discard()
var response CommitDeploymentResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
return fmt.Errorf("error unmarshalling commit response for %s: %w", n, err)
}
if !response.OK {
return fmt.Errorf("error committing for %s: %s: %w", n, response.Error, ErrDeploymentFailed)
}
return nil
}
func (o *Orchestrator) revertDeployment(n string, h actor.Handle) {
msg, err := actor.Message(
o.actor.Handle(),
h,
RevertDeploymentBehavior,
RevertDeploymentMessage{
EnsembleID: o.id,
NodeID: n,
},
)
if err != nil {
log.Debugf("failed to create revert message for %s: %s", n, err)
return
}
if err := o.actor.Send(msg); err != nil {
log.Debugf("failed to send revert message for %s: %s", n, err)
}
}
func (o *Orchestrator) allocate(n string, h actor.Handle) (map[string]actor.Handle, error) {
allocs := make(map[string]AllocationDeploymentConfig)
ncfg, _ := o.cfg.Node(n)
for _, a := range ncfg.Allocations {
acfg, _ := o.cfg.Allocation(a)
provisionScripts := make(map[string][]byte)
for _, p := range acfg.Provision {
provisionScripts[p] = o.cfg.V1.Scripts[p]
}
allocs[a] = AllocationDeploymentConfig{
Executor: acfg.Executor,
Resources: acfg.Resources,
Execution: acfg.Execution,
ProvisionScripts: provisionScripts,
}
}
msg, err := actor.Message(
o.actor.Handle(),
h,
AllocationDeploymentBehavior,
AllocationDeploymentRequest{
EnsembleID: o.id,
NodeID: n,
Allocations: allocs,
},
actor.WithMessageTimeout(AllocationDeploymentTimeout),
)
if err != nil {
return nil, fmt.Errorf("failed to create allocation message for %s: %w", n, err)
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
return nil, fmt.Errorf("failed to invoke allocate for %s: %w", n, err)
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(AllocationDeploymentTimeout):
return nil, fmt.Errorf("timeout in allocation for %s: %w", n, err)
}
defer reply.Discard()
var response AllocationDeploymentResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
return nil, fmt.Errorf("unmarshalling allocation response: %w", err)
}
if !response.OK {
return nil, fmt.Errorf("allocation for %s failed: %s: %w", n, response.Error, ErrDeploymentFailed)
}
// verify that the allocation map has all the allocations
for a := range allocs {
if _, ok := response.Allocations[a]; !ok {
return nil, fmt.Errorf("missing allocation %s for %s: %w", a, n, ErrDeploymentFailed)
}
}
return response.Allocations, nil
}
func (o *Orchestrator) provision(em EnsembleManifest) error {
// 0. start the allocations
allocCfgs := o.cfg.Allocations()
errCh := make(chan error, len(em.Allocations))
wg := sync.WaitGroup{}
for id, manifest := range em.Allocations {
wg.Add(1)
go func(id string, manifest AllocationManifest) {
defer wg.Done()
config, ok := allocCfgs[id]
if !ok {
errCh <- fmt.Errorf("error retreiving allocation from config: %w", ErrProvisioningFailed)
return
}
msg, err := actor.Message(
o.actor.Handle(),
manifest.Handle,
AllocationStartBehavior,
AllocationStartRequest{
Resources: config.Resources,
Execution: config.Execution,
Executor: config.Executor,
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet message: %w", err)
return
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("timeout creating subnet: %w", ErrDeploymentFailed)
return
}
defer reply.Discard()
var response struct {
OK bool
Error string
}
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error creating subnet: %s: %w", response.Error, ErrDeploymentFailed)
return
}
}(id, manifest)
}
wg.Wait()
close(errCh)
var aggErr error
for err := range errCh {
if aggErr == nil {
aggErr = err
continue
} else if err != nil {
aggErr = fmt.Errorf("%w\n%w", aggErr, err)
}
}
if aggErr != nil {
return aggErr
}
// 1. create subnet
// 1.a generate routing table
cidr, err := utils.GetRandomCIDRInRange(
24,
net.ParseIP("10.0.0.0"),
net.ParseIP("10.255.255.255"),
[]string{},
)
if err != nil {
return fmt.Errorf("error getting random CIDR: %w", err)
}
usedIPs := make(map[string]bool)
routingTable := make(map[string]string)
indexRoutingTable := make(map[string]string)
for _, manifest := range em.Allocations {
ip, err := utils.GetNextIP(cidr, usedIPs)
if err != nil {
return fmt.Errorf("error getting next IP: %w", err)
}
routingTable[ip.String()] = em.Nodes[manifest.NodeID].Peer
indexRoutingTable[manifest.NodeID] = ip.String()
}
errCh = make(chan error, len(em.Allocations))
wg = sync.WaitGroup{}
for _, manifest := range em.Allocations {
wg.Add(1)
go func() {
defer wg.Done()
msg, err := actor.Message(
o.actor.Handle(),
manifest.Handle,
SubnetCreateBehavior,
SubnetCreateRequest{
SubnetID: em.ID,
RoutingTable: routingTable,
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet message: %w", err)
return
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("timeout creating subnet: %w", ErrDeploymentFailed)
return
}
defer reply.Discard()
var response struct {
OK bool
Error string
}
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error creating subnet: %s: %w", response.Error, ErrDeploymentFailed)
return
}
}()
}
wg.Wait()
close(errCh)
aggErr = nil
for err := range errCh {
if aggErr == nil {
aggErr = err
continue
} else if err != nil {
aggErr = fmt.Errorf("%w\n%w", aggErr, err)
}
}
if aggErr != nil {
return aggErr
}
// 1.b create and plug IPs
wg = sync.WaitGroup{}
errCh = make(chan error, len(em.Allocations))
for _, manifest := range em.Allocations {
wg.Add(1)
go func() {
defer wg.Done()
msg, err := actor.Message(
o.actor.Handle(),
manifest.Handle,
SubnetAddPeerBehavior,
SubnetAddPeerRequest{
SubnetID: em.ID,
IP: indexRoutingTable[manifest.NodeID],
PeerID: em.Nodes[manifest.NodeID].Peer,
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet add-peer message: %w", err)
return
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet add-peer message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("timeout adding peer to subnet: %w", ErrDeploymentFailed)
return
}
defer reply.Discard()
var response struct {
OK bool
Error string
}
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet add-peer response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error adding peer to subnet: %s: %w", response.Error, ErrDeploymentFailed)
return
}
}()
}
wg.Wait()
close(errCh)
aggErr = nil
for err := range errCh {
if aggErr == nil {
aggErr = err
continue
} else if err != nil {
aggErr = fmt.Errorf("%w\n%w", aggErr, err)
}
}
if aggErr != nil {
return aggErr
}
// 1.c configure DNS
wg = sync.WaitGroup{}
errCh = make(chan error, len(em.Allocations))
for _, manifest := range em.Allocations {
wg.Add(1)
go func() {
defer wg.Done()
msg, err := actor.Message(
o.actor.Handle(),
manifest.Handle,
SubnetDNSAddRecordBehavior,
SubnetDNSAddRecordRequest{
SubnetID: em.ID,
DomainName: manifest.DNSName,
IP: indexRoutingTable[manifest.NodeID],
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet add-peer message: %w", err)
return
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet add-peer message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("timeout adding peer to subnet: %w", ErrDeploymentFailed)
return
}
defer reply.Discard()
var response struct {
OK bool
Error string
}
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet add-peer response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error adding peer to subnet: %s: %w", response.Error, ErrDeploymentFailed)
return
}
}()
}
wg.Wait()
close(errCh)
aggErr = nil
for err := range errCh {
if aggErr == nil {
aggErr = err
continue
} else if err != nil {
aggErr = fmt.Errorf("%w\n%w", aggErr, err)
}
}
if aggErr != nil {
return aggErr
}
// 1.d configure port mapping
wg = sync.WaitGroup{}
errCh = make(chan error, len(em.Allocations))
for _, manifest := range em.Allocations {
for srcPort, destPort := range manifest.Ports {
wg.Add(1)
go func() {
defer wg.Done()
msg, err := actor.Message(
o.actor.Handle(),
manifest.Handle,
SubnetMapPortBehavior,
SubnetMapPortRequest{
Protocol: "TCP", // TODO: add support in AllocationManifest for protocol
SourceIP: "0.0.0.0",
SourcePort: strconv.Itoa(srcPort),
DestIP: indexRoutingTable[manifest.NodeID],
DestPort: strconv.Itoa(destPort),
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet add-peer message: %w", err)
return
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet add-peer message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("timeout adding peer to subnet: %w", ErrDeploymentFailed)
return
}
defer reply.Discard()
var response struct {
OK bool
Error string
}
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet add-peer response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error adding peer to subnet: %s: %w", response.Error, ErrDeploymentFailed)
return
}
}()
}
}
wg.Wait()
close(errCh)
aggErr = nil
for err := range errCh {
if aggErr == nil {
aggErr = err
continue
} else if err != nil {
aggErr = fmt.Errorf("%w\n%w", aggErr, err)
}
}
if aggErr != nil {
return aggErr
}
return nil
}
func (o *Orchestrator) revert(mf EnsembleManifest) {
for n, nmf := range mf.Nodes {
o.revertDeployment(n, nmf.Handle)
}
}
func (o *Orchestrator) acceptPeerLocation(nodeID, peerID string, loc Location) bool {
n, ok := o.cfg.Node(nodeID)
if !ok {
return false
}
// check explicit peer placement
if n.Peer != "" {
return n.Peer == peerID
}
// check acceptable locations
if len(n.Location.Accept) > 0 {
accept := false
for _, acceptable := range n.Location.Accept {
if acceptable.Includes(loc) {
accept = true
break
}
}
if !accept {
return false
}
}
// check unacceptable locations
if len(n.Location.Reject) > 0 {
reject := false
for _, unacceptable := range n.Location.Reject {
if unacceptable.Includes(loc) {
reject = true
break
}
}
if reject {
return false
}
}
return true
}
func (o *Orchestrator) makeInitialBidRequest() (EnsembleBidRequest, error) {
return o.ensembleConfigToBidRequest(&o.cfg)
}
func (o *Orchestrator) makeResidualBidRequest(candidate map[string][]Bid, exclusion map[string]struct{}) (EnsembleBidRequest, error) {
residualConfig := EnsembleConfig{
V1: &EnsembleConfigV1{
Allocations: make(map[string]AllocationConfig),
Nodes: make(map[string]NodeConfig),
},
}
for n, ncfg := range o.cfg.V1.Nodes {
if _, exclude := candidate[n]; exclude {
continue
}
residualConfig.V1.Nodes[n] = ncfg
}
for id, ncfg := range residualConfig.V1.Nodes {
log.Debugf("still looking for node %s", id)
for _, a := range ncfg.Allocations {
residualConfig.V1.Allocations[a] = o.cfg.V1.Allocations[a]
}
}
result, err := o.ensembleConfigToBidRequest(&residualConfig)
if err != nil {
return result, err
}
for p := range exclusion {
result.PeerExclusion = append(result.PeerExclusion, p)
}
return result, nil
}
func (o *Orchestrator) ensembleConfigToBidRequest(config *EnsembleConfig) (EnsembleBidRequest, error) {
v1Config := config.V1
ensembleBidRequest := EnsembleBidRequest{
ID: o.id,
}
for nodeID, nodeConfig := range v1Config.Nodes {
bidRequest := BidRequest{
V1: &BidRequestV1{
NodeID: nodeID,
Location: nodeConfig.Location,
},
}
var aggregateResources types.Resources
var executors []AllocationExecutor
var staticPorts []int
dynamicPortsCount := 0
for _, allocationName := range nodeConfig.Allocations {
allocationConfig, ok := v1Config.Allocations[allocationName]
if !ok {
continue
}
if !containsExecutor(executors, allocationConfig.Executor) {
executors = append(executors, allocationConfig.Executor)
}
err := aggregateResources.Add(allocationConfig.Resources)
if err != nil {
return EnsembleBidRequest{}, err
}
for _, portConfig := range nodeConfig.Ports {
if portConfig.Allocation == allocationName {
if portConfig.Public == 0 {
dynamicPortsCount++
} else {
staticPorts = append(staticPorts, portConfig.Public)
}
}
}
}
bidRequest.V1.Executors = executors
bidRequest.V1.Resources = aggregateResources
bidRequest.V1.PublicPorts.Static = staticPorts
bidRequest.V1.PublicPorts.Dynamic = dynamicPortsCount
ensembleBidRequest.Request = append(ensembleBidRequest.Request, bidRequest)
}
return ensembleBidRequest, nil
}
func (o *Orchestrator) supervise() {
// TODO
}
func (o *Orchestrator) Stop() {
// TODO
}
func containsExecutor(executors []AllocationExecutor, executor AllocationExecutor) bool {
for _, e := range executors {
if e == executor {
return true
}
}
return false
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ensemblev1
import (
"fmt"
"strings"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
func NewEnsemblev1Transformer() transform.Transformer {
return transform.NewTransformer(
[]map[tree.Path]transform.TransformerFunc{
{
"allocations.*.volumes": transform.MapToNamedSliceTransformer("volume"),
"volumes": transform.MapToNamedSliceTransformer("volume"),
"resources": transform.MapToNamedSliceTransformer("resource"),
},
{
"allocations.*.volumes.[]": TransformVolume,
"allocations.*.resources": TransformResources,
"scripts.*": TransformStringToBytes,
},
{
"allocations.*.execution": transform.SpecConfigTransformer("execution"),
"allocations.*.volumes.[].remote": transform.SpecConfigTransformer("remote volume"),
"edge_constraints.[]": TransformEdgeConstraint,
},
{
"": TransformSpec,
},
},
)
}
func TransformStringToBytes(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
str, ok := data.(string)
if !ok {
return nil, fmt.Errorf("invalid string data: %v", data)
}
data = []byte(str)
return data, nil
}
// TransformSpec transforms the spec configuration and wraps it in a "V1" key.
func TransformSpec(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
spec, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid spec configuration: %v", data)
}
// move edge_constraints to edges
if edgeConstraints, ok := spec["edge_constraints"]; ok {
spec["edges"] = edgeConstraints
delete(spec, "edge_constraints")
}
return map[string]any{"V1": spec}, nil
}
// TransformEdgeConstraint maps the edges parameter to Source and Target (S and T) properties.
func TransformEdgeConstraint(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
edgeConstraints, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid edge constraints: %v", data)
}
// Map the edges parameter to Source and Target (S and T) properties
if edges, ok := edgeConstraints["edges"]; ok {
// Assert edges is a list of two strings
edgesList, ok := edges.([]any)
if !ok || len(edgesList) != 2 {
return nil, fmt.Errorf("invalid edges parameter: %v", edges)
}
edgeConstraints["S"] = edgesList[0]
edgeConstraints["T"] = edgesList[1]
delete(edgeConstraints, "edges")
}
return edgeConstraints, nil
}
// TransformVloume transforms the volume configuration and handles inheritance.
// The volume configuration can be a string in the format "name:mountpoint" or a map.
// If the volume is defined in the parent volumes, the configurations are merged.
func TransformVolume(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, split it into name and mountpoint.
switch v := data.(type) {
case string:
mapping := strings.Split(v, ":")
if len(mapping) != 2 {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
config = map[string]any{
"name": mapping[0],
"mountpoint": mapping[1],
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
if path.Matches("allocations.*.volumes.[]") {
// Handle volume inheritance
parent := tree.NewPath("")
c, err := transform.GetConfigAtPath(*root, parent.Next("volumes"))
if err != nil {
return config, nil
}
volumes, _ := transform.ToAnySlice(c)
for _, v := range volumes {
if volume, ok := v.(map[string]any); ok && volume["name"] == config["name"] {
// Merge the configurations
for k, v := range config {
volume[k] = v
}
config = volume
}
}
}
return config, nil
}
// TransformResources transforms the resources configuration and handles inheritance.
// The resources configuration can be a string reference "reference" or a map.
// If the resources is defined in the parent resources, the configurations are merged.
func TransformResources(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, transform it to a map with the name as the reference
switch v := data.(type) {
case string:
config = map[string]any{
"name": v,
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid resources configuration: %v", data)
}
if path.Matches("allocations.*.resources") {
// Handle volume inheritance
parent := tree.NewPath("")
c, err := transform.GetConfigAtPath(*root, parent.Next("resources"))
if err != nil {
return config, nil
}
resources, _ := transform.ToAnySlice(c)
for _, v := range resources {
if rcs, ok := v.(map[string]any); ok && rcs["name"] == config["name"] {
// Merge the configurations
for k, v := range config {
rcs[k] = v
}
config = rcs
}
}
}
return config, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ensemblev1
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewEnsembleV1Validator() validate.Validator {
return validate.NewValidator(
map[tree.Path]validate.ValidatorFunc{
"V1": ValidateSpec,
"V1.allocations.*": ValidateAllocation,
},
)
}
// ValidateSpec checks the root configuration for consistency.
func ValidateSpec(_ *map[string]any, data any, _ tree.Path) error {
spec, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid spec configuration: %v", data)
}
// TODO: Specify and complete validation - Dawit Abate
// Check if the allocations map is present and not empty.
if spec["allocations"] == nil || len(spec["allocations"].(map[string]any)) == 0 {
return fmt.Errorf("allocations list is required")
}
return nil
}
// ValidateAllocation checks the allocation configuration.
func ValidateAllocation(_ *map[string]any, data any, _ tree.Path) error {
allocation, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid allocation configuration: %v", data)
}
// TODO: Specify and complete validation - Dawit Abate
// Check if the allocation has an execution.
if allocation["execution"] == nil {
return fmt.Errorf("allocation must have an execution")
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package parser
import (
"gitlab.com/nunet/device-management-service/dms/jobs/parser/ensemblev1"
)
var registry *Registry
func init() {
registry = &Registry{
parsers: make(map[SpecType]Parser),
}
// Register Nunet parser.
ensembleV1Parser := NewBasicParser(
ensemblev1.NewEnsemblev1Transformer(),
ensemblev1.NewEnsembleV1Validator(),
)
registry.RegisterParser(SpecTypeEnsembleV1, ensembleV1Parser)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package parser
import (
"fmt"
)
func Parse(specType SpecType, data []byte, result any) error {
parser, exists := registry.GetParser(specType)
if !exists {
return fmt.Errorf("parser for spec type %s not found", specType)
}
err := parser.Parse(data, result)
if err != nil {
return err
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package parser
import (
"encoding/json"
"fmt"
"github.com/go-viper/mapstructure/v2"
yaml "gopkg.in/yaml.v3"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
type SpecType string
const (
SpecTypeEnsembleV1 SpecType = "ensembleV1"
)
const DefaultTagName = "json"
type Parser interface {
Parse(data []byte, dest any) error
}
type BasicParser struct {
validator validate.Validator
transformer transform.Transformer
}
func NewBasicParser(transformer transform.Transformer, validator validate.Validator) Parser {
return BasicParser{
transformer: transformer,
validator: validator,
}
}
func (p BasicParser) Parse(data []byte, result any) error {
var rawConfig map[string]any
// Try to unmarshal as YAML first
err := yaml.Unmarshal(data, &rawConfig)
if err != nil {
// If YAML fails, try JSON
err = json.Unmarshal(data, &rawConfig)
if err != nil {
return fmt.Errorf("failed to parse config: %v", err)
}
}
// Apply transformers
transformed, err := p.transformer.Transform(&rawConfig)
if err != nil {
return fmt.Errorf("failed to transform config: %v", err)
}
transformedMap, ok := transformed.(map[string]any)
if !ok {
return fmt.Errorf("transformed config is not a map: %v", transformed)
}
// Validate the transformed configuration
if err = p.validator.Validate(&transformedMap); err != nil {
return err
}
// Create a new mapstructure decoder
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: result,
TagName: DefaultTagName,
})
if err != nil {
return fmt.Errorf("failed to create decoder: %v", err)
}
// Decode the transformed configuration
err = decoder.Decode(transformed)
if err != nil {
return fmt.Errorf("failed to decode config: %v", err)
}
return err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package parser
import (
"sync"
)
type Registry struct {
parsers map[SpecType]Parser
mu sync.RWMutex
}
func (r *Registry) RegisterParser(specType SpecType, p Parser) {
r.mu.Lock()
defer r.mu.Unlock()
r.parsers[specType] = p
}
func (r *Registry) GetParser(specType SpecType) (Parser, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
p, exists := r.parsers[specType]
return p, exists
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package transform
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// SpecConfigTransformer converts a map to a map with a "type" field and a "params" field.
func SpecConfigTransformer(specName string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
spec, ok := data.(map[string]any)
result := map[string]any{}
if !ok {
return nil, fmt.Errorf("invalid %s configuration: %v", specName, data)
}
params := map[string]any{}
for k, v := range spec {
if k != "type" {
params[k] = v
}
}
result["type"] = spec["type"]
result["params"] = params
return result, nil
}
}
// MapToNamedSliceTransformer converts a map of maps to a slice of maps and assigns the key to the "name" field.
func MapToNamedSliceTransformer(name string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
maps, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid %s configuration: %v", name, data)
}
result := []any{}
for k, v := range maps {
if v == nil {
v = map[string]any{}
}
if e, ok := v.(map[string]any); ok {
e["name"] = k
}
result = append(result, v)
}
return result, nil
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package transform
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// TransformerFunc is a function that transforms a part of the configuration.
// It modifies the data to conform to the expected structure and returns the transformed data.
// It takes the root configuration, the data to transform and the current path in the tree.
type TransformerFunc func(*map[string]interface{}, interface{}, tree.Path) (any, error)
// Transformer is a configuration transformer.
type Transformer interface {
Transform(*map[string]interface{}) (interface{}, error)
}
// TransformerImpl is the implementation of the Transformer interface.
type TransformerImpl struct {
transformers []map[tree.Path]TransformerFunc
}
// NewTransformer creates a new transformer with the given transformers.
func NewTransformer(transformers []map[tree.Path]TransformerFunc) Transformer {
return TransformerImpl{
transformers: transformers,
}
}
// Transform applies the transformers to the configuration.
func (t TransformerImpl) Transform(rawConfig *map[string]interface{}) (interface{}, error) {
data := any(*rawConfig)
var err error
for _, transformers := range t.transformers {
data, err = t.transform(rawConfig, data, tree.NewPath(), transformers)
if err != nil {
return nil, err
}
}
// return Normalize(data), nil
return data, nil
}
// transform is a recursive function that applies the transformers to the configuration.
func (t TransformerImpl) transform(root *map[string]interface{}, data any, path tree.Path, transformers map[tree.Path]TransformerFunc) (interface{}, error) {
var err error
// Apply transformers that match the current path.
for pattern, transformer := range transformers {
if path.Matches(pattern) {
data, err = transformer(root, data, path)
if err != nil {
return nil, err
}
}
}
// Recursively apply transformers to children.
if result, ok := data.(map[string]interface{}); ok {
for key, value := range result {
next := path.Next(key)
result[key], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
} else if result, err := ToAnySlice(data); err == nil {
for i, value := range result {
next := path.Next(fmt.Sprintf("[%d]", i))
result[i], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
}
return data, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package transform
import (
"fmt"
"reflect"
"sort"
"strconv"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// getConfigAtPath retrieves a part of the configuration at a given path
func GetConfigAtPath(config map[string]interface{}, path tree.Path) (any, error) {
current := any(config)
for _, key := range path.Parts() {
switch v := current.(type) {
case map[string]any:
current = v[key]
case []any, []map[string]any:
i, err := strconv.Atoi(key[1 : len(key)-1])
if err != nil {
return nil, fmt.Errorf("invalid index: %v", key)
}
switch v := v.(type) {
case []any:
current = v[i]
case []map[string]any:
current = v[i]
}
default:
return nil, fmt.Errorf("invalid data type: %v", current)
}
}
return current, nil
}
// Generic function to convert any slice to []any
func ToAnySlice(slice any) ([]any, error) {
value := reflect.ValueOf(slice)
// Check if the input is a slice
if value.Kind() != reflect.Slice {
return nil, fmt.Errorf("input is not a slice. type: %T", slice)
}
length := value.Len()
anySlice := make([]any, length)
for i := 0; i < length; i++ {
anySlice[i] = value.Index(i).Interface()
}
return anySlice, nil
}
func normalizeMap(m interface{}) interface{} {
v := reflect.ValueOf(m)
switch v.Kind() {
case reflect.Map:
// Create a new map to hold normalized values
newMap := reflect.MakeMap(reflect.MapOf(v.Type().Key(), reflect.TypeOf((*interface{})(nil)).Elem()))
for _, key := range v.MapKeys() {
newValue := normalizeMap(v.MapIndex(key).Interface())
newMap.SetMapIndex(key, reflect.ValueOf(newValue))
}
return newMap.Interface()
case reflect.Slice:
// Create a new []interface{} slice to hold normalized values
newSlice := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
newSlice[i] = normalizeMap(v.Index(i).Interface())
}
// Sort the slice if it's sortable
sort.Slice(newSlice, func(i, j int) bool {
return fmt.Sprint(newSlice[i]) < fmt.Sprint(newSlice[j])
})
return newSlice
default:
// For other types, return as is
return m
}
}
// Normalize is the exported function that users will call
func Normalize(m any) interface{} {
return normalizeMap(m)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package tree
import (
"strings"
)
const (
configPathSeparator = "."
configPathMatchAny = "*"
configPathMatchAnyMultiple = "**"
configPathList = "[]"
)
// Path is a custom type for representing paths in the configuration
type Path string
func NewPath(path ...string) Path {
return Path(strings.Join(path, configPathSeparator))
}
// Parts returns the parts of the path
func (p Path) Parts() []string {
return strings.Split(string(p), configPathSeparator)
}
// Parent returns the parent path
func (p Path) Parent() Path {
parts := p.Parts()
if len(parts) > 1 {
return Path(strings.Join(parts[:len(parts)-1], configPathSeparator))
}
return ""
}
// Next returns the next part of the path
func (p Path) Next(path string) Path {
if path == "" {
return p
}
if p == "" {
return Path(path)
}
return Path(string(p) + configPathSeparator + path)
}
// Last returns the last part of the path
func (p Path) Last() string {
parts := p.Parts()
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}
// Matches checks if the path matches a given pattern
func (p Path) Matches(pattern Path) bool {
pathParts := p.Parts()
patternParts := pattern.Parts()
return matchParts(pathParts, patternParts)
}
// String returns the string representation of the path
func (p Path) String() string {
return string(p)
}
// matchParts checks if the path parts match the pattern parts
func matchParts(pathParts, patternParts []string) bool {
// If the pattern is longer than the path, it can't match
if len(pathParts) < len(patternParts) {
return false
}
for i, part := range patternParts {
switch part {
case configPathMatchAnyMultiple:
// if it is the last part of the pattern, it matches
if i == len(patternParts)-1 {
return true
}
// Otherwise, try to match the rest of the path
for j := i; j < len(pathParts); j++ {
if matchParts(pathParts[j:], patternParts[i+1:]) {
return true
}
}
case configPathList:
// check if pathParts[i] is inclosed by []
if pathParts[i][0] != '[' || pathParts[i][len(pathParts[i])-1] != ']' {
return false
}
default:
// If the part doesn't match, it doesn't match
if part != configPathMatchAny && part != pathParts[i] {
return false
}
}
// If it is the last part of the pattern and the path is longer, it doesn't match
if i == len(patternParts)-1 && i < len(pathParts)-1 {
return false
}
}
return true
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package validate
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// ValidatorFunc is a function that validates a part of the configuration.
// It takes the root configuration, the data to validate and the current path in the tree.
type ValidatorFunc func(*map[string]any, any, tree.Path) error
// Validator is a configuration validator.
// It contains a map of patterns to paths to functions that validate the configuration.
type Validator interface {
Validate(*map[string]any) error
}
// ValidatorImpl is the implementation of the Validator interface.
type ValidatorImpl struct {
validators map[tree.Path]ValidatorFunc
}
// NewValidator creates a new validator with the given validators.
func NewValidator(validators map[tree.Path]ValidatorFunc) Validator {
return ValidatorImpl{
validators: validators,
}
}
// Validate applies the validators to the configuration.
func (v ValidatorImpl) Validate(rawConfig *map[string]any) error {
data := any(*rawConfig)
return v.validate(rawConfig, data, tree.NewPath(), v.validators)
}
// validate is a recursive function that applies the validators to the configuration.
func (v ValidatorImpl) validate(root *map[string]interface{}, data any, path tree.Path, validators map[tree.Path]ValidatorFunc) error {
// Apply validators that match the current path.
for pattern, validator := range validators {
if path.Matches(pattern) {
if err := validator(root, data, path); err != nil {
return err
}
}
}
// Recursively apply validators to children.
switch data := data.(type) {
case map[string]interface{}:
for key, value := range data {
next := path.Next(key)
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
case []interface{}:
for i, value := range data {
next := path.Next(fmt.Sprintf("[%d]", i))
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package dms
import (
logging "github.com/ipfs/go-log/v2"
"go.uber.org/multierr"
)
var log = logging.Logger("dms")
// silenceLibp2pLogging is used to silence logs coming from libp2p
// imported libraries as they're enabled by default.
//
// TODO: move this to observability?
func silenceLibp2pLogging() error {
log.Debug("silecing libp2p logging")
var errs error
err := logging.SetLogLevel("libp2p", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("swarm2", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("basichost", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("pubsub", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("p2p-config", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("routedhost", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("relay", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("autorelay", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("autonat", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("node", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("p2p-holepunch", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("rcmgr", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevelRegex("dht/*", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevelRegex("net/*", "panic")
errs = multierr.Append(errs, err)
return errs
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"encoding/json"
"fmt"
"math/rand"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/types"
)
const (
NewDeploymentBehavior = "/dms/node/deployment/new"
// Minimum time for deployment
MinDeploymentTime = time.Minute - time.Second
bidStateGCInterval = time.Minute
bidStateTimeout = 5 * time.Minute
)
type NewDeploymentRequest struct {
Ensemble jobs.EnsembleConfig
}
type NewDeploymentResponse struct {
Status string
EnsembleID string `json:",omitempty"`
Error string `json:",omitempty"`
}
func (n *Node) newDeployment(msg actor.Envelope) {
defer msg.Discard()
if time.Until(msg.Expiry()) < MinDeploymentTime {
log.Debugf("deployment time too short")
n.sendReply(msg, NewDeploymentResponse{
Status: "ERROR",
Error: "requested deployment time too short",
})
return
}
var request NewDeploymentRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
log.Debugf("unmarshalling deployment request: %s", err)
n.sendReply(msg, NewDeploymentResponse{
Status: "ERROR",
Error: err.Error(),
})
return
}
orchestrator, err := n.createOrchestrator(request.Ensemble)
if err != nil {
log.Warnf("creating orchestrator: %s", err)
n.sendReply(msg, NewDeploymentResponse{
Status: "ERROR",
Error: err.Error(),
})
return
}
n.mx.Lock()
n.deployments[orchestrator.ID()] = orchestrator
n.mx.Unlock()
log.Infof("deploying ensemble: %s", orchestrator.ID())
n.sendReply(msg, NewDeploymentResponse{
Status: "OK",
EnsembleID: orchestrator.ID(),
})
if err := orchestrator.Deploy(msg.Expiry().Add(-jobs.MinEnsembleDeploymentTime)); err != nil {
orchestrator.Stop()
log.Errorf("error creating ensemble: %s", err)
n.mx.Lock()
delete(n.deployments, orchestrator.ID())
n.mx.Unlock()
return
}
// save the deployment
n.mx.Lock()
if err := n.saveDeployment(orchestrator.ID()); err != nil {
log.Errorf("error saving deployment %s: %s", orchestrator.ID(), err)
}
n.mx.Unlock()
}
func (n *Node) deploymentVerifyEdgeConstraint(msg actor.Envelope) {
defer msg.Discard()
var request jobs.VerifyEdgeConstraintRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
log.Warnf("error unmarshalling constraint request: %s", err)
n.sendReply(msg, jobs.VerifyEdgeConstraintResponse{
OK: false,
Error: err.Error(),
})
}
// TODO
}
func (n *Node) createOrchestrator(ensemble jobs.EnsembleConfig) (*jobs.Orchestrator, error) {
orch, err := jobs.NewOrchestrator(n.actor, n.network, ensemble)
if err != nil {
return nil, err
}
return orch, nil
}
func (n *Node) saveDeployments() error {
n.mx.Lock()
defer n.mx.Unlock()
var failed []string
for oid := range n.deployments {
if err := n.saveDeployment(oid); err != nil {
log.Errorf("error saving deployment %s: %s", oid, err)
failed = append(failed, oid)
}
}
if len(failed) != 0 {
return fmt.Errorf("failed to save deployments: %v", failed)
}
return nil
}
func (n *Node) saveDeployment(_ string) error {
// TODO
return nil
}
func (n *Node) restoreDeployments() error {
// TODO
return nil
}
func (n *Node) handleBidRequest(msg actor.Envelope) {
defer msg.Discard()
log.Debugf("got a bid request from: %s", &msg.From.Address)
var request jobs.EnsembleBidRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
machineResources, err := n.hardware.GetMachineResources()
if err != nil {
log.Debugf("failed to get machine resources")
return
}
// randomize the order of bid request checks
rand.Shuffle(len(request.Request), func(i, j int) {
request.Request[i], request.Request[j] = request.Request[j], request.Request[i]
})
// find the first bid request that matches
var toAnswer jobs.BidRequest
var found bool
loop:
for _, v := range request.Request {
// check if it is a V1 request
if v.V1 == nil {
log.Debug("bid request not v1")
continue
}
// check if we are excluded
hostID := n.actor.Handle().Address.HostID
for _, p := range request.PeerExclusion {
if p == hostID {
log.Debug("bid request has execlusion")
continue loop
}
}
// TODO allow static ports
if len(v.V1.PublicPorts.Static) > 0 {
log.Debug("bid request has static public ports")
continue loop
}
// if the desired executable is not found stop
for _, e := range v.V1.Executors {
_, err := n.getExecutor(e)
if err != nil {
log.Debugf("failed to get executor: %v", e)
continue loop
}
}
comparisonResult, err := machineResources.Compare(v.V1.Resources)
if err != nil {
log.Debugf("failed to compare machine resources")
continue loop
}
if comparisonResult != types.Better {
log.Debugf("resource comparison - not better - result: %v")
continue
}
found = true
toAnswer = v
break
}
if !found {
log.Debugf("bid requirements were not satisfied")
return
}
// handle dynamic port allocs
allocKey := request.ID
ports, err := n.portAllocator.Allocate(allocKey, toAnswer.V1.PublicPorts.Dynamic)
if err != nil {
log.Debugf("failed to allocate ports")
return
}
cleanup := func() {
n.portAllocator.Release(allocKey)
}
log.Debugf("signing bid with did: %+v", n.actor.Security().DID())
provider, err := n.rootCap.Trust().GetProvider(n.actor.Security().DID())
if err != nil {
cleanup()
return
}
log.Debugf("signing bid with proider: %+v", provider)
bid := jobs.Bid{
V1: &jobs.BidV1{
EnsembleID: request.ID,
NodeID: toAnswer.V1.NodeID,
Peer: n.hostID,
Location: jobs.Location{
Region: n.hostLocation.HostContinent,
Country: n.hostLocation.HostCountry,
City: n.hostLocation.HostCity,
},
Handle: n.actor.Handle(),
},
}
err = bid.Sign(provider)
if err != nil {
cleanup()
return
}
n.sendReply(msg, bid)
n.rememberBid(request.ID, toAnswer, ports)
}
func (n *Node) rememberBid(eid string, req jobs.BidRequest, ports []int) {
n.mx.Lock()
defer n.mx.Unlock()
_, exists := n.bids[eid]
if exists {
// we have an older bid
n.portAllocator.Release(eid)
}
n.bids[eid] = &bidState{
expire: time.Now().Add(bidStateTimeout),
request: req,
ports: ports,
}
}
func (n *Node) gcBidState() {
ticker := time.NewTicker(bidStateGCInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
n.doGCBidState()
case <-n.ctx.Done():
return
}
}
}
func (n *Node) doGCBidState() {
now := time.Now()
n.mx.Lock()
defer n.mx.Unlock()
for k, bs := range n.bids {
if bs.expire.Before(now) {
n.portAllocator.Release(k)
delete(n.bids, k)
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"encoding/json"
"time"
kbucket "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
const (
PeersListBehavior = "/dms/node/peers/list"
PeerAddrInfoBehavior = "/dms/node/peers/self"
PeerPingBehavior = "/dms/node/peers/ping"
PeerDHTBehavior = "/dms/node/peers/dht"
PeerConnectBehavior = "/dms/node/peers/connect"
PeerScoreBehavior = "/dms/node/peers/score"
OnboardBehavior = "/dms/node/onboarding/onboard"
OffboardBehavior = "/dms/node/onboarding/offboard"
OnboardStatusBehavior = "/dms/node/onboarding/status"
OnboardResourceBehavior = "/dms/node/onboarding/resource"
ContainerStartBehavior = "/dms/node/container/start"
ContainerStopBehavior = "/dms/node/container/stop"
ContainerListBehavior = "/dms/node/container/list"
VMStartBehavior = "/dms/node/vm/start/custom"
VMStopBehavior = "/dms/node/vm/stop"
VMListBehavior = "/dms/node/vm/list"
DeploymentListBehavior = "/dms/node/deployment/list"
DeploymentStatusBehavior = "/dms/node/deployment/status"
DeploymentManifestBehavior = "/dms/node/deployment/manifest"
DeploymentShutdownBehavior = "/dms/node/deployment/shutdown"
pingTimeout = 1 * time.Second
)
type PingRequest struct {
Host string
}
type PingResponse struct {
Error string
RTT int64
}
func (n *Node) handlePeerPing(msg actor.Envelope) {
defer msg.Discard()
var request PingRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
resp := PingResponse{}
res, err := n.network.Ping(context.Background(), request.Host, pingTimeout)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
if res.Error != nil {
resp.Error = res.Error.Error()
}
resp.RTT = res.RTT.Milliseconds()
n.sendReply(msg, resp)
}
type PeersListResponse struct {
Peers []peer.ID
}
func (n *Node) handlePeersList(msg actor.Envelope) {
defer msg.Discard()
// get the underlying libp2p instance and extract the DHT data
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
// TODO log
return
}
resp := PeersListResponse{
Peers: make([]peer.ID, 0),
}
for _, v := range libp2pNet.PS.Peers() {
resp.Peers = append(resp.Peers, v)
}
n.sendReply(msg, resp)
}
type PeerAddrInfoResponse struct {
ID string `json:"id"`
Address string `json:"listen_addr"`
}
func (n *Node) handlePeerAddrInfo(msg actor.Envelope) {
defer msg.Discard()
stats := n.network.Stat()
resp := PeerAddrInfoResponse{
ID: stats.ID,
Address: stats.ListenAddr,
}
n.sendReply(msg, resp)
}
type PeerDHTResponse struct {
Peers []kbucket.PeerInfo
}
func (n *Node) handlePeerDHT(msg actor.Envelope) {
defer msg.Discard()
// get the underlying libp2p instance and extract the DHT data
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
// TODO log
return
}
resp := PeerDHTResponse{
Peers: libp2pNet.DHT.RoutingTable().GetPeerInfos(),
}
n.sendReply(msg, resp)
}
type PeerConnectRequest struct {
Address string
}
type PeerConnectResponse struct {
Status string
Error string
}
func (n *Node) handlePeerConnect(msg actor.Envelope) {
defer msg.Discard()
var request PeerConnectRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
// TODO log
return
}
resp := PeerConnectResponse{}
peerAddr, err := multiaddr.NewMultiaddr(request.Address)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
if err := libp2pNet.Host.Connect(context.Background(), *addrInfo); err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Status = "CONNECTED"
n.sendReply(msg, resp)
}
type OnboardRequest struct {
Config types.OnboardingConfig
}
type OnboardResponse struct {
Error string
Config types.OnboardingConfig
}
func (n *Node) handleOnboard(msg actor.Envelope) {
defer msg.Discard()
resp := OnboardResponse{}
var request OnboardRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
err := n.onboarder.Onboard(context.Background(), request.Config)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Config = request.Config
n.sendReply(msg, resp)
}
type OffboardRequest struct {
Force bool
}
type OffboardResponse struct {
Success bool
}
func (n *Node) handleOffboard(msg actor.Envelope) {
defer msg.Discard()
var request OffboardRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
resp := OffboardResponse{}
if err := n.onboarder.Offboard(context.Background(), request.Force); err != nil {
resp.Success = false
n.sendReply(msg, resp)
return
}
resp.Success = true
n.sendReply(msg, resp)
}
type OnboardStatusResponse struct {
Onboarded bool
Error string
}
func (n *Node) handleOnboardStatus(msg actor.Envelope) {
defer msg.Discard()
resp := OnboardStatusResponse{}
onboarded, err := n.onboarder.IsOnboarded(context.Background())
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.Onboarded = onboarded
n.sendReply(msg, resp)
}
type OnboardResourceRequest struct {
Config types.OnboardingConfig
}
type OnboardResourceResponse struct {
Error string
Result types.OnboardingConfig
}
type DeploymentListResponse struct {
// Deployment ID -> Deployment Status
Deployments map[string]string
}
func (n *Node) handleDeploymentList(msg actor.Envelope) {
defer msg.Discard()
var resp DeploymentListResponse
for ID, dep := range n.deployments {
resp.Deployments[ID] = jobs.DeploymentStatusString(dep.Status())
}
n.sendReply(msg, resp)
}
type DeploymentStatusRequest struct {
ID string
}
type DeploymentStatusResponse struct {
Status jobs.DeploymentStatus
Error string
}
func (n *Node) handleDeploymentStatus(msg actor.Envelope) {
defer msg.Discard()
var request DeploymentStatusRequest
var resp DeploymentStatusResponse
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
d, ok := n.deployments[request.ID]
if !ok {
// TODO: check database for persisted deployments data
resp.Error = ErrDeploymentNotFound.Error()
n.sendReply(msg, resp)
return
}
resp.Status = d.Status()
n.sendReply(msg, resp)
}
type DeploymentManifestRequest struct {
ID string
}
type DeploymentManifestResponse struct {
Manifest jobs.EnsembleManifest
Error string
}
func (n *Node) handleDeploymentManifest(msg actor.Envelope) {
defer msg.Discard()
var request DeploymentManifestRequest
var resp DeploymentManifestResponse
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
d, ok := n.deployments[request.ID]
if !ok {
// TODO: check database for persisted deployments data
resp.Error = ErrDeploymentNotFound.Error()
n.sendReply(msg, resp)
return
}
resp.Manifest = d.Manifest()
n.sendReply(msg, resp)
}
type DeploymentShutdownRequest struct {
ID string
}
type DeploymentShutdownResponse struct {
Error string
}
func (n *Node) handleDeploymentShutdown(msg actor.Envelope) {
defer msg.Discard()
var request DeploymentShutdownRequest
var resp DeploymentShutdownResponse
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
d, ok := n.deployments[request.ID]
if !ok {
resp.Error = ErrDeploymentNotFound.Error()
n.sendReply(msg, resp)
return
}
if d.Status() != jobs.DeploymentStatusRunning {
// maybe-TODO: if it's still provisioning/committing,
// we should stop the deployment process anyway
resp.Error = ErrDeploymentNotRunning.Error()
n.sendReply(msg, resp)
return
}
err := d.Shutdown()
if err != nil {
resp.Error = err.Error()
}
n.sendReply(msg, resp)
}
type CustomVMStartRequest struct {
Execution types.ExecutionRequest
}
type CustomVMStartResponse struct {
Error string
}
func (n *Node) handleVMContainerStart(msg actor.Envelope) {
defer msg.Discard()
var request CustomVMStartRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
var executionType jobs.AllocationExecutor
if request.Execution.EngineSpec.IsType(types.ExecutorTypeFirecracker.String()) {
executionType = jobs.ExecutorFirecracker
} else if request.Execution.EngineSpec.IsType(types.ExecutorTypeDocker.String()) {
executionType = jobs.ExecutorDocker
}
resp := CustomVMStartResponse{}
e, err := n.getExecutor(executionType)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
err = e.executor.Start(context.Background(), &request.Execution)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
type VMStopRequest struct {
ExecutionID string
ExecutionType jobs.AllocationExecutor
}
type VMStopResponse struct {
Error string
}
func (n *Node) handleVMContainerStop(msg actor.Envelope) {
defer msg.Discard()
var request VMStopRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
// TODO log
return
}
resp := VMStopResponse{}
e, err := n.getExecutor(request.ExecutionType)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
err = e.executor.Cancel(context.Background(), request.ExecutionID)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
type ListVMResponse struct {
Error string
VMS []types.ExecutionListItem
ExecutionType jobs.AllocationExecutor
}
func (n *Node) handleVMContainerList(msg actor.Envelope) {
defer msg.Discard()
resp := ListVMResponse{}
e, err := n.getExecutor(resp.ExecutionType)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.VMS = e.executor.List()
n.sendReply(msg, resp)
}
type PeerScoreResponse struct {
Score map[string]*network.PeerScoreSnapshot
}
func (n *Node) handlePeerScore(msg actor.Envelope) {
defer msg.Discard()
resp := PeerScoreResponse{Score: make(map[string]*network.PeerScoreSnapshot)}
snapshot := n.network.GetBroadcastScore()
for p, score := range snapshot {
resp.Score[p.String()] = score
}
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetCreate(msg actor.Envelope) {
defer msg.Discard()
var request jobs.SubnetCreateRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.SubnetCreateResponse{}
err := n.network.CreateSubnet(context.Background(), request.SubnetID, request.RoutingTable)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetAddPeer(msg actor.Envelope) {
defer msg.Discard()
var request jobs.SubnetAddPeerRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.SubnetAddPeerResponse{}
err := n.network.AddSubnetPeer(request.SubnetID, request.PeerID, request.IP)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetAcceptPeer(msg actor.Envelope) {
defer msg.Discard()
var request jobs.SubnetAcceptPeerRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.SubnetAcceptPeerResponse{}
err := n.network.AcceptSubnetPeer(request.SubnetID, request.PeerID, request.IP)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetMapPort(msg actor.Envelope) {
defer msg.Discard()
var request jobs.SubnetMapPortRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.SubnetMapPortResponse{}
err := n.network.MapPort(request.SubnetID, request.Protocol, request.SourceIP, request.SourcePort, request.DestIP, request.DestPort)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetDNSAddRecord(msg actor.Envelope) {
defer msg.Discard()
var request jobs.SubnetDNSAddRecordRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.SubnetDNSAddRecordResponse{}
err := n.network.AddSubnetDNSRecord(request.SubnetID, request.DomainName, request.IP)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
func (n *Node) handleAllocationDeployment(msg actor.Envelope) {
defer msg.Discard()
var request jobs.AllocationDeploymentRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.AllocationDeploymentResponse{}
allocations, err := n.createAllocations(request.EnsembleID, request.NodeID, request.Allocations)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.OK = true
resp.Allocations = allocations
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetUnmapPort(msg actor.Envelope) {
defer msg.Discard()
var request jobs.SubnetUnmapPortRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.SubnetUnmapPortResponse{}
err := n.network.UnmapPort(
request.SubnetID, request.Protocol, request.SourceIP, request.SourcePort, request.DestIP, request.DestPort,
)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetDNSRemoveRecord(msg actor.Envelope) {
defer msg.Discard()
var request jobs.SubnetDNSRemoveRecordRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.SubnetDNSRemoveRecordResponse{}
err := n.network.RemoveSubnetDNSRecord(request.SubnetID, request.DomainName)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetDestroy(msg actor.Envelope) {
defer msg.Discard()
var request jobs.SubnetDestroyRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.SubnetDestroyResponse{}
err := n.network.DestroySubnet(request.SubnetID)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetRemovePeer(msg actor.Envelope) {
defer msg.Discard()
var request jobs.SubnetRemovePeerRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.SubnetRemovePeerResponse{}
err := n.network.RemoveSubnetPeer(request.SubnetID, request.PeerID)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
n.sendReply(msg, resp)
}
func (n *Node) handleCommitDeployment(msg actor.Envelope) {
defer msg.Discard()
var request jobs.CommitDeploymentRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
return
}
resp := jobs.CommitDeploymentResponse{}
err := n.commitDeployment(request.EnsembleID)
if err != nil {
resp.Error = err.Error()
n.sendReply(msg, resp)
return
}
resp.OK = true
n.sendReply(msg, resp)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"errors"
"fmt"
"math/rand"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/executor"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
const (
helloMinDelay = 10 * time.Second
helloMaxDelay = 20 * time.Second
helloTimeout = 3 * time.Second
helloAttempts = 3
clearCommitedResourcesFrequency = 60 * time.Second
rootProto = "actor/root/messages/0.0.1"
)
// Node is the structure that holds the node's dependencies.
type Node struct {
rootCap ucan.CapabilityContext
actor actor.Actor
scheduler *bt.Scheduler
network network.Network
resourceManager types.ResourceManager
hardware types.HardwareManager
hostID string
onboarder *onboarding.Onboarding
executors map[string]executorMetadata
rumutex sync.RWMutex
ctx context.Context
cancel func()
mx sync.Mutex
allocmx sync.Mutex
peers map[peer.ID]*peerState
bids map[string]*bidState
deployments map[string]*jobs.Orchestrator
allocations map[string]*jobs.Allocation
running int32
geoip types.GeoIPLocator
hostLocation HostGeolocation
portConfig PortConfig
portAllocator *PortAllocator
commitedResources map[string]*bidState
}
type peerState struct {
conns int
hasRoot bool
helloIn, helloOut, helloPending bool
helloAttempts int
}
type bidState struct {
expire time.Time
request jobs.BidRequest
ports []int
}
type executorMetadata struct {
executor executor.Executor
executionType jobs.AllocationExecutor
}
type HostGeolocation struct {
HostContinent string
HostCountry string
HostCity string
}
type PortConfig struct {
AvailableRangeFrom int
AvailableRangeTo int
}
// New creates a new node, attaches an actor to the node.
func New(onboarder *onboarding.Onboarding,
rootCap ucan.CapabilityContext,
hostID string, net network.Network,
resourceManager types.ResourceManager,
scheduler *bt.Scheduler,
hardware types.HardwareManager,
geoip types.GeoIPLocator, hostLocation HostGeolocation, portConfig PortConfig,
) (*Node, error) {
if onboarder == nil {
return nil, errors.New("onboarder is nil")
}
if rootCap == nil {
return nil, errors.New("root capability context is nil")
}
if hostID == "" {
return nil, errors.New("host id is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
if scheduler == nil {
return nil, errors.New("scheduler is nil")
}
if geoip == nil {
return nil, errors.New("geoip is nil")
}
rootDID := rootCap.DID()
rootTrust := rootCap.Trust()
anchor, err := rootTrust.GetAnchor(rootDID)
if err != nil {
return nil, fmt.Errorf("failed to get root DID anchor: %w", err)
}
pubk := anchor.PublicKey()
provider, err := rootTrust.GetProvider(rootDID)
if err != nil {
return nil, fmt.Errorf("failed to get root DID provider: %w", err)
}
privk, err := provider.PrivateKey()
if err != nil {
return nil, fmt.Errorf("failed to get root private key: %w", err)
}
rootSec, err := actor.NewBasicSecurityContext(pubk, privk, rootCap)
if err != nil {
return nil, fmt.Errorf("failed to create security context: %w", err)
}
nodeActor, err := createActor(rootSec, actor.NewRateLimiter(actor.DefaultRateLimiterConfig()), hostID, "root", net, scheduler)
if err != nil {
return nil, fmt.Errorf("failed to create node actor: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
n := &Node{
hostID: hostID,
network: net,
bids: make(map[string]*bidState),
deployments: make(map[string]*jobs.Orchestrator),
allocations: make(map[string]*jobs.Allocation),
peers: make(map[peer.ID]*peerState),
resourceManager: resourceManager,
hardware: hardware,
actor: nodeActor,
rootCap: rootCap,
scheduler: scheduler,
onboarder: onboarder,
executors: make(map[string]executorMetadata),
ctx: ctx,
cancel: cancel,
geoip: geoip,
hostLocation: hostLocation,
portConfig: portConfig,
portAllocator: NewPortAllocator(portConfig),
commitedResources: make(map[string]*bidState),
}
if err := n.initSupportedExecutors(ctx); err != nil {
cancel()
return nil, fmt.Errorf("new executor: %w", err)
}
dmsBehaviors := map[string]struct {
fn func(actor.Envelope)
opts []actor.BehaviorOption
}{
PublicHelloBehavior: {
fn: n.publicHelloBehavior,
},
PublicStatusBehavior: {
fn: n.publicStatusBehavior,
},
BroadcastHelloBehavior: {
fn: n.broadcastHelloBehavior,
opts: []actor.BehaviorOption{
actor.WithBehaviorTopic(BroadcastHelloTopic),
},
},
PeersListBehavior: {
fn: n.handlePeersList,
},
PeerAddrInfoBehavior: {
fn: n.handlePeerAddrInfo,
},
PeerPingBehavior: {
fn: n.handlePeerPing,
},
PeerDHTBehavior: {
fn: n.handlePeerDHT,
},
PeerConnectBehavior: {
fn: n.handlePeerConnect,
},
PeerScoreBehavior: {
fn: n.handlePeerScore,
},
OnboardBehavior: {
fn: n.handleOnboard,
},
OffboardBehavior: {
fn: n.handleOffboard,
},
OnboardStatusBehavior: {
fn: n.handleOnboardStatus,
},
VMStartBehavior: {
fn: n.handleVMContainerStart,
},
VMStopBehavior: {
fn: n.handleVMContainerStop,
},
VMListBehavior: {
fn: n.handleVMContainerList,
},
ContainerStartBehavior: {
fn: n.handleVMContainerStart,
},
ContainerStopBehavior: {
fn: n.handleVMContainerStop,
},
ContainerListBehavior: {
fn: n.handleVMContainerList,
},
NewDeploymentBehavior: {
fn: n.newDeployment,
},
DeploymentListBehavior: {
fn: n.handleDeploymentList,
},
DeploymentStatusBehavior: {
fn: n.handleDeploymentStatus,
},
DeploymentManifestBehavior: {
fn: n.handleDeploymentManifest,
},
DeploymentShutdownBehavior: {
fn: n.handleDeploymentShutdown,
},
jobs.VerifyEdgeConstraintBehavior: {
fn: n.deploymentVerifyEdgeConstraint,
},
jobs.BidRequestBehavior: {
fn: n.handleBidRequest,
opts: []actor.BehaviorOption{
actor.WithBehaviorTopic(jobs.BidRequestTopic),
},
},
jobs.SubnetCreateBehavior: {
fn: n.handleSubnetCreate,
},
jobs.SubnetDestroyBehavior: {
fn: n.handleSubnetDestroy,
},
jobs.SubnetAddPeerBehavior: {
fn: n.handleSubnetAddPeer,
},
jobs.SubnetRemovePeerBehavior: {
fn: n.handleSubnetRemovePeer,
},
jobs.SubnetAcceptPeerBehavior: {
fn: n.handleSubnetAcceptPeer,
},
jobs.SubnetMapPortBehavior: {
fn: n.handleSubnetMapPort,
},
jobs.SubnetDNSAddRecordBehavior: {
fn: n.handleSubnetDNSAddRecord,
},
jobs.SubnetUnmapPortBehavior: {
fn: n.handleSubnetUnmapPort,
},
jobs.SubnetDNSRemoveRecordBehavior: {
fn: n.handleSubnetDNSRemoveRecord,
},
jobs.AllocationDeploymentBehavior: {
fn: n.handleAllocationDeployment,
},
jobs.CommitDeploymentBehavior: {
fn: n.handleCommitDeployment,
},
}
for behavior, handler := range dmsBehaviors {
if err := nodeActor.AddBehavior(behavior, handler.fn, handler.opts...); err != nil {
return nil, fmt.Errorf("adding %s behavior: %w", behavior, err)
}
}
if err := n.restoreDeployments(); err != nil {
log.Errorf("restoring deployments: %s", err)
}
ticker := time.NewTicker(clearCommitedResourcesFrequency)
go func() {
for range ticker.C {
n.clearCommitedResources()
}
}()
return n, nil
}
// GetAllocation gets an allocation by id.
func (n *Node) GetAllocation(id string) (*jobs.Allocation, error) {
n.allocmx.Lock()
defer n.allocmx.Unlock()
alloc, ok := n.allocations[id]
if !ok {
return nil, errors.New("allocation not found")
}
return alloc, nil
}
// Start node
func (n *Node) Start() error {
if !atomic.CompareAndSwapInt32(&n.running, 0, 1) {
return nil
}
if err := n.actor.Start(); err != nil {
return fmt.Errorf("failed to start node actor: %w", err)
}
if err := n.subscribe(BroadcastHelloTopic, jobs.BidRequestTopic); err != nil {
_ = n.actor.Stop()
return err
}
go n.gcBidState()
return nil
}
// ExecutorAvailable returns the availability of a specific executor.
func (n *Node) ExecutorAvailable(execType jobs.AllocationExecutor) bool {
n.rumutex.RLock()
defer n.rumutex.RUnlock()
_, ok := n.executors[string(execType)]
return ok
}
func (n *Node) subscribe(topics ...string) error {
for _, topic := range topics {
if err := n.actor.Subscribe(topic, n.setupBroadcast); err != nil {
return fmt.Errorf("error subscribing to %s: %w", topic, err)
}
}
n.network.SetBroadcastAppScore(n.broadcastScore)
if err := n.network.Notify(n.actor.Context(), n.peerPreConnected, n.peerConnected, n.peerDisconnected, n.peerIdentified, n.peerIdentified); err != nil {
return fmt.Errorf("error setting up peer notifications: %w", err)
}
return nil
}
func (n *Node) setupBroadcast(topic string) error {
return n.network.SetupBroadcastTopic(topic, func(t *network.Topic) error {
return t.SetScoreParams(&pubsub.TopicScoreParams{
SkipAtomicValidation: true,
TopicWeight: 1.0,
TimeInMeshWeight: 0.00027, // ~1/3600
TimeInMeshQuantum: time.Second,
TimeInMeshCap: 1.0,
InvalidMessageDeliveriesWeight: -1000,
InvalidMessageDeliveriesDecay: pubsub.ScoreParameterDecay(time.Hour),
})
})
}
func (n *Node) broadcastScore(p peer.ID) float64 {
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
return 0
}
if st.helloIn && st.helloOut {
return 5
}
if st.hasRoot {
return 1
}
return 0
}
func (n *Node) peerConnected(p peer.ID) {
log.Debugf("peer connected: %s", p)
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
st = &peerState{}
n.peers[p] = st
}
st.conns++
}
func (n *Node) peerPreConnected(p peer.ID, protos []protocol.ID, conns int) {
log.Debugf("peer preconnected: %s %s (%d)", p, protos, conns)
n.mx.Lock()
defer n.mx.Unlock()
st := &peerState{conns: conns}
n.peers[p] = st
if includesRootProtocol(protos) {
st.hasRoot = true
st.helloPending = true
st.helloAttempts = 1
go n.sayHello(p)
}
}
func (n *Node) peerIdentified(p peer.ID, protos []protocol.ID) {
log.Debugf("peer identified: %s %s", p, protos)
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
st = &peerState{}
n.peers[p] = st
}
if includesRootProtocol(protos) {
st.hasRoot = true
if !st.helloOut && !st.helloPending {
st.helloPending = true
st.helloAttempts++
go n.sayHello(p)
}
}
}
func (n *Node) peerDisconnected(p peer.ID) {
log.Debugf("peer disconnected: %s", p)
n.mx.Lock()
defer n.mx.Unlock()
st, ok := n.peers[p]
if !ok {
return
}
st.conns--
if st.conns <= 0 {
delete(n.peers, p)
}
}
func (n *Node) sayHello(p peer.ID) {
pubk, err := p.ExtractPublicKey()
if err != nil {
log.Debugf("failed to extract public key: %s", err)
return
}
if !crypto.AllowedKey(int(pubk.Type())) {
log.Debugf("unexpected key type: %d", pubk.Type())
return
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
log.Debugf("failed to extract actor ID: %s", err)
return
}
actorDID := did.FromPublicKey(pubk)
handle := actor.Handle{
ID: actorID,
DID: actorDID,
Address: actor.Address{
HostID: p.String(),
InboxAddress: "root",
},
}
wait := helloMinDelay + time.Duration(rand.Int63n(int64(helloMaxDelay-helloMinDelay)))
time.Sleep(wait)
n.mx.Lock()
st, ok := n.peers[p]
if !ok {
n.mx.Unlock()
return
}
if !n.network.PeerConnected(p) {
st.helloPending = false
n.mx.Unlock()
return
}
n.mx.Unlock()
msg, err := actor.Message(
n.actor.Handle(),
handle,
PublicHelloBehavior,
nil,
actor.WithMessageTimeout(helloTimeout),
)
if err != nil {
log.Debugf("failed to construct hello message: %s", err)
return
}
log.Debugf("saying hello to %s", handle.Address.HostID)
replyCh, err := n.actor.Invoke(msg)
if err != nil {
n.mx.Lock()
if st, ok = n.peers[p]; ok {
if st.helloAttempts < helloAttempts {
st.helloAttempts++
go n.sayHello(p)
} else {
st.helloPending = false
}
}
n.mx.Unlock()
log.Debugf("error invoking hello: %s", err)
return
}
select {
case reply := <-replyCh:
reply.Discard()
n.mx.Lock()
if st, ok = n.peers[p]; ok {
st.helloOut = true
st.helloPending = false
} else if n.network.PeerConnected(p) {
// race with connected notification
st = &peerState{helloOut: true}
n.peers[p] = st
}
n.mx.Unlock()
log.Infof("got hello response from %s", handle.Address.HostID)
case <-time.After(time.Until(msg.Expiry())):
n.mx.Lock()
if st, ok = n.peers[p]; ok {
if st.helloAttempts < helloAttempts {
st.helloAttempts++
go n.sayHello(p)
} else {
st.helloPending = false
}
}
n.mx.Unlock()
log.Debugf("hello timeout for %s", handle.Address.HostID)
}
}
// Stop node
func (n *Node) Stop() error {
n.mx.Lock()
defer n.mx.Unlock()
if !atomic.CompareAndSwapInt32(&n.running, 1, 0) {
return nil
}
// stop all allocations
for k, alloc := range n.allocations {
if err := alloc.Stop(n.ctx); err != nil {
log.Warnf("error stopping allocation %s: %err", k, err)
}
}
if err := n.saveDeployments(); err != nil {
log.Errorf("error saving active deployments: %s", err)
}
n.cancel()
// clear the broadcast app score
n.network.SetBroadcastAppScore(nil)
// stop the actor
if err := n.actor.Stop(); err != nil {
return fmt.Errorf("failed to stop node actor: %w", err)
}
return nil
}
func (n *Node) sendReply(msg actor.Envelope, payload interface{}) {
var opt []actor.MessageOption
if msg.IsBroadcast() {
opt = append(opt, actor.WithMessageSource(n.actor.Handle()))
}
reply, err := actor.ReplyTo(msg, payload, opt...)
if err != nil {
log.Debugf("error creating reply: %s", err)
return
}
if err := n.actor.Send(reply); err != nil {
log.Debugf("error sending reply: %s", err)
}
}
func (n *Node) getExecutor(execType jobs.AllocationExecutor) (executorMetadata, error) {
n.rumutex.RLock()
defer n.rumutex.RUnlock()
e, ok := n.executors[string(execType)]
if !ok {
return executorMetadata{}, errors.New("executor not available")
}
return e, nil
}
func (n *Node) createAllocations(ensembleID string, _ string, allocations map[string]jobs.AllocationDeploymentConfig) (map[string]actor.Handle, error) {
allocHandles := make(map[string]actor.Handle, len(allocations))
for allocationID, config := range allocations {
if _, ok := n.allocations[allocationID]; ok {
continue
}
allocation, err := n.createAllocation(jobs.Job{
ID: ensembleID,
Resources: config.Resources,
Execution: config.Execution,
ProvisionScripts: config.ProvisionScripts,
})
if err != nil {
return nil, fmt.Errorf("failed to create allocation %s: %w", allocationID, err)
}
if err := allocation.Run(n.ctx); err != nil {
return nil, fmt.Errorf("failed to run allocation %s: %w", allocationID, err)
}
allocHandles[allocationID] = allocation.Actor.Handle()
}
return allocHandles, nil
}
// createAllocation creates an allocation
func (n *Node) createAllocation(job jobs.Job) (*jobs.Allocation, error) {
// generate random keypair
priv, pub, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return nil, fmt.Errorf("failed to generate random keypair for allocation job %s: %w", job.ID, err)
}
security, err := actor.NewBasicSecurityContext(pub, priv, n.rootCap)
if err != nil {
return nil, fmt.Errorf("failed to create security context: %w", err)
}
allocationInbox, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate uuid for allocation inbox: %w", err)
}
actor, err := createActor(security, n.actor.Limiter(), n.hostID, allocationInbox.String(), n.network, n.scheduler)
if err != nil {
return nil, fmt.Errorf("failed to create allocation actor: %w", err)
}
allocation, err := jobs.NewAllocation(actor, jobs.AllocationDetails{Job: job, NodeID: n.hostID}, n.resourceManager)
if err != nil {
return nil, fmt.Errorf("failed to create allocation: %w", err)
}
err = allocation.Start()
if err != nil {
return nil, fmt.Errorf("failed to start the allocation: %w", err)
}
n.updateAllocations(allocation)
return allocation, nil
}
func (n *Node) updateAllocations(alloc *jobs.Allocation) {
n.allocmx.Lock()
n.allocations[alloc.ID] = alloc
n.allocmx.Unlock()
}
func (n *Node) commitDeployment(ensembleID string) error {
n.mx.Lock()
defer n.mx.Unlock()
bidState, ok := n.bids[ensembleID]
if !ok {
return fmt.Errorf("no bid requests for ensemble id: %s", ensembleID)
}
if bidState.expire.Before(time.Now()) {
return fmt.Errorf("bid request for ensemble id: %s has expired", ensembleID)
}
_, alreadyCommited := n.commitedResources[ensembleID]
if alreadyCommited {
return nil
}
if err := n.resourceManager.CommitResources(context.TODO(), types.ResourceAllocation{
JobID: ensembleID,
Resources: bidState.request.V1.Resources,
}); err != nil {
return fmt.Errorf("failed to preallocate resources for ensemble id: %s: %w", ensembleID, err)
}
n.commitedResources[ensembleID] = bidState
return nil
}
func (n *Node) clearCommitedResources() {
n.mx.Lock()
defer n.mx.Unlock()
for ensembleID, v := range n.commitedResources {
// if allocation not found for this commitment and bid is expired release resources
_, allocFound := n.allocations[ensembleID]
if !allocFound && time.Now().After(v.expire) {
if err := n.resourceManager.ReleaseCommittedResources(context.Background(), ensembleID); err != nil {
log.Errorf("failed to preallocate resources for ensemble id: %s: %w", ensembleID, err)
}
delete(n.bids, ensembleID)
delete(n.commitedResources, ensembleID)
}
}
}
// createActor creates an actor.
func createActor(sctx *actor.BasicSecurityContext, limiter actor.RateLimiter, hostID, inboxAddress string, net network.Network, scheduler *bt.Scheduler) (*actor.BasicActor, error) {
self := actor.Handle{
ID: sctx.ID(),
DID: sctx.DID(),
Address: actor.Address{
HostID: hostID,
InboxAddress: inboxAddress,
},
}
actor, err := actor.New(scheduler, net, sctx, limiter, actor.BasicActorParams{}, self)
if err != nil {
return nil, fmt.Errorf("failed to create actor: %w", err)
}
return actor, nil
}
func includesRootProtocol(protos []protocol.ID) bool {
for _, proto := range protos {
if proto == rootProto {
return true
}
}
return false
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
// +build linux
package node
import (
"context"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/executor/firecracker"
)
func (n *Node) initSupportedExecutors(ctx context.Context) error {
executor, err := firecracker.NewExecutor(ctx, "root")
if err == nil {
n.executors[string(jobs.ExecutorFirecracker)] = executorMetadata{
executor: executor,
executionType: jobs.ExecutorFirecracker,
}
}
dockerExec, err := docker.NewExecutor(ctx, "root")
if err == nil {
n.executors[string(jobs.ExecutorDocker)] = executorMetadata{
executor: dockerExec,
executionType: jobs.ExecutorDocker,
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"fmt"
"sync"
)
// PortAllocator keeps track of port allocations and manages state.
type PortAllocator struct {
config PortConfig
mx sync.Mutex
allocs map[string][]int
reserved map[int]struct{}
}
// NewPortAllocator initializes a new PortAllocator with a PortConfig.
func NewPortAllocator(config PortConfig) *PortAllocator {
return &PortAllocator{
config: config,
allocs: make(map[string][]int),
reserved: make(map[int]struct{}),
}
}
// AllocatePorts allocates the requested number of ports and associates them with the allocation ID.
func (pa *PortAllocator) Allocate(allocationID string, numPorts int) ([]int, error) {
pa.mx.Lock()
defer pa.mx.Unlock()
var allocated []int
for i := 0; i < numPorts; i++ {
port, err := pa.allocate()
if err != nil {
return nil, err
}
allocated = append(allocated, port)
}
for _, p := range allocated {
pa.reserved[p] = struct{}{}
}
pa.allocs[allocationID] = allocated
return allocated, nil
}
func (pa *PortAllocator) allocate() (int, error) {
for i := pa.config.AvailableRangeFrom; i <= pa.config.AvailableRangeTo; i++ {
_, reserved := pa.reserved[i]
if reserved {
continue
}
pa.reserved[i] = struct{}{}
return i, nil
}
return 0, fmt.Errorf("no available ports")
}
func (pa *PortAllocator) Release(allocationID string) {
pa.mx.Lock()
defer pa.mx.Unlock()
allocated, ok := pa.allocs[allocationID]
if !ok {
return
}
for _, p := range allocated {
delete(pa.reserved, p)
}
delete(pa.allocs, allocationID)
}
// GetAllocations returns the allocated ports for a specific allocation ID.
func (pa *PortAllocator) GetAllocation(allocationID string) ([]int, error) {
ports, exists := pa.allocs[allocationID]
if !exists {
return nil, fmt.Errorf("port allocation ID not found: %s", allocationID)
}
return ports, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/types"
)
const (
PublicHelloBehavior = "/public/hello"
PublicStatusBehavior = "/public/status"
BroadcastHelloBehavior = "/broadcast/hello"
BroadcastHelloTopic = "/nunet/hello"
)
type HelloResponse struct {
DID did.DID
}
type PublicStatusResponse struct {
Status string
Resources types.Resources
}
func (n *Node) publicHelloBehavior(msg actor.Envelope) {
pubk, err := did.PublicKeyFromDID(msg.From.DID)
if err != nil {
log.Debugf("failed to extract public key from DID: %s", err)
return
}
p, err := peer.IDFromPublicKey(pubk)
if err != nil {
log.Debugf("failed to extract peer ID from public key: %s", err)
return
}
n.mx.Lock()
if st, ok := n.peers[p]; ok {
st.helloIn = true
} else if n.network.PeerConnected(p) {
// rance with connected notification
st = &peerState{helloIn: true}
n.peers[p] = st
}
n.mx.Unlock()
n.handleHello(msg)
}
func (n *Node) broadcastHelloBehavior(msg actor.Envelope) {
n.handleHello(msg)
}
func (n *Node) handleHello(msg actor.Envelope) {
defer msg.Discard()
log.Debugf("hello from %s", msg.From.Address.HostID)
resp := HelloResponse{DID: n.actor.Security().DID()}
n.sendReply(msg, resp)
}
func (n *Node) publicStatusBehavior(msg actor.Envelope) {
defer msg.Discard()
var resp PublicStatusResponse
machineResources, err := n.hardware.GetMachineResources()
if err != nil {
resp.Status = "ERROR"
} else {
resp.Status = "OK"
resp.Resources = machineResources.Resources
}
n.sendReply(msg, resp)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package onboarding
import (
"context"
"errors"
"fmt"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
var (
ErrMachineNotOnboarded = errors.New("machine is not onboarded")
ErrOutOfRange = errors.New("out of range")
)
type Config struct {
Fs afero.Afero
WorkDir string
DatabasePath string
ConfigRepo repositories.OnboardingConfig
ResourceManager types.ResourceManager
Hardware types.HardwareManager
}
// NewConfig is a constructor for Config
func NewConfig(
fs afero.Afero,
workDir, dbPath string,
configRepo repositories.OnboardingConfig,
) *Config {
return &Config{
Fs: fs,
WorkDir: workDir,
DatabasePath: dbPath,
ConfigRepo: configRepo,
}
}
// Onboarding acts a receiver for methods related to onboarding
type Onboarding struct {
Config
}
// New is a constructor for Onboarding
func New(config *Config) *Onboarding {
return &Onboarding{Config: *config}
}
// IsOnboarded checks whether the machine is onboarded or not
func (o *Onboarding) IsOnboarded(ctx context.Context) (bool, error) {
_, err := o.ConfigRepo.Get(ctx)
if err != nil {
return false, err
}
// TODO: validate onboarding params
return true, nil
}
// Info returns the onboarding configuration
// It fetches the onboarding config from the database and the onboarded resources from the resource manager
// It also fetches the machine resources from the hardware package
func (o *Onboarding) Info(ctx context.Context) (types.OnboardingConfig, error) {
info, err := o.ConfigRepo.Get(ctx)
if err != nil {
return types.OnboardingConfig{}, fmt.Errorf("could not get onboarding config: %w", err)
}
// get onboarded resources from the resource manager
resources, err := o.ResourceManager.GetOnboardedResources(ctx)
if err != nil {
return types.OnboardingConfig{}, fmt.Errorf("could not get onboarded resources: %w", err)
}
info.OnboardedResources = resources.Resources
// get machine resources
machineResources, err := o.Hardware.GetMachineResources()
if err != nil {
return types.OnboardingConfig{}, fmt.Errorf("could not get machine resources: %w", err)
}
info.MachineResources = machineResources.Resources
return info, nil
}
// Onboard validates the onboarding params and onboards the machine to the network
// It saves the onboarding config to the database and updates the onboarded resources in the resource manager
func (o *Onboarding) Onboard(ctx context.Context, config types.OnboardingConfig) error {
log.Debugf("onboarding the machine with the config: %+v", config)
if err := o.validatePrerequisites(config); err != nil {
return fmt.Errorf("could not validate onboarding prerequisites: %w", err)
}
if err := o.ResourceManager.UpdateOnboardedResources(ctx, config.OnboardedResources); err != nil {
return fmt.Errorf("could not update onboarded resources: %w", err)
}
if _, err := o.ConfigRepo.Save(ctx, config); err != nil {
return fmt.Errorf("could not save onboarding config: %w", err)
}
return nil
}
// Offboard offboards the machine from the network by clearing the onboarding config from the database
func (o *Onboarding) Offboard(ctx context.Context, force bool) error {
onboarded, err := o.IsOnboarded(ctx)
if err != nil && !force {
if errors.Is(err, ErrMachineNotOnboarded) {
return ErrMachineNotOnboarded
}
return fmt.Errorf("could not retrieve onboard status: %w", err)
}
if err != nil {
log.Errorf("problem with onboarding state: %v", err)
log.Info("continuing with offboarding because forced")
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
// TODO: shutdown routine to stop networking etc... here
err = o.ConfigRepo.Clear(ctx)
if err != nil && !force {
return fmt.Errorf("failed to clear onboarding config from db: %w", err)
}
if err != nil {
log.Errorf("failed to clear onboarding config from db: %v", err)
log.Info("continuing with offboarding because forced")
}
// clear the onboarded resources
if err := o.ResourceManager.UpdateOnboardedResources(ctx, types.Resources{}); err != nil {
return fmt.Errorf("could not clear onboarded resources: %w", err)
}
return nil
}
func validateRange(actual, min, max float64) error {
if actual < min || actual > max {
return ErrOutOfRange
}
return nil
}
func (o *Onboarding) validateCapacity(resources types.Resources) error {
// TODO: https://gitlab.com/nunet/device-management-service/-/merge_requests/563#note_2139212199
machineResources, err := o.Hardware.GetMachineResources()
if err != nil {
return fmt.Errorf("retrieve provisioned machine resources: %w", err)
}
if resources.CPU.Cores < 1 || resources.CPU.Cores > machineResources.CPU.Cores {
return fmt.Errorf("cores must be between %d and %.0f", 1, machineResources.CPU.Cores)
}
if err := validateRange(
resources.RAM.Size,
machineResources.RAM.Size/10,
machineResources.RAM.Size*9/10,
); err != nil {
if errors.Is(err, ErrOutOfRange) {
return fmt.Errorf("expected RAM to be between %.2f and %.2f, got %.2f ",
types.ConvertBytesToGB(machineResources.RAM.Size/10),
types.ConvertBytesToGB(machineResources.RAM.Size*9/10),
types.ConvertBytesToGB(resources.RAM.Size),
)
}
return fmt.Errorf("validating resource range for RAM: %w", err)
}
for _, gpu := range resources.GPUs {
selectedGPU, err := machineResources.GPUs.GetWithIndex(gpu.Index)
if err != nil {
return fmt.Errorf("could not get find gpu: %w", err)
}
if err := validateRange(
gpu.VRAM,
selectedGPU.VRAM/10,
selectedGPU.VRAM*9/10,
); err != nil {
if errors.Is(err, ErrOutOfRange) {
return fmt.Errorf("expected GPU %d VRAM to be between %.2f and %.2f, got %.2f",
gpu.Index,
types.ConvertBytesToGB(selectedGPU.VRAM/10),
types.ConvertBytesToGB(selectedGPU.VRAM*9/10),
types.ConvertBytesToGB(gpu.VRAM),
)
}
return fmt.Errorf("validating resource range for GPU %d: %w", gpu.Index, err)
}
}
return nil
}
// validateUsage validates the resource usage data
// It checks if the there is enough resources available to onboard
func (o *Onboarding) validateUsage(resources types.Resources) error {
freeResources, err := o.Hardware.GetFreeResources()
if err != nil {
return fmt.Errorf("could not get usage data: %w", err)
}
if resources.CPU.Compute() > freeResources.CPU.Compute() {
return fmt.Errorf("CPU usage is too high: %.2f", freeResources.CPU.Compute())
}
if resources.RAM.Size > freeResources.RAM.Size {
return fmt.Errorf("memory usage is too high: %.2f", freeResources.RAM.Size)
}
for _, gpu := range resources.GPUs {
selectedGPU, err := freeResources.GPUs.GetWithIndex(gpu.Index)
if err != nil {
return fmt.Errorf("could not find gpu: %w", err)
}
if gpu.VRAM > selectedGPU.VRAM {
return fmt.Errorf("GPU %s usage is too high: %.2f", gpu.Model, gpu.VRAM)
}
}
return nil
}
// validatePrerequisites validates the onboarding prerequisites
func (o *Onboarding) validatePrerequisites(config types.OnboardingConfig) error {
ok, err := o.Fs.DirExists(o.WorkDir)
if err != nil {
return fmt.Errorf("could not check if config directory exists: %w", err)
}
if !ok {
return fmt.Errorf("working directory does not exist")
}
if err := o.validateCapacity(config.OnboardedResources); err != nil {
return fmt.Errorf("could not validate capacity data: %w", err)
}
if err := o.validateUsage(config.OnboardedResources); err != nil {
return fmt.Errorf("could not validate usage data: %w", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package resources
import (
"context"
"fmt"
"sync"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// ManagerRepos holds all the repositories needed for resource management
type ManagerRepos struct {
OnboardedResources repositories.OnboardedResources
ResourceAllocation repositories.ResourceAllocation
}
// DefaultManager implements the ResourceManager interface
// TODO: Add telemetry for the methods https://gitlab.com/nunet/device-management-service/-/issues/535
type DefaultManager struct {
repos ManagerRepos
store *store
hardware types.HardwareManager
// allocationLock is used to synchronize access to the allocation pool during allocation and deallocation
// it ensures that resource allocation and deallocation are atomic operations
allocationLock sync.Mutex
// committedLock is used to synchronize access to the committed resources pool during committing and releasing
// it ensures that resource committing and releasing are atomic operations
committedLock sync.Mutex
}
// NewResourceManager returns a new defaultResourceManager instance
func NewResourceManager(repos ManagerRepos, hardware types.HardwareManager) (*DefaultManager, error) {
if hardware == nil {
return nil, fmt.Errorf("hardware manager cannot be nil")
}
rmStore := newStore()
// TODO: load the allocations from db on startup
return &DefaultManager{
repos: repos,
store: rmStore,
hardware: hardware,
}, nil
}
var _ types.ResourceManager = (*DefaultManager)(nil)
// CommitResources preallocates the resources required by the jobs
func (d *DefaultManager) CommitResources(ctx context.Context, allocation types.ResourceAllocation) error {
d.committedLock.Lock()
defer d.committedLock.Unlock()
// Check if resources are already allocated for the job
var ok bool
d.store.withCommittedRLock(func() {
_, ok = d.store.committedResources[allocation.JobID]
})
if ok {
return fmt.Errorf("resources already committed for job %s", allocation.JobID)
}
ok = false
d.store.withAllocationsLock(func() {
_, ok = d.store.allocations[allocation.JobID]
})
if ok {
return fmt.Errorf("resources already allocated for job %s", allocation.JobID)
}
freeResources, err := d.GetFreeResources(ctx)
if err != nil {
return fmt.Errorf("getting free resources: %w", err)
}
// Check if there are enough free resources in dms pool to allocate
if err := freeResources.Subtract(allocation.Resources); err != nil {
return fmt.Errorf("no free resources: %w", err)
}
// Check if there are enough free resources on the machine to commit
systemFreeResources, err := d.hardware.GetFreeResources()
if err != nil {
return fmt.Errorf("get system free resources: %w", err)
}
if err := systemFreeResources.Subtract(allocation.Resources); err != nil {
return fmt.Errorf("no free resources on the machine: %w", err)
}
// update the committed resources in the store
d.store.withCommittedLock(func() {
d.store.committedResources[allocation.JobID] = &types.CommittedResources{
Resources: allocation.Resources,
JobID: allocation.JobID,
}
})
return nil
}
// ReleaseCommittedResources releases the resources that were committed
func (d *DefaultManager) ReleaseCommittedResources(_ context.Context, jobID string) error {
d.committedLock.Lock()
defer d.committedLock.Unlock()
// Check if resources are already deallocated for the job
var (
ok bool
)
d.store.withCommittedLock(func() {
_, ok = d.store.committedResources[jobID]
})
if !ok {
return fmt.Errorf("resources not committed for job %s", jobID)
}
// Release the committed resources
d.store.withCommittedLock(func() {
delete(d.store.committedResources, jobID)
})
return nil
}
// AllocateResources allocates resources for a job
func (d *DefaultManager) AllocateResources(ctx context.Context, allocation types.ResourceAllocation) error {
d.allocationLock.Lock()
defer d.allocationLock.Unlock()
// Check if resources are already allocated for the job
var ok bool
d.store.withAllocationsRLock(func() {
_, ok = d.store.allocations[allocation.JobID]
})
if ok {
return fmt.Errorf("resources already allocated for job %s", allocation.JobID)
}
freeResources, err := d.GetFreeResources(ctx)
if err != nil {
return fmt.Errorf("getting free resources: %w", err)
}
// Check if there are enough free resources in dms pool to allocate
if err := freeResources.Subtract(allocation.Resources); err != nil {
return fmt.Errorf("no free resources: %w", err)
}
// Check if there are enough free resources on the machine to allocate
systemFreeResources, err := d.hardware.GetFreeResources()
if err != nil {
return fmt.Errorf("get system free resources: %w", err)
}
if err := systemFreeResources.Subtract(allocation.Resources); err != nil {
return fmt.Errorf("no free resources on the machine: %w", err)
}
if err := d.storeAllocation(ctx, allocation); err != nil {
return fmt.Errorf("storing allocations in db: %w", err)
}
return nil
}
// DeallocateResources deallocates resources for a job
func (d *DefaultManager) DeallocateResources(ctx context.Context, jobID string) error {
d.allocationLock.Lock()
defer d.allocationLock.Unlock()
// Check if resources are already deallocated for the job
var (
ok bool
)
d.store.withAllocationsRLock(func() {
_, ok = d.store.allocations[jobID]
})
if !ok {
return fmt.Errorf("resources not allocated for job %s", jobID)
}
if err := d.deleteAllocation(ctx, jobID); err != nil {
return fmt.Errorf("deleting allocations from db: %w", err)
}
return nil
}
// GetFreeResources returns the free resources in the allocation pool
func (d *DefaultManager) GetFreeResources(ctx context.Context) (types.FreeResources, error) {
var freeResources types.FreeResources
// get onboarded resources
onboardedResources, err := d.GetOnboardedResources(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting onboarded resources: %w", err)
}
// get allocated resources
totalAllocation, err := d.GetTotalAllocation()
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting total allocations: %w", err)
}
// get committed resources
var committedResources types.Resources
d.store.withCommittedRLock(func() {
for _, committedResource := range d.store.committedResources {
_ = committedResources.Add(committedResource.Resources)
}
})
// calculate the free resources
freeResources.Resources = onboardedResources.Resources
if err := freeResources.Resources.Subtract(totalAllocation); err != nil {
return types.FreeResources{}, fmt.Errorf("subtracting total allocation: %w", err)
}
if err := freeResources.Resources.Subtract(committedResources); err != nil {
return types.FreeResources{}, fmt.Errorf("subtracting committed resources: %w", err)
}
return freeResources, nil
}
// GetTotalAllocation returns the total allocations of the jobs requiring resources
func (d *DefaultManager) GetTotalAllocation() (types.Resources, error) {
var (
totalAllocation types.Resources
err error
)
d.store.withAllocationsRLock(func() {
for _, allocation := range d.store.allocations {
err = totalAllocation.Add(allocation.Resources)
if err != nil {
break
}
}
})
return totalAllocation, err
}
// GetOnboardedResources returns the onboarded resources of the machine
func (d *DefaultManager) GetOnboardedResources(ctx context.Context) (types.OnboardedResources, error) {
var (
onboardedResources types.OnboardedResources
ok bool
)
d.store.withOnboardedRLock(func() {
if d.store.onboardedResources != nil {
onboardedResources = *d.store.onboardedResources
ok = true
}
})
if ok {
return onboardedResources, nil
}
onboardedResources, err := d.repos.OnboardedResources.Get(ctx)
if err != nil {
return types.OnboardedResources{}, fmt.Errorf("failed to get onboarded resources: %w", err)
}
_ = d.store.withOnboardedLock(func() error {
d.store.onboardedResources = &onboardedResources
return nil
})
return onboardedResources, nil
}
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
func (d *DefaultManager) UpdateOnboardedResources(ctx context.Context, resources types.Resources) error {
if err := d.store.withOnboardedLock(func() error {
// calculate the new free resources based on the allocations
totalAllocation, err := d.GetTotalAllocation()
if err != nil {
return fmt.Errorf("getting total allocations: %w", err)
}
// QUESTION(@kanishka): should this line be here or after the Subtract? I think after.
onboardedResources := types.OnboardedResources{Resources: resources}
// Check if the demand is too high
if err := resources.Subtract(totalAllocation); err != nil {
return fmt.Errorf("couldn't subtract allocation: %w. Demand too high", err)
}
// Potential issue: if the onboarded resources are updated in the db, the free resources should be updated as well
// If the free resources update fails, the onboarded resources should not be updated
// Since we have no concept of transactions in the current implementation of db, we cannot handle this scenario
// without writing a custom transaction manager
_, err = d.repos.OnboardedResources.Save(ctx, onboardedResources)
if err != nil {
return fmt.Errorf("failed to update onboarded resources: %w", err)
}
d.store.onboardedResources = &onboardedResources
return nil
}); err != nil {
return err
}
return nil
}
// loadAllocationsFromDB fetches the allocations from the database
func (d *DefaultManager) loadAllocationsFromDB(ctx context.Context) error { //nolint:unused
allocations, err := d.repos.ResourceAllocation.FindAll(ctx, d.repos.ResourceAllocation.GetQuery())
if err != nil {
return fmt.Errorf("loading allocations from db: %w", err)
}
for _, allocation := range allocations {
d.store.allocations[allocation.JobID] = allocation
}
return nil
}
// storeAllocation stores the allocations in the database and the store
func (d *DefaultManager) storeAllocation(ctx context.Context, allocation types.ResourceAllocation) error {
_, err := d.repos.ResourceAllocation.Create(ctx, allocation)
if err != nil {
return fmt.Errorf("storing allocations in db: %w", err)
}
d.store.withAllocationsLock(func() {
d.store.allocations[allocation.JobID] = allocation
})
return nil
}
// deleteAllocation deletes the allocations from the database and the store
func (d *DefaultManager) deleteAllocation(ctx context.Context, jobID string) error {
query := d.repos.ResourceAllocation.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("JobID", jobID))
allocation, err := d.repos.ResourceAllocation.Find(context.Background(), query)
if err != nil {
return fmt.Errorf("finding allocations in db: %w", err)
}
if err := d.repos.ResourceAllocation.Delete(ctx, allocation.ID); err != nil {
return fmt.Errorf("deleting allocations from db: %w", err)
}
d.store.withAllocationsLock(func() {
delete(d.store.allocations, jobID)
})
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package resources
import (
"sync"
"gitlab.com/nunet/device-management-service/types"
)
// locks holds the locks for the resource manager
// allocations: lock for the allocations map
// onboarded: lock for the onboarded resources
// free: lock for the free resources
type locks struct {
allocations sync.RWMutex
onboarded sync.RWMutex
committed sync.RWMutex
}
// newLocks returns a new locks instance
func newLocks() *locks {
return &locks{}
}
// store holds the resources of the machine
// onboardedResources: resources that are onboarded to the machine
// freeResources: resources that are free to be allocated
// allocations: resources that are requested by the jobs
type store struct {
onboardedResources *types.OnboardedResources
committedResources map[string]*types.CommittedResources
allocations map[string]types.ResourceAllocation
locks *locks
}
// newStore returns a new store instance
func newStore() *store {
return &store{
allocations: make(map[string]types.ResourceAllocation),
committedResources: make(map[string]*types.CommittedResources),
locks: newLocks(),
}
}
// withAllocationsLock locks the allocations lock and executes the function
func (s *store) withAllocationsLock(fn func()) {
s.locks.allocations.Lock()
defer s.locks.allocations.Unlock()
fn()
}
// withOnboardedLock locks the onboarded lock and executes the function
func (s *store) withOnboardedLock(fn func() error) error {
s.locks.onboarded.Lock()
defer s.locks.onboarded.Unlock()
return fn()
}
// withCommittedLock locks the committed lock and executes the function
func (s *store) withCommittedLock(fn func()) {
s.locks.committed.Lock()
defer s.locks.committed.Unlock()
fn()
}
func (s *store) withCommittedRLock(fn func()) {
s.locks.committed.RLock()
defer s.locks.committed.RUnlock()
fn()
}
// withAllocationsRLock performs a read lock and returns the result and error
func (s *store) withAllocationsRLock(fn func()) {
s.locks.allocations.RLock()
defer s.locks.allocations.RUnlock()
fn()
}
// withOnboardedRLock performs a read lock and returns the result and error
func (s *store) withOnboardedRLock(fn func()) {
s.locks.onboarded.RLock()
defer s.locks.onboarded.RUnlock()
fn()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package dms
import (
"gorm.io/gorm"
)
// SanityCheck before being deleted performed basic consistency checks before starting the DMS
// in the following sequence:
// It checks for services that are marked running from the database and stops then removes them.
// Update their status to 'finshed with errors'.
// Recalculates free resources and update the database.
//
// Deleted now because dependencies such as the docker package have been replaced with executor/docker
func SanityCheck(_ *gorm.DB) {
// TODO: sanity check of DMS last exit and correction of invalid states
// resources.CalcFreeResAndUpdateDB()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package docker
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/stdcopy"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
// Client wraps the Docker client to provide high-level operations on Docker containers and networks.
type Client struct {
client *client.Client // Embed the Docker client.
}
// NewDockerClient initializes a new Docker client with environment variables and API version negotiation.
func NewDockerClient() (*Client, error) {
log.Infow("docker_client_init_started")
c, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation(), client.WithHostFromEnv())
if err != nil {
log.Errorw("docker_client_init_failure", "error", err)
return nil, err
}
log.Infow("docker_client_init_success")
return &Client{client: c}, nil
}
// IsInstalled checks if Docker is installed and reachable by pinging the Docker daemon.
func (c *Client) IsInstalled(ctx context.Context) bool {
log.Infow("docker_client_is_installed_check_started")
_, err := c.client.Ping(ctx)
if err != nil {
log.Errorw("docker_client_is_installed_failure", "error", err)
return false
}
log.Infow("docker_client_is_installed_success")
return true
}
// CreateContainer creates a new Docker container with the specified configuration.
func (c *Client) CreateContainer(
ctx context.Context,
config *container.Config,
hostConfig *container.HostConfig,
networkingConfig *network.NetworkingConfig,
platform *v1.Platform,
name string,
pullImage bool,
) (string, error) {
if pullImage {
log.Infow("docker_create_container_started", "image", config.Image)
_, err := c.PullImage(ctx, config.Image)
if err != nil {
log.Errorw("docker_create_container_failure", "error", err)
return "", err
}
}
resp, err := c.client.ContainerCreate(
ctx,
config,
hostConfig,
networkingConfig,
platform,
name,
)
if err != nil {
log.Errorw("docker_create_container_failure", "error", err)
return "", err
}
log.Infow("docker_create_container_success", "containerID", resp.ID)
return resp.ID, nil
}
// InspectContainer returns detailed information about a Docker container.
func (c *Client) InspectContainer(ctx context.Context, id string) (types.ContainerJSON, error) {
log.Infow("docker_inspect_container_started", "containerID", id)
return c.client.ContainerInspect(ctx, id)
}
// FollowLogs tails the logs of a specified container, returning separate readers for stdout and stderr.
func (c *Client) FollowLogs(ctx context.Context, id string) (stdout, stderr io.Reader, err error) {
log.Infow("docker_follow_logs_started", "containerID", id)
cont, err := c.InspectContainer(ctx, id)
if err != nil {
log.Errorw("docker_follow_logs_failure", "error", err)
return nil, nil, errors.Wrap(err, "failed to get container")
}
logOptions := container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
}
logsReader, err := c.client.ContainerLogs(ctx, cont.ID, logOptions)
if err != nil {
log.Errorw("docker_follow_logs_failure", "error", err)
return nil, nil, errors.Wrap(err, "failed to get container logs")
}
stdoutReader, stdoutWriter := io.Pipe()
stderrReader, stderrWriter := io.Pipe()
go func() {
stdoutBuffer := bufio.NewWriter(stdoutWriter)
stderrBuffer := bufio.NewWriter(stderrWriter)
defer func() {
logsReader.Close()
stdoutBuffer.Flush()
stdoutWriter.Close()
stderrBuffer.Flush()
stderrWriter.Close()
}()
_, err = stdcopy.StdCopy(stdoutBuffer, stderrBuffer, logsReader)
if err != nil && !errors.Is(err, context.Canceled) {
log.Warnf("context closed while getting logs: %v\n", err)
}
}()
return stdoutReader, stderrReader, nil
}
// StartContainer starts a specified Docker container.
func (c *Client) StartContainer(ctx context.Context, containerID string) error {
log.Infow("docker_start_container_started", "containerID", containerID)
return c.client.ContainerStart(ctx, containerID, container.StartOptions{})
}
// WaitContainer waits for a container to stop, returning channels for the result and errors.
func (c *Client) WaitContainer(
ctx context.Context,
containerID string,
) (<-chan container.WaitResponse, <-chan error) {
log.Infow("docker_wait_container_started", "containerID", containerID)
return c.client.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
}
// PauseContainer pauses the main process of the given container without terminating it.
func (c *Client) PauseContainer(ctx context.Context, containerID string) error {
return c.client.ContainerPause(ctx, containerID)
}
// ResumeContainer resumes the process execution within the container
func (c *Client) ResumeContainer(ctx context.Context, containerID string) error {
return c.client.ContainerUnpause(ctx, containerID)
}
// StopContainer stops a running Docker container with a specified timeout.
func (c *Client) StopContainer(
ctx context.Context,
containerID string,
options container.StopOptions,
) error {
log.Infow("docker_stop_container_started", "containerID", containerID)
return c.client.ContainerStop(ctx, containerID, options)
}
// RemoveContainer removes a Docker container, optionally forcing removal and removing associated volumes.
func (c *Client) RemoveContainer(ctx context.Context, containerID string) error {
log.Infow("docker_remove_container_started", "containerID", containerID)
return c.client.ContainerRemove(
ctx,
containerID,
container.RemoveOptions{RemoveVolumes: true, Force: true},
)
}
// removeContainers removes all containers matching the specified filters.
func (c *Client) removeContainers(ctx context.Context, filterz filters.Args) error {
log.Infow("docker_remove_containers_started")
containers, err := c.client.ContainerList(
ctx,
container.ListOptions{All: true, Filters: filterz},
)
if err != nil {
log.Errorw("docker_remove_containers_failure", "error", err)
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(containers))
for _, container := range containers {
wg.Add(1)
go func(container types.Container, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.RemoveContainer(ctx, container.ID)
}(container, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
if errs != nil {
log.Errorw("docker_remove_containers_failure", "error", errs)
} else {
log.Infow("docker_remove_containers_success")
}
return errs
}
// removeNetworks removes all networks matching the specified filters.
func (c *Client) removeNetworks(ctx context.Context, filterz filters.Args) error {
log.Infow("docker_remove_networks_started")
networks, err := c.client.NetworkList(ctx, network.ListOptions{Filters: filterz})
if err != nil {
log.Errorw("docker_remove_networks_failure", "error", err)
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(networks))
for _, n := range networks {
wg.Add(1)
go func(network network.Inspect, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.client.NetworkRemove(ctx, network.ID)
}(n, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
if errs != nil {
log.Errorw("docker_remove_networks_failure", "error", errs)
} else {
log.Infow("docker_remove_networks_success")
}
return errs
}
// RemoveObjectsWithLabel removes all Docker containers and networks with a specific label.
func (c *Client) RemoveObjectsWithLabel(ctx context.Context, label string, value string) error {
log.Infow("docker_remove_objects_with_label_started", "label", label, "value", value)
filterz := filters.NewArgs(
filters.Arg("label", fmt.Sprintf("%s=%s", label, value)),
)
containerErr := c.removeContainers(ctx, filterz)
networkErr := c.removeNetworks(ctx, filterz)
if containerErr != nil || networkErr != nil {
log.Errorw("docker_remove_objects_with_label_failure", "containerErr", containerErr, "networkErr", networkErr)
}
log.Infow("docker_remove_objects_with_label_success")
return multierr.Combine(containerErr, networkErr)
}
// GetOutputStream streams the logs for a specified container.
// The 'since' parameter specifies the timestamp from which to start streaming logs.
// The 'follow' parameter indicates whether to continue streaming logs as they are produced.
// Returns an io.ReadCloser to read the output stream and an error if the operation fails.
func (c *Client) GetOutputStream(
ctx context.Context,
containerID string,
since string,
follow bool,
) (io.ReadCloser, error) {
log.Infow("docker_get_output_stream_started", "containerID", containerID)
logOptions := container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: follow,
Since: since,
}
logReader, err := c.client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
log.Errorw("docker_get_output_stream_failure", "error", err)
return nil, errors.Wrap(err, "failed to get container logs")
}
log.Infow("docker_get_output_stream_success", "containerID", containerID)
return logReader, nil
}
// FindContainer searches for a container by label and value, returning its ID if found.
func (c *Client) FindContainer(ctx context.Context, label string, value string) (string, error) {
log.Infow("docker_find_container_started", "label", label, "value", value)
containers, err := c.client.ContainerList(ctx, container.ListOptions{All: true})
if err != nil {
log.Errorw("docker_find_container_failure", "error", err)
return "", err
}
for _, cont := range containers {
if cont.Labels[label] == value {
log.Infow("docker_find_container_success", "containerID", cont.ID)
return cont.ID, nil
}
}
err = fmt.Errorf("unable to find container for %s=%s", label, value)
log.Errorw("docker_find_container_failure", "error", err)
return "", err
}
// GetImage returns detailed information about a Docker image.
func (c *Client) GetImage(ctx context.Context, imageName string) (image.Summary, error) {
images, err := c.client.ImageList(ctx, image.ListOptions{All: true})
if err != nil {
return image.Summary{}, err
}
// If imageName does not contain a tag, we need to append ":latest" to the image name.
if !strings.Contains(imageName, ":") {
imageName = fmt.Sprintf("%s:latest", imageName)
}
for _, image := range images {
for _, tag := range image.RepoTags {
if tag == imageName {
return image, nil
}
}
}
return image.Summary{}, fmt.Errorf("unable to find image %s", imageName)
}
// PullImage pulls a Docker image from a registry.
func (c *Client) PullImage(ctx context.Context, imageName string) (string, error) {
log.Infow("docker_pull_image_started", "image", imageName)
out, err := c.client.ImagePull(ctx, imageName, image.PullOptions{})
if err != nil {
log.Errorw("docker_pull_image_failure", "error", err)
return "", err
}
defer out.Close()
d := json.NewDecoder(io.TeeReader(out, os.Stdout))
var message jsonmessage.JSONMessage
var digest string
for {
if err := d.Decode(&message); err != nil {
if err == io.EOF {
break
}
log.Errorw("docker_pull_image_failure", "error", err)
return "", err
}
if message.Aux != nil {
continue
}
if message.Error != nil {
log.Errorw("docker_pull_image_failure", "error", message.Error.Message)
return "", errors.New(message.Error.Message)
}
if strings.HasPrefix(message.Status, "Digest") {
digest = strings.TrimPrefix(message.Status, "Digest: ")
}
}
log.Infow("docker_pull_image_success", "digest", digest)
return digest, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package docker
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"gitlab.com/nunet/device-management-service/dms/hardware"
"gitlab.com/nunet/device-management-service/observability"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"github.com/pkg/errors"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
var ErrNotInstalled = errors.New("docker is not installed")
const (
nanoCPUsPerCore = 1e9
labelExecutorName = "nunet-executor"
labelJobID = "nunet-jobID"
labelExecutionID = "nunet-executionID"
outputStreamCheckTickTime = 100 * time.Millisecond
outputStreamCheckTimeout = 5 * time.Second
statusWaitTickTime = 100 * time.Millisecond
statusWaitTimeout = 10 * time.Second
initScriptsBaseDir = "/tmp/nunet-init-scripts-"
)
// Executor manages the lifecycle of Docker containers for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Docker client for container management.
}
// NewExecutor initializes a new Executor instance with a Docker client.
func NewExecutor(ctx context.Context, id string) (*Executor, error) {
dockerClient, err := NewDockerClient()
if err != nil {
return nil, err
}
if !dockerClient.IsInstalled(ctx) {
return nil, ErrNotInstalled
}
return &Executor{
ID: id,
client: dockerClient,
}, nil
}
// Start begins the execution of a request by starting a Docker container.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
endTrace := observability.StartTrace("docker_executor_start_duration")
defer endTrace()
// Log starting execution
log.Infow("docker_executor_start_begin", "jobID", request.JobID, "executionID", request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// container is already running.
containerID, err := e.FindRunningContainer(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running container for this execution, we will instead check for a handler, and
// failing that will create a new container.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
log.Errorw("docker_executor_start_failure", "executionID", request.ExecutionID, "error", "execution already started")
return fmt.Errorf("execution is already started")
}
log.Errorw("docker_executor_start_failure", "executionID", request.ExecutionID, "error", "execution completed")
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
containerID, err = e.newDockerExecutionContainer(ctx, request)
if err != nil {
log.Errorw("docker_executor_start_failure", "executionID", request.ExecutionID, "error", err)
return fmt.Errorf("failed to create new container: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
containerID: containerID,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
TTYEnabled: true,
initScripts: request.ProvisionScripts,
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the container.
go handler.run(ctx)
return nil
}
// Pause pauses the container
func (e *Executor) Pause(
ctx context.Context,
executionID string,
) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
return handler.pause(ctx)
}
// Resume resumes the container
func (e *Executor) Resume(
ctx context.Context,
executionID string,
) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
return handler.resume(ctx)
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
endTrace := observability.StartTrace("docker_executor_wait_duration")
defer endTrace()
log.Infow("docker_executor_wait_begin", "executionID", executionID)
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
log.Errorw("docker_executor_wait_failure", "executionID", executionID, "error", "execution not found")
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
log.Infow("docker_executor_wait_success", "executionID", executionID)
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
log.Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
log.Debugf("executionID %s received results from execution ", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// GetLogStream provides a stream of output logs for a specific execution.
// Parameters 'withHistory' and 'follow' control whether to include past logs
// and whether to keep the stream open for new logs, respectively.
// It returns an error if the execution is not found.
func (e *Executor) GetLogStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
// It's possible we've recorded the execution as running, but have not yet added the handler to
// the handler map because we're still waiting for the container to start. We will try and wait
// for a few seconds to see if the handler is added to the map.
chHandler := make(chan *executionHandler)
chExit := make(chan struct{})
go func(ch chan *executionHandler, exit chan struct{}) {
// Check the handlers every 100ms and send it down the
// channel if we find it. If we don't find it after 5 seconds
// then we'll be told on the exit channel
ticker := time.NewTicker(outputStreamCheckTickTime)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h, found := e.handlers.Get(request.ExecutionID)
if found {
ch <- h
return
}
case <-exit:
ticker.Stop()
return
}
}
}(chHandler, chExit)
// Either we'll find a handler for the execution (which might have finished starting)
// or we'll timeout and return an error.
select {
case handler := <-chHandler:
return handler.outputStream(ctx, request)
case <-time.After(outputStreamCheckTimeout):
chExit <- struct{}{}
}
return nil, fmt.Errorf("execution (%s) not found", request.ExecutionID)
}
// List returns a slice of ExecutionListItem containing information about current executions.
// This implementation currently returns an empty list and should be updated in the future.
func (e *Executor) List() []types.ExecutionListItem {
// TODO: list dms containers
return nil
}
// GetStatus returns the status of the execution identified by its executionID.
// It returns the status of the execution and an error if the execution is not found or status is unknown.
func (e *Executor) GetStatus(ctx context.Context, executionID string) (types.ExecutionStatus, error) {
handler, found := e.handlers.Get(executionID)
if !found {
return "", fmt.Errorf("execution (%s) not found", executionID)
}
return handler.status(ctx)
}
// WaitForStatus waits for the execution to reach a specific status.
// It returns an error if the execution is not found or the status is unknown.
func (e *Executor) WaitForStatus(
ctx context.Context,
executionID string,
status types.ExecutionStatus,
timeout *time.Duration,
) error {
waitTimeout := statusWaitTimeout
if timeout != nil {
waitTimeout = *timeout
}
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
ticker := time.NewTicker(statusWaitTickTime)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
s, err := handler.status(ctx)
if err != nil {
return err
}
if s == status {
return nil
}
case <-time.After(waitTimeout):
return fmt.Errorf("execution (%s) did not reach status %s", executionID, status)
}
}
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context, request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// Cleanup removes all Docker resources associated with the executor.
// This includes removing containers including networks and volumes with the executor's label.
// It also removes all temporary directories created for init scripts.
func (e *Executor) Cleanup(ctx context.Context) error {
endTrace := observability.StartTrace("docker_executor_cleanup_duration")
defer endTrace()
log.Infow("docker_executor_cleanup_begin", "executorID", e.ID)
err := e.client.RemoveObjectsWithLabel(ctx, labelExecutorName, e.ID)
if err != nil {
log.Errorw("docker_executor_cleanup_failure", "executorID", e.ID, "error", err)
return fmt.Errorf("failed to remove containers: %w", err)
}
log.Infow("docker_executor_cleanup_success", "executorID", e.ID)
// Remove all provision scripts used for mounting
pattern := initScriptsBaseDir + "*"
matches, err := filepath.Glob(pattern)
if err != nil {
return fmt.Errorf("failed to find init script directories: %w", err)
}
for _, dir := range matches {
if err := os.RemoveAll(dir); err != nil {
log.Warnf("Failed to remove init script directory %s: %v", dir, err)
} else {
log.Infof("Removed init script directory: %s", dir)
}
}
return nil
}
// newDockerExecutionContainer is an internal method called by Start to set up a new Docker container
// for the job execution. It configures the container based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up environment variables, mounts and resource
// constraints. It then creates the container but does not start it.
// The method returns a container.CreateResponse and an error if any part of the setup fails.
func (e *Executor) newDockerExecutionContainer(
ctx context.Context,
params *types.ExecutionRequest,
) (string, error) {
dockerArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return "", fmt.Errorf("failed to decode docker engine spec: %w", err)
}
// TODO: Move this code block ( L263-272) to the allocator in future
// Select the GPU with the highest available free VRAM and choose the GPU vendor for container's host config
// TODO: use the hardware manager instantiated in the dms package
hardwareManager := hardware.NewHardwareManager()
machineResources, err := hardwareManager.GetMachineResources()
if err != nil {
return "", fmt.Errorf("failed to get machine resources: %w", err)
}
var chosenGPUVendor types.GPUVendor
if len(machineResources.GPUs) == 0 {
log.Infow("no GPUs available on the machine")
chosenGPUVendor = types.GPUVendorNone
} else {
// Essential for multi-vendor GPU nodes. For example,
// if a machine has an 8 GB NVIDIA and a 16 GB Intel GPU, the latter should be used first.
// Even for machines with a single GPU, this is important as integrated GPUs would also be commonly detected.
maxFreeVRAMGpu, err := machineResources.GPUs.MaxFreeVRAMGPU()
if err != nil {
// TODO: log a warning here
chosenGPUVendor = types.GPUVendorNone
} else {
chosenGPUVendor = maxFreeVRAMGpu.Vendor
}
}
containerConfig := container.Config{
Image: dockerArgs.Image,
Env: dockerArgs.Environment,
Entrypoint: dockerArgs.Entrypoint,
Cmd: dockerArgs.Cmd,
Labels: e.containerLabels(params.JobID, params.ExecutionID),
WorkingDir: dockerArgs.WorkingDirectory,
// TODO (Tty): tty currently breaks the logs and consequently the `Run()` methods and `GetLogStream()`.
// to enable Tty, besides setting to true, we must handle the logs correctly.
// Needs to be true for applications such as Jupyter or Gradio to work correctly. See issue #459 for details.
// Tty: true,
}
mounts, err := makeContainerMounts(params.Inputs, params.Outputs, params.ResultsDir)
if err != nil {
return "", fmt.Errorf("failed to create container mounts: %w", err)
}
initScriptsDir, err := prepareInitScripts(params.ProvisionScripts, params.ExecutionID)
if err != nil {
return "", fmt.Errorf("failed to prepare init scripts: %w", err)
}
if initScriptsDir != "" {
oldEntryPoint := containerConfig.Entrypoint
oldCmd := containerConfig.Cmd
// Execute init scripts first
containerConfig.Entrypoint = []string{"/bin/sh", "-c"}
containerConfig.Cmd = []string{
fmt.Sprintf("%s/run_provision_scripts.sh && %s %s",
initScriptsDir,
strings.Join(oldEntryPoint, " "),
strings.Join(oldCmd, " ")),
}
// Add a mount for the init scripts
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: initScriptsDir,
Target: initScriptsDir,
})
}
log.Infof("Adding %d GPUs to request", len(params.Resources.GPUs))
hostConfig := configureHostConfig(chosenGPUVendor, params, mounts)
executionContainer, err := e.client.CreateContainer(
ctx,
&containerConfig,
&hostConfig,
nil,
nil,
labelExecutionValue(e.ID, params.JobID, params.ExecutionID),
true,
)
if err != nil {
return "", fmt.Errorf("failed to create container: %w", err)
}
return executionContainer, nil
}
// prepareInitScripts creates a shell script that will run all init scripts
func prepareInitScripts(scripts map[string][]byte, id string) (string, error) {
if len(scripts) == 0 {
return "", nil
}
tempDir := initScriptsBaseDir + id
err := os.MkdirAll(tempDir, 0o700)
if err != nil {
return "", fmt.Errorf("failed to create init scripts base directory: %w", err)
}
scriptNames := make([]string, 0, len(scripts))
for name, content := range scripts {
filename := filepath.Join(tempDir, name)
if err := os.WriteFile(filename, content, 0o700); err != nil {
return "", fmt.Errorf("failed to write init script %s: %w", name, err)
}
scriptNames = append(scriptNames, filename)
}
// Create a wrapper script to execute all init scripts
wrapperContent := "#!/bin/sh\n\n"
for _, script := range scriptNames {
wrapperContent += fmt.Sprintf("echo 'Executing %s'\n", filepath.Base(script))
wrapperContent += fmt.Sprintf("%s\n", script)
}
wrapperPath := filepath.Join(tempDir, "run_provision_scripts.sh")
if err := os.WriteFile(wrapperPath, []byte(wrapperContent), 0o700); err != nil {
return "", fmt.Errorf("failed to write wrapper script: %w", err)
}
return tempDir, nil
}
// configureHostConfig sets up the host configuration for the container based on the
// GPU vendor and resources requested by the execution. It supports both GPU and CPU configurations.
func configureHostConfig(vendor types.GPUVendor, params *types.ExecutionRequest, mounts []mount.Mount) container.HostConfig {
var hostConfig container.HostConfig
switch vendor {
case types.GPUVendorNvidia:
deviceIDs := make([]string, len(params.Resources.GPUs))
for i, gpu := range params.Resources.GPUs {
deviceIDs[i] = fmt.Sprint(gpu.Index)
}
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores * nanoCPUsPerCore),
CPUCount: int64(params.Resources.CPU.Cores),
DeviceRequests: []container.DeviceRequest{
{
DeviceIDs: deviceIDs,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores * nanoCPUsPerCore),
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
// Updated the device handling for Intel GPUs.
// Previously, specific device paths were determined using PCI addresses and symlinks.
// Now, the approach has been simplified by directly binding the entire /dev/dri directory.
// This change exposes all Intel GPUs to the container, which may be preferable for
// environments with multiple Intel GPUs. It reduces complexity as granular control
// is not required if all GPUs need to be accessible.
case types.GPUVendorIntel:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores * nanoCPUsPerCore),
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
default:
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: int64(params.Resources.CPU.Cores * nanoCPUsPerCore),
CPUCount: int64(params.Resources.CPU.Cores),
},
}
}
return hostConfig
}
// makeContainerMounts creates the mounts for the container based on the input and output
// volumes provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeContainerMounts(
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]mount.Mount, error) {
// the actual mounts we will give to the container
// these are paths for both input and output data
mounts := make([]mount.Mount, 0)
for _, input := range inputs {
if input.Type != types.StorageVolumeTypeBind {
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: input.Source,
Target: input.Target,
ReadOnly: input.ReadOnly,
})
} else {
return nil, fmt.Errorf("unsupported storage volume type: %s", input.Type)
}
}
for _, output := range outputs {
if output.Source == "" {
return nil, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return nil, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return nil, fmt.Errorf("failed to create results directory: %w", err)
}
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: output.Source,
Target: output.Target,
// this is an output volume so can be written to
ReadOnly: false,
})
}
return mounts, nil
}
// containerLabels returns the labels to be applied to the container for the given job and execution.
func (e *Executor) containerLabels(jobID string, executionID string) map[string]string {
return map[string]string{
labelExecutorName: e.ID,
labelJobID: labelJobValue(e.ID, jobID),
labelExecutionID: labelExecutionValue(e.ID, jobID, executionID),
}
}
// labelJobValue returns the value for the job label.
func labelJobValue(executorID string, jobID string) string {
return fmt.Sprintf("%s_%s", executorID, jobID)
}
// labelExecutionValue returns the value for the execution label.
func labelExecutionValue(executorID string, jobID string, executionID string) string {
return fmt.Sprintf("%s_%s_%s", executorID, jobID, executionID)
}
// FindRunningContainer finds the container that is running the execution
// with the given ID. It returns the container ID if found, or an error if
// the container is not found.
func (e *Executor) FindRunningContainer(
ctx context.Context,
jobID string,
executionID string,
) (string, error) {
labelValue := labelExecutionValue(e.ID, jobID, executionID)
return e.client.FindContainer(ctx, labelExecutionID, labelValue)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package docker
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
var DestroyTimeout = time.Second * 10
// executionHandler manages the lifecycle and execution of a Docker container for a specific job.
type executionHandler struct {
// provided by the executor
ID string
client *Client // Docker client for container management.
// meta data about the task
jobID string
executionID string
containerID string
resultsDir string // Directory to store execution results.
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // Blocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
// TTY setting
TTYEnabled bool // Indicates if TTY is enabled for the container.
// others
initScripts map[string][]byte
}
// active checks if the execution handler's container is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the container and handles its execution lifecycle.
func (h *executionHandler) run(ctx context.Context) {
endTrace := observability.StartTrace("docker_execution_handler_run_duration")
defer endTrace()
h.running.Store(true)
defer func() {
if err := h.destroy(DestroyTimeout); err != nil {
log.Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
if err := h.client.StartContainer(ctx, h.containerID); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start container: %v", err))
log.Errorw("docker_execution_handler_run_failure", "error", err)
return
}
close(h.activeCh) // Indicate that the container has started.
log.Infow("docker_execution_handler_run_success", "executionID", h.executionID)
var containerError error
var containerExitStatusCode int64
// Wait for the container to finish or for an execution error.
statusCh, errCh := h.client.WaitContainer(ctx, h.containerID)
select {
case status := <-ctx.Done():
h.result = types.NewFailedExecutionResult(fmt.Errorf("execution cancelled: %v", status))
log.Errorw("docker_execution_handler_run_failure_cancelled", "executionID", h.executionID)
return
case err := <-errCh:
log.Errorf("error while waiting for container: %v\n", err)
h.result = types.NewFailedExecutionResult(
fmt.Errorf("failed to wait for container: %v", err),
)
return
case exitStatus := <-statusCh:
containerExitStatusCode = exitStatus.StatusCode
containerJSON, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: err.Error(),
}
log.Errorw("docker_execution_handler_inspect_container_failure", "error", err)
return
}
if containerJSON.State.OOMKilled {
containerError = errors.New("container was killed due to OOM")
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
log.Errorw("docker_execution_handler_container_oom_killed", "executionID", h.executionID)
return
}
if exitStatus.Error != nil {
containerError = errors.New(exitStatus.Error.Message)
}
}
// Follow container logs to capture stdout and stderr.
stdoutPipe, stderrPipe, logsErr := h.client.FollowLogs(ctx, h.containerID)
if logsErr != nil {
followError := fmt.Errorf("failed to follow container logs: %w", logsErr)
if containerError != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: fmt.Sprintf(
"container error: '%s'. logs error: '%s'",
containerError,
followError,
),
}
} else {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: followError.Error(),
}
}
log.Errorw("docker_execution_handler_follow_logs_failure", "error", logsErr)
return
}
// Initialize the result with the exit status code.
h.result = types.NewExecutionResult(int(containerExitStatusCode))
// Capture the logs based on the TTY setting.
if h.TTYEnabled {
// TTY combines stdout and stderr, read from stdoutPipe only.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
} else {
// Read from stdout and stderr separately.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
h.result.STDERR, _ = bufio.NewReader(stderrPipe).ReadString('\x00')
}
log.Infow("docker_execution_handler_run_logs_success", "executionID", h.executionID)
}
// pause pauses the main process of the container without terminating it.
func (h *executionHandler) pause(ctx context.Context) error {
return h.client.PauseContainer(ctx, h.containerID)
}
// resume resumes the process execution within the container
func (h *executionHandler) resume(ctx context.Context) error {
return h.client.ResumeContainer(ctx, h.containerID)
}
// kill sends a stop signal to the container.
func (h *executionHandler) kill(ctx context.Context) error {
endTrace := observability.StartTrace("docker_execution_handler_kill_duration")
defer endTrace()
timeout := int(DestroyTimeout)
stopOptions := container.StopOptions{
Timeout: &timeout,
}
err := h.client.StopContainer(ctx, h.containerID, stopOptions)
if err != nil {
log.Errorw("docker_execution_handler_kill_failure", "error", err, "executionID", h.executionID)
return err
}
log.Infow("docker_execution_handler_kill_success", "executionID", h.executionID)
return nil
}
// destroy cleans up the container and its associated resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
endTrace := observability.StartTrace("docker_execution_handler_destroy_duration")
defer endTrace()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// stop the container
if err := h.kill(ctx); err != nil {
log.Errorw("docker_execution_handler_destroy_failure", "error", err, "executionID", h.executionID)
return fmt.Errorf("failed to kill container (%s): %w", h.containerID, err)
}
if err := h.client.RemoveContainer(ctx, h.containerID); err != nil {
log.Errorw("docker_execution_handler_destroy_failure", "error", err, "executionID", h.executionID)
return err
}
err := os.RemoveAll(initScriptsBaseDir + h.executionID)
if err != nil {
return err
}
// Remove related objects like networks or volumes created for this execution.
err = h.client.RemoveObjectsWithLabel(
ctx,
labelExecutionID,
labelExecutionValue(h.ID, h.jobID, h.executionID),
)
if err != nil {
log.Errorw("docker_execution_handler_destroy_failure", "error", err, "executionID", h.executionID)
return err
}
log.Infow("docker_execution_handler_destroy_success", "executionID", h.executionID)
return nil
}
func (h *executionHandler) outputStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
endTrace := observability.StartTrace("docker_execution_handler_output_stream_duration")
defer endTrace()
since := "1" // Default to the start of UNIX time to get all logs.
if request.Tail {
since = strconv.FormatInt(time.Now().Unix(), 10)
}
select {
case <-ctx.Done():
log.Errorw("docker_execution_handler_output_stream_canceled", "executionID", h.executionID)
return nil, ctx.Err()
case <-h.activeCh: // Ensure the container is active before attempting to stream logs.
}
// Gets the underlying reader, and provides data since the value of the `since` timestamp.
reader, err := h.client.GetOutputStream(ctx, h.containerID, since, request.Follow)
if err != nil {
log.Errorw("docker_execution_handler_output_stream_failure", "error", err, "executionID", h.executionID)
return nil, err
}
log.Infow("docker_execution_handler_output_stream_success", "executionID", h.executionID)
return reader, nil
}
// status returns the result of the execution.
func (h *executionHandler) status(ctx context.Context) (types.ExecutionStatus, error) {
if h.result != nil {
if h.result.ExitCode == types.ExecutionStatusCodeSuccess {
return types.ExecutionStatusSuccess, nil
}
return types.ExecutionStatusFailed, nil
}
info, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
return types.ExecutionStatusFailed, fmt.Errorf("failed to get container status: %v", err)
}
switch info.State.Status {
case "created":
return types.ExecutionStatusPending, nil
case "running":
return types.ExecutionStatusRunning, nil
case "paused":
return types.ExecutionStatusPaused, nil
case "exited":
if info.State.ExitCode == 0 {
return types.ExecutionStatusSuccess, nil
}
return types.ExecutionStatusFailed, nil
case "dead":
return types.ExecutionStatusFailed, nil
default:
return types.ExecutionStatusFailed, fmt.Errorf("unknown container status: %s", info.State.Status)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package docker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyImage = "image"
EngineKeyEntrypoint = "entrypoint"
EngineKeyCmd = "cmd"
EngineKeyEnvironment = "environment"
EngineKeyWorkingDirectory = "working_directory"
)
// EngineSpec contains necessary parameters to execute a docker job.
type EngineSpec struct {
// Image this should be pullable by docker
Image string `json:"image,omitempty"`
// Entrypoint optionally override the default entrypoint
Entrypoint []string `json:"entrypoint,omitempty"`
// Cmd specifies the command to run in the container
Cmd []string `json:"cmd,omitempty"`
// EnvironmentVariables is a slice of env to run the container with
Environment []string `json:"environment,omitempty"`
// WorkingDirectory inside the container
WorkingDirectory string `json:"working_directory,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.Image) {
return fmt.Errorf("invalid docker engine params: image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a docker engine spec
// It converts the params into a docker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(string(types.ExecutorTypeDocker)) {
return EngineSpec{}, fmt.Errorf(
"invalid docker engine type. expected %s, but received: %s",
types.ExecutorTypeDocker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid docker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode docker engine params: %w", err)
}
var dockerSpec *EngineSpec
if err := json.Unmarshal(paramBytes, &dockerSpec); err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode docker engine params: %w", err)
}
return *dockerSpec, dockerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Docker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewDockerEngineBuilder function initializes a new DockerEngineBuilder instance.
// It sets the engine type to model.EngineDocker.String() and image as per the input argument.
func NewDockerEngineBuilder(image string) *EngineBuilder {
eb := types.NewSpecConfig(string(types.ExecutorTypeDocker))
eb.WithParam(EngineKeyImage, image)
return &EngineBuilder{eb: eb}
}
// WithEntrypoint is a builder method that sets the Docker engine entrypoint.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEntrypoint(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEntrypoint, e)
return b
}
// WithCmd is a builder method that sets the Docker engine's Command.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithCmd(c ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyCmd, c)
return b
}
// WithEnvironment is a builder method that sets the Docker engine's environment variables.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEnvironment(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEnvironment, e)
return b
}
// WithWorkingDirectory is a builder method that sets the Docker engine's working directory.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithWorkingDirectory(w string) *EngineBuilder {
b.eb.WithParam(EngineKeyWorkingDirectory, w)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
// +build linux
package firecracker
import (
"context"
"fmt"
"os"
"syscall"
"time"
firecracker "github.com/firecracker-microvm/firecracker-go-sdk"
fcmodels "github.com/firecracker-microvm/firecracker-go-sdk/client/models"
)
const pidCheckTickTime = 100 * time.Millisecond
// Client wraps the Firecracker SDK to provide high-level operations on Firecracker VMs.
type Client struct{}
// NewFirecrackerClient initializes a new Firecracker client.
func NewFirecrackerClient() (*Client, error) {
log.Infow("firecracker_client_init_started")
client := &Client{}
log.Infow("firecracker_client_init_success")
return client, nil
}
// IsInstalled checks if Firecracker is installed on the host.
func (c *Client) IsInstalled(ctx context.Context) bool {
log.Infow("firecracker_client_is_installed_check_started")
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
cmd := firecracker.VMCommandBuilder{}.WithArgs([]string{"--version"}).Build(ctx)
version, err := cmd.Output()
if err != nil || !cmd.ProcessState.Success() {
log.Errorw("firecracker_client_is_installed_failure", "error", err)
return false
}
isInstalled := string(version) != ""
if isInstalled {
log.Infow("firecracker_client_is_installed_success")
} else {
log.Errorw("firecracker_client_is_installed_failure", "error", "version check failed")
}
return isInstalled
}
// CreateVM creates a new Firecracker VM with the specified configuration.
func (c *Client) CreateVM(
ctx context.Context,
cfg firecracker.Config,
) (*firecracker.Machine, error) {
log.Infow("firecracker_create_vm_started", "socketPath", cfg.SocketPath)
cmd := firecracker.VMCommandBuilder{}.
WithSocketPath(cfg.SocketPath).
Build(ctx)
machineOpts := []firecracker.Opt{
firecracker.WithProcessRunner(cmd),
}
m, err := firecracker.NewMachine(ctx, cfg, machineOpts...)
if err != nil {
log.Errorw("firecracker_create_vm_failure", "error", err)
return nil, err
}
log.Infow("firecracker_create_vm_success", "socketPath", cfg.SocketPath)
return m, nil
}
// StartVM starts the Firecracker VM.
func (c *Client) StartVM(ctx context.Context, m *firecracker.Machine) error {
log.Infow("firecracker_start_vm_started")
err := m.Start(ctx)
if err != nil {
log.Errorw("firecracker_start_vm_failure", "error", err)
return err
}
log.Infow("firecracker_start_vm_success")
return nil
}
// StopVM stops the Firecracker VM.
func (c *Client) StopVM(_ context.Context, m *firecracker.Machine) error {
return m.StopVMM()
}
// ShutdownVM shuts down the Firecracker VM.
func (c *Client) ShutdownVM(ctx context.Context, m *firecracker.Machine) error {
log.Infow("firecracker_shutdown_vm_started")
err := m.Shutdown(ctx)
if err != nil {
log.Errorw("firecracker_shutdown_vm_failure", "error", err)
return err
}
log.Infow("firecracker_shutdown_vm_success")
return nil
}
// DestroyVM destroys the Firecracker VM.
func (c *Client) DestroyVM(
ctx context.Context,
m *firecracker.Machine,
timeout time.Duration,
) error {
log.Infow("firecracker_destroy_vm_started")
defer os.Remove(m.Cfg.SocketPath)
// Get the PID of the Firecracker process and shut down the VM.
// If the process is still running after the timeout, kill it.
// If the process is not running, return early.
pid, _ := m.PID()
if pid <= 0 {
return nil
}
err := c.ShutdownVM(ctx, m)
if err != nil {
return err
}
pid, _ = m.PID()
if pid <= 0 {
log.Infow("firecracker_destroy_vm_no_pid")
return nil
}
// This checks if the process is still running every pidCheckTickTime.
// If the process is still running after the timeout it will set done to false.
done := make(chan bool, 1)
go func() {
ticker := time.NewTicker(pidCheckTickTime)
defer ticker.Stop()
to := time.NewTimer(timeout)
defer to.Stop()
for {
select {
case <-to.C:
done <- false
return
case <-ticker.C:
if pid, _ = m.PID(); pid <= 0 {
done <- true
return
}
}
}
}()
// Wait for the check to finish.
killed := <-done
if !killed {
// The shutdown request timed out, kill the process with SIGKILL.
err := syscall.Kill(pid, syscall.SIGKILL)
if err != nil {
log.Errorw("firecracker_destroy_vm_kill_failure", "error", err)
return fmt.Errorf("failed to kill process: %v", err)
}
log.Infow("firecracker_destroy_vm_kill_success")
}
log.Infow("firecracker_destroy_vm_success")
return nil
}
// PauseVM pauses the Firecracker VM.
func (c *Client) PauseVM(ctx context.Context, m *firecracker.Machine) error {
return m.PauseVM(ctx)
}
// ResumeVM resumes the Firecracker VM.
func (c *Client) ResumeVM(ctx context.Context, m *firecracker.Machine) error {
return m.ResumeVM(ctx)
}
// FindVM finds a Firecracker VM by its socket path.
// This implementation checks if the VM is running by sending a request to the Firecracker API.
func (c *Client) FindVM(ctx context.Context, socketPath string) (*firecracker.Machine, error) {
// Check if the socket file exists.
log.Infow("firecracker_find_vm_started", "socketPath", socketPath)
if _, err := os.Stat(socketPath); err != nil {
log.Errorw("firecracker_find_vm_failure", "error", err)
return nil, fmt.Errorf("VM with socket path %v not found", socketPath)
}
// Create a new Firecracker machine instance.
cmd := firecracker.VMCommandBuilder{}.WithSocketPath(socketPath).Build(ctx)
machine, err := firecracker.NewMachine(
ctx,
firecracker.Config{SocketPath: socketPath},
firecracker.WithProcessRunner(cmd),
)
if err != nil {
log.Errorw("firecracker_find_vm_failure", "error", err)
return nil, fmt.Errorf("failed to create machine with socket %s: %v", socketPath, err)
}
// Check if the VM is running by getting its instance info.
info, err := machine.DescribeInstanceInfo(ctx)
if err != nil {
log.Errorw("firecracker_find_vm_failure", "error", err)
return nil, fmt.Errorf("failed to get instance info for socket %s: %v", socketPath, err)
}
if *info.State != fcmodels.InstanceInfoStateRunning {
return nil, fmt.Errorf(
"VM with socket %s is not running, current state: %s",
socketPath,
*info.State,
)
}
log.Infow("firecracker_find_vm_success", "socketPath", socketPath)
return machine, nil
}
// Original Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Modified Copyright 2024, NuNet;
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
//
// NuNet modifications:
//
// Inserting custom scripts to be executed from within Firecracker when booting the VM
// with `/sbin/overlay-init`
package firecracker
import (
"fmt"
"os"
"path/filepath"
)
const (
overlayInit = `
#!/bin/sh
# Parameters:
# 1. rw_root -- path where the read/write root is mounted
# 2. work_dir -- path to the overlay workdir (must be on same filesystem as rw_root)
# Overlay will be set up on /mnt, original root on /mnt/rom
pivot() {
local rw_root work_dir
rw_root="$1"
work_dir="$2"
/bin/mount \
-o noatime,lowerdir=/,upperdir=${rw_root},workdir=${work_dir} \
-t overlay "overlayfs:${rw_root}" /mnt
pivot_root /mnt /mnt/rom
}
# Overlay is configured under /overlay
# Global variable $overlay_root is expected to be set to either:
# "ram", which configures a tmpfs as the rw overlay layer (this is
# the default, if the variable is unset)
# - or -
# A block device name, relative to /dev, in which case it is assumed
# to contain an ext4 filesystem suitable for use as a rw overlay
# layer. e.g. "vdb"
do_overlay() {
local overlay_dir="/overlay"
if [ "$overlay_root" = ram ] ||
[ -z "$overlay_root" ]; then
/bin/mount -t tmpfs -o noatime,mode=0755 tmpfs /overlay
else
/bin/mount -t ext4 "/dev/$overlay_root" /overlay
fi
mkdir -p /overlay/root /overlay/work
pivot /overlay/root /overlay/work
}
# If we're given an overlay, ensure that it really exists. Panic if not.
if [ -n "$overlay_root" ] &&
[ "$overlay_root" != ram ] &&
[ ! -b "/dev/$overlay_root" ]; then
echo -n "FATAL: "
echo "Overlay root given as $overlay_root but /dev/$overlay_root does not exist"
exit 1
fi
do_overlay
# firecracker-containerd itself doesn't need /volumes but volume package
# uses that to share files between in-VM snapshotters.
mkdir /volumes
`
)
func prepareCustomInit(initScripts map[string][]byte, customInitPath, initScriptsDir string) error {
customInitContent := overlayInit
customInitContent += "\n# (NuNet) Execute init scripts\n"
customInitContent += fmt.Sprintf("for script in %s/*; do\n", initScriptsDir)
customInitContent += " if [ -x \"$script\" ]; then\n"
customInitContent += " \"$script\"\n"
customInitContent += " fi\n"
customInitContent += "done\n\n"
customInitContent += "exec /usr/sbin/init $@\n"
err := os.WriteFile(customInitPath, []byte(customInitContent), 0o644)
if err != nil {
return fmt.Errorf("failed to write custom init: %w", err)
}
// TODO: is 0755 necessary here?
err = os.MkdirAll(initScriptsDir, 0o755)
if err != nil {
return fmt.Errorf("failed to create init scripts directory: %w", err)
}
for name, content := range initScripts {
err = os.WriteFile(filepath.Join(initScriptsDir, name), content, 0o644)
if err != nil {
return fmt.Errorf("failed to write init script %s: %w", name, err)
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
// +build linux
package firecracker
import (
"context"
"errors"
"fmt"
"io"
"os"
"strings"
"sync"
"sync/atomic"
"time"
firecracker "github.com/firecracker-microvm/firecracker-go-sdk"
fcmodels "github.com/firecracker-microvm/firecracker-go-sdk/client/models"
"go.uber.org/multierr"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
socketDir = "/tmp"
customInitPrefixPath = "/tmp/custom-init"
initScriptsPrefixPath = "/tmp/init_scripts"
DefaultCPUCount int64 = 1
DefaultMemSize int64 = 50 * 1024
statusWaitTickTime = 100 * time.Millisecond
statusWaitTimeout = 10 * time.Second
defaultKernelArgs = "ro console=ttyS0 noapic reboot=k panic=1 pci=off nomodules systemd.unified_cgroup_hierarchy=0 systemd.journald.forward_to_console systemd.log_color=false systemd.unit=firecracker.target"
)
var ErrNotInstalled = errors.New("firecracker is not installed")
// Executor manages the lifecycle of Firecracker VMs for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Firecracker client for VM management.
}
// NewExecutor initializes a new executor for Firecracker VMs.
func NewExecutor(ctx context.Context, id string) (*Executor, error) {
log.Infow("firecracker_executor_init_started", "executorID", id)
firecrackerClient, err := NewFirecrackerClient()
if err != nil {
log.Errorw("firecracker_executor_init_failure", "error", err)
return nil, err
}
if !firecrackerClient.IsInstalled(ctx) {
log.Errorw("firecracker_executor_not_installed", "executorID", id)
return nil, ErrNotInstalled
}
fe := &Executor{
ID: id,
client: firecrackerClient,
}
log.Infow("firecracker_executor_init_success", "executorID", id)
return fe, nil
}
// List the current executions.
func (e *Executor) List() []types.ExecutionListItem {
executions := make([]types.ExecutionListItem, 0)
e.handlers.Range(func(key, value any) bool {
strKey := key.(string)
val := value.(*executionHandler)
executions = append(executions, types.ExecutionListItem{
ExecutionID: strKey,
Running: val.running.Load(),
})
return true
})
log.Infow("firecracker_executor_list", "executionCount", len(executions))
return executions
}
// start begins the execution of a request by starting a new Firecracker VM.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
log.Infow("firecracker_start_execution", "jobID", request.JobID, "executionID", request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// VM is already running.
machine, err := e.FindRunningVM(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running VM for this execution, we will instead check for a handler, and
// failing that will create a new VM.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
}
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
machine, err = e.newFirecrackerExecutionVM(ctx, request)
if err != nil {
log.Errorw("firecracker_create_vm_failure", "error", err)
return fmt.Errorf("failed to create new firecracker VM: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
machine: machine,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the VM.
go handler.run(ctx)
log.Infow("firecracker_start_execution_success", "jobID", request.JobID, "executionID", request.ExecutionID)
return nil
}
// Pause pauses the container
func (e *Executor) Pause(
ctx context.Context,
executionID string,
) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
return handler.pause(ctx)
}
// Resume resumes the container
func (e *Executor) Resume(
ctx context.Context,
executionID string,
) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
return handler.resume(ctx)
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
log.Errorw("firecracker_wait_execution_not_found", "executionID", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
log.Infow("firecracker_wait_execution", "executionID", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
log.Infow("firecracker_wait_execution_result_received", "executionID", handler.executionID)
out <- handler.result
} else {
log.Errorw("firecracker_wait_execution_result_nil", "executionID", handler.executionID)
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
log.Errorw("firecracker_cancel_execution_not_found", "executionID", executionID)
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// GetLogStream is not implemented for Firecracker.
// It is defined to satisfy the Executor interface.
// This method will return an error if called.
func (e *Executor) GetLogStream(_ context.Context, _ types.LogStreamRequest) (io.ReadCloser, error) {
log.Error("firecracker_get_log_stream_not_implemented")
return nil, fmt.Errorf("GetLogStream is not implemented for Firecracker")
}
// GetStatus returns the status of the execution identified by the executionID.
// It returns the status of the execution and an error if the execution is not found or status is unknown.
func (e *Executor) GetStatus(ctx context.Context, executionID string) (types.ExecutionStatus, error) {
handler, found := e.handlers.Get(executionID)
if !found {
return "", fmt.Errorf("execution (%s) not found", executionID)
}
return handler.status(ctx)
}
// WaitForStatus waits for the execution to reach a specific status.
// It returns an error if the execution is not found or the status is unknown.
func (e *Executor) WaitForStatus(
ctx context.Context,
executionID string,
status types.ExecutionStatus,
timeout *time.Duration,
) error {
waitTimeout := statusWaitTimeout
if timeout != nil {
waitTimeout = *timeout
}
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
ticker := time.NewTicker(statusWaitTickTime)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
s, err := handler.status(ctx)
if err != nil {
return err
}
if s >= status {
return nil
}
case <-time.After(waitTimeout):
return fmt.Errorf("execution (%s) did not reach status %s", executionID, status)
}
}
}
// Cleanup removes all resources associated with the executor.
// This includes stopping and removing all running VMs and deleting their socket paths.
func (e *Executor) Cleanup() error {
log.Infow("firecracker_cleanup_started", "executorID", e.ID)
wg := sync.WaitGroup{}
errCh := make(chan error, len(e.handlers.Keys()))
e.handlers.Iter(func(_ string, handler *executionHandler) bool {
wg.Add(1)
go func(handler *executionHandler, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- handler.destroy(vmDestroyTimeout)
}(handler, &wg, errCh)
return true
})
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
log.Infow("firecracker_cleanup_complete", "executorID", e.ID)
return errs
}
// newFirecrackerExecutionVM is an internal method called by Start to set up a new Firecracker VM
// for the job execution. It configures the VM based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up mounts and resource constraints.
// It then creates the VM but does not start it. The method returns a firecracker.Machine instance
// and an error if any part of the setup fails.
func (e *Executor) newFirecrackerExecutionVM(
ctx context.Context,
params *types.ExecutionRequest,
) (*firecracker.Machine, error) {
log.Infow("firecracker_create_vm_started", "executionID", params.ExecutionID)
fcArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
log.Errorw("firecracker_create_vm_decode_spec_failure", "error", err)
return nil, fmt.Errorf("failed to decode firecracker engine spec: %w", err)
}
customInitPath := fmt.Sprintf("%s-%s", customInitPrefixPath, params.ExecutionID)
initScriptsPath := fmt.Sprintf("%s-%s", initScriptsPrefixPath, params.ExecutionID)
err = prepareCustomInit(params.ProvisionScripts, customInitPath, initScriptsPath)
if err != nil {
return nil, fmt.Errorf("failed to prepare custom init: %w", err)
}
fcArgs.KernelArgs = strings.Join([]string{defaultKernelArgs, fcArgs.KernelArgs}, " ")
fcArgs.KernelArgs = strings.Join([]string{fcArgs.KernelArgs, fmt.Sprintf("init=%s", customInitPrefixPath)}, " ")
fcConfig := firecracker.Config{
VMID: params.ExecutionID,
SocketPath: e.generateSocketPath(params.JobID, params.ExecutionID),
KernelImagePath: fcArgs.KernelImage,
InitrdPath: fcArgs.Initrd,
KernelArgs: fcArgs.KernelArgs,
LogLevel: "Error",
MachineCfg: fcmodels.MachineConfiguration{
VcpuCount: firecracker.Int64(int64(params.Resources.CPU.Cores)),
MemSizeMib: firecracker.Int64(int64(params.Resources.RAM.Size)),
},
}
mounts, err := makeVMMounts(
fcArgs.RootFileSystem,
params.Inputs,
params.Outputs,
params.ResultsDir,
customInitPath,
initScriptsPath,
)
if err != nil {
log.Errorw("firecracker_create_vm_mounts_failure", "error", err)
return nil, fmt.Errorf("failed to create VM mounts: %w", err)
}
fcConfig.Drives = mounts
machine, err := e.client.CreateVM(ctx, fcConfig)
if err != nil {
log.Errorw("firecracker_create_vm_failure", "error", err)
return nil, fmt.Errorf("failed to create VM: %w", err)
}
log.Infow("firecracker_create_vm_success", "executionID", params.ExecutionID)
return machine, nil
}
// makeVMMounts creates the mounts for the VM based on the input and output volumes
// provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeVMMounts(
rootFileSystem string,
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir, customInitPath, initScriptsPath string,
) ([]fcmodels.Drive, error) {
var drives []fcmodels.Drive
drivesBuilder := firecracker.NewDrivesBuilder(rootFileSystem)
for _, input := range inputs {
drivesBuilder.AddDrive(input.Source, input.ReadOnly)
}
drivesBuilder.AddDrive(customInitPath, true)
drivesBuilder.AddDrive(initScriptsPath, true)
for _, output := range outputs {
if output.Source == "" {
return drives, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return drives, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return drives, fmt.Errorf("failed to create results directory: %w", err)
}
drivesBuilder.AddDrive(output.Source, false)
}
drives = drivesBuilder.Build()
return drives, nil
}
// FindRunningVM finds the VM that is running the execution with the given ID.
// It returns the Mchine instance if found, or an error if the VM is not found.
func (e *Executor) FindRunningVM(
ctx context.Context,
jobID string,
executionID string,
) (*firecracker.Machine, error) {
return e.client.FindVM(ctx, e.generateSocketPath(jobID, executionID))
}
// generateSocketPath generates a socket path based on the job identifiers.
func (e *Executor) generateSocketPath(jobID string, executionID string) string {
return fmt.Sprintf("%s/%s_%s_%s.sock", socketDir, e.ID, jobID, executionID)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
// +build linux
package firecracker
import (
"context"
"fmt"
"os"
"sync/atomic"
"time"
firecracker "github.com/firecracker-microvm/firecracker-go-sdk"
fcmodels "github.com/firecracker-microvm/firecracker-go-sdk/client/models"
"gitlab.com/nunet/device-management-service/types"
)
const (
vmDestroyTimeout = time.Second * 10
)
// executionHandler is a struct that holds the necessary information to manage the execution of a firecracker VM.
type executionHandler struct {
//
// provided by the executor
ID string
client *Client
// meta data about the task
JobID string
executionID string
machine *firecracker.Machine
resultsDir string
// synchronization
activeCh chan bool // Blocks until the VM starts running.
waitCh chan bool // Blocks until execution completes or fails.
running *atomic.Bool // Indicates if the VM is currently running.
// result of the execution
result *types.ExecutionResult
}
// active returns true if the firecracker VM is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the firecracker VM and waits for it to finish.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
if err := h.destroy(vmDestroyTimeout); err != nil {
log.Warnf("failed to destroy VM: %v", err)
}
h.running.Store(false)
close(h.waitCh)
}()
// Start the VM
log.Infow("firecracker_execution_starting", "executionID", h.executionID)
if err := h.client.StartVM(ctx, h.machine); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start VM: %v", err))
log.Errorw("firecracker_vm_start_failure", "error", err, "executionID", h.executionID)
return
}
close(h.activeCh) // Indicate that the VM has started.
err := h.machine.Wait(ctx)
if err != nil {
if ctx.Err() != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("context closed while waiting on VM: %v", err))
log.Errorw("firecracker_execution_context_closed", "error", err, "executionID", h.executionID)
return
}
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to wait on VM: %v", err))
log.Errorw("firecracker_vm_wait_failure", "error", err, "executionID", h.executionID)
return
}
h.result = types.NewExecutionResult(types.ExecutionStatusCodeSuccess)
log.Infow("firecracker_execution_success", "executionID", h.executionID)
}
// kill stops the firecracker VM.
func (h *executionHandler) kill(ctx context.Context) error {
log.Infow("firecracker_kill_vm", "executionID", h.executionID)
return h.client.ShutdownVM(ctx, h.machine)
}
// destroy stops the firecracker VM and removes its resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
log.Infow("firecracker_destroy_vm", "executionID", h.executionID)
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// TODO: move this to executionHandler field
customInitPath := fmt.Sprintf("%s-%s", customInitPrefixPath, h.executionID)
initScriptsPath := fmt.Sprintf("%s-%s", initScriptsPrefixPath, h.executionID)
if err := os.RemoveAll(customInitPath); err != nil {
log.Errorf("failed to remove custom init: %v", err)
}
if err := os.RemoveAll(initScriptsPath); err != nil {
log.Errorf("failed to remove init scripts: %v", err)
}
return h.client.DestroyVM(ctx, h.machine, timeout)
}
// pause pauses the firecracker VM.
func (h *executionHandler) pause(ctx context.Context) error {
return h.client.PauseVM(ctx, h.machine)
}
// resume resumes the firecracker VM.
func (h *executionHandler) resume(ctx context.Context) error {
return h.client.ResumeVM(ctx, h.machine)
}
// status returns the result of the execution.
func (h *executionHandler) status(ctx context.Context) (types.ExecutionStatus, error) {
if !h.active() {
if h.result != nil {
if h.result.ExitCode == types.ExecutionStatusCodeSuccess {
return types.ExecutionStatusSuccess, nil
}
return types.ExecutionStatusFailed, fmt.Errorf("VM exited: %v", h.result.ErrorMsg)
}
return types.ExecutionStatusPending, nil
}
info, err := h.machine.DescribeInstanceInfo(ctx)
if err != nil {
return types.ExecutionStatusFailed, fmt.Errorf("failed to get VM status: %v", err)
}
switch *info.State {
case fcmodels.InstanceInfoStateNotStarted:
return types.ExecutionStatusPending, nil
case fcmodels.InstanceInfoStateRunning:
return types.ExecutionStatusRunning, nil
case fcmodels.InstanceInfoStatePaused:
return types.ExecutionStatusPaused, nil
default:
return types.ExecutionStatusFailed, fmt.Errorf("unknown VM state: %s", *info.State)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
// +build linux
package firecracker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyKernelImage = "kernel_image"
EngineKeyKernelArgs = "kernel_args"
EngineKeyRootFileSystem = "root_file_system"
EngineKeyInitrd = "initrd"
EngineKeyMMDSMessage = "mmds_message"
)
// EngineSpec contains necessary parameters to execute a firecracker job.
type EngineSpec struct {
// KernelImage is the path to the kernel image file.
KernelImage string `json:"kernel_image,omitempty"`
// InitrdPath is the path to the initial ramdisk file.
Initrd string `json:"initrd_path,omitempty"`
// KernelArgs is the kernel command line arguments.
KernelArgs string `json:"kernel_args,omitempty"`
// RootFileSystem is the path to the root file system.
RootFileSystem string `json:"root_file_system,omitempty"`
// MMDSMessage is the MMDS message to be sent to the Firecracker VM.
MMDSMessage string `json:"mmds_message,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.RootFileSystem) {
return fmt.Errorf("invalid firecracker engine params: root_file_system cannot be empty")
}
if validate.IsBlank(c.KernelImage) {
return fmt.Errorf("invalid firecracker engine params: kernel_image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a firecracker engine spec
// It converts the params into a firecracker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(types.ExecutorTypeFirecracker.String()) {
return EngineSpec{}, fmt.Errorf(
"invalid firecracker engine type. expected %s, but received: %s",
types.ExecutorTypeFirecracker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid firecracker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode firecracker engine params: %w", err)
}
var firecrackerSpec *EngineSpec
err = json.Unmarshal(paramBytes, &firecrackerSpec)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode firecracker engine params: %w", err)
}
return *firecrackerSpec, firecrackerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Firecracker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewFirecrackerEngineBuilder function initializes a new FirecrackerEngineBuilder instance.
// It sets the engine type to EngineFirecracker.String() and kernel image path as per the input argument.
func NewFirecrackerEngineBuilder(rootFileSystem string) *EngineBuilder {
eb := types.NewSpecConfig(types.ExecutorTypeFirecracker.String())
eb.WithParam(EngineKeyRootFileSystem, rootFileSystem)
return &EngineBuilder{eb: eb}
}
// WithRootFileSystem is a builder method that sets the Firecracker engine root file system.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithRootFileSystem(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyRootFileSystem, e)
return b
}
// WithKernelImage is a builder method that sets the Firecracker engine kernel image.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelImage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelImage, e)
return b
}
// WithInitrd is a builder method that sets the Firecracker init ram disk.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithInitrd(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyInitrd, e)
return b
}
// WithKernelArgs is a builder method that sets the Firecracker engine kernel arguments.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelArgs(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelArgs, e)
return b
}
// WithMMDSMessage is a builder method that sets the Firecracker engine MMDS message.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithMMDSMessage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyMMDSMessage, e)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package null
import (
"context"
"io"
"time"
"gitlab.com/nunet/device-management-service/executor"
"gitlab.com/nunet/device-management-service/types"
)
// Executor is a no-op implementation of the Executor interface.
type Executor struct{}
// NewExecutor creates a new Executor.
func NewExecutor(_ context.Context, _ string) (executor.Executor, error) {
return &Executor{}, nil
}
var _ executor.Executor = (*Executor)(nil)
// Start does nothing and returns nil.
func (e *Executor) Start(_ context.Context, _ *types.ExecutionRequest) error {
return nil
}
// Run returns a nil result and nil error.
func (e *Executor) Run(_ context.Context, _ *types.ExecutionRequest) (*types.ExecutionResult, error) {
return nil, nil
}
// Wait returns channels that immediately close.
func (e *Executor) Wait(_ context.Context, _ string) (<-chan *types.ExecutionResult, <-chan error) {
resultCh := make(chan *types.ExecutionResult)
errCh := make(chan error)
close(resultCh)
close(errCh)
return resultCh, errCh
}
// Cancel does nothing and returns nil.
func (e *Executor) Cancel(_ context.Context, _ string) error {
return nil
}
// GetLogStream returns a closed io.ReadCloser and nil error.
func (e *Executor) GetLogStream(_ context.Context, _ types.LogStreamRequest) (io.ReadCloser, error) {
return io.NopCloser(nil), nil
}
// List returns an empty slice of ExecutionListItem.
func (e *Executor) List() []types.ExecutionListItem {
return []types.ExecutionListItem{}
}
// Cleanup does nothing and returns nil.
func (e *Executor) Cleanup(_ context.Context) error {
return nil
}
// GetStatus returns an empty ExecutionStatus.
func (e *Executor) GetStatus(_ context.Context, _ string) (types.ExecutionStatus, error) {
return "", nil
}
// Pause does nothing and returns nil.
func (e *Executor) Pause(_ context.Context, _ string) error {
return nil
}
// Resume does nothing and returns nil.
func (e *Executor) Resume(_ context.Context, _ string) error {
return nil
}
// WaitForStatus does nothing and returns nil.
func (e *Executor) WaitForStatus(_ context.Context, _ string, _ types.ExecutionStatus, _ *time.Duration) error {
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package backgroundtasks
import (
"sort"
"sync"
"time"
)
// Scheduler orchestrates the execution of tasks based on their triggers and priority.
type Scheduler struct {
tasks map[int]*Task // Map of tasks by their ID.
runningTasks map[int]bool // Map to keep track of running tasks.
ticker *time.Ticker // Ticker for periodic checks of task triggers.
stopChan chan struct{} // Channel to signal stopping the scheduler.
maxRunningTasks int // Maximum number of tasks that can run concurrently.
lastTaskID int // Counter for assigning unique IDs to tasks.
mu sync.Mutex // Mutex to protect access to task maps.
}
// NewScheduler creates a new Scheduler with a specified limit on running tasks.
func NewScheduler(maxRunningTasks int) *Scheduler {
return &Scheduler{
tasks: make(map[int]*Task),
runningTasks: make(map[int]bool),
ticker: time.NewTicker(1 * time.Second),
stopChan: make(chan struct{}),
maxRunningTasks: maxRunningTasks,
lastTaskID: 0,
}
}
// AddTask adds a new task to the scheduler and initializes its state.
func (s *Scheduler) AddTask(task *Task) *Task {
s.mu.Lock()
defer s.mu.Unlock()
task.ID = s.lastTaskID
task.Enabled = true
for _, trigger := range task.Triggers {
trigger.Reset()
}
s.tasks[task.ID] = task
s.lastTaskID++
return task
}
// RemoveTask removes a task from the scheduler.
func (s *Scheduler) RemoveTask(taskID int) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Start begins the scheduler's task execution loop.
func (s *Scheduler) Start() {
go func() {
for {
select {
case <-s.stopChan:
return
case <-s.ticker.C:
s.runTasks()
}
}
}()
}
// runningTasksCount returns the count of running tasks.
func (s *Scheduler) runningTasksCount() int {
s.mu.Lock()
defer s.mu.Unlock()
count := 0
for _, isRunning := range s.runningTasks {
if isRunning {
count++
}
}
return count
}
// runTasks checks and runs tasks based on their triggers and priority.
func (s *Scheduler) runTasks() {
// Sort tasks by priority.
sortedTasks := make([]*Task, 0, len(s.tasks))
for _, task := range s.tasks {
sortedTasks = append(sortedTasks, task)
}
sort.Slice(sortedTasks, func(i, j int) bool {
return sortedTasks[i].Priority > sortedTasks[j].Priority
})
for _, task := range sortedTasks {
if !task.Enabled || s.runningTasks[task.ID] {
continue
}
if len(task.Triggers) == 0 {
s.RemoveTask(task.ID)
continue
}
for _, trigger := range task.Triggers {
if trigger.IsReady() && s.runningTasksCount() < s.maxRunningTasks {
s.runningTasks[task.ID] = true
go s.runTask(task.ID)
trigger.Reset()
break
}
}
}
}
// Stop signals the scheduler to stop running tasks.
func (s *Scheduler) Stop() {
close(s.stopChan)
}
// runTask executes a task and manages its lifecycle and retry policy.
func (s *Scheduler) runTask(taskID int) {
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
s.runningTasks[taskID] = false
}()
task := s.tasks[taskID]
execution := Execution{StartedAt: time.Now()}
defer func() {
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.tasks[taskID] = task
s.mu.Unlock()
}()
for i := 0; i < task.RetryPolicy.MaxRetries+1; i++ {
err := runTaskWithRetry(task.Function, task.Args, task.RetryPolicy.Delay)
if err == nil {
execution.Status = "SUCCESS"
execution.EndedAt = time.Now()
return
}
execution.Error = err.Error()
}
execution.Status = "FAILED"
execution.EndedAt = time.Now()
}
// runTaskWithRetry attempts to execute a task with a retry policy.
func runTaskWithRetry(
fn func(args interface{}) error,
args []interface{},
delay time.Duration,
) error {
err := fn(args)
if err != nil {
time.Sleep(delay)
return err
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package backgroundtasks
import (
"time"
"github.com/robfig/cron/v3"
)
// Trigger interface defines a method to check if a trigger condition is met.
type Trigger interface {
IsReady() bool // Returns true if the trigger condition is met.
Reset() // Resets the trigger state.
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTrigger struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTrigger) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTrigger) Reset() {
t.lastTriggered = time.Now()
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTriggerWithJitter struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
Jitter func() time.Duration
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTriggerWithJitter) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval + t.Jitter()).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTriggerWithJitter) Reset() {
t.lastTriggered = time.Now()
}
// EventTrigger triggers based on an external event signaled through a channel.
type EventTrigger struct {
Trigger chan bool // Channel to signal an event.
}
// IsReady checks if there is a signal in the trigger channel.
func (t *EventTrigger) IsReady() bool {
select {
case <-t.Trigger:
return true
default:
return false
}
}
// Reset for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) Reset() {}
// OneTimeTrigger triggers once after a specified delay.
type OneTimeTrigger struct {
Delay time.Duration // The delay after which to trigger.
registeredAt time.Time // Time when the trigger was set.
}
// Reset sets the trigger registration time to the current time.
func (t *OneTimeTrigger) Reset() {
t.registeredAt = time.Now()
}
// IsReady checks if the current time has passed the delay period.
func (t *OneTimeTrigger) IsReady() bool {
return t.registeredAt.Add(t.Delay).Before(time.Now())
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package config
import (
"bytes"
"encoding/json"
"fmt"
"os"
"reflect"
"github.com/spf13/afero"
"github.com/spf13/viper"
)
var (
cfg Config
homeDir, _ = os.UserHomeDir()
)
func getViper() *viper.Viper {
v := viper.New()
v.SetConfigName("dms_config")
v.SetConfigType("json")
v.AddConfigPath(".") // config file reading order starts with current working directory
v.AddConfigPath(fmt.Sprintf("%s/.nunet", homeDir)) // then home directory
v.AddConfigPath("/etc/nunet/") // finally /etc/nunet
return v
}
func setDefaultConfig() *viper.Viper {
v := getViper()
v.SetDefault("general.user_dir", fmt.Sprintf("%s/.nunet", homeDir))
v.SetDefault("general.work_dir", fmt.Sprintf("%s/nunet", homeDir))
v.SetDefault("general.data_dir", fmt.Sprintf("%s/nunet/data", homeDir))
v.SetDefault("general.debug", false)
v.SetDefault("general.port_available_range_from", 16384)
v.SetDefault("general.port_available_range_to", 32768)
v.SetDefault("rest.addr", "127.0.0.1")
v.SetDefault("rest.port", 9999)
v.SetDefault("profiler.enabled", true)
v.SetDefault("profiler.addr", "127.0.0.1")
v.SetDefault("profiler.port", 6060)
v.SetDefault("p2p.listen_address", []string{
"/ip4/0.0.0.0/tcp/9000",
"/ip4/0.0.0.0/udp/9000/quic-v1",
})
v.SetDefault("p2p.bootstrap_peers", []string{
"/dnsaddr/bootstrap.p2p.nunet.io/p2p/QmQ2irHa8aFTLRhkbkQCRrounE4MbttNp8ki7Nmys4F9NP",
"/dnsaddr/bootstrap.p2p.nunet.io/p2p/Qmf16N2ecJVWufa29XKLNyiBxKWqVPNZXjbL3JisPcGqTw",
"/dnsaddr/bootstrap.p2p.nunet.io/p2p/QmTkWP72uECwCsiiYDpCFeTrVeUM9huGTPsg3m6bHxYQFZ",
})
v.SetDefault("p2p.memory", 1024)
v.SetDefault("p2p.fd", 512)
v.SetDefault("job.log_update_interval", 2)
v.SetDefault("job.target_peer", "")
v.SetDefault("job.cleanup_interval", 3)
// default observability settings
v.SetDefault("observability.log_level", "INFO")
v.SetDefault("observability.log_file", fmt.Sprintf("%s/nunet/logs/nunet-dms.log", homeDir))
v.SetDefault("observability.max_size", 100) // megabytes
v.SetDefault("observability.max_backups", 3)
v.SetDefault("observability.max_age", 28) // days
v.SetDefault("observability.elasticsearch_url", "http://localhost:9200")
v.SetDefault("observability.elasticsearch_index", "nunet-dms")
v.SetDefault("observability.flush_interval", 5) // Default flush interval is 5 seconds
// default APM settings
v.SetDefault("apm.server_url", "http://apm.telemetry.nunet.io")
v.SetDefault("apm.service_name", "nunet-dms")
v.SetDefault("apm.environment", "production")
v.SetDefault("apm.certificate", "/usr/share/elasticsearch/config/client.crt")
v.SetDefault("apm.key", "/usr/share/elasticsearch/config/client.key")
v.SetDefault("apm.ca", "/usr/share/elasticsearch/config/ca.crt")
return v
}
func LoadConfig() error {
v := setDefaultConfig()
if err := v.ReadInConfig(); err != nil {
if err := setDefaultConfig().UnmarshalExact(&cfg); err != nil {
return fmt.Errorf("failed to unmarshal default config: %w", err)
}
return nil
}
if err := v.UnmarshalExact(&cfg); err != nil {
return fmt.Errorf("failed to unmarshal config: %w", err)
}
return nil
}
func GetConfig() *Config {
if reflect.DeepEqual(cfg, Config{}) {
if err := LoadConfig(); err != nil {
return &cfg
}
}
return &cfg
}
func Get(key string) (interface{}, error) {
v := getViper()
loadedConfig, err := json.Marshal(GetConfig())
if err != nil {
return nil, fmt.Errorf("could not marshal config: %w", err)
}
if err := v.ReadConfig(bytes.NewReader(loadedConfig)); err != nil {
return nil, fmt.Errorf("could not read config: %w", err)
}
if !v.IsSet(key) {
return nil, fmt.Errorf("key '%s' not found in configuration", key)
}
return v.Get(key), nil
}
func Set(fs afero.Fs, key string, value interface{}) error {
v := getViper()
v.SetFs(fs)
v.Set(key, value)
if err := v.UnmarshalExact(&cfg); err != nil {
return fmt.Errorf("failed to unmarshal config: %w", err)
}
loadedConfig, err := json.Marshal(GetConfig())
if err != nil {
return fmt.Errorf("could not marshal config: %w", err)
}
if err := v.MergeConfig(bytes.NewReader(loadedConfig)); err != nil {
return fmt.Errorf("failed to merge config: %w", err)
}
if err := v.WriteConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
// Config file does not exist, create it.
return v.SafeWriteConfig()
}
return fmt.Errorf("failed to write config: %w", err)
}
return nil
}
func FileExists(fs afero.Fs) (bool, error) {
v := getViper()
v.SetFs(fs)
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
return false, nil
}
return false, fmt.Errorf("could not read config file: %w", err)
}
return true, nil
}
func GetPath() string {
v := getViper()
if err := v.ReadInConfig(); err != nil {
return setDefaultConfig().ConfigFileUsed()
}
return v.ConfigFileUsed()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package internal
import (
"os"
"os/signal"
"syscall"
)
var ShutdownChan chan os.Signal
func init() {
ShutdownChan = make(chan os.Signal, 1)
signal.Notify(ShutdownChan, syscall.SIGINT, syscall.SIGTERM)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
// Package internal is a work in progress. It is planned to accommodate
// modules such as db and types.
package internal
import (
"log"
"net/http"
"github.com/gorilla/websocket"
)
// UpgradeConnection is generic protocol upgrader for entire DMS.
var UpgradeConnection = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(_ *http.Request) bool { return true },
}
// WebSocketConnection is pointer to gorilla/websocket.Conn
type WebSocketConnection struct {
*websocket.Conn
}
// Command represents a command to be executed
type Command struct {
Command string
NodeID string // ID of the node where command will be executed
Result string
Conn *WebSocketConnection
}
var commandChan = make(chan Command)
var clients = make(map[WebSocketConnection]string)
// ListenForWs listens to the connected client for any message. It is assumed that
// every message that is coming is a command to be executed.
func ListenForWs(conn *WebSocketConnection) {
defer func() {
if r := recover(); r != nil {
log.Printf("Recovered from panic: %v", r) // Log the panic if needed
}
}()
cmd := Command{NodeID: clients[*conn], Conn: conn}
for {
_, msg, err := conn.ReadMessage()
if err == nil {
// logic to send command and fetch the output
cmd.Command = string(msg)
commandChan <- cmd
} else {
log.Printf("Error reading message: %v", err) // Handle the error if needed
return
}
}
}
// SendCommandForExecution work is to send command for execution and fetch the result
// This function listens for new commands from commandChan
func SendCommandForExecution() {
for {
command := <-commandChan
// TO BE IMPLEMENTED
// send command
// fetch result
// send back result
err := command.Conn.WriteMessage(websocket.TextMessage, []byte(command.Command))
if err != nil {
log.Printf("Error writing message: %v", err) // Log the error when message fails to send
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"crypto/rand"
"errors"
"io"
"golang.org/x/crypto/sha3"
)
// RandomEntropy bytes from rand.Reader
func RandomEntropy(length int) ([]byte, error) {
buf := make([]byte, length)
n, err := io.ReadFull(rand.Reader, buf)
if err != nil || n != length {
return nil, errors.New("failed to read random bytes")
}
return buf, nil
}
// Sha3 return sha3 of a given byte array
func Sha3(data ...[]byte) ([]byte, error) {
d := sha3.New256()
for _, b := range data {
_, err := d.Write(b)
if err != nil {
return nil, err
}
}
return d.Sum(nil), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"crypto/subtle"
"fmt"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
"github.com/libp2p/go-libp2p/core/crypto/pb"
"golang.org/x/crypto/sha3"
)
var ethSignMagic = []byte(
"\x19Ethereum Signed Message:\n",
)
type EthPublicKey struct {
key *secp256k1.PublicKey
}
var _ PubKey = (*EthPublicKey)(nil)
func UnmarshalEthPublicKey(data []byte) (_k PubKey, err error) {
k, err := secp256k1.ParsePubKey(data)
if err != nil {
return nil, err
}
return &EthPublicKey{key: k}, nil
}
func (k *EthPublicKey) Verify(data []byte, sigStr []byte) (success bool, err error) {
sig, err := ecdsa.ParseDERSignature(sigStr)
if err != nil {
return false, err
}
hasher := sha3.NewLegacyKeccak256()
hasher.Write(ethSignMagic)
hasher.Write([]byte(fmt.Sprintf("%d", len(data))))
hasher.Write(data)
hash := hasher.Sum(nil)
return sig.Verify(hash, k.key), nil
}
func (k *EthPublicKey) Raw() (res []byte, err error) {
return k.key.SerializeCompressed(), nil
}
func (k *EthPublicKey) Type() pb.KeyType {
return Eth
}
func (k *EthPublicKey) Equals(o Key) bool {
sk, ok := o.(*EthPublicKey)
if !ok {
return basicEquals(k, o)
}
return k.key.IsEqual(sk.key)
}
func basicEquals(k1, k2 Key) bool {
if k1.Type() != k2.Type() {
return false
}
a, err := k1.Raw()
if err != nil {
return false
}
b, err := k2.Raw()
if err != nil {
return false
}
return subtle.ConstantTimeCompare(a, b) == 1
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"bytes"
"encoding/base32"
"encoding/json"
"fmt"
)
// ID is the encoding of the actor's public key
type ID struct{ PublicKey []byte }
func (id ID) Equal(other ID) bool {
return bytes.Equal(id.PublicKey, other.PublicKey)
}
func (id ID) Empty() bool {
return len(id.PublicKey) == 0
}
// IDJsonView is the on the wire json reprsentation of an ID
type IDJSONView struct {
Pub string `json:"pub"`
}
func (id ID) String() string {
return base32.StdEncoding.EncodeToString(id.PublicKey)
}
func IDFromString(s string) (ID, error) {
data, err := base32.StdEncoding.DecodeString(s)
if err != nil {
return ID{}, fmt.Errorf("decode ID: %w", err)
}
return ID{PublicKey: data}, nil
}
func (id ID) MarshalJSON() ([]byte, error) {
return json.Marshal(IDJSONView{Pub: id.String()})
}
var _ json.Marshaler = ID{}
func (id *ID) UnmarshalJSON(data []byte) error {
var input IDJSONView
err := json.Unmarshal(data, &input)
if err != nil {
return fmt.Errorf("unmarshaling ID: %w", err)
}
val, err := IDFromString(input.Pub)
if err != nil {
return fmt.Errorf("unmarshaling ID: %w", err)
}
*id = val
return nil
}
var _ json.Unmarshaler = (*ID)(nil)
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"crypto/rand"
"fmt"
"github.com/libp2p/go-libp2p/core/crypto"
)
const (
Ed25519 = crypto.Ed25519
Secp256k1 = crypto.Secp256k1
Eth = 127
)
type (
Key = crypto.Key
PrivKey = crypto.PrivKey
PubKey = crypto.PubKey
)
func AllowedKey(t int) bool {
switch t {
case Ed25519:
return true
case Secp256k1:
return true
default:
return false
}
}
func GenerateKeyPair(t int) (PrivKey, PubKey, error) {
switch t {
case Ed25519:
return crypto.GenerateEd25519Key(rand.Reader)
case Secp256k1:
return crypto.GenerateSecp256k1Key(rand.Reader)
default:
return nil, nil, fmt.Errorf("unsupported key type %d: %w", t, ErrUnsupportedKeyType)
}
}
func PublicKeyToBytes(k PubKey) ([]byte, error) {
return crypto.MarshalPublicKey(k)
}
func BytesToPublicKey(data []byte) (PubKey, error) {
return crypto.UnmarshalPublicKey(data)
}
func PrivateKeyToBytes(k PrivKey) ([]byte, error) {
return crypto.MarshalPrivateKey(k)
}
func BytesToPrivateKey(data []byte) (PrivKey, error) {
return crypto.UnmarshalPrivateKey(data)
}
func IDFromPublicKey(k PubKey) (ID, error) {
data, err := PublicKeyToBytes(k)
if err != nil {
return ID{}, fmt.Errorf("id from public key: %w", err)
}
return ID{PublicKey: data}, nil
}
func PublicKeyFromID(id ID) (PubKey, error) {
return BytesToPublicKey(id.PublicKey)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package keystore
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/hex"
"encoding/json"
"fmt"
"path/filepath"
"strings"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
"golang.org/x/crypto/scrypt"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
var (
// KDF parameters
nameKDF = "scrypt"
scryptKeyLen = 64
scryptN = 1 << 18
scryptR = 8
scryptP = 1
ksVersion = 3
ksCipher = "aes-256-ctr"
)
// Key represents a keypair to be stored in a keystore
type Key struct {
ID string
Data []byte
}
// NewKey creates new Key
func NewKey(id string, data []byte) (*Key, error) {
return &Key{
ID: id,
Data: data,
}, nil
}
// PrivKey acts upon a Key which its `Data` is a private key.
// The method unmarshals the raw pvkey bytes.
func (key *Key) PrivKey() (crypto.PrivKey, error) {
priv, err := libp2p_crypto.UnmarshalPrivateKey(key.Data)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal private key: %v", err)
}
return priv, nil
}
// MarshalToJSON encrypts and marshals a key to json byte array.
func (key *Key) MarshalToJSON(passphrase string) ([]byte, error) {
if passphrase == "" {
return nil, ErrEmptyPassphrase
}
salt, err := crypto.RandomEntropy(64)
if err != nil {
return nil, err
}
dk, err := scrypt.Key([]byte(passphrase), salt, scryptN, scryptR, scryptP, scryptKeyLen)
if err != nil {
return nil, err
}
iv, err := crypto.RandomEntropy(aes.BlockSize)
if err != nil {
return nil, err
}
enckey := dk[:32]
aesBlock, err := aes.NewCipher(enckey)
if err != nil {
return nil, err
}
stream := cipher.NewCTR(aesBlock, iv)
cipherText := make([]byte, len(key.Data))
stream.XORKeyStream(cipherText, key.Data)
mac, err := crypto.Sha3(dk[32:64], cipherText)
if err != nil {
return nil, err
}
cipherParamsJSON := cipherparamsJSON{
IV: hex.EncodeToString(iv),
}
sp := ScryptParams{
N: scryptN,
R: scryptR,
P: scryptP,
DKeyLength: scryptKeyLen,
Salt: hex.EncodeToString(salt),
}
keyjson := cryptoJSON{
Cipher: ksCipher,
CipherText: hex.EncodeToString(cipherText),
CipherParams: cipherParamsJSON,
KDF: nameKDF,
KDFParams: sp,
MAC: hex.EncodeToString(mac),
}
encjson := encryptedKeyJSON{
Crypto: keyjson,
ID: key.ID,
Version: ksVersion,
}
data, err := json.MarshalIndent(&encjson, "", " ")
if err != nil {
return nil, err
}
return data, nil
}
// UnmarshalKey decrypts and unmarhals the private key
func UnmarshalKey(data []byte, passphrase string) (*Key, error) {
if passphrase == "" {
return nil, ErrEmptyPassphrase
}
encjson := encryptedKeyJSON{}
if err := json.Unmarshal(data, &encjson); err != nil {
return nil, fmt.Errorf("failed to unmarshal key data: %w", err)
}
if encjson.Version != ksVersion {
return nil, ErrVersionMismatch
}
if encjson.Crypto.Cipher != ksCipher {
return nil, ErrCipherMismatch
}
mac, err := hex.DecodeString(encjson.Crypto.MAC)
if err != nil {
return nil, fmt.Errorf("failed to decode mac: %w", err)
}
iv, err := hex.DecodeString(encjson.Crypto.CipherParams.IV)
if err != nil {
return nil, fmt.Errorf("failed to decode cipher params iv: %w", err)
}
salt, err := hex.DecodeString(encjson.Crypto.KDFParams.Salt)
if err != nil {
return nil, fmt.Errorf("failed to decode salt: %w", err)
}
ciphertext, err := hex.DecodeString(encjson.Crypto.CipherText)
if err != nil {
return nil, fmt.Errorf("failed to decode cipher text: %w", err)
}
dk, err := scrypt.Key([]byte(passphrase), salt, encjson.Crypto.KDFParams.N, encjson.Crypto.KDFParams.R, encjson.Crypto.KDFParams.P, encjson.Crypto.KDFParams.DKeyLength)
if err != nil {
return nil, fmt.Errorf("failed to derive key: %w", err)
}
hash, err := crypto.Sha3(dk[32:64], ciphertext)
if err != nil {
return nil, fmt.Errorf("failed to hash key and ciphertext: %w", err)
}
if !bytes.Equal(hash, mac) {
return nil, ErrMACMismatch
}
aesBlock, err := aes.NewCipher(dk[:32])
if err != nil {
return nil, fmt.Errorf("failed to create cipher block: %w", err)
}
stream := cipher.NewCTR(aesBlock, iv)
outputkey := make([]byte, len(ciphertext))
stream.XORKeyStream(outputkey, ciphertext)
return &Key{
ID: encjson.ID,
Data: outputkey,
}, nil
}
func removeFileExtension(filename string) string {
ext := filepath.Ext(filename)
return strings.TrimSuffix(filename, ext)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package keystore
import (
"fmt"
"os"
"path/filepath"
"sync"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/utils"
)
// KeyStore manages a local keystore with lock and unlock functionalities.
type KeyStore interface {
Save(id string, data []byte, passphrase string) (string, error)
Get(keyID string, passphrase string) (*Key, error)
Delete(keyID string, passphrase string) error
ListKeys() ([]string, error)
}
// BasicKeyStore handles keypair storage.
// TODO: add cache?
type BasicKeyStore struct {
fs afero.Fs
keysDir string
mu sync.RWMutex
}
var _ KeyStore = (*BasicKeyStore)(nil)
// New creates a new BasicKeyStore.
func New(fs afero.Fs, keysDir string) (*BasicKeyStore, error) {
if keysDir == "" {
return nil, ErrEmptyKeysDir
}
if err := fs.MkdirAll(keysDir, 0o700); err != nil {
return nil, fmt.Errorf("failed to create keystore directory: %w", err)
}
return &BasicKeyStore{
fs: fs,
keysDir: keysDir,
}, nil
}
// Save encrypts a key and writes it to a file.
func (ks *BasicKeyStore) Save(id string, data []byte, passphrase string) (string, error) {
if passphrase == "" {
return "", ErrEmptyPassphrase
}
key := &Key{
ID: id,
Data: data,
}
keyDataJSON, err := key.MarshalToJSON(passphrase)
if err != nil {
return "", fmt.Errorf("failed to marshal key: %w", err)
}
filename, err := utils.WriteToFile(ks.fs, keyDataJSON, filepath.Join(ks.keysDir, key.ID+".json"))
if err != nil {
return "", fmt.Errorf("failed to write key to file: %v", err)
}
return filename, nil
}
// Get unlocks a key by keyID.
func (ks *BasicKeyStore) Get(keyID string, passphrase string) (*Key, error) {
bts, err := afero.ReadFile(ks.fs, filepath.Join(ks.keysDir, keyID+".json"))
if err != nil {
if os.IsNotExist(err) {
return nil, ErrKeyNotFound
}
return nil, fmt.Errorf("failed to read keystore file: %w", err)
}
key, err := UnmarshalKey(bts, passphrase)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal keystore file: %w", err)
}
return key, err
}
// Delete removes the file referencing the given key.
func (ks *BasicKeyStore) Delete(keyID string, passphrase string) error {
ks.mu.Lock()
defer ks.mu.Unlock()
filePath := filepath.Join(ks.keysDir, keyID+".json")
bts, err := afero.ReadFile(ks.fs, filePath)
if err != nil {
if os.IsNotExist(err) {
return ErrKeyNotFound
}
return fmt.Errorf("failed to read keystore file: %w", err)
}
_, err = UnmarshalKey(bts, passphrase)
if err != nil {
return fmt.Errorf("invalid passphrase or corrupted key file: %w", err)
}
err = ks.fs.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to delete key file: %w", err)
}
return nil
}
// ListKeys lists the keys in the keysDir.
func (ks *BasicKeyStore) ListKeys() ([]string, error) {
keys := make([]string, 0)
dirEntries, err := afero.ReadDir(ks.fs, ks.keysDir)
if err != nil {
return nil, fmt.Errorf("failed to read keystore directory: %w", err)
}
for _, entry := range dirEntries {
_, err := afero.ReadFile(ks.fs, filepath.Join(ks.keysDir, entry.Name()))
if err != nil {
continue
}
keys = append(keys, removeFileExtension(entry.Name()))
}
return keys, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
// ReadVault
func ReadVault(path string, passphrase string) ([]byte, error) { //nolint:revive // its a todo
// TODO
return nil, ErrTODO
}
// WriteVault
func WriteVault(path string, passphrase string, data []byte) error { //nolint:revive // its a todo
// TODO
return ErrTODO
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
type GetAnchorFunc func(did DID) (Anchor, error)
var anchorMethods map[string]GetAnchorFunc
func init() {
anchorMethods = map[string]GetAnchorFunc{
"key": makeKeyAnchor,
}
}
func GetAnchorForDID(did DID) (Anchor, error) {
makeAnchor, ok := anchorMethods[did.Method()]
if !ok {
return nil, ErrNoAnchorMethod
}
return makeAnchor(did)
}
func makeKeyAnchor(did DID) (Anchor, error) {
pubk, err := PublicKeyFromDID(did)
if err != nil {
return nil, err
}
return NewAnchor(did, pubk), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
import (
"context"
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const anchorEntryTTL = time.Hour
// Anchor is a DID anchor that encapsulates a public key that can be used
// for verification of signatures.
type Anchor interface {
DID() DID
Verify(data []byte, sig []byte) error
PublicKey() crypto.PubKey
}
// Provider holds the private key material necessary to sign statements for
// a DID.
type Provider interface {
DID() DID
Sign(data []byte) ([]byte, error)
Anchor() Anchor
PrivateKey() (crypto.PrivKey, error)
}
type TrustContext interface {
Anchors() []DID
Providers() []DID
GetAnchor(did DID) (Anchor, error)
GetProvider(did DID) (Provider, error)
AddAnchor(anchor Anchor)
AddProvider(provider Provider)
Start(gcInterval time.Duration)
Stop()
}
type anchorEntry struct {
anchor Anchor
expire time.Time
}
type BasicTrustContext struct {
mx sync.Mutex
anchors map[DID]*anchorEntry
providers map[DID]Provider
stop func()
}
var _ TrustContext = (*BasicTrustContext)(nil)
func NewTrustContext() TrustContext {
return &BasicTrustContext{
anchors: make(map[DID]*anchorEntry),
providers: make(map[DID]Provider),
}
}
func NewTrustContextWithPrivateKey(privk crypto.PrivKey) (TrustContext, error) {
ctx := NewTrustContext()
provider, err := ProviderFromPrivateKey(privk)
if err != nil {
return nil, fmt.Errorf("provide from private key: %w", err)
}
ctx.AddProvider(provider)
return ctx, nil
}
func NewTrustContextWithProvider(p Provider) TrustContext {
ctx := NewTrustContext()
ctx.AddProvider(p)
return ctx
}
func (ctx *BasicTrustContext) Anchors() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.anchors))
for anchor := range ctx.anchors {
result = append(result, anchor)
}
return result
}
func (ctx *BasicTrustContext) Providers() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.providers))
for provider := range ctx.providers {
result = append(result, provider)
}
return result
}
func (ctx *BasicTrustContext) GetAnchor(did DID) (Anchor, error) {
anchor, ok := ctx.getAnchor(did)
if ok {
return anchor, nil
}
anchor, err := GetAnchorForDID(did)
if err != nil {
return nil, fmt.Errorf("get anchor for did: %w", err)
}
ctx.AddAnchor(anchor)
return anchor, nil
}
func (ctx *BasicTrustContext) getAnchor(did DID) (Anchor, bool) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
entry, ok := ctx.anchors[did]
if ok {
entry.expire = time.Now().Add(anchorEntryTTL)
return entry.anchor, true
}
return nil, false
}
func (ctx *BasicTrustContext) GetProvider(did DID) (Provider, error) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
provider, ok := ctx.providers[did]
if !ok {
return nil, ErrNoProvider
}
return provider, nil
}
func (ctx *BasicTrustContext) AddAnchor(anchor Anchor) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.anchors[anchor.DID()] = &anchorEntry{
anchor: anchor,
expire: time.Now().Add(anchorEntryTTL),
}
}
func (ctx *BasicTrustContext) AddProvider(provider Provider) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.providers[provider.DID()] = provider
}
func (ctx *BasicTrustContext) Start(gcInterval time.Duration) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
}
gcCtx, stop := context.WithCancel(context.Background())
ctx.stop = stop
go ctx.gc(gcCtx, gcInterval)
}
func (ctx *BasicTrustContext) Stop() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
ctx.stop = nil
}
}
func (ctx *BasicTrustContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcAnchorEntries()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicTrustContext) gcAnchorEntries() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := time.Now()
for k, e := range ctx.anchors {
if e.expire.Before(now) {
delete(ctx.anchors, k)
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
import (
"strings"
)
type DID struct {
URI string `json:"uri,omitempty"`
}
func (did DID) Equal(other DID) bool {
return did.URI == other.URI
}
func (did DID) Empty() bool {
return did.URI == ""
}
func (did DID) String() string {
return did.URI
}
func (did DID) Method() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[1]
}
return ""
}
func (did DID) Identifier() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[2]
}
return ""
}
func FromString(s string) (DID, error) {
if s != "" {
parts := strings.Split(s, ":")
if len(parts) != 3 {
return DID{}, ErrInvalidDID
}
for _, part := range parts {
if part == "" {
return DID{}, ErrInvalidDID
}
}
// TODO validate parts according to spec: https://www.w3.org/TR/did-core/
}
return DID{URI: s}, nil
}
// Original Copyright 2024, Ucan Working Group; Modified Copyright 2024, NuNet;
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
//
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package did
import (
"fmt"
"strings"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
mb "github.com/multiformats/go-multibase"
varint "github.com/multiformats/go-varint"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
type PublicKeyAnchor struct {
did DID
pubk crypto.PubKey
}
var _ Anchor = (*PublicKeyAnchor)(nil)
type PrivateKeyProvider struct {
did DID
privk crypto.PrivKey
}
var _ Provider = (*PrivateKeyProvider)(nil)
func NewAnchor(did DID, pubk crypto.PubKey) Anchor {
return &PublicKeyAnchor{
did: did,
pubk: pubk,
}
}
func NewProvider(did DID, privk crypto.PrivKey) Provider {
return &PrivateKeyProvider{
did: did,
privk: privk,
}
}
func (a *PublicKeyAnchor) DID() DID {
return a.did
}
func (a *PublicKeyAnchor) Verify(data []byte, sig []byte) error {
ok, err := a.pubk.Verify(data, sig)
if err != nil {
return err
}
if !ok {
return ErrInvalidSignature
}
return nil
}
func (a *PublicKeyAnchor) PublicKey() crypto.PubKey {
return a.pubk
}
func (p *PrivateKeyProvider) DID() DID {
return p.did
}
func (p *PrivateKeyProvider) Sign(data []byte) ([]byte, error) {
return p.privk.Sign(data)
}
func (p *PrivateKeyProvider) PrivateKey() (crypto.PrivKey, error) {
return p.privk, nil
}
func (p *PrivateKeyProvider) Anchor() Anchor {
return NewAnchor(p.did, p.privk.GetPublic())
}
func FromID(id crypto.ID) (DID, error) {
pubk, err := crypto.PublicKeyFromID(id)
if err != nil {
return DID{}, fmt.Errorf("public key from id: %w", err)
}
return FromPublicKey(pubk), nil
}
func FromPublicKey(pubk crypto.PubKey) DID {
uri := FormatKeyURI(pubk)
return DID{URI: uri}
}
func PublicKeyFromDID(did DID) (crypto.PubKey, error) {
if did.Method() != "key" {
return nil, ErrInvalidDID
}
pubk, err := ParseKeyURI(did.URI)
if err != nil {
return nil, fmt.Errorf("parsing did key identifier: %w", err)
}
return pubk, nil
}
func AnchorFromPublicKey(pubk crypto.PubKey) (Anchor, error) {
did := FromPublicKey(pubk)
return NewAnchor(did, pubk), nil
}
func ProviderFromPrivateKey(privk crypto.PrivKey) (Provider, error) {
did := FromPublicKey(privk.GetPublic())
return NewProvider(did, privk), nil
}
// Note: this code originated in https://github.com/ucan-wg/go-ucan/blob/main/didkey/key.go
// Copyright applies; some superficial modifications by vyzo.
const (
multicodecKindEd25519PubKey uint64 = 0xed
multicodecKindSecp256k1PubKey uint64 = 0xe7
multicodecKindEthPubKey uint64 = 0xef01
keyPrefix = "did:key"
)
func FormatKeyURI(pubk crypto.PubKey) string {
raw, err := pubk.Raw()
if err != nil {
return ""
}
var t uint64
switch pubk.Type() {
case crypto.Ed25519:
t = multicodecKindEd25519PubKey
case crypto.Secp256k1:
t = multicodecKindSecp256k1PubKey
case crypto.Eth:
t = multicodecKindEthPubKey
default:
// we don't support those yet
log.Errorf("unsupported key type: %d", t)
return ""
}
size := varint.UvarintSize(t)
data := make([]byte, size+len(raw))
n := varint.PutUvarint(data, t)
copy(data[n:], raw)
b58BKeyStr, err := mb.Encode(mb.Base58BTC, data)
if err != nil {
return ""
}
return fmt.Sprintf("%s:%s", keyPrefix, b58BKeyStr)
}
func ParseKeyURI(uri string) (crypto.PubKey, error) {
if !strings.HasPrefix(uri, keyPrefix) {
return nil, fmt.Errorf("decentralized identifier is not a 'key' type")
}
uri = strings.TrimPrefix(uri, keyPrefix+":")
enc, data, err := mb.Decode(uri)
if err != nil {
return nil, fmt.Errorf("decoding multibase: %w", err)
}
if enc != mb.Base58BTC {
return nil, fmt.Errorf("unexpected multibase encoding: %s", mb.EncodingToStr[enc])
}
keyType, n, err := varint.FromUvarint(data)
if err != nil {
return nil, err
}
switch keyType {
case multicodecKindEd25519PubKey:
return libp2p_crypto.UnmarshalEd25519PublicKey(data[n:])
case multicodecKindSecp256k1PubKey:
return libp2p_crypto.UnmarshalSecp256k1PublicKey(data[n:])
case multicodecKindEthPubKey:
return crypto.UnmarshalEthPublicKey(data[n:])
default:
return nil, ErrInvalidKeyType
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
"os/exec"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const ledgerCLI = "ledger-cli"
type LedgerWalletProvider struct {
did DID
pubk crypto.PubKey
acct int
}
var _ Provider = (*LedgerWalletProvider)(nil)
type LedgerKeyOutput struct {
Key string `json:"key"`
Address string `json:"address"`
}
type LedgerSignOutput struct {
ECDSA LedgerSignECDSAOutput `json:"ecdsa"`
}
type LedgerSignECDSAOutput struct {
V uint `json:"v"`
R string `json:"r"`
S string `json:"s"`
}
func NewLedgerWalletProvider(acct int) (Provider, error) {
tmp, err := getLedgerTmpFile()
if err != nil {
return nil, err
}
defer os.Remove(tmp)
var output LedgerKeyOutput
if err := ledgerExec(
tmp,
&output,
"key",
"-o", tmp,
"-a", fmt.Sprintf("%d", acct),
); err != nil {
return nil, fmt.Errorf("error executing ledger cli: %w", err)
}
// decode the hex key
raw, err := hex.DecodeString(output.Key)
if err != nil {
return nil, fmt.Errorf("decode ledger key: %w", err)
}
pubk, err := crypto.UnmarshalEthPublicKey(raw)
if err != nil {
return nil, fmt.Errorf("unmarshal ledger raw key: %w", err)
}
did := FromPublicKey(pubk)
return &LedgerWalletProvider{
did: did,
pubk: pubk,
acct: acct,
}, nil
}
func ledgerExec(tmp string, output interface{}, args ...string) error {
ledger, err := exec.LookPath(ledgerCLI)
if err != nil {
return fmt.Errorf("can't find %s in PATH: %w", ledgerCLI, err)
}
cmd := exec.Command(ledger, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("get ledger key: %w", err)
}
f, err := os.Open(tmp)
if err != nil {
return fmt.Errorf("open ledger output: %w", err)
}
defer f.Close()
decoder := json.NewDecoder(f)
if err := decoder.Decode(&output); err != nil {
return fmt.Errorf("parse ledger output: %w", err)
}
return nil
}
func getLedgerTmpFile() (string, error) {
tmp, err := os.CreateTemp("", "ledger.out")
if err != nil {
return "", fmt.Errorf("creating temporary file: %w", err)
}
tmpPath := tmp.Name()
tmp.Close()
return tmpPath, nil
}
func (p *LedgerWalletProvider) DID() DID {
return p.did
}
func (p *LedgerWalletProvider) Sign(data []byte) ([]byte, error) {
tmp, err := getLedgerTmpFile()
if err != nil {
return nil, err
}
defer os.Remove(tmp)
dataHex := hex.EncodeToString(data)
var output LedgerSignOutput
if err := ledgerExec(
tmp,
&output,
"sign",
"-o", tmp,
"-a", fmt.Sprintf("%d", p.acct),
dataHex,
); err != nil {
return nil, fmt.Errorf("error executing ledger cli: %w", err)
}
rBytes, err := hex.DecodeString(output.ECDSA.R)
if err != nil {
return nil, fmt.Errorf("error decoding signature r: %w", err)
}
sBytes, err := hex.DecodeString(output.ECDSA.S)
if err != nil {
return nil, fmt.Errorf("error decoding signature s: %w", err)
}
r := secp256k1.ModNScalar{}
s := secp256k1.ModNScalar{}
if overflow := r.SetByteSlice(rBytes); overflow {
return nil, fmt.Errorf("signature r overflowed")
}
if overflow := s.SetByteSlice(sBytes); overflow {
return nil, fmt.Errorf("signature s overflowed")
}
sig := ecdsa.NewSignature(&r, &s)
return sig.Serialize(), nil
}
func (p *LedgerWalletProvider) Anchor() Anchor {
return NewAnchor(p.did, p.pubk)
}
func (p *LedgerWalletProvider) PrivateKey() (crypto.PrivKey, error) {
return nil, fmt.Errorf("ledger private key cannot be exported: %w", ErrHardwareKey)
}
//go:build linux
package sys
import (
"fmt"
"kernel.org/pub/linux/libs/security/libcap/cap"
)
// RequiredCaps checks if the required capabilities are set
func RequiredCaps() error {
caps := cap.GetProc()
adminP, err := caps.GetFlag(cap.Permitted, cap.NET_ADMIN)
if err != nil {
return fmt.Errorf("error getting NET_ADMIN flag: %w", err)
}
adminE, err := caps.GetFlag(cap.Effective, cap.NET_ADMIN)
if err != nil {
return fmt.Errorf("error getting NET_ADMIN flag: %w", err)
}
if adminP && adminE {
return nil
}
return fmt.Errorf("required capability NET_ADMIN not set")
}
package sys
import (
"net"
"strings"
"github.com/songgao/water"
)
// types for tun tap
type TunTapMode int
const (
NetTunMode TunTapMode = iota
NetTapMode
)
// // TUN is a struct containing the fields necessary
// // to configure a system TUN device. Access the
// // internal TUN device through TUN.Iface
type NetInterface struct {
Iface *water.Interface
Src string
Dst string
}
// GetNetInterfaces gets the list of network interfaces
func GetNetInterfaces() ([]net.Interface, error) {
return net.Interfaces()
}
// GetNetInterfaceByName gets the network interface by name
func GetNetInterfaceByName(name string) (*net.Interface, error) {
return net.InterfaceByName(name)
}
func GetUsedAddresses() ([]string, error) {
ifaces, err := GetNetInterfaces()
if err != nil {
return nil, err
}
var networks []string
for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
if !strings.Contains(addr.String(), ":") {
networks = append(networks, addr.String())
}
}
}
return networks, nil
}
//go:build linux
package sys
import (
"fmt"
"net"
"os"
"os/exec"
"syscall"
"github.com/songgao/water"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
// NewTunTapInterface creates a new tun/tap interface
func NewTunTapInterface(name string, mode TunTapMode, persist bool) (*NetInterface, error) {
var intMode water.DeviceType = water.TAP
if mode == NetTunMode {
intMode = water.TUN
}
config := water.Config{
DeviceType: intMode,
}
config.Name = name
if persist {
config.PlatformSpecificParams.Persist = true
}
iface, err := water.New(config)
if err != nil {
return nil, fmt.Errorf("error creating interface: %w", err)
}
return &NetInterface{
Iface: iface,
}, nil
}
// UpNetInterface brings the network interface up
func (n *NetInterface) Up() error {
link, err := netlink.LinkByName(n.Iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("error setting network interface up: %w", err)
}
return nil
}
// DownNetInterface brings the network interface down
func (n *NetInterface) Down() error {
link, err := netlink.LinkByName(n.Iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.LinkSetDown(link); err != nil {
return fmt.Errorf("error setting network interface down: %w", err)
}
return nil
}
// DeleteNetInterface deletes the network interface
func (n *NetInterface) Delete() error {
link, err := netlink.LinkByName(n.Iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.LinkDel(link); err != nil {
return fmt.Errorf("error deleting network interface: %w", err)
}
return nil
}
// SetMTU sets the MTU of the network interface
func (n *NetInterface) SetMTU(mtu int) error {
link, err := netlink.LinkByName(n.Iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.LinkSetMTU(link, mtu); err != nil {
return fmt.Errorf("error setting network interface mtu: %w", err)
}
return nil
}
// SetAddress sets the address of the network interface in CIDR notation
func (n *NetInterface) SetAddress(address string) error {
addr, err := netlink.ParseAddr(address)
if err != nil {
return fmt.Errorf("error parsing address: %w", err)
}
link, err := netlink.LinkByName(n.Iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.AddrAdd(link, addr); err != nil {
return fmt.Errorf("error setting network interface address: %w", err)
}
return nil
}
// AddRoute adds a route to the network interface
func (n *NetInterface) AddRoute(route string) error {
link, err := netlink.LinkByName(n.Iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
var ip net.IP
err = ip.UnmarshalText([]byte(route))
if err != nil {
return fmt.Errorf("error parsing route: %w", err)
}
err = netlink.RouteAdd(&netlink.Route{
LinkIndex: link.Attrs().Index,
Gw: ip,
Priority: 3000,
})
if err != nil {
return fmt.Errorf("error adding route: %w", err)
}
return nil
}
// DelRoute deletes a route from the network interface
func (n *NetInterface) DelRoute(route string) error {
link, err := netlink.LinkByName(n.Iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
netRoute, err := netlink.ParseIPNet(route)
if err != nil {
return fmt.Errorf("error parsing route: %w", err)
}
return netlink.RouteDel(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: netRoute,
Priority: 3000,
})
}
// ForwardingEnabled checks if IP forwarding is enabled
func ForwardingEnabled() (bool, error) {
data, err := os.ReadFile("/proc/sys/net/ipv4/ip_forward")
if err != nil {
return false, fmt.Errorf("error reading ip_forward: %v", err)
}
if string(data) == "1\n" {
return true, nil
}
return false, nil
}
// AddDNATRule adds a DNAT rule to iptables PRERROUTING chain
func AddDNATRule(protocol, sourceIP, sourcePort, destIP, destPort string) error {
args := []string{
"PREROUTING", "-t", "nat",
"-d", sourceIP, "-p", protocol,
"--dport", sourcePort, "-j", "DNAT",
"--to-destination", destIP + ":" + destPort,
}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding DNAT rule: %w", err)
}
}
return nil
}
// DelDNATRule deletes a DNAT rule to iptables PRERROUTING chain if it exists
func DelDNATRule(protocol, sourceIP, sourcePort, destIP, destPort string) error {
args := []string{
"PREROUTING", "-t", "nat",
"-d", sourceIP, "-p", protocol,
"--dport", sourcePort, "-j", "DNAT",
"--to-destination", destIP + ":" + destPort,
}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error deleting DNAT rule: %w", err)
}
}
return nil
}
// AddForwardRule adds an ip:port FORWARD rule to iptables
func AddForwardRule(protocol, destIP, destPort string) error {
args := []string{
"FORWARD", "-t", "filter",
"-p", protocol, "-d", destIP,
"--dport", destPort, "-j", "ACCEPT",
}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding forward rule: %w", err)
}
}
return nil
}
// DelForwardRule deletes an ip:port FORWARD rule if it exists
func DelForwardRule(protocol, destIP, destPort string) error {
args := []string{
"FORWARD", "-t", "filter",
"-p", protocol, "-d", destIP,
"--dport", destPort, "-j", "ACCEPT",
}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error adding froward rule: %w", err)
}
}
return nil
}
// AddForwardIntRule adds a FORWARD between interfaces rule to iptables
func AddForwardIntRule(inInt, outInt string) error {
args := []string{
"FORWARD", "-t", "filter",
"-i", inInt, "-o", outInt, "-j", "ACCEPT",
}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding interface forward rule: %w", err)
}
}
return nil
}
// DelForwardIntRule deletes a FORWARD between interfaces rule if it exists
func DelForwardIntRule(inInt, outInt string) error {
args := []string{
"FORWARD", "-t", "filter",
"-i", inInt, "-o", outInt, "-j", "ACCEPT",
}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error deleting interface forward rule: %w", err)
}
}
return nil
}
// AddMasqueradeRule adds a MASQUERADE rule to iptables POSTROUTING chain
func AddMasqueradeRule() error {
args := []string{"POSTROUTING", "-t", "nat", "-j", "MASQUERADE"}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding rule: %w", err)
}
}
return nil
}
// DelMasqueradeRule deletes a MASQUERADE rule from the POSTROUTING chain if it exists
func DelMasqueradeRule() error {
args := []string{"POSTROUTING", "-t", "nat", "-j", "MASQUERADE"}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error deleting masquerade rule: %w", err)
}
}
return nil
}
// AddRelEstRule adds a RELATED,ESTABLISHED rule to specified chain of iptables
func AddRelEstRule(chain string) error {
args := []string{
chain, "-t", "filter",
"-m", "conntrack", "--ctstate",
"RELATED,ESTABLISHED", "-j", "ACCEPT",
}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding related,established rule: %w", err)
}
}
return nil
}
// DelRelEstRule deletes a RELATED,ESTABLISHED rule from the specified chain if it exists
func DelRelEstRule(chain string) error {
args := []string{
chain, "-t", "filter",
"-m", "conntrack", "--ctstate",
"RELATED,ESTABLISHED", "-j", "ACCEPT",
}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error deleting related,established rule: %w", err)
}
}
return nil
}
// iptDeleteRule deletes the specified rule from the iptables chain
func iptDeleteRule(rule ...string) error {
out, err := execCmd("iptables", append([]string{"-D"}, rule...))
if err != nil {
return fmt.Errorf("error deleting rule: %w, output: %s", err, out)
}
return nil
}
func iptAppendRule(rule ...string) error {
out, err := execCmd("iptables", append([]string{"-A"}, rule...))
if err != nil {
return fmt.Errorf("error appending rule: %w, output: %s", err, out)
}
return nil
}
func iptRuleExist(rule ...string) bool {
_, err := execCmd("iptables", append([]string{"-C"}, rule...))
return err == nil
}
func execCmd(command string, args []string) (string, error) {
cmd := exec.Command(command, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{
AmbientCaps: []uintptr{
unix.CAP_NET_ADMIN,
},
}
output, err := cmd.CombinedOutput()
if err != nil {
return string(output), fmt.Errorf("failed to execute command: %q: %w", command, err)
}
return string(output), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ucan
import (
"strings"
)
type Capability string
const Root = Capability("/")
func (c Capability) Implies(other Capability) bool {
if c == other || c == Root {
return true
}
return strings.HasPrefix(string(other), string(c)+"/")
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ucan
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
)
const (
maxCapabilitySize = 16384
SelfSignNo SelfSignMode = iota
SelfSignAlso
SelfSignOnly
)
type SelfSignMode int
type CapabilityContext interface {
// DID returns the context's controlling DID
DID() did.DID
// Trust returns the context's did trust context
Trust() did.TrustContext
// Consume ingests the provided capability tokens
Consume(origin did.DID, cap []byte) error
// Discard discards previously consumed capability tokens
Discard(cap []byte)
// Require ensures that at least one of the capabilities is delegated from
// the subject to the audience, with an appropriate anchor
// An empty list will mean that no capabilities are required and is vacuously
// true.
Require(anchor did.DID, subject crypto.ID, audience crypto.ID, require []Capability) error
// RequireBroadcast ensures that at least one of the capabilities is delegated
// to thes subject for the specified broadcast topics
RequireBroadcast(origin did.DID, subject crypto.ID, topic string, require []Capability) error
// Provide prepares the appropriate capability tokens to prove and delegate authority
// to a subject for an audience.
// - It delegates invocations to the subject with an audience and invoke capabilities
// - It delegates the delegate capabilities to the target with audience the subject
Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, delegate []Capability) ([]byte, error)
// ProvideBroadcast prepares the appropriate capability tokens to prove authority
// to broadcast to a topic
ProvideBroadcast(subject crypto.ID, topic string, expire uint64, broadcast []Capability) ([]byte, error)
// AddRoots adds trust anchors
AddRoots(trust []did.DID, require, provide TokenList) error
// ListRoots list the current trust anchors
ListRoots() ([]did.DID, TokenList, TokenList)
// RemoveRoots removes the specified trust anchors
RemoveRoots(trust []did.DID, require, provide TokenList)
// Delegate creates the appropriate delegation tokens anchored in our roots
Delegate(subject, audience did.DID, topics []string, expire, depth uint64, cap []Capability, selfSign SelfSignMode) (TokenList, error)
// DelegateInvocation creates the appropriate invocation tokens anchored in anchor
DelegateInvocation(target, subject, audience did.DID, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error)
// DelegateBroadcast creates the appropriate broadcast token anchored in our roots
DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error)
// Grant creates the appropriate delegation tokens considering ourselves as the root
Grant(action Action, subject, audience did.DID, topic []string, expire, depth uint64, provide []Capability) (TokenList, error)
// Start starts a token garbage collector goroutine that clears expired tokens
Start(gcInterval time.Duration)
// Stop stops a previously started gc goroutine
Stop()
}
type BasicCapabilityContext struct {
mx sync.Mutex
provider did.Provider
trust did.TrustContext
roots map[did.DID]struct{} // our root anchors of trust
require map[did.DID][]*Token // our acceptance side-roots
provide map[did.DID][]*Token // root capabilities -> tokens
tokens map[did.DID][]*Token // our context dependent capabilities; subject -> tokens
stop func()
}
var _ CapabilityContext = (*BasicCapabilityContext)(nil)
func NewCapabilityContext(trust did.TrustContext, ctxDID did.DID, roots []did.DID, require, provide TokenList) (CapabilityContext, error) {
ctx := &BasicCapabilityContext{
trust: trust,
roots: make(map[did.DID]struct{}),
require: make(map[did.DID][]*Token),
provide: make(map[did.DID][]*Token),
tokens: make(map[did.DID][]*Token),
}
p, err := trust.GetProvider(ctxDID)
if err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
ctx.provider = p
if err := ctx.AddRoots(roots, require, provide); err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
return ctx, nil
}
func (ctx *BasicCapabilityContext) DID() did.DID {
return ctx.provider.DID()
}
func (ctx *BasicCapabilityContext) Trust() did.TrustContext {
return ctx.trust
}
func (ctx *BasicCapabilityContext) Start(gcInterval time.Duration) {
if ctx.stop != nil {
gcCtx, cancel := context.WithCancel(context.Background())
go ctx.gc(gcCtx, gcInterval)
ctx.stop = cancel
}
}
func (ctx *BasicCapabilityContext) Stop() {
if ctx.stop != nil {
ctx.stop()
}
}
func (ctx *BasicCapabilityContext) AddRoots(roots []did.DID, require, provide TokenList) error {
ctx.addRoots(roots)
now := uint64(time.Now().UnixNano())
for _, t := range require.Tokens {
if err := t.Verify(ctx.trust, now); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeRequireToken(t)
}
for _, t := range provide.Tokens {
if err := t.Verify(ctx.trust, now); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeProvideToken(t)
}
return nil
}
func (ctx *BasicCapabilityContext) ListRoots() ([]did.DID, TokenList, TokenList) {
var require, provide []*Token
roots := ctx.getRoots()
for _, anchor := range ctx.getRequireAnchors() {
tokenList := ctx.getRequireTokens(anchor)
require = append(require, tokenList...)
}
for _, anchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(anchor)
provide = append(provide, tokenList...)
}
return roots, TokenList{Tokens: require}, TokenList{Tokens: provide}
}
func (ctx *BasicCapabilityContext) RemoveRoots(trust []did.DID, require, provide TokenList) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, root := range trust {
delete(ctx.roots, root)
}
for _, t := range require.Tokens {
tokenList, ok := ctx.require[t.Issuer()]
if ok {
tokenList = slices.DeleteFunc(tokenList, func(ot *Token) bool {
return bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(tokenList) > 0 {
ctx.require[t.Issuer()] = tokenList
} else {
delete(ctx.require, t.Issuer())
}
}
}
for _, t := range provide.Tokens {
tokenList, ok := ctx.provide[t.Issuer()]
if ok {
tokenList = slices.DeleteFunc(tokenList, func(ot *Token) bool {
return bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(tokenList) > 0 {
ctx.provide[t.Issuer()] = tokenList
} else {
delete(ctx.provide, t.Issuer())
}
}
}
}
func (ctx *BasicCapabilityContext) Grant(action Action, subject, audience did.DID, topics []string, expire, depth uint64, provide []Capability) (TokenList, error) {
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return TokenList{}, fmt.Errorf("nonce: %w", err)
}
topicCap := make([]Capability, 0, len(topics))
for _, topic := range topics {
topicCap = append(topicCap, Capability(topic))
}
result := &DMSToken{
Issuer: ctx.DID(),
Subject: subject,
Audience: audience,
Action: action,
Topic: topicCap,
Capability: provide,
Nonce: nonce,
Expire: expire,
Depth: depth,
}
data, err := result.SignatureData()
if err != nil {
return TokenList{}, fmt.Errorf("grant: %w", err)
}
sig, err := ctx.provider.Sign(data)
if err != nil {
return TokenList{}, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return TokenList{Tokens: []*Token{{DMS: result}}}, nil
}
func (ctx *BasicCapabilityContext) Delegate(subject, audience did.DID, topics []string, expire, depth uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
topicCap := make([]Capability, 0, len(topics))
for _, topic := range topics {
topicCap = append(topicCap, Capability(topic))
}
var result []*Token
if selfSign == SelfSignOnly {
goto selfsign
}
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
if len(tokenList) == 0 {
continue
}
for _, t := range tokenList {
var providing []Capability
definitiveExpire := expire
if definitiveExpire == 0 {
definitiveExpire = t.Expire()
}
for _, c := range provide {
if t.Anchor(trustAnchor) && t.AllowDelegation(Delegate, ctx.DID(), audience, topicCap, definitiveExpire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.Delegate(ctx.provider, subject, audience, topicCap, definitiveExpire, depth, providing)
if err != nil {
log.Debugf("error delegating %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
tokens, err := ctx.Grant(Delegate, subject, audience, topics, expire, depth, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, tokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) DelegateInvocation(target, subject, audience did.DID, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
// first get tokens we have about ourselves and see if any allows delegation to
// the subject for the audience
tokenList := ctx.getSubjectTokens(ctx.DID())
tokens := ctx.delegateInvocation(tokenList, target, subject, audience, expire, provide)
result = append(result, tokens...)
if selfSign == SelfSignOnly {
goto selfsign
}
// then we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateInvocation(tokenList, trustAnchor, subject, audience, expire, provide)
result = append(result, tokens...)
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
selfTokens, err := ctx.Grant(Invoke, subject, audience, nil, expire, 0, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateInvocation(tokenList []*Token, anchor, subject, audience did.DID, expire uint64, provide []Capability) []*Token {
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(Invoke, ctx.DID(), audience, nil, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.DelegateInvocation(ctx.provider, subject, audience, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
if selfSign == SelfSignOnly {
goto selfsign
}
// first we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateBroadcast(tokenList, trustAnchor, subject, topic, expire, provide)
result = append(result, tokens...)
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
selfTokens, err := ctx.Grant(Broadcast, subject, did.DID{}, []string{topic}, expire, 0, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting broadcast: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateBroadcast(tokenList []*Token, anchor did.DID, subject did.DID, topic string, expire uint64, provide []Capability) []*Token {
topicCap := Capability(topic)
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(Broadcast, ctx.DID(), did.DID{}, []Capability{topicCap}, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.DelegateBroadcast(ctx.provider, subject, topicCap, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) Consume(origin did.DID, data []byte) error {
if len(data) > maxCapabilitySize {
return ErrTooBig
}
var tokens TokenList
if err := json.Unmarshal(data, &tokens); err != nil {
return fmt.Errorf("unmarshaling payload: %w", err)
}
rootAnchors := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
now := uint64(time.Now().UnixNano())
for _, t := range tokens.Tokens {
if t.Anchor(ctx.DID()) {
goto verify
}
if t.Anchor(origin) {
goto verify
}
for _, anchor := range rootAnchors {
if t.Anchor(anchor) {
goto verify
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) {
goto verify
}
}
}
log.Debugf("ignoring token %+v", *t)
continue
verify:
if err := t.Verify(ctx.trust, now); err != nil {
log.Warnf("failed to verify token issued by %s: %s", t.Issuer(), err)
continue
}
ctx.consumeSubjectToken(t)
}
return nil
}
func (ctx *BasicCapabilityContext) Discard(data []byte) {
var tokens TokenList
if err := json.Unmarshal(data, &tokens); err != nil {
return
}
ctx.discardTokens(tokens.Tokens)
}
func (ctx *BasicCapabilityContext) consumeAnchorToken(getf func() []*Token, setf func(result []*Token), t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList := getf()
result := make([]*Token, 0, len(tokenList)+1)
for _, ot := range tokenList {
if ot.Subsumes(t) {
return
}
if t.Subsumes(ot) {
continue
}
result = append(result, ot)
}
result = append(result, t)
setf(result)
}
func (ctx *BasicCapabilityContext) consumeRequireToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.require[t.Issuer()] },
func(result []*Token) {
ctx.require[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeProvideToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.provide[t.Issuer()] },
func(result []*Token) {
ctx.provide[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeSubjectToken(t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
subject := t.Subject()
tokenList := ctx.tokens[subject]
tokenList = append(tokenList, t)
ctx.tokens[subject] = tokenList
}
func (ctx *BasicCapabilityContext) Require(anchor did.DID, subject crypto.ID, audience crypto.ID, cap []Capability) error {
if len(cap) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return fmt.Errorf("DID for audience: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
for _, t := range tokenList {
for _, c := range cap {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) RequireBroadcast(anchor did.DID, subject crypto.ID, topic string, require []Capability) error {
if len(require) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
topicCap := Capability(topic)
for _, t := range tokenList {
for _, c := range require {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, provide []Capability) ([]byte, error) {
if len(invoke) == 0 && len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return nil, fmt.Errorf("DID for audience: %w", err)
}
var result []*Token
var invocation, delegation TokenList
if len(invoke) == 0 {
return nil, fmt.Errorf("no invocation capabilities: %w", ErrNotAuthorized)
}
invocation, err = ctx.DelegateInvocation(target, subjectDID, audienceDID, expire, invoke, SelfSignAlso)
if err != nil {
return nil, fmt.Errorf("cannot provide invocation tokens: %w", err)
}
if len(invocation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary invocation tokens: %w", ErrNotAuthorized)
}
result = append(result, invocation.Tokens...)
if len(provide) == 0 {
goto marshal
}
delegation, err = ctx.Delegate(target, subjectDID, nil, expire, 1, provide, SelfSignOnly)
if err != nil {
return nil, fmt.Errorf("cannot provide delegation tokens: %w", err)
}
if len(delegation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary delegation tokens: %w", ErrNotAuthorized)
}
result = append(result, delegation.Tokens...)
marshal:
payload := TokenList{Tokens: result}
data, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) ProvideBroadcast(subject crypto.ID, topic string, expire uint64, provide []Capability) ([]byte, error) {
if len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
broadcast, err := ctx.DelegateBroadcast(subjectDID, topic, expire, provide, SelfSignAlso)
if err != nil {
return nil, fmt.Errorf("cannot provide broadcast tokens: %w", err)
}
if len(broadcast.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary broadcast tokens: %w", ErrNotAuthorized)
}
data, err := json.Marshal(broadcast)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) getRoots() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.roots))
for anchor := range ctx.roots {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) addRoots(anchors []did.DID) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, anchor := range anchors {
ctx.roots[anchor] = struct{}{}
}
}
func (ctx *BasicCapabilityContext) getRequireAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.require))
for anchor := range ctx.require {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getProvideAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.provide))
for anchor := range ctx.provide {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getTokens(getf func() ([]*Token, bool), setf func([]*Token)) []*Token {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList, ok := getf()
if !ok {
return nil
}
// filter expired
now := uint64(time.Now().UnixNano())
result := slices.DeleteFunc(slices.Clone(tokenList), func(t *Token) bool {
return t.ExpireBefore(now)
})
setf(result)
return result
}
func (ctx *BasicCapabilityContext) getRequireTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.require[anchor]; return result, ok },
func(result []*Token) { ctx.require[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getProvideTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.provide[anchor]; return result, ok },
func(result []*Token) { ctx.provide[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getSubjectTokens(subject did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.tokens[subject]; return result, ok },
func(result []*Token) { ctx.tokens[subject] = result },
)
}
func (ctx *BasicCapabilityContext) discardTokens(tokens []*Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, t := range tokens {
subject := t.Subject()
subjectTokens := slices.DeleteFunc(slices.Clone(ctx.tokens[subject]), func(ot *Token) bool {
return t.Issuer() == ot.Issuer() && bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(subjectTokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = subjectTokens
}
}
}
func (ctx *BasicCapabilityContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcTokens()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicCapabilityContext) gcTokens() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := uint64(time.Now().UnixNano())
for anchor, tokens := range ctx.require {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.require, anchor)
} else {
ctx.require[anchor] = tokens
}
}
for anchor, tokens := range ctx.provide {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.provide, anchor)
} else {
ctx.provide[anchor] = tokens
}
}
for subject, tokens := range ctx.tokens {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = tokens
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ucan
import (
"encoding/json"
"fmt"
"io"
"gitlab.com/nunet/device-management-service/lib/did"
)
type Saver interface {
Save(wr io.Writer) error
}
type CapabilityContextView struct {
DID did.DID
Roots []did.DID
Require TokenList
Provide TokenList
}
func SaveCapabilityContext(ctx CapabilityContext, wr io.Writer) error {
if bctx, ok := ctx.(*BasicCapabilityContext); ok {
return saveCapabilityContext(bctx, wr)
} else if saver, ok := ctx.(Saver); ok {
return saver.Save(wr)
}
return fmt.Errorf("cannot save context: %w", ErrBadContext)
}
func saveCapabilityContext(ctx *BasicCapabilityContext, wr io.Writer) error {
roots, require, provide := ctx.ListRoots()
view := CapabilityContextView{
DID: ctx.provider.DID(),
Roots: roots,
Require: require,
Provide: provide,
}
encoder := json.NewEncoder(wr)
if err := encoder.Encode(&view); err != nil {
return fmt.Errorf("encoding capability context view: %w", err)
}
return nil
}
func LoadCapabilityContext(trust did.TrustContext, rd io.Reader) (CapabilityContext, error) {
var view CapabilityContextView
decoder := json.NewDecoder(rd)
if err := decoder.Decode(&view); err != nil {
return nil, fmt.Errorf("decoding capability context view: %w", err)
}
var require, provide TokenList
for _, t := range view.Require.Tokens {
if !t.Expired() {
require.Tokens = append(require.Tokens, t)
}
}
for _, t := range view.Provide.Tokens {
if !t.Expired() {
provide.Tokens = append(provide.Tokens, t)
}
}
return NewCapabilityContext(trust, view.DID, view.Roots, require, provide)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ucan
import (
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"time"
"gitlab.com/nunet/device-management-service/lib/did"
)
type Action string
const (
Invoke Action = "invoke"
Delegate Action = "delegate"
Broadcast Action = "broadcast"
// Revoke Action = "revoke" // TODO
nonceLength = 12 // 96 bits
)
var signaturePrefix = []byte("dms:token:")
type Token struct {
// DMS tokens
DMS *DMSToken `json:"dms,omitempty"`
// UCAN standard (when it is done) envelope for BYO anhcors
UCAN *BYOToken `json:"ucan,omitempty"`
}
type DMSToken struct {
Action Action `json:"act"`
Issuer did.DID `json:"iss"`
Subject did.DID `json:"sub"`
Audience did.DID `json:"aud"`
Topic []Capability `json:"topic,omitempty"`
Capability []Capability `json:"cap"`
Nonce []byte `json:"nonce"`
Expire uint64 `json:"exp"`
Depth uint64 `json:"depth,omitempty"`
Chain *Token `json:"chain,omitempty"`
Signature []byte `json:"sig,omitempty"`
}
type BYOToken struct {
// TODO followup
}
type TokenList struct {
Tokens []*Token `json:"tok,omitempty"`
}
func (t *Token) SignatureData() ([]byte, error) {
switch {
case t.DMS != nil:
return t.DMS.SignatureData()
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) SignatureData() ([]byte, error) {
tCopy := *t
tCopy.Signature = nil
data, err := json.Marshal(&tCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (t *Token) Issuer() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Issuer
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Subject() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Subject
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Audience() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Audience
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Capability() []Capability {
switch {
case t.DMS != nil:
return t.DMS.Capability
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Topic() []Capability {
switch {
case t.DMS != nil:
return t.DMS.Topic
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Expire() uint64 {
switch {
case t.DMS != nil:
return t.DMS.Expire
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0
}
}
func (t *Token) Nonce() []byte {
switch {
case t.DMS != nil:
return t.DMS.Nonce
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil // expired right after the unix big bang
}
}
func (t *Token) Action() Action {
switch {
case t.DMS != nil:
return t.DMS.Action
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return Action("")
}
}
func (t *Token) Verify(trust did.TrustContext, now uint64) error {
return t.verify(trust, now, 0)
}
func (t *Token) verify(trust did.TrustContext, now, depth uint64) error {
switch {
case t.DMS != nil:
return t.DMS.verify(trust, now, depth)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return ErrBadToken
}
}
func (t *DMSToken) verify(trust did.TrustContext, now, depth uint64) error {
if t.ExpireBefore(now) {
return ErrCapabilityExpired
}
if t.Depth > 0 && depth > t.Depth {
return fmt.Errorf("max token depth exceeded: %w", ErrNotAuthorized)
}
if t.Chain != nil {
if t.Chain.Action() != Delegate {
return fmt.Errorf("verify: chain does not allow delegation: %w", ErrNotAuthorized)
}
if t.Chain.ExpireBefore(t.Expire) {
return ErrCapabilityExpired
}
if err := t.Chain.verify(trust, now, depth+1); err != nil {
return err
}
if !t.Issuer.Equal(t.Chain.Subject()) {
return fmt.Errorf("verify: issuer/chain subject misnmatch: %w", ErrNotAuthorized)
}
needCapability := slices.Clone(t.Capability)
for _, c := range t.Capability {
if t.Chain.allowDelegation(t.Issuer, t.Audience, t.Topic, t.Expire, c) {
needCapability = slices.DeleteFunc(needCapability, func(oc Capability) bool {
return c == oc
})
if len(needCapability) == 0 {
break
}
}
}
if len(needCapability) > 0 {
return fmt.Errorf("verify: capabilities are not allowed by the chain: %w", ErrNotAuthorized)
}
}
anchor, err := trust.GetAnchor(t.Issuer)
if err != nil {
return fmt.Errorf("verify: anchor: %w", err)
}
data, err := t.SignatureData()
if err != nil {
return fmt.Errorf("verify: signature data: %w", err)
}
if err := anchor.Verify(data, t.Signature); err != nil {
return fmt.Errorf("verify: signature: %w", err)
}
return nil
}
func (t *Token) AllowAction(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowAction(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowAction(ot *Token) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(ot.Expire()) {
return false
}
if !ot.Anchor(t.Subject) {
return false
}
if t.Depth > 0 {
depth, ok := ot.AnchorDepth(t.Subject)
if ok && depth > t.Depth {
return false
}
}
if !t.Audience.Empty() && !t.Audience.Equal(ot.Audience()) {
return false
}
for _, oc := range ot.Capability() {
allow := false
for _, c := range t.Capability {
if c.Implies(oc) {
allow = true
break
}
}
if !allow {
return false
}
}
for _, otherTopic := range ot.Topic() {
allow := false
for _, topic := range t.Topic {
if topic.Implies(otherTopic) {
allow = true
break
}
}
if !allow {
return false
}
}
return true
}
func (t *Token) Size() int {
data, _ := t.SignatureData()
return len(data)
}
func (t *Token) Subsumes(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.Subsumes(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Subsumes(ot *Token) bool {
if t.Issuer.Equal(ot.Issuer()) &&
t.Subject.Equal(ot.Subject()) &&
t.Audience.Equal(ot.Audience()) &&
t.Expire >= ot.Expire() {
loop:
for _, oc := range ot.Capability() {
for _, c := range t.Capability {
if c.Implies(oc) {
continue loop
}
}
return false
}
return true
}
return false
}
func (t *Token) AllowInvocation(subject, audience did.DID, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowInvocation(subject, audience, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowInvocation(subject, audience did.DID, c Capability) bool {
if t.Action != Invoke {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, granted := range t.Capability {
if granted.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowBroadcast(subject did.DID, topic Capability, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowBroadcast(subject, topic, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowBroadcast(subject did.DID, topic Capability, c Capability) bool {
if t.Action != Broadcast {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() {
return false
}
allow := false
for _, allowTopic := range t.Topic {
if allowTopic.Implies(topic) {
allow = true
break
}
}
if !allow {
return false
}
for _, allowCap := range t.Capability {
if allowCap.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowDelegation(action Action, issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowDelegation(action, issuer, audience, topics, expire, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowDelegation(action Action, issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
if action == Delegate {
if !t.verifyDepth(2) {
// certificate would be dead end with 1
return false
}
} else {
if !t.verifyDepth(1) {
return false
}
}
return t.allowDelegation(issuer, audience, topics, expire, c)
}
func (t *Token) allowDelegation(issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.allowDelegation(issuer, audience, topics, expire, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) allowDelegation(issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(expire) {
return false
}
if !t.Subject.Equal(issuer) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, topic := range topics {
allow := false
for _, myTopic := range t.Topic {
if myTopic.Implies(topic) {
allow = true
break
}
}
if !allow {
return false
}
}
for _, myCap := range t.Capability {
if myCap.Implies(c) {
return true
}
}
return false
}
func (t *Token) verifyDepth(depth uint64) bool {
switch {
case t.DMS != nil:
return t.DMS.verifyDepth(depth)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) verifyDepth(depth uint64) bool {
if t.Depth > 0 && depth > t.Depth {
return false
}
if t.Chain != nil {
return t.Chain.verifyDepth(depth + 1)
}
return true
}
func (t *Token) Delegate(provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.Delegate(provider, subject, audience, topics, expire, depth, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) Delegate(provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Delegate, provider, subject, audience, topics, expire, depth, c)
}
func (t *DMSToken) delegate(action Action, provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*DMSToken, error) {
if t.Action != Delegate {
return nil, ErrNotAuthorized
}
if action == Delegate {
if !t.verifyDepth(2) {
// certificate would be dead end with 1
return nil, ErrNotAuthorized
}
} else {
if !t.verifyDepth(1) {
return nil, ErrNotAuthorized
}
}
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return nil, fmt.Errorf("nonce: %w", err)
}
result := &DMSToken{
Action: action,
Issuer: provider.DID(),
Subject: subject,
Audience: audience,
Topic: topics,
Capability: c,
Nonce: nonce,
Expire: expire,
Depth: depth,
Chain: &Token{DMS: t},
}
data, err := result.SignatureData()
if err != nil {
return nil, fmt.Errorf("delegate: %w", err)
}
sig, err := provider.Sign(data)
if err != nil {
return nil, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return result, nil
}
func (t *Token) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateInvocation(provider, subject, audience, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Invoke, provider, subject, audience, nil, expire, 0, c)
}
func (t *Token) DelegateBroadcast(provider did.Provider, subject did.DID, topic Capability, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateBroadcast(provider, subject, topic, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateBroadcast(provider did.Provider, subject did.DID, topic Capability, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Broadcast, provider, subject, did.DID{}, []Capability{topic}, expire, 0, c)
}
func (t *Token) Anchor(anchor did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.Anchor(anchor)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Anchor(anchor did.DID) bool {
if t.Issuer.Equal(anchor) {
return true
}
if t.Chain != nil {
return t.Chain.Anchor(anchor)
}
return false
}
func (t *Token) AnchorDepth(anchor did.DID) (uint64, bool) {
switch {
case t.DMS != nil:
return t.DMS.AnchorDepth(anchor)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0, false
}
}
func (t *DMSToken) AnchorDepth(anchor did.DID) (depth uint64, have bool) {
if t.Issuer.Equal(anchor) {
have = true
depth = 0
}
if t.Chain != nil {
if chainDepth, chainHave := t.Chain.AnchorDepth(anchor); chainHave {
have = true
depth = chainDepth + 1
}
}
return depth, have
}
func (t *Token) Expired() bool {
return t.ExpireBefore(uint64(time.Now().UnixNano()))
}
func (t *Token) ExpireBefore(deadline uint64) bool {
switch {
case t.DMS != nil:
return t.DMS.ExpireBefore(deadline)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return true
}
}
func (t *DMSToken) ExpireBefore(deadline uint64) bool {
if deadline > t.Expire {
return true
}
if t.Chain != nil {
return t.Chain.ExpireBefore(deadline)
}
return false
}
func (t *Token) SelfSigned(origin did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.SelfSigned(origin)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) SelfSigned(origin did.DID) bool {
if t.Chain != nil {
return t.Chain.SelfSigned(origin)
}
return t.Issuer.Equal(origin)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package main
import "gitlab.com/nunet/device-management-service/cmd"
// @title Device Management Service
// @version 0.4.185
// @description A dashboard application for computing providers.
// @termsOfService https://nunet.io/tos
// @contact.name Support
// @contact.url https://devexchange.nunet.io/
// @contact.email support@nunet.io
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @host localhost:9999
//
// @Schemes http
//
// @BasePath /api/v1
func main() {
cmd.Execute()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
"strings"
"sync"
"time"
dht_pb "github.com/libp2p/go-libp2p-kad-dht/pb"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
msgio "github.com/libp2p/go-msgio"
"github.com/libp2p/go-msgio/protoio" //nolint:staticcheck
multiaddr "github.com/multiformats/go-multiaddr"
"google.golang.org/protobuf/proto"
"gitlab.com/nunet/device-management-service/observability"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
)
const kadv1 = "/kad/1.0.0"
// Connect to Bootstrap nodes
func (l *Libp2p) ConnectToBootstrapNodes(ctx context.Context) error {
// bootstrap all nodes at the same time.
if len(l.config.BootstrapPeers) > 0 {
var wg sync.WaitGroup
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
for _, addr := range l.config.BootstrapPeers {
wg.Add(1)
go func(peerAddr multiaddr.Multiaddr) {
defer wg.Done()
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
log.Errorf("failed to convert multi addr to addr info %v - %v", peerAddr, err)
return
}
if err := l.Host.Connect(connectCtx, *addrInfo); err != nil {
log.Errorf("failed to connect to bootstrap node %s - %v", addrInfo.ID.String(), err)
} else {
log.Infof("connected to Bootstrap Node %s", addrInfo.ID.String())
}
}(addr)
}
wg.Wait()
}
return nil
}
// Start dht bootstrapper
func (l *Libp2p) BootstrapDHT(ctx context.Context) error {
endTrace := observability.StartTrace("libp2p_bootstrap_duration")
defer endTrace()
if err := l.DHT.Bootstrap(ctx); err != nil {
log.Errorw("libp2p_bootstrap_failure", "error", err)
return err
}
log.Infow("libp2p_bootstrap_success")
return nil
}
// startRandomWalk starts a background process that crawls the dht by resolving random keys.
func (l *Libp2p) startRandomWalk(ctx context.Context) {
go func() {
log.Debug("starting bootstrap process")
// A simple mechanism to improve our botostrap and peer discovery:
// 1. initiate a background, never ending, random walk which tries to resolve
// random keys in the dht and by extension discovers other peers.
interval := 5 * time.Minute
delayOnError := 10 * time.Second
time.Sleep(interval) // wait for dht ready
dhtProto := protocol.ID(l.config.DHTPrefix + kadv1)
sender := newDHTMessageSender(l.Host, dhtProto)
messenger, err := dht_pb.NewProtocolMessenger(sender)
if err != nil {
log.Errorf("bootstrap: creating protocol messenger: %s", err)
return
}
var depth int
var key string
for {
select {
case <-ctx.Done():
log.Debugf("bootstrap: context done, stopping bootstrap")
return
default:
randomPeerID, err := l.DHT.RoutingTable().GenRandPeerID(0)
if err != nil {
log.Debugf("bootstrap: failed to generate random peer ID: %s", err)
continue
}
key = randomPeerID.String()
log.Debugf("bootstrap: crawling from %s", key)
peers, err := l.DHT.GetClosestPeers(ctx, key)
if err != nil {
log.Debugf("bootstrap: failed to get closest peers with key=%s - error: %s", randomPeerID.String(), err)
time.Sleep(delayOnError)
delayOnError = time.Duration(float64(delayOnError) * 1.25)
if delayOnError > 5*time.Minute {
delayOnError = 5 * time.Minute
}
continue
}
delayOnError = 10 * time.Second
if len(peers) == 0 {
continue
}
peerID := peers[rand.Intn(len(peers))] //nolint:gosec
if peerID == l.Host.ID() {
log.Debugf("bootstrap: skipping self")
continue
}
log.Debugf("bootstrap: starting random walk from %s", peerID)
peerAddrInfo, err := l.resolvePeerAddress(ctx, peerID)
if err != nil {
log.Debugf("bootstrap: failed to resolve address for peer %s - %v", peerID, err)
continue
}
var peerInfos []*peer.AddrInfo
selected := &peerAddrInfo
crawl:
log.Debugf("bootstrap: crawling %s", selected.ID)
if err := l.Host.Connect(ctx, *selected); err != nil {
log.Debugf("bootstrap: failed to connect to peer %s: %s", peerID, err)
depth++
continue
}
peerInfos, err = messenger.GetClosestPeers(ctx, selected.ID, randomPeerID)
if err != nil {
log.Debugf("bootstrap: failed to get closest peers from %s: %s", selected.ID, err)
depth++
continue
}
if len(peerInfos) == 0 {
depth++
continue
}
selected = peerInfos[rand.Intn(len(peerInfos))] //nolint:gosec
if selected.ID == l.Host.ID() {
log.Debugf("bootstrap: skipping self")
depth++
continue
}
if depth < 20 {
randomPeerID, err = l.DHT.RoutingTable().GenRandPeerID(0)
if err != nil {
log.Debugf("bootstrap: failed to generate random peer ID: %s", err)
goto cooldown
}
depth++
goto crawl
}
// cooldown
cooldown:
depth = 0
minDelay := interval / 2
maxDelay := (3 * interval) / 2
delay := minDelay + time.Duration(rand.Int63n(int64(maxDelay-minDelay))) //nolint:gosec
log.Debugf("bootstrap: cooling down for %s", delay)
select {
case <-time.After(delay):
case <-ctx.Done():
return
}
interval = interval * 3 / 2
if interval > 4*time.Hour {
interval = 4 * time.Hour
}
}
}
}()
}
type dhtValidator struct {
PS peerstore.Peerstore
customNamespace string
}
// Validate validates an item placed into the dht.
func (d dhtValidator) Validate(key string, value []byte) error {
endTrace := observability.StartTrace("libp2p_dht_validate_duration")
defer endTrace()
// empty value is considered deleting an item from the dht
if len(value) == 0 {
log.Infow("libp2p_dht_validate_success", "key", key)
return nil
}
if !strings.HasPrefix(key, d.customNamespace) {
log.Errorw("libp2p_dht_validate_failure", "key", key, "error", "invalid key namespace")
return errors.New("invalid key namespace")
}
// verify signature
var envelope commonproto.Advertisement
err := proto.Unmarshal(value, &envelope)
if err != nil {
log.Errorw("libp2p_dht_validate_failure", "key", key, "error", fmt.Sprintf("failed to unmarshal envelope: %v", err))
return fmt.Errorf("failed to unmarshal envelope: %w", err)
}
pubKey, err := crypto.UnmarshalSecp256k1PublicKey(envelope.PublicKey)
if err != nil {
log.Errorw("libp2p_dht_validate_failure", "key", key, "error", fmt.Sprintf("failed to unmarshal public key: %v", err))
return fmt.Errorf("failed to unmarshal public key: %w", err)
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
envelope.PublicKey,
}, nil)
ok, err := pubKey.Verify(concatenatedBytes, envelope.Signature)
if err != nil {
log.Errorw("libp2p_dht_validate_failure", "key", key, "error", fmt.Sprintf("failed to verify envelope: %v", err))
return fmt.Errorf("failed to verify envelope: %w", err)
}
if !ok {
log.Errorw("libp2p_dht_validate_failure", "key", key, "error", "public key didn't sign the payload")
return errors.New("failed to verify envelope, public key didn't sign payload")
}
log.Infow("libp2p_dht_validate_success", "key", key)
return nil
}
func (dhtValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
type dhtMessenger struct {
host host.Host
proto protocol.ID
}
func newDHTMessageSender(h host.Host, proto protocol.ID) dht_pb.MessageSender {
return &dhtMessenger{host: h, proto: proto}
}
func (m *dhtMessenger) SendRequest(ctx context.Context, p peer.ID, msg *dht_pb.Message) (*dht_pb.Message, error) {
s, err := m.host.NewStream(ctx, p, m.proto)
if err != nil {
return nil, fmt.Errorf("open stream: %w", err)
}
defer s.Close()
wr := protoio.NewDelimitedWriter(s)
if err := wr.WriteMsg(msg); err != nil {
_ = s.Reset()
return nil, fmt.Errorf("write message: %w", err)
}
r := msgio.NewVarintReaderSize(s, network.MessageSizeMax)
bytes, err := r.ReadMsg()
if err != nil {
_ = s.Reset()
return nil, fmt.Errorf("read message: %w", err)
}
defer r.ReleaseMsg(bytes)
reply := new(dht_pb.Message)
if err := reply.Unmarshal(bytes); err != nil {
_ = s.Reset()
return nil, fmt.Errorf("unmarshal message: %w", err)
}
return reply, nil
}
func (m *dhtMessenger) SendMessage(ctx context.Context, p peer.ID, msg *dht_pb.Message) error {
s, err := m.host.NewStream(ctx, p, m.proto)
if err != nil {
return fmt.Errorf("open stream: %w", err)
}
defer s.Close()
wr := protoio.NewDelimitedWriter(s)
if err := wr.WriteMsg(msg); err != nil {
_ = s.Reset()
return fmt.Errorf("write message: %w", err)
}
return s.CloseWrite()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/peer"
dutil "github.com/libp2p/go-libp2p/p2p/discovery/util"
"gitlab.com/nunet/device-management-service/observability"
)
// DiscoverDialPeers discovers peers using rendezvous point
func (l *Libp2p) DiscoverDialPeers(ctx context.Context) error {
endTrace := observability.StartTrace("libp2p_peer_discover_duration")
defer endTrace()
foundPeers, err := l.findPeersFromRendezvousDiscovery(ctx)
if err != nil {
log.Errorw("libp2p_peer_discover_failure", "error", err)
return err
}
if len(foundPeers) > 0 {
l.discoveredPeers = foundPeers
log.Infow("libp2p_peer_discover_success", "foundPeers", len(foundPeers))
} else {
log.Debug("No peers found during discovery")
}
// filter out peers with no listening addresses and self host
filterSpec := NoAddrIDFilter{ID: l.Host.ID()}
l.discoveredPeers = PeerPassFilter(l.discoveredPeers, filterSpec)
l.dialPeers(ctx)
return nil
}
// advertiseForRendezvousDiscovery is used to advertise node using the dht by giving it the randevouz point.
func (l *Libp2p) advertiseForRendezvousDiscovery(context context.Context) error {
_, err := l.discovery.Advertise(context, l.config.Rendezvous)
return err
}
// findPeersFromRendezvousDiscovery uses the randevouz point to discover other peers.
func (l *Libp2p) findPeersFromRendezvousDiscovery(ctx context.Context) ([]peer.AddrInfo, error) {
endTrace := observability.StartTrace("libp2p_find_peers_duration")
defer endTrace()
peers, err := dutil.FindPeers(
ctx,
l.discovery,
l.config.Rendezvous,
discovery.Limit(l.config.PeerCountDiscoveryLimit),
)
if err != nil {
log.Errorw("libp2p_find_peers_failure", "error", err)
return nil, fmt.Errorf("failed to discover peers: %w", err)
}
log.Infow("libp2p_find_peers_success", "peersCount", len(peers))
return peers, nil
}
func (l *Libp2p) dialPeers(ctx context.Context) {
maxPeers := 16
peersToConnect := l.discoveredPeers
if len(peersToConnect) > maxPeers {
//nolint:gosec
r := rand.New(rand.NewSource(time.Now().UnixNano()))
r.Shuffle(len(peersToConnect), func(i, j int) {
peersToConnect[i], peersToConnect[j] = peersToConnect[j],
peersToConnect[i]
})
// Take only the first maxPeers
peersToConnect = peersToConnect[:maxPeers]
}
for _, p := range peersToConnect {
if p.ID == l.Host.ID() {
continue
}
if !l.PeerConnected(p.ID) {
go func(p peer.AddrInfo) {
dialCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
if err := l.Host.Connect(dialCtx, p); err != nil {
log.Debugf("couldn't establish connection with: %s - error: %v", p.ID, err)
return
}
log.Debugf("connected with: %s", p.ID)
}(p)
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"fmt"
"net"
"github.com/miekg/dns"
)
// ResolveDNS resolves a DNS query using the provided resolver
func resolveDNS(query *dns.Msg, records map[string]string) *dns.Msg {
// Create a response message
m := new(dns.Msg)
m.SetReply(query)
for _, question := range query.Question {
if question.Qtype != dns.TypeA {
// We only support A records
m.SetRcode(query, dns.RcodeNotImplemented)
continue
}
ip, ok := records[question.Name]
if !ok {
// Not found in our map, set answer to NXDOMAIN
m.SetRcode(query, dns.RcodeNameError)
continue
}
// Found record, add A record to the answer section
a := &dns.A{
Hdr: dns.RR_Header{Name: question.Name, Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.ParseIP(ip),
}
m.Answer = append(m.Answer, a)
}
return m
}
// HandleDNSQuery handles a DNS query by parsing the UDP packet, resolving the query, and sending a response
func handleDNSQuery(packet []byte, records map[string]string) ([]byte, error) {
// Parse the UDP packet into a DNS message
msg := new(dns.Msg)
err := msg.Unpack(packet)
if err != nil {
return nil, fmt.Errorf("failed to decode DNS message: %w", err)
}
// Resolve the DNS query
response := resolveDNS(msg, records)
log.Debug("DNS query resolved successfully", "response", response)
// Encode the response message into a UDP packet
responseBytes, err := response.Pack()
if err != nil {
return nil, fmt.Errorf("failed to encode DNS response: %w", err)
}
return responseBytes, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/control"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
mafilt "github.com/whyrusleeping/multiaddr-filter"
)
var defaultServerFilters = []string{
"/ip4/10.0.0.0/ipcidr/8",
"/ip4/100.64.0.0/ipcidr/10",
"/ip4/169.254.0.0/ipcidr/16",
"/ip4/172.16.0.0/ipcidr/12",
"/ip4/192.0.0.0/ipcidr/24",
"/ip4/192.0.2.0/ipcidr/24",
"/ip4/192.168.0.0/ipcidr/16",
"/ip4/198.18.0.0/ipcidr/15",
"/ip4/198.51.100.0/ipcidr/24",
"/ip4/203.0.113.0/ipcidr/24",
"/ip4/240.0.0.0/ipcidr/4",
"/ip6/100::/ipcidr/64",
"/ip6/2001:2::/ipcidr/48",
"/ip6/2001:db8::/ipcidr/32",
"/ip6/fc00::/ipcidr/7",
"/ip6/fe80::/ipcidr/10",
}
// PeerFilter is an interface for filtering peers
// satisfaction of filter criteria allows the peer to pass
type PeerFilter interface {
satisfies(p peer.AddrInfo) bool
}
// NoAddrIDFilter filters out peers with no listening addresses
// and a peer with a specific ID
type NoAddrIDFilter struct {
ID peer.ID
}
func (f NoAddrIDFilter) satisfies(p peer.AddrInfo) bool {
return len(p.Addrs) > 0 && p.ID != f.ID
}
func PeerPassFilter(peers []peer.AddrInfo, pf PeerFilter) []peer.AddrInfo {
var filtered []peer.AddrInfo
for _, p := range peers {
if pf.satisfies(p) {
filtered = append(filtered, p)
}
}
return filtered
}
type filtersConnectionGater multiaddr.Filters
var _ connmgr.ConnectionGater = (*filtersConnectionGater)(nil)
func (f *filtersConnectionGater) InterceptAddrDial(_ peer.ID, addr multiaddr.Multiaddr) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(addr)
}
func (f *filtersConnectionGater) InterceptPeerDial(_ peer.ID) (allow bool) {
return true
}
func (f *filtersConnectionGater) InterceptAccept(connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptUpgraded(_ network.Conn) (allow bool, reason control.DisconnectReason) {
return true, 0
}
func makeAddrsFactory(announce []string, appendAnnouce []string, noAnnounce []string) func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var err error // To assign to the slice in the for loop
existing := make(map[string]bool) // To avoid duplicates
annAddrs := make([]multiaddr.Multiaddr, len(announce))
for i, addr := range announce {
annAddrs[i], err = multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
existing[addr] = true
}
appendAnnAddrs := make([]multiaddr.Multiaddr, 0)
for _, addr := range appendAnnouce {
if existing[addr] {
// skip AppendAnnounce that is on the Announce list already
continue
}
appendAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
appendAnnAddrs = append(appendAnnAddrs, appendAddr)
}
filters := multiaddr.NewFilters()
noAnnAddrs := map[string]bool{}
for _, addr := range noAnnounce {
f, err := mafilt.NewMask(addr)
if err == nil {
filters.AddFilter(*f, multiaddr.ActionDeny)
continue
}
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
noAnnAddrs[string(maddr.Bytes())] = true
}
return func(allAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
var addrs []multiaddr.Multiaddr
if len(annAddrs) > 0 {
addrs = annAddrs
} else {
addrs = allAddrs
}
addrs = append(addrs, appendAnnAddrs...)
var out []multiaddr.Multiaddr
for _, maddr := range addrs {
// check for exact matches
ok := noAnnAddrs[string(maddr.Bytes())]
// check for /ipcidr matches
if !ok && !filters.AddrBlocked(maddr) {
out = append(out, maddr)
}
}
return out
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/types"
)
// StreamHandler is a function type that processes data from a stream.
type StreamHandler func(stream network.Stream)
// HandlerRegistry manages the registration of stream handlers for different protocols.
type HandlerRegistry struct {
host host.Host
handlers map[protocol.ID]StreamHandler
bytesHandlers map[protocol.ID]func(data []byte)
mu sync.RWMutex
}
// NewHandlerRegistry creates a new handler registry instance.
func NewHandlerRegistry(host host.Host) *HandlerRegistry {
return &HandlerRegistry{
host: host,
handlers: make(map[protocol.ID]StreamHandler),
bytesHandlers: make(map[protocol.ID]func(data []byte)),
}
}
// RegisterHandlerWithStreamCallback registers a stream handler for a specific protocol.
func (r *HandlerRegistry) RegisterHandlerWithStreamCallback(messageType types.MessageType, handler StreamHandler) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.handlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.handlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(handler))
return nil
}
// RegisterHandlerWithBytesCallback registers a stream handler for a specific protocol and sends the bytes back to callback.
func (r *HandlerRegistry) RegisterHandlerWithBytesCallback(messageType types.MessageType, s StreamHandler, handler func(data []byte)) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.bytesHandlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.bytesHandlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(s))
return nil
}
// SendMessageToLocalHandler given the message type it sends data to the local handler found.
func (r *HandlerRegistry) SendMessageToLocalHandler(messageType types.MessageType, data []byte) {
r.mu.RLock()
defer r.mu.RUnlock()
protoID := protocol.ID(messageType)
h, ok := r.bytesHandlers[protoID]
if !ok {
return
}
// we need this goroutine to avoid blocking the caller goroutine
go h(data)
}
// UnregisterHandler unregisters a stream handler for a specific protocol.
func (r *HandlerRegistry) UnregisterHandler(messageType types.MessageType) {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
delete(r.handlers, protoID)
delete(r.bytesHandlers, protoID)
r.host.RemoveStreamHandler(protoID)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"context"
"strings"
"time"
"github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/routing"
"github.com/libp2p/go-libp2p/p2p/host/autorelay"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/libp2p/go-libp2p/p2p/host/resource-manager"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/security/noise"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
mafilt "github.com/whyrusleeping/multiaddr-filter"
"gitlab.com/nunet/device-management-service/types"
)
// NewHost returns a new libp2p host with dht and other related settings.
func NewHost(ctx context.Context, config *types.Libp2pConfig, appScore func(p peer.ID) float64, scoreInspect pubsub.ExtendedPeerScoreInspectFn) (host.Host, *dht.IpfsDHT, *pubsub.PubSub, error) {
newPeer := make(chan peer.AddrInfo)
var idht *dht.IpfsDHT
connmgr, err := connmgr.NewConnManager(
100,
400,
connmgr.WithGracePeriod(time.Duration(config.GracePeriodMs)*time.Millisecond),
)
if err != nil {
return nil, nil, nil, err
}
filter := ma.NewFilters()
for _, s := range defaultServerFilters {
f, err := mafilt.NewMask(s)
if err != nil {
log.Errorf("incorrectly formatted address filter in config: %s - %v", s, err)
}
filter.AddFilter(*f, ma.ActionDeny)
}
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, nil, nil, err
}
var libp2pOpts []libp2p.Option
dhtOpts := []dht.Option{
dht.ProtocolPrefix(protocol.ID(config.DHTPrefix)),
dht.NamespacedValidator(strings.ReplaceAll(config.CustomNamespace, "/", ""), dhtValidator{PS: ps}),
dht.Mode(dht.ModeAutoServer),
}
// set up the resource manager
mem := int64(config.Memory)
if mem > 0 {
mem = 1024 * 1024 * mem
} else {
mem = 1024 * 1024 * 1024 // 1GB
}
fds := config.FileDescriptors
if fds == 0 {
fds = 512
}
limits := rcmgr.DefaultLimits
limits.SystemBaseLimit.ConnsInbound = 512
limits.SystemBaseLimit.ConnsOutbound = 512
limits.SystemBaseLimit.Conns = 1024
limits.SystemBaseLimit.StreamsInbound = 8192
limits.SystemBaseLimit.StreamsOutbound = 8192
limits.SystemBaseLimit.Streams = 16384
scaled := limits.Scale(mem, fds)
log.Infof("libp2p limits: %+v", scaled)
mgr, err := rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(scaled))
if err != nil {
return nil, nil, nil, err
}
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(config.ListenAddress...),
libp2p.ResourceManager(mgr),
libp2p.Identity(config.PrivateKey),
libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) {
idht, err = dht.New(ctx, h, dhtOpts...)
return idht, err
}),
libp2p.Peerstore(ps),
libp2p.Security(libp2ptls.ID, libp2ptls.New),
libp2p.Security(noise.ID, noise.New),
// libp2p.NoListenAddrs,
libp2p.ChainOptions(
libp2p.Transport(tcp.NewTCPTransport),
libp2p.Transport(quic.NewTransport),
libp2p.Transport(webtransport.New),
libp2p.Transport(ws.New),
),
libp2p.EnableNATService(),
libp2p.ConnectionManager(connmgr),
libp2p.EnableRelay(),
libp2p.EnableHolePunching(),
libp2p.EnableRelayService(
relay.WithLimit(&relay.RelayLimit{
Duration: 5 * time.Minute,
Data: 1 << 21, // 2 MiB
}),
),
libp2p.EnableAutoRelayWithPeerSource(
func(ctx context.Context, num int) <-chan peer.AddrInfo {
r := make(chan peer.AddrInfo)
go func() {
defer close(r)
for i := 0; i < num; i++ {
select {
case p := <-newPeer:
select {
case r <- p:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return r
},
autorelay.WithBootDelay(time.Minute),
autorelay.WithBackoff(30*time.Second),
autorelay.WithMinCandidates(2),
autorelay.WithMaxCandidates(3),
autorelay.WithNumRelays(2),
),
)
if config.Server {
libp2pOpts = append(libp2pOpts, libp2p.AddrsFactory(makeAddrsFactory([]string{}, []string{}, defaultServerFilters)))
libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater((*filtersConnectionGater)(filter)))
}
host, err := libp2p.New(libp2pOpts...)
if err != nil {
return nil, nil, nil, err
}
go watchForNewPeers(ctx, host, newPeer)
optsPS := []pubsub.Option{
pubsub.WithFloodPublish(true),
pubsub.WithMessageSigning(true),
pubsub.WithPeerScore(
&pubsub.PeerScoreParams{
SkipAtomicValidation: true,
Topics: make(map[string]*pubsub.TopicScoreParams),
TopicScoreCap: 10,
AppSpecificScore: appScore,
AppSpecificWeight: 1,
DecayInterval: time.Hour,
DecayToZero: 0.001,
RetainScore: 6 * time.Hour,
},
&pubsub.PeerScoreThresholds{
GossipThreshold: -500,
PublishThreshold: -1000,
GraylistThreshold: -2500,
AcceptPXThreshold: 0, // TODO for public mainnet we should limit to botostrappers and set them up without a mesh
OpportunisticGraftThreshold: 2.5,
},
),
pubsub.WithPeerExchange(true),
pubsub.WithPeerScoreInspect(scoreInspect, time.Second),
}
if config.GossipMaxMessageSize > 0 {
optsPS = append(optsPS, pubsub.WithMaxMessageSize(config.GossipMaxMessageSize))
}
gossip, err := pubsub.NewGossipSub(ctx, host, optsPS...)
// gossip, err := pubsub.NewGossipSubWithRouter(ctx, host, pubsub.DefaultGossipSubRouter(host), optsPS...)
if err != nil {
return nil, nil, nil, err
}
return host, idht, gossip, nil
}
func watchForNewPeers(ctx context.Context, host host.Host, newPeer chan peer.AddrInfo) {
sub, err := host.EventBus().Subscribe([]interface{}{
&event.EvtPeerIdentificationCompleted{},
&event.EvtPeerProtocolsUpdated{},
})
if err != nil {
log.Errorf("failed to subscribe to peer identification events: %v", err)
return
}
defer sub.Close()
for ctx.Err() == nil {
var ev any
select {
case <-ctx.Done():
return
case ev = <-sub.Out():
}
if ev, ok := ev.(event.EvtPeerIdentificationCompleted); ok {
var publicAddrs []ma.Multiaddr
for _, addr := range ev.ListenAddrs {
if manet.IsPublicAddr(addr) {
publicAddrs = append(publicAddrs, addr)
}
}
if len(publicAddrs) > 0 {
newPeer <- peer.AddrInfo{ID: ev.Peer, Addrs: publicAddrs}
}
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
cid "github.com/ipfs/go-cid"
dht "github.com/libp2p/go-libp2p-kad-dht"
kbucket "github.com/libp2p/go-libp2p-kbucket"
pubsub "github.com/libp2p/go-libp2p-pubsub"
libp2pdiscovery "github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
drouting "github.com/libp2p/go-libp2p/p2p/discovery/routing"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
multiaddr "github.com/multiformats/go-multiaddr"
multihash "github.com/multiformats/go-multihash"
msmux "github.com/multiformats/go-multistream"
"github.com/spf13/afero"
"google.golang.org/protobuf/proto"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/observability"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/types"
)
const (
MB = 1024 * 1024
maxMessageLengthMB = 1
ValidationAccept = pubsub.ValidationAccept
ValidationReject = pubsub.ValidationReject
ValidationIgnore = pubsub.ValidationIgnore
readTimeout = 30 * time.Second
sendSemaphoreLimit = 4096
)
type (
PeerID = peer.ID
ProtocolID = protocol.ID
Topic = pubsub.Topic
PubSub = pubsub.PubSub
ValidationResult = pubsub.ValidationResult
Validator func([]byte, interface{}) (ValidationResult, interface{})
PeerScoreSnapshot = pubsub.PeerScoreSnapshot
)
// Libp2p contains the configuration for a Libp2p instance.
//
// TODO-suggestion: maybe we should call it something else like Libp2pPeer,
// Libp2pHost or just Peer (callers would use libp2p.Peer...)
type Libp2p struct {
Host host.Host
DHT *dht.IpfsDHT
PS peerstore.Peerstore
pubsub *PubSub
ctx context.Context
cancel func()
mx sync.Mutex
pubsubAppScore func(peer.ID) float64
pubsubScore map[peer.ID]*PeerScoreSnapshot
topicMux sync.RWMutex
pubsubTopics map[string]*Topic
topicValidators map[string]map[uint64]Validator
topicSubscription map[string]map[uint64]*pubsub.Subscription
nextTopicSubID uint64
// send backpressure semaphore
sendSemaphore chan struct{}
// a list of peers discovered by discovery
discoveredPeers []peer.AddrInfo
discovery libp2pdiscovery.Discovery
// services
pingService *ping.PingService
// tasks
discoveryTask *bt.Task
advertiseRendezvousTask *bt.Task
handlerRegistry *HandlerRegistry
config *types.Libp2pConfig
// dependencies (db, filesystem...)
fs afero.Fs
subnets map[string]*subnet
isSubnetWriteProtocolRegistered int32
}
// New creates a libp2p instance.
//
// TODO-Suggestion: move types.Libp2pConfig to here for better readability.
// Unless there is a reason to keep within types.
func New(config *types.Libp2pConfig, fs afero.Fs) (*Libp2p, error) {
if config == nil {
return nil, errors.New("config is nil")
}
if config.Scheduler == nil {
return nil, errors.New("scheduler is nil")
}
return &Libp2p{
config: config,
discoveredPeers: make([]peer.AddrInfo, 0),
pubsubTopics: make(map[string]*pubsub.Topic),
topicSubscription: make(map[string]map[uint64]*pubsub.Subscription),
topicValidators: make(map[string]map[uint64]Validator),
sendSemaphore: make(chan struct{}, sendSemaphoreLimit),
fs: fs,
subnets: make(map[string]*subnet),
}, nil
}
// Init initializes a libp2p host with its dependencies.
func (l *Libp2p) Init() error {
ctx, cancel := context.WithCancel(context.Background())
host, dht, pubsub, err := NewHost(ctx, l.config, l.broadcastAppScore, l.broadcastScoreInspect)
if err != nil {
cancel()
log.Error(err)
return err
}
l.ctx = ctx
l.cancel = cancel
l.Host = host
l.DHT = dht
l.PS = host.Peerstore()
l.discovery = drouting.NewRoutingDiscovery(dht)
l.pubsub = pubsub
l.handlerRegistry = NewHandlerRegistry(host)
// Initialize the observability package with the host
if err := observability.Initialize(l.Host); err != nil {
return fmt.Errorf("failed to initialize observability: %w", err)
}
return nil
}
// Start performs network bootstrapping, peer discovery and protocols handling.
func (l *Libp2p) Start() error {
// set stream handlers
l.registerStreamHandlers()
// connect to bootstrap nodes
err := l.ConnectToBootstrapNodes(l.ctx)
if err != nil {
log.Errorf("libp2p_bootstrap_failure", "error", err)
return err
}
log.Infow("libp2p_bootstrap_success")
err = l.BootstrapDHT(l.ctx)
if err != nil {
log.Errorf("libp2p_bootstrap_failure", "error", err)
return err
}
log.Infow("libp2p_bootstrap_success")
// Start random walk
l.startRandomWalk(l.ctx)
// watch for local address change
go l.watchForAddrsChange(l.ctx)
// discover
go func() {
// wait for dht bootstrap
time.Sleep(1 * time.Minute)
// advertise randevouz discovery
err = l.advertiseForRendezvousDiscovery(l.ctx)
if err != nil {
log.Warnf("libp2p_advertise_rendezvous_failure", "error", err)
} else {
log.Infow("libp2p_advertise_rendezvous_success")
}
err = l.DiscoverDialPeers(l.ctx)
if err != nil {
log.Warnf("libp2p_peer_discover_failure", "error", err)
} else {
log.Infow("libp2p_peer_discover_success", "foundPeers", len(l.discoveredPeers))
}
}()
// register period peer discoveryTask task
discoveryTask := &bt.Task{
Name: "Peer Discovery",
Description: "Periodic task to discover new peers every 15 minutes",
Function: func(_ interface{}) error {
return l.DiscoverDialPeers(l.ctx)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 15 * time.Minute}},
}
l.discoveryTask = l.config.Scheduler.AddTask(discoveryTask)
// register rendezvous advertisement task
advertiseRendezvousTask := &bt.Task{
Name: "Rendezvous advertisement",
Description: "Periodic task to advertise a rendezvous point every 6 hours",
Function: func(_ interface{}) error {
return l.advertiseForRendezvousDiscovery(l.ctx)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 6 * time.Hour}},
}
l.advertiseRendezvousTask = l.config.Scheduler.AddTask(advertiseRendezvousTask)
l.config.Scheduler.Start()
return nil
}
// RegisterStreamMessageHandler registers a stream handler for a specific protocol.
func (l *Libp2p) RegisterStreamMessageHandler(messageType types.MessageType, handler StreamHandler) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithStreamCallback(messageType, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// RegisterBytesMessageHandler registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) RegisterBytesMessageHandler(messageType types.MessageType, handler func(data []byte)) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithBytesCallback(messageType, l.handleReadBytesFromStream, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// HandleMessage registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) HandleMessage(messageType string, handler func(data []byte)) error {
return l.RegisterBytesMessageHandler(types.MessageType(messageType), handler)
}
func (l *Libp2p) handleReadBytesFromStream(s network.Stream) {
l.handlerRegistry.mu.RLock()
callback, ok := l.handlerRegistry.bytesHandlers[s.Protocol()]
l.handlerRegistry.mu.RUnlock()
if !ok {
_ = s.Reset()
return
}
if err := s.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
_ = s.Reset()
log.Warnf("error setting read deadline: %s", err)
return
}
c := bufio.NewReader(s)
defer s.Close()
// read the first 8 bytes to determine the size of the message
msgLengthBuffer := make([]byte, 8)
_, err := c.Read(msgLengthBuffer)
if err != nil {
log.Debugf("error reading message length: %s", err)
_ = s.Reset()
return
}
// create a buffer with the size of the message and then read until its full
lengthPrefix := binary.LittleEndian.Uint64(msgLengthBuffer)
// check if the message length is greater than max allowed
if lengthPrefix > maxMessageLengthMB*MB {
_ = s.Reset()
log.Warnf("message length exceeds maximum: %d", lengthPrefix)
return
}
buf := make([]byte, lengthPrefix)
// read the full message
_, err = io.ReadFull(c, buf)
if err != nil {
log.Debugf("error reading message: %s", err)
_ = s.Reset()
return
}
_ = s.Close()
callback(buf)
}
// UnregisterMessageHandler unregisters a stream handler for a specific protocol.
func (l *Libp2p) UnregisterMessageHandler(messageType string) {
l.handlerRegistry.UnregisterHandler(types.MessageType(messageType))
}
// SendMessage asynchronously sends a message to a peer
func (l *Libp2p) SendMessage(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error {
pid, err := peer.Decode(hostID)
if err != nil {
return fmt.Errorf("send: invalid peer ID: %w", err)
}
// we are delivering a message to ourself
// we should use the handler to send the message to the handler directly which has been previously registered.
if pid == l.Host.ID() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data)
return nil
}
ctx, cancel := context.WithTimeout(ctx, time.Until(expiry))
select {
case l.sendSemaphore <- struct{}{}:
go func() {
defer cancel()
defer func() { <-l.sendSemaphore }()
l.sendMessage(ctx, pid, msg, expiry, nil)
}()
return nil
case <-ctx.Done():
cancel()
return ctx.Err()
}
}
// SendMessageSync synchronously sends a message to a peer
func (l *Libp2p) SendMessageSync(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error {
pid, err := peer.Decode(hostID)
if err != nil {
return fmt.Errorf("send: invalid peer ID: %w", err)
}
if pid == l.Host.ID() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data)
return nil
}
ctx, cancel := context.WithTimeout(ctx, time.Until(expiry))
defer cancel()
result := make(chan error, 1)
l.sendMessage(ctx, pid, msg, expiry, result)
return <-result
}
// workaround for https://github.com/libp2p/go-libp2p/issues/2983
func (l *Libp2p) newStream(ctx context.Context, pid peer.ID, proto protocol.ID) (network.Stream, error) {
s, err := l.Host.Network().NewStream(network.WithNoDial(ctx, "already dialed"), pid)
if err != nil {
return nil, err
}
selected, err := msmux.SelectOneOf([]protocol.ID{proto}, s)
if err != nil {
_ = s.Reset()
return nil, err
}
if err := s.SetProtocol(selected); err != nil {
_ = s.Reset()
return nil, err
}
return s, nil
}
func (l *Libp2p) sendMessage(ctx context.Context, pid peer.ID, msg types.MessageEnvelope, expiry time.Time, result chan error) {
var err error
defer func() {
if result != nil {
result <- err
}
}()
if !l.PeerConnected(pid) {
var ai peer.AddrInfo
ai, err = l.resolvePeerAddress(ctx, pid)
if err != nil {
log.Warnf("send: error resolving addresses for peer %s: %s", pid, err)
return
}
if err = l.Host.Connect(ctx, ai); err != nil {
log.Warnf("send: failed to connect to peer %s: %s", pid, err)
return
}
}
requestBufferSize := 8 + len(msg.Data)
if requestBufferSize > maxMessageLengthMB*MB {
log.Warnf("send: message size %d is greater than limit %d bytes", requestBufferSize, maxMessageLengthMB*MB)
err = fmt.Errorf("message too large")
return
}
ctx = network.WithAllowLimitedConn(ctx, "send message")
stream, err := l.newStream(ctx, pid, protocol.ID(msg.Type))
if err != nil {
log.Warnf("send: failed to open stream to peer %s: %s", pid, err)
return
}
defer stream.Close()
if err = stream.SetWriteDeadline(expiry); err != nil {
_ = stream.Reset()
log.Warnf("send: failed to set write deadline to peer %s: %s", pid, err)
return
}
requestPayloadWithLength := make([]byte, requestBufferSize)
binary.LittleEndian.PutUint64(requestPayloadWithLength, uint64(len(msg.Data)))
copy(requestPayloadWithLength[8:], msg.Data)
if _, err = stream.Write(requestPayloadWithLength); err != nil {
_ = stream.Reset()
log.Warnf("send: failed to send message to peer %s: %s", pid, err)
}
if err = stream.CloseWrite(); err != nil {
_ = stream.Reset()
log.Warnf("send: failed to flush output to peer %s: %s", pid, err)
}
log.Debugf("send %d bytes to peer %s", len(requestPayloadWithLength), pid)
}
// OpenStream opens a stream to a remote address and returns the stream for the caller to handle.
func (l *Libp2p) OpenStream(ctx context.Context, addr string, messageType types.MessageType) (network.Stream, error) {
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil, fmt.Errorf("invalid multiaddress: %w", err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(maddr)
if err != nil {
return nil, fmt.Errorf("could not resolve peer info: %w", err)
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(messageType))
if err != nil {
return nil, fmt.Errorf("failed to open stream: %w", err)
}
return stream, nil
}
// GetMultiaddr returns the peer's multiaddr.
func (l *Libp2p) GetMultiaddr() ([]multiaddr.Multiaddr, error) {
peerInfo := peer.AddrInfo{
ID: l.Host.ID(),
Addrs: l.Host.Addrs(),
}
return peer.AddrInfoToP2pAddrs(&peerInfo)
}
// Stop performs a cleanup of any resources used in this package.
func (l *Libp2p) Stop() error {
var errorMessages []string
l.cancel()
l.config.Scheduler.RemoveTask(l.discoveryTask.ID)
l.config.Scheduler.RemoveTask(l.advertiseRendezvousTask.ID)
if err := l.DHT.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if err := l.Host.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if len(errorMessages) > 0 {
return errors.New(strings.Join(errorMessages, "; "))
}
return nil
}
// Stat returns the status about the libp2p network.
func (l *Libp2p) Stat() types.NetworkStats {
lAddrs := make([]string, 0, len(l.Host.Addrs()))
for _, addr := range l.Host.Addrs() {
lAddrs = append(lAddrs, addr.String())
}
return types.NetworkStats{
ID: l.Host.ID().String(),
ListenAddr: strings.Join(lAddrs, ", "),
}
}
// GetPeerIP gets the ip of the peer from the peer store
func (l *Libp2p) GetPeerIP(p PeerID) string {
addrs := l.Host.Peerstore().Addrs(p)
for _, addr := range addrs {
addrParts := strings.Split(addr.String(), "/")
for i, part := range addrParts {
if part == "ip4" || part == "ip6" {
return addrParts[i+1]
}
}
}
return ""
}
// Ping the remote address. The remote address is the encoded peer id which will be decoded and used here.
//
// TODO (Return error once): something that was confusing me when using this method is that the error is
// returned twice if any. Once as a field of PingResult and one as a return value.
func (l *Libp2p) Ping(ctx context.Context, peerIDAddress string, timeout time.Duration) (types.PingResult, error) {
// avoid dial to self attempt
if peerIDAddress == l.Host.ID().String() {
err := errors.New("can't ping self")
return types.PingResult{Success: false, Error: err}, err
}
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
remotePeer, err := peer.Decode(peerIDAddress)
if err != nil {
return types.PingResult{}, err
}
pingChan := ping.Ping(pingCtx, l.Host, remotePeer)
select {
case res := <-pingChan:
if res.Error != nil {
log.Errorf("failed to ping peer %s: %v", peerIDAddress, res.Error)
return types.PingResult{
Success: false,
RTT: res.RTT,
Error: res.Error,
}, res.Error
}
return types.PingResult{
RTT: res.RTT,
Success: true,
}, nil
case <-pingCtx.Done():
return types.PingResult{
Error: pingCtx.Err(),
}, pingCtx.Err()
}
}
// ResolveAddress resolves the address by given a peer id.
func (l *Libp2p) ResolveAddress(ctx context.Context, id string) ([]string, error) {
ai, err := l.resolveAddress(ctx, id)
if err != nil {
return nil, err
}
result := make([]string, 0, len(ai.Addrs))
for _, addr := range ai.Addrs {
result = append(result, fmt.Sprintf("%s/p2p/%s", addr, id))
}
return result, nil
}
func (l *Libp2p) resolveAddress(ctx context.Context, id string) (peer.AddrInfo, error) {
pid, err := peer.Decode(id)
if err != nil {
return peer.AddrInfo{}, fmt.Errorf("failed to resolve invalid peer: %w", err)
}
return l.resolvePeerAddress(ctx, pid)
}
func (l *Libp2p) resolvePeerAddress(ctx context.Context, pid peer.ID) (peer.AddrInfo, error) {
// resolve ourself
if l.Host.ID() == pid {
addrs, err := l.GetMultiaddr()
if err != nil {
return peer.AddrInfo{}, fmt.Errorf("failed to resolve self: %w", err)
}
return peer.AddrInfo{ID: pid, Addrs: addrs}, nil
}
if l.PeerConnected(pid) {
addrs := l.Host.Peerstore().Addrs(pid)
return peer.AddrInfo{
ID: pid,
Addrs: addrs,
}, nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
pi, err := l.DHT.FindPeer(ctx, pid)
if err != nil {
return peer.AddrInfo{}, fmt.Errorf("failed to resolve address for peer %s: %w", pid, err)
}
return pi, nil
}
// Query return all the advertisements in the network related to a key.
// The network is queried to find providers for the given key, and peers which we aren't connected to can be retrieved.
func (l *Libp2p) Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error) {
if key == "" {
return nil, errors.New("advertisement key is empty")
}
customCID, err := createCIDFromKey(key)
if err != nil {
return nil, fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
addrInfo, err := l.DHT.FindProviders(ctx, customCID)
if err != nil {
return nil, fmt.Errorf("failed to find providers for key %s: %w", key, err)
}
advertisements := make([]*commonproto.Advertisement, 0)
for _, v := range addrInfo {
// TODO: use go routines to get the values in parallel.
bytesAdvertisement, err := l.DHT.GetValue(ctx, l.getCustomNamespace(key, v.ID.String()))
if err != nil {
continue
}
var ad commonproto.Advertisement
if err := proto.Unmarshal(bytesAdvertisement, &ad); err != nil {
return nil, fmt.Errorf("failed to unmarshal advertisement payload: %w", err)
}
advertisements = append(advertisements, &ad)
}
return advertisements, nil
}
// Advertise given data and a key pushes the data to the dht.
func (l *Libp2p) Advertise(ctx context.Context, key string, data []byte) error {
if key == "" {
return errors.New("advertisement key is empty")
}
pubKeyBytes, err := l.getPublicKey()
if err != nil {
return fmt.Errorf("failed to get public key: %w", err)
}
envelope := &commonproto.Advertisement{
PeerId: l.Host.ID().String(),
Timestamp: time.Now().Unix(),
Data: data,
PublicKey: pubKeyBytes,
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
pubKeyBytes,
}, nil)
sig, err := l.sign(concatenatedBytes)
if err != nil {
return fmt.Errorf("failed to sign advertisement envelope content: %w", err)
}
envelope.Signature = sig
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
return fmt.Errorf("failed to marshal advertise envelope: %w", err)
}
customCID, err := createCIDFromKey(key)
if err != nil {
return fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
err = l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), envelopeBytes)
if err != nil {
return fmt.Errorf("failed to put key %s into the dht: %w", key, err)
}
err = l.DHT.Provide(ctx, customCID, true)
if err != nil {
return fmt.Errorf("failed to provide key %s into the dht: %w", key, err)
}
return nil
}
// Unadvertise removes the data from the dht.
func (l *Libp2p) Unadvertise(ctx context.Context, key string) error {
err := l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), nil)
if err != nil {
return fmt.Errorf("failed to remove key %s from the DHT: %w", key, err)
}
return nil
}
// Publish publishes data to a topic.
// The requirements are that only one topic handler should exist per topic.
func (l *Libp2p) Publish(ctx context.Context, topic string, data []byte) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to publish: %w", err)
}
err = topicHandler.Publish(ctx, data)
if err != nil {
return fmt.Errorf("failed to publish to topic %s: %w", topic, err)
}
return nil
}
// Subscribe subscribes to a topic and sends the messages to the handler.
func (l *Libp2p) Subscribe(ctx context.Context, topic string, handler func(data []byte), validator Validator) (uint64, error) {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return 0, fmt.Errorf("failed to subscribe to topic: %w", err)
}
sub, err := topicHandler.Subscribe()
if err != nil {
return 0, fmt.Errorf("failed to subscribe to topic %s: %w", topic, err)
}
l.topicMux.Lock()
subID := l.nextTopicSubID
l.nextTopicSubID++
topicMap, ok := l.topicSubscription[topic]
if !ok {
topicMap = make(map[uint64]*pubsub.Subscription)
l.topicSubscription[topic] = topicMap
}
if validator != nil {
validatorMap, ok := l.topicValidators[topic]
if !ok {
if err := l.pubsub.RegisterTopicValidator(topic, l.validate); err != nil {
sub.Cancel()
return 0, fmt.Errorf("failed to register topic validator: %w", err)
}
validatorMap = make(map[uint64]Validator)
l.topicValidators[topic] = validatorMap
}
validatorMap[subID] = validator
}
topicMap[subID] = sub
l.topicMux.Unlock()
go func() {
for {
msg, err := sub.Next(ctx)
if err != nil {
continue
}
handler(msg.Data)
}
}()
return subID, nil
}
func (l *Libp2p) validate(_ context.Context, _ peer.ID, msg *pubsub.Message) ValidationResult {
l.topicMux.RLock()
validators, ok := l.topicValidators[msg.GetTopic()]
l.topicMux.RUnlock()
if !ok {
return ValidationAccept
}
for _, validator := range validators {
result, validatorData := validator(msg.Data, msg.ValidatorData)
if result != ValidationAccept {
return result
}
msg.ValidatorData = validatorData
}
return ValidationAccept
}
func (l *Libp2p) SetupBroadcastTopic(topic string, setup func(*Topic) error) error {
t, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to %s", topic)
}
return setup(t)
}
func (l *Libp2p) SetBroadcastAppScore(f func(peer.ID) float64) {
l.mx.Lock()
defer l.mx.Unlock()
l.pubsubAppScore = f
}
func (l *Libp2p) broadcastAppScore(p peer.ID) float64 {
f := func(peer.ID) float64 { return 0 }
l.mx.Lock()
if l.pubsubAppScore != nil {
f = l.pubsubAppScore
}
l.mx.Unlock()
return f(p)
}
func (l *Libp2p) GetBroadcastScore() map[peer.ID]*PeerScoreSnapshot {
l.mx.Lock()
defer l.mx.Unlock()
return l.pubsubScore
}
func (l *Libp2p) broadcastScoreInspect(score map[peer.ID]*PeerScoreSnapshot) {
l.mx.Lock()
defer l.mx.Unlock()
l.pubsubScore = score
}
func (l *Libp2p) watchForAddrsChange(ctx context.Context) {
sub, err := l.Host.EventBus().Subscribe(&event.EvtLocalAddressesUpdated{})
if err != nil {
log.Errorf("failed to subscribe to event bus: %v", err)
return
}
for {
select {
case <-ctx.Done():
return
case <-sub.Out():
log.Debug("network address changed. trying to be bootstrap again.")
if err = l.ConnectToBootstrapNodes(l.ctx); err != nil {
log.Errorf("failed to start network: %v", err)
}
}
}
}
func (l *Libp2p) Notify(ctx context.Context, preconnected func(peer.ID, []protocol.ID, int), connected, disconnected func(peer.ID), identified, updated func(peer.ID, []protocol.ID)) error {
sub, err := l.Host.EventBus().Subscribe([]interface{}{
&event.EvtPeerConnectednessChanged{},
&event.EvtPeerIdentificationCompleted{},
&event.EvtPeerProtocolsUpdated{},
})
if err != nil {
return fmt.Errorf("failed to subscribe to event bus: %w", err)
}
for _, p := range l.Host.Network().Peers() {
switch l.Host.Network().Connectedness(p) {
case network.Limited:
fallthrough
case network.Connected:
protos, _ := l.Host.Peerstore().GetProtocols(p)
preconnected(p, protos, len(l.Host.Network().ConnsToPeer(p)))
}
}
go func() {
defer sub.Close()
for ctx.Err() == nil {
var ev any
select {
case <-ctx.Done():
return
case ev = <-sub.Out():
switch evt := ev.(type) {
case event.EvtPeerConnectednessChanged:
switch evt.Connectedness {
case network.Limited:
fallthrough
case network.Connected:
connected(evt.Peer)
case network.NotConnected:
disconnected(evt.Peer)
}
case event.EvtPeerIdentificationCompleted:
identified(evt.Peer, evt.Protocols)
case event.EvtPeerProtocolsUpdated:
updated(evt.Peer, evt.Added)
}
}
}
}()
return nil
}
func (l *Libp2p) PeerConnected(p PeerID) bool {
switch l.Host.Network().Connectedness(p) {
case network.Limited:
return true
case network.Connected:
return true
default:
return false
}
}
// getOrJoinTopicHandler gets the topic handler, it will be created if it doesn't exist.
// for publishing and subscribing its needed therefore its implemented in this function.
func (l *Libp2p) getOrJoinTopicHandler(topic string) (*pubsub.Topic, error) {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
t, err := l.pubsub.Join(topic)
if err != nil {
return nil, fmt.Errorf("failed to join topic %s: %w", topic, err)
}
topicHandler = t
l.pubsubTopics[topic] = t
}
return topicHandler, nil
}
// Unsubscribe cancels the subscription to a topic
func (l *Libp2p) Unsubscribe(topic string, subID uint64) error {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to topic: %s", topic)
}
topicValidators, ok := l.topicValidators[topic]
if ok {
delete(topicValidators, subID)
}
// delete subscription handler and subscription
topicSubscriptions, ok := l.topicSubscription[topic]
if ok {
sub, ok := topicSubscriptions[subID]
if ok {
sub.Cancel()
delete(topicSubscriptions, subID)
}
}
if len(topicSubscriptions) == 0 {
delete(l.pubsubTopics, topic)
if err := topicHandler.Close(); err != nil {
return fmt.Errorf("failed to close topic handler: %w", err)
}
}
return nil
}
func (l *Libp2p) VisiblePeers() []peer.AddrInfo {
return l.discoveredPeers
}
func (l *Libp2p) KnownPeers() ([]peer.AddrInfo, error) {
knownPeers := l.Host.Peerstore().Peers()
peers := make([]peer.AddrInfo, 0, len(knownPeers))
for _, p := range knownPeers {
peers = append(peers, peer.AddrInfo{ID: p})
}
return peers, nil
}
func (l *Libp2p) DumpDHTRoutingTable() ([]kbucket.PeerInfo, error) {
rt := l.DHT.RoutingTable()
return rt.GetPeerInfos(), nil
}
func (l *Libp2p) registerStreamHandlers() {
l.pingService = ping.NewPingService(l.Host)
l.Host.SetStreamHandler(protocol.ID("/ipfs/ping/1.0.0"), l.pingService.PingHandler)
}
func (l *Libp2p) sign(data []byte) ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
signature, err := privKey.Sign(data)
if err != nil {
return nil, fmt.Errorf("failed to sign data: %w", err)
}
return signature, nil
}
func (l *Libp2p) getPublicKey() ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
pubKey := privKey.GetPublic()
return pubKey.Raw()
}
func (l *Libp2p) getCustomNamespace(key, peerID string) string {
return fmt.Sprintf("%s-%s-%s", l.config.CustomNamespace, key, peerID)
}
func createCIDFromKey(key string) (cid.Cid, error) {
hash := sha256.Sum256([]byte(key))
mh, err := multihash.Encode(hash[:], multihash.SHA2_256)
if err != nil {
return cid.Cid{}, err
}
return cid.NewCidV1(cid.Raw, mh), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"sync"
"github.com/libp2p/go-libp2p/core/peer"
)
type SubnetRoutingTable interface {
Add(peerID peer.ID, addr string)
Remove(peerID peer.ID)
Get(peerID peer.ID) (string, bool)
RemoveByIP(addr string)
GetByIP(addr string) (peer.ID, bool)
All() map[peer.ID]string
Clear()
}
type rtable struct {
mx sync.RWMutex
idx map[peer.ID]string
revIdx map[string]peer.ID
}
func NewRoutingTable() SubnetRoutingTable {
return &rtable{
idx: make(map[peer.ID]string),
revIdx: make(map[string]peer.ID),
}
}
func (rt *rtable) Add(peerID peer.ID, addr string) {
rt.mx.Lock()
defer rt.mx.Unlock()
rt.idx[peerID] = addr
rt.revIdx[addr] = peerID
}
func (rt *rtable) Remove(peerID peer.ID) {
rt.mx.Lock()
defer rt.mx.Unlock()
addr, ok := rt.idx[peerID]
if !ok {
return
}
delete(rt.idx, peerID)
delete(rt.revIdx, addr)
}
func (rt *rtable) Get(peerID peer.ID) (string, bool) {
rt.mx.RLock()
defer rt.mx.RUnlock()
addr, ok := rt.idx[peerID]
return addr, ok
}
func (rt *rtable) RemoveByIP(addr string) {
rt.mx.Lock()
defer rt.mx.Unlock()
peerID, ok := rt.revIdx[addr]
if !ok {
return
}
delete(rt.idx, peerID)
delete(rt.revIdx, addr)
}
func (rt *rtable) GetByIP(addr string) (peer.ID, bool) {
rt.mx.RLock()
defer rt.mx.RUnlock()
peerID, ok := rt.revIdx[addr]
return peerID, ok
}
func (rt *rtable) All() map[peer.ID]string {
rt.mx.RLock()
defer rt.mx.RUnlock()
idx := make(map[peer.ID]string)
for k, v := range rt.idx {
idx[k] = v
}
return idx
}
func (rt *rtable) Clear() {
rt.mx.Lock()
defer rt.mx.Unlock()
rt.idx = make(map[peer.ID]string)
rt.revIdx = make(map[string]peer.ID)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io/fs"
"math/rand"
"net"
"sync"
"sync/atomic"
"time"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
network "github.com/libp2p/go-libp2p/core/network"
peer "github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/lib/sys"
"gitlab.com/nunet/device-management-service/types"
)
const (
IfaceMTU = 1420
PacketExchangeProtocolID = "/dms/subnet/packet-exchange/0.0.1"
)
type subnet struct {
ctx context.Context
network *Libp2p
info struct {
id string
rtable SubnetRoutingTable
}
mx sync.Mutex
ifaces map[string]struct {
tun *sys.NetInterface
ctx context.Context
cancel context.CancelFunc
}
io struct {
mx sync.RWMutex
streams map[string]*struct {
mx sync.Mutex
stream network.Stream
}
}
dnsmx sync.RWMutex
dnsRecords map[string]string
portMapping map[string]*struct {
destPort string
destIP string
srcIP string
}
}
func (l *Libp2p) CreateSubnet(ctx context.Context, subnetID string, routingTable map[string]string) error {
if _, ok := l.subnets[subnetID]; ok {
return fmt.Errorf("subnet with ID %s already exists", subnetID)
}
s := newSubnet(ctx, l)
s.info.id = subnetID
for ip, peerctx := range routingTable {
peerID, err := peer.Decode(peerctx)
if err != nil {
return fmt.Errorf("failed to decode peer ID %s: %w", peerctx, err)
}
s.info.rtable.Add(peerID, ip)
}
if atomic.CompareAndSwapInt32(&l.isSubnetWriteProtocolRegistered, 0, 1) {
err := s.network.RegisterStreamMessageHandler(
types.MessageType(PacketExchangeProtocolID),
func(stream network.Stream) {
l.writePackets(stream)
})
if err != nil {
return err
}
}
l.subnets[subnetID] = s
return nil
}
func (l *Libp2p) DestroySubnet(subnetID string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
for ip := range s.ifaces {
s.ifaces[ip].cancel()
_ = s.ifaces[ip].tun.Down()
_ = s.ifaces[ip].tun.Delete()
}
s.mx.Lock()
s.ifaces = make(map[string]struct {
tun *sys.NetInterface
ctx context.Context
cancel context.CancelFunc
})
s.mx.Unlock()
s.io.mx.Lock()
for _, ms := range s.io.streams {
ms.mx.Lock()
_ = ms.stream.Reset()
ms.mx.Unlock()
}
s.io.streams = make(map[string]*struct {
mx sync.Mutex
stream network.Stream
})
s.io.mx.Unlock()
s.dnsmx.Lock()
s.dnsRecords = make(map[string]string)
s.dnsmx.Unlock()
s.info.rtable.Clear()
for sourcePort, mapping := range s.portMapping {
_ = l.UnmapPort(subnetID, "tcp", mapping.srcIP, sourcePort, mapping.destIP, mapping.destPort)
}
l.UnregisterMessageHandler(PacketExchangeProtocolID)
delete(l.subnets, subnetID)
return nil
}
func (l *Libp2p) AddSubnetPeer(subnetID, peerID, ip string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return fmt.Errorf("invalid IP address %s", ip)
}
peerIDObj, err := peer.Decode(peerID)
if err != nil {
return fmt.Errorf("failed to decode peer ID %s: %w", peerID, err)
}
s.info.rtable.Add(peerIDObj, ip)
ifaces, err := sys.GetNetInterfaces()
if err != nil {
return err
}
takenNames := make([]string, 0)
for _, iface := range ifaces {
takenNames = append(takenNames, iface.Name)
}
log.Debug("finding proper iface name for TUN interface", "taken_names", takenNames)
name, err := generateUniqueName(takenNames)
if err != nil {
return fmt.Errorf("failed to generate unique name for TUN interface: %w", err)
}
log.Debug("Creating TUN interface", "name", name)
address := fmt.Sprintf("%s/24", ipAddr.String())
iface, err := sys.NewTunTapInterface(name, sys.NetTunMode, false)
if err != nil {
return fmt.Errorf("failed to create tun interface: %w", err)
}
err = iface.SetAddress(address)
if err != nil {
return fmt.Errorf("failed to set address on tun interface: %w", err)
}
err = iface.SetMTU(IfaceMTU)
if err != nil {
return fmt.Errorf("failed to set MTU on tun interface: %w", err)
}
if err := iface.Up(); err != nil {
return fmt.Errorf("failed to bring up tun interface: %w", err)
}
ctx, cancel := context.WithCancel(s.ctx)
s.mx.Lock()
s.ifaces[ipAddr.String()] = struct {
tun *sys.NetInterface
ctx context.Context
cancel context.CancelFunc
}{
tun: iface,
ctx: ctx,
cancel: cancel,
}
s.mx.Unlock()
go s.readPackets(ctx, iface)
return nil
}
func (l *Libp2p) RemoveSubnetPeer(subnetID, peerID string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
peerIDObj, err := peer.Decode(peerID)
if err != nil {
return fmt.Errorf("failed to decode peer ID %s: %w", peerID, err)
}
ip, ok := s.info.rtable.Get(peerIDObj)
if !ok {
return fmt.Errorf("peer with ID %s is not in the subnet", peerID)
}
s.mx.Lock()
iface, ok := s.ifaces[ip]
if ok {
iface.cancel()
_ = iface.tun.Down()
_ = iface.tun.Delete()
delete(s.ifaces, ip)
}
s.mx.Unlock()
s.info.rtable.Remove(peerIDObj)
return nil
}
func (l *Libp2p) AcceptSubnetPeer(subnetID, peerID, ip string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return fmt.Errorf("invalid IP address %s", ip)
}
peerIDObj, err := peer.Decode(peerID)
if err != nil {
return fmt.Errorf("failed to decode peer ID %s: %w", peerID, err)
}
s.info.rtable.Add(peerIDObj, ip)
return nil
}
// AddDNSRecord adds a dns record to our local resolver
func (l *Libp2p) AddSubnetDNSRecord(subnetID, name, ip string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
s.dnsmx.Lock()
s.dnsRecords[name] = ip
s.dnsmx.Unlock()
return nil
}
func (l *Libp2p) RemoveSubnetDNSRecord(subnetID, name string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
s.dnsmx.Lock()
delete(s.dnsRecords, name)
s.dnsmx.Unlock()
return nil
}
func (l *Libp2p) writePackets(stream network.Stream) {
IDSize := make([]byte, 2)
// read_subnet_id
// Read the incoming packet's size as a binary value.
_, err := stream.Read(IDSize)
if err != nil {
log.Error("failed to read subnet id size from stream", err)
_ = stream.Reset()
return
}
// Decode the incoming packet's size from binary.
size := binary.LittleEndian.Uint16(IDSize)
subnetID := make([]byte, size)
// Read in the packet until completion.
var IDLen uint16
for IDLen < size {
tmp, err := stream.Read(subnetID[IDLen:size])
IDLen += uint16(tmp)
if err != nil {
log.Error("failed to read subnet id from stream", err)
_ = stream.Reset()
return
}
}
// retrieve subnet object
subnet, ok := l.subnets[string(subnetID)]
if !ok {
log.Errorf("unrecognized subnet id %s, subnet does not exist on this host", string(subnetID))
_ = stream.Reset()
return
}
subnet.writePackets(stream)
}
func newSubnet(ctx context.Context, l *Libp2p) *subnet {
return &subnet{
ctx: ctx,
network: l,
info: struct {
id string
rtable SubnetRoutingTable
}{
rtable: NewRoutingTable(),
},
ifaces: make(map[string]struct {
tun *sys.NetInterface
ctx context.Context
cancel context.CancelFunc
}),
io: struct {
mx sync.RWMutex
streams map[string]*struct {
mx sync.Mutex
stream network.Stream
}
}{
streams: make(map[string]*struct {
mx sync.Mutex
stream network.Stream
}),
},
dnsRecords: map[string]string{},
portMapping: map[string]*struct {
destPort string
destIP string
srcIP string
}{},
}
}
func (s *subnet) readPackets(ctx context.Context, iface *sys.NetInterface) {
for {
select {
case <-ctx.Done():
log.Debug("context done, abandoning read loop...", "subnet", s.info.id)
return
default:
{
packet := make([]byte, 1420)
// Read in a packet from the tun device.
plen, err := iface.Iface.Read(packet)
if errors.Is(err, fs.ErrClosed) {
time.Sleep(1 * time.Second)
log.Debug("tun device closed, abandoning read loop...", err, "subnet", s.info.id)
return
} else if err != nil {
log.Error("failed to read packet from tun device ", err, "subnet", s.info.id)
continue
}
if plen == 0 {
continue
}
srcPort, destPort, srcIP, destIP, err := s.parseIPPacket(packet)
if err != nil {
log.Error("failed to parse IP packet", err)
continue
}
log.Debug(
"read packet from tun device",
"tun", iface.Iface.Name(),
"subnet", s.info.id,
"destIP", destIP,
"destPort", destPort,
"srcIP", srcIP,
"srcPort", srcPort,
)
if destIP != "10.0.0.1" && destPort != 53 {
s.Route(destIP, packet, plen)
continue
}
log.Debug(
"handling DNS query",
"subnet", s.info.id,
"destIP", destIP,
"destPort", destPort,
"srcIP", srcIP,
"srcPort", srcPort,
)
if err := s.handleDNSQueries(iface, packet, plen); err != nil {
log.Error("failed to handle DNS query", err)
}
}
}
}
}
func (s *subnet) handleDNSQueries(iface *sys.NetInterface, packet []byte, packetlen int) error {
s.dnsmx.RLock()
payload, err := handleDNSQuery(packet[28:packetlen], s.dnsRecords)
s.dnsmx.RUnlock()
if err != nil {
return err
}
srcPort, destPort, srcIP, destIP, err := s.parseIPPacket(packet)
if err != nil {
return err
}
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: net.ParseIP(destIP),
DstIP: net.ParseIP(srcIP),
Protocol: layers.IPProtocolUDP,
}
udpLayer := &layers.UDP{
SrcPort: layers.UDPPort(destPort),
DstPort: layers.UDPPort(srcPort),
}
// Set the UDP checksum
err = udpLayer.SetNetworkLayerForChecksum(ipLayer)
if err != nil {
return err
}
// Create the packet
buffer := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buffer, opts,
ipLayer,
udpLayer,
gopacket.Payload(payload),
)
if err != nil {
return err
}
_, _ = iface.Iface.Write(buffer.Bytes())
return nil
}
func (s *subnet) Route(destIP string, packet []byte, plen int) {
log.Debug("routing packet", "subnet", s.info.id, "dstIP", destIP)
// check if present in our tuns table first
defer s.mx.Unlock()
s.mx.Lock()
if _, ok := s.ifaces[destIP]; ok {
log.Debug("found destination ip in tuns table", "subnet", s.info.id, "dstIP", destIP)
// if so, write to the tun
_, _ = s.ifaces[destIP].tun.Iface.Write(packet[:plen])
return
}
// if else check if present in our routing table
peerID, ok := s.info.rtable.GetByIP(destIP)
if !ok {
log.Debug("unrecognized destination ip", "subnet", s.info.id, "dstIP", destIP)
return
}
log.Debugf("found destination ip in routing table", "subnet", s.info.id, "dstIP", destIP, "peerID", peerID.String())
go s.redirectPacketToStream(s.ctx, peerID, packet, plen)
}
func (s *subnet) redirectPacketToStream(ctx context.Context, dst peer.ID, packet []byte, plen int) {
// Check if we already have an open connection to the destination peer.
defer s.io.mx.Unlock()
s.io.mx.Lock()
ms, ok := s.io.streams[dst.String()]
if ok {
log.Debug("found existing stream to destination peer", "subnet", s.info.id, "dst", dst.String())
if func() bool {
ms.mx.Lock()
defer ms.mx.Unlock()
_ = ms.stream.SetWriteDeadline(time.Now().Add(time.Second))
// Write out the packet's length to the libp2p stream to ensure
// we know the full size of the packet at the other end.
err := binary.Write(ms.stream, binary.LittleEndian, uint16(len(s.info.id)))
if err == nil {
// Write the packet out to the libp2p stream.
// If everything succeeds continue on to the next packet.
_, _ = (ms.stream).Write([]byte(s.info.id))
} else {
// If we encounter an error when writing to a stream we should
// close that stream and delete it from the active stream map.
_ = ms.stream.Reset()
delete(s.io.streams, dst.String())
return false
}
// Write out the packet's length to the libp2p stream to ensure
// we know the full size of the packet at the other end.
err = binary.Write(ms.stream, binary.LittleEndian, uint16(plen))
if err == nil {
// Write the packet out to the libp2p stream.
// If everything succeeds continue on to the next packet.
_, err = (ms.stream).Write(packet[:plen])
if err == nil {
return true
}
}
// If we encounter an error when writing to a stream we should
// close that stream and delete it from the active stream map.
ms.stream.Close()
delete(s.io.streams, dst.String())
return false
}() {
return
}
}
log.Debug("no existing stream to destination peer", "subnet", s.info.id, "dst", dst.String())
addrs, err := s.network.ResolveAddress(ctx, dst.String())
if err != nil {
log.Error("failed to resolve peer address", err, "subnet", s.info.id, "dst", dst.String())
return
}
protocolID := types.MessageType(PacketExchangeProtocolID)
stream, err := s.network.OpenStream(ctx, addrs[0], protocolID)
if err != nil {
log.Error("failed to open stream", err, "subnet", s.info.id, "dst", dst.String())
return
}
_ = stream.SetWriteDeadline(time.Now().Add(time.Second))
// Write packet length
err = binary.Write(stream, binary.LittleEndian, uint16(len([]byte(s.info.id))))
if err != nil {
log.Error("failed to write subnet id length", err, "subnet", s.info.id, "dst", dst.String())
stream.Close()
return
}
// Write the packet
_, err = stream.Write([]byte(s.info.id))
if err != nil {
log.Error("failed to write subnet id", err, "subnet", s.info.id, "dst", dst.String())
stream.Close()
return
}
// Write packet length
err = binary.Write(stream, binary.LittleEndian, uint16(plen))
if err != nil {
log.Error("failed to write packet length", err, "subnet", s.info.id, "dst", dst.String())
stream.Close()
return
}
// Write the packet
_, err = stream.Write(packet[:plen])
if err != nil {
log.Error("failed to write packet", err, "subnet", s.info.id, "dst", dst.String())
stream.Close()
return
}
// If all succeeds when writing the packet to the stream
// we should reuse this stream by adding it active streams map.
s.io.streams[dst.String()] = &struct {
mx sync.Mutex
stream network.Stream
}{
mx: sync.Mutex{},
stream: stream,
}
}
func (s *subnet) writePackets(stream network.Stream) {
defer stream.Close()
if _, ok := s.info.rtable.Get(stream.Conn().RemotePeer()); !ok {
log.Debug("unrecognized source peer", "subnet", s.info.id, "src", stream.Conn().RemotePeer().String())
_ = stream.Reset()
return
}
packet := make([]byte, 1420)
packetSize := make([]byte, 2)
for {
select {
case <-s.ctx.Done():
log.Debug("context done", "subnet", s.info.id)
_ = stream.Reset()
return
default:
{
// read_packet
// Read the incoming packet's size as a binary value.
_, err := stream.Read(packetSize)
if err != nil {
log.Error("failed to read packet size from stream", err, "subnet", s.info.id)
_ = stream.Reset()
return
}
// Decode the incoming packet's size from binary.
size := binary.LittleEndian.Uint16(packetSize)
// Read in the packet until completion.
var plen uint16
for plen < size {
tmp, err := stream.Read(packet[plen:size])
plen += uint16(tmp)
if err != nil {
log.Error("failed to read packet from stream", err, "subnet", s.info.id)
_ = stream.Reset()
return
}
}
_ = stream.SetWriteDeadline(time.Now().Add(time.Second))
log.Debug("read packet from stream", "subnet", s.info.id, "src", stream.Conn().RemotePeer().String())
log.Debug("read packet from stream", "subnet", s.info.id, "src", stream.Conn().RemotePeer().String())
// write_packet
destIP := net.IPv4(packet[16], packet[17], packet[18], packet[19]).String()
// retrieve proper tun and write to it
// if no tun is found, drop the packet
s.mx.Lock()
if iface, ok := s.ifaces[destIP]; ok {
log.Debug("writing packet to tun device", "tun", iface.tun.Iface.Name(), "subnet", s.info.id, "dstIP", destIP)
_, _ = iface.tun.Iface.Write(packet[:plen])
} else {
// drop the packet
log.Debug("unrecognized destination ip, no tun device found for ip", "subnet", s.info.id, "dstIP", destIP)
}
s.mx.Unlock()
}
}
}
}
func (s *subnet) parseIPPacket(rawPacket []byte) (srcPort int, destPort int, srcIP string, destIP string, err error) {
// Create a packet object from the raw data
packet := gopacket.NewPacket(rawPacket, layers.LayerTypeIPv4, gopacket.Default)
if err := packet.ErrorLayer(); err != nil {
return 0, 0, "", "", fmt.Errorf("failed to decode packet: %s", err)
}
// Get IP layer
ipLayer := packet.Layer(layers.LayerTypeIPv4)
if ipLayer != nil {
ip, _ := ipLayer.(*layers.IPv4)
srcIP = ip.SrcIP.String()
destIP = ip.DstIP.String()
}
// Get TCP layer
udpLayer := packet.Layer(layers.LayerTypeUDP)
if udpLayer != nil {
udp, _ := udpLayer.(*layers.UDP)
srcPort = int(udp.SrcPort)
destPort = int(udp.DstPort)
}
return
}
// stringSliceContains checks if a string is in a slice of strings
func stringSliceContains(slice []string, word string) bool {
for _, s := range slice {
if s == word {
return true
}
}
return false
}
func generateUniqueName(takenList []string) (string, error) {
var retries int
var candidate string
i := 30
for {
candidate = fmt.Sprintf("dms%d", rand.Intn(i)) //nolint:gosec
if !stringSliceContains(takenList, candidate) {
break
}
retries++
if retries > 30 {
i += 30
}
if retries > 100 {
return "", fmt.Errorf("failed to generate unique name")
}
}
return candidate, nil
}
//go:build linux
// +build linux
package libp2p
import (
"fmt"
"gitlab.com/nunet/device-management-service/lib/sys"
)
func (l *Libp2p) MapPort(subnetID, protocol, sourceIP, sourcePort, destIP, destPort string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
if _, ok := s.portMapping[sourcePort]; ok {
return fmt.Errorf("port %s is already mapped", sourcePort)
}
// TODO track the port so that we can unmap it when we tear down the subnet
err := sys.AddDNATRule(protocol, sourceIP, sourcePort, destIP, destPort)
if err != nil {
return err
}
err = sys.AddForwardRule("tcp", destIP, destPort)
if err != nil {
return err
}
err = sys.AddMasqueradeRule()
if err != nil {
return err
}
s.portMapping[sourcePort] = &struct {
destPort string
destIP string
srcIP string
}{
destPort: destPort,
destIP: destIP,
srcIP: sourceIP,
}
return nil
}
func (l *Libp2p) UnmapPort(subnetID, protocol, sourceIP, sourcePort, destIP, destPort string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
mapping, ok := s.portMapping[sourcePort]
if !ok {
return fmt.Errorf("port %s is not mapped", sourcePort)
}
if mapping.destIP != destIP || mapping.destPort != destPort || mapping.srcIP != sourceIP {
return fmt.Errorf("port %s is not mapped to %s:%s", sourcePort, destIP, destPort)
}
err := sys.DelDNATRule(protocol, sourceIP, sourcePort, destIP, destPort)
if err != nil {
return err
}
err = sys.DelForwardRule("tcp", destIP, destPort)
if err != nil {
return err
}
err = sys.DelMasqueradeRule()
if err != nil {
return err
}
delete(s.portMapping, sourcePort)
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/spf13/afero"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
type (
PeerID = libp2p.PeerID
ProtocolID = libp2p.ProtocolID
Topic = libp2p.Topic
Validator = libp2p.Validator
ValidationResult = libp2p.ValidationResult
PeerScoreSnapshot = libp2p.PeerScoreSnapshot
)
const (
ValidationAccept = libp2p.ValidationAccept
ValidationReject = libp2p.ValidationReject
ValidationIgnore = libp2p.ValidationIgnore
)
// Messenger defines the interface for sending messages.
type Messenger interface {
// SendMessage asynchronously sends a message to the given peer.
SendMessage(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error
// SendMessageSync synchronously sends a message to the given peer.
// This method blocks until the message has been sent.
SendMessageSync(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error
}
type Network interface {
// Messenger embedded interface
Messenger
// Init initializes the network
Init() error
// Start starts the network
Start() error
// Stat returns the network information
Stat() types.NetworkStats
// Ping pings the given address and returns the PingResult
Ping(ctx context.Context, address string, timeout time.Duration) (types.PingResult, error)
// HandleMessage is responsible for registering a message type and its handler.
HandleMessage(messageType string, handler func(data []byte)) error
// UnregisterMessageHandler unregisters a stream handler for a specific protocol.
UnregisterMessageHandler(messageType string)
// ResolveAddress given an id it retruns the address of the peer.
// In libp2p, id represents the peerID and the response is the addrinfo
ResolveAddress(ctx context.Context, id string) ([]string, error)
// Advertise advertises the given data with the given adId
// such as advertising device capabilities on the DHT
Advertise(ctx context.Context, key string, data []byte) error
// Unadvertise stops advertising data corresponding to the given adId
Unadvertise(ctx context.Context, key string) error
// Query returns the network advertisement
Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error)
// Publish publishes the given data to the given topic if the network
// type allows publish/subscribe functionality such as gossipsub or nats
Publish(ctx context.Context, topic string, data []byte) error
// Subscribe subscribes to the given topic and calls the handler function
// if the network type allows it similar to Publish()
Subscribe(ctx context.Context, topic string, handler func(data []byte), validator libp2p.Validator) (uint64, error)
// Unsubscribe from a topic
Unsubscribe(topic string, subID uint64) error
// SetupBroadcastTopic allows the application to configure pubsub topic directly
SetupBroadcastTopic(topic string, setup func(*Topic) error) error
// SetupBroadcastAppScore allows the application to configure application level
// scoring for pubsub
SetBroadcastAppScore(func(PeerID) float64)
// GetBroadcastScore returns the latest broadcast score snapshot
GetBroadcastScore() map[PeerID]*PeerScoreSnapshot
// Notify allows the application to receive notifications about peer connections
// and disconnecions
Notify(ctx context.Context, preconnected func(PeerID, []ProtocolID, int), connected, disconnected func(PeerID), identified, updated func(PeerID, []ProtocolID)) error
// PeerConnected returs true if the peer is currently connected
PeerConnected(p PeerID) bool
// Stop stops the network including any existing advertisements and subscriptions
Stop() error
// GetPeerIP returns the ipv4 or v6 of a peer
GetPeerIP(p PeerID) string
// CreateSubnet creates a subnet with the given subnetID and CIDR
CreateSubnet(ctx context.Context, subnetID string, routingTable map[string]string) error
// RemoveSubnet removes a subnet
DestroySubnet(subnetID string) error
// AddSubnetPeer adds a peer to the subnet
AddSubnetPeer(subnetID, peerID, ip string) error
// RemoveSubnetPeer removes a peer from the subnet
RemoveSubnetPeer(subnetID, peerID string) error
// AcceptSubnetPeer accepts a peer to the subnet
AcceptSubnetPeer(subnetID, peerID, ip string) error
// MapPort maps a sourceIp:sourcePort to destIP:destPort
MapPort(subnetID, protocol, sourceIP, sourcePort, destIP, destPort string) error
// UnmapPort removes a previous port map
UnmapPort(subnetID, protocol, sourceIP, sourcePort, destIP, destPort string) error
// AddDNSRecord adds a dns record to our local resolver
AddSubnetDNSRecord(subnetID, name, ip string) error
// RemoveDNSRecord removes a dns record from our local resolver
RemoveSubnetDNSRecord(subnetID, name string) error
}
// NewNetwork returns a new network given the configuration.
func NewNetwork(netConfig *types.NetworkConfig, fs afero.Fs) (Network, error) {
// TODO: probable additional params to receive: DB, FileSystem
if netConfig == nil {
return nil, errors.New("network configuration is nil")
}
switch netConfig.Type {
case types.Libp2pNetwork:
ln, err := libp2p.New(&netConfig.Libp2pConfig, fs)
return ln, err
case types.NATSNetwork:
return nil, errors.New("not implemented")
default:
return nil, fmt.Errorf("unsupported network type: %s", netConfig.Type)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"fmt"
"net"
"strconv"
"strings"
"golang.org/x/exp/rand"
)
// GetNextIP returns the next available IP in the CIDR range
func GetNextIP(cidr string, usedIPs map[string]bool) (net.IP, error) {
cidrParts := strings.Split(cidr, "/")
if len(cidrParts) != 2 {
return nil, fmt.Errorf("invalid CIDR %s", cidr)
}
mask, err := strconv.Atoi(cidrParts[1])
if err != nil {
return nil, err
}
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, err
}
networkMask := ipnet.IP.Mask(ipnet.Mask)
firstHostIP := net.IP{networkMask[0], networkMask[1], networkMask[2], networkMask[3] + byte(1)}
for ip := firstHostIP; ipnet.Contains(ip); ip = nextIP(ip, mask) {
if ip == nil {
break
}
if !usedIPs[ip.String()] {
return ip, nil
}
}
return nil, fmt.Errorf("no available IPs in CIDR %s", cidr)
}
// nextIP returns the next available IP in the network
func nextIP(ip net.IP, netmask int) net.IP {
if ip4 := ip.To4(); ip4 != nil {
if netmask == 0 && ip4[0] == 255 && ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 254 {
return nil // no more IPs for this network
}
if netmask == 0 && ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 254 {
ip4[0]++
ip4[1] = 0
ip4[2] = 0
ip4[3] = 0
}
if netmask == 8 && ip[1] == 255 && ip4[2] == 255 && ip4[3] == 254 {
return nil // no more IPs for this network
}
if netmask == 8 && ip[1] < 255 && ip4[2] == 255 && ip4[3] == 254 {
ip4[1]++
ip4[2] = 0
ip4[3] = 0
}
if netmask == 16 && ip4[2] == 255 && ip4[3] == 254 {
return nil // no more IPs for this network
}
if (netmask == 16 || netmask == 8) && ip[2] < 255 && ip4[3] == 254 {
ip4[2]++
ip4[3] = 0
}
if netmask == 24 && ip4[3] == 254 {
return nil // no more IPs for this network
}
ip4[3]++
return ip4
}
return nil
}
// GetRandomCIDR returns a random CIDR with the given mask
// and not in the blacklist.
// If the blacklist is empty, it will return a random CIDR with the given mask.
// This function supports mask 0, 8, 16, 24.
// If you need more elaborate masks to get more subnets (i.e: 0<mask<32)
// refactor this to use bitwise operations on the IP.
func GetRandomCIDR(mask int, blacklist []string) (string, error) {
var cidr string
var breakCounter int
for {
if mask > 0 && breakCounter > 2^mask || mask == 0 && breakCounter > 255 {
return fmt.Sprintf("%s/%d", "0.0.0.0", mask), fmt.Errorf("could not find a CIDR after %d attempts", breakCounter)
}
ip := net.IP{byte(rand.Intn(256)), byte(rand.Intn(256)), byte(rand.Intn(256)), byte(rand.Intn(256))}
candidate := fmt.Sprintf("%s/%d", ip, mask)
if !isOnBlacklist(candidate, blacklist) {
cidr = candidate
break
}
breakCounter++
}
_, ip, err := net.ParseCIDR(cidr)
if err != nil {
return "", nil
}
return ip.String(), nil
}
// isOnBlacklist returns true if the given CIDR is in the blacklist.
func isOnBlacklist(cidr string, blacklist []string) bool {
for _, blacklistedCIDR := range blacklist {
_, blacklistedSubnet, err := net.ParseCIDR(blacklistedCIDR)
if err != nil {
return false // Ignore errors in blacklist for safety
}
_, subnet, err := net.ParseCIDR(cidr)
if err != nil {
return false // Ignore errors in generated CIDR for safety
}
if blacklistedSubnet.Contains(subnet.IP) {
return true
}
}
return false
}
func GetRandomCIDRInRange(mask int, start, end net.IP, blacklist []string) (string, error) {
var cidr string
var breakCounter int
networkBitsIndex := mask / 8
for {
if mask > 0 && breakCounter > 2^mask || mask == 0 && breakCounter > 255 {
return fmt.Sprintf("%s/%d", "0.0.0.0", mask), fmt.Errorf("could not find a CIDR after %d attempts", breakCounter)
}
ip := net.IP{byte(0), byte(0), byte(0), byte(0)}
for i := 0; i < networkBitsIndex; i++ {
ip[i] = byte(randRange(int(start.To4()[i]), int(end.To4()[i])))
}
candidate := fmt.Sprintf("%s/%d", ip, mask)
if !isOnBlacklist(candidate, blacklist) {
cidr = candidate
break
}
breakCounter++
}
_, ip, err := net.ParseCIDR(cidr)
if err != nil {
return "", nil
}
return ip.String(), nil
}
func randRange(min, max int) int {
return rand.Intn(max-min+1) + min
}
package observability
import (
"context"
"fmt"
"os"
"path/filepath"
"sync"
"time"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/olivere/elastic/v7"
"gitlab.com/nunet/device-management-service/internal/config"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gopkg.in/natefinch/lumberjack.v2"
)
const timestampKey = "timestamp"
var (
// EventBus is the global event bus instance
EventBus event.Bus
// customEventEmitter is the emitter for CustomEvent
customEventEmitter event.Emitter
)
// CustomEvent represents a custom event structure
type CustomEvent struct {
Name string
Timestamp time.Time
Data map[string]interface{}
}
// Initialize sets up the logger, tracing, and event bus
func Initialize(host host.Host) error {
if isNoOp() {
return nil
}
// Load the configuration
cfg := config.GetConfig()
// Initialize the event bus
if err := initEventBus(host); err != nil {
return err
}
// Initialize the logger with configurations
if err := initLogger(cfg.Observability); err != nil {
fmt.Fprintf(os.Stderr, "Warning: Failed to initialize logger: %v\n", err)
}
// Initialize Elastic APM tracing
if err := initTracing(cfg.APM); err != nil {
fmt.Fprintf(os.Stderr, "Warning: Failed to initialize tracing: %v\n", err)
}
return nil
}
// OverrideLoggerForTesting reconfigures the logger to log only to console
func OverrideLoggerForTesting() error {
// Use the existing configuration
cfg := config.GetConfig()
// Parse the global log level
logLevel, err := parseLogLevel(cfg.Observability.LogLevel)
if err != nil {
return fmt.Errorf("invalid log level: %w", err)
}
// Reconfigure the logger to log only to console
consoleCore := createConsoleCore(logLevel)
combinedCore = zapcore.NewTee(consoleCore)
// Replace the global logger
logging.SetPrimaryCore(combinedCore)
return nil
}
// Global variables to hold references to cores for dynamic updates
var (
combinedCore zapcore.Core
esSyncerInstance *bufferedElasticsearchSyncer
)
// initLogger configures the global logger with console, file, Elasticsearch logging, and event emission
func initLogger(observabilityConfig config.Observability) error {
// make sure log dir exists
if err := os.MkdirAll(filepath.Dir(observabilityConfig.LogFile), 0o755); err != nil {
return fmt.Errorf("failed to create log directory: %w", err)
}
// Parse the global log level
logLevel, err := parseLogLevel(observabilityConfig.LogLevel)
if err != nil {
return fmt.Errorf("invalid log level: %w", err)
}
// Create cores
consoleCore := createConsoleCore(logLevel)
fileCore := createFileCore(observabilityConfig, logLevel)
esCore, err := createElasticsearchCore(observabilityConfig, logLevel)
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: Unable to create Elasticsearch logger: %v\n", err)
esCore = nil // Proceed without Elasticsearch core
}
eventCore := newEventEmitterCore(logLevel)
// Combine cores, excluding nil cores
var cores []zapcore.Core
cores = append(cores, consoleCore, fileCore)
if esCore != nil {
cores = append(cores, esCore)
}
cores = append(cores, eventCore)
combinedCore = zapcore.NewTee(cores...)
// Replace the global logger
logging.SetPrimaryCore(combinedCore)
return nil
}
// parseLogLevel parses a string into a zapcore.Level
func parseLogLevel(levelStr string) (zapcore.Level, error) {
var level zapcore.Level
err := level.UnmarshalText([]byte(levelStr))
if err != nil {
return 0, err
}
return level, nil
}
// createConsoleCore creates a console logging core
func createConsoleCore(logLevel zapcore.Level) zapcore.Core {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = timestampKey
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
consoleEncoder := zapcore.NewConsoleEncoder(encoderConfig)
consoleWS := zapcore.AddSync(os.Stdout)
return zapcore.NewCore(consoleEncoder, consoleWS, logLevel)
}
// createFileCore creates a file logging core
func createFileCore(observabilityConfig config.Observability, logLevel zapcore.Level) zapcore.Core {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = timestampKey
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
fileEncoder := zapcore.NewJSONEncoder(encoderConfig)
fileWS := zapcore.AddSync(&lumberjack.Logger{
Filename: observabilityConfig.LogFile,
MaxSize: observabilityConfig.MaxSize, // megabytes
MaxBackups: observabilityConfig.MaxBackups,
MaxAge: observabilityConfig.MaxAge, // days
Compress: true,
})
return zapcore.NewCore(fileEncoder, fileWS, logLevel)
}
// createElasticsearchCore creates an Elasticsearch logging core
func createElasticsearchCore(observabilityConfig config.Observability, logLevel zapcore.Level) (zapcore.Core, error) {
esWS, err := newElasticsearchWriteSyncer(
observabilityConfig.ElasticsearchURL,
observabilityConfig.ElasticsearchIndex,
time.Duration(observabilityConfig.FlushInterval)*time.Second,
)
if err != nil {
return nil, err
}
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = timestampKey
encoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
esEncoder := zapcore.NewJSONEncoder(encoderConfig)
return zapcore.NewCore(esEncoder, esWS, logLevel), nil
}
// newElasticsearchWriteSyncer creates a WriteSyncer for Elasticsearch with buffering
func newElasticsearchWriteSyncer(url string, index string, flushInterval time.Duration) (zapcore.WriteSyncer, error) {
// Create Elasticsearch client
client, err := elastic.NewClient(
elastic.SetURL(url),
elastic.SetSniff(false), // Disable sniffing if not using a cluster
elastic.SetHealthcheck(false), // Disable initial health check
)
if err != nil {
return nil, err
}
esSyncer := newBufferedElasticsearchSyncer(client, index, flushInterval)
// Store the instance globally for dynamic updates
esSyncerInstance = esSyncer
return esSyncer, nil
}
// bufferedElasticsearchSyncer implements zapcore.WriteSyncer to send logs to Elasticsearch with buffering
type bufferedElasticsearchSyncer struct {
client *elastic.Client
index string
ctx context.Context
buffer []string
bufferMutex sync.Mutex
flushInterval time.Duration
ticker *time.Ticker
done chan struct{}
}
// newBufferedElasticsearchSyncer creates a new bufferedElasticsearchSyncer
func newBufferedElasticsearchSyncer(client *elastic.Client, index string, flushInterval time.Duration) *bufferedElasticsearchSyncer {
syncer := &bufferedElasticsearchSyncer{
client: client,
index: index,
ctx: context.Background(),
buffer: make([]string, 0),
flushInterval: flushInterval,
done: make(chan struct{}),
}
// Start the flush ticker
syncer.ticker = time.NewTicker(syncer.flushInterval)
go syncer.start()
return syncer
}
// start begins the periodic flushing of the buffer
func (b *bufferedElasticsearchSyncer) start() {
for {
select {
case <-b.ticker.C:
b.Flush()
case <-b.done:
return
}
}
}
// Write buffers the log entry
func (b *bufferedElasticsearchSyncer) Write(p []byte) (n int, err error) {
b.bufferMutex.Lock()
b.buffer = append(b.buffer, string(p))
b.bufferMutex.Unlock()
return len(p), nil
}
// Sync flushes the buffer to Elasticsearch
func (b *bufferedElasticsearchSyncer) Sync() error {
b.Flush()
return nil
}
// Flush sends the buffered log entries to Elasticsearch
func (b *bufferedElasticsearchSyncer) Flush() {
b.bufferMutex.Lock()
bufferCopy := b.buffer
b.buffer = make([]string, 0)
b.bufferMutex.Unlock()
if len(bufferCopy) == 0 {
return
}
bulkRequest := b.client.Bulk()
for _, logEntry := range bufferCopy {
req := elastic.NewBulkIndexRequest().Index(b.index).Doc(logEntry)
bulkRequest = bulkRequest.Add(req)
}
_, err := bulkRequest.Do(b.ctx)
if err != nil {
// Handle the error (e.g., log it)
fmt.Fprintf(os.Stderr, "Error flushing logs to Elasticsearch: %v\n", err)
}
}
// Close stops the ticker and flushes remaining logs
func (b *bufferedElasticsearchSyncer) Close() {
b.ticker.Stop()
close(b.done)
b.Flush()
}
// setFlushInterval allows changing the flush interval dynamically
func (b *bufferedElasticsearchSyncer) setFlushInterval(interval time.Duration) {
b.ticker.Stop()
b.flushInterval = interval
b.ticker = time.NewTicker(b.flushInterval)
}
// initEventBus initializes the global event bus
func initEventBus(host host.Host) error {
EventBus = host.EventBus()
// Create an emitter for CustomEvent
var err error
customEventEmitter, err = EventBus.Emitter(new(CustomEvent))
if err != nil {
return fmt.Errorf("failed to create custom event emitter: %w", err)
}
return nil
}
// newEventEmitterCore creates a zapcore.Core that emits log entries to the event bus
func newEventEmitterCore(level zapcore.LevelEnabler) zapcore.Core {
return &eventEmitterCore{
LevelEnabler: level,
}
}
// eventEmitterCore is a zapcore.Core that emits log entries to the event bus
type eventEmitterCore struct {
zapcore.LevelEnabler
fields []zapcore.Field
}
// With implements zapcore.Core
func (e *eventEmitterCore) With(fields []zapcore.Field) zapcore.Core {
return &eventEmitterCore{
LevelEnabler: e.LevelEnabler,
fields: append(e.fields, fields...),
}
}
// Check implements zapcore.Core
func (e *eventEmitterCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if e.Enabled(entry.Level) {
return ce.AddCore(entry, e)
}
return ce
}
// Write implements zapcore.Core
func (e *eventEmitterCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
// Convert the log entry and fields into a map[string]interface{}
eventData := map[string]interface{}{
"level": entry.Level.String(),
"timestamp": entry.Time.UTC().Format(time.RFC3339),
"message": entry.Message,
"logger": entry.LoggerName,
"caller": entry.Caller.String(),
}
// Add fields
for _, field := range fields {
eventData[field.Key] = field.Interface
}
// Create a CustomEvent
customEvent := CustomEvent{
Name: "log_event",
Timestamp: entry.Time,
Data: eventData,
}
// Emit the event using the customEventEmitter
if err := customEventEmitter.Emit(customEvent); err != nil {
// Handle error if necessary
fmt.Fprintf(os.Stderr, "Error emitting event: %v\n", err)
}
return nil
}
// Sync implements zapcore.Core
func (e *eventEmitterCore) Sync() error {
return nil
}
// SetLogLevel sets the global log level for all collectors
func SetLogLevel(level string) error {
_, err := parseLogLevel(level)
if err != nil {
return err
}
// Update the configuration
cfg := config.GetConfig()
cfg.Observability.LogLevel = level
// Rebuild the logger with the new log level
return rebuildLogger(cfg.Observability)
}
// rebuildLogger rebuilds the combined core and updates the global logger
func rebuildLogger(observabilityConfig config.Observability) error {
// Re-initialize the logger with the updated configuration
return initLogger(observabilityConfig)
}
// SetFlushInterval sets the flush interval for Elasticsearch logging dynamically
func SetFlushInterval(seconds int) error {
// Update the configuration
cfg := config.GetConfig()
cfg.Observability.FlushInterval = seconds
// Update the flush interval in the elasticsearchWriteSyncer
if esSyncerInstance != nil {
esSyncerInstance.setFlushInterval(time.Duration(seconds) * time.Second)
}
return nil
}
// EmitCustomEvent allows developers to emit custom events with variadic key-value pairs
func EmitCustomEvent(eventName string, keyValues ...interface{}) error {
if len(keyValues)%2 != 0 {
return fmt.Errorf("keyValues must be in key-value pairs")
}
eventData := make(map[string]interface{})
for i := 0; i < len(keyValues); i += 2 {
key, ok := keyValues[i].(string)
if !ok {
return fmt.Errorf("key must be a string")
}
eventData[key] = keyValues[i+1]
}
// Create the custom event
customEvent := &CustomEvent{
Name: eventName,
Timestamp: time.Now(),
Data: eventData,
}
// Emit the event using the customEventEmitter
if err := customEventEmitter.Emit(customEvent); err != nil {
return fmt.Errorf("failed to emit custom event: %w", err)
}
return nil
}
// Shutdown cleans up resources
func Shutdown() {
if customEventEmitter != nil {
customEventEmitter.Close()
}
if esSyncerInstance != nil {
esSyncerInstance.Close()
}
}
package observability
import (
"fmt"
"net/url"
"os"
"sync"
"time"
logging "github.com/ipfs/go-log/v2"
"gitlab.com/nunet/device-management-service/internal/config"
"go.elastic.co/apm"
"go.elastic.co/apm/transport"
)
var (
// noOpMode indicates whether tracing is in no-op mode
noOpMode bool
// mutex to protect access to noOpMode
mutex sync.RWMutex
// log is the logger for the observability package
log = logging.Logger("observability")
)
func initTracing(apmConfig config.APM) error {
// Create a new APM transport
tr, err := transport.NewHTTPTransport()
if err != nil {
return fmt.Errorf("failed to create APM transport: %w", err)
}
// Parse the APM Server URL
serverURL, err := url.Parse(apmConfig.ServerURL)
if err != nil {
return fmt.Errorf("failed to parse APM server URL: %w", err)
}
// Set the APM Server URL
tr.SetServerURL(serverURL)
// Create a new tracer with the transport and set the environment
apm.DefaultTracer, err = apm.NewTracerOptions(apm.TracerOptions{
ServiceName: apmConfig.ServiceName,
ServiceVersion: "1.0.0",
ServiceEnvironment: apmConfig.Environment,
Transport: tr,
})
if err != nil {
fmt.Fprintf(os.Stderr, "Warning: Failed to initialize APM tracing: %v\n", err)
SetNoOpMode(true)
return nil // Proceed without tracing
}
return nil
}
// StartTrace starts a trace for the given operationName and key-value pairs.
// It returns a function that should be deferred to end the trace.
func StartTrace(operationName string, keyValues ...interface{}) func() {
if isNoOp() {
return func() {}
}
// Start an Elastic APM transaction
tx := apm.DefaultTracer.StartTransaction(operationName, "custom")
// Record the start time
startTime := time.Now()
// Log the start of the operation with the original naming
logFields := append([]interface{}{
"startTime", startTime,
"trace.id", tx.TraceContext().Trace.String(),
"transaction.id", tx.TraceContext().Span.String(),
}, keyValues...)
log.Infow(operationName+"_start", logFields...)
// Return the EndTrace function
return func() {
// Calculate duration
endTime := time.Now()
duration := endTime.Sub(startTime)
// Log the end of the operation
logFields = append([]interface{}{
"endTime", endTime,
"duration", duration,
"trace.id", tx.TraceContext().Trace.String(),
"transaction.id", tx.TraceContext().Span.String(),
}, keyValues...)
log.Infow(operationName+"_end", logFields...)
// End the transaction
tx.End()
}
}
// SetNoOpMode enables or disables the no-op mode for tracing.
func SetNoOpMode(enabled bool) {
mutex.Lock()
defer mutex.Unlock()
noOpMode = enabled
}
// isNoOp checks if the tracer is in no-op mode.
func isNoOp() bool {
mutex.RLock()
defer mutex.RUnlock()
return noOpMode
}
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.34.2
// protoc v5.28.0
// source: common.proto
package common
import (
reflect "reflect"
sync "sync"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
// Advertisement is the envelope to advertise peers payload.
type Advertisement struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
PeerId string `protobuf:"bytes,1,opt,name=peer_id,json=peerId,proto3" json:"peer_id,omitempty"`
Timestamp int64 `protobuf:"varint,2,opt,name=timestamp,proto3" json:"timestamp,omitempty"`
Data []byte `protobuf:"bytes,3,opt,name=data,proto3" json:"data,omitempty"`
Signature []byte `protobuf:"bytes,4,opt,name=signature,proto3" json:"signature,omitempty"`
PublicKey []byte `protobuf:"bytes,5,opt,name=public_key,json=publicKey,proto3" json:"public_key,omitempty"`
}
func (x *Advertisement) Reset() {
*x = Advertisement{}
if protoimpl.UnsafeEnabled {
mi := &file_common_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *Advertisement) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*Advertisement) ProtoMessage() {}
func (x *Advertisement) ProtoReflect() protoreflect.Message {
mi := &file_common_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use Advertisement.ProtoReflect.Descriptor instead.
func (*Advertisement) Descriptor() ([]byte, []int) {
return file_common_proto_rawDescGZIP(), []int{0}
}
func (x *Advertisement) GetPeerId() string {
if x != nil {
return x.PeerId
}
return ""
}
func (x *Advertisement) GetTimestamp() int64 {
if x != nil {
return x.Timestamp
}
return 0
}
func (x *Advertisement) GetData() []byte {
if x != nil {
return x.Data
}
return nil
}
func (x *Advertisement) GetSignature() []byte {
if x != nil {
return x.Signature
}
return nil
}
func (x *Advertisement) GetPublicKey() []byte {
if x != nil {
return x.PublicKey
}
return nil
}
var File_common_proto protoreflect.FileDescriptor
var file_common_proto_rawDesc = []byte{
0x0a, 0x0c, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x06,
0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x6e, 0x22, 0x97, 0x01, 0x0a, 0x0d, 0x41, 0x64, 0x76, 0x65, 0x72,
0x74, 0x69, 0x73, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x70, 0x65, 0x65, 0x72,
0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x65, 0x65, 0x72, 0x49,
0x64, 0x12, 0x1c, 0x0a, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x18, 0x02,
0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x12,
0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64,
0x61, 0x74, 0x61, 0x12, 0x1c, 0x0a, 0x09, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72, 0x65,
0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x73, 0x69, 0x67, 0x6e, 0x61, 0x74, 0x75, 0x72,
0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x5f, 0x6b, 0x65, 0x79, 0x18,
0x05, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x70, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79,
0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_common_proto_rawDescOnce sync.Once
file_common_proto_rawDescData = file_common_proto_rawDesc
)
func file_common_proto_rawDescGZIP() []byte {
file_common_proto_rawDescOnce.Do(func() {
file_common_proto_rawDescData = protoimpl.X.CompressGZIP(file_common_proto_rawDescData)
})
return file_common_proto_rawDescData
}
var (
file_common_proto_msgTypes = make([]protoimpl.MessageInfo, 1)
file_common_proto_goTypes = []any{
(*Advertisement)(nil), // 0: common.Advertisement
}
)
var file_common_proto_depIdxs = []int32{
0, // [0:0] is the sub-list for method output_type
0, // [0:0] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_common_proto_init() }
func file_common_proto_init() {
if File_common_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_common_proto_msgTypes[0].Exporter = func(v any, i int) any {
switch v := v.(*Advertisement); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_common_proto_rawDesc,
NumEnums: 0,
NumMessages: 1,
NumExtensions: 0,
NumServices: 0,
},
GoTypes: file_common_proto_goTypes,
DependencyIndexes: file_common_proto_depIdxs,
MessageInfos: file_common_proto_msgTypes,
}.Build()
File_common_proto = out.File
file_common_proto_rawDesc = nil
file_common_proto_goTypes = nil
file_common_proto_depIdxs = nil
}
package basiccontroller
import (
"context"
"os"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// BasicVolumeController is the default implementation of the VolumeController.
// It persists storage volumes information using the StorageVolume.
type BasicVolumeController struct {
// repo is the repository for storage volume operations
repo repositories.StorageVolume
// basePath is the base path where volumes are stored under
basePath string
// file system to act upon
FS afero.Fs
}
// NewDefaultVolumeController returns a new instance of BasicVolumeController
//
// TODO-BugFix [path]: volBasePath might not end with `/`, causing errors when calling methods.
// We need to validate it using the `path` library or just verifying the string.
func NewDefaultVolumeController(repo repositories.StorageVolume, volBasePath string, fs afero.Fs) (*BasicVolumeController, error) {
endTrace := observability.StartTrace(TraceVolumeControllerInitDuration)
defer endTrace()
vc := &BasicVolumeController{
repo: repo,
basePath: volBasePath,
FS: fs,
}
log.Infow(LogVolumeControllerInitSuccess, LogKeyBasePath, volBasePath)
return vc, nil
}
// CreateVolume creates a new storage volume given a source (S3, IPFS, job, etc). The
// creation of a storage volume effectively creates an empty directory in the local filesystem
// and writes a record in the database.
//
// The directory name follows the format: `<volSource> + "-" + <name>
// where `name` is random.
//
// TODO-maybe [withName]: allow callers to specify custom name for path
func (vc *BasicVolumeController) CreateVolume(volSource storage.VolumeSource, opts ...storage.CreateVolOpt) (types.StorageVolume, error) {
endTrace := observability.StartTrace(TraceVolumeCreateDuration)
defer endTrace()
vol := types.StorageVolume{
Private: false,
ReadOnly: false,
EncryptionType: types.EncryptionTypeNull,
}
for _, opt := range opts {
opt(&vol)
}
randomStr, err := utils.RandomString(16)
if err != nil {
log.Errorw(LogVolumeCreateFailure, LogKeyError, err)
return types.StorageVolume{}, err
}
vol.Path = vc.basePath + string(volSource) + "-" + randomStr
if err := vc.FS.Mkdir(vol.Path, os.ModePerm); err != nil {
log.Errorw(LogVolumeCreateFailure, LogKeyPath, vol.Path, LogKeyError, err)
return types.StorageVolume{}, err
}
createdVol, err := vc.repo.Create(context.TODO(), vol)
if err != nil {
log.Errorw(LogVolumeCreateFailure, LogKeyPath, vol.Path, LogKeyError, err)
return types.StorageVolume{}, err
}
log.Infow(LogVolumeCreateSuccess, LogKeyVolumeID, createdVol.ID, LogKeyPath, vol.Path)
return createdVol, nil
}
// LockVolume makes the volume read-only, not only changing the field value but also changing file permissions.
// It should be used after all necessary data has been written.
// It optionally can also set the CID and mark the volume as private.
//
// TODO-maybe [CID]: maybe calculate CID of every volume in case WithCID opt is not provided
func (vc *BasicVolumeController) LockVolume(pathToVol string, opts ...storage.LockVolOpt) error {
endTrace := observability.StartTrace(TraceVolumeLockDuration)
defer endTrace()
query := vc.repo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("Path", pathToVol))
vol, err := vc.repo.Find(context.TODO(), query)
if err != nil {
log.Errorw(LogVolumeLockFailure, LogKeyPath, pathToVol, LogKeyError, err)
return err
}
for _, opt := range opts {
opt(&vol)
}
vol.ReadOnly = true
updatedVol, err := vc.repo.Update(context.TODO(), vol.ID, vol)
if err != nil {
log.Errorw(LogVolumeLockFailure, LogKeyVolumeID, vol.ID, LogKeyError, err)
return err
}
// Change file permissions to read-only
if err := vc.FS.Chmod(updatedVol.Path, 0o400); err != nil {
log.Errorw(LogVolumeLockFailure, LogKeyPath, updatedVol.Path, LogKeyError, err)
return err
}
log.Infow(LogVolumeLockSuccess, LogKeyVolumeID, updatedVol.ID, LogKeyPath, updatedVol.Path)
return nil
}
// WithPrivate designates a given volume as private. It can be used both
// when creating or locking a volume.
func WithPrivate[T storage.CreateVolOpt | storage.LockVolOpt]() T {
return func(v *types.StorageVolume) {
v.Private = true
}
}
// WithCID sets the CID of a given volume if already calculated
//
// TODO [validate]: check if CID provided is valid
func WithCID(cid string) storage.LockVolOpt {
return func(v *types.StorageVolume) {
v.CID = cid
}
}
// DeleteVolume deletes a given storage volume record from the database and removes the corresponding directory.
// Identifier is either a CID or a path of a volume.
//
// Note [CID]: if we start to type CID as cid.CID, we may have to use generics here
// as in `[T string | cid.CID]`
func (vc *BasicVolumeController) DeleteVolume(identifier string, idType storage.IDType) error {
endTrace := observability.StartTrace(TraceVolumeDeleteDuration)
defer endTrace()
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
log.Errorw(LogVolumeDeleteFailure, LogKeyIdentifier, identifier, LogKeyError, ErrMsgInvalidIdentifier)
return ErrInvalidIdentifier
}
vol, err := vc.repo.Find(context.TODO(), query)
if err != nil {
if err == repositories.ErrNotFound {
log.Errorw(LogVolumeDeleteFailure, LogKeyIdentifier, identifier, LogKeyError, ErrMsgVolumeNotFound)
return repositories.ErrNotFound
}
log.Errorw(LogVolumeDeleteFailure, LogKeyIdentifier, identifier, LogKeyError, err)
return err
}
// Remove the directory
if err := vc.FS.RemoveAll(vol.Path); err != nil {
log.Errorw(LogVolumeDeleteFailure, LogKeyPath, vol.Path, LogKeyError, err)
return err
}
if err := vc.repo.Delete(context.TODO(), vol.ID); err != nil {
log.Errorw(LogVolumeDeleteFailure, LogKeyVolumeID, vol.ID, LogKeyError, err)
return err
}
log.Infow(LogVolumeDeleteSuccess, LogKeyVolumeID, vol.ID, LogKeyPath, vol.Path)
return nil
}
// ListVolumes returns a list of all storage volumes stored on the database
//
// TODO [filter]: maybe add opts to filter results by certain values
func (vc *BasicVolumeController) ListVolumes() ([]types.StorageVolume, error) {
endTrace := observability.StartTrace(TraceVolumeListDuration)
defer endTrace()
volumes, err := vc.repo.FindAll(context.TODO(), vc.repo.GetQuery())
if err != nil {
log.Errorw(LogVolumeListFailure, LogKeyError, err)
return nil, err
}
log.Infow(LogVolumeListSuccess, LogKeyVolumeCount, len(volumes))
return volumes, nil
}
// GetSize returns the size of a volume
// TODO-minor: identify which measurement type will be used
func (vc *BasicVolumeController) GetSize(identifier string, idType storage.IDType) (int64, error) {
endTrace := observability.StartTrace(TraceVolumeGetSizeDuration)
defer endTrace()
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
log.Errorw(LogVolumeGetSizeFailure, LogKeyIdentifier, identifier, LogKeyError, ErrMsgInvalidIdentifier)
return 0, ErrInvalidIdentifier
}
vol, err := vc.repo.Find(context.TODO(), query)
if err != nil {
log.Errorw(LogVolumeGetSizeFailure, LogKeyIdentifier, identifier, LogKeyError, err)
return 0, err
}
size, err := utils.GetDirectorySize(vc.FS, vol.Path)
if err != nil {
log.Errorw(LogVolumeGetSizeFailure, LogKeyPath, vol.Path, LogKeyError, err)
return 0, err
}
log.Infow(LogVolumeGetSizeSuccess, LogKeyVolumeID, vol.ID, LogKeySize, size)
return size, nil
}
// EncryptVolume encrypts a given volume
func (vc *BasicVolumeController) EncryptVolume(path string, _ types.Encryptor, _ types.EncryptionType) error {
endTrace := observability.StartTrace(TraceVolumeEncryptDuration)
defer endTrace()
log.Errorw(LogVolumeEncryptNotImplemented, LogKeyPath, path)
return ErrNotImplemented
}
// DecryptVolume decrypts a given volume
func (vc *BasicVolumeController) DecryptVolume(path string, _ types.Decryptor, _ types.EncryptionType) error {
endTrace := observability.StartTrace(TraceVolumeDecryptDuration)
defer endTrace()
log.Errorw(LogVolumeDecryptNotImplemented, LogKeyPath, path)
return ErrNotImplemented
}
var _ storage.VolumeController = (*BasicVolumeController)(nil)
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package basiccontroller
import (
"context"
"fmt"
"github.com/spf13/afero"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
rGorm "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/types"
)
type VolumeControllerTestKit struct {
BasicVolController *BasicVolumeController
Fs afero.Fs
Volumes map[string]*types.StorageVolume
}
// SetupVolumeControllerTestKit sets up a volume controller with 0-n volumes given a base path.
// If volumes are inputed, directories will be created and volumes will be stored in the database
func SetupVolumeControllerTestKit(basePath string, volumes map[string]*types.StorageVolume) (*VolumeControllerTestKit, error) {
// Initialize telemetry in test mode, replacing the global st
// It's initiated here too, besides on basic_controller_test.go, because
// s3 tests depend on basicController (which in turn depends on telemetry instantiation).
// S3 are calling this SetupVolControllerTestSuite, so it's one way to initialize telemetry
// for basic controller
db, err := gorm.Open(
sqlite.Open("file:?mode=memory&cache=shared"),
&gorm.Config{Logger: logger.Default.LogMode(logger.Silent)},
)
if err != nil {
return nil, fmt.Errorf("failed to create in-memory mock database: %w", err)
}
err = db.AutoMigrate(&types.StorageVolume{})
if err != nil {
return nil, fmt.Errorf("failed to automigrate: %w", err)
}
fs := afero.NewMemMapFs()
err = fs.MkdirAll(basePath, 0o755)
if err != nil {
return nil, fmt.Errorf("failed to create base path: %w", err)
}
repo := rGorm.NewStorageVolume(db)
vc, err := NewDefaultVolumeController(repo, basePath, fs)
if err != nil {
return nil, fmt.Errorf("failed to create volume controller: %w", err)
}
for _, vol := range volumes {
// create root volume dir
err = fs.MkdirAll(vol.Path, 0o755)
if err != nil {
return nil, fmt.Errorf("failed to create volume dir: %w", err)
}
// create volume record in db
_, err = repo.Create(context.Background(), *vol)
if err != nil {
return nil, fmt.Errorf("failed to create volume record: %w", err)
}
}
return &VolumeControllerTestKit{
BasicVolController: vc,
Fs: fs,
Volumes: volumes,
}, nil
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/storage"
basicController "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Download fetches files from a given S3 bucket. The key may be a directory ending
// with `/` or have a wildcard (`*`) so it handles normal S3 folders but it does
// not handle x-directory.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Download(ctx context.Context, sourceSpecs *types.SpecConfig) (types.StorageVolume, error) {
endTrace := observability.StartTrace("s3_download_duration")
defer endTrace()
var storageVol types.StorageVolume
source, err := DecodeInputSpec(sourceSpecs)
if err != nil {
log.Errorw("s3_download_failure", "error", err)
return types.StorageVolume{}, err
}
storageVol, err = s.volController.CreateVolume(storage.VolumeSourceS3)
if err != nil {
log.Errorw("s3_volume_create_failure", "error", fmt.Errorf("failed to create storage volume: %w", err))
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %v", err)
}
resolvedObjects, err := resolveStorageKey(ctx, s.Client, &source)
if err != nil {
log.Errorw("s3_resolve_key_failure", "error", fmt.Errorf("failed to resolve storage key: %v", err))
return types.StorageVolume{}, fmt.Errorf("failed to resolve storage key: %v", err)
}
for _, resolvedObject := range resolvedObjects {
err = s.downloadObject(ctx, &source, resolvedObject, storageVol.Path)
if err != nil {
log.Errorw("s3_download_object_failure", "error", fmt.Errorf("failed to download s3 object: %v", err))
return types.StorageVolume{}, fmt.Errorf("failed to download s3 object: %v", err)
}
}
// after data is filled within the volume, we have to lock it
err = s.volController.LockVolume(storageVol.Path)
if err != nil {
log.Errorw("s3_volume_lock_failure", "error", fmt.Errorf("failed to lock storage volume: %v", err))
return types.StorageVolume{}, fmt.Errorf("failed to lock storage volume: %v", err)
}
log.Infow("s3_download_success", "volumeID", storageVol.ID, "path", storageVol.Path)
return storageVol, nil
}
func (s *Storage) downloadObject(ctx context.Context, source *InputSource, object s3Object, volPath string) error {
endTrace := observability.StartTrace("s3_download_object_duration")
defer endTrace()
outputPath := filepath.Join(volPath, *object.key)
// use the same file system instance used by the Volume Controller
var fs afero.Fs
if basicVolController, ok := s.volController.(*basicController.BasicVolumeController); ok {
fs = basicVolController.FS
}
err := fs.MkdirAll(outputPath, 0o755)
if err != nil {
log.Errorw("s3_create_directory_failure", "path", outputPath, "error", fmt.Errorf("failed to create directory: %v", err))
return fmt.Errorf("failed to create directory: %v", err)
}
if object.isDir {
// if object is a directory, we don't need to download it (just create the dir)
return nil
}
outputFile, err := fs.OpenFile(outputPath, os.O_RDWR|os.O_CREATE, 0o755)
if err != nil {
log.Errorw("s3_open_file_failure", "path", outputPath, "error", err)
return err
}
defer outputFile.Close()
log.Debugw("Downloading s3 object", "objectKey", *object.key, "outputPath", outputPath)
_, err = s.downloader.Download(ctx, outputFile, &s3.GetObjectInput{
Bucket: aws.String(source.Bucket),
Key: object.key,
IfMatch: object.eTag,
})
if err != nil {
log.Errorw("s3_download_failure", "objectKey", *object.key, "error", fmt.Errorf("failed to download file: %w", err))
return fmt.Errorf("failed to download file: %w", err)
}
log.Infow("s3_download_object_success", "objectKey", *object.key)
return nil
}
// resolveStorageKey returns a list of s3 objects within a bucket according to the key provided.
func resolveStorageKey(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := source.Key
if key == "" {
err := fmt.Errorf("key is required")
log.Errorw("s3_resolve_key_failure", "error", err)
return nil, err
}
// Check if the key represents a single object
if !strings.HasSuffix(key, "/") && !strings.Contains(key, "*") {
return resolveSingleObject(ctx, client, source)
}
// key represents multiple objects
return resolveObjectsWithPrefix(ctx, client, source)
}
func resolveSingleObject(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
headObjectInput := &s3.HeadObjectInput{
Bucket: aws.String(source.Bucket),
Key: aws.String(key),
}
headObjectOut, err := client.HeadObject(ctx, headObjectInput)
if err != nil {
log.Errorw("s3_head_object_failure", "key", key, "error", err)
return []s3Object{}, fmt.Errorf("failed to retrieve object metadata: %v", err)
}
if strings.HasPrefix(*headObjectOut.ContentType, "application/x-directory") {
err := fmt.Errorf("x-directory is not yet handled")
log.Errorw("s3_directory_handling_failure", "key", key, "error", err)
return []s3Object{}, err
}
return []s3Object{
{
key: aws.String(source.Key),
eTag: headObjectOut.ETag,
size: *headObjectOut.ContentLength,
},
}, nil
}
func resolveObjectsWithPrefix(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
// List objects with the given prefix
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(source.Bucket),
Prefix: aws.String(key),
}
var objects []s3Object
paginator := s3.NewListObjectsV2Paginator(client, listObjectsInput)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
log.Errorw("s3_list_objects_failure", "error", fmt.Errorf("failed to list objects: %v", err))
return nil, fmt.Errorf("failed to list objects: %v", err)
}
for _, obj := range page.Contents {
objects = append(objects, s3Object{
key: aws.String(*obj.Key),
size: *obj.Size,
isDir: strings.HasSuffix(*obj.Key, "/"),
})
}
}
log.Infow("s3_resolve_objects_with_prefix_success", "objectCount", len(objects))
return objects, nil
}
// s3/aws_config.go
package s3
import (
"context"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"gitlab.com/nunet/device-management-service/observability"
)
// GetAWSDefaultConfig returns the default AWS config based on environment variables,
// shared configuration and shared credentials files.
func GetAWSDefaultConfig() (aws.Config, error) {
endTrace := observability.StartTrace("get_aws_default_config_duration")
defer endTrace()
var optFns []func(*config.LoadOptions) error
cfg, err := config.LoadDefaultConfig(context.Background(), optFns...)
if err != nil {
log.Errorw("get_aws_default_config_failure", "error", err)
return aws.Config{}, err
}
log.Infow("get_aws_default_config_success")
return cfg, nil
}
// hasValidCredentials checks if the provided AWS config has valid credentials.
func hasValidCredentials(config aws.Config) bool {
endTrace := observability.StartTrace("has_valid_credentials_duration")
defer endTrace()
credentials, err := config.Credentials.Retrieve(context.Background())
if err != nil {
log.Errorw("has_valid_credentials_failure", "error", err)
return false
}
if !credentials.HasKeys() {
log.Errorw("has_valid_credentials_failure_no_keys")
return false
}
log.Infow("has_valid_credentials_success")
return true
}
// sanitizeKey removes trailing spaces and wildcards
func sanitizeKey(key string) string {
endTrace := observability.StartTrace("sanitize_key_duration")
defer endTrace()
sanitizedKey := strings.TrimSuffix(strings.TrimSpace(key), "*")
log.Infow("sanitize_key_success", "sanitizedKey", sanitizedKey)
return sanitizedKey
}
// s3/client.go
package s3
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
)
type Storage struct {
*s3.Client
volController storage.VolumeController
downloader *s3Manager.Downloader
uploader *s3Manager.Uploader
}
type s3Object struct {
key *string
eTag *string
size int64
isDir bool
}
// NewClient creates a new S3Storage which includes a S3-SDK client.
// It depends on a VolumeController to manage the volumes being acted upon.
func NewClient(config aws.Config, volController storage.VolumeController) (*Storage, error) {
endTrace := observability.StartTrace("new_client_duration")
defer endTrace()
if !hasValidCredentials(config) {
err := fmt.Errorf("invalid credentials")
log.Errorw("new_client_invalid_credentials", "error", err)
return nil, err
}
s3Client := s3.NewFromConfig(config)
storage := &Storage{
s3Client,
volController,
s3Manager.NewDownloader(s3Client),
s3Manager.NewUploader(s3Client),
}
log.Infow("new_client_success")
return storage, nil
}
// Size calculates the size of a given object in S3.
func (s *Storage) Size(ctx context.Context, source *types.SpecConfig) (uint64, error) {
endTrace := observability.StartTrace("s3_size_duration")
defer endTrace()
inputSource, err := DecodeInputSpec(source)
if err != nil {
log.Errorw("s3_size_decode_input_spec_failure", "error", err)
return 0, fmt.Errorf("failed to decode input spec: %v", err)
}
input := &s3.HeadObjectInput{
Bucket: aws.String(inputSource.Bucket),
Key: aws.String(inputSource.Key),
}
output, err := s.HeadObject(ctx, input)
if err != nil {
log.Errorw("s3_size_head_object_failure", "error", err)
return 0, fmt.Errorf("failed to get object size: %v", err)
}
log.Infow("s3_size_success", "bucket", inputSource.Bucket, "key", inputSource.Key, "size", *output.ContentLength)
return uint64(*output.ContentLength), nil
}
// Compile time interface check
// var _ storage.StorageProvider = (*S3Storage)(nil)
// s3/input_source.go
package s3
import (
"fmt"
"github.com/fatih/structs"
"github.com/go-viper/mapstructure/v2"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
type InputSource struct {
Bucket string
Key string
Filter string
Region string
Endpoint string
}
func (s InputSource) Validate() error {
if s.Bucket == "" {
err := fmt.Errorf("invalid s3 storage params: bucket cannot be empty")
log.Errorw("s3_input_source_validation_failure", "error", err)
return err
}
return nil
}
func (s InputSource) ToMap() map[string]interface{} {
return structs.Map(s)
}
func DecodeInputSpec(spec *types.SpecConfig) (InputSource, error) {
endTrace := observability.StartTrace("decode_input_spec_duration")
defer endTrace()
if !spec.IsType(types.StorageProviderS3) {
err := fmt.Errorf("invalid storage source type. Expected %s but received %s", types.StorageProviderS3, spec.Type)
log.Errorw("decode_input_spec_invalid_type_failure", "error", err)
return InputSource{}, err
}
inputParams := spec.Params
if inputParams == nil {
err := fmt.Errorf("invalid storage input source params. cannot be nil")
log.Errorw("decode_input_spec_nil_params_failure", "error", err)
return InputSource{}, err
}
var c InputSource
if err := mapstructure.Decode(spec.Params, &c); err != nil {
log.Errorw("decode_input_spec_decode_failure", "error", err)
return c, err
}
if err := c.Validate(); err != nil {
log.Errorw("decode_input_spec_validation_failure", "error", err)
return c, err
}
log.Infow("decode_input_spec_success", "bucket", c.Bucket)
return c, nil
}
// s3/upload.go
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/observability"
basicController "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Upload uploads all files (recursively) from a local volume to an S3 bucket.
// It handles directories.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Upload(ctx context.Context, vol types.StorageVolume, destinationSpecs *types.SpecConfig) error {
endTrace := observability.StartTrace("s3_upload_duration")
defer endTrace()
target, err := DecodeInputSpec(destinationSpecs)
if err != nil {
log.Errorw("s3_upload_decode_spec_failure", "error", err)
return fmt.Errorf("failed to decode input spec: %v", err)
}
sanitizedKey := sanitizeKey(target.Key)
// set file system to act upon based on the volume controller implementation
var fs afero.Fs
if basicVolController, ok := s.volController.(*basicController.BasicVolumeController); ok {
fs = basicVolController.FS
}
log.Debugw("Uploading files", "sourcePath", vol.Path, "bucket", target.Bucket, "key", sanitizedKey)
err = afero.Walk(fs, vol.Path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
log.Errorw("s3_upload_walk_failure", "error", err)
return err
}
// Skip directories
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(vol.Path, filePath)
if err != nil {
log.Errorw("s3_upload_relative_path_failure", "error", err)
return fmt.Errorf("failed to get relative path: %v", err)
}
// Construct the S3 key by joining the sanitized key and the relative path
s3Key := filepath.Join(sanitizedKey, relPath)
file, err := fs.Open(filePath)
if err != nil {
log.Errorw("s3_upload_open_file_failure", "filePath", filePath, "error", err)
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
log.Debugw("Uploading file to S3", "filePath", filePath, "bucket", target.Bucket, "key", s3Key)
_, err = s.uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(target.Bucket),
Key: aws.String(s3Key),
Body: file,
})
if err != nil {
log.Errorw("s3_upload_put_object_failure", "filePath", filePath, "error", err)
return fmt.Errorf("failed to upload file to S3: %v", err)
}
log.Infow("s3_upload_file_success", "filePath", filePath, "bucket", target.Bucket, "key", s3Key)
return nil
})
if err != nil {
log.Errorw("s3_upload_failure", "error", err)
return fmt.Errorf("upload failed. It's possible that some files were uploaded; Error: %v", err)
}
log.Infow("s3_upload_success", "sourcePath", vol.Path, "bucket", target.Bucket)
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"fmt"
"reflect"
"slices"
"github.com/hashicorp/go-version"
)
// Connectivity represents the network configuration
type Connectivity struct {
Ports []int `json:"ports" description:"Ports that need to be open for the job to run"`
VPN bool `json:"vpn" description:"Whether VPN is required"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Connectivity] = (*Connectivity)(nil)
_ Calculable[Connectivity] = (*Connectivity)(nil)
)
func (c *Connectivity) Compare(other Connectivity) (Comparison, error) {
if reflect.DeepEqual(*c, other) {
return Equal, nil
}
if IsStrictlyContainedInt(c.Ports, other.Ports) && (c.VPN && other.VPN || c.VPN && !other.VPN) {
return Better, nil
}
return None, nil
}
func (c *Connectivity) Add(other Connectivity) error {
portSet := make(map[int]struct{}, len(c.Ports))
for _, port := range c.Ports {
portSet[port] = struct{}{}
}
for _, port := range other.Ports {
if _, exists := portSet[port]; !exists {
c.Ports = append(c.Ports, port)
}
}
// Set VPN to true if other has VPN
c.VPN = c.VPN || other.VPN
return nil
}
func (c *Connectivity) Subtract(other Connectivity) error {
if other.VPN {
c.VPN = false
}
// Filter out the ports that exist in 'other' from 'c'
filteredPorts := c.Ports[:0]
for _, port := range c.Ports {
if !slices.Contains(other.Ports, port) {
filteredPorts = append(filteredPorts, port)
}
}
// Set the filtered ports back to c
if len(filteredPorts) == 0 {
c.Ports = nil
} else {
c.Ports = filteredPorts
}
return nil
}
// TimeInformation represents the time constraints
type TimeInformation struct {
Units string `json:"units" description:"Time units"`
MaxTime int `json:"max_time" description:"Maximum time that job should run"`
Preference int `json:"preference" description:"Time preference"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[TimeInformation] = (*TimeInformation)(nil)
_ Calculable[TimeInformation] = (*TimeInformation)(nil)
)
func (t *TimeInformation) Add(other TimeInformation) error {
t.MaxTime += other.MaxTime
// If there are other time fields, add them here
// Example:
// t.MinTime += other.MinTime
// t.AvgTime += other.AvgTime
return nil
}
func (t *TimeInformation) Subtract(other TimeInformation) error {
t.MaxTime -= other.MaxTime
// If there are other time fields, subtract them here
// Example:
// t.MinTime -= other.MinTime
// t.AvgTime -= other.AvgTime
return nil
}
func (t *TimeInformation) TotalTime() int {
switch t.Units {
case "seconds":
return t.MaxTime
case "minutes":
return t.MaxTime * 60
case "hours":
return t.MaxTime * 60 * 60
case "days":
return t.MaxTime * 60 * 60 * 24
default:
return t.MaxTime
}
}
func (t *TimeInformation) Compare(other TimeInformation) (Comparison, error) {
if reflect.DeepEqual(t, other) {
return Equal, nil
}
ownTotalTime := t.TotalTime()
otherTotalTime := other.TotalTime()
if ownTotalTime == otherTotalTime {
return Equal, nil
}
if ownTotalTime < otherTotalTime {
return Worse, nil
}
return Better, nil
}
// PriceInformation represents the pricing information
type PriceInformation struct {
Currency string `json:"currency" description:"Currency used for pricing"`
CurrencyPerHour int `json:"currency_per_hour" description:"Price charged per hour"`
TotalPerJob int `json:"total_per_job" description:"Maximum total price or budget of the job"`
Preference int `json:"preference" description:"Pricing preference"`
}
// implementing Comparable interface
var _ Comparable[PriceInformation] = (*PriceInformation)(nil)
func (p *PriceInformation) Compare(other PriceInformation) (Comparison, error) {
if reflect.DeepEqual(p, other) {
return Equal, nil
}
if p.Currency == other.Currency {
if p.TotalPerJob == other.TotalPerJob {
if p.CurrencyPerHour == other.CurrencyPerHour {
return Equal, nil
} else if p.CurrencyPerHour < other.CurrencyPerHour {
return Better, nil
}
return Worse, nil
}
if p.TotalPerJob < other.TotalPerJob {
if p.CurrencyPerHour <= other.CurrencyPerHour {
return Better, nil
}
return Worse, nil
}
return Worse, nil
}
return None, nil
}
func (p *PriceInformation) Equal(price PriceInformation) bool {
return p.Currency == price.Currency &&
p.CurrencyPerHour == price.CurrencyPerHour &&
p.TotalPerJob == price.TotalPerJob &&
p.Preference == price.Preference
}
// Library represents the libraries
type Library struct {
Name string `json:"name" description:"Name of the library"`
Constraint string `json:"constraint" description:"Constraint of the library"`
Version string `json:"version" description:"Version of the library"`
}
// implementing Comparable interface
var _ Comparable[Library] = (*Library)(nil)
func (lib *Library) Compare(other Library) (Comparison, error) {
ownVersion, err := version.NewVersion(lib.Version)
if err != nil {
return None, fmt.Errorf("error parsing version: %v", err)
}
constraints, err := version.NewConstraint(other.Constraint + " " + other.Version)
if err != nil {
return None, fmt.Errorf("error parsing constraint: %v", err)
}
// return 'None' if the names of the libraries are different
if lib.Name != other.Name {
return None, nil
}
// else return 'Equal if versions of libraries are equal and the constraint is '='
if other.Constraint == "=" && constraints.Check(ownVersion) {
return Equal, nil
}
// else return 'Better' if versions of libraries match the constraint
if constraints.Check(ownVersion) {
return Better, nil
}
// else return 'Worse'
return Worse, nil
}
func (lib *Library) Equal(library Library) bool {
if lib.Name == library.Name && lib.Constraint == library.Constraint && lib.Version == library.Version {
return true
}
return false
}
// Locality represents the locality
type Locality struct {
Kind string `json:"kind" description:"Kind of the region (geographic, nunet-defined, etc)"`
Name string `json:"name" description:"Name of the region"`
}
// implementing Comparable interface
var _ Comparable[Locality] = (*Locality)(nil)
func (loc *Locality) Compare(other Locality) (Comparison, error) {
if loc.Kind != other.Kind {
return None, nil
}
if loc.Name == other.Name {
return Equal, nil
}
return Worse, nil
}
func (loc *Locality) Equal(locality Locality) bool {
if loc.Kind == locality.Kind && loc.Name == locality.Name {
return true
}
return false
}
// KYC represents the KYC data
type KYC struct {
Type string `json:"type" description:"Type of KYC"`
Data string `json:"data" description:"Data required for KYC"`
}
// implementing Comparable interface
var _ Comparable[KYC] = (*KYC)(nil)
func (k *KYC) Compare(other KYC) (Comparison, error) {
if reflect.DeepEqual(*k, other) {
return Equal, nil
}
return None, nil
}
func (k *KYC) Equal(kyc KYC) bool {
if k.Type == kyc.Type && k.Data == kyc.Data {
return true
}
return false
}
// JobType represents the type of the job
type JobType string
const (
Batch JobType = "batch"
SingleRun JobType = "single_run"
Recurring JobType = "recurring"
LongRunning JobType = "long_running"
)
// implementing Comparable and Calculable interface
var _ Comparable[JobType] = (*JobType)(nil)
func (j JobType) Compare(other JobType) (Comparison, error) {
if reflect.DeepEqual(j, other) {
return Equal, nil
}
return None, nil
}
// JobTypes a slice of JobType
type JobTypes []JobType
// implementing Comparable and Calculable interfaces
var (
_ Comparable[JobTypes] = (*JobTypes)(nil)
_ Calculable[JobTypes] = (*JobTypes)(nil)
)
func (j *JobTypes) Add(other JobTypes) error {
existing := *j
result := existing[:0]
existingSet := make(map[JobType]struct{}, len(result))
for _, job := range *j {
result = append(result, job)
existingSet[job] = struct{}{}
}
for _, job := range other {
if _, exists := existingSet[job]; !exists {
result = append(result, job)
existingSet[job] = struct{}{}
}
}
*j = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (j *JobTypes) Subtract(other JobTypes) error {
existing := *j
result := existing[:0]
toRemove := make(map[JobType]struct{}, len(other))
for _, job := range other {
toRemove[job] = struct{}{}
}
for _, job := range existing {
if _, found := toRemove[job]; !found {
result = append(result, job)
}
}
*j = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (j *JobTypes) Compare(other JobTypes) (Comparison, error) {
// we know that interfaces here are slices, so need to assert first
l := ConvertTypedSliceToUntypedSlice(*j)
r := ConvertTypedSliceToUntypedSlice(other)
if !IsSameShallowType(l, r) {
// cannot compare different types
return None, nil
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return Equal, nil
case IsStrictlyContained(l, r):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
return Better, nil
case IsStrictlyContained(r, l):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
return Worse, nil
// TODO: this comparator does not take into account options when several job types are available and several job types are required
// in the same data structure; this is why the test fails;
}
return None, nil
}
func (j *JobTypes) Contains(jobType JobType) bool {
for _, j := range *j {
if j == jobType {
return true
}
}
return false
}
type Libraries []Library
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Libraries] = (*Libraries)(nil)
_ Calculable[Libraries] = (*Libraries)(nil)
)
func (l *Libraries) Add(other Libraries) error {
existing := make(map[Library]struct{}, len(*l))
for _, lib := range *l {
existing[lib] = struct{}{}
}
for _, otherLibrary := range other {
if _, found := existing[otherLibrary]; !found {
*l = append(*l, otherLibrary)
existing[otherLibrary] = struct{}{}
}
}
return nil
}
func (l *Libraries) Subtract(other Libraries) error {
// remove from array
existing := *l
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[Library]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex2 := range existing {
if _, found := toRemove[ex2]; !found {
result = append(result, ex2)
}
}
*l = result
return nil
}
func (l *Libraries) Compare(other Libraries) (Comparison, error) {
interimComparison1 := make([][]Comparison, 0)
for _, otherLibrary := range other {
var interimComparison2 []Comparison
for _, ownLibrary := range *l {
c, err := ownLibrary.Compare(otherLibrary)
if err != nil {
return None, fmt.Errorf("error comparing library: %v", err)
}
interimComparison2 = append(interimComparison2, c)
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
finalComparison := make([]Comparison, 0)
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, Worse) {
return Worse, nil
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal, nil
}
return Better, nil
}
func (l *Libraries) Contains(library Library) bool {
for _, lib := range *l {
if lib.Equal(library) {
return true
}
}
return false
}
type Localities []Locality
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Localities] = (*Localities)(nil)
_ Calculable[Localities] = (*Localities)(nil)
)
func (l *Localities) Compare(other Localities) (Comparison, error) {
interimComparison := make([]map[string]Comparison, 0)
for _, otherLocality := range other {
field := make(map[string]Comparison)
field[otherLocality.Kind] = None
for _, ownLocality := range *l {
if ownLocality.Kind == otherLocality.Kind {
c, err := ownLocality.Compare(otherLocality)
if err != nil {
return None, fmt.Errorf("error comparing locality: %v", err)
}
field[otherLocality.Kind] = c
// this is to make sure that we have a comparison even if slice dimentiones do not match
}
}
interimComparison = append(interimComparison, field)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []Comparison
for _, c := range interimComparison {
for _, v := range c { // we know that there is only one value in the map
finalComparison = append(finalComparison, v)
}
}
if slices.Contains(finalComparison, Worse) {
return Worse, nil
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal, nil
}
return Better, nil
}
func (l *Localities) Add(other Localities) error {
existing := make(map[Locality]struct{}, len(*l))
for _, loc := range *l {
existing[loc] = struct{}{}
}
for _, otherLocality := range other {
if _, found := existing[otherLocality]; !found {
*l = append(*l, otherLocality)
existing[otherLocality] = struct{}{}
}
}
return nil
}
func (l *Localities) Subtract(other Localities) error {
// remove from array
existing := *l
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[Locality]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex := range existing {
if _, found := toRemove[ex]; !found {
result = append(result, ex)
}
}
*l = result
return nil
}
func (l *Localities) Contains(locality Locality) bool {
for _, loc := range *l {
if loc.Equal(locality) {
return true
}
}
return false
}
type KYCs []KYC
// implementing Comparable and Calculable interfaces
var (
_ Comparable[KYCs] = (*KYCs)(nil)
_ Calculable[KYCs] = (*KYCs)(nil)
)
func (k *KYCs) Compare(other KYCs) (Comparison, error) {
if reflect.DeepEqual(*k, other) {
return Equal, nil
} else if len(other) == 0 && len(*k) != 0 {
return Better, nil
}
for _, ownKYC := range *k {
for _, otherKYC := range other {
comp, err := ownKYC.Compare(otherKYC)
if err != nil {
return None, fmt.Errorf("error comparing KYC: %v", err)
}
if comp != None {
return comp, nil
}
}
}
return None, nil
}
func (k *KYCs) Add(other KYCs) error {
existing := make(map[KYC]struct{}, len(*k))
for _, kyc := range *k {
existing[kyc] = struct{}{}
}
for _, otherKYC := range other {
if _, found := existing[otherKYC]; !found {
*k = append(*k, otherKYC)
existing[otherKYC] = struct{}{}
}
}
return nil
}
func (k *KYCs) Subtract(other KYCs) error {
// remove from array
existing := *k
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[KYC]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex := range existing {
if _, found := toRemove[ex]; !found {
result = append(result, ex)
}
}
*k = result
return nil
}
func (k *KYCs) Contains(kyc KYC) bool {
for _, k := range *k {
if k.Equal(kyc) {
return true
}
}
return false
}
type PricesInformation []PriceInformation
// implementing Comparable and Calculable interfaces
var (
_ Comparable[PricesInformation] = (*PricesInformation)(nil)
_ Calculable[PricesInformation] = (*PricesInformation)(nil)
)
func (ps *PricesInformation) Add(other PricesInformation) error {
existing := make(map[PriceInformation]struct{}, len(*ps))
for _, p := range *ps {
existing[p] = struct{}{}
}
for _, otherPrice := range other {
if _, found := existing[otherPrice]; !found {
*ps = append(*ps, otherPrice)
existing[otherPrice] = struct{}{}
}
}
return nil
}
func (ps *PricesInformation) Subtract(other PricesInformation) error {
if len(other) == 0 {
return nil // Nothing to subtract, no operation needed
}
toRemove := make(map[PriceInformation]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
result := (*ps)[:0] // Reuse the slice's underlying array
for _, ex := range *ps {
if _, found := toRemove[ex]; !found {
result = append(result, ex) // Keep entries not in 'toRemove'
}
}
*ps = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (ps *PricesInformation) Compare(other PricesInformation) (Comparison, error) {
if reflect.DeepEqual(*ps, other) {
return Equal, nil
}
for _, ownPrice := range *ps {
for _, otherPrice := range other {
c, err := ownPrice.Compare(otherPrice)
if err != nil {
return None, fmt.Errorf("error comparing price: %v", err)
}
if c != None {
return c, nil
}
}
}
return None, nil
}
func (ps *PricesInformation) Contains(price PriceInformation) bool {
for _, p := range *ps {
if p.Equal(price) {
return true
}
}
return false
}
// HardwareCapability represents the hardware capability of the machine
type HardwareCapability struct {
Executors Executors `json:"executor" description:"Executor type required for the job (docker, vm, wasm, or others)"`
JobTypes JobTypes `json:"type" description:"Details about type of the job (One time, batch, recurring, long running). Refer to dms.jobs package for jobType data model"`
Resources Resources `json:"resources" description:"Resources required for the job"`
Libraries Libraries `json:"libraries" description:"Libraries required for the job"`
Localities Localities `json:"locality" description:"Preferred localities of the machine for execution"`
Connectivity Connectivity `json:"connectivity" description:"Network configuration required"`
Price PricesInformation `json:"price" description:"Pricing information"`
Time TimeInformation `json:"time" description:"Time constraints"`
KYCs KYCs
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[HardwareCapability] = (*HardwareCapability)(nil)
_ Calculable[HardwareCapability] = (*HardwareCapability)(nil)
)
// Compare compares two HardwareCapability objects
func (c *HardwareCapability) Compare(other HardwareCapability) (Comparison, error) {
complexComparison := make(ComplexComparison)
compareFields := []struct {
name string
compare func() (Comparison, error)
}{
{"Executors", func() (Comparison, error) { return c.Executors.Compare(other.Executors) }},
{"JobTypes", func() (Comparison, error) { return c.JobTypes.Compare(other.JobTypes) }},
{"Resources", func() (Comparison, error) { return c.Resources.Compare(other.Resources) }},
{"Libraries", func() (Comparison, error) { return c.Libraries.Compare(other.Libraries) }},
{"Localities", func() (Comparison, error) { return c.Localities.Compare(other.Localities) }},
{"Connectivity", func() (Comparison, error) { return c.Connectivity.Compare(other.Connectivity) }},
{"Price", func() (Comparison, error) { return c.Price.Compare(other.Price) }},
{"Time", func() (Comparison, error) { return c.Time.Compare(other.Time) }},
{"KYCs", func() (Comparison, error) { return c.KYCs.Compare(other.KYCs) }},
}
for _, field := range compareFields {
result, err := field.compare()
if err != nil {
return None, fmt.Errorf("error comparing %s: %v", field.name, err)
}
complexComparison[field.name] = result
}
return complexComparison.Result(), nil
}
// Add adds the resources of the given HardwareCapability to the current HardwareCapability
func (c *HardwareCapability) Add(other HardwareCapability) error {
// Executors
if err := c.Executors.Add(other.Executors); err != nil {
return fmt.Errorf("error adding Executors")
}
// JobTypes
if err := c.JobTypes.Add(other.JobTypes); err != nil {
return fmt.Errorf("error adding JobTypes: %v", err)
}
// Resources
if err := c.Resources.Add(other.Resources); err != nil {
return err
}
// Libraries
if err := c.Libraries.Add(other.Libraries); err != nil {
return fmt.Errorf("error adding Libraries: %v", err)
}
// Localities
if err := c.Localities.Add(other.Localities); err != nil {
return fmt.Errorf("error adding Localities: %v", err)
}
// Connectivity
if err := c.Connectivity.Add(other.Connectivity); err != nil {
return fmt.Errorf("error adding Connectivity: %v", err)
}
// Price
if err := c.Price.Add(other.Price); err != nil {
return fmt.Errorf("error adding Price: %v", err)
}
// Time
if err := c.Time.Add(other.Time); err != nil {
return fmt.Errorf("error adding Time: %v", err)
}
// KYCs
if err := c.KYCs.Add(other.KYCs); err != nil {
return fmt.Errorf("error adding KYCs: %v", err)
}
return nil
}
// Subtract subtracts the resources of the given HardwareCapability from the current HardwareCapability
func (c *HardwareCapability) Subtract(cap HardwareCapability) error {
// Executors
if err := c.Executors.Subtract(cap.Executors); err != nil {
return fmt.Errorf("error subtracting Executors: %v", err)
}
// JobTypes
if err := c.JobTypes.Subtract(cap.JobTypes); err != nil {
return fmt.Errorf("error comparing JobTypes: %v", err)
}
// Resources
if err := c.Resources.Subtract(cap.Resources); err != nil {
return fmt.Errorf("error subtracting Resources: %v", err)
}
// Libraries
if err := c.Libraries.Subtract(cap.Libraries); err != nil {
return fmt.Errorf("error subtracting Libraries: %v", err)
}
// Localities
if err := c.Localities.Subtract(cap.Localities); err != nil {
return fmt.Errorf("error subtracting Localities: %v", err)
}
// Connectivity
if err := c.Connectivity.Subtract(cap.Connectivity); err != nil {
return fmt.Errorf("error subtracting Connectivity: %v", err)
}
// Price
if err := c.Price.Subtract(cap.Price); err != nil {
return fmt.Errorf("error subtracting Price: %v", err)
}
// Time
if err := c.Time.Subtract(cap.Time); err != nil {
return fmt.Errorf("error subtracting Time: %v", err)
}
// KYCs
if err := c.KYCs.Subtract(cap.KYCs); err != nil {
return fmt.Errorf("error subtracting KYCs: %v", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import "reflect"
// Comparable public Comparable interface to be enforced on types that can be compared
type Comparable[T any] interface {
Compare(other T) (Comparison, error)
}
type Calculable[T any] interface {
Add(other T) error
Subtract(other T) error
}
type PreferenceString string
const (
Hard PreferenceString = "Hard"
Soft PreferenceString = "Soft"
)
type Preference struct {
TypeName string
Strength PreferenceString
DefaultComparatorOverride Comparator
}
type Comparator func(l, r interface{}, preference ...Preference) Comparison
// ComplexCompare helper function to return a complex comparison of two complex types
// this uses reflection, could become a performance bottleneck
// with generics, we wouldn't need this function and could use the ComplexCompare method directly
func ComplexCompare(l, r interface{}) ComplexComparison {
// Complex comparison is a comparison of two complex types
// Which have nested fields that need to be considered together
// before a final comparison for the whole complex type can be made
// it is a helper function used in some type comparators
complexComparison := make(map[string]Comparison)
val1 := reflect.ValueOf(l)
val2 := reflect.ValueOf(r)
for i := 0; i < val1.NumField(); i++ {
field1 := val1.Field(i)
field2 := val2.Field(i)
// we need a pointer to field1 to call the ComplexCompare method
// if the field is not addressable we need to create a pointer to it
// and then call the ComplexCompare method on it
var field1Ptr reflect.Value
if field1.CanAddr() {
field1Ptr = field1.Addr()
} else {
ptr := reflect.New(field1.Type())
ptr.Elem().Set(field1)
field1Ptr = ptr
}
compareMethod := field1Ptr.MethodByName("Compare")
if compareMethod.IsValid() &&
compareMethod.Type().NumIn() == 1 &&
compareMethod.Type().In(0) == field1.Type() {
result := compareMethod.Call([]reflect.Value{field2})[0].Interface().(Comparison)
complexComparison[val1.Type().Field(i).Name] = result
} else {
complexComparison[val1.Type().Field(i).Name] = None
}
}
return complexComparison
}
func NumericComparator[T float64 | float32 | int | int32 | int64 | uint64](l, r T, _ ...Preference) Comparison {
// comparator for numeric types:
// left represents machine capabilities;
// right represents required capabilities;
switch {
case l == r:
return Equal
case l < r:
return Worse
default:
return Better
}
}
func LiteralComparator[T ~string](l, r T, _ ...Preference) Comparison {
// comparator for literal (string-like) types:
// left represents machine capabilities;
// right represents required capabilities;
// which can only be equal or not equal.
// ComplexCompare the string values
if l == r {
return Equal
}
return None
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
// Comparison is a type for comparison results
type Comparison string
const (
// Worse means left object is 'worse' than right object
Worse Comparison = "Worse"
// Better means left object is 'better' than right object
Better Comparison = "Better"
// Equal means objects on the left and right are 'equally good'
Equal Comparison = "Equal"
// None means comparison could not be performed
None Comparison = "None"
)
// And returns the result of AND operation of two Comparison values
// it respects the following table of truth:
// | AND | Better | Worse | Equal | None |
// | ------ | ------ |--------|--------|--------|
// | Better | Better | Worse | Better | None |
// | Worse | Worse | Worse | Worse | None |
// | Equal | Better | Worse | Equal | None |
// | None | None | None | None | None |
func (c Comparison) And(cmp Comparison) Comparison {
if c == cmp {
return c
}
switch c {
case Equal:
switch cmp {
case Better:
return Better
case Worse:
return Worse
case Equal:
return Equal
default:
return None
}
case Better:
switch cmp {
case Worse:
return Worse
case Equal:
return Better
default:
return None
}
case Worse:
return Worse
default:
return None
}
}
// ComplexComparison is a map of string to Comparison
type ComplexComparison map[string]Comparison
// Result returns the result of AND operation of all Comparison values in the ComplexComparison
func (c *ComplexComparison) Result() Comparison {
result := Equal
for _, comparison := range *c {
result = result.And(comparison)
}
return result
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
// ExecutionStatus is the status of an execution
type ExecutionStatus string
const (
ExecutionStatusPending ExecutionStatus = "pending"
ExecutionStatusRunning ExecutionStatus = "running"
ExecutionStatusPaused ExecutionStatus = "paused"
ExecutionStatusFailed ExecutionStatus = "failed"
ExecutionStatusSuccess ExecutionStatus = "success"
)
// ExecutionRequest is the request object for executing a job
type ExecutionRequest struct {
JobID string // ID of the job to execute
ExecutionID string // ID of the execution
EngineSpec *SpecConfig // Engine spec for the execution
Resources *Resources // Resources for the execution
Inputs []*StorageVolumeExecutor // Input volumes for the execution
Outputs []*StorageVolumeExecutor // Output volumes for the results
ResultsDir string // Directory to store the results
ProvisionScripts map[string][]byte // (named) Scripts to run when initiating the execution
}
// ExecutionListItem is the result of the current executions.
type ExecutionListItem struct {
ExecutionID string // ID of the execution
Running bool
}
// ExecutionResult is the result of an execution
type ExecutionResult struct {
STDOUT string `json:"stdout"` // STDOUT of the execution
STDERR string `json:"stderr"` // STDERR of the execution
ExitCode int `json:"exit_code"` // Exit code of the execution
ErrorMsg string `json:"error_msg"` // Error message if the execution failed
}
// NewExecutionResult creates a new ExecutionResult object
func NewExecutionResult(code int) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: code,
}
}
// NewFailedExecutionResult creates a new ExecutionResult object for a failed execution
// It sets the error message from the provided error and sets the exit code to -1
func NewFailedExecutionResult(err error) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: -1,
ErrorMsg: err.Error(),
}
}
// LogStreamRequest is the request object for streaming logs from an execution
type LogStreamRequest struct {
JobID string // ID of the job
ExecutionID string // ID of the execution
Tail bool // Tail the logs
Follow bool // Follow the logs
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"reflect"
)
// ExecutorType is the type of the executor
type ExecutorType string
const (
ExecutorTypeDocker ExecutorType = "docker"
ExecutorTypeFirecracker ExecutorType = "firecracker"
ExecutorTypeWasm ExecutorType = "wasm"
ExecutionStatusCodeSuccess = 0
)
// implementing Comparable interface
var _ Comparable[ExecutorType] = (*ExecutorType)(nil)
// Compare compares two ExecutorType objects
func (e ExecutorType) Compare(other ExecutorType) (Comparison, error) {
return LiteralComparator(string(e), string(other)), nil
}
// String returns the string representation of the ExecutorType
func (e ExecutorType) String() string {
return string(e)
}
// Executor is the executor type
type Executor struct {
ExecutorType ExecutorType `json:"executor_type"`
}
// implementing Comparable interface
var _ Comparable[Executor] = (*Executor)(nil)
// Compare compares two Executor objects
func (e *Executor) Compare(other Executor) (Comparison, error) {
// comparator for Executor types
// it is needed because executor type is defined as enum of ExecutorType's in types.execution.go
// left represent machine capabilities
// right represent required capabilities
// it is not so complex as the type has only one field
// therefore this method just passes it through...
return e.ExecutorType.Compare(other.ExecutorType)
}
// Equal checks if two Executor objects are equal
func (e *Executor) Equal(executor Executor) bool {
return e.ExecutorType == executor.ExecutorType
}
// Executors is a list of Executor objects
type Executors []Executor
// implementing Comparable and Calculable interface
var (
_ Comparable[Executors] = (*Executors)(nil)
_ Calculable[Executors] = (*Executors)(nil)
)
// Add adds the Executor object to another Executor object
func (e *Executors) Add(other Executors) error {
// append to Executors slice
*e = append(*e, other...)
return nil
}
// Subtract subtracts the Executor object from another Executor object
func (e *Executors) Subtract(other Executors) error {
if len(other) == 0 {
return nil
}
toRemove := make(map[ExecutorType]struct{})
for _, ex := range other {
toRemove[ex.ExecutorType] = struct{}{}
}
result := (*e)[:0]
for _, ex := range *e {
if _, found := toRemove[ex.ExecutorType]; !found {
result = append(result, ex)
}
}
*e = result[:len(result):len(result)]
return nil
}
// Contains checks if an Executor object is in the list of Executors
func (e *Executors) Contains(executor Executor) bool {
executors := *e
for _, ex := range executors {
if ex.Equal(executor) {
return true
}
}
return false
}
// Compare compares two Executors objects
func (e *Executors) Compare(other Executors) (Comparison, error) {
if reflect.DeepEqual(*e, other) {
return Equal, nil
}
// comparator for Executors types:
// left represent machine capabilities;
// right represent required capabilities;
lSlice := make([]interface{}, 0, len(*e))
rSlice := make([]interface{}, 0, len(other))
for _, ex := range *e {
lSlice = append(lSlice, ex)
}
for _, ex := range other {
rSlice = append(rSlice, ex)
}
if !IsSameShallowType(lSlice, rSlice) {
return None, nil
}
switch {
case reflect.DeepEqual(lSlice, rSlice):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return Equal, nil
case IsStrictlyContained(lSlice, rSlice):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
return Better, nil
case IsStrictlyContained(rSlice, lSlice):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
return Worse, nil
default:
return None, nil
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"fmt"
"slices"
"strings"
)
// HardwareManager defines the interface for managing machine resources.
type HardwareManager interface {
GetMachineResources() (MachineResources, error)
GetUsage() (Resources, error)
GetFreeResources() (Resources, error)
}
type GPUVendor string
const (
GPUVendorNvidia GPUVendor = "NVIDIA"
GPUVendorAMDATI GPUVendor = "AMD/ATI"
GPUVendorIntel GPUVendor = "Intel"
GPUVendorUnknown GPUVendor = "Unknown"
GPUVendorNone GPUVendor = "None"
)
// implementing Comparable interface
var _ Comparable[GPUVendor] = (*GPUVendor)(nil)
func (g GPUVendor) Compare(other GPUVendor) (Comparison, error) {
if g == other {
return Equal, nil
}
return None, nil
}
// ParseGPUVendor parses the GPU vendor string and returns the corresponding GPUVendor enum
func ParseGPUVendor(vendor string) GPUVendor {
switch {
case strings.Contains(strings.ToUpper(vendor), "NVIDIA"):
return GPUVendorNvidia
case strings.Contains(strings.ToUpper(vendor), "AMD") ||
strings.Contains(strings.ToUpper(vendor), "ATI"):
return GPUVendorAMDATI
case strings.Contains(strings.ToUpper(vendor), "INTEL"):
return GPUVendorIntel
default:
return GPUVendorUnknown
}
}
// GPU represents the GPU information
type GPU struct {
// Index is the self-reported index of the device in the system
Index int `json:"index" description:"GPU index in the system"`
// Vendor is the maker of the GPU, e.g. NVidia, AMD, Intel
Vendor GPUVendor `json:"vendor" description:"GPU vendor, e.g., NVidia, AMD, Intel"`
// PCIAddress is the PCI address of the device, in the format AAAA:BB:CC.C
// Used to discover the correct device rendering cards
PCIAddress string `json:"pci_address" description:"PCI address of the device, in the format AAAA:BB:CC.C"`
// Model represents the GPU model name, e.g., "Tesla T4", "A100"
Model string `json:"model" description:"GPU model, e.g., Tesla T4, A100"`
// VRAM is the total amount of VRAM on the device
VRAM float64 `json:"vram" description:"Total amount of VRAM on the device"`
// Gorm fields
// Team, is this the right way to do this? What is the best practice we're following?
ResourceID string `json:"resource_id" gorm:"foreignKey:ID"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[GPU] = (*GPU)(nil)
_ Calculable[GPU] = (*GPU)(nil)
)
func (g *GPU) Compare(other GPU) (Comparison, error) {
comparison := make(ComplexComparison)
// compare the VRAM
switch {
case g.VRAM > other.VRAM:
comparison["VRAM"] = Better
case g.VRAM < other.VRAM:
comparison["VRAM"] = Worse
default:
comparison["VRAM"] = Equal
}
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and VRAM is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of GPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
return comparison["VRAM"], nil
}
func (g *GPU) Add(other GPU) error {
g.VRAM += other.VRAM
return nil
}
func (g *GPU) Subtract(other GPU) error {
if g.VRAM < other.VRAM {
return fmt.Errorf("total VRAM: underflow, cannot subtract %v from %v", g.VRAM, other.VRAM)
}
g.VRAM -= other.VRAM
return nil
}
func (g *GPU) Equal(other GPU) bool {
return g.Model == other.Model &&
g.VRAM == other.VRAM &&
g.Index == other.Index &&
g.Vendor == other.Vendor &&
g.PCIAddress == other.PCIAddress
}
type GPUs []GPU
// implementing Comparable and Calculable interfaces
var (
_ Calculable[GPUs] = (*GPUs)(nil)
_ Comparable[GPUs] = (*GPUs)(nil)
)
func (gpus GPUs) Compare(other GPUs) (Comparison, error) {
interimComparison1 := make([][]Comparison, 0)
for _, otherGPU := range other {
var interimComparison2 []Comparison
for _, ownGPU := range gpus {
c, err := ownGPU.Compare(otherGPU)
if err != nil {
return None, fmt.Errorf("error comparing GPU: %v", err)
}
interimComparison2 = append(interimComparison2, c)
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []Comparison
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, Worse) {
return Worse, nil
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal, nil
}
return Better, nil
}
func (gpus GPUs) Add(other GPUs) error {
// TODO: I think this logic needs to change
// 1. if other gpu is in own gpus, add the total vram
// 2. if other gpu is not in own gpus, append it to own gpus
// assuming that the GPUs are ordered by index
// which may not be the case
otherGPUs := make(map[int]GPU)
for _, otherGPU := range other {
otherGPUs[otherGPU.Index] = otherGPU
}
for i, gpu := range gpus {
if otherGPU, ok := otherGPUs[gpu.Index]; ok {
if err := gpus[i].Add(otherGPU); err != nil {
return fmt.Errorf("failed to add GPU %s: %w", gpu.Model, err)
}
}
}
return nil
}
func (gpus GPUs) Subtract(other GPUs) error {
// assuming that the GPUs are ordered by index
// which may not be the case
otherGPUs := make(map[int]GPU)
for _, otherGPU := range other {
otherGPUs[otherGPU.Index] = otherGPU
}
for i, gpu := range gpus {
if otherGPU, ok := otherGPUs[gpu.Index]; ok {
if err := gpus[i].Subtract(otherGPU); err != nil {
return fmt.Errorf("failed to subtract GPU %s: %w", gpu.Model, err)
}
}
}
return nil
}
// MaxFreeVRAMGPU returns the GPU with the maximum free VRAM from the list of GPUs
func (gpus GPUs) MaxFreeVRAMGPU() (GPU, error) {
if len(gpus) == 0 {
return GPU{}, fmt.Errorf("no GPUs found")
}
var maxFreeVRAMGPU GPU
for _, gpu := range gpus {
if gpu.VRAM > maxFreeVRAMGPU.VRAM {
maxFreeVRAMGPU = gpu
}
}
return maxFreeVRAMGPU, nil
}
// GetWithIndex returns the GPU with the specified index
func (gpus GPUs) GetWithIndex(index int) (GPU, error) {
for _, gpu := range gpus {
if gpu.Index == index {
return gpu, nil
}
}
return GPU{}, fmt.Errorf("GPU with index %d not found", index)
}
// CPU represents the CPU information
type CPU struct {
// ClockSpeed represents the CPU clock speed in Hz
ClockSpeed float64 `json:"clock_speed" description:"CPU clock speed in Hz"`
// Cores represents the number of physical CPU cores
Cores float32 `json:"cores" description:"Number of physical CPU cores"`
// TODO: capture the below fields if required
// Model represents the CPU model, e.g., "Intel Core i7-9700K", "AMD Ryzen 9 5900X"
Model string `json:"model" description:"CPU model, e.g., Intel Core i7-9700K, AMD Ryzen 9 5900X"`
// Vendor represents the CPU manufacturer, e.g., "Intel", "AMD"
Vendor string `json:"vendor" description:"CPU manufacturer, e.g., Intel, AMD"`
// Threads represents the number of logical CPU threads (including hyperthreading)
Threads int `json:"threads" description:"Number of logical CPU threads (including hyperthreading)"`
// Architecture represents the CPU architecture, e.g., "x86", "x86_64", "arm64"
Architecture string `json:"architecture" description:"CPU architecture, e.g., x86, x86_64, arm64"`
// Cache size in bytes
CacheSize uint64 `json:"cache_size" description:"CPU cache size in bytes"`
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[CPU] = (*CPU)(nil)
_ Comparable[CPU] = (*CPU)(nil)
)
func (c *CPU) Compare(other CPU) (Comparison, error) {
perfComparison := NumericComparator(
float64(c.Cores)*c.ClockSpeed,
float64(other.Cores)*other.ClockSpeed,
)
archComparison := LiteralComparator(c.Architecture, other.Architecture)
if archComparison == Equal {
return perfComparison, nil
}
return None, nil
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and frequency is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of CPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
}
func (c *CPU) Add(other CPU) error {
c.Cores = round(c.Cores+other.Cores, 2)
return nil
}
func (c *CPU) Subtract(other CPU) error {
if c.Cores < other.Cores {
return fmt.Errorf("core: underflow, cannot subtract %v from %v", c.Cores, other.Cores)
}
c.Cores = round(c.Cores-other.Cores, 2)
return nil
}
func (c *CPU) Compute() float64 {
return float64(c.Cores) * c.ClockSpeed
}
// RAM represents the RAM information
type RAM struct {
// Size in bytes
Size float64 `json:"size" description:"Size of the RAM in bytes"`
// TODO: capture the below fields if required
// Clock speed in Hz
ClockSpeed uint64 `json:"clock_speed" description:"Clock speed of the RAM in Hz"`
// Type represents the RAM type, e.g., "DDR4", "DDR5", "LPDDR4"
Type string `json:"type" description:"RAM type, e.g., DDR4, DDR5, LPDDR4"`
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[RAM] = (*RAM)(nil)
_ Comparable[RAM] = (*RAM)(nil)
)
func (r *RAM) Compare(other RAM) (Comparison, error) {
comparison := make(ComplexComparison)
// compare the Size
comparison["Size"] = NumericComparator(r.Size, other.Size)
comparison["ClockSpeed"] = NumericComparator(r.ClockSpeed, other.ClockSpeed)
return comparison["Size"], nil
}
func (r *RAM) Add(other RAM) error {
r.Size += other.Size
return nil
}
func (r *RAM) Subtract(other RAM) error {
if r.Size < other.Size {
return fmt.Errorf("size: underflow, cannot subtract %v from %v", r.Size, other.Size)
}
r.Size -= other.Size
return nil
}
// Disk represents the disk information
type Disk struct {
// Size in bytes
Size float64 `json:"size" description:"Size of the disk in bytes"`
// TODO: capture the below fields if required
// Model represents the disk model, e.g., "Samsung 970 EVO Plus", "Western Digital Blue SN550"
Model string `json:"model" description:"Disk model, e.g., Samsung 970 EVO Plus, Western Digital Blue SN550"`
// Vendor represents the disk manufacturer, e.g., "Samsung", "Western Digital"
Vendor string `json:"vendor" description:"Disk manufacturer, e.g., Samsung, Western Digital"`
// Type represents the disk type, e.g., "SSD", "HDD", "NVMe"
Type string `json:"type" description:"Disk type, e.g., SSD, HDD, NVMe"`
// Interface represents the disk interface, e.g., "SATA", "PCIe", "M.2"
Interface string `json:"interface" description:"Disk interface, e.g., SATA, PCIe, M.2"`
// Read speed in bytes per second
ReadSpeed uint64 `json:"read_speed" description:"Read speed in bytes per second"`
// Write speed in bytes per second
WriteSpeed uint64 `json:"write_speed" description:"Write speed in bytes per second"`
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[Disk] = (*Disk)(nil)
_ Comparable[Disk] = (*Disk)(nil)
)
func (d *Disk) Compare(other Disk) (Comparison, error) {
comparison := make(ComplexComparison)
// compare the Size
comparison["Size"] = NumericComparator(d.Size, other.Size)
return comparison["Size"], nil
}
func (d *Disk) Add(other Disk) error {
d.Size += other.Size
return nil
}
func (d *Disk) Subtract(other Disk) error {
if d.Size < other.Size {
return fmt.Errorf("size: underflow, cannot subtract %v from %v", d.Size, other.Size)
}
d.Size -= other.Size
return nil
}
// NetworkInfo represents the network information
// TODO: not yet used, but can be used to capture the network information
type NetworkInfo struct {
// Bandwidth in bits per second (b/s)
Bandwidth uint64
// NetworkType represents the network type, e.g., "Ethernet", "Wi-Fi", "Cellular"
NetworkType string
}
// ConvertBytesToGB converts bytes to gigabytes
func ConvertBytesToGB(bytes float64) float64 {
return float64(bytes) / 1e9
}
// ConvertGBToBytes converts gigabytes to bytes
func ConvertGBToBytes(gb float64) float64 {
return gb * 1e9
}
// ConvertMiBToGB converts mebibytes to gigabytes
func ConvertMiBToGB(mib float64) float64 {
return (mib * 1024 * 1024) / 1_000_000_000
}
// ConvertMibToBytes converts mebibytes to bytes
func ConvertMibToBytes(mib float64) float64 {
return mib * 1024 * 1024
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"time"
)
const (
NetP2P = "p2p"
)
// NetworkSpec is a stub. Please expand based on requirements.
type NetworkSpec struct{}
// NetConfig is a stub. Please expand it or completely change it based on requirements.
type NetConfig struct {
NetworkSpec SpecConfig `json:"network_spec"` // Network specification
}
func (nc *NetConfig) GetNetworkConfig() *SpecConfig {
return &nc.NetworkSpec
}
// NetworkStats should contain all network info the user is interested in.
// for now there's only peerID and listening address but reachability, local and remote addr etc...
// can be added when necessary.
type NetworkStats struct {
ID string `json:"id"`
ListenAddr string `json:"listen_addr"`
}
// MessageInfo is a stub. Please expand it or completely change it based on requirements.
type MessageInfo struct {
Info string `json:"info"` // Message information
}
type PingResult struct {
RTT time.Duration
Success bool
Error error
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"reflect"
)
func ConvertNumericToFloat64(n any) (float64, bool) {
switch n := n.(type) {
case int, int8, int16, int32, int64:
return float64(reflect.ValueOf(n).Int()), true
case uint, uint8, uint16, uint32, uint64:
return float64(reflect.ValueOf(n).Uint()), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"context"
"fmt"
)
// Resources represents the resources of the machine
type Resources struct {
CPU CPU `json:"cpu" gorm:"embedded;embeddedPrefix:cpu_"`
GPUs GPUs `json:"gpus" gorm:"foreignKey:ResourceID"`
RAM RAM `json:"ram" gorm:"embedded;embeddedPrefix:ram_"`
Disk Disk `json:"disk" gorm:"embedded;embeddedPrefix:disk_"`
}
// implements the Calculable and Comparable interfaces
var (
_ Calculable[Resources] = (*Resources)(nil)
_ Comparable[Resources] = (*Resources)(nil)
)
// Compare compares two Resources objects
func (r *Resources) Compare(other Resources) (Comparison, error) {
cpuComp, err := r.CPU.Compare(other.CPU)
if err != nil {
return None, fmt.Errorf("error comparing CPU: %v", err)
}
ramComp, err := r.RAM.Compare(other.RAM)
if err != nil {
return None, fmt.Errorf("error comparing RAM: %v", err)
}
diskComp, err := r.Disk.Compare(other.Disk)
if err != nil {
return None, fmt.Errorf("error comparing Disk: %v", err)
}
gpuComp, err := r.GPUs.Compare(other.GPUs)
if err != nil {
return None, fmt.Errorf("error comparing GPUs: %v", err)
}
return cpuComp.And(ramComp).And(diskComp).And(gpuComp), nil
}
// Equal returns true if the resources are equal
func (r *Resources) Equal(other Resources) bool {
if r.RAM.Size != other.RAM.Size {
return false
}
if r.CPU.Cores != other.CPU.Cores {
return false
}
if r.Disk.Size != other.Disk.Size {
return false
}
return true
}
// Add returns the sum of the resources
func (r *Resources) Add(other Resources) error {
if err := r.CPU.Add(other.CPU); err != nil {
return fmt.Errorf("error adding CPU: %v", err)
}
if err := r.RAM.Add(other.RAM); err != nil {
return fmt.Errorf("error adding RAM: %v", err)
}
if err := r.Disk.Add(other.Disk); err != nil {
return fmt.Errorf("error adding Disk: %v", err)
}
if err := r.GPUs.Add(other.GPUs); err != nil {
return fmt.Errorf("error adding GPUs: %v", err)
}
return nil
}
// Subtract returns the difference of the resources
func (r *Resources) Subtract(other Resources) error {
if err := r.CPU.Subtract(other.CPU); err != nil {
return fmt.Errorf("error subtracting CPU: %v", err)
}
if err := r.RAM.Subtract(other.RAM); err != nil {
return fmt.Errorf("error subtracting RAM: %v", err)
}
if err := r.Disk.Subtract(other.Disk); err != nil {
return fmt.Errorf("error subtracting Disk: %v", err)
}
if err := r.GPUs.Subtract(other.GPUs); err != nil {
return fmt.Errorf("error subtracting GPUs: %v", err)
}
return nil
}
// MachineResources represents the total resources of the machine
type MachineResources struct {
BaseDBModel
Resources
}
// FreeResources represents the free resources of the machine
type FreeResources struct {
BaseDBModel
Resources
}
// CommittedResources represents the committed resources of the machine
type CommittedResources struct {
BaseDBModel
Resources
JobID string `json:"job_id"`
}
// OnboardedResources represents the onboarded resources of the machine
type OnboardedResources struct {
BaseDBModel
Resources
}
// ResourceAllocation represents the allocation of resources for a job
type ResourceAllocation struct {
BaseDBModel
JobID string `json:"job_id"`
Resources
}
// ResourceManager is an interface that defines the methods to manage the resources of the machine
type ResourceManager interface {
// CommitResources preallocates the resources required by the jobs
CommitResources(context.Context, ResourceAllocation) error
// ReleaseCommittedResources releases the resources that were preallocated
ReleaseCommittedResources(context.Context, string) error
// AllocateResources allocates the resources required by a job
AllocateResources(context.Context, ResourceAllocation) error
// DeallocateResources deallocates the resources required by a job
DeallocateResources(context.Context, string) error
// GetTotalAllocation returns the total allocations for the jobs
GetTotalAllocation() (Resources, error)
// GetFreeResources returns the free resources in the allocation pool
GetFreeResources(ctx context.Context) (FreeResources, error)
// GetOnboardedResources returns the onboarded resources of the machine
GetOnboardedResources(context.Context) (OnboardedResources, error)
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
UpdateOnboardedResources(context.Context, Resources) error
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"errors"
"strings"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// SpecConfig represents a configuration for a spec
// A SpecConfig can be used to define an engine spec, a storage volume, etc.
type SpecConfig struct {
// Type of the spec (e.g. docker, firecracker, storage, etc.)
Type string `json:"type"`
// Params of the spec
Params map[string]interface{} `json:"params,omitempty"`
}
type Config interface {
GetNetworkConfig() *SpecConfig
}
// NewSpecConfig creates a new SpecConfig with the given type
func NewSpecConfig(t string) *SpecConfig {
return &SpecConfig{
Type: t,
Params: make(map[string]interface{}),
}
}
// WithParam adds a new key-value pair to the spec params
func (s *SpecConfig) WithParam(key string, value interface{}) *SpecConfig {
if s.Params == nil {
s.Params = make(map[string]interface{})
}
s.Params[key] = value
return s
}
// Normalize ensures that the spec config is in a valid state
func (s *SpecConfig) Normalize() {
if s == nil {
return
}
s.Type = strings.TrimSpace(s.Type)
// Ensure that an empty and nil map are treated the same
if len(s.Params) == 0 {
s.Params = make(map[string]interface{})
}
}
// Validate checks if the spec config is valid
func (s *SpecConfig) Validate() error {
if s == nil {
return errors.New("nil spec config")
}
if validate.IsBlank(s.Type) {
return errors.New("missing spec type")
}
return nil
}
// IsType returns true if the current SpecConfig is of the given type
func (s *SpecConfig) IsType(t string) bool {
if s == nil {
return false
}
t = strings.TrimSpace(t)
return strings.EqualFold(s.Type, t)
}
// IsEmpty returns true if the spec config is empty
func (s *SpecConfig) IsEmpty() bool {
return s == nil || (validate.IsBlank(s.Type) && len(s.Params) == 0)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"net"
"time"
"github.com/google/uuid"
"github.com/oschwald/geoip2-golang"
"gorm.io/gorm"
)
// BaseDBModel is a base model for all entities. It'll be mainly used for database
// records.
type BaseDBModel struct {
ID string `gorm:"type:uuid"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// GetID returns the ID of the entity.
func (m BaseDBModel) GetID() string {
return m.ID
}
// BeforeCreate sets the ID and CreatedAt fields before creating a new entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeCreate(_ *gorm.DB) error {
m.ID = uuid.NewString()
m.CreatedAt = time.Now()
return nil
}
// BeforeUpdate sets the UpdatedAt field before updating an entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeUpdate(_ *gorm.DB) error {
m.UpdatedAt = time.Now()
return nil
}
// GeoIPLocator returns the info about an IP.
type GeoIPLocator interface {
Country(ipAddress net.IP) (*geoip2.Country, error)
City(ipAddress net.IP) (*geoip2.City, error)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"math"
"reflect"
"slices"
)
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
// round rounds the value to the specified number of decimal places
func round[T float32 | float64](value T, places int) T {
factor := math.Pow(10, float64(places))
roundedValue := math.Round(float64(value)*factor) / factor
return T(roundedValue)
}
func SliceContainsOneValue(slice []Comparison, value Comparison) bool {
// returns true if all elements in the slice are equal to the given value
for _, v := range slice {
if v != value {
return false
}
}
return true
}
func returnBestMatch(dimension []Comparison) (Comparison, int) {
bestMatch := None
bestIndex := -1
for i, v := range dimension {
switch v {
case Equal:
return v, i // Equal is the best, so return immediately
case Better:
if bestMatch != Better { // Prioritize Better only if we haven't seen one yet
bestMatch = Better
bestIndex = i
}
case Worse:
if bestMatch == None { // Prioritize Worse only if nothing better found
bestMatch = Worse
bestIndex = i
}
default:
// Ignore None
}
}
return bestMatch, bestIndex
}
func removeIndex(slice [][]Comparison, index int) [][]Comparison {
// removeIndex removes the element at the specified index from each sub-slice in the given slice.
// If the index is out of bounds for a sub-slice, the function leaves that sub-slice unmodified.
for i, c := range slice {
if index < 0 || index >= len(c) {
// Index is out of bounds, leave the sub-slice unmodified
continue
}
slice[i] = append(c[:index], c[index+1:]...)
}
return slice
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"errors"
"github.com/fivebinaries/go-cardano-serialization/address"
)
// ValidateAddress checks if the wallet address is a valid cardano address
func ValidateAddress(addr string) error {
validCardano := false
isValidCardano(addr, &validCardano)
if validCardano {
return nil
}
return errors.New("invalid cardano wallet address")
}
// isValidCardano checks if the cardano address is valid
func isValidCardano(addr string, valid *bool) {
defer func() {
if r := recover(); r != nil {
*valid = false
}
}()
if _, err := address.NewAddress(addr); err == nil {
*valid = true
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"errors"
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
)
func GetDirectorySize(fs afero.Fs, path string) (int64, error) {
var size int64
err := afero.Walk(fs, path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to calculate volume size: %w", err)
}
return size, nil
}
// WriteToFile writes data to a file.
func WriteToFile(fs afero.Fs, data []byte, filePath string) (string, error) {
if err := fs.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return "", fmt.Errorf("failed to open path: %w", err)
}
file, err := fs.Create(filePath)
if err != nil {
return "", fmt.Errorf("failed to create path: %w", err)
}
defer file.Close()
n, err := file.Write(data)
if err != nil {
return "", fmt.Errorf("failed to write data to path: %w", err)
}
if n != len(data) {
return "", errors.New("failed to write the size of data to file")
}
return filePath, nil
}
// FileExists checks if destination file exists
func FileExists(fs afero.Fs, filename string) bool {
info, err := fs.Stat(filename)
if os.IsNotExist(err) {
return false
}
return !info.IsDir()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"path"
)
type HTTPClient struct {
BaseURL string
APIVersion string
Client *http.Client
}
func NewHTTPClient(baseURL, version string) *HTTPClient {
return &HTTPClient{
BaseURL: baseURL,
APIVersion: version,
Client: http.DefaultClient,
}
}
// MakeRequest performs an HTTP request with the given method, path, and body
// It returns the response body, status code, and an error if any
func (c *HTTPClient) MakeRequest(method, relativePath string, body []byte) ([]byte, int, error) {
url, err := url.Parse(c.BaseURL)
if err != nil {
return nil, 0, fmt.Errorf("failed to parse base URL: %v", err)
}
url.Path = path.Join(c.APIVersion, relativePath)
req, err := http.NewRequest(method, url.String(), bytes.NewBuffer(body))
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
resp, err := c.Client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("request failed: %v", err)
}
defer resp.Body.Close()
// Read the response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("failed to read response body: %v", err)
}
return respBody, resp.StatusCode, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"io"
"sync"
"time"
)
type IOProgress struct {
n float64
size float64
started time.Time
estimated time.Time
err error
}
type Reader struct {
reader io.Reader
lock sync.RWMutex
Progress IOProgress
}
type Writer struct {
writer io.Writer
lock sync.RWMutex
Progress IOProgress
}
func ReaderWithProgress(r io.Reader, size int64) *Reader {
return &Reader{
reader: r,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func WriterWithProgress(w io.Writer, size int64) *Writer {
return &Writer{
writer: w,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func (r *Reader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.lock.Lock()
r.Progress.n += float64(n)
r.Progress.err = err
r.lock.Unlock()
return n, err
}
func (w *Writer) Write(p []byte) (n int, err error) {
n, err = w.writer.Write(p)
w.lock.Lock()
w.Progress.n += float64(n)
w.Progress.err = err
w.lock.Unlock()
return n, err
}
func (p IOProgress) Size() float64 {
return p.size
}
func (p IOProgress) N() float64 {
return p.n
}
func (p IOProgress) Complete() bool {
if p.err == io.EOF {
return true
}
if p.size == -1 {
return false
}
return p.n >= p.size
}
// Percent calculates the percentage complete.
func (p IOProgress) Percent() float64 {
if p.n == 0 {
return 0
}
if p.n >= p.size {
return 100
}
return 100.0 / (p.size / p.n)
}
func (p IOProgress) Remaining() time.Duration {
if p.estimated.IsZero() {
return time.Until(p.Estimated())
}
return time.Until(p.estimated)
}
func (p IOProgress) Estimated() time.Time {
ratio := p.n / p.size
past := float64(time.Since(p.started))
if p.n > 0.0 {
total := time.Duration(past / ratio)
p.estimated = p.started.Add(total)
}
return p.estimated
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"fmt"
"strings"
"sync"
)
// A SyncMap is a concurrency-safe sync.Map that uses strongly-typed
// method signatures to ensure the types of its stored data are known.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// SyncMapFromMap converts a standard Go map to a concurrency-safe SyncMap.
func SyncMapFromMap[K comparable, V any](m map[K]V) *SyncMap[K, V] {
ret := &SyncMap[K, V]{}
for k, v := range m {
ret.Put(k, v)
}
return ret
}
// Get retrieves the value associated with the given key from the map.
// It returns the value and a boolean indicating whether the key was found.
func (m *SyncMap[K, V]) Get(key K) (V, bool) {
value, ok := m.Load(key)
if !ok {
var empty V
return empty, false
}
return value.(V), true
}
// Put inserts or updates a key-value pair in the map.
func (m *SyncMap[K, V]) Put(key K, value V) {
m.Store(key, value)
}
// Iter iterates over each key-value pair in the map, executing the provided function on each pair.
// The iteration stops if the provided function returns false.
func (m *SyncMap[K, V]) Iter(ranger func(key K, value V) bool) {
m.Range(func(key, value any) bool {
k := key.(K)
v := value.(V)
return ranger(k, v)
})
}
// Keys returns a slice containing all the keys present in the map.
func (m *SyncMap[K, V]) Keys() []K {
var keys []K
m.Iter(func(key K, _ V) bool {
keys = append(keys, key)
return true
})
return keys
}
// String provides a string representation of the map, listing all key-value pairs.
func (m *SyncMap[K, V]) String() string {
// Use a strings.Builder for efficient string concatenation.
var sb strings.Builder
sb.Write([]byte(`{`))
m.Range(func(key, value any) bool {
// Append each key-value pair to the string builder.
sb.Write([]byte(fmt.Sprintf(`%s=%s`, key, value)))
return true
})
sb.Write([]byte(`}`))
return sb.String()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"math/big"
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"time"
"github.com/spf13/afero"
"golang.org/x/exp/slices"
"gitlab.com/nunet/device-management-service/types"
)
const (
KernelFileURL = "https://d.nunet.io/fc/vmlinux"
KernelFilePath = "/etc/nunet/vmlinux"
FilesystemURL = "https://d.nunet.io/fc/nunet-fc-ubuntu-20.04-0.ext4"
FilesystemPath = "/etc/nunet/nunet-fc-ubuntu-20.04-0.ext4"
)
// DownloadFile downloads a file from a url and saves it to a filepath
func DownloadFile(url string, filepath string, maxBytes int64) (err error) {
file, err := os.Create(filepath)
if err != nil {
return err
}
defer file.Close()
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
client := &http.Client{
Timeout: 10 * time.Second,
}
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("failed to download file: server returned %s", resp.Status)
}
reader := io.LimitReader(resp.Body, maxBytes)
_, err = io.Copy(file, reader)
if err != nil {
return err
}
log.Println("Finished downloading file '", filepath, "'")
return nil
}
// ReadHTTPString GET request to http endpoint and return response as string
func ReadHTTPString(url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(respBody), nil
}
// RandomString generates a random string of length n
func RandomString(n int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
sb.WriteByte(charset[n.Int64()])
}
return sb.String(), nil
}
// SliceContains checks if a string exists in a slice
func SliceContains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
// DeleteFile deletes a file, with or without a backup
func DeleteFile(path string, backup bool) (err error) {
if backup {
err = os.Rename(path, fmt.Sprintf("%s.bk.%d", path, time.Now().Unix()))
} else {
err = os.Remove(path)
}
return
}
// CreateDirectoryIfNotExists creates a directory if it does not exist
func CreateDirectoryIfNotExists(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
err := os.MkdirAll(path, 0o755)
if err != nil {
return err
}
}
return nil
}
// CalculateSHA256Checksum calculates the SHA256 checksum of a file
func CalculateSHA256Checksum(filePath string) (string, error) {
// Open the file for reading
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
// Create a new SHA-256 hash
hash := sha256.New()
// Copy the file's contents into the hash object
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
// Calculate the checksum and return it as a hexadecimal string
checksum := hex.EncodeToString(hash.Sum(nil))
return checksum, nil
}
// put checksum in file
func CreateCheckSumFile(filePath string, checksum string) (string, error) {
sha256FilePath := fmt.Sprintf("%s.sha256.txt", filePath)
sha256File, err := os.Create(sha256FilePath)
if err != nil {
return "", fmt.Errorf("unable to create SHA-256 checksum file: %v", err)
}
defer sha256File.Close()
_, err = sha256File.WriteString(checksum)
if err != nil {
return "", fmt.Errorf("unable to write to SHA-256 checksum file: %v", err)
}
return sha256FilePath, nil
}
// SanitizeArchivePath Sanitize archive file pathing from "G305: Zip Slip vulnerability"
func SanitizeArchivePath(d, t string) (v string, err error) {
v = filepath.Join(d, t)
if strings.HasPrefix(v, filepath.Clean(d)) {
return v, nil
}
return "", fmt.Errorf("%s: %s", "content filepath is tainted", t)
}
// ExtractTarGzToPath extracts a tar.gz file to a specified path
func ExtractTarGzToPath(tarGzFilePath, extractedPath string, maxBytes int64) error {
// Ensure the target directory exists; create it if it doesn't.
if err := os.MkdirAll(extractedPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating target directory: %v", err)
}
tarGzFile, err := os.Open(tarGzFilePath)
if err != nil {
return fmt.Errorf("error opening tar.gz file: %v", err)
}
defer tarGzFile.Close()
gzipReader, err := gzip.NewReader(tarGzFile)
if err != nil {
return fmt.Errorf("error creating gzip reader: %v", err)
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
var totalSize int64
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading tar header: %v", err)
}
if header.Size > maxBytes {
return fmt.Errorf("file %s exceeds the maximum allowed size of %d bytes", header.Name, maxBytes)
}
// Construct the full target path by joining the target directory with
// the name of the file or directory from the archive.
fullTargetPath, err := SanitizeArchivePath(extractedPath, header.Name)
if err != nil {
return fmt.Errorf("failed to santize path %w", err)
}
// Ensure that the directory path leading to the file exists.
if header.FileInfo().IsDir() {
// Create the directory and any parent directories as needed.
if err := os.MkdirAll(fullTargetPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
} else {
// Create the file and any parent directories as needed.
if err := os.MkdirAll(filepath.Dir(fullTargetPath), os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// Create a new file with the specified path.
newFile, err := os.Create(fullTargetPath)
if err != nil {
return fmt.Errorf("error creating file: %v", err)
}
defer newFile.Close()
// Copy the file contents from the tar archive to the new file.
for {
n, err := io.CopyN(newFile, tarReader, 1024)
totalSize += n
if totalSize > maxBytes {
return fmt.Errorf("extracted data exceeds allowed limit of %d bytes", maxBytes)
}
if err != nil {
if err == io.EOF {
break
}
return err
}
}
}
}
return nil
}
// CheckWSL check if running in WSL
func CheckWSL(afs afero.Afero) (bool, error) {
file, err := afs.Open("/proc/version")
if err != nil {
return false, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "Microsoft") || strings.Contains(line, "WSL") {
return true, nil
}
}
if scanner.Err() != nil {
return false, scanner.Err()
}
return false, nil
}
func RandomBool() (bool, error) {
n, err := rand.Int(rand.Reader, big.NewInt(2))
if err != nil {
return false, err
}
// Return true if the number is 1, otherwise false
return n.Int64() == 1, nil
}
func IsExecutorType(v interface{}) bool {
_, ok := v.(types.ExecutorType)
return ok
}
func IsGPUVendor(v interface{}) bool {
_, ok := v.(types.GPUVendor)
return ok
}
func IsJobType(v interface{}) bool {
_, ok := v.(types.JobType)
return ok
}
func IsJobTypes(v interface{}) bool {
_, ok := v.(types.JobTypes)
return ok
}
func IsExecutor(v interface{}) bool {
_, ok := v.(types.Executor)
return ok
}
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
func NoIntersectionSlices(slice1, slice2 []interface{}) bool {
result := false // the default result is false
for _, subElement := range slice1 {
if slices.Contains(slice2, subElement) {
result = false
} else {
result = true
}
}
return result
}
// IntersectionStringSlices returns the intersection of two slices of strings.
func IntersectionSlices(slice1, slice2 []interface{}) []interface{} {
// Create a map to store strings from the first slice.
executorMap := make(map[interface{}]bool)
// Iterate through the first slice and add elements to the map.
for _, str := range slice1 {
executorMap[str] = true
}
// Create a slice to store the intersection of the strings.
intersectionSlice := []interface{}{}
// Iterate through the second slice and check for common elements.
for _, str := range slice2 {
if executorMap[str] {
// If the string is found in the map, add to the intersection slice.
intersectionSlice = append(intersectionSlice, str)
// Remove the string from the map to avoid duplicates in the result.
delete(executorMap, str)
}
}
return intersectionSlice
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package validate
import (
"strings"
)
// IsBlank checks if a string is empty or contains only whitespace
func IsBlank(s string) bool {
return len(strings.TrimSpace(s)) == 0
}
// IsNotBlank checks if a string is not empty and does not contain only whitespace
func IsNotBlank(s string) bool {
return !IsBlank(s)
}
// Just checks if a variable is a string
func IsLiteral(s interface{}) bool {
switch s.(type) {
case string:
return true
default:
return false
}
}