// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
const (
HealthCheckBehavior = "/dms/actor/healthcheck"
HealthCheckGrantDuration = 2 * time.Hour
)
var HealthCheckInterval = 30 * time.Second
type BasicActor struct {
dispatch *Dispatch
registry Registry
network network.Network
security SecurityContext
supervisor Handle
limiter RateLimiter
parent Handle
children map[did.DID]Handle
params BasicActorParams
self Handle
mx sync.Mutex
subscriptions map[string]uint64
}
type BasicActorParams struct{}
var _ Actor = (*BasicActor)(nil)
// New creates a new basic actor.
func New(
supervisor Handle,
net network.Network,
security *BasicSecurityContext,
limiter RateLimiter,
params BasicActorParams,
self Handle,
opt ...DispatchOption,
) (*BasicActor, error) {
if net == nil {
return nil, errors.New("network is nil")
}
if security == nil {
return nil, errors.New("security is nil")
}
dispatchOptions := []DispatchOption{WithRateLimiter(limiter)}
dispatchOptions = append(dispatchOptions, opt...)
dispatch := NewDispatch(security, dispatchOptions...)
actor := &BasicActor{
dispatch: dispatch,
registry: newRegistry(),
network: net,
security: security,
limiter: limiter,
supervisor: supervisor,
params: params,
self: self,
subscriptions: make(map[string]uint64),
children: make(map[did.DID]Handle),
}
if err := actor.grantSupervisorCapabilities(supervisor); err != nil {
return nil, fmt.Errorf("granting supervisor capabilities: %w", err)
}
return actor, nil
}
func (a *BasicActor) grantSupervisorCapabilities(supervisor Handle) error {
if supervisor.Empty() || supervisor.ID.Equal(a.self.ID) {
return nil
}
actorDID, err := did.FromID(a.self.ID)
if err != nil {
return fmt.Errorf("actor did: %w", err)
}
expiry := time.Now().Add(HealthCheckGrantDuration)
actorCap := a.security.Capability()
tokens, err := actorCap.Grant(
ucan.Delegate,
supervisor.DID,
actorDID,
nil,
uint64(expiry.UnixNano()),
0,
[]ucan.Capability{
ucan.Capability(HealthCheckBehavior),
},
)
if err != nil {
return fmt.Errorf("error granting healthcheck capability to supervisor: %w", err)
}
if err := actorCap.AddRoots(nil, tokens, ucan.TokenList{}, ucan.TokenList{}); err != nil {
return fmt.Errorf("error adding supervisor anchor: %w", err)
}
return nil
}
func (a *BasicActor) Start() error {
// Network messages
if err := a.network.HandleMessage(
fmt.Sprintf("actor/%s/messages/0.0.1", a.self.Address.InboxAddress),
a.handleMessage,
); err != nil {
return fmt.Errorf("starting actor: %s: %w", a.self.ID, err)
}
// and start the internal goroutines
a.dispatch.Start()
// XXX: is this clean?
go func() {
select {
case <-time.After(HealthCheckGrantDuration):
if err := a.grantSupervisorCapabilities(a.supervisor); err != nil {
log.Errorf("error granting supervisor capabilities: %s", err)
}
case <-a.Context().Done():
return
}
}()
return nil
}
func (a *BasicActor) handleMessage(data []byte, srcPeerID peer.ID) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling message: %s", err)
return
}
if !a.self.ID.Equal(msg.To.ID) {
log.Warnf("message is not for ourselves: %s %s", a.self.ID, msg.To.ID)
return
}
if msg.From.Address.HostID != srcPeerID.String() {
log.Warnf("message from %s not matching peer id %s", msg.From.Address.HostID, srcPeerID)
return
}
if !a.limiter.Allow(msg) {
log.Warnf("incoming message invoking %s not allowed by limiter", msg.Behavior)
return
}
// TODO no err?
_ = a.Receive(msg)
}
func (a *BasicActor) Context() context.Context {
return a.dispatch.Context()
}
func (a *BasicActor) Handle() Handle {
return a.self
}
func (a *BasicActor) Supervisor() Handle {
return a.supervisor
}
func (a *BasicActor) Security() SecurityContext {
return a.security
}
func (a *BasicActor) UpdateSecurityContext(newSecurity SecurityContext) error {
a.mx.Lock()
defer a.mx.Unlock()
if newSecurity == nil {
return errors.New("new security context cannot be nil")
}
// Update the actor's security context
a.security = newSecurity
// Update the dispatch's security context as well
a.dispatch.UpdateSecurityContext(newSecurity)
return nil
}
func (a *BasicActor) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
return a.dispatch.AddBehavior(behavior, continuation, opt...)
}
func (a *BasicActor) RemoveBehavior(behavior string) {
a.dispatch.RemoveBehavior(behavior)
}
func (a *BasicActor) Receive(msg Envelope) error {
if a.self.ID.Equal(msg.To.ID) {
return a.dispatch.Receive(msg)
}
if msg.IsBroadcast() {
return a.dispatch.Receive(msg)
}
return fmt.Errorf("bad receiver: %w", ErrInvalidMessage)
}
func (a *BasicActor) Send(msg Envelope) error {
if msg.To.ID.Equal(a.self.ID) {
return a.Receive(msg)
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
invoke := []Capability{Capability(msg.Behavior)}
var delegate []Capability
if msg.Options.ReplyTo != "" {
delegate = append(delegate, Capability(msg.Options.ReplyTo))
}
if err := a.security.Provide(&msg, invoke, delegate); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
err = a.network.SendMessage(
a.Context(),
msg.To.Address.HostID,
types.MessageEnvelope{
Type: types.MessageType(
fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress),
),
Data: data,
},
msg.Expiry(),
)
if err != nil {
return fmt.Errorf("sending message to %s: %w", msg.To.ID, err)
}
return nil
}
func (a *BasicActor) Invoke(msg Envelope) (<-chan Envelope, error) {
if msg.Options.ReplyTo == "" {
msg.Options.ReplyTo = fmt.Sprintf("/dms/actor/replyto/%d", a.security.Nonce())
}
result := make(chan Envelope, 1)
if err := a.dispatch.AddBehavior(
msg.Options.ReplyTo,
func(reply Envelope) {
result <- reply
close(result)
},
WithBehaviorExpiry(msg.Options.Expire),
WithBehaviorOneShot(true),
); err != nil {
return nil, fmt.Errorf("adding reply behavior: %w", err)
}
if err := a.Send(msg); err != nil {
a.dispatch.RemoveBehavior(msg.Options.ReplyTo)
return nil, fmt.Errorf("sending message: %w", err)
}
return result, nil
}
func (a *BasicActor) CreateChild(
id string,
super Handle,
opts ...CreateChildOption,
) (Actor, error) {
// Create default options
privk, _, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return nil, fmt.Errorf("failed to create a child actor: %w", err)
}
options := &CreateChildOptions{
PrivKey: privk,
}
// apply caller's options
for _, opt := range opts {
opt(options)
}
sctx, err := NewBasicSecurityContext(options.PrivKey.GetPublic(), options.PrivKey, a.security.Capability())
if err != nil {
return nil, fmt.Errorf("failed to create a child actor: %w", err)
}
child, err := New(
super,
a.network,
sctx,
a.limiter,
BasicActorParams{},
Handle{
ID: sctx.id,
DID: sctx.DID(),
Address: Address{
HostID: a.self.Address.HostID,
InboxAddress: id,
},
},
)
if err != nil {
return nil, fmt.Errorf("failed to create a child actor: %w", err)
}
if err := a.registry.Add(child.Handle(), a.self, nil); err != nil {
return nil, fmt.Errorf("failed to add child to actor registry: %w", err)
}
a.mx.Lock()
child.parent = a.Handle()
a.children[child.Handle().DID] = child.Handle()
a.mx.Unlock()
return child, nil
}
// Parent returns the parent actor
func (a *BasicActor) Parent() Handle {
return a.parent
}
// Children returns the children actors
func (a *BasicActor) Children() map[did.DID]Handle {
a.mx.Lock()
defer a.mx.Unlock()
c := make(map[did.DID]Handle)
for did, handle := range a.children {
c[did] = handle
}
return c
}
func (a *BasicActor) Publish(msg Envelope) error {
if !msg.IsBroadcast() {
return ErrInvalidMessage
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
broadcast := []Capability{Capability(msg.Behavior)}
if err := a.security.ProvideBroadcast(&msg, msg.Options.Topic, broadcast); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
if err := a.network.Publish(a.Context(), msg.Options.Topic, data); err != nil {
return fmt.Errorf("publishing message: %w", err)
}
return nil
}
func (a *BasicActor) Subscribe(topic string, setup ...BroadcastSetup) error {
a.mx.Lock()
defer a.mx.Unlock()
_, ok := a.subscriptions[topic]
if ok {
return nil
}
subID, err := a.network.Subscribe(
a.Context(),
topic,
a.handleBroadcast,
func(data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
return a.validateBroadcast(topic, data, validatorData)
},
)
if err != nil {
return fmt.Errorf("subscribe: %w", err)
}
for _, f := range setup {
if err := f(topic); err != nil {
_ = a.network.Unsubscribe(topic, subID)
return fmt.Errorf("setup broadcast topic: %w", err)
}
}
a.subscriptions[topic] = subID
return nil
}
func (a *BasicActor) validateBroadcast(topic string, data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
var msg Envelope
if validatorData != nil {
if _, ok := validatorData.(Envelope); !ok {
log.Warnf("bogus pubsub validation data: %v", validatorData)
return network.ValidationReject, nil
}
// we have already validated the message, just short-circuit
return network.ValidationAccept, validatorData
} else if err := json.Unmarshal(data, &msg); err != nil {
return network.ValidationReject, nil
}
if !msg.IsBroadcast() {
return network.ValidationReject, nil
}
if msg.Options.Topic != topic {
return network.ValidationReject, nil
}
if msg.Expired() {
return network.ValidationIgnore, nil
}
if err := a.security.Verify(msg); err != nil {
return network.ValidationReject, nil
}
if !a.limiter.Allow(msg) {
log.Warnf("incoming broadcast message in %s not allowed by limiter", topic)
return network.ValidationIgnore, nil
}
return network.ValidationAccept, msg
}
func (a *BasicActor) handleBroadcast(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling broadcast message: %s", err)
return
}
// don't receive message from self
if msg.From.Equal(a.Handle()) {
return
}
if err := a.Receive(msg); err != nil {
log.Warnf("error receiving broadcast message: %s", err)
}
}
func (a *BasicActor) Stop() error {
a.dispatch.Stop()
for topic, subID := range a.subscriptions {
err := a.network.Unsubscribe(topic, subID)
if err != nil {
log.Debugf("error unsubscribing from %s: %s", topic, err)
}
}
a.network.UnregisterMessageHandler(fmt.Sprintf("actor/%s/messages/0.0.1", a.self.Address.InboxAddress))
return nil
}
func (a *BasicActor) Limiter() RateLimiter {
return a.limiter
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/observability"
)
var (
DefaultDispatchGCInterval = 120 * time.Second
DefaultDispatchWorkers = 1
)
// Dispatch provides a reaction kernel with multithreaded dispatch and oneshot
// continuations.
type Dispatch struct {
ctx context.Context
close func()
sctx SecurityContext
mx sync.Mutex
q chan Envelope // incoming message queue
vq chan Envelope // verified message queue
behaviors map[string]*BehaviorState
running bool
options DispatchOptions
}
type DispatchOptions struct {
Limiter RateLimiter
GCInterval time.Duration
Workers int
}
type BehaviorState struct {
cont Behavior
opt BehaviorOptions
}
type DispatchOption func(o *DispatchOptions)
func WithDispatchWorkers(count int) DispatchOption {
return func(o *DispatchOptions) {
o.Workers = count
}
}
func WithDispatchGCInterval(dt time.Duration) DispatchOption {
return func(o *DispatchOptions) {
o.GCInterval = dt
}
}
func WithRateLimiter(limiter RateLimiter) DispatchOption {
return func(o *DispatchOptions) {
o.Limiter = limiter
}
}
func NewDispatch(sctx SecurityContext, opt ...DispatchOption) *Dispatch {
k := &Dispatch{
sctx: sctx,
q: make(chan Envelope, 100),
vq: make(chan Envelope, 100),
behaviors: make(map[string]*BehaviorState),
options: DispatchOptions{
GCInterval: DefaultDispatchGCInterval,
Workers: DefaultDispatchWorkers,
Limiter: NoRateLimiter{},
},
}
for _, f := range opt {
f(&k.options)
}
return k
}
func (k *Dispatch) Start() {
k.mx.Lock()
defer k.mx.Unlock()
if !k.running {
k.ctx, k.close = context.WithCancel(context.Background())
for i := 0; i < k.options.Workers; i++ {
go k.recv()
}
go k.dispatch()
go k.gc()
k.running = true
}
}
func (k *Dispatch) Stop() {
k.mx.Lock()
defer k.mx.Unlock()
if k.running {
k.close()
k.running = false
}
}
func (k *Dispatch) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
st := &BehaviorState{
cont: continuation,
opt: BehaviorOptions{
Capability: []Capability{Capability(behavior)},
},
}
for _, f := range opt {
if err := f(&st.opt); err != nil {
return fmt.Errorf("adding behavior: %w", err)
}
}
log.Debugw("registered_behaviour", "labels", []string{string(observability.LabelNode)}, "behavior", behavior)
k.mx.Lock()
defer k.mx.Unlock()
k.behaviors[behavior] = st
return nil
}
func (k *Dispatch) RemoveBehavior(behavior string) {
k.mx.Lock()
defer k.mx.Unlock()
delete(k.behaviors, behavior)
}
func (k *Dispatch) Receive(msg Envelope) error {
select {
case k.q <- msg:
// check if closed to avoid cancellation/select race
if k.ctx.Err() != nil {
log.Debugf("context closed, dropping message from %s", msg.From)
return k.ctx.Err()
}
return nil
case <-k.ctx.Done():
return k.ctx.Err()
default:
return fmt.Errorf("k.queue full")
}
}
func (k *Dispatch) Context() context.Context {
return k.ctx
}
// UpdateSecurityContext updates the security context used by the dispatch.
// This should be called when the underlying capability context changes.
func (k *Dispatch) UpdateSecurityContext(newSecurity SecurityContext) {
k.mx.Lock()
defer k.mx.Unlock()
k.sctx = newSecurity
}
func (k *Dispatch) recv() {
for {
select {
case msg := <-k.q:
if err := k.sctx.Verify(msg); err != nil {
log.Debugf("failed to verify message from %s: %s", msg.From, err)
continue
}
select {
case k.vq <- msg:
// ok
default:
log.Errorf("k.vq full, dropping message from %s", msg.From)
}
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) dispatch() {
for {
select {
case msg := <-k.vq:
k.mx.Lock()
b, ok := k.behaviors[msg.Behavior]
if !ok {
k.mx.Unlock()
log.Debugf("unknown behavior %s", msg.Behavior)
continue
}
if b.Expired(time.Now()) {
delete(k.behaviors, msg.Behavior)
k.mx.Unlock()
log.Debugf("expired behavior %s", msg.Behavior)
continue
}
if msg.IsBroadcast() {
if err := k.sctx.RequireBroadcast(msg, b.opt.Topic, b.opt.Capability); err != nil {
k.mx.Unlock()
// TODO Warnw
log.Debugf("broadcast message from %s does not have the required capability %s %s: %s", msg.From, b.opt.Capability, string(msg.Capability), err)
continue
}
} else if err := k.sctx.Require(msg, b.opt.Capability); err != nil {
k.mx.Unlock()
log.Debugw("message does not have the required capability",
"from", msg.From,
"capabilities", b.opt.Capability,
"error", err,
"labels", string(observability.LabelNode),
)
// TODO? msg without an answer
continue
}
if b.opt.OneShot {
delete(k.behaviors, msg.Behavior)
}
k.mx.Unlock()
if err := k.options.Limiter.Acquire(msg); err != nil {
k.sctx.Discard(msg)
log.Debugf("limiter rejected message from %s: %s", msg.From, err)
continue
}
msg.Discard = func() {
k.sctx.Discard(msg)
}
log.Debugw("dispatching_message", "labels", string(observability.LabelNode),
"msg_from", msg.From,
"behavior", msg.Behavior,
"broadcast", msg.IsBroadcast(),
)
go func() {
defer k.options.Limiter.Release(msg)
endSpan := observability.StartSpan("Dispatch: "+msg.Behavior, "FromDID", msg.From.DID)
defer endSpan()
// exec the behavior's handler
b.cont(msg)
}()
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) gc() {
ticker := time.NewTicker(k.options.GCInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
k.mx.Lock()
now := time.Now()
for x, b := range k.behaviors {
if b.Expired(now) {
delete(k.behaviors, x)
}
}
k.mx.Unlock()
case <-k.ctx.Done():
return
}
}
}
func (b *BehaviorState) Expired(now time.Time) bool {
if b.opt.Expire > 0 {
return uint64(now.UnixNano()) > b.opt.Expire
}
return false
}
func WithBehaviorExpiry(expire uint64) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Expire = expire
return nil
}
}
func WithBehaviorCapability(require ...Capability) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Capability = require
return nil
}
}
func WithBehaviorOneShot(oneShot bool) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.OneShot = oneShot
return nil
}
}
func WithBehaviorTopic(topic string) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Topic = topic
return nil
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"fmt"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
)
func (h *Handle) Empty() bool {
return h.ID.Empty() &&
h.DID.Empty() &&
h.Address.Empty()
}
func (h *Handle) String() string {
var idStr string
idDID, err := did.FromID(h.ID)
if err == nil {
idStr = idDID.String()
}
return fmt.Sprintf("%s[%s]@%s", idStr, h.DID, h.Address)
}
func (h Handle) Equal(other Handle) bool {
if !h.ID.Equal(other.ID) {
return false
}
if !h.DID.Equal(other.DID) {
return false
}
if h.Address.HostID != other.Address.HostID {
return false
}
if h.Address.InboxAddress != other.Address.InboxAddress {
return false
}
return true
}
func HandleFromString(_ string) (Handle, error) {
// TODO
return Handle{}, ErrTODO
}
func (a *Address) Empty() bool {
return a.HostID == "" && a.InboxAddress == ""
}
func (a *Address) String() string {
return a.HostID + ":" + a.InboxAddress
}
func AddressFromString(_ string) (Address, error) {
// TODO
return Address{}, ErrTODO
}
func HandleFromPeerID(dest string) (Handle, error) {
peerID, err := peer.Decode(dest)
if err != nil {
return Handle{}, err
}
pubk, err := peerID.ExtractPublicKey()
if err != nil {
return Handle{}, err
}
return handleFromPublicKey(pubk)
}
func HandleFromDID(dest string) (Handle, error) {
actorDID, err := did.FromString(dest)
if err != nil {
return Handle{}, err
}
pubk, err := did.PublicKeyFromDID(actorDID)
if err != nil {
return Handle{}, err
}
return handleFromPublicKey(pubk)
}
// handleFromPublicKey converts a verified public key into the canonical Handle.
// All common logic for HandleFromPeerID / HandleFromDID lives here.
func handleFromPublicKey(pubk crypto.PubKey) (Handle, error) {
if !crypto.AllowedKey(int(pubk.Type())) {
return Handle{}, fmt.Errorf("unexpected key type: %d", pubk.Type())
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
return Handle{}, err
}
actorDID := did.FromPublicKey(pubk)
peerID, err := peer.IDFromPublicKey(pubk)
if err != nil {
return Handle{}, err
}
return Handle{
ID: actorID,
DID: actorDID,
Address: Address{
HostID: peerID.String(),
InboxAddress: "root",
},
}, nil
}
// HandleFromPublicKeyWithInboxAddress converts a verified public key into the canonical Handle.
// we also pass an inbox address to denote which inbox the message is routed to.
func HandleFromPublicKeyWithInboxAddress(pubk crypto.PubKey, inboxAddress, host string) (Handle, error) {
if !crypto.AllowedKey(int(pubk.Type())) {
return Handle{}, fmt.Errorf("unexpected key type: %d", pubk.Type())
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
return Handle{}, err
}
actorDID := did.FromPublicKey(pubk)
return Handle{
ID: actorID,
DID: actorDID,
Address: Address{
HostID: host,
InboxAddress: inboxAddress,
},
}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
type (
ID = crypto.ID
DID = did.DID
Capability = ucan.Capability
)
// ActorHandle is a handle for naming an actor reachable in the network
type Handle struct {
ID ID `json:"id"`
DID DID `json:"did"`
Address Address `json:"addr"`
}
// ActorAddress is a raw actor address representation
type Address struct {
HostID string `json:"host,omitempty"`
InboxAddress string `json:"inbox,omitempty"`
}
// Envelope is the envelope for messages in the actor system
type Envelope struct {
To Handle `json:"to"`
Behavior string `json:"be"`
From Handle `json:"from"`
Nonce uint64 `json:"nonce"`
Options EnvelopeOptions `json:"opt"`
Message []byte `json:"msg"`
Capability []byte `json:"cap,omitempty"`
Signature []byte `json:"sig,omitempty"`
Discard func() `json:"-"`
}
// EnvelopeOptions are sender specified options for processing an envelope
type EnvelopeOptions struct {
Expire uint64 `json:"exp"`
ReplyTo string `json:"cont,omitempty"`
Topic string `json:"topic,omitempty"`
}
// Actor is the local interface to the actor system
type Actor interface {
Context() context.Context
Handle() Handle
Security() SecurityContext
Supervisor() Handle
AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error
RemoveBehavior(behavior string)
Receive(msg Envelope) error
Send(msg Envelope) error
Invoke(msg Envelope) (<-chan Envelope, error)
Publish(msg Envelope) error
Subscribe(topic string, setup ...BroadcastSetup) error
Start() error
Stop() error
// TODO: add child termination strategies
// e.g.: childSelfRelease which relies on a func `f` to self release and terminate
// the child actor
CreateChild(id string, super Handle, opts ...CreateChildOption) (Actor, error)
Parent() Handle
Children() map[did.DID]Handle
Limiter() RateLimiter
// UpdateSecurityContext updates the actor's security context.
// This is used when the underlying capability context changes (e.g., after reloading from disk).
UpdateSecurityContext(newSecurity SecurityContext) error
}
// ActorSecurityContext provides a context for which to perform cryptographic operations
// for an actor.
// This includes:
// - signing messages
// - verifying message signatures
// - requiring capabilities
// - granting capabilities
type SecurityContext interface {
ID() ID
DID() DID
Nonce() uint64
PrivKey() crypto.PrivKey
// Require checks the capability token(s).
// It succeeds if and only if
// - the signature is valid
// - the capability token(s) in the envelope grants the origin actor ID/DID
// any of the specified capabilities.
Require(msg Envelope, invoke []Capability) error
// Provide populates the envelope with necessary capability tokens and signs it.
// the envelope is modified in place
Provide(msg *Envelope, invoke []Capability, delegate []Capability) error
// Require verifies the envelope and checks the capability tokens
// for a broadcast topic
RequireBroadcast(msg Envelope, topic string, broadcast []Capability) error
// ProvideBroadcast populates the envelope with the necessary capability tokens
// for broadcast in the topic and signs it
ProvideBroadcast(msg *Envelope, topic string, broadcast []Capability) error
// Verify verifies the message signature in an envelope
Verify(msg Envelope) error
// Sign signs an envelope; the envelope is modified in place.
Sign(msg *Envelope) error
// Grant grants the specified capabilities to the specified audience.
//
// Useful for granting capabilities between actors without sending
// tokens to each other.
Grant(sub, aud did.DID, caps []ucan.Capability, expiry time.Duration) error
// Discard discards unwanted tokens from a consumed envelope
Discard(msg Envelope)
// Return the capability context
Capability() ucan.CapabilityContext
}
// RateLimiter implements a stateful resource access limiter
// This is necessary to combat spam attacks and ensure that our system does not
// become overloaded with too many goroutines.
type RateLimiter interface {
Allow(msg Envelope) bool
Acquire(msg Envelope) error
Release(msg Envelope)
Config() RateLimiterConfig
SetConfig(cfg RateLimiterConfig)
}
type RateLimiterConfig struct {
PublicLimitAllow int
PublicLimitAcquire int
BroadcastLimitAllow int
BroadcastLimitAcquire int
TopicDefaultLimit int
TopicLimit map[string]int
}
type (
Behavior func(msg Envelope)
MessageOption func(msg *Envelope) error
)
type BehaviorOption func(opt *BehaviorOptions) error
type BehaviorOptions struct {
Capability []Capability
Expire uint64
OneShot bool
Topic string
}
type BroadcastSetup func(topic string) error
type CreateChildOption func(*CreateChildOptions)
type CreateChildOptions struct {
PrivKey crypto.PrivKey
}
// WithPrivKey sets a specific private key for the child actor
func WithPrivKey(privKey crypto.PrivKey) CreateChildOption {
return func(o *CreateChildOptions) {
o.PrivKey = privKey
}
}
func CapabilitiesJoin(caps []Capability) string {
ret := ""
for _, c := range caps {
ret += string(c) + " "
}
return ret
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"strings"
"sync"
)
// NoRateLimiter is the null limiter, that does not rate limit
type NoRateLimiter struct{}
var _ RateLimiter = NoRateLimiter{}
type BasicRateLimiter struct {
cfg RateLimiterConfig
mx sync.Mutex
activeBroadcast int
activeTopics map[string]int
activePublic int
}
var _ RateLimiter = (*BasicRateLimiter)(nil)
// implementation
func (l NoRateLimiter) Allow(_ Envelope) bool { return true }
func (l NoRateLimiter) Acquire(_ Envelope) error { return nil }
func (l NoRateLimiter) Release(_ Envelope) {}
func (l NoRateLimiter) Config() RateLimiterConfig { return RateLimiterConfig{} }
func (l NoRateLimiter) SetConfig(_ RateLimiterConfig) {}
func DefaultRateLimiterConfig() RateLimiterConfig {
return RateLimiterConfig{
PublicLimitAllow: 4096,
PublicLimitAcquire: 4112,
BroadcastLimitAllow: 1024,
BroadcastLimitAcquire: 1040,
TopicDefaultLimit: 128,
}
}
func (cfg *RateLimiterConfig) Valid() bool {
return cfg.PublicLimitAllow > 0 &&
cfg.PublicLimitAcquire >= cfg.PublicLimitAllow &&
cfg.BroadcastLimitAllow > 0 &&
cfg.BroadcastLimitAcquire >= cfg.BroadcastLimitAllow &&
cfg.TopicDefaultLimit > 0
}
func NewRateLimiter(cfg RateLimiterConfig) RateLimiter {
return &BasicRateLimiter{
cfg: cfg,
activeTopics: make(map[string]int),
}
}
func (l *BasicRateLimiter) Allow(msg Envelope) bool {
if msg.IsBroadcast() {
return l.allowBroadcast(msg)
}
if isPublicBehavior(msg) {
return l.allowPublic()
}
return true
}
func (l *BasicRateLimiter) allowPublic() bool {
l.mx.Lock()
defer l.mx.Unlock()
return l.activePublic < l.cfg.PublicLimitAllow
}
func (l *BasicRateLimiter) allowBroadcast(msg Envelope) bool {
l.mx.Lock()
defer l.mx.Unlock()
if l.activeBroadcast >= l.cfg.BroadcastLimitAllow {
return false
}
topic := msg.Options.Topic
active := l.activeTopics[topic]
topicLimit, ok := l.cfg.TopicLimit[topic]
if !ok {
return active < l.cfg.TopicDefaultLimit
}
return active < topicLimit
}
func (l *BasicRateLimiter) Acquire(msg Envelope) error {
if msg.IsBroadcast() {
return l.acquireBroadcast(msg)
}
if isPublicBehavior(msg) {
return l.acquirePublic()
}
return nil
}
func (l *BasicRateLimiter) acquirePublic() error {
l.mx.Lock()
defer l.mx.Unlock()
if l.activePublic >= l.cfg.PublicLimitAcquire {
return ErrRateLimitExceeded
}
l.activePublic++
return nil
}
func (l *BasicRateLimiter) acquireBroadcast(msg Envelope) error {
l.mx.Lock()
defer l.mx.Unlock()
if l.activeBroadcast >= l.cfg.BroadcastLimitAcquire {
return ErrRateLimitExceeded
}
topic := msg.Options.Topic
active := l.activeTopics[topic]
topicLimit, ok := l.cfg.TopicLimit[topic]
if ok {
if active >= topicLimit {
return ErrRateLimitExceeded
}
} else if active >= l.cfg.TopicDefaultLimit {
return ErrRateLimitExceeded
}
active++
l.activeTopics[topic] = active
l.activeBroadcast++
return nil
}
func (l *BasicRateLimiter) Release(msg Envelope) {
if msg.IsBroadcast() {
l.releaseBroadcast(msg)
} else if isPublicBehavior(msg) {
l.releasePublic()
}
}
func (l *BasicRateLimiter) releasePublic() {
l.mx.Lock()
defer l.mx.Unlock()
l.activePublic--
}
func (l *BasicRateLimiter) releaseBroadcast(msg Envelope) {
l.mx.Lock()
defer l.mx.Unlock()
topic := msg.Options.Topic
active, ok := l.activeTopics[topic]
if !ok {
return
}
active--
if active > 0 {
l.activeTopics[topic] = active
} else {
delete(l.activeTopics, topic)
}
l.activeBroadcast--
}
func (l *BasicRateLimiter) Config() RateLimiterConfig {
l.mx.Lock()
defer l.mx.Unlock()
return l.cfg
}
func (l *BasicRateLimiter) SetConfig(cfg RateLimiterConfig) {
l.mx.Lock()
defer l.mx.Unlock()
l.cfg = cfg
}
func isPublicBehavior(msg Envelope) bool {
return strings.HasPrefix(msg.Behavior, "/public/")
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"testing"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
// MockActor is a simplified actor implementation that uses virtualNet for communications
// but without the dispatch and registry complexity of BasicActor
type MockActor struct {
network network.Network
security SecurityContext
supervisor Handle
limiter RateLimiter
parent Handle
children map[did.DID]Handle
self Handle
mx sync.Mutex
behaviors map[string]Behavior
subscriptions map[string]uint64
}
// NewMockActor creates a new mock actor
func NewMockActor(
supervisor Handle,
net network.Network,
security SecurityContext,
limiter RateLimiter,
self Handle,
) (*MockActor, error) {
if net == nil {
return nil, errors.New("network is nil")
}
if security == nil {
return nil, errors.New("security is nil")
}
actor := &MockActor{
network: net,
security: security,
limiter: limiter,
supervisor: supervisor,
self: self,
behaviors: make(map[string]Behavior),
subscriptions: make(map[string]uint64),
children: make(map[did.DID]Handle),
}
return actor, nil
}
func (a *MockActor) Context() context.Context {
return context.Background()
}
func (a *MockActor) Handle() Handle {
return a.self
}
func (a *MockActor) Supervisor() Handle {
return a.supervisor
}
func (a *MockActor) Security() SecurityContext {
return a.security
}
func (a *MockActor) UpdateSecurityContext(newSecurity SecurityContext) error {
a.mx.Lock()
defer a.mx.Unlock()
if newSecurity == nil {
return errors.New("new security context cannot be nil")
}
a.security = newSecurity
return nil
}
func (a *MockActor) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
st := &BehaviorState{
cont: continuation,
opt: BehaviorOptions{
Capability: []Capability{Capability(behavior)},
},
}
for _, f := range opt {
if err := f(&st.opt); err != nil {
return fmt.Errorf("adding behavior: %w", err)
}
}
a.mx.Lock()
defer a.mx.Unlock()
a.behaviors[behavior] = continuation
return nil
}
func (a *MockActor) RemoveBehavior(behavior string) {
a.mx.Lock()
defer a.mx.Unlock()
delete(a.behaviors, behavior)
}
func (a *MockActor) Receive(msg Envelope) error {
a.mx.Lock()
behavior, ok := a.behaviors[msg.Behavior]
a.mx.Unlock()
if !ok {
return fmt.Errorf("no behavior registered for %s", msg.Behavior)
}
msg = Envelope{
From: msg.From,
To: msg.To,
Message: msg.Message,
Behavior: msg.Behavior,
Options: msg.Options,
Signature: msg.Signature,
Capability: msg.Capability,
Nonce: msg.Nonce,
Discard: func() {},
}
behavior(msg)
return nil
}
func (a *MockActor) Send(msg Envelope) error {
if msg.To.ID.Equal(a.self.ID) {
return a.Receive(msg)
}
// Sign the message
if err := a.security.Sign(&msg); err != nil {
return fmt.Errorf("signing message: %w", err)
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
protocol := fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress)
err = a.network.SendMessage(
a.Context(),
msg.To.Address.HostID,
types.MessageEnvelope{
Type: types.MessageType(protocol),
Data: data,
},
msg.Expiry(),
)
if err != nil {
return fmt.Errorf("sending message to %s: %w", msg.To.ID, err)
}
return nil
}
func (a *MockActor) Invoke(msg Envelope) (<-chan Envelope, error) {
if msg.Options.ReplyTo == "" {
msg.Options.ReplyTo = fmt.Sprintf("/dms/actor/replyto/%d", a.security.Nonce())
}
result := make(chan Envelope, 1)
if err := a.AddBehavior(
msg.Options.ReplyTo,
func(reply Envelope) {
result <- reply
close(result)
},
); err != nil {
return nil, fmt.Errorf("adding reply behavior: %w", err)
}
if err := a.Send(msg); err != nil {
a.RemoveBehavior(msg.Options.ReplyTo)
return nil, fmt.Errorf("sending message: %w", err)
}
return result, nil
}
func (a *MockActor) CreateChild(
id string,
super Handle,
opts ...CreateChildOption,
) (Actor, error) {
// Create default options
privk, _, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return nil, fmt.Errorf("failed to create a child actor: %w", err)
}
options := &CreateChildOptions{
PrivKey: privk,
}
// apply caller's options
for _, opt := range opts {
opt(options)
}
sctx, err := NewBasicSecurityContext(options.PrivKey.GetPublic(), options.PrivKey, a.security.Capability())
if err != nil {
return nil, fmt.Errorf("failed to create a child actor: %w", err)
}
child, err := NewMockActor(
super,
a.network,
sctx,
a.limiter,
Handle{
ID: sctx.ID(),
DID: sctx.DID(),
Address: Address{
HostID: a.self.Address.HostID,
InboxAddress: id,
},
},
)
if err != nil {
return nil, fmt.Errorf("failed to create a child actor: %w", err)
}
a.mx.Lock()
child.parent = a.Handle()
a.children[child.Handle().DID] = child.Handle()
a.mx.Unlock()
return child, nil
}
func (a *MockActor) Parent() Handle {
return a.parent
}
func (a *MockActor) Children() map[did.DID]Handle {
a.mx.Lock()
defer a.mx.Unlock()
c := make(map[did.DID]Handle)
for did, handle := range a.children {
c[did] = handle
}
return c
}
func (a *MockActor) Publish(msg Envelope) error {
if !msg.IsBroadcast() {
return ErrInvalidMessage
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
if err := a.network.Publish(a.Context(), msg.Options.Topic, data); err != nil {
return fmt.Errorf("publishing message: %w", err)
}
return nil
}
func (a *MockActor) Subscribe(topic string, setup ...BroadcastSetup) error {
a.mx.Lock()
defer a.mx.Unlock()
_, ok := a.subscriptions[topic]
if ok {
return nil
}
subID, err := a.network.Subscribe(
a.Context(),
topic,
a.handleBroadcast,
func(data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
return a.validateBroadcast(topic, data, validatorData)
},
)
if err != nil {
return fmt.Errorf("subscribe: %w", err)
}
for _, f := range setup {
if err := f(topic); err != nil {
_ = a.network.Unsubscribe(topic, subID)
return fmt.Errorf("setup broadcast topic: %w", err)
}
}
a.subscriptions[topic] = subID
return nil
}
func (a *MockActor) validateBroadcast(topic string, data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
var msg Envelope
if validatorData != nil {
if _, ok := validatorData.(Envelope); !ok {
return network.ValidationReject, nil
}
return network.ValidationAccept, validatorData
} else if err := json.Unmarshal(data, &msg); err != nil {
return network.ValidationReject, nil
}
if !msg.IsBroadcast() {
return network.ValidationReject, nil
}
if msg.Options.Topic != topic {
return network.ValidationReject, nil
}
if msg.Expired() {
return network.ValidationIgnore, nil
}
if err := a.security.Verify(msg); err != nil {
return network.ValidationReject, nil
}
if !a.limiter.Allow(msg) {
return network.ValidationIgnore, nil
}
return network.ValidationAccept, msg
}
func (a *MockActor) handleBroadcast(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
return
}
// don't receive message from self
if msg.From.Equal(a.Handle()) {
return
}
if err := a.Receive(msg); err != nil {
return
}
}
func (a *MockActor) Start() error {
// Network messages
if err := a.network.HandleMessage(
fmt.Sprintf("actor/%s/messages/0.0.1", a.self.Address.InboxAddress),
a.handleMessage,
); err != nil {
return fmt.Errorf("starting actor: %s: %w", a.self.ID, err)
}
return nil
}
func (a *MockActor) handleMessage(data []byte, peerID peer.ID) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
return
}
if !a.self.ID.Equal(msg.To.ID) {
return
}
if msg.From.Address.HostID != peerID.String() {
return
}
_ = a.Receive(msg)
}
func (a *MockActor) Stop() error {
for topic, subID := range a.subscriptions {
err := a.network.Unsubscribe(topic, subID)
if err != nil {
return err
}
}
a.network.UnregisterMessageHandler(fmt.Sprintf("actor/%s/messages/0.0.1", a.self.Address.InboxAddress))
return nil
}
func (a *MockActor) Limiter() RateLimiter {
return a.limiter
}
func NewMockActorForTest(t *testing.T, supervisor Handle, substrate *network.Substrate) (Actor, network.Network, Handle, crypto.PrivKey, crypto.PubKey) {
t.Helper()
priv, pub, err := crypto.GenerateKeyPair(crypto.Ed25519)
require.NoError(t, err)
rootDID, root1Trust := MakeRootTrustContext(t)
mockActorDID, actor1Trust := MakeTrustContext(t, priv)
actorCap := MakeCapabilityContext(t, mockActorDID, rootDID, actor1Trust, root1Trust)
securityCtx, err := NewBasicSecurityContext(pub, priv, actorCap)
require.NoError(t, err)
id := securityCtx.ID()
peerID, err := peer.IDFromPublicKey(pub)
require.NoError(t, err)
peer := substrate.AddWiredPeer(peerID)
require.NoError(t, peer.Start())
handle := Handle{
ID: id,
DID: mockActorDID,
Address: Address{
HostID: peerID.String(),
InboxAddress: "root",
},
}
mockActor, err := NewMockActor(
supervisor,
peer,
securityCtx,
nil,
handle,
)
require.NoError(t, err)
require.NoError(t, mockActor.Start())
return mockActor, peer, handle, priv, pub
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"encoding/json"
"fmt"
"reflect"
"time"
)
const (
defaultMessageTimeout = 30 * time.Second
)
var signaturePrefix = []byte("dms:msg:")
type HealthCheckResponse struct {
OK bool
Error string
}
// Message constructs a new message envelope and applies the options
func Message(src Handle, dest Handle, behavior string, payload interface{}, opt ...MessageOption) (Envelope, error) {
var data []byte
if payload == nil || (reflect.ValueOf(payload).Kind() == reflect.Ptr && reflect.ValueOf(payload).IsNil()) {
data = []byte{}
} else {
var err error
data, err = json.Marshal(payload)
if err != nil {
return Envelope{}, fmt.Errorf("marshaling payload: %w", err)
}
}
msg := Envelope{
To: dest,
Behavior: behavior,
From: src,
Message: data,
Options: EnvelopeOptions{
Expire: uint64(time.Now().Add(defaultMessageTimeout).UnixNano()),
},
Discard: func() {},
}
for _, f := range opt {
if err := f(&msg); err != nil {
return Envelope{}, fmt.Errorf("setting message option: %w", err)
}
}
return msg, nil
}
func ReplyTo(msg Envelope, payload interface{}, opt ...MessageOption) (Envelope, error) {
if msg.Options.ReplyTo == "" {
return Envelope{}, fmt.Errorf("no behavior to reply to: %w", ErrInvalidMessage)
}
msgOptions := []MessageOption{WithMessageExpiry(msg.Options.Expire)}
msgOptions = append(msgOptions, opt...)
return Message(msg.To, msg.From, msg.Options.ReplyTo, payload, msgOptions...)
}
// WithMessageContext provides the necessary envelope and signs it.
//
// NOTE: If this option must be passed last, otherwise the signature will be invalidated by further modifications.
//
// NOTE: Signing is implicit in Send.
func WithMessageSignature(sctx SecurityContext, invoke []Capability, delegate []Capability) MessageOption {
return func(msg *Envelope) error {
if !msg.From.ID.Equal(sctx.ID()) {
return ErrInvalidSecurityContext
}
msg.Nonce = sctx.Nonce()
if msg.IsBroadcast() {
return sctx.ProvideBroadcast(msg, msg.Options.Topic, invoke)
}
return sctx.Provide(msg, invoke, delegate)
}
}
// WithMessageTimeout sets the message expiration from a relative timeout
//
// NOTE: messages created with Message have an implicit timeout of DefaultMessageTimeout
func WithMessageTimeout(timeo time.Duration) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = uint64(time.Now().Add(timeo).UnixNano())
return nil
}
}
// WithMessageExpiry sets the message expiry
//
// NOTE: created with Message message have an implicit timeout of DefaultMessageTimeout
func WithMessageExpiry(expiry uint64) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = expiry
return nil
}
}
// WithMessageExpiry TODO
func WithMessageExpiryTime(t time.Time) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = uint64(t.UnixNano())
return nil
}
}
// WithMessageReplyTo sets the message replyto behavior
//
// NOTE: ReplyTo is set implicitly in Invoke and the appropriate capability
//
// tokens are delegated by Provide.
func WithMessageReplyTo(replyto string) MessageOption {
return func(msg *Envelope) error {
msg.Options.ReplyTo = replyto
return nil
}
}
// WithMessageTopic sets the broadcast topic
func WithMessageTopic(topic string) MessageOption {
return func(msg *Envelope) error {
msg.Options.Topic = topic
return nil
}
}
// WithMessageSource sets the message source
func WithMessageSource(source Handle) MessageOption {
return func(msg *Envelope) error {
msg.From = source
return nil
}
}
func (msg *Envelope) SignatureData() ([]byte, error) {
msgCopy := *msg
msgCopy.Signature = nil
data, err := json.Marshal(&msgCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (msg *Envelope) Expired() bool {
return uint64(time.Now().UnixNano()) > msg.Options.Expire
}
// convert the expiration to a time.Time object
func (msg *Envelope) Expiry() time.Time {
sec := msg.Options.Expire / uint64(time.Second)
nsec := msg.Options.Expire % uint64(time.Second)
return time.Unix(int64(sec), int64(nsec))
}
func (msg *Envelope) IsBroadcast() bool {
return msg.To.Empty() && msg.Options.Topic != ""
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"errors"
"sync"
"gitlab.com/nunet/device-management-service/lib/did"
)
type NoopActor struct {
mx sync.Mutex
ctx context.Context
security SecurityContext
handle Handle
supervisor Handle
parent Handle
children map[did.DID]Handle
// Testing helpers
sentMessages []Envelope
receivedMessages []Envelope
behaviors map[string]Behavior
invokeResponses map[string]Envelope
}
var _ Actor = (*NoopActor)(nil)
func NewNoopActor() *NoopActor {
return &NoopActor{
mx: sync.Mutex{},
ctx: context.Background(),
security: &BasicSecurityContext{},
handle: Handle{},
supervisor: Handle{},
parent: Handle{},
children: make(map[did.DID]Handle),
sentMessages: []Envelope{},
receivedMessages: []Envelope{},
behaviors: make(map[string]Behavior),
invokeResponses: make(map[string]Envelope),
}
}
func (c *NoopActor) Context() context.Context { return c.ctx }
func (c *NoopActor) Handle() Handle { return c.handle }
func (c *NoopActor) Security() SecurityContext { return c.security }
func (c *NoopActor) Supervisor() Handle { return c.supervisor }
func (c *NoopActor) UpdateSecurityContext(newSecurity SecurityContext) error {
c.mx.Lock()
defer c.mx.Unlock()
if newSecurity == nil {
return errors.New("new security context cannot be nil")
}
c.security = newSecurity
return nil
}
func (c *NoopActor) AddBehavior(name string, behavior Behavior, _ ...BehaviorOption) error {
c.mx.Lock()
defer c.mx.Unlock()
c.behaviors[name] = behavior
return nil
}
func (c *NoopActor) RemoveBehavior(name string) {
c.mx.Lock()
defer c.mx.Unlock()
delete(c.behaviors, name)
}
func (c *NoopActor) Receive(msg Envelope) error {
c.mx.Lock()
c.receivedMessages = append(c.receivedMessages, msg)
behavior, exists := c.behaviors[msg.Behavior]
c.mx.Unlock()
if exists && behavior != nil {
behavior(msg)
}
return nil
}
func (c *NoopActor) Send(msg Envelope) error {
c.mx.Lock()
defer c.mx.Unlock()
c.sentMessages = append(c.sentMessages, msg)
return nil
}
func (c *NoopActor) Invoke(msg Envelope) (<-chan Envelope, error) {
c.mx.Lock()
defer c.mx.Unlock()
ch := make(chan Envelope, 1)
if response, exists := c.invokeResponses[msg.Behavior]; exists {
ch <- response
}
close(ch)
return ch, nil
}
func (c *NoopActor) Publish(msg Envelope) error {
c.mx.Lock()
defer c.mx.Unlock()
c.sentMessages = append(c.sentMessages, msg)
return nil
}
func (c *NoopActor) Subscribe(_ string, _ ...BroadcastSetup) error { return nil }
func (c *NoopActor) Start() error { return nil }
func (c *NoopActor) Stop() error { return nil }
func (c *NoopActor) CreateChild(_ string, _ Handle, _ ...CreateChildOption) (Actor, error) {
return NewNoopActor(), nil
}
func (c *NoopActor) Parent() Handle { return c.parent }
func (c *NoopActor) Children() map[did.DID]Handle { return c.children }
func (c *NoopActor) Limiter() RateLimiter { return NoRateLimiter{} }
// Testing helper methods
func (c *NoopActor) GetSentMessages() []Envelope {
c.mx.Lock()
defer c.mx.Unlock()
messages := make([]Envelope, len(c.sentMessages))
copy(messages, c.sentMessages)
return messages
}
func (c *NoopActor) SetHandle(handle Handle) {
c.mx.Lock()
defer c.mx.Unlock()
c.handle = handle
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"errors"
"sync"
)
type Info struct {
Addr Handle
Parent Handle
Children []Handle
}
type Registry interface {
Actors() map[string]Info
Add(a Handle, parent Handle, children []Handle) error
Get(a Handle) (Info, bool)
SetParent(a Handle, parent Handle) error
GetParent(a Handle) (Handle, bool)
}
type registry struct {
mx sync.Mutex
actors map[string]Info
}
func newRegistry() *registry {
return ®istry{
actors: make(map[string]Info),
}
}
func (r *registry) Actors() map[string]Info {
r.mx.Lock()
defer r.mx.Unlock()
actors := make(map[string]Info, len(r.actors))
for k, v := range r.actors {
actors[k] = v
}
return actors
}
func (r *registry) Add(a Handle, parent Handle, children []Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
if _, ok := r.actors[a.Address.InboxAddress]; ok {
log.Warnf("overwritting actor %s already registered", a.Address.InboxAddress)
}
if children == nil {
children = []Handle{}
}
r.actors[a.Address.InboxAddress] = Info{
Addr: a,
Parent: parent,
Children: children,
}
return nil
}
func (r *registry) Get(a Handle) (Info, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
return info, ok
}
func (r *registry) SetParent(a Handle, parent Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return errors.New("actor not found")
}
info.Parent = parent
r.actors[a.Address.InboxAddress] = info
return nil
}
func (r *registry) GetParent(a Handle) (Handle, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return Handle{}, false
}
return info.Parent, true
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
type BasicSecurityContext struct {
id ID
privk crypto.PrivKey
cap ucan.CapabilityContext
mx sync.Mutex
nonce uint64
}
var _ SecurityContext = (*BasicSecurityContext)(nil)
func NewBasicSecurityContext(pubk crypto.PubKey, privk crypto.PrivKey, capCxt ucan.CapabilityContext) (*BasicSecurityContext, error) {
sctx := &BasicSecurityContext{
privk: privk,
cap: capCxt,
nonce: uint64(time.Now().UnixNano()),
}
var err error
sctx.id, err = crypto.IDFromPublicKey(pubk)
if err != nil {
return nil, fmt.Errorf("creating security context: %w", err)
}
return sctx, nil
}
func (s *BasicSecurityContext) ID() ID {
return s.id
}
func (s *BasicSecurityContext) DID() DID {
return s.cap.DID()
}
func (s *BasicSecurityContext) Nonce() uint64 {
s.mx.Lock()
defer s.mx.Unlock()
nonce := s.nonce
s.nonce++
return nonce
}
func (s *BasicSecurityContext) PrivKey() crypto.PrivKey {
return s.privk
}
func (s *BasicSecurityContext) Require(msg Envelope, invoke []Capability) error {
// if we are sending to self, nothing to do, signature is alredady verified
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return nil
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.Require(s.DID(), msg.From.ID, s.id, invoke); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) Provide(msg *Envelope, invoke []Capability, delegate []Capability) error {
// if we are sending to self, nothing to do, just Sign
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return s.Sign(msg)
}
tokens, err := s.cap.Provide(msg.To.DID, s.id, msg.To.ID, msg.Options.Expire, invoke, delegate)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) RequireBroadcast(msg Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.RequireBroadcast(s.DID(), msg.From.ID, topic, broadcast); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) ProvideBroadcast(msg *Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
tokens, err := s.cap.ProvideBroadcast(msg.From.ID, topic, msg.Options.Expire, broadcast)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) Verify(msg Envelope) error {
if msg.Expired() {
return ErrMessageExpired
}
pubk, err := crypto.PublicKeyFromID(msg.From.ID)
if err != nil {
return fmt.Errorf("public key from id: %w", err)
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
ok, err := pubk.Verify(data, msg.Signature)
if err != nil {
return fmt.Errorf("verify message signature: %w", err)
}
if !ok {
return ErrSignatureVerification
}
return nil
}
func (s *BasicSecurityContext) Sign(msg *Envelope) error {
if !s.id.Equal(msg.From.ID) {
return ErrBadSender
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
sig, err := s.privk.Sign(data)
if err != nil {
return fmt.Errorf("signing message: %w", err)
}
msg.Signature = sig
return nil
}
func (s *BasicSecurityContext) Grant(
sub, aud did.DID, caps []ucan.Capability, expiry time.Duration,
) error {
tokens, err := s.cap.Grant(
ucan.Delegate,
sub,
aud,
[]string{},
MakeExpiry(expiry),
1,
caps,
)
if err != nil {
return fmt.Errorf("create granting token for audience %s caps: %w", aud, err)
}
err = s.cap.AddRoots([]did.DID{}, tokens, ucan.TokenList{}, ucan.TokenList{})
if err != nil {
return fmt.Errorf("add roots for audience %s: %w", aud, err)
}
return nil
}
func (s *BasicSecurityContext) Discard(msg Envelope) {
s.cap.Discard(msg.Capability)
}
func (s *BasicSecurityContext) Capability() ucan.CapabilityContext {
return s.cap
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/multiformats/go-multiaddr"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
backgroundtasks "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
func MakeRootTrustContext(t *testing.T) (did.DID, did.TrustContext) {
t.Helper()
privk, _, err := crypto.GenerateKeyPair(crypto.Ed25519)
require.NoError(t, err)
return MakeTrustContext(t, privk)
}
func MakeTrustContext(t *testing.T, privk crypto.PrivKey) (did.DID, did.TrustContext) {
t.Helper()
provider, err := did.ProviderFromPrivateKey(privk)
require.NoError(t, err, "provider from public key")
ctx := did.NewTrustContext()
ctx.AddProvider(provider)
return provider.DID(), ctx
}
func MakeCapabilityContext(t *testing.T, actorDID, rootDID did.DID, trust, root did.TrustContext) ucan.CapabilityContext {
t.Helper()
actorCap, err := ucan.NewCapabilityContext(trust, actorDID, nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
rootCap, err := ucan.NewCapabilityContext(root, rootDID, nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
tokens, err := rootCap.Grant(
ucan.Delegate,
actorDID,
did.DID{},
nil,
MakeExpiry(time.Hour),
0,
[]ucan.Capability{ucan.Root},
)
require.NoError(t, err)
err = actorCap.AddRoots([]did.DID{rootDID}, ucan.TokenList{}, tokens, ucan.TokenList{})
require.NoError(t, err)
return actorCap
}
func MakeExpiry(d time.Duration) uint64 {
return uint64(time.Now().Add(d).UnixNano())
}
func AllowReciprocal(t *testing.T, actorCap ucan.CapabilityContext, rootTrust did.TrustContext, rootDID, otherRootDID did.DID, capability string) {
t.Helper()
rootCap, err := ucan.NewCapabilityContext(rootTrust, rootDID, nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
tokens, err := rootCap.Grant(
ucan.Delegate,
otherRootDID,
did.DID{},
nil,
MakeExpiry(time.Hour),
0,
[]ucan.Capability{ucan.Capability(capability)},
)
require.NoError(t, err)
err = actorCap.AddRoots(nil, tokens, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
}
func AllowBroadcast(t *testing.T, actor1, actor2 ucan.CapabilityContext, root1, root2 did.TrustContext, root1DID, root2DID did.DID, topic string, actorCap ...Capability) {
t.Helper()
root1Cap, err := ucan.NewCapabilityContext(root1, root1DID, nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
root2Cap, err := ucan.NewCapabilityContext(root2, root2DID, nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err)
tokens, err := root1Cap.Grant(
ucan.Delegate,
actor1.DID(),
did.DID{},
[]string{topic},
MakeExpiry(120*time.Second),
0,
actorCap,
)
require.NoError(t, err, "granting broadcast capability")
err = actor1.AddRoots(nil, ucan.TokenList{}, tokens, ucan.TokenList{})
require.NoError(t, err, "add roots")
tokens, err = root2Cap.Grant(
ucan.Delegate,
actor1.DID(),
did.DID{},
[]string{topic},
MakeExpiry(120*time.Second),
0,
actorCap,
)
require.NoError(t, err, "grant broadcast capability")
err = actor2.AddRoots(nil, tokens, ucan.TokenList{}, ucan.TokenList{})
require.NoError(t, err, "add roots")
}
func CreateActor(t *testing.T, peer network.Network, capCxt ucan.CapabilityContext) *BasicActor {
t.Helper()
privk, pubk, err := crypto.GenerateKeyPair(crypto.Ed25519)
require.NoError(t, err)
sctx, err := NewBasicSecurityContext(pubk, privk, capCxt)
assert.NoError(t, err)
params := BasicActorParams{}
uuid, err := uuid.NewUUID()
assert.NoError(t, err)
handle := Handle{
ID: sctx.id,
DID: capCxt.DID(),
Address: Address{
HostID: peer.GetHostID().String(),
InboxAddress: uuid.String(),
},
}
actor, err := New(Handle{}, peer, sctx, NewRateLimiter(DefaultRateLimiterConfig()), params, handle)
assert.NoError(t, err)
assert.NotNil(t, actor)
return actor
}
func NewLibp2pNetwork(t *testing.T, quicPort int, bootstrap []multiaddr.Multiaddr) ([]multiaddr.Multiaddr, crypto.PrivKey, *libp2p.Libp2p) {
t.Helper()
priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519)
assert.NoError(t, err)
net, err := network.NewNetwork(&types.NetworkConfig{
Type: types.Libp2pNetwork,
Libp2pConfig: types.Libp2pConfig{
Env: "test",
PrivateKey: priv,
BootstrapPeers: bootstrap,
Rendezvous: "nunet-randevouz",
Server: false,
Scheduler: backgroundtasks.NewScheduler(1, time.Second),
CustomNamespace: "/nunet-dht-1/",
PeerCountDiscoveryLimit: 40,
GossipMaxMessageSize: 2 << 16,
ListenAddress: []string{"/ip4/0.0.0.0/tcp/3001", fmt.Sprintf("/ip4/0.0.0.0/udp/%d/quic-v1", quicPort)},
},
}, afero.NewMemMapFs())
assert.NoError(t, err)
err = net.Init(&config.Config{
General: config.General{
Env: "test",
},
P2P: config.P2P{
ListenAddress: []string{"/ip4/0.0.0.0/tcp/3001", fmt.Sprintf("/ip4/0.0.0.0/udp/%d/quic-v1", quicPort)},
},
})
assert.NoError(t, err)
err = net.Start()
assert.NoError(t, err)
libp2pInstance, _ := net.(*libp2p.Libp2p)
multi, err := libp2pInstance.GetMultiaddr()
assert.NoError(t, err)
return multi, priv, libp2pInstance
}
func nonZeroID() crypto.ID { return crypto.ID{PublicKey: []byte{1}} }
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package api
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// log is the logger for the actor API package
var log = logging.Logger("actor-api")
const (
ErrHostNotInitialized = "host node hasn't yet been initialized"
)
// ActorHandle godoc
//
// @Summary Retrieve actor handle
// @Description Retrieve actor handle with ID, DID, and inbox address
// @Tags actor
// @Produce json
// @Success 200 {object} actor.Handle
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "handle id is invalid"
// @Router /actor/handle [get]
func (rs *Server) ActorHandle(c *gin.Context) {
if rs.config.P2P == nil {
log.Errorw("actor_handle_retrieve_failure", "error", ErrHostNotInitialized)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": ErrHostNotInitialized})
return
}
// get handle here
pubk := rs.config.P2P.GetPeerPubKey(rs.config.P2P.GetHostID())
id, err := crypto.IDFromPublicKey(pubk)
if err != nil {
log.Errorw("actor_handle_retrieve_failure", "error", "handle id is invalid")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "handle id is invalid"})
return
}
actorDID := did.FromPublicKey(pubk)
handle := actor.Handle{
ID: id,
DID: actorDID,
Address: actor.Address{
HostID: rs.config.P2P.GetHostID().String(),
InboxAddress: "root",
},
}
log.Debugw("actor_handle_retrieve_success", "id", id, "DID", actorDID)
c.JSON(http.StatusOK, handle)
}
// ActorSendMessage godoc
//
// @Summary Send message to actor
// @Description Send message to actor
// @Tags actor
// @Accept json
// @Produce json
// @Param message body actor.Envelope true "Message to send"
// @Success 200 {object} object "message sent"
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "failed to marshal message"
// @Failure 500 {object} object "destination address can't be resolved"
// @Failure 500 {object} object "failed to send message to destination"
// @Router /actor/send [post]
func (rs *Server) ActorSendMessage(c *gin.Context) {
endSpan := observability.StartSpan(c, "actor_send_message")
defer endSpan()
var msg actor.Envelope
if err := c.ShouldBindJSON(&msg); err != nil {
log.Errorw("actor_send_message_failure", "error", err.Error())
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p2p := rs.config.P2P
if p2p == nil {
log.Errorw("actor_send_message_failure", "error", ErrHostNotInitialized)
c.JSON(http.StatusInternalServerError, gin.H{"error": ErrHostNotInitialized})
return
}
err := sendMessage(c.Request.Context(), p2p, msg)
if err != nil {
log.Errorw("actor_send_message_failure", "error", err, "destination", msg.To.Address.HostID)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
log.Infow("actor_send_message_success", "destination", msg.To.Address.HostID)
c.JSON(http.StatusOK, gin.H{"message": "message sent"})
}
// ActorInvoke godoc
//
// @Summary Invoke actor
// @Description Invoke actor with message
// @Tags actor
// @Accept json
// @Produce json
// @Param message body actor.Envelope true "Message to send"
// @Success 200 {object} object "response message"
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "failed to marshal message"
// @Failure 500 {object} object "destination address can't be resolved"
// @Failure 500 {object} object "failed to send message to destination"
// @Router /actor/invoke [post]
func (rs *Server) ActorInvoke(c *gin.Context) {
var msg actor.Envelope
if err := c.ShouldBindJSON(&msg); err != nil {
log.Errorw("actor_invoke_failure", "error", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p2p := rs.config.P2P
if p2p == nil {
log.Errorw("actor_invoke_failure", "error", ErrHostNotInitialized)
c.JSON(http.StatusInternalServerError, gin.H{"error": ErrHostNotInitialized})
return
}
// Register a message handler for the responseCh
protocol := fmt.Sprintf("actor/%s/messages/0.0.1", msg.From.Address.InboxAddress)
responseCh := make(chan actor.Envelope, 1)
err := p2p.HandleMessage(protocol, func(data []byte, _ peer.ID) {
var envelope actor.Envelope
if err := json.Unmarshal(data, &envelope); err != nil {
log.Errorw("actor_invoke_response_failure", "error", err)
return
}
responseCh <- envelope
})
if err != nil {
log.Errorw("actor_invoke_failure", "error", err, "behavior", msg.Behavior)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Unregister the message handler before returning
defer p2p.UnregisterMessageHandler(protocol)
err = sendMessage(c.Request.Context(), p2p, msg)
if err != nil {
log.Errorw("actor_invoke_failure", "error", err, "behavior", msg.Behavior)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
select {
case responseMsg := <-responseCh:
log.Debugw("actor_invoke_success", "destination", msg.To.Address.HostID)
c.JSON(http.StatusOK, responseMsg)
return
case <-time.After(time.Until(msg.Expiry())):
log.Errorw("actor_invoke_failure", "error", "request timeout")
c.JSON(http.StatusRequestTimeout, gin.H{"error": "request timeout"})
return
case <-c.Request.Context().Done():
log.Errorw("actor_invoke_failure", "error", "request timeout")
c.JSON(http.StatusRequestTimeout, gin.H{"error": "request timeout"})
return
}
}
// ActorBroadcast godoc
//
// @Summary Broadcast message to actors
// @Description Broadcast message to actors
// @Tags actor
// @Accept json
// @Produce json
// @Param message body actor.Envelope true "Message to send"
// @Success 200 {object} object "received responses"
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "failed to marshal message"
// @Failure 500 {object} object "failed to publish message"
// @Router /actor/broadcast [post]
func (rs *Server) ActorBroadcast(c *gin.Context) {
endSpan := observability.StartSpan(c, "actor_broadcast")
defer endSpan()
var msg actor.Envelope
if err := c.ShouldBindJSON(&msg); err != nil {
log.Errorw("actor_broadcast_failure", "error", err)
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
p2p := rs.config.P2P
if p2p == nil {
log.Errorw("actor_broadcast_failure", "error", ErrHostNotInitialized)
c.JSON(http.StatusInternalServerError, gin.H{"error": ErrHostNotInitialized})
return
}
if !msg.IsBroadcast() {
log.Errorw("actor_broadcast_failure", "error", "message is not a broadcast message")
c.JSON(http.StatusBadRequest, gin.H{"error": "message is not a broadcast message"})
return
}
data, err := json.Marshal(msg)
if err != nil {
log.Errorw("actor_broadcast_failure", "error", "failed to marshal message")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to marshal message"})
return
}
// register message handler to collect responses
protocol := fmt.Sprintf("actor/%s/messages/0.0.1", msg.From.Address.InboxAddress)
var messages []actor.Envelope
var mu sync.Mutex
err = p2p.HandleMessage(protocol, func(data []byte, _ peer.ID) {
var envelope actor.Envelope
if err = json.Unmarshal(data, &envelope); err != nil {
log.Errorw("actor_broadcast_failure", "error", "failed to unmarshal response message")
return
}
mu.Lock()
messages = append(messages, envelope)
mu.Unlock()
})
if err != nil {
log.Errorw("actor_broadcast_failure", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// Unregister the message handler before returning
defer p2p.UnregisterMessageHandler(protocol)
// Publish the message
if err := p2p.Publish(c.Request.Context(), msg.Options.Topic, data); err != nil {
log.Errorw("actor_broadcast_failure", "error", "failed to publish message")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to publish message"})
return
}
// Wait for either context done or timeout
select {
case <-time.After(time.Until(msg.Expiry())):
// message expiry time reached
case <-c.Request.Context().Done():
// request context done
}
log.Infow("actor_broadcast_success", "fromAddress", msg.From.Address.HostID, "responsesCount", len(messages))
c.JSON(http.StatusOK, messages)
}
func sendMessage(ctx context.Context, net network.Network, msg actor.Envelope) (err error) {
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("failed to marshal message: %w", err)
}
err = net.SendMessageSync(
ctx,
msg.To.Address.HostID,
types.MessageEnvelope{
Type: types.MessageType(
fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress),
),
Data: data,
},
msg.Expiry(),
)
if err != nil {
return fmt.Errorf("failed to send message to %s: %w", msg.To.ID, err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package api
import (
"errors"
"fmt"
"net/http"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
type ServerConfig struct {
P2P network.Network
Onboarding *onboarding.Onboarding
Resource types.ResourceManager
Middlewares []gin.HandlerFunc
Port uint32
Addr string
}
// getCorsConfig returns the default CORS configuration
func getCorsConfig() cors.Config {
return cors.Config{
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
AllowHeaders: []string{"Access-Control-Allow-Origin", "Origin", "Content-Length", "Content-Type"},
AllowOrigins: []string{"http://localhost:9991", "http://localhost:9992"}, // TODO: this is a security risk
AllowCredentials: false,
MaxAge: 12 * time.Hour,
}
}
func setupRouter(middlewares []gin.HandlerFunc) *gin.Engine {
middlewares = append(middlewares, cors.New(getCorsConfig()))
router := gin.Default()
router.Use(middlewares...)
return router
}
// Server represents a REST server
type Server struct {
router *gin.Engine
config *ServerConfig
dmsConfig *config.Config
}
// NewServer creates a new REST server
func NewServer(config *ServerConfig, dmsConfig *config.Config) *Server {
rs := &Server{
router: setupRouter(config.Middlewares),
config: config,
dmsConfig: dmsConfig,
}
log.Infow("rest_server_init_success", "addr", config.Addr, "port", config.Port)
return rs
}
// HealthCheck is a health check endpoint
func (rs *Server) HealthCheck(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
}
// Config returns dms config
func (rs *Server) Config(c *gin.Context) {
if rs.dmsConfig.General.Debug {
c.JSON(200, gin.H{"config": rs.dmsConfig})
} else {
c.JSON(200, gin.H{"config": "allowed in debug mode"})
}
}
// SetupRoutes sets up all the endpoint routes
func (rs *Server) SetupRoutes() {
// /health route
rs.router.GET("/health", rs.HealthCheck)
rs.router.GET("/config", rs.Config)
v1 := rs.router.Group("/api/v1")
// /actor routes
actor := v1.Group("/actor")
{
actor.GET("/handle", rs.ActorHandle)
actor.POST("/send", rs.ActorSendMessage)
actor.POST("/invoke", rs.ActorInvoke)
actor.POST("/broadcast", rs.ActorBroadcast)
}
log.Infow("rest_server_route_init_success", "endpoint", "/api/v1/actor")
}
// Run starts the server on the specified port
func (rs *Server) Run() error {
addr := fmt.Sprintf("%s:%d", rs.config.Addr, rs.config.Port)
log.Infow("rest_server_starting", "addr", addr)
if err := rs.router.Run(addr); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Errorw("rest_server_run_failure", "addr", addr, "error", err)
return err
}
log.Infow("rest_server_run_success", "addr", addr)
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"encoding/json"
"fmt"
"strings"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
)
const (
ActorHandleEndpoint = "/actor/handle"
ActorSendMessageEndpoint = "/actor/send"
ActorInvokeEndpoint = "/actor/invoke"
ActorBroadcastEndpoint = "/actor/broadcast"
)
// GetDMSHandle retrieves the DMS handle from the server
func (c *Client) GetDMSHandle(ctx context.Context) (actor.Handle, error) {
if !c.dmsHandle.Empty() {
return c.dmsHandle, nil
}
err := c.get(ctx, ActorHandleEndpoint, nil, &c.dmsHandle)
if err != nil {
return actor.Handle{}, fmt.Errorf("get source handle: %w", err)
}
return c.dmsHandle, nil
}
// parseDestinationHandle parses a destination string into a handle
func (c *Client) parseDestinationHandle(destStr string) (actor.Handle, error) {
// Input validation
if destStr == "" {
return actor.Handle{}, fmt.Errorf("empty destination string")
}
// First check if it's a DID (starts with "did:")
if strings.HasPrefix(destStr, "did:") {
dest, err := actor.HandleFromDID(destStr)
if err != nil {
return actor.Handle{}, fmt.Errorf("failed to parse DID handle: %w", err)
}
return dest, nil
}
// Try to parse as JSON handle - don't return if it fails
var jsonDest actor.Handle
if err := json.Unmarshal([]byte(destStr), &jsonDest); err == nil {
// Successfully parsed as JSON
return jsonDest, nil
}
// Default: try to parse as a peer ID
dest, err := actor.HandleFromPeerID(destStr)
if err != nil {
return actor.Handle{}, fmt.Errorf("failed to parse peer ID handle: %w", err)
}
return dest, nil
}
// newUserHandle creates a new user handle
func (c *Client) newUserHandle(id crypto.ID, userDID did.DID, dmsHandle actor.Handle, inbox string) actor.Handle {
return actor.Handle{
ID: id,
DID: userDID,
Address: actor.Address{
HostID: dmsHandle.Address.HostID,
InboxAddress: inbox,
},
}
}
// newClient creates a new client
func (c *Client) unmarshalResponse(resp actor.Envelope, v any) error {
if resp.Message == nil {
return nil
}
if err := json.Unmarshal(resp.Message, v); err != nil {
return fmt.Errorf("unmarshal response: %w", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
// NewMessage creates a new actor message with the specified behavior and payload
func (c *Client) NewActorMessage(ctx context.Context, behavior string, payload any, msgOpts MessageOptions) (actor.Envelope, error) {
dmsHandle, err := c.GetDMSHandle(ctx)
if err != nil {
return actor.Envelope{}, fmt.Errorf("get DMS handle: %w", err)
}
// Create user handle
nonce := c.sctx.Nonce()
inbox := fmt.Sprintf("user-%d", nonce)
src := c.newUserHandle(c.sctx.ID(), c.sctx.DID(), dmsHandle, inbox)
// Handle destination
var dest actor.Handle
opts := []actor.MessageOption{}
replyTo := ""
// Configure behavior based on message type
switch {
case msgOpts.Topic != "":
opts = append(opts, actor.WithMessageTopic(msgOpts.Topic))
replyTo = fmt.Sprintf("/public/user/%d", nonce)
case msgOpts.Destination != "":
// Parse destination string into a handle
var err error
dest, err = c.parseDestinationHandle(msgOpts.Destination)
if err != nil {
return actor.Envelope{}, fmt.Errorf("create destination handle: %w", err)
}
default:
dest = dmsHandle
}
// Handle invocation flag
if msgOpts.IsInvocation {
replyTo = fmt.Sprintf("/private/user/%d", nonce)
}
// Handle expiry
if !msgOpts.Expiry.IsZero() {
opts = append(opts, actor.WithMessageExpiry(uint64(msgOpts.Expiry.UnixNano())))
}
// Handle timeout
if msgOpts.Timeout > 0 {
opts = append(opts, actor.WithMessageTimeout(msgOpts.Timeout))
}
// TODO: Do we delegate capabilities here?
// Handle reply address
// delegate := []ucan.Capability{}
if replyTo != "" || msgOpts.ReplyTo != "" {
if msgOpts.ReplyTo != "" {
replyTo = msgOpts.ReplyTo
}
opts = append(opts, actor.WithMessageReplyTo(replyTo))
// if msgOpts.Topic == "" {
// delegate = append(delegate, ucan.Capability(replyTo))
// }
}
// Add message signature
opts = append(opts, actor.WithMessageSignature(c.sctx, []ucan.Capability{ucan.Capability(behavior)}, []ucan.Capability{}))
// Create the message
msg, err := actor.Message(src, dest, behavior, payload, opts...)
if err != nil {
return actor.Envelope{}, fmt.Errorf("construct message: %w", err)
}
return msg, nil
}
// SendMessage sends a message to a specific actor
func (c *Client) SendMessage(ctx context.Context, behavior string, payload any, msgOpts ...Option) (actor.Envelope, error) {
// Create message
opts := NewMessageOptions(msgOpts...)
msg, err := c.NewActorMessage(ctx, behavior, payload, opts)
if err != nil {
return actor.Envelope{}, fmt.Errorf("create actor message: %w", err)
}
// Send message
return c.SendMessageRaw(ctx, msg)
}
// SendMessageRaw sends a message to a specific actor
func (c *Client) SendMessageRaw(ctx context.Context, msg actor.Envelope) (actor.Envelope, error) {
// Send message
var response actor.Envelope
err := c.post(ctx, ActorSendMessageEndpoint, nil, msg, &response)
if err != nil {
return response, fmt.Errorf("send message: %w", err)
}
return response, nil
}
// InvokeBehavior invokes a behavior on an actor
func (c *Client) InvokeBehavior(ctx context.Context, behavior string, payload any, msgOpts ...Option) (actor.Envelope, error) {
// Create message
opts := NewMessageOptions(msgOpts...)
opts.IsInvocation = true
// Apply timeout to context if specified
if opts.Timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, opts.Timeout)
defer cancel()
}
msg, err := c.NewActorMessage(ctx, behavior, payload, opts)
if err != nil {
return actor.Envelope{}, fmt.Errorf("create actor message: %w", err)
}
// Invoke behavior
return c.InvokeBehaviorRaw(ctx, msg)
}
// InvokeBehaviorRaw invokes a behavior on an actor
func (c *Client) InvokeBehaviorRaw(ctx context.Context, msg actor.Envelope) (actor.Envelope, error) {
// Invoke behavior
var response actor.Envelope
err := c.post(ctx, ActorInvokeEndpoint, nil, msg, &response)
if err != nil {
return response, fmt.Errorf("invoke behavior: %w", err)
}
return response, nil
}
// BroadcastMessage broadcasts a message to a topic
func (c *Client) BroadcastMessage(ctx context.Context, behavior, topic string, payload any, msgOpts ...Option) ([]actor.Envelope, error) {
opts := NewMessageOptions(msgOpts...)
opts.Topic = topic
// Verify that a topic is provided
if opts.Topic == "" {
return nil, fmt.Errorf("broadcast requires a topic")
}
msg, err := c.NewActorMessage(ctx, behavior, payload, opts)
if err != nil {
return nil, fmt.Errorf("create actor message: %w", err)
}
// Broadcast message
return c.BroadcastMessageRaw(ctx, msg)
}
// BroadcastMessageRaw broadcasts a message to a topic
func (c *Client) BroadcastMessageRaw(ctx context.Context, msg actor.Envelope) ([]actor.Envelope, error) {
// Broadcast message
var responses []actor.Envelope
err := c.post(ctx, ActorBroadcastEndpoint, nil, msg, &responses)
if err != nil {
return nil, fmt.Errorf("broadcast message: %w", err)
}
return responses, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
func (c *Client) AllocationsList(
ctx context.Context,
opts ...Option,
) (node.AllocationsListResponse, error) {
var response node.AllocationsListResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.AllocationsListBehavior,
nil,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.AllocationsListBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
func (c *Client) CapList(ctx context.Context, req node.CapListRequest, opts ...Option) (node.CapListResponse, error) {
var response node.CapListResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.CapListBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.CapListBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) ProvideCapAnchor(ctx context.Context, req node.CapTokenAnchorRequest, opts ...Option) (node.CapAnchorResponse, error) {
var response node.CapAnchorResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ProvideCapAnchorBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ProvideCapAnchorBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) RequireCapAnchor(ctx context.Context, req node.CapTokenAnchorRequest, opts ...Option) (node.CapAnchorResponse, error) {
var response node.CapAnchorResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.RequireCapAnchorBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.RequireCapAnchorBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) RevokeCapAnchor(ctx context.Context, req node.CapTokenAnchorRequest, opts ...Option) (node.CapAnchorResponse, error) {
var response node.CapAnchorResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.RevokeCapAnchorBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.RevokeCapAnchorBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// BroadcastCapRevoke broadcasts a capability revocation message
func (c *Client) BroadcastCapRevoke(ctx context.Context, req node.CapTokenAnchorRequest, msgOpts ...Option) ([]node.CapAnchorResponse, error) {
resp, err := c.BroadcastMessage(
ctx,
behaviors.BroadcastRevokeCapBehavior,
behaviors.BroadcastRevocationTopic,
req,
msgOpts...,
)
if err != nil {
return nil, fmt.Errorf("%s: %w", behaviors.BroadcastRevokeCapBehavior, err)
}
response := make([]node.CapAnchorResponse, 0, len(resp))
for _, r := range resp {
var msg node.CapAnchorResponse
if err = c.unmarshalResponse(r, &msg); err != nil {
return nil, err
}
response = append(response, msg)
}
return response, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
)
func (c *Client) NewContract(ctx context.Context, req contracts.CreateContractRequest, opts ...Option) (contracts.CreateContractResponse, error) {
var response contracts.CreateContractResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractCreateBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractCreateBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) ContractStatus(ctx context.Context, req contracts.ContractStatusRequest, opts ...Option) (contracts.ContractStatusResponse, error) {
var response contracts.ContractStatusResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractStatusBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractStatusBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) ApproveLocal(ctx context.Context, req contracts.ContractApproveLocalRequest, opts ...Option) (contracts.ContractApproveLocalResponse, error) {
var response contracts.ContractApproveLocalResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractApproveLocalBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractApproveLocalBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) ListIncoming(ctx context.Context, req contracts.ContractListIncomingRequest, opts ...Option) (contracts.ContractListIncomingResponse, error) {
var response contracts.ContractListIncomingResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractListBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractListBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) ListTransactions(ctx context.Context, opts ...Option) (contracts.ContractListLocalTransactionsResponse, error) {
var response contracts.ContractListLocalTransactionsResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractListLocalTransactionsBehavior,
struct{}{},
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractListLocalTransactionsBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) ConfirmTransaction(ctx context.Context, req contracts.ContractConfirmLocalTransactionRequest, opts ...Option) (contracts.ContractConfirmLocalTransactionResponse, error) {
var response contracts.ContractConfirmLocalTransactionResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractConfirmLocalTransactionBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractConfirmLocalTransactionBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) CollectUsagesAndForwardToPaymentProviders(ctx context.Context, req contracts.CollectUsagesAndForwardToPaymentProvidersRequest, opts ...Option) (contracts.CollectUsagesAndForwardToPaymentProvidersReponse, error) {
var response contracts.CollectUsagesAndForwardToPaymentProvidersReponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractUsagesCalculateBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractUsagesCalculateBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) GetPaymentStatus(ctx context.Context, req contracts.ContractPaymentStatusRequest, opts ...Option) (contracts.ContractPaymentStatusResponse, error) {
var response contracts.ContractPaymentStatusResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractPaymentStatusBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractPaymentStatusBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) TerminateContract(ctx context.Context, req contracts.ContractTerminationRequest, opts ...Option) (contracts.ContractTerminationResponse, error) {
var response contracts.ContractTerminationResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractTerminationBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractTerminationBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) SettleContract(ctx context.Context, req contracts.ContractSettleRequest, opts ...Option) (contracts.ContractSettleResponse, error) {
var response contracts.ContractSettleResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractSettleBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractSettleBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) CompleteContract(ctx context.Context, req contracts.ContractCompletionRequest, opts ...Option) (contracts.ContractCompletionResponse, error) {
var response contracts.ContractCompletionResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractCompleteBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractCompleteBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) ValidateContract(ctx context.Context, req contracts.ContractValidateRequest, opts ...Option) (contracts.ContractValidateResponse, error) {
var response contracts.ContractValidateResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ContractValidationBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ContractValidationBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
func (c *Client) DeploymentList(
ctx context.Context, req node.DeploymentListRequest,
opts ...Option,
) (node.DeploymentListResponse, error) {
var response node.DeploymentListResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DeploymentListBehavior,
req,
opts...,
)
if err != nil {
return response,
fmt.Errorf("%s: %w", behaviors.DeploymentListBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DeploymentStatus(ctx context.Context, req node.DeploymentStatusRequest, opts ...Option) (node.DeploymentStatusResponse, error) {
var response node.DeploymentStatusResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DeploymentStatusBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.DeploymentStatusBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DeploymentLogs(ctx context.Context, req node.DeploymentLogsRequest, opts ...Option) (node.DeploymentLogsResponse, error) {
var response node.DeploymentLogsResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DeploymentLogsBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.DeploymentLogsBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DeploymentManifest(ctx context.Context, req node.DeploymentManifestRequest, opts ...Option) (node.DeploymentManifestResponse, error) {
var response node.DeploymentManifestResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DeploymentManifestBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.DeploymentManifestBehavior, err)
}
// Unmarshal response
if err = json.Unmarshal(resp.Message, &response); err != nil {
return response, fmt.Errorf("unmarshal response: %w", err)
}
return response, nil
}
func (c *Client) DeploymentInfo(ctx context.Context, req node.DeploymentInfoRequest, opts ...Option) (node.DeploymentInfoResponse, error) {
var response node.DeploymentInfoResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DeploymentInfoBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.DeploymentInfoBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DeploymentShutdown(ctx context.Context, req node.DeploymentShutdownRequest, opts ...Option) (node.DeploymentShutdownResponse, error) {
var response node.DeploymentShutdownResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DeploymentShutdownBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.DeploymentShutdownBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DeploymentNew(ctx context.Context, req node.NewDeploymentRequest, opts ...Option) (node.NewDeploymentResponse, error) {
var response node.NewDeploymentResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.NewDeploymentBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.NewDeploymentBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DeploymentUpdate(
ctx context.Context, req node.UpdateDeploymentRequest,
opts ...Option,
) (node.UpdateDeploymentResponse, error) {
var response node.UpdateDeploymentResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DeploymentUpdateBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.DeploymentUpdateBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DeploymentPrune(ctx context.Context, req node.DeploymentPruneRequest, opts ...Option) (node.DeploymentPruneResponse, error) {
var response node.DeploymentPruneResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DeploymentPruneBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.DeploymentPruneBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DeploymentDelete(ctx context.Context, req node.DeploymentDeleteRequest, opts ...Option) (node.DeploymentDeleteResponse, error) {
var response node.DeploymentDeleteResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DeploymentDeleteBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.DeploymentDeleteBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
func (c *Client) LoggerConfig(ctx context.Context, req node.LoggerConfigRequest, opts ...Option) (node.LoggerConfigResponse, error) {
var response node.LoggerConfigResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.LoggerConfigBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.LoggerConfigBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
func (c *Client) Onboard(ctx context.Context, req node.OnboardRequest, opts ...Option) (node.OnboardResponse, error) {
var response node.OnboardResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.OnboardBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.OnboardBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) Offboard(ctx context.Context, req node.OffboardRequest, opts ...Option) (node.OffboardResponse, error) {
var response node.OffboardResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.OffboardBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.OffboardBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) OnboardStatus(ctx context.Context, opts ...Option) (node.OnboardStatusResponse, error) {
var response node.OnboardStatusResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.OnboardStatusBehavior,
nil,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.OnboardStatusBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
func (c *Client) PeersSelf(ctx context.Context, msgOpts ...Option) (node.PeerAddrInfoResponse, error) {
var response node.PeerAddrInfoResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.PeerAddrInfoBehavior,
nil,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.PeerAddrInfoBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) PeersList(ctx context.Context, msgOpts ...Option) (node.PeersListResponse, error) {
var response node.PeersListResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.PeersListBehavior,
nil,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.PeersListBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) PeersListFromDHT(ctx context.Context, msgOpts ...Option) (node.PeerDHTResponse, error) {
var response node.PeerDHTResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.PeerDHTBehavior,
nil,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.PeerDHTBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) PeerPing(ctx context.Context, req node.PingRequest, msgOpts ...Option) (node.PingResponse, error) {
var response node.PingResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.PeerPingBehavior,
req,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.PeerPingBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) PeerConnect(ctx context.Context, req node.PeerConnectRequest, msgOpts ...Option) (node.PeerConnectResponse, error) {
var response node.PeerConnectResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.PeerConnectBehavior,
req,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.PeerConnectBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) PeerScore(ctx context.Context, msgOpts ...Option) (node.PeerScoreResponse, error) {
var response node.PeerScoreResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.PeerScoreBehavior,
nil,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.PeerScoreBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) Flightrec(ctx context.Context, msgOpts ...Option) (node.PingResponse, error) {
var response node.PingResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.DebugFlightrecBehavior,
nil,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.DebugFlightrecBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
// Hello sends a hello message
func (c *Client) Hello(ctx context.Context, msgOpts ...Option) (node.HelloResponse, error) {
var response node.HelloResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.PublicHelloBehavior,
nil,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.PublicStatusBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// BroadcastHello broadcasts a hello message to a topic
func (c *Client) BroadcastHello(ctx context.Context, msgOpts ...Option) ([]node.HelloResponse, error) {
resp, err := c.BroadcastMessage(
ctx,
behaviors.BroadcastHelloBehavior,
behaviors.BroadcastHelloTopic,
nil,
msgOpts...,
)
if err != nil {
return nil, fmt.Errorf("%s: %w", behaviors.BroadcastHelloBehavior, err)
}
response := make([]node.HelloResponse, 0, len(resp))
for _, r := range resp {
var msg node.HelloResponse
if err = c.unmarshalResponse(r, &msg); err != nil {
return nil, err
}
response = append(response, msg)
}
return response, nil
}
func (c *Client) Status(ctx context.Context, msgOpts ...Option) (node.PublicStatusResponse, error) {
var response node.PublicStatusResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.PublicStatusBehavior,
nil,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.PublicStatusBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) Discovery(ctx context.Context, msgOpts ...Option) (node.DiscoveryStatusResponse, error) {
var response node.DiscoveryStatusResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.StatusDiscoveryBehavior,
nil,
msgOpts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.StatusDiscoveryBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DiscoveryBroadcast(ctx context.Context, msgOpts ...Option) ([]node.DiscoveryStatusResponse, error) {
resp, err := c.BroadcastMessage(
ctx,
behaviors.BroadcastStatusDiscoveryBehavior,
behaviors.BroadcastStatusDiscoveryTopic,
nil,
msgOpts...,
)
if err != nil {
return nil, fmt.Errorf("%s: %w", behaviors.BroadcastStatusDiscoveryBehavior, err)
}
response := make([]node.DiscoveryStatusResponse, 0, len(resp))
for _, r := range resp {
var msg node.DiscoveryStatusResponse
if err = c.unmarshalResponse(r, &msg); err != nil {
return nil, err
}
response = append(response, msg)
}
return response, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
func (c *Client) ResourcesAllocated(ctx context.Context, opts ...Option) (node.ResourcesResponse, error) {
var response node.ResourcesResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ResourcesAllocatedBehavior,
nil,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ResourcesAllocatedBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) ResourcesFree(ctx context.Context, opts ...Option) (node.ResourcesResponse, error) {
var response node.ResourcesResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ResourcesFreeBehavior,
nil,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ResourcesFreeBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) ResourcesOnboarded(ctx context.Context, opts ...Option) (node.ResourcesResponse, error) {
var response node.ResourcesResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.ResourcesOnboardedBehavior,
nil,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.ResourcesOnboardedBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) HardwareSpec(ctx context.Context, opts ...Option) (node.ResourcesResponse, error) {
var response node.ResourcesResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.HardwareSpecBehavior,
nil,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.HardwareSpecBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) HardwareUsage(ctx context.Context, opts ...Option) (node.ResourcesResponse, error) {
var response node.ResourcesResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.HardwareUsageBehavior,
nil,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.HardwareUsageBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
func (c *Client) CreateVolume(ctx context.Context, req node.CreateVolumeRequest, opts ...Option) (node.CreateVolumeResponse, error) {
var response node.CreateVolumeResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.VolumeCreateBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.VolumeCreateBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) DeleteVolume(ctx context.Context, req node.DeleteVolumeRequest, opts ...Option) (node.DeleteVolumeResponse, error) {
var response node.DeleteVolumeResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.VolumeDeleteBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.VolumeDeleteBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
func (c *Client) StartVolume(ctx context.Context, req node.StartVolumeRequest, opts ...Option) (node.StartVolumeResponse, error) {
var response node.StartVolumeResponse
resp, err := c.InvokeBehavior(
ctx,
behaviors.VolumeStartBehavior,
req,
opts...,
)
if err != nil {
return response, fmt.Errorf("%s: %w", behaviors.VolumeStartBehavior, err)
}
err = c.unmarshalResponse(resp, &response)
return response, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"path"
"reflect"
"strings"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
// ConnectionType represents the different ways to connect to the DMS
type ConnectionType string
const (
ConnectionTCP ConnectionType = "tcp"
ConnectionUnixSocket ConnectionType = "unix"
ConnectionNPipe ConnectionType = "npipe"
)
// Config provides configuration for the DMS client
type Config struct {
// Connection details
Host string
Protocol ConnectionType
APIPrefix string
Version string
// TLS configuration
TLSConfig *tls.Config
// Timeouts
ConnectTimeout time.Duration
RequestTimeout time.Duration
}
// Client represents the main client for interacting with the DMS
type Client struct {
// HTTP client for making requests
httpClient *http.Client
// Connection details
host string
protocol ConnectionType
apiPrefix string
version string
// Options for client behavior
options Config
// Actor options
sctx actor.SecurityContext
dmsHandle actor.Handle
}
var _ DmsClient = (*Client)(nil)
func NewClientSecurityContext(priv io.Reader, capData io.Reader) (actor.SecurityContext, error) {
// Generate ephemeral key pair for this session
privk, pubk, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return nil, fmt.Errorf("generate ephemeral key pair: %w", err)
}
// Create trust context
trustKeyData, err := io.ReadAll(priv)
if err != nil {
return nil, fmt.Errorf("read private key: %w", err)
}
trustPriv, err := crypto.BytesToPrivateKey(trustKeyData)
if err != nil {
return nil, fmt.Errorf("unmarshal private key: %w", err)
}
trustCtx, err := did.NewTrustContextWithPrivateKey(trustPriv)
if err != nil {
return nil, fmt.Errorf("create trust context: %w", err)
}
// Create capability context
capCtx, err := ucan.LoadCapabilityContext(trustCtx, capData)
if err != nil {
return nil, fmt.Errorf("create capability context: %w", err)
}
return actor.NewBasicSecurityContext(pubk, privk, capCtx)
}
// NewClient creates a new DMS client with the given options
func NewClient(cfg Config, securityContext actor.SecurityContext) (*Client, error) {
// Create transport based on connection type
transport, err := createTransport(cfg)
if err != nil {
return nil, fmt.Errorf("failed to create transport: %w", err)
}
return NewClientWithTransport(cfg, transport, securityContext)
}
func NewClientWithTransport(cfg Config, transport http.RoundTripper, securityContext actor.SecurityContext) (*Client, error) {
// Set default values
if cfg.ConnectTimeout == 0 {
cfg.ConnectTimeout = 10 * time.Second
}
if cfg.RequestTimeout == 0 {
cfg.RequestTimeout = 30 * time.Second
}
if cfg.Protocol == "" {
cfg.Protocol = ConnectionTCP
}
// Create HTTP client
httpClient := &http.Client{
Transport: transport,
Timeout: cfg.RequestTimeout + 100*time.Millisecond, // extra time to avoid race
}
client := &Client{
httpClient: httpClient,
host: cfg.Host,
protocol: cfg.Protocol,
apiPrefix: cfg.APIPrefix,
version: cfg.Version,
sctx: securityContext,
options: cfg,
}
return client, nil
}
// createTransport creates a custom transport based on connection type
func createTransport(opts Config) (http.RoundTripper, error) {
switch opts.Protocol {
case ConnectionUnixSocket:
return createUnixSocketTransport(opts)
case ConnectionNPipe:
return createNPipeTransport(opts)
default:
return createTCPTransport(opts)
}
}
// createTCPTransport sets up a TCP transport with TLS support
func createTCPTransport(opts Config) (http.RoundTripper, error) {
// Base dialer with timeout
dialer := &net.Dialer{
Timeout: opts.ConnectTimeout,
}
// Create transport with TLS and custom dialer
transport := &http.Transport{
DialContext: dialer.DialContext,
TLSClientConfig: opts.TLSConfig,
// Additional TCP-specific configurations
MaxIdleConns: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
}
return transport, nil
}
// getAPIPath returns the versioned request path to call the API.
func (c *Client) getAPIPath(p string, query url.Values) string {
apiPath := ""
if c.version != "" {
version := "/v" + strings.TrimPrefix(strings.ToLower(c.version), "v")
apiPath = path.Join(c.apiPrefix, version, p)
} else {
apiPath = path.Join(c.apiPrefix, p)
}
return (&url.URL{Path: apiPath, RawQuery: query.Encode()}).String()
}
func (c *Client) encodeBody(obj interface{}, headers http.Header) (io.Reader, http.Header, error) {
if obj == nil {
return nil, headers, nil
}
// encoding/json encodes a nil pointer as the JSON document `null`,
// irrespective of whether the type implements json.Marshaler or encoding.TextMarshaler.
// That is almost certainly not what the caller intended as the request body.
if reflect.TypeOf(obj).Kind() == reflect.Ptr && reflect.ValueOf(obj).IsNil() {
return nil, headers, nil
}
data := bytes.NewBuffer(nil)
if err := json.NewEncoder(data).Encode(obj); err != nil {
return nil, headers, err
}
if headers == nil {
headers = make(map[string][]string)
}
headers["Content-Type"] = []string{"application/json"}
return data, headers, nil
}
func (c *Client) addHeaders(req *http.Request, headers http.Header) *http.Request {
for k, v := range headers {
req.Header[http.CanonicalHeaderKey(k)] = v
}
return req
}
// prepareRequest is a private helper method to construct a request
func (c *Client) buildRequest(ctx context.Context, method, path string, query url.Values, body any) (*http.Request, error) {
reqBody, headers, err := c.encodeBody(body, nil)
if err != nil {
return nil, err
}
// Create request
req, err := http.NewRequestWithContext(ctx, method, c.getAPIPath(path, query), reqBody)
if err != nil {
return nil, err
}
// Add headers
req = c.addHeaders(req, headers)
if c.options.TLSConfig != nil {
req.URL.Scheme = "https"
} else {
req.URL.Scheme = "http"
}
req.URL.Host = c.host
if reqBody != nil && req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "text/plain")
}
return req, nil
}
// ParseResponse is a utility method to parse JSON response body
func parseResponse(resp *http.Response, target any) error {
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
fmt.Println(resp)
return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body))
}
if body != nil && target != nil {
if err := json.Unmarshal(body, target); err != nil {
return fmt.Errorf("failed to parse response JSON: %w", err)
}
}
return nil
}
// do performs an HTTP request with retry logic
func (c *Client) do(req *http.Request) (*http.Response, error) {
resp, err := c.httpClient.Do(req)
// Check for specific error types
if err != nil {
// Handle context errors
if errors.Is(err, context.Canceled) {
return nil, fmt.Errorf("request canceled: %w", err)
}
if errors.Is(err, context.DeadlineExceeded) {
return nil, fmt.Errorf("request deadline exceeded: %w", err)
}
// If error is EOF it is probably trying to connect to https server with http client
if errors.Is(err, io.EOF) {
return nil, fmt.Errorf("tried to connect to https server with http client: %w", err)
}
if errors.Is(err, http.ErrSchemeMismatch) {
return nil, fmt.Errorf("tried to connect to http server with https client: %w", err)
}
// Handle TLS errors
var tlsErr *tls.CertificateVerificationError
if errors.As(err, &tlsErr) {
return nil, fmt.Errorf("tls certificate verification failed: %w", err)
}
// Handle socket/permission-related errors
if c.protocol == ConnectionUnixSocket {
if os.IsPermission(err) {
return nil, fmt.Errorf("permission denied for Unix socket: %w", err)
}
if os.IsNotExist(err) {
return nil, fmt.Errorf("unix socket does not exist: %w", err)
}
}
return nil, fmt.Errorf("request failed: %w", err)
}
return resp, nil
}
// sendRequest builds and sends the request, returning the response
func (c *Client) sendRequest(ctx context.Context, method, path string, query url.Values, body any) (*http.Response, error) {
req, err := c.buildRequest(ctx, method, path, query, body)
if err != nil {
return nil, err
}
return c.do(req)
}
// get performs a GET request
func (c *Client) get(ctx context.Context, path string, query url.Values, target any) error {
resp, err := c.sendRequest(ctx, http.MethodGet, path, query, nil)
if err != nil {
return err
}
// Parse the response
return parseResponse(resp, target)
}
// post performs a POST request
func (c *Client) post(ctx context.Context, path string, query url.Values, body any, target any) error {
resp, err := c.sendRequest(ctx, http.MethodPost, path, query, body)
if err != nil {
return err
}
// Parse the response
return parseResponse(resp, target)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package client
import (
"context"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
)
// MessageOptions contains common options for actor message operations
type MessageOptions struct {
// Timeout is the timeout duration for the operation (equivalent to --timeout/-t flag)
Timeout time.Duration
// Expiry is the expiration time for the message (equivalent to --expiry/-e flag)
Expiry time.Time
// Destination is the destination DMS DID, peer ID or handle (equivalent to --dest/-d flag)
Destination string
// Topic for broadcast messages
Topic string
// IsInvocation indicates whether the message is an invocation
IsInvocation bool
// ReplyTo specifies the reply address for the message
ReplyTo string
}
// NewMessageOptions creates a new MessageOptions with default values
func NewMessageOptions(msgOpts ...Option) MessageOptions {
opts := MessageOptions{
Timeout: 0,
Expiry: time.Time{},
Destination: "",
Topic: "",
IsInvocation: false,
ReplyTo: "",
}
for _, opt := range msgOpts {
opt(&opts)
}
return opts
}
// Option is a function that configures a MessageOptions
type Option func(*MessageOptions)
// WithTimeout sets the timeout duration
func WithTimeout(timeout time.Duration) Option {
return func(o *MessageOptions) {
o.Timeout = timeout
}
}
// WithExpiry sets the expiry time
func WithExpiry(expiry time.Time) Option {
return func(o *MessageOptions) {
o.Expiry = expiry
}
}
// WithDestination sets the destination
func WithDestination(destination string) Option {
return func(o *MessageOptions) {
o.Destination = destination
}
}
// WithTopic sets the topic for broadcast messages
func WithTopic(topic string) Option {
return func(o *MessageOptions) {
o.Topic = topic
o.IsInvocation = false
}
}
// WithInvocation sets whether the message is an invocation
func WithInvocation(isInvocation bool) Option {
return func(o *MessageOptions) {
o.IsInvocation = isInvocation
}
}
// WithReplyTo sets the reply address for the message
func WithReplyTo(replyTo string) Option {
return func(o *MessageOptions) {
o.ReplyTo = replyTo
}
}
// DmsClient is the top-level client interface for the DMS service
type DmsClient interface {
ActorClient
ActorBehaviorsClient
}
// ActorClient provides methods for general actor message operations
type ActorClient interface {
// NewActorMessage creates a new actor message
NewActorMessage(ctx context.Context, behavior string, payload any, msgOpts MessageOptions) (actor.Envelope, error)
// SendMessageRaw sends a message to a specific actor
SendMessageRaw(ctx context.Context, msg actor.Envelope) (actor.Envelope, error)
// SendMessage creates a new actor message and sends it to a specific actor
SendMessage(ctx context.Context, behavior string, payload any, msgOpts ...Option) (actor.Envelope, error)
// InvokeBehaviorRaw invokes a behavior on an actor
InvokeBehaviorRaw(ctx context.Context, msg actor.Envelope) (actor.Envelope, error)
// InvokeBehavior creates a new actor message and invokes a behavior on an actor
InvokeBehavior(ctx context.Context, behavior string, payload any, msgOpts ...Option) (actor.Envelope, error)
// BroadcastMessageRaw broadcasts a message to a topic
BroadcastMessageRaw(ctx context.Context, msg actor.Envelope) ([]actor.Envelope, error)
// BroadcastMessage creates a new actor message and broadcasts a message to a topic
BroadcastMessage(ctx context.Context, behavior string, topic string, payload any, msgOpts ...Option) ([]actor.Envelope, error)
}
// ActorBehaviorsClient provides access to actor behavior methods
type ActorBehaviorsClient interface {
ActorPublicBehaviorClient
ActorPeersBehaviorClient
ActorOnboardingBehaviorClient
ActorDeploymentBehaviorClient
ActorAllocationsBehaviorClient
ActorResourcesBehaviorClient
ActorHardwareBehaviorClient
ActorCapBehaviorClient
ActorLoggerBehaviorClient
ActorVolumeBehaviorClient
ActorContractBehaviorClient
}
// ActorPublicBehaviorClient provides methods for public behaviors
type ActorPublicBehaviorClient interface {
// Hello sends a hello message
Hello(ctx context.Context, opts ...Option) (node.HelloResponse, error)
// BroadcastHello broadcasts a hello message to a topic
BroadcastHello(ctx context.Context, opts ...Option) ([]node.HelloResponse, error)
// Status retrieves the status of the actor
Status(ctx context.Context, opts ...Option) (node.PublicStatusResponse, error)
// Discovery retrieves the discovery information of the actor
Discovery(ctx context.Context, opts ...Option) (node.DiscoveryStatusResponse, error)
// DiscoveryBroadcast broadcasts the discovery information of the actor
DiscoveryBroadcast(ctx context.Context, opts ...Option) ([]node.DiscoveryStatusResponse, error)
}
// ActorPeersBehaviorClient provides methods for peer-related behaviors
type ActorPeersBehaviorClient interface {
// PeersSelf retrieves information about the actor's own peer
PeersSelf(ctx context.Context, opts ...Option) (node.PeerAddrInfoResponse, error)
// PeersList lists the peers connected to the actor
PeersList(ctx context.Context, opts ...Option) (node.PeersListResponse, error)
// PeersListFromDht lists peers from the DHT
PeersListFromDHT(ctx context.Context, opts ...Option) (node.PeerDHTResponse, error)
// PeersPing pings a peer
PeerPing(ctx context.Context, req node.PingRequest, opts ...Option) (node.PingResponse, error)
// PeersConnect connects to a peer
PeerConnect(ctx context.Context, req node.PeerConnectRequest, opts ...Option) (node.PeerConnectResponse, error)
// PeersScore retrieves the score of peers
PeerScore(ctx context.Context, opts ...Option) (node.PeerScoreResponse, error)
// Flightrec dump a flight recorder snapshot
Flightrec(ctx context.Context, opts ...Option) (node.PingResponse, error)
}
// ActorOnboardingBehaviorClient provides methods for onboarding
type ActorOnboardingBehaviorClient interface {
// Onboard performs onboarding
Onboard(ctx context.Context, req node.OnboardRequest, opts ...Option) (node.OnboardResponse, error)
// Offboard performs offboarding
Offboard(ctx context.Context, req node.OffboardRequest, opts ...Option) (node.OffboardResponse, error)
// OnboardStatus retrieves onboarding status
OnboardStatus(ctx context.Context, opts ...Option) (node.OnboardStatusResponse, error)
}
// ActorDeploymentBehaviorClient provides methods for deployment
type ActorDeploymentBehaviorClient interface {
// DeploymentList lists deployments
DeploymentList(ctx context.Context, req node.DeploymentListRequest, opts ...Option) (node.DeploymentListResponse, error)
// DeploymentStatus retrieves the status of a deployment
DeploymentStatus(ctx context.Context, req node.DeploymentStatusRequest, opts ...Option) (node.DeploymentStatusResponse, error)
// DeploymentLogs retrieves the logs of a deployment
DeploymentLogs(ctx context.Context, req node.DeploymentLogsRequest, opts ...Option) (node.DeploymentLogsResponse, error)
// DeploymentManifest retrieves the manifest of a deployment
DeploymentManifest(ctx context.Context, req node.DeploymentManifestRequest, opts ...Option) (node.DeploymentManifestResponse, error)
// DeploymentInfo retrieves comprehensive information about a deployment
DeploymentInfo(ctx context.Context, req node.DeploymentInfoRequest, opts ...Option) (node.DeploymentInfoResponse, error)
// DeploymentShutdown shuts down a deployment
DeploymentShutdown(ctx context.Context, req node.DeploymentShutdownRequest, opts ...Option) (node.DeploymentShutdownResponse, error)
// DeploymentNew creates a new deployment
DeploymentNew(ctx context.Context, req node.NewDeploymentRequest, opts ...Option) (node.NewDeploymentResponse, error)
// DeploymentUpdate updates a running deployment
DeploymentUpdate(ctx context.Context, req node.UpdateDeploymentRequest, opts ...Option) (node.UpdateDeploymentResponse, error)
// DeploymentPrune removes old deployments
DeploymentPrune(ctx context.Context, req node.DeploymentPruneRequest, opts ...Option) (node.DeploymentPruneResponse, error)
// DeploymentDelete removes a specific deployment
DeploymentDelete(ctx context.Context, req node.DeploymentDeleteRequest, opts ...Option) (node.DeploymentDeleteResponse, error)
}
// ActorAllocationsBehaviorClient provides methods for allocations view
type ActorAllocationsBehaviorClient interface {
AllocationsList(ctx context.Context, opts ...Option) (node.AllocationsListResponse, error)
}
// ActorResourcesBehaviorClient provides methods for resource management
type ActorResourcesBehaviorClient interface {
// ResourcesAllocated retrieves allocated resources
ResourcesAllocated(ctx context.Context, opts ...Option) (node.ResourcesResponse, error)
// ResourcesFree retrieves free resources
ResourcesFree(ctx context.Context, opts ...Option) (node.ResourcesResponse, error)
// ResourcesOnboarded retrieves onboarded resources
ResourcesOnboarded(ctx context.Context, opts ...Option) (node.ResourcesResponse, error)
}
// ActorHardwareBehaviorClient provides methods for hardware information
type ActorHardwareBehaviorClient interface {
// HardwareSpec retrieves hardware specifications
HardwareSpec(ctx context.Context, opts ...Option) (node.ResourcesResponse, error)
// HardwareUsage retrieves hardware usage
HardwareUsage(ctx context.Context, opts ...Option) (node.ResourcesResponse, error)
}
// ActorCapBehaviorClient provides methods for capability management
type ActorCapBehaviorClient interface {
// CapList retrieves capability list
CapList(ctx context.Context, req node.CapListRequest, opts ...Option) (node.CapListResponse, error)
// ProvideCapAnchor anchors token on provide
ProvideCapAnchor(ctx context.Context, req node.CapTokenAnchorRequest, opts ...Option) (node.CapAnchorResponse, error)
// RequireCapAnchor anchors token on require
RequireCapAnchor(ctx context.Context, req node.CapTokenAnchorRequest, opts ...Option) (node.CapAnchorResponse, error)
// RevokeCapAnchor anchors token on revoke anchor
RevokeCapAnchor(ctx context.Context, req node.CapTokenAnchorRequest, opts ...Option) (node.CapAnchorResponse, error)
// BroadcastCapRevoke broadcasts caps that should be revoked
BroadcastCapRevoke(ctx context.Context, req node.CapTokenAnchorRequest, opts ...Option) ([]node.CapAnchorResponse, error)
}
// ActorLoggerBehaviorClient provides methods for logger configuration
type ActorLoggerBehaviorClient interface {
// LoggerConfig configures the logger
LoggerConfig(ctx context.Context, req node.LoggerConfigRequest, opts ...Option) (node.LoggerConfigResponse, error)
}
// ActorVolumeBehaviorClient provides methods for volume management
type ActorVolumeBehaviorClient interface {
// CreateVolume creates a new volume
CreateVolume(ctx context.Context, req node.CreateVolumeRequest, opts ...Option) (node.CreateVolumeResponse, error)
// DeleteVolume deletes a volume
DeleteVolume(ctx context.Context, req node.DeleteVolumeRequest, opts ...Option) (node.DeleteVolumeResponse, error)
// StartVolume starts a volume
StartVolume(ctx context.Context, req node.StartVolumeRequest, opts ...Option) (node.StartVolumeResponse, error)
}
// ActorContractBehaviorClient provides methods for contracts
type ActorContractBehaviorClient interface {
NewContract(ctx context.Context, req contracts.CreateContractRequest, opts ...Option) (contracts.CreateContractResponse, error)
ContractStatus(ctx context.Context, req contracts.ContractStatusRequest, opts ...Option) (contracts.ContractStatusResponse, error)
ApproveLocal(ctx context.Context, req contracts.ContractApproveLocalRequest, opts ...Option) (contracts.ContractApproveLocalResponse, error)
ListIncoming(ctx context.Context, req contracts.ContractListIncomingRequest, opts ...Option) (contracts.ContractListIncomingResponse, error)
ListTransactions(ctx context.Context, opts ...Option) (contracts.ContractListLocalTransactionsResponse, error)
CollectUsagesAndForwardToPaymentProviders(ctx context.Context, req contracts.CollectUsagesAndForwardToPaymentProvidersRequest, opts ...Option) (contracts.CollectUsagesAndForwardToPaymentProvidersReponse, error)
ConfirmTransaction(ctx context.Context, req contracts.ContractConfirmLocalTransactionRequest, opts ...Option) (contracts.ContractConfirmLocalTransactionResponse, error)
GetPaymentStatus(ctx context.Context, req contracts.ContractPaymentStatusRequest, opts ...Option) (contracts.ContractPaymentStatusResponse, error)
TerminateContract(ctx context.Context, req contracts.ContractTerminationRequest, opts ...Option) (contracts.ContractTerminationResponse, error)
CompleteContract(ctx context.Context, req contracts.ContractCompletionRequest, opts ...Option) (contracts.ContractCompletionResponse, error)
ValidateContract(ctx context.Context, req contracts.ContractValidateRequest, opts ...Option) (contracts.ContractValidateResponse, error)
SettleContract(ctx context.Context, req contracts.ContractSettleRequest, opts ...Option) (contracts.ContractSettleResponse, error)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build !windows
package client
import (
"context"
"net"
"net/http"
)
// createUnixTransport sets up a Unix socket transport
func createUnixSocketTransport(opts Config) (http.RoundTripper, error) {
transport := &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return net.DialTimeout("unix", opts.Host, opts.ConnectTimeout)
},
}
return transport, nil
}
// createNPipeTransport sets up a Windows named pipe transport
func createNPipeTransport(_ Config) (http.RoundTripper, error) {
return nil, ErrProtocolNotAvailable
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
)
// NewActorCmd is a constructor for `actor` parent command
func NewActorCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
cmd := &cobra.Command{
Use: "actor",
Short: "Interact with the actor system",
Long: `Interact with the actor system
Actors are the entities which compose the NuActor system, a secure decentralized programming framework based on the Actor Model.
Actors are connected through the libp2p network substrate and communication is achieved via immutable messages.
For more information on the actor system, please refer to actor/README.md`,
}
cmd.AddCommand(newActorMsgCmd(dmsCli))
cmd.AddCommand(newActorSendCmd(dmsCli))
cmd.AddCommand(newActorInvokeCmd(dmsCli))
cmd.AddCommand(newActorBroadcastCmd(dmsCli))
cmd.AddCommand(newActorCmdGroup(dmsCli))
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/cmd/cli"
)
func newActorBroadcastCmd(dmsCLI *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "broadcast <msg>",
Short: "Broadcast a message",
Long: `Broadcast a message to a topic
If a topic is specified in the message's payload, the message will be published to all subscribers of that topic.`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
return runActorBroadcastCmd(
cmd.Context(),
dmsCLI,
args[0],
cli.CmdStreams(cmd),
)
},
}
return cmd
}
// runActorBroadcastCmd is the testable core logic for the broadcast command
func runActorBroadcastCmd(
ctx context.Context,
dmsCLI *cli.DmsCLI,
msgArg string,
streams cli.Streams,
) error {
var msg actor.Envelope
if err := json.Unmarshal([]byte(msgArg), &msg); err != nil {
return fmt.Errorf("could not unmarshal message: %w", err)
}
client, err := dmsCLI.NewClient(nil)
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
res, err := client.BroadcastMessageRaw(ctx, msg)
if err != nil {
return fmt.Errorf("could not broadcast message: %w", err)
}
for _, r := range res {
if err := displayResponse(streams.Out, r); err != nil {
return err
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/client"
"gitlab.com/nunet/device-management-service/cmd/cli"
)
func newActorCmdGroup(dmsCli *cli.DmsCLI) *cobra.Command {
// Create a slice of valid arguments from behavior map keys
validArgs := make([]string, 0, len(registeredBehaviors))
for behavior := range registeredBehaviors {
validArgs = append(validArgs, behavior)
}
cmd := &cobra.Command{
Use: "cmd",
Short: "Invoke a predefined behavior on an actor",
Long: `Invoke a predefined behavior on an actor
Example:
nunet actor cmd --context user /broadcast/hello
Adding the --dest flag will cause the behavior to be invoked on the specified actor.
For more information on behaviors, refer to cmd/actor/README.md`,
ValidArgs: validArgs,
}
for behavior, behaviorCfg := range registeredBehaviors {
cmd.AddCommand(newActorCmdCmd(dmsCli, behavior, behaviorCfg))
}
useMessageOptsFlags(cmd, true)
return cmd
}
type actorCmdOptions struct {
Context string
Payload any
Args []string
MsgOpts []client.Option
Streams cli.Streams
}
func newActorCmdCmd(dmsCli *cli.DmsCLI, behavior string, behaviorCfg *behaviorConfig) *cobra.Command {
opts := actorCmdOptions{}
if behaviorCfg.Payload != nil {
opts.Payload = behaviorCfg.Payload()
}
cmd := &cobra.Command{
Use: fmt.Sprintf("%s [<param> ...]", behavior),
Short: behaviorCfg.Short,
Long: behaviorCfg.Long,
ValidArgsFunction: behaviorCfg.ValidArgsFn,
Args: behaviorCfg.Args,
PreRunE: func(cmd *cobra.Command, args []string) error {
opts.Args = args
opts.Context, _ = cmd.Flags().GetString(fnContext)
opts.MsgOpts = getBehaviorMsgOpts(cmd)
opts.Streams = cli.CmdStreams(cmd)
if behaviorCfg.PreRunFn != nil {
return behaviorCfg.PreRunFn(cmd, dmsCli, opts)
}
return nil
},
RunE: func(cmd *cobra.Command, _ []string) error {
return behaviorCfg.Run(cmd.Context(), dmsCli, opts, cli.CmdStreams(cmd))
},
}
if behaviorCfg.SetFlags != nil {
behaviorCfg.SetFlags(cmd, opts.Payload)
}
return cmd
}
// NewActorCmdWrapper is a factory for creating actor command aliases.
func NewActorCmdWrapper(dmsCli *cli.DmsCLI, behavior string) (*cobra.Command, error) {
behaviorCfg, ok := registeredBehaviors[behavior]
if !ok {
return nil, fmt.Errorf("unknown behavior: %s", behavior)
}
cmd := newActorCmdCmd(dmsCli, behavior, behaviorCfg)
useMessageOptsFlags(cmd, true)
return cmd, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/cmd/cli"
)
// NewActorInvokeCmd is a constructor for `actor invoke` subcommand
func newActorInvokeCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "invoke <msg>",
Short: "Invoke a behaviour in an actor and return the result",
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
return runActorInvokeCmd(
cmd.Context(),
dmsCli,
args[0],
cli.CmdStreams(cmd),
)
},
}
return cmd
}
func runActorInvokeCmd(
ctx context.Context,
dmsCli *cli.DmsCLI,
msgArg string,
streams cli.Streams,
) error {
var msg actor.Envelope
if err := json.Unmarshal([]byte(msgArg), &msg); err != nil {
return fmt.Errorf("could not unmarshal message: %w", err)
}
if msg.Options.ReplyTo == "" {
return fmt.Errorf("missing replyTo field in message")
}
cli, err := dmsCli.NewClient(nil)
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
res, err := cli.InvokeBehaviorRaw(ctx, msg)
if err != nil {
return fmt.Errorf("could not invoke behaviour: %w", err)
}
return displayResponse(streams.Out, json.RawMessage(res.Message))
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/client"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
)
type actorMsgOptions struct {
Context string
Behavior string
Payload string
MsgOpts client.MessageOptions
}
func newActorMsgCmd(dmsCli *cli.DmsCLI) *cobra.Command {
var opts actorMsgOptions
cmd := &cobra.Command{
Use: "msg <behavior> <payload>",
Short: "Construct a message",
Long: `Construct and sign a message that can be communicated to an actor.
The constructed message is returned as a JSON object that can be used stored or piped into another command, for instance the the send, invoke, or broadcast command.
Example:
nunet actor msg --broadcast /nunet/hello /broadcast/hello 'Hello, World!'`,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
opts.Context, _ = cmd.Flags().GetString(fnContext)
opts.Behavior = args[0]
opts.Payload = args[1]
for _, opt := range getNewMsgOpts(cmd) {
opt(&opts.MsgOpts)
}
return runActorMsgCmd(cmd.Context(), dmsCli, opts, cli.CmdStreams(cmd))
},
}
useNewMsgOptsFlags(cmd, false)
return cmd
}
func runActorMsgCmd(ctx context.Context, dmsCli *cli.DmsCLI, opts actorMsgOptions, streams cli.Streams) error {
sctx, err := utils.NewSecurityContext(dmsCli, opts.Context)
if err != nil {
return fmt.Errorf("could not create security context: %w", err)
}
cli, err := dmsCli.NewClient(sctx)
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
msg, err := cli.NewActorMessage(ctx, opts.Behavior, opts.Payload, opts.MsgOpts)
if err != nil {
return fmt.Errorf("could not create message: %w", err)
}
return displayResponse(streams.Out, msg)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/cmd/cli"
)
// newActorSendCmd is a constructor for `actor send` subcommand
func newActorSendCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "send <msg>",
Short: "Send a message",
Long: `Send a message to an actor
Actors only communicate via messages. For more information on constructing a message, see:
nunet actor msg --help
The message is encoded into an actor envelope, which then is sent across the network through the API.`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
return runActorSendCmd(
cmd.Context(),
dmsCli,
args[0],
cli.CmdStreams(cmd),
)
},
}
return cmd
}
// runActorSendCmd is the testable core logic for the send command
func runActorSendCmd(
ctx context.Context,
dmsCli *cli.DmsCLI,
msgArg string,
streams cli.Streams,
) error {
var msg actor.Envelope
if err := json.Unmarshal([]byte(msgArg), &msg); err != nil {
return fmt.Errorf("could not unmarshal message: %w", err)
}
client, err := dmsCli.NewClient(nil)
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
res, err := client.SendMessageRaw(ctx, msg)
if err != nil {
return fmt.Errorf("could not send message: %w", err)
}
return displayResponse(streams.Out, json.RawMessage(res.Message))
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/client"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/utils/convert"
)
type BehaviorAction string
const (
bBroadcast BehaviorAction = "broadcast"
bInvoke BehaviorAction = "invoke"
bSend BehaviorAction = "send"
)
var ErrInvalidArgument = errors.New("invalid argument")
type Command = cobra.Command
type behaviorConfig struct {
Payload func() any
Behavior string
Action BehaviorAction
SetFlags func(cmd *Command, payload any)
RunFn func(ctx context.Context, dmsCli *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error)
PreRunFn func(cmd *Command, dmsCli *cli.DmsCLI, opts actorCmdOptions) error
ValidArgsFn func(cmd *Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective)
Args cobra.PositionalArgs
Long string
Short string
}
func (b *behaviorConfig) Run(ctx context.Context, dmsCli *cli.DmsCLI, opts actorCmdOptions, streams cli.Streams) error {
// Create security context first
sctx, err := utils.NewSecurityContext(dmsCli, opts.Context)
if err != nil {
return fmt.Errorf("could not create security context: %w", err)
}
// Check if timeout was set via -t flag
var timeout time.Duration
for _, opt := range opts.MsgOpts {
// Apply the option to a temporary MessageOptions to check if it sets timeout
tempOpts := &client.MessageOptions{}
opt(tempOpts)
if tempOpts.Timeout > 0 {
timeout = tempOpts.Timeout
break
}
}
var dmsClient client.DmsClient
if timeout > 0 {
// Create client with timeout from -t flag
dmsClient, err = dmsCli.NewClientWithTimeout(sctx, timeout)
if err != nil {
return fmt.Errorf("could not create client with timeout: %w", err)
}
} else {
// Create client with default timeout
dmsClient, err = dmsCli.NewClient(sctx)
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
}
res, err := b.RunFn(ctx, dmsCli, dmsClient, opts)
if err != nil {
return fmt.Errorf("could not run behavior: %w", err)
}
return displayResponse(streams.Out, res)
}
var registeredBehaviors = map[string]*behaviorConfig{
// /dms/tokenomics/contract/settle
behaviors.ContractSettleBehavior: {
Payload: func() any { return &ContractSettleCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractSettleCmd)
cmd.Flags().StringVarP(&p.ContractDID, "contract-did", "", "", "contract did (required)")
cmd.Flags().StringVarP(&p.ContractHost, "contract-host-did", "", "", "contract host did (required)")
_ = cmd.MarkFlagRequired("contract-did")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractSettleCmd)
if !ok {
return nil, fmt.Errorf("failed to decode contract settle payload")
}
request := contracts.ContractSettleRequest{
ContractDID: req.ContractDID,
}
if req.ContractHost != "" {
destination, err := getDestinationHandle(req.ContractDID, req.ContractHost)
if err != nil {
return nil, err
}
opts.MsgOpts = append(opts.MsgOpts, client.WithDestination(destination))
}
resp, err := dmsClient.SettleContract(ctx, request, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a settle request",
Long: `Invoke the /dms/tokenomics/contract/settle behavior on an actor
This behavior calls the contract settle behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/settle --contract-did <did> --contract-host-did <hostdid>`,
},
// /dms/tokenomics/contract/terminate
behaviors.ContractTerminationBehavior: {
Payload: func() any { return &ContractTerminateCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractTerminateCmd)
cmd.Flags().StringVarP(&p.ContractDID, "contract-did", "", "", "contract did (required)")
cmd.Flags().StringVarP(&p.ContractHost, "contract-host-did", "", "", "contract host did (required)")
_ = cmd.MarkFlagRequired("contract-did")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractTerminateCmd)
if !ok {
return nil, fmt.Errorf("failed to decode contract terminate payload")
}
request := contracts.ContractTerminationRequest{
ContractDID: req.ContractDID,
}
if req.ContractHost != "" {
destination, err := getDestinationHandle(req.ContractDID, req.ContractHost)
if err != nil {
return nil, err
}
opts.MsgOpts = append(opts.MsgOpts, client.WithDestination(destination))
}
resp, err := dmsClient.TerminateContract(ctx, request, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a termination request",
Long: `Invoke the /dms/tokenomics/contract/terminate behavior on an actor
This behavior calls the contract terminate behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/terminate --contract-did <did> --contract-host-did <hostdid>`,
},
// /dms/tokenomics/contract/complete
behaviors.ContractCompleteBehavior: {
Payload: func() any { return &ContractCompleteCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractCompleteCmd)
cmd.Flags().StringVarP(&p.ContractDID, "contract-did", "", "", "contract did (required)")
cmd.Flags().StringVarP(&p.ContractHost, "contract-host-did", "", "", "contract host did (required)")
_ = cmd.MarkFlagRequired("contract-did")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractCompleteCmd)
if !ok {
return nil, fmt.Errorf("failed to decode contract complete payload")
}
request := contracts.ContractCompletionRequest{
ContractDID: req.ContractDID,
}
if req.ContractHost != "" {
destination, err := getDestinationHandle(req.ContractDID, req.ContractHost)
if err != nil {
return nil, err
}
opts.MsgOpts = append(opts.MsgOpts, client.WithDestination(destination))
}
resp, err := dmsClient.CompleteContract(ctx, request, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a contract complete request",
Long: `Invoke the /dms/tokenomics/contract/complete behavior on an actor
This behavior calls the contract complete behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/complete --contract-did <did> --contract-host-did <hostdid>`,
},
// /dms/tokenomics/contract/validate
behaviors.ContractValidationBehavior: {
Payload: func() any { return &ContractValidateCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractValidateCmd)
cmd.Flags().StringVarP(&p.ContractDID, "contract-did", "", "", "contract did (required)")
cmd.Flags().StringVarP(&p.ContractHost, "contract-host-did", "", "", "contract host did (required)")
_ = cmd.MarkFlagRequired("contract-did")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractValidateCmd)
if !ok {
return nil, fmt.Errorf("failed to decode contract complete payload")
}
request := contracts.ContractValidateRequest{
ContractDID: req.ContractDID,
}
if req.ContractHost != "" {
destination, err := getDestinationHandle(req.ContractDID, req.ContractHost)
if err != nil {
return nil, err
}
opts.MsgOpts = append(opts.MsgOpts, client.WithDestination(destination))
}
resp, err := dmsClient.ValidateContract(ctx, request, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a contract validate request",
Long: `Invoke the /dms/tokenomics/contract/validate behavior on an actor
This behavior calls the contract validate behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/validate --contract-did <did> --contract-host-did <hostdid>`,
},
// /dms/tokenomics/contract/state
behaviors.ContractStatusBehavior: {
Payload: func() any { return &ContractStatusRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractStatusRequestCmd)
cmd.Flags().StringVarP(&p.ContractDID, "contract-did", "", "", "contract-did (required)")
cmd.Flags().StringVarP(&p.ContractHost, "contract-host-did", "", "", "contract host did (required)")
_ = cmd.MarkFlagRequired("contract-did")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractStatusRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
contractReq := contracts.ContractStatusRequest{
ContractDID: req.ContractDID,
}
if req.ContractHost != "" {
destination, err := getDestinationHandle(req.ContractDID, req.ContractHost)
if err != nil {
return nil, err
}
opts.MsgOpts = append(opts.MsgOpts, client.WithDestination(destination))
}
resp, err := dmsClient.ContractStatus(ctx, contractReq, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a contract state request",
Long: `Invoke the /dms/tokenomics/contract/state behavior on an actor
This behavior calls the actors contract state behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/state --contract-did <did> --contract-host-did <hostdid>`,
},
// /dms/tokenomics/contract/payment/status
behaviors.ContractPaymentStatusBehavior: {
Payload: func() any { return &ContractPaymentStatusCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractPaymentStatusCmd)
cmd.Flags().StringVarP(&p.UniqueID, "unique-id", "", "", "unique id (required)")
_ = cmd.MarkFlagRequired("unique-id")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractPaymentStatusCmd)
if !ok {
return nil, fmt.Errorf("failed to decode payment status payload")
}
request := contracts.ContractPaymentStatusRequest{
UniqueID: req.UniqueID,
}
resp, err := dmsClient.GetPaymentStatus(ctx, request, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a payment status request",
Long: `Invoke the /dms/tokenomics/contract/payment/status behavior on an actor
This behavior calls the payment status behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/payment/status --unique-id <uniqueid>`,
},
// /dms/tokenomics/contract/usages/calculate
behaviors.ContractUsagesCalculateBehavior: {
Payload: func() any { return &CollectUsagesAndForwardToPaymentProvidersCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*CollectUsagesAndForwardToPaymentProvidersCmd)
cmd.Flags().StringVar(&p.ContractDID, "contract-did", "", "Contract DID to process (optional, processes all contracts if not specified)")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req := contracts.CollectUsagesAndForwardToPaymentProvidersRequest{
ContractDID: opts.Payload.(*CollectUsagesAndForwardToPaymentProvidersCmd).ContractDID,
}
resp, err := dmsClient.CollectUsagesAndForwardToPaymentProviders(ctx, req, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a usage calculation request",
Long: `Invoke the /dms/tokenomics/contract/usages/calculate behavior on an actor
This behavior calls the actors contract calculate usages behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/usages/calculate
nunet actor cmd --context user /dms/tokenomics/contract/usages/calculate --contract-did did:key:...`,
},
// /dms/tokenomics/contract/transactions/confirm
behaviors.ContractConfirmLocalTransactionBehavior: {
Payload: func() any { return &ContractConfirmLocalTransactionCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractConfirmLocalTransactionCmd)
cmd.Flags().StringVarP(&p.UniqueID, "unique-id", "", "", "transaction unique id (required)")
cmd.Flags().StringVarP(&p.TxHash, "tx-hash", "", "", "transaction hash (required)")
cmd.Flags().StringVarP(&p.Blockchain, "blockchain", "", "", "which blockchain was used (required)")
cmd.Flags().StringVarP(&p.QuoteID, "quote-id", "", "", "payment quote id (optional)")
_ = cmd.MarkFlagRequired("unique-id")
_ = cmd.MarkFlagRequired("tx-hash")
_ = cmd.MarkFlagRequired("blockchain")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractConfirmLocalTransactionCmd)
if !ok {
return nil, fmt.Errorf("failed to decode ContractConfirmLocalTransactionCmd payload")
}
request := contracts.ContractConfirmLocalTransactionRequest{
UniqueID: req.UniqueID,
TxHash: req.TxHash,
Blockchain: req.Blockchain,
QuoteID: req.QuoteID,
}
resp, err := dmsClient.ConfirmTransaction(ctx, request, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a confirm transactions request",
Long: `Invoke the /dms/tokenomics/contract/transactions/confirm behavior on an actor
This behavior calls the actors contract confirm transactions behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/transactions/confirm --unique-id <uniqueid> --tx-hash <txhash> --blockchain ETHEREUM`,
},
// /dms/tokenomics/contract/transactions/list
behaviors.ContractListLocalTransactionsBehavior: {
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
resp, err := dmsClient.ListTransactions(ctx, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a list transactions request",
Long: `Invoke the /dms/tokenomics/contract/transactions/list behavior on an actor
This behavior calls the actors contract list transactions behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/transactions/list`,
},
// /dms/tokenomics/contract/payment/quote/get
behaviors.ContractGetPaymentQuoteBehavior: {
Payload: func() any { return &ContractGetPaymentQuoteCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractGetPaymentQuoteCmd)
cmd.Flags().StringVarP(&p.UniqueID, "unique-id", "", "", "transaction unique id (required)")
_ = cmd.MarkFlagRequired("unique-id")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractGetPaymentQuoteCmd)
if !ok {
return nil, fmt.Errorf("failed to decode ContractGetPaymentQuoteCmd payload")
}
request := contracts.ContractGetPaymentQuoteRequest{
UniqueID: req.UniqueID,
}
resp, err := dmsClient.InvokeBehavior(ctx, behaviors.ContractGetPaymentQuoteBehavior, request, opts.MsgOpts...)
if err != nil {
return nil, fmt.Errorf("failed to invoke behavior: %w", err)
}
var quoteResp contracts.ContractGetPaymentQuoteResponse
if err := json.Unmarshal(resp.Message, "eResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return quoteResp, nil
},
Action: bInvoke,
Short: "Get a payment quote for a transaction",
Long: `Invoke the /dms/tokenomics/contract/payment/quote/get behavior on an actor
This behavior gets a real-time payment quote for a transaction requiring currency conversion.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/payment/quote/get --unique-id <unique_id>`,
},
// /dms/tokenomics/contract/payment/quote/validate
behaviors.ContractValidatePaymentQuoteBehavior: {
Payload: func() any { return &ContractValidatePaymentQuoteCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractValidatePaymentQuoteCmd)
cmd.Flags().StringVarP(&p.QuoteID, "quote-id", "", "", "quote id (required)")
_ = cmd.MarkFlagRequired("quote-id")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractValidatePaymentQuoteCmd)
if !ok {
return nil, fmt.Errorf("failed to decode ContractValidatePaymentQuoteCmd payload")
}
request := contracts.ContractValidatePaymentQuoteRequest{
QuoteID: req.QuoteID,
}
resp, err := dmsClient.InvokeBehavior(ctx, behaviors.ContractValidatePaymentQuoteBehavior, request, opts.MsgOpts...)
if err != nil {
return nil, fmt.Errorf("failed to invoke behavior: %w", err)
}
var validateResp contracts.ContractValidatePaymentQuoteResponse
if err := json.Unmarshal(resp.Message, &validateResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return validateResp, nil
},
Action: bInvoke,
Short: "Validate a payment quote",
Long: `Invoke the /dms/tokenomics/contract/payment/quote/validate behavior on an actor
This behavior validates a payment quote before payment execution.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/payment/quote/validate --quote-id <quote_id>`,
},
// /dms/tokenomics/contract/payment/quote/cancel
behaviors.ContractCancelPaymentQuoteBehavior: {
Payload: func() any { return &ContractCancelPaymentQuoteCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractCancelPaymentQuoteCmd)
cmd.Flags().StringVarP(&p.QuoteID, "quote-id", "", "", "quote id (required)")
_ = cmd.MarkFlagRequired("quote-id")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractCancelPaymentQuoteCmd)
if !ok {
return nil, fmt.Errorf("failed to decode ContractCancelPaymentQuoteCmd payload")
}
request := contracts.ContractCancelPaymentQuoteRequest{
QuoteID: req.QuoteID,
}
resp, err := dmsClient.InvokeBehavior(ctx, behaviors.ContractCancelPaymentQuoteBehavior, request, opts.MsgOpts...)
if err != nil {
return nil, fmt.Errorf("failed to invoke behavior: %w", err)
}
var cancelResp contracts.ContractCancelPaymentQuoteResponse
if err := json.Unmarshal(resp.Message, &cancelResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err)
}
return cancelResp, nil
},
Action: bInvoke,
Short: "Cancel a payment quote",
Long: `Invoke the /dms/tokenomics/contract/payment/quote/cancel behavior on an actor
This behavior cancels/invalidates a payment quote.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/payment/quote/cancel --quote-id <quote_id>`,
},
// /dms/tokenomics/contract/list_incoming
behaviors.ContractListBehavior: {
Payload: func() any { return &contracts.ContractListIncomingRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*contracts.ContractListIncomingRequest)
cmd.Flags().StringVarP((*string)(&p.Role), "role", "", "", "role filter (provider|requestor)")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, _ := opts.Payload.(*contracts.ContractListIncomingRequest)
if req == nil {
req = &contracts.ContractListIncomingRequest{}
}
resp, err := dmsClient.ListIncoming(ctx, *req, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a list incoming contract request",
Long: `Invoke the /dms/tokenomics/contract/list_incoming behavior on an actor
This behavior calls the actors contract list behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/list_incoming`,
},
// /dms/tokenomics/contract/aprove_local
behaviors.ContractApproveLocalBehavior: {
Payload: func() any { return &ContractApproveLocalRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*ContractApproveLocalRequestCmd)
cmd.Flags().StringVarP(&p.ContractDID, "contract-did", "", "", "contract-did (required)")
_ = cmd.MarkFlagRequired("contract-did")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*ContractApproveLocalRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
contractReq := contracts.ContractApproveLocalRequest{
ContractDID: req.ContractDID,
}
resp, err := dmsClient.ApproveLocal(ctx, contractReq, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a contract approval request",
Long: `Invoke the /dms/tokenomics/contract/aprove_local behavior on an actor
This behavior calls the actors contract approval behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/aprove_local --contract-did <did>`,
},
// /dms/tokenomics/contract/create
behaviors.ContractCreateBehavior: {
Payload: func() any { return &CreateContractRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*CreateContractRequestCmd)
cmd.Flags().StringVarP(&p.ContractFile, "contract-file", "", "", "contract-file (required)")
_ = cmd.MarkFlagRequired("contract-file")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*CreateContractRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
data, err := os.ReadFile(req.ContractFile)
if err != nil {
return nil, fmt.Errorf("failed to read contract file: %w", err)
}
var contractReq contracts.CreateContractRequest
err = json.Unmarshal(data, &contractReq)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal create contract request payload: %w", err)
}
resp, err := dmsClient.NewContract(ctx, contractReq, opts.MsgOpts...)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a create contract message",
Long: `Invoke the /dms/tokenomics/contract/create behavior on an actor
This behavior calls the actors create contract behaviour.
Examples:
nunet actor cmd --context user /dms/tokenomics/contract/create --contract-file <file> --dest <did_of_solution_enabler>`,
},
// /dms/volume/create
behaviors.VolumeCreateBehavior: {
Payload: func() any { return &CreateVolumeRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*CreateVolumeRequestCmd)
cmd.Flags().StringVarP(&p.VolumeName, "name", "n", "", "name (required)")
cmd.Flags().StringVarP(&p.ClientPEMFile, "client-pem-file", "p", "", "client-pem-file (required)")
cmd.Flags().StringVarP(&p.CAOutputDir, "ca-output-dir", "", "", "ca-output-dir (required)")
_ = cmd.MarkFlagRequired("name")
_ = cmd.MarkFlagRequired("client-pem-file")
_ = cmd.MarkFlagRequired("ca-output-dir")
},
RunFn: func(ctx context.Context, dmsCli *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*CreateVolumeRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
afs := afero.Afero{Fs: dmsCli.FS()}
data, err := afs.ReadFile(req.ClientPEMFile)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
// validate client pem
cfg := &node.CreateVolumeRequest{
Name: req.VolumeName,
ClientPEM: string(data),
}
resp, err := dmsClient.CreateVolume(ctx, *cfg, opts.MsgOpts...)
if err != nil {
return resp, err
}
err = afs.WriteFile(filepath.Join(req.CAOutputDir, "glusterfs.ca"), []byte(resp.CAData), 0o775)
if err != nil {
return resp, err
}
return resp, nil
},
Action: bInvoke,
Short: "Send a create volume message",
Long: `Invoke the /dms/volume/create behavior on an actor
This behavior calls the actors create volume behaviour.
Examples:
nunet actor cmd --context user /dms/volume/create --name <volname> --client-pem-file <filename>`,
},
// /dms/volume/delete
behaviors.VolumeDeleteBehavior: {
Payload: func() any { return &node.DeleteVolumeRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeleteVolumeRequest)
cmd.Flags().StringVarP(&p.Name, "name", "n", "", "name (required)")
_ = cmd.MarkFlagRequired("name")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.DeleteVolumeRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return dmsClient.DeleteVolume(ctx, *req, opts.MsgOpts...)
},
Action: bInvoke,
Short: "Send a delete volume message",
Long: `Invoke the /dms/volume/delete behavior on an actor
This behavior calls the actors delete volume behaviour.
Examples:
nunet actor cmd --context user /dms/volume/delete --name <volname>`,
},
// /dms/volume/start
behaviors.VolumeStartBehavior: {
Payload: func() any { return &node.StartVolumeRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.StartVolumeRequest)
cmd.Flags().StringVarP(&p.Name, "name", "n", "", "name (required)")
_ = cmd.MarkFlagRequired("name")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.StartVolumeRequest)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
return dmsClient.StartVolume(ctx, *req, opts.MsgOpts...)
},
Action: bInvoke,
Short: "Send a start volume message",
Long: `Invoke the /dms/volume/start behavior on an actor
This behavior calls the actors start volume behaviour.
Examples:
nunet actor cmd --context user /dms/volume/start --name <volname>`,
},
// /public/hello
behaviors.PublicHelloBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.Hello(ctx, opts.MsgOpts...)
},
Short: "Invoke a 'hello' message",
Long: `Invoke the /public/hello behavior on an actor
This behavior invokes a "hello" for a polite introduction.
Examples:
nunet actor cmd --context user /public/hello
nunet actor cmd --context user /public/hello --dest <did/peer_id/actor_handle>`,
},
// /broadcast/hello
behaviors.BroadcastHelloBehavior: {
Action: bBroadcast,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.BroadcastHello(ctx, opts.MsgOpts...)
},
Short: "Broadcast a 'hello' message to a topic",
Long: `Invokes the /broadcast/hello behavior on an actor
This behavior sends a "hello" message to a broadcast topic for polite introduction.
Examples:
nunet actor cmd --context user /broadcast/hello`,
},
// /public/status
behaviors.PublicStatusBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.Status(ctx, opts.MsgOpts...)
},
Short: "Retrieve actor status",
Long: `Invokes the /public/status behavior on an actor
This behavior retrieves the status and resources information.
Examples:
nunet actor cmd --context user /public/status # own actor status
nunet actor cmd --context user /public/status --dest <did/peer_id/actor_handle> # status of specified destination`,
},
// /dms/node/status
behaviors.StatusDiscoveryBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.Discovery(ctx, opts.MsgOpts...)
},
Short: "Invoke a 'status discovery' message",
Long: `Invoke the /dms/node/status behavior on an actor
This behavior invokes a "status discovery" behavior for fleet discovery.
Examples:
nunet actor cmd --context user /dms/node/status
nunet actor cmd --context user /dms/node/status --dest <did/peer_id/actor_handle>`,
},
// /broadcast/dms/status
behaviors.BroadcastStatusDiscoveryBehavior: {
Action: bBroadcast,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.DiscoveryBroadcast(ctx, opts.MsgOpts...)
},
Short: "Broadcast a 'status discovery' message to a topic",
Long: `Broadcast the /broadcast/dms/status behavior to nodes in the network
This behavior broadcasts a "status discovery" message to topic /nunet/status for fleet discovery.
Examples:
nunet actor cmd --context user /broadcast/dms/status`,
},
// /dms/node/peers/list
behaviors.PeersListBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.PeersList(ctx, opts.MsgOpts...)
},
Short: "List connected peers",
Long: `Invokes the /dms/node/peers/list behavior on an actor
This behavior retrieves a list of connected peers.
Examples:
nunet actor cmd --context user /dms/node/peers/list # own node actor peer list
nunet actor cmd --context user /dms/node/peers/list --dest <did/peer_id/actor_handle> # specified node actor peer list`,
},
// /dms/node/peers/self
behaviors.PeerAddrInfoBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.PeersSelf(ctx, opts.MsgOpts...)
},
Short: "Get peer's ID and addresses",
Long: `Invokes the /dms/node/peers/self behavior on an actor
This behavior retrieves information about the node itself, such as its ID or addresses.
Examples:
nunet actor cmd --context user /dms/node/peers/self # own node actor peer ID
nunet actor cmd --context user /dms/node/peers/self --dest <did/peer_id/actor_handle> # specified node actor peer ID`,
},
// /dms/node/peers/ping
behaviors.PeerPingBehavior: {
Action: bInvoke,
Payload: func() any { return &node.PingRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.PingRequest)
cmd.Flags().StringVarP(&p.Host, "host", "H", "", "host address to ping (required)")
_ = cmd.MarkFlagRequired("host")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.PingRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.PeerPing(ctx, *req, opts.MsgOpts...)
},
Short: "Ping a peer",
Long: `Invokes the /dms/node/peers/ping behavior on an actor
This behavior establishes a ping connection with a peer.
Examples:
nunet actor cmd --context user /dms/node/peers/ping --host <peer_id>`,
},
// /dms/node/peers/dht
behaviors.PeerDHTBehavior: {
Action: bInvoke,
// TODO: Check the actual implementation?
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.PeersListFromDHT(ctx, opts.MsgOpts...)
},
Short: "List peers connected to DHT",
Long: `Invokes the /dms/node/peers/dht behavior on an actor
This behavior returns a list of peers from the Distributed Hash Table (DHT) used for peer discovery and content routing.
Examples:
nunet actor cmd --context user /dms/node/peers/dht`,
},
// /dms/node/peers/connect
behaviors.PeerConnectBehavior: {
Action: bInvoke,
Payload: func() any { return &node.PeerConnectRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.PeerConnectRequest)
cmd.Flags().StringVarP(&p.Address, "address", "a", "", "peer address to connect to (required)")
_ = cmd.MarkFlagRequired("address")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.PeerConnectRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.PeerConnect(ctx, *req, opts.MsgOpts...)
},
Short: "Connect to a peer",
Long: `Invokes the /dms/node/peers/connect behavior on an actor
This behavior initiates a connection to a specified peer.
Examples:
nunet actor cmd --context user /dms/node/peers/connect --address /p2p/<peer_id>`,
},
// /dms/node/peers/score
behaviors.PeerScoreBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.PeerScore(ctx, opts.MsgOpts...)
},
Short: "Retrieves gossipsub broadcast score",
Long: `Invokes the /dms/node/peers/score behavior on an actor
This behavior retrieves a snapshot of the peer's gossipsub broadcast score.
Examples:
nunet actor cmd --context user /dms/node/peers/score`,
},
// /dms/debug/flightrec
behaviors.DebugFlightrecBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.Flightrec(ctx, opts.MsgOpts...)
},
Short: "Dumps a flight recorder snapshot",
Long: `Invokes the /dms/debug/flightrec behavior on an actor
This behavior dumps a flight recorder snapshot.
Examples:
nunet actor cmd --context user /dms/debug/flightrec`,
},
// /dms/node/onboarding/onboard
behaviors.OnboardBehavior: {
Action: bInvoke,
Payload: func() any { return &onboardingInput{} },
SetFlags: func(cmd *Command, payload any) {
// infer the type of the payload
p := payload.(*onboardingInput)
cmd.Flags().StringVarP(&p.RAMSize, "ram", "R", "0GiB", "set the amount of memory to reserve for NuNet (defaults to GiB)")
cmd.Flags().Float32VarP(&p.CPUCores, "cpu", "C", 0, "set the number of CPU cores to reserve for NuNet")
cmd.Flags().StringVarP(&p.DiskSize, "disk", "D", "0GiB", "set the amount of disk size to reserve for NuNet (defaults to GiB)")
cmd.Flags().StringVarP(&p.GPUsStr, "gpus", "G", "", "comma-separated list of GPU Index and VRAM in GiB (e.g. 0:4,1:8). The gpu index can be obtained from 'nunet gpu list' command. Unit can be specified for the VRAM but defaults to GiB")
cmd.Flags().BoolVarP(&p.NoGPU, "no-gpu", "N", false, "do not reserve any GPU resources")
cmd.MarkFlagsOneRequired("ram", "cpu", "disk")
cmd.MarkFlagsRequiredTogether("ram", "cpu", "disk")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
p, ok := opts.Payload.(*onboardingInput)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
if err := processOnboardInput(ctx, dmsClient, opts); err != nil {
return nil, err
}
req := node.OnboardRequest{}
req.Config.OnboardedResources.CPU.Cores = p.CPUCores
req.Config.OnboardedResources.CPU.ClockSpeed = p.CPUCLock
req.NoGPU = p.NoGPU
req.Config.OnboardedResources.GPUs = p.GPUs
var err error
// convert RAM and Disk from specified unit to bytes if specified otherwise, default to GiB
req.Config.OnboardedResources.RAM.Size, err = convert.ParseBytesWithDefaultUnit(p.RAMSize, "GiB")
if err != nil {
return nil, fmt.Errorf("failed to decode RAM size. Expected Unit in GiB")
}
req.Config.OnboardedResources.Disk.Size, err = convert.ParseBytesWithDefaultUnit(p.DiskSize, "GiB")
if err != nil {
return nil, fmt.Errorf("failed to decode Disk size. Expected Unit in GiB")
}
return dmsClient.Onboard(ctx, req, opts.MsgOpts...)
},
Short: "Onboard a node to the network",
Long: `Invokes the /dms/node/onboarding/onboard behavior on an actor
This behavior is used to onboard a node to the DMS, making its resources available for use.
Examples:
nunet actor cmd --context user /dms/node/onboarding/onboard --disk 1 --ram 1 --cpu 2`,
},
// /dms/node/onboarding/offboard
behaviors.OffboardBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req := node.OffboardRequest{}
return dmsClient.Offboard(ctx, req, opts.MsgOpts...)
},
Short: "Offboard a node from the network",
Long: `Invokes the /dms/node/onboarding/offboard behavior on an actor
This behavior is used to offboard a node from the DMS (Device Management Service).
Examples:
nunet actor cmd --context user /dms/node/onboarding/offboard
nunet actor cmd --context user /dms/node/onboarding/offboard --force`,
// TODO: there is no flag set for --force
},
// /dms/node/onboarding/status
behaviors.OnboardStatusBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.OnboardStatus(ctx, opts.MsgOpts...)
},
Short: "Retrieve onboarding status of a node",
Long: `Invokes the /dms/node/onboarding/status behavior on an actor
This behavior is used to check the onboarding status of a node.
Examples:
nunet actor cmd --context user /dms/node/onboarding/status`,
},
// /dms/node/deployment/list
behaviors.DeploymentListBehavior: {
Action: bInvoke,
Payload: func() any { return &DeploymentListCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*DeploymentListCmd)
// Existing metadata filter
cmd.Flags().StringToStringVarP(&p.Metadata, "filter", "f", nil, "metadata filter to filter deployments (optional)")
// Pagination
cmd.Flags().IntVar(&p.Limit, "limit", 0, "Maximum number of results to return (0 = no limit)")
cmd.Flags().IntVar(&p.Offset, "offset", 0, "Number of results to skip")
// Status filter
cmd.Flags().StringSliceVar(&p.Status, "status", nil, "Filter by deployment status (can specify multiple, e.g., --status Running --status Failed)")
// Date filters
cmd.Flags().StringVar(&p.CreatedAfter, "created-after", "", "Filter deployments created after this date (RFC3339 or relative: 1h, 1d, etc.)")
cmd.Flags().StringVar(&p.CreatedBefore, "created-before", "", "Filter deployments created before this date (RFC3339 or relative)")
cmd.Flags().StringVar(&p.UpdatedAfter, "updated-after", "", "Filter deployments updated after this date (RFC3339 or relative)")
cmd.Flags().StringVar(&p.UpdatedBefore, "updated-before", "", "Filter deployments updated before this date (RFC3339 or relative)")
// Sorting
cmd.Flags().StringVar(&p.SortBy, "sort", "-created_at", "Sort field and direction (e.g., 'created_at', '-created_at', 'status')")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
payload, ok := opts.Payload.(*DeploymentListCmd)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
req := &node.DeploymentListRequest{
Metadata: payload.Metadata,
Limit: payload.Limit,
Offset: payload.Offset,
SortBy: payload.SortBy,
}
// set status
if len(payload.Status) > 0 {
req.Status = make([]jobtypes.DeploymentStatus, 0, len(payload.Status))
for _, statusStr := range payload.Status {
statusStr = strings.TrimSpace(statusStr)
for i := jobtypes.DeploymentStatusPreparing; i <= jobtypes.DeploymentStatusCompleted; i++ {
if strings.EqualFold(i.String(), statusStr) {
req.Status = append(req.Status, i)
break
}
}
}
}
// Parse date strings from CLI if provided
if payload.CreatedAfter != "" {
parsed, err := parseDateString(payload.CreatedAfter)
if err != nil {
return nil, fmt.Errorf("invalid created-after date: %w", err)
}
req.CreatedAfter = &parsed
}
if payload.CreatedBefore != "" {
parsed, err := parseDateString(payload.CreatedBefore)
if err != nil {
return nil, fmt.Errorf("invalid created-before date: %w", err)
}
req.CreatedBefore = &parsed
}
if payload.UpdatedAfter != "" {
parsed, err := parseDateString(payload.UpdatedAfter)
if err != nil {
return nil, fmt.Errorf("invalid updated-after date: %w", err)
}
req.UpdatedAfter = &parsed
}
if payload.UpdatedBefore != "" {
parsed, err := parseDateString(payload.UpdatedBefore)
if err != nil {
return nil, fmt.Errorf("invalid updated-before date: %w", err)
}
req.UpdatedBefore = &parsed
}
return dmsClient.DeploymentList(ctx, *req, opts.MsgOpts...)
},
Short: "List deployments",
Long: `Invokes the /dms/node/deployment/list behavior on an actor
This behavior retrieves a list of deployments on the node with support for pagination, filtering, and sorting.
Examples:
# List first 10 deployments
nunet actor cmd --context user /dms/node/deployment/list --limit 10
# List with pagination
nunet actor cmd --context user /dms/node/deployment/list --limit 10 --offset 0
# Filter by status
nunet actor cmd --context user /dms/node/deployment/list --status Running --status Failed
# Filter by creation date (relative)
nunet actor cmd --context user /dms/node/deployment/list --created-after "7d"
# Filter by creation date (absolute)
nunet actor cmd --context user /dms/node/deployment/list --created-after "2024-01-01T00:00:00Z"
# Combine filters with pagination
nunet actor cmd --context user /dms/node/deployment/list --status Running --created-after "1d" --limit 50 --sort "-created_at"
# With metadata filter
nunet actor cmd --context user /dms/node/deployment/list --filter "environment=production" --status Running --limit 10`,
},
// /dms/node/deployment/prune
behaviors.DeploymentPruneBehavior: {
Action: bInvoke,
Payload: func() any { return &node.DeploymentPruneRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentPruneRequest)
cmd.Flags().StringVar(&p.Before, "before", "", "remove deployments created before this time: RFC3339 or duration (e.g. 1m, 1h, 1s, 1d)")
cmd.Flags().BoolVarP(&p.All, "all", "a", false, "remove all deployments whose status is greater than Running")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.DeploymentPruneRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
if req.Before == "" && !req.All {
return nil, fmt.Errorf("must provide --before or --all")
}
return dmsClient.DeploymentPrune(ctx, *req, opts.MsgOpts...)
},
Short: "Prune old deployments",
Long: `Invokes the /dms/node/deployment/prune behavior on an actor
This behavior removes deployments before a specified datetime or duration, or deletes all deployments with status greater than Running when --all is used.
Examples:
nunet actor cmd --context user /dms/node/deployment/prune --before 2025-01-01T00:00:00Z
nunet actor cmd --context user /dms/node/deployment/prune --all`,
},
// /dms/node/deployment/status
behaviors.DeploymentStatusBehavior: {
Action: bInvoke,
Payload: func() any { return &node.DeploymentStatusRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentStatusRequest)
cmd.Flags().StringVarP(&p.ID, "id", "i", "", "deployment ID (required)")
cmd.Flags().BoolVarP(&p.IncludeUsage, "include-usage", "u", false, "include allocation resource usage statistics")
_ = cmd.MarkFlagRequired("id")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.DeploymentStatusRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.DeploymentStatus(ctx, *req, opts.MsgOpts...)
},
Short: "Get deployment status",
Long: `Invokes the /dms/node/deployment/status behavior on an actor
This behavior retrieves the status of a specific deployment.
Examples:
nunet actor cmd --context user /dms/node/deployment/status --id <deployment_id>
nunet actor cmd --context user /dms/node/deployment/status --id <deployment_id> --include-usage`,
},
// /dms/node/deployment/logs
behaviors.DeploymentLogsBehavior: {
Action: bInvoke,
Payload: func() any { return &node.DeploymentLogsRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentLogsRequest)
cmd.Flags().StringVarP(&p.EnsembleID, "id", "i", "", "ensemble ID (required)")
cmd.Flags().StringVarP(&p.AllocationName, "allocation", "a", "", "allocation name (required)")
_ = cmd.MarkFlagRequired("id")
_ = cmd.MarkFlagRequired("allocation")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.DeploymentLogsRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.DeploymentLogs(ctx, *req, opts.MsgOpts...)
},
Short: "Get deployment logs",
Long: `Invokes the /dms/node/deployment/logs behavior on an actor
This behavior retrieves the logs of a specific deployment, writing it to a file
with path returned in the response.
Examples:
nunet actor cmd --context user /dms/node/deployment/logs --id <deployment_id> --allocation <allocation_name>`,
},
// /dms/node/deployment/manifest
behaviors.DeploymentManifestBehavior: {
Action: bInvoke,
Payload: func() any { return &node.DeploymentManifestRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentManifestRequest)
cmd.Flags().StringVarP(&p.ID, "id", "i", "", "deployment ID (required)")
_ = cmd.MarkFlagRequired("id")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.DeploymentManifestRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.DeploymentManifest(ctx, *req, opts.MsgOpts...)
},
Short: "Get deployment manifest",
Long: `Invokes the /dms/node/deployment/manifest behavior on an actor
This behavior retrieves the manifest of a specific deployment.
Examples:
nunet actor cmd --context user /dms/node/deployment/manifest --id <deployment_id>`,
},
// /dms/node/deployment/info
behaviors.DeploymentInfoBehavior: {
Action: bInvoke,
Payload: func() any { return &node.DeploymentInfoRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentInfoRequest)
cmd.Flags().StringVarP(&p.ID, "id", "i", "", "deployment ID (required)")
cmd.Flags().BoolVar(&p.IncludeUsage, "usage", false, "include resource usage statistics")
cmd.Flags().BoolVar(&p.IncludeLogs, "logs", false, "include log file paths for allocations")
cmd.Flags().StringSliceVar(&p.AllocationNames, "allocations", nil, "specific allocation names to include logs for (empty = all)")
_ = cmd.MarkFlagRequired("id")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.DeploymentInfoRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.DeploymentInfo(ctx, *req, opts.MsgOpts...)
},
Short: "Get comprehensive deployment information",
Long: `Invokes the /dms/node/deployment/info behavior on an actor
This behavior retrieves comprehensive information about a deployment including status,
manifest, allocation details, optional resource usage, and optional log file paths.
Logs are returned as file paths (not content) for optimal performance.
Examples:
nunet actor cmd --context user /dms/node/deployment/info --id <deployment_id>
nunet actor cmd --context user /dms/node/deployment/info --id <deployment_id> --usage
nunet actor cmd --context user /dms/node/deployment/info --id <deployment_id> --logs
nunet actor cmd --context user /dms/node/deployment/info --id <deployment_id> --usage --logs
nunet actor cmd --context user /dms/node/deployment/info --id <deployment_id> --logs --allocations alloc1 alloc2`,
},
// /dms/node/deployment/shutdown
behaviors.DeploymentShutdownBehavior: {
Action: bInvoke,
Payload: func() any { return &node.DeploymentShutdownRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentShutdownRequest)
cmd.Flags().StringVarP(&p.ID, "id", "i", "", "deployment ID (required)")
_ = cmd.MarkFlagRequired("id")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.DeploymentShutdownRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.DeploymentShutdown(ctx, *req, opts.MsgOpts...)
},
Short: "Shutdown a deployment",
Long: `Invokes the /dms/node/deployment/shutdown behavior on an actor
This behavior shuts down a specific deployment.
Examples:
nunet actor cmd --context user /dms/node/deployment/shutdown --id <deployment_id>`,
},
behaviors.NewDeploymentBehavior: {
Action: bInvoke,
Short: "Create a new deployment",
Long: `Invokes the /dms/node/deployment/new behavior on an actor
This behavior creates a new deployment.
Examples:
nunet actor cmd --context user /dms/node/deployment/new --spec-file <path to ensemble specification file>`,
Payload: func() any { return &NewDeploymentRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*NewDeploymentRequestCmd)
cmd.Flags().StringVarP(&p.Config, "spec-file", "f", "ensemble.yaml", "path of the ensemble specification file")
},
RunFn: func(ctx context.Context, dmsCli *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
reqCmd, ok := opts.Payload.(*NewDeploymentRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
req := &node.NewDeploymentRequest{}
cfg, err := ProcessEnsembleYaml(afero.Afero{Fs: dmsCli.FS()}, dmsCli.Env(), reqCmd.Config)
if err != nil {
return nil, fmt.Errorf("failed to process ensemble config file: %w", err)
}
req.Ensemble = *cfg
return dmsClient.DeploymentNew(ctx, *req, opts.MsgOpts...)
},
},
// /dms/node/deployment/update
behaviors.DeploymentUpdateBehavior: {
Action: bInvoke,
Short: "Updates an existing deployment",
Long: `Invokes the /dms/node/deployment/update behavior on an actor
This behavior updates an existing deployment.
Examples:
nunet actor cmd --context user /dms/node/deployment/update --spec-file <path to ensemble specification file> --id <ensemble_id>`,
Payload: func() any { return &UpdateDeploymentRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*UpdateDeploymentRequestCmd)
cmd.Flags().StringVarP(&p.Config, "spec-file", "f", "ensemble.yaml", "path of the ensemble specification file")
cmd.Flags().StringVarP(&p.EnsembleID, "id", "i", "", "id of the ensemble to update (required)")
_ = cmd.MarkFlagRequired("id")
},
RunFn: func(
ctx context.Context, dmsCli *cli.DmsCLI,
dmsClient client.DmsClient, opts actorCmdOptions,
) (any, error) {
reqCmd, ok := opts.Payload.(*UpdateDeploymentRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to encode payload")
}
req := &node.UpdateDeploymentRequest{
EnsembleID: reqCmd.EnsembleID,
}
cfg, err := ProcessEnsembleYaml(afero.Afero{Fs: dmsCli.FS()}, dmsCli.Env(), reqCmd.Config)
if err != nil {
return nil, fmt.Errorf("failed to process ensemble config file: %w", err)
}
req.Ensemble = *cfg
return dmsClient.DeploymentUpdate(ctx, *req, opts.MsgOpts...)
},
},
behaviors.ResourcesAllocatedBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.ResourcesAllocated(ctx, opts.MsgOpts...)
},
Short: "Get allocated resources",
Long: `Invokes the /dms/node/resources/allocated behavior on an actor
This behavior retrieves the resources allocated by the node. The resources include CPU, RAM, GPU and disk space.
The returned units are in Hz for CPU clock speed, bytes for RAM, VRAM and disk space.
Examples:
nunet actor cmd --context user /dms/node/resources/allocated`,
},
behaviors.ResourcesFreeBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.ResourcesFree(ctx, opts.MsgOpts...)
},
Short: "Get free resources",
Long: `Invokes the /dms/node/resources/free behavior on an actor
This behavior retrieves the free resources available on the node. The resources include CPU, RAM, GPU and disk space.
The returned units are in Hz for CPU clock speed, bytes for RAM, VRAM and disk space.
Examples:
nunet actor cmd --context user /dms/node/resources/free`,
},
behaviors.ResourcesOnboardedBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.ResourcesOnboarded(ctx, opts.MsgOpts...)
},
Short: "Get onboarded resources",
Long: `Invokes the /dms/node/resources/onboarded behavior on an actor
This behavior retrieves the resources onboarded to the node. The resources include CPU, RAM, GPU and disk space.
The returned units are in Hz for CPU clock speed, bytes for RAM, VRAM and disk space.
Examples:
nunet actor cmd --context user /dms/node/resources/onboarded`,
},
behaviors.AllocationsListBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.AllocationsList(ctx, opts.MsgOpts...)
},
Short: "List allocations",
Long: `Invokes the /dms/node/allocations/list behavior on an actor
This behavior retrieves information about all running allocations within your onboarded DMS.
The information includes allocation ID, status, executor type, container ID, resources, and port mappings.
Examples:
nunet actor cmd --context user /dms/node/allocations/list`,
},
behaviors.LoggerConfigBehavior: {
Payload: func() any { return &node.LoggerConfigRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.LoggerConfigRequest)
cmd.Flags().StringVarP(&p.URL, "url", "u", "", "Elasticsearch URL")
cmd.Flags().StringVarP(&p.Level, "level", "l", "", "logging level (info, warn, debug etc.)")
cmd.Flags().IntVarP(&p.Interval, "interval", "i", 0, "flush interval in seconds")
cmd.MarkFlagsOneRequired("url", "level", "interval")
cmd.Flags().StringVar(&p.APIKey, "api-key", "", "API Key for Elasticsearch and APM")
cmd.Flags().StringVar(&p.APMURL, "apm-url", "", "APM Server URL")
cmd.Flags().Bool("enable-elastic", false, "Enable Elasticsearch logging")
},
PreRunFn: func(cmd *cobra.Command, _ *cli.DmsCLI, opts actorCmdOptions) error {
p, ok := opts.Payload.(*node.LoggerConfigRequest)
if !ok {
return fmt.Errorf("failed to decode payload")
}
flag := cmd.Flags().Lookup("enable-elastic")
if flag != nil && flag.Changed {
val, err := strconv.ParseBool(flag.Value.String())
if err != nil {
return fmt.Errorf("invalid value for --enable-elastic: %v", err)
}
p.ElasticEnabled = &val
}
return nil
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.LoggerConfigRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.LoggerConfig(ctx, *req, opts.MsgOpts...)
},
Action: bInvoke,
Short: "Adjust logger settings",
Long: `Invokes the /dms/node/logger/config behavior on an actor
This behavior allows the user to adjust logger settings, i.e. logging level, flush interval and Elasticsearch URL.
Examples:
nunet actor cmd --context user /dms/node/logger/config --level debug # set debug level
nunet actor cmd --context user /dms/node/logger/config --url <elasticsearch-url>
nunet actor cmd --context user /dms/node/logger/config --interval 10 # flush logs each 10 seconds
nunet actor cmd --context user /dms/node/logger/config --api-key <api-key>
nunet actor cmd --context user /dms/node/logger/config --apm-url <apm-url>
nunet actor cmd --context user /dms/node/logger/config --enable-elastic`,
},
behaviors.HardwareSpecBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.HardwareSpec(ctx, opts.MsgOpts...)
},
Short: "Get hardware specifications",
Long: `Invokes the /dms/node/hardware/spec behavior on an actor
This behavior retrieves the hardware specifications of the system.
Examples:
nunet actor cmd --context user /dms/node/hardware/spec`,
},
behaviors.HardwareUsageBehavior: {
Action: bInvoke,
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
return dmsClient.HardwareUsage(ctx, opts.MsgOpts...)
},
Short: "Get hardware usage",
Long: `Invokes the /dms/node/hardware/usage behavior on an actor
This behavior retrieves the hardware usage of the system.
Examples:
nunet actor cmd --context user /dms/node/hardware/usage`,
},
behaviors.CapListBehavior: {
Action: bInvoke,
Payload: func() any { return &node.CapListRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.CapListRequest)
cmd.Flags().StringVarP(&p.Context, "context", "c", "", "context name")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.CapListRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.CapList(ctx, *req, opts.MsgOpts...)
},
Short: "List capabilities",
Long: `Invokes the /dms/cap/list behavior on an actor
This behavior retrieves a list of capabilities available on the node.
Examples:
nunet actor cmd --context user /dms/cap/list`,
},
behaviors.ProvideCapAnchorBehavior: {
Action: bInvoke,
Payload: func() any { return &CapAnchorRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*CapAnchorRequestCmd)
cmd.Flags().StringVar(&p.Token, "token", "", "add revoke anchor")
cmd.MarkFlagsOneRequired("token")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
payload, ok := opts.Payload.(*CapAnchorRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
var token ucan.Token
if err := json.Unmarshal([]byte(payload.Token), &token); err != nil {
return nil, err
}
req := &node.CapTokenAnchorRequest{
Token: ucan.TokenList{
Tokens: []*ucan.Token{
&token,
},
},
}
return dmsClient.ProvideCapAnchor(ctx, *req, opts.MsgOpts...)
},
Short: "Anchors a capability token on the provide anchor of a node",
Long: `Invokes the /dms/cap/provide/anchor behavior on an actor and requests to anchor on provide anchor.
This behavior invokes a node to anchor a token on the provide anchor.
Examples:
nunet actor cmd --context user /dms/cap/provide/anchor --dest <peerID|did> --token <token>`,
},
behaviors.RequireCapAnchorBehavior: {
Action: bInvoke,
Payload: func() any { return &CapAnchorRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*CapAnchorRequestCmd)
cmd.Flags().StringVar(&p.Token, "token", "", "add revoke anchor")
cmd.MarkFlagsOneRequired("token")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
payload, ok := opts.Payload.(*CapAnchorRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
var token ucan.Token
if err := json.Unmarshal([]byte(payload.Token), &token); err != nil {
return nil, err
}
req := &node.CapTokenAnchorRequest{
Token: ucan.TokenList{
Tokens: []*ucan.Token{
&token,
},
},
}
return dmsClient.RequireCapAnchor(ctx, *req, opts.MsgOpts...)
},
Short: "Anchors a capability token on the require anchor of a node",
Long: `Invokes the /dms/cap/require/anchor behavior on an actor and anchors a token on the require anchor.
This behavior invokes a node to anchor a token on the require anchor.
Examples:
nunet actor cmd --context user /dms/cap/require/anchor --dest <peerID|did> --token <token>`,
},
behaviors.RevokeCapAnchorBehavior: {
Action: bInvoke,
Payload: func() any { return &CapAnchorRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*CapAnchorRequestCmd)
cmd.Flags().StringVar(&p.Token, "token", "", "add revoke anchor")
cmd.MarkFlagsOneRequired("token")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
payload, ok := opts.Payload.(*CapAnchorRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
var token ucan.Token
if err := json.Unmarshal([]byte(payload.Token), &token); err != nil {
return nil, err
}
req := &node.CapTokenAnchorRequest{
Token: ucan.TokenList{
Tokens: []*ucan.Token{
&token,
},
},
}
return dmsClient.RevokeCapAnchor(ctx, *req, opts.MsgOpts...)
},
Short: "Anchors revocation tokens on a node",
Long: `Invokes the /dms/cap/revoke/anchor behavior on an actor and anchors a revocation token.
This behavior invokes a node to anchor a revocation token.
Examples:
nunet actor cmd --context user /dms/cap/revoke/anchor --dest <peerID|did> --token <revocation_token>`,
},
behaviors.BroadcastRevokeCapBehavior: {
Action: bInvoke,
Payload: func() any { return &CapAnchorRequestCmd{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*CapAnchorRequestCmd)
cmd.Flags().StringVar(&p.Token, "token", "", "add revoke token")
cmd.MarkFlagsOneRequired("token")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
payload, ok := opts.Payload.(*CapAnchorRequestCmd)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
var token ucan.Token
if err := json.Unmarshal([]byte(payload.Token), &token); err != nil {
return nil, err
}
req := &node.CapTokenAnchorRequest{
Token: ucan.TokenList{
Tokens: []*ucan.Token{
&token,
},
},
}
return dmsClient.BroadcastCapRevoke(ctx, *req, opts.MsgOpts...)
},
Short: "Broadcast revocation capability anchors",
Long: `Invokes the /dms/cap/revoke/broadcast behavior on an actor
This behavior broadcasts a revocation token.
Examples:
nunet actor cmd --context user /dms/cap/revoke/broadcast --token <revocation_token>`,
},
behaviors.DeploymentDeleteBehavior: {
Action: bInvoke,
Payload: func() any { return &node.DeploymentDeleteRequest{} },
SetFlags: func(cmd *cobra.Command, payload any) {
p := payload.(*node.DeploymentDeleteRequest)
cmd.Flags().StringVar(&p.OrchestratorID, "id", "", "deployment id to delete (required)")
_ = cmd.MarkFlagRequired("id")
},
RunFn: func(ctx context.Context, _ *cli.DmsCLI, dmsClient client.DmsClient, opts actorCmdOptions) (any, error) {
req, ok := opts.Payload.(*node.DeploymentDeleteRequest)
if !ok {
return nil, fmt.Errorf("failed to decode payload")
}
return dmsClient.DeploymentDelete(ctx, *req, opts.MsgOpts...)
},
Short: "Delete a specific deployment",
Long: `Invokes the /dms/node/deployment/delete behavior on an actor
This behavior removes a specific deployment by its deployment id.
Examples:
nunet actor cmd --context user /dms/node/deployment/delete --id <deployment-id>`,
},
}
func getDestinationHandle(cDID, cHost string) (string, error) {
contractDID, err := did.FromString(cDID)
if err != nil {
return "", fmt.Errorf("failed to get contract did")
}
pubKey, err := did.PublicKeyFromDID(contractDID)
if err != nil {
return "", fmt.Errorf("failed to get contract did public key")
}
contractHostDID, err := did.FromString(cHost)
if err != nil {
return "", fmt.Errorf("failed to get contract host did")
}
hostPubKey, err := did.PublicKeyFromDID(contractHostDID)
if err != nil {
return "", fmt.Errorf("failed to get contract host public key")
}
hostPeerID, err := peer.IDFromPublicKey(hostPubKey)
if err != nil {
return "", fmt.Errorf("failed to get contract host peer id")
}
destination, err := actor.HandleFromPublicKeyWithInboxAddress(pubKey, cDID, hostPeerID.String())
if err != nil {
return "", fmt.Errorf("failed to get create remote handle")
}
d, err := json.Marshal(destination)
if err != nil {
return "", fmt.Errorf("failed to marshal destination handle")
}
return string(d), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"time"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/client"
"gitlab.com/nunet/device-management-service/cmd/utils"
)
const (
fnContext = "context"
fnDest = "dest"
fnTimeout = "timeout"
fnExpiry = "expiry"
fnBroadcast = "broadcast"
fnInvoke = "invoke"
)
func useMessageOptsFlags(cmd *cobra.Command, persistent bool) {
flagSet := cmd.Flags()
if persistent {
flagSet = cmd.PersistentFlags()
}
flagSet.StringP(fnContext, "c", "", "capability context name")
flagSet.StringP(fnDest, "d", "", "destination DMS DID, peer ID or handle")
flagSet.DurationP(fnTimeout, "t", 0, "timeout duration")
flagSet.VarP(utils.NewTimeValue(&time.Time{}), fnExpiry, "e", "expiration time")
cmd.MarkFlagsMutuallyExclusive(fnTimeout, fnExpiry)
}
func useNewMsgOptsFlags(cmd *cobra.Command, persistent bool) {
useMessageOptsFlags(cmd, persistent)
flagSet := cmd.Flags()
if persistent {
flagSet = cmd.PersistentFlags()
}
flagSet.StringP(fnBroadcast, "b", "", "broadcast topic")
flagSet.BoolP(fnInvoke, "i", false, "construct an invocation")
cmd.MarkFlagsMutuallyExclusive(fnDest, fnBroadcast)
cmd.MarkFlagsMutuallyExclusive(fnInvoke, fnBroadcast)
}
func getBehaviorMsgOpts(cmd *cobra.Command) []client.Option {
var opts []client.Option
if dest, err := cmd.Flags().GetString(fnDest); err == nil {
opts = append(opts, client.WithDestination(dest))
}
if timeout, err := cmd.Flags().GetDuration(fnTimeout); err == nil {
opts = append(opts, client.WithTimeout(timeout))
}
if expiry, err := utils.GetTime(cmd.Flags(), fnExpiry); err == nil {
opts = append(opts, client.WithExpiry(expiry))
}
return opts
}
func getNewMsgOpts(cmd *cobra.Command) []client.Option {
opts := getBehaviorMsgOpts(cmd)
if topic, err := cmd.Flags().GetString(fnBroadcast); err == nil {
opts = append(opts, client.WithTopic(topic))
}
if invocation, err := cmd.Flags().GetBool(fnInvoke); err == nil {
opts = append(opts, client.WithInvocation(invocation))
}
return opts
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"context"
"fmt"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/client"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/convert"
)
// onboardingInput used for command line onboarding parameters input
type onboardingInput struct {
NoGPU bool
GPUsStr string
RAMSize string
DiskSize string
CPUCores float32
CPUCLock float64
GPUs types.GPUs
}
func processOnboardInput(ctx context.Context, dmsClient client.DmsClient, opts actorCmdOptions) error {
p, ok := opts.Payload.(*onboardingInput)
if !ok {
return ErrInvalidArgument
}
r, err := dmsClient.HardwareSpec(ctx, opts.MsgOpts...)
if err != nil || !r.OK {
return fmt.Errorf("could not get machine resourcs: %w", err)
}
res := r.Resources
// Set the CPU clock speed
p.CPUCLock = res.CPU.ClockSpeed
if p.NoGPU {
fmt.Println("Skipping GPU selection.")
return nil
}
// Handle GPU onboarding
//
// If no GPUs are found, skip GPU selection
if len(res.GPUs) == 0 {
fmt.Println("No usable GPUs detected; prerequisites may not be met. Skipping GPU selection.\n" +
"Read more: https://gitlab.com/nunet/device-management-service#gpu-machines")
return nil
}
// Check if GPUs are specified in the command line
if p.GPUsStr != "" {
p.GPUs, err = commandLineGPUOnboarding(res, p.GPUsStr, opts.Streams)
if err != nil {
return fmt.Errorf("onboard GPUs: %w", err)
}
return nil
}
// Interactive GPU onboarding
r, err = dmsClient.HardwareUsage(ctx, opts.MsgOpts...)
if err != nil || !r.OK {
return fmt.Errorf("could not get machine resource usage: %w", err)
}
usage := r.Resources
p.GPUs, err = interactiveGPUOnboarding(res, usage, opts.Streams)
if err != nil {
return fmt.Errorf("interactive GPU onboarding: %w", err)
}
return nil
}
// commandLineGPUOnboarding parses the GPU arguments from the command line and allocates VRAM for each selected GPU
// The GPU arguments are in the format "index:VRAM,index:VRAM,..."
func commandLineGPUOnboarding(machineResources types.Resources, gpuArgs string, _ cli.Streams) (types.GPUs, error) {
var gpus types.GPUs
gpuIndices := strings.Split(gpuArgs, ",")
for _, gpuIndex := range gpuIndices {
gpuIndexSplit := strings.Split(gpuIndex, ":")
if len(gpuIndexSplit) != 2 {
return nil, fmt.Errorf("invalid GPU format: %s", gpuIndex)
}
index, err := strconv.Atoi(gpuIndexSplit[0])
if err != nil {
return nil, fmt.Errorf("invalid GPU index: %w", err)
}
gpu, err := machineResources.GPUs.GetWithIndex(index)
if err != nil {
return nil, fmt.Errorf("invalid GPU index: %w", err)
}
gpu.VRAM, err = convert.ParseBytesWithDefaultUnit(gpuIndexSplit[1], "GiB")
if err != nil {
return nil, fmt.Errorf("invalid GPU VRAM: %w", err)
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// interactiveGPUOnboarding prompts the user to select GPUs and allocate VRAM for each selected GPU
func interactiveGPUOnboarding(machineResources types.Resources, machineResourceUsage types.Resources, streams cli.Streams) (types.GPUs, error) {
var (
gpuMap = make(map[string]types.GPU)
gpuPromptItems = make([]*selectPromptItem, 0)
selectedGPUs types.GPUs
)
for _, gpu := range machineResources.GPUs {
gpuMap[gpu.Model] = gpu
gpuPromptItems = append(gpuPromptItems, &selectPromptItem{
Label: gpu.Model,
})
}
// Prompt for GPU selection
res, err := selectPromptMultiple("Select GPU", gpuPromptItems, streams)
if err != nil {
return nil, fmt.Errorf("could not select GPU: %w", err)
}
// Validate VRAM input
vramValidator := func(input string) error {
if _, err := strconv.ParseFloat(input, 64); err != nil {
return fmt.Errorf("invalid input: %w", err)
}
return nil
}
// Update the VRAM allocation for each selected GPU
for _, gpuName := range res {
gpu := gpuMap[gpuName]
fmt.Printf("-----------------------------------\n")
fmt.Printf("Selected GPU: %s\n", gpuName)
fmt.Printf("Total VRAM: %d GB\n", gpu.VRAMInGB())
gpuUsage, err := machineResourceUsage.GPUs.GetWithIndex(gpu.Index)
if err != nil {
return nil, fmt.Errorf("could not get GPU usage: %w", err)
}
fmt.Printf("Used VRAM: %d GB\n", gpuUsage.VRAMInGB())
fmt.Printf("Available VRAM: %d GB\n", types.ConvertBytesToGB(gpu.VRAM-gpuUsage.VRAM))
// Prompt for VRAM allocation
input, err := prompt("Enter new VRAM allocation in GB", vramValidator, streams)
if err != nil {
return nil, fmt.Errorf("could not prompt for VRAM: %w", err)
}
fmt.Println("-----------------------------------")
// Update the GPU with the new VRAM allocation
vram, err := strconv.ParseUint(input, 10, 64)
if err != nil {
return nil, fmt.Errorf("could not parse VRAM: %w", err)
}
gpu.VRAM = types.ConvertGBToBytes(vram)
selectedGPUs = append(selectedGPUs, gpu)
}
return selectedGPUs, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"fmt"
"io"
"github.com/manifoldco/promptui"
"gitlab.com/nunet/device-management-service/cmd/cli"
)
type selectPromptItem struct {
Label string
Selected bool
}
type writerCloser struct {
io.Writer
}
func (wc *writerCloser) Close() error {
return nil
}
func runSelectPrompt(label string, items []*selectPromptItem, multiple bool, streams cli.Streams) ([]string, error) {
const doneLabel = "Done"
if multiple && len(items) > 0 && items[0].Label != doneLabel {
items = append([]*selectPromptItem{{Label: doneLabel}}, items...)
}
template := &promptui.SelectTemplates{
Label: "{{ .Label }}",
Active: "{{ if .Selected }}{{ \"✓\" | green }} {{ end }}→ {{ .Label | cyan | bold }}",
Inactive: "{{ if .Selected }}{{ \"✓\" | green }} {{ end }} {{ .Label | faint }}",
Selected: "{{ .Label | green | bold }}",
}
p := promptui.Select{
Label: label,
Items: items,
Templates: template,
HideSelected: true,
Size: len(items),
Stdin: io.NopCloser(streams.In),
Stdout: &writerCloser{streams.Out},
}
for done := false; !done; {
index, _, err := p.Run()
if err != nil {
return nil, fmt.Errorf("prompt failed %w", err)
}
selectedItem := items[index]
if multiple && selectedItem.Label != doneLabel {
selectedItem.Selected = !selectedItem.Selected
done = false
} else {
done = true
}
}
var selected []string
for _, item := range items {
if item.Selected {
selected = append(selected, item.Label)
}
}
return selected, nil
}
func selectPromptMultiple(label string, items []*selectPromptItem, streams cli.Streams) ([]string, error) {
return runSelectPrompt(label, items, true, streams)
}
func prompt(label string, validate func(string) error, streams cli.Streams) (string, error) {
p := promptui.Prompt{
Label: label,
Templates: &promptui.PromptTemplates{
Valid: "{{ \"✓\" | green }} {{ . }} {{ \":\" | bold}} ",
},
Validate: validate,
Stdin: io.NopCloser(streams.In),
Stdout: &writerCloser{streams.Out},
}
result, err := p.Run()
return result, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package actor
import (
"encoding/json"
"fmt"
"io"
"strconv"
"strings"
"time"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/dms/jobs/parser"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/lib/env"
)
func displayResponse(w io.Writer, resp any) error {
encoder := json.NewEncoder(w)
encoder.SetIndent("", " ")
return encoder.Encode(resp)
}
func ProcessEnsembleYaml(fs afero.Afero, env env.EnvironmentProvider, path string) (
*jobtypes.EnsembleConfig, error,
) {
data, err := fs.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}
cfg := &jobtypes.EnsembleConfig{}
err = parser.Decode(parser.SpecTypeEnsembleV1, data, &cfg, &parser.Options{
Env: env,
Fs: fs,
WorkingDir: "",
})
if err != nil {
return nil, err
}
return cfg, nil
}
// parseDateString parses a date string supporting both relative formats (e.g., "7d", "12h") and absolute formats (RFC3339, common date formats)
func parseDateString(dateStr string) (time.Time, error) {
dateStr = strings.TrimSpace(dateStr)
if dateStr == "" {
return time.Time{}, fmt.Errorf("empty date string")
}
// Try relative formats first
if strings.HasSuffix(dateStr, "d") {
// days is not a standard Go duration; handle explicitly
daysStr := strings.TrimSuffix(dateStr, "d")
if daysStr == "" {
return time.Time{}, fmt.Errorf("invalid date duration: %s", dateStr)
}
if nDays, err := strconv.Atoi(daysStr); err == nil && nDays > 0 {
return time.Now().AddDate(0, 0, -nDays), nil
}
return time.Time{}, fmt.Errorf("invalid date duration days: %s", dateStr)
}
// Try standard duration formats
if dur, err := time.ParseDuration(dateStr); err == nil {
return time.Now().Add(-dur), nil
}
// Try datetime formats
var parseErr error
for _, layout := range []string{time.RFC3339, "2006-01-02 15:04:05", "2006-01-02"} {
t, err := time.Parse(layout, dateStr)
if err == nil {
return t, nil
}
parseErr = err
}
return time.Time{}, fmt.Errorf("invalid date format: %w", parseErr)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"strconv"
"strings"
"syscall"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
// AnchorOptions holds the command-line options for the anchor command.
type AnchorOptions struct {
Context string
Root string
Provide string
Require string
Revoke string
PrismURL string
}
func newAnchorCmd(dmsCLI *cli.DmsCLI) *cobra.Command {
var opts AnchorOptions
cmd := &cobra.Command{
Use: "anchor",
Short: "Manage capability anchors",
Long: `Add or modify capability anchors in a capability context.
An anchor is a basis of trust in the capability system. There are three types of anchors:
1. Root anchor: Represents absolute trust or effective root capability.
Use the --root flag with a DID value to add a root anchor.
2. Require anchor: Represents input trust. We verify incoming messages based on the require anchor.
Use the --require flag with a token to add a require anchor.
3. Provide anchor: Represents output trust. We emit invocation tokens based on our provide anchors and sign output.
Use the --provide flag with a token to add a provide anchor.
Only one type of anchor can be added or modified per command execution.
Usage examples:
nunet cap anchor --context user --root did:example:123456789abcdefghi
nunet cap anchor --context dms --require '{"some": "json", "token": "here"}'
nunet cap anchor --context user --provide '{"another": "json", "token": "example"}'
nunet cap anchor --context user --revoke '{"another": "revocation", "token": "example"}'
Note: The --context flag is required to specify the capability context.`,
RunE: func(cmd *cobra.Command, _ []string) error {
return runAnchorCmd(cmd.Context(), dmsCLI, opts, cli.CmdStreams(cmd))
},
}
useFlagContext(cmd, &opts.Context)
useFlagRoot(cmd, &opts.Root)
useFlagRequire(cmd, &opts.Require)
useFlagProvide(cmd, &opts.Provide)
useFlagRevoke(cmd, &opts.Revoke)
cmd.Flags().StringVar(&opts.PrismURL, "prism-url", "", "PRISM resolver URL (e.g., http://localhost:8080). Required when adding require/provide/revoke anchors with PRISM DIDs. Can also be set via PRISM_RESOLVER_URL environment variable.")
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnProvide, fnRoot, fnRequire, fnRevoke)
cmd.MarkFlagsMutuallyExclusive(fnProvide, fnRoot, fnRequire, fnRevoke)
return cmd
}
func runAnchorCmd(_ context.Context, dmsCLI *cli.DmsCLI, opts AnchorOptions, _ cli.Streams) error {
capCtx, err := utils.LoadCapabilityContext(dmsCLI, opts.Context)
if err != nil {
return err
}
// Configure PRISM resolver if URL is provided (needed for require/provide/revoke with PRISM DIDs)
// Check flag first, then environment variable
prismURL := opts.PrismURL
if prismURL == "" {
prismURL = dmsCLI.Env().Getenv("PRISM_RESOLVER_URL")
}
originalConfig := did.GetPRISMResolverConfig()
if prismURL != "" {
did.SetPRISMResolverConfig(did.PRISMResolverConfig{
ResolverURL: prismURL,
PreferredVerificationMethod: originalConfig.PreferredVerificationMethod,
HTTPClient: originalConfig.HTTPClient,
})
defer did.SetPRISMResolverConfig(originalConfig)
}
switch {
case opts.Root != "":
rootDID, err := did.FromString(opts.Root)
if err != nil {
return fmt.Errorf("invalid root DID: %w", err)
}
if err := capCtx.AddRoots([]did.DID{rootDID}, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{}); err != nil {
return fmt.Errorf("failed to add root anchors: %w", err)
}
case opts.Require != "":
var tokens ucan.TokenList
if err := json.Unmarshal([]byte(opts.Require), &tokens); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
if err := capCtx.AddRoots(nil, tokens, ucan.TokenList{}, ucan.TokenList{}); err != nil {
return fmt.Errorf("failed to add require anchors: %w", err)
}
case opts.Provide != "":
var tokens ucan.TokenList
if err := json.Unmarshal([]byte(opts.Provide), &tokens); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
if err := capCtx.AddRoots(nil, ucan.TokenList{}, tokens, ucan.TokenList{}); err != nil {
return fmt.Errorf("failed to add provide anchors: %w", err)
}
case opts.Revoke != "":
var token ucan.Token
if err := json.Unmarshal([]byte(opts.Revoke), &token); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
if err := capCtx.AddRoots(nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{Tokens: []*ucan.Token{&token}}); err != nil {
return fmt.Errorf("failed to add revoke anchors: %w", err)
}
default:
return fmt.Errorf("one of --provide, --root, --require, or --revoke must be specified")
}
if err := utils.SaveCapabilityContext(dmsCLI, capCtx); err != nil {
return err
}
// Send SIGUSR1 to running DMS to reload contexts
if err := signalDMSReload(dmsCLI); err != nil {
// Log the error but don't fail - DMS might not be running (expected during initial setup)
fmt.Fprintf(os.Stderr, "Warning: Could not signal DMS to reload (DMS may not be running): %v\n", err)
} else {
fmt.Println("Successfully signaled DMS to reload capability contexts")
}
return nil
}
// signalDMSReload sends SIGUSR1 to the running DMS process
func signalDMSReload(dmsCLI *cli.DmsCLI) error {
// Get DMS config to find the port
cfg, err := dmsCLI.Config()
if err != nil {
// If we can't load config, fall back to default port
cfg = &config.Config{}
cfg.Rest.Port = 9999
}
port := cfg.Rest.Port
if port == 0 {
port = 9999 // default
}
// Find process listening on the configured port
var pidBytes []byte
// Try lsof first (more widely available)
pidBytes, err = exec.Command("sh", "-c", fmt.Sprintf("lsof -ti :%d -sTCP:LISTEN", port)).Output()
if err != nil {
// Try ss as fallback (on systems without lsof)
pidBytes, err = exec.Command("sh", "-c", fmt.Sprintf("ss -tlnp | grep :%d | grep -oP 'pid=\\K[0-9]+'", port)).Output()
if err != nil {
// DMS not running - this is OK during initial setup
return nil
}
}
pidStr := strings.TrimSpace(string(pidBytes))
if pidStr == "" {
// No process listening - DMS not running
return nil
}
// Handle multiple PIDs (take the first one)
pids := strings.Split(pidStr, "\n")
pid, err := strconv.Atoi(pids[0])
if err != nil {
return fmt.Errorf("invalid PID: %w", err)
}
process, err := os.FindProcess(pid)
if err != nil {
return fmt.Errorf("failed to find process: %w", err)
}
// Send SIGUSR1
if err := process.Signal(syscall.SIGUSR1); err != nil {
return fmt.Errorf("failed to send SIGUSR1 signal: %w", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"context"
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/node"
)
// BroadcastOptions holds the command-line options for the broadcast command.
type BroadcastOptions struct {
Context string
}
func newBroadcastCmd(dmsCLI *cli.DmsCLI) *cobra.Command {
var opts BroadcastOptions
cmd := &cobra.Command{
Use: "broadcast",
Short: "Broadcast revocation tokens to all peers",
Long: `Broadcast all revocation tokens from a capability context to all peers in the network.
This command retrieves all revocation tokens from the specified capability context
and broadcasts them to all connected peers via the /nunet/revocation topic.
Each peer will receive the revocation tokens and update their local capability contexts.
The --context flag specifies which capability context to broadcast revocations from.
If not specified, the DMS context is used by default.
Usage examples:
nunet cap broadcast --context dms
nunet cap broadcast --context org`,
RunE: func(cmd *cobra.Command, _ []string) error {
return runBroadcastCmd(cmd.Context(), dmsCLI, opts)
},
}
useFlagContext(cmd, &opts.Context)
return cmd
}
func runBroadcastCmd(ctx context.Context, dmsCLI *cli.DmsCLI, opts BroadcastOptions) error {
// Create security context for the specified capability context
sctx, err := utils.NewSecurityContext(dmsCLI, opts.Context)
if err != nil {
return fmt.Errorf("failed to create security context: %w", err)
}
// Get DMS client with the security context
client, err := dmsCLI.NewClient(sctx)
if err != nil {
return fmt.Errorf("failed to create DMS client: %w", err)
}
// Prepare the broadcast request
req := node.CapBroadcastRequest{
Context: opts.Context,
}
// Invoke the /dms/cap/broadcast behavior
respEnvelope, err := client.InvokeBehavior(ctx, behaviors.BroadcastRevokeCapBehavior, req)
if err != nil {
return fmt.Errorf("failed to broadcast revocation tokens: %w", err)
}
var resp node.CapBroadcastResponse
if err := json.Unmarshal(respEnvelope.Message, &resp); err != nil {
return fmt.Errorf("failed to unmarshal response: %w", err)
}
if !resp.OK {
return fmt.Errorf("broadcast failed: %s", resp.Error)
}
fmt.Printf("Successfully broadcast %d revocation token(s) to all peers\n", resp.TokensCount)
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
)
const (
fnContext = "context"
fnAudience = "audience"
fnAction = "action"
fnCap = "cap"
fnTopic = "topic"
fnExpiry = "expiry"
fnDuration = "duration"
fnAutoExpire = "auto-expire"
fnSelfSign = "self-sign"
fnDepth = "depth"
fnProvide = "provide"
fnRevoke = "revoke"
fnRoot = "root"
fnRequire = "require"
fnForce = "force"
fnToken = "token"
)
// NewCapCmd returns the cap command that adds other commands
func NewCapCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "cap",
Short: "Manage capabilities",
Long: `Manage capabilities for the Device Management Service`,
}
cmd.AddCommand(newGrantCmd(dmsCli))
cmd.AddCommand(newAnchorCmd(dmsCli))
cmd.AddCommand(newRevokeCmd(dmsCli))
cmd.AddCommand(newNewCmd(dmsCli))
cmd.AddCommand(newDelegateCmd(dmsCli))
cmd.AddCommand(newListCmd(dmsCli))
cmd.AddCommand(newRemoveCmd(dmsCli))
cmd.AddCommand(newBroadcastCmd(dmsCli))
cmd.AddCommand(newHelpCmd())
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
// DelegateCapOptions holds the command-line options for the delegate command.
type DelegateCapOptions struct {
Context string
Caps []string
Topics []string
Audience string
Expiry time.Time
Duration time.Duration
AutoExpire bool
Depth uint64
SelfSign string
Subject string
}
func newDelegateCmd(dmsCLI *cli.DmsCLI) *cobra.Command {
var opts DelegateCapOptions
cmd := &cobra.Command{
Use: "delegate <did>",
Short: "Delegate capabilities",
Long: `Delegate capabilities to a subject
Capabilities are delegated based on provide anchors. No capabilities are delegated by default, you need to use --cap flag to explicitly specify the capabilities to delegate.
Example:
nunet cap anchor --context user --provide '<token>'
nunet cap delegate --context user --cap /public --duration 1h did:key:<some-key>`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.Subject = args[0]
return runDelegateCap(cmd.Context(), dmsCLI, opts, cli.CmdStreams(cmd))
},
}
useFlagContext(cmd, &opts.Context)
useFlagAudience(cmd, &opts.Audience)
useFlagCap(cmd, &opts.Caps)
useFlagTopic(cmd, &opts.Topics)
useFlagExpiry(cmd, &opts.Expiry)
useFlagDuration(cmd, &opts.Duration)
useFlagAutoExpire(cmd, &opts.AutoExpire)
useFlagDepth(cmd, &opts.Depth)
cmd.Flags().StringVar(&opts.SelfSign, fnSelfSign, "no", "Self-sign option: 'no', 'also', or 'only'")
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnExpiry, fnDuration, fnAutoExpire)
cmd.MarkFlagsMutuallyExclusive(fnExpiry, fnDuration, fnAutoExpire)
cmd.MarkFlagsMutuallyExclusive(fnSelfSign, fnAutoExpire)
return cmd
}
func runDelegateCap(_ context.Context, dmsCLI *cli.DmsCLI, opts DelegateCapOptions, streams cli.Streams) error {
var expirationTime uint64
switch {
case !opts.Expiry.IsZero():
expirationTime = uint64(opts.Expiry.UnixNano())
case opts.Duration != 0:
expirationTime = uint64(time.Now().Add(opts.Duration).UnixNano())
case opts.AutoExpire:
expirationTime = 0
default:
return fmt.Errorf("either expiration or duration must be specified")
}
subjectDID, err := did.FromString(opts.Subject)
if err != nil {
return fmt.Errorf("invalid subject DID: %w", err)
}
var audienceDID did.DID
if opts.Audience != "" {
audienceDID, err = did.FromString(opts.Audience)
if err != nil {
return fmt.Errorf("invalid audience DID: %w", err)
}
}
capabilities := make([]ucan.Capability, len(opts.Caps))
for i, cap := range opts.Caps {
capabilities[i] = ucan.Capability(cap)
}
var selfSignMode ucan.SelfSignMode
switch opts.SelfSign {
case "no":
selfSignMode = ucan.SelfSignNo
case "also":
selfSignMode = ucan.SelfSignAlso
case "only":
selfSignMode = ucan.SelfSignOnly
default:
return fmt.Errorf("invalid self-sign option: %s", opts.SelfSign)
}
capCtx, err := utils.LoadCapabilityContext(dmsCLI, opts.Context)
if err != nil {
return err
}
tokens, err := capCtx.Delegate(subjectDID, audienceDID, opts.Topics, expirationTime, opts.Depth, capabilities, selfSignMode)
if err != nil {
return fmt.Errorf("failed to delegate capabilities: %w", err)
}
tokensJSON, err := json.Marshal(tokens)
if err != nil {
return fmt.Errorf("unable to marshal tokens to json: %w", err)
}
fmt.Fprintln(streams.Out, string(tokensJSON))
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"time"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms/node"
)
func useFlagContext(cmd *cobra.Command, context *string) {
cmd.Flags().StringVarP(context, fnContext, "c", node.UserContextName, "specifies capability context")
}
func useFlagAudience(cmd *cobra.Command, audience *string) {
cmd.Flags().StringVarP(audience, fnAudience, "a", "", "audience DID (optional)")
}
func useFlagCap(cmd *cobra.Command, caps *[]string) {
cmd.Flags().StringSliceVar(caps, fnCap, []string{}, "capabilities to grant/delegate (can be specified multiple times)")
}
func useFlagTopic(cmd *cobra.Command, topics *[]string) {
cmd.Flags().StringSliceVarP(topics, fnTopic, "t", []string{}, "topics to grant/delegate (can be specified multiple times)")
}
func useFlagExpiry(cmd *cobra.Command, expiry *time.Time) {
cmd.Flags().VarP(utils.NewTimeValue(expiry), fnExpiry, "e", "set expiration date (ISO 8601 format)")
}
func useFlagDuration(cmd *cobra.Command, duration *time.Duration) {
cmd.Flags().DurationVar(duration, fnDuration, 0, "set duration time (specify unit)")
}
func useFlagAutoExpire(cmd *cobra.Command, autoExpire *bool) {
cmd.Flags().BoolVar(autoExpire, fnAutoExpire, false, "set auto expiration")
}
func useFlagDepth(cmd *cobra.Command, depth *uint64) {
cmd.Flags().Uint64VarP(depth, fnDepth, "d", 0, "delegation depth (optional, default=0)")
}
func useFlagRoot(cmd *cobra.Command, root *string) {
cmd.Flags().StringVar(root, fnRoot, "", "DID to add as root anchor (represents absolute trust)")
}
func useFlagRequire(cmd *cobra.Command, require *string) {
cmd.Flags().StringVar(require, fnRequire, "", "JWT to add as require anchor (for input trust)")
}
func useFlagProvide(cmd *cobra.Command, provide *string) {
cmd.Flags().StringVar(provide, fnProvide, "", "JWT to add as provide anchor (for output trust)")
}
func useFlagRevoke(cmd *cobra.Command, revoke *string) {
cmd.Flags().StringVar(revoke, fnRevoke, "", "JWT to add as revoke anchor (for output trust)")
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
type grantOptions struct {
Context string
Caps []string
Topics []string
Audience string
Expiry time.Time
Duration time.Duration
Depth uint64
Subject string
}
func newGrantCmd(dmsCLI *cli.DmsCLI) *cobra.Command {
var opts grantOptions
cmd := &cobra.Command{
Use: "grant <did>",
Short: "Grant capabilities",
Long: `Grant a self-sign token delegating capabilities
It is not necessary to set up a anchor before granting a capability because this operation is self-signed.
Example:
nunet cap grant --context user --cap /public --duration 1h did:key:<some-key>
`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.Subject = args[0]
return runGrantCmd(cmd.Context(), dmsCLI, opts, cli.CmdStreams(cmd))
},
}
useFlagContext(cmd, &opts.Context)
useFlagCap(cmd, &opts.Caps)
useFlagTopic(cmd, &opts.Topics)
useFlagAudience(cmd, &opts.Audience)
useFlagExpiry(cmd, &opts.Expiry)
useFlagDuration(cmd, &opts.Duration)
useFlagDepth(cmd, &opts.Depth)
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnExpiry, fnDuration)
return cmd
}
func runGrantCmd(_ context.Context, dmsCLI *cli.DmsCLI, opts grantOptions, streams cli.Streams) error {
var expirationTime uint64
switch {
case !opts.Expiry.IsZero():
expirationTime = uint64(opts.Expiry.UnixNano())
case opts.Duration != 0:
expirationTime = uint64(time.Now().Add(opts.Duration).UnixNano())
default:
return fmt.Errorf("either expiration or duration must be specified")
}
subjectDID, err := did.FromString(opts.Subject)
if err != nil {
return fmt.Errorf("invalid subject DID: %w", err)
}
var audienceDID did.DID
if opts.Audience != "" {
audienceDID, err = did.FromString(opts.Audience)
if err != nil {
return fmt.Errorf("invalid audience DID: %w", err)
}
}
capabilities := make([]ucan.Capability, len(opts.Caps))
for i, cap := range opts.Caps {
capabilities[i] = ucan.Capability(cap)
}
capCtx, err := utils.LoadCapabilityContext(dmsCLI, opts.Context)
if err != nil {
return err
}
tokens, err := capCtx.Grant(ucan.Delegate, subjectDID, audienceDID, opts.Topics, expirationTime, opts.Depth, capabilities)
if err != nil {
return fmt.Errorf("failed to grant capabilities: %w", err)
}
tokensJSON, err := json.Marshal(tokens)
if err != nil {
return fmt.Errorf("unable to marshal tokens to json: %w", err)
}
fmt.Fprintln(streams.Out, string(tokensJSON))
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"fmt"
"github.com/spf13/cobra"
)
var behaviorsAndCapsHelp = `
The following are the implemented behaviors and their associated capabilities:
/dms/node/peers/ping
- PeerPingBehavior: Ping a peer to check if it is alive.
/dms/node/peers/list
- PeersListBehavior: List peers visible to the node.
/dms/node/peers/self
- PeerSelfBehavior: Get the peer id and listening address of the node.
/dms/node/peers/dht
- PeerDHTBehavior: Get the peers in DHT of the node along with their DHT parameters.
/dms/node/peers/connect
- PeerConnectBehavior: Connect to a peer.
/dms/node/peers/score
- PeerScoreBehavior: Get the libp2p pubsub peer score of peers.
/dms/node/onboarding/onboard
- OnboardBehavior: Onboard the node as a compute provider.
/dms/node/onboarding/offboard
- OffboardBehavior: Offboard the node as a compute provider.
/dms/node/onboarding/status
- OnboardStatusBehavior: Get the onboarding status.
/dms/node/deployment/new
- NewDeploymentBehavior: Start a new deployment on the node.
/dms/node/deployment/list
- DeploymentListBehavior: List all the deployments orchestrated by the node.
/dms/node/deployment/logs
- DeploymentLogsBehavior: Get the logs of a particular deployment.
/dms/node/deployment/status
- DeploymentStatusBehavior: Get the status of a deployment.
/dms/node/deployment/manifest
- DeploymentManifestBehavior: Get the manifest of a deployment.
/dms/node/deployment/info
- DeploymentInfoBehavior: Get comprehensive information about a deployment in a single call.
/dms/node/deployment/shutdown
- DeploymentShutdownBehavior: Shutdown a deployment.
/dms/node/resources/allocated
- ResourcesAllocatedBehavior: Get the amount of resources allocated to Allocations running on the node.
/dms/node/resources/free
- ResourcesFreeBehavior: Get the amount of resources that are free to be allocated on the node.
/dms/node/resources/onboarded
- ResourcesOnboardedBehavior: Get the amount of resources the node is onboarded with.
/dms/node/hardware/spec
- HardwareSpecBehavior: Get the hardware resource specification of the machine.
/dms/node/hardware/usage
- HardwareUsageBehavior: Get the full resource usage on the machine.
/dms/node/logger/config
- LoggerConfigBehavior: Configure the logger/observability config of the node.
/dms/deployment/request
- BidRequestBehavior: Request a bid from the compute provider for a specific ensemble.
/dms/deployment/bid
- BidReplyBehavior: Reply to a bid request from an orchestrator.
/dms/deployment/commit
- CommitDeploymentBehavior: Temporarily commit the resources the provider bid on.
/dms/deployment/allocate
- AllocationDeploymentBehavior: Allocate the resources the provider bid on.
/dms/deployment/revert
- RevertDeploymentBehavior: Revert any commit or allocation done during a deployment.
/dms/cap/list
- CapListBehavior: Get a list of all the capabilities another node had.
/dms/cap/anchor
- CapAnchorBehavior: Anchor capability tokens on another node.
/public/hello
- PublicHelloBehavior: Get a hello message from a node.
/public/status
- PublicStatusBehavior: Get the total resource amount on the machine.
/broadcast/hello
- BroadcastHelloBehavior: Broadcast a hello message and get replies from nodes.
/dms/allocation/start
- AllocationStartBehavior: Start an allocation after a deployment.
/dms/allocation/restart
- AllocationRestartBehavior: Restart an allocation after a deployment has been started.
/dms/actor/healthcheck/register
- RegisterHealthcheckBehavior: Register a new healthcheck mechanism for an allocation.
/dms/allocation/subnet/add-peer
- SubnetAddPeerBehavior: Add a peer to a subnet.
/dms/allocation/subnet/remove-peer
- SubnetRemovePeerBehavior: Remove a peer from a subnet.
/dms/allocation/subnet/accept-peer
- SubnetAcceptPeerBehavior: Accept a peer in a subnet.
/dms/allocation/subnet/map-port
- SubnetMapPortBehavior: Map a port in a subnet.
/dms/allocation/subnet/unmap-port
- SubnetUnmapPortBehavior: Unmap a port in a subnet.
/dms/allocation/subnet/dns/add-records
- SubnetDNSAddRecordsBehavior: Add DNS records to a subnet.
/dms/allocation/subnet/dns/remove-records
- SubnetDNSRemoveRecordsBehavior: Remove a DNS record from a subnet.
/dms/ensemble/<ENSEMBLE_ID>
- EnsembleNamespace: Interact with ensembles on the node.
/dms/ensemble/<ENSEMBLE_ID>/allocation/logs
- AllocationLogsBehavior: Get the logs of an allocation in an ensemble.
/dms/ensemble/<ENSEMBLE_ID>/allocation/shutdown
- AllocationShutdownBehavior: Shutdown an allocation in an ensemble.
/dms/ensemble/<ENSEMBLE_ID>/node/subnet/create
- SubnetCreateBehavior: Create a new subnet for an ensemble.
/dms/ensemble/<ENSEMBLE_ID>/node/subnet/destroy
- SubnetDestroyBehavior: Destroy a subnet for an ensemble.
`
func newHelpCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "help",
Short: "Description of Behaviors and Associated Capabilities",
Long: "List of all implemented behaviors and their associated capabilities\n" +
"with a short description of each behavior and intended use.\n",
Args: cobra.ExactArgs(0),
Run: func(cmd *cobra.Command, _ []string) {
fmt.Fprintln(cmd.OutOrStdout(), behaviorsAndCapsHelp)
},
}
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
type ListCapOptions struct {
Context string
}
func newListCmd(dmsCLI *cli.DmsCLI) *cobra.Command {
var opts ListCapOptions
cmd := &cobra.Command{
Use: "list",
Short: "List capability anchors",
Long: `List all capability anchors in a capability context
It outputs DIDs and capability tokens set for root, provide, require and revoke anchors.`,
RunE: func(cmd *cobra.Command, _ []string) error {
return runListCap(cmd.Context(), dmsCLI, opts, cli.CmdStreams(cmd))
},
}
useFlagContext(cmd, &opts.Context)
return cmd
}
func runListCap(_ context.Context, dmsCLI *cli.DmsCLI, opts ListCapOptions, streams cli.Streams) error {
capCtx, err := utils.LoadCapabilityContext(dmsCLI, opts.Context)
if err != nil {
return err
}
roots, require, provide, revoke := capCtx.ListRoots()
list, err := formatCapabilityList(roots, require, provide, revoke)
if err != nil {
return fmt.Errorf("failed to format capability list: %w", err)
}
fmt.Fprint(streams.Out, list)
return nil
}
func formatCapabilityList(roots []did.DID, require, provide, revoke ucan.TokenList) (string, error) {
var sb strings.Builder
fmt.Fprintf(&sb, "roots:\n")
for _, root := range roots {
fmt.Fprintf(&sb, "\t%s\n", root)
}
fmt.Fprintf(&sb, "require:\n")
for _, t := range require.Tokens {
data, err := json.Marshal(t)
if err != nil {
return "", fmt.Errorf("failed to marshal capability token: %w", err)
}
fmt.Fprintf(&sb, "\t%s\n", string(data))
}
fmt.Fprintf(&sb, "provide:\n")
for _, t := range provide.Tokens {
data, err := json.Marshal(t)
if err != nil {
return "", fmt.Errorf("failed to marshal capability token: %w", err)
}
fmt.Fprintf(&sb, "\t%s\n", string(data))
}
fmt.Fprintf(&sb, "revoke:\n")
for _, t := range revoke.Tokens {
data, err := json.Marshal(t)
if err != nil {
return "", fmt.Errorf("failed to marshal capability token: %w", err)
}
fmt.Fprintf(&sb, "\t%s\n", string(data))
}
return sb.String(), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"context"
"fmt"
"path/filepath"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
dmsUtils "gitlab.com/nunet/device-management-service/utils"
)
type NewCapOptions struct {
Force bool
Context string
PrismURL string
}
func newNewCmd(dmsCLI *cli.DmsCLI) *cobra.Command {
var opts NewCapOptions
cmd := &cobra.Command{
Use: "new <name>",
Short: "Create a new capability context",
Long: `Create a new persistent capability context
If the specified key has a PRISM DID association (created via 'nunet key create-prism'),
the capability context will use the PRISM DID. Otherwise, it will use a did:key identity.
Example:
nunet cap new user
nunet cap new myprism # Uses PRISM DID if associated
nunet cap new ledger # ledger account 0
nunet cap new ledger:3 # ledger account 3
nunet cap new ledger:finance # ledger alias "finance"`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.Context = node.UserContextName
if len(args) > 0 {
opts.Context = args[0]
}
return runNewCap(cmd.Context(), dmsCLI, opts, cli.CmdStreams(cmd))
},
}
cmd.Flags().BoolVarP(&opts.Force, fnForce, "f", false, "force overwrite of existing context")
cmd.Flags().StringVar(&opts.PrismURL, "prism-url", "", "PRISM resolver URL (e.g., http://localhost:8080). If not specified, uses default resolver.")
return cmd
}
func runNewCap(ctx context.Context, dmsCLI *cli.DmsCLI, opts NewCapOptions, streams cli.Streams) error {
var passphrase string
var privKey crypto.PrivKey
fs := dmsCLI.FS()
cfg, err := dmsCLI.Config()
if err != nil {
return fmt.Errorf("unable to get config: %w", err)
}
// ledger doesnt need keystore
if node.IsLedgerContext(opts.Context) || node.IsEternlContext(opts.Context) {
return GenCaps(ctx, cfg, fs, opts, streams, nil)
}
// set up keystore
ks := dmsCLI.Keystore()
if ks == nil {
keyStoreDir := filepath.Join(cfg.General.UserDir, node.KeystoreDir)
ks, err = keystore.New(fs, keyStoreDir, false)
if err != nil {
return fmt.Errorf("failed to open keystore: %w", err)
}
}
// extract priv key
if ks.Exists(opts.Context) {
fmt.Fprintf(streams.Out, "Using identity at %s/%s.json...\n", ks.Dir(), opts.Context)
passphrase, err = dmsCLI.Passphrase(opts.Context)
if err != nil {
return fmt.Errorf("failed to get passphrase: %w", err)
}
key, err := ks.Get(opts.Context, passphrase)
if err != nil {
return fmt.Errorf("failed to get key from keystore: %w", err)
}
privKey, err = key.PrivKey()
if err != nil {
return fmt.Errorf("unable to convert key from keystore to private key: %v", err)
}
} else {
fmt.Fprintf(streams.Out, "A new identity will be created for '%s' context...\n", opts.Context)
passphrase, err = dmsCLI.NewPassphrase(opts.Context)
if err != nil {
return fmt.Errorf("failed to create new passphrase: %w", err)
}
privKey, err = dms.GenerateAndStorePrivKey(ks, passphrase, opts.Context)
if err != nil {
return fmt.Errorf("failed to create new key: %w", err)
}
}
if !ks.Exists(opts.Context) {
return fmt.Errorf("key missing: %s", opts.Context)
}
return GenCaps(ctx, cfg, fs, opts, streams, privKey)
}
// GenCaps generates capability files.
func GenCaps(
_ context.Context, cfg *config.Config, fs afero.Fs, opts NewCapOptions, streams cli.Streams, privKey crypto.PrivKey,
) error {
var err error
var trustCtx did.TrustContext
var rootDID did.DID
switch {
case node.IsLedgerContext(opts.Context):
// need userDir for the resolver
idx, err := node.ResolveLedgerIndex(fs, cfg.General.UserDir, node.GetContextKey(opts.Context))
if err != nil {
return err
}
provider, err := did.NewLedgerWalletProvider(idx)
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
rootDID = provider.DID()
opts.Context = node.GetContextKey(opts.Context) // normalize context name
case node.IsEternlContext(opts.Context):
provider, err := did.NewEternlWalletProvider()
if err != nil {
return err
}
trustCtx = did.NewTrustContextWithProvider(provider)
rootDID = provider.DID()
opts.Context = node.GetContextKey(opts.Context)
default:
// Check if this key has a PRISM DID association
prismDIDStr, err := node.GetPrismDID(fs, cfg.General.UserDir, opts.Context)
if err != nil {
return fmt.Errorf("unable to check for PRISM DID association: %w", err)
}
if prismDIDStr != "" {
// Use PRISM DID if available
prismDID, err := did.FromString(prismDIDStr)
if err != nil {
return fmt.Errorf("invalid PRISM DID for key %s: %w", opts.Context, err)
}
// Create PRISM provider
provider, err := did.ProviderFromPRISMPrivateKey(prismDID, privKey)
if err != nil {
return fmt.Errorf("unable to create PRISM provider: %w", err)
}
// Create trust context with PRISM provider
trustCtx = did.NewTrustContextWithProvider(provider)
// Configure PRISM resolver if URL is provided
originalConfig := did.GetPRISMResolverConfig()
if opts.PrismURL != "" {
did.SetPRISMResolverConfig(did.PRISMResolverConfig{
ResolverURL: opts.PrismURL,
PreferredVerificationMethod: originalConfig.PreferredVerificationMethod,
HTTPClient: originalConfig.HTTPClient,
})
}
// Try to resolve PRISM DID to get anchor for verification
// This is optional - if resolution fails, we'll continue without the anchor
// The anchor will be resolved on-demand when needed for verification
anchor, err := did.GetAnchorForDID(prismDID)
// Restore original config if we changed it
if opts.PrismURL != "" {
did.SetPRISMResolverConfig(originalConfig)
}
if err != nil {
fmt.Fprintf(streams.Out, "⚠️ Warning: Could not resolve PRISM DID anchor (will resolve on-demand): %v\n", err)
fmt.Fprintf(streams.Out, " This is normal if the DID was just created or the resolver is unavailable.\n")
} else {
// Add PRISM anchor to trust context for faster verification
trustCtx.AddAnchor(anchor)
}
// Use PRISM DID as root
rootDID = prismDID
fmt.Fprintf(streams.Out, "✅ Using PRISM DID: %s\n", prismDIDStr)
if opts.PrismURL != "" {
fmt.Fprintf(streams.Out, " Resolver URL: %s\n", opts.PrismURL)
}
} else {
// Fall back to did:key if no PRISM DID association
trustCtx, err = did.NewTrustContextWithPrivateKey(privKey)
if err != nil {
return fmt.Errorf("unable to create trust context: %w", err)
}
rootDID = did.FromPublicKey(privKey.GetPublic())
fmt.Fprintf(streams.Out, "Using did:key identity (no PRISM DID association found)\n")
}
}
capStoreDir := filepath.Join(cfg.General.UserDir, node.CapstoreDir)
capStoreFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", opts.Context))
fileExists, err := afero.Exists(fs, capStoreFile)
if err != nil {
return fmt.Errorf("unable to check if capability context file exists: %w", err)
}
if fileExists && !opts.Force {
confirmed, err := dmsUtils.PromptYesNo(
streams.In,
streams.Out,
fmt.Sprintf(
"WARNING: A capability context file already exists at %s. Creating a new one will overwrite the existing context. Do you want to proceed?",
capStoreFile,
),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return fmt.Errorf("operation cancelled by user")
}
} else {
if err := fs.MkdirAll(capStoreDir, 0o700); err != nil {
return fmt.Errorf("unable to create capability store directory: %w", err)
}
}
capCtx, err := ucan.NewCapabilityContextWithName(opts.Context, trustCtx, rootDID, nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
if err != nil {
return fmt.Errorf("unable to create capability context: %w", err)
}
if err := node.SaveCapabilityContext(capCtx, fs, cfg.UserDir); err != nil {
return fmt.Errorf("save capability context: %w", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"context"
"encoding/json"
"fmt"
"os"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
// RemoveCapOptions holds the command-line options for the remove command.
type RemoveCapOptions struct {
Context string
Root string
Provide string
Require string
Revoke string
}
func newRemoveCmd(dmsCLI *cli.DmsCLI) *cobra.Command {
var opts RemoveCapOptions
const (
fnProvide = "provide"
fnRoot = "root"
fnRequire = "require"
)
cmd := &cobra.Command{
Use: "remove",
Short: "Remove capability anchors",
Long: `Remove capability anchors in a capability context
One capability anchor must be specified at a time.
Example:
nunet cap remove --context user --root did:key:abcd1234
nunet cap remove --context user --require '<the-token>'`,
RunE: func(cmd *cobra.Command, _ []string) error {
return runRemoveCap(cmd.Context(), dmsCLI, opts, cli.CmdStreams(cmd))
},
}
useFlagContext(cmd, &opts.Context)
useFlagRoot(cmd, &opts.Root)
useFlagRequire(cmd, &opts.Require)
useFlagProvide(cmd, &opts.Provide)
useFlagRevoke(cmd, &opts.Revoke)
_ = cmd.MarkFlagRequired(fnContext)
cmd.MarkFlagsOneRequired(fnProvide, fnRoot, fnRequire, fnRevoke)
cmd.MarkFlagsMutuallyExclusive(fnProvide, fnRoot, fnRequire, fnRevoke)
return cmd
}
func runRemoveCap(_ context.Context, dmsCLI *cli.DmsCLI, opts RemoveCapOptions, _ cli.Streams) error {
capCtx, err := utils.LoadCapabilityContext(dmsCLI, opts.Context)
if err != nil {
return err
}
switch {
case opts.Root != "":
rootDID, err := did.FromString(opts.Root)
if err != nil {
return fmt.Errorf("invalid root DID: %w", err)
}
capCtx.RemoveRoots([]did.DID{rootDID}, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
case opts.Require != "":
var token ucan.Token
if err := json.Unmarshal([]byte(opts.Require), &token); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
capCtx.RemoveRoots(nil, ucan.TokenList{Tokens: []*ucan.Token{&token}}, ucan.TokenList{}, ucan.TokenList{})
case opts.Provide != "":
var token ucan.Token
if err := json.Unmarshal([]byte(opts.Provide), &token); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
capCtx.RemoveRoots(nil, ucan.TokenList{}, ucan.TokenList{Tokens: []*ucan.Token{&token}}, ucan.TokenList{})
case opts.Revoke != "":
var token ucan.Token
if err := json.Unmarshal([]byte(opts.Revoke), &token); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
capCtx.RemoveRoots(nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{Tokens: []*ucan.Token{&token}})
default:
return fmt.Errorf("one of --provide, --root, --require, or --revoke must be specified")
}
if err := utils.SaveCapabilityContext(dmsCLI, capCtx); err != nil {
return err
}
// Send SIGUSR1 to running DMS to reload contexts
if err := signalDMSReload(dmsCLI); err != nil {
// Log the error but don't fail - DMS might not be running (expected during initial setup)
fmt.Fprintf(os.Stderr, "Warning: Could not signal DMS to reload (DMS may not be running): %v\n", err)
} else {
fmt.Println("Successfully signaled DMS to reload capability contexts")
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cap
import (
"context"
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
// RevokeCapOptions holds the command-line options for the revoke command.
type RevokeCapOptions struct {
Context string
Token string
}
func newRevokeCmd(dmsCLI *cli.DmsCLI) *cobra.Command {
var opts RevokeCapOptions
cmd := &cobra.Command{
Use: "revoke <token>",
Short: "Revoke a token",
Long: `Revoke a granted or deleated token
Example:
nunet cap revoke --context user '{"some": "json", "token": "here"}'
The above command revokes a token`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.Token = args[0]
return runRevokeCap(cmd.Context(), dmsCLI, opts, cli.CmdStreams(cmd))
},
}
useFlagContext(cmd, &opts.Context)
_ = cmd.MarkFlagRequired(fnContext)
return cmd
}
func runRevokeCap(_ context.Context, dmsCLI *cli.DmsCLI, opts RevokeCapOptions, streams cli.Streams) error {
capCtx, err := utils.LoadCapabilityContext(dmsCLI, opts.Context)
if err != nil {
return err
}
var tokens ucan.TokenList
if err := json.Unmarshal([]byte(opts.Token), &tokens); err != nil {
return fmt.Errorf("unmarshal tokens: %w", err)
}
var outputJSON []byte
for _, token := range tokens.Tokens {
revocationTokens, err := capCtx.Revoke(token)
if err != nil {
return fmt.Errorf("failed to revoke: %w", err)
}
tokensJSON, err := json.Marshal(revocationTokens)
if err != nil {
return fmt.Errorf("unable to marshal tokens to json: %w", err)
}
outputJSON = append(outputJSON, tokensJSON...)
outputJSON = append(outputJSON, []byte("\n")...)
}
fmt.Fprintln(streams.Out, string(outputJSON))
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cli
import (
"fmt"
"io"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/client"
"gitlab.com/nunet/device-management-service/cmd/cli/passphrase"
"gitlab.com/nunet/device-management-service/internal/config"
env "gitlab.com/nunet/device-management-service/lib/env"
)
type Streams struct {
In io.Reader
Out io.Writer
Err io.Writer
}
func CmdStreams(cmd *cobra.Command) Streams {
return Streams{
In: cmd.InOrStdin(),
Out: cmd.OutOrStdout(),
Err: cmd.ErrOrStderr(),
}
}
type DmsCLI struct {
env env.EnvironmentProvider
fs afero.Fs
defaultConfig *config.Config
configLoader *config.Loader
passphraseProvider passphrase.Provider
keystoreProvider keystore.KeyStore
clientFn func(cfg *config.Config, sctx actor.SecurityContext) (client.DmsClient, error)
clientFnWithTimeout func(cfg *config.Config, sctx actor.SecurityContext, timeout time.Duration) (client.DmsClient, error)
}
func (c *DmsCLI) Env() env.EnvironmentProvider {
return c.env
}
func (c *DmsCLI) FS() afero.Fs {
return c.fs
}
func (c *DmsCLI) ConfigLoader() *config.Loader {
return c.configLoader
}
func (c *DmsCLI) Config() (*config.Config, error) {
return c.configLoader.GetConfig()
}
func (c *DmsCLI) Passphrase(key string) (string, error) {
return c.passphraseProvider.GetPassphrase(key)
}
func (c *DmsCLI) NewPassphrase(key string) (string, error) {
return c.passphraseProvider.NewPassphrase(key)
}
func (c *DmsCLI) Keystore() keystore.KeyStore {
return c.keystoreProvider
}
func (c *DmsCLI) NewClient(sctx actor.SecurityContext) (client.DmsClient, error) {
cfg, err := c.Config()
if err != nil {
return nil, err
}
return c.clientFn(cfg, sctx)
}
func (c *DmsCLI) NewClientWithTimeout(sctx actor.SecurityContext, timeout time.Duration) (client.DmsClient, error) {
cfg, err := c.Config()
if err != nil {
return nil, err
}
return c.clientFnWithTimeout(cfg, sctx, timeout)
}
func New(opts ...func(*DmsCLI)) *DmsCLI {
cli := &DmsCLI{}
for _, opt := range opts {
opt(cli)
}
if cli.fs == nil {
cli.fs = afero.NewOsFs()
}
if cli.configLoader == nil {
cli.configLoader = config.NewLoader(config.WithFS(cli.fs))
}
if cli.defaultConfig != nil {
cli.configLoader.SetConfig(*cli.defaultConfig)
}
if cli.env == nil {
cli.env = env.NewOSEnvironment()
}
if cli.passphraseProvider == nil {
cli.passphraseProvider = passphrase.DefaultProvider(cli.env)
}
if cli.clientFn == nil {
cli.clientFn = func(cfg *config.Config, sctx actor.SecurityContext) (client.DmsClient, error) {
return client.NewClient(client.Config{
Host: fmt.Sprintf("%s:%d", cfg.Rest.Addr, cfg.Rest.Port),
APIPrefix: "/api",
Version: "v1",
}, sctx)
}
}
if cli.clientFnWithTimeout == nil {
cli.clientFnWithTimeout = func(cfg *config.Config, sctx actor.SecurityContext, timeout time.Duration) (client.DmsClient, error) {
return client.NewClient(client.Config{
Host: fmt.Sprintf("%s:%d", cfg.Rest.Addr, cfg.Rest.Port),
APIPrefix: "/api",
Version: "v1",
RequestTimeout: timeout,
}, sctx)
}
}
return cli
}
func WithEnv(env env.EnvironmentProvider) func(*DmsCLI) {
return func(cli *DmsCLI) {
cli.env = env
}
}
func WithFS(fs afero.Fs) func(*DmsCLI) {
return func(cli *DmsCLI) {
cli.fs = fs
}
}
func WithConfig(cfg *config.Config) func(*DmsCLI) {
return func(cli *DmsCLI) {
cli.defaultConfig = cfg
}
}
func WithPassphraseProvider(pp passphrase.Provider) func(*DmsCLI) {
return func(cli *DmsCLI) {
cli.passphraseProvider = pp
}
}
func WithKeystoreProvider(ks keystore.KeyStore) func(*DmsCLI) {
return func(cli *DmsCLI) {
cli.keystoreProvider = ks
}
}
func WithClientFn(clientFn func(cfg *config.Config, sctx actor.SecurityContext) (client.DmsClient, error)) func(*DmsCLI) {
return func(cli *DmsCLI) {
cli.clientFn = clientFn
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package passphrase
import (
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/lib/env"
"gitlab.com/nunet/device-management-service/utils"
)
type Provider interface {
GetPassphrase(key string) (string, error)
NewPassphrase(key string) (string, error)
}
type envPassphraseProvider struct {
env env.EnvironmentProvider
}
func (e *envPassphraseProvider) GetPassphrase(_ string) (string, error) {
passphrase := e.env.Getenv(node.DMSPassphraseEnv)
if passphrase == "" {
return "", ErrPassphraseNotFound
}
return passphrase, nil
}
func (e *envPassphraseProvider) NewPassphrase(key string) (string, error) {
return e.GetPassphrase(key)
}
type promptPassphraseProvider struct{}
func (p *promptPassphraseProvider) GetPassphrase(_ string) (string, error) {
return utils.PromptForPassphrase(false)
}
func (p *promptPassphraseProvider) NewPassphrase(_ string) (string, error) {
return utils.PromptForPassphrase(true)
}
type DefaultPassphraseProvider struct {
providers []Provider
}
func (d *DefaultPassphraseProvider) GetPassphrase(key string) (string, error) {
for _, provider := range d.providers {
passphrase, err := provider.GetPassphrase(key)
if err == nil {
return passphrase, nil
}
}
return "", ErrPassphraseNotFound
}
func (d *DefaultPassphraseProvider) NewPassphrase(key string) (string, error) {
for _, provider := range d.providers {
passphrase, err := provider.NewPassphrase(key)
if err == nil {
return passphrase, nil
}
}
return "", ErrNewPassphraseFailed
}
func DefaultProvider(env env.EnvironmentProvider) Provider {
return &DefaultPassphraseProvider{
providers: []Provider{
&envPassphraseProvider{env: env},
&promptPassphraseProvider{},
},
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"encoding/json"
"fmt"
"os/exec"
"strconv"
"strings"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/internal/config"
)
func newConfigCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "config",
Short: "Manage configuration file",
Long: `Utility to manage user's configuration file via command-line
Search for the configuration file is done in the following locations and order:
1. "." (current directory)
2. "$HOME/.nunet"
3. "/etc/nunet"`,
}
cmd.AddCommand(newConfigGetCmd(dmsCli))
cmd.AddCommand(newConfigSetCmd(dmsCli))
cmd.AddCommand(newConfigEditCmd(dmsCli))
return cmd
}
func newConfigGetCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "get <key>",
Short: "Display configuration",
Long: `Display the value for a configuration key
It reads the value from configuration file, otherwise it return default values
Example:
nunet config get rest.port`,
Args: cobra.MaximumNArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
ldr := dmsCli.ConfigLoader()
_ = ensureConfigFile(dmsCli.FS(), ldr)
cfg, err := ldr.GetConfig()
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
cmd.Println("Found config file at:", ldr.ConfigFile())
// No key print the whole struct as JSON
if len(args) == 0 {
all, err := json.MarshalIndent(cfg, "", " ")
if err != nil {
return fmt.Errorf("indent config JSON: %w", err)
}
cmd.Println(string(all))
return nil
}
val, found := ldr.GetValue(strings.ToLower(args[0]))
if !found {
return fmt.Errorf("key %q not found", args[0])
}
pretty, _ := json.MarshalIndent(val, "", " ")
cmd.Println(string(pretty))
return nil
},
}
return cmd
}
func newConfigSetCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "set <key> <value>",
Short: "Update configuration",
Long: `Set value for a configuration key.
Creates the configuration file if it does not yet exist.
Examples:
nunet config set rest.port 4444
nunet config set general.work_dir ~/.config/dms
nunet config set observability.elastic.enabled true
nunet config set p2p.listen_address '["/ip4/0.0.0.0/tcp/9889", "/ip4/0.0.0.0/udp/9889/quic-v1"]'`,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
key := strings.ToLower(args[0])
raw := args[1]
ldr := dmsCli.ConfigLoader()
_ = ensureConfigFile(dmsCli.FS(), ldr)
exists, err := afero.Exists(dmsCli.FS(), ldr.ConfigFile())
if err != nil {
return fmt.Errorf("stat config file: %w", err)
}
if !exists {
cmd.Println("Config file did not exist. Creating new file...")
} else {
cmd.Println("Updating existing config file...")
}
// Parse numeric and bool literals, keep string otherwise.
value := parseLiteral(raw)
if err := ldr.Set(key, value); err != nil {
return fmt.Errorf("failed to set config: %w", err)
}
cmd.Println("Applied changes.")
return nil
},
}
return cmd
}
func newConfigEditCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "edit",
Short: "Edit configuration",
Long: `Open configuration file with the default text editor.
The command reads the $EDITOR environment variable and fails if it is unset.`,
Args: cobra.NoArgs,
RunE: func(cmd *cobra.Command, _ []string) error {
editor := dmsCli.Env().Getenv("EDITOR")
if editor == "" {
return fmt.Errorf("$EDITOR not set")
}
ldr := dmsCli.ConfigLoader()
_ = ensureConfigFile(dmsCli.FS(), ldr)
cmd.Printf("Text editor: %s\n", editor)
cmd.Printf("Config path: %s\n", ldr.ConfigFile())
proc := exec.Command(editor, ldr.ConfigFile())
proc.Stdout = cmd.OutOrStdout()
proc.Stdin = cmd.InOrStdin()
proc.Stderr = cmd.OutOrStderr()
return proc.Run()
},
}
return cmd
}
// Helpers
func ensureConfigFile(fs afero.Fs, ldr *config.Loader) error {
if path := ldr.ConfigFile(); path != "" {
if ok, err := afero.Exists(fs, path); err == nil && ok {
return nil
}
}
return ldr.Write(false)
}
func parseLiteral(s string) interface{} {
// try string slice first
var ss []string
if err := json.Unmarshal([]byte(s), &ss); err == nil {
return ss
}
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return int(i)
}
if b, err := strconv.ParseBool(s); err == nil {
return b
}
return s
}
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/actor"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
)
func newContractCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "contracts",
Short: "Interact with contracts",
}
cmd.AddCommand(newContractListCmd(dmsCli))
return cmd
}
func newContractListCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "list",
Short: "List contracts",
}
cmd.AddCommand(newContractListAlias(dmsCli, "incoming", "List contracts where this node is the provider", contracts.ContractRoleProvider))
cmd.AddCommand(newContractListAlias(dmsCli, "outgoing", "List contracts where this node is the requestor", contracts.ContractRoleRequestor))
return cmd
}
func newContractListAlias(dmsCli *cli.DmsCLI, use, short string, role contracts.ContractListIncomingRole) *cobra.Command {
cmd, err := actor.NewActorCmdWrapper(dmsCli, behaviors.ContractListBehavior)
if err != nil {
return &cobra.Command{
Use: use,
RunE: func(_ *cobra.Command, _ []string) error {
return err
},
}
}
cmd.Use = fmt.Sprintf("%s [flags]", use)
cmd.Short = short
cmd.Long = `This command lists contracts for the given role.
Examples:
nunet contract list incoming
nunet contract list outgoing`
prevPreRun := cmd.PreRunE
cmd.PreRunE = func(cmd *cobra.Command, args []string) error {
if err := cmd.Flags().Set("role", string(role)); err != nil {
return fmt.Errorf("failed to set role flag: %w", err)
}
if prevPreRun != nil {
return prevPreRun(cmd, args)
}
return nil
}
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/actor"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/dms/behaviors"
)
func newDeployCmd(dmsCli *cli.DmsCLI) *cobra.Command {
behavior := behaviors.NewDeploymentBehavior
cmd, err := actor.NewActorCmdWrapper(dmsCli, behavior)
if err != nil {
return &cobra.Command{
Use: "deploy",
RunE: func(_ *cobra.Command, _ []string) error {
return err
},
}
}
cmd.Use = "deploy"
cmd.Short = "Create a deployment"
cmd.Long = `This command creates a new deployment. It receives an ensemble file as argument.
Example:
Deploy ensemble with a 5 minute timeout
nunet -c alice deploy -f foo.yaml -t 5m`
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/actor"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/dms/behaviors"
)
func newGetCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "get",
Short: "Get deployments, allocations etc.",
}
cmd.AddCommand(newGetDeployments(dmsCli))
cmd.AddCommand(newGetAllocations(dmsCli))
return cmd
}
func newGetDeployments(dmsCli *cli.DmsCLI) *cobra.Command {
behavior := behaviors.DeploymentListBehavior
cmd, err := actor.NewActorCmdWrapper(dmsCli, behavior)
if err != nil {
return &cobra.Command{
Use: "deployments",
RunE: func(_ *cobra.Command, _ []string) error {
return err
},
}
}
cmd.Use = "deployments"
cmd.Short = "Get all deployments"
cmd.Long = `Get all deployments. It will show running deployments as well as completed or stopped ones.
Each deployment will be referenced by its ensemble ID along with their status.`
return cmd
}
func newGetAllocations(dmsCli *cli.DmsCLI) *cobra.Command {
behavior := behaviors.AllocationsListBehavior
cmd, err := actor.NewActorCmdWrapper(dmsCli, behavior)
if err != nil {
return &cobra.Command{
Use: "allocations",
RunE: func(_ *cobra.Command, _ []string) error {
return err
},
}
}
cmd.Use = "allocations"
cmd.Long = `Get all allocations. It will show running allocations as well as completed or stopped ones.
This returns all allocations running on the host, from one acting as Compute Provider. This will not show not the allocations from deployed ensembles.`
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"context"
"fmt"
"os"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/image"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/lib/hardware/gpu"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
func newGPUCommand() *cobra.Command {
gpuCmd := &cobra.Command{
Use: "gpu <operation>",
Short: "Manage GPU resources",
Long: `Available operations:
- list: List all available GPUs
- test: Test GPU deployment by running a docker container with GPU resources
`,
}
gpuCmd.AddCommand(newGPUListCommand())
gpuCmd.AddCommand(newGPUTestCommand())
return gpuCmd
}
func newGPUListCommand() *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List all available GPUs",
RunE: func(_ *cobra.Command, _ []string) error {
gpuManager := gpu.NewGPUManager()
gpus, err := gpuManager.GetGPUs()
if err != nil {
return fmt.Errorf("get gpus: %w", err)
}
gpuUsage, err := gpuManager.GetGPUUsage()
if err != nil {
return fmt.Errorf("get GPU usage: %w", err)
}
if len(gpus) == 0 {
log.Infow("no_gpus_detected_on_host",
"labels", string(observability.LabelNode))
return nil
}
if len(gpus) != len(gpuUsage) {
return fmt.Errorf("internal error: GPU count mismatch")
}
fmt.Println("GPU Details:")
for i, g := range gpus {
fmt.Printf("Model: %s, Total VRAM: %d GB, Used VRAM: %d GB, Cores: %d, Vendor: %s, PCI Address: %s, UUID: %s, Index: %d\n",
g.Model, g.VRAMInGB(), gpuUsage[i].VRAMInGB(), g.Cores, g.Vendor, g.PCIAddress, g.UUID, g.Index)
}
return nil
},
}
}
func newGPUTestCommand() *cobra.Command {
return &cobra.Command{
Use: "test",
Short: "Test GPU deployment by running a Docker container with GPU resources",
RunE: func(_ *cobra.Command, _ []string) error {
dockerClient, err := docker.NewDockerClient()
if err != nil {
return fmt.Errorf("new docker client: %w", err)
}
gpus, err := gpu.NewGPUManager().GetGPUs()
if err != nil {
return fmt.Errorf("get gpus: %w", err)
}
if len(gpus) == 0 {
return fmt.Errorf("no GPUs found")
}
maxFreeVRAMGpu, err := gpus.MaxFreeVRAMGPU()
if err != nil {
return fmt.Errorf("get GPU with max free VRAM: %v", err)
}
fmt.Printf("Selected Vendor: %s, Device: %s", maxFreeVRAMGpu.Vendor, maxFreeVRAMGpu.Model)
if maxFreeVRAMGpu.Vendor == types.GPUVendorNvidia {
// Check if NVIDIA container toolkit is installed
// We specifically look for the nvidia-container-toolkit executable because:
// 1. It's the name of the main package installed via apt (nvidia-container-toolkit)
// 2. It's the most reliable indicator of a proper toolkit installation
// 3. Checking for this single file reduces the risk of false positives
_, err = os.Stat("/usr/bin/nvidia-container-toolkit")
if os.IsNotExist(err) {
return fmt.Errorf("nvidia container toolkit is not installed. Please install it before running this command")
}
}
imageName := "ubuntu:20.04"
if !dockerClient.IsInstalled(context.Background()) {
return fmt.Errorf("docker is not installed or running")
}
fmt.Printf("Creating the docker conainer for the image: %s\n", imageName)
containerConfig := &container.Config{
Image: imageName,
User: "root",
Tty: true, // Enable TTY
AttachStdout: true, // Attach stdout
AttachStderr: true, // Attach stderr
Entrypoint: []string{""}, // Set entrypoint to run shell commands
Cmd: []string{
// This will show both the integrated and discrete GPUs
"sh", "-c",
"apt-get update && apt-get install -y pciutils && lspci | grep 'VGA compatible controller'",
},
}
var hostConfig *container.HostConfig
switch maxFreeVRAMGpu.Vendor {
case types.GPUVendorNvidia:
hostConfig = &container.HostConfig{
AutoRemove: true,
Resources: container.Resources{
DeviceRequests: []container.DeviceRequest{
{
Driver: "nvidia",
Count: -1,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = &container.HostConfig{
AutoRemove: true,
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
case types.GPUVendorIntel:
hostConfig = &container.HostConfig{
AutoRemove: true,
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
default:
return fmt.Errorf("unknown GPU vendor: %s", maxFreeVRAMGpu.Vendor)
}
containerID, err := dockerClient.CreateContainer(context.Background(),
containerConfig,
hostConfig,
nil,
image.PullOptions{},
nil,
"nunet-gpu-test",
true,
)
if err != nil {
return fmt.Errorf("pulling Docker image: %v", err)
}
fmt.Println("Container created with ID: ", containerID)
if err := dockerClient.StartContainer(context.Background(), "nunet-gpu-test"); err != nil {
return fmt.Errorf("starting docker container: %v", err)
}
ctx := context.Background()
// Wait for the container to finish execution
statusCh, errCh := dockerClient.WaitContainer(ctx, containerID)
select {
case err := <-errCh:
if err != nil {
fmt.Printf("Container exited with error: %v\n", err)
}
case <-statusCh:
fmt.Println("Container execution completed.")
}
reader, err := dockerClient.GetOutputStream(ctx, containerID, "", true)
if err != nil {
return fmt.Errorf("getting output stream: %v", err)
}
// Print the output stream
if _, err := os.Stdout.ReadFrom(reader); err != nil {
return fmt.Errorf("reading output stream: %v", err)
}
return nil
},
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/spf13/cobra"
"google.golang.org/protobuf/proto"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
prismpb "gitlab.com/nunet/device-management-service/proto/generated/prism"
dmsUtils "gitlab.com/nunet/device-management-service/utils"
)
func newKeyCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
cmd := &cobra.Command{
Use: "key",
Short: "Manage cryptographic keys",
Long: `Manage cryptographic keys for the Device Management Service (DMS).
This command provides subcommands for creating new keys and retrieving Decentralized Identifiers (DIDs) associated with existing keys.`,
}
cmd.AddCommand(newKeyNewCmd(dmsCli))
cmd.AddCommand(newKeyImportCmd(dmsCli))
cmd.AddCommand(newKeyImportPrismCmd(dmsCli))
cmd.AddCommand(newKeyCreatePrismCmd(dmsCli))
cmd.AddCommand(newKeyDIDCmd(dmsCli))
cmd.AddCommand(newKeyLedgerAliasCmd(dmsCli))
return cmd
}
func newKeyNewCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
return &cobra.Command{
Use: "new <name>",
Short: "Generate a key pair",
Long: `Generate a key pair and save the private key into the user's local keystore.
This command creates a new cryptographic key pair, stores the private key securely, and displays the associated Decentralized Identifier (DID). If a key with the specified name already exists, the user will be prompted to confirm before overwriting it.
Example:
nunet key new user`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cfg, err := dmsCli.Config()
if err != nil {
return fmt.Errorf("get dms config: %w", err)
}
fs := dmsCli.FS()
keyStoreDir := filepath.Join(cfg.General.UserDir, node.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir, false)
if err != nil {
return fmt.Errorf("failed to create keystore: %w", err)
}
keyID := node.UserContextName
if len(args) > 0 {
keyID = args[0]
}
if ks.Exists(keyID) {
confirmed, err := dmsUtils.PromptYesNo(
cmd.InOrStdin(),
cmd.OutOrStdout(),
fmt.Sprintf("A key with name '%s' already exists. Do you want to overwrite it with a new one?", keyID),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return dmsUtils.ErrOperationCancelled
}
}
passphrase, err := dmsCli.NewPassphrase(keyID)
if err != nil {
return fmt.Errorf("get dms passphrase: %w", err)
}
priv, err := dms.GenerateAndStorePrivKey(ks, passphrase, keyID)
if err != nil {
return fmt.Errorf("failed to generate and store new private key: %w", err)
}
did := did.FromPublicKey(priv.GetPublic())
fmt.Fprintln(cmd.OutOrStdout(), did)
return nil
},
}
}
func newKeyImportCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
return &cobra.Command{
Use: "import <name> <private-key-hex>",
Short: "Import a private key",
Long: `Import an existing private key into the user's local keystore.
This command takes a hex-encoded private key (in libp2p protobuf format or raw Ed25519 seed) and stores it securely with the given name.
If a key with the specified name already exists, the user will be prompted to confirm before overwriting it.
Example:
nunet key import myweb3key 08011240...`,
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
keyID := args[0]
hexKey := args[1]
rawPriv, err := hex.DecodeString(hexKey)
if err != nil {
return fmt.Errorf("invalid hex string: %w", err)
}
cfg, err := dmsCli.Config()
if err != nil {
return fmt.Errorf("get dms config: %w", err)
}
fs := dmsCli.FS()
keyStoreDir := filepath.Join(cfg.General.UserDir, node.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir, false)
if err != nil {
return fmt.Errorf("failed to create keystore: %w", err)
}
if ks.Exists(keyID) {
confirmed, err := dmsUtils.PromptYesNo(
cmd.InOrStdin(),
cmd.OutOrStdout(),
fmt.Sprintf("A key with name '%s' already exists. Do you want to overwrite it?", keyID),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return dmsUtils.ErrOperationCancelled
}
}
passphrase, err := dmsCli.NewPassphrase(keyID)
if err != nil {
return fmt.Errorf("get dms passphrase: %w", err)
}
priv, err := dms.ImportAndStorePrivKey(ks, rawPriv, passphrase, keyID)
if err != nil {
return fmt.Errorf("failed to import and store private key: %w", err)
}
did := did.FromPublicKey(priv.GetPublic())
fmt.Fprintln(cmd.OutOrStdout(), did)
return nil
},
}
}
func newKeyDIDCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
return &cobra.Command{
Use: "did <name>",
Short: "Get a key's DID",
Long: `Get the DID (Decentralized Identifier) for a specified key.
This command retrieves the DID associated with either a key stored in the local keystore
or a hardware ledger. For the ledger you can now supply an account index or a named alias.
Examples:
nunet key did user # key from keystore
nunet key did ledger # ledger account 0 (default)
nunet key did ledger:3 # ledger account 3
nunet key did ledger:business # ledger alias "business"`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
cfg, err := dmsCli.ConfigLoader().Load()
if err != nil {
return fmt.Errorf("get dms config: %w", err)
}
fs := dmsCli.FS()
env := dmsCli.Env()
keyName := args[0]
// Ledger branch
if node.IsLedgerContext(keyName) {
idx, err := node.ResolveLedgerIndex(
fs, cfg.General.UserDir, node.GetContextKey(keyName),
)
if err != nil {
return err
}
provider, err := did.NewLedgerWalletProvider(idx)
if err != nil {
return err
}
fmt.Println(provider.DID())
return nil
}
if node.IsEternlContext(keyName) {
provider, err := did.NewEternlWalletProvider()
if err != nil {
return err
}
fmt.Println(provider.DID())
return nil
}
keyStoreDir := filepath.Join(cfg.General.UserDir, node.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir, false)
if err != nil {
return fmt.Errorf("failed to open keystore: %w", err)
}
passphrase, err := utils.GetDMSPassphrase(env, false)
if err != nil {
return fmt.Errorf("get dms passphrase: %w", err)
}
key, err := ks.Get(keyName, passphrase)
if err != nil {
return fmt.Errorf("failed to get key: %w", err)
}
priv, err := key.PrivKey()
if err != nil {
return fmt.Errorf("unable to convert key from keystore to private key: %v", err)
}
did := did.FromPublicKey(priv.GetPublic())
fmt.Fprintln(cmd.OutOrStdout(), did)
return nil
},
}
}
func newKeyImportPrismCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
return &cobra.Command{
Use: "import-prism <name> <prism-did> <private-key-hex>",
Short: "Import a PRISM identity (DID + private key)",
Long: `Import an existing PRISM identity into the user's local keystore.
This command takes a PRISM DID and a hex-encoded private key (in libp2p protobuf format or raw Ed25519/secp256k1 key)
and stores it securely with the given name. The private key will be associated with the PRISM DID for signing UCAN tokens.
Supported key formats:
- libp2p protobuf format (hex encoded)
- Raw Ed25519 seed (32 bytes, hex encoded)
- Raw Ed25519 private key (64 bytes, hex encoded)
- Raw secp256k1 private key (32 bytes, hex encoded)
If a key with the specified name already exists, the user will be prompted to confirm before overwriting it.
Example:
nunet key import-prism myprism did:prism:9b5118411248d9663b6ab15128fba8106511230ff654e7514cdcc4ce919bde9b 08011240...`,
Args: cobra.ExactArgs(3),
RunE: func(cmd *cobra.Command, args []string) error {
keyID := args[0]
prismDIDStr := args[1]
hexKey := args[2]
// Parse PRISM DID
prismDID, err := did.FromString(prismDIDStr)
if err != nil {
return fmt.Errorf("invalid PRISM DID: %w", err)
}
if prismDID.Method() != "prism" {
return fmt.Errorf("expected PRISM DID (did:prism:...), got %s", prismDID.Method())
}
// Decode private key
rawPriv, err := hex.DecodeString(hexKey)
if err != nil {
return fmt.Errorf("invalid hex string: %w", err)
}
cfg, err := dmsCli.Config()
if err != nil {
return fmt.Errorf("get dms config: %w", err)
}
fs := dmsCli.FS()
keyStoreDir := filepath.Join(cfg.General.UserDir, node.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir, false)
if err != nil {
return fmt.Errorf("failed to create keystore: %w", err)
}
if ks.Exists(keyID) {
confirmed, err := dmsUtils.PromptYesNo(
cmd.InOrStdin(),
cmd.OutOrStdout(),
fmt.Sprintf("A key with name '%s' already exists. Do you want to overwrite it?", keyID),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return dmsUtils.ErrOperationCancelled
}
}
passphrase, err := dmsCli.NewPassphrase(keyID)
if err != nil {
return fmt.Errorf("get dms passphrase: %w", err)
}
// Import and store the private key (same as regular import)
priv, err := dms.ImportAndStorePrivKey(ks, rawPriv, passphrase, keyID)
if err != nil {
return fmt.Errorf("failed to import and store private key: %w", err)
}
// Verify we can create a PRISM provider
provider, err := did.ProviderFromPRISMPrivateKey(prismDID, priv)
if err != nil {
return fmt.Errorf("failed to create PRISM provider: %w", err)
}
// Store the PRISM DID association
if err := node.SetPrismDID(fs, cfg.General.UserDir, keyID, prismDIDStr); err != nil {
return fmt.Errorf("failed to store PRISM DID association: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "PRISM identity imported successfully\n")
fmt.Fprintf(cmd.OutOrStdout(), "DID: %s\n", provider.DID())
fmt.Fprintf(cmd.OutOrStdout(), "Key name: %s\n", keyID)
fmt.Fprintf(cmd.OutOrStdout(), "\nThe PRISM DID association has been stored.\n")
fmt.Fprintf(cmd.OutOrStdout(), "This key will be used with the PRISM DID when signing UCAN tokens.\n")
return nil
},
}
}
func newKeyCreatePrismCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
var (
prismURL string
keyType string
waitConfirmation bool
timeout string
submissionTimeout string
)
cmd := &cobra.Command{
Use: "create-prism <name>",
Short: "Create a new PRISM identity",
Long: `Create a new PRISM identity by generating keys, creating a PRISM DID operation,
submitting it to the PRISM network via NeoPRISM, and importing it into DMS.
This command automates the complete PRISM identity creation workflow:
1. Generates cryptographic keys locally (Secp256k1 or Ed25519)
2. Creates a PRISM DID operation with the generated keys
3. Submits the operation to NeoPRISM (which handles blockchain transaction fees)
4. Optionally waits for blockchain confirmation
5. Imports the created identity into DMS
The command returns all necessary credentials (DID, public/private keys) and sets up
the identity for use with UCAN capabilities.
Note: PRISM identities are independent of Cardano wallets. NeoPRISM handles wallet
management and transaction fees when submitting operations to the blockchain.
Example:
nunet key create-prism myprism`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
keyID := args[0]
// Validate key type
if keyType != "secp256k1" && keyType != "ed25519" {
return fmt.Errorf("invalid key type: %s (supported: secp256k1, ed25519)", keyType)
}
// Parse timeout
timeoutDuration, err := time.ParseDuration(timeout)
if err != nil {
return fmt.Errorf("invalid timeout duration: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "🔑 Generating %s key pair...\n", keyType)
// Generate keys
privKey, pubKey, err := generatePRISMKeys(keyType)
if err != nil {
return fmt.Errorf("generate keys: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "✅ Key pair generated successfully\n\n")
fmt.Fprintf(cmd.OutOrStdout(), "📝 Creating PRISM DID operation...\n")
// Create PRISM operation
signedOpHex, err := did.CreateSignedPRISMOperationSimple(privKey, pubKey, "master-0")
if err != nil {
return fmt.Errorf("create PRISM operation: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "✅ PRISM operation created\n\n")
fmt.Fprintf(cmd.OutOrStdout(), "🔍 Extracting DID from operation...\n")
// Extract DID from operation
prismDIDStr, err := extractDIDFromSignedOperation(signedOpHex)
if err != nil {
return fmt.Errorf("extract DID from operation: %w", err)
}
// Parse PRISM DID
prismDID, err := did.FromString(prismDIDStr)
if err != nil {
return fmt.Errorf("parse PRISM DID: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "✅ DID extracted: %s\n\n", prismDIDStr)
// Parse submission timeout
submissionTimeoutDuration := 2 * time.Minute // Default 2 minutes (should be fast if working)
if submissionTimeout != "" {
parsedTimeout, err := time.ParseDuration(submissionTimeout)
if err != nil {
return fmt.Errorf("invalid submission timeout duration: %w", err)
}
submissionTimeoutDuration = parsedTimeout
}
// Quick connectivity check
fmt.Fprintf(cmd.OutOrStdout(), "🔍 Checking NeoPRISM connectivity at %s...\n", prismURL)
if err := checkNeoPRISMConnectivity(prismURL, 5*time.Second); err != nil {
return fmt.Errorf("NeoPRISM connectivity check failed: %w\n\nTroubleshooting:\n- Ensure NeoPRISM is running: docker ps | grep neoprism\n- Check NeoPRISM is accessible at %s\n- Verify NeoPRISM logs: docker logs neoprism --tail 50", err, prismURL)
}
fmt.Fprintf(cmd.OutOrStdout(), "✅ NeoPRISM is reachable\n\n")
fmt.Fprintf(cmd.OutOrStdout(), "📤 Submitting operation to NeoPRISM (timeout: %s)...\n", submissionTimeoutDuration)
// Submit to NeoPRISM
txID, operationIDs, err := submitPRISMOperationToNeoPRISM(prismURL, signedOpHex, submissionTimeoutDuration, cmd.OutOrStdout())
if err != nil {
return fmt.Errorf("submit to NeoPRISM: %w\n\nTroubleshooting:\n- Ensure NeoPRISM is running at %s\n- Check NeoPRISM logs for errors\n- Verify NeoPRISM has sufficient funds for transaction fees\n- Try increasing --submission-timeout if the operation is large", err, prismURL)
}
fmt.Fprintf(cmd.OutOrStdout(), "✅ Operation submitted successfully\n\n")
// Print PRISM identity information
fmt.Fprintf(cmd.OutOrStdout(), "📋 PRISM Identity Information:\n")
fmt.Fprintf(cmd.OutOrStdout(), " DID: %s\n", prismDIDStr)
fmt.Fprintf(cmd.OutOrStdout(), " Key Name: %s\n", keyID)
fmt.Fprintf(cmd.OutOrStdout(), " Key Type: %s\n", keyType)
if txID != "" {
fmt.Fprintf(cmd.OutOrStdout(), " Transaction ID: %s\n", txID)
}
if len(operationIDs) > 0 {
fmt.Fprintf(cmd.OutOrStdout(), " Operation IDs: %s\n", strings.Join(operationIDs, ", "))
}
fmt.Fprintf(cmd.OutOrStdout(), "\n")
// Optionally wait for confirmation
if waitConfirmation {
fmt.Fprintf(cmd.OutOrStdout(), "⏳ Waiting for blockchain confirmation (timeout: %s)...\n", timeout)
ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
defer cancel()
err = waitForDIDDocument(ctx, prismDID, prismURL, timeoutDuration, cmd.OutOrStdout())
if err != nil {
return fmt.Errorf("wait for DID document: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "✅ DID document confirmed on blockchain\n\n")
} else {
fmt.Fprintf(cmd.OutOrStdout(), "⏭️ Skipping confirmation wait (use --wait-confirmation to enable)\n\n")
}
// Get config and file system
cfg, err := dmsCli.Config()
if err != nil {
return fmt.Errorf("get dms config: %w", err)
}
fs := dmsCli.FS()
// Setup keystore
keyStoreDir := filepath.Join(cfg.General.UserDir, node.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir, false)
if err != nil {
return fmt.Errorf("failed to create keystore: %w", err)
}
// Check if key exists
if ks.Exists(keyID) {
confirmed, err := dmsUtils.PromptYesNo(
cmd.InOrStdin(),
cmd.OutOrStdout(),
fmt.Sprintf("A key with name '%s' already exists. Do you want to overwrite it?", keyID),
)
if err != nil {
return fmt.Errorf("failed to get user confirmation: %w", err)
}
if !confirmed {
return dmsUtils.ErrOperationCancelled
}
}
// Marshal private key for storage
rawPriv, err := crypto.PrivateKeyToBytes(privKey)
if err != nil {
return fmt.Errorf("marshal private key: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "💾 Importing identity into DMS...\n")
// Get passphrase
passphrase, err := dmsCli.NewPassphrase(keyID)
if err != nil {
return fmt.Errorf("get dms passphrase: %w", err)
}
// Import and store the private key
priv, err := dms.ImportAndStorePrivKey(ks, rawPriv, passphrase, keyID)
if err != nil {
return fmt.Errorf("failed to import and store private key: %w", err)
}
// Verify we can create a PRISM provider
provider, err := did.ProviderFromPRISMPrivateKey(prismDID, priv)
if err != nil {
return fmt.Errorf("failed to create PRISM provider: %w", err)
}
// Store the PRISM DID association
if err := node.SetPrismDID(fs, cfg.General.UserDir, keyID, prismDIDStr); err != nil {
return fmt.Errorf("failed to store PRISM DID association: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "✅ Identity imported successfully\n\n")
// Output results
fmt.Fprintf(cmd.OutOrStdout(), "✅ PRISM identity created successfully\n\n")
fmt.Fprintf(cmd.OutOrStdout(), "DID: %s\n", provider.DID())
fmt.Fprintf(cmd.OutOrStdout(), "Key name: %s\n", keyID)
if txID != "" {
fmt.Fprintf(cmd.OutOrStdout(), "Transaction ID: %s\n", txID)
}
if len(operationIDs) > 0 {
fmt.Fprintf(cmd.OutOrStdout(), "Operation IDs: %s\n", strings.Join(operationIDs, ", "))
}
fmt.Fprintf(cmd.OutOrStdout(), "\n⚠️ IMPORTANT: Your private key is stored in the keystore. Make sure to back up\n")
fmt.Fprintf(cmd.OutOrStdout(), " your keystore directory if you need to recover this identity.\n\n")
fmt.Fprintf(cmd.OutOrStdout(), "The identity has been imported into DMS and is ready to use with UCAN.\n")
return nil
},
}
cmd.Flags().StringVar(&prismURL, "prism-url", "http://localhost:8080", "PRISM resolver/submitter URL")
cmd.Flags().StringVar(&keyType, "key-type", "secp256k1", "Key type for PRISM operation (ed25519 or secp256k1)")
cmd.Flags().BoolVar(&waitConfirmation, "wait-confirmation", true, "Wait for blockchain confirmation before returning")
cmd.Flags().StringVar(&timeout, "timeout", "20m", "Timeout for blockchain confirmation")
cmd.Flags().StringVar(&submissionTimeout, "submission-timeout", "2m", "Timeout for submitting operation to NeoPRISM")
return cmd
}
// generatePRISMKeys generates a key pair for PRISM identity creation
func generatePRISMKeys(keyType string) (crypto.PrivKey, crypto.PubKey, error) {
var keyTypeEnum int
switch keyType {
case "secp256k1":
keyTypeEnum = crypto.Secp256k1
case "ed25519":
keyTypeEnum = crypto.Ed25519
default:
return nil, nil, fmt.Errorf("unsupported key type: %s", keyType)
}
privKey, pubKey, err := crypto.GenerateKeyPair(keyTypeEnum)
if err != nil {
return nil, nil, fmt.Errorf("generate key pair: %w", err)
}
return privKey, pubKey, nil
}
// checkNeoPRISMConnectivity performs a quick connectivity check to NeoPRISM
func checkNeoPRISMConnectivity(neoprismURL string, timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// Try to reach a simple endpoint (health check or similar)
// If no health endpoint, try the submission endpoint with invalid data to see if it responds
testURL := fmt.Sprintf("%s/api/signed-operation-submissions", strings.TrimSuffix(neoprismURL, "/"))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, testURL, bytes.NewReader([]byte(`{"signed_operations":["test"]}`)))
if err != nil {
return fmt.Errorf("create test request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
client := &http.Client{
Timeout: timeout,
}
resp, err := client.Do(req)
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
return fmt.Errorf("connection timeout - NeoPRISM may be unreachable or slow to respond")
}
return fmt.Errorf("connection failed: %w", err)
}
defer resp.Body.Close()
// Any response (even error) means NeoPRISM is reachable
// 422 is expected for invalid test data, which confirms the endpoint works
if resp.StatusCode == http.StatusUnprocessableEntity || resp.StatusCode == http.StatusOK {
return nil
}
// Other status codes might indicate issues, but at least it's reachable
return nil
}
// extractDIDFromSignedOperation extracts the PRISM DID from a signed operation
// The DID suffix is the hexadecimal-encoded SHA256 hash of the operation bytes
func extractDIDFromSignedOperation(signedOpHex string) (string, error) {
// Decode hex
signedOpBytes, err := hex.DecodeString(signedOpHex)
if err != nil {
return "", fmt.Errorf("decode hex: %w", err)
}
// Parse the SignedPrismOperation to get the operation
var signedOp prismpb.SignedPrismOperation
if err := proto.Unmarshal(signedOpBytes, &signedOp); err != nil {
return "", fmt.Errorf("unmarshal signed operation: %w", err)
}
if signedOp.Operation == nil {
return "", fmt.Errorf("operation is nil")
}
// Encode the operation to bytes
operationBytes, err := proto.Marshal(signedOp.Operation)
if err != nil {
return "", fmt.Errorf("marshal operation: %w", err)
}
// Compute SHA256 hash
hash := sha256.Sum256(operationBytes)
// NeoPRISM expects canonical PRISM DIDs with hexadecimal suffix (64 chars)
// Format: did:prism:{64-char-hex}
didSuffix := hex.EncodeToString(hash[:])
return fmt.Sprintf("did:prism:%s", didSuffix), nil
}
// submitPRISMOperationToNeoPRISM submits a signed PRISM operation to NeoPRISM
func submitPRISMOperationToNeoPRISM(neoprismURL string, signedOpHex string, timeout time.Duration, output io.Writer) (string, []string, error) {
// Create context with timeout
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// NeoPRISM API format
submitURL := fmt.Sprintf("%s/api/signed-operation-submissions", strings.TrimSuffix(neoprismURL, "/"))
reqBody := map[string]interface{}{
"signed_operations": []string{signedOpHex},
}
bodyBytes, err := json.Marshal(reqBody)
if err != nil {
return "", nil, fmt.Errorf("marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, submitURL, bytes.NewReader(bodyBytes))
if err != nil {
return "", nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
// Use a client with timeout
client := &http.Client{
Timeout: timeout,
}
// Log that we're sending the request
if output != nil {
fmt.Fprintf(output, " Sending request to %s...\n", submitURL)
}
startTime := time.Now()
resp, err := client.Do(req)
requestTime := time.Since(startTime)
if err != nil {
// Check if it's a context timeout
if ctx.Err() == context.DeadlineExceeded {
return "", nil, fmt.Errorf("submission timeout after %v: NeoPRISM did not respond in time. This usually means:\n- NeoPRISM is processing the transaction (can take time on blockchain)\n- NeoPRISM is unavailable or overloaded\n- Network connectivity issues\n\nTry: curl -X POST %s/api/signed-operation-submissions -H 'Content-Type: application/json' -d '{\"signed_operations\":[\"test\"]}' to test connectivity", requestTime, submitURL)
}
return "", nil, fmt.Errorf("submit operation (took %v): %w", requestTime, err)
}
defer resp.Body.Close()
if output != nil {
fmt.Fprintf(output, " Received response (status %d) in %v\n", resp.StatusCode, requestTime)
}
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return "", nil, fmt.Errorf("submission failed with status %d: %s", resp.StatusCode, string(respBody))
}
// Parse NeoPRISM response
var submitResponse struct {
TxID string `json:"tx_id"`
OperationIDs []string `json:"operation_ids"`
}
if err := json.Unmarshal(respBody, &submitResponse); err != nil {
return "", nil, fmt.Errorf("parse response: %w", err)
}
return submitResponse.TxID, submitResponse.OperationIDs, nil
}
// waitForDIDDocument waits for a PRISM DID document to be available on the resolver
func waitForDIDDocument(ctx context.Context, prismDID did.DID, prismURL string, timeout time.Duration, output io.Writer) error {
// Configure resolver
originalConfig := did.GetPRISMResolverConfig()
defer did.SetPRISMResolverConfig(originalConfig)
did.SetPRISMResolverConfig(did.PRISMResolverConfig{
ResolverURL: prismURL,
PreferredVerificationMethod: "authentication",
})
// Create a context with timeout
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// Poll for DID document
retryDelay := 2 * time.Second
maxRetries := int(timeout / retryDelay)
if maxRetries < 1 {
maxRetries = 1
}
var lastErr error
lastLogTime := time.Now()
logInterval := 5 * time.Second // Log progress every 5 seconds
for i := 0; i < maxRetries; i++ {
// Check if context is cancelled
select {
case <-ctx.Done():
if lastErr != nil {
return fmt.Errorf("timeout waiting for DID document: %w", lastErr)
}
return fmt.Errorf("timeout waiting for DID document")
default:
}
// Try to resolve the DID
anchor, err := did.GetAnchorForDID(prismDID)
if err == nil {
// Successfully resolved - verify it has authentication methods
// The anchor creation already verifies the DID document exists
// We can trust that if GetAnchorForDID succeeds, the DID is valid
_ = anchor // Use anchor to avoid unused variable
return nil
}
lastErr = err
// Log progress periodically
now := time.Now()
if now.Sub(lastLogTime) >= logInterval {
fmt.Fprintf(output, " Still waiting... (attempt %d/%d)\n", i+1, maxRetries)
lastLogTime = now
}
// Wait before retrying (except on last attempt)
if i < maxRetries-1 {
select {
case <-ctx.Done():
errMsg := "timeout waiting for DID document"
if lastErr != nil {
return fmt.Errorf("%s: %w", errMsg, lastErr)
}
return fmt.Errorf("%s", errMsg)
case <-time.After(retryDelay):
// Continue to next iteration
}
}
}
errMsg := fmt.Sprintf("failed to resolve DID document after %d attempts", maxRetries)
if lastErr != nil {
return fmt.Errorf("%s: %w", errMsg, lastErr)
}
return fmt.Errorf("%s", errMsg)
}
func newKeyLedgerAliasCmd(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "ledger-alias",
Short: "Manage aliases for Ledger accounts",
}
cmd.AddCommand(newKeyLedgerAliasSetCmd(dmsCli))
return cmd
}
// Child: `nunet key ledger-alias set <alias> <index>`
func newKeyLedgerAliasSetCmd(dmsCli *cli.DmsCLI) *cobra.Command {
return &cobra.Command{
Use: "set <alias> <index>",
Short: "Create or update a Ledger alias",
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
alias := args[0]
idx, err := strconv.Atoi(args[1])
if err != nil || idx < 0 {
return fmt.Errorf("index must be a non-negative integer")
}
cfg, err := dmsCli.ConfigLoader().Load()
if err != nil {
return fmt.Errorf("get dms config: %w", err)
}
if err := node.SetLedgerAlias(dmsCli.FS(), cfg.General.UserDir, alias, idx); err != nil {
return err
}
fmt.Fprintf(cmd.OutOrStdout(),
"Alias %q → account %d saved\n",
alias, idx)
return nil
},
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"context"
"encoding/base64"
"fmt"
"io"
"net"
"os"
"os/signal"
"path"
"strings"
"syscall"
"time"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
"gitlab.com/nunet/device-management-service/client"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/cmd/utils"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/dms/node"
dmsUtils "gitlab.com/nunet/device-management-service/utils"
)
type Allocation struct {
Alloc string
PortMapping map[int]int
DNSName string
IP string
Status string
}
type DeploymentNetwork struct {
ID string
Allocations []Allocation
}
func newNetworkCommand(dmsCli *cli.DmsCLI) *cobra.Command {
gpuCmd := &cobra.Command{
Use: "network <cmd>",
Short: "Network Utility Tool",
Long: `Available operations:
- ls: List all Networks the DMS is part of
- show: Show details of a specific Network
- attach: Attach to the Network.
`,
}
gpuCmd.AddCommand(newNetworkListCommand(dmsCli))
gpuCmd.AddCommand(newNetworkShowCommand(dmsCli))
gpuCmd.AddCommand(newNetworkAttachCommand(dmsCli))
return gpuCmd
}
type networkListOpts struct {
Context string
Verbose bool
}
func newNetworkListCommand(dmsCli *cli.DmsCLI) *cobra.Command {
opts := networkListOpts{}
cmd := &cobra.Command{
Use: "ls",
Short: "List all Networks",
RunE: func(cmd *cobra.Command, _ []string) error {
sctx, err := utils.NewSecurityContext(dmsCli, opts.Context)
if err != nil {
return fmt.Errorf("could not create security context: %w", err)
}
// Now call newClient with the correct arguments
client, err := dmsCli.NewClient(sctx)
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
ids, err := getDeploymentIDs(cmd.Context(), client)
if err != nil {
return fmt.Errorf("error getting deployment IDs: %w", err)
}
depNets, err := getNetworkList(cmd.Context(), ids, client)
if err != nil {
return fmt.Errorf("error getting network list: %w", err)
}
if len(depNets) == 0 {
fmt.Fprintln(cmd.OutOrStdout(), "No Deployment Networks")
return nil
}
fmt.Println("Deployment Networks")
// TODO Format
for _, dn := range depNets {
fmt.Fprintf(cmd.OutOrStdout(), "ID: %s\n", dn.ID)
fmt.Fprintf(cmd.OutOrStdout(), " Allocations in Network=%d\n", len(dn.Allocations))
if opts.Verbose {
for _, a := range dn.Allocations {
fmt.Fprintf(cmd.OutOrStdout(), " Alloc: %s\n", a.Alloc)
fmt.Fprintf(cmd.OutOrStdout(), " IP: %s\n", a.IP)
fmt.Fprintf(cmd.OutOrStdout(), " Hostname: %s\n", a.DNSName)
fmt.Fprintf(cmd.OutOrStdout(), " Ports: %+v\n", a.PortMapping)
fmt.Fprintf(cmd.OutOrStdout(), " Status: %+v\n", a.Status)
}
}
}
fmt.Fprintf(cmd.OutOrStdout(), "\n")
return nil
},
}
cmd.Flags().StringVarP(&opts.Context, "context", "c", node.DefaultContextName, "specify a capability context")
cmd.Flags().BoolVarP(&opts.Verbose, "verbose", "v", false, "verbose output")
err := cmd.MarkFlagRequired("context")
if err != nil {
log.Fatalf("unable to mark flag 'context' as required: %v", err)
}
return cmd
}
type networkShowOpts struct {
Context string
ID string
}
func newNetworkShowCommand(dmsCli *cli.DmsCLI) *cobra.Command {
opts := networkShowOpts{}
cmd := &cobra.Command{
Use: "show",
Short: "Show details of a specific Network",
RunE: func(cmd *cobra.Command, _ []string) error {
sctx, err := utils.NewSecurityContext(dmsCli, opts.Context)
if err != nil {
return fmt.Errorf("could not create security context: %w", err)
}
// Now call newClient with the correct arguments
client, err := dmsCli.NewClient(sctx)
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
depNet, err := getNetwork(cmd.Context(), opts.ID, client)
if err != nil {
return fmt.Errorf("error getting network detail: %w", err)
}
// TODO Format
fmt.Fprintf(cmd.OutOrStdout(), "ID: %s\n", depNet.ID)
fmt.Fprintf(cmd.OutOrStdout(), " Allocations in Network\n")
for _, a := range depNet.Allocations {
fmt.Fprintf(cmd.OutOrStdout(), " Alloc: %s\n", a.Alloc)
fmt.Fprintf(cmd.OutOrStdout(), " IP: %s\n", a.IP)
fmt.Fprintf(cmd.OutOrStdout(), " Hostname: %s\n", a.DNSName)
fmt.Fprintf(cmd.OutOrStdout(), " Ports: %+v\n", a.PortMapping)
fmt.Fprintf(cmd.OutOrStdout(), " Status: %+v\n", a.Status)
}
return nil
},
}
cmd.Flags().StringVarP(&opts.Context, "context", "c", node.DefaultContextName, "Capability Context")
cmd.Flags().StringVarP(&opts.ID, "id", "i", "", "Deployment ID")
err := cmd.MarkFlagRequired("context")
if err != nil {
log.Fatalf("unable to mark flag 'context' as required: %v", err)
}
err = cmd.MarkFlagRequired("id")
if err != nil {
log.Fatalf("unable to mark flag 'id' as required: %v", err)
}
return cmd
}
type networkAttachOpts struct {
Context string
ID string
Alloc string
Shell bool
Forward bool
Username string
Identity string
Port string
PortMap string
}
func newNetworkAttachCommand(dmsCli *cli.DmsCLI) *cobra.Command {
opts := networkAttachOpts{}
cmd := &cobra.Command{
Use: "attach",
Short: "Attach to a specific Network",
RunE: func(cmd *cobra.Command, _ []string) error {
ctx := cmd.Context()
streams := cli.CmdStreams(cmd)
cfg, err := dmsCli.Config()
if err != nil {
return fmt.Errorf("unable to get config: %w", err)
}
afs := afero.Afero{Fs: dmsCli.FS()}
sctx, err := utils.NewSecurityContext(dmsCli, opts.Context)
if err != nil {
return fmt.Errorf("could not create security context: %w", err)
}
// Now call newClient with the correct arguments
client, err := dmsCli.NewClient(sctx)
if err != nil {
return fmt.Errorf("could not create client: %w", err)
}
depNet, err := getNetwork(ctx, opts.ID, client)
if err != nil {
return fmt.Errorf("error getting network detail: %w", err)
}
targetAlloc := Allocation{}
for _, a := range depNet.Allocations {
if a.Alloc == opts.Alloc {
targetAlloc = a
break
}
}
switch {
case opts.Shell:
key, err := afs.ReadFile(opts.Identity)
if err != nil {
return fmt.Errorf("unable to read identity priv key file: %w", err)
}
signer, err := ssh.ParsePrivateKey(key)
if err != nil {
if _, ok := err.(*ssh.PassphraseMissingError); ok {
passphrase, err := dmsUtils.PromptForPassphrase(false)
if err != nil {
return fmt.Errorf("unable to read passphrase: %w", err)
}
signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(passphrase))
if err != nil {
return fmt.Errorf("unable to parse identity priv key with passphrase: %w", err)
}
} else {
return fmt.Errorf("unable to parse identity priv key: %w", err)
}
}
hostKeyManager, err := NewHostKeyManager(afs, path.Join(cfg.UserDir, "ssh", "known_hosts"))
if err != nil {
return fmt.Errorf("unable to create host key manager: %w", err)
}
sshCliConfig := &ssh.ClientConfig{
User: opts.Username,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(signer),
},
HostKeyCallback: hostKeyManager.HostKeyCallback(cmd.OutOrStdout()),
}
client, err := ssh.Dial("tcp", targetAlloc.IP+":"+opts.Port, sshCliConfig)
if err != nil {
return fmt.Errorf("unable to dial: %w", err)
}
session, err := client.NewSession()
if err != nil {
return fmt.Errorf("unable to create session: %w", err)
}
defer session.Close()
modes := ssh.TerminalModes{
ssh.ECHO: 1, // enable echoing
ssh.TTY_OP_ISPEED: 14400, // input = 14.4kbaud
ssh.TTY_OP_OSPEED: 14400, // output = 14.4kbaud
}
fd := int(os.Stdin.Fd())
width, height, err := term.GetSize(fd)
if err != nil {
return fmt.Errorf("unable to get terminal size: %w", err)
}
if err := session.RequestPty("linux", height, width, modes); err != nil {
return fmt.Errorf("unable to request pseudo terminal: %w", err)
}
// set input and output
session.Stdout = streams.Out
session.Stdin = streams.In
session.Stderr = streams.Err
if err := session.Shell(); err != nil {
return fmt.Errorf("unable to start shell: %w", err)
}
sigWinChChan := make(chan os.Signal, 1)
signal.Notify(sigWinChChan, syscall.SIGWINCH)
go func() {
for range sigWinChChan {
fd := int(os.Stdin.Fd())
width, height, err := term.GetSize(fd)
if err != nil {
log.Warnf("unable to get terminal size: %v", err)
continue
}
if err := session.WindowChange(height, width); err != nil {
log.Warnf("unable to change remote window size: %v", err)
}
}
}()
oState, err := term.MakeRaw(fd)
if err != nil {
return fmt.Errorf("unable to make raw terminal: %w", err)
}
defer func() {
err := term.Restore(fd, oState)
if err != nil {
log.Errorf("unable to restore terminal: %v", err)
}
}()
err = session.Wait()
if err != nil {
return fmt.Errorf("unable to wait: %w", err)
}
case opts.Forward:
// TODO #957
default:
fmt.Fprintf(cmd.OutOrStderr(), "unknown action flag")
}
return nil
},
}
cmd.Flags().StringVarP(&opts.Context, "context", "c", node.DefaultContextName, "Capability Context")
cmd.Flags().StringVarP(&opts.ID, "id", "i", "", "Deployment ID")
cmd.Flags().BoolVar(&opts.Shell, "shell", false, "Attach a Shell")
cmd.Flags().BoolVar(&opts.Forward, "forward", false, "Attach a Forwarder")
cmd.Flags().StringVarP(&opts.Alloc, "alloc", "a", "", "Allocation Name")
cmd.Flags().StringVarP(&opts.Username, "username", "u", "", "Username for SSH Shell")
cmd.Flags().StringVarP(&opts.Identity, "identity", "I", "", "SSH Private Key for SSH Shell")
cmd.Flags().StringVarP(&opts.Port, "port", "p", "22", "Port for SSH Shell")
cmd.Flags().StringVarP(&opts.PortMap, "portmap", "P", "", "Port Mapping <host:alloc> for Forwarder")
err := cmd.MarkFlagRequired("context")
if err != nil {
log.Fatalf("unable to mark flag 'context' as required: %v", err)
}
err = cmd.MarkFlagRequired("id")
if err != nil {
log.Fatalf("unable to mark flag 'id' as required: %v", err)
}
cmd.MarkFlagsMutuallyExclusive("shell", "forward")
cmd.MarkFlagsMutuallyExclusive("shell", "portmap")
cmd.MarkFlagsMutuallyExclusive("forward", "username")
cmd.MarkFlagsMutuallyExclusive("forward", "identity")
cmd.MarkFlagsMutuallyExclusive("forward", "port")
cmd.MarkFlagsRequiredTogether("shell", "username", "identity", "port")
cmd.MarkFlagsRequiredTogether("forward", "portmap")
// network attach --context dag --shell root@alloc1:22 -i <id>
return cmd
}
func getDeploymentIDs(ctx context.Context, dmsClient client.DmsClient) ([]string, error) {
resp, err := dmsClient.DeploymentList(
ctx,
node.DeploymentListRequest{},
client.WithTimeout(5*time.Second),
)
if err != nil {
return nil, fmt.Errorf("error getting deployment list from client: %w", err)
}
ids := make([]string, 0)
for _, deployment := range resp.Deployments {
// only running deployments
if deployment.Status == jobtypes.DeploymentStatusRunning.String() {
ids = append(ids, deployment.OrchestratorID)
}
}
return ids, nil
}
func getNetworkList(ctx context.Context, ids []string, dmsClient client.DmsClient) ([]DeploymentNetwork, error) {
depNet := make([]DeploymentNetwork, 0)
for _, i := range ids {
resp, err := dmsClient.DeploymentManifest(
ctx,
node.DeploymentManifestRequest{
ID: i,
},
client.WithTimeout(5*time.Second),
)
if err != nil {
return nil, fmt.Errorf("unable to get deployment manifest(id=%s) from client: %w", i, err)
}
if resp.Manifest.Subnet.Join {
// "network" is only if orchestator joined the subnet
allocs := make([]Allocation, 0)
for allocID, alloc := range resp.Manifest.Allocations {
allocs = append(allocs, Allocation{
Alloc: allocID,
DNSName: alloc.DNSName,
IP: alloc.PrivAddr,
PortMapping: alloc.Ports,
Status: string(alloc.Status),
})
}
depNet = append(depNet, DeploymentNetwork{
ID: resp.Manifest.ID,
Allocations: allocs,
})
}
}
return depNet, nil
}
func getNetwork(ctx context.Context, id string, dmsClient client.DmsClient) (DeploymentNetwork, error) {
resp, err := dmsClient.DeploymentManifest(
ctx,
node.DeploymentManifestRequest{
ID: id,
},
client.WithTimeout(5*time.Second),
)
if err != nil {
return DeploymentNetwork{}, fmt.Errorf("unable to get deployment manifest(id=%s) from client: %w", id, err)
}
if resp.Manifest.Subnet.Join {
// "network" is only if orchestator joined the subnet
allocs := make([]Allocation, 0)
for allocID, alloc := range resp.Manifest.Allocations {
allocs = append(allocs, Allocation{
Alloc: allocID,
DNSName: alloc.DNSName,
IP: alloc.PrivAddr,
PortMapping: alloc.Ports,
Status: string(alloc.Status),
})
}
return DeploymentNetwork{
ID: resp.Manifest.ID,
Allocations: allocs,
}, nil
}
return DeploymentNetwork{}, fmt.Errorf("the deployment does not have an accessible network")
}
// Hostkey manager
type HostKeyManager struct {
knownHostsPath string
keys map[string]string
}
func NewHostKeyManager(afs afero.Afero, knownHostsPath string) (*HostKeyManager, error) {
hostKeyManager := &HostKeyManager{
knownHostsPath: knownHostsPath,
keys: make(map[string]string),
}
// create dir+file if not exists
if _, err := afs.Stat(knownHostsPath); os.IsNotExist(err) {
if err := afs.MkdirAll(path.Dir(knownHostsPath), 0o700); err != nil {
return nil, fmt.Errorf("unable to create known_hosts dir: %w", err)
}
if _, err := afs.Create(knownHostsPath); err != nil {
log.Fatalf("unable to create known_hosts file: %v", err)
}
}
// read known_hosts file
data, err := os.ReadFile(knownHostsPath)
if err != nil {
return nil, fmt.Errorf("unable to read known_hosts file: %w", err)
}
records := strings.Split(string(data), "\n")
for _, record := range records {
record = strings.TrimSpace(record)
if record == "" {
continue
}
fields := strings.Fields(record)
if len(fields) >= 2 {
if err != nil {
return nil, fmt.Errorf("unable to parse known_hosts key: %w", err)
}
hostKeyManager.keys[fields[0]] = fields[1]
}
}
return hostKeyManager, nil
}
func (h *HostKeyManager) saveRecord(hostname string, key string) error {
// write to file
f, err := os.OpenFile(h.knownHostsPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600)
if err != nil {
return fmt.Errorf("unable to open known_hosts file: %w", err)
}
defer f.Close()
_, err = fmt.Fprintf(f, "%s %s\n", hostname, key)
if err != nil {
return fmt.Errorf("unable to write known_hosts file: %w", err)
}
return nil
}
func (h *HostKeyManager) HostKeyCallback(out io.Writer) ssh.HostKeyCallback {
return func(hostname string, _ net.Addr, key ssh.PublicKey) error {
// encode key to string
keyStr := base64.StdEncoding.EncodeToString(key.Marshal())
stored, ok := h.keys[hostname]
if !ok {
fmt.Fprintf(
out,
"Unknown host key for %s\nFingerprint: %s\n\n",
hostname,
keyStr,
)
yes, err := dmsUtils.PromptYesNo(os.Stdin, out, "Are you sure you want to proceed?")
if err != nil {
return fmt.Errorf("unable to prompt for host key verification: %w", err)
}
if !yes {
return fmt.Errorf("host key verification failed")
}
// save the key
if err := h.saveRecord(hostname, keyStr); err != nil {
return fmt.Errorf("unable to save host key: %w", err)
}
return nil
}
// if stored key exists
if stored != keyStr {
return fmt.Errorf("host key verification failed")
}
return nil
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/actor"
"gitlab.com/nunet/device-management-service/cmd/cap"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/lib/env"
)
// NewRootCMD returns the cmds
func NewRootCMD(dmsCli *cli.DmsCLI) *cobra.Command {
cmd := &cobra.Command{
Use: "nunet",
Short: "NuNet Device Management Service",
Long: `The Device Management Service (DMS) Command Line Interface (CLI)`,
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: false,
HiddenDefaultCmd: true,
},
SilenceErrors: true,
SilenceUsage: true,
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
PersistentPreRun: func(_ *cobra.Command, _ []string) {
_, _ = dmsCli.ConfigLoader().Load()
},
}
dmsCli.ConfigLoader().BindFlags(cmd.PersistentFlags())
cmd.AddCommand(newRunCmd(dmsCli))
cmd.AddCommand(newKeyCmd(dmsCli))
cmd.AddCommand(cap.NewCapCmd(dmsCli))
cmd.AddCommand(actor.NewActorCmd(dmsCli))
cmd.AddCommand(newConfigCmd(dmsCli))
cmd.AddCommand(newVersionCmd())
cmd.AddCommand(newGPUCommand())
cmd.AddCommand(newNetworkCommand(dmsCli))
cmd.AddCommand(newTranslateCmd(dmsCli))
cmd.AddCommand(newValidateCmd(dmsCli))
cmd.AddCommand(newDeployCmd(dmsCli))
cmd.AddCommand(newGetCmd(dmsCli))
cmd.AddCommand(newContractCmd(dmsCli))
return cmd
}
func Execute() {
dmsCli := cli.New(
cli.WithFS(afero.NewOsFs()),
cli.WithEnv(env.NewOSEnvironment()),
)
cobra.CheckErr(NewRootCMD(dmsCli).Execute())
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"fmt"
"net/http"
_ "net/http/pprof" // #nosec
"os"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/internal"
"gitlab.com/nunet/device-management-service/lib/did"
)
type RunOptions struct {
Context string
PrismURL string
}
func newRunCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
var opts RunOptions
cmd := &cobra.Command{
Use: "run",
Short: "Start the Device Management Service",
Long: `Start the Device Management Service
The Device Management Service (DMS) is a system application for running a node in the NuNet decentralized network of compute providers.
By default, DMS listens on port 9999. For more information on configuration, see:
nunet config --help
Or manually create a dms_config.json file and refer to the README for available settings.`,
RunE: func(_ *cobra.Command, _ []string) error {
// Configure PRISM resolver URL from flag or environment variable
prismURL := opts.PrismURL
if prismURL == "" {
prismURL = dmsCli.Env().Getenv("PRISM_RESOLVER_URL")
}
if prismURL != "" {
originalConfig := did.GetPRISMResolverConfig()
did.SetPRISMResolverConfig(did.PRISMResolverConfig{
ResolverURL: prismURL,
PreferredVerificationMethod: originalConfig.PreferredVerificationMethod,
HTTPClient: originalConfig.HTTPClient,
})
}
passphrase, err := dmsCli.Passphrase(opts.Context)
if err != nil {
return fmt.Errorf("get dms passphrase: %w", err)
}
cfg, err := dmsCli.Config()
if err != nil {
return fmt.Errorf("get dms config: %w", err)
}
if cfg.Profiler.Enabled {
go func() {
pprofMux := http.DefaultServeMux
http.DefaultServeMux = http.NewServeMux()
profilerAddr := fmt.Sprintf("%s:%d", cfg.Profiler.Addr, cfg.Profiler.Port)
log.Infof("Starting profiler on %s\n", profilerAddr)
// #nosec
if err := http.ListenAndServe(profilerAddr, pprofMux); err != nil {
log.Errorf("Error starting profiler: %v\n", err)
}
}()
}
dmsInstance, err := dms.NewDMS(dmsCli.FS(), cfg, dmsCli.Env(), passphrase, opts.Context)
if err != nil {
return fmt.Errorf("failed to initialize dms: %w", err)
}
go func() {
sig := <-internal.ShutdownChan
log.Infow("Shutting down after a receiving signal", "sig", sig)
dmsInstance.Stop()
os.Exit(0)
}()
err = dmsInstance.Run()
if err != nil {
return err
}
<-internal.ShutdownChan
return nil
},
}
cmd.Flags().StringVarP(&opts.Context, "context", "c", node.DefaultContextName, "specify a capability context")
cmd.Flags().StringVar(&opts.PrismURL, "prism-url", "", "PRISM resolver URL (e.g., http://localhost:8080). Can also be set via PRISM_RESOLVER_URL environment variable.")
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"context"
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/dms/jobs/parser"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/dms/translator"
)
type TranslateOptions struct {
InputFile string
FromFormat string
OutputFile string
}
func newTranslateCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
var opts TranslateOptions
cmd := &cobra.Command{
Use: "translate <input-file>",
Short: "Translate a foreign specification to a NuNet Ensemble configuration.",
Long: `Translate a foreign specification, such as a Docker Compose file,
into a native NuNet DMS Ensemble configuration file.
This allows you to leverage existing development files and easily onboard them
onto the NuNet platform.`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.InputFile = args[0]
return runTranslateCmd(cmd.Context(), dmsCli, opts, cli.CmdStreams(cmd))
},
}
cmd.Flags().StringVarP(&opts.FromFormat, "from", "f", "docker-compose", "The source format of the input file (e.g., 'docker-compose').")
cmd.Flags().StringVarP(&opts.OutputFile, "output", "o", "", "Path to the output NuNet Ensemble file.")
return cmd
}
func runTranslateCmd(_ context.Context, dmsCli *cli.DmsCLI, opts TranslateOptions, streams cli.Streams) error {
fs := afero.Afero{Fs: dmsCli.FS()}
// Read the input file content
inputBytes, err := fs.ReadFile(opts.InputFile)
if err != nil {
return fmt.Errorf("could not read input file '%s': %w", opts.InputFile, err)
}
// Perform the translation
translation, err := translator.Translate(translator.SpecType(opts.FromFormat), inputBytes)
if err != nil {
return fmt.Errorf("translation failed: %w", err)
}
data, err := parser.Encode(parser.SpecTypeEnsembleV1, translation.Config)
if err != nil {
return fmt.Errorf("failed to encode config: %w", err)
}
if opts.OutputFile == "" {
fmt.Fprintln(streams.Out, string(data))
} else {
if err := fs.WriteFile(opts.OutputFile, data, 0o644); err != nil {
return fmt.Errorf("failed to write output file '%s': %w", opts.OutputFile, err)
}
}
// Print a success message and any warnings to stderr
if opts.OutputFile != "" {
fmt.Fprintf(streams.Out, "Successfully translated '%s' to '%s'.\n", opts.InputFile, opts.OutputFile)
}
if len(translation.Warnings) > 0 {
fmt.Fprintln(streams.Err, "\nPlease review the following warnings (also included as comments in the output file):")
for _, warning := range translation.Warnings {
fmt.Fprintf(streams.Err, " - %s\n", warning)
}
}
return nil
}
type ValidateOpts struct {
InputFile string
}
func newValidateCmd(
dmsCli *cli.DmsCLI,
) *cobra.Command {
var opts ValidateOpts
cmd := &cobra.Command{
Use: "validate <input-file>",
Short: "Validate a NuNet Ensemble configuration.",
Long: `Parses a NuNet Ensemble configuration file and validate if the configuration is valid.
Example:
nunet validate ensemble.yaml
`,
Args: cobra.ExactArgs(1),
RunE: func(cmd *cobra.Command, args []string) error {
opts.InputFile = args[0]
return runValidateCmd(cmd.Context(), dmsCli, opts, cli.CmdStreams(cmd))
},
}
return cmd
}
func runValidateCmd(_ context.Context, dmsCli *cli.DmsCLI, opts ValidateOpts, streams cli.Streams) error {
fs := afero.Afero{Fs: dmsCli.FS()}
// Read the input file content
inputBytes, err := fs.ReadFile(opts.InputFile)
if err != nil {
return fmt.Errorf("could not read input file '%s': %w", opts.InputFile, err)
}
var cfg jobtypes.EnsembleConfig
err = parser.Decode(parser.SpecTypeEnsembleV1, inputBytes, &cfg, &parser.Options{
Env: dmsCli.Env(),
Fs: fs,
WorkingDir: "",
})
if err != nil {
return err
}
err = cfg.Validate()
if err != nil {
return err
}
fmt.Fprintln(streams.Out, "Configuration is valid.")
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"fmt"
"strings"
"time"
"github.com/spf13/pflag"
)
// TimeValue adapts time.Time for use as a flag.
type TimeValue struct {
Time *time.Time
Formats []string
}
// NewTimeValue creates a new TimeValue.
func NewTimeValue(t *time.Time, formats ...string) *TimeValue {
if formats == nil {
formats = []string{
time.RFC822,
time.RFC822Z,
time.RFC3339,
time.RFC3339Nano,
time.DateTime,
time.DateOnly,
}
}
return &TimeValue{
Time: t,
Formats: formats,
}
}
// Set time.Time value from string based on accepted formats.
func (t TimeValue) Set(s string) error {
s = strings.TrimSpace(s)
for _, format := range t.Formats {
v, err := time.Parse(format, s)
if err == nil {
*t.Time = v
return nil
}
}
return fmt.Errorf("format must be one of: %v", strings.Join(t.Formats, ", "))
}
// Type name for time.Time flags.
func (t TimeValue) Type() string {
return "time"
}
// String returns the string representation of the time.Time value.
func (t TimeValue) String() string {
if t.Time == nil || t.Time.IsZero() {
return ""
}
return t.Time.String()
}
func GetTime(f *pflag.FlagSet, name string) (time.Time, error) {
t := time.Time{}
flag := f.Lookup(name)
if flag == nil {
return t, fmt.Errorf("flag %s not found", name)
}
if flag.Value == nil || flag.Value.Type() != new(TimeValue).Type() {
return t, fmt.Errorf("flag %s has wrong type or no value", name)
}
val := flag.Value.(*TimeValue)
return *val.Time, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"bytes"
"fmt"
"io"
"path/filepath"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/client"
"gitlab.com/nunet/device-management-service/cmd/cli"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/env"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/utils"
)
const (
DefaultUserContextName = "user"
)
func NewSecurityContext(
dmsCLI *cli.DmsCLI,
context string,
) (actor.SecurityContext, error) {
if context == "" {
context = DefaultUserContextName
}
// Generate ephemeral key pair
privk, pubk, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return nil, fmt.Errorf("generate ephemeral key pair: %w", err)
}
capCtx, err := LoadCapabilityContext(dmsCLI, context)
if err != nil {
return nil, fmt.Errorf("load capability context: %w", err)
}
return actor.NewBasicSecurityContext(pubk, privk, capCtx)
}
func NewCapabilityContext(dmsCLI *cli.DmsCLI, context string) (ucan.CapabilityContext, did.DID, error) {
if context == "" {
context = DefaultUserContextName
}
var ctxDID did.DID
cfg, err := dmsCLI.Config()
if err != nil {
return nil, ctxDID, fmt.Errorf("get config: %w", err)
}
fs := dmsCLI.FS()
keyStoreDir := filepath.Join(cfg.UserDir, node.KeystoreDir)
ks, err := keystore.New(fs, keyStoreDir, false)
if err != nil {
return nil, ctxDID, fmt.Errorf("create keystore: %w", err)
}
passphrase, err := dmsCLI.Passphrase(context)
if err != nil {
return nil, ctxDID, fmt.Errorf("get passphrase: %w", err)
}
priv, err := dms.GenerateAndStorePrivKey(ks, passphrase, context)
if err != nil {
return nil, ctxDID, fmt.Errorf("generate and store private key: %w", err)
}
ctxDID = did.FromPublicKey(priv.GetPublic())
trustCtx, err := did.NewTrustContextWithPrivateKey(priv)
if err != nil {
return nil, ctxDID, fmt.Errorf("create trust context: %w", err)
}
capCtx, err := ucan.NewCapabilityContextWithName(context, trustCtx, ctxDID, nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
if err != nil {
return nil, ctxDID, fmt.Errorf("create capability context: %w", err)
}
if err := SaveCapabilityContext(dmsCLI, capCtx); err != nil {
return nil, ctxDID, fmt.Errorf("save capability context: %w", err)
}
return capCtx, ctxDID, nil
}
// LoadCapabilityContext is a helper function to reduce boilerplate in commands.
// It handles the common steps of loading a capability context: getting config,
// retrieving passphrase, loading trust context, and finally loading capability context.
// TODO slow
func LoadCapabilityContext(dmsCLI *cli.DmsCLI, contextName string) (ucan.CapabilityContext, error) {
if contextName == "" {
contextName = DefaultUserContextName
}
cfg, err := dmsCLI.Config()
if err != nil {
return nil, fmt.Errorf("unable to get config: %w", err)
}
fs := dmsCLI.FS()
passphrase := ""
if !node.IsLedgerContext(contextName) && !node.IsEternlContext(contextName) {
passphrase, err = dmsCLI.Passphrase(contextName)
if err != nil {
return nil, fmt.Errorf("get dms passphrase: %w", err)
}
}
trustCtx, err := node.GetTrustContext(fs, contextName, passphrase, cfg.UserDir)
if err != nil {
return nil, fmt.Errorf("get trust context: %w", err)
}
contextName = node.GetContextKey(contextName) // normalize context name
capCtx, err := node.LoadCapabilityContext(trustCtx, fs, contextName, cfg.UserDir)
if err != nil {
return nil, fmt.Errorf("failed to load capability context: %w", err)
}
return capCtx, nil
}
// SaveCapabilityContext is a helper function to save a capability context
func SaveCapabilityContext(dmsCLI *cli.DmsCLI, capCtx ucan.CapabilityContext) error {
cfg, err := dmsCLI.Config()
if err != nil {
return fmt.Errorf("unable to get config: %w", err)
}
fs := dmsCLI.FS()
if err := node.SaveCapabilityContext(capCtx, fs, cfg.UserDir); err != nil {
return fmt.Errorf("save capability context: %w", err)
}
return nil
}
func NewClient(cfg *config.Config, sctx actor.SecurityContext) (client.DmsClient, error) {
return client.NewClient(client.Config{
Host: fmt.Sprintf("%s:%d", cfg.Rest.Addr, cfg.Rest.Port),
APIPrefix: "/api",
Version: "v1",
}, sctx)
}
// TODO test code not in _test.go
func NewTestCli(opts ...func(*cli.DmsCLI)) *cli.DmsCLI {
defaults := []func(*cli.DmsCLI){}
env := env.NewMockEnvironment()
err := env.Setenv("DMS_PASSPHRASE", "pass")
if err == nil {
defaults = append(defaults, cli.WithEnv(env))
}
fs := afero.NewMemMapFs()
cfg := config.DefaultConfig
cfg.General.UserDir = "/tmp/nunet/user"
cfg.General.WorkDir = "/tmp/nunet/work"
cfg.General.DataDir = "/tmp/nunet/data"
defaults = append(defaults, cli.WithFS(fs), cli.WithConfig(&cfg))
dmsCli := cli.New(append(defaults, opts...)...)
return dmsCli
}
func GetDMSPassphrase(
env env.EnvironmentProvider, withConfirm bool,
) (string, error) {
var err error
passphrase := env.Getenv(node.DMSPassphraseEnv)
if passphrase == "" {
passphrase, err = utils.PromptForPassphrase(withConfirm)
if err != nil {
return "", fmt.Errorf("failed to get passphrase: %w", err)
}
}
return passphrase, nil
}
func ExecuteCommand(
command *cobra.Command, args ...string,
) (stdout, stderr string, err error) {
var stdoutBuf, stderrBuf bytes.Buffer
// Redirect to our buffers
command.SetOut(&stdoutBuf)
command.SetErr(&stderrBuf)
// Set args and execute the command
command.SetArgs(args)
err = command.Execute()
return stdoutBuf.String(), stderrBuf.String(), err
}
func ExecuteCommandWithInput(command *cobra.Command, input [][]byte, args ...string) (stdout, stderr string, err error) {
if len(input) > 0 {
in, out := io.Pipe()
command.SetIn(in)
go func() {
for _, input := range input {
_, err := out.Write(input)
if err != nil {
return
}
}
}()
}
return ExecuteCommand(command, args...)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cmd
import (
"fmt"
"github.com/spf13/cobra"
)
// Version information set by the build system (see Makefile)
var (
Version string
GoVersion string
BuildDate string
Commit string
)
func newVersionCmd() *cobra.Command {
return &cobra.Command{
Use: "version",
Short: "Information about current version",
Long: `Display information about the current Device Management Service (DMS) version`,
Run: func(_ *cobra.Command, _ []string) {
fmt.Println("NuNet Device Management Service")
fmt.Printf("Version: %s\nCommit: %s\n\nGo Version: %s\nBuild Date: %s\n",
Version, Commit, GoVersion, BuildDate)
},
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
"fmt"
"github.com/dgraph-io/badger/v3"
clover "github.com/ostafen/clover/v2"
badgerstore "github.com/ostafen/clover/v2/store/badger"
"gitlab.com/nunet/device-management-service/observability"
)
func createCollections(db *clover.DB, collections []string) error {
for _, c := range collections {
if err := db.CreateCollection(c); err != nil {
if err == clover.ErrCollectionExist {
continue
}
err = fmt.Errorf("failed to create collection %s: %w", c, err)
logger.Errorw("clover_db_init_failure", "collection", c, "error", err)
return err
}
}
return nil
}
// NewDB initializes and sets up the clover database using bbolt under the hood.
// Additionally, it automatically creates collections for the necessary types.
func NewDB(path string, collections []string) (*clover.DB, error) {
endSpan := observability.StartSpan("clover_db_init")
defer endSpan()
db, err := clover.Open(path)
if err != nil {
logger.Errorw("clover_db_init_failure",
"error", fmt.Errorf("failed to connect to database: %w", err),
)
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
if err := createCollections(db, collections); err != nil {
return nil, err
}
logger.Debugw("clover_db_init_success", "path", path, "collections", collections)
return db, nil
}
// NewMemDB initializes and sets up in-memory database using badger store.
// Additionally, it automatically creates collections for the necessary types.
func NewMemDB(collections []string) (*clover.DB, error) {
store, err := badgerstore.Open(badger.DefaultOptions("").WithInMemory(true)) // opens a badger in memory database
if err != nil {
return nil, fmt.Errorf("failed to create in-memory store: %w", err)
}
db, err := clover.OpenWithStore(store)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
if err := createCollections(db, collections); err != nil {
return nil, err
}
return db, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericEntityRepositoryClover is a generic single entity repository implementation using Clover.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryClover.
// It initializes and returns a repository with the provided Clover database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericEntityRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericEntityRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericEntityRepositoryClover[T]) query() *clover_q.Query {
return clover_q.NewQuery(repo.collection)
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryClover[T]) Save(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) Get(_ context.Context) (T, error) {
var model T
q := repo.query().Sort(clover_q.SortOption{
Field: "CreatedAt",
Direction: -1,
})
doc, err := repo.db.FindFirst(q)
if err != nil {
return model, handleDBError(err)
}
if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryClover[T]) Clear(_ context.Context) error {
return repo.db.Delete(repo.query())
}
// History retrieves previous versions of the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) History(_ context.Context, query repositories.Query[T]) ([]T, error) {
var models []T
q := repo.query()
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return false
}
models = append(models, model)
return true
})
return models, handleDBError(err)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
)
const (
pkField = "_id"
deletedAtField = "DeletedAt"
)
// GenericRepositoryClover is a generic repository implementation using Clover.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericRepository creates a new instance of GenericRepositoryClover.
// It initializes and returns a repository with the provided Clover database.
func NewGenericRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericRepository[T] {
endSpan := observability.StartSpan("clover_db_repo_init")
defer endSpan()
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
logger.Debugw("clover_db_repo_init_success", "collection", collection)
return &GenericRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericRepositoryClover[T]) query(includeDeleted bool) *clover_q.Query {
query := clover_q.NewQuery(repo.collection)
if !includeDeleted {
query = query.Where(clover_q.Field(deletedAtField).LtEq(time.Unix(0, 0)))
}
return query
}
func (repo *GenericRepositoryClover[T]) queryWithID(
id interface{},
includeDeleted bool,
) *clover_q.Query {
return repo.query(includeDeleted).Where(clover_q.Field(pkField).Eq(id.(string)))
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryClover[T]) Create(ctx context.Context, data T) (T, error) {
endSpan := observability.StartSpan(ctx, "clover_db_create")
defer endSpan()
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
logger.Errorw("clover_db_create_failure", "error", err)
return data, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
logger.Errorw("clover_db_create_failure", "error", err)
return data, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
logger.Infow("clover_db_create_success", "collection", repo.collection)
return model, nil
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryClover[T]) Get(ctx context.Context, id interface{}) (T, error) {
endSpan := observability.StartSpan(ctx, "clover_db_get")
defer endSpan()
var model T
doc, err := repo.db.FindById(repo.collection, id.(string))
if err != nil {
logger.Errorw("clover_db_get_failure", "id", id, "error", err)
return model, handleDBError(err)
}
if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
logger.Errorw("clover_db_get_failure", "id", id, "error", err)
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
logger.Infow("clover_db_get_success", "id", id)
return model, nil
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryClover[T]) Update(
ctx context.Context,
id interface{},
data T,
) (T, error) {
endSpan := observability.StartSpan(ctx, "clover_db_update")
defer endSpan()
updates := toCloverDoc(data).AsMap()
updates["UpdatedAt"] = time.Now()
err := repo.db.Update(repo.queryWithID(id, false), updates)
if err != nil {
logger.Errorw("clover_db_update_failure", "id", id, "error", err)
return data, handleDBError(err)
}
data, err = repo.Get(ctx, id)
logger.Infow("clover_db_update_success", "id", id)
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryClover[T]) Delete(ctx context.Context, id interface{}) error {
endSpan := observability.StartSpan(ctx, "clover_db_delete")
defer endSpan()
err := repo.db.Delete(
repo.queryWithID(id, false),
)
if err != nil {
logger.Errorw("clover_db_delete_failure", "id", id, "error", err)
return err
}
logger.Infow("clover_db_delete_success", "id", id)
return nil
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryClover[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
endSpan := observability.StartSpan(ctx, "clover_db_find")
defer endSpan()
var model T
q := repo.query(false)
q = applyConditions(q, query)
doc, err := repo.db.FindFirst(q)
if err != nil {
logger.Errorw("clover_db_find_failure", "error", err)
return model, handleDBError(err)
}
if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
logger.Errorw("clover_db_find_failure", "error", err)
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
logger.Infow("clover_db_find_success", "collection", repo.collection)
return model, nil
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryClover[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
endSpan := observability.StartSpan(ctx, "clover_db_find_all")
defer endSpan()
var models []T
var modelParsingErr error
q := repo.query(false)
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
model, internalErr := toModel[T](doc, false)
if internalErr != nil {
modelParsingErr = handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, internalErr))
logger.Errorw("clover_db_find_all_failure", "error", internalErr)
return false
}
models = append(models, model)
return true
})
if err != nil {
logger.Errorw("clover_db_find_all_failure", "error", err)
return models, handleDBError(err)
}
if modelParsingErr != nil {
return models, modelParsingErr
}
logger.Infow("clover_db_find_all_success", "collection", repo.collection)
return models, nil
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a Clover database query.
// It takes a Clover database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified Clover database instance is returned.
func applyConditions[T repositories.ModelType](
q *clover_q.Query,
query repositories.Query[T],
) *clover_q.Query {
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
// change the field name to json tag name if specified in the struct
condition.Field = fieldJSONTag[T](condition.Field)
switch condition.Operator {
case "=":
q = q.Where(clover_q.Field(condition.Field).Eq(condition.Value))
case ">":
q = q.Where(clover_q.Field(condition.Field).Gt(condition.Value))
case ">=":
q = q.Where(clover_q.Field(condition.Field).GtEq(condition.Value))
case "<":
q = q.Where(clover_q.Field(condition.Field).Lt(condition.Value))
case "<=":
q = q.Where(clover_q.Field(condition.Field).LtEq(condition.Value))
case "!=":
q = q.Where(clover_q.Field(condition.Field).Neq(condition.Value))
case "IN":
if values, ok := condition.Value.([]interface{}); ok {
q = q.Where(clover_q.Field(condition.Field).In(values...))
}
case "LIKE":
if value, ok := condition.Value.(string); ok {
q = q.Where(clover_q.Field(condition.Field).Like(value))
}
}
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldName = fieldJSONTag[T](fieldName)
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
q = q.Where(clover_q.Field(fieldName).Eq(fieldValue))
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := 1
if query.SortBy[0] == '-' {
dir = -1
query.SortBy = fieldJSONTag[T](query.SortBy[1:])
}
q = q.Sort(clover_q.SortOption{Field: query.SortBy, Direction: dir})
}
// Apply limit if specified in the query.
if query.Limit > 0 {
q = q.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
q = q.Limit(query.Offset)
}
return q
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package clover
import (
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
"gitlab.com/nunet/device-management-service/db/repositories"
)
func handleDBError(err error) error {
if err != nil {
switch err {
case clover.ErrDocumentNotExist:
return repositories.ErrNotFound
case clover.ErrDuplicateKey:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
func toCloverDoc[T repositories.ModelType](data T) *clover_d.Document {
jsonBytes, err := json.Marshal(data)
if err != nil {
return clover_d.NewDocument()
}
mappedData := make(map[string]interface{})
err = json.Unmarshal(jsonBytes, &mappedData)
if err != nil {
return clover_d.NewDocument()
}
doc := clover_d.NewDocumentOf(mappedData)
return doc
}
func toModel[T repositories.ModelType](doc *clover_d.Document, isEntityRepo bool) (T, error) {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return model, err
}
if !isEntityRepo {
// we shouldn't try to update IDs of entity repositories as they might not
// even have an ID at all
model, err = repositories.UpdateField(model, "ID", doc.ObjectId())
if err != nil {
return model, err
}
}
return model, nil
}
func fieldJSONTag[T repositories.ModelType](field string) string {
fieldName := field
if field, ok := reflect.TypeOf(*new(T)).FieldByName(field); ok {
if tag, ok := field.Tag.Lookup("json"); ok {
if name := strings.Split(tag, ",")[0]; name != "" {
fieldName = name
}
}
}
return fieldName
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package repositories
// EQ creates a QueryCondition for equality comparison.
// It takes a field name and a value and returns a QueryCondition with the equality operator.
func EQ(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "=", Value: value}
}
// GT creates a QueryCondition for greater-than comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than operator.
func GT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">", Value: value}
}
// GTE creates a QueryCondition for greater-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than or equal operator.
func GTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">=", Value: value}
}
// LT creates a QueryCondition for less-than comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than operator.
func LT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<", Value: value}
}
// LTE creates a QueryCondition for less-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than or equal operator.
func LTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<=", Value: value}
}
// IN creates a QueryCondition for an "IN" comparison.
// It takes a field name and a slice of values and returns a QueryCondition with the "IN" operator.
func IN(field string, values []interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "IN", Value: values}
}
// LIKE creates a QueryCondition for a "LIKE" comparison.
// It takes a field name and a pattern and returns a QueryCondition with the "LIKE" operator.
func LIKE(field, pattern string) QueryCondition {
return QueryCondition{Field: field, Operator: "LIKE", Value: pattern}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package repositories
import (
"fmt"
"reflect"
)
// UpdateField is a generic function that updates a field of a struct or a pointer to a struct.
// The function uses reflection to dynamically update the specified field of the input struct.
func UpdateField[T interface{}](input T, fieldName string, newValue interface{}) (T, error) {
// Use reflection to get the struct's field
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
// If input is a pointer, get the underlying element
val = val.Elem()
} else {
// If input is not a pointer, ensure it's addressable
val = reflect.ValueOf(&input).Elem()
}
// Check if the input is a struct
if val.Kind() != reflect.Struct {
return input, fmt.Errorf("not a struct: %T", input)
}
// Get the field by name
field := val.FieldByName(fieldName)
if !field.IsValid() {
return input, fmt.Errorf("field not found: %v", fieldName)
}
// Check if the field is settable
if !field.CanSet() {
return input, fmt.Errorf("field not settable: %v", fieldName)
}
// Check if types are compatible
if !reflect.TypeOf(newValue).ConvertibleTo(field.Type()) {
return input, fmt.Errorf(
"incompatible conversion: %v -> %v; value: %v",
field.Type(), reflect.TypeOf(newValue), newValue,
)
}
// Convert the new value to the field type
convertedValue := reflect.ValueOf(newValue).Convert(field.Type())
// Set the new value to the field
field.Set(convertedValue)
return input, nil
}
// IsEmptyValue checks if value represents a zero-value struct (or pointer to a zero-value struct) using reflection.
// The function is useful for determining if a struct or its pointer is empty, i.e., all fields have their zero-values.
func IsEmptyValue(value interface{}) bool {
// Check if the value is nil
if value == nil {
return true
}
// Use reflection to get the value's type and kind
val := reflect.ValueOf(value)
// If the value is a pointer, dereference it to get the underlying element
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Check if the value is zero (empty) based on its kind
return val.IsZero()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dms
import (
"context"
"crypto/ed25519"
_ "embed"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/libp2p/go-libp2p/core/crypto"
ma "github.com/multiformats/go-multiaddr"
"github.com/oschwald/geoip2-golang"
clover "github.com/ostafen/clover/v2"
"github.com/spf13/afero"
"go.elastic.co/apm/module/apmgin/v2"
"gitlab.com/nunet/device-management-service/api"
clover_db "gitlab.com/nunet/device-management-service/db/clover"
"gitlab.com/nunet/device-management-service/dms/node"
"gitlab.com/nunet/device-management-service/dms/node/geolocation"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/dms/orchestrator"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/gateway/provider"
"gitlab.com/nunet/device-management-service/gateway/provider/local"
gatewastore "gitlab.com/nunet/device-management-service/gateway/store"
backgroundtasks "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/env"
"gitlab.com/nunet/device-management-service/lib/hardware"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/storage/volume/glusterfs/controller"
"gitlab.com/nunet/device-management-service/tokenomics/store"
"gitlab.com/nunet/device-management-service/tokenomics/store/payment"
payment_quote "gitlab.com/nunet/device-management-service/tokenomics/store/payment_quote"
"gitlab.com/nunet/device-management-service/tokenomics/store/transaction"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/sys"
)
//go:embed node/data/GeoLite2-Country.mmdb
var geoLite2Country []byte
type DMS struct {
P2P *libp2p.Libp2p
Node *node.Node
RestServer *api.Server
}
func initialize(fs afero.Fs, cfg *config.Config, env env.EnvironmentProvider) {
workDir := cfg.WorkDir
if workDir != "" {
err := fs.MkdirAll(workDir, os.FileMode(0o700))
if err != nil {
log.Warnf("unable to create work directory: %v", err)
}
}
dataDir := cfg.DataDir
if dataDir != "" {
err := fs.MkdirAll(dataDir, os.FileMode(0o700))
if err != nil {
log.Warnf("unable to create data directory: %v", err)
}
}
userDir := cfg.UserDir
if userDir != "" {
err := fs.MkdirAll(userDir, os.FileMode(0o700))
if err != nil {
log.Warnf("unable to create user directory: %v", err)
}
}
libp2pLogging := env.Getenv("DMS_CONN_LOGS")
if libp2pLogging == "false" || libp2pLogging == "" {
err := silenceConnLogs()
if err != nil {
log.Warnf("unable to set libp2p logging: %v", err)
}
}
// create the iptables NUNET chain if it doesn't exist, flush any rules in there and create jump rules
err := sys.CreateNuNetChain()
if err != nil {
log.Errorf("unable to create iptables NUNET chain: %v", err)
}
err = sys.FlushNuNetChain()
if err != nil {
log.Errorf("unable to flush iptables NUNET chain: %v", err)
}
err = sys.AddJumpRules()
if err != nil {
log.Errorf("unable to add iptables NUNET jump rules: %v", err)
}
}
func NewDMS(fs afero.Fs, gcfg *config.Config, env env.EnvironmentProvider, ksPassphrase, contextName string) (*DMS, error) {
log.Debugf("starting dms with config: %+v", gcfg)
if contextName == "" {
contextName = node.DefaultContextName
}
// if bootstrap peers were passed by env var then override them
btPeers := env.Getenv("BOOTSTRAP_PEERS")
if btPeers != "" {
peers := strings.Split(btPeers, ",")
gcfg.P2P.BootstrapPeers = peers
}
initialize(fs, gcfg, env)
var volumeController *controller.GlusterController
if gcfg.StorageMode {
var err error
volumeController, err = controller.NewGlusterController(gcfg.StorageGlusterfsHostname, gcfg.StorageBricksDir, gcfg.StorageCADirectory)
if err != nil {
return nil, fmt.Errorf("failed to create glusterfs controller: %w", err)
}
if !volumeController.IsServerWorking() {
return nil, errors.New("failed to start in storage mode")
}
}
geoip2db, err := geoip2.FromBytes(geoLite2Country)
if err != nil {
return nil, fmt.Errorf("unable to load geoip2 database: %w", err)
}
log.Debugf("loaded geoip2 database: %v", geoip2db)
keyStoreDir := filepath.Join(gcfg.UserDir, node.KeystoreDir)
keyStore, err := keystore.New(fs, keyStoreDir, false)
if err != nil {
return nil, fmt.Errorf("unable to create keystore: %w", err)
}
privK, err := GetPrivKeyFromKS(keyStore, ksPassphrase, contextName)
if err != nil {
return nil,
fmt.Errorf("private key from keystore: %w", err)
}
pubKey := privK.GetPublic()
db, err := NewDMSDB(gcfg.General.WorkDir)
if err != nil {
return nil, fmt.Errorf("unable to connect to database: %w", err)
}
contractStore, err := store.New(db)
if err != nil {
return nil, fmt.Errorf("unable to create contract store: %w", err)
}
// payment validator
paymentsStore, err := payment.New(db)
if err != nil {
return nil, fmt.Errorf("unable to create payment store: %w", err)
}
usageStore, err := usage.New(db)
if err != nil {
return nil, fmt.Errorf("unable to create usage store: %w", err)
}
hardwareManager := hardware.NewHardwareManager()
repos := resources.ManagerRepos{
OnboardedResources: clover_db.NewGenericEntityRepository[types.OnboardedResources](db),
ResourceAllocation: clover_db.NewGenericRepository[types.ResourceAllocation](db),
}
resourceManager, err := resources.NewResourceManager(repos, hardwareManager)
if err != nil {
return nil, fmt.Errorf("unable to create resource manager: %w", err)
}
onboardRepo := clover_db.NewGenericEntityRepository[types.OnboardingConfig](db)
// Create deployment store for orchestrator registry
deploymentStore, err := orchestrator.NewCloverDeploymentStore(db)
if err != nil {
return nil, fmt.Errorf("unable to create deployment store: %w", err)
}
onboardingManager, err := onboarding.New(context.Background(), resourceManager, hardwareManager, onboardRepo)
if err != nil {
return nil, fmt.Errorf("unable to create onboarding manager: %w", err)
}
bootstrapPeers := make([]ma.Multiaddr, len(gcfg.BootstrapPeers))
for i, addr := range gcfg.BootstrapPeers {
bootstrapPeers[i], _ = ma.NewMultiaddr(addr)
}
cfg := &types.Libp2pConfig{
Env: gcfg.General.Env,
PrivateKey: privK,
BootstrapPeers: bootstrapPeers,
Rendezvous: "nunet-test",
Server: false,
Scheduler: backgroundtasks.NewScheduler(10, 1*time.Second),
DHTPrefix: "/nunet",
CustomNamespace: "/nunet-dht-1/",
ListenAddress: gcfg.P2P.ListenAddress,
PeerCountDiscoveryLimit: 40,
GracePeriodMs: 20000, // 20 seconds
Memory: gcfg.P2P.Memory,
FileDescriptors: gcfg.P2P.FileDescriptors,
}
p2pNet, err := libp2p.New(cfg, fs)
if err != nil {
return nil, fmt.Errorf("unable to create libp2p instance: %v", err)
}
if err = p2pNet.Init(gcfg); err != nil {
return nil, fmt.Errorf("unable to initialize libp2p: %v", err)
}
trustCtx, err := did.NewTrustContextWithPrivateKey(privK)
if err != nil {
return nil, fmt.Errorf("unable to create trust context: %w", err)
}
capStoreDir := filepath.Join(gcfg.UserDir, node.CapstoreDir)
capStoreFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", contextName))
// Check if capability context exists and if it uses a PRISM DID
// If so, add PRISM provider to trust context before loading
if _, err := fs.Stat(capStoreFile); err == nil {
// File exists, check if it's a PRISM DID
prismDIDStr, err := node.GetPrismDID(fs, gcfg.UserDir, contextName)
if err == nil && prismDIDStr != "" {
// This context has a PRISM DID association
prismDID, err := did.FromString(prismDIDStr)
if err == nil {
// Create PRISM provider and add to trust context
prismProvider, err := did.ProviderFromPRISMPrivateKey(prismDID, privK)
if err == nil {
trustCtx.AddProvider(prismProvider)
}
}
}
}
capCtx, err := LoadOrCreateCapCtx(
fs, capStoreFile, trustCtx, contextName, pubKey)
if err != nil {
return nil,
fmt.Errorf(
"unable to load or create capability context: %w", err)
}
trustCtx.Start(time.Hour)
capCtx.Start(5 * time.Minute)
hostLocation := geolocation.Geolocation{
Continent: gcfg.HostContinent,
Country: gcfg.HostCountry,
City: gcfg.HostCity,
}
portConfig := node.PortConfig{
AvailableRangeFrom: gcfg.PortAvailableRangeFrom,
AvailableRangeTo: gcfg.PortAvailableRangeTo,
}
volumeTracker := storage.NewVolumeTracker()
txStore, err := transaction.New(db)
if err != nil {
return nil, fmt.Errorf("failed to create transaction store: %w", err)
}
paymentQuoteStore, err := payment_quote.New(db)
if err != nil {
return nil, fmt.Errorf("failed to create payment quote store: %w", err)
}
factories := provider.NewProviderFactoryRegistry(capCtx.DID().URI)
// add local incus to the factory
local.RegisterFactory(factories)
provRegistry, err := buildProviderRegistry(gcfg, factories)
if err != nil {
log.Fatalf("failed to build provider registry: %v", err)
}
if gcfg.General.ComputeGateway {
if os.Getenv("DMS_BINARY_PATH") == "" {
log.Fatal("DMS_BINARY_PATH env var not set: compute gateway needs absolute path to dms binary")
}
}
provisionedResourceStore, err := gatewastore.New(db)
if err != nil {
log.Fatalf("failed to prepate gateway store: %v", err)
}
hostID := p2pNet.Host.ID().String()
node, err := node.New(*gcfg, afero.Afero{Fs: fs}, onboardingManager,
capCtx, hostID, p2pNet, resourceManager, cfg.Scheduler, hardwareManager,
geoip2db, hostLocation, portConfig, volumeTracker,
volumeController,
contractStore,
paymentsStore,
usageStore,
txStore,
deploymentStore,
provRegistry,
provisionedResourceStore,
paymentQuoteStore,
)
if err != nil {
return nil, fmt.Errorf("failed to create node: %s", err)
}
// initialize rest api server
restConfig := api.ServerConfig{
P2P: p2pNet,
Onboarding: onboardingManager,
Resource: resourceManager,
Middlewares: nil,
Port: gcfg.Rest.Port,
Addr: gcfg.Rest.Addr,
}
// Add APM middleware by appending to restConfig.MidW
restConfig.Middlewares = append(restConfig.Middlewares, apmgin.Middleware(gin.Default()))
rServer := api.NewServer(&restConfig, gcfg)
rServer.SetupRoutes()
return &DMS{
P2P: p2pNet,
Node: node,
RestServer: rServer,
}, nil
}
func buildProviderRegistry(gcfg *config.Config, factories *provider.FactoryRegistry) (*provider.Registry, error) {
reg := provider.NewProviderRegistry()
for _, pc := range gcfg.General.Providers {
p, err := factories.Create(pc.Type, pc.Config)
if err != nil {
return nil, fmt.Errorf("failed to create provider %q: %w", pc.Type, err)
}
reg.Register(p)
}
return reg, nil
}
func (d *DMS) Run() error {
if err := d.P2P.Start(); err != nil {
return fmt.Errorf("unable to start libp2p: %v", err)
}
err := d.Node.Start()
if err != nil {
return fmt.Errorf("failed to start node: %s", err)
}
go func() {
err := d.RestServer.Run()
if err != nil {
log.Fatal(err)
}
}()
// Listen for SIGUSR1 to reload capability contexts
go func() {
err := d.Node.ListenForCapabilityContextsUpdates()
if err != nil {
log.Errorf("failed to listen for capability contexts updates: %v", err)
}
}()
err = d.Node.StartContracts()
if err != nil {
log.Errorf("failed to start contracts from db: %v", err)
}
return nil
}
func (d *DMS) Stop() {
log.Infof("Shutting down DMS")
if d.Node != nil {
if err := d.Node.Stop(); err != nil {
log.Errorf("failed to stop node: %s", err)
}
}
log.Infof("node stopped")
if d.P2P != nil {
if err := d.P2P.Stop(); err != nil {
log.Errorf("failed to stop libp2p network: %s", err)
}
}
log.Infof("network stopped")
observability.Shutdown()
// TODO: stop rest server
}
// GenerateAndStorePrivKey generates a new key pair using Secp256k1,
// storing the private key into user's keystore.
func GenerateAndStorePrivKey(ks keystore.KeyStore, passphrase string, keyID string) (crypto.PrivKey, error) {
privK, _, err := crypto.GenerateKeyPair(crypto.Ed25519, 256)
if err != nil {
return nil, fmt.Errorf("unable to generate key pair: %w", err)
}
rawPriv, err := crypto.MarshalPrivateKey(privK)
if err != nil {
return nil, fmt.Errorf("unable to marshal private key: %w", err)
}
_, err = ks.Save(
keyID,
rawPriv,
passphrase,
)
if err != nil {
return nil, fmt.Errorf("unable to save private key into the keystore: %w", err)
}
return privK, nil
}
// ImportAndStorePrivKey validates the provided private key and stores it into the user's keystore.
func ImportAndStorePrivKey(ks keystore.KeyStore, rawPriv []byte, passphrase string, keyID string) (crypto.PrivKey, error) {
privK, err := crypto.UnmarshalPrivateKey(rawPriv)
if err != nil {
// try to interpret as raw Ed25519 key
if len(rawPriv) == 32 {
// assume it's a seed
stdPriv := ed25519.NewKeyFromSeed(rawPriv)
privK, err = crypto.UnmarshalEd25519PrivateKey(stdPriv)
} else if len(rawPriv) == 64 {
// assume it's a full private key
privK, err = crypto.UnmarshalEd25519PrivateKey(rawPriv)
}
if err != nil {
return nil, fmt.Errorf("invalid private key format: %w", err)
}
}
// ensure we store the key in Protobuf format, regardless of input
marshaledPriv, err := crypto.MarshalPrivateKey(privK)
if err != nil {
return nil, fmt.Errorf("unable to marshal private key: %w", err)
}
_, err = ks.Save(
keyID,
marshaledPriv,
passphrase,
)
if err != nil {
return nil, fmt.Errorf("unable to save private key into the keystore: %w", err)
}
return privK, nil
}
// NewDMSDB creates a clover database with all known dms collections
func NewDMSDB(path string) (*clover.DB, error) {
return clover_db.NewDB(
path,
[]string{
"free_resources",
"request_tracker",
"onboarded_resources",
"machine_resources",
"onboarding_config",
"resource_allocation",
"deployments",
"gpu",
"contracts",
"contracts_keys",
"provisioned_resources",
"contracts_payments",
"service_provider_transactions",
"contracts_usage",
"usage_metadata",
"payment_quotes",
},
)
}
// GetPrivKeyFromKS returns a private key from user's keystore.
// Creates a new one if it does not exist.
func GetPrivKeyFromKS(
keyStore keystore.KeyStore, ksPassphrase string,
contextName string,
) (crypto.PrivKey, error) {
var privK crypto.PrivKey
ksPrivKey, err := keyStore.Get(contextName, ksPassphrase)
if err != nil {
if errors.Is(err, keystore.ErrKeyNotFound) {
privK, err = GenerateAndStorePrivKey(keyStore, ksPassphrase, contextName)
if err != nil {
return nil, fmt.Errorf("couldn't generate and store privK key into keystore: %w", err)
}
} else {
return nil, fmt.Errorf("failed to get private key from keystore; Error: %v", err)
}
} else {
privK, err = ksPrivKey.PrivKey()
if err != nil {
return nil, fmt.Errorf("unable to convert key from keystore to private key: %v", err)
}
}
return privK, nil
}
// LoadOrCreateCapCtx loads a capability context from a file or creates a new one
// if it does not exist.
//
// Note: please use afero 'fs' arg instead of 'os'
func LoadOrCreateCapCtx(
fs afero.Fs,
capStoreFile string,
trustCtx did.TrustContext,
contextName string,
pubKey crypto.PubKey,
) (ucan.CapabilityContext, error) {
var capCtx ucan.CapabilityContext
if _, err := fs.Stat(capStoreFile); err != nil {
capStoreDir := filepath.Dir(capStoreFile)
if err := fs.MkdirAll(capStoreDir, os.FileMode(0o700)); err != nil {
return nil, fmt.Errorf("unable to create capability context directory: %w", err)
}
// does not exist; create it
rootDID := did.FromPublicKey(pubKey)
capCtx, err = ucan.NewCapabilityContextWithName(contextName, trustCtx, rootDID, nil, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
if err != nil {
return nil, fmt.Errorf("unable to create capability context: %w", err)
}
// Save it!
f, err := fs.Create(capStoreFile)
if err != nil {
return nil, fmt.Errorf("unable to create capability context file: %w", err)
}
err = ucan.SaveCapabilityContext(capCtx, f)
_ = f.Close()
if err != nil {
return nil, fmt.Errorf("unable to save capability context: %w", err)
}
} else {
f, err := fs.Open(capStoreFile)
if err != nil {
return nil, fmt.Errorf("unable to open capability context: %w", err)
}
capCtx, err = ucan.LoadCapabilityContextWithName(contextName, trustCtx, f)
_ = f.Close()
if err != nil {
return nil, fmt.Errorf("unable to load capability context: %w", err)
}
}
return capCtx, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobs
import (
"context"
"fmt"
"math"
"path/filepath"
"sync"
"time"
"github.com/google/uuid"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/dms/orchestrator"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/tokenomics/eventhandler"
"gitlab.com/nunet/device-management-service/tokenomics/events"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
deleteLogsAfter = 30 * time.Minute
// Liveness reporting configuration
livenessReportInterval = 30 * time.Second
livenessReportTimeout = 2 * time.Minute
livenessMaxRetries = 3
)
// AllocationInfo gathers useful internal information for external callers
type AllocationInfo struct {
ID string `json:"id"`
Type jobtypes.AllocationType `json:"type"`
Resources types.Resources `json:"resources"`
Orchestrator string `json:"orchestrator"` // peerID
Status string `json:"status"`
Executor string `json:"executor"`
ExecutionID string `json:"execution_id"`
UsingPorts []int `json:"using_ports,omitempty"`
CreatedAt time.Time `json:"created_at"`
StartedAt time.Time `json:"started_at"`
}
// Status holds the status of an allocation.
type Status struct {
JobResources types.Resources
Status AllocationStatus
}
// AllocationDetails encapsulates the dependencies to the constructor.
// TODO: rename and organize general dependencies of allocaiton
type AllocationDetails struct {
Job Job
NodeID string
SourceID string
}
// TODO: remove this struct and move everything to AllocationDetails
type Job struct {
Resources types.Resources
Execution types.SpecConfig
ProvisionScripts map[string][]byte
Keys []types.AllocationKey
Volume []types.VolumeConfig
}
// Allocation represents an allocation
// allocationLiveness contains state for push-based liveness reporting
type allocationLiveness struct {
enabled bool
interval time.Duration
sequenceNumber int64
cancel context.CancelFunc
lock sync.Mutex
}
type Allocation struct {
ID string
allocType jobtypes.AllocationType
Actor actor.Actor
actorRunning bool
status AllocationStatus
nodeID string
sourceID string
computeProviderDID string
deploymentID string
orchestrator actor.Handle
executor types.Executor
executionID string
Job Job
network network.Network
// TODO: create separated type for vpn info
state struct {
subnetIP string
gatewayIP string
portMapping map[int]int
}
resultsDir string
workDir string
lock sync.Mutex
fs afero.Afero
healthcheck func() error
// selfRelease will use node's releaseAllocation mechanism
selfRelease func() error
Contracts map[string]types.ContractConfig
contractEventHandler *eventhandler.EventHandler
contractStore TailContractGetter
createdAt time.Time
startedAt time.Time
// Liveness reporting state
liveness allocationLiveness
}
func (a *Allocation) setStatus(ns AllocationStatus, msg string, notify bool) {
a.lock.Lock()
defer a.lock.Unlock()
os := a.status
a.status = ns
if notify && os != ns {
a.statusChangeNotify(os, ns, msg)
}
}
// TailContractFinder is an interface for finding tail contracts associated with a head contract.
// This interface is used to avoid import cycles between jobs and store packages.
// It returns ContractConfig directly to avoid referencing contracts.Contract.
type TailContractGetter interface {
FindTailContract(headContractConfig types.ContractConfig, computeProviderDID string) (*types.ContractConfig, error)
}
// NewAllocation creates a new allocation given the actor.
func NewAllocation(
id string,
allocType jobtypes.AllocationType,
orchestrator actor.Handle,
fs afero.Afero,
workDir string,
actor actor.Actor,
details AllocationDetails,
network network.Network,
executor types.Executor,
selfRelease func() error,
contractEventHandler *eventhandler.EventHandler,
contractStore TailContractGetter,
deploymentID string,
) (*Allocation, error) {
if network == nil {
return nil, fmt.Errorf("network is nil")
}
if actor == nil {
return nil, fmt.Errorf("actor is nil")
}
if executor == nil {
return nil, fmt.Errorf("executor is nil")
}
executionID, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("create executor id: %w", err)
}
allocation := &Allocation{
ID: id,
allocType: allocType,
fs: fs,
nodeID: details.NodeID,
sourceID: details.SourceID,
orchestrator: orchestrator,
Job: details.Job,
Actor: actor,
executionID: executionID.String(),
workDir: workDir,
status: AllocationPending,
network: network,
executor: executor,
selfRelease: selfRelease,
state: struct {
subnetIP string
gatewayIP string
portMapping map[int]int
}{},
createdAt: time.Now(),
contractEventHandler: contractEventHandler,
contractStore: contractStore,
computeProviderDID: actor.Parent().DID.URI,
deploymentID: deploymentID,
}
// Initialize liveness reporting state
allocation.liveness.enabled = true // hard coded for now
allocation.liveness.interval = livenessReportInterval
log.Debugw("allocation_created",
"labels", string(observability.LabelAllocation),
"allocationID", allocation.ID,
"allocDID", allocation.Actor.Handle().DID.String(),
"executionID", allocation.executionID,
)
return allocation, nil
}
// GetPortMapping returns allocation's port mapping
func (a *Allocation) GetPortMapping() map[int]int {
a.lock.Lock()
defer a.lock.Unlock()
ports := make(map[int]int)
for i, v := range a.state.portMapping {
ports[i] = v
}
return ports
}
// findTailContractsForHeadContract finds Tail Contracts associated with a Head Contract
// using the Head Contract config from the ensemble
func (a *Allocation) findTailContractForHeadContract(headContractConfig types.ContractConfig) (*types.ContractConfig, error) {
if a.contractStore == nil {
return nil, fmt.Errorf("contract store is not available")
}
tailContract, err := a.contractStore.FindTailContract(headContractConfig, a.computeProviderDID)
if err != nil {
return nil, fmt.Errorf("failed to find tail contracts for head contract %s: %w", headContractConfig.DID, err)
}
return tailContract, nil
}
// Run creates the executor based on the execution engine configuration.
func (a *Allocation) Run(
ctx context.Context, subnetIP string,
gatewayIP string, portMapping map[int]int,
) error {
a.lock.Lock()
defer func() {
a.lock.Unlock()
a.setStatus(AllocationRunning, "allocation started", true)
}()
if a.status == AllocationRunning {
log.Warnw("allocation_already_running",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID)
// TODO: Should we return error instead?
return nil
}
var err error
a.resultsDir = filepath.Join(a.workDir, "jobs", a.ID)
err = a.fs.MkdirAll(a.resultsDir, 0o700)
if err != nil {
return fmt.Errorf("create results directory: %w", err)
}
executionRequest := &types.ExecutionRequest{
JobID: a.ID,
ExecutionID: a.executionID,
EngineSpec: &a.Job.Execution,
Resources: &a.Job.Resources,
ProvisionScripts: a.Job.ProvisionScripts,
Keys: a.Job.Keys,
ResultsDir: a.resultsDir,
PersistLogsDuration: deleteLogsAfter,
GatewayIP: gatewayIP,
}
// prepare the directories on host
if len(a.Job.Volume) > 0 {
executionRequest.Inputs = make([]*types.StorageVolumeExecutor, 0)
for _, v := range a.Job.Volume {
src := ""
if v.Type == "glusterfs" {
src = filepath.Join(a.workDir, "volumes", a.ID, v.Name)
} else {
src = v.Src
}
target := v.MountDestination
if target == "" {
target = "/" + v.Name
}
executionRequest.Inputs = append(executionRequest.Inputs, &types.StorageVolumeExecutor{
Type: "bind",
Source: src,
Target: target,
ReadOnly: v.ReadOnly,
})
}
}
for hostPort, executorPort := range portMapping {
executionRequest.PortsToBind = append(
executionRequest.PortsToBind,
types.PortsToBind{
IP: subnetIP,
HostPort: hostPort,
ExecutorPort: executorPort,
},
)
}
err = a.executor.Start(ctx, executionRequest)
if err != nil {
return fmt.Errorf("start executor: %w", err)
}
// Update status (lock already held from function start)
a.startedAt = time.Now()
var headContractConfig types.ContractConfig
// Find Head Contract config from ensemble contracts
for _, contractConfig := range a.Contracts {
headContractConfig = contractConfig
break
}
headContractDID := headContractConfig.DID
// Find Tail Contracts associated with Head Contract
var contractsToNotify []types.ContractConfig
if headContractDID != "" {
// Contract chain scenario: use Tail Contracts
tailContract, err := a.findTailContractForHeadContract(headContractConfig)
if err != nil {
log.Warnw("failed to find tail contracts, falling back to ensemble contracts",
"head_contract_did", headContractDID,
"error", err)
// Convert map to slice for fallback
for _, v := range a.Contracts {
contractsToNotify = append(contractsToNotify, v)
}
} else {
contractsToNotify = []types.ContractConfig{*tailContract}
}
} else {
// P2P scenario: use contracts from ensemble
for _, v := range a.Contracts {
contractsToNotify = append(contractsToNotify, v)
}
}
// Send events to Tail Contracts (or P2P contracts)
for _, v := range contractsToNotify {
evt := events.StartAllocation{
EventBase: events.EventBase{Type: events.StartAllocationEvent},
AllocationBase: events.AllocationBase{
AllocationID: a.ID,
DeploymentID: a.deploymentID,
ComputeProviderDID: a.computeProviderDID,
HeadContractDID: headContractDID, // Include Head Contract DID in payload
},
Resources: a.Job.Resources,
}
a.contractEventHandler.Push(eventhandler.Event{
ContractHostDID: v.Host,
ContractDID: v.DID,
Payload: evt,
})
}
// NEW: Log the resources we've assigned for this run
log.Infow("allocation_run_started",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID,
"cpuCoresAssigned", a.Job.Resources.CPU.Cores,
"ramGBAssigned", a.Job.Resources.RAM.SizeInGB(),
"gpuCountAssigned", len(a.Job.Resources.GPUs),
)
if a.allocType == jobtypes.AllocationTypeTask {
go a.handleExecutionExit(ctx)
} else {
// Start periodic liveness reporting for service allocations
a.startLivenessReporting(ctx)
}
return nil
}
// handleExecutionExit handles the exit of an execution
//
// TODO: retry policy for transient and long-running allocations
func (a *Allocation) handleExecutionExit(ctx context.Context) {
resChan, errChan := a.executor.Wait(ctx, a.executionID)
var result *types.ExecutionResult
var err error
select {
case result = <-resChan:
case err = <-errChan:
case <-ctx.Done():
err = ctx.Err()
}
a.handleTransience(result, err)
}
// handleTransience handles the exit of an execution for transient allocations.
//
// TODO: retry policy (meanwhile, we'll teardown everything in case of error)
func (a *Allocation) handleTransience(r *types.ExecutionResult, err error) {
notifyOrchestrator := func(req behaviors.TaskTerminationNotification) {
req.AllocationID = a.ID
req.Status = string(a.status)
// send logs if existent
if r != nil {
if len(r.STDOUT) > 0 {
req.Stdout = []byte(r.STDOUT)
}
if len(r.STDERR) > 0 {
req.Stderr = []byte(r.STDERR)
}
}
msg, err := actor.Message(
a.Actor.Handle(),
a.orchestrator,
behaviors.NotifyTaskTerminationBehavior,
req,
actor.WithMessageExpiry(uint64(time.Now().Add(2*time.Minute).UnixNano())),
)
if err != nil {
log.Errorf("error creating task termination notification: %s", err)
}
err = a.Actor.Send(msg)
if err != nil {
log.Errorf("error notifying orchestrator: %s", err)
}
}
if err != nil {
log.Warnf("execution failed: %v", err)
a.setStatus(AllocationFailed, "execution failed", true)
exitCode := 0
if r != nil {
exitCode = r.ExitCode
}
notifyOrchestrator(behaviors.TaskTerminationNotification{
Error: behaviors.TerminationError{
ExitCode: exitCode,
Err: fmt.Sprintf("general execution failure: %v", err),
},
})
} else if r != nil {
switch {
case r.ExitCode != 0:
log.Infof("execution exited with exit code: %d", r.ExitCode)
a.setStatus(AllocationFailed, fmt.Sprintf("execution exited with exit code: %d", r.ExitCode), true)
notifyOrchestrator(behaviors.TaskTerminationNotification{
Error: behaviors.TerminationError{
ExitCode: r.ExitCode,
Err: fmt.Sprintf("execution exit code != 0, exit code: %d", r.ExitCode),
},
})
case r.ExitCode == 0 && !r.Killed:
log.Infof("task execution successfully completed")
a.setStatus(AllocationCompleted, "task execution successfully completed", true)
notifyOrchestrator(behaviors.TaskTerminationNotification{})
case r.ExitCode == 0 && r.Killed:
log.Infof("execution possibly killed")
a.setStatus(AllocationFailed, "execution possibly killed", true)
notifyOrchestrator(behaviors.TaskTerminationNotification{
Error: behaviors.TerminationError{
ExitCode: r.ExitCode,
Err: "execution possibly killed",
Killed: true,
},
})
}
}
log.Debugf("self releasing: %s", a.ID)
err = a.selfRelease()
if err != nil {
log.Errorf("error releasing self: %s", err)
}
if len(a.Contracts) == 0 {
log.Errorf("no contracts (handleTranscience)",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID)
return
}
var headContractConfig types.ContractConfig
for _, contractConfig := range a.Contracts {
headContractConfig = contractConfig
break
}
headContractDID := headContractConfig.DID
// Find Tail Contracts associated with Head Contract
var contractsToNotify []types.ContractConfig
if headContractDID != "" {
// Contract chain scenario: use Tail Contracts
tailContract, err := a.findTailContractForHeadContract(headContractConfig)
if err != nil {
log.Warnw("failed to find tail contracts, falling back to ensemble contracts",
"head_contract_did", headContractDID,
"error", err)
// Convert map to slice for fallback
for _, v := range a.Contracts {
contractsToNotify = append(contractsToNotify, v)
}
} else {
contractsToNotify = []types.ContractConfig{*tailContract}
}
} else {
// P2P scenario: use contracts from ensemble
for _, v := range a.Contracts {
contractsToNotify = append(contractsToNotify, v)
}
}
// Send events to Tail Contracts (or P2P contracts)
for _, v := range contractsToNotify {
evt := events.CompleteAllocation{
EventBase: events.EventBase{Type: events.CompleteAllocationEvent},
AllocationBase: events.AllocationBase{
AllocationID: a.ID,
DeploymentID: a.deploymentID,
ComputeProviderDID: a.computeProviderDID,
HeadContractDID: headContractDID, // Include Head Contract DID in payload
},
}
a.contractEventHandler.Push(eventhandler.Event{
ContractHostDID: v.Host,
ContractDID: v.DID,
Payload: evt,
})
}
}
// Cancel stops the running executor
func (a *Allocation) stopExecution(ctx context.Context) error {
log.Debugw("allocation_stopping_execution",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID)
if a.Status().Status != AllocationRunning {
return nil
}
if a.executor == nil {
return nil
}
if err := a.executor.Cancel(ctx, a.executionID); err != nil {
a.setStatus(AllocationFailed, fmt.Sprintf("error stopping executor: %v", err), true)
return fmt.Errorf("stop execution: %w", err)
}
a.setStatus(AllocationStopped, "allocation stopped", true)
log.Debugw("allocation_stopped_execution",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID)
return nil
}
func (a *Allocation) Cleanup() error {
if a.executor == nil {
return nil
}
if err := a.executor.Remove(a.executionID, orchestrator.AllocationShutdownTimeout); err != nil {
return fmt.Errorf("failed to remove execution: %w", err)
}
log.Debugw("allocation_removed_execution",
"labels", string(observability.LabelAllocation),
"executionID", a.executionID)
return nil
}
// Terminate stops the allocation and cleans up after
// TODO: shouldn't act on a best effort basis? meaning,
// it won't return errors right away but try to clean up
// all the other steps
func (a *Allocation) Terminate(ctx context.Context) error {
var headContractConfig types.ContractConfig
for _, contractConfig := range a.Contracts {
headContractConfig = contractConfig
break
}
headContractDID := headContractConfig.DID
// Find Tail Contracts associated with Head Contract
var contractsToNotify []types.ContractConfig
if headContractDID != "" {
// Contract chain scenario: use Tail Contracts
tailContract, err := a.findTailContractForHeadContract(headContractConfig)
if err != nil {
log.Warnw("failed to find tail contracts, falling back to ensemble contracts",
"head_contract_did", headContractDID,
"error", err)
// Convert map to slice for fallback
for _, v := range a.Contracts {
contractsToNotify = append(contractsToNotify, v)
}
} else {
contractsToNotify = []types.ContractConfig{*tailContract}
}
} else {
// P2P scenario: use contracts from ensemble
for _, v := range a.Contracts {
contractsToNotify = append(contractsToNotify, v)
}
}
// Send events to Tail Contracts (or P2P contracts)
for _, v := range contractsToNotify {
evt := events.StopAllocation{
EventBase: events.EventBase{Type: events.StopAllocationEvent},
AllocationBase: events.AllocationBase{
AllocationID: a.ID,
DeploymentID: a.deploymentID,
ComputeProviderDID: a.computeProviderDID,
HeadContractDID: headContractDID, // Include Head Contract DID in payload
},
}
a.contractEventHandler.Push(eventhandler.Event{
ContractHostDID: v.Host,
ContractDID: v.DID,
Payload: evt,
})
}
status := a.Status().Status
if status != AllocationStopped && status != AllocationCompleted {
err := a.Stop(ctx)
if err != nil {
log.Warnw("allocation_failed_to_stop",
"labels", string(observability.LabelAllocation),
"error", err,
"allocationID", a.ID)
return fmt.Errorf("failed to stop allocation: %w", err)
}
// terminated status only if had to stop
a.setStatus(AllocationTerminated, "allocation terminated", true)
}
if err := a.Cleanup(); err != nil {
log.Warnw("allocation_failed_to_cleanup",
"labels", string(observability.LabelAllocation),
"error", err,
"allocationID", a.ID)
}
return nil
}
// StopActor stops the allocation actor
func (a *Allocation) stopActor() error {
a.lock.Lock()
defer a.lock.Unlock()
if a.actorRunning {
if err := a.Actor.Stop(); err != nil {
log.Warnw("allocation_actor_stop_failure",
"labels", string(observability.LabelAllocation),
"error", err,
"allocationID", a.ID)
}
log.Debugw("allocation_actor_stopped",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID)
a.actorRunning = false
}
return nil
}
// Stop stops the running executor and the allocation actor
// TODO: shouldn't act on a best effort basis? meaning,
// it won't return errors right away but try to clean up
// all the other steps
func (a *Allocation) Stop(ctx context.Context) error {
// Stop liveness reporting first
a.stopLivenessReporting()
err := a.stopActor()
if err != nil {
return fmt.Errorf("stop actor: %w", err)
}
err = a.stopExecution(ctx)
if err != nil {
a.setStatus(AllocationFailed, "failed to stop execution", true)
return fmt.Errorf("stop execution: %w", err)
}
a.setStatus(AllocationStopped, "allocation stopped", true)
return nil
}
// Status returns information about the allocated/usage of resources and execution status of workload.
func (a *Allocation) Status() Status {
a.lock.Lock()
defer a.lock.Unlock()
return Status{
JobResources: a.Job.Resources,
Status: a.status,
}
}
// Start the actor of the allocation.
func (a *Allocation) Start() error {
a.lock.Lock()
defer a.lock.Unlock()
// start actor
if a.actorRunning {
return nil
}
allocationBehaviors := map[string]func(actor.Envelope){
behaviors.AllocationStartBehavior: a.handleAllocationStart,
behaviors.AllocationRestartBehavior: a.handleAllocationRestart,
behaviors.AllocationStatsBehavior: a.handleAllocationStats,
behaviors.SubnetAddPeerBehavior: a.handleSubnetAddPeer,
behaviors.SubnetRemovePeersBehavior: a.handleSubnetRemovePeers,
behaviors.SubnetAcceptPeersBehavior: a.handleSubnetAcceptPeers,
behaviors.SubnetMapPortBehavior: a.handleSubnetMapPort,
behaviors.SubnetUnmapPortBehavior: a.handleSubnetUnmapPort,
behaviors.SubnetDNSAddRecordsBehavior: a.handleSubnetDNSAddRecords,
behaviors.SubnetDNSRemoveRecordsBehavior: a.handleSubnetDNSRemoveRecords,
behaviors.RegisterHealthcheckBehavior: a.handleRegisterHealthcheck,
actor.HealthCheckBehavior: a.handleHealthcheck,
}
// add allocation behaviours to actor
for behavior, handler := range allocationBehaviors {
err := a.Actor.AddBehavior(behavior, handler)
if err != nil {
return fmt.Errorf("add allocation start behavior to allocation actor: %w", err)
}
}
err := a.Actor.Start()
if err != nil {
return fmt.Errorf("start allocation actor: %w", err)
}
a.actorRunning = true
return nil
}
func (a *Allocation) Restart(ctx context.Context) error {
if a.state.subnetIP == "" {
// if you get this error, did you start the allocation properly before restart?
return fmt.Errorf("allocation: state is empty, no subnet ip is provided")
}
if err := a.Stop(ctx); err != nil {
return err
}
if err := a.Start(); err != nil {
return err
}
if err := a.Run(ctx, a.state.subnetIP, a.state.gatewayIP, a.state.portMapping); err != nil {
_ = a.Stop(ctx)
return fmt.Errorf("run allocation: %w", err)
}
return nil
}
// TODO: make send reply a helper func from actor pkg
func (a *Allocation) sendReply(msg actor.Envelope, payload interface{}) {
var opt []actor.MessageOption
if msg.IsBroadcast() {
opt = append(opt, actor.WithMessageSource(a.Actor.Handle()))
}
reply, err := actor.ReplyTo(msg, payload, opt...)
if err != nil {
log.Debugf("creating reply: %s", err)
return
}
if err := a.Actor.Send(reply); err != nil {
log.Debugf("sending reply: %s", err)
}
}
func (a *Allocation) SetHealthCheck(f func() error) {
a.lock.Lock()
defer a.lock.Unlock()
a.healthcheck = f
}
func (a *Allocation) Info() AllocationInfo {
a.lock.Lock()
defer a.lock.Unlock()
return AllocationInfo{
ID: a.ID,
Type: a.allocType,
Orchestrator: a.orchestrator.Address.HostID,
Resources: a.Job.Resources,
Status: string(a.status),
Executor: a.Job.Execution.Type,
ExecutionID: a.ID,
UsingPorts: utils.MapKeysToSlice(a.state.portMapping),
CreatedAt: a.createdAt,
StartedAt: a.startedAt,
}
}
// startLivenessReporting starts periodic push-based liveness reporting
// Only for service allocations - tasks use handleTransience
func (a *Allocation) startLivenessReporting(ctx context.Context) {
if a.allocType == jobtypes.AllocationTypeTask {
return // Tasks already push via handleTransience
}
if !a.liveness.enabled {
log.Debugw("push_liveness_disabled",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID)
return
}
livenessCtx, cancel := context.WithCancel(ctx)
a.liveness.lock.Lock()
a.liveness.cancel = cancel
a.liveness.lock.Unlock()
log.Infow("starting_push_liveness_reporting",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID,
"interval", a.liveness.interval,
"note", "passive collection only, pull checks remain authoritative")
go func() {
// Send initial heartbeat immediately
if err := a.sendLivenessReport(livenessCtx); err != nil {
log.Debugw("initial_liveness_report_failed",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID,
"error", err)
}
ticker := time.NewTicker(a.liveness.interval)
defer ticker.Stop()
for {
select {
case <-livenessCtx.Done():
log.Debugw("stopping_liveness_reporting",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID)
return
case <-ticker.C:
if err := a.sendLivenessReport(livenessCtx); err != nil {
log.Debugw("liveness_report_failed",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID,
"error", err)
// Continue trying - don't stop on failure
}
}
}
}()
}
// sendLivenessReport sends a single liveness notification
func (a *Allocation) sendLivenessReport(ctx context.Context) error {
currentStatus := a.Status()
// Increment sequence number
a.liveness.lock.Lock()
a.liveness.sequenceNumber++
seqNum := a.liveness.sequenceNumber
a.liveness.lock.Unlock()
// Perform self health check
health := a.performSelfHealthCheck(ctx)
// Optionally gather resource usage
var resourceUsage *jobtypes.AllocationResourceUsage
if usage, err := a.gatherResourceUsage(ctx); err == nil {
resourceUsage = usage
}
notification := jobtypes.AllocationLivenessNotification{
AllocationID: a.ID,
Status: string(currentStatus.Status),
Timestamp: time.Now().Unix(),
SequenceNumber: seqNum,
Health: health,
ResourceUsage: resourceUsage,
Version: "0.1",
}
return a.sendToOrchestratorWithRetry(
ctx,
behaviors.NotifyAllocationLivenessBehavior,
notification,
livenessMaxRetries,
)
}
// performSelfHealthCheck runs registered healthcheck (if any)
func (a *Allocation) performSelfHealthCheck(ctx context.Context) jobtypes.HealthStatus {
a.lock.Lock()
healthcheck := a.healthcheck
a.lock.Unlock()
if healthcheck == nil {
return jobtypes.HealthStatus{
Healthy: true,
LastCheckTime: time.Now().Unix(),
CheckType: jobtypes.HealthCheckTypeNone,
Message: "no healthcheck configured",
}
}
// Run with timeout
checkCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
errChan := make(chan error, 1)
go func() {
errChan <- healthcheck()
close(errChan)
}()
select {
case err := <-errChan:
if err != nil {
return jobtypes.HealthStatus{
Healthy: false,
LastCheckTime: time.Now().Unix(),
CheckType: jobtypes.HealthCheckTypeSelf,
Message: fmt.Sprintf("healthcheck failed: %v", err),
}
}
return jobtypes.HealthStatus{
Healthy: true,
LastCheckTime: time.Now().Unix(),
CheckType: jobtypes.HealthCheckTypeSelf,
Message: "healthcheck passed",
}
case <-checkCtx.Done():
return jobtypes.HealthStatus{
Healthy: false,
LastCheckTime: time.Now().Unix(),
CheckType: jobtypes.HealthCheckTypeSelf,
Message: "healthcheck timeout",
}
}
}
// gatherResourceUsage collects resource metrics
func (a *Allocation) gatherResourceUsage(ctx context.Context) (*jobtypes.AllocationResourceUsage, error) {
// zero usage if allocation not running
if a.Status().Status != jobtypes.AllocationRunning {
return &jobtypes.AllocationResourceUsage{}, nil
}
stats, err := a.executor.Stats(ctx, a.executionID)
if err != nil {
return nil, fmt.Errorf("failed to retrieve allocation stats: %w", err)
}
if stats == nil {
return nil, fmt.Errorf("allocation stats are nil")
}
resrcUsage := jobtypes.AllocationResourceUsage{
CPUUsagePercent: stats.CPUUsage.Percent,
MemoryUsedBytes: stats.Memory.Usage,
MemoryLimitBytes: a.Job.Resources.RAM.Size,
NetworkRxBytes: stats.Network.RxBytes,
NetworkTxBytes: stats.Network.TxBytes,
}
return &resrcUsage, nil
}
// sendToOrchestratorWithRetry sends with exponential backoff
func (a *Allocation) sendToOrchestratorWithRetry(
ctx context.Context,
behavior string,
payload interface{},
maxRetries int,
) error {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// Exponential backoff: 2^(attempt-1) seconds (1s, 2s, 4s, 8s...)
backoff := time.Duration(math.Pow(2, float64(attempt-1))) * time.Second
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(backoff):
}
}
msg, err := actor.Message(
a.Actor.Handle(),
a.orchestrator,
behavior,
payload,
actor.WithMessageExpiry(uint64(time.Now().Add(livenessReportTimeout).UnixNano())),
)
if err != nil {
lastErr = fmt.Errorf("create message: %w", err)
continue
}
if err := a.Actor.Send(msg); err != nil {
lastErr = fmt.Errorf("send attempt %d: %w", attempt+1, err)
continue
}
return nil // Success
}
return fmt.Errorf("failed after %d attempts: %w", maxRetries+1, lastErr)
}
// statusChangeNotify sends immediate notification when status changes
func (a *Allocation) statusChangeNotify(oldStatus, newStatus AllocationStatus, reason string) {
if !a.liveness.enabled {
return
}
update := jobtypes.AllocationStatusUpdate{
AllocationID: a.ID,
OldStatus: string(oldStatus),
NewStatus: string(newStatus),
Timestamp: time.Now().Unix(),
Reason: reason,
}
msg, err := actor.Message(
a.Actor.Handle(),
a.orchestrator,
behaviors.NotifyAllocationStatusBehavior,
update,
actor.WithMessageExpiry(uint64(time.Now().Add(livenessReportTimeout).UnixNano())),
)
if err != nil {
log.Debugf("failed to create status update message: %v", err)
return
}
if err := a.Actor.Send(msg); err != nil {
log.Debugf("failed to send status update: %v", err)
}
}
// stopLivenessReporting stops the liveness reporting goroutine
func (a *Allocation) stopLivenessReporting() {
a.liveness.lock.Lock()
defer a.liveness.lock.Unlock()
if a.liveness.cancel != nil {
a.liveness.cancel()
a.liveness.cancel = nil
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobs
import (
"context"
"encoding/json"
"fmt"
"strings"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
func (a *Allocation) handleAllocationStart(msg actor.Envelope) {
log.Infow("behavior_allocation_start_invoked",
"labels", string(observability.LabelAllocation),
"from", msg.From)
defer msg.Discard()
var req behaviors.AllocationStartRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
log.Errorw("allocation_start_request_unmarshal_error",
"labels", string(observability.LabelAllocation),
"error", err)
return
}
var resp behaviors.AllocationStartResponse
// Store state regardless of whether we're running or in standby
a.state.subnetIP = req.SubnetIP
a.state.gatewayIP = req.GatewayIP
a.state.portMapping = req.PortMapping
// TODO: context should cancel when the actor is stopped to stop monitor
if err := a.Run(context.TODO(), req.SubnetIP, req.GatewayIP, req.PortMapping); err != nil {
err = fmt.Errorf("failed to run allocation: %w", err)
log.Errorw("allocation_start_run_failure",
"labels", string(observability.LabelAllocation),
"error", err)
resp.Error = err.Error()
resp.OK = false
a.sendReply(msg, resp)
return
}
log.Infow("allocation_start_success",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID)
resp.OK = true
a.sendReply(msg, resp)
}
type AllocationRestartResponse struct {
OK bool
Error string
}
func (a *Allocation) handleAllocationRestart(msg actor.Envelope) {
defer msg.Discard()
resp := behaviors.AllocationRestartResponse{}
if err := a.Restart(context.TODO()); err != nil { // TODO: fix context.TODO()
err = fmt.Errorf("failed to restart allocation: %w", err)
log.Errorw("allocation_restart_failure",
"labels", string(observability.LabelAllocation),
"error", err)
resp.Error = err.Error()
resp.OK = false
a.sendReply(msg, resp)
return
}
log.Infow("allocation_restart_success",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID)
resp.OK = true
a.sendReply(msg, resp)
}
func (a *Allocation) handleAllocationStats(msg actor.Envelope) {
defer msg.Discard()
var resp behaviors.AllocationStatsResponse
if len(msg.Message) > 0 {
var req behaviors.AllocationStatsRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
}
// Do not report usage when allocation is not running; return nil stats so deployment info does not add to usage.
if a.Status().Status != jobtypes.AllocationRunning {
resp.OK = true
resp.Stats = nil
a.sendReply(msg, resp)
return
}
if a.executor == nil {
err := fmt.Errorf("allocation executor not initialized")
log.Errorw("allocation_stats_executor_nil",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID,
"error", err,
)
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
stats, err := a.executor.Stats(context.TODO(), a.executionID) // TODO: fix context.TODO()
if err != nil {
err = fmt.Errorf("failed to retrieve allocation stats: %w", err)
log.Errorw("allocation_stats_failure",
"labels", string(observability.LabelAllocation),
"allocationID", a.ID,
"executionID", a.executionID,
"error", err,
)
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
resp.OK = true
resp.Stats = stats
a.sendReply(msg, resp)
}
func (a *Allocation) handleRegisterHealthcheck(msg actor.Envelope) {
defer msg.Discard()
var request behaviors.RegisterHealthcheckRequest
resp := behaviors.RegisterHealthcheckResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
healthcheck, err := types.NewHealthCheck(request.HealthCheck, func(mf types.HealthCheckManifest) error {
return a.execHealthCheck(mf)
})
if err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
a.SetHealthCheck(healthcheck)
resp.OK = true
a.sendReply(msg, resp)
}
type HealthCheckResponse struct {
OK bool
Error string
}
func (a *Allocation) handleHealthcheck(msg actor.Envelope) {
defer msg.Discard()
a.lock.Lock()
healthcheck := a.healthcheck
a.lock.Unlock()
var resp HealthCheckResponse
if healthcheck != nil {
if err := healthcheck(); err != nil {
resp.Error = err.Error()
} else {
resp.OK = true
}
} else {
resp.OK = true
}
reply, err := actor.ReplyTo(msg, resp)
if err != nil {
log.Warnw("allocation_healthcheck_reply_creation_failure",
"labels", string(observability.LabelAllocation),
"error", err)
return
}
if err := a.Actor.Send(reply); err != nil {
log.Warnw("allocation_healthcheck_reply_send_failure",
"labels", string(observability.LabelAllocation),
"error", err)
}
}
func (a *Allocation) execHealthCheck(mf types.HealthCheckManifest) error {
exitCode, stdout, stderr, err := a.executor.Exec(context.TODO(), a.executionID, mf.Exec)
log.Debugw("health_check_command_output",
"labels", string(observability.LabelAllocation),
"command", mf.Exec,
"stdout", stdout,
"stderr", stderr)
if err != nil {
log.Warnw("health_check_command_exec_failure",
"labels", string(observability.LabelAllocation),
"error", err)
return fmt.Errorf("health check command failed: %w", err)
}
if exitCode != 0 {
log.Warnw("health_check_command_exitcode_failure",
"labels", string(observability.LabelAllocation),
"exitCode", exitCode)
return fmt.Errorf("health check command failed with exit code %d", exitCode)
}
if !strings.Contains(stdout+stderr, mf.Response.Value) {
log.Warnw("health_check_command_unexpected_output",
"labels", string(observability.LabelAllocation),
"stderr", stderr,
"expectedValue", mf.Response.Value)
return fmt.Errorf("unexpected health check command output: %s\nstderr: %s", stdout, stderr)
}
log.Debugw("health_check_command_succeeded",
"labels", string(observability.LabelAllocation))
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobs
import (
"encoding/json"
"fmt"
"go.uber.org/multierr"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
)
func (a *Allocation) handleSubnetAddPeer(msg actor.Envelope) {
defer msg.Discard()
var request behaviors.SubnetAddPeerRequest
resp := behaviors.SubnetAddPeerResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
err := a.network.AddSubnetPeer(request.SubnetID, request.PeerID, request.IP)
if err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
log.Debugw("subnet_peer_added",
"labels", []string{},
"peerID", request.PeerID,
"subnetID", request.SubnetID)
resp.OK = true
a.sendReply(msg, resp)
}
func (a *Allocation) handleSubnetAcceptPeers(msg actor.Envelope) {
defer msg.Discard()
var request behaviors.SubnetAcceptPeersRequest
resp := behaviors.SubnetAcceptPeersResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
err := a.network.AcceptSubnetPeers(request.SubnetID, request.PartialRoutingTable)
if err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
log.Debugw("subnet_peer_accepted",
"labels", []string{},
"peers", request.PartialRoutingTable)
resp.OK = true
a.sendReply(msg, resp)
}
func (a *Allocation) handleSubnetMapPort(msg actor.Envelope) {
defer msg.Discard()
log.Debugw("handle_subnet_map_port_invoked", "from", msg.From)
var request behaviors.SubnetMapPortRequest
resp := behaviors.SubnetMapPortResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
log.Debugw("subnet_map_port_unmarshal_error",
"labels", []string{},
"error", err)
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
err := a.network.MapPort(request.SubnetID, request.Protocol, request.SourceIP, request.SourcePort, request.DestIP, request.DestPort)
if err != nil {
log.Debugw("subnet_map_port_error",
"labels", []string{},
"error", err)
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
log.Debugw("subnet_port_mapped",
"labels", []string{},
"sourcePort", request.SourcePort,
"subnetID", request.SubnetID)
resp.OK = true
a.sendReply(msg, resp)
}
func (a *Allocation) handleSubnetDNSAddRecords(msg actor.Envelope) {
defer msg.Discard()
var request behaviors.SubnetDNSAddRecordsRequest
resp := behaviors.SubnetDNSAddRecordsResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
err := a.network.AddSubnetDNSRecords(request.SubnetID, request.Records)
if err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
log.Debugw("subnet_dns_records_added",
"labels", []string{},
"records", request.Records,
"subnetID", request.SubnetID)
resp.OK = true
a.sendReply(msg, resp)
}
func (a *Allocation) handleSubnetUnmapPort(msg actor.Envelope) {
defer msg.Discard()
var request behaviors.SubnetUnmapPortRequest
resp := behaviors.SubnetUnmapPortResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
err := a.network.UnmapPort(
request.SubnetID, request.Protocol, request.SourceIP, request.SourcePort, request.DestIP, request.DestPort,
)
if err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
log.Debugw("subnet_port_unmapped",
"labels", []string{},
"sourcePort", request.SourcePort,
"subnetID", request.SubnetID)
resp.OK = true
a.sendReply(msg, resp)
}
func (a *Allocation) handleSubnetDNSRemoveRecords(msg actor.Envelope) {
defer msg.Discard()
var request behaviors.SubnetDNSRemoveRecordsRequest
resp := behaviors.SubnetDNSRemoveRecordsResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
var errs error
for _, domain := range request.DomainNames {
err := a.network.RemoveSubnetDNSRecord(request.SubnetID, domain)
if err != nil {
errs = multierr.Append(errs, fmt.Errorf("error removing dns record: %w", err))
}
}
if errs != nil {
resp.Error = errs.Error()
a.sendReply(msg, resp)
return
}
log.Debugw("subnet_dns_record_removed",
"labels", []string{},
"domains", request.DomainNames,
"subnetID", request.SubnetID)
resp.OK = true
a.sendReply(msg, resp)
}
func (a *Allocation) handleSubnetRemovePeers(msg actor.Envelope) {
defer msg.Discard()
var request behaviors.SubnetRemovePeersRequest
resp := behaviors.SubnetRemovePeersResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
err := a.network.RemoveSubnetPeers(request.SubnetID, request.PartialRoutingTable)
if err != nil {
resp.Error = err.Error()
a.sendReply(msg, resp)
return
}
log.Debugw("subnet_peer_removed",
"labels", []string{},
"peers", request.PartialRoutingTable)
resp.OK = true
a.sendReply(msg, resp)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ensemblev1
import (
"fmt"
"maps"
"reflect"
"strings"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/utils"
)
func NewEnsemblev1Decoder() transform.Transformer {
return transform.NewTransformer(
[]map[tree.Path]transform.TransformerFunc{
// Transform key value pairs to slices with name as key
{
"allocations.*.volumes": transform.MapToNamedSliceTransformer("volume"),
"volumes": transform.MapToNamedSliceTransformer("volume"),
"resources": transform.MapToNamedSliceTransformer("resource"),
},
// Transform configs
{
"allocations.*.volumes.[]": TransformVolume,
"allocations.*.resources": TransformResources,
"scripts.*": TransformStringToBytes,
"allocations.*.execution.environment": TransformEnvironment,
},
// Transform numeric values
{
"allocations.*.resources.cpu.clock_speed": transform.ParseWithDefaultUnit("cpu clock_speed", "GHz"),
"allocations.*.resources.ram.clock_speed": transform.ParseWithDefaultUnit("ram clock_speed", "GHz"),
"allocations.*.resources.ram.size": transform.ParseBytesWithDefaultUnit("ram size", "GiB"),
"allocations.*.resources.disk.size": transform.ParseBytesWithDefaultUnit("disk size", "GiB"),
"allocations.*.resources.gpu.[].vram": transform.ParseBytesWithDefaultUnit("gpu vram", "GiB"),
"allocations.*.healthcheck.interval": transform.ParseDuration("healthcheck duration"),
},
{
"allocations.*.execution": transform.ToSpecConfigTransformer("execution"),
"allocations.*.volumes.[].remote": transform.ToSpecConfigTransformer("remote volume"),
"edge_constraints.[]": TransformEdgeConstraint,
},
{
"": TransformSpec,
},
},
)
}
func TransformStringToBytes(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
switch v := data.(type) {
case []byte:
return v, nil
case string:
return []byte(v), nil
default:
return nil, fmt.Errorf("invalid data type: %T,", data)
}
}
// TransformSpec transforms the spec configuration and wraps it in a "V1" key.
func TransformSpec(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
spec, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid spec configuration: %v", data)
}
// set default values for allocations
if allocations, ok := spec["allocations"]; ok {
for allocName, alloc := range allocations.(map[string]any) {
if allocation, ok := alloc.(map[string]any); ok {
// set dns_name of allocations to the allocation name if not set
if allocation["dns_name"] == nil {
allocation["dns_name"] = allocName
}
// set failure_recovery to "stay_down" if not set
if allocation["failure_recovery"] == nil {
allocation["failure_recovery"] = defaultAllocationFailureStrategy
}
}
}
}
// set default values for nodes
if nodes, ok := spec["nodes"]; ok {
for _, node := range nodes.(map[string]any) {
if nodeConfig, ok := node.(map[string]any); ok {
// set failure_recovery to "stay_down" if not set
if nodeConfig["failure_recovery"] == nil {
nodeConfig["failure_recovery"] = defaultNodeFailureStrategy
}
// set redundancy to 0 if not set
if nodeConfig["redundancy"] == nil {
nodeConfig["redundancy"] = 0
}
}
}
}
// move edge_constraints to edges
if edgeConstraints, ok := spec["edge_constraints"]; ok {
spec["edges"] = edgeConstraints
delete(spec, "edge_constraints")
}
return map[string]any{"v1": spec}, nil
}
// TransformEdgeConstraint maps the edges parameter to Source and Target (S and T) properties.
func TransformEdgeConstraint(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
edgeConstraints, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid edge constraints: %v", data)
}
// Map the edges parameter to Source and Target (S and T) properties
if edges, ok := edgeConstraints["edges"]; ok {
// Assert edges is a list of two strings
edgesList, ok := edges.([]any)
if !ok || len(edgesList) != 2 {
return nil, fmt.Errorf("invalid edges parameter: %v", edges)
}
edgeConstraints["S"] = edgesList[0]
edgeConstraints["T"] = edgesList[1]
delete(edgeConstraints, "edges")
}
return edgeConstraints, nil
}
// TransformVloume transforms the volume configuration and handles inheritance.
// The volume configuration can be a string in the format "name:mountpoint" or a map.
// If the volume is defined in the parent volumes, the configurations are merged.
func TransformVolume(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, split it into name and mountpoint.
switch v := data.(type) {
case string:
mapping := strings.Split(v, ":")
if len(mapping) != 2 {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
config = map[string]any{
"name": mapping[0],
"mountpoint": mapping[1],
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
if path.Matches("allocations.*.volumes.[]") {
// Handle volume inheritance
parent := tree.NewPath("")
if c, err := utils.GetConfigAtPath(*root, parent.Next("volumes")); err == nil {
for _, v := range c.([]any) {
if volume, ok := v.(map[string]any); ok && volume["name"] == config["name"] {
// Merge the configurations
maps.Copy(volume, config)
config = volume
}
}
}
}
return config, nil
}
// TransformResources transforms the resources configuration and handles inheritance.
// The resources configuration can be a string reference "reference" or a map.
// If the resources is defined in the parent resources, the configurations are merged.
func TransformResources(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, transform it to a map with the name as the reference
switch v := data.(type) {
case string:
config = map[string]any{
"name": v,
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid resources configuration: %v", data)
}
if path.Matches("allocations.*.resources") {
// Handle volume inheritance
parent := tree.NewPath("")
if c, err := utils.GetConfigAtPath(*root, parent.Next("resources")); err == nil {
for _, v := range c.([]any) {
if rcs, ok := v.(map[string]any); ok && rcs["name"] == config["name"] {
// Merge the configurations
maps.Copy(rcs, config)
config = rcs
}
}
}
}
return config, nil
}
func TransformEnvironment(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
switch v := data.(type) {
case map[string]any, map[string]string:
envs := make([]string, 0)
for k, v := range reflect.ValueOf(v).Seq2() {
envs = append(envs, fmt.Sprintf("%s=%s", k, v))
}
return envs, nil
case []string, []any:
return v, nil
default:
return nil, fmt.Errorf("invalid environment configuration: %T", data)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ensemblev1
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
func NewEnsemblev1Encoder() transform.Transformer {
return transform.NewTransformer(
[]map[tree.Path]transform.TransformerFunc{
{
"": FormatSpec,
},
{
"allocations.*.execution": transform.FlattenSpecConfigTransformer("execution"),
"allocations.*.volumes.[].remote": transform.FlattenSpecConfigTransformer("volume remote"),
},
{
"allocations.*.resources.cpu.clock_speed": transform.ToSIFormatWithUnit("cpu clock_speed", "Hz"),
"allocations.*.resources.ram.clock_speed": transform.ToSIFormatWithUnit("ram clock_speed", "Hz"),
"allocations.*.resources.ram.size": transform.ToBytesFormat("ram size"),
"allocations.*.resources.disk.size": transform.ToBytesFormat("disk size"),
"allocations.*.resources.gpu.[].vram": transform.ToBytesFormat("gpu vram"),
},
{
"allocations.*.volumes": transform.NamedSliceToMapTransformer("volumes"),
},
},
)
}
func FormatSpec(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
spec, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid spec configuration: %v", data)
}
v1, ok := spec["v1"].(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid spec configuration: %v", data)
}
v1["version"] = "v1"
return v1, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ensemblev1
import "gitlab.com/nunet/device-management-service/dms/jobs/parser/types"
func NewEnsemblev1Parser() types.BasicParser {
return types.NewBasicParser(
"yaml",
resolvePlaceholders,
NewEnsemblev1Decoder(),
NewEnsemblev1Encoder(),
NewEnsembleV1Validator(),
)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ensemblev1
import (
"gitlab.com/nunet/device-management-service/dms/jobs/parser/resolve"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/types"
)
func resolvePlaceholders(data *any, options *types.Options) error {
resolver := resolve.NewResolver(
map[string]resolve.Handler{
"env": resolve.NewEnvResolver(options.Env),
"file": resolve.NewFileResolver(options.Fs, options.WorkingDir),
},
nil,
)
return tree.Walk(data, tree.NewPath(), func(node *any, _ tree.Path) error {
if strVal, ok := (*node).(string); ok {
interpolated, err := resolver.Process(strVal)
if err != nil {
return err
}
*node = interpolated
}
return nil
})
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ensemblev1
import (
"fmt"
"path/filepath"
"reflect"
"regexp"
"slices"
"strings"
"time"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/utils"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
"gitlab.com/nunet/device-management-service/types"
cutils "gitlab.com/nunet/device-management-service/utils/convert"
vutils "gitlab.com/nunet/device-management-service/utils/validate"
)
const (
defaultAllocationFailureStrategy = "stay_down"
defaultNodeFailureStrategy = "stay_down"
)
var (
validEscalationStrategies = [...]string{"redeploy", "teardown"}
validAllocationFailureStrategies = [...]string{"stay_down", "one_for_one", "one_for_all", "rest_for_one"}
validNodeFailureStrategies = [...]string{"stay_down", "restart", "redeploy"}
)
// NewEnsembleV1Validator creates a new validator for the NuNet configuration.
func NewEnsembleV1Validator() validate.Validator {
return validate.NewValidator(
map[tree.Path]validate.ValidatorFunc{
"V1": ValidateSpec,
"V1.subnet": ValidateSubnet,
"V1.allocations.*": ValidateAllocation,
"V1.edges.[]": ValidateEdgeConstraints,
"V1.nodes.*": ValidateNode,
"V1.supervisor": ValidateSupervisor,
"V1.supervisor.children.[]": ValidateSupervisor,
"V1.allocations.*.resources": ValidateResources,
"V1.allocations.*.execution": ValidateExecution,
"V1.allocations.*.healthcheck": ValidateHealthCheck,
"V1.allocations.*.volume": ValidateVolume,
"V1.contracts.*": ValidateContract,
},
)
}
// ValidateSpec checks the root configuration for consistency.
func ValidateSpec(_ *map[string]any, data any, _ tree.Path) error {
spec, dataOk := data.(map[string]any)
if !dataOk {
return fmt.Errorf("invalid spec configuration: %v", data)
}
// validate escalation strategy if present
if es, ok := spec["escalation_strategy"].(string); ok {
if !slices.Contains(validEscalationStrategies[:], es) {
return fmt.Errorf("invalid escalation_strategy %q: must be one of %q", es, validEscalationStrategies)
}
}
// Check if the allocations map is present and not empty
allocs, ok := spec["allocations"].(map[string]any)
if !ok || len(allocs) == 0 {
return fmt.Errorf("at least one allocation must be defined")
}
allocationNames := make(map[string]string)
dnsNames := make(map[string]string)
for allocName, allocConfigRaw := range allocs {
// All allocation names must be fully qualified domain names
if !vutils.IsDNSNameValid(allocName) {
return fmt.Errorf("invalid allocation name, must be a valid hostname: %s", allocName)
}
// Check for duplicate allocation names (case-insensitive)
lowerName := strings.ToLower(allocName)
if originalName, exists := allocationNames[lowerName]; exists {
return fmt.Errorf("duplicate allocation names found: '%s' and '%s'", originalName, allocName)
}
allocationNames[lowerName] = allocName
// Check for duplicate dns_name values
allocConfig, ok := allocConfigRaw.(map[string]any)
if !ok {
continue
}
dnsName, ok := allocConfig["dns_name"].(string)
if !ok || dnsName == "" {
continue // skip if dns_name is not set or not a string
}
if existingAlloc, exists := dnsNames[dnsName]; exists {
return fmt.Errorf("duplicate dns_name found: '%s' used in allocations '%s' and '%s'", dnsName, existingAlloc, allocName)
}
dnsNames[dnsName] = allocName
}
// check for cyclic dependencies
graph := utils.CreateAdjencyList(allocs, tree.NewPath("depends_on"))
if utils.DetectCycles(graph) {
return fmt.Errorf("cyclic dependencies detected in allocations")
}
// Check if nodes are defined when edge_constraints are present
if edges, ok := spec["edges"].([]any); ok && len(edges) > 0 {
if spec["nodes"] == nil {
return fmt.Errorf("nodes must be defined when edge_constraints are present")
}
}
// Check that no allocation is present in multiple nodes and that dependencies are in the same node
if nodes, ok := spec["nodes"].(map[string]any); ok && len(nodes) > 0 {
// Create a map to track which node each allocation belongs to
allocToNode := make(map[string]string)
// Build the allocation-to-node map
for nodeName, nodeConfig := range nodes {
nodeMap, ok := nodeConfig.(map[string]any)
if !ok {
continue
}
nodeAllocs, ok := nodeMap["allocations"].([]any)
if !ok {
continue
}
for _, alloc := range nodeAllocs {
allocName, ok := alloc.(string)
if !ok {
continue
}
// Check if this allocation is already assigned to another node
if existingNode, exists := allocToNode[allocName]; exists {
return fmt.Errorf("allocation '%s' is assigned to multiple nodes ('%s' and '%s'): an allocation can only be assigned to one node", allocName, existingNode, nodeName)
}
// Record this allocation's node
allocToNode[allocName] = nodeName
// Check dependencies immediately
allocConfig, ok := allocs[allocName].(map[string]any)
if !ok {
continue // Should never happen
}
dependencies, ok := allocConfig["depends_on"].([]any)
if !ok {
continue
}
for _, dep := range dependencies {
depName, ok := dep.(string)
if !ok {
continue
}
// Check if the dependency is in the same node
if !slices.Contains(nodeAllocs, dep) {
return fmt.Errorf("allocation '%s' depends on '%s', but '%s' is not in the same node: dependent allocations must be in the same node", allocName, depName, depName)
}
}
}
}
}
return nil
}
// ValidateAllocation checks the allocation configuration.
func ValidateAllocation(root *map[string]any, data any, path tree.Path) error {
allocation, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid allocation configuration: %v", data)
}
// Check if the allocation has an execution
if allocation["execution"] == nil {
return fmt.Errorf("allocation must have an execution")
}
// Check if the allocation has resources
if allocation["resources"] == nil {
return fmt.Errorf("allocation must have resources")
}
// Validate executor type matches execution type
executor, ok := allocation["executor"].(string)
if !ok || executor == "" {
return fmt.Errorf("allocation must have an executor")
}
execution, ok := allocation["execution"].(map[string]any)
if !ok {
return fmt.Errorf("invalid execution configuration")
}
execType, ok := execution["type"].(string)
if !ok || execType == "" {
return fmt.Errorf("execution must have a type")
}
if executor != execType {
return fmt.Errorf("allocation executor type (%s) must match execution type (%s)", executor, execType)
}
allocType, ok := allocation["type"].(string)
if !ok || allocType == "" {
return fmt.Errorf("allocation must have a type (service or task)")
}
// Validate DNS name if present
if dnsName, ok := allocation["dns_name"].(string); ok {
if dnsName == "" {
return fmt.Errorf("dns_name cannot be empty if specified")
}
// Add basic DNS name format validation
if !vutils.IsDNSNameValid(dnsName) {
return fmt.Errorf("invalid dns_name format: %s", dnsName)
}
}
// validate failure_recovery
if failureRecovery, ok := allocation["failure_recovery"].(string); ok {
if !slices.Contains(validAllocationFailureStrategies[:], failureRecovery) {
return fmt.Errorf("invalid failure_recovery %q: must be one of %q", failureRecovery, validAllocationFailureStrategies)
}
} else {
return fmt.Errorf("failure_recovery must be specified for allocation")
}
// validate depends_on if present
if dependsOn, ok := allocation["depends_on"]; ok {
dependsOn, ok := dependsOn.([]any)
if !ok {
return fmt.Errorf("depends_on must be a list of allocation names")
}
var allocs map[string]any
if cfg, err := utils.GetConfigAtPath(*root, "V1.allocations"); err == nil {
allocs, _ = cfg.(map[string]any)
}
for _, v := range dependsOn {
dep, ok := v.(string)
if !ok {
return fmt.Errorf("depends_on must be a list of allocation names")
}
if dep != "" && path.Last() == dep {
return fmt.Errorf("depends_on must not refer to itself")
}
if alloc, exists := allocs[dep]; !exists || alloc == nil {
return fmt.Errorf("depends_on allocation '%s' not found", dependsOn)
}
}
}
// Validate keys if specified
if keys, ok := allocation["keys"].([]any); ok && len(keys) > 0 {
for i, keyObj := range keys {
keyMap, ok := keyObj.(map[string]any)
if !ok {
return fmt.Errorf("invalid allocation key spec at index %d: must be a map", i)
}
keyType, ok := keyMap["type"].(string)
if !ok || keyType == "" {
return fmt.Errorf("allocation key spec at index %d must have a type", i)
}
if !types.KeySSH.Equal(keyType) && !types.KeyGPG.Equal(keyType) {
return fmt.Errorf("key at index %d has invalid type: %s (must be 'ssh' or 'gpg')", i, keyType)
}
keyFile, ok := keyMap["file"].(string)
if !ok || keyFile == "" {
return fmt.Errorf("allocation key at index %d is empty", i)
}
// destination not required for ssh keys. However 'user' in execution will be
// used if defined. When a user isn't defined, we default to root.
keyDest, ok := keyMap["dest"].(string)
if (!ok || keyDest == "") && !types.KeySSH.Equal(keyType) {
return fmt.Errorf("allocation key at index %d is missing a destination", i)
}
}
}
// Validate provision scripts if specified
if provision, ok := allocation["provision"].([]any); ok && len(provision) > 0 {
rootScripts, err := utils.GetConfigAtPath(*root, tree.NewPath("V1.scripts"))
if err != nil || rootScripts == nil {
return fmt.Errorf("scripts must be defined when provision is defined")
}
rootScriptsMap, ok := rootScripts.(map[string]any)
if !ok {
return fmt.Errorf("invalid scripts configuration")
}
for _, script := range provision {
scriptStr, ok := script.(string)
if !ok {
return fmt.Errorf("invalid script reference: %v", script)
}
if _, exists := rootScriptsMap[scriptStr]; !exists {
return fmt.Errorf("referenced script '%s' not found", scriptStr)
}
}
}
return nil
}
// ValidateResources validates the resource configuration
func ValidateResources(_ *map[string]any, data any, _ tree.Path) error {
resources, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid resources configuration: %v", data)
}
// Validate CPU (required)
cpu, ok := resources["cpu"].(map[string]any)
if !ok {
return fmt.Errorf("resources must have cpu configuration")
}
// Handle cores as any number type and convert to positive integer
cores, ok := cpu["cores"]
if !ok {
return fmt.Errorf("cpu must have cores value")
}
coresFloat, err := cutils.ToPositiveFloat64(cores, "cpu cores")
if err != nil {
return err
}
cpu["cores"] = coresFloat
// Optional CPU fields
if arch, ok := cpu["architecture"].(string); ok && arch == "" {
return fmt.Errorf("cpu architecture cannot be empty if specified")
}
if freq, ok := cpu["clock_speed"]; ok {
if _, err := cutils.ToPositiveFloat64(freq, "cpu clock_speed"); err != nil {
return err
}
}
// Validate RAM (required)
ram, ok := resources["ram"].(map[string]any)
if !ok {
return fmt.Errorf("resources must have ram configuration")
}
// Validate RAM size
size, ok := ram["size"]
if !ok {
return fmt.Errorf("ram must have size value")
}
sizeFloat, err := cutils.ToPositiveFloat64(size, "ram size")
if err != nil {
return err
}
ram["size"] = sizeFloat
// Optional RAM speed
if speed, ok := ram["clock_speed"]; ok {
speedFloat, err := cutils.ToPositiveFloat64(speed, "ram clock_speed")
if err != nil {
return err
}
ram["clock_speed"] = uint64(speedFloat)
}
// Validate disk (required)
disk, ok := resources["disk"].(map[string]any)
if !ok {
return fmt.Errorf("resources must have disk configuration")
}
// Validate disk size
diskSize, ok := disk["size"]
if !ok {
return fmt.Errorf("disk must have size value")
}
diskSizeFloat, err := cutils.ToPositiveFloat64(diskSize, "disk size")
if err != nil {
return err
}
disk["size"] = diskSizeFloat
// Optional disk type
if diskType, ok := disk["type"].(string); ok && diskType == "" {
return fmt.Errorf("disk type cannot be empty if specified")
}
// Optional GPUs
if gpusRaw, ok := resources["gpus"]; ok {
gpus, ok := gpusRaw.([]any)
if !ok {
return fmt.Errorf("gpus must be an array")
}
for i, gpuRaw := range gpus {
gpu, ok := gpuRaw.(map[string]any)
if !ok {
return fmt.Errorf("gpu at index %d is not a valid configuration", i)
}
// Validate GPU memory if specified
if memory, ok := gpu["memory"]; ok {
memoryFloat, err := cutils.ToPositiveFloat64(memory, fmt.Sprintf("gpu memory at index %d", i))
if err != nil {
return err
}
gpu["memory"] = memoryFloat
}
// Validate vendor if specified
if vendor, ok := gpu["vendor"].(string); ok && vendor == "" {
return fmt.Errorf("gpu vendor cannot be empty if specified at index %d", i)
}
// Validate model if specified
if model, ok := gpu["model"].(string); ok && model == "" {
return fmt.Errorf("gpu model cannot be empty if specified at index %d", i)
}
}
}
return nil
}
// ValidateNode checks the node configuration.
func ValidateNode(root *map[string]any, data any, _ tree.Path) error {
node, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid node configuration: %v", data)
}
// Validate allocation references
if allocations, ok := node["allocations"].([]any); ok {
rootAllocs, err := utils.GetConfigAtPath(*root, tree.NewPath("V1.allocations"))
if err != nil || rootAllocs == nil {
return fmt.Errorf("allocations must be defined when node is defined")
}
rootAllocsMap, ok := rootAllocs.(map[string]any)
if !ok {
return fmt.Errorf("invalid allocations configuration")
}
for _, alloc := range allocations {
allocStr, ok := alloc.(string)
if !ok {
return fmt.Errorf("invalid allocation reference: %v", alloc)
}
if _, exists := rootAllocsMap[allocStr]; !exists {
return fmt.Errorf("referenced allocation '%s' not found", allocStr)
}
}
}
// validate redundancy if present
if redundancy, ok := node["redundancy"]; ok {
// check if redundancy is a positive integer
redundancyInt, ok := redundancy.(int)
if !ok {
return fmt.Errorf("redundancy must be a number")
}
if redundancyInt < 0 {
return fmt.Errorf("redundancy must be a positive number")
}
}
// validate failure recovery
if failureRecovery, ok := node["failure_recovery"].(string); ok {
if !slices.Contains(validNodeFailureStrategies[:], failureRecovery) {
return fmt.Errorf("invalid failure_recovery %q: must be one of %q", failureRecovery, validNodeFailureStrategies)
}
} else {
return fmt.Errorf("failure_recovery must be specified for node")
}
// Validate ports if specified
if ports, ok := node["ports"].([]any); ok {
for _, port := range ports {
portMap, ok := port.(map[string]any)
if !ok {
return fmt.Errorf("invalid port configuration: %v", port)
}
// Validate private port
if private, ok := portMap["private"].(int); !ok || private <= 0 {
return fmt.Errorf("port must have a valid private port number")
}
// Require public port when private port is specified
if _, ok := portMap["public"].(int); !ok {
return fmt.Errorf("public port must be specified when private port is defined")
}
// Validate public port
if public, ok := portMap["public"].(int); ok && (public != 0 && public < 1025 || public > 65535) {
return fmt.Errorf("port must have a valid public port number between 1025 and 65535 if specified")
}
// Validate allocation reference
if alloc, ok := portMap["allocation"].(string); ok {
rootAllocs, err := utils.GetConfigAtPath(*root, tree.NewPath("V1.allocations"))
if err != nil || rootAllocs == nil {
return fmt.Errorf("allocations must be defined when referenced in port")
}
rootAllocsMap, ok := rootAllocs.(map[string]any)
if !ok {
return fmt.Errorf("invalid allocations configuration")
}
if _, exists := rootAllocsMap[alloc]; !exists {
return fmt.Errorf("referenced allocation '%s' not found in port configuration", alloc)
}
}
}
}
// Validate location constraints if specified
if location, ok := node["location"].(map[string]any); ok {
if err := validateLocationConstraints(location); err != nil {
return fmt.Errorf("invalid location constraints: %v", err)
}
}
// Validate peer if specified
if peer, ok := node["peer"].(string); ok && peer == "" {
return fmt.Errorf("peer cannot be empty if specified")
}
return nil
}
// ValidateExecution checks the execution configuration.
func ValidateExecution(_ *map[string]any, data any, _ tree.Path) error {
execution, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid execution configuration: %v", data)
}
// Check execution type
execType, ok := execution["type"].(string)
if !ok || execType == "" {
return fmt.Errorf("execution must have a type")
}
// Get params
params, ok := execution["params"].(map[string]any)
if !ok {
return fmt.Errorf("execution must have params")
}
// Validate based on execution type
switch execType {
case "docker":
if err := validateDockerExecution(params); err != nil {
return err
}
case "firecracker":
if err := validateFirecrackerExecution(params); err != nil {
return err
}
case "null":
// No specific validation for null executor
return nil
default:
return fmt.Errorf("unsupported execution type: %s", execType)
}
return nil
}
// validateDockerExecution validates docker-specific execution configuration
func validateDockerExecution(execution map[string]any) error {
// Validate image (required)
image, ok := execution["image"].(string)
if !ok || image == "" {
return fmt.Errorf("docker execution must have an image")
}
// Validate image format with a single regex
if !vutils.IsDockerImageValid(image) {
return fmt.Errorf("invalid docker image format: %s", image)
}
// Validate entrypoint if present
if entrypoint, ok := execution["entrypoint"].([]any); ok {
for i, entry := range entrypoint {
if _, ok := entry.(string); !ok {
return fmt.Errorf("docker entrypoint at index %d must be a string", i)
}
}
}
// Validate command if present
if cmd, ok := execution["cmd"].([]any); ok {
for i, c := range cmd {
if _, ok := c.(string); !ok {
return fmt.Errorf("docker command at index %d must be a string", i)
}
}
}
// Validate environment variables if present
if envs, ok := execution["environment"]; ok {
v := reflect.ValueOf(envs)
if v.Kind() != reflect.Slice {
return fmt.Errorf("docker environment must be a slice")
}
for i := range v.Len() {
envStr, ok := v.Index(i).Interface().(string)
if !ok {
return fmt.Errorf("docker environment variable at index %d must be a string", i)
}
if envStr == "" {
return fmt.Errorf("docker environment variable at index %d cannot be empty", i)
}
// Verify format is KEY=VALUE
if !strings.Contains(envStr, "=") {
return fmt.Errorf("docker environment variable at index %d must be in KEY=VALUE format", i)
}
parts := strings.SplitN(envStr, "=", 2)
if parts[0] == "" {
return fmt.Errorf("docker environment variable key at index %d cannot be empty", i)
}
// Validate environment variable key format
if !vutils.IsEnvVarKeyValid(parts[0]) {
return fmt.Errorf("invalid environment variable key format at index %d: %s", i, parts[0])
}
}
}
// Validate working directory if present
if workDir, ok := execution["working_directory"].(string); ok && workDir == "" {
return fmt.Errorf("docker working directory cannot be empty if specified")
}
if restartPolicy, ok := execution["restart_policy"].(string); ok {
if restartPolicy == "" {
return fmt.Errorf("docker restart_policy cannot be empty if specified")
}
// validate restart policy is one of the allowed values
switch restartPolicy {
case "no", "on-failure", "always", "unless-stopped":
// valid policies
default:
return fmt.Errorf("invalid docker restart_policy: %s", restartPolicy)
}
}
return nil
}
// validateFirecrackerExecution validates firecracker-specific execution configuration
func validateFirecrackerExecution(execution map[string]any) error {
// Validate kernel image (required)
kernelImage, ok := execution["kernel_image"].(string)
if !ok || kernelImage == "" {
return fmt.Errorf("firecracker execution must have a kernel_image")
}
// Validate root file system (required)
rootFS, ok := execution["root_file_system"].(string)
if !ok || rootFS == "" {
return fmt.Errorf("firecracker execution must have a root_file_system")
}
// Validate kernel args if present
if kernelArgs, ok := execution["kernel_args"].(string); ok && kernelArgs == "" {
return fmt.Errorf("firecracker kernel_args cannot be empty if specified")
}
// Validate initrd if present
if initrd, ok := execution["initrd"].(string); ok && initrd == "" {
return fmt.Errorf("firecracker initrd cannot be empty if specified")
}
return nil
}
// ValidateSupervisor checks the supervisor configuration.
func ValidateSupervisor(root *map[string]any, data any, _ tree.Path) error {
supervisor, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid supervisor configuration: %v", data)
}
// Validate strategy if specified
if strategy, ok := supervisor["strategy"].(string); ok {
if strategy == "" {
return fmt.Errorf("supervisor strategy cannot be empty if specified")
}
// Validate strategy is one of the allowed values
switch strategy {
case "OneForOne", "AllForOne", "RestForOne":
// valid strategies
default:
return fmt.Errorf("invalid supervisor strategy: %s", strategy)
}
}
// Validate allocations if specified
if allocations, ok := supervisor["allocations"].([]any); ok {
rootAllocs, err := utils.GetConfigAtPath(*root, tree.NewPath("V1.allocations"))
if err != nil || rootAllocs == nil {
return fmt.Errorf("allocations must be defined when supervisor is defined")
}
rootAllocsMap, ok := rootAllocs.(map[string]any)
if !ok {
return fmt.Errorf("invalid allocations configuration")
}
for _, alloc := range allocations {
allocStr, ok := alloc.(string)
if !ok {
return fmt.Errorf("invalid allocation reference: %v", alloc)
}
if _, exists := rootAllocsMap[allocStr]; !exists {
return fmt.Errorf("referenced allocation '%s' not found", allocStr)
}
}
}
// Validate children if specified - only check type since path validation will handle the rest
if children, ok := supervisor["children"].([]any); ok {
for i, child := range children {
if _, ok := child.(map[string]any); !ok {
return fmt.Errorf("invalid child supervisor at index %d: must be a map", i)
}
}
}
return nil
}
// validateLocationConstraints validates the location constraints configuration
func validateLocationConstraints(location map[string]any) error {
// Validate accept locations if present
if accept, ok := location["accept"].([]any); ok {
for i, loc := range accept {
if err := validateLocation(loc, i); err != nil {
return fmt.Errorf("invalid accept location at index %d: %v", i, err)
}
}
}
// Validate reject locations if present
if reject, ok := location["reject"].([]any); ok {
for i, loc := range reject {
if err := validateLocation(loc, i); err != nil {
return fmt.Errorf("invalid reject location at index %d: %v", i, err)
}
}
}
return nil
}
// validateLocation validates a single location configuration
func validateLocation(data any, _ int) error {
location, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("location must be a map")
}
// continent required if specifying location
continent, ok := location["continent"].(string)
if !ok || continent == "" {
return fmt.Errorf("continent is required when specifying a location")
}
// Country is optional
if country, ok := location["country"].(string); ok && country == "" {
return fmt.Errorf("country cannot be empty if specified")
}
// City is optional but requires country if specified
if city, ok := location["city"].(string); ok {
if city == "" {
return fmt.Errorf("city cannot be empty if specified")
}
country, hasCountry := location["country"].(string)
if !hasCountry || country == "" {
return fmt.Errorf("country must be specified when city is provided")
}
}
// ASN is optional
if asn, ok := location["asn"].(uint); ok {
if asn <= 0 {
return fmt.Errorf("ASN must be positive if specified")
}
}
// ISP is optional
if isp, ok := location["isp"].(string); ok && isp == "" {
return fmt.Errorf("ISP cannot be empty if specified")
}
return nil
}
// ValidateEdgeConstraints checks the edge constraints configuration.
func ValidateEdgeConstraints(root *map[string]any, data any, _ tree.Path) error {
edgeConstraints, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid edge constraints configuration: %v", data)
}
// Check if "S" and "T" keys exist and are non-empty
s, sOk := edgeConstraints["S"].(string)
t, tOk := edgeConstraints["T"].(string)
if !sOk || s == "" || !tOk || t == "" {
return fmt.Errorf("invalid edge constraints configuration: edges should be a pair of named nodes")
}
// Validate that S and T are different nodes
if s == t {
return fmt.Errorf("edge constraint source and target must be different nodes")
}
// Check if "nodes" key exists
nodesConfig, err := utils.GetConfigAtPath(*root, tree.NewPath("V1.nodes"))
if err != nil || nodesConfig == nil {
return fmt.Errorf("invalid edge constraints configuration: nodes must be defined")
}
nodes, nodesOk := nodesConfig.(map[string]any)
if !nodesOk {
return fmt.Errorf("invalid edge constraints configuration: nodes must be a map")
}
// Check if S and T are present in the "nodes" map
if _, ok := nodes[s]; !ok {
return fmt.Errorf("invalid edge constraints configuration: node '%s' not found", s)
}
if _, ok := nodes[t]; !ok {
return fmt.Errorf("invalid edge constraints configuration: node '%s' not found", t)
}
// Validate RTT and BW if present
if rtt, ok := edgeConstraints["RTT"].(uint); ok {
edgeConstraints["RTT"] = rtt
}
if bw, ok := edgeConstraints["BW"].(uint); ok {
edgeConstraints["BW"] = bw
}
return nil
}
// ValidateHealthCheck checks the healthcheck configuration.
func ValidateHealthCheck(_ *map[string]any, data any, _ tree.Path) error {
healthcheck, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid healthcheck configuration: %v", data)
}
if len(healthcheck) == 0 {
return fmt.Errorf("healthcheck cannot be empty if specified")
}
// Check healthcheck type
hcType, ok := healthcheck["type"]
if !ok || hcType == "" {
return fmt.Errorf("healthcheck must have a type")
}
// Must have exec or endpoint
switch hcType {
case "http":
if _, ok := healthcheck["endpoint"]; !ok {
return fmt.Errorf("http type healthcheck must have an endpoint")
}
case "command":
if _, ok := healthcheck["exec"]; !ok {
return fmt.Errorf("command type healthcheck must have exec")
}
default:
return fmt.Errorf("unsupported healthcheck type: %s", hcType)
}
if interval, ok := healthcheck["interval"].(time.Duration); ok {
if interval == 0 {
return fmt.Errorf("healthcheck interval must be greater than 0")
}
}
return nil
}
// ValidateSubnet validates the subnet config
func ValidateSubnet(_ *map[string]any, data any, _ tree.Path) error {
subnet, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid subnet config")
}
if len(subnet) == 0 {
return fmt.Errorf("subnet can not be empty if specified")
}
if _, ok := subnet["join"].(bool); !ok {
return fmt.Errorf("subnet.join expects boolean value")
}
return nil
}
// ValidateContract validates the contract configuration.
// It checks that the contract has a valid DID and host format.
// Payment details are validated at contract creation time, not here.
func ValidateContract(_ *map[string]any, data any, _ tree.Path) error {
contract, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid contract configuration: %v", data)
}
// Validate DID
did, ok := contract["did"].(string)
if !ok {
return fmt.Errorf("contract 'did' must be a string")
}
if !strings.HasPrefix(did, "did:") {
return fmt.Errorf("invalid did format")
}
// Validate host (if present)
if host, ok := contract["host"].(string); ok {
if !strings.HasPrefix(host, "did:") {
return fmt.Errorf("invalid host did format")
}
}
// Payment details are validated at contract creation time, not here
return nil
}
// ValidateVolume validates the allocation's volume config
func ValidateVolume(_ *map[string]any, data any, _ tree.Path) error {
volumes, ok := data.([]any)
if !ok {
return fmt.Errorf("invalid volume configuration: %v", data)
}
if len(volumes) == 0 {
return fmt.Errorf("volume cannot be empty if specified")
}
for _, vol := range volumes {
volume, ok := vol.(map[string]any)
if !ok {
return fmt.Errorf("invalid type: %T - expecting map[string]", vol)
}
// Check volume type
volType, ok := volume["type"]
if !ok || volType == "" {
return fmt.Errorf("volume must have a type")
}
// types for now are local or glusterfs
switch volType {
case "glusterfs":
servers, ok := volume["servers"].([]string)
if !ok {
return fmt.Errorf("glusterfs type volume must define one or more servers")
}
if len(servers) == 0 {
return fmt.Errorf("glusterfs type volume must define at least one server")
}
for i, server := range servers {
serverStr := server
if serverStr == "" {
return fmt.Errorf("glusterfs volume server at index %d must be a non-empty string", i)
}
}
case "local":
volumeSrc, ok := volume["src"].(string)
if !ok {
return fmt.Errorf("local type volume must define source destination as 'src'")
}
if strings.Contains(volumeSrc, "/") {
// validate as a path
if !filepath.IsAbs(volumeSrc) {
return fmt.Errorf("local type volume source must be an absolute path")
}
} else {
// validate as a named volume (alphanumeric, dashes, underscores and dot only)
// ignoring linter because only in the case of multiple named volumes will the regex run multiple times
matched, err := regexp.MatchString(`^[a-zA-Z0-9][a-zA-Z0-9_.-]*$`, volumeSrc) //nolint:staticcheck
if err != nil {
return fmt.Errorf("error validating local volume src: %w", err)
}
if !matched {
return fmt.Errorf("local volume src contains invalid characters")
}
}
default:
return fmt.Errorf("unsupported volume type: %s", volType)
}
if _, ok := volume["mount_destination"]; !ok {
return fmt.Errorf("volume must define a mount destination")
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package parser
import (
"gitlab.com/nunet/device-management-service/dms/jobs/parser/ensemblev1"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/types"
)
var registry *Registry
func init() {
registry = &Registry{
parsers: make(map[SpecType]types.Parser),
}
// Register Nunet parser.
registry.RegisterParser(SpecTypeEnsembleV1, ensemblev1.NewEnsemblev1Parser())
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package parser
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/types"
)
type SpecType string
type Options types.Options
const (
SpecTypeEnsembleV1 SpecType = "ensembleV1"
)
func Decode(specType SpecType, data []byte, result any, opts *Options) error {
parser, exists := registry.GetParser(specType)
if !exists {
return fmt.Errorf("parser for spec type %s not found", specType)
}
err := parser.Decode(data, result, (*types.Options)(opts))
if err != nil {
return err
}
return nil
}
func Encode(specType SpecType, data any) ([]byte, error) {
parser, exists := registry.GetParser(specType)
if !exists {
return nil, fmt.Errorf("parser for spec type %s not found", specType)
}
return parser.Encode(data)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package parser
import (
"sync"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/types"
)
type Registry struct {
parsers map[SpecType]types.Parser
mu sync.RWMutex
}
func (r *Registry) RegisterParser(specType SpecType, p types.Parser) {
r.mu.Lock()
defer r.mu.Unlock()
r.parsers[specType] = p
}
func (r *Registry) GetParser(specType SpecType) (types.Parser, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
p, exists := r.parsers[specType]
return p, exists
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package resolve
import (
"bytes"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/lib/env"
)
type Handler interface {
Resolve(path string) ([]byte, error)
}
// FileResolver implements FileResolver for the local filesystem.
type FileResolver struct {
Fs afero.Afero
BasePath string
WorkingDir string
}
func NewFileResolver(fs afero.Fs, basePath string) Handler {
return &FileResolver{Fs: afero.Afero{Fs: fs}, BasePath: basePath}
}
func (l *FileResolver) Resolve(path string) ([]byte, error) {
joinedPath := filepath.Join(l.BasePath, path)
content, err := l.Fs.ReadFile(joinedPath)
if err != nil {
if os.IsNotExist(err) {
return nil, ErrNotExist
}
return nil, err
}
return content, nil
}
// EnvResolver implements EnvResolver for the environment variables
type EnvResolver struct {
Env env.EnvironmentProvider
}
func NewEnvResolver(env env.EnvironmentProvider) Handler {
return &EnvResolver{Env: env}
}
func (r *EnvResolver) Resolve(key string) ([]byte, error) {
val := r.Env.Getenv(key)
if val == "" {
return nil, ErrNotExist
}
return []byte(val), nil
}
type Resolver struct {
SourceHandlers map[string]Handler
expressionRegex *regexp.Regexp
}
func NewResolver(sourceHandlers map[string]Handler, expressionRegex *regexp.Regexp) *Resolver {
if expressionRegex == nil {
expressionRegex = regexp.MustCompile(`\${([^{}]+?)}`)
}
return &Resolver{
SourceHandlers: sourceHandlers,
expressionRegex: expressionRegex,
}
}
func (r *Resolver) Process(input string) (string, error) {
result := []byte(input)
for {
matches := r.expressionRegex.FindSubmatchIndex(result)
if matches == nil {
break
}
expression := result[matches[0]:matches[1]]
content := string(result[matches[2]:matches[3]])
resolvedValue, err := r.resolveContent(content)
if err != nil {
return "", fmt.Errorf("failed to resolve expression '%s': %w", expression, err)
}
result = bytes.Replace(result, expression, resolvedValue, 1)
}
return string(result), nil
}
func (r *Resolver) resolveContent(content string) ([]byte, error) {
parts := strings.SplitN(content, ":-", 2)
mainPart := parts[0]
var defaultValue []byte
hasDefault := len(parts) > 1
if hasDefault {
defaultValue = []byte(parts[1])
}
modifierParts := strings.Split(mainPart, "|")
sourceAndKey := modifierParts[0]
source, key, found := strings.Cut(sourceAndKey, ":")
if !found {
// Default to 'env' source if no source is specified
source = "env"
key = sourceAndKey
}
handler, exists := r.SourceHandlers[source]
if !exists {
return nil, fmt.Errorf("unknown source: '%s'", source)
}
value, err := handler.Resolve(key)
if err != nil {
if hasDefault {
return defaultValue, nil
}
return nil, err
}
return value, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package transform
import (
"fmt"
"maps"
"time"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/utils/convert"
)
// ToSpecConfigTransformer converts a map to a map with a "type" field and a "params" field.
func ToSpecConfigTransformer(specName string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
spec, ok := data.(map[string]any)
result := map[string]any{}
if !ok {
return nil, fmt.Errorf("invalid %s configuration: %v", specName, data)
}
params := map[string]any{}
for k, v := range spec {
if k != "type" {
params[k] = v
}
}
result["type"] = spec["type"]
result["params"] = params
return result, nil
}
}
func FlattenSpecConfigTransformer(specName string) TransformerFunc {
return func(_ *map[string]any, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
spec, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid %s configuration: %v", specName, data)
}
result := make(map[string]any)
result["type"] = spec["type"]
maps.Copy(result, spec["params"].(map[string]any))
return result, nil
}
}
// MapToNamedSliceTransformer converts a map of maps to a slice of maps and assigns the key to the "name" field.
func MapToNamedSliceTransformer(name string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
maps, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid %s configuration: %v", name, data)
}
result := []any{}
for k, v := range maps {
if v == nil {
v = map[string]any{}
}
if e, ok := v.(map[string]any); ok {
e["name"] = k
}
result = append(result, v)
}
return result, nil
}
}
// NamedSliceToMapTransformer converts a map of maps to a slice of maps and assigns the key to the "name" field.
func NamedSliceToMapTransformer(name string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
maps, ok := data.([]any)
if !ok {
return nil, fmt.Errorf("invalid %s configuration: %v", name, data)
}
result := make(map[string]any)
for i, v := range maps {
if v == nil {
v = map[string]any{}
}
m, ok := v.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid %s configuration: %v", name, data)
}
k, ok := m["name"].(string)
if !ok || k == "" {
k = fmt.Sprintf("%s_%d", name, i+1)
} else {
delete(m, "name")
}
result[k] = m
}
return result, nil
}
}
func ParseWithDefaultUnit(name string, defaultUnit string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
v, err := convert.ParseSIWithDefaultUnit(data, defaultUnit)
if err != nil {
return nil, fmt.Errorf("invalid %s value: %v", name, err)
}
return v, nil
}
}
func ParseBytesWithDefaultUnit(name string, defaultUnit string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
v, err := convert.ParseBytesWithDefaultUnit(data, defaultUnit)
if err != nil {
return nil, fmt.Errorf("invalid %s value: %v", name, err)
}
return v, nil
}
}
func ToBytesWithDefaultUnit(name string, defaultUnit string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
v, err := convert.ParseBytesWithDefaultUnit(data, defaultUnit)
if err != nil {
return nil, fmt.Errorf("invalid %s value: %v", name, err)
}
return v, nil
}
}
func ToBytesFormat(name string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
v, err := convert.ToBytesFormat(data)
if err != nil {
fmt.Printf("%s: %v, %T\n", name, data, data)
return nil, fmt.Errorf("invalid %s value: %v", name, data)
}
return v, nil
}
}
func ToSIFormatWithUnit(name string, unit string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
v, err := convert.ToSIFormatWithUnit(data, unit)
if err != nil {
return nil, fmt.Errorf("invalid %s value: %v", name, err)
}
return v, nil
}
}
func ParseDuration(name string) TransformerFunc {
return func(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
v, ok := data.(string)
if !ok || v == "" {
return nil, fmt.Errorf("invalid %s value: %v", name, data)
}
return time.ParseDuration(v)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package transform
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// TransformerFunc is a function that transforms a part of the configuration.
// It modifies the data to conform to the expected structure and returns the transformed data.
// It takes the root configuration, the data to transform and the current path in the tree.
type TransformerFunc func(*map[string]interface{}, interface{}, tree.Path) (any, error)
// Transformer is a configuration transformer.
type Transformer interface {
Transform(*map[string]interface{}) (interface{}, error)
}
// TransformerImpl is the implementation of the Transformer interface.
type TransformerImpl struct {
transformers []map[tree.Path]TransformerFunc
}
// NewTransformer creates a new transformer with the given transformers.
func NewTransformer(transformers []map[tree.Path]TransformerFunc) Transformer {
return TransformerImpl{
transformers: transformers,
}
}
// Transform applies the transformers to the configuration.
func (t TransformerImpl) Transform(rawConfig *map[string]interface{}) (interface{}, error) {
data := any(*rawConfig)
var err error
for _, transformers := range t.transformers {
data, err = t.transform(rawConfig, data, tree.NewPath(), transformers)
if err != nil {
return nil, err
}
}
return data, nil
}
// transform is a recursive function that applies the transformers to the configuration.
func (t TransformerImpl) transform(root *map[string]interface{}, data any, path tree.Path, transformers map[tree.Path]TransformerFunc) (interface{}, error) {
var err error
// Apply transformers that match the current path.
for pattern, transformer := range transformers {
if path.Matches(pattern) {
data, err = transformer(root, data, path)
if err != nil {
return nil, err
}
}
}
// Recursively apply transformers to children.
if result, ok := data.(map[string]interface{}); ok {
for key, value := range result {
next := path.Next(key)
result[key], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
} else if result, ok := data.([]any); ok {
for i, value := range result {
next := path.Next(fmt.Sprintf("[%d]", i))
result[i], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
}
return data, nil
}
// Original Copyright 2020 The Compose Specification Authors; Modified Copyright 2024, Nunet;
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tree
import (
"fmt"
"strings"
)
const (
configPathSeparator = "."
configPathMatchAny = "*"
configPathMatchAnyMultiple = "**"
configPathList = "[]"
)
// Path is a custom type for representing paths in the configuration
type Path string
func NewPath(path ...string) Path {
return Path(strings.Join(path, configPathSeparator))
}
// Parts returns the parts of the path
func (p Path) Parts() []string {
return strings.Split(string(p), configPathSeparator)
}
// Parent returns the parent path
func (p Path) Parent() Path {
parts := p.Parts()
if len(parts) > 1 {
return Path(strings.Join(parts[:len(parts)-1], configPathSeparator))
}
return ""
}
// Next returns the next part of the path
func (p Path) Next(path string) Path {
if path == "" {
return p
}
if p == "" {
return Path(path)
}
return Path(string(p) + configPathSeparator + path)
}
// Last returns the last part of the path
func (p Path) Last() string {
parts := p.Parts()
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}
// Matches checks if the path matches a given pattern
func (p Path) Matches(pattern Path) bool {
pathParts := p.Parts()
patternParts := pattern.Parts()
return matchParts(pathParts, patternParts)
}
// FindParentWithKey returns the first parent path that has the specified key
func (p Path) FindParentWithKey(key string) Path {
if p == "" || key == "" {
return ""
}
parts := p.Parts()
for i := len(parts); i > 0; i-- {
// Check if current part matches the key
if parts[i-1] == key {
return Path(strings.Join(parts[:i], configPathSeparator))
}
}
return ""
}
// String returns the string representation of the path
func (p Path) String() string {
return string(p)
}
// matchParts checks if the path parts match the pattern parts
func matchParts(pathParts, patternParts []string) bool {
// If the pattern is longer than the path, it can't match
if len(pathParts) < len(patternParts) {
return false
}
for i, part := range patternParts {
switch part {
case configPathMatchAnyMultiple:
// if it is the last part of the pattern, it matches
if i == len(patternParts)-1 {
return true
}
// Otherwise, try to match the rest of the path
for j := i; j < len(pathParts); j++ {
if matchParts(pathParts[j:], patternParts[i+1:]) {
return true
}
}
case configPathList:
// check if pathParts[i] is inclosed by []
if pathParts[i][0] != '[' || pathParts[i][len(pathParts[i])-1] != ']' {
return false
}
default:
// If the part doesn't match, it doesn't match
if part != configPathMatchAny && !strings.EqualFold(part, pathParts[i]) {
return false
}
}
// If it is the last part of the pattern and the path is longer, it doesn't match
if i == len(patternParts)-1 && i < len(pathParts)-1 {
return false
}
}
return true
}
// WalkFunc is a function applied to each node in the data structure.
type WalkFunc func(node *any, path Path) error
// Walk recursively traverses a generic data structure and applies a function to each node.
func Walk(data *any, path Path, fn WalkFunc) error {
// Apply the function to the current node first.
if err := fn(data, path); err != nil {
return err
}
switch v := (*data).(type) {
case map[string]any:
for key, val := range v {
// A temporary variable is needed to pass the address correctly.
tempVal := val
if err := Walk(&tempVal, path.Next(key), fn); err != nil {
return err
}
v[key] = tempVal
}
case []any:
for i, val := range v {
tempVal := val
next := path.Next(fmt.Sprintf("[%d]", i))
if err := Walk(&tempVal, next, fn); err != nil {
return err
}
v[i] = tempVal
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"encoding/json"
"fmt"
"github.com/go-viper/mapstructure/v2"
"github.com/spf13/afero"
"go.yaml.in/yaml/v3"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
"gitlab.com/nunet/device-management-service/lib/env"
)
type Options struct {
Env env.EnvironmentProvider
Fs afero.Afero
WorkingDir string
}
type Parser interface {
Decode(data []byte, dest any, opts *Options) error
Encode(data any) ([]byte, error)
}
const DefaultTagName = "json"
type resolveFunc func(data *any, options *Options) error
type BasicParser struct {
format string
resolveFn resolveFunc
decoder transform.Transformer
encoder transform.Transformer
validator validate.Validator
}
func NewBasicParser(format string, resolveFn resolveFunc, decoder, encoder transform.Transformer, validator validate.Validator) BasicParser {
return BasicParser{
format: format,
resolveFn: resolveFn,
decoder: decoder,
encoder: encoder,
validator: validator,
}
}
func (p BasicParser) unmarshal(data []byte, result any) error {
switch p.format {
case "json":
return json.Unmarshal(data, result)
case "yaml":
return yaml.Unmarshal(data, result)
default:
return fmt.Errorf("invalid format: %s", p.format)
}
}
func (p BasicParser) marshal(data any) ([]byte, error) {
switch p.format {
case "json":
return json.Marshal(data)
case "yaml":
return yaml.Marshal(data)
default:
return nil, fmt.Errorf("invalid format: %s", p.format)
}
}
func (p BasicParser) Decode(data []byte, result any, opts *Options) error {
var rawConfig map[string]any
if err := p.unmarshal(data, &rawConfig); err != nil {
return fmt.Errorf("failed to parse config: %v", err)
}
// Resolve files and environment variables
var rawConfigAny any = rawConfig
err := p.resolveFn(&rawConfigAny, opts)
if err != nil {
return fmt.Errorf("failed to resolve config: %v", err)
}
rawConfig = rawConfigAny.(map[string]any)
// Apply transformers
transformed, err := p.decoder.Transform(&rawConfig)
if err != nil {
return fmt.Errorf("failed to transform config: %v", err)
}
transformedMap, ok := transformed.(map[string]any)
if !ok {
return fmt.Errorf("transformed config is not a map: %v", transformed)
}
// Validate the transformed configuration
if err = p.validator.Validate(&transformedMap); err != nil {
return err
}
if err := decodeWithDefaultTagName(transformedMap, &result); err != nil {
return fmt.Errorf("failed to decode spec: %v", err)
}
return err
}
func decodeWithDefaultTagName(input any, result any) error {
var m mapstructure.Metadata
ms, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Metadata: &m,
Result: result,
TagName: DefaultTagName,
})
if err != nil {
return err
}
return ms.Decode(input)
}
func (p BasicParser) Encode(data any) ([]byte, error) {
// Convert struct -> map[string]any via JSON roundtrip so that
// nested maps of structs (e.g., Allocations, Nodes) are converted
// into map[string]any values instead of remaining as structs.
raw, err := p.marshal(data)
if err != nil {
return nil, fmt.Errorf("failed to encode spec: %v", err)
}
var config map[string]any
if err := p.unmarshal(raw, &config); err != nil {
return nil, fmt.Errorf("failed to encode spec: %v", err)
}
transformed, err := p.encoder.Transform(&config)
if err != nil {
return nil, fmt.Errorf("failed to transform config: %v", err)
}
return p.marshal(transformed)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"errors"
"fmt"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// Sentinel errors to enable precise upstream handling
var (
ErrKeyNotFound = errors.New("key not found")
ErrInvalidIndex = errors.New("invalid index")
ErrIndexOutOfRange = errors.New("index out of range")
ErrInvalidTypeAtPath = errors.New("invalid type at path")
ErrCycleDetected = errors.New("cycle detected")
)
// GetConfigAtPath retrieves a part of the configuration at a given path
func GetConfigAtPath(config any, path tree.Path) (any, error) {
current := config
for _, key := range path.Parts() {
switch v := current.(type) {
case map[string]any:
val, ok := v[key]
if !ok {
if val, ok = v[strings.ToLower(key)]; !ok {
return nil, fmt.Errorf("%w: %q", ErrKeyNotFound, key)
}
}
current = val
case []any:
if len(key) < 3 || key[0] != '[' || key[len(key)-1] != ']' {
return nil, fmt.Errorf("%w: %s (expected [n])", ErrInvalidIndex, key)
}
i, err := strconv.Atoi(key[1 : len(key)-1])
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrInvalidIndex, key)
}
if i < 0 || i >= len(v) {
return nil, fmt.Errorf("%w: %d (len=%d)", ErrIndexOutOfRange, i, len(v))
}
current = v[i]
case []map[string]any:
if len(key) < 3 || key[0] != '[' || key[len(key)-1] != ']' {
return nil, fmt.Errorf("%w: %s (expected [n])", ErrInvalidIndex, key)
}
i, err := strconv.Atoi(key[1 : len(key)-1])
if err != nil {
return nil, fmt.Errorf("%w: %s", ErrInvalidIndex, key)
}
if i < 0 || i >= len(v) {
return nil, fmt.Errorf("%w: %d (len=%d)", ErrIndexOutOfRange, i, len(v))
}
current = v[i]
default:
return nil, fmt.Errorf("%w at %q: %T", ErrInvalidTypeAtPath, key, current)
}
}
return current, nil
}
// CreateAdjencyList creates an adjacency list from a map
func CreateAdjencyList[T comparable](m map[T]any, path tree.Path) map[T][]T {
adjencyList := make(map[T][]T)
for key, value := range m {
if val, err := GetConfigAtPath(value, path); err == nil {
switch v := val.(type) {
case []T:
adjencyList[key] = v
case []any:
for _, v := range v {
if k, ok := v.(T); ok {
adjencyList[key] = append(adjencyList[key], k)
}
}
case T:
adjencyList[key] = []T{v}
}
}
}
return adjencyList
}
func hasCycle[T comparable](adjencyList map[T][]T, node T, visited, recursionStack map[T]bool) bool {
// If the node is already in the current recursion stack, we found a cycle
if recursionStack[node] {
return true
}
// If we've already fully visited this node (and its descendants), skip
if visited[node] {
return false
}
// Mark as visited and add to the recursion stack
visited[node] = true
recursionStack[node] = true
// Recurse on all neighbors
for _, neighbor := range adjencyList[node] {
if hasCycle(adjencyList, neighbor, visited, recursionStack) {
return true
}
}
// Remove from recursion stack when backtracking
recursionStack[node] = false
return false
}
func DetectCycles[T comparable](adjencyList map[T][]T) bool {
visited := make(map[T]bool)
recursionStack := make(map[T]bool)
for node := range adjencyList {
if !visited[node] {
if hasCycle(adjencyList, node, visited, recursionStack) {
return true
}
}
}
return false
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package validate
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// ValidatorFunc is a function that validates a part of the configuration.
// It takes the root configuration, the data to validate and the current path in the tree.
type ValidatorFunc func(*map[string]any, any, tree.Path) error
// Validator is a configuration validator.
// It contains a map of patterns to paths to functions that validate the configuration.
type Validator interface {
Validate(*map[string]any) error
}
// ValidatorImpl is the implementation of the Validator interface.
type ValidatorImpl struct {
validators map[tree.Path]ValidatorFunc
}
// NewValidator creates a new validator with the given validators.
func NewValidator(validators map[tree.Path]ValidatorFunc) Validator {
return ValidatorImpl{
validators: validators,
}
}
// Validate applies the validators to the configuration.
func (v ValidatorImpl) Validate(rawConfig *map[string]any) error {
data := any(*rawConfig)
return v.validate(rawConfig, data, tree.NewPath(), v.validators)
}
// validate is a recursive function that applies the validators to the configuration.
func (v ValidatorImpl) validate(root *map[string]interface{}, data any, path tree.Path, validators map[tree.Path]ValidatorFunc) error {
// Apply validators that match the current path.
for pattern, validator := range validators {
if path.Matches(pattern) {
if err := validator(root, data, path); err != nil {
return err
}
}
}
// Recursively apply validators to children.
switch data := data.(type) {
case map[string]interface{}:
for key, value := range data {
next := path.Next(key)
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
case []interface{}:
for i, value := range data {
next := path.Next(fmt.Sprintf("[%d]", i))
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobtypes
import (
"encoding/json"
"errors"
"fmt"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/types"
)
// BidRequest is a versioned bid request
type BidRequest struct {
V1 *BidRequestV1
}
// BidRequestV1 is v1 of bid requests for a node to use for deployment
type BidRequestV1 struct {
// NodeID is the unique identifier for a node, within the context of an ensemble
NodeID string `json:"node_id"`
// Executors list of required executors to support the allocation(s)
Executors []AllocationExecutor `json:"executors"`
// Resources (aggregate) required hardware resources
Resources types.Resources `json:"resources"`
// Location is the node location constraints
Location LocationConstraints `json:"location,omitempty"`
PublicPorts struct {
Static []int `json:"static,omitempty"` // statically configured public ports
Dynamic int `json:"dynamic,omitempty"` // number of dynamic ports
} `json:"public_ports,omitempty"`
GeneralRequirements struct {
PrivilegedDocker bool `json:"privileged_docker,omitempty"`
} `json:"general_requirements,omitempty"`
// contract attached to a bid request
Contracts map[string]types.ContractConfig `json:"contracts,omitempty"`
}
// Bid is the version struct for Bids in response to a bid request
type Bid struct {
V1 *BidV1
}
// BidV1 is v1 of the bid structure
// XXX: note that the json tags are camel case to be compatible with bid structure version before introducing
// the tags. Otherwise signature verification would fail.
type BidV1 struct {
EnsembleID string `json:"EnsembleID"` // unique identifier for the ensemble
NodeID string `json:"NodeID"` // unique identifier for a node; matches the id of the BidRequest to which this bid pertains
Peer string `json:"Peer"` // the peer ID of the node
Location Location `json:"Location"` // the location of the node
Handle actor.Handle `json:"Handle"` // the handle of the actor submitting the bid
PubAddress string `json:"PubAddress,omitempty"` // observed public address of the node
Contracts map[string]types.ContractConfig `json:"Contracts,omitempty"`
Signature []byte `json:"Signature"`
PromiseBid bool `json:"PromiseBid,omitempty"`
}
const bidPrefix = "dms-bid-"
// EnsembleBidRequest is a request for a bids pertaining to an ensemble
//
// Note: At the moment, we embed a bid request for each node
// This is fine for small deployments, and a small network, which is what we have.
// For large deployments however, this won't scale and we will have to create aggregate
// bid requests for related group of nodes and also handle them with bid request
// aggregators who control multiple nodes.
type EnsembleBidRequest struct {
// ID is the unique identifier of an ensemble (in the context of the orchestrator)
ID string `json:"id"`
// Request is the list of node bid requests
Request []BidRequest `json:"request"`
// Nonce is a sequential number for each request sent out
Nonce uint64 `json:"nonce"`
// PeerExclusion is the list of peers to exclude from bidding
PeerExclusion []string `json:"peer_exclusion,omitempty"`
}
func (b *Bid) Contracts() map[string]types.ContractConfig {
return b.V1.Contracts
}
func (b *EnsembleBidRequest) Validate() error {
if b == nil {
return errors.New("ensemble bid request is nil")
}
if b.ID == "" {
return errors.New("ensemble id is empty")
}
if len(b.Request) == 0 {
return errors.New("ensemble with empty requests")
}
return nil
}
func (b *Bid) SignatureData() ([]byte, error) {
bidV1Copy := *b.V1
bidV1Copy.Signature = nil
data, err := json.Marshal(&bidV1Copy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(bidPrefix)+len(data))
copy(result, bidPrefix)
copy(result[len(bidPrefix):], data)
return result, nil
}
func (b *Bid) Sign(key did.Provider) error {
data, err := b.SignatureData()
if err != nil {
return fmt.Errorf("unable to create bid signature data")
}
sig, err := key.Sign(data)
if err != nil {
return fmt.Errorf("unable to sign the bid")
}
b.V1.Signature = sig
return nil
}
func (b *Bid) Validate() error {
if b.V1 == nil {
return fmt.Errorf("bid V1 is nil")
}
p, err := peer.Decode(b.V1.Peer)
if err != nil {
return fmt.Errorf("failed to decode bid's peer id: %w", err)
}
pubKey, err := p.ExtractPublicKey()
if err != nil {
return fmt.Errorf("failed to extract public key: %w", err)
}
signData, err := b.SignatureData()
if err != nil {
return fmt.Errorf("unable to get bid signature data")
}
ok, err := pubKey.Verify(signData, b.V1.Signature)
if err != nil {
return fmt.Errorf("failed to verify signature: %w", err)
}
if !ok {
return errors.New("signature verification failed")
}
return nil
}
func (b *Bid) EnsembleID() string {
return b.V1.EnsembleID
}
func (b *Bid) NodeID() string {
return b.V1.NodeID
}
func (b *Bid) Peer() string {
return b.V1.Peer
}
func (b *Bid) Handle() actor.Handle {
return b.V1.Handle
}
func (b *Bid) Location() Location {
return b.V1.Location
}
func (b *Bid) PubAddress() string {
return b.V1.PubAddress
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobtypes
import (
"encoding/json"
"errors"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
type (
AllocationExecutor string
AllocationType string
)
const (
// Executor types define the runtime environment for allocations
ExecutorFirecracker AllocationExecutor = "firecracker" // Firecracker VM-based execution
ExecutorDocker AllocationExecutor = "docker" // Docker container-based execution
ExecutorNull AllocationExecutor = "null" // Null executor for testing
// AllocationType defines the lifecycle behavior of the allocation
AllocationTypeService AllocationType = "service" // Long-running process that should restart on failure
AllocationTypeTask AllocationType = "task" // One-off job that runs to completion
)
// EnsembleConfig is the versioned structure that contains the ensemble configuration
type EnsembleConfig struct {
V1 *EnsembleConfigV1 `json:"v1"`
}
// EnsembleConfigV1 is version 1 of the configuration for an ensemble
type EnsembleConfigV1 struct {
EscalationStrategy EscalationStrategy `json:"escalation_strategy" yaml:"escalation_strategy"` // escalation strategy (redeploy|teardown)
Allocations map[string]AllocationConfig `json:"allocations" yaml:"allocations"` // (named) allocations in the ensemble
Nodes map[string]NodeConfig `json:"nodes" yaml:"nodes"` // (named) nodes in the ensemble
Edges []EdgeConstraint `json:"edges,omitempty" yaml:"edges,omitempty"` // network edge constraints
Supervisor SupervisorConfig `json:"supervisor,omitempty" yaml:"supervisor,omitempty"` // supervision structure
Keys map[string]string `json:"keys,omitempty" yaml:"keys,omitempty"` // (named) ssh public keys relevant to the allocation
Scripts map[string][]byte `json:"scripts,omitempty" yaml:"scripts,omitempty"` // (named) provisioning scripts
Subnet SubnetConfig `json:"subnet,omitempty" yaml:"subnet,omitempty"` // subnet config
ExcludePeers []string `json:"exclude_peers,omitempty" yaml:"exclude_peers,omitempty"` // list of peers to not deploy on
Metadata map[string]string `json:"metadata,omitempty" yaml:"metadata,omitempty"` // metadata (arbitrary key-value pairs)
Contracts map[string]types.ContractConfig `json:"contracts,omitempty" yaml:"contracts,omitempty"` // (named) contracts between parties
}
type EscalationStrategy string
const (
EscalationStrategyRedeploy EscalationStrategy = "redeploy"
EscalationStrategyTeardown EscalationStrategy = "teardown"
)
// AllocationConfig is the configuration of an allocation
type AllocationConfig struct {
Executor AllocationExecutor `json:"executor" yaml:"executor"` // the executor of the allocation
Type AllocationType `json:"type" yaml:"type"` // the type of allocation (service vs task)
Resources types.Resources `json:"resources" yaml:"resources"` // the HW resources required by the allocation
Execution types.SpecConfig `json:"execution" yaml:"execution"` // the allocation execution configuration
DNSName string `json:"dns_name,omitempty" yaml:"dns_name,omitempty"` // the internal DNS name of the allocation
Keys []types.AllocationKey `json:"keys,omitempty" yaml:"keys,omitempty"` // names of the authorized ssh keys for the allocation
Provision []string `json:"provision,omitempty" yaml:"provision,omitempty"` // names of provisioning scripts to run (in order)
HealthCheck types.HealthCheckManifest `json:"healthcheck,omitempty" yaml:"healthcheck,omitempty"` // name of the health check script
Volume []types.VolumeConfig `json:"volume,omitempty" yaml:"volume,omitempty"` // unified storage configuration (optional)
FailureRecovery AllocationFailureRecovery `json:"failure_recovery,omitempty" yaml:"failure_recovery,omitempty"` // failure recovery (stay_down|one_for_one|one_for_all|rest_for_one)
DependsOn []string `json:"depends_on,omitempty" yaml:"depends_on,omitempty"` // list of allocations that this allocation depends on
}
// AllocationFailureRecovery
type AllocationFailureRecovery string
const (
AllocationFailureRecoveryStayDown AllocationFailureRecovery = "stay_down"
AllocationFailureRecoveryOneForOne AllocationFailureRecovery = "one_for_one"
AllocationFailureRecoveryOneForAll AllocationFailureRecovery = "one_for_all"
AllocationFailureRecoveryRestForOne AllocationFailureRecovery = "rest_for_one"
)
// NodeConfig is the configuration of a distinct DMS node
type NodeConfig struct {
Allocations []string `json:"allocations" yaml:"allocations"` // list of allocation IDs
Ports []PortConfig `json:"ports,omitempty" yaml:"ports,omitempty"` // list of port mappings
Location LocationConstraints `json:"location,omitempty" yaml:"location,omitempty"` // location constraints
Peer string `json:"peer,omitempty" yaml:"peer,omitempty"` // peer ID to use for this node
Redundancy int `json:"redundancy,omitempty" yaml:"redundancy,omitempty"` // number of redundant nodes
FailureRecovery NodeFailureRecovery `json:"failure_recovery,omitempty" yaml:"failure_recovery,omitempty"` // failure recovery (stay_down|restart|redeploy)
// TODO contract information
}
// NodeFailureRecovery is the failure recovery strategy for a node
type NodeFailureRecovery string
const (
NodeFailureRecoveryStayDown NodeFailureRecovery = "stay_down"
NodeFailureRecoveryRestart NodeFailureRecovery = "restart"
NodeFailureRecoveryRedeploy NodeFailureRecovery = "redeploy"
)
// LocationConstraints provides the node location placement constraints
type LocationConstraints struct {
Accept []Location `json:"accept,omitempty" yaml:"accept,omitempty"` // list of accepted locations
Reject []Location `json:"reject,omitempty" yaml:"reject,omitempty"` // list of rejected locations
}
// Location is a geographical location on Planet Earth
type Location struct {
Continent string `json:"continent,omitempty" yaml:"continent,omitempty"` // geographical region
Country string `json:"country,omitempty" yaml:"country,omitempty"` // country code
City string `json:"city,omitempty" yaml:"city,omitempty"` // city name
ASN uint `json:"asn,omitempty" yaml:"asn,omitempty"` // autonomous system number
ISP string `json:"isp,omitempty" yaml:"isp,omitempty"` // internet service provider
}
// PortConfig is the configuration for a port mapping a public port to a private port
// in an allocation
type PortConfig struct {
Public int `json:"public" yaml:"public"` // public port number
Private int `json:"private" yaml:"private"` // private port number
Allocation string `json:"allocation" yaml:"allocation"` // allocation ID
}
// EdgeConstraint is a constraint for a network edge between two nodes
type EdgeConstraint struct {
S string `json:"s" yaml:"s"` // source node ID
T string `json:"t" yaml:"t"` // target node ID
RTT uint `json:"rtt,omitempty" yaml:"rtt,omitempty"` // round trip time in milliseconds
BW uint `json:"bw,omitempty" yaml:"bw,omitempty"` // bandwidth in bits per second
}
// SupervisorConfig is the supervisory structure configuration for the ensemble
type SupervisorConfig struct {
Strategy SupervisorStrategy `json:"strategy,omitempty" yaml:"strategy,omitempty"` // supervision strategy
Allocations []string `json:"allocations,omitempty" yaml:"allocations,omitempty"` // list of allocation IDs
Children []SupervisorConfig `json:"children,omitempty" yaml:"children,omitempty"` // list of child supervisors
}
// SupervisorStrategy is the name of a supervision strategy
type SupervisorStrategy string
type SubnetConfig struct {
Join bool `json:"join,omitempty" yaml:"join,omitempty"` // for orchestrator to join the subnet
}
const (
StrategyOneForOne SupervisorStrategy = "OneForOne"
StrategyAllForOne SupervisorStrategy = "AllForOne"
StrategyRestForOne SupervisorStrategy = "RestForOne"
)
// Validate validates the ensemble configuration
func (e *EnsembleConfig) Validate() error {
if e == nil || e.V1 == nil {
return errors.New("invalid ensemble config")
}
return nil
}
func (e *EnsembleConfig) Contracts() map[string]types.ContractConfig {
return e.V1.Contracts
}
func (e *EnsembleConfig) Contract(contractID string) (types.ContractConfig, bool) {
c, ok := e.V1.Contracts[contractID]
return c, ok
}
func (e *EnsembleConfig) Allocations() map[string]AllocationConfig {
return e.V1.Allocations
}
func (e *EnsembleConfig) Allocation(name string) (AllocationConfig, bool) {
a, ok := e.V1.Allocations[name]
return a, ok
}
// buildStandbyNodes constructs standby node configurations for nodes with redundancy
func buildStandbyNodes(nodes map[string]NodeConfig, nodeIDGenerator types.NodeIDGenerator) map[string]NodeConfig {
standbyNodes := make(map[string]NodeConfig)
for nodeID, nodeConfig := range nodes {
if nodeConfig.Redundancy == 0 {
continue
}
for i := 1; i <= nodeConfig.Redundancy; i++ {
standbyNodeID, err := nodeIDGenerator.GenerateStandbyNodeID(nodeID, i)
if err != nil {
// Log error and skip this standby node
continue
}
ncfg := NodeConfig{
Allocations: nodeConfig.Allocations,
Ports: nodeConfig.Ports,
Location: nodeConfig.Location,
FailureRecovery: nodeConfig.FailureRecovery,
}
standbyNodes[standbyNodeID] = ncfg
}
}
return standbyNodes
}
// Nodes returns a map of all nodes including standby nodes (using default generator)
func (e *EnsembleConfig) Nodes() map[string]NodeConfig {
return e.NodesWithGenerator(types.NewDefaultNodeIDGenerator())
}
// NodesWithGenerator returns a map of all nodes including standby nodes using the provided generator
func (e *EnsembleConfig) NodesWithGenerator(nodeIDGenerator types.NodeIDGenerator) map[string]NodeConfig {
result := make(map[string]NodeConfig)
// First add all primary nodes
for nodeID, nodeConfig := range e.V1.Nodes {
result[nodeID] = nodeConfig
}
// Then add standby nodes
for nodeID, nodeConfig := range buildStandbyNodes(e.V1.Nodes, nodeIDGenerator) {
result[nodeID] = nodeConfig
}
return result
}
// PrimaryNodes returns only the primary nodes (no standby nodes)
func (e *EnsembleConfig) PrimaryNodes() map[string]NodeConfig {
return e.V1.Nodes
}
func (e *EnsembleConfig) Node(nodeID string) (NodeConfig, bool) {
return e.NodeWithGenerator(nodeID, types.NewDefaultNodeIDGenerator())
}
func (e *EnsembleConfig) NodeWithGenerator(nodeID string, nodeIDGenerator types.NodeIDGenerator) (NodeConfig, bool) {
// Check if it's a primary node directly from the configuration
if n, ok := e.V1.Nodes[nodeID]; ok {
return n, true
}
// Check if it's a standby node using the generator
isStandby, primaryNodeID, standbyIndex, err := nodeIDGenerator.ParseNodeID(nodeID)
if err != nil || !isStandby {
return NodeConfig{}, false
}
// Get the primary node
primaryNode, ok := e.V1.Nodes[primaryNodeID]
if !ok || standbyIndex < 1 || standbyIndex > primaryNode.Redundancy {
return NodeConfig{}, false
}
// Return a copy of the primary node config for this standby
return primaryNode, true
}
func (e *EnsembleConfig) AllocationsForNode(node string) map[string]AllocationConfig {
allocations := make(map[string]AllocationConfig)
nodeConfig, ok := e.Node(node)
if !ok {
return allocations
}
for _, allocName := range nodeConfig.Allocations {
if alloc, ok := e.Allocation(allocName); ok {
allocations[allocName] = alloc
}
}
return allocations
}
func (e *EnsembleConfig) PortsForAllocation(allocation string) []PortConfig {
var ports []PortConfig
for _, node := range e.Nodes() {
for _, port := range node.Ports {
if port.Allocation == allocation {
ports = append(ports, port)
}
}
}
return ports
}
func (e *EnsembleConfig) EdgeConstraints() []EdgeConstraint {
return e.V1.Edges
}
func (e *EnsembleConfig) Subnet() SubnetConfig {
return e.V1.Subnet
}
func (e *EnsembleConfig) AddNodeAndAllocations(
name string, node NodeConfig,
allocs map[string]AllocationConfig,
) {
e.V1.Nodes[name] = node
for k, v := range allocs {
if utils.SliceContains(node.Allocations, k) {
e.V1.Allocations[k] = v
}
}
}
func (e *EnsembleConfig) RemoveNodeAndAllocations(
nodeName string,
) {
if nodeCfg, ok := e.Node(nodeName); ok {
delete(e.V1.Nodes, nodeName)
for _, alloc := range nodeCfg.Allocations {
delete(e.V1.Allocations, alloc)
}
}
}
func (e *EnsembleConfig) Clone() EnsembleConfig {
var clone EnsembleConfig
bytes, err := json.Marshal(e)
if err != nil {
log.Errorf("error marshaling ensemble config: %s", err)
return clone
}
if err := json.Unmarshal(bytes, &clone); err != nil {
log.Errorf("unmarshaling ensemble config: %s", err)
}
return clone
}
func (l *Location) Equal(other Location) bool {
if l.Continent != other.Continent {
return false
}
if l.Country != "" && l.Country != other.Country {
return false
}
if l.City != "" && l.City != other.City {
return false
}
if l.ASN > 0 && l.ASN != other.ASN {
return false
}
if l.ISP != "" && l.ISP != other.ISP {
return false
}
return true
}
func (l Location) Satisfies(constraints LocationConstraints) bool {
// Accept list takes precedence
if len(constraints.Accept) > 0 {
for _, a := range constraints.Accept {
if a.Equal(l) {
return true
}
}
return false
}
// Otherwise enforce Reject list
for _, r := range constraints.Reject {
if r.Equal(l) {
return false
}
}
return true
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobtypes
import (
"encoding/json"
"errors"
"fmt"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/types"
)
var ErrAllocationNotFound = errors.New("allocation not found")
// NodeStatus represents the current status of a node
type NodeStatus string
const (
NodeStatusActive NodeStatus = "active" // Node is active and running allocations
NodeStatusStandby NodeStatus = "standby" // Node is a standby waiting to be activated
// NodeStatusFailed NodeStatus = "failed" // Node has failed
)
type RedundancyRole string
var (
RolePrimary RedundancyRole = "primary"
RoleStandby RedundancyRole = "standby"
)
type EnsembleManifest struct {
ID string `json:"id"` // ensemble globally unique id
Metadata map[string]string `json:"metadata,omitempty"` // metadata
Orchestrator actor.Handle `json:"orchestrator"` // orchestrator actor
Allocations map[string]AllocationManifest `json:"allocations"` // allocation name -> manifest
Nodes map[string]NodeManifest `json:"nodes"` // node name -> manifest
Subnet SubnetConfig `json:"subnet"` // subnet configurations
Contracts map[string]ContractManifest `json:"contracts"` // contract name -> manifest
}
type ContractManifest struct {
ID string `json:"id"` // contract unique id
DID string `json:"did"` // DID of the contract
Host string `json:"host"`
}
type AllocationManifest struct {
ID string `json:"id"` // allocation unique id
NodeID string `json:"node_id"` // allocation node
Type AllocationType `json:"type"` // allocation type
Handle actor.Handle `json:"handle"` // handle of the allocation control actor
DNSName string `json:"dns_name"` // (internal) DNS name of the allocation
PrivAddr string `json:"priv_addr"` // (VPN) private IP address of the allocation peer
Ports map[int]int `json:"ports,omitempty"` // port mapping, public -> private
Healthcheck types.HealthCheckManifest `json:"healthcheck"` // healthcheck configuration
Status AllocationStatus `json:"status"` // current status of the allocation
RedundancyGroup string `json:"redundancy_group,omitempty"` // base allocation name this belongs to
IsStandby bool `json:"is_standby"` // whether this is a standby allocation
}
type NodeManifest struct {
ID string `json:"id"` // node unique id
Peer string `json:"peer,omitempty"` // peer where the node is running
Handle actor.Handle `json:"handle"` // handle of the control actor for the node
PubAddress []string `json:"pub_address"` // public IP4/6 address of the node peer
Location Location `json:"location"` // location of the peer
Allocations []string `json:"allocations"` // allocations in the node
RedundancyRole RedundancyRole `json:"redundancy_role,omitempty"` // "primary" or "standby"
PrimaryNode string `json:"primary_node,omitempty"` // ID of primary node if this is a standby
StandbyNodes []string `json:"standby_nodes,omitempty"` // IDs of standby nodes if this is a primary
StandbyIndex int `json:"standby_index,omitempty"` // Index of this standby (1, 2, etc.)
Status NodeStatus `json:"status"` // current status of the node (active, standby, failed)
}
func (mf *EnsembleManifest) Clone() EnsembleManifest {
var clone EnsembleManifest
bytes, err := json.Marshal(mf)
if err != nil {
log.Errorf("error marshaling ensemble manifest: %s", err)
return clone
}
if err := json.Unmarshal(bytes, &clone); err != nil {
log.Errorf("error unmarshaling ensemble manifest: %s", err)
}
return clone
}
// UpdateAllocation applies the provided updater function to the specified allocation.
func (mf *EnsembleManifest) UpdateAllocation(name string, update func(*AllocationManifest)) error {
if alloc, ok := mf.Allocations[name]; ok {
update(&alloc)
mf.Allocations[name] = alloc
} else {
return ErrAllocationNotFound
}
return nil
}
func (mf *EnsembleManifest) IsTerminatedTask(name string) bool {
a, ok := mf.Allocations[name]
if !ok {
return false
}
if a.Type == AllocationTypeTask &&
a.Status != AllocationRunning {
return true
}
return false
}
func (mf *EnsembleManifest) Allocation(id string) (AllocationManifest, bool) {
a, ok := mf.Allocations[id]
return a, ok
}
func (mf *EnsembleManifest) Node(name string) (NodeManifest, bool) {
n, ok := mf.Nodes[name]
return n, ok
}
func (mf *EnsembleManifest) JSON() ([]byte, error) {
data, err := json.MarshalIndent(mf, "", " ")
if err != nil {
return nil, fmt.Errorf("unable to marshal manifest: %w", err)
}
return data, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobtypes
import (
"time"
"gitlab.com/nunet/device-management-service/types"
)
type DeploymentStatus int
const (
DeploymentStatusPreparing DeploymentStatus = iota
DeploymentStatusGenerating
DeploymentStatusCommitting
DeploymentStatusProvisioning
DeploymentStatusRunning
DeploymentStatusUpdating
DeploymentStatusFailed
DeploymentStatusShuttingDown
DeploymentStatusCompleted
)
func (d DeploymentStatus) String() string {
switch d {
case DeploymentStatusPreparing:
return "Preparing"
case DeploymentStatusGenerating:
return "Generating"
case DeploymentStatusCommitting:
return "Committing"
case DeploymentStatusProvisioning:
return "Provisioning"
case DeploymentStatusRunning:
return "Running"
case DeploymentStatusUpdating:
return "Updating"
case DeploymentStatusFailed:
return "Failed"
case DeploymentStatusShuttingDown:
return "ShuttingDown"
case DeploymentStatusCompleted:
return "Completed"
default:
return "Unknown"
}
}
type OrchestratorView struct {
types.BaseDBModel
OrchestratorID string
Cfg EnsembleConfig
Manifest EnsembleManifest
SubnetManifest SubnetManifest
Status DeploymentStatus
DeploymentSnapshot DeploymentSnapshot
PrivKey []byte
// Fields for persistence
CompletedAt *time.Time // nil if not completed
ErrorMessage string // for failed deployments
}
type DeploymentSnapshot struct {
// candidates keeps state of candidates while committing.
Candidates map[string]Bid
// Expiry is the time passed as an argument when calling Deploy()
Expiry time.Time
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package jobtypes
// EnsembleCfgReader provides read-only access to an EnsembleConfig
type EnsembleCfgReader struct {
cfg EnsembleConfig
}
// NewEnsembleCfgReader creates a new reader with a deep copy of the config
func NewEnsembleCfgReader(cfg EnsembleConfig) EnsembleCfgReader {
return EnsembleCfgReader{cfg: cfg.Clone()}
}
// Read returns the Reader's config which was cloned
// from another payload by the constructor.
func (r EnsembleCfgReader) Read() EnsembleConfig {
return r.cfg
}
// ManifestReader provides read-only access to an EnsembleManifest
type ManifestReader struct {
manifest EnsembleManifest
}
// NewManifestReader creates a new reader with a deep copy of the manifest
func NewManifestReader(manifest EnsembleManifest) ManifestReader {
return ManifestReader{manifest: manifest.Clone()}
}
// Read returns the Reader's config which was cloned
// from another payload by the constructor.
func (r ManifestReader) Read() EnsembleManifest {
return r.manifest
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package dms
import (
logging "github.com/ipfs/go-log/v2"
"go.uber.org/multierr"
)
var log = logging.Logger("dms")
// silenceLibp2pLogging is used to silence logs coming from libp2p
// imported libraries as they're enabled by default.
//
// TODO: move this to observability?
func silenceConnLogs() error {
log.Debug("silecing libp2p logging")
var errs error
err := logging.SetLogLevel("libp2p", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("swarm2", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("basichost", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("reuseport-transport", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("pubsub", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("p2p-config", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("routedhost", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("relay", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("autorelay", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("p2p-circuit", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("autonat", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("autonatv2", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("upgrader", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("p2p-holepunch", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevel("rcmgr", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevelRegex("dht/*", "panic")
errs = multierr.Append(errs, err)
err = logging.SetLogLevelRegex("net/*", "panic")
errs = multierr.Append(errs, err)
// dms-specific connections logs
err = logging.SetLogLevel("node.conn", "panic")
errs = multierr.Append(errs, err)
return errs
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"fmt"
"path/filepath"
"slices"
"sync"
"time"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/jobs"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/network"
netutils "gitlab.com/nunet/device-management-service/network/utils"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/storage/volume"
"gitlab.com/nunet/device-management-service/tokenomics/eventhandler"
"gitlab.com/nunet/device-management-service/tokenomics/store"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
var (
// ErrAllocationNotFound is returned when an allocation is not found but is expected.
ErrAllocationNotFound = fmt.Errorf("allocation not found")
// ErrDynamicPortsNotAvailable is returned when dynamic ports are not available for allocation.
ErrDynamicPortsNotAvailable = fmt.Errorf("dynamic ports not available")
// ErrPortsBusy is returned when the requested ports are already allocated.
ErrPortsBusy = fmt.Errorf("ports are already allocated")
// ErrResourcesNotAvailable is returned when the requested resources are not available.
ErrResourcesNotAvailable = fmt.Errorf("resources not available")
// ErrNoHardwareCapacity is returned when there is no capacity left on the hardware.
ErrNoHardwareCapacity = fmt.Errorf("no capacity left on the hardware")
)
// TODO: move port allocator stuffs to other file
// portAllocator keeps track of port allocations and manages state.
type portAllocator struct {
config PortConfig
lock sync.Mutex
allocations map[string][]int // allocationID -> ports
reserved map[int]struct{} // reserved ports
}
// newPortAllocator initializes a new portAllocator with a PortConfig.
func newPortAllocator(config PortConfig) *portAllocator {
return &portAllocator{
config: config,
allocations: make(map[string][]int),
reserved: make(map[int]struct{}),
}
}
// allocate a helper function to allocate a port.
func (pa *portAllocator) allocate(port int) error {
if port < pa.config.AvailableRangeFrom || port > pa.config.AvailableRangeTo {
return fmt.Errorf("port %d is outside allowed range [%d-%d]",
port, pa.config.AvailableRangeFrom, pa.config.AvailableRangeTo)
}
if _, reserved := pa.reserved[port]; reserved {
return fmt.Errorf("port %d is already reserved", port)
}
if !netutils.IsFreePort(port) {
return fmt.Errorf("port %d is not free", port)
}
pa.reserved[port] = struct{}{}
return nil
}
// Allocate allocates the requested ports, associating them with an allocationID.
// If it's not possible to allocate one of the ports, an error is returned and no ports are allocated.
func (pa *portAllocator) Allocate(allocationID string, ports []int) error {
pa.lock.Lock()
defer pa.lock.Unlock()
if len(ports) == 0 {
return nil
}
for _, port := range ports {
if err := pa.allocate(port); err != nil {
pa.release(ports)
return fmt.Errorf("cannot allocate port %d: %w", port, err)
}
}
pa.allocations[allocationID] = ports
return nil
}
// getAvailablePorts returns a list of available ports in the range specified in the config.
func (pa *portAllocator) getAvailablePorts(numPorts int) []int {
ports := make([]int, 0, numPorts)
for port := pa.config.AvailableRangeFrom; port <= pa.config.AvailableRangeTo && len(ports) < numPorts; port++ {
// Skip if port is reserved
if _, reserved := pa.reserved[port]; reserved {
continue
}
// Check if port is actually free on the system
if netutils.IsFreePort(port) {
ports = append(ports, port)
}
}
return ports
}
// AllocateRandom allocates the requested number of ports and associates them with the allocation ID.
func (pa *portAllocator) AllocateRandom(allocationID string, numPorts int) ([]int, error) {
pa.lock.Lock()
defer pa.lock.Unlock()
if numPorts == 0 {
return nil, fmt.Errorf("cannot allocate 0 ports")
}
portsToAllocate := pa.getAvailablePorts(numPorts)
if len(portsToAllocate) != numPorts {
pa.release(portsToAllocate)
return nil, fmt.Errorf("failed to allocate %d ports", numPorts)
}
// allocate them
for _, port := range portsToAllocate {
if err := pa.allocate(port); err != nil {
pa.release(portsToAllocate)
return nil, fmt.Errorf("failed to allocate port %d: %w", port, err)
}
}
pa.allocations[allocationID] = portsToAllocate
return portsToAllocate, nil
}
// release a helper function to release ports.
func (pa *portAllocator) release(ports []int) {
for _, p := range ports {
if _, ok := pa.reserved[p]; !ok {
continue
}
delete(pa.reserved, p)
}
}
// Release releases the ports associated with the allocation ID.
func (pa *portAllocator) Release(allocationID string) {
pa.lock.Lock()
defer pa.lock.Unlock()
allocated, ok := pa.allocations[allocationID]
if !ok {
return
}
pa.release(allocated)
delete(pa.allocations, allocationID)
}
// GetAllocation returns the allocated ports for a specific allocation ID.
func (pa *portAllocator) GetAllocation(allocationID string) ([]int, error) {
ports, exists := pa.allocations[allocationID]
if !exists {
return nil, fmt.Errorf("port allocation ID not found: %s", allocationID)
}
return ports, nil
}
// isAllocated checks if the given ports are already allocated.
func (pa *portAllocator) isAllocated(ports []int) bool {
for _, port := range ports {
if _, reserved := pa.reserved[port]; reserved {
return true
}
}
return false
}
// portsAvailable checks if the requested number of ports are available.
func (pa *portAllocator) portsAvailable(numPorts int) bool {
pa.lock.Lock()
defer pa.lock.Unlock()
ports := pa.getAvailablePorts(numPorts)
return len(ports) == numPorts
}
// Allocator is the interface for the node allocator.
// It is responsible for managing resources and allocations.
type Allocator interface {
// Run starts the allocator.
Run() error
// Commit commits resources and ports for an allocation.
Commit(ctx context.Context,
allocationID string,
resources types.CommittedResources,
ports map[int]int,
numDynamicPorts int,
expiry int64,
) error
// Uncommit uncommits resources and ports for an allocation.
Uncommit(ctx context.Context, allocationID string) error
// Allocate allocates resources and ports for an allocation.
Allocate(
ctx context.Context,
allocationID string,
allocType jobtypes.AllocationType,
actr actor.Actor,
orchestrator actor.Handle,
job jobs.Job,
executor types.Executor,
contracts map[string]types.ContractConfig,
contractEventHandler *eventhandler.EventHandler,
deploymentID string,
) (*jobs.Allocation, error)
// Release releases allocated resources and ports for an allocation.
Release(ctx context.Context, allocationID string) error
// Stop stops the allocator.
Stop(ctx context.Context) error
// CheckAvailability checks if the requested resources and ports are available.
CheckAvailability(ports []int, numDynamicPorts int, resources types.Resources) error
// GetAllocations returns all allocations.
GetAllocations() map[string]*jobs.Allocation
// GetAllocation returns a specific allocation.
GetAllocation(allocationID string) (*jobs.Allocation, error)
}
// allocator is the implementation of the Allocator interface.
type allocator struct {
network network.Network
ports *portAllocator
resources types.ResourceManager
hardware types.HardwareManager
allocations map[string]*jobs.Allocation
monitoredEnsembles map[string]struct{}
commits map[string]int64
workDir string
hostID string
lock sync.Mutex
fs afero.Afero
ctx context.Context
cancel context.CancelFunc
volumeTracker *storage.VolumeTracker
contractStore *store.Store
}
var _ Allocator = (*allocator)(nil)
// newAllocator returns a new default allocator
func newAllocator(
vt *storage.VolumeTracker,
portAllocator *portAllocator,
resourceManager types.ResourceManager,
hardwareManager types.HardwareManager,
network network.Network,
fs afero.Afero,
workDir,
hostID string,
contractStore *store.Store,
) *allocator {
ctx, cancel := context.WithCancel(context.Background())
return &allocator{
ports: portAllocator,
resources: resourceManager,
hardware: hardwareManager,
network: network,
allocations: make(map[string]*jobs.Allocation),
monitoredEnsembles: make(map[string]struct{}),
fs: fs,
workDir: workDir,
hostID: hostID,
commits: make(map[string]int64),
ctx: ctx,
cancel: cancel,
volumeTracker: vt,
contractStore: contractStore,
}
}
func (a *allocator) Run() error {
// start a ticker to clear the commits after expiry
a.clearCommits()
// start monitoring ensemble allocations for cleanup
a.monitorEnsembleAllocations()
return nil
}
func (a *allocator) registerEnsembleMonitor(ensembleID string) {
log.Debugf("Registering monitoring allocations for ensemble %s", ensembleID)
a.monitoredEnsembles[ensembleID] = struct{}{}
}
func (a *allocator) monitorEnsembleAllocations() {
doneStatuses := []jobs.AllocationStatus{jobs.AllocationCompleted, jobs.AllocationTerminated}
log.Debugf("Starting monitoring ensemble allocations")
cleanupFinishedEnsemble := func(ensembleID string, allocationIDs []string) {
log.Debugf("Cleaning up ensemble %s", ensembleID)
// TODO issue #1154 - better handle transient allocations
subnetStatusMx.Lock()
if stat, ok := subnetStatus[ensembleID]; ok && stat == 1 {
if err := a.network.DestroySubnet(ensembleID); err != nil {
log.Warnf("Monitor Ensemble: failed to destroy subnet (it may already be destroyed) %s: %v", ensembleID, err)
}
subnetStatus[ensembleID] = 0 // mark as destroyed
}
subnetStatusMx.Unlock()
for _, allocID := range allocationIDs {
if err := a.Release(a.ctx, allocID); err != nil {
log.Errorf("Monitor Ensemble: failed to release allocation %s: %v", allocID, err)
}
}
a.lock.Lock()
defer a.lock.Unlock()
delete(a.monitoredEnsembles, ensembleID)
}
go func() {
ticker := time.NewTicker(ensembleMonitorFrequency)
defer ticker.Stop()
for {
select {
case <-ticker.C:
doneAllocs := make(map[string][]string)
running := make(map[string]struct{})
for id, alloc := range a.allocations {
ensembleID := types.EnsembleIDFromAllocationID(id)
status := alloc.Status().Status
if slices.Contains(doneStatuses, status) {
doneAllocs[ensembleID] = append(doneAllocs[ensembleID], id)
continue
}
running[ensembleID] = struct{}{}
}
for ensembleID := range a.monitoredEnsembles {
if _, ok := running[ensembleID]; !ok {
cleanupFinishedEnsemble(ensembleID, doneAllocs[ensembleID])
}
}
case <-a.ctx.Done():
return
}
}
}()
}
func (a *allocator) Commit(ctx context.Context,
allocationID string,
resources types.CommittedResources,
ports map[int]int,
numDynamicPorts int,
expiry int64,
) error {
a.lock.Lock()
defer a.lock.Unlock()
// Check against the actual hardware usage to ensure dms can guarantee the commitment
hasCapacity, err := a.hardware.CheckCapacity(resources.Resources)
if err != nil {
return fmt.Errorf("check hardware capacity: %w", err)
}
if !hasCapacity {
return ErrNoHardwareCapacity
}
// commit the resources
if err := a.resources.CommitResources(ctx, resources); err != nil {
return fmt.Errorf("commit resources: %w", err)
}
// revert in case of a failure in following steps
revertResourceCommit := func() {
if err := a.resources.UncommitResources(ctx, allocationID); err != nil {
log.Warnf("failed to revert resource commit for allocation %s: %v", allocationID, err)
}
}
// commit the ports
if len(ports) > 0 {
staticPorts := make([]int, 0, len(ports))
for port := range ports {
staticPorts = append(staticPorts, port)
}
err := a.ports.Allocate(allocationID, staticPorts)
if err != nil {
revertResourceCommit()
return fmt.Errorf("allocate port: %w", err)
}
}
// commit dynamic ports
if numDynamicPorts > 0 {
_, err := a.ports.AllocateRandom(allocationID, numDynamicPorts)
if err != nil {
// uncommit the resources
revertResourceCommit()
// release the static ports if they were allocated
a.ports.Release(allocationID)
return fmt.Errorf("failed to allocate ports: %w", err)
}
}
// store the commit
a.commits[allocationID] = expiry
return nil
}
func (a *allocator) Uncommit(ctx context.Context, allocationID string) error {
log.Debugf("uncommitting allocation %s", allocationID)
// uncommit the ports (do first as it's best-efforts)
a.ports.Release(allocationID)
// Check if the allocation is committed
if _, ok := a.commits[allocationID]; !ok {
log.Warnf("allocation %s not committed", allocationID)
return nil
}
// uncommit the resources
err := a.resources.UncommitResources(ctx, allocationID)
if err != nil {
return fmt.Errorf("uncommit resources: %w", err)
}
// remove the commit from the state
a.lock.Lock()
delete(a.commits, allocationID)
a.lock.Unlock()
log.Debugf("uncommitted allocation %s", allocationID)
return nil
}
func (a *allocator) mountVolumeOnHost(job jobs.Job, allocationID string) error {
if len(job.Volume) == 0 {
return nil
}
for _, v := range job.Volume {
log.Infof("mounting volume %s for allocation %s", v.Name, allocationID)
mounter, err := volume.New(a.volumeTracker, v, allocationID)
if err != nil {
return fmt.Errorf("create volume: %w", err)
}
desginationPath := filepath.Join(a.workDir, "volumes", allocationID, v.Name)
err = utils.CreateDirIfNotExists(a.fs, desginationPath)
if err != nil {
return fmt.Errorf("mount directory: %w", err)
}
err = mounter.Mount(desginationPath, make(map[string]string))
if err != nil {
return fmt.Errorf("failed to mount volume: %w", err)
}
}
return nil
}
func (a *allocator) unmountVolumeOnHost(job jobs.Job, allocationID string) error {
if len(job.Volume) == 0 {
return nil
}
for _, v := range job.Volume {
mounter, err := volume.New(a.volumeTracker, v, allocationID)
if err != nil {
return fmt.Errorf("create volume unmounter: %w", err)
}
desginationPath := filepath.Join(a.workDir, "volumes", allocationID, v.Name)
err = mounter.Unmount(desginationPath)
if err != nil {
return fmt.Errorf("failed to unmount volume: %w", err)
}
}
return nil
}
func (a *allocator) Allocate(
ctx context.Context,
allocationID string,
allocType jobtypes.AllocationType,
allocActor actor.Actor,
orchestrator actor.Handle,
job jobs.Job,
executor types.Executor,
contracts map[string]types.ContractConfig,
contractEventHandler *eventhandler.EventHandler,
deploymentID string,
) (*jobs.Allocation, error) {
// Ensure that the allocation is committed
a.lock.Lock()
defer a.lock.Unlock()
if _, ok := a.commits[allocationID]; !ok {
return nil, fmt.Errorf("allocation not committed: %s", allocationID)
}
// Check against the actual hardware usage to ensure dms can guarantee the allocation
hasCapacity, err := a.hardware.CheckCapacity(job.Resources)
if err != nil {
return nil, fmt.Errorf("check hardware capacity: %w", err)
}
if !hasCapacity {
return nil, ErrNoHardwareCapacity
}
err = a.mountVolumeOnHost(job, allocationID)
if err != nil {
return nil, err
}
// allocate the resources
err = a.resources.AllocateResources(ctx, allocationID)
if err != nil {
return nil, fmt.Errorf("allocate resources: %w", err)
}
allocation, err := jobs.NewAllocation(
allocationID,
allocType,
orchestrator,
a.fs,
a.workDir,
allocActor,
jobs.AllocationDetails{Job: job, NodeID: a.hostID},
a.network,
executor,
func() error { return a.Release(ctx, allocationID) },
contractEventHandler,
a.contractStore,
deploymentID,
)
if err != nil {
return nil, fmt.Errorf("create allocation: %w", err)
}
allocation.Contracts = contracts
// start the allocation
err = allocation.Start()
if err != nil {
return nil, fmt.Errorf("start allocation: %w", err)
}
// delete the commit and store the allocation
delete(a.commits, allocationID)
a.allocations[allocationID] = allocation
a.registerEnsembleMonitor(types.EnsembleIDFromAllocationID(allocation.ID))
return allocation, nil
}
// TODO: it should release on best-efforts
func (a *allocator) Release(ctx context.Context, allocationID string) error {
log.Debugf("releasing allocation %s", allocationID)
a.lock.Lock()
defer a.lock.Unlock()
// deallocate the resources and ports
// do first since it's best effort
a.ports.Release(allocationID)
// Check if allocated
allocation, ok := a.allocations[allocationID]
if !ok {
log.Warnf("allocation %s not found", allocationID)
// The reason we are not returning an error is because
// for instance when shutting down a deployment with allocations of type TASK
// the allocation is not found in the allocator because it was already terminated
// and we don't want to return an error in this case
// return fmt.Errorf("failed to release allocation: allocation %s not found", allocationID)
return nil
}
// stop and cleanup
// TODO: maybe we should not call Terminate since it sets
// the status of an allocation to Terminated but sometimes we're
// releasing Completed task-allocations which should have the
// status as Completed rather than Terminated
err := allocation.Terminate(ctx)
if err != nil {
log.Errorf("terminate allocation: %v", err)
return fmt.Errorf("terminate allocation: %w", err)
}
err = a.resources.DeallocateResources(ctx, allocationID)
if err != nil {
return fmt.Errorf("deallocate resources for allocation id: %s: %w", allocationID, err)
}
if err := a.unmountVolumeOnHost(allocation.Job, allocationID); err != nil {
return fmt.Errorf("unmount volume: %w", err)
}
// remove the allocation
delete(a.allocations, allocationID)
log.Debugf("successfully released allocation %s", allocationID)
return nil
}
func (a *allocator) Stop(ctx context.Context) error {
// stop all allocations
for id, allocation := range a.allocations {
err := allocation.Stop(ctx)
if err != nil {
log.Warnf("stop allocation %s: %v", id, err)
}
}
// clear the commits
for allocationID := range a.commits {
err := a.Uncommit(context.Background(), allocationID)
if err != nil {
log.Warnf("uncommit allocation %s: %v", allocationID, err)
}
}
for _, ensembleID := range a.getRunningEnsemblesIDs() {
err := a.network.DestroySubnet(ensembleID)
if err != nil {
log.Warnf("destroy subnet %s: %v", ensembleID, err)
}
}
// cancel the context to stop the allocator goroutines
a.cancel()
return nil
}
func (a *allocator) CheckAvailability(ports []int, numDynamicPorts int, resources types.Resources) error {
// Check if the requested ports are already allocated
if a.ports.isAllocated(ports) {
return ErrPortsBusy
}
// Check if the requested dynamic ports are available
if !a.ports.portsAvailable(numDynamicPorts) {
return ErrDynamicPortsNotAvailable
}
// Check if the requested resources are available
freeResources, err := a.resources.GetFreeResources(context.Background())
if err != nil {
return fmt.Errorf("get free resources: %w", err)
}
if err := freeResources.Subtract(resources); err != nil {
return ErrResourcesNotAvailable
}
return nil
}
func (a *allocator) GetAllocations() map[string]*jobs.Allocation {
a.lock.Lock()
defer a.lock.Unlock()
allocations := make(map[string]*jobs.Allocation, len(a.allocations))
for k, v := range a.allocations {
allocations[k] = v
}
return allocations
}
func (a *allocator) GetAllocation(allocationID string) (*jobs.Allocation, error) {
a.lock.Lock()
defer a.lock.Unlock()
allocation, ok := a.allocations[allocationID]
if !ok {
return nil, ErrAllocationNotFound
}
return allocation, nil
}
func (a *allocator) getRunningEnsemblesIDs() []string {
a.lock.Lock()
defer a.lock.Unlock()
ensembleIDsSet := make(map[string]struct{})
for id := range a.allocations {
ensembleID := types.EnsembleIDFromAllocationID(id)
ensembleIDsSet[ensembleID] = struct{}{}
}
return utils.MapKeysToSlice(ensembleIDsSet)
}
func (a *allocator) getCommits() map[string]int64 {
a.lock.Lock()
defer a.lock.Unlock()
commits := make(map[string]int64, len(a.commits))
for k, v := range a.commits {
commits[k] = v
}
return commits
}
func (a *allocator) getCommit(allocationID string) (int64, bool) {
a.lock.Lock()
defer a.lock.Unlock()
expiry, ok := a.commits[allocationID]
return expiry, ok
}
func (a *allocator) clearCommits() {
ticker := time.NewTicker(clearCommitsFrequency)
go func() {
select {
case <-ticker.C:
for allocationID, expiry := range a.commits {
if expiry < time.Now().Unix() {
err := a.Uncommit(context.Background(), allocationID)
if err != nil {
log.Warnf("uncommit allocation %s: %v", allocationID, err)
}
}
}
case <-a.ctx.Done():
ticker.Stop()
return
}
}()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/google/uuid"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/jobs"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/dms/orchestrator"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/tokenomics/eventhandler"
"gitlab.com/nunet/device-management-service/tokenomics/events"
"gitlab.com/nunet/device-management-service/types"
)
func (n *Node) handleSubnetCreate(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error creating subnet: %s", err)
n.sendReply(msg, orchestrator.SubnetCreateResponse{Error: err.Error()})
}
var request orchestrator.SubnetCreateRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling subnet create request: %s", err))
return
}
resp := orchestrator.SubnetCreateResponse{}
err := n.network.CreateSubnet(context.Background(), request.SubnetID, request.CIDR, request.RoutingTable)
if err != nil {
handleErr(err)
return
}
// TODO issue #1154 - better handle transient allocations
subnetStatusMx.Lock()
subnetStatus[request.SubnetID] = 1
subnetStatusMx.Unlock()
resp.OK = true
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetDestroy(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error destroying subnet: %s", err)
n.sendReply(msg, orchestrator.SubnetDestroyResponse{Error: err.Error()})
}
var request orchestrator.SubnetDestroyRequest
resp := orchestrator.SubnetDestroyResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling subnet destroy: %s", err))
return
}
// if subnet already destroyed by a transient alloc cleaning up after itself
subnetStatusMx.Lock()
if subnetStatus, ok := subnetStatus[request.SubnetID]; ok && subnetStatus == 0 {
// Subnet is already destroyed
resp.OK = true
n.sendReply(msg, resp)
subnetStatusMx.Unlock()
return
}
subnetStatusMx.Unlock()
err := n.network.DestroySubnet(request.SubnetID)
if err != nil {
handleErr(err)
return
}
// TODO issue #1154 - better handle transient allocations
subnetStatusMx.Lock()
subnetStatus[request.SubnetID] = 0
subnetStatusMx.Unlock()
resp.OK = true
n.sendReply(msg, resp)
}
func (n *Node) handleSubnetJoin(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error subnet join: %s", err)
n.sendReply(msg, orchestrator.SubnetJoinResponse{Error: err.Error()})
}
var request orchestrator.SubnetJoinRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling subnet join: %s", err))
return
}
resp := orchestrator.SubnetJoinResponse{}
_ = n.network.RemoveSubnetPeers(request.SubnetID, map[string]string{request.IP: request.PeerID})
err := n.network.AddSubnetPeer(request.SubnetID, request.PeerID, request.IP)
if err != nil {
handleErr(err)
return
}
err = n.network.AcceptSubnetPeers(request.SubnetID, request.RoutingTable)
if err != nil {
handleErr(err)
return
}
err = n.network.AddSubnetDNSRecords(request.SubnetID, request.Records)
if err != nil {
handleErr(err)
return
}
resp.OK = true
n.sendReply(msg, resp)
}
func (n *Node) addEnsembleBehaviors(ensembleID string) error {
dmsBehaviors := map[string]struct {
fn func(actor.Envelope)
opts []actor.BehaviorOption
}{
fmt.Sprintf(behaviors.SubnetCreateBehavior.DynamicTemplate, ensembleID): {
fn: n.handleSubnetCreate,
},
fmt.Sprintf(behaviors.SubnetDestroyBehavior.DynamicTemplate, ensembleID): {
fn: n.handleSubnetDestroy,
},
fmt.Sprintf(behaviors.AllocationLogsBehavior.DynamicTemplate, ensembleID): {
fn: n.handleAllocationLogs,
},
fmt.Sprintf(behaviors.AllocationShutdownBehavior.DynamicTemplate, ensembleID): {
fn: n.handleAllocationShutdown,
},
}
for behavior, handler := range dmsBehaviors {
if err := n.actor.AddBehavior(behavior, handler.fn, handler.opts...); err != nil {
return fmt.Errorf("adding %s behavior: %w", behavior, err)
}
}
return nil
}
// createAllocation creates an allocation
func (n *Node) createAllocation(
allocationID string,
allocType jobtypes.AllocationType,
job jobs.Job, supervisor actor.Handle,
contracts map[string]types.ContractConfig,
deploymentID string,
) (*jobs.Allocation, error) {
if contracts == nil {
contracts = make(map[string]types.ContractConfig)
}
executor, err := createExecutor(context.Background(), n.fs, job.Execution.Type)
if err != nil {
return nil, fmt.Errorf("create executor: %w", err)
}
allocActor, err := n.actor.CreateChild(allocationID, supervisor)
if err != nil {
return nil, fmt.Errorf("create allocation actor: %w", err)
}
allocation, err := n.allocator.Allocate(
context.Background(), allocationID,
allocType, allocActor, supervisor,
job, executor,
contracts,
n.contractEventHandler,
deploymentID,
)
if err != nil {
return nil, fmt.Errorf("allocate: %w", err)
}
// Find Head Contract config from ensemble contracts
computeProviderDID := n.actor.Handle().DID.URI
var headContractConfig types.ContractConfig
for _, contractConfig := range contracts {
headContractConfig = contractConfig
break
}
// Determine which contracts to notify
var contractsToNotify map[string]types.ContractConfig
if headContractConfig.DID != "" {
// Contract chain: find and use Tail Contracts using Head Contract config
tailContract, err := n.contractStore.FindTailContract(
headContractConfig,
computeProviderDID,
)
if err != nil {
log.Warnw("failed to find tail contracts, falling back to ensemble contracts",
"head_contract_did", headContractConfig.DID,
"error", err)
contractsToNotify = contracts // Fallback
} else {
// Convert to map format
contractsToNotify = make(map[string]types.ContractConfig)
contractsToNotify[tailContract.DID] = *tailContract
}
} else {
// P2P: use ensemble contracts
contractsToNotify = contracts
}
// Send events to appropriate contracts
headContractDID := headContractConfig.DID
for _, v := range contractsToNotify {
evt := events.CreateAllocation{
EventBase: events.EventBase{Type: events.CreateAllocationEvent},
Resources: job.Resources,
AllocationBase: events.AllocationBase{
AllocationID: allocationID,
DeploymentID: deploymentID,
ComputeProviderDID: computeProviderDID,
HeadContractDID: headContractDID, // Include Head Contract DID in payload
},
}
n.contractEventHandler.Push(eventhandler.Event{
ContractHostDID: v.Host,
ContractDID: v.DID,
Payload: evt,
})
}
return allocation, nil
}
func (n *Node) createAllocations(
ensembleID string,
allocations map[string]jobtypes.AllocationDeploymentConfig,
supervisor actor.Handle,
) (map[string]actor.Handle, error) {
if len(allocations) == 0 {
log.Errorf("no allocations to create for ensembleID: %s", ensembleID)
return nil, fmt.Errorf("no allocations to create for ensembleID: %s", ensembleID)
}
if supervisor.Empty() || supervisor.DID.Empty() {
log.Errorf("invalid supervisor handle: %+v", supervisor)
return nil, fmt.Errorf("invalid supervisor handle")
}
allocHandlesByID := make(map[string]actor.Handle, len(allocations))
for allocationID, allocationConfig := range allocations {
allocation, err := n.createAllocation(
allocationID,
allocationConfig.Type,
jobs.Job{
Resources: allocationConfig.Resources,
Execution: allocationConfig.Execution,
ProvisionScripts: allocationConfig.ProvisionScripts,
Keys: allocationConfig.Keys,
Volume: allocationConfig.Volume,
},
supervisor,
allocationConfig.Contracts,
ensembleID,
)
if err != nil {
return nil, fmt.Errorf("create allocation %s: %w", allocationID, err)
}
allocHandlesByID[allocationID] = allocation.Actor.Handle()
// node grants subnet create/destroy caps to the orchestrator
if err := n.actor.Security().Grant(supervisor.DID, n.actor.Handle().DID, []ucan.Capability{
ucan.Capability(fmt.Sprintf(behaviors.EnsembleNamespace, ensembleID)),
}, grantAllocationCapsFreq); err != nil {
return nil, fmt.Errorf("grant node caps: %w", err)
}
allocDID, err := did.FromID(allocation.Actor.Handle().ID)
if err != nil {
return nil, fmt.Errorf("deriving allocation did: %w", err)
}
if err := n.actor.Security().Grant(supervisor.DID, allocDID, []ucan.Capability{
behaviors.AllocationNamespace,
}, grantAllocationCapsFreq); err != nil {
return nil, fmt.Errorf("grant allocation caps: %w", err)
}
// refresh allocation caps grants periodically
go func() {
ticker := time.NewTicker(grantAllocationCapsFreq)
defer ticker.Stop()
for allocation.Status().Status != jobs.AllocationStopped {
select {
case <-n.ctx.Done():
return
case <-ticker.C:
// node grants subnet create/destroy caps to the orchestrator
if err := n.actor.Security().Grant(supervisor.DID, n.actor.Handle().DID, []ucan.Capability{
ucan.Capability(fmt.Sprintf(behaviors.EnsembleNamespace, ensembleID)),
}, grantAllocationCapsFreq); err != nil {
log.Warnf("grant node caps: %v", err)
}
// allocation grants subnet manage caps to the orchestrator
if err := n.actor.Security().Grant(supervisor.DID, allocDID, []ucan.Capability{
behaviors.AllocationNamespace,
}, grantAllocationCapsFreq); err != nil {
log.Warnf("grant allocation caps: %v", err)
}
}
}
}()
}
log.Infof("Finished createAllocations for ensembleID: %s", ensembleID)
return allocHandlesByID, nil
}
// TODO (wrong nomenclature): handleAllocationDeployment -> handleEnsembleDeployment
func (n *Node) handleAllocationDeployment(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling allocation deployment: %s", err)
n.sendReply(msg, jobtypes.AllocationDeploymentResponse{Error: err.Error()})
}
var request jobtypes.AllocationDeploymentRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
resp := jobtypes.AllocationDeploymentResponse{}
if err := n.addEnsembleBehaviors(request.EnsembleID); err != nil {
handleErr(fmt.Errorf("failed to register dynamic behaviors: %s", err))
return
}
allocations, err := n.createAllocations(
request.EnsembleID,
request.Allocations,
msg.From,
)
if err != nil {
handleErr(err)
return
}
resp.OK = true
resp.Allocations = allocations
n.sendReply(msg, resp)
}
type AllocationShutdownRequest struct {
AllocationID string
}
type AllocationShutdownResponse struct {
OK bool
Error string
}
func (n *Node) handleAllocationShutdown(msg actor.Envelope) {
log.Debugf("handling allocation shutdown request from %s", msg.From.DID)
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling allocation shutdown request: %s", err)
n.sendReply(msg, AllocationShutdownResponse{Error: err.Error()})
}
var request AllocationShutdownRequest
resp := AllocationShutdownResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling allocation shutdown request: %s", err))
return
}
err := n.allocator.Release(context.Background(), request.AllocationID)
if err != nil {
handleErr(err)
return
}
resp.OK = true
n.sendReply(msg, resp)
}
func ensembleIDFromBehavior(b string) (string, error) {
parts := strings.Split(b, "/")
if len(parts) > 3 {
return parts[3], nil
}
return "", fmt.Errorf("invalid ensemble behavior: %s", b)
}
func (n *Node) handleAllocationLogs(msg actor.Envelope) {
defer msg.Discard()
log.Infof("behavior get logs invoked by: %+v", msg.From)
handleErr := func(err error) {
log.Errorf("error getting allocation logs: %s", err)
n.sendReply(msg, orchestrator.AllocationLogsResponse{Error: err.Error()})
}
var resp orchestrator.AllocationLogsResponse
ensembleID, err := ensembleIDFromBehavior(msg.Behavior)
if err != nil {
handleErr(fmt.Errorf("error getting ensemble ID from behavior %s: %s", msg.Behavior, err))
return
}
var req orchestrator.AllocationLogsRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("allocation logs request: %w", types.ErrUnmarshal))
return
}
allocID := types.ConstructAllocationID(ensembleID, req.AllocName)
resultsDir := filepath.Join(n.dmsConfig.WorkDir, "jobs", allocID)
stdout, err := n.fs.ReadFile(filepath.Join(resultsDir, "stdout.log"))
if err != nil {
if errors.Is(err, os.ErrNotExist) {
log.Warnf("stdout file for allocation %s does not exist (ensemble: %s)", req.AllocName, ensembleID)
} else {
handleErr(fmt.Errorf("failed to read results file: %s", err))
return
}
}
stderr, err := n.fs.ReadFile(filepath.Join(resultsDir, "stderr.log"))
if err != nil {
if err == os.ErrNotExist {
log.Debugf("stderr file for allocation %s does not exist (ensemble: %s)", req.AllocName, ensembleID)
} else {
handleErr(fmt.Errorf("failed to read results file: %s", err))
return
}
}
if len(stdout) == 0 && len(stderr) == 0 {
handleErr(
fmt.Errorf("stdout and stderr files for allocation %s are empty (ensemble: %s)",
req.AllocName, ensembleID),
)
return
}
log.Info("sending logs for allocation: ", allocID)
resp.Stdout = stdout
resp.Stderr = stderr
n.sendReply(msg, resp)
}
// AllocationsListResponse represents the response for the allocations list request
type AllocationsListResponse struct {
Allocations []jobs.AllocationInfo `json:"allocations"`
Error string `json:"error,omitempty"`
}
// handleAllocationsList returns information about all running allocations
func (n *Node) handleAllocationsList(msg actor.Envelope) {
defer msg.Discard()
resp := AllocationsListResponse{
Allocations: []jobs.AllocationInfo{},
}
allocations := n.allocator.GetAllocations()
for _, alloc := range allocations {
resp.Allocations = append(resp.Allocations, alloc.Info())
}
n.sendReply(msg, resp)
}
func createExecutor(ctx context.Context, fs afero.Afero, executionType string) (types.Executor, error) {
switch executionType {
case types.ExecutorTypeDocker.String():
id := uuid.New().String()
exec, err := docker.NewExecutor(ctx, fs, id)
if err != nil {
return nil, fmt.Errorf("create executor: %w", err)
}
return exec, nil
default:
return nil, fmt.Errorf("unsupported executor type: %s", executionType)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/jobs"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/gateway/provider"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/types"
)
const bidStateTimeout = 5 * time.Minute
func (n *Node) getExecutor(execType jobs.AllocationExecutor) (executorMetadata, error) {
n.lock.RLock()
defer n.lock.RUnlock()
e, ok := n.executors[string(execType)]
if !ok {
return executorMetadata{}, errors.New("executor not available")
}
return e, nil
}
func (n *Node) storeBid(eid string, nonce uint64, req jobtypes.BidRequest) {
n.lock.Lock()
defer n.lock.Unlock()
n.bids[eid] = &bidState{
expire: time.Now().Add(bidStateTimeout),
nonce: nonce,
request: req,
}
n.answeredBids[eid] = append(n.answeredBids[eid], nonce)
}
func (n *Node) getBid(eid string) (*bidState, bool) {
n.lock.Lock()
defer n.lock.Unlock()
b, exists := n.bids[eid]
return b, exists
}
func (n *Node) bidAnswered(eid string, nonce uint64) bool {
n.lock.Lock()
defer n.lock.Unlock()
for e, n := range n.answeredBids {
if e == eid && slices.Contains(n, nonce) {
return true
}
}
return false
}
func (n *Node) location() jobtypes.Location {
n.lock.RLock()
defer n.lock.RUnlock()
return jobtypes.Location{
Continent: n.hostLocation.Continent,
Country: n.hostLocation.Country,
City: n.hostLocation.City,
}
}
func (n *Node) verifyContract(bidContracts map[string]types.ContractConfig, orchestratorDID did.DID) error {
for contractKey, contractConfig := range bidContracts {
// Chain verification mode (Contract A: Orchestrator ↔ Organization)
// Provider will verify Contract B (Organization ↔ Provider) via chain verification
if err := n.verifyContractChain(contractConfig, orchestratorDID, contractKey); err != nil {
log.Debugw("contract_chain_verification_failed",
"labels", string(observability.LabelDeployment),
"contract_key", contractKey,
"error", err,
"reason", "likely not a chain contract",
)
log.Debugw("performing p2p contract verification since contract chain verification failed",
"labels", string(observability.LabelDeployment),
"contract_key", contractKey,
)
// Traditional P2P verification (provider is directly in the contract)
if err := n.verifyP2PContract(contractConfig); err != nil {
return fmt.Errorf("contract verification failed for %s: %w", contractKey, err)
}
} else {
log.Debugw("contract_verification_success",
"labels", string(observability.LabelDeployment),
"contract_key", contractKey,
"contract_chain", true)
}
}
return nil
}
// verifyP2PContract performs traditional P2P contract verification
func (n *Node) verifyP2PContract(contractConfig types.ContractConfig) error {
hostDID, err := did.FromString(contractConfig.Host)
if err != nil {
return fmt.Errorf("failed to get contracts host did: %w", err)
}
pubKey, err := did.PublicKeyFromDID(hostDID)
if err != nil {
return fmt.Errorf("failed to get contracts host public key from did: %w", err)
}
pid, err := peer.IDFromPublicKey(pubKey)
if err != nil {
return fmt.Errorf("failed to get peer id: %w", err)
}
contractActorDID, err := did.FromString(contractConfig.DID)
if err != nil {
return fmt.Errorf("failed to get contracts actor did: %w", err)
}
pubKeyContractActor, err := did.PublicKeyFromDID(contractActorDID)
if err != nil {
return fmt.Errorf("failed to get contracts actor public key from did: %w", err)
}
destination, err := actor.HandleFromPublicKeyWithInboxAddress(pubKeyContractActor, contractConfig.DID, pid.String())
if err != nil {
return fmt.Errorf("failed to get contracts host handle: %w", err)
}
req := contracts.ContractValidateRequest{ContractDID: contractConfig.DID}
reply, err := n.invokeBehaviour(destination, behaviors.ContractValidationBehavior, req, invokeMessageTimeout)
if err != nil {
return fmt.Errorf("failed to send message to contract host: %w", err)
}
var respEnvelope contracts.ContractValidateResponse
err = json.Unmarshal(reply.Message, &respEnvelope)
if err != nil {
return fmt.Errorf("failed to unmarshal contract hosts response payload: %w", err)
}
if !respEnvelope.Valid {
return fmt.Errorf("contract is invalid")
}
return nil
}
// verifyContractChain performs contract chain verification
// contractConfig contains Contract A (head contract: Orchestrator ↔ Organization)
// The orchestrator specifies the head contract, and the provider finds Contract B
// with the same organization specified in the head contract
func (n *Node) verifyContractChain(contractConfig types.ContractConfig, orchestratorDID did.DID, contractKey string) error {
hostDID, err := did.FromString(contractConfig.Host)
if err != nil {
return fmt.Errorf("failed to parse host DID: %w", err)
}
// Get provider DID (self)
providerDID := n.actor.Handle().DID
pubKeyContractHost, err := did.PublicKeyFromDID(hostDID)
if err != nil {
return fmt.Errorf("failed to get contract actor public key: %w", err)
}
pid, err := peer.IDFromPublicKey(pubKeyContractHost)
if err != nil {
return fmt.Errorf("failed to get peer id: %w", err)
}
// For chain verification, we call the behavior on the ContractActor (not the node)
// The ContractActor will find Contract A (using contractConfig.DID) and Contract B (Org ↔ Provider)
destination, err := actor.HandleFromPeerID(pid.String())
if err != nil {
return fmt.Errorf("failed to construct contract actor handle: %w", err)
}
req := contracts.ContractChainVerificationRequest{
SolutionEnablerDID: hostDID.String(),
ContractDID: contractConfig.DID, // Contract DID from config
OrchestratorDID: orchestratorDID.String(),
ProviderDID: providerDID.String(),
}
reply, err := n.invokeBehaviour(destination, behaviors.ContractChainVerificationBehavior, req, invokeMessageTimeout)
if err != nil {
return fmt.Errorf("failed to invoke chain verification: %w", err)
}
var resp contracts.ContractChainVerificationResponse
err = json.Unmarshal(reply.Message, &resp)
if err != nil {
return fmt.Errorf("failed to unmarshal chain verification response: %w", err)
}
if !resp.Valid {
return fmt.Errorf("contract chain verification failed: %s", resp.Error)
}
log.Infow("contract chain verification successful",
"labels", string(observability.LabelDeployment),
"contract_key", contractKey,
"organization_did", resp.OrganizationDID,
"orchestrator_contract", resp.OrchestratorContract.ContractDID,
"provider_contract", resp.ProviderContract.ContractDID)
return nil
}
// gateway logic to decide a bid or not goes here
// we keep all the restrictions and contrains here as it is for normal bid
func (n *Node) handleBidRequest(msg actor.Envelope) {
defer msg.Discard()
// ignore bid request from self if broadcast
// only accept self bid if own peer specified on ensemble
if msg.IsBroadcast() &&
n.actor.Handle().Address.HostID == msg.From.Address.HostID {
return
}
log.Infow(
"got a bid request from actor",
"labels", string(observability.LabelDeployment),
"from", msg.From.Address,
)
// if not a gateway check onboarded
if !n.dmsConfig.General.ComputeGateway {
if onboarded := n.onboarding.IsOnboarded(); !onboarded {
log.Debugw(
"node_not_onboarded_ignoring_bid_request",
"labels", string(observability.LabelDeployment),
)
return
}
}
var request jobtypes.EnsembleBidRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
log.Debugw(
"unmarshal_bid_request_error",
"labels", string(observability.LabelDeployment),
"error", err,
)
return
}
log.Infow(
"bid_request",
"labels", string(observability.LabelDeployment),
"from", msg.From.Address,
"orchestratorID", request.ID,
)
// metric
if m := observability.BidReceived; m != nil {
m.Add(n.ctx, 1, metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", request.ID),
))
}
if n.dmsConfig.Job.RequireContractsForDeployment {
// contracts are global at ensemble level so they apply to all nodes
if len(request.Request) > 0 {
if len(request.Request[0].V1.Contracts) == 0 {
log.Debugw(
"bid_request_missing_contracts_for_deployment",
"labels", string(observability.LabelDeployment),
"ensemble_id", request.ID,
)
return
}
}
}
// contracts are global at ensemble level so they apply
// to all nodes
if len(request.Request) > 0 {
if len(request.Request[0].V1.Contracts) > 0 {
// Extract orchestrator DID from message envelope
orchestratorDID := msg.From.DID
err := n.verifyContract(request.Request[0].V1.Contracts, orchestratorDID)
if err != nil {
log.Errorw(
"contract_verification_error",
"labels", string(observability.LabelDeployment),
"error", err,
)
return
}
log.Infow("contract_verification_success",
"contracts", request.Request[0].V1.Contracts,
"labels", string(observability.LabelDeployment),
)
} else {
log.Debugw(
"contracts_empty",
"labels", string(observability.LabelDeployment),
)
}
}
machineResources, err := n.hardware.GetMachineResources()
if err != nil {
log.Debugw(
"machine_resources_retrieval_error",
"labels", string(observability.LabelDeployment),
"error", err,
)
return
}
// randomize the order of bid request checks
rand.Shuffle(len(request.Request), func(i, j int) {
request.Request[i], request.Request[j] = request.Request[j], request.Request[i]
})
// find the first bid request that matches
var toAnswer jobtypes.BidRequest
var found bool
loop:
for _, v := range request.Request {
// check if it is a V1 request
if v.V1 == nil {
log.Debugw("bid_request_not_v1",
"labels", string(observability.LabelDeployment))
continue
}
answered := n.bidAnswered(request.ID, request.Nonce)
if answered {
log.Debugf("bid already answered: ensembleID: %s, nonce: %d", request.ID, request.Nonce)
return
}
// check if we are excluded
hostID := n.actor.Handle().Address.HostID
for _, p := range request.PeerExclusion {
if p == hostID {
log.Debugw("bid_request_excluded_peer",
"labels", string(observability.LabelDeployment),
"hostID", hostID)
continue loop
}
}
constraints := v.V1.Location
if !n.location().Satisfies(constraints) {
log.Debugw("bid_request_location_constraints_not_satisfied",
"labels", string(observability.LabelDeployment),
"nodeID", v.V1.NodeID,
"ourLocation", n.location(),
"constraints", constraints,
)
continue loop
}
// if the desired executable is not found stop
if !n.dmsConfig.General.ComputeGateway {
for _, e := range v.V1.Executors {
executor, err := n.getExecutor(e)
if err != nil {
log.Debugw("executor_unavailable",
"labels", string(observability.LabelDeployment),
"executor", e,
"error", err)
continue loop
}
if executor.executionType == jobtypes.ExecutorDocker {
if v.V1.GeneralRequirements.PrivilegedDocker {
if !n.dmsConfig.AllowPrivilegedDocker {
log.Debugw("privileged_docker_not_allowed",
"labels", string(observability.LabelDeployment))
continue loop
}
}
}
}
}
if !n.dmsConfig.General.ComputeGateway {
comparisonResult, err := machineResources.Compare(v.V1.Resources)
if err != nil {
log.Debugw("compare_machine_resources_error",
"labels", string(observability.LabelDeployment),
"error", err)
continue loop
}
if comparisonResult != types.Better {
log.Debugw("resource_not_better",
"labels", string(observability.LabelDeployment),
"comparisonResult", comparisonResult)
continue
}
} else {
// make this concurrent
foundServer := int32(0)
allProviders := n.serverProviderRegistry.All()
log.Debugf("server providers %d", len(allProviders))
var wg sync.WaitGroup
for _, pp := range allProviders {
wg.Add(1)
go func(pp provider.Provider) {
defer wg.Done()
if atomic.LoadInt32(&foundServer) == 1 {
return
}
plans, err := pp.ListPlans(n.ctx)
if err != nil {
return
}
_, err = pp.SelectMatchingPlan(plans, v.V1.Resources)
if err != nil {
return
}
atomic.StoreInt32(&foundServer, 1)
}(pp)
}
wg.Wait()
if atomic.LoadInt32(&foundServer) == 0 {
log.Debug("couldn't find servers to provision")
return
}
}
found = true
toAnswer = v
break
}
if !found {
log.Debugw("bid_requirements_not_satisfied",
"labels", string(observability.LabelDeployment))
return
}
if !n.dmsConfig.General.ComputeGateway {
if err := n.allocator.CheckAvailability(toAnswer.V1.PublicPorts.Static, toAnswer.V1.PublicPorts.Dynamic, toAnswer.V1.Resources); err != nil {
log.Debugw("no_resource_availability_for_bid",
"labels", string(observability.LabelDeployment),
"nodeID", toAnswer.V1.NodeID,
"staticPorts", toAnswer.V1.PublicPorts.Static,
"dynamicPorts", toAnswer.V1.PublicPorts.Dynamic,
"resources", toAnswer.V1.Resources,
"error", err)
return
}
}
log.Debugw("signing_bid_with_node_identity",
"labels", string(observability.LabelDeployment),
"DID", n.actor.Security().DID())
provider, err := n.rootCap.Trust().GetProvider(n.actor.Security().DID())
if err != nil {
log.Debugw("provider_retrieval_error",
"labels", string(observability.LabelDeployment),
"error", err)
return
}
log.Debugw("signing_bid_with_provider",
"labels", string(observability.LabelDeployment),
"providerDID", provider.DID())
bid := jobtypes.Bid{
V1: &jobtypes.BidV1{
EnsembleID: request.ID,
NodeID: toAnswer.V1.NodeID,
Peer: n.hostID,
PubAddress: n.publicIP.String(),
Location: n.location(),
Handle: n.actor.Handle(),
},
}
// indicate if its a promise bid
if n.dmsConfig.General.ComputeGateway {
bid.V1.PromiseBid = true
}
if err := bid.Sign(provider); err != nil {
log.Debugw("bid_signing_error",
"labels", string(observability.LabelDeployment),
"error", err)
return
}
log.Infow("sending_bid_response",
"labels", string(observability.LabelDeployment),
"ensembleID", request.ID,
"nodeID", toAnswer.V1.NodeID,
"peerID", n.hostID,
"nonce", request.Nonce)
n.sendReply(msg, bid)
n.storeBid(request.ID, request.Nonce, toAnswer)
// metric
if m := observability.BidAccepted; m != nil {
m.Add(n.ctx, 1, metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", request.ID),
))
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"encoding/json"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
type CapListRequest struct {
Context string
}
type CapListResponse struct {
OK bool
Error string
Roots []did.DID
Require ucan.TokenList
Provide ucan.TokenList
Revoke ucan.TokenList
}
func (n *Node) handleCapList(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling capabilities list: %s", err)
n.sendReply(msg, CapListResponse{Error: err.Error()})
}
var request CapListRequest
resp := CapListResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
roots, require, provide, revoke := n.rootCap.ListRoots()
resp.OK = true
resp.Roots = roots
resp.Require = require
resp.Provide = provide
resp.Revoke = revoke
n.sendReply(msg, resp)
}
type CapRootAnchorRequest struct {
DID []did.DID
}
type CapTokenAnchorRequest struct {
Token ucan.TokenList
}
type CapAnchorResponse struct {
OK bool
Error string
}
func (n *Node) handleProvideCapAnchor(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling provide capability anchor: %s", err)
n.sendReply(msg, CapAnchorResponse{Error: err.Error()})
}
var request CapTokenAnchorRequest
resp := CapAnchorResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
if err := n.rootCap.AddRoots(nil, ucan.TokenList{}, request.Token, ucan.TokenList{}); err != nil {
handleErr(err)
return
}
if err := SaveCapabilityContext(n.rootCap, n.fs, n.dmsConfig.UserDir); err != nil {
handleErr(err)
return
}
resp.OK = true
n.sendReply(msg, resp)
}
func (n *Node) handleRequireCapAnchor(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling require capability anchor: %s", err)
n.sendReply(msg, CapAnchorResponse{Error: err.Error()})
}
var request CapTokenAnchorRequest
resp := CapAnchorResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
if err := n.rootCap.AddRoots(nil, request.Token, ucan.TokenList{}, ucan.TokenList{}); err != nil {
handleErr(err)
return
}
if err := SaveCapabilityContext(n.rootCap, n.fs, n.dmsConfig.UserDir); err != nil {
handleErr(err)
return
}
resp.OK = true
n.sendReply(msg, resp)
}
func (n *Node) handleRevokeCapAnchor(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling revoke capability anchor: %s", err)
n.sendReply(msg, CapAnchorResponse{Error: err.Error()})
}
var request CapTokenAnchorRequest
resp := CapAnchorResponse{}
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
if err := n.rootCap.AddRoots(nil, ucan.TokenList{}, ucan.TokenList{}, request.Token); err != nil {
handleErr(err)
return
}
if err := SaveCapabilityContext(n.rootCap, n.fs, n.dmsConfig.UserDir); err != nil {
handleErr(err)
return
}
resp.OK = true
n.sendReply(msg, resp)
}
type CapBroadcastRequest struct {
Context string // Capability context name to broadcast revocations from
}
type CapBroadcastResponse struct {
OK bool
Error string
TokensCount int // Number of revocation tokens broadcast
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"math/big"
"strconv"
"strings"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/tokenomics"
cardanoClient "gitlab.com/nunet/device-management-service/tokenomics/client/cardano"
ethereumClient "gitlab.com/nunet/device-management-service/tokenomics/client/ethereum"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/store"
payment_quote "gitlab.com/nunet/device-management-service/tokenomics/store/payment_quote"
"gitlab.com/nunet/device-management-service/tokenomics/store/transaction"
"gitlab.com/nunet/device-management-service/types"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
const (
invokeMessageTimeout = 20 * time.Second
invokeSignRequestTimeout = 2 * time.Minute
cardanoBlockchain = "CARDANO"
ethereumBlockchain = "ETHEREUM"
)
// handleContractUsagesCalculate produces the usages and forwards them to
// payment validators
func (n *Node) handleContractUsagesCalculate(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.CollectUsagesAndForwardToPaymentProvidersReponse{}
// Parse request, default to empty request (process all contracts) if no message
var req contracts.CollectUsagesAndForwardToPaymentProvidersRequest
if len(msg.Message) > 0 {
if err := json.Unmarshal(msg.Message, &req); err != nil {
resp.Error = fmt.Errorf("failed to unmarshal request: %w", err).Error()
n.sendReply(msg, resp)
return
}
}
resp = n.collectUsagesAndForwardToPaymentProviders(req)
errAggregated := ""
for _, result := range resp.Results {
if result.Error != "" {
errAggregated += result.Error + "\n"
}
}
if errAggregated != "" {
resp.Error = errAggregated
}
n.sendReply(msg, resp)
}
// handleNewContract is registered on the contract host
func (n *Node) handleNewContract(msg actor.Envelope) {
defer msg.Discard()
solutionEnablerDID := ""
handleErr := func(err error) {
log.Errorw("handle_contract_propose",
"labels", []string{string(observability.LabelContract)},
"error", err, "solutionEnablerDID", solutionEnablerDID)
n.sendReply(msg, contracts.CreateContractResponse{Error: err.Error()})
}
var request contracts.CreateContractRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("unmarshal create contract request: %s", err))
return
}
// Validate required fields
if err := request.Validate(); err != nil {
handleErr(fmt.Errorf("invalid contract request: %w", err))
return
}
solutionEnablerDID = request.SolutionEnablerDID.String()
// Service provider path: forward to contract host, persist local copy, relay response.
if !request.SolutionEnablerDID.Equal(n.actor.Handle().DID) {
resp, err := n.forwardContractCreateToHost(request)
if err != nil {
handleErr(err)
return
}
n.sendReply(msg, resp)
return
}
creatorOfContract := msg.From
// Contract host path: existing behaviour
resp, err := n.createContractOnHost(request, creatorOfContract)
if err != nil {
handleErr(err)
return
}
n.sendReply(msg, resp)
}
func (n *Node) createContractOnHost(request contracts.CreateContractRequest, creatorOfContract actor.Handle) (contracts.CreateContractResponse, error) {
privKey, pubKey, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("failed to generate contract keypair: %w", err)
}
// Validate payment details before creating contract
processor, err := contracts.GetPaymentModelProcessor(request.PaymentDetails.PaymentModel)
if err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("invalid payment model %q: %w", request.PaymentDetails.PaymentModel, err)
}
if err := processor.Validate(request.PaymentDetails); err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("invalid payment details: %w", err)
}
// Create forwardInvoice function to forward invoices from contract actor to payment validator
forwardInvoice := func(req contracts.ContractUsageRequest) error {
destination, err := actor.HandleFromDID(request.PaymentValidatorDID.URI)
if err != nil {
return fmt.Errorf("failed to get payment validator handle: %w", err)
}
envelope, err := n.invokeBehaviour(destination, behaviors.ContractUsageBehavior, req, invokeMessageTimeout)
if envelope.Message == nil || err != nil {
return fmt.Errorf("failed to forward invoice to payment validator: %w", err)
}
log.Infof("Successfully sent invoice for contract %s to payment validator (payment_model: %s)", req.Contract.ContractDID, req.Contract.PaymentDetails.PaymentModel)
return nil
}
contractActor, err := tokenomics.NewContractActor(n.actor.Handle(), request.PaymentValidatorDID, n.network, request.ContractParticipants, privKey, pubKey, n.contractStore, n.usageStore, forwardInvoice)
if err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("failed to create contract actor: %w", err)
}
contractObj := contracts.NewContract(contractActor.ContractDID.URI, request)
// Determine if this is a Head Contract in a chain and set metadata
// Head Contract: Provider = Organization (not a compute provider)
// This can be detected from ContractParticipants structure or explicit flag
// For now, we'll use a helper function that can be enhanced based on actual use cases
if isHeadContractFromRequest(request) {
if contractObj.Metadata == nil {
contractObj.Metadata = make(map[string]string)
}
contractObj.Metadata[contracts.ContractChainRoleMetadataKey] = contracts.ContractChainRoleHead
}
if err := n.contractStore.Upsert(contractObj); err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("failed to save contract: %w", err)
}
// Initialize usage metadata for this contract
if err := n.usageStore.InitializeContractMetadata(contractActor.ContractDID.URI); err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("failed to initialize usage metadata for contract %s: %w", contractActor.ContractDID.URI, err)
}
pkBytes, err := crypto.PublicKeyToBytes(pubKey)
if err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("failed to convert public key to bytes: %w", err)
}
privKeyBytes, err := crypto.PrivateKeyToBytes(privKey)
if err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("failed to convert private key to bytes: %w", err)
}
if err := n.contractStore.InsertContractKey(store.ContractKey{
ContractDID: contractActor.ContractDID.URI,
Key: privKeyBytes,
}); err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("failed to save actor private key for contract %s: %w", contractActor.ContractDID.URI, err)
}
if err := contractActor.Start(); err != nil {
return contracts.CreateContractResponse{}, fmt.Errorf("failed to start actor: %w", err)
}
// Register with billing scheduler
// NOTE: This is safe even if the contract is later loaded by StartContracts()
// because RegisterContract() is idempotent and checks for existing registration
if n.billingScheduler != nil {
err = contractActor.RegisterBilling(n.billingScheduler)
if err != nil {
log.Warnw("failed to register contract for billing",
"contract_did", contractActor.ContractDID.URI,
"error", err)
// Don't fail contract creation if billing registration fails
}
}
// Store actor reference in map for O(1) lookup
n.addContractActor(contractActor)
// if solution enabler, propose to parties
if request.SolutionEnablerDID.Equal(n.actor.Handle().DID) {
go func() {
sigs, err := n.proposeContract(contractActor.ContractDID.URI, creatorOfContract)
if err != nil {
log.Errorf("failed to propose contract: %w", err)
return
}
contractObj, err := n.contractStore.GetContract(contractActor.ContractDID.URI)
if err != nil {
log.Errorf("failed to get contract: %w", err)
return
}
contractObj.Signatures = sigs
if len(contractObj.Signatures) == 2 {
contractObj.CurrentState = contracts.ContractAccepted
contractObj.Transitions = []contracts.StateTransition{
{
FromState: contracts.ContractDraft,
ToState: contracts.ContractAccepted,
Timestamp: time.Now(),
Event: contracts.EventAccepted,
InitiatedBy: n.actor.Handle().DID,
},
}
}
if err := n.contractStore.Upsert(contractObj); err != nil {
log.Errorf("failed to update contract with signatures: %w", err)
}
}()
}
return contracts.CreateContractResponse{
ContractRequest: request,
ContractDID: contractActor.ContractDID.URI,
PubKey: hex.EncodeToString(pkBytes),
}, nil
}
func (n *Node) forwardContractCreateToHost(request contracts.CreateContractRequest) (contracts.CreateContractResponse, error) {
var resp contracts.CreateContractResponse
if request.SolutionEnablerDID.Empty() {
return resp, errors.New("solution enabler DID is empty")
}
destination, err := actor.HandleFromDID(request.SolutionEnablerDID.String())
if err != nil {
return resp, fmt.Errorf("failed to resolve contract host handle: %w", err)
}
envelope, err := n.invokeBehaviour(destination, behaviors.ContractCreateBehavior, request, invokeMessageTimeout)
if err != nil {
return resp, fmt.Errorf("failed to forward create contract request to host: %w", err)
}
if envelope.Message == nil {
return resp, errors.New("contract host returned empty response")
}
if err := json.Unmarshal(envelope.Message, &resp); err != nil {
return resp, fmt.Errorf("failed to unmarshal contract host response: %w", err)
}
if resp.Error != "" {
return resp, fmt.Errorf("contract host error: %s", resp.Error)
}
localContract := contracts.NewContract(resp.ContractDID, resp.ContractRequest)
if err := n.contractStore.Upsert(localContract); err != nil {
return resp, fmt.Errorf("failed to save local contract copy: %w", err)
}
return resp, nil
}
// this behaviour is registered by service and compute provider
// its used to sign contracts and send back the response to solution enabler.
func (n *Node) handleContractPropose(msg actor.Envelope) {
defer msg.Discard()
contractID := ""
handleErr := func(err error) {
log.Errorw("handle_contract_propose",
"labels", []string{string(observability.LabelContract)},
"error", err, "contractID", contractID)
n.sendReply(msg, contracts.ProposeContractResponse{Error: err.Error()})
}
var request contracts.ProposeContractRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("failed to unmarshal contract: %s", err))
return
}
incomingContract := request.Contract
contractID = request.Contract.ContractDID
provider, err := n.rootCap.Trust().GetProvider(n.actor.Security().DID())
if err != nil {
handleErr(fmt.Errorf("failed to get provider: %w", err))
return
}
providerDID := provider.DID()
isCreator := request.CreatorOfContract.DID.Equal(providerDID)
isRequestor := providerDID.Equal(incomingContract.ContractParticipants.Requestor)
isProvider := providerDID.Equal(incomingContract.ContractParticipants.Provider)
// Auto-sign if creator and is a participant
if isCreator && (isRequestor || isProvider) {
log.Infof("automatically signing contract for creator/requestor: %s", providerDID.URI)
sig, err := incomingContract.Sign(provider)
if err != nil {
handleErr(fmt.Errorf("failed to sign proposed contract: %w", err))
return
}
err = n.contractStore.Upsert(&incomingContract)
if err != nil {
handleErr(fmt.Errorf("failed to save proposed contract in the db: %w", err))
return
}
n.sendReply(msg, contracts.ProposeContractResponse{
Signature: contracts.Signature{
DID: providerDID,
Signatures: sig,
},
})
return
}
// determine service or compute provider
// if compute provider make sure to check if its approved
savedContract, err := n.contractStore.GetContract(incomingContract.ContractDID)
if err != nil {
// no contract, save it in the local store
err := n.contractStore.Upsert(&incomingContract)
if err != nil {
handleErr(fmt.Errorf("failed to save contract in the db: %w", err))
return
}
err = errors.New("contract is not signed yet")
handleErr(err)
return
}
if savedContract.CurrentState == contracts.ContractDraft {
err := errors.New("contract is not signed")
handleErr(err)
return
}
sig, err := incomingContract.Sign(provider)
if err != nil {
handleErr(fmt.Errorf("failed to sign contract: %w", err))
return
}
n.sendReply(msg, contracts.ProposeContractResponse{
Signature: contracts.Signature{
DID: provider.DID(),
Signatures: sig,
},
})
}
func (n *Node) proposeContract(contractDID string, creatorOfContract actor.Handle) ([]contracts.Signature, error) {
contractObj, err := n.contractStore.GetContract(contractDID)
if err != nil {
return nil, fmt.Errorf("failed to find contract %s: %w", contractDID, err)
}
providerHandle, err := actor.HandleFromDID(contractObj.ContractParticipants.Provider.URI)
if err != nil {
return nil, fmt.Errorf("failed to get provider's DID: %w", err)
}
requesterHandle, err := actor.HandleFromDID(contractObj.ContractParticipants.Requestor.URI)
if err != nil {
return nil, fmt.Errorf("failed to get requester's DID: %w", err)
}
propose := func(handle actor.Handle) (*contracts.Signature, error) {
envelope, err := n.invokeBehaviour(handle, behaviors.ContractProposeBehavior, contracts.ProposeContractRequest{
Contract: *contractObj,
CreatorOfContract: creatorOfContract,
}, invokeMessageTimeout)
if envelope.Message != nil && err == nil {
var response contracts.ProposeContractResponse
err := json.Unmarshal(envelope.Message, &response)
if err == nil && response.Error == "" {
return &response.Signature, nil
}
}
return nil, fmt.Errorf("failed to get back response from %s", handle.DID)
}
sigs := make([]contracts.Signature, 0)
providerSig, err := propose(providerHandle)
if err == nil {
sigs = append(sigs, *providerSig)
}
requesterSig, err := propose(requesterHandle)
if err == nil {
sigs = append(sigs, *requesterSig)
}
return sigs, nil
}
// compute provider handles approvals
func (n *Node) handleContractApprovalLocal(msg actor.Envelope) {
defer msg.Discard()
contractID := ""
handleErr := func(err error) {
log.Errorw("handle_contract_approval_local",
"labels", []string{string(observability.LabelContract)},
"error", err, "contractID", contractID)
n.sendReply(msg, contracts.ContractApproveLocalResponse{Error: err.Error()})
}
var req contracts.ContractApproveLocalRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal contract approve request: %s", err))
return
}
savedContract, err := n.contractStore.GetContract(req.ContractDID)
if err != nil {
handleErr(fmt.Errorf("failed to get contract: %w", err))
return
}
savedContract.CurrentState = contracts.ContractAccepted
err = n.contractStore.Upsert(savedContract)
if err != nil {
handleErr(fmt.Errorf("failed to update contract: %w", err))
return
}
contractID = savedContract.ContractDID
// sign the contract and send it to the contract host
contractDID, err := did.FromString(savedContract.ContractDID)
if err != nil {
handleErr(fmt.Errorf("failed to convert contract did: %w", err))
return
}
pubKey, err := did.PublicKeyFromDID(contractDID)
if err != nil {
handleErr(fmt.Errorf("failed to get public key from contract host did: %w", err))
return
}
pubKeySolutionEnabler, err := did.PublicKeyFromDID(savedContract.SolutionEnablerDID)
if err != nil {
handleErr(fmt.Errorf("failed to get public key: %w", err))
return
}
soltionEnablerPeerID, err := peer.IDFromPublicKey(pubKeySolutionEnabler)
if err != nil {
handleErr(fmt.Errorf("failed to peer id from public key: %w", err))
return
}
destination, err := actor.HandleFromPublicKeyWithInboxAddress(pubKey, savedContract.ContractDID, soltionEnablerPeerID.String())
if err != nil {
handleErr(fmt.Errorf("failed to get get contract host handle: %w", err))
return
}
provider, err := n.rootCap.Trust().GetProvider(n.actor.Security().DID())
if err != nil {
handleErr(fmt.Errorf("failed to get provider: %w", err))
return
}
sig, err := savedContract.Sign(provider)
if err != nil {
handleErr(fmt.Errorf("failed to sign contract: %w", err))
return
}
signReq := contracts.ContractSignRequest{
ContractDID: savedContract.ContractDID,
Signature: sig,
}
reply, err := n.invokeBehaviour(destination, behaviors.ContractSignBehavior, signReq, invokeSignRequestTimeout)
if err != nil {
handleErr(fmt.Errorf("failed to get invoke sign contract on contract host: %w", err))
return
}
var signResp contracts.ContractSignResponse
if err := json.Unmarshal(reply.Message, &signResp); err != nil {
handleErr(fmt.Errorf("failed to unmarshal contract host response: %w", err))
return
}
if signResp.Error != "" {
handleErr(fmt.Errorf("error from contract host: %s", signResp.Error))
return
}
n.sendReply(msg, contracts.ContractApproveLocalResponse{
Success: true,
})
}
// compute provider can list incoming contracts for approval
func (n *Node) handleListIncomingContracts(msg actor.Envelope) {
defer msg.Discard()
contractID := ""
handleErr := func(err error) {
log.Errorw("handle_list_incoming_contracts",
"labels", []string{string(observability.LabelContract)},
"error", err, "contractID", contractID)
n.sendReply(msg, contracts.ContractListIncomingResponse{Error: err.Error()})
}
var req contracts.ContractListIncomingRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal list incoming request: %w", err))
return
}
allContracts, err := n.contractStore.GetAllContracts()
if err != nil {
handleErr(fmt.Errorf("failed to get all contracts: %w", err))
return
}
callerDID := msg.From.DID.String()
rootDID := n.rootCap.DID().String()
filteredLocal := filterContractsByRole(allContracts, req.Role, callerDID)
if callerDID == rootDID {
solutionHosts := uniqueSolutionEnablerDIDs(allContracts)
if len(solutionHosts) == 0 {
log.Warnf("no solution hosts found (i.e: no contracts created yet) for caller %s", callerDID)
handleErr(fmt.Errorf("no solution hosts found to retrieve contracts from for caller %s", callerDID))
return
}
aggregated := make(map[string]*contracts.Contract, len(filteredLocal))
for _, hostDID := range solutionHosts {
if hostDID == "" {
continue
}
// if the solution enabler is this node, use local data
if hostDID == rootDID {
for _, c := range filterContractsByRole(allContracts, req.Role, callerDID) {
aggregated[c.ContractDID] = c
}
continue
}
handle, err := actor.HandleFromDID(hostDID)
if err != nil {
log.Warnf("failed to build handle for host %s: %v", hostDID, err)
continue
}
reply, err := n.invokeBehaviour(handle, behaviors.ContractListBehavior, req, invokeMessageTimeout)
if err != nil || reply.Message == nil {
log.Warnf("failed to invoke list incoming on host %s: %v", hostDID, err)
continue
}
var remoteResp contracts.ContractListIncomingResponse
if err := json.Unmarshal(reply.Message, &remoteResp); err != nil {
log.Warnf("failed to decode contract host response %s: %v", hostDID, err)
continue
}
if remoteResp.Error != "" {
log.Warnf("host %s returned error listing incoming contracts: %s", hostDID, remoteResp.Error)
continue
}
for _, c := range remoteResp.Contracts {
aggregated[c.ContractDID] = c
}
}
contractsSlice := make([]*contracts.Contract, 0, len(aggregated))
for _, c := range aggregated {
contractsSlice = append(contractsSlice, c)
}
n.sendReply(msg, contracts.ContractListIncomingResponse{
Contracts: contractsSlice,
})
return
}
// Contract host invocation: respond with local contracts only
resp := contracts.ContractListIncomingResponse{
Contracts: filteredLocal,
}
n.sendReply(msg, resp)
}
// handleContractInfo handles requests for contract information
// This is used by deployment handlers to retrieve Provider and Requestor DIDs
// from the contract host when they're not specified in the ensemble
func (n *Node) handleContractInfo(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("handle_contract_info",
"labels", []string{string(observability.LabelContract)},
"error", err)
n.sendReply(msg, behaviors.ContractInfoResponse{
OK: false,
Error: err.Error(),
})
}
var req behaviors.ContractInfoRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal contract info request: %w", err))
return
}
if req.ContractDID == "" {
handleErr(errors.New("contract_did is required"))
return
}
// Get contract from store
contract, err := n.contractStore.GetContract(req.ContractDID)
if err != nil {
handleErr(fmt.Errorf("contract not found: %w", err))
return
}
// Return Provider and Requestor DIDs
resp := behaviors.ContractInfoResponse{
OK: true,
Provider: contract.ContractParticipants.Provider.String(),
Requestor: contract.ContractParticipants.Requestor.String(),
}
n.sendReply(msg, resp)
}
func filterContractsByRole(contractsList []*contracts.Contract, role contracts.ContractListIncomingRole, targetDID string) []*contracts.Contract {
result := make([]*contracts.Contract, 0, len(contractsList))
for _, c := range contractsList {
switch role {
case contracts.ContractRoleProvider:
if targetDID == "" || c.ContractParticipants.Provider.String() == targetDID {
result = append(result, c)
}
case contracts.ContractRoleRequestor:
if targetDID == "" || c.ContractParticipants.Requestor.String() == targetDID {
result = append(result, c)
}
default:
if targetDID == "" || c.SolutionEnablerDID.String() == targetDID || c.ContractParticipants.Provider.String() == targetDID || c.ContractParticipants.Requestor.String() == targetDID {
result = append(result, c)
}
}
}
return result
}
func uniqueSolutionEnablerDIDs(contractsList []*contracts.Contract) []string {
unique := make(map[string]struct{}, len(contractsList))
for _, c := range contractsList {
host := c.SolutionEnablerDID.String()
if host == "" {
continue
}
unique[host] = struct{}{}
}
hosts := make([]string, 0, len(unique))
for host := range unique {
hosts = append(hosts, host)
}
return hosts
}
func (n *Node) StartContracts() error {
allContracts, err := n.contractStore.GetAllContracts()
if err != nil {
return fmt.Errorf("failed to starts contracts: %w", err)
}
for _, v := range allContracts {
key, err := n.contractStore.GetContractKey(v.ContractDID)
if err != nil {
log.Warnf("failed to get contract %s private key: %v", v.ContractDID, err)
continue
}
privKey, err := crypto.BytesToPrivateKey(key.Key)
if err != nil {
continue
}
pubKey := privKey.GetPublic()
// Create forwardInvoice function to forward invoices from contract actor to payment validator
forwardInvoice := func(req contracts.ContractUsageRequest) error {
destination, err := actor.HandleFromDID(v.PaymentValidatorDID.URI)
if err != nil {
return fmt.Errorf("failed to get payment validator handle: %w", err)
}
envelope, err := n.invokeBehaviour(destination, behaviors.ContractUsageBehavior, req, invokeMessageTimeout)
if envelope.Message == nil || err != nil {
return fmt.Errorf("failed to forward invoice to payment validator: %w", err)
}
log.Infof("Successfully sent invoice for contract %s to payment validator (payment_model: %s)", req.Contract.ContractDID, req.Contract.PaymentDetails.PaymentModel)
return nil
}
contractActor, err := tokenomics.NewContractActor(n.actor.Handle(), v.PaymentValidatorDID, n.network, v.ContractParticipants, privKey, pubKey, n.contractStore, n.usageStore, forwardInvoice)
if err != nil {
continue
}
err = contractActor.Start()
if err != nil {
continue
}
// Register with billing scheduler
// NOTE: This is idempotent - if a contract was already registered
// (e.g., created while node was running), this will be a no-op
if n.billingScheduler != nil {
err = contractActor.RegisterBilling(n.billingScheduler)
if err != nil {
log.Warnw("failed to register contract for billing",
"contract_did", contractActor.ContractDID.URI,
"error", err)
}
}
// Store actor reference in map for O(1) lookup
n.addContractActor(contractActor)
}
return nil
}
// payment validator to accept validation requests
func (n *Node) handleContractPaymentValidationRequestFromContractHost(msg actor.Envelope) {
defer msg.Discard()
contractID := ""
handleErr := func(err error) {
log.Errorw("handle_contract_payment_validation_request_from_contract_host",
"labels", []string{string(observability.LabelContract)},
"error", err, "contractID", contractID)
n.sendReply(msg, contracts.ContractPaymentValidationResponse{Error: err.Error()})
}
var req contracts.ContractPaymentValidationRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal payment validation request: %s", err))
return
}
payment, err := n.paymentStore.GetByUniqueID(req.UniqueID)
if err != nil {
handleErr(fmt.Errorf("failed to find payment with unique id: %s", req.UniqueID))
return
}
// Check if there's a quote for this transaction
var expectedAmount string
if req.QuoteID != "" {
// If quote_id is provided, get quote by ID
quote, err := n.paymentQuoteStore.GetQuote(req.QuoteID)
if err != nil {
handleErr(fmt.Errorf("failed to get quote: %w", err))
return
}
if quote.Used {
expectedAmount = quote.ConvertedAmount
} else {
// Fallback to payment amount if quote not found or not used
expectedAmount = payment.Amount
}
} else {
// Try to find used quote by unique_id
quote, err := n.paymentQuoteStore.GetQuoteByUniqueID(req.UniqueID)
if quote != nil && quote.Used {
expectedAmount = quote.ConvertedAmount
} else if err != nil { // meaning no quote was used
// Fallback to payment amount (for backward compatibility)
expectedAmount = payment.Amount
}
}
verified := false
errorMsg := ""
contractID = payment.Contract.ContractDID
switch req.Blockchain {
case ethereumBlockchain:
ethAddr := types.PaymentAddressInfo{}
foundEthAddr := false
for _, v := range payment.Contract.PaymentDetails.Addresses {
if v.Blockchain == ethereumBlockchain {
ethAddr = v
foundEthAddr = true
break
}
}
if !foundEthAddr {
handleErr(fmt.Errorf("ethereum address was not found in payment addresses: %w", err))
return
}
c := ethereumClient.NewClient(
n.dmsConfig.PaymentProvider.EthereumRPCURL,
n.dmsConfig.PaymentProvider.EthereumRPCToken,
)
blockNum, err := ethereumClient.GetBlockNumber(c)
if err != nil {
handleErr(fmt.Errorf("failed to get block number: %w", err))
return
}
// deduct some block numbers
blockNum -= 45000 // 5 days back approx
blockNumHex := fmt.Sprintf("0x%x", blockNum)
txs, err := ethereumClient.GetERC20Transfers(
c,
n.dmsConfig.PaymentProvider.NtxContractAddress,
ethAddr.ProviderAddr,
blockNumHex,
"latest",
)
if err != nil {
handleErr(fmt.Errorf("failed to get erc20 transfer: %w", err))
return
}
for _, tx := range txs {
if strings.EqualFold(tx.TxHash, req.TxHash) {
if !strings.EqualFold(tx.From, ethAddr.RequesterAddr) {
handleErr(fmt.Errorf("requester transaction address %s doesn't match the one in transaction: %s", ethAddr.RequesterAddr, tx.From))
return
}
ok, err := compareDecimals(tx.Amount, expectedAmount)
if err != nil {
errorMsg = err.Error() + " tx amount: " + tx.Amount + " expected amount: " + expectedAmount
}
if ok {
verified = true
} else {
errorMsg = "not verified: tx amount: " + tx.Amount + " expected amount: " + expectedAmount
}
break
}
}
case cardanoBlockchain:
cardanoAddr := types.PaymentAddressInfo{}
foundCardanoAddr := false
for _, v := range payment.Contract.PaymentDetails.Addresses {
if v.Blockchain == cardanoBlockchain {
cardanoAddr = v
foundCardanoAddr = true
break
}
}
if !foundCardanoAddr {
handleErr(fmt.Errorf("cardano address was not found in payment addresses: %w", err))
return
}
client := cardanoClient.NewClient(
n.dmsConfig.PaymentProvider.BlockFrostAPIKey,
n.dmsConfig.PaymentProvider.BlockFrostAPIURL,
)
asset := n.dmsConfig.PaymentProvider.CardanoAssetPolicyID + hex.EncodeToString([]byte(n.dmsConfig.PaymentProvider.CardanoAssetName))
txs, err := client.FindTxsToAddressForAsset(n.ctx, asset, cardanoAddr.ProviderAddr)
if err != nil {
handleErr(fmt.Errorf("failed to get cardano transactions: %w", err))
return
}
for _, tx := range txs {
if strings.EqualFold(tx.TxHash, req.TxHash) {
foundFrom := false
for _, v := range tx.FromAddrs {
if v == cardanoAddr.RequesterAddr {
foundFrom = true
break
}
}
if !foundFrom {
handleErr(fmt.Errorf("requester transaction address not found: %s", cardanoAddr.RequesterAddr))
return
}
ok, err := compareDecimals(tx.Quantity, expectedAmount)
if err != nil {
errorMsg = err.Error() + " tx amount: " + tx.Quantity + " expected amount: " + expectedAmount
}
if ok {
verified = true
} else {
errorMsg = "not verified: tx amount: " + tx.Quantity + " expected amount: " + expectedAmount
}
break
}
}
default:
handleErr(fmt.Errorf("unsupported blockchain payment info: %s", req.Blockchain))
return
}
resp := contracts.ContractPaymentValidationResponse{}
if verified {
payment.Paid = true
err := n.paymentStore.Update(payment)
if err != nil {
resp.Error = err.Error()
}
} else {
if errorMsg != "" {
resp.Error = errorMsg
} else {
resp.Error = "not verified"
}
}
n.sendReply(msg, resp)
}
func (n *Node) handleConfirmLocalTransaction(msg actor.Envelope) {
defer msg.Discard()
contractID := ""
handleErr := func(err error) {
log.Errorw("handle_confirm_local_transaction",
"labels", []string{string(observability.LabelContract)},
"error", err, "contractID", contractID)
n.sendReply(msg, contracts.ContractConfirmLocalTransactionResponse{Error: err.Error()})
}
var req contracts.ContractConfirmLocalTransactionRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal incoming transaction confirm request: %s", err))
return
}
paymentProviderDID, err := n.transactionStore.GetPaymentValidatorDID(req.UniqueID)
if err != nil {
handleErr(fmt.Errorf("failed to get payment validator did: %s", err))
return
}
contractID = paymentProviderDID
// If quote_id is provided, validate and mark as used
if req.QuoteID != "" {
// Validate quote is still valid (not expired, not used)
quote, err := n.paymentQuoteStore.ValidateQuote(req.QuoteID)
if err != nil {
handleErr(fmt.Errorf("quote validation failed: %w", err))
return
}
if quote.UniqueID != req.UniqueID {
handleErr(fmt.Errorf("quote does not match transaction"))
return
}
// Mark quote as used
if err := n.paymentQuoteStore.MarkQuoteAsUsed(req.QuoteID); err != nil {
handleErr(fmt.Errorf("failed to mark quote as used: %w", err))
return
}
}
paymentValidationReq := contracts.ContractPaymentValidationRequest{
TxHash: req.TxHash,
UniqueID: req.UniqueID,
Blockchain: req.Blockchain,
QuoteID: req.QuoteID,
}
paymentProvider, err := actor.HandleFromDID(paymentProviderDID)
if err != nil {
handleErr(fmt.Errorf("failed to get payment provider hande: %w", err))
return
}
reply, err := n.invokeBehaviour(paymentProvider, behaviors.ContractPaymentValidationRequestBehavior, paymentValidationReq, invokeMessageTimeout)
if err != nil {
handleErr(fmt.Errorf("failed to send transaction confirmation to payment provider: %w", err))
return
}
var replyResponse contracts.ContractPaymentValidationResponse
_ = json.Unmarshal(reply.Message, &replyResponse)
if replyResponse.Error != "" {
handleErr(fmt.Errorf("payment validation response from payment provider: %s", replyResponse.Error))
return
}
_, err = n.transactionStore.MarkAsPaid(req.UniqueID, req.TxHash)
if err != nil {
handleErr(fmt.Errorf("failed to get mark transaction as paid: %s", err))
return
}
// Forward paid transaction to compute provider
tx, err := n.transactionStore.GetTransactionByUniqueID(req.UniqueID)
if err != nil {
handleErr(fmt.Errorf("failed to get transaction: %w", err))
return
}
contract, err := n.contractStore.GetContract(tx.ContractDID)
if err != nil {
handleErr(fmt.Errorf("failed to get contract: %w", err))
return
}
// Send paid transaction to compute provider
computeProviderTxReq := contracts.TransactionForServiceProviderRequest{
PaymentValidatorDID: tx.PaymentValidatorDID,
UniqueID: tx.UniqueID,
ContractDID: tx.ContractDID,
ToAddress: tx.ToAddress,
Amount: tx.Amount,
Status: "paid", // Mark as paid
TxHash: req.TxHash,
Metadata: tx.Metadata,
}
destination, err := actor.HandleFromDID(contract.ContractParticipants.Provider.URI)
if err != nil {
handleErr(fmt.Errorf("failed to get destination handle: %w", err))
return
}
_, err = n.invokeBehaviour(
destination,
behaviors.ContractTransactionBehavior,
computeProviderTxReq,
invokeMessageTimeout,
)
if err != nil {
handleErr(fmt.Errorf("failed to forward paid transaction to compute provider: %w", err))
return
}
// if the transaction is a orchestration fee, add it to the orchestration fee metric
if feeType, ok := tx.Metadata["fee_type"].(string); ok && feeType == "orchestration" {
if m := observability.TxPaidFeesAmount; m != nil {
amount, err := strconv.ParseFloat(tx.Amount, 64)
if err == nil {
m.Add(n.ctx, amount, metric.WithAttributes(
observability.AttrDID,
attribute.String("ContractDID", tx.ContractDID),
))
}
}
}
// metric
if m := observability.TxPaidAmount; m != nil {
amount, err := strconv.ParseFloat(tx.Amount, 64)
if err == nil {
m.Add(n.ctx, amount, metric.WithAttributes(
observability.AttrDID,
attribute.String("ContractDID", tx.ContractDID),
))
}
}
log.Infof("successfully forwarded paid transaction %s to compute provider", req.UniqueID)
n.sendReply(msg, contracts.ContractConfirmLocalTransactionResponse{})
}
func (n *Node) handleListLocalTransactions(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("handle_list_local_transactions",
"labels", []string{string(observability.LabelContract)},
"error", err)
n.sendReply(msg, contracts.ContractListLocalTransactionsResponse{Error: err.Error()})
}
txs, err := n.transactionStore.AllTransactions()
if err != nil {
handleErr(fmt.Errorf("failed to get local transactions: %s", err))
return
}
resp := contracts.ContractListLocalTransactionsResponse{
Transactions: txs,
}
n.sendReply(msg, resp)
}
func (n *Node) handlePaymentStatus(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
n.sendReply(msg, contracts.ContractPaymentStatusResponse{Error: err.Error()})
}
var req contracts.ContractPaymentStatusRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal get payment request: %s", err))
return
}
p, err := n.paymentStore.GetByUniqueID(req.UniqueID)
if err != nil {
handleErr(fmt.Errorf("failed to get payment: %s", err))
return
}
resp := contracts.ContractPaymentStatusResponse{
UniqueID: p.UniqueID,
Paid: p.Paid,
}
n.sendReply(msg, resp)
}
func (n *Node) handleIncomingTransaction(msg actor.Envelope) {
defer msg.Discard()
contractID := ""
handleErr := func(err error) {
log.Errorw("handle_incoming_transaction",
"labels", []string{string(observability.LabelContract)},
"error", err, "contractID", contractID)
n.sendReply(msg, contracts.TransactionForServiceProviderResponse{Error: err.Error()})
}
var req contracts.TransactionForServiceProviderRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal incoming transaction request: %s", err))
return
}
contractID = req.ContractDID
err := n.transactionStore.Upsert(transaction.Transaction{
UniqueID: req.UniqueID,
PaymentValidatorDID: req.PaymentValidatorDID,
ContractDID: req.ContractDID,
ToAddress: req.ToAddress,
Amount: req.Amount,
Status: req.Status, // Use provided status, or "" defaults to "unpaid" in Upsert
TxHash: req.TxHash, // Use provided tx hash
Metadata: req.Metadata,
// Store conversion metadata if provided
OriginalAmount: req.OriginalAmount, // Amount in pricing currency (USDT)
PricingCurrency: req.PricingCurrency, // Currency of original amount
RequiresConversion: req.RequiresConversion, // True if conversion is needed
})
if err != nil {
handleErr(fmt.Errorf("failed to insert transaction into the store: %w", err))
return
}
// if USDT, add to USD metric
if req.PricingCurrency == "USDT" {
// metric
if m := observability.TxCreatedUSDAmount; m != nil {
amount, err := strconv.ParseFloat(req.OriginalAmount, 64)
if err == nil {
m.Add(n.ctx, amount, metric.WithAttributes(
observability.AttrDID,
attribute.String("ContractDID", req.ContractDID),
))
}
}
}
// metric
if m := observability.TxCreatedAmount; m != nil {
amount, err := strconv.ParseFloat(req.Amount, 64)
if err == nil {
m.Add(n.ctx, amount, metric.WithAttributes(
observability.AttrDID,
attribute.String("ContractDID", req.ContractDID),
))
}
}
resp := contracts.TransactionForServiceProviderResponse{}
n.sendReply(msg, resp)
}
// payment provider listens for requests from contract host
// about usages of a contracts. As a payment provider, we should
// contact the service provider for what amount to pay.
func (n *Node) handleIncomingContractUsage(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("handleIncomingContractUsage error: %v", err)
n.sendReply(msg, contracts.ContractUsageResponse{Error: err.Error()})
}
var req contracts.ContractUsageRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal incoming contract usages request: %s", err))
return
}
log.Infof("handleIncomingContractUsage: payment_model=%s, usages=%d, hasTimeUtilization=%v, hasResourceUtilization=%v, hasPeriodicDetails=%v, contractDID=%s",
req.Contract.PaymentDetails.PaymentModel, req.Usages,
req.TimeUtilization != nil, req.ResourceUtilization != nil, req.PeriodicDetails != nil, req.Contract.ContractDID)
// Get processor for this payment model
processor, err := contracts.GetPaymentModelProcessor(req.Contract.PaymentDetails.PaymentModel)
if err != nil {
handleErr(fmt.Errorf("unsupported payment model: %w", err))
return
}
// Convert request to UsageData format
usageData, err := n.convertRequestToUsageData(&req)
if err != nil {
handleErr(fmt.Errorf("failed to convert request to usage data: %w", err))
return
}
// Edge case: Periodic with no deployments - skip processing
if usageData == nil {
resp := contracts.ContractUsageResponse{}
n.sendReply(msg, resp)
return
}
// Calculate payment items using processor
items, err := processor.CalculatePayment(usageData, &req.Contract)
if err != nil {
handleErr(fmt.Errorf("failed to calculate payment: %w", err))
return
}
// Process payment items (save and forward)
if err := n.paymentProcessor.ProcessPaymentItems(&req.Contract, items, req.UniqueID); err != nil {
handleErr(fmt.Errorf("failed to process payment items: %w", err))
return
}
resp := contracts.ContractUsageResponse{}
n.sendReply(msg, resp)
}
func (n *Node) handleContractChainVerification(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.ContractChainVerificationResponse{}
var req contracts.ContractChainVerificationRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
resp.Error = fmt.Sprintf("failed to unmarshal request: %v", err)
n.sendReply(msg, resp)
return
}
// Verify that the caller is the Provider mentioned in the request
// This ensures only the actual Provider can verify chains involving itself
providerDID, err := did.FromString(req.ProviderDID)
if err != nil {
resp.Error = fmt.Sprintf("invalid provider DID: %v", err)
n.sendReply(msg, resp)
return
}
// Security check: Only the Provider mentioned in the request can verify the chain
if msg.From.DID != providerDID {
resp.Error = fmt.Sprintf("caller DID (%s) does not match ProviderDID in request (%s)",
msg.From.DID.String(), req.ProviderDID)
n.sendReply(msg, resp)
return
}
orchestratorDID, err := did.FromString(req.OrchestratorDID)
if err != nil {
resp.Error = fmt.Sprintf("invalid orchestrator DID: %v", err)
n.sendReply(msg, resp)
return
}
// Step 1: Verify Contract A (Orchestrator ↔ Organization)
// Use the provided ContractADID to get the contract
contractA, err := n.contractStore.GetContract(req.ContractDID)
if err != nil {
resp.Error = fmt.Sprintf("contract A not found: %v", err)
n.sendReply(msg, resp)
return
}
// Validate Contract A participants match the provided DIDs
provStr := contractA.ContractParticipants.Provider.String()
reqStr := contractA.ContractParticipants.Requestor.String()
orchStr := orchestratorDID.String()
// Check that Contract A is between Orchestrator and Organization
if reqStr != orchStr && provStr == req.ProviderDID {
resp.Error = "contract A participants do not match orchestrator and organization"
n.sendReply(msg, resp)
return
}
// Ensure Contract A is active
if contractA.CurrentState != contracts.ContractAccepted && contractA.CurrentState != contracts.ContractActive {
resp.Error = fmt.Sprintf("contract A is not in active state: %s", contractA.CurrentState)
n.sendReply(msg, resp)
return
}
// Step 2: Find Contract B (Organization ↔ Provider)
// The orchestrator specifies Contract A (head contract) which includes the Organization DID.
// The provider finds Contract B by matching the Organization DID from Contract A.
// This handles the case where a provider has contracts with multiple organizations:
// the provider will find the correct Contract B based on which organization is specified
// in the head contract (Contract A).
//
// Example: If Provider has contracts with Org1 and Org2:
// - Orchestrator specifies Contract A with Org1 → Provider finds Contract B with Org1
// - Orchestrator specifies Contract A with Org2 → Provider finds Contract B with Org2
contractB, err := n.contractStore.FindContractByParticipants(contractA.ContractParticipants.Provider, providerDID)
if err != nil {
resp.Error = fmt.Sprintf("no active contract found between organization and provider: %v", err)
n.sendReply(msg, resp)
return
}
// Step 3: Validate both contracts are in acceptable state
validA := contractA.CurrentState == contracts.ContractAccepted || contractA.CurrentState == contracts.ContractActive
validB := contractB.CurrentState == contracts.ContractAccepted || contractB.CurrentState == contracts.ContractActive
if !validA || !validB {
resp.Error = fmt.Sprintf("contract chain invalid: Contract A state=%s, Contract B state=%s",
contractA.CurrentState, contractB.CurrentState)
n.sendReply(msg, resp)
return
}
// Chain is valid
resp.Valid = true
resp.OrganizationDID = contractA.ContractParticipants.Provider.String()
resp.OrchestratorContract = contractA
resp.ProviderContract = contractB
n.sendReply(msg, resp)
}
// convertRequestToUsageData converts ContractUsageRequest to UsageData format
func (n *Node) convertRequestToUsageData(req *contracts.ContractUsageRequest) (*contracts.UsageData, error) {
paymentModel := req.Contract.PaymentDetails.PaymentModel
switch paymentModel {
case contracts.PayPerAllocation:
return &contracts.UsageData{
ContractDID: req.Contract.ContractDID,
PaymentModel: paymentModel,
Data: req.Usages, // Simple count
}, nil
case contracts.PayPerDeployment:
return &contracts.UsageData{
ContractDID: req.Contract.ContractDID,
PaymentModel: paymentModel,
Data: req.Usages, // Simple count
}, nil
case contracts.PayPerTimeUtilization:
if req.TimeUtilization == nil {
return nil, errors.New("time_utilization is required for pay_per_time_utilization payment model")
}
return &contracts.UsageData{
ContractDID: req.Contract.ContractDID,
PaymentModel: paymentModel,
Data: req.TimeUtilization,
}, nil
case contracts.PayPerResourceUtilization:
if req.ResourceUtilization == nil {
return nil, errors.New("resource_utilization is required for pay_per_resource_utilization payment model")
}
return &contracts.UsageData{
ContractDID: req.Contract.ContractDID,
PaymentModel: paymentModel,
Data: req.ResourceUtilization,
}, nil
case contracts.FixedRental:
if req.FixedRentalDetails == nil {
return nil, errors.New("fixed_rental_details is required for fixed_rental payment model")
}
return &contracts.UsageData{
ContractDID: req.Contract.ContractDID,
PaymentModel: paymentModel,
Data: req.FixedRentalDetails,
}, nil
case contracts.Periodic:
if req.PeriodicDetails == nil {
return nil, errors.New("periodic_details is required for periodic payment model")
}
// Edge case: No deployments - return nil to skip processing
if len(req.PeriodicDetails.Deployments) == 0 {
log.Infow("received periodic invoice request with no deployments (zero runtime), skipping payment processing",
"contract_did", req.Contract.ContractDID,
"period_start", req.PeriodicDetails.PeriodStart,
"period_end", req.PeriodicDetails.PeriodEnd)
return nil, nil
}
return &contracts.UsageData{
ContractDID: req.Contract.ContractDID,
PaymentModel: paymentModel,
Data: req.PeriodicDetails,
}, nil
default:
return nil, fmt.Errorf("unsupported payment model: %s", paymentModel)
}
}
// isHeadContractFromRequest determines if a contract is a Head Contract in a chain
// This can be enhanced based on actual contract creation context
// For now, returns false (contracts without metadata are treated as P2P)
func isHeadContractFromRequest(request contracts.CreateContractRequest) bool {
return request.Metadata[contracts.ContractChainRoleMetadataKey] == contracts.ContractChainRoleHead
}
// true if a is bigger than b
func compareDecimals(a, b string) (bool, error) {
af, _, err := big.ParseFloat(a, 10, 256, big.ToNearestEven)
if err != nil {
return false, err
}
bf, _, err := big.ParseFloat(b, 10, 256, big.ToNearestEven)
if err != nil {
return false, err
}
return af.Cmp(bf) >= 0, nil
}
// handleGetPaymentQuote handles requests for payment quotes
func (n *Node) handleGetPaymentQuote(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("handle_get_payment_quote",
"labels", []string{string(observability.LabelContract)},
"error", err)
n.sendReply(msg, contracts.ContractGetPaymentQuoteResponse{Error: err.Error()})
}
var req contracts.ContractGetPaymentQuoteRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal payment quote request: %w", err))
return
}
// Get transaction
tx, err := n.transactionStore.GetTransactionByUniqueID(req.UniqueID)
if err != nil {
handleErr(fmt.Errorf("failed to get transaction: %w", err))
return
}
// Check if conversion is needed
if !tx.RequiresConversion || tx.PricingCurrency == "" {
handleErr(fmt.Errorf("transaction does not require conversion"))
return
}
// Check if there's already an active quote for this transaction
existingQuote, err := n.paymentQuoteStore.HasActiveQuote(req.UniqueID)
if err != nil {
handleErr(fmt.Errorf("failed to check for existing quote: %w", err))
return
}
if existingQuote != nil {
handleErr(fmt.Errorf("active quote already exists for this transaction (quote_id: %s). Please cancel the existing quote before creating a new one", existingQuote.QuoteID))
return
}
// Get payment currency from transaction addresses
paymentCurrency := "NTX"
if len(tx.ToAddress) > 0 {
paymentCurrency = tx.ToAddress[0].Currency
}
// Perform real-time conversion
if n.priceConverter == nil {
handleErr(fmt.Errorf("price converter not configured"))
return
}
ctx := context.Background()
oracle := n.priceConverter.GetOracle()
if oracle == nil {
handleErr(fmt.Errorf("price oracle not available"))
return
}
convertedAmount, err := oracle.ConvertAmount(ctx, tx.OriginalAmount, tx.PricingCurrency, paymentCurrency)
if err != nil {
handleErr(fmt.Errorf("failed to convert amount: %w", err))
return
}
// Get exchange rate
rate, err := oracle.GetPrice(ctx, tx.PricingCurrency, paymentCurrency)
if err != nil {
handleErr(fmt.Errorf("failed to get exchange rate: %w", err))
return
}
// Generate quote ID
quoteID := fmt.Sprintf("quote_%s_%d", req.UniqueID, time.Now().UnixNano())
// Get quote TTL from config (default: 2 minutes)
quoteTTL := 2 * time.Minute
if n.dmsConfig.CoinMarketCap.QuoteTTL != "" {
if parsedTTL, err := time.ParseDuration(n.dmsConfig.CoinMarketCap.QuoteTTL); err == nil {
quoteTTL = parsedTTL
}
}
// Create quote
quote := payment_quote.PaymentQuote{
QuoteID: quoteID,
UniqueID: req.UniqueID,
OriginalAmount: tx.OriginalAmount,
ConvertedAmount: convertedAmount,
PricingCurrency: tx.PricingCurrency,
PaymentCurrency: paymentCurrency,
ExchangeRate: fmt.Sprintf("%.8f", rate),
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(quoteTTL),
Used: false,
}
// Store quote
if err := n.paymentQuoteStore.CreateQuote(quote); err != nil {
handleErr(fmt.Errorf("failed to create quote: %w", err))
return
}
// Return response
resp := contracts.ContractGetPaymentQuoteResponse{
QuoteID: quoteID,
OriginalAmount: tx.OriginalAmount,
ConvertedAmount: convertedAmount,
PricingCurrency: tx.PricingCurrency,
PaymentCurrency: paymentCurrency,
ExchangeRate: fmt.Sprintf("%.8f", rate),
ExpiresAt: quote.ExpiresAt,
}
n.sendReply(msg, resp)
}
// handleValidatePaymentQuote handles validation requests for payment quotes
func (n *Node) handleValidatePaymentQuote(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
n.sendReply(msg, contracts.ContractValidatePaymentQuoteResponse{
Valid: false,
Error: err.Error(),
})
}
var req contracts.ContractValidatePaymentQuoteRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal validate quote request: %w", err))
return
}
// Validate quote (checks expiration and usage)
quote, err := n.paymentQuoteStore.ValidateQuote(req.QuoteID)
if err != nil {
handleErr(fmt.Errorf("quote validation failed: %w", err))
return
}
// Return validation result with quote details
resp := contracts.ContractValidatePaymentQuoteResponse{
Valid: true,
QuoteID: quote.QuoteID,
OriginalAmount: quote.OriginalAmount,
ConvertedAmount: quote.ConvertedAmount,
PricingCurrency: quote.PricingCurrency,
PaymentCurrency: quote.PaymentCurrency,
ExchangeRate: quote.ExchangeRate,
ExpiresAt: quote.ExpiresAt,
}
n.sendReply(msg, resp)
}
// handleCancelPaymentQuote handles cancellation requests for payment quotes
func (n *Node) handleCancelPaymentQuote(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
n.sendReply(msg, contracts.ContractCancelPaymentQuoteResponse{Error: err.Error()})
}
var req contracts.ContractCancelPaymentQuoteRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(fmt.Errorf("failed to unmarshal cancel quote request: %w", err))
return
}
// Validate quote exists
quote, err := n.paymentQuoteStore.GetQuote(req.QuoteID)
if err != nil {
handleErr(fmt.Errorf("quote not found: %w", err))
return
}
// Check if already used
if quote.Used {
handleErr(fmt.Errorf("quote already used"))
return
}
// Invalidate quote (mark as used)
if err := n.paymentQuoteStore.InvalidateQuote(req.QuoteID); err != nil {
handleErr(fmt.Errorf("failed to invalidate quote: %w", err))
return
}
n.sendReply(msg, contracts.ContractCancelPaymentQuoteResponse{})
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"encoding/json"
"errors"
"fmt"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/jobs"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/dms/orchestrator"
"gitlab.com/nunet/device-management-service/gateway/provider"
"gitlab.com/nunet/device-management-service/gateway/store"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// Deployment behavior data source strategy:
// - Store-based (primary source of truth): handleDeploymentList, handleDeploymentStatus, handleDeploymentManifest
// - In-memory registry (requires active orchestrator): handleDeploymentLogs, handleDeploymentShutdown, handleDeploymentUpdate
// - Store + in-memory: handleNewDeployment (creates in memory, auto-saved to store via status watcher)
// MinDeploymentTime minimum time for deployment
const (
MinDeploymentTime = time.Minute - time.Second
MinUpdateDeploymentTime = 2 * (time.Minute - time.Second) // TODO: tune this
allocationStatsRequestTimeout = 20 * time.Second
maxRetries = 5
retryDelay = time.Second
)
var (
ErrDeploymentNotFound = errors.New("deployment not found")
ErrorDeploymentNotRunning = errors.New("deployment is not running")
)
func (n *Node) handleVerifyEdgeConstraint(msg actor.Envelope) {
defer msg.Discard()
var request orchestrator.VerifyEdgeConstraintRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
log.Warnw("verify_edge_constraint_unmarshal_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err)
n.sendReply(msg, orchestrator.VerifyEdgeConstraintResponse{
OK: false,
Error: err.Error(),
})
}
// TODO: implement
// also add to docs (dms/behaviors/README.md, help-caps command and man page)
}
func (n *Node) commitDeployment(
ensembleID, allocationID string,
resources types.CommittedResources, ports map[int]int,
) error {
bid, ok := n.getBid(ensembleID)
if !ok {
return fmt.Errorf("no bid requests for ensemble id: %s", ensembleID)
}
n.lock.Lock()
defer n.lock.Unlock()
if bid.expire.Before(time.Now()) {
return fmt.Errorf("bid request for ensemble id: %s has expired", ensembleID)
}
if err := n.allocator.Commit(context.Background(), allocationID, resources, ports, bid.request.V1.PublicPorts.Dynamic, bid.expire.Unix()); err != nil {
return fmt.Errorf("commit resources for ensemble allocID: %s: %w", allocationID, err)
}
return nil
}
func (n *Node) handleCommitDeployment(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("commit_deployment_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err)
n.sendReply(msg, orchestrator.CommitDeploymentResponse{Error: err.Error()})
}
var request orchestrator.CommitDeploymentRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
log.Infow("commit_deployment_started",
"labels", []string{string(observability.LabelDeployment)},
"ensembleID", request.EnsembleID)
resp := orchestrator.CommitDeploymentResponse{}
err := n.commitDeployment(request.EnsembleID, request.AllocationName, request.Resources, request.PortMapping)
if err != nil {
handleErr(err)
return
}
resp.OK = true
n.sendReply(msg, resp)
}
type NewDeploymentRequest struct {
Ensemble jobtypes.EnsembleConfig
}
type NewDeploymentResponse struct {
Status string
EnsembleID string `json:",omitempty"`
Error string `json:",omitempty"`
}
func (n *Node) saveDeployment(orchestrator orchestrator.Orchestrator) error {
err := n.orchestratorRegistry.SaveOrchestrator(orchestrator)
if err != nil {
return fmt.Errorf("save deployment: %w", err)
}
log.Debugw("deployment_saved", "labels", []string{string(observability.LabelDeployment)},
"orchestratorID", orchestrator.ID(), "stats", orchestrator.Status().String())
return nil
}
func (n *Node) handleNewDeployment(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("new_deployment_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err)
n.sendReply(msg, NewDeploymentResponse{Status: "ERROR", Error: err.Error()})
}
if time.Until(msg.Expiry()) < MinDeploymentTime {
log.Debugf("deployment time too short")
handleErr(errors.New("requested deployment time too short"))
return
}
var request NewDeploymentRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("unmarshal new deployment request: %s", err))
return
}
if request.Ensemble.V1 == nil {
handleErr(errors.New("empty ensemble config"))
return
}
if request.Ensemble.Contracts() != nil {
for _, contract := range request.Ensemble.Contracts() {
if contract.DID == "" {
handleErr(errors.New("contract DID is required"))
return
}
}
// retrieve contract information from contract host
for k, contract := range request.Ensemble.Contracts() {
// Call the contract host node (not contract actor) to get contract info
// Use HandleFromDID which defaults to "root" inbox for node-level behaviors
destination, err := actor.HandleFromDID(contract.Host)
if err != nil {
handleErr(fmt.Errorf("failed to get contract host handle: %w", err))
return
}
// invoke behavior to retrieve contract information
reply, err := n.invokeBehaviour(
destination,
behaviors.ContractInfoBehavior,
behaviors.ContractInfoRequest{
ContractDID: contract.DID,
},
invokeMessageTimeout,
)
if err != nil {
handleErr(fmt.Errorf("failed to invoke contract info for %s: %w", contract.DID, err))
return
}
var contractInfoResp behaviors.ContractInfoResponse
if err := json.Unmarshal(reply.Message, &contractInfoResp); err != nil {
handleErr(fmt.Errorf("failed to unmarshal contract info response: %w", err))
return
}
if !contractInfoResp.OK {
handleErr(fmt.Errorf("contract info error: %s", contractInfoResp.Error))
return
}
contract.Provider = contractInfoResp.Provider
contract.Requestor = contractInfoResp.Requestor
request.Ensemble.V1.Contracts[k] = contract
}
}
orch, err := n.createOrchestrator(n.ctx, request.Ensemble, request.Ensemble.Contracts())
if err != nil {
log.Warnw("orchestrator_creation_failure",
"labels", []string{string(observability.LabelDeployment)},
"error", err)
handleErr(err)
return
}
log.Infow("new_ensemble_deployment_initiated",
"labels", []string{string(observability.LabelDeployment)},
"ensembleID", orch.ID())
n.sendReply(msg, NewDeploymentResponse{
Status: "OK",
EnsembleID: orch.ID(),
})
if err := orch.Deploy(msg.Expiry().Add(-orchestrator.MinEnsembleDeploymentTime)); err != nil {
// Manually save the failed status before stopping to ensure it persists
if err := n.orchestratorRegistry.SaveOrchestrator(orch); err != nil {
log.Warnw("failed to save failed orchestrator", "error", err)
}
orch.Stop()
log.Errorw("ensemble_deployment_error",
"labels", []string{string(observability.LabelDeployment)},
"ensembleID", orch.ID(),
"error", err)
return
}
// Orchestrator status is automatically saved to store via status watcher
}
type DeploymentListRequest struct {
// Existing metadata filter (backward compatible)
Metadata map[string]string `json:"metadata,omitempty"`
// Pagination
Limit int `json:"limit,omitempty"` // Max number of results (default: no limit, for backward compat)
Offset int `json:"offset,omitempty"` // Number of results to skip (default: 0)
// Status filter (for JSON API - parsed from strings in CLI)
Status []jobtypes.DeploymentStatus `json:"status,omitempty"` // Filter by one or more statuses
// Date filters (for JSON API)
CreatedAfter *time.Time `json:"created_after,omitempty"` // Filter by CreatedAt >= value
CreatedBefore *time.Time `json:"created_before,omitempty"` // Filter by CreatedAt <= value
UpdatedAfter *time.Time `json:"updated_after,omitempty"` // Filter by UpdatedAt >= value
UpdatedBefore *time.Time `json:"updated_before,omitempty"` // Filter by UpdatedAt <= value
// Sorting
SortBy string `json:"sort_by,omitempty"` // Field to sort by (e.g., "created_at", "-created_at" for desc)
}
type DeploymentListResponse struct {
// Enhanced deployment information
Deployments []DeploymentInfo `json:"deployments"`
// Pagination metadata
Total int `json:"total"` // Total number of deployments matching filters
HasMore bool `json:"has_more"` // Whether there are more results available
NextOffset int `json:"next_offset,omitempty"` // Offset for next page (if has_more is true)
}
type DeploymentInfo struct {
OrchestratorID string `json:"orchestrator_id"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}
func (n *Node) handleDeploymentList(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("deployment_list_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err)
n.sendReply(msg, DeploymentListResponse{Deployments: []DeploymentInfo{}})
}
var request DeploymentListRequest
var resp DeploymentListResponse
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling deployment list request: %s", err))
return
}
// Build query from request
query := orchestrator.DeploymentQuery{
Limit: request.Limit,
Offset: request.Offset,
SortBy: request.SortBy,
}
// Status filter
if len(request.Status) > 0 {
query.StatusFilter = request.Status
}
// Date filters
if request.CreatedAfter != nil {
query.CreatedAfter = request.CreatedAfter
}
if request.CreatedBefore != nil {
query.CreatedBefore = request.CreatedBefore
}
if request.UpdatedAfter != nil {
query.UpdatedAfter = request.UpdatedAfter
}
if request.UpdatedBefore != nil {
query.UpdatedBefore = request.UpdatedBefore
}
// Query deployments from store
deployments, total, err := n.orchestratorRegistry.QueryDeployments(query)
if err != nil {
handleErr(fmt.Errorf("failed to query deployments: %w", err))
return
}
// Apply metadata filtering (in-memory, as metadata is in deployment_data)
filteredDeployments := make([]DeploymentInfo, 0)
for _, deployment := range deployments {
if shouldIncludeDeployment(deployment, request.Metadata) {
info := DeploymentInfo{
OrchestratorID: deployment.OrchestratorID,
Status: deployment.Status.String(),
CreatedAt: deployment.CreatedAt,
UpdatedAt: deployment.UpdatedAt,
CompletedAt: deployment.CompletedAt,
Metadata: deployment.Manifest.Metadata,
}
filteredDeployments = append(filteredDeployments, info)
}
}
// Calculate pagination metadata
resp.Deployments = filteredDeployments
resp.Total = total
resp.HasMore = request.Limit > 0 && (request.Offset+len(filteredDeployments) < total)
if resp.HasMore {
resp.NextOffset = request.Offset + len(filteredDeployments)
}
n.sendReply(msg, resp)
}
// shouldIncludeDeployment checks if a deployment should be included based on metadata filter
func shouldIncludeDeployment(deployment *jobtypes.OrchestratorView, metadataFilter map[string]string) bool {
if len(metadataFilter) == 0 {
return true
}
// Check if all metadata filter conditions are met
for k, v := range metadataFilter {
manifestValue, exists := deployment.Manifest.Metadata[k]
if !exists || manifestValue != v {
return false
}
}
return true
}
type DeploymentLogsRequest struct {
EnsembleID string
AllocationName string
}
type DeploymentLogsResponse struct {
LogsWrittenTo string
Error string
}
func (n *Node) handleDeploymentLogs(msg actor.Envelope) {
defer msg.Discard()
orchestratorID := ""
handleErr := func(err error) {
log.Errorw("deployment_logs_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", orchestratorID,
)
n.sendReply(msg, DeploymentLogsResponse{Error: err.Error()})
}
var request DeploymentLogsRequest
var resp DeploymentLogsResponse
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling deployment logs: %w", err))
return
}
// For logs, we need an active orchestrator (in-memory registry only)
o, err := n.orchestratorRegistry.GetOrchestrator(request.EnsembleID)
if err != nil {
handleErr(fmt.Errorf("failed to get orchestrator: %w", err))
return
}
orchestratorID = o.ID()
data, err := o.GetAllocationLogs(request.AllocationName)
if err != nil {
handleErr(fmt.Errorf("failed to get allocation logs: %w", err))
return
}
allocDir, err := o.WriteAllocationLogs(request.AllocationName, data.Stdout, data.Stderr)
if err != nil {
handleErr(fmt.Errorf("failed to write allocation logst: %w", err))
return
}
resp.LogsWrittenTo = allocDir
n.sendReply(msg, resp)
}
type DeploymentStatusRequest struct {
ID string `json:"id"`
IncludeUsage bool `json:"include_usage,omitempty"`
}
type DeploymentStatusResponse struct {
Status string `json:"status"`
Error string `json:"error"`
AllocationInfo map[string]jobtypes.AllocationInfo `json:"allocation_info,omitempty"`
AllocationUsage map[string]*types.ExecutorStats `json:"allocation_usage,omitempty"`
}
func (n *Node) handleDeploymentStatus(msg actor.Envelope) {
defer msg.Discard()
orchestratorID := ""
handleErr := func(err error) {
log.Errorw("deployment_status_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", orchestratorID,
)
n.sendReply(msg, DeploymentStatusResponse{Error: err.Error()})
}
var request DeploymentStatusRequest
var resp DeploymentStatusResponse
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling deployment status: %s", err))
return
}
// Read deployment status from store (primary source of truth)
deployment, err := n.orchestratorRegistry.GetDeployment(request.ID)
if err != nil {
handleErr(fmt.Errorf("failed to get deployment: %s", err))
return
}
orchestratorID = deployment.ID
resp.Status = deployment.Status.String()
if deployment.Status == jobtypes.DeploymentStatusRunning {
orch, err := n.orchestratorRegistry.GetOrchestrator(request.ID)
if err != nil {
handleErr(fmt.Errorf("failed to get orchestrator: %s", err))
return
}
resp.AllocationInfo = orch.AllocationInfo()
if request.IncludeUsage {
manifest := orch.Manifest()
usage := make(map[string]*types.ExecutorStats, len(manifest.Allocations))
for allocID, allocManifest := range manifest.Allocations {
reply, err := n.invokeBehaviour(
allocManifest.Handle,
behaviors.AllocationStatsBehavior,
behaviors.AllocationStatsRequest{},
allocationStatsRequestTimeout,
)
if err != nil {
handleErr(fmt.Errorf("failed to invoke allocation stats for %s: %w", allocID, err))
return
}
var statsResp behaviors.AllocationStatsResponse
if err := json.Unmarshal(reply.Message, &statsResp); err != nil {
handleErr(fmt.Errorf("failed to decode allocation stats response for %s: %w", allocID, err))
return
}
if !statsResp.OK {
handleErr(fmt.Errorf("allocation %s stats error: %s", allocID, statsResp.Error))
return
}
usage[allocID] = statsResp.Stats
}
resp.AllocationUsage = usage
}
}
n.sendReply(msg, resp)
}
type DeploymentManifestRequest struct {
ID string
}
type DeploymentManifestResponse struct {
Manifest jobs.EnsembleManifest `json:"manifest"`
Error string `json:"error,omitempty"`
}
func (n *Node) handleDeploymentManifest(msg actor.Envelope) {
defer msg.Discard()
orchestratorID := ""
handleErr := func(err error) {
log.Errorw("deployment_manifest_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", orchestratorID,
)
n.sendReply(msg, DeploymentManifestResponse{Error: err.Error()})
}
var request DeploymentManifestRequest
var resp DeploymentManifestResponse
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling deployment manifest: %s", err))
return
}
// Read deployment manifest from store (primary source of truth)
deployment, err := n.orchestratorRegistry.GetDeployment(request.ID)
if err != nil {
handleErr(fmt.Errorf("failed to get deployment: %s", err))
return
}
// if deployment is running, get the latest manifest directly from orchestrator
if deployment.Status == jobtypes.DeploymentStatusRunning {
orch, err := n.orchestratorRegistry.GetOrchestrator(request.ID)
if err != nil {
handleErr(fmt.Errorf("failed to get orchestrator: %s", err))
return
}
resp.Manifest = orch.Manifest()
n.sendReply(msg, resp)
return
}
resp.Manifest = deployment.Manifest
n.sendReply(msg, resp)
}
func (n *Node) handleDeploymentInfo(msg actor.Envelope) {
defer msg.Discard()
orchestratorID := ""
handleErr := func(err error) {
log.Errorw("deployment_info_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", orchestratorID,
)
n.sendReply(msg, DeploymentInfoResponse{Error: err.Error()})
}
var request DeploymentInfoRequest
var resp DeploymentInfoResponse
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling deployment info request: %s", err))
return
}
if request.ID == "" {
handleErr(errors.New("deployment ID is required"))
return
}
// Get deployment from store (primary source of truth)
deployment, err := n.orchestratorRegistry.GetDeployment(request.ID)
if err != nil {
handleErr(fmt.Errorf("failed to get deployment: %s", err))
return
}
orchestratorID = deployment.OrchestratorID
resp.ID = deployment.OrchestratorID
resp.Status = deployment.Status.String()
// Try to get orchestrator (may exist even for non-running deployments)
orch, err := n.orchestratorRegistry.GetOrchestrator(request.ID)
hasOrchestrator := err == nil
var manifest jobs.EnsembleManifest
var allocationInfo map[string]jobtypes.AllocationInfo
if hasOrchestrator {
// If orchestrator exists, get latest manifest from it
manifest = orch.Manifest()
resp.Manifest = &manifest
resp.Allocations = make(map[string]AllocationDetails)
// Get allocation info
allocationInfo = orch.AllocationInfo()
for allocID, info := range allocationInfo {
resp.Allocations[allocID] = AllocationDetails{
AllocationID: info.AllocationID,
Status: string(info.Status),
HeartbeatSeq: info.HeartbeatSeq,
HasHealthCheck: info.HasHealthCheck,
Health: info.Health,
ResourceLimit: info.ResourceLimit,
ResourceUsage: info.ResourceUsage,
DNSName: info.DNSName,
IP: info.IP,
Timestamp: info.Timestamp,
}
}
} else {
// If orchestrator doesn't exist, use manifest from store
manifest = deployment.Manifest
resp.Manifest = &manifest
resp.Allocations = make(map[string]AllocationDetails)
allocationInfo = make(map[string]jobtypes.AllocationInfo)
}
// Collect resource usage if requested (works for both running and completed deployments if orchestrator exists)
if request.IncludeUsage && hasOrchestrator {
usage := make(map[string]*types.ExecutorStats, len(manifest.Allocations))
// Build a map from allocation ID to manifest key for matching
allocIDToManifestKey := make(map[string]string, len(allocationInfo))
for allocID := range allocationInfo {
parsedID, err := types.ParseAllocationID(allocID)
if err != nil {
log.Warnw("failed to parse allocation ID for usage mapping",
"allocationID", allocID,
"error", err)
continue
}
manifestKey := parsedID.ManifestKey()
allocIDToManifestKey[allocID] = manifestKey
}
for manifestKey, allocManifest := range manifest.Allocations {
// Find the corresponding allocation ID
var matchingAllocID string
for allocID, mappedKey := range allocIDToManifestKey {
if mappedKey == manifestKey {
matchingAllocID = allocID
break
}
}
if matchingAllocID == "" {
log.Warnw("could not find matching allocation ID for manifest key",
"manifestKey", manifestKey)
continue
}
reply, err := n.invokeBehaviour(
allocManifest.Handle,
behaviors.AllocationStatsBehavior,
behaviors.AllocationStatsRequest{},
allocationStatsRequestTimeout,
)
if err != nil {
log.Warnw("failed to get allocation stats",
"allocation", matchingAllocID,
"error", err)
continue
}
var statsResp behaviors.AllocationStatsResponse
if err := json.Unmarshal(reply.Message, &statsResp); err != nil {
log.Warnw("failed to decode allocation stats",
"allocation", matchingAllocID,
"error", err)
continue
}
if statsResp.OK && statsResp.Stats != nil {
// Update allocation details with stats
if details, exists := resp.Allocations[matchingAllocID]; exists {
usage[matchingAllocID] = statsResp.Stats
details.ExecutorStats = statsResp.Stats
resp.Allocations[matchingAllocID] = details
}
} else {
log.Warnw("allocation stats error",
"allocation", matchingAllocID,
"error", statsResp.Error)
}
}
resp.Usage = usage
}
// Collect logs if requested (works for both running and completed deployments if orchestrator exists)
if request.IncludeLogs {
// Build a map from allocation ID to manifest key for matching
manifestKeyToAllocID := make(map[string]string, len(allocationInfo))
for allocID := range allocationInfo {
parsedID, err := types.ParseAllocationID(allocID)
if err != nil {
log.Warnw("failed to parse allocation ID for logs mapping",
"allocationID", allocID,
"error", err)
continue
}
manifestKey := parsedID.ManifestKey()
manifestKeyToAllocID[manifestKey] = allocID
}
// Determine which allocations to get logs for
allocationsToLog := make(map[string]bool)
if len(request.AllocationNames) == 0 {
// Get logs for all allocations in manifest
for manifestKey := range manifest.Allocations {
allocationsToLog[manifestKey] = true
}
} else {
// Get logs for specified allocation names (could be config names or manifest keys)
// Try to match them to manifest keys
for _, requestedName := range request.AllocationNames {
found := false
// First try as manifest key directly
if _, exists := manifest.Allocations[requestedName]; exists {
allocationsToLog[requestedName] = true
found = true
} else {
// Try as config name - search through manifest allocations
for manifestKey, allocManifest := range manifest.Allocations {
if allocManifest.RedundancyGroup == requestedName {
allocationsToLog[manifestKey] = true
found = true
break
}
}
}
if !found {
log.Warnw("requested allocation name not found in manifest",
"requestedName", requestedName)
}
}
}
// Get logs for each allocation
for manifestKey := range allocationsToLog {
// Find the corresponding allocation ID
allocID, exists := manifestKeyToAllocID[manifestKey]
if !exists {
log.Warnw("could not find matching allocation ID for manifest key",
"manifestKey", manifestKey)
continue
}
logsData, err := orch.GetAllocationLogs(manifestKey)
if err != nil {
log.Warnw("failed to get allocation logs",
"allocation", allocID,
"manifestKey", manifestKey,
"error", err)
// Add error to allocation details
if details, exists := resp.Allocations[allocID]; exists {
details.Logs = &AllocationLogs{
Error: err.Error(),
}
resp.Allocations[allocID] = details
}
continue
}
// Write logs to directory (reuse existing logic)
allocDir, err := orch.WriteAllocationLogs(manifestKey, logsData.Stdout, logsData.Stderr)
if err != nil {
log.Warnw("failed to write allocation logs",
"allocation", allocID,
"manifestKey", manifestKey,
"error", err)
if details, exists := resp.Allocations[allocID]; exists {
details.Logs = &AllocationLogs{
Error: err.Error(),
}
resp.Allocations[allocID] = details
}
continue
}
// Add logs to allocation details
if details, exists := resp.Allocations[allocID]; exists {
details.Logs = &AllocationLogs{
StdoutPath: filepath.Join(allocDir, "stdout.log"),
StderrPath: filepath.Join(allocDir, "stderr.log"),
LogsWrittenTo: allocDir,
}
resp.Allocations[allocID] = details
}
}
}
n.sendReply(msg, resp)
}
type DeploymentInfoRequest struct {
ID string `json:"id"` // Deployment ID (required)
IncludeUsage bool `json:"include_usage,omitempty"` // Include resource usage stats
IncludeLogs bool `json:"include_logs,omitempty"` // Include logs for all allocations
AllocationNames []string `json:"allocation_names,omitempty"` // Specific allocations to include logs for (empty = all)
}
type AllocationLogs struct {
StdoutPath string `json:"stdout_path,omitempty"` // Path to stdout.log file
StderrPath string `json:"stderr_path,omitempty"` // Path to stderr.log file
LogsWrittenTo string `json:"logs_written_to,omitempty"` // Directory path where logs are written
Error string `json:"error,omitempty"` // Error retrieving logs (if any)
}
type AllocationDetails struct {
// From AllocationInfo
AllocationID string `json:"allocation_id"`
Status string `json:"status"`
HeartbeatSeq int64 `json:"heartbeat_seq"`
HasHealthCheck bool `json:"has_health_check"`
Health string `json:"health"`
ResourceLimit types.Resources `json:"resource_limit"`
ResourceUsage jobtypes.AllocationResourceUsage `json:"resource_usage"`
DNSName string `json:"dns_name"`
IP string `json:"ip"`
Timestamp int64 `json:"timestamp"`
// Optional: Resource usage stats (if IncludeUsage=true)
ExecutorStats *types.ExecutorStats `json:"executor_stats,omitempty"`
// Optional: Logs (if IncludeLogs=true)
Logs *AllocationLogs `json:"logs,omitempty"`
}
type DeploymentInfoResponse struct {
// Basic deployment information
ID string `json:"id"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
// Full manifest
Manifest *jobs.EnsembleManifest `json:"manifest,omitempty"`
// Allocation information
Allocations map[string]AllocationDetails `json:"allocations,omitempty"`
// Optional: Resource usage (if IncludeUsage=true)
Usage map[string]*types.ExecutorStats `json:"usage,omitempty"`
}
type DeploymentShutdownRequest struct {
ID string
}
type DeploymentShutdownResponse struct {
OK bool
Error string
}
func (n *Node) handleDeploymentShutdown(msg actor.Envelope) {
defer msg.Discard()
orchestratorID := ""
handleErr := func(err error) {
log.Errorw("deployment_shutdown_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", orchestratorID,
)
n.sendReply(msg, DeploymentShutdownResponse{Error: err.Error()})
}
var request DeploymentShutdownRequest
var resp DeploymentShutdownResponse
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
o, err := n.orchestratorRegistry.GetOrchestrator(request.ID)
if err != nil {
handleErr(err)
return
}
if o.Status() != jobs.DeploymentStatusRunning {
log.Debugw("deployment_not_running_for_shutdown",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", request.ID,
"status", o.Status())
// maybe-TODO: if it's still provisioning/committing,
// we should stop the deployment process anyway
resp.Error = ErrorDeploymentNotRunning.Error()
n.sendReply(msg, resp)
return
}
err = o.Shutdown()
if err != nil {
handleErr(err)
return
}
// force status update and ignore status watcher
if err := n.orchestratorRegistry.SaveOrchestrator(o); err != nil {
handleErr(err)
return
}
resp.OK = true
n.sendReply(msg, resp)
}
func (n *Node) handleDeploymentRevert(msg actor.Envelope) {
defer msg.Discard()
orchestratorID := ""
var request orchestrator.DeploymentRevertRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
log.Debugw("revert_deployment_unmarshal_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", orchestratorID,
)
n.sendReply(msg, orchestrator.DeploymentRevertResponse{
OK: false,
Error: fmt.Sprintf("failed to unmarshal revert request: %v", err),
})
return
}
ensembleID := request.EnsembleID
// forget bid
n.lock.Lock()
delete(n.bids, ensembleID)
delete(n.answeredBids, ensembleID)
n.lock.Unlock()
err := n.network.DestroySubnet(request.EnsembleID)
if err != nil {
log.Warnf("failed to destroy subnet for ensemble id: %s: %v (it may not have been created or may already been destroyed)", ensembleID, err)
}
for _, allocID := range request.AllocsByName {
// Now the allocID comes pre-constructed from the orchestrator, so we use it directly
// without calling types.ConstructAllocationID again
// Here we're considering both the committed and uncommitted resources/allocations/ports from the orchestrator
// TODO: consider the allocation state and perform the necessary actions eg: uncommit, release, etc.
// This would need addition of a new behavior. UnCommitAllocationBehavior?
// https://gitlab.com/nunet/device-management-service/-/issues/961
if a, _ := n.allocator.GetAllocation(allocID); a != nil {
if err := n.allocator.Release(context.Background(), allocID); err != nil {
log.Errorw("revert_deployment_release_failure",
"labels", []string{string(observability.LabelDeployment)},
"ensembleID", ensembleID,
"error", err)
}
} else {
log.Debugf("allocation %s not found in allocator, skipping to uncommit", allocID)
if err := n.allocator.Uncommit(context.Background(), allocID); err != nil {
log.Errorw("revert_deployment_uncommit_failure",
"labels", []string{string(observability.LabelDeployment)},
"ensembleID", ensembleID,
"error", err,
)
} else {
log.Debugf("successfully uncommitted allocation %s", allocID)
}
}
}
log.Infow("deployment_reverted",
"labels", []string{string(observability.LabelDeployment)},
"ensembleID", ensembleID)
// Send success response
n.sendReply(msg, orchestrator.DeploymentRevertResponse{
OK: true,
})
}
type UpdateDeploymentRequest struct {
EnsembleID string
Ensemble jobtypes.EnsembleConfig
}
type UpdateDeploymentResponse struct {
OK bool
Error string `json:",omitempty"`
}
func (n *Node) handleDeploymentUpdate(msg actor.Envelope) {
defer msg.Discard()
orchestratorID := ""
handleErr := func(err error) {
log.Errorw("deployment update error",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", orchestratorID,
)
n.sendReply(msg, UpdateDeploymentResponse{Error: err.Error()})
}
if time.Until(msg.Expiry()) < MinUpdateDeploymentTime {
handleErr(fmt.Errorf("requested deployment update time too short"))
return
}
var request UpdateDeploymentRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("unmarshal update deployment request: %s", err))
return
}
orch, err := n.orchestratorRegistry.GetOrchestrator(request.EnsembleID)
if err != nil {
handleErr(err)
return
}
if orch.Status() != jobs.DeploymentStatusRunning {
handleErr(errors.Join(fmt.Errorf("deployment %s is not running(status=%v), cannot update", request.EnsembleID, orch.Status()), ErrorDeploymentNotRunning))
return
}
log.Infof("updating ensemble: %s", orch.ID())
if err := orch.Update(request.Ensemble, msg.Expiry().Add(-orchestrator.MinEnsembleUpdateTimeout)); err != nil {
handleErr(fmt.Errorf("error updating ensemble: %s", err))
return
}
n.sendReply(msg, UpdateDeploymentResponse{
OK: true,
})
// update the deployment in db
if err := n.updateDeployment(orch); err != nil {
log.Errorf("error saving deployment %s: %s", orch.ID(), err)
}
}
func (n *Node) updateDeployment(_ orchestrator.Orchestrator) error {
// TODO
return nil
}
// Deployment pruning and clearing commands
type DeploymentPruneRequest struct {
Before string `json:"before,omitempty"`
All bool `json:"all,omitempty"`
}
type DeploymentPruneResponse struct {
OK bool `json:"ok"`
Error string `json:"error,omitempty"`
}
func (n *Node) handleDeploymentPrune(msg actor.Envelope) {
defer msg.Discard()
log.Infow("deployment_prune_started",
"labels", []string{string(observability.LabelDeployment)},
"msg", msg)
orchestratorID := ""
handleErr := func(err error) {
log.Errorw("deployment_prune_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", orchestratorID,
)
n.sendReply(msg, DeploymentPruneResponse{Error: err.Error()})
}
var request DeploymentPruneRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling deployment prune request: %s", err))
return
}
if request.All {
// delete all deployments whose status is greater than Running
statuses := []jobtypes.DeploymentStatus{
jobtypes.DeploymentStatusFailed,
jobtypes.DeploymentStatusCompleted,
}
for _, s := range statuses {
views, err := n.orchestratorRegistry.GetDeploymentsByStatus(s)
if err != nil {
handleErr(fmt.Errorf("failed to list deployments by status %s: %w", s.String(), err))
return
}
for _, v := range views {
orchestratorID = v.OrchestratorID
if err := n.orchestratorRegistry.DeleteDeployment(v.OrchestratorID); err != nil {
handleErr(fmt.Errorf("failed to delete deployment %s: %w", v.OrchestratorID, err))
return
}
}
}
log.Infow("deployments_pruned_by_status",
"labels", []string{string(observability.LabelDeployment)},
"mode", "all_status_gt_running")
n.sendReply(msg, DeploymentPruneResponse{OK: true})
return
}
if strings.TrimSpace(request.Before) == "" {
handleErr(errors.New("before must be provided unless --all is used"))
return
}
// parse supported formats: duration (1s,1m,1h,1d) and datetime (RFC3339, common)
var cutoffTime time.Time
// Try duration forms first
before := strings.TrimSpace(request.Before)
if strings.HasSuffix(before, "d") {
// days is not a standard Go duration; handle explicitly
daysStr := strings.TrimSuffix(before, "d")
if daysStr == "" {
handleErr(fmt.Errorf("invalid before duration: %s", before))
return
}
if nDays, err := strconv.Atoi(daysStr); err == nil && nDays > 0 {
cutoffTime = time.Now().AddDate(0, 0, -nDays)
} else {
handleErr(fmt.Errorf("invalid before duration days: %s", before))
return
}
} else if dur, err := time.ParseDuration(before); err == nil {
cutoffTime = time.Now().Add(-dur)
} else {
// Try datetime formats
var parseErr error
for _, layout := range []string{time.RFC3339, "2006-01-02 15:04:05", "2006-01-02"} {
if t, err := time.Parse(layout, before); err == nil {
cutoffTime = t
parseErr = nil
break
}
parseErr = err
continue
}
if cutoffTime.IsZero() {
handleErr(fmt.Errorf("invalid before value: %w", parseErr))
return
}
}
// delete all deployments whose status is greater than Running
statuses := []jobtypes.DeploymentStatus{
jobtypes.DeploymentStatusFailed,
jobtypes.DeploymentStatusCompleted,
}
for _, s := range statuses {
views, err := n.orchestratorRegistry.GetDeploymentsByStatus(s)
if err != nil {
handleErr(fmt.Errorf("failed to list deployments by status %s: %w", s.String(), err))
return
}
for _, v := range views {
if v.CreatedAt.Before(cutoffTime) {
if err := n.orchestratorRegistry.DeleteDeployment(v.OrchestratorID); err != nil {
handleErr(fmt.Errorf("failed to delete deployment %s: %w", v.OrchestratorID, err))
return
}
}
}
}
log.Infow("deployments_pruned",
"labels", []string{string(observability.LabelDeployment)},
"mode", "before")
n.sendReply(msg, DeploymentPruneResponse{OK: true})
}
type DeploymentDeleteRequest struct {
OrchestratorID string `json:"orchestrator_id"`
}
type DeploymentDeleteResponse struct {
OK bool `json:"ok"`
Error string `json:"error,omitempty"`
}
func (n *Node) handleDeploymentDelete(msg actor.Envelope) {
defer msg.Discard()
orchestratorID := ""
handleErr := func(err error) {
log.Errorw("deployment_delete_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", orchestratorID,
)
n.sendReply(msg, DeploymentDeleteResponse{Error: err.Error()})
}
var request DeploymentDeleteRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(fmt.Errorf("error unmarshalling deployment delete request: %s", err))
return
}
if request.OrchestratorID == "" {
handleErr(errors.New("orchestrator_id is required"))
return
}
orchestratorID = request.OrchestratorID
// Delete the specific deployment
if err := n.orchestratorRegistry.DeleteDeployment(request.OrchestratorID); err != nil {
handleErr(fmt.Errorf("failed to delete deployment %s: %w", request.OrchestratorID, err))
return
}
log.Infow("deployment_deleted",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", request.OrchestratorID)
n.sendReply(msg, DeploymentDeleteResponse{OK: true})
}
// handleBidSigning is for new provisioned dmses to sign bids
func (n *Node) handleBidSigning(msg actor.Envelope) {
defer msg.Discard()
var req jobtypes.SignPromiseBidRequest
err := json.Unmarshal(msg.Message, &req)
if err != nil {
n.sendReply(msg, jobtypes.PromiseBidSigningResponse{
Error: err.Error(),
})
return
}
bid := jobtypes.Bid{
V1: &jobtypes.BidV1{
EnsembleID: req.Bid.EnsembleID(),
NodeID: req.Bid.NodeID(),
Peer: n.hostID,
Location: req.Bid.Location(),
Handle: n.actor.Handle(),
},
}
provider, err := n.rootCap.Trust().GetProvider(n.actor.Security().DID())
if err != nil {
log.Debugw("provider_retrieval_error",
"labels", string(observability.LabelDeployment),
"error", err)
n.sendReply(msg, jobtypes.PromiseBidSigningResponse{
Error: err.Error(),
})
return
}
err = bid.Sign(provider)
if err != nil {
log.Debugw("provider_sign_error",
"labels", string(observability.LabelDeployment),
"error", err)
n.sendReply(msg, jobtypes.PromiseBidSigningResponse{
Error: err.Error(),
})
return
}
n.storeBid(bid.EnsembleID(), req.Nounce, req.BidRequest)
n.sendReply(msg, jobtypes.PromiseBidSigningResponse{
Bid: bid,
})
}
func (n *Node) handlePromiseBid(msg actor.Envelope) {
defer msg.Discard()
log.Debug("handling promise bids")
var req jobtypes.PromiseBidRequest
err := json.Unmarshal(msg.Message, &req)
if err != nil {
n.sendReply(msg, jobtypes.ConvertedPromiseBidResponse{
Error: err.Error(),
})
return
}
// check if we already signed this bid
bid, ok := n.getBid(req.Bid.EnsembleID())
if !ok {
n.sendReply(msg, jobtypes.ConvertedPromiseBidResponse{
Error: fmt.Sprintf("bid with ensemble id %s not found", req.Bid.EnsembleID()),
})
return
}
allProviders := n.serverProviderRegistry.All()
log.Debugf("checking %d providers for provisioning", len(allProviders))
ctx, cancel := context.WithCancel(n.ctx)
defer cancel()
var (
wg sync.WaitGroup
found int32
targetPlan provider.Plan
targetProvider provider.Provider
mu sync.Mutex
)
for _, pp := range allProviders {
wg.Add(1)
go func(pp provider.Provider) {
defer wg.Done()
if atomic.LoadInt32(&found) == 1 {
return
}
select {
case <-ctx.Done():
return
default:
}
plans, err := pp.ListPlans(ctx)
if err != nil {
return
}
matchedPlan, err := pp.SelectMatchingPlan(plans, bid.request.V1.Resources)
if err != nil || matchedPlan == nil {
return
}
if atomic.CompareAndSwapInt32(&found, 0, 1) {
mu.Lock()
targetPlan = *matchedPlan
targetProvider = pp
mu.Unlock()
cancel() // stop others
}
}(pp)
}
wg.Wait()
if atomic.LoadInt32(&found) == 0 || targetProvider == nil {
log.Debug("targetProvider is nil — no matching plan found")
n.sendReply(msg, jobtypes.ConvertedPromiseBidResponse{
Error: fmt.Sprintf("no suitable plan found for bid with ensemble: %s", req.Bid.EnsembleID()),
})
return
}
// TODO: for now image is empty, maybe make it part of the plan object
server, err := targetProvider.ProvisionServer(n.ctx, targetPlan, targetPlan.Name, "", msg.From.DID.String())
if err != nil {
n.sendReply(msg, jobtypes.ConvertedPromiseBidResponse{
Error: fmt.Sprintf("failed to provision server for bid with ensemble: %s", req.Bid.EnsembleID()),
})
return
}
log.Debugf("successfully provisioned server %s using provider %s", server.ID, targetProvider.Name())
connected := false
for i := 1; i <= maxRetries; i++ {
err = n.network.Connect(n.ctx, fmt.Sprintf("%s/p2p/%s", server.ListenAddr, server.PeerID))
if err == nil {
connected = true
break
}
time.Sleep(retryDelay)
}
if !connected {
n.sendReply(msg, jobtypes.ConvertedPromiseBidResponse{
Error: fmt.Sprintf("failed to connect to provisioned resource for bid: %s", req.Bid.EnsembleID()),
})
err := targetProvider.DeleteServer(ctx, server.ID)
if err != nil {
log.Errorf("failed to delete provisioned instance: %v", err)
}
return
}
err = n.addProvisionedResources(msg.From.DID.String(), targetProvider, server)
if err != nil {
log.Errorf("failed to record provisioned resource in store: %v", err)
}
destination, err := actor.HandleFromPeerID(server.PeerID)
if err != nil {
n.sendReply(msg, jobtypes.ConvertedPromiseBidResponse{
Error: err.Error(),
})
return
}
signReq := jobtypes.SignPromiseBidRequest{
Bid: req.Bid,
BidRequest: bid.request,
Nounce: bid.nonce,
}
envelope, err := n.invokeBehaviour(destination, behaviors.PromiseBidSigningBehavior, signReq, invokeMessageTimeout)
if err != nil {
n.sendReply(msg, jobtypes.ConvertedPromiseBidResponse{
Error: err.Error(),
})
return
}
var signedBidPayload jobtypes.PromiseBidSigningResponse
err = json.Unmarshal(envelope.Message, &signedBidPayload)
if err != nil {
n.sendReply(msg, jobtypes.ConvertedPromiseBidResponse{
Error: err.Error(),
})
return
}
resp := jobtypes.ConvertedPromiseBidResponse{
Bid: signedBidPayload.Bid,
}
n.sendReply(msg, resp)
}
func (n *Node) addProvisionedResources(orchestratorDID string, p provider.Provider, server *provider.Server) error {
return n.gatewayStore.Insert(&store.ProvisionedResources{
ProvisionedVMPeerID: server.PeerID,
Orchestrator: orchestratorDID,
ProviderName: p.Name(),
Resource: *server,
CreatedAt: time.Now(),
})
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"encoding/json"
"path/filepath"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/observability"
)
type LoggerConfigRequest struct {
Interval int `json:"interval,omitempty"`
URL string `json:"url,omitempty"`
Level string `json:"level,omitempty"`
APIKey string `json:"api_key,omitempty"`
APMURL string `json:"apm_url,omitempty"`
ElasticEnabled *bool `json:"elastic_enabled,omitempty"`
}
type LoggerConfigResponse struct {
Error string `json:"error,omitempty"`
OK bool
}
func (n *Node) handleLoggerConfig(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("logger_config_error",
"labels", string(observability.LabelNode),
"error", err)
n.sendReply(msg, LoggerConfigResponse{Error: err.Error()})
}
var (
req LoggerConfigRequest
resp LoggerConfigResponse
)
if err := json.Unmarshal(msg.Message, &req); err != nil {
handleErr(err)
return
}
log.Debugw("logger_config_request_received",
"labels", string(observability.LabelNode),
"configRequest", req)
if req.Interval != 0 {
if err := observability.SetFlushInterval(req.Interval); err != nil {
handleErr(err)
return
}
log.Debugw("logger_flush_interval_updated",
"labels", string(observability.LabelNode),
"interval", req.Interval)
}
if req.Level != "" {
if err := observability.SetLogLevel(req.Level); err != nil {
handleErr(err)
return
}
log.Debugw("logger_level_updated",
"labels", string(observability.LabelNode),
"level", req.Level)
}
if req.URL != "" {
if err := observability.SetElasticsearchEndpoint(req.URL); err != nil {
handleErr(err)
return
}
log.Debugw("logger_elasticsearch_endpoint_updated",
"labels", string(observability.LabelNode),
"url", req.URL)
}
if req.APIKey != "" { // Handle API Key
if err := observability.SetAPIKey(req.APIKey); err != nil {
handleErr(err)
return
}
log.Debugw("logger_api_key_updated",
"labels", string(observability.LabelNode))
}
if req.APMURL != "" { // Handle APM URL
if err := observability.SetAPMURL(req.APMURL); err != nil {
handleErr(err)
return
}
log.Debugw("logger_apm_url_updated",
"labels", string(observability.LabelNode),
"apmUrl", req.APMURL)
}
if req.ElasticEnabled != nil { // Handle Elasticsearch Enabled
if err := observability.EnableElasticsearchLogging(*req.ElasticEnabled); err != nil {
handleErr(err)
return
}
log.Debugw("logger_elasticsearch_enabled_flag_updated",
"labels", string(observability.LabelNode),
"enabled", *req.ElasticEnabled)
}
resp.OK = true
n.sendReply(msg, resp)
}
func (n *Node) handleFlightrec(msg actor.Envelope) {
defer msg.Discard()
observability.FlightrecCapture(filepath.Join(n.dmsConfig.WorkDir, "logs"), "flightrec.trace")
n.sendReply(msg, PingResponse{})
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"encoding/json"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
type OnboardRequest struct {
NoGPU bool
GPUs string
Config types.OnboardingConfig
}
type OnboardResponse struct {
OK bool `json:"success"`
Error string `json:"error,omitempty"`
Config types.OnboardingConfig `json:"config,omitempty"`
}
func (n *Node) handleOnboard(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("onboard_error",
"labels", string(observability.LabelNode),
"error", err)
n.sendReply(msg, OnboardResponse{Error: err.Error()})
}
resp := OnboardResponse{}
var request OnboardRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
config, err := n.onboarding.Onboard(context.Background(), request.Config)
if err != nil {
handleErr(err)
return
}
resp.Config = config
resp.OK = true
n.sendReply(msg, resp)
}
type OffboardRequest struct{}
type OffboardResponse struct {
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
func (n *Node) handleOffboard(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("offboard_error",
"labels", string(observability.LabelNode),
"error", err)
n.sendReply(msg, OffboardResponse{Error: err.Error()})
}
resp := OffboardResponse{}
if err := n.onboarding.Offboard(context.Background()); err != nil {
handleErr(err)
return
}
resp.Success = true
n.sendReply(msg, resp)
}
type OnboardStatusResponse struct {
Onboarded bool `json:"onboarded"`
}
func (n *Node) handleOnboardStatus(msg actor.Envelope) {
defer msg.Discard()
resp := OnboardStatusResponse{}
resp.Onboarded = n.onboarding.IsOnboarded()
n.sendReply(msg, resp)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"encoding/json"
"fmt"
"time"
kbucket "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
const defaultPingTimeout = 20 * time.Second
type PingRequest struct {
Host string
}
type PingResponse struct {
Error string
RTT int64
}
var ErrHostEmpty = fmt.Errorf("host is empty")
func (n *Node) handlePeerPing(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("peer_ping_error", "error", err)
n.sendReply(msg, PingResponse{Error: err.Error()})
}
var request PingRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(types.ErrUnmarshal)
return
}
if request.Host == "" {
handleErr(fmt.Errorf("ping request: %w", ErrHostEmpty))
return
}
resp := PingResponse{}
res, err := n.network.Ping(context.Background(), request.Host, defaultPingTimeout)
if err != nil {
handleErr(err)
return
}
if res.Error != nil {
resp.Error = res.Error.Error()
}
resp.RTT = res.RTT.Milliseconds()
n.sendReply(msg, resp)
}
type PeersListResponse struct {
Peers []peer.ID
}
func (n *Node) handlePeersList(msg actor.Envelope) {
defer msg.Discard()
resp := PeersListResponse{
Peers: make([]peer.ID, 0),
}
resp.Peers = n.network.Peers()
n.sendReply(msg, resp)
}
type PeerAddrInfoResponse struct {
ID string `json:"id"`
Address string `json:"listen_addr"`
}
func (n *Node) handlePeerAddrInfo(msg actor.Envelope) {
defer msg.Discard()
stats := n.network.Stat()
resp := PeerAddrInfoResponse{
ID: stats.ID,
Address: stats.ListenAddr,
}
n.sendReply(msg, resp)
}
type PeerDHTResponse struct {
Peers []kbucket.PeerInfo
}
func (n *Node) handlePeerDHT(msg actor.Envelope) {
defer msg.Discard()
// get the underlying libp2p instance and extract the DHT data
libp2pNet, ok := n.network.(*libp2p.Libp2p)
if !ok {
log.Debug("peer_dht_not_libp2p_network")
return
}
resp := PeerDHTResponse{
Peers: libp2pNet.DHT.RoutingTable().GetPeerInfos(),
}
n.sendReply(msg, resp)
}
type PeerConnectRequest struct {
Address string
}
type PeerConnectResponse struct {
Status string
Error string
}
func (n *Node) handlePeerConnect(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorw("peer_connect_error", "error", err)
n.sendReply(msg, PeerConnectResponse{Error: err.Error()})
}
var request PeerConnectRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
log.Debugw("peer_connect_unmarshal_error", "error", err)
handleErr(fmt.Errorf("peer connect: %w", types.ErrUnmarshal))
return
}
resp := PeerConnectResponse{}
err := n.network.Connect(context.Background(), request.Address)
if err != nil {
handleErr(err)
return
}
resp.Status = "CONNECTED"
n.sendReply(msg, resp)
}
type PeerScoreResponse struct {
Score map[string]*network.PeerScoreSnapshot
}
func (n *Node) handlePeerScore(msg actor.Envelope) {
defer msg.Discard()
resp := PeerScoreResponse{Score: make(map[string]*network.PeerScoreSnapshot)}
snapshot := n.network.GetBroadcastScore()
for p, score := range snapshot {
resp.Score[p.String()] = score
}
n.sendReply(msg, resp)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"fmt"
"os"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/sys"
)
type HelloResponse struct {
DID did.DID
}
type PublicStatusResponse struct {
Status string
Resources types.Resources
}
type NetworkInterfaces struct {
Name string `json:"name"`
IP []string `json:"ip"`
MacAddress string `json:"mac_address"`
}
type DiscoveryStatus struct {
Hostname string `json:"hostname"`
DID string `json:"did"`
Network struct {
Interfaces []NetworkInterfaces `json:"interfaces"`
P2P types.NetworkStats `json:"p2p"`
} `json:"network"`
Onboarded bool `json:"onboarded"`
Resources struct {
Total types.Resources `json:"total"`
Onboarded types.Resources `json:"onboarded"`
Allocated types.Resources `json:"allocated"`
Free types.Resources `json:"free"`
} `json:"resources"`
Config config.Config `json:"config"`
Errors []string `json:"errors"`
}
type DiscoveryStatusResponse map[string]DiscoveryStatus
func (n *Node) publicHelloBehavior(msg actor.Envelope) {
pubk, err := did.PublicKeyFromDID(msg.From.DID)
if err != nil {
log.Debugf("failed to extract public key from DID: %s", err)
return
}
p, err := peer.IDFromPublicKey(pubk)
if err != nil {
log.Debugf("failed to extract peer ID from public key: %s", err)
return
}
n.lock.Lock()
if st, ok := n.peers[p]; ok {
st.helloIn = true
} else if n.network.PeerConnected(p) {
// rance with connected notification
st = &peerState{helloIn: true}
n.peers[p] = st
}
n.lock.Unlock()
n.handleHello(msg)
}
func (n *Node) handleBroadcastHelloBehavior(msg actor.Envelope) {
n.handleHello(msg)
}
func (n *Node) handleHello(msg actor.Envelope) {
defer msg.Discard()
log.Debugf("hello from %s", msg.From.Address.HostID)
resp := HelloResponse{DID: n.actor.Security().DID()}
n.sendReply(msg, resp)
}
func (n *Node) publicStatusBehavior(msg actor.Envelope) {
defer msg.Discard()
var resp PublicStatusResponse
machineResources, err := n.hardware.GetMachineResources()
if err != nil {
resp.Status = "ERROR"
} else {
resp.Status = "OK"
resp.Resources = machineResources.Resources
}
n.sendReply(msg, resp)
}
func (n *Node) handleStatusDiscoveryBehavior(msg actor.Envelope) {
defer msg.Discard()
var err error
resp := make(DiscoveryStatusResponse, 0)
// TODO peerIDs for peers under the controller
// only self peer for now
for _, peerID := range []string{n.network.GetHostID().String()} {
peerInfo := DiscoveryStatus{}
collectErrors := func(msg string, err error) {
if err != nil {
peerInfo.Errors = append(peerInfo.Errors, fmt.Sprintf("%s: %v", msg, err))
}
}
peerInfo.Hostname, err = os.Hostname()
collectErrors("error getting hostname", err)
peerInfo.DID = n.actor.Security().DID().String()
ifaceList, err := sys.GetNetInterfaces()
collectErrors("error getting network interfaces", err)
for _, iface := range ifaceList {
netIface := NetworkInterfaces{}
netIface.Name = iface.Name
netIface.MacAddress = iface.HardwareAddr.String()
ip, err := iface.Addrs()
collectErrors(fmt.Sprintf("error getting addrs for iface %s", iface.Name), err)
for _, addr := range ip {
netIface.IP = append(netIface.IP, addr.String())
}
peerInfo.Network.Interfaces = append(peerInfo.Network.Interfaces, netIface)
}
peerInfo.Network.P2P = n.network.Stat()
if n.hardware != nil {
totalResrc, err := n.hardware.GetMachineResources()
collectErrors("error getting total machine resource", err)
peerInfo.Resources.Total = totalResrc.Resources
} else {
collectErrors("error getting hardware info", fmt.Errorf("node hardware manager not set"))
}
if n.onboarding != nil {
peerInfo.Onboarded = n.onboarding.Config.IsOnboarded
if peerInfo.Onboarded {
peerInfo.Resources.Onboarded = n.onboarding.Config.OnboardedResources
peerInfo.Resources.Allocated, err = n.resourceManager.GetTotalAllocation()
collectErrors("error getting allocated resources", err)
freeResources, err := n.resourceManager.GetFreeResources(context.Background())
collectErrors("error getting free resources", err)
peerInfo.Resources.Free = freeResources.Resources
}
} else {
collectErrors("error getting onboarding info", fmt.Errorf("node onboarding not set"))
}
peerInfo.Config = n.dmsConfig
resp[peerID] = peerInfo
}
n.sendReply(msg, resp)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/types"
)
type ResourcesResponse struct {
OK bool
Resources types.Resources
Error string `json:"error,omitempty"`
}
func (n *Node) handleAllocatedResources(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling allocated resources: %s", err)
n.sendReply(msg, ResourcesResponse{Error: err.Error()})
}
resp := ResourcesResponse{}
allocatedResources, err := n.resourceManager.GetTotalAllocation()
if err != nil {
handleErr(err)
return
}
resp.Resources = allocatedResources
resp.OK = true
n.sendReply(msg, resp)
}
func (n *Node) handleFreeResources(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling free resources: %s", err)
n.sendReply(msg, ResourcesResponse{Error: err.Error()})
}
resp := ResourcesResponse{}
freeResources, err := n.resourceManager.GetFreeResources(context.Background())
if err != nil {
handleErr(err)
return
}
resp.OK = true
resp.Resources = freeResources.Resources
n.sendReply(msg, resp)
}
func (n *Node) handleOnboardedResources(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling onboarded resources: %s", err)
n.sendReply(msg, ResourcesResponse{Error: err.Error()})
}
resp := ResourcesResponse{}
onboardedResources, err := n.resourceManager.GetOnboardedResources(context.Background())
if err != nil {
handleErr(err)
return
}
resp.Resources = onboardedResources.Resources
resp.OK = true
n.sendReply(msg, resp)
}
func (n *Node) handleHardwareUsage(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling hardware usage: %s", err)
n.sendReply(msg, ResourcesResponse{Error: err.Error()})
}
resp := ResourcesResponse{}
hardwareUsage, err := n.hardware.GetUsage()
if err != nil {
handleErr(err)
return
}
resp.OK = true
resp.Resources = hardwareUsage
n.sendReply(msg, resp)
}
func (n *Node) handleHardwareSpec(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error handling hardware spec: %s", err)
n.sendReply(msg, ResourcesResponse{Error: err.Error()})
}
resp := ResourcesResponse{}
hardwareSpec, err := n.hardware.GetMachineResources()
if err != nil {
handleErr(err)
return
}
resp.OK = true
resp.Resources = hardwareSpec.Resources
n.sendReply(msg, resp)
}
// Copyright 2025, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/actor"
)
type CreateVolumeRequest struct {
Name string `json:"name"`
ClientPEM string `json:"client_pem"`
}
type CreateVolumeResponse struct {
OK bool
CAData string `json:"ca_data"`
Error string `json:"error,omitempty"`
}
func (n *Node) handleCreateVolume(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error creating volume: %s", err)
n.sendReply(msg, CreateVolumeResponse{Error: err.Error()})
}
var request CreateVolumeRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
certificateAuthority, err := n.volumeController.CreateVolume(request.Name, request.ClientPEM)
if err != nil {
handleErr(err)
return
}
n.lock.Lock()
n.volumeOwners[request.Name] = msg.From.DID.String()
n.lock.Unlock()
resp := CreateVolumeResponse{
OK: true,
CAData: certificateAuthority,
}
n.sendReply(msg, resp)
}
type DeleteVolumeRequest struct {
Name string `json:"name"`
}
type DeleteVolumeResponse struct {
OK bool
Error string `json:"error,omitempty"`
}
func (n *Node) handleDeleteVolume(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error deleting volume: %s", err)
n.sendReply(msg, DeleteVolumeResponse{Error: err.Error()})
}
var request DeleteVolumeRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
owner, exists := n.volumeOwners[request.Name]
if !exists {
handleErr(fmt.Errorf("volume does not exist or ownership is unknown"))
return
}
if owner != msg.From.DID.String() {
handleErr(fmt.Errorf("permission denied: only the creator can delete the volume"))
return
}
err := n.volumeController.DeleteVolume(request.Name)
if err != nil {
handleErr(err)
return
}
n.lock.Lock()
delete(n.volumeOwners, request.Name)
n.lock.Unlock()
resp := DeleteVolumeResponse{OK: true}
n.sendReply(msg, resp)
}
type StartVolumeRequest struct {
Name string `json:"name"`
}
type StartVolumeResponse struct {
OK bool
Error string `json:"error,omitempty"`
}
func (n *Node) handleStartVolume(msg actor.Envelope) {
defer msg.Discard()
handleErr := func(err error) {
log.Errorf("Error starting volume: %s", err)
n.sendReply(msg, StartVolumeResponse{Error: err.Error()})
}
var request StartVolumeRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
handleErr(err)
return
}
resp := StartVolumeResponse{}
err := n.volumeController.StartVolume(request.Name)
if err != nil {
log.Errorf("Error deleting volume: %s", err)
n.sendReply(msg, StartVolumeResponse{Error: err.Error()})
return
}
resp.OK = true
n.sendReply(msg, resp)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/crypto/keystore"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
const (
DefaultContextName = "dms"
UserContextName = "user"
KeystoreDir = "key/"
CapstoreDir = "cap/"
DMSPassphraseEnv = "DMS_PASSPHRASE"
)
const (
ledger = "ledger"
eternl = "eternl"
)
// IsLedgerContext checks if the context is a ledger context.
func IsLedgerContext(context string) bool {
return strings.HasPrefix(context, ledger)
}
// IsEternlContext checks if the context is a eternl context.
func IsEternlContext(context string) bool {
return strings.HasPrefix(context, eternl)
}
// GetContextKey returns the key part of the context, if it has a prefix.
func GetContextKey(context string) string {
parts := strings.Split(context, ":")
if len(parts) != 2 {
return context
}
return parts[1]
}
func CreateTrustContextFromKeyStore(
fs afero.Fs, contextKey,
passphrase, keyStorePath string,
) (did.TrustContext, crypto.PrivKey, error) {
keyStoreDir := filepath.Join(keyStorePath, KeystoreDir)
// support FS cache for E2E tests #1139
fsCache := os.Getenv("DMS_E2E_CACHE_KEYS") == "1"
ks, err := keystore.New(fs, keyStoreDir, fsCache)
if err != nil {
return nil, nil, fmt.Errorf("failed to open keystore: %w", err)
}
ksPrivKey, err := ks.Get(contextKey, passphrase)
if err != nil {
return nil, nil, fmt.Errorf("failed to get key from keystore %s: %w", contextKey, err)
}
priv, err := ksPrivKey.PrivKey()
if err != nil {
return nil, nil, fmt.Errorf("unable to convert key from keystore to private key: %w", err)
}
trustCtx, err := did.NewTrustContextWithPrivateKey(priv)
if err != nil {
return nil, nil, fmt.Errorf("unable to create trust context: %w", err)
}
// Check if this key has a PRISM DID association
// If so, add PRISM provider to trust context
prismDIDStr, err := GetPrismDID(fs, keyStorePath, contextKey)
if err == nil && prismDIDStr != "" {
prismDID, err := did.FromString(prismDIDStr)
if err == nil {
prismProvider, err := did.ProviderFromPRISMPrivateKey(prismDID, priv)
if err == nil {
trustCtx.AddProvider(prismProvider)
}
}
}
return trustCtx, priv, nil
}
func LoadCapabilityContext(trustCtx did.TrustContext, fs afero.Fs, name string, capStorePath string) (ucan.CapabilityContext, error) {
capStoreDir := filepath.Join(capStorePath, CapstoreDir)
capStoreFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", name))
f, err := fs.Open(capStoreFile)
if err != nil {
return nil, fmt.Errorf("unable to open capability context file: %w", err)
}
defer f.Close()
capCtx, err := ucan.LoadCapabilityContextWithName(name, trustCtx, f)
if err != nil {
return nil, fmt.Errorf("unable to load capability context: %w", err)
}
return capCtx, nil
}
func SaveCapabilityContext(capCtx ucan.CapabilityContext, fs afero.Fs, capStorePath string) error {
name := capCtx.Name()
capStoreDir := filepath.Join(capStorePath, CapstoreDir)
capCtxFile := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap", name))
capCtxBackup := filepath.Join(capStoreDir, fmt.Sprintf("%s.cap.%d", name, time.Now().Unix()))
// ensure the directory exists
if err := fs.MkdirAll(capStoreDir, os.FileMode(0o700)); err != nil {
return fmt.Errorf("creating capability context director: %w", err)
}
// first take a backup -- move the current context
if _, err := fs.Stat(capCtxFile); err == nil {
if err := fs.Rename(capCtxFile, capCtxBackup); err != nil {
return fmt.Errorf("error backing up current capability context: %w", err)
}
}
// now open for writing
f, err := fs.Create(capCtxFile)
if err != nil {
return fmt.Errorf("error creating new capability context file: %w", err)
}
defer f.Close()
if err := ucan.SaveCapabilityContext(capCtx, f); err != nil {
return fmt.Errorf("error saving capability context: %w", err)
}
return nil
}
// GetTrustContext returns a DID TrustContext for the given logical context
// name. For ledger-backed contexts it now supports:
//
// - “ledger” → account index 0
// - “ledger:<index>” → explicit ledger account index
// - “ledger:<alias>” → alias resolved via ledger_aliases.json
func GetTrustContext(
fs afero.Fs,
context, passphrase, userDir string,
) (did.TrustContext, error) {
// Ledger path
if IsLedgerContext(context) {
idx, err := ResolveLedgerIndex(fs, userDir, GetContextKey(context))
if err != nil {
return nil, err
}
provider, err := did.NewLedgerWalletProvider(idx)
if err != nil {
return nil, err
}
return did.NewTrustContextWithProvider(provider), nil
}
if IsEternlContext(context) {
provider, err := did.NewEternlWalletProvider()
if err != nil {
return nil, err
}
return did.NewTrustContextWithProvider(provider), nil
}
// Keystore path
trustCtx, _, err := CreateTrustContextFromKeyStore(
fs, context, passphrase, userDir,
)
return trustCtx, err
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package geolocation
import (
"bufio"
"bytes"
_ "embed"
"fmt"
"math"
"strconv"
"strings"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
)
// currently only using a GeoNames file for cities with population > 5000
// download it here: https://download.geonames.org/export/dump/cities5000.zip
//
//go:embed cities5000.txt
var cities5000 string
const LightSpeed = 299792.458 // in km/s
type LocationProvider interface {
Coordinate(loc jtypes.Location) (Coordinate, error)
}
type Coordinate struct {
lat float64
long float64
}
func (c *Coordinate) Empty() bool {
return c.lat == 0 && c.long == 0
}
type GeoLocator struct {
coordinates map[string]map[string]Coordinate // country -> city -> coordinate
}
func NewGeoLocator() (*GeoLocator, error) {
buf := bytes.NewBufferString(cities5000)
geo := &GeoLocator{
coordinates: make(map[string]map[string]Coordinate),
}
scanner := bufio.NewScanner(buf)
scanner.Buffer(make([]byte, 64*1024), 1024*1024) // increase buffer size for large lines
for scanner.Scan() {
fields := strings.SplitN(scanner.Text(), "\t", 20) // limit to 20 fields in each entry
if len(fields) < 19 {
continue
}
cityName := fields[1]
countryCode := fields[8]
coordinate, err := parseCoordinate(fields)
if err != nil {
log.Warnw("parsing_coordinates_failure",
"cityName", cityName,
"countryCode", countryCode,
"error", err)
continue
}
countryMap, ok := geo.coordinates[countryCode]
if !ok {
countryMap = make(map[string]Coordinate)
geo.coordinates[countryCode] = countryMap
}
countryMap[cityName] = coordinate
}
if err := scanner.Err(); err != nil {
return nil, fmt.Errorf("reading cities file: %w", err)
}
return geo, nil
}
func (geo *GeoLocator) Coordinate(loc jtypes.Location) (Coordinate, error) {
if loc.Country == "" || loc.City == "" {
return Coordinate{}, fmt.Errorf("no city in location")
}
coord, ok := geo.coordinates[loc.Country][loc.City]
if !ok {
return Coordinate{}, fmt.Errorf("unknown city")
}
return coord, nil
}
func parseCoordinate(fields []string) (Coordinate, error) {
lat, err := strconv.ParseFloat(fields[4], 64)
if err != nil {
return Coordinate{}, fmt.Errorf("parse latitude: %w", err)
}
long, err := strconv.ParseFloat(fields[5], 64)
if err != nil {
return Coordinate{}, fmt.Errorf("parse longitude: %w", err)
}
return Coordinate{lat: lat, long: long}, nil
}
// using a haversine formula to calculate the shortest path
func ComputeGeodesic(p1, p2 Coordinate) float64 {
const earthRadius = 6371 // km
if p1.Empty() || p2.Empty() {
return 0
}
lat1 := p1.lat * math.Pi / 180
lat2 := p2.lat * math.Pi / 180
dLat := (p2.lat - p1.lat) * math.Pi / 180
dLon := (p2.long - p1.long) * math.Pi / 180
a := math.Sin(dLat/2)*math.Sin(dLat/2) +
math.Cos(lat1)*math.Cos(lat2)*
math.Sin(dLon/2)*math.Sin(dLon/2)
c := 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a))
return earthRadius * c
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package geolocation
import (
"net"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
types "gitlab.com/nunet/device-management-service/types"
)
type Geolocation struct {
Continent string
Country string
City string
}
// TODO: Add city check when we have GeoIP database with cities
func (g *Geolocation) Empty() bool {
return g.Continent == "" && g.Country == ""
}
func Geolocate(ip net.IP, geoIP types.GeoIPLocator) (jobtypes.Location, error) {
rec, err := geoIP.Country(ip)
if err != nil {
return jobtypes.Location{}, err
}
location := jobtypes.Location{
Country: rec.Country.IsoCode,
Continent: rec.Continent.Code,
}
return location, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package geolocation
import (
"fmt"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
)
// MockGeoLocator is a mock implementation of LocationProvider for testing
type MockGeoLocator struct {
coordinates map[string]map[string]Coordinate
}
// NewMockGeoLocator creates a new mock GeoLocator
func NewMockGeoLocator() *MockGeoLocator {
return &MockGeoLocator{
coordinates: make(map[string]map[string]Coordinate),
}
}
// AddLocation adds a location to the mock GeoLocator
func (m *MockGeoLocator) AddLocation(country, city string, lat, long float64) {
if _, ok := m.coordinates[country]; !ok {
m.coordinates[country] = make(map[string]Coordinate)
}
m.coordinates[country][city] = Coordinate{lat: lat, long: long}
}
// Coordinate returns the coordinate for a location
func (m *MockGeoLocator) Coordinate(loc jtypes.Location) (Coordinate, error) {
if loc.Country == "" || loc.City == "" {
return Coordinate{}, fmt.Errorf("no city in location")
}
coord, ok := m.coordinates[loc.Country][loc.City]
if !ok {
return Coordinate{}, fmt.Errorf("unknown city")
}
return coord, nil
}
func (m *MockGeoLocator) Empty() bool {
return len(m.coordinates) == 0
}
// Copyright 2025, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions
// and limitations under the License.
package node
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/spf13/afero"
)
const LedgerAliasFile = "ledger_aliases.json"
// $USERDIR/ledger_aliases.json (new – root of the user dir)
func aliasFilePath(userDir string) string {
return filepath.Join(userDir, LedgerAliasFile)
}
// loadLedgerAliases parses the on-disk alias table
// Missing file empty map (no error)
func loadLedgerAliases(fs afero.Fs, userDir string) (map[string]int, error) {
path := aliasFilePath(userDir)
aliases := make(map[string]int)
f, err := fs.Open(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return aliases, nil
}
return nil, fmt.Errorf("open alias table: %w", err)
}
defer f.Close()
if err := json.NewDecoder(f).Decode(&aliases); err != nil {
return nil, fmt.Errorf("decode alias table: %w", err)
}
return aliases, nil
}
// saveLedgerAliases writes aliases atomically, backing up the old file first.
func saveLedgerAliases(fs afero.Fs, userDir string, aliases map[string]int) error {
dir := userDir
if err := fs.MkdirAll(dir, 0o700); err != nil {
return fmt.Errorf("create alias dir: %w", err)
}
path := aliasFilePath(userDir)
backup := fmt.Sprintf("%s.%d", path, time.Now().Unix())
// Backup existing file, if present
if _, err := fs.Stat(path); err == nil {
if err := fs.Rename(path, backup); err != nil {
return fmt.Errorf("backup alias table: %w", err)
}
}
// Write to temp file, then rename
tmp, err := afero.TempFile(fs, dir, "aliases.*.tmp")
if err != nil {
return fmt.Errorf("create temp alias file: %w", err)
}
tmpName := tmp.Name()
enc := json.NewEncoder(tmp)
enc.SetIndent("", " ")
if err := enc.Encode(aliases); err != nil {
tmp.Close()
return fmt.Errorf("encode alias table: %w", err)
}
tmp.Close()
// Best-effort: set permissions to 0600. Ignore if FS doesn’t support it.
_ = fs.Chmod(tmpName, 0o600)
if err := fs.Rename(tmpName, path); err != nil {
return fmt.Errorf("rename alias file: %w", err)
}
return nil
}
// ResolveLedgerIndex turns <key> from “ledger:<key>” into the actual
// account index:
//
// - "" or "ledger" → 0
// - all-digits → parsed integer (must be ≥0)
// - otherwise → look up alias in the on-disk table
func ResolveLedgerIndex(fs afero.Fs, userDir, key string) (int, error) {
k := strings.TrimSpace(key)
if k == "" || k == "ledger" {
return 0, nil
}
if n, err := strconv.Atoi(k); err == nil {
if n < 0 {
return 0, fmt.Errorf("ledger index cannot be negative: %d", n)
}
return n, nil
}
aliases, err := loadLedgerAliases(fs, userDir)
if err != nil {
return 0, err
}
idx, ok := aliases[k]
if !ok {
return 0, fmt.Errorf("ledger alias not found: %q", k)
}
return idx, nil
}
func SetLedgerAlias(fs afero.Fs, userDir, alias string, index int) error {
alias = strings.TrimSpace(alias)
if alias == "" {
return fmt.Errorf("alias cannot be empty")
}
if strings.Contains(alias, ":") {
return fmt.Errorf("alias cannot contain ':'")
}
if _, err := strconv.Atoi(alias); err == nil {
return fmt.Errorf("alias cannot be purely numeric")
}
if index < 0 {
return fmt.Errorf("ledger index cannot be negative")
}
aliases, err := loadLedgerAliases(fs, userDir)
if err != nil {
return err
}
aliases[alias] = index
return saveLedgerAliases(fs, userDir, aliases)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"fmt"
"math/rand"
"time"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/observability"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/network"
)
func includesRootProtocol(protocols []protocol.ID) bool {
for _, proto := range protocols {
if proto == rootProto {
return true
}
}
return false
}
func (n *Node) subscribe(topics ...string) error {
for _, topic := range topics {
if err := n.actor.Subscribe(topic, n.setupBroadcast); err != nil {
return fmt.Errorf("error subscribing to %s: %w", topic, err)
}
}
n.network.SetBroadcastAppScore(n.broadcastScore)
if err := n.network.Notify(n.actor.Context(),
n.peerPreConnected,
n.peerConnected,
n.peerDisconnected,
n.peerIdentified,
n.peerIdentified,
); err != nil {
return fmt.Errorf("setting up peer notifications: %w", err)
}
return nil
}
func (n *Node) setupBroadcast(topic string) error {
return n.network.SetupBroadcastTopic(topic, func(t *network.Topic) error {
return t.SetScoreParams(&pubsub.TopicScoreParams{
SkipAtomicValidation: true,
TopicWeight: 1.0,
TimeInMeshWeight: 0.00027, // ~1/3600
TimeInMeshQuantum: time.Second,
TimeInMeshCap: 1.0,
InvalidMessageDeliveriesWeight: -1000,
InvalidMessageDeliveriesDecay: pubsub.ScoreParameterDecay(time.Hour),
})
})
}
func (n *Node) broadcastScore(p peer.ID) float64 {
n.lock.Lock()
defer n.lock.Unlock()
st, ok := n.peers[p]
if !ok {
return 0
}
if st.helloIn && st.helloOut {
return 5
}
if st.hasRoot {
return 1
}
return 0
}
func (n *Node) peerConnected(p peer.ID) {
logConn.Debugf("peer connected: %s", p)
n.lock.Lock()
defer n.lock.Unlock()
st, ok := n.peers[p]
if !ok {
st = &peerState{}
n.peers[p] = st
}
st.numConnections++
}
func (n *Node) peerPreConnected(p peer.ID, protocols []protocol.ID, numConnections int) {
logConn.Debugf("peer preconnected: %s %s (%d)", p, protocols, numConnections)
n.lock.Lock()
defer n.lock.Unlock()
st := &peerState{numConnections: numConnections}
n.peers[p] = st
if includesRootProtocol(protocols) {
st.hasRoot = true
st.helloPending = true
st.helloAttempts = 1
go n.sayHello(p)
}
}
func (n *Node) peerIdentified(p peer.ID, protocols []protocol.ID) {
logConn.Debugf("peer identified: %s %s", p, protocols)
n.lock.Lock()
defer n.lock.Unlock()
st, ok := n.peers[p]
if !ok {
st = &peerState{}
n.peers[p] = st
}
if includesRootProtocol(protocols) {
st.hasRoot = true
if !st.helloOut && !st.helloPending {
st.helloPending = true
st.helloAttempts++
go n.sayHello(p)
}
}
}
func (n *Node) peerDisconnected(p peer.ID) {
logConn.Debugf("peer disconnected: %s", p)
n.lock.Lock()
defer n.lock.Unlock()
st, ok := n.peers[p]
if !ok {
return
}
st.numConnections--
if st.numConnections <= 0 {
delete(n.peers, p)
}
}
func (n *Node) sayHello(p peer.ID) {
pubk, err := p.ExtractPublicKey()
if err != nil {
log.Debugf("extract public key: %s", err)
return
}
if !crypto.AllowedKey(int(pubk.Type())) {
log.Debugf("unexpected key type: %d", pubk.Type())
return
}
actorID, err := crypto.IDFromPublicKey(pubk)
if err != nil {
log.Debugf("extract actor ID: %s", err)
return
}
actorDID := did.FromPublicKey(pubk)
handle := actor.Handle{
ID: actorID,
DID: actorDID,
Address: actor.Address{
HostID: p.String(),
InboxAddress: "root",
},
}
wait := helloMinDelay + time.Duration(rand.Int63n(int64(helloMaxDelay-helloMinDelay)))
time.Sleep(wait)
n.lock.Lock()
st, ok := n.peers[p]
if !ok {
n.lock.Unlock()
return
}
if !n.network.PeerConnected(p) {
st.helloPending = false
n.lock.Unlock()
return
}
n.lock.Unlock()
msg, err := actor.Message(
n.actor.Handle(),
handle,
behaviors.PublicHelloBehavior,
nil,
actor.WithMessageTimeout(helloTimeout),
)
if err != nil {
log.Debugf("construct hello message: %s", err)
return
}
logConn.Debugf("saying hello to %s", handle.Address.HostID)
replyCh, err := n.actor.Invoke(msg)
if err != nil {
n.lock.Lock()
if st, ok = n.peers[p]; ok {
if st.helloAttempts < helloAttempts {
st.helloAttempts++
go n.sayHello(p)
} else {
st.helloPending = false
}
}
n.lock.Unlock()
logConn.Debugf("invoking hello: %s", err)
return
}
select {
case reply := <-replyCh:
reply.Discard()
n.lock.Lock()
if st, ok = n.peers[p]; ok {
st.helloOut = true
st.helloPending = false
} else if n.network.PeerConnected(p) {
// race with connected notification
st = &peerState{helloOut: true}
n.peers[p] = st
}
n.lock.Unlock()
log.Debugw("got hello response from", "labels", string(observability.LabelNode), "hostID", handle.Address.HostID)
case <-time.After(time.Until(msg.Expiry())):
n.lock.Lock()
if st, ok = n.peers[p]; ok {
if st.helloAttempts < helloAttempts {
st.helloAttempts++
go n.sayHello(p)
} else {
st.helloPending = false
}
}
n.lock.Unlock()
logConn.Debugw("hello timeout", "labels", string(observability.LabelNode), "hostID", handle.Address.HostID)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
lcrypto "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/spf13/afero"
gatewastore "gitlab.com/nunet/device-management-service/gateway/store"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/dms/jobs"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/dms/node/geolocation"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/dms/orchestrator"
"gitlab.com/nunet/device-management-service/gateway/provider"
"gitlab.com/nunet/device-management-service/internal"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/storage/volume/glusterfs/controller"
"gitlab.com/nunet/device-management-service/tokenomics"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/contracts/processors"
"gitlab.com/nunet/device-management-service/tokenomics/eventhandler"
"gitlab.com/nunet/device-management-service/tokenomics/pricing"
"gitlab.com/nunet/device-management-service/tokenomics/store"
"gitlab.com/nunet/device-management-service/tokenomics/store/payment"
payment_quote "gitlab.com/nunet/device-management-service/tokenomics/store/payment_quote"
"gitlab.com/nunet/device-management-service/tokenomics/store/transaction"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
"gitlab.com/nunet/device-management-service/types"
)
const (
helloMinDelay = 10 * time.Second
helloMaxDelay = 20 * time.Second
helloTimeout = 3 * time.Second
helloAttempts = 3
clearCommitsFrequency = 60 * time.Second
ensembleMonitorFrequency = 10 * time.Second
grantAllocationCapsFreq = 1 * time.Hour
rootProto = "actor/root/messages/0.0.1"
// TODO: We should consider a restoration deadline down the line (see code at restoreDeployments)
// RestoreDeadlineCommitting = 1 * time.Minute
// RestoreDeadlineProvisioning = 1 * time.Minute
// RestoreDeadlineRunning = 5 * time.Minute
bidStateGCInterval = time.Minute
provisionedResourcesGCInterval = time.Minute
// contract event handler config
eventHandlerWorkers = 2
eventHandlerQueueSize = 200
eventHandlerBaseDelay = 5 * time.Second
eventHandlerMaxDelay = 15 * time.Second
)
// TODO issue #1154 - better handle transient allocations
// temporary subnet status handling - 1 = active , 0 = destroyed
var (
subnetStatusMx sync.Mutex
subnetStatus map[string]int
)
type peerState struct {
numConnections int
hasRoot bool
helloIn, helloOut, helloPending bool
helloAttempts int
}
type bidState struct {
expire time.Time
nonce uint64
request jobtypes.BidRequest
}
type executorMetadata struct {
executor types.Executor
executionType jobs.AllocationExecutor
}
type PortConfig struct {
AvailableRangeFrom int
AvailableRangeTo int
}
// Node is the structure that holds the node's dependencies.
type Node struct {
lock sync.RWMutex
rootCap ucan.CapabilityContext
// dms modules
allocator Allocator
actor actor.Actor
scheduler *bt.Scheduler
network network.Network
resourceManager types.ResourceManager
hardware types.HardwareManager
onboarding *onboarding.Onboarding
executors map[string]executorMetadata
// in-memory state
hostID string
geoIP types.GeoIPLocator
hostLocation geolocation.Geolocation
publicIP net.IP
peers map[peer.ID]*peerState
bids map[string]*bidState
answeredBids map[string][]uint64
running atomic.Bool
// volume controller
volumeController controller.GlusterControllerInterface
volumeOwners map[string]string // mapping volume name with did
// utils
orchestratorRegistry orchestrator.Registry
dmsConfig config.Config
fs afero.Afero
ctx context.Context
cancel func()
// contract store
contractStore *store.Store
paymentStore *payment.Store
usageStore *usage.Store
contractActors map[string]*tokenomics.ContractActor // DID URI -> Actor (optimized lookup)
billingScheduler *tokenomics.ContractBillingScheduler
transactionStore *transaction.Store
// contract event handler
contractEventHandler *eventhandler.EventHandler
// serverProviderRegistry registory
serverProviderRegistry *provider.Registry
// payment processor
paymentProcessor PaymentProcessor
gatewayStore *gatewastore.Store
// payment quote store and price converter
paymentQuoteStore *payment_quote.Store
priceConverter *pricing.PriceConverter
}
// createActor creates an actor.
func createActor(
sctx *actor.BasicSecurityContext,
limiter actor.RateLimiter,
hostID, inboxAddress string,
net network.Network,
supervisor actor.Handle,
) (*actor.BasicActor, error) {
self := actor.Handle{
ID: sctx.ID(),
DID: sctx.DID(),
Address: actor.Address{
HostID: hostID,
InboxAddress: inboxAddress,
},
}
newActor, err := actor.New(supervisor, net, sctx, limiter, actor.BasicActorParams{}, self)
if err != nil {
return nil, fmt.Errorf("create actor: %w", err)
}
return newActor, nil
}
// New creates a new node, attaches an actor to the node.
func New(cfg config.Config, fs afero.Afero,
onboarding *onboarding.Onboarding,
rootCap ucan.CapabilityContext,
hostID string, net network.Network,
resourceManager types.ResourceManager,
scheduler *bt.Scheduler,
hardware types.HardwareManager,
geoIP types.GeoIPLocator, hostLocation geolocation.Geolocation,
portConfig PortConfig, vt *storage.VolumeTracker,
volumeController controller.GlusterControllerInterface,
contractStore *store.Store,
paymentStore *payment.Store,
usageStore *usage.Store,
transactionStore *transaction.Store,
deploymentStore orchestrator.DeploymentStore,
providerRegistry *provider.Registry,
gatewayStore *gatewastore.Store,
paymentQuoteStore *payment_quote.Store,
) (*Node, error) {
if onboarding == nil {
return nil, errors.New("onboarding is nil")
}
if rootCap == nil {
return nil, errors.New("root capability context is nil")
}
if hostID == "" {
return nil, errors.New("hostID is empty")
}
if net == nil {
return nil, errors.New("network is nil")
}
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
if scheduler == nil {
return nil, errors.New("scheduler is nil")
}
if geoIP == nil {
return nil, errors.New("geoIP is nil")
}
if contractStore == nil {
return nil, errors.New("contract store is nil")
}
if paymentStore == nil {
return nil, errors.New("payment store is nil")
}
if usageStore == nil {
return nil, errors.New("usage store is nil")
}
if transactionStore == nil {
return nil, errors.New("transaction store is nil")
}
if deploymentStore == nil {
return nil, errors.New("deployment store is nil")
}
subnetStatus = make(map[string]int)
rootDID := rootCap.DID()
rootTrust := rootCap.Trust()
anchor, err := rootTrust.GetAnchor(rootDID)
if err != nil {
return nil, fmt.Errorf("get root DID anchor: %w", err)
}
pubk := anchor.PublicKey()
provider, err := rootTrust.GetProvider(rootDID)
if err != nil {
return nil, fmt.Errorf("get root DID provider: %w", err)
}
privk, err := provider.PrivateKey()
if err != nil {
return nil, fmt.Errorf("get root private key: %w", err)
}
rootSec, err := actor.NewBasicSecurityContext(pubk, privk, rootCap)
if err != nil {
return nil, fmt.Errorf("create security context: %w", err)
}
nodeActor, err := createActor(rootSec, actor.NewRateLimiter(actor.DefaultRateLimiterConfig()), hostID, "root", net, actor.Handle{})
if err != nil {
return nil, fmt.Errorf("create node actor: %w", err)
}
allocator := newAllocator(vt, newPortAllocator(portConfig), resourceManager, hardware, net, fs, cfg.WorkDir, hostID, contractStore)
ctx, cancel := context.WithCancel(context.Background())
n := &Node{
allocator: allocator,
hostID: hostID,
network: net,
bids: make(map[string]*bidState),
answeredBids: make(map[string][]uint64),
peers: make(map[peer.ID]*peerState),
resourceManager: resourceManager,
hardware: hardware,
actor: nodeActor,
rootCap: rootCap,
scheduler: scheduler,
onboarding: onboarding,
executors: make(map[string]executorMetadata),
ctx: ctx,
cancel: cancel,
orchestratorRegistry: orchestrator.NewRegistry(deploymentStore),
geoIP: geoIP,
hostLocation: hostLocation,
dmsConfig: cfg,
fs: fs,
volumeController: volumeController,
volumeOwners: make(map[string]string),
contractStore: contractStore,
paymentStore: paymentStore,
usageStore: usageStore,
transactionStore: transactionStore,
contractActors: make(map[string]*tokenomics.ContractActor),
serverProviderRegistry: providerRegistry,
gatewayStore: gatewayStore,
}
// Create payment processor with invokeBehaviour function
invokeBehaviourFunc := func(destination actor.Handle, behavior string, req interface{}, timeout time.Duration) (actor.Envelope, error) {
msg, err := actor.Message(
nodeActor.Handle(),
destination,
behavior,
req,
actor.WithMessageExpiry(actor.MakeExpiry(timeout)),
)
if err != nil {
return actor.Envelope{}, fmt.Errorf("failed to create message: %w", err)
}
replyCh, err := nodeActor.Invoke(msg)
if err != nil {
return actor.Envelope{}, fmt.Errorf("failed to invoke message: %w", err)
}
ticker := time.NewTicker(timeout)
defer ticker.Stop()
select {
case reply := <-replyCh:
defer reply.Discard()
return reply, nil
case <-ticker.C:
return actor.Envelope{}, errors.New("failed to receive reply due to timeout")
}
}
// Initialize price oracle and converter
var priceConverter *pricing.PriceConverter
if cfg.CoinMarketCap.APIKey != "" {
// Parse cache TTL (default to 5m if empty/invalid)
cacheTTL := 5 * time.Minute
if cfg.CoinMarketCap.CacheTTL != "" {
if parsedTTL, err := time.ParseDuration(cfg.CoinMarketCap.CacheTTL); err == nil {
cacheTTL = parsedTTL
}
}
// Set defaults for baseURL and endpointPath if empty
baseURL := cfg.CoinMarketCap.BaseURL
if baseURL == "" {
baseURL = "https://pro-api.coinmarketcap.com/v1"
}
endpointPath := cfg.CoinMarketCap.EndpointPath
if endpointPath == "" {
endpointPath = "/tools/price-conversion"
}
oracle := pricing.NewCoinMarketCapOracle(cfg.CoinMarketCap.APIKey, baseURL, endpointPath, cacheTTL)
priceConverter = pricing.NewPriceConverter(oracle)
}
n.paymentProcessor = NewPaymentProcessor(paymentStore, net, nodeActor.Handle(), invokeBehaviourFunc)
n.paymentQuoteStore = paymentQuoteStore
n.priceConverter = priceConverter
// set up the flight recorder
observability.FlightrecInit()
// setup contract event handler
n.contractEventHandler = eventhandler.New(ctx, eventHandlerWorkers, eventHandlerQueueSize, eventHandlerBaseDelay, eventHandlerMaxDelay, n.handleContractEvents)
// Initialize payment model processors
// This must be called before using GetPaymentModelProcessor
processors.InitPaymentModelProcessors(usageStore)
if err := n.initSupportedExecutors(ctx); err != nil {
cancel()
return nil, fmt.Errorf("new executor: %w", err)
}
dmsBehaviors := n.getDMSBehaviors()
for behavior, handler := range dmsBehaviors {
if err := nodeActor.AddBehavior(behavior, handler.fn, handler.opts...); err != nil {
return nil, fmt.Errorf("adding %s behavior: %w", behavior, err)
}
}
// Create billing scheduler with billing function
// Note: billingFunc closure references n, so it must be created after n
billingFunc := func(contractDID did.DID) error {
return n.executeBillingForContract(contractDID)
}
billingScheduler, err := tokenomics.NewContractBillingScheduler(
contractStore,
usageStore,
billingFunc,
)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to create billing scheduler: %w", err)
}
// Set billing scheduler on node
n.billingScheduler = billingScheduler
// NOTE: Do NOT start billing scheduler here
// It will be started in Node.Start() method
return n, nil
}
func (n *Node) saveDeployments() error {
n.lock.Lock()
defer n.lock.Unlock()
var failed []string
for id, o := range n.orchestratorRegistry.Orchestrators() {
if err := n.saveDeployment(o); err != nil {
log.Errorw("error saving active deployment",
"labels", string(observability.LabelDeployment),
"deploymentID", id,
"error", err)
failed = append(failed, id)
}
}
if len(failed) != 0 {
return fmt.Errorf("save deployments: %v", failed)
}
return nil
}
func (n *Node) restoreDeployments() error {
// Get all deployments from store (source of truth)
allDeployments, err := n.orchestratorRegistry.GetAllDeployments()
if err != nil {
return fmt.Errorf("failed to get all deployments: %w", err)
}
var failedToRestore []string
for _, d := range allDeployments {
// Only restore deployments in restorable states
if !isRestorableStatus(d.Status) {
log.Infof("deployment %s has non-restorable status %s skipping", d.OrchestratorID, d.Status.String())
continue
}
// Check if deployment is still valid (not based on time)
if !isDeploymentStillValid(d) {
log.Warnf("deployment %s is no longer valid; skipping", d.OrchestratorID)
continue
}
// recreate actor given priv key
pvkey, err := lcrypto.UnmarshalPrivateKey(d.PrivKey)
if err != nil {
log.Errorf("unmarshal orchestrator actor private key for %s: %v", d.OrchestratorID, err)
failedToRestore = append(failedToRestore, d.OrchestratorID)
continue
}
childActor, err := n.actor.CreateChild(
d.OrchestratorID,
n.actor.Handle(),
actor.WithPrivKey(pvkey),
)
if err != nil {
log.Errorf("restore actor creation error for %s: %v", d.OrchestratorID, err)
failedToRestore = append(failedToRestore, d.OrchestratorID)
continue
}
if err := childActor.Start(); err != nil {
log.Errorf("start orchestrator actor for %s: %v", d.OrchestratorID, err)
continue
}
if d.Manifest.Subnet.Join {
if err := n.addOrchestratorBehaviors(childActor, d.OrchestratorID); err != nil {
return fmt.Errorf("adding behaviors for orch to join subnet: %w", err)
}
}
orchestrator, err := n.
orchestratorRegistry.
RestoreDeployment(
n.ctx,
n.fs,
childActor,
d.OrchestratorID,
d.Cfg,
d.Manifest,
d.Status,
d.DeploymentSnapshot,
d.SubnetManifest,
types.NewDefaultAllocationIDGenerator(),
)
if err != nil {
log.Errorf("restoring deployment %s failed: %v", d.OrchestratorID, err)
failedToRestore = append(failedToRestore, d.OrchestratorID)
continue
}
log.Infow("restored deployment",
"labels", string(observability.LabelDeployment),
"deploymentID", orchestrator.ID())
}
if len(failedToRestore) > 0 {
return fmt.Errorf("failed to restore the following deployment(s): %v", failedToRestore)
}
return nil
}
func (n *Node) getDMSBehaviors() map[string]struct {
fn func(actor.Envelope)
opts []actor.BehaviorOption
} {
dmsBehaviors := map[string]struct {
fn func(actor.Envelope)
opts []actor.BehaviorOption
}{
behaviors.PublicHelloBehavior: {
fn: n.publicHelloBehavior,
},
behaviors.PublicStatusBehavior: {
fn: n.publicStatusBehavior,
},
behaviors.BroadcastHelloBehavior: {
fn: n.handleBroadcastHelloBehavior,
opts: []actor.BehaviorOption{
actor.WithBehaviorTopic(behaviors.BroadcastHelloTopic),
},
},
behaviors.PeersListBehavior: {
fn: n.handlePeersList,
},
behaviors.PeerAddrInfoBehavior: {
fn: n.handlePeerAddrInfo,
},
behaviors.PeerPingBehavior: {
fn: n.handlePeerPing,
},
behaviors.PeerDHTBehavior: {
fn: n.handlePeerDHT,
},
behaviors.PeerConnectBehavior: {
fn: n.handlePeerConnect,
},
behaviors.PeerScoreBehavior: {
fn: n.handlePeerScore,
},
behaviors.DebugFlightrecBehavior: {
fn: n.handleFlightrec,
},
behaviors.OnboardBehavior: {
fn: n.handleOnboard,
},
behaviors.OffboardBehavior: {
fn: n.handleOffboard,
},
behaviors.OnboardStatusBehavior: {
fn: n.handleOnboardStatus,
},
behaviors.NewDeploymentBehavior: {
fn: n.handleNewDeployment,
},
behaviors.DeploymentUpdateBehavior: {
fn: n.handleDeploymentUpdate,
},
behaviors.DeploymentListBehavior: {
fn: n.handleDeploymentList,
},
behaviors.DeploymentLogsBehavior: {
fn: n.handleDeploymentLogs,
},
behaviors.DeploymentStatusBehavior: {
fn: n.handleDeploymentStatus,
},
behaviors.DeploymentManifestBehavior: {
fn: n.handleDeploymentManifest,
},
behaviors.DeploymentInfoBehavior: {
fn: n.handleDeploymentInfo,
},
behaviors.DeploymentShutdownBehavior: {
fn: n.handleDeploymentShutdown,
},
behaviors.DeploymentPruneBehavior: {
fn: n.handleDeploymentPrune,
},
behaviors.DeploymentDeleteBehavior: {
fn: n.handleDeploymentDelete,
},
behaviors.AllocationsListBehavior: {
fn: n.handleAllocationsList,
},
behaviors.VerifyEdgeConstraintBehavior: {
fn: n.handleVerifyEdgeConstraint,
},
behaviors.BidRequestBehavior: {
fn: n.handleBidRequest,
opts: []actor.BehaviorOption{
actor.WithBehaviorTopic(behaviors.BidRequestTopic),
},
},
behaviors.DeploymentRevertBehavior: {
fn: n.handleDeploymentRevert,
},
behaviors.SubnetCreateBehavior.Static: {
fn: n.handleSubnetCreate,
},
behaviors.SubnetDestroyBehavior.Static: {
fn: n.handleSubnetDestroy,
},
behaviors.ResourcesAllocatedBehavior: {
fn: n.handleAllocatedResources,
},
behaviors.ResourcesFreeBehavior: {
fn: n.handleFreeResources,
},
behaviors.ResourcesOnboardedBehavior: {
fn: n.handleOnboardedResources,
},
behaviors.LoggerConfigBehavior: {
fn: n.handleLoggerConfig,
},
behaviors.HardwareUsageBehavior: {
fn: n.handleHardwareUsage,
},
behaviors.HardwareSpecBehavior: {
fn: n.handleHardwareSpec,
},
behaviors.CapListBehavior: {
fn: n.handleCapList,
},
behaviors.ProvideCapAnchorBehavior: {
fn: n.handleProvideCapAnchor,
},
behaviors.RequireCapAnchorBehavior: {
fn: n.handleRequireCapAnchor,
},
behaviors.RevokeCapAnchorBehavior: {
fn: n.handleRevokeCapAnchor,
},
behaviors.BroadcastRevokeCapBehavior: {
fn: n.handleRevokeCapAnchor,
opts: []actor.BehaviorOption{
actor.WithBehaviorTopic(behaviors.BroadcastRevocationTopic),
},
},
behaviors.AllocationDeploymentBehavior: {
fn: n.handleAllocationDeployment,
},
behaviors.CommitDeploymentBehavior: {
fn: n.handleCommitDeployment,
},
behaviors.VolumeCreateBehavior: {
fn: n.handleCreateVolume,
},
behaviors.VolumeDeleteBehavior: {
fn: n.handleDeleteVolume,
},
behaviors.VolumeStartBehavior: {
fn: n.handleStartVolume,
},
behaviors.StatusDiscoveryBehavior: {
fn: n.handleStatusDiscoveryBehavior,
},
behaviors.BroadcastStatusDiscoveryBehavior: {
fn: n.handleStatusDiscoveryBehavior,
opts: []actor.BehaviorOption{
actor.WithBehaviorTopic(behaviors.BroadcastStatusDiscoveryTopic),
},
},
// solution enabler
behaviors.ContractCreateBehavior: {
fn: n.handleNewContract,
},
behaviors.ContractUsagesCalculateBehavior: {
fn: n.handleContractUsagesCalculate,
},
// listerner by service provider and compute provider
behaviors.ContractProposeBehavior: {
fn: n.handleContractPropose,
},
// used by compute provider to accpet a contract
behaviors.ContractApproveLocalBehavior: {
fn: n.handleContractApprovalLocal,
},
// used by compute provider to list incoming contracts
behaviors.ContractListBehavior: {
fn: n.handleListIncomingContracts,
},
// used by payment validator
behaviors.ContractUsageBehavior: {
fn: n.handleIncomingContractUsage,
},
// used by SP and CP
behaviors.ContractTransactionBehavior: {
fn: n.handleIncomingTransaction,
},
behaviors.ContractPaymentStatusBehavior: {
fn: n.handlePaymentStatus,
},
// used by payment validator to validate payment
behaviors.ContractPaymentValidationRequestBehavior: {
fn: n.handleContractPaymentValidationRequestFromContractHost,
},
behaviors.ContractListLocalTransactionsBehavior: {
fn: n.handleListLocalTransactions,
},
behaviors.ContractChainVerificationBehavior: {
fn: n.handleContractChainVerification,
},
behaviors.ContractInfoBehavior: {
fn: n.handleContractInfo,
},
behaviors.ContractConfirmLocalTransactionBehavior: {
fn: n.handleConfirmLocalTransaction,
},
behaviors.ContractGetPaymentQuoteBehavior: {
fn: n.handleGetPaymentQuote,
},
behaviors.ContractValidatePaymentQuoteBehavior: {
fn: n.handleValidatePaymentQuote,
},
behaviors.ContractCancelPaymentQuoteBehavior: {
fn: n.handleCancelPaymentQuote,
},
// gateway
behaviors.PromiseBidToBidBehavior: {
fn: n.handlePromiseBid,
},
// provisioned server
behaviors.PromiseBidSigningBehavior: {
fn: n.handleBidSigning,
},
}
return dmsBehaviors
}
func (n *Node) addOrchestratorBehaviors(actr actor.Actor, ensembleID string) error {
orchBehaviors := map[string]struct {
fn func(actor.Envelope)
opts []actor.BehaviorOption
}{
fmt.Sprintf(behaviors.SubnetCreateBehavior.DynamicTemplate, ensembleID): {
fn: n.handleSubnetCreate,
},
fmt.Sprintf(behaviors.SubnetDestroyBehavior.DynamicTemplate, ensembleID): {
fn: n.handleSubnetDestroy,
},
fmt.Sprintf(behaviors.SubnetJoinBehavior.DynamicTemplate, ensembleID): {
fn: n.handleSubnetJoin,
},
}
for behavior, handler := range orchBehaviors {
if err := n.actor.AddBehavior(behavior, handler.fn, handler.opts...); err != nil {
return fmt.Errorf("adding %s behavior: %w", behavior, err)
}
}
err := n.actor.Security().Grant(actr.Handle().DID, n.actor.Handle().DID, []ucan.Capability{
ucan.Capability(fmt.Sprintf(behaviors.EnsembleNamespace, ensembleID)),
}, time.Hour)
if err != nil {
return fmt.Errorf("granting subnet caps to self orchestrator: %w", err)
}
return nil
}
func (n *Node) gcBidState() {
ticker := time.NewTicker(bidStateGCInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
n.doGCBidState()
case <-n.ctx.Done():
return
}
}
}
func (n *Node) shutdownUnusedProvisionedResources() {
ticker := time.NewTicker(provisionedResourcesGCInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
n.deleteProvionedResource()
case <-n.ctx.Done():
return
}
}
}
func (n *Node) deleteProvionedResource() {
all, err := n.gatewayStore.All()
if err != nil {
log.Errorf("failed to get provisioned resources from store: %v", err)
return
}
for _, v := range all {
destination, err := actor.HandleFromPeerID(v.ProvisionedVMPeerID)
if err != nil {
log.Errorf("failed to get handle of provisioned dms")
continue
}
envelope, err := n.invokeBehaviour(destination, behaviors.AllocationsListBehavior, nil, invokeMessageTimeout)
if envelope.Message == nil || err != nil {
log.Errorf("failed to get allocation list from new dms: %v", err)
continue
}
var allocs AllocationsListResponse
err = json.Unmarshal(envelope.Message, &allocs)
if err != nil {
log.Errorf("failed to unmarshal allocation list from new dms: %v", err)
continue
}
killVM := true
for _, alloc := range allocs.Allocations {
if alloc.Status == "pending" || alloc.Status == "running" || alloc.Status == "stopped" {
killVM = false
break
}
}
if killVM {
serverProvider, err := n.serverProviderRegistry.Get(v.ProviderName)
if err != nil {
log.Errorf("failed to get server provider for deleting unused resource: %v", err)
continue
}
err = serverProvider.DeleteServer(context.Background(), v.Resource.ID)
if err != nil {
log.Errorf("failed to delete provisioned resource: %v", err)
continue
}
err = n.gatewayStore.Delete(v.Resource.ID)
if err != nil {
log.Errorf("failed to delete provisioned resource from local store: %v", err)
}
}
}
}
func (n *Node) doGCBidState() {
now := time.Now()
n.lock.Lock()
defer n.lock.Unlock()
for k, bs := range n.bids {
if bs.expire.Before(now) {
delete(n.bids, k)
delete(n.answeredBids, k)
}
}
}
func (n *Node) geolocate() {
log.Infow("geolocation_initiated",
"labels", string(observability.LabelNode),
)
ip, err := n.network.HostPublicIP()
if err != nil {
log.Errorw("failed to get host public IP: %v", err)
return
}
if ip == nil {
log.Errorw("host public IP is nil")
return
}
n.lock.Lock()
n.publicIP = ip
n.lock.Unlock()
location, err := geolocation.Geolocate(ip, n.geoIP)
if err != nil {
log.Errorw("failed to geolocate host: %v", err)
return
}
n.lock.Lock()
n.hostLocation = geolocation.Geolocation{
Continent: location.Continent,
Country: location.Country,
City: location.City,
}
n.lock.Unlock()
log.Infow("geolocation_successful",
"labels", string(observability.LabelNode),
"continent", location.Continent,
"country", location.Country,
"city", location.City,
)
// periodic emitMetric
emitMetric := func() {
if m := observability.NodeLocation; m != nil {
m.Record(n.ctx, 1, metric.WithAttributes(
observability.AttrDID,
attribute.String("continent", location.Continent),
attribute.String("country", location.Country),
attribute.String("city", location.City),
attribute.Bool("onboarded", n.onboarding.IsOnboarded()),
))
}
}
go func() {
ticker := time.NewTicker(time.Hour)
defer ticker.Stop()
for {
select {
case <-ticker.C:
emitMetric()
case <-n.ctx.Done():
return
}
}
}()
}
// Start node
func (n *Node) Start() error {
log.Infow("node_start_initiated",
"labels", string(observability.LabelNode))
if err := n.allocator.Run(); err != nil {
return fmt.Errorf("start node allocator: %w", err)
}
if err := n.actor.Start(); err != nil {
return fmt.Errorf("start node actor: %w", err)
}
if err := n.subscribe(
behaviors.BroadcastHelloTopic,
behaviors.BidRequestTopic,
behaviors.BroadcastStatusDiscoveryTopic,
behaviors.BroadcastRevocationTopic,
); err != nil {
_ = n.actor.Stop()
return err
}
go func() {
if err := n.restoreDeployments(); err != nil {
log.Errorw("restoring deployments failed",
"labels", string(observability.LabelNode),
"error", err)
}
}()
// Start billing scheduler
if n.billingScheduler != nil {
n.billingScheduler.Start()
log.Infow("billing scheduler started",
"labels", string(observability.LabelNode))
}
n.running.Store(true)
go n.gcBidState()
go n.geolocate()
if n.dmsConfig.General.ComputeGateway {
go n.shutdownUnusedProvisionedResources()
}
log.Infow("node_started_successfully",
"labels", string(observability.LabelNode))
return nil
}
// Stop node
func (n *Node) Stop() error {
log.Infow("node_stop_initiated",
"labels", string(observability.LabelNode))
// Stop billing scheduler first (before stopping other components)
if n.billingScheduler != nil {
n.billingScheduler.Stop()
log.Infow("billing scheduler stopped",
"labels", string(observability.LabelNode))
}
if err := n.allocator.Stop(context.Background()); err != nil {
log.Errorf("stopping node allocator: %s", err)
}
if err := n.saveDeployments(); err != nil {
log.Errorw("error saving active deployments during node stop",
"labels", string(observability.LabelDeployment),
"error", err)
}
n.cancel()
// clear the broadcast app score
n.network.SetBroadcastAppScore(nil)
// stop the actor
if err := n.actor.Stop(); err != nil {
return fmt.Errorf("stop node actor: %w", err)
}
n.running.Store(false)
log.Infow("node_stopped_successfully",
"labels", string(observability.LabelNode))
return nil
}
// ListenForCapabilityContextsUpdates reloads all capability contexts from disk
func (n *Node) ListenForCapabilityContextsUpdates() error {
for {
select {
case <-internal.ReloadChan:
log.Infow("Received SIGUSR1, reloading capability contexts...")
// Reload the capability context while holding the lock
capCtx, err := func() (ucan.CapabilityContext, error) {
n.lock.Lock()
defer n.lock.Unlock()
// Reload the DMS capability context
// Note: Capability contexts are stored in UserDir, not WorkDir
capCtx, err := LoadCapabilityContext(n.rootCap.Trust(), n.fs, n.rootCap.Name(), n.dmsConfig.General.UserDir)
if err != nil {
return nil, fmt.Errorf("failed to reload DMS capability context: %w", err)
}
n.rootCap = capCtx
return capCtx, nil
}()
if err != nil {
log.Errorw("Failed to reload capability context", "error", err)
continue
}
// Create a new security context from the updated rootCap (same logic as in New())
rootDID := capCtx.DID()
rootTrust := capCtx.Trust()
anchor, err := rootTrust.GetAnchor(rootDID)
if err != nil {
log.Errorw("Failed to get root DID anchor", "error", err)
continue
}
pubk := anchor.PublicKey()
provider, err := rootTrust.GetProvider(rootDID)
if err != nil {
log.Errorw("Failed to get root DID provider", "error", err)
continue
}
privk, err := provider.PrivateKey()
if err != nil {
log.Errorw("Failed to get root private key", "error", err)
continue
}
newSecurity, err := actor.NewBasicSecurityContext(pubk, privk, capCtx)
if err != nil {
log.Errorw("Failed to create new security context", "error", err)
continue
}
// Update the actor's security context
if err := n.actor.UpdateSecurityContext(newSecurity); err != nil {
log.Errorw("Failed to update actor security context", "error", err)
continue
}
log.Infow("Capability contexts reloaded successfully from disk")
case <-n.ctx.Done():
log.Infow("Node context done, stopping reload loop")
return nil
}
}
}
func createEnsembleID(peerID string) (string, error) {
var id string
suffixID, err := uuid.NewUUID()
if err != nil {
return id, fmt.Errorf("failed to generate uuid for allocation inbox: %w", err)
}
h := sha256.New()
h.Write([]byte(peerID + suffixID.String()))
return hex.EncodeToString(h.Sum(nil)), nil
}
func (n *Node) createOrchestrator(
ctx context.Context,
ensemble jobtypes.EnsembleConfig,
contracts map[string]types.ContractConfig,
) (orchestrator.Orchestrator, error) {
if ensemble.V1 == nil {
return nil, fmt.Errorf("empty ensemble config")
}
ensembleID, err := createEnsembleID(n.actor.Handle().Address.HostID)
if err != nil {
return nil, fmt.Errorf("generate ensemble id: %w", err)
}
childActor, err := n.actor.CreateChild(ensembleID, n.actor.Handle())
if err != nil {
return nil, fmt.Errorf("create child actor: %w", err)
}
log.Infow("deploying ensemble",
"labels", string(observability.LabelDeployment),
"ensembleID", ensembleID)
err = childActor.Start()
if err != nil {
return nil, fmt.Errorf("start child actor: %w", err)
}
orch, err := n.orchestratorRegistry.NewOrchestrator(
ctx, n.fs, n.dmsConfig.WorkDir,
ensembleID, childActor, ensemble,
types.NewDefaultNodeIDGenerator(),
types.NewDefaultAllocationIDGenerator(),
n.contractEventHandler,
contracts,
)
if err != nil {
return nil, fmt.Errorf("new orchestrator: %w", err)
}
// if orchestrator needs to join subnet, add the subnet behaviors under ensemble namespace
// and grant the caps
if ensemble.Subnet().Join {
if err := n.addOrchestratorBehaviors(childActor, ensembleID); err != nil {
return nil, fmt.Errorf("adding behaviors for orch to join subnet: %w", err)
}
}
return orch, nil
}
// TODO: make send reply a helper func from actor pkg
func (n *Node) sendReply(msg actor.Envelope, payload interface{}) {
var opt []actor.MessageOption
if msg.IsBroadcast() {
opt = append(opt, actor.WithMessageSource(n.actor.Handle()))
}
reply, err := actor.ReplyTo(msg, payload, opt...)
if err != nil {
log.Debugf("creating reply: %s", err)
return
}
if err := n.actor.Send(reply); err != nil {
log.Debugf("sending reply: %s", err)
}
}
// ...
// TODO: do we wanna maintain the below code that is only used for e2e tests? Could there be a better way to do this?
// ...
func (n *Node) ResourceManager() types.ResourceManager {
return n.resourceManager
}
func (n *Node) Allocator() Allocator {
return n.allocator
}
// GetBidRequests returns the bid requests for the node.
func (n *Node) GetBidRequests() []jobs.BidRequest {
n.lock.Lock()
defer n.lock.Unlock()
reqs := make([]jobs.BidRequest, 0, len(n.bids))
for _, v := range n.bids {
reqs = append(reqs, v.request)
}
return reqs
}
func (n *Node) addContractActor(a *tokenomics.ContractActor) {
n.lock.Lock()
defer n.lock.Unlock()
n.contractActors[a.ContractDID.URI] = a
}
// executeBillingForContract executes billing for a contract
// This is called by the scheduler
func (n *Node) executeBillingForContract(contractDID did.DID) error {
// Get contract to verify it exists
contract, err := n.contractStore.GetContract(contractDID.URI)
if err != nil {
return fmt.Errorf("contract not found: %w", err)
}
// Find the contract actor using map lookup (O(1))
contractActor, exists := n.contractActors[contractDID.URI]
if !exists {
return fmt.Errorf("contract actor not found for %s", contractDID.URI)
}
// Check if contract should stop billing
if contract.CurrentState == contracts.ContractTerminated ||
contract.CurrentState == contracts.ContractCompleted {
// For FixedRental and Periodic payment models, generate pro-rated final invoice before unregistering
if contract.PaymentDetails.PaymentModel == contracts.FixedRental ||
contract.PaymentDetails.PaymentModel == contracts.Periodic {
// Execute billing to generate pro-rated final invoice for terminated contract
// This will call handleTerminatedContractInvoice() which generates the pro-rated invoice
err = contractActor.CheckAndGenerateInvoice()
if err != nil {
// Log error but still unregister - the invoice generation may have failed
// but we don't want to keep retrying for a terminated contract
log.Warnw("failed to generate final invoice for terminated contract",
"labels", string(observability.LabelContract),
"contract_did", contractDID.URI,
"payment_model", contract.PaymentDetails.PaymentModel,
"error", err)
}
}
// Unregister from scheduler after attempting to generate final invoice
n.billingScheduler.UnregisterContract(contractDID)
return fmt.Errorf("contract %s is %s, stopping billing", contractDID.URI, contract.CurrentState)
}
// Execute billing
err = contractActor.CheckAndGenerateInvoice()
if err != nil {
return fmt.Errorf("billing execution failed: %w", err)
}
// Update billing schedule after successful invoice
// This updates the trigger's lastInvoiceAt to use actual invoice time
if err = n.billingScheduler.UpdateContract(contractDID); err != nil {
return fmt.Errorf("failed to update billing schedule: %w", err)
}
return nil
}
// contractType determines the type of contract for billing purposes
type contractType int
const (
contractTypeP2P contractType = iota // Normal P2P contract (not part of chain)
contractTypeHeadContract // Head Contract in a chain
contractTypeTailContract // Tail Contract in a chain
)
// detectContractTypeForBilling determines contract type using metadata field
// This works in both checkAndGenerateInvoice and collectUsagesAndForwardToPaymentProviders
// without needing providerDID context
// Uses metadata field for fast O(1) detection
func (n *Node) detectContractTypeForBilling(contract *contracts.Contract) contractType {
// Check metadata field (O(1) lookup)
if contract.Metadata != nil {
if role, ok := contract.Metadata[contracts.ContractChainRoleMetadataKey]; ok {
switch role {
case contracts.ContractChainRoleHead:
return contractTypeHeadContract
case contracts.ContractChainRoleTail:
return contractTypeTailContract
}
}
}
// Default: P2P contract (no metadata set means it's not part of a chain)
return contractTypeP2P
}
func (n *Node) collectUsagesAndForwardToPaymentProviders(req contracts.CollectUsagesAndForwardToPaymentProvidersRequest) contracts.CollectUsagesAndForwardToPaymentProvidersReponse {
resp := contracts.CollectUsagesAndForwardToPaymentProvidersReponse{
Results: make([]contracts.ContractUsageResult, 0),
}
now := time.Now()
// Get contracts to process
var contractsToProcess []*contracts.Contract
var err error
if req.ContractDID != "" {
// Process specific contract
contract, err := n.contractStore.GetContract(req.ContractDID)
if err != nil {
resp.Error = fmt.Sprintf("failed to get contract %s: %v", req.ContractDID, err)
return resp
}
contractsToProcess = []*contracts.Contract{contract}
} else {
// Process all contracts
contractsToProcess, err = n.contractStore.GetAllContracts()
if err != nil {
resp.Error = fmt.Sprintf("failed to get contracts: %v", err)
return resp
}
}
type paymentForwardToProviderRequest struct {
Usages int // For backward compatibility
TimeUtilization *contracts.TimeUtilizationUsage // For pay_per_time_utilization
ResourceUtilization *contracts.ResourceUtilizationUsage // For pay_per_resource_utilization
FixedRentalDetails *contracts.FixedRentalUsage // For fixed_rental
Contract contracts.Contract
}
allPayments := make([]paymentForwardToProviderRequest, 0)
// Collect usage for each contract based on its payment model
for _, contract := range contractsToProcess {
// Skip contracts with billing disabled (Contract A)
if contract.DisableBilling {
result := contracts.ContractUsageResult{
ContractDID: contract.ContractDID,
PaymentModel: contract.PaymentDetails.PaymentModel,
Error: "billing is disabled for this contract",
}
resp.Results = append(resp.Results, result)
log.Debugw("skipping manual billing (disabled by contract flag)",
"contract_did", contract.ContractDID)
continue
}
// Get contract-specific last processed timestamp
lastProcessedAt, _ := n.usageStore.GetLastProcessedAt(contract.ContractDID)
// Get processor for this payment model
processor, err := contracts.GetPaymentModelProcessor(contract.PaymentDetails.PaymentModel)
if err != nil {
result := contracts.ContractUsageResult{
ContractDID: contract.ContractDID,
PaymentModel: contract.PaymentDetails.PaymentModel,
Error: fmt.Sprintf("unsupported payment model: %v", err),
}
resp.Results = append(resp.Results, result)
log.Warnf("unsupported payment model for contract %s: %v", contract.ContractDID, err)
continue
}
// Check if this model supports manual billing
if !processor.SupportsManualBilling() {
result := contracts.ContractUsageResult{
ContractDID: contract.ContractDID,
PaymentModel: contract.PaymentDetails.PaymentModel,
Error: fmt.Sprintf("%s payment model uses automatic periodic billing and cannot be manually triggered. Invoices are generated automatically by the contract actor.", contract.PaymentDetails.PaymentModel),
}
resp.Results = append(resp.Results, result)
log.Warnf("manual invoice generation attempted for %s contract %s", contract.PaymentDetails.PaymentModel, contract.ContractDID)
continue
}
// Detect contract type using metadata field (no providerDID needed)
contractType := n.detectContractTypeForBilling(contract)
var usageData *contracts.UsageData
switch contractType {
case contractTypeHeadContract:
// Head Contract: query events by head_contract_did
usageData, err = processor.CollectUsage(
contract.ContractDID,
lastProcessedAt,
now,
"", // providerDID (empty for Head Contract - aggregate all)
contract.ContractDID, // headContractDID (query events where head_contract_did = this DID)
)
case contractTypeTailContract, contractTypeP2P:
// Tail Contract or P2P: query events by contract_did
usageData, err = processor.CollectUsage(
contract.ContractDID,
lastProcessedAt,
now,
"", // providerDID (can be specified for per-provider billing)
"", // headContractDID (empty - query by contract_did)
)
default:
// Fallback: treat as P2P
usageData, err = processor.CollectUsage(
contract.ContractDID,
lastProcessedAt,
now,
"",
"",
)
}
if err != nil {
result := contracts.ContractUsageResult{
ContractDID: contract.ContractDID,
PaymentModel: contract.PaymentDetails.PaymentModel,
Error: fmt.Sprintf("failed to collect usage: %v", err),
}
resp.Results = append(resp.Results, result)
log.Warnf("failed to collect usage for contract %s: %v", contract.ContractDID, err)
continue
}
// Convert usage data to result format
result := n.convertUsageDataToResult(usageData)
if result.Error != "" {
resp.Results = append(resp.Results, result)
log.Warnf("failed to convert usage data for contract %s: %s", contract.ContractDID, result.Error)
continue
}
// Save contract-specific last processed timestamp
err = n.usageStore.SaveLastProcessedAt(contract.ContractDID, now)
if err != nil {
result.Error = fmt.Sprintf("failed to save last processed timestamp: %v", err)
resp.Results = append(resp.Results, result)
log.Warnf("failed to save last processed usage for contract %s: %v", contract.ContractDID, err)
continue
}
// Add result
resp.Results = append(resp.Results, result)
// Add to payments if there are usages
if result.Usages > 0 || result.TimeUtilization != nil || result.ResourceUtilization != nil {
paymentReq := paymentForwardToProviderRequest{
Contract: *contract,
Usages: result.Usages,
}
if result.TimeUtilization != nil {
paymentReq.TimeUtilization = result.TimeUtilization
}
if result.ResourceUtilization != nil {
paymentReq.ResourceUtilization = result.ResourceUtilization
}
allPayments = append(allPayments, paymentReq)
}
}
for _, v := range allPayments {
req := contracts.ContractUsageRequest{
UniqueID: uuid.NewString(),
Contract: v.Contract,
Usages: v.Usages,
TimeUtilization: v.TimeUtilization, // May contain multiple deployments for pay_per_time_utilization
ResourceUtilization: v.ResourceUtilization, // May contain multiple deployments for pay_per_resource_utilization
}
// construct destination address
destination, err := actor.HandleFromDID(v.Contract.PaymentValidatorDID.URI)
if err != nil {
log.Errorf("failed to get handle of payment provider for contract %s: %v", v.Contract.ContractDID, err)
continue
}
envelope, err := n.invokeBehaviour(destination, behaviors.ContractUsageBehavior, req, invokeMessageTimeout)
if envelope.Message == nil || err != nil {
log.Errorf("failed to update payment status of contract host for contract %s: %v", v.Contract.ContractDID, err)
continue
}
log.Infof("Successfully sent ContractUsageRequest for contract %s to payment validator", v.Contract.ContractDID)
resp.TotalUsages++
}
return resp
}
// convertUsageDataToResult converts UsageData from processor to ContractUsageResult
func (n *Node) convertUsageDataToResult(usageData *contracts.UsageData) contracts.ContractUsageResult {
result := contracts.ContractUsageResult{
ContractDID: usageData.ContractDID,
PaymentModel: usageData.PaymentModel,
}
switch usageData.PaymentModel {
case contracts.PayPerAllocation:
usageCount, ok := usageData.Data.(int)
if !ok {
result.Error = "invalid usage data type for pay_per_allocation"
return result
}
result.Usages = usageCount
case contracts.PayPerDeployment:
usageCount, ok := usageData.Data.(int)
if !ok {
result.Error = "invalid usage data type for pay_per_deployment"
return result
}
result.Usages = usageCount
case contracts.PayPerTimeUtilization:
timeUtil, ok := usageData.Data.(*contracts.TimeUtilizationUsage)
if !ok {
result.Error = "invalid usage data type for pay_per_time_utilization"
return result
}
result.TimeUtilization = timeUtil
result.Usages = len(timeUtil.Deployments) // For backward compatibility
case contracts.PayPerResourceUtilization:
resourceUtil, ok := usageData.Data.(*contracts.ResourceUtilizationUsage)
if !ok {
result.Error = "invalid usage data type for pay_per_resource_utilization"
return result
}
result.ResourceUtilization = resourceUtil
result.Usages = len(resourceUtil.Deployments) // For backward compatibility
default:
result.Error = fmt.Sprintf("unsupported payment model: %s", usageData.PaymentModel)
}
return result
}
func (n *Node) invokeBehaviour(destination actor.Handle, behavior string, payload any, timeout time.Duration) (actor.Envelope, error) {
msg, err := actor.Message(
n.actor.Handle(),
destination,
behavior,
payload,
actor.WithMessageExpiry(actor.MakeExpiry(timeout)),
)
if err != nil {
return actor.Envelope{}, fmt.Errorf("failed to create contract actor message: %w", err)
}
replyCh, err := n.actor.Invoke(msg)
if err != nil {
return actor.Envelope{}, fmt.Errorf("failed to invoke message: %w", err)
}
ticker := time.NewTicker(timeout)
defer ticker.Stop()
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
return reply, nil
case <-ticker.C:
return actor.Envelope{}, errors.New("failed to receive reply due to timeout")
}
}
func (n *Node) handleContractEvents(event eventhandler.Event) error {
hostDID, err := did.FromString(event.ContractHostDID)
if err != nil {
return fmt.Errorf("failed to get contracts host did: %w", err)
}
pubKey, err := did.PublicKeyFromDID(hostDID)
if err != nil {
return fmt.Errorf("failed to get contracts host public key from did: %w", err)
}
pid, err := peer.IDFromPublicKey(pubKey)
if err != nil {
return fmt.Errorf("failed to get peer id: %w", err)
}
// get actor public key
contractActorDID, err := did.FromString(event.ContractDID)
if err != nil {
return fmt.Errorf("failed to get contracts actor did: %w", err)
}
pubKeyContractActor, err := did.PublicKeyFromDID(contractActorDID)
if err != nil {
return fmt.Errorf("failed to get contracts actor public key from did: %w", err)
}
destination, err := actor.HandleFromPublicKeyWithInboxAddress(pubKeyContractActor, event.ContractDID, pid.String())
if err != nil {
return fmt.Errorf("failed to get contracts host handle: %w", err)
}
bts, err := json.Marshal(event.Payload)
if err != nil {
return fmt.Errorf("failed to marshal event object: %w", err)
}
req := contracts.ContractEventRequest{
Payload: bts,
}
reply, err := n.invokeBehaviour(destination, behaviors.ContractEventsBehavior, req, invokeMessageTimeout)
if err != nil {
return fmt.Errorf("failed to send message to contract host: %w", err)
}
var respEnvelope contracts.ContractEventResponse
err = json.Unmarshal(reply.Message, &respEnvelope)
if err != nil {
return fmt.Errorf("failed to unmarshal contract hosts response payload: %w", err)
}
if respEnvelope.Error != "" {
return fmt.Errorf("failed to process contract event: %s", respEnvelope.Error)
}
return nil
}
// isRestorableStatus checks if a deployment status is restorable
// TODO: This will be implemented later with more sophisticated logic
// For now, all deployments with status <= Running are considered restorable
func isRestorableStatus(status jobtypes.DeploymentStatus) bool {
// TODO: Implement more sophisticated restorable status logic
// This should consider:
// - Deployment lifecycle state
// - Resource allocation status
// - Compute provider availability
// - Network connectivity requirements
// For now we will keep it as true until this logic has been designed.
return status <= jobtypes.DeploymentStatusRunning
}
// isDeploymentStillValid checks if a deployment is still valid for restoration
// TODO: This will be implemented later with comprehensive validation
// For now, all deployments are considered valid
func isDeploymentStillValid(_ *jobtypes.OrchestratorView) bool {
// TODO: Implement comprehensive deployment validation
// This should check:
// - Deployment configuration is still valid
// - Required resources are still available
// - Network configuration is still accessible
// - Compute provider is still responsive
// - Deployment hasn't been explicitly cancelled
// - Resource quotas haven't been exceeded
// - Security policies haven't changed
//
// Note: Compute providers should also implement similar validation
// on their side to ensure resources are still available and
// haven't been reclaimed due to extended downtime
return true
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
package node
import (
"context"
job_types "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/executor/docker"
)
func (n *Node) initSupportedExecutors(ctx context.Context) error {
dockerExec, err := docker.NewExecutor(ctx, n.fs, "root")
if err == nil {
n.executors[string(job_types.ExecutorDocker)] = executorMetadata{
executor: dockerExec,
executionType: job_types.ExecutorDocker,
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"fmt"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/store/payment"
"gitlab.com/nunet/device-management-service/types"
)
// PaymentProcessor defines the interface for processing payment items.
// This interface allows for different implementations and easy testing.
type PaymentProcessor interface {
// ProcessPaymentItems processes a batch of payment items for a contract.
// This method:
// - Generates unique IDs for items that don't have one
// - Saves each payment to the payment store
// - Forwards transaction requests to the service provider
// - Continues processing even if individual items fail
// Returns an error only if there's a critical failure that prevents processing.
ProcessPaymentItems(
contract *contracts.Contract,
items []*contracts.PaymentItem,
baseUniqueID string,
) error
}
// paymentProcessorImpl implements PaymentProcessor
type paymentProcessorImpl struct {
paymentStore *payment.Store
network network.Network
actor actor.Handle
invokeBehaviour func(destination actor.Handle, behavior string, req interface{}, timeout time.Duration) (actor.Envelope, error)
}
// NewPaymentProcessor creates a new payment processor implementation
func NewPaymentProcessor(
paymentStore *payment.Store,
network network.Network,
actor actor.Handle,
invokeBehaviour func(destination actor.Handle, behavior string, req interface{}, timeout time.Duration) (actor.Envelope, error),
) PaymentProcessor {
if paymentStore == nil {
panic("payment store cannot be nil")
}
if network == nil {
panic("network cannot be nil")
}
if actor.ID.String() == "" {
panic("actor handle cannot be empty")
}
if invokeBehaviour == nil {
panic("invokeBehaviour function cannot be nil")
}
return &paymentProcessorImpl{
paymentStore: paymentStore,
network: network,
actor: actor,
invokeBehaviour: invokeBehaviour,
}
}
// ProcessPaymentItems implements PaymentProcessor.ProcessPaymentItems
func (pp *paymentProcessorImpl) ProcessPaymentItems(
contract *contracts.Contract,
items []*contracts.PaymentItem,
_ string, // baseUniqueID - no longer used, UUIDs generated in payment processors
) error {
if contract == nil {
return fmt.Errorf("contract cannot be nil")
}
if len(items) == 0 {
return nil // No items to process
}
// Process each item
for _, item := range items {
// UUID should already be generated by payment processor in CalculatePayment
// If not set, log a warning (should not happen in normal flow)
if item.UniqueID == "" {
log.Warnf("PaymentItem has empty UniqueID, this should not happen")
// Fallback: generate UUID if somehow missing (defensive programming)
item.UniqueID = uuid.NewString()
}
// Build transaction metadata from PaymentItem
// Start with a copy of PaymentItem.Metadata (if it exists)
metadata := make(map[string]interface{})
if item.Metadata != nil {
// Copy all existing metadata
for k, v := range item.Metadata {
metadata[k] = v
}
}
// Add deployment_id from PaymentItem.DeploymentID
if item.DeploymentID != "" {
metadata["deployment_id"] = item.DeploymentID
}
// Store metadata back in PaymentItem for use in forwardTransaction
item.Metadata = metadata
// Save payment
if err := pp.savePayment(contract, item); err != nil {
log.Errorf("failed to save payment for item %s: %v", item.UniqueID, err)
continue // Continue with other items
}
// Forward transaction
if err := pp.forwardTransaction(contract, item); err != nil {
log.Errorf("failed to forward transaction for item %s: %v", item.UniqueID, err)
// Continue - payment is saved
}
}
// Calculate and process orchestration fee for the entire batch
if contract.PaymentDetails.OrchestrationFee == nil {
return nil // No orchestration fee configured
}
feeCalculator := contracts.NewOrchestrationFeeCalculator()
// Calculate fee based on all payment items
feeAmount, err := feeCalculator.CalculateFee(items, contract.PaymentDetails.OrchestrationFee)
if err != nil {
log.Errorf("failed to calculate orchestration fee: %v", err)
return nil // Don't fail - primary transactions already processed
}
if feeAmount == "" {
return nil // No fee to generate
}
// Generate orchestration fee item
feeItem, err := feeCalculator.GenerateOrchestrationFeeItem(
items,
contract,
items[0].UniqueID,
feeAmount,
)
if err != nil {
log.Errorf("failed to generate orchestration fee item: %v", err)
return nil // Don't fail - primary transactions already processed
}
if feeItem == nil {
return nil // No fee item generated
}
// Process orchestration fee item using existing methods
// Save payment (reuses existing savePayment method)
if err := pp.savePayment(contract, feeItem); err != nil {
log.Errorf("failed to save orchestration fee payment: %v", err)
return nil // Don't fail - primary transactions already processed
}
// Forward transaction - create temporary contract copy with modified addresses if needed
tempContract := *contract
// Use orchestration fee recipient address if specified, otherwise use contract addresses
if contract.PaymentDetails.OrchestrationFee.RecipientAddress.RequesterAddr != "" {
tempContract.PaymentDetails.Addresses = []types.PaymentAddressInfo{
contract.PaymentDetails.OrchestrationFee.RecipientAddress,
}
}
// Forward transaction using existing forwardTransaction method
// This forwards to the original requestor (same as primary transactions)
if err := pp.forwardTransaction(&tempContract, feeItem); err != nil {
log.Errorf("failed to forward orchestration fee transaction: %v", err)
// Continue - payment is saved
}
return nil
}
// savePayment saves a payment to the payment store
func (pp *paymentProcessorImpl) savePayment(contract *contracts.Contract, item *contracts.PaymentItem) error {
return pp.paymentStore.Insert(payment.Payment{
UniqueID: item.UniqueID,
Contract: *contract,
Usages: item.Usages,
Amount: item.Amount,
Paid: false,
})
}
// forwardTransaction forwards a transaction request to the service provider
func (pp *paymentProcessorImpl) forwardTransaction(
contract *contracts.Contract,
item *contracts.PaymentItem,
) error {
// Create transaction request
txReq := contracts.TransactionForServiceProviderRequest{
PaymentValidatorDID: contract.PaymentValidatorDID.URI,
UniqueID: item.UniqueID,
ContractDID: contract.ContractDID,
ToAddress: contract.PaymentDetails.Addresses,
Amount: item.Amount, // Store original amount (in pricing currency if conversion needed)
Status: "unpaid",
Metadata: item.Metadata,
}
// If pricing_currency is set, store original amount and conversion metadata
// NO CONVERSION HAPPENS HERE - conversion only happens at quote time
if contract.PaymentDetails.PricingCurrency != "" &&
contract.PaymentDetails.PricingCurrency != "NTX" {
// Store original amount in pricing currency (e.g., USDT)
// The Amount field contains the original amount in pricing currency
// Conversion to NTX will happen later when user requests a quote
txReq.OriginalAmount = item.Amount
txReq.PricingCurrency = contract.PaymentDetails.PricingCurrency
txReq.RequiresConversion = true
}
destination, err := actor.HandleFromDID(contract.ContractParticipants.Requestor.URI)
if err != nil {
return fmt.Errorf("failed to get requestor handle: %w", err)
}
go func() {
reply, err := pp.invokeBehaviour(destination, behaviors.ContractTransactionBehavior, txReq, invokeMessageTimeout)
if reply.Message == nil || err != nil {
log.Errorf("failed to forward transaction %s: %v", item.UniqueID, err)
}
}()
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package node
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/lib/did"
)
const PrismDIDMapFile = "prism_did_map.json"
// prismDIDMapFilePath returns the path to the PRISM DID mapping file
func prismDIDMapFilePath(userDir string) string {
return filepath.Join(userDir, PrismDIDMapFile)
}
// loadPrismDIDMap loads the PRISM DID to key name mapping
// Missing file returns empty map (no error)
func loadPrismDIDMap(fs afero.Fs, userDir string) (map[string]string, error) {
path := prismDIDMapFilePath(userDir)
mapping := make(map[string]string)
f, err := fs.Open(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return mapping, nil
}
return nil, fmt.Errorf("open PRISM DID map: %w", err)
}
defer f.Close()
if err := json.NewDecoder(f).Decode(&mapping); err != nil {
return nil, fmt.Errorf("decode PRISM DID map: %w", err)
}
return mapping, nil
}
// savePrismDIDMap writes the PRISM DID mapping atomically
func savePrismDIDMap(fs afero.Fs, userDir string, mapping map[string]string) error {
path := prismDIDMapFilePath(userDir)
// Create backup if file exists
if _, err := fs.Stat(path); err == nil {
backupPath := path + ".bak"
_ = fs.Remove(backupPath) // ignore error
// Read existing file and write to backup
existingData, err := afero.ReadFile(fs, path)
if err == nil {
if err := afero.WriteFile(fs, backupPath, existingData, 0o600); err != nil {
return fmt.Errorf("backup PRISM DID map: %w", err)
}
}
}
// Write to temp file first
tmpPath := path + ".tmp"
data, err := json.MarshalIndent(mapping, "", " ")
if err != nil {
return fmt.Errorf("marshal PRISM DID map: %w", err)
}
if err := afero.WriteFile(fs, tmpPath, data, 0o600); err != nil {
return fmt.Errorf("write PRISM DID map: %w", err)
}
// Atomic rename
if err := fs.Rename(tmpPath, path); err != nil {
_ = fs.Remove(tmpPath) // cleanup
return fmt.Errorf("rename PRISM DID map: %w", err)
}
return nil
}
// SetPrismDID associates a key name with a PRISM DID
func SetPrismDID(fs afero.Fs, userDir, keyName, prismDID string) error {
mapping, err := loadPrismDIDMap(fs, userDir)
if err != nil {
return err
}
// Validate PRISM DID
didObj, err := did.FromString(prismDID)
if err != nil {
return fmt.Errorf("invalid PRISM DID: %w", err)
}
if didObj.Method() != "prism" {
return fmt.Errorf("expected PRISM DID (did:prism:...), got %s", didObj.Method())
}
mapping[keyName] = prismDID
return savePrismDIDMap(fs, userDir, mapping)
}
// GetPrismDID retrieves the PRISM DID for a key name
// Returns empty string if not found
func GetPrismDID(fs afero.Fs, userDir, keyName string) (string, error) {
mapping, err := loadPrismDIDMap(fs, userDir)
if err != nil {
return "", err
}
return mapping[keyName], nil
}
// RemovePrismDID removes the PRISM DID association for a key name
func RemovePrismDID(fs afero.Fs, userDir, keyName string) error {
mapping, err := loadPrismDIDMap(fs, userDir)
if err != nil {
return err
}
delete(mapping, keyName)
return savePrismDIDMap(fs, userDir, mapping)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package onboarding
import (
"context"
"errors"
"fmt"
"sync"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
"go.opentelemetry.io/otel/metric"
)
var (
// ErrMachineNotOnboarded is returned when the machine is not onboarded and is expected to be
ErrMachineNotOnboarded = errors.New("machine is not onboarded")
// ErrUnmetCapacity is returned when the onboarded resources doesn't meet the required capacity
ErrUnmetCapacity = errors.New("capacity not met")
// ErrHighUsage is returned when the machine has high usage compared to the onboarded resources
ErrHighUsage = errors.New("machine has high usage")
// ErrOutOfRange is returned when the actual value is out of the expected range
ErrOutOfRange = errors.New("out of range")
)
// validateRange validates the actual value is within the min and max range
func validateRange(actual, minimum, maximum uint64) error {
if actual < minimum || actual > maximum {
return ErrOutOfRange
}
return nil
}
// populateOnboardingConfig populates the onboarding config fields
func populateOnboardingConfig(ctx context.Context,
config *types.OnboardingConfig,
resourceManager types.ResourceManager,
) error {
onboardedResources, err := resourceManager.GetOnboardedResources(ctx)
if err != nil {
return fmt.Errorf("could not get onboarded resources: %w", err)
}
config.OnboardedResources = onboardedResources.Resources
// ...
// load other fields
// ...
return nil
}
// Onboarding implements the OnboardingManager interface
type Onboarding struct {
ResourceManager types.ResourceManager
Hardware types.HardwareManager
// ConfigRepo is the db repository to store the onboarding config
ConfigRepo repositories.GenericEntityRepository[types.OnboardingConfig]
// Config is the cached onboarding configuration
Config types.OnboardingConfig
// Lock is the lock to protect the onboarding state
Lock sync.RWMutex
}
// Ensure Onboarding implements the OnboardingManager interface
var _ types.OnboardingManager = (*Onboarding)(nil)
// New creates a new Onboarding instance
func New(ctx context.Context,
resourceManager types.ResourceManager,
hardwareManager types.HardwareManager,
configRepo repositories.GenericEntityRepository[types.OnboardingConfig],
) (*Onboarding, error) {
if resourceManager == nil {
return nil, fmt.Errorf("resource manager is required")
}
if hardwareManager == nil {
return nil, fmt.Errorf("hardware manager is required")
}
if configRepo == nil {
return nil, fmt.Errorf("config repo is required")
}
// try loading the onboarding config from the database
config, err := configRepo.Get(ctx)
if err != nil {
if !errors.Is(err, repositories.ErrNotFound) {
return nil, fmt.Errorf("could not get onboarding config: %w", err)
}
config = types.OnboardingConfig{}
}
onboardingManager := &Onboarding{
ResourceManager: resourceManager,
Hardware: hardwareManager,
ConfigRepo: configRepo,
Config: config,
}
// validate the onboarding config
if onboardingManager.Config.IsOnboarded {
// populate the onboarded resources
if err := populateOnboardingConfig(ctx, &config, resourceManager); err != nil {
return nil, fmt.Errorf("populate onboarding config: %w", err)
}
if err := onboardingManager.validatePrerequisites(config); err != nil {
switch {
case errors.Is(err, ErrUnmetCapacity):
log.Errorw("machine onboarded, but capacity not fully met",
"labels", string(observability.LabelNode),
"error", err)
case errors.Is(err, ErrHighUsage):
log.Errorw("machine onboarded, but high usage detected",
"labels", string(observability.LabelNode),
"error", err)
return onboardingManager, nil
default:
log.Errorw("machine is onboarded but prerequisites are not met",
"labels", string(observability.LabelNode),
"error", err)
}
// if the machine is onboarded but the prerequisites are not met, offboard the machine
log.Infow("offboarding the machine because onboarded resources are no longer valid",
"labels", string(observability.LabelNode))
if err := onboardingManager.Offboard(context.Background()); err != nil {
return nil, fmt.Errorf("offboard the machine: %w", err)
}
}
}
return onboardingManager, nil
}
// validateCapacity validates the machine capacity for the requested onboarding resources
func validateCapacity(onboardedResources, machineResources types.Resources) error {
if onboardedResources.CPU.Cores < 1 || onboardedResources.CPU.Cores > machineResources.CPU.Cores {
return fmt.Errorf("cores must be between %d and %.0f", 1, machineResources.CPU.Cores)
}
if err := validateRange(
onboardedResources.RAM.Size,
machineResources.RAM.Size/10, // minimum 10% of total RAM
machineResources.RAM.Size*9/10, // maximum 90% of total RAM
); err != nil {
if errors.Is(err, ErrOutOfRange) {
return fmt.Errorf("expected RAM to be between %d GiB and %d GiB, got %d GiB",
machineResources.RAM.SizeInGiB()/10,
machineResources.RAM.SizeInGiB()*9/10,
onboardedResources.RAM.SizeInGiB(),
)
}
return fmt.Errorf("validating resource range for RAM: %w", err)
}
if err := validateRange(
onboardedResources.Disk.Size,
1024*1024*1024, // minimum 1 GiB
machineResources.Disk.Size*9/10, // maximum 90% of total disk
); err != nil {
if errors.Is(err, ErrOutOfRange) {
return fmt.Errorf("expected Disk to be between 1 GiB and %d GiB, got %d GiB",
machineResources.Disk.SizeInGiB()*9/10,
onboardedResources.Disk.SizeInGiB(),
)
}
return fmt.Errorf("validating resource range for disk: %w", err)
}
for _, gpu := range onboardedResources.GPUs {
selectedGPU, err := machineResources.GPUs.GetWithIndex(gpu.Index)
if err != nil {
return fmt.Errorf("could not get find gpu: %w", err)
}
if err := validateRange(
gpu.VRAM,
selectedGPU.VRAM/10,
selectedGPU.VRAM*9/10,
); err != nil {
if errors.Is(err, ErrOutOfRange) {
return fmt.Errorf("expected GPU %d VRAM to be between %d and %d, got %d",
gpu.Index,
selectedGPU.VRAMInGB()/10,
selectedGPU.VRAMInGB()*9/10,
gpu.VRAMInGB(),
)
}
return fmt.Errorf("validating resource range for GPU %d: %w", gpu.Index, err)
}
}
return nil
}
// validateUsage validates the machine usage for the requested onboarding resources
func validateUsage(onboardedResources, systemFreeResources types.Resources) error {
if onboardedResources.CPU.Compute() > systemFreeResources.CPU.Compute() {
return fmt.Errorf("not enough free compute available on the system: %.2f GHz", systemFreeResources.CPU.ComputeInGHz())
}
if onboardedResources.RAM.Size > systemFreeResources.RAM.Size {
return fmt.Errorf("not enough free RAM available on the system: %d GB", systemFreeResources.RAM.SizeInGB())
}
if onboardedResources.Disk.Size > systemFreeResources.Disk.Size {
return fmt.Errorf("not enough free Disk available on the system: %d GB", systemFreeResources.Disk.SizeInGB())
}
for _, gpu := range onboardedResources.GPUs {
selectedGPU, err := systemFreeResources.GPUs.GetWithIndex(gpu.Index)
if err != nil {
return fmt.Errorf("could not find gpu: %w", err)
}
if gpu.VRAM > selectedGPU.VRAM {
return fmt.Errorf("not enough free VRAM available on GPU %s: %d GB", gpu.Model, selectedGPU.VRAMInGB())
}
}
return nil
}
// validatePrerequisites validates the onboarding prerequisites
func (o *Onboarding) validatePrerequisites(config types.OnboardingConfig) error {
machineResources, err := o.Hardware.GetMachineResources()
if err != nil {
return fmt.Errorf("could not get machine resources: %w", err)
}
{
gpuCount := len(machineResources.Resources.GPUs)
log.Infow("machine_hardware_resources",
"labels", string(observability.LabelNode),
"cpuCores", machineResources.Resources.CPU.Cores,
"ramGB", machineResources.Resources.RAM.SizeInGB(),
"gpuCount", gpuCount,
)
for idx, gpu := range machineResources.Resources.GPUs {
log.Infow("machine_hardware_gpu",
"labels", string(observability.LabelNode),
"gpuIndex", gpu.Index,
"gpuModel", gpu.Model,
"gpuVramGB", gpu.VRAMInGB(),
"gpuCores", gpu.Cores,
"gpuLogIndex", idx, // just to see the loop index
)
}
}
if err := validateCapacity(config.OnboardedResources, machineResources.Resources); err != nil {
return fmt.Errorf("%w: %v", ErrUnmetCapacity, err)
}
systemFreeResources, err := o.Hardware.GetFreeResources()
if err != nil {
return fmt.Errorf("could not get system free resources: %w", err)
}
{
gpuCount := len(systemFreeResources.GPUs)
log.Infow("machine_free_resources",
"labels", string(observability.LabelNode),
"freeCpuCores", systemFreeResources.CPU.Cores,
"freeRamGB", systemFreeResources.RAM.SizeInGB(),
"freeGpuCount", gpuCount,
)
for idx, gpu := range systemFreeResources.GPUs {
log.Infow("machine_free_gpu",
"labels", string(observability.LabelNode),
"gpuIndex", gpu.Index,
"gpuModel", gpu.Model,
"gpuVramGB", gpu.VRAMInGB(),
"gpuCores", gpu.Cores,
"gpuLogIndex", idx,
)
}
}
if err := validateUsage(config.OnboardedResources, systemFreeResources); err != nil {
return fmt.Errorf("%w: %v", ErrHighUsage, err)
}
return nil
}
// Onboard validates the onboarding params and onboards the machine to the network
func (o *Onboarding) Onboard(ctx context.Context, config types.OnboardingConfig) (types.OnboardingConfig, error) {
o.Lock.Lock()
defer o.Lock.Unlock()
log.Debugw("onboarding machine with config", "labels", string(observability.LabelNode), "config", config)
if err := o.validatePrerequisites(config); err != nil {
return types.OnboardingConfig{}, fmt.Errorf("could not validate onboarding prerequisites: %w", err)
}
if err := o.ResourceManager.UpdateOnboardedResources(ctx, config.OnboardedResources); err != nil {
return types.OnboardingConfig{}, fmt.Errorf("could not update onboarded resources: %w", err)
}
log.Infow("onboarded_resources_assigned",
"labels", string(observability.LabelNode),
"cpuCoresAssigned", config.OnboardedResources.CPU.Cores,
"ramGBAssigned", config.OnboardedResources.RAM.SizeInGB(),
"diskMBAssigned", config.OnboardedResources.Disk.Size/(1024.0*1024.0),
"gpuCountAssigned", len(config.OnboardedResources.GPUs),
)
if m := observability.NodeOnboarded; m != nil {
m.Add(ctx, 1, metric.WithAttributes(
observability.AttrDID))
observability.NodeOnboardedCPU.Record(ctx, float64(config.OnboardedResources.CPU.Cores), metric.WithAttributes(
observability.AttrDID))
observability.NodeOnboardedRAM.Record(ctx, int64(config.OnboardedResources.RAM.SizeInGB()), metric.WithAttributes(
observability.AttrDID))
observability.NodeOnboardedDisk.Record(ctx, int64(config.OnboardedResources.Disk.Size/(1024.0*1024.0)), metric.WithAttributes(
observability.AttrDID))
observability.NodeOnboardedGPU.Record(ctx, int64(len(config.OnboardedResources.GPUs)), metric.WithAttributes(
observability.AttrDID))
}
config.IsOnboarded = true
if _, err := o.ConfigRepo.Save(ctx, config); err != nil {
return types.OnboardingConfig{}, fmt.Errorf("could not save onboarding config: %w", err)
}
log.Infow("machine_onboarded_successfully",
"labels", string(observability.LabelNode))
o.Config = config
return o.Config, nil
}
// Offboard offboards the machine from the network
func (o *Onboarding) Offboard(ctx context.Context) error {
o.Lock.Lock()
defer o.Lock.Unlock()
if !o.Config.IsOnboarded {
return ErrMachineNotOnboarded
}
err := o.ConfigRepo.Clear(ctx)
if err != nil {
return fmt.Errorf("failed to clear onboarding config from db: %w", err)
}
o.Config.IsOnboarded = false
// clear the onboarded resources
if err := o.ResourceManager.UpdateOnboardedResources(ctx, types.Resources{}); err != nil {
return fmt.Errorf("could not clear onboarded resources: %w", err)
}
if m := observability.NodeOnboarded; m != nil {
m.Add(ctx, -1, metric.WithAttributes(
observability.AttrDID,
))
observability.NodeOnboardedCPU.Record(ctx, 0, metric.WithAttributes(
observability.AttrDID))
observability.NodeOnboardedRAM.Record(ctx, 0, metric.WithAttributes(
observability.AttrDID))
observability.NodeOnboardedDisk.Record(ctx, 0, metric.WithAttributes(
observability.AttrDID))
observability.NodeOnboardedGPU.Record(ctx, 0, metric.WithAttributes(
observability.AttrDID))
}
log.Infow("machine_offboarded_successfully",
"labels", string(observability.LabelNode))
return nil
}
// IsOnboarded checks whether the machine is onboarded or not
func (o *Onboarding) IsOnboarded() bool {
o.Lock.RLock()
defer o.Lock.RUnlock()
return o.Config.IsOnboarded
}
// Info returns the onboarding configuration
func (o *Onboarding) Info(ctx context.Context) (types.OnboardingConfig, error) {
o.Lock.RLock()
defer o.Lock.RUnlock()
info := o.Config
if err := populateOnboardingConfig(ctx, &info, o.ResourceManager); err != nil {
return types.OnboardingConfig{}, fmt.Errorf("populate onboarding config: %w", err)
}
return info, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"encoding/json"
"fmt"
"path/filepath"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
func (o *BasicOrchestrator) handleTaskTermination(msg actor.Envelope) {
msg.Discard()
var req behaviors.TaskTerminationNotification
if err := json.Unmarshal(msg.Message, &req); err != nil {
log.Debugf("unmarshalling task completion request: %s", err)
return
}
log.Infow("task_terminated",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.id,
"allocationID", req.AllocationID,
"status", req.Status)
// Parse the allocation ID to get the manifest key
allocID, err := types.ParseAllocationID(req.AllocationID)
if err != nil {
log.Debugf("failed to parse allocation ID %s: %v", req.AllocationID, err)
return
}
manifestKey := allocID.ManifestKey()
updateMan := o.Manifest()
a, ok := updateMan.Allocations[manifestKey]
if !ok {
log.Debugf("allocation %s not found on the manifest", req.AllocationID)
return
}
// update allocation status
a.Status = jtypes.AllocationStatus(req.Status)
updateMan.Allocations[manifestKey] = a
o.updateManifest(updateMan)
if req.Error.Err != "" {
log.Errorf(
"allocation task %s yielded error: %v",
req.AllocationID, req.Error,
)
return
}
allocDir, err := o.WriteAllocationLogs(manifestKey, req.Stdout, req.Stderr)
if err != nil {
log.Errorf("failed to write logs for allocation %s: %v", manifestKey, err)
return
}
log.Infow("allocation_logs_saved", "labels", []string{string(observability.LabelDeployment)},
"manifest", manifestKey, "path", allocDir, "orchestratorID", o.id)
}
func (o *BasicOrchestrator) WriteAllocationLogs(
allocName string, stdout, stderr []byte,
) (string, error) {
ensembleDir := filepath.Join(
o.workDir,
"deployments",
o.id,
)
allocDir := filepath.Join(ensembleDir, allocName)
err := o.fs.MkdirAll(allocDir, 0o755)
if err != nil {
return "", fmt.Errorf("failed to create allocation directory %s: %w", allocDir, err)
}
if len(stdout) > 0 {
stdoutPath := filepath.Join(allocDir, "stdout.log")
err = o.fs.WriteFile(stdoutPath, stdout, 0o644)
if err != nil {
return "", fmt.Errorf("failed to write stdout logs to %s: %w", stdoutPath, err)
}
}
if len(stderr) > 0 {
stderrPath := filepath.Join(allocDir, "stderr.log")
err = o.fs.WriteFile(stderrPath, stderr, 0o644)
if err != nil {
return "", fmt.Errorf("failed to write stderr logs to %s: %w", stderrPath, err)
}
}
return allocDir, nil
}
// handleAllocationLiveness passively records push heartbeats
// NOTE: This does NOT affect health decisions - supervisor's pull checks remain authoritative
func (o *BasicOrchestrator) handleAllocationLiveness(msg actor.Envelope) {
defer msg.Discard()
var notification jtypes.AllocationLivenessNotification
if err := json.Unmarshal(msg.Message, ¬ification); err != nil {
log.Debugw("unmarshalling_liveness_notification_failed",
"labels", []string{string(observability.LabelDeployment)},
"error", err)
return
}
o.lock.Lock()
defer o.lock.Unlock()
if _, ok := o.allocs[notification.AllocationID]; !ok {
log.Debugw("liveness_notification_for_unknown_allocation",
"labels", []string{string(observability.LabelDeployment)},
"allocationID", notification.AllocationID)
return
}
log.Debugw("received_allocation_heartbeat",
"labels", []string{string(observability.LabelDeployment)},
"ensembleID", o.id,
"allocationID", notification.AllocationID,
"sequence", notification.SequenceNumber,
"status", notification.Status,
"healthy", notification.Health.Healthy,
"check_type", notification.Health.CheckType)
nInfo := o.allocs[notification.AllocationID]
nInfo.HeartbeatSeq = notification.SequenceNumber
if nInfo.Status != jtypes.AllocationStatus(notification.Status) {
if m := observability.AllocationStatus; m != nil {
m.Record(o.ctx, 1, metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", o.id),
attribute.String("allocationID", notification.AllocationID),
attribute.String("status", notification.Status),
))
}
}
nInfo.Status = jtypes.AllocationStatus(notification.Status)
if notification.ResourceUsage != nil {
nInfo.ResourceUsage.CPUUsagePercent = notification.ResourceUsage.CPUUsagePercent
nInfo.ResourceUsage.MemoryUsedBytes = notification.ResourceUsage.MemoryUsedBytes
nInfo.ResourceUsage.MemoryLimitBytes = notification.ResourceUsage.MemoryLimitBytes
nInfo.ResourceUsage.NetworkRxBytes = notification.ResourceUsage.NetworkRxBytes
nInfo.ResourceUsage.NetworkTxBytes = notification.ResourceUsage.NetworkTxBytes
}
if o.allocs[notification.AllocationID].HasHealthCheck {
if notification.Health.Healthy {
nInfo.Health = "Healthy"
} else {
log.Warnw("allocation_self_reported_unhealthy",
"labels", []string{string(observability.LabelDeployment)},
"allocationID", notification.AllocationID,
"message", notification.Health.Message,
"check_type", notification.Health.CheckType,
// TODO metric in supervisor for unhealthy
"note", "supervisor pull checks remain authoritative")
nInfo.Health = "Unhealthy: " + notification.Health.Message
}
}
nInfo.Timestamp = time.Now().Unix()
o.allocs[notification.AllocationID] = nInfo
// Log resource usage if provided
if notification.ResourceUsage != nil {
log.Debugw("allocation_resource_usage",
"labels", []string{string(observability.LabelDeployment)},
"allocationID", notification.AllocationID,
"cpu_percent", notification.ResourceUsage.CPUUsagePercent,
"memory_used_bytes", notification.ResourceUsage.MemoryUsedBytes,
"memory_limit_bytes", notification.ResourceUsage.MemoryLimitBytes)
}
// metrics
if observability.AllocationHeartbeat != nil {
u := nInfo.ResourceUsage
observability.AllocationHeartbeat.Add(o.ctx, 1, metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", o.id),
attribute.String("allocationID", notification.AllocationID),
attribute.String("status", notification.Status),
))
allocAttrs := metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", o.id),
attribute.String("allocationID", notification.AllocationID),
)
observability.AllocCPUUsage.Record(o.ctx, u.CPUUsagePercent, allocAttrs)
observability.AllocMemUsed.Record(o.ctx, int64(u.MemoryUsedBytes), allocAttrs)
observability.AllocMemLimit.Record(o.ctx, int64(u.MemoryLimitBytes), allocAttrs)
observability.AllocNetRx.Record(o.ctx, int64(u.NetworkRxBytes), allocAttrs)
observability.AllocNetTx.Record(o.ctx, int64(u.NetworkTxBytes), allocAttrs)
}
}
// handleAllocationStatusUpdate receives immediate status change notifications
func (o *BasicOrchestrator) handleAllocationStatusUpdate(msg actor.Envelope) {
defer msg.Discard()
var update jtypes.AllocationStatusUpdate
if err := json.Unmarshal(msg.Message, &update); err != nil {
log.Debugw("unmarshalling_status_update_failed",
"labels", []string{string(observability.LabelDeployment)},
"error", err)
return
}
// Log the status change (observability)
log.Infow("allocation_status_changed",
"labels", []string{string(observability.LabelDeployment)},
"ensembleID", o.id,
"allocationID", update.AllocationID,
"old_status", update.OldStatus,
"new_status", update.NewStatus,
"reason", update.Reason)
// Optionally update manifest for faster visibility
// Supervisor's pull checks will validate and correct if needed
allocID, err := types.ParseAllocationID(update.AllocationID)
if err != nil {
log.Debugf("failed to parse allocation ID %s: %v", update.AllocationID, err)
return
}
manifestKey := allocID.ManifestKey()
manifest := o.Manifest()
a, ok := manifest.Allocations[manifestKey]
if !ok {
log.Debugf("allocation %s not found on the manifest", update.AllocationID)
return
}
// update allocation status
a.Status = jtypes.AllocationStatus(update.NewStatus)
manifest.Allocations[manifestKey] = a
o.updateManifest(manifest)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
crand "crypto/rand"
"encoding/json"
"errors"
"fmt"
"math"
"math/big"
"math/rand"
"sync"
"sync/atomic"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/dms/node/geolocation"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
var ErrCandidateNotFound = errors.New("candidate not found")
// BidCoordinator handles the bidding process for ensemble deployment
type BidCoordinator struct {
eid string // ensembleID
actor actor.Actor
geo geolocation.LocationProvider
nonce uint64
}
// NewBidCoordinator creates a new BidCoordinator instance given a ensemble config copy
func NewBidCoordinator(
eid string, actor actor.Actor,
) (*BidCoordinator, error) {
geo, err := geolocation.NewGeoLocator()
if err != nil {
return nil, fmt.Errorf("failed to create geolocator: %w", err)
}
return &BidCoordinator{
eid: eid,
actor: actor,
geo: geo,
}, nil
}
// bid handles the bid process from beginning to end
//
// TODO: update deployment status when Generating
func (b *BidCoordinator) bid(cfgReader jtypes.EnsembleCfgReader, candidates map[string]jtypes.Bid, expiry time.Time) (map[string]jtypes.Bid, error) {
cfg := cfgReader.Read() // read cfg copy
candidate := make(map[string]jtypes.Bid)
edgeConstraintCache := make(map[string]bool)
// 0. check if one of the ensemble nodes have peer specified
// If bid request to peer specified fails, the entire deployment must fail
nodeForTargetPeer := make(map[string]string)
for nodeID, node := range cfg.Nodes() {
if node.Peer != "" {
nodeForTargetPeer[node.Peer] = nodeID
}
}
// 1. Create bid requests for nodes
log.Debugw("creating initial bid request",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", b.eid,
"nodes: ", cfg.Nodes())
bidrq, err := b.makeInitialBidRequest(cfg)
if err != nil {
return candidate, fmt.Errorf("creating bid request: %w", err)
}
// 2. Collect bids
log.Debugw("collecting bids",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", b.eid)
bidMap := make(map[string][]jtypes.Bid)
peerExclusion := make(map[string]struct{})
// do not bid peers excluded from config
for _, peerID := range cfg.V1.ExcludePeers {
peerExclusion[peerID] = struct{}{}
if _, ok := nodeForTargetPeer[peerID]; ok {
if bid, ok := candidates[nodeForTargetPeer[peerID]]; ok {
bidMap[nodeForTargetPeer[peerID]] = append(bidMap[nodeForTargetPeer[peerID]], bid)
}
}
}
addBid := func(bid jtypes.Bid) bool {
// if peer is already specified on another node, ignore the bid
if _, ok := nodeForTargetPeer[bid.Peer()]; ok {
if nodeForTargetPeer[bid.Peer()] != bid.NodeID() {
return false
}
}
// check that the peer has not already submitted a bid
peerID := bid.Peer()
if _, exclude := peerExclusion[peerID]; exclude {
log.Debugw("ignoring duplicate bid from peer",
"labels", []string{string(observability.LabelDeployment)},
"peerID", peerID)
return false
}
err := bid.Validate()
if err != nil {
log.Debugw("failed to validate bid",
"labels", []string{string(observability.LabelDeployment)},
"peerID", peerID,
"error", err)
return false
}
// verify that this is a node in the ensemble
nodeID := bid.NodeID()
if _, ok := cfg.Node(nodeID); !ok {
log.Debugw("ignoring bid for unknown node",
"labels", []string{string(observability.LabelDeployment)},
"peerID", peerID,
"nodeID", nodeID)
return false
}
// verify the location constraints of the node
loc := bid.Location()
if !acceptPeerLocation(cfg, nodeID, peerID, loc) {
log.Debugw("ignoring out-of-location bid",
"labels", []string{string(observability.LabelDeployment)},
"peerID", peerID,
"nodeID", nodeID,
"location", loc,
)
return false
}
// don't bloat the permutation space
if len(bidMap[nodeID]) >= MaxBidMultiplier {
log.Debugw("node is saturated, ignoring new bid",
"labels", []string{string(observability.LabelDeployment)},
"peerID", peerID,
"nodeID", nodeID)
return false
}
log.Infof("added bid to bidMap from peer %s for %s", peerID, nodeID)
bidMap[nodeID] = append(bidMap[nodeID], bid)
peerExclusion[peerID] = struct{}{}
return true
}
// remove bid from bidMap and peerExclusion
rmBid := func(bid jtypes.Bid) {
peerID := bid.Peer()
delete(peerExclusion, peerID)
nodeID := bid.NodeID()
bids := bidMap[nodeID]
for i, b := range bids {
if b.Peer() == peerID {
bidMap[nodeID] = append(bids[:i], bids[i+1:]...)
break
}
}
}
bidCh, bidDoneCh, bidExpiryTime, err := b.requestBids(cfg, bidrq, expiry)
if err != nil {
return candidate, fmt.Errorf("request bids: %w", err)
}
maxBids := MaxBidMultiplier * len(cfg.Nodes())
b.collectBids(bidCh, bidDoneCh, bidExpiryTime, addBid, maxBids)
// 3. Create a candidate deployment
log.Debugw("creating candidate deployments",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", b.eid)
var (
nextCandidate func() (map[string]jtypes.Bid, bool)
ok bool
)
for time.Now().Before(expiry) {
nextCandidate, ok = b.makeCandidateDeployments(cfg, bidMap)
if ok {
break
}
// we don't have bids for some of our nodes, so we don't have a candidate
// we need to make a residual bid request for the remaining nodes
// Note: in order to facilitate random selection, the residual bid requests
// can drop some of the original bids
log.Debugw("not enough bids for all nodes, making residual request",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", b.eid)
bidrq, err := b.makeResidualBidRequest(cfg, bidMap, rmBid)
if err != nil {
return candidate, fmt.Errorf("creating residual bid request: %w", err)
}
bidCh, bidDoneCh, bidExpiryTime, err := b.requestBids(cfg, bidrq, expiry)
if err != nil {
return candidate, fmt.Errorf("collecting residual bids: %w", err)
}
maxBids := MaxBidMultiplier * (len(cfg.Nodes()) - len(bidMap))
b.collectBids(bidCh, bidDoneCh, bidExpiryTime, addBid, maxBids)
}
if !ok {
log.Debugw("failed to create candidate deployments, retrying",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", b.eid)
return candidate,
fmt.Errorf("%w: failed to create candidate deployments - trying again",
ErrCandidateNotFound)
}
for n, bids := range bidMap {
bidList := make([]string, 0, len(bids))
for _, bid := range bids {
bidList = append(bidList, bid.Peer())
}
log.Infow("node has bids",
"labels", []string{string(observability.LabelDeployment)},
"name", n, "amount", len(bids),
// TODO remove once arrays supported in log scripts
"bids", struct{ Peers []string }{bidList})
}
// 4. Iterate through the candidates trying to find one that satisfies the
// edge constraints
log.Debugf("generating candidate deployment")
for time.Now().Before(expiry) {
candidate, ok = nextCandidate()
if !ok {
return candidate,
fmt.Errorf("%w: failed to find candidate that satisfies edge constraints",
ErrCandidateNotFound)
}
log.Debugf("candidate deployment: %+v", candidate)
if ok := b.verifyEdgeConstraints(cfg, candidate, edgeConstraintCache); !ok {
log.Debugf("candidate does not satisfy edge constraints")
continue
}
break
}
return candidate, nil
}
func (b *BidCoordinator) requestBids(
cfg jtypes.EnsembleConfig,
bidRequest jtypes.EnsembleBidRequest, expiry time.Time,
) (chan jtypes.Bid, chan struct{}, time.Time, error) {
log.Debugw("requesting_bids", "labels", []string{string(observability.LabelDeployment)}, "request", bidRequest)
bidExpiryTime := time.Now().Add(BidRequestTimeout)
if expiry.Before(bidExpiryTime) {
return nil, nil, time.Time{}, fmt.Errorf("not enough time for deployment: %w", ErrDeploymentFailed)
}
bidExpiry := uint64(bidExpiryTime.UnixNano())
// Split requests into direct peer requests and broadcast requests
var directRequests []jtypes.BidRequest
var broadcastRequests []jtypes.BidRequest
for _, req := range bidRequest.Request {
if req.V1 == nil {
continue
}
nodeConfig, ok := cfg.Node(req.V1.NodeID)
if !ok {
continue
}
if nodeConfig.Peer != "" {
// This node has a specific peer target
directRequests = append(directRequests, req)
} else {
// This node needs broadcast
broadcastRequests = append(broadcastRequests, req)
}
}
// Send direct peer requests
log.Debugf("sending direct peer requests: %+v", directRequests)
for _, req := range directRequests {
nodeConfig, _ := cfg.Node(req.V1.NodeID)
targetedReq := jtypes.EnsembleBidRequest{
ID: bidRequest.ID,
Nonce: bidRequest.Nonce,
Request: []jtypes.BidRequest{req},
PeerExclusion: bidRequest.PeerExclusion,
}
err := b.requestBidPeer(targetedReq, nodeConfig, bidExpiry)
if err != nil {
return nil, nil, time.Time{}, fmt.Errorf("requesting bid to targeted peer: %w", err)
}
}
// create reply behavior for this specific ensemble bid request
bidCh := make(chan jtypes.Bid)
bidDoneCh := make(chan struct{})
if err := b.actor.AddBehavior(
behaviors.BidReplyBehavior,
func(msg actor.Envelope) {
defer msg.Discard()
var bid jtypes.Bid
if err := json.Unmarshal(msg.Message, &bid); err != nil {
log.Errorw("failed to unmarshal bid",
"labels", []string{string(observability.LabelDeployment)},
"from", msg.From,
"error", err)
return
}
log.Infow("deployment_bid",
"labels", []string{string(observability.LabelDeployment)},
"from", msg.From)
timer := time.NewTimer(time.Until(bidExpiryTime))
defer timer.Stop()
select {
case bidCh <- bid:
case <-timer.C:
case <-bidDoneCh:
}
},
actor.WithBehaviorExpiry(bidExpiry),
); err != nil {
return nil, nil, time.Time{}, fmt.Errorf("adding bid behavior: %w", err)
}
// Send broadcast
log.Debugf("sending broadcast requests: %+v", broadcastRequests)
if len(broadcastRequests) > 0 {
broadcastReq := jtypes.EnsembleBidRequest{
ID: bidRequest.ID,
Nonce: bidRequest.Nonce,
Request: broadcastRequests,
PeerExclusion: bidRequest.PeerExclusion,
}
err := b.broadcastBid(broadcastReq, bidExpiry)
if err != nil {
return nil, nil, time.Time{}, fmt.Errorf("broadcasting bid request: %w", err)
}
}
return bidCh, bidDoneCh, bidExpiryTime, nil
}
func (b *BidCoordinator) broadcastBid(bidRequest jtypes.EnsembleBidRequest, bidExpiry uint64) error {
msg, err := actor.Message(
b.actor.Handle(),
actor.Handle{},
behaviors.BidRequestBehavior,
bidRequest,
actor.WithMessageTopic(behaviors.BidRequestTopic),
actor.WithMessageReplyTo(behaviors.BidReplyBehavior),
actor.WithMessageExpiry(bidExpiry),
)
if err != nil {
return fmt.Errorf("creating broadcast bid message: %w", err)
}
if err := b.actor.Publish(msg); err != nil {
return fmt.Errorf("publishing broadcast bid request: %w", err)
}
return nil
}
func (b *BidCoordinator) requestBidPeer(
targetedReq jtypes.EnsembleBidRequest, nodeConfig jtypes.NodeConfig, bidExpiry uint64,
) error {
destHandle, err := actor.HandleFromPeerID(nodeConfig.Peer)
if err != nil {
return fmt.Errorf("getting handle of selected peer %s: %w", nodeConfig.Peer, err)
}
log.Infof("sending direct peer request to %s: %+v", nodeConfig.Peer, targetedReq)
msg, err := actor.Message(
b.actor.Handle(),
destHandle,
behaviors.BidRequestBehavior,
targetedReq,
actor.WithMessageReplyTo(behaviors.BidReplyBehavior),
actor.WithMessageExpiry(bidExpiry),
)
if err != nil {
return fmt.Errorf("creating targeted bid message: %w", err)
}
log.Infow("requesting bid from targeted peer",
"labels", []string{string(observability.LabelDeployment)},
"peerID", nodeConfig.Peer,
"orchestratorID", b.eid)
if err := b.actor.Send(msg); err != nil {
return fmt.Errorf("sending targeted bid request: %w", err)
}
return nil
}
func (b *BidCoordinator) collectBids(
bidCh chan jtypes.Bid, bidDoneCh chan struct{}, bidExpiryTime time.Time,
addBid func(jtypes.Bid) bool, maxBids int,
) {
defer close(bidDoneCh)
log.Debugf("collecting bids until: %v", bidExpiryTime)
timer := time.NewTimer(time.Until(bidExpiryTime))
defer timer.Stop()
bidCount := 0
for {
select {
case bid, ok := <-bidCh:
if !ok {
log.Debugw("bid channel closed",
"labels", []string{string(observability.LabelDeployment)})
return
}
log.Debugw("received bid",
"labels", []string{string(observability.LabelDeployment)},
"ensembleID", bid.EnsembleID(),
"peerID", bid.Peer(),
"nodeID", bid.NodeID())
if err := bid.Validate(); err != nil {
log.Warnw("invalid bid",
"ensembleID", bid.EnsembleID(),
"peerID", bid.Peer(),
"nodeID", bid.NodeID(),
"labels", []string{string(observability.LabelDeployment)},
"error", err)
continue
}
if bid.EnsembleID() != b.eid {
log.Warnw("bid for unexpected ensemble id",
"labels", []string{string(observability.LabelDeployment)},
"expectedID", b.eid,
"gotID", bid.EnsembleID(),
"peerID", bid.Peer())
continue
}
if addBid(bid) {
bidCount++
if bidCount >= maxBids {
return
}
}
case <-timer.C:
return
}
}
}
func (b *BidCoordinator) makeCandidateDeployments(
cfg jtypes.EnsembleConfig, bids map[string][]jtypes.Bid,
) (func() (map[string]jtypes.Bid, bool), bool) {
// immediate satisfaction check: we need a bid for every node
if len(cfg.Nodes()) != len(bids) {
return nil, false
}
// first shuffle all the bids to seed the permutation generator
for _, blst := range bids {
rand.Shuffle(len(blst), func(i, j int) {
blst[i], blst[j] = blst[j], blst[i]
})
}
// count the bits in the permutation space; if it is more than 63, we need to use
// a bignum based permutation generator, or it will overflow.
bits := 0
for _, blst := range bids {
bits += int(math.Ceil(math.Log2(float64(len(blst)))))
}
if bits > 63 {
return b.makeCandidateDeploymentBig(cfg, bids)
}
return b.makeCandidateDeploymentSmall(cfg, bids)
}
func (b *BidCoordinator) makeCandidateDeploymentSmall(
cfg jtypes.EnsembleConfig, bids map[string][]jtypes.Bid,
) (func() (map[string]jtypes.Bid, bool), bool) {
// fix the order of permutation
type permutator struct {
mod int64
node string
bids []jtypes.Bid
}
permutators := make([]permutator, 0, len(bids))
modulus := int64(1)
for n, blst := range bids {
permutators = append(permutators, permutator{mod: modulus, node: n, bids: blst})
modulus *= int64(len(blst))
}
// function to get a permutation by index
getPermutation := func(index int64) map[string]jtypes.Bid {
result := make(map[string]jtypes.Bid)
for _, permutator := range permutators {
selection := (index / permutator.mod) % int64(len(permutator.bids))
result[permutator.node] = permutator.bids[selection]
}
return result
}
// and return a function that gets a random next permutation
// note that we cache the constraint results, so potential duplication is ok.
// also note that the permutation space is large enough so that it's ok to skip
// some permutations.
// final note: Obviously we can deterministically generate all permutations in order
// (and we were doing that initially) but this has the problem that we are not
// modifying the network structure enough to get meaningful variance in a reasonable
// time.
nperm := modulus
if nperm > int64(MaxPermutations) {
nperm = int64(MaxPermutations)
}
count := int64(0)
return func() (map[string]jtypes.Bid, bool) {
for count < nperm {
count++
nextPerm := rand.Int63n(nperm)
perm := getPermutation(nextPerm)
if !b.checkPermutationEdgeConstraints(cfg, perm) {
continue
}
return perm, true
}
return nil, false
}, true
}
func (b *BidCoordinator) makeCandidateDeploymentBig(
cfg jtypes.EnsembleConfig, bids map[string][]jtypes.Bid,
) (func() (map[string]jtypes.Bid, bool), bool) {
// Note: this is the same as above with bignums
// fix the order of permutation
type permutator struct {
mod *big.Int
node string
bids []jtypes.Bid
}
permutators := make([]permutator, 0, len(bids))
modulus := big.NewInt(1)
for n, blst := range bids {
permutators = append(permutators, permutator{mod: modulus, node: n, bids: blst})
modulus = new(big.Int).Mul(modulus, big.NewInt(int64(len(blst))))
}
// function to get a permutation by index
getPermutation := func(index *big.Int) map[string]jtypes.Bid {
result := make(map[string]jtypes.Bid)
for _, permutator := range permutators {
selection := int(
new(big.Int).Mod(
new(big.Int).Div(index, permutator.mod),
big.NewInt(int64(len(permutator.bids))),
).Int64(),
)
result[permutator.node] = permutator.bids[selection]
}
return result
}
// and return a function that gets a random next permutation
// note that we cache the constraint results, so potential duplication is ok.
// also note that the permutation space is large enough so that it's ok to skip
// some permutations.
// final note: Obviously we can deterministically generate all permutations in order
// (and we were doing that initially) but this has the problem that we are not
// modifying the network structure enough to get meaningful variance in a reasonable
// time.
nperm := MaxPermutations
count := 0
bytes := make([]byte, (modulus.BitLen()+7)/8)
return func() (map[string]jtypes.Bid, bool) {
for count < nperm {
count++
if _, err := crand.Read(bytes); err != nil {
log.Errorw("random_bytes_read_error",
"labels", []string{string(observability.LabelDeployment)},
"error", err)
return nil, false
}
nextPerm := new(big.Int).SetBytes(bytes)
perm := getPermutation(nextPerm)
if !b.checkPermutationEdgeConstraints(cfg, perm) {
continue
}
return perm, true
}
return nil, false
}, true
}
func (b *BidCoordinator) checkPermutationEdgeConstraints(
cfg jtypes.EnsembleConfig, candidate map[string]jtypes.Bid,
) bool {
for _, cst := range cfg.EdgeConstraints() {
if cst.RTT == 0 {
continue
}
bidS, ok := candidate[cst.S]
if !ok {
log.Errorf("Bid %s not found in candidate", cst.S)
return false
}
bidT, ok := candidate[cst.T]
if !ok {
log.Errorf("Bid %s not found in candidate", cst.T)
return false
}
locS, err := b.geo.Coordinate(bidS.Location())
if err != nil {
log.Errorf("Failed to get location for bid %s: %v", bidS.NodeID(), err)
continue
}
locT, err := b.geo.Coordinate(bidT.Location())
if err != nil {
log.Errorf("Failed to get location for bid %s: %v", bidT.NodeID(), err)
continue
}
distance := geolocation.ComputeGeodesic(locS, locT)
// in milliseconds
minRTT := (distance / geolocation.LightSpeed) * 2 * 1000
if minRTT > float64(cst.RTT) {
log.Debugw("edge constraint not satisfied",
"labels", []string{string(observability.LabelDeployment)},
"minRTT", minRTT,
"constraint", cst.RTT,
"from", cst.S,
"to", cst.T)
return false
}
// TODO: add bandwidth check when that information becomes available
}
return true
}
func (b *BidCoordinator) verifyEdgeConstraints(
cfg jtypes.EnsembleConfig, candidate map[string]jtypes.Bid, cache map[string]bool,
) bool {
var mx sync.Mutex
var wg sync.WaitGroup
var toVerify []jtypes.EdgeConstraint
for _, cst := range cfg.EdgeConstraints() {
bidS := candidate[cst.S]
bidT := candidate[cst.T]
key := bidS.Peer() + ":" + bidT.Peer()
accept, ok := cache[key]
if !ok {
toVerify = append(toVerify, cst)
continue
}
if !accept {
return false
}
}
if len(toVerify) == 0 {
return true
}
accept := true
wg.Add(len(toVerify))
for _, cst := range toVerify {
go func(cst jtypes.EdgeConstraint) {
defer wg.Done()
result := b.verifyEdgeConstraint(candidate, cst)
bidS := candidate[cst.S]
bidT := candidate[cst.T]
key := bidS.Peer() + ":" + bidT.Peer()
mx.Lock()
cache[key] = result
accept = accept && result
mx.Unlock()
}(cst)
}
wg.Wait()
return accept
}
type VerifyEdgeConstraintRequest struct {
EnsembleID string // the ensemble identifier
S, T string // the peer IDs of the edge S->T
RTT uint // maximum RTT in ms (if > 0)
BW uint // minim BW in Kbps
}
type VerifyEdgeConstraintResponse struct {
OK bool
Error string
}
func (b *BidCoordinator) verifyEdgeConstraint(candidate map[string]jtypes.Bid, cst jtypes.EdgeConstraint) bool {
bidS := candidate[cst.S]
bidT := candidate[cst.T]
key := bidS.Peer() + ":" + bidT.Peer()
log.Debugw("verifying edge constraint",
"labels", []string{string(observability.LabelDeployment)},
"peerS", bidS.Peer(),
"peerT", bidT.Peer(),
"constraint", cst)
handle := bidS.Handle()
msg, err := actor.Message(
b.actor.Handle(),
handle,
behaviors.VerifyEdgeConstraintBehavior,
VerifyEdgeConstraintRequest{
EnsembleID: b.eid,
S: bidS.Peer(),
T: bidT.Peer(),
RTT: cst.RTT,
BW: cst.BW,
},
actor.WithMessageTimeout(VerifyEdgeConstraintTimeout),
)
if err != nil {
log.Warnw("creating constraint check message error",
"labels", []string{string(observability.LabelDeployment)},
"edgeKey", key,
"error", err)
return false
}
replyCh, err := b.actor.Invoke(msg)
if err != nil {
log.Warnw("invoke constraint check error",
"labels", []string{string(observability.LabelDeployment)},
"edgeKey", key,
"error", err)
return false
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(VerifyEdgeConstraintTimeout):
return false
}
defer reply.Discard()
var response VerifyEdgeConstraintResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
log.Warnw("unmarshal bid constraint response error",
"labels", []string{string(observability.LabelDeployment)},
"edgeKey", key,
"error", err)
return false
}
if response.Error != "" {
log.Debugw("verify bid constraint not satisfied",
"labels", []string{string(observability.LabelDeployment)},
"edgeKey", key,
"error", response.Error)
}
return response.OK
}
func acceptPeerLocation(
cfg jtypes.EnsembleConfig, nodeID, peerID string, loc jtypes.Location,
) bool {
n, ok := cfg.Node(nodeID)
if !ok {
return false
}
// check explicit peer placement
if n.Peer != "" {
return n.Peer == peerID
}
return loc.Satisfies(n.Location)
}
func (b *BidCoordinator) makeInitialBidRequest(cfg jtypes.EnsembleConfig) (jtypes.EnsembleBidRequest, error) {
return b.ensembleConfigToBidRequest(&cfg)
}
func (b *BidCoordinator) makeResidualBidRequest(
cfg jtypes.EnsembleConfig,
candidate map[string][]jtypes.Bid, rmbid func(jtypes.Bid),
) (jtypes.EnsembleBidRequest, error) {
residualConfig := jtypes.EnsembleConfig{
V1: &jtypes.EnsembleConfigV1{
Allocations: make(map[string]jtypes.AllocationConfig),
Nodes: make(map[string]jtypes.NodeConfig),
Contracts: cfg.Contracts(),
},
}
// randomly drop some of the candidate bids and exclusion
newCandidates := make(map[string][]jtypes.Bid)
newExclusion := make(map[string]struct{})
// drop half of the bids and delete from candidate and exclusion
for n, bids := range candidate {
newBids := make([]jtypes.Bid, 0, len(bids))
desiredSize := int(math.Floor(float64(rand.Intn(len(bids))) / 2))
for i, bid := range bids {
if i > desiredSize {
log.Infof("dropping bid from %s (%s) from candidate ", bid.Peer(), bid.V1.Handle.DID)
rmbid(bid)
continue
}
log.Infof("keeping bid from %s (%s) for candidate", bid.Peer(), bid.V1.Handle.DID)
newBids = append(newBids, bid)
newExclusion[bid.Peer()] = struct{}{}
}
if len(newBids) > 0 {
newCandidates[n] = newBids
}
}
for n, ncfg := range cfg.Nodes() {
if _, exclude := newCandidates[n]; exclude {
log.Debugw(
fmt.Sprintf("node %s is in candidate, skipping", n),
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n,
)
continue
}
residualConfig.V1.Nodes[n] = ncfg
}
for id, ncfg := range residualConfig.V1.Nodes {
log.Debugw(
fmt.Sprintf("still looking for node %s", id),
"labels", []string{string(observability.LabelDeployment)},
"nodeID", id,
)
for _, a := range ncfg.Allocations {
residualConfig.V1.Allocations[a] = cfg.V1.Allocations[a]
}
}
result, err := b.ensembleConfigToBidRequest(&residualConfig)
if err != nil {
return result, err
}
for p := range newExclusion {
result.PeerExclusion = append(result.PeerExclusion, p)
}
return result, nil
}
func (b *BidCoordinator) ensembleConfigToBidRequest(config *jtypes.EnsembleConfig) (jtypes.EnsembleBidRequest, error) {
v1Config := config.V1
ensembleBidRequest := jtypes.EnsembleBidRequest{
ID: b.eid,
Nonce: b.getNonce(),
}
log.Infow("generating bid request",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", b.eid,
"nodes", v1Config.Nodes)
// Log contract information
if len(v1Config.Contracts) > 0 {
log.Debugf("including contracts in bid request",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", b.eid,
"contractCount", len(v1Config.Contracts))
}
nodes := config.Nodes()
for nodeID, nodeConfig := range nodes {
bidRequest := jtypes.BidRequest{
V1: &jtypes.BidRequestV1{
NodeID: nodeID,
Location: nodeConfig.Location,
Contracts: v1Config.Contracts,
},
}
var aggregateResources types.Resources
var executors []jtypes.AllocationExecutor
var staticPorts []int
dynamicPortsCount := 0
for _, allocationName := range nodeConfig.Allocations {
allocationConfig, ok := v1Config.Allocations[allocationName]
if !ok {
continue
}
if !containsExecutor(executors, allocationConfig.Executor) {
executors = append(executors, allocationConfig.Executor)
}
if allocationConfig.Executor == jtypes.ExecutorDocker {
// check if bid includes allocation requiring privileged docker
dockerCfg, err := docker.DecodeSpec(&allocationConfig.Execution)
if err != nil {
return jtypes.EnsembleBidRequest{}, fmt.Errorf("decoding docker spec: %w", err)
}
if dockerCfg.Privileged {
bidRequest.V1.GeneralRequirements.PrivilegedDocker = true
}
}
err := aggregateResources.Add(allocationConfig.Resources)
if err != nil {
return jtypes.EnsembleBidRequest{}, err
}
for _, portConfig := range nodeConfig.Ports {
if portConfig.Allocation == allocationName {
if portConfig.Public == 0 {
dynamicPortsCount++
} else {
staticPorts = append(staticPorts, portConfig.Public)
}
}
}
}
bidRequest.V1.Executors = executors
bidRequest.V1.Resources = aggregateResources
bidRequest.V1.PublicPorts.Static = staticPorts
bidRequest.V1.PublicPorts.Dynamic = dynamicPortsCount
ensembleBidRequest.Request = append(ensembleBidRequest.Request, bidRequest)
}
return ensembleBidRequest, nil
}
func (b *BidCoordinator) getNonce() uint64 {
atomic.AddUint64(&b.nonce, 1)
return b.nonce
}
func containsExecutor(executors []jtypes.AllocationExecutor, executor jtypes.AllocationExecutor) bool {
for _, e := range executors {
if e == executor {
return true
}
}
return false
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
type Committer struct {
ctx context.Context
eid string // ensemble id
actor actor.Actor
allocationIDGenerator types.AllocationIDGenerator
nodeIDGenerator types.NodeIDGenerator
}
func NewCommitter(ctx context.Context, eid string, act actor.Actor, allocationIDGenerator types.AllocationIDGenerator, nodeIDGenerator types.NodeIDGenerator) *Committer {
return &Committer{
ctx: ctx,
eid: eid,
actor: act,
allocationIDGenerator: allocationIDGenerator,
nodeIDGenerator: nodeIDGenerator,
}
}
// parseStandbyNode parses a node name and returns (isStandby, primaryNode, standbyIndex)
func parseStandbyNode(nodeName string) (bool, string, int) {
return types.ParseNodeName(nodeName)
}
// processNodeAllocations processes allocations and port mappings for a node
func (c *Committer) processNodeAllocations(
cfg jtypes.EnsembleConfig,
nodeID string,
isStandby bool,
primaryNode string,
allocationNodes map[string]string,
portsByAllocation map[string][]jtypes.PortConfig,
) {
// For standby nodes, we need to get config from primary
var nodeConfig jtypes.NodeConfig
var ok bool
if isStandby {
nodeConfig, ok = cfg.NodeWithGenerator(primaryNode, c.nodeIDGenerator)
} else {
nodeConfig, ok = cfg.NodeWithGenerator(nodeID, c.nodeIDGenerator)
}
if !ok {
return
}
for _, allocName := range nodeConfig.Allocations {
// Generate manifest key using generator
manifestKey, err := c.allocationIDGenerator.GenerateManifestKey(nodeID, allocName)
if err != nil {
log.Errorf("failed to generate manifest key for %s.%s: %v", nodeID, allocName, err)
continue
}
// Track the node that will deploy this allocation using manifest key
allocationNodes[manifestKey] = nodeID
// TODO: optimize the manifest format and how node/alloc data is
// being passed around. A bit messy at the moment. see #825
for _, portMap := range nodeConfig.Ports {
if portMap.Allocation == allocName {
portsByAllocation[allocName] = append(portsByAllocation[allocName], portMap)
}
}
}
}
// updateManifestAllocations updates manifest allocations with node and port information
func (c *Committer) updateManifestAllocations(
manifest jtypes.EnsembleManifest,
allocationNodes map[string]string,
portsByAllocation map[string][]jtypes.PortConfig,
allocations map[string]actor.Handle,
) {
for _, nodeManifest := range manifest.Nodes {
for _, allocName := range nodeManifest.Allocations {
// Parse the manifest key to get allocation details
allocID, err := types.ParseManifestKey(allocName, c.eid)
if err != nil {
log.Warnf("failed to parse manifest key %s: %v", allocName, err)
continue
}
allocPorts := make(map[int]int)
if ports, ok := portsByAllocation[allocID.ConfigName()]; ok {
for _, pc := range ports {
allocPorts[pc.Public] = pc.Private
}
}
if alloc, ok := manifest.Allocations[allocName]; ok {
alloc.NodeID = allocationNodes[allocName]
alloc.Handle = allocations[allocID.String()]
alloc.Ports = allocPorts
alloc.IsStandby = nodeManifest.RedundancyRole == jtypes.RoleStandby
alloc.RedundancyGroup = allocID.ConfigName()
manifest.Allocations[allocName] = alloc
log.Infow("adding allocation to manifest", "labels", []string{string(observability.LabelDeployment)},
// TODO allocationID?
"allocation", allocName,
"nodeID", alloc.NodeID,
"handle", alloc.Handle,
"isStandby", alloc.IsStandby)
}
}
}
}
// commit works with a two-commit phases:
// - first commit the resources in all the nodes to ensure the deployment is (still)
// feasible.
// - then create all the allocations for provisioning
// - if there are any failures, we need to revert this deployment and start anew
func (c *Committer) commit(
cfgReader jtypes.EnsembleCfgReader,
manifestReader jtypes.ManifestReader,
candidate map[string]jtypes.Bid,
) (jtypes.EnsembleManifest, error) {
var mx sync.Mutex
cfg := cfgReader.Read()
manifest := manifestReader.Read()
// Phase 1: commit
var wg1 sync.WaitGroup
ok := true
wg1.Add(len(candidate))
for n, bid := range candidate {
go func(n string, bid jtypes.Bid) {
defer wg1.Done()
err := c.commitDeployment(cfg, n, bid.Handle())
mx.Lock()
if err != nil {
log.Errorw("commit resources error",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n,
"error", err)
ok = false
mx.Unlock()
return
}
log.Infow("committing deployment",
"nodeID", n,
)
err = updateNodeManifest(manifest.Nodes, n, func(n *jtypes.NodeManifest) {
n.Handle = bid.Handle()
})
if err != nil {
log.Debugw("committing: update node error",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n,
"error", err)
ok = false
}
mx.Unlock()
}(n, bid)
}
wg1.Wait()
if !ok {
return manifest, fmt.Errorf("failed to commit resources: %w", ErrDeploymentFailed)
}
// Phase 2: allocate
var wg2 sync.WaitGroup
allocations := make(map[string]actor.Handle)
wg2.Add(len(candidate))
for n, bid := range candidate {
go func(n string, bid jtypes.Bid) {
defer wg2.Done()
allocated, err := c.allocate(cfg, n, bid.Handle())
mx.Lock()
if err != nil {
log.Errorw("allocation error",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n,
"error", err)
ok = false
} else {
log.Debugw("allocating deployment", "nodeID", n)
for a, h := range allocated {
allocations[a] = h
}
}
mx.Unlock()
}(n, bid)
}
wg2.Wait()
if !ok {
return manifest, fmt.Errorf("failed to allocate resources: %w", ErrDeploymentFailed)
}
allocationNodes := make(map[string]string)
portsByAllocation := make(map[string][]jtypes.PortConfig)
// There are certain details that are filled during provisioning, e.g. allocation
// VPN addresses and public port mappings
for n, bid := range candidate {
// Extract node role information
var role jtypes.RedundancyRole
var primaryNode string
var standbyIndex int
isStandby, parsedPrimary, parsedIndex := parseStandbyNode(n)
if isStandby {
role = jtypes.RoleStandby
primaryNode = parsedPrimary
standbyIndex = parsedIndex
} else {
role = jtypes.RolePrimary
primaryNode = n
standbyIndex = 0
}
// update manifest node
if nmf, ok := manifest.Nodes[n]; ok {
nmf.Peer = bid.Peer()
nmf.PubAddress = append(nmf.PubAddress, bid.PubAddress())
nmf.Handle = bid.Handle()
nmf.Location = bid.Location()
nmf.RedundancyRole = role
nmf.PrimaryNode = primaryNode
nmf.StandbyIndex = standbyIndex
manifest.Nodes[n] = nmf
// TODO: remove from here on the dynamic ensemble modification PR
// use diffs instead, after o.commit
} else {
nmf := jtypes.NodeManifest{
ID: n,
Peer: bid.Peer(),
PubAddress: []string{bid.PubAddress()},
Handle: bid.Handle(),
Location: bid.Location(),
RedundancyRole: role,
PrimaryNode: primaryNode,
StandbyIndex: standbyIndex,
StandbyNodes: make([]string, 0),
}
if role == jtypes.RoleStandby {
nmf.StandbyNodes = make([]string, 0)
} else {
for i := 0; i < standbyIndex; i++ {
nmf.StandbyNodes = append(nmf.StandbyNodes, fmt.Sprintf("%s-standby-%d", primaryNode, i+1))
}
}
manifest.Nodes[n] = nmf
// TODO: manifest partial updates
}
// Process node allocations and port mappings
c.processNodeAllocations(cfg, n, isStandby, primaryNode, allocationNodes, portsByAllocation)
}
// Update manifest allocations with node and port information
c.updateManifestAllocations(manifest, allocationNodes, portsByAllocation, allocations)
return manifest, nil
}
type CommitDeploymentRequest struct {
EnsembleID string
AllocationName string
NodeID string
Resources types.CommittedResources
PortMapping map[int]int
}
type CommitDeploymentResponse struct {
OK bool
Error string
}
func (c *Committer) commitDeployment(cfg jtypes.EnsembleConfig, n string, h actor.Handle) error {
// Check if this is a standby node and get the primary node config
isStandby, primaryNode, _ := parseStandbyNode(n)
var ncfg jtypes.NodeConfig
var ok bool
if isStandby {
ncfg, ok = cfg.NodeWithGenerator(primaryNode, c.nodeIDGenerator)
} else {
ncfg, ok = cfg.NodeWithGenerator(n, c.nodeIDGenerator)
}
if !ok {
return fmt.Errorf("node %s not found", n)
}
if len(ncfg.Allocations) == 0 {
return nil
}
getAllocPortMapping := func(allocName string) map[int]int {
ports := make(map[int]int)
for _, pc := range ncfg.Ports {
if pc.Allocation == allocName {
ports[pc.Public] = pc.Private
}
}
return ports
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(ncfg.Allocations))
aggregatedTimeout := time.Duration(len(ncfg.Allocations)) * CommitDeploymentTimeout
for _, allocName := range ncfg.Allocations {
wg.Add(1)
go func(allocName string) {
defer wg.Done()
allocation, ok := cfg.Allocation(allocName)
if !ok {
errCh <- fmt.Errorf("allocation %s not found: %w", allocName, ErrDeploymentFailed)
return
}
allocPorts := getAllocPortMapping(allocName)
// Generate full allocation ID using generator
fullAllocID, err := c.allocationIDGenerator.GenerateFullAllocationID(c.eid, n, allocName)
if err != nil {
errCh <- fmt.Errorf("failed to generate full allocation ID for %s.%s: %w", n, allocName, err)
return
}
msg, err := actor.Message(
c.actor.Handle(),
h,
behaviors.CommitDeploymentBehavior,
CommitDeploymentRequest{
EnsembleID: c.eid,
AllocationName: fullAllocID,
NodeID: n,
Resources: types.CommittedResources{Resources: allocation.Resources, AllocationID: fullAllocID},
PortMapping: allocPorts,
},
actor.WithMessageTimeout(aggregatedTimeout),
)
if err != nil {
errCh <- fmt.Errorf("failed to create commit message for %s: %w", n, err)
return
}
replyCh, err := c.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("failed to invoke commit for %s: %w", n, err)
return
}
ticker := time.NewTicker(aggregatedTimeout)
defer ticker.Stop()
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
case <-ticker.C:
errCh <- fmt.Errorf("timeout committing for %s: %w", n, ErrDeploymentFailed)
return
}
var response CommitDeploymentResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling commit response for %s: %w", n, err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error committing for %s: %s: %w", n, response.Error, ErrDeploymentFailed)
return
}
}(allocName)
}
wg.Wait()
close(errCh)
var aggErr error
for err := range errCh {
if aggErr == nil {
aggErr = err
continue
} else if err != nil {
aggErr = fmt.Errorf("%w\n%w", aggErr, err)
}
}
if aggErr != nil {
return aggErr
}
return nil
}
func (c *Committer) allocate(cfg jtypes.EnsembleConfig, n string, h actor.Handle) (map[string]actor.Handle, error) {
allocs := make(map[string]jtypes.AllocationDeploymentConfig)
// Check if this is a standby node and get the primary node config
isStandby, primaryNode, _ := parseStandbyNode(n)
var ncfg jtypes.NodeConfig
var ok bool
if isStandby {
ncfg, ok = cfg.NodeWithGenerator(primaryNode, c.nodeIDGenerator)
} else {
ncfg, ok = cfg.NodeWithGenerator(n, c.nodeIDGenerator)
}
if !ok {
return nil, fmt.Errorf("node not found for %s", n)
}
if len(ncfg.Allocations) == 0 {
log.Warnf("no allocations found for %s, won't allocate (ensemble: %s)", n, c.eid)
return nil, nil
}
contracts := cfg.Contracts()
for _, a := range ncfg.Allocations {
acfg, _ := cfg.Allocation(a)
provisionScripts := make(map[string][]byte)
for _, p := range acfg.Provision {
provisionScripts[p] = cfg.V1.Scripts[p]
}
// Generate full allocation ID using generator
fullAllocID, err := c.allocationIDGenerator.GenerateFullAllocationID(c.eid, n, a)
if err != nil {
return nil, fmt.Errorf("failed to generate full allocation ID for %s.%s: %w", n, a, err)
}
fmt.Println("fullAllocID", fullAllocID)
allocs[fullAllocID] = jtypes.AllocationDeploymentConfig{
Type: acfg.Type,
Executor: acfg.Executor,
Resources: acfg.Resources,
Execution: acfg.Execution,
ProvisionScripts: provisionScripts,
Keys: acfg.Keys,
Volume: acfg.Volume,
Contracts: contracts,
}
}
aggregatedTimeout := time.Duration(len(allocs)) * AllocationDeploymentTimeout
msg, err := actor.Message(
c.actor.Handle(),
h,
behaviors.AllocationDeploymentBehavior,
jtypes.AllocationDeploymentRequest{
EnsembleID: c.eid,
NodeID: n,
Allocations: allocs,
},
actor.WithMessageTimeout(aggregatedTimeout),
)
if err != nil {
return nil, fmt.Errorf("failed to create allocation message for %s: %w", n, err)
}
log.Debugf("Invoking allocation for node: %s", n)
replyCh, err := c.actor.Invoke(msg)
if err != nil {
return nil, fmt.Errorf("failed to invoke allocate for %s: %w", n, err)
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(aggregatedTimeout):
return nil, fmt.Errorf("timeout in allocation for %s: %w", n, err)
}
defer reply.Discard()
var response jtypes.AllocationDeploymentResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
return nil, fmt.Errorf("unmarshalling allocation response: %w", err)
}
if !response.OK {
return nil, fmt.Errorf("allocation for %s failed: %s: %w", n, response.Error, ErrDeploymentFailed)
}
// verify that the allocation map has all the allocations
for a := range allocs {
if _, ok := response.Allocations[a]; !ok {
return nil, fmt.Errorf("missing allocation %s for %s: %w", a, n, ErrDeploymentFailed)
}
}
log.Infow("Allocation successful", "nodeID", n)
return response.Allocations, nil
}
func updateNodeManifest(
m map[string]jtypes.NodeManifest,
nodeName string, fn func(*jtypes.NodeManifest),
) error {
if node, ok := m[nodeName]; ok {
fn(&node)
m[nodeName] = node
return nil
}
return fmt.Errorf("node %s not found", nodeName)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/ostafen/clover/v2"
"github.com/ostafen/clover/v2/document"
"github.com/ostafen/clover/v2/query"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
)
const (
deploymentsCollection = "deployments"
sortFieldCreatedAt = "created_at"
sortFieldUpdatedAt = "updated_at"
sortFieldStatus = "status"
)
// DeploymentQuery defines query parameters for advanced deployment filtering
type DeploymentQuery struct {
StatusFilter []jtypes.DeploymentStatus
CreatedAfter *time.Time
CreatedBefore *time.Time
UpdatedAfter *time.Time
UpdatedBefore *time.Time
Limit int
Offset int
SortBy string // e.g., "created_at", "-created_at"
}
// DeploymentStore defines the interface for persisting orchestrator deployments
type DeploymentStore interface {
// Upsert saves or updates a deployment
Upsert(deployment *jtypes.OrchestratorView) error
// Get retrieves a deployment by ID
Get(orchestratorID string) (*jtypes.OrchestratorView, error)
// GetAll retrieves all deployments, optionally filtered by status
GetAll(statusFilter *jtypes.DeploymentStatus) ([]*jtypes.OrchestratorView, error)
// Query retrieves deployments with advanced filtering, pagination, and sorting
Query(query DeploymentQuery) ([]*jtypes.OrchestratorView, int, error)
// Returns: deployments, total count, error
// Delete removes a deployment by ID
Delete(orchestratorID string) error
// Prune removes deployments older than the specified time
Prune(olderThan time.Time) error
// Clear removes all deployments
Clear() error
}
// cloverDeploymentStore implements DeploymentStore using CloverDB with bytes serialization
type cloverDeploymentStore struct {
db *clover.DB
}
// NewCloverDeploymentStore creates a new cloverDB-backed deployment store
func NewCloverDeploymentStore(db *clover.DB) (DeploymentStore, error) {
if db == nil {
return nil, errors.New("db is nil")
}
return &cloverDeploymentStore{db: db}, nil
}
// Upsert inserts or updates a deployment in CloverDB (Upsert behavior)
func (s *cloverDeploymentStore) Upsert(deployment *jtypes.OrchestratorView) error {
if deployment == nil {
return errors.New("deployment is nil")
}
// Marshal the deployment to bytes
bts, err := json.Marshal(deployment)
if err != nil {
return fmt.Errorf("failed to marshal deployment: %w", err)
}
// Check if deployment exists
q := query.NewQuery(deploymentsCollection).Where(query.Field("orchestrator_id").Eq(deployment.OrchestratorID))
existingDoc, err := s.db.FindFirst(q)
if err != nil {
return fmt.Errorf("failed to check existing deployment: %w", err)
}
now := time.Now()
if existingDoc != nil {
// Update the existing document
update := document.NewDocument()
update.Set("orchestrator_id", deployment.OrchestratorID)
update.Set("status", int(deployment.Status))
update.Set("updated_at", now.Unix())
update.Set("deployment_data", bts)
// Set completion time if status changed to terminal
if deployment.Status == jtypes.DeploymentStatusCompleted ||
deployment.Status == jtypes.DeploymentStatusFailed {
update.Set("completed_at", now.Unix())
}
return s.db.Update(q, update.AsMap())
}
// Insert a new document
doc := document.NewDocument()
doc.Set("orchestrator_id", deployment.OrchestratorID)
doc.Set("status", int(deployment.Status))
doc.Set("created_at", now.Unix())
doc.Set("updated_at", now.Unix())
doc.Set("deployment_data", bts)
// Set completion time if it's already in a terminal state
if deployment.Status == jtypes.DeploymentStatusCompleted ||
deployment.Status == jtypes.DeploymentStatusFailed {
doc.Set("completed_at", now.Unix())
}
return s.db.Insert(deploymentsCollection, doc)
}
// Get retrieves a deployment by OrchestratorID
func (s *cloverDeploymentStore) Get(orchestratorID string) (*jtypes.OrchestratorView, error) {
if orchestratorID == "" {
return nil, errors.New("orchestratorID is empty")
}
q := query.NewQuery(deploymentsCollection).Where(query.Field("orchestrator_id").Eq(orchestratorID))
doc, err := s.db.FindFirst(q)
if err != nil || doc == nil {
return nil, fmt.Errorf("failed to find deployment by ID: %w", err)
}
var deployment jtypes.OrchestratorView
data := doc.Get("deployment_data")
deploymentData, ok := data.([]byte)
if !ok {
return nil, errors.New("no deployment data available")
}
err = json.Unmarshal(deploymentData, &deployment)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal deployment: %w", err)
}
return &deployment, nil
}
// GetAll retrieves all deployments, optionally filtered by status
func (s *cloverDeploymentStore) GetAll(statusFilter *jtypes.DeploymentStatus) ([]*jtypes.OrchestratorView, error) {
q := query.NewQuery(deploymentsCollection)
if statusFilter != nil {
q = q.Where(query.Field("status").Eq(int(*statusFilter)))
}
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve all deployments: %w", err)
}
allDeployments := make([]*jtypes.OrchestratorView, 0, len(docs))
for _, doc := range docs {
var deployment jtypes.OrchestratorView
data := doc.Get("deployment_data")
deploymentData, ok := data.([]byte)
if !ok {
return nil, fmt.Errorf("no deployment data available for document")
}
err = json.Unmarshal(deploymentData, &deployment)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal single deployment: %w", err)
}
allDeployments = append(allDeployments, &deployment)
}
return allDeployments, nil
}
// Delete removes a deployment by OrchestratorID
func (s *cloverDeploymentStore) Delete(orchestratorID string) error {
if orchestratorID == "" {
return errors.New("orchestratorID is empty")
}
q := query.NewQuery(deploymentsCollection).Where(query.Field("orchestrator_id").Eq(orchestratorID))
return s.db.Delete(q)
}
// Prune removes deployments older than the specified time
func (s *cloverDeploymentStore) Prune(olderThan time.Time) error {
q := query.NewQuery(deploymentsCollection).Where(query.Field("created_at").Lt(olderThan.Unix()))
return s.db.Delete(q)
}
// Clear removes all deployments
func (s *cloverDeploymentStore) Clear() error {
q := query.NewQuery(deploymentsCollection)
return s.db.Delete(q)
}
// Query retrieves deployments with advanced filtering, pagination, and sorting
func (s *cloverDeploymentStore) Query(q DeploymentQuery) ([]*jtypes.OrchestratorView, int, error) {
// Build combined condition using the same pattern as QueryEvents
// This ensures all conditions are properly combined as AND conditions
var combinedCondition query.Criteria
hasCondition := false
// Status filter
if len(q.StatusFilter) > 0 {
statusInts := make([]interface{}, len(q.StatusFilter))
for i, status := range q.StatusFilter {
statusInts[i] = int(status)
}
statusCondition := query.Field("status").In(statusInts...)
if !hasCondition {
combinedCondition = statusCondition
} else {
combinedCondition = combinedCondition.And(statusCondition)
}
hasCondition = true
}
// Date filters
if q.CreatedAfter != nil {
createdAfterCondition := query.Field("created_at").GtEq(q.CreatedAfter.Unix())
if !hasCondition {
combinedCondition = createdAfterCondition
} else {
combinedCondition = combinedCondition.And(createdAfterCondition)
}
hasCondition = true
}
if q.CreatedBefore != nil {
createdBeforeCondition := query.Field("created_at").LtEq(q.CreatedBefore.Unix())
if !hasCondition {
combinedCondition = createdBeforeCondition
} else {
combinedCondition = combinedCondition.And(createdBeforeCondition)
}
hasCondition = true
}
if q.UpdatedAfter != nil {
updatedAfterCondition := query.Field("updated_at").GtEq(q.UpdatedAfter.Unix())
if !hasCondition {
combinedCondition = updatedAfterCondition
} else {
combinedCondition = combinedCondition.And(updatedAfterCondition)
}
hasCondition = true
}
if q.UpdatedBefore != nil {
updatedBeforeCondition := query.Field("updated_at").LtEq(q.UpdatedBefore.Unix())
if !hasCondition {
combinedCondition = updatedBeforeCondition
} else {
combinedCondition = combinedCondition.And(updatedBeforeCondition)
}
hasCondition = true
}
// Build base query for counting (without sorting/pagination)
baseQuery := query.NewQuery(deploymentsCollection)
if hasCondition {
baseQuery = baseQuery.Where(combinedCondition)
}
// Count total using base query
totalDocs, err := s.db.Count(baseQuery)
if err != nil {
return nil, 0, fmt.Errorf("failed to count deployments: %w", err)
}
// Build fetch query with sorting and pagination
fetchQuery := query.NewQuery(deploymentsCollection)
if hasCondition {
fetchQuery = fetchQuery.Where(combinedCondition)
}
// Apply sorting
if q.SortBy != "" {
sortField := q.SortBy
direction := 1
if strings.HasPrefix(sortField, "-") {
direction = -1
sortField = sortField[1:]
}
// Map field names to DB fields
dbField := mapSortField(sortField)
fetchQuery = fetchQuery.Sort(query.SortOption{Field: dbField, Direction: direction})
} else {
// Default sort: newest first
fetchQuery = fetchQuery.Sort(query.SortOption{Field: "created_at", Direction: -1})
}
// Apply limit to query (if we have offset, we need offset+limit to get enough results)
// If no limit specified, we still want to limit to prevent huge results
effectiveLimit := q.Limit
if effectiveLimit <= 0 {
effectiveLimit = 1000 // Default max to prevent unbounded queries
}
if q.Offset > 0 {
// Need to fetch offset+limit documents, then skip the first offset
fetchQuery = fetchQuery.Limit(q.Offset + effectiveLimit)
} else {
fetchQuery = fetchQuery.Limit(effectiveLimit)
}
// Execute query
allDocs, err := s.db.FindAll(fetchQuery)
if err != nil {
return nil, 0, fmt.Errorf("failed to query deployments: %w", err)
}
// Handle offset manually (slice the results)
var docs []*document.Document
if q.Offset > 0 {
if q.Offset < len(allDocs) {
docs = allDocs[q.Offset:]
// Apply limit after offset
if q.Limit > 0 && len(docs) > q.Limit {
docs = docs[:q.Limit]
}
}
// else: offset beyond results, docs remains empty
} else {
docs = allDocs
}
// Unmarshal results
deployments := make([]*jtypes.OrchestratorView, 0, len(docs))
for _, doc := range docs {
var deployment jtypes.OrchestratorView
data := doc.Get("deployment_data")
deploymentData, ok := data.([]byte)
if !ok {
continue
}
if err := json.Unmarshal(deploymentData, &deployment); err != nil {
continue
}
deployments = append(deployments, &deployment)
}
return deployments, totalDocs, nil
}
// mapSortField maps sort field names to DB field names
func mapSortField(field string) string {
switch field {
case sortFieldCreatedAt, "createdAt":
return sortFieldCreatedAt
case sortFieldUpdatedAt, "updatedAt":
return sortFieldUpdatedAt
case sortFieldStatus:
return sortFieldStatus
default:
return field
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"fmt"
"maps"
"slices"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/utils"
)
func newConfigForDeploymentUpdate(
oldCfg, modifiedCfg jtypes.EnsembleConfig,
existingNodes map[string]string,
) (jtypes.EnsembleConfig, error) {
nodes := make(map[string]jtypes.NodeConfig)
allocations := make(map[string]jtypes.AllocationConfig)
excludePeers := modifiedCfg.V1.ExcludePeers
for n, ncfg := range modifiedCfg.Nodes() {
newAllocations, err := identifyNewAllocations(oldCfg, modifiedCfg, n)
if err != nil {
return jtypes.EnsembleConfig{}, err
}
if peer, ok := existingNodes[n]; ok {
ncfg.Peer = peer
if len(newAllocations) == 0 {
excludePeers = append(excludePeers, peer)
}
}
for alloc, acfg := range newAllocations {
allocations[alloc] = acfg
}
ncfg.Allocations = utils.MapKeysToSlice(newAllocations)
var ports []jtypes.PortConfig
for _, port := range ncfg.Ports {
if slices.Contains(ncfg.Allocations, port.Allocation) {
ports = append(ports, port)
}
}
ncfg.Ports = ports
nodes[n] = ncfg
}
return jtypes.EnsembleConfig{
V1: &jtypes.EnsembleConfigV1{
EscalationStrategy: modifiedCfg.V1.EscalationStrategy,
Allocations: allocations,
Nodes: nodes,
Scripts: modifiedCfg.V1.Scripts,
Keys: modifiedCfg.V1.Keys,
Edges: modifiedCfg.EdgeConstraints(),
Supervisor: modifiedCfg.V1.Supervisor,
ExcludePeers: excludePeers,
Subnet: modifiedCfg.Subnet(),
Metadata: modifiedCfg.V1.Metadata,
},
}, nil
}
// identifyNewAllocations returns the set of allocations for a node that are new in the modifiedConfig
func identifyNewAllocations(
currentConfig, modifiedConfig jtypes.EnsembleConfig,
node string,
) (map[string]jtypes.AllocationConfig, error) {
allocations := make(map[string]jtypes.AllocationConfig)
ncfg, ok := modifiedConfig.Node(node)
if !ok {
return allocations, nil
}
var currentAllocations []string
if cfg, ok := currentConfig.Node(node); ok {
currentAllocations = cfg.Allocations
}
for _, alloc := range ncfg.Allocations {
if slices.Contains(currentAllocations, alloc) {
continue
}
if allocConfig, ok := modifiedConfig.Allocation(alloc); ok {
allocations[alloc] = allocConfig
} else {
return nil, fmt.Errorf("allocation %s not found", alloc)
}
}
return allocations, nil
}
// identifyRemovedAllocations returns the set of allocations for a node that are removed in the modifiedConfig
func identifyRemovedAllocations(
currentConfig, modifiedConfig jtypes.EnsembleConfig,
node string,
) map[string]jtypes.AllocationConfig {
allocations := make(map[string]jtypes.AllocationConfig)
ncfg, ok := modifiedConfig.Node(node)
if !ok {
return allocations
}
if cfg, ok := currentConfig.Node(node); ok {
for _, alloc := range cfg.Allocations {
if !slices.Contains(ncfg.Allocations, alloc) {
if allocConfig, ok := currentConfig.Allocation(alloc); ok {
allocations[alloc] = allocConfig
}
}
}
}
return allocations
}
// identifyRemovedNodes returns the set of nodes that are removed in the modifiedConfig
func identifyRemovedNodes(
currentConfig, modifiedConfig jtypes.EnsembleConfig,
) map[string]jtypes.NodeConfig {
nodes := make(map[string]jtypes.NodeConfig)
for name, node := range currentConfig.Nodes() {
if _, exists := modifiedConfig.Node(name); !exists {
nodes[name] = node
}
}
return nodes
}
// identifyRelocatedNodes returns the set of node names whose Peer is changed
// between currentConfig and modifiedConfig. These will be treated as remove+add (relocation).
func identifyRelocatedNodes(currentConfig, modifiedConfig jtypes.EnsembleConfig) map[string]jtypes.NodeConfig {
nodes := make(map[string]jtypes.NodeConfig)
for name, currCfg := range currentConfig.Nodes() {
if modCfg, ok := modifiedConfig.Node(name); ok {
if currCfg.Peer != modCfg.Peer {
nodes[name] = currCfg
}
}
}
return nodes
}
// newConfigForRemovedNodes builds a new ensemble config based on
// a base ensemble config + removed nodes
func newConfigForRemovedNodes(
oldCfg, modifieCfg jtypes.EnsembleConfig,
) (jtypes.EnsembleConfig, error) {
nodes := identifyRemovedNodes(oldCfg, modifieCfg)
maps.Copy(nodes, identifyRelocatedNodes(oldCfg, modifieCfg))
nodesCfg := jtypes.EnsembleConfig{
V1: &jtypes.EnsembleConfigV1{
Allocations: make(map[string]jtypes.AllocationConfig),
Nodes: nodes,
Scripts: modifieCfg.V1.Scripts,
Keys: modifieCfg.V1.Keys,
Edges: modifieCfg.EdgeConstraints(),
Supervisor: modifieCfg.V1.Supervisor,
},
}
// add their allocations
allocationsNames := make([]string, 0)
for _, node := range nodesCfg.Nodes() {
allocationsNames = append(allocationsNames, node.Allocations...)
}
allocations := make(map[string]jtypes.AllocationConfig)
for _, name := range allocationsNames {
alloc, ok := oldCfg.Allocation(name)
if !ok {
return jtypes.EnsembleConfig{},
fmt.Errorf("new config for nodes: allocation %s not found", name)
}
allocations[name] = alloc
}
nodesCfg.V1.Allocations = allocations
return nodesCfg, nil
}
// manifestOnlyForNodes returns a manifest that only contains the given nodes
// and their allocations
func manifestOnlyForNodes(mf jtypes.EnsembleManifest, nodes []string,
) (jtypes.EnsembleManifest, error) {
newMf := jtypes.EnsembleManifest{
ID: mf.ID,
Orchestrator: mf.Orchestrator,
Nodes: make(map[string]jtypes.NodeManifest),
Allocations: make(map[string]jtypes.AllocationManifest),
Subnet: mf.Subnet,
}
for _, name := range nodes {
nmf, ok := mf.Node(name)
if !ok {
return jtypes.EnsembleManifest{},
fmt.Errorf("manifestOnlyForNodes: node %s not found", name)
}
newMf.Nodes[name] = nmf
for _, allocName := range nmf.Allocations {
alloc, ok := mf.Allocation(allocName)
if !ok {
return jtypes.EnsembleManifest{},
fmt.Errorf("manifestOnlyForNodes: allocation %s not found", allocName)
}
newMf.Allocations[allocName] = alloc
}
}
return newMf, nil
}
// validateEnsembleUpdate checks if a given ensemble modification is valid
// A node with it's peer changed is considered as relocated and will be treated as a new node.
//
// Invalid modifications:
// - Removing supervisor
// - Changing node location
//
// Unsupported modifications:
// - Changing node's ports (except when adding for new node's allocations)
// - Adding edge constraints for already deployed nodes, unless at least one endpoint is a relocated node
// - Changing supervisor strategy
func validateEnsembleUpdate(currentConfig, modifiedConfig jtypes.EnsembleConfig) error {
// 1. Supervisor must not be removed
if len(currentConfig.V1.Supervisor.Allocations) > 0 &&
len(modifiedConfig.V1.Supervisor.Allocations) == 0 {
return fmt.Errorf("invalid modification: removing supervisor is not allowed")
}
// 2. Validate existing nodes
for name, currNode := range currentConfig.Nodes() {
modNode, ok := modifiedConfig.Node(name)
if !ok || modNode.Peer != currNode.Peer {
continue
}
if !validateLocationConstraintsUpdate(currNode.Location, modNode.Location) {
return fmt.Errorf("invalid modification: changing node location for node '%s' is not allowed", name)
}
// Track new allocations
newAllocs := map[string]bool{}
for _, alloc := range modNode.Allocations {
if !slices.Contains(currNode.Allocations, alloc) {
newAllocs[alloc] = true
}
}
// Validate existing port configurations
for _, currPort := range currNode.Ports {
if _, isNew := newAllocs[currPort.Allocation]; isNew {
continue
}
// If the allocation is removed, skip validation for the port
if !slices.Contains(modNode.Allocations, currPort.Allocation) {
continue
}
found := false
for _, modPort := range modNode.Ports {
if modPort.Allocation == currPort.Allocation &&
modPort.Public == currPort.Public &&
modPort.Private == currPort.Private {
found = true
break
}
}
if !found {
return fmt.Errorf("unsupported modification: changing node's ports for existing allocations on node '%s' is not supported", name)
}
}
}
// 3. Edge constraints: no new edges between already deployed nodes
// Allow new edges if at least one endpoint is a relocated node (peer changed)
existing := currentConfig.Nodes()
relocated := identifyRelocatedNodes(currentConfig, modifiedConfig)
for _, edge := range modifiedConfig.V1.Edges {
_, sExists := existing[edge.S]
_, tExists := existing[edge.T]
_, sReloc := relocated[edge.S]
_, tReloc := relocated[edge.T]
// Allow if either endpoint is a new node or a relocated node
if !sExists || !tExists || sReloc || tReloc {
continue
}
// Skip if edge already exists
alreadyExists := false
for _, currEdge := range currentConfig.V1.Edges {
if (currEdge.S == edge.S && currEdge.T == edge.T) ||
(currEdge.S == edge.T && currEdge.T == edge.S) {
alreadyExists = true
break
}
}
if !alreadyExists {
return fmt.Errorf("unsupported modification: adding edge constraints for already deployed nodes '%s' and '%s' is not supported", edge.S, edge.T)
}
}
// 4. Supervisor strategy must remain unchanged
if currentConfig.V1.Supervisor.Strategy != modifiedConfig.V1.Supervisor.Strategy {
return fmt.Errorf("unsupported modification: changing supervisor strategy is not supported")
}
return nil
}
// Helper function to check if location constraints updates are valid
func validateLocationConstraintsUpdate(current, newLoc jtypes.LocationConstraints) bool {
covers := func(super, sub []jtypes.Location) bool {
for _, loc := range sub {
if !slices.ContainsFunc(super, func(candidate jtypes.Location) bool {
return candidate.Equal(loc)
}) {
return false
}
}
return true
}
hasNewAccept := len(newLoc.Accept) > 0
hasCurrentAccept := len(current.Accept) > 0
hasNewReject := len(newLoc.Reject) > 0
hasCurrentReject := len(current.Reject) > 0
if hasNewAccept {
// Accept list can only broaden: newLoc must include all current accepted locations
if !hasCurrentAccept || !covers(newLoc.Accept, current.Accept) {
return false
}
} else if hasCurrentAccept && hasNewReject {
// Cannot switch from accept list mode to reject list mode
return false
}
if hasNewReject {
// Reject list can only shrink: newLoc must be a subset of current reject list
if !hasCurrentReject || !covers(current.Reject, newLoc.Reject) {
return false
}
}
return true
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"fmt"
"sort"
"strings"
"sync"
"time"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
)
// mockDeploymentStore implements DeploymentStore for testing
type mockDeploymentStore struct {
mu sync.RWMutex
deployments map[string]*jtypes.OrchestratorView
}
// NewMockDeploymentStore creates a new mock deployment store for testing
func NewMockDeploymentStore() DeploymentStore {
return &mockDeploymentStore{
deployments: make(map[string]*jtypes.OrchestratorView),
}
}
func (m *mockDeploymentStore) Upsert(deployment *jtypes.OrchestratorView) error {
m.mu.Lock()
defer m.mu.Unlock()
now := time.Now()
if existing, exists := m.deployments[deployment.OrchestratorID]; exists {
deployment.CreatedAt = existing.CreatedAt
deployment.UpdatedAt = now
if (deployment.Status == jtypes.DeploymentStatusCompleted ||
deployment.Status == jtypes.DeploymentStatusFailed) &&
existing.Status != deployment.Status {
deployment.CompletedAt = &now
} else {
deployment.CompletedAt = existing.CompletedAt
}
} else {
deployment.CreatedAt = now
deployment.UpdatedAt = now
}
m.deployments[deployment.OrchestratorID] = deployment
return nil
}
func (m *mockDeploymentStore) Get(orchestratorID string) (*jtypes.OrchestratorView, error) {
m.mu.RLock()
defer m.mu.RUnlock()
deployment, exists := m.deployments[orchestratorID]
if !exists {
return nil, fmt.Errorf("deployment not found")
}
return deployment, nil
}
func (m *mockDeploymentStore) GetAll(statusFilter *jtypes.DeploymentStatus) ([]*jtypes.OrchestratorView, error) {
m.mu.RLock()
defer m.mu.RUnlock()
var result []*jtypes.OrchestratorView
for _, deployment := range m.deployments {
if statusFilter == nil || deployment.Status == *statusFilter {
result = append(result, deployment)
}
}
return result, nil
}
func (m *mockDeploymentStore) Delete(orchestratorID string) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, exists := m.deployments[orchestratorID]; !exists {
return fmt.Errorf("deployment not found")
}
delete(m.deployments, orchestratorID)
return nil
}
func (m *mockDeploymentStore) Prune(olderThan time.Time) error {
m.mu.Lock()
defer m.mu.Unlock()
for id, deployment := range m.deployments {
if deployment.CreatedAt.Before(olderThan) {
delete(m.deployments, id)
}
}
return nil
}
func (m *mockDeploymentStore) Clear() error {
m.mu.Lock()
defer m.mu.Unlock()
m.deployments = make(map[string]*jtypes.OrchestratorView)
return nil
}
func (m *mockDeploymentStore) Query(q DeploymentQuery) ([]*jtypes.OrchestratorView, int, error) {
m.mu.RLock()
defer m.mu.RUnlock()
// Collect all deployments
allDeployments := make([]*jtypes.OrchestratorView, 0, len(m.deployments))
for _, deployment := range m.deployments {
allDeployments = append(allDeployments, deployment)
}
// Apply filters
filtered := make([]*jtypes.OrchestratorView, 0)
for _, deployment := range allDeployments {
// Status filter
if len(q.StatusFilter) > 0 {
found := false
for _, status := range q.StatusFilter {
if deployment.Status == status {
found = true
break
}
}
if !found {
continue
}
}
// Date filters
if q.CreatedAfter != nil && deployment.CreatedAt.Before(*q.CreatedAfter) {
continue
}
if q.CreatedBefore != nil && deployment.CreatedAt.After(*q.CreatedBefore) {
continue
}
if q.UpdatedAfter != nil && deployment.UpdatedAt.Before(*q.UpdatedAfter) {
continue
}
if q.UpdatedBefore != nil && deployment.UpdatedAt.After(*q.UpdatedBefore) {
continue
}
filtered = append(filtered, deployment)
}
// Count total before sorting/pagination
total := len(filtered)
// Apply sorting
if q.SortBy != "" {
sortField := q.SortBy
descending := strings.HasPrefix(sortField, "-")
if descending {
sortField = sortField[1:]
}
sort.Slice(filtered, func(i, j int) bool {
var less bool
switch mapSortField(sortField) {
case "created_at":
less = filtered[i].CreatedAt.Before(filtered[j].CreatedAt)
case "updated_at":
less = filtered[i].UpdatedAt.Before(filtered[j].UpdatedAt)
case "status":
less = int(filtered[i].Status) < int(filtered[j].Status)
default:
less = false
}
if descending {
return !less
}
return less
})
} else {
// Default sort: newest first
sort.Slice(filtered, func(i, j int) bool {
return filtered[i].CreatedAt.After(filtered[j].CreatedAt)
})
}
// Apply pagination
start := q.Offset
if start < 0 {
start = 0
}
if start > len(filtered) {
return []*jtypes.OrchestratorView{}, total, nil
}
end := len(filtered)
if q.Limit > 0 && start+q.Limit < end {
end = start + q.Limit
}
if start >= len(filtered) {
return []*jtypes.OrchestratorView{}, total, nil
}
return filtered[start:end], total, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"context"
"fmt"
"sync"
"time"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/actor"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/tokenomics/eventhandler"
"gitlab.com/nunet/device-management-service/types"
)
type MockOrchestrator struct {
lock sync.Mutex
ctx context.Context
cancel func()
fs afero.Afero
workDir string
actor actor.Actor
id string
cfg jtypes.EnsembleConfig
manifest jtypes.EnsembleManifest
subnetManifest jtypes.SubnetManifest
status jtypes.DeploymentStatus
deploymentSnapshot jtypes.DeploymentSnapshot
supervisor *Supervisor
nonce uint64
}
func NewMockOrchestrator(
ctx context.Context,
fs afero.Afero,
workDir string,
id string,
oActor actor.Actor,
cfg jtypes.EnsembleConfig,
) (*MockOrchestrator, error) {
mo := &MockOrchestrator{
actor: oActor,
id: id,
cfg: cfg,
ctx: ctx,
fs: fs,
workDir: workDir,
subnetManifest: jtypes.SubnetManifest{},
deploymentSnapshot: jtypes.DeploymentSnapshot{},
nonce: 0,
supervisor: NewSupervisor(ctx, oActor, id),
}
mo.ctx, mo.cancel = context.WithCancel(ctx)
return mo, nil
}
func (m *MockOrchestrator) Deploy(_ time.Time) error {
m.lock.Lock()
defer m.lock.Unlock()
m.manifest = m.newManifest(m.cfg)
return nil
}
func (m *MockOrchestrator) newManifest(
cfg jtypes.EnsembleConfig,
) jtypes.EnsembleManifest {
manifest := jtypes.EnsembleManifest{
ID: m.id,
Orchestrator: m.actor.Handle(),
Metadata: cfg.V1.Metadata,
Allocations: make(map[string]jtypes.AllocationManifest),
Nodes: make(map[string]jtypes.NodeManifest),
}
for name, alloc := range cfg.Allocations() {
amf := jtypes.AllocationManifest{
ID: types.NewAllocationID(m.id, "mock-node", name).String(),
Type: alloc.Type,
NodeID: "mock-node",
DNSName: alloc.DNSName + ".internal",
Healthcheck: alloc.HealthCheck,
Status: jtypes.AllocationPending,
}
manifest.Allocations[name] = amf
}
for name, node := range cfg.Nodes() {
nmf := jtypes.NodeManifest{
ID: name,
Allocations: node.Allocations,
Peer: node.Peer,
}
manifest.Nodes[name] = nmf
}
manifest.Subnet = cfg.V1.Subnet
return manifest
}
func (m *MockOrchestrator) Shutdown() error {
m.lock.Lock()
defer m.lock.Unlock()
m.status = jtypes.DeploymentStatusCompleted
return nil
}
func (m *MockOrchestrator) Stop() {
m.lock.Lock()
defer m.lock.Unlock()
m.status = jtypes.DeploymentStatusCompleted
}
// helper to set status when testing
func (m *MockOrchestrator) SetStatus(status jtypes.DeploymentStatus) {
m.lock.Lock()
defer m.lock.Unlock()
m.status = status
}
func (m *MockOrchestrator) Status() jtypes.DeploymentStatus {
m.lock.Lock()
defer m.lock.Unlock()
return m.status
}
func (m *MockOrchestrator) Manifest() jtypes.EnsembleManifest {
m.lock.Lock()
defer m.lock.Unlock()
return m.manifest.Clone()
}
func (m *MockOrchestrator) SubnetManifest() jtypes.SubnetManifest {
m.lock.Lock()
defer m.lock.Unlock()
return m.subnetManifest
}
func (m *MockOrchestrator) Config() jtypes.EnsembleConfig {
return jtypes.EnsembleConfig{}
}
func (m *MockOrchestrator) ID() string {
return m.id
}
func (m *MockOrchestrator) ActorPrivateKey() crypto.PrivKey {
return m.actor.Security().PrivKey()
}
func (m *MockOrchestrator) DeploymentSnapshot() jtypes.DeploymentSnapshot {
return jtypes.DeploymentSnapshot{}
}
func (m *MockOrchestrator) GetAllocationLogs(_ string) (AllocationLogsResponse, error) {
return AllocationLogsResponse{}, nil
}
func (m *MockOrchestrator) WriteAllocationLogs(_ string, _, _ []byte) (string, error) {
return "", nil
}
func (m *MockOrchestrator) Update(_ jtypes.EnsembleConfig, _ time.Time) error {
return nil
}
func (m *MockOrchestrator) StatusChannel(_ context.Context) <-chan jtypes.DeploymentStatus {
return make(chan jtypes.DeploymentStatus)
}
func (m *MockOrchestrator) AllocationInfo() map[string]jtypes.AllocationInfo {
return make(map[string]jtypes.AllocationInfo)
}
func (m *MockOrchestrator) UpdateAllocationStatus() {}
func (m *MockOrchestrator) Done() <-chan struct{} {
return nil
}
type MockOrchestratorRegistry struct {
lock sync.RWMutex
orchestrators map[string]Orchestrator // map of orchestrators
}
var _ Registry = (*MockOrchestratorRegistry)(nil)
// NewRegistry creates a new orchestrator registry
func NewMockOrchestratorRegistry() *MockOrchestratorRegistry {
return &MockOrchestratorRegistry{
orchestrators: make(map[string]Orchestrator),
}
}
func (m *MockOrchestratorRegistry) NewOrchestrator(
ctx context.Context, fs afero.Afero, workDir string,
id string, actor actor.Actor, cfg jtypes.EnsembleConfig,
_ types.NodeIDGenerator, _ types.AllocationIDGenerator,
_ *eventhandler.EventHandler, _ map[string]types.ContractConfig,
) (Orchestrator, error) {
m.lock.RLock()
if _, ok := m.orchestrators[id]; ok {
m.lock.RUnlock()
return nil, ErrOrchestratorExists
}
m.lock.RUnlock()
mo, err := NewMockOrchestrator(ctx, fs, workDir, id, actor, cfg)
if err != nil {
return nil, fmt.Errorf("failed to create orchestrator: %w", err)
}
m.lock.Lock()
defer m.lock.Unlock()
m.orchestrators[id] = mo
return mo, nil
}
func (m *MockOrchestratorRegistry) RestoreDeployment(
_ context.Context, _ afero.Afero,
_ actor.Actor, _ string, _ jtypes.EnsembleConfig,
_ jtypes.EnsembleManifest, _ jtypes.DeploymentStatus,
_ jtypes.DeploymentSnapshot,
_ jtypes.SubnetManifest,
_ types.AllocationIDGenerator,
) (Orchestrator, error) {
return nil, nil
}
func (m *MockOrchestratorRegistry) Orchestrators() map[string]Orchestrator {
m.lock.RLock()
defer m.lock.RUnlock()
orchestrators := make(map[string]Orchestrator, len(m.orchestrators))
for id, o := range m.orchestrators {
orchestrators[id] = o
}
return orchestrators
}
func (m *MockOrchestratorRegistry) GetOrchestrator(id string) (Orchestrator, error) {
m.lock.RLock()
defer m.lock.RUnlock()
if o, ok := m.orchestrators[id]; ok {
return o, nil
}
return nil, ErrOrchestratorNotFound
}
func (m *MockOrchestratorRegistry) DeleteOrchestrator(_ string) {}
// Methods for deployment persistence (mock implementations)
func (m *MockOrchestratorRegistry) SaveOrchestrator(_ Orchestrator) error {
return nil
}
func (m *MockOrchestratorRegistry) GetAllDeployments() ([]*jtypes.OrchestratorView, error) {
return nil, nil
}
func (m *MockOrchestratorRegistry) GetDeploymentsByStatus(_ jtypes.DeploymentStatus) ([]*jtypes.OrchestratorView, error) {
return nil, nil
}
func (m *MockOrchestratorRegistry) QueryDeployments(_ DeploymentQuery) ([]*jtypes.OrchestratorView, int, error) {
return nil, 0, nil
}
func (m *MockOrchestratorRegistry) DeleteDeployment(_ string) error {
return nil
}
func (m *MockOrchestratorRegistry) GetDeployment(orchestratorID string) (*jtypes.OrchestratorView, error) {
m.lock.RLock()
defer m.lock.RUnlock()
orch, ok := m.orchestrators[orchestratorID]
if !ok {
return nil, ErrOrchestratorNotFound
}
// Convert orchestrator to OrchestratorView
privKey := orch.ActorPrivateKey()
privKeyBytes, err := crypto.PrivateKeyToBytes(privKey)
if err != nil {
return nil, fmt.Errorf("failed to marshal private key: %w", err)
}
view := &jtypes.OrchestratorView{
OrchestratorID: orch.ID(),
Cfg: orch.Config(),
Manifest: orch.Manifest(),
SubnetManifest: orch.SubnetManifest(),
Status: orch.Status(),
DeploymentSnapshot: orch.DeploymentSnapshot(),
PrivKey: privKeyBytes,
}
return view, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/spf13/afero"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/tokenomics/eventhandler"
"gitlab.com/nunet/device-management-service/tokenomics/events"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// keep as var instead of consts so that we change the values in tests
var (
BidRequestTimeout = 5 * time.Second
CommitDeploymentTimeout = 3 * time.Second
VerifyEdgeConstraintTimeout = 5 * time.Second
AllocationDeploymentTimeout = 5 * time.Second
// Setting a big timeout as the user might have to
// download large execution images
AllocationStartTimeout = 5 * time.Minute
AllocationShutdownTimeout = 5 * time.Second
MinEnsembleDeploymentTime = 15 * time.Second
MinEnsembleUpdateTimeout = 15 * time.Second
SubnetCreateTimeout = 2 * time.Minute
SubnetDestroyTimeout = 10 * time.Second
MaxBidMultiplier = 8
MaxPermutations = 1_000_000
grantOrchestratorCapsFrequency = 5 * time.Minute
)
var (
ErrProvisioningFailed = errors.New("failed to provision the ensemble")
ErrDeploymentFailed = errors.New("failed to create deployment")
ErrOrchestratorExists = errors.New("orchestrator with ID already exists")
ErrOrchestratorNotFound = errors.New("orchestrator with ID not found")
)
// Orchestrator is the interface for orchestrating deployments
type Orchestrator interface {
Deploy(expiry time.Time) error
Update(cfg jtypes.EnsembleConfig, expiry time.Time) error
Shutdown() error
Stop()
GetAllocationLogs(allocationID string) (AllocationLogsResponse, error)
WriteAllocationLogs(allocationID string, stdout, stderr []byte) (string, error)
StatusChannel(ctx context.Context) <-chan jtypes.DeploymentStatus
Status() jtypes.DeploymentStatus
Manifest() jtypes.EnsembleManifest
SubnetManifest() jtypes.SubnetManifest
Config() jtypes.EnsembleConfig
ID() string
ActorPrivateKey() crypto.PrivKey
DeploymentSnapshot() jtypes.DeploymentSnapshot
AllocationInfo() map[string]jtypes.AllocationInfo
UpdateAllocationStatus()
Done() <-chan struct{}
}
type BasicOrchestrator struct {
lock sync.Mutex
ctx context.Context
cancel func()
fs afero.Afero
workDir string
actor actor.Actor
id string
cfg jtypes.EnsembleConfig
manifest jtypes.EnsembleManifest
subnetManifest jtypes.SubnetManifest
status jtypes.DeploymentStatus
allocs map[string]jtypes.AllocationInfo
deploymentSnapshot jtypes.DeploymentSnapshot
supervisor *Supervisor
// ID generators
nodeIDGenerator types.NodeIDGenerator
allocationIDGenerator types.AllocationIDGenerator
// Status subscribers
statusSubscribers map[chan jtypes.DeploymentStatus]struct{}
statusSubscribersLock sync.RWMutex
contractEventHandler *eventhandler.EventHandler
contracts map[string]types.ContractConfig
}
var _ Orchestrator = (*BasicOrchestrator)(nil)
func NewOrchestrator(
ctx context.Context,
fs afero.Afero,
workDir string,
id string,
oActor actor.Actor,
cfg jtypes.EnsembleConfig,
nodeIDGenerator types.NodeIDGenerator,
allocationIDGenerator types.AllocationIDGenerator,
contractEventHandler *eventhandler.EventHandler,
contracts map[string]types.ContractConfig,
) (*BasicOrchestrator, error) {
if err := cfg.Validate(); err != nil {
return nil, fmt.Errorf("failed to validate ensemble configuration: %w", err)
}
// Validate generators at instantiation time
validator := types.NewDefaultGeneratorValidator()
if err := validator.ValidateNodeIDGenerator(nodeIDGenerator); err != nil {
return nil, fmt.Errorf("invalid node ID generator: %w", err)
}
if err := validator.ValidateAllocationIDGenerator(allocationIDGenerator); err != nil {
return nil, fmt.Errorf("invalid allocation ID generator: %w", err)
}
subnet, err := newSubnetManifest()
if err != nil {
return nil, fmt.Errorf("failed to create subnet manifest: %w", err)
}
childCtx, childCancel := context.WithCancel(ctx)
o := &BasicOrchestrator{
actor: oActor,
id: id,
cfg: cfg,
ctx: childCtx,
cancel: childCancel,
fs: fs,
workDir: workDir,
subnetManifest: subnet,
allocs: make(map[string]jtypes.AllocationInfo),
supervisor: NewSupervisor(childCtx, oActor, id),
nodeIDGenerator: nodeIDGenerator,
allocationIDGenerator: allocationIDGenerator,
statusSubscribers: make(map[chan jtypes.DeploymentStatus]struct{}),
contractEventHandler: contractEventHandler,
contracts: contracts,
}
err = o.RegisterBehaviors()
if err != nil {
return nil, fmt.Errorf("failed to register behaviors: %w", err)
}
return o, nil
}
func (o *BasicOrchestrator) SetStatus(status jtypes.DeploymentStatus) {
o.setStatus(status)
}
func (o *BasicOrchestrator) setStatus(status jtypes.DeploymentStatus) {
o.lock.Lock()
defer o.lock.Unlock()
log.Infow("orchestrator_status_updated",
"labels", []string{string(observability.LabelDeployment)},
"status", status.String(),
"orchestratorID", o.id)
oldStatus := o.status
o.status = status
if oldStatus != status {
// Notify all subscribers
o.statusSubscribersLock.RLock()
defer o.statusSubscribersLock.RUnlock()
for ch := range o.statusSubscribers {
select {
case ch <- status:
default:
// Skip if channel is blocked
}
}
}
// If we've reached a terminal state, close all subscriber channels
if status == jtypes.DeploymentStatusFailed ||
status == jtypes.DeploymentStatusCompleted {
for ch := range o.statusSubscribers {
close(ch)
}
o.statusSubscribers = make(map[chan jtypes.DeploymentStatus]struct{})
}
// metrics
if m := observability.DeploymentStatus; m != nil {
m.Add(o.ctx, 1, metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", o.id),
attribute.String("status", status.String()),
))
}
}
func (o *BasicOrchestrator) StatusChannel(ctx context.Context) <-chan jtypes.DeploymentStatus {
ch := make(chan jtypes.DeploymentStatus, 1)
// Send initial status
select {
case ch <- o.Status():
case <-ctx.Done():
close(ch)
return ch
}
o.statusSubscribersLock.Lock()
o.statusSubscribers[ch] = struct{}{}
o.statusSubscribersLock.Unlock()
// Clean up when context is done
go func() {
<-ctx.Done()
o.statusSubscribersLock.Lock()
delete(o.statusSubscribers, ch)
o.statusSubscribersLock.Unlock()
select {
case <-ch:
default:
close(ch)
}
}()
return ch
}
func (o *BasicOrchestrator) Deploy(expiry time.Time) error {
defer func() {
if o.status != jtypes.DeploymentStatusRunning {
o.setStatus(jtypes.DeploymentStatusFailed)
}
}()
o.setStatus(jtypes.DeploymentStatusPreparing)
log.Infow("initializing manifest",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.id)
o.manifest = o.newManifest(o.cfg)
if err := o.deploy(o.cfg, o.manifest, expiry); err != nil {
return fmt.Errorf("deploying ensemble: %w", err)
}
for _, a := range o.Manifest().Allocations {
err := o.grantOrchestratorCaps(a.Handle.DID)
if err != nil {
return fmt.Errorf("failed to grant orchestrator capabilities to allocations: %w", err)
}
}
log.Infow("deployment successful, starting supervisor",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.id)
go o.supervisor.Supervise(jtypes.NewManifestReader(o.manifest))
for _, v := range o.contracts {
evt := events.DeploymentStart{
EventBase: events.EventBase{Type: events.DeploymentStartEvent},
DeploymentID: o.manifest.ID,
OrchestratorID: o.id,
HeadContractDID: v.DID, // treat contrat as if head of contract chain, won't be taken into consideration in billing if contract is p2p
}
o.contractEventHandler.Push(eventhandler.Event{
ContractHostDID: v.Host,
ContractDID: v.DID,
Payload: evt,
})
}
return nil
}
func (o *BasicOrchestrator) NewManifest(
cfg jtypes.EnsembleConfig,
) jtypes.EnsembleManifest {
return o.newManifest(cfg)
}
func (o *BasicOrchestrator) newManifest(
cfg jtypes.EnsembleConfig,
) jtypes.EnsembleManifest {
log.Debugw("creating new manifest",
"labels", []string{string(observability.LabelDeployment)},
"config", cfg.V1)
manifest := jtypes.EnsembleManifest{
ID: o.id,
Orchestrator: o.actor.Handle(),
Metadata: cfg.V1.Metadata,
Allocations: make(map[string]jtypes.AllocationManifest),
Nodes: make(map[string]jtypes.NodeManifest),
Contracts: make(map[string]jtypes.ContractManifest),
Subnet: cfg.V1.Subnet,
}
for name, v := range cfg.Contracts() {
manifest.Contracts[name] = jtypes.ContractManifest{
ID: name,
DID: v.DID,
Host: v.Host,
}
}
for name, node := range cfg.NodesWithGenerator(o.nodeIDGenerator) {
nodeAllocations := make([]string, 0)
for _, allocName := range node.Allocations {
_, ok := cfg.Allocation(allocName)
if !ok {
log.Errorw("allocation not found in ensemble config, skipping",
"labels", []string{string(observability.LabelAllocation)},
"allocation", allocName)
continue
}
// Generate manifest key using generator
allocKey, err := o.allocationIDGenerator.GenerateManifestKey(name, allocName)
if err != nil {
log.Errorf("failed to generate manifest key for %s.%s: %v", name, allocName, err)
continue
}
nodeAllocations = append(nodeAllocations, allocKey)
}
standbyNodes := make([]string, 0)
if node.Redundancy > 0 {
for i := 1; i <= node.Redundancy; i++ {
standbyNodeID, err := o.nodeIDGenerator.GenerateStandbyNodeID(name, i)
if err != nil {
log.Errorf("failed to generate standby node ID for %s-%d: %v", name, i, err)
continue
}
standbyNodes = append(standbyNodes, standbyNodeID)
}
}
// Create primary node entry
nmf := jtypes.NodeManifest{
ID: name,
Allocations: nodeAllocations,
Peer: node.Peer,
StandbyNodes: standbyNodes,
}
manifest.Nodes[name] = nmf
}
// Now create allocation entries
for nodeID, nodeManifest := range manifest.Nodes {
for _, allocKey := range nodeManifest.Allocations {
parts := strings.Split(allocKey, ".")
if len(parts) != 2 {
log.Errorf("invalid allocation key format: %s, skipping", allocKey)
continue
}
configAllocName := parts[1]
alloc, ok := cfg.Allocation(configAllocName)
if !ok {
log.Errorf("allocation %s not found in ensemble config, skipping", configAllocName)
continue
}
isStandby := nodeManifest.RedundancyRole == jtypes.RoleStandby
// Generate full allocation ID using generator
fullAllocID, err := o.allocationIDGenerator.GenerateFullAllocationID(o.id, nodeID, configAllocName)
if err != nil {
log.Errorf("failed to generate full allocation ID for %s.%s: %v", nodeID, configAllocName, err)
continue
}
amf := jtypes.AllocationManifest{
ID: fullAllocID,
Type: alloc.Type,
NodeID: nodeID,
DNSName: alloc.DNSName + ".internal",
Healthcheck: alloc.HealthCheck,
Status: jtypes.AllocationPending,
Ports: make(map[int]int),
RedundancyGroup: configAllocName,
IsStandby: isStandby,
}
manifest.Allocations[allocKey] = amf
}
}
return manifest
}
func (o *BasicOrchestrator) invokeBehaviour(destination actor.Handle, behavior string, payload any, timeout time.Duration) (actor.Envelope, error) {
msg, err := actor.Message(
o.actor.Handle(),
destination,
behavior,
payload,
actor.WithMessageExpiry(actor.MakeExpiry(timeout)),
)
if err != nil {
return actor.Envelope{}, fmt.Errorf("failed to create contract actor message: %w", err)
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
return actor.Envelope{}, fmt.Errorf("failed to invoke message: %w", err)
}
ticker := time.NewTicker(timeout)
defer ticker.Stop()
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
return reply, nil
case <-ticker.C:
return actor.Envelope{}, errors.New("failed to receive reply due to timeout")
}
}
// TODO (dynamic ensemble PR): documentation on how updates
// and revert handle manifest changes
//
// IMPORTANT: when passing the manifest and config down the stack,
// use the readers (`jobs/types/readers.go`) to guarantee the immutability
// of these objects. (that is not to solve race condition problems but
// to manage the state of the orchestrator in a safer way)
func (o *BasicOrchestrator) deploy(
cfg jtypes.EnsembleConfig,
partialManifest jtypes.EnsembleManifest,
expiry time.Time,
) error {
o.deploymentSnapshot.Expiry = expiry
deploy:
for time.Now().Before(expiry) {
o.setStatus(jtypes.DeploymentStatusPreparing)
// delete old state of candidates if any
for c := range o.deploymentSnapshot.Candidates {
o.lock.Lock()
delete(o.deploymentSnapshot.Candidates, c)
o.lock.Unlock()
}
// 1. bid
bidCoordinator, err := NewBidCoordinator(o.id, o.actor)
if err != nil {
return fmt.Errorf("failed to create bidder: %w", err)
}
candidateDeployment, err := bidCoordinator.bid(jtypes.NewEnsembleCfgReader(cfg), o.deploymentSnapshot.Candidates, expiry)
if err != nil {
if errors.Is(err, ErrCandidateNotFound) {
log.Warnf("candidate deployment not found, redeploying: %v", err)
continue deploy
}
log.Errorf("failed to bid: %v", err)
return fmt.Errorf("failed to bid: %w", err)
}
for key, v := range candidateDeployment {
if v.V1.PromiseBid {
// wait for provisioning
pb := jtypes.PromiseBidRequest{
Bid: v,
}
envelope, err := o.invokeBehaviour(v.V1.Handle, behaviors.PromiseBidToBidBehavior, pb, time.Minute*5)
if err != nil {
log.Errorf("failed to convert promise bid: %v", err)
return fmt.Errorf("failed to convert promise bid: %w", err)
}
var newBid jtypes.ConvertedPromiseBidResponse
err = json.Unmarshal(envelope.Message, &newBid)
if err != nil {
log.Errorf("failed to unmarshal new bid: %v", err)
return fmt.Errorf("failed to unmarshal new bid: %w", err)
}
// replace the current bid with the new bid
candidateDeployment[key] = newBid.Bid
}
}
// 2. Commit the deployment
o.deploymentSnapshot.Candidates = candidateDeployment
o.setStatus(jtypes.DeploymentStatusCommitting)
committer := NewCommitter(o.ctx, o.id, o.actor, o.allocationIDGenerator, o.nodeIDGenerator)
manifestAfterCommit, err := committer.commit(
jtypes.NewEnsembleCfgReader(cfg),
jtypes.NewManifestReader(partialManifest),
candidateDeployment,
)
if err != nil {
log.Warnw("failed to commit deployment",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.id,
"error", err)
for nodeName, n := range manifestAfterCommit.Nodes {
o.revertNodeDeployment(cfg, nodeName, n.Handle)
}
continue deploy
}
o.updateManifest(manifestAfterCommit)
log.Debugw("manifest after commit",
"labels", []string{string(observability.LabelDeployment)},
"manifest", manifestAfterCommit)
// 3. provision the network and start the allocations
o.setStatus(jtypes.DeploymentStatusProvisioning)
provisioner := NewProvisioner(o.ctx, o.cancel, o.actor, o.subnetManifest, o.allocationIDGenerator)
manifestAfterProvision, err := provisioner.Provision(
jtypes.NewEnsembleCfgReader(cfg),
jtypes.NewManifestReader(manifestAfterCommit))
if err != nil {
log.Errorw("provisioning failed",
"labels", []string{string(observability.LabelDeployment)},
"error", err,
"orchestratorID", o.id)
o.revert(cfg, manifestAfterCommit)
continue deploy
}
go o.monitorOnlyTaskManifest()
o.updateManifest(manifestAfterProvision)
log.Infof("deployment successful")
o.setStatus(jtypes.DeploymentStatusRunning)
var allocated types.Resources
for idx, a := range o.Manifest().Allocations {
res := o.Config().V1.Allocations[a.ID].Resources
o.allocs[a.ID] = jtypes.AllocationInfo{
AllocationID: a.ID,
HeartbeatSeq: 0,
Status: a.Status,
HasHealthCheck: len(a.Healthcheck.Exec) != 0 && a.Healthcheck.Type != "",
ResourceLimit: res,
DNSName: a.DNSName,
IP: o.SubnetManifest().IndexRoutingTable[idx],
ResourceUsage: jtypes.AllocationResourceUsage{},
Timestamp: time.Now().Unix(),
}
_ = allocated.RAM.Add(res.RAM)
_ = allocated.Disk.Add(res.Disk)
_ = allocated.CPU.Add(res.CPU)
_ = allocated.GPUs.Add(res.GPUs)
}
// metric
if m := observability.DeploymentSuccess; m != nil {
m.Add(o.ctx, 1, metric.WithAttributes(
observability.AttrDID,
// attribute.Int("allocations", len(o.Manifest().Allocations)),
))
if m := observability.DeploySuccessAllocations; m != nil {
m.Record(o.ctx, int64(len(o.Manifest().Allocations)), metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", o.id),
))
}
if m := observability.DeploySuccessCPUCoresAssigned; m != nil {
m.Record(o.ctx, float64(allocated.CPU.Cores), metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", o.id),
))
}
if m := observability.DeploySuccessRAMGBAssigned; m != nil {
m.Record(o.ctx, int64(allocated.RAM.SizeInGB()), metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", o.id),
))
}
if m := observability.DeploySuccessDiskMBAssigned; m != nil {
m.Record(o.ctx, float64(allocated.Disk.Size/(1024.0*1024.0)), metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", o.id),
))
}
if m := observability.DeploySuccessGPUCountAssigned; m != nil {
m.Record(o.ctx, int64(len(allocated.GPUs)), metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", o.id),
))
}
}
return nil
}
// we failed to create the deployment in time
log.Errorw("deployment creation timed out",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.id)
return ErrDeploymentFailed
}
// Stop stops the orchestrator
func (o *BasicOrchestrator) Stop() {
// TODO
o.cancel()
err := o.actor.Stop()
if err != nil {
log.Warnf("error stopping orchestrator's actor: %s", err)
}
}
type AllocationLogsRequest struct {
AllocName string
}
type AllocationLogsResponse struct {
Stdout []byte
Stderr []byte
Error string
}
func (o *BasicOrchestrator) GetAllocationLogs(name string) (AllocationLogsResponse, error) {
var allocNodeHandle actor.Handle
var logsResp AllocationLogsResponse
for _, n := range o.manifest.Nodes {
if ok := utils.SliceContains(n.Allocations, name); ok {
allocNodeHandle = n.Handle
break
}
}
if allocNodeHandle.Empty() {
return logsResp,
fmt.Errorf(
"node not found for allocation %s of ensemble %s",
name, o.id,
)
}
msg, err := actor.Message(
o.actor.Handle(),
allocNodeHandle,
fmt.Sprintf(behaviors.AllocationLogsBehavior.DynamicTemplate, o.manifest.ID),
AllocationLogsRequest{
AllocName: name,
},
actor.WithMessageExpiry(uint64(time.Now().Add(2*time.Minute).UnixNano())),
)
if err != nil {
return logsResp, fmt.Errorf("creating get logs message: %w", err)
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
return logsResp, fmt.Errorf("invoking get logs message: %w", err)
}
var reply actor.Envelope
select {
case reply = <-replyCh:
case <-time.After(2 * time.Minute):
return logsResp, fmt.Errorf("timeout getting logs for %s: %w", name, ErrDeploymentFailed)
}
defer reply.Discard()
if err := json.Unmarshal(reply.Message, &logsResp); err != nil {
return logsResp, fmt.Errorf("unmarshalling get logs response: %w", err)
}
if logsResp.Error != "" {
return logsResp, fmt.Errorf("replied with error getting logs for %s: %s", name, logsResp.Error)
}
return logsResp, nil
}
func (o *BasicOrchestrator) Status() jtypes.DeploymentStatus {
o.lock.Lock()
defer o.lock.Unlock()
return o.status
}
func (o *BasicOrchestrator) Manifest() jtypes.EnsembleManifest {
o.lock.Lock()
defer o.lock.Unlock()
return o.manifest.Clone()
}
func (o *BasicOrchestrator) SubnetManifest() jtypes.SubnetManifest {
o.lock.Lock()
defer o.lock.Unlock()
return o.subnetManifest
}
func (o *BasicOrchestrator) ManifestNodesPeerIDs() []string {
o.lock.Lock()
defer o.lock.Unlock()
ids := make([]string, len(o.manifest.Nodes))
for _, n := range o.manifest.Nodes {
ids = append(ids, n.Peer)
}
return ids
}
func (o *BasicOrchestrator) Config() jtypes.EnsembleConfig {
o.lock.Lock()
defer o.lock.Unlock()
return o.cfg.Clone()
}
func (o *BasicOrchestrator) ID() string {
return o.id
}
func (o *BasicOrchestrator) ActorPrivateKey() crypto.PrivKey {
return o.actor.Security().PrivKey()
}
func (o *BasicOrchestrator) DeploymentSnapshot() jtypes.DeploymentSnapshot {
o.lock.Lock()
defer o.lock.Unlock()
return o.deploymentSnapshot
}
func (o *BasicOrchestrator) updateManifest(m jtypes.EnsembleManifest) {
o.lock.Lock()
// cloning since the orchestrator original manifest state
// might inherit map references of partial updates
o.manifest = m.Clone()
o.lock.Unlock()
o.UpdateAllocationStatus()
}
// monitorOnlyTaskManifest will be responsible for tearing down
// the orchestrator after all tasks are terminated when
// the ensemble is composed *ONLY* by tasks
func (o *BasicOrchestrator) monitorOnlyTaskManifest() {
if !isOnlyTaskManifest(o.manifest) {
return
}
ticker := time.NewTicker(monitorOnlyTaskManifestInterval)
defer ticker.Stop()
for {
select {
case <-o.ctx.Done():
return
case <-ticker.C:
o.lock.Lock()
allTerminated := true
for name := range o.manifest.Allocations {
if !o.manifest.IsTerminatedTask(name) {
allTerminated = false
break
}
}
o.lock.Unlock()
if !allTerminated {
continue
}
log.Infof("All tasks are terminated, shutting down orchestrator.")
o.setStatus(jtypes.DeploymentStatusCompleted)
o.cancel()
return
}
}
}
func isOnlyTaskManifest(m jtypes.EnsembleManifest) bool {
for _, a := range m.Allocations {
if a.Type != jtypes.AllocationTypeTask {
return false
}
}
return true
}
func (o *BasicOrchestrator) AllocationInfo() map[string]jtypes.AllocationInfo {
o.lock.Lock()
defer o.lock.Unlock()
allocsCopy := make(map[string]jtypes.AllocationInfo, len(o.allocs))
for k, v := range o.allocs {
allocsCopy[k] = v
}
return allocsCopy
}
func (o *BasicOrchestrator) UpdateAllocationStatus() {
manifest := o.Manifest()
if manifest.Allocations == nil {
return
}
o.lock.Lock()
defer o.lock.Unlock()
for _, a := range manifest.Allocations {
if _, ok := o.allocs[a.ID]; !ok {
continue
}
allocInfo := o.allocs[a.ID]
allocInfo.Status = a.Status
o.allocs[a.ID] = allocInfo
}
}
func (o *BasicOrchestrator) RegisterBehaviors() error {
orchestratorBehaviors := map[string]func(actor.Envelope){
behaviors.NotifyTaskTerminationBehavior: o.handleTaskTermination,
behaviors.NotifyAllocationLivenessBehavior: o.handleAllocationLiveness,
behaviors.NotifyAllocationStatusBehavior: o.handleAllocationStatusUpdate,
}
for b, handler := range orchestratorBehaviors {
err := o.actor.AddBehavior(b, handler)
if err != nil {
return fmt.Errorf("add behavior %s to orchestrator actor: %w", b, err)
}
}
return nil
}
func (o *BasicOrchestrator) grantOrchestratorCaps(alloc did.DID) error {
log.Infow("granting alloc capabilities",
"orchestratorID", o.id,
"allocationDID", alloc.String(),
)
oDID, err := did.FromID(o.actor.Handle().ID)
if err != nil {
return fmt.Errorf("failed to parse orchestrator DID: %w", err)
}
err = o.actor.Security().Grant(
alloc,
oDID,
[]ucan.Capability{behaviors.OrchestratorNamespace},
grantOrchestratorCapsFrequency,
)
if err != nil {
return fmt.Errorf(
"granting orchestrator caps to alloc %s: %w",
alloc.String(), err)
}
// TODO: create helper func to periodically grant caps as
// it's being used here and on createAllocations()
go func() {
ticker := time.NewTicker(grantOrchestratorCapsFrequency)
defer ticker.Stop()
select {
case <-o.ctx.Done():
return
case <-ticker.C:
err := o.actor.Security().Grant(
alloc,
o.actor.Handle().DID,
[]ucan.Capability{},
grantOrchestratorCapsFrequency,
)
if err != nil {
log.Errorf(
"periodic grant orchestrator caps to alloc %s: %w",
alloc.String(), err)
}
return
}
}()
return nil
}
func (o *BasicOrchestrator) Done() <-chan struct{} {
return o.ctx.Done()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"context"
"encoding/json"
"fmt"
"slices"
"strings"
"sync"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
netutils "gitlab.com/nunet/device-management-service/network/utils"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
)
const orchSubnetName = "orchestrator"
var monitorOnlyTaskManifestInterval = time.Second * 10
// Provisioner handles the provisioning process for ensemble deployment
type Provisioner struct {
ctx context.Context
cancel context.CancelFunc
actor actor.Actor
subnetManifest jtypes.SubnetManifest
allocationIDGenerator types.AllocationIDGenerator
lock sync.Mutex
}
// NewProvisioner creates a new Provisioner instance
func NewProvisioner(
ctx context.Context,
cancel context.CancelFunc,
actor actor.Actor,
subnetManifest jtypes.SubnetManifest,
allocationIDGenerator types.AllocationIDGenerator,
) *Provisioner {
return &Provisioner{
ctx: ctx,
cancel: cancel,
actor: actor,
subnetManifest: subnetManifest,
allocationIDGenerator: allocationIDGenerator,
}
}
// Provision handles the provisioning process and returns the updated manifest
func (p *Provisioner) Provision(
cfgReader jtypes.EnsembleCfgReader,
manifestReader jtypes.ManifestReader,
) (jtypes.EnsembleManifest, error) {
cfg := cfgReader.Read()
manifest := manifestReader.Read()
log.Infow("provisioning ensemble manifest",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", manifest.ID,
)
// 1. provision subnet
manifest, err := p.provisionSubnet(manifest)
if err != nil {
return manifest, fmt.Errorf("provisioning subnet: %w", err)
}
// 2. start allocations
manifest, err = p.provisionAllocations(cfg, manifest)
if err != nil {
return manifest, fmt.Errorf("provisioning allocations: %w", err)
}
return manifest, nil
}
func (p *Provisioner) provisionSubnet(manifest jtypes.EnsembleManifest, skipCreate ...string) (jtypes.EnsembleManifest, error) {
for allocManifestKey := range manifest.Allocations {
err := p.addAllocationToSubnet(manifest, allocManifestKey)
if err != nil {
return manifest,
fmt.Errorf("error adding allocation %s to subnet: %w", allocManifestKey, err)
}
err = manifest.UpdateAllocation(allocManifestKey, func(alloc *jtypes.AllocationManifest) {
alloc.PrivAddr = p.subnetManifest.IndexRoutingTable[allocManifestKey]
})
if err != nil {
return manifest, fmt.Errorf("error updating allocation %s: %w", allocManifestKey, err)
}
}
// handles to request subnetcreate
subCreateHandles := []actor.Handle{}
// subnet config requests (add peer, dns, port map)
subReqs := []subnetRequest{}
for _, nodeManifest := range manifest.Nodes {
if !slices.Contains(skipCreate, nodeManifest.ID) {
subCreateHandles = append(subCreateHandles, nodeManifest.Handle)
}
fmt.Println("allocations in manifest.Nodes", nodeManifest.Allocations)
for _, allocID := range nodeManifest.Allocations {
allocManifest, ok := manifest.Allocations[allocID]
if !ok {
log.Warnf("provisioning subnet: allocation %s not found in manifest, skipping", allocID)
continue
}
fmt.Println("ip", p.subnetManifest.IndexRoutingTable[allocID], "allocID", allocID)
subReqs = append(subReqs, subnetRequest{
handle: allocManifest.Handle,
ip: p.subnetManifest.IndexRoutingTable[allocID],
peerID: manifest.Nodes[allocManifest.NodeID].Peer,
ports: allocManifest.Ports,
})
}
}
if manifest.Subnet.Join { // orchestrator should join the subnet
if _, ok := p.subnetManifest.IndexRoutingTable[orchSubnetName]; !ok {
ip, err := netutils.GetNextIP(p.subnetManifest.CIDR, p.subnetManifest.UsedIPs)
log.Debug("Generated IP %s for orchestrator", ip)
if err != nil {
return manifest, fmt.Errorf("error getting next IP: %w", err)
}
p.subnetManifest.RoutingTable[ip.String()] = p.actor.Handle().Address.HostID
p.subnetManifest.IndexRoutingTable[orchSubnetName] = ip.String()
p.subnetManifest.UsedIPs[ip.String()] = true
subCreateHandles = append(subCreateHandles, p.actor.Supervisor())
p.subnetManifest.DNSRecords[orchSubnetName] = p.subnetManifest.IndexRoutingTable[orchSubnetName]
}
}
err := p.createSubnet(manifest.ID, p.subnetManifest.RoutingTable, subCreateHandles)
if err != nil {
return manifest, fmt.Errorf("error creating subnet: %w", err)
}
// if orchestrator should join subnet, setup with one behavior
// this doesn't look very good but let's address with #893
if manifest.Subnet.Join {
err := p.orchestratorJoinSubnet(manifest.ID, p.subnetManifest.IndexRoutingTable, p.subnetManifest.RoutingTable, p.subnetManifest.DNSRecords)
if err != nil {
return manifest, fmt.Errorf("error joining subnet: %w", err)
}
}
// 1.b create and plug IPs
log.Infow("adding peers to subnet",
"labels", []string{string(observability.LabelDeployment)},
"peerCount", len(subReqs),
"manifestID", manifest.ID,
)
for i, req := range subReqs {
log.Infow("subnet request details",
"labels", []string{string(observability.LabelDeployment)},
"requestIndex", i,
"handle", req.handle,
"ip", req.ip,
"peerID", req.peerID,
)
}
err = p.subnetAddPeer(manifest.ID, subReqs)
if err != nil {
return manifest, fmt.Errorf("error adding peers to subnet: %w", err)
}
// 1.c configure DNS
err = p.addDNSRecords(manifest.ID, subReqs, p.subnetManifest.DNSRecords)
if err != nil {
return manifest, fmt.Errorf("error adding dns records to subnet: %w", err)
}
// 1.d configure port mapping
err = p.mapPorts(manifest.ID, subReqs)
if err != nil {
return manifest, fmt.Errorf("error adding port mappings to subnet: %w", err)
}
return manifest, nil
}
func (p *Provisioner) provisionAllocations(
cfg jtypes.EnsembleConfig, manifest jtypes.EnsembleManifest,
) (jtypes.EnsembleManifest, error) {
var wg sync.WaitGroup
interim := map[string][]string{} // a map of verteces to edges (their dependencies)
for allocName, allocCfg := range cfg.Allocations() {
interim[allocName] = allocCfg.DependsOn
}
orderedAllocs, err := orderByDependency(interim)
if err != nil {
return manifest, err
}
allocStatuses := make(map[string]jtypes.AllocationStatus)
for nodeKey, nodeManifest := range manifest.Nodes {
if nodeManifest.RedundancyRole == jtypes.RoleStandby {
continue // skip standby nodes' allocations
}
for _, allocs := range orderedAllocs {
wg = sync.WaitGroup{}
errCh := make(chan error, len(allocs))
for _, allocName := range allocs {
allocKey, err := p.allocationIDGenerator.GenerateManifestKey(nodeKey, allocName)
if err != nil {
log.Errorf("failed to generate manifest key for %s.%s: %v", nodeKey, allocName, err)
continue
}
statusMetric := func(status jtypes.AllocationStatus) {
// metric
if m := observability.AllocationStatus; m != nil {
m.Record(p.ctx, 1, metric.WithAttributes(
observability.AttrDID,
attribute.String("orchestratorID", manifest.ID),
attribute.String("allocationID", allocKey),
attribute.String("status", string(status)),
))
}
}
// TODO: this is a temporary hack, we need to find better ways to handle this
rawNodeAllocations := make([]string, 0)
for _, alloc := range nodeManifest.Allocations {
parts := strings.Split(alloc, ".")
if len(parts) != 2 {
log.Errorf("invalid allocation key format: %s, skipping", alloc)
continue
}
rawNodeAllocations = append(rawNodeAllocations, parts[1])
}
if !utils.SliceContains(rawNodeAllocations, allocName) {
log.Debugf("skipping allocation %s because it's not on node %s", allocName, nodeKey)
continue
}
wg.Add(1)
go func(allocManifest jtypes.AllocationManifest, allocKey string) {
defer wg.Done()
// Determine if this is a standby allocation
isStandby := allocManifest.IsStandby
nodeManifest := manifest.Nodes[allocManifest.NodeID]
if nodeManifest.RedundancyRole == jtypes.RoleStandby {
isStandby = true
}
msg, err := actor.Message(
p.actor.Handle(),
allocManifest.Handle,
behaviors.AllocationStartBehavior,
behaviors.AllocationStartRequest{
SubnetIP: p.subnetManifest.IndexRoutingTable[allocKey],
GatewayIP: p.subnetManifest.GatewayIP,
PortMapping: allocManifest.Ports,
},
actor.WithMessageExpiry(actor.MakeExpiry(AllocationStartTimeout)),
)
if err != nil {
errCh <- fmt.Errorf("error creating allocation start message: %w", err)
return
}
replyCh, err := p.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking allocation start: %w", err)
return
}
ticker := time.NewTicker(AllocationStartTimeout)
defer ticker.Stop()
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var response behaviors.AllocationStartResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling allocation start response: %w", err)
return
}
if !response.OK {
allocStatuses[allocKey] = jtypes.AllocationFailed
statusMetric(jtypes.AllocationFailed)
errCh <- fmt.Errorf("error starting allocation: %s: %w", response.Error, ErrDeploymentFailed)
return
}
case <-ticker.C:
errCh <- fmt.Errorf("timeout starting allocation: %w", ErrDeploymentFailed)
return
}
statusMsg := "started"
status := jtypes.AllocationRunning
if isStandby {
statusMsg = "prepared in standby mode"
status = jtypes.AllocationStandby
}
allocStatuses[allocKey] = status
statusMetric(status)
log.Infow("allocation successfully started on peer",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", manifest.ID,
"peerID", nodeManifest.Peer,
"status", statusMsg,
"handle", allocManifest.Handle)
}(manifest.Allocations[allocKey], allocKey)
}
wg.Wait()
for allocKey, status := range allocStatuses {
if alloc, ok := manifest.Allocations[allocKey]; ok {
alloc.Status = status
manifest.Allocations[allocKey] = alloc
} else {
log.Warnf("allocation %s not found in manifest", allocKey)
}
}
// TODO dup?
wg.Wait()
for allocName, status := range allocStatuses {
err := manifest.UpdateAllocation(allocName, func(alloc *jtypes.AllocationManifest) {
alloc.Status = status
})
if err != nil {
log.Warnf("error updating allocation status: %s", err)
go func() {
errCh <- err
}()
}
}
close(errCh)
if aggErr := aggregateErrors(errCh); aggErr != nil {
return manifest, aggErr
}
}
}
return manifest, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"context"
"fmt"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/actor"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/tokenomics/eventhandler"
"gitlab.com/nunet/device-management-service/types"
)
// Registry is an interface which acts as a source of orchestrators.
type Registry interface {
// NewOrchestrator creates a new orchestrator
NewOrchestrator(
ctx context.Context, fs afero.Afero, workDir string,
id string, actor actor.Actor, cfg jtypes.EnsembleConfig,
nodeIDGenerator types.NodeIDGenerator, allocationIDGenerator types.AllocationIDGenerator,
contractEventHandler *eventhandler.EventHandler,
contracts map[string]types.ContractConfig,
) (Orchestrator, error)
// RestoreDeployment restores deployments where the status is either provisioning, committing or running
RestoreDeployment(
ctx context.Context, fs afero.Afero,
actr actor.Actor, id string, cfg jtypes.EnsembleConfig,
manifest jtypes.EnsembleManifest, status jtypes.DeploymentStatus,
restoreInfo jtypes.DeploymentSnapshot,
subnetManifest jtypes.SubnetManifest,
allocationIDGenerator types.AllocationIDGenerator,
) (Orchestrator, error)
// Orchestrators returns a map of all orchestrators
Orchestrators() map[string]Orchestrator
// GetOrchestrator returns an orchestrator by ID
GetOrchestrator(id string) (Orchestrator, error)
// DeleteOrchestrator deletes an orchestrator by ID
DeleteOrchestrator(id string)
// Methods for deployment persistence
// SaveOrchestrator persists orchestrator state to store
SaveOrchestrator(orchestrator Orchestrator) error
// GetAllDeployments retrieves all deployments from store
GetAllDeployments() ([]*jtypes.OrchestratorView, error)
// GetDeploymentsByStatus retrieves deployments filtered by status
GetDeploymentsByStatus(status jtypes.DeploymentStatus) ([]*jtypes.OrchestratorView, error)
// QueryDeployments retrieves deployments with advanced filtering
QueryDeployments(query DeploymentQuery) ([]*jtypes.OrchestratorView, int, error)
// DeleteDeployment removes a specific deployment by orchestrator ID
DeleteDeployment(orchestratorID string) error
// GetDeployment retrieves a deployment from store by ID
GetDeployment(orchestratorID string) (*jtypes.OrchestratorView, error)
}
// basicRegistry the default implementation of Registry
type basicRegistry struct {
lock sync.RWMutex
orchestrators map[string]Orchestrator // map of orchestrators (in-memory cache)
store DeploymentStore // persistent store
}
var _ Registry = (*basicRegistry)(nil)
// NewRegistry creates a new orchestrator registry
func NewRegistry(store DeploymentStore) Registry {
return &basicRegistry{
orchestrators: make(map[string]Orchestrator),
store: store,
}
}
// NewOrchestrator creates a new orchestrator
// TODO-style: NewOrchestrator calls NewOrchestrator, that is confusing
func (f *basicRegistry) NewOrchestrator(
ctx context.Context, fs afero.Afero, workDir string,
id string, actor actor.Actor, cfg jtypes.EnsembleConfig,
nodeIDGenerator types.NodeIDGenerator, allocationIDGenerator types.AllocationIDGenerator,
contractEventHandler *eventhandler.EventHandler,
contracts map[string]types.ContractConfig,
) (Orchestrator, error) {
// check if orchestrator already exists in store
if _, err := f.store.Get(id); err == nil {
return nil, ErrOrchestratorExists
}
// NewOrchestrator creates a new orchestrator with a new context
o, err := NewOrchestrator(ctx, fs, workDir, id, actor, cfg, nodeIDGenerator, allocationIDGenerator, contractEventHandler, contracts)
if err != nil {
return nil, fmt.Errorf("failed to create orchestrator: %w", err)
}
// Initialize manifest with metadata so it's available when saved to store
o.manifest = o.newManifest(cfg)
// Save to store immediately
if err := f.SaveOrchestrator(o); err != nil {
return nil, fmt.Errorf("failed to save orchestrator: %w", err)
}
f.lock.Lock()
defer f.lock.Unlock()
f.orchestrators[id] = o
// save deployment on status change, use the orchestrator's context
f.saveDeploymentOnStatusChange(o.ctx, o)
return o, nil
}
func (f *basicRegistry) saveDeploymentOnStatusChange(ctx context.Context, o Orchestrator) {
go func() {
prevStatus := o.Status()
statusCh := o.StatusChannel(ctx)
for {
select {
case <-ctx.Done():
return
case status := <-statusCh:
if status == prevStatus {
continue
}
prevStatus = status
log.Infow("status changed, saving deployment...",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.ID(),
"status", o.Status().String(),
)
if err := f.SaveOrchestrator(o); err != nil {
log.Errorw("failed to save orchestrator on status change", "error", err)
}
}
}
}()
}
// restoreDeployment restores deployments where the status is either provisioning, committing or running
// TODO: restore subnetManifest if necessary
func (f *basicRegistry) restoreDeployment(
ctx context.Context,
fs afero.Afero,
actr actor.Actor, id string,
cfg jtypes.EnsembleConfig, manifest jtypes.EnsembleManifest,
status jtypes.DeploymentStatus, restoreInfo jtypes.DeploymentSnapshot,
subnetManifest jtypes.SubnetManifest,
allocationIDGenerator types.AllocationIDGenerator,
) (Orchestrator, error) {
log.Infow("restoring deployment", "id", id, "status", status)
ctx, cancel := context.WithCancel(ctx)
o := &BasicOrchestrator{
id: id,
fs: fs,
actor: actr,
cfg: cfg,
status: status,
deploymentSnapshot: restoreInfo,
supervisor: NewSupervisor(ctx, actr, id),
manifest: manifest,
subnetManifest: subnetManifest,
allocs: make(map[string]jtypes.AllocationInfo),
ctx: ctx,
cancel: cancel,
statusSubscribers: make(map[chan jtypes.DeploymentStatus]struct{}),
allocationIDGenerator: allocationIDGenerator,
}
err := o.RegisterBehaviors()
if err != nil {
return nil, fmt.Errorf("failed to register orchestrator behaviors: %w", err)
}
// TODO: manifest.Empty()
if manifest.ID == "" {
o.manifest = o.NewManifest(cfg)
}
log.Infow("restored deployment", "id", id, "status", status)
if o.status == jtypes.DeploymentStatusPreparing {
log.Debugw("restoring deployment from manifest",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", id,
)
// save deployment on status change, use the orchestrator's context
f.saveDeploymentOnStatusChange(o.ctx, o)
err := o.deploy(cfg, manifest, restoreInfo.Expiry)
if err != nil {
o.setStatus(jtypes.DeploymentStatusFailed)
return o, fmt.Errorf("failed to deploy deployment: %w", err)
}
go o.monitorOnlyTaskManifest()
go o.supervisor.Supervise(jtypes.NewManifestReader(o.manifest))
return o, nil
}
if o.status == jtypes.DeploymentStatusCommitting {
log.Debugw("reverting deployment of old candidates and restarting deployment from the beginning",
"labels", []string{string(observability.LabelDeployment)},
"reason", "deployment was in committing state",
"orchestratorID", id,
)
manifest := o.manifest.Clone()
for nodeID, bid := range restoreInfo.Candidates {
o.revertNodeDeployment(cfg, nodeID, bid.Handle())
}
// save deployment on status change, use the orchestrator's context
f.saveDeploymentOnStatusChange(o.ctx, o)
err := o.deploy(cfg, manifest, restoreInfo.Expiry)
if err != nil {
o.setStatus(jtypes.DeploymentStatusFailed)
return o, fmt.Errorf("failed to deploy deployment: %w", err)
}
go o.monitorOnlyTaskManifest()
go o.supervisor.Supervise(jtypes.NewManifestReader(o.manifest))
return o, nil
}
if o.status == jtypes.DeploymentStatusProvisioning {
log.Debugw("restoring deployment from manifest",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", id,
)
// Log the manifest content to see what nodes are included
log.Infow("manifest before provisioning restoration",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", id,
"manifestNodes", len(manifest.Nodes),
"nodeIDs", func() []string {
var nodeIDs []string
for nodeID := range manifest.Nodes {
nodeIDs = append(nodeIDs, nodeID)
}
return nodeIDs
}(),
)
log.Infow("starting provisioning restoration with healthy provisioner",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", id,
)
provisioner := NewProvisioner(ctx, cancel, actr, subnetManifest, o.allocationIDGenerator)
manifestAfterProvision, err := provisioner.Provision(
jtypes.NewEnsembleCfgReader(cfg),
jtypes.NewManifestReader(manifest))
if err != nil {
log.Errorf("failed to provision network during restoration: %s", err)
log.Infow("reverting deployment due to failed provisioning during restoration",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", id,
)
o.revert(cfg, manifest)
log.Infow("reverted deployment due to failed provisioning during restoration",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", id,
)
log.Infow("deploying new manifest after failed provisioning during restoration",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", id,
)
// save deployment on status change, use the orchestrator's context
f.saveDeploymentOnStatusChange(o.ctx, o)
err := o.deploy(cfg, o.newManifest(cfg), restoreInfo.Expiry)
if err != nil {
o.setStatus(jtypes.DeploymentStatusFailed)
return o, fmt.Errorf("failed to deploy deployment: %w", err)
}
go o.monitorOnlyTaskManifest()
go o.supervisor.Supervise(jtypes.NewManifestReader(o.manifest))
return o, nil
}
log.Infow("provisioning restoration completed successfully",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", id,
)
// save deployment on status change, use the orchestrator's context
f.saveDeploymentOnStatusChange(o.ctx, o)
o.updateManifest(manifestAfterProvision)
o.setStatus(jtypes.DeploymentStatusRunning)
go o.monitorOnlyTaskManifest()
go o.supervisor.Supervise(jtypes.NewManifestReader(o.manifest))
return o, nil
}
if o.status == jtypes.DeploymentStatusRunning {
// grant allocs still running
for _, a := range o.Manifest().Allocations {
err := o.grantOrchestratorCaps(a.Handle.DID)
if err != nil {
return nil, fmt.Errorf("failed to grant orchestrator capabilities during restoration: %w", err)
}
}
// save deployment on status change, use the orchestrator's context
f.saveDeploymentOnStatusChange(o.ctx, o)
if o.manifest.Subnet.Join {
if _, ok := o.subnetManifest.IndexRoutingTable[orchSubnetName]; ok {
handleError := func(err error) (Orchestrator, error) {
log.Errorf("failed to join subnet: %s", err)
o.revert(cfg, manifest)
if err := o.deploy(cfg, o.newManifest(cfg), restoreInfo.Expiry); err != nil {
o.setStatus(jtypes.DeploymentStatusFailed)
return o, fmt.Errorf("failed to deploy deployment: %w", err)
}
return o, nil
}
provisioner := NewProvisioner(ctx, cancel, o.actor, o.subnetManifest, o.allocationIDGenerator)
err := provisioner.createSubnet(o.manifest.ID, o.subnetManifest.RoutingTable, []actor.Handle{o.actor.Supervisor()})
if err != nil {
return handleError(err)
}
err = provisioner.orchestratorJoinSubnet(
manifest.ID,
o.subnetManifest.IndexRoutingTable,
o.subnetManifest.RoutingTable,
o.subnetManifest.DNSRecords,
)
if err != nil {
return handleError(err)
}
}
}
for idx, a := range o.Manifest().Allocations {
o.allocs[a.ID] = jtypes.AllocationInfo{
AllocationID: a.ID,
HeartbeatSeq: 0,
Status: a.Status,
HasHealthCheck: len(a.Healthcheck.Exec) != 0 && a.Healthcheck.Type != "",
ResourceLimit: o.Config().V1.Allocations[a.ID].Resources,
DNSName: a.DNSName,
IP: o.SubnetManifest().IndexRoutingTable[idx],
ResourceUsage: jtypes.AllocationResourceUsage{},
Timestamp: time.Now().Unix(),
}
}
go o.monitorOnlyTaskManifest()
go o.supervisor.Supervise(jtypes.NewManifestReader(o.manifest))
}
return o, nil
}
// RestoreDeployment creates an orchestrator and attempts to restore its deployment
func (f *basicRegistry) RestoreDeployment(
ctx context.Context,
fs afero.Afero,
actr actor.Actor, id string, cfg jtypes.EnsembleConfig,
manifest jtypes.EnsembleManifest, status jtypes.DeploymentStatus, restoreInfo jtypes.DeploymentSnapshot,
subnetManifest jtypes.SubnetManifest,
allocationIDGenerator types.AllocationIDGenerator,
) (Orchestrator, error) {
// check if orchestrator already exists
f.lock.RLock()
if _, ok := f.orchestrators[id]; ok {
f.lock.RUnlock()
return nil, ErrOrchestratorExists
}
f.lock.RUnlock()
o, err := f.restoreDeployment(ctx, fs, actr, id, cfg, manifest, status, restoreInfo, subnetManifest, allocationIDGenerator)
if err != nil {
return nil, fmt.Errorf("failed to restore deployment: %w", err)
}
f.lock.Lock()
defer f.lock.Unlock()
f.orchestrators[id] = o
return o, nil
}
// Orchestrators returns a map of all orchestrators
func (f *basicRegistry) Orchestrators() map[string]Orchestrator {
f.lock.RLock()
defer f.lock.RUnlock()
orchestrators := make(map[string]Orchestrator, len(f.orchestrators))
for id, o := range f.orchestrators {
orchestrators[id] = o
}
return orchestrators
}
// GetOrchestrator returns an orchestrator by ID
func (f *basicRegistry) GetOrchestrator(id string) (Orchestrator, error) {
f.lock.RLock()
defer f.lock.RUnlock()
o, ok := f.orchestrators[id]
if !ok {
return nil, ErrOrchestratorNotFound
}
return o, nil
}
// DeleteOrchestrator deletes an orchestrator by ID
func (f *basicRegistry) DeleteOrchestrator(id string) {
f.lock.Lock()
defer f.lock.Unlock()
// Remove from memory
delete(f.orchestrators, id)
// Remove from store
if err := f.store.Delete(id); err != nil {
log.Warnf("failed to delete orchestrator from store: %s", err)
// Log warning but don't fail - the orchestrator is already removed from memory
// This could happen if the deployment was already deleted or doesn't exist
}
log.Infow("deleted orchestrator from store", "id", id)
}
// SaveOrchestrator persists orchestrator state to store
func (f *basicRegistry) SaveOrchestrator(orchestrator Orchestrator) error {
pvkey := orchestrator.ActorPrivateKey()
pkRaw, err := crypto.MarshalPrivateKey(pvkey)
if err != nil {
return fmt.Errorf("convert priv key to raw: %w", err)
}
view := &jtypes.OrchestratorView{
BaseDBModel: types.BaseDBModel{
ID: orchestrator.ID(),
CreatedAt: time.Now(),
},
OrchestratorID: orchestrator.ID(),
Cfg: orchestrator.Config(),
Manifest: orchestrator.Manifest(),
SubnetManifest: orchestrator.SubnetManifest(),
Status: orchestrator.Status(),
DeploymentSnapshot: orchestrator.DeploymentSnapshot(),
PrivKey: pkRaw,
}
return f.store.Upsert(view)
}
// GetAllDeployments retrieves all deployments from store
func (f *basicRegistry) GetAllDeployments() ([]*jtypes.OrchestratorView, error) {
deployments, err := f.store.GetAll(nil)
return deployments, err
}
// GetDeploymentsByStatus retrieves deployments filtered by status
func (f *basicRegistry) GetDeploymentsByStatus(status jtypes.DeploymentStatus) ([]*jtypes.OrchestratorView, error) {
return f.store.GetAll(&status)
}
// QueryDeployments retrieves deployments with advanced filtering
func (f *basicRegistry) QueryDeployments(query DeploymentQuery) ([]*jtypes.OrchestratorView, int, error) {
return f.store.Query(query)
}
// DeleteDeployment removes a specific deployment by orchestrator ID
func (f *basicRegistry) DeleteDeployment(orchestratorID string) error {
return f.store.Delete(orchestratorID)
}
// GetDeployment retrieves a deployment from store by ID
func (f *basicRegistry) GetDeployment(orchestratorID string) (*jtypes.OrchestratorView, error) {
return f.store.Get(orchestratorID)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"encoding/json"
"fmt"
"net"
"strconv"
"strings"
"sync"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
netutils "gitlab.com/nunet/device-management-service/network/utils"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/utils"
)
type recordsModificationType int
const (
add recordsModificationType = iota
remove
)
var orchestratorJoinTimeout = 2 * time.Minute
type subnetRequest struct {
handle actor.Handle
ip string
peerID string
ports map[int]int
}
type SubnetCreateRequest struct {
SubnetID string
IP string
RoutingTable map[string]string
CIDR string
}
type SubnetCreateResponse struct {
OK bool
Error string
}
type SubnetJoinRequest struct {
SubnetID string
PeerID string
IP string
// map of domain_name:ip
RoutingTable map[string]string
Records map[string]string
}
type SubnetJoinResponse struct {
OK bool
Error string
}
func newSubnetManifest() (jtypes.SubnetManifest, error) {
cidr, err := netutils.GetRandomCIDRInRange(
24,
net.ParseIP("10.0.0.0"),
net.ParseIP("10.255.255.255"),
[]string{},
)
if err != nil {
return jtypes.SubnetManifest{}, fmt.Errorf("error getting random CIDR: %w", err)
}
parts := strings.Split(strings.Split(cidr, "/")[0], ".")
gatewayIP := fmt.Sprintf("%s.%s.%s.%s", parts[0], parts[1], parts[2], "1")
broadcastIP := fmt.Sprintf("%s.%s.%s.%s", parts[0], parts[1], parts[2], "255")
usedIPs := map[string]bool{
gatewayIP: true,
broadcastIP: true,
}
return jtypes.SubnetManifest{
CIDR: cidr,
GatewayIP: gatewayIP,
BroadcastIP: broadcastIP,
UsedIPs: usedIPs,
RoutingTable: make(map[string]string),
IndexRoutingTable: make(map[string]string),
DNSRecords: make(map[string]string),
}, nil
}
func (p *Provisioner) addAllocationToSubnet(mf jtypes.EnsembleManifest, allocName string) error {
p.lock.Lock()
defer p.lock.Unlock()
if _, ok := p.subnetManifest.IndexRoutingTable[allocName]; ok {
log.Debugf("allocation %s already in subnet", allocName)
return nil
}
ip, err := netutils.GetNextIP(p.subnetManifest.CIDR, p.subnetManifest.UsedIPs)
if err != nil {
return fmt.Errorf("error getting next IP: %w", err)
}
allocManifest, ok := mf.Allocations[allocName]
if !ok {
return fmt.Errorf("allocation %s not found in manifest", allocName)
}
p.subnetManifest.RoutingTable[ip.String()] = mf.Nodes[allocManifest.NodeID].Peer
p.subnetManifest.IndexRoutingTable[allocName] = ip.String()
p.subnetManifest.DNSRecords[allocManifest.DNSName] = ip.String()
p.subnetManifest.UsedIPs[ip.String()] = true
log.Debugf("Added allocation %s with IP %s to subnet", allocName, ip)
return nil
}
func (o *BasicOrchestrator) removeAllocationsFromSubnet(
mf jtypes.EnsembleManifest, allocsNames []string,
) error {
o.lock.Lock()
defer o.lock.Unlock()
for _, allocName := range allocsNames {
// Get the allocation from the manifest
allocManifest, ok := mf.Allocations[allocName]
if !ok {
log.Warnf(
"skipping subnet removal: allocation %s not found in manifest",
allocName,
)
continue
}
// Get the IP address for this allocation
ip, ok := o.subnetManifest.IndexRoutingTable[allocName]
if !ok {
log.Warnf(
"skipping subnet removal: allocation %s not found in index routing table",
allocName,
)
continue
}
// Remove the IP from the used IPs map
delete(o.subnetManifest.UsedIPs, ip)
// Remove the routing table entry
delete(o.subnetManifest.RoutingTable, ip)
// Remove the index routing table entry
delete(o.subnetManifest.IndexRoutingTable, allocName)
// Remove any DNS records associated with this allocation
if allocManifest.DNSName != "" {
delete(o.subnetManifest.DNSRecords, allocManifest.DNSName)
}
log.Debugw("Removed allocation with IP from subnet", "labels", []string{string(observability.LabelDeployment)},
"allocationID", allocName, "ip", ip)
}
return nil
}
func (o *BasicOrchestrator) getAllocIP(allocName string) (string, bool) {
o.lock.Lock()
defer o.lock.Unlock()
ip, ok := o.subnetManifest.IndexRoutingTable[allocName]
if !ok {
return "", false
}
return ip, true
}
// TODO: make it as a method of subnetManifest
// to be tackled here: https://gitlab.com/nunet/device-management-service/-/issues/909
func (p *Provisioner) newSubnetRequests(mf jtypes.EnsembleManifest) ([]subnetRequest, error) {
// subnet config requests (add peer, dns, port map)
subReqs := []subnetRequest{}
for allocName, allocManifest := range mf.Allocations {
ip, ok := p.subnetManifest.IndexRoutingTable[allocName]
if !ok {
return nil, fmt.Errorf("ip not found for allocation %s", allocName)
}
nmf, ok := mf.Nodes[allocManifest.NodeID]
if !ok {
return nil, fmt.Errorf("node not found for allocation %s", allocName)
}
subReqs = append(subReqs, subnetRequest{
handle: allocManifest.Handle,
ip: ip,
peerID: nmf.Peer,
ports: allocManifest.Ports,
})
}
return subReqs, nil
}
func (p *Provisioner) createSubnet(
manifestID string,
routingTable map[string]string,
subCreateHandles []actor.Handle,
) error {
errCh := make(chan error, len(subCreateHandles))
wg := sync.WaitGroup{}
for _, handle := range subCreateHandles {
wg.Add(1)
go func(h actor.Handle) {
defer wg.Done()
msg, err := actor.Message(
p.actor.Handle(),
h,
fmt.Sprintf(behaviors.SubnetCreateBehavior.DynamicTemplate, manifestID),
SubnetCreateRequest{
SubnetID: manifestID,
RoutingTable: routingTable,
CIDR: p.subnetManifest.CIDR,
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet message: %w", err)
return
}
replyCh, err := p.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var response SubnetCreateResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error creating subnet: %s: %w", response.Error, ErrDeploymentFailed)
return
}
case <-time.After(SubnetCreateTimeout):
errCh <- fmt.Errorf("timeout creating subnet: %w", ErrDeploymentFailed)
return
}
log.Infow("subnet successfully created on peer",
"labels", []string{string(observability.LabelDeployment)},
"manifestID", manifestID, "handle", h)
}(handle)
}
wg.Wait()
close(errCh)
return aggregateErrors(errCh)
}
// TODO: this step should be together with subnet creation, we have
// to refactor the SubnetCreate handle
func (p *Provisioner) subnetAddPeer(manifestID string, subReqs []subnetRequest) error {
// 1.b create and plug IPs
wg := sync.WaitGroup{}
errCh := make(chan error, len(subReqs))
for _, req := range subReqs {
wg.Add(1)
go func() {
defer wg.Done()
msg, err := actor.Message(
p.actor.Handle(),
req.handle,
behaviors.SubnetAddPeerBehavior,
behaviors.SubnetAddPeerRequest{
SubnetID: manifestID,
IP: req.ip,
PeerID: req.peerID,
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet add-peer message: %w", err)
return
}
replyCh, err := p.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet add-peer message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var response behaviors.SubnetAddPeerResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet add-peer response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error adding peer to subnet: %s: %w", response.Error, ErrDeploymentFailed)
return
}
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("timeout adding peer to subnet: %w", ErrDeploymentFailed)
return
}
log.Infow("peer successfully added to subnet on peer",
"labels", []string{string(observability.LabelDeployment)}, "handle", req.handle)
}()
}
wg.Wait()
close(errCh)
return aggregateErrors(errCh)
}
func (p *Provisioner) addDNSRecords(
manifestID string,
subReqs []subnetRequest, dnsRecords map[string]string,
) error {
wg := sync.WaitGroup{}
errCh := make(chan error, len(subReqs))
for _, req := range subReqs {
wg.Add(1)
go func() {
defer wg.Done()
msg, err := actor.Message(
p.actor.Handle(),
req.handle,
behaviors.SubnetDNSAddRecordsBehavior,
behaviors.SubnetDNSAddRecordsRequest{
SubnetID: manifestID,
Records: dnsRecords,
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet add-dns-records message: %w", err)
return
}
replyCh, err := p.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet add-dns-records message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var response behaviors.SubnetDNSAddRecordsResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet add-peer response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error sending dns records to peer: %s: %w", response.Error, ErrDeploymentFailed)
return
}
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("timeout sending dns records to subnet: %w", ErrDeploymentFailed)
return
}
log.Infow("DNS records successfully added to subnet on peer", "handle", req.handle)
}()
}
wg.Wait()
close(errCh)
return aggregateErrors(errCh)
}
func (p *Provisioner) removeDNSRecords(
manifestID string,
subReqs []subnetRequest, domainNames []string,
) error {
wg := sync.WaitGroup{}
errCh := make(chan error, len(subReqs))
for _, req := range subReqs {
wg.Add(1)
go func() {
defer wg.Done()
msg, err := actor.Message(
p.actor.Handle(),
req.handle,
behaviors.SubnetDNSRemoveRecordsBehavior,
behaviors.SubnetDNSRemoveRecordsRequest{
SubnetID: manifestID,
DomainNames: domainNames,
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet remove-dns-records message: %w", err)
return
}
replyCh, err := p.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet remove-dns-records message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var response behaviors.SubnetDNSRemoveRecordsResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet remove-peer response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error sending dns records to be removed from peer: %s: %w", response.Error, ErrDeploymentFailed)
return
}
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("timeout removing dns records from subnet: %w", ErrDeploymentFailed)
return
}
log.Infow("DNS records successfully removed from subnet on peer", "handle", req.handle)
}()
}
wg.Wait()
close(errCh)
return aggregateErrors(errCh)
}
func (p *Provisioner) mapPorts(manifestID string, subReqs []subnetRequest) error {
wg := sync.WaitGroup{}
errCh := make(chan error, len(subReqs))
for _, req := range subReqs {
for pubPort := range req.ports {
wg.Add(1)
go func() {
defer wg.Done()
msg, err := actor.Message(
p.actor.Handle(),
req.handle,
behaviors.SubnetMapPortBehavior,
behaviors.SubnetMapPortRequest{
SubnetID: manifestID,
Protocol: "TCP", // TODO: add support in AllocationManifest for protocol
SourceIP: "0.0.0.0",
SourcePort: strconv.Itoa(pubPort),
DestIP: req.ip,
DestPort: strconv.Itoa(pubPort),
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet MapPort message: %w", err)
return
}
replyCh, err := p.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet MapPort message: %w", err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var response behaviors.SubnetMapPortResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet add-peer response: %w", err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error adding peer to subnet: %s: %w", response.Error, ErrDeploymentFailed)
return
}
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("timeout mapping port for subnet: %w", ErrDeploymentFailed)
return
}
log.Info("port mapping successfully added to subnet on peer", req.handle)
}()
}
}
wg.Wait()
close(errCh)
return aggregateErrors(errCh)
}
// TODO: maybe this hsould go to the createSubnet method
func (p *Provisioner) orchestratorJoinSubnet(
manifestID string,
indexRoutingTable map[string]string, routingTable map[string]string, dnsRecords map[string]string,
) error {
behaviorName := fmt.Sprintf(behaviors.SubnetJoinBehavior.DynamicTemplate, manifestID)
msg, err := actor.Message(
p.actor.Handle(),
p.actor.Supervisor(),
behaviorName,
SubnetJoinRequest{
SubnetID: manifestID,
IP: indexRoutingTable[orchSubnetName],
PeerID: p.actor.Handle().Address.HostID,
RoutingTable: routingTable,
Records: dnsRecords,
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
return fmt.Errorf("error creating subnet join message: %w", err)
}
replyCh, err := p.actor.Invoke(msg)
if err != nil {
return fmt.Errorf("error invoking subnet join message: %w", err)
}
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var response SubnetJoinResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
return fmt.Errorf("error unmarshalling subnet join response: %w", err)
}
if !response.OK {
return fmt.Errorf("error joining orchestrator to subnet: %s: %w", response.Error, ErrDeploymentFailed)
}
case <-time.After(orchestratorJoinTimeout):
return fmt.Errorf("timeout joining orchestrator to subnet: %w", ErrDeploymentFailed)
}
log.Info("orchestrator successfully joined the subnet")
return nil
}
// It invokes an subnet update behavior on every peer within the subnet
// with updates on:
// 1. routing table
// 2. dns records
//
// Usually used when allocations are added or removed
func (p *Provisioner) updateSubnetAllocations(
mf jtypes.EnsembleManifest,
newDNSRecords, additionsRoutingTable map[string]string,
) error {
subReqs, err := p.newSubnetRequests(mf)
if err != nil {
return fmt.Errorf("error creating subnet requests: %w", err)
}
// 1. send extend routing table with acceptPeersBehavior
err = p.acceptPeersSubnetBehavior(mf.ID, subReqs, additionsRoutingTable)
if err != nil {
return fmt.Errorf("error adding peers to subnet: %w", err)
}
// 2. dns records
err = p.addDNSRecords(mf.ID, subReqs, newDNSRecords)
if err != nil {
return fmt.Errorf("error adding dns records: %w", err)
}
return nil
}
func (p *Provisioner) acceptPeersSubnetBehavior(
mfID string,
subReqs []subnetRequest, routingTable map[string]string,
) error {
return p.modifyRoutingTableAllocations(mfID, add, subReqs, routingTable)
}
func (p *Provisioner) removePeersSubnetBehavior(
mfID string,
subReqs []subnetRequest, routingTable map[string]string,
) error {
return p.modifyRoutingTableAllocations(mfID, remove, subReqs, routingTable)
}
func (p *Provisioner) modifyRoutingTableAllocations(
mfID string, modificationType recordsModificationType,
subReqs []subnetRequest, routingTable map[string]string,
) error {
var (
behavior string
operationName string
errorMsgPrefix string
timeoutMsg string
successMsg string
)
switch modificationType {
case add:
behavior = behaviors.SubnetAcceptPeersBehavior
operationName = "accept-peer"
errorMsgPrefix = "adding new peers to subnet"
timeoutMsg = "timeout adding new peers to subnet"
successMsg = "new peers added to subnet"
case remove:
behavior = behaviors.SubnetRemovePeersBehavior
operationName = "remove-peer"
errorMsgPrefix = "removing peers from subnet"
timeoutMsg = "timeout removing peers from subnet"
successMsg = "peers removed from subnet"
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(subReqs))
for _, req := range subReqs {
wg.Add(1)
go func(request subnetRequest) {
defer wg.Done()
msg, err := actor.Message(
p.actor.Handle(),
request.handle,
behavior,
behaviors.SubnetAcceptPeersRequest{ // Add/remove payload are equal
SubnetID: mfID,
PartialRoutingTable: routingTable,
},
actor.WithMessageExpiry(uint64(time.Now().Add(5*time.Second).UnixNano())),
)
if err != nil {
errCh <- fmt.Errorf("error creating subnet %s message: %w", operationName, err)
return
}
replyCh, err := p.actor.Invoke(msg)
if err != nil {
errCh <- fmt.Errorf("error invoking subnet %s message: %w", operationName, err)
return
}
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var response behaviors.SubnetAcceptPeersResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
errCh <- fmt.Errorf("error unmarshalling subnet %s response: %w", operationName, err)
return
}
if !response.OK {
errCh <- fmt.Errorf("error %s: %s: %w", errorMsgPrefix, response.Error, ErrDeploymentFailed)
return
}
case <-time.After(2 * time.Minute):
errCh <- fmt.Errorf("%s: %w", timeoutMsg, ErrDeploymentFailed)
return
}
log.Info(successMsg, request.handle)
}(req)
}
wg.Wait()
close(errCh)
return aggregateErrors(errCh)
}
func (p *Provisioner) revertSubnetAllocationsUpdate(
mf jtypes.EnsembleManifest,
dnsRecords, partialRountingTable map[string]string,
) {
subReqs, err := p.newSubnetRequests(mf)
if err != nil {
log.Errorf("Reverting subnet allocations update: error creating subnet requests: %w", err)
}
// 1. remove from routing table
err = p.removePeersSubnetBehavior(mf.ID, subReqs, partialRountingTable)
if err != nil {
log.Errorf("Reverting subnet allocations update: error removing peers from subnet: %w", err)
}
// 2. remove DNS records
err = p.removeDNSRecords(mf.ID, subReqs, utils.MapKeysToSlice(dnsRecords))
if err != nil {
log.Errorf("Reverting subnet allocations update: error removing dns records: %w", err)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
const (
RegisterHealthCheckTimeout = 5 * time.Second
)
var (
HealthCheckTimeout = 5 * time.Second
FailureEscalationTimeout = 2 * time.Minute
)
// Supervisor encapsulates supervision logic.
type Supervisor struct {
id string
ctx context.Context
actor actor.Actor
manifest jtypes.EnsembleManifest
registeredHealthChecks map[string]struct{} // allocationID -> struct{}
failures map[string]int // allocationID -> failureCount
escalations map[string]int // allocationID -> escalationCount
lock sync.Mutex
}
// NewSupervisor creates a new Supervisor instance.
func NewSupervisor(ctx context.Context, actor actor.Actor, id string) *Supervisor {
return &Supervisor{
ctx: ctx,
actor: actor,
id: id,
failures: make(map[string]int),
escalations: make(map[string]int),
registeredHealthChecks: make(map[string]struct{}),
manifest: jtypes.EnsembleManifest{
ID: id,
Allocations: make(map[string]jtypes.AllocationManifest),
Nodes: make(map[string]jtypes.NodeManifest),
},
}
}
// Supervise runs the supervision loop, including registration and periodic healthchecks.
func (s *Supervisor) Supervise(manifestReader jtypes.ManifestReader) {
log.Debugw("supervisor started for orchestrator",
"labels", string(observability.LabelDeployment),
"supervisorID", s.id,
"allocations", s.manifest.Allocations)
manifest := manifestReader.Read()
wg := sync.WaitGroup{}
// Registration Phase – register allocations that have a defined healthcheck.
for _, allocation := range manifest.Allocations {
if allocation.Healthcheck.Type == "" {
continue
}
wg.Add(1)
go func(allocation jtypes.AllocationManifest) {
defer wg.Done()
if err := s.registerHealthCheck(allocation, manifest.Orchestrator); err != nil {
log.Errorf("register healthcheck for allocation: %s", err)
}
}(allocation)
}
wg.Wait()
// Update the manifest
s.lock.Lock()
s.manifest = manifest
s.lock.Unlock()
// Supervision Phase – start the supervision loop
s.startSupervision()
}
// startSupervision performs periodic health checks on registered allocations.
func (s *Supervisor) startSupervision() {
ticker := time.NewTicker(actor.HealthCheckInterval)
defer ticker.Stop()
for {
select {
case <-s.ctx.Done():
return
case <-ticker.C:
wg := sync.WaitGroup{}
for allocationID := range s.registeredHealthChecks {
fmt.Println("allocationID", allocationID)
// Parse the allocation ID to get the manifest key
allocID, err := types.ParseAllocationID(allocationID)
if err != nil {
log.Warnf("failed to parse allocation ID %s: %v", allocationID, err)
continue
}
manifestKey := allocID.ManifestKey()
allocation, ok := s.getAllocation(manifestKey)
if !ok {
log.Warnf("allocation not found in manifest to supervise: %s", allocationID)
continue
}
if allocation.Healthcheck.Type == "" {
log.Debugf("allocation does not have a healthcheck: %s", allocationID)
continue
}
wg.Add(1)
go func(allocation jtypes.AllocationManifest) {
defer wg.Done()
if err := s.performHealthCheck(allocation); err != nil {
log.Errorf("failed to perform healthcheck for allocation %s: %s", allocation.ID, err)
}
}(allocation)
}
wg.Wait()
}
}
}
func (s *Supervisor) registerHealthCheck(allocation jtypes.AllocationManifest, orchestrator actor.Handle) error {
expiry := actor.MakeExpiry(RegisterHealthCheckTimeout)
msg, err := actor.Message(
orchestrator,
allocation.Handle,
behaviors.RegisterHealthcheckBehavior,
behaviors.RegisterHealthcheckRequest{
EnsembleID: s.id,
HealthCheck: allocation.Healthcheck,
},
actor.WithMessageExpiry(expiry),
)
if err != nil {
return fmt.Errorf("create actor message: %w", err)
}
replyCh, err := s.actor.Invoke(msg)
if err != nil {
return fmt.Errorf("register healthcheck on allocation: %w", err)
}
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var response behaviors.RegisterHealthcheckResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
return fmt.Errorf("unmarshalling supervisor reply: %w", err)
}
if !response.OK {
return fmt.Errorf("error registering healthcheck: %s", response.Error)
}
s.lock.Lock()
s.registeredHealthChecks[allocation.ID] = struct{}{}
s.lock.Unlock()
log.Infof("successfully registered healthcheck for allocation: %s", allocation.ID)
return nil
case <-time.After(RegisterHealthCheckTimeout):
return fmt.Errorf("timeout waiting for supervisor reply")
}
}
func (s *Supervisor) unregisterHealthCheck(allocationID string) {
s.lock.Lock()
delete(s.registeredHealthChecks, allocationID)
s.lock.Unlock()
}
func (s *Supervisor) performHealthCheck(allocation jtypes.AllocationManifest) error {
// Parse the allocation ID to get the manifest key for termination check
allocID, err := types.ParseAllocationID(allocation.ID)
if err != nil {
log.Warnf("failed to parse allocation ID %s: %v", allocation.ID, err)
return err
}
manifestKey := allocID.ManifestKey()
if s.manifest.IsTerminatedTask(manifestKey) {
return nil
}
expiry := actor.MakeExpiry(HealthCheckTimeout)
msg, err := actor.Message(
s.actor.Handle(),
allocation.Handle,
actor.HealthCheckBehavior,
allocation.ID,
actor.WithMessageExpiry(expiry),
)
if err != nil {
return fmt.Errorf("create supervisor message: %w", err)
}
replyCh, err := s.actor.Invoke(msg)
if err != nil {
return fmt.Errorf("invoke healthcheck on allocation: %w", err)
}
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var resp behaviors.HealthCheckResponse
if err := json.Unmarshal(reply.Message, &resp); err != nil {
return fmt.Errorf("unmarshalling supervisor reply: %w", err)
}
if !resp.OK {
log.Errorf("error in healthcheck for allocation %s: %s", allocation.ID, resp.Error)
s.lock.Lock()
s.failures[allocation.ID]++
failureCount := s.failures[allocation.ID]
s.lock.Unlock()
if failureCount >= 3 {
log.Infof("escalating failure for allocation %s", allocation.ID)
if err := s.escalateFailure(allocation); err != nil {
log.Errorf("failed to escalate failure: %s", err)
} else {
log.Debug("escalated failure, resetting healthcheck failures counter")
s.lock.Lock()
delete(s.failures, allocation.ID)
s.lock.Unlock()
}
}
} else {
log.Infof("successfully healthchecked allocation %s", allocation.ID)
s.lock.Lock()
delete(s.failures, allocation.ID)
s.lock.Unlock()
}
case <-time.After(HealthCheckTimeout):
if s.manifest.IsTerminatedTask(manifestKey) {
return nil
}
log.Warnf("timeout waiting for supervisor reply for allocation %s", allocation.ID)
s.lock.Lock()
s.failures[allocation.ID]++
v := s.failures[allocation.ID]
s.lock.Unlock()
if v >= 3 {
if err := s.escalateFailure(allocation); err != nil {
log.Errorf("failed to escalate failure: %s", err)
} else {
log.Debug("escalated failure, resetting healthcheck failures counter")
s.lock.Lock()
delete(s.failures, allocation.ID)
s.lock.Unlock()
}
}
return fmt.Errorf("timeout waiting for supervisor reply")
}
return nil
}
// escalateFailure handles escalation when an allocation repeatedly fails its healthcheck.
func (s *Supervisor) escalateFailure(allocation jtypes.AllocationManifest) error {
// TODO we need to decide how to handle repeated failures and also correlated failures
// from a node.
// Also, we should not restart at first failure, but wait for a number of
// consecutive failures.
// See https://gitlab.com/nunet/device-management-service/-/issues/794
log.Debugf("escalating failure for allocation %s", allocation.Handle.String())
log.Debugw("escalating failure for allocation",
"labels", string(observability.LabelAllocation),
"allocationHandle", allocation.Handle.String(),
"supervisorID", s.id)
expiry := actor.MakeExpiry(5 * time.Second)
msg, err := actor.Message(
s.actor.Handle(),
allocation.Handle,
behaviors.AllocationRestartBehavior,
allocation.ID,
actor.WithMessageExpiry(expiry),
)
if err != nil {
return err
}
replyCh, err := s.actor.Invoke(msg)
if err != nil {
return err
}
select {
case reply := <-replyCh:
defer reply.Discard()
var resp behaviors.AllocationRestartResponse
if err := json.Unmarshal(reply.Message, &resp); err != nil {
return fmt.Errorf("unmarshalling supervisor reply: %w", err)
}
if !resp.OK {
return fmt.Errorf("error restarting allocation: %s", resp.Error)
}
s.lock.Lock()
defer s.lock.Unlock()
s.escalations[allocation.ID]++
s.failures[allocation.ID] = 0
return nil
case <-time.After(FailureEscalationTimeout):
return fmt.Errorf("timeout waiting for supervisor reply")
}
}
// Update updates the supervisor with a new ensemble manifest.
//
// 1. register new healthchecks and start healthchecking
// 2. unregister healthchecks for removed allocations
// 3. update manifest
//
// For 1. and 2., the updates will already be reflected on
// the next ticker.
func (s *Supervisor) Update(manifestReader jtypes.ManifestReader) {
manifest := manifestReader.Read()
log.Debug("updating supervisor")
// 1. Registering the healthchecks for just the new allocations
var wg sync.WaitGroup
for _, allocation := range manifest.Allocations {
if allocation.Healthcheck.Type == "" {
continue
}
// register healthcheck only if it is not already present
if _, ok := s.getAllocation(types.AllocationNameFromID(allocation.ID)); !ok {
wg.Add(1)
go func(allocation jtypes.AllocationManifest) {
defer wg.Done()
if err := s.registerHealthCheck(allocation, manifest.Orchestrator); err != nil {
log.Errorf("failed to register healthcheck for allocation: %s", err)
}
}(allocation)
} else {
// 2. unregister healthcheck if it is no longer present
s.unregisterHealthCheck(allocation.ID)
}
}
wg.Wait()
// 3. Update the manifest
s.lock.Lock()
s.manifest = manifest
s.lock.Unlock()
}
func (s *Supervisor) getAllocation(name string) (jtypes.AllocationManifest, bool) {
s.lock.Lock()
defer s.lock.Unlock()
a, ok := s.manifest.Allocations[name]
return a, ok
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"encoding/json"
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/tokenomics/eventhandler"
"gitlab.com/nunet/device-management-service/tokenomics/events"
)
type SubnetDestroyRequest struct {
SubnetID string
}
type SubnetDestroyResponse struct {
OK bool
Error string
}
type AllocationStopRequest struct {
AllocationID string
}
type AllocationStopResponse struct {
OK bool
Error string
}
func (o *BasicOrchestrator) Shutdown() error {
allocStatuses := make(map[string]jtypes.AllocationStatus)
if o.status == jtypes.DeploymentStatusCompleted || o.status == jtypes.DeploymentStatusShuttingDown {
log.Error("orchestrator already shutting down or completed")
return nil
}
o.setStatus(jtypes.DeploymentStatusShuttingDown)
o.lock.Lock()
log.Infow("orchestrator_shutdown_initiated",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.id)
defer func() {
// set statuses on alloc manifest
for allocName, status := range allocStatuses {
err := o.manifest.UpdateAllocation(allocName, func(alloc *jtypes.AllocationManifest) {
alloc.Status = status
})
if err != nil {
log.Errorf("failed to update allocation manifest %s status: %v", allocName, err)
}
}
o.lock.Unlock()
log.Infow("status updated",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.id,
"status", jtypes.DeploymentStatusCompleted.String(),
)
// set orchestrator status
o.setStatus(jtypes.DeploymentStatusCompleted)
if o.cancel != nil {
o.cancel()
}
for _, v := range o.contracts {
evt := events.DeploymentStop{
EventBase: events.EventBase{Type: events.DeploymentStopEvent},
DeploymentID: o.manifest.ID,
OrchestratorID: o.id,
HeadContractDID: v.DID, // treat contrat as if head of contract chain, won't be taken into consideration in billing if contract is p2p
}
o.contractEventHandler.Push(eventhandler.Event{
ContractHostDID: v.Host,
ContractDID: v.DID,
Payload: evt,
})
}
o.UpdateAllocationStatus()
}()
destroyHandles := map[string]actor.Handle{}
for _, node := range o.manifest.Nodes {
destroyHandles[node.ID] = node.Handle
}
if o.manifest.Subnet.Join {
destroyHandles["orchestrator"] = o.actor.Supervisor()
}
errCh1 := make(chan error, len(destroyHandles))
wg := sync.WaitGroup{}
for id, handle := range destroyHandles {
wg.Add(1)
go func(h actor.Handle, id string) {
defer wg.Done()
msg, err := actor.Message(
o.actor.Handle(),
h,
fmt.Sprintf(behaviors.SubnetDestroyBehavior.DynamicTemplate, o.manifest.ID),
SubnetDestroyRequest{
SubnetID: o.manifest.ID,
},
actor.WithMessageExpiry(actor.MakeExpiry(5*time.Second)),
)
if err != nil {
log.Errorf("error creating stop message for %s/%s: %s", o.manifest.ID, id, err)
errCh1 <- err
return
}
// invoke the subnet destroy message
replyCh, err := o.actor.Invoke(msg)
if err != nil {
log.Errorf("error invoking stop message for %s/%s: %s", o.manifest.ID, id, err)
errCh1 <- err
return
}
var reply actor.Envelope
// wait for the reply
select {
case reply = <-replyCh:
defer reply.Discard()
var resp SubnetDestroyResponse
if err := json.Unmarshal(reply.Message, &resp); err != nil {
log.Errorf("error unmarshalling subnet destroy response: %v", err)
errCh1 <- err
return
}
if !resp.OK {
log.Errorf("failed to destroy subnet %s/%s: %v", o.manifest.ID, id, resp.Error)
errCh1 <- fmt.Errorf("failed to destroy subnet %s/%s: %v", o.manifest.ID, id, resp.Error)
return
}
case <-time.After(SubnetDestroyTimeout):
log.Errorf("timeout destroying subnet %s", o.manifest.ID)
errCh1 <- fmt.Errorf("timeout destroying subnet %s", o.manifest.ID)
return
}
log.Infof("subnet %s destroyed", o.manifest.ID)
}(handle, id)
}
wg.Wait()
close(errCh1)
errCh2 := make(chan error, len(o.manifest.Allocations))
wg = sync.WaitGroup{}
for allocName, alloc := range o.manifest.Allocations {
wg.Add(1)
go func(h actor.Handle, allocID string, allocName string) {
defer wg.Done()
msg, err := actor.Message(
o.actor.Handle(),
h,
fmt.Sprintf(behaviors.AllocationShutdownBehavior.DynamicTemplate, o.manifest.ID),
AllocationStopRequest{
AllocationID: allocID,
},
actor.WithMessageExpiry(actor.MakeExpiry(AllocationShutdownTimeout)),
)
if err != nil {
log.Errorf("error creating stop message for alloc: %s: %v", allocID, err)
errCh2 <- err
return
}
// invoke the stop message
replyCh, err := o.actor.Invoke(msg)
if err != nil {
log.Errorf("error invoking stop message for %s: %v", allocID, err)
errCh2 <- err
return
}
// wait for the reply
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var resp AllocationStopResponse
if err := json.Unmarshal(reply.Message, &resp); err != nil {
log.Errorf("error unmarshalling stop allocation response: %s", err)
errCh2 <- err
return
}
if !resp.OK {
log.Errorf("failed to stop allocation %s", allocID)
errCh2 <- fmt.Errorf("failed to stop allocation %s", allocID)
return
}
case <-time.After(AllocationShutdownTimeout):
log.Errorf("timeout stopping allocation %s", allocID)
errCh2 <- fmt.Errorf("timeout stopping allocation %s", allocID)
return
}
log.Infof("allocation %s stopped", allocID)
allocStatuses[allocName] = jtypes.AllocationCompleted
}(o.manifest.Nodes[alloc.NodeID].Handle, alloc.ID, allocName)
}
wg.Wait()
log.Infow("orchestrator_shutdown_complete",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.id)
close(errCh2)
err1 := aggregateErrors(errCh1)
err2 := aggregateErrors(errCh2)
if err1 != nil || err2 != nil {
return fmt.Errorf("errors occurred during shutdown: %w, %w", err1, err2)
}
return nil
}
type DeploymentRevertRequest struct {
EnsembleID string
AllocsByName []string
}
type DeploymentRevertResponse struct {
OK bool
Error string
}
func (o *BasicOrchestrator) revertNodeDeployment(
cfg jtypes.EnsembleConfig, n string, h actor.Handle,
) {
ncfg, ok := cfg.Node(n)
if !ok {
log.Warnf("revert node: failed to find node config for %s", n)
return
}
// Instead of sending the allocation names with nodeID prefix,
// we should send the complete allocation IDs as created by types.ConstructAllocationID
// to match exactly how they're stored in the allocator
allocIDs := make([]string, 0, len(ncfg.Allocations))
for _, allocName := range ncfg.Allocations {
// Generate full allocation ID using the generator
allocID, err := o.allocationIDGenerator.GenerateFullAllocationID(o.id, n, allocName)
if err != nil {
log.Errorf("failed to generate full allocation ID for %s.%s: %v", n, allocName, err)
continue
}
allocIDs = append(allocIDs, allocID)
}
msg, err := actor.Message(
o.actor.Handle(),
h,
behaviors.DeploymentRevertBehavior,
DeploymentRevertRequest{
EnsembleID: o.id,
AllocsByName: allocIDs,
},
)
if err != nil {
log.Debugw("revert_message_create_failure",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n,
"error", err)
return
}
replyCh, err := o.actor.Invoke(msg)
if err != nil {
log.Debugw("revert_message_invoke_failure",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n,
"error", err)
return
}
// Wait for revert response with timeout
select {
case reply := <-replyCh:
defer reply.Discard()
var response DeploymentRevertResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
log.Errorw("failed to unmarshal revert response",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n,
"error", err)
return
}
if !response.OK {
log.Errorw("revert failed on node",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n,
"error", response.Error)
return
}
log.Debugw("revert completed successfully",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n)
case <-time.After(30 * time.Second):
log.Errorw("revert timeout on node",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n)
return
}
o.removeNodeFromManifest(n)
log.Debugw("revert message sent successfully",
"labels", []string{string(observability.LabelDeployment)},
"nodeID", n)
}
func (o *BasicOrchestrator) revert(cfg jtypes.EnsembleConfig, mf jtypes.EnsembleManifest) {
log.Infow("reverting manifest",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID)
// Log the manifest content to see what nodes are being reverted
log.Infow("manifest nodes being reverted",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
"manifestNodes", len(mf.Nodes),
"nodeIDs", func() []string {
var nodeIDs []string
for nodeID := range mf.Nodes {
nodeIDs = append(nodeIDs, nodeID)
}
return nodeIDs
}(),
)
// BUG FIX: Collect all nodes first to avoid modifying map during iteration
// The original code was calling removeNodeFromManifest during iteration,
// which modified mf.Nodes and caused the iteration to stop prematurely
nodesToRevert := make([]struct {
nodeID string
handle actor.Handle
}, len(mf.Nodes))
for n, nmf := range mf.Nodes {
nodesToRevert = append(nodesToRevert, struct {
nodeID string
handle actor.Handle
}{n, nmf.Handle})
}
for i, node := range nodesToRevert {
log.Infow("attempting to revert node",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
"nodeID", node.nodeID,
"handle", node.handle.String(),
"iteration", i+1,
"total", len(nodesToRevert),
)
o.revertNodeDeployment(cfg, node.nodeID, node.handle)
log.Infow("completed reverting node",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
"nodeID", node.nodeID,
)
}
// If subnet was being joined, destroy it during revert
if mf.Subnet.Join {
// IMPORTANT
// For the deployment restoration (#1166) to work properly
// We need to cleanup the subnetManifest state for the orchestrator
// to trigger the code path that will result in re-creating the subnet on the orchestrator
// during redeployment (during restoration)
// without this line, the redeployment will fail and will continually try to redeploy without succes
// because the orchestrator will try to join the subnet where it wasn't created (i.e: subnet doesnt exist error)
o.revertSubnetManifest()
log.Infow("destroying subnet during revert",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
)
// Send subnet destroy request to the supervisor
msg, err := actor.Message(
o.actor.Handle(),
o.actor.Supervisor(),
fmt.Sprintf(behaviors.SubnetDestroyBehavior.DynamicTemplate, mf.ID),
SubnetDestroyRequest{
SubnetID: mf.ID,
},
)
if err != nil {
log.Errorw("failed to create subnet destroy message",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
"error", err,
)
return
}
// Send the message and wait for response
replyCh, err := o.actor.Invoke(msg)
if err != nil {
log.Errorw("failed to invoke subnet destroy message during revert",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
"error", err,
)
return
}
// Wait for subnet destroy response with timeout
select {
case reply := <-replyCh:
defer reply.Discard()
var response SubnetDestroyResponse
if err := json.Unmarshal(reply.Message, &response); err != nil {
log.Errorw("failed to unmarshal subnet destroy response",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
"error", err,
)
return
} else if !response.OK {
log.Errorw("subnet destroy failed during revert",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
"error", response.Error,
)
return
} else {
log.Infow("subnet destroy completed successfully during revert",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
)
}
case <-time.After(30 * time.Second):
log.Errorw("subnet destroy timeout during revert",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", mf.ID,
)
}
}
}
// revertSubnetManifest reverts the subnet manifest
// only as far as the orchestrator is concerned
// we retain the same subnet state for the rest of the peers
// to avoid re-shaping the subnet with new IPs and routing tables.
func (o *BasicOrchestrator) revertSubnetManifest() {
orchestratorIP := o.subnetManifest.IndexRoutingTable[orchSubnetName] // IMPORTANT: to re-trigger subnet creation during restoration
delete(o.subnetManifest.IndexRoutingTable, orchSubnetName)
delete(o.subnetManifest.RoutingTable, orchestratorIP)
delete(o.subnetManifest.UsedIPs, orchestratorIP)
delete(o.subnetManifest.DNSRecords, orchSubnetName)
log.Infow("reverted subnet manifest",
"labels", []string{string(observability.LabelDeployment)},
"orchestratorID", o.id,
)
}
// removeNodeFromManifest removes the node from the manifest and its allocations
func (o *BasicOrchestrator) removeNodeFromManifest(name string) {
log.Infof("removing node %s from manifest", name)
o.lock.Lock()
defer o.lock.Unlock()
n, ok := o.manifest.Node(name)
if !ok {
return
}
// Remove allocations from subnet (inline the logic to avoid nested locking)
for _, allocName := range n.Allocations {
allocManifest, ok := o.manifest.Allocations[allocName]
if !ok {
log.Warnf(
"skipping subnet removal: allocation %s not found in manifest",
allocName,
)
continue
}
ip, ok := o.subnetManifest.IndexRoutingTable[allocName]
if !ok {
log.Warnf(
"skipping subnet removal: allocation %s not found in index routing table",
allocName,
)
continue
}
// Remove the IP from the used IPs map
delete(o.subnetManifest.UsedIPs, ip)
// Remove the routing table entry
delete(o.subnetManifest.RoutingTable, ip)
// Remove the index routing table entry
delete(o.subnetManifest.IndexRoutingTable, allocName)
// Remove any DNS records associated with this allocation
if allocManifest.DNSName != "" {
delete(o.subnetManifest.DNSRecords, allocManifest.DNSName)
}
log.Debugf("Removed allocation %s with IP %s from subnet", allocName, ip)
}
// Clean up allocations and remove node
for _, a := range n.Allocations {
// be careful with redundant allocations
alloc := o.manifest.Allocations[a]
alloc.Status = jtypes.AllocationTerminated
o.manifest.Allocations[a] = alloc
// XXX we're setting the status and then removing the allocs
// the status is irrelevant at this point but keeping it
// since we will most likely need to move to keeping with
// a removed status to keep a history of the deployment.
delete(o.manifest.Allocations, a)
}
delete(o.manifest.Nodes, name)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"encoding/json"
"fmt"
"testing"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network"
)
type TestDMS struct {
priv crypto.PrivKey
pub crypto.PubKey
peerID peer.ID
handle actor.Handle
actor actor.Actor
super actor.Actor // orchestrator actor is the child of node actor. This is the parent actor in the test.
net network.Network
channels map[string]chan struct{}
allocationActors map[string]actor.Actor // Keep allocation actors in memory
}
func MakeProvider(t *testing.T, substrate *network.Substrate) TestDMS {
t.Helper()
mockActor, peer, handle, priv, pub := actor.NewMockActorForTest(t, actor.Handle{}, substrate)
dms := TestDMS{
priv: priv,
pub: pub,
peerID: peer.GetHostID(),
handle: handle,
actor: mockActor,
super: nil,
net: peer,
channels: make(map[string]chan struct{}),
allocationActors: make(map[string]actor.Actor),
}
return dms
}
func MakeOrchestrator(t *testing.T, substrate *network.Substrate) TestDMS {
t.Helper()
mockActor, peer, handle, priv, pub := actor.NewMockActorForTest(t, actor.Handle{}, substrate)
childActor, err := mockActor.CreateChild("test-orch-child", handle)
require.NoError(t, err)
require.NoError(t, childActor.Start())
dms := TestDMS{
priv: priv,
pub: pub,
peerID: peer.GetHostID(),
handle: handle,
actor: childActor,
super: mockActor,
net: peer,
channels: make(map[string]chan struct{}),
allocationActors: make(map[string]actor.Actor),
}
return dms
}
func (dms *TestDMS) MockOrchestratorBehaviors(t *testing.T, ensembleID string) {
t.Helper()
dms.channels[fmt.Sprintf(behaviors.SubnetCreateBehavior.DynamicTemplate, ensembleID)] = make(chan struct{}, 1)
require.NoError(t, dms.super.AddBehavior(fmt.Sprintf(behaviors.SubnetCreateBehavior.DynamicTemplate, ensembleID), func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
reply, err := actor.ReplyTo(msg, SubnetCreateResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
dms.channels[fmt.Sprintf(behaviors.SubnetJoinBehavior.DynamicTemplate, ensembleID)] = make(chan struct{}, 1)
require.NoError(t, dms.super.AddBehavior(fmt.Sprintf(behaviors.SubnetJoinBehavior.DynamicTemplate, ensembleID), func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
reply, err := actor.ReplyTo(msg, SubnetJoinResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
// Add SubnetDestroy behavior for revert operations
dms.channels[fmt.Sprintf(behaviors.SubnetDestroyBehavior.DynamicTemplate, ensembleID)] = make(chan struct{}, 1)
require.NoError(t, dms.super.AddBehavior(fmt.Sprintf(behaviors.SubnetDestroyBehavior.DynamicTemplate, ensembleID), func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
reply, err := actor.ReplyTo(msg, SubnetDestroyResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
}
func (dms *TestDMS) MockDeploymentBehaviors(t *testing.T, ensembleID string, bidBehavior func(msg actor.Envelope), orchestratorActor ...actor.Actor) {
t.Helper()
defaultBidBehavior := func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
var request jtypes.EnsembleBidRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
t.Fatalf("unmarshal bid request: %s", err)
}
// send bid response
bid := jtypes.Bid{
V1: &jtypes.BidV1{
EnsembleID: request.ID,
NodeID: request.Request[0].V1.NodeID,
Peer: dms.handle.Address.HostID,
Location: jtypes.Location{Country: "US"},
Handle: dms.handle,
},
}
// sign the bid using the provider's private key
// Create DID provider for signing
providerDID := did.NewProvider(dms.actor.Handle().DID, dms.priv)
// Sign the bid
require.NoError(t, bid.Sign(providerDID))
var opt []actor.MessageOption
if msg.IsBroadcast() {
opt = append(opt, actor.WithMessageSource(dms.actor.Handle()))
}
reply, err := actor.ReplyTo(msg, bid, opt...)
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}
// Add compute provider behaviors
dms.channels[behaviors.BidRequestBehavior] = make(chan struct{})
if bidBehavior == nil {
bidBehavior = defaultBidBehavior
}
require.NoError(t, dms.actor.AddBehavior(behaviors.BidRequestBehavior, bidBehavior, []actor.BehaviorOption{
actor.WithBehaviorTopic(behaviors.BidRequestTopic),
}...))
dms.channels[behaviors.CommitDeploymentBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(behaviors.CommitDeploymentBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
reply, err := actor.ReplyTo(msg, CommitDeploymentResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
dms.channels[behaviors.AllocationDeploymentBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(behaviors.AllocationDeploymentBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
var request jtypes.AllocationDeploymentRequest
if err := json.Unmarshal(msg.Message, &request); err != nil {
t.Fatalf("unmarshal allocation deployment request: %s", err)
}
allocs := request.Allocations
// Create actual allocation actors for each allocation
allocationHandles := make(map[string]actor.Handle)
// Create allocation actor for alloc1 if it doesn't exist
for alloc := range allocs {
if _, exists := dms.allocationActors[alloc]; !exists {
allocationActor, err := dms.actor.CreateChild(alloc, dms.actor.Handle())
require.NoError(t, err)
// Set up subnet behaviors on the allocation actor
require.NoError(t, allocationActor.AddBehavior(behaviors.SubnetAddPeerBehavior, func(msg actor.Envelope) {
defer msg.Discard()
reply, err := actor.ReplyTo(msg, behaviors.SubnetAddPeerResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = allocationActor.Handle()
require.NoError(t, allocationActor.Send(reply))
}))
require.NoError(t, allocationActor.AddBehavior(behaviors.SubnetMapPortBehavior, func(msg actor.Envelope) {
defer msg.Discard()
reply, err := actor.ReplyTo(msg, behaviors.SubnetMapPortResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = allocationActor.Handle()
require.NoError(t, allocationActor.Send(reply))
}))
require.NoError(t, allocationActor.AddBehavior(behaviors.SubnetDNSAddRecordsBehavior, func(msg actor.Envelope) {
defer msg.Discard()
reply, err := actor.ReplyTo(msg, behaviors.SubnetDNSAddRecordsResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = allocationActor.Handle()
require.NoError(t, allocationActor.Send(reply))
}))
require.NoError(t, allocationActor.AddBehavior(behaviors.AllocationStartBehavior, func(msg actor.Envelope) {
defer msg.Discard()
reply, err := actor.ReplyTo(msg, behaviors.AllocationStartResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = allocationActor.Handle()
require.NoError(t, allocationActor.Send(reply))
}))
require.NoError(t, allocationActor.Start())
// Store the allocation actor in the TestDMS struct
dms.allocationActors[alloc] = allocationActor
// Grant capabilities between allocation actor and orchestrator
// This simulates what happens in the commit phase
if len(orchestratorActor) > 0 {
// Grant capabilities from orchestrator to allocation actor (OrchestratorNamespace)
err = orchestratorActor[0].Security().Grant(
allocationActor.Handle().DID,
orchestratorActor[0].Handle().DID,
[]ucan.Capability{behaviors.OrchestratorNamespace},
5*time.Minute,
)
require.NoError(t, err)
// Grant capabilities from allocation actor to orchestrator (AllocationNamespace)
err = allocationActor.Security().Grant(
orchestratorActor[0].Handle().DID,
allocationActor.Handle().DID,
[]ucan.Capability{behaviors.AllocationNamespace},
5*time.Minute,
)
require.NoError(t, err)
}
}
allocationHandles[alloc] = dms.allocationActors[alloc].Handle()
}
reply, err := actor.ReplyTo(msg, jtypes.AllocationDeploymentResponse{
OK: true,
Allocations: allocationHandles,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
dms.channels[fmt.Sprintf(behaviors.SubnetCreateBehavior.DynamicTemplate, ensembleID)] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(fmt.Sprintf(behaviors.SubnetCreateBehavior.DynamicTemplate, ensembleID), func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
reply, err := actor.ReplyTo(msg, SubnetCreateResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
dms.channels[behaviors.SubnetAddPeerBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(behaviors.SubnetAddPeerBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
reply, err := actor.ReplyTo(msg, behaviors.SubnetAddPeerResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
dms.channels[behaviors.SubnetDNSAddRecordsBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(behaviors.SubnetDNSAddRecordsBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
reply, err := actor.ReplyTo(msg, behaviors.SubnetDNSAddRecordsResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
dms.channels[behaviors.SubnetMapPortBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(behaviors.SubnetMapPortBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
go func() { dms.channels[msg.Behavior] <- struct{}{} }()
}()
reply, err := actor.ReplyTo(msg, behaviors.SubnetMapPortResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
require.NoError(t, dms.actor.Subscribe(behaviors.BidRequestTopic, func(_ string) error {
return nil
}))
}
// MockCommittingStateBehaviors sets up behaviors specific to committing state restoration
func (dms *TestDMS) MockCommittingStateBehaviors(t *testing.T, ensembleID string) {
t.Helper()
// Mock behavior for deployment revert (called before redeploying)
dms.channels[behaviors.DeploymentRevertBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(behaviors.DeploymentRevertBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
dms.channels[msg.Behavior] <- struct{}{}
}()
// Send reply for invoke-style messaging
reply, err := actor.ReplyTo(msg, DeploymentRevertResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
// Mock behavior for reverting node deployment
shutdownBehavior := fmt.Sprintf(behaviors.AllocationShutdownBehavior.DynamicTemplate, ensembleID)
dms.channels[shutdownBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(shutdownBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
dms.channels[msg.Behavior] <- struct{}{}
}()
reply, err := actor.ReplyTo(msg, behaviors.AllocationRestartResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
// Mock subnet behaviors for provisioning phase
// Note: These need to match the ensemble ID used in the test
// Mock SubnetCreateBehavior
createBehavior := fmt.Sprintf(behaviors.SubnetCreateBehavior.DynamicTemplate, ensembleID)
dms.channels[createBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(createBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
dms.channels[msg.Behavior] <- struct{}{}
}()
reply, err := actor.ReplyTo(msg, SubnetCreateResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
// Mock SubnetJoinBehavior
joinBehavior := fmt.Sprintf(behaviors.SubnetJoinBehavior.DynamicTemplate, ensembleID)
dms.channels[joinBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(joinBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
dms.channels[msg.Behavior] <- struct{}{}
}()
reply, err := actor.ReplyTo(msg, SubnetJoinResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
// Mock AllocationStartBehavior
dms.channels[behaviors.AllocationStartBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(behaviors.AllocationStartBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
dms.channels[msg.Behavior] <- struct{}{}
}()
reply, err := actor.ReplyTo(msg, behaviors.AllocationStartResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
// Mock behavior for allocation start
dms.channels[behaviors.AllocationStartBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(behaviors.AllocationStartBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
dms.channels[msg.Behavior] <- struct{}{}
}()
reply, err := actor.ReplyTo(msg, behaviors.AllocationStartResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
// Mock behavior for health check registration
dms.channels[behaviors.RegisterHealthcheckBehavior] = make(chan struct{}, 1)
require.NoError(t, dms.actor.AddBehavior(behaviors.RegisterHealthcheckBehavior, func(msg actor.Envelope) {
defer func() {
msg.Discard()
dms.channels[msg.Behavior] <- struct{}{}
}()
reply, err := actor.ReplyTo(msg, behaviors.RegisterHealthcheckResponse{
OK: true,
})
require.NoError(t, err)
reply.To = msg.From
reply.From = dms.handle
require.NoError(t, dms.actor.Send(reply))
}))
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import (
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"go.uber.org/multierr"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
jtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/utils"
)
// Update updates a running ensemble.
//
// Removed allocations/nodes shall NOT be reverted in case
// new deployments fail.
//
// TODO: if one of the deployment fails, revert all other deployments
//
// TODO: we may have to see how to handle DependsOn, for now we'll ignore it
func (o *BasicOrchestrator) Update(modifiedCfg jtypes.EnsembleConfig, expiry time.Time) error {
o.setStatus(jtypes.DeploymentStatusUpdating)
defer o.setStatus(jtypes.DeploymentStatusRunning)
if time.Now().After(expiry) {
return fmt.Errorf("update expiry time has already passed")
}
if err := modifiedCfg.Validate(); err != nil {
return fmt.Errorf("invalid ensemble configuration: %w", err)
}
err := validateEnsembleUpdate(o.cfg, modifiedCfg)
if err != nil {
return fmt.Errorf("invalid ensemble update: %w", err)
}
// 0. Save current state for potential rollback
// currentConfig := o.cfg.Clone()
// currentManifest := o.manifest.Clone()
// 1. teardown removed nodes and allocations
err = o.handleEnsembleRemovals(modifiedCfg)
if err != nil {
return fmt.Errorf("handling ensemble removals: %w", err)
}
// 2. deploy new nodes
err = o.handleNewAllocations(modifiedCfg, expiry)
if err != nil {
return fmt.Errorf("deploying new nodes: %w", err)
}
// 4. start supervisor for new allocations
o.supervisor.Update(jtypes.NewManifestReader(o.manifest))
return nil
}
// handleNewAllocations deployes new allocations to the running ensemble.
// - It does NOT remove allocations from existent nodes
//
// It is similar to o.deploy but it adds:
// 1. extends subnet's routing table and dns records
// 2. updating all other nodes's subnets on the ensemble.
//
// It _implictly_ updates o.manifest by the use of o.commit and
// o.provision
//
// TODO: 6. revert subnet updates
func (o *BasicOrchestrator) handleNewAllocations(
modifiedCfg jtypes.EnsembleConfig, expiry time.Time,
) error {
existingNodes := make(map[string]string)
for n, node := range o.manifest.Nodes {
existingNodes[n] = node.Peer
}
newConfig, err := newConfigForDeploymentUpdate(
o.cfg,
modifiedCfg,
existingNodes,
)
if err != nil {
return fmt.Errorf("creating new nodes config: %w", err)
}
if len(newConfig.Allocations()) == 0 {
return nil
}
addNodesAndAllocsToCfg := func() {
for name := range newConfig.Nodes() {
if node, ok := modifiedCfg.Node(name); ok {
o.lock.Lock()
o.cfg.AddNodeAndAllocations(name, node, newConfig.Allocations())
o.lock.Unlock()
}
}
}
updateManifest := func(manifest jtypes.EnsembleManifest) {
currentManifest := o.Manifest()
for n, node := range manifest.Nodes {
for _, alloc := range node.Allocations {
currentManifest.Allocations[alloc] = manifest.Allocations[alloc]
}
if nmf, ok := currentManifest.Node(n); ok {
node.Allocations = append(node.Allocations, nmf.Allocations...)
}
currentManifest.Nodes[n] = node
}
o.updateManifest(currentManifest)
}
deploy:
for time.Now().Before(expiry) {
// 1. bid
bidC, err := NewBidCoordinator(o.id, o.actor)
if err != nil {
return fmt.Errorf("failed to create bidder: %w", err)
}
bidC.getNonce()
candidate, err := bidC.bid(jtypes.NewEnsembleCfgReader(newConfig), o.DeploymentSnapshot().Candidates, expiry)
if err != nil {
if errors.Is(err, ErrCandidateNotFound) {
log.Warnf("candidate deployment not found, redeploying: %v", err)
continue deploy
}
return fmt.Errorf("failed to bid: %v", err)
}
newManifest := o.newManifest(newConfig)
for alloc, amf := range o.manifest.Allocations {
newManifest.Allocations[alloc] = amf
}
// 2. commit
committer := NewCommitter(o.ctx, o.id, o.actor, o.allocationIDGenerator, o.nodeIDGenerator)
updatedManifest, err := committer.commit(
jtypes.NewEnsembleCfgReader(newConfig),
jtypes.NewManifestReader(newManifest), candidate)
if err != nil {
log.Warnf("committing for new nodes: %w", err)
continue deploy
}
// 3. extend subnet manifest with new nodes
provisioner := NewProvisioner(o.ctx, o.cancel, o.actor, o.subnetManifest, o.allocationIDGenerator)
addedDNSRecords := make(map[string]string, len(updatedManifest.Allocations))
routingTableExtension := make(map[string]string, len(updatedManifest.Allocations))
for allocName, alloc := range updatedManifest.Allocations {
err := provisioner.addAllocationToSubnet(updatedManifest, allocName)
if err != nil {
log.Warnf(
"error adding allocation %s to subnet: %w",
allocName, err)
err := o.removeAllocationsFromSubnet(
updatedManifest,
utils.MapKeysToSlice(updatedManifest.Allocations),
)
if err != nil {
log.Warnf("error removing allocations from subnet: %w", err)
}
continue deploy
}
ip, ok := o.getAllocIP(allocName)
if !ok {
log.Warnf("allocation %s not found in subnet", allocName)
}
addedDNSRecords[alloc.DNSName] = ip
routingTableExtension[ip] = updatedManifest.Nodes[alloc.NodeID].Peer
}
// 4. provision subnet
skip := utils.MapKeysToSlice(existingNodes)
mfAfterSubnet, err := provisioner.provisionSubnet(updatedManifest, skip...)
if err != nil {
o.revert(newConfig, updatedManifest)
log.Errorf("provision subnet for new nodes (will revert deployment): %w", err)
continue deploy
}
// 5. update existent nodes with new subnet information
err = provisioner.updateSubnetAllocations(mfAfterSubnet, addedDNSRecords, routingTableExtension)
if err != nil {
o.revert(newConfig, updatedManifest)
provisioner.revertSubnetAllocationsUpdate(mfAfterSubnet, addedDNSRecords, routingTableExtension)
log.Warnf("updating subnet allocations for new nodes (will revert deployment): %w", err)
continue deploy
}
// 6. provision allocations
mfAFterProvisionAllocs, err := provisioner.provisionAllocations(newConfig, mfAfterSubnet)
if err != nil {
o.revert(newConfig, updatedManifest)
provisioner.revertSubnetAllocationsUpdate(mfAfterSubnet, addedDNSRecords, routingTableExtension)
log.Warnf("provisioning allocations for new nodes (will revert deployment): %w", err)
continue deploy
}
log.Info("updated: new nodes deployed successfully")
// 7. update config and manifest with added nodes
updateManifest(mfAFterProvisionAllocs)
addNodesAndAllocsToCfg()
o.deploymentSnapshot.Candidates = candidate
return nil
}
return nil
}
// handleEnsembleRemovals handles both removals of nodes and allocations in
// a best effort basis.
func (o *BasicOrchestrator) handleEnsembleRemovals(modifiedCfg jtypes.EnsembleConfig) error {
var errs error
log.Infof("removing nodes and allocations from config %+v", modifiedCfg.V1)
// 1. teardown removed nodes
removeNodesCfg, err := newConfigForRemovedNodes(
o.cfg, modifiedCfg,
)
if err != nil {
errs = multierr.Append(errs, err)
} else if len(removeNodesCfg.Nodes()) > 0 {
mf, err := manifestOnlyForNodes(o.manifest, utils.MapKeysToSlice(removeNodesCfg.Nodes()))
if err != nil {
errs = multierr.Append(errs, err)
} else {
o.revert(removeNodesCfg, mf)
o.removeNodesAllocationsFromCfg(utils.MapKeysToSlice(removeNodesCfg.Nodes()))
}
}
// 2. teardown allocations from existent nodes
for n := range o.cfg.Nodes() {
allocs := identifyRemovedAllocations(o.cfg, modifiedCfg, n)
if len(allocs) == 0 {
continue
}
// Generate manifest keys using the allocation ID generator
allocNamesForSubnetRemoval := make(map[string]jtypes.AllocationConfig, len(allocs))
for alloc, allocCfg := range allocs {
allocName, err := o.allocationIDGenerator.GenerateManifestKey(n, alloc)
if err != nil {
log.Errorf("failed to generate manifest key for %s.%s: %v", n, alloc, err)
continue
}
allocNamesForSubnetRemoval[allocName] = allocCfg
}
errCh := make(chan error, len(allocs))
wg := sync.WaitGroup{}
for alloc := range allocNamesForSubnetRemoval {
amf, ok := o.manifest.Allocation(alloc)
if !ok {
errCh <- fmt.Errorf("allocation %s not found in manifest", alloc)
continue
}
wg.Add(1)
go func(h actor.Handle, allocID string) {
defer wg.Done()
msg, err := actor.Message(
o.actor.Handle(),
h,
fmt.Sprintf(behaviors.AllocationShutdownBehavior.DynamicTemplate, o.manifest.ID),
AllocationStopRequest{
AllocationID: allocID,
},
actor.WithMessageExpiry(actor.MakeExpiry(AllocationShutdownTimeout)),
)
if err != nil {
log.Errorf("error creating stop message for alloc: %s: %v", allocID, err)
errCh <- err
return
}
// invoke the stop message
replyCh, err := o.actor.Invoke(msg)
if err != nil {
log.Errorf("error invoking stop message for %s: %v", allocID, err)
errCh <- err
return
}
// wait for the reply
var reply actor.Envelope
select {
case reply = <-replyCh:
defer reply.Discard()
var resp AllocationStopResponse
if err := json.Unmarshal(reply.Message, &resp); err != nil {
log.Errorf("error unmarshalling stop allocation response: %s", err)
errCh <- err
return
}
if !resp.OK {
log.Errorf("failed to stop allocation %s", allocID)
errCh <- fmt.Errorf("failed to stop allocation %s", allocID)
return
}
case <-time.After(AllocationShutdownTimeout):
log.Errorf("timeout stopping allocation %s", allocID)
errCh <- fmt.Errorf("timeout stopping allocation %s", allocID)
return
}
log.Infof("allocation %s stopped", allocID)
}(o.manifest.Nodes[n].Handle, amf.ID)
}
wg.Wait()
close(errCh)
err := o.removeAllocationsFromSubnet(o.manifest, utils.MapKeysToSlice(allocNamesForSubnetRemoval))
if err != nil {
log.Errorf(
"removeNodeFromManifest: error removing allocations from subnet: %v",
err)
}
func(allocs []string) {
o.lock.Lock()
defer o.lock.Unlock()
for _, alloc := range allocs {
delete(o.manifest.Allocations, alloc)
allocName := strings.Split(alloc, ".")[1] // TODO: this is a hack to get the allocation name
delete(o.cfg.V1.Allocations, allocName)
}
nodeAllocs := modifiedCfg.V1.Nodes[n].Allocations
nmf := o.manifest.Nodes[n]
nmf.Allocations = nodeAllocs
o.manifest.Nodes[n] = nmf
ncfg := o.cfg.V1.Nodes[n]
ncfg.Allocations = nodeAllocs
o.cfg.V1.Nodes[n] = ncfg
}(utils.MapKeysToSlice(allocNamesForSubnetRemoval))
if errCh != nil {
errs = multierr.Append(errs, aggregateErrors(errCh))
}
}
if errs != nil {
log.Errorf("error removing allocations for nodes: %v", errs)
}
return errs
}
func (o *BasicOrchestrator) removeNodesAllocationsFromCfg(nodes []string) {
o.lock.Lock()
defer o.lock.Unlock()
for _, node := range nodes {
o.cfg.RemoveNodeAndAllocations(node)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package orchestrator
import "fmt"
// orderByDependency returns a list of verteces ordered by their dependencies.
// representing allocations as vertexes in a graph and dependencies as edges, the following returns an ordered list of
// lists such that: outer list is ordered, the inner list is unordered for allocations to be executed in parallel
func orderByDependency(vertices map[string][]string) ([][]string, error) {
// Build reverse dependency map and in-degree counts
dependentMap := make(map[string][]string)
inDegree := make(map[string]int)
// Initialize data structures
for vertexName, vertex := range vertices {
inDegree[vertexName] = len(vertex)
for _, dep := range vertex {
if _, ok := vertices[dep]; !ok {
return nil, fmt.Errorf("service %s depends on non-existent service %s", vertexName, dep)
}
dependentMap[dep] = append(dependentMap[dep], vertexName)
}
}
// Initialize queue with services that have no dependencies
queue := make([]string, 0)
for vertex, degree := range inDegree {
if degree == 0 {
queue = append(queue, vertex)
}
}
result := [][]string{}
// Process levels using BFS
for len(queue) > 0 {
levelSize := len(queue)
currentLevel := make([]string, 0, levelSize)
for i := 0; i < levelSize; i++ {
service := queue[0]
queue = queue[1:]
currentLevel = append(currentLevel, service)
// Update dependents' in-degrees
for _, dependent := range dependentMap[service] {
inDegree[dependent]--
if inDegree[dependent] == 0 {
queue = append(queue, dependent)
}
}
}
result = append(result, currentLevel)
}
// Check for cycles
total := 0
for _, level := range result {
total += len(level)
}
if total != len(vertices) {
return nil, fmt.Errorf("cycle detected in dependencies")
}
return result, nil
}
// aggregateErrors aggregates multiple errors coming from
// a channel until there are no error msgs anymore
func aggregateErrors(errCh chan error) error {
var aggErr error
for err := range errCh {
if aggErr == nil {
aggErr = err
continue
} else if err != nil {
aggErr = fmt.Errorf("%w\n%w", aggErr, err)
}
}
return aggErr
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package resources
import (
"context"
"fmt"
"sync"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// ManagerRepos holds all the repositories needed for resource management
type ManagerRepos struct {
OnboardedResources repositories.GenericEntityRepository[types.OnboardedResources]
ResourceAllocation repositories.GenericRepository[types.ResourceAllocation]
}
// DefaultManager implements the ResourceManager interface
// TODO: Add telemetry for the methods https://gitlab.com/nunet/device-management-service/-/issues/535
type DefaultManager struct {
repos ManagerRepos
store *store
hardware types.HardwareManager
// allocationLock is used to synchronize access to the allocation pool during allocation and deallocation
// it ensures that resource allocation and deallocation are atomic operations
allocationLock sync.Mutex
// committedLock is used to synchronize access to the committed resources pool during committing and releasing
// it ensures that resource committing and releasing are atomic operations
committedLock sync.Mutex
}
// NewResourceManager returns a new defaultResourceManager instance
func NewResourceManager(repos ManagerRepos, hardware types.HardwareManager) (*DefaultManager, error) {
if hardware == nil {
return nil, fmt.Errorf("hardware manager cannot be nil")
}
return &DefaultManager{
repos: repos,
store: newStore(),
hardware: hardware,
}, nil
}
var _ types.ResourceManager = (*DefaultManager)(nil)
// CommitResources commits the resources for an allocation
func (d *DefaultManager) CommitResources(ctx context.Context, commitment types.CommittedResources) error {
if err := commitment.ValidateBasic(); err != nil {
return fmt.Errorf("validating commitment: %w", err)
}
// Check if resources are already allocated for the allocation
var ok bool
d.store.withCommittedRLock(func() {
_, ok = d.store.committedResources[commitment.AllocationID]
})
if ok {
return fmt.Errorf("%w: for allocation %s", ErrResourcesAlreadyCommitted, commitment.AllocationID)
}
d.committedLock.Lock()
defer d.committedLock.Unlock()
if err := d.checkCapacity(ctx, commitment.Resources); err != nil {
return fmt.Errorf("checking capacity: %w", err)
}
// update the committed resources in the store
d.store.withCommittedLock(func() {
d.store.committedResources[commitment.AllocationID] = &types.CommittedResources{
Resources: commitment.Resources,
AllocationID: commitment.AllocationID,
}
})
return nil
}
// UncommitResources releases the committed resources for an allocation
func (d *DefaultManager) UncommitResources(_ context.Context, allocationID string) error {
d.committedLock.Lock()
defer d.committedLock.Unlock()
// Check if resources are already deallocated for the allocation
var (
ok bool
)
d.store.withCommittedLock(func() {
_, ok = d.store.committedResources[allocationID]
})
if !ok {
return fmt.Errorf("%w: for allocation %s", ErrResourcesNotCommitted, allocationID)
}
// Release the committed resources
d.store.withCommittedLock(func() {
delete(d.store.committedResources, allocationID)
})
return nil
}
// IsCommitted checks if the resources are committed for an allocationID
func (d *DefaultManager) IsCommitted(allocationID string) (bool, error) {
var ok bool
d.store.withCommittedRLock(func() {
_, ok = d.store.committedResources[allocationID]
})
return ok, nil
}
// AllocateResources allocates resources for an allocation
func (d *DefaultManager) AllocateResources(ctx context.Context, allocationID string) error {
d.allocationLock.Lock()
defer d.allocationLock.Unlock()
// Ensure that the resources are committed for the allocation
var (
ok bool
allocation *types.CommittedResources
)
d.store.withCommittedRLock(func() {
allocation, ok = d.store.committedResources[allocationID]
})
if !ok {
return fmt.Errorf("%w: for allocation %s", ErrResourcesNotCommitted, allocationID)
}
// Check if resources are already allocated for the allocation
d.store.withAllocationsRLock(func() {
_, ok = d.store.allocations[allocation.AllocationID]
})
if ok {
return fmt.Errorf("%w: for allocation %s", ErrResourcesAlreadyAllocated, allocation.AllocationID)
}
allocatedResource := types.ResourceAllocation{AllocationID: allocationID, Resources: allocation.Resources}
if err := d.storeAllocation(ctx, allocatedResource); err != nil {
return fmt.Errorf("storing allocations in db: %w", err)
}
// clear the committed resources
d.store.withCommittedLock(func() {
delete(d.store.committedResources, allocationID)
})
return nil
}
// DeallocateResources deallocates resources for a allocation
func (d *DefaultManager) DeallocateResources(ctx context.Context, allocationID string) error {
d.allocationLock.Lock()
defer d.allocationLock.Unlock()
// Check if resources are already deallocated for the allocation
var (
ok bool
)
d.store.withAllocationsRLock(func() {
_, ok = d.store.allocations[allocationID]
})
if !ok {
return fmt.Errorf("%w: for allocation %s", ErrResourcesNotAllocated, allocationID)
}
if err := d.deleteAllocation(ctx, allocationID); err != nil {
return fmt.Errorf("deleting allocations from db: %w", err)
}
return nil
}
// IsAllocated checks if the resources are allocated for an allocationID
func (d *DefaultManager) IsAllocated(allocationID string) (bool, error) {
var ok bool
d.store.withAllocationsRLock(func() {
_, ok = d.store.allocations[allocationID]
})
return ok, nil
}
// GetFreeResources returns the free resources in the allocation pool
func (d *DefaultManager) GetFreeResources(ctx context.Context) (types.FreeResources, error) {
var freeResources types.FreeResources
// get onboarded resources
onboardedResources, err := d.GetOnboardedResources(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting onboarded resources: %w", err)
}
// get allocated resources
totalAllocation, err := d.GetTotalAllocation()
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting total allocations: %w", err)
}
// get committed resources
var committedResources types.Resources
d.store.withCommittedRLock(func() {
for _, committedResource := range d.store.committedResources {
_ = committedResources.Add(committedResource.Resources)
}
})
log.Debugw("onboarded_resources",
"labels", string(observability.LabelMetric),
"resources", onboardedResources.Resources)
log.Debugw("total_allocation",
"labels", string(observability.LabelMetric),
"allocation", totalAllocation)
log.Debugw("committed_resources",
"labels", string(observability.LabelMetric),
"committed", committedResources)
freeResources.Resources = onboardedResources.Resources
if err := freeResources.Resources.Subtract(totalAllocation); err != nil {
return types.FreeResources{}, fmt.Errorf("subtracting total allocation: %w", err)
}
if err := freeResources.Resources.Subtract(committedResources); err != nil {
return types.FreeResources{}, fmt.Errorf("subtracting committed resources: %w", err)
}
log.Debugw("free_resources",
"labels", string(observability.LabelMetric),
"free", freeResources.Resources)
return freeResources, nil
}
// GetTotalAllocation returns the total allocations of the jobs requiring resources
func (d *DefaultManager) GetTotalAllocation() (types.Resources, error) {
var (
totalAllocation types.Resources
err error
)
d.store.withAllocationsRLock(func() {
for _, allocation := range d.store.allocations {
err = totalAllocation.Add(allocation.Resources)
if err != nil {
break
}
}
})
return totalAllocation, err
}
// GetOnboardedResources returns the onboarded resources of the machine
func (d *DefaultManager) GetOnboardedResources(ctx context.Context) (types.OnboardedResources, error) {
var (
onboardedResources types.OnboardedResources
ok bool
)
d.store.withOnboardedRLock(func() {
if d.store.onboardedResources != nil {
onboardedResources = *d.store.onboardedResources
ok = true
}
})
if ok {
return onboardedResources, nil
}
onboardedResources, err := d.repos.OnboardedResources.Get(ctx)
if err != nil {
return types.OnboardedResources{}, fmt.Errorf("failed to get onboarded resources: %w", err)
}
_ = d.store.withOnboardedLock(func() error {
d.store.onboardedResources = &onboardedResources
return nil
})
return onboardedResources, nil
}
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
func (d *DefaultManager) UpdateOnboardedResources(ctx context.Context, resources types.Resources) error {
if err := d.store.withOnboardedLock(func() error {
// calculate the new free resources based on the allocations
totalAllocation, err := d.GetTotalAllocation()
if err != nil {
return fmt.Errorf("getting total allocations: %w", err)
}
// Check if the allocation is too high
if err := resources.Subtract(totalAllocation); err != nil {
return fmt.Errorf("couldn't subtract allocation: %w. Demand too high", err)
}
onboardedResources := types.OnboardedResources{Resources: resources}
_, err = d.repos.OnboardedResources.Save(ctx, onboardedResources)
if err != nil {
return fmt.Errorf("failed to update onboarded resources: %w", err)
}
d.store.onboardedResources = &onboardedResources
return nil
}); err != nil {
return err
}
return nil
}
// checkCapacity checks if the resources are available in the pool
func (d *DefaultManager) checkCapacity(ctx context.Context, resources types.Resources) error {
freeResources, err := d.GetFreeResources(ctx)
if err != nil {
return fmt.Errorf("getting free resources: %w", err)
}
// Check if there are enough free resources in dms resource pool to allocate
if err := freeResources.Subtract(resources); err != nil {
return types.ErrNoFreeResources
}
return nil
}
// storeAllocation stores the allocations in the database and the store
func (d *DefaultManager) storeAllocation(ctx context.Context, allocation types.ResourceAllocation) error {
_, err := d.repos.ResourceAllocation.Create(ctx, allocation)
if err != nil {
return fmt.Errorf("storing allocations in db: %w", err)
}
d.store.withAllocationsLock(func() {
d.store.allocations[allocation.AllocationID] = allocation
})
return nil
}
// deleteAllocation deletes the allocations from the database and the store
func (d *DefaultManager) deleteAllocation(ctx context.Context, allocationID string) error {
query := d.repos.ResourceAllocation.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("AllocationID", allocationID))
allocation, err := d.repos.ResourceAllocation.Find(context.Background(), query)
if err != nil {
return fmt.Errorf("finding allocations in db: %w", err)
}
if err := d.repos.ResourceAllocation.Delete(ctx, allocation.ID); err != nil {
return fmt.Errorf("deleting allocations from db: %w", err)
}
d.store.withAllocationsLock(func() {
delete(d.store.allocations, allocationID)
})
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package resources
import (
"sync"
"gitlab.com/nunet/device-management-service/types"
)
// locks holds the locks for the resource manager
// allocations: lock for the allocations map
// onboarded: lock for the onboarded resources
// free: lock for the free resources
type locks struct {
allocations sync.RWMutex
onboarded sync.RWMutex
committed sync.RWMutex
}
// newLocks returns a new locks instance
func newLocks() *locks {
return &locks{}
}
// store holds the resources of the machine
// onboardedResources: resources that are onboarded to the machine
// freeResources: resources that are free to be allocated
// allocations: resources that are requested by the jobs
type store struct {
onboardedResources *types.OnboardedResources
committedResources map[string]*types.CommittedResources
allocations map[string]types.ResourceAllocation
locks *locks
}
// newStore returns a new store instance
func newStore() *store {
return &store{
allocations: make(map[string]types.ResourceAllocation),
committedResources: make(map[string]*types.CommittedResources),
locks: newLocks(),
}
}
// withAllocationsLock locks the allocations lock and executes the function
func (s *store) withAllocationsLock(fn func()) {
s.locks.allocations.Lock()
defer s.locks.allocations.Unlock()
fn()
}
// withOnboardedLock locks the onboarded lock and executes the function
func (s *store) withOnboardedLock(fn func() error) error {
s.locks.onboarded.Lock()
defer s.locks.onboarded.Unlock()
return fn()
}
// withCommittedLock locks the committed lock and executes the function
func (s *store) withCommittedLock(fn func()) {
s.locks.committed.Lock()
defer s.locks.committed.Unlock()
fn()
}
func (s *store) withCommittedRLock(fn func()) {
s.locks.committed.RLock()
defer s.locks.committed.RUnlock()
fn()
}
// withAllocationsRLock performs a read lock and returns the result and error
func (s *store) withAllocationsRLock(fn func()) {
s.locks.allocations.RLock()
defer s.locks.allocations.RUnlock()
fn()
}
// withOnboardedRLock performs a read lock and returns the result and error
func (s *store) withOnboardedRLock(fn func()) {
s.locks.onboarded.RLock()
defer s.locks.onboarded.RUnlock()
fn()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package dockercompose
import (
"context"
"github.com/compose-spec/compose-go/v2/loader"
"github.com/compose-spec/compose-go/v2/types"
)
// Parse takes the content of a docker-compose.yml file and returns a parsed Project object.
func Parse(content []byte) (*types.Project, error) {
configDetails := types.ConfigDetails{
ConfigFiles: []types.ConfigFile{
{
Content: content,
},
},
}
// The loader.LoadWithContext function handles parsing and validation of the compose file.
// NOTE: should we add context to the Parse function and pass it here?
project, err := loader.LoadWithContext(context.Background(), configDetails, func(opts *loader.Options) { opts.SetProjectName("project", true) })
if err != nil {
return nil, err
}
return project, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package dockercompose
import (
"fmt"
"slices"
"strconv"
"time"
composetypes "github.com/compose-spec/compose-go/v2/types"
jobtypes "gitlab.com/nunet/device-management-service/dms/jobs/types"
"gitlab.com/nunet/device-management-service/dms/translator/types"
nunettypes "gitlab.com/nunet/device-management-service/types"
)
// DockerTranslator implements the translator.Translator interface for Docker Compose files.
type DockerTranslator struct{}
// NewDockerComposeTranslator creates a new DockerTranslator.
func NewDockerComposeTranslator() *DockerTranslator {
return &DockerTranslator{}
}
// Translate converts a Docker Compose file content into a NuNet EnsembleConfig.
func (t *DockerTranslator) Translate(input []byte) (*types.Translation, error) {
project, err := Parse(input)
if err != nil {
return nil, fmt.Errorf("failed to parse docker compose file: %w", err)
}
warnings := &WarningCollector{}
allocations := make(map[string]jobtypes.AllocationConfig)
nodes := make(map[string]jobtypes.NodeConfig)
// Translate top-level unsupported features
if len(project.Networks) > 1 {
warnings.Add("", "networks", "multiple networks are not supported; all services will be joined into a single subnet.")
}
checkUnsupportedTopLevelFeatures(project, warnings)
// Translate each service
for _, service := range project.ServiceNames() {
alloc, node := translateService(project, service, warnings)
allocations[service] = alloc
nodes[service] = node
}
ensemble := &jobtypes.EnsembleConfig{
V1: &jobtypes.EnsembleConfigV1{
Allocations: allocations,
Nodes: nodes,
Subnet: jobtypes.SubnetConfig{
Join: true, // All services join the same subnet by default
},
// Default escalation strategy
EscalationStrategy: jobtypes.EscalationStrategyRedeploy,
},
}
return &types.Translation{
Config: ensemble,
Warnings: warnings.Get(),
}, nil
}
// translateService converts a single Docker Compose service into a NuNet Allocation and Node.
func translateService(project *composetypes.Project, serviceName string, w *WarningCollector) (jobtypes.AllocationConfig, jobtypes.NodeConfig) {
service, err := project.GetService(serviceName)
if err != nil {
w.Add("", "services", fmt.Sprintf("service '%s' not found", serviceName))
return jobtypes.AllocationConfig{}, jobtypes.NodeConfig{}
}
executionParams := map[string]any{
"image": service.Image,
}
if service.Command != nil {
executionParams["command"] = service.Command
}
if service.Entrypoint != nil {
executionParams["entrypoint"] = service.Entrypoint
}
if len(service.Environment) > 0 {
executionParams["environment"] = service.Environment
}
if service.WorkingDir != "" {
executionParams["working_directory"] = service.WorkingDir
}
// Set cpu and memory limits if specified in the service
if service.Deploy == nil {
service.Deploy = &composetypes.DeployConfig{}
}
if service.Deploy.Resources.Limits == nil {
limits := composetypes.Resource{}
if service.CPUS > 0 {
limits.NanoCPUs = composetypes.NanoCPUs(service.CPUS)
}
if service.MemLimit > 0 {
limits.MemoryBytes = service.MemLimit
}
service.Deploy.Resources.Limits = &limits
}
alloc := jobtypes.AllocationConfig{
Executor: jobtypes.ExecutorDocker,
Type: jobtypes.AllocationTypeService,
Resources: translateResources(serviceName, service.Deploy, w),
Execution: nunettypes.SpecConfig{Type: "docker", Params: executionParams},
DNSName: service.DomainName,
Volume: translateVolumes(serviceName, service.Volumes, project.Volumes, w),
HealthCheck: translateHealthCheck(service.HealthCheck, w),
DependsOn: service.GetDependencies(),
}
node := jobtypes.NodeConfig{
Allocations: []string{service.Name},
Ports: translatePorts(service.Ports, service.Name, w),
}
if service.Restart != "" {
alloc.FailureRecovery, node.FailureRecovery = translateServiceRestart(serviceName, service.Restart, w)
}
checkForUnsupportedServiceFeatures(service, w)
return alloc, node
}
func translateServiceRestart(serviceName, restart string, w *WarningCollector) (jobtypes.AllocationFailureRecovery, jobtypes.NodeFailureRecovery) {
switch restart {
case composetypes.RestartPolicyNo:
return jobtypes.AllocationFailureRecoveryStayDown, jobtypes.NodeFailureRecoveryStayDown
case composetypes.RestartPolicyAlways, composetypes.RestartPolicyUnlessStopped, composetypes.RestartPolicyOnFailure:
return "", jobtypes.NodeFailureRecoveryRestart
default:
w.Add(serviceName, "restart", fmt.Sprintf("unknown restart policy '%s'", restart))
return "", ""
}
}
func translateResources(serviceName string, deploy *composetypes.DeployConfig, w *WarningCollector) nunettypes.Resources {
// Set default resources
res := nunettypes.Resources{}
if deploy == nil || deploy.Resources.Limits == nil {
return res
}
limits := deploy.Resources.Limits
if limits.NanoCPUs > 0 {
res.CPU.Cores = limits.NanoCPUs.Value()
}
if limits.MemoryBytes > 0 {
res.RAM.Size = uint64(limits.MemoryBytes)
}
if deploy.Resources.Reservations != nil {
w.Add(serviceName, "deploy.resources.reservations", "resource reservations are not supported and will be ignored. Using limits instead.")
}
return res
}
func translateVolumes(serviceName string, volumes []composetypes.ServiceVolumeConfig, projectVolumes composetypes.Volumes, w *WarningCollector) []nunettypes.VolumeConfig {
nunetVolumes := make([]nunettypes.VolumeConfig, 0)
for _, vol := range volumes {
v := nunettypes.VolumeConfig{
Type: "local",
MountDestination: vol.Target,
ReadOnly: vol.ReadOnly,
}
switch vol.Type {
case "bind":
v.Src = vol.Source
case "volume":
v.Src = vol.Source
volume, ok := projectVolumes[vol.Source]
if !ok {
w.Add("", "volumes", fmt.Sprintf("volume '%s' not found", vol.Source))
}
if volume.Driver != "" && volume.Driver != "local" {
w.Add(serviceName, "volumes.volume.driver", "only 'local' volumes are supported.")
continue
}
if dev, ok := volume.DriverOpts["device"]; ok && dev != "" {
v.Src = dev
}
if vol.Volume != nil && vol.Volume.NoCopy {
w.Add("", "volumes.volume.nocopy", "'nocopy' is not supported and will be ignored.")
}
default:
w.Add("", fmt.Sprintf("volumes type '%s'", vol.Type), "only 'bind' and 'volume' types are supported.")
continue
}
nunetVolumes = append(nunetVolumes, v)
}
return nunetVolumes
}
func translatePorts(ports []composetypes.ServicePortConfig, serviceName string, w *WarningCollector) []jobtypes.PortConfig {
nunetPorts := make([]jobtypes.PortConfig, 0)
for _, port := range ports {
targetPort := int(port.Target)
publishedPort, _ := strconv.Atoi(port.Published)
if port.Protocol != "tcp" && port.Protocol != "" {
w.Add(serviceName, fmt.Sprintf("ports.protocol: %s", port.Protocol), "only TCP ports are supported. UDP will be ignored.")
continue
}
nunetPorts = append(nunetPorts, jobtypes.PortConfig{
Public: publishedPort,
Private: targetPort,
Allocation: serviceName,
})
}
return nunetPorts
}
func translateHealthCheck(hc *composetypes.HealthCheckConfig, w *WarningCollector) nunettypes.HealthCheckManifest {
if hc == nil || hc.Disable || len(hc.Test) == 0 {
return nunettypes.HealthCheckManifest{}
}
cmd := hc.Test
if slices.Contains([]string{"CMD", "CMD-SHELL"}, hc.Test[0]) {
if len(hc.Test) < 2 {
w.Add("", "healthcheck.test", "healthcheck test must be a command if test type is CMD or CMD-SHELL.")
}
cmd = hc.Test[1:]
}
manifest := nunettypes.HealthCheckManifest{
Type: "command",
Exec: cmd,
}
if hc.Interval != nil {
manifest.Interval = time.Duration(*hc.Interval)
}
if hc.Timeout != nil {
w.Add("", "healthcheck.timeout", "'timeout' is not supported and will be ignored.")
}
if hc.Retries != nil {
w.Add("", "healthcheck.retries", "'retries' is not supported and will be ignored.")
}
if hc.StartPeriod != nil {
w.Add("", "healthcheck.start_period", "'start_period' is not supported and will be ignored.")
}
return manifest
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package dockercompose
import (
"fmt"
"sync"
"github.com/compose-spec/compose-go/v2/types"
)
// WarningCollector safely collects warnings for unsupported features during the translation process.
type WarningCollector struct {
warnings []string
mu sync.Mutex
}
// Add records a new warning. It is safe for concurrent use.
func (wc *WarningCollector) Add(serviceName, feature, reason string) {
wc.mu.Lock()
defer wc.mu.Unlock()
var msg string
if serviceName != "" {
msg = fmt.Sprintf("Service '%s': Unsupported feature '%s' was ignored. Reason: %s", serviceName, feature, reason)
} else {
msg = fmt.Sprintf("Top-level feature '%s' was ignored. Reason: %s", feature, reason)
}
wc.warnings = append(wc.warnings, msg)
}
// Get returns all collected warnings.
func (wc *WarningCollector) Get() []string {
wc.mu.Lock()
defer wc.mu.Unlock()
// Return a copy to prevent race conditions if the original slice is modified later.
warningsCopy := make([]string, len(wc.warnings))
copy(warningsCopy, wc.warnings)
return warningsCopy
}
// checkUnsupportedTopLevelFeatures checks for unsupported features at the root of the Compose file.
func checkUnsupportedTopLevelFeatures(p *types.Project, w *WarningCollector) {
if len(p.Configs) > 0 {
w.Add("", "configs", "top-level 'configs' are not supported.")
}
if len(p.Secrets) > 0 {
w.Add("", "secrets", "top-level 'secrets' are not supported.")
}
}
// checkForUnsupportedServiceFeatures checks for service-level features that cannot be translated.
func checkForUnsupportedServiceFeatures(s types.ServiceConfig, w *WarningCollector) {
unsupported := map[string]string{
"build": "NuNet requires pre-built Docker images.",
"cgroup_parent": "cgroup configuration is managed by the NuNet DMS.",
"container_name": "container naming is managed by the NuNet DMS.",
"devices": "direct device access is not supported for security reasons.",
"dns": "DNS is managed by the NuNet virtual network.",
"dns_search": "DNS is managed by the NuNet virtual network.",
"external_links": "linking is handled by the NuNet virtual network.",
"extra_hosts": "host entries are managed by the NuNet virtual network.",
"ipc": "IPC namespace is not configurable.",
"mac_address": "network identity is managed by the NuNet DMS.",
"privileged": "privileged mode is not supported for security reasons.",
"restart": "restart policy is managed by the allocation's 'failure_recovery' strategy.",
"security_opt": "security options are managed by the NuNet DMS.",
"shm_size": "shared memory size is not configurable.",
"stop_grace_period": "shutdown behavior is managed by the NuNet DMS.",
"sysctls": "kernel parameters are not configurable for security reasons.",
"tmpfs": "in-memory filesystems are not supported.",
"ulimits": "resource limits are managed by the NuNet DMS.",
}
// Use reflection in a real-world scenario for more dynamic checking if needed,
// but for this specific list, direct checks are clearer.
if s.Build != nil {
w.Add(s.Name, "build", unsupported["build"])
}
if s.CgroupParent != "" {
w.Add(s.Name, "cgroup_parent", unsupported["cgroup_parent"])
}
if s.ContainerName != "" {
w.Add(s.Name, "container_name", unsupported["container_name"])
}
if len(s.Devices) > 0 {
w.Add(s.Name, "devices", unsupported["devices"])
}
// ... and so on for the other unsupported features.
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package translator
import (
"gitlab.com/nunet/device-management-service/dms/translator/dockercompose"
"gitlab.com/nunet/device-management-service/dms/translator/types"
)
var registry *Registry
func init() {
registry = &Registry{
translators: make(map[SpecType]types.Translator),
}
registry.RegisterTranslator(SpecTypeDockerCompose, dockercompose.NewDockerComposeTranslator())
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package translator
import (
"sync"
"gitlab.com/nunet/device-management-service/dms/translator/types"
)
type Registry struct {
translators map[SpecType]types.Translator
mu sync.RWMutex
}
func (r *Registry) RegisterTranslator(specType SpecType, t types.Translator) {
r.mu.Lock()
defer r.mu.Unlock()
r.translators[specType] = t
}
func (r *Registry) GetTranslator(specType SpecType) (types.Translator, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
p, exists := r.translators[specType]
return p, exists
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package translator
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/translator/types"
)
type SpecType string
const (
SpecTypeDockerCompose SpecType = "docker-compose"
)
func Translate(specType SpecType, input []byte) (*types.Translation, error) {
translator, found := registry.GetTranslator(specType)
if !found {
return nil, fmt.Errorf("translator for spec type %s not found", specType)
}
return translator.Translate(input)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package docker
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"slices"
"strings"
"sync"
"time"
dockertypes "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/stdcopy"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
"go.uber.org/multierr"
)
type ClientInterface interface {
IsInstalled(ctx context.Context) bool
CreateContainer(
ctx context.Context,
config *container.Config,
hostConfig *container.HostConfig,
networkingConfig *network.NetworkingConfig,
imagePullOpts image.PullOptions,
platform *v1.Platform,
name string,
pullImage bool,
) (string, error)
InspectContainer(ctx context.Context, id string) (dockertypes.ContainerJSON, error)
FollowLogs(ctx context.Context, id string) (stdout, stderr io.Reader, err error)
StartContainer(ctx context.Context, containerID string) error
WaitContainer(
ctx context.Context,
containerID string,
) (<-chan container.WaitResponse, <-chan error)
PauseContainer(ctx context.Context, containerID string) error
ResumeContainer(ctx context.Context, containerID string) error
StopContainer(
ctx context.Context,
containerID string,
options container.StopOptions,
) error
RemoveContainer(ctx context.Context, containerID string) error
RemoveObjectsWithLabel(ctx context.Context, label string, value string) error
FindContainer(ctx context.Context, label string, value string) (string, error)
GetImage(ctx context.Context, imageName string) (image.Summary, error)
PullImage(ctx context.Context, imageName string, imagePullOpts image.PullOptions) (string, error)
GetOutputStream(
ctx context.Context,
containerID string,
since string,
follow bool,
) (io.ReadCloser, error)
Exec(ctx context.Context, containerID string, cmd []string) (int, string, string, error)
CopyToContainer(ctx context.Context, containerID, dstPath string, content io.Reader, options container.CopyToContainerOptions) error
ContainerStats(ctx context.Context, containerID string) (types.ExecutorStats, error)
}
// Client wraps the Docker client to provide high-level operations on Docker containers and networks.
type Client struct {
client *client.Client // Embed the Docker client.
}
// Ensure that Client implements the ClientInterface.
var _ ClientInterface = (*Client)(nil)
// NewDockerClient initializes a new Docker client with environment variables and API version negotiation.
func NewDockerClient() (*Client, error) {
log.Debugw("docker_client_init_started",
"labels", string(observability.LabelDeployment))
c, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation(), client.WithHostFromEnv())
if err != nil {
log.Errorw("docker_client_init_failure",
"labels", string(observability.LabelDeployment),
"error", err)
return nil, err
}
log.Debugw("docker_client_init_success",
"labels", string(observability.LabelDeployment))
return &Client{client: c}, nil
}
// IsInstalled checks if Docker is installed and reachable by pinging the Docker daemon.
func (c *Client) IsInstalled(ctx context.Context) bool {
log.Debugw("docker_client_is_installed_check_started",
"labels", string(observability.LabelNode))
_, err := c.client.Ping(ctx)
if err != nil {
log.Errorw("docker_client_is_installed_failure",
"labels", string(observability.LabelNode),
"error", err)
return false
}
log.Debugw("docker_client_is_installed_success",
"labels", string(observability.LabelNode))
return true
}
// CreateContainer creates a new Docker container with the specified configuration.
func (c *Client) CreateContainer(
ctx context.Context,
config *container.Config,
hostConfig *container.HostConfig,
networkingConfig *network.NetworkingConfig,
imagePullOpts image.PullOptions,
platform *v1.Platform,
name string,
pullImage bool,
) (string, error) {
log.Infow("docker_create_container_started",
"labels", string(observability.LabelDeployment),
"image", config.Image,
"name", name)
if pullImage {
_, err := c.PullImage(ctx, config.Image, imagePullOpts)
if err != nil {
log.Errorw("docker_create_container_failure",
"labels", string(observability.LabelDeployment),
"error", err)
return "", err
}
}
resp, err := c.client.ContainerCreate(
ctx,
config,
hostConfig,
networkingConfig,
platform,
name,
)
if err != nil {
log.Errorw("docker_create_container_failure",
"labels", string(observability.LabelDeployment),
"error", err)
return "", err
}
log.Infow("docker_create_container_success",
"labels", string(observability.LabelDeployment),
"containerID", resp.ID)
return resp.ID, nil
}
// InspectContainer returns detailed information about a Docker container.
func (c *Client) InspectContainer(ctx context.Context, id string) (dockertypes.ContainerJSON, error) {
return c.client.ContainerInspect(ctx, id)
}
// FollowLogs tails the logs of a specified container, returning separate readers for stdout and stderr.
func (c *Client) FollowLogs(ctx context.Context, id string) (stdout, stderr io.Reader, err error) {
cont, err := c.InspectContainer(ctx, id)
if err != nil {
log.Errorw("docker_follow_logs_failure", "error", err)
return nil, nil, errors.Wrap(err, "failed to get container")
}
logOptions := container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
}
logsReader, err := c.client.ContainerLogs(ctx, cont.ID, logOptions)
if err != nil {
log.Errorw("docker_follow_logs_failure", "error", err)
return nil, nil, errors.Wrap(err, "failed to get container logs")
}
stdoutR, stdoutW := io.Pipe()
stderrR, stderrW := io.Pipe()
go func() {
defer logsReader.Close()
defer stdoutW.Close()
defer stderrW.Close()
// Copy logs to stdout and stderr
if cont.Config.Tty {
// If TTY is enabled, everything goes to stdout
_, err = io.Copy(stdoutW, logsReader)
} else {
// If TTY is not enabled, use stdcopy.StdCopy to demultiplex the streams
_, err = stdcopy.StdCopy(stdoutW, stderrW, logsReader)
}
if err != nil {
log.Errorf("failed to copy logs: %v", err)
}
}()
return stdoutR, stderrR, nil
}
// StartContainer starts a specified Docker container.
func (c *Client) StartContainer(ctx context.Context, containerID string) error {
log.Infow("docker_start_container",
"labels", string(observability.LabelDeployment),
"containerID", containerID)
return c.client.ContainerStart(ctx, containerID, container.StartOptions{})
}
// WaitContainer waits for a container to stop, returning channels for the result and errors.
func (c *Client) WaitContainer(
ctx context.Context,
containerID string,
) (<-chan container.WaitResponse, <-chan error) {
log.Infow("docker_wait_container", "containerID", containerID)
return c.client.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
}
// PauseContainer pauses the main process of the given container without terminating it.
func (c *Client) PauseContainer(ctx context.Context, containerID string) error {
return c.client.ContainerPause(ctx, containerID)
}
// ResumeContainer resumes the process execution within the container
func (c *Client) ResumeContainer(ctx context.Context, containerID string) error {
return c.client.ContainerUnpause(ctx, containerID)
}
// StopContainer stops a running Docker container with a specified timeout.
func (c *Client) StopContainer(
ctx context.Context,
containerID string,
options container.StopOptions,
) error {
log.Infow("docker_stop_container_started",
"labels", string(observability.LabelAllocation),
"containerID", containerID)
return c.client.ContainerStop(ctx, containerID, options)
}
// RemoveContainer removes a Docker container, optionally forcing removal and removing associated volumes.
func (c *Client) RemoveContainer(ctx context.Context, containerID string) error {
return c.client.ContainerRemove(
ctx,
containerID,
container.RemoveOptions{RemoveVolumes: true, Force: true},
)
}
// removeContainers removes all containers matching the specified filters.
func (c *Client) removeContainers(ctx context.Context, filterz filters.Args) error {
containers, err := c.client.ContainerList(
ctx,
container.ListOptions{All: true, Filters: filterz},
)
if err != nil {
log.Errorw("docker_remove_containers_failure",
"labels", string(observability.LabelDeployment),
"error", err)
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(containers))
for _, cont := range containers {
wg.Add(1)
go func(cont dockertypes.Container, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.RemoveContainer(ctx, cont.ID)
}(cont, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
if errs != nil {
log.Errorw("docker_remove_containers_failure",
"labels", string(observability.LabelDeployment),
"error", errs)
} else {
log.Infow("docker_remove_containers_success",
"labels", string(observability.LabelDeployment))
}
return errs
}
// removeNetworks removes all networks matching the specified filters.
func (c *Client) removeNetworks(ctx context.Context, filterz filters.Args) error {
networks, err := c.client.NetworkList(ctx, network.ListOptions{Filters: filterz})
if err != nil {
log.Errorw("docker_remove_networks_failure",
"labels", string(observability.LabelDeployment),
"error", err)
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(networks))
for _, n := range networks {
wg.Add(1)
go func(network network.Inspect, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.client.NetworkRemove(ctx, network.ID)
}(n, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
if errs != nil {
log.Errorw("docker_remove_networks_failure",
"labels", string(observability.LabelDeployment),
"error", errs)
} else {
log.Infow("docker_remove_networks_success",
"labels", string(observability.LabelDeployment))
}
return errs
}
// RemoveObjectsWithLabel removes all Docker containers and networks with a specific label.
func (c *Client) RemoveObjectsWithLabel(ctx context.Context, label string, value string) error {
filterz := filters.NewArgs(
filters.Arg("label", fmt.Sprintf("%s=%s", label, value)),
)
containerErr := c.removeContainers(ctx, filterz)
networkErr := c.removeNetworks(ctx, filterz)
if containerErr != nil || networkErr != nil {
log.Errorw("docker_remove_objects_with_label_failure",
"labels", string(observability.LabelDeployment),
"containerErr", containerErr,
"networkErr", networkErr)
}
log.Infow("docker_remove_objects_with_label_success",
"labels", string(observability.LabelDeployment))
return multierr.Combine(containerErr, networkErr)
}
// GetOutputStream streams the logs for a specified container.
// The 'since' parameter specifies the timestamp from which to start streaming logs.
// The 'follow' parameter indicates whether to continue streaming logs as they are produced.
// Returns an io.ReadCloser to read the output stream and an error if the operation fails.
func (c *Client) GetOutputStream(
ctx context.Context,
containerID string,
since string,
follow bool,
) (io.ReadCloser, error) {
logOptions := container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: follow,
Since: since,
}
logReader, err := c.client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
log.Errorw("docker_get_output_stream_failure", "error", err)
return nil, errors.Wrap(err, "failed to get container logs")
}
log.Infow("docker_get_output_stream_success", "containerID", containerID)
return logReader, nil
}
// FindContainer searches for a container by label and value, returning its ID if found.
func (c *Client) FindContainer(ctx context.Context, label string, value string) (string, error) {
containers, err := c.client.ContainerList(ctx, container.ListOptions{All: true})
if err != nil {
log.Errorw("docker_find_container_failure", "error", err)
return "", err
}
for _, cont := range containers {
if cont.Labels[label] == value {
log.Infow("docker_find_container_success", "containerID", cont.ID)
return cont.ID, nil
}
}
err = fmt.Errorf("container (%s=%s) not found", label, value)
log.Warnw("docker_container_not_found", "error", err)
return "", err
}
// HasImage checks if an image exists locally
func (c *Client) HasImage(ctx context.Context, imageName string) bool {
// If imageName does not contain a tag, we need to append ":latest" to the image name
if !strings.Contains(imageName, ":") {
imageName = fmt.Sprintf("%s:latest", imageName)
}
_, _, err := c.client.ImageInspectWithRaw(ctx, imageName)
if err != nil {
if client.IsErrNotFound(err) {
return false
}
log.Warnf("Failed to inspect image: %v", err)
return false
}
return true
}
// GetImage returns detailed information about a Docker image.
func (c *Client) GetImage(ctx context.Context, imageName string) (image.Summary, error) {
images, err := c.client.ImageList(ctx, image.ListOptions{All: true})
if err != nil {
return image.Summary{}, err
}
// If imageName does not contain a tag, we need to append ":latest" to the image name.
if !strings.Contains(imageName, ":") {
imageName = fmt.Sprintf("%s:latest", imageName)
}
for _, image := range images {
if slices.Contains(image.RepoTags, imageName) {
return image, nil
}
}
return image.Summary{}, fmt.Errorf("unable to find image %s", imageName)
}
// PullImage pulls a Docker image from a registry.
func (c *Client) PullImage(ctx context.Context, imageName string, imagePullOpts image.PullOptions) (string, error) {
log.Infow("docker_pull_image_started",
"labels", string(observability.LabelDeployment),
"image", imageName)
out, err := c.client.ImagePull(ctx, imageName, imagePullOpts)
if err != nil {
log.Errorw("docker_pull_image_failure",
"labels", string(observability.LabelDeployment),
"error", err)
return "", err
}
defer out.Close()
d := json.NewDecoder(io.TeeReader(out, os.Stdout))
var message jsonmessage.JSONMessage
var digest string
for {
if err := d.Decode(&message); err != nil {
if err == io.EOF {
break
}
log.Errorw("docker_pull_image_failure",
"labels", string(observability.LabelDeployment),
"error", err)
return "", err
}
if message.Aux != nil {
continue
}
if message.Error != nil {
log.Errorw("docker_pull_image_failure",
"labels", string(observability.LabelDeployment),
"error", message.Error.Message)
return "", errors.New(message.Error.Message)
}
if strings.HasPrefix(message.Status, "Digest") {
digest = strings.TrimPrefix(message.Status, "Digest: ")
}
}
log.Infow("docker_pull_image_success",
"labels", string(observability.LabelDeployment),
"digest", digest)
return digest, nil
}
// Exec executes a command inside a running container.
// Returns (exit code, stdout, stderr, and an error if the operation fails)
func (c *Client) Exec(ctx context.Context, containerID string, cmd []string) (int, string, string, error) {
log.Infow("docker_container_exec_started", "container ID", containerID)
idresp, err := c.client.ContainerExecCreate(ctx, containerID, container.ExecOptions{
Cmd: cmd,
AttachStdout: true,
AttachStderr: true,
})
if err != nil {
log.Errorw("docker_container_exec_failure", "error", err)
return 1, "", "", err
}
hijconn, err := c.client.ContainerExecAttach(ctx, idresp.ID, container.ExecAttachOptions{
Detach: false,
Tty: false,
})
if err != nil {
log.Errorw("docker_container_exec_failure", "error", err)
return 1, "", "", err
}
defer hijconn.Close()
var stdout, stderr bytes.Buffer
n, err := stdcopy.StdCopy(&stdout, &stderr, hijconn.Reader)
if err != nil {
log.Errorw("docker_container_exec_failure", "error", err)
return 1, "", "", err
}
execInspect, err := c.client.ContainerExecInspect(ctx, idresp.ID)
if err != nil {
log.Errorw("docker_container_exec_failure", "error", err)
return 1, "", "", err
}
log.Debugw("docker_container_exec_inspect",
"exitCode", execInspect.ExitCode,
"bytesCopied", n,
"containerID", containerID)
return execInspect.ExitCode, stdout.String(), stderr.String(), nil
}
// CopyToContainer copies content to a container's file system.
// dstPath is the path in the container where the content will be copied.
func (c *Client) CopyToContainer(ctx context.Context, containerID, dstPath string, content io.Reader, options container.CopyToContainerOptions) error {
log.Infow("docker_copy_to_container_started",
"containerID", containerID,
"dstPath", dstPath)
err := c.client.CopyToContainer(ctx, containerID, dstPath, content, options)
if err != nil {
log.Errorw("docker_copy_to_container_failure",
"containerID", containerID,
"dstPath", dstPath,
"error", err)
return err
}
log.Infow("docker_copy_to_container_success",
"containerID", containerID,
"dstPath", dstPath)
return nil
}
// ContainerStats retrieves real-time resource usage stats for a container.
func (c *Client) ContainerStats(ctx context.Context, containerID string) (types.ExecutorStats, error) {
log.Infow("docker_container_stats_started",
"containerID", containerID)
statsResponse, err := c.client.ContainerStats(ctx, containerID, false)
if err != nil {
log.Errorw("docker_container_stats_failure",
"containerID", containerID,
"error", err)
return types.ExecutorStats{}, errors.Wrap(err, "failed to get container stats")
}
defer statsResponse.Body.Close()
// Decode the stats JSON from the response body
var dockerStats container.StatsResponse
if err := json.NewDecoder(statsResponse.Body).Decode(&dockerStats); err != nil {
log.Errorw("docker_container_stats_failure",
"containerID", containerID,
"error", err)
return types.ExecutorStats{}, errors.Wrap(err, "failed to decode container stats")
}
// Calculate CPU usage percentage
var cpuPercent float64
cpuDelta := float64(dockerStats.CPUStats.CPUUsage.TotalUsage) - float64(dockerStats.PreCPUStats.CPUUsage.TotalUsage)
systemDelta := float64(dockerStats.CPUStats.SystemUsage) - float64(dockerStats.PreCPUStats.SystemUsage)
onlineCPUs := float64(dockerStats.CPUStats.OnlineCPUs)
if onlineCPUs == 0.0 {
onlineCPUs = float64(len(dockerStats.CPUStats.CPUUsage.PercpuUsage))
}
if systemDelta > 0.0 && cpuDelta > 0.0 {
cpuPercent = (cpuDelta / systemDelta) * onlineCPUs * 100.0
}
// Calculate memory usage percentage
var memoryPercent float64
if dockerStats.MemoryStats.Limit > 0 {
memoryPercent = (float64(dockerStats.MemoryStats.Usage) / float64(dockerStats.MemoryStats.Limit)) * 100.0
}
// Aggregate network statistics from all interfaces
var rxBytes, rxPackets, rxErrors, rxDropped uint64
var txBytes, txPackets, txErrors, txDropped uint64
for _, network := range dockerStats.Networks {
rxBytes += network.RxBytes
rxPackets += network.RxPackets
rxErrors += network.RxErrors
rxDropped += network.RxDropped
txBytes += network.TxBytes
txPackets += network.TxPackets
txErrors += network.TxErrors
txDropped += network.TxDropped
}
// Aggregate block I/O statistics
var readBytes, writeBytes uint64
for _, blkio := range dockerStats.BlkioStats.IoServiceBytesRecursive {
switch blkio.Op {
case "Read":
readBytes += blkio.Value
case "Write":
writeBytes += blkio.Value
}
}
result := types.ExecutorStats{}
result.CPUUsage.TotalUsage = dockerStats.CPUStats.CPUUsage.TotalUsage
result.CPUUsage.UsageInKernelmode = dockerStats.CPUStats.CPUUsage.UsageInKernelmode
result.CPUUsage.UsageInUsermode = dockerStats.CPUStats.CPUUsage.UsageInUsermode
result.CPUUsage.Percent = cpuPercent
result.Memory.Usage = dockerStats.MemoryStats.Usage
result.Memory.MaxUsage = dockerStats.MemoryStats.MaxUsage
result.Memory.Limit = dockerStats.MemoryStats.Limit
result.Memory.Percent = memoryPercent
result.Network.RxBytes = rxBytes
result.Network.RxPackets = rxPackets
result.Network.RxErrors = rxErrors
result.Network.RxDropped = rxDropped
result.Network.TxBytes = txBytes
result.Network.TxPackets = txPackets
result.Network.TxErrors = txErrors
result.Network.TxDropped = txDropped
result.BlockIO.ReadBytes = readBytes
result.BlockIO.WriteBytes = writeBytes
result.Timestamp = time.Now().UnixMilli()
log.Infow("docker_container_stats_success",
"containerID", containerID,
"cpuPercent", cpuPercent,
"memoryPercent", memoryPercent)
return result, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package docker
import (
"archive/tar"
"bytes"
"context"
"encoding/base64"
"fmt"
"io"
"path/filepath"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types/mount"
"github.com/docker/go-connections/nat"
"github.com/pkg/errors"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
var ErrNotInstalled = errors.New("docker is not installed")
const (
nanoCPUsPerCore = 1e9
labelExecutorName = "nunet-executor"
labelJobID = "nunet-jobID"
labelExecutionID = "nunet-executionID"
outputStreamCheckTickTime = 100 * time.Millisecond
outputStreamCheckTimeout = 5 * time.Second
statusWaitTickTime = 100 * time.Millisecond
statusWaitTimeout = 10 * time.Second
initScriptsBaseDir = "/tmp/nunet/init-scripts-"
enableTTY = false
)
// Executor manages the lifecycle of Docker containers for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Docker client for container management.
fs afero.Afero
}
var _ types.Executor = (*Executor)(nil)
// NewExecutor initializes a new Executor instance with a Docker client.
func NewExecutor(ctx context.Context, fs afero.Afero, id string) (*Executor, error) {
dockerClient, err := NewDockerClient()
if err != nil {
return nil, err
}
if !dockerClient.IsInstalled(ctx) {
return nil, ErrNotInstalled
}
return &Executor{
ID: id,
fs: fs,
client: dockerClient,
}, nil
}
func (e *Executor) GetID() string {
return e.ID
}
// Start begins the execution of a request by starting a Docker container.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
endSpan := observability.StartSpan(ctx, "docker_executor_start")
defer endSpan()
log.Infow("docker_executor_start_begin",
"labels", []string{string(observability.LabelDeployment)},
"jobID", request.JobID,
"executionID", request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// container is already running.
containerID, err := e.FindRunningContainer(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running container for this execution, we will instead check for a handler, and
// failing that will create a new container.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
log.Errorw("docker_executor_start_failure",
"labels", []string{string(observability.LabelDeployment)},
"executionID", request.ExecutionID,
"error", "execution already started")
return fmt.Errorf("execution is already started")
}
log.Errorw("docker_executor_start_failure",
"labels", []string{string(observability.LabelDeployment)},
"executionID", request.ExecutionID,
"error", "execution completed")
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
containerID, err = e.newDockerExecutionContainer(ctx, request, enableTTY)
if err != nil {
log.Errorw("docker_executor_start_failure",
"labels", []string{string(observability.LabelDeployment)},
"executionID", request.ExecutionID,
"error", err)
return fmt.Errorf("failed to create new container: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
fs: e.fs,
executionID: request.ExecutionID,
containerID: containerID,
resultsDir: request.ResultsDir,
persistLogsDuration: request.PersistLogsDuration,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
TTYEnabled: enableTTY,
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the container.
go handler.run(ctx)
return nil
}
// Pause pauses the container
func (e *Executor) Pause(
ctx context.Context,
executionID string,
) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
return handler.pause(ctx)
}
// Resume resumes the container
func (e *Executor) Resume(
ctx context.Context,
executionID string,
) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
return handler.resume(ctx)
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
endSpan := observability.StartSpan(ctx, "docker_executor_wait")
defer endSpan()
log.Infow("docker_executor_wait_begin", "executionID", executionID)
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
log.Errorw("docker_executor_wait_failure", "executionID", executionID, "error", "execution not found")
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
log.Infow("docker_executor_wait_success", "executionID", executionID)
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
log.Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
log.Debugf("executionID %s received results from execution ", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// Remove removes a container identified by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Remove(executionID string, timeout time.Duration) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
return handler.destroy(timeout)
}
// GetLogStream provides a stream of output logs for a specific execution.
// Parameters 'withHistory' and 'follow' control whether to include past logs
// and whether to keep the stream open for new logs, respectively.
// It returns an error if the execution is not found.
func (e *Executor) GetLogStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
// It's possible we've recorded the execution as running, but have not yet added the handler to
// the handler map because we're still waiting for the container to start. We will try and wait
// for a few seconds to see if the handler is added to the map.
chHandler := make(chan *executionHandler)
chExit := make(chan struct{})
go func(ch chan *executionHandler, exit chan struct{}) {
// Check the handlers every 100ms and send it down the
// channel if we find it. If we don't find it after 5 seconds
// then we'll be told on the exit channel
ticker := time.NewTicker(outputStreamCheckTickTime)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h, found := e.handlers.Get(request.ExecutionID)
if found {
ch <- h
return
}
case <-exit:
ticker.Stop()
return
}
}
}(chHandler, chExit)
// Either we'll find a handler for the execution (which might have finished starting)
// or we'll timeout and return an error.
select {
case handler := <-chHandler:
return handler.outputStream(ctx, request)
case <-time.After(outputStreamCheckTimeout):
chExit <- struct{}{}
}
return nil, fmt.Errorf("execution (%s) not found", request.ExecutionID)
}
// List returns a slice of ExecutionListItem containing information about current executions.
// This implementation currently returns an empty list and should be updated in the future.
func (e *Executor) List() []types.ExecutionListItem {
// TODO: list dms containers
return nil
}
// GetStatus returns the status of the execution identified by its executionID.
// It returns the status of the execution and an error if the execution is not found or status is unknown.
func (e *Executor) GetStatus(ctx context.Context, executionID string) (types.ExecutionStatus, error) {
handler, found := e.handlers.Get(executionID)
if !found {
return "", fmt.Errorf("execution (%s) not found", executionID)
}
return handler.status(ctx)
}
// WaitForStatus waits for the execution to reach a specific status.
// It returns an error if the execution is not found or the status is unknown.
func (e *Executor) WaitForStatus(
ctx context.Context,
executionID string,
status types.ExecutionStatus,
timeout *time.Duration,
) error {
waitTimeout := statusWaitTimeout
if timeout != nil {
waitTimeout = *timeout
}
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("execution (%s) not found", executionID)
}
ticker := time.NewTicker(statusWaitTickTime)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
s, err := handler.status(ctx)
if err != nil {
return err
}
if s == status {
return nil
}
case <-time.After(waitTimeout):
return fmt.Errorf("execution (%s) did not reach status %s", executionID, status)
}
}
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context, request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// Cleanup removes all Docker resources associated with the executor.
// This includes removing containers including networks and volumes with the executor's label.
// It also removes all temporary directories created for init scripts.
func (e *Executor) Cleanup(ctx context.Context) error {
endSpan := observability.StartSpan(ctx, "docker_executor_cleanup")
defer endSpan()
log.Infow("docker_executor_cleanup_begin",
"labels", []string{string(observability.LabelDeployment)},
"executorID", e.ID)
err := e.client.RemoveObjectsWithLabel(ctx, labelExecutorName, e.ID)
if err != nil {
log.Errorw("docker_executor_cleanup_failure",
"labels", []string{string(observability.LabelDeployment)},
"executorID", e.ID,
"error", err)
return fmt.Errorf("failed to remove containers: %w", err)
}
log.Infow("docker_executor_cleanup_success",
"labels", []string{string(observability.LabelDeployment)},
"executorID", e.ID)
// Remove all provision scripts used for mounting
pattern := initScriptsBaseDir + "*"
matches, err := afero.Glob(e.fs, pattern)
if err != nil {
return fmt.Errorf("failed to find init script directories: %w", err)
}
for _, dir := range matches {
if err := e.fs.RemoveAll(dir); err != nil {
log.Warnf("Failed to remove init script directory %s: %v", dir, err)
} else {
log.Infof("Removed init script directory: %s", dir)
}
}
return nil
}
// Exec executes a command in the container with the given containerID.
// Returns the exit code, stdout, stderr and an error if the execution fails.
func (e *Executor) Exec(ctx context.Context, executionID string, command []string) (int, string, string, error) {
h, found := e.handlers.Get(executionID)
if !found {
return 0, "", "", fmt.Errorf("failed to get execution handler for execution=%q", executionID)
}
return e.client.Exec(ctx, h.containerID, command)
}
// Stats returns the resource usage stats for a container. errors if the execution is not found or stats cannot be retrieved.
func (e *Executor) Stats(ctx context.Context, executionID string) (*types.ExecutorStats, error) {
endSpan := observability.StartSpan(ctx, "docker_executor_stats")
defer endSpan()
log.Infow("docker_executor_stats_begin",
"labels", []string{string(observability.LabelDeployment)},
"executionID", executionID)
handler, found := e.handlers.Get(executionID)
if !found {
log.Errorw("docker_executor_stats_failure",
"labels", []string{string(observability.LabelDeployment)},
"executionID", executionID,
"error", "execution not found")
return nil, fmt.Errorf("execution (%s) not found", executionID)
}
stats, err := e.client.ContainerStats(ctx, handler.containerID)
if err != nil {
log.Errorw("docker_executor_stats_failure",
"labels", []string{string(observability.LabelDeployment)},
"executionID", executionID,
"error", err)
return nil, fmt.Errorf("failed to get container stats: %w", err)
}
log.Infow("docker_executor_stats_success",
"labels", []string{string(observability.LabelDeployment)},
"executionID", executionID)
return &stats, nil
}
// copyKeysToContainer copies the keys from the request to the
// container, respecting each key's destination path when necessary
func (e *Executor) copyKeysToContainer(ctx context.Context,
containerID string, keys []types.AllocationKey, user string,
) error {
if len(keys) == 0 {
log.Infof("No keys to copy to container %s", containerID)
return nil
}
log.Infof("Starting to copy %d keys to container %s", len(keys), containerID)
for i, key := range keys {
if key.File == "" {
return fmt.Errorf("key %d has no file", i)
}
if key.Type == "" {
return fmt.Errorf("key %d has no type", i)
}
var destination string
var buf bytes.Buffer
tarF := tar.NewWriter(&buf)
switch key.Type {
case types.KeySSH:
if user == "" || user == "root" {
destination = "/root"
} else {
destination = "/home/" + user // extraction will fail if user does not exist
}
if err := tarF.WriteHeader(&tar.Header{
// implicitly create .ssh dir
// intentionally not creating the preceding path to not attempt to add
// a key to a user that does not exist
Name: ".ssh/authorized_keys",
Mode: 0o600,
Size: int64(len(key.File)),
}); err != nil {
return fmt.Errorf("unable to write tar header for key. Skipping")
}
case types.KeyGPG:
if key.Dest == "" {
return fmt.Errorf("destination required for key type %s", types.KeyGPG)
}
// write to root and take care of path creating during extraction
destination = "/"
if err := tarF.WriteHeader(&tar.Header{
// implicitly create the preceding path for the final file
Name: key.Dest,
Mode: 0o600,
Size: int64(len(key.File)),
}); err != nil {
return fmt.Errorf("unable to write tar header for key")
}
default:
return fmt.Errorf("unknown key type %q", key.Type)
}
if _, err := tarF.Write([]byte(key.File)); err != nil {
return fmt.Errorf("unable to write key to tar. Skipping")
}
if err := tarF.Close(); err != nil {
return fmt.Errorf("unable to close tar")
}
if err := e.client.CopyToContainer(
ctx,
containerID,
destination,
&buf,
container.CopyToContainerOptions{},
); err != nil {
log.Errorf("Failed to copy key to container %s:%s : %v", containerID, key.Dest, err)
return fmt.Errorf("failed to copy key to container %s:%s : %v", containerID, key.Dest, err)
}
}
log.Infof("Successfully copied keys to container %s", containerID)
return nil
}
// newDockerExecutionContainer is an internal method called by Start to set up a new Docker container
// for the job execution. It configures the container based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up environment variables, mounts and resource
// constraints. It then creates the container but does not start it.
// The method returns a container.CreateResponse and an error if any part of the setup fails.
func (e *Executor) newDockerExecutionContainer(
ctx context.Context,
params *types.ExecutionRequest,
tty bool,
) (string, error) {
dockerArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return "", fmt.Errorf("failed to decode docker engine spec: %w", err)
}
// expose mapped ports
exposes := nat.PortSet{}
for _, port := range params.PortsToBind {
p, err := nat.NewPort("tcp", fmt.Sprintf("%d", port.ExecutorPort))
if err == nil {
exposes[p] = struct{}{}
}
}
containerConfig := container.Config{
Image: dockerArgs.Image,
Env: dockerArgs.Environment,
Entrypoint: dockerArgs.Entrypoint,
Cmd: dockerArgs.Cmd,
Labels: e.containerLabels(params.JobID, params.ExecutionID),
WorkingDir: dockerArgs.WorkingDirectory,
// Needs to be true for applications such as Jupyter or Gradio to work correctly. See issue #459 for details.
Tty: tty,
ExposedPorts: exposes,
User: dockerArgs.User,
}
mounts, err := makeContainerMounts(params.Inputs, params.Outputs, params.ResultsDir)
if err != nil {
return "", fmt.Errorf("failed to create container mounts: %w", err)
}
initScriptsDir, err := prepareInitScripts(e.fs, params.ProvisionScripts, params.ExecutionID)
if err != nil {
return "", fmt.Errorf("failed to prepare init scripts: %w", err)
}
if initScriptsDir != "" {
oldEntryPoint := containerConfig.Entrypoint
oldCmd := containerConfig.Cmd
// Execute init scripts first
containerConfig.Entrypoint = []string{"/bin/sh", "-c"}
containerConfig.Cmd = []string{
fmt.Sprintf("%s/run_provision_scripts.sh && %s %s",
initScriptsDir,
strings.Join(oldEntryPoint, " "),
strings.Join(oldCmd, " ")),
}
// Add a mount for the init scripts
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: initScriptsDir,
Target: initScriptsDir,
})
}
log.Infof("Adding %d GPUs to request", len(params.Resources.GPUs))
hostConfig, err := configureHostConfig(params, dockerArgs, mounts)
if err != nil {
return "", fmt.Errorf("failed to configure host config: %w", err)
}
imagePullOpts := image.PullOptions{}
if dockerArgs.RegistryAuth.Username != "" && dockerArgs.RegistryAuth.Password != "" {
registryAuth := `{"username":"` + dockerArgs.RegistryAuth.Username +
`","password":"` + dockerArgs.RegistryAuth.Password + `"}`
imagePullOpts.RegistryAuth = base64.StdEncoding.EncodeToString([]byte(registryAuth))
}
hasImage := e.client.HasImage(ctx, dockerArgs.Image)
executionContainer, err := e.client.CreateContainer(
ctx,
&containerConfig,
&hostConfig,
nil,
imagePullOpts,
nil,
params.JobID,
!hasImage, // only pull if we don't have the image
)
if err != nil {
return "", fmt.Errorf("failed to create container: %w", err)
}
if len(params.Keys) > 0 {
if err := e.copyKeysToContainer(ctx, executionContainer, params.Keys, dockerArgs.User); err != nil {
log.Warnf("failed to copy SSH keys to container: %v", err)
}
}
return executionContainer, nil
}
// prepareInitScripts creates a shell script that will run all init scripts
func prepareInitScripts(fs afero.Afero, scripts map[string][]byte, id string) (string, error) {
if len(scripts) == 0 {
return "", nil
}
tempDir := initScriptsBaseDir + id
err := fs.MkdirAll(tempDir, 0o700)
if err != nil {
return "", fmt.Errorf("failed to create init scripts base directory: %w", err)
}
scriptNames := make([]string, 0, len(scripts))
for name, content := range scripts {
filename := filepath.Join(tempDir, name)
if err := fs.WriteFile(filename, content, 0o700); err != nil {
return "", fmt.Errorf("failed to write init script %s: %w", name, err)
}
scriptNames = append(scriptNames, filename)
}
// Create a wrapper script to execute all init scripts
wrapperContent := "#!/bin/sh\n\n"
for _, script := range scriptNames {
wrapperContent += fmt.Sprintf("echo 'Executing %s'\n", filepath.Base(script))
wrapperContent += fmt.Sprintf("%s\n", script)
}
wrapperPath := filepath.Join(tempDir, "run_provision_scripts.sh")
if err := fs.WriteFile(wrapperPath, []byte(wrapperContent), 0o700); err != nil {
return "", fmt.Errorf("failed to write wrapper script: %w", err)
}
return tempDir, nil
}
// configureHostConfig sets up the host configuration for the container based on the
// GPU vendor and resources requested by the execution. It supports both GPU and CPU configurations.
func configureHostConfig(params *types.ExecutionRequest, dockerArgs EngineSpec, mounts []mount.Mount) (container.HostConfig, error) {
var hostConfig container.HostConfig
// set the config for the first selected gpu
// TODO: support multiple GPUs
if len(params.Resources.GPUs) >= 1 {
switch params.Resources.GPUs[0].Vendor {
case types.GPUVendorNvidia:
deviceIDs := make([]string, len(params.Resources.GPUs))
for i, gpu := range params.Resources.GPUs {
deviceIDs[i] = fmt.Sprint(gpu.Index)
}
hostConfig = container.HostConfig{
Resources: container.Resources{
DeviceRequests: []container.DeviceRequest{
{
DeviceIDs: deviceIDs,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = container.HostConfig{
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
// Updated the device handling for Intel GPUs.
// Previously, specific device paths were determined using PCI addresses and symlinks.
// Now, the approach has been simplified by directly binding the entire /dev/dri directory.
// This change exposes all Intel GPUs to the container, which may be preferable for
// environments with multiple Intel GPUs. It reduces complexity as granular control
// is not required if all GPUs need to be accessible.
case types.GPUVendorIntel:
hostConfig = container.HostConfig{
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
}
}
// Set the cpu config
hostConfig.Resources.NanoCPUs = int64(params.Resources.CPU.Cores * nanoCPUsPerCore)
hostConfig.Resources.CPUCount = int64(params.Resources.CPU.Cores)
// setup mounts
hostConfig.Mounts = mounts
// configure port binding
portMaps := make(map[nat.Port][]nat.PortBinding)
for _, toBind := range params.PortsToBind {
natPort, err := nat.NewPort("tcp", strconv.Itoa(toBind.ExecutorPort))
if err != nil {
return hostConfig, fmt.Errorf("failed to create port: %w", err)
}
if _, ok := portMaps[natPort]; ok {
portMaps[natPort] = append(portMaps[natPort], nat.PortBinding{
HostIP: toBind.IP,
HostPort: fmt.Sprintf("%d", toBind.HostPort),
})
continue
}
portMaps[natPort] = []nat.PortBinding{
{
HostIP: toBind.IP,
HostPort: fmt.Sprintf("%d", toBind.HostPort),
},
}
}
hostConfig.PortBindings = portMaps
hostConfig.Privileged = dockerArgs.Privileged
// Configure DNS settings
hostConfig.DNS = []string{params.GatewayIP, "1.1.1.1"}
hostConfig.DNSSearch = []string{"internal"}
hostConfig.DNSOptions = []string{
"ndots:1", // reduce DNS lookups by setting ndots lower
"timeout:2",
"attempts:1",
}
// set restart policy
if dockerArgs.RestartPolicy != "" {
hostConfig.RestartPolicy = container.RestartPolicy{
Name: container.RestartPolicyMode(dockerArgs.RestartPolicy),
}
// if set to 'on-failure', hardcode maximum retry count to 3
if dockerArgs.RestartPolicy == "on-failure" {
hostConfig.RestartPolicy.MaximumRetryCount = 3
}
}
return hostConfig, nil
}
// makeContainerMounts creates the mounts for the container based on the input and output
// volumes provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeContainerMounts(
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]mount.Mount, error) {
// the actual mounts we will give to the container
// these are paths for both input and output data
mounts := make([]mount.Mount, 0)
for _, input := range inputs {
mnt := mount.Mount{
Type: mount.TypeBind,
Source: input.Source,
Target: input.Target,
ReadOnly: input.ReadOnly,
}
// if the source contains a "/" (is a path) we assume it's a bind mount
// otherwise we assume it's a named volume
if strings.Contains(input.Source, "/") {
mnt.Type = mount.TypeBind
} else {
mnt.Type = mount.TypeVolume
mnt.Source = input.Source // perhaps precede with orch peerID?
}
mounts = append(mounts, mnt)
}
for _, output := range outputs {
if output.Source == "" {
return nil, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return nil, fmt.Errorf("results directory is empty")
}
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: output.Source,
Target: output.Target,
// this is an output volume so can be written to
ReadOnly: false,
})
}
return mounts, nil
}
// containerLabels returns the labels to be applied to the container for the given job and execution.
func (e *Executor) containerLabels(jobID string, executionID string) map[string]string {
return map[string]string{
labelExecutorName: e.ID,
labelJobID: labelJobValue(e.ID, jobID),
labelExecutionID: labelExecutionValue(e.ID, jobID, executionID),
}
}
// labelJobValue returns the value for the job label.
func labelJobValue(executorID string, jobID string) string {
return fmt.Sprintf("%s_%s", executorID, jobID)
}
// labelExecutionValue returns the value for the execution label.
func labelExecutionValue(executorID string, jobID string, executionID string) string {
return fmt.Sprintf("%s_%s_%s", executorID, jobID, executionID)
}
// FindRunningContainer finds the container that is running the execution
// with the given ID. It returns the container ID if found, or an error if
// the container is not found.
func (e *Executor) FindRunningContainer(
ctx context.Context,
jobID string,
executionID string,
) (string, error) {
labelValue := labelExecutionValue(e.ID, jobID, executionID)
return e.client.FindContainer(ctx, labelExecutionID, labelValue)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package docker
import (
"context"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strconv"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
var DestroyTimeout = time.Second * 10
// executionHandler manages the lifecycle and execution of a Docker container for a specific job.
type executionHandler struct {
// provided by the executor
ID string
client *Client // Docker client for container management.
fs afero.Afero
// meta data about the task
jobID string
executionID string
containerID string
resultsDir string // Directory to store execution results.
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // Blocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
persistLogsDuration time.Duration
// log files
stdoutFile afero.File
stderrFile afero.File
// TTY setting
TTYEnabled bool // Indicates if TTY is enabled for the container.
}
// active checks if the execution handler's container is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the container and handles its execution lifecycle.
func (h *executionHandler) run(ctx context.Context) {
endSpan := observability.StartSpan(ctx, "docker_execution_handler_run")
defer endSpan()
h.running.Store(true)
defer close(h.waitCh)
if err := h.prepareLogFiles(); err != nil {
err = fmt.Errorf("failed to create execution log files: %v", err)
log.Errorw("docker_execution_handler_run_failure",
"labels", string(observability.LabelDeployment),
"error", err)
h.result = types.NewFailedExecutionResult(err)
return
}
// monitor kill events to see if the container was externally killed
eventsCh, errEventsCh := h.client.client.Events(ctx, events.ListOptions{
Filters: filters.NewArgs(
filters.Arg("container", h.containerID),
filters.Arg("event", "kill"),
),
})
if err := h.client.StartContainer(ctx, h.containerID); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start container: %v", err))
log.Errorw("docker_execution_handler_run_failure",
"labels", string(observability.LabelDeployment),
"error", err)
return
}
close(h.activeCh)
log.Infow("docker_execution_handler_run_success",
"labels", string(observability.LabelDeployment),
"executionID", h.executionID)
var containerError error
var containerExitStatusCode int64
// Start streaming logs before waiting for container exit
stdoutPipe, stderrPipe, logsErr := h.client.FollowLogs(ctx, h.containerID)
if logsErr != nil {
followError := fmt.Errorf("failed to follow container logs: %w", logsErr)
log.Errorw("docker_execution_handler_follow_logs_failure",
"error", logsErr)
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: followError.Error(),
}
return
}
// Create channels to signal when log copying is done
stdoutDone := make(chan bool)
stderrDone := make(chan bool)
// Start copying stdout in background
go func() {
defer close(stdoutDone)
defer func() {
if h.stdoutFile != nil {
err := h.stdoutFile.Close()
if err != nil {
log.Warnf("Error closing stdout file: %v", err)
}
}
}()
if _, err := io.Copy(h.stdoutFile, stdoutPipe); err != nil {
log.Warnf("Error copying stdout: %v", err)
}
}()
// Start copying stderr in background
go func() {
defer close(stderrDone)
defer func() {
if h.stderrFile != nil {
err := h.stderrFile.Close()
if err != nil {
log.Warnf("Error closing stderr file: %v", err)
}
}
}()
if _, err := io.Copy(h.stderrFile, stderrPipe); err != nil {
log.Warnf("Error copying stderr: %v", err)
}
}()
// Wait for container exit status while logs are being copied
statusCh, errCh := h.client.WaitContainer(ctx, h.containerID)
select {
case status := <-ctx.Done():
h.result = types.NewFailedExecutionResult(fmt.Errorf("execution cancelled: %v", status))
log.Errorw("docker_execution_handler_run_failure_cancelled", "executionID", h.executionID)
return
case err := <-errCh:
log.Errorf("error while waiting for container: %v\n", err)
h.result = types.NewFailedExecutionResult(
fmt.Errorf("failed to wait for container: %v", err),
)
return
case exitStatus := <-statusCh:
containerExitStatusCode = exitStatus.StatusCode
containerJSON, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: err.Error(),
}
log.Errorw("docker_execution_handler_inspect_container_failure", "error", err)
return
}
if containerJSON.State.OOMKilled {
containerError = errors.New("container was killed due to OOM")
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
log.Errorw("docker_execution_handler_container_oom_killed",
"labels", string(observability.LabelAllocation),
"executionID", h.executionID)
return
}
if exitStatus.Error != nil {
containerError = errors.New(exitStatus.Error.Message)
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
log.Errorw("docker_execution_handler_container_error", "executionID", h.executionID)
return
}
if containerExitStatusCode != 0 {
containerError = errors.New("container either killed or application returned non-zero exit code")
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
log.Warnw("docker_execution_handler_non_zero_exit_code", "executionID", h.executionID)
return
}
// Here we're trying to identify containers terminated by external means but
// were returned, gracefully, with a 0 exit code.
if containerExitStatusCode == 0 {
// Check if there were any kill events
select {
case err := <-errEventsCh:
if err != nil {
log.Warnw("docker_execution_handler_events_error", "error", err)
}
case <-eventsCh:
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: "container was killed by external signal",
Killed: true,
}
log.Infow("docker_execution_handler_container_killed", "executionID", h.executionID)
return
default:
}
}
}
// Initialize the result with the exit status code
h.result = types.NewExecutionResult(int(containerExitStatusCode))
// Wait for log copying to complete
<-stdoutDone
<-stderrDone
// Read the complete logs for the result
if stdout, err := os.ReadFile(filepath.Join(h.resultsDir, "stdout.log")); err == nil {
h.result.STDOUT = string(stdout)
} else {
log.Errorf("failed to read stdout logs file and retrieve logs: %v", err)
}
if stderr, err := os.ReadFile(filepath.Join(h.resultsDir, "stderr.log")); err == nil {
h.result.STDERR = string(stderr)
} else {
log.Errorf("failed to read stderr logs file and retrieve logs: %v", err)
}
log.Infow("docker_execution_handler_run_logs_success", "executionID", h.executionID)
h.running.Store(false)
}
// pause pauses the main process of the container without terminating it.
func (h *executionHandler) pause(ctx context.Context) error {
return h.client.PauseContainer(ctx, h.containerID)
}
// resume resumes the process execution within the container
func (h *executionHandler) resume(ctx context.Context) error {
return h.client.ResumeContainer(ctx, h.containerID)
}
// kill sends a stop signal to the container.
func (h *executionHandler) kill(ctx context.Context) error {
endSpan := observability.StartSpan(ctx, "docker_execution_handler_kill")
defer endSpan()
timeout := int(DestroyTimeout)
stopOptions := container.StopOptions{
Timeout: &timeout,
}
err := h.client.StopContainer(ctx, h.containerID, stopOptions)
if err != nil {
log.Errorw("docker_execution_handler_kill_failure",
"labels", string(observability.LabelAllocation),
"error", err,
"executionID", h.executionID)
return err
}
log.Infow("docker_execution_handler_kill_success",
"labels", string(observability.LabelAllocation),
"executionID", h.executionID)
return nil
}
// destroy cleans up the container and its associated resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
endSpan := observability.StartSpan(ctx, "docker_execution_handler_destroy")
defer endSpan()
// stop the container
if err := h.kill(ctx); err != nil {
log.Errorw("docker_execution_handler_destroy_failure",
"labels", string(observability.LabelDeployment),
"error", err,
"executionID", h.executionID)
return fmt.Errorf("failed to kill container (%s): %w", h.containerID, err)
}
if err := h.client.RemoveContainer(ctx, h.containerID); err != nil {
log.Errorw("docker_execution_handler_destroy_failure",
"labels", string(observability.LabelDeployment),
"error", err,
"executionID", h.executionID)
return err
}
// remove init scripts
err := os.RemoveAll(initScriptsBaseDir + h.executionID)
if err != nil {
return err
}
// Remove related objects like networks or volumes created for this execution.
err = h.client.RemoveObjectsWithLabel(
ctx,
labelExecutionID,
labelExecutionValue(h.ID, h.jobID, h.executionID),
)
if err != nil {
log.Errorw("docker_execution_handler_destroy_failure",
"labels", string(observability.LabelDeployment),
"error", err,
"executionID", h.executionID)
return err
}
h.handleDeletionOfLogFiles(h.persistLogsDuration)
log.Infow("docker_execution_handler_destroy_success",
"labels", string(observability.LabelDeployment),
"executionID", h.executionID)
return nil
}
func (h *executionHandler) outputStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
endSpan := observability.StartSpan(ctx, "docker_execution_handler_output_stream")
defer endSpan()
since := "1" // Default to the start of UNIX time to get all logs.
if request.Tail {
since = strconv.FormatInt(time.Now().Unix(), 10)
}
select {
case <-ctx.Done():
log.Errorw("docker_execution_handler_output_stream_canceled", "executionID", h.executionID)
return nil, ctx.Err()
case <-h.activeCh: // Ensure the container is active before attempting to stream logs.
}
// Gets the underlying reader, and provides data since the value of the `since` timestamp.
reader, err := h.client.GetOutputStream(ctx, h.containerID, since, request.Follow)
if err != nil {
log.Errorw("docker_execution_handler_output_stream_failure",
"error", err,
"executionID", h.executionID)
return nil, err
}
log.Infow("docker_execution_handler_output_stream_success", "executionID", h.executionID)
return reader, nil
}
// status returns the result of the execution.
func (h *executionHandler) status(ctx context.Context) (types.ExecutionStatus, error) {
if h.result != nil {
if h.result.ExitCode == types.ExecutionStatusCodeSuccess {
return types.ExecutionStatusSuccess, nil
}
return types.ExecutionStatusFailed, nil
}
info, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
return types.ExecutionStatusFailed, fmt.Errorf("failed to get container status: %v", err)
}
switch info.State.Status {
case "created":
return types.ExecutionStatusPending, nil
case "running":
return types.ExecutionStatusRunning, nil
case "paused":
return types.ExecutionStatusPaused, nil
case "exited":
if info.State.ExitCode == 0 {
return types.ExecutionStatusSuccess, nil
}
return types.ExecutionStatusFailed, nil
case "dead":
return types.ExecutionStatusFailed, nil
default:
return types.ExecutionStatusFailed, fmt.Errorf("unknown container status: %s", info.State.Status)
}
}
// prepareLogFiles creates/opens the log files in the results directory
func (h *executionHandler) prepareLogFiles() error {
log.Debug("preparing log files")
if err := os.MkdirAll(h.resultsDir, 0o755); err != nil {
return fmt.Errorf("failed to create results directory: %w", err)
}
var err error
h.stdoutFile, err = h.fs.OpenFile(filepath.Join(h.resultsDir, "stdout.log"), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("failed to open stdout file: %w", err)
}
log.Debugf("stdout saved to: %s", filepath.Join(h.resultsDir, "stdout.log"))
if !h.TTYEnabled {
h.stderrFile, err = h.fs.OpenFile(filepath.Join(h.resultsDir, "stderr.log"), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
if err != nil {
h.stdoutFile.Close()
return fmt.Errorf("failed to open stderr file: %w", err)
}
log.Debugf("stderr saved to: %s", filepath.Join(h.resultsDir, "stderr.log"))
}
return nil
}
// handleDeletionOfLogFiles closes the log files
func (h *executionHandler) handleDeletionOfLogFiles(after time.Duration) {
go func() {
log.Debugf("deleting log files after %s", after)
<-time.After(after)
log.Debug("closing log files")
err := h.fs.RemoveAll(h.resultsDir)
if err != nil {
log.Errorf("failed to remove log files: %v", err)
}
}()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package docker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyImage = "image"
EngineKeyEntrypoint = "entrypoint"
EngineKeyCmd = "cmd"
EngineKeyEnvironment = "environment"
EngineKeyWorkingDirectory = "working_directory"
)
type RegistryAuth struct {
Username string `json:"username,omitempty" yaml:"username,omitempty"`
Password string `json:"password,omitempty" yaml:"password,omitempty"`
}
// EngineSpec contains necessary parameters to execute a docker job.
type EngineSpec struct {
// Image this should be pullable by docker
Image string `json:"image,omitempty" yaml:"image,omitempty"`
// Entrypoint optionally override the default entrypoint
Entrypoint []string `json:"entrypoint,omitempty" yaml:"entrypoint,omitempty"`
// Cmd specifies the command to run in the container
Cmd []string `json:"cmd,omitempty" yaml:"cmd,omitempty"`
// EnvironmentVariables is a slice of env to run the container with
Environment []string `json:"environment,omitempty" yaml:"environment,omitempty"`
// WorkingDirectory inside the container
WorkingDirectory string `json:"working_directory,omitempty" yaml:"working_directory,omitempty"`
// Privileged indicates whether the container should run with --privileged mode
Privileged bool `json:"privileged,omitempty" yaml:"privileged,omitempty"`
// User to run the container as
User string `json:"user,omitempty" yaml:"user,omitempty"`
// RegistryAuth to authenticate with registries when pulling private images
RegistryAuth RegistryAuth `json:"registry_auth,omitempty" yaml:"registry_auth,omitempty"`
// RestartPolicy for the container
RestartPolicy string `json:"restart_policy,omitempty" yaml:"restart_policy,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.Image) {
return fmt.Errorf("invalid docker engine params: image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a docker engine spec
// It converts the params into a docker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(string(types.ExecutorTypeDocker)) {
return EngineSpec{}, fmt.Errorf(
"invalid docker engine type. expected %s, but received: %s",
types.ExecutorTypeDocker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid docker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode docker engine params: %w", err)
}
var dockerSpec *EngineSpec
if err := json.Unmarshal(paramBytes, &dockerSpec); err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode docker engine params: %w", err)
}
return *dockerSpec, dockerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Docker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewDockerEngineBuilder function initializes a new DockerEngineBuilder instance.
// It sets the engine type to model.EngineDocker.String() and image as per the input argument.
func NewDockerEngineBuilder(image string) *EngineBuilder {
eb := types.NewSpecConfig(string(types.ExecutorTypeDocker))
eb.WithParam(EngineKeyImage, image)
return &EngineBuilder{eb: eb}
}
// WithEntrypoint is a builder method that sets the Docker engine entrypoint.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEntrypoint(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEntrypoint, e)
return b
}
// WithCmd is a builder method that sets the Docker engine's Command.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithCmd(c ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyCmd, c)
return b
}
// WithEnvironment is a builder method that sets the Docker engine's environment variables.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEnvironment(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEnvironment, e)
return b
}
// WithWorkingDirectory is a builder method that sets the Docker engine's working directory.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithWorkingDirectory(w string) *EngineBuilder {
b.eb.WithParam(EngineKeyWorkingDirectory, w)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package null
import (
"context"
"fmt"
"io"
"time"
"gitlab.com/nunet/device-management-service/types"
)
// Executor is a no-op implementation of the Executor interface.
type Executor struct{}
// NewExecutor creates a new Executor.
func NewExecutor(_ context.Context, _ string) (types.Executor, error) {
return &Executor{}, nil
}
var _ types.Executor = (*Executor)(nil)
func (e *Executor) GetID() string {
return ""
}
func (e *Executor) Exec(_ context.Context, _ string, _ []string) (int, string, string, error) {
return 0, "", "", nil
}
// Start does nothing and returns nil.
func (e *Executor) Start(_ context.Context, _ *types.ExecutionRequest) error {
return nil
}
// Run returns a nil result and nil error.
func (e *Executor) Run(_ context.Context, _ *types.ExecutionRequest) (*types.ExecutionResult, error) {
return nil, nil
}
// Wait returns channels that immediately close.
func (e *Executor) Wait(_ context.Context, _ string) (<-chan *types.ExecutionResult, <-chan error) {
resultCh := make(chan *types.ExecutionResult)
errCh := make(chan error)
close(resultCh)
close(errCh)
return resultCh, errCh
}
// Cancel does nothing and returns nil.
func (e *Executor) Cancel(_ context.Context, _ string) error {
return nil
}
// Remove does nothing and returns nil.
func (e *Executor) Remove(_ string, _ time.Duration) error {
return nil
}
// GetLogStream returns a closed io.ReadCloser and nil error.
func (e *Executor) GetLogStream(_ context.Context, _ types.LogStreamRequest) (io.ReadCloser, error) {
return io.NopCloser(nil), nil
}
// List returns an empty slice of ExecutionListItem.
func (e *Executor) List() []types.ExecutionListItem {
return []types.ExecutionListItem{}
}
// Cleanup does nothing and returns nil.
func (e *Executor) Cleanup(_ context.Context) error {
return nil
}
// GetStatus returns an empty ExecutionStatus.
func (e *Executor) GetStatus(_ context.Context, _ string) (types.ExecutionStatus, error) {
return "", nil
}
// Pause does nothing and returns nil.
func (e *Executor) Pause(_ context.Context, _ string) error {
return nil
}
// Resume does nothing and returns nil.
func (e *Executor) Resume(_ context.Context, _ string) error {
return nil
}
// WaitForStatus does nothing and returns nil.
func (e *Executor) WaitForStatus(_ context.Context, _ string, _ types.ExecutionStatus, _ *time.Duration) error {
return nil
}
// Stats returns nil stats and an error indicating stats are not available for null executor.
func (e *Executor) Stats(_ context.Context, _ string) (*types.ExecutorStats, error) {
return nil, fmt.Errorf("stats not available for null executor")
}
package provider
import "fmt"
// Factory is a constructor that builds a Provider from config.
type Factory func(cfg map[string]interface{}) (Provider, error)
// FactoryRegistry stores known provider constructors
type FactoryRegistry struct {
GatewayDID string
factories map[string]Factory
}
// NewProviderFactoryRegistry creates an empty registry.
func NewProviderFactoryRegistry(gatewayDID string) *FactoryRegistry {
return &FactoryRegistry{
GatewayDID: gatewayDID,
factories: make(map[string]Factory),
}
}
// Register adds a new provider factory.
func (r *FactoryRegistry) Register(name string, factory Factory) {
r.factories[name] = factory
}
// Create instantiates a provider using a registered factory.
func (r *FactoryRegistry) Create(name string, cfg map[string]interface{}) (Provider, error) {
f, ok := r.factories[name]
if !ok {
return nil, fmt.Errorf("no factory registered for provider type %q", name)
}
return f(cfg)
}
// List lists all known provider factory types.
func (r *FactoryRegistry) List() []string {
names := make([]string, 0, len(r.factories))
for n := range r.factories {
names = append(names, n)
}
return names
}
package local
import (
"bytes"
"context"
"encoding/json"
"fmt"
"os"
"os/exec"
"strings"
"time"
logging "github.com/ipfs/go-log/v2"
incus "github.com/lxc/incus/client"
"github.com/lxc/incus/shared/api"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/gateway/provider"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/types"
)
var log = logging.Logger("gateway/local/incus")
// IncusProvider
type IncusProvider struct {
client incus.InstanceServer
dmsBinaryPath string
gatewayDID string
}
func RegisterFactory(reg *provider.FactoryRegistry) {
reg.Register("local-incus", func(_ map[string]interface{}) (provider.Provider, error) {
// pass cfg if needed
return NewLocalIncusProvider(reg.GatewayDID)
})
}
// NewLocalIncusProvider creates a new local Incus provider using the local Unix socket.
func NewLocalIncusProvider(gatewayDID string) (*IncusProvider, error) {
c, err := incus.ConnectIncusUnix("", nil)
if err != nil {
return nil, fmt.Errorf("failed to connect to local Incus: %w", err)
}
return &IncusProvider{client: c, dmsBinaryPath: os.Getenv("DMS_BINARY_PATH"), gatewayDID: gatewayDID}, nil
}
// Name returns the provider identifier.
func (p *IncusProvider) Name() string {
return "local-incus"
}
// ListPlans returns a few static plans that represent local resource profiles.
func (p *IncusProvider) ListPlans(_ context.Context) ([]provider.Plan, error) {
plans := make([]provider.Plan, 0)
plans = append(plans, provider.Plan{
ID: "plan1",
Name: "VM1",
Description: "vm with 8 gb ram and 4 cpu",
CPU: 4,
MemoryMB: 4,
DiskGB: 10,
Region: "local",
PriceUSD: 10,
})
return plans, nil
}
//nolint:unparam
func runCommand(ctx context.Context, name string, args ...string) (string, error) {
log.Infof("runCommand : %s %s", name, strings.Join(args, " "))
cmd := exec.CommandContext(ctx, name, args...)
var out, stderr bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &stderr
if err := cmd.Run(); err != nil {
log.Errorf("runCommand started: error: %v %s", err, stderr.String())
return "", fmt.Errorf("command %q failed: %w; stderr: %s", name, err, stderr.String())
}
return out.String(), nil
}
// ProvisionServer creates a new Incus instance (container or VM) based on the plan and image.
// TODO: proper capabilities for orchestrator instead of root anchoring
func (p *IncusProvider) ProvisionServer(ctx context.Context, _ provider.Plan, name string, imageAlias, orchestratorDID string) (*provider.Server, error) {
// for incus if there was no image just default to the following image
if imageAlias == "" {
imageAlias = "ubuntu-22.04-vm"
}
res, err := runCommand(ctx, "incus", "launch", imageAlias, name, "--vm",
"--config", "limits.cpu=4",
"--config", "limits.memory=4GiB",
"--device", "root,size=5GiB")
if err != nil {
return nil, fmt.Errorf("failed to launch VM: %w; stderr: %s", err, res)
}
var ip string
timeout := time.After(2 * time.Minute)
tick := time.NewTicker(2 * time.Second)
defer tick.Stop()
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-timeout:
return nil, fmt.Errorf("timeout waiting for VM %s to get an IP", name)
case <-tick.C:
out, err := exec.CommandContext(ctx, "incus", "list", name, "--format", "json").Output()
if err != nil {
continue
}
var vmList []struct {
State struct {
Network map[string]struct {
Addresses []struct {
Family string `json:"family"`
Address string `json:"address"`
} `json:"addresses"`
} `json:"network"`
} `json:"state"`
}
if err := json.Unmarshal(out, &vmList); err != nil || len(vmList) == 0 {
continue
}
found := false
for _, iface := range vmList[0].State.Network {
for _, addr := range iface.Addresses {
if addr.Family == "inet" && addr.Address != "127.0.0.1" {
ip = addr.Address
found = true
break
}
}
if found {
break
}
}
if ip != "" {
goto done
}
}
}
done:
server := &provider.Server{
ID: name,
Name: name,
Status: "RUNNING",
IP: ip,
}
time.Sleep(5 * time.Second)
res, err = runCommand(ctx, "incus", "file", "push", p.dmsBinaryPath, name+"/home/ubuntu/dms")
if err != nil {
return nil, fmt.Errorf("failed to copy file into VM: %w %s", err, res)
}
res, err = runCommand(ctx, "incus", "exec", name, "--", "bash", "-c", `
set -eux
# Disable systemd-resolved and set custom DNS
#systemctl stop systemd-resolved
#systemctl disable systemd-resolved
rm -f /etc/resolv.conf
echo -e "nameserver 1.1.1.1\nnameserver 8.8.8.8" | tee /etc/resolv.conf
apt install -y ca-certificates curl
install -m 0755 -d /etc/apt/keyrings
curl -fsSL https://download.docker.com/linux/ubuntu/gpg -o /etc/apt/keyrings/docker.asc
chmod a+r /etc/apt/keyrings/docker.asc
echo "deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/ubuntu $(. /etc/os-release && echo "${UBUNTU_CODENAME:-$VERSION_CODENAME}") stable" | tee /etc/apt/sources.list.d/docker.list > /dev/null
apt update
apt install -y docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin
apt install -y openssh-server
systemctl enable ssh
systemctl start ssh
# Allow password authentication in sshd_config
sed -i 's/^#\?PasswordAuthentication .*/PasswordAuthentication yes/' /etc/ssh/sshd_config
sed -i 's/^#\?PermitRootLogin .*/PermitRootLogin yes/' /etc/ssh/sshd_config
systemctl restart ssh
# Set a root password
echo "root:root" | chpasswd
DMS_PASSPHRASE=pass /home/ubuntu/dms key new dms
setcap cap_net_admin,cap_sys_admin+ep /home/ubuntu/dms
DMS_PASSPHRASE=pass /home/ubuntu/dms cap new dms
DMS_PASSPHRASE=pass /home/ubuntu/dms cap anchor --context dms --root `+p.gatewayDID+`
DMS_PASSPHRASE=pass /home/ubuntu/dms cap anchor --context dms --root `+orchestratorDID+`
/home/ubuntu/dms config set p2p.listen_address '["/ip4/0.0.0.0/tcp/9001", "/ip4/0.0.0.0/udp/9001/quic-v1"]'
GOLOG_LOG_LEVEL=debug,pubsub=error,observability=error DMS_PASSPHRASE=pass /home/ubuntu/dms run --context dms > logfile.log 2>&1 &
sleep 7
DMS_PASSPHRASE=pass /home/ubuntu/dms actor cmd --context dms /dms/node/onboarding/onboard --no-gpu --ram 3 GB --cpu 2 --disk 2GiB
`)
if err != nil {
return nil, fmt.Errorf("failed to install dms and requirements: %w %s", err, res)
}
time.Sleep(10 * time.Second)
// connect to gateway, orchestrator and bootstrap peers to give identify a head start before the deployment
gatewayHandle, err := actor.HandleFromDID(p.gatewayDID)
if err != nil {
return nil, fmt.Errorf("failed to parse gateway DID handle: %w", err)
}
orchHandle, err := actor.HandleFromDID(orchestratorDID)
if err != nil {
return nil, fmt.Errorf("failed to parse orchestrator DID handle: %w", err)
}
// Connect to gateway
res, err = runCommand(ctx, "incus", "exec", name, "--", "bash", "-c", `
set -eux
DMS_PASSPHRASE=pass /home/ubuntu/dms actor cmd --context dms /dms/node/peers/connect --address /p2p/`+gatewayHandle.Address.HostID+`
`)
if err != nil {
return nil, fmt.Errorf("failed to execute self with json output: %w %s", err, res)
}
// connect to orchestrator
res, err = runCommand(ctx, "incus", "exec", name, "--", "bash", "-c", `
set -eux
DMS_PASSPHRASE=pass /home/ubuntu/dms actor cmd --context dms /dms/node/peers/connect --address /p2p/`+orchHandle.Address.HostID+`
`)
if err != nil {
return nil, fmt.Errorf("failed to execute self with json output: %w %s", err, res)
}
// connect to bootstrap nodes
for _, addr := range config.DefaultConfig.BootstrapPeers {
res, err = runCommand(ctx, "incus", "exec", name, "--", "bash", "-c", `
set -eux
DMS_PASSPHRASE=pass /home/ubuntu/dms actor cmd --context dms /dms/node/peers/connect --address `+addr[31:])
if err != nil {
return nil, fmt.Errorf("failed to execute self with json output: %w %s", err, res)
}
}
// give identify some time to finish obtaining observed addr
time.Sleep(60 * time.Second)
res, err = runCommand(ctx, "incus", "exec", name, "--", "bash", "-c", `
set -eux
DMS_PASSPHRASE=pass /home/ubuntu/dms actor cmd --context dms /dms/node/peers/self
`)
if err != nil {
return nil, fmt.Errorf("failed to execute self : %w %s", err, res)
}
var self self
err = json.Unmarshal([]byte(res), &self)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal self payload: %w %s", err, res)
}
server.PeerID = self.ID
ips := strings.Split(self.ListenAddr, ",")
if len(ips) == 0 {
return nil, fmt.Errorf("failed to get listen addr: %w %s", err, res)
}
server.ListenAddr = ips[0]
return server, nil
}
type self struct {
ID string `json:"id"`
ListenAddr string `json:"listen_addr"`
}
// DeleteServer removes an Incus instance by name.
func (p *IncusProvider) DeleteServer(_ context.Context, serverID string) error {
op, err := p.client.DeleteInstance(serverID)
if err != nil {
return fmt.Errorf("failed to delete instance: %w", err)
}
return op.Wait()
}
// RestartServer restarts an Incus instance.
func (p *IncusProvider) RestartServer(_ context.Context, serverID string) error {
op, err := p.client.UpdateInstanceState(serverID, api.InstanceStatePut{
Action: "restart",
Timeout: -1,
}, "")
if err != nil {
return fmt.Errorf("failed to restart instance: %w", err)
}
return op.Wait()
}
// GetServerStatus retrieves instance state and metadata.
func (p *IncusProvider) GetServerStatus(_ context.Context, serverID string) (*provider.Server, error) {
inst, etag, err := p.client.GetInstance(serverID)
if err != nil {
return nil, fmt.Errorf("failed to get instance: %w", err)
}
state, _, err := p.client.GetInstanceState(serverID)
if err != nil {
return nil, fmt.Errorf("failed to get instance state: %w", err)
}
server := &provider.Server{
ID: inst.Name,
Name: inst.Name,
Status: state.Status,
Metadata: map[string]string{
"etag": etag,
"pid": fmt.Sprintf("%d", state.Pid),
},
}
return server, nil
}
// SelectMatchingPlan selects a plan matching target resource requirements.
func (p *IncusProvider) SelectMatchingPlan(_ []provider.Plan, _ types.Resources) (*provider.Plan, error) {
plns, _ := p.ListPlans(context.Background())
return &plns[0], nil
}
package provider
import (
"fmt"
"sync"
)
// Registry manages a collection of providers by name.
type Registry struct {
mu sync.RWMutex
providers map[string]Provider
}
// NewProviderRegistry
func NewProviderRegistry(initialProviders ...Provider) *Registry {
r := &Registry{
providers: make(map[string]Provider),
}
for _, p := range initialProviders {
r.Register(p)
}
return r
}
// Register adds a provider to the registry.
func (r *Registry) Register(p Provider) {
r.mu.Lock()
defer r.mu.Unlock()
r.providers[p.Name()] = p
}
// Get retrieves a provider by name.
func (r *Registry) Get(name string) (Provider, error) {
r.mu.RLock()
defer r.mu.RUnlock()
p, ok := r.providers[name]
if !ok {
return nil, fmt.Errorf("provider %q not found", name)
}
return p, nil
}
// List returns the names of all registered providers.
func (r *Registry) List() []string {
r.mu.RLock()
defer r.mu.RUnlock()
names := make([]string, 0, len(r.providers))
for n := range r.providers {
names = append(names, n)
}
return names
}
// All returns all registered providers.
func (r *Registry) All() []Provider {
r.mu.RLock()
defer r.mu.RUnlock()
pvs := make([]Provider, 0, len(r.providers))
for _, v := range r.providers {
pvs = append(pvs, v)
}
return pvs
}
package provider
import (
"errors"
"gitlab.com/nunet/device-management-service/types"
)
// Plan represents a VM or server plan
type Plan struct {
ID string
Name string
Description string
CPU int
MemoryMB int
DiskGB int
GPUCount int
GPUModel string
GPUVRAMGB int
Region string
PriceUSD float64
}
// Server represents a provisioned server or VM.
type Server struct {
ID string
Name string
IP string
PlanID string
Status string // e.g., "running", "stopped", "deleted"
Region string
Created int64 // Unix timestamp
Metadata map[string]string
PeerID string
ListenAddr string
}
// SelectMatchingPlan based on available plans returns the best plan for the given resources
// helper function
func SelectMatchingPlan(plans []Plan, target types.Resources) (*Plan, error) {
for _, v := range plans {
res := convertPlanToResources(v)
comp, err := res.Compare(target)
if err != nil {
continue
}
if comp == types.Better || comp == types.Equal {
return &v, nil
}
}
return nil, errors.New("can't match hardware requirements")
}
// convertPlanToResources returns resources
func convertPlanToResources(plan Plan) types.Resources {
var gpus types.GPUs
if plan.GPUCount > 0 {
for i := 0; i < plan.GPUCount; i++ {
gpus = append(gpus, types.GPU{
Index: i,
Vendor: types.ParseGPUVendor(plan.GPUModel),
Model: plan.GPUModel,
VRAM: types.ConvertGBToBytes(uint64(plan.GPUVRAMGB)),
})
}
}
return types.Resources{
CPU: types.CPU{
Cores: float32(plan.CPU),
},
RAM: types.RAM{
Size: types.ConvertMibToBytes(uint64(plan.MemoryMB)),
},
Disk: types.Disk{
Size: types.ConvertGBToBytes(uint64(plan.DiskGB)),
},
GPUs: gpus,
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package store
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/ostafen/clover/v2"
"github.com/ostafen/clover/v2/document"
"github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/gateway/provider"
)
const provisionedResourcesCollection = "provisioned_resources"
type ProvisionedResources struct {
ProviderName string
Orchestrator string
ProvisionedVMPeerID string
Resource provider.Server
CreatedAt time.Time
}
type Store struct {
db *clover.DB
}
// New store
func New(db *clover.DB) (*Store, error) {
if db == nil {
return nil, errors.New("db is nil")
}
return &Store{
db: db,
}, nil
}
// Insert
func (s *Store) Insert(r *ProvisionedResources) error {
if r == nil {
return errors.New("resources data is nil")
}
bts, err := json.Marshal(r)
if err != nil {
return fmt.Errorf("failed to marshal contract: %w", err)
}
// Insert a new document
doc := document.NewDocumentOf(r)
doc.Set("id", r.Resource.ID)
doc.Set("created_at", time.Now().UnixNano())
doc.Set("data", bts)
return s.db.Insert(provisionedResourcesCollection, doc)
}
// All retrieves all provisioned resources from the database
func (s *Store) All() ([]*ProvisionedResources, error) {
q := query.NewQuery(provisionedResourcesCollection)
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve all provisioned resources: %w", err)
}
all := make([]*ProvisionedResources, 0)
for _, doc := range docs {
var res ProvisionedResources
data := doc.Get("data")
err = json.Unmarshal(data.([]byte), &res)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal single resource: %w", err)
}
all = append(all, &res)
}
return all, nil
}
// Delete removes a provisioned resource by its Resource.ID
func (s *Store) Delete(resourceID string) error {
if resourceID == "" {
return errors.New("resourceID is empty")
}
q := query.NewQuery(provisionedResourcesCollection).
Where(query.Field("id").Eq(resourceID))
err := s.db.Delete(q)
if err != nil {
return fmt.Errorf("failed to delete resource with id %s: %w", resourceID, err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package backgroundtasks
import (
"slices"
"sort"
"sync"
"time"
)
// Scheduler orchestrates the execution of tasks based on their triggers and priority.
type Scheduler struct {
tasks []*Task // List of tasks.
pollInterval time.Duration // Ticker for periodic checks of task triggers.
stopChan chan struct{} // Channel to signal stopping the scheduler.
taskSem chan struct{} // Semaphore to limit the number of running tasks.
nextTaskID int // Counter for assigning unique IDs to tasks.
mu sync.Mutex // Mutex to protect access to task maps.
taskWg sync.WaitGroup // Wait group to wait for all tasks to finish.
}
// NewScheduler creates a new Scheduler with a specified limit on running tasks.
func NewScheduler(maxRunningTasks int, pollInterval time.Duration) *Scheduler {
if pollInterval <= 0 {
pollInterval = 1 * time.Second
}
return &Scheduler{
tasks: make([]*Task, 0),
taskSem: make(chan struct{}, maxRunningTasks),
stopChan: make(chan struct{}),
pollInterval: pollInterval,
nextTaskID: 0,
}
}
// AddTask adds a new task to the scheduler and initializes its state.
func (s *Scheduler) AddTask(task *Task) *Task {
s.mu.Lock()
defer s.mu.Unlock()
for _, trigger := range task.Triggers {
trigger.Reset(time.Now().UTC())
}
task.Enabled = true
task.ID = s.nextTaskID
s.nextTaskID++
s.tasks = append(s.tasks, task)
return task
}
func (s *Scheduler) GetTask(taskID int) (*Task, bool) {
s.mu.Lock()
defer s.mu.Unlock()
for _, task := range s.tasks {
if task.ID == taskID {
return task, true
}
}
return nil, false
}
// RemoveTask removes a task from the scheduler.
func (s *Scheduler) RemoveTask(taskID int) {
s.mu.Lock()
defer s.mu.Unlock()
// Use slices.DeleteFunc to safely remove the task
s.tasks = slices.DeleteFunc(s.tasks, func(task *Task) bool {
return task.ID == taskID
})
}
// Start begins the scheduler's task execution loop.
func (s *Scheduler) Start() {
ticker := time.NewTicker(s.pollInterval)
go func() {
for {
select {
case <-s.stopChan:
return
case now := <-ticker.C:
s.checkAndDispatchTasks(now.UTC())
}
}
}()
}
func (s *Scheduler) checkAndDispatchTasks(currentTime time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
tasksToCheck := make([]*Task, len(s.tasks))
copy(tasksToCheck, s.tasks)
sort.SliceStable(tasksToCheck, func(i, j int) bool {
if tasksToCheck[i].Priority != tasksToCheck[j].Priority {
return tasksToCheck[i].Priority < tasksToCheck[j].Priority
}
return tasksToCheck[i].ID < tasksToCheck[j].ID
})
for _, task := range tasksToCheck {
if !task.Enabled {
continue
}
for _, trigger := range task.Triggers {
if trigger.IsReady(currentTime) {
s.dispatchTask(task, trigger)
}
}
}
}
func (s *Scheduler) dispatchTask(task *Task, trigger Trigger) {
s.taskWg.Add(1)
select {
case s.taskSem <- struct{}{}:
trigger.MarkTriggered(time.Now().UTC())
go s.executeTask(task)
case <-s.stopChan:
s.taskWg.Done()
}
}
func (s *Scheduler) executeTask(task *Task) {
defer func() {
<-s.taskSem
s.taskWg.Done()
}()
for retries := 0; retries <= task.RetryPolicy.MaxRetries; retries++ {
execution := Execution{
StartedAt: time.Now().UTC(),
}
err := task.Function(task.Args)
execution.EndedAt = time.Now().UTC()
if err != nil {
execution.Error = err.Error()
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.mu.Unlock()
if retries < task.RetryPolicy.MaxRetries {
select {
case <-s.stopChan:
return
case <-time.After(task.RetryPolicy.Delay):
}
}
} else {
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.mu.Unlock()
return
}
}
}
// Stop signals the scheduler to stop running tasks.
func (s *Scheduler) Stop() {
close(s.stopChan)
s.taskWg.Wait()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package backgroundtasks
import (
"time"
"github.com/robfig/cron/v3"
)
// Trigger interface defines a method to check if a trigger condition is met.
type Trigger interface {
IsReady(currentTime time.Time) bool // Returns true if the trigger condition is met.
MarkTriggered(triggerTime time.Time) // Marks the trigger as triggered.
Reset(currentTime time.Time) // Resets the trigger state.
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTrigger struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
Jitter func() time.Duration
startedAt time.Time
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTrigger) IsReady(currentTime time.Time) bool {
var jitter time.Duration
if t.Jitter != nil {
jitter = t.Jitter()
}
if t.startedAt.IsZero() {
t.startedAt = currentTime
}
now := currentTime.UTC()
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
return false
}
nextCronTriggerTime := cronExpr.Next(t.startedAt)
return !nextCronTriggerTime.IsZero() && nextCronTriggerTime.Add(jitter).Before(now)
}
// Trigger based on interval.
if t.Interval > 0 {
if t.startedAt.Add(t.Interval + jitter).Before(now) {
return true
}
}
return false
}
func (t *PeriodicTrigger) MarkTriggered(triggerTime time.Time) {
t.startedAt = triggerTime.UTC()
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTrigger) Reset(currentTime time.Time) {
t.startedAt = currentTime.UTC()
}
// EventTrigger triggers based on an external event signaled through a channel.
type EventTrigger struct {
Trigger chan bool // Channel to signal an event.
}
// IsReady checks if there is a signal in the trigger channel.
func (t *EventTrigger) IsReady(_ time.Time) bool {
select {
case <-t.Trigger:
return true
default:
return false
}
}
// MarkTriggered for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) MarkTriggered(_ time.Time) {}
// Reset for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) Reset(_ time.Time) {}
// OneTimeTrigger triggers once after a specified delay.
type OneTimeTrigger struct {
After time.Duration // The delay after which to trigger.
startedAt time.Time // Time when the trigger was set.
triggered bool // Flag indicating if the trigger has been triggered.
}
// IsReady checks if the current time has passed the delay period.
func (t *OneTimeTrigger) IsReady(currentTime time.Time) bool {
if t.startedAt.IsZero() {
t.startedAt = currentTime
}
return !t.triggered && t.startedAt.Add(t.After).Before(currentTime)
}
func (t *OneTimeTrigger) MarkTriggered(_ time.Time) {
t.triggered = true
}
// Reset sets the trigger registration time to the current time.
func (t *OneTimeTrigger) Reset(currentTime time.Time) {
t.startedAt = currentTime.UTC()
t.triggered = false
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package config
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
"github.com/go-playground/validator/v10"
"github.com/go-viper/mapstructure/v2"
"github.com/spf13/afero"
"github.com/spf13/pflag"
"github.com/spf13/viper"
)
var (
homeDir, _ = os.UserHomeDir()
defaultCfgName = "dms_config"
defaultCfgExt = "json"
validate = validator.New()
)
var DefaultConfig = Config{
General: General{
Env: "production",
UserDir: fmt.Sprintf("%s/.nunet", homeDir),
WorkDir: fmt.Sprintf("%s/nunet", homeDir),
DataDir: fmt.Sprintf("%s/nunet/data", homeDir),
Debug: false,
PortAvailableRangeFrom: 16384,
PortAvailableRangeTo: 65536,
StorageCADirectory: fmt.Sprintf("%s/.nunet/storage_ca_directory", homeDir),
StorageBricksDir: fmt.Sprintf("%s/.nunet/storage_bricks_dir", homeDir),
PaymentProvider: PaymentProvider{
Mode: false,
EthereumRPCURL: "https://ethereum-sepolia-rpc.publicnode.com",
NtxContractAddress: "0xB37216b70a745129966E553cF8Ee2C51e1cB359A", // TSTNTX
EthereumRPCToken: "",
},
},
Rest: Rest{
Addr: "127.0.0.1",
Port: 9999,
},
Profiler: Profiler{
Enabled: true,
Addr: "127.0.0.1",
Port: 6060,
},
P2P: P2P{
ListenAddress: []string{
"/ip4/0.0.0.0/tcp/9000",
"/ip4/0.0.0.0/udp/9000/quic-v1",
},
BootstrapPeers: []string{
"/dnsaddr/bootstrap.p2p.nunet.io/p2p/12D3KooWHzew9HTYzywFuvTHGK5Yzoz7qAhMfxagtCvhvjheoBQ3",
"/dnsaddr/bootstrap.p2p.nunet.io/p2p/12D3KooWJMtMN1mTNRfgMqUygT7eSXamVzc9ihpSjeairm9PebmB",
"/dnsaddr/bootstrap.p2p.nunet.io/p2p/12D3KooWKjSodxxi7UfRHzuk7eGgUF49MoPUCJvtva9K12TqDDsi",
},
Memory: 1024,
FileDescriptors: 512,
},
Observability: Observability{
Logging: Logging{
Level: "INFO",
File: fmt.Sprintf("%s/nunet/logs/nunet-dms-logs.jsonl", homeDir),
Rotation: Rotation{
MaxSizeMB: 100,
MaxBackups: 3,
MaxAgeDays: 28,
},
},
Elastic: Elastic{
URL: "https://telemetry.nunet.io",
Index: "nunet-dms",
FlushInterval: 5,
Enabled: false,
APIKey: "",
InsecureSkipVerify: false,
},
OTel: OTel{
Enabled: false,
Endpoint: "",
// Enabled: true,
// Endpoint: "localhost:3000",
Insecure: true,
},
},
APM: APM{
ServerURL: "https://apm.telemetry.nunet.io",
ServiceName: "nunet-dms",
Environment: "production",
APIKey: "",
},
Job: Job{
AllowPrivilegedDocker: false,
},
}
// Loader encapsulates Viper, the loaded Config, and an abstract filesystem.
type Loader struct {
v *viper.Viper
cfg *Config
fs afero.Fs
cfgMu sync.RWMutex
once sync.Once
cfgFile *string
}
// Option configures a Loader (functional-options pattern).
type Option func(*Loader)
// WithFS swaps the filesystem (defaults to OS FS).
func WithFS(fs afero.Fs) Option { return func(l *Loader) { l.fs = fs } }
// WithConfig swaps the default config (defaults to DefaultConfig).
func WithConfig(cfg *Config) Option { return func(l *Loader) { l.cfg = cfg } }
// NewLoader creates a Loader with sane defaults (does not read files).
func NewLoader(opts ...Option) *Loader {
l := &Loader{
v: viper.New(),
fs: afero.NewOsFs(),
cfg: &Config{},
}
*l.cfg = DefaultConfig
for _, opt := range opts {
opt(l)
}
l.init()
return l
}
// tryReadConfig mirrors v.ReadInConfig() but falls back for custom filesystems.
func tryReadConfig(vip *viper.Viper, fs afero.Fs) error {
if err := vip.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
name := defaultCfgName + "." + defaultCfgExt // dms_config.json
rel := "./" + name // ./dms_config.json
for _, cand := range []string{rel, name} { // test & prod cases
if ok, _ := afero.Exists(fs, cand); ok {
raw, _ := afero.ReadFile(fs, cand)
vip.SetConfigFile(cand)
return vip.ReadConfig(bytes.NewReader(raw))
}
}
} else {
return err // syntax / permission errors, etc.
}
}
return nil
}
func (l *Loader) init() {
v := l.v
v.SetFs(l.fs)
v.SetConfigName(defaultCfgName)
v.SetConfigType(defaultCfgExt)
v.AddConfigPath("./")
if homeDir, err := os.UserHomeDir(); err == nil {
v.AddConfigPath(fmt.Sprintf("%s/.nunet", homeDir)) // $HOME
}
if configDir, err := os.UserConfigDir(); err == nil {
v.AddConfigPath(fmt.Sprintf("%s/nunet", configDir)) // $CONFIG
}
if runtime.GOOS != "windows" {
v.AddConfigPath("/etc/nunet/") // system
}
_ = l.setConfig(*l.cfg, true)
v.SetEnvPrefix("DMS")
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.AutomaticEnv()
}
// Public Loader API
// Get reads the config file (if any), overlays env & flags, validates, and
// stores the result. Subsequent calls are cheap no-ops.
func (l *Loader) Load() (*Config, error) {
var err error
l.once.Do(func() { err = l.Reload() })
return l.cfg, err
}
func (l *Loader) ConfigFile() string {
return l.v.ConfigFileUsed()
}
// Reload forces a fresh read; useful for SIGHUP hot-reload.
func (l *Loader) Reload() error {
l.cfgMu.Lock()
migrated, err := l.readAndUnmarshal()
l.cfgMu.Unlock()
if migrated && err == nil {
if err = l.Write(true); err != nil {
return fmt.Errorf("persist migrated config: %w", err)
}
}
return err
}
func (l *Loader) SetConfig(c Config) {
l.cfgMu.Lock()
*l.cfg = c
l.cfgMu.Unlock()
_ = l.setConfig(c, false)
}
func (l *Loader) setConfig(cfg Config, def bool) error {
l.cfgMu.Lock()
defer l.cfgMu.Unlock()
tmp := viper.New()
tmp.SetConfigType("json")
raw, err := json.Marshal(cfg)
if err != nil {
return err
}
if err = tmp.ReadConfig(bytes.NewReader(raw)); err != nil {
return err
}
for _, k := range tmp.AllKeys() {
v := tmp.Get(k)
if def {
l.v.SetDefault(k, v)
} else {
l.v.Set(k, v)
}
}
return nil
}
func (l *Loader) GetConfig() (*Config, error) {
l.cfgMu.Lock()
defer l.cfgMu.Unlock()
return l.cfg, nil
}
// Update sets a single dotted key, validates the struct, then writes to disk.
func (l *Loader) Set(key string, value interface{}) error {
key = strings.ToLower(key)
tmp := viper.New()
tmp.SetConfigType("json")
if err := tmp.MergeConfigMap(l.v.AllSettings()); err != nil {
return err
}
tmp.Set(key, value)
var probe Config
if err := tmp.UnmarshalExact(&probe); err != nil {
return err // reject: unknown key or wrong type
}
if err := l.v.MergeConfigMap(map[string]any{key: value}); err != nil {
return err
}
l.cfgMu.Lock()
*l.cfg = probe
l.cfgMu.Unlock()
return l.Write(false)
}
// Write persists the current in-memory config to disk atomically.
// • Creates the file (and parent directories) on first run.
// • Uses a temp-file + rename so it can’t be truncated on crash.
func (l *Loader) Write(backupExisting bool) error {
l.cfgMu.RLock()
defer l.cfgMu.RUnlock()
cfgPath := l.v.ConfigFileUsed()
if cfgPath == "" {
// First run – default to "./dms_config.json" (same as search path 1)
cfgPath = fmt.Sprintf("./%s.%s", defaultCfgName, defaultCfgExt)
l.v.SetConfigFile(cfgPath)
}
if backupExisting {
backupPath := fmt.Sprintf("%s.bk.%d", cfgPath, time.Now().Unix())
if err := l.fs.Rename(cfgPath, backupPath); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("couldn't make backup of config file ahead of write. aborting write. Error: %w", err)
}
}
// Ensure directory hierarchy exists.
if err := l.fs.MkdirAll(filepath.Dir(cfgPath), 0o755); err != nil {
return fmt.Errorf("create config directory: %w", err)
}
raw, err := json.Marshal(l.cfg)
if err != nil {
return fmt.Errorf("marshal config: %w", err)
}
var m map[string]any
if err := json.Unmarshal(raw, &m); err != nil {
return fmt.Errorf("unmarshal map: %w", err)
}
if err := l.v.MergeConfigMap(m); err != nil {
return fmt.Errorf("merge into viper: %w", err)
}
tmpPath := strings.TrimSuffix(cfgPath, filepath.Ext(cfgPath)) + ".tmp" + filepath.Ext(cfgPath)
// Always try to clean up the temp file If Rename succeeds the file
// no longer exists so this is a harmless no-op
defer func() { _ = l.fs.Remove(tmpPath) }()
if err := l.v.WriteConfigAs(tmpPath); err != nil {
return fmt.Errorf("write temp config: %w", err)
}
if err := l.fs.Rename(tmpPath, cfgPath); err != nil {
return fmt.Errorf("atomic rename failed: %w", err)
}
return nil
}
// BindFlags attaches CLI flags to the Loader’s Viper instance.
func (l *Loader) BindFlags(fs *pflag.FlagSet) {
v := l.v
// --config is special: record its value, do NOT bind to Viper.
cfgFile := new(string)
fs.StringVar(cfgFile, "config", "", "config file (override search paths)")
l.cfgFile = cfgFile
type flag struct {
name, key, short, usage string
isBool, isInt bool
}
flags := []flag{
{"rest-addr", "rest.addr", "", "REST API host", false, false},
{"rest-port", "rest.port", "", "REST API port", false, true},
{"user-dir", "general.user_dir", "", "user directory", false, false},
{"work-dir", "general.work_dir", "", "work directory", false, false},
{"data-dir", "general.data_dir", "", "data directory", false, false},
{"debug", "general.debug", "", "debug mode", true, false},
{"profiler-enabled", "profiler.enabled", "", "enable profiler", true, false},
{"profiler-addr", "profiler.addr", "", "profiler address", false, false},
{"profiler-port", "profiler.port", "", "profiler port", false, true},
}
for _, f := range flags {
switch {
case f.isBool:
fs.BoolP(f.name, f.short, v.GetBool(f.key), f.usage)
case f.isInt:
fs.IntP(f.name, f.short, v.GetInt(f.key), f.usage)
default:
fs.StringP(f.name, f.short, v.GetString(f.key), f.usage)
}
_ = v.BindPFlag(f.key, fs.Lookup(f.name))
}
}
func (l *Loader) readAndUnmarshal() (migrated bool, err error) {
// honour --config flag if supplied
if l.cfgFile != nil && *l.cfgFile != "" {
l.v.SetConfigFile(*l.cfgFile)
}
if err := tryReadConfig(l.v, l.fs); err != nil {
return false, fmt.Errorf("read config: %w", err)
}
if err := l.v.UnmarshalExact(
l.cfg,
func(c *mapstructure.DecoderConfig) { c.ZeroFields = true },
); err != nil {
return false, fmt.Errorf("unmarshal config: %w", err)
}
migrated = migrateLegacyObservability(l)
if err := validate.Struct(l.cfg); err != nil {
return false, fmt.Errorf("validate config: %w", err)
}
return migrated, nil
}
// GetValue fetches a value by dotted key. second return is false if unset.
func (l *Loader) GetValue(key string) (interface{}, bool) {
key = strings.ToLower(key)
if !l.v.IsSet(key) {
return nil, false
}
return l.v.Get(key), true
}
// migrateLegacyObservability copies values from the deprecated flat
// observability keys into the new nested structure IF, AND ONLY IF, the
// nested fields have not been set This lets old and new config files
// work side-by-side and means we can remove the flat keys in a later
// release without breaking users
func migrateLegacyObservability(l *Loader) (migrated bool) {
oCfg := &l.cfg.Observability
// return if no legacy keys are set
if oCfg.LogLevel == "" &&
oCfg.LogFile == "" &&
oCfg.MaxSize == 0 &&
oCfg.MaxBackups == 0 &&
oCfg.MaxAge == 0 &&
oCfg.ElasticsearchURL == "" &&
oCfg.ElasticsearchIndex == "" &&
oCfg.FlushInterval == 0 &&
!oCfg.ElasticsearchEnabled &&
oCfg.ElasticsearchAPIKey == "" &&
!oCfg.InsecureSkipVerify {
return false
}
oMap, ok := l.v.Get("observability").(map[string]any)
if !ok {
return false
}
if oCfg.Logging.Level == "" && oCfg.LogLevel != "" {
l.v.Set("observability.logging.level", oCfg.LogLevel)
}
if oCfg.Logging.File == "" && oCfg.LogFile != "" {
l.v.Set("observability.logging.file", oCfg.LogFile)
}
if oCfg.Logging.Rotation.MaxSizeMB == 0 && oCfg.MaxSize != 0 {
l.v.Set("observability.logging.rotation.max_size_mb", oCfg.MaxSize)
}
if oCfg.Logging.Rotation.MaxBackups == 0 && oCfg.MaxBackups != 0 {
l.v.Set("observability.logging.rotation.max_backups", oCfg.MaxBackups)
}
if oCfg.Logging.Rotation.MaxAgeDays == 0 && oCfg.MaxAge != 0 {
l.v.Set("observability.logging.rotation.max_age_days", oCfg.MaxAge)
}
// NOTE: elastic flat keys are still valid and left untouched - they'll
// be removed in a later major version once users have migrated
if oCfg.ElasticsearchURL != "" {
l.v.Set("observability.elastic.url", oCfg.ElasticsearchURL)
}
if oCfg.ElasticsearchIndex != "" {
l.v.Set("observability.elastic.index", oCfg.ElasticsearchIndex)
}
if oCfg.FlushInterval != 0 {
l.v.Set("observability.elastic.flush_interval", oCfg.FlushInterval)
}
if !oCfg.Elastic.Enabled && oCfg.ElasticsearchEnabled {
l.v.Set("observability.elastic.enabled", oCfg.ElasticsearchEnabled)
}
if oCfg.Elastic.APIKey == "" && oCfg.ElasticsearchAPIKey != "" {
l.v.Set("observability.elastic.api_key", oCfg.ElasticsearchAPIKey)
}
if !oCfg.Elastic.InsecureSkipVerify && oCfg.InsecureSkipVerify {
l.v.Set("observability.elastic.insecure_skip_verify", oCfg.InsecureSkipVerify)
}
// clean up old keys
delete(oMap, "max_size")
delete(oMap, "max_backups")
delete(oMap, "max_age")
delete(oMap, "log_file")
delete(oMap, "log_level")
delete(oMap, "elasticsearch_url")
delete(oMap, "elasticsearch_index")
delete(oMap, "flush_interval")
delete(oMap, "elasticsearch_enabled")
delete(oMap, "elasticsearch_api_key")
delete(oMap, "insecure_skip_verify")
err := l.v.MergeConfigMap(map[string]any{"observability": oMap})
return err == nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package internal
import (
"os"
"os/signal"
"syscall"
)
var (
ShutdownChan chan os.Signal
ReloadChan chan os.Signal
)
func init() {
ShutdownChan = make(chan os.Signal, 1)
signal.Notify(ShutdownChan, syscall.SIGINT, syscall.SIGTERM)
ReloadChan = make(chan os.Signal, 1)
signal.Notify(ReloadChan, syscall.SIGUSR1)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"bytes"
"crypto/ed25519"
"fmt"
"github.com/fxamacker/cbor/v2"
"github.com/libp2p/go-libp2p/core/crypto/pb"
)
type CardanoPublicKey struct {
key *ed25519.PublicKey
}
var _ PubKey = (*CardanoPublicKey)(nil)
func UnmarshalCardanoPublicKey(data []byte) (_k PubKey, err error) {
if len(data) == ed25519.PublicKeySize {
pubKey := ed25519.PublicKey(data)
return &CardanoPublicKey{key: &pubKey}, nil
}
// Try to unmarshal as COSE Key
var coseKey map[int]interface{}
if err := cbor.Unmarshal(data, &coseKey); err != nil {
return nil, fmt.Errorf("invalid cardano public key: %w", err)
}
// Check kty = 1 (OKP)
if kty, ok := coseKey[1].(uint64); !ok || kty != 1 {
return nil, fmt.Errorf("invalid COSE Key: kty must be 1 (OKP)")
}
// Check crv = 6 (Ed25519)
// crv is label -1. In CBOR negative integers are distinct types.
// map[int]interface{} might not work if the key is negative?
// cbor library handles map keys.
// fxamacker/cbor supports map[int]interface{}. Negative integers are just negative ints in Go.
// crv (-1)
if crv, ok := coseKey[-1].(uint64); !ok || crv != 6 {
return nil, fmt.Errorf("invalid COSE Key: crv must be 6 (Ed25519)")
}
// x (-2)
x, ok := coseKey[-2].([]byte)
if !ok {
return nil, fmt.Errorf("invalid COSE Key: missing x parameter")
}
if len(x) != ed25519.PublicKeySize {
return nil, fmt.Errorf("invalid COSE Key: x parameter length %d", len(x))
}
pubKey := ed25519.PublicKey(x)
return &CardanoPublicKey{key: &pubKey}, nil
}
func (k *CardanoPublicKey) Verify(data []byte, sigBytes []byte) (success bool, err error) {
dec := cbor.NewDecoder(bytes.NewReader(sigBytes))
var coseSign1 []interface{}
if err := dec.Decode(&coseSign1); err != nil {
return false, fmt.Errorf("failed to decode COSE_Sign1: %w", err)
}
if len(coseSign1) != 4 {
return false, fmt.Errorf("unexpected COSE_Sign1 length: %d", len(coseSign1))
}
protected, ok := coseSign1[0].([]byte)
if !ok {
return false, fmt.Errorf("invalid protected headers")
}
signature, ok := coseSign1[3].([]byte)
if !ok {
return false, fmt.Errorf("invalid signature bytes")
}
sigStruct := []interface{}{
"Signature1",
protected,
[]byte{}, // external AAD empty
data,
}
sigStructBytes, err := cbor.Marshal(sigStruct)
if err != nil {
return false, fmt.Errorf("failed to marshal SigStructure: %w", err)
}
valid := ed25519.Verify(*k.key, sigStructBytes, signature)
return valid, nil
}
func (k *CardanoPublicKey) Raw() (res []byte, err error) {
return []byte(*k.key), nil
}
func (k *CardanoPublicKey) Type() pb.KeyType {
return Cardano
}
func (k *CardanoPublicKey) Equals(o Key) bool {
sk, ok := o.(*CardanoPublicKey)
if !ok {
return basicEquals(k, o)
}
return k.key.Equal(*sk.key)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"crypto/rand"
"errors"
"io"
"golang.org/x/crypto/sha3"
)
// RandomEntropy bytes from rand.Reader
func RandomEntropy(length int) ([]byte, error) {
if length < 0 {
return nil, errors.New("length must be non-negative")
}
buf := make([]byte, length)
n, err := io.ReadFull(rand.Reader, buf)
if err != nil || n != length {
return nil, errors.New("failed to read random bytes")
}
return buf, nil
}
// Sha3 return sha3 of a given byte array
func Sha3(data ...[]byte) ([]byte, error) {
d := sha3.New256()
for _, b := range data {
_, err := d.Write(b)
if err != nil {
return nil, err
}
}
return d.Sum(nil), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"crypto/subtle"
"fmt"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
"github.com/libp2p/go-libp2p/core/crypto/pb"
"golang.org/x/crypto/sha3"
)
var ethSignMagic = []byte(
"\x19Ethereum Signed Message:\n",
)
type EthPublicKey struct {
key *secp256k1.PublicKey
}
var _ PubKey = (*EthPublicKey)(nil)
func UnmarshalEthPublicKey(data []byte) (_k PubKey, err error) {
k, err := secp256k1.ParsePubKey(data)
if err != nil {
return nil, err
}
return &EthPublicKey{key: k}, nil
}
func (k *EthPublicKey) Verify(data []byte, sigStr []byte) (success bool, err error) {
sig, err := ecdsa.ParseDERSignature(sigStr)
if err != nil {
return false, err
}
hasher := sha3.NewLegacyKeccak256()
hasher.Write(ethSignMagic)
fmt.Fprintf(hasher, "%d", len(data))
hasher.Write(data)
hash := hasher.Sum(nil)
return sig.Verify(hash, k.key), nil
}
func (k *EthPublicKey) Raw() (res []byte, err error) {
return k.key.SerializeCompressed(), nil
}
func (k *EthPublicKey) Type() pb.KeyType {
return Eth
}
func (k *EthPublicKey) Equals(o Key) bool {
sk, ok := o.(*EthPublicKey)
if !ok {
return basicEquals(k, o)
}
return k.key.IsEqual(sk.key)
}
func basicEquals(k1, k2 Key) bool {
if k1.Type() != k2.Type() {
return false
}
a, err := k1.Raw()
if err != nil {
return false
}
b, err := k2.Raw()
if err != nil {
return false
}
return subtle.ConstantTimeCompare(a, b) == 1
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"bytes"
"encoding/base32"
"encoding/json"
"fmt"
)
// ID is the encoding of the actor's public key
type ID struct{ PublicKey []byte }
func (id ID) Equal(other ID) bool {
return bytes.Equal(id.PublicKey, other.PublicKey)
}
func (id ID) Empty() bool {
return len(id.PublicKey) == 0
}
// IDJsonView is the on the wire json reprsentation of an ID
type IDJSONView struct {
Pub string `json:"pub"`
}
func (id ID) String() string {
return base32.StdEncoding.EncodeToString(id.PublicKey)
}
func IDFromString(s string) (ID, error) {
data, err := base32.StdEncoding.DecodeString(s)
if err != nil {
return ID{}, fmt.Errorf("decode ID: %w", err)
}
return ID{PublicKey: data}, nil
}
func (id ID) MarshalJSON() ([]byte, error) {
return json.Marshal(IDJSONView{Pub: id.String()})
}
var _ json.Marshaler = ID{}
func (id *ID) UnmarshalJSON(data []byte) error {
var input IDJSONView
err := json.Unmarshal(data, &input)
if err != nil {
return fmt.Errorf("unmarshaling ID: %w", err)
}
val, err := IDFromString(input.Pub)
if err != nil {
return fmt.Errorf("unmarshaling ID: %w", err)
}
*id = val
return nil
}
var _ json.Unmarshaler = (*ID)(nil)
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package crypto
import (
"crypto/rand"
"fmt"
"github.com/libp2p/go-libp2p/core/crypto"
)
const (
Ed25519 = crypto.Ed25519
Secp256k1 = crypto.Secp256k1
Eth = 127
Cardano = 128
)
type (
Key = crypto.Key
PrivKey = crypto.PrivKey
PubKey = crypto.PubKey
)
func AllowedKey(t int) bool {
switch t {
case Ed25519:
return true
case Secp256k1:
return true
default:
return false
}
}
func GenerateKeyPair(t int) (PrivKey, PubKey, error) {
switch t {
case Ed25519:
return crypto.GenerateEd25519Key(rand.Reader)
case Secp256k1:
return crypto.GenerateSecp256k1Key(rand.Reader)
default:
return nil, nil, fmt.Errorf("unsupported key type %d: %w", t, ErrUnsupportedKeyType)
}
}
func PublicKeyToBytes(k PubKey) ([]byte, error) {
return crypto.MarshalPublicKey(k)
}
func BytesToPublicKey(data []byte) (PubKey, error) {
return crypto.UnmarshalPublicKey(data)
}
func PrivateKeyToBytes(k PrivKey) ([]byte, error) {
return crypto.MarshalPrivateKey(k)
}
func BytesToPrivateKey(data []byte) (PrivKey, error) {
return crypto.UnmarshalPrivateKey(data)
}
func IDFromPublicKey(k PubKey) (ID, error) {
data, err := PublicKeyToBytes(k)
if err != nil {
return ID{}, fmt.Errorf("id from public key: %w", err)
}
return ID{PublicKey: data}, nil
}
func PublicKeyFromID(id ID) (PubKey, error) {
return BytesToPublicKey(id.PublicKey)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package keystore
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"encoding/hex"
"encoding/json"
"fmt"
"path/filepath"
"strings"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
"golang.org/x/crypto/scrypt"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
var (
// KDF parameters
nameKDF = "scrypt"
scryptKeyLen = 64
// Default scrypt parameters (production values)
defaultScryptN = 1 << 18
defaultScryptR = 8
defaultScryptP = 1
// Current scrypt parameters (can be overridden for testing)
scryptN = defaultScryptN
scryptR = defaultScryptR
scryptP = defaultScryptP
ksVersion = 3
ksCipher = "aes-256-ctr"
)
// SetTestScryptParams allows overriding scrypt parameters for testing purposes.
// This significantly speeds up key generation/derivation but reduces security.
// IMPORTANT: Only use this in test environments.
func SetTestScryptParams(n, r, p int) {
scryptN = n
scryptR = r
scryptP = p
}
// ResetScryptParamsToDefaults restores scrypt parameters to their original production values.
func ResetScryptParamsToDefaults() {
scryptN = defaultScryptN
scryptR = defaultScryptR
scryptP = defaultScryptP
}
// Key represents a keypair to be stored in a keystore
type Key struct {
ID string
Data []byte
}
// NewKey creates new Key
func NewKey(id string, data []byte) (*Key, error) {
return &Key{
ID: id,
Data: data,
}, nil
}
// PrivKey acts upon a Key which its `Data` is a private key.
// The method unmarshals the raw pvkey bytes.
func (key *Key) PrivKey() (crypto.PrivKey, error) {
priv, err := libp2p_crypto.UnmarshalPrivateKey(key.Data)
if err != nil {
return nil, fmt.Errorf("unable to unmarshal private key: %v", err)
}
return priv, nil
}
// MarshalToJSON encrypts and marshals a key to json byte array.
func (key *Key) MarshalToJSON(passphrase string) ([]byte, error) {
if passphrase == "" {
return nil, ErrEmptyPassphrase
}
salt, err := crypto.RandomEntropy(64)
if err != nil {
return nil, err
}
dk, err := scrypt.Key([]byte(passphrase), salt, scryptN, scryptR, scryptP, scryptKeyLen)
if err != nil {
return nil, err
}
iv, err := crypto.RandomEntropy(aes.BlockSize)
if err != nil {
return nil, err
}
enckey := dk[:32]
aesBlock, err := aes.NewCipher(enckey)
if err != nil {
return nil, err
}
stream := cipher.NewCTR(aesBlock, iv)
cipherText := make([]byte, len(key.Data))
stream.XORKeyStream(cipherText, key.Data)
mac, err := crypto.Sha3(dk[32:64], cipherText)
if err != nil {
return nil, err
}
cipherParamsJSON := cipherparamsJSON{
IV: hex.EncodeToString(iv),
}
sp := ScryptParams{
N: scryptN,
R: scryptR,
P: scryptP,
DKeyLength: scryptKeyLen,
Salt: hex.EncodeToString(salt),
}
keyjson := cryptoJSON{
Cipher: ksCipher,
CipherText: hex.EncodeToString(cipherText),
CipherParams: cipherParamsJSON,
KDF: nameKDF,
KDFParams: sp,
MAC: hex.EncodeToString(mac),
}
encjson := encryptedKeyJSON{
Crypto: keyjson,
ID: key.ID,
Version: ksVersion,
}
data, err := json.MarshalIndent(&encjson, "", " ")
if err != nil {
return nil, err
}
return data, nil
}
// UnmarshalKey decrypts and unmarhals the private key
func UnmarshalKey(data []byte, passphrase string) (*Key, error) {
if passphrase == "" {
return nil, ErrEmptyPassphrase
}
encjson := encryptedKeyJSON{}
if err := json.Unmarshal(data, &encjson); err != nil {
return nil, fmt.Errorf("%w: %w", ErrDecodeKey, err)
}
if encjson.Version != ksVersion {
return nil, ErrVersionMismatch
}
if encjson.Crypto.Cipher != ksCipher {
return nil, ErrCipherMismatch
}
mac, err := hex.DecodeString(encjson.Crypto.MAC)
if err != nil {
return nil, fmt.Errorf("%w: mac: %w", ErrDecodeKey, err)
}
iv, err := hex.DecodeString(encjson.Crypto.CipherParams.IV)
if err != nil {
return nil, fmt.Errorf("%w: cipher params: %w", ErrDecodeKey, err)
}
salt, err := hex.DecodeString(encjson.Crypto.KDFParams.Salt)
if err != nil {
return nil, fmt.Errorf("%w: salt: %w", ErrDecodeKey, err)
}
ciphertext, err := hex.DecodeString(encjson.Crypto.CipherText)
if err != nil {
return nil, fmt.Errorf("%w: cipher text: %w", ErrDecodeKey, err)
}
dk, err := scrypt.Key([]byte(passphrase), salt, encjson.Crypto.KDFParams.N, encjson.Crypto.KDFParams.R, encjson.Crypto.KDFParams.P, encjson.Crypto.KDFParams.DKeyLength)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrKeyProcessing, err)
}
hash, err := crypto.Sha3(dk[32:64], ciphertext)
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrKeyProcessing, err)
}
if !bytes.Equal(hash, mac) {
return nil, ErrMACMismatch
}
aesBlock, err := aes.NewCipher(dk[:32])
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrDecodeKey, err)
}
stream := cipher.NewCTR(aesBlock, iv)
outputkey := make([]byte, len(ciphertext))
stream.XORKeyStream(outputkey, ciphertext)
return &Key{
ID: encjson.ID,
Data: outputkey,
}, nil
}
func removeFileExtension(filename string) string {
ext := filepath.Ext(filename)
return strings.TrimSuffix(filename, ext)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package keystore
import (
"bytes"
"encoding/gob"
"fmt"
"os"
"path/filepath"
"slices"
"sync"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/utils"
)
// KeyStore manages a local keystore with lock and unlock functionalities.
type KeyStore interface {
Save(id string, data []byte, passphrase string) (string, error)
Get(keyID string, passphrase string) (*Key, error)
Delete(keyID string, passphrase string) error
ListKeys() ([]string, error)
Exists(key string) bool
Dir() string
}
// BasicKeyStore handles keypair storage.
// TODO: add cache?
type BasicKeyStore struct {
fs afero.Fs
keysDir string
mu sync.RWMutex
cache map[string]*Key
fsCache bool
}
var _ KeyStore = (*BasicKeyStore)(nil)
// New creates a new BasicKeyStore.
//
// fsCache: keeps unmarshalled keys in the file system. Insecure, only for tests.
func New(fs afero.Fs, keysDir string, fsCache bool) (*BasicKeyStore, error) {
if keysDir == "" {
return nil, ErrEmptyKeysDir
}
if err := fs.MkdirAll(keysDir, 0o700); err != nil {
return nil, fmt.Errorf("%w: %w", ErrCreateKeysDir, err)
}
return &BasicKeyStore{
fs: fs,
keysDir: keysDir,
cache: make(map[string]*Key),
fsCache: fsCache,
}, nil
}
// Save encrypts a key and writes it to a file.
func (ks *BasicKeyStore) Save(id string, data []byte, passphrase string) (string, error) {
if passphrase == "" {
return "", ErrEmptyPassphrase
}
key := &Key{
ID: id,
Data: data,
}
keyDataJSON, err := key.MarshalToJSON(passphrase)
if err != nil {
return "", fmt.Errorf("failed to marshal key: %w", err)
}
filename, err := utils.WriteToFile(ks.fs, keyDataJSON, filepath.Join(ks.keysDir, key.ID+".json"))
if err != nil {
return "", fmt.Errorf("failed to write key to file: %w", err)
}
// cache
delete(ks.cache, key.ID)
if ks.fsCache {
_ = ks.fs.Remove(filepath.Join(ks.keysDir, id+".gob"))
}
return filename, nil
}
// Get unlocks a key by keyID.
func (ks *BasicKeyStore) Get(keyID string, passphrase string) (*Key, error) {
// read cache?
fsCachePath := filepath.Join(ks.keysDir, keyID+".gob")
if key, ok := ks.cache[keyID]; ok {
return key, nil
}
if ks.fsCache {
if b, err := afero.ReadFile(ks.fs, fsCachePath); err == nil {
key := &Key{}
if err := gob.NewDecoder(bytes.NewReader(b)).Decode(key); err == nil {
ks.cache[keyID] = key
return key, nil
}
}
}
// read & unmarshall
bts, err := afero.ReadFile(ks.fs, filepath.Join(ks.keysDir, keyID+".json"))
if err != nil {
if os.IsNotExist(err) {
return nil, ErrKeyNotFound
}
return nil, fmt.Errorf("failed to read keystore file: %w", err)
}
key, err := UnmarshalKey(bts, passphrase)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal keystore file: %w", err)
}
// save cache
ks.cache[keyID] = key
if ks.fsCache {
var buf bytes.Buffer
if err := gob.NewEncoder(&buf).Encode(key); err != nil {
return key, nil
}
if err := afero.WriteFile(ks.fs, fsCachePath, buf.Bytes(), 0o600); err != nil {
return key, nil
}
}
return key, err
}
// Exists returns whether a key is stored
func (ks *BasicKeyStore) Exists(key string) bool {
keys, err := ks.ListKeys()
if err != nil {
return false
}
return slices.Contains(keys, key)
}
// Delete removes the file referencing the given key.
func (ks *BasicKeyStore) Delete(keyID string, passphrase string) error {
ks.mu.Lock()
defer ks.mu.Unlock()
filePath := filepath.Join(ks.keysDir, keyID+".json")
bts, err := afero.ReadFile(ks.fs, filePath)
if err != nil {
if os.IsNotExist(err) {
return ErrKeyNotFound
}
return fmt.Errorf("failed to read keystore file: %w", err)
}
_, err = UnmarshalKey(bts, passphrase)
if err != nil {
return fmt.Errorf("invalid passphrase or corrupted key file: %w", err)
}
err = ks.fs.Remove(filePath)
if err != nil {
return fmt.Errorf("failed to delete key file: %w", err)
}
// cache
delete(ks.cache, keyID)
if ks.fsCache {
_ = ks.fs.Remove(filepath.Join(ks.keysDir, keyID+".gob"))
}
return nil
}
// ListKeys lists the keys in the keysDir.
func (ks *BasicKeyStore) ListKeys() ([]string, error) {
keys := make([]string, 0)
dirEntries, err := afero.ReadDir(ks.fs, ks.keysDir)
if err != nil {
return nil, fmt.Errorf("failed to read keystore directory: %w", err)
}
for _, entry := range dirEntries {
_, err := afero.ReadFile(ks.fs, filepath.Join(ks.keysDir, entry.Name()))
if err != nil {
continue
}
keys = append(keys, removeFileExtension(entry.Name()))
}
return keys, nil
}
func (ks *BasicKeyStore) Dir() string {
return ks.keysDir
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
type GetAnchorFunc func(did DID) (Anchor, error)
var anchorMethods map[string]GetAnchorFunc
func init() {
anchorMethods = map[string]GetAnchorFunc{
"key": makeKeyAnchor,
"prism": makePrismAnchor,
}
}
// GetAnchorForDID resolves a DID to an Anchor using the appropriate method resolver
func GetAnchorForDID(did DID) (Anchor, error) {
makeAnchor, ok := anchorMethods[did.Method()]
if !ok {
return nil, ErrNoAnchorMethod
}
return makeAnchor(did)
}
// RegisterAnchorMethod registers a new DID method resolver
// This allows extending the system with additional DID methods at runtime
func RegisterAnchorMethod(method string, resolver GetAnchorFunc) {
anchorMethods[method] = resolver
}
func makeKeyAnchor(did DID) (Anchor, error) {
pubk, err := PublicKeyFromDID(did)
if err != nil {
return nil, err
}
return NewAnchor(did, pubk), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
import (
"context"
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const anchorEntryTTL = time.Hour
type Signer string
const EternlSigner Signer = "ETERNL"
const LedgerSigner Signer = "LEDGER"
const PrivateKeySigner Signer = "PRIVATE_KEY"
// Anchor is a DID anchor that encapsulates a public key that can be used
// for verification of signatures.
type Anchor interface {
DID() DID
Verify(data []byte, sig []byte) error
PublicKey() crypto.PubKey
}
// Provider holds the private key material necessary to sign statements for
// a DID.
type Provider interface {
Signer() Signer
DID() DID
Sign(data []byte) ([]byte, error)
Anchor() Anchor
PrivateKey() (crypto.PrivKey, error)
}
type TrustContext interface {
Anchors() []DID
Providers() []DID
GetAnchor(did DID) (Anchor, error)
GetProvider(did DID) (Provider, error)
AddAnchor(anchor Anchor)
AddProvider(provider Provider)
Start(gcInterval time.Duration)
Stop()
}
type anchorEntry struct {
anchor Anchor
expire time.Time
}
type BasicTrustContext struct {
mx sync.Mutex
anchors map[DID]*anchorEntry
providers map[DID]Provider
stop func()
}
var _ TrustContext = (*BasicTrustContext)(nil)
func NewTrustContext() TrustContext {
return &BasicTrustContext{
anchors: make(map[DID]*anchorEntry),
providers: make(map[DID]Provider),
}
}
func NewTrustContextWithPrivateKey(privk crypto.PrivKey) (TrustContext, error) {
ctx := NewTrustContext()
provider, err := ProviderFromPrivateKey(privk)
if err != nil {
return nil, fmt.Errorf("provide from private key: %w", err)
}
ctx.AddProvider(provider)
return ctx, nil
}
func NewTrustContextWithProvider(p Provider) TrustContext {
ctx := NewTrustContext()
ctx.AddProvider(p)
return ctx
}
func (ctx *BasicTrustContext) Anchors() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.anchors))
for anchor := range ctx.anchors {
result = append(result, anchor)
}
return result
}
func (ctx *BasicTrustContext) Providers() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.providers))
for provider := range ctx.providers {
result = append(result, provider)
}
return result
}
func (ctx *BasicTrustContext) GetAnchor(did DID) (Anchor, error) {
anchor, ok := ctx.getAnchor(did)
if ok {
return anchor, nil
}
anchor, err := GetAnchorForDID(did)
if err != nil {
return nil, fmt.Errorf("get anchor for did: %w", err)
}
ctx.AddAnchor(anchor)
return anchor, nil
}
func (ctx *BasicTrustContext) getAnchor(did DID) (Anchor, bool) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
entry, ok := ctx.anchors[did]
if ok {
entry.expire = time.Now().Add(anchorEntryTTL)
return entry.anchor, true
}
return nil, false
}
func (ctx *BasicTrustContext) GetProvider(did DID) (Provider, error) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
provider, ok := ctx.providers[did]
if !ok {
return nil, ErrNoProvider
}
return provider, nil
}
func (ctx *BasicTrustContext) AddAnchor(anchor Anchor) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.anchors[anchor.DID()] = &anchorEntry{
anchor: anchor,
expire: time.Now().Add(anchorEntryTTL),
}
}
func (ctx *BasicTrustContext) AddProvider(provider Provider) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.providers[provider.DID()] = provider
}
func (ctx *BasicTrustContext) Start(gcInterval time.Duration) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
}
gcCtx, stop := context.WithCancel(context.Background())
ctx.stop = stop
go ctx.gc(gcCtx, gcInterval)
}
func (ctx *BasicTrustContext) Stop() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
ctx.stop = nil
}
}
func (ctx *BasicTrustContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcAnchorEntries()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicTrustContext) gcAnchorEntries() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := time.Now()
for k, e := range ctx.anchors {
if e.expire.Before(now) {
delete(ctx.anchors, k)
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
import (
"fmt"
"strings"
)
type DID struct {
URI string `json:"uri,omitempty"`
}
func (did DID) Equal(other DID) bool {
return did.URI == other.URI
}
func (did DID) Empty() bool {
return did.URI == ""
}
func (did DID) String() string {
return did.URI
}
func (did DID) Method() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[1]
}
return ""
}
func (did DID) Identifier() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[2]
}
return ""
}
func FromString(s string) (DID, error) {
if s != "" {
parts := strings.Split(s, ":")
if len(parts) != 3 {
return DID{},
fmt.Errorf("%w: %s", ErrInvalidDID, s)
}
for _, part := range parts {
if part == "" {
return DID{},
fmt.Errorf("%w: %s", ErrInvalidDID, s)
}
}
// TODO validate parts according to spec: https://www.w3.org/TR/did-core/
}
return DID{URI: s}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
import (
"bytes"
"encoding/hex"
"fmt"
"os/exec"
"regexp"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const eternlCLI = "eternl-cli"
type EternlWalletProvider struct {
did DID
pubk crypto.PubKey
}
var _ Provider = (*EternlWalletProvider)(nil)
func signDataWithBinary(binaryPath string, data string) (string, string, error) {
// add simple data to get a signature/public key
if data == "" {
data = "6765745F7075626B6579"
}
cmd := exec.Command(binaryPath, data)
var out bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &out
if err := cmd.Run(); err != nil {
return "", "", fmt.Errorf("failed to run binary: %w\nOutput: %s", err, out.String())
}
re := regexp.MustCompile(`PubKeyRaw:\[([a-fA-F0-9]+)\]`)
matches := re.FindStringSubmatch(out.String())
if len(matches) < 2 {
return "", "", fmt.Errorf("pub value not found in output")
}
re2 := regexp.MustCompile(`Signature:\[([a-fA-F0-9]+)\]`)
matches2 := re2.FindStringSubmatch(out.String())
if len(matches2) < 2 {
return "", "", fmt.Errorf("sig value not found in output")
}
pubkey := matches[1]
sig := matches2[1]
return pubkey, sig, nil
}
func NewEternlWalletProvider() (Provider, error) {
eCli, err := exec.LookPath(eternlCLI)
if err != nil {
return nil, fmt.Errorf("can't find %s in PATH: %w", eternlCLI, err)
}
pub, _, err := signDataWithBinary(eCli, "")
if err != nil {
return nil, err
}
pubKeyBytes, err := hex.DecodeString(pub)
if err != nil {
return nil, err
}
pubKey, err := crypto.UnmarshalCardanoPublicKey(pubKeyBytes)
if err != nil {
return nil, err
}
did := FromPublicKey(pubKey)
return &EternlWalletProvider{
did: did,
pubk: pubKey,
}, nil
}
func (p *EternlWalletProvider) Signer() Signer {
return EternlSigner
}
func (p *EternlWalletProvider) DID() DID {
return p.did
}
func (p *EternlWalletProvider) Sign(data []byte) ([]byte, error) {
eCli, err := exec.LookPath(eternlCLI)
if err != nil {
return nil, fmt.Errorf("can't find %s in PATH: %w", eternlCLI, err)
}
dataHex := hex.EncodeToString(data)
_, sig, err := signDataWithBinary(eCli, dataHex)
if err != nil {
return nil, err
}
return hex.DecodeString(sig)
}
func (p *EternlWalletProvider) Anchor() Anchor {
return NewAnchor(p.did, p.pubk)
}
func (p *EternlWalletProvider) PrivateKey() (crypto.PrivKey, error) {
return nil, fmt.Errorf("eternl private key cannot be exported: %w", ErrHardwareKey)
}
// Original Copyright 2024, Ucan Working Group; Modified Copyright 2024, NuNet;
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
//
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package did
import (
"fmt"
"strings"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
mb "github.com/multiformats/go-multibase"
varint "github.com/multiformats/go-varint"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
type PublicKeyAnchor struct {
did DID
pubk crypto.PubKey
}
var _ Anchor = (*PublicKeyAnchor)(nil)
type PrivateKeyProvider struct {
did DID
privk crypto.PrivKey
}
var _ Provider = (*PrivateKeyProvider)(nil)
func NewAnchor(did DID, pubk crypto.PubKey) Anchor {
return &PublicKeyAnchor{
did: did,
pubk: pubk,
}
}
func NewProvider(did DID, privk crypto.PrivKey) Provider {
return &PrivateKeyProvider{
did: did,
privk: privk,
}
}
func (a *PublicKeyAnchor) DID() DID {
return a.did
}
func (a *PublicKeyAnchor) Verify(data []byte, sig []byte) error {
ok, err := a.pubk.Verify(data, sig)
if err != nil {
return err
}
if !ok {
return ErrInvalidSignature
}
return nil
}
func (a *PublicKeyAnchor) PublicKey() crypto.PubKey {
return a.pubk
}
func (p *PrivateKeyProvider) DID() DID {
return p.did
}
func (p *PrivateKeyProvider) Sign(data []byte) ([]byte, error) {
return p.privk.Sign(data)
}
func (p *PrivateKeyProvider) PrivateKey() (crypto.PrivKey, error) {
return p.privk, nil
}
func (p *PrivateKeyProvider) Signer() Signer {
return PrivateKeySigner
}
func (p *PrivateKeyProvider) Anchor() Anchor {
return NewAnchor(p.did, p.privk.GetPublic())
}
func FromID(id crypto.ID) (DID, error) {
pubk, err := crypto.PublicKeyFromID(id)
if err != nil {
return DID{}, fmt.Errorf("public key from id: %w", err)
}
return FromPublicKey(pubk), nil
}
func FromPublicKey(pubk crypto.PubKey) DID {
uri := FormatKeyURI(pubk)
return DID{URI: uri}
}
func PublicKeyFromDID(did DID) (crypto.PubKey, error) {
if did.Method() != "key" {
return nil, ErrInvalidDID
}
pubk, err := ParseKeyURI(did.URI)
if err != nil {
return nil, fmt.Errorf("parsing did key identifier: %w", err)
}
return pubk, nil
}
func AnchorFromPublicKey(pubk crypto.PubKey) (Anchor, error) {
did := FromPublicKey(pubk)
return NewAnchor(did, pubk), nil
}
func ProviderFromPrivateKey(privk crypto.PrivKey) (Provider, error) {
did := FromPublicKey(privk.GetPublic())
return NewProvider(did, privk), nil
}
// Note: this code originated in https://github.com/ucan-wg/go-ucan/blob/main/didkey/key.go
// Copyright applies; some superficial modifications by vyzo.
const (
multicodecKindEd25519PubKey uint64 = 0xed
multicodecKindSecp256k1PubKey uint64 = 0xe7
multicodecKindEthPubKey uint64 = 0xef01
multicodecKindCardanoPubKey uint64 = 0xef02
keyPrefix = "did:key"
)
func FormatKeyURI(pubk crypto.PubKey) string {
raw, err := pubk.Raw()
if err != nil {
return ""
}
var t uint64
switch pubk.Type() {
case crypto.Ed25519:
t = multicodecKindEd25519PubKey
case crypto.Secp256k1:
t = multicodecKindSecp256k1PubKey
case crypto.Eth:
t = multicodecKindEthPubKey
case crypto.Cardano:
t = multicodecKindCardanoPubKey
default:
// we don't support those yet
log.Errorf("unsupported key type: %d", t)
return ""
}
size := varint.UvarintSize(t)
data := make([]byte, size+len(raw))
n := varint.PutUvarint(data, t)
copy(data[n:], raw)
b58BKeyStr, err := mb.Encode(mb.Base58BTC, data)
if err != nil {
return ""
}
return fmt.Sprintf("%s:%s", keyPrefix, b58BKeyStr)
}
func ParseKeyURI(uri string) (crypto.PubKey, error) {
if !strings.HasPrefix(uri, keyPrefix) {
return nil, fmt.Errorf("decentralized identifier is not a 'key' type")
}
uri = strings.TrimPrefix(uri, keyPrefix+":")
enc, data, err := mb.Decode(uri)
if err != nil {
return nil, fmt.Errorf("decoding multibase: %w", err)
}
if enc != mb.Base58BTC {
return nil, fmt.Errorf("unexpected multibase encoding: %s", mb.EncodingToStr[enc])
}
keyType, n, err := varint.FromUvarint(data)
if err != nil {
return nil, err
}
switch keyType {
case multicodecKindEd25519PubKey:
return libp2p_crypto.UnmarshalEd25519PublicKey(data[n:])
case multicodecKindSecp256k1PubKey:
return libp2p_crypto.UnmarshalSecp256k1PublicKey(data[n:])
case multicodecKindCardanoPubKey:
return crypto.UnmarshalCardanoPublicKey(data[n:])
case multicodecKindEthPubKey:
return crypto.UnmarshalEthPublicKey(data[n:])
default:
return nil, ErrInvalidKeyType
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
import (
"encoding/hex"
"encoding/json"
"fmt"
"os"
"os/exec"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"github.com/decred/dcrd/dcrec/secp256k1/v4/ecdsa"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const ledgerCLI = "ledger-cli"
type LedgerWalletProvider struct {
did DID
pubk crypto.PubKey
acct int
}
var _ Provider = (*LedgerWalletProvider)(nil)
type LedgerKeyOutput struct {
Key string `json:"key"`
Address string `json:"address"`
}
type LedgerSignOutput struct {
ECDSA LedgerSignECDSAOutput `json:"ecdsa"`
}
type LedgerSignECDSAOutput struct {
V uint `json:"v"`
R string `json:"r"`
S string `json:"s"`
}
func NewLedgerWalletProvider(acct int) (Provider, error) {
tmp, err := getLedgerTmpFile()
if err != nil {
return nil, err
}
defer os.Remove(tmp)
var output LedgerKeyOutput
if err := ledgerExec(
tmp,
&output,
"key",
"-o", tmp,
"-a", fmt.Sprintf("%d", acct),
); err != nil {
return nil, fmt.Errorf("error executing ledger cli: %w", err)
}
// decode the hex key
raw, err := hex.DecodeString(output.Key)
if err != nil {
return nil, fmt.Errorf("decode ledger key: %w", err)
}
pubk, err := crypto.UnmarshalEthPublicKey(raw)
if err != nil {
return nil, fmt.Errorf("unmarshal ledger raw key: %w", err)
}
did := FromPublicKey(pubk)
return &LedgerWalletProvider{
did: did,
pubk: pubk,
acct: acct,
}, nil
}
func ledgerExec(tmp string, output interface{}, args ...string) error {
ledger, err := exec.LookPath(ledgerCLI)
if err != nil {
return fmt.Errorf("can't find %s in PATH: %w", ledgerCLI, err)
}
cmd := exec.Command(ledger, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
return fmt.Errorf("get ledger key: %w", err)
}
f, err := os.Open(tmp)
if err != nil {
return fmt.Errorf("open ledger output: %w", err)
}
defer f.Close()
decoder := json.NewDecoder(f)
if err := decoder.Decode(&output); err != nil {
return fmt.Errorf("parse ledger output: %w", err)
}
return nil
}
func getLedgerTmpFile() (string, error) {
tmp, err := os.CreateTemp("", "ledger.out")
if err != nil {
return "", fmt.Errorf("creating temporary file: %w", err)
}
tmpPath := tmp.Name()
tmp.Close()
return tmpPath, nil
}
func (p *LedgerWalletProvider) DID() DID {
return p.did
}
func (p *LedgerWalletProvider) Signer() Signer {
return LedgerSigner
}
func (p *LedgerWalletProvider) Sign(data []byte) ([]byte, error) {
tmp, err := getLedgerTmpFile()
if err != nil {
return nil, err
}
defer os.Remove(tmp)
dataHex := hex.EncodeToString(data)
var output LedgerSignOutput
if err := ledgerExec(
tmp,
&output,
"sign",
"-o", tmp,
"-a", fmt.Sprintf("%d", p.acct),
dataHex,
); err != nil {
return nil, fmt.Errorf("error executing ledger cli: %w", err)
}
rBytes, err := hex.DecodeString(output.ECDSA.R)
if err != nil {
return nil, fmt.Errorf("error decoding signature r: %w", err)
}
sBytes, err := hex.DecodeString(output.ECDSA.S)
if err != nil {
return nil, fmt.Errorf("error decoding signature s: %w", err)
}
r := secp256k1.ModNScalar{}
s := secp256k1.ModNScalar{}
if overflow := r.SetByteSlice(rBytes); overflow {
return nil, fmt.Errorf("signature r overflowed")
}
if overflow := s.SetByteSlice(sBytes); overflow {
return nil, fmt.Errorf("signature s overflowed")
}
sig := ecdsa.NewSignature(&r, &s)
return sig.Serialize(), nil
}
func (p *LedgerWalletProvider) Anchor() Anchor {
return NewAnchor(p.did, p.pubk)
}
func (p *LedgerWalletProvider) PrivateKey() (crypto.PrivKey, error) {
return nil, fmt.Errorf("ledger private key cannot be exported: %w", ErrHardwareKey)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
"golang.org/x/crypto/ed25519"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const (
curveEd25519 = "Ed25519"
curveSecp256k1 = "secp256k1"
)
// PRISMResolverConfig holds configuration for PRISM DID resolution
type PRISMResolverConfig struct {
// ResolverURL is the base URL of the PRISM DID resolver
// Example: "https://prism-agent.example.com"
ResolverURL string
// HTTPClient is the HTTP client to use for resolution
// If nil, a default client with 30s timeout will be used
HTTPClient *http.Client
// PreferredVerificationMethod specifies which verification method to prefer
// when multiple are available. Options: "authentication", "assertionMethod", "capabilityInvocation"
// If empty, defaults to "authentication"
PreferredVerificationMethod string
}
var defaultPRISMResolverConfig = PRISMResolverConfig{
ResolverURL: "https://prism-agent.example.com",
PreferredVerificationMethod: "authentication",
}
// DIDDocument represents a W3C DID Document
type DIDDocument struct { //nolint:revive
Context interface{} `json:"@context"`
ID string `json:"id"`
VerificationMethod []VerificationMethod `json:"verificationMethod,omitempty"`
Authentication []VerificationMethodRef `json:"authentication,omitempty"`
AssertionMethod []VerificationMethodRef `json:"assertionMethod,omitempty"`
KeyAgreement []VerificationMethodRef `json:"keyAgreement,omitempty"`
CapabilityInvocation []VerificationMethodRef `json:"capabilityInvocation,omitempty"`
CapabilityDelegation []VerificationMethodRef `json:"capabilityDelegation,omitempty"`
Service []Service `json:"service,omitempty"`
}
// VerificationMethod represents a public key in a DID Document
type VerificationMethod struct {
ID string `json:"id"`
Type string `json:"type"`
Controller string `json:"controller"`
PublicKeyJWK json.RawMessage `json:"publicKeyJwk,omitempty"`
}
// VerificationMethodRef can be either a string (ID reference) or a VerificationMethod object
type VerificationMethodRef struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Controller string `json:"controller,omitempty"`
PublicKeyJWK json.RawMessage `json:"publicKeyJwk,omitempty"`
}
// UnmarshalJSON implements custom JSON unmarshaling for VerificationMethodRef
// It handles both string references (just the ID) and full objects
func (v *VerificationMethodRef) UnmarshalJSON(data []byte) error {
// Try to unmarshal as string first
var str string
if err := json.Unmarshal(data, &str); err == nil {
// It's a string reference, just set the ID
v.ID = str
return nil
}
// If not a string, unmarshal as an object
type Alias VerificationMethodRef
var alias Alias
if err := json.Unmarshal(data, &alias); err != nil {
return err
}
*v = VerificationMethodRef(alias)
return nil
}
// Service represents a service endpoint in a DID Document
type Service struct {
ID string `json:"id"`
Type string `json:"type"`
ServiceEndpoint string `json:"serviceEndpoint"`
}
// JWK represents a JSON Web Key
type JWK struct {
Kty string `json:"kty"` // Key type: "EC" or "OKP"
Crv string `json:"crv"` // Curve: "Ed25519", "secp256k1", "X25519"
X string `json:"x"` // X coordinate (base64url)
Y string `json:"y,omitempty"` // Y coordinate (base64url, only for EC keys)
}
// resolvePRISMDID resolves a PRISM DID to its DID document
func resolvePRISMDID(ctx context.Context, did DID, config PRISMResolverConfig) (*DIDDocument, error) {
// Build resolution URL
// PRISM resolvers typically use: https://resolver-url/did/{did}
// OpenPrismNode uses: https://resolver-url/api/v1/identifiers/{did}
resolverURL := config.ResolverURL
if resolverURL == "" {
resolverURL = defaultPRISMResolverConfig.ResolverURL
}
// Detect resolver format and build appropriate URL
// Note: DIDs in URL paths need to be URL-encoded (e.g., : becomes %3A)
encodedDID := url.PathEscape(did.URI)
// NeoPRISM format - use W3C-compliant DID resolution endpoint
// /api/dids/{did} returns W3C-compliant JSON DID resolution result
resolutionURL := fmt.Sprintf("%s/api/dids/%s", strings.TrimSuffix(resolverURL, "/"), encodedDID)
// Create HTTP client if not provided
client := config.HTTPClient
if client == nil {
client = &http.Client{
Timeout: 30 * time.Second,
}
}
// Create request
req, err := http.NewRequestWithContext(ctx, http.MethodGet, resolutionURL, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Accept", "application/did+json,application/json")
// Execute request
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("resolve DID: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DID resolution failed with status %d", resp.StatusCode)
}
// Parse response
var doc DIDDocument
if err := json.NewDecoder(resp.Body).Decode(&doc); err != nil {
return nil, fmt.Errorf("parse DID document: %w", err)
}
// Validate DID matches
if doc.ID != did.URI {
return nil, fmt.Errorf("DID document ID mismatch: expected %s, got %s", did.URI, doc.ID)
}
return &doc, nil
}
// extractPublicKeyFromDIDDocument extracts a public key from a PRISM DID document
func extractPublicKeyFromDIDDocument(doc *DIDDocument, config PRISMResolverConfig) (crypto.PubKey, error) {
// Determine which verification methods to check
var methodRefs []VerificationMethodRef
preferred := config.PreferredVerificationMethod
if preferred == "" {
preferred = defaultPRISMResolverConfig.PreferredVerificationMethod
}
switch preferred {
case "authentication":
methodRefs = doc.Authentication
case "assertionMethod":
methodRefs = doc.AssertionMethod
case "capabilityInvocation":
methodRefs = doc.CapabilityInvocation
default:
// Fallback to authentication
methodRefs = doc.Authentication
}
// Try to find a verification method in the preferred relationship
for _, ref := range methodRefs {
var vm *VerificationMethod
// Handle case where ref might be embedded (has PublicKeyJWK)
if ref.PublicKeyJWK != nil {
// Embedded verification method
vm = &VerificationMethod{
ID: ref.ID,
Type: ref.Type,
Controller: ref.Controller,
PublicKeyJWK: ref.PublicKeyJWK,
}
} else if ref.ID != "" {
// Look up by ID in verificationMethod array
for i := range doc.VerificationMethod {
if doc.VerificationMethod[i].ID == ref.ID {
vm = &doc.VerificationMethod[i]
break
}
}
}
if vm != nil {
if pubk, err := extractPublicKeyFromVerificationMethod(vm); err == nil {
return pubk, nil
}
}
}
// Fallback: try all verification methods
for i := range doc.VerificationMethod {
if pubk, err := extractPublicKeyFromVerificationMethod(&doc.VerificationMethod[i]); err == nil {
return pubk, nil
}
}
return nil, fmt.Errorf("no supported verification method found in DID document")
}
// extractPublicKeyFromVerificationMethod extracts a public key from a verification method
func extractPublicKeyFromVerificationMethod(vm *VerificationMethod) (crypto.PubKey, error) {
// Only support JsonWebKey2020 type
if vm.Type != "JsonWebKey2020" {
return nil, fmt.Errorf("unsupported verification method type: %s", vm.Type)
}
if vm.PublicKeyJWK == nil {
return nil, fmt.Errorf("missing publicKeyJwk in verification method")
}
// Parse JWK
var jwk JWK
if err := json.Unmarshal(vm.PublicKeyJWK, &jwk); err != nil {
return nil, fmt.Errorf("parse JWK: %w", err)
}
// Convert JWK to crypto.PubKey based on curve
switch jwk.Crv {
case curveEd25519:
return extractEd25519Key(jwk)
case curveSecp256k1:
return extractSecp256k1Key(jwk)
case "X25519":
// X25519 is for key agreement, not signing
// We'll skip it for now as we need signing keys
return nil, fmt.Errorf("X25519 keys are for key agreement, not signing")
default:
return nil, fmt.Errorf("unsupported curve: %s", jwk.Crv)
}
}
// extractEd25519Key extracts an Ed25519 public key from JWK
func extractEd25519Key(jwk JWK) (crypto.PubKey, error) {
if jwk.Kty != "OKP" {
return nil, fmt.Errorf("Ed25519 key must have kty=OKP, got %s", jwk.Kty)
}
// Decode base64url X coordinate
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("decode Ed25519 X coordinate: %w", err)
}
if len(xBytes) != ed25519.PublicKeySize {
return nil, fmt.Errorf("invalid Ed25519 key size: expected %d, got %d", ed25519.PublicKeySize, len(xBytes))
}
// Convert to libp2p Ed25519 public key
pubKey := ed25519.PublicKey(xBytes)
return libp2p_crypto.UnmarshalEd25519PublicKey(pubKey)
}
// extractSecp256k1Key extracts a secp256k1 public key from JWK
func extractSecp256k1Key(jwk JWK) (crypto.PubKey, error) {
if jwk.Kty != "EC" {
return nil, fmt.Errorf("secp256k1 key must have kty=EC, got %s", jwk.Kty)
}
if jwk.Y == "" {
return nil, fmt.Errorf("secp256k1 key missing Y coordinate")
}
// Decode base64url coordinates
xBytes, err := base64.RawURLEncoding.DecodeString(jwk.X)
if err != nil {
return nil, fmt.Errorf("decode secp256k1 X coordinate: %w", err)
}
yBytes, err := base64.RawURLEncoding.DecodeString(jwk.Y)
if err != nil {
return nil, fmt.Errorf("decode secp256k1 Y coordinate: %w", err)
}
// libp2p's UnmarshalSecp256k1PublicKey expects the uncompressed public key format:
// 0x04 || X (32 bytes) || Y (32 bytes) = 65 bytes total
if len(xBytes) != 32 || len(yBytes) != 32 {
return nil, fmt.Errorf("invalid secp256k1 key size: X=%d bytes, Y=%d bytes (expected 32 each)", len(xBytes), len(yBytes))
}
pubKeyBytes := make([]byte, 0, 1+len(xBytes)+len(yBytes))
pubKeyBytes = append(pubKeyBytes, 0x04) // uncompressed point prefix
pubKeyBytes = append(pubKeyBytes, xBytes...)
pubKeyBytes = append(pubKeyBytes, yBytes...)
// Use libp2p's secp256k1 unmarshaler
return libp2p_crypto.UnmarshalSecp256k1PublicKey(pubKeyBytes)
}
// makePrismAnchor creates an Anchor for a PRISM DID
func makePrismAnchor(did DID) (Anchor, error) {
// Use global config
config := globalPRISMConfig
// Resolve DID document
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
doc, err := resolvePRISMDID(ctx, did, config)
if err != nil {
return nil, fmt.Errorf("resolve PRISM DID: %w", err)
}
// Extract public key from DID document
pubk, err := extractPublicKeyFromDIDDocument(doc, config)
if err != nil {
return nil, fmt.Errorf("extract public key: %w", err)
}
// Create anchor
return NewAnchor(did, pubk), nil
}
// SetPRISMResolverConfig sets the global PRISM resolver configuration
// This allows users to configure the resolver URL and other options
var globalPRISMConfig = defaultPRISMResolverConfig
// SetPRISMResolverConfig updates the global PRISM resolver configuration
func SetPRISMResolverConfig(config PRISMResolverConfig) {
globalPRISMConfig = config
}
// GetPRISMResolverConfig returns the current global PRISM resolver configuration
func GetPRISMResolverConfig() PRISMResolverConfig {
return globalPRISMConfig
}
// ProviderFromPRISMPrivateKey creates a Provider from a PRISM DID and private key
// This allows using PRISM DIDs for signing UCAN tokens
// Note: The private key is not verified against the DID document at creation time.
// Verification happens when the provider is used (e.g., when signing tokens).
func ProviderFromPRISMPrivateKey(prismDID DID, privk crypto.PrivKey) (Provider, error) {
if prismDID.Method() != "prism" {
return nil, fmt.Errorf("expected PRISM DID, got %s", prismDID.Method())
}
return NewProvider(prismDID, privk), nil
}
// ImportPRISMPrivateKeyFromJWK imports a PRISM private key from JWK format
// Returns the private key and the corresponding PRISM DID
func ImportPRISMPrivateKeyFromJWK(jwkData []byte, _ DID) (crypto.PrivKey, error) {
var jwk JWK
if err := json.Unmarshal(jwkData, &jwk); err != nil {
return nil, fmt.Errorf("parse JWK: %w", err)
}
switch jwk.Crv {
case curveEd25519:
return importEd25519PrivateKeyFromJWK(jwk)
case curveSecp256k1:
return importSecp256k1PrivateKeyFromJWK(jwk)
default:
return nil, fmt.Errorf("unsupported curve: %s", jwk.Crv)
}
}
// importEd25519PrivateKeyFromJWK imports an Ed25519 private key from JWK
func importEd25519PrivateKeyFromJWK(jwk JWK) (crypto.PrivKey, error) {
if jwk.Kty != "OKP" {
return nil, fmt.Errorf("Ed25519 key must have kty=OKP, got %s", jwk.Kty)
}
// JWK for Ed25519 private keys typically has both 'd' (private key) and 'x' (public key)
// But we might only have 'd' (the seed) or the full private key
// For Ed25519, the private key is 32 bytes (seed) or 64 bytes (seed + public key)
// Try to get 'd' (private key material)
// Note: JWK private keys have a 'd' field, but our JWK struct doesn't include it yet
// For now, we'll need to parse it from the raw JSON
return nil, fmt.Errorf("JWK private key import not yet implemented - use raw key format")
}
// importSecp256k1PrivateKeyFromJWK imports a secp256k1 private key from JWK
func importSecp256k1PrivateKeyFromJWK(_ JWK) (crypto.PrivKey, error) {
return nil, fmt.Errorf("JWK private key import not yet implemented - use raw key format")
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package did
import (
"crypto/ed25519"
"encoding/hex"
"fmt"
"github.com/decred/dcrd/dcrec/secp256k1/v4"
"gitlab.com/nunet/device-management-service/lib/crypto"
prismpb "gitlab.com/nunet/device-management-service/proto/generated/prism"
"google.golang.org/protobuf/proto"
)
// PRISMCreateDIDOperation represents a PRISM CreateDID operation
type PRISMCreateDIDOperation struct {
PublicKeys []PRISMPublicKey
Services []PRISMService
Context []string
}
// PRISMPublicKey represents a public key in PRISM format
type PRISMPublicKey struct {
ID string // e.g., "master-0"
Usage string // "MASTER_KEY", "AUTHENTICATION_KEY", etc.
Key []byte // Public key bytes (Ed25519: 32 bytes, Secp256k1: 33 bytes compressed)
Curve string // "Ed25519" or "secp256k1" (optional, auto-detected if empty)
}
// PRISMService represents a service in PRISM format
type PRISMService struct {
ID string
Type string
ServiceEndpoint string
}
// CreateSignedPRISMOperation creates a signed PRISM CreateDID operation using protobuf encoding
func CreateSignedPRISMOperation(
privKey crypto.PrivKey,
_ crypto.PubKey,
keyID string,
publicKeys []PRISMPublicKey,
services []PRISMService,
context []string,
) (string, error) {
// Step 1: Create the ProtoCreateDID operation
createDID, err := buildCreateDIDOperation(publicKeys, services, context)
if err != nil {
return "", fmt.Errorf("build create DID operation: %w", err)
}
// Step 2: Create PrismOperation with create_did field
prismOp := &prismpb.PrismOperation{
Operation: &prismpb.PrismOperation_CreateDid{
CreateDid: createDID,
},
}
// Step 3: Encode PrismOperation to bytes
operationBytes, err := proto.Marshal(prismOp)
if err != nil {
return "", fmt.Errorf("marshal prism operation: %w", err)
}
// Step 4: Sign the encoded operation bytes
signature, err := privKey.Sign(operationBytes)
if err != nil {
return "", fmt.Errorf("sign operation: %w", err)
}
// Step 5: Create SignedPrismOperation
signedOp := &prismpb.SignedPrismOperation{
SignedWith: keyID,
Signature: signature,
Operation: prismOp,
}
// Step 6: Encode SignedPrismOperation to bytes
signedOpBytes, err := proto.Marshal(signedOp)
if err != nil {
return "", fmt.Errorf("marshal signed operation: %w", err)
}
// Step 7: Encode to hex
return hex.EncodeToString(signedOpBytes), nil
}
// buildCreateDIDOperation builds a ProtoCreateDID message from the provided keys, services, and context
func buildCreateDIDOperation(
publicKeys []PRISMPublicKey,
services []PRISMService,
context []string,
) (*prismpb.ProtoCreateDID, error) {
// Convert PRISMPublicKey to protobuf PublicKey
pbPublicKeys := make([]*prismpb.PublicKey, 0, len(publicKeys))
for _, pk := range publicKeys {
pbKey, err := convertToPRISMPublicKey(pk)
if err != nil {
return nil, fmt.Errorf("convert public key %s: %w", pk.ID, err)
}
pbPublicKeys = append(pbPublicKeys, pbKey)
}
// Convert PRISMService to protobuf Service
pbServices := make([]*prismpb.Service, 0, len(services))
for _, svc := range services {
pbServices = append(pbServices, &prismpb.Service{
Id: svc.ID,
Type: svc.Type,
ServiceEndpoint: svc.ServiceEndpoint,
})
}
// Build DIDCreationData
didData := &prismpb.ProtoCreateDID_DIDCreationData{
PublicKeys: pbPublicKeys,
Services: pbServices,
Context: context,
}
// Build ProtoCreateDID
createDID := &prismpb.ProtoCreateDID{
DidData: didData,
}
return createDID, nil
}
// convertToPRISMPublicKey converts a PRISMPublicKey to a protobuf PublicKey
func convertToPRISMPublicKey(pk PRISMPublicKey) (*prismpb.PublicKey, error) {
// Convert usage string to KeyUsage enum
usage, err := parseKeyUsage(pk.Usage)
if err != nil {
return nil, fmt.Errorf("parse key usage: %w", err)
}
var ecKeyData *prismpb.ECKeyData
// Auto-detect curve if not specified
curve := pk.Curve
if curve == "" {
// Detect from key size: Ed25519 is 32 bytes, Secp256k1 compressed is 33 bytes
if len(pk.Key) == ed25519.PublicKeySize { //nolint:gocritic
curve = curveEd25519
} else if len(pk.Key) == 33 {
curve = curveSecp256k1
} else {
return nil, fmt.Errorf("cannot auto-detect curve: key size %d bytes (expected 32 for Ed25519 or 33 for Secp256k1)", len(pk.Key))
}
}
switch curve {
case curveEd25519:
// For Ed25519, we use ECKeyData with the full key in the x field
// (Ed25519 doesn't have separate x/y coordinates like secp256k1)
if len(pk.Key) != ed25519.PublicKeySize {
return nil, fmt.Errorf("Ed25519 key must be 32 bytes, got %d", len(pk.Key))
}
ecKeyData = &prismpb.ECKeyData{
Curve: curveEd25519,
X: pk.Key, // For Ed25519, the full 32-byte key goes in X
Y: nil, // Ed25519 doesn't use Y coordinate
}
case curveSecp256k1:
// For Secp256k1, we need to extract X and Y coordinates from compressed key
// Standard approach: Parse compressed key, then extract coordinates
xBytes, yBytes, err := extractSecp256k1Coordinates(pk.Key)
if err != nil {
return nil, fmt.Errorf("extract secp256k1 coordinates: %w", err)
}
ecKeyData = &prismpb.ECKeyData{
Curve: curveSecp256k1,
X: xBytes,
Y: yBytes,
}
default:
return nil, fmt.Errorf("unsupported curve: %s (supported: Ed25519, secp256k1)", curve)
}
// Create PublicKey with ECKeyData
pbKey := &prismpb.PublicKey{
Id: pk.ID,
Usage: usage,
KeyData: &prismpb.PublicKey_EcKeyData{
EcKeyData: ecKeyData,
},
}
return pbKey, nil
}
// extractSecp256k1Coordinates extracts X and Y coordinates from a compressed Secp256k1 public key.
// The input should be a 33-byte compressed public key.
// Returns X and Y coordinates as 32-byte slices.
// This uses the standard secp256k1 library approach: parse compressed key, serialize uncompressed, extract coordinates.
func extractSecp256k1Coordinates(compressedKey []byte) (xBytes, yBytes []byte, err error) {
// Validate input: compressed Secp256k1 keys are 33 bytes
if len(compressedKey) != 33 {
return nil, nil, fmt.Errorf("compressed secp256k1 key must be 33 bytes, got %d", len(compressedKey))
}
// Parse compressed public key using standard secp256k1 library
pubKey, err := secp256k1.ParsePubKey(compressedKey)
if err != nil {
return nil, nil, fmt.Errorf("parse compressed public key: %w", err)
}
// Serialize uncompressed format: 0x04 || X (32 bytes) || Y (32 bytes) = 65 bytes total
// This is the standard SEC 1 uncompressed point format
uncompressed := pubKey.SerializeUncompressed()
if len(uncompressed) != 65 {
return nil, nil, fmt.Errorf("unexpected uncompressed key size: %d (expected 65)", len(uncompressed))
}
// Validate prefix byte (should be 0x04 for uncompressed)
if uncompressed[0] != 0x04 {
return nil, nil, fmt.Errorf("invalid uncompressed key prefix: expected 0x04, got 0x%02x", uncompressed[0])
}
// Extract X and Y coordinates (skip first byte which is 0x04)
xBytes = make([]byte, 32)
yBytes = make([]byte, 32)
copy(xBytes, uncompressed[1:33]) // X coordinate (32 bytes)
copy(yBytes, uncompressed[33:65]) // Y coordinate (32 bytes)
return xBytes, yBytes, nil
}
// parseKeyUsage converts a string to KeyUsage enum
func parseKeyUsage(usage string) (prismpb.KeyUsage, error) {
switch usage {
case "MASTER_KEY", "master_key":
return prismpb.KeyUsage_MASTER_KEY, nil
case "ISSUING_KEY", "issuing_key":
return prismpb.KeyUsage_ISSUING_KEY, nil
case "KEY_AGREEMENT_KEY", "key_agreement_key":
return prismpb.KeyUsage_KEY_AGREEMENT_KEY, nil
case "AUTHENTICATION_KEY", "authentication_key":
return prismpb.KeyUsage_AUTHENTICATION_KEY, nil
case "REVOCATION_KEY", "revocation_key":
return prismpb.KeyUsage_REVOCATION_KEY, nil
case "CAPABILITY_INVOCATION_KEY", "capability_invocation_key":
return prismpb.KeyUsage_CAPABILITY_INVOCATION_KEY, nil
case "CAPABILITY_DELEGATION_KEY", "capability_delegation_key":
return prismpb.KeyUsage_CAPABILITY_DELEGATION_KEY, nil
case "VDR_KEY", "vdr_key":
return prismpb.KeyUsage_VDR_KEY, nil
default:
return prismpb.KeyUsage_UNKNOWN_KEY, fmt.Errorf("unknown key usage: %s", usage)
}
}
// CreateSignedPRISMOperationSimple is a convenience function that creates a signed PRISM operation
// with a single master key from the provided private/public key pair
// Note: NeoPRISM requires Secp256k1 keys for master keys, so this function will use Secp256k1 if available
func CreateSignedPRISMOperationSimple(
privKey crypto.PrivKey,
pubKey crypto.PubKey,
keyID string,
) (string, error) {
// Get public key bytes
pubKeyBytes, err := pubKey.Raw()
if err != nil {
return "", fmt.Errorf("get public key bytes: %w", err)
}
// Detect key type and set curve
var curve string
var context []string
switch pubKey.Type() {
case crypto.Ed25519:
if len(pubKeyBytes) != ed25519.PublicKeySize {
return "", fmt.Errorf("expected Ed25519 public key (32 bytes), got %d bytes", len(pubKeyBytes))
}
curve = curveEd25519
context = []string{
"https://www.w3.org/ns/did/v1",
"https://w3id.org/security/suites/ed25519-2020/v1",
}
case crypto.Secp256k1:
if len(pubKeyBytes) != 33 {
return "", fmt.Errorf("expected Secp256k1 compressed public key (33 bytes), got %d bytes", len(pubKeyBytes))
}
curve = "secp256k1"
context = []string{
"https://www.w3.org/ns/did/v1",
"https://w3id.org/security/suites/secp256k1-2019/v1",
}
default:
return "", fmt.Errorf("unsupported key type: %d (supported: Ed25519, Secp256k1)", pubKey.Type())
}
// Create keys: master key (required) and authentication key (for DID document)
// Note: Master keys don't appear in verificationMethod, so we add an authentication key
// to ensure the DID document has verification methods
publicKeys := []PRISMPublicKey{
{
ID: keyID,
Usage: "MASTER_KEY",
Key: pubKeyBytes,
Curve: curve,
},
{
ID: "auth-0",
Usage: "AUTHENTICATION_KEY",
Key: pubKeyBytes, // Use the same key for authentication
Curve: curve,
},
}
return CreateSignedPRISMOperation(privKey, pubKey, keyID, publicKeys, nil, context)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package env
import (
"os"
"sync"
)
// EnvironmentProvider provides access to environment variables
type EnvironmentProvider interface {
Getenv(key string) string
Setenv(key, value string) error
}
// OSEnvironment uses the actual OS environment
type OSEnvironment struct{}
var _ EnvironmentProvider = (*OSEnvironment)(nil)
func NewOSEnvironment() OSEnvironment {
return OSEnvironment{}
}
func (e OSEnvironment) Getenv(key string) string {
return os.Getenv(key)
}
func (e OSEnvironment) Setenv(key, value string) error {
return os.Setenv(key, value)
}
// MockEnvironment provides an in-memory environment for testing
type MockEnvironment struct {
vars map[string]string
mu sync.RWMutex
}
func NewMockEnvironment() *MockEnvironment {
return &MockEnvironment{
vars: make(map[string]string),
}
}
func (e *MockEnvironment) Getenv(key string) string {
e.mu.RLock()
defer e.mu.RUnlock()
v, ok := e.vars[key]
if !ok {
return ""
}
return v
}
func (e *MockEnvironment) Setenv(key, value string) error {
e.mu.Lock()
defer e.mu.Unlock()
e.vars[key] = value
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cpu
import (
"fmt"
"sync"
"time"
"github.com/shirou/gopsutil/v4/cpu"
"gitlab.com/nunet/device-management-service/types"
)
var (
cpuCache *types.CPU
cpuCacheOnce sync.Once
cpuCacheErr error
)
const (
sampleInterval = 1 * time.Second
alpha = 0.4 // smoothing factor (should be tweaked)
)
// monitor CPU usage using an exponential moving average to help with
// smoothing out short-term fluctuations
type Monitor struct {
mu sync.RWMutex
avgUsage float64
ticker *time.Ticker
done chan struct{}
running bool
}
func NewCPUMonitor() *Monitor {
monitor := &Monitor{
avgUsage: 0.0,
done: make(chan struct{}),
}
monitor.Start()
return monitor
}
func (c *Monitor) Start() {
if c.running {
return
}
c.running = true
c.ticker = time.NewTicker(sampleInterval)
go func() {
for {
select {
case <-c.ticker.C:
cpuPercent, err := cpu.Percent(sampleInterval, false)
if err != nil {
continue
}
if len(cpuPercent) > 0 {
c.mu.Lock()
c.avgUsage = alpha*cpuPercent[0] + (1.0-alpha)*c.avgUsage
c.mu.Unlock()
}
case <-c.done:
return
}
}
}()
}
func (c *Monitor) Stop() {
if !c.running {
return
}
c.running = false
c.ticker.Stop()
close(c.done)
}
func (c *Monitor) GetAvgCPUUsage() (types.CPU, error) {
c.mu.RLock()
defer c.mu.RUnlock()
cpuInfo, err := GetCPU()
if err != nil {
return types.CPU{}, fmt.Errorf("get CPU info: %s", err)
}
usedCores := float64(cpuInfo.Cores) * c.avgUsage / 100
cpuInfo.Cores = float32(usedCores)
return cpuInfo, nil
}
// GetCPU returns the CPU information for the system
func GetCPU() (types.CPU, error) {
cpuCacheOnce.Do(func() {
cpuInfo, err := getCPU()
if err != nil {
cpuCacheErr = fmt.Errorf("get CPU info: %w", err)
return
}
cpuCache = &cpuInfo
})
if cpuCacheErr != nil {
return types.CPU{}, cpuCacheErr
}
if cpuCache == nil {
return types.CPU{}, fmt.Errorf("cpu info is nil")
}
return *cpuCache, nil
}
// GetUsage returns the CPU usage for the system at the current moment
func GetUsage() (types.CPU, error) {
cpuUsage, err := cpu.Percent(time.Second, false)
if err != nil {
return types.CPU{}, fmt.Errorf("failed to get CPU usage: %s", err)
}
cpuInfo, err := GetCPU()
if err != nil {
return types.CPU{}, fmt.Errorf("get CPU info: %s", err)
}
// Calculate the used cores
usedCores := float64(cpuInfo.Cores) * cpuUsage[0] / 100
cpuInfo.Cores = float32(usedCores)
return cpuInfo, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cpu
import (
"fmt"
"github.com/shirou/gopsutil/v4/cpu"
"gitlab.com/nunet/device-management-service/types"
)
// getCPU returns the CPU information for the system
func getCPU() (types.CPU, error) {
cores, err := cpu.Info()
if err != nil {
return types.CPU{}, fmt.Errorf("failed to get CPU info: %s", err)
}
var totalCompute float64
for i := 0; i < len(cores); i++ {
totalCompute += cores[i].Mhz
}
return types.CPU{
Cores: float32(len(cores)),
ClockSpeed: cores[0].Mhz * 1000000,
}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package hardware
import (
"context"
"fmt"
"github.com/shirou/gopsutil/v4/disk"
"gitlab.com/nunet/device-management-service/types"
)
// GetDisk returns the types.Disk for the system
func GetDisk() (types.Disk, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get partitions: %w", err)
}
var totalStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get disk usage: %w", err)
}
totalStorage += usage.Total
}
return types.Disk{
Size: totalStorage,
}, nil
}
// GetDiskUsage returns the types.Disk usage
func GetDiskUsage() (types.Disk, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get partitions: %w", err)
}
var usedStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return types.Disk{}, fmt.Errorf("failed to get disk usage: %w", err)
}
usedStorage += usage.Used
}
return types.Disk{
Size: usedStorage,
}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && (amd64 || 386)
package gpu
import (
"fmt"
goamdsmi "gitlab.com/nunet/device-management-service/lib/hardware/gpu/amdsmi"
"gitlab.com/nunet/device-management-service/types"
)
// amdGPUConnector implements the types.GPUConnector interface for AMD GPUs.
type amdGPUConnector struct {
processorIndexCache map[string]int // UUID to index map
processorCache []goamdsmi.ProcessorHandle
}
var _ types.GPUConnector = (*amdGPUConnector)(nil)
// newAMDGPUConnector returns a new AMD GPU Connector.
func newAMDGPUConnector() (types.GPUConnector, error) {
connector := &amdGPUConnector{
processorIndexCache: make(map[string]int),
}
if err := connector.initialize(); err != nil {
return nil, fmt.Errorf("initialize AMD GPU connector: %w", err)
}
if err := connector.loadProcessorHandles(); err != nil {
return nil, fmt.Errorf("load processor handles: %w", err)
}
return connector, nil
}
// initialize initializes the AMD SMI library and loads the processors.
func (a *amdGPUConnector) initialize() error {
status, err := goamdsmi.Init()
if err != nil {
return err
}
if status.Code != goamdsmi.StatusSuccess {
return fmt.Errorf("AMD SMI initialization was unsuccessful: %w", status.Error())
}
return nil
}
// loadProcessorHandles loads the processor handles for the AMD GPUs.
func (a *amdGPUConnector) loadProcessorHandles() error {
// Retrieve socket handles
sockets, ret := goamdsmi.GetSocketHandles()
if ret.Code != goamdsmi.StatusSuccess {
return fmt.Errorf("get socket handles: %w", ret.Error())
}
// Iterate over each socket
for _, socket := range sockets {
// Retrieve processor handles for the current socket
processors, ret := goamdsmi.GetProcessorHandles(socket)
if ret.Code != goamdsmi.StatusSuccess {
return fmt.Errorf("get processor handles: %w", ret.Error())
}
// Iterate over each processor
for i, processor := range processors {
uuid, ret := goamdsmi.GetGPUUUID(processor)
if ret.Code != goamdsmi.StatusSuccess {
return fmt.Errorf("get GPU UUID: %w", ret.Error())
}
// Add the processor to the processor map
a.processorIndexCache[uuid] = i
a.processorCache = append(a.processorCache, processor)
}
}
return nil
}
// GetGPUs returns the AMD GPUs in the system.
func (a *amdGPUConnector) GetGPUs() (types.GPUs, error) {
gpus := make(types.GPUs, len(a.processorCache))
for _, processor := range a.processorCache {
boardInfo, ret := goamdsmi.GetGPUBoardInfo(processor)
if ret.Code != goamdsmi.StatusSuccess {
return nil, fmt.Errorf("get board info: %w", ret.Error())
}
vRAM, ret := goamdsmi.GetGPUVRAM(processor)
if ret.Code != goamdsmi.StatusSuccess {
return nil, fmt.Errorf("get GPU VRAM: %w", ret.Error())
}
bdfID, ret := goamdsmi.GetGPUBDFID(processor)
if ret.Code != goamdsmi.StatusSuccess {
return nil, fmt.Errorf("get GPU BDFID: %w", ret.Error())
}
uuid, ret := goamdsmi.GetGPUUUID(processor)
if ret.Code != goamdsmi.StatusSuccess {
return nil, fmt.Errorf("get GPU UUID: %w", ret.Error())
}
// Get compute unit count; default to 0 if unavailable
var cores uint32
asicInfo, ret := goamdsmi.GetGPUASICInfo(processor)
if ret.Code == goamdsmi.StatusSuccess && asicInfo.NumComputeUnits != 0xFFFFFFFF {
cores = asicInfo.NumComputeUnits
}
gpu := types.GPU{
UUID: uuid,
Model: boardInfo.ProductName,
VRAM: types.ConvertMibToBytes(uint64(vRAM.Total)),
Cores: cores,
Vendor: types.GPUVendorAMDATI,
PCIAddress: bdfIDToPCIAddress(bdfID),
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// GetGPUUsage returns the GPU usage for the device with the given UUID.
func (a *amdGPUConnector) GetGPUUsage(uuid string) (uint64, error) {
if len(uuid) == 0 {
return 0, fmt.Errorf("no UUID provided")
}
processorIndex, ok := a.processorIndexCache[uuid]
if !ok {
return 0, fmt.Errorf("AMD gpu device with UUID %s not found", uuid)
}
processor := a.processorCache[processorIndex]
vram, ret := goamdsmi.GetGPUVRAM(processor)
if ret.Code != goamdsmi.StatusSuccess {
return 0, fmt.Errorf("get GPU usage: %w", ret.Error())
}
return types.ConvertMibToBytes(uint64(vram.Used)), nil
}
// Shutdown shuts down the AMD SMI library.
func (a *amdGPUConnector) Shutdown() error {
ret := goamdsmi.Shutdown()
if ret.Code != goamdsmi.StatusSuccess {
return fmt.Errorf("shutdown amdsmi: %w", ret.Error())
}
// clear the cache
a.processorCache = nil
a.processorIndexCache = nil
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && (amd64 || 386)
package amdsmi
/*
#cgo CXXFLAGS: -Iinclude -std=c++11
#cgo CFLAGS: -Iinclude
#cgo LDFLAGS: -ldl
#include <dlfcn.h>
#include <stdio.h>
#include <stdlib.h>
#include "amdsmi.h"
// Define constants for Go to use
const int GO_AMDSMI_PROCESSOR_TYPE_UNKNOWN = AMDSMI_PROCESSOR_TYPE_UNKNOWN;
const int GO_AMDSMI_PROCESSOR_TYPE_AMD_GPU = AMDSMI_PROCESSOR_TYPE_AMD_GPU;
const int GO_AMDSMI_PROCESSOR_TYPE_AMD_CPU = AMDSMI_PROCESSOR_TYPE_AMD_CPU;
const int GO_AMDSMI_PROCESSOR_TYPE_NON_AMD_GPU = AMDSMI_PROCESSOR_TYPE_NON_AMD_GPU;
const int GO_AMDSMI_PROCESSOR_TYPE_NON_AMD_CPU = AMDSMI_PROCESSOR_TYPE_NON_AMD_CPU;
const int GO_AMDSMI_PROCESSOR_TYPE_AMD_CPU_CORE = AMDSMI_PROCESSOR_TYPE_AMD_CPU_CORE;
const int GO_AMDSMI_PROCESSOR_TYPE_AMD_APU = AMDSMI_PROCESSOR_TYPE_AMD_APU;
const int GO_AMDSMI_GPU_UUID_SIZE = AMDSMI_GPU_UUID_SIZE;
// ========================
// AMD SMI Function Pointer Definitions
// ========================
// Macro to define a function pointer type
#define DEFINE_AMDSMI_FUNC_TYPE(ret, name, args) typedef ret (*name##_fp) args
// Define function pointer types for AMD SMI functions
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_init, (uint64_t flags));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_shut_down, (void));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_socket_handles, (uint32_t *socket_count, amdsmi_socket_handle* socket_handles));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_socket_info, (amdsmi_socket_handle socket_handle, size_t len, char *name));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_processor_handles, (amdsmi_socket_handle socket_handle, uint32_t* processor_count, amdsmi_processor_handle* processor_handles));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_processor_type, (amdsmi_processor_handle processor_handle, processor_type_t* processor_type));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_gpu_board_info, (amdsmi_processor_handle processor_handle, amdsmi_board_info_t *board_info));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_gpu_id, (amdsmi_processor_handle processor_handle, uint16_t *id));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_gpu_device_uuid, (amdsmi_processor_handle processor_handle, unsigned int *uuid_length, char *uuid));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_gpu_vram_usage, (amdsmi_processor_handle processor_handle, amdsmi_vram_usage_t *vram_info));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_gpu_bdf_id, (amdsmi_processor_handle processor_handle, uint64_t *bdf_id));
DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_get_gpu_asic_info, (amdsmi_processor_handle processor_handle, amdsmi_asic_info_t *info));
// ========================
// AMD SMI Function Pointers Struct
// ========================
typedef struct {
amdsmi_init_fp amdsmi_init;
amdsmi_shut_down_fp amdsmi_shut_down;
amdsmi_get_socket_handles_fp amdsmi_get_socket_handles;
amdsmi_get_socket_info_fp amdsmi_get_socket_info;
amdsmi_get_processor_handles_fp amdsmi_get_processor_handles;
amdsmi_get_processor_type_fp amdsmi_get_processor_type;
amdsmi_get_gpu_board_info_fp amdsmi_get_gpu_board_info;
amdsmi_get_gpu_id_fp amdsmi_get_gpu_id;
amdsmi_get_gpu_device_uuid_fp amdsmi_get_gpu_device_uuid;
amdsmi_get_gpu_vram_usage_fp amdsmi_get_gpu_vram_usage;
amdsmi_get_gpu_bdf_id_fp amdsmi_get_gpu_bdf_id;
amdsmi_get_gpu_asic_info_fp amdsmi_get_gpu_asic_info;
} amdsmi_functions_t;
// ========================
// Global Variables
// ========================
static void *lib_handle = NULL;
static amdsmi_functions_t amdsmi_funcs;
// ========================
// Helper Macros
// ========================
// Macro to load a symbol and assign it to a struct member
#define LOAD_AMDSMI_SYMBOL(func_name) \
amdsmi_funcs.func_name = (func_name##_fp)dlsym(lib_handle, #func_name); \
if (!amdsmi_funcs.func_name) { \
fprintf(stderr, "Error loading symbol %s: %s\n", #func_name, dlerror()); \
dlclose(lib_handle); \
lib_handle = NULL; \
return 0; \
}
// Macro to load a symbol optionally (does not abort if symbol is missing).
// Use for symbols that may not exist in older library versions.
#define LOAD_AMDSMI_SYMBOL_OPTIONAL(func_name) \
amdsmi_funcs.func_name = (func_name##_fp)dlsym(lib_handle, #func_name); \
if (!amdsmi_funcs.func_name) { \
dlerror(); \
}
// Macro to define a wrapper function
#define DEFINE_AMDSMI_WRAPPER(ret_type, wrapper_name, amdsmi_func, args, ...) \
ret_type wrapper_name args { \
if (amdsmi_funcs.amdsmi_func) { \
return amdsmi_funcs.amdsmi_func(__VA_ARGS__); \
} \
return AMDSMI_STATUS_INVAL; \
}
// ========================
// Library Management Functions
// ========================
// Load the AMD SMI library and resolve all required symbols
int load_amdsmi_library() {
if (lib_handle) {
return 1;
}
lib_handle = dlopen("libamd_smi.so", RTLD_LAZY);
if (!lib_handle) {
return 0;
}
// Clear any existing errors
dlerror();
// Load all required symbols
LOAD_AMDSMI_SYMBOL(amdsmi_init)
LOAD_AMDSMI_SYMBOL(amdsmi_shut_down)
LOAD_AMDSMI_SYMBOL(amdsmi_get_socket_handles)
LOAD_AMDSMI_SYMBOL(amdsmi_get_socket_info)
LOAD_AMDSMI_SYMBOL(amdsmi_get_processor_handles)
LOAD_AMDSMI_SYMBOL(amdsmi_get_processor_type)
LOAD_AMDSMI_SYMBOL(amdsmi_get_gpu_board_info)
LOAD_AMDSMI_SYMBOL(amdsmi_get_gpu_id)
LOAD_AMDSMI_SYMBOL(amdsmi_get_gpu_device_uuid)
LOAD_AMDSMI_SYMBOL(amdsmi_get_gpu_vram_usage)
LOAD_AMDSMI_SYMBOL(amdsmi_get_gpu_bdf_id)
// Optional symbols (may not exist in older library versions)
LOAD_AMDSMI_SYMBOL_OPTIONAL(amdsmi_get_gpu_asic_info)
return 1;
}
// Unload the AMD SMI library and reset function pointers
void unload_amdsmi_library() {
if (lib_handle) {
dlclose(lib_handle);
lib_handle = NULL;
// Reset all function pointers to NULL
amdsmi_funcs.amdsmi_init = NULL;
amdsmi_funcs.amdsmi_shut_down = NULL;
amdsmi_funcs.amdsmi_get_socket_handles = NULL;
amdsmi_funcs.amdsmi_get_socket_info = NULL;
amdsmi_funcs.amdsmi_get_processor_handles = NULL;
amdsmi_funcs.amdsmi_get_processor_type = NULL;
amdsmi_funcs.amdsmi_get_gpu_board_info = NULL;
amdsmi_funcs.amdsmi_get_gpu_id = NULL;
amdsmi_funcs.amdsmi_get_gpu_device_uuid = NULL;
amdsmi_funcs.amdsmi_get_gpu_vram_usage = NULL;
amdsmi_funcs.amdsmi_get_gpu_bdf_id = NULL;
amdsmi_funcs.amdsmi_get_gpu_asic_info = NULL;
}
}
// ========================
// Wrapper Functions
// ========================
// Define wrappers for each AMD SMI function
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_init, amdsmi_init, (uint64_t flags), flags)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_shut_down, amdsmi_shut_down, (void))
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_socket_handles, amdsmi_get_socket_handles, (uint32_t *socket_count, amdsmi_socket_handle* socket_handles), socket_count, socket_handles)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_socket_info, amdsmi_get_socket_info, (amdsmi_socket_handle socket_handle, size_t len, char *name), socket_handle, len, name)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_processor_handles, amdsmi_get_processor_handles, (amdsmi_socket_handle socket_handle, uint32_t* processor_count, amdsmi_processor_handle* processor_handles), socket_handle, processor_count, processor_handles)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_processor_type, amdsmi_get_processor_type, (amdsmi_processor_handle processor_handle, processor_type_t* processor_type), processor_handle, processor_type)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_gpu_board_info, amdsmi_get_gpu_board_info, (amdsmi_processor_handle processor_handle, amdsmi_board_info_t *board_info), processor_handle, board_info)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_gpu_id, amdsmi_get_gpu_id, (amdsmi_processor_handle processor_handle, uint16_t *id), processor_handle, id)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_gpu_device_uuid, amdsmi_get_gpu_device_uuid, (amdsmi_processor_handle processor_handle, unsigned int *uuid_length, char *uuid), processor_handle, uuid_length, uuid)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_gpu_vram_usage, amdsmi_get_gpu_vram_usage, (amdsmi_processor_handle processor_handle, amdsmi_vram_usage_t *vram_info), processor_handle, vram_info)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_gpu_bdf_id, amdsmi_get_gpu_bdf_id, (amdsmi_processor_handle processor_handle, uint64_t *bdf_id), processor_handle, bdf_id)
DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_get_gpu_asic_info, amdsmi_get_gpu_asic_info, (amdsmi_processor_handle processor_handle, amdsmi_asic_info_t *info), processor_handle, info)
// is_gpu_asic_info_available checks if amdsmi_get_gpu_asic_info was loaded
int is_gpu_asic_info_available() {
return amdsmi_funcs.amdsmi_get_gpu_asic_info != NULL ? 1 : 0;
}
*/
import "C"
/*
===============================================================================
Adding a New AMD SMI Function to the Cgo Block
===============================================================================
To add a new AMD SMI function, follow these steps:
1. **Define the Function Pointer Type:**
- Use `DEFINE_AMDSMI_FUNC_TYPE` to create a typedef for the new function.
- Example:
`DEFINE_AMDSMI_FUNC_TYPE(amdsmi_status_t, amdsmi_new_function, (int arg1, float arg2));`
2. **Add to the Struct:**
- Add the new function pointer to `amdsmi_functions_t`.
- Example:
`amdsmi_new_function_fp amdsmi_new_function;`
3. **Load the Symbol:**
- In `load_amdsmi_library`, load the new symbol using `LOAD_AMDSMI_SYMBOL`.
- Example:
`LOAD_AMDSMI_SYMBOL(amdsmi_new_function)`
4. **Create a Wrapper Function:**
- Define a wrapper using `DEFINE_AMDSMI_WRAPPER`.
- Example:
`DEFINE_AMDSMI_WRAPPER(amdsmi_status_t, call_amdsmi_new_function, amdsmi_new_function, (int arg1, float arg2), arg1, arg2)`
5. **Rebuild and Test:**
- Rebuild and verify the new function works as expected.
===============================================================================
*/
import (
"fmt"
"unsafe"
)
// ProcessorType is a Go type to represent processor_type_t from C.
type ProcessorType uint32
var (
ProcessorTypeUnknown = ProcessorType(C.GO_AMDSMI_PROCESSOR_TYPE_UNKNOWN)
ProcessorTypeAMDGPU = ProcessorType(C.GO_AMDSMI_PROCESSOR_TYPE_AMD_GPU)
ProcessorTypeAMDCPU = ProcessorType(C.GO_AMDSMI_PROCESSOR_TYPE_AMD_CPU)
ProcessorTypeNonAMDGPU = ProcessorType(C.GO_AMDSMI_PROCESSOR_TYPE_NON_AMD_GPU)
ProcessorTypeNonAMDCPU = ProcessorType(C.GO_AMDSMI_PROCESSOR_TYPE_NON_AMD_CPU)
ProcessorTypeAMDCPUCore = ProcessorType(C.GO_AMDSMI_PROCESSOR_TYPE_AMD_CPU_CORE)
ProcessorTypeAMDAPU = ProcessorType(C.GO_AMDSMI_PROCESSOR_TYPE_AMD_APU)
GPUUUIDSize = uint(C.GO_AMDSMI_GPU_UUID_SIZE)
)
// Init initializes the AMD SMI library with GPUs.
func Init() (Status, error) {
if C.load_amdsmi_library() == 0 {
return Status{}, fmt.Errorf("failed to load AMD SMI library")
}
ret := C.call_amdsmi_init(0)
return Status{
Code: StatusCode(ret),
}, nil
}
// Shutdown shuts down the AMD SMI library.
func Shutdown() Status {
ret := C.call_amdsmi_shut_down()
C.unload_amdsmi_library()
return Status{
Code: StatusCode(ret),
}
}
// GetSocketHandles returns the socket handles of the GPUs.
func GetSocketHandles() ([]SocketHandle, Status) {
var socketCount C.uint32_t
ret := C.call_amdsmi_get_socket_handles(&socketCount, nil)
if ret != C.AMDSMI_STATUS_SUCCESS {
return nil, Status{Code: StatusCode(ret), Message: "get socket count"}
}
if socketCount == 0 {
return nil, Status{Code: StatusSuccess, Message: "no socket found"}
}
sockets := make([]C.amdsmi_socket_handle, socketCount)
ret = C.call_amdsmi_get_socket_handles(&socketCount, (*C.amdsmi_socket_handle)(unsafe.Pointer(&sockets[0])))
if ret != C.AMDSMI_STATUS_SUCCESS {
return nil, Status{Code: StatusCode(ret), Message: "get socket handles"}
}
goSockets := make([]SocketHandle, socketCount)
for i, socket := range sockets {
goSockets[i] = SocketHandle{handle: unsafe.Pointer(socket)}
}
return goSockets, Status{Code: StatusSuccess}
}
// GetSocketName retrieves the socket name for a given socket handle.
func GetSocketName(socketHandle SocketHandle, maxLen int) (string, Status) {
name := make([]C.char, maxLen)
ret := C.call_amdsmi_get_socket_info(C.amdsmi_socket_handle(socketHandle.handle), C.size_t(maxLen), (*C.char)(unsafe.Pointer(&name[0])))
if ret != C.AMDSMI_STATUS_SUCCESS {
return "", Status{Code: StatusCode(ret), Message: "get socket info"}
}
socketInfo := C.GoString(&name[0])
return socketInfo, Status{Code: StatusSuccess}
}
// GetProcessorHandles retrieves all processor handles for a given socket.
func GetProcessorHandles(socket SocketHandle) ([]ProcessorHandle, Status) {
var processorCount C.uint32_t
ret := C.call_amdsmi_get_processor_handles(C.amdsmi_socket_handle(socket.handle), &processorCount, nil)
if ret != C.AMDSMI_STATUS_SUCCESS {
return nil, Status{Code: StatusCode(ret), Message: "get processor count"}
}
processors := make([]C.amdsmi_processor_handle, processorCount)
ret = C.call_amdsmi_get_processor_handles(C.amdsmi_socket_handle(socket.handle), &processorCount, (*C.amdsmi_processor_handle)(unsafe.Pointer(&processors[0])))
if ret != C.AMDSMI_STATUS_SUCCESS {
return nil, Status{Code: StatusCode(ret), Message: "get processor handles"}
}
goProcessors := make([]ProcessorHandle, processorCount)
for i, processor := range processors {
goProcessors[i] = ProcessorHandle{handle: unsafe.Pointer(processor)}
}
return goProcessors, Status{Code: StatusSuccess}
}
// GetProcessorType retrieves the type of a given processor.
func GetProcessorType(processor ProcessorHandle) (ProcessorType, Status) {
var processorType C.processor_type_t
ret := C.call_amdsmi_get_processor_type(C.amdsmi_processor_handle(processor.handle), &processorType)
if ret != C.AMDSMI_STATUS_SUCCESS {
return 0, Status{Code: StatusCode(ret), Message: "get processor type"}
}
return ProcessorType(processorType), Status{Code: StatusSuccess}
}
// GetGPUBoardInfo retrieves the board information for a given GPU processor handle.
func GetGPUBoardInfo(processor ProcessorHandle) (BoardInfo, Status) {
var boardInfo C.amdsmi_board_info_t
ret := C.call_amdsmi_get_gpu_board_info(C.amdsmi_processor_handle(processor.handle), &boardInfo)
if ret != C.AMDSMI_STATUS_SUCCESS {
return BoardInfo{}, Status{Code: StatusCode(ret), Message: "get GPU board info"}
}
goBoardInfo := BoardInfo{
ModelNumber: C.GoString(&boardInfo.model_number[0]),
ProductSerial: C.GoString(&boardInfo.product_serial[0]),
FruID: C.GoString(&boardInfo.fru_id[0]),
ProductName: C.GoString(&boardInfo.product_name[0]),
ManufacturerName: C.GoString(&boardInfo.manufacturer_name[0]),
}
return goBoardInfo, Status{Code: StatusSuccess}
}
// GetGPUID retrieves the GPU ID for a given processor handle.
func GetGPUID(processor ProcessorHandle) (uint32, Status) {
var gpuID C.uint16_t
ret := C.call_amdsmi_get_gpu_id(C.amdsmi_processor_handle(processor.handle), &gpuID)
if ret != C.AMDSMI_STATUS_SUCCESS {
return 0, Status{Code: StatusCode(ret), Message: "get GPU ID"}
}
return uint32(gpuID), Status{Code: StatusSuccess}
}
// GetGPUBDFID retrieves the GPU BDF ID for a given processor handle.
func GetGPUBDFID(processor ProcessorHandle) (uint64, Status) {
var bdfID C.uint64_t
ret := C.call_amdsmi_get_gpu_bdf_id(C.amdsmi_processor_handle(processor.handle), &bdfID)
if ret != C.AMDSMI_STATUS_SUCCESS {
return 0, Status{Code: StatusCode(ret), Message: "get GPU BDF ID"}
}
return uint64(bdfID), Status{Code: StatusSuccess}
}
// GetGPUUUID retrieves the GPU UUID for a given processor handle.
func GetGPUUUID(processor ProcessorHandle) (string, Status) {
var uuid [38]C.char
var length C.uint = 38
ret := C.call_amdsmi_get_gpu_device_uuid(C.amdsmi_processor_handle(processor.handle), &length, (*C.char)(unsafe.Pointer(&uuid)))
if ret != C.AMDSMI_STATUS_SUCCESS {
return "", Status{Code: StatusCode(ret), Message: "get GPU UUID"}
}
return C.GoString(&uuid[0]), Status{Code: StatusSuccess}
}
// GetGPUVRAM retrieves the GPU VRAM stats for a given processor handle.
func GetGPUVRAM(processor ProcessorHandle) (VRAM, Status) {
var vramUsage C.amdsmi_vram_usage_t
ret := C.call_amdsmi_get_gpu_vram_usage(C.amdsmi_processor_handle(processor.handle), &vramUsage)
if ret != C.AMDSMI_STATUS_SUCCESS {
return VRAM{}, Status{Code: StatusCode(ret), Message: "get GPU VRAM usage"}
}
goVRAM := VRAM{
Total: uint32(vramUsage.vram_total),
Used: uint32(vramUsage.vram_used),
}
return goVRAM, Status{Code: StatusSuccess}
}
// GetGPUASICInfo retrieves the ASIC information for a given GPU processor handle.
// This includes the number of compute units (GPU cores).
// Returns an error status if the function is not available in the loaded library.
func GetGPUASICInfo(processor ProcessorHandle) (ASICInfo, Status) {
if C.is_gpu_asic_info_available() == 0 {
return ASICInfo{}, Status{Code: StatusNotSupported, Message: "amdsmi_get_gpu_asic_info not available in loaded library"}
}
var asicInfo C.amdsmi_asic_info_t
ret := C.call_amdsmi_get_gpu_asic_info(C.amdsmi_processor_handle(processor.handle), &asicInfo)
if ret != C.AMDSMI_STATUS_SUCCESS {
return ASICInfo{}, Status{Code: StatusCode(ret), Message: "get GPU ASIC info"}
}
goASICInfo := ASICInfo{
MarketName: C.GoString(&asicInfo.market_name[0]),
VendorID: uint32(asicInfo.vendor_id),
VendorName: C.GoString(&asicInfo.vendor_name[0]),
SubvendorID: uint32(asicInfo.subvendor_id),
DeviceID: uint64(asicInfo.device_id),
RevID: uint32(asicInfo.rev_id),
ASICSerial: C.GoString(&asicInfo.asic_serial[0]),
OAMID: uint32(asicInfo.oam_id),
NumComputeUnits: uint32(asicInfo.num_of_compute_units),
TargetGfxVersion: uint64(asicInfo.target_graphics_version),
}
return goASICInfo, Status{Code: StatusSuccess}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && (amd64 || 386)
package amdsmi
import "fmt"
// StatusCode is a Go type to represent amdsmi_status_t from C.
type StatusCode uint32
const (
// General success and error codes
StatusSuccess StatusCode = 0 // AMDSMI_STATUS_SUCCESS
StatusInval StatusCode = 1 // AMDSMI_STATUS_INVAL
StatusNotSupported StatusCode = 2 // AMDSMI_STATUS_NOT_SUPPORTED
StatusNotYetImplemented StatusCode = 3 // AMDSMI_STATUS_NOT_YET_IMPLEMENTED
StatusFailLoadModule StatusCode = 4 // AMDSMI_STATUS_FAIL_LOAD_MODULE
StatusFailLoadSymbol StatusCode = 5 // AMDSMI_STATUS_FAIL_LOAD_SYMBOL
StatusDRMError StatusCode = 6 // AMDSMI_STATUS_DRM_ERROR
StatusAPIFailed StatusCode = 7 // AMDSMI_STATUS_API_FAILED
StatusTimeout StatusCode = 8 // AMDSMI_STATUS_TIMEOUT
StatusRetry StatusCode = 9 // AMDSMI_STATUS_RETRY
StatusNoPerm StatusCode = 10 // AMDSMI_STATUS_NO_PERM
StatusInterrupt StatusCode = 11 // AMDSMI_STATUS_INTERRUPT
StatusIO StatusCode = 12 // AMDSMI_STATUS_IO
StatusAddressFault StatusCode = 13 // AMDSMI_STATUS_ADDRESS_FAULT
StatusFileError StatusCode = 14 // AMDSMI_STATUS_FILE_ERROR
StatusOutOfResources StatusCode = 15 // AMDSMI_STATUS_OUT_OF_RESOURCES
StatusInternalException StatusCode = 16 // AMDSMI_STATUS_INTERNAL_EXCEPTION
StatusInputOutOfBounds StatusCode = 17 // AMDSMI_STATUS_INPUT_OUT_OF_BOUNDS
StatusInitError StatusCode = 18 // AMDSMI_STATUS_INIT_ERROR
StatusRefcountOverflow StatusCode = 19 // AMDSMI_STATUS_REFCOUNT_OVERFLOW
// Device related errors (starting from 30)
StatusBusy StatusCode = 30 // AMDSMI_STATUS_BUSY
StatusNotFound StatusCode = 31 // AMDSMI_STATUS_NOT_FOUND
StatusNotInit StatusCode = 32 // AMDSMI_STATUS_NOT_INIT
StatusNoSlot StatusCode = 33 // AMDSMI_STATUS_NO_SLOT
StatusDriverNotLoaded StatusCode = 34 // AMDSMI_STATUS_DRIVER_NOT_LOADED
// Data and size errors (starting from 40)
StatusNoData StatusCode = 40 // AMDSMI_STATUS_NO_DATA
StatusInsufficientSize StatusCode = 41 // AMDSMI_STATUS_INSUFFICIENT_SIZE
StatusUnexpectedSize StatusCode = 42 // AMDSMI_STATUS_UNEXPECTED_SIZE
StatusUnexpectedData StatusCode = 43 // AMDSMI_STATUS_UNEXPECTED_DATA
// General errors with specific values
StatusMapError StatusCode = 0xFFFFFFFE // AMDSMI_STATUS_MAP_ERROR
StatusUnknownError StatusCode = 0xFFFFFFFF // AMDSMI_STATUS_UNKNOWN_ERROR
)
// String returns the string representation of the StatusCode
func (code StatusCode) String() string {
switch code {
case StatusSuccess:
return "Success"
case StatusInval:
return "Inval"
case StatusNotSupported:
return "NotSupported"
case StatusNotYetImplemented:
return "NotYetImplemented"
case StatusFailLoadModule:
return "FailLoadModule"
case StatusFailLoadSymbol:
return "FailLoadSymbol"
case StatusDRMError:
return "DRMError"
case StatusAPIFailed:
return "APIFailed"
case StatusTimeout:
return "Timeout"
case StatusRetry:
return "Retry"
case StatusNoPerm:
return "NoPerm"
case StatusInterrupt:
return "Interrupt"
case StatusIO:
return "IO"
case StatusAddressFault:
return "AddressFault"
case StatusFileError:
return "FileError"
case StatusOutOfResources:
return "OutOfResources"
case StatusInternalException:
return "InternalException"
case StatusInputOutOfBounds:
return "InputOutOfBounds"
case StatusInitError:
return "InitError"
case StatusRefcountOverflow:
return "RefcountOverflow"
// Device related errors
case StatusBusy:
return "Busy"
case StatusNotFound:
return "NotFound"
case StatusNotInit:
return "NotInit"
case StatusNoSlot:
return "NoSlot"
case StatusDriverNotLoaded:
return "DriverNotLoaded"
// Data and size errors
case StatusNoData:
return "NoData"
case StatusInsufficientSize:
return "InsufficientSize"
case StatusUnexpectedSize:
return "UnexpectedSize"
case StatusUnexpectedData:
return "UnexpectedData"
// General errors
case StatusMapError:
return "MapError"
case StatusUnknownError:
return "UnknownError"
default:
return fmt.Sprintf("AMDSMIStatusCode(%d)", code)
}
}
// Status a wrapper around StatusCode and a message
type Status struct {
Code StatusCode
Message string
}
// Error returns the error message if the status is not success
func (s Status) Error() error {
if s.Code == StatusSuccess {
return nil
}
return fmt.Errorf("%s %s(%d)", s.Message, s.Code.String(), s.Code)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && (amd64 || 386)
package gpu
import (
"fmt"
"sync"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
var ErrNotInitialised = fmt.Errorf("GPU manager not initialised")
// gpuConnectors is a struct that holds the GPU connectors for different vendors.
type gpuConnectors struct {
nvidia types.GPUConnector
amd types.GPUConnector
intel types.GPUConnector
}
// gpuManager implements the types.GPUManager interface.
type gpuManager struct {
initialised bool
gpuIndexCache map[string]int // UUID to gpuCache index map
gpuCache types.GPUs // Cache of GPUs
connectors gpuConnectors
lock sync.RWMutex
}
var _ types.GPUManager = &gpuManager{}
// NewGPUManager creates a new GPU manager.
func NewGPUManager() types.GPUManager {
nvidiaConnector, err := newNVIDIAGPUConnector()
if err != nil {
log.Debugw("could not create NVIDIA GPU connector",
"labels", string(observability.LabelNode),
"error", err)
}
amdConnector, err := newAMDGPUConnector()
if err != nil {
log.Debugw("could not create AMD GPU connector",
"labels", string(observability.LabelNode),
"error", err)
}
intelConnector, err := newIntelGPUConnector()
if err != nil {
log.Debugw("could not create Intel GPU connector",
"labels", string(observability.LabelNode),
"error", err)
}
connector := gpuConnectors{
nvidia: nvidiaConnector,
amd: amdConnector,
intel: intelConnector,
}
return &gpuManager{
initialised: true,
gpuIndexCache: make(map[string]int),
connectors: connector,
}
}
// getGPUs is a helper function to get the GPUs from the device managers
func (g *gpuManager) getGPUs() (types.GPUs, error) {
var gpus []types.GPU
if g.connectors.nvidia != nil {
nvidiaGPUs, err := g.connectors.nvidia.GetGPUs()
if err != nil {
return nil, fmt.Errorf("get NVIDIA GPUs: %w", err)
}
gpus = append(gpus, nvidiaGPUs...)
}
if g.connectors.amd != nil {
amdGPUs, err := g.connectors.amd.GetGPUs()
if err != nil {
return nil, fmt.Errorf("get AMD GPUs: %w", err)
}
gpus = append(gpus, amdGPUs...)
}
if g.connectors.intel != nil {
intelGPUs, err := g.connectors.intel.GetGPUs()
if err != nil {
return nil, fmt.Errorf("get Intel GPUs: %w", err)
}
gpus = append(gpus, intelGPUs...)
}
return gpus, nil
}
// GetGPUs returns the GPUs in the system
func (g *gpuManager) GetGPUs() (types.GPUs, error) {
if !g.initialised {
return nil, ErrNotInitialised
}
// Check if the GPUs are cached
g.lock.RLock()
if len(g.gpuCache) > 0 {
g.lock.RUnlock()
return g.gpuCache.Copy(), nil
}
g.lock.RUnlock()
g.lock.Lock()
defer g.lock.Unlock()
gpus, err := g.getGPUs()
if err != nil {
g.lock.Unlock()
return nil, fmt.Errorf("get GPUs: %w", err)
}
// Assign index to GPUs
// Note: The index is internal to dms and is not the same as the device index
gpus = assignIndexToGPUs(gpus)
// Cache the GPUs
g.gpuCache = gpus
for i, gpu := range gpus {
g.gpuIndexCache[gpu.UUID] = i
}
return g.gpuCache.Copy(), nil
}
// getGPUUsage a helper function to get the GPU usage based on the vendor
func (g *gpuManager) getGPUUsage(uuid string, vendor types.GPUVendor) (uint64, error) {
switch vendor {
case types.GPUVendorNvidia:
usage, err := g.connectors.nvidia.GetGPUUsage(uuid)
if err != nil {
return 0, fmt.Errorf("get nvidia gpu usage: %w", err)
}
return usage, nil
case types.GPUVendorAMDATI:
usage, err := g.connectors.amd.GetGPUUsage(uuid)
if err != nil {
return 0, fmt.Errorf("get amd gpu usage: %w", err)
}
return usage, nil
case types.GPUVendorIntel:
usage, err := g.connectors.intel.GetGPUUsage(uuid)
if err != nil {
return 0, fmt.Errorf("get intel gpu usage: %w", err)
}
return usage, nil
default:
return 0, fmt.Errorf("unsupported vendor")
}
}
// GetGPUUsage returns the GPU usage based on the specified uuid.
// if uuid is empty, it returns the usage of all GPUs.
func (g *gpuManager) GetGPUUsage(uuid ...string) (types.GPUs, error) {
if !g.initialised {
return nil, ErrNotInitialised
}
// Check if there are gpus
gpuCache, err := g.GetGPUs()
if err != nil {
return nil, fmt.Errorf("get gpus: %w", err)
}
if len(gpuCache) == 0 {
return nil, nil
}
// Get the GPUs based on the UUID
var gpus []types.GPU
if len(uuid) == 0 {
// copy the GPU cache
gpus = gpuCache
} else {
// Get the GPUs based on the UUID
for _, u := range uuid {
g.lock.RLock()
if index, ok := g.gpuIndexCache[u]; ok {
gpus = append(gpus, g.gpuCache[index])
}
g.lock.RUnlock()
}
}
if len(gpus) == 0 {
return nil, fmt.Errorf("no GPUs found for the specified parameters")
}
for i, gpu := range gpus {
// Get the GPU usage based on the vendor
usage, err := g.getGPUUsage(gpu.UUID, gpu.Vendor)
if err != nil {
return nil, fmt.Errorf("get gpu usage: %w", err)
}
gpus[i].VRAM = usage
}
return gpus, nil
}
// Shutdown shuts down the GPU manager
// TODO: this results in a permanent shutdown of the GPU manager and we don't have a way to restart it yet
func (g *gpuManager) Shutdown() error {
if !g.initialised {
return ErrNotInitialised
}
g.lock.Lock()
defer g.lock.Unlock()
if g.connectors.nvidia != nil {
if err := g.connectors.nvidia.Shutdown(); err != nil {
log.Errorw("shutdown nvml",
"labels", string(observability.LabelNode),
"error", err)
}
}
if g.connectors.amd != nil {
if err := g.connectors.amd.Shutdown(); err != nil {
log.Errorw("shutdown amdsmi",
"labels", string(observability.LabelNode),
"error", err)
}
}
if g.connectors.intel != nil {
if err := g.connectors.intel.Shutdown(); err != nil {
log.Errorw("shutdowm xpum",
"labels", string(observability.LabelNode),
"error", err)
}
}
g.connectors = gpuConnectors{}
g.initialised = false
// Clear the cache
g.gpuCache = nil
g.gpuIndexCache = nil
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && (amd64 || 386)
package gpu
import (
"fmt"
"strconv"
"gitlab.com/nunet/device-management-service/lib/hardware/gpu/xpum"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// intelGPUConnector implements the types.GPUConnector interface for Intel GPUs.
type intelGPUConnector struct {
deviceCache []xpum.DeviceBasicInfo
deviceIndexMap map[string]int32 // UUID to index of deviceCache map
}
var _ types.GPUConnector = (*intelGPUConnector)(nil)
// newIntelGPUConnector creates a new Intel GPU Connector.
func newIntelGPUConnector() (types.GPUConnector, error) {
connector := &intelGPUConnector{
deviceIndexMap: make(map[string]int32),
}
if err := connector.initialize(); err != nil {
return nil, fmt.Errorf("initialize Intel GPU connector: %w", err)
}
if err := connector.loadDevices(); err != nil {
return nil, fmt.Errorf("load Intel GPU devices: %w", err)
}
return connector, nil
}
// initialize loads the Intel XPUM library and initializes the connector by loading the devices.
func (i *intelGPUConnector) initialize() error {
ret, err := xpum.Init()
if err != nil {
return fmt.Errorf("initialize Intel XPUM: %w", err)
}
if ret.Code != xpum.ResultOk {
return fmt.Errorf("intel XPUM initialization was unsuccessful: %w", ret.Error())
}
return nil
}
// loadDevices loads the Intel GPU devices.
func (i *intelGPUConnector) loadDevices() error {
deviceList, ret := xpum.GetDeviceList()
if ret.Code != xpum.ResultOk {
return fmt.Errorf("get Intel GPU device list: %w", ret.Error())
}
for _, device := range deviceList {
i.deviceCache = append(i.deviceCache, device)
i.deviceIndexMap[device.UUID] = device.DeviceID
}
return nil
}
// getTotalVRAM returns the total VRAM for the device with the given deviceID.
func (i *intelGPUConnector) getTotalVRAM(deviceID int32) (uint64, error) {
deviceProps, ret := xpum.GetDeviceProperties(deviceID)
if ret.Code != xpum.ResultOk {
return 0, fmt.Errorf("get properties for device %d: %w", deviceID, ret.Error())
}
for _, prop := range deviceProps {
if prop.Name == xpum.DevicePropertyMemoryPhysicalSizeByte {
totalMemory, err := strconv.ParseUint(prop.Value, 10, 64)
if err != nil {
return 0, fmt.Errorf("parse total memory for device %d: %w", deviceID, err)
}
return totalMemory, nil
}
}
return 0, fmt.Errorf("total memory property not found for device %d", deviceID)
}
// GetGPUs returns the Intel GPUs.
func (i *intelGPUConnector) GetGPUs() (types.GPUs, error) {
gpus := make(types.GPUs, 0, len(i.deviceCache))
for _, device := range i.deviceCache {
vram, err := i.getTotalVRAM(device.DeviceID)
if err != nil {
return nil, fmt.Errorf("get total VRAM for intel device %d: %w", device.DeviceID, err)
}
// Get EU (Execution Unit) count; default to 0 if unavailable
var cores uint32
deviceProps, ret := xpum.GetDeviceProperties(device.DeviceID)
if ret.Code == xpum.ResultOk {
for _, prop := range deviceProps {
if prop.Name == xpum.DevicePropertyNumberOfEUs {
euCount, err := strconv.ParseUint(prop.Value, 10, 32)
if err == nil {
cores = uint32(euCount)
} else {
log.Debugw("could not parse EU count for Intel GPU",
"labels", string(observability.LabelNode),
"error", err, "deviceID", device.DeviceID)
}
break
}
}
}
gpu := types.GPU{
UUID: device.UUID,
Model: device.DeviceName,
VRAM: vram, // VRAM in bytes
Cores: cores,
Vendor: types.GPUVendorIntel,
PCIAddress: device.PCIBDFAddress,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// GetGPUUsage returns the GPU usage for the device with the given UUID.
func (i *intelGPUConnector) GetGPUUsage(uuid string) (uint64, error) {
deviceIndex, ok := i.deviceIndexMap[uuid]
if !ok {
return 0, fmt.Errorf("intel device with UUID %s not found", uuid)
}
device := i.deviceCache[deviceIndex]
stats, err := xpum.GetDeviceStats(device.DeviceID, 0)
if err != nil {
return 0, fmt.Errorf("getting device stats for %d: %w", device.DeviceID, err)
}
for _, stat := range stats {
for _, data := range stat.DataList {
if data.MetricsType == xpum.StatsMemoryUsed {
usedMemory := data.Value
return usedMemory, nil
}
}
}
return 0, fmt.Errorf("used memory not found for device %d", device.DeviceID)
}
// Shutdown shuts down the Intel GPU connector.
func (i *intelGPUConnector) Shutdown() error {
ret := xpum.Shutdown()
if ret.Code != xpum.ResultOk {
return fmt.Errorf("shutdown Intel XPUM: %w", ret.Error())
}
// Clear the cache
i.deviceCache = nil
i.deviceIndexMap = nil
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && (amd64 || 386)
package gpu
import (
"errors"
"fmt"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"gitlab.com/nunet/device-management-service/types"
)
// nvidiaGPUConnector implements the types.GPUConnector interface for NVIDIA GPUs.
type nvidiaGPUConnector struct {
deviceCache []nvml.Device
deviceCacheIndexMap map[string]int // UUID to deviceCache index map
}
var _ types.GPUConnector = (*nvidiaGPUConnector)(nil)
// newNVIDIAGPUConnector creates a new NVIDIA GPU Connector.
func newNVIDIAGPUConnector() (types.GPUConnector, error) {
connector := &nvidiaGPUConnector{
deviceCacheIndexMap: make(map[string]int),
}
if err := connector.initialise(); err != nil {
return nil, fmt.Errorf("initialize NVIDIA GPU connector: %w", err)
}
if err := connector.loadDevices(); err != nil {
return nil, fmt.Errorf("load NVIDIA GPU devices: %w", err)
}
return connector, nil
}
// getNVIDIADeviceCount returns the number of NVIDIA devices (GPUs).
func getNVIDIADeviceCount() (int, error) {
deviceCount, ret := nvml.DeviceGetCount()
if !errors.Is(ret, nvml.SUCCESS) {
return 0, fmt.Errorf("get device count: %s", nvml.ErrorString(ret))
}
return deviceCount, nil
}
// getNVIDIADeviceHandle returns the handle for the NVIDIA device by its index.
func getNVIDIADeviceHandle(index int) (nvml.Device, error) {
device, ret := nvml.DeviceGetHandleByIndex(index)
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("get device handle for device %d: %s", index, nvml.ErrorString(ret))
}
return device, nil
}
// getNVIDIADeviceName returns the name of the NVIDIA device.
func getNVIDIADeviceName(device nvml.Device) (string, error) {
name, ret := device.GetName()
if !errors.Is(ret, nvml.SUCCESS) {
return "", fmt.Errorf("get name for device: %s", nvml.ErrorString(ret))
}
return name, nil
}
// getNVIDIADeviceMemory returns the memory information for the NVIDIA device.
func getNVIDIADeviceMemory(device nvml.Device) (nvml.Memory, error) {
memory, ret := device.GetMemoryInfo()
if !errors.Is(ret, nvml.SUCCESS) {
return nvml.Memory{}, fmt.Errorf("get NVIDIA GPU memory info: %s", nvml.ErrorString(ret))
}
return memory, nil
}
// getNVIDIADeviceUUID returns the UUID of the NVIDIA device.
func getNVIDIADeviceUUID(device nvml.Device) (string, error) {
uuid, ret := device.GetUUID()
if !errors.Is(ret, nvml.SUCCESS) {
return "", fmt.Errorf("get UUID for device: %s", nvml.ErrorString(ret))
}
return uuid, nil
}
// getNVIDIACoreCount returns the number of GPU cores (CUDA cores) for the NVIDIA device.
func getNVIDIACoreCount(device nvml.Device) (int, error) {
cores, ret := device.GetNumGpuCores()
if !errors.Is(ret, nvml.SUCCESS) {
return 0, fmt.Errorf("get GPU core count: %s", nvml.ErrorString(ret))
}
return cores, nil
}
// getNVIDIAPCIAddress returns the PCI address for the NVIDIA device.
func getNVIDIAPCIAddress(device nvml.Device) (string, error) {
pciInfo, ret := device.GetPciInfo()
if !errors.Is(ret, nvml.SUCCESS) {
return "", fmt.Errorf("get PCI info for device: %s", nvml.ErrorString(ret))
}
return convertBusID(pciInfo.BusId), nil
}
// initialise initialises the nvidia gpu connector
func (n *nvidiaGPUConnector) initialise() error {
ret := nvml.Init()
if !errors.Is(ret, nvml.SUCCESS) {
return fmt.Errorf("initialize nvml: %s", nvml.ErrorString(ret))
}
return nil
}
// loadDevices loads the NVIDIA GPU devices.
func (n *nvidiaGPUConnector) loadDevices() error {
deviceCount, err := getNVIDIADeviceCount()
if err != nil {
return err
}
// Iterate over each device
for i := 0; i < deviceCount; i++ {
device, err := getNVIDIADeviceHandle(i)
if err != nil {
return err
}
uuid, err := getNVIDIADeviceUUID(device)
if err != nil {
return err
}
n.deviceCache = append(n.deviceCache, device)
n.deviceCacheIndexMap[uuid] = i
}
return nil
}
// GetGPUs returns the GPU information for NVIDIA GPUs.
func (n *nvidiaGPUConnector) GetGPUs() (types.GPUs, error) {
if len(n.deviceCache) == 0 {
return nil, nil
}
gpus := make(types.GPUs, 0, len(n.deviceCache))
for _, device := range n.deviceCache {
name, err := getNVIDIADeviceName(device)
if err != nil {
return nil, err
}
memory, err := getNVIDIADeviceMemory(device)
if err != nil {
return nil, err
}
pciAddress, err := getNVIDIAPCIAddress(device)
if err != nil {
return nil, err
}
uuid, err := getNVIDIADeviceUUID(device)
if err != nil {
return nil, err
}
var cores uint32
coreCount, err := getNVIDIACoreCount(device)
if err != nil {
log.Debugw("could not get GPU core count, defaulting to 0",
"error", err, "uuid", uuid)
} else {
cores = uint32(coreCount)
}
gpu := types.GPU{
UUID: uuid,
PCIAddress: pciAddress,
Model: name,
VRAM: memory.Total,
Cores: cores,
Vendor: types.GPUVendorNvidia,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// GetGPUUsage returns the GPU usage for the device with the given UUID.
func (n *nvidiaGPUConnector) GetGPUUsage(uuid string) (uint64, error) {
deviceIndex, ok := n.deviceCacheIndexMap[uuid]
if !ok {
return 0, fmt.Errorf("nvidia device with UUID %s not found", uuid)
}
device := n.deviceCache[deviceIndex]
memory, err := getNVIDIADeviceMemory(device)
if err != nil {
return 0, err
}
return memory.Used, nil
}
// Shutdown shuts down the NVIDIA Management Library.
func (n *nvidiaGPUConnector) Shutdown() error {
ret := nvml.Shutdown()
if !errors.Is(ret, nvml.SUCCESS) {
return fmt.Errorf("shutdown nvml: %s", nvml.ErrorString(ret))
}
// clear the cache
n.deviceCache = nil
n.deviceCacheIndexMap = nil
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package gpu
import (
"fmt"
"strings"
"gitlab.com/nunet/device-management-service/types"
)
// assignIndexToGPUs assigns an index to each GPU in the list starting from 0
func assignIndexToGPUs(gpus []types.GPU) []types.GPU {
for i := range gpus {
gpus[i].Index = i
}
return gpus
}
// bdfIDToPCIAddress converts a 64-bit BDFID to a standard PCI address string.
// The PCI address format is 'domain:bus:device.function'.
//
// Taken from the AMD SMI library:
// BDFID = ((DOMAIN & 0xffffffff) << 32) | ((BUS & 0xff) << 8) |
// ((DEVICE & 0x1f) <<3 ) | (FUNCTION & 0x7)
//
// | Name | Field |
// ---------- | ------- |
// | Domain | [64:32] |
// | Reserved | [31:16] |
// | Bus | [15: 8] |
// | Device | [ 7: 3] |
// | Function | [ 2: 0] |
func bdfIDToPCIAddress(bdfID uint64) string {
// Extract Domain: Bits [63:32]
domain := (bdfID >> 32) & 0xFFFFFFFF
// Extract Bus: Bits [15:8]
bus := (bdfID >> 8) & 0xFF
// Extract Device: Bits [7:3]
device := (bdfID >> 3) & 0x1F
// Extract Function: Bits [2:0]
function := bdfID & 0x7
// Format each component into hexadecimal with appropriate padding
// Domain: 4 hex digits, Bus: 2 hex digits, Device: 2 hex digits, Function: 1 hex digit
return fmt.Sprintf("%04X:%02X:%02X.%X", domain, bus, device, function)
}
// convertBusID converts the BusId array to a correctly formatted PCI address string.
func convertBusID(busID [32]int8) string {
busIDBytes := make([]byte, len(busID))
for i, b := range busID {
busIDBytes[i] = byte(b)
}
busIDStr := strings.TrimRight(string(busIDBytes), "\x00")
// Check if the string starts with extra zero groups and correct the format
if strings.HasPrefix(busIDStr, "00000000:") {
// Trim it to the correct format: "0000:XX:YY.Z"
return "0000" + busIDStr[8:]
}
return busIDStr
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && (amd64 || 386)
package xpum
import "fmt"
// ResultCode is a Go type to represent xpum_result_t.
type ResultCode uint32
const (
ResultOk ResultCode = iota // XPUM_OK = 0
ResultGenericError // XPUM_GENERIC_ERROR = 1
ResultBufferTooSmall // XPUM_BUFFER_TOO_SMALL = 2
ResultDeviceNotFound // XPUM_RESULT_DEVICE_NOT_FOUND = 3
ResultTileNotFound // XPUM_RESULT_TILE_NOT_FOUND = 4
ResultGroupNotFound // XPUM_RESULT_GROUP_NOT_FOUND = 5
ResultPolicyTypeInvalid // XPUM_RESULT_POLICY_TYPE_INVALID = 6
ResultPolicyActionTypeInvalid // XPUM_RESULT_POLICY_ACTION_TYPE_INVALID = 7
ResultPolicyConditionTypeInvalid // XPUM_RESULT_POLICY_CONDITION_TYPE_INVALID = 8
ResultPolicyTypeActionNotSupport // XPUM_RESULT_POLICY_TYPE_ACTION_NOT_SUPPORT = 9
ResultPolicyTypeConditionNotSupport // XPUM_RESULT_POLICY_TYPE_CONDITION_NOT_SUPPORT = 10
ResultPolicyInvalidThreshold // XPUM_RESULT_POLICY_INVALID_THRESHOLD = 11
ResultPolicyInvalidFrequency // XPUM_RESULT_POLICY_INVALID_FREQUENCY = 12
ResultPolicyNotExist // XPUM_RESULT_POLICY_NOT_EXIST = 13
ResultDiagnosticTaskNotComplete // XPUM_RESULT_DIAGNOSTIC_TASK_NOT_COMPLETE = 14
ResultDiagnosticTaskNotFound // XPUM_RESULT_DIAGNOSTIC_TASK_NOT_FOUND = 15
GroupDeviceDuplicated // XPUM_GROUP_DEVICE_DUPLICATED = 16
GroupChangeNotAllowed // XPUM_GROUP_CHANGE_NOT_ALLOWED = 17
NotInitialized // XPUM_NOT_INITIALIZED = 18
DumpRawDataTaskNotExist // XPUM_DUMP_RAW_DATA_TASK_NOT_EXIST = 19
DumpRawDataIllegalDumpFilePath // XPUM_DUMP_RAW_DATA_ILLEGAL_DUMP_FILE_PATH = 20
ResultUnknownAgentConfigKey // XPUM_RESULT_UNKNOWN_AGENT_CONFIG_KEY = 21
UpdateFirmwareImageFileNotFound // XPUM_UPDATE_FIRMWARE_IMAGE_FILE_NOT_FOUND = 22
UpdateFirmwareUnsupportedAmc // XPUM_UPDATE_FIRMWARE_UNSUPPORTED_AMC = 23
UpdateFirmwareUnsupportedAmcSingle // XPUM_UPDATE_FIRMWARE_UNSUPPORTED_AMC_SINGLE = 24
UpdateFirmwareUnsupportedGfxAll // XPUM_UPDATE_FIRMWARE_UNSUPPORTED_GFX_ALL = 25
UpdateFirmwareModelInconsistence // XPUM_UPDATE_FIRMWARE_MODEL_INCONSISTENCE = 26
UpdateFirmwareIgscNotFound // XPUM_UPDATE_FIRMWARE_IGSC_NOT_FOUND = 27
UpdateFirmwareTaskRunning // XPUM_UPDATE_FIRMWARE_TASK_RUNNING = 28
UpdateFirmwareInvalidFwImage // XPUM_UPDATE_FIRMWARE_INVALID_FW_IMAGE = 29
UpdateFirmwareFwImageNotCompatibleWithDevice // XPUM_UPDATE_FIRMWARE_FW_IMAGE_NOT_COMPATIBLE_WITH_DEVICE = 30
ResultDumpMetricsTypeNotSupport // XPUM_RESULT_DUMP_METRICS_TYPE_NOT_SUPPORT = 31
MetricNotSupported // XPUM_METRIC_NOT_SUPPORTED = 32
MetricNotEnabled // XPUM_METRIC_NOT_ENABLED = 33
ResultHealthInvalidType // XPUM_RESULT_HEALTH_INVALID_TYPE = 34
ResultHealthInvalidConfigType // XPUM_RESULT_HEALTH_INVALID_CONIG_TYPE = 35
ResultHealthInvalidThreshold // XPUM_RESULT_HEALTH_INVALID_THRESHOLD = 36
ResultDiagnosticInvalidLevel // XPUM_RESULT_DIAGNOSTIC_INVALID_LEVEL = 37
ResultDiagnosticInvalidTaskType // XPUM_RESULT_DIAGNOSTIC_INVALID_TASK_TYPE = 38
ResultAgentSetInvalidValue // XPUM_RESULT_AGENT_SET_INVALID_VALUE = 39
LevelZeroInitializationError // XPUM_LEVEL_ZERO_INITIALIZATION_ERROR = 40
UnsupportedSessionID // XPUM_UNSUPPORTED_SESSIONID = 41
ResultMemoryEccLibNotSupport // XPUM_RESULT_MEMORY_ECC_LIB_NOT_SUPPORT = 42
UpdateFirmwareUnsupportedGfxData // XPUM_UPDATE_FIRMWARE_UNSUPPORTED_GFX_DATA = 43
UpdateFirmwareUnsupportedPsc // XPUM_UPDATE_FIRMWARE_UNSUPPORTED_PSC = 44
UpdateFirmwareUnsupportedPscIgsc // XPUM_UPDATE_FIRMWARE_UNSUPPORTED_PSC_IGSC = 45
UpdateFirmwareUnsupportedGfxCodeData // XPUM_UPDATE_FIRMWARE_UNSUPPORTED_GFX_CODE_DATA = 46
IntervalInvalid // XPUM_INTERVAL_INVALID = 47
ResultFileDup // XPUM_RESULT_FILE_DUP = 48
ResultInvalidDir // XPUM_RESULT_INVALID_DIR = 49
ResultFwMgmtNotInit // XPUM_RESULT_FW_MGMT_NOT_INIT = 50
VgpuInvalidLmem // XPUM_VGPU_INVALID_LMEM = 51
VgpuInvalidNumVfs // XPUM_VGPU_INVALID_NUMVFS = 52
VgpuDirtyPf // XPUM_VGPU_DIRTY_PF = 53
VgpuVfUnsupportedOperation // XPUM_VGPU_VF_UNSUPPORTED_OPERATION = 54
VgpuCreateVfFailed // XPUM_VGPU_CREATE_VF_FAILED = 55
VgpuRemoveVfFailed // XPUM_VGPU_REMOVE_VF_FAILED = 56
VgpuNoConfigFile // XPUM_VGPU_NO_CONFIG_FILE = 57
VgpuSysfsError // XPUM_VGPU_SYSFS_ERROR = 58
VgpuUnsupportedDeviceModel // XPUM_VGPU_UNSUPPORTED_DEVICE_MODEL = 59
ResultResetFail // XPUM_RESULT_RESET_FAIL = 60
APIUnsupported // XPUM_API_UNSUPPORTED = 61
PrecheckInvalidSinceTime // XPUM_PRECHECK_INVALID_SINCETIME = 62
PprNotFound // XPUM_PPR_NOT_FOUND = 63
UpdateFirmwareGfxDataImageVersionLowerOrEqualToDevice // XPUM_UPDATE_FIRMWARE_GFX_DATA_IMAGE_VERSION_LOWER_OR_EQUAL_TO_DEVICE = 64
ResultUnsupportedDevice // XPUM_RESULT_UNSUPPORTED_DEVICE = 65
GroupLimitReached // XPUM_GROUP_LIMIT_REACHED = 66
)
// String returns the string representation of the ResultCode without the 'Result' prefix.
func (code ResultCode) String() string {
switch code {
case ResultOk:
return "Ok"
case ResultGenericError:
return "GenericError"
case ResultBufferTooSmall:
return "BufferTooSmall"
case ResultDeviceNotFound:
return "DeviceNotFound"
case ResultTileNotFound:
return "TileNotFound"
case ResultGroupNotFound:
return "GroupNotFound"
case ResultPolicyTypeInvalid:
return "PolicyTypeInvalid"
case ResultPolicyActionTypeInvalid:
return "PolicyActionTypeInvalid"
case ResultPolicyConditionTypeInvalid:
return "PolicyConditionTypeInvalid"
case ResultPolicyTypeActionNotSupport:
return "PolicyTypeActionNotSupport"
case ResultPolicyTypeConditionNotSupport:
return "PolicyTypeConditionNotSupport"
case ResultPolicyInvalidThreshold:
return "PolicyInvalidThreshold"
case ResultPolicyInvalidFrequency:
return "PolicyInvalidFrequency"
case ResultPolicyNotExist:
return "PolicyNotExist"
case ResultDiagnosticTaskNotComplete:
return "DiagnosticTaskNotComplete"
case ResultDiagnosticTaskNotFound:
return "DiagnosticTaskNotFound"
case GroupDeviceDuplicated:
return "GroupDeviceDuplicated"
case GroupChangeNotAllowed:
return "GroupChangeNotAllowed"
case NotInitialized:
return "NotInitialized"
case DumpRawDataTaskNotExist:
return "DumpRawDataTaskNotExist"
case DumpRawDataIllegalDumpFilePath:
return "DumpRawDataIllegalDumpFilePath"
case ResultUnknownAgentConfigKey:
return "UnknownAgentConfigKey"
case UpdateFirmwareImageFileNotFound:
return "UpdateFirmwareImageFileNotFound"
case UpdateFirmwareUnsupportedAmc:
return "UpdateFirmwareUnsupportedAmc"
case UpdateFirmwareUnsupportedAmcSingle:
return "UpdateFirmwareUnsupportedAmcSingle"
case UpdateFirmwareUnsupportedGfxAll:
return "UpdateFirmwareUnsupportedGfxAll"
case UpdateFirmwareModelInconsistence:
return "UpdateFirmwareModelInconsistence"
case UpdateFirmwareIgscNotFound:
return "UpdateFirmwareIgscNotFound"
case UpdateFirmwareTaskRunning:
return "UpdateFirmwareTaskRunning"
case UpdateFirmwareInvalidFwImage:
return "UpdateFirmwareInvalidFwImage"
case UpdateFirmwareFwImageNotCompatibleWithDevice:
return "UpdateFirmwareFwImageNotCompatibleWithDevice"
case ResultDumpMetricsTypeNotSupport:
return "DumpMetricsTypeNotSupport"
case MetricNotSupported:
return "MetricNotSupported"
case MetricNotEnabled:
return "MetricNotEnabled"
case ResultHealthInvalidType:
return "HealthInvalidType"
case ResultHealthInvalidConfigType:
return "HealthInvalidConfigType"
case ResultHealthInvalidThreshold:
return "HealthInvalidThreshold"
case ResultDiagnosticInvalidLevel:
return "DiagnosticInvalidLevel"
case ResultDiagnosticInvalidTaskType:
return "DiagnosticInvalidTaskType"
case ResultAgentSetInvalidValue:
return "AgentSetInvalidValue"
case LevelZeroInitializationError:
return "LevelZeroInitializationError"
case UnsupportedSessionID:
return "UnsupportedSessionId"
case ResultMemoryEccLibNotSupport:
return "MemoryEccLibNotSupport"
case UpdateFirmwareUnsupportedGfxData:
return "UpdateFirmwareUnsupportedGfxData"
case UpdateFirmwareUnsupportedPsc:
return "UpdateFirmwareUnsupportedPsc"
case UpdateFirmwareUnsupportedPscIgsc:
return "UpdateFirmwareUnsupportedPscIgsc"
case UpdateFirmwareUnsupportedGfxCodeData:
return "UpdateFirmwareUnsupportedGfxCodeData"
case IntervalInvalid:
return "IntervalInvalid"
case ResultFileDup:
return "FileDup"
case ResultInvalidDir:
return "InvalidDir"
case ResultFwMgmtNotInit:
return "FwMgmtNotInit"
case VgpuInvalidLmem:
return "VgpuInvalidLmem"
case VgpuInvalidNumVfs:
return "VgpuInvalidNumVfs"
case VgpuDirtyPf:
return "VgpuDirtyPf"
case VgpuVfUnsupportedOperation:
return "VgpuVfUnsupportedOperation"
case VgpuCreateVfFailed:
return "VgpuCreateVfFailed"
case VgpuRemoveVfFailed:
return "VgpuRemoveVfFailed"
case VgpuNoConfigFile:
return "VgpuNoConfigFile"
case VgpuSysfsError:
return "VgpuSysfsError"
case VgpuUnsupportedDeviceModel:
return "VgpuUnsupportedDeviceModel"
case ResultResetFail:
return "ResetFail"
case APIUnsupported:
return "ApiUnsupported"
case PrecheckInvalidSinceTime:
return "PrecheckInvalidSinceTime"
case PprNotFound:
return "PprNotFound"
case UpdateFirmwareGfxDataImageVersionLowerOrEqualToDevice:
return "UpdateFirmwareGfxDataImageVersionLowerOrEqualToDevice"
case ResultUnsupportedDevice:
return "UnsupportedDevice"
case GroupLimitReached:
return "GroupLimitReached"
default:
return fmt.Sprintf("ResultCode(%d)", code)
}
}
// Result is a wrapper around ResultCode and a message.
type Result struct {
Code ResultCode
Message string
}
// Error returns the error message if the ResultCode is not ResultOk.
func (r Result) Error() error {
if r.Code == ResultOk {
return nil
}
return fmt.Errorf("%s %s(%d)", r.Message, r.Code.String(), r.Code)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux && (amd64 || 386)
package xpum
/*
#cgo CXXFLAGS: -Iinclude -std=c++11
#cgo CFLAGS: -Iinclude
#cgo LDFLAGS: -ldl
#include <dlfcn.h>
#include <stdio.h>
#include <stdlib.h>
#include "xpum_api.h"
#include "xpum_structs.h"
// ========================
// XPUM Function Pointer Definitions
// ========================
// Macro to define a function pointer type
#define DEFINE_XPUM_FUNC_TYPE(ret, name, args) typedef ret (*name##_fp) args
// Define function pointer types for Intel XPUM functions
DEFINE_XPUM_FUNC_TYPE(xpum_result_t, xpumInit, (void));
DEFINE_XPUM_FUNC_TYPE(xpum_result_t, xpumShutdown, (void));
DEFINE_XPUM_FUNC_TYPE(xpum_result_t, xpumGetDeviceList, (xpum_device_basic_info deviceList[], int *count));
DEFINE_XPUM_FUNC_TYPE(xpum_result_t, xpumGetStats, (xpum_device_id_t deviceId, xpum_device_stats_t dataList[], uint32_t *count, uint64_t *begin, uint64_t *end, uint64_t sessionId));
DEFINE_XPUM_FUNC_TYPE(xpum_result_t, xpumGetDeviceProperties, (xpum_device_id_t deviceId, xpum_device_properties_t *pXpumProperties));
// ========================
// XPUM Function Pointers Struct
// ========================
typedef struct {
xpumInit_fp xpumInit;
xpumShutdown_fp xpumShutdown;
xpumGetDeviceList_fp xpumGetDeviceList;
xpumGetStats_fp xpumGetStats;
xpumGetDeviceProperties_fp xpumGetDeviceProperties;
} xpum_functions_t;
// ========================
// Global Variables
// ========================
static void *lib_handle = NULL;
static xpum_functions_t xpum_funcs;
// ========================
// Helper Macros
// ========================
// Macro to load a symbol and assign it to a struct member
#define LOAD_XPUM_SYMBOL(func_name) \
xpum_funcs.func_name = (func_name##_fp)dlsym(lib_handle, #func_name); \
if (!xpum_funcs.func_name) { \
fprintf(stderr, "Error loading symbol %s: %s\n", #func_name, dlerror()); \
dlclose(lib_handle); \
lib_handle = NULL; \
return 0; \
}
// Macro to define a wrapper function
#define DEFINE_XPUM_WRAPPER(ret_type, wrapper_name, xpum_func, args, ...) \
ret_type wrapper_name args { \
if (xpum_funcs.xpum_func) { \
return xpum_funcs.xpum_func(__VA_ARGS__); \
} \
return XPUM_GENERIC_ERROR; \
}
// ========================
// Library Management Functions
// ========================
// Load the XPUM library and resolve all required symbols
int load_xpum_library() {
if (lib_handle) {
fprintf(stderr, "libxpum.so is already loaded.\n");
return 1;
}
lib_handle = dlopen("libxpum.so", RTLD_LAZY);
if (!lib_handle) {
return 0;
}
// Clear any existing errors
dlerror();
// Load all required symbols
LOAD_XPUM_SYMBOL(xpumInit)
LOAD_XPUM_SYMBOL(xpumShutdown)
LOAD_XPUM_SYMBOL(xpumGetDeviceList)
LOAD_XPUM_SYMBOL(xpumGetStats)
LOAD_XPUM_SYMBOL(xpumGetDeviceProperties)
return 1;
}
// Unload the XPUM library and reset function pointers
void unload_xpum_library() {
if (lib_handle) {
dlclose(lib_handle);
lib_handle = NULL;
// Reset all function pointers to NULL
xpum_funcs.xpumInit = NULL;
xpum_funcs.xpumShutdown = NULL;
xpum_funcs.xpumGetDeviceList = NULL;
xpum_funcs.xpumGetStats = NULL;
xpum_funcs.xpumGetDeviceProperties = NULL;
}
}
// ========================
// Wrapper Functions
// ========================
DEFINE_XPUM_WRAPPER(xpum_result_t, call_xpumInit, xpumInit, (void))
DEFINE_XPUM_WRAPPER(xpum_result_t, call_xpumShutdown, xpumShutdown, (void))
DEFINE_XPUM_WRAPPER(xpum_result_t, call_xpumGetDeviceList, xpumGetDeviceList, (xpum_device_basic_info deviceList[], int *count), deviceList, count)
DEFINE_XPUM_WRAPPER(xpum_result_t, call_xpumGetStats, xpumGetStats, (xpum_device_id_t deviceId, xpum_device_stats_t dataList[], uint32_t *count, uint64_t *begin, uint64_t *end, uint64_t sessionId), deviceId, dataList, count, begin, end, sessionId)
DEFINE_XPUM_WRAPPER(xpum_result_t, call_xpumGetDeviceProperties, xpumGetDeviceProperties, (xpum_device_id_t deviceId, xpum_device_properties_t *pXpumProperties), deviceId, pXpumProperties)
*/
import "C"
import (
"fmt"
"os"
"time"
"unsafe"
"github.com/avast/retry-go"
)
/*
===============================================================================
Adding a New XPUM Function to the Cgo Block
===============================================================================
To add a new XPUM function, follow these steps:
1. **Define the Function Pointer Type:**
- Use `DEFINE_XPUM_FUNC_TYPE` to create a typedef for the new function.
- Example:
`DEFINE_XPUM_FUNC_TYPE(xpum_result_t, xpum_new_function, (int arg1, float arg2));`
2. **Add to the Struct:**
- Add the new function pointer to `xpum_functions_t`.
- Example:
`xpum_new_function_fp xpum_new_function;`
3. **Load the Symbol:**
- In `load_xpum_library`, load the new symbol using `LOAD_XPUM_SYMBOL`.
- Example:
`LOAD_XPUM_SYMBOL(xpum_new_function)`
4. **Create a Wrapper Function:**
- Define a wrapper using `DEFINE_XPUM_WRAPPER`.
- Example:
`DEFINE_XPUM_WRAPPER(xpum_result_t, call_xpum_new_function, xpum_new_function, (int arg1, float arg2), arg1, arg2)`
5. **Rebuild and Test:**
- Rebuild and verify the new function works as expected.
===============================================================================
*/
const (
StatsMemoryUsed = C.XPUM_STATS_MEMORY_USED
DevicePropertyMemoryPhysicalSizeByte = C.XPUM_DEVICE_PROPERTY_MEMORY_PHYSICAL_SIZE_BYTE
DevicePropertyNumberOfEUs = C.XPUM_DEVICE_PROPERTY_NUMBER_OF_EUS
)
// Init initializes the Intel XPUM library.
func Init() (Result, error) {
// TODO: Set the log level based on the dms log level
os.Setenv("SPDLOG_LEVEL", "off")
if C.load_xpum_library() == 0 {
return Result{}, fmt.Errorf("could not load XPUM library")
}
ret := C.call_xpumInit()
return Result{Code: ResultCode(ret)}, nil
}
// Shutdown shuts down the Intel XPUM library.
func Shutdown() Result {
ret := C.call_xpumShutdown()
C.unload_xpum_library()
return Result{Code: ResultCode(ret)}
}
// GetDeviceList retrieves the list of devices and converts them to Go structs.
func GetDeviceList() ([]DeviceBasicInfo, Result) {
const maxDevices = 32
var count C.int = maxDevices
var deviceList [maxDevices]C.xpum_device_basic_info
// Call the C wrapper function
ret := C.call_xpumGetDeviceList((*C.xpum_device_basic_info)(unsafe.Pointer(&deviceList[0])), &count)
if ret != C.XPUM_OK {
return nil, Result{Code: ResultCode(ret), Message: fmt.Sprintf("get device list: %v", ret)}
}
// Convert C array to Go slice of DeviceBasicInfo
goDevices := make([]DeviceBasicInfo, int(count))
for i := 0; i < int(count); i++ {
goDevices[i] = DeviceBasicInfo{
DeviceID: int32(deviceList[i].deviceId),
FunctionType: int32(deviceList[i].functionType),
UUID: C.GoString(&deviceList[i].uuid[0]),
DeviceName: C.GoString(&deviceList[i].deviceName[0]),
PCIDeviceID: C.GoString(&deviceList[i].PCIDeviceId[0]),
PCIBDFAddress: C.GoString(&deviceList[i].PCIBDFAddress[0]),
VendorName: C.GoString(&deviceList[i].VendorName[0]),
DRMDevice: C.GoString(&deviceList[i].drmDevice[0]),
}
}
return goDevices, Result{Code: ResultOk}
}
// GetDeviceStats retrieves device statistics for the specified device ID.
func GetDeviceStats(deviceID int32, sessionID uint64) ([]DeviceStats, error) {
const maxStats = 100
var count C.uint32_t = maxStats
var begin, end C.uint64_t
var statsList [maxStats]C.xpum_device_stats_t
// We're retrying here because the XPUM library may not have initialized properly
// Only the calls to xpumGetStats returns no data, so we retry this call
// TODO: Add a check to see if the XPUM library is initialized by calling a c function ( if it exists )
var goStats []DeviceStats
if err := retry.Do(
func() error {
// Call the C wrapper function
ret := C.call_xpumGetStats(C.xpum_device_id_t(deviceID), (*C.xpum_device_stats_t)(unsafe.Pointer(&statsList[0])), &count, &begin, &end, C.uint64_t(sessionID))
if ret != C.XPUM_OK {
return fmt.Errorf("get device stats: %v", ResultCode(ret).String())
}
// Convert C array to Go slice of DeviceStats
goStats = make([]DeviceStats, int(count))
hasData := false
for i := 0; i < int(count); i++ {
dataCount := int(statsList[i].count)
if dataCount != 0 {
hasData = true
}
goDataList := make([]DeviceStatsData, dataCount)
for j := 0; j < dataCount; j++ {
goDataList[j] = DeviceStatsData{
MetricsType: int32(statsList[i].dataList[j].metricsType),
IsCounter: bool(statsList[i].dataList[j].isCounter),
Value: uint64(statsList[i].dataList[j].value),
Accumulated: uint64(statsList[i].dataList[j].accumulated),
Min: uint64(statsList[i].dataList[j].min),
Avg: uint64(statsList[i].dataList[j].avg),
Max: uint64(statsList[i].dataList[j].max),
Scale: uint32(statsList[i].dataList[j].scale),
}
}
goStats[i] = DeviceStats{
DeviceID: int32(statsList[i].deviceId),
IsTileData: bool(statsList[i].isTileData),
TileID: int32(statsList[i].tileId),
Count: int32(statsList[i].count),
DataList: goDataList,
}
}
if !hasData {
return fmt.Errorf("no data for devices. It could be an xpum init issue. Retrying")
}
return nil
},
retry.Delay(1*time.Second),
retry.Attempts(3),
); err != nil {
return nil, fmt.Errorf("get device stats: %w", err)
}
return goStats, nil
}
func GetDeviceProperties(deviceID int32) ([]DeviceProperty, Result) {
const xpumMaxNumProperties = 100
// Prepare the properties struct
cProperties := C.xpum_device_properties_t{
deviceId: C.xpum_device_id_t(deviceID),
propertyLen: C.int(xpumMaxNumProperties),
}
// Call the C wrapper function
ret := C.call_xpumGetDeviceProperties(C.xpum_device_id_t(deviceID), &cProperties)
if ret != C.XPUM_OK {
return nil, Result{Code: ResultCode(ret), Message: fmt.Sprintf("get device properties: %v", ret)}
}
numProperties := int(cProperties.propertyLen)
goProperties := make([]DeviceProperty, numProperties)
for i := 0; i < numProperties; i++ {
cProp := cProperties.properties[i]
goProperties[i] = DeviceProperty{
Name: int32(cProp.name),
Value: C.GoString(&cProp.value[0]),
}
}
return goProperties, Result{Code: ResultOk}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package hardware
import (
"fmt"
"gitlab.com/nunet/device-management-service/lib/hardware/cpu"
"gitlab.com/nunet/device-management-service/lib/hardware/gpu"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// defaultHardwareManager manages the machine's hardware resources.
type defaultHardwareManager struct {
gpuManager types.GPUManager
cpuMonitor *cpu.Monitor
}
// NewHardwareManager creates a new instance of defaultHardwareManager.
func NewHardwareManager() types.HardwareManager {
// init cpu monitoring
return &defaultHardwareManager{
gpuManager: gpu.NewGPUManager(),
cpuMonitor: cpu.NewCPUMonitor(),
}
}
var _ types.HardwareManager = (*defaultHardwareManager)(nil)
// GetMachineResources returns the resources of the machine in a thread-safe manner.
func (m *defaultHardwareManager) GetMachineResources() (types.MachineResources, error) {
var err error
var cpuDetails types.CPU
var ram types.RAM
var gpus []types.GPU
var diskDetails types.Disk
if cpuDetails, err = cpu.GetCPU(); err != nil {
return types.MachineResources{}, fmt.Errorf("get CPU: %w", err)
}
if ram, err = GetRAM(); err != nil {
return types.MachineResources{}, fmt.Errorf("get RAM: %w", err)
}
if gpus, err = m.gpuManager.GetGPUs(); err != nil {
return types.MachineResources{}, fmt.Errorf("get GPUs: %w", err)
}
if diskDetails, err = GetDisk(); err != nil {
return types.MachineResources{}, fmt.Errorf("get Disk: %w", err)
}
return types.MachineResources{
Resources: types.Resources{
CPU: cpuDetails,
RAM: ram,
Disk: diskDetails,
GPUs: gpus,
},
}, nil
}
// GetUsage returns the usage of the machine.
func (m *defaultHardwareManager) GetUsage() (types.Resources, error) {
avgCPUUsage, err := m.cpuMonitor.GetAvgCPUUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("get CPU usage: %w", err)
}
// Log CPU usage with "accounting" and "metric" labels
log.Debugw("cpu_usage_computed",
"labels", string(observability.LabelAccounting),
"usage", avgCPUUsage)
ram, err := GetRAMUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("get RAM usage: %w", err)
}
// Log RAM usage
log.Debugw("ram_usage_computed",
"labels", string(observability.LabelAccounting),
"usedMemoryBytes", ram.Size)
diskDetails, err := GetDiskUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("get Disk usage: %w", err)
}
// Log disk usage
log.Debugw("disk_usage_computed",
"labels", string(observability.LabelAccounting),
"usedStorageBytes", diskDetails.Size)
gpus, err := m.gpuManager.GetGPUUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("get GPU usage: %w", err)
}
// Log GPU usage
for _, gpuItem := range gpus {
log.Debugw("gpu_usage_computed",
"labels", string(observability.LabelAccounting),
"gpuUUID", gpuItem.UUID,
"vendor", gpuItem.Vendor,
"usedVRAM", gpuItem.VRAM)
}
return types.Resources{
CPU: avgCPUUsage,
RAM: ram,
Disk: diskDetails,
GPUs: gpus,
}, nil
}
// GetFreeResources returns the free resources of the machine.
func (m *defaultHardwareManager) GetFreeResources() (types.Resources, error) {
usage, err := m.GetUsage()
if err != nil {
return types.Resources{}, fmt.Errorf("get usage: %w", err)
}
availableResources, err := m.GetMachineResources()
if err != nil {
return types.Resources{}, fmt.Errorf("get machine resources: %w", err)
}
log.Debugw("resources_available", "labels", string(observability.LabelNode), "resources", availableResources)
log.Debugw("resources_used", "labels", string(observability.LabelNode), "resources", usage)
if err := availableResources.Subtract(usage); err != nil {
return types.Resources{}, fmt.Errorf("no free resources: %w", err)
}
return availableResources.Resources, nil
}
// CheckCapacity checks if the machine has enough resources to commit/allocate.
func (m *defaultHardwareManager) CheckCapacity(expected types.Resources) (bool, error) {
// Check if there are enough free resources on the machine to allocate
freeResources, err := m.GetFreeResources()
if err != nil {
return false, fmt.Errorf("get free resources: %w", err)
}
if err := freeResources.Subtract(expected); err != nil {
return false, fmt.Errorf("no free resources on the machine: %w", err)
}
return true, nil
}
// Shutdown shuts down the hardware manager.
func (m *defaultHardwareManager) Shutdown() error {
// shutdown gpu manager
if err := m.gpuManager.Shutdown(); err != nil {
return fmt.Errorf("shutdown gpu manager: %w", err)
}
// shutdown cpu monitor
m.cpuMonitor.Stop()
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package hardware
import (
"fmt"
"gitlab.com/nunet/device-management-service/types"
)
type mockHardwareManager struct {
machineResources types.MachineResources
freeResources types.Resources
usedResources types.Resources
}
var _ types.HardwareManager = (*mockHardwareManager)(nil)
func NewMockHardwareManager(
machineResources types.MachineResources,
freeResources types.Resources,
usedResources types.Resources,
) types.HardwareManager {
return &mockHardwareManager{
machineResources: machineResources,
freeResources: freeResources,
usedResources: usedResources,
}
}
func (m *mockHardwareManager) GetMachineResources() (types.MachineResources, error) {
return m.machineResources, nil
}
func (m *mockHardwareManager) GetFreeResources() (types.Resources, error) {
return m.freeResources, nil
}
func (m *mockHardwareManager) GetUsage() (types.Resources, error) {
return m.usedResources, nil
}
func (m *mockHardwareManager) CheckCapacity(resources types.Resources) (bool, error) {
if err := m.freeResources.Subtract(resources); err != nil {
return false, fmt.Errorf("%w: %w", types.ErrNoFreeResources, err)
}
return true, nil
}
func (m *mockHardwareManager) Shutdown() error { return nil }
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package hardware
import (
"fmt"
"github.com/shirou/gopsutil/v4/mem"
"gitlab.com/nunet/device-management-service/types"
)
// GetRAM returns the types.RAM information for the system
func GetRAM() (types.RAM, error) {
v, err := mem.VirtualMemory()
if err != nil {
return types.RAM{}, fmt.Errorf("failed to get total memory: %s", err)
}
return types.RAM{
Size: v.Total,
}, nil
}
// GetRAMUsage returns the RAM usage
func GetRAMUsage() (types.RAM, error) {
v, err := mem.VirtualMemory()
if err != nil {
return types.RAM{}, fmt.Errorf("failed to get total memory: %s", err)
}
return types.RAM{
Size: v.Used,
}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ucan
import (
"strings"
)
type Capability string
const Root = Capability("/")
func (c Capability) Implies(other Capability) bool {
if c == other || c == Root {
return true
}
return strings.HasPrefix(string(other), string(c)+"/")
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ucan
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
)
const (
maxCapabilitySize = 16384
SelfSignNo SelfSignMode = iota
SelfSignAlso
SelfSignOnly
)
type SelfSignMode int
// CapabilityContext exposes the necessary functionalities to manage capabilities
// between different contexts. The work is based on UCAN but we're not
// strictly following its specs.
//
// TODO: explain side-chains
//
// TODO: explain anchor concept
//
// Some concepts:
//
// - Issuer: the one delegating/granting/invoking capabilities. Responsible for signing the token.
// - Audience: is the resource which the capabilities can be applied upon.
// - Subject:
// - is the receiver, when delegating/granting capabilities
// - is the invoker, when invoking capabilities
type CapabilityContext interface {
// Name returns the context name
Name() string
// DID returns the context's controlling DID
DID() did.DID
// Trust returns the context's did trust context
Trust() did.TrustContext
// Consume ingests some or all of the provided capability tokens.
// It'll only return an error if all provided capabilities were not ingested.
Consume(origin did.DID, capToken []byte) error
// Discard discards previously consumed capability tokens
Discard(capTokens []byte)
// Require ensures that at least one of the capabilities is delegated from
// the subject to the audience, with an appropriate anchor
// An empty list will mean that no capabilities are required and is vacuously
// true.
//
// TODO (if necessary): create a RequireAll() since this method is basically a RequireAny()
Require(anchor did.DID, subject crypto.ID, audience crypto.ID, require []Capability) error
// RequireBroadcast ensures that at least one of the capabilities is delegated
// to thes subject for the specified broadcast topics
RequireBroadcast(origin did.DID, subject crypto.ID, topic string, require []Capability) error
// Provide prepares the appropriate capability tokens to prove and delegate authority
// to a subject for an audience.
// - It delegates invocations to the subject with an audience and invoke capabilities
// - It delegates the delegate capabilities to the target with audience the subject
Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, delegate []Capability) ([]byte, error)
// ProvideBroadcast prepares the appropriate capability tokens to prove authority
// to broadcast to a topic
ProvideBroadcast(subject crypto.ID, topic string, expire uint64, broadcast []Capability) ([]byte, error)
// AddRoots adds trust anchors
//
// require: regards to side-chains. It'll be used as one of the sources of truth when an entity is claiming having certain capabilities.
//
// provide: regards to the capabilities that we can delegate.
AddRoots(trust []did.DID, require, provide TokenList, revoke TokenList) error
// ListRoots list the current trust anchors
ListRoots() ([]did.DID, TokenList, TokenList, TokenList)
// RemoveRoots removes the specified trust anchors
RemoveRoots(trust []did.DID, require, provide TokenList, revoke TokenList)
// Delegate creates the appropriate delegation tokens anchored in our roots
Delegate(subject, audience did.DID, topics []string, expire, depth uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error)
// DelegateInvocation creates the appropriate invocation tokens anchored in anchor
DelegateInvocation(target, subject, audience did.DID, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error)
// DelegateBroadcast creates the appropriate broadcast token anchored in our roots
DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error)
// Grant creates the appropriate delegation tokens considering ourselves as the root
Grant(action Action, subject, audience did.DID, topic []string, expire, depth uint64, provide []Capability) (TokenList, error)
// Revoke creates a revocation for the provided token (token=(iss+sub+nonce))
Revoke(*Token) (*Token, error)
// Start starts a token garbage collector goroutine that clears expired tokens
Start(gcInterval time.Duration)
// Stop stops a previously started gc goroutine
Stop()
}
type BasicCapabilityContext struct {
mx sync.Mutex
name string
provider did.Provider
trust did.TrustContext
roots map[did.DID]struct{} // our root anchors of trust
require map[did.DID][]*Token // our acceptance side-roots
provide map[did.DID][]*Token // root capabilities -> tokens
tokens map[did.DID][]*Token // our context dependent capabilities; subject -> tokens
revoke *RevocationSet // revocation tokens
stop func()
}
var _ CapabilityContext = (*BasicCapabilityContext)(nil)
func newCapabilityContext(name string, trust did.TrustContext, ctxDID did.DID, roots []did.DID, require, provide TokenList, revoke TokenList) (*BasicCapabilityContext, error) {
ctx := &BasicCapabilityContext{
name: name,
trust: trust,
roots: make(map[did.DID]struct{}),
require: make(map[did.DID][]*Token),
provide: make(map[did.DID][]*Token),
revoke: &RevocationSet{revoked: make(map[string]*Token)},
tokens: make(map[did.DID][]*Token),
}
p, err := trust.GetProvider(ctxDID)
if err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
ctx.provider = p
if err := ctx.AddRoots(roots, require, provide, revoke); err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
return ctx, nil
}
func NewCapabilityContext(trust did.TrustContext, ctxDID did.DID, roots []did.DID, require, provide TokenList, revoke TokenList) (CapabilityContext, error) {
return newCapabilityContext("dms", trust, ctxDID, roots, require, provide, revoke)
}
func NewCapabilityContextWithName(name string, trust did.TrustContext, ctxDID did.DID, roots []did.DID, require, provide TokenList, revoke TokenList) (CapabilityContext, error) {
return newCapabilityContext(name, trust, ctxDID, roots, require, provide, revoke)
}
func (ctx *BasicCapabilityContext) Name() string {
return ctx.name
}
func (ctx *BasicCapabilityContext) DID() did.DID {
return ctx.provider.DID()
}
func (ctx *BasicCapabilityContext) Trust() did.TrustContext {
return ctx.trust
}
func (ctx *BasicCapabilityContext) Start(gcInterval time.Duration) {
if ctx.stop != nil {
gcCtx, cancel := context.WithCancel(context.Background())
go ctx.gc(gcCtx, gcInterval)
ctx.stop = cancel
}
}
func (ctx *BasicCapabilityContext) Stop() {
if ctx.stop != nil {
ctx.stop()
}
}
func (ctx *BasicCapabilityContext) AddRoots(roots []did.DID, require, provide, revoke TokenList) error {
ctx.addRoots(roots)
now := uint64(time.Now().UnixNano())
for _, t := range revoke.Tokens {
if t.Action() != Revoke {
return fmt.Errorf("verify token: %w", ErrBadToken)
}
if err := t.Verify(ctx.trust, now, ctx.revoke); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeRevokeToken(t)
}
for _, t := range require.Tokens {
if err := t.Verify(ctx.trust, now, ctx.revoke); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeRequireToken(t)
}
for _, t := range provide.Tokens {
if err := t.Verify(ctx.trust, now, ctx.revoke); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeProvideToken(t)
}
ctx.cleanUpTokens()
return nil
}
func (ctx *BasicCapabilityContext) cleanUpTokens() {
now := time.Now().UnixNano()
for _, anchor := range ctx.getRequireAnchors() {
tokenList := ctx.getRequireTokens(anchor)
// Use slices.DeleteFunc to safely remove invalid tokens
tokenList = slices.DeleteFunc(tokenList, func(t *Token) bool {
return t.Verify(ctx.trust, uint64(now), ctx.revoke) != nil
})
ctx.mx.Lock()
ctx.require[anchor] = tokenList
ctx.mx.Unlock()
}
for _, anchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(anchor)
// Use slices.DeleteFunc to safely remove invalid tokens
tokenList = slices.DeleteFunc(tokenList, func(t *Token) bool {
return t.Verify(ctx.trust, uint64(now), ctx.revoke) != nil
})
ctx.mx.Lock()
ctx.provide[anchor] = tokenList
ctx.mx.Unlock()
}
ctx.mx.Lock()
for subject, tokenList := range ctx.tokens {
// Use slices.DeleteFunc to safely remove invalid tokens
tokenList = slices.DeleteFunc(tokenList, func(t *Token) bool {
return t.Verify(ctx.trust, uint64(now), ctx.revoke) != nil
})
ctx.tokens[subject] = tokenList
}
ctx.mx.Unlock()
}
func (ctx *BasicCapabilityContext) ListRoots() ([]did.DID, TokenList, TokenList, TokenList) {
var require, provide, revoke []*Token
roots := ctx.getRoots()
for _, anchor := range ctx.getRequireAnchors() {
tokenList := ctx.getRequireTokens(anchor)
require = append(require, tokenList...)
}
for _, anchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(anchor)
provide = append(provide, tokenList...)
}
revoke = ctx.revoke.List()
return roots, TokenList{Tokens: require}, TokenList{Tokens: provide}, TokenList{Tokens: revoke}
}
func (ctx *BasicCapabilityContext) RemoveRoots(trust []did.DID, require, provide, revoke TokenList) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, root := range trust {
delete(ctx.roots, root)
}
for _, t := range require.Tokens {
tokenList, ok := ctx.require[t.Issuer()]
if ok {
tokenList = slices.DeleteFunc(tokenList, func(ot *Token) bool {
return bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(tokenList) > 0 {
ctx.require[t.Issuer()] = tokenList
} else {
delete(ctx.require, t.Issuer())
}
}
}
for _, t := range provide.Tokens {
tokenList, ok := ctx.provide[t.Issuer()]
if ok {
tokenList = slices.DeleteFunc(tokenList, func(ot *Token) bool {
return bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(tokenList) > 0 {
ctx.provide[t.Issuer()] = tokenList
} else {
delete(ctx.provide, t.Issuer())
}
}
}
for _, t := range revoke.Tokens {
_, ok := ctx.revoke.revoked[t.RevocationKey()]
if ok {
delete(ctx.revoke.revoked, t.RevocationKey())
}
}
}
func (ctx *BasicCapabilityContext) Grant(action Action, subject, audience did.DID, topics []string, expire, depth uint64, provide []Capability) (TokenList, error) {
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return TokenList{}, fmt.Errorf("nonce: %w", err)
}
topicCap := make([]Capability, 0, len(topics))
for _, topic := range topics {
topicCap = append(topicCap, Capability(topic))
}
result := &DMSToken{
Issuer: ctx.DID(),
Subject: subject,
Audience: audience,
Action: action,
Topic: topicCap,
Capability: provide,
Nonce: nonce,
Expire: expire,
Depth: depth,
}
data, err := result.SignatureData()
if err != nil {
return TokenList{}, fmt.Errorf("grant: %w", err)
}
sig, err := ctx.provider.Sign(data)
if err != nil {
return TokenList{}, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return TokenList{Tokens: []*Token{{DMS: result}}}, nil
}
func (ctx *BasicCapabilityContext) Revoke(token *Token) (*Token, error) {
if !ctx.DID().Equal(token.Issuer()) {
return nil, ErrNotAuthorized
}
revocationToken := &DMSToken{
Action: Revoke,
Issuer: token.Issuer(),
Subject: token.Subject(),
Nonce: token.Nonce(),
Expire: token.Expiry(),
Capability: token.Capability(),
}
data, err := revocationToken.SignatureData()
if err != nil {
return nil, fmt.Errorf("revoke: %w", err)
}
sig, err := ctx.provider.Sign(data)
if err != nil {
return nil, fmt.Errorf("sign: %w", err)
}
revocationToken.Signature = sig
return &Token{DMS: revocationToken}, nil
}
func (ctx *BasicCapabilityContext) Delegate(subject, audience did.DID, topics []string, expire, depth uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
topicCap := make([]Capability, 0, len(topics))
for _, topic := range topics {
topicCap = append(topicCap, Capability(topic))
}
var result []*Token
if selfSign == SelfSignOnly {
goto selfsign
}
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
if len(tokenList) == 0 {
continue
}
for _, t := range tokenList {
var providing []Capability
definitiveExpire := expire
if definitiveExpire == 0 {
definitiveExpire = t.Expire()
}
for _, c := range provide {
if t.Anchor(trustAnchor) && t.AllowDelegation(Delegate, ctx.DID(), audience, topicCap, definitiveExpire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
if len(provide) > len(providing) {
// attempt to widen caps
continue
}
token, err := t.Delegate(ctx.provider, subject, audience, topicCap, definitiveExpire, depth, providing)
if err != nil {
log.Debugf("error delegating %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
tokens, err := ctx.Grant(Delegate, subject, audience, topics, expire, depth, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, tokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) DelegateInvocation(target, subject, audience did.DID, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
// first get tokens we have about ourselves and see if any allows delegation to
// the subject for the audience
tokenList := ctx.getSubjectTokens(ctx.DID())
tokens := ctx.delegateInvocation(tokenList, target, subject, audience, expire, provide)
result = append(result, tokens...)
if selfSign == SelfSignOnly {
goto selfsign
}
// then we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateInvocation(tokenList, trustAnchor, subject, audience, expire, provide)
result = append(result, tokens...)
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
selfTokens, err := ctx.Grant(Invoke, subject, audience, nil, expire, 0, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateInvocation(tokenList []*Token, anchor, subject, audience did.DID, expire uint64, provide []Capability) []*Token {
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(Invoke, ctx.DID(), audience, nil, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
if len(provide) > len(providing) {
// attempt to widen caps
continue
}
token, err := t.DelegateInvocation(ctx.provider, subject, audience, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability, selfSign SelfSignMode) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
if selfSign == SelfSignOnly {
goto selfsign
}
// first we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateBroadcast(tokenList, trustAnchor, subject, topic, expire, provide)
result = append(result, tokens...)
}
if selfSign == SelfSignNo {
if len(result) == 0 {
return TokenList{}, ErrNotAuthorized
}
return TokenList{Tokens: result}, nil
}
selfsign:
selfTokens, err := ctx.Grant(Broadcast, subject, did.DID{}, []string{topic}, expire, 0, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting broadcast: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateBroadcast(tokenList []*Token, anchor did.DID, subject did.DID, topic string, expire uint64, provide []Capability) []*Token {
topicCap := Capability(topic)
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(Broadcast, ctx.DID(), did.DID{}, []Capability{topicCap}, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
if len(provide) > len(providing) {
// attempt to widen caps
continue
}
token, err := t.DelegateBroadcast(ctx.provider, subject, topicCap, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) Consume(origin did.DID, capToken []byte) error {
if len(capToken) > maxCapabilitySize {
return ErrTooBig
}
var tokens TokenList
if err := json.Unmarshal(capToken, &tokens); err != nil {
return fmt.Errorf("unmarshaling payload: %w", err)
}
rootAnchors := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
now := uint64(time.Now().UnixNano())
for _, t := range tokens.Tokens {
if t.Anchor(ctx.DID()) {
goto verify
}
if t.Anchor(origin) {
goto verify
}
for _, anchor := range rootAnchors {
if t.Anchor(anchor) {
goto verify
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) {
goto verify
}
}
}
log.Debugf("ignoring token %+v", *t)
continue
verify:
if err := t.Verify(ctx.trust, now, ctx.revoke); err != nil {
log.Warnf("failed to verify token issued by %s: %s", t.Issuer(), err)
continue
}
ctx.consumeSubjectToken(t)
}
return nil
}
func (ctx *BasicCapabilityContext) Discard(capTokens []byte) {
var tokens TokenList
if err := json.Unmarshal(capTokens, &tokens); err != nil {
return
}
ctx.discardTokens(tokens.Tokens)
}
func (ctx *BasicCapabilityContext) consumeAnchorToken(getf func() []*Token, setf func(result []*Token), t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList := getf()
result := make([]*Token, 0, len(tokenList)+1)
for _, ot := range tokenList {
if ot.Subsumes(t) {
return
}
if t.Subsumes(ot) {
continue
}
result = append(result, ot)
}
result = append(result, t)
setf(result)
}
func (ctx *BasicCapabilityContext) consumeRequireToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.require[t.Issuer()] },
func(result []*Token) {
ctx.require[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeProvideToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.provide[t.Issuer()] },
func(result []*Token) {
ctx.provide[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeSubjectToken(t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
subject := t.Subject()
tokenList := ctx.tokens[subject]
tokenList = append(tokenList, t)
ctx.tokens[subject] = tokenList
}
func (ctx *BasicCapabilityContext) consumeRevokeToken(t *Token) {
ctx.revoke.Revoke(t)
}
func (ctx *BasicCapabilityContext) Require(anchor did.DID, subject crypto.ID, audience crypto.ID, require []Capability) error {
if len(require) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return fmt.Errorf("DID for audience: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
for _, t := range tokenList {
for _, c := range require {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) RequireBroadcast(anchor did.DID, subject crypto.ID, topic string, require []Capability) error {
if len(require) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
topicCap := Capability(topic)
for _, t := range tokenList {
for _, c := range require {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowBroadcast(subjectDID, topicCap, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, provide []Capability) ([]byte, error) {
if len(invoke) == 0 && len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return nil, fmt.Errorf("DID for audience: %w", err)
}
var result []*Token
var invocation, delegation TokenList
if len(invoke) == 0 {
return nil, fmt.Errorf("no invocation capabilities: %w", ErrNotAuthorized)
}
invocation, err = ctx.DelegateInvocation(target, subjectDID, audienceDID, expire, invoke, SelfSignAlso)
if err != nil {
return nil, fmt.Errorf("cannot provide invocation tokens: %w", err)
}
if len(invocation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary invocation tokens: %w", ErrNotAuthorized)
}
result = append(result, invocation.Tokens...)
if len(provide) == 0 {
goto marshal
}
delegation, err = ctx.Delegate(target, subjectDID, nil, expire, 1, provide, SelfSignOnly)
if err != nil {
return nil, fmt.Errorf("cannot provide delegation tokens: %w", err)
}
if len(delegation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary delegation tokens: %w", ErrNotAuthorized)
}
result = append(result, delegation.Tokens...)
marshal:
payload := TokenList{Tokens: result}
data, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) ProvideBroadcast(subject crypto.ID, topic string, expire uint64, provide []Capability) ([]byte, error) {
if len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
broadcast, err := ctx.DelegateBroadcast(subjectDID, topic, expire, provide, SelfSignAlso)
if err != nil {
return nil, fmt.Errorf("cannot provide broadcast tokens: %w", err)
}
if len(broadcast.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary broadcast tokens: %w", ErrNotAuthorized)
}
data, err := json.Marshal(broadcast)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) getRoots() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.roots))
for anchor := range ctx.roots {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) addRoots(anchors []did.DID) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, anchor := range anchors {
ctx.roots[anchor] = struct{}{}
}
}
func (ctx *BasicCapabilityContext) getRequireAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.require))
for anchor := range ctx.require {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getProvideAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.provide))
for anchor := range ctx.provide {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getTokens(getf func() ([]*Token, bool), setf func([]*Token)) []*Token {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList, ok := getf()
if !ok {
return nil
}
// filter expired
now := uint64(time.Now().UnixNano())
result := slices.DeleteFunc(slices.Clone(tokenList), func(t *Token) bool {
return t.ExpireBefore(now)
})
setf(result)
return result
}
func (ctx *BasicCapabilityContext) getRequireTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.require[anchor]; return result, ok },
func(result []*Token) { ctx.require[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getProvideTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.provide[anchor]; return result, ok },
func(result []*Token) { ctx.provide[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getSubjectTokens(subject did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.tokens[subject]; return result, ok },
func(result []*Token) { ctx.tokens[subject] = result },
)
}
func (ctx *BasicCapabilityContext) discardTokens(tokens []*Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, t := range tokens {
subject := t.Subject()
subjectTokens := slices.DeleteFunc(slices.Clone(ctx.tokens[subject]), func(ot *Token) bool {
return t.Issuer() == ot.Issuer() && bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(subjectTokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = subjectTokens
}
}
}
func (ctx *BasicCapabilityContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcTokens()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicCapabilityContext) gcTokens() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := uint64(time.Now().UnixNano())
for anchor, tokens := range ctx.require {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.require, anchor)
} else {
ctx.require[anchor] = tokens
}
}
for anchor, tokens := range ctx.provide {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.provide, anchor)
} else {
ctx.provide[anchor] = tokens
}
}
ctx.revoke.gc(now)
for subject, tokens := range ctx.tokens {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = tokens
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ucan
import (
"encoding/json"
"fmt"
"io"
"gitlab.com/nunet/device-management-service/lib/did"
)
type Saver interface {
Save(wr io.Writer) error
}
type CapabilityContextView struct {
DID did.DID `json:"did"`
Roots []did.DID `json:"roots"`
Require TokenList `json:"require"`
Provide TokenList `json:"provide"`
Revoke TokenList `json:"revoke"`
}
func SaveCapabilityContext(ctx CapabilityContext, wr io.Writer) error {
if bctx, ok := ctx.(*BasicCapabilityContext); ok {
return saveCapabilityContext(bctx, wr)
} else if saver, ok := ctx.(Saver); ok {
return saver.Save(wr)
}
return fmt.Errorf("cannot save context: %w", ErrBadContext)
}
func saveCapabilityContext(ctx *BasicCapabilityContext, wr io.Writer) error {
roots, require, provide, revoke := ctx.ListRoots()
view := CapabilityContextView{
DID: ctx.provider.DID(),
Roots: roots,
Require: require,
Provide: provide,
Revoke: revoke,
}
encoder := json.NewEncoder(wr)
if err := encoder.Encode(&view); err != nil {
return fmt.Errorf("encoding capability context view: %w", err)
}
return nil
}
func LoadCapabilityContextWithName(name string, trust did.TrustContext, rd io.Reader) (CapabilityContext, error) {
var view CapabilityContextView
decoder := json.NewDecoder(rd)
if err := decoder.Decode(&view); err != nil {
return nil, fmt.Errorf("decoding capability context view: %w", err)
}
var require, provide, revoke TokenList
for _, t := range view.Require.Tokens {
if !t.Expired() {
require.Tokens = append(require.Tokens, t)
}
}
for _, t := range view.Provide.Tokens {
if !t.Expired() {
provide.Tokens = append(provide.Tokens, t)
}
}
for _, t := range view.Revoke.Tokens {
if !t.Expired() {
revoke.Tokens = append(revoke.Tokens, t)
}
}
return NewCapabilityContextWithName(name, trust, view.DID, view.Roots, require, provide, revoke)
}
func LoadCapabilityContext(trust did.TrustContext, rd io.Reader) (CapabilityContext, error) {
return LoadCapabilityContextWithName("dms", trust, rd)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ucan
import (
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/did"
)
type Action string
const (
Invoke Action = "invoke"
Delegate Action = "delegate"
Broadcast Action = "broadcast"
Revoke Action = "revoke"
nonceLength = 12 // 96 bits
)
var signaturePrefix = []byte("dms:token:")
type Token struct {
// DMS tokens
DMS *DMSToken `json:"dms,omitempty"`
// UCAN standard (when it is done) envelope for BYO anhcors
UCAN *BYOToken `json:"ucan,omitempty"`
}
type DMSToken struct {
Action Action `json:"act"`
Issuer did.DID `json:"iss"`
Subject did.DID `json:"sub"`
Audience did.DID `json:"aud"`
Topic []Capability `json:"topic,omitempty"`
Capability []Capability `json:"cap"`
Nonce []byte `json:"nonce"`
Expire uint64 `json:"exp"`
Depth uint64 `json:"depth,omitempty"`
Chain *Token `json:"chain,omitempty"`
Signature []byte `json:"sig,omitempty"`
}
type BYOToken struct {
// TODO followup
}
type TokenList struct {
Tokens []*Token `json:"tok,omitempty"`
}
type RevocationSet struct {
lk sync.RWMutex
revoked map[string]*Token
}
func (r *RevocationSet) Revoked(key string) bool {
r.lk.RLock()
defer r.lk.RUnlock()
_, revoked := r.revoked[key]
return revoked
}
func (r *RevocationSet) Revoke(t *Token) {
r.lk.Lock()
defer r.lk.Unlock()
r.revoked[t.RevocationKey()] = t
}
func (r *RevocationSet) List() []*Token {
r.lk.RLock()
defer r.lk.RUnlock()
result := make([]*Token, 0, len(r.revoked))
now := uint64(time.Now().UnixNano())
for _, t := range r.revoked {
if t.ExpireBefore(now) {
continue
}
result = append(result, t)
}
return result
}
func (r *RevocationSet) gc(now uint64) {
r.lk.Lock()
defer r.lk.Unlock()
for key, token := range r.revoked {
if token.ExpireBefore(now) {
delete(r.revoked, key)
}
}
}
func (t *Token) RevocationKey() string {
switch {
case t.DMS != nil:
return t.DMS.RevocationKey()
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return ""
}
}
func (t *DMSToken) RevocationKey() string {
return fmt.Sprintf("%s#%s#%s", t.Issuer, t.Subject, string(t.Nonce))
}
func (t *DMSToken) Revoked(revoke *RevocationSet) bool {
return revoke.Revoked(t.RevocationKey())
}
func (t *Token) SignatureData() ([]byte, error) {
switch {
case t.DMS != nil:
return t.DMS.SignatureData()
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) SignatureData() ([]byte, error) {
tCopy := *t
tCopy.Signature = nil
data, err := json.Marshal(&tCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (t *Token) Issuer() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Issuer
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Subject() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Subject
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Audience() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Audience
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Capability() []Capability {
switch {
case t.DMS != nil:
return t.DMS.Capability
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Topic() []Capability {
switch {
case t.DMS != nil:
return t.DMS.Topic
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Expire() uint64 {
switch {
case t.DMS != nil:
return t.DMS.Expire
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0
}
}
func (t *Token) Nonce() []byte {
switch {
case t.DMS != nil:
return t.DMS.Nonce
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil // expired right after the unix big bang
}
}
func (t *Token) Action() Action {
switch {
case t.DMS != nil:
return t.DMS.Action
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return Action("")
}
}
func (t *Token) Verify(trust did.TrustContext, now uint64, revoke *RevocationSet) error {
return t.verify(trust, now, 0, revoke)
}
func (t *Token) verify(trust did.TrustContext, now, depth uint64, revoke *RevocationSet) error {
switch {
case t.DMS != nil:
return t.DMS.verify(trust, now, depth, revoke)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return ErrBadToken
}
}
func (t *DMSToken) verify(trust did.TrustContext, now, depth uint64, revoke *RevocationSet) error {
if t.ExpireBefore(now) {
return ErrCapabilityExpired
}
if t.Action == Revoke {
anchor, err := trust.GetAnchor(t.Issuer)
if err != nil {
return fmt.Errorf("verify: anchor: %w", err)
}
data, err := t.SignatureData()
if err != nil {
return fmt.Errorf("verify: signature data: %w", err)
}
if err := anchor.Verify(data, t.Signature); err != nil {
return fmt.Errorf("verify: signature: %w", err)
}
return nil
}
if t.Depth > 0 && depth > t.Depth {
return fmt.Errorf("max token depth exceeded: %w", ErrNotAuthorized)
}
if t.Revoked(revoke) {
return fmt.Errorf("verify: token has been revoked: %w", ErrNotAuthorized)
}
if t.Chain != nil {
if t.Chain.Action() != Delegate {
return fmt.Errorf("verify: chain does not allow delegation: %w", ErrNotAuthorized)
}
if t.Chain.ExpireBefore(t.Expire) {
return ErrCapabilityExpired
}
if err := t.Chain.verify(trust, now, depth+1, revoke); err != nil {
return err
}
if !t.Issuer.Equal(t.Chain.Subject()) {
return fmt.Errorf("verify: issuer/chain subject mismatch: %w", ErrNotAuthorized)
}
needCapability := slices.Clone(t.Capability)
for _, c := range t.Capability {
if t.Chain.allowDelegation(t.Issuer, t.Audience, t.Topic, t.Expire, c) {
needCapability = slices.DeleteFunc(needCapability, func(oc Capability) bool {
return c == oc
})
if len(needCapability) == 0 {
break
}
}
}
if len(needCapability) > 0 {
return fmt.Errorf("verify: capabilities are not allowed by the chain: %w", ErrNotAuthorized)
}
}
anchor, err := trust.GetAnchor(t.Issuer)
if err != nil {
return fmt.Errorf("verify: anchor: %w", err)
}
data, err := t.SignatureData()
if err != nil {
return fmt.Errorf("verify: signature data: %w", err)
}
if err := anchor.Verify(data, t.Signature); err != nil {
return fmt.Errorf("verify: signature: %w", err)
}
return nil
}
func (t *Token) AllowAction(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowAction(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowAction(ot *Token) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(ot.Expire()) {
return false
}
if !ot.Anchor(t.Subject) {
return false
}
if t.Depth > 0 {
depth, ok := ot.AnchorDepth(t.Subject)
if ok && depth > t.Depth {
return false
}
}
if !t.Audience.Empty() && !t.Audience.Equal(ot.Audience()) {
return false
}
for _, oc := range ot.Capability() {
allow := false
for _, c := range t.Capability {
if c.Implies(oc) {
allow = true
break
}
}
if !allow {
return false
}
}
for _, otherTopic := range ot.Topic() {
allow := false
for _, topic := range t.Topic {
if topic.Implies(otherTopic) {
allow = true
break
}
}
if !allow {
return false
}
}
return true
}
func (t *Token) Size() int {
data, _ := t.SignatureData()
return len(data)
}
func (t *Token) Subsumes(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.Subsumes(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Subsumes(ot *Token) bool {
if t.Issuer.Equal(ot.Issuer()) &&
t.Subject.Equal(ot.Subject()) &&
t.Audience.Equal(ot.Audience()) &&
t.Expire >= ot.Expire() {
// Recursively check if chains subsume each other
// If both have chains, the chains must subsume each other
// If only one has a chain, they don't subsume each other
hasTChain := t.Chain != nil
hasOtChain := ot.DMS != nil && ot.DMS.Chain != nil
if hasTChain != hasOtChain {
// One has a chain, the other doesn't - they don't subsume each other
return false
}
if hasTChain && hasOtChain {
// Both have chains - recursively check if t's chain subsumes ot's chain
if !t.Chain.Subsumes(ot.DMS.Chain) {
return false
}
}
loop:
for _, oc := range ot.Capability() {
for _, c := range t.Capability {
if c.Implies(oc) {
continue loop
}
}
return false
}
return true
}
return false
}
func (t *Token) AllowInvocation(subject, audience did.DID, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowInvocation(subject, audience, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowInvocation(subject, audience did.DID, c Capability) bool {
if t.Action != Invoke {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, granted := range t.Capability {
if granted.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowBroadcast(subject did.DID, topic Capability, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowBroadcast(subject, topic, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowBroadcast(subject did.DID, topic Capability, c Capability) bool {
if t.Action != Broadcast {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() {
return false
}
allow := false
for _, allowTopic := range t.Topic {
if allowTopic.Implies(topic) {
allow = true
break
}
}
if !allow {
return false
}
for _, allowCap := range t.Capability {
if allowCap.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowDelegation(action Action, issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowDelegation(action, issuer, audience, topics, expire, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowDelegation(action Action, issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
if action == Delegate {
if !t.verifyDepth(2) {
// certificate would be dead end with 1
return false
}
} else {
if !t.verifyDepth(1) {
return false
}
}
return t.allowDelegation(issuer, audience, topics, expire, c)
}
func (t *Token) allowDelegation(issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.allowDelegation(issuer, audience, topics, expire, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) allowDelegation(issuer, audience did.DID, topics []Capability, expire uint64, c Capability) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(expire) {
return false
}
if !t.Subject.Equal(issuer) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, topic := range topics {
allow := false
for _, myTopic := range t.Topic {
if myTopic.Implies(topic) {
allow = true
break
}
}
if !allow {
return false
}
}
for _, myCap := range t.Capability {
if myCap.Implies(c) {
return true
}
}
return false
}
func (t *Token) verifyDepth(depth uint64) bool {
switch {
case t.DMS != nil:
return t.DMS.verifyDepth(depth)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) verifyDepth(depth uint64) bool {
if t.Depth > 0 && depth > t.Depth {
return false
}
if t.Chain != nil {
return t.Chain.verifyDepth(depth + 1)
}
return true
}
func (t *Token) Delegate(provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.Delegate(provider, subject, audience, topics, expire, depth, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) Delegate(provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Delegate, provider, subject, audience, topics, expire, depth, c)
}
func (t *DMSToken) delegate(action Action, provider did.Provider, subject, audience did.DID, topics []Capability, expire, depth uint64, c []Capability) (*DMSToken, error) {
if t.Action != Delegate {
return nil, ErrNotAuthorized
}
if action == Delegate {
if !t.verifyDepth(2) {
// certificate would be dead end with 1
return nil, ErrNotAuthorized
}
} else {
if !t.verifyDepth(1) {
return nil, ErrNotAuthorized
}
}
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return nil, fmt.Errorf("nonce: %w", err)
}
result := &DMSToken{
Action: action,
Issuer: provider.DID(),
Subject: subject,
Audience: audience,
Topic: topics,
Capability: c,
Nonce: nonce,
Expire: expire,
Depth: depth,
Chain: &Token{DMS: t},
}
data, err := result.SignatureData()
if err != nil {
return nil, fmt.Errorf("delegate: %w", err)
}
sig, err := provider.Sign(data)
if err != nil {
return nil, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return result, nil
}
func (t *Token) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateInvocation(provider, subject, audience, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Invoke, provider, subject, audience, nil, expire, 0, c)
}
func (t *Token) DelegateBroadcast(provider did.Provider, subject did.DID, topic Capability, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateBroadcast(provider, subject, topic, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateBroadcast(provider did.Provider, subject did.DID, topic Capability, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Broadcast, provider, subject, did.DID{}, []Capability{topic}, expire, 0, c)
}
func (t *Token) Anchor(anchor did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.Anchor(anchor)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Anchor(anchor did.DID) bool {
if t.Issuer.Equal(anchor) {
return true
}
if t.Chain != nil {
return t.Chain.Anchor(anchor)
}
return false
}
func (t *Token) AnchorDepth(anchor did.DID) (uint64, bool) {
switch {
case t.DMS != nil:
return t.DMS.AnchorDepth(anchor)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0, false
}
}
func (t *DMSToken) AnchorDepth(anchor did.DID) (depth uint64, have bool) {
if t.Issuer.Equal(anchor) {
have = true
depth = 0
}
if t.Chain != nil {
if chainDepth, chainHave := t.Chain.AnchorDepth(anchor); chainHave {
have = true
depth = chainDepth + 1
}
}
return depth, have
}
func (t *Token) Expiry() uint64 {
switch {
case t.DMS != nil:
return t.DMS.Expire
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0
}
}
func (t *Token) Expired() bool {
return t.ExpireBefore(uint64(time.Now().UnixNano()))
}
func (t *Token) ExpireBefore(deadline uint64) bool {
switch {
case t.DMS != nil:
return t.DMS.ExpireBefore(deadline)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return true
}
}
func (t *DMSToken) ExpireBefore(deadline uint64) bool {
if deadline > t.Expire {
return true
}
if t.Chain != nil {
return t.Chain.ExpireBefore(deadline)
}
return false
}
func (t *Token) SelfSigned(origin did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.SelfSigned(origin)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) SelfSigned(origin did.DID) bool {
if t.Chain != nil {
return t.Chain.SelfSigned(origin)
}
return t.Issuer.Equal(origin)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package main
import (
"gitlab.com/nunet/device-management-service/cmd"
)
// @title Device Management Service
// @version 0.4.185
// @description A dashboard application for computing providers.
// @termsOfService https://nunet.io/tos
// @contact.name Support
// @contact.url https://devexchange.nunet.io/
// @contact.email support@nunet.io
// @license.name Apache 2.0
// @license.url http://www.apache.org/licenses/LICENSE-2.0.html
// @host localhost:9999
//
// @Schemes http
//
// @BasePath /api/v1
func main() {
cmd.Execute()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
"strings"
"sync"
"time"
dht_pb "github.com/libp2p/go-libp2p-kad-dht/pb"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
msgio "github.com/libp2p/go-msgio"
"github.com/libp2p/go-msgio/protoio" //nolint:staticcheck
multiaddr "github.com/multiformats/go-multiaddr"
"google.golang.org/protobuf/proto"
dmscrypto "gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/observability"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/types"
)
const kadv1 = "/kad/1.0.0"
var (
ErrInvalidKeyNamespace = errors.New("invalid key namespace")
ErrValidateEnvelopeByPbkey = errors.New("failed to verify public key")
)
// Connect to Bootstrap nodes
func (l *Libp2p) connectToBootstrapNodes(ctx context.Context) error {
// bootstrap all nodes at the same time.
if len(l.config.BootstrapPeers) > 0 {
var wg sync.WaitGroup
connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
for _, addr := range l.config.BootstrapPeers {
wg.Add(1)
go func(peerAddr multiaddr.Multiaddr) {
defer wg.Done()
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
log.Errorf("failed to convert multi addr to addr info %v - %v", peerAddr, err)
return
}
if err := l.Host.Connect(connectCtx, *addrInfo); err != nil {
log.Errorf("failed to connect to bootstrap node %s - %v", addrInfo.ID.String(), err)
} else {
log.Infof("connected to Bootstrap Node %s", addrInfo.ID.String())
}
}(addr)
}
wg.Wait()
}
return nil
}
// Start dht bootstrapper
func (l *Libp2p) bootstrapDHT(ctx context.Context) error {
endSpan := observability.StartSpan(ctx, "libp2p_bootstrap")
defer endSpan()
if err := l.DHT.Bootstrap(ctx); err != nil {
log.Errorw("libp2p_bootstrap_failure",
"labels", string(observability.LabelNode),
"error", err)
return err
}
log.Infow("libp2p_bootstrap_success",
"labels", string(observability.LabelNode))
return nil
}
// startRandomWalk starts a background process that crawls the dht by resolving random keys.
func (l *Libp2p) startRandomWalk(ctx context.Context) {
go func() {
log.Debug("starting random walk for dht peer discovery")
// A simple mechanism to improve our botostrap and peer discovery:
// 1. initiate a background, never ending, random walk which tries to resolve
// random keys in the dht and by extension discovers other peers.
interval := 5 * time.Minute
delayOnError := 10 * time.Second
time.Sleep(interval) // wait for dht ready
dhtProto := protocol.ID(l.config.DHTPrefix + kadv1)
sender := newDHTMessageSender(l.Host, dhtProto)
messenger, err := dht_pb.NewProtocolMessenger(sender)
if err != nil {
log.Errorf("random walk: creating protocol messenger: %s", err)
return
}
var depth int
var key string
for {
select {
case <-ctx.Done():
log.Debugf("random walk: context done, stopping random walk")
return
default:
randomPeerID, err := l.DHT.RoutingTable().GenRandPeerID(0)
if err != nil {
log.Debugf("random walk: failed to generate random peer ID: %s", err)
continue
}
key = randomPeerID.String()
log.Debugf("random walk: crawling from %s", key)
peers, err := l.DHT.GetClosestPeers(ctx, key)
if err != nil {
log.Debugf("random walk: failed to get closest peers with key=%s - error: %s", randomPeerID.String(), err)
time.Sleep(delayOnError)
delayOnError = time.Duration(float64(delayOnError) * 1.25)
if delayOnError > 5*time.Minute {
delayOnError = 5 * time.Minute
}
continue
}
delayOnError = 10 * time.Second
if len(peers) == 0 {
continue
}
peerID := peers[rand.Intn(len(peers))] //nolint:gosec
if peerID == l.Host.ID() {
log.Debugf("random walk: skipping self")
continue
}
log.Debugf("random walk: starting random walk from %s", peerID)
peerAddrInfo, err := l.resolvePeerAddress(ctx, peerID)
if err != nil {
log.Debugf("random walk: failed to resolve address for peer %s - %v", peerID, err)
continue
}
var peerInfos []*peer.AddrInfo
selected := &peerAddrInfo
crawl:
log.Debugf("random walk: crawling %s", selected.ID)
if err := l.Host.Connect(ctx, *selected); err != nil {
log.Debugf("random walk: failed to connect to peer %s: %s", peerID, err)
depth++
continue
}
peerInfos, err = messenger.GetClosestPeers(ctx, selected.ID, randomPeerID)
if err != nil {
log.Debugf("random walk: failed to get closest peers from %s: %s", selected.ID, err)
depth++
continue
}
if len(peerInfos) == 0 {
if depth < 20 {
depth++
continue
}
goto cooldown
}
selected = peerInfos[rand.Intn(len(peerInfos))] //nolint:gosec
if selected.ID == l.Host.ID() {
log.Debugf("random walk: skipping self")
depth++
continue
}
if depth < 20 {
randomPeerID, err = l.DHT.RoutingTable().GenRandPeerID(0)
if err != nil {
log.Debugf("random walk: failed to generate random peer ID: %s", err)
goto cooldown
}
depth++
goto crawl
}
// cooldown
cooldown:
depth = 0
minDelay := interval / 2
maxDelay := (3 * interval) / 2
delay := minDelay + time.Duration(rand.Int63n(int64(maxDelay-minDelay))) //nolint:gosec
log.Debugf("random walk: cooling down for %s", delay)
select {
case <-time.After(delay):
case <-ctx.Done():
return
}
interval = interval * 3 / 2
if interval > 4*time.Hour {
interval = 4 * time.Hour
}
}
}
}()
}
type dhtValidator struct {
PS peerstore.Peerstore
customNamespace string
}
// Validate validates an item placed into the dht.
func (d dhtValidator) Validate(key string, value []byte) error {
endSpan := observability.StartSpan("libp2p_dht_validate")
defer endSpan()
// empty value is considered deleting an item from the dht
if len(value) == 0 {
log.Infow("libp2p_dht_validate_success",
"labels", string(observability.LabelNode),
"key", key)
return nil
}
if !strings.HasPrefix(key, d.customNamespace) {
log.Errorw("libp2p_dht_validate_failure",
"labels", string(observability.LabelNode),
"key", key,
"error", "invalid key namespace")
return ErrInvalidKeyNamespace
}
// verify signature
var envelope commonproto.Advertisement
err := proto.Unmarshal(value, &envelope)
if err != nil {
log.Errorw("libp2p_dht_validate_failure",
"labels", string(observability.LabelNode),
"key", key,
"error", fmt.Sprintf("failed to unmarshal envelope: %v", err))
return fmt.Errorf("%w envelope: %w", types.ErrUnmarshal, err)
}
pubKey, err := crypto.UnmarshalEd25519PublicKey(envelope.PublicKey)
if err != nil {
log.Errorw("libp2p_dht_validate_failure",
"labels", string(observability.LabelNode),
"key", key,
"error", fmt.Sprintf("failed to unmarshal public key: %v", err))
return fmt.Errorf("%w: %w", dmscrypto.ErrUnmarshalPublicKey, err)
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
envelope.PublicKey,
}, nil)
ok, err := pubKey.Verify(concatenatedBytes, envelope.Signature)
if err != nil {
log.Errorw("libp2p_dht_validate_failure",
"labels", string(observability.LabelNode),
"key", key,
"error", fmt.Sprintf("failed to verify envelope: %v", err))
return fmt.Errorf("%w envelope: %w", ErrValidateEnvelopeByPbkey, err)
}
if !ok {
log.Errorw("libp2p_dht_validate_failure",
"labels", string(observability.LabelNode),
"key", key,
"error", "public key didn't sign the payload")
return fmt.Errorf("%w, public key didn't sign payload", ErrValidateEnvelopeByPbkey)
}
log.Infow("libp2p_dht_validate_success",
"labels", string(observability.LabelNode),
"key", key)
return nil
}
func (dhtValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
type dhtMessenger struct {
host host.Host
proto protocol.ID
}
func newDHTMessageSender(h host.Host, proto protocol.ID) dht_pb.MessageSender {
return &dhtMessenger{host: h, proto: proto}
}
func (m *dhtMessenger) SendRequest(ctx context.Context, p peer.ID, msg *dht_pb.Message) (*dht_pb.Message, error) {
s, err := m.host.NewStream(ctx, p, m.proto)
if err != nil {
return nil, fmt.Errorf("open stream: %w", err)
}
defer s.Close()
wr := protoio.NewDelimitedWriter(s)
if err := wr.WriteMsg(msg); err != nil {
_ = s.Reset()
return nil, fmt.Errorf("write message: %w", err)
}
r := msgio.NewVarintReaderSize(s, network.MessageSizeMax)
bytes, err := r.ReadMsg()
if err != nil {
_ = s.Reset()
return nil, fmt.Errorf("read message: %w", err)
}
defer r.ReleaseMsg(bytes)
reply := new(dht_pb.Message)
if err := proto.Unmarshal(bytes, reply); err != nil {
_ = s.Reset()
return nil, fmt.Errorf("unmarshal message: %w", err)
}
return reply, nil
}
func (m *dhtMessenger) SendMessage(ctx context.Context, p peer.ID, msg *dht_pb.Message) error {
s, err := m.host.NewStream(ctx, p, m.proto)
if err != nil {
return fmt.Errorf("open stream: %w", err)
}
defer s.Close()
wr := protoio.NewDelimitedWriter(s)
if err := wr.WriteMsg(msg); err != nil {
_ = s.Reset()
return fmt.Errorf("write message: %w", err)
}
return s.CloseWrite()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"context"
"fmt"
"math/rand"
"time"
"github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/peer"
dutil "github.com/libp2p/go-libp2p/p2p/discovery/util"
"gitlab.com/nunet/device-management-service/observability"
)
// discoverDialPeers discovers peers using rendezvous point
func (l *Libp2p) discoverDialPeers(ctx context.Context) error {
endSpan := observability.StartSpan(ctx, "libp2p_peer_discover")
defer endSpan()
foundPeers, err := l.findPeersFromRendezvousDiscovery(ctx)
if err != nil {
log.Errorw("libp2p_peer_discover_failure", "error", err)
return err
}
if len(foundPeers) > 0 {
l.discoveredPeers = foundPeers
log.Debugw("libp2p_peer_discover_success", "foundPeers", len(foundPeers))
} else {
log.Warnw("No peers found during discovery")
}
// filter out peers with no listening addresses and self host
filterSpec := NoAddrIDFilter{ID: l.Host.ID()}
l.discoveredPeers = PeerPassFilter(l.discoveredPeers, filterSpec)
l.dialPeers(ctx)
return nil
}
// advertiseForRendezvousDiscovery is used to advertise node using the dht by giving it the randevouz point.
func (l *Libp2p) advertiseForRendezvousDiscovery(context context.Context) error {
_, err := l.discovery.Advertise(context, l.config.Rendezvous)
return err
}
// findPeersFromRendezvousDiscovery uses the randevouz point to discover other peers.
func (l *Libp2p) findPeersFromRendezvousDiscovery(ctx context.Context) ([]peer.AddrInfo, error) {
endSpan := observability.StartSpan(ctx, "libp2p_find_peers")
defer endSpan()
peers, err := dutil.FindPeers(
ctx,
l.discovery,
l.config.Rendezvous,
discovery.Limit(l.config.PeerCountDiscoveryLimit),
)
if err != nil {
log.Errorw("libp2p_find_peers_failure", "error", err)
return nil, fmt.Errorf("failed to discover peers: %w", err)
}
log.Debugw("libp2p_peers_from_rendezvous", "peersCount", len(peers))
return peers, nil
}
func (l *Libp2p) dialPeers(ctx context.Context) {
maxPeers := 16
peersToConnect := l.discoveredPeers
if len(peersToConnect) > maxPeers {
//nolint:gosec
r := rand.New(rand.NewSource(time.Now().UnixNano()))
r.Shuffle(len(peersToConnect), func(i, j int) {
peersToConnect[i], peersToConnect[j] = peersToConnect[j],
peersToConnect[i]
})
// Take only the first maxPeers
peersToConnect = peersToConnect[:maxPeers]
}
for _, p := range peersToConnect {
if p.ID == l.Host.ID() {
continue
}
if !l.PeerConnected(p.ID) {
go func(p peer.AddrInfo) {
dialCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
if err := l.Host.Connect(dialCtx, p); err != nil {
log.Debugf("couldn't establish connection with: %s - error: %v", p.ID, err)
return
}
log.Debugf("connected with: %s", p.ID)
}(p)
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"fmt"
"net"
"strings"
"github.com/miekg/dns"
)
// resolveDNS resolves a DNS query using the provided resolver
func resolveDNS(query *dns.Msg, records map[string]string) *dns.Msg {
// Create a response message
m := new(dns.Msg)
m.SetReply(query)
for _, question := range query.Question {
if question.Qtype != dns.TypeA {
// We only support A records
m.SetRcode(query, dns.RcodeSuccess)
m.Answer = []dns.RR{}
continue
}
ip, ok := records[strings.TrimSuffix(question.Name, ".")]
if !ok {
// Not found in our map, set answer to SRVFAIL for request to fallthrough to the next nameserver
m.SetRcode(query, dns.RcodeServerFailure)
continue
}
// Found record, add A record to the answer section
a := &dns.A{
Hdr: dns.RR_Header{Name: question.Name, Rrtype: dns.TypeA, Class: dns.ClassINET},
A: net.ParseIP(ip),
}
m.Answer = append(m.Answer, a)
}
return m
}
// handleDNSQuery handles a DNS query by parsing the UDP packet, resolving the query, and sending a response
func handleDNSQuery(packet []byte, records map[string]string) ([]byte, error) {
// Parse the UDP packet into a DNS message
msg := new(dns.Msg)
err := msg.Unpack(packet)
if err != nil {
return nil, fmt.Errorf("failed to decode DNS message: %w", err)
}
// Resolve the DNS query
response := resolveDNS(msg, records)
log.Debugw("DNS query resolved successfully", "response", response)
// Encode the response message into a UDP packet
responseBytes, err := response.Pack()
if err != nil {
return nil, fmt.Errorf("failed to encode DNS response: %w", err)
}
return responseBytes, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/control"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
mafilt "github.com/whyrusleeping/multiaddr-filter"
)
var defaultServerFilters = []string{
"/ip4/10.0.0.0/ipcidr/8",
"/ip4/100.64.0.0/ipcidr/10",
"/ip4/169.254.0.0/ipcidr/16",
"/ip4/172.16.0.0/ipcidr/12",
"/ip4/192.0.0.0/ipcidr/24",
"/ip4/192.0.2.0/ipcidr/24",
"/ip4/192.168.0.0/ipcidr/16",
"/ip4/198.18.0.0/ipcidr/15",
"/ip4/198.51.100.0/ipcidr/24",
"/ip4/203.0.113.0/ipcidr/24",
"/ip4/240.0.0.0/ipcidr/4",
"/ip6/100::/ipcidr/64",
"/ip6/2001:2::/ipcidr/48",
"/ip6/2001:db8::/ipcidr/32",
"/ip6/fc00::/ipcidr/7",
"/ip6/fe80::/ipcidr/10",
}
// PeerFilter is an interface for filtering peers
// satisfaction of filter criteria allows the peer to pass
type PeerFilter interface {
satisfies(p peer.AddrInfo) bool
}
// NoAddrIDFilter filters out peers with no listening addresses
// and a peer with a specific ID
type NoAddrIDFilter struct {
ID peer.ID
}
func (f NoAddrIDFilter) satisfies(p peer.AddrInfo) bool {
return len(p.Addrs) > 0 && p.ID != f.ID
}
func PeerPassFilter(peers []peer.AddrInfo, pf PeerFilter) []peer.AddrInfo {
var filtered []peer.AddrInfo
for _, p := range peers {
if pf.satisfies(p) {
filtered = append(filtered, p)
}
}
return filtered
}
type filtersConnectionGater multiaddr.Filters
var _ connmgr.ConnectionGater = (*filtersConnectionGater)(nil)
func (f *filtersConnectionGater) InterceptAddrDial(_ peer.ID, addr multiaddr.Multiaddr) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(addr)
}
func (f *filtersConnectionGater) InterceptPeerDial(_ peer.ID) (allow bool) {
return true
}
func (f *filtersConnectionGater) InterceptAccept(connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptUpgraded(_ network.Conn) (allow bool, reason control.DisconnectReason) {
return true, 0
}
func makeAddrsFactory(announce []string, appendAnnouce []string, noAnnounce []string) func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var err error // To assign to the slice in the for loop
existing := make(map[string]bool) // To avoid duplicates
annAddrs := make([]multiaddr.Multiaddr, len(announce))
for i, addr := range announce {
annAddrs[i], err = multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
existing[addr] = true
}
appendAnnAddrs := make([]multiaddr.Multiaddr, 0)
for _, addr := range appendAnnouce {
if existing[addr] {
// skip AppendAnnounce that is on the Announce list already
continue
}
appendAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
appendAnnAddrs = append(appendAnnAddrs, appendAddr)
}
filters := multiaddr.NewFilters()
noAnnAddrs := map[string]bool{}
for _, addr := range noAnnounce {
f, err := mafilt.NewMask(addr)
if err == nil {
filters.AddFilter(*f, multiaddr.ActionDeny)
continue
}
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
noAnnAddrs[string(maddr.Bytes())] = true
}
return func(allAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
var addrs []multiaddr.Multiaddr
if len(annAddrs) > 0 {
addrs = annAddrs
} else {
addrs = allAddrs
}
addrs = append(addrs, appendAnnAddrs...)
var out []multiaddr.Multiaddr
for _, maddr := range addrs {
// check for exact matches
ok := noAnnAddrs[string(maddr.Bytes())]
// check for /ipcidr matches
if !ok && !filters.AddrBlocked(maddr) {
out = append(out, maddr)
}
}
return out
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/types"
)
// StreamHandler is a function type that processes data from a stream.
type StreamHandler func(stream network.Stream)
// HandlerRegistry manages the registration of stream handlers for different protocols.
type HandlerRegistry struct {
host host.Host
handlers map[protocol.ID]StreamHandler
bytesHandlers map[protocol.ID]func(data []byte, peerId peer.ID)
mu sync.RWMutex
}
// NewHandlerRegistry creates a new handler registry instance.
func NewHandlerRegistry(host host.Host) *HandlerRegistry {
return &HandlerRegistry{
host: host,
handlers: make(map[protocol.ID]StreamHandler),
bytesHandlers: make(map[protocol.ID]func(data []byte, peerId peer.ID)),
}
}
// RegisterHandlerWithStreamCallback registers a stream handler for a specific protocol.
func (r *HandlerRegistry) RegisterHandlerWithStreamCallback(messageType types.MessageType, handler StreamHandler) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.handlers[protoID]
if ok {
return ErrStreamRegistered
}
r.handlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(handler))
return nil
}
// RegisterHandlerWithBytesCallback registers a stream handler for a specific protocol and sends the bytes back to callback.
func (r *HandlerRegistry) RegisterHandlerWithBytesCallback(
messageType types.MessageType,
s StreamHandler, handler func(data []byte, peerId peer.ID),
) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.bytesHandlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.bytesHandlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(s))
return nil
}
// SendMessageToLocalHandler given the message type it sends data to the local handler found.
func (r *HandlerRegistry) SendMessageToLocalHandler(messageType types.MessageType, data []byte, peerID peer.ID) {
r.mu.RLock()
defer r.mu.RUnlock()
protoID := protocol.ID(messageType)
h, ok := r.bytesHandlers[protoID]
if !ok {
return
}
// we need this goroutine to avoid blocking the caller goroutine
go h(data, peerID)
}
// UnregisterHandler unregisters a stream handler for a specific protocol.
func (r *HandlerRegistry) UnregisterHandler(messageType types.MessageType) {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
delete(r.handlers, protoID)
delete(r.bytesHandlers, protoID)
r.host.RemoveStreamHandler(protoID)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"context"
"fmt"
"net"
"strconv"
"strings"
"time"
"github.com/quic-go/quic-go"
"github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/routing"
"github.com/libp2p/go-libp2p/p2p/host/autorelay"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
rcmgr "github.com/libp2p/go-libp2p/p2p/host/resource-manager"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/protocol/holepunch"
"github.com/libp2p/go-libp2p/p2p/security/noise"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
mafilt "github.com/whyrusleeping/multiaddr-filter"
libp2pquic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
// NewHost returns a new libp2p host with dht and other related settings.
func NewHost(ctx context.Context, config *types.Libp2pConfig, appScore func(p peer.ID) float64, scoreInspect pubsub.ExtendedPeerScoreInspectFn) (host.Host, *dht.IpfsDHT, *pubsub.PubSub, *net.UDPConn, *RawQUICTransport, error) {
newPeer := make(chan peer.AddrInfo)
var idht *dht.IpfsDHT
connmgr, err := connmgr.NewConnManager(
100,
400,
connmgr.WithGracePeriod(time.Duration(config.GracePeriodMs)*time.Millisecond),
)
if err != nil {
return nil, nil, nil, nil, nil, err
}
filter := ma.NewFilters()
for _, s := range defaultServerFilters {
f, err := mafilt.NewMask(s)
if err != nil {
log.Errorw("incorrectly formatted address filter in config",
"labels", string(observability.LabelNode),
"filter", s,
"error", err,
)
}
filter.AddFilter(*f, ma.ActionDeny)
}
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, nil, nil, nil, nil, err
}
var libp2pOpts []libp2p.Option
dhtOpts := []dht.Option{
dht.ProtocolPrefix(protocol.ID(config.DHTPrefix)),
dht.NamespacedValidator(strings.ReplaceAll(config.CustomNamespace, "/", ""), dhtValidator{PS: ps, customNamespace: config.CustomNamespace}),
dht.Mode(dht.ModeAutoServer),
}
// set up the resource manager
mem := int64(config.Memory)
if mem > 0 {
mem = 1024 * 1024 * mem
} else {
mem = 1024 * 1024 * 1024 // 1GB
}
fds := config.FileDescriptors
if fds == 0 {
fds = 512
}
limits := rcmgr.DefaultLimits
limits.SystemBaseLimit.ConnsInbound = 512
limits.SystemBaseLimit.ConnsOutbound = 512
limits.SystemBaseLimit.Conns = 1024
limits.SystemBaseLimit.StreamsInbound = 8192
limits.SystemBaseLimit.StreamsOutbound = 8192
limits.SystemBaseLimit.Streams = 16384
scaled := limits.Scale(mem, fds)
log.Debugw("libp2p_limits",
"labels", string(observability.LabelNode),
"limits", limits,
)
mgr, err := rcmgr.NewResourceManager(rcmgr.NewFixedLimiter(scaled))
if err != nil {
return nil, nil, nil, nil, nil, err
}
// get quic port from listen address
quicPort := 0
hasQUICPort := false
for _, addr := range config.ListenAddress {
maddr := ma.StringCast(addr)
maddrComps := ma.Split(maddr)
for _, comp := range maddrComps {
if comp.Protocol().Code == ma.P_UDP {
quicPort, err = strconv.Atoi(comp.Value())
if err != nil {
log.Errorf("failed to parse QUIC port from address %s: %v", addr, err)
return nil, nil, nil, nil, nil, fmt.Errorf("failed to parse QUIC port from address %s: %v", addr, err)
}
log.Debugf("QUIC port found in address %s: %d", addr, quicPort)
hasQUICPort = true
break
}
}
}
// quic port must be set
if !hasQUICPort {
log.Errorf("QUIC port not found in listen addresses")
return nil, nil, nil, nil, nil, fmt.Errorf("QUIC port not found in listen addresses")
}
udpConn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: quicPort})
if err != nil {
return nil, nil, nil, nil, nil, fmt.Errorf("failed to create UDP listener: %w", err)
}
rqtr := NewRawQUICTransport(udpConn)
newReuse := func(statelessResetKey quic.StatelessResetKey, tokenGeneratorKey quic.TokenGeneratorKey) (*quicreuse.ConnManager, error) {
reuseConnM, err := quicreuse.NewConnManager(
statelessResetKey,
tokenGeneratorKey,
)
if err != nil {
return nil, fmt.Errorf("failed to create reuse: %w", err)
}
trDone, err := reuseConnM.LendTransport("udp4", rqtr, udpConn)
if err != nil {
return nil, fmt.Errorf("failed to add transport to reuse: %w", err)
}
go func() {
// wait for the connection manager to be done to close the raw quic transport
<-trDone
log.Info("closing raw quic transport")
rqtr.Close()
}()
return reuseConnM, nil
}
libp2pOpts = append(libp2pOpts,
libp2p.ListenAddrStrings(config.ListenAddress...),
libp2p.ResourceManager(mgr),
libp2p.Identity(config.PrivateKey),
libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) {
idht, err = dht.New(ctx, h, dhtOpts...)
return idht, err
}),
libp2p.Peerstore(ps),
libp2p.Security(libp2ptls.ID, libp2ptls.New),
libp2p.Security(noise.ID, noise.New),
libp2p.ChainOptions(
libp2p.Transport(tcp.NewTCPTransport),
libp2p.Transport(libp2pquic.NewTransport),
libp2p.Transport(webtransport.New),
libp2p.Transport(ws.New),
),
libp2p.ConnectionManager(connmgr),
libp2p.EnableNATService(),
libp2p.EnableAutoNATv2(),
libp2p.EnableRelay(),
libp2p.EnableRelayService(
relay.WithLimit(&relay.RelayLimit{
Duration: 10 * time.Minute,
Data: 1 << 22, // 4 MiB
}),
),
libp2p.EnableAutoRelayWithPeerSource(
func(ctx context.Context, num int) <-chan peer.AddrInfo {
r := make(chan peer.AddrInfo)
go func() {
defer close(r)
for i := 0; i < num; i++ {
select {
case p := <-newPeer:
select {
case r <- p:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return r
},
autorelay.WithBootDelay(30*time.Second),
autorelay.WithBackoff(30*time.Second),
autorelay.WithMinCandidates(3),
autorelay.WithMaxCandidates(6),
autorelay.WithNumRelays(2),
),
libp2p.EnableHolePunching(holepunch.WithAddrFilter(&quicAddrFilter{})),
// libp2p.EnableHolePunching(),
libp2p.QUICReuse(newReuse),
)
if config.Server {
libp2pOpts = append(libp2pOpts, libp2p.AddrsFactory(makeAddrsFactory([]string{}, []string{}, defaultServerFilters)))
libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater((*filtersConnectionGater)(filter)))
}
host, err := libp2p.New(libp2pOpts...)
if err != nil {
return nil, nil, nil, nil, nil, err
}
go watchForNewPeers(ctx, host, newPeer)
optsPS := []pubsub.Option{
pubsub.WithFloodPublish(true),
pubsub.WithMessageSigning(true),
pubsub.WithPeerScore(
&pubsub.PeerScoreParams{
SkipAtomicValidation: true,
Topics: make(map[string]*pubsub.TopicScoreParams),
TopicScoreCap: 10,
AppSpecificScore: appScore,
AppSpecificWeight: 1,
DecayInterval: time.Hour,
DecayToZero: 0.001,
RetainScore: 6 * time.Hour,
},
&pubsub.PeerScoreThresholds{
GossipThreshold: -500,
PublishThreshold: -1000,
GraylistThreshold: -2500,
AcceptPXThreshold: 0, // TODO for public mainnet we should limit to botostrappers and set them up without a mesh
OpportunisticGraftThreshold: 2.5,
},
),
pubsub.WithPeerExchange(true),
pubsub.WithPeerScoreInspect(scoreInspect, time.Second),
pubsub.WithStrictSignatureVerification(true),
}
if config.GossipMaxMessageSize > 0 {
optsPS = append(optsPS, pubsub.WithMaxMessageSize(config.GossipMaxMessageSize))
}
gossip, err := pubsub.NewGossipSub(ctx, host, optsPS...)
// gossip, err := pubsub.NewGossipSubWithRouter(ctx, host, pubsub.DefaultGossipSubRouter(host), optsPS...)
if err != nil {
return nil, nil, nil, nil, nil, err
}
return host, idht, gossip, udpConn, rqtr, nil
}
func watchForNewPeers(ctx context.Context, host host.Host, newPeer chan peer.AddrInfo) {
sub, err := host.EventBus().Subscribe([]interface{}{
&event.EvtPeerIdentificationCompleted{},
})
if err != nil {
log.Errorw("failed to subscribe to peer identification events",
"labels", string(observability.LabelNode),
"error", err,
)
return
}
defer sub.Close()
for ctx.Err() == nil {
var ev any
select {
case <-ctx.Done():
return
case ev = <-sub.Out():
}
if ev, ok := ev.(event.EvtPeerIdentificationCompleted); ok {
identPeer := peer.AddrInfo{
ID: ev.Peer,
Addrs: ev.ListenAddrs,
}
go handleNewPeers(ctx, identPeer, newPeer)
}
}
}
func handleNewPeers(ctx context.Context, identifiedPeer peer.AddrInfo, newPeer chan peer.AddrInfo) {
select {
case <-ctx.Done():
return
default:
var publicAddrs []ma.Multiaddr
for _, addr := range identifiedPeer.Addrs {
if manet.IsPublicAddr(addr) {
publicAddrs = append(publicAddrs, addr)
}
}
if len(publicAddrs) > 0 {
newPeer <- peer.AddrInfo{ID: identifiedPeer.ID, Addrs: publicAddrs}
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"bufio"
"bytes"
"context"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/netip"
"slices"
"strings"
"sync"
"time"
"github.com/quic-go/quic-go/http3"
"github.com/quic-go/quic-go"
crypto "gitlab.com/nunet/device-management-service/lib/crypto"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"gitlab.com/nunet/device-management-service/utils/sys"
cid "github.com/ipfs/go-cid"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
libp2pdiscovery "github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
drouting "github.com/libp2p/go-libp2p/p2p/discovery/routing"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
multiaddr "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
multihash "github.com/multiformats/go-multihash"
msmux "github.com/multiformats/go-multistream"
"github.com/spf13/afero"
"google.golang.org/protobuf/proto"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/observability"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/types"
)
const (
MB = 1024 * 1024
maxMessageLengthMB = 1
ValidationAccept = pubsub.ValidationAccept
ValidationReject = pubsub.ValidationReject
ValidationIgnore = pubsub.ValidationIgnore
readTimeout = 30 * time.Second
sendSemaphoreLimit = 4096
)
type (
PeerID = peer.ID
ProtocolID = protocol.ID
Topic = pubsub.Topic
PubSub = pubsub.PubSub
ValidationResult = pubsub.ValidationResult
Validator func([]byte, interface{}) (ValidationResult, interface{})
PeerScoreSnapshot = pubsub.PeerScoreSnapshot
)
// Libp2p contains the configuration for a Libp2p instance.
//
// TODO-suggestion: maybe we should call it something else like Libp2pPeer,
// Libp2pHost or just Peer (callers would use libp2p.Peer...)
type Libp2p struct {
Host host.Host
DHT *dht.IpfsDHT
PS peerstore.Peerstore
pubsub *PubSub
ctx context.Context
cancel func()
mx sync.Mutex
pubsubAppScore func(peer.ID) float64
pubsubScore map[peer.ID]*PeerScoreSnapshot
topicMux sync.RWMutex
pubsubTopics map[string]*Topic
topicValidators map[string]map[uint64]Validator
topicSubscription map[string]map[uint64]*pubsub.Subscription
nextTopicSubID uint64
// send backpressure semaphore
sendSemaphore chan struct{}
// a list of peers discovered by discovery
discoveredPeers []peer.AddrInfo
discovery libp2pdiscovery.Discovery
// channel to signal when a public IP address has been confirmed
observeAddrMx sync.RWMutex
observedAddrCond sync.Cond
observedAddr multiaddr.Multiaddr
// services
pingService *ping.PingService
// tasks
discoveryTask *bt.Task
advertiseRendezvousTask *bt.Task
handlerRegistry *HandlerRegistry
config *types.Libp2pConfig
// dependencies (db, filesystem...)
fs afero.Fs
subnetsmx sync.Mutex
subnets map[string]*subnet
isHTTPServerRegistered int32
// for ip proxying in subnets
ipproxy *http3.Server
ipproxyCtx context.Context
ipproxyCtxCancel func()
ipproxyConns map[string]*quic.Conn
ipproxyConnsMx sync.Mutex
udpln *net.UDPConn
rawqtr *RawQUICTransport
NetIfaceFactory NetInterfaceFactory // Injected factory for creating NetInterface (for testing/mocking)
}
// This results in a cyclic dependency error
// var _ dmsNetwork.Network = (*Libp2p)(nil)
// TODO: remove this once we move the network types and interfaces to the types package
// New creates a libp2p instance.
//
// TODO-Suggestion: move types.Libp2pConfig to here for better readability.
// Unless there is a reason to keep within types.
func New(config *types.Libp2pConfig, fs afero.Fs) (*Libp2p, error) {
if config == nil {
return nil, errors.New("config is nil")
}
if config.Scheduler == nil {
return nil, errors.New("scheduler is nil")
}
var netIfaceFactory NetInterfaceFactory
if config.NetIfaceFactory != nil {
netIfaceFactory = func(name string) (sys.NetInterface, error) {
iface, err := config.NetIfaceFactory(name)
if err != nil {
return nil, err
}
return iface.(sys.NetInterface), nil
}
} else if config.NetIfaceFactory == nil {
netIfaceFactory = func(name string) (sys.NetInterface, error) {
return sys.NewTunTapInterface(name, sys.NetTunMode, false)
}
}
l := &Libp2p{
config: config,
discoveredPeers: make([]peer.AddrInfo, 0),
pubsubTopics: make(map[string]*pubsub.Topic),
topicSubscription: make(map[string]map[uint64]*pubsub.Subscription),
topicValidators: make(map[string]map[uint64]Validator),
sendSemaphore: make(chan struct{}, sendSemaphoreLimit),
fs: fs,
subnets: make(map[string]*subnet),
NetIfaceFactory: netIfaceFactory, // from config or nil
}
l.observedAddrCond = *sync.NewCond(&l.observeAddrMx)
return l, nil
}
// Init initializes a libp2p host with its dependencies.
func (l *Libp2p) Init(cfg *config.Config) error {
ctx, cancel := context.WithCancel(context.Background())
host, dht, pubsub, udpConn, rqtr, err := NewHost(ctx, l.config, l.broadcastAppScore, l.broadcastScoreInspect)
if err != nil {
cancel()
log.Error(err)
return err
}
l.ctx = ctx
l.cancel = cancel
l.Host = host
l.DHT = dht
l.PS = host.Peerstore()
l.discovery = drouting.NewRoutingDiscovery(dht)
l.pubsub = pubsub
l.handlerRegistry = NewHandlerRegistry(host)
l.udpln = udpConn
l.rawqtr = rqtr
l.rawqtr.network = l
// Extract the public key from the private key
publicKey := l.config.PrivateKey.GetPublic()
// Derive the DID from the public key
didInstance := did.FromPublicKey(publicKey)
if didInstance.Empty() {
return fmt.Errorf("failed to derive a valid DID from public key")
}
// Initialize the observability package with the host and DID
if err := observability.Initialize(l.Host, didInstance, cfg); err != nil {
return fmt.Errorf("failed to initialize observability: %w", err)
}
return nil
}
// Start performs network bootstrapping, peer discovery and protocols handling.
func (l *Libp2p) Start() error {
if l.Host == nil {
return ErrHostNotInitialized
}
// set stream handlers
l.registerStreamHandlers()
// connect to bootstrap nodes
err := l.connectToBootstrapNodes(l.ctx)
if err != nil {
log.Errorw("libp2p_bootstrap_failure", "labels", string(observability.LabelNode), "error", err)
return err
}
err = l.bootstrapDHT(l.ctx)
if err != nil {
log.Errorw("libp2p_bootstrap_failure", "labels", string(observability.LabelNode), "error", err)
return err
}
// Start random walk
l.startRandomWalk(l.ctx)
// watch for local address change
go l.watchForAddrsChange(l.ctx)
// discover
go func() {
// wait for dht bootstrap
time.Sleep(1 * time.Minute)
// advertise randevouz discovery
err = l.advertiseForRendezvousDiscovery(l.ctx)
if err != nil {
log.Warnw("libp2p_advertise_rendezvous_failure", "labels", string(observability.LabelNode), "error", err)
} else {
log.Debugw("libp2p_advertise_rendezvous_success", "labels", string(observability.LabelNode))
}
err = l.discoverDialPeers(l.ctx)
if err != nil {
log.Warnw("libp2p_peer_discover_failure", "labels", string(observability.LabelNode), "error", err)
} else {
log.Debugw("libp2p_peer_discover_success", "labels", string(observability.LabelNode), "foundPeers", len(l.discoveredPeers))
}
}()
// register period peer discoveryTask task
discoveryTask := &bt.Task{
Name: "Peer Discovery",
Description: "Periodic task to discover new peers every 15 minutes",
Function: func(_ interface{}) error {
return l.discoverDialPeers(l.ctx)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 15 * time.Minute}},
}
l.discoveryTask = l.config.Scheduler.AddTask(discoveryTask)
// register rendezvous advertisement task
advertiseRendezvousTask := &bt.Task{
Name: "Rendezvous advertisement",
Description: "Periodic task to advertise a rendezvous point every 6 hours",
Function: func(_ interface{}) error {
return l.advertiseForRendezvousDiscovery(l.ctx)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 6 * time.Hour}},
}
l.advertiseRendezvousTask = l.config.Scheduler.AddTask(advertiseRendezvousTask)
l.config.Scheduler.Start()
go l.watchForObservedAddr()
return nil
}
// RegisterStreamMessageHandler registers a stream handler for a specific protocol.
func (l *Libp2p) RegisterStreamMessageHandler(messageType types.MessageType, handler StreamHandler) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithStreamCallback(messageType, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// RegisterBytesMessageHandler registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) RegisterBytesMessageHandler(messageType types.MessageType, handler func(data []byte, peerId peer.ID)) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithBytesCallback(messageType, l.handleReadBytesFromStream, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// HandleMessage registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) HandleMessage(messageType string, handler func(data []byte, peerId peer.ID)) error {
return l.RegisterBytesMessageHandler(types.MessageType(messageType), handler)
}
func (l *Libp2p) handleReadBytesFromStream(s network.Stream) {
l.handlerRegistry.mu.RLock()
callback, ok := l.handlerRegistry.bytesHandlers[s.Protocol()]
l.handlerRegistry.mu.RUnlock()
if !ok {
_ = s.Reset()
return
}
if err := s.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
_ = s.Reset()
log.Warnf("error setting read deadline: %s", err)
return
}
c := bufio.NewReader(s)
defer s.Close()
// read the first 8 bytes to determine the size of the message
msgLengthBuffer := make([]byte, 8)
_, err := c.Read(msgLengthBuffer)
if err != nil {
log.Debugf("error reading message length: %s", err)
_ = s.Reset()
return
}
// create a buffer with the size of the message and then read until its full
lengthPrefix := binary.LittleEndian.Uint64(msgLengthBuffer)
// check if the message length is greater than max allowed
if lengthPrefix > maxMessageLengthMB*MB {
_ = s.Reset()
log.Warnf("message length exceeds maximum: %d", lengthPrefix)
return
}
buf := make([]byte, lengthPrefix)
// read the full message
_, err = io.ReadFull(c, buf)
if err != nil {
log.Debugf("error reading message: %s", err)
_ = s.Reset()
return
}
_ = s.Close()
callback(buf, s.Conn().RemotePeer())
}
// UnregisterMessageHandler unregisters a stream handler for a specific protocol.
func (l *Libp2p) UnregisterMessageHandler(messageType string) {
l.handlerRegistry.UnregisterHandler(types.MessageType(messageType))
}
// SendMessage asynchronously sends a message to a peer
func (l *Libp2p) SendMessage(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error {
pid, err := peer.Decode(hostID)
if err != nil {
return fmt.Errorf("send: invalid peer ID: %w", err)
}
// we are delivering a message to ourself
// we should use the handler to send the message to the handler directly which has been previously registered.
if pid == l.Host.ID() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data, pid)
return nil
}
ctx, cancel := context.WithTimeout(ctx, time.Until(expiry))
select {
case l.sendSemaphore <- struct{}{}:
go func() {
defer cancel()
defer func() { <-l.sendSemaphore }()
l.sendMessage(ctx, pid, msg, expiry, nil)
}()
return nil
case <-ctx.Done():
cancel()
return ctx.Err()
}
}
// SendMessageSync synchronously sends a message to a peer
func (l *Libp2p) SendMessageSync(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error {
pid, err := peer.Decode(hostID)
if err != nil {
return fmt.Errorf("send: invalid peer ID: %w", err)
}
if pid == l.Host.ID() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data, pid)
return nil
}
ctx, cancel := context.WithTimeout(ctx, time.Until(expiry))
defer cancel()
result := make(chan error, 1)
l.sendMessage(ctx, pid, msg, expiry, result)
return <-result
}
// workaround for https://github.com/libp2p/go-libp2p/issues/2983
func (l *Libp2p) newStream(ctx context.Context, pid peer.ID, proto protocol.ID) (network.Stream, error) {
s, err := l.Host.Network().NewStream(network.WithNoDial(ctx, "already dialed"), pid)
if err != nil {
return nil, err
}
selected, err := msmux.SelectOneOf([]protocol.ID{proto}, s)
if err != nil {
_ = s.Reset()
return nil, err
}
if err := s.SetProtocol(selected); err != nil {
_ = s.Reset()
return nil, err
}
return s, nil
}
func (l *Libp2p) sendMessage(ctx context.Context, pid peer.ID, msg types.MessageEnvelope, expiry time.Time, result chan error) {
var err error
defer func() {
if result != nil {
result <- err
}
}()
if !l.PeerConnected(pid) {
var ai peer.AddrInfo
ai, err = l.resolvePeerAddress(ctx, pid)
if err != nil {
log.Warnf("send: error resolving addresses for peer %s: %s", pid, err)
return
}
if err = l.Host.Connect(ctx, ai); err != nil {
log.Warnf("send: failed to connect to peer %s: %s", pid, err)
return
}
}
requestBufferSize := 8 + len(msg.Data)
if requestBufferSize > maxMessageLengthMB*MB {
log.Warnf("send: message size %d is greater than limit %d bytes", requestBufferSize, maxMessageLengthMB*MB)
err = fmt.Errorf("message too large")
return
}
ctx = network.WithAllowLimitedConn(ctx, "send message")
stream, err := l.newStream(ctx, pid, protocol.ID(msg.Type))
if err != nil {
log.Warnf("send: failed to open stream to peer %s: %s", pid, err)
return
}
defer stream.Close()
if err = stream.SetWriteDeadline(expiry); err != nil {
_ = stream.Reset()
log.Warnf("send: failed to set write deadline to peer %s: %s", pid, err)
return
}
requestPayloadWithLength := make([]byte, requestBufferSize)
binary.LittleEndian.PutUint64(requestPayloadWithLength, uint64(len(msg.Data)))
copy(requestPayloadWithLength[8:], msg.Data)
if _, err = stream.Write(requestPayloadWithLength); err != nil {
_ = stream.Reset()
log.Warnf("send: failed to send message to peer %s: %s", pid, err)
}
if err = stream.CloseWrite(); err != nil {
_ = stream.Reset()
log.Warnf("send: failed to flush output to peer %s: %s", pid, err)
}
log.Debugf("send %d bytes to peer %s", len(requestPayloadWithLength), pid)
}
// OpenStream opens a stream to a remote address and returns the stream for the caller to handle.
func (l *Libp2p) OpenStream(ctx context.Context, addr string, messageType types.MessageType) (network.Stream, error) {
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil, fmt.Errorf("invalid multiaddress: %w", err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(maddr)
if err != nil {
return nil, fmt.Errorf("could not resolve peer info: %w", err)
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(messageType))
if err != nil {
return nil, fmt.Errorf("failed to open stream: %w", err)
}
return stream, nil
}
// GetMultiaddr returns the peer's multiaddr.
func (l *Libp2p) GetMultiaddr() ([]multiaddr.Multiaddr, error) {
peerInfo := peer.AddrInfo{
ID: l.Host.ID(),
Addrs: l.Host.Addrs(),
}
return peer.AddrInfoToP2pAddrs(&peerInfo)
}
// Stop performs a cleanup of any resources used in this package.
func (l *Libp2p) Stop() error {
var errorMessages []string
// Cancel context if not nil
if l.cancel != nil {
l.cancel()
}
// Remove scheduled tasks if scheduler exists
if l.config != nil && l.config.Scheduler != nil {
// Only remove tasks if they exist
if l.discoveryTask != nil {
l.config.Scheduler.RemoveTask(l.discoveryTask.ID)
}
if l.advertiseRendezvousTask != nil {
l.config.Scheduler.RemoveTask(l.advertiseRendezvousTask.ID)
}
}
// Close DHT if not nil
if l.DHT != nil {
if err := l.DHT.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
}
// Close Host if not nil
if l.Host != nil {
if err := l.Host.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
}
// Close subnets
if l.subnets != nil {
for subnetID := range l.subnets {
err := l.DestroySubnet(subnetID)
if err != nil {
errorMessages = append(errorMessages, err.Error())
}
}
}
if l.udpln != nil {
if err := l.udpln.Close(); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
errorMessages = append(errorMessages, err.Error())
}
}
if len(errorMessages) > 0 {
return errors.New(strings.Join(errorMessages, "; "))
}
return nil
}
// Stat returns the status about the libp2p network.
func (l *Libp2p) Stat() types.NetworkStats {
lAddrs := make([]string, 0, len(l.Host.Addrs()))
for _, addr := range l.Host.Addrs() {
lAddrs = append(lAddrs, addr.String())
}
return types.NetworkStats{
ID: l.Host.ID().String(),
ListenAddr: strings.Join(lAddrs, ", "),
}
}
// Peers returns a list of peers from the peer store
func (l *Libp2p) Peers() []peer.ID {
return l.PS.Peers()
}
// Connect connects to a peer by its multiaddress and returns an error if any
func (l *Libp2p) Connect(ctx context.Context, peerMultiAddr string) error {
if peerMultiAddr == "" {
return fmt.Errorf("peer multiaddress is empty")
}
log.Infow("Creating multiaddress from peerMultiAddr", "labels", string(observability.LabelNode),
"addr", peerMultiAddr)
peerAddr, err := multiaddr.NewMultiaddr(peerMultiAddr)
if err != nil {
log.Errorw("Invalid multiaddress", "labels", string(observability.LabelNode), "error", err)
return fmt.Errorf("invalid multiaddress: %w", err)
}
log.Infow("Resolving peer info from multiaddress", "labels", string(observability.LabelNode))
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
log.Infow("Could not resolve peer info", "labels", string(observability.LabelNode), "error", err)
return fmt.Errorf("could not resolve peer info: %w", err)
}
log.Infow("Connecting to peer", "labels", string(observability.LabelNode), "addr", peerMultiAddr)
if err := l.Host.Connect(ctx, *addrInfo); err != nil {
log.Errorw("Failed to connect to peer", "labels", string(observability.LabelNode),
"addr", peerMultiAddr, "error", err)
return fmt.Errorf("failed to connect to peer %s: %w", peerMultiAddr, err)
}
return nil
}
// GetPeerIP gets the ip of the peer from the peer store
func (l *Libp2p) GetPeerIP(p PeerID) string {
addrs := l.Host.Peerstore().Addrs(p)
for _, addr := range addrs {
addrParts := strings.Split(addr.String(), "/")
for i, part := range addrParts {
if part == "ip4" || part == "ip6" {
return addrParts[i+1]
}
}
}
return ""
}
func (l *Libp2p) watchForObservedAddr() {
log.Infof("watching for observed address")
l.observeAddrMx.Lock()
defer l.observeAddrMx.Unlock()
sub, err := l.Host.EventBus().Subscribe(new(event.EvtPeerIdentificationCompleted))
if err != nil {
log.Errorf("could not subscribe to event: %w", err)
return
}
defer sub.Close()
// track address observations
addrCount := make(map[string]int)
var addrMux sync.Mutex
for e := range sub.Out() {
event := e.(event.EvtPeerIdentificationCompleted)
if event.ObservedAddr.String() == "" {
continue
}
// peer that reported the event
isPeerPublic := slices.ContainsFunc(event.ListenAddrs, manet.IsPublicAddr)
if !isPeerPublic {
continue
}
if !manet.IsPublicAddr(event.ObservedAddr) {
continue
}
// skip relays
addrStr := event.ObservedAddr.String()
if strings.Contains(addrStr, "p2p-circuit") {
continue
}
ip, err := manet.ToIP(event.ObservedAddr)
if err != nil {
continue
}
addrMux.Lock()
addrCount[ip.String()]++
count := addrCount[ip.String()]
addrMux.Unlock()
log.Infof("got public ip: %s (seen %d times)", ip.String(), count)
if count >= 3 {
l.observedAddr = event.ObservedAddr
// send the observed address on the channel
l.observedAddrCond.Broadcast()
return
}
}
}
// GetHostID returns the host ID.
func (l *Libp2p) GetHostID() PeerID {
return l.Host.ID()
}
// GetPeerPubKey returns the public key for the given peerID.
func (l *Libp2p) GetPeerPubKey(peerID PeerID) crypto.PubKey {
return l.Host.Peerstore().PubKey(peerID)
}
// Ping the remote address. The remote address is the encoded peer id which will be decoded and used here.
//
// TODO (Return error once): something that was confusing me when using this method is that the error is
// returned twice if any. Once as a field of PingResult and one as a return value.
func (l *Libp2p) Ping(ctx context.Context, peerIDAddress string, timeout time.Duration) (types.PingResult, error) {
// avoid dial to self attempt
if peerIDAddress == l.Host.ID().String() {
err := errors.New("can't ping self")
return types.PingResult{Success: false, Error: err}, err
}
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
remotePeer, err := peer.Decode(peerIDAddress)
if err != nil {
return types.PingResult{}, err
}
// ensure we are connected to the peer before pinging
if l.Host.Network().Connectedness(remotePeer) != network.Connected {
err = l.Connect(pingCtx, fmt.Sprintf("/p2p/%s", peerIDAddress))
if err != nil {
return types.PingResult{
Success: false,
Error: err,
}, err
}
}
pingChan := ping.Ping(pingCtx, l.Host, remotePeer)
select {
case res := <-pingChan:
if res.Error != nil {
log.Errorf("failed to ping peer %s: %v", peerIDAddress, res.Error)
return types.PingResult{
Success: false,
RTT: res.RTT,
Error: res.Error,
}, res.Error
}
return types.PingResult{
RTT: res.RTT,
Success: true,
}, nil
case <-pingCtx.Done():
return types.PingResult{
Error: pingCtx.Err(),
}, pingCtx.Err()
}
}
// ResolveAddress resolves the address by given a peer id.
func (l *Libp2p) ResolveAddress(ctx context.Context, id string) ([]string, error) {
ai, err := l.resolveAddress(ctx, id)
if err != nil {
return nil, err
}
result := make([]string, 0, len(ai.Addrs))
for _, addr := range ai.Addrs {
result = append(result, fmt.Sprintf("%s/p2p/%s", addr, id))
}
return result, nil
}
func (l *Libp2p) resolveAddress(ctx context.Context, id string) (peer.AddrInfo, error) {
pid, err := peer.Decode(id)
if err != nil {
return peer.AddrInfo{}, fmt.Errorf("failed to resolve invalid peer: %w", err)
}
return l.resolvePeerAddress(ctx, pid)
}
func (l *Libp2p) resolvePeerAddress(ctx context.Context, pid peer.ID) (peer.AddrInfo, error) {
// resolve ourself
if l.Host.ID() == pid {
addrs, err := l.GetMultiaddr()
if err != nil {
return peer.AddrInfo{}, fmt.Errorf("failed to resolve self: %w", err)
}
return peer.AddrInfo{ID: pid, Addrs: addrs}, nil
}
if l.PeerConnected(pid) {
addrs := l.Host.Peerstore().Addrs(pid)
return peer.AddrInfo{
ID: pid,
Addrs: addrs,
}, nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
pi, err := l.DHT.FindPeer(ctx, pid)
if err != nil {
return peer.AddrInfo{}, fmt.Errorf("failed to resolve address for peer %s: %w", pid, err)
}
return pi, nil
}
// Query return all the advertisements in the network related to a key.
// The network is queried to find providers for the given key, and peers which we aren't connected to can be retrieved.
func (l *Libp2p) Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error) {
if key == "" {
return nil, errors.New("advertisement key is empty")
}
customCID, err := createCIDFromKey(key)
if err != nil {
return nil, fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
addrInfo, err := l.DHT.FindProviders(ctx, customCID)
if err != nil {
return nil, fmt.Errorf("failed to find providers for key %s: %w", key, err)
}
advertisements := make([]*commonproto.Advertisement, 0)
for _, v := range addrInfo {
// TODO: use go routines to get the values in parallel.
bytesAdvertisement, err := l.DHT.GetValue(ctx, l.getCustomNamespace(key, v.ID.String()))
if err != nil {
continue
}
var ad commonproto.Advertisement
if err := proto.Unmarshal(bytesAdvertisement, &ad); err != nil {
return nil, fmt.Errorf("failed to unmarshal advertisement payload: %w", err)
}
advertisements = append(advertisements, &ad)
}
return advertisements, nil
}
// Advertise given data and a key pushes the data to the dht.
func (l *Libp2p) Advertise(ctx context.Context, key string, data []byte) error {
if key == "" {
return errors.New("advertisement key is empty")
}
pubKeyBytes, err := l.getPublicKey()
if err != nil {
return fmt.Errorf("failed to get public key: %w", err)
}
envelope := &commonproto.Advertisement{
PeerId: l.Host.ID().String(),
Timestamp: time.Now().Unix(),
Data: data,
PublicKey: pubKeyBytes,
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
pubKeyBytes,
}, nil)
sig, err := l.sign(concatenatedBytes)
if err != nil {
return fmt.Errorf("failed to sign advertisement envelope content: %w", err)
}
envelope.Signature = sig
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
return fmt.Errorf("failed to marshal advertise envelope: %w", err)
}
customCID, err := createCIDFromKey(key)
if err != nil {
return fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
err = l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), envelopeBytes)
if err != nil {
return fmt.Errorf("failed to put key %s into the dht: %w", key, err)
}
err = l.DHT.Provide(ctx, customCID, true)
if err != nil {
return fmt.Errorf("failed to provide key %s into the dht: %w", key, err)
}
log.Infof("advertised key: %s", key)
return nil
}
// Unadvertise removes the data from the dht.
func (l *Libp2p) Unadvertise(ctx context.Context, key string) error {
err := l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), nil)
if err != nil {
return fmt.Errorf("failed to remove key %s from the DHT: %w", key, err)
}
return nil
}
// Publish publishes data to a topic.
// The requirements are that only one topic handler should exist per topic.
func (l *Libp2p) Publish(ctx context.Context, topic string, data []byte) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to publish: %w", err)
}
err = topicHandler.Publish(ctx, data)
if err != nil {
return fmt.Errorf("failed to publish to topic %s: %w", topic, err)
}
return nil
}
// Subscribe subscribes to a topic and sends the messages to the handler.
func (l *Libp2p) Subscribe(ctx context.Context, topic string, handler func(data []byte), validator Validator) (uint64, error) {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return 0, fmt.Errorf("failed to subscribe to topic: %w", err)
}
sub, err := topicHandler.Subscribe()
if err != nil {
return 0, fmt.Errorf("failed to subscribe to topic %s: %w", topic, err)
}
l.topicMux.Lock()
subID := l.nextTopicSubID
l.nextTopicSubID++
topicMap, ok := l.topicSubscription[topic]
if !ok {
topicMap = make(map[uint64]*pubsub.Subscription)
l.topicSubscription[topic] = topicMap
}
if validator != nil {
validatorMap, ok := l.topicValidators[topic]
if !ok {
if err := l.pubsub.RegisterTopicValidator(topic, l.validate); err != nil {
sub.Cancel()
return 0, fmt.Errorf("failed to register topic validator: %w", err)
}
validatorMap = make(map[uint64]Validator)
l.topicValidators[topic] = validatorMap
}
validatorMap[subID] = validator
}
topicMap[subID] = sub
l.topicMux.Unlock()
go func() {
for {
msg, err := sub.Next(ctx)
if err != nil {
continue
}
handler(msg.Data)
}
}()
return subID, nil
}
func (l *Libp2p) validate(_ context.Context, _ peer.ID, msg *pubsub.Message) ValidationResult {
l.topicMux.RLock()
validators, ok := l.topicValidators[msg.GetTopic()]
l.topicMux.RUnlock()
if !ok {
return ValidationAccept
}
for _, validator := range validators {
result, _ := validator(msg.Data, msg.ValidatorData)
if result != ValidationAccept {
return result
}
}
return ValidationAccept
}
func (l *Libp2p) SetupBroadcastTopic(topic string, setup func(*Topic) error) error {
t, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("%w: %s", ErrNotSubscribed, topic)
}
return setup(t)
}
func (l *Libp2p) SetBroadcastAppScore(f func(peer.ID) float64) {
l.mx.Lock()
defer l.mx.Unlock()
l.pubsubAppScore = f
}
func (l *Libp2p) broadcastAppScore(p peer.ID) float64 {
f := func(peer.ID) float64 { return 0 }
l.mx.Lock()
if l.pubsubAppScore != nil {
f = l.pubsubAppScore
}
l.mx.Unlock()
return f(p)
}
func (l *Libp2p) GetBroadcastScore() map[peer.ID]*PeerScoreSnapshot {
l.mx.Lock()
defer l.mx.Unlock()
return l.pubsubScore
}
func (l *Libp2p) broadcastScoreInspect(score map[peer.ID]*PeerScoreSnapshot) {
l.mx.Lock()
defer l.mx.Unlock()
l.pubsubScore = score
}
func (l *Libp2p) watchForAddrsChange(ctx context.Context) {
sub, err := l.Host.EventBus().Subscribe(&event.EvtLocalAddressesUpdated{})
if err != nil {
log.Errorf("failed to subscribe to event bus: %v", err)
return
}
for {
select {
case <-ctx.Done():
return
case <-sub.Out():
log.Debug("network address changed. trying to be bootstrap again.")
if err = l.connectToBootstrapNodes(l.ctx); err != nil {
log.Errorf("failed to start network: %v", err)
}
}
}
}
func (l *Libp2p) Notify(
ctx context.Context,
preconnected func(peer.ID, []protocol.ID, int),
connected, disconnected func(peer.ID),
identified, updated func(peer.ID, []protocol.ID),
) error {
sub, err := l.Host.EventBus().Subscribe([]interface{}{
&event.EvtPeerConnectednessChanged{},
&event.EvtPeerIdentificationCompleted{},
&event.EvtPeerProtocolsUpdated{},
})
if err != nil {
return fmt.Errorf("failed to subscribe to event bus: %w", err)
}
for _, p := range l.Host.Network().Peers() {
switch l.Host.Network().Connectedness(p) {
case network.Limited:
fallthrough
case network.Connected:
protos, _ := l.Host.Peerstore().GetProtocols(p)
preconnected(p, protos, len(l.Host.Network().ConnsToPeer(p)))
}
}
go func() {
defer sub.Close()
for ctx.Err() == nil {
var ev any
select {
case <-ctx.Done():
return
case ev = <-sub.Out():
switch evt := ev.(type) {
case event.EvtPeerConnectednessChanged:
switch evt.Connectedness {
case network.Limited:
fallthrough
case network.Connected:
connected(evt.Peer)
case network.NotConnected:
disconnected(evt.Peer)
}
case event.EvtPeerIdentificationCompleted:
identified(evt.Peer, evt.Protocols)
case event.EvtPeerProtocolsUpdated:
updated(evt.Peer, evt.Added)
}
}
}
}()
return nil
}
func (l *Libp2p) PeerConnected(p PeerID) bool {
switch l.Host.Network().Connectedness(p) {
case network.Limited:
return true
case network.Connected:
return true
default:
return false
}
}
// getOrJoinTopicHandler gets the topic handler, it will be created if it doesn't exist.
// for publishing and subscribing its needed therefore its implemented in this function.
func (l *Libp2p) getOrJoinTopicHandler(topic string) (*pubsub.Topic, error) {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
t, err := l.pubsub.Join(topic)
if err != nil {
return nil, fmt.Errorf("failed to join topic %s: %w", topic, err)
}
topicHandler = t
l.pubsubTopics[topic] = t
}
return topicHandler, nil
}
// Unsubscribe cancels the subscription to a topic
func (l *Libp2p) Unsubscribe(topic string, subID uint64) error {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to topic: %s", topic)
}
topicValidators, ok := l.topicValidators[topic]
if ok {
delete(topicValidators, subID)
}
// delete subscription handler and subscription
topicSubscriptions, ok := l.topicSubscription[topic]
if ok {
sub, ok := topicSubscriptions[subID]
if ok {
sub.Cancel()
delete(topicSubscriptions, subID)
}
}
if len(topicSubscriptions) == 0 {
delete(l.pubsubTopics, topic)
if err := topicHandler.Close(); err != nil {
return fmt.Errorf("failed to close topic handler: %w", err)
}
}
return nil
}
func (l *Libp2p) HostPublicIP() (net.IP, error) {
if l.config.Env == "dev" || l.config.Env == "test" {
log.Infow("host public ip: using listening IP since in dev or test environment")
return l.listeningIP()
}
log.Infow("checking observed public IP...")
addr, err := l.waitForObservedAddr(l.ctx)
if err != nil {
return nil, fmt.Errorf("failed to retrieve observed addr: %w", err)
}
log.Infow("obtained observed public IP", "addr", addr.String())
return manet.ToIP(addr)
}
func (l *Libp2p) listeningIP() (net.IP, error) {
var privIP net.IP
var hasPrivIP bool
if len(l.Host.Addrs()) > 4 && l.config.Env != "production" {
for _, addr := range l.Host.Addrs() {
if manet.IsPrivateAddr(addr) && !strings.Contains(addr.String(), "/ip4/127.0.0.1/udp") {
ip, err := manet.ToIP(addr)
if err != nil {
return nil, fmt.Errorf("failed to convert multiaddr to IP: %w", err)
}
return ip, nil
}
}
} else {
for _, addr := range l.Host.Addrs() {
if manet.IsPublicAddr(addr) {
ip, err := manet.ToIP(addr)
if err != nil {
return nil, fmt.Errorf("failed to convert multiaddr to IP: %w", err)
}
return ip, nil
} else if manet.IsPrivateAddr(addr) {
ip, err := manet.ToIP(addr)
if err != nil {
return nil, fmt.Errorf("failed to convert multiaddr to IP: %w", err)
}
privIP = ip
hasPrivIP = true
}
}
if hasPrivIP {
return privIP, nil
}
}
return net.ParseIP("127.0.0.1"), nil
}
// WaitForObservedAddr waits for the node to confirm its public IP address
// Returns the observed multiaddress or an error if the context expires
func (l *Libp2p) waitForObservedAddr(ctx context.Context) (multiaddr.Multiaddr, error) {
// if we already have an observed address, return it immediately
l.observeAddrMx.RLock()
if l.observedAddr != nil {
addr := l.observedAddr
l.observeAddrMx.RUnlock()
return addr, nil
}
l.observeAddrMx.RUnlock()
// otherwise wait for the signal
done := make(chan struct{})
var addr multiaddr.Multiaddr
var err error
go func() {
l.observeAddrMx.Lock()
defer l.observeAddrMx.Unlock()
for l.observedAddr == nil && ctx.Err() == nil {
l.observedAddrCond.Wait()
}
if l.observedAddr != nil {
addr = l.observedAddr
} else {
err = ctx.Err()
}
close(done)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
return addr, err
}
}
func (l *Libp2p) registerStreamHandlers() {
l.pingService = ping.NewPingService(l.Host)
}
func (l *Libp2p) sign(data []byte) ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
signature, err := privKey.Sign(data)
if err != nil {
return nil, fmt.Errorf("failed to sign data: %w", err)
}
return signature, nil
}
func (l *Libp2p) getPublicKey() ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
pubKey := privKey.GetPublic()
return pubKey.Raw()
}
func (l *Libp2p) RawQUICConnect(target peer.ID, serverName string) (*quic.Conn, *netip.AddrPort, error) {
if l.config.Env == "dev" || l.config.Env == "test" {
return l.rawQUICConnect(target, serverName, false)
}
return l.rawQUICConnect(target, serverName, true)
}
func (l *Libp2p) RawQUICConnectLocal(target peer.ID, serverName string) (*quic.Conn, *netip.AddrPort, error) {
return l.rawQUICConnect(target, serverName, false)
}
func (l *Libp2p) rawQUICConnect(target peer.ID, serverName string, onlyPublicAddress bool) (*quic.Conn, *netip.AddrPort, error) {
waitForConnection := func(timeout time.Duration, targetPeer peer.ID) error {
select {
case <-time.After(timeout):
return fmt.Errorf("timed out waiting for connection to peer %s", targetPeer)
default:
for {
if l.Host.Network().Connectedness(target) == network.Connected {
return nil
}
time.Sleep(50 * time.Millisecond)
}
}
}
// 1st step: check if we have cached QUIC addresses in peerstore
// This should be checked BEFORE trying DHT lookup
var udpAddr *net.UDPAddr
rawQuicAddrs := getRawQUICAddrs(l.Host.Peerstore().Addrs(target))
for _, a := range rawQuicAddrs {
fmt.Println("found address", a, "for target", target)
if onlyPublicAddress {
if manet.IsPublicAddr(a) && isQUICAddr(a) {
addr, err := quicAddrToNetAddr(a)
if err != nil {
return nil, nil, fmt.Errorf("failed to convert multiaddr to net.UDPAddr: %w", err)
}
udpAddr = addr
}
} else {
if isQUICAddr(a) {
addr, err := quicAddrToNetAddr(a)
if err != nil {
return nil, nil, fmt.Errorf("failed to convert multiaddr to net.UDPAddr: %w", err)
}
udpAddr = addr
break
}
}
}
if udpAddr != nil {
addr := udpAddr.AddrPort()
conn, err := dialSubnetQUICLayer(l, l.rawqtr.Transport, udpAddr, serverName)
return conn, &addr, err
}
// 2nd step: try DHT lookup if no cached addresses found
var ai peer.AddrInfo
var err error
if ai, err = l.DHT.FindPeer(context.Background(), target); err != nil {
log.Debugf("DHT FindPeer failed for %s: %v", target, err)
// Check if peer is already connected - use existing connection info
if l.PeerConnected(target) {
log.Debugf("Peer %s is already connected, using existing connection", target)
addrs := l.Host.Peerstore().Addrs(target)
if len(addrs) > 0 {
ai = peer.AddrInfo{ID: target, Addrs: addrs}
log.Debugf("Using existing connection addresses for %s", target)
} else {
return nil, nil, fmt.Errorf("peer is connected but no addresses in peerstore: %w", err)
}
} else {
// Try to use any cached addresses from peerstore even if DHT failed
cachedAddrs := l.Host.Peerstore().Addrs(target)
if len(cachedAddrs) > 0 {
log.Debugf("DHT failed but found cached addresses for %s, attempting connection", target)
ai = peer.AddrInfo{ID: target, Addrs: cachedAddrs}
} else {
return nil, nil, fmt.Errorf("failed to find peer: %w", err)
}
}
} else {
log.Debugf("DHT FindPeer succeeded for %s", target)
}
if err := l.Host.Connect(context.Background(), ai); err != nil {
if ai.ID == "" {
return nil, nil, fmt.Errorf("failed to find peer via DHT and no cached addresses available: %w", err)
}
return nil, nil, fmt.Errorf("failed to connect to peer: %w", err)
}
// As soon as the relayed peer accepts the connection via the relay,
// it tries to establish a direction connection back to us using the DCUtR protocol.
// We wait for this connection to be established.
err = waitForConnection(2*time.Minute, target)
if err != nil {
return nil, nil, fmt.Errorf("failed to establish direct connection to peer %s: %w", target, err)
}
// Now that we have a direct connection to the target, we can dial another
// QUIC connection on the same 4-tupe. This works since QUIC demultiplexes connections
// based on their connection ID.
var directAddr *net.UDPAddr
log.Infof("dialQUIC: connections to target %q : %+v", target.String(), l.Host.Network().ConnsToPeer(target))
for _, c := range l.Host.Network().ConnsToPeer(target) {
if a := c.RemoteMultiaddr(); isQUICAddr(a) {
directAddr, err = quicAddrToNetAddr(a)
if err != nil {
return nil, nil, fmt.Errorf("failed to convert multiaddr to net.UDPAddr: %w", err)
}
log.Infof("dialQUIC: found QUIC address: %s", directAddr.String())
break
}
}
// Due to https://github.com/libp2p/go-libp2p/issues/3101, we can't rely on the Connectedness connection state,
// as it doesn't distinguish between direct and connections via an unlimited relay.
start := time.Now()
ticker := time.NewTicker(25 * time.Millisecond)
defer ticker.Stop()
connectLoop:
for now := range ticker.C {
if now.Sub(start) > 5*time.Second {
break
}
log.Infof("dialQUIC: connections to target %q : %+v", target.String(), l.Host.Network().ConnsToPeer(target))
for _, c := range l.Host.Network().ConnsToPeer(target) {
if a := c.RemoteMultiaddr(); isQUICAddr(a) {
directAddr, err = quicAddrToNetAddr(a)
if err != nil {
return nil, nil, fmt.Errorf("failed to convert multiaddr to net.UDPAddr: %w", err)
}
log.Infof("dialQUIC: found QUIC address: %s", directAddr.String())
break connectLoop
}
}
}
if directAddr == nil {
return nil, nil, fmt.Errorf("failed to find a direct QUIC address for peer %s after hole punching", target)
}
log.Debugf("dialing QUIC address: %s", directAddr)
log.Debugf("found hole punched connection, addr: %s:%d", directAddr.IP.String(), directAddr.Port)
addr := directAddr.AddrPort()
conn, err := dialSubnetQUICLayer(l, l.rawqtr.Transport, directAddr, serverName)
return conn, &addr, err
}
func dialSubnetQUICLayer(l *Libp2p, tr *quic.Transport, addr *net.UDPAddr, servName string) (*quic.Conn, error) {
priv, err := l.config.PrivateKey.Raw()
if err != nil {
return nil, fmt.Errorf("failed to get private key: %w", err)
}
cert, err := generateSelfSignedCert(ed25519.PrivateKey(priv), []string{fmt.Sprintf("%s.nunet.internal", servName)})
if err != nil {
return nil, fmt.Errorf("failed to generate self signed certificate: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
log.Debugf("dialing QUIC address: %s", addr)
conn, err := tr.Dial(
ctx,
addr,
&tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: makeVerifySubnetPeerCertificateFn(l),
PreferServerCipherSuites: true,
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{*cert},
NextProtos: []string{"raw"},
ServerName: fmt.Sprintf("%s.nunet.internal", servName),
},
&quic.Config{
EnableDatagrams: true,
MaxIdleTimeout: 2 * 60 * time.Second, // Increase idle timeout
HandshakeIdleTimeout: 2 * 60 * time.Second, // Explicit handshake timeout
KeepAlivePeriod: 2 * 60 * time.Second, // Send keep-alive packets
MaxIncomingStreams: 1000,
},
)
if err != nil {
return nil, fmt.Errorf("failed to dial QUIC address: %w", err)
}
return conn, nil
}
func (l *Libp2p) getCustomNamespace(key, peerID string) string {
return fmt.Sprintf("%s-%s-%s", l.config.CustomNamespace, key, peerID)
}
func createCIDFromKey(key string) (cid.Cid, error) {
hash := sha256.Sum256([]byte(key))
mh, err := multihash.Encode(hash[:], multihash.SHA2_256)
if err != nil {
return cid.Cid{}, err
}
return cid.NewCidV1(cid.Raw, mh), nil
}
func makeVerifySubnetPeerCertificateFn(l *Libp2p) func([][]byte, [][]*x509.Certificate) error {
return func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
if len(l.subnets) == 0 {
return fmt.Errorf("peer not a member of any subnet, invalidating cert")
}
cert, err := x509.ParseCertificate(rawCerts[0])
if err != nil {
return fmt.Errorf("failed to parse certificate: %v", err)
}
// Check expiration
if time.Now().Before(cert.NotBefore) || time.Now().After(cert.NotAfter) {
return fmt.Errorf("certificate is expired or not yet valid")
}
subnetID := strings.Split(cert.Subject.CommonName, ".")[0]
// Check server name
if _, ok := l.subnets[subnetID]; ok && slices.Contains(cert.DNSNames, cert.Subject.CommonName) {
return fmt.Errorf("either server name does not match certificate or peer not a member of provided subnets")
}
var pubKey crypto.PubKey
switch algoType := strings.ToLower(cert.PublicKeyAlgorithm.String()); algoType {
case "ecdsa":
key, ok := cert.PublicKey.(ecdsa.PublicKey)
if !ok {
return fmt.Errorf("failed to cast public key to ecdsa type=%T", cert.PublicKey)
}
pubkey, err := key.ECDH()
if err != nil {
return fmt.Errorf("failed to get ecdh public key: %v", err)
}
pubKey, err = ic.UnmarshalECDSAPublicKey(pubkey.Bytes())
if err != nil {
return fmt.Errorf("failed to unmarshal ecdsa public key: %v", err)
}
case "rsa":
key, ok := cert.PublicKey.(rsa.PublicKey)
if !ok {
return fmt.Errorf("failed to cast public key to rsa, type=%T", cert.PublicKey)
}
rawBytes, err := x509.MarshalPKIXPublicKey(key)
if err != nil {
return fmt.Errorf("failed to marshal PKIX public key: %v", err)
}
pubKey, err = ic.UnmarshalRsaPublicKey(rawBytes)
if err != nil {
return fmt.Errorf("failed to unmarshal rsa public key: %v", err)
}
case "ed25519":
key, ok := cert.PublicKey.(ed25519.PublicKey)
if !ok {
return fmt.Errorf("failed to cast public key to ed25519, type=%T", cert.PublicKey)
}
pubkey, err := ic.UnmarshalEd25519PublicKey([]byte(key))
if err != nil {
return fmt.Errorf("failed to unmarshal ed25519 public key: %v", err)
}
pubKey = pubkey
default:
return fmt.Errorf("unsupported public key type: %T, %s", cert.PublicKey, cert.PublicKeyAlgorithm.String())
}
peerID, err := peer.IDFromPublicKey(pubKey)
if err != nil {
return fmt.Errorf("failed to get peer id from public key: %v", err)
}
for _, subnet := range l.subnets {
peerMap := subnet.info.rtable.All()
if _, ok := peerMap[peerID]; ok {
log.Debugf("peer is a member of subnet, allowing raw quic connection (peerID=%S, subnet=%s)", peerID.String(), subnet.info.id)
goto done
}
}
return fmt.Errorf("peer not a member of any subnet, invalidating cert")
done:
return nil
}
}
func getRawQUICAddrs(multiaddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
rawQuicAddrs := make([]multiaddr.Multiaddr, 0)
for _, a := range multiaddrs {
if isQUICAddr(a) {
_, err := quicAddrToNetAddr(a)
if err != nil {
log.Errorf("failed to convert multiaddr to net.UDPAddr: %v", err)
continue
}
rawQuicAddrs = append(rawQuicAddrs, a)
}
}
return rawQuicAddrs
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"fmt"
"math/big"
"net"
"slices"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/transport/quicreuse"
ma "github.com/multiformats/go-multiaddr"
mafmt "github.com/multiformats/go-multiaddr-fmt"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/quic-go/quic-go"
)
type RawQUICTransport struct {
*quic.Transport
listener *interceptingListener
network *Libp2p
listenerReady chan struct{}
}
type interceptingListener struct {
intercept []string
acceptQueue chan *quic.Conn
quicreuse.QUICListener
}
func NewRawQUICTransport(udpConn *net.UDPConn) *RawQUICTransport {
return &RawQUICTransport{Transport: &quic.Transport{Conn: udpConn}, listenerReady: make(chan struct{})}
}
func (t *RawQUICTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (quicreuse.QUICListener, error) {
wrappedConf := &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
if slices.Contains(info.SupportedProtos, "raw") {
priv, err := t.network.config.PrivateKey.Raw()
if err != nil {
return nil, fmt.Errorf("failed to get priv key to generate tls cert: %w", err)
}
var cert *tls.Certificate
cert, err = generateSelfSignedCert(ed25519.PrivateKey(priv), []string{info.ServerName})
if err != nil {
return nil, fmt.Errorf("failed to generate self signed cert: %w", err)
}
return &tls.Config{
ClientAuth: tls.RequireAnyClientCert,
Certificates: []tls.Certificate{*cert},
NextProtos: []string{"raw"},
InsecureSkipVerify: false,
VerifyPeerCertificate: makeVerifySubnetPeerCertificateFn(t.network),
}, nil
}
// use libp2p's tls.Config
if tlsConf.GetConfigForClient != nil {
return tlsConf.GetConfigForClient(info)
}
return tlsConf, nil
},
}
ln, err := t.Transport.Listen(wrappedConf, conf)
if err != nil {
return nil, err
}
t.listener = newInterceptingListener(ln, []string{"raw"})
close(t.listenerReady)
return t.listener, nil
}
func newInterceptingListener(ln quicreuse.QUICListener, intercept []string) *interceptingListener {
return &interceptingListener{
intercept: intercept,
acceptQueue: make(chan *quic.Conn, 32),
QUICListener: ln,
}
}
func (l *interceptingListener) Accept(ctx context.Context) (*quic.Conn, error) {
start:
conn, err := l.QUICListener.Accept(ctx)
if err != nil {
return nil, err
}
if conn.ConnectionState().TLS.NegotiatedProtocol == "raw" {
log.Debugf("intercepting a raw connection from: %s", conn.RemoteAddr())
l.acceptQueue <- conn
goto start
}
log.Debugf("accepting a non-raw connection from: %s", conn.RemoteAddr())
return conn, nil
}
type quicAddrFilter struct{}
func (f *quicAddrFilter) filterQUICIPv4(_ peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
return ma.FilterAddrs(maddrs, func(addr ma.Multiaddr) bool {
first, _ := ma.SplitFirst(addr)
if first == nil {
return false
}
if first.Protocol().Code != ma.P_IP4 {
return false
}
return isQUICAddr(addr)
})
}
func (f *quicAddrFilter) FilterRemote(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
return f.filterQUICIPv4(remoteID, maddrs)
}
func (f *quicAddrFilter) FilterLocal(remoteID peer.ID, maddrs []ma.Multiaddr) []ma.Multiaddr {
return f.filterQUICIPv4(remoteID, maddrs)
}
func generateSelfSignedCert(priv ed25519.PrivateKey, dnsNames []string) (*tls.Certificate, error) {
// Generate a new ed25519 key pair
pub := priv.Public().(ed25519.PublicKey)
// Create a certificate template
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-time.Hour), // Valid from 1 hour ago
NotAfter: time.Now().Add(24 * time.Hour * 365), // Valid for 1 year
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
DNSNames: dnsNames,
}
// Create the certificate
certDER, err := x509.CreateCertificate(rand.Reader, template, template, pub, priv)
if err != nil {
return nil, err
}
// Create the tls.Certificate
cert := &tls.Certificate{
Certificate: [][]byte{certDER},
PrivateKey: priv,
}
return cert, nil
}
func isQUICAddr(a ma.Multiaddr) bool {
return mafmt.And(mafmt.IP, mafmt.Base(ma.P_UDP), mafmt.Base(ma.P_QUIC_V1)).Matches(a)
}
func quicAddrToNetAddr(a ma.Multiaddr) (*net.UDPAddr, error) {
first, _ := ma.SplitFunc(a, func(c ma.Component) bool { return c.Protocol().Code == ma.P_QUIC_V1 })
if first == nil {
return nil, fmt.Errorf("no QUIC address found in multiaddr")
}
netAddr, err := manet.ToNetAddr(first)
if err != nil {
return nil, fmt.Errorf("failed to convert multiaddr to net.Addr: %w", err)
}
return netAddr.(*net.UDPAddr), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"slices"
"sync"
"github.com/libp2p/go-libp2p/core/peer"
)
type SubnetRoutingTable interface {
Add(peerID peer.ID, addr string)
Remove(peerID peer.ID, ip string)
Get(peerID peer.ID) ([]string, bool)
RemoveByIP(addr string)
GetByIP(addr string) (peer.ID, bool)
All() map[peer.ID][]string
Clear()
}
type rtable struct {
mx sync.RWMutex
idx map[peer.ID][]string
revIdx map[string]peer.ID
}
func NewRoutingTable() SubnetRoutingTable {
return &rtable{
idx: make(map[peer.ID][]string),
revIdx: make(map[string]peer.ID),
}
}
func (rt *rtable) Add(peerID peer.ID, addr string) {
rt.mx.Lock()
defer rt.mx.Unlock()
if _, ok := rt.idx[peerID]; !ok {
rt.idx[peerID] = make([]string, 0)
}
// Skip if the IP already exists
if slices.Contains(rt.idx[peerID], addr) {
return
}
rt.idx[peerID] = append(rt.idx[peerID], addr)
rt.revIdx[addr] = peerID
}
func (rt *rtable) Remove(peerID peer.ID, ip string) {
rt.mx.Lock()
defer rt.mx.Unlock()
addrs, ok := rt.idx[peerID]
if !ok {
return
}
// Use slices.DeleteFunc to safely remove the IP address
rt.idx[peerID] = slices.DeleteFunc(addrs, func(addr string) bool {
return addr == ip
})
if len(rt.idx[peerID]) == 0 {
delete(rt.idx, peerID)
}
delete(rt.revIdx, ip)
}
func (rt *rtable) Get(peerID peer.ID) ([]string, bool) {
rt.mx.RLock()
defer rt.mx.RUnlock()
addr, ok := rt.idx[peerID]
return addr, ok
}
func (rt *rtable) RemoveByIP(addr string) {
rt.mx.Lock()
defer rt.mx.Unlock()
peerID, ok := rt.revIdx[addr]
if !ok {
return
}
delete(rt.idx, peerID)
delete(rt.revIdx, addr)
}
func (rt *rtable) GetByIP(addr string) (peer.ID, bool) {
rt.mx.RLock()
defer rt.mx.RUnlock()
peerID, ok := rt.revIdx[addr]
return peerID, ok
}
func (rt *rtable) All() map[peer.ID][]string {
rt.mx.RLock()
defer rt.mx.RUnlock()
idx := make(map[peer.ID][]string)
for k, v := range rt.idx {
idx[k] = v
}
return idx
}
func (rt *rtable) Clear() {
rt.mx.Lock()
defer rt.mx.Unlock()
rt.idx = make(map[peer.ID][]string)
rt.revIdx = make(map[string]peer.ID)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"context"
"errors"
"fmt"
"io/fs"
"math"
"math/rand"
"net"
"net/http"
"net/netip"
"strings"
"sync"
"sync/atomic"
"time"
connectip "github.com/quic-go/connect-ip-go"
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
"github.com/yosida95/uritemplate/v3"
"gitlab.com/nunet/device-management-service/observability"
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
peer "github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/utils/sys"
)
const (
IfaceMTU = 1420
MaxPacketSize = 2 * 1420 // Consistent packet size limit
)
type NetInterfaceFactory func(name string) (sys.NetInterface, error)
type subnet struct {
ctx context.Context
network *Libp2p
info struct {
id string
rtable SubnetRoutingTable
cidr *net.IPNet
}
mx sync.Mutex
ifaces map[string]struct {
tun sys.NetInterface
ctx context.Context
cancel context.CancelFunc
}
// TODO: add some map to store HTTP/3 tunnel connections
dnsmx sync.RWMutex
dnsRecords map[string]string
proxiedConns struct {
mx sync.Mutex
conns map[string]*connectip.Conn // key: IP string
}
portMappingMx sync.Mutex
portMapping map[string]*struct {
destPort string
destIP string
srcIP string
}
locks map[string]*sync.Mutex
packetQueues map[string]chan []byte
ifaceFactory NetInterfaceFactory
}
func newSubnet(ctx context.Context, l *Libp2p, factory NetInterfaceFactory) *subnet {
return &subnet{
ctx: ctx,
network: l,
info: struct {
id string
rtable SubnetRoutingTable
cidr *net.IPNet
}{
rtable: NewRoutingTable(),
cidr: &net.IPNet{IP: net.IPv4zero, Mask: net.CIDRMask(0, 32)}, // TODO: replace
},
ifaces: make(map[string]struct {
tun sys.NetInterface
ctx context.Context
cancel context.CancelFunc
}),
proxiedConns: struct {
mx sync.Mutex
conns map[string]*connectip.Conn
}{
conns: map[string]*connectip.Conn{},
},
dnsRecords: map[string]string{},
portMapping: map[string]*struct {
destPort string
destIP string
srcIP string
}{},
locks: make(map[string]*sync.Mutex),
packetQueues: make(map[string]chan []byte),
ifaceFactory: factory,
}
}
func (l *Libp2p) CreateSubnet(ctx context.Context, subnetID string, cidr string, routingTable map[string]string) error {
l.subnetsmx.Lock()
defer l.subnetsmx.Unlock()
if _, ok := l.subnets[subnetID]; ok {
return fmt.Errorf("subnet with ID %s already exists", subnetID)
}
s := newSubnet(ctx, l, l.NetIfaceFactory)
s.info.id = subnetID
_, CIDR, err := net.ParseCIDR(cidr)
if err != nil {
return err
}
s.mx.Lock()
s.info.cidr = CIDR
for ip, peerctx := range routingTable {
peerID, err := peer.Decode(peerctx)
if err != nil {
s.mx.Unlock()
return fmt.Errorf("failed to decode peer ID %s: %w", peerctx, err)
}
s.info.rtable.Add(peerID, ip)
}
s.mx.Unlock()
if atomic.CompareAndSwapInt32(&l.isHTTPServerRegistered, 0, 1) {
if err := l.startIPProxy(); err != nil {
return err
}
}
l.subnets[subnetID] = s
return nil
}
func (l *Libp2p) DestroySubnet(subnetID string) error {
l.subnetsmx.Lock()
defer l.subnetsmx.Unlock()
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
for ip := range s.ifaces {
s.ifaces[ip].cancel()
s.ifaces[ip].tun.Close()
_ = s.ifaces[ip].tun.Down()
_ = s.ifaces[ip].tun.Delete()
}
s.mx.Lock()
s.ifaces = make(map[string]struct {
tun sys.NetInterface
ctx context.Context
cancel context.CancelFunc
})
s.mx.Unlock()
// Clean up proxied connections
s.proxiedConns.mx.Lock()
for ip, conn := range s.proxiedConns.conns {
log.Debugf("closing proxied connection for %s during subnet destruction", ip)
conn.Close()
}
s.proxiedConns.conns = make(map[string]*connectip.Conn)
s.proxiedConns.mx.Unlock()
s.dnsmx.Lock()
s.dnsRecords = make(map[string]string)
s.dnsmx.Unlock()
s.info.rtable.Clear()
// Create snapshot of port mappings to avoid holding lock during cleanup
s.portMappingMx.Lock()
mappingsSnapshot := make(map[string]*struct {
destPort string
destIP string
srcIP string
})
for k, v := range s.portMapping {
mappingsSnapshot[k] = v
}
s.portMappingMx.Unlock()
// Now iterate over snapshot
for sourcePort, mapping := range mappingsSnapshot {
err := l.UnmapPort(subnetID, "tcp", mapping.srcIP, sourcePort, mapping.destIP, mapping.destPort)
if err != nil {
log.Errorf("failed to unmap port %s: %v", sourcePort, err)
}
}
if len(l.subnets) == 1 {
if atomic.CompareAndSwapInt32(&l.isHTTPServerRegistered, 1, 0) {
l.stopIPProxy()
}
}
delete(l.subnets, subnetID)
log.Debugf("subnet %s destroyed", subnetID)
return nil
}
// TODO: This method isn't doing what its name implies.
// This is basically Creating the subnet, not adding adding a peer to the subnet.
// Move all this business logic to the CreateSubnet.
func (l *Libp2p) AddSubnetPeer(subnetID, peerID, ip string) error {
l.subnetsmx.Lock()
defer l.subnetsmx.Unlock()
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return fmt.Errorf("invalid IP address %s", ip)
}
peerIDObj, err := peer.Decode(peerID)
if err != nil {
return fmt.Errorf("failed to decode peer ID %s: %w", peerID, err)
}
s.info.rtable.Add(peerIDObj, ip)
ifaces, err := sys.GetNetInterfaces()
if err != nil {
return err
}
takenNames := make([]string, 0)
for _, iface := range ifaces {
takenNames = append(takenNames, iface.Name)
}
log.Debugf("finding proper iface name for TUN interface (taken_names=%s)", takenNames)
name, err := generateUniqueName(takenNames)
if err != nil {
return fmt.Errorf("failed to generate unique name for TUN interface: %w", err)
}
log.Debugf("Creating TUN interface with name: %s", name)
address := fmt.Sprintf("%s/24", ipAddr.String())
iface, err := s.ifaceFactory(name)
if err != nil {
return fmt.Errorf("failed to create tun interface: %w", err)
}
err = iface.SetAddress(address)
if err != nil {
return fmt.Errorf("failed to set address on tun interface: %w", err)
}
err = iface.SetMTU(IfaceMTU)
if err != nil {
return fmt.Errorf("failed to set MTU on tun interface: %w", err)
}
if err := iface.Up(); err != nil {
return fmt.Errorf("failed to bring up tun interface: %w", err)
}
ctx, cancel := context.WithCancel(s.ctx)
s.mx.Lock()
s.ifaces[ipAddr.String()] = struct {
tun sys.NetInterface
ctx context.Context
cancel context.CancelFunc
}{
tun: iface,
ctx: ctx,
cancel: cancel,
}
s.mx.Unlock()
go s.readPackets(ctx, iface)
return nil
}
func (l *Libp2p) RemoveSubnetPeers(subnetID string, partialRoutinTable map[string]string) error {
for ip, peerID := range partialRoutinTable {
err := l.removeSubnetPeer(subnetID, peerID, ip)
if err != nil {
return err
}
}
return nil
}
func (l *Libp2p) removeSubnetPeer(subnetID, peerID, ip string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
peerIDObj, err := peer.Decode(peerID)
if err != nil {
return fmt.Errorf("failed to decode peer ID %s: %w", peerID, err)
}
ips, ok := s.info.rtable.Get(peerIDObj)
if !ok {
return fmt.Errorf("peer with ID %s is not in the subnet", peerID)
}
found := false
for _, i := range ips {
if i == ip {
found = true
break
}
}
if !found {
return nil
}
s.mx.Lock()
iface, ok := s.ifaces[ip]
if ok {
iface.cancel()
if err := iface.tun.Down(); err != nil {
log.Errorf("failed to bring down tun device: %v (subnet=%s, ip=%s)", err, s.info.id, ip)
}
if err := iface.tun.Delete(); err != nil {
log.Errorf("failed to delete tun device: %v (subnet=%s, ip=%s)", err, s.info.id, ip)
}
delete(s.ifaces, ip)
}
s.mx.Unlock()
s.info.rtable.Remove(peerIDObj, ip)
return nil
}
func (l *Libp2p) AcceptSubnetPeers(subnetID string, partialRoutingTable map[string]string) error {
for ip, peerID := range partialRoutingTable {
err := l.acceptSubnetPeer(subnetID, peerID, ip)
if err != nil {
return err
}
}
return nil
}
func (l *Libp2p) acceptSubnetPeer(subnetID, peerID, ip string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
ipAddr := net.ParseIP(ip)
if ipAddr == nil {
return fmt.Errorf("invalid IP address %s", ip)
}
peerIDObj, err := peer.Decode(peerID)
if err != nil {
return fmt.Errorf("failed to decode peer ID %s: %w", peerID, err)
}
// TODO: remove this or check if the IP exists.
// if _, ok := s.info.rtable.Get(peerIDObj); ok {
// return nil
// }
s.info.rtable.Add(peerIDObj, ip)
return nil
}
func (l *Libp2p) AddSubnetDNSRecords(subnetID string, records map[string]string) error {
l.subnetsmx.Lock()
defer l.subnetsmx.Unlock()
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
s.dnsmx.Lock()
for name, ip := range records {
s.dnsRecords[name] = ip
}
s.dnsmx.Unlock()
return nil
}
func (l *Libp2p) RemoveSubnetDNSRecord(subnetID, name string) error {
l.subnetsmx.Lock()
defer l.subnetsmx.Unlock()
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
s.dnsmx.Lock()
delete(s.dnsRecords, name)
s.dnsmx.Unlock()
return nil
}
func (l *Libp2p) startIPProxy() error {
p := connectip.Proxy{}
hostIP, err := l.HostPublicIP()
if err != nil {
return fmt.Errorf("failed to get host public IP: %w", err)
}
if l.rawqtr.listener == nil {
<-l.rawqtr.listenerReady
}
actualPort := l.rawqtr.listener.Addr().(*net.UDPAddr).Port
template := uritemplate.MustNew(fmt.Sprintf("http://%s:%d/vpn", hostIP, actualPort))
mux := http.NewServeMux()
mux.HandleFunc("/vpn", func(w http.ResponseWriter, r *http.Request) {
// get subnet id from the query
subnetID := r.URL.Query().Get("subnetID")
if subnetID == "" {
log.Debug("received bad http proxy request, no subnetID was provided")
w.WriteHeader(http.StatusBadRequest)
return
}
// get src ip from query params
srcIP := r.URL.Query().Get("srcIP")
if srcIP == "" || !IsIPv4(srcIP) {
log.Debug("received bad http proxy request, no srcIP was provided")
w.WriteHeader(http.StatusBadRequest)
return
}
log.Debugf("received http proxy request for subnet %s from %s", subnetID, srcIP)
l.subnetsmx.Lock()
// retrieve subnet
subnet, ok := l.subnets[subnetID]
l.subnetsmx.Unlock()
if !ok {
log.Debugf("subnet with ID %s does not exist", subnetID)
w.WriteHeader(http.StatusNotFound)
return
}
addr := netip.MustParseAddr(srcIP)
route := netip.MustParsePrefix(subnet.info.cidr.String())
// XXX bad hack - recreating template with correct host so it never fails (dial ip could change if behind nat)
template = uritemplate.MustNew(fmt.Sprintf("http://%s/vpn", r.Host))
req, err := connectip.ParseRequest(r, template)
if err != nil {
log.Errorf("failed to parse request: %v", err)
w.WriteHeader(http.StatusBadRequest)
return
}
conn, err := p.Proxy(w, req)
if err != nil {
log.Errorf("failed to proxy connection: %v", err)
w.WriteHeader(http.StatusInternalServerError)
return
}
// ATOMIC check and store for incoming connections
subnet.proxiedConns.mx.Lock()
if old, ok := subnet.proxiedConns.conns[srcIP]; ok {
old.Close()
}
subnet.proxiedConns.conns[srcIP] = conn
subnet.proxiedConns.mx.Unlock()
// Double-check: is this still the connection in the map?
subnet.proxiedConns.mx.Lock()
if subnet.proxiedConns.conns[srcIP] != conn {
subnet.proxiedConns.mx.Unlock()
log.Debugf("connection from %s lost race, closing", srcIP)
log.Errorf("closing connection for %s", srcIP)
conn.Close()
return
}
subnet.proxiedConns.mx.Unlock()
log.Debugf("connection from %s stored in subnet", srcIP)
if err := l.handleIPProxyConn(subnet, conn, addr, route, 0); err != nil {
log.Error("failed to handle connection: %v", err)
}
})
ctx, cancel := context.WithCancel(context.Background())
s := &http3.Server{
Handler: mux,
EnableDatagrams: true,
}
go func() {
if l.rawqtr.listener == nil {
<-l.rawqtr.listenerReady
}
for {
select {
case <-ctx.Done():
log.Debug("ip proxy context done, shutting down")
return
case <-l.ctx.Done():
log.Debug("libp2p context done, shutting down")
return
case rawQUICConn := <-l.rawqtr.listener.acceptQueue:
l.ipproxyConnsMx.Lock()
l.ipproxyConns[rawQUICConn.RemoteAddr().String()] = rawQUICConn
l.ipproxyConnsMx.Unlock()
go func() {
log.Debug("serve http3 connection on raw quic connection")
err = s.ServeQUICConn(rawQUICConn)
if err != nil {
log.Errorf("failed to serve http3 connection on raw quic connection: %v", err)
err := rawQUICConn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "server is down")
if err != nil {
log.Errorf("failed to close raw quic connection: %v", err)
}
}
}()
}
}
}()
l.ipproxyCtx = ctx
l.ipproxyCtxCancel = cancel
l.ipproxy = s
l.ipproxyConnsMx.Lock()
l.ipproxyConns = make(map[string]*quic.Conn)
l.ipproxyConnsMx.Unlock()
log.Info("started ip proxy for all subnets")
return nil
}
func (l *Libp2p) handleIPProxyConn(
snet *subnet,
conn *connectip.Conn,
addr netip.Addr,
route netip.Prefix,
ipProtocol uint8,
) error {
log.Debugf(
"handling ip proxy conn (subnet=%s, addr=%s, route=%s, ipProtocol=%d)",
snet.info.id, addr.String(), route.String(), ipProtocol,
)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := conn.AssignAddresses(ctx, []netip.Prefix{netip.PrefixFrom(addr, addr.BitLen())}); err != nil {
return fmt.Errorf("failed to assign addresses: %w", err)
}
if err := conn.AdvertiseRoute(ctx, []connectip.IPRoute{
{StartIP: route.Addr(), EndIP: LastIP(route), IPProtocol: ipProtocol},
}); err != nil {
return fmt.Errorf("failed to advertise route: %w", err)
}
errChan := make(chan error, 2)
go func() {
for {
select {
case <-snet.ctx.Done():
snet.cleanupConn(addr.String(), conn)
return
default:
b := make([]byte, 2000)
n, err := conn.ReadPacket(b)
if err != nil {
errChan <- fmt.Errorf("failed to read from connection: %w", err)
snet.cleanupConn(addr.String(), conn)
return
}
log.Debugf("read %d bytes from connection", n)
// 1. retrieve dest ip
destIP := net.IPv4(b[16], b[17], b[18], b[19]).String()
_, ok := snet.info.rtable.GetByIP(destIP)
if !ok {
log.Debugf("unrecognized destination ip %s, no peerID found for ip, not a subnet member", destIP)
continue
}
// 2. fetch the respective tun dev
snet.mx.Lock()
if iface, ok := snet.ifaces[destIP]; ok {
log.Debugf("writing packet to tun device %s", iface.tun.Name())
// 3. write to tun dev
if _, err := iface.tun.Write(b[:n]); err != nil {
log.Errorf("failed to write to tun device: %v (subnet=%s, destIP=%s)", err, snet.info.id, destIP)
}
} else {
log.Debugf("unrecognized destination ip %s, no tun device found for ip", destIP)
}
snet.mx.Unlock()
}
}
}()
go func() {
select {
case err := <-errChan:
log.Errorf("failed to handle connection: %v", err)
case <-snet.ctx.Done():
snet.cleanupConn(addr.String(), conn)
}
}()
return nil
}
func (l *Libp2p) stopIPProxy() {
log.Infow("stopping ip proxy")
l.ipproxyCtxCancel()
// ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
// defer cancel()
// l.ipproxy.Shutdown(ctx)
l.ipproxy.Close()
for _, conn := range l.ipproxyConns {
err := conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "shutting down")
if err != nil {
log.Errorf("failed to close ipproxy connection: %v", err)
}
}
l.ipproxyConns = make(map[string]*quic.Conn)
l.ipproxy = nil
}
func (s *subnet) handleDNSQueries(iface sys.NetInterface, packet []byte, packetlen int) error {
s.dnsmx.RLock()
payload, err := handleDNSQuery(packet[28:packetlen], s.dnsRecords)
s.dnsmx.RUnlock()
if err != nil {
return err
}
srcPort, destPort, srcIP, destIP, err := s.parseIPPacket(packet)
if err != nil {
return err
}
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
SrcIP: net.ParseIP(destIP),
DstIP: net.ParseIP(srcIP),
Protocol: layers.IPProtocolUDP,
}
udpLayer := &layers.UDP{
SrcPort: layers.UDPPort(destPort),
DstPort: layers.UDPPort(srcPort),
}
// Set the UDP checksum
err = udpLayer.SetNetworkLayerForChecksum(ipLayer)
if err != nil {
return err
}
// Create the packet
buffer := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err = gopacket.SerializeLayers(buffer, opts,
ipLayer,
udpLayer,
gopacket.Payload(payload),
)
if err != nil {
return err
}
_, _ = iface.Write(buffer.Bytes())
return nil
}
func (s *subnet) Route(iface sys.NetInterface, srcIP, destIP string, packet []byte, plen int) {
log.Debugf("routing packet (subnet=%s, dstIP=%s, packet_len=%d)", s.info.id, destIP, plen)
s.mx.Lock()
peerID, ok := s.info.rtable.GetByIP(destIP)
if !ok {
log.Debugf("unrecognized destination ip on subnetID: %s, dstIP: %s, not a subnet member", s.info.id, destIP)
s.mx.Unlock()
return
}
_, ok = s.info.rtable.GetByIP(srcIP)
if !ok {
log.Debugf("unrecognized source ip on subnetID: %s, srcIP: %s, not a subnet member", s.info.id, srcIP)
s.mx.Unlock()
return
}
if _, ok := s.ifaces[destIP]; ok {
log.Debugf("found destination ip in tuns table (local) on subnetID: %s, dstIP: %s, writing packet of length %d", s.info.id, destIP, plen)
// if so, write to the tun
n, err := s.ifaces[destIP].tun.Write(packet[:plen])
if err != nil {
log.Errorf("failed to write to tun device: %v (subnet=%s, destIP=%s, bytes_written=%d)", err, s.info.id, destIP, n)
} else {
log.Debugf("successfully wrote %d bytes to tun device (subnet=%s, destIP=%s)", n, s.info.id, destIP)
}
s.mx.Unlock()
return
}
s.mx.Unlock()
log.Debugf(
"found destination ip in routing table (subnet=%s, destIP=%s, peerID=%s)",
s.info.id, destIP, peerID.String(),
)
if err := s.proxyPacket(s.ctx, iface, peerID, srcIP, destIP, packet, plen); err != nil {
log.Errorf("failed to proxy packet: %v", err)
}
}
func (s *subnet) proxyPacket(
ctx context.Context,
iface sys.NetInterface,
dst peer.ID,
srcIP,
destIP string,
packet []byte,
plen int,
) error {
s.proxiedConns.mx.Lock()
conn, ok := s.proxiedConns.conns[destIP]
if !ok {
log.Debugf("no connection for %s, establishing one", destIP)
// No connection: establish one synchronously
newConn, quicConn, err := s.dialIPProxy(ctx, dst, srcIP)
if err != nil {
s.proxiedConns.mx.Unlock()
return fmt.Errorf("failed to establish connection to %s: %w", destIP, err)
}
if old, ok := s.proxiedConns.conns[destIP]; ok {
if old != newConn { // Only close if it's a different connection
log.Debugf("closing old connection for %s", destIP)
old.Close()
}
}
s.proxiedConns.conns[destIP] = newConn
s.proxiedConns.mx.Unlock()
conn = newConn
// Start a goroutine to read from the new connection and write to the TUN device
go func(ipconn *connectip.Conn, destIP, srcIP string, quicConn *quic.Conn) {
for {
select {
case <-ctx.Done():
log.Debugf("context done, abandoning read loop... (subnet=%s)", s.info.id)
return
case <-s.ctx.Done():
log.Debugf("context done, abandoning read loop... (subnet=%s)", s.info.id)
return
case <-quicConn.Context().Done():
log.Debugf("quic connection for %s closed, closing connection", destIP)
s.cleanupConn(destIP, ipconn)
return
default:
b := make([]byte, 2000)
n, err := ipconn.ReadPacket(b)
if err != nil {
log.Errorf("failed to read from outgoing connection: %v (subnet=%s, dst=%s)", err, s.info.id, dst.String())
s.cleanupConn(destIP, ipconn)
return
}
// Write to the appropriate TUN device
s.mx.Lock()
iface, ok := s.ifaces[srcIP]
s.mx.Unlock()
if ok {
if _, err := iface.tun.Write(b[:n]); err != nil {
log.Errorf("failed to write to TUN from outgoing connection: %v (subnet=%s, dst=%s)", err, s.info.id, dst.String())
}
} else {
log.Debugf("no TUN device for destIP %s (subnet=%s)", destIP, s.info.id)
}
}
}
}(newConn, destIP, srcIP, quicConn)
} else {
log.Debugf("found connection for %s, writing packet", destIP)
s.proxiedConns.mx.Unlock()
}
// Now write to the connection
icmp, err := conn.WritePacket(packet[:plen])
if err != nil {
log.Errorf("failed to write packet to connection: %s (subnet=%s, dst=%s)", err, s.info.id, dst.String())
s.cleanupConn(destIP, conn)
return fmt.Errorf("failed to write packet to connection: %w", err)
}
if len(icmp) > 0 {
_, err := iface.Write(icmp)
if err != nil {
log.Errorf("failed to write ICMP packet to tun device: %s (subnet=%s, dst=%s)", err, s.info.id, dst.String())
return fmt.Errorf("failed to write ICMP packet to tun device: %w", err)
}
}
return nil
}
func (s *subnet) readPackets(ctx context.Context, iface sys.NetInterface) {
for {
select {
case <-ctx.Done():
log.Debugf("context done, abandoning read loop... (subnet=%s)", s.info.id)
return
case <-s.ctx.Done():
log.Debugf("context done, abandoning read loop... (subnet=%s)", s.info.id)
return
default:
{
packet := make([]byte, MaxPacketSize)
// Read in a packet from the tun device.
plen, err := iface.Read(packet)
if errors.Is(err, fs.ErrClosed) {
time.Sleep(1 * time.Second)
log.Debugf("tun device closed, abandoning read loop... (err=%s, subnet=%s)", err, s.info.id)
return
} else if err != nil {
// TODO Errorw
log.Debugf("(error): failed to read packet from tun device: %s (subnet=%s)", err, s.info.id)
continue
}
if plen == 0 {
log.Warnw("(error): received zero-length packet from tun device (subnet=%s, iface=%s)", s.info.id, iface.Name())
continue
}
if plen > MaxPacketSize {
log.Debugf("received packet with length %d, truncating to %d", plen, MaxPacketSize)
plen = MaxPacketSize
}
srcPort, destPort, srcIP, destIP, err := s.parseIPPacket(packet)
if err != nil {
log.Debugf("(error): failed to parse IP packet: %s", err)
continue
}
log.Debugw(
"read packet from tun device",
"labels", string(observability.LabelNode),
"tun", iface.Name(),
"subnet", s.info.id,
"destIP", destIP,
"destPort", destPort,
"srcIP", srcIP,
"srcPort", srcPort,
)
// Fix DNS filtering logic - only handle DNS queries, route everything else
if destPort == 53 {
// TODO Debugw
log.Debugf(
"handling DNS query (tun=%s, subnet=%s, destIP=%s, destPort=%s, srcIP=%s, srcPort=%s)",
iface.Name(),
s.info.id,
destIP,
destPort,
srcIP,
srcPort,
)
if err := s.handleDNSQueries(iface, packet, plen); err != nil {
log.Errorf("failed to handle DNS query: %s", err)
}
} else {
s.Route(iface, srcIP, destIP, packet, plen)
}
}
}
}
}
func (s *subnet) parseIPPacket(rawPacket []byte) (srcPort int, destPort int, srcIP string, destIP string, err error) {
// Create a packet object from the raw data
packet := gopacket.NewPacket(rawPacket, layers.LayerTypeIPv4, gopacket.Default)
if err := packet.ErrorLayer(); err != nil {
return 0, 0, "", "", fmt.Errorf("failed to decode packet: %s", err)
}
// Get IP layer
ipLayer := packet.Layer(layers.LayerTypeIPv4)
if ipLayer != nil {
ip, _ := ipLayer.(*layers.IPv4)
srcIP = ip.SrcIP.String()
destIP = ip.DstIP.String()
}
udpLayer := packet.Layer(layers.LayerTypeUDP)
if udpLayer != nil {
udp, _ := udpLayer.(*layers.UDP)
srcPort = int(udp.SrcPort)
destPort = int(udp.DstPort)
}
// Add TCP parsing
tcpLayer := packet.Layer(layers.LayerTypeTCP)
if tcpLayer != nil {
tcp, _ := tcpLayer.(*layers.TCP)
srcPort = int(tcp.SrcPort)
destPort = int(tcp.DstPort)
}
return srcPort, destPort, srcIP, destIP, err
}
func (s *subnet) dialIPProxy(
ctx context.Context,
target peer.ID,
srcIP string,
) (*connectip.Conn, *quic.Conn, error) {
conn, proxyAddr, err := s.network.RawQUICConnect(target, s.info.id)
if err != nil {
return nil, nil, fmt.Errorf("failed to dial raw QUIC connection: %w", err)
}
tr := &http3.Transport{EnableDatagrams: true}
hconn := tr.NewClientConn(conn)
template := uritemplate.MustNew(fmt.Sprintf(
"https://%s:%d/vpn?subnetID=%s&srcIP=%s",
proxyAddr.Addr().Unmap().String(),
proxyAddr.Port(),
s.info.id,
srcIP,
))
ipconn, rsp, err := connectip.Dial(ctx, hconn, template)
if err != nil {
return nil, nil, fmt.Errorf("failed to dial connect-ip connection: %w", err)
}
if rsp.StatusCode != http.StatusOK {
log.Errorf("unexpected status code: %d (err=%s, body=%s)", rsp.StatusCode, err, rsp.Body)
return nil, nil, fmt.Errorf("unexpected status code: %d", rsp.StatusCode)
}
log.Debugf("connected to IP Proxy for target %s on %s", target, proxyAddr)
return ipconn, conn, nil
}
func (s *subnet) PeersAddresses() map[string]bool {
addresses := make(map[string]bool)
s.mx.Lock()
rtable := s.info.rtable
s.mx.Unlock()
for peerID := range rtable.All() {
addrs := s.network.Host.Peerstore().Addrs(peerID)
if len(addrs) == 0 {
pinfo, err := s.network.resolvePeerAddress(context.TODO(), peerID)
if err != nil {
log.Errorf("failed to resolve peer address: %s", err)
continue
}
addrs = pinfo.Addrs
}
for _, addr := range addrs {
parts := strings.Split(addr.String(), "/")
ip := parts[2]
port := parts[4]
addresses[fmt.Sprintf("%s:%s", ip, port)] = true
}
}
return addresses
}
// stringSliceContains checks if a string is in a slice of strings
func stringSliceContains(slice []string, word string) bool {
for _, s := range slice {
if s == word {
return true
}
}
return false
}
func generateUniqueName(takenList []string) (string, error) {
var retries int
var candidate string
i := 30
for {
candidate = fmt.Sprintf("dms%d", rand.Intn(i)) //nolint:gosec
if !stringSliceContains(takenList, candidate) {
break
}
retries++
if retries > 30 {
i += 30
}
if retries > 100 {
return "", fmt.Errorf("failed to generate unique name")
}
}
return candidate, nil
}
func LastIP(prefix netip.Prefix) netip.Addr {
addr := prefix.Addr()
bytes := addr.AsSlice()
hostBits := len(bytes)*8 - prefix.Bits()
for i := len(bytes) - 1; i >= 0; i-- {
setBits := math.Min(8, float64(hostBits))
if setBits <= 0 {
break
}
bytes[i] |= byte(0xff >> (8 - int(setBits)))
hostBits -= 8
}
if addr.Is4() {
return netip.AddrFrom4([4]byte(bytes[:4]))
}
return netip.AddrFrom16([16]byte(bytes))
}
func IsIPv4(ip string) bool {
return net.ParseIP(ip).To4() != nil
}
func (s *subnet) cleanupConn(ip string, conn *connectip.Conn) {
defer s.proxiedConns.mx.Unlock()
s.proxiedConns.mx.Lock()
current, ok := s.proxiedConns.conns[ip]
if ok && current == conn {
log.Debugf("cleanupConn: closing and removing connection for %s", ip)
delete(s.proxiedConns.conns, ip)
conn.Close()
}
}
func NetTunFactory(name string) (sys.NetInterface, error) {
return sys.NewTunTapInterface(name, sys.NetTunMode, false)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
// +build linux
package libp2p
import (
"fmt"
"net"
"gitlab.com/nunet/device-management-service/utils/sys"
)
func (l *Libp2p) MapPort(subnetID, protocol, sourceIP, sourcePort, destIP, destPort string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
// Check if port already mapped (with lock)
s.portMappingMx.Lock()
if _, ok := s.portMapping[sourcePort]; ok {
s.portMappingMx.Unlock()
return fmt.Errorf("port %s is already mapped", sourcePort)
}
s.portMappingMx.Unlock()
// TODO: check if any rules for the port already exists
err := sys.AddForwardRule(protocol, destIP, destPort)
if err != nil {
return err
}
loIface, err := sys.GetNetInterfaceByFlags(net.FlagLoopback)
if err != nil {
log.Errorf("failed to get loopback interface: %v", err)
log.Warnf("port %s will not be mapped to localhost:%s", sourcePort, destIP, destPort)
} else {
err = sys.AddOutputNatRule(protocol, destIP, destPort, loIface.Name)
if err != nil {
return err
}
}
// Store mapping (with lock)
s.portMappingMx.Lock()
s.portMapping[sourcePort] = &struct {
destPort string
destIP string
srcIP string
}{
destPort: destPort,
destIP: destIP,
srcIP: sourceIP,
}
s.portMappingMx.Unlock()
return nil
}
func (l *Libp2p) UnmapPort(subnetID, protocol, sourceIP, sourcePort, destIP, destPort string) error {
s, ok := l.subnets[subnetID]
if !ok {
return fmt.Errorf("subnet with ID %s does not exist", subnetID)
}
// Get and validate mapping (with lock)
s.portMappingMx.Lock()
mapping, ok := s.portMapping[sourcePort]
if !ok {
s.portMappingMx.Unlock()
return fmt.Errorf("port %s is not mapped", sourcePort)
}
if mapping.destIP != destIP || mapping.destPort != destPort || mapping.srcIP != sourceIP {
s.portMappingMx.Unlock()
return fmt.Errorf("port %s is not mapped to %s:%s", sourcePort, destIP, destPort)
}
s.portMappingMx.Unlock()
err := sys.DelDNATRule(protocol, sourcePort, destIP, destPort)
if err != nil {
return err
}
err = sys.DelForwardRule(protocol, destIP, destPort)
if err != nil {
return err
}
loIface, err := sys.GetNetInterfaceByFlags(net.FlagLoopback)
if err != nil {
log.Errorf("failed to get loopback interface: %v", err)
log.Warnf("Unable to delete localhost OutputNat rule for %s:%s", destIP, destPort)
} else {
err = sys.DelOutputNatRule(protocol, destIP, destPort, loIface.Name)
if err != nil {
return err
}
}
err = sys.DelMasqueradeRule()
if err != nil {
return err
}
// Delete mapping (with lock)
s.portMappingMx.Lock()
delete(s.portMapping, sourcePort)
s.portMappingMx.Unlock()
log.Infof("port %s unmapped successfully", sourcePort)
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package libp2p
import (
"fmt"
"testing"
"time"
crypto "github.com/libp2p/go-libp2p/core/crypto"
multiaddr "github.com/multiformats/go-multiaddr"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
backgroundtasks "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/types"
)
func createPeer(t *testing.T, port, quicPort int, bootstrapPeers []multiaddr.Multiaddr, factory ...NetInterfaceFactory) *Libp2p { //nolint
peerConfig := setupPeerConfig(t, port, quicPort, bootstrapPeers)
if len(factory) > 0 && factory[0] != nil {
peerConfig.NetIfaceFactory = func(name string) (interface{}, error) {
return factory[0](name)
}
}
peer1, err := New(peerConfig, afero.NewMemMapFs())
require.NoError(t, err)
// No need to set NetIfaceFactory after construction
require.NoError(t, peer1.Init(&config.Config{}))
// Add test cleanup to ensure proper shutdown
t.Cleanup(func() {
if peer1 != nil {
err := peer1.Stop()
if err != nil {
t.Logf("Error stopping peer: %v", err)
}
time.Sleep(100 * time.Millisecond) // Give time for resources to be released
}
})
return peer1
}
func setupPeerConfig(t *testing.T, libp2pPort, quicPort int, bootstrapPeers []multiaddr.Multiaddr) *types.Libp2pConfig {
priv, _, err := crypto.GenerateKeyPair(crypto.Ed25519, 256)
assert.NoError(t, err)
// Use TCP addresses only - the QUIC address will be added automatically in host.go
// Using only the TCP transport avoids QUIC collisions
listenAddresses := []string{
fmt.Sprintf("/ip4/0.0.0.0/tcp/%d", libp2pPort),
fmt.Sprintf("/ip4/0.0.0.0/udp/%d/quic-v1", quicPort),
}
return &types.Libp2pConfig{
Env: "test",
PrivateKey: priv,
BootstrapPeers: bootstrapPeers,
Rendezvous: "nunet-randevouz",
Server: false,
Scheduler: backgroundtasks.NewScheduler(10, time.Second),
DHTPrefix: "/nunet",
CustomNamespace: "/nunet-dht-1/",
ListenAddress: listenAddresses,
PeerCountDiscoveryLimit: 40,
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package network
import (
"context"
"errors"
"fmt"
"net"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/network/libp2p"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/types"
)
type (
PeerID = libp2p.PeerID
ProtocolID = libp2p.ProtocolID
Topic = libp2p.Topic
Validator = libp2p.Validator
ValidationResult = libp2p.ValidationResult
PeerScoreSnapshot = libp2p.PeerScoreSnapshot
)
const (
ValidationAccept = libp2p.ValidationAccept
ValidationReject = libp2p.ValidationReject
ValidationIgnore = libp2p.ValidationIgnore
)
// Messenger defines the interface for sending messages.
type Messenger interface {
// SendMessage asynchronously sends a message to the given peer.
SendMessage(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error
// SendMessageSync synchronously sends a message to the given peer.
// This method blocks until the message has been sent.
SendMessageSync(ctx context.Context, hostID string, msg types.MessageEnvelope, expiry time.Time) error
}
type Network interface {
// Messenger embedded interface
Messenger
// Init initializes the network
Init(*config.Config) error
// Start starts the network
Start() error
// Stat returns the network information
Stat() types.NetworkStats
// Peers returns a list of peers from the peer store
Peers() []peer.ID
// Ping pings the given address and returns the PingResult
Ping(ctx context.Context, address string, timeout time.Duration) (types.PingResult, error)
// GetHostID returns the host ID
GetHostID() PeerID
// GetPeerPubKey returns the public key for the given peerID
GetPeerPubKey(peerID PeerID) crypto.PubKey
// HandleMessage is responsible for registering a message type and its handler.
HandleMessage(messageType string, handler func(data []byte, peerId peer.ID)) error
// UnregisterMessageHandler unregisters a stream handler for a specific protocol.
UnregisterMessageHandler(messageType string)
// ResolveAddress given an id it retruns the address of the peer.
// In libp2p, id represents the peerID and the response is the addrinfo
ResolveAddress(ctx context.Context, id string) ([]string, error)
// Advertise advertises the given data with the given adId
// such as advertising device capabilities on the DHT
Advertise(ctx context.Context, key string, data []byte) error
// Unadvertise stops advertising data corresponding to the given adId
Unadvertise(ctx context.Context, key string) error
// Query returns the network advertisement
Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error)
// Publish publishes the given data to the given topic if the network
// type allows publish/subscribe functionality such as gossipsub or nats
Publish(ctx context.Context, topic string, data []byte) error
// Subscribe subscribes to the given topic and calls the handler function
// if the network type allows it similar to Publish()
Subscribe(ctx context.Context, topic string, handler func(data []byte), validator libp2p.Validator) (uint64, error)
// Unsubscribe from a topic
Unsubscribe(topic string, subID uint64) error
// SetupBroadcastTopic allows the application to configure pubsub topic directly
SetupBroadcastTopic(topic string, setup func(*Topic) error) error
// SetupBroadcastAppScore allows the application to configure application level
// scoring for pubsub
SetBroadcastAppScore(func(PeerID) float64)
// GetBroadcastScore returns the latest broadcast score snapshot
GetBroadcastScore() map[PeerID]*PeerScoreSnapshot
// Notify allows the application to receive notifications about peer connections
// and disconnecions
Notify(ctx context.Context, preconnected func(PeerID, []ProtocolID, int), connected, disconnected func(PeerID), identified, updated func(PeerID, []ProtocolID)) error
// Connect connects to the given peer address
Connect(ctx context.Context, addr string) error
// PeerConnected returs true if the peer is currently connected
PeerConnected(p PeerID) bool
// Stop stops the network including any existing advertisements and subscriptions
Stop() error
// GetPeerIP returns the ipv4 or v6 of a peer
GetPeerIP(p PeerID) string
// HostPublicIP returns the public IP of the host
HostPublicIP() (net.IP, error)
// CreateSubnet creates a subnet with the given subnetID and CIDR
CreateSubnet(ctx context.Context, subnetID, cidr string, routingTable map[string]string) error
// RemoveSubnet removes a subnet
DestroySubnet(subnetID string) error
// AddSubnetPeer adds a peer to the subnet
AddSubnetPeer(subnetID, peerID, ip string) error
// RemoveSubnetPeer removes a peer from the subnet
//
// partialRoutingTable: ip -> peerID
RemoveSubnetPeers(subnetID string, partialRoutingTable map[string]string) error
// AcceptSubnetPeer accepts a peer to the subnet
//
// partialRoutingTable: ip -> peerID
AcceptSubnetPeers(subnetID string, partialRoutingTable map[string]string) error
// MapPort maps a sourceIp:sourcePort to destIP:destPort
MapPort(subnetID, protocol, sourceIP, sourcePort, destIP, destPort string) error
// UnmapPort removes a previous port map
UnmapPort(subnetID, protocol, sourceIP, sourcePort, destIP, destPort string) error
// AddSubnetDNSRecords adds dns records to our local resolver
AddSubnetDNSRecords(subnetID string, records map[string]string) error
// RemoveDNSRecord removes a dns record from our local resolver
RemoveSubnetDNSRecord(subnetID, name string) error
}
// NewNetwork returns a new network given the configuration.
func NewNetwork(netConfig *types.NetworkConfig, fs afero.Fs) (Network, error) {
// TODO: probable additional params to receive: DB, FileSystem
if netConfig == nil {
return nil, errors.New("network configuration is nil")
}
switch netConfig.Type {
case types.Libp2pNetwork:
ln, err := libp2p.New(&netConfig.Libp2pConfig, fs)
return ln, err
case types.VirtualNetwork: // in memory network for tests only
return NewMemoryNetHost()
case types.NATSNetwork:
return nil, errors.New("not implemented")
default:
return nil, fmt.Errorf("unsupported network type: %s", netConfig.Type)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"fmt"
"math/rand/v2"
"net"
"strconv"
"strings"
)
// GetNextIP returns the next available IP in the CIDR range
func GetNextIP(cidr string, usedIPs map[string]bool) (net.IP, error) {
cidrParts := strings.Split(cidr, "/")
if len(cidrParts) != 2 {
return nil, fmt.Errorf("invalid CIDR %s", cidr)
}
mask, err := strconv.Atoi(cidrParts[1])
if err != nil {
return nil, err
}
_, ipnet, err := net.ParseCIDR(cidr)
if err != nil {
return nil, err
}
networkMask := ipnet.IP.Mask(ipnet.Mask)
firstHostIP := net.IP{networkMask[0], networkMask[1], networkMask[2], networkMask[3] + byte(1)}
for ip := firstHostIP; ipnet.Contains(ip); ip = nextIP(ip, mask) {
if ip == nil {
break
}
if !usedIPs[ip.String()] {
return ip, nil
}
}
return nil, fmt.Errorf("no available IPs in CIDR %s", cidr)
}
// nextIP returns the next available IP in the network
func nextIP(ip net.IP, netmask int) net.IP {
if ip4 := ip.To4(); ip4 != nil {
if netmask == 0 && ip4[0] == 255 && ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 254 {
return nil // no more IPs for this network
}
if netmask == 0 && ip4[1] == 255 && ip4[2] == 255 && ip4[3] == 254 {
ip4[0]++
ip4[1] = 0
ip4[2] = 0
ip4[3] = 0
}
if netmask == 8 && ip[1] == 255 && ip4[2] == 255 && ip4[3] == 254 {
return nil // no more IPs for this network
}
if netmask == 8 && ip[1] < 255 && ip4[2] == 255 && ip4[3] == 254 {
ip4[1]++
ip4[2] = 0
ip4[3] = 0
}
if netmask == 16 && ip4[2] == 255 && ip4[3] == 254 {
return nil // no more IPs for this network
}
if (netmask == 16 || netmask == 8) && ip[2] < 255 && ip4[3] == 254 {
ip4[2]++
ip4[3] = 0
}
if netmask == 24 && ip4[3] == 254 {
return nil // no more IPs for this network
}
ip4[3]++
return ip4
}
return nil
}
// GetRandomCIDR returns a random CIDR with the given mask
// and not in the blacklist.
// If the blacklist is empty, it will return a random CIDR with the given mask.
// This function supports mask 0, 8, 16, 24.
// If you need more elaborate masks to get more subnets (i.e: 0<mask<32)
// refactor this to use bitwise operations on the IP.
func GetRandomCIDR(mask int, blacklist []string) (string, error) {
var cidr string
var breakCounter int
for {
if mask > 0 && breakCounter > 2^mask || mask == 0 && breakCounter > 255 {
return fmt.Sprintf("%s/%d", "0.0.0.0", mask), fmt.Errorf("could not find a CIDR after %d attempts", breakCounter)
}
ip := net.IP{byte(rand.IntN(256)), byte(rand.IntN(256)), byte(rand.IntN(256)), byte(rand.IntN(256))}
candidate := fmt.Sprintf("%s/%d", ip, mask)
if !isOnBlacklist(candidate, blacklist) {
cidr = candidate
break
}
breakCounter++
}
_, ip, err := net.ParseCIDR(cidr)
if err != nil {
return "", nil
}
return ip.String(), nil
}
// isOnBlacklist returns true if the given CIDR is in the blacklist.
func isOnBlacklist(cidr string, blacklist []string) bool {
for _, blacklistedCIDR := range blacklist {
_, blacklistedSubnet, err := net.ParseCIDR(blacklistedCIDR)
if err != nil {
return false // Ignore errors in blacklist for safety
}
_, subnet, err := net.ParseCIDR(cidr)
if err != nil {
return false // Ignore errors in generated CIDR for safety
}
if blacklistedSubnet.Contains(subnet.IP) {
return true
}
}
return false
}
func GetRandomCIDRInRange(mask int, start, end net.IP, blacklist []string) (string, error) {
var cidr string
var breakCounter int
networkBitsIndex := mask / 8
for {
if mask > 0 && breakCounter > 2^mask || mask == 0 && breakCounter > 255 {
return fmt.Sprintf("%s/%d", "0.0.0.0", mask), fmt.Errorf("could not find a CIDR after %d attempts", breakCounter)
}
ip := net.IP{byte(0), byte(0), byte(0), byte(0)}
for i := 0; i < networkBitsIndex; i++ {
ip[i] = byte(randRange(int(start.To4()[i]), int(end.To4()[i])))
}
candidate := fmt.Sprintf("%s/%d", ip, mask)
if !isOnBlacklist(candidate, blacklist) {
cidr = candidate
break
}
breakCounter++
}
_, ip, err := net.ParseCIDR(cidr)
if err != nil {
return "", nil
}
return ip.String(), nil
}
func randRange(minimum, maximum int) int {
return rand.IntN(maximum-minimum+1) + minimum
}
// IsFreePort checks if a given port is free to use by trying to listen on it.
func IsFreePort(port int) bool {
addr := net.JoinHostPort("", strconv.Itoa(port))
ln, err := net.Listen("tcp", addr)
if err != nil {
// error listening, probably in use
return false
}
_ = ln.Close()
return true
}
// GetRandomAvailablePort returns a random available port on the system.
func GetRandomAvailablePort() (int, error) {
// Try to find an available port within a reasonable range
// Standard ephemeral port range is typically 49152-65535
// but we'll use a broader range to increase chances of finding one
const (
minPort = 10000
maxPort = 65535
maxTries = 100
)
for i := 0; i < maxTries; i++ {
port := randRange(minPort, maxPort)
if IsFreePort(port) {
return port, nil
}
}
return 0, fmt.Errorf("could not find an available port after %d attempts", maxTries)
}
// GetMultipleAvailablePorts returns a list of random available ports on the system.
func GetMultipleAvailablePorts(numPorts int) ([]int, error) {
ports := make([]int, 0, numPorts)
for i := 0; i < numPorts; i++ {
port, err := GetRandomAvailablePort()
if err != nil {
return nil, err
}
ports = append(ports, port)
}
return ports, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package network
import (
"context"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/peer"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/crypto"
common "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/types"
)
// Substrate (shared “world”) where each host has its own view of
// the network. Actions are usually done as: hostAlice gets
// hostBob handler for message XXX, and call the handler.
//
// The only global state is the dht for simplicity purposes.
// And the globalPeers which is used only for connecting new hosts.
//
// Improvements:
// - Make use of real pub-priv key pair and real peerIDs
// - Some methods should be more realistic: HostPublicIP(),
// ResolvePeerAddress(), GetPeerIP()...
// - Notify() implementation may be necessary
type Substrate struct {
mx sync.RWMutex
dht map[string]map[string][]byte // key -> peerID -> value
globalPeers map[string]*MemoryHost // used only for connecting new hosts
}
// MemoryHost — implements Network, delegates to Substrate
type MemoryHost struct {
pid peer.ID
substrate *Substrate
// local state
mx sync.RWMutex
peers map[string]*MemoryHost
msgHandlers map[string]func([]byte, peer.ID)
subs map[string]map[uint64]func([]byte)
score map[string]*PeerScoreSnapshot
nextSubID uint64
}
var _ Network = (*MemoryHost)(nil)
func NewSubstrate() *Substrate {
return &Substrate{
dht: map[string]map[string][]byte{},
globalPeers: map[string]*MemoryHost{},
}
}
// AddWiredPeer adds and returns a host to the substrate connected
// to all existent peers
func (substrate *Substrate) AddWiredPeer(id peer.ID) Network {
return substrate.AddPeer(id, true)
}
// AddPeer returns a Network implementation bound to this substrate.
// If forceConnection is true, the new peer connects to all existing peers.
func (substrate *Substrate) AddPeer(id peer.ID, forceConnection bool) Network {
host := &MemoryHost{
pid: id,
substrate: substrate,
peers: map[string]*MemoryHost{},
msgHandlers: map[string]func([]byte, peer.ID){},
subs: map[string]map[uint64]func([]byte){},
score: map[string]*PeerScoreSnapshot{},
nextSubID: 0,
}
substrate.mx.Lock()
defer substrate.mx.Unlock()
if forceConnection {
// Connect to all existing peers
for peerID, existingHost := range substrate.globalPeers {
if existingHost != nil &&
peerID != id.String() {
connectPeers(host, existingHost)
}
}
} else { //nolint
// TODO: implement connecting to x random peers when forceConnection is false
}
substrate.globalPeers[id.String()] = host
return host
}
// connectPeers establishes a bidirectional connection between two peers
func connectPeers(bob, alice *MemoryHost) {
bob.mx.Lock()
bob.peers[alice.pid.String()] = alice
bob.mx.Unlock()
alice.mx.Lock()
alice.peers[bob.pid.String()] = bob
alice.mx.Unlock()
}
// NewMemoryNetHost is a substrate wrapper that simply
// returns a host who is the only participant of the network.
//
// Useful for tests that don't need conn between peers but
// only a single instance of Network
func NewMemoryNetHost() (Network, error) {
substrate := NewSubstrate()
_, pubKey, err := crypto.GenerateKeyPair(crypto.Ed25519)
if err != nil {
return nil, err
}
peerID, err := peer.IDFromPublicKey(pubKey)
if err != nil {
return nil, err
}
return substrate.AddPeer(peerID, true), nil
}
// Messenger
func (h *MemoryHost) SendMessage(
_ context.Context,
hostID string,
env types.MessageEnvelope,
_ time.Time,
) error {
if hostID == "" {
return errors.New("virtual: empty hostID")
}
// TODO: instead of checking self explicitly
// solve the mutex locks
if h.pid.String() == hostID {
handler, ok := h.msgHandlers[string(env.Type)]
if !ok {
return fmt.Errorf("virtual: no handler for msgType: %q", string(env.Type))
}
handler(env.Data, h.pid)
return nil
}
h.mx.RLock()
targetHost, ok := h.peers[hostID]
h.mx.RUnlock()
if !ok {
return fmt.Errorf("virtual: peer %s not conneced", hostID)
}
targetHost.mx.RLock()
handler, ok := targetHost.msgHandlers[string(env.Type)]
targetHost.mx.RUnlock()
if !ok {
return fmt.Errorf("virtual: no handler for msgType: %q", string(env.Type))
}
handler(env.Data, h.pid)
return nil
}
func (h *MemoryHost) SendMessageSync(
ctx context.Context,
host string,
env types.MessageEnvelope,
exp time.Time,
) error {
return h.SendMessage(ctx, host, env, exp)
}
// Handlers
func (h *MemoryHost) HandleMessage(msgType string, handler func([]byte, peer.ID)) error {
h.mx.Lock()
defer h.mx.Unlock()
h.msgHandlers[msgType] = handler
return nil
}
func (h *MemoryHost) UnregisterMessageHandler(msgType string) {
h.mx.Lock()
defer h.mx.Unlock()
delete(h.msgHandlers, msgType)
}
// PubSub
func (h *MemoryHost) Publish(_ context.Context, topic string, data []byte) error {
for _, destPeer := range h.peers {
destPeer.mx.RLock()
subs := destPeer.subs[topic]
destPeer.mx.RUnlock()
for _, cb := range subs {
go cb(data)
}
}
return nil
}
func (h *MemoryHost) Subscribe(
_ context.Context,
topic string,
cb func([]byte),
_ Validator,
) (uint64, error) {
h.mx.Lock()
defer h.mx.Unlock()
h.nextSubID++
if h.subs[topic] == nil {
h.subs[topic] = map[uint64]func([]byte){}
}
h.subs[topic][h.nextSubID] = cb
return h.nextSubID, nil
}
func (h *MemoryHost) Unsubscribe(topic string, id uint64) error {
h.mx.Lock()
defer h.mx.Unlock()
if m := h.subs[topic]; m != nil {
delete(m, id)
if len(m) == 0 {
delete(h.subs, topic)
}
}
return nil
}
// DHT lookups
func (h *MemoryHost) Advertise(_ context.Context, k string, d []byte) error {
h.substrate.mx.Lock()
defer h.substrate.mx.Unlock()
if h.substrate.dht[k] == nil {
h.substrate.dht[k] = map[string][]byte{}
}
h.substrate.dht[k][h.pid.String()] = d
return nil
}
func (h *MemoryHost) Unadvertise(_ context.Context, k string) error {
h.substrate.mx.Lock()
defer h.substrate.mx.Unlock()
if h.substrate.dht[k] != nil {
delete(h.substrate.dht[k], h.pid.String())
if len(h.substrate.dht[k]) == 0 {
delete(h.substrate.dht, k)
}
}
return nil
}
func (h *MemoryHost) Query(_ context.Context, k string) ([]*common.Advertisement, error) {
h.substrate.mx.RLock()
defer h.substrate.mx.RUnlock()
peers, ok := h.substrate.dht[k]
if !ok {
return nil, nil
}
ads := make([]*common.Advertisement, 0, len(peers))
for peerID, data := range peers {
ads = append(ads, &common.Advertisement{
PeerId: peerID,
Timestamp: time.Now().UnixNano(),
Data: data,
})
}
return ads, nil
}
func (h *MemoryHost) Ping(_ context.Context, id string, _ time.Duration) (types.PingResult, error) {
h.mx.RLock()
_, ok := h.peers[id]
h.mx.RUnlock()
if !ok {
return types.PingResult{Success: false}, nil
}
return types.PingResult{Success: true, RTT: time.Millisecond}, nil
}
func (h *MemoryHost) GetHostID() peer.ID { return h.pid }
func (h *MemoryHost) GetPeerPubKey(_ peer.ID) crypto.PubKey {
// TODO: I think memoryHost should make use of real pub-pvkey-peerIDs
_, pubKey, _ := crypto.GenerateKeyPair(crypto.Ed25519)
return pubKey
}
func (h *MemoryHost) Stop() error {
h.substrate.mx.Lock()
delete(h.substrate.globalPeers, h.pid.String())
h.substrate.mx.Unlock()
// Remove this host from all other hosts' peer lists
h.mx.Lock()
for _, anotherHost := range h.peers {
anotherHost.mx.Lock()
delete(anotherHost.peers, h.pid.String())
anotherHost.mx.Unlock()
}
h.mx.Unlock()
return nil
}
func (h *MemoryHost) Stat() types.NetworkStats {
return types.NetworkStats{ID: h.pid.String(), ListenAddr: "virtual://" + string(h.pid)}
}
func (h *MemoryHost) Peers() []peer.ID {
h.mx.RLock()
defer h.mx.RUnlock()
peers := make([]peer.ID, 0, len(h.peers))
for pid := range h.peers {
id, err := peer.Decode(pid)
if err != nil {
continue // skip invalid peer IDs
}
peers = append(peers, id)
}
return peers
}
func (h *MemoryHost) Connect(_ context.Context, peerID string) error {
peer, err := peer.Decode(peerID)
if err != nil {
return errors.New("virtual: invalid peer ID")
}
_ = h.substrate.AddPeer(peer, true)
return nil
}
func (h *MemoryHost) PeerConnected(p peer.ID) bool {
h.mx.RLock()
defer h.mx.RUnlock()
_, ok := h.peers[p.String()]
return ok
}
// Misc stubs
func (*MemoryHost) Init(*config.Config) error { return nil }
func (*MemoryHost) Start() error { return nil }
func (*MemoryHost) ResolveAddress(context.Context, string) ([]string, error) { return nil, nil }
func (*MemoryHost) SetupBroadcastTopic(string, func(*Topic) error) error { return nil }
func (*MemoryHost) SetBroadcastAppScore(func(peer.ID) float64) {}
func (h *MemoryHost) GetBroadcastScore() map[peer.ID]*PeerScoreSnapshot {
h.mx.RLock()
defer h.mx.RUnlock()
// Return a copy to avoid race conditions
scores := make(map[peer.ID]*PeerScoreSnapshot)
for k, v := range h.score {
id, err := peer.Decode(k)
if err != nil {
continue
}
scores[id] = v
}
return scores
}
func (*MemoryHost) Notify(context.Context,
func(peer.ID, []ProtocolID, int),
func(peer.ID),
func(peer.ID),
func(peer.ID, []ProtocolID),
func(peer.ID, []ProtocolID),
) error {
return nil
}
func (*MemoryHost) GetPeerIP(peer.ID) string { return "" }
func (*MemoryHost) HostPublicIP() (net.IP, error) { return nil, nil }
func (*MemoryHost) CreateSubnet(context.Context, string, string, map[string]string) error { return nil }
func (*MemoryHost) DestroySubnet(string) error { return nil }
func (*MemoryHost) AddSubnetPeer(string, string, string) error { return nil }
func (*MemoryHost) RemoveSubnetPeers(string, map[string]string) error { return nil }
func (*MemoryHost) AcceptSubnetPeers(string, map[string]string) error { return nil }
func (*MemoryHost) MapPort(string, string, string, string, string, string) error { return nil }
func (*MemoryHost) UnmapPort(string, string, string, string, string, string) error { return nil }
func (*MemoryHost) AddSubnetDNSRecords(string, map[string]string) error { return nil }
func (*MemoryHost) RemoveSubnetDNSRecord(string, string) error { return nil }
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package observability
import (
"os"
"reflect"
"runtime/debug"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// LogLabel is the type used for routing decisions based on label names.
type LogLabel string
// These constants define the labels we can attach to log entries.
// TODO desc for each label
const (
LabelDefault LogLabel = "default"
// TODO merge with LabelContract?
LabelAccounting LogLabel = "accounting"
LabelMetric LogLabel = "metric"
LabelDeployment LogLabel = "deployment"
LabelAllocation LogLabel = "allocation"
LabelNode LogLabel = "node"
LabelContract LogLabel = "contract"
// TODO unused
LabelUser LogLabel = "user"
)
const EnvCIRun = "DMS_CI_RUN"
// LabelRoutingConfig defines optional routing rules per label.
type LabelRoutingConfig struct {
// If SkipES is true, logs with this label will not be sent to Elasticsearch.
SkipES bool
// If ESIndex is non-empty, logs with this label will be routed to that ES index
// instead of the default/logs index.
ESIndex string
}
// labelRoutingMap is our in-memory map from label → routing configuration.
var labelRoutingMap = map[LogLabel]LabelRoutingConfig{
LabelContract: {
SkipES: false,
ESIndex: "contract-index",
},
LabelAccounting: {
SkipES: false,
ESIndex: "accounting-index",
},
LabelMetric: {
SkipES: false,
ESIndex: "metric-index",
},
LabelDeployment: {
SkipES: false,
ESIndex: "deployment-index",
},
LabelAllocation: {
SkipES: false,
ESIndex: "allocation-index",
},
LabelNode: {
SkipES: false,
ESIndex: "node-index",
},
LabelUser: {
SkipES: false,
ESIndex: "user-index",
},
}
// labelInjectionCore ensures "labels" is set and sets "es_skip"/"es_index".
type labelInjectionCore struct {
next zapcore.Core
levelEnabler zapcore.LevelEnabler
}
// GetLabelRoutingConfig inspects the provided labels and returns whether logs
// should be skipped for ES (skipES) and which ES index to route them to (esIndex).
func GetLabelRoutingConfig(labels []string) (skipES bool, esIndex string) {
for _, lbl := range labels {
cfg, exists := labelRoutingMap[LogLabel(lbl)]
if !exists {
continue
}
if cfg.SkipES {
skipES = true
}
if cfg.ESIndex != "" {
esIndex = cfg.ESIndex
}
}
return skipES, esIndex
}
func newLabelInjectionCore(next zapcore.Core, enabler zapcore.LevelEnabler) zapcore.Core {
return &labelInjectionCore{
next: next,
levelEnabler: enabler,
}
}
func (l *labelInjectionCore) Enabled(level zapcore.Level) bool {
return l.levelEnabler.Enabled(level)
}
func (l *labelInjectionCore) With(fields []zapcore.Field) zapcore.Core {
return &labelInjectionCore{
next: l.next.With(fields),
levelEnabler: l.levelEnabler,
}
}
func (l *labelInjectionCore) Check(ent zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if l.Enabled(ent.Level) {
return ce.AddCore(ent, l)
}
return ce
}
// Write merges or sets a single "labels" field, then determines "es_skip" and "es_index".
func (l *labelInjectionCore) Write(ent zapcore.Entry, fields []zapcore.Field) error {
mergedLabels := make([]string, 0)
uniqueLabels := make(map[string]bool)
// We'll store all non-"labels" fields in a new slice and handle "labels" separately.
finalFields := make([]zapcore.Field, 0, len(fields))
for _, f := range fields {
if f.Key == "labels" {
labels := extractLabels([]zapcore.Field{f})
for _, lbl := range labels {
if lbl == "" {
continue
}
if !uniqueLabels[lbl] {
uniqueLabels[lbl] = true
mergedLabels = append(mergedLabels, lbl)
}
}
} else {
finalFields = append(finalFields, f)
}
}
// Default to ["default"] if no labels found.
if len(mergedLabels) == 0 {
mergedLabels = []string{string(LabelDefault)}
}
finalFields = append(finalFields, zap.Strings("labels", mergedLabels))
// Use the routing logic to decide skipES, overrideIndex.
skipES, overrideIndex := GetLabelRoutingConfig(mergedLabels)
if skipES {
finalFields = append(finalFields, zap.Bool("es_skip", true))
}
if overrideIndex != "" {
finalFields = append(finalFields, zap.String("es_index", overrideIndex))
}
// mark CI logs
if os.Getenv(EnvCIRun) != "" {
finalFields = append(finalFields, zap.Bool("ci", true))
}
finalFields = l.gatherFields(ent, fields, finalFields)
finalFields = append(finalFields, zap.String("transaction.id", rootTransaction.TraceContext().Trace.String()))
return l.next.Write(ent, finalFields)
}
func (l *labelInjectionCore) gatherFields(
ent zapcore.Entry, fields []zapcore.Field, finalFields []zapcore.Field,
) []zapcore.Field {
// catch WARN and ERROR and create spans
switch {
case ent.Level == zapcore.WarnLevel || ent.Level == zapcore.ErrorLevel:
// extract error
var errMsg string
for _, f := range fields {
if f.Key != "error" {
continue
}
if f.String != "" {
errMsg = " " + f.String
} else if err, ok := f.Interface.(error); ok {
errMsg = " " + err.Error()
}
}
errSnippet := errMsg
// optionally attach an unstructured log msg
var logMsg string
if fields == nil {
logMsg = ent.Message[0:min(len(ent.Message), 300)]
}
if len(errSnippet) == 0 {
errSnippet = " " + logMsg
}
if len(errSnippet) > 30 {
errSnippet = errSnippet[:30] + "..."
}
// keep in sync with the call path
skipFrames := 13
// get a stack trace
sTrace := strings.Split(string(debug.Stack()), "\n")[skipFrames:]
if len(sTrace) > 0 && sTrace[len(sTrace)-1] == "" {
sTrace = sTrace[:len(sTrace)-1]
}
// create span
end := StartSpan(strings.ToUpper(ent.Level.String())+errSnippet,
"error", errMsg,
"logMsg", logMsg,
// format stack for Kibana
"stack_trace", strings.Join(sTrace, "\n ------ "),
)
var spanID string
if len(activeSpans) > 0 {
spanID = activeSpans[len(activeSpans)-1].TraceContext().Span.String()
} else {
spanID = "no-active-span"
}
end()
// bind this log msg to this err span, add the stack trace
// sTraceStr, _ := json.Marshal(sTrace)
for i, v := range sTrace {
sTrace[i] = strings.Replace(v, "\t", " ", 1)
}
finalFields = append(finalFields,
zap.String("span.id", spanID),
// TODO remove _ nesting, fblog fixed arrays
// zap.Any("stack_trace", sTrace))
zap.Dict("stack_trace", zap.Strings("_", sTrace)))
case len(activeSpans) > 0:
// bind non-err logs to traces
latestSpan := activeSpans[len(activeSpans)-1]
finalFields = append(finalFields, zap.String("span.id", latestSpan.TraceContext().Span.String()))
case latestSpanID != "":
finalFields = append(finalFields, zap.String("span.id", latestSpanID))
}
return finalFields
}
func (l *labelInjectionCore) Sync() error {
return l.next.Sync()
}
// extractLabels looks for a field with key == "labels" and returns a string slice.
func extractLabels(fields []zapcore.Field) []string {
for _, f := range fields {
if f.Key != "labels" {
continue
}
// CASE 1: Directly handle a single string field
if f.Type == zapcore.StringType {
return []string{f.String}
}
// CASE 2: Switch on f.Interface's concrete type
switch val := f.Interface.(type) {
case []string:
return val
case string:
// e.g. if we only have one label
return []string{val}
case []interface{}:
// e.g. if the library passes an array of interfaces
var s []string
for _, i := range val {
if str, ok := i.(string); ok {
s = append(s, str)
}
}
return s
}
// CASE 3: If the field is using reflection (zapcore.ReflectType)
if f.Type == zapcore.ReflectType {
rv := reflect.ValueOf(f.Interface)
if rv.Kind() == reflect.Slice {
var s []string
for i := 0; i < rv.Len(); i++ {
elem := rv.Index(i).Interface()
if str, ok := elem.(string); ok {
s = append(s, str)
}
}
return s
}
}
}
// If we found no "labels" field or couldn't parse it, return nil
return nil
}
package observability
import (
"context"
"time"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
"go.opentelemetry.io/otel/metric"
sdkmetric "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
semconv "go.opentelemetry.io/otel/semconv/v1.17.0"
)
// metrics
var (
NodeOnboarded metric.Int64UpDownCounter
NodeOnboardedCPU metric.Float64Gauge
NodeOnboardedRAM metric.Int64Gauge
NodeOnboardedDisk metric.Int64Gauge
NodeOnboardedGPU metric.Int64Gauge
NodeLocation metric.Int64Gauge
BidReceived metric.Int64Counter
BidAccepted metric.Int64Counter
DeploymentSuccess metric.Int64Counter
DeploySuccessAllocations metric.Int64Gauge
DeploySuccessCPUCoresAssigned metric.Float64Gauge
DeploySuccessRAMGBAssigned metric.Int64Gauge
DeploySuccessDiskMBAssigned metric.Float64Gauge
DeploySuccessGPUCountAssigned metric.Int64Gauge
DeploymentStatus metric.Int64Counter
AllocationHeartbeat metric.Int64Counter
AllocationStatus metric.Int64Gauge
AllocCPUUsage metric.Float64Gauge
AllocMemUsed metric.Int64Gauge
AllocMemLimit metric.Int64Gauge
AllocNetRx metric.Int64Gauge
AllocNetTx metric.Int64Gauge
TxPaidAmount metric.Float64Counter
TxPaidFeesAmount metric.Float64Counter
TxCreatedAmount metric.Float64Counter
TxCreatedUSDAmount metric.Float64Counter
)
func initMetrics(ctx context.Context) error {
if !ObservabilityCfg.OTel.Enabled {
log.Info("otel_metrics_disabled", "msg", "OTel metrics export disabled in config")
return nil
}
if ObservabilityCfg.OTel.Endpoint == "" {
log.Warn("otel_metrics_skipped", "msg", "OTel endpoint not configured")
return nil
}
// Build exporter options from config
opts := []otlpmetricgrpc.Option{
otlpmetricgrpc.WithEndpoint(ObservabilityCfg.OTel.Endpoint),
}
if ObservabilityCfg.OTel.Insecure {
opts = append(opts, otlpmetricgrpc.WithInsecure())
}
exporter, err := otlpmetricgrpc.New(ctx, opts...)
if err != nil {
return err
}
res, _ := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceName("dms"),
),
)
// This sends metrics to the collector every 3 seconds
mp := sdkmetric.NewMeterProvider(
sdkmetric.WithResource(res),
sdkmetric.WithReader(sdkmetric.NewPeriodicReader(exporter, sdkmetric.WithInterval(3*time.Second))),
)
otel.SetMeterProvider(mp)
if err := systemMetrics(ctx); err != nil {
return err
}
if err := nodeMetrics(ctx); err != nil {
return err
}
if err := deploymentMetrics(ctx); err != nil {
return err
}
if err := transactionMetrics(ctx); err != nil {
return err
}
return nil
}
func systemMetrics(_ context.Context) error {
meter := otel.Meter("system")
// define types
cpu, err := meter.Float64ObservableGauge(
"dms.sys.cpu.total.norm",
metric.WithDescription("CPU usage as a percentage (0.0 to 1.0)"),
metric.WithUnit("%"),
// did: string
)
if err != nil {
return err
}
ramUsed, err := meter.Int64ObservableGauge(
"dms.sys.memory.actual.used",
metric.WithDescription("RAM usage in bytes"),
metric.WithUnit("By"),
// did: string
)
if err != nil {
return err
}
ramTotal, err := meter.Int64ObservableGauge(
"dms.sys.memory.total",
metric.WithDescription("Total RAM in bytes"),
metric.WithUnit("By"),
// did: string
)
if err != nil {
return err
}
diskUsed, err := meter.Int64ObservableGauge(
"dms.sys.filesystem.used",
metric.WithDescription("Disk usage in bytes"),
metric.WithUnit("By"),
// did: string
)
if err != nil {
return err
}
diskTotal, err := meter.Int64ObservableGauge(
"dms.sys.filesystem.total",
metric.WithDescription("Total disk space in bytes"),
metric.WithUnit("By"),
// did: string
)
if err != nil {
return err
}
uptime, err := meter.Float64ObservableGauge(
"dms.sys.uptime",
metric.WithDescription("System uptime in seconds"),
metric.WithUnit("s"),
// did: string
)
if err != nil {
return err
}
load15, err := meter.Float64ObservableGauge(
"dms.sys.load.15",
metric.WithDescription("15-minute load average"),
metric.WithUnit("%"),
// did: string
)
if err != nil {
return err
}
networkIn, err := meter.Int64ObservableGauge(
"dms.sys.network.in",
metric.WithDescription("Network bytes received"),
metric.WithUnit("By"),
// did: string
)
if err != nil {
return err
}
networkOut, err := meter.Int64ObservableGauge(
"dms.sys.network.out",
metric.WithDescription("Network bytes sent"),
metric.WithUnit("By"),
// did: string
)
if err != nil {
return err
}
attrs := metric.WithAttributes(AttrDID)
// collect
_, err = meter.RegisterCallback(func(_ context.Context, o metric.Observer) error {
metrics := collectSystemMetrics()
// CPU usage
if cpuUsage, ok := metrics["cpuUsage"].(float64); ok {
o.ObserveFloat64(cpu, cpuUsage/100.0, attrs)
}
// RAM usage
if ramUsedVal, ok := metrics["ramUsed"].(uint64); ok {
o.ObserveInt64(ramUsed, int64(ramUsedVal), attrs)
}
if ramTotalVal, ok := metrics["ramTotal"].(uint64); ok {
o.ObserveInt64(ramTotal, int64(ramTotalVal), attrs)
}
// Disk usage
if diskUsedVal, ok := metrics["diskUsed"].(uint64); ok {
o.ObserveInt64(diskUsed, int64(diskUsedVal), attrs)
}
if diskTotalVal, ok := metrics["diskTotal"].(uint64); ok {
o.ObserveInt64(diskTotal, int64(diskTotalVal), attrs)
}
// Uptime
if uptimeVal, ok := metrics["uptime"].(float64); ok {
o.ObserveFloat64(uptime, uptimeVal, attrs)
}
// Load average
if load15Val, ok := metrics["load15"].(float64); ok {
o.ObserveFloat64(load15, load15Val, attrs)
}
// Network RX/TX
if rxBytes, ok := metrics["rxBytes"].(uint64); ok {
o.ObserveInt64(networkIn, int64(rxBytes), attrs)
}
if txBytes, ok := metrics["txBytes"].(uint64); ok {
o.ObserveInt64(networkOut, int64(txBytes), attrs)
}
return nil
}, cpu, ramUsed, ramTotal, diskUsed, diskTotal, uptime, load15, networkIn, networkOut)
return err
}
func nodeMetrics(_ context.Context) error {
meter := otel.Meter("node")
var err error
// define types
NodeOnboarded, err = meter.Int64UpDownCounter("dms.node.onboarded",
metric.WithDescription("Total number of onboarded nodes"),
metric.WithUnit("{node}"),
// did: string
)
if err != nil {
return err
}
NodeOnboardedCPU, err = meter.Float64Gauge("dms.node.onboarded.cpu",
metric.WithDescription("CPU cores assigned by the node"),
metric.WithUnit("{core}"),
// did: string
)
if err != nil {
return err
}
NodeOnboardedRAM, err = meter.Int64Gauge("dms.node.onboarded.memory",
metric.WithDescription("RAM assigned by the node"),
metric.WithUnit("GBy"),
// did: string
)
if err != nil {
return err
}
NodeOnboardedDisk, err = meter.Int64Gauge("dms.node.onboarded.disk",
metric.WithDescription("Disk space assigned by the node"),
metric.WithUnit("MBy"),
// did: string
)
if err != nil {
return err
}
NodeOnboardedGPU, err = meter.Int64Gauge("dms.node.onboarded.gpu",
metric.WithDescription("Number of GPUs assigned by the node"),
metric.WithUnit("{gpu}"),
// did: string
)
if err != nil {
return err
}
NodeLocation, err = meter.Int64Gauge("dms.node.location",
metric.WithDescription("Node geolocation"),
metric.WithUnit("{location}"),
// did: string
// attribute.String("continent", location.Continent),
// attribute.String("country", location.Country),
// attribute.String("city", location.City),
// attribute.Bool("onboarded", n.onboarding.IsOnboarded()),
)
if err != nil {
return err
}
return nil
}
func deploymentMetrics(_ context.Context) error {
meter := otel.Meter("deployment")
var err error
// BIDS
// Bid metrics: track bid requests received and accepted
BidReceived, err = meter.Int64Counter("dms.bid.received",
metric.WithDescription("Bid requests received from orchestrators"),
metric.WithUnit("{bid}"),
// did: string
// attribute.String("orchestratorID", o.id),
)
if err != nil {
return err
}
BidAccepted, err = meter.Int64Counter("dms.bid.accepted",
metric.WithDescription("Bid responses sent (bids accepted by node)"),
metric.WithUnit("{bid}"),
// did: string
// attribute.String("orchestratorID", o.id),
)
if err != nil {
return err
}
// DEPLOYMENT
// define types
DeploymentStatus, err = meter.Int64Counter("dms.deployment.status",
metric.WithDescription("Deployment status change"),
metric.WithUnit("{deployment}"),
// did: string
// attribute.String("orchestratorID", o.id),
// attribute.String("status", status.String()),
)
if err != nil {
return err
}
DeploymentSuccess, err = meter.Int64Counter("dms.deployment.success",
metric.WithDescription("Deployment successful"),
metric.WithUnit("{deployment}"),
// did: string
// attribute.Int("allocations", len(o.Manifest().Allocations)),
)
if err != nil {
return err
}
DeploySuccessAllocations, err = meter.Int64Gauge("dms.deployment.success.allocations",
metric.WithDescription("Number of allocations in successful deployment"),
metric.WithUnit("{allocation}"),
// did: string
// attribute.String("orchestratorID", o.id),
)
if err != nil {
return err
}
DeploySuccessCPUCoresAssigned, err = meter.Float64Gauge("dms.deployment.success.cpu.assigned",
metric.WithDescription("CPU cores assigned in successful deployment"),
metric.WithUnit("{core}"),
// did: string
// attribute.String("orchestratorID", o.id),
)
if err != nil {
return err
}
DeploySuccessRAMGBAssigned, err = meter.Int64Gauge("dms.deployment.success.memory.assigned",
metric.WithDescription("RAM gigabytes assigned in successful deployment"),
metric.WithUnit("GBy"),
// did: string
// attribute.String("orchestratorID", o.id),
)
if err != nil {
return err
}
DeploySuccessDiskMBAssigned, err = meter.Float64Gauge("dms.deployment.success.disk.assigned",
metric.WithDescription("Disk space assigned in successful deployment"),
metric.WithUnit("MBy"),
// did: string
// attribute.String("orchestratorID", o.id),
)
if err != nil {
return err
}
DeploySuccessGPUCountAssigned, err = meter.Int64Gauge("dms.deployment.success.gpu.assigned",
metric.WithDescription("Number of GPUs assigned in successful deployment"),
metric.WithUnit("{gpu}"),
// did: string
// attribute.String("orchestratorID", o.id),
)
if err != nil {
return err
}
// ALLOCATIONS
AllocationHeartbeat, err = meter.Int64Counter("dms.allocation.heartbeat",
metric.WithDescription("Periodic deployment heartbeat"),
metric.WithUnit("{deployment}"),
// did: string
// attribute.String("orchestratorID", o.id),
// attribute.String("allocationID", notification.AllocationID),
// attribute.String("status", notification.Status),
)
if err != nil {
return err
}
AllocationStatus, err = meter.Int64Gauge("dms.allocation.status",
metric.WithDescription("Allocation status change"),
metric.WithUnit("{allocation}"),
// did: string
// attribute.String("orchestratorID", o.id),
// attribute.String("allocationID", notification.AllocationID),
// attribute.String("status", notification.Status),
)
if err != nil {
return err
}
AllocCPUUsage, err = meter.Float64Gauge("dms.allocation.cpu.usage",
metric.WithDescription("Allocation CPU usage percent"),
metric.WithUnit("%"),
// did: string
// attribute.String("orchestratorID", o.id),
// attribute.String("allocationID", notification.AllocationID),
)
if err != nil {
return err
}
AllocMemUsed, err = meter.Int64Gauge("dms.allocation.memory.used",
metric.WithDescription("Allocation memory used bytes"),
metric.WithUnit("By"),
// did: string
// attribute.String("orchestratorID", o.id),
// attribute.String("allocationID", notification.AllocationID),
)
if err != nil {
return err
}
AllocMemLimit, err = meter.Int64Gauge("dms.allocation.memory.limit",
metric.WithDescription("Allocation memory limit bytes"),
metric.WithUnit("By"),
// did: string
// attribute.String("orchestratorID", o.id),
// attribute.String("allocationID", notification.AllocationID),
)
if err != nil {
return err
}
AllocNetRx, err = meter.Int64Gauge("dms.allocation.network.rx",
metric.WithDescription("Allocation network bytes received"),
metric.WithUnit("By"),
// did: string
// attribute.String("orchestratorID", o.id),
// attribute.String("allocationID", notification.AllocationID),
)
if err != nil {
return err
}
AllocNetTx, err = meter.Int64Gauge("dms.allocation.network.tx",
metric.WithDescription("Allocation network bytes sent"),
metric.WithUnit("By"),
// did: string
// attribute.String("orchestratorID", o.id),
// attribute.String("allocationID", notification.AllocationID),
)
if err != nil {
return err
}
return nil
}
func transactionMetrics(_ context.Context) error {
meter := otel.Meter("transaction")
var err error
// define types
TxPaidAmount, err = meter.Float64Counter("dms.transaction.paid.amount",
metric.WithDescription("Total amount of paid transactions"),
metric.WithUnit("{NTX}"),
// did: string
// attribute.String("ContractDID", tx.ContractDID),
)
if err != nil {
return err
}
TxPaidFeesAmount, err = meter.Float64Counter("dms.transaction.paid.fees.amount",
metric.WithDescription("Total amount of paid fees"),
metric.WithUnit("{NTX}"),
// did: string
// attribute.String("ContractDID", tx.ContractDID),
)
if err != nil {
return err
}
TxCreatedAmount, err = meter.Float64Counter("dms.transaction.created.amount",
metric.WithDescription("Total amount of created transactions"),
metric.WithUnit("{NTX}"),
// did: string
// attribute.String("ContractDID", req.ContractDID),
)
if err != nil {
return err
}
TxCreatedUSDAmount, err = meter.Float64Counter("dms.transaction.created.usd.amount",
metric.WithDescription("Total amount of created transactions in USD"),
metric.WithUnit("USD"),
// did: string
// attribute.String("ContractDID", req.ContractDID),
)
if err != nil {
return err
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package observability
import (
context "context"
"crypto/tls"
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"sync"
"sync/atomic"
"time"
logging "github.com/ipfs/go-log/v2"
"github.com/libp2p/go-libp2p/core/event"
"github.com/libp2p/go-libp2p/core/host"
"github.com/natefinch/lumberjack"
"github.com/olivere/elastic/v7"
"go.opentelemetry.io/otel/attribute"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/lib/did"
)
const timestampKey = "timestamp"
// Global variables for observability control
var (
EventBus event.Bus
customEventEmitter event.Emitter
noOpMode bool
ObservabilityCfg config.Observability = config.DefaultConfig.Observability
ApmCfg config.APM = config.DefaultConfig.APM
mutex sync.RWMutex
combinedCore zapcore.Core
esSyncerInstance *bufferedElasticsearchSyncer
atomicLevel zap.AtomicLevel = zap.NewAtomicLevel()
log = logging.Logger("observability")
didID did.DID
AttrDID attribute.KeyValue
)
// CustomEvent represents a custom event structure
type CustomEvent struct {
Name string
Timestamp time.Time
Data map[string]interface{}
}
var esDisabledFlag int32 // 0 => false, 1 => true
func disableES() {
atomic.StoreInt32(&esDisabledFlag, 1)
}
func isESDisabled() bool {
return atomic.LoadInt32(&esDisabledFlag) == 1
}
// Initialize sets up the logger, metrics, tracing, and event bus
func Initialize(host host.Host, did did.DID, cfg *config.Config) error {
mutex.Lock()
ObservabilityCfg = cfg.Observability
ApmCfg = cfg.APM
mutex.Unlock()
if IsNoOpMode() {
return nil
}
didID = did
AttrDID = attribute.String("did", did.String())
// Initialize the event bus
if err := initEventBus(host); err != nil {
return err
}
// Initialize the logger with configurations
if err := initLogger(ObservabilityCfg); err != nil {
// Non-fatal: we log a warning and proceed
log.Warn("Failed to initialize logger", zap.Error(err))
}
// Initialize Elastic APM tracing only if the APM URL is provided
if ApmCfg.ServerURL != "" {
initTracing(ApmCfg)
} else {
log.Warn("APM Server URL not provided, tracing will be disabled")
}
if err := initMetrics(context.TODO()); err != nil {
log.Warn("Failed to initialize metrics", zap.Error(err))
}
return nil
}
// OverrideLoggerForTesting reconfigures the logger for unit tests
func OverrideLoggerForTesting() error {
// Set observability to no-op mode for unit tests
SetNoOpMode(true)
return nil
}
// initLogger configures the global logger with console, file, Elasticsearch logging, and event emission
func initLogger(observabilityConfig config.Observability) error {
// Acquire the lock only briefly
mutex.Lock()
localNoOp := noOpMode
mutex.Unlock()
// If we're in no-op mode, do nothing and return
if localNoOp {
return nil
}
// 1. Check if DMS_OBSERVE_LEVEL and GOLOG_LOG_LEVEL is set. If so, let that override any config-based level
if envLogLevel := os.Getenv("DMS_OBSERVE_LEVEL"); envLogLevel != "" {
observabilityConfig.Logging.Level = envLogLevel
} else if envLogLevel := os.Getenv("GOLOG_LOG_LEVEL"); envLogLevel != "" {
observabilityConfig.Logging.Level = envLogLevel
}
// 2. Parse the final log level string
logLevel, err := parseLogLevel(observabilityConfig.Logging.Level)
if err != nil {
return fmt.Errorf("invalid log level: %w", err)
}
atomicLevel.SetLevel(logLevel)
// Close existing ES syncer if present
if esSyncerInstance != nil {
esSyncerInstance.Close()
esSyncerInstance = nil
}
// Create console/file cores
consoleCore := createConsoleCore(atomicLevel)
fileCore := createFileCore(observabilityConfig, atomicLevel)
var esCore zapcore.Core
if observabilityConfig.Elastic.Enabled && !isESDisabled() {
esCore, err = createElasticsearchCore(observabilityConfig, atomicLevel)
if err != nil {
log.Errorw("elasticsearch_failed", "error", err)
disableES()
esCore = nil
}
}
// Create the event emitter core
eventCore := newEventEmitterCore(atomicLevel)
// Attach DID field
didField := zap.String("did", didID.String())
consoleCore = consoleCore.With([]zapcore.Field{didField})
fileCore = fileCore.With([]zapcore.Field{didField})
if esCore != nil {
esCore = esCore.With([]zapcore.Field{didField})
}
eventCore = eventCore.With([]zapcore.Field{didField})
// Combine the cores into a Tee
cores := []zapcore.Core{consoleCore, fileCore, eventCore}
if esCore != nil {
cores = append(cores, esCore)
}
baseTee := zapcore.NewTee(cores...)
// Wrap the combined tee with our label injection core
labelInjectedCore := newLabelInjectionCore(baseTee, atomicLevel)
// Lock again to replace global references
mutex.Lock()
defer mutex.Unlock()
combinedCore = labelInjectedCore
logging.SetPrimaryCore(combinedCore)
return nil
}
// parseLogLevel parses a string into a zapcore.Level
func parseLogLevel(levelStr string) (zapcore.Level, error) {
var level zapcore.Level
err := level.UnmarshalText([]byte(levelStr))
if err != nil {
return 0, err
}
return level, nil
}
func utcTimeEncoder(t time.Time, enc zapcore.PrimitiveArrayEncoder) {
// Use zapcore.ISO8601TimeEncoder on the UTC time.
zapcore.ISO8601TimeEncoder(t.UTC(), enc)
}
// createConsoleCore creates a console logging core
func createConsoleCore(levelEnabler zapcore.LevelEnabler) zapcore.Core {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = timestampKey
encoderConfig.EncodeTime = utcTimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
consoleEncoder := zapcore.NewConsoleEncoder(encoderConfig)
consoleWS := zapcore.AddSync(os.Stdout)
return zapcore.NewCore(consoleEncoder, consoleWS, levelEnabler)
}
// createFileCore creates a file logging core
func createFileCore(observabilityConfig config.Observability, levelEnabler zapcore.LevelEnabler) zapcore.Core {
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = timestampKey
encoderConfig.EncodeTime = utcTimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
fileEncoder := zapcore.NewJSONEncoder(encoderConfig)
fileWS := zapcore.AddSync(&lumberjack.Logger{
Filename: observabilityConfig.Logging.File,
MaxSize: observabilityConfig.Logging.Rotation.MaxSizeMB, // in MB
MaxBackups: observabilityConfig.Logging.Rotation.MaxBackups, // number of backups
MaxAge: observabilityConfig.Logging.Rotation.MaxAgeDays, // in days
Compress: true,
})
return zapcore.NewCore(fileEncoder, fileWS, levelEnabler)
}
// createElasticsearchCore creates an Elasticsearch logging core with "preflight" fallback
func createElasticsearchCore(observabilityConfig config.Observability, levelEnabler zapcore.LevelEnabler) (zapcore.Core, error) {
// Basic validations
if observabilityConfig.Elastic.URL == "" {
return nil, fmt.Errorf("elasticsearch URL is not configured")
}
if observabilityConfig.Elastic.Index == "" {
return nil, fmt.Errorf("elasticsearch index is not configured")
}
if observabilityConfig.Elastic.APIKey == "" {
return nil, fmt.Errorf("elasticsearch API key is not configured")
}
// Attempt to build the WriteSyncer
esWS, err := newElasticsearchWriteSyncer(
observabilityConfig.Elastic.URL,
observabilityConfig.Elastic.Index,
time.Duration(observabilityConfig.Elastic.FlushInterval)*time.Second,
observabilityConfig.Elastic.APIKey,
observabilityConfig.Elastic.InsecureSkipVerify,
)
if err != nil {
return nil, err
}
encoderConfig := zap.NewProductionEncoderConfig()
encoderConfig.TimeKey = timestampKey
encoderConfig.EncodeTime = utcTimeEncoder
encoderConfig.EncodeLevel = zapcore.CapitalLevelEncoder
esEncoder := zapcore.NewJSONEncoder(encoderConfig)
return zapcore.NewCore(esEncoder, esWS, levelEnabler), nil
}
// newElasticsearchWriteSyncer creates a WriteSyncer for Elasticsearch with buffering
func newElasticsearchWriteSyncer(
url string,
index string,
flushInterval time.Duration,
apiKey string,
insecureSkipVerify bool,
) (zapcore.WriteSyncer, error) {
// 1) Short preflight check to ensure ES is reachable
if err := preflightCheckES(url, apiKey, insecureSkipVerify); err != nil {
return nil, fmt.Errorf("ES preflight: %w", err)
}
// 2) Build the actual transport + client
dialer := &net.Dialer{
Timeout: 3 * time.Second,
}
tlsConfig := &tls.Config{InsecureSkipVerify: insecureSkipVerify}
httpClient := &http.Client{
Transport: &http.Transport{
DialContext: dialer.DialContext,
TLSClientConfig: tlsConfig,
DisableKeepAlives: false,
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 3 * time.Second,
},
Timeout: 5 * time.Second,
}
clientOptions := []elastic.ClientOptionFunc{
elastic.SetURL(url),
elastic.SetHttpClient(httpClient),
elastic.SetSniff(false),
elastic.SetHealthcheck(false),
}
if apiKey != "" {
clientOptions = append(clientOptions, elastic.SetHeaders(http.Header{
"Authorization": []string{"ApiKey " + apiKey},
}))
}
client, err := elastic.NewClient(clientOptions...)
if err != nil {
return nil, fmt.Errorf("failed to create elastic client: %v", err)
}
esSyncer := newBufferedElasticsearchSyncer(client, index, flushInterval)
esSyncerInstance = esSyncer
return esSyncer, nil
}
// preflightCheckES does a quick GET /_cluster/health to ensure ES is reachable
func preflightCheckES(url, apiKey string, insecureSkip bool) error {
preflightURL := url + "/_cluster/health"
log.Infow("Preflight: Attempting to initialize Elasticsearch",
"url", url,
"apiKey", apiKey,
)
testCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(testCtx, http.MethodGet, preflightURL, nil)
if err != nil {
return fmt.Errorf("preflight request creation failed: %w", err)
}
req.Header.Set("Authorization", "ApiKey "+apiKey)
ephemeralTransport := &http.Transport{
DialContext: (&net.Dialer{
Timeout: 2 * time.Second,
}).DialContext,
TLSClientConfig: &tls.Config{InsecureSkipVerify: insecureSkip},
DisableKeepAlives: true,
MaxIdleConns: 2,
IdleConnTimeout: 2 * time.Second,
TLSHandshakeTimeout: 2 * time.Second,
}
ephemeralClient := &http.Client{
Transport: ephemeralTransport,
Timeout: 5 * time.Second,
}
resp, err := ephemeralClient.Do(req)
if err != nil {
log.Warnw("Elasticsearch preflight request failed", "err", err, "url", url)
return err
}
defer resp.Body.Close()
if resp.StatusCode == 401 {
log.Warnw("Elasticsearch preflight returned 401 Unauthorized", "url", url)
return fmt.Errorf("invalid credentials (401 unauthorized)")
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
log.Warnw("Elasticsearch preflight returned unexpected status",
"statusCode", resp.StatusCode,
"url", url,
)
return fmt.Errorf("got unexpected HTTP %d during ES preflight", resp.StatusCode)
}
log.Infow("Elasticsearch preflight succeeded", "url", url, "statusCode", resp.StatusCode)
return nil
}
// bufferedElasticsearchSyncer implements zapcore.WriteSyncer to send logs to ES with buffering
type bufferedElasticsearchSyncer struct {
client *elastic.Client
index string
ctx context.Context
cancelFunc context.CancelFunc
buffer []string
bufferMutex sync.Mutex
flushInterval time.Duration
lastErrorTime time.Time
errorCount int
warnLogged bool
maxBufferSize int
wg sync.WaitGroup
}
func newBufferedElasticsearchSyncer(client *elastic.Client, index string, flushInterval time.Duration) *bufferedElasticsearchSyncer {
ctx, cancel := context.WithCancel(context.Background())
syncer := &bufferedElasticsearchSyncer{
client: client,
index: index,
ctx: ctx,
cancelFunc: cancel,
buffer: make([]string, 0),
flushInterval: flushInterval,
maxBufferSize: 1000,
}
syncer.wg.Add(1)
go syncer.start()
return syncer
}
func (b *bufferedElasticsearchSyncer) start() {
ticker := time.NewTicker(b.flushInterval)
defer ticker.Stop()
defer b.wg.Done()
for {
select {
case <-ticker.C:
b.Flush()
case <-b.ctx.Done():
return
}
}
}
// Write buffers the log entry
func (b *bufferedElasticsearchSyncer) Write(p []byte) (n int, err error) {
b.bufferMutex.Lock()
defer b.bufferMutex.Unlock()
if len(b.buffer) >= b.maxBufferSize {
if !b.warnLogged {
b.warnLogged = true
}
return 0, fmt.Errorf("log buffer is full")
}
b.buffer = append(b.buffer, string(p))
return len(p), nil
}
// Sync flushes the buffer to Elasticsearch
func (b *bufferedElasticsearchSyncer) Sync() error {
b.Flush()
return nil
}
// Flush sends the buffered log entries to Elasticsearch with advanced routing
// but never sends the same log to both the override index and the default index.
func (b *bufferedElasticsearchSyncer) Flush() {
// Lock only the buffer mutex
b.bufferMutex.Lock()
defer b.bufferMutex.Unlock()
if len(b.buffer) == 0 {
return
}
if b.client == nil {
if !b.warnLogged {
b.warnLogged = true
}
return
}
// Copy and clear the buffer so we can release the lock sooner
bufferCopy := b.buffer
b.buffer = nil
bulkRequest := b.client.Bulk()
for _, logEntry := range bufferCopy {
// Parse the JSON to check "es_skip" and "es_index"
var record map[string]interface{}
if err := json.Unmarshal([]byte(logEntry), &record); err != nil {
// If parsing fails, skip or store as-is. We'll skip to avoid malformed JSON in ES.
continue
}
// Skip if "es_skip" == true
if skipVal, ok := record["es_skip"].(bool); ok && skipVal {
continue
}
// If "es_index" is set, store in that index
if overrideIndex, ok := record["es_index"].(string); ok && overrideIndex != "" {
req := elastic.NewBulkIndexRequest().Index(overrideIndex).Doc(record)
bulkRequest = bulkRequest.Add(req)
} else {
// Otherwise, store in the default index
req := elastic.NewBulkIndexRequest().Index(b.index).Doc(record)
bulkRequest = bulkRequest.Add(req)
}
}
flushCtx, cancel := context.WithTimeout(b.ctx, 3*time.Second)
defer cancel()
_, err := bulkRequest.Do(flushCtx)
if err != nil {
// If it's a 401, disable ES immediately (atomic, no global lock)
if esErr, ok := err.(*elastic.Error); ok && esErr.Status == 401 {
disableES()
return
}
now := time.Now()
// Throttle repeated warnings
if b.errorCount == 0 || now.Sub(b.lastErrorTime) > 5*time.Minute {
b.lastErrorTime = now
b.errorCount = 1
} else {
b.errorCount++
}
// If it fails too many times in a short window, disable ES
if b.errorCount >= 3 {
disableES()
}
} else {
b.errorCount = 0
}
}
func (b *bufferedElasticsearchSyncer) Close() {
b.cancelFunc()
b.wg.Wait()
b.Flush()
}
func (b *bufferedElasticsearchSyncer) setFlushInterval(interval time.Duration) {
b.cancelFunc()
b.wg.Wait()
b.ctx, b.cancelFunc = context.WithCancel(context.Background())
b.flushInterval = interval
b.wg.Add(1)
go b.start()
}
// SetElasticsearchEndpoint updates the Elasticsearch URL and reinitializes the logger.
func SetElasticsearchEndpoint(url string) error {
mutex.Lock()
ObservabilityCfg.Elastic.URL = url
mutex.Unlock()
err := initLogger(ObservabilityCfg)
if err != nil {
return fmt.Errorf("failed to reinitialize logger after updating ES endpoint: %v", err)
}
return nil
}
// initEventBus initializes the global event bus
func initEventBus(host host.Host) error {
EventBus = host.EventBus()
var err error
customEventEmitter, err = EventBus.Emitter(new(CustomEvent))
if err != nil {
return fmt.Errorf("failed to create custom event emitter: %w", err)
}
return nil
}
// newEventEmitterCore creates a zapcore.Core that emits log entries to the event bus
func newEventEmitterCore(levelEnabler zapcore.LevelEnabler) zapcore.Core {
return &eventEmitterCore{
LevelEnabler: levelEnabler,
}
}
// eventEmitterCore is a zapcore.Core that emits log entries to the event bus
type eventEmitterCore struct {
zapcore.LevelEnabler
fields []zapcore.Field
}
func (e *eventEmitterCore) With(fields []zapcore.Field) zapcore.Core {
return &eventEmitterCore{
LevelEnabler: e.LevelEnabler,
fields: append(e.fields, fields...),
}
}
func (e *eventEmitterCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry {
if e.LevelEnabler.Enabled(entry.Level) {
return ce.AddCore(entry, e)
}
return ce
}
func (e *eventEmitterCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
if IsNoOpMode() {
return nil
}
eventData := map[string]interface{}{
"level": entry.Level.String(),
"timestamp": entry.Time.UTC().Format(time.RFC3339),
"message": entry.Message,
"logger": entry.LoggerName,
"caller": entry.Caller.String(),
}
for _, field := range fields {
eventData[field.Key] = field.Interface
}
customEvent := CustomEvent{
Name: "log_event",
Timestamp: entry.Time,
Data: eventData,
}
// nil-guard: only emit if the bus is initialised
if customEventEmitter != nil {
_ = customEventEmitter.Emit(customEvent)
}
return nil
}
func (e *eventEmitterCore) Sync() error {
return nil
}
// SetLogLevel sets the global log level for all collectors
func SetLogLevel(level string) error {
mutex.Lock()
defer mutex.Unlock()
logLevel, err := parseLogLevel(level)
if err != nil {
return fmt.Errorf("invalid log level: %w", err)
}
ObservabilityCfg.Logging.Level = level
atomicLevel.SetLevel(logLevel)
return nil
}
// SetFlushInterval sets the flush interval for Elasticsearch logging dynamically
func SetFlushInterval(seconds int) error {
mutex.Lock()
ObservabilityCfg.Elastic.FlushInterval = seconds
localEsSyncer := esSyncerInstance
mutex.Unlock()
if localEsSyncer != nil {
localEsSyncer.setFlushInterval(time.Duration(seconds) * time.Second)
}
return nil
}
// EmitCustomEvent allows developers to emit custom events with variadic key-value pairs
func EmitCustomEvent(eventName string, keyValues ...interface{}) error {
if len(keyValues)%2 != 0 {
return fmt.Errorf("keyValues must be in key-value pairs")
}
eventData := make(map[string]interface{})
for i := 0; i < len(keyValues); i += 2 {
key, ok := keyValues[i].(string)
if !ok {
return fmt.Errorf("key must be a string")
}
eventData[key] = keyValues[i+1]
}
customEvent := CustomEvent{
Name: eventName,
Timestamp: time.Now(),
Data: eventData,
}
if err := customEventEmitter.Emit(customEvent); err != nil {
log.Debug("Failed to emit custom event", zap.Error(err))
}
return nil
}
// Shutdown cleans up resources
func Shutdown() {
mutex.Lock()
if customEventEmitter != nil {
customEventEmitter.Close()
}
if esSyncerInstance != nil {
esSyncerInstance.Close()
esSyncerInstance = nil
}
mutex.Unlock()
// Shutdown the tracer
ShutdownTracer()
}
// ShutdownTracer wraps the Shutdown function from tracing.go
func ShutdownTracer() {
shutdownTracer()
}
// SetNoOpMode enables or disables the no-op mode for observability.
func SetNoOpMode(enabled bool) {
mutex.Lock()
noOpMode = enabled
if noOpMode {
// If in no-op mode, set the log level to a very high threshold
atomicLevel.SetLevel(zapcore.Level(100))
} else {
logLevel, err := parseLogLevel(ObservabilityCfg.Logging.Level)
if err != nil {
logLevel = zapcore.InfoLevel
}
atomicLevel.SetLevel(logLevel)
}
mutex.Unlock()
}
// IsNoOpMode returns whether observability is in no-op mode.
func IsNoOpMode() bool {
mutex.RLock()
defer mutex.RUnlock()
return noOpMode
}
// SetAPIKey updates the API key for both Elasticsearch and APM.
func SetAPIKey(apiKey string) error {
mutex.Lock()
ObservabilityCfg.Elastic.APIKey = apiKey
ApmCfg.APIKey = apiKey
mutex.Unlock()
// Reinit logger outside the lock
err := initLogger(ObservabilityCfg)
if err != nil {
return fmt.Errorf("failed to reinitialize logger: %v", err)
}
if ApmCfg.ServerURL != "" {
initTracing(ApmCfg)
}
return nil
}
// SetAPMURL updates the APM server URL and reinitializes the APM tracer.
func SetAPMURL(url string) error {
mutex.Lock()
if noOpMode {
mutex.Unlock()
return nil
}
ApmCfg.ServerURL = url
mutex.Unlock()
if ApmCfg.ServerURL != "" {
initTracing(ApmCfg)
} else {
ShutdownTracer()
}
return nil
}
// EnableElasticsearchLogging enables or disables Elasticsearch logging dynamically.
func EnableElasticsearchLogging(enabled bool) error {
mutex.Lock()
ObservabilityCfg.Elastic.Enabled = enabled
mutex.Unlock()
// Attempt reinit
err := initLogger(ObservabilityCfg)
if err != nil {
return fmt.Errorf("failed to reinitialize logger: %v", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
// tracing.go
package observability
import (
"context"
"net/url"
"os"
"path/filepath"
"runtime/trace"
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/disk"
"github.com/shirou/gopsutil/v4/host"
"github.com/shirou/gopsutil/v4/load"
"github.com/shirou/gopsutil/v4/mem"
"github.com/shirou/gopsutil/v4/net"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/types"
"go.elastic.co/apm/module/apmhttp/v2"
"go.elastic.co/apm/transport"
"go.elastic.co/apm/v2"
)
const (
// EnvFlightrecSec triggers a flight recorder to start recording a specified number of seconds, which later can be
// saved into a trace file.
EnvFlightrecSec = "DMS_FLIGHTREC_SEC"
)
var (
tracingNoOpMode bool
tracerMutex sync.Mutex
currentTracer *apm.Tracer
// rootTransaction is the top-level trace.
rootTransaction *apm.Transaction
// rootSpan is the top-level span.
rootSpan *apm.Span
// activeSpans is last span created, used to bind log msgs.
activeSpans []*apm.Span
// latestSpanID of the last span created, used in case of no active spans.
latestSpanID string
// internal tracing
flightrec *trace.FlightRecorder
flightrecHandleOnce sync.Once
flightrecCaptureOnce sync.Once
)
// initTracing initializes or reinitializes the Elastic APM tracer.
func initTracing(apmConfig config.APM) {
tracerMutex.Lock()
defer tracerMutex.Unlock()
if IsNoOpMode() {
tracingNoOpMode = true
return
}
// Close existing tracer if any
if currentTracer != nil {
currentTracer.Close()
currentTracer = nil
}
// Check required APM configs
if apmConfig.ServerURL == "" || apmConfig.ServiceName == "" || apmConfig.Environment == "" {
log.Warn("APM configurations are incomplete, tracing will be disabled")
tracingNoOpMode = true
return
}
// Create a new APM transport
tr, err := transport.NewHTTPTransport()
if err != nil {
log.Warnf("Failed to create APM transport: %v", err)
tracingNoOpMode = true
return
}
// Parse and set the APM Server URL
serverURL, err := url.Parse(apmConfig.ServerURL)
if err != nil {
log.Warnf("Failed to parse APM server URL: %v", err)
tracingNoOpMode = true
return
}
tr.SetServerURL(serverURL)
// Set API key if provided
if apmConfig.SecretToken != "" {
tr.SetSecretToken(apmConfig.SecretToken)
} else if apmConfig.APIKey != "" {
tr.SetAPIKey(apmConfig.APIKey)
}
// Initialize the APM tracer
tracer, err := apm.NewTracerOptions(apm.TracerOptions{
ServiceName: apmConfig.ServiceName,
ServiceVersion: "1.0.0",
Transport: tr,
})
if err != nil {
log.Warnf("Failed to initialize APM tracer: %v", err)
tracingNoOpMode = true
return
}
tracer.SetMetricsInterval(10 * time.Second)
apm.SetDefaultTracer(tracer)
currentTracer = tracer
tracingNoOpMode = false
initRootTrace(tracer)
}
func initRootTrace(tracer *apm.Tracer) {
// compose a distinctive name
name := "DMS-" + didID.String()
if nodeName := os.Getenv("ELASTIC_APM_SERVICE_NODE_NAME"); nodeName != "" {
name += "-" + nodeName
}
// create the root trace
opts := apm.TransactionOptions{}
if traceparent := os.Getenv("ELASTIC_APM_TRACEPARENT"); traceparent != "" {
traceCtx, _ := apmhttp.ParseTraceparentHeader(traceparent)
traceCtx.State, _ = apmhttp.ParseTracestateHeader(os.Getenv("ELASTIC_APM_TRACESTATE"))
opts.TraceContext = traceCtx
}
rootTransaction = tracer.StartTransactionOptions(name, "background-job", opts)
rootTransaction.Context.SetLabel("did", didID.String())
rootSpan, _ = apm.StartSpan(apm.ContextWithTransaction(context.Background(), rootTransaction), "root", "custom")
}
func collectSystemMetrics() map[string]interface{} {
metrics := make(map[string]interface{})
// CPU usage
if cpuUsage, err := cpu.Percent(0, false); err == nil && len(cpuUsage) > 0 {
metrics["cpuUsage"] = cpuUsage[0]
}
// RAM usage
if v, err := mem.VirtualMemory(); err == nil {
metrics["ramUsed"] = v.Used
metrics["ramTotal"] = v.Total
}
// Disk usage
if partitions, err := disk.Partitions(false); err == nil {
var used, total uint64
for _, part := range partitions {
if usage, err := disk.Usage(part.Mountpoint); err == nil {
used += usage.Used
total += usage.Total
}
}
metrics["diskUsed"] = used
metrics["diskTotal"] = total
}
// Uptime
if uptime, err := host.Uptime(); err == nil {
metrics["uptime"] = float64(uptime)
}
// Load average
if avg, err := load.Avg(); err == nil {
metrics["load15"] = avg.Load15
}
// Network RX/TX
if ioStats, err := net.IOCounters(false); err == nil && len(ioStats) > 0 {
metrics["rxBytes"] = ioStats[0].BytesRecv
metrics["txBytes"] = ioStats[0].BytesSent
}
return metrics
}
// StartSpan is a unified entry point to start instrumentation.
//
// Usage patterns:
// - StartSpan(operationName string, keyValues ...interface{})
// - StartSpan(ctx context.Context, operationName string, keyValues ...interface{})
// - StartSpan(c *gin.Context, operationName string, keyValues ...interface{})
//
// Logic:
// 1. If we find an existing span in ctx (e.g., from apmgin), start a nested span.
// 2. If no existing span is found, nest under the root DMS span.
func StartSpan(args ...interface{}) func() {
var ctx context.Context
var operationName string
var keyValues []interface{}
if len(args) == 0 {
log.Error("StartSpan called without arguments")
return func() {}
}
// Determine if the first argument is a context or operation name
switch v := args[0].(type) {
case string:
// No context provided
ctx = context.Background()
// sanitize
operationName = strings.ReplaceAll(strings.ReplaceAll(v,
": /", ": "),
"/", "_")
keyValues = args[1:]
case *gin.Context:
ctx = v.Request.Context()
if len(args) < 2 {
log.Error("StartSpan called with *gin.Context but without operation name")
return func() {}
}
if opName, ok := args[1].(string); ok {
operationName = opName
keyValues = args[2:]
} else {
log.Error("Operation name must be a string when called with *gin.Context")
return func() {}
}
case context.Context:
ctx = v
if len(args) < 2 {
log.Error("StartSpan called with context but without operation name")
return func() {}
}
if opName, ok := args[1].(string); ok {
operationName = opName
keyValues = args[2:]
} else {
log.Error("Operation name must be a string when called with context.Context")
return func() {}
}
default:
log.Error("Unsupported first argument type for StartSpan")
return func() {}
}
return startSpan(ctx, operationName, keyValues...)
}
func startSpan(ctx context.Context, operationName string, keyValues ...interface{}) func() {
tracerMutex.Lock()
noOp := tracingNoOpMode
tracer := currentTracer
tracerMutex.Unlock()
if IsNoOpMode() || noOp || tracer == nil {
return func() {}
}
// get tx from ctx or fall back to root tx
parent := apm.SpanFromContext(ctx)
if parent == nil {
ctx = apm.ContextWithSpan(ctx, rootSpan)
}
// create a new span
span, _ := apm.StartSpan(ctx, operationName, "custom")
if span.Dropped() {
// TODO causes an inf loop
// log.Warn("Span dropped: " + operationName)
return func() {}
}
activeSpans = append(activeSpans, span)
latestSpanID = span.TraceContext().Span.String()
// set up labels
span.Context.SetLabel("did", didID.String())
for i := 0; i < len(keyValues); i += 2 {
if i+1 < len(keyValues) {
if key, ok := keyValues[i].(string); ok {
span.Context.SetLabel(key, keyValues[i+1])
}
}
}
startTime := time.Now()
log.Debugw("Operation started inside existing transaction",
"operation", operationName,
"trace.id", parent.TraceContext().Trace.String(),
"transaction.id", parent.TraceContext().Span.String())
return func() {
duration := time.Since(startTime)
log.Debugw("Operation ended",
"operation", operationName,
"duration", duration,
"trace.id", parent.TraceContext().Trace.String(),
"transaction.id", parent.TraceContext().Span.String())
span.End()
// remove from active stack
activeSpans = slices.DeleteFunc(activeSpans, func(s *apm.Span) bool {
return s == span
})
}
}
// shutdownTracer closes the current tracer
func shutdownTracer() {
tracerMutex.Lock()
defer tracerMutex.Unlock()
if currentTracer != nil {
rootSpan.End()
rootTransaction.End()
ch := make(chan struct{})
currentTracer.Flush(ch)
close(ch)
currentTracer.Close()
currentTracer = nil
}
}
// FlightrecInit sets up the flight recorder based on env vars.
func FlightrecInit() {
flightrecHandleOnce.Do(func() {
secsNum, _ := strconv.Atoi(os.Getenv(EnvFlightrecSec))
if secsNum <= 0 {
return
}
flightrec = trace.NewFlightRecorder(trace.FlightRecorderConfig{
MinAge: time.Duration(secsNum) * time.Second,
MaxBytes: 5 * types.MB,
})
if err := flightrec.Start(); err != nil {
log.Errorw("flightrec_start", "error", err)
}
})
}
// FlightrecCapture captures a flight recorder snapshot.
func FlightrecCapture(path, file string) {
// once.Do ensures that the provided function is executed only once.
flightrecCaptureOnce.Do(func() {
f, err := os.Create(filepath.Join(path, file))
if err != nil {
log.Errorw("opening_flightrec", "file", f.Name(), "error", err)
return
}
defer f.Close() // ignore error
// WriteTo writes the flight recorder data to the provided io.Writer.
log.Infow("flightrec_capture", "file", f.Name())
_, err = flightrec.WriteTo(f)
if err != nil {
log.Errorw("writing_flightrec", "file", f.Name(), "error", err)
return
}
// Stop the flight recorder after the snapshot has been taken.
flightrec.Stop()
log.Infow("flightrec_captured", "file", f.Name())
})
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package basiccontroller
import (
"context"
"os"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// BasicVolumeController is the default implementation of the VolumeController.
// It persists storage volumes information using the StorageVolume.
type BasicVolumeController struct {
// repo is the repository for storage volume operations
repo repositories.GenericRepository[types.StorageVolume]
// basePath is the base path where volumes are stored under
basePath string
// file system to act upon
FS afero.Fs
}
// NewDefaultVolumeController returns a new instance of BasicVolumeController
//
// TODO-BugFix [path]: volBasePath might not end with `/`, causing errors when calling methods.
// We need to validate it using the `path` library or just verifying the string.
func NewDefaultVolumeController(repo repositories.GenericRepository[types.StorageVolume], volBasePath string, fs afero.Fs) (*BasicVolumeController, error) {
endSpan := observability.StartSpan(TraceVolumeControllerInitDuration)
defer endSpan()
vc := &BasicVolumeController{
repo: repo,
basePath: volBasePath,
FS: fs,
}
log.Infow(LogVolumeControllerInitSuccess, LogKeyBasePath, volBasePath)
return vc, nil
}
// CreateVolume creates a new storage volume given a source (S3, IPFS, job, etc). The
// creation of a storage volume effectively creates an empty directory in the local filesystem
// and writes a record in the database.
//
// The directory name follows the format: `<volSource> + "-" + <name>
// where `name` is random.
//
// TODO-maybe [withName]: allow callers to specify custom name for path
func (vc *BasicVolumeController) CreateVolume(volSource storage.VolumeSource, opts ...storage.CreateVolOpt) (types.StorageVolume, error) {
endSpan := observability.StartSpan(TraceVolumeCreateDuration)
defer endSpan()
vol := types.StorageVolume{
Private: false,
ReadOnly: false,
EncryptionType: types.EncryptionTypeNull,
}
for _, opt := range opts {
opt(&vol)
}
randomStr, err := utils.RandomString(16)
if err != nil {
log.Errorw(LogVolumeCreateFailure, LogKeyError, err)
return types.StorageVolume{}, err
}
vol.Path = vc.basePath + string(volSource) + "-" + randomStr
if err := vc.FS.Mkdir(vol.Path, os.ModePerm); err != nil {
log.Errorw(LogVolumeCreateFailure, LogKeyPath, vol.Path, LogKeyError, err)
return types.StorageVolume{}, err
}
createdVol, err := vc.repo.Create(context.TODO(), vol)
if err != nil {
log.Errorw(LogVolumeCreateFailure, LogKeyPath, vol.Path, LogKeyError, err)
return types.StorageVolume{}, err
}
log.Infow(LogVolumeCreateSuccess, LogKeyVolumeID, createdVol.ID, LogKeyPath, vol.Path)
return createdVol, nil
}
// LockVolume makes the volume read-only, not only changing the field value but also changing file permissions.
// It should be used after all necessary data has been written.
// It optionally can also set the CID and mark the volume as private.
//
// TODO-maybe [CID]: maybe calculate CID of every volume in case WithCID opt is not provided
func (vc *BasicVolumeController) LockVolume(pathToVol string, opts ...storage.LockVolOpt) error {
endSpan := observability.StartSpan(TraceVolumeLockDuration)
defer endSpan()
query := vc.repo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("Path", pathToVol))
vol, err := vc.repo.Find(context.TODO(), query)
if err != nil {
log.Errorw(LogVolumeLockFailure, LogKeyPath, pathToVol, LogKeyError, err)
return err
}
for _, opt := range opts {
opt(&vol)
}
vol.ReadOnly = true
updatedVol, err := vc.repo.Update(context.TODO(), vol.ID, vol)
if err != nil {
log.Errorw(LogVolumeLockFailure, LogKeyVolumeID, vol.ID, LogKeyError, err)
return err
}
// Change file permissions to read-only
if err := vc.FS.Chmod(updatedVol.Path, 0o400); err != nil {
log.Errorw(LogVolumeLockFailure, LogKeyPath, updatedVol.Path, LogKeyError, err)
return err
}
log.Infow(LogVolumeLockSuccess, LogKeyVolumeID, updatedVol.ID, LogKeyPath, updatedVol.Path)
return nil
}
// WithPrivate designates a given volume as private. It can be used both
// when creating or locking a volume.
func WithPrivate[T storage.CreateVolOpt | storage.LockVolOpt]() T {
return func(v *types.StorageVolume) {
v.Private = true
}
}
// WithCID sets the CID of a given volume if already calculated
//
// TODO [validate]: check if CID provided is valid
func WithCID(cid string) storage.LockVolOpt {
return func(v *types.StorageVolume) {
v.CID = cid
}
}
// DeleteVolume deletes a given storage volume record from the database and removes the corresponding directory.
// Identifier is either a CID or a path of a volume.
//
// Note [CID]: if we start to type CID as cid.CID, we may have to use generics here
// as in `[T string | cid.CID]`
func (vc *BasicVolumeController) DeleteVolume(identifier string, idType storage.IDType) error {
endSpan := observability.StartSpan(TraceVolumeDeleteDuration)
defer endSpan()
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
log.Errorw(LogVolumeDeleteFailure, LogKeyIdentifier, identifier, LogKeyError, ErrMsgInvalidIdentifier)
return ErrInvalidIdentifier
}
vol, err := vc.repo.Find(context.TODO(), query)
if err != nil {
if err == repositories.ErrNotFound {
log.Errorw(LogVolumeDeleteFailure, LogKeyIdentifier, identifier, LogKeyError, ErrMsgVolumeNotFound)
return repositories.ErrNotFound
}
log.Errorw(LogVolumeDeleteFailure, LogKeyIdentifier, identifier, LogKeyError, err)
return err
}
// Remove the directory
if err := vc.FS.RemoveAll(vol.Path); err != nil {
log.Errorw(LogVolumeDeleteFailure, LogKeyPath, vol.Path, LogKeyError, err)
return err
}
if err := vc.repo.Delete(context.TODO(), vol.ID); err != nil {
log.Errorw(LogVolumeDeleteFailure, LogKeyVolumeID, vol.ID, LogKeyError, err)
return err
}
log.Infow(LogVolumeDeleteSuccess, LogKeyVolumeID, vol.ID, LogKeyPath, vol.Path)
return nil
}
// ListVolumes returns a list of all storage volumes stored on the database
//
// TODO [filter]: maybe add opts to filter results by certain values
func (vc *BasicVolumeController) ListVolumes() ([]types.StorageVolume, error) {
endSpan := observability.StartSpan(TraceVolumeListDuration)
defer endSpan()
volumes, err := vc.repo.FindAll(context.TODO(), vc.repo.GetQuery())
if err != nil {
log.Errorw(LogVolumeListFailure, LogKeyError, err)
return nil, err
}
log.Infow(LogVolumeListSuccess, LogKeyVolumeCount, len(volumes))
return volumes, nil
}
// GetSize returns the size of a volume
// TODO-minor: identify which measurement type will be used
func (vc *BasicVolumeController) GetSize(identifier string, idType storage.IDType) (int64, error) {
endSpan := observability.StartSpan(TraceVolumeGetSizeDuration)
defer endSpan()
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
log.Errorw(LogVolumeGetSizeFailure, LogKeyIdentifier, identifier, LogKeyError, ErrMsgInvalidIdentifier)
return 0, ErrInvalidIdentifier
}
vol, err := vc.repo.Find(context.TODO(), query)
if err != nil {
log.Errorw(LogVolumeGetSizeFailure, LogKeyIdentifier, identifier, LogKeyError, err)
return 0, err
}
size, err := utils.GetDirectorySize(vc.FS, vol.Path)
if err != nil {
log.Errorw(LogVolumeGetSizeFailure, LogKeyPath, vol.Path, LogKeyError, err)
return 0, err
}
log.Infow(LogVolumeGetSizeSuccess, LogKeyVolumeID, vol.ID, LogKeySize, size)
return size, nil
}
// EncryptVolume encrypts a given volume
func (vc *BasicVolumeController) EncryptVolume(path string, _ types.Encryptor, _ types.EncryptionType) error {
endSpan := observability.StartSpan(TraceVolumeEncryptDuration)
defer endSpan()
log.Errorw(LogVolumeEncryptNotImplemented, LogKeyPath, path)
return ErrNotImplemented
}
// DecryptVolume decrypts a given volume
func (vc *BasicVolumeController) DecryptVolume(path string, _ types.Decryptor, _ types.EncryptionType) error {
endSpan := observability.StartSpan(TraceVolumeDecryptDuration)
defer endSpan()
log.Errorw(LogVolumeDecryptNotImplemented, LogKeyPath, path)
return ErrNotImplemented
}
var _ storage.VolumeController = (*BasicVolumeController)(nil)
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package basiccontroller
import (
"context"
"fmt"
"github.com/spf13/afero"
cloverDB "gitlab.com/nunet/device-management-service/db/clover"
"gitlab.com/nunet/device-management-service/types"
)
type VolumeControllerTestKit struct {
BasicVolController *BasicVolumeController
Fs afero.Fs
Volumes map[string]*types.StorageVolume
}
// SetupVolumeControllerTestKit sets up a volume controller with 0-n volumes given a base path.
// If volumes are inputed, directories will be created and volumes will be stored in the database
func SetupVolumeControllerTestKit(basePath string, volumes map[string]*types.StorageVolume) (*VolumeControllerTestKit, error) {
// Initialize telemetry in test mode, replacing the global st
// It's initiated here too, besides on basic_controller_test.go, because
// s3 tests depend on basicController (which in turn depends on telemetry instantiation).
// S3 are calling this SetupVolControllerTestSuite, so it's one way to initialize telemetry
// for basic controller
db, err := cloverDB.NewMemDB(
[]string{
"storage_volume",
},
)
if err != nil {
return nil, fmt.Errorf("failed to create in-memory mock database: %w", err)
}
fs := afero.NewMemMapFs()
err = fs.MkdirAll(basePath, 0o755)
if err != nil {
return nil, fmt.Errorf("failed to create base path: %w", err)
}
repo := cloverDB.NewGenericRepository[types.StorageVolume](db)
vc, err := NewDefaultVolumeController(repo, basePath, fs)
if err != nil {
return nil, fmt.Errorf("failed to create volume controller: %w", err)
}
for _, vol := range volumes {
// create root volume dir
err = fs.MkdirAll(vol.Path, 0o755)
if err != nil {
return nil, fmt.Errorf("failed to create volume dir: %w", err)
}
// create volume record in db
_, err = repo.Create(context.Background(), *vol)
if err != nil {
return nil, fmt.Errorf("failed to create volume record: %w", err)
}
}
return &VolumeControllerTestKit{
BasicVolController: vc,
Fs: fs,
Volumes: volumes,
}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/storage"
basicController "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Download fetches files from a given S3 bucket. The key may be a directory ending
// with `/` or have a wildcard (`*`) so it handles normal S3 folders but it does
// not handle x-directory.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Download(ctx context.Context, sourceSpecs *types.SpecConfig) (types.StorageVolume, error) {
endSpan := observability.StartSpan(ctx, "s3_download")
defer endSpan()
var storageVol types.StorageVolume
source, err := DecodeInputSpec(sourceSpecs)
if err != nil {
log.Errorw("s3_download_failure", "error", err)
return types.StorageVolume{}, err
}
storageVol, err = s.volController.CreateVolume(storage.VolumeSourceS3)
if err != nil {
log.Errorw("s3_volume_create_failure", "error", fmt.Errorf("failed to create storage volume: %w", err))
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %v", err)
}
resolvedObjects, err := resolveStorageKey(ctx, s.Client, &source)
if err != nil {
log.Errorw("s3_resolve_key_failure", "error", fmt.Errorf("failed to resolve storage key: %v", err))
return types.StorageVolume{}, fmt.Errorf("failed to resolve storage key: %v", err)
}
for _, resolvedObject := range resolvedObjects {
err = s.downloadObject(ctx, &source, resolvedObject, storageVol.Path)
if err != nil {
log.Errorw("s3_download_object_failure", "error", fmt.Errorf("failed to download s3 object: %v", err))
return types.StorageVolume{}, fmt.Errorf("failed to download s3 object: %v", err)
}
}
// after data is filled within the volume, we have to lock it
err = s.volController.LockVolume(storageVol.Path)
if err != nil {
log.Errorw("s3_volume_lock_failure", "error", fmt.Errorf("failed to lock storage volume: %v", err))
return types.StorageVolume{}, fmt.Errorf("failed to lock storage volume: %v", err)
}
log.Infow("s3_download_success", "volumeID", storageVol.ID, "path", storageVol.Path)
return storageVol, nil
}
func (s *Storage) downloadObject(ctx context.Context, source *InputSource, object s3Object, volPath string) error {
endSpan := observability.StartSpan(ctx, "s3_download_object")
defer endSpan()
outputPath := filepath.Join(volPath, *object.key)
// use the same file system instance used by the Volume Controller
var fs afero.Fs
if basicVolController, ok := s.volController.(*basicController.BasicVolumeController); ok {
fs = basicVolController.FS
}
err := fs.MkdirAll(outputPath, 0o755)
if err != nil {
log.Errorw("s3_create_directory_failure", "path", outputPath, "error", fmt.Errorf("failed to create directory: %v", err))
return fmt.Errorf("failed to create directory: %v", err)
}
if object.isDir {
// if object is a directory, we don't need to download it (just create the dir)
return nil
}
outputFile, err := fs.OpenFile(outputPath, os.O_RDWR|os.O_CREATE, 0o755)
if err != nil {
log.Errorw("s3_open_file_failure", "path", outputPath, "error", err)
return err
}
defer outputFile.Close()
log.Debugw("Downloading s3 object", "objectKey", *object.key, "outputPath", outputPath)
_, err = s.downloader.Download(ctx, outputFile, &s3.GetObjectInput{
Bucket: aws.String(source.Bucket),
Key: object.key,
IfMatch: object.eTag,
})
if err != nil {
log.Errorw("s3_download_failure", "objectKey", *object.key, "error", fmt.Errorf("failed to download file: %w", err))
return fmt.Errorf("failed to download file: %w", err)
}
log.Infow("s3_download_object_success", "objectKey", *object.key)
return nil
}
// resolveStorageKey returns a list of s3 objects within a bucket according to the key provided.
func resolveStorageKey(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := source.Key
if key == "" {
err := fmt.Errorf("key is required")
log.Errorw("s3_resolve_key_failure", "error", err)
return nil, err
}
// Check if the key represents a single object
if !strings.HasSuffix(key, "/") && !strings.Contains(key, "*") {
return resolveSingleObject(ctx, client, source)
}
// key represents multiple objects
return resolveObjectsWithPrefix(ctx, client, source)
}
func resolveSingleObject(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
headObjectInput := &s3.HeadObjectInput{
Bucket: aws.String(source.Bucket),
Key: aws.String(key),
}
headObjectOut, err := client.HeadObject(ctx, headObjectInput)
if err != nil {
log.Errorw("s3_head_object_failure", "key", key, "error", err)
return []s3Object{}, fmt.Errorf("failed to retrieve object metadata: %v", err)
}
if strings.HasPrefix(*headObjectOut.ContentType, "application/x-directory") {
err := fmt.Errorf("x-directory is not yet handled")
log.Errorw("s3_directory_handling_failure", "key", key, "error", err)
return []s3Object{}, err
}
return []s3Object{
{
key: aws.String(source.Key),
eTag: headObjectOut.ETag,
size: *headObjectOut.ContentLength,
},
}, nil
}
func resolveObjectsWithPrefix(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
// List objects with the given prefix
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(source.Bucket),
Prefix: aws.String(key),
}
var objects []s3Object
paginator := s3.NewListObjectsV2Paginator(client, listObjectsInput)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
log.Errorw("s3_list_objects_failure", "error", fmt.Errorf("failed to list objects: %v", err))
return nil, fmt.Errorf("failed to list objects: %v", err)
}
for _, obj := range page.Contents {
objects = append(objects, s3Object{
key: aws.String(*obj.Key),
size: *obj.Size,
isDir: strings.HasSuffix(*obj.Key, "/"),
})
}
}
log.Infow("s3_resolve_objects_with_prefix_success", "objectCount", len(objects))
return objects, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
// s3/aws_config.go
package s3
import (
"context"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"gitlab.com/nunet/device-management-service/observability"
)
// GetAWSDefaultConfig returns the default AWS config based on environment variables,
// shared configuration and shared credentials files.
func GetAWSDefaultConfig() (aws.Config, error) {
endSpan := observability.StartSpan("get_aws_default_config")
defer endSpan()
var optFns []func(*config.LoadOptions) error
cfg, err := config.LoadDefaultConfig(context.Background(), optFns...)
if err != nil {
log.Errorw("get_aws_default_config_failure", "error", err)
return aws.Config{}, err
}
log.Infow("get_aws_default_config_success")
return cfg, nil
}
// hasValidCredentials checks if the provided AWS config has valid credentials.
func hasValidCredentials(config aws.Config) bool {
endSpan := observability.StartSpan("has_valid_credentials")
defer endSpan()
credentials, err := config.Credentials.Retrieve(context.Background())
if err != nil {
log.Errorw("has_valid_credentials_failure", "error", err)
return false
}
if !credentials.HasKeys() {
log.Errorw("has_valid_credentials_failure_no_keys")
return false
}
log.Infow("has_valid_credentials_success")
return true
}
// sanitizeKey removes trailing spaces and wildcards
func sanitizeKey(key string) string {
endSpan := observability.StartSpan("sanitize_key")
defer endSpan()
sanitizedKey := strings.TrimSuffix(strings.TrimSpace(key), "*")
log.Infow("sanitize_key_success", "sanitizedKey", sanitizedKey)
return sanitizedKey
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
// s3/client.go
package s3
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
)
type Storage struct {
*s3.Client
volController storage.VolumeController
downloader *s3Manager.Downloader
uploader *s3Manager.Uploader
}
type s3Object struct {
key *string
eTag *string
size int64
isDir bool
}
// NewClient creates a new S3Storage which includes a S3-SDK client.
// It depends on a VolumeController to manage the volumes being acted upon.
func NewClient(config aws.Config, volController storage.VolumeController) (*Storage, error) {
endSpan := observability.StartSpan("new_client")
defer endSpan()
if !hasValidCredentials(config) {
err := fmt.Errorf("invalid credentials")
log.Errorw("new_client_invalid_credentials", "error", err)
return nil, err
}
s3Client := s3.NewFromConfig(config)
storage := &Storage{
s3Client,
volController,
s3Manager.NewDownloader(s3Client),
s3Manager.NewUploader(s3Client),
}
log.Infow("new_client_success")
return storage, nil
}
// Size calculates the size of a given object in S3.
func (s *Storage) Size(ctx context.Context, source *types.SpecConfig) (uint64, error) {
endSpan := observability.StartSpan(ctx, "s3_size")
defer endSpan()
inputSource, err := DecodeInputSpec(source)
if err != nil {
log.Errorw("s3_size_decode_input_spec_failure", "error", err)
return 0, fmt.Errorf("failed to decode input spec: %v", err)
}
input := &s3.HeadObjectInput{
Bucket: aws.String(inputSource.Bucket),
Key: aws.String(inputSource.Key),
}
output, err := s.HeadObject(ctx, input)
if err != nil {
log.Errorw("s3_size_head_object_failure", "error", err)
return 0, fmt.Errorf("failed to get object size: %v", err)
}
log.Infow("s3_size_success", "bucket", inputSource.Bucket, "key", inputSource.Key, "size", *output.ContentLength)
return uint64(*output.ContentLength), nil
}
// Compile time interface check
// var _ storage.StorageProvider = (*S3Storage)(nil)
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
// s3/input_source.go
package s3
import (
"fmt"
"github.com/fatih/structs"
"github.com/go-viper/mapstructure/v2"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/types"
)
type InputSource struct {
Bucket string
Key string
Filter string
Region string
Endpoint string
}
func (s InputSource) Validate() error {
if s.Bucket == "" {
err := fmt.Errorf("invalid s3 storage params: bucket cannot be empty")
log.Errorw("s3_input_source_validation_failure", "error", err)
return err
}
return nil
}
func (s InputSource) ToMap() map[string]interface{} {
return structs.Map(s)
}
func DecodeInputSpec(spec *types.SpecConfig) (InputSource, error) {
endSpan := observability.StartSpan("decode_input_spec")
defer endSpan()
if !spec.IsType(types.StorageProviderS3) {
err := fmt.Errorf("invalid storage source type. Expected %s but received %s", types.StorageProviderS3, spec.Type)
log.Errorw("decode_input_spec_invalid_type_failure", "error", err)
return InputSource{}, err
}
inputParams := spec.Params
if inputParams == nil {
err := fmt.Errorf("invalid storage input source params. cannot be nil")
log.Errorw("decode_input_spec_nil_params_failure", "error", err)
return InputSource{}, err
}
var c InputSource
if err := mapstructure.Decode(spec.Params, &c); err != nil {
log.Errorw("decode_input_spec_decode_failure", "error", err)
return c, err
}
if err := c.Validate(); err != nil {
log.Errorw("decode_input_spec_validation_failure", "error", err)
return c, err
}
log.Infow("decode_input_spec_success", "bucket", c.Bucket)
return c, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
// s3/upload.go
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/observability"
basicController "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Upload uploads all files (recursively) from a local volume to an S3 bucket.
// It handles directories.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Upload(ctx context.Context, vol types.StorageVolume, destinationSpecs *types.SpecConfig) error {
endSpan := observability.StartSpan(ctx, "s3_upload")
defer endSpan()
target, err := DecodeInputSpec(destinationSpecs)
if err != nil {
log.Errorw("s3_upload_decode_spec_failure", "error", err)
return fmt.Errorf("failed to decode input spec: %v", err)
}
sanitizedKey := sanitizeKey(target.Key)
// set file system to act upon based on the volume controller implementation
var fs afero.Fs
if basicVolController, ok := s.volController.(*basicController.BasicVolumeController); ok {
fs = basicVolController.FS
}
log.Debugw("Uploading files", "sourcePath", vol.Path, "bucket", target.Bucket, "key", sanitizedKey)
err = afero.Walk(fs, vol.Path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
log.Errorw("s3_upload_walk_failure", "error", err)
return err
}
// Skip directories
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(vol.Path, filePath)
if err != nil {
log.Errorw("s3_upload_relative_path_failure", "error", err)
return fmt.Errorf("failed to get relative path: %v", err)
}
// Construct the S3 key by joining the sanitized key and the relative path
s3Key := filepath.Join(sanitizedKey, relPath)
file, err := fs.Open(filePath)
if err != nil {
log.Errorw("s3_upload_open_file_failure", "filePath", filePath, "error", err)
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
log.Debugw("Uploading file to S3", "filePath", filePath, "bucket", target.Bucket, "key", s3Key)
_, err = s.uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(target.Bucket),
Key: aws.String(s3Key),
Body: file,
})
if err != nil {
log.Errorw("s3_upload_put_object_failure", "filePath", filePath, "error", err)
return fmt.Errorf("failed to upload file to S3: %v", err)
}
log.Infow("s3_upload_file_success", "filePath", filePath, "bucket", target.Bucket, "key", s3Key)
return nil
})
if err != nil {
log.Errorw("s3_upload_failure", "error", err)
return fmt.Errorf("upload failed. It's possible that some files were uploaded; Error: %v", err)
}
log.Infow("s3_upload_success", "sourcePath", vol.Path, "bucket", target.Bucket)
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package controller
import (
"crypto/sha256"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
logging "github.com/ipfs/go-log/v2"
dmscrypto "gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/utils/sys"
)
var log = logging.Logger("actor")
var _ GlusterControllerInterface = (*GlusterController)(nil)
// GlusterControllerInterface defines the contract for a GlusterFS controller.
type GlusterControllerInterface interface {
CreateVolume(volName, clientPEM string) (string, error)
StartVolume(volName string) error
DeleteVolume(volName string) error
CheckServer() error
IsServerWorking() bool
}
const fuseModel = "fuse"
// GlusterController is responsible for managing GlusterFS volumes.
type GlusterController struct {
glusterfsServerHostname string
bricksDir string
caAuthority string
}
// NewGlusterController creates a new instance of GlusterController.
func NewGlusterController(glusterfsServerHostname, bricksDir, caDir string) (*GlusterController, error) {
if caDir == "" {
return nil, errors.New("glusterfs CA directory is empty")
}
if !isModuleLoaded(fuseModel) {
err := loadModule(fuseModel)
if err != nil {
log.Warnf("failed to load fuse kernel module: %v", err)
}
}
g := &GlusterController{
glusterfsServerHostname: glusterfsServerHostname,
bricksDir: bricksDir,
caAuthority: caDir,
}
g.ensureDirectories()
// check if the glusterfs_nodes contains a list of server certificates
empty, err := isPEMDirectoryEmpty(filepath.Join(g.caAuthority, "glusterfs_nodes"))
if err != nil {
return nil, fmt.Errorf("failed to check glusterfs_nodes for server certificates: %w", err)
}
if empty {
return nil, errors.New("glusterfs_nodes must contain server certificates")
}
return g, nil
}
func (gc *GlusterController) ensureDirectories() {
dirs := []string{
filepath.Join(gc.caAuthority, "glusterfs_nodes"),
filepath.Join(gc.caAuthority, "clients"),
gc.bricksDir,
}
for _, dir := range dirs {
if err := os.MkdirAll(dir, 0o755); err != nil {
log.Errorf("failed to create directory %s: %w", dir, err)
}
}
}
func (gc *GlusterController) createCACerts(folders []string, output string) error {
caFilePath := filepath.Join(gc.caAuthority, output)
var caContent strings.Builder
for _, folder := range folders {
files, err := os.ReadDir(folder)
if err != nil {
return fmt.Errorf("failed to read directory %s: %w", folder, err)
}
for _, file := range files {
if filepath.Ext(file.Name()) == ".pem" {
filePath := filepath.Join(folder, file.Name())
content, err := os.ReadFile(filePath)
if err != nil {
return fmt.Errorf("failed to read file %s: %w", filePath, err)
}
caContent.Write(content)
}
}
}
if err := os.WriteFile(caFilePath, []byte(caContent.String()), 0o644); err != nil {
return fmt.Errorf("failed to write CA file %s: %w", caFilePath, err)
}
return nil
}
// generateGlusterFSServerCA concatenates all .pem files in glusterfs_nodes and clients folders into a single file "glusterfs.ca"
func (gc *GlusterController) generateGlusterFSServerCA() error {
folders := []string{
filepath.Join(gc.caAuthority, "glusterfs_nodes"),
filepath.Join(gc.caAuthority, "clients"),
}
return gc.createCACerts(folders, "glusterfs.ca")
}
// generateGlusterFSClientCA concatenates all .pem files in glusterfs_nodes folders into a single file "glusterfs-client.ca"
func (gc *GlusterController) generateGlusterFSClientCA() error {
folders := []string{
filepath.Join(gc.caAuthority, "glusterfs_nodes"),
}
return gc.createCACerts(folders, "glusterfs-client.ca")
}
// enableTLS enables tls for the volume.
func (gc *GlusterController) enableTLS(volName string) error {
cmds := [][]string{
{"volume", "set", volName, "server.ssl", "on"},
{"volume", "set", volName, "client.ssl", "on"},
}
for _, cmd := range cmds {
output, err := sys.ExecCommand("gluster", cmd...).CombinedOutput()
if err != nil {
return fmt.Errorf("failed to set TLS option %v for volume %s: %v, output: %s", cmd, volName, err, string(output))
}
}
return nil
}
// CreateVolume creates a new GlusterFS volume.
func (gc *GlusterController) CreateVolume(volName string, clientPem string) (string, error) {
err := validatePEM([]byte(clientPem))
if err != nil {
return "", fmt.Errorf("failed to validate pem: %w", err)
}
_, err = gc.saveHashedContent(clientPem)
if err != nil {
return "", fmt.Errorf("failed to save client certificate: %w", err)
}
// create a random brick name
randomBytes, err := dmscrypto.RandomEntropy(20)
if err != nil {
return "", fmt.Errorf("failed to create random brick name: %w", err)
}
generatedBrickName, err := dmscrypto.Sha3(randomBytes)
if err != nil {
return "", fmt.Errorf("failed to generate brick hash: %w", err)
}
args := []string{"volume", "create", volName, fmt.Sprintf("%s:%s", gc.glusterfsServerHostname, filepath.Join(gc.bricksDir, hex.EncodeToString(generatedBrickName)))}
// force create
args = append(args, "force")
output, err := sys.ExecCommand("gluster", args...).CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to create volume %s: %v, output: %s", volName, err, string(output))
}
// refresh the CA
if err := gc.generateGlusterFSServerCA(); err != nil {
return "", fmt.Errorf("failed to reload glusterfs server ca: %w", err)
}
// refresh client CA
if err := gc.generateGlusterFSClientCA(); err != nil {
return "", fmt.Errorf("failed to reload glusterfs client ca: %w", err)
}
if err := gc.enableTLS(volName); err != nil {
return "", fmt.Errorf("failed to enable TLS for volume %s: %v", volName, err)
}
clientCert, err := os.ReadFile(filepath.Join(gc.caAuthority, "glusterfs-client.ca"))
if err != nil {
return "", fmt.Errorf("failed to read the client ca file: %w", err)
}
return string(clientCert), nil
}
// StartVolume starts a given GlusterFS volume.
func (gc *GlusterController) StartVolume(volName string) error {
output, err := sys.ExecCommand("gluster", "volume", "start", volName, "force").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to start volume %s: %v, output: %s", volName, err, string(output))
}
return nil
}
// DeleteVolume stops and deletes the specified GlusterFS volume.
func (gc *GlusterController) DeleteVolume(volName string) error {
output, err := sys.ExecCommand("gluster", "volume", "stop", volName, "--mode=script").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to stop volume %s: %v, output: %s", volName, err, string(output))
}
// Delete the volume
output, err = sys.ExecCommand("gluster", "volume", "delete", volName, "--mode=script").CombinedOutput()
if err != nil {
return fmt.Errorf("failed to delete volume %s: %v, output: %s", volName, err, string(output))
}
return nil
}
// CheckServer executes a simple GlusterFS command ("gluster pool list") to verify
// that the glusterd daemon is running and responsive.
func (gc *GlusterController) CheckServer() error {
output, err := sys.ExecCommand("gluster", "pool", "list").CombinedOutput()
if err != nil {
return fmt.Errorf("glusterfs server check failed: %w, output: %s", err, output)
}
return nil
}
// IsServerWorking returns true if CheckServer does not report any errors.
func (gc *GlusterController) IsServerWorking() bool {
return gc.CheckServer() == nil
}
func (gc *GlusterController) saveHashedContent(content string) (string, error) {
hash := sha256.Sum256([]byte(content))
hashString := hex.EncodeToString(hash[:])
filePath := filepath.Join(gc.caAuthority, "clients", hashString+".pem")
if err := os.WriteFile(filePath, []byte(content), 0o644); err != nil {
return "", fmt.Errorf("failed to write file %s: %w", filePath, err)
}
return filePath, nil
}
// isModuleLoaded checks if a given kernel module is loaded
func isModuleLoaded(module string) bool {
output, err := sys.ExecCommand("lsmod").CombinedOutput()
if err != nil {
return false
}
return strings.Contains(string(output), module)
}
func loadModule(module string) error {
_, err := sys.ExecCommand("modprobe", module).CombinedOutput()
if err != nil {
return err
}
return nil
}
// validatePEM checks if a given file is a valid PEM-encoded certificate or key
func validatePEM(data []byte) error {
for len(data) > 0 {
var block *pem.Block
block, data = pem.Decode(data)
if block == nil {
return errors.New("invalid PEM")
}
switch block.Type {
case "CERTIFICATE":
if _, err := x509.ParseCertificate(block.Bytes); err != nil {
return fmt.Errorf("invalid certificate: %w", err)
}
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
if _, err := x509.ParsePKCS8PrivateKey(block.Bytes); err != nil {
if _, errRSA := x509.ParsePKCS1PrivateKey(block.Bytes); errRSA != nil {
if _, errEC := x509.ParseECPrivateKey(block.Bytes); errEC != nil {
return errors.New("invalid certificate")
}
}
}
}
}
return nil
}
func isPEMDirectoryEmpty(dirPath string) (bool, error) {
entries, err := os.ReadDir(dirPath)
if err != nil {
return false, err
}
for _, entry := range entries {
if !entry.IsDir() && filepath.Ext(entry.Name()) == ".pem" {
return false, nil
}
}
return true, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package glusterfs
import (
"bufio"
"context"
"fmt"
"strings"
"sync"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"github.com/docker/docker/client"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
)
// GlusterFS holds the configuration needed to mount a GlusterFS volume.
type GlusterFS struct {
servers []string
name string
mu sync.Mutex
tracker *storage.VolumeTracker
allocationID string
clientPrivateKey string
clientPEM string
clientCA string
}
var _ types.Mounter = (*GlusterFS)(nil)
// New creates a new GlusterFS mounter with the provided configuration.
func New(t *storage.VolumeTracker, servers []string, name string, clientPrivateKey, clientPEM, clientCA, allocationID string) (*GlusterFS, error) {
if len(servers) == 0 {
return nil, fmt.Errorf("no GlusterFS servers provided")
}
if name == "" {
return nil, fmt.Errorf("no volume provided")
}
return &GlusterFS{
allocationID: allocationID,
servers: servers,
name: name,
tracker: t,
clientPrivateKey: clientPrivateKey,
clientPEM: clientPEM,
clientCA: clientCA,
}, nil
}
// Mount mounts the GlusterFS volume to the provided targetPath.
// Additional mount options can be passed in the options map.
func (g *GlusterFS) Mount(targetPath string, _ map[string]string) error {
g.mu.Lock()
defer g.mu.Unlock()
if g.tracker.IsMounted(targetPath) {
return fmt.Errorf("%s is already mounted", targetPath)
}
if targetPath == "" {
return fmt.Errorf("target path cannot be empty")
}
if err := g.runGlusterfsClient(targetPath); err != nil {
return fmt.Errorf("failed to run glusterfs client: %w", err)
}
return nil
}
func (g *GlusterFS) runGlusterfsClient(targetPath string) error {
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation(), client.WithHostFromEnv())
if err != nil {
return fmt.Errorf("failed to create Docker Client: %w", err)
}
binds := []string{
"/dev/:/dev/",
}
hostConfig := &container.HostConfig{
Binds: binds,
Privileged: true,
NetworkMode: "host",
CgroupnsMode: "host",
Mounts: []mount.Mount{
{
Type: mount.TypeBind,
Source: targetPath,
Target: "/mounted",
ReadOnly: false,
BindOptions: &mount.BindOptions{
Propagation: mount.PropagationRShared,
},
Consistency: mount.ConsistencyDefault,
},
},
}
envs := []string{
"GLUSTER_VOLUME=" + g.name,
"GLUSTER_HOST=" + strings.Join(g.servers, ","),
"MOUNT_PATH=mounted",
}
// if we are supplying tls certs then its a secure connection
if g.clientPrivateKey != "" {
clientAuth := []string{
"GLUSTERFS_PEM=" + g.clientPEM,
"GLUSTERFS_KEY=" + g.clientPrivateKey,
"GLUSTERFS_CA=" + g.clientCA,
}
envs = append(envs, clientAuth...)
}
containerConfig := &container.Config{
Env: envs,
Image: "nunet-glusterfs-client",
}
mountingContainerName := fmt.Sprintf("%s_%s", g.allocationID, g.name)
resp, err := cli.ContainerCreate(context.Background(), containerConfig, hostConfig, nil, nil, mountingContainerName)
if err != nil {
return fmt.Errorf("failed to create container: %w", err)
}
if err := cli.ContainerStart(context.Background(), resp.ID, container.StartOptions{}); err != nil {
return fmt.Errorf("failed to start glusterfs client container: %w", err)
}
logReader, err := cli.ContainerLogs(context.Background(), resp.ID, container.LogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
})
if err != nil {
return fmt.Errorf("failed to read container logs: %w", err)
}
defer logReader.Close()
logScanner := bufio.NewScanner(logReader)
for logScanner.Scan() {
logLine := logScanner.Text()
if strings.Contains(logLine, "mounted successfully at") {
fmt.Println(logLine)
g.tracker.TrackMount(targetPath, g.allocationID, resp.ID)
return nil
}
if strings.Contains(logLine, "failed mounting glusterfs volume") {
return fmt.Errorf("failed to mount volume: %s", g.name)
}
}
if err := logScanner.Err(); err != nil {
return fmt.Errorf("error reading logs: %w", err)
}
return nil
}
func (g *GlusterFS) unmountAndKillContainer(containerID string) error {
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation(), client.WithHostFromEnv())
if err != nil {
return fmt.Errorf("failed to create Docker Client: %w", err)
}
execConfig := container.ExecOptions{
Cmd: []string{"umount", "/mounted/"},
AttachStdout: true,
AttachStderr: true,
}
execResp, err := cli.ContainerExecCreate(context.Background(), containerID, execConfig)
if err != nil {
return fmt.Errorf("failed to create exec instance: %w", err)
}
execAttachResp, err := cli.ContainerExecAttach(context.Background(), execResp.ID, container.ExecAttachOptions{})
if err != nil {
return fmt.Errorf("failed to attach to exec instance: %w", err)
}
defer execAttachResp.Close()
logScanner := bufio.NewScanner(execAttachResp.Reader)
for logScanner.Scan() {
logLine := logScanner.Text()
fmt.Println(logLine)
if strings.Contains(logLine, "success") || strings.Contains(logLine, "not mounted") {
break
}
}
if err := logScanner.Err(); err != nil {
return fmt.Errorf("error reading exec output: %w", err)
}
if err := cli.ContainerKill(context.Background(), containerID, "SIGKILL"); err != nil {
return fmt.Errorf("failed to kill container %s: %w", containerID, err)
}
if err := cli.ContainerRemove(context.Background(), containerID, container.RemoveOptions{Force: true}); err != nil {
return fmt.Errorf("failed to remove container %s: %w", containerID, err)
}
return nil
}
// Unmount unmounts the GlusterFS volume from the provided targetPath.
func (g *GlusterFS) Unmount(targetPath string) error {
g.mu.Lock()
defer g.mu.Unlock()
if !g.tracker.IsMounted(targetPath) {
log.Warnf("target path %s is not mounted", targetPath)
// no need to unmount if it's not mounted
return nil
}
if targetPath == "" {
return fmt.Errorf("target path cannot be empty")
}
info, err := g.tracker.GetMountInfo(targetPath)
if err != nil {
return nil
}
err = g.unmountAndKillContainer(info.ContainerID)
if err != nil {
return fmt.Errorf("failed to unmount volume and kill container: %w", err)
}
g.tracker.UntrackMount(targetPath)
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package localfs
import (
"fmt"
"gitlab.com/nunet/device-management-service/types"
)
type LocalFS struct {
src string
}
var _ types.Mounter = (*LocalFS)(nil)
// New creates a new LocalFS storage instance using the provided path.
func New(src string) (*LocalFS, error) {
return &LocalFS{src}, nil
}
// Mount for LocalFS might perform a bind mount or simply check that the path exists.
func (l *LocalFS) Mount(targetPath string, _ map[string]string) error {
if targetPath == "" {
return fmt.Errorf("target path cannot be empty")
}
return nil
}
func (l *LocalFS) Unmount(_ string) error {
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package volume
import (
"fmt"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/storage/volume/glusterfs"
"gitlab.com/nunet/device-management-service/storage/volume/localfs"
"gitlab.com/nunet/device-management-service/types"
)
// New creates a volume implementation based on the provided configuration.
func New(t *storage.VolumeTracker, sc types.VolumeConfig, allocationID string) (types.Mounter, error) {
switch sc.Type {
case "glusterfs":
return glusterfs.New(t, sc.Servers, sc.Name, sc.ClientPrivateKey, sc.ClientPEM, sc.ClientCA, allocationID)
case "local":
return localfs.New(sc.Src)
default:
return nil, fmt.Errorf("unsupported storage type: %s", sc.Type)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package storage
import (
"errors"
"sync"
)
type VolumeTracker struct {
mu sync.RWMutex
mounts map[string]TrackedVolume
}
type TrackedVolume struct {
AllocationID string
ContainerID string
}
func NewVolumeTracker() *VolumeTracker {
return &VolumeTracker{
mounts: map[string]TrackedVolume{},
}
}
func (v *VolumeTracker) TrackMount(targetPath, allocationID, containerID string) {
v.mu.Lock()
defer v.mu.Unlock()
v.mounts[targetPath] = TrackedVolume{
AllocationID: allocationID,
ContainerID: containerID,
}
}
func (v *VolumeTracker) UntrackMount(targetPath string) {
v.mu.Lock()
defer v.mu.Unlock()
delete(v.mounts, targetPath)
}
func (v *VolumeTracker) IsMounted(targetPath string) bool {
v.mu.RLock()
defer v.mu.RUnlock()
_, ok := v.mounts[targetPath]
return ok
}
func (v *VolumeTracker) GetMountInfo(targetPath string) (TrackedVolume, error) {
v.mu.RLock()
defer v.mu.RUnlock()
info, ok := v.mounts[targetPath]
if !ok {
return TrackedVolume{}, errors.New("failed to find volume")
}
return info, nil
}
package tokenomics
import (
"errors"
"fmt"
"sync"
"time"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
contractstore "gitlab.com/nunet/device-management-service/tokenomics/store"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
)
// BillingTaskArgs contains the arguments for a billing task
type BillingTaskArgs struct {
ContractDID did.DID
ContractStore *contractstore.Store
UsageStore *usage.Store
ExecuteBilling func(contractDID did.DID) error
}
// ContractBillingScheduler manages all contract billing tasks using a centralized scheduler
type ContractBillingScheduler struct {
scheduler *bt.Scheduler
contractStore *contractstore.Store
usageStore *usage.Store
tasks map[string]*bt.Task // contract DID URI -> task
mu sync.RWMutex
billingFunc func(contractDID did.DID) error // Function to execute billing
}
// NewContractBillingScheduler creates a new billing scheduler
func NewContractBillingScheduler(
contractStore *contractstore.Store,
usageStore *usage.Store,
billingFunc func(contractDID did.DID) error,
) (*ContractBillingScheduler, error) {
// Create scheduler with reasonable limits
// - Max 10 concurrent billing tasks (adjust based on needs)
// - Poll every 30 seconds to check for ready tasks
scheduler := bt.NewScheduler(10, 30*time.Second)
if contractStore == nil {
return nil, errors.New("contract store is required")
}
if usageStore == nil {
return nil, errors.New("usage store is required")
}
if billingFunc == nil {
return nil, errors.New("billing function is required")
}
return &ContractBillingScheduler{
scheduler: scheduler,
contractStore: contractStore,
usageStore: usageStore,
tasks: make(map[string]*bt.Task),
billingFunc: billingFunc,
}, nil
}
// Start starts the billing scheduler
func (cbs *ContractBillingScheduler) Start() {
cbs.scheduler.Start()
}
// Stop stops the billing scheduler and waits for all tasks to complete
func (cbs *ContractBillingScheduler) Stop() {
cbs.scheduler.Stop()
}
// RegisterContract registers a contract for automatic billing
// This method is idempotent - calling it multiple times for the same contract is safe
func (cbs *ContractBillingScheduler) RegisterContract(contractDID did.DID) error {
cbs.mu.Lock()
defer cbs.mu.Unlock()
contractURI := contractDID.URI
// Check if already registered (idempotent check)
// This prevents duplicate registration if called from both createContractOnHost
// and StartContracts
if _, exists := cbs.tasks[contractURI]; exists {
return nil // Already registered, no-op
}
// Get contract to check payment model
contract, err := cbs.contractStore.GetContract(contractURI)
if err != nil {
return fmt.Errorf("failed to get contract: %w", err)
}
// Check if payment model supports automatic billing
processor, err := contracts.GetPaymentModelProcessor(contract.PaymentDetails.PaymentModel)
if err != nil {
return fmt.Errorf("failed to get payment processor: %w", err)
}
if !processor.SupportsAutomaticBilling() {
// Not applicable for this payment model
return nil
}
// Calculate billing cycle and check interval
billingCycle := calculateBillingCycle(contract)
checkInterval := calculateCheckInterval(billingCycle)
// Get last invoice time
lastInvoiceAt, err := cbs.usageStore.GetLastProcessedAt(contractURI)
if err != nil {
// If not found, use contract start date
lastInvoiceAt = contract.Duration.StartDate
}
if lastInvoiceAt.IsZero() {
lastInvoiceAt = contract.Duration.StartDate
}
// Create billing cycle trigger - calculates exact next invoice time
trigger := NewBillingCycleTrigger(billingCycle, lastInvoiceAt, checkInterval)
// Create task arguments
taskArgs := &BillingTaskArgs{
ContractDID: contractDID,
ContractStore: cbs.contractStore,
UsageStore: cbs.usageStore,
ExecuteBilling: cbs.billingFunc,
}
// Create task - inlined function for simplicity
task := &bt.Task{
Name: fmt.Sprintf("billing-%s", contractURI),
Description: fmt.Sprintf("Automatic billing for contract %s (%s)", contractURI, contract.PaymentDetails.PaymentModel),
Triggers: []bt.Trigger{trigger},
Function: func(args interface{}) error {
// The scheduler passes task.Args ([]interface{}) directly, so we need to extract the first element
argsSlice, ok := args.([]interface{})
if !ok || len(argsSlice) == 0 {
return fmt.Errorf("invalid billing task args")
}
billingArgs, ok := argsSlice[0].(*BillingTaskArgs)
if !ok {
return fmt.Errorf("invalid billing task args type")
}
return billingArgs.ExecuteBilling(billingArgs.ContractDID)
},
Args: []interface{}{taskArgs},
Enabled: true,
}
// Register task with scheduler
cbs.scheduler.AddTask(task)
cbs.tasks[contractURI] = task
log.Infow("registered contract for automatic billing",
"labels", string(observability.LabelContract),
"contract_did", contractURI,
"payment_model", contract.PaymentDetails.PaymentModel,
"billing_cycle", billingCycle,
"check_interval", checkInterval)
return nil
}
// UnregisterContract removes a contract from automatic billing
func (cbs *ContractBillingScheduler) UnregisterContract(contractDID did.DID) {
cbs.mu.Lock()
defer cbs.mu.Unlock()
contractURI := contractDID.URI
task, exists := cbs.tasks[contractURI]
if !exists {
return
}
cbs.scheduler.RemoveTask(task.ID)
delete(cbs.tasks, contractURI)
log.Infow("unregistered contract from automatic billing",
"labels", string(observability.LabelContract),
"contract_did", contractURI)
}
// GetTask returns the task for a contract (for observability)
func (cbs *ContractBillingScheduler) GetTask(contractDID did.DID) (*bt.Task, bool) {
cbs.mu.RLock()
defer cbs.mu.RUnlock()
task, exists := cbs.tasks[contractDID.URI]
return task, exists
}
// UpdateContract updates billing schedule after successful invoice
// Updates the trigger's lastInvoiceAt to use actual invoice time
func (cbs *ContractBillingScheduler) UpdateContract(contractDID did.DID) error {
cbs.mu.Lock()
defer cbs.mu.Unlock()
contractURI := contractDID.URI
task, exists := cbs.tasks[contractURI]
if !exists {
return nil
}
// Get updated lastInvoiceAt from store
// Note: This should be called after SaveLastProcessedAt in contract_host.go
lastInvoiceAt, err := cbs.usageStore.GetLastProcessedAt(contractURI)
if err != nil {
// If we can't get the time from store, log warning but don't fail
// The trigger will still work, just might not be perfectly accurate
log.Warnw("failed to get last invoice time for trigger update, using current time",
"labels", string(observability.LabelContract),
"contract_did", contractURI,
"error", err)
// Use current time as fallback
lastInvoiceAt = time.Now().UTC()
}
// Update trigger's lastInvoiceAt
for _, trigger := range task.Triggers {
if billingTrigger, ok := trigger.(*BillingCycleTrigger); ok {
billingTrigger.UpdateLastInvoiceAt(lastInvoiceAt)
log.Debugw("updated billing trigger after invoice",
"labels", string(observability.LabelContract),
"contract_did", contractURI,
"last_invoice_at", lastInvoiceAt,
"next_check", billingTrigger.nextCheck)
break
}
}
return nil
}
// calculateBillingCycle calculates the full billing cycle duration from PaymentPeriod × PaymentPeriodCount
func calculateBillingCycle(contract *contracts.Contract) time.Duration {
paymentPeriod := contract.PaymentDetails.PaymentPeriod
paymentPeriodCount := contract.PaymentDetails.PaymentPeriodCount
if paymentPeriodCount <= 0 {
paymentPeriodCount = 1
}
periodDuration, err := parsePaymentPeriod(paymentPeriod)
if err != nil {
// Default to 1 hour if invalid period
periodDuration = time.Hour
}
return periodDuration * time.Duration(paymentPeriodCount)
}
// calculateCheckInterval calculates the dynamic checker interval based on billing period
// Formula: clamp(billingCycle / 10, 30s, min(billingCycle / 2, 24h))
// This ensures:
// - Short periods (minutes): frequent checks (30s minimum)
// - Long periods (days/weeks/months): efficient checks (24h maximum)
func calculateCheckInterval(billingCycle time.Duration) time.Duration {
// Use 1/10 of billing cycle as base for reasonable granularity
// This means we'll check 10 times per billing cycle (good for short periods)
checkInterval := billingCycle / 10
// Apply minimum bound: 30 seconds (for fast testing with 1-minute periods)
minInterval := 30 * time.Second
if checkInterval < minInterval {
checkInterval = minInterval
}
// Apply maximum bound: Use 1/2 of period, but cap at 24 hours
// This prevents excessive checks for very long periods
// For 1-month period: 1/2 month = 15 days → capped at 24 hours
// For 1-week period: 1/2 week = 3.5 days → capped at 24 hours
// For 1-day period: 1/2 day = 12 hours → use 12 hours
maxInterval := billingCycle / 2
absoluteMaxInterval := 24 * time.Hour
if maxInterval > absoluteMaxInterval {
maxInterval = absoluteMaxInterval
}
if checkInterval > maxInterval {
checkInterval = maxInterval
}
return checkInterval
}
// parsePaymentPeriod converts a payment period string to a time.Duration
func parsePaymentPeriod(period string) (time.Duration, error) {
switch period {
case contracts.PaymentPeriodMinute:
return time.Minute, nil
case contracts.PaymentPeriodHour:
return time.Hour, nil
case contracts.PaymentPeriodDay:
return 24 * time.Hour, nil
case contracts.PaymentPeriodWeek:
return 7 * 24 * time.Hour, nil
case contracts.PaymentPeriodMonth:
// Approximate: 30 days (could be enhanced to handle exact calendar months)
return 30 * 24 * time.Hour, nil
default:
return 0, fmt.Errorf("invalid payment_period: %s", period)
}
}
package tokenomics
import (
"time"
)
// BillingCycleTrigger calculates the next invoice time based on billing cycle
// and ensures we check frequently enough to catch period boundaries.
type BillingCycleTrigger struct {
BillingCycle time.Duration // Full billing cycle (e.g., 1 hour, 1 day)
LastInvoiceAt time.Time // Last time an invoice was generated
CheckInterval time.Duration // Minimum check frequency (e.g., 30s)
nextCheck time.Time // Next time to check for invoice
startedAt time.Time // When trigger was initialized
}
// NewBillingCycleTrigger creates a new billing cycle trigger
func NewBillingCycleTrigger(
billingCycle time.Duration,
lastInvoiceAt time.Time,
checkInterval time.Duration,
) *BillingCycleTrigger {
now := time.Now().UTC()
trigger := &BillingCycleTrigger{
BillingCycle: billingCycle,
LastInvoiceAt: lastInvoiceAt,
CheckInterval: checkInterval,
startedAt: now,
}
trigger.nextCheck = trigger.calculateNextCheck(now)
return trigger
}
// IsReady checks if it's time to check for invoice generation
func (t *BillingCycleTrigger) IsReady(currentTime time.Time) bool {
now := currentTime.UTC()
return now.After(t.nextCheck) || now.Equal(t.nextCheck)
}
// MarkTriggered updates the trigger state after execution
func (t *BillingCycleTrigger) MarkTriggered(triggerTime time.Time) {
t.startedAt = triggerTime.UTC()
// Recalculate next check time
t.nextCheck = t.calculateNextCheck(triggerTime.UTC())
}
// Reset resets the trigger to a new start time
func (t *BillingCycleTrigger) Reset(currentTime time.Time) {
t.startedAt = currentTime.UTC()
t.nextCheck = t.calculateNextCheck(currentTime.UTC())
}
// UpdateLastInvoiceAt updates the last invoice time (called after successful billing)
func (t *BillingCycleTrigger) UpdateLastInvoiceAt(invoiceTime time.Time) {
t.LastInvoiceAt = invoiceTime.UTC()
// Use invoiceTime (or current time if invoiceTime is in the future) to calculate next check
now := time.Now().UTC()
if invoiceTime.After(now) {
// If invoice time is in the future (shouldn't happen), use current time
t.nextCheck = t.calculateNextCheck(now)
} else {
// Use invoice time to calculate next check for accurate scheduling
t.nextCheck = t.calculateNextCheck(invoiceTime.UTC())
}
}
// calculateNextCheck determines when to check next for invoice generation
func (t *BillingCycleTrigger) calculateNextCheck(now time.Time) time.Time {
if t.LastInvoiceAt.IsZero() {
// First invoice - check after one billing cycle from start
nextInvoice := t.startedAt.Add(t.BillingCycle)
// But ensure we check at least every CheckInterval
minCheck := now.Add(t.CheckInterval)
if nextInvoice.Before(minCheck) {
return minCheck
}
return nextInvoice
}
// Calculate next period boundary
elapsed := now.Sub(t.LastInvoiceAt)
periodsElapsed := elapsed / t.BillingCycle
nextPeriodBoundary := t.LastInvoiceAt.Add((periodsElapsed + 1) * t.BillingCycle)
// Ensure we check at least every CheckInterval
minNextCheck := now.Add(t.CheckInterval)
if nextPeriodBoundary.Before(minNextCheck) {
return minNextCheck
}
return nextPeriodBoundary
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package cardano
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"sync"
"time"
)
const (
defaultPageCount = 100
maxWorkers = 6
httpTimeout = 30 * time.Second
)
type BFClient struct {
baseURL string
apiKey string
client *http.Client
}
type AssetTx struct {
TxHash string `json:"tx_hash"`
Block string `json:"block"`
}
type TxUtxos struct {
Hash string `json:"hash"`
Inputs []TxUtxoEntry `json:"inputs"`
Outputs []TxUtxoEntry `json:"outputs"`
}
type TxUtxoEntry struct {
Address string `json:"address"`
Amount []AmountItem `json:"amount"`
OutputIndex *int `json:"output_index,omitempty"`
}
type AmountItem struct {
Unit string `json:"unit"`
Quantity string `json:"quantity"`
}
type Match struct {
TxHash string
BlockHash string
OutputIndex *int
FromAddrs []string
ToAddress string
Quantity string
Unit string
}
func NewClient(apiKey, endpoint string) *BFClient {
return &BFClient{
baseURL: endpoint,
apiKey: apiKey,
client: &http.Client{
Timeout: httpTimeout,
Transport: &http.Transport{
DialContext: (&net.Dialer{
Timeout: 10 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
TLSHandshakeTimeout: 10 * time.Second,
},
},
}
}
func (b *BFClient) doRequest(ctx context.Context, method, path string, params url.Values) (*http.Response, error) {
rel := path
if params != nil {
rel = rel + "?" + params.Encode()
}
u := b.baseURL + rel
req, err := http.NewRequestWithContext(ctx, method, u, nil)
if err != nil {
return nil, err
}
req.Header.Set("project_id", b.apiKey)
req.Header.Set("Accept", "application/json")
resp, err := b.client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
var bodyRepr struct {
Error string `json:"error"`
Msg string `json:"message"`
}
_ = json.NewDecoder(resp.Body).Decode(&bodyRepr)
resp.Body.Close()
return nil, fmt.Errorf("blockfrost error: status=%d url=%s err=%v", resp.StatusCode, u, bodyRepr)
}
return resp, nil
}
func (b *BFClient) ListAssetTxs(ctx context.Context, asset string) (<-chan AssetTx, <-chan error) {
out := make(chan AssetTx)
errCh := make(chan error, 1)
go func() {
defer close(out)
defer close(errCh)
page := 1
for {
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
default:
}
params := url.Values{}
params.Set("page", strconv.Itoa(page))
params.Set("count", strconv.Itoa(defaultPageCount))
path := fmt.Sprintf("/assets/%s/transactions", url.PathEscape(asset))
resp, err := b.doRequest(ctx, http.MethodGet, path, params)
if err != nil {
errCh <- fmt.Errorf("assets transactions request failed: %w", err)
return
}
body, err := io.ReadAll(resp.Body)
defer resp.Body.Close()
if err != nil {
errCh <- fmt.Errorf("failed to read body: %w", err)
return
}
var items []AssetTx
err = json.Unmarshal(body, &items)
if err != nil {
errCh <- fmt.Errorf("decode asset txs page %d failed: %w", page, err)
return
}
if len(items) == 0 {
return
}
for _, it := range items {
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
case out <- it:
}
}
if len(items) < defaultPageCount {
return
}
page++
time.Sleep(120 * time.Millisecond)
}
}()
return out, errCh
}
func (b *BFClient) GetTxUtxos(ctx context.Context, txHash string) (*TxUtxos, error) {
path := fmt.Sprintf("/txs/%s/utxos", url.PathEscape(txHash))
resp, err := b.doRequest(ctx, http.MethodGet, path, nil)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var result TxUtxos
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}
return &result, nil
}
func (b *BFClient) FindTxsToAddressForAsset(ctx context.Context, asset, destAddress string) ([]Match, error) {
assetTxsCh, errCh := b.ListAssetTxs(ctx, asset)
type job struct{ assetTx AssetTx }
jobs := make(chan job)
results := make(chan Match)
var wg sync.WaitGroup
for i := 0; i < maxWorkers; i++ {
wg.Add(1)
go func(workerID int) {
defer wg.Done()
for j := range jobs {
utxos, err := b.GetTxUtxos(ctx, j.assetTx.TxHash)
if err != nil {
if ctx.Err() != nil {
return
}
fmt.Fprintf(os.Stderr, "warning: worker %d: failed to fetch utxos for tx %s: %v\n", workerID, j.assetTx.TxHash, err)
time.Sleep(300 * time.Millisecond)
continue
}
// sender(inputs)
fromSet := make(map[string]struct{})
for _, inp := range utxos.Inputs {
fromSet[inp.Address] = struct{}{}
}
fromAddrs := make([]string, 0, len(fromSet))
for a := range fromSet {
fromAddrs = append(fromAddrs, a)
}
// outputs of destination
for _, out := range utxos.Outputs {
if out.Address != destAddress {
continue
}
for _, amt := range out.Amount {
if amt.Unit == asset {
results <- Match{
TxHash: j.assetTx.TxHash,
BlockHash: j.assetTx.Block,
OutputIndex: out.OutputIndex,
FromAddrs: fromAddrs,
ToAddress: out.Address,
Quantity: amt.Quantity,
Unit: amt.Unit,
}
}
}
}
time.Sleep(80 * time.Millisecond)
}
}(i)
}
var collectorErr error
doneCollect := make(chan struct{})
go func() {
defer close(doneCollect)
defer close(jobs)
for tx := range assetTxsCh {
select {
case <-ctx.Done():
collectorErr = ctx.Err()
return
default:
}
jobs <- job{assetTx: tx}
}
}()
go func() {
<-doneCollect
wg.Wait()
close(results)
}()
var matches []Match
collectLoop:
for {
select {
case e := <-errCh:
if e != nil {
return nil, e
}
case m, ok := <-results:
if !ok {
break collectLoop
}
matches = append(matches, m)
case <-ctx.Done():
return nil, ctx.Err()
}
}
if collectorErr != nil {
return nil, collectorErr
}
return matches, nil
}
package ethereum
import (
"encoding/json"
)
// GetBlockNumber returns the number of the most recent block.
func GetBlockNumber(c Caller) (uint64, error) {
resp, err := c.Call("eth_blockNumber", nil)
if err != nil {
return 0, err
}
if resp.Error != nil {
return 0, resp.Error
}
var result string
if err := json.Unmarshal(resp.Result, &result); err != nil {
return 0, err
}
return hexToBigInt(result).Uint64(), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ethereum
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
)
type Caller interface {
Call(method string, params []interface{}) (*RPCResponse, error)
}
// Client represents a JSON-RPC Ethereum client
type Client struct {
URL string
Token string
Client *http.Client
}
// RPCRequest represents a JSON-RPC request
type RPCRequest struct {
JSONRPC string `json:"jsonrpc"`
Method string `json:"method"`
Params []interface{} `json:"params"`
ID int `json:"id"`
}
// RPCResponse represents a JSON-RPC response
type RPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID int `json:"id"`
Result json.RawMessage `json:"result"`
Error *RPCError `json:"error,omitempty"`
}
// RPCError represents a JSON-RPC error
type RPCError struct {
Code int `json:"code"`
Message string `json:"message"`
}
func (e *RPCError) Error() string {
return fmt.Sprintf("RPC error %d: %s", e.Code, e.Message)
}
// NewClient creates a new Ethereum JSON-RPC client
func NewClient(url, token string) *Client {
return &Client{
URL: url,
Token: token,
Client: &http.Client{},
}
}
// Call sends a JSON-RPC request
func (c *Client) Call(method string, params []interface{}) (*RPCResponse, error) {
reqBody := RPCRequest{
JSONRPC: "2.0",
Method: method,
Params: params,
ID: 1,
}
data, err := json.Marshal(reqBody)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", c.URL, bytes.NewBuffer(data))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/json")
if c.Token != "" {
req.Header.Set("Authorization", "Bearer "+c.Token)
}
resp, err := c.Client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
var rpcResp RPCResponse
if err := json.Unmarshal(body, &rpcResp); err != nil {
return nil, err
}
return &rpcResp, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package ethereum
import (
"encoding/json"
"fmt"
"math/big"
"strings"
)
func hexToAddress(topic string) string {
return "0x" + topic[len(topic)-40:]
}
func hexToBigInt(hexStr string) *big.Int {
n := new(big.Int)
n.SetString(strings.TrimPrefix(hexStr, "0x"), 16)
return n
}
type ERC20Tx struct {
From string
To string
Amount string
TxHash string
}
func GetERC20Transfers(c Caller, tokenAddress, toAddress string, fromBlock, toBlock string) ([]ERC20Tx, error) {
transferTopic := "0xddf252ad1be2c89b69c2b068fc378daa952ba7f163c4a11628f55a4df523b3ef"
paddedTo := "0x" + strings.Repeat("0", 24) + strings.ToLower(strings.TrimPrefix(toAddress, "0x"))
params := []interface{}{
map[string]interface{}{
"fromBlock": fromBlock,
"toBlock": toBlock,
"address": tokenAddress,
"topics": []interface{}{transferTopic, nil, paddedTo},
},
}
resp, err := c.Call("eth_getLogs", params)
if err != nil {
return nil, err
}
if resp.Error != nil {
return nil, err
}
var logs []map[string]interface{}
if err := json.Unmarshal(resp.Result, &logs); err != nil {
return nil, err
}
txs := make([]ERC20Tx, len(logs))
for i, l := range logs {
topics := l["topics"].([]interface{})
if len(topics) == 0 {
continue
}
var txHash string
if h, ok := l["transactionHash"].(string); ok {
txHash = h
}
fromAddr := hexToAddress(topics[1].(string))
toAddr := hexToAddress(topics[2].(string))
amount := hexToBigInt(l["data"].(string))
humanAmount := convertToDecimals(amount, 6)
txs[i] = ERC20Tx{
From: fromAddr,
To: toAddr,
Amount: humanAmount,
TxHash: txHash,
}
}
return txs, nil
}
// convertToDecimals converts a big.Int token amount into a string with decimals
func convertToDecimals(amount *big.Int, decimals int) string {
dec := big.NewInt(0).Exp(big.NewInt(10), big.NewInt(int64(decimals)), nil) // 10^decimals
intPart := new(big.Int).Div(amount, dec)
fracPart := new(big.Int).Mod(amount, dec)
fracStr := fmt.Sprintf("%0*d", decimals, fracPart)
return fmt.Sprintf("%s.%s", intPart.String(), fracStr)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package tokenomics
import (
"context"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/behaviors"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/lib/ucan"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/observability"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/contracts/processors"
"gitlab.com/nunet/device-management-service/tokenomics/events"
contractstore "gitlab.com/nunet/device-management-service/tokenomics/store"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
)
const (
contractActorCapsLifespan = 7 * 24 * time.Hour
contractPeriodicChecker = 5 * time.Minute
)
// FixedRentalBillingCheckerInterval controls how often the contract actor checks
// whether a Fixed Rental invoice should be generated.
// In production this is set to 15 minutes to balance timeliness and load.
var FixedRentalBillingCheckerInterval = 15 * time.Minute
// PeriodicBillingCheckerInterval controls how often the contract actor checks
// whether a Periodic invoice should be generated.
// Set to 1 minute for testing, will be changed to 15 minutes after E2E tests pass.
var PeriodicBillingCheckerInterval = 15 * time.Minute
// Sentinel errors for fixed rental and periodic invoice calculation
var (
ErrFullPeriodElapsed = errors.New("full billing period has elapsed, use regular billing instead of pro-rated")
ErrPeriodNotElapsed = errors.New("billing period has not elapsed yet, no invoice needed")
ErrNoDeployments = errors.New("no deployments active during billing period, skipping invoice")
)
type ContractActor struct {
*actor.BasicActor
ContractDID did.DID
SolutionEnablerDID did.DID
forwardInvoice func(contracts.ContractUsageRequest) error // Function to forward invoice using solution enabler's actor
ctx context.Context
cancel context.CancelFunc
contractStore *contractstore.Store
participants contracts.ContractParticipants
PaymentProviderDID did.DID
usageStore *usage.Store
}
func NewContractActor(
solutionEnabler actor.Handle,
paymentValidator did.DID,
net network.Network,
participants contracts.ContractParticipants,
privKey crypto.PrivKey, pubKey crypto.PubKey,
contractStore *contractstore.Store,
usageStore *usage.Store,
forwardInvoice func(contracts.ContractUsageRequest) error, // Function to forward invoice using solution enabler's actor
) (*ContractActor, error) {
provider, err := did.ProviderFromPrivateKey(privKey)
if err != nil {
return nil, err
}
ctx := did.NewTrustContext()
ctx.AddProvider(provider)
contractKeyDID := did.FromPublicKey(pubKey)
contractCap, err := ucan.NewCapabilityContext(ctx, contractKeyDID, []did.DID{solutionEnabler.DID}, ucan.TokenList{}, ucan.TokenList{}, ucan.TokenList{})
if err != nil {
return nil, fmt.Errorf("contract actor capability context: %w", err)
}
newSecurityContext, err := actor.NewBasicSecurityContext(pubKey, privKey, contractCap)
if err != nil {
return nil, err
}
self := actor.Handle{
ID: newSecurityContext.ID(),
DID: newSecurityContext.DID(),
Address: actor.Address{
HostID: net.GetHostID().String(),
InboxAddress: contractKeyDID.URI,
},
}
actor, err := actor.New(self, net, newSecurityContext, actor.NewRateLimiter(actor.DefaultRateLimiterConfig()), actor.BasicActorParams{}, self)
if err != nil {
return nil, fmt.Errorf("new contract actor: %w", err)
}
ctxActor, cancel := context.WithCancel(context.Background())
contractActor := ContractActor{
BasicActor: actor,
ContractDID: contractKeyDID,
SolutionEnablerDID: solutionEnabler.DID,
forwardInvoice: forwardInvoice,
ctx: ctxActor,
cancel: cancel,
contractStore: contractStore,
participants: participants,
usageStore: usageStore,
PaymentProviderDID: paymentValidator,
}
if err := contractActor.setupBehaviorsAndCapabilities(); err != nil {
return nil, fmt.Errorf("setting up behaviors and capabilities: %w", err)
}
if err := contractActor.SetupParticipantsCapabilities(participants); err != nil {
return nil, fmt.Errorf("failed to setup participant capabilities: %w", err)
}
if err := contractActor.setupPaymentValidatorBehaviorAndCapabilities(paymentValidator); err != nil {
return nil, fmt.Errorf("failed to setup payment validator capabilities: %w", err)
}
return &contractActor, nil
}
func (c *ContractActor) Start() error {
err := c.BasicActor.Start()
if err != nil {
return fmt.Errorf("failed to start contract actor: %w", err)
}
go func() {
ticker := time.NewTicker(contractPeriodicChecker)
defer ticker.Stop()
for range ticker.C {
select {
case <-c.ctx.Done():
return
default:
}
// Refresh contract from store to get latest state
contract, err := c.contractStore.GetContract(c.ContractDID.URI)
if err != nil {
log.Errorw("contract not found while checking its status",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
continue
}
if time.Now().After(contract.Duration.EndDate) {
// if we reach duration then we mark it as completed
contract.CurrentState = contracts.ContractCompleted
if err := c.contractStore.Upsert(contract); err != nil {
log.Errorw("failed to update contract with status completed",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
}
log.Infow("contract has reached its end date",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"end_date", contract.Duration.EndDate)
return // Contract completed, stop checking
}
}
}()
// REMOVED: Old per-actor goroutine billing routine
// The billing is now handled by the centralized scheduler
// Registration happens via RegisterBilling() method called by Node
log.Infow("contract actor started",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
)
return nil
}
func (c *ContractActor) setupBehaviorsAndCapabilities() error {
contractBehaviors := c.getContractBehaviors()
for behavior, handler := range contractBehaviors {
if err := c.BasicActor.AddBehavior(behavior, handler.fn, handler.opts...); err != nil {
return fmt.Errorf("adding %s behavior: %w", behavior, err)
}
}
// Grant all capabilities to contract creator
err := c.Security().Grant(
c.SolutionEnablerDID,
c.ContractDID,
[]ucan.Capability{behaviors.TokenomicNamespace, behaviors.ContractStatusBehavior, behaviors.ContractProposeBehavior}, contractActorCapsLifespan)
if err != nil {
return fmt.Errorf("failed to grant capabilities: %w", err)
}
return nil
}
func (c *ContractActor) setupPaymentValidatorBehaviorAndCapabilities(paymentValidatorDID did.DID) error {
err := c.Security().Grant(
paymentValidatorDID,
c.ContractDID,
[]ucan.Capability{behaviors.ContractPaymentValidateBehavior},
contractActorCapsLifespan,
)
if err != nil {
return fmt.Errorf("failed to grant capabilities to Provider: %w", err)
}
return nil
}
// Setup contract participants capabilities
func (c *ContractActor) SetupParticipantsCapabilities(participants contracts.ContractParticipants) error {
// Grant capabilities to the primary party
err := c.Security().Grant(
participants.Provider,
c.ContractDID,
[]ucan.Capability{
behaviors.ContractEventsBehavior,
behaviors.ContractTerminationBehavior,
behaviors.ContractCompleteBehavior,
behaviors.ContractStatusBehavior,
behaviors.ContractSettleBehavior,
behaviors.ContractValidationBehavior,
behaviors.ContractSignBehavior,
behaviors.ContractPaymentValidateBehavior,
},
contractActorCapsLifespan,
)
if err != nil {
return fmt.Errorf("failed to grant capabilities to Provider: %w", err)
}
// Grant capabilities to the secondary party
err = c.Security().Grant(
participants.Requestor,
c.ContractDID,
[]ucan.Capability{
behaviors.ContractEventsBehavior,
behaviors.ContractTerminationBehavior,
behaviors.ContractCompleteBehavior,
behaviors.ContractStatusBehavior,
behaviors.ContractSettleBehavior,
behaviors.ContractValidationBehavior,
behaviors.ContractSignBehavior,
behaviors.ContractPaymentValidateBehavior,
},
contractActorCapsLifespan,
)
if err != nil {
return fmt.Errorf("failed to grant capabilities to Requestor: %w", err)
}
return nil
}
// getContractBehaviors returns a map of behavior names to handler functions
func (c *ContractActor) getContractBehaviors() map[string]struct {
fn func(actor.Envelope)
opts []actor.BehaviorOption
} {
contractBehaviors := map[string]struct {
fn func(actor.Envelope)
opts []actor.BehaviorOption
}{
behaviors.ContractTerminationBehavior: {
fn: c.handleContractTermination,
},
behaviors.ContractCompleteBehavior: {
fn: c.handleCompleteContract,
},
behaviors.ContractSettleBehavior: {
fn: c.handleSettleContract,
},
behaviors.ContractStatusBehavior: {
fn: c.handleContractState,
},
behaviors.ContractValidationBehavior: {
fn: c.handleContractValidation,
},
behaviors.ContractSignBehavior: {
fn: c.handleContractSignByParticipants,
},
behaviors.ContractPaymentValidateBehavior: {
fn: c.handlePaymentValidate,
},
behaviors.ContractEventsBehavior: {
fn: c.handleContractEvents,
},
}
return contractBehaviors
}
func (c *ContractActor) handleContractEvents(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.ContractEventResponse{}
var req contracts.ContractEventRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
// Extract event type, provider DID, and Head Contract DID
var eventType events.EventType
var providerDID string
var headContractDID string
var eventMap map[string]interface{}
if err := json.Unmarshal(req.Payload, &eventMap); err == nil {
if typeStr, ok := eventMap["type"].(string); ok {
eventType = events.EventType(typeStr)
}
// Extract provider DID from event payload if available
if provDID, ok := eventMap["provider_did"].(string); ok {
providerDID = provDID
}
// Extract Head Contract DID from event payload if available
if hcDid, ok := eventMap["head_contract_did"].(string); ok {
headContractDID = hcDid
}
}
// If not in payload, use message sender (for backwards compatibility)
if providerDID == "" {
providerDID = msg.From.DID.String()
}
// Store with event_type, provider_did, and head_contract_did
err := c.usageStore.AddUsageEvent(usage.Usage{
ContractDID: c.ContractDID.URI, // Tail Contract DID (this contract)
HeadContractDID: headContractDID, // Head Contract DID (if part of chain)
ProviderDID: providerDID,
EventType: eventType, // Optional - extracted from JSON if empty
Data: req.Payload,
})
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
c.sendReply(msg, resp)
}
func (c *ContractActor) handlePaymentValidate(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.PaymentValidateResponse{}
contract, err := c.contractStore.GetContract(c.ContractDID.URI)
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
contract.Paid = true
err = c.contractStore.Upsert(contract)
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
c.sendReply(msg, resp)
}
func (c *ContractActor) handleContractSignByParticipants(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.ContractSignResponse{}
var req contracts.ContractSignRequest
if err := json.Unmarshal(msg.Message, &req); err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
contract, err := c.contractStore.GetContract(c.ContractDID.URI)
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
// check if its already signed
signed := false
for _, v := range contract.Signatures {
if v.DID == msg.From.DID {
signed = true
break
}
}
if signed {
resp.Error = "contract already signed"
c.sendReply(msg, resp)
return
}
if contract.ContractParticipants.Provider == msg.From.DID || contract.ContractParticipants.Requestor == msg.From.DID {
contract.Signatures = append(contract.Signatures, contracts.Signature{
DID: msg.From.DID,
Signatures: req.Signature,
})
// if both sigs available mark the contract as accepted
// no need additional checks since the capabilities define
// who can access this contract so in this case only participants can sign
if len(contract.Signatures) == 2 {
contract.CurrentState = contracts.ContractAccepted
contract.Transitions = []contracts.StateTransition{
{
FromState: contracts.ContractDraft,
ToState: contracts.ContractAccepted,
Event: contracts.EventAccepted,
Timestamp: time.Now(),
},
}
}
err := c.contractStore.Upsert(contract)
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
resp.Contract = *contract
c.sendReply(msg, resp)
return
}
resp.Error = "not allowed to sign this contract"
c.sendReply(msg, resp)
}
func (c *ContractActor) handleContractState(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.ContractStatusResponse{}
contract, err := c.contractStore.GetContract(c.ContractDID.URI)
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
resp.Contract = *contract
c.sendReply(msg, resp)
}
// sendReply sends a reply to the given message envelope with the provided payload
func (c *ContractActor) sendReply(msg actor.Envelope, payload interface{}) {
var opt []actor.MessageOption
if msg.IsBroadcast() {
opt = append(opt, actor.WithMessageSource(c.Handle()))
}
reply, err := actor.ReplyTo(msg, payload, opt...)
if err != nil {
return
}
_ = c.Send(reply)
}
func (c *ContractActor) handleContractTermination(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.ContractTerminationResponse{}
savedContract, err := c.contractStore.GetContract(c.ContractDID.URI)
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
if !savedContract.TerminationOption.Allowed {
resp.Error = "contract is not allowed to be terminated"
c.sendReply(msg, resp)
return
}
var lastTransition contracts.ContractState
if len(savedContract.Transitions) > 0 {
lastTransition = savedContract.Transitions[len(savedContract.Transitions)-1].ToState
}
savedContract.CurrentState = contracts.ContractTerminated
savedContract.Transitions = append(savedContract.Transitions, contracts.StateTransition{
FromState: lastTransition,
ToState: contracts.ContractTerminated,
Timestamp: time.Now(),
Event: contracts.EventTerminate,
InitiatedBy: msg.From.DID,
})
err = c.contractStore.Upsert(savedContract)
if err != nil {
resp.Error = err.Error()
}
c.sendReply(msg, resp)
}
func (c *ContractActor) handleCompleteContract(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.ContractCompletionResponse{}
savedContract, err := c.contractStore.GetContract(c.ContractDID.URI)
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
var lastTransition contracts.ContractState
if len(savedContract.Transitions) > 0 {
lastTransition = savedContract.Transitions[len(savedContract.Transitions)-1].ToState
}
savedContract.CurrentState = contracts.ContractCompleted
savedContract.Transitions = append(savedContract.Transitions, contracts.StateTransition{
FromState: lastTransition,
ToState: contracts.ContractCompleted,
Timestamp: time.Now(),
Event: contracts.EventComplete,
InitiatedBy: msg.From.DID,
})
err = c.contractStore.Upsert(savedContract)
if err != nil {
resp.Error = err.Error()
}
c.sendReply(msg, resp)
}
func (c *ContractActor) handleSettleContract(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.ContractSettleResponse{}
savedContract, err := c.contractStore.GetContract(c.ContractDID.URI)
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
var lastTransition contracts.ContractState
if len(savedContract.Transitions) > 0 {
lastTransition = savedContract.Transitions[len(savedContract.Transitions)-1].ToState
}
savedContract.CurrentState = contracts.ContractSettled
savedContract.Transitions = append(savedContract.Transitions, contracts.StateTransition{
FromState: lastTransition,
ToState: contracts.ContractSettled,
Timestamp: time.Now(),
Event: contracts.EventSettle,
InitiatedBy: msg.From.DID,
})
err = c.contractStore.Upsert(savedContract)
if err != nil {
resp.Error = err.Error()
}
c.sendReply(msg, resp)
}
func (c *ContractActor) handleContractValidation(msg actor.Envelope) {
defer msg.Discard()
resp := contracts.ContractValidateResponse{}
contract, err := c.contractStore.GetContract(c.ContractDID.URI)
if err != nil {
resp.Error = err.Error()
c.sendReply(msg, resp)
return
}
resp.CurrentStatus = string(contract.CurrentState)
if contract.CurrentState == contracts.ContractAccepted || contract.CurrentState == contracts.ContractActive {
resp.Valid = true
}
c.sendReply(msg, resp)
}
func (c *ContractActor) sendFixedRentalInvoice(
contract *contracts.Contract,
fixedRentalUsage *contracts.FixedRentalUsage,
now time.Time,
) {
// Create ContractUsageRequest
req := contracts.ContractUsageRequest{
UniqueID: uuid.NewString(),
Contract: *contract,
Usages: fixedRentalUsage.PeriodsInvoiced,
FixedRentalDetails: fixedRentalUsage,
}
// Forward invoice using solution enabler's actor (contract host node actor)
// This ensures the message is sent from the solution enabler's actor handle, which has the required capabilities
if c.forwardInvoice == nil {
log.Errorw("forwardInvoice function not set for contract actor",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
return
}
if err := c.forwardInvoice(req); err != nil {
log.Errorw("failed to forward fixed rental invoice to payment validator",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
return
}
// Update last processed timestamp after successful send
// Note: If payment validator processing fails, we'll catch it on next cycle
err := c.usageStore.SaveLastProcessedAt(c.ContractDID.URI, now)
if err != nil {
log.Errorw("failed to save last processed timestamp after fixed rental invoice",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
} else {
log.Infow("fixed rental invoice generated and sent successfully",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"periods_invoiced", fixedRentalUsage.PeriodsInvoiced,
"amount", fixedRentalUsage.Amount)
}
}
// checkAndGenerateInvoice is the unified invoice checking and generation logic
// Returns true if the billing routine should stop (contract terminated/completed), false otherwise.
func (c *ContractActor) checkAndGenerateInvoice() bool {
// Get current contract state
contract, err := c.contractStore.GetContract(c.ContractDID.URI)
if err != nil {
log.Errorw("failed to get contract for automatic billing",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
return false // Continue checking (might be transient error)
}
// Skip all billing if explicitly disabled (Contract A)
// Organization contracts (Contract A) will be billed manually by the organization
// outside the contract system
if contract.DisableBilling {
log.Debugw("skipping billing (disabled by contract flag)",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
return true // Stop billing routine
}
// Get processor for this payment model
processor, err := contracts.GetPaymentModelProcessor(contract.PaymentDetails.PaymentModel)
if err != nil {
log.Errorw("failed to get payment model processor",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
return true // Stop routine if processor not found
}
// Defensive check: Only process contracts that support automatic billing
if !processor.SupportsAutomaticBilling() {
return true // Stop routine if payment model doesn't support automatic billing
}
// Check if contract is terminated - generate final invoice
if contract.CurrentState == contracts.ContractTerminated {
return c.handleTerminatedContractInvoice(contract, processor)
}
// Check if contract is completed or expired
if contract.CurrentState == contracts.ContractCompleted ||
time.Now().After(contract.Duration.EndDate) {
return true // Stop billing routine
}
// Get last invoice timestamp
lastInvoiceAt, err := c.usageStore.GetLastProcessedAt(c.ContractDID.URI)
if err != nil {
log.Errorw("failed to get last processed timestamp for automatic billing",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
return false // Continue checking (might be transient error)
}
now := time.Now()
// Initialize lastInvoiceAt if zero (first invoice)
// For contracts that start mid-period, first invoice should cover partial period
// by using contract start date as baseline.
// For automatic-only models (FixedRental, Periodic), we save the initialization
// to prevent infinite re-initialization loop. For models that support manual billing,
// we don't save to avoid interfering with manual collection which should start from zero.
unixEpoch := time.Unix(0, 0)
if lastInvoiceAt.IsZero() || lastInvoiceAt.Equal(unixEpoch) {
// For first invoice calculation, use contract start date to ensure partial period coverage
// This ensures no usage is lost if contract starts mid-period
lastInvoiceAt = contract.Duration.StartDate
// Save initialization timestamp for all automatic billing models
// This prevents infinite re-initialization loop.
// Manual collection can still work correctly - it will use this saved timestamp
// when querying, which is correct because it represents the last time automatic
// billing processed usage. Manual collection queries from lastProcessedAt to now,
// so if it's the first manual collection, it will still capture all usage from
// contract start date to now.
if err := c.usageStore.SaveLastProcessedAt(c.ContractDID.URI, lastInvoiceAt); err != nil {
log.Errorw("failed to save initial timestamp for automatic billing",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
}
// Return false to continue checking - not enough time has passed yet for first invoice
return false
}
// Use processor to check and generate invoice
usageData, err := processor.CheckAndGenerateInvoice(contract, lastInvoiceAt, now)
if err != nil {
// Check for specific error types
if errors.Is(err, processors.ErrPeriodNotElapsed) {
// Not enough time has passed yet, check again next time
return false // Continue checking
}
if errors.Is(err, processors.ErrNoDeployments) {
// No deployments during period - skip invoice and update timestamp
// This is specific to Periodic model
if contract.PaymentDetails.PaymentModel == contracts.Periodic {
// Update lastInvoiceAt to skip this period
periodDuration, err := parsePaymentPeriod(contract.PaymentDetails.PaymentPeriod)
if err == nil {
paymentPeriodCount := contract.PaymentDetails.PaymentPeriodCount
if paymentPeriodCount <= 0 {
paymentPeriodCount = 1
}
billingCycleDuration := periodDuration * time.Duration(paymentPeriodCount)
nextPeriodStart := lastInvoiceAt.Add(billingCycleDuration)
if err := c.usageStore.SaveLastProcessedAt(c.ContractDID.URI, nextPeriodStart); err != nil {
log.Errorw("failed to update last processed timestamp after skipping period",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
}
}
}
return false // Continue checking for next period
}
// Other error occurred - log but continue checking (might be transient error)
log.Errorw("failed to check and generate invoice",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
return false // Continue checking (might be transient error)
}
if usageData != nil {
// Invoice should be generated - send it
c.sendInvoice(contract, usageData, now)
return false // Continue billing routine
}
return false
}
// CheckAndGenerateInvoice is the public method for checking and generating invoices
// This is called by the centralized billing scheduler
// Returns an error if the contract is terminated or completed, signaling the scheduler to unregister
func (c *ContractActor) CheckAndGenerateInvoice() error {
shouldStop := c.checkAndGenerateInvoice()
if shouldStop {
// Contract terminated/completed - scheduler will detect this and unregister
return fmt.Errorf("contract terminated or completed")
}
return nil
}
// RegisterBilling registers this contract actor with the billing scheduler
// This method is called by the Node after creating the actor
func (c *ContractActor) RegisterBilling(scheduler *ContractBillingScheduler) error {
return scheduler.RegisterContract(c.ContractDID)
}
// UnregisterBilling unregisters this contract from billing
func (c *ContractActor) UnregisterBilling(scheduler *ContractBillingScheduler) {
scheduler.UnregisterContract(c.ContractDID)
}
// handleTerminatedContractInvoice handles final invoice generation for terminated contracts
func (c *ContractActor) handleTerminatedContractInvoice(contract *contracts.Contract, processor contracts.PaymentModelProcessor) bool {
lastInvoiceAt, err := c.usageStore.GetLastProcessedAt(c.ContractDID.URI)
if err != nil {
log.Errorw("failed to get last processed timestamp for terminated contract final invoice",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
return true // Stop billing routine
}
if lastInvoiceAt.IsZero() {
// No previous invoice, nothing to invoice
return true // Stop billing routine
}
// Get termination timestamp from contract transitions
// This ensures we pro-rate based on actual termination time, not when billing runs
var terminationTime time.Time
if contract.Transitions != nil {
for i := len(contract.Transitions) - 1; i >= 0; i-- {
if contract.Transitions[i].ToState == contracts.ContractTerminated {
terminationTime = contract.Transitions[i].Timestamp
break
}
}
}
// Use termination time if available, otherwise use current time as fallback
endTime := terminationTime
if terminationTime.IsZero() {
endTime = time.Now()
}
elapsed := endTime.Sub(lastInvoiceAt)
if elapsed <= 0 {
// No elapsed time, nothing to invoice
return true // Stop billing routine
}
// Try to generate regular invoice first
usageData, err := processor.CheckAndGenerateInvoice(contract, lastInvoiceAt, endTime)
if err != nil {
// If period not elapsed, try pro-rated invoice for periodic and fixed rental contracts
if errors.Is(err, processors.ErrPeriodNotElapsed) {
var proRatedUsageData *contracts.UsageData
var proRateErr error
var paymentModel string
switch contract.PaymentDetails.PaymentModel {
case contracts.Periodic:
// Pro-rate based on actual deployment time within the partial period
if periodicProcessor, ok := processor.(*processors.PeriodicProcessor); ok {
proRatedUsageData, proRateErr = periodicProcessor.GenerateProRatedInvoice(contract, lastInvoiceAt, endTime)
paymentModel = "periodic"
}
case contracts.FixedRental:
// Pro-rate based on elapsed time ratio to billing cycle
if fixedRentalProcessor, ok := processor.(*processors.FixedRentalProcessor); ok {
proRatedUsageData, proRateErr = fixedRentalProcessor.GenerateProRatedInvoice(contract, lastInvoiceAt, endTime)
paymentModel = "fixed_rental"
}
}
if proRateErr != nil {
// Pro-rating also failed
log.Warnw("could not generate pro-rated invoice for terminated contract",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"payment_model", paymentModel,
"error", proRateErr)
return true // Stop billing routine
}
if proRatedUsageData != nil {
// Pro-rated invoice generated - send it
var amount string
switch contract.PaymentDetails.PaymentModel {
case contracts.Periodic:
amount = proRatedUsageData.Data.(*contracts.PeriodicUsage).Amount
case contracts.FixedRental:
amount = proRatedUsageData.Data.(*contracts.FixedRentalUsage).Amount
}
log.Infow("generated pro-rated invoice for terminated contract",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"payment_model", paymentModel,
"amount", amount,
"termination_time", terminationTime,
"elapsed_since_last_invoice", elapsed)
c.sendInvoice(contract, proRatedUsageData, endTime)
return true // Stop billing routine after final invoice
}
}
// For other errors or payment models, just log and stop
log.Warnw("could not generate final invoice for terminated contract",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
return true // Stop billing routine
}
if usageData != nil {
// Final invoice generated - send it
c.sendInvoice(contract, usageData, endTime)
}
return true // Stop billing routine after final invoice
}
// sendInvoice sends invoice using the appropriate format based on payment model
func (c *ContractActor) sendInvoice(contract *contracts.Contract, usageData *contracts.UsageData, now time.Time) {
switch contract.PaymentDetails.PaymentModel {
case contracts.FixedRental:
c.sendFixedRentalInvoiceFromUsageData(contract, usageData, now)
case contracts.Periodic:
c.sendPeriodicInvoiceFromUsageData(contract, usageData, now)
case contracts.PayPerAllocation, contracts.PayPerDeployment, contracts.PayPerTimeUtilization, contracts.PayPerResourceUtilization:
// For these models, convert UsageData to ContractUsageRequest and forward via forwardInvoice
c.sendGenericInvoiceFromUsageData(contract, usageData, now)
default:
log.Errorw("unsupported payment model for automatic billing",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"payment_model", contract.PaymentDetails.PaymentModel)
}
}
// sendGenericInvoiceFromUsageData sends invoice for PayPerAllocation, PayPerDeployment, PayPerTimeUtilization, PayPerResourceUtilization
// by converting UsageData to ContractUsageRequest and forwarding via forwardInvoice
func (c *ContractActor) sendGenericInvoiceFromUsageData(contract *contracts.Contract, usageData *contracts.UsageData, now time.Time) {
req := contracts.ContractUsageRequest{
UniqueID: uuid.NewString(),
Contract: *contract,
}
// Convert UsageData to ContractUsageRequest format
switch usageData.PaymentModel {
case contracts.PayPerAllocation:
usageCount, ok := usageData.Data.(int)
if !ok {
log.Errorw("invalid usage data type for pay_per_allocation",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
return
}
req.Usages = usageCount
case contracts.PayPerDeployment:
usageCount, ok := usageData.Data.(int)
if !ok {
log.Errorw("invalid usage data type for pay_per_deployment",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
return
}
req.Usages = usageCount
case contracts.PayPerTimeUtilization:
timeUtil, ok := usageData.Data.(*contracts.TimeUtilizationUsage)
if !ok {
log.Errorw("invalid usage data type for pay_per_time_utilization",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
return
}
req.TimeUtilization = timeUtil
req.Usages = len(timeUtil.Deployments) // For backward compatibility
case contracts.PayPerResourceUtilization:
resourceUtil, ok := usageData.Data.(*contracts.ResourceUtilizationUsage)
if !ok {
log.Errorw("invalid usage data type for pay_per_resource_utilization",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
return
}
req.ResourceUtilization = resourceUtil
req.Usages = len(resourceUtil.Deployments) // For backward compatibility
default:
log.Errorw("unsupported payment model for generic invoice",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"payment_model", usageData.PaymentModel)
return
}
// Forward invoice using solution enabler's actor (contract host node actor)
if c.forwardInvoice == nil {
log.Errorw("forwardInvoice function not set for contract actor",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
return
}
if err := c.forwardInvoice(req); err != nil {
log.Errorw("failed to forward invoice to payment validator",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"payment_model", usageData.PaymentModel,
"error", err)
return
}
// Update last processed timestamp after successful send
err := c.usageStore.SaveLastProcessedAt(c.ContractDID.URI, now)
if err != nil {
log.Errorw("failed to save last processed timestamp after invoice",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
} else {
log.Infow("invoice generated and sent successfully",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"payment_model", usageData.PaymentModel,
"usages", req.Usages)
}
}
// sendFixedRentalInvoiceFromUsageData sends fixed rental invoice from UsageData
func (c *ContractActor) sendFixedRentalInvoiceFromUsageData(contract *contracts.Contract, usageData *contracts.UsageData, now time.Time) {
fixedRentalUsage, ok := usageData.Data.(*contracts.FixedRentalUsage)
if !ok {
log.Errorw("invalid usage data type for fixed rental",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
return
}
// Use existing sendFixedRentalInvoice function
c.sendFixedRentalInvoice(contract, fixedRentalUsage, now)
}
// sendPeriodicInvoiceFromUsageData sends periodic invoice from UsageData
func (c *ContractActor) sendPeriodicInvoiceFromUsageData(contract *contracts.Contract, usageData *contracts.UsageData, now time.Time) {
periodicUsage, ok := usageData.Data.(*contracts.PeriodicUsage)
if !ok {
log.Errorw("invalid usage data type for periodic",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
return
}
// Use existing sendPeriodicInvoice function
c.sendPeriodicInvoice(contract, periodicUsage, now)
}
// sendPeriodicInvoice sends periodic invoice(s) - one per deployment (Edge Case 5)
func (c *ContractActor) sendPeriodicInvoice(
contract *contracts.Contract,
periodicUsage *contracts.PeriodicUsage,
now time.Time,
) {
pd := contract.PaymentDetails
// Parse fee per time unit
feePerUnit, err := strconv.ParseFloat(pd.FeePerTimeUnit, 64)
if err != nil {
log.Errorw("failed to parse fee_per_time_unit for periodic invoice",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
return
}
// Edge Case 5: Generate one invoice per deployment for the period
for _, deployment := range periodicUsage.Deployments {
// Calculate amount for this deployment
var timeInUnit float64
deploymentTimeSec := deployment.TotalUtilizationSec
switch pd.TimeUnit {
case contracts.TimeUnitSecond:
timeInUnit = deploymentTimeSec
case contracts.TimeUnitMinute:
timeInUnit = deploymentTimeSec / 60.0
case contracts.TimeUnitHour:
timeInUnit = deploymentTimeSec / 3600.0
default:
log.Errorw("unsupported time_unit for periodic invoice",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"time_unit", pd.TimeUnit)
continue // Skip this deployment
}
deploymentAmount := feePerUnit * timeInUnit
// Edge Case 4: Use deployment stop time as period end
// The deployment stop time is already accounted for in the runtime calculation
// via CalculateDeploymentTimeUtilizationByContract, so we use periodicUsage.PeriodEnd
// which is set correctly based on whether deployment stopped during the period
periodEnd := periodicUsage.PeriodEnd
// Generate unique ID for this deployment's invoice
// Note: UUID will be generated in PeriodicProcessor.CalculatePayment,
// this is just for the request identifier
uniqueID := uuid.NewString()
// Create PeriodicUsage for this single deployment
deploymentPeriodicUsage := &contracts.PeriodicUsage{
PeriodStart: periodicUsage.PeriodStart,
PeriodEnd: periodEnd, // Use deployment stop time if applicable
LastInvoiceAt: periodicUsage.LastInvoiceAt,
Deployments: []contracts.DeploymentTimeUtilization{deployment}, // Single deployment
TotalTimeSec: deploymentTimeSec,
Amount: fmt.Sprintf("%.8f", deploymentAmount),
PeriodsInvoiced: periodicUsage.PeriodsInvoiced,
}
// Create ContractUsageRequest for this deployment
req := contracts.ContractUsageRequest{
UniqueID: uniqueID,
Contract: *contract,
PeriodicDetails: deploymentPeriodicUsage,
}
log.Infow("generating periodic invoice for deployment",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"deployment_id", deployment.DeploymentID,
"period_start", periodicUsage.PeriodStart,
"period_end", periodEnd,
"amount", deploymentPeriodicUsage.Amount,
"runtime_sec", deploymentTimeSec)
// Forward invoice using solution enabler's actor (contract host node actor)
// This ensures the message is sent from the solution enabler's actor handle, which has the required capabilities
if c.forwardInvoice == nil {
log.Errorw("forwardInvoice function not set for contract actor",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI)
continue // Continue with other deployments
}
// Send invoice using forwardInvoice function
if err := c.forwardInvoice(req); err != nil {
log.Errorw("failed to send periodic invoice for deployment",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"deployment_id", deployment.DeploymentID,
"error", err)
// Continue with other deployments even if one fails
continue
}
}
// Update last invoice timestamp after all deployment invoices are sent
if err := c.usageStore.SaveLastProcessedAt(c.ContractDID.URI, now); err != nil {
log.Errorw("failed to update last processed timestamp after sending periodic invoices",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"error", err)
// Don't return - invoices were sent successfully
} else {
log.Infow("periodic invoices generated and sent successfully",
"labels", string(observability.LabelContract),
"contract_did", c.ContractDID.URI,
"deployments", len(periodicUsage.Deployments))
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package contracts
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"time"
"gitlab.com/nunet/device-management-service/actor"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/tokenomics/store/transaction"
"gitlab.com/nunet/device-management-service/types"
)
type CreateContractRequest struct {
Metadata map[string]string `json:"metadata"`
SolutionEnablerDID did.DID `json:"solution_enabler_did"`
PaymentValidatorDID did.DID `json:"payment_validator_did"`
ResourceConfiguration types.Resources `json:"resource_configuration"`
TerminationOption *TerminationOption `json:"termination_option"`
Penalties []PenaltyClause `json:"penalties"`
PaymentDetails PaymentDetails `json:"payment_details"`
ContractTerms interface{} `json:"contract_terms"`
ContractParticipants ContractParticipants `json:"contract_participants"`
Duration DurationDetails `json:"duration"`
DisableBilling bool `json:"disable_billing,omitempty"` // If true, disables all billing (automatic and manual)
}
type ContractPaymentStatusRequest struct {
UniqueID string `json:"unique_id"`
}
type ContractPaymentStatusResponse struct {
UniqueID string `json:"unique_id"`
Paid bool `json:"paid"`
Error string `json:"error"`
}
type CollectUsagesAndForwardToPaymentProvidersRequest struct {
ContractDID string `json:"contract_did,omitempty"` // If empty, processes all contracts
}
// AllocationTimeUtilization represents time utilization for a single allocation
type AllocationTimeUtilization struct {
AllocationID string `json:"allocation_id"`
Duration time.Duration `json:"duration"` // Total time the allocation ran
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time,omitempty"` // Empty if still running
}
// DeploymentTimeUtilization represents time utilization for a deployment
type DeploymentTimeUtilization struct {
DeploymentID string `json:"deployment_id"`
Allocations []AllocationTimeUtilization `json:"allocations"`
TotalUtilizationSec float64 `json:"total_utilization_sec"` // Total seconds across all allocations
}
// TimeUtilizationUsage represents usage data for pay_per_time_utilization model
type TimeUtilizationUsage struct {
Deployments []DeploymentTimeUtilization `json:"deployments"`
}
// AllocationResourceUtilization represents resource utilization for a single allocation
type AllocationResourceUtilization struct {
AllocationID string `json:"allocation_id"`
Resources types.Resources `json:"resources"` // CPU cores, RAM GB, Disk GB, GPU count
Duration time.Duration `json:"duration"` // How long allocation ran
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time,omitempty"`
// Calculated costs (for invoice details)
CPUCost string `json:"cpu_cost,omitempty"`
RAMCost string `json:"ram_cost,omitempty"`
DiskCost string `json:"disk_cost,omitempty"`
GPUCost string `json:"gpu_cost,omitempty"`
TotalCost string `json:"total_cost,omitempty"`
}
// DeploymentResourceUtilization tracks all allocations in a deployment
type DeploymentResourceUtilization struct {
DeploymentID string `json:"deployment_id"`
Allocations []AllocationResourceUtilization `json:"allocations"`
TotalUtilizationSec float64 `json:"total_utilization_sec"`
TotalCost string `json:"total_cost,omitempty"`
}
// ResourceUtilizationUsage represents resource utilization data
type ResourceUtilizationUsage struct {
Deployments []DeploymentResourceUtilization `json:"deployments"`
}
// FixedRentalUsage represents usage data for fixed_rental payment model
type FixedRentalUsage struct {
PeriodsInvoiced int `json:"periods_invoiced"` // Number of full periods invoiced
PeriodStart time.Time `json:"period_start"` // Start of the first period in this invoice
PeriodEnd time.Time `json:"period_end"` // End of the last period in this invoice
Amount string `json:"amount"` // Total amount for this invoice
LastInvoiceAt time.Time `json:"last_invoice_at"` // Timestamp of last invoice (before this one)
}
// PeriodicUsage represents usage data for periodic payment model
type PeriodicUsage struct {
PeriodStart time.Time `json:"period_start"` // Start of billing period
PeriodEnd time.Time `json:"period_end"` // End of billing period
LastInvoiceAt time.Time `json:"last_invoice_at"` // Timestamp of last invoice
Deployments []DeploymentTimeUtilization `json:"deployments"` // Deployment runtime data
TotalTimeSec float64 `json:"total_time_sec"` // Total deployment time in seconds
Amount string `json:"amount"` // Calculated amount
PeriodsInvoiced int `json:"periods_invoiced"` // Number of periods covered
}
type ContractUsageResult struct {
ContractDID string `json:"contract_did"`
PaymentModel PaymentModel `json:"payment_model"`
Usages int `json:"usages"` // For backward compatibility
Error string `json:"error,omitempty"`
TimeUtilization *TimeUtilizationUsage `json:"time_utilization,omitempty"` // For pay_per_time_utilization
ResourceUtilization *ResourceUtilizationUsage `json:"resource_utilization,omitempty"` // For pay_per_resource_utilization
FixedRentalDetails *FixedRentalUsage `json:"fixed_rental_details,omitempty"` // For fixed_rental
PeriodicDetails *PeriodicUsage `json:"periodic_details,omitempty"` // For periodic
}
type CollectUsagesAndForwardToPaymentProvidersReponse struct {
Error string `json:"error"`
TotalUsages int `json:"total_usages"`
Results []ContractUsageResult `json:"results,omitempty"` // Per-contract results
}
type ContractListLocalTransactionsRequest struct{}
type ContractListLocalTransactionsResponse struct {
Error string `json:"error"`
Transactions []*transaction.Transaction `json:"transactions"`
}
type ContractConfirmLocalTransactionRequest struct {
UniqueID string `json:"unique_id"`
TxHash string `json:"tx_hash"`
Blockchain string `json:"blockchain"`
QuoteID string `json:"quote_id,omitempty"` // Optional: quote ID for price conversion
}
type ContractConfirmLocalTransactionResponse struct {
Error string `json:"error"`
}
type TransactionForServiceProviderRequest struct {
UniqueID string `json:"unique_id"`
PaymentValidatorDID string `json:"payment_validator_did"`
ContractDID string `json:"contract_did"`
ToAddress []types.PaymentAddressInfo `json:"to_address"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
Amount string `json:"amount"` // Original amount (in pricing currency if conversion needed)
Status string `json:"status,omitempty"` // optional status, defaults to "unpaid" if empty
TxHash string `json:"tx_hash,omitempty"` // optional transaction hash
// New fields for price conversion tracking (NO conversion happens here)
OriginalAmount string `json:"original_amount,omitempty"` // Amount in pricing currency (USDT)
PricingCurrency string `json:"pricing_currency,omitempty"` // Currency of original amount (e.g., "USDT")
RequiresConversion bool `json:"requires_conversion,omitempty"` // True if conversion is needed
}
type TransactionForServiceProviderResponse struct {
Error string `json:"error"`
}
type ContractUsageRequest struct {
UniqueID string `json:"unique_id"`
Contract Contract `json:"contract"`
Usages int `json:"usages"` // For backward compatibility
TimeUtilization *TimeUtilizationUsage `json:"time_utilization,omitempty"` // For pay_per_time_utilization
ResourceUtilization *ResourceUtilizationUsage `json:"resource_utilization,omitempty"` // For pay_per_resource_utilization
FixedRentalDetails *FixedRentalUsage `json:"fixed_rental_details,omitempty"` // For fixed_rental
PeriodicDetails *PeriodicUsage `json:"periodic_details,omitempty"` // For periodic
}
type ContractUsageResponse struct {
Error string `json:"error"`
}
type ContractEventRequest struct {
Payload []byte `json:"payload"`
}
type ContractEventResponse struct {
Error string `json:"error"`
}
type ContractPaymentValidationRequest struct {
TxHash string `json:"tx_hash"`
UniqueID string `json:"unique_id"`
Blockchain string `json:"blockchain"`
QuoteID string `json:"quote_id,omitempty"` // Optional: quote ID for price conversion
}
type ContractPaymentValidationResponse struct {
Error string `json:"error"`
}
// ContractGetPaymentQuoteRequest requests a payment quote for a transaction
type ContractGetPaymentQuoteRequest struct {
UniqueID string `json:"unique_id"` // Transaction unique_id
}
// ContractGetPaymentQuoteResponse returns a payment quote
type ContractGetPaymentQuoteResponse struct {
QuoteID string `json:"quote_id,omitempty"` // Unique quote identifier
OriginalAmount string `json:"original_amount,omitempty"` // Amount in pricing currency (USDT)
ConvertedAmount string `json:"converted_amount,omitempty"` // Amount in payment currency (NTX)
PricingCurrency string `json:"pricing_currency,omitempty"` // Original currency (e.g., "USDT")
PaymentCurrency string `json:"payment_currency,omitempty"` // Payment currency (e.g., "NTX")
ExchangeRate string `json:"exchange_rate,omitempty"` // Exchange rate used
ExpiresAt time.Time `json:"expires_at,omitempty"` // Quote expiration timestamp
Error string `json:"error,omitempty"`
}
// ContractValidatePaymentQuoteRequest validates a payment quote before payment
type ContractValidatePaymentQuoteRequest struct {
QuoteID string `json:"quote_id"` // Quote to validate
}
// ContractValidatePaymentQuoteResponse returns validation result
type ContractValidatePaymentQuoteResponse struct {
Valid bool `json:"valid"` // Whether quote is valid
QuoteID string `json:"quote_id,omitempty"` // Quote identifier
OriginalAmount string `json:"original_amount,omitempty"` // Amount in pricing currency
ConvertedAmount string `json:"converted_amount,omitempty"` // Amount in payment currency
PricingCurrency string `json:"pricing_currency,omitempty"` // Original currency
PaymentCurrency string `json:"payment_currency,omitempty"` // Payment currency
ExchangeRate string `json:"exchange_rate,omitempty"` // Exchange rate used
ExpiresAt time.Time `json:"expires_at,omitempty"` // Quote expiration timestamp
Error string `json:"error,omitempty"` // Error if invalid
}
// ContractCancelPaymentQuoteRequest cancels/invalidates a payment quote
type ContractCancelPaymentQuoteRequest struct {
QuoteID string `json:"quote_id"` // Quote to cancel
}
// ContractCancelPaymentQuoteResponse returns cancellation result
type ContractCancelPaymentQuoteResponse struct {
Error string `json:"error,omitempty"`
}
type PaymentValidateRequest struct {
ContractDID string `json:"contract_did"`
}
type PaymentValidateResponse struct {
Error string `json:"error"`
}
type ContractListIncomingRole string
const (
ContractRoleProvider ContractListIncomingRole = "provider"
ContractRoleRequestor ContractListIncomingRole = "requestor"
)
type ContractListIncomingRequest struct {
Role ContractListIncomingRole `json:"role,omitempty"`
}
type ContractListIncomingResponse struct {
Contracts []*Contract `json:"contracts"`
Error string `json:"error,omitempty"`
}
type ContractApproveLocalRequest struct {
ContractDID string `json:"contract_did"`
}
type ContractApproveLocalResponse struct {
Success bool `json:"success"`
Error string `json:"error,omitempty"`
}
// ContractVerificationResponse represents the response structure for contract verification
type ContractVerificationResponse struct {
Valid bool `json:"valid"`
Error string `json:"error,omitempty"`
}
type ContractStatusRequest struct {
ContractDID string `json:"contract_did"`
}
type ContractStatusResponse struct {
Error string `json:"error"`
Contract Contract `json:"contract"`
}
type CreateContractResponse struct {
ContractRequest CreateContractRequest `json:"contract_request"`
ContractDID string `json:"contract_did"`
PubKey string `json:"pub_key"`
Error string `json:"error"`
}
type ProposeContractRequest struct {
Contract Contract `json:"contract"`
CreatorOfContract actor.Handle `json:"creator_of_contract"`
}
type ProposeContractResponse struct {
Signature Signature `json:"signature"`
Error string `json:"error"`
}
type ContractTerminationRequest struct {
ContractDID string `json:"contract_did"`
}
type ContractTerminationResponse struct {
Error string `json:"error"`
}
type ContractCompletionRequest struct {
ContractDID string `json:"contract_did"`
}
type ContractCompletionResponse struct {
Error string `json:"error"`
}
type ContractSettleRequest struct {
ContractDID string `json:"contract_did"`
}
type ContractSettleResponse struct {
Error string `json:"error"`
}
type ContractValidateRequest struct {
ContractDID string `json:"contract_did"`
}
type ContractValidateResponse struct {
Valid bool `json:"valid"`
CurrentStatus string `json:"current_status"`
Error string `json:"error"`
}
// ContractChainVerificationRequest requests verification of a contract chain
type ContractChainVerificationRequest struct {
SolutionEnablerDID string `json:"solution_enabler_did"` // Contract host DID
ContractDID string `json:"contract_did"` // Contract DID (Orch ↔ Org)
OrganizationDID string `json:"organization_did"` // Organization DID (from Contract A)
OrchestratorDID string `json:"orchestrator_did"` // Orchestrator DID
ProviderDID string `json:"provider_did"` // Provider DID
}
// ContractChainVerificationResponse contains the chain verification result
type ContractChainVerificationResponse struct {
Valid bool `json:"valid"`
OrganizationDID string `json:"organization_did,omitempty"`
OrchestratorContract *Contract `json:"orchestrator_contract,omitempty"` // Contract A
ProviderContract *Contract `json:"provider_contract,omitempty"` // Contract B
Error string `json:"error,omitempty"`
}
type ContractSignRequest struct {
ContractDID string `json:"contract_did"`
Signature []byte `json:"signature"`
}
type ContractSignResponse struct {
Error string `json:"error"`
Contract Contract `json:"contract"`
}
// ContractState represents the possible states of a contract
type ContractState string
const (
ContractDraft ContractState = "DRAFT"
ContractProposed ContractState = "PROPOSED"
ContractAccepted ContractState = "ACCEPTED" // TODO add accepted status
ContractActive ContractState = "ACTIVE" // Not expired and not canceled
ContractPaused ContractState = "PAUSED" // in case of dispute
ContractUpdate ContractState = "UPDATED"
ContractTerminated ContractState = "TERMINATED"
ContractCompleted ContractState = "COMPLETED"
ContractSettled ContractState = "SETTLED"
)
// ContractEvent represents events that can trigger state transitions
type ContractEvent string
const (
EventPropose ContractEvent = "PROPOSE"
EventAccepted ContractEvent = "ACCEPTED"
EventActivate ContractEvent = "ACTIVATE"
EventDispute ContractEvent = "DISPUTE"
EventUpdate ContractEvent = "UPDATED"
EventTerminate ContractEvent = "TERMINATE" // TODO: termination
EventComplete ContractEvent = "COMPLETE"
EventSettle ContractEvent = "SETTLE"
)
// StateTransitionError represents an invalid state transition
type StateTransitionError struct {
Current ContractState
Event ContractEvent
Message string
}
func (e *StateTransitionError) Error() string {
return e.Message
}
// ContractStateTransition represents a valid state transition
type ContractStateTransition struct {
FromState ContractState
Event ContractEvent
ToState ContractState
}
// Contract chain role constants for metadata
const (
ContractChainRoleMetadataKey = "contract_chain_role"
ContractChainRoleHead = "head"
ContractChainRoleTail = "tail"
)
// StateTransition represents a historical state change
type StateTransition struct {
FromState ContractState `json:"from_state"`
ToState ContractState `json:"to_state"`
Event ContractEvent `json:"event"`
Timestamp time.Time `json:"timestamp"`
InitiatedBy did.DID `json:"initiated_by"`
}
// Contract represents the contract details between nodes
type Contract struct {
ContractDID string `json:"contract_did"`
SolutionEnablerDID did.DID `json:"solution_enabler_did"`
PaymentValidatorDID did.DID `json:"payment_validator_did"`
ResourceConfiguration types.Resources `json:"resource_configuration"`
TerminationOption *TerminationOption `json:"termination_option,omitempty"`
Penalties []PenaltyClause `json:"penalties"`
Duration *DurationDetails `json:"duration,omitempty"`
ContractParticipants ContractParticipants `json:"participants"`
PaymentDetails PaymentDetails `json:"payment_details"` // Zero value: zero value of payments.Payment struct
Paid bool `json:"paid"`
Signatures []Signature `json:"signatures"` // Changed to slice of Signature
Settled bool `json:"settled"` // Example default: false
Verification jobs.Status `json:"verification"` // Zero value: zero value of jobs.Status
ContractProof []byte `json:"contract_proof"` // Example default: "Pending"
CurrentState ContractState `json:"current_state"` // state tracking
ContractTerms interface{} `json:"contract_terms"` // To store contract agreement terms
TerminationStarted time.Time `json:"termination_started"`
Transitions []StateTransition `json:"transitions"`
DisableBilling bool `json:"disable_billing,omitempty"` // If true, disables all billing (automatic and manual)
Metadata map[string]string `json:"metadata,omitempty"` // Contract metadata (e.g., contract_chain_role)
}
func (c *Contract) Sign(key did.Provider) ([]byte, error) {
cBytes, err := json.Marshal(c)
if err != nil {
return nil, fmt.Errorf("failed to get contract data: %w", err)
}
sig, err := key.Sign(cBytes)
if err != nil {
return nil, fmt.Errorf("unable to sign the contract")
}
return sig, nil
}
type ContractParticipants struct {
Provider did.DID `json:"provider"`
Requestor did.DID `json:"requestor"`
}
// TerminationOption specifies termination rules for long-running jobs.
type TerminationOption struct {
Allowed bool `json:"allowed"`
NoticePeriod time.Duration `json:"notice_period"` // e.g., "30 days"
}
// DurationDetails defines the duration for hire.
type DurationDetails struct {
StartDate time.Time `json:"start_date"`
EndDate time.Time `json:"end_date"`
}
type PenaltyClause struct {
Condition string `json:"condition"` // e.g., "Uptime < 99.9%"
Penalty float64 `json:"penalty"`
}
// Signature represents a digital signature on the contract
type Signature struct {
DID did.DID `json:"did"` // The DID of the signer
Signatures []byte `json:"signature"` // The actual signature bytes
}
// Validate validates the CreateContractRequest and returns an error if any required fields are missing or invalid
func (req *CreateContractRequest) Validate() error {
// Validate SolutionEnablerDID
if req.SolutionEnablerDID.Empty() {
return fmt.Errorf("solution_enabler_did is required")
}
// Validate PaymentValidatorDID
if req.PaymentValidatorDID.Empty() {
return fmt.Errorf("payment_validator_did is required")
}
// Validate ContractParticipants
if req.ContractParticipants.Provider.Empty() {
return fmt.Errorf("contract_participants.provider is required")
}
if req.ContractParticipants.Requestor.Empty() {
return fmt.Errorf("contract_participants.requestor is required")
}
// Validate PaymentDetails
if req.PaymentDetails.PaymentModel == "" {
return fmt.Errorf("payment_details.payment_model is required")
}
// Validate Duration
if req.Duration.StartDate.IsZero() {
return fmt.Errorf("duration.start_date is required")
}
if req.Duration.EndDate.IsZero() {
return fmt.Errorf("duration.end_date is required")
}
if !req.Duration.EndDate.After(req.Duration.StartDate) {
return fmt.Errorf("duration.end_date must be after duration.start_date")
}
// Validate payment_period and payment_period_count for models that require them
if req.PaymentDetails.PaymentModel == FixedRental || req.PaymentDetails.PaymentModel == Periodic {
// Validate payment_period is required and valid
if req.PaymentDetails.PaymentPeriod == "" {
return fmt.Errorf("payment_details.payment_period is required for payment model %s", req.PaymentDetails.PaymentModel)
}
// Validate payment_period is one of the valid values
validPeriods := map[string]bool{
PaymentPeriodMinute: true,
PaymentPeriodHour: true,
PaymentPeriodDay: true,
PaymentPeriodWeek: true,
PaymentPeriodMonth: true,
}
if !validPeriods[req.PaymentDetails.PaymentPeriod] {
return fmt.Errorf("payment_details.payment_period must be one of: minute, hour, day, week, month, got: %s", req.PaymentDetails.PaymentPeriod)
}
// Validate payment_period_count is positive
if req.PaymentDetails.PaymentPeriodCount <= 0 {
return fmt.Errorf("payment_details.payment_period_count must be a positive integer for payment model %s, got: %d", req.PaymentDetails.PaymentModel, req.PaymentDetails.PaymentPeriodCount)
}
}
return nil
}
func GenerateContractID(req CreateContractRequest) (string, error) {
data, err := json.Marshal(req)
if err != nil {
return "", err
}
hash := sha256.Sum256(data)
return hex.EncodeToString(hash[:]), nil
}
// SetPeriodicityDefaults sets default values for PaymentPeriod and PaymentPeriodCount if not provided.
// Default: PaymentPeriod = "hour", PaymentPeriodCount = 1
// This ensures all contracts have periodicity configured for automatic billing.
func SetPeriodicityDefaults(pd *PaymentDetails) {
if pd.PaymentPeriod == "" {
pd.PaymentPeriod = PaymentPeriodHour
}
if pd.PaymentPeriodCount <= 0 {
pd.PaymentPeriodCount = 1
}
}
func NewContract(contractDID string, req CreateContractRequest) *Contract {
// Set periodicity defaults if not provided
SetPeriodicityDefaults(&req.PaymentDetails)
// Copy metadata from request if provided
metadata := make(map[string]string)
if req.Metadata != nil {
for k, v := range req.Metadata {
metadata[k] = v
}
}
return &Contract{
ContractDID: contractDID,
SolutionEnablerDID: req.SolutionEnablerDID,
PaymentValidatorDID: req.PaymentValidatorDID,
ResourceConfiguration: req.ResourceConfiguration,
TerminationOption: req.TerminationOption,
Penalties: req.Penalties,
Duration: &req.Duration,
ContractParticipants: req.ContractParticipants,
PaymentDetails: req.PaymentDetails,
ContractTerms: req.ContractTerms,
CurrentState: ContractDraft,
Transitions: []StateTransition{},
DisableBilling: req.DisableBilling,
Metadata: metadata,
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package contracts
import (
"fmt"
"strconv"
)
// OrchestrationFeeCalculator calculates orchestration fees for payment items
type OrchestrationFeeCalculator struct{}
// NewOrchestrationFeeCalculator creates a new fee calculator
func NewOrchestrationFeeCalculator() *OrchestrationFeeCalculator {
return &OrchestrationFeeCalculator{}
}
// CalculateFee calculates the orchestration fee for a batch of payment items
// Returns the fee amount as a string, or empty string if no fee should be charged
func (c *OrchestrationFeeCalculator) CalculateFee(
paymentItems []*PaymentItem,
config *OrchestrationFeeConfig,
) (string, error) {
if config == nil {
return "", nil // No orchestration fee configured
}
var fixedFee float64
var percentage float64
// Parse fixed amount
if config.FixedAmount != "" && config.FixedAmount != "0" {
var err error
fixedFee, err = strconv.ParseFloat(config.FixedAmount, 64)
if err != nil {
return "", fmt.Errorf("invalid fixed_amount: %w", err)
}
}
// Parse percentage
if config.Percentage != "" && config.Percentage != "0" {
var err error
percentage, err = strconv.ParseFloat(config.Percentage, 64)
if err != nil {
return "", fmt.Errorf("invalid percentage: %w", err)
}
}
// If both are zero, no fee
if fixedFee == 0 && percentage == 0 {
return "", nil
}
// Calculate total amount across all payment items
var totalAmount float64
for _, item := range paymentItems {
itemAmount, err := strconv.ParseFloat(item.Amount, 64)
if err != nil {
return "", fmt.Errorf("invalid payment item amount: %w", err)
}
totalAmount += itemAmount
}
// Calculate: Fixed Fee + (Total Amount × Percentage / 100)
percentageFee := totalAmount * percentage / 100.0
totalFee := fixedFee + percentageFee
// Format and return (8 decimal places, same as formatAmount helper)
return fmt.Sprintf("%.8f", totalFee), nil
}
// GenerateOrchestrationFeeItem creates a PaymentItem for the orchestration fee
func (c *OrchestrationFeeCalculator) GenerateOrchestrationFeeItem(
paymentItems []*PaymentItem,
_ *Contract, // Contract parameter kept for API consistency but not used
baseUniqueID string,
feeAmount string,
) (*PaymentItem, error) {
if feeAmount == "" {
return nil, nil // No fee to generate
}
// Collect original unique IDs for metadata
originalUniqueIDs := make([]string, len(paymentItems))
originalAmounts := make([]string, len(paymentItems))
for i, item := range paymentItems {
originalUniqueIDs[i] = item.UniqueID
originalAmounts[i] = item.Amount
}
// Create orchestration fee item
feeItem := &PaymentItem{
UniqueID: fmt.Sprintf("%s-orchestration-fee", baseUniqueID),
DeploymentID: "", // Empty - applies to entire batch
Amount: feeAmount,
Usages: len(paymentItems), // Number of payment items in batch
IsOrchestrationFee: true,
Metadata: map[string]interface{}{
"original_unique_ids": originalUniqueIDs,
"original_amounts": originalAmounts,
"fee_type": "orchestration",
"payment_item_count": len(paymentItems),
},
}
return feeItem, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package contracts
import (
"fmt"
"sync"
"time"
)
// UsageData represents collected usage data
type UsageData struct {
ContractDID string
PaymentModel PaymentModel
Data interface{} // Model-specific data (TimeUtilizationUsage, ResourceUtilizationUsage, etc.)
}
// PaymentItem represents a single payment to process
type PaymentItem struct {
UniqueID string
DeploymentID string // Empty for non-deployment models
Amount string // Final amount in payment currency (NTX after conversion)
Usages int
Metadata map[string]interface{} // Model-specific metadata
// Fields for price conversion tracking (optional)
OriginalAmount string `json:"original_amount,omitempty"` // Amount in pricing currency (USDT)
PricingCurrency string `json:"pricing_currency,omitempty"` // Currency of original amount
ExchangeRate string `json:"exchange_rate,omitempty"` // Rate used for conversion
ConversionTimestamp time.Time `json:"conversion_timestamp,omitempty"` // When conversion occurred
IsOrchestrationFee bool `json:"is_orchestration_fee,omitempty"` // Indicates if this is an orchestration fee transaction
}
// PaymentModelProcessor defines the clear, shared interface for all payment model processors.
// Each processor is a self-contained strategy with direct store access.
// All payment models must implement this interface.
type PaymentModelProcessor interface {
// CollectUsage collects usage data for manual billing.
// This method is called when a user manually triggers invoice generation.
// Processors have full control over how they query and process events from the store.
// Returns UsageData containing model-specific usage information.
// providerDID is optional - if provided, filters events by provider for per-node billing.
// headContractDID is optional - if provided, queries events by Head Contract DID instead of Tail Contract DID
CollectUsage(
contractDID string,
lastProcessedAt time.Time,
now time.Time,
providerDID string, // Optional: if provided, filters events by provider
headContractDID string, // Optional: if provided, queries by Head Contract DID
) (*UsageData, error)
// CalculatePayment calculates payment items from usage data.
// This method converts UsageData into PaymentItems that can be processed.
// Each PaymentItem represents a single payment transaction (one per deployment for deployment-based models).
// Returns a slice of PaymentItems, each with calculated amount and metadata.
CalculatePayment(
usageData *UsageData,
contract *Contract,
) ([]*PaymentItem, error)
// Validate validates payment model configuration.
// This method checks that PaymentDetails contains all required fields for this payment model.
// Returns an error if validation fails, nil if valid.
Validate(paymentDetails PaymentDetails) error
// SupportsManualBilling indicates whether this payment model supports manual invoice generation.
// Models like FixedRental and Periodic return false (automatic billing only).
// Models like PayPerAllocation, PayPerDeployment, etc. return true.
SupportsManualBilling() bool
// SupportsAutomaticBilling indicates whether this payment model supports automatic periodic billing.
// Models like FixedRental and Periodic return true.
// Other models return false.
SupportsAutomaticBilling() bool
// CheckAndGenerateInvoice checks if an invoice should be generated and collects usage data.
// This method is called periodically by the contract actor for automatic billing models.
// It checks if enough time has elapsed since the last invoice and collects usage if so.
// Returns UsageData if invoice should be generated, nil if not yet time, error on failure.
// Only called for models where SupportsAutomaticBilling() returns true.
CheckAndGenerateInvoice(
contract *Contract,
lastInvoiceAt time.Time,
now time.Time,
) (*UsageData, error)
}
// PaymentModelRegistry manages payment model processors.
// All processors must be registered before use.
type PaymentModelRegistry struct {
processors map[PaymentModel]PaymentModelProcessor
mu sync.RWMutex
}
var globalRegistry = &PaymentModelRegistry{
processors: make(map[PaymentModel]PaymentModelProcessor),
}
// RegisterPaymentModelProcessor registers a payment model processor.
// Should be called during application initialization.
func RegisterPaymentModelProcessor(model PaymentModel, processor PaymentModelProcessor) {
if processor == nil {
panic(fmt.Sprintf("cannot register nil processor for payment model: %s", model))
}
globalRegistry.mu.Lock()
defer globalRegistry.mu.Unlock()
globalRegistry.processors[model] = processor
}
// GetPaymentModelProcessor returns the processor for a payment model.
// Returns an error if no processor is registered for the given model.
func GetPaymentModelProcessor(model PaymentModel) (PaymentModelProcessor, error) {
globalRegistry.mu.RLock()
defer globalRegistry.mu.RUnlock()
processor, ok := globalRegistry.processors[model]
if !ok {
return nil, fmt.Errorf("no processor registered for payment model: %s", model)
}
return processor, nil
}
// MustGetPaymentModelProcessor returns the processor or panics (for initialization).
func MustGetPaymentModelProcessor(model PaymentModel) PaymentModelProcessor {
processor, err := GetPaymentModelProcessor(model)
if err != nil {
panic(err)
}
return processor
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package contracts
import (
"fmt"
"strconv"
"time"
"gitlab.com/nunet/device-management-service/types"
)
type PaymentModel string
const (
PayPerAllocation PaymentModel = "pay_per_allocation"
PayPerDeployment PaymentModel = "pay_per_deployment"
PayPerTimeUtilization PaymentModel = "pay_per_time_utilization"
PayPerResourceUtilization PaymentModel = "pay_per_resource_utilization"
FixedRental PaymentModel = "fixed_rental"
Periodic PaymentModel = "periodic"
)
const (
FiatMethod PaymentType = "fiat"
BlockchainMethod PaymentType = "blockchain"
)
const (
PaymentPeriodMinute string = "minute"
PaymentPeriodHour string = "hour"
PaymentPeriodDay string = "day"
PaymentPeriodWeek string = "week"
PaymentPeriodMonth string = "month"
)
const (
TimeUnitSecond string = "second"
TimeUnitMinute string = "minute"
TimeUnitHour string = "hour"
)
type PaymentType string
// Payment represents a payment transaction
type PaymentDetails struct {
PaymentType PaymentType `json:"payment_type"`
Timestamp time.Time `json:"timestamp"`
// payment model
PaymentModel PaymentModel `json:"payment_model"`
// pay per deployment payment model
FeePerDeployment string `json:"fee_per_deployment,omitempty"`
// pay per allocation payment model
FeePerAllocation string `json:"fee_per_allocation"`
// pay per time utilization payment model
FeePerTimeUnit string `json:"fee_per_time_unit,omitempty"` // e.g., "0.01" per second
TimeUnit string `json:"time_unit,omitempty"` // "second", "minute", "hour"
// pay per resource utilization payment model
FeePerCPUCorePerTimeUnit string `json:"fee_per_cpu_core_per_time_unit,omitempty"` // e.g., "0.10" per core per hour
FeePerRAMGBPerTimeUnit string `json:"fee_per_ram_gb_per_time_unit,omitempty"` // e.g., "0.05" per GB per hour
FeePerDiskGBPerTimeUnit string `json:"fee_per_disk_gb_per_time_unit,omitempty"` // e.g., "0.01" per GB per hour
FeePerGPUPerTimeUnit string `json:"fee_per_gpu_per_time_unit,omitempty"` // e.g., "5.00" per GPU per hour (optional)
ResourceTimeUnit string `json:"resource_time_unit,omitempty"` // "second", "minute", "hour"
// fixed rental payment model
FixedRentalAmount string `json:"fixed_rental_amount,omitempty"` // e.g., "20.00"
PaymentPeriod string `json:"payment_period,omitempty"` // "minute", "hour", "day", "week", "month"
PaymentPeriodCount int `json:"payment_period_count,omitempty"` // Number of periods to wait before invoicing (default: 1). Invoice amount is fixedRentalAmount, invoiced every paymentPeriodCount periods
Addresses []types.PaymentAddressInfo `json:"addresses"`
PricingCurrency string `json:"pricing_currency,omitempty"`
// Orchestration fee configuration (optional)
OrchestrationFee *OrchestrationFeeConfig `json:"orchestration_fee,omitempty"`
}
// OrchestrationFeeConfig represents the configuration for orchestration fees
type OrchestrationFeeConfig struct {
// Fixed fee amount (e.g., "1.50" for $1.50)
// If empty or "0", no fixed fee is charged
FixedAmount string `json:"fixed_amount,omitempty"`
// Percentage fee (e.g., "2.5" for 2.5%)
// If empty or "0", no percentage fee is charged
// Percentage is applied to the total amount of all payment items in the batch
Percentage string `json:"percentage,omitempty"`
// Recipient address for orchestration fee payments
// If empty, uses the contract's default addresses
// The transaction is still forwarded to the original requestor, but uses this address for payment details
RecipientAddress types.PaymentAddressInfo `json:"recipient_address,omitempty"`
}
// ValidateOrchestrationFee validates the orchestration fee configuration
func (pd *PaymentDetails) ValidateOrchestrationFee() error {
if pd.OrchestrationFee == nil {
return nil // Optional field
}
fee := pd.OrchestrationFee
// At least one component must be specified
if (fee.FixedAmount == "" || fee.FixedAmount == "0") &&
(fee.Percentage == "" || fee.Percentage == "0") {
return fmt.Errorf("orchestration_fee must have at least fixed_amount or percentage")
}
// Validate fixed amount if provided
if fee.FixedAmount != "" && fee.FixedAmount != "0" {
if _, err := strconv.ParseFloat(fee.FixedAmount, 64); err != nil {
return fmt.Errorf("invalid orchestration_fee.fixed_amount: %w", err)
}
}
// Validate percentage if provided
if fee.Percentage != "" && fee.Percentage != "0" {
pct, err := strconv.ParseFloat(fee.Percentage, 64)
if err != nil {
return fmt.Errorf("invalid orchestration_fee.percentage: %w", err)
}
if pct < 0 || pct > 100 {
return fmt.Errorf("orchestration_fee.percentage must be between 0 and 100")
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package processors
import (
"fmt"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
"gitlab.com/nunet/device-management-service/utils/convert"
)
// FixedRentalProcessor implements PaymentModelProcessor for fixed_rental model.
type FixedRentalProcessor struct {
store *usage.Store
}
func NewFixedRentalProcessor(store *usage.Store) *FixedRentalProcessor {
if store == nil {
panic("usage store cannot be nil")
}
return &FixedRentalProcessor{store: store}
}
// CollectUsage implements PaymentModelProcessor.CollectUsage
// Fixed rental does not support manual billing
func (p *FixedRentalProcessor) CollectUsage(
_ string,
_ time.Time,
_ time.Time,
_ string, // providerDID (not used for fixed_rental)
_ string, // headContractDID (not used for fixed_rental)
) (*contracts.UsageData, error) {
return nil, fmt.Errorf("fixed_rental does not support manual billing")
}
// CalculatePayment implements PaymentModelProcessor.CalculatePayment
func (p *FixedRentalProcessor) CalculatePayment(
usageData *contracts.UsageData,
_ *contracts.Contract,
) ([]*contracts.PaymentItem, error) {
fixedRentalUsage, ok := usageData.Data.(*contracts.FixedRentalUsage)
if !ok {
return nil, fmt.Errorf("invalid usage data type")
}
// Generate unique UUID for this payment item
uniqueID := uuid.NewString()
items := []*contracts.PaymentItem{
{
UniqueID: uniqueID, // Generated UUID
DeploymentID: "", // Not deployment-based
Amount: fixedRentalUsage.Amount,
Usages: 1,
Metadata: map[string]interface{}{
"periods_invoiced": fixedRentalUsage.PeriodsInvoiced,
"period_start": fixedRentalUsage.PeriodStart.Format(time.RFC3339),
"period_end": fixedRentalUsage.PeriodEnd.Format(time.RFC3339),
"last_invoice_at": fixedRentalUsage.LastInvoiceAt.Format(time.RFC3339),
},
},
}
return items, nil
}
// Validate implements PaymentModelProcessor.Validate
func (p *FixedRentalProcessor) Validate(paymentDetails contracts.PaymentDetails) error {
if paymentDetails.FixedRentalAmount == "" {
return fmt.Errorf("fixed_rental_amount is required")
}
if paymentDetails.PaymentPeriod == "" {
return fmt.Errorf("payment_period is required")
}
if _, err := convert.StringToFloat64(paymentDetails.FixedRentalAmount, "fixed_rental_amount"); err != nil {
return err
}
if _, err := convert.ParsePaymentPeriod(paymentDetails.PaymentPeriod); err != nil {
return err
}
if paymentDetails.PaymentPeriodCount <= 0 {
return fmt.Errorf("payment_period_count must be a positive integer, got: %d", paymentDetails.PaymentPeriodCount)
}
return nil
}
// SupportsManualBilling implements PaymentModelProcessor.SupportsManualBilling
func (p *FixedRentalProcessor) SupportsManualBilling() bool {
return false
}
// SupportsAutomaticBilling implements PaymentModelProcessor.SupportsAutomaticBilling
func (p *FixedRentalProcessor) SupportsAutomaticBilling() bool {
return true
}
// CheckAndGenerateInvoice implements PaymentModelProcessor.CheckAndGenerateInvoice
func (p *FixedRentalProcessor) CheckAndGenerateInvoice(
contract *contracts.Contract,
lastInvoiceAt time.Time,
now time.Time,
) (*contracts.UsageData, error) {
pd := contract.PaymentDetails
// Parse payment period
periodDuration, err := convert.ParsePaymentPeriod(pd.PaymentPeriod)
if err != nil {
return nil, fmt.Errorf("invalid payment_period: %w", err)
}
// Calculate elapsed periods
billingCyclesElapsed, periodStart, periodEnd := convert.CalculateElapsedPeriods(
lastInvoiceAt,
now,
periodDuration,
pd.PaymentPeriodCount,
)
if billingCyclesElapsed < 1 {
return nil, ErrPeriodNotElapsed
}
// Parse fixed rental amount
fixedRentalAmount, err := convert.StringToFloat64(pd.FixedRentalAmount, "fixed_rental_amount")
if err != nil {
return nil, err
}
// Calculate total amount for all elapsed billing cycles
// Each billing cycle invoices for the fixedRentalAmount
periodsToInvoice := billingCyclesElapsed * pd.PaymentPeriodCount
if pd.PaymentPeriodCount <= 0 {
periodsToInvoice = billingCyclesElapsed
}
totalAmount := fixedRentalAmount * float64(billingCyclesElapsed)
return &contracts.UsageData{
ContractDID: contract.ContractDID,
PaymentModel: contracts.FixedRental,
Data: &contracts.FixedRentalUsage{
PeriodsInvoiced: periodsToInvoice,
PeriodStart: periodStart,
PeriodEnd: periodEnd,
Amount: formatAmount(totalAmount),
LastInvoiceAt: lastInvoiceAt,
},
}, nil
}
// GenerateProRatedInvoice generates a pro-rated invoice for a terminated contract
// based on the partial period elapsed (from lastInvoiceAt to now).
// Unlike CheckAndGenerateInvoice, this method does not require a full billing period to have elapsed.
func (p *FixedRentalProcessor) GenerateProRatedInvoice(
contract *contracts.Contract,
lastInvoiceAt time.Time,
now time.Time,
) (*contracts.UsageData, error) {
pd := contract.PaymentDetails
// Parse payment period
periodDuration, err := convert.ParsePaymentPeriod(pd.PaymentPeriod)
if err != nil {
return nil, fmt.Errorf("invalid payment_period: %w", err)
}
// For pro-rating, we invoice for the partial period from lastInvoiceAt to now
periodStart := lastInvoiceAt
periodEnd := now
// Calculate elapsed time
elapsed := now.Sub(lastInvoiceAt)
if elapsed <= 0 {
return nil, fmt.Errorf("no elapsed time for pro-rated invoice")
}
// Calculate billing cycle duration
// The fixed rental amount covers paymentPeriodCount periods (one billing cycle)
periodCount := pd.PaymentPeriodCount
if periodCount <= 0 {
periodCount = 1
}
billingCycleDuration := periodDuration * time.Duration(periodCount)
// Parse fixed rental amount
fixedRentalAmount, err := convert.StringToFloat64(pd.FixedRentalAmount, "fixed_rental_amount")
if err != nil {
return nil, err
}
// Calculate pro-rated amount: (elapsed / billing cycle duration) * fixedRentalAmount
proRatedRatio := float64(elapsed) / float64(billingCycleDuration)
proRatedAmount := proRatedRatio * fixedRentalAmount
// For pro-rated invoices, use 0 to indicate partial period
// The actual ratio is reflected in the amount calculation
return &contracts.UsageData{
ContractDID: contract.ContractDID,
PaymentModel: contracts.FixedRental,
Data: &contracts.FixedRentalUsage{
PeriodsInvoiced: 0, // Pro-rated, not a full period
PeriodStart: periodStart,
PeriodEnd: periodEnd,
Amount: formatAmount(proRatedAmount),
LastInvoiceAt: lastInvoiceAt,
},
}, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package processors
import (
"encoding/json"
"fmt"
"time"
"gitlab.com/nunet/device-management-service/tokenomics/events"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
)
// allocationWindow represents a time window for an allocation
type allocationWindow struct {
allocationID string
deploymentID string
startTime time.Time
endTime time.Time
isComplete bool
}
// deploymentWindow represents a time window for a deployment
type deploymentWindow struct {
deploymentID string
startTime time.Time
endTime time.Time
isComplete bool
}
// processAllocationEndEvent processes end events (CompleteAllocationEvent, StopAllocationEvent)
// and returns the allocationID if found. Returns empty string and false if not found or invalid.
func processAllocationEndEvent(evt *usage.Usage) (allocationID string, ok bool) {
if evt.Timestamp.IsZero() {
return "", false
}
switch evt.EventType {
case events.CompleteAllocationEvent:
var data events.CompleteAllocation
if err := json.Unmarshal(evt.Data, &data); err != nil {
return "", false
}
return data.AllocationID, true
case events.StopAllocationEvent:
var data events.StopAllocation
if err := json.Unmarshal(evt.Data, &data); err != nil {
return "", false
}
return data.AllocationID, true
default:
return "", false
}
}
// formatAmount formats a float64 amount as a string with 8 decimal places
func formatAmount(amount float64) string {
return fmt.Sprintf("%.8f", amount)
}
// FormatAmount formats a float64 amount as a string with 8 decimal places
// This is the public version for use in other packages
func FormatAmount(amount float64) string {
return fmt.Sprintf("%.8f", amount)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package processors
import (
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
)
// InitPaymentModelProcessors initializes and registers all payment model processors.
// This should be called during application startup.
func InitPaymentModelProcessors(store *usage.Store) {
if store == nil {
panic("usage store cannot be nil")
}
contracts.RegisterPaymentModelProcessor(
contracts.PayPerAllocation,
NewPayPerAllocationProcessor(store),
)
contracts.RegisterPaymentModelProcessor(
contracts.PayPerDeployment,
NewPayPerDeploymentProcessor(store),
)
contracts.RegisterPaymentModelProcessor(
contracts.PayPerTimeUtilization,
NewPayPerTimeUtilizationProcessor(store),
)
contracts.RegisterPaymentModelProcessor(
contracts.PayPerResourceUtilization,
NewPayPerResourceUtilizationProcessor(store),
)
contracts.RegisterPaymentModelProcessor(
contracts.FixedRental,
NewFixedRentalProcessor(store),
)
contracts.RegisterPaymentModelProcessor(
contracts.Periodic,
NewPeriodicProcessor(store),
)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package processors
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/events"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
"gitlab.com/nunet/device-management-service/utils/convert"
)
// PayPerAllocationProcessor implements PaymentModelProcessor for pay_per_allocation model.
// This processor counts allocations and calculates payment based on a fixed fee per allocation.
var _ contracts.PaymentModelProcessor = (*PayPerAllocationProcessor)(nil)
type PayPerAllocationProcessor struct {
store *usage.Store
}
func NewPayPerAllocationProcessor(store *usage.Store) *PayPerAllocationProcessor {
if store == nil {
panic("usage store cannot be nil")
}
return &PayPerAllocationProcessor{store: store}
}
// CollectUsage implements PaymentModelProcessor.CollectUsage
func (p *PayPerAllocationProcessor) CollectUsage(
contractDID string,
lastProcessedAt time.Time,
now time.Time,
_ string, // providerDID - unused in this processor
headContractDID string, // New parameter
) (*contracts.UsageData, error) {
var usageCount int
var err error
// If headContractDID is provided, query by Head Contract DID
// Otherwise, query by Tail Contract DID (backward compatible)
if headContractDID != "" {
filters := usage.EventFilters{
HeadContractDID: headContractDID,
EventTypes: []events.EventType{events.StartAllocationEvent},
StartTime: lastProcessedAt,
EndTime: now,
}
usageEvents, err := p.store.QueryEvents(filters)
if err != nil {
return nil, fmt.Errorf("failed to query events by head contract: %w", err)
}
// Count unique allocations from events
allocationSet := make(map[string]bool)
for _, evt := range usageEvents {
var evtData events.StartAllocation
if err := json.Unmarshal(evt.Data, &evtData); err != nil {
continue
}
if evtData.AllocationID != "" {
allocationSet[evtData.AllocationID] = true
}
}
usageCount = len(allocationSet)
} else {
// Existing logic: query by Tail Contract DID
usageCount, err = p.store.CountAllocationsByContractDID(contractDID, lastProcessedAt, now)
if err != nil {
return nil, fmt.Errorf("failed to count allocations: %w", err)
}
}
return &contracts.UsageData{
ContractDID: contractDID,
PaymentModel: contracts.PayPerAllocation,
Data: usageCount, // Simple count for this model
}, nil
}
// CalculatePayment implements PaymentModelProcessor.CalculatePayment
func (p *PayPerAllocationProcessor) CalculatePayment(
usageData *contracts.UsageData,
contract *contracts.Contract,
) ([]*contracts.PaymentItem, error) {
usageCount, ok := usageData.Data.(int)
if !ok {
return nil, fmt.Errorf("invalid usage data type for pay_per_allocation")
}
if usageCount == 0 {
return []*contracts.PaymentItem{}, nil
}
pd := contract.PaymentDetails
feePerAllocation, err := convert.StringToFloat64(pd.FeePerAllocation, "fees_per_allocation")
if err != nil {
return nil, err
}
totalAmount := feePerAllocation * float64(usageCount)
// Generate unique UUID for this payment item
uniqueID := uuid.NewString()
items := []*contracts.PaymentItem{
{
UniqueID: uniqueID, // Generated UUID
DeploymentID: "", // Not deployment-based
Amount: formatAmount(totalAmount),
Usages: usageCount,
Metadata: map[string]interface{}{
"allocation_count": usageCount,
"fee_per_allocation": feePerAllocation,
"total_amount": totalAmount,
"payment_model": contracts.PayPerAllocation,
"payment_period": pd.PaymentPeriod,
"payment_period_count": pd.PaymentPeriodCount,
},
},
}
return items, nil
}
// Validate implements PaymentModelProcessor.Validate
func (p *PayPerAllocationProcessor) Validate(paymentDetails contracts.PaymentDetails) error {
if paymentDetails.FeePerAllocation == "" {
return fmt.Errorf("fees_per_allocation is required")
}
if _, err := convert.StringToFloat64(paymentDetails.FeePerAllocation, "fees_per_allocation"); err != nil {
return err
}
// Validate payment_period if provided (optional for this model)
if paymentDetails.PaymentPeriod != "" {
if _, err := convert.ParsePaymentPeriod(paymentDetails.PaymentPeriod); err != nil {
return err
}
}
// Validate payment_period_count if provided (optional for this model)
if paymentDetails.PaymentPeriodCount < 0 {
return fmt.Errorf("payment_period_count must be a positive integer, got: %d", paymentDetails.PaymentPeriodCount)
}
return nil
}
// SupportsManualBilling implements PaymentModelProcessor.SupportsManualBilling
func (p *PayPerAllocationProcessor) SupportsManualBilling() bool {
return true
}
// SupportsAutomaticBilling implements PaymentModelProcessor.SupportsAutomaticBilling
func (p *PayPerAllocationProcessor) SupportsAutomaticBilling() bool {
return true
}
// CheckAndGenerateInvoice implements PaymentModelProcessor.CheckAndGenerateInvoice
func (p *PayPerAllocationProcessor) CheckAndGenerateInvoice(
contract *contracts.Contract,
lastInvoiceAt time.Time,
now time.Time,
) (*contracts.UsageData, error) {
pd := contract.PaymentDetails
// Parse payment period
periodDuration, err := convert.ParsePaymentPeriod(pd.PaymentPeriod)
if err != nil {
return nil, fmt.Errorf("invalid payment_period: %w", err)
}
// Calculate elapsed periods
billingCyclesElapsed, _, _ := convert.CalculateElapsedPeriods(
lastInvoiceAt,
now,
periodDuration,
pd.PaymentPeriodCount,
)
if billingCyclesElapsed < 1 {
return nil, ErrPeriodNotElapsed
}
// Period has elapsed, collect usage
// Detect contract type from metadata to determine query strategy
headContractDID := ""
if contract.Metadata != nil {
if role, ok := contract.Metadata[contracts.ContractChainRoleMetadataKey]; ok {
if role == contracts.ContractChainRoleHead {
// Head Contract: query by head_contract_did = contract's DID
headContractDID = contract.ContractDID
}
// For Tail Contract or P2P, headContractDID remains empty (query by contract_did)
}
}
return p.CollectUsage(contract.ContractDID, lastInvoiceAt, now, "", headContractDID)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package processors
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/events"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
"gitlab.com/nunet/device-management-service/utils/convert"
)
// PayPerDeploymentProcessor implements PaymentModelProcessor for pay_per_deployment model.
// This processor counts deployments and calculates payment based on a fixed fee per deployment.
var _ contracts.PaymentModelProcessor = (*PayPerDeploymentProcessor)(nil)
type PayPerDeploymentProcessor struct {
store *usage.Store
}
func NewPayPerDeploymentProcessor(store *usage.Store) *PayPerDeploymentProcessor {
if store == nil {
panic("usage store cannot be nil")
}
return &PayPerDeploymentProcessor{store: store}
}
// CollectUsage implements PaymentModelProcessor.CollectUsage
func (p *PayPerDeploymentProcessor) CollectUsage(
contractDID string,
lastProcessedAt time.Time,
now time.Time,
_ string, // providerDID - unused in this processor
headContractDID string, // New parameter
) (*contracts.UsageData, error) {
var usageCount int
var err error
// If headContractDID is provided, query by Head Contract DID
// Otherwise, query by Tail Contract DID (backward compatible)
if headContractDID != "" {
filters := usage.EventFilters{
HeadContractDID: headContractDID,
EventTypes: []events.EventType{events.DeploymentStartEvent},
StartTime: lastProcessedAt,
EndTime: now,
}
usageEvents, err := p.store.QueryEvents(filters)
if err != nil {
return nil, fmt.Errorf("failed to query events by head contract: %w", err)
}
// Count unique deployments from events
deploymentSet := make(map[string]bool)
for _, evt := range usageEvents {
var evtData events.DeploymentStart
if err := json.Unmarshal(evt.Data, &evtData); err != nil {
continue
}
if evtData.DeploymentID != "" {
deploymentSet[evtData.DeploymentID] = true
}
}
usageCount = len(deploymentSet)
} else {
// Existing logic: query by Tail Contract DID
usageCount, err = p.store.CountDeploymentsByContract(contractDID, lastProcessedAt, now)
if err != nil {
return nil, fmt.Errorf("failed to count deployments: %w", err)
}
}
return &contracts.UsageData{
ContractDID: contractDID,
PaymentModel: contracts.PayPerDeployment,
Data: usageCount, // Simple count for this model
}, nil
}
// CalculatePayment implements PaymentModelProcessor.CalculatePayment
func (p *PayPerDeploymentProcessor) CalculatePayment(
usageData *contracts.UsageData,
contract *contracts.Contract,
) ([]*contracts.PaymentItem, error) {
usageCount, ok := usageData.Data.(int)
if !ok {
return nil, fmt.Errorf("invalid usage data type for pay_per_deployment")
}
if usageCount == 0 {
return []*contracts.PaymentItem{}, nil
}
pd := contract.PaymentDetails
feePerDeployment, err := convert.StringToFloat64(pd.FeePerDeployment, "fee_per_deployment")
if err != nil {
return nil, err
}
totalAmount := feePerDeployment * float64(usageCount)
// Generate unique UUID for this payment item
uniqueID := uuid.NewString()
items := []*contracts.PaymentItem{
{
UniqueID: uniqueID, // Generated UUID
DeploymentID: "", // Not per-deployment item, but deployment-based model
Amount: formatAmount(totalAmount),
Usages: usageCount,
Metadata: map[string]interface{}{
"deployment_count": usageCount,
"fee_per_deployment": feePerDeployment,
"total_amount": totalAmount,
"payment_model": contracts.PayPerDeployment,
"payment_period": pd.PaymentPeriod,
"payment_period_count": pd.PaymentPeriodCount,
},
},
}
return items, nil
}
// Validate implements PaymentModelProcessor.Validate
func (p *PayPerDeploymentProcessor) Validate(paymentDetails contracts.PaymentDetails) error {
if paymentDetails.FeePerDeployment == "" {
return fmt.Errorf("fee_per_deployment is required")
}
if _, err := convert.StringToFloat64(paymentDetails.FeePerDeployment, "fee_per_deployment"); err != nil {
return err
}
// Validate payment_period if provided (optional for this model)
if paymentDetails.PaymentPeriod != "" {
if _, err := convert.ParsePaymentPeriod(paymentDetails.PaymentPeriod); err != nil {
return err
}
}
// Validate payment_period_count if provided (optional for this model)
if paymentDetails.PaymentPeriodCount < 0 {
return fmt.Errorf("payment_period_count must be a positive integer, got: %d", paymentDetails.PaymentPeriodCount)
}
return nil
}
// SupportsManualBilling implements PaymentModelProcessor.SupportsManualBilling
func (p *PayPerDeploymentProcessor) SupportsManualBilling() bool {
return true
}
// SupportsAutomaticBilling implements PaymentModelProcessor.SupportsAutomaticBilling
func (p *PayPerDeploymentProcessor) SupportsAutomaticBilling() bool {
return true
}
// CheckAndGenerateInvoice implements PaymentModelProcessor.CheckAndGenerateInvoice
func (p *PayPerDeploymentProcessor) CheckAndGenerateInvoice(
contract *contracts.Contract,
lastInvoiceAt time.Time,
now time.Time,
) (*contracts.UsageData, error) {
pd := contract.PaymentDetails
// Parse payment period
periodDuration, err := convert.ParsePaymentPeriod(pd.PaymentPeriod)
if err != nil {
return nil, fmt.Errorf("invalid payment_period: %w", err)
}
// Calculate elapsed periods
billingCyclesElapsed, _, _ := convert.CalculateElapsedPeriods(
lastInvoiceAt,
now,
periodDuration,
pd.PaymentPeriodCount,
)
if billingCyclesElapsed < 1 {
return nil, ErrPeriodNotElapsed
}
// Period has elapsed, collect usage
// Detect contract type from metadata to determine query strategy
headContractDID := ""
if contract.Metadata != nil {
if role, ok := contract.Metadata[contracts.ContractChainRoleMetadataKey]; ok {
if role == contracts.ContractChainRoleHead {
// Head Contract: query by head_contract_did = contract's DID
headContractDID = contract.ContractDID
}
// For Tail Contract or P2P, headContractDID remains empty (query by contract_did)
}
}
return p.CollectUsage(contract.ContractDID, lastInvoiceAt, now, "", headContractDID)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package processors
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/events"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/convert"
)
// PayPerResourceUtilizationProcessor implements PaymentModelProcessor for pay_per_resource_utilization model.
// This processor collects resource utilization and calculates payment based on resources × time.
var _ contracts.PaymentModelProcessor = (*PayPerResourceUtilizationProcessor)(nil)
type PayPerResourceUtilizationProcessor struct {
store *usage.Store
}
func NewPayPerResourceUtilizationProcessor(store *usage.Store) *PayPerResourceUtilizationProcessor {
if store == nil {
panic("usage store cannot be nil")
}
return &PayPerResourceUtilizationProcessor{store: store}
}
// CollectUsage implements PaymentModelProcessor.CollectUsage
func (p *PayPerResourceUtilizationProcessor) CollectUsage(
contractDID string,
lastProcessedAt time.Time,
now time.Time,
_ string, // providerDID - unused in this processor
headContractDID string, // New parameter
) (*contracts.UsageData, error) {
// Use store's abstract query methods
// If headContractDID is provided, query by Head Contract DID; otherwise query by contractDID
startEvents, endEvents, err := p.store.QueryAllocationEvents(contractDID, lastProcessedAt, now, headContractDID)
if err != nil {
return nil, fmt.Errorf("failed to query allocation events: %w", err)
}
// Query create events for resource fallback
createEvents, err := p.store.QueryCreateAllocationEvents(contractDID)
if err != nil {
return nil, fmt.Errorf("failed to query create allocation events: %w", err)
}
// Build resource map from create events (fallback)
allocationResources := make(map[string]types.Resources)
for _, evt := range createEvents {
var data events.CreateAllocation
if err := json.Unmarshal(evt.Data, &data); err != nil {
continue
}
allocationResources[data.AllocationID] = data.Resources
}
// Build allocation windows with resources
windows := p.buildAllocationWindowsWithResources(startEvents, endEvents, allocationResources)
// Group by deployment and calculate resource utilization
deployments := p.calculateDeploymentResourceUtilization(windows, lastProcessedAt, now)
return &contracts.UsageData{
ContractDID: contractDID,
PaymentModel: contracts.PayPerResourceUtilization,
Data: &contracts.ResourceUtilizationUsage{Deployments: deployments},
}, nil
}
// CalculatePayment implements PaymentModelProcessor.CalculatePayment
func (p *PayPerResourceUtilizationProcessor) CalculatePayment(
usageData *contracts.UsageData,
contract *contracts.Contract,
) ([]*contracts.PaymentItem, error) {
resourceUtil, ok := usageData.Data.(*contracts.ResourceUtilizationUsage)
if !ok {
return nil, fmt.Errorf("invalid usage data type")
}
pd := contract.PaymentDetails
feePerCPUCore, err := convert.StringToFloat64(pd.FeePerCPUCorePerTimeUnit, "fee_per_cpu_core_per_time_unit")
if err != nil {
return nil, err
}
feePerRAMGB, err := convert.StringToFloat64(pd.FeePerRAMGBPerTimeUnit, "fee_per_ram_gb_per_time_unit")
if err != nil {
return nil, err
}
feePerDiskGB, err := convert.StringToFloat64(pd.FeePerDiskGBPerTimeUnit, "fee_per_disk_gb_per_time_unit")
if err != nil {
return nil, err
}
var feePerGPU float64
if pd.FeePerGPUPerTimeUnit != "" {
feePerGPU, err = convert.StringToFloat64(pd.FeePerGPUPerTimeUnit, "fee_per_gpu_per_time_unit")
if err != nil {
return nil, err
}
}
items := make([]*contracts.PaymentItem, 0)
const bytesInGB = 1024 * 1024 * 1024
for _, deployment := range resourceUtil.Deployments {
var deploymentTotalCost float64
// Process each allocation in the deployment
for key, allocation := range deployment.Allocations {
// Convert duration to time unit
timeInUnit, err := convert.DurationToUnit(allocation.Duration, pd.ResourceTimeUnit)
if err != nil {
return nil, fmt.Errorf("failed to convert duration: %w", err)
}
// Calculate costs per resource
// Note: RAM.Size and Disk.Size are in bytes, need to convert to GB
cpuCost := float64(allocation.Resources.CPU.Cores) * feePerCPUCore * timeInUnit
ramCostGB := float64(allocation.Resources.RAM.Size) / float64(bytesInGB) // Convert bytes to GB (binary)
ramCost := ramCostGB * feePerRAMGB * timeInUnit
diskCostGB := float64(allocation.Resources.Disk.Size) / float64(bytesInGB) // Convert bytes to GB (binary)
diskCost := diskCostGB * feePerDiskGB * timeInUnit
var gpuCost float64
if len(allocation.Resources.GPUs) > 0 && feePerGPU > 0 {
gpuCost = float64(len(allocation.Resources.GPUs)) * feePerGPU * timeInUnit
}
allocation.CPUCost = formatAmount(cpuCost)
allocation.RAMCost = formatAmount(ramCost)
allocation.DiskCost = formatAmount(diskCost)
allocation.GPUCost = formatAmount(gpuCost)
allocation.TotalCost = formatAmount(cpuCost + ramCost + diskCost + gpuCost)
deployment.Allocations[key] = allocation
allocationCost := cpuCost + ramCost + diskCost + gpuCost
deploymentTotalCost += allocationCost
}
// Generate unique UUID for this payment item
uniqueID := uuid.NewString()
// Build enriched metadata
metadata := map[string]interface{}{
"total_utilization_sec": deployment.TotalUtilizationSec,
"allocation_count": len(deployment.Allocations),
}
// Add allocation details with resources
allocations := make([]map[string]interface{}, 0, len(deployment.Allocations))
for _, alloc := range deployment.Allocations {
allocData := map[string]interface{}{
"allocation_id": alloc.AllocationID,
"duration_sec": alloc.Duration.Seconds(),
"start_time": alloc.StartTime.Format(time.RFC3339),
"resources": map[string]interface{}{
"cpu_cores": alloc.Resources.CPU.Cores,
"ram_gb": float64(alloc.Resources.RAM.Size) / float64(bytesInGB),
"disk_gb": float64(alloc.Resources.Disk.Size) / float64(bytesInGB),
"gpu_count": len(alloc.Resources.GPUs),
},
}
if !alloc.EndTime.IsZero() {
allocData["end_time"] = alloc.EndTime.Format(time.RFC3339)
}
// Add costs if available
if alloc.CPUCost != "" {
allocData["cpu_cost"] = alloc.CPUCost
}
if alloc.RAMCost != "" {
allocData["ram_cost"] = alloc.RAMCost
}
if alloc.DiskCost != "" {
allocData["disk_cost"] = alloc.DiskCost
}
if alloc.GPUCost != "" {
allocData["gpu_cost"] = alloc.GPUCost
}
if alloc.TotalCost != "" {
allocData["total_cost"] = alloc.TotalCost
}
allocations = append(allocations, allocData)
}
metadata["allocations"] = allocations
items = append(items, &contracts.PaymentItem{
UniqueID: uniqueID, // Generated UUID
DeploymentID: deployment.DeploymentID,
Amount: formatAmount(deploymentTotalCost),
Usages: 1,
Metadata: metadata,
})
}
return items, nil
}
// Validate implements PaymentModelProcessor.Validate
func (p *PayPerResourceUtilizationProcessor) Validate(paymentDetails contracts.PaymentDetails) error {
if paymentDetails.FeePerCPUCorePerTimeUnit == "" {
return fmt.Errorf("fee_per_cpu_core_per_time_unit is required")
}
if paymentDetails.FeePerRAMGBPerTimeUnit == "" {
return fmt.Errorf("fee_per_ram_gb_per_time_unit is required")
}
if paymentDetails.FeePerDiskGBPerTimeUnit == "" {
return fmt.Errorf("fee_per_disk_gb_per_time_unit is required")
}
if paymentDetails.ResourceTimeUnit == "" {
return fmt.Errorf("resource_time_unit is required")
}
if _, err := convert.StringToFloat64(paymentDetails.FeePerCPUCorePerTimeUnit, "fee_per_cpu_core_per_time_unit"); err != nil {
return err
}
if _, err := convert.StringToFloat64(paymentDetails.FeePerRAMGBPerTimeUnit, "fee_per_ram_gb_per_time_unit"); err != nil {
return err
}
if _, err := convert.StringToFloat64(paymentDetails.FeePerDiskGBPerTimeUnit, "fee_per_disk_gb_per_time_unit"); err != nil {
return err
}
if paymentDetails.FeePerGPUPerTimeUnit != "" {
if _, err := convert.StringToFloat64(paymentDetails.FeePerGPUPerTimeUnit, "fee_per_gpu_per_time_unit"); err != nil {
return err
}
}
// Validate payment_period if provided (optional for this model)
if paymentDetails.PaymentPeriod != "" {
if _, err := convert.ParsePaymentPeriod(paymentDetails.PaymentPeriod); err != nil {
return err
}
}
// Validate payment_period_count if provided (optional for this model)
if paymentDetails.PaymentPeriodCount < 0 {
return fmt.Errorf("payment_period_count must be a positive integer, got: %d", paymentDetails.PaymentPeriodCount)
}
return nil
}
// SupportsManualBilling implements PaymentModelProcessor.SupportsManualBilling
func (p *PayPerResourceUtilizationProcessor) SupportsManualBilling() bool {
return true
}
// SupportsAutomaticBilling implements PaymentModelProcessor.SupportsAutomaticBilling
func (p *PayPerResourceUtilizationProcessor) SupportsAutomaticBilling() bool {
return true
}
// CheckAndGenerateInvoice implements PaymentModelProcessor.CheckAndGenerateInvoice
func (p *PayPerResourceUtilizationProcessor) CheckAndGenerateInvoice(
contract *contracts.Contract,
lastInvoiceAt time.Time,
now time.Time,
) (*contracts.UsageData, error) {
pd := contract.PaymentDetails
// Parse payment period
periodDuration, err := convert.ParsePaymentPeriod(pd.PaymentPeriod)
if err != nil {
return nil, fmt.Errorf("invalid payment_period: %w", err)
}
// Calculate elapsed periods
billingCyclesElapsed, _, _ := convert.CalculateElapsedPeriods(
lastInvoiceAt,
now,
periodDuration,
pd.PaymentPeriodCount,
)
if billingCyclesElapsed < 1 {
return nil, ErrPeriodNotElapsed
}
// Period has elapsed, collect usage
// Detect contract type from metadata to determine query strategy
headContractDID := ""
if contract.Metadata != nil {
if role, ok := contract.Metadata[contracts.ContractChainRoleMetadataKey]; ok {
if role == contracts.ContractChainRoleHead {
headContractDID = contract.ContractDID
}
}
}
return p.CollectUsage(contract.ContractDID, lastInvoiceAt, now, "", headContractDID)
}
// buildAllocationWindowsWithResources builds windows with resources from events
func (p *PayPerResourceUtilizationProcessor) buildAllocationWindowsWithResources(
startEvents, endEvents []*usage.Usage,
allocationResources map[string]types.Resources,
) map[string]*allocationWindowWithResources {
windows := make(map[string]*allocationWindowWithResources)
// Process start events
for _, evt := range startEvents {
if evt.Timestamp.IsZero() {
continue
}
var data events.StartAllocation
if err := json.Unmarshal(evt.Data, &data); err != nil {
continue
}
// Get resources from StartAllocationEvent (primary source)
resources := data.Resources
// Fallback: If resources not in StartAllocationEvent, use CreateAllocationEvent
if resources.CPU.Cores == 0 && resources.RAM.Size == 0 {
if createRes, ok := allocationResources[data.AllocationID]; ok {
resources = createRes
}
}
windows[data.AllocationID] = &allocationWindowWithResources{
allocationWindow: allocationWindow{
allocationID: data.AllocationID,
deploymentID: data.DeploymentID,
startTime: evt.Timestamp,
},
resources: resources,
}
}
// Process end events using shared helper
for _, evt := range endEvents {
allocationID, ok := processAllocationEndEvent(evt)
if !ok {
continue
}
window := windows[allocationID]
if window != nil {
window.endTime = evt.Timestamp
window.isComplete = true
}
}
return windows
}
// calculateDeploymentResourceUtilization groups windows by deployment with resources
func (p *PayPerResourceUtilizationProcessor) calculateDeploymentResourceUtilization(
windows map[string]*allocationWindowWithResources,
queryStart, queryEnd time.Time,
) []contracts.DeploymentResourceUtilization {
deploymentMap := make(map[string]*contracts.DeploymentResourceUtilization)
for _, window := range windows {
effectiveStart, effectiveEnd, valid := usage.CalculateEffectiveTime(
window.startTime, window.endTime, window.isComplete, queryStart, queryEnd,
)
if !valid {
continue
}
duration := effectiveEnd.Sub(effectiveStart)
if deploymentMap[window.deploymentID] == nil {
deploymentMap[window.deploymentID] = &contracts.DeploymentResourceUtilization{
DeploymentID: window.deploymentID,
Allocations: make([]contracts.AllocationResourceUtilization, 0),
}
}
allocUtil := contracts.AllocationResourceUtilization{
AllocationID: window.allocationID,
Resources: window.resources,
Duration: duration,
StartTime: window.startTime, // Always use actual start time for tracking
}
if window.isComplete {
allocUtil.EndTime = window.endTime
}
deploymentMap[window.deploymentID].Allocations = append(
deploymentMap[window.deploymentID].Allocations,
allocUtil,
)
deploymentMap[window.deploymentID].TotalUtilizationSec += duration.Seconds()
}
result := make([]contracts.DeploymentResourceUtilization, 0, len(deploymentMap))
for _, deployment := range deploymentMap {
result = append(result, *deployment)
}
return result
}
type allocationWindowWithResources struct {
allocationWindow
resources types.Resources
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package processors
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/events"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
"gitlab.com/nunet/device-management-service/utils/convert"
)
// PayPerTimeUtilizationProcessor implements PaymentModelProcessor for pay_per_time_utilization model.
// This processor collects allocation time utilization and calculates payment based on time duration.
var _ contracts.PaymentModelProcessor = (*PayPerTimeUtilizationProcessor)(nil)
type PayPerTimeUtilizationProcessor struct {
store *usage.Store
}
func NewPayPerTimeUtilizationProcessor(store *usage.Store) *PayPerTimeUtilizationProcessor {
if store == nil {
panic("usage store cannot be nil")
}
return &PayPerTimeUtilizationProcessor{store: store}
}
// CollectUsage implements PaymentModelProcessor.CollectUsage
func (p *PayPerTimeUtilizationProcessor) CollectUsage(
contractDID string,
lastProcessedAt time.Time,
now time.Time,
_ string, // providerDID - unused in this processor
headContractDID string, // New parameter
) (*contracts.UsageData, error) {
// Use store's abstract query method
// If headContractDID is provided, query by Head Contract DID; otherwise query by contractDID
startEvents, endEvents, err := p.store.QueryAllocationEvents(contractDID, lastProcessedAt, now, headContractDID)
if err != nil {
return nil, fmt.Errorf("failed to query allocation events: %w", err)
}
// Build allocation windows (processor-specific logic)
windows := p.buildAllocationWindows(startEvents, endEvents)
// Group by deployment and calculate time utilization
deployments := p.calculateDeploymentTimeUtilization(windows, lastProcessedAt, now)
return &contracts.UsageData{
ContractDID: contractDID,
PaymentModel: contracts.PayPerTimeUtilization,
Data: &contracts.TimeUtilizationUsage{Deployments: deployments},
}, nil
}
// CalculatePayment implements PaymentModelProcessor.CalculatePayment
func (p *PayPerTimeUtilizationProcessor) CalculatePayment(
usageData *contracts.UsageData,
contract *contracts.Contract,
) ([]*contracts.PaymentItem, error) {
timeUtil, ok := usageData.Data.(*contracts.TimeUtilizationUsage)
if !ok {
return nil, fmt.Errorf("invalid usage data type")
}
pd := contract.PaymentDetails
feePerUnit, err := convert.StringToFloat64(pd.FeePerTimeUnit, "fee_per_time_unit")
if err != nil {
return nil, err
}
items := make([]*contracts.PaymentItem, 0)
for _, deployment := range timeUtil.Deployments {
// Use utility function for time conversion
timeInUnit, err := convert.SecondsToUnit(deployment.TotalUtilizationSec, pd.TimeUnit)
if err != nil {
return nil, fmt.Errorf("failed to convert time: %w", err)
}
amount := feePerUnit * timeInUnit
// Generate unique UUID for this payment item
uniqueID := uuid.NewString()
// Build enriched metadata
metadata := map[string]interface{}{
"total_utilization_sec": deployment.TotalUtilizationSec,
"allocation_count": len(deployment.Allocations),
}
// Add allocation details
allocations := make([]map[string]interface{}, 0, len(deployment.Allocations))
for _, alloc := range deployment.Allocations {
allocData := map[string]interface{}{
"allocation_id": alloc.AllocationID,
"duration_sec": alloc.Duration.Seconds(),
"start_time": alloc.StartTime.Format(time.RFC3339),
"payment_model": contracts.PayPerTimeUtilization,
"payment_period": pd.PaymentPeriod,
"payment_period_count": pd.PaymentPeriodCount,
"fee_per_time_unit": pd.FeePerTimeUnit,
"time_unit": pd.TimeUnit,
}
if !alloc.EndTime.IsZero() {
allocData["end_time"] = alloc.EndTime.Format(time.RFC3339)
}
allocations = append(allocations, allocData)
}
metadata["allocations"] = allocations
items = append(items, &contracts.PaymentItem{
UniqueID: uniqueID, // Generated UUID
DeploymentID: deployment.DeploymentID,
Amount: formatAmount(amount),
Usages: 1,
Metadata: metadata,
})
}
return items, nil
}
// Validate implements PaymentModelProcessor.Validate
func (p *PayPerTimeUtilizationProcessor) Validate(paymentDetails contracts.PaymentDetails) error {
if paymentDetails.FeePerTimeUnit == "" {
return fmt.Errorf("fee_per_time_unit is required")
}
if paymentDetails.TimeUnit == "" {
return fmt.Errorf("time_unit is required")
}
if _, err := convert.StringToFloat64(paymentDetails.FeePerTimeUnit, "fee_per_time_unit"); err != nil {
return err
}
// Validate payment_period if provided (optional for this model)
if paymentDetails.PaymentPeriod != "" {
if _, err := convert.ParsePaymentPeriod(paymentDetails.PaymentPeriod); err != nil {
return err
}
}
// Validate payment_period_count if provided (optional for this model)
if paymentDetails.PaymentPeriodCount < 0 {
return fmt.Errorf("payment_period_count must be a positive integer, got: %d", paymentDetails.PaymentPeriodCount)
}
return nil
}
// SupportsManualBilling implements PaymentModelProcessor.SupportsManualBilling
func (p *PayPerTimeUtilizationProcessor) SupportsManualBilling() bool {
return true
}
// SupportsAutomaticBilling implements PaymentModelProcessor.SupportsAutomaticBilling
func (p *PayPerTimeUtilizationProcessor) SupportsAutomaticBilling() bool {
return true
}
// CheckAndGenerateInvoice implements PaymentModelProcessor.CheckAndGenerateInvoice
func (p *PayPerTimeUtilizationProcessor) CheckAndGenerateInvoice(
contract *contracts.Contract,
lastInvoiceAt time.Time,
now time.Time,
) (*contracts.UsageData, error) {
pd := contract.PaymentDetails
// Parse payment period
periodDuration, err := convert.ParsePaymentPeriod(pd.PaymentPeriod)
if err != nil {
return nil, fmt.Errorf("invalid payment_period: %w", err)
}
// Calculate elapsed periods
billingCyclesElapsed, _, _ := convert.CalculateElapsedPeriods(
lastInvoiceAt,
now,
periodDuration,
pd.PaymentPeriodCount,
)
if billingCyclesElapsed < 1 {
return nil, ErrPeriodNotElapsed
}
// Period has elapsed, collect usage
// Detect contract type from metadata to determine query strategy
headContractDID := ""
if contract.Metadata != nil {
if role, ok := contract.Metadata[contracts.ContractChainRoleMetadataKey]; ok {
if role == contracts.ContractChainRoleHead {
headContractDID = contract.ContractDID
}
}
}
return p.CollectUsage(contract.ContractDID, lastInvoiceAt, now, "", headContractDID)
}
// buildAllocationWindows is processor-specific logic
func (p *PayPerTimeUtilizationProcessor) buildAllocationWindows(
startEvents, endEvents []*usage.Usage,
) map[string]*allocationWindow {
windows := make(map[string]*allocationWindow)
// Process start events
for _, evt := range startEvents {
if evt.Timestamp.IsZero() {
continue
}
var data events.StartAllocation
if err := json.Unmarshal(evt.Data, &data); err != nil {
continue
}
windows[data.AllocationID] = &allocationWindow{
allocationID: data.AllocationID,
deploymentID: data.DeploymentID,
startTime: evt.Timestamp,
}
}
// Process end events using shared helper
for _, evt := range endEvents {
allocationID, ok := processAllocationEndEvent(evt)
if !ok {
continue
}
window := windows[allocationID]
if window != nil {
window.endTime = evt.Timestamp
window.isComplete = true
}
}
return windows
}
// calculateDeploymentTimeUtilization groups windows by deployment
func (p *PayPerTimeUtilizationProcessor) calculateDeploymentTimeUtilization(
windows map[string]*allocationWindow,
queryStart, queryEnd time.Time,
) []contracts.DeploymentTimeUtilization {
deploymentMap := make(map[string]*contracts.DeploymentTimeUtilization)
for _, window := range windows {
effectiveStart, effectiveEnd, valid := usage.CalculateEffectiveTime(
window.startTime, window.endTime, window.isComplete, queryStart, queryEnd,
)
if !valid {
continue
}
duration := effectiveEnd.Sub(effectiveStart)
if deploymentMap[window.deploymentID] == nil {
deploymentMap[window.deploymentID] = &contracts.DeploymentTimeUtilization{
DeploymentID: window.deploymentID,
Allocations: make([]contracts.AllocationTimeUtilization, 0),
}
}
allocUtil := contracts.AllocationTimeUtilization{
AllocationID: window.allocationID,
Duration: duration,
StartTime: window.startTime, // Always use actual start time for tracking
}
if window.isComplete {
allocUtil.EndTime = window.endTime
}
deploymentMap[window.deploymentID].Allocations = append(
deploymentMap[window.deploymentID].Allocations,
allocUtil,
)
deploymentMap[window.deploymentID].TotalUtilizationSec += duration.Seconds()
}
result := make([]contracts.DeploymentTimeUtilization, 0, len(deploymentMap))
for _, deployment := range deploymentMap {
result = append(result, *deployment)
}
return result
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package processors
import (
"encoding/json"
"fmt"
"time"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/tokenomics/events"
"gitlab.com/nunet/device-management-service/tokenomics/store/usage"
"gitlab.com/nunet/device-management-service/utils/convert"
)
// PeriodicProcessor implements PaymentModelProcessor for periodic model.
var _ contracts.PaymentModelProcessor = (*PeriodicProcessor)(nil)
type PeriodicProcessor struct {
store *usage.Store
}
func NewPeriodicProcessor(store *usage.Store) *PeriodicProcessor {
if store == nil {
panic("usage store cannot be nil")
}
return &PeriodicProcessor{store: store}
}
// CollectUsage implements PaymentModelProcessor.CollectUsage
// Periodic does not support manual billing
func (p *PeriodicProcessor) CollectUsage(
_ string,
_ time.Time,
_ time.Time,
_ string, // providerDID (not used for periodic)
_ string, // headContractDID (not used for periodic)
) (*contracts.UsageData, error) {
return nil, fmt.Errorf("periodic does not support manual billing")
}
// CalculatePayment implements PaymentModelProcessor.CalculatePayment
func (p *PeriodicProcessor) CalculatePayment(
usageData *contracts.UsageData,
contract *contracts.Contract,
) ([]*contracts.PaymentItem, error) {
periodicUsage, ok := usageData.Data.(*contracts.PeriodicUsage)
if !ok {
return nil, fmt.Errorf("invalid usage data type")
}
pd := contract.PaymentDetails
feePerUnit, err := convert.StringToFloat64(pd.FeePerTimeUnit, "fee_per_time_unit")
if err != nil {
return nil, err
}
items := make([]*contracts.PaymentItem, 0)
for _, deployment := range periodicUsage.Deployments {
// Convert deployment time to time unit
timeInUnit, err := convert.SecondsToUnit(deployment.TotalUtilizationSec, pd.TimeUnit)
if err != nil {
return nil, fmt.Errorf("failed to convert time: %w", err)
}
amount := feePerUnit * timeInUnit
// Generate unique UUID for this payment item
uniqueID := uuid.NewString()
// Build enriched metadata
metadata := map[string]interface{}{
"total_utilization_sec": deployment.TotalUtilizationSec,
"period_start": periodicUsage.PeriodStart.Format(time.RFC3339),
"period_end": periodicUsage.PeriodEnd.Format(time.RFC3339),
"periods_invoiced": periodicUsage.PeriodsInvoiced,
"payment_model": contracts.Periodic,
"payment_period": pd.PaymentPeriod,
"payment_period_count": pd.PaymentPeriodCount,
"fee_per_time_unit": pd.FeePerTimeUnit,
"time_unit": pd.TimeUnit,
"deployment_id": deployment.DeploymentID,
"amount": formatAmount(amount),
}
// Add allocation details if available
if len(deployment.Allocations) > 0 {
metadata["allocation_count"] = len(deployment.Allocations)
allocations := make([]map[string]interface{}, 0, len(deployment.Allocations))
for _, alloc := range deployment.Allocations {
allocData := map[string]interface{}{
"allocation_id": alloc.AllocationID,
"duration_sec": alloc.Duration.Seconds(),
"start_time": alloc.StartTime.Format(time.RFC3339),
}
if !alloc.EndTime.IsZero() {
allocData["end_time"] = alloc.EndTime.Format(time.RFC3339)
}
allocations = append(allocations, allocData)
}
metadata["allocations"] = allocations
}
items = append(items, &contracts.PaymentItem{
UniqueID: uniqueID, // Generated UUID
DeploymentID: deployment.DeploymentID,
Amount: formatAmount(amount),
Usages: 1,
Metadata: metadata,
})
}
return items, nil
}
// Validate implements PaymentModelProcessor.Validate
func (p *PeriodicProcessor) Validate(paymentDetails contracts.PaymentDetails) error {
if paymentDetails.FeePerTimeUnit == "" {
return fmt.Errorf("fee_per_time_unit is required")
}
if paymentDetails.TimeUnit == "" {
return fmt.Errorf("time_unit is required")
}
if paymentDetails.PaymentPeriod == "" {
return fmt.Errorf("payment_period is required")
}
if _, err := convert.StringToFloat64(paymentDetails.FeePerTimeUnit, "fee_per_time_unit"); err != nil {
return err
}
if _, err := convert.ParsePaymentPeriod(paymentDetails.PaymentPeriod); err != nil {
return err
}
if paymentDetails.PaymentPeriodCount <= 0 {
return fmt.Errorf("payment_period_count must be a positive integer, got: %d", paymentDetails.PaymentPeriodCount)
}
return nil
}
// SupportsManualBilling implements PaymentModelProcessor.SupportsManualBilling
func (p *PeriodicProcessor) SupportsManualBilling() bool {
return false
}
// SupportsAutomaticBilling implements PaymentModelProcessor.SupportsAutomaticBilling
func (p *PeriodicProcessor) SupportsAutomaticBilling() bool {
return true
}
// CheckAndGenerateInvoice implements PaymentModelProcessor.CheckAndGenerateInvoice
func (p *PeriodicProcessor) CheckAndGenerateInvoice(
contract *contracts.Contract,
lastInvoiceAt time.Time,
now time.Time,
) (*contracts.UsageData, error) {
pd := contract.PaymentDetails
// Parse payment period
periodDuration, err := convert.ParsePaymentPeriod(pd.PaymentPeriod)
if err != nil {
return nil, fmt.Errorf("invalid payment_period: %w", err)
}
// Calculate elapsed periods
billingCyclesElapsed, periodStart, periodEnd := convert.CalculateElapsedPeriods(
lastInvoiceAt,
now,
periodDuration,
pd.PaymentPeriodCount,
)
if billingCyclesElapsed < 1 {
return nil, ErrPeriodNotElapsed
}
// Detect contract type from metadata to determine query strategy
headContractDID := ""
if contract.Metadata != nil {
if role, ok := contract.Metadata[contracts.ContractChainRoleMetadataKey]; ok {
if role == contracts.ContractChainRoleHead {
// Head Contract: query by head_contract_did = contract's DID
headContractDID = contract.ContractDID
}
// For Tail Contract or P2P, headContractDID remains empty (query by contract_did)
}
}
// Query deployment start and stop events
startEvents, stopEvents, err := p.store.QueryDeploymentEvents(contract.ContractDID, periodStart, periodEnd, headContractDID)
if err != nil {
return nil, fmt.Errorf("failed to query deployment events: %w", err)
}
// Build deployment windows (processor-specific logic)
windows := p.buildDeploymentWindows(startEvents, stopEvents)
// Calculate deployment time utilization
deployments, totalTimeSec := p.calculateDeploymentTimeUtilization(windows, periodStart, periodEnd, now)
// Edge Case: No deployments during period - skip invoice
if len(deployments) == 0 {
return nil, ErrNoDeployments
}
// If totalTimeSec is zero or negative, skip invoice
if totalTimeSec <= 0 {
return nil, ErrNoDeployments
}
// Parse fee per time unit
feePerUnit, err := convert.StringToFloat64(pd.FeePerTimeUnit, "fee_per_time_unit")
if err != nil {
return nil, err
}
// Convert total time to the specified time unit
timeInUnit, err := convert.SecondsToUnit(totalTimeSec, pd.TimeUnit)
if err != nil {
return nil, fmt.Errorf("failed to convert time: %w", err)
}
// Calculate total amount (sum across all deployments for this period)
// Note: Each deployment will get its own invoice, but this provides the combined total
totalAmount := feePerUnit * timeInUnit
periodsToInvoice := billingCyclesElapsed * pd.PaymentPeriodCount
if pd.PaymentPeriodCount <= 0 {
periodsToInvoice = billingCyclesElapsed
}
return &contracts.UsageData{
ContractDID: contract.ContractDID,
PaymentModel: contracts.Periodic,
Data: &contracts.PeriodicUsage{
PeriodStart: periodStart,
PeriodEnd: periodEnd,
LastInvoiceAt: lastInvoiceAt,
Deployments: deployments,
TotalTimeSec: totalTimeSec,
Amount: formatAmount(totalAmount),
PeriodsInvoiced: periodsToInvoice,
},
}, nil
}
// GenerateProRatedInvoice generates a pro-rated invoice for a terminated contract
// based on actual deployment time within the partial period (from lastInvoiceAt to now).
// Unlike CheckAndGenerateInvoice, this method does not require a full billing period to have elapsed.
func (p *PeriodicProcessor) GenerateProRatedInvoice(
contract *contracts.Contract,
lastInvoiceAt time.Time,
now time.Time,
) (*contracts.UsageData, error) {
pd := contract.PaymentDetails
// For pro-rating, we invoice for the actual deployment time from lastInvoiceAt to now
periodStart := lastInvoiceAt
periodEnd := now
// Detect contract type from metadata to determine query strategy
headContractDID := ""
if contract.Metadata != nil {
if role, ok := contract.Metadata[contracts.ContractChainRoleMetadataKey]; ok {
if role == contracts.ContractChainRoleHead {
// Head Contract: query by head_contract_did = contract's DID
headContractDID = contract.ContractDID
}
// For Tail Contract or P2P, headContractDID remains empty (query by contract_did)
}
}
// Query deployment start and stop events in the partial period
startEvents, stopEvents, err := p.store.QueryDeploymentEvents(contract.ContractDID, periodStart, periodEnd, headContractDID)
if err != nil {
return nil, fmt.Errorf("failed to query deployment events: %w", err)
}
// Build deployment windows (processor-specific logic)
windows := p.buildDeploymentWindows(startEvents, stopEvents)
// Calculate deployment time utilization based on actual runtime
deployments, totalTimeSec := p.calculateDeploymentTimeUtilization(windows, periodStart, periodEnd, now)
// Edge Case: No deployments during partial period - skip invoice
if len(deployments) == 0 {
return nil, ErrNoDeployments
}
// If totalTimeSec is zero or negative, skip invoice
if totalTimeSec <= 0 {
return nil, ErrNoDeployments
}
// Parse fee per time unit
feePerUnit, err := convert.StringToFloat64(pd.FeePerTimeUnit, "fee_per_time_unit")
if err != nil {
return nil, err
}
// Convert total time to the specified time unit
timeInUnit, err := convert.SecondsToUnit(totalTimeSec, pd.TimeUnit)
if err != nil {
return nil, fmt.Errorf("failed to convert time: %w", err)
}
// Calculate pro-rated amount based on actual deployment time
// This is the key difference: we charge for actual runtime, not a full period
totalAmount := feePerUnit * timeInUnit
return &contracts.UsageData{
ContractDID: contract.ContractDID,
PaymentModel: contracts.Periodic,
Data: &contracts.PeriodicUsage{
PeriodStart: periodStart,
PeriodEnd: periodEnd,
LastInvoiceAt: lastInvoiceAt,
Deployments: deployments,
TotalTimeSec: totalTimeSec,
Amount: formatAmount(totalAmount),
PeriodsInvoiced: 0, // Pro-rated, not a full period
},
}, nil
}
// buildDeploymentWindows is processor-specific logic for building deployment windows from events
func (p *PeriodicProcessor) buildDeploymentWindows(
startEvents, stopEvents []*usage.Usage,
) map[string]*deploymentWindow {
windows := make(map[string]*deploymentWindow)
// Process start events
for _, evt := range startEvents {
if evt.EventType != events.DeploymentStartEvent {
continue
}
eventTime := evt.Timestamp
if eventTime.IsZero() {
continue
}
var data events.DeploymentStart
if err := json.Unmarshal(evt.Data, &data); err != nil {
continue
}
if windows[data.DeploymentID] == nil {
windows[data.DeploymentID] = &deploymentWindow{
deploymentID: data.DeploymentID,
startTime: eventTime,
isComplete: false,
}
}
}
// Process stop events
for _, evt := range stopEvents {
if evt.EventType != events.DeploymentStopEvent {
continue
}
eventTime := evt.Timestamp
if eventTime.IsZero() {
continue
}
var data events.DeploymentStop
if err := json.Unmarshal(evt.Data, &data); err != nil {
continue
}
window := windows[data.DeploymentID]
if window != nil {
window.endTime = eventTime
window.isComplete = true
}
}
return windows
}
// calculateDeploymentTimeUtilization calculates deployment time utilization from windows
func (p *PeriodicProcessor) calculateDeploymentTimeUtilization(
windows map[string]*deploymentWindow,
queryStart, queryEnd, now time.Time,
) ([]contracts.DeploymentTimeUtilization, float64) {
deployments := make([]contracts.DeploymentTimeUtilization, 0)
var totalTimeSec float64
for _, window := range windows {
// Use CalculateEffectiveTime helper to determine if deployment is relevant
effectiveQueryEnd := queryEnd
if !window.isComplete {
effectiveQueryEnd = now
}
effectiveStart, effectiveEnd, valid := usage.CalculateEffectiveTime(
window.startTime,
window.endTime,
window.isComplete,
queryStart,
effectiveQueryEnd,
)
if !valid {
continue
}
// Calculate duration from effective start to effective end
duration := effectiveEnd.Sub(effectiveStart)
durationSec := duration.Seconds()
if durationSec <= 0 {
continue
}
deployments = append(deployments, contracts.DeploymentTimeUtilization{
DeploymentID: window.deploymentID,
Allocations: []contracts.AllocationTimeUtilization{}, // Empty - tracking at deployment level
TotalUtilizationSec: durationSec,
})
totalTimeSec += durationSec
}
return deployments, totalTimeSec
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package eventhandler
import (
"context"
"log"
"math"
"sync"
"time"
)
// default value for max number of retries
const defaultMaxRetries = 10
// Event represents the data needed to be sent to contract actor
type Event struct {
ContractHostDID string
ContractDID string
Payload interface{}
MaxRetries int
// private
attempts int
}
// HandlerFunc defines the function signature for processing events.
type HandlerFunc func(event Event) error
// EventHandler manages event processing with concurrency and retries.
type EventHandler struct {
events chan Event
wg sync.WaitGroup
workers int
handler HandlerFunc
baseDelay time.Duration
maxDelay time.Duration
ctx context.Context
}
// New creates a new EventHandler.
// - ctx: context to cancel processing
// - workers: number of concurrent workers
// - queueSize: buffer size for the event channel
// - baseDelay: initial retry delay
// - maxDelay: maximum retry delay (cap for backoff)
// - handler: the function to process events
func New(ctx context.Context, workers, queueSize int, baseDelay, maxDelay time.Duration, handler HandlerFunc) *EventHandler {
eh := &EventHandler{
events: make(chan Event, queueSize),
workers: workers,
handler: handler,
baseDelay: baseDelay,
maxDelay: maxDelay,
ctx: ctx,
}
eh.start()
return eh
}
// start launches worker goroutines.
func (eh *EventHandler) start() {
for i := 0; i < eh.workers; i++ {
eh.wg.Add(1)
go eh.worker(i)
}
}
func (eh *EventHandler) worker(id int) {
defer eh.wg.Done()
for {
select {
case <-eh.ctx.Done():
log.Printf("[Worker %d] Stopping (context done)", id)
return
case event, ok := <-eh.events:
if !ok {
log.Printf("[Worker %d] Stopping (events channel closed)", id)
return
}
err := eh.handler(event)
if err != nil {
log.Printf("[Worker %d] Error: %v (attempt %d/%d)", id, err, event.attempts+1, event.MaxRetries)
if event.attempts < event.MaxRetries {
event.attempts++
delay := eh.backoff(event.attempts)
go func(e Event, d time.Duration) {
select {
case <-time.After(d):
eh.Push(e)
case <-eh.ctx.Done():
return
}
}(event, delay)
}
}
}
}
}
// Push adds an event to the queue.
func (eh *EventHandler) Push(event Event) {
if event.MaxRetries == 0 {
event.MaxRetries = defaultMaxRetries
}
select {
case eh.events <- event:
// queued successfully
case <-eh.ctx.Done():
// drop the event
}
}
// backoff calculates exponential backoff delay.
func (eh *EventHandler) backoff(attempt int) time.Duration {
delay := float64(eh.baseDelay) * math.Pow(2, float64(attempt-1))
if delay > float64(eh.maxDelay) {
return eh.maxDelay
}
return time.Duration(delay)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package pricing
import (
"context"
"fmt"
"time"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
)
// PriceConverter handles currency conversion for payment items
type PriceConverter struct {
oracle PriceOracle
}
// NewPriceConverter creates a new price converter
func NewPriceConverter(oracle PriceOracle) *PriceConverter {
return &PriceConverter{
oracle: oracle,
}
}
// GetOracle returns the underlying price oracle
func (c *PriceConverter) GetOracle() PriceOracle {
return c.oracle
}
// ConvertPaymentItem converts a payment item's amount from pricing currency to payment currency
// This method is OPTIONAL - if pricing_currency is not set, it returns immediately with no conversion
func (c *PriceConverter) ConvertPaymentItem(
ctx context.Context,
item *contracts.PaymentItem,
contract *contracts.Contract,
) error {
// OPTIONAL FEATURE: If pricing_currency is not specified, skip conversion entirely
// This ensures backward compatibility - existing contracts work unchanged
pricingCurrency := contract.PaymentDetails.PricingCurrency
if pricingCurrency == "" || pricingCurrency == "NTX" { //nolint:goconst
// No conversion needed - fallback to previous behavior
// Payment item amount is used as-is, exactly as before
return nil
}
// Determine payment currency from address
paymentCurrency := "NTX" // Default
if len(contract.PaymentDetails.Addresses) > 0 {
paymentCurrency = contract.PaymentDetails.Addresses[0].Currency
}
// If currencies match, no conversion needed
if pricingCurrency == paymentCurrency {
return nil
}
// Only convert if pricing currency is a stable asset and payment currency is NTX
// At this point, we know pricingCurrency is set to a stable asset (not empty, not "NTX")
if paymentCurrency == "NTX" {
originalAmount := item.Amount
// Convert amount
convertedAmount, err := c.oracle.ConvertAmount(ctx, originalAmount, pricingCurrency, paymentCurrency)
if err != nil {
return fmt.Errorf("failed to convert amount: %w", err)
}
// Get exchange rate for metadata
rate, err := c.oracle.GetPrice(ctx, pricingCurrency, paymentCurrency)
if err != nil {
return fmt.Errorf("failed to get exchange rate: %w", err)
}
// Update payment item
item.OriginalAmount = originalAmount
item.PricingCurrency = pricingCurrency
item.Amount = convertedAmount
item.ExchangeRate = formatAmount(rate)
item.ConversionTimestamp = time.Now()
// Add to metadata
if item.Metadata == nil {
item.Metadata = make(map[string]interface{})
}
item.Metadata["price_conversion"] = map[string]interface{}{
"original_amount": originalAmount,
"pricing_currency": pricingCurrency,
"converted_amount": convertedAmount,
"payment_currency": paymentCurrency,
"exchange_rate": formatAmount(rate),
"conversion_time": time.Now().Format(time.RFC3339),
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package pricing
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"gitlab.com/nunet/device-management-service/utils/convert"
)
const (
USDT = "USDT"
NTX = "NTX"
)
var allowedConversions = map[string]string{
USDT: NTX, // NTX to USDT
}
// PriceOracle defines the interface for fetching cryptocurrency prices
type PriceOracle interface {
// GetPrice fetches the current exchange rate for a trading pair
// Returns price as a float64 (e.g., 0.05 means 1 NTX = 0.05 USDT)
GetPrice(ctx context.Context, fromCurrency, toCurrency string) (float64, error)
// ConvertAmount converts an amount from one currency to another
ConvertAmount(ctx context.Context, amount string, fromCurrency, toCurrency string) (string, error)
}
// CoinMarketCapOracle implements PriceOracle using CoinMarketCap API
type CoinMarketCapOracle struct {
apiKey string
baseURL string
endpointPath string // API endpoint path (e.g., "/tools/price-conversion")
httpClient *http.Client
cache *PriceCache
}
// PriceCache stores cached exchange rates
type PriceCache struct {
rates map[string]CachedRate
ttl time.Duration
mu sync.RWMutex
}
type CachedRate struct {
Rate float64
Timestamp time.Time
}
// NewCoinMarketCapOracle creates a new CoinMarketCap oracle instance
// baseURL is the API endpoint base URL (e.g., "https://pro-api.coinmarketcap.com/v1")
// If baseURL is empty, defaults to "https://pro-api.coinmarketcap.com/v1"
// endpointPath is the API endpoint path (e.g., "/tools/price-conversion")
// If endpointPath is empty, defaults to "/tools/price-conversion"
func NewCoinMarketCapOracle(apiKey string, baseURL string, endpointPath string, cacheTTL time.Duration) *CoinMarketCapOracle {
// Use default base URL if not specified
if baseURL == "" {
baseURL = "https://pro-api.coinmarketcap.com/v2"
}
// Use default endpoint path if not specified
if endpointPath == "" {
endpointPath = "/tools/price-conversion"
}
return &CoinMarketCapOracle{
apiKey: apiKey,
baseURL: baseURL,
endpointPath: endpointPath,
httpClient: &http.Client{
Timeout: 5 * time.Second,
},
cache: &PriceCache{
rates: make(map[string]CachedRate),
ttl: cacheTTL,
},
}
}
// GetPrice implements PriceOracle.GetPrice
func (o *CoinMarketCapOracle) GetPrice(ctx context.Context, fromCurrency, toCurrency string) (float64, error) {
// Normalize currency symbols
fromCurrency = normalizeCurrency(fromCurrency)
toCurrency = normalizeCurrency(toCurrency)
// Check cache first
cacheKey := fmt.Sprintf("%s/%s", fromCurrency, toCurrency)
if rate, found := o.cache.Get(cacheKey); found {
return rate, nil
}
// Fetch from API
rate, err := o.fetchPriceFromAPI(ctx, fromCurrency, toCurrency)
if err != nil {
// Try to get stale cache as fallback
if staleRate, found := o.cache.GetStale(cacheKey); found {
// Using stale cache as fallback - API fetch failed
return staleRate, nil
}
return 0, fmt.Errorf("failed to fetch price: %w", err)
}
// Update cache
o.cache.Set(cacheKey, rate)
return rate, nil
}
// ConvertAmount implements PriceOracle.ConvertAmount
func (o *CoinMarketCapOracle) ConvertAmount(
ctx context.Context,
amount string,
fromCurrency, toCurrency string,
) (string, error) {
if v, ok := allowedConversions[fromCurrency]; !ok || v != toCurrency {
return "", fmt.Errorf("unsupported conversion: %s to %s", fromCurrency, toCurrency)
}
// Parse amount
amountFloat, err := convert.StringToFloat64(amount, "amount")
if err != nil {
return "", fmt.Errorf("invalid amount: %w", err)
}
// Get exchange rate
rate, err := o.GetPrice(ctx, fromCurrency, toCurrency)
if err != nil {
return "", err
}
convertedAmount := amountFloat * rate
return formatAmount(convertedAmount), nil
}
// fetchPriceFromAPI fetches price from CoinMarketCap API
// Uses the configurable baseURL and endpointPath from the oracle instance
func (o *CoinMarketCapOracle) fetchPriceFromAPI(ctx context.Context, fromCurrency, toCurrency string) (float64, error) {
// CoinMarketCap API endpoint for price conversion
url := fmt.Sprintf("%s%s?amount=1&symbol=%s&convert=%s",
o.baseURL, o.endpointPath, fromCurrency, toCurrency)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return 0, err
}
req.Header.Set("X-CMC_PRO_API_KEY", o.apiKey)
req.Header.Set("Accept", "application/json")
resp, err := o.httpClient.Do(req)
if err != nil {
return 0, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return 0, fmt.Errorf("API returned status %d", resp.StatusCode)
}
// Parse v2 API response - data is an array in v2
var apiResponse struct {
Data []struct {
Amount float64 `json:"amount"`
Symbol string `json:"symbol"`
Quote map[string]struct {
Price float64 `json:"price"`
LastUpdated string `json:"last_updated,omitempty"`
} `json:"quote"`
} `json:"data"`
Status struct {
ErrorCode int `json:"error_code"`
ErrorMessage string `json:"error_message,omitempty"`
} `json:"status"`
}
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return 0, fmt.Errorf("failed to decode API response: %w", err)
}
// Check for API errors
if apiResponse.Status.ErrorCode != 0 {
return 0, fmt.Errorf("API error: %s (code: %d)", apiResponse.Status.ErrorMessage, apiResponse.Status.ErrorCode)
}
// In v2, data is an array, so we take the first element
if len(apiResponse.Data) == 0 {
return 0, fmt.Errorf("no data in API response")
}
quote, ok := apiResponse.Data[0].Quote[toCurrency]
if !ok {
return 0, fmt.Errorf("currency %s not found in response", toCurrency)
}
return quote.Price, nil
}
// normalizeCurrency normalizes currency symbols for API calls
func normalizeCurrency(currency string) string {
// Map common variations to standard symbols
currencyMap := map[string]string{
"USDT": "USDT",
"NTX": "NTX",
}
if normalized, ok := currencyMap[strings.ToUpper(currency)]; ok {
return normalized
}
return strings.ToUpper(currency)
}
// formatAmount formats a float64 amount as a string with 8 decimal places
func formatAmount(amount float64) string {
return fmt.Sprintf("%.8f", amount)
}
// PriceCache methods
func (c *PriceCache) Get(key string) (float64, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
cached, ok := c.rates[key]
if !ok {
return 0, false
}
if time.Since(cached.Timestamp) > c.ttl {
return 0, false // Cache expired
}
return cached.Rate, true
}
func (c *PriceCache) GetStale(key string) (float64, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
cached, ok := c.rates[key]
if !ok {
return 0, false
}
return cached.Rate, true // Return even if stale
}
func (c *PriceCache) Set(key string, rate float64) {
c.mu.Lock()
defer c.mu.Unlock()
c.rates[key] = CachedRate{
Rate: rate,
Timestamp: time.Now(),
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package store
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/ostafen/clover/v2"
"github.com/ostafen/clover/v2/document"
"github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/lib/did"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
"gitlab.com/nunet/device-management-service/types"
)
const (
contractsCollection = "contracts"
contractsKeysCollection = "contracts_keys"
)
type ContractKey struct {
ContractDID string `json:"contract_did"`
Key []byte `json:"key"`
}
type Store struct {
db *clover.DB
}
// New contract store
func New(db *clover.DB) (*Store, error) {
if db == nil {
return nil, errors.New("db is nil")
}
return &Store{
db: db,
}, nil
}
// Insert inserts or updates a contract in CloverDB (Upsert behavior)
func (s *Store) Upsert(contract *contracts.Contract) error {
if contract == nil {
return errors.New("contract is nil")
}
bts, err := json.Marshal(contract)
if err != nil {
return fmt.Errorf("failed to marshal contract: %w", err)
}
q := query.NewQuery(contractsCollection).Where(query.Field("contract_did").Eq(contract.ContractDID))
existingDoc, err := s.db.FindFirst(q)
if err != nil {
return fmt.Errorf("failed to check existing contract: %w", err)
}
if existingDoc != nil {
// Update the existing document
update := document.NewDocument()
update.Set("contract_did", contract.ContractDID)
update.Set("updated_at", time.Now().UnixNano())
update.Set("contract_data", bts)
return s.db.Update(q, update.AsMap())
}
// Insert a new document
doc := document.NewDocumentOf(contract)
doc.Set("contract_did", contract.ContractDID)
doc.Set("created_at", time.Now().UnixNano())
doc.Set("contract_data", bts)
return s.db.Insert(contractsCollection, doc)
}
// GetContract retrieves a contract by ContractDID
func (s *Store) GetContract(contractDID string) (*contracts.Contract, error) {
if contractDID == "" {
return nil, errors.New("contractDID is empty")
}
q := query.NewQuery(contractsCollection).Where(query.Field("contract_did").Eq(contractDID))
doc, err := s.db.FindFirst(q)
if err != nil || doc == nil {
return nil, fmt.Errorf("failed to find contract by ID: %w", err)
}
var contract contracts.Contract
data := doc.Get("contract_data")
contractData, ok := data.([]byte)
if !ok {
return nil, errors.New("no contract data available")
}
err = json.Unmarshal(contractData, &contract)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal contract: %w", err)
}
return &contract, nil
}
// InsertContractKey inserts key
func (s *Store) InsertContractKey(c ContractKey) error {
bts, err := json.Marshal(c)
if err != nil {
return fmt.Errorf("failed to marshal contract key: %w", err)
}
doc := document.NewDocumentOf(c)
doc.Set("contract_did", c.ContractDID)
doc.Set("created_at", time.Now().UnixNano())
doc.Set("key_data", bts)
return s.db.Insert(contractsKeysCollection, doc)
}
// GetContractKey retrieves a contract by ContractDID
func (s *Store) GetContractKey(contractDID string) (*ContractKey, error) {
q := query.NewQuery(contractsKeysCollection).Where(query.Field("contract_did").Eq(contractDID))
doc, err := s.db.FindFirst(q)
if err != nil || doc == nil {
return nil, fmt.Errorf("failed to find contract by did: %w", err)
}
var key ContractKey
data := doc.Get("key_data")
err = json.Unmarshal(data.([]byte), &key)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal contract: %w", err)
}
return &key, nil
}
// GetAllContracts retrieves all contracts from the database
func (s *Store) GetAllContracts() ([]*contracts.Contract, error) {
q := query.NewQuery(contractsCollection)
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve all contracts: %w", err)
}
allContracts := make([]*contracts.Contract, 0)
for _, doc := range docs {
var contract contracts.Contract
data := doc.Get("contract_data")
err = json.Unmarshal(data.([]byte), &contract)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal single contract: %w", err)
}
allContracts = append(allContracts, &contract)
}
return allContracts, nil
}
// FindContractByParticipants finds a contract where both participants match
// Returns the contract if exactly one active contract exists between the two parties
func (s *Store) FindContractByParticipants(participant1, participant2 did.DID) (*contracts.Contract, error) {
q := query.NewQuery(contractsCollection)
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve contracts: %w", err)
}
participant1Str := participant1.String()
participant2Str := participant2.String()
var matches []*contracts.Contract
for _, doc := range docs {
var contract contracts.Contract
data := doc.Get("contract_data")
contractData, ok := data.([]byte)
if !ok {
continue
}
if err := json.Unmarshal(contractData, &contract); err != nil {
continue
}
provStr := contract.ContractParticipants.Provider.String()
reqStr := contract.ContractParticipants.Requestor.String()
// Check if both participants match (order-independent)
matchesParticipant1 := (provStr == participant1Str || reqStr == participant1Str)
matchesParticipant2 := (provStr == participant2Str || reqStr == participant2Str)
if matchesParticipant1 && matchesParticipant2 {
// Only include active contracts (ACCEPTED or ACTIVE state)
if contract.CurrentState == contracts.ContractAccepted || contract.CurrentState == contracts.ContractActive {
matches = append(matches, &contract)
}
}
}
if len(matches) == 0 {
return nil, fmt.Errorf("no active contract found between participants")
}
if len(matches) > 1 {
return nil, fmt.Errorf("multiple active contracts found between participants (violates single contract rule)")
}
return matches[0], nil
}
// FindContractsByParticipant finds contracts where the given DID is a participant
func (s *Store) FindContractsByParticipant(participant did.DID) ([]*contracts.Contract, error) {
q := query.NewQuery(contractsCollection)
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve contracts: %w", err)
}
participantStr := participant.String()
var matches []*contracts.Contract
for _, doc := range docs {
var contract contracts.Contract
data := doc.Get("contract_data")
contractData, ok := data.([]byte)
if !ok {
continue
}
if err := json.Unmarshal(contractData, &contract); err != nil {
continue
}
provStr := contract.ContractParticipants.Provider.String()
reqStr := contract.ContractParticipants.Requestor.String()
if provStr == participantStr || reqStr == participantStr {
// Only include active contracts
if contract.CurrentState == contracts.ContractAccepted || contract.CurrentState == contracts.ContractActive {
matches = append(matches, &contract)
}
}
}
return matches, nil
}
// FindTailContractConfig finds Tail Contracts and returns them as ContractConfig.
// This method wraps FindTailContract to return ContractConfig format for use in jobs package.
// Note: This method name matches the TailContractFinder interface from jobs package.
func (s *Store) FindTailContract(
headContractConfig types.ContractConfig,
computeProviderDID string,
) (*types.ContractConfig, error) {
tailContracts, err := s.findTailContracts(headContractConfig, computeProviderDID)
if err != nil {
return nil, err
}
if len(tailContracts) == 0 {
return nil, fmt.Errorf("no tail contracts found")
}
if len(tailContracts) > 1 {
return nil, fmt.Errorf("multiple tail contracts found")
}
return &types.ContractConfig{
DID: tailContracts[0].ContractDID,
Host: tailContracts[0].SolutionEnablerDID.String(),
Provider: tailContracts[0].ContractParticipants.Provider.String(),
}, nil
}
// findTailContracts is the internal method that returns []*contracts.Contract.
// This is used by FindTailContract to avoid naming conflicts.
func (s *Store) findTailContracts(
headContractConfig types.ContractConfig,
computeProviderDID string,
) ([]*contracts.Contract, error) {
if headContractConfig.Provider == "" {
return nil, fmt.Errorf("head contract config is missing provider DID")
}
organizationDID, err := did.FromString(headContractConfig.Provider)
if err != nil {
return nil, fmt.Errorf("invalid organization DID in head contract config: %w", err)
}
computeProviderDIDObj, err := did.FromString(computeProviderDID)
if err != nil {
return nil, fmt.Errorf("invalid compute provider DID: %w", err)
}
// Get all contracts
allContracts, err := s.GetAllContracts()
if err != nil {
return nil, fmt.Errorf("failed to get all contracts: %w", err)
}
var tailContracts []*contracts.Contract
for _, contract := range allContracts {
// Skip the head contract itself (if it exists locally)
if contract.ContractDID == headContractConfig.DID {
continue
}
// Check if this is a Tail Contract:
// 1. Requestor matches Organization (from Head Contract config's Provider)
// 2. Provider is the compute provider (this node)
// 3. Contract is active
requestorMatchesOrg := contract.ContractParticipants.Requestor.Equal(organizationDID)
providerIsComputeProvider := contract.ContractParticipants.Provider.Equal(computeProviderDIDObj)
isActive := contract.CurrentState == contracts.ContractAccepted ||
contract.CurrentState == contracts.ContractActive
if requestorMatchesOrg && providerIsComputeProvider && isActive {
tailContracts = append(tailContracts, contract)
}
}
return tailContracts, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package payment
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/ostafen/clover/v2"
"github.com/ostafen/clover/v2/document"
"github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
)
const (
paymentsCollection = "contracts_payments"
)
type Payment struct {
UniqueID string `json:"unique_id"`
Contract contracts.Contract `json:"contract"`
Usages int `json:"usages"`
Amount string `json:"amount"`
Paid bool `json:"paid"`
}
type Store struct {
db *clover.DB
}
// New payment store
func New(db *clover.DB) (*Store, error) {
if db == nil {
return nil, errors.New("db is nil")
}
return &Store{
db: db,
}, nil
}
func (s *Store) Insert(p Payment) error {
bts, err := json.Marshal(p)
if err != nil {
return fmt.Errorf("failed to marshal payment: %w", err)
}
q := query.NewQuery(paymentsCollection).Where(query.Field("unique_id").Eq(p.UniqueID))
existingDoc, err := s.db.FindFirst(q)
if err != nil {
return fmt.Errorf("failed to check existing payment: %w", err)
}
if existingDoc != nil {
// payment with this unique_id already exists
return nil
}
doc := document.NewDocumentOf(p)
doc.Set("unique_id", p.UniqueID)
doc.Set("contract_did", p.Contract.ContractDID)
doc.Set("created_at", time.Now().UnixNano())
doc.Set("payment_data", bts)
return s.db.Insert(paymentsCollection, doc)
}
// AllPayments retrieves all payments from the database
func (s *Store) AllPayments() ([]*Payment, error) {
q := query.NewQuery(paymentsCollection)
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve all payments: %w", err)
}
allPayments := make([]*Payment, 0)
for _, doc := range docs {
var payment Payment
data := doc.Get("payment_data")
err = json.Unmarshal(data.([]byte), &payment)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal single payment: %w", err)
}
allPayments = append(allPayments, &payment)
}
return allPayments, nil
}
// GetByUniqueID retrieves a payment by its unique_id
func (s *Store) GetByUniqueID(uniqueID string) (*Payment, error) {
q := query.NewQuery(paymentsCollection).Where(query.Field("unique_id").Eq(uniqueID))
doc, err := s.db.FindFirst(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve payment with unique_id %s: %w", uniqueID, err)
}
if doc == nil {
return nil, fmt.Errorf("payment with unique_id %s not found", uniqueID)
}
data, ok := doc.Get("payment_data").([]byte)
if !ok {
return nil, fmt.Errorf("invalid data format for payment with unique_id %s", uniqueID)
}
var payment Payment
if err := json.Unmarshal(data, &payment); err != nil {
return nil, fmt.Errorf("failed to unmarshal payment with unique_id %s: %w", uniqueID, err)
}
return &payment, nil
}
// Update updates an existing payment by its unique_id.
func (s *Store) Update(p *Payment) error {
if p.UniqueID == "" {
return errors.New("unique_id is required for update")
}
q := query.NewQuery(paymentsCollection).Where(query.Field("unique_id").Eq(p.UniqueID))
doc, err := s.db.FindFirst(q)
if err != nil {
return fmt.Errorf("failed to query existing payment: %w", err)
}
if doc == nil {
return fmt.Errorf("payment with unique_id %s not found", p.UniqueID)
}
bts, err := json.Marshal(p)
if err != nil {
return fmt.Errorf("failed to marshal payment: %w", err)
}
err = s.db.UpdateById(paymentsCollection, doc.ObjectId(), func(d *document.Document) *document.Document {
d.Set("payment_data", bts)
d.Set("updated_at", time.Now().Unix())
return d
})
if err != nil {
return fmt.Errorf("failed to update payment with unique_id %s: %w", p.UniqueID, err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package paymentquote
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/ostafen/clover/v2"
"github.com/ostafen/clover/v2/document"
"github.com/ostafen/clover/v2/query"
)
const (
paymentQuotesCollection = "payment_quotes"
defaultQuoteTTL = 2 * time.Minute // Quotes expire after 2 minutes
)
type PaymentQuote struct {
QuoteID string `json:"quote_id"` // Unique quote identifier
UniqueID string `json:"unique_id"` // Links to transaction
OriginalAmount string `json:"original_amount"` // Amount in pricing currency (USDT)
ConvertedAmount string `json:"converted_amount"` // Amount in payment currency (NTX)
PricingCurrency string `json:"pricing_currency"` // Original currency (e.g., "USDT")
PaymentCurrency string `json:"payment_currency"` // Payment currency (e.g., "NTX")
ExchangeRate string `json:"exchange_rate"` // Rate used for conversion
CreatedAt time.Time `json:"created_at"` // Quote creation timestamp
ExpiresAt time.Time `json:"expires_at"` // Quote expiration timestamp
Used bool `json:"used"` // Whether quote has been used
UsedAt time.Time `json:"used_at,omitempty"` // When quote was used (if used)
}
type Store struct {
db *clover.DB
}
func New(db *clover.DB) (*Store, error) {
if db == nil {
return nil, errors.New("db is nil")
}
// Ensure collection exists
hasCollection, err := db.HasCollection(paymentQuotesCollection)
if err != nil {
return nil, fmt.Errorf("failed to check collection: %w", err)
}
if !hasCollection {
if err := db.CreateCollection(paymentQuotesCollection); err != nil {
return nil, fmt.Errorf("failed to create collection: %w", err)
}
}
return &Store{db: db}, nil
}
// CreateQuote creates a new payment quote
func (s *Store) CreateQuote(quote PaymentQuote) error {
if quote.QuoteID == "" {
return errors.New("quote_id is required")
}
if quote.UniqueID == "" {
return errors.New("unique_id is required")
}
bts, err := json.Marshal(quote)
if err != nil {
return fmt.Errorf("failed to marshal quote: %w", err)
}
doc := document.NewDocumentOf(quote)
doc.Set("quote_id", quote.QuoteID)
doc.Set("unique_id", quote.UniqueID)
doc.Set("created_at", time.Now().UnixNano())
doc.Set("quote_data", bts)
return s.db.Insert(paymentQuotesCollection, doc)
}
// GetQuote retrieves a quote by quote_id (does not check expiration - use ValidateQuote for that)
func (s *Store) GetQuote(quoteID string) (*PaymentQuote, error) {
q := query.NewQuery(paymentQuotesCollection).Where(query.Field("quote_id").Eq(quoteID))
doc, err := s.db.FindFirst(q)
if err != nil {
return nil, fmt.Errorf("failed to find quote: %w", err)
}
if doc == nil {
return nil, fmt.Errorf("quote not found: %s", quoteID)
}
data := doc.Get("quote_data")
var quote PaymentQuote
if err := json.Unmarshal(data.([]byte), "e); err != nil {
return nil, fmt.Errorf("failed to unmarshal quote: %w", err)
}
return "e, nil
}
// ValidateQuote checks if a quote is valid (not used, not expired)
func (s *Store) ValidateQuote(quoteID string) (*PaymentQuote, error) {
quote, err := s.GetQuote(quoteID)
if err != nil {
return nil, err
}
// Check if quote is used
if quote.Used {
return nil, fmt.Errorf("quote already used: %s", quoteID)
}
// Check if quote is expired
if time.Now().After(quote.ExpiresAt) {
return nil, fmt.Errorf("quote expired: %s", quoteID)
}
return quote, nil
}
// MarkQuoteAsUsed marks a quote as used
func (s *Store) MarkQuoteAsUsed(quoteID string) error {
q := query.NewQuery(paymentQuotesCollection).Where(query.Field("quote_id").Eq(quoteID))
doc, err := s.db.FindFirst(q)
if err != nil {
return fmt.Errorf("failed to find quote: %w", err)
}
if doc == nil {
return fmt.Errorf("quote not found: %s", quoteID)
}
data := doc.Get("quote_data")
var quote PaymentQuote
if err := json.Unmarshal(data.([]byte), "e); err != nil {
return fmt.Errorf("failed to unmarshal quote: %w", err)
}
if quote.Used {
return fmt.Errorf("quote already used: %s", quoteID)
}
quote.Used = true
quote.UsedAt = time.Now()
bts, err := json.Marshal(quote)
if err != nil {
return fmt.Errorf("failed to marshal quote: %w", err)
}
update := map[string]interface{}{
"quote_data": bts,
}
return s.db.Update(q, update)
}
// GetQuoteByUniqueID retrieves the most recent unused quote for a transaction
func (s *Store) GetQuoteByUniqueID(uniqueID string) (*PaymentQuote, error) {
q := query.NewQuery(paymentQuotesCollection).
Where(query.Field("unique_id").Eq(uniqueID)).
Sort(query.SortOption{Field: "created_at", Direction: -1})
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to find quotes: %w", err)
}
// Find first unused quote (expiration check done in ValidateQuote)
for _, doc := range docs {
data := doc.Get("quote_data")
var quote PaymentQuote
if err := json.Unmarshal(data.([]byte), "e); err != nil {
continue
}
if !quote.Used {
return "e, nil
}
}
return nil, fmt.Errorf("no unused quote found for unique_id: %s", uniqueID)
}
// HasActiveQuote checks if there's an active (unused and not expired) quote for a transaction
func (s *Store) HasActiveQuote(uniqueID string) (*PaymentQuote, error) {
q := query.NewQuery(paymentQuotesCollection).
Where(query.Field("unique_id").Eq(uniqueID)).
Sort(query.SortOption{Field: "created_at", Direction: -1})
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to find quotes: %w", err)
}
// Find first active quote (not used and not expired)
now := time.Now()
for _, doc := range docs {
data := doc.Get("quote_data")
var quote PaymentQuote
if err := json.Unmarshal(data.([]byte), "e); err != nil {
continue
}
// Check if quote is unused and not expired
if !quote.Used && now.Before(quote.ExpiresAt) {
return "e, nil
}
}
return nil, nil // No active quote found (not an error)
}
// InvalidateQuote explicitly invalidates a quote (e.g., when user cancels payment)
// This marks the quote as used without actually using it for payment
func (s *Store) InvalidateQuote(quoteID string) error {
// Same logic as MarkQuoteAsUsed, but we can add a flag if needed
// For now, marking as used effectively invalidates it
return s.MarkQuoteAsUsed(quoteID)
}
// CleanupExpiredQuotes removes expired quotes older than the retention period
// This should be called periodically (e.g., daily) to clean up old data
func (s *Store) CleanupExpiredQuotes(retentionPeriod time.Duration) error {
cutoffTime := time.Now().Add(-retentionPeriod)
q := query.NewQuery(paymentQuotesCollection).
Where(query.Field("created_at").Lt(cutoffTime.UnixNano()))
docs, err := s.db.FindAll(q)
if err != nil {
return fmt.Errorf("failed to find expired quotes: %w", err)
}
for _, doc := range docs {
if err := s.db.DeleteById(paymentQuotesCollection, doc.ObjectId()); err != nil {
// Log warning but continue
continue
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package transaction
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/ostafen/clover/v2"
"github.com/ostafen/clover/v2/document"
"github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/types"
)
const (
transactionsCollection = "service_provider_transactions"
)
// Transaction struct
type Transaction struct {
UniqueID string `json:"unique_id"`
PaymentValidatorDID string `json:"payment_validator_did"`
ContractDID string `json:"contract_did"`
ToAddress []types.PaymentAddressInfo `json:"to_address"`
Amount string `json:"amount"`
Status string `json:"status"`
TxHash string `json:"tx_hash"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
// New fields for price conversion
OriginalAmount string `json:"original_amount,omitempty"` // Amount in pricing currency (USDT)
PricingCurrency string `json:"pricing_currency,omitempty"` // Currency of original amount (e.g., "USDT")
RequiresConversion bool `json:"requires_conversion,omitempty"` // True if conversion is needed
}
type Store struct {
db *clover.DB
}
// New creates a new transaction store
func New(db *clover.DB) (*Store, error) {
if db == nil {
return nil, errors.New("db is nil")
}
return &Store{db: db}, nil
}
// Upsert inserts a transaction if it doesn't already exist
func (s *Store) Upsert(t Transaction) error {
if t.Status == "" {
t.Status = "unpaid"
}
bts, err := json.Marshal(t)
if err != nil {
return fmt.Errorf("failed to marshal transaction: %w", err)
}
q := query.NewQuery(transactionsCollection).Where(query.Field("unique_id").Eq(t.UniqueID))
existingDoc, err := s.db.FindFirst(q)
if err != nil {
return fmt.Errorf("failed to check existing transaction: %w", err)
}
if existingDoc != nil {
return nil
}
doc := document.NewDocumentOf(t)
doc.Set("unique_id", t.UniqueID)
doc.Set("contract_did", t.ContractDID)
doc.Set("created_at", time.Now().UnixNano())
doc.Set("transaction_data", bts)
return s.db.Insert(transactionsCollection, doc)
}
// AllTransactions retrieves all transactions from the database
func (s *Store) AllTransactions() ([]*Transaction, error) {
q := query.NewQuery(transactionsCollection)
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve all transactions: %w", err)
}
allTransactions := make([]*Transaction, 0)
for _, doc := range docs {
var t Transaction
data := doc.Get("transaction_data")
err = json.Unmarshal(data.([]byte), &t)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal single transaction: %w", err)
}
allTransactions = append(allTransactions, &t)
}
return allTransactions, nil
}
// MarkAsPaid updates a transaction's status to "paid" given its unique ID
// it returns also the payment provider did
func (s *Store) MarkAsPaid(uniqueID string, txHash string) (string, error) {
q := query.NewQuery(transactionsCollection).Where(query.Field("unique_id").Eq(uniqueID))
doc, err := s.db.FindFirst(q)
if err != nil {
return "", fmt.Errorf("failed to find transaction: %w", err)
}
if doc == nil {
return "", fmt.Errorf("transaction not found with unique_id: %s", uniqueID)
}
data := doc.Get("transaction_data")
var t Transaction
if err := json.Unmarshal(data.([]byte), &t); err != nil {
return "", fmt.Errorf("failed to unmarshal transaction: %w", err)
}
t.Status = "paid"
t.TxHash = txHash
bts, err := json.Marshal(t)
if err != nil {
return "", fmt.Errorf("failed to marshal updated transaction: %w", err)
}
update := map[string]interface{}{
"transaction_data": bts,
}
return t.PaymentValidatorDID, s.db.Update(q, update)
}
func (s *Store) GetPaymentValidatorDID(uniqueID string) (string, error) {
q := query.NewQuery(transactionsCollection).Where(query.Field("unique_id").Eq(uniqueID))
doc, err := s.db.FindFirst(q)
if err != nil {
return "", fmt.Errorf("failed to find transaction: %w", err)
}
if doc == nil {
return "", fmt.Errorf("transaction not found with unique_id: %s", uniqueID)
}
data := doc.Get("transaction_data")
var t Transaction
if err := json.Unmarshal(data.([]byte), &t); err != nil {
return "", fmt.Errorf("failed to unmarshal transaction: %w", err)
}
return t.PaymentValidatorDID, nil
}
// GetTransactionByUniqueID retrieves a transaction by its unique ID
func (s *Store) GetTransactionByUniqueID(uniqueID string) (*Transaction, error) {
q := query.NewQuery(transactionsCollection).Where(query.Field("unique_id").Eq(uniqueID))
doc, err := s.db.FindFirst(q)
if err != nil {
return nil, fmt.Errorf("failed to find transaction: %w", err)
}
if doc == nil {
return nil, fmt.Errorf("transaction not found with unique_id: %s", uniqueID)
}
data := doc.Get("transaction_data")
var t Transaction
if err := json.Unmarshal(data.([]byte), &t); err != nil {
return nil, fmt.Errorf("failed to unmarshal transaction: %w", err)
}
return &t, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package usage
import (
"encoding/json"
"errors"
"fmt"
"time"
"github.com/ostafen/clover/v2"
"github.com/ostafen/clover/v2/document"
"github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/tokenomics/events"
)
const (
contractsUsageCollection = "contracts_usage"
lastProcessedAtCollection = "usage_metadata"
lastProcessedAtKeyPrefix = "last_processed_at"
)
type Usage struct {
ContractDID string `json:"contract_did"` // Tail Contract DID (existing)
HeadContractDID string `json:"head_contract_did,omitempty"` // Head Contract DID (new)
ProviderDID string `json:"provider_did,omitempty"` // Provider DID for per-node billing
EventType events.EventType `json:"event_type,omitempty"` // For indexing - extracted from JSON if not provided
Data []byte `json:"data"` // Raw JSON bytes
Timestamp time.Time `json:"timestamp,omitempty"` // Event timestamp
}
type Store struct {
db *clover.DB
}
// EventFilters defines filters for querying events
type EventFilters struct {
ContractDID string // Tail Contract DID
HeadContractDID string // Head Contract DID (new)
EventTypes []events.EventType
StartTime time.Time
EndTime time.Time
}
func New(db *clover.DB) (*Store, error) {
if db == nil {
return nil, errors.New("db is nil")
}
return &Store{
db: db,
}, nil
}
func (s *Store) AddUsageEvent(u Usage) error {
if u.ContractDID == "" {
return errors.New("contractDID is empty")
}
doc := document.NewDocument()
doc.Set("contract_did", u.ContractDID)
doc.Set("head_contract_did", u.HeadContractDID) // Store Head Contract DID (empty for non-chain contracts)
doc.Set("provider_did", u.ProviderDID) // Store provider DID for filtering
doc.Set("created_at", time.Now().UnixNano())
doc.Set("usage_data", u.Data)
// Extract event_type from JSON for indexing (if not already provided)
eventType := u.EventType
if eventType == "" && len(u.Data) > 0 {
var base events.EventBase
if err := json.Unmarshal(u.Data, &base); err == nil {
eventType = base.Type
}
}
// Store event_type as indexed field for efficient querying
if eventType != "" {
doc.Set("event_type", string(eventType))
}
_, err := s.db.InsertOne(contractsUsageCollection, doc)
if err != nil {
return fmt.Errorf("failed to insert usage event: %w", err)
}
return nil
}
func (s *Store) GetEventsByContract(contractDID string) ([]*Usage, error) {
q := query.NewQuery(contractsUsageCollection).Where(query.Field("contract_did").Eq(contractDID))
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve usages for contract %s: %w", contractDID, err)
}
usages := make([]*Usage, 0, len(docs))
for _, doc := range docs {
var u Usage
if cdid, ok := doc.Get("contract_did").(string); ok {
u.ContractDID = cdid
}
if hcDid, ok := doc.Get("head_contract_did").(string); ok {
u.HeadContractDID = hcDid
}
if data, ok := doc.Get("usage_data").([]byte); ok {
u.Data = data
}
if eventTypeStr, ok := doc.Get("event_type").(string); ok {
u.EventType = events.EventType(eventTypeStr)
}
usages = append(usages, &u)
}
return usages, nil
}
// GetAllEvents retrieves all events from DB.
func (s *Store) GetAllEvents() ([]*Usage, error) {
q := query.NewQuery(contractsUsageCollection)
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve all usages: %w", err)
}
allUsages := make([]*Usage, 0)
for _, doc := range docs {
var currentUsage Usage
data := doc.Get("usage_data")
currentUsage.Data = data.([]byte)
currentUsage.ContractDID = doc.Get("contract_did").(string)
if hcDid, ok := doc.Get("head_contract_did").(string); ok {
currentUsage.HeadContractDID = hcDid
}
if eventTypeStr, ok := doc.Get("event_type").(string); ok {
currentUsage.EventType = events.EventType(eventTypeStr)
}
allUsages = append(allUsages, ¤tUsage)
}
return allUsages, nil
}
// GetEventsByDateRange retrieves all events created within the given date range.
func (s *Store) GetEventsByDateRange(start, end time.Time) ([]*Usage, error) {
q := query.NewQuery(contractsUsageCollection).Where(
query.Field("created_at").GtEq(start.UnixNano()).And(query.Field("created_at").LtEq(end.UnixNano())),
)
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve usages by date range: %w", err)
}
usages := make([]*Usage, 0, len(docs))
for _, doc := range docs {
var u Usage
if contractDID, ok := doc.Get("contract_did").(string); ok {
u.ContractDID = contractDID
}
if hcDid, ok := doc.Get("head_contract_did").(string); ok {
u.HeadContractDID = hcDid
}
if data, ok := doc.Get("usage_data").([]byte); ok {
u.Data = data
}
if eventTypeStr, ok := doc.Get("event_type").(string); ok {
u.EventType = events.EventType(eventTypeStr)
}
usages = append(usages, &u)
}
return usages, nil
}
// QueryEvents - Generic query with filters (event_type-based filtering)
func (s *Store) QueryEvents(filters EventFilters) ([]*Usage, error) {
q := query.NewQuery(contractsUsageCollection)
// Build up conditions incrementally using And() to ensure proper AND logic
// This approach ensures all conditions are properly combined as AND conditions
var combinedCondition query.Criteria
hasCondition := false
if filters.ContractDID != "" {
combinedCondition = query.Field("contract_did").Eq(filters.ContractDID)
hasCondition = true
}
if filters.HeadContractDID != "" {
headContractCondition := query.Field("head_contract_did").Eq(filters.HeadContractDID)
if !hasCondition {
combinedCondition = headContractCondition
} else {
combinedCondition = combinedCondition.And(headContractCondition)
}
hasCondition = true
}
if len(filters.EventTypes) > 0 {
typeStrs := make([]interface{}, len(filters.EventTypes))
for i, et := range filters.EventTypes {
typeStrs[i] = string(et)
}
eventTypeCondition := query.Field("event_type").In(typeStrs...)
if !hasCondition {
combinedCondition = eventTypeCondition
} else {
combinedCondition = combinedCondition.And(eventTypeCondition)
}
hasCondition = true
}
if !filters.StartTime.IsZero() {
startTimeCondition := query.Field("created_at").GtEq(filters.StartTime.UnixNano())
if !hasCondition {
combinedCondition = startTimeCondition
} else {
combinedCondition = combinedCondition.And(startTimeCondition)
}
hasCondition = true
}
if !filters.EndTime.IsZero() {
endTimeCondition := query.Field("created_at").LtEq(filters.EndTime.UnixNano())
if !hasCondition {
combinedCondition = endTimeCondition
} else {
combinedCondition = combinedCondition.And(endTimeCondition)
}
hasCondition = true
}
// Apply combined condition if we have any
if hasCondition {
q = q.Where(combinedCondition)
}
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to query events: %w", err)
}
usages := make([]*Usage, 0, len(docs))
for _, doc := range docs {
var u Usage
if cdid, ok := doc.Get("contract_did").(string); ok {
u.ContractDID = cdid
}
if hcDid, ok := doc.Get("head_contract_did").(string); ok {
u.HeadContractDID = hcDid
}
if data, ok := doc.Get("usage_data").([]byte); ok {
u.Data = data
}
if eventTypeStr, ok := doc.Get("event_type").(string); ok {
u.EventType = events.EventType(eventTypeStr)
}
// Extract timestamp from created_at
if timestampNano, ok := doc.Get("created_at").(int64); ok {
u.Timestamp = time.Unix(0, timestampNano)
}
usages = append(usages, &u)
}
return usages, nil
}
// QueryEventsByProvider queries events filtered by contract and provider
func (s *Store) QueryEventsByProvider(contractDID, providerDID string, filters EventFilters) ([]*Usage, error) {
filters.ContractDID = contractDID
q := query.NewQuery(contractsUsageCollection)
// Build conditions
var conditions []query.Criteria
conditions = append(conditions, query.Field("contract_did").Eq(contractDID))
if providerDID != "" {
conditions = append(conditions, query.Field("provider_did").Eq(providerDID))
}
// Apply other filters
if len(filters.EventTypes) > 0 {
typeStrs := make([]interface{}, len(filters.EventTypes))
for i, et := range filters.EventTypes {
typeStrs[i] = string(et)
}
conditions = append(conditions, query.Field("event_type").In(typeStrs...))
}
if !filters.StartTime.IsZero() {
conditions = append(conditions, query.Field("created_at").GtEq(filters.StartTime.UnixNano()))
}
if !filters.EndTime.IsZero() {
conditions = append(conditions, query.Field("created_at").LtEq(filters.EndTime.UnixNano()))
}
// Combine all conditions
var combinedCondition query.Criteria
for i, cond := range conditions {
if i == 0 {
combinedCondition = cond
} else {
combinedCondition = combinedCondition.And(cond)
}
}
q = q.Where(combinedCondition)
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to query events: %w", err)
}
usages := make([]*Usage, 0, len(docs))
for _, doc := range docs {
var u Usage
if cdid, ok := doc.Get("contract_did").(string); ok {
u.ContractDID = cdid
}
if hcDid, ok := doc.Get("head_contract_did").(string); ok {
u.HeadContractDID = hcDid
}
if pdid, ok := doc.Get("provider_did").(string); ok {
u.ProviderDID = pdid
}
if data, ok := doc.Get("usage_data").([]byte); ok {
u.Data = data
}
if eventTypeStr, ok := doc.Get("event_type").(string); ok {
u.EventType = events.EventType(eventTypeStr)
}
// Extract timestamp from created_at
if timestampNano, ok := doc.Get("created_at").(int64); ok {
u.Timestamp = time.Unix(0, timestampNano)
}
usages = append(usages, &u)
}
return usages, nil
}
// GetEventsByHeadContract retrieves events for Head Contract (head contract in chain)
func (s *Store) GetEventsByHeadContract(headContractDID string) ([]*Usage, error) {
q := query.NewQuery(contractsUsageCollection).
Where(query.Field("head_contract_did").Eq(headContractDID))
docs, err := s.db.FindAll(q)
if err != nil {
return nil, fmt.Errorf("failed to retrieve usages for head contract %s: %w", headContractDID, err)
}
usages := make([]*Usage, 0, len(docs))
for _, doc := range docs {
var u Usage
if cdid, ok := doc.Get("contract_did").(string); ok {
u.ContractDID = cdid
}
if hcDid, ok := doc.Get("head_contract_did").(string); ok {
u.HeadContractDID = hcDid
}
if data, ok := doc.Get("usage_data").([]byte); ok {
u.Data = data
}
if eventTypeStr, ok := doc.Get("event_type").(string); ok {
u.EventType = events.EventType(eventTypeStr)
}
usages = append(usages, &u)
}
return usages, nil
}
// QueryAllocationEvents queries allocation start/end events with smart time bounds.
// This helper eliminates duplication across calculation methods.
// It looks back 1 year to catch allocations that started before query period but are still running.
// If headContractDID is provided, queries by Head Contract DID; otherwise queries by contractDID.
func (s *Store) QueryAllocationEvents(
contractDID string,
queryStart, queryEnd time.Time,
headContractDID string, // Optional: if provided, queries by Head Contract DID
) ([]*Usage, []*Usage, error) {
// Look back 1 year to catch allocations that started before query period
queryStartBound := queryStart.AddDate(-1, 0, 0)
filters := EventFilters{
EventTypes: []events.EventType{events.StartAllocationEvent},
StartTime: queryStartBound,
EndTime: queryEnd,
}
if headContractDID != "" {
filters.HeadContractDID = headContractDID
} else {
filters.ContractDID = contractDID
}
startEvents, err := s.QueryEvents(filters)
if err != nil {
return nil, nil, fmt.Errorf("failed to query start events: %w", err)
}
endFilters := EventFilters{
EventTypes: []events.EventType{events.CompleteAllocationEvent, events.StopAllocationEvent},
StartTime: queryStartBound,
EndTime: queryEnd,
}
if headContractDID != "" {
endFilters.HeadContractDID = headContractDID
} else {
endFilters.ContractDID = contractDID
}
endEvents, err := s.QueryEvents(endFilters)
if err != nil {
return nil, nil, fmt.Errorf("failed to query end events: %w", err)
}
return startEvents, endEvents, nil
}
// QueryDeploymentEvents queries deployment start/stop events with smart time bounds.
// This helper eliminates duplication for deployment-based calculations.
// It looks back 1 year to catch deployments that started before query period but are still running.
// If headContractDID is provided, queries by Head Contract DID; otherwise queries by contractDID.
func (s *Store) QueryDeploymentEvents(
contractDID string,
queryStart, queryEnd time.Time,
headContractDID string, // Optional: if provided, queries by Head Contract DID
) ([]*Usage, []*Usage, error) {
queryStartBound := queryStart.AddDate(-1, 0, 0)
filters := EventFilters{
EventTypes: []events.EventType{events.DeploymentStartEvent},
StartTime: queryStartBound,
EndTime: queryEnd,
}
if headContractDID != "" {
filters.HeadContractDID = headContractDID
} else {
filters.ContractDID = contractDID
}
startEvents, err := s.QueryEvents(filters)
if err != nil {
return nil, nil, fmt.Errorf("failed to query deployment start events: %w", err)
}
stopFilters := EventFilters{
EventTypes: []events.EventType{events.DeploymentStopEvent},
StartTime: queryStartBound,
EndTime: queryEnd,
}
if headContractDID != "" {
stopFilters.HeadContractDID = headContractDID
} else {
stopFilters.ContractDID = contractDID
}
stopEvents, err := s.QueryEvents(stopFilters)
if err != nil {
return nil, nil, fmt.Errorf("failed to query deployment stop events: %w", err)
}
return startEvents, stopEvents, nil
}
// QueryCreateAllocationEvents queries create allocation events (for resource fallback).
// This is used when StartAllocationEvent doesn't contain resources.
// No time restriction - need all for resource fallback lookup.
func (s *Store) QueryCreateAllocationEvents(
contractDID string,
) ([]*Usage, error) {
return s.QueryEvents(EventFilters{
ContractDID: contractDID,
EventTypes: []events.EventType{events.CreateAllocationEvent},
// No time restriction - need all for resource fallback
})
}
// CalculateEffectiveTime calculates effective start/end time for a window within query period.
// This is a simple utility function, not an abstraction.
// It handles the common pattern of:
// - If window started before query period, use query start as effective start
// - If window ended, use window end time (if after query start)
// - If window still running, use query end as effective end
// Returns effectiveStart, effectiveEnd, and valid flag (false if window should be excluded).
func CalculateEffectiveTime(
windowStart, windowEnd time.Time,
isComplete bool,
queryStart, queryEnd time.Time,
) (effectiveStart, effectiveEnd time.Time, valid bool) {
// If window started before query period, use query start
if windowStart.Before(queryStart) {
effectiveStart = queryStart
} else {
effectiveStart = windowStart
}
// Determine effective end time
if isComplete {
// Window ended - check if it ended after query start
if windowEnd.After(queryStart) {
effectiveEnd = windowEnd
} else {
// Window ended before query period, exclude it
return time.Time{}, time.Time{}, false
}
} else {
// Window still running - use query end
effectiveEnd = queryEnd
}
// Validate
if !effectiveStart.Before(effectiveEnd) {
return time.Time{}, time.Time{}, false
}
return effectiveStart, effectiveEnd, true
}
// CountAllocationsByContract retrieves all events within a given time range
// and returns a map of contractDID -> allocation count (based on START_ALLOCATION_EVENT).
// This is the backward-compatible version that returns counts for all contracts.
func (s *Store) CountAllocationsByContract(start, end time.Time) (map[string]int, error) {
// First filter by event_type at DB level, then unmarshal to count unique allocation_ids
usageEvents, err := s.QueryEvents(EventFilters{
EventTypes: []events.EventType{events.StartAllocationEvent},
StartTime: start,
EndTime: end,
})
if err != nil {
return nil, fmt.Errorf("failed to query events: %w", err)
}
// Group by contract and count unique allocations
contractAllocations := make(map[string]map[string]bool)
for _, evt := range usageEvents {
var evtData events.StartAllocation
if err := json.Unmarshal(evt.Data, &evtData); err != nil {
continue
}
if evtData.AllocationID != "" {
if contractAllocations[evt.ContractDID] == nil {
contractAllocations[evt.ContractDID] = make(map[string]bool)
}
contractAllocations[evt.ContractDID][evtData.AllocationID] = true
}
}
// Convert to map[string]int
contractCounts := make(map[string]int)
for contractDID, allocationSet := range contractAllocations {
contractCounts[contractDID] = len(allocationSet)
}
return contractCounts, nil
}
// CountAllocationsByContractDID retrieves events within a given time range
// for a specific contract DID and returns the count of unique allocations based on START_ALLOCATION_EVENT.
func (s *Store) CountAllocationsByContractDID(contractDID string, start, end time.Time) (int, error) {
// First filter by event_type at DB level, then unmarshal to count unique allocation_ids
usageEvents, err := s.QueryEvents(EventFilters{
ContractDID: contractDID,
EventTypes: []events.EventType{events.StartAllocationEvent},
StartTime: start,
EndTime: end,
})
if err != nil {
return 0, fmt.Errorf("failed to query events: %w", err)
}
// Unmarshal JSON to get allocation_id for unique counting
allocationSet := make(map[string]bool)
for _, evt := range usageEvents {
var evtData events.StartAllocation
if err := json.Unmarshal(evt.Data, &evtData); err != nil {
continue
}
if evtData.AllocationID != "" {
allocationSet[evtData.AllocationID] = true
}
}
return len(allocationSet), nil
}
// CountDeploymentsByContract retrieves events within a given time range
// and returns the count of unique deployments based on DEPLOYMENT_START_EVENT.
func (s *Store) CountDeploymentsByContract(contractDID string, start, end time.Time) (int, error) {
usageEvents, err := s.QueryEvents(EventFilters{
ContractDID: contractDID,
EventTypes: []events.EventType{events.DeploymentStartEvent},
StartTime: start,
EndTime: end,
})
if err != nil {
return 0, fmt.Errorf("failed to query events: %w", err)
}
// Unmarshal JSON to get deployment_id for unique counting
deploymentSet := make(map[string]bool)
for _, evt := range usageEvents {
var evtData events.DeploymentStart
if err := json.Unmarshal(evt.Data, &evtData); err != nil {
continue
}
if evtData.DeploymentID != "" {
deploymentSet[evtData.DeploymentID] = true
}
}
return len(deploymentSet), nil
}
// SaveLastProcessedAt stores the last processed timestamp (Unix seconds) for a specific contract.
// If contractDID is empty, it stores a global timestamp.
func (s *Store) SaveLastProcessedAt(contractDID string, t time.Time) error {
ok, err := s.db.HasCollection(lastProcessedAtCollection)
if err != nil {
return fmt.Errorf("failed to check collection: %w", err)
}
if !ok {
if err := s.db.CreateCollection(lastProcessedAtCollection); err != nil {
return fmt.Errorf("failed to create metadata collection: %w", err)
}
}
// Create a unique key for each contract
key := lastProcessedAtKeyPrefix
if contractDID != "" {
key = fmt.Sprintf("%s:%s", lastProcessedAtKeyPrefix, contractDID)
}
q := query.NewQuery(lastProcessedAtCollection).Where(query.Field("key").Eq(key))
docs, err := s.db.FindAll(q)
if err != nil {
return fmt.Errorf("failed to query metadata: %w", err)
}
if len(docs) > 0 {
doc := docs[0]
doc.Set(lastProcessedAtKeyPrefix, t.Unix())
if err := s.db.ReplaceById(lastProcessedAtCollection, doc.ObjectId(), doc); err != nil {
return fmt.Errorf("failed to update last processed at: %w", err)
}
} else {
doc := document.NewDocument()
doc.Set("key", key)
if contractDID != "" {
doc.Set("contract_did", contractDID)
}
doc.Set(lastProcessedAtKeyPrefix, t.Unix())
if _, err := s.db.InsertOne(lastProcessedAtCollection, doc); err != nil {
return fmt.Errorf("failed to insert last processed at: %w", err)
}
}
return nil
}
// GetLastProcessedAt retrieves the last processed timestamp for a specific contract.
// If contractDID is empty, it retrieves the global timestamp.
// If no record exists, it returns Unix(0).
func (s *Store) GetLastProcessedAt(contractDID string) (time.Time, error) {
ok, err := s.db.HasCollection(lastProcessedAtCollection)
if err != nil {
return time.Time{}, fmt.Errorf("failed to check metadata collection: %w", err)
}
if !ok {
return time.Unix(0, 0), nil
}
// Create a unique key for each contract
key := lastProcessedAtKeyPrefix
if contractDID != "" {
key = fmt.Sprintf("%s:%s", lastProcessedAtKeyPrefix, contractDID)
}
q := query.NewQuery(lastProcessedAtCollection).Where(query.Field("key").Eq(key))
docs, err := s.db.FindAll(q)
if err != nil {
return time.Time{}, fmt.Errorf("failed to query metadata: %w", err)
}
if len(docs) == 0 {
return time.Unix(0, 0), nil
}
doc := docs[0]
if ts, ok := doc.Get(lastProcessedAtKeyPrefix).(int64); ok {
return time.Unix(ts, 0), nil
}
return time.Unix(0, 0), nil
}
// InitializeContractMetadata initializes the usage metadata for a new contract.
// This creates a metadata entry with the contract-specific key and sets the initial
// last processed timestamp to Unix(0).
func (s *Store) InitializeContractMetadata(contractDID string) error {
if contractDID == "" {
return errors.New("contractDID cannot be empty")
}
ok, err := s.db.HasCollection(lastProcessedAtCollection)
if err != nil {
return fmt.Errorf("failed to check collection: %w", err)
}
if !ok {
if err := s.db.CreateCollection(lastProcessedAtCollection); err != nil {
return fmt.Errorf("failed to create metadata collection: %w", err)
}
}
// Create contract-specific key
key := fmt.Sprintf("%s:%s", lastProcessedAtKeyPrefix, contractDID)
// Check if metadata already exists for this contract
q := query.NewQuery(lastProcessedAtCollection).Where(query.Field("key").Eq(key))
docs, err := s.db.FindAll(q)
if err != nil {
return fmt.Errorf("failed to query metadata: %w", err)
}
// If metadata already exists, don't overwrite it
if len(docs) > 0 {
return nil
}
// Create new metadata entry for this contract
doc := document.NewDocument()
doc.Set("key", key)
doc.Set("contract_did", contractDID)
doc.Set(lastProcessedAtKeyPrefix, time.Unix(0, 0).Unix())
if _, err := s.db.InsertOne(lastProcessedAtCollection, doc); err != nil {
return fmt.Errorf("failed to initialize contract metadata: %w", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"fmt"
"strconv"
"strings"
)
func ConstructAllocationID(ensembleID, allocName string) string {
return ensembleID + "_" + allocName
}
func AllocationNameFromID(id string) string {
return id[strings.LastIndex(id, "_")+1:]
}
func EnsembleIDFromAllocationID(id string) string {
if strings.Count(id, "_") == 0 {
return id
}
return id[:strings.LastIndex(id, "_")]
}
// AllocationIdentifier represents a structured allocation identifier
// that can be used consistently across the orchestrator
type AllocationIdentifier struct {
EnsembleID string
NodeID string
AllocationName string
IsStandby bool
StandbyIndex int
}
// String returns the full allocation ID for actor handles and cross-ensemble references
// Format: ensembleID_nodeID.allocName or ensembleID_nodeID-standby-N.allocName
func (aid AllocationIdentifier) String() string {
if aid.IsStandby {
// Check if NodeID already contains standby suffix
if strings.Contains(aid.NodeID, "-standby-") {
// NodeID already has standby suffix, use it as is
return fmt.Sprintf("%s_%s.%s", aid.EnsembleID, aid.NodeID, aid.AllocationName)
}
// NodeID is primary, add standby suffix
return fmt.Sprintf("%s_%s-standby-%d.%s",
aid.EnsembleID, aid.NodeID, aid.StandbyIndex, aid.AllocationName)
}
return fmt.Sprintf("%s_%s.%s", aid.EnsembleID, aid.NodeID, aid.AllocationName)
}
// ManifestKey returns the key used in manifest.Allocations map
// Format: nodeID.allocName or nodeID-standby-N.allocName
func (aid AllocationIdentifier) ManifestKey() string {
if aid.IsStandby {
// Check if NodeID already contains standby suffix
if strings.Contains(aid.NodeID, "-standby-") {
// NodeID already has standby suffix, use it as is
return fmt.Sprintf("%s.%s", aid.NodeID, aid.AllocationName)
}
// NodeID is primary, add standby suffix
return fmt.Sprintf("%s-standby-%d.%s", aid.NodeID, aid.StandbyIndex, aid.AllocationName)
}
return fmt.Sprintf("%s.%s", aid.NodeID, aid.AllocationName)
}
// ConfigName returns the base allocation name from configuration
// Format: allocName (same for primary and standby)
func (aid AllocationIdentifier) ConfigName() string {
return aid.AllocationName
}
// PrimaryNodeID returns the primary node ID (removes standby suffix if present)
func (aid AllocationIdentifier) PrimaryNodeID() string {
if aid.IsStandby {
return aid.NodeID
}
return aid.NodeID
}
// NewAllocationID creates a new AllocationIdentifier for a primary allocation
func NewAllocationID(ensembleID, nodeID, allocName string) AllocationIdentifier {
return AllocationIdentifier{
EnsembleID: ensembleID,
NodeID: nodeID,
AllocationName: allocName,
IsStandby: false,
StandbyIndex: 0,
}
}
// NewStandbyAllocationID creates a new AllocationIdentifier for a standby allocation
func NewStandbyAllocationID(ensembleID, primaryNodeID, allocName string, standbyIndex int) AllocationIdentifier {
return AllocationIdentifier{
EnsembleID: ensembleID,
NodeID: primaryNodeID, // Keep the primary node ID, we'll handle standby formatting in String() and ManifestKey()
AllocationName: allocName,
IsStandby: true,
StandbyIndex: standbyIndex,
}
}
// ParseAllocationID parses a full allocation ID string into an AllocationIdentifier
// Format: ensembleID_nodeID.allocName or ensembleID_nodeID-standby-N.allocName
func ParseAllocationID(id string) (AllocationIdentifier, error) {
// Split by first underscore to get ensembleID and the rest
parts := strings.SplitN(id, "_", 2)
if len(parts) != 2 {
return AllocationIdentifier{}, fmt.Errorf("invalid allocation ID format: %s", id)
}
ensembleID := parts[0]
rest := parts[1]
// Check if it's a standby allocation
if strings.Contains(rest, "-standby-") {
// Format: nodeID-standby-N.allocName
standbyParts := strings.Split(rest, "-standby-")
if len(standbyParts) != 2 {
return AllocationIdentifier{}, fmt.Errorf("invalid standby allocation ID format: %s", id)
}
nodeID := standbyParts[0]
indexAndAlloc := strings.Split(standbyParts[1], ".")
if len(indexAndAlloc) != 2 {
return AllocationIdentifier{}, fmt.Errorf("invalid standby allocation ID format: %s", id)
}
// Validate that the index part is a valid integer
standbyIndex, err := strconv.Atoi(indexAndAlloc[0])
if err != nil {
return AllocationIdentifier{}, fmt.Errorf("invalid standby index in allocation ID: %s", id)
}
allocName := indexAndAlloc[1]
return AllocationIdentifier{
EnsembleID: ensembleID,
NodeID: fmt.Sprintf("%s-standby-%d", nodeID, standbyIndex),
AllocationName: allocName,
IsStandby: true,
StandbyIndex: standbyIndex,
}, nil
} else if strings.Contains(rest, "standby") {
return AllocationIdentifier{}, fmt.Errorf("invalid standby allocation ID format: %s", id)
}
// Format: nodeID.allocName
nodeAllocParts := strings.Split(rest, ".")
if len(nodeAllocParts) != 2 {
return AllocationIdentifier{}, fmt.Errorf("invalid allocation ID format: %s", id)
}
return AllocationIdentifier{
EnsembleID: ensembleID,
NodeID: nodeAllocParts[0],
AllocationName: nodeAllocParts[1],
IsStandby: false,
StandbyIndex: 0,
}, nil
}
// ParseManifestKey parses a manifest key into an AllocationIdentifier
// Format: nodeID.allocName or nodeID-standby-N.allocName
// Note: This requires the ensembleID to be provided separately
func ParseManifestKey(manifestKey, ensembleID string) (AllocationIdentifier, error) {
// Check if it's a standby allocation
if strings.Contains(manifestKey, "-standby-") {
// Format: nodeID-standby-N.allocName
standbyParts := strings.Split(manifestKey, "-standby-")
if len(standbyParts) != 2 {
return AllocationIdentifier{}, fmt.Errorf("invalid standby manifest key format: %s", manifestKey)
}
nodeID := standbyParts[0]
indexAndAlloc := strings.Split(standbyParts[1], ".")
if len(indexAndAlloc) != 2 {
return AllocationIdentifier{}, fmt.Errorf("invalid standby manifest key format: %s", manifestKey)
}
standbyIndex, err := strconv.Atoi(indexAndAlloc[0])
if err != nil {
return AllocationIdentifier{}, fmt.Errorf("invalid standby index in manifest key: %s", manifestKey)
}
allocName := indexAndAlloc[1]
return AllocationIdentifier{
EnsembleID: ensembleID,
NodeID: fmt.Sprintf("%s-standby-%d", nodeID, standbyIndex),
AllocationName: allocName,
IsStandby: true,
StandbyIndex: standbyIndex,
}, nil
}
// Format: nodeID.allocName
nodeAllocParts := strings.Split(manifestKey, ".")
if len(nodeAllocParts) != 2 {
return AllocationIdentifier{}, fmt.Errorf("invalid manifest key format: %s", manifestKey)
}
return AllocationIdentifier{
EnsembleID: ensembleID,
NodeID: nodeAllocParts[0],
AllocationName: nodeAllocParts[1],
IsStandby: false,
StandbyIndex: 0,
}, nil
}
// ParseNodeName parses a node name to determine if it's a standby node
// Returns (isStandby, primaryNodeID, standbyIndex)
func ParseNodeName(nodeName string) (bool, string, int) {
if strings.Contains(nodeName, "-standby-") {
parts := strings.Split(nodeName, "-standby-")
if len(parts) == 2 {
standbyIndex, err := strconv.Atoi(parts[1])
if err == nil {
return true, parts[0], standbyIndex
}
}
}
return false, nodeName, 0
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"fmt"
"reflect"
"slices"
"github.com/hashicorp/go-version"
)
// Connectivity represents the network configuration
type Connectivity struct {
Ports []int `json:"ports" description:"Ports that need to be open for the job to run"`
VPN bool `json:"vpn" description:"Whether VPN is required"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Connectivity] = (*Connectivity)(nil)
_ Calculable[Connectivity] = (*Connectivity)(nil)
)
func (c *Connectivity) Compare(other Connectivity) (Comparison, error) {
if reflect.DeepEqual(*c, other) {
return Equal, nil
}
if IsStrictlyContainedInt(c.Ports, other.Ports) && (c.VPN && other.VPN || c.VPN && !other.VPN) {
return Better, nil
}
return None, nil
}
func (c *Connectivity) Add(other Connectivity) error {
portSet := make(map[int]struct{}, len(c.Ports))
for _, port := range c.Ports {
portSet[port] = struct{}{}
}
for _, port := range other.Ports {
if _, exists := portSet[port]; !exists {
c.Ports = append(c.Ports, port)
}
}
// Set VPN to true if other has VPN
c.VPN = c.VPN || other.VPN
return nil
}
func (c *Connectivity) Subtract(other Connectivity) error {
if other.VPN {
c.VPN = false
}
// Filter out the ports that exist in 'other' from 'c'
filteredPorts := c.Ports[:0]
for _, port := range c.Ports {
if !slices.Contains(other.Ports, port) {
filteredPorts = append(filteredPorts, port)
}
}
// Set the filtered ports back to c
if len(filteredPorts) == 0 {
c.Ports = nil
} else {
c.Ports = filteredPorts
}
return nil
}
// TimeInformation represents the time constraints
type TimeInformation struct {
Units string `json:"units" description:"Time units"`
MaxTime int `json:"max_time" description:"Maximum time that job should run"`
Preference int `json:"preference" description:"Time preference"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[TimeInformation] = (*TimeInformation)(nil)
_ Calculable[TimeInformation] = (*TimeInformation)(nil)
)
func (t *TimeInformation) Add(other TimeInformation) error {
t.MaxTime += other.MaxTime
// If there are other time fields, add them here
// Example:
// t.MinTime += other.MinTime
// t.AvgTime += other.AvgTime
return nil
}
func (t *TimeInformation) Subtract(other TimeInformation) error {
t.MaxTime -= other.MaxTime
// If there are other time fields, subtract them here
// Example:
// t.MinTime -= other.MinTime
// t.AvgTime -= other.AvgTime
return nil
}
func (t *TimeInformation) TotalTime() int {
switch t.Units {
case "seconds":
return t.MaxTime
case "minutes":
return t.MaxTime * 60
case "hours":
return t.MaxTime * 60 * 60
case "days":
return t.MaxTime * 60 * 60 * 24
default:
return t.MaxTime
}
}
func (t *TimeInformation) Compare(other TimeInformation) (Comparison, error) {
if reflect.DeepEqual(t, other) {
return Equal, nil
}
ownTotalTime := t.TotalTime()
otherTotalTime := other.TotalTime()
if ownTotalTime == otherTotalTime {
return Equal, nil
}
if ownTotalTime < otherTotalTime {
return Worse, nil
}
return Better, nil
}
// PriceInformation represents the pricing information
type PriceInformation struct {
Currency string `json:"currency" description:"Currency used for pricing"`
CurrencyPerHour int `json:"currency_per_hour" description:"Price charged per hour"`
TotalPerJob int `json:"total_per_job" description:"Maximum total price or budget of the job"`
Preference int `json:"preference" description:"Pricing preference"`
}
// implementing Comparable interface
var _ Comparable[PriceInformation] = (*PriceInformation)(nil)
func (p *PriceInformation) Compare(other PriceInformation) (Comparison, error) {
if reflect.DeepEqual(p, other) {
return Equal, nil
}
if p.Currency == other.Currency {
if p.TotalPerJob == other.TotalPerJob {
if p.CurrencyPerHour == other.CurrencyPerHour {
return Equal, nil
} else if p.CurrencyPerHour < other.CurrencyPerHour {
return Better, nil
}
return Worse, nil
}
if p.TotalPerJob < other.TotalPerJob {
if p.CurrencyPerHour <= other.CurrencyPerHour {
return Better, nil
}
return Worse, nil
}
return Worse, nil
}
return None, nil
}
func (p *PriceInformation) Equal(price PriceInformation) bool {
return p.Currency == price.Currency &&
p.CurrencyPerHour == price.CurrencyPerHour &&
p.TotalPerJob == price.TotalPerJob &&
p.Preference == price.Preference
}
// Library represents the libraries
type Library struct {
Name string `json:"name" description:"Name of the library"`
Constraint string `json:"constraint" description:"Constraint of the library"`
Version string `json:"version" description:"Version of the library"`
}
// implementing Comparable interface
var _ Comparable[Library] = (*Library)(nil)
func (lib *Library) Compare(other Library) (Comparison, error) {
ownVersion, err := version.NewVersion(lib.Version)
if err != nil {
return None, fmt.Errorf("error parsing version: %v", err)
}
constraints, err := version.NewConstraint(other.Constraint + " " + other.Version)
if err != nil {
return None, fmt.Errorf("error parsing constraint: %v", err)
}
// return 'None' if the names of the libraries are different
if lib.Name != other.Name {
return None, nil
}
// else return 'Equal if versions of libraries are equal and the constraint is '='
if other.Constraint == "=" && constraints.Check(ownVersion) {
return Equal, nil
}
// else return 'Better' if versions of libraries match the constraint
if constraints.Check(ownVersion) {
return Better, nil
}
// else return 'Worse'
return Worse, nil
}
func (lib *Library) Equal(library Library) bool {
if lib.Name == library.Name && lib.Constraint == library.Constraint && lib.Version == library.Version {
return true
}
return false
}
// Locality represents the locality
type Locality struct {
Kind string `json:"kind" description:"Kind of the region (geographic, nunet-defined, etc)"`
Name string `json:"name" description:"Name of the region"`
}
// implementing Comparable interface
var _ Comparable[Locality] = (*Locality)(nil)
func (loc *Locality) Compare(other Locality) (Comparison, error) {
if loc.Kind != other.Kind {
return None, nil
}
if loc.Name == other.Name {
return Equal, nil
}
return Worse, nil
}
func (loc *Locality) Equal(locality Locality) bool {
if loc.Kind == locality.Kind && loc.Name == locality.Name {
return true
}
return false
}
// KYC represents the KYC data
type KYC struct {
Type string `json:"type" description:"Type of KYC"`
Data string `json:"data" description:"Data required for KYC"`
}
// implementing Comparable interface
var _ Comparable[KYC] = (*KYC)(nil)
func (k *KYC) Compare(other KYC) (Comparison, error) {
if reflect.DeepEqual(*k, other) {
return Equal, nil
}
return None, nil
}
func (k *KYC) Equal(kyc KYC) bool {
if k.Type == kyc.Type && k.Data == kyc.Data {
return true
}
return false
}
// JobType represents the type of the job
type JobType string
const (
Batch JobType = "batch"
SingleRun JobType = "single_run"
Recurring JobType = "recurring"
LongRunning JobType = "long_running"
)
// implementing Comparable and Calculable interface
var _ Comparable[JobType] = (*JobType)(nil)
func (j JobType) Compare(other JobType) (Comparison, error) {
if reflect.DeepEqual(j, other) {
return Equal, nil
}
return None, nil
}
// JobTypes a slice of JobType
type JobTypes []JobType
// implementing Comparable and Calculable interfaces
var (
_ Comparable[JobTypes] = (*JobTypes)(nil)
_ Calculable[JobTypes] = (*JobTypes)(nil)
)
func (j *JobTypes) Add(other JobTypes) error {
existing := *j
result := existing[:0]
existingSet := make(map[JobType]struct{}, len(result))
for _, job := range *j {
result = append(result, job)
existingSet[job] = struct{}{}
}
for _, job := range other {
if _, exists := existingSet[job]; !exists {
result = append(result, job)
existingSet[job] = struct{}{}
}
}
*j = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (j *JobTypes) Subtract(other JobTypes) error {
existing := *j
result := existing[:0]
toRemove := make(map[JobType]struct{}, len(other))
for _, job := range other {
toRemove[job] = struct{}{}
}
for _, job := range existing {
if _, found := toRemove[job]; !found {
result = append(result, job)
}
}
*j = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (j *JobTypes) Compare(other JobTypes) (Comparison, error) {
// we know that interfaces here are slices, so need to assert first
l := ConvertTypedSliceToUntypedSlice(*j)
r := ConvertTypedSliceToUntypedSlice(other)
if !IsSameShallowType(l, r) {
// cannot compare different types
return None, nil
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return Equal, nil
case IsStrictlyContained(l, r):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
return Better, nil
case IsStrictlyContained(r, l):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
return Worse, nil
// TODO: this comparator does not take into account options when several job types are available and several job types are required
// in the same data structure; this is why the test fails;
}
return None, nil
}
func (j *JobTypes) Contains(jobType JobType) bool {
for _, j := range *j {
if j == jobType {
return true
}
}
return false
}
type Libraries []Library
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Libraries] = (*Libraries)(nil)
_ Calculable[Libraries] = (*Libraries)(nil)
)
func (l *Libraries) Add(other Libraries) error {
existing := make(map[Library]struct{}, len(*l))
for _, lib := range *l {
existing[lib] = struct{}{}
}
for _, otherLibrary := range other {
if _, found := existing[otherLibrary]; !found {
*l = append(*l, otherLibrary)
existing[otherLibrary] = struct{}{}
}
}
return nil
}
func (l *Libraries) Subtract(other Libraries) error {
// remove from array
existing := *l
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[Library]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex2 := range existing {
if _, found := toRemove[ex2]; !found {
result = append(result, ex2)
}
}
*l = result
return nil
}
func (l *Libraries) Compare(other Libraries) (Comparison, error) {
interimComparison1 := make([][]Comparison, 0)
for _, otherLibrary := range other {
var interimComparison2 []Comparison
for _, ownLibrary := range *l {
c, err := ownLibrary.Compare(otherLibrary)
if err != nil {
return None, fmt.Errorf("error comparing library: %v", err)
}
interimComparison2 = append(interimComparison2, c)
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
finalComparison := make([]Comparison, 0)
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, Worse) {
return Worse, nil
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal, nil
}
return Better, nil
}
func (l *Libraries) Contains(library Library) bool {
for _, lib := range *l {
if lib.Equal(library) {
return true
}
}
return false
}
type Localities []Locality
// implementing Comparable and Calculable interfaces
var (
_ Comparable[Localities] = (*Localities)(nil)
_ Calculable[Localities] = (*Localities)(nil)
)
func (l *Localities) Compare(other Localities) (Comparison, error) {
interimComparison := make([]map[string]Comparison, 0)
for _, otherLocality := range other {
field := make(map[string]Comparison)
field[otherLocality.Kind] = None
for _, ownLocality := range *l {
if ownLocality.Kind == otherLocality.Kind {
c, err := ownLocality.Compare(otherLocality)
if err != nil {
return None, fmt.Errorf("error comparing locality: %v", err)
}
field[otherLocality.Kind] = c
// this is to make sure that we have a comparison even if slice dimentiones do not match
}
}
interimComparison = append(interimComparison, field)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []Comparison
for _, c := range interimComparison {
for _, v := range c { // we know that there is only one value in the map
finalComparison = append(finalComparison, v)
}
}
if slices.Contains(finalComparison, Worse) {
return Worse, nil
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal, nil
}
return Better, nil
}
func (l *Localities) Add(other Localities) error {
existing := make(map[Locality]struct{}, len(*l))
for _, loc := range *l {
existing[loc] = struct{}{}
}
for _, otherLocality := range other {
if _, found := existing[otherLocality]; !found {
*l = append(*l, otherLocality)
existing[otherLocality] = struct{}{}
}
}
return nil
}
func (l *Localities) Subtract(other Localities) error {
// remove from array
existing := *l
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[Locality]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex := range existing {
if _, found := toRemove[ex]; !found {
result = append(result, ex)
}
}
*l = result
return nil
}
func (l *Localities) Contains(locality Locality) bool {
for _, loc := range *l {
if loc.Equal(locality) {
return true
}
}
return false
}
type KYCs []KYC
// implementing Comparable and Calculable interfaces
var (
_ Comparable[KYCs] = (*KYCs)(nil)
_ Calculable[KYCs] = (*KYCs)(nil)
)
func (k *KYCs) Compare(other KYCs) (Comparison, error) {
if reflect.DeepEqual(*k, other) {
return Equal, nil
} else if len(other) == 0 && len(*k) != 0 {
return Better, nil
}
for _, ownKYC := range *k {
for _, otherKYC := range other {
comp, err := ownKYC.Compare(otherKYC)
if err != nil {
return None, fmt.Errorf("error comparing KYC: %v", err)
}
if comp != None {
return comp, nil
}
}
}
return None, nil
}
func (k *KYCs) Add(other KYCs) error {
existing := make(map[KYC]struct{}, len(*k))
for _, kyc := range *k {
existing[kyc] = struct{}{}
}
for _, otherKYC := range other {
if _, found := existing[otherKYC]; !found {
*k = append(*k, otherKYC)
existing[otherKYC] = struct{}{}
}
}
return nil
}
func (k *KYCs) Subtract(other KYCs) error {
// remove from array
existing := *k
result := existing[:0] // Reuse the underlying slice to avoid allocations
toRemove := make(map[KYC]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
for _, ex := range existing {
if _, found := toRemove[ex]; !found {
result = append(result, ex)
}
}
*k = result
return nil
}
func (k *KYCs) Contains(kyc KYC) bool {
for _, k := range *k {
if k.Equal(kyc) {
return true
}
}
return false
}
type PricesInformation []PriceInformation
// implementing Comparable and Calculable interfaces
var (
_ Comparable[PricesInformation] = (*PricesInformation)(nil)
_ Calculable[PricesInformation] = (*PricesInformation)(nil)
)
func (ps *PricesInformation) Add(other PricesInformation) error {
existing := make(map[PriceInformation]struct{}, len(*ps))
for _, p := range *ps {
existing[p] = struct{}{}
}
for _, otherPrice := range other {
if _, found := existing[otherPrice]; !found {
*ps = append(*ps, otherPrice)
existing[otherPrice] = struct{}{}
}
}
return nil
}
func (ps *PricesInformation) Subtract(other PricesInformation) error {
if len(other) == 0 {
return nil // Nothing to subtract, no operation needed
}
toRemove := make(map[PriceInformation]struct{}, len(other))
for _, ex := range other {
toRemove[ex] = struct{}{}
}
result := (*ps)[:0] // Reuse the slice's underlying array
for _, ex := range *ps {
if _, found := toRemove[ex]; !found {
result = append(result, ex) // Keep entries not in 'toRemove'
}
}
*ps = result[:len(result):len(result)] // Resize the slice to the new length
return nil
}
func (ps *PricesInformation) Compare(other PricesInformation) (Comparison, error) {
if reflect.DeepEqual(*ps, other) {
return Equal, nil
}
for _, ownPrice := range *ps {
for _, otherPrice := range other {
c, err := ownPrice.Compare(otherPrice)
if err != nil {
return None, fmt.Errorf("error comparing price: %v", err)
}
if c != None {
return c, nil
}
}
}
return None, nil
}
func (ps *PricesInformation) Contains(price PriceInformation) bool {
for _, p := range *ps {
if p.Equal(price) {
return true
}
}
return false
}
// HardwareCapability represents the hardware capability of the machine
type HardwareCapability struct {
Executors Executors `json:"executor" description:"Executor type required for the job (docker, vm, wasm, or others)"`
JobTypes JobTypes `json:"type" description:"Details about type of the job (One time, batch, recurring, long running). Refer to dms.jobs package for jobType data model"`
Resources Resources `json:"resources" description:"Resources required for the job"`
Libraries Libraries `json:"libraries" description:"Libraries required for the job"`
Localities Localities `json:"locality" description:"Preferred localities of the machine for execution"`
Connectivity Connectivity `json:"connectivity" description:"Network configuration required"`
Price PricesInformation `json:"price" description:"Pricing information"`
Time TimeInformation `json:"time" description:"Time constraints"`
KYCs KYCs
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[HardwareCapability] = (*HardwareCapability)(nil)
_ Calculable[HardwareCapability] = (*HardwareCapability)(nil)
)
// Compare compares two HardwareCapability objects
func (c *HardwareCapability) Compare(other HardwareCapability) (Comparison, error) {
complexComparison := make(ComplexComparison)
compareFields := []struct {
name string
compare func() (Comparison, error)
}{
{"Executors", func() (Comparison, error) { return c.Executors.Compare(other.Executors) }},
{"JobTypes", func() (Comparison, error) { return c.JobTypes.Compare(other.JobTypes) }},
{"Resources", func() (Comparison, error) { return c.Resources.Compare(other.Resources) }},
{"Libraries", func() (Comparison, error) { return c.Libraries.Compare(other.Libraries) }},
{"Localities", func() (Comparison, error) { return c.Localities.Compare(other.Localities) }},
{"Connectivity", func() (Comparison, error) { return c.Connectivity.Compare(other.Connectivity) }},
{"Price", func() (Comparison, error) { return c.Price.Compare(other.Price) }},
{"Time", func() (Comparison, error) { return c.Time.Compare(other.Time) }},
{"KYCs", func() (Comparison, error) { return c.KYCs.Compare(other.KYCs) }},
}
for _, field := range compareFields {
result, err := field.compare()
if err != nil {
return None, fmt.Errorf("error comparing %s: %v", field.name, err)
}
complexComparison[field.name] = result
}
return complexComparison.Result(), nil
}
// Add adds the resources of the given HardwareCapability to the current HardwareCapability
func (c *HardwareCapability) Add(other HardwareCapability) error {
// Executors
if err := c.Executors.Add(other.Executors); err != nil {
return fmt.Errorf("error adding Executors")
}
// JobTypes
if err := c.JobTypes.Add(other.JobTypes); err != nil {
return fmt.Errorf("error adding JobTypes: %v", err)
}
// Resources
if err := c.Resources.Add(other.Resources); err != nil {
return err
}
// Libraries
if err := c.Libraries.Add(other.Libraries); err != nil {
return fmt.Errorf("error adding Libraries: %v", err)
}
// Localities
if err := c.Localities.Add(other.Localities); err != nil {
return fmt.Errorf("error adding Localities: %v", err)
}
// Connectivity
if err := c.Connectivity.Add(other.Connectivity); err != nil {
return fmt.Errorf("error adding Connectivity: %v", err)
}
// Price
if err := c.Price.Add(other.Price); err != nil {
return fmt.Errorf("error adding Price: %v", err)
}
// Time
if err := c.Time.Add(other.Time); err != nil {
return fmt.Errorf("error adding Time: %v", err)
}
// KYCs
if err := c.KYCs.Add(other.KYCs); err != nil {
return fmt.Errorf("error adding KYCs: %v", err)
}
return nil
}
// Subtract subtracts the resources of the given HardwareCapability from the current HardwareCapability
func (c *HardwareCapability) Subtract(hwCap HardwareCapability) error {
// Executors
if err := c.Executors.Subtract(hwCap.Executors); err != nil {
return fmt.Errorf("error subtracting Executors: %v", err)
}
// JobTypes
if err := c.JobTypes.Subtract(hwCap.JobTypes); err != nil {
return fmt.Errorf("error comparing JobTypes: %v", err)
}
// Resources
if err := c.Resources.Subtract(hwCap.Resources); err != nil {
return fmt.Errorf("error subtracting Resources: %v", err)
}
// Libraries
if err := c.Libraries.Subtract(hwCap.Libraries); err != nil {
return fmt.Errorf("error subtracting Libraries: %v", err)
}
// Localities
if err := c.Localities.Subtract(hwCap.Localities); err != nil {
return fmt.Errorf("error subtracting Localities: %v", err)
}
// Connectivity
if err := c.Connectivity.Subtract(hwCap.Connectivity); err != nil {
return fmt.Errorf("error subtracting Connectivity: %v", err)
}
// Price
if err := c.Price.Subtract(hwCap.Price); err != nil {
return fmt.Errorf("error subtracting Price: %v", err)
}
// Time
if err := c.Time.Subtract(hwCap.Time); err != nil {
return fmt.Errorf("error subtracting Time: %v", err)
}
// KYCs
if err := c.KYCs.Subtract(hwCap.KYCs); err != nil {
return fmt.Errorf("error subtracting KYCs: %v", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import "reflect"
// Comparable public Comparable interface to be enforced on types that can be compared
type Comparable[T any] interface {
Compare(other T) (Comparison, error)
}
type Calculable[T any] interface {
Add(other T) error
Subtract(other T) error
}
type PreferenceString string
const (
Hard PreferenceString = "Hard"
Soft PreferenceString = "Soft"
)
type Preference struct {
TypeName string
Strength PreferenceString
DefaultComparatorOverride Comparator
}
type Comparator func(l, r any, preference ...Preference) Comparison
// ComplexCompare helper function to return a complex comparison of two complex types
// this uses reflection, could become a performance bottleneck
// with generics, we wouldn't need this function and could use the ComplexCompare method directly
func ComplexCompare(l, r any) ComplexComparison {
// Complex comparison is a comparison of two complex types
// Which have nested fields that need to be considered together
// before a final comparison for the whole complex type can be made
// it is a helper function used in some type comparators
complexComparison := make(map[string]Comparison)
val1 := reflect.ValueOf(l)
val2 := reflect.ValueOf(r)
// handle pointers
if val1.Kind() == reflect.Ptr {
val1 = val1.Elem()
}
if val2.Kind() == reflect.Ptr {
val2 = val2.Elem()
}
// ensure we're working with struct types
if val1.Kind() != reflect.Struct || val2.Kind() != reflect.Struct {
// Question: should return error?
return complexComparison
}
for i := range val1.NumField() {
fieldName := val1.Type().Field(i).Name
field1 := val1.Field(i)
field2 := val2.Field(i)
// compare primitive types directly
switch field1.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
complexComparison[fieldName] = NumericComparator(field1.Int(), field2.Int())
continue
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
complexComparison[fieldName] = NumericComparator(field1.Uint(), field2.Uint())
continue
case reflect.Float32, reflect.Float64:
complexComparison[fieldName] = NumericComparator(field1.Float(), field2.Float())
continue
case reflect.String:
complexComparison[fieldName] = LiteralComparator(field1.String(), field2.String())
continue
}
// for struct fields, try to use their Compare method if available
if field1.Kind() == reflect.Struct {
// try to find a Compare method on the field
var field1Ptr reflect.Value
if field1.CanAddr() {
field1Ptr = field1.Addr()
} else {
ptr := reflect.New(field1.Type())
ptr.Elem().Set(field1)
field1Ptr = ptr
}
compareMethod := field1Ptr.MethodByName("Compare")
if compareMethod.IsValid() &&
compareMethod.Type().NumIn() == 1 &&
compareMethod.Type().In(0) == field1.Type() {
result := compareMethod.Call([]reflect.Value{field2})
if len(result) > 0 {
comp, ok := result[0].Interface().(Comparison)
if ok {
complexComparison[fieldName] = comp
continue
}
}
}
}
// default to None if no comparison could be made
complexComparison[fieldName] = None
}
return complexComparison
}
func NumericComparator[T float64 | float32 | int | int32 | int64 | uint64](l, r T, _ ...Preference) Comparison {
// comparator for numeric types:
// left represents machine capabilities;
// right represents required capabilities;
switch {
case l == r:
return Equal
case l < r:
return Worse
default:
return Better
}
}
func LiteralComparator[T ~string](l, r T, _ ...Preference) Comparison {
// comparator for literal (string-like) types:
// left represents machine capabilities;
// right represents required capabilities;
// which can only be equal or not equal.
// ComplexCompare the string values
if l == r {
return Equal
}
return None
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
// Comparison is a type for comparison results
type Comparison string
const (
// Worse means left object is 'worse' than right object
Worse Comparison = "Worse"
// Better means left object is 'better' than right object
Better Comparison = "Better"
// Equal means objects on the left and right are 'equally good'
Equal Comparison = "Equal"
// None means comparison could not be performed
None Comparison = "None"
)
// And returns the result of AND operation of two Comparison values
// it respects the following table of truth:
// | AND | Better | Worse | Equal | None |
// | ------ | ------ |--------|--------|--------|
// | Better | Better | Worse | Better | None |
// | Worse | Worse | Worse | Worse | None |
// | Equal | Better | Worse | Equal | None |
// | None | None | None | None | None |
func (c Comparison) And(cmp Comparison) Comparison {
if c == cmp {
return c
}
switch c {
case Equal:
switch cmp {
case Better:
return Better
case Worse:
return Worse
case Equal:
return Equal
default:
return None
}
case Better:
switch cmp {
case Worse:
return Worse
case Equal:
return Better
default:
return None
}
case Worse:
return Worse
default:
return None
}
}
// ComplexComparison is a map of string to Comparison
type ComplexComparison map[string]Comparison
// Result returns the result of AND operation of all Comparison values in the ComplexComparison
func (c *ComplexComparison) Result() Comparison {
result := Equal
for _, comparison := range *c {
result = result.And(comparison)
}
return result
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"fmt"
"strings"
)
// DefaultAllocationIDGenerator is the default implementation of AllocationIDGenerator
type DefaultAllocationIDGenerator struct{}
// NewDefaultAllocationIDGenerator creates a new DefaultAllocationIDGenerator
func NewDefaultAllocationIDGenerator() *DefaultAllocationIDGenerator {
return &DefaultAllocationIDGenerator{}
}
// GenerateManifestKey generates the key for manifest.Allocations map
// Format: nodeID.allocName (generator doesn't care if nodeID is standby or not)
func (g *DefaultAllocationIDGenerator) GenerateManifestKey(nodeID, allocName string) (string, error) {
if nodeID == "" || allocName == "" {
return "", fmt.Errorf("nodeID and allocName cannot be empty")
}
// Simple format: nodeID.allocName (generator doesn't care if nodeID is standby or not)
return fmt.Sprintf("%s.%s", nodeID, allocName), nil
}
// GenerateFullAllocationID generates the full allocation ID for actor handles
// Format: ensembleID_nodeID.allocName (generator doesn't care if nodeID is standby or not)
func (g *DefaultAllocationIDGenerator) GenerateFullAllocationID(ensembleID, nodeID, allocName string) (string, error) {
if ensembleID == "" || nodeID == "" || allocName == "" {
return "", fmt.Errorf("ensembleID, nodeID, and allocName cannot be empty")
}
// Simple format: ensembleID_nodeID.allocName (generator doesn't care if nodeID is standby or not)
return fmt.Sprintf("%s_%s.%s", ensembleID, nodeID, allocName), nil
}
// ValidateManifestKey validates a manifest key format
func (g *DefaultAllocationIDGenerator) ValidateManifestKey(manifestKey string) error {
if manifestKey == "" {
return fmt.Errorf("manifest key cannot be empty")
}
// Validate format: nodeID.allocName
parts := strings.Split(manifestKey, ".")
if len(parts) != 2 {
return fmt.Errorf("invalid manifest key format: %s (expected nodeID.allocName)", manifestKey)
}
nodeID, allocName := parts[0], parts[1]
if nodeID == "" || allocName == "" {
return fmt.Errorf("invalid manifest key format: %s (nodeID and allocName cannot be empty)", manifestKey)
}
return nil
}
// ValidateFullAllocationID validates a full allocation ID format
func (g *DefaultAllocationIDGenerator) ValidateFullAllocationID(allocID string) error {
if allocID == "" {
return fmt.Errorf("allocation ID cannot be empty")
}
// Validate format: ensembleID_nodeID.allocName
parts := strings.SplitN(allocID, "_", 2)
if len(parts) != 2 {
return fmt.Errorf("invalid allocation ID format: %s (expected ensembleID_nodeID.allocName)", allocID)
}
ensembleID, rest := parts[0], parts[1]
if ensembleID == "" {
return fmt.Errorf("invalid allocation ID format: %s (ensemble ID cannot be empty)", allocID)
}
// Validate the rest follows manifest key format
if err := g.ValidateManifestKey(rest); err != nil {
return fmt.Errorf("invalid allocation ID format: %s (%v)", allocID, err)
}
return nil
}
// DefaultNodeIDGenerator is the default implementation of NodeIDGenerator
type DefaultNodeIDGenerator struct{}
// NewDefaultNodeIDGenerator creates a new DefaultNodeIDGenerator
func NewDefaultNodeIDGenerator() *DefaultNodeIDGenerator {
return &DefaultNodeIDGenerator{}
}
// GenerateNodeID generates a node ID from a base name
func (g *DefaultNodeIDGenerator) GenerateNodeID(baseName string) (string, error) {
if baseName == "" {
return "", fmt.Errorf("base name cannot be empty")
}
return baseName, nil
}
// GenerateStandbyNodeID generates a standby node ID from a primary node ID and standby index
func (g *DefaultNodeIDGenerator) GenerateStandbyNodeID(primaryNodeID string, standbyIndex int) (string, error) {
if primaryNodeID == "" {
return "", fmt.Errorf("primary node ID cannot be empty")
}
if standbyIndex < 1 {
return "", fmt.Errorf("standby index must be >= 1, got %d", standbyIndex)
}
return fmt.Sprintf("%s-standby-%d", primaryNodeID, standbyIndex), nil
}
// ValidateNodeID validates that a node ID is properly formatted
func (g *DefaultNodeIDGenerator) ValidateNodeID(nodeID string) error {
if nodeID == "" {
return fmt.Errorf("node ID cannot be empty")
}
return nil
}
// ParseNodeID extracts components from a node ID (if needed)
func (g *DefaultNodeIDGenerator) ParseNodeID(nodeID string) (bool, string, int, error) {
isStandby, primaryNodeID, standbyIndex := ParseNodeName(nodeID)
return isStandby, primaryNodeID, standbyIndex, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"time"
)
// ExecutionStatus is the status of an execution
type ExecutionStatus string
const (
ExecutionStatusPending ExecutionStatus = "pending"
ExecutionStatusRunning ExecutionStatus = "running"
ExecutionStatusPaused ExecutionStatus = "paused"
ExecutionStatusFailed ExecutionStatus = "failed"
ExecutionStatusSuccess ExecutionStatus = "success"
)
// ExecutionRequest is the request object for executing a job
type ExecutionRequest struct {
JobID string // ID of the job to execute
ExecutionID string // ID of the execution
EngineSpec *SpecConfig // Engine spec for the execution
Resources *Resources // Resources for the execution
Inputs []*StorageVolumeExecutor // Input volumes for the execution
Outputs []*StorageVolumeExecutor // Output volumes for the results
ResultsDir string // Directory to store the results
PersistLogsDuration time.Duration // Duration to persist logs on disk
ProvisionScripts map[string][]byte // (named) Scripts to run when initiating the execution
Keys []AllocationKey // (named) SSH public keys relevant to the allocation
PortsToBind []PortsToBind // List of ports to bind
GatewayIP string // Gateway IP to use as dns resolver
}
type PortsToBind struct {
IP string
HostPort int
ExecutorPort int
}
// ExecutionListItem is the result of the current executions.
type ExecutionListItem struct {
ExecutionID string // ID of the execution
Running bool
}
// ExecutionResult is the result of an execution
// TODO: wrap ErrorMsg + Killed in one struct
type ExecutionResult struct {
STDOUT string `json:"stdout"` // STDOUT of the execution
STDERR string `json:"stderr"` // STDERR of the execution
ExitCode int `json:"exit_code"` // Exit code of the execution
ErrorMsg string `json:"error_msg"` // Error message if the execution failed
Killed bool `json:"killed"` // Executor/Application externally killed
}
// NewExecutionResult creates a new ExecutionResult object
func NewExecutionResult(code int) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: code,
}
}
// NewFailedExecutionResult creates a new ExecutionResult object for a failed execution
// It sets the error message from the provided error and sets the exit code to -1
func NewFailedExecutionResult(err error) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: -1,
ErrorMsg: err.Error(),
}
}
// LogStreamRequest is the request object for streaming logs from an execution
type LogStreamRequest struct {
JobID string // ID of the job
ExecutionID string // ID of the execution
Tail bool // Tail the logs
Follow bool // Follow the logs
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"context"
"io"
"reflect"
"time"
)
// Executor serves as an interface for running jobs on a specific backend, such as a Docker daemon, firecracker, etc.
// It provides a comprehensive set of methods to initiate, monitor, terminate, and retrieve output streams for executions.
type Executor interface {
// GetID returns the unique identifier for the executor.
GetID() string
// Start initiates an execution for the given ExecutionRequest.
// It returns an error if the execution already exists and is in a started or terminal state.
// Implementations may also return other errors based on resource limitations or internal faults.
Start(ctx context.Context, request *ExecutionRequest) error
// Run initiates and waits for the completion of an execution for the given ExecutionRequest.
// It returns a ExecutionResult and an error if any part of the operation fails.
// Specifically, it will return an error if the execution already exists and is in a started or terminal state.
Run(ctx context.Context, request *ExecutionRequest) (*ExecutionResult, error)
// Pause attempts to pause an ongoing execution identified by its executionID.
// Returns an error if the execution does not exist or is already in a terminal state.
Pause(ctx context.Context, executionID string) error
// Resume attempts to resume a paused execution identified by its executionID.
// Returns an error if the execution does not exist, is not paused, or is already in a terminal state.
Resume(ctx context.Context, executionID string) error
// Wait monitors the completion of an execution identified by its executionID.
// It returns two channels:
// 1. A channel that emits the execution result once the task is complete.
// 2. An error channel that relays any issues encountered, such as when the
// execution is non-existent or has already concluded.
Wait(ctx context.Context, executionID string) (<-chan *ExecutionResult, <-chan error)
// Cancel attempts to cancel an ongoing execution identified by its executionID.
// Returns an error if the execution does not exist or is already in a terminal state.
Cancel(ctx context.Context, executionID string) error
// Remove removes an execution identified by its executionID.
// Returns an error if the execution does not exist
Remove(executionID string, timeout time.Duration) error
// Cleanup removes all resources associated with the executor.
// This includes stopping and removing all running containers or VMs and deleting their resources.
Cleanup(ctx context.Context) error
// GetLogStream provides a stream of output for an ongoing or completed execution identified by its executionID.
// The 'Tail' flag indicates whether to exclude hstorical data or not.
// The 'follow' flag indicates whether the stream should continue to send data as it is produced.
// Returns an io.ReadCloser to read the output stream and an error if the operation fails.
// Specifically, it will return an error if the execution does not exist.
GetLogStream(ctx context.Context, request LogStreamRequest) (io.ReadCloser, error)
// List returns a slice of ExecutionListItem containing information about current executions.
// This includes the execution ID and whether it's currently running.
List() []ExecutionListItem
// GetStatus returns the status of the execution identified by its executionID.
// It returns the status of the execution and an error if the execution is not found or status is unknown.
GetStatus(ctx context.Context, executionID string) (ExecutionStatus, error)
// WaitForStatus waits for the execution to reach a specific status.
// It returns an error if the execution is not found or the status is unknown.
WaitForStatus(ctx context.Context, executionID string, status ExecutionStatus, timeout *time.Duration) error
// Exec executes a command in a container and returns the exit code, output, and an error if the operation fails.
Exec(ctx context.Context, containerID string, command []string) (int, string, string, error)
// Stats returns the resource usage stats for a container. errors if the execution is not found or stats cannot be retrieved.
Stats(ctx context.Context, executionID string) (*ExecutorStats, error)
}
// ExecutorType is the type of the executor
type ExecutorType string
const (
ExecutorTypeDocker ExecutorType = "docker"
ExecutorTypeFirecracker ExecutorType = "firecracker"
ExecutorTypeWasm ExecutorType = "wasm"
ExecutionStatusCodeSuccess = 0
)
// implementing Comparable interface
var _ Comparable[ExecutorType] = (*ExecutorType)(nil)
// Compare compares two ExecutorType objects
func (e ExecutorType) Compare(other ExecutorType) (Comparison, error) {
return LiteralComparator(string(e), string(other)), nil
}
// String returns the string representation of the ExecutorType
func (e ExecutorType) String() string {
return string(e)
}
// Executors is a list of Executor objects
type Executors []ExecutorType
// implementing Comparable and Calculable interface
var (
_ Comparable[Executors] = (*Executors)(nil)
_ Calculable[Executors] = (*Executors)(nil)
)
// ExecutorStats represents resource usage stats for an executor.
type ExecutorStats struct {
// CPU usage
CPUUsage struct {
TotalUsage uint64 `json:"total_usage"` // total CPU time consumed in nanoseconds
UsageInKernelmode uint64 `json:"usage_in_kernelmode"` // time spent in kernel mode in nanoseconds
UsageInUsermode uint64 `json:"usage_in_usermode"` // time spent in user mode in nanoseconds
Percent float64 `json:"percent"` // usage percentage
} `json:"cpu_usage"`
// memory usage
Memory struct {
Usage uint64 `json:"usage"` // memory usage in bytes
MaxUsage uint64 `json:"max_usage"` // max memory usage in bytes
Limit uint64 `json:"limit"` // memory limit in bytes
Percent float64 `json:"percent"` // memory usage percentage
} `json:"memory"`
// network
Network struct {
RxBytes uint64 `json:"rx_bytes"` // total bytes received
RxPackets uint64 `json:"rx_packets"` // total packets received
RxErrors uint64 `json:"rx_errors"` // total receive errors
RxDropped uint64 `json:"rx_dropped"` // total receive packets dropped
TxBytes uint64 `json:"tx_bytes"` // total bytes transmitted
TxPackets uint64 `json:"tx_packets"` // total packets transmitted
TxErrors uint64 `json:"tx_errors"` // total transmit errors
TxDropped uint64 `json:"tx_dropped"` // total transmit packets dropped
} `json:"network"`
// block I/O
BlockIO struct {
ReadBytes uint64 `json:"read_bytes"` // total bytes read from block devices
WriteBytes uint64 `json:"write_bytes"` // total bytes written to block devices
} `json:"block_io"`
// timestamp when stats were collected (Unix milliseconds)
Timestamp int64 `json:"timestamp"`
}
// Add adds the Executor object to another Executor object
func (e *Executors) Add(other Executors) error {
// append to Executors slice
*e = append(*e, other...)
return nil
}
// Subtract subtracts the Executor object from another Executor object
func (e *Executors) Subtract(other Executors) error {
if len(other) == 0 {
return nil
}
toRemove := make(map[ExecutorType]struct{})
for _, ex := range other {
toRemove[ex] = struct{}{}
}
result := (*e)[:0]
for _, ex := range *e {
if _, found := toRemove[ex]; !found {
result = append(result, ex)
}
}
*e = result[:len(result):len(result)]
return nil
}
// Contains checks if an Executor object is in the list of Executors
func (e *Executors) Contains(executorType ExecutorType) bool {
executors := *e
for _, ex := range executors {
if ex == executorType {
return true
}
}
return false
}
// Compare compares two Executors objects
func (e *Executors) Compare(other Executors) (Comparison, error) {
if reflect.DeepEqual(*e, other) {
return Equal, nil
}
// comparator for Executors types:
// left represent machine capabilities;
// right represent required capabilities;
lSlice := make([]interface{}, 0, len(*e))
rSlice := make([]interface{}, 0, len(other))
for _, ex := range *e {
lSlice = append(lSlice, ex)
}
for _, ex := range other {
rSlice = append(rSlice, ex)
}
if !IsSameShallowType(lSlice, rSlice) {
return None, nil
}
switch {
case reflect.DeepEqual(lSlice, rSlice):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return Equal, nil
case IsStrictlyContained(lSlice, rSlice):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
return Better, nil
case IsStrictlyContained(rSlice, lSlice):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
return Worse, nil
default:
return None, nil
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"fmt"
)
// DefaultGeneratorValidator is the default implementation of GeneratorValidator
type DefaultGeneratorValidator struct{}
// NewDefaultGeneratorValidator creates a new DefaultGeneratorValidator
func NewDefaultGeneratorValidator() *DefaultGeneratorValidator {
return &DefaultGeneratorValidator{}
}
// ValidateAllocationIDGenerator validates that the generator won't cause conflicts
// when given the same allocation name but different node types (primary vs standby)
func (v *DefaultGeneratorValidator) ValidateAllocationIDGenerator(generator AllocationIDGenerator) error {
// Test cases to validate the generator won't cause conflicts
testCases := []struct {
name string
ensembleID string
primaryNode string
standbyNode string
allocName string
}{
{
name: "basic primary and standby",
ensembleID: "test-ensemble",
primaryNode: "node1",
standbyNode: "node1-standby-1",
allocName: "alloc1",
},
{
name: "complex node names",
ensembleID: "ensemble-123",
primaryNode: "web-server-01",
standbyNode: "web-server-01-standby-2",
allocName: "nginx-service",
},
{
name: "edge case with numbers",
ensembleID: "e1",
primaryNode: "n1",
standbyNode: "n1-standby-1",
allocName: "a1",
},
}
for _, tc := range testCases {
// Generate manifest keys for primary and standby
primaryManifestKey, err := generator.GenerateManifestKey(tc.primaryNode, tc.allocName)
if err != nil {
return fmt.Errorf("failed to generate primary manifest key for test case '%s': %w", tc.name, err)
}
standbyManifestKey, err := generator.GenerateManifestKey(tc.standbyNode, tc.allocName)
if err != nil {
return fmt.Errorf("failed to generate standby manifest key for test case '%s': %w", tc.name, err)
}
// Check if manifest keys conflict
if primaryManifestKey == standbyManifestKey {
return fmt.Errorf("manifest key conflict in test case '%s': primary key '%s' == standby key '%s'",
tc.name, primaryManifestKey, standbyManifestKey)
}
// Generate full allocation IDs for primary and standby
primaryFullID, err := generator.GenerateFullAllocationID(tc.ensembleID, tc.primaryNode, tc.allocName)
if err != nil {
return fmt.Errorf("failed to generate primary full allocation ID for test case '%s': %w", tc.name, err)
}
standbyFullID, err := generator.GenerateFullAllocationID(tc.ensembleID, tc.standbyNode, tc.allocName)
if err != nil {
return fmt.Errorf("failed to generate standby full allocation ID for test case '%s': %w", tc.name, err)
}
// Check if full allocation IDs conflict
if primaryFullID == standbyFullID {
return fmt.Errorf("full allocation ID conflict in test case '%s': primary ID '%s' == standby ID '%s'",
tc.name, primaryFullID, standbyFullID)
}
// Validate that the generated keys are properly formatted
if err := generator.ValidateManifestKey(primaryManifestKey); err != nil {
return fmt.Errorf("invalid primary manifest key in test case '%s': %w", tc.name, err)
}
if err := generator.ValidateManifestKey(standbyManifestKey); err != nil {
return fmt.Errorf("invalid standby manifest key in test case '%s': %w", tc.name, err)
}
if err := generator.ValidateFullAllocationID(primaryFullID); err != nil {
return fmt.Errorf("invalid primary full allocation ID in test case '%s': %w", tc.name, err)
}
if err := generator.ValidateFullAllocationID(standbyFullID); err != nil {
return fmt.Errorf("invalid standby full allocation ID in test case '%s': %w", tc.name, err)
}
}
return nil
}
// ValidateNodeIDGenerator validates that the generator won't cause conflicts
func (v *DefaultGeneratorValidator) ValidateNodeIDGenerator(generator NodeIDGenerator) error {
// Test cases to validate the generator won't cause conflicts
testCases := []struct {
name string
primaryNode string
standbyNode string
}{
{
name: "basic primary and standby",
primaryNode: "node1",
standbyNode: "node1-standby-1",
},
{
name: "complex node names",
primaryNode: "web-server-01",
standbyNode: "web-server-01-standby-2",
},
}
for _, tc := range testCases {
// Validate that the generator can handle both primary and standby node IDs
if err := generator.ValidateNodeID(tc.primaryNode); err != nil {
return fmt.Errorf("invalid primary node ID in test case '%s': %w", tc.name, err)
}
if err := generator.ValidateNodeID(tc.standbyNode); err != nil {
return fmt.Errorf("invalid standby node ID in test case '%s': %w", tc.name, err)
}
// Test parsing functionality
isStandby, _, _, err := generator.ParseNodeID(tc.primaryNode)
if err != nil {
return fmt.Errorf("failed to parse primary node ID in test case '%s': %w", tc.name, err)
}
if isStandby {
return fmt.Errorf("primary node ID incorrectly identified as standby in test case '%s'", tc.name)
}
isStandby, parsedPrimaryNodeID, standbyIndex, err := generator.ParseNodeID(tc.standbyNode)
if err != nil {
return fmt.Errorf("failed to parse standby node ID in test case '%s': %w", tc.name, err)
}
if !isStandby {
return fmt.Errorf("standby node ID incorrectly identified as primary in test case '%s'", tc.name)
}
if parsedPrimaryNodeID != tc.primaryNode {
return fmt.Errorf("incorrect primary node ID parsed from standby node in test case '%s': expected '%s', got '%s'",
tc.name, tc.primaryNode, parsedPrimaryNodeID)
}
if standbyIndex < 1 {
return fmt.Errorf("invalid standby index in test case '%s': %d", tc.name, standbyIndex)
}
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"fmt"
"slices"
"strings"
)
// byte units
const (
_ = iota
KB = 1 << (10 * iota)
MB
GB
TB
PB
)
// HardwareManager defines the interface for managing machine resources.
type HardwareManager interface {
GetMachineResources() (MachineResources, error)
GetUsage() (Resources, error)
GetFreeResources() (Resources, error)
CheckCapacity(resources Resources) (bool, error)
Shutdown() error
}
// GPUManager defines the interface for managing GPU resources.
type GPUManager interface {
GetGPUs() (GPUs, error)
GetGPUUsage(uuid ...string) (GPUs, error)
Shutdown() error
}
// GPUConnector vendor-specific adapter that interacts with actual GPU hardware by using vendor-specific libraries
type GPUConnector interface {
GetGPUs() (GPUs, error)
GetGPUUsage(uuid string) (uint64, error)
Shutdown() error
}
type GPUVendor string
const (
GPUVendorNvidia GPUVendor = "NVIDIA"
GPUVendorAMDATI GPUVendor = "AMD/ATI"
GPUVendorIntel GPUVendor = "Intel"
GPUVendorUnknown GPUVendor = "Unknown"
GPUVendorNone GPUVendor = "None"
)
// implementing Comparable interface
var _ Comparable[GPUVendor] = (*GPUVendor)(nil)
func (g GPUVendor) Compare(other GPUVendor) (Comparison, error) {
if g == other {
return Equal, nil
}
return None, nil
}
// ParseGPUVendor parses the GPU vendor string and returns the corresponding GPUVendor enum
func ParseGPUVendor(vendor string) GPUVendor {
switch {
case strings.Contains(strings.ToUpper(vendor), "NVIDIA"):
return GPUVendorNvidia
case strings.Contains(strings.ToUpper(vendor), "AMD") ||
strings.Contains(strings.ToUpper(vendor), "ATI"):
return GPUVendorAMDATI
case strings.Contains(strings.ToUpper(vendor), "INTEL"):
return GPUVendorIntel
default:
return GPUVendorUnknown
}
}
// GPU represents the GPU information
type GPU struct {
// Index is the self-reported index of the device in the system
Index int `json:"index" description:"GPU index in the system"`
// Vendor is the maker of the GPU, e.g. NVidia, AMD, Intel
Vendor GPUVendor `json:"vendor" description:"GPU vendor, e.g., NVidia, AMD, Intel"`
// PCIAddress is the PCI address of the device, in the format AAAA:BB:CC.C
// Used to discover the correct device rendering cards
PCIAddress string `json:"pci_address" description:"PCI address of the device, in the format AAAA:BB:CC.C"`
// Model represents the GPU model name, e.g., "Tesla T4", "A100"
Model string `json:"model" description:"GPU model, e.g., Tesla T4, A100"`
// VRAM is the total amount of VRAM on the device
VRAM uint64 `json:"vram" description:"Total amount of VRAM on the device"`
// Cores is the number of compute cores on the GPU
// (CUDA cores for NVIDIA, Compute Units for AMD, 0 if unknown)
Cores uint32 `json:"cores" description:"Number of compute cores (CUDA cores / Compute Units)"`
// UUID is the unique identifier of the device
UUID string `json:"uuid" description:"Unique identifier of the device"`
}
// implementing Comparable and Calculable interfaces
var (
_ Comparable[GPU] = (*GPU)(nil)
_ Calculable[GPU] = (*GPU)(nil)
)
// Compare compares the GPU with the other GPU
func (g *GPU) Compare(other GPU) (Comparison, error) {
comparison := make(ComplexComparison)
// compare the VRAM
switch {
case g.VRAM > other.VRAM:
comparison["VRAM"] = Better
case g.VRAM < other.VRAM:
comparison["VRAM"] = Worse
default:
comparison["VRAM"] = Equal
}
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and VRAM is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of GPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
return comparison["VRAM"], nil
}
// Add adds the other GPU to the current GPU
func (g *GPU) Add(other GPU) error {
g.VRAM += other.VRAM
return nil
}
// Subtract subtracts the other GPU from the current GPU
func (g *GPU) Subtract(other GPU) error {
if g.VRAM < other.VRAM {
return fmt.Errorf("total VRAM: underflow, cannot subtract %v from %v", other.VRAM, g.VRAM)
}
g.VRAM -= other.VRAM
return nil
}
// Equal checks if the two GPUs are equal
func (g *GPU) Equal(other GPU) bool {
return g.Model == other.Model &&
g.VRAM == other.VRAM &&
g.Cores == other.Cores &&
g.Index == other.Index &&
g.Vendor == other.Vendor &&
g.PCIAddress == other.PCIAddress
}
// VRAMInGB returns the VRAM in gigabytes
func (g *GPU) VRAMInGB() uint64 {
return ConvertBytesToGB(g.VRAM)
}
type GPUs []GPU
// implementing Comparable and Calculable interfaces
var (
_ Calculable[GPUs] = (*GPUs)(nil)
_ Comparable[GPUs] = (*GPUs)(nil)
)
// Copy returns a safe copy of the GPUs
func (gpus GPUs) Copy() GPUs {
gpusCopy := make(GPUs, len(gpus))
copy(gpusCopy, gpus)
return gpusCopy
}
func (gpus GPUs) Compare(other GPUs) (Comparison, error) {
interimComparison1 := make([][]Comparison, 0)
for _, otherGPU := range other {
var interimComparison2 []Comparison
for _, ownGPU := range gpus {
c, err := ownGPU.Compare(otherGPU)
if err != nil {
return None, fmt.Errorf("error comparing GPU: %v", err)
}
interimComparison2 = append(interimComparison2, c)
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []Comparison
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, Worse) {
return Worse, nil
}
if SliceContainsOneValue(finalComparison, Equal) {
return Equal, nil
}
return Better, nil
}
func (gpus GPUs) Add(other GPUs) error {
// TODO: I think this logic needs to change
// 1. if other gpu is in own gpus, add the total vram
// 2. if other gpu is not in own gpus, append it to own gpus
// assuming that the GPUs are ordered by index
// which may not be the case
otherGPUs := make(map[int]GPU)
for _, otherGPU := range other {
otherGPUs[otherGPU.Index] = otherGPU
}
for i, gpu := range gpus {
if otherGPU, ok := otherGPUs[gpu.Index]; ok {
if err := gpus[i].Add(otherGPU); err != nil {
return fmt.Errorf("failed to add GPU %s: %w", gpu.Model, err)
}
}
}
return nil
}
func (gpus GPUs) Subtract(other GPUs) error {
// assuming that the GPUs are ordered by index
// which may not be the case
otherGPUs := make(map[int]GPU)
for _, otherGPU := range other {
otherGPUs[otherGPU.Index] = otherGPU
}
for i, gpu := range gpus {
if otherGPU, ok := otherGPUs[gpu.Index]; ok {
if err := gpus[i].Subtract(otherGPU); err != nil {
return fmt.Errorf("failed to subtract GPU %s: %w", gpu.Model, err)
}
}
}
return nil
}
func (gpus GPUs) Equal(other GPUs) bool {
if len(gpus) != len(other) {
return false
}
used := make([]bool, len(other))
for _, gpu := range gpus {
found := false
for j, otherGPU := range other {
if !used[j] && gpu.Equal(otherGPU) {
used[j] = true
found = true
break
}
}
if !found {
return false
}
}
return true
}
// MaxFreeVRAMGPU returns the GPU with the maximum free VRAM from the list of GPUs
func (gpus GPUs) MaxFreeVRAMGPU() (GPU, error) {
if len(gpus) == 0 {
return GPU{}, fmt.Errorf("no GPUs found")
}
var maxFreeVRAMGPU GPU
for _, gpu := range gpus {
if gpu.VRAM > maxFreeVRAMGPU.VRAM {
maxFreeVRAMGPU = gpu
}
}
return maxFreeVRAMGPU, nil
}
// GetWithIndex returns the GPU with the specified index
func (gpus GPUs) GetWithIndex(index int) (GPU, error) {
for _, gpu := range gpus {
if gpu.Index == index {
return gpu, nil
}
}
return GPU{}, fmt.Errorf("GPU with index %d not found", index)
}
// CPU represents the CPU information
type CPU struct {
// ClockSpeed represents the CPU clock speed in Hz
ClockSpeed float64 `json:"clock_speed,omitempty" yaml:"clock_speed,omitempty" description:"CPU clock speed in Hz"`
// Cores represents the number of physical CPU cores
Cores float32 `json:"cores" yaml:"cores" description:"Number of physical CPU cores"`
// TODO: capture the below fields if required
// Model represents the CPU model, e.g., "Intel Core i7-9700K", "AMD Ryzen 9 5900X"
Model string `json:"model,omitempty" yaml:"model,omitempty" description:"CPU model, e.g., Intel Core i7-9700K, AMD Ryzen 9 5900X"`
// Vendor represents the CPU manufacturer, e.g., "Intel", "AMD"
Vendor string `json:"vendor,omitempty" yaml:"vendor,omitempty" description:"CPU manufacturer, e.g., Intel, AMD"`
// Threads represents the number of logical CPU threads (including hyperthreading)
Threads int `json:"threads,omitempty" yaml:"threads,omitempty" description:"Number of logical CPU threads (including hyperthreading)"`
// Architecture represents the CPU architecture, e.g., "x86", "x86_64", "arm64"
Architecture string `json:"architecture,omitempty" yaml:"architecture,omitempty" description:"CPU architecture, e.g., x86, x86_64, arm64"`
// Cache size in bytes
CacheSize uint64 `json:"cache_size,omitempty" yaml:"cache_size,omitempty" description:"CPU cache size in bytes"`
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[CPU] = (*CPU)(nil)
_ Comparable[CPU] = (*CPU)(nil)
)
// Compare compares the CPU with the other CPU
func (c *CPU) Compare(other CPU) (Comparison, error) {
perfComparison := NumericComparator(
float64(c.Cores)*c.ClockSpeed,
float64(other.Cores)*other.ClockSpeed,
)
archComparison := LiteralComparator(c.Architecture, other.Architecture)
if archComparison == Equal {
return perfComparison, nil
}
return None, nil
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and frequency is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of CPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
}
// Add adds the other CPU to the current CPU
func (c *CPU) Add(other CPU) error {
c.Cores = round(c.Cores+other.Cores, 2)
return nil
}
// Subtract subtracts the other CPU from the current CPU
func (c *CPU) Subtract(other CPU) error {
if c.Cores < other.Cores {
return fmt.Errorf("core: underflow, cannot subtract %v from %v", other.Cores, c.Cores)
}
c.Cores = round(c.Cores-other.Cores, 2)
return nil
}
// Compute returns the total compute power of the CPU in Hz
func (c *CPU) Compute() float64 {
return float64(c.Cores) * c.ClockSpeed
}
// ComputeInGHz returns the total compute power of the CPU in GHz
func (c *CPU) ComputeInGHz() float64 {
return ConvertHzToGHz(c.Compute())
}
// ClockSpeedInGHz returns the clock speed in GHz
func (c *CPU) ClockSpeedInGHz() float64 {
return ConvertHzToGHz(c.ClockSpeed)
}
// RAM represents the RAM information
type RAM struct {
// Size in bytes
Size uint64 `json:"size" yaml:"size" description:"Size of the RAM in bytes"`
// TODO: capture the below fields if required
// Clock speed in Hz
ClockSpeed uint64 `json:"clock_speed,omitempty" yaml:"clock_speed,omitempty" description:"Clock speed of the RAM in Hz"`
// Type represents the RAM type, e.g., "DDR4", "DDR5", "LPDDR4"
Type string `json:"type,omitempty" yaml:"type,omitempty" description:"RAM type, e.g., DDR4, DDR5, LPDDR4"`
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[RAM] = (*RAM)(nil)
_ Comparable[RAM] = (*RAM)(nil)
)
// Compare compares the RAM with the other RAM
func (r *RAM) Compare(other RAM) (Comparison, error) {
comparison := make(ComplexComparison)
// compare the Size
comparison["Size"] = NumericComparator(r.Size, other.Size)
comparison["ClockSpeed"] = NumericComparator(r.ClockSpeed, other.ClockSpeed)
return comparison["Size"], nil
}
// Add adds the other RAM to the current RAM
func (r *RAM) Add(other RAM) error {
r.Size += other.Size
return nil
}
// Subtract subtracts the other RAM from the current RAM
func (r *RAM) Subtract(other RAM) error {
if r.Size < other.Size {
return fmt.Errorf("size: underflow, cannot subtract %v from %v", other.Size, r.Size)
}
r.Size -= other.Size
return nil
}
// SizeInGB returns the size in gigabytes
func (r *RAM) SizeInGB() uint64 {
return ConvertBytesToGB(r.Size)
}
// SizeInGiB returns the size in gibibytes
func (r *RAM) SizeInGiB() uint64 {
return ConvertBytesToGiB(r.Size)
}
// Disk represents the disk information
type Disk struct {
// Size in bytes
Size uint64 `json:"size" yaml:"size" description:"Size of the disk in bytes"`
// TODO: capture the below fields if required
// Model represents the disk model, e.g., "Samsung 970 EVO Plus", "Western Digital Blue SN550"
Model string `json:"model,omitempty" yaml:"model,omitempty" description:"Disk model, e.g., Samsung 970 EVO Plus, Western Digital Blue SN550"`
// Vendor represents the disk manufacturer, e.g., "Samsung", "Western Digital"
Vendor string `json:"vendor,omitempty" yaml:"vendor,omitempty" description:"Disk manufacturer, e.g., Samsung, Western Digital"`
// Type represents the disk type, e.g., "SSD", "HDD", "NVMe"
Type string `json:"type,omitempty" yaml:"type,omitempty" description:"Disk type, e.g., SSD, HDD, NVMe"`
// Interface represents the disk interface, e.g., "SATA", "PCIe", "M.2"
Interface string `json:"interface,omitempty" yaml:"interface,omitempty" description:"Disk interface, e.g., SATA, PCIe, M.2"`
// Read speed in bytes per second
ReadSpeed uint64 `json:"read_speed,omitempty" yaml:"read_speed,omitempty" description:"Read speed in bytes per second"`
// Write speed in bytes per second
WriteSpeed uint64 `json:"write_speed,omitempty" yaml:"write_speed,omitempty" description:"Write speed in bytes per second"`
}
// implementing Comparable and Calculable interfaces
var (
_ Calculable[Disk] = (*Disk)(nil)
_ Comparable[Disk] = (*Disk)(nil)
)
// Compare compares the Disk with the other Disk
func (d *Disk) Compare(other Disk) (Comparison, error) {
comparison := make(ComplexComparison)
// compare the Size
comparison["Size"] = NumericComparator(d.Size, other.Size)
return comparison["Size"], nil
}
// Add adds the other Disk to the current Disk
func (d *Disk) Add(other Disk) error {
d.Size += other.Size
return nil
}
// Subtract subtracts the other Disk from the current Disk
func (d *Disk) Subtract(other Disk) error {
if d.Size < other.Size {
return fmt.Errorf("size: underflow, cannot subtract %v from %v", other.Size, d.Size)
}
d.Size -= other.Size
return nil
}
// SizeInGB returns the size in gigabytes
func (d *Disk) SizeInGB() uint64 {
return ConvertBytesToGB(d.Size)
}
// SizeInGiB returns the size in gibibytes
func (d *Disk) SizeInGiB() uint64 {
return ConvertBytesToGiB(d.Size)
}
// NetworkInfo represents the network information
// TODO: not yet used, but can be used to capture the network information
type NetworkInfo struct {
// Bandwidth in bits per second (b/s)
Bandwidth uint64
// NetworkType represents the network type, e.g., "Ethernet", "Wi-Fi", "Cellular"
NetworkType string
}
// TODO use units and convert the base only?
// ConvertBytesToGB converts bytes to gigabytes
func ConvertBytesToGB(bytes uint64) uint64 {
return bytes / 1e9
}
// ConvertBytesToGiB converts bytes to gibibytes
func ConvertBytesToGiB(bytes uint64) uint64 {
return bytes / 1073741824
}
// ConvertGBToBytes converts gigabytes to bytes
func ConvertGBToBytes(gb uint64) uint64 {
return gb * 1e9
}
// ConvertGiBToBytes converts gibibytes to bytes
func ConvertGiBToBytes(gib uint64) uint64 {
return gib * 1073741824
}
// ConvertMibToBytes converts mebibytes to bytes
// TODO should be MB?
func ConvertMibToBytes(mib uint64) uint64 {
return mib * 1024 * 1024
}
// ConvertHzToGHz converts hertz to gigahertz
func ConvertHzToGHz(hz float64) float64 {
return hz / 1e9
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"fmt"
"io"
"net/http"
"time"
)
const (
HealthCheckTypeCommand = "command"
HealthCheckTypeHTTP = "http"
)
type HealthCheckResponse struct {
Type string `json:"type" yaml:"type"` // TODO: examples of types
Value string `json:"value" yaml:"value"`
}
type HealthCheckManifest struct {
Type string `json:"type" yaml:"type"` // type of healthcheck (command, http)
Exec []string `json:"exec" yaml:"exec"` // command to execute
Endpoint string `json:"endpoint,omitempty" yaml:"endpoint,omitempty"` // endpoint to check
Response HealthCheckResponse `json:"response,omitempty" yaml:"response,omitempty"` // expected response
Interval time.Duration `json:"interval" yaml:"interval"` // interval between healthchecks
}
func NewHealthCheck(mf HealthCheckManifest, fn func(HealthCheckManifest) error) (func() error, error) {
switch mf.Type {
case HealthCheckTypeCommand:
healthcheck := func() error {
return fn(mf)
}
return healthcheck, nil
case HealthCheckTypeHTTP:
healthcheck := func() error {
res, err := http.Get(mf.Endpoint)
if err != nil {
return fmt.Errorf("health check request failed: %w", err)
}
defer res.Body.Close()
bytes, err := io.ReadAll(res.Body)
if err != nil {
return fmt.Errorf("failed to read health check request output: %w", err)
}
if string(bytes) != mf.Response.Value {
return fmt.Errorf("unexpected health check request output: %s", string(bytes))
}
return nil
}
return healthcheck, nil
default:
return nil, fmt.Errorf("unknown healthcheck type: %q", mf.Type)
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
type AllocationKeyType string
const (
KeySSH AllocationKeyType = "ssh"
KeyGPG AllocationKeyType = "gpg"
)
// AllocationKey is a key specification to be uploaded on the allocation, e.g. ssh, gpg
type AllocationKey struct {
Type AllocationKeyType `json:"type"`
File string `json:"file"` // source path to file
Dest string `json:"dest"` // destination path
}
func (t AllocationKeyType) Equal(other string) bool {
return string(t) == other
}
func (t AllocationKeyType) String() string {
return string(t)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"time"
)
const (
NetP2P = "p2p"
)
// NetworkSpec defines the network specification
type NetworkSpec interface{}
// NetConfig is a stub. Please expand it or completely change it based on requirements.
type NetConfig struct {
NetworkSpec SpecConfig `json:"network_spec"` // Network specification
}
func (nc *NetConfig) GetNetworkConfig() *SpecConfig {
return &nc.NetworkSpec
}
// NetworkStats should contain all network info the user is interested in.
// for now there's only peerID and listening address but reachability, local and remote addr etc...
// can be added when necessary.
type NetworkStats struct {
ID string `json:"id"`
ListenAddr string `json:"listen_addr"`
}
// MessageInfo is a stub. Please expand it or completely change it based on requirements.
type MessageInfo struct {
Info string `json:"info"` // Message information
}
type PingResult struct {
RTT time.Duration
Success bool
Error error
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"reflect"
)
func ConvertNumericToFloat64(n any) (float64, bool) {
switch n := n.(type) {
case int, int8, int16, int32, int64:
return float64(reflect.ValueOf(n).Int()), true
case uint, uint8, uint16, uint32, uint64:
return float64(reflect.ValueOf(n).Uint()), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"context"
"errors"
"fmt"
)
var ErrNoFreeResources = errors.New("no free resources")
// Resources represents the resources of the machine
type Resources struct {
CPU CPU `json:"cpu" yaml:"cpu"`
GPUs GPUs `json:"gpus,omitempty" yaml:"gpus,omitempty"`
RAM RAM `json:"ram" yaml:"ram"`
Disk Disk `json:"disk" yaml:"disk"`
}
// implements the Calculable and Comparable interfaces
var (
_ Calculable[Resources] = (*Resources)(nil)
_ Comparable[Resources] = (*Resources)(nil)
)
// Compare compares two Resources objects
func (r *Resources) Compare(other Resources) (Comparison, error) {
cpuComp, err := r.CPU.Compare(other.CPU)
if err != nil {
return None, fmt.Errorf("error comparing CPU: %v", err)
}
ramComp, err := r.RAM.Compare(other.RAM)
if err != nil {
return None, fmt.Errorf("error comparing RAM: %v", err)
}
diskComp, err := r.Disk.Compare(other.Disk)
if err != nil {
return None, fmt.Errorf("error comparing Disk: %v", err)
}
gpuComp, err := r.GPUs.Compare(other.GPUs)
if err != nil {
return None, fmt.Errorf("error comparing GPUs: %v", err)
}
return cpuComp.And(ramComp).And(diskComp).And(gpuComp), nil
}
// Equal returns true if the resources are equal
func (r *Resources) Equal(other Resources) bool {
if r.RAM.Size != other.RAM.Size {
return false
}
if r.CPU.Cores != other.CPU.Cores {
return false
}
if r.Disk.Size != other.Disk.Size {
return false
}
if !r.GPUs.Equal(other.GPUs) {
return false
}
return true
}
// Add returns the sum of the resources
func (r *Resources) Add(other Resources) error {
if err := r.CPU.Add(other.CPU); err != nil {
return fmt.Errorf("error adding CPU: %v", err)
}
if err := r.RAM.Add(other.RAM); err != nil {
return fmt.Errorf("error adding RAM: %v", err)
}
if err := r.Disk.Add(other.Disk); err != nil {
return fmt.Errorf("error adding Disk: %v", err)
}
if err := r.GPUs.Add(other.GPUs); err != nil {
return fmt.Errorf("error adding GPUs: %v", err)
}
return nil
}
// Subtract returns the difference of the resources
func (r *Resources) Subtract(other Resources) error {
if err := r.CPU.Subtract(other.CPU); err != nil {
return fmt.Errorf("error subtracting CPU: %v", err)
}
if err := r.RAM.Subtract(other.RAM); err != nil {
return fmt.Errorf("error subtracting RAM: %v", err)
}
if err := r.Disk.Subtract(other.Disk); err != nil {
return fmt.Errorf("error subtracting Disk: %v", err)
}
if err := r.GPUs.Subtract(other.GPUs); err != nil {
return fmt.Errorf("error subtracting GPUs: %v", err)
}
return nil
}
// MachineResources represents the total resources of the machine
type MachineResources struct {
BaseDBModel
Resources
}
// FreeResources represents the free resources of the machine
type FreeResources struct {
BaseDBModel
Resources
}
// CommittedResources represents the committed resources of the machine
type CommittedResources struct {
BaseDBModel
Resources
AllocationID string `json:"allocationID"`
}
func (c *CommittedResources) ValidateBasic() error {
if c.AllocationID == "" {
return fmt.Errorf("allocation ID is required")
}
// TODO: validate resources
return nil
}
// OnboardedResources represents the onboarded resources of the machine
type OnboardedResources struct {
BaseDBModel
Resources
}
// ResourceAllocation represents the allocation of resources for a job
type ResourceAllocation struct {
BaseDBModel
AllocationID string `json:"allocation_id"`
Resources
}
// ResourceManager is an interface that defines the methods to manage the resources of the machine
type ResourceManager interface {
// CommitResources commits the resources required by the allocation
// TODO: explicit receive Allocation ID as parameter instead of impliclty through the struct
CommitResources(context.Context, CommittedResources) error
// UncommitResources releases the resources that were committed for the allocation
UncommitResources(context.Context, string) error
// IsCommitted returns true if the resources are committed for the allocation
IsCommitted(string) (bool, error)
// AllocateResources allocates the resources required by an allocation
AllocateResources(context.Context, string) error
// DeallocateResources deallocates the resources required by an allocation
DeallocateResources(context.Context, string) error
// IsAllocated returns true if the resources are allocated for the allocation
IsAllocated(allocationID string) (bool, error)
// GetTotalAllocation returns the total allocations for the allocation
GetTotalAllocation() (Resources, error)
// GetFreeResources returns the free resources in the pool
GetFreeResources(ctx context.Context) (FreeResources, error)
// GetOnboardedResources returns the onboarded resources of the machine
GetOnboardedResources(context.Context) (OnboardedResources, error)
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
UpdateOnboardedResources(context.Context, Resources) error
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"errors"
"strings"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// SpecConfig represents a configuration for a spec
// A SpecConfig can be used to define an engine spec, a storage volume, etc.
type SpecConfig struct {
// Type of the spec (e.g. docker, firecracker, storage, etc.)
Type string `json:"type" yaml:"type"`
// Params of the spec
Params map[string]interface{} `json:"params,omitempty" yaml:"params,omitempty"`
}
type Config interface {
GetNetworkConfig() *SpecConfig
}
// NewSpecConfig creates a new SpecConfig with the given type
func NewSpecConfig(t string) *SpecConfig {
return &SpecConfig{
Type: t,
Params: make(map[string]interface{}),
}
}
// WithParam adds a new key-value pair to the spec params
func (s *SpecConfig) WithParam(key string, value interface{}) *SpecConfig {
if s.Params == nil {
s.Params = make(map[string]interface{})
}
s.Params[key] = value
return s
}
// Normalize ensures that the spec config is in a valid state
func (s *SpecConfig) Normalize() {
if s == nil {
return
}
s.Type = strings.TrimSpace(s.Type)
// Ensure that an empty and nil map are treated the same
if len(s.Params) == 0 {
s.Params = make(map[string]interface{})
}
}
// Validate checks if the spec config is valid
func (s *SpecConfig) Validate() error {
if s == nil {
return errors.New("nil spec config")
}
if validate.IsBlank(s.Type) {
return errors.New("missing spec type")
}
return nil
}
// IsType returns true if the current SpecConfig is of the given type
func (s *SpecConfig) IsType(t string) bool {
if s == nil {
return false
}
t = strings.TrimSpace(t)
return strings.EqualFold(s.Type, t)
}
// IsEmpty returns true if the spec config is empty
func (s *SpecConfig) IsEmpty() bool {
return s == nil || (validate.IsBlank(s.Type) && len(s.Params) == 0)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"fmt"
)
// TestNodeIDGenerator for testing with predictable IDs
type TestNodeIDGenerator struct {
counter int
}
// NewTestNodeIDGenerator creates a new TestNodeIDGenerator
func NewTestNodeIDGenerator() *TestNodeIDGenerator {
return &TestNodeIDGenerator{counter: 0}
}
// GenerateNodeID generates a node ID from a base name
func (g *TestNodeIDGenerator) GenerateNodeID(baseName string) (string, error) {
g.counter++
return fmt.Sprintf("%s-test-%d", baseName, g.counter), nil
}
// GenerateStandbyNodeID generates a standby node ID from a primary node ID and standby index
func (g *TestNodeIDGenerator) GenerateStandbyNodeID(primaryNodeID string, standbyIndex int) (string, error) {
if primaryNodeID == "" {
return "", fmt.Errorf("primary node ID cannot be empty")
}
if standbyIndex < 1 {
return "", fmt.Errorf("standby index must be >= 1, got %d", standbyIndex)
}
return fmt.Sprintf("%s-standby-%d", primaryNodeID, standbyIndex), nil
}
// ValidateNodeID validates that a node ID is properly formatted
func (g *TestNodeIDGenerator) ValidateNodeID(nodeID string) error {
// Simple validation for tests
if nodeID == "" {
return fmt.Errorf("node ID cannot be empty")
}
return nil
}
// ParseNodeID extracts components from a node ID (if needed)
func (g *TestNodeIDGenerator) ParseNodeID(nodeID string) (bool, string, int, error) {
isStandby, primaryNodeID, standbyIndex := ParseNodeName(nodeID)
return isStandby, primaryNodeID, standbyIndex, nil
}
// TestAllocationIDGenerator for testing
type TestAllocationIDGenerator struct{}
// NewTestAllocationIDGenerator creates a new TestAllocationIDGenerator
func NewTestAllocationIDGenerator() *TestAllocationIDGenerator {
return &TestAllocationIDGenerator{}
}
// GenerateManifestKey generates the key for manifest.Allocations map
func (g *TestAllocationIDGenerator) GenerateManifestKey(nodeID, allocName string) (string, error) {
return fmt.Sprintf("%s.%s", nodeID, allocName), nil
}
// GenerateFullAllocationID generates the full allocation ID for actor handles
func (g *TestAllocationIDGenerator) GenerateFullAllocationID(ensembleID, nodeID, allocName string) (string, error) {
return fmt.Sprintf("%s_%s.%s", ensembleID, nodeID, allocName), nil
}
// ValidateManifestKey validates a manifest key format
func (g *TestAllocationIDGenerator) ValidateManifestKey(manifestKey string) error {
if manifestKey == "" {
return fmt.Errorf("manifest key cannot be empty")
}
return nil
}
// ValidateFullAllocationID validates a full allocation ID format
func (g *TestAllocationIDGenerator) ValidateFullAllocationID(allocID string) error {
if allocID == "" {
return fmt.Errorf("allocation ID cannot be empty")
}
return nil
}
// FailingAllocationIDGenerator for testing error cases
type FailingAllocationIDGenerator struct{}
// NewFailingAllocationIDGenerator creates a new FailingAllocationIDGenerator
func NewFailingAllocationIDGenerator() *FailingAllocationIDGenerator {
return &FailingAllocationIDGenerator{}
}
// GenerateManifestKey generates the key for manifest.Allocations map
func (g *FailingAllocationIDGenerator) GenerateManifestKey(_, _ string) (string, error) {
// This generator always returns the same key to test conflicts
return "conflict.key", nil
}
// GenerateFullAllocationID generates the full allocation ID for actor handles
func (g *FailingAllocationIDGenerator) GenerateFullAllocationID(_, _, _ string) (string, error) {
// This generator always returns the same ID to test conflicts
return "ensemble_conflict.key", nil
}
// ValidateManifestKey validates a manifest key format
func (g *FailingAllocationIDGenerator) ValidateManifestKey(manifestKey string) error {
if manifestKey == "" {
return fmt.Errorf("manifest key cannot be empty")
}
return nil
}
// ValidateFullAllocationID validates a full allocation ID format
func (g *FailingAllocationIDGenerator) ValidateFullAllocationID(allocID string) error {
if allocID == "" {
return fmt.Errorf("allocation ID cannot be empty")
}
return nil
}
// TestGeneratorValidator for testing validator behavior
type TestGeneratorValidator struct{}
// NewTestGeneratorValidator creates a new TestGeneratorValidator
func NewTestGeneratorValidator() *TestGeneratorValidator {
return &TestGeneratorValidator{}
}
// ValidateAllocationIDGenerator validates that the generator won't cause conflicts
func (v *TestGeneratorValidator) ValidateAllocationIDGenerator(generator AllocationIDGenerator) error {
// Simple test - just check if generator can generate different keys for primary and standby
primaryKey, err := generator.GenerateManifestKey("node1", "alloc1")
if err != nil {
return fmt.Errorf("failed to generate primary key: %w", err)
}
standbyKey, err := generator.GenerateManifestKey("node1-standby-1", "alloc1")
if err != nil {
return fmt.Errorf("failed to generate standby key: %w", err)
}
if primaryKey == standbyKey {
return fmt.Errorf("conflict detected: primary key '%s' == standby key '%s'", primaryKey, standbyKey)
}
return nil
}
// ValidateNodeIDGenerator validates that the generator won't cause conflicts
func (v *TestGeneratorValidator) ValidateNodeIDGenerator(generator NodeIDGenerator) error {
// Simple test - just validate that generator can handle basic node IDs
if err := generator.ValidateNodeID("node1"); err != nil {
return fmt.Errorf("failed to validate primary node ID: %w", err)
}
if err := generator.ValidateNodeID("node1-standby-1"); err != nil {
return fmt.Errorf("failed to validate standby node ID: %w", err)
}
return nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"net"
"time"
"github.com/oschwald/geoip2-golang"
)
// BaseDBModel is a base model for all entities. It'll be mainly used for database
// records.
type BaseDBModel struct {
ID string
CreatedAt time.Time `json:",omitempty,omitzero"`
UpdatedAt time.Time `json:",omitempty,omitzero"`
DeletedAt time.Time `json:",omitempty,omitzero"`
}
// GetID returns the ID of the entity.
func (m BaseDBModel) GetID() string {
return m.ID
}
// GeoIPLocator returns the info about an IP.
type GeoIPLocator interface {
Country(ipAddress net.IP) (*geoip2.Country, error)
City(ipAddress net.IP) (*geoip2.City, error)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package types
import (
"math"
"reflect"
"slices"
)
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
// round rounds the value to the specified number of decimal places
func round[T float32 | float64](value T, places int) T {
factor := math.Pow(10, float64(places))
roundedValue := math.Round(float64(value)*factor) / factor
return T(roundedValue)
}
func SliceContainsOneValue(slice []Comparison, value Comparison) bool {
// returns true if all elements in the slice are equal to the given value
for _, v := range slice {
if v != value {
return false
}
}
return true
}
func returnBestMatch(dimension []Comparison) (Comparison, int) {
bestMatch := None
bestIndex := -1
for i, v := range dimension {
switch v {
case Equal:
return v, i // Equal is the best, so return immediately
case Better:
if bestMatch != Better { // Prioritize Better only if we haven't seen one yet
bestMatch = Better
bestIndex = i
}
case Worse:
if bestMatch == None { // Prioritize Worse only if nothing better found
bestMatch = Worse
bestIndex = i
}
default:
// Ignore None
}
}
return bestMatch, bestIndex
}
func removeIndex(slice [][]Comparison, index int) [][]Comparison {
// removeIndex removes the element at the specified index from each sub-slice in the given slice.
// If the index is out of bounds for a sub-slice, the function leaves that sub-slice unmodified.
for i, c := range slice {
if index < 0 || index >= len(c) {
// Index is out of bounds, leave the sub-slice unmodified
continue
}
slice[i] = append(c[:index], c[index+1:]...)
}
return slice
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package convert
import (
"fmt"
"strconv"
)
// StringToFloat64 converts a string to float64 with proper error handling
func StringToFloat64(s string, fieldName string) (float64, error) {
if s == "" {
return 0, fmt.Errorf("%s is required and cannot be empty", fieldName)
}
value, err := strconv.ParseFloat(s, 64)
if err != nil {
return 0, fmt.Errorf("invalid %s: %w", fieldName, err)
}
return value, nil
}
// StringToInt converts a string to int with proper error handling
func StringToInt(s string, fieldName string) (int, error) {
if s == "" {
return 0, fmt.Errorf("%s is required and cannot be empty", fieldName)
}
value, err := strconv.Atoi(s)
if err != nil {
return 0, fmt.Errorf("invalid %s: %w", fieldName, err)
}
return value, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package convert
import (
"fmt"
"reflect"
"regexp"
"strconv"
"strings"
"github.com/dustin/go-humanize"
)
const (
// BytesPerGB is the number of bytes in 1 GB (SI standard: 1 GB = 10^9 bytes)
BytesPerGB = 1e9
)
// ToPositiveFloat64 converts various numeric types to float64 and validates it's positive
func ToPositiveFloat64(value any, fieldName string) (float64, error) {
var result float64
switch v := value.(type) {
case float64:
result = v
case float32:
result = float64(v)
case int, int8, int16, int32, int64:
result = float64(reflect.ValueOf(v).Int())
case uint, uint8, uint16, uint32, uint64:
result = float64(reflect.ValueOf(v).Uint())
case string:
v = strings.TrimSpace(v)
parsed, err := strconv.ParseFloat(v, 64)
if err != nil {
return 0, fmt.Errorf("%s must be a valid number", fieldName)
}
result = parsed
default:
return 0, fmt.Errorf("%s must be a valid number", fieldName)
}
if result <= 0 {
return 0, fmt.Errorf("%s must be positive", fieldName)
}
return result, nil
}
// Matches: optional sign, digits, optional decimal point and digits, optional e/E notation, optional space, optional unit
var numberWithUnitRegex = regexp.MustCompile(`^([+-]?\d*\.?\d+(?:[eE][+-]?\d+)?)\s*([a-zA-Z]+)?$`)
// extractValueAndUnit takes an input value of any type and returns the numeric value and unit if present.
// The numeric value is parsed and normalized to avoid scientific notation.
func extractValueAndUnit(value any) (float64, string, error) {
// Convert value to string
strValue := fmt.Sprintf("%v", value)
// Remove leading and trailing spaces
strValue = strings.TrimSpace(strValue)
// Match the input string
matches := numberWithUnitRegex.FindStringSubmatch(strValue)
// If the number part is not captured, return an error
if len(matches) == 0 {
return 0, "", fmt.Errorf("invalid numeric value: %v", value)
}
// Parse the numeric part to float
numValue, err := strconv.ParseFloat(matches[1], 64)
if err != nil {
return 0, "", fmt.Errorf("invalid numeric value: %v", value)
}
// Return the numeric value and unit (if present)
return numValue, matches[2], nil
}
// ParseBytesWithDefaultUnit converts a value to bytes using humanize.ParseBytes.
// If the value has no unit suffix, appends the default unit before parsing.
// Example: ParseBytesWithDefaultUnit("100", "GiB") -> 107374182400 (100 GiB in bytes)
func ParseBytesWithDefaultUnit(value any, defaultUnit string) (uint64, error) {
num, unit, err := extractValueAndUnit(value)
if err != nil {
return 0, err
}
// Use the extracted unit or default
if unit == "" {
unit = defaultUnit
}
// Format the value with the unit for humanize.ParseBytes
strVal := fmt.Sprintf("%.10f%s", num, unit)
return humanize.ParseBytes(strVal)
}
// ParseSIWithDefaultUnit converts a value to SI units using humanize.ParseSI.
// If the value has no unit suffix, appends the default unit before parsing.
// Example: ParseSIWithDefaultUnit("2.4", "GHz") -> 2400000000 (2.4 GHz in Hz)
func ParseSIWithDefaultUnit(value any, defaultUnit string) (float64, error) {
num, unit, err := extractValueAndUnit(value)
if err != nil {
return 0, err
}
// Use the extracted unit or default
if unit == "" {
unit = defaultUnit
} else if defaultUnit != "" {
// Handle both cases:
// 1. SI prefixes (e.g., Hz vs GHz, B vs MB)
// 2. Same length units with different prefixes (e.g., MB vs GB)
if !strings.HasSuffix(strings.ToLower(unit), strings.ToLower(defaultUnit)) &&
!strings.HasSuffix(strings.ToLower(defaultUnit), strings.ToLower(unit)) &&
(len(unit) != len(defaultUnit) || !strings.EqualFold(unit[1:], defaultUnit[1:])) {
return 0, fmt.Errorf("invalid unit %q", unit)
}
}
// Format the value with the unit for humanize.ParseSI
strVal := fmt.Sprintf("%.10f%s", num, unit)
parsed, _, err := humanize.ParseSI(strVal)
return parsed, err
}
func ToBytesFormat(value any) (string, error) {
v, err := strconv.ParseUint(fmt.Sprintf("%v", value), 10, 64)
if err != nil {
return "", err
}
return humanize.BigBytes(humanize.BigByte.SetUint64(v)), nil
}
func ToSIFormatWithUnit(value any, unit string) (string, error) {
v, err := strconv.ParseFloat(fmt.Sprintf("%v", value), 64)
if err != nil {
return "", err
}
return humanize.SI(v, unit), nil
}
// BytesToGB converts bytes to gigabytes (GB) as float64 for precision.
// Uses SI standard: 1 GB = 10^9 bytes = 1,000,000,000 bytes
// Supports various numeric types: uint64, int64, float64, string, etc.
// Example: BytesToGB(5000000000) -> 5.0
func BytesToGB(bytes any) (float64, error) {
var bytesValue uint64
switch v := bytes.(type) {
case uint64:
bytesValue = v
case uint32:
bytesValue = uint64(v)
case uint16:
bytesValue = uint64(v)
case uint8:
bytesValue = uint64(v)
case uint:
bytesValue = uint64(v)
case int64:
if v < 0 {
return 0, fmt.Errorf("bytes cannot be negative")
}
bytesValue = uint64(v)
case int32:
if v < 0 {
return 0, fmt.Errorf("bytes cannot be negative")
}
bytesValue = uint64(v)
case int16:
if v < 0 {
return 0, fmt.Errorf("bytes cannot be negative")
}
bytesValue = uint64(v)
case int8:
if v < 0 {
return 0, fmt.Errorf("bytes cannot be negative")
}
bytesValue = uint64(v)
case int:
if v < 0 {
return 0, fmt.Errorf("bytes cannot be negative")
}
bytesValue = uint64(v)
case float64:
if v < 0 {
return 0, fmt.Errorf("bytes cannot be negative")
}
bytesValue = uint64(v)
case float32:
if v < 0 {
return 0, fmt.Errorf("bytes cannot be negative")
}
bytesValue = uint64(v)
case string:
parsed, err := strconv.ParseUint(v, 10, 64)
if err != nil {
return 0, fmt.Errorf("invalid bytes value: %w", err)
}
bytesValue = parsed
default:
return 0, fmt.Errorf("unsupported type for bytes conversion: %T", bytes)
}
return float64(bytesValue) / BytesPerGB, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package convert
import (
"fmt"
"time"
"gitlab.com/nunet/device-management-service/tokenomics/contracts"
)
// ParsePaymentPeriod converts a payment period string to duration
func ParsePaymentPeriod(period string) (time.Duration, error) {
switch period {
case contracts.PaymentPeriodMinute:
return time.Minute, nil
case contracts.PaymentPeriodHour:
return time.Hour, nil
case contracts.PaymentPeriodDay:
return 24 * time.Hour, nil
case contracts.PaymentPeriodWeek:
return 7 * 24 * time.Hour, nil
case contracts.PaymentPeriodMonth:
return 30 * 24 * time.Hour, nil // Approximate
default:
return 0, fmt.Errorf("unsupported payment period: %s", period)
}
}
// CalculateElapsedPeriods calculates elapsed billing periods
func CalculateElapsedPeriods(
lastInvoiceAt time.Time,
now time.Time,
periodDuration time.Duration,
periodCount int,
) (billingCyclesElapsed int, periodStart, periodEnd time.Time) {
if periodCount <= 0 {
periodCount = 1
}
elapsed := now.Sub(lastInvoiceAt)
periodsElapsed := int(elapsed / periodDuration)
billingCyclesElapsed = periodsElapsed / periodCount
if billingCyclesElapsed < 1 {
return 0, time.Time{}, time.Time{}
}
periodStart = lastInvoiceAt.Truncate(periodDuration)
if periodStart.Before(lastInvoiceAt) {
periodStart = periodStart.Add(periodDuration)
}
periodsToInvoice := billingCyclesElapsed * periodCount
periodEnd = periodStart.Add(periodDuration * time.Duration(periodsToInvoice))
return billingCyclesElapsed, periodStart, periodEnd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package convert
import (
"fmt"
"time"
)
// DurationToUnit converts a duration to the specified time unit
func DurationToUnit(duration time.Duration, unit string) (float64, error) {
switch unit {
case "second":
return duration.Seconds(), nil
case "minute":
return duration.Minutes(), nil
case "hour":
return duration.Hours(), nil
default:
return 0, fmt.Errorf("unsupported time unit: %s", unit)
}
}
// SecondsToUnit converts seconds (float64) to the specified time unit
func SecondsToUnit(seconds float64, unit string) (float64, error) {
duration := time.Duration(seconds * float64(time.Second))
return DurationToUnit(duration, unit)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"github.com/spf13/afero"
)
func GetDirectorySize(fs afero.Fs, path string) (int64, error) {
var size int64
err := afero.Walk(fs, path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to calculate volume size: %w", err)
}
return size, nil
}
// WriteToFile writes data to a file.
func WriteToFile(fs afero.Fs, data []byte, filePath string) (string, error) {
if err := fs.MkdirAll(filepath.Dir(filePath), os.ModePerm); err != nil {
return "", fmt.Errorf("failed to open path: %w", err)
}
file, err := fs.Create(filePath)
if err != nil {
return "", fmt.Errorf("failed to create path: %w", err)
}
defer file.Close()
n, err := file.Write(data)
if err != nil {
return "", fmt.Errorf("failed to write data to path: %w", err)
}
if n != len(data) {
return "", errors.New("failed to write the size of data to file")
}
return filePath, nil
}
// FileExists checks if destination file exists
func FileExists(fs afero.Fs, filename string) bool {
info, err := fs.Stat(filename)
if os.IsNotExist(err) {
return false
}
return !info.IsDir()
}
func CreateDirIfNotExists(fs afero.Afero, path string) error {
if _, err := fs.Stat(path); os.IsNotExist(err) {
err := fs.MkdirAll(path, 0o777) // Creates parent directories if needed
if err != nil {
return fmt.Errorf("failed to create directory: %w", err)
}
}
return nil
}
// CurrentFileDirectory returns the path of this file
func CurrentFileDirectory() string {
_, file, _, ok := runtime.Caller(0)
if !ok {
return ""
}
return filepath.Dir(file)
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/url"
"path"
"time"
"go.elastic.co/apm/module/apmhttp"
)
type HTTPClient struct {
BaseURL string
APIVersion string
Client *http.Client
}
// NewHTTPClient creates a new HTTPClient with APM instrumentation.
func NewHTTPClient(baseURL, version string) *HTTPClient {
return &HTTPClient{
BaseURL: baseURL,
APIVersion: version,
Client: apmhttp.WrapClient(http.DefaultClient),
}
}
// MakeRequest performs an HTTP request with the given method, path, and body.
// It returns the response body, status code, and an error if any.
func (c *HTTPClient) MakeRequest(ctx context.Context, method, relativePath string, body []byte) ([]byte, int, error) {
url, err := url.Parse(c.BaseURL)
if err != nil {
return nil, 0, fmt.Errorf("failed to parse base URL: %v", err)
}
url.Path = path.Join(c.APIVersion, relativePath)
req, err := http.NewRequestWithContext(ctx, method, url.String(), bytes.NewBuffer(body))
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
c.Client.Timeout = 80 * time.Second
resp, err := c.Client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("request failed: %v", err)
}
defer resp.Body.Close()
// Read the response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("failed to read response body: %v", err)
}
return respBody, resp.StatusCode, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"bufio"
"errors"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"golang.org/x/term"
)
const (
reonboardPrompt = "Looks like your machine is already onboarded. Proceed with reonboarding?"
passphrasePrompt = "Passphrase"
confirmPassphrasePrompt = "Please confirm your passphrase"
)
var ErrOperationCancelled = errors.New("operation cancelled by user")
// PromptReonboard is a wrapper of utils.PromptYesNo with custom prompt that return error if user declines reonboard
func PromptReonboard(r io.Reader, w io.Writer) error {
confirmed, err := PromptYesNo(r, w, reonboardPrompt)
if err != nil {
return fmt.Errorf("could not confirm reonboarding: %w", err)
}
if !confirmed {
return fmt.Errorf("reonboarding aborted by user")
}
return nil
}
func PromptForPassphrase(confirm bool) (string, error) {
maxTries := 3
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM)
done := make(chan bool)
var passphrase string
var err error
// Start a goroutine to handle passphrase input
go func() {
defer close(done)
var bytePassphrase, byteConfirmation []byte
for range maxTries {
fmt.Printf("%s: ", passphrasePrompt)
bytePassphrase, err = term.ReadPassword(int(os.Stdin.Fd()))
if err != nil {
err = fmt.Errorf("failed to read passphrase: %w", err)
return
}
fmt.Println("") // new line after passphrase input
if confirm {
fmt.Printf("%s: ", confirmPassphrasePrompt)
byteConfirmation, err = term.ReadPassword(int(os.Stdin.Fd()))
if err != nil {
err = fmt.Errorf("failed to read passphrase confirmation: %w", err)
return
}
fmt.Println("") // new line after passphrase confirm input
if string(bytePassphrase) != string(byteConfirmation) {
err = fmt.Errorf("passphrases do not match")
}
}
if err == nil {
passphrase = string(bytePassphrase)
return
}
fmt.Println(err)
fmt.Println("")
}
err = fmt.Errorf("user failed to input passphrase")
}()
// Wait for either the passphrase input to complete or an interrupt signal
select {
case <-done:
return passphrase, err
case <-sigChan:
return "", errors.New("sigterm received")
}
}
// PromptYesNo loops on confirmation from user until valid answer
func PromptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) {
reader := bufio.NewReader(in)
for {
fmt.Fprintf(out, "%s (y/N): ", prompt)
response, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("read response string failed: %w", err)
}
response = strings.ToLower(strings.TrimSpace(response))
switch response {
case "y", "yes":
return true, nil
case "n", "no":
return false, nil
}
}
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"fmt"
"strings"
"sync"
)
// A SyncMap is a concurrency-safe sync.Map that uses strongly-typed
// method signatures to ensure the types of its stored data are known.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// SyncMapFromMap converts a standard Go map to a concurrency-safe SyncMap.
func SyncMapFromMap[K comparable, V any](m map[K]V) *SyncMap[K, V] {
ret := &SyncMap[K, V]{}
for k, v := range m {
ret.Put(k, v)
}
return ret
}
// Get retrieves the value associated with the given key from the map.
// It returns the value and a boolean indicating whether the key was found.
func (m *SyncMap[K, V]) Get(key K) (V, bool) {
value, ok := m.Load(key)
if !ok {
var empty V
return empty, false
}
return value.(V), true
}
// Put inserts or updates a key-value pair in the map.
func (m *SyncMap[K, V]) Put(key K, value V) {
m.Store(key, value)
}
// Iter iterates over each key-value pair in the map, executing the provided function on each pair.
// The iteration stops if the provided function returns false.
func (m *SyncMap[K, V]) Iter(ranger func(key K, value V) bool) {
m.Range(func(key, value any) bool {
k := key.(K)
v := value.(V)
return ranger(k, v)
})
}
// Keys returns a slice containing all the keys present in the map.
func (m *SyncMap[K, V]) Keys() []K {
keys := make([]K, 0)
m.Iter(func(key K, _ V) bool {
keys = append(keys, key)
return true
})
return keys
}
// String provides a string representation of the map, listing all key-value pairs.
func (m *SyncMap[K, V]) String() string {
// Use a strings.Builder for efficient string concatenation.
var sb strings.Builder
sb.WriteString("{")
first := true
m.Range(func(key, value any) bool {
if !first {
sb.WriteString(" ")
}
first = false
// Properly format the key-value pair
sb.WriteString(fmt.Sprintf("%v=%v", key, value))
return true
})
sb.WriteString("}")
return sb.String()
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package sys
import (
"os/exec"
"syscall"
"golang.org/x/sys/unix"
)
// Commander interface represents a command that can be executed.
type Commander interface {
CombinedOutput() ([]byte, error)
}
// ExecFunc defines the function signature for command execution.
type ExecFunc func(name string, args ...string) Commander
// ExecCommand is the function used to run external commands.
// It defaults to exec.Command with added capabilities.
var ExecCommand ExecFunc = func(name string, args ...string) Commander {
cmd := exec.Command(name, args...)
cmd.SysProcAttr = &syscall.SysProcAttr{
AmbientCaps: []uintptr{
unix.CAP_SYS_ADMIN,
},
}
return cmd
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package sys
import (
"net"
"strings"
"github.com/songgao/water"
)
// types for tun tap
type TunTapMode int
const (
NetTunMode TunTapMode = iota
NetTapMode
NuNetIptablesChain = "NUNET"
)
// NetInterface defines the interface for network interfaces (TUN/TAP)
type NetInterface interface {
Name() string
Write([]byte) (int, error)
Read([]byte) (int, error)
Up() error
Down() error
Close() error
Delete() error
SetAddress(string) error
SetMTU(int) error
}
type netiface struct {
iface *water.Interface
}
func (n *netiface) Name() string {
return n.iface.Name()
}
func (n *netiface) Read(packet []byte) (int, error) {
return n.iface.Read(packet)
}
func (n *netiface) Write(packet []byte) (int, error) {
return n.iface.Write(packet)
}
func (n *netiface) Close() error {
return n.iface.Close()
}
// GetNetInterfaces gets the list of network interfaces
func GetNetInterfaces() ([]net.Interface, error) {
return net.Interfaces()
}
// GetNetInterfaceByName gets the network interface by name
func GetNetInterfaceByName(name string) (*net.Interface, error) {
return net.InterfaceByName(name)
}
// GetNetInterfaceByFlags gets the network interface by the flags
func GetNetInterfaceByFlags(flag net.Flags) (*net.Interface, error) {
ifaces, err := GetNetInterfaces()
if err != nil {
return nil, err
}
for _, iface := range ifaces {
if iface.Flags&flag != 0 {
return &iface, nil
}
}
return nil, nil
}
func GetUsedAddresses() ([]string, error) {
ifaces, err := GetNetInterfaces()
if err != nil {
return nil, err
}
var networks []string
for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
return nil, err
}
for _, addr := range addrs {
if !strings.Contains(addr.String(), ":") {
networks = append(networks, addr.String())
}
}
}
return networks, nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
//go:build linux
package sys
import (
"fmt"
"net"
"os"
"os/exec"
"syscall"
"github.com/songgao/water"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
)
// NewTunTapInterface creates a new tun/tap interface
func NewTunTapInterface(name string, mode TunTapMode, persist bool) (NetInterface, error) {
var intMode water.DeviceType = water.TAP
if mode == NetTunMode {
intMode = water.TUN
}
config := water.Config{
DeviceType: intMode,
}
config.Name = name
if persist {
config.PlatformSpecificParams.Persist = true
}
iface, err := water.New(config)
if err != nil {
return nil, fmt.Errorf("error creating interface: %w", err)
}
return &netiface{
iface: iface,
}, nil
}
// UpNetInterface brings the network interface up
func (n *netiface) Up() error {
link, err := netlink.LinkByName(n.iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("error setting network interface up: %w", err)
}
return nil
}
// DownNetInterface brings the network interface down
func (n *netiface) Down() error {
link, err := netlink.LinkByName(n.iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.LinkSetDown(link); err != nil {
return fmt.Errorf("error setting network interface down: %w", err)
}
return nil
}
// DeleteNetInterface deletes the network interface
func (n *netiface) Delete() error {
link, err := netlink.LinkByName(n.iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.LinkDel(link); err != nil {
return fmt.Errorf("error deleting network interface: %w", err)
}
return nil
}
// SetMTU sets the MTU of the network interface
func (n *netiface) SetMTU(mtu int) error {
link, err := netlink.LinkByName(n.iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.LinkSetMTU(link, mtu); err != nil {
return fmt.Errorf("error setting network interface mtu: %w", err)
}
return nil
}
// SetAddress sets the address of the network interface in CIDR notation
func (n *netiface) SetAddress(address string) error {
addr, err := netlink.ParseAddr(address)
if err != nil {
return fmt.Errorf("error parsing address: %w", err)
}
link, err := netlink.LinkByName(n.iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
if err := netlink.AddrAdd(link, addr); err != nil {
return fmt.Errorf("error setting network interface address: %w", err)
}
return nil
}
// AddRouteRule adds an ip route rule to the network interface
func (n *netiface) AddRouteRule(src, dst, gw string) error {
link, err := netlink.LinkByName(n.iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
var gwIP net.IP
if gw != "" {
err = gwIP.UnmarshalText([]byte(gw))
if err != nil {
return fmt.Errorf("error parsing gw: %w", err)
}
}
var destNet *net.IPNet
if dst != "" {
_, destNet, err = net.ParseCIDR(dst)
if err != nil {
return fmt.Errorf("error parsing dest net: %w", err)
}
}
var srcIP net.IP
if src != "" {
err = srcIP.UnmarshalText([]byte(src))
if err != nil {
return fmt.Errorf("error parsing src: %w", err)
}
}
routes, err := netlink.RouteGet(destNet.IP)
if err != nil {
return fmt.Errorf("error getting routes: %w", err)
}
for _, r := range routes {
if r.Dst.IP.Equal(destNet.IP) && r.LinkIndex == link.Attrs().Index {
// route rule already exists for interface
return nil
}
}
err = netlink.RouteAdd(&netlink.Route{
LinkIndex: link.Attrs().Index,
Src: srcIP,
Dst: destNet,
Gw: gwIP,
Priority: 3000,
})
if err != nil {
if os.IsExist(err) || err.Error() == "file exists" {
// route already exists
// we could instead use RouteReplace if we want
// to force deletion and adding of route
return nil
}
return fmt.Errorf("error adding route: %w", err)
}
return nil
}
// DelRoute deletes a route from the network interface
func (n *netiface) DelRoute(route string) error {
link, err := netlink.LinkByName(n.iface.Name())
if err != nil {
return fmt.Errorf("error getting network interface by name: %w", err)
}
netRoute, err := netlink.ParseIPNet(route)
if err != nil {
return fmt.Errorf("error parsing route: %w", err)
}
return netlink.RouteDel(&netlink.Route{
LinkIndex: link.Attrs().Index,
Dst: netRoute,
Priority: 3000,
})
}
// ForwardingEnabled checks if IP forwarding is enabled
func ForwardingEnabled() (bool, error) {
data, err := os.ReadFile("/proc/sys/net/ipv4/ip_forward")
if err != nil {
return false, fmt.Errorf("error reading ip_forward: %v", err)
}
if string(data) == "1\n" {
return true, nil
}
return false, nil
}
// CreateNuNetChain creates a custom iptables chain on filter and nat tables
func CreateNuNetChain() error {
if !chainExist("filter", NuNetIptablesChain) {
out, err := iptables([]string{
"-t", "filter", "-N", NuNetIptablesChain,
})
if err != nil {
return fmt.Errorf("error creating chain: %w, output: %s", err, out)
}
}
if !chainExist("nat", NuNetIptablesChain) {
out, err := iptables([]string{
"-t", "nat", "-N", NuNetIptablesChain,
})
if err != nil {
return fmt.Errorf("error creating chain: %w, output: %s", err, out)
}
}
return nil
}
// FlushNuNetChain flushes the custom iptables chain on filter and nat tables
func FlushNuNetChain() error {
if err := flushChain("filter", NuNetIptablesChain); err != nil {
return fmt.Errorf("error flushing filter chain: %w", err)
}
if err := flushChain("nat", NuNetIptablesChain); err != nil {
return fmt.Errorf("error flushing nat chain: %w", err)
}
return nil
}
// chainExist checks if the specified chain exists on the specified table
func chainExist(table, chain string) bool {
args := []string{
"-t", table, "-L", chain,
}
_, err := iptables(args)
return err == nil
}
// flushChain flushes all rules in the specified chain
func flushChain(table, chain string) error {
args := []string{
"-t", table, "-F", chain,
}
if chainExist(table, chain) {
out, err := iptables(args)
if err != nil {
return fmt.Errorf("error creating chain: %w, output: %s", err, out)
}
}
return nil
}
// AddJumpRules adds a jump rule from FORWARD, OUTPUT, and PREROUTING chains to NUNET chain
func AddJumpRules() error {
// FORWARD
args := []string{"FORWARD", "-t", "filter", "-j", NuNetIptablesChain}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding jump rule: %w", err)
}
}
// OUTPUT
args = []string{"OUTPUT", "-t", "nat", "-j", NuNetIptablesChain}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding jump rule: %w", err)
}
}
// PREROUTING
args = []string{"PREROUTING", "-t", "nat", "-j", NuNetIptablesChain}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding jump rule: %w", err)
}
}
return nil
}
// DelJumpRule deletes a jump rule from src chain+table to dest chain
func DelJumpRules() error {
// FORWARD
args := []string{"FORWARD", "-t", "filter", "-j", NuNetIptablesChain}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error adding jump rule: %w", err)
}
}
// OUTPUT
args = []string{"OUTPUT", "-t", "nat", "-j", NuNetIptablesChain}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error adding jump rule: %w", err)
}
}
// PREROUTING
args = []string{"PREROUTING", "-t", "nat", "-j", NuNetIptablesChain}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error adding jump rule: %w", err)
}
}
return nil
}
// AddDNATRule adds a DNAT rule to iptables PRERROUTING chain
func AddDNATRule(protocol, sourcePort, destIP, destPort string) error {
args := []string{
NuNetIptablesChain, "-t", "nat", "-p", protocol,
"--dport", sourcePort, "-j", "DNAT",
"--to-destination", destIP + ":" + destPort,
}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf(
"error adding DNAT rule (you might have to upgrade your iptables. See DMS readme): %w",
err,
)
}
}
return nil
}
// DelDNATRule deletes a DNAT rule to iptables PREROUTING chain if it exists
func DelDNATRule(protocol, sourcePort, destIP, destPort string) error {
args := []string{
NuNetIptablesChain, "-t", "nat", "-p", protocol,
"--dport", sourcePort, "-j", "DNAT",
"--to-destination", destIP + ":" + destPort,
}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error deleting DNAT rule: %w", err)
}
}
return nil
}
// AddForwardRule adds an ip:port FORWARD rule to iptables
func AddForwardRule(protocol, destIP, destPort string) error {
args := []string{
NuNetIptablesChain, "-t", "filter",
"-p", protocol, "-d", destIP,
"--dport", destPort, "-j", "ACCEPT",
}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding forward rule: %w", err)
}
}
return nil
}
// DelForwardRule deletes an ip:port FORWARD rule if it exists
func DelForwardRule(protocol, destIP, destPort string) error {
args := []string{
NuNetIptablesChain, "-t", "filter",
"-p", protocol, "-d", destIP,
"--dport", destPort, "-j", "ACCEPT",
}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error deleting forward rule: %w", err)
}
}
return nil
}
// AddOutputNatRule adds an iptable DNAT rule to OUTPUT chain nat table to redirect
// the locally originating packets not included in the prerouting chain
// takes the iface name to restrict the rule to the specified interface (normally loopback)
// sudo iptables -t nat -A OUTPUT -p tcp --dport 7224 -o lo -j DNAT --to-destination 10.49.64.3:7224
func AddOutputNatRule(protocol, destIP, destPort, ifaceName string) error {
args := []string{
NuNetIptablesChain, "-t", "nat", "-p", protocol,
"--dport", destPort, "-o", ifaceName,
"-j", "DNAT", "--to-destination", destIP + ":" + destPort,
}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding forward rule: %w", err)
}
}
return nil
}
// DelOutputNatRule deletes an output chain dnat rule
func DelOutputNatRule(protocol, destIP, destPort, ifaceName string) error {
args := []string{
NuNetIptablesChain, "-t", "nat", "-p", protocol,
"--dport", destPort, "-o", ifaceName,
"-j", "DNAT", "--to-destination", destIP + ":" + destPort,
}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error adding froward rule: %w", err)
}
}
return nil
}
// AddForwardIntRule adds a FORWARD between interfaces rule to iptables
func AddForwardIntRule(inInt, outInt string) error {
args := []string{
NuNetIptablesChain, "-t", "filter",
"-i", inInt, "-o", outInt, "-j", "ACCEPT",
}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding interface forward rule: %w", err)
}
}
return nil
}
// DelForwardIntRule deletes a FORWARD between interfaces rule if it exists
func DelForwardIntRule(inInt, outInt string) error {
args := []string{
NuNetIptablesChain, "-t", "filter",
"-i", inInt, "-o", outInt, "-j", "ACCEPT",
}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error deleting interface forward rule: %w", err)
}
}
return nil
}
// AddMasqueradeRule adds a MASQUERADE rule to iptables POSTROUTING chain
func AddMasqueradeRule() error {
args := []string{"POSTROUTING", "-t", "nat", "-j", "MASQUERADE"}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding rule: %w", err)
}
}
return nil
}
// DelMasqueradeRule deletes a MASQUERADE rule from the POSTROUTING chain if it exists
func DelMasqueradeRule() error {
args := []string{"POSTROUTING", "-t", "nat", "-j", "MASQUERADE"}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error deleting masquerade rule: %w", err)
}
}
return nil
}
// AddRelEstRule adds a RELATED,ESTABLISHED rule to specified chain of iptables
func AddRelEstRule(chain string) error {
args := []string{
chain, "-t", "filter",
"-m", "conntrack", "--ctstate",
"RELATED,ESTABLISHED", "-j", "ACCEPT",
}
if !iptRuleExist(args...) {
err := iptAppendRule(args...)
if err != nil {
return fmt.Errorf("error adding related,established rule: %w", err)
}
}
return nil
}
// DelRelEstRule deletes a RELATED,ESTABLISHED rule from the specified chain if it exists
func DelRelEstRule(chain string) error {
args := []string{
chain, "-t", "filter",
"-m", "conntrack", "--ctstate",
"RELATED,ESTABLISHED", "-j", "ACCEPT",
}
if iptRuleExist(args...) {
err := iptDeleteRule(args...)
if err != nil {
return fmt.Errorf("error deleting related,established rule: %w", err)
}
}
return nil
}
// iptDeleteRule deletes the specified rule from the iptables chain
func iptDeleteRule(rule ...string) error {
out, err := iptables(append([]string{"-D"}, rule...))
if err != nil {
return fmt.Errorf("error deleting rule: %w, output: %s", err, out)
}
return nil
}
func iptAppendRule(rule ...string) error {
out, err := iptables(append([]string{"-A"}, rule...))
if err != nil {
return fmt.Errorf("error appending rule: %w, output: %s", err, out)
}
return nil
}
func iptRuleExist(rule ...string) bool {
_, err := iptables(append([]string{"-C"}, rule...))
return err == nil
}
func iptables(args []string) (string, error) {
cmd := exec.Command("iptables", args...)
cmd.SysProcAttr = &syscall.SysProcAttr{
AmbientCaps: []uintptr{
unix.CAP_NET_ADMIN,
},
}
output, err := cmd.CombinedOutput()
if err != nil {
return string(output), fmt.Errorf("failed to execute command: iptables %q: %w", args, err)
}
return string(output), nil
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package utils
import (
"archive/tar"
"crypto/rand"
"fmt"
"io"
"math/big"
"os"
"path/filepath"
"reflect"
"strings"
"github.com/spf13/afero"
"golang.org/x/exp/slices"
)
// RandomString generates a random string of length n
func RandomString(n int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
sb.WriteByte(charset[n.Int64()])
}
return sb.String(), nil
}
// SliceContains checks if a string exists in a slice
func SliceContains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
// Sanitize archive file pathing from "G305: Zip Slip vulnerability"
func SanitizeArchivePath(d, t string) (v string, err error) {
v = filepath.Join(d, t)
if strings.HasPrefix(v, filepath.Clean(d)) {
return v, nil
}
return "", fmt.Errorf("%s: %s", "content filepath is tainted", t)
}
// CreateTarArchive creates a tar archive of the source directory
func CreateTarArchive(afs afero.Afero, sourceDir, targetFile string) error {
tarFile, err := afs.Create(targetFile)
if err != nil {
return fmt.Errorf("failed to create tar file: %w", err)
}
defer tarFile.Close()
tw := tar.NewWriter(tarFile)
defer tw.Close()
// Walk through the source directory
return afs.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Get the relative path
relPath, err := filepath.Rel(filepath.Dir(sourceDir), path)
if err != nil {
return fmt.Errorf("failed to get relative path: %w", err)
}
// Skip the source directory itself
if path == sourceDir {
return nil
}
// Create a tar header
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return fmt.Errorf("failed to create tar header: %w", err)
}
// Set the name to the relative path
header.Name = relPath
// Write the header
if err := tw.WriteHeader(header); err != nil {
return fmt.Errorf("failed to write tar header: %w", err)
}
// If it's a regular file, write the content
if info.Mode().IsRegular() {
file, err := afs.Open(path)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()
// copy in chunks to avoid decompression bomb
for {
_, err := io.CopyN(tw, file, 1024)
if err != nil {
if err == io.EOF {
break
}
return fmt.Errorf("failed to copy file content to tar: %w", err)
}
}
}
return nil
})
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
func MapKeysToSlice[R comparable, T any](m map[R]T) []R {
keys := make([]R, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// Copyright 2024, Nunet
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and limitations under the License.
package validate
import (
"regexp"
"strings"
)
// IsBlank checks if a string is empty or contains only whitespace
func IsBlank(s string) bool {
return len(strings.TrimSpace(s)) == 0
}
// IsNotBlank checks if a string is not empty and does not contain only whitespace
func IsNotBlank(s string) bool {
return !IsBlank(s)
}
// IsLiteral just checks if a variable is a string
func IsLiteral(s interface{}) bool {
_, ok := s.(string)
return ok
}
var dnsNameRegex = regexp.MustCompile(`^[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$`)
// IsDNSNameValid performs DNS name validation using regex
func IsDNSNameValid(name string) bool {
// DNSNameRegex validates DNS names with the following rules:
// - Starts with a letter or number
// - Contains only letters, numbers, dots, and hyphens
// - Each label (part between dots) starts with a letter/number and ends with a letter/number
// - Each label is between 1 and 63 characters
// - Total length between 1 and 253 characters
return dnsNameRegex.MatchString(name)
}
var dockerImageRegex = regexp.MustCompile(`^(?:(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]*[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]*[a-zA-Z0-9])?)*(?::\d+)?)/)?(?:[a-z0-9]+(?:(?:[._-][a-z0-9]+)*)/)*[a-z0-9]+(?:(?:[._-][a-z0-9]+)*)?(?::[a-zA-Z0-9]+(?:(?:[._-][a-zA-Z0-9]+)*)?)?$`)
// IsDockerImageValid validates docker image format with a single regex
func IsDockerImageValid(image string) bool {
// Docker image format: [registry/]name[:tag]
// Registry: Optional domain name with optional port
// Name: Required lowercase letters, numbers, separators (., _, -)
// Tag: Optional letters, numbers, separators (., _, -)
return dockerImageRegex.MatchString(image)
}
var envVarKeyRegex = regexp.MustCompile("^[a-zA-Z_][a-zA-Z0-9_]*$")
// IsEnvVarKeyValid validates environment variable key format
func IsEnvVarKeyValid(key string) bool {
// Environment variable keys should start with a letter or underscore
// and can contain only letters, numbers, and underscores
return envVarKeyRegex.MatchString(key)
}