package api
import (
"fmt"
"os"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
ginSwagger "github.com/swaggo/gin-swagger"
"github.com/swaggo/gin-swagger/swaggerFiles"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
type RESTServerConfig struct {
P2P *libp2p.Libp2p
Onboarding *onboarding.Onboarding
Logger *logger.Logger
Resource resources.Manager
MidW []gin.HandlerFunc
Port uint32
Addr string
}
// RESTServer represents a HTTP server
type RESTServer struct {
router *gin.Engine
config *RESTServerConfig
}
// NewRESTServer is a constructor function for RESTServer
// It returns a pointer to RESTServer
func NewRESTServer(config *RESTServerConfig) *RESTServer {
return &RESTServer{
router: setupRouter(config.MidW),
config: config,
}
}
func setupRouter(mid []gin.HandlerFunc) *gin.Engine {
mid = append(mid, cors.New(getCustomCorsConfig()))
router := gin.Default()
router.Use(mid...)
return router
}
// InitializeRoutes sets up all the endpoint routes
func (rs *RESTServer) InitializeRoutes() {
v1 := rs.router.Group("/api/v1")
// onboardHandler := NewOnboardingHandler(s.config.Onboarding)
onboarding := v1.Group("/onboarding")
{
onboarding.GET("/provisioned", rs.ProvisionedCapacity)
onboarding.GET("/address/new", rs.CreatePaymentAddress)
onboarding.GET("/status", rs.Status)
onboarding.GET("/info", rs.Info)
onboarding.POST("/onboard", rs.Onboard)
onboarding.POST("/resource-config", rs.ResourceConfig)
onboarding.DELETE("/offboard", rs.Offboard)
}
// deviceHandler := DeviceHandler{}
device := v1.Group("/device")
{
device.GET("/status", rs.DeviceStatus)
device.POST("/status", rs.UpdateDeviceStatus)
}
// vmHandler := VMHandler{}
vm := v1.Group("/vm")
{
vm.POST("/start-default", rs.StartDefault)
vm.POST("/start-custom", rs.StartCustom)
}
// ph := P2PHandler{p2p: rs.config.P2P}
p2p := v1.Group("/peers")
{
p2p.GET("", rs.ListPeers)
p2p.GET("/self", rs.SelfPeerInfo)
// DEBUGGING ONLY
if _, debugMode := os.LookupEnv("NUNET_DEBUG"); debugMode {
p2p.GET("/ping", rs.PingPeer)
p2p.GET("/dht", rs.KnownPeers)
// p2p.GET("/dht/dump", ph.DumpDHTHandler)
}
}
// Swagger API documentation
rs.router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
}
// Run starts the server on the specified port
func (rs *RESTServer) Run() error {
return rs.router.Run(fmt.Sprintf("%s:%d", rs.config.Addr, rs.config.Port))
}
func getCustomCorsConfig() cors.Config {
config := defaultConfig()
// FIXME: This is a security concern.
config.AllowOrigins = []string{"http://localhost:9991", "http://localhost:9992"}
return config
}
// defaultConfig returns a generic default configuration mapped to localhost.
func defaultConfig() cors.Config {
return cors.Config{
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
AllowHeaders: []string{"Access-Control-Allow-Origin", "Origin", "Content-Length", "Content-Type"},
AllowCredentials: false,
MaxAge: 12 * time.Hour,
}
}
package api
import (
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/libp2p/go-libp2p/core/peer"
)
// DEBUG
func (rs RESTServer) PingPeer(c *gin.Context) {
reqCtx := c.Request.Context()
id := c.Query("peerID")
if id == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "peerID not provided"})
return
}
if id == rs.config.P2P.Host.ID().String() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "peerID can not be self peerID"})
return
}
// decode only for validation
target, err := peer.Decode(id)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid string ID: could not decode string ID to peer ID"})
return
}
res, err := rs.config.P2P.Ping(reqCtx, target.String(), time.Second*5)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("could not ping peer %s: %v", id, err)})
return
}
c.JSON(http.StatusOK, gin.H{"message": fmt.Sprintf("ping peer %s, success=%t, RTT=%d", id, res.Success, res.RTT)})
}
package api
import (
"net/http"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// DeviceStatusHandler godoc
//
// @Summary Retrieve device status
// @Description Retrieve device status whether paused/offline (unable to receive job deployments) or online
// @Tags device
// @Produce json
// @Success 200 {object} object
// @Failure 500 {object} object "host node has not yet been initialized"
// @Failure 500 {object} object "could not retrieve data from peer"
// @Failure 500 {object} object "failed to type assert peer data for peer ID"
// @Router /device/status [get]
func (rs RESTServer) DeviceStatus(c *gin.Context) {
// TODO: handle this after refactor
// status, err := libp2p.DeviceStatus()
// if err != nil {
// c.AbortWithStatusJSON(500, gin.H{"error": "could not retrieve device status"})
// return
// }
// c.JSON(200, gin.H{"online": status})
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "device status not implemented"})
}
// UpdateDeviceStatus godoc
//
// @Summary Update device status between online/offline
// @Description Update device status to online (able to receive jobs) or offline (unable to receive jobs).
// @Tags device
// @Produce json
// @Failure 400 {object} object "empty content data"
// @Failure 400 {object} object "invalid payload data"
// @Failure 500 {object} object "host node has not yet been initialized"
// @Failure 500 {object} object "could not retrieve data from self peer"
// @Failure 500 {object} object "failed to type assert peer data for peer ID"
// @Failure 500 {object} object "Failed to retrieve libp2p info from database"
// @Failure 500 {object} object "Failed to update libp2p info in database"
// @Failure 500 {object} object "failed to put peer data into peerstore"
// @Success 200 {object} object "Device status successfully changed to online"
// @Success 200 {object} object "Device status successfully changed to offline"
// @Success 200 {object} object "no change in device status"
// @Router /device/status [post]
func (rs RESTServer) UpdateDeviceStatus(c *gin.Context) {
span := trace.SpanFromContext(c.Request.Context())
span.SetAttributes(attribute.String("URL", "/device/status"))
var status struct {
IsAvailable bool `json:"is_available"`
}
if c.Request.ContentLength == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "empty content data"})
return
}
err := c.ShouldBindJSON(&status)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid payload data"})
return
}
// TODO: handle this after refactor
// err = libp2p.ChangeDeviceStatus(status.IsAvailable)
// if err != nil {
// c.AbortWithStatusJSON(500, gin.H{"error": err.Error()})
// return
// }
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "change device status not implemented"})
// END
var msg string
if status.IsAvailable {
msg = "Device status successfully changed to online"
} else {
msg = "Device status successfully changed to offline"
}
c.JSON(http.StatusOK, gin.H{"message": msg})
}
package api
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/types"
)
// // OnboardingHandler is a controller for /onboarding endpoint functionalities
// type OnboardingHandler struct {
// service *onboarding.Onboarding
// }
// // NewOnboardingHandler is a constructor for OnboardingHandler
// func NewOnboardingHandler(s *onboarding.Onboarding) OnboardingHandler {
// return OnboardingHandler{service: s}
// }
// ProvisionedCapacity godoc
//
// @Summary Returns provisioned capacity on host.
// @Description Get total memory capacity in MB and CPU capacity in MHz.
// @Tags onboarding
// @Produce json
// @Success 200 {object} types.Provisioned
// @Router /onboarding/provisioned [get]
func (rs *RESTServer) ProvisionedCapacity(c *gin.Context) {
provisionedResources, err := rs.config.Resource.SystemSpecs().GetProvisionedResources()
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, provisionedResources)
}
// CreatePaymentAddress godoc
//
// @Summary Create a new payment address.
// @Description Create a payment address from public key. Return payment address and private key.
// @Tags onboarding
// @Produce json
// @Success 200 {object} types.BlockchainAddressPrivKey
// @Router /onboarding/address/new [get]
func (rs RESTServer) CreatePaymentAddress(c *gin.Context) {
wallet := c.DefaultQuery("blockchain", "cardano")
pair, err := onboarding.CreatePaymentAddress(wallet)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, pair)
}
// Onboard godoc
//
// @Summary Runs the onboarding process.
// @Description Onboard runs onboarding script given the amount of resources to onboard.
// @Tags onboarding
// @Produce json
// @Param capacity body types.CapacityForNunet true "Capacity for NuNet"
// @Success 201 {object} types.OnboardingConfig
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "could not check if config directory exists"
// @Failure 500 {object} object "config directory does not exist"
// @Failure 500 {object} object "could not validate payment address"
// @Failure 500 {object} object "could not validate capacity data"
// @Failure 500 {object} object "cardano node requires 10000MB of RAM and 6000MHz CPU"
// @Failure 500 {object} object "invalid channel data, channel does not exist"
// @Failure 500 {object} object "unable to create available resources table"
// @Failure 500 {object} object "unable to update available resources table"
// @Failure 500 {object} object "could not calculate free resources and update database"
// @Failure 500 {object} object "could not register and run new node"
// @Router /onboarding/onboard [post]
func (rs *RESTServer) Onboard(c *gin.Context) {
capacity := types.CapacityForNunet{
ServerMode: true,
IsAvailable: true,
}
if err := c.BindJSON(&capacity); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request data"})
return
}
oConfig, p2p, err := rs.config.Onboarding.Onboard(c.Request.Context(), capacity)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rs.config.P2P = p2p
c.JSON(http.StatusCreated, oConfig)
}
// Offboard godoc
//
// @Summary Runs the offboarding process.
// @Description Offboard runs offboarding process to remove the machine from the NuNet network.
// @Tags onboarding
// @Produce json
// @Success 200 {string} string "device successfully offboarded"
// @Param force query string false "force offboarding"
// @Failure 400 {object} object "invalid query data"
// @Failure 500 {object} object "could not retrieve onboard status"
// @Failure 500 {object} object "machine is not onboarded"
// @Failure 500 {object} object "unable to shutdown node"
// @Failure 500 {object} object "unable to delete available resources on database"
// @Failure 500 {object} object "could not remove payment address"
// @Router /onboarding/offboard [post]
func (rs RESTServer) Offboard(c *gin.Context) {
query := c.DefaultQuery("force", "false")
force, err := strconv.ParseBool(query)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid query data"})
return
}
err = rs.config.Onboarding.Offboard(c.Request.Context(), force)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "device successfully offboarded"})
}
// Status godoc
//
// @Summary Returns whether device is onboarded or not.
// @Tags onboarding
// @Produce json
// @Success 200 {boolean}
// @Router /onboarding/status [get]
func (rs RESTServer) Status(c *gin.Context) {
status, err := rs.config.Onboarding.IsOnboarded(c.Request.Context())
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"onboarded": status})
}
// Info godoc
//
// @Summary Returns additional information about onboarded device.
// @Tags onboarding
// @Produce json
// @Success 200 {object} types.OnboardingConfig
// @Router /onboarding/info [get]
func (rs RESTServer) Info(c *gin.Context) {
info, err := rs.config.Onboarding.Info(c.Request.Context())
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"info": info})
}
// ResourceConfig godoc
//
// @Summary changes the amount of resources of onboarded device .
// @Tags onboarding
// @Produce json
// @Success 200 {object} types.OnboardingConfig
// @Router /onboarding/resource-config [post]
func (rs RESTServer) ResourceConfig(c *gin.Context) {
if c.Request.ContentLength == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "request body is empty"})
return
}
var capacity types.CapacityForNunet
err := c.BindJSON(&capacity)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request data"})
return
}
oConfig, err := rs.config.Onboarding.ResourceConfig(c.Request.Context(), capacity)
if err != nil {
switch err {
case onboarding.ErrMachineNotOnboarded:
c.AbortWithStatusJSON(http.StatusConflict, gin.H{"error": err.Error()})
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, oConfig)
}
package api
import (
"net/http"
"github.com/gin-gonic/gin"
)
// ListPeers godoc
//
// @Summary Return list of peers currently connected to
// @Description Gets a list of peers the libp2p node can see within the network and return a list of peers
// @Tags p2p
// @Produce json
// @Failure 500 {object} object "no peers yet"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Success 200 {object} object "list of peers"
// @Router /peers [get]
func (rs RESTServer) ListPeers(c *gin.Context) {
if rs.config.P2P == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
peers := rs.config.P2P.VisiblePeers()
if len(peers) == 0 {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no peers yet"})
return
}
c.JSON(http.StatusOK, peers)
}
// KnownPeers godoc
//
// @Summary Return list of peers which have sent a dht update
// @Description Gets a list of peers the libp2p node has received a dht update from
// @Tags p2p
// @Produce json
// @Success 200 {object} object "List of peers"
// @Failure 404 {object} object "No peers found"
// @Failure 500 {object} object "Host Node hasn't yet been initialized"
// @Router /peers/dht [get]
func (rs RESTServer) KnownPeers(c *gin.Context) {
if rs.config.P2P == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
peers, err := rs.config.P2P.KnownPeers()
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if len(peers) == 0 {
c.JSON(http.StatusNotFound, gin.H{"message": "no peers found"})
return
}
c.JSON(http.StatusOK, peers)
}
// SelfPeerInfo godoc
//
// @Summary Return self peer info
// @Description Gets self peer info of libp2p node
// @Tags p2p
// @Produce json
// @Success 200 {object} object "Self Peer Info"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Router /peers/self [get]
func (rs RESTServer) SelfPeerInfo(c *gin.Context) {
if rs.config.P2P == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
self := rs.config.P2P.Stat()
c.JSON(http.StatusOK, self)
}
// DumpDHT godoc
//
// @Summary Return a dump of the dht
// @Description Returns entire DHT content
// @Tags p2p
// @Produce json
// @Success 200 {object} object "List of DHT peers"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "no content in DHT"
// @Router /peers/dht/dump [get]
func (rs RESTServer) DumpDHT(c *gin.Context) {
if rs.config.P2P == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
dht, err := rs.config.P2P.DumpDHTRoutingTable()
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if len(dht) == 0 {
c.JSON(http.StatusInternalServerError, gin.H{"message": "empty DHT"})
return
}
c.JSON(http.StatusOK, dht)
}
package api
import (
"net/http"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"gitlab.com/nunet/device-management-service/executor/firecracker"
"gitlab.com/nunet/device-management-service/types"
)
type CustomVM struct {
KernelImagePath string `json:"kernel_image_path"`
FilesystemPath string `json:"filesystem_path"`
VCPUCount int32 `json:"vcpu_count"`
MemSizeMib int `json:"mem_size_mib"`
TapDevice string `json:"tap_device"`
}
type DefaultVM struct {
KernelImagePath string `json:"kernel_image_path"`
FilesystemPath string `json:"filesystem_path"`
PublicKey string `json:"public_key"`
NodeID string `json:"node_id"`
}
// StartCustom godoc
//
// @Summary Start a VM with custom configuration.
// @Description This endpoint is an abstraction of all primitive endpoints. When invokend, it calls all primitive endpoints in a sequence.
// @Tags vm
// @Produce json
// @Param body body firecracker.CustomVM true "body"
// @Success 200 {object} string "VM started successfully."
// @Failure 400 {object} string "invalid request body"
// @Failure 500 {object} string "could not create database table"
// @Failure 500 {object} string "could not initialize virtual machine"
// @Failure 500 {object} string "failed to configure drives"
// @Failure 500 {object} string "failed to configure machine config"
// @Failure 500 {object} string "failed to configure network-interfaces"
// @Failure 500 {object} string "failed to setup MMDS"
// @Failure 500 {object} string "failed to pass MMDS message"
// @Failure 500 {object} string "unable to start virtual machine"
// @Router /vm/start-custom [post]
func (rs RESTServer) StartCustom(c *gin.Context) {
reqCtx := c.Request.Context()
span := trace.SpanFromContext(reqCtx)
span.SetAttributes(attribute.String("URL", "/vm/start-custom"))
var body CustomVM
err := c.BindJSON(&body)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
fe := firecracker.NewFirecrackerEngineBuilder(body.FilesystemPath).
WithKernelImage(body.KernelImagePath).
Build()
fer := &types.ExecutionRequest{
JobID: "test_job",
ExecutionID: "test_execution",
EngineSpec: fe,
Resources: &types.ExecutionResources{
CPU: types.CPU{Cores: uint32(body.VCPUCount)},
Memory: types.RAM{Size: int64(body.MemSizeMib)},
},
}
fc, err := firecracker.NewExecutor(c.Request.Context(), "manual-start-custom")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
err = fc.Start(c.Request.Context(), fer)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "VM started successfully"})
}
// StartDefault godoc
//
// @Summary Start a VM with default configuration.
// @Description Kernel file and filesystem file needs to be passed in body. This endpoint is an abstraction of all primitive endpoints.
// @Tags vm
// @Produce json
// @Param body body firecracker.DefaultVM true "body"
// @Success 200 {object} string "VM started successfully."
// @Failure 400 {object} string "invalid request body"
// @Failure 500 {object} string "could not initialize virtual machine"
// @Failure 500 {object} string "failed to confiugre boot source"
// @Failure 500 {object} string "failed to configure drives"
// @Failure 500 {object} string "failed to configure machineConfig"
// @Failure 500 {object} string "failed to configure network-interfaces"
// @Failure 500 {object} string "failed to setup MMDS"
// @Failure 500 {object} string "failed to pass MMDS message"
// @Failure 500 {object} string "unable to start virtual machine"
// @Router /vm/start-default [post]
func (rs RESTServer) StartDefault(c *gin.Context) {
reqCtx := c.Request.Context()
span := trace.SpanFromContext(reqCtx)
span.SetAttributes(attribute.String("URL", "/vm/start-default"))
var body DefaultVM
err := c.BindJSON(&body)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
fe := firecracker.NewFirecrackerEngineBuilder(body.FilesystemPath).
WithKernelImage(body.KernelImagePath).
Build()
fer := &types.ExecutionRequest{
JobID: "test_job",
ExecutionID: "test_execution",
EngineSpec: fe,
Resources: &types.ExecutionResources{
CPU: types.CPU{Cores: 1},
Memory: types.RAM{Size: 1024},
},
}
fc, err := firecracker.NewExecutor(c.Request.Context(), "manual-start-default")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
err = fc.Start(c.Request.Context(), fer)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "VM started successfully"})
}
package cmd
import (
"fmt"
"math"
"regexp"
"strconv"
)
type amdGPU struct {
index int
}
func (a *amdGPU) name() string {
pattern := fmt.Sprintf(`GPU\[%d\]\s+: Card series:\s+(.+)`, a.index)
re := regexp.MustCompile(pattern)
rocmOutput, err := runShellCmd("rocm-smi --showproductname")
if err != nil {
return ""
}
match := re.FindStringSubmatch(rocmOutput)
if len(match) > 1 {
return match[1]
}
return ""
}
func (a *amdGPU) utilizationRate() int64 {
pattern := fmt.Sprintf(`GPU\[%d\]\s+: GPU use \(%%\): (\d+)`, a.index)
re := regexp.MustCompile(pattern)
rocmOutput, err := runShellCmd("rocm-smi --showuse")
if err != nil {
return 0
}
match := re.FindStringSubmatch(rocmOutput)
if len(match) > 1 {
utilization, err := strconv.ParseInt(match[1], 10, 32)
if err != nil {
return 0
}
return utilization
}
return 0
}
func (a *amdGPU) memory() memoryInfo {
patternTotal := fmt.Sprintf(`GPU\[%d\]\s+: vram Total Memory \(B\): (\d+)`, a.index)
reTotal := regexp.MustCompile(patternTotal)
patternUsed := fmt.Sprintf(`GPU\[%d\]\s+: vram Total Used Memory \(B\): (\d+)`, a.index)
reUsed := regexp.MustCompile(patternUsed)
rocmOutput, err := runShellCmd("rocm-smi --showmeminfo vram")
if err != nil {
return memoryInfo{}
}
matchTotal := reTotal.FindStringSubmatch(rocmOutput)
matchUsed := reUsed.FindStringSubmatch(rocmOutput)
if len(matchTotal) > 1 && len(matchUsed) > 1 {
total, err := strconv.ParseInt(matchTotal[1], 10, 64)
if err != nil {
total = 0
}
used, err := strconv.ParseInt(matchUsed[1], 10, 64)
if err != nil {
used = 0
}
free := (total - used)
return memoryInfo{
used: uint64(used),
free: uint64(free),
total: uint64(total),
}
}
return memoryInfo{}
}
func (a *amdGPU) temperature() float64 {
pattern := fmt.Sprintf(`GPU\[%d\]\s+: Temperature \(Sensor edge\) \(C\): ([\d\.]+)`, a.index)
re := regexp.MustCompile(pattern)
rocmOutput, err := runShellCmd("rocm-smi --showtemp")
if err != nil {
return 0
}
match := re.FindStringSubmatch(rocmOutput)
if len(match) > 1 {
temperature, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return 0
}
return temperature
}
return 0
}
func (a *amdGPU) powerUsage() uint32 {
pattern := fmt.Sprintf(`GPU\[%d\]\s+: Average Graphics Package Power \(W\): ([\d\.]+)`, a.index)
re := regexp.MustCompile(pattern)
rocmOutput, err := runShellCmd("rocm-smi --showpower")
if err != nil {
return 0
}
match := re.FindStringSubmatch(rocmOutput)
if len(match) > 1 {
powerFloat, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return 0
}
power := uint32(math.Round(powerFloat))
return power
}
return 0
}
package cmd
import (
"os"
"github.com/spf13/cobra"
)
// autocompleteCmd represents the command to generate shell autocompletion scripts
func newAutoCompleteCmd() *cobra.Command {
return &cobra.Command{
Use: "autocomplete [shell]",
Short: "Generate autocomplete script for your shell",
Long: `Generate an autocomplete script for the nunet CLI.
This command supports Bash and Zsh shells.`,
DisableFlagsInUseLine: true,
Hidden: true,
ValidArgs: []string{"bash", "zsh"},
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
Run: func(cmd *cobra.Command, args []string) {
switch args[0] {
case "bash":
_ = cmd.Root().GenBashCompletion(os.Stdout)
case "zsh":
_ = cmd.Root().GenZshCompletion(os.Stdout)
}
},
}
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/buger/jsonparser"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewCapacityCmd is a constructor for `capacity` command
func newCapacityCmd(client *utils.HTTPClient) *cobra.Command {
fnAvailable := "available"
fnOnboarded := "onboarded"
fnFull := "full"
cmd := &cobra.Command{
Use: "capacity",
Short: "Display capacity of device resources",
Long: `Retrieve capacity of the machine, onboarded or available amount of resources`,
RunE: func(cmd *cobra.Command, _ []string) error {
table := setupTable(cmd.OutOrStdout())
defer table.Render()
resBody, resCode, err := client.MakeRequest("GET", "/onboarding/info", nil)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
info, _, _, err := jsonparser.Get(resBody, "info")
if err != nil {
return fmt.Errorf("could not get info value by key: %w", err)
}
var config *types.OnboardingConfig
if err := json.Unmarshal(info, &config); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
full, _ := cmd.Flags().GetBool(fnFull)
available, _ := cmd.Flags().GetBool(fnAvailable)
onboarded, _ := cmd.Flags().GetBool(fnOnboarded)
if full {
table.Append(formatCapacityData("Full", &config.TotalResources.Resources))
}
if onboarded {
table.Append(formatCapacityData("Onboarded", &config.OnboardedResources.Resources))
}
if available {
free := &types.Resources{
RAM: config.TotalResources.RAM - config.OnboardedResources.RAM,
Disk: config.TotalResources.Disk - config.OnboardedResources.Disk,
}
table.Append(formatCapacityData("Available", free))
}
return nil
},
}
cmd.Flags().BoolP(fnFull, "f", false, "display device capacity")
cmd.Flags().BoolP(fnAvailable, "a", false, "display amount of resources still available for onboarding")
cmd.Flags().BoolP(fnOnboarded, "o", false, "display amount of onboarded resources")
cmd.MarkFlagsOneRequired(fnAvailable, fnFull, fnOnboarded)
return cmd
}
func formatCapacityData(resourceType string, resources *types.Resources) []string {
return []string{
resourceType,
fmt.Sprintf("%d", resources.RAM),
fmt.Sprintf("%f", resources.CPU),
fmt.Sprintf("%d", resources.NumCores),
}
}
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewDeviceCmd is a constructor for `device` parent command
func newDeviceCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "device",
Short: "device related operations",
Long: `manage onboarded device`,
Run: func(cmd *cobra.Command, _ []string) {
err := cmd.Help()
if err != nil {
cmd.Println(err)
}
},
}
cmd.AddCommand(newDeviceStatusCmd(client))
cmd.AddCommand(newDeviceSetCmd(client))
return cmd
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewDeviceSetCmd is a constructor for `device set` subcommand
func newDeviceSetCmd(client *utils.HTTPClient) *cobra.Command {
validArgs := []string{"online", "offline"}
return &cobra.Command{
Use: "set",
Short: "Set device online status",
ValidArgs: validArgs,
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
Long: ``,
RunE: func(cmd *cobra.Command, args []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
var statusJSON []byte
if args[0] == "online" {
statusJSON = []byte(`{"is_available": true}`)
} else if args[0] == "offline" {
statusJSON = []byte(`{"is_available": false}`)
}
resBody, resCode, err := client.MakeRequest("POST", "/device/status", statusJSON)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 201 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
var response map[string]any
if err := json.Unmarshal(resBody, &response); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
msg, ok := response["message"]
if ok {
fmt.Fprintln(cmd.OutOrStdout(), msg)
}
return nil
},
}
}
package cmd
import (
"fmt"
"github.com/buger/jsonparser"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewDeviceStatusCmd is a constructor for `device status` subcommand
func newDeviceStatusCmd(client *utils.HTTPClient) *cobra.Command {
return &cobra.Command{
Use: "status",
Short: "Display current device status",
Args: cobra.NoArgs,
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
resBody, resCode, err := client.MakeRequest("GET", "/device/status", nil)
if err != nil {
return fmt.Errorf("unable to make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
online, err := jsonparser.GetBoolean(resBody, "online")
if err != nil {
return fmt.Errorf("failed to get device status from json response: %w", err)
}
if online {
fmt.Fprintln(cmd.OutOrStdout(), "Status: online")
} else {
fmt.Fprintln(cmd.OutOrStdout(), "Status: offline")
}
return nil
},
}
}
package cmd
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newGPUCmd(afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "gpu",
Short: "GPU-related operations",
Long: ``,
Run: func(cmd *cobra.Command, _ []string) {
err := cmd.Help()
if err != nil {
cmd.Println(err)
}
},
}
cmd.AddCommand(newGPUCapacityCmd())
cmd.AddCommand(newGPUStatusCmd())
cmd.AddCommand(newGPUOnboardCmd(afs))
return cmd
}
package cmd
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"syscall"
"github.com/docker/cli/opts"
docker_types "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
)
// ContainerOptions set parameters for running a Docker container (NVIDIA/AMD/Intel)
type ContainerOptions struct {
UseGPUs bool
Devices []string
Groups []string
Image string
Command []string
Entrypoint []string
}
func newGPUCapacityCmd() *cobra.Command {
fnCuda := "cuda-tensor"
fnRocm := "rocm-hip"
cmd := &cobra.Command{
Use: "capacity",
Short: "Check availability of NVIDIA/AMD/Intel GPUs",
Long: ``,
Run: func(cmd *cobra.Command, _ []string) {
cuda, _ := cmd.Flags().GetBool(fnCuda)
rocm, _ := cmd.Flags().GetBool(fnRocm)
if !cuda && !rocm {
fmt.Println(`Error: no flags specified`)
err := cmd.Help()
if err != nil {
cmd.Println(err)
}
os.Exit(1)
}
vendors, err := resources.ManagerInstance.SystemSpecs().GetGPUVendors()
if err != nil {
fmt.Println("Error detecting GPU vendors:", err)
os.Exit(1)
}
hasAMD := containsVendor(vendors, types.GPUVendorAMDATI)
hasNVIDIA := containsVendor(vendors, types.GPUVendorNvidia)
if !hasAMD && !hasNVIDIA {
fmt.Println("No AMD or NVIDIA GPU(s) detected...")
os.Exit(1)
}
ctx := context.Background()
if cuda {
if !hasNVIDIA {
fmt.Println("No NVIDIA GPU(s) detected...")
os.Exit(1)
}
cudaOpts := ContainerOptions{
UseGPUs: true,
Image: "registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/pytorch",
Command: []string{"python", "check-cuda-and-tensor-cores-availability.py"},
Entrypoint: []string{""},
}
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
fmt.Println("Error creating Docker client:", err)
os.Exit(1)
}
images, err := cli.ImageList(ctx, docker_types.ImageListOptions{})
if err != nil {
fmt.Println("Error listing Docker images:", err)
os.Exit(1)
}
if !imageExists(images, cudaOpts.Image) {
err := pullImage(ctx, cli, cudaOpts.Image)
if err != nil {
fmt.Println("Error pulling CUDA image:", err)
os.Exit(1)
}
}
err = runDockerContainer(ctx, cli, cudaOpts)
if err != nil {
fmt.Println("Error running CUDA container:", err)
os.Exit(1)
}
}
if rocm {
if !hasAMD {
fmt.Println("No AMD GPU(s) detected...")
os.Exit(1)
}
rocmOpts := ContainerOptions{
Devices: []string{"/dev/kfd", "/dev/dri"},
Groups: []string{"video", "render"},
Image: "registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/pytorch-amd",
Command: []string{"python", "check-rocm-and-hip-availability.py"},
Entrypoint: []string{""},
}
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
fmt.Println("Error creating Docker client:", err)
os.Exit(1)
}
images, err := cli.ImageList(ctx, docker_types.ImageListOptions{})
if err != nil {
fmt.Println("Error listing images:", err)
os.Exit(1)
}
if !imageExists(images, rocmOpts.Image) {
err := pullImage(ctx, cli, rocmOpts.Image)
if err != nil {
fmt.Println("Error pulling ROCm-HIP image:", err)
os.Exit(1)
}
}
err = runDockerContainer(ctx, cli, rocmOpts)
if err != nil {
fmt.Println("Error running ROCm-HIP container:", err)
os.Exit(1)
}
}
},
}
cmd.Flags().BoolP(fnCuda, "c", false, "check CUDA Tensor")
cmd.Flags().BoolP(fnRocm, "r", false, "check ROCM-HIP")
return cmd
}
func runDockerContainer(ctx context.Context, cli *client.Client, options ContainerOptions) error {
if options.Image == "" {
return fmt.Errorf("image name cannot be empty")
}
config := &container.Config{
Image: options.Image,
Entrypoint: options.Entrypoint,
Cmd: options.Command,
Tty: true,
}
hostConfig := &container.HostConfig{}
if options.UseGPUs {
gpuOpts := opts.GpuOpts{}
if err := gpuOpts.Set("all"); err != nil {
return fmt.Errorf("failed setting GPU opts: %v", err)
}
hostConfig.DeviceRequests = gpuOpts.Value()
}
for _, device := range options.Devices {
hostConfig.Devices = append(hostConfig.Devices, container.DeviceMapping{
PathOnHost: device,
PathInContainer: device,
CgroupPermissions: "rwm",
})
}
hostConfig.GroupAdd = options.Groups
resp, err := cli.ContainerCreate(ctx, config, hostConfig, nil, nil, "")
if err != nil {
return fmt.Errorf("cannot create container: %v", err)
}
defer func() {
if err := cli.ContainerRemove(ctx, resp.ID, docker_types.ContainerRemoveOptions{}); err != nil {
fmt.Printf("WARNING: could not remove container: %v\n", err)
}
}()
if err := cli.ContainerStart(ctx, resp.ID, docker_types.ContainerStartOptions{}); err != nil {
return fmt.Errorf("cannot start container: %v", err)
}
out, err := cli.ContainerAttach(ctx, resp.ID, docker_types.ContainerAttachOptions{
Stream: true,
Stdout: true,
Stderr: true,
})
if err != nil {
return fmt.Errorf("failed attaching container: %v", err)
}
_, err = io.Copy(os.Stdout, out.Reader)
if err != nil {
return fmt.Errorf("failed to copy container output: %w", err)
}
waitCh, errCh := cli.ContainerWait(ctx, resp.ID, container.WaitConditionNotRunning)
select {
case waitResult := <-waitCh:
if waitResult.Error != nil {
return fmt.Errorf("container exit error: %s", waitResult.Error.Message)
}
case err := <-errCh:
return fmt.Errorf("error waiting for container: %v", err)
}
return nil
}
func imageExists(images []docker_types.ImageSummary, imageName string) bool {
for _, image := range images {
for _, tag := range image.RepoTags {
if tag == imageName {
return true
}
}
}
return false
}
func pullImage(ctx context.Context, cli *client.Client, imageName string) error {
ctxCancel, cancel := context.WithCancel(ctx)
defer cancel()
out, err := cli.ImagePull(ctxCancel, imageName, docker_types.ImagePullOptions{})
if err != nil {
return fmt.Errorf("unable to pull image %s: %v", imageName, err)
}
// define interrupt to stop image pull
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM)
go func() {
<-interrupt
fmt.Println("signal: interrupt")
cancel()
}()
fmt.Printf("Pulling image: %s\nThis may take some time...\n", imageName)
defer out.Close()
_, err = io.Copy(os.Stdout, out)
if err != nil {
return fmt.Errorf("failed to copy image pull to stdout: %w", err)
}
return nil
}
package cmd
import (
"fmt"
"io"
"os/exec"
"strings"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
containerPath = "maint-scripts/install_container_runtime"
amdDriverPath = "maint-scripts/install_amd_drivers"
nvidiaDriverPath = "maint-scripts/install_nvidia_drivers"
)
func newGPUOnboardCmd(afs afero.Afero) *cobra.Command {
return &cobra.Command{
Use: "onboard",
Short: "Install GPU drivers and Container Runtime",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
wsl, err := utils.CheckWSL(afs)
if err != nil {
return fmt.Errorf("could not check WSL: %w", err)
}
mining, err := checkMiningOS(afs)
if err != nil {
return fmt.Errorf("couldn't check if Mining OS: %w", err)
}
// TODO: Define API endpoint for this
vendors, err := resources.ManagerInstance.SystemSpecs().GetGPUVendors()
if err != nil {
return fmt.Errorf("couldn't check presence of a GPU: %w", err)
}
hasAMD := containsVendor(vendors, types.GPUVendorAMDATI)
hasNVIDIA := containsVendor(vendors, types.GPUVendorNvidia)
hasIntel := containsVendor(vendors, types.GPUVendorIntel)
if !hasAMD && !hasNVIDIA && !hasIntel {
return fmt.Errorf("no NVIDIA/AMD/Intel GPU(s) detected")
}
switch {
case wsl:
fmt.Fprintf(cmd.OutOrStdout(), "You are running on Windows Subsystem for Linux (WSL). AMD GPUs are not supported.")
if !hasNVIDIA {
return fmt.Errorf("no NVIDIA GPU(s) detected")
}
if err := promptContainer(cmd.InOrStdin(), cmd.OutOrStdout(), containerPath); err != nil {
return fmt.Errorf("couldn't install container runtime: %w", err)
}
case mining:
fmt.Fprintf(cmd.OutOrStdout(), "You are likely running a Mining OS. Skipping driver installation...")
if err := promptContainer(cmd.InOrStdin(), cmd.OutOrStdout(), containerPath); err != nil {
return fmt.Errorf("couldn't install container runtime: %w", err)
}
default:
if hasNVIDIA {
nvidiaGPUs, err := resources.ManagerInstance.SystemSpecs().GetGPUs(types.GPUVendorNvidia)
if err != nil {
return fmt.Errorf("couldn't fetch Nvidia info: %w", err)
}
printGPUs(nvidiaGPUs)
if err := promptContainer(cmd.InOrStdin(), cmd.OutOrStdout(), containerPath); err != nil {
return fmt.Errorf("couldn't install container runtime: %w", err)
}
if err := promptDriverInstallation(cmd.InOrStdin(), cmd.OutOrStdout(), types.GPUVendorNvidia, nvidiaDriverPath); err != nil {
return fmt.Errorf("couldn't install Nvidia driver: %w", err)
}
}
if hasAMD {
AMDGPUs, err := resources.ManagerInstance.SystemSpecs().GetGPUs(types.GPUVendorAMDATI)
if err != nil {
return fmt.Errorf("couldn't fetch AMD driver info: %w", err)
}
printGPUs(AMDGPUs)
if err := promptDriverInstallation(cmd.InOrStdin(), cmd.OutOrStdout(), types.GPUVendorAMDATI, amdDriverPath); err != nil {
return fmt.Errorf("couldn't install AMD driver: %w", err)
}
}
}
return nil
},
}
}
// containsVendor takes a slice of GPUVendor structs that were detected in the system
// and look for a specific vendor, returning true if it is found.
func containsVendor(vendors []types.GPUVendor, target types.GPUVendor) bool {
for _, v := range vendors {
if v == target {
return true
}
}
return false
}
// runScript executes a bash script from a given path.
// It takes the script's path as input and tries to run it, if successful it prints the output.
func runScript(scriptPath string) error {
script := exec.Command("/bin/bash", scriptPath)
output, err := script.CombinedOutput()
if err != nil {
return fmt.Errorf("script failed with error: %w", err)
}
fmt.Printf("%s\n", output)
return nil
}
// promptContainer takes container runtime script path as input and prompts the user for confirmation.
// If confirmed, it runs the script.
func promptContainer(in io.Reader, out io.Writer, containerPath string) error {
proceed, err := utils.PromptYesNo(in, out, "Do you want to proceed with Container Runtime installation? (y/N)")
if err != nil {
return fmt.Errorf("could not read answer from prompt: %w", err)
}
if proceed {
err := runScript(containerPath)
if err != nil {
return fmt.Errorf("cannot run container runtime installation script: %w", err)
}
}
return nil
}
// promptDriverInstallation takes GPUVendor (for printing) and the installation script as inputs.
// It prompts the user for confirmation and if confirmed it runs the script.
func promptDriverInstallation(in io.Reader, out io.Writer, vendor types.GPUVendor, scriptPath string) error {
prompt := fmt.Sprintf("Do you want to proceed with %s driver installation? (y/N)", vendor)
proceed, err := utils.PromptYesNo(in, out, prompt)
if err != nil {
return fmt.Errorf("could not read answer from prompt: %w", err)
}
if proceed {
err := runScript(scriptPath)
if err != nil {
return fmt.Errorf("cannot run driver installation script: %w", err)
}
}
return nil
}
// printGPUs display a list of detected GPUs in the machine.
// It takes a slice of GPUInfo structs as input, get the vendor from the first element
// and then iterate over each element to display the GPU card series.
func printGPUs(gpus []types.GPU) {
var vendor string
if len(gpus) == 0 {
return
}
vendor = string(gpus[0].Vendor)
fmt.Printf("Available %s GPU(s):", vendor)
for _, gpu := range gpus {
fmt.Printf("- %s\n", gpu.Model)
}
}
// checkMiningOS detects if host is running a mining OS.
// It reads from /etc/os-release file and look for common distros inside of it, if any is found it returns true.
func checkMiningOS(afs afero.Afero) (bool, error) {
miningOSes := []string{"Hive", "Rave", "PiMP", "Minerstat", "SimpleMining", "NH", "Miner", "SM", "MMP"}
osFile := "/etc/os-release"
info, err := afs.ReadFile(osFile)
if err != nil {
return false, fmt.Errorf("cannot read file %s: %w", osFile, err)
}
infoStr := string(info)
for _, os := range miningOSes {
if strings.Contains(infoStr, os) {
return true, nil
}
}
return false, nil
}
package cmd
import (
"fmt"
"os"
"os/exec"
"os/signal"
"regexp"
"syscall"
"time"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/dustin/go-humanize"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
)
func newGPUStatusCmd() *cobra.Command {
return &cobra.Command{
Use: "status",
Short: "Check GPU status in real time",
Long: ``,
Run: func(_ *cobra.Command, _ []string) {
vendors, err := resources.ManagerInstance.SystemSpecs().GetGPUVendors()
if err != nil {
fmt.Println("Error trying to detect GPU(s):", err)
return
}
hasAMD := containsVendor(vendors, types.GPUVendorAMDATI)
hasNVIDIA := containsVendor(vendors, types.GPUVendorNvidia)
hasIntel := containsVendor(vendors, types.GPUVendorIntel)
if hasNVIDIA || hasAMD || hasIntel {
if hasNVIDIA {
// NVML initialization
retNVML := nvml.Init()
if retNVML != nvml.SUCCESS {
fmt.Println("Failed to initialize NVML:", nvml.ErrorString(retNVML))
}
defer func() {
retNVML := nvml.Shutdown()
if retNVML != nvml.SUCCESS {
fmt.Println("Failed to shutdown NVML:", nvml.ErrorString(retNVML))
}
}()
}
countNVML, retNVML := nvml.DeviceGetCount()
if retNVML != nvml.SUCCESS {
fmt.Println("Failed to count NVIDIA GPU devices:", nvml.ErrorString(retNVML))
countNVML = 0
}
countROCM, err := getCountAMD()
if err != nil {
fmt.Println("Failed to count AMD GPU devices:", err)
countROCM = 0
}
countIntel, err := getCountIntel()
if err != nil {
fmt.Println("Failed to count Intel GPU devices:", err)
countIntel = 0
}
// Initialize GPU slices
nvidiaGPUs := make([]nvidiaGPU, countNVML)
for i := 0; i < countNVML; i++ {
nvidiaGPUs[i] = nvidiaGPU{index: i}
}
amdGPUs := make([]amdGPU, countROCM)
for i := 0; i < countROCM; i++ {
amdGPUs[i] = amdGPU{index: (i + 1)}
}
intelGPUs := make([]intelGPU, countIntel)
for i := 0; i < countIntel; i++ {
intelGPUs[i] = intelGPU{index: (i + 1)}
}
// Define channel for receiving interrupt signal and closing the real-time loop
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM)
exit := make(chan struct{})
go func() {
<-interrupt
close(exit)
}()
for {
select {
case <-exit:
fmt.Println("signal: interrupt")
return
default:
// Clear screen (not reliable, maybe implement something ncurses-like for future)
fmt.Print("\033[H\033[2J")
fmt.Println("========== NuNet GPU Status ==========")
fmt.Println("========== GPU Utilization ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %d%%\n", n.index, n.name(), n.utilizationRate())
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %d%%\n", a.index, a.name(), a.utilizationRate())
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %d%%\n", i.index, i.name(), i.utilizationRate())
}
fmt.Println("========== Memory Capacity ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %s\n", n.index, n.name(), humanize.IBytes(n.memory().total))
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %s\n", a.index, a.name(), humanize.IBytes(a.memory().total))
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %s\n", i.index, i.name(), humanize.IBytes(i.memory().total))
}
fmt.Println("========== Memory Used ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %s\n", n.index, n.name(), humanize.IBytes(n.memory().used))
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %s\n", a.index, a.name(), humanize.IBytes(a.memory().used))
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %s\n", i.index, i.name(), humanize.IBytes(i.memory().used))
}
fmt.Println("========== Memory Free ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %s\n", n.index, n.name(), humanize.IBytes(n.memory().free))
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %s\n", a.index, a.name(), humanize.IBytes(a.memory().free))
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %s\n", i.index, i.name(), humanize.IBytes(i.memory().free))
}
fmt.Println("========== Temperature ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %.0f°C\n", n.index, n.name(), n.temperature())
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %.0f°C\n", a.index, a.name(), a.temperature())
}
fmt.Println("========== Power Usage ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %dW\n", n.index, n.name(), n.powerUsage())
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %dW\n", a.index, a.name(), a.powerUsage())
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %dW\n", i.index, i.name(), i.powerUsage())
}
fmt.Println("")
fmt.Println("Press CTRL+C to exit...")
fmt.Println("Refreshing status in a few seconds...")
time.Sleep(2 * time.Second)
}
}
} else {
fmt.Println("No AMD, NVIDIA or Intel GPU(s) detected...")
return
}
},
}
}
func runShellCmd(command string) (string, error) {
cmd := exec.Command("sh", "-c", command)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("unable to get combined output from command %s: %v", command, err)
}
return string(output), nil
}
func getCountAMD() (int, error) {
rocmOutput, err := runShellCmd("rocm-smi --showid")
if err != nil {
return 0, fmt.Errorf("cannot run shell command: %v", err)
}
pattern := `GPU\[(\d+)\]`
re := regexp.MustCompile(pattern)
matches := re.FindAllStringSubmatch(rocmOutput, -1)
ids := make([]string, 0, len(matches))
for _, match := range matches {
ids = append(ids, match[1])
}
return len(ids), nil
}
// GetCountIntel returns the number of discrete Intel GPUs
func getCountIntel() (int, error) {
cmd := exec.Command("xpu-smi", "health", "-l")
output, err := cmd.CombinedOutput()
if err != nil {
return 0, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
outputStr := string(output)
// Use regex to find all instances of Device ID
deviceIDRegex := regexp.MustCompile(`(?i)\| Device ID\s+\|\s+(\d+)\s+\|`)
deviceIDMatches := deviceIDRegex.FindAllStringSubmatch(outputStr, -1)
return len(deviceIDMatches), nil
}
package cmd
import (
"encoding/json"
"fmt"
"io"
"github.com/buger/jsonparser"
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewInfoCmd is a constructor for `info` command
func newInfoCmd(client *utils.HTTPClient) *cobra.Command {
return &cobra.Command{
Use: "info",
Short: "Display information about onboarded device",
Long: "Display onboarding config of onboarded device on Nunet Device Management Service",
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
resBody, resCode, err := client.MakeRequest("GET", "/onboarding/info", nil)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
info, _, _, err := jsonparser.Get(resBody, "info")
if err != nil {
return fmt.Errorf("unable to parse JSON from key: %w", err)
}
var config types.OnboardingConfig
if err := json.Unmarshal(info, &config); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
displayInTable(cmd.OutOrStdout(), &config)
return nil
},
}
}
func displayInTable(w io.Writer, config *types.OnboardingConfig) {
table := tablewriter.NewWriter(w)
table.SetHeader([]string{"Info", "Value"})
table.Append([]string{"Name", config.Name})
table.Append([]string{"Update Timestamp", fmt.Sprintf("%d", config.UpdateTimestamp)})
table.Append([]string{"Memory Max", fmt.Sprintf("%d", config.TotalResources.RAM)})
table.Append([]string{"Total Core", fmt.Sprintf("%d", config.TotalResources.NumCores)})
table.Append([]string{"CPU Max", fmt.Sprintf("%.2f", config.TotalResources.CPU)})
table.Append([]string{"Reserved CPU", fmt.Sprintf("%.2f", config.OnboardedResources.CPU)})
table.Append([]string{"Reserved Memory", fmt.Sprintf("%d", config.OnboardedResources.RAM)})
table.Append([]string{"Network", config.Network})
table.Append([]string{"Public Key", config.PublicKey})
table.Append([]string{"Node ID", config.NodeID})
table.Append([]string{"Allow Cardano", fmt.Sprintf("%t", config.AllowCardano)})
table.Append([]string{"Dashboard", config.Dashboard})
table.Append([]string{"NTX Price Per Minute", fmt.Sprintf("%f", config.NTXPricePerMinute)})
table.Render()
}
package cmd
import (
"fmt"
"math"
"regexp"
"strconv"
"strings"
)
type intelGPU struct {
index int
}
// name returns the name of the Intel GPU.
func (i *intelGPU) name() string {
pattern := `Device Name:\s+(.+)`
re := regexp.MustCompile(pattern)
xpuOutput, err := runShellCmd(fmt.Sprintf("xpu-smi discovery -d %d", i.index))
if err != nil {
return ""
}
match := re.FindStringSubmatch(xpuOutput)
if len(match) > 1 {
return strings.TrimSpace(match[1])
}
return ""
}
// utilizationRate returns the utilization rate of the Intel GPU.
func (i *intelGPU) utilizationRate() int64 {
pattern := fmt.Sprintf(`GPU Utilization \(%%\)\s+\|\s+(\d+)`)
re := regexp.MustCompile(pattern)
xpuOutput, err := runShellCmd(fmt.Sprintf("xpu-smi stats -d %d", i.index))
if err != nil {
return 0
}
match := re.FindStringSubmatch(xpuOutput)
if len(match) > 1 {
utilization, err := strconv.ParseInt(match[1], 10, 32)
if err != nil {
return 0
}
return utilization
}
return 0
}
// memory returns the memory information of the Intel GPU.
func (i *intelGPU) memory() memoryInfo {
patternTotal := `Memory Physical Size:\s+([^\s]+)\s+MiB`
reTotal := regexp.MustCompile(patternTotal)
patternUsed := `GPU Memory Used \(MiB\)\s+\|\s+(\d+)`
reUsed := regexp.MustCompile(patternUsed)
xpuOutput, err := runShellCmd(fmt.Sprintf("xpu-smi discovery -d %d", i.index))
if err != nil {
return memoryInfo{}
}
matchTotal := reTotal.FindStringSubmatch(xpuOutput)
matchUsed := reUsed.FindStringSubmatch(xpuOutput)
if len(matchTotal) > 1 && len(matchUsed) > 1 {
total, err := strconv.ParseFloat(matchTotal[1], 64)
if err != nil {
total = 0
}
used, err := strconv.ParseFloat(matchUsed[1], 64)
if err != nil {
used = 0
}
free := (total - used)
return memoryInfo{
used: uint64(used),
free: uint64(free),
total: uint64(total),
}
}
return memoryInfo{}
}
// powerUsage returns the power usage of the Intel GPU.
func (i *intelGPU) powerUsage() uint32 {
pattern := `GPU Power \(W\)\s+\|\s+([\d\.]+)`
re := regexp.MustCompile(pattern)
xpuOutput, err := runShellCmd(fmt.Sprintf("xpu-smi stats -d %d", i.index))
if err != nil {
return 0
}
match := re.FindStringSubmatch(xpuOutput)
if len(match) > 1 {
powerFloat, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return 0
}
power := uint32(math.Round(powerFloat))
return power
}
return 0
}
//go:build linux
package cmd
import (
"fmt"
"path/filepath"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/backend"
)
const (
logDir = "/tmp/nunet-log"
dmsUnit = "nunet-dms.service"
tarGzName = "nunet-log.tar.gz"
)
// NewLogCmd is a constructor for `log` command
func newLogCmd(afs afero.Afero, logger backend.Logger) *cobra.Command {
return &cobra.Command{
Use: "log",
Short: "Gather all logs into a tarball. COMMAND MUST RUN AS ROOT WITH SUDO",
RunE: func(cmd *cobra.Command, _ []string) error {
dmsLogDir := filepath.Join(logDir, "dms-log")
fmt.Fprintln(cmd.OutOrStdout(), "Collecting logs...")
err := afs.MkdirAll(dmsLogDir, 0o777)
if err != nil {
return fmt.Errorf("cannot create dms-log directory: %w", err)
}
if logger == nil {
return fmt.Errorf("logger service is not initialised")
}
defer logger.Close()
// filter by service unit name
match := fmt.Sprintf("_SYSTEMD_UNIT=%s", dmsUnit)
err = logger.AddMatch(match)
if err != nil {
return fmt.Errorf("cannot add unit match: %w", err)
}
var counter int
for {
count, err := logger.Next()
if err != nil {
fmt.Fprintf(cmd.OutOrStderr(), "Error reading next logger entry: %v\n", err)
continue
}
if count == 0 {
break
}
entry, err := logger.GetEntry()
if err != nil {
fmt.Fprintf(cmd.OutOrStderr(), "Error getting logger entry %d: %v\n", count, err)
continue
}
msg, ok := entry.Fields["MESSAGE"]
if !ok {
fmt.Fprintf(cmd.OutOrStderr(), "Error: no message field in entry %d\n", count)
}
logData := fmt.Sprintf("%d: %s\n", entry.RealtimeTimestamp, msg)
logFilePath := filepath.Join(dmsLogDir, fmt.Sprintf("dms_log.%d", count))
err = appendToFile(afs, logFilePath, logData)
if err != nil {
fmt.Fprintf(cmd.OutOrStderr(), "Error writing log file for boot %d: %v\n", count, err)
}
counter++
}
if counter == 0 {
return fmt.Errorf("no log entries for %s", dmsUnit)
}
tarGzFile := filepath.Join(logDir, tarGzName)
err = createTar(afs, tarGzFile, dmsLogDir)
if err != nil {
return fmt.Errorf("cannot create tar.gz file: %w", err)
}
err = afs.RemoveAll(dmsLogDir)
if err != nil {
return fmt.Errorf("remove dms-log directory failed: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), tarGzFile)
return nil
},
}
}
package cmd
import (
"fmt"
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
const (
sensorNVML nvml.TemperatureSensors = iota
)
type nvidiaGPU struct {
index int
}
// helper function
func (n *nvidiaGPU) getDevice() (nvml.Device, error) {
device, ret := nvml.DeviceGetHandleByIndex(n.index)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("failed to get device (index %d) handle: %s", n.index, nvml.ErrorString(ret))
}
return device, nil
}
func (n *nvidiaGPU) name() string {
device, err := n.getDevice()
if err != nil {
return ""
}
name, ret := device.GetName()
if ret != nvml.SUCCESS {
return ""
}
return name
}
func (n *nvidiaGPU) utilizationRate() uint32 {
device, err := n.getDevice()
if err != nil {
return 0
}
utilization, ret := device.GetUtilizationRates()
if ret != nvml.SUCCESS {
return 0
}
return utilization.Gpu
}
func (n *nvidiaGPU) memory() memoryInfo {
device, err := n.getDevice()
if err != nil {
return memoryInfo{}
}
memoryNVML, ret := device.GetMemoryInfo()
if ret != nvml.SUCCESS {
return memoryInfo{}
}
memory := memoryInfo{
used: memoryNVML.Used,
free: memoryNVML.Free,
total: memoryNVML.Total,
}
return memory
}
func (n *nvidiaGPU) temperature() float64 {
device, err := n.getDevice()
if err != nil {
return 0
}
temp, ret := device.GetTemperature(sensorNVML)
if ret != nvml.SUCCESS {
return 0
}
return float64(temp)
}
func (n *nvidiaGPU) powerUsage() uint32 {
device, err := n.getDevice()
if err != nil {
return 0
}
power, ret := device.GetPowerUsage()
if ret != nvml.SUCCESS {
return 0
}
return power
}
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewOffboardCmd is a constructor for `offboard` command
func newOffboardCmd(client *utils.HTTPClient) *cobra.Command {
fnForce := "force"
cmd := &cobra.Command{
Use: "offboard",
Short: "Offboard the device from NuNet",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
fmt.Println("Warning: Offboarding will remove all your data and you will not be able to onboard again with the same identity")
answer, err := utils.PromptYesNo(cmd.InOrStdin(), cmd.OutOrStdout(), "Are you sure you want to offboard?")
if err != nil {
return fmt.Errorf("unable to read response: %w", err)
}
if !answer {
return nil
}
force, _ := cmd.Flags().GetBool(fnForce)
query := fmt.Sprintf("?force=%t", force)
resBody, resCode, err := client.MakeRequest("DELETE", "/onboarding/offboard"+query, nil)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
// TODO:what to do with the response body?
fmt.Fprintf(cmd.OutOrStdout(), "%s\n", resBody)
return nil
},
}
cmd.Flags().BoolP(fnForce, "f", false, "force offboarding")
return cmd
}
package cmd
import (
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
var imagesNVIDIA = []string{
"registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/tensorflow",
"registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/pytorch",
}
var imagesAMD = []string{
"registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/tensorflow-amd",
"registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/pytorch-amd",
}
func newOnboardMLCmd(afs afero.Afero, dockerClient *docker.Client) *cobra.Command {
return &cobra.Command{
Use: "onboard-ml",
Short: "Setup for Machine Learning with GPU",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
wsl, err := utils.CheckWSL(afs)
if err != nil {
return fmt.Errorf("could not check WSL: %w", err)
}
vendors, err := resources.ManagerInstance.SystemSpecs().GetGPUVendors()
if err != nil {
return fmt.Errorf("could not fetch GPU vendors: %w", err)
}
// check for GPU vendors
hasAMD := containsVendor(vendors, types.GPUVendorAMDATI)
hasNVIDIA := containsVendor(vendors, types.GPUVendorNvidia)
hasIntel := containsVendor(vendors, types.GPUVendorIntel)
if !hasAMD && !hasNVIDIA && !hasIntel {
return fmt.Errorf("no NVIDIA/AMD/Intel GPU(s) detected")
}
if wsl {
fmt.Fprintf(cmd.OutOrStdout(), "You are running on Windows Subsystem for Linux (WSL)\nMake sure that NVIDIA drivers are set up correctly\n\nWARNING: AMD GPUs are not supported on WSL!\n")
}
if hasNVIDIA {
for _, image := range imagesNVIDIA {
digest, err := dockerClient.PullImage(cmd.Context(), image)
if err != nil {
return fmt.Errorf("could not pull image %s: %w", image, err)
}
fmt.Fprintf(cmd.OutOrStdout(), "Pulled Nvidia image %s with digest %s\n", image, digest)
}
}
if hasAMD {
for _, image := range imagesAMD {
digest, err := dockerClient.PullImage(cmd.Context(), image)
if err != nil {
return fmt.Errorf("could not pull image %s: %w", image, err)
}
fmt.Fprintf(cmd.OutOrStdout(), "Pulled AMD image %s with digest %s\n", image, digest)
}
}
return nil
},
}
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewOnboardCmd is a constructor for `onboard` command
func newOnboardCmd(client *utils.HTTPClient) *cobra.Command {
fnCPU := "cpu"
fnMemory := "memory"
fnChannel := "nunet-channel"
fnAddr := "address"
fnPlugin := "plugin"
fnNTXPrice := "ntx-price"
fnLocal := "local-enable"
fnCardano := "cardano"
fnUnavailable := "unavailable"
cmd := &cobra.Command{
Use: "onboard",
Short: "Onboard current machine to NuNet",
RunE: func(cmd *cobra.Command, _ []string) error {
fmt.Fprintln(cmd.OutOrStdout(), "Checking onboard status...")
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if onboarded {
err := promptReonboard(cmd.InOrStdin(), cmd.OutOrStdout())
if err != nil {
return err
}
}
memory, err := cmd.Flags().GetUint64(fnMemory)
if err != nil {
fmt.Println("couldn't get 'memory' flag: %w", err)
}
cpu, err := cmd.Flags().GetInt64(fnCPU)
if err != nil {
fmt.Println("couldn't get 'cpu' flag: %w", err)
}
channel, err := cmd.Flags().GetString(fnChannel)
if err != nil {
fmt.Println("couldn't get 'channel' flag: %w", err)
}
addr, err := cmd.Flags().GetString(fnAddr)
if err != nil {
fmt.Println("couldn't get 'addr' flag: %w", err)
}
ntx, err := cmd.Flags().GetFloat64(fnNTXPrice)
if err != nil {
fmt.Println("couldn't get 'ntx' flag: %w", err)
}
local, err := cmd.Flags().GetBool(fnLocal)
if err != nil {
fmt.Println("couldn't get 'local' flag: %w", err)
}
unavailable, err := cmd.Flags().GetBool(fnUnavailable)
if err != nil {
fmt.Println("couldn't get 'unavailable' flag: %w", err)
}
cardano, err := cmd.Flags().GetBool(fnCardano)
if err != nil {
fmt.Println("couldn't get 'cardano' flag: %w", err)
}
reserved := types.CapacityForNunet{
Memory: memory,
CPU: cpu,
Channel: channel,
PaymentAddress: addr,
NTXPricePerMinute: ntx,
Cardano: cardano,
ServerMode: local,
IsAvailable: unavailable, // TODO: Update this
}
onboardJSON, err := json.Marshal(reserved)
if err != nil {
return fmt.Errorf("unable to marshal JSON data: %w", err)
}
resBody, resCode, err := client.MakeRequest("POST", "/onboarding/onboard", onboardJSON)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 201 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
// TODO: what to do with the response body?
fmt.Fprintf(cmd.OutOrStdout(), "%s\n", resBody)
fmt.Fprintln(cmd.OutOrStdout(), "Successfully onboarded!")
return nil
},
}
cmd.Flags().Uint64P(fnMemory, "m", 0, "set value for memory usage")
cmd.Flags().Int64P(fnCPU, "c", 0, "set value for CPU usage")
cmd.Flags().StringP(fnChannel, "n", "", "set channel")
cmd.Flags().StringP(fnAddr, "a", "", "set wallet address")
cmd.Flags().Float64P(fnNTXPrice, "x", 0, "price in NTX per minute for onboarded compute resource")
cmd.Flags().StringP(fnPlugin, "p", "", "set plugin")
cmd.Flags().BoolP(fnUnavailable, "u", false, "unavailable for job deployment (default: false)")
cmd.Flags().BoolP(fnLocal, "l", true, "set server mode (enable for local)")
cmd.Flags().BoolP(fnCardano, "C", false, "set Cardano wallet")
cmd.MarkFlagsRequiredTogether("memory", "cpu", "nunet-channel", "address")
return cmd
}
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
func newPeerCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "peer",
Short: "Peer-related operations",
Long: ``,
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
}
cmd.AddCommand(newPeerListCmd(client))
cmd.AddCommand(newPeerSelfCmd(client))
return cmd
}
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewPeerListCmd is a constructor for `peer list` subcommand
func newPeerListCmd(client *utils.HTTPClient) *cobra.Command {
fnDHT := "dht"
cmd := &cobra.Command{
Use: "list",
Short: "Display list of peers in the network",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
dht, _ := cmd.Flags().GetBool(fnDHT)
if !dht {
bootPeer, err := getBootstrapPeers(cmd.OutOrStdout(), client)
if err != nil {
return fmt.Errorf("could not fetch bootstrap peers: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "Bootstrap peers (%d)\n", len(bootPeer))
for _, b := range bootPeer {
fmt.Fprintf(cmd.OutOrStdout(), "%s\n", b)
}
fmt.Fprintf(cmd.OutOrStdout(), "\n")
}
dhtPeer, err := getDHTPeers(client)
if err != nil {
return fmt.Errorf("could not fetch DHT peers: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "DHT peers (%d)\n", len(dhtPeer))
for _, d := range dhtPeer {
fmt.Fprintf(cmd.OutOrStdout(), "%s\n", d)
}
return nil
},
}
cmd.Flags().BoolP(fnDHT, "d", false, "list only DHT peers")
return cmd
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/buger/jsonparser"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewPeerSelfCmd is a constructor for `peer self` subcommand
func newPeerSelfCmd(client *utils.HTTPClient) *cobra.Command {
return &cobra.Command{
Use: "self",
Short: "Display self peer info",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
resBody, resCode, err := client.MakeRequest("GET", "/peers/self", nil)
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
var response map[string]any
if err := json.Unmarshal(resBody, &response); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
id, ok := response["id"]
if !ok {
return fmt.Errorf("no self peer ID returned")
}
addrsByte, _, _, err := jsonparser.Get(resBody, "listen_addr")
if err != nil {
return fmt.Errorf("failed to get addresses field: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), "Host ID:", id)
fmt.Fprintln(cmd.OutOrStdout(), "Listening Addresses:", string(addrsByte))
return nil
},
}
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewResourceConfigCmd is a constructor for `resource-config` command
func newResourceConfigCmd(client *utils.HTTPClient) *cobra.Command {
fnMemory := "memory"
fnCPU := "cpu"
fnNTXPrice := "ntx-price"
cmd := &cobra.Command{
Use: "resource-config",
Short: "Update configuration of onboarded device",
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
memory, _ := cmd.Flags().GetUint64(fnMemory)
cpu, _ := cmd.Flags().GetInt64(fnCPU)
ntx, _ := cmd.Flags().GetFloat64(fnNTXPrice)
updated := types.CapacityForNunet{
Memory: memory,
CPU: cpu,
NTXPricePerMinute: ntx,
}
updatedConfig, err := json.Marshal(updated)
if err != nil {
return fmt.Errorf("unable to marshal JSON data: %w", err)
}
resBody, resCode, err := client.MakeRequest("POST", "/onboarding/resource-config", updatedConfig)
if err != nil {
return fmt.Errorf("unable to make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
fmt.Fprintln(cmd.OutOrStdout(), "Resources updated successfully!")
fmt.Fprintln(cmd.OutOrStdout(), string(resBody))
return nil
},
}
cmd.Flags().Uint64P(fnMemory, "m", 0, "set amount of memory")
cmd.Flags().Int64P(fnCPU, "c", 0, "set amount of CPU")
cmd.Flags().Float64P(fnNTXPrice, "x", 0, "Set NTX Price per minute for compute resources you are updating")
cmd.MarkFlagsRequiredTogether("cpu", "memory")
return cmd
}
package cmd
import (
"fmt"
"github.com/coreos/go-systemd/sdjournal"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/api/docs"
"gitlab.com/nunet/device-management-service/cmd/backend"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/utils"
)
func newRootCmd(client *utils.HTTPClient, afs afero.Afero, dockerExec *docker.Client, logger backend.Logger) *cobra.Command {
cmd := &cobra.Command{
Use: "nunet",
Short: "NuNet Device Management Service",
Long: `The Device Management Service (DMS) Command Line Interface (CLI)`,
Version: docs.SwaggerInfo.Version,
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: false,
HiddenDefaultCmd: true,
},
SilenceErrors: true,
SilenceUsage: true,
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
}
cmd.AddCommand(newGPUCmd(afs))
cmd.AddCommand(newOffboardCmd(client))
cmd.AddCommand(newOnboardMLCmd(afs, dockerExec))
cmd.AddCommand(newRunCmd())
cmd.AddCommand(newPeerCmd(client))
cmd.AddCommand(newOnboardCmd(client))
cmd.AddCommand(newInfoCmd(client))
cmd.AddCommand(newDeviceCmd(client))
cmd.AddCommand(newCapacityCmd(client))
cmd.AddCommand(newResourceConfigCmd(client))
cmd.AddCommand(newLogCmd(afs, logger))
cmd.AddCommand(newWalletCmd(client))
cmd.AddCommand(newVersionCmd())
cmd.AddCommand(newAutoCompleteCmd())
return cmd
}
// Execute is a wrapper for cobra.Command method with same name
// It makes use of cobra.CheckErr to facilitate error handling
func Execute() {
afs := afero.Afero{Fs: afero.NewOsFs()}
client := utils.NewHTTPClient(
fmt.Sprintf("http://%s:%d",
config.GetConfig().Addr,
config.GetConfig().Port),
"/api/v1",
)
dockerClient, err := docker.NewDockerClient()
if err != nil {
cobra.CheckErr(fmt.Errorf("couldn't instantiate new docker client; Error: %w", err))
}
journal, err := sdjournal.NewJournal()
if err != nil {
cobra.CheckErr(fmt.Errorf("failed to get new sdjournal; Error: %w", err))
}
cobra.CheckErr(newRootCmd(client, afs, dockerClient, journal).Execute())
}
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms"
)
func newRunCmd() *cobra.Command {
return &cobra.Command{
Use: "run",
Short: "Start the Device Management Service",
Long: `The Device Management Service (DMS) is a system application for computing and service providers. It handles networking and device management.`,
Run: func(_ *cobra.Command, _ []string) {
dms.Run()
},
}
}
package cmd
import (
"archive/tar"
"compress/gzip"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/buger/jsonparser"
"github.com/olekukonko/tablewriter"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func checkOnboarded(client *utils.HTTPClient) (bool, error) {
resBody, resCode, err := client.MakeRequest("GET", "/onboarding/status", nil)
if err != nil {
return false, fmt.Errorf("unable to make request: %w", err)
}
if resCode != 200 {
return false, fmt.Errorf("request failed with status code: %d", resCode)
}
onboarded, err := jsonparser.GetBoolean(resBody, "onboarded")
if err != nil {
return false, fmt.Errorf("could not get onboard status: %w", err)
}
return onboarded, nil
}
// promptReonboard is a wrapper of utils.PromptYesNo with custom prompt that return error if user declines reonboard
func promptReonboard(r io.Reader, w io.Writer) error {
prompt := "Looks like your machine is already onboarded. Proceed with reonboarding?"
confirmed, err := utils.PromptYesNo(r, w, prompt)
if err != nil {
return fmt.Errorf("could not confirm reonboarding: %w", err)
}
if !confirmed {
return fmt.Errorf("reonboarding aborted by user")
}
return nil
}
// getDHTPeers fetches API to retrieve info from DHT peers
func getDHTPeers(client *utils.HTTPClient) ([]string, error) {
resBody, resCode, err := client.MakeRequest("GET", "/peers/dht", nil)
if err != nil {
return nil, fmt.Errorf("cannot make request: %w", err)
}
if resCode != 200 {
return nil, fmt.Errorf("request failed with status code: %d", resCode)
}
msg, err := jsonparser.GetString(resBody, "message")
if err == nil {
return nil, errors.New(msg)
}
var dhtSlice []string
if _, err = jsonparser.ArrayEach(resBody, func(value []byte, _ jsonparser.ValueType, _ int, _ error) {
dhtSlice = append(dhtSlice, string(value))
}); err != nil {
return nil, fmt.Errorf("cannot iterate over DHT peer list: %w", err)
}
if len(dhtSlice) == 0 {
return nil, fmt.Errorf("no DHT peers available")
}
return dhtSlice, nil
}
// getBootstrapPeers fetches API to retrieve data from bootstrap peers
func getBootstrapPeers(w io.Writer, client *utils.HTTPClient) ([]string, error) {
resBody, resCode, err := client.MakeRequest("GET", "/peers", nil)
if err != nil {
return nil, fmt.Errorf("unable to make request: %w", err)
}
if resCode != 200 {
return nil, fmt.Errorf("request failed with status code: %d", resCode)
}
msg, err := jsonparser.GetString(resBody, "message")
if err == nil {
return nil, errors.New(msg)
}
var bootSlice []string
if _, err = jsonparser.ArrayEach(resBody, func(value []byte, _ jsonparser.ValueType, _ int, _ error) {
id, err := jsonparser.GetString(value, "ID")
if err != nil {
fmt.Fprintln(w, "Error getting bootstrap peer ID string:", err)
os.Exit(1)
}
bootSlice = append(bootSlice, id)
}); err != nil {
return nil, fmt.Errorf("cannot iterate over bootstrap peer list: %w", err)
}
if len(bootSlice) == 0 {
return nil, fmt.Errorf("no bootstrap peers available")
}
return bootSlice, nil
}
// printWallet takes types.BlockchainAddressPrivKey struct as input and display it in YAML-like format for better readability
func printWallet(w io.Writer, pair *types.BlockchainAddressPrivKey) {
if pair.Address != "" {
fmt.Fprintf(w, "address: %s\n", pair.Address)
}
if pair.PrivateKey != "" {
fmt.Fprintf(w, "private_key: %s\n", pair.PrivateKey)
}
if pair.Mnemonic != "" {
fmt.Fprintf(w, "mnemonic: %s\n", pair.Mnemonic)
}
}
func setupTable(w io.Writer) *tablewriter.Table {
table := tablewriter.NewWriter(w)
headers := []string{"Resources", "Memory", "CPU", "Cores"}
table.SetHeader(headers)
table.SetAutoMergeCellsByColumnIndex([]int{0})
table.SetAutoFormatHeaders(false)
return table
}
// appendToFile opens filename and write string data to it
func appendToFile(afs afero.Afero, filename, data string) error {
f, err := afs.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("open %s file failed: %w", filename, err)
}
defer f.Close()
_, err = f.WriteString(data)
if err != nil {
return fmt.Errorf("write string data to file %s failed: %w", filename, err)
}
return nil
}
func createTar(afs afero.Afero, tarGzPath string, sourceDir string) error {
tarGzFile, err := afs.Create(tarGzPath)
if err != nil {
return fmt.Errorf("create %s file failed: %w", tarGzPath, err)
}
defer tarGzFile.Close()
gzWriter := gzip.NewWriter(tarGzFile)
defer gzWriter.Close()
tarWriter := tar.NewWriter(gzWriter)
defer tarWriter.Close()
return afs.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if path == tarGzPath {
return nil
}
header, err := tar.FileInfoHeader(info, info.Name())
if err != nil {
return err
}
header.Name = strings.TrimPrefix(path, sourceDir)
if header.Name == "" || header.Name == "/" {
return nil
}
err = tarWriter.WriteHeader(header)
if err != nil {
return err
}
if info.Mode().IsRegular() {
data, err := afs.ReadFile(path)
if err != nil {
return err
}
_, err = tarWriter.Write(data)
if err != nil {
return err
}
}
return nil
})
}
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/api/docs"
)
func newVersionCmd() *cobra.Command {
return &cobra.Command{
Use: "version",
Short: "Display the Nunet DMS version",
Long: `This command prints the version of the Nunet Device Management Service.`,
Run: func(_ *cobra.Command, _ []string) {
fmt.Printf("Nunet Device Management Service Version: %s\n", docs.SwaggerInfo.Version)
},
}
}
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewWalletCmd is a constructor for `wallet` parent command
func newWalletCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "wallet",
Short: "Wallet Management",
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
}
cmd.AddCommand(newWalletNewCmd(client))
return cmd
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewWalletNewCmd is a constructor for `wallet new` command
func newWalletNewCmd(client *utils.HTTPClient) *cobra.Command {
fnEth := "ethereum"
fnCardano := "cardano"
cmd := &cobra.Command{
Use: "new",
Short: "Create new wallet",
RunE: func(cmd *cobra.Command, _ []string) error {
eth, _ := cmd.Flags().GetBool(fnEth)
var (
pair *types.BlockchainAddressPrivKey
query string
)
if eth {
query = "?blockchain=ethereum"
}
resBody, resCode, err := client.MakeRequest("GET", "/onboarding/address/new"+query, nil)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
if err := json.Unmarshal(resBody, &pair); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
printWallet(cmd.OutOrStdout(), pair)
return nil
},
}
cmd.Flags().BoolP(fnEth, "e", false, "create Ethereum wallet")
cmd.Flags().BoolP(fnCardano, "c", false, "create Cardano wallet")
cmd.MarkFlagsOneRequired(fnEth, fnCardano)
cmd.MarkFlagsMutuallyExclusive(fnEth, fnCardano)
return cmd
}
package clover
import (
"fmt"
clover "github.com/ostafen/clover/v2"
)
// NewDB initializes and sets up the clover database using bbolt under the hood.
// Additionally, it automatically creates collections for the necessary types.
func NewDB(path string, collections []string) (*clover.DB, error) {
db, err := clover.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
for _, collection := range collections {
if err := db.CreateCollection(collection); err != nil {
return nil, fmt.Errorf("failed to create collection %s: %w", collection, err)
}
}
return db, nil
}
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatClover is a Clover implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatClover struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatClover.
// It initializes and returns a Clover-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *clover.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatClover{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerClover is a Clover implementation of the RequestTracker interface.
type RequestTrackerClover struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerClover.
// It initializes and returns a Clover-based repository for RequestTracker entities.
func NewRequestTracker(db *clover.DB) repositories.RequestTracker {
return &RequestTrackerClover{
NewGenericRepository[types.RequestTracker](db),
}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineClover is a Clover implementation of the VirtualMachine interface.
type VirtualMachineClover struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineClover.
// It initializes and returns a Clover-based repository for VirtualMachine entities.
func NewVirtualMachine(db *clover.DB) repositories.VirtualMachine {
return &VirtualMachineClover{
NewGenericRepository[types.VirtualMachine](db),
}
}
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericEntityRepositoryClover is a generic single entity repository implementation using Clover.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryClover.
// It initializes and returns a repository with the provided Clover database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericEntityRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericEntityRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericEntityRepositoryClover[T]) query() *clover_q.Query {
return clover_q.NewQuery(repo.collection)
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryClover[T]) Save(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) Get(_ context.Context) (T, error) {
var model T
q := repo.query().Sort(clover_q.SortOption{
Field: "CreatedAt",
Direction: -1,
})
doc, err := repo.db.FindFirst(q)
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryClover[T]) Clear(_ context.Context) error {
return repo.db.Delete(repo.query())
}
// History retrieves previous versions of the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) History(_ context.Context, query repositories.Query[T]) ([]T, error) {
var models []T
q := repo.query()
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return false
}
models = append(models, model)
return true
})
return models, handleDBError(err)
}
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pkField = "_id"
deletedAtField = "DeletedAt"
)
// GenericRepositoryClover is a generic repository implementation using Clover.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericRepository creates a new instance of GenericRepositoryClover.
// It initializes and returns a repository with the provided Clover database.
func NewGenericRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericRepositoryClover[T]) query(includeDeleted bool) *clover_q.Query {
query := clover_q.NewQuery(repo.collection)
if !includeDeleted {
query = query.Where(clover_q.Field(deletedAtField).LtEq(time.Unix(0, 0)))
}
return query
}
func (repo *GenericRepositoryClover[T]) queryWithID(
id interface{},
includeDeleted bool,
) *clover_q.Query {
return repo.query(includeDeleted).Where(clover_q.Field(pkField).Eq(id.(string)))
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryClover[T]) Create(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return data, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryClover[T]) Get(_ context.Context, id interface{}) (T, error) {
var model T
doc, err := repo.db.FindFirst(repo.queryWithID(id, false))
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryClover[T]) Update(
ctx context.Context,
id interface{},
data T,
) (T, error) {
updates := toCloverDoc(data).AsMap()
updates["UpdatedAt"] = time.Now()
err := repo.db.Update(repo.queryWithID(id, false), updates)
if err != nil {
return data, handleDBError(err)
}
data, err = repo.Get(ctx, id)
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryClover[T]) Delete(_ context.Context, id interface{}) error {
err := repo.db.Update(
repo.queryWithID(id, false),
map[string]interface{}{"DeletedAt": time.Now()},
)
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryClover[T]) Find(
_ context.Context,
query repositories.Query[T],
) (T, error) {
var model T
q := repo.query(false)
q = applyConditions(q, query)
doc, err := repo.db.FindFirst(q)
if err != nil {
return model, handleDBError(err)
} else if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryClover[T]) FindAll(
_ context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
var modelParsingErr error
q := repo.query(false)
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
model, internalErr := toModel[T](doc, false)
if internalErr != nil {
modelParsingErr = handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, internalErr))
return false
}
models = append(models, model)
return true
})
if err != nil {
return models, handleDBError(err)
}
if modelParsingErr != nil {
return models, modelParsingErr
}
return models, nil
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a Clover database query.
// It takes a Clover database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified Clover database instance is returned.
func applyConditions[T repositories.ModelType](
q *clover_q.Query,
query repositories.Query[T],
) *clover_q.Query {
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
// change the field name to json tag name if specified in the struct
condition.Field = fieldJSONTag[T](condition.Field)
switch condition.Operator {
case "=":
q = q.Where(clover_q.Field(condition.Field).Eq(condition.Value))
case ">":
q = q.Where(clover_q.Field(condition.Field).Gt(condition.Value))
case ">=":
q = q.Where(clover_q.Field(condition.Field).GtEq(condition.Value))
case "<":
q = q.Where(clover_q.Field(condition.Field).Lt(condition.Value))
case "<=":
q = q.Where(clover_q.Field(condition.Field).LtEq(condition.Value))
case "!=":
q = q.Where(clover_q.Field(condition.Field).Neq(condition.Value))
case "IN":
if values, ok := condition.Value.([]interface{}); ok {
q = q.Where(clover_q.Field(condition.Field).In(values...))
}
case "LIKE":
if value, ok := condition.Value.(string); ok {
q = q.Where(clover_q.Field(condition.Field).Like(value))
}
}
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldName = fieldJSONTag[T](fieldName)
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
q = q.Where(clover_q.Field(fieldName).Eq(fieldValue))
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := 1
if query.SortBy[0] == '-' {
dir = -1
query.SortBy = fieldJSONTag[T](query.SortBy[1:])
}
q = q.Sort(clover_q.SortOption{Field: query.SortBy, Direction: dir})
}
// Apply limit if specified in the query.
if query.Limit > 0 {
q = q.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
q = q.Limit(query.Offset)
}
return q
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoClover is a Clover implementation of the PeerInfo interface.
type PeerInfoClover struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoClover.
// It initializes and returns a Clover-based repository for PeerInfo entities.
func NewPeerInfo(db *clover.DB) repositories.PeerInfo {
return &PeerInfoClover{NewGenericRepository[types.PeerInfo](db)}
}
// MachineClover is a Clover implementation of the Machine interface.
type MachineClover struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineClover.
// It initializes and returns a Clover-based repository for Machine entities.
func NewMachine(db *clover.DB) repositories.Machine {
return &MachineClover{NewGenericRepository[types.Machine](db)}
}
// FreeResourcesClover is a Clover implementation of the FreeResources interface.
type FreeResourcesClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResources(db *clover.DB) repositories.FreeResources {
return &FreeResourcesClover{NewGenericEntityRepository[types.FreeResources](db)}
}
// AvailableResourcesClover is a Clover implementation of the AvailableResources interface.
type AvailableResourcesClover struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// AvailableResourcesRepositoryClover is a Clover implementation of the AvailableResourcesRepository interface.
type AvailableResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesClover.
// It initializes and returns a Clover-based repository for AvailableResources entity.
func NewAvailableResources(db *clover.DB) repositories.AvailableResources {
return &AvailableResourcesClover{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesClover is a Clover implementation of the Services interface.
type ServicesClover struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesClover.
// It initializes and returns a Clover-based repository for Services entities.
func NewServices(db *clover.DB) repositories.Services {
return &ServicesClover{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsClover is a Clover implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsClover struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsClover.
// It initializes and returns a Clover-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *clover.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsClover{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoClover is a Clover implementation of the Libp2pInfo interface.
type Libp2pInfoClover struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoClover.
// It initializes and returns a Clover-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *clover.DB) repositories.Libp2pInfo {
return &Libp2pInfoClover{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDClover is a Clover implementation of the MachineUUID interface.
type MachineUUIDClover struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDClover.
// It initializes and returns a Clover-based repository for MachineUUID entity.
func NewMachineUUID(db *clover.DB) repositories.MachineUUID {
return &MachineUUIDClover{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionClover is a Clover implementation of the Connection interface.
type ConnectionClover struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionClover.
// It initializes and returns a Clover-based repository for Connection entities.
func NewConnection(db *clover.DB) repositories.Connection {
return &ConnectionClover{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenClover is a Clover implementation of the ElasticToken interface.
type ElasticTokenClover struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenClover.
// It initializes and returns a Clover-based repository for ElasticToken entities.
func NewElasticToken(db *clover.DB) repositories.ElasticToken {
return &ElasticTokenClover{NewGenericRepository[types.ElasticToken](db)}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesRepositoryClover is a Clover implementation of the MachineResourcesRepository interface.
type MachineResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResourcesRepository creates a new instance of MachineResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for MachineResources entity.
func NewMachineResourcesRepository(db *clover.DB) repositories.MachineResources {
return &MachineResourcesRepositoryClover{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesRepositoryClover is a Clover implementation of the FreeResourcesRepository interface.
type FreeResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResourcesRepository creates a new instance of FreeResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResourcesRepository(db *clover.DB) repositories.FreeResources {
return &FreeResourcesRepositoryClover{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesRepositoryClover is a Clover implementation of the OnboardedResourcesRepository interface.
type OnboardedResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResourcesRepository creates a new instance of OnboardedResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for OnboardedResources entity.
func NewOnboardedResourcesRepository(db *clover.DB) repositories.OnboardedResources {
return &OnboardedResourcesRepositoryClover{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// RequiredResourcesRepositoryClover is a Clover implementation of the RequiredResourcesRepository interface.
type RequiredResourcesRepositoryClover struct {
repositories.GenericRepository[types.RequiredResources]
}
// NewRequiredResourcesRepository creates a new instance of RequiredResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for RequiredResources entities.
func NewRequiredResourcesRepository(db *clover.DB) repositories.RequiredResources {
return &RequiredResourcesRepositoryClover{
NewGenericRepository[types.RequiredResources](db),
}
}
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeClover is a Clover implementation of the StorageVolume interface.
type StorageVolumeClover struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeClover.
// It initializes and returns a Clover-based repository for StorageVolume entities.
func NewStorageVolume(db *clover.DB) repositories.StorageVolume {
return &StorageVolumeClover{
NewGenericRepository[types.StorageVolume](db),
}
}
package clover
import (
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
"gitlab.com/nunet/device-management-service/db/repositories"
)
func handleDBError(err error) error {
if err != nil {
switch err {
case clover.ErrDocumentNotExist:
return repositories.ErrNotFound
case clover.ErrDuplicateKey:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
func toCloverDoc[T repositories.ModelType](data T) *clover_d.Document {
jsonBytes, err := json.Marshal(data)
if err != nil {
return clover_d.NewDocument()
}
mappedData := make(map[string]interface{})
err = json.Unmarshal(jsonBytes, &mappedData)
if err != nil {
return clover_d.NewDocument()
}
doc := clover_d.NewDocumentOf(mappedData)
return doc
}
func toModel[T repositories.ModelType](doc *clover_d.Document, isEntityRepo bool) (T, error) {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return model, err
}
if !isEntityRepo {
// we shouldn't try to update IDs of entity repositories as they might not
// even have an ID at all
model, err = repositories.UpdateField(model, "ID", doc.ObjectId())
if err != nil {
return model, err
}
}
return model, nil
}
func fieldJSONTag[T repositories.ModelType](field string) string {
fieldName := field
if field, ok := reflect.TypeOf(*new(T)).FieldByName(field); ok {
if tag, ok := field.Tag.Lookup("json"); ok {
fieldName = strings.Split(tag, ",")[0]
}
}
return fieldName
}
package repositories
import (
"context"
)
// QueryCondition is a struct representing a query condition.
type QueryCondition struct {
Field string // Field specifies the database or struct field to which the condition applies.
Operator string // Operator defines the comparison operator (e.g., "=", ">", "<").
Value interface{} // Value is the expected value for the given field.
}
type ModelType interface{}
// Query is a struct that wraps both the instance of type T and additional query parameters.
// It is used to construct queries with conditions, sorting, limiting, and offsetting.
type Query[T any] struct {
Instance T // Instance is an optional object of type T used to build conditions from its fields.
Conditions []QueryCondition // Conditions represent the conditions applied to the query.
SortBy string // SortBy specifies the field by which the query results should be sorted.
Limit int // Limit specifies the maximum number of results to return.
Offset int // Offset specifies the number of results to skip before starting to return data.
}
// GenericRepository is an interface defining basic CRUD operations and standard querying methods.
type GenericRepository[T ModelType] interface {
// Create adds a new record to the repository.
Create(ctx context.Context, data T) (T, error)
// Get retrieves a record by its identifier.
Get(ctx context.Context, id interface{}) (T, error)
// Update modifies a record by its identifier.
Update(ctx context.Context, id interface{}, data T) (T, error)
// Delete removes a record by its identifier.
Delete(ctx context.Context, id interface{}) error
// Find retrieves a single record based on a query.
Find(ctx context.Context, query Query[T]) (T, error)
// FindAll retrieves multiple records based on a query.
FindAll(ctx context.Context, query Query[T]) ([]T, error)
// GetQuery returns an empty query instance for the repository's type.
GetQuery() Query[T]
}
// EQ creates a QueryCondition for equality comparison.
// It takes a field name and a value and returns a QueryCondition with the equality operator.
func EQ(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "=", Value: value}
}
// GT creates a QueryCondition for greater-than comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than operator.
func GT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">", Value: value}
}
// GTE creates a QueryCondition for greater-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than or equal operator.
func GTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">=", Value: value}
}
// LT creates a QueryCondition for less-than comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than operator.
func LT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<", Value: value}
}
// LTE creates a QueryCondition for less-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than or equal operator.
func LTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<=", Value: value}
}
// IN creates a QueryCondition for an "IN" comparison.
// It takes a field name and a slice of values and returns a QueryCondition with the "IN" operator.
func IN(field string, values []interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "IN", Value: values}
}
// LIKE creates a QueryCondition for a "LIKE" comparison.
// It takes a field name and a pattern and returns a QueryCondition with the "LIKE" operator.
func LIKE(field, pattern string) QueryCondition {
return QueryCondition{Field: field, Operator: "LIKE", Value: pattern}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatGORM is a GORM implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatGORM struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatGORM.
// It initializes and returns a GORM-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *gorm.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatGORM{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerGORM is a GORM implementation of the RequestTracker interface.
type RequestTrackerGORM struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerGORM.
// It initializes and returns a GORM-based repository for RequestTracker entities.
func NewRequestTracker(db *gorm.DB) repositories.RequestTracker {
return &RequestTrackerGORM{
NewGenericRepository[types.RequestTracker](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineGORM is a GORM implementation of the VirtualMachine interface.
type VirtualMachineGORM struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineGORM.
// It initializes and returns a GORM-based repository for VirtualMachine entities.
func NewVirtualMachine(db *gorm.DB) repositories.VirtualMachine {
return &VirtualMachineGORM{
NewGenericRepository[types.VirtualMachine](db),
}
}
package gorm
import (
"context"
"fmt"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
createdAtField = "CreatedAt"
)
// GenericEntityRepositoryGORM is a generic single entity repository implementation using GORM as an ORM.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB // db is the GORM database instance.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryGORM.
// It initializes and returns a repository with the provided GORM database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *gorm.DB,
) repositories.GenericEntityRepository[T] {
return &GenericEntityRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryGORM[T]) Save(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Get(ctx context.Context) (T, error) {
var result T
query := repo.GetQuery()
query.SortBy = fmt.Sprintf("-%s", createdAtField)
db := repo.db.WithContext(ctx)
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Clear(ctx context.Context) error {
return repo.db.WithContext(ctx).Delete(new(T), "id IS NOT NULL").Error
}
// History retrieves previous records from the repository constrained by the query.
func (repo *GenericEntityRepositoryGORM[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
package gorm
import (
"context"
"fmt"
"reflect"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericRepositoryGORM is a generic repository implementation using GORM as an ORM.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB
}
// NewGenericRepository creates a new instance of GenericRepositoryGORM.
// It initializes and returns a repository with the provided GORM database.
func NewGenericRepository[T repositories.ModelType](db *gorm.DB) repositories.GenericRepository[T] {
return &GenericRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryGORM[T]) Create(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Get(ctx context.Context, id interface{}) (T, error) {
var result T
err := repo.db.WithContext(ctx).First(&result, "id = ?", id).Error
if err != nil {
return result, handleDBError(err)
}
return result, handleDBError(err)
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Update(ctx context.Context, id interface{}, data T) (T, error) {
err := repo.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Updates(data).Error
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.WithContext(ctx).Delete(new(T), "id = ?", id).Error
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryGORM[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var result T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryGORM[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a GORM database query.
// It takes a GORM database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified GORM database instance is returned.
func applyConditions[T any](db *gorm.DB, query repositories.Query[T]) *gorm.DB {
// Retrieve the table name using the GORM naming strategy.
tableName := db.NamingStrategy.TableName(reflect.TypeOf(*new(T)).Name())
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
columnName := db.NamingStrategy.ColumnName(tableName, condition.Field)
db = db.Where(
fmt.Sprintf("%s %s ?", columnName, condition.Operator),
condition.Value,
)
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
columnName := db.NamingStrategy.ColumnName(tableName, fieldName)
db = db.Where(fmt.Sprintf("%s = ?", columnName), fieldValue)
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := "ASC"
if query.SortBy[0] == '-' {
query.SortBy = query.SortBy[1:]
dir = "DESC"
}
columnName := db.NamingStrategy.ColumnName(tableName, query.SortBy)
db = db.Order(fmt.Sprintf("%s.%s %s", tableName, columnName, dir))
}
// Apply limit if specified in the query.
if query.Limit > 0 {
db = db.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
db = db.Limit(query.Offset)
}
return db
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoGORM is a GORM implementation of the PeerInfo interface.
type PeerInfoGORM struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoGORM.
// It initializes and returns a GORM-based repository for PeerInfo entities.
func NewPeerInfo(db *gorm.DB) repositories.PeerInfo {
return &PeerInfoGORM{NewGenericRepository[types.PeerInfo](db)}
}
// MachineGORM is a GORM implementation of the Machine interface.
type MachineGORM struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineGORM.
// It initializes and returns a GORM-based repository for Machine entities.
func NewMachine(db *gorm.DB) repositories.Machine {
return &MachineGORM{NewGenericRepository[types.Machine](db)}
}
// AvailableResourcesGORM is a GORM implementation of the AvailableResources interface.
type AvailableResourcesGORM struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesGORM.
// It initializes and returns a GORM-based repository for AvailableResources entity.
func NewAvailableResources(db *gorm.DB) repositories.AvailableResources {
return &AvailableResourcesGORM{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesGORM is a GORM implementation of the Services interface.
type ServicesGORM struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesGORM.
// It initializes and returns a GORM-based repository for Services entities.
func NewServices(db *gorm.DB) repositories.Services {
return &ServicesGORM{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsGORM is a GORM implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsGORM struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsGORM.
// It initializes and returns a GORM-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *gorm.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsGORM{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoGORM is a GORM implementation of the Libp2pInfo interface.
type Libp2pInfoGORM struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoGORM.
// It initializes and returns a GORM-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *gorm.DB) repositories.Libp2pInfo {
return &Libp2pInfoGORM{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDGORM is a GORM implementation of the MachineUUID interface.
type MachineUUIDGORM struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDGORM.
// It initializes and returns a GORM-based repository for MachineUUID entity.
func NewMachineUUID(db *gorm.DB) repositories.MachineUUID {
return &MachineUUIDGORM{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionGORM is a GORM implementation of the Connection interface.
type ConnectionGORM struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionGORM.
// It initializes and returns a GORM-based repository for Connection entities.
func NewConnection(db *gorm.DB) repositories.Connection {
return &ConnectionGORM{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenGORM is a GORM implementation of the ElasticToken interface.
type ElasticTokenGORM struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenGORM.
// It initializes and returns a GORM-based repository for ElasticToken entities.
func NewElasticToken(db *gorm.DB) repositories.ElasticToken {
return &ElasticTokenGORM{NewGenericRepository[types.ElasticToken](db)}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
type OnboardingParamsGORM struct {
repositories.GenericEntityRepository[types.OnboardingConfig]
}
func NewOnboardingParams(db *gorm.DB) repositories.OnboardingParams {
return &OnboardingParamsGORM{
NewGenericEntityRepository[types.OnboardingConfig](db),
}
}
package gorm
import (
"gitlab.com/nunet/device-management-service/types"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesGORM is a GORM implementation of the MachineResources interface.
type MachineResourcesGORM struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResources creates a new instance of MachineResourcesGORM.
// It initializes and returns a GORM-based repository for MachineResources entity.
func NewMachineResources(db *gorm.DB) repositories.MachineResources {
return &MachineResourcesGORM{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesGORM is a GORM implementation of the FreeResources interface.
type FreeResourcesGORM struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesGORM.
// It initializes and returns a GORM-based repository for FreeResources entity.
func NewFreeResources(db *gorm.DB) repositories.FreeResources {
return &FreeResourcesGORM{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesRepositoryGORM is a GORM implementation of the OnboardedResources interface.
type OnboardedResourcesRepositoryGORM struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResources creates a new instance of OnboardedResourcesGORM.
// It initializes and returns a GORM-based repository for OnboardedResources entity.
func NewOnboardedResources(db *gorm.DB) repositories.OnboardedResources {
return &OnboardedResourcesRepositoryGORM{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// RequiredResourcesRepositoryGORM is a GORM implementation of the RequiredResources interface.
type RequiredResourcesRepositoryGORM struct {
repositories.GenericRepository[types.RequiredResources]
}
// NewRequiredResources creates a new instance of RequiredResourcesGORM.
// It initializes and returns a GORM-based repository for RequiredResources entities.
func NewRequiredResources(db *gorm.DB) repositories.RequiredResources {
return &RequiredResourcesRepositoryGORM{
NewGenericRepository[types.RequiredResources](db),
}
}
package gorm
import (
"errors"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// handleDBError is a utility function that translates GORM database errors into custom repository errors.
// It takes a GORM database error as input and returns a corresponding custom error from the repositories package.
func handleDBError(err error) error {
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return repositories.ErrNotFound
case gorm.ErrInvalidData, gorm.ErrInvalidField, gorm.ErrInvalidValue:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
package repositories
import (
"fmt"
"reflect"
)
// UpdateField is a generic function that updates a field of a struct or a pointer to a struct.
// The function uses reflection to dynamically update the specified field of the input struct.
func UpdateField[T interface{}](input T, fieldName string, newValue interface{}) (T, error) {
// Use reflection to get the struct's field
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
// If input is a pointer, get the underlying element
val = val.Elem()
} else {
// If input is not a pointer, ensure it's addressable
val = reflect.ValueOf(&input).Elem()
}
// Check if the input is a struct
if val.Kind() != reflect.Struct {
return input, fmt.Errorf("not a struct: %T", input)
}
// Get the field by name
field := val.FieldByName(fieldName)
if !field.IsValid() {
return input, fmt.Errorf("field not found: %v", fieldName)
}
// Check if the field is settable
if !field.CanSet() {
return input, fmt.Errorf("field not settable: %v", fieldName)
}
// Check if types are compatible
if !reflect.TypeOf(newValue).ConvertibleTo(field.Type()) {
return input, fmt.Errorf(
"incompatible conversion: %v -> %v; value: %v",
field.Type(), reflect.TypeOf(newValue), newValue,
)
}
// Convert the new value to the field type
convertedValue := reflect.ValueOf(newValue).Convert(field.Type())
// Set the new value to the field
field.Set(convertedValue)
return input, nil
}
// IsEmptyValue checks if value represents a zero-value struct (or pointer to a zero-value struct) using reflection.
// The function is useful for determining if a struct or its pointer is empty, i.e., all fields have their zero-values.
func IsEmptyValue(value interface{}) bool {
// Check if the value is nil
if value == nil {
return true
}
// Use reflection to get the value's type and kind
val := reflect.ValueOf(value)
// If the value is a pointer, dereference it to get the underlying element
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Check if the value is zero (empty) based on its kind
return val.IsZero()
}
package actor
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
type BasicActor struct {
dispatch *Dispatch
scheduler *bt.Scheduler
registry *registry
network network.Network
security SecurityContext
params BasicActorParams
self Handle
mx sync.Mutex
subscriptions map[string]uint64
}
type BasicActorParams struct{}
var _ Actor = (*BasicActor)(nil)
// New creates a new basic actor.
func New(dispatch *Dispatch, scheduler *bt.Scheduler, net network.Network, security *BasicSecurityContext, params BasicActorParams, self Handle) (*BasicActor, error) {
if dispatch == nil {
return nil, errors.New("dispatch is nil")
}
if scheduler == nil {
return nil, errors.New("scheduler is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if security == nil {
return nil, errors.New("security is nil")
}
actor := &BasicActor{
dispatch: dispatch,
scheduler: scheduler,
registry: ®istry{},
network: net,
security: security,
params: params,
self: self,
subscriptions: make(map[string]uint64),
}
return actor, nil
}
func (a *BasicActor) Start() error {
// Network messages
if err := a.network.HandleMessage(
fmt.Sprintf("actor/%s/messages/0.0.1", a.self.Address.InboxAddress),
a.handleMessage,
); err != nil {
return fmt.Errorf("starting actor: %s: %w", a.self.ID, err)
}
// and start the internal goroutines
a.dispatch.Start()
a.scheduler.Start()
return nil
}
func (a *BasicActor) handleMessage(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling message: %s", err)
return
}
if !a.self.ID.Equal(msg.To.ID) {
log.Warnf("message is not for ourselves: %s %s", a.self.ID, msg.To.ID)
return
}
_ = a.dispatch.Receive(msg)
}
func (a *BasicActor) Context() context.Context {
return a.dispatch.Context()
}
func (a *BasicActor) Handle() Handle {
return a.self
}
func (a *BasicActor) Security() SecurityContext {
return a.security
}
func (a *BasicActor) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
return a.dispatch.AddBehavior(behavior, continuation, opt...)
}
func (a *BasicActor) RemoveBehavior(behavior string) {
a.dispatch.RemoveBehavior(behavior)
}
func (a *BasicActor) Receive(msg Envelope) error {
if a.self.ID.Equal(msg.To.ID) {
return a.dispatch.Receive(msg)
}
if msg.IsBroadcast() {
return a.dispatch.Receive(msg)
}
return fmt.Errorf("bad receiver: %w", ErrInvalidMessage)
}
func (a *BasicActor) Send(msg Envelope) error {
if msg.To.ID.Equal(a.self.ID) {
return a.Receive(msg)
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
invoke := []Capability{Capability(msg.Behavior)}
var delegate []Capability
if msg.Options.ReplyTo != "" {
delegate = append(delegate, Capability(msg.Options.ReplyTo))
}
if err := a.security.Provide(&msg, invoke, delegate); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
addrs, err := a.network.ResolveAddress(
a.Context(),
msg.To.Address.HostID,
)
if err != nil {
return fmt.Errorf("resolving address for %s: %w", msg.To.ID, err)
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
err = a.network.SendMessage(
a.Context(),
addrs,
types.MessageEnvelope{
Type: types.MessageType(
fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress),
),
Data: data,
})
if err != nil {
return fmt.Errorf("sending message to %s: %w", msg.To.ID, err)
}
return nil
}
func (a *BasicActor) Invoke(msg Envelope, opt ...BehaviorOption) (<-chan Envelope, error) {
if msg.Options.ReplyTo == "" {
msg.Options.ReplyTo = fmt.Sprintf("/dms/actor/replyto/%d", a.security.Nonce())
}
result := make(chan Envelope, 1)
opt = append([]BehaviorOption{
WithBehaviorExpiry(msg.Options.Expire),
WithBehaviorOneShot(true),
}, opt...)
if err := a.dispatch.AddBehavior(
msg.Options.ReplyTo,
func(reply Envelope) {
result <- reply
close(result)
},
opt...,
); err != nil {
return nil, fmt.Errorf("adding reply behavior: %w", err)
}
if err := a.Send(msg); err != nil {
a.dispatch.RemoveBehavior(msg.Options.ReplyTo)
return nil, fmt.Errorf("sending message: %w", err)
}
return result, nil
}
func (a *BasicActor) Publish(msg Envelope) error {
if !msg.IsBroadcast() {
return ErrInvalidMessage
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
broadcast := []Capability{Capability(msg.Behavior)}
if err := a.security.ProvideBroadcast(&msg, msg.Options.Topic, broadcast); err != nil {
return fmt.Errorf("providing behavior capability for %s: %w", msg.Behavior, err)
}
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
if err := a.network.Publish(a.Context(), msg.Options.Topic, data); err != nil {
return fmt.Errorf("publishing message: %w", err)
}
return nil
}
func (a *BasicActor) Subscribe(topic string) error {
a.mx.Lock()
defer a.mx.Unlock()
_, ok := a.subscriptions[topic]
if ok {
return nil
}
subID, err := a.network.Subscribe(
a.Context(),
topic,
a.handleBroadcast,
func(data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
return a.validateBroadcast(topic, data, validatorData)
},
)
if err != nil {
return fmt.Errorf("subscribe: %w", err)
}
a.subscriptions[topic] = subID
return nil
}
func (a *BasicActor) validateBroadcast(topic string, data []byte, validatorData interface{}) (network.ValidationResult, interface{}) {
var msg Envelope
if validatorData != nil {
if _, ok := validatorData.(Envelope); !ok {
log.Warnf("bogus pubsub validation data: %v", validatorData)
return network.ValidationReject, nil
}
// we have already validated the message, just short-circuit
return network.ValidationAccept, validatorData
} else if err := json.Unmarshal(data, &msg); err != nil {
return network.ValidationReject, nil
}
if !msg.IsBroadcast() {
return network.ValidationReject, nil
}
if msg.Options.Topic != topic {
return network.ValidationReject, nil
}
if msg.Expired() {
return network.ValidationIgnore, nil
}
if err := a.security.Verify(msg); err != nil {
return network.ValidationReject, nil
}
return network.ValidationAccept, msg
}
func (a *BasicActor) handleBroadcast(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
log.Debugf("error unmarshaling broadcast message: %s", err)
return
}
if err := a.Receive(msg); err != nil {
log.Warnf("error receiving broadcast message: %s", err)
}
}
func (a *BasicActor) Stop() error {
a.dispatch.close()
for topic, subID := range a.subscriptions {
err := a.network.Unsubscribe(topic, subID)
if err != nil {
log.Debugf("error unsubscribing from %s: %s", topic, err)
}
}
return nil
}
package actor
import (
"context"
"fmt"
"sync"
"time"
)
var (
DefaultDispatchGCInterval = 120 * time.Second
DefaultDispatchWorkers = 1
)
// DispatchLimiter implements a (potentially) stateful resource access limiter
// This is necessary to combat spam attacks and ensure that our system does not
// become overloaded with too many goroutines.
type DispatchLimiter interface {
Acquire(msg Envelope) error
Release(msg Envelope)
}
// NoDispatchLimiter is the null limiter, that does not rate limit
type NoDispatchLimiter struct{}
// Dispatch provides a reaction kernel with multithreaded dispatch and oneshot
// continuations.
type Dispatch struct {
ctx context.Context
close func()
sctx SecurityContext
mx sync.Mutex
q chan Envelope // incoming message queue
vq chan Envelope // verified message queue
behaviors map[string]*BehaviorState
started bool
options DispatchOptions
}
type DispatchOptions struct {
Limiter DispatchLimiter
GCInterval time.Duration
Workers int
}
type BehaviorState struct {
cont Behavior
opt BehaviorOptions
}
type DispatchOption func(o *DispatchOptions)
func WithDispatchWorkers(count int) DispatchOption {
return func(o *DispatchOptions) {
o.Workers = count
}
}
func WithDispatchGCInterval(dt time.Duration) DispatchOption {
return func(o *DispatchOptions) {
o.GCInterval = dt
}
}
func WithDispatchLimiter(limiter DispatchLimiter) DispatchOption {
return func(o *DispatchOptions) {
o.Limiter = limiter
}
}
func (l NoDispatchLimiter) Acquire(_ Envelope) error { return nil }
func (l NoDispatchLimiter) Release(_ Envelope) {}
func NewDispatch(sctx SecurityContext, opt ...DispatchOption) *Dispatch {
ctx, cancel := context.WithCancel(context.Background())
k := &Dispatch{
sctx: sctx,
ctx: ctx,
close: cancel,
q: make(chan Envelope),
vq: make(chan Envelope),
behaviors: make(map[string]*BehaviorState),
options: DispatchOptions{
GCInterval: DefaultDispatchGCInterval,
Workers: DefaultDispatchWorkers,
Limiter: NoDispatchLimiter{},
},
}
for _, f := range opt {
f(&k.options)
}
return k
}
func (k *Dispatch) Start() {
k.mx.Lock()
defer k.mx.Unlock()
if !k.started {
for i := 0; i < k.options.Workers; i++ {
go k.recv()
}
go k.dispatch()
go k.gc()
k.started = true
}
}
func (k *Dispatch) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
st := &BehaviorState{
cont: continuation,
opt: BehaviorOptions{
Capability: []Capability{Capability(behavior)},
},
}
for _, f := range opt {
if err := f(&st.opt); err != nil {
return fmt.Errorf("adding behavior: %w", err)
}
}
k.mx.Lock()
defer k.mx.Unlock()
k.behaviors[behavior] = st
return nil
}
func (k *Dispatch) RemoveBehavior(behavior string) {
k.mx.Lock()
defer k.mx.Unlock()
delete(k.behaviors, behavior)
}
func (k *Dispatch) Receive(msg Envelope) error {
select {
case k.q <- msg:
return nil
case <-k.ctx.Done():
return k.ctx.Err()
}
}
func (k *Dispatch) Context() context.Context {
return k.ctx
}
func (k *Dispatch) recv() {
for {
select {
case msg := <-k.q:
if err := k.sctx.Verify(msg); err != nil {
log.Debugf("failed to verify message from %s: %s", msg.From.ID, err)
return
}
k.vq <- msg
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) dispatch() {
for {
select {
case msg := <-k.vq:
k.mx.Lock()
b, ok := k.behaviors[msg.Behavior]
if !ok {
k.mx.Unlock()
log.Debugf("unknown behavior %s", msg.Behavior)
continue
}
if b.Expired(time.Now()) {
delete(k.behaviors, msg.Behavior)
k.mx.Unlock()
log.Debugf("expired behavior %s", msg.Behavior)
continue
}
if msg.IsBroadcast() {
if err := k.sctx.RequireBroadcast(msg, b.opt.Topic, b.opt.Capability); err != nil {
k.mx.Unlock()
log.Warnf("broadcast message from %s does not have the required capability %s: %s", msg.From, b.opt.Capability, err)
continue
}
} else if err := k.sctx.Require(msg, b.opt.Capability); err != nil {
k.mx.Unlock()
log.Warnf("message from %s does not have the required capability %s: %s", msg.From, b.opt.Capability, err)
continue
}
if b.opt.OneShot {
delete(k.behaviors, msg.Behavior)
}
k.mx.Unlock()
if err := k.options.Limiter.Acquire(msg); err != nil {
k.sctx.Discard(msg)
log.Debugf("limiter rejected message from %s: %s", msg.From.ID, err)
continue
}
msg.Discard = func() {
k.sctx.Discard(msg)
}
go func() {
defer k.options.Limiter.Release(msg)
b.cont(msg)
}()
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) gc() {
ticker := time.NewTicker(k.options.GCInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
k.mx.Lock()
now := time.Now()
for x, b := range k.behaviors {
if b.Expired(now) {
delete(k.behaviors, x)
}
}
k.mx.Unlock()
case <-k.ctx.Done():
return
}
}
}
func (b *BehaviorState) Expired(now time.Time) bool {
if b.opt.Expire > 0 {
return uint64(now.UnixNano()) > b.opt.Expire
}
return false
}
func WithBehaviorExpiry(expire uint64) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Expire = expire
return nil
}
}
func WithBehaviorCapability(require ...Capability) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Capability = require
return nil
}
}
func WithBehaviorOneShot(oneShot bool) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.OneShot = oneShot
return nil
}
}
func WithBehaviorTopic(topic string) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Topic = topic
return nil
}
}
package actor
import (
"fmt"
)
func (h *Handle) Empty() bool {
return h.ID.Empty() &&
h.DID.Empty() &&
h.Address.Empty()
}
func (h *Handle) String() string {
return fmt.Sprintf("%s[%s]@%s", h.ID, h.DID, h.Address)
}
func HandleFromString(_ string) (Handle, error) {
// TODO
return Handle{}, ErrTODO
}
func (a *Address) Empty() bool {
return a.HostID == "" && a.InboxAddress == ""
}
func (a *Address) String() string {
return a.HostID + ":" + a.InboxAddress
}
func AddressFromString(_ string) (Address, error) {
// TODO
return Address{}, ErrTODO
}
package actor
type BasicDispatchLimiter struct {
// TODO we can leave this for follow up
}
func (l *BasicDispatchLimiter) Reserve(_ Envelope) error {
// TODO we can leave this for follow up
return ErrTODO
}
func (l *BasicDispatchLimiter) Release(_ Envelope) {
// TODO we can leave this for follow up
}
package actor
import (
"encoding/json"
"fmt"
"time"
)
const (
heartbeatBehavior = "/dms/actor/heartbeat"
defaultMessageTimeout = 30 * time.Second
)
var signaturePrefix = []byte("dms:msg:")
type HeartbeatMessage struct{}
// Message constructs a new message envelope and applies the options
func Message(src Handle, dest Handle, behavior string, payload interface{}, opt ...MessageOption) (Envelope, error) {
data, err := json.Marshal(payload)
if err != nil {
return Envelope{}, fmt.Errorf("marshaling payload: %w", err)
}
msg := Envelope{
To: dest,
Behavior: behavior,
From: src,
Message: data,
Options: EnvelopeOptions{
Expire: uint64(time.Now().Add(defaultMessageTimeout).UnixNano()),
},
Discard: func() {},
}
for _, f := range opt {
if err := f(&msg); err != nil {
return Envelope{}, fmt.Errorf("setting message option: %w", err)
}
}
return msg, nil
}
// WithMessageContext provides the necessary envelope and signs it.
//
// NOTE: If this option must be passed last, otherwise the signature will be invalidated by further modifications.
//
// NOTE: Signing is implicit in Send.
func WithMessageSignature(sctx SecurityContext, cap []Capability, delegate []Capability) MessageOption {
return func(msg *Envelope) error {
if !msg.From.ID.Equal(sctx.ID()) {
return ErrInvalidSecurityContext
}
msg.Nonce = sctx.Nonce()
return sctx.Provide(msg, cap, delegate)
}
}
// WithMessageTimeout sets the message expiration from a relative timeout
//
// NOTE: messages created with Message have an implicit timeout of DefaultMessageTimeout
func WithMessageTimeout(timeo time.Duration) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = uint64(time.Now().Add(timeo).UnixNano())
return nil
}
}
// WithMessageExpiry sets the message expiry
//
// NOTE: created with Message message have an implicit timeout of DefaultMessageTimeout
func WithMessageExpiry(expiry uint64) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = expiry
return nil
}
}
// WithMessageReplyTo sets the message replyto behavior
//
// NOTE: ReplyTo is set implicitly in Invoke and the appropriate capability
//
// tokens are delegated by Provide.
func WithMessageReplyTo(replyto string) MessageOption {
return func(msg *Envelope) error {
msg.Options.ReplyTo = replyto
return nil
}
}
// WithMessageTopic sets the broadcast topic
func WithMessageTopic(topic string) MessageOption {
return func(msg *Envelope) error {
msg.Options.Topic = topic
return nil
}
}
func (msg *Envelope) SignatureData() ([]byte, error) {
msgCopy := *msg
msgCopy.Signature = nil
data, err := json.Marshal(&msgCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (msg *Envelope) Expired() bool {
return uint64(time.Now().UnixNano()) > msg.Options.Expire
}
// convert the expiration to a time.Time object
func (msg *Envelope) Expiry() time.Time {
sec := msg.Options.Expire / uint64(time.Second)
nsec := msg.Options.Expire % uint64(time.Second)
return time.Unix(int64(sec), int64(nsec)) //nolint
}
func (msg *Envelope) IsBroadcast() bool {
return msg.To.Empty() && msg.Options.Topic != ""
}
package actor
import (
"errors"
"sync"
)
type Info struct {
Addr *Handle
Parent *Handle
Children []Handle
}
type Registry interface {
Actors() map[string]Info
Add(a Handle, parent Handle, children []Handle) error
Get(a Handle) (Info, bool)
SetParent(a Handle, parent Handle) error
GetParent(a Handle) (*Handle, bool)
}
type registry struct {
mx sync.Mutex
actors map[string]Info
}
func NewRegistry() Registry {
return ®istry{
actors: make(map[string]Info),
}
}
func (r *registry) Actors() map[string]Info {
r.mx.Lock()
defer r.mx.Unlock()
actors := make(map[string]Info, len(r.actors))
for k, v := range r.actors {
actors[k] = v
}
return actors
}
func (r *registry) Add(a Handle, parent Handle, children []Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
if _, ok := r.actors[a.Address.InboxAddress]; ok {
return errors.New("actor already exists")
}
if children == nil {
children = []Handle{}
}
r.actors[a.Address.InboxAddress] = Info{
Addr: &a,
Parent: &parent,
Children: children,
}
return nil
}
func (r *registry) Get(a Handle) (Info, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
return info, ok
}
func (r *registry) SetParent(a Handle, parent Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return errors.New("actor not found")
}
info.Parent = &parent
r.actors[a.Address.InboxAddress] = info
return nil
}
func (r *registry) GetParent(a Handle) (*Handle, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return nil, false
}
return info.Parent, true
}
package actor
import (
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/ucan"
)
type BasicSecurityContext struct {
id ID
privk crypto.PrivKey
cap ucan.CapabilityContext
mx sync.Mutex
nonce uint64
}
var _ SecurityContext = (*BasicSecurityContext)(nil)
func NewBasicSecurityContext(pubk crypto.PubKey, privk crypto.PrivKey, cap ucan.CapabilityContext) (*BasicSecurityContext, error) {
sctx := &BasicSecurityContext{
privk: privk,
cap: cap,
nonce: uint64(time.Now().UnixNano()),
}
var err error
sctx.id, err = crypto.IDFromPublicKey(pubk)
if err != nil {
return nil, fmt.Errorf("creating security context: %w", err)
}
return sctx, nil
}
func (s *BasicSecurityContext) ID() ID {
return s.id
}
func (s *BasicSecurityContext) DID() DID {
return s.cap.DID()
}
func (s *BasicSecurityContext) Nonce() uint64 {
s.mx.Lock()
defer s.mx.Unlock()
nonce := s.nonce
s.nonce++
return nonce
}
func (s *BasicSecurityContext) Require(msg Envelope, cap []Capability) error {
// if we are sending to self, nothing to do, signature is alredady verified
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return nil
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.Require(s.DID(), msg.From.ID, s.id, cap); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) Provide(msg *Envelope, invoke []Capability, delegate []Capability) error {
// if we are sending to self, nothing to do, just Sign
if s.id.Equal(msg.From.ID) && s.id.Equal(msg.To.ID) {
return s.Sign(msg)
}
tokens, err := s.cap.Provide(msg.To.DID, s.id, msg.To.ID, msg.Options.Expire, invoke, delegate)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) RequireBroadcast(msg Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
// first consume the capability tokens in the envelope
if err := s.cap.Consume(msg.From.DID, msg.Capability); err != nil {
return fmt.Errorf("consuming capabilities: %w", err)
}
// check if any of the requested invocation capabilities are delegated
if err := s.cap.RequireBroadcast(s.DID(), msg.From.ID, topic, broadcast); err != nil {
s.cap.Discard(msg.Capability)
return fmt.Errorf("requiring capabilities: %w", err)
}
return nil
}
func (s *BasicSecurityContext) ProvideBroadcast(msg *Envelope, topic string, broadcast []Capability) error {
if !msg.IsBroadcast() {
return fmt.Errorf("not a broadcast message: %w", ErrInvalidMessage)
}
if topic != msg.Options.Topic {
return fmt.Errorf("broadcast topic mismatch: %w", ErrInvalidMessage)
}
tokens, err := s.cap.ProvideBroadcast(msg.From.ID, topic, msg.Options.Expire, broadcast)
if err != nil {
return fmt.Errorf("providing capabilities: %w", err)
}
msg.Capability = tokens
return s.Sign(msg)
}
func (s *BasicSecurityContext) Verify(msg Envelope) error {
if msg.Expired() {
return ErrMessageExpired
}
pubk, err := crypto.PublicKeyFromID(msg.From.ID)
if err != nil {
return fmt.Errorf("public key from id: %w", err)
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
ok, err := pubk.Verify(data, msg.Signature)
if err != nil {
return fmt.Errorf("verify message signature: %w", err)
}
if !ok {
return ErrSignatureVerification
}
return nil
}
func (s *BasicSecurityContext) Sign(msg *Envelope) error {
if !s.id.Equal(msg.From.ID) {
return ErrBadSender
}
data, err := msg.SignatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
sig, err := s.privk.Sign(data)
if err != nil {
return fmt.Errorf("signing message: %w", err)
}
msg.Signature = sig
return nil
}
func (s *BasicSecurityContext) Discard(msg Envelope) {
s.cap.Discard(msg.Capability)
}
package jobs
import (
"context"
"errors"
"fmt"
"reflect"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/executor"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/executor/firecracker"
"gitlab.com/nunet/device-management-service/types"
)
// Status holds the status of an allocation.
type Status struct {
JobResources types.ExecutionResources
Status AllocationStatus
}
// AllocationDetails encapsulates the dependencies to the constructor.
type AllocationDetails struct {
Job Job
NodeID string
SourceID string
}
// AllocationStatus is a representation of the execution status
type AllocationStatus string
const (
pending AllocationStatus = "pending"
running AllocationStatus = "running"
stopped AllocationStatus = "stopped"
)
// Allocation represents an allocation
type Allocation struct {
ID string
Job Job
status AllocationStatus
NodeID string
SourceID string
executionID string
actor *dms.BasicActor
executor executor.Executor
resourceManager resources.Manager
}
// NewAllocation creates a new allocation given the actor.
func NewAllocation(actor *dms.BasicActor, details AllocationDetails, resourceManager resources.Manager) (*Allocation, error) {
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
id, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate uuid for allocation: %w", err)
}
executorID, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to create executor id: %w", err)
}
return &Allocation{
ID: id.String(),
Job: details.Job,
NodeID: details.NodeID,
SourceID: details.SourceID,
actor: actor,
executionID: executorID.String(),
resourceManager: resourceManager,
status: pending,
}, nil
}
// Run creates the executor based on the execution engine configuration.
func (a *Allocation) Run(ctx context.Context) error {
freeResources, err := a.resourceManager.UpdateFreeResources(ctx)
if err != nil {
return fmt.Errorf("failed to get free resources: %w", err)
}
if !availableResources(a.Job.Resources, freeResources) {
return fmt.Errorf("no available resources for job %s", a.Job.ID)
}
// if executor is nil create it
if a.executor == nil {
err = a.createExecutor(ctx, a.Job.Execution)
if err != nil {
return fmt.Errorf("failed to create executor: %w", err)
}
}
err = a.executor.Start(ctx, &types.ExecutionRequest{
JobID: a.Job.ID,
ExecutionID: a.executionID,
EngineSpec: &a.Job.Execution,
Resources: &a.Job.Resources,
// TODO add the following
Inputs: []*types.StorageVolumeExecutor{},
Outputs: []*types.StorageVolumeExecutor{},
ResultsDir: "",
})
if err != nil {
return fmt.Errorf("failed to start executor: %w", err)
}
_, err = a.resourceManager.UpdateFreeResources(ctx)
if err != nil {
return fmt.Errorf("failed to update resources after running allocation's executor: %w", err)
}
a.status = running
return nil
}
// Stop stops the running executor
func (a *Allocation) Stop(ctx context.Context) error {
if a.status != running {
return errors.New("allocation is not running")
}
err := a.executor.Cancel(ctx, a.executionID)
if err != nil {
return fmt.Errorf("failed to stop execution: %w", err)
}
a.status = stopped
_, err = a.resourceManager.UpdateFreeResources(ctx)
if err != nil {
return fmt.Errorf("failed to update resources after stoping allocation's executor: %w", err)
}
return nil
}
// Status returns information about the allocated/usage of resources and execution status of workload.
func (a *Allocation) Status(_ context.Context) Status {
return Status{
JobResources: a.Job.Resources,
Status: a.status,
}
}
// StartActor starts the actor of the allocation.
func (a *Allocation) StartActor() error {
err := a.actor.Start()
if err != nil {
return fmt.Errorf("failed to start allocation actor: %w", err)
}
return nil
}
// ProcessMessages processes actor messages.
func (a *Allocation) ProcessMessages() {
for msg := range a.actor.Messages() {
a.dispatchMethod(msg.Type, msg.Data)
}
}
// SendMessage sends a message through the actor.
func (a *Allocation) SendMessage(ctx context.Context, destination *dms.ActorAddrInfo, m *dms.Message) error {
return a.actor.SendMessage(ctx, destination, m)
}
func (a *Allocation) dispatchMethod(methodName string, args ...any) {
handlerMethod := fmt.Sprintf("Handle%s", methodName)
arguments := make([]reflect.Value, 0)
for _, v := range args {
arguments = append(arguments, reflect.ValueOf(v))
}
method := reflect.ValueOf(a).MethodByName(handlerMethod)
if method.IsValid() {
method.Call(arguments)
return
}
// check if actor has the method
actorMethod := reflect.ValueOf(a.actor).MethodByName(handlerMethod)
if actorMethod.IsValid() {
actorMethod.Call(arguments)
}
}
func (a *Allocation) createExecutor(ctx context.Context, conf types.SpecConfig) error {
if conf.Type == types.ExecutorTypeFirecracker {
executor, err := firecracker.NewExecutor(ctx, a.executionID)
if err != nil {
return fmt.Errorf("firecracker executor: %w", err)
}
a.executor = executor
} else if conf.Type == types.ExecutorTypeDocker {
executor, err := docker.NewExecutor(ctx, a.executionID)
if err != nil {
return fmt.Errorf("docker executor: %w", err)
}
a.executor = executor
}
return nil
}
// TODO: ExecutionResources and FreeResources should be compatible
func availableResources(jobResources types.ExecutionResources, fr types.FreeResources) bool {
return fr.RAM >= uint64(jobResources.Memory.Size) && fr.Disk >= jobResources.Disk.Size && fr.CPU > float64(jobResources.CPU.Cores)
}
package parser
import (
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/nunet"
)
var registry Registry[jobs.JobSpec]
func init() {
registry = &RegistryImpl[jobs.JobSpec]{
parsers: make(map[SpecType]Parser[jobs.JobSpec]),
}
// Register Nunet parser.
nunetParser := NewParser[jobs.JobSpec](
nunet.NewNuNetTransformer(),
nunet.NewNuNetValidator(),
)
registry.RegisterParser(specTypeNuNet, nunetParser)
}
package nunet
import (
"fmt"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetTransformer() transform.Transformer {
return transform.NewTransformer(
[]map[tree.Path]transform.TransformerFunc{
{
"jobs": TransformJobs,
"jobs.**.children": TransformJobs,
"jobs.**.volumes": TransformVolumes,
"jobs.**.networks": TransformNetworks,
},
{
"jobs.**.volumes.[]": TransformVolume,
"jobs.**.networks.[]": TransformNetwork,
"jobs.**.libraries.[]": TransformLibrary,
},
{
"jobs.**.execution": TransformExecution,
"jobs.**.volumes.[].remote": TransformVolumeRemote,
},
},
)
}
// TransformJobs transforms the jobs map to a slice and assigns the keys to the "name" field.
func TransformJobs(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
jobs, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid jobs configuration: %v", data)
}
return transform.MapToSlice(jobs)
}
// TransformVolumes transforms the volumes map to a slice and assigns the keys to the "name" field.
func TransformVolumes(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
volumes, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volumes configuration: %v", data)
}
return transform.MapToSlice(volumes)
}
// TransformNetworks transforms the networks map to a slice and assigns the keys to the "name" field.
func TransformNetworks(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
networks, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid networks configuration: %v", data)
}
return transform.MapToSlice(networks)
}
// TransformExecution transforms the engine configuration from flat map to SpecConfig format.
func TransformExecution(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
engine, ok := data.(map[string]any)
result := map[string]any{}
if !ok {
return nil, fmt.Errorf("invalid engine configuration: %v", data)
}
params := map[string]any{}
for k, v := range engine {
if k != "type" {
params[k] = v
}
}
result["type"] = engine["type"]
result["params"] = params
return result, nil
}
// TransformVloume transforms the volume configuration and handles inheritance.
// The volume configuration can be a string in the format "name:mountpoint" or a map.
// If the volume is defined in the parent volumes, the configurations are merged.
func TransformVolume(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, split it into name and mountpoint.
switch v := data.(type) {
case string:
mapping := strings.Split(v, ":")
if len(mapping) != 2 {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
config = map[string]any{
"name": mapping[0],
"mountpoint": mapping[1],
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
// Collect all potential parent paths where the volume could be defined.
parentPaths := []tree.Path{}
pathParts := path.Parts()
for i, part := range pathParts {
if part == "children" {
parentPaths = append(parentPaths, tree.NewPath(pathParts[:i]...))
}
}
// Merge the volume configuration with the parent configurations.
for _, parent := range parentPaths {
// Check if the volume exists in the parent
c, err := transform.GetConfigAtPath(*root, parent.Next("volumes"))
if err != nil {
fmt.Println("error: ", err)
continue
}
volumes, _ := transform.ToAnySlice(c)
for _, v := range volumes {
if volume, ok := v.(map[string]any); ok && volume["name"] == config["name"] {
// Merge the configurations
for k, v := range volume {
config[k] = v
}
}
}
}
return config, nil
}
func TransformVolumeRemote(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
remoteConfig := map[string]any{}
remoteConfig["type"] = config["type"]
if params, ok := config["params"]; ok {
remoteConfig["params"] = params.(map[string]any)
return remoteConfig, nil
}
params := map[string]any{}
for k, v := range config {
if k != "type" {
params[k] = v
}
}
remoteConfig["params"] = params
return remoteConfig, nil
}
// TransformNetwork transforms the network configuration
func TransformNetwork(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid network configuration: %v", data)
}
ports, _ := transform.ToAnySlice(config["ports"])
portMap := []map[string]any{}
for _, port := range ports {
protocol, host, container := "tcp", 0, 0
switch v := port.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) <= 2 {
host, _ = strconv.Atoi(parts[0])
container, _ = strconv.Atoi(parts[len(parts)-1])
} else if len(parts) == 3 {
protocol = parts[0]
host, _ = strconv.Atoi(parts[1])
container, _ = strconv.Atoi(parts[len(parts)-1])
}
case int:
host = v
container = v
case map[string]any:
switch h := v["host_port"].(type) {
case int:
host = h
case string:
host, _ = strconv.Atoi(h)
}
switch c := v["container_port"].(type) {
case int:
container = c
case string:
container, _ = strconv.Atoi(c)
}
if p, ok := v["protocol"].(string); ok {
protocol = p
}
}
portMap = append(portMap, map[string]any{
"protocol": protocol,
"host_port": host,
"container_port": container,
})
}
config["port_map"] = portMap
delete(config, "ports")
return config, nil
}
// TransformLibrary tansforms the library configuration to a map.
// The library configuration can be a string in the format "name:version" or a map.
func TransformLibrary(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
switch v := data.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) == 1 {
parts = append(parts, "")
}
if len(parts) != 2 {
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
return map[string]any{
"name": parts[0],
"version": parts[1],
}, nil
case map[string]any:
return v, nil
default:
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
}
package nunet
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetValidator() validate.Validator {
return validate.NewValidator(
map[tree.Path]validate.ValidatorFunc{
"": ValidateSpec,
"jobs.[]": ValidateJob,
"jobs.**.children.[]": ValidateJob,
},
)
}
// ValidateSpec checks the root configuration for consistency.
func ValidateSpec(_ *map[string]any, data any, _ tree.Path) error {
spec, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid spec configuration: %v", data)
}
// Check if the jobs list is present and not empty.
if spec["jobs"] == nil || len(spec["jobs"].([]any)) == 0 {
return fmt.Errorf("jobs list is required")
}
return nil
}
// ValidateJob checks the job configuration.
func ValidateJob(_ *map[string]any, data any, _ tree.Path) error {
job, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid job configuration: %v", data)
}
// Check if the job has either children or an execution.
if job["children"] == nil || len(job["children"].([]any)) == 0 {
if job["execution"] == nil {
return fmt.Errorf("job must have either children or an execution")
}
}
return nil
}
package parser
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs"
)
func Parse(specType SpecType, data []byte) (jobs.JobSpec, error) {
result := jobs.JobSpec{}
parser, exists := registry.GetParser(specType)
if !exists {
return result, fmt.Errorf("parser for spec type %s not found", specType)
}
result, err := parser.Parse(data)
if err != nil {
return result, err
}
return result, nil
}
package parser
import (
"encoding/json"
"fmt"
"github.com/mitchellh/mapstructure"
yaml "gopkg.in/yaml.v3"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
type SpecType string
const (
specTypeNuNet SpecType = "nunet"
specTypeNomad SpecType = "nomad"
specTypeK8s SpecType = "k8s"
)
const DefaultTagName = "json"
type Parser[T any] interface {
Parse(data []byte) (T, error)
}
type Impl[T any] struct {
validator validate.Validator
transformer transform.Transformer
}
func NewParser[T any](transformer transform.Transformer, validator validate.Validator) Parser[T] {
return Impl[T]{
transformer: transformer,
validator: validator,
}
}
func (p Impl[T]) Parse(data []byte) (T, error) {
var rawConfig map[string]any
var config T
// Try to unmarshal as YAML first
err := yaml.Unmarshal(data, &rawConfig)
if err != nil {
// If YAML fails, try JSON
err = json.Unmarshal(data, &rawConfig)
if err != nil {
return config, fmt.Errorf("failed to parse config: %v", err)
}
}
// Apply transformers
transformed, err := p.transformer.Transform(&rawConfig)
if err != nil {
return config, fmt.Errorf("failed to transform config: %v", err)
}
// Validate the transformed configuration
if err = p.validator.Validate(&rawConfig); err != nil {
return config, err
}
// Create a new mapstructure decoder
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &config,
TagName: DefaultTagName,
})
if err != nil {
return config, fmt.Errorf("failed to create decoder: %v", err)
}
// Decode the transformed configuration
err = decoder.Decode(transformed)
if err != nil {
return config, fmt.Errorf("failed to decode config: %v", err)
}
return config, err
}
package parser
import (
"sync"
)
type Registry[T any] interface {
GetParser(specType SpecType) (Parser[T], bool)
RegisterParser(specType SpecType, p Parser[T])
}
type RegistryImpl[T any] struct {
parsers map[SpecType]Parser[T]
mu sync.RWMutex
}
func (r *RegistryImpl[T]) RegisterParser(specType SpecType, p Parser[T]) {
r.mu.Lock()
defer r.mu.Unlock()
r.parsers[specType] = p
}
func (r *RegistryImpl[T]) GetParser(specType SpecType) (Parser[T], bool) {
r.mu.RLock()
defer r.mu.RUnlock()
p, exists := r.parsers[specType]
return p, exists
}
package transform
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// TransformerFunc is a function that transforms a part of the configuration.
// It modifies the data to conform to the expected structure and returns the transformed data.
// It takes the root configuration, the data to transform and the current path in the tree.
type TransformerFunc func(*map[string]interface{}, interface{}, tree.Path) (any, error)
// Transformer is a configuration transformer.
type Transformer interface {
Transform(*map[string]interface{}) (interface{}, error)
}
// TransformerImpl is the implementation of the Transformer interface.
type TransformerImpl struct {
transformers []map[tree.Path]TransformerFunc
}
// NewTransformer creates a new transformer with the given transformers.
func NewTransformer(transformers []map[tree.Path]TransformerFunc) Transformer {
return TransformerImpl{
transformers: transformers,
}
}
// Transform applies the transformers to the configuration.
func (t TransformerImpl) Transform(rawConfig *map[string]interface{}) (interface{}, error) {
data := any(*rawConfig)
var err error
for _, transformers := range t.transformers {
data, err = t.transform(rawConfig, data, tree.NewPath(), transformers)
if err != nil {
return nil, err
}
}
return Normalize(data), nil
}
// transform is a recursive function that applies the transformers to the configuration.
func (t TransformerImpl) transform(root *map[string]interface{}, data any, path tree.Path, transformers map[tree.Path]TransformerFunc) (interface{}, error) {
var err error
// Apply transformers that match the current path.
for pattern, transformer := range transformers {
if path.Matches(pattern) {
data, err = transformer(root, data, path)
if err != nil {
return nil, err
}
}
}
// Recursively apply transformers to children.
if result, ok := data.(map[string]interface{}); ok {
for key, value := range result {
next := path.Next(key)
result[key], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
} else if result, err := ToAnySlice(data); err == nil {
for i, value := range result {
next := path.Next(fmt.Sprintf("[%d]", i))
result[i], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
}
return data, nil
}
package transform
import (
"fmt"
"reflect"
"sort"
"strconv"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// mapToSlice converts a map of maps to a slice
// and assigns the key to the "name" field.
func MapToSlice(data map[string]any) ([]any, error) {
if data == nil {
return nil, nil
}
result := []any{}
for k, v := range data {
if v == nil {
v = map[string]any{}
}
if e, ok := v.(map[string]any); ok {
e["name"] = k
}
result = append(result, v)
}
return result, nil
}
// getConfigAtPath retrieves a part of the configuration at a given path
func GetConfigAtPath(config map[string]interface{}, path tree.Path) (any, error) {
current := any(config)
for _, key := range path.Parts() {
switch v := current.(type) {
case map[string]any:
current = v[key]
case []any, []map[string]any:
i, err := strconv.Atoi(key[1 : len(key)-1])
if err != nil {
return nil, fmt.Errorf("invalid index: %v", key)
}
switch v := v.(type) {
case []any:
current = v[i]
case []map[string]any:
current = v[i]
}
default:
return nil, fmt.Errorf("invalid data type: %v", current)
}
}
return current, nil
}
// Generic function to convert any slice to []any
func ToAnySlice(slice any) ([]any, error) {
value := reflect.ValueOf(slice)
// Check if the input is a slice
if value.Kind() != reflect.Slice {
return nil, fmt.Errorf("input is not a slice. type: %T", slice)
}
length := value.Len()
anySlice := make([]any, length)
for i := 0; i < length; i++ {
anySlice[i] = value.Index(i).Interface()
}
return anySlice, nil
}
func normalizeMap(m interface{}) interface{} {
v := reflect.ValueOf(m)
switch v.Kind() {
case reflect.Map:
// Create a new map to hold normalized values
newMap := reflect.MakeMap(reflect.MapOf(v.Type().Key(), reflect.TypeOf((*interface{})(nil)).Elem()))
for _, key := range v.MapKeys() {
newValue := normalizeMap(v.MapIndex(key).Interface())
newMap.SetMapIndex(key, reflect.ValueOf(newValue))
}
return newMap.Interface()
case reflect.Slice:
// Create a new []interface{} slice to hold normalized values
newSlice := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
newSlice[i] = normalizeMap(v.Index(i).Interface())
}
// Sort the slice if it's sortable
sort.Slice(newSlice, func(i, j int) bool {
return fmt.Sprint(newSlice[i]) < fmt.Sprint(newSlice[j])
})
return newSlice
default:
// For other types, return as is
return m
}
}
// NormalizeMap is the exported function that users will call
func Normalize(m any) interface{} {
return normalizeMap(m)
}
package tree
import (
"strings"
)
const (
configPathSeparator = "."
configPathMatchAny = "*"
configPathMatchAnyMultiple = "**"
configPathList = "[]"
)
// Path is a custom type for representing paths in the configuration
type Path string
func NewPath(path ...string) Path {
return Path(strings.Join(path, configPathSeparator))
}
// Parts returns the parts of the path
func (p Path) Parts() []string {
return strings.Split(string(p), configPathSeparator)
}
// Parent returns the parent path
func (p Path) Parent() Path {
parts := p.Parts()
if len(parts) > 1 {
return Path(strings.Join(parts[:len(parts)-1], configPathSeparator))
}
return ""
}
// Next returns the next part of the path
func (p Path) Next(path string) Path {
if path == "" {
return p
}
if p == "" {
return Path(path)
}
return Path(string(p) + configPathSeparator + path)
}
// Last returns the last part of the path
func (p Path) Last() string {
parts := p.Parts()
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}
// Matches checks if the path matches a given pattern
func (p Path) Matches(pattern Path) bool {
pathParts := p.Parts()
patternParts := pattern.Parts()
return matchParts(pathParts, patternParts)
}
// String returns the string representation of the path
func (p Path) String() string {
return string(p)
}
// matchParts checks if the path parts match the pattern parts
func matchParts(pathParts, patternParts []string) bool {
// If the pattern is longer than the path, it can't match
if len(pathParts) < len(patternParts) {
return false
}
for i, part := range patternParts {
switch part {
case configPathMatchAnyMultiple:
// if it is the last part of the pattern, it matches
if i == len(patternParts)-1 {
return true
}
// Otherwise, try to match the rest of the path
for j := i; j < len(pathParts); j++ {
if matchParts(pathParts[j:], patternParts[i+1:]) {
return true
}
}
case configPathList:
// check if pathParts[i] is inclosed by []
if pathParts[i][0] != '[' || pathParts[i][len(pathParts[i])-1] != ']' {
return false
}
default:
// If the part doesn't match, it doesn't match
if part != configPathMatchAny && part != pathParts[i] {
return false
}
}
// If it is the last part of the pattern and the path is longer, it doesn't match
if i == len(patternParts)-1 && i < len(pathParts)-1 {
return false
}
}
return true
}
package validate
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// ValidatorFunc is a function that validates a part of the configuration.
// It takes the root configuration, the data to validate and the current path in the tree.
type ValidatorFunc func(*map[string]any, any, tree.Path) error
// Validator is a configuration validator.
// It contains a map of patterns to paths to functions that validate the configuration.
type Validator interface {
Validate(*map[string]any) error
}
// ValidatorImpl is the implementation of the Validator interface.
type ValidatorImpl struct {
validators map[tree.Path]ValidatorFunc
}
// NewValidator creates a new validator with the given validators.
func NewValidator(validators map[tree.Path]ValidatorFunc) Validator {
return ValidatorImpl{
validators: validators,
}
}
// Validate applies the validators to the configuration.
func (v ValidatorImpl) Validate(rawConfig *map[string]any) error {
data := any(*rawConfig)
return v.validate(rawConfig, data, tree.NewPath(), v.validators)
}
// validate is a recursive function that applies the validators to the configuration.
func (v ValidatorImpl) validate(root *map[string]interface{}, data any, path tree.Path, validators map[tree.Path]ValidatorFunc) error {
// Apply validators that match the current path.
for pattern, validator := range validators {
if path.Matches(pattern) {
if err := validator(root, data, path); err != nil {
return err
}
}
}
// Recursively apply validators to children.
switch data := data.(type) {
case map[string]interface{}:
for key, value := range data {
next := path.Next(key)
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
case []interface{}:
for i, value := range data {
next := path.Next(fmt.Sprintf("[%d]", i))
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
}
return nil
}
package node
import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"time"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
// TODO: remove after resource manager MR is merged
type benchmarker interface {
Benchmark(ctx context.Context) (*types.Capability, error)
}
// Node is the structure that holds the node's dependencies.
type Node struct {
ID string
actor *dms.BasicActor
network network.Network
actorFactory *dms.ActorFactory
// TODO: fix when resource manager is merged to develop
resourceManager resources.Manager
benchmark benchmarker
allocations map[string]*jobs.Allocation
mu sync.RWMutex
}
// New creates a new node, attaches an actor to the node.
func New(_ context.Context, id string, net network.Network, benchmark benchmarker, resourceManager resources.Manager) (*Node, error) {
if id == "" {
return nil, errors.New("id is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if benchmark == nil {
return nil, errors.New("benchmarker is nil")
}
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
actorFactory := dms.NewActorFactory(id, net, &dms.ActorParams{
HeartbeatInterval: time.Second * 5,
HeartbeatCheckInterval: time.Second * 8,
})
n := &Node{
ID: id,
actorFactory: actorFactory,
network: net,
allocations: make(map[string]*jobs.Allocation),
benchmark: benchmark,
resourceManager: resourceManager,
}
err := n.createNodeActor()
if err != nil {
return nil, fmt.Errorf("failed to create node actor: %w", err)
}
return n, nil
}
// GetAllocation gets an allocation by id.
func (n *Node) GetAllocation(id string) (*jobs.Allocation, error) {
n.mu.RLock()
defer n.mu.RUnlock()
alloc, ok := n.allocations[id]
if !ok {
return nil, errors.New("allocation not found")
}
return alloc, nil
}
// GetAllocation gets an allocation by id.
func (n *Node) BenchmarkCapability(ctx context.Context) (*types.Capability, error) {
return n.benchmark.Benchmark(ctx)
}
// CreateAllocation creates an allocation.
func (n *Node) CreateAllocation(_ context.Context, job jobs.Job) (*jobs.Allocation, error) {
allocationActor, err := n.actor.CreateActor()
if err != nil {
return nil, fmt.Errorf("failed to create allocation actor: %w", err)
}
allocation, err := jobs.NewAllocation(allocationActor, jobs.AllocationDetails{Job: job, NodeID: n.ID, SourceID: ""}, n.resourceManager)
if err != nil {
return nil, fmt.Errorf("failed to create allocation actor: %w", err)
}
n.mu.Lock()
n.allocations[allocation.ID] = allocation
n.mu.Unlock()
return allocation, nil
}
// ProcessMessages processes actor messages.
func (n *Node) ProcessMessages() {
for msg := range n.actor.Messages() {
n.dispatchMethod(msg.Type, msg.Data)
}
}
// SendMessage sends a message through the actor.
func (n *Node) SendMessage(ctx context.Context, destination *dms.ActorAddrInfo, m *dms.Message) error {
return n.actor.SendMessage(ctx, destination, m)
}
// HandleHello will be called when a message type of `Hello` arrives to the messages queue.
// For this to work properly we should always append Handle in front of the function.
func (n *Node) HandleHello(payload []byte) {
fmt.Println("hello from: ", string(payload))
}
func (n *Node) dispatchMethod(methodName string, args ...any) {
handlerMethod := fmt.Sprintf("Handle%s", methodName)
arguments := make([]reflect.Value, 0)
for _, v := range args {
arguments = append(arguments, reflect.ValueOf(v))
}
method := reflect.ValueOf(n).MethodByName(handlerMethod)
if method.IsValid() {
method.Call(arguments)
return
}
// check if actor has the method
actorMethod := reflect.ValueOf(n.actor).MethodByName(handlerMethod)
if actorMethod.IsValid() {
actorMethod.Call(arguments)
}
}
// createNodeActor the root actor in this node instance.
func (n *Node) createNodeActor() error {
actor, err := n.actorFactory.NewActor()
if err != nil {
return fmt.Errorf("failed to create actor: %w", err)
}
n.actor = actor
err = n.actor.Start()
if err != nil {
return fmt.Errorf("failed to start node actor: %w", err)
}
return nil
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func CPUComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for CPU type
// we want to reason about the inner fields of the CPU type and how they compare between left and right
// validate input type
lCPU, lok := l.(types.CPU)
rCPU, rok := r.(types.CPU)
if !lok || !rok {
return types.Error
}
perfComparison := NumericComparator(
(int64(lCPU.Cores) * lCPU.ClockSpeedHz),
(int64(rCPU.Cores) * rCPU.ClockSpeedHz),
)
archComparison := LiteralComparator(lCPU.Architecture, rCPU.Architecture)
if archComparison == types.Error {
return types.Error
}
if archComparison != types.Equal {
return types.Worse
}
return perfComparison
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and frequency is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of CPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func DiskComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for Memory type
// we want to reason about the inner fields of the Memory type and how they compare between left and right
// validate input type
_, lok := l.(types.Disk)
_, rok := r.(types.Disk)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["Type"] == types.Error {
return types.Error
}
if comparison["Type"] != types.Equal {
return types.Worse
}
return comparison["Size"]
// currently this is a very simple comparison, based on the assumption
// that more Size / or equal amount of size and speed is acceptable, but nothing less;
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func ExecutionResourcesComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for types.ExecutionResources type
// Current implementation of the type has four fields: CPU, Memory, Disk, GPUs
// we consider that all fields have to be 'Better' or 'Equal'
// for the comparison to be 'Better' or 'Equal'
// else we return 'Worse'
// validate input type
_, lok := l.(types.ExecutionResources)
_, rok := r.(types.ExecutionResources)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["CPU"] == types.Error ||
comparison["Memory"] == types.Error ||
comparison["Disk"] == types.Error ||
comparison["GPUs"] == types.Error {
return types.Error
}
if comparison["CPU"] == types.Worse ||
comparison["Memory"] == types.Worse ||
comparison["Disk"] == types.Worse ||
comparison["GPUs"] == types.Worse {
return types.Worse
}
if comparison["CPU"] == types.Equal &&
comparison["Memory"] == types.Equal &&
comparison["Disk"] == types.Equal &&
comparison["GPUs"] == types.Equal {
return types.Equal
}
return types.Better // if non above returns, then the result is better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func ExecutorComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for Executor types
// it is needed because executor type is defined as enum of ExecutorType's in types.execution.go
// left represent machine capabilities
// right represent required capabilities
// it is not so complex as the type has only one field
// therefore this method just passes it through...
// validate input type
_, lrawok := lraw.(types.Executor)
_, rrawok := rraw.(types.Executor)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Executor)
r := rraw.(types.Executor)
leftExecutorType := l.ExecutorType
rightExecutorType := r.ExecutorType
comparison := Compare(leftExecutorType, rightExecutorType)
return comparison
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func ExecutorTypeComparator(l, r interface{}, _ ...Preference) types.Comparison {
_, lok := l.(types.ExecutorType)
_, rok := r.(types.ExecutorType)
if !lok || !rok {
return types.Error
}
result := types.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = types.Equal
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func ExecutorsComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for Executors types:
// left represent machine capabilities;
// right represent required capabilities;
var result types.Comparison
result = types.Error // error is the default value
// validate input type
ll, lrawok := lraw.(types.Executors)
rr, rrawok := rraw.(types.Executors)
if !lrawok || !rrawok {
return types.Error
}
l := make([]interface{}, 0)
r := make([]interface{}, 0)
for _, v := range ll {
l = append(l, v)
}
for _, v := range rr {
r = append(r, v)
}
if !utils.IsSameShallowType(l, r) {
result = types.Error
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = types.Equal
case utils.IsStrictlyContained(l, r):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
result = types.Better
case utils.IsStrictlyContained(r, l):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
result = types.Worse
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func GPUVendorComparator(l, r interface{}, _ ...Preference) types.Comparison {
// validate input type
_, lok := l.(types.GPUVendor)
_, rok := r.(types.GPUVendor)
if !lok || !rok {
return types.Error
}
result := types.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = types.Equal
}
return result
// This comparison logic just tells if the vendor is the same or not;
// however, we do not have yet a mechanism for externally defined preferences from a user;
// in this case, we may need to implement that -- because some compute may prefer one vendor over the other;
// some compute may be strictly dependent on a specific vendor;
// technically, this will have to be solved on the resource matching level;
// but the mechanism will have to be generic...
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
func GPUsComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for GPUs type which is just a slice of GPU types:
// left represent machine capabilities;
// right represent required capabilities;
// we need to check if for ech GPU on the right there exist a matching GPU on the left...
// (since given slices are not ordered...)
// validate input type
_, lrawok := lraw.([]types.GPU)
_, rrawok := rraw.([]types.GPU)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.GPU)
r := rraw.([]types.GPU)
interimComparison1 := make([][]types.Comparison, 0)
for _, rGPU := range r {
var interimComparison2 []types.Comparison
for _, lGPU := range l {
interimComparison2 = append(interimComparison2, Compare(lGPU, rGPU))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []types.Comparison
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, types.Error) {
return types.Error
}
if slices.Contains(finalComparison, types.Worse) {
return types.Worse
}
if SliceContainsOneValue(finalComparison, types.Equal) {
return types.Equal
}
return types.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func GpuComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for GPU type
// we want to reason about the inner fields of the GPU type and how they compare between left and right
// in the future we may want to pass custom preference parameters to the ComplexComparator
// for now it is probably best to hardcode them;
// validate input type
_, lok := l.(types.GPU)
_, rok := r.(types.GPU)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["TotalVRAM"] == types.Error {
return types.Error
}
if comparison["TotalVRAM"] == types.Worse {
return types.Worse
}
if comparison["TotalVRAM"] == types.Better {
return types.Better
}
if comparison["TotalVRAM"] == types.Equal {
return types.Equal
}
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and VRAM is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of GPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
return types.Error // error is the default value
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func JobTypeComparator(l, r interface{}, _ ...Preference) types.Comparison {
// validate input type
_, lok := l.(types.JobType)
_, rok := r.(types.JobType)
if !lok || !rok {
return types.Error
}
result := types.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = types.Equal
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func JobTypesComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for JobTypes type:
// left represent machine capabilities;
// right represent required capabilities;
// if machine capabilities contain oll the required capabilities, then we are good to go
// validate input type
_, lrawok := lraw.(types.JobTypes)
_, rrawok := rraw.(types.JobTypes)
if !lrawok || !rrawok {
return types.Error
}
result := types.Error // error is the default value
// we know that interfaces here are slices, so need to assert first
l := utils.ConvertTypedSliceToUntypedSlice(lraw)
r := utils.ConvertTypedSliceToUntypedSlice(rraw)
if !utils.IsSameShallowType(l, r) {
result = types.Error
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = types.Equal
case utils.IsStrictlyContained(l, r):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
result = types.Better
case utils.IsStrictlyContained(r, l):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
result = types.Worse
// TODO: this comparator does not take into account options when several job types are available and several job types are required
// in the same data structure; this is why the test fails;
}
return result
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
func LibrariesComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for Libraries slices (of different lengths) of Library types:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.Library)
_, rrawok := rraw.([]types.Library)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.Library)
r := rraw.([]types.Library)
interimComparison1 := make([][]types.Comparison, 0)
for _, rLibrary := range r {
var interimComparison2 []types.Comparison
for _, lLibrary := range l {
interimComparison2 = append(interimComparison2, Compare(lLibrary, rLibrary))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
finalComparison := make([]types.Comparison, 0)
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, types.Error) {
return types.Error
}
if slices.Contains(finalComparison, types.Worse) {
return types.Worse
}
if SliceContainsOneValue(finalComparison, types.Equal) {
return types.Equal
}
return types.Better
}
package matching
import (
"github.com/hashicorp/go-version"
"gitlab.com/nunet/device-management-service/types"
)
func LibraryComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for single Library type:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Library)
_, rrawok := rraw.(types.Library)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Library)
lVersion, err := version.NewVersion(l.Version)
if err != nil {
return types.Error
}
r := rraw.(types.Library)
// return 'Error' if the version of the left library is not valid
constraints, err := version.NewConstraint(r.Constraint + " " + r.Version)
if err != nil {
return types.Error
}
// return 'Error' if the names of the libraries are different
if l.Name != r.Name {
return types.Error
}
// else return 'Equal if versions of libraries are equal and the constraint is '='
if r.Constraint == "=" && constraints.Check(lVersion) {
return types.Equal
}
// else return 'Better' if versions of libraries match the constraint
if constraints.Check(lVersion) {
return types.Better
}
// else return 'Worse'
return types.Worse
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func LiteralComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for literal (basically string) types:
// left represent machine capabilities;
// right represent required capabilities;
// which can be only equal or not equal...
// validate input type
_, lok := l.(string)
_, rok := r.(string)
if !lok || !rok {
return types.Error
}
var result types.Comparison
result = types.Error // error is the default value
if str, ok := l.(string); ok {
// l is of type string, now check if it equals r
if str == r {
result = types.Equal
}
}
return result
}
package matching
import (
// "reflect"
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
func LocalitiesComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of Localities comparator
// which is simply a slice of Locality type;
// we do not have separate type defined for Localities
// it takes preference variable where comparison Preference is defined
// this is the first method that is used to take Preference variable into account
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.Locality)
_, rrawok := rraw.([]types.Locality)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.Locality)
r := rraw.([]types.Locality)
interimComparison := make([](map[string]types.Comparison), 0)
for _, rLocality := range r {
field := make(map[string]types.Comparison)
field[rLocality.Kind] = types.Error
for _, lLocality := range l {
if lLocality.Kind == rLocality.Kind {
field[rLocality.Kind] = Compare(lLocality, rLocality)
// this is to make sure that we have a comparison even if slice dimentiones do not match
}
}
interimComparison = append(interimComparison, field)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []types.Comparison
for _, c := range interimComparison {
for _, v := range c { // we know that there is only one value in the map
finalComparison = append(finalComparison, v)
}
}
if slices.Contains(finalComparison, types.Error) {
return types.Error
}
if slices.Contains(finalComparison, types.Worse) {
return types.Worse
}
if SliceContainsOneValue(finalComparison, types.Equal) {
return types.Equal
}
return types.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func LocalityComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// comparator for Locality:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Locality)
_, rrawok := rraw.(types.Locality)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Locality)
r := rraw.(types.Locality)
if l.Kind == r.Kind {
if l.Name == r.Name {
return types.Equal
}
return types.Worse
}
return types.Error
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func MemoryComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for Memory type
// we want to reason about the inner fields of the Memory type and how they compare between left and right
// validate input type
_, lok := l.(types.RAM)
_, rok := r.(types.RAM)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["Size"] == types.Error {
return types.Error
}
if comparison["Size"] == types.Worse {
return types.Worse
}
return comparison["ClockSpeedHz"]
// currently this is a very simple comparison, based on the assumption
// that more Size / or equal amount of size and speed is acceptable, but nothing less;
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
func NumericComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for numeric types:
// left represent machine capabilities;
// right represent required capabilities;
var result types.Comparison
result = types.Error // error is the default value
// validate input type
l, lnumeric := validate.ConvertNumericToFloat64(lraw)
r, rnumeric := validate.ConvertNumericToFloat64(rraw)
if !lnumeric || !rnumeric {
result = types.Error
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = types.Equal
case l < r:
// if declared machine numeric capability
// is less than job's required capability
// then the result of comparison is 'Worse'
result = types.Worse
case l > r:
// if declared machine numeric capability
// is more than job's required numeric capability
// then the result of comparison is 'Better'
result = types.Better
}
return result
}
package matching
import (
"fmt"
"gitlab.com/nunet/device-management-service/types"
)
// CapabilityComparator compares two capabilities by ANDing the comparisons of their fields.
// it respects the following table of truth:
//
// | AND | Better | Worse | Equal | Error |
// | ------ | ------ |--------|--------|--------|
// | Better | Better | Worse | Better | Error |
// | Worse | Worse | Worse | Worse | Error |
// | Equal | Better | Worse | Equal | Error |
// | Error | Error | Error | Error | Error |
//
// The comparison of the fields is done by the Compare function.
//
// Result = (Comparison of Executors) AND (Comparison of JobTypes) AND
// (Comparison of Resources) AND (Comparison of Libraries) AND
// (Comparison of Localities) AND (Comparison of Storage) AND
// (Comparison of Connectivity) AND (Comparison of Price) AND
// (Comparison of Time) AND (Comparison of KYCs)
func CapabilityComparator(l, r interface{}, _ ...Preference) types.Comparison {
var result types.Comparison
_, lok := l.(types.Capability)
_, rok := r.(types.Capability)
if !lok || !rok {
fmt.Println(lok, rok)
return types.Error
}
lcap := l.(types.Capability)
rcap := r.(types.Capability)
// Executors
result = Compare(lcap.Executors, rcap.Executors)
// JobTypes
result = result.And(Compare(lcap.JobTypes, rcap.JobTypes))
// Resources
result = result.And(Compare(lcap.Resources, rcap.Resources))
// Libraries
result = result.And(Compare(lcap.Libraries, rcap.Libraries))
// Localities
result = result.And(Compare(lcap.Localities, rcap.Localities))
// Storage
result = result.And(Compare(lcap.Storage, rcap.Storage))
// Connectivity
result = result.And(Compare(lcap.Connectivity, rcap.Connectivity))
// Price
result = result.And(Compare(lcap.Price, rcap.Price))
// Time
result = result.And(Compare(lcap.Time, rcap.Time))
// KYCs
result = result.And(Compare(lcap.KYCs, rcap.KYCs))
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// generic compare function for comparing any custom types given a custom comparator
// for simple types (i.e which are not nested in the map[string]interface{} structure)
type Comparator func(l, r interface{}, preference ...Preference) types.Comparison
func Compare(l, r interface{}, _ ...Preference) types.Comparison {
// TODO: it would be better to pass a pointer as this is a global structure
comparatorMap := initComparatorMap()
// check if the type is numeric
if _, numeric := validate.ConvertNumericToFloat64(l); numeric {
comparator := comparatorMap["Numeric"]
if comparator == nil {
return types.Error
}
return comparator(l, r)
}
typeName := reflect.TypeOf(l).Name()
// this means that the type is probably a slice of custom types
// we have to get the element types and then map it to existing custom types that know
// so that we can call a correct comparator for that
//nolint
switch reflect.TypeOf(l).Kind() {
// check if we have a slice of further types
// we need to mention each type explicitly
case reflect.Slice:
if _, ok := l.([]types.GPU); ok {
typeName = "GPUs"
}
if _, ok := l.([]types.Library); ok {
typeName = "Libraries"
}
if _, ok := l.([]types.Locality); ok {
typeName = "Localities"
}
if _, ok := l.([]types.Storage); ok {
typeName = "Storages"
}
if _, ok := l.([]types.PriceInformation); ok {
typeName = "PricesInformation"
}
if _, ok := l.([]types.KYC); ok {
typeName = "KYCs"
}
}
// select the comparator based on type
comparator := comparatorMap[typeName]
if comparator == nil {
return types.Error
}
return comparator(l, r)
}
type ComparatorMap map[string]Comparator
func initComparatorMap() ComparatorMap {
// comparatorMap holds all defined comparators in a variable that can be passed
// around and searched / referenced
comparatorMap := make(map[string]Comparator)
comparatorMap["Numeric"] = NumericComparator
comparatorMap["Capability"] = CapabilityComparator
comparatorMap["string"] = LiteralComparator
comparatorMap["Executors"] = ExecutorsComparator
comparatorMap["ExecutorType"] = ExecutorTypeComparator
comparatorMap["JobType"] = JobTypeComparator
comparatorMap["JobTypes"] = JobTypesComparator
comparatorMap["GPUVendor"] = GPUVendorComparator
comparatorMap["GPUs"] = GPUsComparator
comparatorMap["GPU"] = GpuComparator
comparatorMap["Executor"] = ExecutorComparator
comparatorMap["ExecutionResources"] = ExecutionResourcesComparator
comparatorMap["CPU"] = CPUComparator
comparatorMap["RAM"] = MemoryComparator
comparatorMap["Disk"] = DiskComparator
comparatorMap["Library"] = LibraryComparator
comparatorMap["Libraries"] = LibrariesComparator
comparatorMap["Locality"] = LocalityComparator
comparatorMap["Localities"] = LocalitiesComparator
comparatorMap["Storage"] = StorageComparator
comparatorMap["Storages"] = StoragesComparator
comparatorMap["Connectivity"] = ConnectivityComparator
comparatorMap["PriceInformation"] = PriceInformationComparator
comparatorMap["PricesInformation"] = PricesInformationComparator
comparatorMap["TimeInformation"] = TimeInformationComparator
comparatorMap["KYC"] = KYCComparator
comparatorMap["KYCs"] = KYCsComparator
return comparatorMap
}
type Preference struct {
TypeName string
Strength PreferenceString
DefaultComparatorOverride Comparator
}
type PreferenceString string
const (
Hard PreferenceString = "Hard"
Soft PreferenceString = "Soft"
)
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func ConnectivityComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Connectivity)
_, rrawok := rraw.(types.Connectivity)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Connectivity)
r := rraw.(types.Connectivity)
//nolint
if reflect.DeepEqual(l, r) {
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return types.Equal
} else if (utils.IsStrictlyContainedInt(l.Ports, r.Ports)) && (l.VPN && r.VPN || l.VPN && !r.VPN) {
return types.Better
} else {
return types.Worse
}
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func KYCComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.KYC)
_, rrawok := rraw.(types.KYC)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.KYC)
r := rraw.(types.KYC)
if reflect.DeepEqual(l, r) {
return types.Equal
}
return types.Error
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func KYCsComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.KYC)
_, rrawok := rraw.([]types.KYC)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.KYC)
r := rraw.([]types.KYC)
//nolint
if reflect.DeepEqual(l, r) {
return types.Equal
} else if len(r) == 0 && len(l) != 0 {
return types.Better
} else {
for _, lkyc := range l {
for _, rkyc := range r {
if comp := Compare(lkyc, rkyc); comp == types.Equal {
return types.Equal
}
}
}
}
return types.Error
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func PriceInformationComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.PriceInformation)
_, rrawok := rraw.(types.PriceInformation)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.PriceInformation)
r := rraw.(types.PriceInformation)
//nolint
if reflect.DeepEqual(l, r) {
return types.Equal
} else if l.Currency == r.Currency {
//nolint
if l.TotalPerJob == r.TotalPerJob {
if l.CurrencyPerHour == r.CurrencyPerHour {
return types.Equal
} else if l.CurrencyPerHour < r.CurrencyPerHour {
return types.Better
} else {
return types.Worse
}
} else if l.TotalPerJob < r.TotalPerJob {
if l.CurrencyPerHour <= r.CurrencyPerHour {
return types.Better
} else {
return types.Worse
}
} else {
return types.Worse
}
}
return types.Error
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func PricesInformationComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
// validate input type
_, lrawok := lraw.([]types.PriceInformation)
_, rrawok := rraw.([]types.PriceInformation)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.PriceInformation)
r := rraw.([]types.PriceInformation)
if reflect.DeepEqual(l, r) {
return types.Equal
}
comparison := types.Error
for _, lPrice := range l {
for _, rPrice := range r {
if comparison = Compare(lPrice, rPrice); comparison != types.Error {
return comparison
}
}
}
return comparison
}
package matching
import "gitlab.com/nunet/device-management-service/types"
func StorageComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Storage)
_, rrawok := rraw.(types.Storage)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Storage)
r := rraw.(types.Storage)
if l.Type == r.Type {
if l.Size*l.Amount == r.Size*l.Amount {
return types.Equal
} else if l.Size*l.Amount > r.Size*l.Amount {
return types.Better
}
return types.Worse
}
if l.Type == types.SSD_STORAGE_TYPE && r.Type == types.HDD_STORAGE_TYPE {
if l.Size*l.Amount == r.Size*r.Amount {
return types.Better
} else if l.Size*l.Amount > r.Size*r.Amount {
return types.Better
}
return types.Worse
} else if l.Type == types.HDD_STORAGE_TYPE && r.Type == types.SSD_STORAGE_TYPE {
return types.Worse
}
return types.Error
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func StoragesComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.Storage)
_, rrawok := rraw.([]types.Storage)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.Storage)
r := rraw.([]types.Storage)
rAcc := map[string]map[string]int{
"ssd": {
"size": 0,
"amount": 0,
},
"hdd": {
"size": 0,
"amount": 0,
},
}
for _, rstrg := range r {
if rstrg.Type == "ssd" {
rAcc["ssd"]["size"] += rstrg.Size * rstrg.Amount
} else if rstrg.Type == "hdd" {
rAcc["hdd"]["size"] += rstrg.Size * rstrg.Amount
}
}
lAcc := map[string]map[string]int{
"ssd": {
"size": 0,
"amount": 0,
},
"hdd": {
"size": 0,
"amount": 0,
},
}
for _, lstrg := range l {
if lstrg.Type == "ssd" {
lAcc["ssd"]["size"] += lstrg.Size * lstrg.Amount
} else if lstrg.Type == "hdd" {
lAcc["hdd"]["size"] += lstrg.Size * lstrg.Amount
}
}
// compare
totalRequestedSSD := rAcc["ssd"]["size"]
totalRequestedHDD := rAcc["hdd"]["size"]
totalAvailableSSD := lAcc["ssd"]["size"]
totalAvailableHDD := lAcc["hdd"]["size"]
// if hdd is being requested but we don't have it
//nolint
if totalAvailableHDD == 0 && totalRequestedHDD > 0 {
// if ssd is better than ssd and hdd combined
if totalAvailableSSD >= totalRequestedSSD+totalRequestedHDD {
return types.Better
}
return types.Worse
} else if totalAvailableSSD < totalRequestedSSD {
return types.Worse
} else if totalAvailableSSD == totalRequestedSSD && totalAvailableHDD == totalRequestedHDD {
return types.Equal
} else if totalAvailableSSD >= totalRequestedSSD && totalAvailableHDD >= totalRequestedHDD {
return types.Better
}
return types.Error
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func TimeInformationComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.TimeInformation)
_, rrawok := rraw.(types.TimeInformation)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.TimeInformation)
r := rraw.(types.TimeInformation)
if reflect.DeepEqual(l, r) {
return types.Equal
}
lTotalTime := totalTime(l)
rTotalTime := totalTime(r)
if lTotalTime == rTotalTime {
return types.Equal
} else if lTotalTime < rTotalTime {
return types.Worse
}
return types.Better
}
func totalTime(timeInfo types.TimeInformation) int {
switch timeInfo.Units {
case "seconds":
return timeInfo.MaxTime
case "minutes":
return timeInfo.MaxTime * 60
case "hours":
return timeInfo.MaxTime * 60 * 60
case "days":
return timeInfo.MaxTime * 60 * 60 * 24
default:
return timeInfo.MaxTime
}
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func ReturnComplexComparison(l, r interface{}) types.ComplexComparison {
// Complex comparison is a comparison of two complex types
// Which have nested fields that need to be considered together
// before a final comparison for the whole complex type can be made
// it is a helper function used in some type comparators
vl := reflect.ValueOf(l)
vr := reflect.ValueOf(r)
complexComparison := make(types.ComplexComparison)
for i := 0; i < vl.NumField(); i++ {
innerTypeName := vl.Type().Field(i).Name
valueL := vl.Field(i).Interface()
valueR := vr.Field(i).Interface()
complexComparison[innerTypeName] = Compare(valueL, valueR)
}
return complexComparison
}
func removeIndex(slice [][]types.Comparison, index int) [][]types.Comparison {
// removeIndex removes the element at the specified index from each sub-slice in the given slice.
// If the index is out of bounds for a sub-slice, the function leaves that sub-slice unmodified.
for i, c := range slice {
if index < 0 || index >= len(c) {
// Index is out of bounds, leave the sub-slice unmodified
continue
}
slice[i] = append(c[:index], c[index+1:]...)
}
return slice
}
func returnBestMatch(dimension []types.Comparison) (types.Comparison, int) {
// while i feel that there could be some weird matrix sorting algorithm that could be used here
// i can't think of any right now, so i will just iterate over the matrix and return matches
// in somewhat manual way
for i, v := range dimension {
if v == types.Equal {
return v, i // selecting an equal match is the most efficient match
}
}
for i, v := range dimension {
if v == types.Better {
return v, i // selecting a better is also not bad
}
}
for i, v := range dimension {
if v == types.Worse {
return v, i // this is just for sport
}
}
for i, v := range dimension {
if v == types.Error {
return v, i // this is just for sport
}
}
return types.Error, -1
}
func SliceContainsOneValue(slice []types.Comparison, value types.Comparison) bool {
// returns true if all elements in the slice are equal to the given value
for _, v := range slice {
if v != value {
return false
}
}
return true
}
package resources
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/db"
gormRepo "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var (
// zlog is the logger for the resources package
zlog *otelzap.Logger
// ManagerInstance is the ResourceManager instance
ManagerInstance Manager
)
// TODO: This needs to be initialized in `dms` package and removed from here
// https://gitlab.com/nunet/device-management-service/-/issues/536
// it is being initialized in `dms` package now but there is usage in executor
// in executor/docker/executor.go:262:25 in function newDockerExecutionContainer
// which heavily depends on this var and any attempt to fix it will involve
// too many changes. Once that code moves to allocations, this can be removed.
func init() {
zlog = logger.OtelZapLogger("resources")
repos := ManagerRepos{
FreeResources: gormRepo.NewFreeResources(db.DB),
OnboardedResources: gormRepo.NewOnboardedResources(db.DB),
RequiredResources: gormRepo.NewRequiredResources(db.DB),
VirtualMachine: gormRepo.NewVirtualMachine(db.DB),
Services: gormRepo.NewServices(db.DB),
}
ManagerInstance = NewResourceManager(repos)
}
package resources
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// SystemSpecs is an interface that defines the methods to get the system specifications of the machine
type SystemSpecs interface {
// GetSpecInfo returns the detailed specifications of the machine
GetSpecInfo() (types.SpecInfo, error)
// GetGPUVendors returns the GPU vendors of the machine
GetGPUVendors() ([]types.GPUVendor, error)
// GetGPUs returns the GPUs of the machine for the given vendors
// If no vendors are provided, it returns the information of all the GPUs
GetGPUs(vendors ...types.GPUVendor) ([]types.GPU, error)
// GetTotalMemory returns the total memory of the machine in MB
GetTotalMemory() (uint64, error)
// GetTotalStorage returns the total storage of the machine in MB
GetTotalStorage() (uint64, error)
// GetCPUInfo returns the CPU information of the machine
GetCPUInfo() (types.CPUInfo, error)
// GetProvisionedResources returns the total resources of the machine
GetProvisionedResources() (types.Resources, error)
}
// Manager is an interface that defines the methods to manage the resources of the machine
type Manager interface {
// UpdateFreeResources calculates, updates db and returns the free resources of the machine in the database
UpdateFreeResources(context.Context) (types.FreeResources, error)
// GetOnboardedResources returns the onboarded resources of the machine
GetOnboardedResources(context.Context) (types.OnboardedResources, error)
// GetRequiredResources returns the resources required by the jobs running on the machine
GetRequiredResources(context.Context) (types.Resources, error)
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
UpdateOnboardedResources(context.Context, types.OnboardedResources) error
// SystemSpecs returns the SystemSpecs instance
SystemSpecs() SystemSpecs
// UsageMonitor returns the UsageMonitor instance
UsageMonitor() UsageMonitor
// ... other methods
}
// DefaultManager implements the Manager interface
// TODO: do we want to have an in-memory cache for the resources instead of querying the DB every time?
// TODO: Add telemetry for the methods https://gitlab.com/nunet/device-management-service/-/issues/535
type DefaultManager struct {
usageMonitor UsageMonitor
systemSpecs SystemSpecs
repos ManagerRepos
}
// ManagerRepos holds all the repositories needed for resource management
type ManagerRepos struct {
FreeResources repositories.FreeResources
OnboardedResources repositories.OnboardedResources
RequiredResources repositories.RequiredResources
VirtualMachine repositories.VirtualMachine
Services repositories.Services
}
// NewResourceManager returns a new defaultResourceManager instance
func NewResourceManager(repos ManagerRepos) *DefaultManager {
sysSpecs := newSystemSpecs()
return &DefaultManager{
usageMonitor: newUsageMonitor(
sysSpecs,
repos.VirtualMachine,
repos.Services,
repos.RequiredResources,
),
systemSpecs: sysSpecs,
repos: repos,
}
}
var _ Manager = (*DefaultManager)(nil)
// UpdateFreeResources calculates, updates db and returns the free resources of the machine in the database
func (d DefaultManager) UpdateFreeResources(ctx context.Context) (types.FreeResources, error) {
usage, err := d.usageMonitor.GetUsage(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting usage: %w", err)
}
onboardedResources, err := d.GetOnboardedResources(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting total resources: %w", err)
}
freeResources, err := onboardedResources.Subtract(usage)
if err != nil {
return types.FreeResources{}, fmt.Errorf("calculating free resources: %w", err)
}
if err := d.updateDBFreeResources(ctx, types.FreeResources{Resources: freeResources}); err != nil {
return types.FreeResources{}, fmt.Errorf("updating free resources in db: %w", err)
}
return types.FreeResources{Resources: freeResources}, nil
}
// GetOnboardedResources returns the onboarded resources of the machine
func (d DefaultManager) GetOnboardedResources(ctx context.Context) (types.OnboardedResources, error) {
onboardedResources, err := d.repos.OnboardedResources.Get(ctx)
if err != nil {
return types.OnboardedResources{}, fmt.Errorf("failed to get onboarded resources: %w", err)
}
return onboardedResources, nil
}
// GetRequiredResources returns the resources required by the jobs running on the machine
func (d DefaultManager) GetRequiredResources(ctx context.Context) (types.Resources, error) {
jobRequirements, err := d.repos.RequiredResources.FindAll(ctx, d.repos.RequiredResources.GetQuery())
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get resource requirements from db - %v", err)
}
var totalRequiredResources types.Resources
for _, req := range jobRequirements {
totalRequiredResources = totalRequiredResources.Add(req.Resources)
}
return totalRequiredResources, nil
}
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
func (d DefaultManager) UpdateOnboardedResources(ctx context.Context, resources types.OnboardedResources) error {
_, err := d.repos.OnboardedResources.Save(ctx, resources)
if err != nil {
return fmt.Errorf("failed to update onboarded resources: %w", err)
}
return nil
}
// SystemSpecs returns the SystemSpecs instance
func (d DefaultManager) SystemSpecs() SystemSpecs {
return d.systemSpecs
}
// UsageMonitor returns the UsageMonitor instance
func (d DefaultManager) UsageMonitor() UsageMonitor {
return d.usageMonitor
}
// updateDBFreeResources updates the free resources in the database
func (d DefaultManager) updateDBFreeResources(ctx context.Context, freeResources types.FreeResources) error {
_, err := d.repos.FreeResources.Save(ctx, freeResources)
if err != nil {
return fmt.Errorf("updating free resources: %w", err)
}
return nil
}
package resources
import (
"context"
"errors"
"fmt"
"os/exec"
"regexp"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/types"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/jaypipes/ghw"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/disk"
"github.com/shirou/gopsutil/v4/mem"
)
type gpuMetadata struct {
PCIAddress string
}
// linuxSystemSpecs implements the SystemSpecs interface for Linux systems
type linuxSystemSpecs struct{}
// newSystemSpecs returns a new instance of linuxSystemSpecs
func newSystemSpecs() *linuxSystemSpecs {
return &linuxSystemSpecs{}
}
var _ SystemSpecs = (*linuxSystemSpecs)(nil)
// GetSpecInfo returns the detailed specifications of the system
// TODO: implement the method
// https://gitlab.com/nunet/device-management-service/-/issues/537
func (l linuxSystemSpecs) GetSpecInfo() (types.SpecInfo, error) {
// TODO implement me
panic("implement me")
}
// GetGPUVendors returns the GPU vendors for the system
func (l linuxSystemSpecs) GetGPUVendors() ([]types.GPUVendor, error) {
var vendors []types.GPUVendor
gpu, err := ghw.GPU()
if err != nil {
return nil, err
}
for _, card := range gpu.GraphicsCards {
if card.DeviceInfo != nil {
class := card.DeviceInfo.Class
if class != nil {
className := strings.ToLower(class.Name)
if strings.Contains(className, "display controller") ||
strings.Contains(className, "vga compatible controller") ||
strings.Contains(className, "3d controller") ||
strings.Contains(className, "2d controller") {
vendor := card.DeviceInfo.Vendor
if vendor != nil {
switch {
case strings.Contains(strings.ToLower(vendor.Name), "nvidia"):
vendors = append(vendors, types.GPUVendorNvidia)
case strings.Contains(strings.ToLower(vendor.Name), "amd"):
vendors = append(vendors, types.GPUVendorAMDATI)
case strings.Contains(strings.ToLower(vendor.Name), "intel"):
vendors = append(vendors, types.GPUVendorIntel)
default:
vendors = append(vendors, types.GPUVendorUnknown)
}
}
}
}
}
}
return vendors, nil
}
// GetGPUs returns the GPUs based on the specified vendors. If no vendors are provided, it returns the information of all the GPUs
func (l linuxSystemSpecs) GetGPUs(vendors ...types.GPUVendor) ([]types.GPU, error) {
var gpus []types.GPU
gpuMetadataMap, err := l.fetchGPUMetadata()
if err != nil {
return nil, fmt.Errorf("failed to fetch GPU metadata: %w", err)
}
// Helper function to fetch and append GPU info
fetchAndAppendGPUs := func(fetchFunc func(metadata []gpuMetadata) ([]types.GPU, error), vendor types.GPUVendor) {
vendorMetadata, ok := gpuMetadataMap[vendor]
if !ok {
zlog.Sugar().Infof("No %s GPUs found", vendor)
return
}
gpuList, err := fetchFunc(vendorMetadata)
if err != nil {
zlog.Sugar().Warnf("Failed to retrieve %s GPU information: %v", vendor, err)
return
}
gpus = append(gpus, gpuList...)
}
if len(vendors) == 0 {
// No specific vendor requested, fetch all types of GPUs
fetchAndAppendGPUs(l.getIntelGPUInfo, types.GPUVendorIntel)
fetchAndAppendGPUs(l.getNVIDIAGPUInfo, types.GPUVendorNvidia)
fetchAndAppendGPUs(l.getAMDGPUInfo, types.GPUVendorAMDATI)
} else {
// Fetch GPUs for the specified vendor only
for _, vendor := range vendors {
switch vendor {
case types.GPUVendorIntel:
fetchAndAppendGPUs(l.getIntelGPUInfo, vendor)
case types.GPUVendorNvidia:
fetchAndAppendGPUs(l.getNVIDIAGPUInfo, vendor)
case types.GPUVendorAMDATI:
fetchAndAppendGPUs(l.getAMDGPUInfo, vendor)
default:
return nil, fmt.Errorf("unsupported GPU vendor: %v", vendor)
}
}
}
// Assign index to GPUs and return
// Note: The index is internal to dms and is not the same as the device index
return assignIndexToGPUs(gpus), nil
}
// GetTotalMemory returns the total memory available on the system
func (l linuxSystemSpecs) GetTotalMemory() (uint64, error) {
v, err := mem.VirtualMemory()
if err != nil {
return 0, fmt.Errorf("failed to get total memory: %s", err)
}
ramInMB := v.Total / 1024 / 1024
return ramInMB, nil
}
// GetTotalStorage returns the total storage available on the system
func (l linuxSystemSpecs) GetTotalStorage() (uint64, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return 0, fmt.Errorf("failed to get partitions: %w", err)
}
var totalStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return 0, fmt.Errorf("failed to get disk usage: %w", err)
}
totalStorage += usage.Total
}
return totalStorage, nil
}
// GetCPUInfo returns the CPU information for the system
func (l linuxSystemSpecs) GetCPUInfo() (types.CPUInfo, error) {
cores, err := cpu.Info()
if err != nil {
return types.CPUInfo{}, fmt.Errorf("failed to get CPU info: %s", err)
}
var totalCompute float64
for i := 0; i < len(cores); i++ {
totalCompute += cores[i].Mhz
}
return types.CPUInfo{
Compute: totalCompute,
NumCores: uint64(len(cores)),
MHzPerCore: cores[0].Mhz,
}, nil
}
// GetProvisionedResources returns the total resources available on the system
func (l linuxSystemSpecs) GetProvisionedResources() (types.Resources, error) {
cpuInfo, err := l.GetCPUInfo()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get CPU info: %s", err)
}
totalMemory, err := l.GetTotalMemory()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get total memory: %s", err)
}
gpus, err := l.GetGPUs()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get GPU info: %s", err)
}
totalDisk, err := l.GetTotalStorage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get total storage: %s", err)
}
return types.Resources{
CPU: cpuInfo.Compute,
RAM: totalMemory,
Disk: totalDisk,
GPU: gpus,
}, nil
}
// getAMDGPUInfo returns the GPU information for AMD GPUs
func (l linuxSystemSpecs) getAMDGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
cmd := exec.Command("rocm-smi", "--showid", "--showproductname", "--showmeminfo", "vram")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("AMD ROCm not installed, initialized, or configured (reboot recommended for newly installed AMD GPU Drivers): %s", err)
}
outputStr := string(output)
// fmt.Println("rocm-smi vram output:\n", outputStr) // Print the output for debugging
gpuNameRegex := regexp.MustCompile(`GPU\[\d+\]\s+: Card Series:\s+([^\n]+)`)
totalRegex := regexp.MustCompile(`GPU\[\d+\]\s+: VRAM Total Memory \(B\):\s+(\d+)`)
usedRegex := regexp.MustCompile(`GPU\[\d+\]\s+: VRAM Total Used Memory \(B\):\s+(\d+)`)
gpuNameMatches := gpuNameRegex.FindAllStringSubmatch(outputStr, -1)
totalMatches := totalRegex.FindAllStringSubmatch(outputStr, -1)
usedMatches := usedRegex.FindAllStringSubmatch(outputStr, -1)
if len(gpuNameMatches) == 0 || len(totalMatches) == 0 || len(usedMatches) == 0 {
return nil, fmt.Errorf("failed to find AMD GPU information or vram information in the output")
}
if len(gpuNameMatches) != len(totalMatches) || len(totalMatches) != len(usedMatches) {
return nil, fmt.Errorf("inconsistent AMD GPU information detected")
}
gpuInfos := make([]types.GPU, 0)
for i := range gpuNameMatches {
gpuName := gpuNameMatches[i][1]
totalMemoryBytes, err := strconv.ParseInt(totalMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse total amdgpu vram: %s", err)
}
usedMemoryBytes, err := strconv.ParseInt(usedMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse used amdgpu vram: %s", err)
}
totalMemoryMiB := totalMemoryBytes / 1024 / 1024
usedMemoryMiB := usedMemoryBytes / 1024 / 1024
freeMemoryMiB := totalMemoryMiB - usedMemoryMiB
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
TotalVRAM: uint64(totalMemoryMiB),
UsedVRAM: uint64(usedMemoryMiB),
FreeVRAM: uint64(freeMemoryMiB),
Vendor: types.GPUVendorAMDATI,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// getNVIDIAGPUInfo returns the GPU information for NVIDIA GPUs
func (l linuxSystemSpecs) getNVIDIAGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
// Initialize NVML
ret := nvml.Init()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("NVIDIA Management Library not installed, initialized or configured (reboot recommended for newly installed NVIDIA GPU drivers): %s", nvml.ErrorString(ret))
}
defer func() {
_ = nvml.Shutdown()
}()
// Get the number of GPU devices
deviceCount, ret := nvml.DeviceGetCount()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device count: %s", nvml.ErrorString(ret))
}
if deviceCount != len(metadata) {
return nil, fmt.Errorf("failed to find NVIDIA GPU information for all GPUs")
}
var gpus []types.GPU
// Iterate over each device
for i := 0; i < deviceCount; i++ {
// Get the device handle
device, ret := nvml.DeviceGetHandleByIndex(i)
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device handle for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the device name
name, ret := device.GetName()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get name for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the memory info
memory, ret := device.GetMemoryInfo()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get nvidiagpu vram info for device %d: %s", i, nvml.ErrorString(ret))
}
gpu := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Name: name,
Model: name,
TotalVRAM: memory.Total / 1024 / 1024,
UsedVRAM: memory.Used / 1024 / 1024,
FreeVRAM: memory.Free / 1024 / 1024,
Vendor: types.GPUVendorNvidia,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// getIntelGPUInfo returns the GPU information for Intel GPUs
func (l linuxSystemSpecs) getIntelGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
// Determine the number of discrete Intel GPUs
cmd := exec.Command("xpu-smi", "health", "-l")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
outputStr := string(output)
// fmt.Println("xpu-smi health -l output:\n", outputStr) // Print the output for debugging
// Use regex to find all instances of Device ID
deviceIDRegex := regexp.MustCompile(`(?i)\| Device ID\s+\|\s+(\d+)\s+\|`)
deviceIDMatches := deviceIDRegex.FindAllStringSubmatch(outputStr, -1)
// fmt.Printf("Found device ID matches: %v\n", deviceIDMatches) // Print matched device IDs for debugging
if len(deviceIDMatches) == 0 {
return nil, fmt.Errorf("failed to find any Intel GPUs")
}
if len(deviceIDMatches) != len(metadata) {
return nil, fmt.Errorf("failed to find Intel GPU information for all GPUs")
}
gpuInfos := make([]types.GPU, 0)
for i, match := range deviceIDMatches {
deviceID := match[1]
// Get GPU details using xpu-smi discovery
cmd = exec.Command("xpu-smi", "discovery", "-d", deviceID)
output, err = cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to get discovery info for Intel GPU %s: %s", deviceID, err)
}
outputStr = string(output)
// fmt.Printf("xpu-smi discovery -d %s output:\n%s", deviceID, outputStr) // Print the output for debugging
// Use regex to find GPU name and total memory
nameRegex := regexp.MustCompile(`(?i)Device Name:\s+([^\n|]+)`)
totalMemRegex := regexp.MustCompile(`(?i)Memory Physical Size:\s+([^\s]+)\s+MiB`)
nameMatch := nameRegex.FindStringSubmatch(outputStr)
totalMemMatch := totalMemRegex.FindStringSubmatch(outputStr)
if nameMatch == nil || totalMemMatch == nil {
return nil, fmt.Errorf("failed to parse discovery info for Intel GPU %s", deviceID)
}
gpuName := strings.TrimSpace(nameMatch[1])
totalMemoryMiB, err := strconv.ParseFloat(totalMemMatch[1], 64)
if err != nil {
return nil, fmt.Errorf("failed to parse total memory for Intel GPU %s: %s", deviceID, err)
}
// Get used memory using xpu-smi stats
cmd = exec.Command("xpu-smi", "stats", "-d", deviceID)
output, err = cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to get stats for Intel GPU %s: %s", deviceID, err)
}
outputStr = string(output)
// fmt.Printf("xpu-smi stats -d %s output:\n%s", deviceID, outputStr) // Print the output for debugging
// Use regex to find used memory
usedMemRegex := regexp.MustCompile(`(?i)GPU Memory Used \(MiB\)\s+\|\s+(\d+)\s+\|`)
usedMemMatch := usedMemRegex.FindStringSubmatch(outputStr)
if usedMemMatch == nil {
return nil, fmt.Errorf("failed to parse used memory for Intel GPU %s", deviceID)
}
usedMemoryMiB, err := strconv.ParseFloat(usedMemMatch[1], 64)
if err != nil {
return nil, fmt.Errorf("failed to parse used memory for Intel GPU %s: %s", deviceID, err)
}
freeMemoryMiB := totalMemoryMiB - usedMemoryMiB
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
TotalVRAM: uint64(totalMemoryMiB),
UsedVRAM: uint64(usedMemoryMiB),
FreeVRAM: uint64(freeMemoryMiB),
Vendor: types.GPUVendorIntel,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// fetchGPUMetadata fetches the GPU metadata for the system using `ghw.GPU()`
// TODO: Use one single library to fetch GPU information or improve the match criteria
// https://gitlab.com/nunet/device-management-service/-/issues/548
// TODO: write tests by mocking the gpu snapshot
// https://gitlab.com/nunet/device-management-service/-/issues/534
func (l linuxSystemSpecs) fetchGPUMetadata() (map[types.GPUVendor][]gpuMetadata, error) {
gpuInfo, err := ghw.GPU()
if err != nil {
return nil, err
}
gpuDetails := make(map[types.GPUVendor][]gpuMetadata)
for _, card := range gpuInfo.GraphicsCards {
if card.DeviceInfo == nil {
continue
}
pciAddress := card.Address
vendor := types.ParseGPUVendor(card.DeviceInfo.Vendor.Name)
gpuDetails[vendor] = append(gpuDetails[vendor], gpuMetadata{PCIAddress: pciAddress})
}
return gpuDetails, nil
}
// assignIndexToGPUs assigns an index to each GPU in the list starting from 0
func assignIndexToGPUs(gpus []types.GPU) []types.GPU {
for i := range gpus {
gpus[i].Index = i
}
return gpus
}
package resources
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// UsageMonitor defines the methods to monitor the system usage
type UsageMonitor interface {
// GetUsage returns the resources used by the machine
GetUsage(context.Context) (types.Resources, error)
}
// defaultUsageMonitor implements the UsageMonitor interface
type defaultUsageMonitor struct {
systemSpecs SystemSpecs
vmRepo repositories.VirtualMachine
serviceRepo repositories.Services
requiredResourcesRepo repositories.RequiredResources
}
// newUsageMonitor creates a new defaultUsageMonitor
func newUsageMonitor(
systemSpecs SystemSpecs,
vmRepo repositories.VirtualMachine,
serviceRepo repositories.Services,
requiredResourcesRepo repositories.RequiredResources,
) *defaultUsageMonitor {
return &defaultUsageMonitor{
systemSpecs: systemSpecs,
vmRepo: vmRepo,
serviceRepo: serviceRepo,
requiredResourcesRepo: requiredResourcesRepo,
}
}
var _ UsageMonitor = (*defaultUsageMonitor)(nil)
// GetUsage returns the resources used by the machine
func (um *defaultUsageMonitor) GetUsage(ctx context.Context) (types.Resources, error) {
cpuInfo, err := um.systemSpecs.GetCPUInfo()
if err != nil {
return types.Resources{}, fmt.Errorf("getting CPU info: %w", err)
}
vmUsage, err := um.getVMUsage(ctx, cpuInfo)
if err != nil {
return types.Resources{}, fmt.Errorf("getting VM usage: %w", err)
}
contUsage, err := um.getContainerUsage(ctx)
if err != nil {
return types.Resources{}, fmt.Errorf("getting container usage: %w", err)
}
return vmUsage.Add(contUsage), nil
}
// getVMUsage returns the total usage of all running VMs
func (um *defaultUsageMonitor) getVMUsage(ctx context.Context, cpuInfo types.CPUInfo) (types.Resources, error) {
query := um.vmRepo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("State", "running"))
vms, err := um.vmRepo.FindAll(ctx, query)
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get running VMs: %w", err)
}
var resourcesUsage types.Resources
if len(vms) == 0 {
return resourcesUsage, nil
}
// TODO: disk usage
var totalVCPU, totalMemSizeMib uint
for _, vm := range vms {
totalVCPU += vm.VCPUCount
totalMemSizeMib += vm.MemSizeMib
}
resourcesUsage.RAM = uint64(totalMemSizeMib)
resourcesUsage.CPU = float64(totalVCPU) * cpuInfo.MHzPerCore // CPU in MHz
return resourcesUsage, nil
}
// getContainerUsage returns the total usage of all running containers
func (um *defaultUsageMonitor) getContainerUsage(ctx context.Context) (types.Resources, error) {
query := um.serviceRepo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("JobStatus", "running"))
services, err := um.serviceRepo.FindAll(ctx, query)
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get running containers: %w", err)
}
var resourcesUsage types.Resources
if len(services) == 0 {
return resourcesUsage, nil
}
requirements, err := um.getResourceRequirements(ctx)
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get resource requirements: %w", err)
}
// TODO: disk usage
for _, service := range services {
resourcesReq := requirements[service.ResourceRequirements]
resourcesUsage.CPU += resourcesReq.CPU
resourcesUsage.RAM += resourcesReq.RAM
}
return resourcesUsage, nil
}
// getResourceRequirements returns the resource requirements of all jobs
func (um *defaultUsageMonitor) getResourceRequirements(ctx context.Context) (map[int]types.RequiredResources, error) {
requiredResources, err := um.requiredResourcesRepo.FindAll(ctx, um.requiredResourcesRepo.GetQuery())
if err != nil {
return nil, fmt.Errorf("unable to query resource requirements: %w", err)
}
mappedRequiredResources := make(map[int]types.RequiredResources)
for _, rr := range requiredResources {
mappedRequiredResources[rr.JobID] = rr
}
return mappedRequiredResources, nil
}
package docker
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/stdcopy"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
// Client wraps the Docker client to provide high-level operations on Docker containers and networks.
type Client struct {
client *client.Client // Embed the Docker client.
}
// NewDockerClient initializes a new Docker client with environment variables and API version negotiation.
func NewDockerClient() (*Client, error) {
c, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}
return &Client{client: c}, nil
}
// IsInstalled checks if Docker is installed and reachable by pinging the Docker daemon.
func (c *Client) IsInstalled(ctx context.Context) bool {
_, err := c.client.Ping(ctx)
return err == nil
}
// CreateContainer creates a new Docker container with the specified configuration.
func (c *Client) CreateContainer(
ctx context.Context,
config *container.Config,
hostConfig *container.HostConfig,
networkingConfig *network.NetworkingConfig,
platform *v1.Platform,
name string,
) (string, error) {
_, err := c.PullImage(ctx, config.Image)
if err != nil {
return "", err
}
resp, err := c.client.ContainerCreate(
ctx,
config,
hostConfig,
networkingConfig,
platform,
name,
)
if err != nil {
return "", err
}
return resp.ID, nil
}
// InspectContainer returns detailed information about a Docker container.
func (c *Client) InspectContainer(ctx context.Context, id string) (types.ContainerJSON, error) {
return c.client.ContainerInspect(ctx, id)
}
// FollowLogs tails the logs of a specified container, returning separate readers for stdout and stderr.
func (c *Client) FollowLogs(ctx context.Context, id string) (stdout, stderr io.Reader, err error) {
cont, err := c.InspectContainer(ctx, id)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
}
logsReader, err := c.client.ContainerLogs(ctx, cont.ID, logOptions)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container logs")
}
stdoutReader, stdoutWriter := io.Pipe()
stderrReader, stderrWriter := io.Pipe()
go func() {
stdoutBuffer := bufio.NewWriter(stdoutWriter)
stderrBuffer := bufio.NewWriter(stderrWriter)
defer func() {
logsReader.Close()
stdoutBuffer.Flush()
stdoutWriter.Close()
stderrBuffer.Flush()
stderrWriter.Close()
}()
_, err = stdcopy.StdCopy(stdoutBuffer, stderrBuffer, logsReader)
if err != nil && !errors.Is(err, context.Canceled) {
zlog.Sugar().Warnf("context closed while getting logs: %v\n", err)
}
}()
return stdoutReader, stderrReader, nil
}
// StartContainer starts a specified Docker container.
func (c *Client) StartContainer(ctx context.Context, containerID string) error {
return c.client.ContainerStart(ctx, containerID, types.ContainerStartOptions{})
}
// WaitContainer waits for a container to stop, returning channels for the result and errors.
func (c *Client) WaitContainer(
ctx context.Context,
containerID string,
) (<-chan container.ContainerWaitOKBody, <-chan error) {
return c.client.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
}
// StopContainer stops a running Docker container with a specified timeout.
func (c *Client) StopContainer(
ctx context.Context,
containerID string,
timeout time.Duration,
) error {
return c.client.ContainerStop(ctx, containerID, &timeout)
}
// RemoveContainer removes a Docker container, optionally forcing removal and removing associated volumes.
func (c *Client) RemoveContainer(ctx context.Context, containerID string) error {
return c.client.ContainerRemove(
ctx,
containerID,
types.ContainerRemoveOptions{RemoveVolumes: true, Force: true},
)
}
// removeContainers removes all containers matching the specified filters.
func (c *Client) removeContainers(ctx context.Context, filterz filters.Args) error {
containers, err := c.client.ContainerList(
ctx,
types.ContainerListOptions{All: true, Filters: filterz},
)
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(containers))
for _, container := range containers {
wg.Add(1)
go func(container types.Container, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.RemoveContainer(ctx, container.ID)
}(container, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// removeNetworks removes all networks matching the specified filters.
func (c *Client) removeNetworks(ctx context.Context, filterz filters.Args) error {
networks, err := c.client.NetworkList(ctx, types.NetworkListOptions{Filters: filterz})
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(networks))
for _, network := range networks {
wg.Add(1)
go func(network types.NetworkResource, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.client.NetworkRemove(ctx, network.ID)
}(network, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// RemoveObjectsWithLabel removes all Docker containers and networks with a specific label.
func (c *Client) RemoveObjectsWithLabel(ctx context.Context, label string, value string) error {
filterz := filters.NewArgs(
filters.Arg("label", fmt.Sprintf("%s=%s", label, value)),
)
containerErr := c.removeContainers(ctx, filterz)
networkErr := c.removeNetworks(ctx, filterz)
return multierr.Combine(containerErr, networkErr)
}
// GetOutputStream streams the logs for a specified container.
// The 'since' parameter specifies the timestamp from which to start streaming logs.
// The 'follow' parameter indicates whether to continue streaming logs as they are produced.
// Returns an io.ReadCloser to read the output stream and an error if the operation fails.
func (c *Client) GetOutputStream(
ctx context.Context,
containerID string,
since string,
follow bool,
) (io.ReadCloser, error) {
cont, err := c.InspectContainer(ctx, containerID)
if err != nil {
return nil, errors.Wrap(err, "failed to get container")
}
if !cont.State.Running {
return nil, fmt.Errorf("cannot get logs for a container that is not running")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: follow,
Since: since,
}
logReader, err := c.client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
return nil, errors.Wrap(err, "failed to get container logs")
}
return logReader, nil
}
// FindContainer searches for a container by label and value, returning its ID if found.
func (c *Client) FindContainer(ctx context.Context, label string, value string) (string, error) {
containers, err := c.client.ContainerList(ctx, types.ContainerListOptions{All: true})
if err != nil {
return "", err
}
for _, container := range containers {
if container.Labels[label] == value {
return container.ID, nil
}
}
return "", fmt.Errorf("unable to find container for %s=%s", label, value)
}
// PullImage pulls a Docker image from a registry.
func (c *Client) PullImage(ctx context.Context, imageName string) (string, error) {
out, err := c.client.ImagePull(ctx, imageName, types.ImagePullOptions{})
if err != nil {
zlog.Sugar().Errorf("unable to pull image: %v", err)
return "", err
}
defer out.Close()
d := json.NewDecoder(io.TeeReader(out, os.Stdout))
var message jsonmessage.JSONMessage
var digest string
for {
if err := d.Decode(&message); err != nil {
if err == io.EOF {
break
}
zlog.Sugar().Errorf("unable pull image: %v", err)
return "", err
}
if message.Aux != nil {
continue
}
if message.Error != nil {
zlog.Sugar().Errorf("unable pull image: %v", message.Error.Message)
return "", errors.New(message.Error.Message)
}
if strings.HasPrefix(message.Status, "Digest") {
digest = strings.TrimPrefix(message.Status, "Digest: ")
}
}
return digest, nil
}
package docker
import (
"context"
"fmt"
"io"
"os"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
labelExecutorName = "nunet-executor"
labelJobID = "nunet-jobID"
labelExecutionID = "nunet-executionID"
outputStreamCheckTickTime = 100 * time.Millisecond
outputStreamCheckTimeout = 5 * time.Second
)
// Executor manages the lifecycle of Docker containers for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Docker client for container management.
}
// NewExecutor initializes a new Executor instance with a Docker client.
func NewExecutor(ctx context.Context, id string) (*Executor, error) {
dockerClient, err := NewDockerClient()
if err != nil {
return nil, err
}
if !dockerClient.IsInstalled(ctx) {
return nil, fmt.Errorf("docker is not installed")
}
return &Executor{
ID: id,
client: dockerClient,
}, nil
}
// Start begins the execution of a request by starting a Docker container.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// container is already running.
containerID, err := e.FindRunningContainer(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running container for this execution, we will instead check for a handler, and
// failing that will create a new container.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
}
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
containerID, err = e.newDockerExecutionContainer(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new container: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
containerID: containerID,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
TTYEnabled: true,
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the container.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s received results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// GetLogStream provides a stream of output logs for a specific execution.
// Parameters 'withHistory' and 'follow' control whether to include past logs
// and whether to keep the stream open for new logs, respectively.
// It returns an error if the execution is not found.
func (e *Executor) GetLogStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
// It's possible we've recorded the execution as running, but have not yet added the handler to
// the handler map because we're still waiting for the container to start. We will try and wait
// for a few seconds to see if the handler is added to the map.
chHandler := make(chan *executionHandler)
chExit := make(chan struct{})
go func(ch chan *executionHandler, exit chan struct{}) {
// Check the handlers every 100ms and send it down the
// channel if we find it. If we don't find it after 5 seconds
// then we'll be told on the exit channel
ticker := time.NewTicker(outputStreamCheckTickTime)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h, found := e.handlers.Get(request.ExecutionID)
if found {
ch <- h
return
}
case <-exit:
ticker.Stop()
return
}
}
}(chHandler, chExit)
// Either we'll find a handler for the execution (which might have finished starting)
// or we'll timeout and return an error.
select {
case handler := <-chHandler:
return handler.outputStream(ctx, request)
case <-time.After(outputStreamCheckTimeout):
chExit <- struct{}{}
}
return nil, fmt.Errorf("execution (%s) not found", request.ExecutionID)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// Cleanup removes all Docker resources associated with the executor.
// This includes removing containers including networks and volumes with the executor's label.
func (e *Executor) Cleanup(ctx context.Context) error {
err := e.client.RemoveObjectsWithLabel(ctx, labelExecutorName, e.ID)
if err != nil {
return fmt.Errorf("failed to remove containers: %w", err)
}
zlog.Info("Cleaned up all Docker resources")
return nil
}
// newDockerExecutionContainer is an internal method called by Start to set up a new Docker container
// for the job execution. It configures the container based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up environment variables, mounts and resource
// constraints. It then creates the container but does not start it.
// The method returns a container.CreateResponse and an error if any part of the setup fails.
func (e *Executor) newDockerExecutionContainer(
ctx context.Context,
params *types.ExecutionRequest,
) (string, error) {
dockerArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return "", fmt.Errorf("failed to decode docker engine spec: %w", err)
}
// TODO: Move this code block ( L263-272) to the allocator in future
// Select the GPU with the highest available free VRAM and choose the GPU vendor for container's host config
gpus, err := resources.ManagerInstance.SystemSpecs().GetGPUs()
if err != nil {
return "", fmt.Errorf("failed to get GPU info: %w", err)
}
maxFreeVRAMGpu, err := types.GPUList(gpus).GetGPUWithHighestFreeVRAM()
if err != nil {
return "", fmt.Errorf("failed to get GPU with highest free VRAM: %w", err)
}
// Essential for multi-vendor GPU nodes. For example,
// if a machine has an 8 GB NVIDIA and a 16 GB Intel GPU, the latter should be used first.
// Even for machines with a single GPU, this is important as integrated GPUs would also be commonly detected.
chosenGPUVendor := maxFreeVRAMGpu.Vendor
containerConfig := container.Config{
Image: dockerArgs.Image,
Tty: true, // Needs to be true for applications such as Jupyter or Gradio to work correctly. See issue #459 for details.
Env: dockerArgs.Environment,
Entrypoint: dockerArgs.Entrypoint,
Cmd: dockerArgs.Cmd,
Labels: e.containerLabels(params.JobID, params.ExecutionID),
WorkingDir: dockerArgs.WorkingDirectory,
}
mounts, err := makeContainerMounts(params.Inputs, params.Outputs, params.ResultsDir)
if err != nil {
return "", fmt.Errorf("failed to create container mounts: %w", err)
}
zlog.Sugar().Infof("Adding %d GPUs to request", len(params.Resources.GPUs))
hostConfig := configureHostConfig(chosenGPUVendor, params, mounts)
if _, err = e.client.PullImage(ctx, dockerArgs.Image); err != nil {
return "", fmt.Errorf("failed to pull docker image: %w", err)
}
executionContainer, err := e.client.CreateContainer(
ctx,
&containerConfig,
&hostConfig,
nil,
nil,
labelExecutionValue(e.ID, params.JobID, params.ExecutionID),
)
if err != nil {
return "", fmt.Errorf("failed to create container: %w", err)
}
return executionContainer, nil
}
// configureHostConfig sets up the host configuration for the container based on the
// GPU vendor and resources requested by the execution. It supports both GPU and CPU configurations.
func configureHostConfig(vendor types.GPUVendor, params *types.ExecutionRequest, mounts []mount.Mount) container.HostConfig {
var hostConfig container.HostConfig
switch vendor {
case types.GPUVendorNvidia:
deviceIDs := make([]string, len(params.Resources.GPUs))
for i, gpu := range params.Resources.GPUs {
deviceIDs[i] = fmt.Sprint(gpu.Index)
}
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: params.Resources.CPU.ClockSpeedHz,
CPUCount: int64(params.Resources.CPU.Cores),
DeviceRequests: []container.DeviceRequest{
{
DeviceIDs: deviceIDs,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: params.Resources.CPU.ClockSpeedHz,
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
// Updated the device handling for Intel GPUs.
// Previously, specific device paths were determined using PCI addresses and symlinks.
// Now, the approach has been simplified by directly binding the entire /dev/dri directory.
// This change exposes all Intel GPUs to the container, which may be preferable for
// environments with multiple Intel GPUs. It reduces complexity as granular control
// is not required if all GPUs need to be accessible.
case types.GPUVendorIntel:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: params.Resources.CPU.ClockSpeedHz,
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
default:
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: params.Resources.CPU.ClockSpeedHz,
CPUCount: int64(params.Resources.CPU.Cores),
},
}
}
return hostConfig
}
// makeContainerMounts creates the mounts for the container based on the input and output
// volumes provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeContainerMounts(
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]mount.Mount, error) {
// the actual mounts we will give to the container
// these are paths for both input and output data
mounts := make([]mount.Mount, 0)
for _, input := range inputs {
if input.Type != types.StorageVolumeTypeBind {
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: input.Source,
Target: input.Target,
ReadOnly: input.ReadOnly,
})
} else {
return nil, fmt.Errorf("unsupported storage volume type: %s", input.Type)
}
}
for _, output := range outputs {
if output.Source == "" {
return nil, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return nil, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return nil, fmt.Errorf("failed to create results directory: %w", err)
}
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: output.Source,
Target: output.Target,
// this is an output volume so can be written to
ReadOnly: false,
})
}
return mounts, nil
}
// containerLabels returns the labels to be applied to the container for the given job and execution.
func (e *Executor) containerLabels(jobID string, executionID string) map[string]string {
return map[string]string{
labelExecutorName: e.ID,
labelJobID: labelJobValue(e.ID, jobID),
labelExecutionID: labelExecutionValue(e.ID, jobID, executionID),
}
}
// labelJobValue returns the value for the job label.
func labelJobValue(executorID string, jobID string) string {
return fmt.Sprintf("%s_%s", executorID, jobID)
}
// labelExecutionValue returns the value for the execution label.
func labelExecutionValue(executorID string, jobID string, executionID string) string {
return fmt.Sprintf("%s_%s_%s", executorID, jobID, executionID)
}
// FindRunningContainer finds the container that is running the execution
// with the given ID. It returns the container ID if found, or an error if
// the container is not found.
func (e *Executor) FindRunningContainer(
ctx context.Context,
jobID string,
executionID string,
) (string, error) {
labelValue := labelExecutionValue(e.ID, jobID, executionID)
return e.client.FindContainer(ctx, labelExecutionID, labelValue)
}
package docker
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"strconv"
"sync/atomic"
"time"
"gitlab.com/nunet/device-management-service/types"
)
const DestroyTimeout = time.Second * 10
// executionHandler manages the lifecycle and execution of a Docker container for a specific job.
type executionHandler struct {
// provided by the executor
ID string
client *Client // Docker client for container management.
// meta data about the task
jobID string
executionID string
containerID string
resultsDir string // Directory to store execution results.
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // Blocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
// TTY setting
TTYEnabled bool // Indicates if TTY is enabled for the container.
}
// active checks if the execution handler's container is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the container and handles its execution lifecycle.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
if err := h.destroy(DestroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
if err := h.client.StartContainer(ctx, h.containerID); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start container: %v", err))
return
}
close(h.activeCh) // Indicate that the container has started.
var containerError error
var containerExitStatusCode int64
// Wait for the container to finish or for an execution error.
statusCh, errCh := h.client.WaitContainer(ctx, h.containerID)
select {
case status := <-ctx.Done():
h.result = types.NewFailedExecutionResult(fmt.Errorf("execution cancelled: %v", status))
return
case err := <-errCh:
zlog.Sugar().Errorf("error while waiting for container: %v\n", err)
h.result = types.NewFailedExecutionResult(
fmt.Errorf("failed to wait for container: %v", err),
)
return
case exitStatus := <-statusCh:
containerExitStatusCode = exitStatus.StatusCode
containerJSON, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: err.Error(),
}
return
}
if containerJSON.ContainerJSONBase.State.OOMKilled {
containerError = errors.New("container was killed due to OOM")
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
return
}
if exitStatus.Error != nil {
containerError = errors.New(exitStatus.Error.Message)
}
}
// Follow container logs to capture stdout and stderr.
stdoutPipe, stderrPipe, logsErr := h.client.FollowLogs(ctx, h.containerID)
if logsErr != nil {
followError := fmt.Errorf("failed to follow container logs: %w", logsErr)
if containerError != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: fmt.Sprintf(
"container error: '%s'. logs error: '%s'",
containerError,
followError,
),
}
} else {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: followError.Error(),
}
}
return
}
// Initialize the result with the exit status code.
h.result = types.NewExecutionResult(int(containerExitStatusCode))
// Capture the logs based on the TTY setting.
if h.TTYEnabled {
// TTY combines stdout and stderr, read from stdoutPipe only.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
} else {
// Read from stdout and stderr separately.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
h.result.STDERR, _ = bufio.NewReader(stderrPipe).ReadString('\x00')
}
}
// kill sends a stop signal to the container.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.StopContainer(ctx, h.containerID, DestroyTimeout)
}
// destroy cleans up the container and its associated resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// stop the container
if err := h.kill(ctx); err != nil {
return fmt.Errorf("failed to kill container (%s): %w", h.containerID, err)
}
if err := h.client.RemoveContainer(ctx, h.containerID); err != nil {
return err
}
// Remove related objects like networks or volumes created for this execution.
return h.client.RemoveObjectsWithLabel(
ctx,
labelExecutionID,
labelExecutionValue(h.ID, h.jobID, h.executionID),
)
}
func (h *executionHandler) outputStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
since := "1" // Default to the start of UNIX time to get all logs.
if request.Tail {
since = strconv.FormatInt(time.Now().Unix(), 10)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-h.activeCh: // Ensure the container is active before attempting to stream logs.
}
// Gets the underlying reader, and provides data since the value of the `since` timestamp.
return h.client.GetOutputStream(ctx, h.containerID, since, request.Follow)
}
package docker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("docker.executor")
}
package docker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyImage = "image"
EngineKeyEntrypoint = "entrypoint"
EngineKeyCmd = "cmd"
EngineKeyEnvironment = "environment"
EngineKeyWorkingDirectory = "working_directory"
)
// EngineSpec contains necessary parameters to execute a docker job.
type EngineSpec struct {
// Image this should be pullable by docker
Image string `json:"image,omitempty"`
// Entrypoint optionally override the default entrypoint
Entrypoint []string `json:"entrypoint,omitempty"`
// Cmd specifies the command to run in the container
Cmd []string `json:"cmd,omitempty"`
// EnvironmentVariables is a slice of env to run the container with
Environment []string `json:"environment,omitempty"`
// WorkingDirectory inside the container
WorkingDirectory string `json:"working_directory,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.Image) {
return fmt.Errorf("invalid docker engine params: image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a docker engine spec
// It converts the params into a docker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(types.ExecutorTypeDocker) {
return EngineSpec{}, fmt.Errorf(
"invalid docker engine type. expected %s, but received: %s",
types.ExecutorTypeDocker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid docker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode docker engine params: %w", err)
}
var dockerSpec *EngineSpec
if err := json.Unmarshal(paramBytes, &dockerSpec); err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode docker engine params: %w", err)
}
return *dockerSpec, dockerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Docker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewDockerEngineBuilder function initializes a new DockerEngineBuilder instance.
// It sets the engine type to model.EngineDocker.String() and image as per the input argument.
func NewDockerEngineBuilder(image string) *EngineBuilder {
eb := types.NewSpecConfig(types.ExecutorTypeDocker)
eb.WithParam(EngineKeyImage, image)
return &EngineBuilder{eb: eb}
}
// WithEntrypoint is a builder method that sets the Docker engine entrypoint.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEntrypoint(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEntrypoint, e)
return b
}
// WithCmd is a builder method that sets the Docker engine's Command.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithCmd(c ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyCmd, c)
return b
}
// WithEnvironment is a builder method that sets the Docker engine's environment variables.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEnvironment(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEnvironment, e)
return b
}
// WithWorkingDirectory is a builder method that sets the Docker engine's working directory.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithWorkingDirectory(w string) *EngineBuilder {
b.eb.WithParam(EngineKeyWorkingDirectory, w)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
package firecracker
import (
"context"
"fmt"
"os"
"syscall"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
)
const pidCheckTickTime = 100 * time.Millisecond
// Client wraps the Firecracker SDK to provide high-level operations on Firecracker VMs.
type Client struct{}
func NewFirecrackerClient() (*Client, error) {
return &Client{}, nil
}
// IsInstalled checks if Firecracker is installed on the host.
func (c *Client) IsInstalled(ctx context.Context) bool {
// Check if the Firecracker binary is installed.
// This implementation sends a version request to the Firecracker binary.
ctx, cancel := context.WithTimeout(ctx, 2*time.Second)
defer cancel()
cmd := firecracker.VMCommandBuilder{}.WithArgs([]string{"--version"}).Build(ctx)
version, err := cmd.Output()
if err != nil || !cmd.ProcessState.Success() {
return false
}
return string(version) != ""
}
// CreateVM creates a new Firecracker VM with the specified configuration.
func (c *Client) CreateVM(
ctx context.Context,
cfg firecracker.Config,
) (*firecracker.Machine, error) {
cmd := firecracker.VMCommandBuilder{}.
WithSocketPath(cfg.SocketPath).
Build(ctx)
machineOpts := []firecracker.Opt{
firecracker.WithProcessRunner(cmd),
}
m, err := firecracker.NewMachine(ctx, cfg, machineOpts...)
return m, err
}
// StartVM starts the Firecracker VM.
func (c *Client) StartVM(ctx context.Context, m *firecracker.Machine) error {
return m.Start(ctx)
}
// ShutdownVM shuts down the Firecracker VM.
func (c *Client) ShutdownVM(ctx context.Context, m *firecracker.Machine) error {
return m.Shutdown(ctx)
}
// DestroyVM destroys the Firecracker VM.
func (c *Client) DestroyVM(
ctx context.Context,
m *firecracker.Machine,
timeout time.Duration,
) error {
// Get the PID of the Firecracker process and shut down the VM.
// If the process is still running after the timeout, kill it.
err := c.ShutdownVM(ctx, m)
if err != nil {
return fmt.Errorf("failed to shutdown vm: %w", err)
}
pid, _ := m.PID()
defer os.Remove(m.Cfg.SocketPath)
// If the process is not running, return early.
if pid <= 0 {
return nil
}
// This checks if the process is still running every pidCheckTickTime.
// If the process is still running after the timeout it will set done to false.
done := make(chan bool, 1)
go func() {
ticker := time.NewTicker(pidCheckTickTime)
defer ticker.Stop()
to := time.NewTimer(timeout)
defer to.Stop()
for {
select {
case <-to.C:
done <- false
return
case <-ticker.C:
if pid, _ := m.PID(); pid <= 0 {
done <- true
return
}
}
}
}()
// Wait for the check to finish.
killed := <-done
if !killed {
// The shutdown request timed out, kill the process with SIGKILL.
err := syscall.Kill(pid, syscall.SIGKILL)
if err != nil {
return fmt.Errorf("failed to kill process: %v", err)
}
}
return nil
}
// FindVM finds a Firecracker VM by its socket path.
// This implementation checks if the VM is running by sending a request to the Firecracker API.
func (c *Client) FindVM(ctx context.Context, socketPath string) (*firecracker.Machine, error) {
// Check if the socket file exists.
if _, err := os.Stat(socketPath); err != nil {
return nil, fmt.Errorf("VM with socket path %v not found", socketPath)
}
// Create a new Firecracker machine instance.
cmd := firecracker.VMCommandBuilder{}.WithSocketPath(socketPath).Build(ctx)
machine, err := firecracker.NewMachine(
ctx,
firecracker.Config{SocketPath: socketPath},
firecracker.WithProcessRunner(cmd),
)
if err != nil {
return nil, fmt.Errorf("failed to create machine with socket %s: %v", socketPath, err)
}
// Check if the VM is running by getting its instance info.
info, err := machine.DescribeInstanceInfo(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get instance info for socket %s: %v", socketPath, err)
}
if *info.State != "Running" {
return nil, fmt.Errorf(
"VM with socket %s is not running, current state: %s",
socketPath,
*info.State,
)
}
return machine, nil
}
package firecracker
import (
"context"
"fmt"
"io"
"os"
"sync"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
fcModels "github.com/firecracker-microvm/firecracker-go-sdk/client/models"
"go.uber.org/multierr"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
socketDir = "/tmp"
)
// Executor manages the lifecycle of Firecracker VMs for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Firecracker client for VM management.
}
// NewExecutor initializes a new executor for Firecracker VMs.
func NewExecutor(ctx context.Context, id string) (*Executor, error) {
firecrackerClient, err := NewFirecrackerClient()
if err != nil {
return nil, err
}
if !firecrackerClient.IsInstalled(ctx) {
return nil, fmt.Errorf("firecracker is not installed")
}
fe := &Executor{
ID: id,
client: firecrackerClient,
}
return fe, nil
}
// start begins the execution of a request by starting a new Firecracker VM.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// VM is already running.
machine, err := e.FindRunningVM(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running VM for this execution, we will instead check for a handler, and
// failing that will create a new VM.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
}
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
machine, err = e.newFirecrackerExecutionVM(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new firecracker VM: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
machine: machine,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the VM.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s received results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// GetLogStream is not implemented for Firecracker.
// It is defined to satisfy the Executor interface.
// This method will return an error if called.
func (e *Executor) GetLogStream(_ context.Context, _ types.LogStreamRequest) (io.ReadCloser, error) {
return nil, fmt.Errorf("GetLogStream is not implemented for Firecracker")
}
// Cleanup removes all resources associated with the executor.
// This includes stopping and removing all running VMs and deleting their socket paths.
func (e *Executor) Cleanup(_ context.Context) error {
wg := sync.WaitGroup{}
errCh := make(chan error, len(e.handlers.Keys()))
e.handlers.Iter(func(_ string, handler *executionHandler) bool {
wg.Add(1)
go func(handler *executionHandler, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- handler.destroy(time.Second * 10)
}(handler, &wg, errCh)
return true
})
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
zlog.Info("Cleaned up all firecracker resources")
return errs
}
// newFirecrackerExecutionVM is an internal method called by Start to set up a new Firecracker VM
// for the job execution. It configures the VM based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up mounts and resource constraints.
// It then creates the VM but does not start it. The method returns a firecracker.Machine instance
// and an error if any part of the setup fails.
func (e *Executor) newFirecrackerExecutionVM(
ctx context.Context,
params *types.ExecutionRequest,
) (*firecracker.Machine, error) {
fcArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return nil, fmt.Errorf("failed to decode firecracker engine spec: %w", err)
}
fcConfig := firecracker.Config{
VMID: params.ExecutionID,
SocketPath: e.generateSocketPath(params.JobID, params.ExecutionID),
KernelImagePath: fcArgs.KernelImage,
InitrdPath: fcArgs.Initrd,
KernelArgs: fcArgs.KernelArgs,
MachineCfg: fcModels.MachineConfiguration{
VcpuCount: firecracker.Int64(int64(params.Resources.CPU.Cores)),
MemSizeMib: firecracker.Int64(params.Resources.Memory.Size),
},
}
mounts, err := makeVMMounts(
fcArgs.RootFileSystem,
params.Inputs,
params.Outputs,
params.ResultsDir,
)
if err != nil {
return nil, fmt.Errorf("failed to create VM mounts: %w", err)
}
fcConfig.Drives = mounts
machine, err := e.client.CreateVM(ctx, fcConfig)
if err != nil {
return nil, fmt.Errorf("failed to create VM: %w", err)
}
// e.client.VMPassMMDs(ctx, machine, fcArgs.MMDSMessage)
return machine, nil
}
// makeVMMounts creates the mounts for the VM based on the input and output volumes
// provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeVMMounts(
rootFileSystem string,
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]fcModels.Drive, error) {
var drives []fcModels.Drive
drivesBuilder := firecracker.NewDrivesBuilder(rootFileSystem)
for _, input := range inputs {
drivesBuilder.AddDrive(input.Source, input.ReadOnly)
}
for _, output := range outputs {
if output.Source == "" {
return drives, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return drives, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return drives, fmt.Errorf("failed to create results directory: %w", err)
}
drivesBuilder.AddDrive(output.Source, false)
}
drives = drivesBuilder.Build()
return drives, nil
}
// FindRunningVM finds the VM that is running the execution with the given ID.
// It returns the Mchine instance if found, or an error if the VM is not found.
func (e *Executor) FindRunningVM(
ctx context.Context,
jobID string,
executionID string,
) (*firecracker.Machine, error) {
return e.client.FindVM(ctx, e.generateSocketPath(jobID, executionID))
}
// generateSocketPath generates a socket path based on the job identifiers.
func (e *Executor) generateSocketPath(jobID string, executionID string) string {
return fmt.Sprintf("%s/%s_%s_%s.sock", socketDir, e.ID, jobID, executionID)
}
package firecracker
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
"gitlab.com/nunet/device-management-service/types"
)
// executionHandler is a struct that holds the necessary information to manage the execution of a firecracker VM.
type executionHandler struct {
//
// provided by the executor
ID string
client *Client
// meta data about the task
JobID string
executionID string
machine *firecracker.Machine
resultsDir string
// synchronization
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // BLocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
}
// active returns true if the firecracker VM is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the firecracker VM and waits for it to finish.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
destroyTimeout := time.Second * 10
if err := h.destroy(destroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
// start the VM
zlog.Sugar().Info("starting firecracker execution")
if err := h.client.StartVM(ctx, h.machine); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start VM: %v", err))
return
}
close(h.activeCh) // Indicate that the VM has started.
err := h.machine.Wait(ctx)
if err != nil {
if ctx.Err() != nil {
h.result = types.NewFailedExecutionResult(
fmt.Errorf("context closed while waiting on VM: %v", err),
)
return
}
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to wait on VM: %v", err))
return
}
h.result = types.NewExecutionResult(types.ExecutionStatusCodeSuccess)
}
// kill stops the firecracker VM.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.ShutdownVM(ctx, h.machine)
}
// destroy stops the firecracker VM and removes its resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return h.client.DestroyVM(ctx, h.machine, timeout)
}
package firecracker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("executor.firecracker")
}
package firecracker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyKernelImage = "kernel_image"
EngineKeyKernelArgs = "kernel_args"
EngineKeyRootFileSystem = "root_file_system"
EngineKeyMMDSMessage = "mmds_message"
)
// EngineSpec contains necessary parameters to execute a firecracker job.
type EngineSpec struct {
// KernelImage is the path to the kernel image file.
KernelImage string `json:"kernel_image,omitempty"`
// InitrdPath is the path to the initial ramdisk file.
Initrd string `json:"initrd_path,omitempty"`
// KernelArgs is the kernel command line arguments.
KernelArgs string `json:"kernel_args,omitempty"`
// RootFileSystem is the path to the root file system.
RootFileSystem string `json:"root_file_system,omitempty"`
// MMDSMessage is the MMDS message to be sent to the Firecracker VM.
MMDSMessage string `json:"mmds_message,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.RootFileSystem) {
return fmt.Errorf("invalid firecracker engine params: root_file_system cannot be empty")
}
if validate.IsBlank(c.KernelImage) {
return fmt.Errorf("invalid firecracker engine params: kernel_image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a firecracker engine spec
// It converts the params into a firecracker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(types.ExecutorTypeFirecracker) {
return EngineSpec{}, fmt.Errorf(
"invalid firecracker engine type. expected %s, but received: %s",
types.ExecutorTypeFirecracker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid firecracker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode firecracker engine params: %w", err)
}
var firecrackerSpec *EngineSpec
err = json.Unmarshal(paramBytes, &firecrackerSpec)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode firecracker engine params: %w", err)
}
return *firecrackerSpec, firecrackerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Firecracker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewFirecrackerEngineBuilder function initializes a new FirecrackerEngineBuilder instance.
// It sets the engine type to EngineFirecracker.String() and kernel image path as per the input argument.
func NewFirecrackerEngineBuilder(rootFileSystem string) *EngineBuilder {
eb := types.NewSpecConfig(types.ExecutorTypeFirecracker)
eb.WithParam(EngineKeyRootFileSystem, rootFileSystem)
return &EngineBuilder{eb: eb}
}
// WithRootFileSystem is a builder method that sets the Firecracker engine root file system.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithRootFileSystem(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyRootFileSystem, e)
return b
}
// WithKernelImage is a builder method that sets the Firecracker engine kernel image.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelImage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelImage, e)
return b
}
// WithKernelArgs is a builder method that sets the Firecracker engine kernel arguments.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelArgs(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelArgs, e)
return b
}
// WithMMDSMessage is a builder method that sets the Firecracker engine MMDS message.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithMMDSMessage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyMMDSMessage, e)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
package backgroundtasks
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("background_tasks")
}
package backgroundtasks
import (
"sort"
"sync"
"time"
)
// Scheduler orchestrates the execution of tasks based on their triggers and priority.
type Scheduler struct {
tasks map[int]*Task // Map of tasks by their ID.
runningTasks map[int]bool // Map to keep track of running tasks.
ticker *time.Ticker // Ticker for periodic checks of task triggers.
stopChan chan struct{} // Channel to signal stopping the scheduler.
maxRunningTasks int // Maximum number of tasks that can run concurrently.
lastTaskID int // Counter for assigning unique IDs to tasks.
mu sync.Mutex // Mutex to protect access to task maps.
}
// NewScheduler creates a new Scheduler with a specified limit on running tasks.
func NewScheduler(maxRunningTasks int) *Scheduler {
return &Scheduler{
tasks: make(map[int]*Task),
runningTasks: make(map[int]bool),
ticker: time.NewTicker(1 * time.Second),
stopChan: make(chan struct{}),
maxRunningTasks: maxRunningTasks,
lastTaskID: 0,
}
}
// AddTask adds a new task to the scheduler and initializes its state.
func (s *Scheduler) AddTask(task *Task) *Task {
s.mu.Lock()
defer s.mu.Unlock()
task.ID = s.lastTaskID
task.Enabled = true
for _, trigger := range task.Triggers {
trigger.Reset()
}
s.tasks[task.ID] = task
s.lastTaskID++
return task
}
// RemoveTask removes a task from the scheduler.
func (s *Scheduler) RemoveTask(taskID int) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Start begins the scheduler's task execution loop.
func (s *Scheduler) Start() {
go func() {
for {
select {
case <-s.stopChan:
return
case <-s.ticker.C:
s.runTasks()
}
}
}()
}
// runningTasksCount returns the count of running tasks.
func (s *Scheduler) runningTasksCount() int {
s.mu.Lock()
defer s.mu.Unlock()
count := 0
for _, isRunning := range s.runningTasks {
if isRunning {
count++
}
}
return count
}
// runTasks checks and runs tasks based on their triggers and priority.
func (s *Scheduler) runTasks() {
// Sort tasks by priority.
sortedTasks := make([]*Task, 0, len(s.tasks))
for _, task := range s.tasks {
sortedTasks = append(sortedTasks, task)
}
sort.Slice(sortedTasks, func(i, j int) bool {
return sortedTasks[i].Priority > sortedTasks[j].Priority
})
for _, task := range sortedTasks {
if !task.Enabled || s.runningTasks[task.ID] {
continue
}
if len(task.Triggers) == 0 {
s.RemoveTask(task.ID)
continue
}
for _, trigger := range task.Triggers {
if trigger.IsReady() && s.runningTasksCount() < s.maxRunningTasks {
s.runningTasks[task.ID] = true
go s.runTask(task.ID)
trigger.Reset()
break
}
}
}
}
// Stop signals the scheduler to stop running tasks.
func (s *Scheduler) Stop() {
close(s.stopChan)
}
// runTask executes a task and manages its lifecycle and retry policy.
func (s *Scheduler) runTask(taskID int) {
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
s.runningTasks[taskID] = false
}()
task := s.tasks[taskID]
execution := Execution{StartedAt: time.Now()}
defer func() {
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.tasks[taskID] = task
s.mu.Unlock()
}()
for i := 0; i < task.RetryPolicy.MaxRetries+1; i++ {
err := runTaskWithRetry(task.Function, task.Args, task.RetryPolicy.Delay)
if err == nil {
execution.Status = "SUCCESS"
execution.EndedAt = time.Now()
return
}
execution.Error = err.Error()
}
execution.Status = "FAILED"
execution.EndedAt = time.Now()
}
// runTaskWithRetry attempts to execute a task with a retry policy.
func runTaskWithRetry(
fn func(args interface{}) error,
args []interface{},
delay time.Duration,
) error {
err := fn(args)
if err != nil {
time.Sleep(delay)
return err
}
return nil
}
package backgroundtasks
import (
"time"
"github.com/robfig/cron/v3"
)
// Trigger interface defines a method to check if a trigger condition is met.
type Trigger interface {
IsReady() bool // Returns true if the trigger condition is met.
Reset() // Resets the trigger state.
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTrigger struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTrigger) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTrigger) Reset() {
t.lastTriggered = time.Now()
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTriggerWithJitter struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
Jitter func() time.Duration
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTriggerWithJitter) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval + t.Jitter()).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTriggerWithJitter) Reset() {
t.lastTriggered = time.Now()
}
// EventTrigger triggers based on an external event signaled through a channel.
type EventTrigger struct {
Trigger chan bool // Channel to signal an event.
}
// IsReady checks if there is a signal in the trigger channel.
func (t *EventTrigger) IsReady() bool {
select {
case <-t.Trigger:
return true
default:
return false
}
}
// Reset for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) Reset() {}
// OneTimeTrigger triggers once after a specified delay.
type OneTimeTrigger struct {
Delay time.Duration // The delay after which to trigger.
registeredAt time.Time // Time when the trigger was set.
}
// Reset sets the trigger registration time to the current time.
func (t *OneTimeTrigger) Reset() {
t.registeredAt = time.Now()
}
// IsReady checks if the current time has passed the delay period.
func (t *OneTimeTrigger) IsReady() bool {
return t.registeredAt.Add(t.Delay).Before(time.Now())
}
package did
type GetAnchorFunc func(did DID) (Anchor, error)
var anchorMethods map[string]GetAnchorFunc
func init() {
anchorMethods = map[string]GetAnchorFunc{
"key": makeKeyAnchor,
}
}
func GetAnchorForDID(did DID) (Anchor, error) {
makeAnchor, ok := anchorMethods[did.Method()]
if !ok {
return nil, ErrNoAnchorMethod
}
return makeAnchor(did)
}
func makeKeyAnchor(did DID) (Anchor, error) {
pubk, err := PublicKeyFromDID(did)
if err != nil {
return nil, err
}
return NewAnchor(did, pubk), nil
}
package did
import (
"context"
"fmt"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
const anchorEntryTTL = time.Hour
// Anchor is a DID anchor that encapsulates a public key that can be used
// for verification of signatures.
type Anchor interface {
DID() DID
Verify(data []byte, sig []byte) error
PublicKey() crypto.PubKey
}
// Provider holds the private key material necessary to sign statements for
// a DID.
type Provider interface {
DID() DID
Sign(data []byte) ([]byte, error)
Anchor() Anchor
PrivateKey() crypto.PrivKey
}
type TrustContext interface {
Anchors() []DID
Providers() []DID
GetAnchor(did DID) (Anchor, error)
GetProvider(did DID) (Provider, error)
AddAnchor(anchor Anchor)
AddProvider(provider Provider)
Start(gcInterval time.Duration)
Stop()
}
type anchorEntry struct {
anchor Anchor
expire time.Time
}
type BasicTrustContext struct {
mx sync.Mutex
anchors map[DID]*anchorEntry
providers map[DID]Provider
stop func()
}
var _ TrustContext = (*BasicTrustContext)(nil)
func NewTrustContext() TrustContext {
return &BasicTrustContext{
anchors: make(map[DID]*anchorEntry),
providers: make(map[DID]Provider),
}
}
func (ctx *BasicTrustContext) Anchors() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.anchors))
for anchor := range ctx.anchors {
result = append(result, anchor)
}
return result
}
func (ctx *BasicTrustContext) Providers() []DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]DID, 0, len(ctx.providers))
for provider := range ctx.providers {
result = append(result, provider)
}
return result
}
func (ctx *BasicTrustContext) GetAnchor(did DID) (Anchor, error) {
anchor, ok := ctx.getAnchor(did)
if ok {
return anchor, nil
}
anchor, err := GetAnchorForDID(did)
if err != nil {
return nil, fmt.Errorf("get anchor for did: %w", err)
}
ctx.AddAnchor(anchor)
return anchor, nil
}
func (ctx *BasicTrustContext) getAnchor(did DID) (Anchor, bool) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
entry, ok := ctx.anchors[did]
if ok {
entry.expire = time.Now().Add(anchorEntryTTL)
return entry.anchor, true
}
return nil, false
}
func (ctx *BasicTrustContext) GetProvider(did DID) (Provider, error) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
provider, ok := ctx.providers[did]
if !ok {
return nil, ErrNoProvider
}
return provider, nil
}
func (ctx *BasicTrustContext) AddAnchor(anchor Anchor) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.anchors[anchor.DID()] = &anchorEntry{
anchor: anchor,
expire: time.Now().Add(anchorEntryTTL),
}
}
func (ctx *BasicTrustContext) AddProvider(provider Provider) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
ctx.providers[provider.DID()] = provider
}
func (ctx *BasicTrustContext) Start(gcInterval time.Duration) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
}
gcCtx, stop := context.WithCancel(context.Background())
ctx.stop = stop
go ctx.gc(gcCtx, gcInterval)
}
func (ctx *BasicTrustContext) Stop() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
if ctx.stop != nil {
ctx.stop()
ctx.stop = nil
}
}
func (ctx *BasicTrustContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcAnchorEntries()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicTrustContext) gcAnchorEntries() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := time.Now()
for k, e := range ctx.anchors {
if e.expire.Before(now) {
delete(ctx.anchors, k)
}
}
}
package did
import (
"strings"
)
type DID struct {
URI string `json:"uri,omitempty"`
}
func (did DID) Equal(other DID) bool {
return did.URI == other.URI
}
func (did DID) Empty() bool {
return did.URI == ""
}
func (did DID) String() string {
return did.URI
}
func (did DID) Method() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[1]
}
return ""
}
func (did DID) Identifier() string {
parts := strings.Split(did.URI, ":")
if len(parts) == 3 {
return parts[2]
}
return ""
}
func FromString(s string) (DID, error) {
if s != "" {
parts := strings.Split(s, ":")
if len(parts) == 3 {
return DID{}, ErrInvalidDID
}
for _, part := range parts {
if part == "" {
return DID{}, ErrInvalidDID
}
}
// TODO validate parts according to spec: https://www.w3.org/TR/did-core/
}
return DID{URI: s}, nil
}
package did
import (
"fmt"
"strings"
libp2p_crypto "github.com/libp2p/go-libp2p/core/crypto"
mb "github.com/multiformats/go-multibase"
varint "github.com/multiformats/go-varint"
"gitlab.com/nunet/device-management-service/lib/crypto"
)
type PublicKeyAnchor struct {
did DID
pubk crypto.PubKey
}
var _ Anchor = (*PublicKeyAnchor)(nil)
type PrivateKeyProvider struct {
did DID
privk crypto.PrivKey
}
var _ Provider = (*PrivateKeyProvider)(nil)
func NewAnchor(did DID, pubk crypto.PubKey) Anchor {
return &PublicKeyAnchor{
did: did,
pubk: pubk,
}
}
func NewProvider(did DID, privk crypto.PrivKey) Provider {
return &PrivateKeyProvider{
did: did,
privk: privk,
}
}
func (a *PublicKeyAnchor) DID() DID {
return a.did
}
func (a *PublicKeyAnchor) Verify(data []byte, sig []byte) error {
ok, err := a.pubk.Verify(data, sig)
if err != nil {
return err
}
if !ok {
return ErrInvalidSignature
}
return nil
}
func (a *PublicKeyAnchor) PublicKey() crypto.PubKey {
return a.pubk
}
func (p *PrivateKeyProvider) DID() DID {
return p.did
}
func (p *PrivateKeyProvider) Sign(data []byte) ([]byte, error) {
return p.privk.Sign(data)
}
func (p *PrivateKeyProvider) PrivateKey() crypto.PrivKey {
return p.privk
}
func (p *PrivateKeyProvider) Anchor() Anchor {
return NewAnchor(p.did, p.privk.GetPublic())
}
func FromID(id crypto.ID) (DID, error) {
pubk, err := crypto.PublicKeyFromID(id)
if err != nil {
return DID{}, fmt.Errorf("public key from id: %w", err)
}
return FromPublicKey(pubk), nil
}
func FromPublicKey(pubk crypto.PubKey) DID {
uri := FormatKeyURI(pubk)
return DID{URI: uri}
}
func PublicKeyFromDID(did DID) (crypto.PubKey, error) {
if did.Method() != "key" {
return nil, ErrInvalidDID
}
pubk, err := ParseKeyURI(did.URI)
if err != nil {
return nil, fmt.Errorf("parsing did key identifier: %w", err)
}
return pubk, nil
}
func AnchorFromPublicKey(pubk crypto.PubKey) (Anchor, error) {
did := FromPublicKey(pubk)
return NewAnchor(did, pubk), nil
}
func ProviderFromPrivateKey(privk crypto.PrivKey) (Provider, error) {
did := FromPublicKey(privk.GetPublic())
return NewProvider(did, privk), nil
}
// Note: this code originated in https://github.com/ucan-wg/go-ucan/blob/main/didkey/key.go
// Copyright applies; some superficial modifications by vyzo.
const (
multicodecKindEd25519PubKey uint64 = 0xed
keyPrefix = "did:key"
)
func FormatKeyURI(pubk crypto.PubKey) string {
raw, err := pubk.Raw()
if err != nil {
return ""
}
// TODO other supported key types (secp?)
t := multicodecKindEd25519PubKey
size := varint.UvarintSize(t)
data := make([]byte, size+len(raw))
n := varint.PutUvarint(data, t)
copy(data[n:], raw)
b58BKeyStr, err := mb.Encode(mb.Base58BTC, data)
if err != nil {
return ""
}
return fmt.Sprintf("%s:%s", keyPrefix, b58BKeyStr)
}
func ParseKeyURI(uri string) (crypto.PubKey, error) {
if !strings.HasPrefix(uri, keyPrefix) {
return nil, fmt.Errorf("decentralized identifier is not a 'key' type")
}
uri = strings.TrimPrefix(uri, keyPrefix+":")
enc, data, err := mb.Decode(uri)
if err != nil {
return nil, fmt.Errorf("decoding multibase: %w", err)
}
if enc != mb.Base58BTC {
return nil, fmt.Errorf("unexpected multibase encoding: %s", mb.EncodingToStr[enc])
}
keyType, n, err := varint.FromUvarint(data)
if err != nil {
return nil, err
}
switch keyType {
case multicodecKindEd25519PubKey:
pubk, err := libp2p_crypto.UnmarshalEd25519PublicKey(data[n:])
if err != nil {
return nil, err
}
return pubk, nil
default:
return nil, ErrInvalidKeyType
}
}
package did
import (
"io"
)
func SaveTrustContext(_ TrustContext, _ io.Writer) (int, error) {
// TODO follow up
return 0, ErrTODO
}
func LoadTrustContext(_ io.Reader) (TrustContext, error) {
// TODO follow up
return nil, ErrTODO
}
package ucan
import (
"strings"
)
type Capability string
const Root = Capability("/")
func (c Capability) Implies(other Capability) bool {
if c == other || c == Root {
return true
}
return strings.HasPrefix(string(other), string(c)+"/")
}
package ucan
import (
"bytes"
"context"
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"sync"
"time"
"gitlab.com/nunet/device-management-service/lib/crypto"
"gitlab.com/nunet/device-management-service/lib/did"
)
const (
maxCapabilitySize = 16384
)
type CapabilityContext interface {
// DID returns the context's controlling DID
DID() did.DID
// Trust returns the context's did trust context
Trust() did.TrustContext
// Consume ingests the provided capability tokens
Consume(origin did.DID, cap []byte) error
// Discard discards previously consumed capability tokens
Discard(cap []byte)
// Require ensures that at least one of the capabilities is delegated from
// the subject to the audience, with an appropriate anchor
// An empty list will mean that no capabilities are required and is vacuously
// true.
Require(anchor did.DID, subject crypto.ID, audience crypto.ID, require []Capability) error
// RequireBroadcast ensures that at least one of the capabilities is delegated
// to thes subject for the specified broadcast topics
RequireBroadcast(origin did.DID, subject crypto.ID, topic string, require []Capability) error
// Provide prepares the appropriate capability tokens to prove and delegate authority
// to a subject for an audience.
// - It delegates invocations to the subject with an audience and invoke capabilities
// - It delegates the delegate capabilities to the target with audience the subject
Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, delegate []Capability) ([]byte, error)
// ProvideBroadcast prepares the appropriate capability tokens to prove authority
// to broadcast to a topic
ProvideBroadcast(subject crypto.ID, topic string, expire uint64, broadcast []Capability) ([]byte, error)
// AddRoots adds trust anchors and/or capabilities derived from our anchors
AddRoots(trust []did.DID, require, provide TokenList) error
// Delegate creates the appropriate delegation tokens anchored in our roots
Delegate(subject, audience did.DID, topics []string, expire uint64, cap []Capability, selfSignOnly bool) (TokenList, error)
// DelegateInvocation creates the appropriate invocation tokens anchored in anchor
DelegateInvocation(anchor, subject, audience did.DID, expire uint64, provide []Capability) (TokenList, error)
// DelegateBroadcast creates the appropriate broadcast token anchored in our roots
DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability) (TokenList, error)
// Grant creates the appropriate delegation tokens considering ourselves as the root
Grant(action Action, subject, audience did.DID, topic []string, expire uint64, provide []Capability) (TokenList, error)
// Start starts a token garbage collector goroutine that clears expired tokens
Start(gcInterval time.Duration)
// Stop stops a previously started gc goroutine
Stop()
}
type BasicCapabilityContext struct {
mx sync.Mutex
provider did.Provider
trust did.TrustContext
roots map[did.DID]struct{} // our root anchors of trust
require map[did.DID][]*Token // our acceptance side-roots
provide map[did.DID][]*Token // root capabilities -> tokens
tokens map[did.DID][]*Token // our context dependent capabilities; subject -> tokens
stop func()
}
var _ CapabilityContext = (*BasicCapabilityContext)(nil)
func NewCapabilityContext(trust did.TrustContext, ctxDID did.DID, roots []did.DID, require, provide TokenList) (CapabilityContext, error) {
ctx := &BasicCapabilityContext{
trust: trust,
roots: make(map[did.DID]struct{}),
require: make(map[did.DID][]*Token),
provide: make(map[did.DID][]*Token),
tokens: make(map[did.DID][]*Token),
}
p, err := trust.GetProvider(ctxDID)
if err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
ctx.provider = p
if err := ctx.AddRoots(roots, require, provide); err != nil {
return nil, fmt.Errorf("new capability context: %w", err)
}
return ctx, nil
}
func (ctx *BasicCapabilityContext) DID() did.DID {
return ctx.provider.DID()
}
func (ctx *BasicCapabilityContext) Trust() did.TrustContext {
return ctx.trust
}
func (ctx *BasicCapabilityContext) Start(gcInterval time.Duration) {
if ctx.stop != nil {
gcCtx, cancel := context.WithCancel(context.Background())
go ctx.gc(gcCtx, gcInterval)
ctx.stop = cancel
}
}
func (ctx *BasicCapabilityContext) Stop() {
if ctx.stop != nil {
ctx.stop()
}
}
func (ctx *BasicCapabilityContext) AddRoots(roots []did.DID, require, provide TokenList) error {
ctx.addRoots(roots)
now := uint64(time.Now().UnixNano())
for _, t := range require.Tokens {
if err := t.Verify(ctx.trust, now); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeRequireToken(t)
}
for _, t := range provide.Tokens {
if err := t.Verify(ctx.trust, now); err != nil {
return fmt.Errorf("verify token: %w", err)
}
ctx.consumeProvideToken(t)
}
return nil
}
func (ctx *BasicCapabilityContext) Grant(action Action, subject, audience did.DID, topic []string, expire uint64, provide []Capability) (TokenList, error) {
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return TokenList{}, fmt.Errorf("nonce: %w", err)
}
result := &DMSToken{
Issuer: ctx.DID(),
Subject: subject,
Audience: audience,
Action: action,
Topic: topic,
Capability: provide,
Nonce: nonce,
Expire: expire,
}
data, err := result.SignatureData()
if err != nil {
return TokenList{}, fmt.Errorf("grant: %w", err)
}
sig, err := ctx.provider.Sign(data)
if err != nil {
return TokenList{}, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return TokenList{Tokens: []*Token{{DMS: result}}}, nil
}
func (ctx *BasicCapabilityContext) Delegate(subject, audience did.DID, topics []string, expire uint64, provide []Capability, selfSignOnly bool) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
if selfSignOnly {
goto selfsign
}
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
if len(tokenList) == 0 {
continue
}
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(trustAnchor) && t.AllowDelegation(ctx.DID(), audience, topics, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.Delegate(ctx.provider, subject, audience, topics, expire, providing)
if err != nil {
log.Debugf("error delegating %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
}
selfsign:
tokens, err := ctx.Grant(Delegate, subject, audience, nil, expire, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, tokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) DelegateInvocation(target, subject, audience did.DID, expire uint64, provide []Capability) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
// first get tokens we have about ourselves and see if any allows delegation to
// the subject for the audience
tokenList := ctx.getSubjectTokens(ctx.DID())
tokens := ctx.delegateInvocation(tokenList, target, subject, audience, expire, provide)
result = append(result, tokens...)
// then we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateInvocation(tokenList, trustAnchor, subject, audience, expire, provide)
result = append(result, tokens...)
}
// self-sign as well
selfTokens, err := ctx.Grant(Invoke, subject, audience, nil, expire, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting invocation: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateInvocation(tokenList []*Token, anchor, subject, audience did.DID, expire uint64, provide []Capability) []*Token {
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(ctx.DID(), audience, nil, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.DelegateInvocation(ctx.provider, subject, audience, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) DelegateBroadcast(subject did.DID, topic string, expire uint64, provide []Capability) (TokenList, error) {
if len(provide) == 0 {
return TokenList{}, nil
}
var result []*Token
// first we issue tokens chained on our provide anchors as appropriate
for _, trustAnchor := range ctx.getProvideAnchors() {
tokenList := ctx.getProvideTokens(trustAnchor)
tokens := ctx.delegateBroadcast(tokenList, trustAnchor, subject, topic, expire, provide)
result = append(result, tokens...)
}
// self-sign as well
selfTokens, err := ctx.Grant(Broadcast, subject, did.DID{}, []string{topic}, expire, provide)
if err != nil {
return TokenList{}, fmt.Errorf("error granting broadcast: %w", err)
}
result = append(result, selfTokens.Tokens...)
return TokenList{Tokens: result}, nil
}
func (ctx *BasicCapabilityContext) delegateBroadcast(tokenList []*Token, anchor did.DID, subject did.DID, topic string, expire uint64, provide []Capability) []*Token {
var result []*Token //nolint
for _, t := range tokenList {
var providing []Capability
for _, c := range provide {
if t.Anchor(anchor) && t.AllowDelegation(ctx.DID(), did.DID{}, []string{topic}, expire, c) {
providing = append(providing, c)
}
}
if len(providing) == 0 {
continue
}
token, err := t.DelegateBroadcast(ctx.provider, subject, topic, expire, providing)
if err != nil {
log.Debugf("error delegating invocation %s to %s: %s", providing, subject, err)
continue
}
result = append(result, token)
}
return result
}
func (ctx *BasicCapabilityContext) Consume(origin did.DID, data []byte) error {
if len(data) > maxCapabilitySize {
return ErrTooBig
}
var tokens TokenList
if err := json.Unmarshal(data, &tokens); err != nil {
return fmt.Errorf("unmarshaling payload: %w", err)
}
rootAnchors := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
now := uint64(time.Now().UnixNano())
for _, t := range tokens.Tokens {
if t.Anchor(ctx.DID()) {
goto verify
}
if t.Anchor(origin) {
goto verify
}
for _, anchor := range rootAnchors {
if t.Anchor(anchor) {
goto verify
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) {
goto verify
}
}
}
continue
verify:
if err := t.Verify(ctx.trust, now); err != nil {
log.Warnf("failed to verify token issued by %s: %s", t.Issuer(), err)
continue
}
ctx.consumeSubjectToken(t)
}
return nil
}
func (ctx *BasicCapabilityContext) Discard(data []byte) {
var tokens TokenList
if err := json.Unmarshal(data, &tokens); err != nil {
return
}
ctx.discardTokens(tokens.Tokens)
}
func (ctx *BasicCapabilityContext) consumeAnchorToken(getf func() []*Token, setf func(result []*Token), t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList := getf()
result := make([]*Token, 0, len(tokenList)+1)
for _, ot := range tokenList {
if ot.Subsumes(t) {
return
}
if t.Subsumes(ot) {
continue
}
result = append(result, ot)
}
result = append(result, t)
setf(result)
}
func (ctx *BasicCapabilityContext) consumeRequireToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.require[t.Issuer()] },
func(result []*Token) {
ctx.require[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeProvideToken(t *Token) {
ctx.consumeAnchorToken(
func() []*Token { return ctx.provide[t.Issuer()] },
func(result []*Token) {
ctx.provide[t.Issuer()] = result
},
t,
)
}
func (ctx *BasicCapabilityContext) consumeSubjectToken(t *Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
subject := t.Subject()
tokenList := ctx.tokens[subject]
tokenList = append(tokenList, t)
ctx.tokens[subject] = tokenList
}
func (ctx *BasicCapabilityContext) Require(anchor did.DID, subject crypto.ID, audience crypto.ID, cap []Capability) error {
if len(cap) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return fmt.Errorf("DID for audience: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
for _, t := range tokenList {
for _, c := range cap {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowInvocation(subjectDID, audienceDID, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) RequireBroadcast(anchor did.DID, subject crypto.ID, topic string, require []Capability) error {
if len(require) == 0 {
return fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return fmt.Errorf("DID for subject: %w", err)
}
tokenList := ctx.getSubjectTokens(subjectDID)
roots := ctx.getRoots()
requireAnchors := ctx.getRequireAnchors()
for _, t := range tokenList {
for _, c := range require {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topic, c) {
return nil
}
for _, anchor := range roots {
if t.Anchor(anchor) && t.AllowBroadcast(subjectDID, topic, c) {
return nil
}
}
for _, anchor := range requireAnchors {
for _, rt := range ctx.getRequireTokens(anchor) {
if rt.AllowAction(t) && t.AllowBroadcast(subjectDID, topic, c) {
return nil
}
}
}
}
}
return ErrNotAuthorized
}
func (ctx *BasicCapabilityContext) Provide(target did.DID, subject crypto.ID, audience crypto.ID, expire uint64, invoke []Capability, provide []Capability) ([]byte, error) {
if len(invoke) == 0 && len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
audienceDID, err := did.FromID(audience)
if err != nil {
return nil, fmt.Errorf("DID for audience: %w", err)
}
var result []*Token
var invocation, delegation TokenList
if len(invoke) == 0 {
return nil, fmt.Errorf("no invocation capabilities: %w", ErrNotAuthorized)
}
invocation, err = ctx.DelegateInvocation(target, subjectDID, audienceDID, expire, invoke)
if err != nil {
return nil, fmt.Errorf("cannot provide invocation tokens: %w", err)
}
if len(invocation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary invocation tokens: %w", ErrNotAuthorized)
}
result = append(result, invocation.Tokens...)
if len(provide) == 0 {
goto marshal
}
delegation, err = ctx.Delegate(target, subjectDID, nil, expire, provide, true)
if err != nil {
return nil, fmt.Errorf("cannot provide delegation tokens: %w", err)
}
if len(delegation.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary delegation tokens: %w", ErrNotAuthorized)
}
result = append(result, delegation.Tokens...)
marshal:
payload := TokenList{Tokens: result}
data, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) ProvideBroadcast(subject crypto.ID, topic string, expire uint64, provide []Capability) ([]byte, error) {
if len(provide) == 0 {
return nil, fmt.Errorf("no capabilities: %w", ErrNotAuthorized)
}
subjectDID, err := did.FromID(subject)
if err != nil {
return nil, fmt.Errorf("DID for subject: %w", err)
}
broadcast, err := ctx.DelegateBroadcast(subjectDID, topic, expire, provide)
if err != nil {
return nil, fmt.Errorf("cannot provide broadcast tokens: %w", err)
}
if len(broadcast.Tokens) == 0 {
return nil, fmt.Errorf("cannot provide the necessary broadcast tokens: %w", ErrNotAuthorized)
}
data, err := json.Marshal(broadcast)
if err != nil {
return nil, fmt.Errorf("marshaling payload: %w", err)
}
return data, nil
}
func (ctx *BasicCapabilityContext) getRoots() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.roots))
for anchor := range ctx.roots {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) addRoots(anchors []did.DID) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, anchor := range anchors {
ctx.roots[anchor] = struct{}{}
}
}
func (ctx *BasicCapabilityContext) getRequireAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.require))
for anchor := range ctx.require {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getProvideAnchors() []did.DID {
ctx.mx.Lock()
defer ctx.mx.Unlock()
result := make([]did.DID, 0, len(ctx.provide))
for anchor := range ctx.provide {
result = append(result, anchor)
}
return result
}
func (ctx *BasicCapabilityContext) getTokens(getf func() ([]*Token, bool), setf func([]*Token)) []*Token {
ctx.mx.Lock()
defer ctx.mx.Unlock()
tokenList, ok := getf()
if !ok {
return nil
}
// filter expired
now := uint64(time.Now().UnixNano())
result := slices.DeleteFunc(slices.Clone(tokenList), func(t *Token) bool {
return t.ExpireBefore(now)
})
setf(result)
return result
}
func (ctx *BasicCapabilityContext) getRequireTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.require[anchor]; return result, ok },
func(result []*Token) { ctx.require[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getProvideTokens(anchor did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.provide[anchor]; return result, ok },
func(result []*Token) { ctx.provide[anchor] = result },
)
}
func (ctx *BasicCapabilityContext) getSubjectTokens(subject did.DID) []*Token {
return ctx.getTokens(
func() ([]*Token, bool) { result, ok := ctx.tokens[subject]; return result, ok },
func(result []*Token) { ctx.tokens[subject] = result },
)
}
func (ctx *BasicCapabilityContext) discardTokens(tokens []*Token) {
ctx.mx.Lock()
defer ctx.mx.Unlock()
for _, t := range tokens {
subject := t.Subject()
subjectTokens := slices.DeleteFunc(slices.Clone(ctx.tokens[subject]), func(ot *Token) bool {
return t.Issuer() == ot.Issuer() && bytes.Equal(t.Nonce(), ot.Nonce())
})
if len(subjectTokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = subjectTokens
}
}
}
func (ctx *BasicCapabilityContext) gc(gcCtx context.Context, gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ctx.gcTokens()
case <-gcCtx.Done():
return
}
}
}
func (ctx *BasicCapabilityContext) gcTokens() {
ctx.mx.Lock()
defer ctx.mx.Unlock()
now := uint64(time.Now().UnixNano())
for anchor, tokens := range ctx.require {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.require, anchor)
} else {
ctx.require[anchor] = tokens
}
}
for anchor, tokens := range ctx.provide {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.provide, anchor)
} else {
ctx.provide[anchor] = tokens
}
}
for subject, tokens := range ctx.tokens {
tokens = slices.DeleteFunc(slices.Clone(tokens), func(t *Token) bool {
return t.ExpireBefore(now)
})
if len(tokens) == 0 {
delete(ctx.tokens, subject)
} else {
ctx.tokens[subject] = tokens
}
}
}
package ucan
import (
"io"
"gitlab.com/nunet/device-management-service/lib/did"
)
func SaveCapabilityContext(_ CapabilityContext, _ io.Writer) (int, error) {
// TODO
return 0, ErrTODO
}
func LoadCapabilityContext(_ io.Reader, _ did.TrustContext) (CapabilityContext, error) {
// TODO
return nil, ErrTODO
}
package ucan
import (
"crypto/rand"
"encoding/json"
"fmt"
"slices"
"time"
"gitlab.com/nunet/device-management-service/lib/did"
)
type Action string
const (
Invoke Action = "invoke"
Delegate Action = "delegate"
Broadcast Action = "broadcast"
// Revoke Action = "revoke" // TODO
nonceLength = 12 // 96 bits
)
var signaturePrefix = []byte("dms:token:")
type Token struct {
// DMS tokens
DMS *DMSToken `json:"dms,omitempty"`
// UCAN standard (when it is done) envelope for BYO anhcors
UCAN *BYOToken `json:"ucan,omitempty"`
}
type DMSToken struct {
Issuer did.DID `json:"iss"`
Subject did.DID `json:"sub"`
Audience did.DID `json:"aud"`
Action Action `json:"act"`
Topic []string `json:"topic,omitempty"`
Capability []Capability `json:"cap"`
Nonce []byte `json:"nonce"`
Expire uint64 `json:"exp"`
Chain *Token `json:"chain,omitempty"`
Signature []byte `json:"sig,omitempty"`
}
type BYOToken struct {
// TODO followup
}
type TokenList struct {
Tokens []*Token `json:"tok,omitempty"`
}
func (t *Token) SignatureData() ([]byte, error) {
switch {
case t.DMS != nil:
return t.DMS.SignatureData()
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) SignatureData() ([]byte, error) {
tCopy := *t
tCopy.Signature = nil
data, err := json.Marshal(&tCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (t *Token) Issuer() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Issuer
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Subject() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Subject
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Audience() did.DID {
switch {
case t.DMS != nil:
return t.DMS.Audience
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return did.DID{}
}
}
func (t *Token) Capability() []Capability {
switch {
case t.DMS != nil:
return t.DMS.Capability
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Topic() []string {
switch {
case t.DMS != nil:
return t.DMS.Topic
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil
}
}
func (t *Token) Expire() uint64 {
switch {
case t.DMS != nil:
return t.DMS.Expire
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return 0
}
}
func (t *Token) Nonce() []byte {
switch {
case t.DMS != nil:
return t.DMS.Nonce
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil // expired right after the unix big bang
}
}
func (t *Token) Action() Action {
switch {
case t.DMS != nil:
return t.DMS.Action
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return Action("")
}
}
func (t *Token) Verify(trust did.TrustContext, now uint64) error {
switch {
case t.DMS != nil:
return t.DMS.Verify(trust, now)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return ErrBadToken
}
}
func (t *DMSToken) Verify(trust did.TrustContext, now uint64) error {
if t.ExpireBefore(now) {
return ErrCapabilityExpired
}
if t.Chain != nil {
if t.Chain.Action() != Delegate {
return ErrNotAuthorized
}
if t.Chain.ExpireBefore(t.Expire) {
return ErrCapabilityExpired
}
if err := t.Chain.Verify(trust, now); err != nil {
return err
}
if !t.Issuer.Equal(t.Chain.Subject()) {
return ErrNotAuthorized
}
for _, topic := range t.Topic {
if !slices.Contains(t.Chain.Topic(), topic) {
return ErrNotAuthorized
}
}
needCapability := slices.Clone(t.Capability)
loop:
for _, c := range needCapability {
if t.Chain.AllowDelegation(t.Issuer, t.Audience, t.Topic, t.Expire, c) {
needCapability = slices.DeleteFunc(needCapability, func(oc Capability) bool {
return c == oc
})
if len(needCapability) == 0 {
break loop
}
}
}
if len(needCapability) > 0 {
return ErrNotAuthorized
}
}
anchor, err := trust.GetAnchor(t.Issuer)
if err != nil {
return fmt.Errorf("verify: anchor: %w", err)
}
data, err := t.SignatureData()
if err != nil {
return fmt.Errorf("verify: signature data: %w", err)
}
if err := anchor.Verify(data, t.Signature); err != nil {
return fmt.Errorf("verify: signature: %w", err)
}
return nil
}
func (t *Token) AllowAction(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowAction(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowAction(ot *Token) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(ot.Expire()) {
return false
}
if !ot.Anchor(t.Subject) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(ot.Audience()) {
return false
}
if ot.Action() == Broadcast {
for _, topic := range ot.Topic() {
if !slices.Contains(t.Topic, topic) {
return false
}
}
}
for _, oc := range ot.Capability() {
allow := false
for _, c := range t.Capability {
if c.Implies(oc) {
allow = true
break
}
}
if !allow {
return false
}
}
return true
}
func (t *Token) Size() int {
data, _ := t.SignatureData()
return len(data)
}
func (t *Token) Subsumes(ot *Token) bool {
switch {
case t.DMS != nil:
return t.DMS.Subsumes(ot)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Subsumes(ot *Token) bool {
if t.Issuer.Equal(ot.Issuer()) &&
t.Subject.Equal(ot.Subject()) &&
t.Audience.Equal(ot.Audience()) &&
t.Expire > ot.Expire() {
loop:
for _, oc := range ot.Capability() {
for _, c := range t.Capability {
if c.Implies(oc) {
continue loop
}
}
return false
}
return true
}
return false
}
func (t *Token) AllowInvocation(subject, audience did.DID, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowInvocation(subject, audience, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowInvocation(subject, audience did.DID, c Capability) bool {
if t.Action != Invoke {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, granted := range t.Capability {
if granted.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowBroadcast(subject did.DID, topic string, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowBroadcast(subject, topic, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowBroadcast(subject did.DID, topic string, c Capability) bool {
if t.Action != Broadcast {
return false
}
if !t.Subject.Equal(subject) {
return false
}
if !t.Audience.Empty() {
return false
}
if !slices.Contains(t.Topic, topic) {
return false
}
for _, granted := range t.Capability {
if granted.Implies(c) {
return true
}
}
return false
}
func (t *Token) AllowDelegation(issuer, audience did.DID, topics []string, expire uint64, c Capability) bool {
switch {
case t.DMS != nil:
return t.DMS.AllowDelegation(issuer, audience, topics, expire, c)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) AllowDelegation(issuer, audience did.DID, topics []string, expire uint64, c Capability) bool {
if t.Action != Delegate {
return false
}
if t.ExpireBefore(expire) {
return false
}
if !t.Subject.Equal(issuer) {
return false
}
if !t.Audience.Empty() && !t.Audience.Equal(audience) {
return false
}
for _, topic := range topics {
if !slices.Contains(t.Topic, topic) {
return false
}
}
for _, granted := range t.Capability {
if granted.Implies(c) {
return true
}
}
return false
}
func (t *Token) Delegate(provider did.Provider, subject, audience did.DID, topics []string, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.Delegate(provider, subject, audience, topics, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) Delegate(provider did.Provider, subject, audience did.DID, topics []string, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Delegate, provider, subject, audience, topics, expire, c)
}
func (t *DMSToken) delegate(action Action, provider did.Provider, subject, audience did.DID, topics []string, expire uint64, c []Capability) (*DMSToken, error) {
if t.Action != Delegate {
return nil, ErrNotAuthorized
}
nonce := make([]byte, nonceLength)
_, err := rand.Read(nonce)
if err != nil {
return nil, fmt.Errorf("nonce: %w", err)
}
result := &DMSToken{
Action: action,
Issuer: provider.DID(),
Subject: subject,
Audience: audience,
Topic: topics,
Capability: c,
Nonce: nonce,
Expire: expire,
Chain: &Token{DMS: t},
}
data, err := result.SignatureData()
if err != nil {
return nil, fmt.Errorf("delegate: %w", err)
}
sig, err := provider.Sign(data)
if err != nil {
return nil, fmt.Errorf("sign: %w", err)
}
result.Signature = sig
return result, nil
}
func (t *Token) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateInvocation(provider, subject, audience, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateInvocation(provider did.Provider, subject, audience did.DID, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Invoke, provider, subject, audience, nil, expire, c)
}
func (t *Token) DelegateBroadcast(provider did.Provider, subject did.DID, topic string, expire uint64, c []Capability) (*Token, error) {
switch {
case t.DMS != nil:
result, err := t.DMS.DelegateBroadcast(provider, subject, topic, expire, c)
if err != nil {
return nil, fmt.Errorf("delegate invocation: %w", err)
}
return &Token{DMS: result}, nil
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return nil, ErrBadToken
}
}
func (t *DMSToken) DelegateBroadcast(provider did.Provider, subject did.DID, topic string, expire uint64, c []Capability) (*DMSToken, error) {
return t.delegate(Broadcast, provider, subject, did.DID{}, []string{topic}, expire, c)
}
func (t *Token) Anchor(anchor did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.Anchor(anchor)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) Anchor(anchor did.DID) bool {
if t.Issuer.Equal(anchor) {
return true
}
if t.Chain != nil {
return t.Chain.Anchor(anchor)
}
return false
}
func (t *Token) Expired() bool {
return t.ExpireBefore(uint64(time.Now().UnixNano()))
}
func (t *Token) ExpireBefore(deadline uint64) bool {
switch {
case t.DMS != nil:
return t.DMS.ExpireBefore(deadline)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return true
}
}
func (t *DMSToken) ExpireBefore(deadline uint64) bool {
if deadline > t.Expire {
return true
}
if t.Chain != nil {
return t.Chain.ExpireBefore(deadline)
}
return false
}
func (t *Token) SelfSigned(origin did.DID) bool {
switch {
case t.DMS != nil:
return t.DMS.SelfSigned(origin)
case t.UCAN != nil:
// TODO UCAN envelopes for BYO trust; followup
fallthrough
default:
return false
}
}
func (t *DMSToken) SelfSigned(origin did.DID) bool {
if t.Chain != nil {
return t.Chain.SelfSigned(origin)
}
return t.Issuer.Equal(origin)
}
package libp2p
import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/multiformats/go-multiaddr"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"google.golang.org/protobuf/proto"
)
// Bootstrap using a list.
func (l *Libp2p) Bootstrap(ctx context.Context, bootstrapPeers []multiaddr.Multiaddr) error {
if err := l.DHT.Bootstrap(ctx); err != nil {
return fmt.Errorf("failed to prepare this node for bootstraping: %w", err)
}
// bootstrap all nodes at the same time.
if len(bootstrapPeers) > 0 {
var wg sync.WaitGroup
for _, addr := range bootstrapPeers {
wg.Add(1)
go func(peerAddr multiaddr.Multiaddr) {
defer wg.Done()
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
zlog.Sugar().Errorf("failed to convert multi addr to addr info %v - %v", peerAddr, err)
return
}
if err := l.Host.Connect(ctx, *addrInfo); err != nil {
zlog.Sugar().Errorf("failed to connect to bootstrap node %s - %v", addrInfo.ID.String(), err)
} else {
zlog.Sugar().Infof("connected to Bootstrap Node %s", addrInfo.ID.String())
}
}(addr)
}
wg.Wait()
}
return nil
}
type dhtValidator struct {
PS peerstore.Peerstore
customNamespace string
}
// Validate validates an item placed into the dht.
func (d dhtValidator) Validate(key string, value []byte) error {
// empty value is considered deleting an item from the dht
if len(value) == 0 {
return nil
}
if !strings.HasPrefix(key, d.customNamespace) {
return errors.New("invalid key namespace")
}
// verify signature
var envelope commonproto.Advertisement
err := proto.Unmarshal(value, &envelope)
if err != nil {
return fmt.Errorf("failed to unmarshal envelope: %w", err)
}
pubKey, err := crypto.UnmarshalSecp256k1PublicKey(envelope.PublicKey)
if err != nil {
return fmt.Errorf("failed to unmarshal public key: %w", err)
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
envelope.PublicKey,
}, nil)
ok, err := pubKey.Verify(concatenatedBytes, envelope.Signature)
if err != nil {
return fmt.Errorf("failed to verify envelope: %w", err)
}
if !ok {
return errors.New("failed to verify envelope, public key didn't sign payload")
}
return nil
}
func (dhtValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
// TODO remove the below when network package is fully implemented
// UpdateKadDHT is a stub
func (l *Libp2p) UpdateKadDHT() {
zlog.Warn("UpdateKadDHT: Stub")
}
// ListKadDHTPeers is a stub
func (l *Libp2p) ListKadDHTPeers(_ context.Context, _ *gin.Context) ([]string, error) {
zlog.Warn("ListKadDHTPeers: Stub")
return nil, nil
}
package libp2p
import (
"context"
"fmt"
"os"
"github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
dutil "github.com/libp2p/go-libp2p/p2p/discovery/util"
)
// DiscoverDialPeers discovers peers using randevouz point
func (l *Libp2p) DiscoverDialPeers(ctx context.Context) error {
foundPeers, err := l.findPeersFromRendezvousDiscovery(ctx)
if err != nil {
return err
}
if len(foundPeers) > 0 {
l.discoveredPeers = foundPeers
}
// filter out peers with no listening addresses and self host
filterSpec := NoAddrIDFilter{ID: l.Host.ID()}
l.discoveredPeers = PeerPassFilter(l.discoveredPeers, filterSpec)
l.dialPeers(ctx)
return nil
}
// advertiseForRendezvousDiscovery is used to advertise node using the dht by giving it the randevouz point.
func (l *Libp2p) advertiseForRendezvousDiscovery(context context.Context) error {
_, err := l.discovery.Advertise(context, l.config.Rendezvous)
return err
}
// findPeersFromRendezvousDiscovery uses the randevouz point to discover other peers.
func (l *Libp2p) findPeersFromRendezvousDiscovery(ctx context.Context) ([]peer.AddrInfo, error) {
peers, err := dutil.FindPeers(
ctx,
l.discovery,
l.config.Rendezvous,
discovery.Limit(l.config.PeerCountDiscoveryLimit),
)
if err != nil {
return nil, fmt.Errorf("failed to discover peers: %w", err)
}
return peers, nil
}
func (l *Libp2p) dialPeers(ctx context.Context) {
for _, p := range l.discoveredPeers {
if p.ID == l.Host.ID() {
continue
}
if l.Host.Network().Connectedness(p.ID) != network.Connected {
_, err := l.Host.Network().DialPeer(ctx, p.ID)
if err != nil {
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("couldn't establish connection with: %s - error: %v", p.ID.String(), err)
}
continue
}
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("connected with: %s", p.ID.String())
}
}
}
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/control"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
mafilt "github.com/whyrusleeping/multiaddr-filter"
)
var defaultServerFilters = []string{
"/ip4/10.0.0.0/ipcidr/8",
"/ip4/100.64.0.0/ipcidr/10",
"/ip4/169.254.0.0/ipcidr/16",
"/ip4/172.16.0.0/ipcidr/12",
"/ip4/192.0.0.0/ipcidr/24",
"/ip4/192.0.2.0/ipcidr/24",
"/ip4/192.168.0.0/ipcidr/16",
"/ip4/198.18.0.0/ipcidr/15",
"/ip4/198.51.100.0/ipcidr/24",
"/ip4/203.0.113.0/ipcidr/24",
"/ip4/240.0.0.0/ipcidr/4",
"/ip6/100::/ipcidr/64",
"/ip6/2001:2::/ipcidr/48",
"/ip6/2001:db8::/ipcidr/32",
"/ip6/fc00::/ipcidr/7",
"/ip6/fe80::/ipcidr/10",
}
// PeerFilter is an interface for filtering peers
// satisfaction of filter criteria allows the peer to pass
type PeerFilter interface {
satisfies(p peer.AddrInfo) bool
}
// NoAddrIDFilter filters out peers with no listening addresses
// and a peer with a specific ID
type NoAddrIDFilter struct {
ID peer.ID
}
func (f NoAddrIDFilter) satisfies(p peer.AddrInfo) bool {
return len(p.Addrs) > 0 && p.ID != f.ID
}
func PeerPassFilter(peers []peer.AddrInfo, pf PeerFilter) []peer.AddrInfo {
var filtered []peer.AddrInfo
for _, p := range peers {
if pf.satisfies(p) {
filtered = append(filtered, p)
}
}
return filtered
}
type filtersConnectionGater multiaddr.Filters
var _ connmgr.ConnectionGater = (*filtersConnectionGater)(nil)
func (f *filtersConnectionGater) InterceptAddrDial(_ peer.ID, addr multiaddr.Multiaddr) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(addr)
}
func (f *filtersConnectionGater) InterceptPeerDial(_ peer.ID) (allow bool) {
return true
}
func (f *filtersConnectionGater) InterceptAccept(connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptUpgraded(_ network.Conn) (allow bool, reason control.DisconnectReason) {
return true, 0
}
func makeAddrsFactory(announce []string, appendAnnouce []string, noAnnounce []string) func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var err error // To assign to the slice in the for loop
existing := make(map[string]bool) // To avoid duplicates
annAddrs := make([]multiaddr.Multiaddr, len(announce))
for i, addr := range announce {
annAddrs[i], err = multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
existing[addr] = true
}
appendAnnAddrs := make([]multiaddr.Multiaddr, 0)
for _, addr := range appendAnnouce {
if existing[addr] {
// skip AppendAnnounce that is on the Announce list already
continue
}
appendAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
appendAnnAddrs = append(appendAnnAddrs, appendAddr)
}
filters := multiaddr.NewFilters()
noAnnAddrs := map[string]bool{}
for _, addr := range noAnnounce {
f, err := mafilt.NewMask(addr)
if err == nil {
filters.AddFilter(*f, multiaddr.ActionDeny)
continue
}
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
noAnnAddrs[string(maddr.Bytes())] = true
}
return func(allAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
var addrs []multiaddr.Multiaddr
if len(annAddrs) > 0 {
addrs = annAddrs
} else {
addrs = allAddrs
}
addrs = append(addrs, appendAnnAddrs...)
var out []multiaddr.Multiaddr
for _, maddr := range addrs {
// check for exact matches
ok := noAnnAddrs[string(maddr.Bytes())]
// check for /ipcidr matches
if !ok && !filters.AddrBlocked(maddr) {
out = append(out, maddr)
}
}
return out
}
}
package libp2p
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/types"
)
// StreamHandler is a function type that processes data from a stream.
type StreamHandler func(stream network.Stream)
// HandlerRegistry manages the registration of stream handlers for different protocols.
type HandlerRegistry struct {
host host.Host
handlers map[protocol.ID]StreamHandler
bytesHandlers map[protocol.ID]func(data []byte)
mu sync.RWMutex
}
// NewHandlerRegistry creates a new handler registry instance.
func NewHandlerRegistry(host host.Host) *HandlerRegistry {
return &HandlerRegistry{
host: host,
handlers: make(map[protocol.ID]StreamHandler),
bytesHandlers: make(map[protocol.ID]func(data []byte)),
}
}
// RegisterHandlerWithStreamCallback registers a stream handler for a specific protocol.
func (r *HandlerRegistry) RegisterHandlerWithStreamCallback(messageType types.MessageType, handler StreamHandler) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.handlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.handlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(handler))
return nil
}
// RegisterHandlerWithBytesCallback registers a stream handler for a specific protocol and sends the bytes back to callback.
func (r *HandlerRegistry) RegisterHandlerWithBytesCallback(messageType types.MessageType, s StreamHandler, handler func(data []byte)) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.bytesHandlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.bytesHandlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(s))
return nil
}
// SendMessageToLocalHandler given the message type it sends data to the local handler found.
func (r *HandlerRegistry) SendMessageToLocalHandler(messageType types.MessageType, data []byte) {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
h, ok := r.bytesHandlers[protoID]
if !ok {
return
}
// we need this goroutine to avoid blocking the caller goroutine
go h(data)
}
package libp2p
import (
"context"
"strings"
"time"
"github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/routing"
"github.com/libp2p/go-libp2p/p2p/host/autorelay"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/security/noise"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"github.com/multiformats/go-multiaddr"
mafilt "github.com/whyrusleeping/multiaddr-filter"
"gitlab.com/nunet/device-management-service/types"
)
// NewHost returns a new libp2p host with dht and other related settings.
func NewHost(ctx context.Context, config *types.Libp2pConfig) (host.Host, *dht.IpfsDHT, *pubsub.PubSub, error) {
var idht *dht.IpfsDHT
connmgr, err := connmgr.NewConnManager(
100,
400,
connmgr.WithGracePeriod(time.Duration(config.GracePeriodMs)*time.Millisecond),
)
if err != nil {
return nil, nil, nil, err
}
filter := multiaddr.NewFilters()
for _, s := range defaultServerFilters {
f, err := mafilt.NewMask(s)
if err != nil {
zlog.Sugar().Errorf("incorrectly formatted address filter in config: %s - %v", s, err)
}
filter.AddFilter(*f, multiaddr.ActionDeny)
}
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, nil, nil, err
}
var libp2pOpts []libp2p.Option
baseOpts := []dht.Option{
dht.ProtocolPrefix(protocol.ID(config.DHTPrefix)),
dht.NamespacedValidator(strings.ReplaceAll(config.CustomNamespace, "/", ""), dhtValidator{PS: ps}),
dht.Mode(dht.ModeServer),
}
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(config.ListenAddress...),
libp2p.Identity(config.PrivateKey),
libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) {
idht, err = dht.New(ctx, h, baseOpts...)
return idht, err
}),
libp2p.Peerstore(ps),
libp2p.Security(libp2ptls.ID, libp2ptls.New),
libp2p.Security(noise.ID, noise.New),
// libp2p.NoListenAddrs,
libp2p.ChainOptions(
libp2p.Transport(tcp.NewTCPTransport),
libp2p.Transport(quic.NewTransport),
libp2p.Transport(webtransport.New),
libp2p.Transport(ws.New),
),
libp2p.EnableNATService(),
libp2p.ConnectionManager(connmgr),
libp2p.EnableRelay(),
libp2p.EnableHolePunching(),
libp2p.EnableRelayService(
relay.WithResources(
relay.Resources{
MaxReservations: 256,
MaxCircuits: 32,
BufferSize: 4096,
MaxReservationsPerPeer: 8,
MaxReservationsPerIP: 16,
},
),
relay.WithLimit(&relay.RelayLimit{
Duration: 5 * time.Minute,
Data: 1 << 21, // 2 MiB
}),
),
libp2p.EnableAutoRelayWithPeerSource(
func(ctx context.Context, num int) <-chan peer.AddrInfo {
r := make(chan peer.AddrInfo)
go func() {
defer close(r)
for i := 0; i < num; i++ {
select {
case p := <-newPeer:
select {
case r <- p:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return r
},
autorelay.WithBootDelay(time.Minute),
autorelay.WithBackoff(30*time.Second),
autorelay.WithMinCandidates(2),
autorelay.WithMaxCandidates(3),
autorelay.WithNumRelays(2),
),
)
if config.Server {
libp2pOpts = append(libp2pOpts, libp2p.AddrsFactory(makeAddrsFactory([]string{}, []string{}, defaultServerFilters)))
libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater((*filtersConnectionGater)(filter)))
} else {
libp2pOpts = append(libp2pOpts, libp2p.NATPortMap())
}
host, err := libp2p.New(libp2pOpts...)
if err != nil {
return nil, nil, nil, err
}
optsPS := []pubsub.Option{
pubsub.WithFloodPublish(true),
pubsub.WithMessageSigning(true),
pubsub.WithMaxMessageSize(config.GossipMaxMessageSize),
}
gossip, err := pubsub.NewGossipSub(ctx, host, optsPS...)
// gossip, err := pubsub.NewGossipSubWithRouter(ctx, host, pubsub.DefaultGossipSubRouter(host), optsPS...)
if err != nil {
return nil, nil, nil, err
}
return host, idht, gossip, nil
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
// TODO: pass the logger to the constructor and remove from here
var (
zlog *otelzap.Logger
newPeer = make(chan peer.AddrInfo)
)
func init() {
zlog = logger.OtelZapLogger("network.libp2p")
}
package libp2p
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/ipfs/go-cid"
kbucket "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"gitlab.com/nunet/device-management-service/types"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
"github.com/spf13/afero"
"google.golang.org/protobuf/proto"
dht "github.com/libp2p/go-libp2p-kad-dht"
libp2pdiscovery "github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
drouting "github.com/libp2p/go-libp2p/p2p/discovery/routing"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
)
const (
MB = 1024 * 1024
MaxMessageLengthMB = 10
ValidationAccept = pubsub.ValidationAccept
ValidationReject = pubsub.ValidationReject
ValidationIgnore = pubsub.ValidationIgnore
)
type (
ValidationResult = pubsub.ValidationResult
Validator func([]byte, interface{}) (ValidationResult, interface{})
)
// Libp2p contains the configuration for a Libp2p instance.
//
// TODO-suggestion: maybe we should call it something else like Libp2pPeer,
// Libp2pHost or just Peer (callers would use libp2p.Peer...)
type Libp2p struct {
Host host.Host
DHT *dht.IpfsDHT
PS peerstore.Peerstore
pubsub *pubsub.PubSub
pubsubTopics map[string]*pubsub.Topic
topicValidators map[string]map[uint64]Validator
topicSubscription map[string]map[uint64]*pubsub.Subscription
nextTopicSubID uint64
topicMux sync.RWMutex
// a list of peers discovered by discovery
discoveredPeers []peer.AddrInfo
discovery libp2pdiscovery.Discovery
// services
pingService *ping.PingService
// tasks
discoveryTask *bt.Task
handlerRegistry *HandlerRegistry
config *types.Libp2pConfig
// dependencies (db, filesystem...)
fs afero.Fs
}
// New creates a libp2p instance.
//
// TODO-Suggestion: move types.Libp2pConfig to here for better readability.
// Unless there is a reason to keep within types.
func New(config *types.Libp2pConfig, fs afero.Fs) (*Libp2p, error) {
if config == nil {
return nil, errors.New("config is nil")
}
if config.Scheduler == nil {
return nil, errors.New("scheduler is nil")
}
return &Libp2p{
config: config,
discoveredPeers: make([]peer.AddrInfo, 0),
pubsubTopics: make(map[string]*pubsub.Topic),
topicSubscription: make(map[string]map[uint64]*pubsub.Subscription),
topicValidators: make(map[string]map[uint64]Validator),
fs: fs,
}, nil
}
// Init initializes a libp2p host with its dependencies.
func (l *Libp2p) Init(context context.Context) error {
host, dht, pubsub, err := NewHost(context, l.config)
if err != nil {
zlog.Sugar().Error(err)
return err
}
l.Host = host
l.DHT = dht
l.PS = host.Peerstore()
l.discovery = drouting.NewRoutingDiscovery(dht)
l.pubsub = pubsub
l.handlerRegistry = NewHandlerRegistry(host)
return nil
}
// Start performs network bootstrapping, peer discovery and protocols handling.
func (l *Libp2p) Start(context context.Context) error {
// set stream handlers
l.registerStreamHandlers()
// bootstrap should return error if it had an error
err := l.Bootstrap(context, l.config.BootstrapPeers)
if err != nil {
zlog.Sugar().Errorf("failed to start network: %v", err)
return err
}
// advertise randevouz discovery
err = l.advertiseForRendezvousDiscovery(context)
if err != nil {
// TODO: the error might be misleading as a peer can normally work well if an error
// is returned here (e.g.: the error is yielded in tests even though all tests pass).
zlog.Sugar().Errorf("failed to start network with randevouz discovery: %v", err)
}
// discover
err = l.DiscoverDialPeers(context)
if err != nil {
zlog.Sugar().Errorf("failed to discover peers: %v", err)
}
// register period peer discoveryTask task
discoveryTask := &bt.Task{
Name: "Peer Discovery",
Description: "Periodic task to discover new peers every 15 minutes",
Function: func(_ interface{}) error {
return l.DiscoverDialPeers(context)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 15 * time.Minute}},
}
l.discoveryTask = l.config.Scheduler.AddTask(discoveryTask)
l.config.Scheduler.Start()
return nil
}
// RegisterStreamMessageHandler registers a stream handler for a specific protocol.
func (l *Libp2p) RegisterStreamMessageHandler(messageType types.MessageType, handler StreamHandler) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithStreamCallback(messageType, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// RegisterBytesMessageHandler registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) RegisterBytesMessageHandler(messageType types.MessageType, handler func(data []byte)) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithBytesCallback(messageType, l.handleReadBytesFromStream, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// HandleMessage registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) HandleMessage(messageType string, handler func(data []byte)) error {
return l.RegisterBytesMessageHandler(types.MessageType(messageType), handler)
}
func (l *Libp2p) handleReadBytesFromStream(s network.Stream) {
callback, ok := l.handlerRegistry.bytesHandlers[s.Protocol()]
if !ok {
s.Close()
return
}
c := bufio.NewReader(s)
defer s.Close()
// read the first 8 bytes to determine the size of the message
msgLengthBuffer := make([]byte, 8)
_, err := c.Read(msgLengthBuffer)
if err != nil {
return
}
// create a buffer with the size of the message and then read until its full
lengthPrefix := binary.LittleEndian.Uint64(msgLengthBuffer)
buf := make([]byte, lengthPrefix)
// read the full message
_, err = io.ReadFull(c, buf)
if err != nil {
return
}
callback(buf)
}
// SendMessage sends a message to a list of peers.
func (l *Libp2p) SendMessage(ctx context.Context, addrs []string, msg types.MessageEnvelope) error {
var wg sync.WaitGroup
errCh := make(chan error, len(addrs))
for _, addr := range addrs {
wg.Add(1)
go func(addr string) {
defer wg.Done()
err := l.sendMessage(ctx, addr, msg)
if err != nil {
errCh <- err
}
}(addr)
}
wg.Wait()
close(errCh)
var result error
for err := range errCh {
if result == nil {
result = err
} else {
result = fmt.Errorf("%v; %v", result, err)
}
}
return result
}
// OpenStream opens a stream to a remote address and returns the stream for the caller to handle.
func (l *Libp2p) OpenStream(ctx context.Context, addr string, messageType types.MessageType) (network.Stream, error) {
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil, fmt.Errorf("invalid multiaddress: %w", err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(maddr)
if err != nil {
return nil, fmt.Errorf("could not resolve peer info: %w", err)
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(messageType))
if err != nil {
return nil, fmt.Errorf("failed to open stream: %w", err)
}
return stream, nil
}
// GetMultiaddr returns the peer's multiaddr.
func (l *Libp2p) GetMultiaddr() ([]multiaddr.Multiaddr, error) {
peerInfo := peer.AddrInfo{
ID: l.Host.ID(),
Addrs: l.Host.Addrs(),
}
return peer.AddrInfoToP2pAddrs(&peerInfo)
}
// Stop performs a cleanup of any resources used in this package.
func (l *Libp2p) Stop() error {
var errorMessages []string
l.config.Scheduler.RemoveTask(l.discoveryTask.ID)
if err := l.DHT.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if err := l.Host.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if len(errorMessages) > 0 {
return errors.New(strings.Join(errorMessages, "; "))
}
return nil
}
// Stat returns the status about the libp2p network.
func (l *Libp2p) Stat() types.NetworkStats {
lAddrs := make([]string, 0, len(l.Host.Addrs()))
for _, addr := range l.Host.Addrs() {
lAddrs = append(lAddrs, addr.String())
}
return types.NetworkStats{
ID: l.Host.ID().String(),
ListenAddr: strings.Join(lAddrs, ", "),
}
}
// Ping the remote address. The remote address is the encoded peer id which will be decoded and used here.
//
// TODO (Return error once): something that was confusing me when using this method is that the error is
// returned twice if any. Once as a field of PingResult and one as a return value.
func (l *Libp2p) Ping(ctx context.Context, peerIDAddress string, timeout time.Duration) (types.PingResult, error) {
// avoid dial to self attempt
if peerIDAddress == l.Host.ID().String() {
err := errors.New("can't ping self")
return types.PingResult{Success: false, Error: err}, err
}
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
remotePeer, err := peer.Decode(peerIDAddress)
if err != nil {
return types.PingResult{}, err
}
pingChan := ping.Ping(pingCtx, l.Host, remotePeer)
select {
case res := <-pingChan:
if res.Error != nil {
zlog.Sugar().Errorf("failed to ping peer %s: %v", peerIDAddress, res.Error)
return types.PingResult{
Success: false,
RTT: res.RTT,
Error: res.Error,
}, res.Error
}
return types.PingResult{
RTT: res.RTT,
Success: true,
}, nil
case <-pingCtx.Done():
return types.PingResult{
Error: pingCtx.Err(),
}, pingCtx.Err()
}
}
// ResolveAddress resolves the address by given a peer id.
func (l *Libp2p) ResolveAddress(ctx context.Context, id string) ([]string, error) {
pid, err := peer.Decode(id)
if err != nil {
return nil, fmt.Errorf("failed to resolve invalid peer: %w", err)
}
// resolve ourself
if l.Host.ID().String() == id {
multiAddrs, err := l.GetMultiaddr()
if err != nil {
return nil, fmt.Errorf("failed to resolve self: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
pi, err := l.DHT.FindPeer(ctx, pid)
if err != nil {
return nil, fmt.Errorf("failed to resolve address %s: %w", id, err)
}
peerInfo := peer.AddrInfo{
ID: pi.ID,
Addrs: pi.Addrs,
}
multiAddrs, err := peer.AddrInfoToP2pAddrs(&peerInfo)
if err != nil {
return nil, fmt.Errorf("failed to convert to p2p address: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
// Query return all the advertisements in the network related to a key.
// The network is queried to find providers for the given key, and peers which we aren't connected to can be retrieved.
func (l *Libp2p) Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error) {
if key == "" {
return nil, errors.New("advertisement key is empty")
}
customCID, err := createCIDFromKey(key)
if err != nil {
return nil, fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
addrInfo, err := l.DHT.FindProviders(ctx, customCID)
if err != nil {
return nil, fmt.Errorf("failed to find providers for key %s: %w", key, err)
}
advertisements := make([]*commonproto.Advertisement, 0)
for _, v := range addrInfo {
// TODO: use go routines to get the values in parallel.
bytesAdvertisement, err := l.DHT.GetValue(ctx, l.getCustomNamespace(key, v.ID.String()))
if err != nil {
continue
}
var ad commonproto.Advertisement
if err := proto.Unmarshal(bytesAdvertisement, &ad); err != nil {
return nil, fmt.Errorf("failed to unmarshal advertisement payload: %w", err)
}
advertisements = append(advertisements, &ad)
}
return advertisements, nil
}
// Advertise given data and a key pushes the data to the dht.
func (l *Libp2p) Advertise(ctx context.Context, key string, data []byte) error {
if key == "" {
return errors.New("advertisement key is empty")
}
pubKeyBytes, err := l.getPublicKey()
if err != nil {
return fmt.Errorf("failed to get public key: %w", err)
}
envelope := &commonproto.Advertisement{
PeerId: l.Host.ID().String(),
Timestamp: time.Now().Unix(),
Data: data,
PublicKey: pubKeyBytes,
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
pubKeyBytes,
}, nil)
sig, err := l.sign(concatenatedBytes)
if err != nil {
return fmt.Errorf("failed to sign advertisement envelope content: %w", err)
}
envelope.Signature = sig
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
return fmt.Errorf("failed to marshal advertise envelope: %w", err)
}
customCID, err := createCIDFromKey(key)
if err != nil {
return fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
err = l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), envelopeBytes)
if err != nil {
return fmt.Errorf("failed to put key %s into the dht: %w", key, err)
}
err = l.DHT.Provide(ctx, customCID, true)
if err != nil {
return fmt.Errorf("failed to provide key %s into the dht: %w", key, err)
}
return nil
}
// Unadvertise removes the data from the dht.
func (l *Libp2p) Unadvertise(ctx context.Context, key string) error {
err := l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), nil)
if err != nil {
return fmt.Errorf("failed to remove key %s from the DHT: %w", key, err)
}
return nil
}
// Publish publishes data to a topic.
// The requirements are that only one topic handler should exist per topic.
func (l *Libp2p) Publish(ctx context.Context, topic string, data []byte) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to publish: %w", err)
}
err = topicHandler.Publish(ctx, data)
if err != nil {
return fmt.Errorf("failed to publish to topic %s: %w", topic, err)
}
return nil
}
// Subscribe subscribes to a topic and sends the messages to the handler.
func (l *Libp2p) Subscribe(ctx context.Context, topic string, handler func(data []byte), validator Validator) (uint64, error) {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return 0, fmt.Errorf("failed to subscribe to topic: %w", err)
}
sub, err := topicHandler.Subscribe()
if err != nil {
return 0, fmt.Errorf("failed to subscribe to topic %s: %w", topic, err)
}
l.topicMux.Lock()
subID := l.nextTopicSubID
l.nextTopicSubID++
topicMap, ok := l.topicSubscription[topic]
if !ok {
topicMap = make(map[uint64]*pubsub.Subscription)
l.topicSubscription[topic] = topicMap
}
if validator != nil {
validatorMap, ok := l.topicValidators[topic]
if !ok {
if err := l.pubsub.RegisterTopicValidator(topic, l.validate); err != nil {
sub.Cancel()
return 0, fmt.Errorf("failed to register topic validator: %w", err)
}
validatorMap = make(map[uint64]Validator)
l.topicValidators[topic] = validatorMap
}
validatorMap[subID] = validator
}
topicMap[subID] = sub
l.topicMux.Unlock()
go func() {
for {
msg, err := sub.Next(ctx)
if err != nil {
continue
}
handler(msg.Data)
}
}()
return subID, nil
}
func (l *Libp2p) validate(_ context.Context, _ peer.ID, msg *pubsub.Message) ValidationResult {
l.topicMux.Lock()
validators, ok := l.topicValidators[msg.GetTopic()]
l.topicMux.Unlock()
if !ok {
return ValidationAccept
}
for _, validator := range validators {
result, validatorData := validator(msg.Data, msg.ValidatorData)
if result != ValidationAccept {
return result
}
msg.ValidatorData = validatorData
}
return ValidationAccept
}
func (l *Libp2p) sendMessage(ctx context.Context, addr string, msg types.MessageEnvelope) error {
peerAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return fmt.Errorf("invalid multiaddr %s: %v", addr, err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
return fmt.Errorf("failed to get peer info %s: %v", addr, err)
}
// we are delivering a message to ourself
// we should use the handler to send the message to the handler directly which has been previously registered.
if peerInfo.ID.String() == l.Host.ID().String() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data)
return nil
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return fmt.Errorf("failed to connect to peer %v: %v", peerInfo.ID, err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(msg.Type))
if err != nil {
return fmt.Errorf("failed to open stream to peer %v: %v", peerInfo.ID, err)
}
defer stream.Close()
requestBufferSize := 8 + len(msg.Data)
if requestBufferSize > MaxMessageLengthMB*MB {
return fmt.Errorf("message size %d is greater than limit %d bytes", requestBufferSize, MaxMessageLengthMB*MB)
}
requestPayloadWithLength := make([]byte, requestBufferSize)
binary.LittleEndian.PutUint64(requestPayloadWithLength, uint64(len(msg.Data)))
copy(requestPayloadWithLength[8:], msg.Data)
_, err = stream.Write(requestPayloadWithLength)
if err != nil {
return fmt.Errorf("failed to send message to peer %v: %v", peerInfo.ID, err)
}
return nil
}
// getOrJoinTopicHandler gets the topic handler, it will be created if it doesn't exist.
// for publishing and subscribing its needed therefore its implemented in this function.
func (l *Libp2p) getOrJoinTopicHandler(topic string) (*pubsub.Topic, error) {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
t, err := l.pubsub.Join(topic)
if err != nil {
return nil, fmt.Errorf("failed to join topic %s: %w", topic, err)
}
topicHandler = t
l.pubsubTopics[topic] = t
}
return topicHandler, nil
}
// Unsubscribe cancels the subscription to a topic
func (l *Libp2p) Unsubscribe(topic string, subID uint64) error {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to topic: %s", topic)
}
topicValidators, ok := l.topicValidators[topic]
if ok {
delete(topicValidators, subID)
}
// delete subscription handler and subscription
topicSubscriptions, ok := l.topicSubscription[topic]
if ok {
sub, ok := topicSubscriptions[subID]
if ok {
sub.Cancel()
delete(topicSubscriptions, subID)
}
}
if len(topicSubscriptions) == 0 {
delete(l.pubsubTopics, topic)
if err := topicHandler.Close(); err != nil {
return fmt.Errorf("failed to close topic handler: %w", err)
}
}
return nil
}
func (l *Libp2p) VisiblePeers() []peer.AddrInfo {
return l.discoveredPeers
}
func (l *Libp2p) KnownPeers() ([]peer.AddrInfo, error) {
knownPeers := l.Host.Peerstore().Peers()
peers := make([]peer.AddrInfo, 0, len(knownPeers))
for _, p := range knownPeers {
peers = append(peers, peer.AddrInfo{ID: p})
}
return peers, nil
}
func (l *Libp2p) DumpDHTRoutingTable() ([]kbucket.PeerInfo, error) {
rt := l.DHT.RoutingTable()
return rt.GetPeerInfos(), nil
}
func (l *Libp2p) registerStreamHandlers() {
l.pingService = ping.NewPingService(l.Host)
l.Host.SetStreamHandler(protocol.ID("/ipfs/ping/1.0.0"), l.pingService.PingHandler)
}
func (l *Libp2p) sign(data []byte) ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
signature, err := privKey.Sign(data)
if err != nil {
return nil, fmt.Errorf("failed to sign data: %w", err)
}
return signature, nil
}
func (l *Libp2p) getPublicKey() ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
pubKey := privKey.GetPublic()
return pubKey.Raw()
}
func (l *Libp2p) getCustomNamespace(key, peerID string) string {
return fmt.Sprintf("%s-%s-%s", l.config.CustomNamespace, key, peerID)
}
func createCIDFromKey(key string) (cid.Cid, error) {
hash := sha256.Sum256([]byte(key))
mh, err := multihash.Encode(hash[:], multihash.SHA2_256)
if err != nil {
return cid.Cid{}, err
}
return cid.NewCidV1(cid.Raw, mh), nil
}
func CleanupPeer(_ peer.ID) error {
zlog.Warn("CleanupPeer: Stub")
return nil
}
func PingPeer(_ context.Context, _ peer.ID) (bool, *ping.Result) {
zlog.Warn("PingPeer: Stub")
return false, nil
}
func DumpKademliaDHT(_ context.Context) ([]types.PeerData, error) {
zlog.Warn("DumpKademliaDHT: Stub")
return nil, nil
}
func OldPingPeer(_ context.Context, _ peer.ID) (bool, *types.PingResult) {
zlog.Warn("OldPingPeer: Stub")
return false, nil
}
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/spf13/afero"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
type (
Validator = libp2p.Validator
ValidationResult = libp2p.ValidationResult
)
const (
ValidationAccept = libp2p.ValidationAccept
ValidationReject = libp2p.ValidationReject
ValidationIgnore = libp2p.ValidationIgnore
)
// Messenger defines the interface for sending messages.
type Messenger interface {
// SendMessage sends a message to the given address.
SendMessage(ctx context.Context, addrs []string, msg types.MessageEnvelope) error
}
type Network interface {
// Messenger embedded interface
Messenger
// Init initializes the network
Init(context.Context) error
// Start starts the network
Start(context context.Context) error
// Stat returns the network information
Stat() types.NetworkStats
// Ping pings the given address and returns the PingResult
Ping(ctx context.Context, address string, timeout time.Duration) (types.PingResult, error)
// HandleMessage is responsible for registering a message type and its handler.
HandleMessage(messageType string, handler func(data []byte)) error
// ResolveAddress given an id it retruns the address of the peer.
// In libp2p, id represents the peerID and the response is the addrinfo
ResolveAddress(ctx context.Context, id string) ([]string, error)
// Advertise advertises the given data with the given adId
// such as advertising device capabilities on the DHT
Advertise(ctx context.Context, key string, data []byte) error
// Unadvertise stops advertising data corresponding to the given adId
Unadvertise(ctx context.Context, key string) error
// Query returns the network advertisement
Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error)
// Publish publishes the given data to the given topic if the network
// type allows publish/subscribe functionality such as gossipsub or nats
Publish(ctx context.Context, topic string, data []byte) error
// Subscribe subscribes to the given topic and calls the handler function
// if the network type allows it similar to Publish()
Subscribe(ctx context.Context, topic string, handler func(data []byte), validator libp2p.Validator) (uint64, error)
// Unsubscribe from a topic
Unsubscribe(topic string, subID uint64) error
// Stop stops the network including any existing advertisements and subscriptions
Stop() error
}
// NewNetwork returns a new network given the configuration.
func NewNetwork(netConfig *types.NetworkConfig, fs afero.Fs) (Network, error) {
// TODO: probable additional params to receive: DB, FileSystem
if netConfig == nil {
return nil, errors.New("network configuration is nil")
}
switch netConfig.Type {
case types.Libp2pNetwork:
ln, err := libp2p.New(&netConfig.Libp2pConfig, fs)
return ln, err
case types.NATSNetwork:
return nil, errors.New("not implemented")
default:
return nil, fmt.Errorf("unsupported network type: %s", netConfig.Type)
}
}
package basiccontroller
import (
"context"
"fmt"
"os"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// BasicVolumeController is the default implementation of the VolumeController.
// It persists storage volumes information using the StorageVolume.
type BasicVolumeController struct {
// repo is the repository for storage volume operations
repo repositories.StorageVolume
// basePath is the base path where volumes are stored under
basePath string
// file system to act upon
FS afero.Fs
}
// NewDefaultVolumeController returns a new instance of BasicVolumeController
//
// TODO-BugFix [path]: volBasePath might not end with `/`, causing errors when calling methods.
// We need to validate it using the `path` library or just verifying the string.
func NewDefaultVolumeController(repo repositories.StorageVolume, volBasePath string, fs afero.Fs) (*BasicVolumeController, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_controller_init_duration", "opentelemetry", "log")
defer cancel()
vc := &BasicVolumeController{
repo: repo,
basePath: volBasePath,
FS: fs,
}
st.Info(ctx, "volume_controller_init_success", nil)
return vc, nil
}
// CreateVolume creates a new storage volume given a source (S3, IPFS, job, etc). The
// creation of a storage volume effectively creates an empty directory in the local filesystem
// and writes a record in the database.
//
// The directory name follows the format: `<volSource> + "-" + <name>
// where `name` is random.
//
// TODO-maybe [withName]: allow callers to specify custom name for path
func (vc *BasicVolumeController) CreateVolume(volSource storage.VolumeSource, opts ...storage.CreateVolOpt) (types.StorageVolume, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_create_duration", "opentelemetry", "log")
defer cancel()
vol := types.StorageVolume{
Private: false,
ReadOnly: false,
EncryptionType: types.EncryptionTypeNull,
}
for _, opt := range opts {
opt(&vol)
}
randomStr, err := utils.RandomString(16)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create random string: %w", err)
}
vol.Path = vc.basePath + string(volSource) + "-" + randomStr
ctx = context.WithValue(ctx, pathKey, vol.Path)
if err := vc.FS.Mkdir(vol.Path, os.ModePerm); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_create_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %w", err)
}
createdVol, err := vc.repo.Create(ctx, vol)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_create_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume in repository: %w", err)
}
ctx = context.WithValue(ctx, volumeIDKey, createdVol.ID)
st.Info(ctx, "volume_create_success", nil)
return createdVol, nil
}
// LockVolume makes the volume read-only, not only changing the field value but also changing file permissions.
// It should be used after all necessary data has been written.
// It optionally can also set the CID and mark the volume as private.
//
// TODO-maybe [CID]: maybe calculate CID of every volume in case WithCID opt is not provided
func (vc *BasicVolumeController) LockVolume(pathToVol string, opts ...storage.LockVolOpt) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_lock_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, pathKey, pathToVol)
query := vc.repo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("Path", pathToVol))
vol, err := vc.repo.Find(ctx, query)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_lock_failure", nil)
return fmt.Errorf("failed to find storage volume with path %s - Error: %w", pathToVol, err)
}
for _, opt := range opts {
opt(&vol)
}
vol.ReadOnly = true
updatedVol, err := vc.repo.Update(ctx, vol.ID, vol)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_lock_failure", nil)
return fmt.Errorf("failed to update storage volume with path %s - Error: %w", pathToVol, err)
}
// change file permissions
if err := vc.FS.Chmod(updatedVol.Path, 0o400); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_lock_failure", nil)
return fmt.Errorf("failed to make storage volume read-only (path: %s): %w", updatedVol.Path, err)
}
st.Info(ctx, "volume_lock_success", nil)
return nil
}
// WithPrivate designates a given volume as private. It can be used both
// when creating or locking a volume.
func WithPrivate[T storage.CreateVolOpt | storage.LockVolOpt]() T {
return func(v *types.StorageVolume) {
v.Private = true
}
}
// WithCID sets the CID of a given volume if already calculated
//
// TODO [validate]: check if CID provided is valid
func WithCID(cid string) storage.LockVolOpt {
return func(v *types.StorageVolume) {
v.CID = cid
}
}
// DeleteVolume deletes a given storage volume record from the database.
// Identifier is either a CID or a path of a volume. Therefore, records for both
// will be deleted.
//
// Note [CID]: if we start to type CID as cid.CID, we may have to use generics here
// as in `[T string | cid.CID]`
func (vc *BasicVolumeController) DeleteVolume(identifier string, idType storage.IDType) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_delete_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, identifierKey, identifier)
ctx = context.WithValue(ctx, idTypeKey, idType)
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
ctx = context.WithValue(ctx, errorKey, "identifier type not supported")
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("identifier type not supported")
}
vol, err := vc.repo.Find(ctx, query)
if err != nil {
if err == repositories.ErrNotFound {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("volume not found: %w", err)
}
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("failed to find volume: %w", err)
}
err = vc.repo.Delete(ctx, vol.ID)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_delete_failure", nil)
return fmt.Errorf("failed to delete volume: %w", err)
}
st.Info(ctx, "volume_delete_success", nil)
return nil
}
// ListVolumes returns a list of all storage volumes stored on the database
//
// TODO [filter]: maybe add opts to filter results by certain values
func (vc *BasicVolumeController) ListVolumes() ([]types.StorageVolume, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_list_duration", "opentelemetry", "log")
defer cancel()
volumes, err := vc.repo.FindAll(ctx, vc.repo.GetQuery())
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "volume_list_failure", nil)
return nil, fmt.Errorf("failed to list volumes: %w", err)
}
ctx = context.WithValue(ctx, volumeCountKey, len(volumes))
st.Info(ctx, "volume_list_success", nil)
return volumes, nil
}
// GetSize returns the size of a volume
// TODO-minor: identify which measurement type will be used
func (vc *BasicVolumeController) GetSize(identifier string, idType storage.IDType) (int64, error) {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_get_size_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, identifierKey, identifier)
ctx = context.WithValue(ctx, idTypeKey, idType)
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
ctx = context.WithValue(ctx, errorKey, fmt.Sprintf("unsupported ID type: %d", idType))
st.Error(ctx, "volume_get_size_failure", nil)
return 0, fmt.Errorf("unsupported ID type: %d", idType)
}
vol, err := vc.repo.Find(ctx, query)
if err != nil {
ctx = context.WithValue(ctx, errorKey, fmt.Sprintf("failed to find volume: %v", err))
st.Error(ctx, "volume_get_size_failure", nil)
return 0, fmt.Errorf("failed to find volume: %w", err)
}
size, err := utils.GetDirectorySize(vc.FS, vol.Path)
if err != nil {
ctx = context.WithValue(ctx, errorKey, fmt.Sprintf("failed to get directory size: %v", err))
st.Error(ctx, "volume_get_size_failure", nil)
return 0, fmt.Errorf("failed to get directory size: %w", err)
}
ctx = context.WithValue(ctx, sizeKey, size)
st.Info(ctx, "volume_get_size_success", nil)
ctx = context.WithValue(ctx, sizeKey, size)
st.Info(ctx, "volume_get_size_success", nil)
return size, nil
}
// EncryptVolume encrypts a given volume
func (vc *BasicVolumeController) EncryptVolume(path string, _ types.Encryptor, _ types.EncryptionType) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_encrypt_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, pathKey, path)
st.Error(ctx, "volume_encrypt_not_implemented", nil)
return fmt.Errorf("not implemented")
}
// DecryptVolume decrypts a given volume
func (vc *BasicVolumeController) DecryptVolume(path string, _ types.Decryptor, _ types.EncryptionType) error {
ctx, cancel := st.SpanContext(context.Background(), "controller", "volume_decrypt_duration", "opentelemetry", "log")
defer cancel()
ctx = context.WithValue(ctx, pathKey, path)
st.Error(ctx, "volume_decrypt_not_implemented", nil)
return fmt.Errorf("not implemented")
}
// TODO-minor: compiler-time check for interface implementation
var _ storage.VolumeController = (*BasicVolumeController)(nil)
package basiccontroller
import (
"context"
"fmt"
"os"
"testing"
clover "github.com/ostafen/clover/v2"
"github.com/spf13/afero"
rclover "gitlab.com/nunet/device-management-service/db/repositories/clover"
"gitlab.com/nunet/device-management-service/telemetry"
"gitlab.com/nunet/device-management-service/types"
)
type VolControllerTestSuiteHelper struct {
BasicVolController *BasicVolumeController
Fs afero.Fs
DB *clover.DB
Volumes map[string]*types.StorageVolume
TempDBDir string
}
// SetupVolControllerTestSuite sets up a volume controller with 0-n volumes given a base path.
// If volumes are inputed, directories will be created and volumes will be stored in the database
func SetupVolControllerTestSuite(t *testing.T, basePath string, volumes map[string]*types.StorageVolume) (*VolControllerTestSuiteHelper, error) {
// Initialize telemetry in test mode, replacing the global st
// It's initiated here too, besides on basic_controller_test.go, because
// s3 tests depend on basicController (which in turn depends on telemetry instantiation).
// S3 are calling this SetupVolControllerTestSuite, so it's one way to initialize telemetry
// for basic controller
st = telemetry.NewTelemetry(nil, nil, true)
tempDir, err := os.MkdirTemp("", "clover-test-*")
if err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
db, err := rclover.NewDB(tempDir, []string{"storage_volume"})
if err != nil {
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to open clover db: %w", err)
}
fs := afero.NewMemMapFs()
err = fs.MkdirAll(basePath, 0o755)
if err != nil {
db.Close()
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to create base path: %w", err)
}
repo := rclover.NewStorageVolume(db)
vc, err := NewDefaultVolumeController(repo, basePath, fs)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume controller: %w", err)
}
for _, vol := range volumes {
// create root volume dir
err = fs.MkdirAll(vol.Path, 0o755)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume dir: %w", err)
}
// create volume record in db
_, err = repo.Create(context.Background(), *vol)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume record: %w", err)
}
}
helper := &VolControllerTestSuiteHelper{vc, fs, db, volumes, tempDir}
t.Cleanup(func() {
TearDownVolControllerTestSuite(helper)
})
return helper, nil
}
// TearDownVolControllerTestSuite cleans up resources created during setup
func TearDownVolControllerTestSuite(helper *VolControllerTestSuiteHelper) {
if helper.DB != nil {
helper.DB.Close()
}
if helper.TempDBDir != "" {
os.RemoveAll(helper.TempDBDir)
}
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/storage"
basiccontroller "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Download fetches files from a given S3 bucket. The key may be a directory ending
// with `/` or have a wildcard (`*`) so it handles normal S3 folders but it does
// not handle x-directory.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Download(ctx context.Context, sourceSpecs *types.SpecConfig) (types.StorageVolume, error) {
ctx, cancel := st.SpanContext(ctx, "s3", "s3_download_duration", "opentelemetry", "log")
defer cancel()
var storageVol types.StorageVolume
source, err := DecodeInputSpec(sourceSpecs)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_download_failure", nil)
return types.StorageVolume{}, err
}
storageVol, err = s.volController.CreateVolume(storage.VolumeSourceS3)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_volume_create_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %v", err)
}
resolvedObjects, err := resolveStorageKey(ctx, s.Client, &source)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_resolve_key_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to resolve storage key: %v", err)
}
for _, resolvedObject := range resolvedObjects {
err = s.downloadObject(ctx, &source, resolvedObject, storageVol.Path)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_download_object_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to download s3 object: %v", err)
}
}
// after data is filled within the volume, we have to lock it
err = s.volController.LockVolume(storageVol.Path)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_volume_lock_failure", nil)
return types.StorageVolume{}, fmt.Errorf("failed to lock storage volume: %v", err)
}
st.Info(ctx, "s3_download_success", nil)
return storageVol, nil
}
func (s *Storage) downloadObject(ctx context.Context, source *InputSource, object s3Object, volPath string) error {
outputPath := filepath.Join(volPath, *object.key)
// use the same file system instance used by the Volume Controller
var fs afero.Fs
if basicVolController, ok := s.volController.(*basiccontroller.BasicVolumeController); ok {
fs = basicVolController.FS
}
err := fs.MkdirAll(outputPath, 0o755)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_create_directory_failure", nil)
return fmt.Errorf("failed to create directory: %v", err)
}
if object.isDir {
// if object is a directory, we don't need to download it (just create the dir)
return nil
}
outputFile, err := fs.OpenFile(outputPath, os.O_RDWR|os.O_CREATE, 0o755)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_open_file_failure", nil)
return err
}
defer outputFile.Close()
zlog.Sugar().Debugf("Downloading s3 object %s to %s", *object.key, outputPath)
_, err = s.downloader.Download(ctx, outputFile, &s3.GetObjectInput{
Bucket: aws.String(source.Bucket),
Key: object.key,
IfMatch: object.eTag,
})
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_download_failure", nil)
return fmt.Errorf("failed to download file: %w", err)
}
st.Info(ctx, "s3_download_object_success", nil)
return nil
}
// resolveStorageKey returns a list of s3 objects within a bucket according to the key provided.
func resolveStorageKey(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := source.Key
if key == "" {
err := fmt.Errorf("key is required")
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_resolve_key_failure", nil)
return nil, err
}
// Check if the key represents a single object
if !strings.HasSuffix(key, "/") && !strings.Contains(key, "*") {
return resolveSingleObject(ctx, client, source)
}
// key represents multiple objects
return resolveObjectsWithPrefix(ctx, client, source)
}
func resolveSingleObject(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
headObjectInput := &s3.HeadObjectInput{
Bucket: aws.String(source.Bucket),
Key: aws.String(key),
}
headObjectOut, err := client.HeadObject(ctx, headObjectInput)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_head_object_failure", nil)
return []s3Object{}, fmt.Errorf("failed to retrieve object metadata: %v", err)
}
if strings.HasPrefix(*headObjectOut.ContentType, "application/x-directory") {
err := fmt.Errorf("x-directory is not yet handled")
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_directory_handling_failure", nil)
return []s3Object{}, err
}
return []s3Object{
{
key: aws.String(source.Key),
eTag: headObjectOut.ETag,
size: *headObjectOut.ContentLength,
},
}, nil
}
func resolveObjectsWithPrefix(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
// List objects with the given prefix
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(source.Bucket),
Prefix: aws.String(key),
}
var objects []s3Object
paginator := s3.NewListObjectsV2Paginator(client, listObjectsInput)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_list_objects_failure", nil)
return nil, fmt.Errorf("failed to list objects: %v", err)
}
for _, obj := range page.Contents {
objects = append(objects, s3Object{
key: aws.String(*obj.Key),
size: *obj.Size,
isDir: strings.HasSuffix(*obj.Key, "/"),
})
}
}
st.Info(ctx, "s3_resolve_objects_with_prefix_success", nil)
return objects, nil
}
package s3
import (
"context"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
)
// GetAWSDefaultConfig returns the default AWS config based on environment variables,
// shared configuration and shared credentials files.
func GetAWSDefaultConfig() (aws.Config, error) {
ctx, cancel := st.SpanContext(context.Background(), "s3", "get_aws_default_config_duration", "opentelemetry", "log")
defer cancel()
var optFns []func(*config.LoadOptions) error
cfg, err := config.LoadDefaultConfig(ctx, optFns...)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "get_aws_default_config_failure", nil)
return aws.Config{}, err
}
st.Info(ctx, "get_aws_default_config_success", nil)
return cfg, nil
}
// hasValidCredentials checks if the provided AWS config has valid credentials.
func hasValidCredentials(config aws.Config) bool {
ctx, cancel := st.SpanContext(context.Background(), "s3", "has_valid_credentials_duration", "opentelemetry", "log")
defer cancel()
credentials, err := config.Credentials.Retrieve(ctx)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "has_valid_credentials_failure", nil)
return false
}
if !credentials.HasKeys() {
st.Error(ctx, "has_valid_credentials_failure_no_keys", nil)
return false
}
st.Info(ctx, "has_valid_credentials_success", nil)
return true
}
// sanitizeKey removes trailing spaces and wildcards
func sanitizeKey(key string) string {
ctx, cancel := st.SpanContext(context.Background(), "s3", "sanitize_key_duration", "opentelemetry", "log")
defer cancel()
sanitizedKey := strings.TrimSuffix(strings.TrimSpace(key), "*")
ctx = context.WithValue(ctx, sanitizedKeyContext, sanitizedKey)
st.Info(ctx, "sanitize_key_success", nil)
return sanitizedKey
}
package s3
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var (
zlog *otelzap.Logger
st = telemetry.GetTelemetry()
)
// Context keys used for tracing
type contextKey string
const (
pathKey contextKey = "path"
SourceSpecsKey contextKey = "sourceSpecs"
errorKey contextKey = "error"
OutputPathKey contextKey = "outputPath"
bucketKey contextKey = "bucket"
S3KeyKey contextKey = "key"
ContentLength contextKey = "content_length"
FilePathKey contextKey = "file_path"
VolumePathKey contextKey = "volume_path"
sanitizedKeyContext contextKey = "sanitized_key"
)
func init() {
zlog = logger.OtelZapLogger("s3")
}
package s3
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
)
type Storage struct {
*s3.Client
volController storage.VolumeController
downloader *s3Manager.Downloader
uploader *s3Manager.Uploader
}
type s3Object struct {
key *string
eTag *string
size int64
isDir bool
}
// NewClient creates a new S3Storage which includes a S3-SDK client.
// It depends on a VolumeController to manage the volumes being acted upon.
func NewClient(config aws.Config, volController storage.VolumeController) (*Storage, error) {
ctx, cancel := st.SpanContext(context.Background(), "s3", "new_client_duration", "opentelemetry", "log")
defer cancel()
if !hasValidCredentials(config) {
err := fmt.Errorf("invalid credentials")
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "new_client_invalid_credentials", nil)
return nil, err
}
s3Client := s3.NewFromConfig(config)
storage := &Storage{
s3Client,
volController,
s3Manager.NewDownloader(s3Client),
s3Manager.NewUploader(s3Client),
}
st.Info(ctx, "new_client_success", nil)
return storage, nil
}
func (s *Storage) Size(ctx context.Context, source *types.SpecConfig) (uint64, error) {
ctx, cancel := st.SpanContext(ctx, "s3", "s3_size_duration", "opentelemetry", "log")
defer cancel()
inputSource, err := DecodeInputSpec(source)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_size_decode_input_spec_failure", nil)
return 0, fmt.Errorf("failed to decode input spec: %v", err)
}
input := &s3.HeadObjectInput{
Bucket: aws.String(inputSource.Bucket),
Key: aws.String(inputSource.Key),
}
output, err := s.HeadObject(ctx, input)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_size_head_object_failure", nil)
return 0, fmt.Errorf("failed to get object size: %v", err)
}
st.Info(ctx, "s3_size_success", nil)
return uint64(*output.ContentLength), nil
}
// Compile time interface check
// var _ storage.StorageProvider = (*S3Storage)(nil)
package s3
import (
"context"
"fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure"
"gitlab.com/nunet/device-management-service/types"
)
type InputSource struct {
Bucket string
Key string
Filter string
Region string
Endpoint string
}
func (s InputSource) Validate() error {
if s.Bucket == "" {
err := fmt.Errorf("invalid s3 storage params: bucket cannot be empty")
st.Error(context.Background(), "s3_input_source_validation_failure", nil)
return err
}
return nil
}
func (s InputSource) ToMap() map[string]interface{} {
return structs.Map(s)
}
func DecodeInputSpec(spec *types.SpecConfig) (InputSource, error) {
ctx, cancel := st.SpanContext(context.Background(), "s3", "decode_input_spec_duration", "opentelemetry", "log")
defer cancel()
if !spec.IsType(types.StorageProviderS3) {
err := fmt.Errorf("invalid storage source type. Expected %s but received %s", types.StorageProviderS3, spec.Type)
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "decode_input_spec_invalid_type_failure", nil)
return InputSource{}, err
}
inputParams := spec.Params
if inputParams == nil {
err := fmt.Errorf("invalid storage input source params. cannot be nil")
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "decode_input_spec_nil_params_failure", nil)
return InputSource{}, err
}
var c InputSource
if err := mapstructure.Decode(spec.Params, &c); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "decode_input_spec_decode_failure", nil)
return c, err
}
if err := c.Validate(); err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "decode_input_spec_validation_failure", nil)
return c, err
}
st.Info(ctx, "decode_input_spec_success", nil)
return c, nil
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
basiccontroller "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Upload uploads all files (recursively) from a local volume to an S3 bucket.
// It handles directories.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Upload(ctx context.Context, vol types.StorageVolume, destinationSpecs *types.SpecConfig) error {
ctx, cancel := st.SpanContext(ctx, "s3", "s3_upload_duration", "opentelemetry", "log")
defer cancel()
target, err := DecodeInputSpec(destinationSpecs)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_decode_spec_failure", nil)
return fmt.Errorf("failed to decode input spec: %v", err)
}
sanitizedKey := sanitizeKey(target.Key)
// set file system to act upon based on the volume controller implementation
var fs afero.Fs
if basicVolController, ok := s.volController.(*basiccontroller.BasicVolumeController); ok {
fs = basicVolController.FS
}
zlog.Sugar().Debugf("Uploading files from %s to s3://%s/%s", vol.Path, target.Bucket, sanitizedKey)
err = afero.Walk(fs, vol.Path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_walk_failure", nil)
return err
}
// Skip directories
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(vol.Path, filePath)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_relative_path_failure", nil)
return fmt.Errorf("failed to get relative path: %v", err)
}
// Construct the S3 key by joining the sanitized key and the relative path
s3Key := filepath.Join(sanitizedKey, relPath)
file, err := fs.Open(filePath)
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_open_file_failure", nil)
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
// Add file path and S3 key to context
ctx = context.WithValue(ctx, FilePathKey, filePath)
ctx = context.WithValue(ctx, S3KeyKey, s3Key)
zlog.Sugar().Debugf("Uploading %s to s3://%s/%s", filePath, target.Bucket, s3Key)
_, err = s.uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(target.Bucket),
Key: aws.String(s3Key),
Body: file,
})
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_put_object_failure", nil)
return fmt.Errorf("failed to upload file to S3: %v", err)
}
st.Info(ctx, "s3_upload_file_success", nil)
return nil
})
if err != nil {
ctx = context.WithValue(ctx, errorKey, err.Error())
st.Error(ctx, "s3_upload_failure", nil)
return fmt.Errorf("upload failed. It's possible that some files were uploaded; Error: %v", err)
}
st.Info(ctx, "s3_upload_success", nil)
return nil
}
package telemetry
import (
"os"
"sync"
"time"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/types"
)
var (
once sync.Once
instance *Telemetry
logLevel types.ObservabilityLevel
zapLogger *zap.Logger
)
// InitGlobalTelemetry initializes the global telemetry instance with configuration loaded from the configuration package.
func InitGlobalTelemetry() error {
var initError error
once.Do(func() {
// Initialize Zap logger
zapLogger, initError = initZapLogger()
if initError != nil {
panic(initError)
}
zap.ReplaceGlobals(zapLogger)
cfg := config.GetConfig()
telemetryConfig := cfg.Telemetry
logLevel = types.INFO // Default level
if level, err := types.ParseObservabilityLevel(telemetryConfig.ObservabilityLevel); err == nil {
logLevel = level
} else {
zap.L().Warn("Invalid observability level, defaulting to INFO", zap.Error(err))
}
instance = &Telemetry{
config: &types.TelemetryConfig{
ServiceName: telemetryConfig.ServiceName,
GlobalEndpoint: telemetryConfig.GlobalEndpoint,
ObservabilityLevel: telemetryConfig.ObservabilityLevel, // Assign the string value
TelemetryMode: telemetryConfig.TelemetryMode,
},
}
opentelemetryCollector := NewOpenTelemetryCollector(instance.config, zap.L())
logCollector := NewLogCollector(instance.config, zap.L())
instance.collectors = map[string]Collector{
logCollector.GetName(): logCollector,
opentelemetryCollector.GetName(): opentelemetryCollector,
}
for _, collector := range instance.collectors {
if err := collector.Initialize(); err != nil {
zap.L().Error("Failed to initialize collector", zap.Error(err))
}
}
// Start periodic flush after initializing collectors
StartPeriodicFlush(5 * time.Minute)
})
return initError
}
// initZapLogger initializes the zap logger based on configuration or environment variables.
func initZapLogger() (*zap.Logger, error) {
var err error
var logger *zap.Logger
if _, debug := os.LookupEnv("NUNET_DEBUG"); debug || config.GetConfig().General.Debug {
zapConfig := zap.NewDevelopmentConfig()
zapConfig.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
logger, err = zapConfig.Build()
} else {
logger, err = zap.NewProduction()
}
return logger, err
}
// StartPeriodicFlush starts a goroutine that periodically flushes telemetry data.
func StartPeriodicFlush(interval time.Duration) {
ticker := time.NewTicker(interval)
go func() {
for range ticker.C {
zap.L().Info("Periodic flush started for telemetry")
instance.Flush()
}
}()
}
package telemetry
import (
"context"
"gitlab.com/nunet/device-management-service/types"
"go.uber.org/zap"
)
type LogCollector struct {
config *types.TelemetryConfig
logger *zap.Logger
}
func NewLogCollector(config *types.TelemetryConfig, logger *zap.Logger) *LogCollector {
return &LogCollector{
config: config,
logger: logger,
}
}
func (c *LogCollector) Initialize() error {
c.logger.Info("LogCollector initialized.")
return nil
}
func (c *LogCollector) HandleEvent(event types.Event) error {
fields := []zap.Field{
zap.Any("context", event.Context),
zap.String("message", event.Message),
zap.String("level", event.Level.String()),
zap.Any("payload", event.Payload),
}
switch event.Level {
case types.TRACE:
c.logger.Debug(event.Message, fields...)
case types.DEBUG:
c.logger.Debug(event.Message, fields...)
case types.INFO:
c.logger.Info(event.Message, fields...)
case types.WARN:
c.logger.Warn(event.Message, fields...)
case types.ERROR:
c.logger.Error(event.Message, fields...)
case types.FATAL:
c.logger.Fatal(event.Message, fields...)
default:
c.logger.Info(event.Message, fields...)
}
return nil
}
func (c *LogCollector) Flush() error {
if err := c.logger.Sync(); err != nil { // Check for error in Sync
return err
}
return nil
}
func (c *LogCollector) Shutdown() error {
return c.Flush()
}
func (c *LogCollector) GetName() string {
return "log"
}
func (c *LogCollector) SpanContext(ctx context.Context, _ string) (context.Context, context.CancelFunc) {
// LogCollector does not support tracing, so just return the original context and a no-op cancel function
return ctx, func() {}
}
// Compile-time check to ensure LogCollector implements the Collector interface
var _ Collector = (*LogCollector)(nil)
package telemetry
import (
"context"
"sync"
"gitlab.com/nunet/device-management-service/types"
)
// MockCollector is a mock implementation of the Collector interface.
type MockCollector struct {
mu sync.Mutex
events []types.Event
traces []MockTrace
initialized bool
name string
}
type MockTrace struct {
SpanName string
Context context.Context
CancelFunc context.CancelFunc
}
// NewMockCollector creates a new instance of MockCollector.
func NewMockCollector(name string) *MockCollector {
return &MockCollector{
events: []types.Event{},
traces: []MockTrace{},
name: name,
}
}
// Initialize is a mock implementation of the Collector interface's Initialize method.
func (m *MockCollector) Initialize() error {
m.mu.Lock()
defer m.mu.Unlock()
m.initialized = true
return nil
}
// SpanContext is a mock implementation of the Collector interface's SpanContext method.
func (m *MockCollector) SpanContext(ctx context.Context, spanName string) (context.Context, context.CancelFunc) {
m.mu.Lock()
defer m.mu.Unlock()
mockCtx, cancel := context.WithCancel(ctx)
m.traces = append(m.traces, MockTrace{
SpanName: spanName,
Context: mockCtx,
CancelFunc: cancel,
})
return mockCtx, cancel
}
// HandleEvent is a mock implementation of the Collector interface's HandleEvent method.
func (m *MockCollector) HandleEvent(event types.Event) error {
m.mu.Lock()
defer m.mu.Unlock()
m.events = append(m.events, event)
return nil
}
// Flush is a mock implementation of the Collector interface's Flush method.
func (m *MockCollector) Flush() error { // Added error return
m.mu.Lock()
defer m.mu.Unlock()
return nil
}
// Shutdown is a mock implementation of the Collector interface's Shutdown method.
func (m *MockCollector) Shutdown() error { // Added error return
m.mu.Lock()
defer m.mu.Unlock()
return nil
}
// GetName returns the name of the mock collector.
func (m *MockCollector) GetName() string {
return m.name
}
// GetTraces returns the recorded traces.
func (m *MockCollector) GetTraces() []MockTrace {
m.mu.Lock()
defer m.mu.Unlock()
return m.traces
}
// GetEvents returns the recorded events.
func (m *MockCollector) GetEvents() []types.Event {
m.mu.Lock()
defer m.mu.Unlock()
return m.events
}
// Reset clears all recorded events and traces.
func (m *MockCollector) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
m.events = []types.Event{}
m.traces = []MockTrace{}
}
// AssertInitialized checks if the mock collector was initialized.
func (m *MockCollector) AssertInitialized() bool {
m.mu.Lock()
defer m.mu.Unlock()
return m.initialized
}
// MockTelemetry is a mock implementation of the Telemetry system.
type MockTelemetry struct {
Telemetry
mu sync.Mutex
collectors map[string]*MockCollector
}
// NewMockTelemetry creates a new instance of MockTelemetry that mimics the Telemetry struct.
func NewMockTelemetry(config *types.TelemetryConfig) *MockTelemetry {
return &MockTelemetry{
Telemetry: Telemetry{
config: config,
collectors: make(map[string]Collector),
},
collectors: make(map[string]*MockCollector),
}
}
// AddCollector adds a mock collector to the telemetry system.
func (m *MockTelemetry) AddCollector(collector *MockCollector) {
m.mu.Lock()
defer m.mu.Unlock()
m.collectors[collector.GetName()] = collector
}
// SpanContext simulates starting a trace with the given collectors.
func (m *MockTelemetry) SpanContext(ctx context.Context, _ string, span string, collectors ...string) (context.Context, context.CancelFunc) { // Renamed unused parameter
var cancelFuncs []context.CancelFunc
for _, collectorName := range collectors {
if collector, ok := m.collectors[collectorName]; ok {
mockCtx, cancel := collector.SpanContext(ctx, span)
cancelFuncs = append(cancelFuncs, cancel)
ctx = mockCtx
}
}
cancel := func() {
for _, cancelFunc := range cancelFuncs {
cancelFunc()
}
}
return ctx, cancel
}
// Trace simulates logging a trace event in all added collectors.
func (m *MockTelemetry) Trace(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.TRACE, message, payload)
}
// Debug simulates logging a debug event in all added collectors.
func (m *MockTelemetry) Debug(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.DEBUG, message, payload)
}
// Info simulates logging an info event in all added collectors.
func (m *MockTelemetry) Info(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.INFO, message, payload)
}
// Warn simulates logging a warning event in all added collectors.
func (m *MockTelemetry) Warn(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.WARN, message, payload)
}
// Error simulates logging an error event in all added collectors.
func (m *MockTelemetry) Error(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.ERROR, message, payload)
}
// Fatal simulates logging a fatal event in all added collectors.
func (m *MockTelemetry) Fatal(ctx context.Context, message string, payload map[string]interface{}) {
m.logEvent(ctx, types.FATAL, message, payload)
}
// logEvent logs an event in all collectors.
func (m *MockTelemetry) logEvent(ctx context.Context, level types.ObservabilityLevel, message string, payload map[string]interface{}) {
event := types.Event{
Context: ctx,
Level: level,
Message: message,
Payload: payload,
}
for _, collector := range m.collectors {
_ = collector.HandleEvent(event) // HandleEvent error is intentionally ignored
}
}
// GetCollector returns a mock collector by name.
func (m *MockTelemetry) GetCollector(name string) *MockCollector {
m.mu.Lock()
defer m.mu.Unlock()
return m.collectors[name]
}
// Reset clears all recorded data in all collectors.
func (m *MockTelemetry) Reset() {
m.mu.Lock()
defer m.mu.Unlock()
for _, collector := range m.collectors {
collector.Reset()
}
}
// Flush is a mock implementation of the Telemetry system's Flush method.
func (m *MockTelemetry) Flush() {
for _, collector := range m.collectors {
_ = collector.Flush() // Flush error is intentionally ignored
}
}
// Shutdown is a mock implementation of the Telemetry system's Shutdown method.
func (m *MockTelemetry) Shutdown() {
m.Flush()
for _, collector := range m.collectors {
_ = collector.Shutdown() // Shutdown error is intentionally ignored
}
}
package telemetry
import (
"context"
"gitlab.com/nunet/device-management-service/types"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/sdk/resource"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"go.uber.org/zap"
)
type OpenTelemetryCollector struct {
config *types.TelemetryConfig
logger *zap.Logger
tracerProvider *sdktrace.TracerProvider
}
func NewOpenTelemetryCollector(config *types.TelemetryConfig, logger *zap.Logger) *OpenTelemetryCollector {
return &OpenTelemetryCollector{
config: config,
logger: logger,
}
}
func (c *OpenTelemetryCollector) Initialize() error {
c.logger.Info("Initializing OpenTelemetry HTTP trace exporter",
zap.String("endpoint", c.config.GlobalEndpoint),
)
exp, err := otlptracehttp.New(context.Background(),
otlptracehttp.WithEndpoint(c.config.GlobalEndpoint),
otlptracehttp.WithInsecure(),
)
if err != nil {
c.logger.Error("Failed to create HTTP trace exporter", zap.Error(err))
return err
}
res, err := resource.New(context.Background(),
resource.WithAttributes(
semconv.ServiceNameKey.String(c.config.ServiceName),
),
)
if err != nil {
c.logger.Error("Failed to create resource", zap.Error(err))
return err
}
c.tracerProvider = sdktrace.NewTracerProvider(
sdktrace.WithBatcher(exp),
sdktrace.WithResource(res),
)
otel.SetTracerProvider(c.tracerProvider)
c.logger.Info("OpenTelemetryCollector initialized.")
return nil
}
func (c *OpenTelemetryCollector) HandleEvent(event types.Event) error {
fields := []attribute.KeyValue{
attribute.String("message", event.Message),
attribute.String("level", event.Level.String()),
}
for key, value := range event.Payload {
fields = append(fields, attribute.String(key, value.(string)))
}
c.logger.Info("Handling event",
zap.String("message", event.Message),
zap.String("level", event.Level.String()),
zap.Any("context", event.Context),
zap.Any("payload", event.Payload),
)
// Fetch tracer name from context, or default to "otel-tracer"
tracerName, ok := event.Context.Value(tracerNameKey).(string)
if !ok {
tracerName = "otel-tracer"
}
tracer := c.tracerProvider.Tracer(tracerName)
ctx := context.Background()
_, span := tracer.Start(ctx, event.Message)
span.SetAttributes(fields...)
span.End()
c.logger.Info("Event sent to OpenTelemetry",
zap.String("message", event.Message),
zap.String("level", event.Level.String()),
zap.Any("context", event.Context),
zap.Any("payload", event.Payload),
)
return nil
}
func (c *OpenTelemetryCollector) Flush() error {
if c.tracerProvider == nil {
c.logger.Warn("TracerProvider is nil, skipping flush")
return nil
}
c.logger.Info("Flushing tracer provider")
if err := c.tracerProvider.ForceFlush(context.Background()); err != nil {
c.logger.Error("Error flushing tracer provider", zap.Error(err))
return err
}
c.logger.Info("Collector flushed successfully")
return nil
}
func (c *OpenTelemetryCollector) Shutdown() error {
if c.tracerProvider == nil {
c.logger.Warn("TracerProvider is nil, skipping shutdown")
return nil
}
c.logger.Info("Shutting down tracer provider")
if err := c.tracerProvider.Shutdown(context.Background()); err != nil {
c.logger.Error("Error shutting down tracer provider", zap.Error(err))
return err
}
c.logger.Info("Collector shutdown successfully")
return nil
}
func (c *OpenTelemetryCollector) GetName() string {
return "opentelemetry"
}
func (c *OpenTelemetryCollector) SpanContext(ctx context.Context, span string) (context.Context, context.CancelFunc) {
tracerName, ok := ctx.Value(tracerNameKey).(string)
if !ok {
tracerName = c.GetName()
}
tracer := c.tracerProvider.Tracer(tracerName)
ctx, s := tracer.Start(ctx, span)
cancel := func() {
s.End()
}
return ctx, cancel
}
// Compile-time check to ensure OpenTelemetryCollector implements the Collector interface
var _ Collector = (*OpenTelemetryCollector)(nil)
package telemetry
import (
"context"
"runtime"
"go.uber.org/zap"
"gitlab.com/nunet/device-management-service/api/docs"
"gitlab.com/nunet/device-management-service/types"
)
type Telemetry struct {
config *types.TelemetryConfig
collectors map[string]Collector
testMode bool
}
// Define a custom type for context keys to avoid conflicts
type contextKey string
const (
collectorsKey contextKey = "collectors"
tracerNameKey contextKey = "tracerName"
versionKey contextKey = "version"
)
func GetTelemetry() *Telemetry {
return instance
}
// NewTelemetry initializes a new Telemetry instance.
// If testMode is true, the telemetry operations will be no-ops.
func NewTelemetry(config *types.TelemetryConfig, collectors map[string]Collector, testMode bool) *Telemetry {
if testMode {
return &Telemetry{
testMode: true,
}
}
return &Telemetry{
config: config,
collectors: collectors,
testMode: false,
}
}
func (t *Telemetry) SpanContext(ctx context.Context, tracerName string, span string, collectors ...string) (context.Context, context.CancelFunc) {
if t.testMode {
return ctx, func() {}
}
var cancelFuncs []context.CancelFunc
// Fetch caller info
pc, _, _, ok := runtime.Caller(1)
functionName := "unknown_function"
if ok {
function := runtime.FuncForPC(pc)
functionName = function.Name()
}
// Use caller info as default tracer and span names if not provided
if tracerName == "" {
tracerName = functionName
}
if span == "" {
span = functionName
}
ctx = context.WithValue(ctx, collectorsKey, collectors)
ctx = context.WithValue(ctx, tracerNameKey, tracerName)
var cancelFunc context.CancelFunc
for _, collector := range collectors {
if c, ok := t.collectors[collector]; ok {
ctx, cancelFunc = c.SpanContext(ctx, span)
cancelFuncs = append(cancelFuncs, cancelFunc)
}
}
cancel := func() {
for _, cancelFunc := range cancelFuncs {
cancelFunc()
}
}
return ctx, cancel
}
func (t *Telemetry) Trace(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.TRACE, message, payload)
}
func (t *Telemetry) Debug(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.DEBUG, message, payload)
}
func (t *Telemetry) Info(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.INFO, message, payload)
}
func (t *Telemetry) Warn(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.WARN, message, payload)
}
func (t *Telemetry) Error(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.ERROR, message, payload)
}
func (t *Telemetry) Fatal(ctx context.Context, message string, payload map[string]interface{}) {
if t.testMode {
return
}
t.logEvent(ctx, types.FATAL, message, payload)
}
func (t *Telemetry) logEvent(ctx context.Context, level types.ObservabilityLevel, message string, payload map[string]interface{}) {
// Check if telemetry is enabled
if t.config.TelemetryMode == "disabled" {
return
}
// Only log events that are at or above the configured log level
if level < logLevel {
return
}
// Add the version to the context
ctx = context.WithValue(ctx, versionKey, docs.SwaggerInfo.Version)
event := types.Event{
Context: ctx,
Level: level,
Message: message,
Payload: payload,
}
// Check for specific collector in context
collectors, ok := ctx.Value(collectorsKey).([]string)
if ok {
for _, collector := range collectors {
if c, ok := t.collectors[collector]; ok {
if err := c.HandleEvent(event); err != nil {
zap.L().Error("Failed to handle event", zap.Error(err))
}
}
}
return
}
// Forward to all collectors by default
for _, collector := range t.collectors {
if err := collector.HandleEvent(event); err != nil {
zap.L().Error("Failed to handle event", zap.Error(err))
}
}
}
func (t *Telemetry) Flush() {
if t.testMode {
return
}
for _, collector := range t.collectors {
if err := collector.Flush(); err != nil {
zap.L().Error("Failed to flush collector", zap.Error(err))
}
}
}
func (t *Telemetry) Shutdown() {
if t.testMode {
return
}
t.Flush()
for _, collector := range t.collectors {
if err := collector.Shutdown(); err != nil {
zap.L().Error("Failed to shut down collector", zap.Error(err))
}
}
}
package types
import (
"errors"
)
// note: this data type may be moved to dms or jobs package in the future
type CapabilityAdder interface {
Add(Capability) error
}
type CapabilitySubtractor interface {
Subtract(Capability) error
}
type CapabilityAddSubtractor interface {
CapabilityAdder
CapabilitySubtractor
}
type Capability struct {
Executors Executors `json:"executor" description:"Executor type required for the job (docker, vm, wasm, or others)"`
JobTypes JobTypes `json:"type" description:"Details about type of the job (One time, batch, recurring, long running). Refer to dms.jobs package for jobType data model"`
Resources ExecutionResources `json:"resources" description:"Resources required for the job"`
Libraries []Library `json:"libraries" description:"Libraries required for the job"`
Localities []Locality `json:"locality" description:"Preferred localities of the machine for execution"`
Storage []Storage `json:"storage" description:"Preferred storage options that the machine should have"`
Connectivity Connectivity `json:"connectivity" description:"Network configuration required"`
Price []PriceInformation `json:"price" description:"Pricing information"`
Time TimeInformation `json:"time" description:"Time constraints"`
KYCs []KYC
}
var _ CapabilityAddSubtractor = &Capability{}
type Connectivity struct {
Ports []int `json:"ports" description:"Ports that need to be open for the job to run"`
VPN bool `json:"vpn" description:"Whether VPN is required"`
}
type PriceInformation struct {
Currency string `json:"currency" description:"Currency used for pricing"`
CurrencyPerHour int `json:"currency_per_hour" description:"Price charged per hour"`
TotalPerJob int `json:"total_per_job" description:"Maximum total price or budget of the job"`
Preference int `json:"preference" description:"Pricing preference"`
}
type TimeInformation struct {
Units string `json:"units" description:"Time units"`
MaxTime int `json:"max_time" description:"Maximum time that job should run"`
Preference int `json:"preference" description:"Time preference"`
}
type Library struct {
Name string `json:"name" description:"Name of the library"`
Constraint string `json:"constraint" description:"Constraint of the library"`
Version string `json:"version" description:"Version of the library"`
}
type Locality struct {
Kind string `json:"kind" description:"Kind of the region (geographic, nunet-defined, etc)"`
Name string `json:"name" description:"Name of the region"`
}
type Storage struct {
Type StorageType `json:"type" description:"Type of storage"`
Size int `json:"size" description:"Size of storage"`
Amount int `json:"amount" description:"Amount of storage"`
}
type StorageType string
const (
//nolint
SSD_STORAGE_TYPE StorageType = "ssd"
//nolint
HDD_STORAGE_TYPE StorageType = "hdd"
)
type KYC struct {
Type string `json:"type" description:"Type of KYC"`
Data string `json:"data" description:"Data required for KYC"`
}
type JobTypes []JobType
type JobType string
const (
BATCH JobType = "batch"
SINGLERUN JobType = "single_run"
RECURRING JobType = "recurring"
LONGRUNNING JobType = "long_running"
)
type (
Executors []Executor
Libraries []Library
Localities []Locality
KYCs []KYC
Storages []Storage
PricesInformation []PriceInformation
)
func (lib *Library) Equal(library Library) bool {
if lib.Name == library.Name && lib.Constraint == library.Constraint && lib.Version == library.Version {
return true
}
return false
}
func (loc *Locality) Equal(locality Locality) bool {
if loc.Kind == locality.Kind && loc.Name == locality.Name {
return true
}
return false
}
func (e *Executor) Equal(executor Executor) bool {
return e.ExecutorType == executor.ExecutorType
}
func (k *KYC) Equal(kyc KYC) bool {
if k.Type == kyc.Type && k.Data == kyc.Data {
return true
}
return false
}
func (p *PriceInformation) Equal(price PriceInformation) bool {
if p.Currency == price.Currency && p.CurrencyPerHour == price.CurrencyPerHour && p.TotalPerJob == price.TotalPerJob && p.Preference == price.Preference {
return true
}
return false
}
func (es Executors) Contains(executor Executor) bool {
for _, e := range es {
if e.Equal(executor) {
return true
}
}
return false
}
func (j JobTypes) Contains(jobType JobType) bool {
for _, j := range j {
if j == jobType {
return true
}
}
return false
}
func (l Libraries) Contains(library Library) bool {
for _, lib := range l {
if lib.Equal(library) {
return true
}
}
return false
}
func (l Localities) Contains(locality Locality) bool {
for _, loc := range l {
if loc.Equal(locality) {
return true
}
}
return false
}
func (s Storages) Contains(storage Storage) bool {
for _, s := range s {
if s.Type == storage.Type {
return true
}
}
return false
}
func (k KYCs) Contains(kyc KYC) bool {
for _, k := range k {
if k.Equal(kyc) {
return true
}
}
return false
}
func (ps PricesInformation) Contains(price PriceInformation) bool {
for _, p := range ps {
if p.Equal(price) {
return true
}
}
return false
}
// Add adds the resources of the given Capability to the current Capability
func (c *Capability) Add(cap Capability) error {
// Executors
for _, executor := range cap.Executors {
if !c.Executors.Contains(executor) {
c.Executors = append(c.Executors, executor)
}
}
// JobTypes
for _, jobType := range cap.JobTypes {
if !c.JobTypes.Contains(jobType) {
c.JobTypes = append(c.JobTypes, jobType)
}
}
// Resources
if c.Resources.CPU.Architecture == cap.Resources.CPU.Architecture {
c.Resources.CPU.Cores += cap.Resources.CPU.Cores
c.Resources.CPU.ClockSpeedHz += cap.Resources.CPU.ClockSpeedHz
}
c.Resources.Memory.ClockSpeedHz += cap.Resources.Memory.ClockSpeedHz // does it make sense?
c.Resources.Memory.Size += cap.Resources.Memory.Size
if c.Resources.Disk.Type == cap.Resources.Disk.Type {
c.Resources.Disk.Size += cap.Resources.Disk.Size
}
c.Resources.GPUs = append(c.Resources.GPUs, cap.Resources.GPUs...)
// Libraries
var myLibraries Libraries = c.Libraries
for _, library := range cap.Libraries {
if !myLibraries.Contains(library) {
c.Libraries = append(c.Libraries, library)
}
}
// Localities
var myLocalities Localities = c.Localities
for _, locality := range cap.Localities {
if !myLocalities.Contains(locality) {
c.Localities = append(c.Localities, locality)
}
}
// Storage
var myStorages Storages = c.Storage
for _, storage := range cap.Storage {
if !myStorages.Contains(storage) {
c.Storage = append(c.Storage, storage)
} else {
for i, s := range c.Storage {
if s.Type == storage.Type {
c.Storage[i].Size += storage.Size
c.Storage[i].Amount += storage.Amount
}
}
}
}
// Connectivity
if cap.Connectivity.VPN {
c.Connectivity.VPN = true
}
for _, port := range cap.Connectivity.Ports {
if !sliceContainsInt(c.Connectivity.Ports, port) {
c.Connectivity.Ports = append(c.Connectivity.Ports, port)
}
}
// Price
var myPrice PricesInformation = c.Price
for _, price := range cap.Price {
if !myPrice.Contains(price) {
c.Price = append(c.Price, price)
}
}
// Time
c.Time.MaxTime += cap.Time.MaxTime
// KYCs
var myKYCs KYCs = c.KYCs
for _, kyc := range cap.KYCs {
if !myKYCs.Contains(kyc) {
c.KYCs = append(c.KYCs, kyc)
}
}
return nil
}
// Subtract subtracts the resources of the given Capability from the current Capability
func (c *Capability) Subtract(cap Capability) error {
// Executors
// No Subtract operation for Executors
// JobTypes
// No Subtract operation for JobTypes
// Resources
if c.Resources.CPU.Cores < cap.Resources.CPU.Cores || c.Resources.CPU.ClockSpeedHz < cap.Resources.CPU.ClockSpeedHz {
return errors.New("cpu resources are not enough")
}
if c.Resources.Memory.Size < cap.Resources.Memory.Size {
return errors.New("memory resources are not enough")
}
if c.Resources.Disk.Size < cap.Resources.Disk.Size {
return errors.New("disk resources are not enough")
}
if c.Resources.CPU.Architecture == cap.Resources.CPU.Architecture {
c.Resources.CPU.Cores -= cap.Resources.CPU.Cores
c.Resources.CPU.Threads -= cap.Resources.CPU.Threads
c.Resources.CPU.ClockSpeedHz -= cap.Resources.CPU.ClockSpeedHz
}
c.Resources.Memory.ClockSpeedHz -= cap.Resources.Memory.ClockSpeedHz // does it make sense?
c.Resources.Memory.Size -= cap.Resources.Memory.Size
if c.Resources.Disk.Type == cap.Resources.Disk.Type {
c.Resources.Disk.Size -= cap.Resources.Disk.Size
}
// Remove the GPUs from the current Capability
// This is a naive implementation and may not work as expected
// if the GPUs are not unique
for _, gpu := range cap.Resources.GPUs {
for i, cgpu := range c.Resources.GPUs {
//nolint
if cgpu.Equal(&gpu) {
c.Resources.GPUs = append(c.Resources.GPUs[:i], c.Resources.GPUs[i+1:]...)
}
}
}
// Libraries
// No Subtract operation for Libraries
// Localities
// No Subtract operation for Localities
// Storage
for _, storage := range cap.Storage {
for i, myStorage := range c.Storage {
if storage.Type == myStorage.Type && storage.Amount == myStorage.Amount && storage.Size == myStorage.Size {
c.Storage = append(c.Storage[:i], c.Storage[i+1:]...)
} else if storage.Type == myStorage.Type {
c.Storage[i].Size -= storage.Size
c.Storage[i].Amount -= storage.Amount
}
}
}
// Connectivity
for _, port := range cap.Connectivity.Ports {
for i, cport := range c.Connectivity.Ports {
if cport == port {
c.Connectivity.Ports = append(c.Connectivity.Ports[:i], c.Connectivity.Ports[i+1:]...)
}
}
}
// Price
// No Subtract operation for Price
// Time
c.Time.MaxTime -= cap.Time.MaxTime
// KYCs
// No Subtract operation for KYCs
return nil
}
// SliceContainsInt checks if a integer exists in a slice
func sliceContainsInt(s []int, val int) bool {
for _, v := range s {
if v == val {
return true
}
}
return false
}
package types
type Comparison string
const (
Worse Comparison = "Worse" // left object is 'worse' than right object
Better Comparison = "Better" // left object is 'better' than right object
Equal Comparison = "Equal" // objects on the left and right are 'equally good'
Error Comparison = "Error" // error in comparison or objects incomparable
)
// TODO: Consider comments in this thread: https://gitlab.com/nunet/device-management-service/-/merge_requests/356#note_1997854443
// TODO: Consider comments in this thread: https://gitlab.com/nunet/device-management-service/-/merge_requests/356#note_1997875361
// 'left' means 'this object' and 'right' means 'the supplied other object';
// it makes sense when using the type in functions like Compare(left, right)
// this type is unused but still reserved in case we will need it in the future
// and still used in some tests that are left from previous versions of the package;
// TOTO: remove / update after the package is finished and refactor
type ComplexComparison map[string]Comparison
// And returns the result of AND operation of two Comparison values
// it respects the following table of truth:
// | AND | Better | Worse | Equal | Error |
// | ------ | ------ |--------|--------|--------|
// | Better | Better | Worse | Better | Error |
// | Worse | Worse | Worse | Worse | Error |
// | Equal | Better | Worse | Equal | Error |
// | Error | Error | Error | Error | Error |
func (c Comparison) And(cmp Comparison) Comparison {
if c == Error || cmp == Error {
return Error
}
if c == cmp {
return c
}
switch c {
case Equal:
switch cmp {
case Better:
return Better
case Worse:
return Worse
default:
return Error
}
case Better:
switch cmp {
case Worse:
return Worse
case Equal:
return Better
default:
return Error
}
case Worse:
if cmp == Error {
return Error
}
return Worse
default:
return Error
}
}
package types
type Executor struct {
ExecutorType ExecutorType `json:"executor_type"`
}
type ExecutorType string
const (
ExecutorTypeDocker = "docker"
ExecutorTypeFirecracker = "firecracker"
ExecutorTypeWasm = "wasm"
ExecutionStatusCodeSuccess = 0
)
// ExecutionRequest is the request object for executing a job
type ExecutionRequest struct {
JobID string // ID of the job to execute
ExecutionID string // ID of the execution
EngineSpec *SpecConfig // Engine spec for the execution
Resources *ExecutionResources // Resources for the execution
Inputs []*StorageVolumeExecutor // Input volumes for the execution
Outputs []*StorageVolumeExecutor // Output volumes for the results
ResultsDir string // Directory to store the results
}
// ExecutionResult is the result of an execution
type ExecutionResult struct {
STDOUT string `json:"stdout"` // STDOUT of the execution
STDERR string `json:"stderr"` // STDERR of the execution
ExitCode int `json:"exit_code"` // Exit code of the execution
ErrorMsg string `json:"error_msg"` // Error message if the execution failed
}
// NewExecutionResult creates a new ExecutionResult object
func NewExecutionResult(code int) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: code,
}
}
// NewFailedExecutionResult creates a new ExecutionResult object for a failed execution
// It sets the error message from the provided error and sets the exit code to -1
func NewFailedExecutionResult(err error) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: -1,
ErrorMsg: err.Error(),
}
}
// LogStreamRequest is the request object for streaming logs from an execution
type LogStreamRequest struct {
JobID string // ID of the job
ExecutionID string // ID of the execution
Tail bool // Tail the logs
Follow bool // Follow the logs
}
package types
const (
NetP2P = "p2p"
)
// NetworkSpec is a stub. Please expand based on requirements.
type NetworkSpec struct{}
// NetConfig is a stub. Please expand it or completely change it based on requirements.
type NetConfig struct {
NetworkSpec SpecConfig `json:"network_spec"` // Network specification
}
func (nc *NetConfig) GetNetworkConfig() *SpecConfig {
return &nc.NetworkSpec
}
// NetworkStats should contain all network info the user is interested in.
// for now there's only peerID and listening address but reachability, local and remote addr etc...
// can be added when necessary.
type NetworkStats struct {
ID string `json:"id"`
ListenAddr string `json:"listen_addr"`
}
// MessageInfo is a stub. Please expand it or completely change it based on requirements.
type MessageInfo struct {
Info string `json:"info"` // Message information
}
package types
import (
"fmt"
"strings"
)
type GPUVendor string
const (
GPUVendorNvidia GPUVendor = "NVIDIA"
GPUVendorAMDATI GPUVendor = "AMD/ATI"
GPUVendorIntel GPUVendor = "Intel"
GPUVendorUnknown GPUVendor = "Unknown"
None GPUVendor = "None"
)
func ParseGPUVendor(vendor string) GPUVendor {
switch {
case strings.Contains(vendor, "NVIDIA"):
return GPUVendorNvidia
case strings.Contains(vendor, "AMD") || strings.Contains(vendor, "ATI"):
return GPUVendorAMDATI
case strings.Contains(vendor, "Intel"):
return GPUVendorIntel
default:
return GPUVendorUnknown
}
}
type GPU struct {
// Index is the self-reported index of the device in the system
Index int
// Name is the model name of the GPU e.g. Tesla T4
Name string
// Vendor is the maker of the GPU, e.g. NVidia, AMD, Intel
Vendor GPUVendor
// PCIAddress is the PCI address of the device, in the format AAAA:BB:CC.C
// Used to discover the correct device rendering cards
PCIAddress string
// Model of the GPU, e.g. A100
Model string `json:"model" description:"GPU model, ex A100"`
// TotalVRAM is the total amount of VRAM on the device
TotalVRAM uint64
// UsedVRAM is the amount of VRAM currently in use
UsedVRAM uint64
// FreeVRAM is the amount of VRAM currently free
FreeVRAM uint64
// Gorm fields
// Team, is this the right way to do this? What is the best practice we're following?
ResourceID uint `gorm:"foreignKey:ID"`
}
func (g *GPU) Equal(gpu *GPU) bool {
if g.Model == gpu.Model &&
g.TotalVRAM == gpu.TotalVRAM &&
g.UsedVRAM == gpu.UsedVRAM &&
g.FreeVRAM == gpu.FreeVRAM &&
g.Index == gpu.Index &&
g.Vendor == gpu.Vendor &&
g.PCIAddress == gpu.PCIAddress {
return true
}
return false
}
type GPUList []GPU
// GetGPUWithHighestFreeVRAM Determine the GPU vendor with the highest free VRAM: NVIDIA, AMD, or Intel.
// Useful for selecting the best GPU if multiple vendors are available,
// especially in multi-GPU systems or mining rigs.
func (gpus GPUList) GetGPUWithHighestFreeVRAM() (GPU, error) {
if len(gpus) == 0 {
// Return a GPU with Vendor set to None if no GPUs are detected - Useful for launching CPU-only containers
return GPU{Vendor: None}, nil
}
var maxFreeVRAMGpu GPU
maxFreeVRAM := uint64(0)
for _, gpu := range gpus {
if gpu.FreeVRAM > maxFreeVRAM {
maxFreeVRAM = gpu.FreeVRAM
maxFreeVRAMGpu = gpu
}
}
return maxFreeVRAMGpu, nil
}
// negativeValueError is a type struct used to return a custom error for negative values in resources subtraction
type negativeValueError struct {
resource string
r1 any
r2 any
}
// Error returns the error message
func (e *negativeValueError) Error() string {
return fmt.Sprintf("Error: %s subtraction results in negative values. (%d - %d)", e.resource, e.r1, e.r2)
}
// ResourceOps defines the operations on resources
// TODO: Check how to handle GPU resources
type ResourceOps interface {
// Add returns the sum of the resources
Add(r Resources) Resources
// Subtract returns the difference of the resources
Subtract(r Resources) (Resources, error)
}
// Resources represents the resources of the machine
type Resources struct {
CPU float64
NumCores int
GPU []GPU `gorm:"foreignKey:ResourceID"`
RAM uint64
Disk uint64
}
// Add returns the sum of the resources
func (r Resources) Add(r2 Resources) Resources {
// TODO: GPU addition
return Resources{
CPU: r.CPU + r2.CPU,
RAM: r.RAM + r2.RAM,
Disk: r.Disk + r2.Disk,
}
}
// Subtract returns the difference of the resources
func (r Resources) Subtract(r2 Resources) (Resources, error) {
// Check if the subtraction results in negative values
// Team, why do we need to return a negative value error?
// Can't we just return the negative value indicating that the machine is overused?
// We can then handle the scenario in the calling function by checking if the result is negative if required.
if r.CPU < r2.CPU {
return Resources{}, &negativeValueError{resource: "CPU", r1: r.CPU, r2: r2.CPU}
}
// TODO: GPU subtraction
if r.RAM < r2.RAM {
return Resources{}, &negativeValueError{resource: "RAM", r1: r.RAM, r2: r2.RAM}
}
if r.Disk < r2.Disk {
return Resources{}, &negativeValueError{resource: "Disk", r1: r.Disk, r2: r2.Disk}
}
return Resources{
CPU: r.CPU - r2.CPU,
RAM: r.RAM - r2.RAM,
Disk: r.Disk - r2.Disk,
}, nil
}
var _ ResourceOps = (*Resources)(nil)
// MachineResources represents the total resources of the machine
type MachineResources struct {
BaseDBModel
Resources
}
// FreeResources represents the free resources of the machine
type FreeResources struct {
BaseDBModel
Resources
}
// OnboardedResources represents the onboarded resources of the machine
type OnboardedResources struct {
BaseDBModel
Resources
}
// RequiredResources represents the resources required by the jobs running on the machine
// TODO: this is a replacement for ServiceResourceRequirements. Check with the team on this.
type RequiredResources struct {
BaseDBModel
JobID int
Resources
}
// CPUInfo represents the CPU information of the machine
type CPUInfo struct {
NumCores uint64
MHzPerCore float64
Compute float64
}
// SpecInfo represents the machine specifications
// TODO: Finalise the fields required in this struct
// https://gitlab.com/nunet/device-management-service/-/issues/533
type SpecInfo struct {
CPUs []CPU
GPUs []GPU
RAMs []RAM
Disks []Disk
Network NetworkInfo
}
// CPU represents the CPU information
type CPU struct {
// Model represents the CPU model, e.g., "Intel Core i7-9700K", "AMD Ryzen 9 5900X"
Model string
// Vendor represents the CPU manufacturer, e.g., "Intel", "AMD"
Vendor string
// ClockSpeedHz represents the CPU clock speed in Hz
ClockSpeedHz int64
// Cores represents the number of physical CPU cores
Cores uint32
// Threads represents the number of logical CPU threads (including hyperthreading)
Threads int
// Architecture represents the CPU architecture, e.g., "x86", "x86_64", "arm64"
Architecture string
// Cache size in bytes
CacheSize uint64
}
// RAM represents the RAM information
type RAM struct {
// Size in bytes
Size int64
// Clock speed in Hz
ClockSpeedHz uint64
// Type represents the RAM type, e.g., "DDR4", "DDR5", "LPDDR4"
Type string
}
// Disk represents the disk information
type Disk struct {
// Model represents the disk model, e.g., "Samsung 970 EVO Plus", "Western Digital Blue SN550"
// TODO: may be removed as Disk models will be usually irrelevant, right?
Model string
// Vendor represents the disk manufacturer, e.g., "Samsung", "Western Digital"
// TODO: may be removed as Disk vendors will be usually irrelevant, right?
Vendor string
// Size in bytes
Size uint64
// Type represents the disk type, e.g., "SSD", "HDD", "NVMe"
Type string
// Interface represents the disk interface, e.g., "SATA", "PCIe", "M.2"
Interface string
// Read speed in bytes per second
// TODO: may be removed as it may be too specific for our case
ReadSpeed uint64
// Write speed in bytes per second
// TODO: may be removed as it may be too specific for our case
WriteSpeed uint64
}
// NetworkInfo represents the network information
type NetworkInfo struct {
// Bandwidth in bits per second (b/s)
Bandwidth uint64
// NetworkType represents the network type, e.g., "Ethernet", "Wi-Fi", "Cellular"
NetworkType string
}
// ExecutionResources represents the resources required to execute a task
type ExecutionResources struct {
// CPU configuration
CPU CPU `json:"cpu,omitempty" description:"CPU configuration"`
// Memory configuration
Memory RAM `json:"memory,omitempty" description:"Memory configuration"`
// Disk configuration
Disk Disk `json:"disk,omitempty" description:"Disk configuration"`
// GPU configuration
GPUs []GPU `json:"gpus,omitempty" description:"GPU configuration"`
}
package types
import (
"errors"
"strings"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// SpecConfig represents a configuration for a spec
// A SpecConfig can be used to define an engine spec, a storage volume, etc.
type SpecConfig struct {
// Type of the spec (e.g. docker, firecracker, storage, etc.)
Type string `json:"type"`
// Params of the spec
Params map[string]interface{} `json:"params,omitempty"`
}
type Config interface {
GetNetworkConfig() *SpecConfig
}
// NewSpecConfig creates a new SpecConfig with the given type
func NewSpecConfig(t string) *SpecConfig {
return &SpecConfig{
Type: t,
Params: make(map[string]interface{}),
}
}
// WithParam adds a new key-value pair to the spec params
func (s *SpecConfig) WithParam(key string, value interface{}) *SpecConfig {
if s.Params == nil {
s.Params = make(map[string]interface{})
}
s.Params[key] = value
return s
}
// Normalize ensures that the spec config is in a valid state
func (s *SpecConfig) Normalize() {
if s == nil {
return
}
s.Type = strings.TrimSpace(s.Type)
// Ensure that an empty and nil map are treated the same
if len(s.Params) == 0 {
s.Params = make(map[string]interface{})
}
}
// Validate checks if the spec config is valid
func (s *SpecConfig) Validate() error {
if s == nil {
return errors.New("nil spec config")
}
if validate.IsBlank(s.Type) {
return errors.New("missing spec type")
}
return nil
}
// IsType returns true if the current SpecConfig is of the given type
func (s *SpecConfig) IsType(t string) bool {
if s == nil {
return false
}
t = strings.TrimSpace(t)
return strings.EqualFold(s.Type, t)
}
// IsEmpty returns true if the spec config is empty
func (s *SpecConfig) IsEmpty() bool {
return s == nil || (validate.IsBlank(s.Type) && len(s.Params) == 0)
}
package types
import (
"context"
)
// TelemetryConfig holds the configuration for the telemetry system.
type TelemetryConfig struct {
ServiceName string
GlobalEndpoint string
ObservabilityLevel string
CollectorConfigs map[string]CollectorConfig
TelemetryMode string
}
// CollectorConfig holds the configuration for individual collectors.
type CollectorConfig struct {
CollectorType string
CollectorEndpoint string
}
// Event represents a telemetry event with its details.
type Event struct {
Context context.Context
Level ObservabilityLevel
Message string
Payload map[string]interface{}
}
// ObservabilityLevel defines the levels of observability.
type ObservabilityLevel int
const (
TRACE ObservabilityLevel = 1
DEBUG ObservabilityLevel = 2
INFO ObservabilityLevel = 3
WARN ObservabilityLevel = 4
ERROR ObservabilityLevel = 5
FATAL ObservabilityLevel = 6
)
// ParseObservabilityLevel converts a string representation of the observability level to an integer.
func ParseObservabilityLevel(levelStr string) (ObservabilityLevel, error) {
switch levelStr {
case "TRACE":
return TRACE, nil
case "DEBUG":
return DEBUG, nil
case "INFO":
return INFO, nil
case "WARN":
return WARN, nil
case "ERROR":
return ERROR, nil
case "FATAL":
return FATAL, nil
default:
return INFO, nil
}
}
func (level ObservabilityLevel) String() string {
switch level {
case TRACE:
return "TRACE"
case DEBUG:
return "DEBUG"
case INFO:
return "INFO"
case WARN:
return "WARN"
case ERROR:
return "ERROR"
case FATAL:
return "FATAL"
default:
return "UNKNOWN"
}
}
package types
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// BaseDBModel is a base model for all entities. It'll be mainly used for database
// records.
type BaseDBModel struct {
ID string `gorm:"type:uuid"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// GetID returns the ID of the entity.
func (m BaseDBModel) GetID() string {
return m.ID
}
// BeforeCreate sets the ID and CreatedAt fields before creating a new entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeCreate(_ *gorm.DB) error {
m.ID = uuid.NewString()
m.CreatedAt = time.Now()
return nil
}
// BeforeUpdate sets the UpdatedAt field before updating an entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeUpdate(_ *gorm.DB) error {
m.UpdatedAt = time.Now()
return nil
}
package utils
import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/cosmos/btcutil/bech32"
"github.com/ethereum/go-ethereum/common"
"github.com/fivebinaries/go-cardano-serialization/address"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
)
// KoiosEndpoint type for Koios rest api endpoints
type KoiosEndpoint string
const (
// KoiosMainnet - mainnet Koios rest api endpoint
KoiosMainnet KoiosEndpoint = "api.koios.rest"
// KoiosPreProd - testnet preprod Koios rest api endpoint
KoiosPreProd KoiosEndpoint = "preprod.koios.rest"
)
type UTXOs struct {
TxHash string `json:"tx_hash"`
IsSpent bool `json:"is_spent"`
}
type TxHashResp struct {
TxHash string `json:"tx_hash"`
TransactionType string `json:"transaction_type"`
DateTime string `json:"date_time"`
}
type ClaimCardanoTokenBody struct {
ComputeProviderAddress string `json:"compute_provider_address"`
TxHash string `json:"tx_hash"`
}
type RewardRespToCPD struct {
ServiceProviderAddr string `json:"service_provider_addr"`
ComputeProviderAddr string `json:"compute_provider_addr"`
RewardType string `json:"reward_type,omitempty"`
SignatureDatum string `json:"signature_datum,omitempty"`
MessageHashDatum string `json:"message_hash_datum,omitempty"`
Datum string `json:"datum,omitempty"`
SignatureAction string `json:"signature_action,omitempty"`
MessageHashAction string `json:"message_hash_action,omitempty"`
Action string `json:"action,omitempty"`
}
type UpdateTxStatusBody struct {
Address string `json:"address,omitempty"`
}
func GetJobTxHashes(size int, clean string) ([]TxHashResp, error) {
if clean != "done" && clean != "refund" && clean != "withdraw" && clean != "" {
return nil, fmt.Errorf("invalid clean_tx parameter")
}
err := db.DB.Where("transaction_type = ?", clean).Delete(&types.Services{}).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
}
resp := make([]TxHashResp, 0)
services := make([]types.Services, 0)
if size == 0 {
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type is NOT NULL").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("no job deployed to request reward for: %w", err)
}
} else {
services, err = getLimitedTransactions(size)
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("could not get limited transactions: %w", err)
}
}
for _, service := range services {
resp = append(resp, TxHashResp{
TxHash: service.TxHash,
TransactionType: service.TransactionType,
DateTime: service.CreatedAt.String(),
})
}
return resp, nil
}
func RequestReward(claim ClaimCardanoTokenBody) (*RewardRespToCPD, error) {
// At some point, management dashboard should send container ID to identify
// against which container we are requesting reward
service := types.Services{
TxHash: claim.TxHash,
}
// SELECTs the first record; first record which is not marked as delete
err := db.DB.Where("tx_hash = ?", claim.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
return nil, fmt.Errorf("unknown tx hash: %w", err)
}
zlog.Sugar().Infof("service found from txHash: %+v", service)
if service.JobStatus == "running" {
return nil, fmt.Errorf("job is still running")
// c.JSON(503, gin.H{"error": "the job is still running"})
}
reward := RewardRespToCPD{
ServiceProviderAddr: service.ServiceProviderAddr,
ComputeProviderAddr: service.ComputeProviderAddr,
RewardType: service.TransactionType,
SignatureDatum: service.SignatureDatum,
MessageHashDatum: service.MessageHashDatum,
Datum: service.Datum,
SignatureAction: service.SignatureAction,
MessageHashAction: service.MessageHashAction,
Action: service.Action,
}
return &reward, nil
}
func SendStatus(status types.BlockchainTxStatus) string {
if status.TransactionStatus == "success" {
zlog.Sugar().Infof("withdraw transaction successful - updating DB")
// Partial deletion of entry
var service types.Services
err := db.DB.Where("tx_hash = ?", status.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
}
service.TransactionType = "done"
db.DB.Save(&service)
}
return status.TransactionStatus
}
func UpdateStatus(body UpdateTxStatusBody) error {
utxoHashes, err := GetUTXOsOfSmartContract(body.Address, KoiosPreProd)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to fetch UTXOs from Blockchain: %w", err)
}
fiveMinAgo := time.Now().Add(-5 * time.Minute)
var services []types.Services
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Where("deleted_at IS NULL").
Where("created_at <= ?", fiveMinAgo).
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("no job deployed to request reward for: %w", err)
}
err = UpdateTransactionStatus(services, utxoHashes)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to update transaction status")
}
return nil
}
func getLimitedTransactions(sizeDone int) ([]types.Services, error) {
var doneServices []types.Services
var services []types.Services
err := db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type = ?", "done").
Order("created_at DESC").
Limit(sizeDone).
Find(&doneServices).Error
if err != nil {
return []types.Services{}, err
}
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
return []types.Services{}, err
}
services = append(services, doneServices...)
return services, nil
}
// isValidCardano checks if the cardano address is valid
func isValidCardano(addr string, valid *bool) {
defer func() {
if r := recover(); r != nil {
*valid = false
}
}()
if _, err := address.NewAddress(addr); err == nil {
*valid = true
}
}
// ValidateAddress checks if the wallet address is a valid ethereum/cardano address
func ValidateAddress(addr string) error {
if common.IsHexAddress(addr) {
return errors.New("ethereum wallet address not allowed")
}
validCardano := false
isValidCardano(addr, &validCardano)
if validCardano {
return nil
}
return errors.New("invalid cardano wallet address")
}
func GetAddressPaymentCredential(addr string) (string, error) {
_, data, err := bech32.Decode(addr, 1023)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
converted, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
return hex.EncodeToString(converted)[2:58], nil
}
// GetTxReceiver returns the list of receivers of a transaction from the transaction hash
func GetTxReceiver(txHash string, endpoint KoiosEndpoint) (string, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_info", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return "", err
}
defer resp.Body.Close()
res := []struct {
Outputs []struct {
InlineDatum struct {
Value struct {
Fields []struct {
Bytes string `json:"bytes"`
} `json:"fields"`
} `json:"value"`
} `json:"inline_datum"`
} `json:"outputs"`
}{}
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&res); err != nil && err != io.EOF {
return "", err
}
if len(res) == 0 || len(res[0].Outputs) == 0 || len(res[0].Outputs[1].InlineDatum.Value.Fields) == 0 {
return "", fmt.Errorf("unable to find receiver")
}
receiver := res[0].Outputs[1].InlineDatum.Value.Fields[1].Bytes
return receiver, nil
}
// GetTxConfirmations returns the number of confirmations of a transaction from the transaction hash
func GetTxConfirmations(txHash string, endpoint KoiosEndpoint) (int, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_status", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return 0, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}
var res []struct {
TxHash string `json:"tx_hash"`
Confirmations int `json:"num_confirmations"`
}
if err := json.Unmarshal(body, &res); err != nil {
return 0, err
}
return res[len(res)-1].Confirmations, nil
}
// WaitForTxConfirmation waits for a transaction to be confirmed
func WaitForTxConfirmation(confirmations int, timeout time.Duration, txHash string, endpoint KoiosEndpoint) error {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
conf, err := GetTxConfirmations(txHash, endpoint)
if err != nil {
return err
}
if conf >= confirmations {
return nil
}
case <-time.After(timeout):
return errors.New("timeout")
}
}
}
// GetUTXOsOfSmartContract fetch all utxos of smart contract and return list of tx_hash
func GetUTXOsOfSmartContract(address string, endpoint KoiosEndpoint) ([]string, error) {
type Request struct {
Address []string `json:"_addresses"`
Extended bool `json:"_extended"`
}
reqBody, _ := json.Marshal(Request{Address: []string{address}, Extended: true})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/address_utxos", endpoint),
"application/json",
bytes.NewBuffer(reqBody),
)
if err != nil {
return nil, fmt.Errorf("error making POST request: %v", err)
}
defer resp.Body.Close()
var utxos []UTXOs
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&utxos); err != nil && err != io.EOF {
return nil, err
}
utxoHashes := make([]string, 0)
for _, utxo := range utxos {
utxoHashes = append(utxoHashes, utxo.TxHash)
}
return utxoHashes, nil
}
// UpdateTransactionStatus updates the status of claimed transactions in local DB
func UpdateTransactionStatus(services []types.Services, utxoHashes []string) error {
for _, service := range services {
if !SliceContains(utxoHashes, service.TxHash) {
switch service.TransactionType {
case "withdraw":
{
service.TransactionType = transactionWithdrawnStatus
}
case "refund":
{
service.TransactionType = transactionRefundedStatus
}
case "distribute-50":
case "distribute-75":
{
service.TransactionType = transactionDistributedStatus
}
}
s := service
if err := db.DB.Save(&s).Error; err != nil {
return err
}
}
}
return nil
}
package utils
import (
"fmt"
"os"
"github.com/spf13/afero"
)
func GetDirectorySize(fs afero.Fs, path string) (int64, error) {
var size int64
err := afero.Walk(fs, path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to calculate volume size: %w", err)
}
return size, nil
}
package utils
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"path"
)
type HTTPClient struct {
BaseURL string
APIVersion string
Client *http.Client
}
func NewHTTPClient(baseURL, version string) *HTTPClient {
return &HTTPClient{
BaseURL: baseURL,
APIVersion: version,
Client: http.DefaultClient,
}
}
// MakeRequest performs an HTTP request with the given method, path, and body
// It returns the response body, status code, and an error if any
func (c *HTTPClient) MakeRequest(method, relativePath string, body []byte) ([]byte, int, error) {
url, err := url.Parse(c.BaseURL)
if err != nil {
return nil, 0, fmt.Errorf("failed to parse base URL: %v", err)
}
url.Path = path.Join(c.APIVersion, relativePath)
req, err := http.NewRequest(method, url.String(), bytes.NewBuffer(body))
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
resp, err := c.Client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("request failed: %v", err)
}
defer resp.Body.Close()
// Read the response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("failed to read response body: %v", err)
}
return respBody, resp.StatusCode, nil
}
package utils
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
const transactionWithdrawnStatus = "withdrawn"
const transactionRefundedStatus = "refunded"
const transactionDistributedStatus = "distributed"
func init() {
zlog = logger.OtelZapLogger("utils")
}
package utils
import (
"io"
"sync"
"time"
)
type IOProgress struct {
n float64
size float64
started time.Time
estimated time.Time
err error
}
type Reader struct {
reader io.Reader
lock sync.RWMutex
Progress IOProgress
}
type Writer struct {
writer io.Writer
lock sync.RWMutex
Progress IOProgress
}
func ReaderWithProgress(r io.Reader, size int64) *Reader {
return &Reader{
reader: r,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func WriterWithProgress(w io.Writer, size int64) *Writer {
return &Writer{
writer: w,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func (r *Reader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.lock.Lock()
r.Progress.n += float64(n)
r.Progress.err = err
r.lock.Unlock()
return n, err
}
func (w *Writer) Write(p []byte) (n int, err error) {
n, err = w.writer.Write(p)
w.lock.Lock()
w.Progress.n += float64(n)
w.Progress.err = err
w.lock.Unlock()
return n, err
}
func (p IOProgress) Size() float64 {
return p.size
}
func (p IOProgress) N() float64 {
return p.n
}
func (p IOProgress) Complete() bool {
if p.err == io.EOF {
return true
}
if p.size == -1 {
return false
}
return p.n >= p.size
}
// Percent calculates the percentage complete.
func (p IOProgress) Percent() float64 {
if p.n == 0 {
return 0
}
if p.n >= p.size {
return 100
}
return 100.0 / (p.size / p.n)
}
func (p IOProgress) Remaining() time.Duration {
if p.estimated.IsZero() {
return time.Until(p.Estimated())
}
return time.Until(p.estimated)
}
func (p IOProgress) Estimated() time.Time {
ratio := p.n / p.size
past := float64(time.Since(p.started))
if p.n > 0.0 {
total := time.Duration(past / ratio)
p.estimated = p.started.Add(total)
}
return p.estimated
}
package utils
import (
"fmt"
"strings"
"sync"
)
// A SyncMap is a concurrency-safe sync.Map that uses strongly-typed
// method signatures to ensure the types of its stored data are known.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// SyncMapFromMap converts a standard Go map to a concurrency-safe SyncMap.
func SyncMapFromMap[K comparable, V any](m map[K]V) *SyncMap[K, V] {
ret := &SyncMap[K, V]{}
for k, v := range m {
ret.Put(k, v)
}
return ret
}
// Get retrieves the value associated with the given key from the map.
// It returns the value and a boolean indicating whether the key was found.
func (m *SyncMap[K, V]) Get(key K) (V, bool) {
value, ok := m.Load(key)
if !ok {
var empty V
return empty, false
}
return value.(V), true
}
// Put inserts or updates a key-value pair in the map.
func (m *SyncMap[K, V]) Put(key K, value V) {
m.Store(key, value)
}
// Iter iterates over each key-value pair in the map, executing the provided function on each pair.
// The iteration stops if the provided function returns false.
func (m *SyncMap[K, V]) Iter(ranger func(key K, value V) bool) {
m.Range(func(key, value any) bool {
k := key.(K)
v := value.(V)
return ranger(k, v)
})
}
// Keys returns a slice containing all the keys present in the map.
func (m *SyncMap[K, V]) Keys() []K {
var keys []K
m.Iter(func(key K, _ V) bool {
keys = append(keys, key)
return true
})
return keys
}
// String provides a string representation of the map, listing all key-value pairs.
func (m *SyncMap[K, V]) String() string {
// Use a strings.Builder for efficient string concatenation.
var sb strings.Builder
sb.Write([]byte(`{`))
m.Range(func(key, value any) bool {
// Append each key-value pair to the string builder.
sb.Write([]byte(fmt.Sprintf(`%s=%s`, key, value)))
return true
})
sb.Write([]byte(`}`))
return sb.String()
}
package utils
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"math/big"
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"time"
"github.com/google/uuid"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
const (
KernelFileURL = "https://d.nunet.io/fc/vmlinux"
KernelFilePath = "/etc/nunet/vmlinux"
FilesystemURL = "https://d.nunet.io/fc/nunet-fc-ubuntu-20.04-0.ext4"
FilesystemPath = "/etc/nunet/nunet-fc-ubuntu-20.04-0.ext4"
)
// DownloadFile downloads a file from a url and saves it to a filepath
func DownloadFile(url string, filepath string) (err error) {
zlog.Sugar().Infof("Downloading file '", filepath, "' from '", url, "'")
file, err := os.Create(filepath)
if err != nil {
return err
}
defer file.Close()
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
_, err = io.Copy(file, resp.Body)
if err != nil {
return err
}
log.Println("Finished downloading file '", filepath, "'")
return nil
}
// ReadHTTPString GET request to http endpoint and return response as string
func ReadHTTPString(url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(respBody), nil
}
// RandomString generates a random string of length n
func RandomString(n int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
sb.WriteByte(charset[n.Int64()])
}
return sb.String(), nil
}
// GenerateMachineUUID generates a machine uuid
func GenerateMachineUUID() (string, error) {
var machine types.MachineUUID
machineUUID, err := uuid.NewDCEGroup()
if err != nil {
return "", err
}
machine.UUID = machineUUID.String()
return machine.UUID, nil
}
// GetMachineUUID returns the machine uuid from the DB
func GetMachineUUID() string {
var machine types.MachineUUID
uuid, err := GenerateMachineUUID()
if err != nil {
zlog.Sugar().Errorf("could not generate machine uuid: %v", err)
}
machine.UUID = uuid
result := db.DB.FirstOrCreate(&machine)
if result.Error != nil {
zlog.Sugar().Errorf("could not find or create machine uuid record in DB: %v", result.Error)
}
return machine.UUID
}
// SliceContains checks if a string exists in a slice
func SliceContains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
// DeleteFile deletes a file, with or without a backup
func DeleteFile(path string, backup bool) (err error) {
if backup {
err = os.Rename(path, fmt.Sprintf("%s.bk.%d", path, time.Now().Unix()))
} else {
err = os.Remove(path)
}
return
}
// ReadyForElastic checks if the device is ready to send logs to elastic
func ReadyForElastic() bool {
elasticToken := types.ElasticToken{}
db.DB.Find(&elasticToken)
return elasticToken.NodeID != "" && elasticToken.ChannelName != ""
}
// PromptYesNo loops on confirmation from user until valid answer
func PromptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) {
reader := bufio.NewReader(in)
for {
fmt.Fprintf(out, "%s (y/N): ", prompt)
response, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("read response string failed: %w", err)
}
response = strings.ToLower(strings.TrimSpace(response))
if response == "y" || response == "yes" {
return true, nil
} else if response == "n" || response == "no" {
return false, nil
}
}
}
// CreateDirectoryIfNotExists creates a directory if it does not exist
func CreateDirectoryIfNotExists(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
err := os.MkdirAll(path, 0o755)
if err != nil {
return err
}
}
return nil
}
// CalculateSHA256Checksum calculates the SHA256 checksum of a file
func CalculateSHA256Checksum(filePath string) (string, error) {
// Open the file for reading
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
// Create a new SHA-256 hash
hash := sha256.New()
// Copy the file's contents into the hash object
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
// Calculate the checksum and return it as a hexadecimal string
checksum := hex.EncodeToString(hash.Sum(nil))
return checksum, nil
}
// put checksum in file
func CreateCheckSumFile(filePath string, checksum string) (string, error) {
sha256FilePath := fmt.Sprintf("%s.sha256.txt", filePath)
sha256File, err := os.Create(sha256FilePath)
if err != nil {
return "", fmt.Errorf("unable to create SHA-256 checksum file: %v", err)
}
defer sha256File.Close()
_, err = sha256File.WriteString(checksum)
if err != nil {
return "", fmt.Errorf("unable to write to SHA-256 checksum file: %v", err)
}
return sha256FilePath, nil
}
// SanitizeArchivePath Sanitize archive file pathing from "G305: Zip Slip vulnerability"
func SanitizeArchivePath(d, t string) (v string, err error) {
v = filepath.Join(d, t)
if strings.HasPrefix(v, filepath.Clean(d)) {
return v, nil
}
return "", fmt.Errorf("%s: %s", "content filepath is tainted", t)
}
// ExtractTarGzToPath extracts a tar.gz file to a specified path
func ExtractTarGzToPath(tarGzFilePath, extractedPath string) error {
// Ensure the target directory exists; create it if it doesn't.
if err := os.MkdirAll(extractedPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating target directory: %v", err)
}
tarGzFile, err := os.Open(tarGzFilePath)
if err != nil {
return fmt.Errorf("error opening tar.gz file: %v", err)
}
defer tarGzFile.Close()
gzipReader, err := gzip.NewReader(tarGzFile)
if err != nil {
return fmt.Errorf("error creating gzip reader: %v", err)
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading tar header: %v", err)
}
// Construct the full target path by joining the target directory with
// the name of the file or directory from the archive.
fullTargetPath, err := SanitizeArchivePath(extractedPath, header.Name)
if err != nil {
return fmt.Errorf("failed to santize path %w", err)
}
// Ensure that the directory path leading to the file exists.
if header.FileInfo().IsDir() {
// Create the directory and any parent directories as needed.
if err := os.MkdirAll(fullTargetPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
} else {
// Create the file and any parent directories as needed.
if err := os.MkdirAll(filepath.Dir(fullTargetPath), os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// Create a new file with the specified path.
newFile, err := os.Create(fullTargetPath)
if err != nil {
return fmt.Errorf("error creating file: %v", err)
}
defer newFile.Close()
// Copy the file contents from the tar archive to the new file.
for {
_, err := io.CopyN(newFile, tarReader, 1024)
if err != nil {
if err == io.EOF {
break
}
return err
}
}
}
}
return nil
}
// CheckWSL check if running in WSL
func CheckWSL(afs afero.Afero) (bool, error) {
file, err := afs.Open("/proc/version")
if err != nil {
return false, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "Microsoft") || strings.Contains(line, "WSL") {
return true, nil
}
}
if scanner.Err() != nil {
return false, scanner.Err()
}
return false, nil
}
// SaveServiceInfo updates service info into SP's DMS for claim Reward by SP user
func SaveServiceInfo(cpService types.Services) error {
var spService types.Services
err := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Find(&spService).Error
if err != nil {
return fmt.Errorf("unable to find service on SP side: %v", err)
}
cpService.ID = spService.ID
cpService.CreatedAt = spService.CreatedAt
result := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Updates(&cpService)
if result.Error != nil {
return fmt.Errorf("unable to update service info on SP side: %v", result.Error.Error())
}
return nil
}
func RandomBool() (bool, error) {
n, err := rand.Int(rand.Reader, big.NewInt(2))
if err != nil {
return false, err
}
// Return true if the number is 1, otherwise false
return n.Int64() == 1, nil
}
func IsExecutorType(v interface{}) bool {
_, ok := v.(types.ExecutorType)
return ok
}
func IsGPUVendor(v interface{}) bool {
_, ok := v.(types.GPUVendor)
return ok
}
func IsJobType(v interface{}) bool {
_, ok := v.(types.JobType)
return ok
}
func IsJobTypes(v interface{}) bool {
_, ok := v.(types.JobTypes)
return ok
}
func IsExecutor(v interface{}) bool {
_, ok := v.(types.Executor)
return ok
}
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
func NoIntersectionSlices(slice1, slice2 []interface{}) bool {
result := false // the default result is false
for _, subElement := range slice1 {
if slices.Contains(slice2, subElement) {
result = false
} else {
result = true
}
}
return result
}
// IntersectionStringSlices returns the intersection of two slices of strings.
func IntersectionSlices(slice1, slice2 []interface{}) []interface{} {
// Create a map to store strings from the first slice.
executorMap := make(map[interface{}]bool)
// Iterate through the first slice and add elements to the map.
for _, str := range slice1 {
executorMap[str] = true
}
// Create a slice to store the intersection of the strings.
intersectionSlice := []interface{}{}
// Iterate through the second slice and check for common elements.
for _, str := range slice2 {
if executorMap[str] {
// If the string is found in the map, add to the intersection slice.
intersectionSlice = append(intersectionSlice, str)
// Remove the string from the map to avoid duplicates in the result.
delete(executorMap, str)
}
}
return intersectionSlice
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
package validate
import (
"reflect"
)
func ConvertNumericToFloat64(n any) (float64, bool) {
switch n := n.(type) {
case int, int8, int16, int32, int64:
return float64(reflect.ValueOf(n).Int()), true
case uint, uint8, uint16, uint32, uint64:
return float64(reflect.ValueOf(n).Uint()), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}
package validate
import (
"strings"
)
// IsBlank checks if a string is empty or contains only whitespace
func IsBlank(s string) bool {
return len(strings.TrimSpace(s)) == 0
}
// IsNotBlank checks if a string is not empty and does not contain only whitespace
func IsNotBlank(s string) bool {
return !IsBlank(s)
}
// Just checks if a variable is a string
func IsLiteral(s interface{}) bool {
switch s.(type) {
case string:
return true
default:
return false
}
}