package api
import (
"fmt"
"os"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
ginSwagger "github.com/swaggo/gin-swagger"
"github.com/swaggo/gin-swagger/swaggerFiles"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
type RESTServerConfig struct {
P2P *libp2p.Libp2p
Onboarding *onboarding.Onboarding
Logger *logger.Logger
Resource resources.Manager
MidW []gin.HandlerFunc
Port uint32
Addr string
}
// RESTServer represents a HTTP server
type RESTServer struct {
router *gin.Engine
config *RESTServerConfig
}
// NewRESTServer is a constructor function for RESTServer
// It returns a pointer to RESTServer
func NewRESTServer(config *RESTServerConfig) *RESTServer {
return &RESTServer{
router: setupRouter(config.MidW),
config: config,
}
}
func setupRouter(mid []gin.HandlerFunc) *gin.Engine {
mid = append(mid, cors.New(getCustomCorsConfig()))
router := gin.Default()
router.Use(mid...)
return router
}
// InitializeRoutes sets up all the endpoint routes
func (rs *RESTServer) InitializeRoutes() {
v1 := rs.router.Group("/api/v1")
// onboardHandler := NewOnboardingHandler(s.config.Onboarding)
onboarding := v1.Group("/onboarding")
{
onboarding.GET("/provisioned", rs.ProvisionedCapacity)
onboarding.GET("/address/new", rs.CreatePaymentAddress)
onboarding.GET("/status", rs.Status)
onboarding.GET("/info", rs.Info)
onboarding.POST("/onboard", rs.Onboard)
onboarding.POST("/resource-config", rs.ResourceConfig)
onboarding.DELETE("/offboard", rs.Offboard)
}
// deviceHandler := DeviceHandler{}
device := v1.Group("/device")
{
device.GET("/status", rs.DeviceStatus)
device.POST("/status", rs.UpdateDeviceStatus)
}
// vmHandler := VMHandler{}
vm := v1.Group("/vm")
{
vm.POST("/start-default", rs.StartDefault)
vm.POST("/start-custom", rs.StartCustom)
}
// ph := P2PHandler{p2p: rs.config.P2P}
p2p := v1.Group("/peers")
{
p2p.GET("", rs.ListPeers)
p2p.GET("/self", rs.SelfPeerInfo)
// DEBUGGING ONLY
if _, debugMode := os.LookupEnv("NUNET_DEBUG"); debugMode {
p2p.GET("/ping", rs.PingPeer)
p2p.GET("/dht", rs.KnownPeers)
// p2p.GET("/dht/dump", ph.DumpDHTHandler)
}
}
// Swagger API documentation
rs.router.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
}
// Run starts the server on the specified port
func (rs *RESTServer) Run() error {
return rs.router.Run(fmt.Sprintf("%s:%d", rs.config.Addr, rs.config.Port))
}
func getCustomCorsConfig() cors.Config {
config := defaultConfig()
// FIXME: This is a security concern.
config.AllowOrigins = []string{"http://localhost:9991", "http://localhost:9992"}
return config
}
// defaultConfig returns a generic default configuration mapped to localhost.
func defaultConfig() cors.Config {
return cors.Config{
AllowMethods: []string{"GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"},
AllowHeaders: []string{"Access-Control-Allow-Origin", "Origin", "Content-Length", "Content-Type"},
AllowCredentials: false,
MaxAge: 12 * time.Hour,
}
}
package api
import (
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/libp2p/go-libp2p/core/peer"
)
// DEBUG
func (rs RESTServer) PingPeer(c *gin.Context) {
reqCtx := c.Request.Context()
id := c.Query("peerID")
if id == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "peerID not provided"})
return
}
if id == rs.config.P2P.Host.ID().String() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "peerID can not be self peerID"})
return
}
// decode only for validation
target, err := peer.Decode(id)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid string ID: could not decode string ID to peer ID"})
return
}
res, err := rs.config.P2P.Ping(reqCtx, target.String(), time.Second*5)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("could not ping peer %s: %v", id, err)})
return
}
c.JSON(http.StatusOK, gin.H{"message": fmt.Sprintf("ping peer %s, success=%t, RTT=%d", id, res.Success, res.RTT)})
}
package api
import (
"net/http"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// DeviceStatusHandler godoc
//
// @Summary Retrieve device status
// @Description Retrieve device status whether paused/offline (unable to receive job deployments) or online
// @Tags device
// @Produce json
// @Success 200 {object} object
// @Failure 500 {object} object "host node has not yet been initialized"
// @Failure 500 {object} object "could not retrieve data from peer"
// @Failure 500 {object} object "failed to type assert peer data for peer ID"
// @Router /device/status [get]
func (rs RESTServer) DeviceStatus(c *gin.Context) {
// TODO: handle this after refactor
// status, err := libp2p.DeviceStatus()
// if err != nil {
// c.AbortWithStatusJSON(500, gin.H{"error": "could not retrieve device status"})
// return
// }
// c.JSON(200, gin.H{"online": status})
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "device status not implemented"})
}
// UpdateDeviceStatus godoc
//
// @Summary Update device status between online/offline
// @Description Update device status to online (able to receive jobs) or offline (unable to receive jobs).
// @Tags device
// @Produce json
// @Failure 400 {object} object "empty content data"
// @Failure 400 {object} object "invalid payload data"
// @Failure 500 {object} object "host node has not yet been initialized"
// @Failure 500 {object} object "could not retrieve data from self peer"
// @Failure 500 {object} object "failed to type assert peer data for peer ID"
// @Failure 500 {object} object "Failed to retrieve libp2p info from database"
// @Failure 500 {object} object "Failed to update libp2p info in database"
// @Failure 500 {object} object "failed to put peer data into peerstore"
// @Success 200 {object} object "Device status successfully changed to online"
// @Success 200 {object} object "Device status successfully changed to offline"
// @Success 200 {object} object "no change in device status"
// @Router /device/status [post]
func (rs RESTServer) UpdateDeviceStatus(c *gin.Context) {
span := trace.SpanFromContext(c.Request.Context())
span.SetAttributes(attribute.String("URL", "/device/status"))
var status struct {
IsAvailable bool `json:"is_available"`
}
if c.Request.ContentLength == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "empty content data"})
return
}
err := c.ShouldBindJSON(&status)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid payload data"})
return
}
// TODO: handle this after refactor
// err = libp2p.ChangeDeviceStatus(status.IsAvailable)
// if err != nil {
// c.AbortWithStatusJSON(500, gin.H{"error": err.Error()})
// return
// }
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "change device status not implemented"})
// END
var msg string
if status.IsAvailable {
msg = "Device status successfully changed to online"
} else {
msg = "Device status successfully changed to offline"
}
c.JSON(http.StatusOK, gin.H{"message": msg})
}
package api
import (
"net/http"
"strconv"
"github.com/gin-gonic/gin"
"gitlab.com/nunet/device-management-service/dms/onboarding"
"gitlab.com/nunet/device-management-service/types"
)
// // OnboardingHandler is a controller for /onboarding endpoint functionalities
// type OnboardingHandler struct {
// service *onboarding.Onboarding
// }
// // NewOnboardingHandler is a constructor for OnboardingHandler
// func NewOnboardingHandler(s *onboarding.Onboarding) OnboardingHandler {
// return OnboardingHandler{service: s}
// }
// ProvisionedCapacity godoc
//
// @Summary Returns provisioned capacity on host.
// @Description Get total memory capacity in MB and CPU capacity in MHz.
// @Tags onboarding
// @Produce json
// @Success 200 {object} types.Provisioned
// @Router /onboarding/provisioned [get]
func (rs *RESTServer) ProvisionedCapacity(c *gin.Context) {
provisionedResources, err := rs.config.Resource.SystemSpecs().GetProvisionedResources()
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, provisionedResources)
}
// CreatePaymentAddress godoc
//
// @Summary Create a new payment address.
// @Description Create a payment address from public key. Return payment address and private key.
// @Tags onboarding
// @Produce json
// @Success 200 {object} types.BlockchainAddressPrivKey
// @Router /onboarding/address/new [get]
func (rs RESTServer) CreatePaymentAddress(c *gin.Context) {
wallet := c.DefaultQuery("blockchain", "cardano")
pair, err := onboarding.CreatePaymentAddress(wallet)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, pair)
}
// Onboard godoc
//
// @Summary Runs the onboarding process.
// @Description Onboard runs onboarding script given the amount of resources to onboard.
// @Tags onboarding
// @Produce json
// @Param capacity body types.CapacityForNunet true "Capacity for NuNet"
// @Success 201 {object} types.OnboardingConfig
// @Failure 400 {object} object "invalid request data"
// @Failure 500 {object} object "could not check if config directory exists"
// @Failure 500 {object} object "config directory does not exist"
// @Failure 500 {object} object "could not validate payment address"
// @Failure 500 {object} object "could not validate capacity data"
// @Failure 500 {object} object "cardano node requires 10000MB of RAM and 6000MHz CPU"
// @Failure 500 {object} object "invalid channel data, channel does not exist"
// @Failure 500 {object} object "unable to create available resources table"
// @Failure 500 {object} object "unable to update available resources table"
// @Failure 500 {object} object "could not calculate free resources and update database"
// @Failure 500 {object} object "could not register and run new node"
// @Router /onboarding/onboard [post]
func (rs *RESTServer) Onboard(c *gin.Context) {
capacity := types.CapacityForNunet{
ServerMode: true,
IsAvailable: true,
}
if err := c.BindJSON(&capacity); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request data"})
return
}
oConfig, p2p, err := rs.config.Onboarding.Onboard(c.Request.Context(), capacity)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rs.config.P2P = p2p
c.JSON(http.StatusCreated, oConfig)
}
// Offboard godoc
//
// @Summary Runs the offboarding process.
// @Description Offboard runs offboarding process to remove the machine from the NuNet network.
// @Tags onboarding
// @Produce json
// @Success 200 {string} string "device successfully offboarded"
// @Param force query string false "force offboarding"
// @Failure 400 {object} object "invalid query data"
// @Failure 500 {object} object "could not retrieve onboard status"
// @Failure 500 {object} object "machine is not onboarded"
// @Failure 500 {object} object "unable to shutdown node"
// @Failure 500 {object} object "unable to delete available resources on database"
// @Failure 500 {object} object "could not remove payment address"
// @Router /onboarding/offboard [post]
func (rs RESTServer) Offboard(c *gin.Context) {
query := c.DefaultQuery("force", "false")
force, err := strconv.ParseBool(query)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid query data"})
return
}
err = rs.config.Onboarding.Offboard(c.Request.Context(), force)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "device successfully offboarded"})
}
// Status godoc
//
// @Summary Returns whether device is onboarded or not.
// @Tags onboarding
// @Produce json
// @Success 200 {boolean}
// @Router /onboarding/status [get]
func (rs RESTServer) Status(c *gin.Context) {
status, err := rs.config.Onboarding.IsOnboarded(c.Request.Context())
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"onboarded": status})
}
// Info godoc
//
// @Summary Returns additional information about onboarded device.
// @Tags onboarding
// @Produce json
// @Success 200 {object} types.OnboardingConfig
// @Router /onboarding/info [get]
func (rs RESTServer) Info(c *gin.Context) {
info, err := rs.config.Onboarding.Info(c.Request.Context())
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"info": info})
}
// ResourceConfig godoc
//
// @Summary changes the amount of resources of onboarded device .
// @Tags onboarding
// @Produce json
// @Success 200 {object} types.OnboardingConfig
// @Router /onboarding/resource-config [post]
func (rs RESTServer) ResourceConfig(c *gin.Context) {
if c.Request.ContentLength == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "request body is empty"})
return
}
var capacity types.CapacityForNunet
err := c.BindJSON(&capacity)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request data"})
return
}
oConfig, err := rs.config.Onboarding.ResourceConfig(c.Request.Context(), capacity)
if err != nil {
switch err {
case onboarding.ErrMachineNotOnboarded:
c.AbortWithStatusJSON(http.StatusConflict, gin.H{"error": err.Error()})
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, oConfig)
}
package api
import (
"net/http"
"github.com/gin-gonic/gin"
)
// ListPeers godoc
//
// @Summary Return list of peers currently connected to
// @Description Gets a list of peers the libp2p node can see within the network and return a list of peers
// @Tags p2p
// @Produce json
// @Failure 500 {object} object "no peers yet"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Success 200 {object} object "list of peers"
// @Router /peers [get]
func (rs RESTServer) ListPeers(c *gin.Context) {
if rs.config.P2P == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
peers := rs.config.P2P.VisiblePeers()
if len(peers) == 0 {
c.JSON(http.StatusInternalServerError, gin.H{"error": "no peers yet"})
return
}
c.JSON(http.StatusOK, peers)
}
// KnownPeers godoc
//
// @Summary Return list of peers which have sent a dht update
// @Description Gets a list of peers the libp2p node has received a dht update from
// @Tags p2p
// @Produce json
// @Success 200 {object} object "List of peers"
// @Failure 404 {object} object "No peers found"
// @Failure 500 {object} object "Host Node hasn't yet been initialized"
// @Router /peers/dht [get]
func (rs RESTServer) KnownPeers(c *gin.Context) {
if rs.config.P2P == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
peers, err := rs.config.P2P.KnownPeers()
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if len(peers) == 0 {
c.JSON(http.StatusNotFound, gin.H{"message": "no peers found"})
return
}
c.JSON(http.StatusOK, peers)
}
// SelfPeerInfo godoc
//
// @Summary Return self peer info
// @Description Gets self peer info of libp2p node
// @Tags p2p
// @Produce json
// @Success 200 {object} object "Self Peer Info"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Router /peers/self [get]
func (rs RESTServer) SelfPeerInfo(c *gin.Context) {
if rs.config.P2P == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
self := rs.config.P2P.Stat()
c.JSON(http.StatusOK, self)
}
// DumpDHT godoc
//
// @Summary Return a dump of the dht
// @Description Returns entire DHT content
// @Tags p2p
// @Produce json
// @Success 200 {object} object "List of DHT peers"
// @Failure 500 {object} object "host node hasn't yet been initialized"
// @Failure 500 {object} object "no content in DHT"
// @Router /peers/dht/dump [get]
func (rs RESTServer) DumpDHT(c *gin.Context) {
if rs.config.P2P == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "host node hasn't yet been initialized"})
return
}
dht, err := rs.config.P2P.DumpDHTRoutingTable()
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if len(dht) == 0 {
c.JSON(http.StatusInternalServerError, gin.H{"message": "empty DHT"})
return
}
c.JSON(http.StatusOK, dht)
}
package api
import (
"net/http"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"gitlab.com/nunet/device-management-service/executor/firecracker"
"gitlab.com/nunet/device-management-service/types"
)
type CustomVM struct {
KernelImagePath string `json:"kernel_image_path"`
FilesystemPath string `json:"filesystem_path"`
VCPUCount int32 `json:"vcpu_count"`
MemSizeMib int `json:"mem_size_mib"`
TapDevice string `json:"tap_device"`
}
type DefaultVM struct {
KernelImagePath string `json:"kernel_image_path"`
FilesystemPath string `json:"filesystem_path"`
PublicKey string `json:"public_key"`
NodeID string `json:"node_id"`
}
// StartCustom godoc
//
// @Summary Start a VM with custom configuration.
// @Description This endpoint is an abstraction of all primitive endpoints. When invokend, it calls all primitive endpoints in a sequence.
// @Tags vm
// @Produce json
// @Param body body firecracker.CustomVM true "body"
// @Success 200 {object} string "VM started successfully."
// @Failure 400 {object} string "invalid request body"
// @Failure 500 {object} string "could not create database table"
// @Failure 500 {object} string "could not initialize virtual machine"
// @Failure 500 {object} string "failed to configure drives"
// @Failure 500 {object} string "failed to configure machine config"
// @Failure 500 {object} string "failed to configure network-interfaces"
// @Failure 500 {object} string "failed to setup MMDS"
// @Failure 500 {object} string "failed to pass MMDS message"
// @Failure 500 {object} string "unable to start virtual machine"
// @Router /vm/start-custom [post]
func (rs RESTServer) StartCustom(c *gin.Context) {
reqCtx := c.Request.Context()
span := trace.SpanFromContext(reqCtx)
span.SetAttributes(attribute.String("URL", "/vm/start-custom"))
var body CustomVM
err := c.BindJSON(&body)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
fe := firecracker.NewFirecrackerEngineBuilder(body.FilesystemPath).
WithKernelImage(body.KernelImagePath).
Build()
fer := &types.ExecutionRequest{
JobID: "test_job",
ExecutionID: "test_execution",
EngineSpec: fe,
Resources: &types.ExecutionResources{
CPU: types.CPU{Cores: uint32(body.VCPUCount)},
Memory: types.RAM{Size: int64(body.MemSizeMib)},
},
}
fc, err := firecracker.NewExecutor(c.Request.Context(), "manual-start-custom")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
err = fc.Start(c.Request.Context(), fer)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "VM started successfully"})
}
// StartDefault godoc
//
// @Summary Start a VM with default configuration.
// @Description Kernel file and filesystem file needs to be passed in body. This endpoint is an abstraction of all primitive endpoints.
// @Tags vm
// @Produce json
// @Param body body firecracker.DefaultVM true "body"
// @Success 200 {object} string "VM started successfully."
// @Failure 400 {object} string "invalid request body"
// @Failure 500 {object} string "could not initialize virtual machine"
// @Failure 500 {object} string "failed to confiugre boot source"
// @Failure 500 {object} string "failed to configure drives"
// @Failure 500 {object} string "failed to configure machineConfig"
// @Failure 500 {object} string "failed to configure network-interfaces"
// @Failure 500 {object} string "failed to setup MMDS"
// @Failure 500 {object} string "failed to pass MMDS message"
// @Failure 500 {object} string "unable to start virtual machine"
// @Router /vm/start-default [post]
func (rs RESTServer) StartDefault(c *gin.Context) {
reqCtx := c.Request.Context()
span := trace.SpanFromContext(reqCtx)
span.SetAttributes(attribute.String("URL", "/vm/start-default"))
var body DefaultVM
err := c.BindJSON(&body)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
fe := firecracker.NewFirecrackerEngineBuilder(body.FilesystemPath).
WithKernelImage(body.KernelImagePath).
Build()
fer := &types.ExecutionRequest{
JobID: "test_job",
ExecutionID: "test_execution",
EngineSpec: fe,
Resources: &types.ExecutionResources{
CPU: types.CPU{Cores: 1},
Memory: types.RAM{Size: 1024},
},
}
fc, err := firecracker.NewExecutor(c.Request.Context(), "manual-start-default")
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
err = fc.Start(c.Request.Context(), fer)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"message": "VM started successfully"})
}
package cmd
import (
"fmt"
"math"
"regexp"
"strconv"
)
type amdGPU struct {
index int
}
func (a *amdGPU) name() string {
pattern := fmt.Sprintf(`GPU\[%d\]\s+: Card series:\s+(.+)`, a.index)
re := regexp.MustCompile(pattern)
rocmOutput, err := runShellCmd("rocm-smi --showproductname")
if err != nil {
return ""
}
match := re.FindStringSubmatch(rocmOutput)
if len(match) > 1 {
return match[1]
}
return ""
}
func (a *amdGPU) utilizationRate() int64 {
pattern := fmt.Sprintf(`GPU\[%d\]\s+: GPU use \(%%\): (\d+)`, a.index)
re := regexp.MustCompile(pattern)
rocmOutput, err := runShellCmd("rocm-smi --showuse")
if err != nil {
return 0
}
match := re.FindStringSubmatch(rocmOutput)
if len(match) > 1 {
utilization, err := strconv.ParseInt(match[1], 10, 32)
if err != nil {
return 0
}
return utilization
}
return 0
}
func (a *amdGPU) memory() memoryInfo {
patternTotal := fmt.Sprintf(`GPU\[%d\]\s+: vram Total Memory \(B\): (\d+)`, a.index)
reTotal := regexp.MustCompile(patternTotal)
patternUsed := fmt.Sprintf(`GPU\[%d\]\s+: vram Total Used Memory \(B\): (\d+)`, a.index)
reUsed := regexp.MustCompile(patternUsed)
rocmOutput, err := runShellCmd("rocm-smi --showmeminfo vram")
if err != nil {
return memoryInfo{}
}
matchTotal := reTotal.FindStringSubmatch(rocmOutput)
matchUsed := reUsed.FindStringSubmatch(rocmOutput)
if len(matchTotal) > 1 && len(matchUsed) > 1 {
total, err := strconv.ParseInt(matchTotal[1], 10, 64)
if err != nil {
total = 0
}
used, err := strconv.ParseInt(matchUsed[1], 10, 64)
if err != nil {
used = 0
}
free := (total - used)
return memoryInfo{
used: uint64(used),
free: uint64(free),
total: uint64(total),
}
}
return memoryInfo{}
}
func (a *amdGPU) temperature() float64 {
pattern := fmt.Sprintf(`GPU\[%d\]\s+: Temperature \(Sensor edge\) \(C\): ([\d\.]+)`, a.index)
re := regexp.MustCompile(pattern)
rocmOutput, err := runShellCmd("rocm-smi --showtemp")
if err != nil {
return 0
}
match := re.FindStringSubmatch(rocmOutput)
if len(match) > 1 {
temperature, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return 0
}
return temperature
}
return 0
}
func (a *amdGPU) powerUsage() uint32 {
pattern := fmt.Sprintf(`GPU\[%d\]\s+: Average Graphics Package Power \(W\): ([\d\.]+)`, a.index)
re := regexp.MustCompile(pattern)
rocmOutput, err := runShellCmd("rocm-smi --showpower")
if err != nil {
return 0
}
match := re.FindStringSubmatch(rocmOutput)
if len(match) > 1 {
powerFloat, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return 0
}
power := uint32(math.Round(powerFloat))
return power
}
return 0
}
package cmd
import (
"os"
"github.com/spf13/cobra"
)
// autocompleteCmd represents the command to generate shell autocompletion scripts
func newAutoCompleteCmd() *cobra.Command {
return &cobra.Command{
Use: "autocomplete [shell]",
Short: "Generate autocomplete script for your shell",
Long: `Generate an autocomplete script for the nunet CLI.
This command supports Bash and Zsh shells.`,
DisableFlagsInUseLine: true,
Hidden: true,
ValidArgs: []string{"bash", "zsh"},
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
Run: func(cmd *cobra.Command, args []string) {
switch args[0] {
case "bash":
_ = cmd.Root().GenBashCompletion(os.Stdout)
case "zsh":
_ = cmd.Root().GenZshCompletion(os.Stdout)
}
},
}
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/buger/jsonparser"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewCapacityCmd is a constructor for `capacity` command
func newCapacityCmd(client *utils.HTTPClient) *cobra.Command {
fnAvailable := "available"
fnOnboarded := "onboarded"
fnFull := "full"
cmd := &cobra.Command{
Use: "capacity",
Short: "Display capacity of device resources",
Long: `Retrieve capacity of the machine, onboarded or available amount of resources`,
RunE: func(cmd *cobra.Command, _ []string) error {
table := setupTable(cmd.OutOrStdout())
defer table.Render()
resBody, resCode, err := client.MakeRequest("GET", "/onboarding/info", nil)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
info, _, _, err := jsonparser.Get(resBody, "info")
if err != nil {
return fmt.Errorf("could not get info value by key: %w", err)
}
var config *types.OnboardingConfig
if err := json.Unmarshal(info, &config); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
full, _ := cmd.Flags().GetBool(fnFull)
available, _ := cmd.Flags().GetBool(fnAvailable)
onboarded, _ := cmd.Flags().GetBool(fnOnboarded)
if full {
table.Append(formatCapacityData("Full", &config.TotalResources.Resources))
}
if onboarded {
table.Append(formatCapacityData("Onboarded", &config.OnboardedResources.Resources))
}
if available {
free := &types.Resources{
RAM: config.TotalResources.RAM - config.OnboardedResources.RAM,
Disk: config.TotalResources.Disk - config.OnboardedResources.Disk,
}
table.Append(formatCapacityData("Available", free))
}
return nil
},
}
cmd.Flags().BoolP(fnFull, "f", false, "display device capacity")
cmd.Flags().BoolP(fnAvailable, "a", false, "display amount of resources still available for onboarding")
cmd.Flags().BoolP(fnOnboarded, "o", false, "display amount of onboarded resources")
cmd.MarkFlagsOneRequired(fnAvailable, fnFull, fnOnboarded)
return cmd
}
func formatCapacityData(resourceType string, resources *types.Resources) []string {
return []string{
resourceType,
fmt.Sprintf("%d", resources.RAM),
fmt.Sprintf("%f", resources.CPU),
fmt.Sprintf("%d", resources.NumCores),
}
}
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewDeviceCmd is a constructor for `device` parent command
func newDeviceCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "device",
Short: "device related operations",
Long: `manage onboarded device`,
Run: func(cmd *cobra.Command, _ []string) {
err := cmd.Help()
if err != nil {
cmd.Println(err)
}
},
}
cmd.AddCommand(newDeviceStatusCmd(client))
cmd.AddCommand(newDeviceSetCmd(client))
return cmd
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewDeviceSetCmd is a constructor for `device set` subcommand
func newDeviceSetCmd(client *utils.HTTPClient) *cobra.Command {
validArgs := []string{"online", "offline"}
return &cobra.Command{
Use: "set",
Short: "Set device online status",
ValidArgs: validArgs,
Args: cobra.MatchAll(cobra.ExactArgs(1), cobra.OnlyValidArgs),
Long: ``,
RunE: func(cmd *cobra.Command, args []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
var statusJSON []byte
if args[0] == "online" {
statusJSON = []byte(`{"is_available": true}`)
} else if args[0] == "offline" {
statusJSON = []byte(`{"is_available": false}`)
}
resBody, resCode, err := client.MakeRequest("POST", "/device/status", statusJSON)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 201 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
var response map[string]any
if err := json.Unmarshal(resBody, &response); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
msg, ok := response["message"]
if ok {
fmt.Fprintln(cmd.OutOrStdout(), msg)
}
return nil
},
}
}
package cmd
import (
"fmt"
"github.com/buger/jsonparser"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewDeviceStatusCmd is a constructor for `device status` subcommand
func newDeviceStatusCmd(client *utils.HTTPClient) *cobra.Command {
return &cobra.Command{
Use: "status",
Short: "Display current device status",
Args: cobra.NoArgs,
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
resBody, resCode, err := client.MakeRequest("GET", "/device/status", nil)
if err != nil {
return fmt.Errorf("unable to make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
online, err := jsonparser.GetBoolean(resBody, "online")
if err != nil {
return fmt.Errorf("failed to get device status from json response: %w", err)
}
if online {
fmt.Fprintln(cmd.OutOrStdout(), "Status: online")
} else {
fmt.Fprintln(cmd.OutOrStdout(), "Status: offline")
}
return nil
},
}
}
package cmd
import (
"github.com/spf13/afero"
"github.com/spf13/cobra"
)
func newGPUCmd(afs afero.Afero) *cobra.Command {
cmd := &cobra.Command{
Use: "gpu",
Short: "GPU-related operations",
Long: ``,
Run: func(cmd *cobra.Command, _ []string) {
err := cmd.Help()
if err != nil {
cmd.Println(err)
}
},
}
cmd.AddCommand(newGPUCapacityCmd())
cmd.AddCommand(newGPUStatusCmd())
cmd.AddCommand(newGPUOnboardCmd(afs))
return cmd
}
package cmd
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"syscall"
"github.com/docker/cli/opts"
docker_types "github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
)
// ContainerOptions set parameters for running a Docker container (NVIDIA/AMD/Intel)
type ContainerOptions struct {
UseGPUs bool
Devices []string
Groups []string
Image string
Command []string
Entrypoint []string
}
func newGPUCapacityCmd() *cobra.Command {
fnCuda := "cuda-tensor"
fnRocm := "rocm-hip"
cmd := &cobra.Command{
Use: "capacity",
Short: "Check availability of NVIDIA/AMD/Intel GPUs",
Long: ``,
Run: func(cmd *cobra.Command, _ []string) {
cuda, _ := cmd.Flags().GetBool(fnCuda)
rocm, _ := cmd.Flags().GetBool(fnRocm)
if !cuda && !rocm {
fmt.Println(`Error: no flags specified`)
err := cmd.Help()
if err != nil {
cmd.Println(err)
}
os.Exit(1)
}
vendors, err := resources.ManagerInstance.SystemSpecs().GetGPUVendors()
if err != nil {
fmt.Println("Error detecting GPU vendors:", err)
os.Exit(1)
}
hasAMD := containsVendor(vendors, types.GPUVendorAMDATI)
hasNVIDIA := containsVendor(vendors, types.GPUVendorNvidia)
if !hasAMD && !hasNVIDIA {
fmt.Println("No AMD or NVIDIA GPU(s) detected...")
os.Exit(1)
}
ctx := context.Background()
if cuda {
if !hasNVIDIA {
fmt.Println("No NVIDIA GPU(s) detected...")
os.Exit(1)
}
cudaOpts := ContainerOptions{
UseGPUs: true,
Image: "registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/pytorch",
Command: []string{"python", "check-cuda-and-tensor-cores-availability.py"},
Entrypoint: []string{""},
}
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
fmt.Println("Error creating Docker client:", err)
os.Exit(1)
}
images, err := cli.ImageList(ctx, docker_types.ImageListOptions{})
if err != nil {
fmt.Println("Error listing Docker images:", err)
os.Exit(1)
}
if !imageExists(images, cudaOpts.Image) {
err := pullImage(ctx, cli, cudaOpts.Image)
if err != nil {
fmt.Println("Error pulling CUDA image:", err)
os.Exit(1)
}
}
err = runDockerContainer(ctx, cli, cudaOpts)
if err != nil {
fmt.Println("Error running CUDA container:", err)
os.Exit(1)
}
}
if rocm {
if !hasAMD {
fmt.Println("No AMD GPU(s) detected...")
os.Exit(1)
}
rocmOpts := ContainerOptions{
Devices: []string{"/dev/kfd", "/dev/dri"},
Groups: []string{"video", "render"},
Image: "registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/pytorch-amd",
Command: []string{"python", "check-rocm-and-hip-availability.py"},
Entrypoint: []string{""},
}
cli, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
fmt.Println("Error creating Docker client:", err)
os.Exit(1)
}
images, err := cli.ImageList(ctx, docker_types.ImageListOptions{})
if err != nil {
fmt.Println("Error listing images:", err)
os.Exit(1)
}
if !imageExists(images, rocmOpts.Image) {
err := pullImage(ctx, cli, rocmOpts.Image)
if err != nil {
fmt.Println("Error pulling ROCm-HIP image:", err)
os.Exit(1)
}
}
err = runDockerContainer(ctx, cli, rocmOpts)
if err != nil {
fmt.Println("Error running ROCm-HIP container:", err)
os.Exit(1)
}
}
},
}
cmd.Flags().BoolP(fnCuda, "c", false, "check CUDA Tensor")
cmd.Flags().BoolP(fnRocm, "r", false, "check ROCM-HIP")
return cmd
}
func runDockerContainer(ctx context.Context, cli *client.Client, options ContainerOptions) error {
if options.Image == "" {
return fmt.Errorf("image name cannot be empty")
}
config := &container.Config{
Image: options.Image,
Entrypoint: options.Entrypoint,
Cmd: options.Command,
Tty: true,
}
hostConfig := &container.HostConfig{}
if options.UseGPUs {
gpuOpts := opts.GpuOpts{}
if err := gpuOpts.Set("all"); err != nil {
return fmt.Errorf("failed setting GPU opts: %v", err)
}
hostConfig.DeviceRequests = gpuOpts.Value()
}
for _, device := range options.Devices {
hostConfig.Devices = append(hostConfig.Devices, container.DeviceMapping{
PathOnHost: device,
PathInContainer: device,
CgroupPermissions: "rwm",
})
}
hostConfig.GroupAdd = options.Groups
resp, err := cli.ContainerCreate(ctx, config, hostConfig, nil, nil, "")
if err != nil {
return fmt.Errorf("cannot create container: %v", err)
}
defer func() {
if err := cli.ContainerRemove(ctx, resp.ID, docker_types.ContainerRemoveOptions{}); err != nil {
fmt.Printf("WARNING: could not remove container: %v\n", err)
}
}()
if err := cli.ContainerStart(ctx, resp.ID, docker_types.ContainerStartOptions{}); err != nil {
return fmt.Errorf("cannot start container: %v", err)
}
out, err := cli.ContainerAttach(ctx, resp.ID, docker_types.ContainerAttachOptions{
Stream: true,
Stdout: true,
Stderr: true,
})
if err != nil {
return fmt.Errorf("failed attaching container: %v", err)
}
_, err = io.Copy(os.Stdout, out.Reader)
if err != nil {
return fmt.Errorf("failed to copy container output: %w", err)
}
waitCh, errCh := cli.ContainerWait(ctx, resp.ID, container.WaitConditionNotRunning)
select {
case waitResult := <-waitCh:
if waitResult.Error != nil {
return fmt.Errorf("container exit error: %s", waitResult.Error.Message)
}
case err := <-errCh:
return fmt.Errorf("error waiting for container: %v", err)
}
return nil
}
func imageExists(images []docker_types.ImageSummary, imageName string) bool {
for _, image := range images {
for _, tag := range image.RepoTags {
if tag == imageName {
return true
}
}
}
return false
}
func pullImage(ctx context.Context, cli *client.Client, imageName string) error {
ctxCancel, cancel := context.WithCancel(ctx)
defer cancel()
out, err := cli.ImagePull(ctxCancel, imageName, docker_types.ImagePullOptions{})
if err != nil {
return fmt.Errorf("unable to pull image %s: %v", imageName, err)
}
// define interrupt to stop image pull
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM)
go func() {
<-interrupt
fmt.Println("signal: interrupt")
cancel()
}()
fmt.Printf("Pulling image: %s\nThis may take some time...\n", imageName)
defer out.Close()
_, err = io.Copy(os.Stdout, out)
if err != nil {
return fmt.Errorf("failed to copy image pull to stdout: %w", err)
}
return nil
}
package cmd
import (
"fmt"
"io"
"os/exec"
"strings"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
containerPath = "maint-scripts/install_container_runtime"
amdDriverPath = "maint-scripts/install_amd_drivers"
nvidiaDriverPath = "maint-scripts/install_nvidia_drivers"
)
func newGPUOnboardCmd(afs afero.Afero) *cobra.Command {
return &cobra.Command{
Use: "onboard",
Short: "Install GPU drivers and Container Runtime",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
wsl, err := utils.CheckWSL(afs)
if err != nil {
return fmt.Errorf("could not check WSL: %w", err)
}
mining, err := checkMiningOS(afs)
if err != nil {
return fmt.Errorf("couldn't check if Mining OS: %w", err)
}
// TODO: Define API endpoint for this
vendors, err := resources.ManagerInstance.SystemSpecs().GetGPUVendors()
if err != nil {
return fmt.Errorf("couldn't check presence of a GPU: %w", err)
}
hasAMD := containsVendor(vendors, types.GPUVendorAMDATI)
hasNVIDIA := containsVendor(vendors, types.GPUVendorNvidia)
hasIntel := containsVendor(vendors, types.GPUVendorIntel)
if !hasAMD && !hasNVIDIA && !hasIntel {
return fmt.Errorf("no NVIDIA/AMD/Intel GPU(s) detected")
}
switch {
case wsl:
fmt.Fprintf(cmd.OutOrStdout(), "You are running on Windows Subsystem for Linux (WSL). AMD GPUs are not supported.")
if !hasNVIDIA {
return fmt.Errorf("no NVIDIA GPU(s) detected")
}
if err := promptContainer(cmd.InOrStdin(), cmd.OutOrStdout(), containerPath); err != nil {
return fmt.Errorf("couldn't install container runtime: %w", err)
}
case mining:
fmt.Fprintf(cmd.OutOrStdout(), "You are likely running a Mining OS. Skipping driver installation...")
if err := promptContainer(cmd.InOrStdin(), cmd.OutOrStdout(), containerPath); err != nil {
return fmt.Errorf("couldn't install container runtime: %w", err)
}
default:
if hasNVIDIA {
nvidiaGPUs, err := resources.ManagerInstance.SystemSpecs().GetGPUs(types.GPUVendorNvidia)
if err != nil {
return fmt.Errorf("couldn't fetch Nvidia info: %w", err)
}
printGPUs(nvidiaGPUs)
if err := promptContainer(cmd.InOrStdin(), cmd.OutOrStdout(), containerPath); err != nil {
return fmt.Errorf("couldn't install container runtime: %w", err)
}
if err := promptDriverInstallation(cmd.InOrStdin(), cmd.OutOrStdout(), types.GPUVendorNvidia, nvidiaDriverPath); err != nil {
return fmt.Errorf("couldn't install Nvidia driver: %w", err)
}
}
if hasAMD {
AMDGPUs, err := resources.ManagerInstance.SystemSpecs().GetGPUs(types.GPUVendorAMDATI)
if err != nil {
return fmt.Errorf("couldn't fetch AMD driver info: %w", err)
}
printGPUs(AMDGPUs)
if err := promptDriverInstallation(cmd.InOrStdin(), cmd.OutOrStdout(), types.GPUVendorAMDATI, amdDriverPath); err != nil {
return fmt.Errorf("couldn't install AMD driver: %w", err)
}
}
}
return nil
},
}
}
// containsVendor takes a slice of GPUVendor structs that were detected in the system
// and look for a specific vendor, returning true if it is found.
func containsVendor(vendors []types.GPUVendor, target types.GPUVendor) bool {
for _, v := range vendors {
if v == target {
return true
}
}
return false
}
// runScript executes a bash script from a given path.
// It takes the script's path as input and tries to run it, if successful it prints the output.
func runScript(scriptPath string) error {
script := exec.Command("/bin/bash", scriptPath)
output, err := script.CombinedOutput()
if err != nil {
return fmt.Errorf("script failed with error: %w", err)
}
fmt.Printf("%s\n", output)
return nil
}
// promptContainer takes container runtime script path as input and prompts the user for confirmation.
// If confirmed, it runs the script.
func promptContainer(in io.Reader, out io.Writer, containerPath string) error {
proceed, err := utils.PromptYesNo(in, out, "Do you want to proceed with Container Runtime installation? (y/N)")
if err != nil {
return fmt.Errorf("could not read answer from prompt: %w", err)
}
if proceed {
err := runScript(containerPath)
if err != nil {
return fmt.Errorf("cannot run container runtime installation script: %w", err)
}
}
return nil
}
// promptDriverInstallation takes GPUVendor (for printing) and the installation script as inputs.
// It prompts the user for confirmation and if confirmed it runs the script.
func promptDriverInstallation(in io.Reader, out io.Writer, vendor types.GPUVendor, scriptPath string) error {
prompt := fmt.Sprintf("Do you want to proceed with %s driver installation? (y/N)", vendor)
proceed, err := utils.PromptYesNo(in, out, prompt)
if err != nil {
return fmt.Errorf("could not read answer from prompt: %w", err)
}
if proceed {
err := runScript(scriptPath)
if err != nil {
return fmt.Errorf("cannot run driver installation script: %w", err)
}
}
return nil
}
// printGPUs display a list of detected GPUs in the machine.
// It takes a slice of GPUInfo structs as input, get the vendor from the first element
// and then iterate over each element to display the GPU card series.
func printGPUs(gpus []types.GPU) {
var vendor string
if len(gpus) == 0 {
return
}
vendor = string(gpus[0].Vendor)
fmt.Printf("Available %s GPU(s):", vendor)
for _, gpu := range gpus {
fmt.Printf("- %s\n", gpu.Model)
}
}
// checkMiningOS detects if host is running a mining OS.
// It reads from /etc/os-release file and look for common distros inside of it, if any is found it returns true.
func checkMiningOS(afs afero.Afero) (bool, error) {
miningOSes := []string{"Hive", "Rave", "PiMP", "Minerstat", "SimpleMining", "NH", "Miner", "SM", "MMP"}
osFile := "/etc/os-release"
info, err := afs.ReadFile(osFile)
if err != nil {
return false, fmt.Errorf("cannot read file %s: %w", osFile, err)
}
infoStr := string(info)
for _, os := range miningOSes {
if strings.Contains(infoStr, os) {
return true, nil
}
}
return false, nil
}
package cmd
import (
"fmt"
"os"
"os/exec"
"os/signal"
"regexp"
"syscall"
"time"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/dustin/go-humanize"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
)
func newGPUStatusCmd() *cobra.Command {
return &cobra.Command{
Use: "status",
Short: "Check GPU status in real time",
Long: ``,
Run: func(_ *cobra.Command, _ []string) {
vendors, err := resources.ManagerInstance.SystemSpecs().GetGPUVendors()
if err != nil {
fmt.Println("Error trying to detect GPU(s):", err)
return
}
hasAMD := containsVendor(vendors, types.GPUVendorAMDATI)
hasNVIDIA := containsVendor(vendors, types.GPUVendorNvidia)
hasIntel := containsVendor(vendors, types.GPUVendorIntel)
if hasNVIDIA || hasAMD || hasIntel {
if hasNVIDIA {
// NVML initialization
retNVML := nvml.Init()
if retNVML != nvml.SUCCESS {
fmt.Println("Failed to initialize NVML:", nvml.ErrorString(retNVML))
}
defer func() {
retNVML := nvml.Shutdown()
if retNVML != nvml.SUCCESS {
fmt.Println("Failed to shutdown NVML:", nvml.ErrorString(retNVML))
}
}()
}
countNVML, retNVML := nvml.DeviceGetCount()
if retNVML != nvml.SUCCESS {
fmt.Println("Failed to count NVIDIA GPU devices:", nvml.ErrorString(retNVML))
countNVML = 0
}
countROCM, err := getCountAMD()
if err != nil {
fmt.Println("Failed to count AMD GPU devices:", err)
countROCM = 0
}
countIntel, err := getCountIntel()
if err != nil {
fmt.Println("Failed to count Intel GPU devices:", err)
countIntel = 0
}
// Initialize GPU slices
nvidiaGPUs := make([]nvidiaGPU, countNVML)
for i := 0; i < countNVML; i++ {
nvidiaGPUs[i] = nvidiaGPU{index: i}
}
amdGPUs := make([]amdGPU, countROCM)
for i := 0; i < countROCM; i++ {
amdGPUs[i] = amdGPU{index: (i + 1)}
}
intelGPUs := make([]intelGPU, countIntel)
for i := 0; i < countIntel; i++ {
intelGPUs[i] = intelGPU{index: (i + 1)}
}
// Define channel for receiving interrupt signal and closing the real-time loop
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt, syscall.SIGTERM)
exit := make(chan struct{})
go func() {
<-interrupt
close(exit)
}()
for {
select {
case <-exit:
fmt.Println("signal: interrupt")
return
default:
// Clear screen (not reliable, maybe implement something ncurses-like for future)
fmt.Print("\033[H\033[2J")
fmt.Println("========== NuNet GPU Status ==========")
fmt.Println("========== GPU Utilization ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %d%%\n", n.index, n.name(), n.utilizationRate())
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %d%%\n", a.index, a.name(), a.utilizationRate())
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %d%%\n", i.index, i.name(), i.utilizationRate())
}
fmt.Println("========== Memory Capacity ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %s\n", n.index, n.name(), humanize.IBytes(n.memory().total))
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %s\n", a.index, a.name(), humanize.IBytes(a.memory().total))
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %s\n", i.index, i.name(), humanize.IBytes(i.memory().total))
}
fmt.Println("========== Memory Used ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %s\n", n.index, n.name(), humanize.IBytes(n.memory().used))
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %s\n", a.index, a.name(), humanize.IBytes(a.memory().used))
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %s\n", i.index, i.name(), humanize.IBytes(i.memory().used))
}
fmt.Println("========== Memory Free ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %s\n", n.index, n.name(), humanize.IBytes(n.memory().free))
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %s\n", a.index, a.name(), humanize.IBytes(a.memory().free))
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %s\n", i.index, i.name(), humanize.IBytes(i.memory().free))
}
fmt.Println("========== Temperature ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %.0f°C\n", n.index, n.name(), n.temperature())
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %.0f°C\n", a.index, a.name(), a.temperature())
}
fmt.Println("========== Power Usage ==========")
for _, n := range nvidiaGPUs {
fmt.Printf("%d %s: %dW\n", n.index, n.name(), n.powerUsage())
}
for _, a := range amdGPUs {
fmt.Printf("%d AMD %s: %dW\n", a.index, a.name(), a.powerUsage())
}
for _, i := range intelGPUs {
fmt.Printf("%d %s: %dW\n", i.index, i.name(), i.powerUsage())
}
fmt.Println("")
fmt.Println("Press CTRL+C to exit...")
fmt.Println("Refreshing status in a few seconds...")
time.Sleep(2 * time.Second)
}
}
} else {
fmt.Println("No AMD, NVIDIA or Intel GPU(s) detected...")
return
}
},
}
}
func runShellCmd(command string) (string, error) {
cmd := exec.Command("sh", "-c", command)
output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("unable to get combined output from command %s: %v", command, err)
}
return string(output), nil
}
func getCountAMD() (int, error) {
rocmOutput, err := runShellCmd("rocm-smi --showid")
if err != nil {
return 0, fmt.Errorf("cannot run shell command: %v", err)
}
pattern := `GPU\[(\d+)\]`
re := regexp.MustCompile(pattern)
matches := re.FindAllStringSubmatch(rocmOutput, -1)
ids := make([]string, 0, len(matches))
for _, match := range matches {
ids = append(ids, match[1])
}
return len(ids), nil
}
// GetCountIntel returns the number of discrete Intel GPUs
func getCountIntel() (int, error) {
cmd := exec.Command("xpu-smi", "health", "-l")
output, err := cmd.CombinedOutput()
if err != nil {
return 0, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
outputStr := string(output)
// Use regex to find all instances of Device ID
deviceIDRegex := regexp.MustCompile(`(?i)\| Device ID\s+\|\s+(\d+)\s+\|`)
deviceIDMatches := deviceIDRegex.FindAllStringSubmatch(outputStr, -1)
return len(deviceIDMatches), nil
}
package cmd
import (
"encoding/json"
"fmt"
"io"
"github.com/buger/jsonparser"
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewInfoCmd is a constructor for `info` command
func newInfoCmd(client *utils.HTTPClient) *cobra.Command {
return &cobra.Command{
Use: "info",
Short: "Display information about onboarded device",
Long: "Display onboarding config of onboarded device on Nunet Device Management Service",
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
resBody, resCode, err := client.MakeRequest("GET", "/onboarding/info", nil)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
info, _, _, err := jsonparser.Get(resBody, "info")
if err != nil {
return fmt.Errorf("unable to parse JSON from key: %w", err)
}
var config types.OnboardingConfig
if err := json.Unmarshal(info, &config); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
displayInTable(cmd.OutOrStdout(), &config)
return nil
},
}
}
func displayInTable(w io.Writer, config *types.OnboardingConfig) {
table := tablewriter.NewWriter(w)
table.SetHeader([]string{"Info", "Value"})
table.Append([]string{"Name", config.Name})
table.Append([]string{"Update Timestamp", fmt.Sprintf("%d", config.UpdateTimestamp)})
table.Append([]string{"Memory Max", fmt.Sprintf("%d", config.TotalResources.RAM)})
table.Append([]string{"Total Core", fmt.Sprintf("%d", config.TotalResources.NumCores)})
table.Append([]string{"CPU Max", fmt.Sprintf("%.2f", config.TotalResources.CPU)})
table.Append([]string{"Reserved CPU", fmt.Sprintf("%.2f", config.OnboardedResources.CPU)})
table.Append([]string{"Reserved Memory", fmt.Sprintf("%d", config.OnboardedResources.RAM)})
table.Append([]string{"Network", config.Network})
table.Append([]string{"Public Key", config.PublicKey})
table.Append([]string{"Node ID", config.NodeID})
table.Append([]string{"Allow Cardano", fmt.Sprintf("%t", config.AllowCardano)})
table.Append([]string{"Dashboard", config.Dashboard})
table.Append([]string{"NTX Price Per Minute", fmt.Sprintf("%f", config.NTXPricePerMinute)})
table.Render()
}
package cmd
import (
"fmt"
"math"
"regexp"
"strconv"
"strings"
)
type intelGPU struct {
index int
}
// name returns the name of the Intel GPU.
func (i *intelGPU) name() string {
pattern := `Device Name:\s+(.+)`
re := regexp.MustCompile(pattern)
xpuOutput, err := runShellCmd(fmt.Sprintf("xpu-smi discovery -d %d", i.index))
if err != nil {
return ""
}
match := re.FindStringSubmatch(xpuOutput)
if len(match) > 1 {
return strings.TrimSpace(match[1])
}
return ""
}
// utilizationRate returns the utilization rate of the Intel GPU.
func (i *intelGPU) utilizationRate() int64 {
pattern := fmt.Sprintf(`GPU Utilization \(%%\)\s+\|\s+(\d+)`)
re := regexp.MustCompile(pattern)
xpuOutput, err := runShellCmd(fmt.Sprintf("xpu-smi stats -d %d", i.index))
if err != nil {
return 0
}
match := re.FindStringSubmatch(xpuOutput)
if len(match) > 1 {
utilization, err := strconv.ParseInt(match[1], 10, 32)
if err != nil {
return 0
}
return utilization
}
return 0
}
// memory returns the memory information of the Intel GPU.
func (i *intelGPU) memory() memoryInfo {
patternTotal := `Memory Physical Size:\s+([^\s]+)\s+MiB`
reTotal := regexp.MustCompile(patternTotal)
patternUsed := `GPU Memory Used \(MiB\)\s+\|\s+(\d+)`
reUsed := regexp.MustCompile(patternUsed)
xpuOutput, err := runShellCmd(fmt.Sprintf("xpu-smi discovery -d %d", i.index))
if err != nil {
return memoryInfo{}
}
matchTotal := reTotal.FindStringSubmatch(xpuOutput)
matchUsed := reUsed.FindStringSubmatch(xpuOutput)
if len(matchTotal) > 1 && len(matchUsed) > 1 {
total, err := strconv.ParseFloat(matchTotal[1], 64)
if err != nil {
total = 0
}
used, err := strconv.ParseFloat(matchUsed[1], 64)
if err != nil {
used = 0
}
free := (total - used)
return memoryInfo{
used: uint64(used),
free: uint64(free),
total: uint64(total),
}
}
return memoryInfo{}
}
// powerUsage returns the power usage of the Intel GPU.
func (i *intelGPU) powerUsage() uint32 {
pattern := `GPU Power \(W\)\s+\|\s+([\d\.]+)`
re := regexp.MustCompile(pattern)
xpuOutput, err := runShellCmd(fmt.Sprintf("xpu-smi stats -d %d", i.index))
if err != nil {
return 0
}
match := re.FindStringSubmatch(xpuOutput)
if len(match) > 1 {
powerFloat, err := strconv.ParseFloat(match[1], 64)
if err != nil {
return 0
}
power := uint32(math.Round(powerFloat))
return power
}
return 0
}
//go:build linux
package cmd
import (
"fmt"
"path/filepath"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/cmd/backend"
)
const (
logDir = "/tmp/nunet-log"
dmsUnit = "nunet-dms.service"
tarGzName = "nunet-log.tar.gz"
)
// NewLogCmd is a constructor for `log` command
func newLogCmd(afs afero.Afero, logger backend.Logger) *cobra.Command {
return &cobra.Command{
Use: "log",
Short: "Gather all logs into a tarball. COMMAND MUST RUN AS ROOT WITH SUDO",
RunE: func(cmd *cobra.Command, _ []string) error {
dmsLogDir := filepath.Join(logDir, "dms-log")
fmt.Fprintln(cmd.OutOrStdout(), "Collecting logs...")
err := afs.MkdirAll(dmsLogDir, 0o777)
if err != nil {
return fmt.Errorf("cannot create dms-log directory: %w", err)
}
if logger == nil {
return fmt.Errorf("logger service is not initialised")
}
defer logger.Close()
// filter by service unit name
match := fmt.Sprintf("_SYSTEMD_UNIT=%s", dmsUnit)
err = logger.AddMatch(match)
if err != nil {
return fmt.Errorf("cannot add unit match: %w", err)
}
var counter int
for {
count, err := logger.Next()
if err != nil {
fmt.Fprintf(cmd.OutOrStderr(), "Error reading next logger entry: %v\n", err)
continue
}
if count == 0 {
break
}
entry, err := logger.GetEntry()
if err != nil {
fmt.Fprintf(cmd.OutOrStderr(), "Error getting logger entry %d: %v\n", count, err)
continue
}
msg, ok := entry.Fields["MESSAGE"]
if !ok {
fmt.Fprintf(cmd.OutOrStderr(), "Error: no message field in entry %d\n", count)
}
logData := fmt.Sprintf("%d: %s\n", entry.RealtimeTimestamp, msg)
logFilePath := filepath.Join(dmsLogDir, fmt.Sprintf("dms_log.%d", count))
err = appendToFile(afs, logFilePath, logData)
if err != nil {
fmt.Fprintf(cmd.OutOrStderr(), "Error writing log file for boot %d: %v\n", count, err)
}
counter++
}
if counter == 0 {
return fmt.Errorf("no log entries for %s", dmsUnit)
}
tarGzFile := filepath.Join(logDir, tarGzName)
err = createTar(afs, tarGzFile, dmsLogDir)
if err != nil {
return fmt.Errorf("cannot create tar.gz file: %w", err)
}
err = afs.RemoveAll(dmsLogDir)
if err != nil {
return fmt.Errorf("remove dms-log directory failed: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), tarGzFile)
return nil
},
}
}
package cmd
import (
"fmt"
"github.com/NVIDIA/go-nvml/pkg/nvml"
)
const (
sensorNVML nvml.TemperatureSensors = iota
)
type nvidiaGPU struct {
index int
}
// helper function
func (n *nvidiaGPU) getDevice() (nvml.Device, error) {
device, ret := nvml.DeviceGetHandleByIndex(n.index)
if ret != nvml.SUCCESS {
return nil, fmt.Errorf("failed to get device (index %d) handle: %s", n.index, nvml.ErrorString(ret))
}
return device, nil
}
func (n *nvidiaGPU) name() string {
device, err := n.getDevice()
if err != nil {
return ""
}
name, ret := device.GetName()
if ret != nvml.SUCCESS {
return ""
}
return name
}
func (n *nvidiaGPU) utilizationRate() uint32 {
device, err := n.getDevice()
if err != nil {
return 0
}
utilization, ret := device.GetUtilizationRates()
if ret != nvml.SUCCESS {
return 0
}
return utilization.Gpu
}
func (n *nvidiaGPU) memory() memoryInfo {
device, err := n.getDevice()
if err != nil {
return memoryInfo{}
}
memoryNVML, ret := device.GetMemoryInfo()
if ret != nvml.SUCCESS {
return memoryInfo{}
}
memory := memoryInfo{
used: memoryNVML.Used,
free: memoryNVML.Free,
total: memoryNVML.Total,
}
return memory
}
func (n *nvidiaGPU) temperature() float64 {
device, err := n.getDevice()
if err != nil {
return 0
}
temp, ret := device.GetTemperature(sensorNVML)
if ret != nvml.SUCCESS {
return 0
}
return float64(temp)
}
func (n *nvidiaGPU) powerUsage() uint32 {
device, err := n.getDevice()
if err != nil {
return 0
}
power, ret := device.GetPowerUsage()
if ret != nvml.SUCCESS {
return 0
}
return power
}
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewOffboardCmd is a constructor for `offboard` command
func newOffboardCmd(client *utils.HTTPClient) *cobra.Command {
fnForce := "force"
cmd := &cobra.Command{
Use: "offboard",
Short: "Offboard the device from NuNet",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
fmt.Println("Warning: Offboarding will remove all your data and you will not be able to onboard again with the same identity")
answer, err := utils.PromptYesNo(cmd.InOrStdin(), cmd.OutOrStdout(), "Are you sure you want to offboard?")
if err != nil {
return fmt.Errorf("unable to read response: %w", err)
}
if !answer {
return nil
}
force, _ := cmd.Flags().GetBool(fnForce)
query := fmt.Sprintf("?force=%t", force)
resBody, resCode, err := client.MakeRequest("DELETE", "/onboarding/offboard"+query, nil)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
// TODO:what to do with the response body?
fmt.Fprintf(cmd.OutOrStdout(), "%s\n", resBody)
return nil
},
}
cmd.Flags().BoolP(fnForce, "f", false, "force offboarding")
return cmd
}
package cmd
import (
"fmt"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
var imagesNVIDIA = []string{
"registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/tensorflow",
"registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/pytorch",
}
var imagesAMD = []string{
"registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/tensorflow-amd",
"registry.gitlab.com/nunet/ml-on-gpu/ml-on-gpu-service/develop/pytorch-amd",
}
func newOnboardMLCmd(afs afero.Afero, dockerClient *docker.Client) *cobra.Command {
return &cobra.Command{
Use: "onboard-ml",
Short: "Setup for Machine Learning with GPU",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
wsl, err := utils.CheckWSL(afs)
if err != nil {
return fmt.Errorf("could not check WSL: %w", err)
}
vendors, err := resources.ManagerInstance.SystemSpecs().GetGPUVendors()
if err != nil {
return fmt.Errorf("could not fetch GPU vendors: %w", err)
}
// check for GPU vendors
hasAMD := containsVendor(vendors, types.GPUVendorAMDATI)
hasNVIDIA := containsVendor(vendors, types.GPUVendorNvidia)
hasIntel := containsVendor(vendors, types.GPUVendorIntel)
if !hasAMD && !hasNVIDIA && !hasIntel {
return fmt.Errorf("no NVIDIA/AMD/Intel GPU(s) detected")
}
if wsl {
fmt.Fprintf(cmd.OutOrStdout(), "You are running on Windows Subsystem for Linux (WSL)\nMake sure that NVIDIA drivers are set up correctly\n\nWARNING: AMD GPUs are not supported on WSL!\n")
}
if hasNVIDIA {
for _, image := range imagesNVIDIA {
digest, err := dockerClient.PullImage(cmd.Context(), image)
if err != nil {
return fmt.Errorf("could not pull image %s: %w", image, err)
}
fmt.Fprintf(cmd.OutOrStdout(), "Pulled Nvidia image %s with digest %s\n", image, digest)
}
}
if hasAMD {
for _, image := range imagesAMD {
digest, err := dockerClient.PullImage(cmd.Context(), image)
if err != nil {
return fmt.Errorf("could not pull image %s: %w", image, err)
}
fmt.Fprintf(cmd.OutOrStdout(), "Pulled AMD image %s with digest %s\n", image, digest)
}
}
return nil
},
}
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewOnboardCmd is a constructor for `onboard` command
func newOnboardCmd(client *utils.HTTPClient) *cobra.Command {
fnCPU := "cpu"
fnMemory := "memory"
fnChannel := "nunet-channel"
fnAddr := "address"
fnPlugin := "plugin"
fnNTXPrice := "ntx-price"
fnLocal := "local-enable"
fnCardano := "cardano"
fnUnavailable := "unavailable"
cmd := &cobra.Command{
Use: "onboard",
Short: "Onboard current machine to NuNet",
RunE: func(cmd *cobra.Command, _ []string) error {
fmt.Fprintln(cmd.OutOrStdout(), "Checking onboard status...")
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if onboarded {
err := promptReonboard(cmd.InOrStdin(), cmd.OutOrStdout())
if err != nil {
return err
}
}
memory, err := cmd.Flags().GetUint64(fnMemory)
if err != nil {
fmt.Println("couldn't get 'memory' flag: %w", err)
}
cpu, err := cmd.Flags().GetInt64(fnCPU)
if err != nil {
fmt.Println("couldn't get 'cpu' flag: %w", err)
}
channel, err := cmd.Flags().GetString(fnChannel)
if err != nil {
fmt.Println("couldn't get 'channel' flag: %w", err)
}
addr, err := cmd.Flags().GetString(fnAddr)
if err != nil {
fmt.Println("couldn't get 'addr' flag: %w", err)
}
ntx, err := cmd.Flags().GetFloat64(fnNTXPrice)
if err != nil {
fmt.Println("couldn't get 'ntx' flag: %w", err)
}
local, err := cmd.Flags().GetBool(fnLocal)
if err != nil {
fmt.Println("couldn't get 'local' flag: %w", err)
}
unavailable, err := cmd.Flags().GetBool(fnUnavailable)
if err != nil {
fmt.Println("couldn't get 'unavailable' flag: %w", err)
}
cardano, err := cmd.Flags().GetBool(fnCardano)
if err != nil {
fmt.Println("couldn't get 'cardano' flag: %w", err)
}
reserved := types.CapacityForNunet{
Memory: memory,
CPU: cpu,
Channel: channel,
PaymentAddress: addr,
NTXPricePerMinute: ntx,
Cardano: cardano,
ServerMode: local,
IsAvailable: unavailable, // TODO: Update this
}
onboardJSON, err := json.Marshal(reserved)
if err != nil {
return fmt.Errorf("unable to marshal JSON data: %w", err)
}
resBody, resCode, err := client.MakeRequest("POST", "/onboarding/onboard", onboardJSON)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 201 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
// TODO: what to do with the response body?
fmt.Fprintf(cmd.OutOrStdout(), "%s\n", resBody)
fmt.Fprintln(cmd.OutOrStdout(), "Successfully onboarded!")
return nil
},
}
cmd.Flags().Uint64P(fnMemory, "m", 0, "set value for memory usage")
cmd.Flags().Int64P(fnCPU, "c", 0, "set value for CPU usage")
cmd.Flags().StringP(fnChannel, "n", "", "set channel")
cmd.Flags().StringP(fnAddr, "a", "", "set wallet address")
cmd.Flags().Float64P(fnNTXPrice, "x", 0, "price in NTX per minute for onboarded compute resource")
cmd.Flags().StringP(fnPlugin, "p", "", "set plugin")
cmd.Flags().BoolP(fnUnavailable, "u", false, "unavailable for job deployment (default: false)")
cmd.Flags().BoolP(fnLocal, "l", true, "set server mode (enable for local)")
cmd.Flags().BoolP(fnCardano, "C", false, "set Cardano wallet")
cmd.MarkFlagsRequiredTogether("memory", "cpu", "nunet-channel", "address")
return cmd
}
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
func newPeerCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "peer",
Short: "Peer-related operations",
Long: ``,
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
}
cmd.AddCommand(newPeerListCmd(client))
cmd.AddCommand(newPeerSelfCmd(client))
return cmd
}
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewPeerListCmd is a constructor for `peer list` subcommand
func newPeerListCmd(client *utils.HTTPClient) *cobra.Command {
fnDHT := "dht"
cmd := &cobra.Command{
Use: "list",
Short: "Display list of peers in the network",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
dht, _ := cmd.Flags().GetBool(fnDHT)
if !dht {
bootPeer, err := getBootstrapPeers(cmd.OutOrStdout(), client)
if err != nil {
return fmt.Errorf("could not fetch bootstrap peers: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "Bootstrap peers (%d)\n", len(bootPeer))
for _, b := range bootPeer {
fmt.Fprintf(cmd.OutOrStdout(), "%s\n", b)
}
fmt.Fprintf(cmd.OutOrStdout(), "\n")
}
dhtPeer, err := getDHTPeers(client)
if err != nil {
return fmt.Errorf("could not fetch DHT peers: %w", err)
}
fmt.Fprintf(cmd.OutOrStdout(), "DHT peers (%d)\n", len(dhtPeer))
for _, d := range dhtPeer {
fmt.Fprintf(cmd.OutOrStdout(), "%s\n", d)
}
return nil
},
}
cmd.Flags().BoolP(fnDHT, "d", false, "list only DHT peers")
return cmd
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/buger/jsonparser"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewPeerSelfCmd is a constructor for `peer self` subcommand
func newPeerSelfCmd(client *utils.HTTPClient) *cobra.Command {
return &cobra.Command{
Use: "self",
Short: "Display self peer info",
Long: ``,
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
resBody, resCode, err := client.MakeRequest("GET", "/peers/self", nil)
if err != nil {
return fmt.Errorf("unable to make internal request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
var response map[string]any
if err := json.Unmarshal(resBody, &response); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
id, ok := response["id"]
if !ok {
return fmt.Errorf("no self peer ID returned")
}
addrsByte, _, _, err := jsonparser.Get(resBody, "listen_addr")
if err != nil {
return fmt.Errorf("failed to get addresses field: %w", err)
}
fmt.Fprintln(cmd.OutOrStdout(), "Host ID:", id)
fmt.Fprintln(cmd.OutOrStdout(), "Listening Addresses:", string(addrsByte))
return nil
},
}
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewResourceConfigCmd is a constructor for `resource-config` command
func newResourceConfigCmd(client *utils.HTTPClient) *cobra.Command {
fnMemory := "memory"
fnCPU := "cpu"
fnNTXPrice := "ntx-price"
cmd := &cobra.Command{
Use: "resource-config",
Short: "Update configuration of onboarded device",
RunE: func(cmd *cobra.Command, _ []string) error {
onboarded, err := checkOnboarded(client)
if err != nil {
return fmt.Errorf("could not check onboard status: %w", err)
}
if !onboarded {
return fmt.Errorf("machine is not onboarded")
}
memory, _ := cmd.Flags().GetUint64(fnMemory)
cpu, _ := cmd.Flags().GetInt64(fnCPU)
ntx, _ := cmd.Flags().GetFloat64(fnNTXPrice)
updated := types.CapacityForNunet{
Memory: memory,
CPU: cpu,
NTXPricePerMinute: ntx,
}
updatedConfig, err := json.Marshal(updated)
if err != nil {
return fmt.Errorf("unable to marshal JSON data: %w", err)
}
resBody, resCode, err := client.MakeRequest("POST", "/onboarding/resource-config", updatedConfig)
if err != nil {
return fmt.Errorf("unable to make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
fmt.Fprintln(cmd.OutOrStdout(), "Resources updated successfully!")
fmt.Fprintln(cmd.OutOrStdout(), string(resBody))
return nil
},
}
cmd.Flags().Uint64P(fnMemory, "m", 0, "set amount of memory")
cmd.Flags().Int64P(fnCPU, "c", 0, "set amount of CPU")
cmd.Flags().Float64P(fnNTXPrice, "x", 0, "Set NTX Price per minute for compute resources you are updating")
cmd.MarkFlagsRequiredTogether("cpu", "memory")
return cmd
}
package cmd
import (
"fmt"
"github.com/coreos/go-systemd/sdjournal"
"github.com/spf13/afero"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/api/docs"
"gitlab.com/nunet/device-management-service/cmd/backend"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/internal/config"
"gitlab.com/nunet/device-management-service/utils"
)
func newRootCmd(client *utils.HTTPClient, afs afero.Afero, dockerExec *docker.Client, logger backend.Logger) *cobra.Command {
cmd := &cobra.Command{
Use: "nunet",
Short: "NuNet Device Management Service",
Long: `The Device Management Service (DMS) Command Line Interface (CLI)`,
Version: docs.SwaggerInfo.Version,
CompletionOptions: cobra.CompletionOptions{
DisableDefaultCmd: false,
HiddenDefaultCmd: true,
},
SilenceErrors: true,
SilenceUsage: true,
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
}
cmd.AddCommand(newGPUCmd(afs))
cmd.AddCommand(newOffboardCmd(client))
cmd.AddCommand(newOnboardMLCmd(afs, dockerExec))
cmd.AddCommand(newRunCmd())
cmd.AddCommand(newPeerCmd(client))
cmd.AddCommand(newOnboardCmd(client))
cmd.AddCommand(newInfoCmd(client))
cmd.AddCommand(newDeviceCmd(client))
cmd.AddCommand(newCapacityCmd(client))
cmd.AddCommand(newResourceConfigCmd(client))
cmd.AddCommand(newLogCmd(afs, logger))
cmd.AddCommand(newWalletCmd(client))
cmd.AddCommand(newVersionCmd())
cmd.AddCommand(newAutoCompleteCmd())
return cmd
}
// Execute is a wrapper for cobra.Command method with same name
// It makes use of cobra.CheckErr to facilitate error handling
func Execute() {
afs := afero.Afero{Fs: afero.NewOsFs()}
client := utils.NewHTTPClient(
fmt.Sprintf("http://%s:%d",
config.GetConfig().Addr,
config.GetConfig().Port),
"/api/v1",
)
dockerClient, err := docker.NewDockerClient()
if err != nil {
cobra.CheckErr(fmt.Errorf("couldn't instantiate new docker client; Error: %w", err))
}
journal, err := sdjournal.NewJournal()
if err != nil {
cobra.CheckErr(fmt.Errorf("failed to get new sdjournal; Error: %w", err))
}
cobra.CheckErr(newRootCmd(client, afs, dockerClient, journal).Execute())
}
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/dms"
)
func newRunCmd() *cobra.Command {
return &cobra.Command{
Use: "run",
Short: "Start the Device Management Service",
Long: `The Device Management Service (DMS) is a system application for computing and service providers. It handles networking and device management.`,
Run: func(_ *cobra.Command, _ []string) {
dms.Run()
},
}
}
package cmd
import (
"archive/tar"
"compress/gzip"
"errors"
"fmt"
"io"
"os"
"strings"
"github.com/buger/jsonparser"
"github.com/olekukonko/tablewriter"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func checkOnboarded(client *utils.HTTPClient) (bool, error) {
resBody, resCode, err := client.MakeRequest("GET", "/onboarding/status", nil)
if err != nil {
return false, fmt.Errorf("unable to make request: %w", err)
}
if resCode != 200 {
return false, fmt.Errorf("request failed with status code: %d", resCode)
}
onboarded, err := jsonparser.GetBoolean(resBody, "onboarded")
if err != nil {
return false, fmt.Errorf("could not get onboard status: %w", err)
}
return onboarded, nil
}
// promptReonboard is a wrapper of utils.PromptYesNo with custom prompt that return error if user declines reonboard
func promptReonboard(r io.Reader, w io.Writer) error {
prompt := "Looks like your machine is already onboarded. Proceed with reonboarding?"
confirmed, err := utils.PromptYesNo(r, w, prompt)
if err != nil {
return fmt.Errorf("could not confirm reonboarding: %w", err)
}
if !confirmed {
return fmt.Errorf("reonboarding aborted by user")
}
return nil
}
// getDHTPeers fetches API to retrieve info from DHT peers
func getDHTPeers(client *utils.HTTPClient) ([]string, error) {
resBody, resCode, err := client.MakeRequest("GET", "/peers/dht", nil)
if err != nil {
return nil, fmt.Errorf("cannot make request: %w", err)
}
if resCode != 200 {
return nil, fmt.Errorf("request failed with status code: %d", resCode)
}
msg, err := jsonparser.GetString(resBody, "message")
if err == nil {
return nil, errors.New(msg)
}
var dhtSlice []string
if _, err = jsonparser.ArrayEach(resBody, func(value []byte, _ jsonparser.ValueType, _ int, _ error) {
dhtSlice = append(dhtSlice, string(value))
}); err != nil {
return nil, fmt.Errorf("cannot iterate over DHT peer list: %w", err)
}
if len(dhtSlice) == 0 {
return nil, fmt.Errorf("no DHT peers available")
}
return dhtSlice, nil
}
// getBootstrapPeers fetches API to retrieve data from bootstrap peers
func getBootstrapPeers(w io.Writer, client *utils.HTTPClient) ([]string, error) {
resBody, resCode, err := client.MakeRequest("GET", "/peers", nil)
if err != nil {
return nil, fmt.Errorf("unable to make request: %w", err)
}
if resCode != 200 {
return nil, fmt.Errorf("request failed with status code: %d", resCode)
}
msg, err := jsonparser.GetString(resBody, "message")
if err == nil {
return nil, errors.New(msg)
}
var bootSlice []string
if _, err = jsonparser.ArrayEach(resBody, func(value []byte, _ jsonparser.ValueType, _ int, _ error) {
id, err := jsonparser.GetString(value, "ID")
if err != nil {
fmt.Fprintln(w, "Error getting bootstrap peer ID string:", err)
os.Exit(1)
}
bootSlice = append(bootSlice, id)
}); err != nil {
return nil, fmt.Errorf("cannot iterate over bootstrap peer list: %w", err)
}
if len(bootSlice) == 0 {
return nil, fmt.Errorf("no bootstrap peers available")
}
return bootSlice, nil
}
// printWallet takes types.BlockchainAddressPrivKey struct as input and display it in YAML-like format for better readability
func printWallet(w io.Writer, pair *types.BlockchainAddressPrivKey) {
if pair.Address != "" {
fmt.Fprintf(w, "address: %s\n", pair.Address)
}
if pair.PrivateKey != "" {
fmt.Fprintf(w, "private_key: %s\n", pair.PrivateKey)
}
if pair.Mnemonic != "" {
fmt.Fprintf(w, "mnemonic: %s\n", pair.Mnemonic)
}
}
func setupTable(w io.Writer) *tablewriter.Table {
table := tablewriter.NewWriter(w)
headers := []string{"Resources", "Memory", "CPU", "Cores"}
table.SetHeader(headers)
table.SetAutoMergeCellsByColumnIndex([]int{0})
table.SetAutoFormatHeaders(false)
return table
}
// appendToFile opens filename and write string data to it
func appendToFile(afs afero.Afero, filename, data string) error {
// nolint:gofumpt
f, err := afs.OpenFile(filename, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
return fmt.Errorf("open %s file failed: %w", filename, err)
}
defer f.Close()
_, err = f.WriteString(data)
if err != nil {
return fmt.Errorf("write string data to file %s failed: %w", filename, err)
}
return nil
}
func createTar(afs afero.Afero, tarGzPath string, sourceDir string) error {
tarGzFile, err := afs.Create(tarGzPath)
if err != nil {
return fmt.Errorf("create %s file failed: %w", tarGzPath, err)
}
defer tarGzFile.Close()
gzWriter := gzip.NewWriter(tarGzFile)
defer gzWriter.Close()
tarWriter := tar.NewWriter(gzWriter)
defer tarWriter.Close()
return afs.Walk(sourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if path == tarGzPath {
return nil
}
header, err := tar.FileInfoHeader(info, info.Name())
if err != nil {
return err
}
header.Name = strings.TrimPrefix(path, sourceDir)
if header.Name == "" || header.Name == "/" {
return nil
}
err = tarWriter.WriteHeader(header)
if err != nil {
return err
}
if info.Mode().IsRegular() {
data, err := afs.ReadFile(path)
if err != nil {
return err
}
_, err = tarWriter.Write(data)
if err != nil {
return err
}
}
return nil
})
}
package cmd
import (
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/api/docs"
)
func newVersionCmd() *cobra.Command {
return &cobra.Command{
Use: "version",
Short: "Display the Nunet DMS version",
Long: `This command prints the version of the Nunet Device Management Service.`,
Run: func(_ *cobra.Command, _ []string) {
fmt.Printf("Nunet Device Management Service Version: %s\n", docs.SwaggerInfo.Version)
},
}
}
package cmd
import (
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/utils"
)
// NewWalletCmd is a constructor for `wallet` parent command
func newWalletCmd(client *utils.HTTPClient) *cobra.Command {
cmd := &cobra.Command{
Use: "wallet",
Short: "Wallet Management",
Run: func(cmd *cobra.Command, _ []string) {
_ = cmd.Help()
},
}
cmd.AddCommand(newWalletNewCmd(client))
return cmd
}
package cmd
import (
"encoding/json"
"fmt"
"github.com/spf13/cobra"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// NewWalletNewCmd is a constructor for `wallet new` command
func newWalletNewCmd(client *utils.HTTPClient) *cobra.Command {
fnEth := "ethereum"
fnCardano := "cardano"
cmd := &cobra.Command{
Use: "new",
Short: "Create new wallet",
RunE: func(cmd *cobra.Command, _ []string) error {
eth, _ := cmd.Flags().GetBool(fnEth)
var (
pair *types.BlockchainAddressPrivKey
query string
)
if eth {
query = "?blockchain=ethereum"
}
resBody, resCode, err := client.MakeRequest("GET", "/onboarding/address/new"+query, nil)
if err != nil {
return fmt.Errorf("could not make request: %w", err)
}
if resCode != 200 {
return fmt.Errorf("request failed with status code: %d", resCode)
}
if err := json.Unmarshal(resBody, &pair); err != nil {
return fmt.Errorf("could not unmarshal response body: %w", err)
}
printWallet(cmd.OutOrStdout(), pair)
return nil
},
}
cmd.Flags().BoolP(fnEth, "e", false, "create Ethereum wallet")
cmd.Flags().BoolP(fnCardano, "c", false, "create Cardano wallet")
cmd.MarkFlagsOneRequired(fnEth, fnCardano)
cmd.MarkFlagsMutuallyExclusive(fnEth, fnCardano)
return cmd
}
package clover
import (
"fmt"
clover "github.com/ostafen/clover/v2"
)
// NewDB initializes and sets up the clover database using bbolt under the hood.
// Additionally, it automatically creates collections for the necessary types.
func NewDB(path string, collections []string) (*clover.DB, error) {
db, err := clover.Open(path)
if err != nil {
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
for _, collection := range collections {
if err := db.CreateCollection(collection); err != nil {
return nil, fmt.Errorf("failed to create collection %s: %w", collection, err)
}
}
return db, nil
}
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatClover is a Clover implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatClover struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatClover.
// It initializes and returns a Clover-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *clover.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatClover{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerClover is a Clover implementation of the RequestTracker interface.
type RequestTrackerClover struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerClover.
// It initializes and returns a Clover-based repository for RequestTracker entities.
func NewRequestTracker(db *clover.DB) repositories.RequestTracker {
return &RequestTrackerClover{
NewGenericRepository[types.RequestTracker](db),
}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineClover is a Clover implementation of the VirtualMachine interface.
type VirtualMachineClover struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineClover.
// It initializes and returns a Clover-based repository for VirtualMachine entities.
func NewVirtualMachine(db *clover.DB) repositories.VirtualMachine {
return &VirtualMachineClover{
NewGenericRepository[types.VirtualMachine](db),
}
}
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericEntityRepositoryClover is a generic single entity repository implementation using Clover.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryClover.
// It initializes and returns a repository with the provided Clover database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericEntityRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericEntityRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericEntityRepositoryClover[T]) query() *clover_q.Query {
return clover_q.NewQuery(repo.collection)
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryClover[T]) Save(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) Get(_ context.Context) (T, error) {
var model T
q := repo.query().Sort(clover_q.SortOption{
Field: "CreatedAt",
Direction: -1,
})
doc, err := repo.db.FindFirst(q)
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, true)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryClover[T]) Clear(_ context.Context) error {
return repo.db.Delete(repo.query())
}
// History retrieves previous versions of the record from the repository.
func (repo *GenericEntityRepositoryClover[T]) History(_ context.Context, query repositories.Query[T]) ([]T, error) {
var models []T
q := repo.query()
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return false
}
models = append(models, model)
return true
})
return models, handleDBError(err)
}
package clover
import (
"context"
"fmt"
"reflect"
"time"
"github.com/iancoleman/strcase"
clover "github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
clover_q "github.com/ostafen/clover/v2/query"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
pkField = "_id"
deletedAtField = "DeletedAt"
)
// GenericRepositoryClover is a generic repository implementation using Clover.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryClover[T repositories.ModelType] struct {
db *clover.DB // db is the Clover database instance.
collection string // collection is the name of the collection in the database.
}
// NewGenericRepository creates a new instance of GenericRepositoryClover.
// It initializes and returns a repository with the provided Clover database.
func NewGenericRepository[T repositories.ModelType](
db *clover.DB,
) repositories.GenericRepository[T] {
collection := strcase.ToSnake(reflect.TypeOf(*new(T)).Name())
return &GenericRepositoryClover[T]{db: db, collection: collection}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryClover[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
func (repo *GenericRepositoryClover[T]) query(includeDeleted bool) *clover_q.Query {
query := clover_q.NewQuery(repo.collection)
if !includeDeleted {
query = query.Where(clover_q.Field(deletedAtField).LtEq(time.Unix(0, 0)))
}
return query
}
func (repo *GenericRepositoryClover[T]) queryWithID(
id interface{},
includeDeleted bool,
) *clover_q.Query {
return repo.query(includeDeleted).Where(clover_q.Field(pkField).Eq(id.(string)))
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryClover[T]) Create(_ context.Context, data T) (T, error) {
var model T
doc := toCloverDoc(data)
doc.Set("CreatedAt", time.Now())
_, err := repo.db.InsertOne(repo.collection, doc)
if err != nil {
return data, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return data, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryClover[T]) Get(_ context.Context, id interface{}) (T, error) {
var model T
doc, err := repo.db.FindFirst(repo.queryWithID(id, false))
if err != nil || doc == nil {
return model, handleDBError(err)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, err))
}
return model, nil
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryClover[T]) Update(
ctx context.Context,
id interface{},
data T,
) (T, error) {
updates := toCloverDoc(data).AsMap()
updates["UpdatedAt"] = time.Now()
err := repo.db.Update(repo.queryWithID(id, false), updates)
if err != nil {
return data, handleDBError(err)
}
data, err = repo.Get(ctx, id)
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryClover[T]) Delete(_ context.Context, id interface{}) error {
err := repo.db.Update(
repo.queryWithID(id, false),
map[string]interface{}{"DeletedAt": time.Now()},
)
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryClover[T]) Find(
_ context.Context,
query repositories.Query[T],
) (T, error) {
var model T
q := repo.query(false)
q = applyConditions(q, query)
doc, err := repo.db.FindFirst(q)
if err != nil {
return model, handleDBError(err)
} else if doc == nil {
return model, handleDBError(clover.ErrDocumentNotExist)
}
model, err = toModel[T](doc, false)
if err != nil {
return model, fmt.Errorf("failed to convert document to model: %v", err)
}
return model, nil
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryClover[T]) FindAll(
_ context.Context,
query repositories.Query[T],
) ([]T, error) {
var models []T
var modelParsingErr error
q := repo.query(false)
q = applyConditions(q, query)
err := repo.db.ForEach(q, func(doc *clover_d.Document) bool {
model, internalErr := toModel[T](doc, false)
if internalErr != nil {
modelParsingErr = handleDBError(fmt.Errorf("%v: %v", repositories.ErrParsingModel, internalErr))
return false
}
models = append(models, model)
return true
})
if err != nil {
return models, handleDBError(err)
}
if modelParsingErr != nil {
return models, modelParsingErr
}
return models, nil
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a Clover database query.
// It takes a Clover database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified Clover database instance is returned.
func applyConditions[T repositories.ModelType](
q *clover_q.Query,
query repositories.Query[T],
) *clover_q.Query {
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
// change the field name to json tag name if specified in the struct
condition.Field = fieldJSONTag[T](condition.Field)
switch condition.Operator {
case "=":
q = q.Where(clover_q.Field(condition.Field).Eq(condition.Value))
case ">":
q = q.Where(clover_q.Field(condition.Field).Gt(condition.Value))
case ">=":
q = q.Where(clover_q.Field(condition.Field).GtEq(condition.Value))
case "<":
q = q.Where(clover_q.Field(condition.Field).Lt(condition.Value))
case "<=":
q = q.Where(clover_q.Field(condition.Field).LtEq(condition.Value))
case "!=":
q = q.Where(clover_q.Field(condition.Field).Neq(condition.Value))
case "IN":
if values, ok := condition.Value.([]interface{}); ok {
q = q.Where(clover_q.Field(condition.Field).In(values...))
}
case "LIKE":
if value, ok := condition.Value.(string); ok {
q = q.Where(clover_q.Field(condition.Field).Like(value))
}
}
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldName = fieldJSONTag[T](fieldName)
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
q = q.Where(clover_q.Field(fieldName).Eq(fieldValue))
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := 1
if query.SortBy[0] == '-' {
dir = -1
query.SortBy = fieldJSONTag[T](query.SortBy[1:])
}
q = q.Sort(clover_q.SortOption{Field: query.SortBy, Direction: dir})
}
// Apply limit if specified in the query.
if query.Limit > 0 {
q = q.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
q = q.Limit(query.Offset)
}
return q
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoClover is a Clover implementation of the PeerInfo interface.
type PeerInfoClover struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoClover.
// It initializes and returns a Clover-based repository for PeerInfo entities.
func NewPeerInfo(db *clover.DB) repositories.PeerInfo {
return &PeerInfoClover{NewGenericRepository[types.PeerInfo](db)}
}
// MachineClover is a Clover implementation of the Machine interface.
type MachineClover struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineClover.
// It initializes and returns a Clover-based repository for Machine entities.
func NewMachine(db *clover.DB) repositories.Machine {
return &MachineClover{NewGenericRepository[types.Machine](db)}
}
// FreeResourcesClover is a Clover implementation of the FreeResources interface.
type FreeResourcesClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResources(db *clover.DB) repositories.FreeResources {
return &FreeResourcesClover{NewGenericEntityRepository[types.FreeResources](db)}
}
// AvailableResourcesClover is a Clover implementation of the AvailableResources interface.
type AvailableResourcesClover struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// AvailableResourcesRepositoryClover is a Clover implementation of the AvailableResourcesRepository interface.
type AvailableResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesClover.
// It initializes and returns a Clover-based repository for AvailableResources entity.
func NewAvailableResources(db *clover.DB) repositories.AvailableResources {
return &AvailableResourcesClover{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesClover is a Clover implementation of the Services interface.
type ServicesClover struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesClover.
// It initializes and returns a Clover-based repository for Services entities.
func NewServices(db *clover.DB) repositories.Services {
return &ServicesClover{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsClover is a Clover implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsClover struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsClover.
// It initializes and returns a Clover-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *clover.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsClover{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoClover is a Clover implementation of the Libp2pInfo interface.
type Libp2pInfoClover struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoClover.
// It initializes and returns a Clover-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *clover.DB) repositories.Libp2pInfo {
return &Libp2pInfoClover{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDClover is a Clover implementation of the MachineUUID interface.
type MachineUUIDClover struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDClover.
// It initializes and returns a Clover-based repository for MachineUUID entity.
func NewMachineUUID(db *clover.DB) repositories.MachineUUID {
return &MachineUUIDClover{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionClover is a Clover implementation of the Connection interface.
type ConnectionClover struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionClover.
// It initializes and returns a Clover-based repository for Connection entities.
func NewConnection(db *clover.DB) repositories.Connection {
return &ConnectionClover{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenClover is a Clover implementation of the ElasticToken interface.
type ElasticTokenClover struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenClover.
// It initializes and returns a Clover-based repository for ElasticToken entities.
func NewElasticToken(db *clover.DB) repositories.ElasticToken {
return &ElasticTokenClover{NewGenericRepository[types.ElasticToken](db)}
}
package clover
import (
"github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesRepositoryClover is a Clover implementation of the MachineResourcesRepository interface.
type MachineResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResourcesRepository creates a new instance of MachineResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for MachineResources entity.
func NewMachineResourcesRepository(db *clover.DB) repositories.MachineResources {
return &MachineResourcesRepositoryClover{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesRepositoryClover is a Clover implementation of the FreeResourcesRepository interface.
type FreeResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResourcesRepository creates a new instance of FreeResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for FreeResources entity.
func NewFreeResourcesRepository(db *clover.DB) repositories.FreeResources {
return &FreeResourcesRepositoryClover{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesRepositoryClover is a Clover implementation of the OnboardedResourcesRepository interface.
type OnboardedResourcesRepositoryClover struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResourcesRepository creates a new instance of OnboardedResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for OnboardedResources entity.
func NewOnboardedResourcesRepository(db *clover.DB) repositories.OnboardedResources {
return &OnboardedResourcesRepositoryClover{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// RequiredResourcesRepositoryClover is a Clover implementation of the RequiredResourcesRepository interface.
type RequiredResourcesRepositoryClover struct {
repositories.GenericRepository[types.RequiredResources]
}
// NewRequiredResourcesRepository creates a new instance of RequiredResourcesRepositoryClover.
// It initializes and returns a Clover-based repository for RequiredResources entities.
func NewRequiredResourcesRepository(db *clover.DB) repositories.RequiredResources {
return &RequiredResourcesRepositoryClover{
NewGenericRepository[types.RequiredResources](db),
}
}
package clover
import (
clover "github.com/ostafen/clover/v2"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// StorageVolumeClover is a Clover implementation of the StorageVolume interface.
type StorageVolumeClover struct {
repositories.GenericRepository[types.StorageVolume]
}
// NewStorageVolume creates a new instance of StorageVolumeClover.
// It initializes and returns a Clover-based repository for StorageVolume entities.
func NewStorageVolume(db *clover.DB) repositories.StorageVolume {
return &StorageVolumeClover{
NewGenericRepository[types.StorageVolume](db),
}
}
package clover
import (
"encoding/json"
"errors"
"reflect"
"strings"
"github.com/ostafen/clover/v2"
clover_d "github.com/ostafen/clover/v2/document"
"gitlab.com/nunet/device-management-service/db/repositories"
)
func handleDBError(err error) error {
if err != nil {
switch err {
case clover.ErrDocumentNotExist:
return repositories.ErrNotFound
case clover.ErrDuplicateKey:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
func toCloverDoc[T repositories.ModelType](data T) *clover_d.Document {
jsonBytes, err := json.Marshal(data)
if err != nil {
return clover_d.NewDocument()
}
mappedData := make(map[string]interface{})
err = json.Unmarshal(jsonBytes, &mappedData)
if err != nil {
return clover_d.NewDocument()
}
doc := clover_d.NewDocumentOf(mappedData)
return doc
}
func toModel[T repositories.ModelType](doc *clover_d.Document, isEntityRepo bool) (T, error) {
var model T
err := doc.Unmarshal(&model)
if err != nil {
return model, err
}
if !isEntityRepo {
// we shouldn't try to update IDs of entity repositories as they might not
// even have an ID at all
model, err = repositories.UpdateField(model, "ID", doc.ObjectId())
if err != nil {
return model, err
}
}
return model, nil
}
func fieldJSONTag[T repositories.ModelType](field string) string {
fieldName := field
if field, ok := reflect.TypeOf(*new(T)).FieldByName(field); ok {
if tag, ok := field.Tag.Lookup("json"); ok {
fieldName = strings.Split(tag, ",")[0]
}
}
return fieldName
}
package repositories
import (
"context"
)
// QueryCondition is a struct representing a query condition.
type QueryCondition struct {
Field string // Field specifies the database or struct field to which the condition applies.
Operator string // Operator defines the comparison operator (e.g., "=", ">", "<").
Value interface{} // Value is the expected value for the given field.
}
type ModelType interface{}
// Query is a struct that wraps both the instance of type T and additional query parameters.
// It is used to construct queries with conditions, sorting, limiting, and offsetting.
type Query[T any] struct {
Instance T // Instance is an optional object of type T used to build conditions from its fields.
Conditions []QueryCondition // Conditions represent the conditions applied to the query.
SortBy string // SortBy specifies the field by which the query results should be sorted.
Limit int // Limit specifies the maximum number of results to return.
Offset int // Offset specifies the number of results to skip before starting to return data.
}
// GenericRepository is an interface defining basic CRUD operations and standard querying methods.
type GenericRepository[T ModelType] interface {
// Create adds a new record to the repository.
Create(ctx context.Context, data T) (T, error)
// Get retrieves a record by its identifier.
Get(ctx context.Context, id interface{}) (T, error)
// Update modifies a record by its identifier.
Update(ctx context.Context, id interface{}, data T) (T, error)
// Delete removes a record by its identifier.
Delete(ctx context.Context, id interface{}) error
// Find retrieves a single record based on a query.
Find(ctx context.Context, query Query[T]) (T, error)
// FindAll retrieves multiple records based on a query.
FindAll(ctx context.Context, query Query[T]) ([]T, error)
// GetQuery returns an empty query instance for the repository's type.
GetQuery() Query[T]
}
// EQ creates a QueryCondition for equality comparison.
// It takes a field name and a value and returns a QueryCondition with the equality operator.
func EQ(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "=", Value: value}
}
// GT creates a QueryCondition for greater-than comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than operator.
func GT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">", Value: value}
}
// GTE creates a QueryCondition for greater-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the greater-than or equal operator.
func GTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: ">=", Value: value}
}
// LT creates a QueryCondition for less-than comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than operator.
func LT(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<", Value: value}
}
// LTE creates a QueryCondition for less-than or equal comparison.
// It takes a field name and a value and returns a QueryCondition with the less-than or equal operator.
func LTE(field string, value interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "<=", Value: value}
}
// IN creates a QueryCondition for an "IN" comparison.
// It takes a field name and a slice of values and returns a QueryCondition with the "IN" operator.
func IN(field string, values []interface{}) QueryCondition {
return QueryCondition{Field: field, Operator: "IN", Value: values}
}
// LIKE creates a QueryCondition for a "LIKE" comparison.
// It takes a field name and a pattern and returns a QueryCondition with the "LIKE" operator.
func LIKE(field, pattern string) QueryCondition {
return QueryCondition{Field: field, Operator: "LIKE", Value: pattern}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// DeploymentRequestFlatGORM is a GORM implementation of the DeploymentRequestFlat interface.
type DeploymentRequestFlatGORM struct {
repositories.GenericRepository[types.DeploymentRequestFlat]
}
// NewDeploymentRequestFlat creates a new instance of DeploymentRequestFlatGORM.
// It initializes and returns a GORM-based repository for DeploymentRequestFlat entities.
func NewDeploymentRequestFlat(db *gorm.DB) repositories.DeploymentRequestFlat {
return &DeploymentRequestFlatGORM{
NewGenericRepository[types.DeploymentRequestFlat](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// RequestTrackerGORM is a GORM implementation of the RequestTracker interface.
type RequestTrackerGORM struct {
repositories.GenericRepository[types.RequestTracker]
}
// NewRequestTracker creates a new instance of RequestTrackerGORM.
// It initializes and returns a GORM-based repository for RequestTracker entities.
func NewRequestTracker(db *gorm.DB) repositories.RequestTracker {
return &RequestTrackerGORM{
NewGenericRepository[types.RequestTracker](db),
}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// VirtualMachineGORM is a GORM implementation of the VirtualMachine interface.
type VirtualMachineGORM struct {
repositories.GenericRepository[types.VirtualMachine]
}
// NewVirtualMachine creates a new instance of VirtualMachineGORM.
// It initializes and returns a GORM-based repository for VirtualMachine entities.
func NewVirtualMachine(db *gorm.DB) repositories.VirtualMachine {
return &VirtualMachineGORM{
NewGenericRepository[types.VirtualMachine](db),
}
}
package gorm
import (
"context"
"fmt"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
const (
createdAtField = "CreatedAt"
)
// GenericEntityRepositoryGORM is a generic single entity repository implementation using GORM as an ORM.
// It is intended to be embedded in single entity model repositories to provide basic database operations.
type GenericEntityRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB // db is the GORM database instance.
}
// NewGenericEntityRepository creates a new instance of GenericEntityRepositoryGORM.
// It initializes and returns a repository with the provided GORM database, primary key field, and value.
func NewGenericEntityRepository[T repositories.ModelType](
db *gorm.DB,
) repositories.GenericEntityRepository[T] {
return &GenericEntityRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericEntityRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Save creates or updates the record to the repository and returns the new/updated data.
func (repo *GenericEntityRepositoryGORM[T]) Save(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves the record from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Get(ctx context.Context) (T, error) {
var result T
query := repo.GetQuery()
query.SortBy = fmt.Sprintf("-%s", createdAtField)
db := repo.db.WithContext(ctx)
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// Clear removes the record with its history from the repository.
func (repo *GenericEntityRepositoryGORM[T]) Clear(ctx context.Context) error {
return repo.db.WithContext(ctx).Delete(new(T), "id IS NOT NULL").Error
}
// History retrieves previous records from the repository constrained by the query.
func (repo *GenericEntityRepositoryGORM[T]) History(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
package gorm
import (
"context"
"fmt"
"reflect"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// GenericRepositoryGORM is a generic repository implementation using GORM as an ORM.
// It is intended to be embedded in model repositories to provide basic database operations.
type GenericRepositoryGORM[T repositories.ModelType] struct {
db *gorm.DB
}
// NewGenericRepository creates a new instance of GenericRepositoryGORM.
// It initializes and returns a repository with the provided GORM database.
func NewGenericRepository[T repositories.ModelType](db *gorm.DB) repositories.GenericRepository[T] {
return &GenericRepositoryGORM[T]{db: db}
}
// GetQuery returns a clean Query instance for building queries.
func (repo *GenericRepositoryGORM[T]) GetQuery() repositories.Query[T] {
return repositories.Query[T]{}
}
// Create adds a new record to the repository and returns the created data.
func (repo *GenericRepositoryGORM[T]) Create(ctx context.Context, data T) (T, error) {
err := repo.db.WithContext(ctx).Create(&data).Error
return data, handleDBError(err)
}
// Get retrieves a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Get(ctx context.Context, id interface{}) (T, error) {
var result T
err := repo.db.WithContext(ctx).First(&result, "id = ?", id).Error
if err != nil {
return result, handleDBError(err)
}
return result, handleDBError(err)
}
// Update modifies a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Update(ctx context.Context, id interface{}, data T) (T, error) {
err := repo.db.WithContext(ctx).Model(new(T)).Where("id = ?", id).Updates(data).Error
return data, handleDBError(err)
}
// Delete removes a record by its identifier.
func (repo *GenericRepositoryGORM[T]) Delete(ctx context.Context, id interface{}) error {
err := repo.db.WithContext(ctx).Delete(new(T), "id = ?", id).Error
return err
}
// Find retrieves a single record based on a query.
func (repo *GenericRepositoryGORM[T]) Find(
ctx context.Context,
query repositories.Query[T],
) (T, error) {
var result T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.First(&result).Error
return result, handleDBError(err)
}
// FindAll retrieves multiple records based on a query.
func (repo *GenericRepositoryGORM[T]) FindAll(
ctx context.Context,
query repositories.Query[T],
) ([]T, error) {
var results []T
db := repo.db.WithContext(ctx).Model(new(T))
db = applyConditions(db, query)
err := db.Find(&results).Error
return results, handleDBError(err)
}
// applyConditions applies conditions, sorting, limiting, and offsetting to a GORM database query.
// It takes a GORM database instance (db) and a generic query (repositories.Query) as input.
// The function dynamically constructs the WHERE clause based on the provided conditions and instance values.
// It also includes sorting, limiting, and offsetting based on the query parameters.
// The modified GORM database instance is returned.
func applyConditions[T any](db *gorm.DB, query repositories.Query[T]) *gorm.DB {
// Retrieve the table name using the GORM naming strategy.
tableName := db.NamingStrategy.TableName(reflect.TypeOf(*new(T)).Name())
// Apply conditions specified in the query.
for _, condition := range query.Conditions {
columnName := db.NamingStrategy.ColumnName(tableName, condition.Field)
db = db.Where(
fmt.Sprintf("%s %s ?", columnName, condition.Operator),
condition.Value,
)
}
// Apply conditions based on non-zero values in the query instance.
if !repositories.IsEmptyValue(query.Instance) {
exampleType := reflect.TypeOf(query.Instance)
exampleValue := reflect.ValueOf(query.Instance)
for i := 0; i < exampleType.NumField(); i++ {
fieldName := exampleType.Field(i).Name
fieldValue := exampleValue.Field(i).Interface()
if !repositories.IsEmptyValue(fieldValue) {
columnName := db.NamingStrategy.ColumnName(tableName, fieldName)
db = db.Where(fmt.Sprintf("%s = ?", columnName), fieldValue)
}
}
}
// Apply sorting if specified in the query.
if query.SortBy != "" {
dir := "ASC"
if query.SortBy[0] == '-' {
query.SortBy = query.SortBy[1:]
dir = "DESC"
}
columnName := db.NamingStrategy.ColumnName(tableName, query.SortBy)
db = db.Order(fmt.Sprintf("%s.%s %s", tableName, columnName, dir))
}
// Apply limit if specified in the query.
if query.Limit > 0 {
db = db.Limit(query.Limit)
}
// Apply offset if specified in the query.
if query.Offset > 0 {
db = db.Limit(query.Offset)
}
return db
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// PeerInfoGORM is a GORM implementation of the PeerInfo interface.
type PeerInfoGORM struct {
repositories.GenericRepository[types.PeerInfo]
}
// NewPeerInfo creates a new instance of PeerInfoGORM.
// It initializes and returns a GORM-based repository for PeerInfo entities.
func NewPeerInfo(db *gorm.DB) repositories.PeerInfo {
return &PeerInfoGORM{NewGenericRepository[types.PeerInfo](db)}
}
// MachineGORM is a GORM implementation of the Machine interface.
type MachineGORM struct {
repositories.GenericRepository[types.Machine]
}
// NewMachine creates a new instance of MachineGORM.
// It initializes and returns a GORM-based repository for Machine entities.
func NewMachine(db *gorm.DB) repositories.Machine {
return &MachineGORM{NewGenericRepository[types.Machine](db)}
}
// AvailableResourcesGORM is a GORM implementation of the AvailableResources interface.
type AvailableResourcesGORM struct {
repositories.GenericEntityRepository[types.AvailableResources]
}
// NewAvailableResources creates a new instance of AvailableResourcesGORM.
// It initializes and returns a GORM-based repository for AvailableResources entity.
func NewAvailableResources(db *gorm.DB) repositories.AvailableResources {
return &AvailableResourcesGORM{
NewGenericEntityRepository[types.AvailableResources](db),
}
}
// ServicesGORM is a GORM implementation of the Services interface.
type ServicesGORM struct {
repositories.GenericRepository[types.Services]
}
// NewServices creates a new instance of ServicesGORM.
// It initializes and returns a GORM-based repository for Services entities.
func NewServices(db *gorm.DB) repositories.Services {
return &ServicesGORM{NewGenericRepository[types.Services](db)}
}
// ServiceResourceRequirementsGORM is a GORM implementation of the ServiceResourceRequirements interface.
type ServiceResourceRequirementsGORM struct {
repositories.GenericRepository[types.ServiceResourceRequirements]
}
// NewServiceResourceRequirements creates a new instance of ServiceResourceRequirementsGORM.
// It initializes and returns a GORM-based repository for ServiceResourceRequirements entities.
func NewServiceResourceRequirements(
db *gorm.DB,
) repositories.ServiceResourceRequirements {
return &ServiceResourceRequirementsGORM{
NewGenericRepository[types.ServiceResourceRequirements](db),
}
}
// Libp2pInfoGORM is a GORM implementation of the Libp2pInfo interface.
type Libp2pInfoGORM struct {
repositories.GenericEntityRepository[types.Libp2pInfo]
}
// NewLibp2pInfo creates a new instance of Libp2pInfoGORM.
// It initializes and returns a GORM-based repository for Libp2pInfo entity.
func NewLibp2pInfo(db *gorm.DB) repositories.Libp2pInfo {
return &Libp2pInfoGORM{NewGenericEntityRepository[types.Libp2pInfo](db)}
}
// MachineUUIDGORM is a GORM implementation of the MachineUUID interface.
type MachineUUIDGORM struct {
repositories.GenericEntityRepository[types.MachineUUID]
}
// NewMachineUUID creates a new instance of MachineUUIDGORM.
// It initializes and returns a GORM-based repository for MachineUUID entity.
func NewMachineUUID(db *gorm.DB) repositories.MachineUUID {
return &MachineUUIDGORM{NewGenericEntityRepository[types.MachineUUID](db)}
}
// ConnectionGORM is a GORM implementation of the Connection interface.
type ConnectionGORM struct {
repositories.GenericRepository[types.Connection]
}
// NewConnection creates a new instance of ConnectionGORM.
// It initializes and returns a GORM-based repository for Connection entities.
func NewConnection(db *gorm.DB) repositories.Connection {
return &ConnectionGORM{NewGenericRepository[types.Connection](db)}
}
// ElasticTokenGORM is a GORM implementation of the ElasticToken interface.
type ElasticTokenGORM struct {
repositories.GenericRepository[types.ElasticToken]
}
// NewElasticToken creates a new instance of ElasticTokenGORM.
// It initializes and returns a GORM-based repository for ElasticToken entities.
func NewElasticToken(db *gorm.DB) repositories.ElasticToken {
return &ElasticTokenGORM{NewGenericRepository[types.ElasticToken](db)}
}
package gorm
import (
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
type OnboardingParamsGORM struct {
repositories.GenericEntityRepository[types.OnboardingConfig]
}
func NewOnboardingParams(db *gorm.DB) repositories.OnboardingParams {
return &OnboardingParamsGORM{
NewGenericEntityRepository[types.OnboardingConfig](db),
}
}
package gorm
import (
"gitlab.com/nunet/device-management-service/types"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// MachineResourcesGORM is a GORM implementation of the MachineResources interface.
type MachineResourcesGORM struct {
repositories.GenericEntityRepository[types.MachineResources]
}
// NewMachineResources creates a new instance of MachineResourcesGORM.
// It initializes and returns a GORM-based repository for MachineResources entity.
func NewMachineResources(db *gorm.DB) repositories.MachineResources {
return &MachineResourcesGORM{
NewGenericEntityRepository[types.MachineResources](db),
}
}
// FreeResourcesGORM is a GORM implementation of the FreeResources interface.
type FreeResourcesGORM struct {
repositories.GenericEntityRepository[types.FreeResources]
}
// NewFreeResources creates a new instance of FreeResourcesGORM.
// It initializes and returns a GORM-based repository for FreeResources entity.
func NewFreeResources(db *gorm.DB) repositories.FreeResources {
return &FreeResourcesGORM{
NewGenericEntityRepository[types.FreeResources](db),
}
}
// OnboardedResourcesRepositoryGORM is a GORM implementation of the OnboardedResources interface.
type OnboardedResourcesRepositoryGORM struct {
repositories.GenericEntityRepository[types.OnboardedResources]
}
// NewOnboardedResources creates a new instance of OnboardedResourcesGORM.
// It initializes and returns a GORM-based repository for OnboardedResources entity.
func NewOnboardedResources(db *gorm.DB) repositories.OnboardedResources {
return &OnboardedResourcesRepositoryGORM{
NewGenericEntityRepository[types.OnboardedResources](db),
}
}
// RequiredResourcesRepositoryGORM is a GORM implementation of the RequiredResources interface.
type RequiredResourcesRepositoryGORM struct {
repositories.GenericRepository[types.RequiredResources]
}
// NewRequiredResources creates a new instance of RequiredResourcesGORM.
// It initializes and returns a GORM-based repository for RequiredResources entities.
func NewRequiredResources(db *gorm.DB) repositories.RequiredResources {
return &RequiredResourcesRepositoryGORM{
NewGenericRepository[types.RequiredResources](db),
}
}
package gorm
import (
"errors"
"gorm.io/gorm"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// handleDBError is a utility function that translates GORM database errors into custom repository errors.
// It takes a GORM database error as input and returns a corresponding custom error from the repositories package.
func handleDBError(err error) error {
if err != nil {
switch err {
case gorm.ErrRecordNotFound:
return repositories.ErrNotFound
case gorm.ErrInvalidData, gorm.ErrInvalidField, gorm.ErrInvalidValue:
return repositories.ErrInvalidData
case repositories.ErrParsingModel:
return err
default:
return errors.Join(repositories.ErrDatabase, err)
}
}
return nil
}
package repositories
import (
"fmt"
"reflect"
)
// UpdateField is a generic function that updates a field of a struct or a pointer to a struct.
// The function uses reflection to dynamically update the specified field of the input struct.
func UpdateField[T interface{}](input T, fieldName string, newValue interface{}) (T, error) {
// Use reflection to get the struct's field
val := reflect.ValueOf(input)
if val.Kind() == reflect.Ptr {
// If input is a pointer, get the underlying element
val = val.Elem()
} else {
// If input is not a pointer, ensure it's addressable
val = reflect.ValueOf(&input).Elem()
}
// Check if the input is a struct
if val.Kind() != reflect.Struct {
return input, fmt.Errorf("not a struct: %T", input)
}
// Get the field by name
field := val.FieldByName(fieldName)
if !field.IsValid() {
return input, fmt.Errorf("field not found: %v", fieldName)
}
// Check if the field is settable
if !field.CanSet() {
return input, fmt.Errorf("field not settable: %v", fieldName)
}
// Check if types are compatible
if !reflect.TypeOf(newValue).ConvertibleTo(field.Type()) {
return input, fmt.Errorf(
"incompatible conversion: %v -> %v; value: %v",
field.Type(), reflect.TypeOf(newValue), newValue,
)
}
// Convert the new value to the field type
convertedValue := reflect.ValueOf(newValue).Convert(field.Type())
// Set the new value to the field
field.Set(convertedValue)
return input, nil
}
// IsEmptyValue checks if value represents a zero-value struct (or pointer to a zero-value struct) using reflection.
// The function is useful for determining if a struct or its pointer is empty, i.e., all fields have their zero-values.
func IsEmptyValue(value interface{}) bool {
// Check if the value is nil
if value == nil {
return true
}
// Use reflection to get the value's type and kind
val := reflect.ValueOf(value)
// If the value is a pointer, dereference it to get the underlying element
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
// Check if the value is zero (empty) based on its kind
return val.IsZero()
}
package actor
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
type BasicActor struct {
dispatch *Dispatch
scheduler *bt.Scheduler
registry *registry
network network.Network
security SecurityContext
params BasicActorParams
self Handle
addedTaskID int
}
type BasicActorParams struct {
Heartbeat struct {
Interval time.Duration
Jitter float64
}
}
var _ Actor = (*BasicActor)(nil)
// New creates a new basic actor.
func New(dispatch *Dispatch, scheduler *bt.Scheduler, net network.Network, security *BasicSecurityContext, params BasicActorParams, self Handle) (*BasicActor, error) {
if dispatch == nil {
return nil, errors.New("dispatch is nil")
}
if scheduler == nil {
return nil, errors.New("scheduler is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if security == nil {
return nil, errors.New("security is nil")
}
actor := &BasicActor{
dispatch: dispatch,
scheduler: scheduler,
registry: ®istry{},
network: net,
security: security,
params: params,
self: self,
}
return actor, nil
}
func (a *BasicActor) Start() error {
// Network messages
if err := a.network.HandleMessage(
fmt.Sprintf("actor/%s/messages/0.0.1", a.self.Address.InboxAddress),
a.handleMessage,
); err != nil {
return fmt.Errorf("starting actor: %s: %w", a.self.ID, err)
}
// Heartbeat
err := a.dispatch.AddBehavior(heartbeatBehavior, a.handleHeartbeat)
if err != nil {
return fmt.Errorf("failed to add heartbeat behaviour: %w", err)
}
if parent, ok := a.registry.GetParent(a.self); ok {
task := &bt.Task{
Name: "actor heartbeat",
Description: "send heartbeat to parent actor",
Triggers: []bt.Trigger{
&bt.PeriodicTriggerWithJitter{
Interval: a.params.Heartbeat.Interval,
Jitter: func() time.Duration {
return jitter(
a.params.Heartbeat.Interval,
a.params.Heartbeat.Jitter,
)
},
},
},
Function: func(_ interface{}) error {
return a.sendHeartbeat(*parent)
},
}
addedTask := a.scheduler.AddTask(task)
a.addedTaskID = addedTask.ID
}
// and start the internal goroutines
a.dispatch.Start()
a.scheduler.Start()
return nil
}
func (a *BasicActor) handleMessage(data []byte) {
var msg Envelope
if err := json.Unmarshal(data, &msg); err != nil {
// TODO log debug
return
}
if !a.self.ID.Equals(msg.To.ID) {
// TODO log warn
return
}
_ = a.dispatch.Receive(msg)
}
func (a *BasicActor) handleHeartbeat(_ Envelope) {
// Note: we don't need a capability token for heartbeats, we verify
// that the origin is one of our children or supervised actors.
// TODO check that the heartbeat origin is one of our children or supervised
// actors.
// TODO handle heartbeat statistics
}
func (a *BasicActor) sendHeartbeat(parent Handle) error {
msg, err := Message(
a.self,
parent,
heartbeatBehavior,
HeartbeatMessage{},
)
if err != nil {
return fmt.Errorf("constructing heartbeat message: %w", err)
}
return a.Send(msg)
}
func (a *BasicActor) Context() context.Context {
return a.dispatch.Context()
}
func (a *BasicActor) Handle() Handle {
return a.self
}
func (a *BasicActor) Security() SecurityContext {
return a.security
}
func (a *BasicActor) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
return a.dispatch.AddBehavior(behavior, continuation, opt...)
}
func (a *BasicActor) RemoveBehavior(behavior string) {
a.dispatch.RemoveBehavior(behavior)
}
func (a *BasicActor) Receive(msg Envelope) error {
if !a.self.ID.Equals(msg.To.ID) {
return fmt.Errorf("bad receiver: %w", ErrInvalidMessage)
}
return a.dispatch.Receive(msg)
}
func (a *BasicActor) Send(msg Envelope) error {
if msg.To.ID.Equals(a.self.ID) {
return a.Receive(msg)
}
if msg.Signature == nil {
if msg.Nonce == 0 {
msg.Nonce = a.security.Nonce()
}
if err := a.security.Provide(&msg); err != nil {
return fmt.Errorf("providing implicit behavior capability for %s: %w", msg.Behavior, err)
}
}
addrs, err := a.network.ResolveAddress(
a.Context(),
msg.To.Address.HostID,
)
if err != nil {
return fmt.Errorf("resolving address for %s: %w", msg.To.ID, err)
}
data, err := json.Marshal(msg)
if err != nil {
return fmt.Errorf("marshaling message: %w", err)
}
err = a.network.SendMessage(
a.Context(),
addrs,
types.MessageEnvelope{
Type: types.MessageType(
fmt.Sprintf("actor/%s/messages/0.0.1", msg.To.Address.InboxAddress),
),
Data: data,
})
if err != nil {
return fmt.Errorf("sending message to %s: %w", msg.To.ID, err)
}
return nil
}
func (a *BasicActor) Invoke(msg Envelope, opt ...BehaviorOption) (<-chan Envelope, error) {
if msg.Options.ReplyTo == "" {
msg.Options.ReplyTo = fmt.Sprintf("/dms/actor/replyto/%d", a.security.Nonce())
}
result := make(chan Envelope, 1)
opt = append([]BehaviorOption{WithBehaviorExpiry(msg.Options.Expire)}, opt...)
if err := a.dispatch.AddBehavior(
msg.Options.ReplyTo,
func(reply Envelope) {
result <- reply
close(result)
},
opt...,
); err != nil {
return nil, fmt.Errorf("adding reply behavior: %w", err)
}
if err := a.Send(msg); err != nil {
a.dispatch.RemoveBehavior(msg.Options.ReplyTo)
return nil, fmt.Errorf("sending message: %w", err)
}
return result, nil
}
func (a *BasicActor) Stop() error {
a.dispatch.close()
a.scheduler.RemoveTask(a.addedTaskID)
return nil
}
package actor
import (
"context"
"fmt"
"sync"
"time"
)
var (
DefaultDispatchGCInterval = 120 * time.Second
DefaultDispatchWorkers = 1
)
// DispatchLimiter implements a (potentially) stateful resource access limiter
// This is necessary to combat spam attacks and ensure that our system does not
// become overloaded with too many goroutines.
type DispatchLimiter interface {
Acquire(msg Envelope) error
Release(msg Envelope)
}
// NoDispatchLimiter is the null limiter, that does not rate limit
type NoDispatchLimiter struct{}
// Dispatch provides a reaction kernel with multithreaded dispatch and oneshot
// continuations.
type Dispatch struct {
ctx context.Context
close func()
sctx SecurityContext
mx sync.Mutex
q chan Envelope // incoming message queue
vq chan Envelope // verified message queue
behaviors map[string]*BehaviorState
started bool
options DispatchOptions
}
type DispatchOptions struct {
Limiter DispatchLimiter
GCInterval time.Duration
Workers int
}
type BehaviorState struct {
cont Behavior
opt BehaviorOptions
}
type DispatchOption func(o *DispatchOptions)
func WithDispatchWorkers(count int) DispatchOption {
return func(o *DispatchOptions) {
o.Workers = count
}
}
func WithDispatchGCInterval(dt time.Duration) DispatchOption {
return func(o *DispatchOptions) {
o.GCInterval = dt
}
}
func WithDispatchLimiter(limiter DispatchLimiter) DispatchOption {
return func(o *DispatchOptions) {
o.Limiter = limiter
}
}
func (l NoDispatchLimiter) Acquire(_ Envelope) error { return nil }
func (l NoDispatchLimiter) Release(_ Envelope) {}
func NewDispatch(sctx SecurityContext, opt ...DispatchOption) *Dispatch {
ctx, cancel := context.WithCancel(context.Background())
k := &Dispatch{
sctx: sctx,
ctx: ctx,
close: cancel,
q: make(chan Envelope),
vq: make(chan Envelope),
behaviors: make(map[string]*BehaviorState),
options: DispatchOptions{
GCInterval: DefaultDispatchGCInterval,
Workers: DefaultDispatchWorkers,
Limiter: NoDispatchLimiter{},
},
}
for _, f := range opt {
f(&k.options)
}
return k
}
func (k *Dispatch) Start() {
k.mx.Lock()
defer k.mx.Unlock()
if !k.started {
for i := 0; i < k.options.Workers; i++ {
go k.recv()
}
go k.dispatch()
go k.gc()
k.started = true
}
}
func (k *Dispatch) AddBehavior(behavior string, continuation Behavior, opt ...BehaviorOption) error {
st := &BehaviorState{
cont: continuation,
opt: BehaviorOptions{
Capability: []Capability{Capability(behavior)},
},
}
for _, f := range opt {
if err := f(&st.opt); err != nil {
return fmt.Errorf("adding behavior: %w", err)
}
}
k.mx.Lock()
defer k.mx.Unlock()
k.behaviors[behavior] = st
return nil
}
func (k *Dispatch) RemoveBehavior(behavior string) {
k.mx.Lock()
defer k.mx.Unlock()
delete(k.behaviors, behavior)
}
func (k *Dispatch) Receive(msg Envelope) error {
select {
case k.q <- msg:
return nil
case <-k.ctx.Done():
return k.ctx.Err()
}
}
func (k *Dispatch) Context() context.Context {
return k.ctx
}
func (k *Dispatch) recv() {
for {
select {
case msg := <-k.q:
if err := k.sctx.Verify(msg); err != nil {
// TODO log debug
fmt.Println("vefieid failed", err.Error())
return
}
k.vq <- msg
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) dispatch() {
for {
select {
case msg := <-k.vq:
k.mx.Lock()
b, ok := k.behaviors[msg.Behavior]
if !ok {
// TODO log debug
k.mx.Unlock()
continue
}
if b.Expired(time.Now()) {
delete(k.behaviors, msg.Behavior)
// TODO log debug
k.mx.Unlock()
continue
}
if err := k.sctx.Require(msg, b.opt.Capability...); err != nil {
// TODO log warn
k.mx.Unlock()
continue
}
if b.opt.OneShot {
delete(k.behaviors, msg.Behavior)
}
k.mx.Unlock()
if err := k.options.Limiter.Acquire(msg); err != nil {
// TODO log warn
continue
}
go func() {
defer k.options.Limiter.Release(msg)
b.cont(msg)
}()
case <-k.ctx.Done():
return
}
}
}
func (k *Dispatch) gc() {
ticker := time.NewTicker(k.options.GCInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
k.mx.Lock()
now := time.Now()
for x, b := range k.behaviors {
if b.Expired(now) {
delete(k.behaviors, x)
}
}
k.mx.Unlock()
case <-k.ctx.Done():
return
}
}
}
func (b *BehaviorState) Expired(now time.Time) bool {
if b.opt.Expire > 0 {
return uint64(now.UnixNano()) > b.opt.Expire
}
return false
}
func WithBehaviorExpiry(expire uint64) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Expire = expire
return nil
}
}
func WithBehaviorCapability(cap ...Capability) BehaviorOption {
return func(opt *BehaviorOptions) error {
opt.Capability = cap
return nil
}
}
package actor
import (
"bytes"
)
func (id ID) Equals(other ID) bool {
return bytes.Equal(id.PublicKey, other.PublicKey)
}
func (did DID) Equals(other DID) bool {
return bytes.Equal(did.PublicKey, other.PublicKey)
}
package actor
import (
"crypto/rand"
"fmt"
"github.com/libp2p/go-libp2p/core/crypto"
)
const KeyTypeEd25519 = crypto.Ed25519
type Key = crypto.Key
type PrivKey = crypto.PrivKey
type PubKey = crypto.PubKey
func GenerateKeyPair(t int) (PrivKey, PubKey, error) {
switch t {
case KeyTypeEd25519:
return crypto.GenerateEd25519Key(rand.Reader)
default:
return nil, nil, fmt.Errorf("unsupported key type %d: %w", t, ErrUnsupportedKeyType)
}
}
func PublicKeyToBytes(k PubKey) ([]byte, error) {
return crypto.MarshalPublicKey(k)
}
func BytesToPublicKey(data []byte) (PubKey, error) {
return crypto.UnmarshalPublicKey(data)
}
func PrivateKeyToBytes(k PrivKey) ([]byte, error) {
return crypto.MarshalPrivateKey(k)
}
func BytesToPrivateKey(data []byte) (PrivKey, error) {
return crypto.UnmarshalPrivateKey(data)
}
func IDFromPublicKey(k PubKey) (ID, error) {
data, err := PublicKeyToBytes(k)
if err != nil {
return ID{}, fmt.Errorf("id from public key: %w", err)
}
return ID{PublicKey: data}, nil
}
func PublicKeyFromID(id ID) (PubKey, error) {
return BytesToPublicKey(id.PublicKey)
}
package actor
type BasicDispatchLimiter struct {
// TODO
}
func (l *BasicDispatchLimiter) Reserve(_ Envelope) error {
// TODO we can leave this for follow up
return ErrTODO
}
func (l *BasicDispatchLimiter) Release(_ Envelope) {
// TODO we can leave this for follow up
}
package actor
import (
"encoding/json"
"fmt"
"time"
)
const (
heartbeatBehavior = "/dms/actor/heartbeat"
defaultMessageTimeout = 30 * time.Second
)
var signaturePrefix = []byte("dms:")
type HeartbeatMessage struct{}
// Message constructs a new message envelope and applies the options
func Message(src Handle, dest Handle, behavior string, payload interface{}, opt ...MessageOption) (Envelope, error) {
data, err := json.Marshal(payload)
if err != nil {
return Envelope{}, fmt.Errorf("marshaling payload: %w", err)
}
msg := Envelope{
To: dest,
Behavior: behavior,
From: src,
Message: data,
Options: EnvelopeOptions{
Expire: uint64(time.Now().Add(defaultMessageTimeout).UnixNano()),
},
}
for _, f := range opt {
if err := f(&msg); err != nil {
return Envelope{}, fmt.Errorf("setting message option: %w", err)
}
}
return msg, nil
}
// WithMessageContext provides the necessary envelope and signs it.
//
// NOTE: If this option must be passed last, otherwise the signature will be invalidated by further modifications.
//
// NOTE: Signing is implicit in Send.
func WithMessageSignature(sctx SecurityContext, cap ...Capability) MessageOption {
return func(msg *Envelope) error {
if !msg.From.ID.Equals(sctx.ID()) {
return ErrInvalidSecurityContext
}
msg.Nonce = sctx.Nonce()
return sctx.Provide(msg, cap...)
}
}
// WithMessageTimeout sets the message expiration from a relative timeout
//
// NOTE: messages created with Message have an implicit timeout of DefaultMessageTimeout
func WithMessageTimeout(timeo time.Duration) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = uint64(time.Now().Add(timeo).UnixNano())
return nil
}
}
// WithMessageExpiry sets the message expiry
//
// NOTE: created with Message message have an implicit timeout of DefaultMessageTimeout
func WithMessageExpiry(expiry uint64) MessageOption {
return func(msg *Envelope) error {
msg.Options.Expire = expiry
return nil
}
}
// WithMessageReplyTo sets the message replyto behavior
//
// NOTE: ReplyTo is set implicitly in Invoke and the appropriate capability
//
// tokens are delegated by Provide.
func WithMessageReplyTo(replyto string) MessageOption {
return func(msg *Envelope) error {
msg.Options.ReplyTo = replyto
return nil
}
}
func (msg *Envelope) signatureData() ([]byte, error) {
msgCopy := *msg
msgCopy.Signature = nil
data, err := json.Marshal(&msgCopy)
if err != nil {
return nil, fmt.Errorf("signature data: %w", err)
}
result := make([]byte, len(signaturePrefix)+len(data))
copy(result, signaturePrefix)
copy(result[len(signaturePrefix):], data)
return result, nil
}
func (msg *Envelope) Expired() bool {
if msg.Options.Expire > 0 {
return uint64(time.Now().UnixNano()) > msg.Options.Expire
}
return false
}
package actor
import (
"errors"
"sync"
)
type Info struct {
Addr *Handle
Parent *Handle
Children []Handle
}
type Registry interface {
Actors() map[string]Info
Add(a Handle, parent Handle, children []Handle) error
Get(a Handle) (Info, bool)
SetParent(a Handle, parent Handle) error
GetParent(a Handle) (*Handle, bool)
}
type registry struct {
mx sync.Mutex
actors map[string]Info
}
func NewRegistry() Registry {
return ®istry{
actors: make(map[string]Info),
}
}
func (r *registry) Actors() map[string]Info {
r.mx.Lock()
defer r.mx.Unlock()
actors := make(map[string]Info, len(r.actors))
for k, v := range r.actors {
actors[k] = v
}
return actors
}
func (r *registry) Add(a Handle, parent Handle, children []Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
if _, ok := r.actors[a.Address.InboxAddress]; ok {
return errors.New("actor already exists")
}
if children == nil {
children = []Handle{}
}
r.actors[a.Address.InboxAddress] = Info{
Addr: &a,
Parent: &parent,
Children: children,
}
return nil
}
func (r *registry) Get(a Handle) (Info, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
return info, ok
}
func (r *registry) SetParent(a Handle, parent Handle) error {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return errors.New("actor not found")
}
info.Parent = &parent
r.actors[a.Address.InboxAddress] = info
return nil
}
func (r *registry) GetParent(a Handle) (*Handle, bool) {
r.mx.Lock()
defer r.mx.Unlock()
info, ok := r.actors[a.Address.InboxAddress]
if !ok {
return nil, false
}
return info.Parent, true
}
package actor
import (
"fmt"
"sync"
"time"
)
type BasicSecurityContext struct {
id ID
did DID
privk PrivKey
mx sync.Mutex
nonce uint64
}
var _ SecurityContext = (*BasicSecurityContext)(nil)
func NewBasicSecurityContext(pubk PubKey, privk PrivKey, did DID) (*BasicSecurityContext, error) {
sctx := &BasicSecurityContext{
did: did,
privk: privk,
nonce: uint64(time.Now().UnixNano()),
}
var err error
sctx.id, err = IDFromPublicKey(pubk)
if err != nil {
return nil, fmt.Errorf("creating security context: %w", err)
}
return sctx, nil
}
func (s *BasicSecurityContext) ID() ID {
return s.id
}
func (s *BasicSecurityContext) DID() DID {
return s.did
}
func (s *BasicSecurityContext) Nonce() uint64 {
s.mx.Lock()
defer s.mx.Unlock()
nonce := s.nonce
s.nonce++
return nonce
}
func (s *BasicSecurityContext) Require(_ Envelope, _ ...Capability) error {
//
// TODO check capability tokens for required capabilities
// we do nothing for now, will be implemented in follow up
return nil
}
func (s *BasicSecurityContext) Provide(msg *Envelope, _ ...Capability) error {
// TODO provide capability tokes for the required capabilities
// we do nothing for now, will be implemented in follow up
return s.Sign(msg)
}
func (s *BasicSecurityContext) Verify(msg Envelope) error {
if msg.Expired() {
return ErrMessageExpired
}
pubk, err := PublicKeyFromID(msg.From.ID)
if err != nil {
return fmt.Errorf("public key from id: %w", err)
}
data, err := msg.signatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
ok, err := pubk.Verify(data, msg.Signature)
if err != nil {
return fmt.Errorf("verify message signature: %w", err)
}
if !ok {
return ErrSignatureVerification
}
return nil
}
func (s *BasicSecurityContext) Sign(msg *Envelope) error {
data, err := msg.signatureData()
if err != nil {
return fmt.Errorf("signature data: %w", err)
}
sig, err := s.privk.Sign(data)
if err != nil {
return fmt.Errorf("signing message: %w", err)
}
msg.Signature = sig
return nil
}
package actor
import (
"encoding/base32"
"encoding/json"
"fmt"
)
type IDJSONView struct {
ID string
}
type DIDJSONView struct {
DID string
}
func (id ID) String() string {
return base32.StdEncoding.EncodeToString(id.PublicKey)
}
func (id ID) MarshalJSON() ([]byte, error) {
return json.Marshal(IDJSONView{ID: id.String()})
}
var _ json.Marshaler = ID{}
func IDFromString(s string) (ID, error) {
data, err := base32.StdEncoding.DecodeString(s)
if err != nil {
return ID{}, fmt.Errorf("decode ID: %w", err)
}
return ID{PublicKey: data}, nil
}
func (id *ID) UnmarshalJSON(data []byte) error {
var input IDJSONView
err := json.Unmarshal(data, &input)
if err != nil {
return fmt.Errorf("unmarshaling ID: %w", err)
}
val, err := IDFromString(input.ID)
if err != nil {
return fmt.Errorf("unmarshaling ID: %w", err)
}
*id = val
return nil
}
var _ json.Unmarshaler = (*ID)(nil)
func (did DID) String() string {
return base32.StdEncoding.EncodeToString(did.PublicKey)
}
func (did DID) MarshalJSON() ([]byte, error) {
return json.Marshal(DIDJSONView{DID: did.String()})
}
var _ json.Marshaler = DID{}
func DIDFromString(s string) (DID, error) {
data, err := base32.StdEncoding.DecodeString(s)
if err != nil {
return DID{}, fmt.Errorf("decode DID: %w", err)
}
return DID{PublicKey: data}, nil
}
func (did *DID) UnmarshalJSON(data []byte) error {
var input DIDJSONView
err := json.Unmarshal(data, &input)
if err != nil {
return fmt.Errorf("unmarshaling DID: %w", err)
}
val, err := DIDFromString(input.DID)
if err != nil {
return fmt.Errorf("unmarshaling DID: %w", err)
}
*did = val
return nil
}
var _ json.Unmarshaler = (*DID)(nil)
func (a Address) String() string {
return a.HostID + ":" + a.InboxAddress
}
func AddressFromString(_ string) (Address, error) {
// TODO
return Address{}, ErrTODO
}
func (a Handle) String() string {
return fmt.Sprintf("%s[%s]@%s", a.ID, a.DID, a.Address)
}
func HandleFromString(_ string) (Handle, error) {
// TODO
return Handle{}, ErrTODO
}
package actor
import (
"time"
)
func jitter(interval time.Duration, pct float64) time.Duration {
return time.Duration(float64(interval) * pct)
}
package jobs
import (
"context"
"errors"
"fmt"
"reflect"
"github.com/google/uuid"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/executor"
"gitlab.com/nunet/device-management-service/executor/docker"
"gitlab.com/nunet/device-management-service/executor/firecracker"
"gitlab.com/nunet/device-management-service/types"
)
// Status holds the status of an allocation.
type Status struct {
JobResources types.ExecutionResources
Status AllocationStatus
}
// AllocationDetails encapsulates the dependencies to the constructor.
type AllocationDetails struct {
Job Job
NodeID string
SourceID string
}
// AllocationStatus is a representation of the execution status
type AllocationStatus string
const (
pending AllocationStatus = "pending"
running AllocationStatus = "running"
stopped AllocationStatus = "stopped"
)
// Allocation represents an allocation
type Allocation struct {
ID string
Job Job
status AllocationStatus
NodeID string
SourceID string
executionID string
actor *dms.BasicActor
executor executor.Executor
resourceManager resources.Manager
}
// NewAllocation creates a new allocation given the actor.
func NewAllocation(actor *dms.BasicActor, details AllocationDetails, resourceManager resources.Manager) (*Allocation, error) {
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
id, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to generate uuid for allocation: %w", err)
}
executorID, err := uuid.NewUUID()
if err != nil {
return nil, fmt.Errorf("failed to create executor id: %w", err)
}
return &Allocation{
ID: id.String(),
Job: details.Job,
NodeID: details.NodeID,
SourceID: details.SourceID,
actor: actor,
executionID: executorID.String(),
resourceManager: resourceManager,
status: pending,
}, nil
}
// Run creates the executor based on the execution engine configuration.
func (a *Allocation) Run(ctx context.Context) error {
freeResources, err := a.resourceManager.UpdateFreeResources(ctx)
if err != nil {
return fmt.Errorf("failed to get free resources: %w", err)
}
if !availableResources(a.Job.Resources, freeResources) {
return fmt.Errorf("no available resources for job %s", a.Job.ID)
}
// if executor is nil create it
if a.executor == nil {
err = a.createExecutor(ctx, a.Job.Execution)
if err != nil {
return fmt.Errorf("failed to create executor: %w", err)
}
}
err = a.executor.Start(ctx, &types.ExecutionRequest{
JobID: a.Job.ID,
ExecutionID: a.executionID,
EngineSpec: &a.Job.Execution,
Resources: &a.Job.Resources,
// TODO add the following
Inputs: []*types.StorageVolumeExecutor{},
Outputs: []*types.StorageVolumeExecutor{},
ResultsDir: "",
})
if err != nil {
return fmt.Errorf("failed to start executor: %w", err)
}
_, err = a.resourceManager.UpdateFreeResources(ctx)
if err != nil {
return fmt.Errorf("failed to update resources after running allocation's executor: %w", err)
}
a.status = running
return nil
}
// Stop stops the running executor
func (a *Allocation) Stop(ctx context.Context) error {
if a.status != running {
return errors.New("allocation is not running")
}
err := a.executor.Cancel(ctx, a.executionID)
if err != nil {
return fmt.Errorf("failed to stop execution: %w", err)
}
a.status = stopped
_, err = a.resourceManager.UpdateFreeResources(ctx)
if err != nil {
return fmt.Errorf("failed to update resources after stoping allocation's executor: %w", err)
}
return nil
}
// Status returns information about the allocated/usage of resources and execution status of workload.
func (a *Allocation) Status(_ context.Context) Status {
return Status{
JobResources: a.Job.Resources,
Status: a.status,
}
}
// StartActor starts the actor of the allocation.
func (a *Allocation) StartActor() error {
err := a.actor.Start()
if err != nil {
return fmt.Errorf("failed to start allocation actor: %w", err)
}
return nil
}
// ProcessMessages processes actor messages.
func (a *Allocation) ProcessMessages() {
for msg := range a.actor.Messages() {
a.dispatchMethod(msg.Type, msg.Data)
}
}
// SendMessage sends a message through the actor.
func (a *Allocation) SendMessage(ctx context.Context, destination *dms.ActorAddrInfo, m *dms.Message) error {
return a.actor.SendMessage(ctx, destination, m)
}
func (a *Allocation) dispatchMethod(methodName string, args ...any) {
handlerMethod := fmt.Sprintf("Handle%s", methodName)
arguments := make([]reflect.Value, 0)
for _, v := range args {
arguments = append(arguments, reflect.ValueOf(v))
}
method := reflect.ValueOf(a).MethodByName(handlerMethod)
if method.IsValid() {
method.Call(arguments)
return
}
// check if actor has the method
actorMethod := reflect.ValueOf(a.actor).MethodByName(handlerMethod)
if actorMethod.IsValid() {
actorMethod.Call(arguments)
}
}
func (a *Allocation) createExecutor(ctx context.Context, conf types.SpecConfig) error {
if conf.Type == types.ExecutorTypeFirecracker {
executor, err := firecracker.NewExecutor(ctx, a.executionID)
if err != nil {
return fmt.Errorf("firecracker executor: %w", err)
}
a.executor = executor
} else if conf.Type == types.ExecutorTypeDocker {
executor, err := docker.NewExecutor(ctx, a.executionID)
if err != nil {
return fmt.Errorf("docker executor: %w", err)
}
a.executor = executor
}
return nil
}
// TODO: ExecutionResources and FreeResources should be compatible
func availableResources(jobResources types.ExecutionResources, fr types.FreeResources) bool {
return fr.RAM >= uint64(jobResources.Memory.Size) && fr.Disk >= jobResources.Disk.Size && fr.CPU > float64(jobResources.CPU.Cores)
}
package parser
import (
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/nunet"
)
var registry Registry[jobs.JobSpec]
func init() {
registry = &RegistryImpl[jobs.JobSpec]{
parsers: make(map[SpecType]Parser[jobs.JobSpec]),
}
// Register Nunet parser.
nunetParser := NewParser[jobs.JobSpec](
nunet.NewNuNetTransformer(),
nunet.NewNuNetValidator(),
)
registry.RegisterParser(specTypeNuNet, nunetParser)
}
package nunet
import (
"fmt"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetTransformer() transform.Transformer {
return transform.NewTransformer(
[]map[tree.Path]transform.TransformerFunc{
{
"jobs": TransformJobs,
"jobs.**.children": TransformJobs,
"jobs.**.volumes": TransformVolumes,
"jobs.**.networks": TransformNetworks,
},
{
"jobs.**.volumes.[]": TransformVolume,
"jobs.**.networks.[]": TransformNetwork,
"jobs.**.libraries.[]": TransformLibrary,
},
{
"jobs.**.execution": TransformExecution,
"jobs.**.volumes.[].remote": TransformVolumeRemote,
},
},
)
}
// TransformJobs transforms the jobs map to a slice and assigns the keys to the "name" field.
func TransformJobs(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
jobs, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid jobs configuration: %v", data)
}
return transform.MapToSlice(jobs)
}
// TransformVolumes transforms the volumes map to a slice and assigns the keys to the "name" field.
func TransformVolumes(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
volumes, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volumes configuration: %v", data)
}
return transform.MapToSlice(volumes)
}
// TransformNetworks transforms the networks map to a slice and assigns the keys to the "name" field.
func TransformNetworks(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
networks, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid networks configuration: %v", data)
}
return transform.MapToSlice(networks)
}
// TransformExecution transforms the engine configuration from flat map to SpecConfig format.
func TransformExecution(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
engine, ok := data.(map[string]any)
result := map[string]any{}
if !ok {
return nil, fmt.Errorf("invalid engine configuration: %v", data)
}
params := map[string]any{}
for k, v := range engine {
if k != "type" {
params[k] = v
}
}
result["type"] = engine["type"]
result["params"] = params
return result, nil
}
// TransformVloume transforms the volume configuration and handles inheritance.
// The volume configuration can be a string in the format "name:mountpoint" or a map.
// If the volume is defined in the parent volumes, the configurations are merged.
func TransformVolume(root *map[string]interface{}, data any, path tree.Path) (any, error) {
var config map[string]any
// If the data is a string, split it into name and mountpoint.
switch v := data.(type) {
case string:
mapping := strings.Split(v, ":")
if len(mapping) != 2 {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
config = map[string]any{
"name": mapping[0],
"mountpoint": mapping[1],
}
case map[string]any:
config = v
default:
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
// Collect all potential parent paths where the volume could be defined.
parentPaths := []tree.Path{}
pathParts := path.Parts()
for i, part := range pathParts {
if part == "children" {
parentPaths = append(parentPaths, tree.NewPath(pathParts[:i]...))
}
}
// Merge the volume configuration with the parent configurations.
for _, parent := range parentPaths {
// Check if the volume exists in the parent
c, err := transform.GetConfigAtPath(*root, parent.Next("volumes"))
if err != nil {
fmt.Println("error: ", err)
continue
}
volumes, _ := transform.ToAnySlice(c)
for _, v := range volumes {
if volume, ok := v.(map[string]any); ok && volume["name"] == config["name"] {
// Merge the configurations
for k, v := range volume {
config[k] = v
}
}
}
}
return config, nil
}
func TransformVolumeRemote(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid volume configuration: %v", data)
}
remoteConfig := map[string]any{}
remoteConfig["type"] = config["type"]
if params, ok := config["params"]; ok {
remoteConfig["params"] = params.(map[string]any)
return remoteConfig, nil
}
params := map[string]any{}
for k, v := range config {
if k != "type" {
params[k] = v
}
}
remoteConfig["params"] = params
return remoteConfig, nil
}
// TransformNetwork transforms the network configuration
func TransformNetwork(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
config, ok := data.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid network configuration: %v", data)
}
ports, _ := transform.ToAnySlice(config["ports"])
portMap := []map[string]any{}
for _, port := range ports {
protocol, host, container := "tcp", 0, 0
switch v := port.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) <= 2 {
host, _ = strconv.Atoi(parts[0])
container, _ = strconv.Atoi(parts[len(parts)-1])
} else if len(parts) == 3 {
protocol = parts[0]
host, _ = strconv.Atoi(parts[1])
container, _ = strconv.Atoi(parts[len(parts)-1])
}
case int:
host = v
container = v
case map[string]any:
switch h := v["host_port"].(type) {
case int:
host = h
case string:
host, _ = strconv.Atoi(h)
}
switch c := v["container_port"].(type) {
case int:
container = c
case string:
container, _ = strconv.Atoi(c)
}
if p, ok := v["protocol"].(string); ok {
protocol = p
}
}
portMap = append(portMap, map[string]any{
"protocol": protocol,
"host_port": host,
"container_port": container,
})
}
config["port_map"] = portMap
delete(config, "ports")
return config, nil
}
// TransformLibrary tansforms the library configuration to a map.
// The library configuration can be a string in the format "name:version" or a map.
func TransformLibrary(_ *map[string]interface{}, data any, _ tree.Path) (any, error) {
if data == nil {
return nil, nil
}
switch v := data.(type) {
case string:
parts := strings.Split(v, ":")
if len(parts) == 1 {
parts = append(parts, "")
}
if len(parts) != 2 {
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
return map[string]any{
"name": parts[0],
"version": parts[1],
}, nil
case map[string]any:
return v, nil
default:
return nil, fmt.Errorf("invalid library configuration: %v", data)
}
}
package nunet
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
// NewNuNetValidator creates a new validator for the NuNet configuration.
func NewNuNetValidator() validate.Validator {
return validate.NewValidator(
map[tree.Path]validate.ValidatorFunc{
"": ValidateSpec,
"jobs.[]": ValidateJob,
"jobs.**.children.[]": ValidateJob,
},
)
}
// ValidateSpec checks the root configuration for consistency.
func ValidateSpec(_ *map[string]any, data any, _ tree.Path) error {
spec, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid spec configuration: %v", data)
}
// Check if the jobs list is present and not empty.
if spec["jobs"] == nil || len(spec["jobs"].([]any)) == 0 {
return fmt.Errorf("jobs list is required")
}
return nil
}
// ValidateJob checks the job configuration.
func ValidateJob(_ *map[string]any, data any, _ tree.Path) error {
job, ok := data.(map[string]any)
if !ok {
return fmt.Errorf("invalid job configuration: %v", data)
}
// Check if the job has either children or an execution.
if job["children"] == nil || len(job["children"].([]any)) == 0 {
if job["execution"] == nil {
return fmt.Errorf("job must have either children or an execution")
}
}
return nil
}
package parser
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs"
)
func Parse(specType SpecType, data []byte) (jobs.JobSpec, error) {
result := jobs.JobSpec{}
parser, exists := registry.GetParser(specType)
if !exists {
return result, fmt.Errorf("parser for spec type %s not found", specType)
}
result, err := parser.Parse(data)
if err != nil {
return result, err
}
return result, nil
}
package parser
import (
"encoding/json"
"fmt"
"github.com/mitchellh/mapstructure"
yaml "gopkg.in/yaml.v3"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/transform"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/validate"
)
type SpecType string
const (
specTypeNuNet SpecType = "nunet"
specTypeNomad SpecType = "nomad"
specTypeK8s SpecType = "k8s"
)
const DefaultTagName = "json"
type Parser[T any] interface {
Parse(data []byte) (T, error)
}
type Impl[T any] struct {
validator validate.Validator
transformer transform.Transformer
}
func NewParser[T any](transformer transform.Transformer, validator validate.Validator) Parser[T] {
return Impl[T]{
transformer: transformer,
validator: validator,
}
}
func (p Impl[T]) Parse(data []byte) (T, error) {
var rawConfig map[string]any
var config T
// Try to unmarshal as YAML first
err := yaml.Unmarshal(data, &rawConfig)
if err != nil {
// If YAML fails, try JSON
err = json.Unmarshal(data, &rawConfig)
if err != nil {
return config, fmt.Errorf("failed to parse config: %v", err)
}
}
// Apply transformers
transformed, err := p.transformer.Transform(&rawConfig)
if err != nil {
return config, fmt.Errorf("failed to transform config: %v", err)
}
// Validate the transformed configuration
if err = p.validator.Validate(&rawConfig); err != nil {
return config, err
}
// Create a new mapstructure decoder
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: &config,
TagName: DefaultTagName,
})
if err != nil {
return config, fmt.Errorf("failed to create decoder: %v", err)
}
// Decode the transformed configuration
err = decoder.Decode(transformed)
if err != nil {
return config, fmt.Errorf("failed to decode config: %v", err)
}
return config, err
}
package parser
import (
"sync"
)
type Registry[T any] interface {
GetParser(specType SpecType) (Parser[T], bool)
RegisterParser(specType SpecType, p Parser[T])
}
type RegistryImpl[T any] struct {
parsers map[SpecType]Parser[T]
mu sync.RWMutex
}
func (r *RegistryImpl[T]) RegisterParser(specType SpecType, p Parser[T]) {
r.mu.Lock()
defer r.mu.Unlock()
r.parsers[specType] = p
}
func (r *RegistryImpl[T]) GetParser(specType SpecType) (Parser[T], bool) {
r.mu.RLock()
defer r.mu.RUnlock()
p, exists := r.parsers[specType]
return p, exists
}
package transform
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// TransformerFunc is a function that transforms a part of the configuration.
// It modifies the data to conform to the expected structure and returns the transformed data.
// It takes the root configuration, the data to transform and the current path in the tree.
type TransformerFunc func(*map[string]interface{}, interface{}, tree.Path) (any, error)
// Transformer is a configuration transformer.
type Transformer interface {
Transform(*map[string]interface{}) (interface{}, error)
}
// TransformerImpl is the implementation of the Transformer interface.
type TransformerImpl struct {
transformers []map[tree.Path]TransformerFunc
}
// NewTransformer creates a new transformer with the given transformers.
func NewTransformer(transformers []map[tree.Path]TransformerFunc) Transformer {
return TransformerImpl{
transformers: transformers,
}
}
// Transform applies the transformers to the configuration.
func (t TransformerImpl) Transform(rawConfig *map[string]interface{}) (interface{}, error) {
data := any(*rawConfig)
var err error
for _, transformers := range t.transformers {
data, err = t.transform(rawConfig, data, tree.NewPath(), transformers)
if err != nil {
return nil, err
}
}
return Normalize(data), nil
}
// transform is a recursive function that applies the transformers to the configuration.
func (t TransformerImpl) transform(root *map[string]interface{}, data any, path tree.Path, transformers map[tree.Path]TransformerFunc) (interface{}, error) {
var err error
// Apply transformers that match the current path.
for pattern, transformer := range transformers {
if path.Matches(pattern) {
data, err = transformer(root, data, path)
if err != nil {
return nil, err
}
}
}
// Recursively apply transformers to children.
if result, ok := data.(map[string]interface{}); ok {
for key, value := range result {
next := path.Next(key)
result[key], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
} else if result, err := ToAnySlice(data); err == nil {
for i, value := range result {
next := path.Next(fmt.Sprintf("[%d]", i))
result[i], err = t.transform(root, value, next, transformers)
if err != nil {
return nil, err
}
}
return result, nil
}
return data, nil
}
package transform
import (
"fmt"
"reflect"
"sort"
"strconv"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// mapToSlice converts a map of maps to a slice
// and assigns the key to the "name" field.
func MapToSlice(data map[string]any) ([]any, error) {
if data == nil {
return nil, nil
}
result := []any{}
for k, v := range data {
if v == nil {
v = map[string]any{}
}
if e, ok := v.(map[string]any); ok {
e["name"] = k
}
result = append(result, v)
}
return result, nil
}
// getConfigAtPath retrieves a part of the configuration at a given path
func GetConfigAtPath(config map[string]interface{}, path tree.Path) (any, error) {
current := any(config)
for _, key := range path.Parts() {
switch v := current.(type) {
case map[string]any:
current = v[key]
case []any, []map[string]any:
i, err := strconv.Atoi(key[1 : len(key)-1])
if err != nil {
return nil, fmt.Errorf("invalid index: %v", key)
}
switch v := v.(type) {
case []any:
current = v[i]
case []map[string]any:
current = v[i]
}
default:
return nil, fmt.Errorf("invalid data type: %v", current)
}
}
return current, nil
}
// Generic function to convert any slice to []any
func ToAnySlice(slice any) ([]any, error) {
value := reflect.ValueOf(slice)
// Check if the input is a slice
if value.Kind() != reflect.Slice {
return nil, fmt.Errorf("input is not a slice. type: %T", slice)
}
length := value.Len()
anySlice := make([]any, length)
for i := 0; i < length; i++ {
anySlice[i] = value.Index(i).Interface()
}
return anySlice, nil
}
func normalizeMap(m interface{}) interface{} {
v := reflect.ValueOf(m)
switch v.Kind() {
case reflect.Map:
// Create a new map to hold normalized values
newMap := reflect.MakeMap(reflect.MapOf(v.Type().Key(), reflect.TypeOf((*interface{})(nil)).Elem()))
for _, key := range v.MapKeys() {
newValue := normalizeMap(v.MapIndex(key).Interface())
newMap.SetMapIndex(key, reflect.ValueOf(newValue))
}
return newMap.Interface()
case reflect.Slice:
// Create a new []interface{} slice to hold normalized values
newSlice := make([]interface{}, v.Len())
for i := 0; i < v.Len(); i++ {
newSlice[i] = normalizeMap(v.Index(i).Interface())
}
// Sort the slice if it's sortable
sort.Slice(newSlice, func(i, j int) bool {
return fmt.Sprint(newSlice[i]) < fmt.Sprint(newSlice[j])
})
return newSlice
default:
// For other types, return as is
return m
}
}
// NormalizeMap is the exported function that users will call
func Normalize(m any) interface{} {
return normalizeMap(m)
}
package tree
import (
"strings"
)
const (
configPathSeparator = "."
configPathMatchAny = "*"
configPathMatchAnyMultiple = "**"
configPathList = "[]"
)
// Path is a custom type for representing paths in the configuration
type Path string
func NewPath(path ...string) Path {
return Path(strings.Join(path, configPathSeparator))
}
// Parts returns the parts of the path
func (p Path) Parts() []string {
return strings.Split(string(p), configPathSeparator)
}
// Parent returns the parent path
func (p Path) Parent() Path {
parts := p.Parts()
if len(parts) > 1 {
return Path(strings.Join(parts[:len(parts)-1], configPathSeparator))
}
return ""
}
// Next returns the next part of the path
func (p Path) Next(path string) Path {
if path == "" {
return p
}
if p == "" {
return Path(path)
}
return Path(string(p) + configPathSeparator + path)
}
// Last returns the last part of the path
func (p Path) Last() string {
parts := p.Parts()
if len(parts) > 0 {
return parts[len(parts)-1]
}
return ""
}
// Matches checks if the path matches a given pattern
func (p Path) Matches(pattern Path) bool {
pathParts := p.Parts()
patternParts := pattern.Parts()
return matchParts(pathParts, patternParts)
}
// String returns the string representation of the path
func (p Path) String() string {
return string(p)
}
// matchParts checks if the path parts match the pattern parts
func matchParts(pathParts, patternParts []string) bool {
// If the pattern is longer than the path, it can't match
if len(pathParts) < len(patternParts) {
return false
}
for i, part := range patternParts {
switch part {
case configPathMatchAnyMultiple:
// if it is the last part of the pattern, it matches
if i == len(patternParts)-1 {
return true
}
// Otherwise, try to match the rest of the path
for j := i; j < len(pathParts); j++ {
if matchParts(pathParts[j:], patternParts[i+1:]) {
return true
}
}
case configPathList:
// check if pathParts[i] is inclosed by []
if pathParts[i][0] != '[' || pathParts[i][len(pathParts[i])-1] != ']' {
return false
}
default:
// If the part doesn't match, it doesn't match
if part != configPathMatchAny && part != pathParts[i] {
return false
}
}
// If it is the last part of the pattern and the path is longer, it doesn't match
if i == len(patternParts)-1 && i < len(pathParts)-1 {
return false
}
}
return true
}
package validate
import (
"fmt"
"gitlab.com/nunet/device-management-service/dms/jobs/parser/tree"
)
// ValidatorFunc is a function that validates a part of the configuration.
// It takes the root configuration, the data to validate and the current path in the tree.
type ValidatorFunc func(*map[string]any, any, tree.Path) error
// Validator is a configuration validator.
// It contains a map of patterns to paths to functions that validate the configuration.
type Validator interface {
Validate(*map[string]any) error
}
// ValidatorImpl is the implementation of the Validator interface.
type ValidatorImpl struct {
validators map[tree.Path]ValidatorFunc
}
// NewValidator creates a new validator with the given validators.
func NewValidator(validators map[tree.Path]ValidatorFunc) Validator {
return ValidatorImpl{
validators: validators,
}
}
// Validate applies the validators to the configuration.
func (v ValidatorImpl) Validate(rawConfig *map[string]any) error {
data := any(*rawConfig)
return v.validate(rawConfig, data, tree.NewPath(), v.validators)
}
// validate is a recursive function that applies the validators to the configuration.
func (v ValidatorImpl) validate(root *map[string]interface{}, data any, path tree.Path, validators map[tree.Path]ValidatorFunc) error {
// Apply validators that match the current path.
for pattern, validator := range validators {
if path.Matches(pattern) {
if err := validator(root, data, path); err != nil {
return err
}
}
}
// Recursively apply validators to children.
switch data := data.(type) {
case map[string]interface{}:
for key, value := range data {
next := path.Next(key)
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
case []interface{}:
for i, value := range data {
next := path.Next(fmt.Sprintf("[%d]", i))
if err := v.validate(root, value, next, validators); err != nil {
return err
}
}
}
return nil
}
package node
import (
"context"
"errors"
"fmt"
"reflect"
"sync"
"time"
"gitlab.com/nunet/device-management-service/dms"
"gitlab.com/nunet/device-management-service/dms/jobs"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/network"
"gitlab.com/nunet/device-management-service/types"
)
// TODO: remove after resource manager MR is merged
type benchmarker interface {
Benchmark(ctx context.Context) (*types.Capability, error)
}
// Node is the structure that holds the node's dependencies.
type Node struct {
ID string
actor *dms.BasicActor
network network.Network
actorFactory *dms.ActorFactory
// TODO: fix when resource manager is merged to develop
resourceManager resources.Manager
benchmark benchmarker
allocations map[string]*jobs.Allocation
mu sync.RWMutex
}
// New creates a new node, attaches an actor to the node.
func New(_ context.Context, id string, net network.Network, benchmark benchmarker, resourceManager resources.Manager) (*Node, error) {
if id == "" {
return nil, errors.New("id is nil")
}
if net == nil {
return nil, errors.New("network is nil")
}
if benchmark == nil {
return nil, errors.New("benchmarker is nil")
}
if resourceManager == nil {
return nil, errors.New("resource manager is nil")
}
actorFactory := dms.NewActorFactory(id, net, &dms.ActorParams{
HeartbeatInterval: time.Second * 5,
HeartbeatCheckInterval: time.Second * 8,
})
n := &Node{
ID: id,
actorFactory: actorFactory,
network: net,
allocations: make(map[string]*jobs.Allocation),
benchmark: benchmark,
resourceManager: resourceManager,
}
err := n.createNodeActor()
if err != nil {
return nil, fmt.Errorf("failed to create node actor: %w", err)
}
return n, nil
}
// GetAllocation gets an allocation by id.
func (n *Node) GetAllocation(id string) (*jobs.Allocation, error) {
n.mu.RLock()
defer n.mu.RUnlock()
alloc, ok := n.allocations[id]
if !ok {
return nil, errors.New("allocation not found")
}
return alloc, nil
}
// GetAllocation gets an allocation by id.
func (n *Node) BenchmarkCapability(ctx context.Context) (*types.Capability, error) {
return n.benchmark.Benchmark(ctx)
}
// CreateAllocation creates an allocation.
func (n *Node) CreateAllocation(_ context.Context, job jobs.Job) (*jobs.Allocation, error) {
allocationActor, err := n.actor.CreateActor()
if err != nil {
return nil, fmt.Errorf("failed to create allocation actor: %w", err)
}
allocation, err := jobs.NewAllocation(allocationActor, jobs.AllocationDetails{Job: job, NodeID: n.ID, SourceID: ""}, n.resourceManager)
if err != nil {
return nil, fmt.Errorf("failed to create allocation actor: %w", err)
}
n.mu.Lock()
n.allocations[allocation.ID] = allocation
n.mu.Unlock()
return allocation, nil
}
// ProcessMessages processes actor messages.
func (n *Node) ProcessMessages() {
for msg := range n.actor.Messages() {
n.dispatchMethod(msg.Type, msg.Data)
}
}
// SendMessage sends a message through the actor.
func (n *Node) SendMessage(ctx context.Context, destination *dms.ActorAddrInfo, m *dms.Message) error {
return n.actor.SendMessage(ctx, destination, m)
}
// HandleHello will be called when a message type of `Hello` arrives to the messages queue.
// For this to work properly we should always append Handle in front of the function.
func (n *Node) HandleHello(payload []byte) {
fmt.Println("hello from: ", string(payload))
}
func (n *Node) dispatchMethod(methodName string, args ...any) {
handlerMethod := fmt.Sprintf("Handle%s", methodName)
arguments := make([]reflect.Value, 0)
for _, v := range args {
arguments = append(arguments, reflect.ValueOf(v))
}
method := reflect.ValueOf(n).MethodByName(handlerMethod)
if method.IsValid() {
method.Call(arguments)
return
}
// check if actor has the method
actorMethod := reflect.ValueOf(n.actor).MethodByName(handlerMethod)
if actorMethod.IsValid() {
actorMethod.Call(arguments)
}
}
// createNodeActor the root actor in this node instance.
func (n *Node) createNodeActor() error {
actor, err := n.actorFactory.NewActor()
if err != nil {
return fmt.Errorf("failed to create actor: %w", err)
}
n.actor = actor
err = n.actor.Start()
if err != nil {
return fmt.Errorf("failed to start node actor: %w", err)
}
return nil
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func CPUComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for CPU type
// we want to reason about the inner fields of the CPU type and how they compare between left and right
// validate input type
lCPU, lok := l.(types.CPU)
rCPU, rok := r.(types.CPU)
if !lok || !rok {
return types.Error
}
perfComparison := NumericComparator(
(int64(lCPU.Cores) * lCPU.ClockSpeedHz),
(int64(rCPU.Cores) * rCPU.ClockSpeedHz),
)
archComparison := LiteralComparator(lCPU.Architecture, rCPU.Architecture)
if archComparison == types.Error {
return types.Error
}
if archComparison != types.Equal {
return types.Worse
}
return perfComparison
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and frequency is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of CPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func DiskComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for Memory type
// we want to reason about the inner fields of the Memory type and how they compare between left and right
// validate input type
_, lok := l.(types.Disk)
_, rok := r.(types.Disk)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["Type"] == types.Error {
return types.Error
}
if comparison["Type"] != types.Equal {
return types.Worse
}
return comparison["Size"]
// currently this is a very simple comparison, based on the assumption
// that more Size / or equal amount of size and speed is acceptable, but nothing less;
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func ExecutionResourcesComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for types.ExecutionResources type
// Current implementation of the type has four fields: CPU, Memory, Disk, GPUs
// we consider that all fields have to be 'Better' or 'Equal'
// for the comparison to be 'Better' or 'Equal'
// else we return 'Worse'
// validate input type
_, lok := l.(types.ExecutionResources)
_, rok := r.(types.ExecutionResources)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["CPU"] == types.Error ||
comparison["Memory"] == types.Error ||
comparison["Disk"] == types.Error ||
comparison["GPUs"] == types.Error {
return types.Error
}
if comparison["CPU"] == types.Worse ||
comparison["Memory"] == types.Worse ||
comparison["Disk"] == types.Worse ||
comparison["GPUs"] == types.Worse {
return types.Worse
}
if comparison["CPU"] == types.Equal &&
comparison["Memory"] == types.Equal &&
comparison["Disk"] == types.Equal &&
comparison["GPUs"] == types.Equal {
return types.Equal
}
return types.Better // if non above returns, then the result is better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func ExecutorComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for Executor types
// it is needed because executor type is defined as enum of ExecutorType's in types.execution.go
// left represent machine capabilities
// right represent required capabilities
// it is not so complex as the type has only one field
// therefore this method just passes it through...
// validate input type
_, lrawok := lraw.(types.Executor)
_, rrawok := rraw.(types.Executor)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Executor)
r := rraw.(types.Executor)
leftExecutorType := l.ExecutorType
rightExecutorType := r.ExecutorType
comparison := Compare(leftExecutorType, rightExecutorType)
return comparison
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func ExecutorTypeComparator(l, r interface{}, _ ...Preference) types.Comparison {
_, lok := l.(types.ExecutorType)
_, rok := r.(types.ExecutorType)
if !lok || !rok {
return types.Error
}
result := types.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = types.Equal
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func ExecutorsComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for Executors types:
// left represent machine capabilities;
// right represent required capabilities;
var result types.Comparison
result = types.Error // error is the default value
// validate input type
ll, lrawok := lraw.(types.Executors)
rr, rrawok := rraw.(types.Executors)
if !lrawok || !rrawok {
return types.Error
}
l := make([]interface{}, 0)
r := make([]interface{}, 0)
for _, v := range ll {
l = append(l, v)
}
for _, v := range rr {
r = append(r, v)
}
if !utils.IsSameShallowType(l, r) {
result = types.Error
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = types.Equal
case utils.IsStrictlyContained(l, r):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
result = types.Better
case utils.IsStrictlyContained(r, l):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
result = types.Worse
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func GPUVendorComparator(l, r interface{}, _ ...Preference) types.Comparison {
// validate input type
_, lok := l.(types.GPUVendor)
_, rok := r.(types.GPUVendor)
if !lok || !rok {
return types.Error
}
result := types.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = types.Equal
}
return result
// This comparison logic just tells if the vendor is the same or not;
// however, we do not have yet a mechanism for externally defined preferences from a user;
// in this case, we may need to implement that -- because some compute may prefer one vendor over the other;
// some compute may be strictly dependent on a specific vendor;
// technically, this will have to be solved on the resource matching level;
// but the mechanism will have to be generic...
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
func GPUsComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for GPUs type which is just a slice of GPU types:
// left represent machine capabilities;
// right represent required capabilities;
// we need to check if for ech GPU on the right there exist a matching GPU on the left...
// (since given slices are not ordered...)
// validate input type
_, lrawok := lraw.([]types.GPU)
_, rrawok := rraw.([]types.GPU)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.GPU)
r := rraw.([]types.GPU)
interimComparison1 := make([][]types.Comparison, 0)
for _, rGPU := range r {
var interimComparison2 []types.Comparison
for _, lGPU := range l {
interimComparison2 = append(interimComparison2, Compare(lGPU, rGPU))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []types.Comparison
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, types.Error) {
return types.Error
}
if slices.Contains(finalComparison, types.Worse) {
return types.Worse
}
if SliceContainsOneValue(finalComparison, types.Equal) {
return types.Equal
}
return types.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func GpuComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for GPU type
// we want to reason about the inner fields of the GPU type and how they compare between left and right
// in the future we may want to pass custom preference parameters to the ComplexComparator
// for now it is probably best to hardcode them;
// validate input type
_, lok := l.(types.GPU)
_, rok := r.(types.GPU)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["TotalVRAM"] == types.Error {
return types.Error
}
if comparison["TotalVRAM"] == types.Worse {
return types.Worse
}
if comparison["TotalVRAM"] == types.Better {
return types.Better
}
if comparison["TotalVRAM"] == types.Equal {
return types.Equal
}
// currently this is a very simple comparison, based on the assumption
// that more cores / or equal amount of cores and VRAM is acceptable, but nothing less;
// for more complex comparisons we would need to encode the very specific hardware knowledge;
// it could be, that we want to compare types.of GPUs and rank them in some way;
// using e.g. benchmarking data from Tom's Hardware or some other source;
return types.Error // error is the default value
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func JobTypeComparator(l, r interface{}, _ ...Preference) types.Comparison {
// validate input type
_, lok := l.(types.JobType)
_, rok := r.(types.JobType)
if !lok || !rok {
return types.Error
}
result := types.Error // default answer is error
if reflect.DeepEqual(l, r) {
result = types.Equal
}
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func JobTypesComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for JobTypes type:
// left represent machine capabilities;
// right represent required capabilities;
// if machine capabilities contain oll the required capabilities, then we are good to go
// validate input type
_, lrawok := lraw.(types.JobTypes)
_, rrawok := rraw.(types.JobTypes)
if !lrawok || !rrawok {
return types.Error
}
result := types.Error // error is the default value
// we know that interfaces here are slices, so need to assert first
l := utils.ConvertTypedSliceToUntypedSlice(lraw)
r := utils.ConvertTypedSliceToUntypedSlice(rraw)
if !utils.IsSameShallowType(l, r) {
result = types.Error
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = types.Equal
case utils.IsStrictlyContained(l, r):
// if machine capabilities contain all the required capabilities
// then the result of comparison is 'Better'
result = types.Better
case utils.IsStrictlyContained(r, l):
// if required capabilities contain all the machine capabilities
// then the result of comparison is 'Worse'
// ("available Capabilities are worse than required")')
// (note that Equal case is already handled above)
result = types.Worse
// TODO: this comparator does not take into account options when several job types are available and several job types are required
// in the same data structure; this is why the test fails;
}
return result
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
func LibrariesComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for Libraries slices (of different lengths) of Library types:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.Library)
_, rrawok := rraw.([]types.Library)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.Library)
r := rraw.([]types.Library)
interimComparison1 := make([][]types.Comparison, 0)
for _, rLibrary := range r {
var interimComparison2 []types.Comparison
for _, lLibrary := range l {
interimComparison2 = append(interimComparison2, Compare(lLibrary, rLibrary))
}
// this matrix structure will hold the comparison results for each GPU on the right
// with each GPU on the left in the order they are in the slices
// first dimension represents left GPUs
// second dimension represents right GPUs
interimComparison1 = append(interimComparison1, interimComparison2)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
finalComparison := make([]types.Comparison, 0)
for i := 0; i < len(interimComparison1); i++ {
// we need to find the best match for each GPU on the right
if len(interimComparison1[i]) < i {
break
}
c := interimComparison1[i]
bestMatch, index := returnBestMatch(c)
finalComparison = append(finalComparison, bestMatch)
interimComparison1 = removeIndex(interimComparison1, index)
}
if slices.Contains(finalComparison, types.Error) {
return types.Error
}
if slices.Contains(finalComparison, types.Worse) {
return types.Worse
}
if SliceContainsOneValue(finalComparison, types.Equal) {
return types.Equal
}
return types.Better
}
package matching
import (
"github.com/hashicorp/go-version"
"gitlab.com/nunet/device-management-service/types"
)
func LibraryComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for single Library type:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Library)
_, rrawok := rraw.(types.Library)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Library)
lVersion, err := version.NewVersion(l.Version)
if err != nil {
return types.Error
}
r := rraw.(types.Library)
// return 'Error' if the version of the left library is not valid
constraints, err := version.NewConstraint(r.Constraint + " " + r.Version)
if err != nil {
return types.Error
}
// return 'Error' if the names of the libraries are different
if l.Name != r.Name {
return types.Error
}
// else return 'Equal if versions of libraries are equal and the constraint is '='
if r.Constraint == "=" && constraints.Check(lVersion) {
return types.Equal
}
// else return 'Better' if versions of libraries match the constraint
if constraints.Check(lVersion) {
return types.Better
}
// else return 'Worse'
return types.Worse
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func LiteralComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for literal (basically string) types:
// left represent machine capabilities;
// right represent required capabilities;
// which can be only equal or not equal...
// validate input type
_, lok := l.(string)
_, rok := r.(string)
if !lok || !rok {
return types.Error
}
var result types.Comparison
result = types.Error // error is the default value
if str, ok := l.(string); ok {
// l is of type string, now check if it equals r
if str == r {
result = types.Equal
}
}
return result
}
package matching
import (
// "reflect"
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
func LocalitiesComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of Localities comparator
// which is simply a slice of Locality type;
// we do not have separate type defined for Localities
// it takes preference variable where comparison Preference is defined
// this is the first method that is used to take Preference variable into account
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.Locality)
_, rrawok := rraw.([]types.Locality)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.Locality)
r := rraw.([]types.Locality)
interimComparison := make([](map[string]types.Comparison), 0)
for _, rLocality := range r {
field := make(map[string]types.Comparison)
field[rLocality.Kind] = types.Error
for _, lLocality := range l {
if lLocality.Kind == rLocality.Kind {
field[rLocality.Kind] = Compare(lLocality, rLocality)
// this is to make sure that we have a comparison even if slice dimentiones do not match
}
}
interimComparison = append(interimComparison, field)
}
// we can now implement a logic to figure out if each required GPU on the left has a matching GPU on the right
var finalComparison []types.Comparison
for _, c := range interimComparison {
for _, v := range c { // we know that there is only one value in the map
finalComparison = append(finalComparison, v)
}
}
if slices.Contains(finalComparison, types.Error) {
return types.Error
}
if slices.Contains(finalComparison, types.Worse) {
return types.Worse
}
if SliceContainsOneValue(finalComparison, types.Equal) {
return types.Equal
}
return types.Better
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func LocalityComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// comparator for Locality:
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Locality)
_, rrawok := rraw.(types.Locality)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Locality)
r := rraw.(types.Locality)
if l.Kind == r.Kind {
if l.Name == r.Name {
return types.Equal
}
return types.Worse
}
return types.Error
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func MemoryComparator(l, r interface{}, _ ...Preference) types.Comparison {
// comparator for Memory type
// we want to reason about the inner fields of the Memory type and how they compare between left and right
// validate input type
_, lok := l.(types.RAM)
_, rok := r.(types.RAM)
if !lok || !rok {
return types.Error
}
comparison := ReturnComplexComparison(l, r)
if comparison["Size"] == types.Error {
return types.Error
}
if comparison["Size"] == types.Worse {
return types.Worse
}
return comparison["ClockSpeedHz"]
// currently this is a very simple comparison, based on the assumption
// that more Size / or equal amount of size and speed is acceptable, but nothing less;
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
func NumericComparator(lraw, rraw interface{}, _ ...Preference) types.Comparison {
// comparator for numeric types:
// left represent machine capabilities;
// right represent required capabilities;
var result types.Comparison
result = types.Error // error is the default value
// validate input type
l, lnumeric := validate.ConvertNumericToFloat64(lraw)
r, rnumeric := validate.ConvertNumericToFloat64(rraw)
if !lnumeric || !rnumeric {
result = types.Error
}
switch {
case reflect.DeepEqual(l, r):
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
result = types.Equal
case l < r:
// if declared machine numeric capability
// is less than job's required capability
// then the result of comparison is 'Worse'
result = types.Worse
case l > r:
// if declared machine numeric capability
// is more than job's required numeric capability
// then the result of comparison is 'Better'
result = types.Better
}
return result
}
package matching
import (
"fmt"
"gitlab.com/nunet/device-management-service/types"
)
// CapabilityComparator compares two capabilities by ANDing the comparisons of their fields.
// it respects the following table of truth:
//
// | AND | Better | Worse | Equal | Error |
// | ------ | ------ |--------|--------|--------|
// | Better | Better | Worse | Better | Error |
// | Worse | Worse | Worse | Worse | Error |
// | Equal | Better | Worse | Equal | Error |
// | Error | Error | Error | Error | Error |
//
// The comparison of the fields is done by the Compare function.
//
// Result = (Comparison of Executors) AND (Comparison of JobTypes) AND
// (Comparison of Resources) AND (Comparison of Libraries) AND
// (Comparison of Localities) AND (Comparison of Storage) AND
// (Comparison of Connectivity) AND (Comparison of Price) AND
// (Comparison of Time) AND (Comparison of KYCs)
func CapabilityComparator(l, r interface{}, _ ...Preference) types.Comparison {
var result types.Comparison
_, lok := l.(types.Capability)
_, rok := r.(types.Capability)
if !lok || !rok {
fmt.Println(lok, rok)
return types.Error
}
lcap := l.(types.Capability)
rcap := r.(types.Capability)
// Executors
result = Compare(lcap.Executors, rcap.Executors)
// JobTypes
result = result.And(Compare(lcap.JobTypes, rcap.JobTypes))
// Resources
result = result.And(Compare(lcap.Resources, rcap.Resources))
// Libraries
result = result.And(Compare(lcap.Libraries, rcap.Libraries))
// Localities
result = result.And(Compare(lcap.Localities, rcap.Localities))
// Storage
result = result.And(Compare(lcap.Storage, rcap.Storage))
// Connectivity
result = result.And(Compare(lcap.Connectivity, rcap.Connectivity))
// Price
result = result.And(Compare(lcap.Price, rcap.Price))
// Time
result = result.And(Compare(lcap.Time, rcap.Time))
// KYCs
result = result.And(Compare(lcap.KYCs, rcap.KYCs))
return result
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// generic compare function for comparing any custom types given a custom comparator
// for simple types (i.e which are not nested in the map[string]interface{} structure)
type Comparator func(l, r interface{}, preference ...Preference) types.Comparison
func Compare(l, r interface{}, _ ...Preference) types.Comparison {
// TODO: it would be better to pass a pointer as this is a global structure
comparatorMap := initComparatorMap()
// check if the type is numeric
if _, numeric := validate.ConvertNumericToFloat64(l); numeric {
comparator := comparatorMap["Numeric"]
if comparator == nil {
return types.Error
}
return comparator(l, r)
}
typeName := reflect.TypeOf(l).Name()
// this means that the type is probably a slice of custom types
// we have to get the element types and then map it to existing custom types that know
// so that we can call a correct comparator for that
//nolint
switch reflect.TypeOf(l).Kind() {
// check if we have a slice of further types
// we need to mention each type explicitly
case reflect.Slice:
if _, ok := l.([]types.GPU); ok {
typeName = "GPUs"
}
if _, ok := l.([]types.Library); ok {
typeName = "Libraries"
}
if _, ok := l.([]types.Locality); ok {
typeName = "Localities"
}
if _, ok := l.([]types.Storage); ok {
typeName = "Storages"
}
if _, ok := l.([]types.PriceInformation); ok {
typeName = "PricesInformation"
}
if _, ok := l.([]types.KYC); ok {
typeName = "KYCs"
}
}
// select the comparator based on type
comparator := comparatorMap[typeName]
if comparator == nil {
return types.Error
}
return comparator(l, r)
}
type ComparatorMap map[string]Comparator
func initComparatorMap() ComparatorMap {
// comparatorMap holds all defined comparators in a variable that can be passed
// around and searched / referenced
comparatorMap := make(map[string]Comparator)
comparatorMap["Numeric"] = NumericComparator
comparatorMap["Capability"] = CapabilityComparator
comparatorMap["string"] = LiteralComparator
comparatorMap["Executors"] = ExecutorsComparator
comparatorMap["ExecutorType"] = ExecutorTypeComparator
comparatorMap["JobType"] = JobTypeComparator
comparatorMap["JobTypes"] = JobTypesComparator
comparatorMap["GPUVendor"] = GPUVendorComparator
comparatorMap["GPUs"] = GPUsComparator
comparatorMap["GPU"] = GpuComparator
comparatorMap["Executor"] = ExecutorComparator
comparatorMap["ExecutionResources"] = ExecutionResourcesComparator
comparatorMap["CPU"] = CPUComparator
comparatorMap["RAM"] = MemoryComparator
comparatorMap["Disk"] = DiskComparator
comparatorMap["Library"] = LibraryComparator
comparatorMap["Libraries"] = LibrariesComparator
comparatorMap["Locality"] = LocalityComparator
comparatorMap["Localities"] = LocalitiesComparator
comparatorMap["Storage"] = StorageComparator
comparatorMap["Storages"] = StoragesComparator
comparatorMap["Connectivity"] = ConnectivityComparator
comparatorMap["PriceInformation"] = PriceInformationComparator
comparatorMap["PricesInformation"] = PricesInformationComparator
comparatorMap["TimeInformation"] = TimeInformationComparator
comparatorMap["KYC"] = KYCComparator
comparatorMap["KYCs"] = KYCsComparator
return comparatorMap
}
type Preference struct {
TypeName string
Strength PreferenceString
DefaultComparatorOverride Comparator
}
type PreferenceString string
const (
Hard PreferenceString = "Hard"
Soft PreferenceString = "Soft"
)
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
func ConnectivityComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Connectivity)
_, rrawok := rraw.(types.Connectivity)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Connectivity)
r := rraw.(types.Connectivity)
//nolint
if reflect.DeepEqual(l, r) {
// if available capabilities are
// equal to required capabilities
// then the result of comparison is 'Equal'
return types.Equal
} else if (utils.IsStrictlyContainedInt(l.Ports, r.Ports)) && (l.VPN && r.VPN || l.VPN && !r.VPN) {
return types.Better
} else {
return types.Worse
}
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func KYCComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.KYC)
_, rrawok := rraw.(types.KYC)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.KYC)
r := rraw.(types.KYC)
if reflect.DeepEqual(l, r) {
return types.Equal
}
return types.Error
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func KYCsComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.KYC)
_, rrawok := rraw.([]types.KYC)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.KYC)
r := rraw.([]types.KYC)
//nolint
if reflect.DeepEqual(l, r) {
return types.Equal
} else if len(r) == 0 && len(l) != 0 {
return types.Better
} else {
for _, lkyc := range l {
for _, rkyc := range r {
if comp := Compare(lkyc, rkyc); comp == types.Equal {
return types.Equal
}
}
}
}
return types.Error
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func PriceInformationComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.PriceInformation)
_, rrawok := rraw.(types.PriceInformation)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.PriceInformation)
r := rraw.(types.PriceInformation)
//nolint
if reflect.DeepEqual(l, r) {
return types.Equal
} else if l.Currency == r.Currency {
//nolint
if l.TotalPerJob == r.TotalPerJob {
if l.CurrencyPerHour == r.CurrencyPerHour {
return types.Equal
} else if l.CurrencyPerHour < r.CurrencyPerHour {
return types.Better
} else {
return types.Worse
}
} else if l.TotalPerJob < r.TotalPerJob {
if l.CurrencyPerHour <= r.CurrencyPerHour {
return types.Better
} else {
return types.Worse
}
} else {
return types.Worse
}
}
return types.Error
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func PricesInformationComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
// validate input type
_, lrawok := lraw.([]types.PriceInformation)
_, rrawok := rraw.([]types.PriceInformation)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.PriceInformation)
r := rraw.([]types.PriceInformation)
if reflect.DeepEqual(l, r) {
return types.Equal
}
comparison := types.Error
for _, lPrice := range l {
for _, rPrice := range r {
if comparison = Compare(lPrice, rPrice); comparison != types.Error {
return comparison
}
}
}
return comparison
}
package matching
import "gitlab.com/nunet/device-management-service/types"
func StorageComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.Storage)
_, rrawok := rraw.(types.Storage)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.Storage)
r := rraw.(types.Storage)
if l.Type == r.Type {
if l.Size*l.Amount == r.Size*l.Amount {
return types.Equal
} else if l.Size*l.Amount > r.Size*l.Amount {
return types.Better
}
return types.Worse
}
if l.Type == types.SSD_STORAGE_TYPE && r.Type == types.HDD_STORAGE_TYPE {
if l.Size*l.Amount == r.Size*r.Amount {
return types.Better
} else if l.Size*l.Amount > r.Size*r.Amount {
return types.Better
}
return types.Worse
} else if l.Type == types.HDD_STORAGE_TYPE && r.Type == types.SSD_STORAGE_TYPE {
return types.Worse
}
return types.Error
}
package matching
import (
"gitlab.com/nunet/device-management-service/types"
)
func StoragesComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.([]types.Storage)
_, rrawok := rraw.([]types.Storage)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.([]types.Storage)
r := rraw.([]types.Storage)
rAcc := map[string]map[string]int{
"ssd": {
"size": 0,
"amount": 0,
},
"hdd": {
"size": 0,
"amount": 0,
},
}
for _, rstrg := range r {
if rstrg.Type == "ssd" {
rAcc["ssd"]["size"] += rstrg.Size * rstrg.Amount
} else if rstrg.Type == "hdd" {
rAcc["hdd"]["size"] += rstrg.Size * rstrg.Amount
}
}
lAcc := map[string]map[string]int{
"ssd": {
"size": 0,
"amount": 0,
},
"hdd": {
"size": 0,
"amount": 0,
},
}
for _, lstrg := range l {
if lstrg.Type == "ssd" {
lAcc["ssd"]["size"] += lstrg.Size * lstrg.Amount
} else if lstrg.Type == "hdd" {
lAcc["hdd"]["size"] += lstrg.Size * lstrg.Amount
}
}
// compare
totalRequestedSSD := rAcc["ssd"]["size"]
totalRequestedHDD := rAcc["hdd"]["size"]
totalAvailableSSD := lAcc["ssd"]["size"]
totalAvailableHDD := lAcc["hdd"]["size"]
// if hdd is being requested but we don't have it
//nolint
if totalAvailableHDD == 0 && totalRequestedHDD > 0 {
// if ssd is better than ssd and hdd combined
if totalAvailableSSD >= totalRequestedSSD+totalRequestedHDD {
return types.Better
}
return types.Worse
} else if totalAvailableSSD < totalRequestedSSD {
return types.Worse
} else if totalAvailableSSD == totalRequestedSSD && totalAvailableHDD == totalRequestedHDD {
return types.Equal
} else if totalAvailableSSD >= totalRequestedSSD && totalAvailableHDD >= totalRequestedHDD {
return types.Better
}
return types.Error
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func TimeInformationComparator(lraw interface{}, rraw interface{}, _ ...Preference) types.Comparison {
// simplified version of (placeholder)
// left represent machine capabilities;
// right represent required capabilities;
// validate input type
_, lrawok := lraw.(types.TimeInformation)
_, rrawok := rraw.(types.TimeInformation)
if !lrawok || !rrawok {
return types.Error
}
l := lraw.(types.TimeInformation)
r := rraw.(types.TimeInformation)
if reflect.DeepEqual(l, r) {
return types.Equal
}
lTotalTime := totalTime(l)
rTotalTime := totalTime(r)
//nolint
if lTotalTime == rTotalTime {
return types.Equal
} else if lTotalTime < rTotalTime {
return types.Worse
}
return types.Better
}
func totalTime(timeInfo types.TimeInformation) int {
switch timeInfo.Units {
case "seconds":
return timeInfo.MaxTime
case "minutes":
return timeInfo.MaxTime * 60
case "hours":
return timeInfo.MaxTime * 60 * 60
case "days":
return timeInfo.MaxTime * 60 * 60 * 24
default:
return timeInfo.MaxTime
}
}
package matching
import (
"reflect"
"gitlab.com/nunet/device-management-service/types"
)
func ReturnComplexComparison(l, r interface{}) types.ComplexComparison {
// Complex comparison is a comparison of two complex types
// Which have nested fields that need to be considered together
// before a final comparison for the whole complex type can be made
// it is a helper function used in some type comparators
vl := reflect.ValueOf(l)
vr := reflect.ValueOf(r)
complexComparison := make(types.ComplexComparison)
for i := 0; i < vl.NumField(); i++ {
innerTypeName := vl.Type().Field(i).Name
valueL := vl.Field(i).Interface()
valueR := vr.Field(i).Interface()
complexComparison[innerTypeName] = Compare(valueL, valueR)
}
return complexComparison
}
func removeIndex(slice [][]types.Comparison, index int) [][]types.Comparison {
// removeIndex removes the element at the specified index from each sub-slice in the given slice.
// If the index is out of bounds for a sub-slice, the function leaves that sub-slice unmodified.
for i, c := range slice {
if index < 0 || index >= len(c) {
// Index is out of bounds, leave the sub-slice unmodified
continue
}
slice[i] = append(c[:index], c[index+1:]...)
}
return slice
}
func returnBestMatch(dimension []types.Comparison) (types.Comparison, int) {
// while i feel that there could be some weird matrix sorting algorithm that could be used here
// i can't think of any right now, so i will just iterate over the matrix and return matches
// in somewhat manual way
for i, v := range dimension {
if v == types.Equal {
return v, i // selecting an equal match is the most efficient match
}
}
for i, v := range dimension {
if v == types.Better {
return v, i // selecting a better is also not bad
}
}
for i, v := range dimension {
if v == types.Worse {
return v, i // this is just for sport
}
}
for i, v := range dimension {
if v == types.Error {
return v, i // this is just for sport
}
}
return types.Error, -1
}
func SliceContainsOneValue(slice []types.Comparison, value types.Comparison) bool {
// returns true if all elements in the slice are equal to the given value
for _, v := range slice {
if v != value {
return false
}
}
return true
}
package resources
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/db"
gormRepo "gitlab.com/nunet/device-management-service/db/repositories/gorm"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var (
// zlog is the logger for the resources package
zlog *otelzap.Logger
// ManagerInstance is the ResourceManager instance
ManagerInstance Manager
)
// TODO: This needs to be initialized in `dms` package and removed from here
// https://gitlab.com/nunet/device-management-service/-/issues/536
// it is being initialized in `dms` package now but there is usage in executor
// in executor/docker/executor.go:262:25 in function newDockerExecutionContainer
// which heavily depends on this var and any attempt to fix it will involve
// too many changes. Once that code moves to allocations, this can be removed.
func init() {
zlog = logger.OtelZapLogger("resources")
repos := ManagerRepos{
FreeResources: gormRepo.NewFreeResources(db.DB),
OnboardedResources: gormRepo.NewOnboardedResources(db.DB),
RequiredResources: gormRepo.NewRequiredResources(db.DB),
VirtualMachine: gormRepo.NewVirtualMachine(db.DB),
Services: gormRepo.NewServices(db.DB),
}
ManagerInstance = NewResourceManager(repos)
}
package resources
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/types"
)
// SystemSpecs is an interface that defines the methods to get the system specifications of the machine
type SystemSpecs interface {
// GetSpecInfo returns the detailed specifications of the machine
GetSpecInfo() (types.SpecInfo, error)
// GetGPUVendors returns the GPU vendors of the machine
GetGPUVendors() ([]types.GPUVendor, error)
// GetGPUs returns the GPUs of the machine for the given vendors
// If no vendors are provided, it returns the information of all the GPUs
GetGPUs(vendors ...types.GPUVendor) ([]types.GPU, error)
// GetTotalMemory returns the total memory of the machine in MB
GetTotalMemory() (uint64, error)
// GetTotalStorage returns the total storage of the machine in MB
GetTotalStorage() (uint64, error)
// GetCPUInfo returns the CPU information of the machine
GetCPUInfo() (types.CPUInfo, error)
// GetProvisionedResources returns the total resources of the machine
GetProvisionedResources() (types.Resources, error)
}
// Manager is an interface that defines the methods to manage the resources of the machine
type Manager interface {
// UpdateFreeResources calculates, updates db and returns the free resources of the machine in the database
UpdateFreeResources(context.Context) (types.FreeResources, error)
// GetOnboardedResources returns the onboarded resources of the machine
GetOnboardedResources(context.Context) (types.OnboardedResources, error)
// GetRequiredResources returns the resources required by the jobs running on the machine
GetRequiredResources(context.Context) (types.Resources, error)
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
UpdateOnboardedResources(context.Context, types.OnboardedResources) error
// SystemSpecs returns the SystemSpecs instance
SystemSpecs() SystemSpecs
// UsageMonitor returns the UsageMonitor instance
UsageMonitor() UsageMonitor
// ... other methods
}
// DefaultManager implements the Manager interface
// TODO: do we want to have an in-memory cache for the resources instead of querying the DB every time?
// TODO: Add telemetry for the methods https://gitlab.com/nunet/device-management-service/-/issues/535
type DefaultManager struct {
usageMonitor UsageMonitor
systemSpecs SystemSpecs
repos ManagerRepos
}
// ManagerRepos holds all the repositories needed for resource management
type ManagerRepos struct {
FreeResources repositories.FreeResources
OnboardedResources repositories.OnboardedResources
RequiredResources repositories.RequiredResources
VirtualMachine repositories.VirtualMachine
Services repositories.Services
}
// NewResourceManager returns a new defaultResourceManager instance
func NewResourceManager(repos ManagerRepos) *DefaultManager {
sysSpecs := newSystemSpecs()
return &DefaultManager{
usageMonitor: newUsageMonitor(
sysSpecs,
repos.VirtualMachine,
repos.Services,
repos.RequiredResources,
),
systemSpecs: sysSpecs,
repos: repos,
}
}
var _ Manager = (*DefaultManager)(nil)
// UpdateFreeResources calculates, updates db and returns the free resources of the machine in the database
func (d DefaultManager) UpdateFreeResources(ctx context.Context) (types.FreeResources, error) {
usage, err := d.usageMonitor.GetUsage(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting usage: %w", err)
}
onboardedResources, err := d.GetOnboardedResources(ctx)
if err != nil {
return types.FreeResources{}, fmt.Errorf("getting total resources: %w", err)
}
freeResources, err := onboardedResources.Subtract(usage)
if err != nil {
return types.FreeResources{}, fmt.Errorf("calculating free resources: %w", err)
}
if err := d.updateDBFreeResources(ctx, types.FreeResources{Resources: freeResources}); err != nil {
return types.FreeResources{}, fmt.Errorf("updating free resources in db: %w", err)
}
return types.FreeResources{Resources: freeResources}, nil
}
// GetOnboardedResources returns the onboarded resources of the machine
func (d DefaultManager) GetOnboardedResources(ctx context.Context) (types.OnboardedResources, error) {
onboardedResources, err := d.repos.OnboardedResources.Get(ctx)
if err != nil {
return types.OnboardedResources{}, fmt.Errorf("failed to get onboarded resources: %w", err)
}
return onboardedResources, nil
}
// GetRequiredResources returns the resources required by the jobs running on the machine
func (d DefaultManager) GetRequiredResources(ctx context.Context) (types.Resources, error) {
jobRequirements, err := d.repos.RequiredResources.FindAll(ctx, d.repos.RequiredResources.GetQuery())
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get resource requirements from db - %v", err)
}
var totalRequiredResources types.Resources
for _, req := range jobRequirements {
totalRequiredResources = totalRequiredResources.Add(req.Resources)
}
return totalRequiredResources, nil
}
// UpdateOnboardedResources updates the onboarded resources of the machine in the database
func (d DefaultManager) UpdateOnboardedResources(ctx context.Context, resources types.OnboardedResources) error {
_, err := d.repos.OnboardedResources.Save(ctx, resources)
if err != nil {
return fmt.Errorf("failed to update onboarded resources: %w", err)
}
return nil
}
// SystemSpecs returns the SystemSpecs instance
func (d DefaultManager) SystemSpecs() SystemSpecs {
return d.systemSpecs
}
// UsageMonitor returns the UsageMonitor instance
func (d DefaultManager) UsageMonitor() UsageMonitor {
return d.usageMonitor
}
// updateDBFreeResources updates the free resources in the database
func (d DefaultManager) updateDBFreeResources(ctx context.Context, freeResources types.FreeResources) error {
_, err := d.repos.FreeResources.Save(ctx, freeResources)
if err != nil {
return fmt.Errorf("updating free resources: %w", err)
}
return nil
}
package resources
import (
"context"
"errors"
"fmt"
"os/exec"
"regexp"
"strconv"
"strings"
"gitlab.com/nunet/device-management-service/types"
"github.com/NVIDIA/go-nvml/pkg/nvml"
"github.com/jaypipes/ghw"
"github.com/shirou/gopsutil/v4/cpu"
"github.com/shirou/gopsutil/v4/disk"
"github.com/shirou/gopsutil/v4/mem"
)
type gpuMetadata struct {
PCIAddress string
}
// linuxSystemSpecs implements the SystemSpecs interface for Linux systems
type linuxSystemSpecs struct{}
// newSystemSpecs returns a new instance of linuxSystemSpecs
func newSystemSpecs() *linuxSystemSpecs {
return &linuxSystemSpecs{}
}
var _ SystemSpecs = (*linuxSystemSpecs)(nil)
// GetSpecInfo returns the detailed specifications of the system
// TODO: implement the method
// https://gitlab.com/nunet/device-management-service/-/issues/537
func (l linuxSystemSpecs) GetSpecInfo() (types.SpecInfo, error) {
// TODO implement me
panic("implement me")
}
// GetGPUVendors returns the GPU vendors for the system
func (l linuxSystemSpecs) GetGPUVendors() ([]types.GPUVendor, error) {
var vendors []types.GPUVendor
gpu, err := ghw.GPU()
if err != nil {
return nil, err
}
for _, card := range gpu.GraphicsCards {
if card.DeviceInfo != nil {
class := card.DeviceInfo.Class
if class != nil {
className := strings.ToLower(class.Name)
if strings.Contains(className, "display controller") ||
strings.Contains(className, "vga compatible controller") ||
strings.Contains(className, "3d controller") ||
strings.Contains(className, "2d controller") {
vendor := card.DeviceInfo.Vendor
if vendor != nil {
switch {
case strings.Contains(strings.ToLower(vendor.Name), "nvidia"):
vendors = append(vendors, types.GPUVendorNvidia)
case strings.Contains(strings.ToLower(vendor.Name), "amd"):
vendors = append(vendors, types.GPUVendorAMDATI)
case strings.Contains(strings.ToLower(vendor.Name), "intel"):
vendors = append(vendors, types.GPUVendorIntel)
default:
vendors = append(vendors, types.GPUVendorUnknown)
}
}
}
}
}
}
return vendors, nil
}
// GetGPUs returns the GPUs based on the specified vendors. If no vendors are provided, it returns the information of all the GPUs
func (l linuxSystemSpecs) GetGPUs(vendors ...types.GPUVendor) ([]types.GPU, error) {
var gpus []types.GPU
gpuMetadataMap, err := l.fetchGPUMetadata()
if err != nil {
return nil, fmt.Errorf("failed to fetch GPU metadata: %w", err)
}
// Helper function to fetch and append GPU info
fetchAndAppendGPUs := func(fetchFunc func(metadata []gpuMetadata) ([]types.GPU, error), vendor types.GPUVendor) {
vendorMetadata, ok := gpuMetadataMap[vendor]
if !ok {
zlog.Sugar().Infof("No %s GPUs found", vendor)
return
}
gpuList, err := fetchFunc(vendorMetadata)
if err != nil {
zlog.Sugar().Warnf("Failed to retrieve %s GPU information: %v", vendor, err)
return
}
gpus = append(gpus, gpuList...)
}
if len(vendors) == 0 {
// No specific vendor requested, fetch all types of GPUs
fetchAndAppendGPUs(l.getIntelGPUInfo, types.GPUVendorIntel)
fetchAndAppendGPUs(l.getNVIDIAGPUInfo, types.GPUVendorNvidia)
fetchAndAppendGPUs(l.getAMDGPUInfo, types.GPUVendorAMDATI)
} else {
// Fetch GPUs for the specified vendor only
for _, vendor := range vendors {
switch vendor {
case types.GPUVendorIntel:
fetchAndAppendGPUs(l.getIntelGPUInfo, vendor)
case types.GPUVendorNvidia:
fetchAndAppendGPUs(l.getNVIDIAGPUInfo, vendor)
case types.GPUVendorAMDATI:
fetchAndAppendGPUs(l.getAMDGPUInfo, vendor)
default:
return nil, fmt.Errorf("unsupported GPU vendor: %v", vendor)
}
}
}
// Assign index to GPUs and return
// Note: The index is internal to dms and is not the same as the device index
return assignIndexToGPUs(gpus), nil
}
// GetTotalMemory returns the total memory available on the system
func (l linuxSystemSpecs) GetTotalMemory() (uint64, error) {
v, err := mem.VirtualMemory()
if err != nil {
return 0, fmt.Errorf("failed to get total memory: %s", err)
}
ramInMB := v.Total / 1024 / 1024
return ramInMB, nil
}
// GetTotalStorage returns the total storage available on the system
func (l linuxSystemSpecs) GetTotalStorage() (uint64, error) {
partitions, err := disk.PartitionsWithContext(context.Background(), false)
if err != nil {
return 0, fmt.Errorf("failed to get partitions: %w", err)
}
var totalStorage uint64
for p := range partitions {
usage, err := disk.UsageWithContext(context.Background(), partitions[p].Mountpoint)
if err != nil {
return 0, fmt.Errorf("failed to get disk usage: %w", err)
}
totalStorage += usage.Total
}
return totalStorage, nil
}
// GetCPUInfo returns the CPU information for the system
func (l linuxSystemSpecs) GetCPUInfo() (types.CPUInfo, error) {
cores, err := cpu.Info()
if err != nil {
return types.CPUInfo{}, fmt.Errorf("failed to get CPU info: %s", err)
}
var totalCompute float64
for i := 0; i < len(cores); i++ {
totalCompute += cores[i].Mhz
}
return types.CPUInfo{
Compute: totalCompute,
NumCores: uint64(len(cores)),
MHzPerCore: cores[0].Mhz,
}, nil
}
// GetProvisionedResources returns the total resources available on the system
func (l linuxSystemSpecs) GetProvisionedResources() (types.Resources, error) {
cpuInfo, err := l.GetCPUInfo()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get CPU info: %s", err)
}
totalMemory, err := l.GetTotalMemory()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get total memory: %s", err)
}
gpus, err := l.GetGPUs()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get GPU info: %s", err)
}
totalDisk, err := l.GetTotalStorage()
if err != nil {
return types.Resources{}, fmt.Errorf("failed to get total storage: %s", err)
}
return types.Resources{
CPU: cpuInfo.Compute,
RAM: totalMemory,
Disk: totalDisk,
GPU: gpus,
}, nil
}
// getAMDGPUInfo returns the GPU information for AMD GPUs
func (l linuxSystemSpecs) getAMDGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
cmd := exec.Command("rocm-smi", "--showid", "--showproductname", "--showmeminfo", "vram")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("AMD ROCm not installed, initialized, or configured (reboot recommended for newly installed AMD GPU Drivers): %s", err)
}
outputStr := string(output)
// fmt.Println("rocm-smi vram output:\n", outputStr) // Print the output for debugging
gpuNameRegex := regexp.MustCompile(`GPU\[\d+\]\s+: Card Series:\s+([^\n]+)`)
totalRegex := regexp.MustCompile(`GPU\[\d+\]\s+: VRAM Total Memory \(B\):\s+(\d+)`)
usedRegex := regexp.MustCompile(`GPU\[\d+\]\s+: VRAM Total Used Memory \(B\):\s+(\d+)`)
gpuNameMatches := gpuNameRegex.FindAllStringSubmatch(outputStr, -1)
totalMatches := totalRegex.FindAllStringSubmatch(outputStr, -1)
usedMatches := usedRegex.FindAllStringSubmatch(outputStr, -1)
if len(gpuNameMatches) == 0 || len(totalMatches) == 0 || len(usedMatches) == 0 {
return nil, fmt.Errorf("failed to find AMD GPU information or vram information in the output")
}
if len(gpuNameMatches) != len(totalMatches) || len(totalMatches) != len(usedMatches) {
return nil, fmt.Errorf("inconsistent AMD GPU information detected")
}
gpuInfos := make([]types.GPU, 0)
for i := range gpuNameMatches {
gpuName := gpuNameMatches[i][1]
totalMemoryBytes, err := strconv.ParseInt(totalMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse total amdgpu vram: %s", err)
}
usedMemoryBytes, err := strconv.ParseInt(usedMatches[i][1], 10, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse used amdgpu vram: %s", err)
}
totalMemoryMiB := totalMemoryBytes / 1024 / 1024
usedMemoryMiB := usedMemoryBytes / 1024 / 1024
freeMemoryMiB := totalMemoryMiB - usedMemoryMiB
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
TotalVRAM: uint64(totalMemoryMiB),
UsedVRAM: uint64(usedMemoryMiB),
FreeVRAM: uint64(freeMemoryMiB),
Vendor: types.GPUVendorAMDATI,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// getNVIDIAGPUInfo returns the GPU information for NVIDIA GPUs
func (l linuxSystemSpecs) getNVIDIAGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
// Initialize NVML
ret := nvml.Init()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("NVIDIA Management Library not installed, initialized or configured (reboot recommended for newly installed NVIDIA GPU drivers): %s", nvml.ErrorString(ret))
}
defer func() {
_ = nvml.Shutdown()
}()
// Get the number of GPU devices
deviceCount, ret := nvml.DeviceGetCount()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device count: %s", nvml.ErrorString(ret))
}
if deviceCount != len(metadata) {
return nil, fmt.Errorf("failed to find NVIDIA GPU information for all GPUs")
}
var gpus []types.GPU
// Iterate over each device
for i := 0; i < deviceCount; i++ {
// Get the device handle
device, ret := nvml.DeviceGetHandleByIndex(i)
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get device handle for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the device name
name, ret := device.GetName()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get name for device %d: %s", i, nvml.ErrorString(ret))
}
// Get the memory info
memory, ret := device.GetMemoryInfo()
if !errors.Is(ret, nvml.SUCCESS) {
return nil, fmt.Errorf("failed to get nvidiagpu vram info for device %d: %s", i, nvml.ErrorString(ret))
}
gpu := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Name: name,
Model: name,
TotalVRAM: memory.Total / 1024 / 1024,
UsedVRAM: memory.Used / 1024 / 1024,
FreeVRAM: memory.Free / 1024 / 1024,
Vendor: types.GPUVendorNvidia,
}
gpus = append(gpus, gpu)
}
return gpus, nil
}
// getIntelGPUInfo returns the GPU information for Intel GPUs
func (l linuxSystemSpecs) getIntelGPUInfo(metadata []gpuMetadata) ([]types.GPU, error) {
// Determine the number of discrete Intel GPUs
cmd := exec.Command("xpu-smi", "health", "-l")
output, err := cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("xpu-smi not installed, initialized, or configured: %s", err)
}
outputStr := string(output)
// fmt.Println("xpu-smi health -l output:\n", outputStr) // Print the output for debugging
// Use regex to find all instances of Device ID
deviceIDRegex := regexp.MustCompile(`(?i)\| Device ID\s+\|\s+(\d+)\s+\|`)
deviceIDMatches := deviceIDRegex.FindAllStringSubmatch(outputStr, -1)
// fmt.Printf("Found device ID matches: %v\n", deviceIDMatches) // Print matched device IDs for debugging
if len(deviceIDMatches) == 0 {
return nil, fmt.Errorf("failed to find any Intel GPUs")
}
if len(deviceIDMatches) != len(metadata) {
return nil, fmt.Errorf("failed to find Intel GPU information for all GPUs")
}
gpuInfos := make([]types.GPU, 0)
for i, match := range deviceIDMatches {
deviceID := match[1]
// Get GPU details using xpu-smi discovery
cmd = exec.Command("xpu-smi", "discovery", "-d", deviceID)
output, err = cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to get discovery info for Intel GPU %s: %s", deviceID, err)
}
outputStr = string(output)
// fmt.Printf("xpu-smi discovery -d %s output:\n%s", deviceID, outputStr) // Print the output for debugging
// Use regex to find GPU name and total memory
nameRegex := regexp.MustCompile(`(?i)Device Name:\s+([^\n|]+)`)
totalMemRegex := regexp.MustCompile(`(?i)Memory Physical Size:\s+([^\s]+)\s+MiB`)
nameMatch := nameRegex.FindStringSubmatch(outputStr)
totalMemMatch := totalMemRegex.FindStringSubmatch(outputStr)
if nameMatch == nil || totalMemMatch == nil {
return nil, fmt.Errorf("failed to parse discovery info for Intel GPU %s", deviceID)
}
gpuName := strings.TrimSpace(nameMatch[1])
totalMemoryMiB, err := strconv.ParseFloat(totalMemMatch[1], 64)
if err != nil {
return nil, fmt.Errorf("failed to parse total memory for Intel GPU %s: %s", deviceID, err)
}
// Get used memory using xpu-smi stats
cmd = exec.Command("xpu-smi", "stats", "-d", deviceID)
output, err = cmd.CombinedOutput()
if err != nil {
return nil, fmt.Errorf("failed to get stats for Intel GPU %s: %s", deviceID, err)
}
outputStr = string(output)
// fmt.Printf("xpu-smi stats -d %s output:\n%s", deviceID, outputStr) // Print the output for debugging
// Use regex to find used memory
usedMemRegex := regexp.MustCompile(`(?i)GPU Memory Used \(MiB\)\s+\|\s+(\d+)\s+\|`)
usedMemMatch := usedMemRegex.FindStringSubmatch(outputStr)
if usedMemMatch == nil {
return nil, fmt.Errorf("failed to parse used memory for Intel GPU %s", deviceID)
}
usedMemoryMiB, err := strconv.ParseFloat(usedMemMatch[1], 64)
if err != nil {
return nil, fmt.Errorf("failed to parse used memory for Intel GPU %s: %s", deviceID, err)
}
freeMemoryMiB := totalMemoryMiB - usedMemoryMiB
gpuInfo := types.GPU{
PCIAddress: metadata[i].PCIAddress,
Model: gpuName,
TotalVRAM: uint64(totalMemoryMiB),
UsedVRAM: uint64(usedMemoryMiB),
FreeVRAM: uint64(freeMemoryMiB),
Vendor: types.GPUVendorIntel,
}
gpuInfos = append(gpuInfos, gpuInfo)
}
return gpuInfos, nil
}
// fetchGPUMetadata fetches the GPU metadata for the system using `ghw.GPU()`
// TODO: Use one single library to fetch GPU information or improve the match criteria
// https://gitlab.com/nunet/device-management-service/-/issues/548
// TODO: write tests by mocking the gpu snapshot
// https://gitlab.com/nunet/device-management-service/-/issues/534
func (l linuxSystemSpecs) fetchGPUMetadata() (map[types.GPUVendor][]gpuMetadata, error) {
gpuInfo, err := ghw.GPU()
if err != nil {
return nil, err
}
gpuDetails := make(map[types.GPUVendor][]gpuMetadata)
for _, card := range gpuInfo.GraphicsCards {
if card.DeviceInfo == nil {
continue
}
pciAddress := card.Address
vendor := types.ParseGPUVendor(card.DeviceInfo.Vendor.Name)
gpuDetails[vendor] = append(gpuDetails[vendor], gpuMetadata{PCIAddress: pciAddress})
}
return gpuDetails, nil
}
// assignIndexToGPUs assigns an index to each GPU in the list starting from 0
func assignIndexToGPUs(gpus []types.GPU) []types.GPU {
for i := range gpus {
gpus[i].Index = i
}
return gpus
}
package resources
import (
"context"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/db/repositories"
)
// UsageMonitor defines the methods to monitor the system usage
type UsageMonitor interface {
// GetUsage returns the resources used by the machine
GetUsage(context.Context) (types.Resources, error)
}
// defaultUsageMonitor implements the UsageMonitor interface
type defaultUsageMonitor struct {
systemSpecs SystemSpecs
vmRepo repositories.VirtualMachine
serviceRepo repositories.Services
requiredResourcesRepo repositories.RequiredResources
}
// newUsageMonitor creates a new defaultUsageMonitor
func newUsageMonitor(
systemSpecs SystemSpecs,
vmRepo repositories.VirtualMachine,
serviceRepo repositories.Services,
requiredResourcesRepo repositories.RequiredResources,
) *defaultUsageMonitor {
return &defaultUsageMonitor{
systemSpecs: systemSpecs,
vmRepo: vmRepo,
serviceRepo: serviceRepo,
requiredResourcesRepo: requiredResourcesRepo,
}
}
var _ UsageMonitor = (*defaultUsageMonitor)(nil)
// GetUsage returns the resources used by the machine
func (um *defaultUsageMonitor) GetUsage(ctx context.Context) (types.Resources, error) {
cpuInfo, err := um.systemSpecs.GetCPUInfo()
if err != nil {
return types.Resources{}, fmt.Errorf("getting CPU info: %w", err)
}
vmUsage, err := um.getVMUsage(ctx, cpuInfo)
if err != nil {
return types.Resources{}, fmt.Errorf("getting VM usage: %w", err)
}
contUsage, err := um.getContainerUsage(ctx)
if err != nil {
return types.Resources{}, fmt.Errorf("getting container usage: %w", err)
}
return vmUsage.Add(contUsage), nil
}
// getVMUsage returns the total usage of all running VMs
func (um *defaultUsageMonitor) getVMUsage(ctx context.Context, cpuInfo types.CPUInfo) (types.Resources, error) {
query := um.vmRepo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("State", "running"))
vms, err := um.vmRepo.FindAll(ctx, query)
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get running VMs: %w", err)
}
var resourcesUsage types.Resources
if len(vms) == 0 {
return resourcesUsage, nil
}
// TODO: disk usage
var totalVCPU, totalMemSizeMib uint
for _, vm := range vms {
totalVCPU += vm.VCPUCount
totalMemSizeMib += vm.MemSizeMib
}
resourcesUsage.RAM = uint64(totalMemSizeMib)
resourcesUsage.CPU = float64(totalVCPU) * cpuInfo.MHzPerCore // CPU in MHz
return resourcesUsage, nil
}
// getContainerUsage returns the total usage of all running containers
func (um *defaultUsageMonitor) getContainerUsage(ctx context.Context) (types.Resources, error) {
query := um.serviceRepo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("JobStatus", "running"))
services, err := um.serviceRepo.FindAll(ctx, query)
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get running containers: %w", err)
}
var resourcesUsage types.Resources
if len(services) == 0 {
return resourcesUsage, nil
}
requirements, err := um.getResourceRequirements(ctx)
if err != nil {
return types.Resources{}, fmt.Errorf("unable to get resource requirements: %w", err)
}
// TODO: disk usage
for _, service := range services {
resourcesReq := requirements[service.ResourceRequirements]
resourcesUsage.CPU += resourcesReq.CPU
resourcesUsage.RAM += resourcesReq.RAM
}
return resourcesUsage, nil
}
// getResourceRequirements returns the resource requirements of all jobs
func (um *defaultUsageMonitor) getResourceRequirements(ctx context.Context) (map[int]types.RequiredResources, error) {
requiredResources, err := um.requiredResourcesRepo.FindAll(ctx, um.requiredResourcesRepo.GetQuery())
if err != nil {
return nil, fmt.Errorf("unable to query resource requirements: %w", err)
}
mappedRequiredResources := make(map[int]types.RequiredResources)
for _, rr := range requiredResources {
mappedRequiredResources[rr.JobID] = rr
}
return mappedRequiredResources, nil
}
package docker
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/network"
"github.com/docker/docker/client"
"github.com/docker/docker/pkg/jsonmessage"
"github.com/docker/docker/pkg/stdcopy"
v1 "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/pkg/errors"
"go.uber.org/multierr"
)
// Client wraps the Docker client to provide high-level operations on Docker containers and networks.
type Client struct {
client *client.Client // Embed the Docker client.
}
// NewDockerClient initializes a new Docker client with environment variables and API version negotiation.
func NewDockerClient() (*Client, error) {
c, err := client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation())
if err != nil {
return nil, err
}
return &Client{client: c}, nil
}
// IsInstalled checks if Docker is installed and reachable by pinging the Docker daemon.
func (c *Client) IsInstalled(ctx context.Context) bool {
_, err := c.client.Ping(ctx)
return err == nil
}
// CreateContainer creates a new Docker container with the specified configuration.
func (c *Client) CreateContainer(
ctx context.Context,
config *container.Config,
hostConfig *container.HostConfig,
networkingConfig *network.NetworkingConfig,
platform *v1.Platform,
name string,
) (string, error) {
_, err := c.PullImage(ctx, config.Image)
if err != nil {
return "", err
}
resp, err := c.client.ContainerCreate(
ctx,
config,
hostConfig,
networkingConfig,
platform,
name,
)
if err != nil {
return "", err
}
return resp.ID, nil
}
// InspectContainer returns detailed information about a Docker container.
func (c *Client) InspectContainer(ctx context.Context, id string) (types.ContainerJSON, error) {
return c.client.ContainerInspect(ctx, id)
}
// FollowLogs tails the logs of a specified container, returning separate readers for stdout and stderr.
func (c *Client) FollowLogs(ctx context.Context, id string) (stdout, stderr io.Reader, err error) {
cont, err := c.InspectContainer(ctx, id)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: true,
}
logsReader, err := c.client.ContainerLogs(ctx, cont.ID, logOptions)
if err != nil {
return nil, nil, errors.Wrap(err, "failed to get container logs")
}
stdoutReader, stdoutWriter := io.Pipe()
stderrReader, stderrWriter := io.Pipe()
go func() {
stdoutBuffer := bufio.NewWriter(stdoutWriter)
stderrBuffer := bufio.NewWriter(stderrWriter)
defer func() {
logsReader.Close()
stdoutBuffer.Flush()
stdoutWriter.Close()
stderrBuffer.Flush()
stderrWriter.Close()
}()
_, err = stdcopy.StdCopy(stdoutBuffer, stderrBuffer, logsReader)
if err != nil && !errors.Is(err, context.Canceled) {
zlog.Sugar().Warnf("context closed while getting logs: %v\n", err)
}
}()
return stdoutReader, stderrReader, nil
}
// StartContainer starts a specified Docker container.
func (c *Client) StartContainer(ctx context.Context, containerID string) error {
return c.client.ContainerStart(ctx, containerID, types.ContainerStartOptions{})
}
// WaitContainer waits for a container to stop, returning channels for the result and errors.
func (c *Client) WaitContainer(
ctx context.Context,
containerID string,
) (<-chan container.ContainerWaitOKBody, <-chan error) {
return c.client.ContainerWait(ctx, containerID, container.WaitConditionNotRunning)
}
// StopContainer stops a running Docker container with a specified timeout.
func (c *Client) StopContainer(
ctx context.Context,
containerID string,
timeout time.Duration,
) error {
return c.client.ContainerStop(ctx, containerID, &timeout)
}
// RemoveContainer removes a Docker container, optionally forcing removal and removing associated volumes.
func (c *Client) RemoveContainer(ctx context.Context, containerID string) error {
return c.client.ContainerRemove(
ctx,
containerID,
types.ContainerRemoveOptions{RemoveVolumes: true, Force: true},
)
}
// removeContainers removes all containers matching the specified filters.
func (c *Client) removeContainers(ctx context.Context, filterz filters.Args) error {
containers, err := c.client.ContainerList(
ctx,
types.ContainerListOptions{All: true, Filters: filterz},
)
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(containers))
for _, container := range containers {
wg.Add(1)
go func(container types.Container, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.RemoveContainer(ctx, container.ID)
}(container, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// removeNetworks removes all networks matching the specified filters.
func (c *Client) removeNetworks(ctx context.Context, filterz filters.Args) error {
networks, err := c.client.NetworkList(ctx, types.NetworkListOptions{Filters: filterz})
if err != nil {
return err
}
wg := sync.WaitGroup{}
errCh := make(chan error, len(networks))
for _, network := range networks {
wg.Add(1)
go func(network types.NetworkResource, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- c.client.NetworkRemove(ctx, network.ID)
}(network, &wg, errCh)
}
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
return errs
}
// RemoveObjectsWithLabel removes all Docker containers and networks with a specific label.
func (c *Client) RemoveObjectsWithLabel(ctx context.Context, label string, value string) error {
filterz := filters.NewArgs(
filters.Arg("label", fmt.Sprintf("%s=%s", label, value)),
)
containerErr := c.removeContainers(ctx, filterz)
networkErr := c.removeNetworks(ctx, filterz)
return multierr.Combine(containerErr, networkErr)
}
// GetOutputStream streams the logs for a specified container.
// The 'since' parameter specifies the timestamp from which to start streaming logs.
// The 'follow' parameter indicates whether to continue streaming logs as they are produced.
// Returns an io.ReadCloser to read the output stream and an error if the operation fails.
func (c *Client) GetOutputStream(
ctx context.Context,
containerID string,
since string,
follow bool,
) (io.ReadCloser, error) {
cont, err := c.InspectContainer(ctx, containerID)
if err != nil {
return nil, errors.Wrap(err, "failed to get container")
}
if !cont.State.Running {
return nil, fmt.Errorf("cannot get logs for a container that is not running")
}
logOptions := types.ContainerLogsOptions{
ShowStdout: true,
ShowStderr: true,
Follow: follow,
Since: since,
}
logReader, err := c.client.ContainerLogs(ctx, containerID, logOptions)
if err != nil {
return nil, errors.Wrap(err, "failed to get container logs")
}
return logReader, nil
}
// FindContainer searches for a container by label and value, returning its ID if found.
func (c *Client) FindContainer(ctx context.Context, label string, value string) (string, error) {
containers, err := c.client.ContainerList(ctx, types.ContainerListOptions{All: true})
if err != nil {
return "", err
}
for _, container := range containers {
if container.Labels[label] == value {
return container.ID, nil
}
}
return "", fmt.Errorf("unable to find container for %s=%s", label, value)
}
// PullImage pulls a Docker image from a registry.
func (c *Client) PullImage(ctx context.Context, imageName string) (string, error) {
out, err := c.client.ImagePull(ctx, imageName, types.ImagePullOptions{})
if err != nil {
zlog.Sugar().Errorf("unable to pull image: %v", err)
return "", err
}
defer out.Close()
d := json.NewDecoder(io.TeeReader(out, os.Stdout))
var message jsonmessage.JSONMessage
var digest string
for {
if err := d.Decode(&message); err != nil {
if err == io.EOF {
break
}
zlog.Sugar().Errorf("unable pull image: %v", err)
return "", err
}
if message.Aux != nil {
continue
}
if message.Error != nil {
zlog.Sugar().Errorf("unable pull image: %v", message.Error.Message)
return "", errors.New(message.Error.Message)
}
if strings.HasPrefix(message.Status, "Digest") {
digest = strings.TrimPrefix(message.Status, "Digest: ")
}
}
return digest, nil
}
package docker
import (
"context"
"fmt"
"io"
"os"
"sync/atomic"
"time"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/mount"
"gitlab.com/nunet/device-management-service/dms/resources"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
labelExecutorName = "nunet-executor"
labelJobID = "nunet-jobID"
labelExecutionID = "nunet-executionID"
outputStreamCheckTickTime = 100 * time.Millisecond
outputStreamCheckTimeout = 5 * time.Second
)
// Executor manages the lifecycle of Docker containers for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Docker client for container management.
}
// NewExecutor initializes a new Executor instance with a Docker client.
func NewExecutor(ctx context.Context, id string) (*Executor, error) {
dockerClient, err := NewDockerClient()
if err != nil {
return nil, err
}
if !dockerClient.IsInstalled(ctx) {
return nil, fmt.Errorf("docker is not installed")
}
return &Executor{
ID: id,
client: dockerClient,
}, nil
}
// Start begins the execution of a request by starting a Docker container.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// container is already running.
containerID, err := e.FindRunningContainer(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running container for this execution, we will instead check for a handler, and
// failing that will create a new container.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
}
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
containerID, err = e.newDockerExecutionContainer(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new container: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
containerID: containerID,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
TTYEnabled: true,
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the container.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s received results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// GetLogStream provides a stream of output logs for a specific execution.
// Parameters 'withHistory' and 'follow' control whether to include past logs
// and whether to keep the stream open for new logs, respectively.
// It returns an error if the execution is not found.
func (e *Executor) GetLogStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
// It's possible we've recorded the execution as running, but have not yet added the handler to
// the handler map because we're still waiting for the container to start. We will try and wait
// for a few seconds to see if the handler is added to the map.
chHandler := make(chan *executionHandler)
chExit := make(chan struct{})
go func(ch chan *executionHandler, exit chan struct{}) {
// Check the handlers every 100ms and send it down the
// channel if we find it. If we don't find it after 5 seconds
// then we'll be told on the exit channel
ticker := time.NewTicker(outputStreamCheckTickTime)
defer ticker.Stop()
for {
select {
case <-ticker.C:
h, found := e.handlers.Get(request.ExecutionID)
if found {
ch <- h
return
}
case <-exit:
ticker.Stop()
return
}
}
}(chHandler, chExit)
// Either we'll find a handler for the execution (which might have finished starting)
// or we'll timeout and return an error.
select {
case handler := <-chHandler:
return handler.outputStream(ctx, request)
case <-time.After(outputStreamCheckTimeout):
chExit <- struct{}{}
}
return nil, fmt.Errorf("execution (%s) not found", request.ExecutionID)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// Cleanup removes all Docker resources associated with the executor.
// This includes removing containers including networks and volumes with the executor's label.
func (e *Executor) Cleanup(ctx context.Context) error {
err := e.client.RemoveObjectsWithLabel(ctx, labelExecutorName, e.ID)
if err != nil {
return fmt.Errorf("failed to remove containers: %w", err)
}
zlog.Info("Cleaned up all Docker resources")
return nil
}
// newDockerExecutionContainer is an internal method called by Start to set up a new Docker container
// for the job execution. It configures the container based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up environment variables, mounts and resource
// constraints. It then creates the container but does not start it.
// The method returns a container.CreateResponse and an error if any part of the setup fails.
func (e *Executor) newDockerExecutionContainer(
ctx context.Context,
params *types.ExecutionRequest,
) (string, error) {
dockerArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return "", fmt.Errorf("failed to decode docker engine spec: %w", err)
}
// TODO: Move this code block ( L263-272) to the allocator in future
// Select the GPU with the highest available free VRAM and choose the GPU vendor for container's host config
gpus, err := resources.ManagerInstance.SystemSpecs().GetGPUs()
if err != nil {
return "", fmt.Errorf("failed to get GPU info: %w", err)
}
maxFreeVRAMGpu, err := types.GPUList(gpus).GetGPUWithHighestFreeVRAM()
if err != nil {
return "", fmt.Errorf("failed to get GPU with highest free VRAM: %w", err)
}
// Essential for multi-vendor GPU nodes. For example,
// if a machine has an 8 GB NVIDIA and a 16 GB Intel GPU, the latter should be used first.
// Even for machines with a single GPU, this is important as integrated GPUs would also be commonly detected.
chosenGPUVendor := maxFreeVRAMGpu.Vendor
containerConfig := container.Config{
Image: dockerArgs.Image,
Tty: true, // Needs to be true for applications such as Jupyter or Gradio to work correctly. See issue #459 for details.
Env: dockerArgs.Environment,
Entrypoint: dockerArgs.Entrypoint,
Cmd: dockerArgs.Cmd,
Labels: e.containerLabels(params.JobID, params.ExecutionID),
WorkingDir: dockerArgs.WorkingDirectory,
}
mounts, err := makeContainerMounts(params.Inputs, params.Outputs, params.ResultsDir)
if err != nil {
return "", fmt.Errorf("failed to create container mounts: %w", err)
}
zlog.Sugar().Infof("Adding %d GPUs to request", len(params.Resources.GPUs))
hostConfig := configureHostConfig(chosenGPUVendor, params, mounts)
if _, err = e.client.PullImage(ctx, dockerArgs.Image); err != nil {
return "", fmt.Errorf("failed to pull docker image: %w", err)
}
executionContainer, err := e.client.CreateContainer(
ctx,
&containerConfig,
&hostConfig,
nil,
nil,
labelExecutionValue(e.ID, params.JobID, params.ExecutionID),
)
if err != nil {
return "", fmt.Errorf("failed to create container: %w", err)
}
return executionContainer, nil
}
// configureHostConfig sets up the host configuration for the container based on the
// GPU vendor and resources requested by the execution. It supports both GPU and CPU configurations.
func configureHostConfig(vendor types.GPUVendor, params *types.ExecutionRequest, mounts []mount.Mount) container.HostConfig {
var hostConfig container.HostConfig
switch vendor {
case types.GPUVendorNvidia:
deviceIDs := make([]string, len(params.Resources.GPUs))
for i, gpu := range params.Resources.GPUs {
deviceIDs[i] = fmt.Sprint(gpu.Index)
}
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: params.Resources.CPU.ClockSpeedHz,
CPUCount: int64(params.Resources.CPU.Cores),
DeviceRequests: []container.DeviceRequest{
{
DeviceIDs: deviceIDs,
Capabilities: [][]string{{"gpu"}},
},
},
},
}
case types.GPUVendorAMDATI:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/kfd:/dev/kfd",
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: params.Resources.CPU.ClockSpeedHz,
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/kfd",
PathInContainer: "/dev/kfd",
CgroupPermissions: "rwm",
},
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
GroupAdd: []string{"video"},
}
// Updated the device handling for Intel GPUs.
// Previously, specific device paths were determined using PCI addresses and symlinks.
// Now, the approach has been simplified by directly binding the entire /dev/dri directory.
// This change exposes all Intel GPUs to the container, which may be preferable for
// environments with multiple Intel GPUs. It reduces complexity as granular control
// is not required if all GPUs need to be accessible.
case types.GPUVendorIntel:
hostConfig = container.HostConfig{
Mounts: mounts,
Binds: []string{
"/dev/dri:/dev/dri",
},
Resources: container.Resources{
NanoCPUs: params.Resources.CPU.ClockSpeedHz,
CPUCount: int64(params.Resources.CPU.Cores),
Devices: []container.DeviceMapping{
{
PathOnHost: "/dev/dri",
PathInContainer: "/dev/dri",
CgroupPermissions: "rwm",
},
},
},
}
default:
hostConfig = container.HostConfig{
Mounts: mounts,
Resources: container.Resources{
NanoCPUs: params.Resources.CPU.ClockSpeedHz,
CPUCount: int64(params.Resources.CPU.Cores),
},
}
}
return hostConfig
}
// makeContainerMounts creates the mounts for the container based on the input and output
// volumes provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeContainerMounts(
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]mount.Mount, error) {
// the actual mounts we will give to the container
// these are paths for both input and output data
mounts := make([]mount.Mount, 0)
for _, input := range inputs {
if input.Type != types.StorageVolumeTypeBind {
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: input.Source,
Target: input.Target,
ReadOnly: input.ReadOnly,
})
} else {
return nil, fmt.Errorf("unsupported storage volume type: %s", input.Type)
}
}
for _, output := range outputs {
if output.Source == "" {
return nil, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return nil, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return nil, fmt.Errorf("failed to create results directory: %w", err)
}
mounts = append(mounts, mount.Mount{
Type: mount.TypeBind,
Source: output.Source,
Target: output.Target,
// this is an output volume so can be written to
ReadOnly: false,
})
}
return mounts, nil
}
// containerLabels returns the labels to be applied to the container for the given job and execution.
func (e *Executor) containerLabels(jobID string, executionID string) map[string]string {
return map[string]string{
labelExecutorName: e.ID,
labelJobID: labelJobValue(e.ID, jobID),
labelExecutionID: labelExecutionValue(e.ID, jobID, executionID),
}
}
// labelJobValue returns the value for the job label.
func labelJobValue(executorID string, jobID string) string {
return fmt.Sprintf("%s_%s", executorID, jobID)
}
// labelExecutionValue returns the value for the execution label.
func labelExecutionValue(executorID string, jobID string, executionID string) string {
return fmt.Sprintf("%s_%s_%s", executorID, jobID, executionID)
}
// FindRunningContainer finds the container that is running the execution
// with the given ID. It returns the container ID if found, or an error if
// the container is not found.
func (e *Executor) FindRunningContainer(
ctx context.Context,
jobID string,
executionID string,
) (string, error) {
labelValue := labelExecutionValue(e.ID, jobID, executionID)
return e.client.FindContainer(ctx, labelExecutionID, labelValue)
}
package docker
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"strconv"
"sync/atomic"
"time"
"gitlab.com/nunet/device-management-service/types"
)
const DestroyTimeout = time.Second * 10
// executionHandler manages the lifecycle and execution of a Docker container for a specific job.
type executionHandler struct {
// provided by the executor
ID string
client *Client // Docker client for container management.
// meta data about the task
jobID string
executionID string
containerID string
resultsDir string // Directory to store execution results.
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // Blocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
// TTY setting
TTYEnabled bool // Indicates if TTY is enabled for the container.
}
// active checks if the execution handler's container is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the container and handles its execution lifecycle.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
if err := h.destroy(DestroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
if err := h.client.StartContainer(ctx, h.containerID); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start container: %v", err))
return
}
close(h.activeCh) // Indicate that the container has started.
var containerError error
var containerExitStatusCode int64
// Wait for the container to finish or for an execution error.
statusCh, errCh := h.client.WaitContainer(ctx, h.containerID)
select {
case status := <-ctx.Done():
h.result = types.NewFailedExecutionResult(fmt.Errorf("execution cancelled: %v", status))
return
case err := <-errCh:
zlog.Sugar().Errorf("error while waiting for container: %v\n", err)
h.result = types.NewFailedExecutionResult(
fmt.Errorf("failed to wait for container: %v", err),
)
return
case exitStatus := <-statusCh:
containerExitStatusCode = exitStatus.StatusCode
containerJSON, err := h.client.InspectContainer(ctx, h.containerID)
if err != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: err.Error(),
}
return
}
if containerJSON.ContainerJSONBase.State.OOMKilled {
containerError = errors.New("container was killed due to OOM")
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: containerError.Error(),
}
return
}
if exitStatus.Error != nil {
containerError = errors.New(exitStatus.Error.Message)
}
}
// Follow container logs to capture stdout and stderr.
stdoutPipe, stderrPipe, logsErr := h.client.FollowLogs(ctx, h.containerID)
if logsErr != nil {
followError := fmt.Errorf("failed to follow container logs: %w", logsErr)
if containerError != nil {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: fmt.Sprintf(
"container error: '%s'. logs error: '%s'",
containerError,
followError,
),
}
} else {
h.result = &types.ExecutionResult{
ExitCode: int(containerExitStatusCode),
ErrorMsg: followError.Error(),
}
}
return
}
// Initialize the result with the exit status code.
h.result = types.NewExecutionResult(int(containerExitStatusCode))
// Capture the logs based on the TTY setting.
if h.TTYEnabled {
// TTY combines stdout and stderr, read from stdoutPipe only.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
} else {
// Read from stdout and stderr separately.
h.result.STDOUT, _ = bufio.NewReader(stdoutPipe).ReadString('\x00') // EOF delimiter
h.result.STDERR, _ = bufio.NewReader(stderrPipe).ReadString('\x00')
}
}
// kill sends a stop signal to the container.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.StopContainer(ctx, h.containerID, DestroyTimeout)
}
// destroy cleans up the container and its associated resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// stop the container
if err := h.kill(ctx); err != nil {
return fmt.Errorf("failed to kill container (%s): %w", h.containerID, err)
}
if err := h.client.RemoveContainer(ctx, h.containerID); err != nil {
return err
}
// Remove related objects like networks or volumes created for this execution.
return h.client.RemoveObjectsWithLabel(
ctx,
labelExecutionID,
labelExecutionValue(h.ID, h.jobID, h.executionID),
)
}
func (h *executionHandler) outputStream(
ctx context.Context,
request types.LogStreamRequest,
) (io.ReadCloser, error) {
since := "1" // Default to the start of UNIX time to get all logs.
if request.Tail {
since = strconv.FormatInt(time.Now().Unix(), 10)
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-h.activeCh: // Ensure the container is active before attempting to stream logs.
}
// Gets the underlying reader, and provides data since the value of the `since` timestamp.
return h.client.GetOutputStream(ctx, h.containerID, since, request.Follow)
}
package docker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("docker.executor")
}
package docker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyImage = "image"
EngineKeyEntrypoint = "entrypoint"
EngineKeyCmd = "cmd"
EngineKeyEnvironment = "environment"
EngineKeyWorkingDirectory = "working_directory"
)
// EngineSpec contains necessary parameters to execute a docker job.
type EngineSpec struct {
// Image this should be pullable by docker
Image string `json:"image,omitempty"`
// Entrypoint optionally override the default entrypoint
Entrypoint []string `json:"entrypoint,omitempty"`
// Cmd specifies the command to run in the container
Cmd []string `json:"cmd,omitempty"`
// EnvironmentVariables is a slice of env to run the container with
Environment []string `json:"environment,omitempty"`
// WorkingDirectory inside the container
WorkingDirectory string `json:"working_directory,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.Image) {
return fmt.Errorf("invalid docker engine params: image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a docker engine spec
// It converts the params into a docker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(types.ExecutorTypeDocker) {
return EngineSpec{}, fmt.Errorf(
"invalid docker engine type. expected %s, but received: %s",
types.ExecutorTypeDocker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid docker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode docker engine params: %w", err)
}
var dockerSpec *EngineSpec
if err := json.Unmarshal(paramBytes, &dockerSpec); err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode docker engine params: %w", err)
}
return *dockerSpec, dockerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Docker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewDockerEngineBuilder function initializes a new DockerEngineBuilder instance.
// It sets the engine type to model.EngineDocker.String() and image as per the input argument.
func NewDockerEngineBuilder(image string) *EngineBuilder {
eb := types.NewSpecConfig(types.ExecutorTypeDocker)
eb.WithParam(EngineKeyImage, image)
return &EngineBuilder{eb: eb}
}
// WithEntrypoint is a builder method that sets the Docker engine entrypoint.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEntrypoint(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEntrypoint, e)
return b
}
// WithCmd is a builder method that sets the Docker engine's Command.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithCmd(c ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyCmd, c)
return b
}
// WithEnvironment is a builder method that sets the Docker engine's environment variables.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithEnvironment(e ...string) *EngineBuilder {
b.eb.WithParam(EngineKeyEnvironment, e)
return b
}
// WithWorkingDirectory is a builder method that sets the Docker engine's working directory.
// It returns the DockerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithWorkingDirectory(w string) *EngineBuilder {
b.eb.WithParam(EngineKeyWorkingDirectory, w)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
package firecracker
import (
"context"
"fmt"
"os"
"os/exec"
"syscall"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
)
const pidCheckTickTime = 100 * time.Millisecond
// Client wraps the Firecracker SDK to provide high-level operations on Firecracker VMs.
type Client struct{}
func NewFirecrackerClient() (*Client, error) {
return &Client{}, nil
}
// IsInstalled checks if Firecracker is installed on the host.
func (c *Client) IsInstalled() bool {
// LookPath searches for an executable named file in the directories named by the PATH environment variable.
// There might be a better way to check if Firecracker is installed.
_, err := exec.LookPath("firecracker")
return err == nil
}
// CreateVM creates a new Firecracker VM with the specified configuration.
func (c *Client) CreateVM(
ctx context.Context,
cfg firecracker.Config,
) (*firecracker.Machine, error) {
cmd := firecracker.VMCommandBuilder{}.
WithSocketPath(cfg.SocketPath).
Build(ctx)
machineOpts := []firecracker.Opt{
firecracker.WithProcessRunner(cmd),
}
m, err := firecracker.NewMachine(ctx, cfg, machineOpts...)
return m, err
}
// StartVM starts the Firecracker VM.
func (c *Client) StartVM(ctx context.Context, m *firecracker.Machine) error {
return m.Start(ctx)
}
// ShutdownVM shuts down the Firecracker VM.
func (c *Client) ShutdownVM(ctx context.Context, m *firecracker.Machine) error {
return m.Shutdown(ctx)
}
// DestroyVM destroys the Firecracker VM.
func (c *Client) DestroyVM(
ctx context.Context,
m *firecracker.Machine,
timeout time.Duration,
) error {
// Get the PID of the Firecracker process and shut down the VM.
// If the process is still running after the timeout, kill it.
err := c.ShutdownVM(ctx, m)
if err != nil {
return fmt.Errorf("failed to shutdown vm: %w", err)
}
pid, _ := m.PID()
defer os.Remove(m.Cfg.SocketPath)
// If the process is not running, return early.
if pid <= 0 {
return nil
}
// This checks if the process is still running every pidCheckTickTime.
// If the process is still running after the timeout it will set done to false.
done := make(chan bool, 1)
go func() {
ticker := time.NewTicker(pidCheckTickTime)
defer ticker.Stop()
to := time.NewTimer(timeout)
defer to.Stop()
for {
select {
case <-to.C:
done <- false
return
case <-ticker.C:
if pid, _ := m.PID(); pid <= 0 {
done <- true
return
}
}
}
}()
// Wait for the check to finish.
killed := <-done
if !killed {
// The shutdown request timed out, kill the process with SIGKILL.
err := syscall.Kill(pid, syscall.SIGKILL)
if err != nil {
return fmt.Errorf("failed to kill process: %v", err)
}
}
return nil
}
// FindVM finds a Firecracker VM by its socket path.
// This implementation checks if the VM is running by sending a request to the Firecracker API.
func (c *Client) FindVM(ctx context.Context, socketPath string) (*firecracker.Machine, error) {
// Check if the socket file exists.
if _, err := os.Stat(socketPath); err != nil {
return nil, fmt.Errorf("VM with socket path %v not found", socketPath)
}
// Create a new Firecracker machine instance.
cmd := firecracker.VMCommandBuilder{}.WithSocketPath(socketPath).Build(ctx)
machine, err := firecracker.NewMachine(
ctx,
firecracker.Config{SocketPath: socketPath},
firecracker.WithProcessRunner(cmd),
)
if err != nil {
return nil, fmt.Errorf("failed to create machine with socket %s: %v", socketPath, err)
}
// Check if the VM is running by getting its instance info.
info, err := machine.DescribeInstanceInfo(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get instance info for socket %s: %v", socketPath, err)
}
if *info.State != "Running" {
return nil, fmt.Errorf(
"VM with socket %s is not running, current state: %s",
socketPath,
*info.State,
)
}
return machine, nil
}
package firecracker
import (
"context"
"fmt"
"io"
"os"
"sync"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
fcModels "github.com/firecracker-microvm/firecracker-go-sdk/client/models"
"go.uber.org/multierr"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
const (
socketDir = "/tmp"
)
// Executor manages the lifecycle of Firecracker VMs for execution requests.
type Executor struct {
ID string
handlers utils.SyncMap[string, *executionHandler] // Maps execution IDs to their handlers.
client *Client // Firecracker client for VM management.
}
// NewExecutor initializes a new executor for Firecracker VMs.
func NewExecutor(_ context.Context, id string) (*Executor, error) {
firecrackerClient, err := NewFirecrackerClient()
if err != nil {
return nil, err
}
if !firecrackerClient.IsInstalled() {
return nil, fmt.Errorf("firecracker is not installed")
}
fe := &Executor{
ID: id,
client: firecrackerClient,
}
return fe, nil
}
// start begins the execution of a request by starting a new Firecracker VM.
func (e *Executor) Start(ctx context.Context, request *types.ExecutionRequest) error {
zlog.Sugar().
Infof("Starting execution for job %s, execution %s", request.JobID, request.ExecutionID)
// It's possible that this is being called due to a restart. We should check if the
// VM is already running.
machine, err := e.FindRunningVM(ctx, request.JobID, request.ExecutionID)
if err != nil {
// Unable to find a running VM for this execution, we will instead check for a handler, and
// failing that will create a new VM.
if handler, ok := e.handlers.Get(request.ExecutionID); ok {
if handler.active() {
return fmt.Errorf("execution is already started")
}
return fmt.Errorf("execution is already completed")
}
// Create a new handler for the execution.
machine, err = e.newFirecrackerExecutionVM(ctx, request)
if err != nil {
return fmt.Errorf("failed to create new firecracker VM: %w", err)
}
}
handler := &executionHandler{
client: e.client,
ID: e.ID,
executionID: request.ExecutionID,
machine: machine,
resultsDir: request.ResultsDir,
waitCh: make(chan bool),
activeCh: make(chan bool),
running: &atomic.Bool{},
}
// register the handler for this executionID
e.handlers.Put(request.ExecutionID, handler)
// run the VM.
go handler.run(ctx)
return nil
}
// Wait initiates a wait for the completion of a specific execution using its
// executionID. The function returns two channels: one for the result and another
// for any potential error. If the executionID is not found, an error is immediately
// sent to the error channel. Otherwise, an internal goroutine (doWait) is spawned
// to handle the asynchronous waiting. Callers should use the two returned channels
// to wait for the result of the execution or an error. This can be due to issues
// either beginning the wait or in getting the response. This approach allows the
// caller to synchronize Wait with calls to Start, waiting for the execution to complete.
func (e *Executor) Wait(
ctx context.Context,
executionID string,
) (<-chan *types.ExecutionResult, <-chan error) {
handler, found := e.handlers.Get(executionID)
resultCh := make(chan *types.ExecutionResult, 1)
errCh := make(chan error, 1)
if !found {
errCh <- fmt.Errorf("execution (%s) not found", executionID)
return resultCh, errCh
}
go e.doWait(ctx, resultCh, errCh, handler)
return resultCh, errCh
}
// doWait is a helper function that actively waits for an execution to finish. It
// listens on the executionHandler's wait channel for completion signals. Once the
// signal is received, the result is sent to the provided output channel. If there's
// a cancellation request (context is done) before completion, an error is relayed to
// the error channel. If the execution result is nil, an error suggests a potential
// flaw in the executor logic.
func (e *Executor) doWait(
ctx context.Context,
out chan *types.ExecutionResult,
errCh chan error,
handler *executionHandler,
) {
zlog.Sugar().Infof("executionID %s waiting for execution", handler.executionID)
defer close(out)
defer close(errCh)
select {
case <-ctx.Done():
errCh <- ctx.Err() // Send the cancellation error to the error channel
return
case <-handler.waitCh:
if handler.result != nil {
zlog.Sugar().
Infof("executionID %s received results from execution", handler.executionID)
out <- handler.result
} else {
errCh <- fmt.Errorf("execution (%s) result is nil", handler.executionID)
}
}
}
// Cancel tries to cancel a specific execution by its executionID.
// It returns an error if the execution is not found.
func (e *Executor) Cancel(ctx context.Context, executionID string) error {
handler, found := e.handlers.Get(executionID)
if !found {
return fmt.Errorf("failed to cancel execution (%s). execution not found", executionID)
}
return handler.kill(ctx)
}
// Run initiates and waits for the completion of an execution in one call.
// This method serves as a higher-level convenience function that
// internally calls Start and Wait methods.
// It returns the result of the execution or an error if either starting
// or waiting fails, or if the context is canceled.
func (e *Executor) Run(
ctx context.Context,
request *types.ExecutionRequest,
) (*types.ExecutionResult, error) {
if err := e.Start(ctx, request); err != nil {
return nil, err
}
resCh, errCh := e.Wait(ctx, request.ExecutionID)
select {
case <-ctx.Done():
return nil, ctx.Err()
case out := <-resCh:
return out, nil
case err := <-errCh:
return nil, err
}
}
// GetLogStream is not implemented for Firecracker.
// It is defined to satisfy the Executor interface.
// This method will return an error if called.
func (e *Executor) GetLogStream(_ context.Context, _ types.LogStreamRequest) (io.ReadCloser, error) {
return nil, fmt.Errorf("GetLogStream is not implemented for Firecracker")
}
// Cleanup removes all resources associated with the executor.
// This includes stopping and removing all running VMs and deleting their socket paths.
func (e *Executor) Cleanup(_ context.Context) error {
wg := sync.WaitGroup{}
errCh := make(chan error, len(e.handlers.Keys()))
e.handlers.Iter(func(_ string, handler *executionHandler) bool {
wg.Add(1)
go func(handler *executionHandler, wg *sync.WaitGroup, errCh chan error) {
defer wg.Done()
errCh <- handler.destroy(time.Second * 10)
}(handler, &wg, errCh)
return true
})
go func() {
wg.Wait()
close(errCh)
}()
var errs error
for err := range errCh {
errs = multierr.Append(errs, err)
}
zlog.Info("Cleaned up all firecracker resources")
return errs
}
// newFirecrackerExecutionVM is an internal method called by Start to set up a new Firecracker VM
// for the job execution. It configures the VM based on the provided ExecutionRequest.
// This includes decoding engine specifications, setting up mounts and resource constraints.
// It then creates the VM but does not start it. The method returns a firecracker.Machine instance
// and an error if any part of the setup fails.
func (e *Executor) newFirecrackerExecutionVM(
ctx context.Context,
params *types.ExecutionRequest,
) (*firecracker.Machine, error) {
fcArgs, err := DecodeSpec(params.EngineSpec)
if err != nil {
return nil, fmt.Errorf("failed to decode firecracker engine spec: %w", err)
}
fcConfig := firecracker.Config{
VMID: params.ExecutionID,
SocketPath: e.generateSocketPath(params.JobID, params.ExecutionID),
KernelImagePath: fcArgs.KernelImage,
InitrdPath: fcArgs.Initrd,
KernelArgs: fcArgs.KernelArgs,
MachineCfg: fcModels.MachineConfiguration{
VcpuCount: firecracker.Int64(int64(params.Resources.CPU.Cores)),
MemSizeMib: firecracker.Int64(params.Resources.Memory.Size),
},
}
mounts, err := makeVMMounts(
fcArgs.RootFileSystem,
params.Inputs,
params.Outputs,
params.ResultsDir,
)
if err != nil {
return nil, fmt.Errorf("failed to create VM mounts: %w", err)
}
fcConfig.Drives = mounts
machine, err := e.client.CreateVM(ctx, fcConfig)
if err != nil {
return nil, fmt.Errorf("failed to create VM: %w", err)
}
// e.client.VMPassMMDs(ctx, machine, fcArgs.MMDSMessage)
return machine, nil
}
// makeVMMounts creates the mounts for the VM based on the input and output volumes
// provided in the execution request. It also creates the results directory if it
// does not exist. The function returns a list of mounts and an error if any part of the
// process fails.
func makeVMMounts(
rootFileSystem string,
inputs []*types.StorageVolumeExecutor,
outputs []*types.StorageVolumeExecutor,
resultsDir string,
) ([]fcModels.Drive, error) {
var drives []fcModels.Drive
drivesBuilder := firecracker.NewDrivesBuilder(rootFileSystem)
for _, input := range inputs {
drivesBuilder.AddDrive(input.Source, input.ReadOnly)
}
for _, output := range outputs {
if output.Source == "" {
return drives, fmt.Errorf("output source is empty")
}
if resultsDir == "" {
return drives, fmt.Errorf("results directory is empty")
}
if err := os.MkdirAll(resultsDir, os.ModePerm); err != nil {
return drives, fmt.Errorf("failed to create results directory: %w", err)
}
drivesBuilder.AddDrive(output.Source, false)
}
drives = drivesBuilder.Build()
return drives, nil
}
// FindRunningVM finds the VM that is running the execution with the given ID.
// It returns the Mchine instance if found, or an error if the VM is not found.
func (e *Executor) FindRunningVM(
ctx context.Context,
jobID string,
executionID string,
) (*firecracker.Machine, error) {
return e.client.FindVM(ctx, e.generateSocketPath(jobID, executionID))
}
// generateSocketPath generates a socket path based on the job identifiers.
func (e *Executor) generateSocketPath(jobID string, executionID string) string {
return fmt.Sprintf("%s/%s_%s_%s.sock", socketDir, e.ID, jobID, executionID)
}
package firecracker
import (
"context"
"fmt"
"sync/atomic"
"time"
"github.com/firecracker-microvm/firecracker-go-sdk"
"gitlab.com/nunet/device-management-service/types"
)
// executionHandler is a struct that holds the necessary information to manage the execution of a firecracker VM.
type executionHandler struct {
//
// provided by the executor
ID string
client *Client
// meta data about the task
JobID string
executionID string
machine *firecracker.Machine
resultsDir string
// synchronization
// synchronization
activeCh chan bool // Blocks until the container starts running.
waitCh chan bool // BLocks until execution completes or fails.
running *atomic.Bool // Indicates if the container is currently running.
// result of the execution
result *types.ExecutionResult
}
// active returns true if the firecracker VM is running.
func (h *executionHandler) active() bool {
return h.running.Load()
}
// run starts the firecracker VM and waits for it to finish.
func (h *executionHandler) run(ctx context.Context) {
h.running.Store(true)
defer func() {
destroyTimeout := time.Second * 10
if err := h.destroy(destroyTimeout); err != nil {
zlog.Sugar().Warnf("failed to destroy container: %v\n", err)
}
h.running.Store(false)
close(h.waitCh)
}()
// start the VM
zlog.Sugar().Info("starting firecracker execution")
if err := h.client.StartVM(ctx, h.machine); err != nil {
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to start VM: %v", err))
return
}
close(h.activeCh) // Indicate that the VM has started.
err := h.machine.Wait(ctx)
if err != nil {
if ctx.Err() != nil {
h.result = types.NewFailedExecutionResult(
fmt.Errorf("context closed while waiting on VM: %v", err),
)
return
}
h.result = types.NewFailedExecutionResult(fmt.Errorf("failed to wait on VM: %v", err))
return
}
h.result = types.NewExecutionResult(types.ExecutionStatusCodeSuccess)
}
// kill stops the firecracker VM.
func (h *executionHandler) kill(ctx context.Context) error {
return h.client.ShutdownVM(ctx, h.machine)
}
// destroy stops the firecracker VM and removes its resources.
func (h *executionHandler) destroy(timeout time.Duration) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return h.client.DestroyVM(ctx, h.machine, timeout)
}
package firecracker
import (
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *logger.Logger
func init() {
zlog = logger.New("executor.firecracker")
}
package firecracker
import (
"encoding/json"
"fmt"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils/validate"
)
const (
EngineKeyKernelImage = "kernel_image"
EngineKeyKernelArgs = "kernel_args"
EngineKeyRootFileSystem = "root_file_system"
EngineKeyMMDSMessage = "mmds_message"
)
// EngineSpec contains necessary parameters to execute a firecracker job.
type EngineSpec struct {
// KernelImage is the path to the kernel image file.
KernelImage string `json:"kernel_image,omitempty"`
// InitrdPath is the path to the initial ramdisk file.
Initrd string `json:"initrd_path,omitempty"`
// KernelArgs is the kernel command line arguments.
KernelArgs string `json:"kernel_args,omitempty"`
// RootFileSystem is the path to the root file system.
RootFileSystem string `json:"root_file_system,omitempty"`
// MMDSMessage is the MMDS message to be sent to the Firecracker VM.
MMDSMessage string `json:"mmds_message,omitempty"`
}
// Validate checks if the engine spec is valid
func (c EngineSpec) Validate() error {
if validate.IsBlank(c.RootFileSystem) {
return fmt.Errorf("invalid firecracker engine params: root_file_system cannot be empty")
}
if validate.IsBlank(c.KernelImage) {
return fmt.Errorf("invalid firecracker engine params: kernel_image cannot be empty")
}
return nil
}
// DecodeSpec decodes a spec config into a firecracker engine spec
// It converts the params into a firecracker EngineSpec struct and validates it
func DecodeSpec(spec *types.SpecConfig) (EngineSpec, error) {
if !spec.IsType(types.ExecutorTypeFirecracker) {
return EngineSpec{}, fmt.Errorf(
"invalid firecracker engine type. expected %s, but received: %s",
types.ExecutorTypeFirecracker,
spec.Type,
)
}
inputParams := spec.Params
if inputParams == nil {
return EngineSpec{}, fmt.Errorf("invalid firecracker engine params: params cannot be nil")
}
paramBytes, err := json.Marshal(inputParams)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to encode firecracker engine params: %w", err)
}
var firecrackerSpec *EngineSpec
err = json.Unmarshal(paramBytes, &firecrackerSpec)
if err != nil {
return EngineSpec{}, fmt.Errorf("failed to decode firecracker engine params: %w", err)
}
return *firecrackerSpec, firecrackerSpec.Validate()
}
// EngineBuilder is a struct that is used for constructing an EngineSpec object
// specifically for Firecracker engines using the Builder pattern.
// It embeds an EngineBuilder object for handling the common builder methods.
type EngineBuilder struct {
eb *types.SpecConfig
}
// NewFirecrackerEngineBuilder function initializes a new FirecrackerEngineBuilder instance.
// It sets the engine type to EngineFirecracker.String() and kernel image path as per the input argument.
func NewFirecrackerEngineBuilder(rootFileSystem string) *EngineBuilder {
eb := types.NewSpecConfig(types.ExecutorTypeFirecracker)
eb.WithParam(EngineKeyRootFileSystem, rootFileSystem)
return &EngineBuilder{eb: eb}
}
// WithRootFileSystem is a builder method that sets the Firecracker engine root file system.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithRootFileSystem(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyRootFileSystem, e)
return b
}
// WithKernelImage is a builder method that sets the Firecracker engine kernel image.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelImage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelImage, e)
return b
}
// WithKernelArgs is a builder method that sets the Firecracker engine kernel arguments.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithKernelArgs(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyKernelArgs, e)
return b
}
// WithMMDSMessage is a builder method that sets the Firecracker engine MMDS message.
// It returns the FirecrackerEngineBuilder for further chaining of builder methods.
func (b *EngineBuilder) WithMMDSMessage(e string) *EngineBuilder {
b.eb.WithParam(EngineKeyMMDSMessage, e)
return b
}
// Build method constructs the final SpecConfig object by calling the embedded EngineBuilder's Build method.
func (b *EngineBuilder) Build() *types.SpecConfig {
return b.eb
}
package backgroundtasks
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("background_tasks")
}
package backgroundtasks
import (
"sort"
"sync"
"time"
)
// Scheduler orchestrates the execution of tasks based on their triggers and priority.
type Scheduler struct {
tasks map[int]*Task // Map of tasks by their ID.
runningTasks map[int]bool // Map to keep track of running tasks.
ticker *time.Ticker // Ticker for periodic checks of task triggers.
stopChan chan struct{} // Channel to signal stopping the scheduler.
maxRunningTasks int // Maximum number of tasks that can run concurrently.
lastTaskID int // Counter for assigning unique IDs to tasks.
mu sync.Mutex // Mutex to protect access to task maps.
}
// NewScheduler creates a new Scheduler with a specified limit on running tasks.
func NewScheduler(maxRunningTasks int) *Scheduler {
return &Scheduler{
tasks: make(map[int]*Task),
runningTasks: make(map[int]bool),
ticker: time.NewTicker(1 * time.Second),
stopChan: make(chan struct{}),
maxRunningTasks: maxRunningTasks,
lastTaskID: 0,
}
}
// AddTask adds a new task to the scheduler and initializes its state.
func (s *Scheduler) AddTask(task *Task) *Task {
s.mu.Lock()
defer s.mu.Unlock()
task.ID = s.lastTaskID
task.Enabled = true
for _, trigger := range task.Triggers {
trigger.Reset()
}
s.tasks[task.ID] = task
s.lastTaskID++
return task
}
// RemoveTask removes a task from the scheduler.
func (s *Scheduler) RemoveTask(taskID int) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tasks, taskID)
}
// Start begins the scheduler's task execution loop.
func (s *Scheduler) Start() {
go func() {
for {
select {
case <-s.stopChan:
return
case <-s.ticker.C:
s.runTasks()
}
}
}()
}
// runningTasksCount returns the count of running tasks.
func (s *Scheduler) runningTasksCount() int {
s.mu.Lock()
defer s.mu.Unlock()
count := 0
for _, isRunning := range s.runningTasks {
if isRunning {
count++
}
}
return count
}
// runTasks checks and runs tasks based on their triggers and priority.
func (s *Scheduler) runTasks() {
// Sort tasks by priority.
sortedTasks := make([]*Task, 0, len(s.tasks))
for _, task := range s.tasks {
sortedTasks = append(sortedTasks, task)
}
sort.Slice(sortedTasks, func(i, j int) bool {
return sortedTasks[i].Priority > sortedTasks[j].Priority
})
for _, task := range sortedTasks {
if !task.Enabled || s.runningTasks[task.ID] {
continue
}
if len(task.Triggers) == 0 {
s.RemoveTask(task.ID)
continue
}
for _, trigger := range task.Triggers {
if trigger.IsReady() && s.runningTasksCount() < s.maxRunningTasks {
s.runningTasks[task.ID] = true
go s.runTask(task.ID)
trigger.Reset()
break
}
}
}
}
// Stop signals the scheduler to stop running tasks.
func (s *Scheduler) Stop() {
close(s.stopChan)
}
// runTask executes a task and manages its lifecycle and retry policy.
func (s *Scheduler) runTask(taskID int) {
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
s.runningTasks[taskID] = false
}()
task := s.tasks[taskID]
execution := Execution{StartedAt: time.Now()}
defer func() {
s.mu.Lock()
task.ExecutionHist = append(task.ExecutionHist, execution)
s.tasks[taskID] = task
s.mu.Unlock()
}()
for i := 0; i < task.RetryPolicy.MaxRetries+1; i++ {
err := runTaskWithRetry(task.Function, task.Args, task.RetryPolicy.Delay)
if err == nil {
execution.Status = "SUCCESS"
execution.EndedAt = time.Now()
return
}
execution.Error = err.Error()
}
execution.Status = "FAILED"
execution.EndedAt = time.Now()
}
// runTaskWithRetry attempts to execute a task with a retry policy.
func runTaskWithRetry(
fn func(args interface{}) error,
args []interface{},
delay time.Duration,
) error {
err := fn(args)
if err != nil {
time.Sleep(delay)
return err
}
return nil
}
package backgroundtasks
import (
"time"
"github.com/robfig/cron/v3"
)
// Trigger interface defines a method to check if a trigger condition is met.
type Trigger interface {
IsReady() bool // Returns true if the trigger condition is met.
Reset() // Resets the trigger state.
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTrigger struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTrigger) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTrigger) Reset() {
t.lastTriggered = time.Now()
}
// PeriodicTrigger triggers at regular intervals or based on a cron expression.
type PeriodicTriggerWithJitter struct {
Interval time.Duration // Interval for periodic triggering.
CronExpr string // Cron expression for triggering.
lastTriggered time.Time // Last time the trigger was activated.
Jitter func() time.Duration
}
// IsReady checks if the trigger should activate based on time or cron expression.
func (t *PeriodicTriggerWithJitter) IsReady() bool {
// Trigger based on interval.
if t.lastTriggered.Add(t.Interval + t.Jitter()).Before(time.Now()) {
return true
}
// Trigger based on cron expression.
if t.CronExpr != "" {
cronExpr, err := cron.ParseStandard(t.CronExpr)
if err != nil {
zlog.Sugar().Errorf("Error parsing CronExpr: %v", err)
return false
}
nextCronTriggerTime := cronExpr.Next(t.lastTriggered)
return nextCronTriggerTime.Before(time.Now())
}
return false
}
// Reset updates the last triggered time to the current time.
func (t *PeriodicTriggerWithJitter) Reset() {
t.lastTriggered = time.Now()
}
// EventTrigger triggers based on an external event signaled through a channel.
type EventTrigger struct {
Trigger chan bool // Channel to signal an event.
}
// IsReady checks if there is a signal in the trigger channel.
func (t *EventTrigger) IsReady() bool {
select {
case <-t.Trigger:
return true
default:
return false
}
}
// Reset for EventTrigger does nothing as its state is managed externally.
func (t *EventTrigger) Reset() {}
// OneTimeTrigger triggers once after a specified delay.
type OneTimeTrigger struct {
Delay time.Duration // The delay after which to trigger.
registeredAt time.Time // Time when the trigger was set.
}
// Reset sets the trigger registration time to the current time.
func (t *OneTimeTrigger) Reset() {
t.registeredAt = time.Now()
}
// IsReady checks if the current time has passed the delay period.
func (t *OneTimeTrigger) IsReady() bool {
return t.registeredAt.Add(t.Delay).Before(time.Now())
}
package libp2p
import (
"bytes"
"context"
"errors"
"fmt"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/multiformats/go-multiaddr"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"google.golang.org/protobuf/proto"
)
// Bootstrap using a list.
func (l *Libp2p) Bootstrap(ctx context.Context, bootstrapPeers []multiaddr.Multiaddr) error {
if err := l.DHT.Bootstrap(ctx); err != nil {
return fmt.Errorf("failed to prepare this node for bootstraping: %w", err)
}
// bootstrap all nodes at the same time.
if len(bootstrapPeers) > 0 {
var wg sync.WaitGroup
for _, addr := range bootstrapPeers {
wg.Add(1)
go func(peerAddr multiaddr.Multiaddr) {
defer wg.Done()
addrInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
zlog.Sugar().Errorf("failed to convert multi addr to addr info %v - %v", peerAddr, err)
return
}
if err := l.Host.Connect(ctx, *addrInfo); err != nil {
zlog.Sugar().Errorf("failed to connect to bootstrap node %s - %v", addrInfo.ID.String(), err)
} else {
zlog.Sugar().Infof("connected to Bootstrap Node %s", addrInfo.ID.String())
}
}(addr)
}
wg.Wait()
}
return nil
}
type dhtValidator struct {
PS peerstore.Peerstore
customNamespace string
}
// Validate validates an item placed into the dht.
func (d dhtValidator) Validate(key string, value []byte) error {
// empty value is considered deleting an item from the dht
if len(value) == 0 {
return nil
}
if !strings.HasPrefix(key, d.customNamespace) {
return errors.New("invalid key namespace")
}
// verify signature
var envelope commonproto.Advertisement
err := proto.Unmarshal(value, &envelope)
if err != nil {
return fmt.Errorf("failed to unmarshal envelope: %w", err)
}
pubKey, err := crypto.UnmarshalSecp256k1PublicKey(envelope.PublicKey)
if err != nil {
return fmt.Errorf("failed to unmarshal public key: %w", err)
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
envelope.PublicKey,
}, nil)
ok, err := pubKey.Verify(concatenatedBytes, envelope.Signature)
if err != nil {
return fmt.Errorf("failed to verify envelope: %w", err)
}
if !ok {
return errors.New("failed to verify envelope, public key didn't sign payload")
}
return nil
}
func (dhtValidator) Select(_ string, _ [][]byte) (int, error) { return 0, nil }
// TODO remove the below when network package is fully implemented
// UpdateKadDHT is a stub
func (l *Libp2p) UpdateKadDHT() {
zlog.Warn("UpdateKadDHT: Stub")
}
// ListKadDHTPeers is a stub
func (l *Libp2p) ListKadDHTPeers(_ context.Context, _ *gin.Context) ([]string, error) {
zlog.Warn("ListKadDHTPeers: Stub")
return nil, nil
}
package libp2p
import (
"context"
"fmt"
"os"
"github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
dutil "github.com/libp2p/go-libp2p/p2p/discovery/util"
)
// DiscoverDialPeers discovers peers using randevouz point
func (l *Libp2p) DiscoverDialPeers(ctx context.Context) error {
foundPeers, err := l.findPeersFromRendezvousDiscovery(ctx)
if err != nil {
return err
}
if len(foundPeers) > 0 {
l.discoveredPeers = foundPeers
}
// filter out peers with no listening addresses and self host
filterSpec := NoAddrIDFilter{ID: l.Host.ID()}
l.discoveredPeers = PeerPassFilter(l.discoveredPeers, filterSpec)
l.dialPeers(ctx)
return nil
}
// advertiseForRendezvousDiscovery is used to advertise node using the dht by giving it the randevouz point.
func (l *Libp2p) advertiseForRendezvousDiscovery(context context.Context) error {
_, err := l.discovery.Advertise(context, l.config.Rendezvous)
return err
}
// findPeersFromRendezvousDiscovery uses the randevouz point to discover other peers.
func (l *Libp2p) findPeersFromRendezvousDiscovery(ctx context.Context) ([]peer.AddrInfo, error) {
peers, err := dutil.FindPeers(
ctx,
l.discovery,
l.config.Rendezvous,
discovery.Limit(l.config.PeerCountDiscoveryLimit),
)
if err != nil {
return nil, fmt.Errorf("failed to discover peers: %w", err)
}
return peers, nil
}
func (l *Libp2p) dialPeers(ctx context.Context) {
for _, p := range l.discoveredPeers {
if p.ID == l.Host.ID() {
continue
}
if l.Host.Network().Connectedness(p.ID) != network.Connected {
_, err := l.Host.Network().DialPeer(ctx, p.ID)
if err != nil {
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("couldn't establish connection with: %s - error: %v", p.ID.String(), err)
}
continue
}
if _, debugMode := os.LookupEnv("NUNET_DEBUG_VERBOSE"); debugMode {
zlog.Sugar().Debugf("connected with: %s", p.ID.String())
}
}
}
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/connmgr"
"github.com/libp2p/go-libp2p/core/control"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/multiformats/go-multiaddr"
mafilt "github.com/whyrusleeping/multiaddr-filter"
)
var defaultServerFilters = []string{
"/ip4/10.0.0.0/ipcidr/8",
"/ip4/100.64.0.0/ipcidr/10",
"/ip4/169.254.0.0/ipcidr/16",
"/ip4/172.16.0.0/ipcidr/12",
"/ip4/192.0.0.0/ipcidr/24",
"/ip4/192.0.2.0/ipcidr/24",
"/ip4/192.168.0.0/ipcidr/16",
"/ip4/198.18.0.0/ipcidr/15",
"/ip4/198.51.100.0/ipcidr/24",
"/ip4/203.0.113.0/ipcidr/24",
"/ip4/240.0.0.0/ipcidr/4",
"/ip6/100::/ipcidr/64",
"/ip6/2001:2::/ipcidr/48",
"/ip6/2001:db8::/ipcidr/32",
"/ip6/fc00::/ipcidr/7",
"/ip6/fe80::/ipcidr/10",
}
// PeerFilter is an interface for filtering peers
// satisfaction of filter criteria allows the peer to pass
type PeerFilter interface {
satisfies(p peer.AddrInfo) bool
}
// NoAddrIDFilter filters out peers with no listening addresses
// and a peer with a specific ID
type NoAddrIDFilter struct {
ID peer.ID
}
func (f NoAddrIDFilter) satisfies(p peer.AddrInfo) bool {
return len(p.Addrs) > 0 && p.ID != f.ID
}
func PeerPassFilter(peers []peer.AddrInfo, pf PeerFilter) []peer.AddrInfo {
var filtered []peer.AddrInfo
for _, p := range peers {
if pf.satisfies(p) {
filtered = append(filtered, p)
}
}
return filtered
}
type filtersConnectionGater multiaddr.Filters
var _ connmgr.ConnectionGater = (*filtersConnectionGater)(nil)
func (f *filtersConnectionGater) InterceptAddrDial(_ peer.ID, addr multiaddr.Multiaddr) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(addr)
}
func (f *filtersConnectionGater) InterceptPeerDial(_ peer.ID) (allow bool) {
return true
}
func (f *filtersConnectionGater) InterceptAccept(connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptSecured(_ network.Direction, _ peer.ID, connAddr network.ConnMultiaddrs) (allow bool) {
return !(*multiaddr.Filters)(f).AddrBlocked(connAddr.RemoteMultiaddr())
}
func (f *filtersConnectionGater) InterceptUpgraded(_ network.Conn) (allow bool, reason control.DisconnectReason) {
return true, 0
}
func makeAddrsFactory(announce []string, appendAnnouce []string, noAnnounce []string) func([]multiaddr.Multiaddr) []multiaddr.Multiaddr {
var err error // To assign to the slice in the for loop
existing := make(map[string]bool) // To avoid duplicates
annAddrs := make([]multiaddr.Multiaddr, len(announce))
for i, addr := range announce {
annAddrs[i], err = multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
existing[addr] = true
}
appendAnnAddrs := make([]multiaddr.Multiaddr, 0)
for _, addr := range appendAnnouce {
if existing[addr] {
// skip AppendAnnounce that is on the Announce list already
continue
}
appendAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
appendAnnAddrs = append(appendAnnAddrs, appendAddr)
}
filters := multiaddr.NewFilters()
noAnnAddrs := map[string]bool{}
for _, addr := range noAnnounce {
f, err := mafilt.NewMask(addr)
if err == nil {
filters.AddFilter(*f, multiaddr.ActionDeny)
continue
}
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil
}
noAnnAddrs[string(maddr.Bytes())] = true
}
return func(allAddrs []multiaddr.Multiaddr) []multiaddr.Multiaddr {
var addrs []multiaddr.Multiaddr
if len(annAddrs) > 0 {
addrs = annAddrs
} else {
addrs = allAddrs
}
addrs = append(addrs, appendAnnAddrs...)
var out []multiaddr.Multiaddr
for _, maddr := range addrs {
// check for exact matches
ok := noAnnAddrs[string(maddr.Bytes())]
// check for /ipcidr matches
if !ok && !filters.AddrBlocked(maddr) {
out = append(out, maddr)
}
}
return out
}
}
package libp2p
import (
"errors"
"sync"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/protocol"
"gitlab.com/nunet/device-management-service/types"
)
// StreamHandler is a function type that processes data from a stream.
type StreamHandler func(stream network.Stream)
// HandlerRegistry manages the registration of stream handlers for different protocols.
type HandlerRegistry struct {
host host.Host
handlers map[protocol.ID]StreamHandler
bytesHandlers map[protocol.ID]func(data []byte)
mu sync.RWMutex
}
// NewHandlerRegistry creates a new handler registry instance.
func NewHandlerRegistry(host host.Host) *HandlerRegistry {
return &HandlerRegistry{
host: host,
handlers: make(map[protocol.ID]StreamHandler),
bytesHandlers: make(map[protocol.ID]func(data []byte)),
}
}
// RegisterHandlerWithStreamCallback registers a stream handler for a specific protocol.
func (r *HandlerRegistry) RegisterHandlerWithStreamCallback(messageType types.MessageType, handler StreamHandler) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.handlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.handlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(handler))
return nil
}
// RegisterHandlerWithBytesCallback registers a stream handler for a specific protocol and sends the bytes back to callback.
func (r *HandlerRegistry) RegisterHandlerWithBytesCallback(messageType types.MessageType, s StreamHandler, handler func(data []byte)) error {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
_, ok := r.bytesHandlers[protoID]
if ok {
return errors.New("stream with this protocol is already registered")
}
r.bytesHandlers[protoID] = handler
r.host.SetStreamHandler(protoID, network.StreamHandler(s))
return nil
}
// SendMessageToLocalHandler given the message type it sends data to the local handler found.
func (r *HandlerRegistry) SendMessageToLocalHandler(messageType types.MessageType, data []byte) {
r.mu.Lock()
defer r.mu.Unlock()
protoID := protocol.ID(messageType)
h, ok := r.bytesHandlers[protoID]
if !ok {
return
}
// we need this goroutine to avoid blocking the caller goroutine
go h(data)
}
package libp2p
import (
"context"
"fmt"
"strings"
"time"
"github.com/libp2p/go-libp2p"
dht "github.com/libp2p/go-libp2p-kad-dht"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/pnet"
"github.com/libp2p/go-libp2p/core/protocol"
"github.com/libp2p/go-libp2p/core/routing"
"github.com/libp2p/go-libp2p/p2p/host/autorelay"
"github.com/libp2p/go-libp2p/p2p/host/peerstore/pstoremem"
"github.com/libp2p/go-libp2p/p2p/net/connmgr"
"github.com/libp2p/go-libp2p/p2p/protocol/circuitv2/relay"
"github.com/libp2p/go-libp2p/p2p/security/noise"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
quic "github.com/libp2p/go-libp2p/p2p/transport/quic"
"github.com/libp2p/go-libp2p/p2p/transport/tcp"
ws "github.com/libp2p/go-libp2p/p2p/transport/websocket"
webtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
"github.com/multiformats/go-multiaddr"
"github.com/spf13/afero"
mafilt "github.com/whyrusleeping/multiaddr-filter"
"gitlab.com/nunet/device-management-service/types"
)
// NewHost returns a new libp2p host with dht and other related settings.
func NewHost(ctx context.Context, config *types.Libp2pConfig, fs afero.Fs) (host.Host, *dht.IpfsDHT, *pubsub.PubSub, error) {
var idht *dht.IpfsDHT
connmgr, err := connmgr.NewConnManager(
100,
400,
connmgr.WithGracePeriod(time.Duration(config.GracePeriodMs)*time.Millisecond),
)
if err != nil {
return nil, nil, nil, err
}
filter := multiaddr.NewFilters()
for _, s := range defaultServerFilters {
f, err := mafilt.NewMask(s)
if err != nil {
zlog.Sugar().Errorf("incorrectly formatted address filter in config: %s - %v", s, err)
}
filter.AddFilter(*f, multiaddr.ActionDeny)
}
ps, err := pstoremem.NewPeerstore()
if err != nil {
return nil, nil, nil, err
}
var libp2pOpts []libp2p.Option
baseOpts := []dht.Option{
dht.ProtocolPrefix(protocol.ID(config.DHTPrefix)),
dht.NamespacedValidator(strings.ReplaceAll(config.CustomNamespace, "/", ""), dhtValidator{PS: ps}),
dht.Mode(dht.ModeServer),
}
if config.PrivateNetwork.WithSwarmKey {
psk, err := configureSwarmKey(fs)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to configure swarm key: %v", err)
}
libp2pOpts = append(libp2pOpts, libp2p.PrivateNetwork(psk))
// guarantee that outer connection will be refused
pnet.ForcePrivateNetwork = true
} else {
// enable quic (it does not work with pnet enabled)
libp2pOpts = append(libp2pOpts, libp2p.Transport(quic.NewTransport))
libp2pOpts = append(libp2pOpts, libp2p.Transport(webtransport.New))
// for some reason, ForcePrivateNetwork was equal to true even without being set to true
pnet.ForcePrivateNetwork = false
}
libp2pOpts = append(libp2pOpts, libp2p.ListenAddrStrings(config.ListenAddress...),
libp2p.Identity(config.PrivateKey),
libp2p.Routing(func(h host.Host) (routing.PeerRouting, error) {
idht, err = dht.New(ctx, h, baseOpts...)
return idht, err
}),
libp2p.Peerstore(ps),
libp2p.Security(libp2ptls.ID, libp2ptls.New),
libp2p.Security(noise.ID, noise.New),
// Do not use DefaulTransports as we can not enable Quic when pnet
libp2p.Transport(tcp.NewTCPTransport),
libp2p.Transport(ws.New),
libp2p.EnableNATService(),
libp2p.ConnectionManager(connmgr),
libp2p.EnableRelay(),
libp2p.EnableHolePunching(),
libp2p.EnableRelayService(
relay.WithResources(
relay.Resources{
MaxReservations: 256,
MaxCircuits: 32,
BufferSize: 4096,
MaxReservationsPerPeer: 8,
MaxReservationsPerIP: 16,
},
),
relay.WithLimit(&relay.RelayLimit{
Duration: 5 * time.Minute,
Data: 1 << 21, // 2 MiB
}),
),
libp2p.EnableAutoRelayWithPeerSource(
func(ctx context.Context, num int) <-chan peer.AddrInfo {
r := make(chan peer.AddrInfo)
go func() {
defer close(r)
for i := 0; i < num; i++ {
select {
case p := <-newPeer:
select {
case r <- p:
case <-ctx.Done():
return
}
case <-ctx.Done():
return
}
}
}()
return r
},
autorelay.WithBootDelay(time.Minute),
autorelay.WithBackoff(30*time.Second),
autorelay.WithMinCandidates(2),
autorelay.WithMaxCandidates(3),
autorelay.WithNumRelays(2),
),
)
if config.Server {
libp2pOpts = append(libp2pOpts, libp2p.AddrsFactory(makeAddrsFactory([]string{}, []string{}, defaultServerFilters)))
libp2pOpts = append(libp2pOpts, libp2p.ConnectionGater((*filtersConnectionGater)(filter)))
} else {
libp2pOpts = append(libp2pOpts, libp2p.NATPortMap())
}
host, err := libp2p.New(libp2pOpts...)
if err != nil {
return nil, nil, nil, err
}
optsPS := []pubsub.Option{pubsub.WithMessageSigning(true), pubsub.WithMaxMessageSize(config.GossipMaxMessageSize)}
gossip, err := pubsub.NewGossipSub(ctx, host, optsPS...)
// gossip, err := pubsub.NewGossipSubWithRouter(ctx, host, pubsub.DefaultGossipSubRouter(host), optsPS...)
if err != nil {
return nil, nil, nil, err
}
return host, idht, gossip, nil
}
package libp2p
import (
"github.com/libp2p/go-libp2p/core/peer"
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
// TODO: pass the logger to the constructor and remove from here
var (
zlog *otelzap.Logger
newPeer = make(chan peer.AddrInfo)
)
func init() {
zlog = logger.OtelZapLogger("network.libp2p")
}
package libp2p
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/ipfs/go-cid"
kbucket "github.com/libp2p/go-libp2p-kbucket"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/protocol/ping"
"gitlab.com/nunet/device-management-service/types"
pubsub "github.com/libp2p/go-libp2p-pubsub"
"github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
"github.com/spf13/afero"
"google.golang.org/protobuf/proto"
dht "github.com/libp2p/go-libp2p-kad-dht"
libp2pdiscovery "github.com/libp2p/go-libp2p/core/discovery"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peerstore"
"github.com/libp2p/go-libp2p/core/protocol"
drouting "github.com/libp2p/go-libp2p/p2p/discovery/routing"
bt "gitlab.com/nunet/device-management-service/internal/background_tasks"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
)
const (
MB = 1024 * 1024
MaxMessageLengthMB = 10
)
// Libp2p contains the configuration for a Libp2p instance.
//
// TODO-suggestion: maybe we should call it something else like Libp2pPeer,
// Libp2pHost or just Peer (callers would use libp2p.Peer...)
type Libp2p struct {
Host host.Host
DHT *dht.IpfsDHT
PS peerstore.Peerstore
pubsub *pubsub.PubSub
pubsubTopics map[string]*pubsub.Topic
topicSubscription map[string]*pubsub.Subscription
topicMux sync.RWMutex
// a list of peers discovered by discovery
discoveredPeers []peer.AddrInfo
discovery libp2pdiscovery.Discovery
// services
pingService *ping.PingService
// tasks
discoveryTask *bt.Task
handlerRegistry *HandlerRegistry
config *types.Libp2pConfig
// dependencies (db, filesystem...)
fs afero.Fs
}
// New creates a libp2p instance.
//
// TODO-Suggestion: move types.Libp2pConfig to here for better readability.
// Unless there is a reason to keep within types.
func New(config *types.Libp2pConfig, fs afero.Fs) (*Libp2p, error) {
if config == nil {
return nil, errors.New("config is nil")
}
if config.Scheduler == nil {
return nil, errors.New("scheduler is nil")
}
return &Libp2p{
config: config,
discoveredPeers: make([]peer.AddrInfo, 0),
pubsubTopics: make(map[string]*pubsub.Topic),
topicSubscription: make(map[string]*pubsub.Subscription),
fs: fs,
}, nil
}
// Init initializes a libp2p host with its dependencies.
func (l *Libp2p) Init(context context.Context) error {
host, dht, pubsub, err := NewHost(context, l.config, l.fs)
if err != nil {
zlog.Sugar().Error(err)
return err
}
l.Host = host
l.DHT = dht
l.PS = host.Peerstore()
l.discovery = drouting.NewRoutingDiscovery(dht)
l.pubsub = pubsub
l.handlerRegistry = NewHandlerRegistry(host)
return nil
}
// Start performs network bootstrapping, peer discovery and protocols handling.
func (l *Libp2p) Start(context context.Context) error {
// set stream handlers
l.registerStreamHandlers()
// bootstrap should return error if it had an error
err := l.Bootstrap(context, l.config.BootstrapPeers)
if err != nil {
zlog.Sugar().Errorf("failed to start network: %v", err)
return err
}
// advertise randevouz discovery
err = l.advertiseForRendezvousDiscovery(context)
if err != nil {
// TODO: the error might be misleading as a peer can normally work well if an error
// is returned here (e.g.: the error is yielded in tests even though all tests pass).
zlog.Sugar().Errorf("failed to start network with randevouz discovery: %v", err)
}
// discover
err = l.DiscoverDialPeers(context)
if err != nil {
zlog.Sugar().Errorf("failed to discover peers: %v", err)
}
// register period peer discoveryTask task
discoveryTask := &bt.Task{
Name: "Peer Discovery",
Description: "Periodic task to discover new peers every 15 minutes",
Function: func(_ interface{}) error {
return l.DiscoverDialPeers(context)
},
Triggers: []bt.Trigger{&bt.PeriodicTrigger{Interval: 15 * time.Minute}},
}
l.discoveryTask = l.config.Scheduler.AddTask(discoveryTask)
l.config.Scheduler.Start()
return nil
}
// RegisterStreamMessageHandler registers a stream handler for a specific protocol.
func (l *Libp2p) RegisterStreamMessageHandler(messageType types.MessageType, handler StreamHandler) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithStreamCallback(messageType, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// RegisterBytesMessageHandler registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) RegisterBytesMessageHandler(messageType types.MessageType, handler func(data []byte)) error {
if messageType == "" {
return errors.New("message type is empty")
}
if err := l.handlerRegistry.RegisterHandlerWithBytesCallback(messageType, l.handleReadBytesFromStream, handler); err != nil {
return fmt.Errorf("failed to register handler %s: %w", messageType, err)
}
return nil
}
// HandleMessage registers a stream handler for a specific protocol and sends bytes to handler func.
func (l *Libp2p) HandleMessage(messageType string, handler func(data []byte)) error {
return l.RegisterBytesMessageHandler(types.MessageType(messageType), handler)
}
func (l *Libp2p) handleReadBytesFromStream(s network.Stream) {
callback, ok := l.handlerRegistry.bytesHandlers[s.Protocol()]
if !ok {
s.Close()
return
}
c := bufio.NewReader(s)
defer s.Close()
// read the first 8 bytes to determine the size of the message
msgLengthBuffer := make([]byte, 8)
_, err := c.Read(msgLengthBuffer)
if err != nil {
return
}
// create a buffer with the size of the message and then read until its full
lengthPrefix := binary.LittleEndian.Uint64(msgLengthBuffer)
buf := make([]byte, lengthPrefix)
// read the full message
_, err = io.ReadFull(c, buf)
if err != nil {
return
}
callback(buf)
}
// SendMessage sends a message to a list of peers.
func (l *Libp2p) SendMessage(ctx context.Context, addrs []string, msg types.MessageEnvelope) error {
var wg sync.WaitGroup
errCh := make(chan error, len(addrs))
for _, addr := range addrs {
wg.Add(1)
go func(addr string) {
defer wg.Done()
err := l.sendMessage(ctx, addr, msg)
if err != nil {
errCh <- err
}
}(addr)
}
wg.Wait()
close(errCh)
var result error
for err := range errCh {
if result == nil {
result = err
} else {
result = fmt.Errorf("%v; %v", result, err)
}
}
return result
}
// OpenStream opens a stream to a remote address and returns the stream for the caller to handle.
func (l *Libp2p) OpenStream(ctx context.Context, addr string, messageType types.MessageType) (network.Stream, error) {
maddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return nil, fmt.Errorf("invalid multiaddress: %w", err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(maddr)
if err != nil {
return nil, fmt.Errorf("could not resolve peer info: %w", err)
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return nil, fmt.Errorf("failed to connect to peer: %w", err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(messageType))
if err != nil {
return nil, fmt.Errorf("failed to open stream: %w", err)
}
return stream, nil
}
// GetMultiaddr returns the peer's multiaddr.
func (l *Libp2p) GetMultiaddr() ([]multiaddr.Multiaddr, error) {
peerInfo := peer.AddrInfo{
ID: l.Host.ID(),
Addrs: l.Host.Addrs(),
}
return peer.AddrInfoToP2pAddrs(&peerInfo)
}
// Stop performs a cleanup of any resources used in this package.
func (l *Libp2p) Stop() error {
var errorMessages []string
l.config.Scheduler.RemoveTask(l.discoveryTask.ID)
if err := l.DHT.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if err := l.Host.Close(); err != nil {
errorMessages = append(errorMessages, err.Error())
}
if len(errorMessages) > 0 {
return errors.New(strings.Join(errorMessages, "; "))
}
return nil
}
// Stat returns the status about the libp2p network.
func (l *Libp2p) Stat() types.NetworkStats {
lAddrs := make([]string, 0, len(l.Host.Addrs()))
for _, addr := range l.Host.Addrs() {
lAddrs = append(lAddrs, addr.String())
}
return types.NetworkStats{
ID: l.Host.ID().String(),
ListenAddr: strings.Join(lAddrs, ", "),
}
}
// Ping the remote address. The remote address is the encoded peer id which will be decoded and used here.
//
// TODO (Return error once): something that was confusing me when using this method is that the error is
// returned twice if any. Once as a field of PingResult and one as a return value.
func (l *Libp2p) Ping(ctx context.Context, peerIDAddress string, timeout time.Duration) (types.PingResult, error) {
// avoid dial to self attempt
if peerIDAddress == l.Host.ID().String() {
err := errors.New("can't ping self")
return types.PingResult{Success: false, Error: err}, err
}
pingCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
remotePeer, err := peer.Decode(peerIDAddress)
if err != nil {
return types.PingResult{}, err
}
pingChan := ping.Ping(pingCtx, l.Host, remotePeer)
select {
case res := <-pingChan:
if res.Error != nil {
zlog.Sugar().Errorf("failed to ping peer %s: %v", peerIDAddress, res.Error)
return types.PingResult{
Success: false,
RTT: res.RTT,
Error: res.Error,
}, res.Error
}
return types.PingResult{
RTT: res.RTT,
Success: true,
}, nil
case <-pingCtx.Done():
return types.PingResult{
Error: pingCtx.Err(),
}, pingCtx.Err()
}
}
// ResolveAddress resolves the address by given a peer id.
func (l *Libp2p) ResolveAddress(ctx context.Context, id string) ([]string, error) {
pid, err := peer.Decode(id)
if err != nil {
return nil, fmt.Errorf("failed to resolve invalid peer: %w", err)
}
// resolve ourself
if l.Host.ID().String() == id {
multiAddrs, err := l.GetMultiaddr()
if err != nil {
return nil, fmt.Errorf("failed to resolve self: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
pi, err := l.DHT.FindPeer(ctx, pid)
if err != nil {
return nil, fmt.Errorf("failed to resolve address %s: %w", id, err)
}
peerInfo := peer.AddrInfo{
ID: pi.ID,
Addrs: pi.Addrs,
}
multiAddrs, err := peer.AddrInfoToP2pAddrs(&peerInfo)
if err != nil {
return nil, fmt.Errorf("failed to convert to p2p address: %w", err)
}
resolved := make([]string, len(multiAddrs))
for i, v := range multiAddrs {
resolved[i] = v.String()
}
return resolved, nil
}
// Query return all the advertisements in the network related to a key.
// The network is queried to find providers for the given key, and peers which we aren't connected to can be retrieved.
func (l *Libp2p) Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error) {
if key == "" {
return nil, errors.New("advertisement key is empty")
}
customCID, err := createCIDFromKey(key)
if err != nil {
return nil, fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
addrInfo, err := l.DHT.FindProviders(ctx, customCID)
if err != nil {
return nil, fmt.Errorf("failed to find providers for key %s: %w", key, err)
}
advertisements := make([]*commonproto.Advertisement, 0)
for _, v := range addrInfo {
// TODO: use go routines to get the values in parallel.
bytesAdvertisement, err := l.DHT.GetValue(ctx, l.getCustomNamespace(key, v.ID.String()))
if err != nil {
continue
}
var ad commonproto.Advertisement
if err := proto.Unmarshal(bytesAdvertisement, &ad); err != nil {
return nil, fmt.Errorf("failed to unmarshal advertisement payload: %w", err)
}
advertisements = append(advertisements, &ad)
}
return advertisements, nil
}
// Advertise given data and a key pushes the data to the dht.
func (l *Libp2p) Advertise(ctx context.Context, key string, data []byte) error {
if key == "" {
return errors.New("advertisement key is empty")
}
pubKeyBytes, err := l.getPublicKey()
if err != nil {
return fmt.Errorf("failed to get public key: %w", err)
}
envelope := &commonproto.Advertisement{
PeerId: l.Host.ID().String(),
Timestamp: time.Now().Unix(),
Data: data,
PublicKey: pubKeyBytes,
}
concatenatedBytes := bytes.Join([][]byte{
[]byte(envelope.PeerId),
{byte(envelope.Timestamp)},
envelope.Data,
pubKeyBytes,
}, nil)
sig, err := l.sign(concatenatedBytes)
if err != nil {
return fmt.Errorf("failed to sign advertisement envelope content: %w", err)
}
envelope.Signature = sig
envelopeBytes, err := proto.Marshal(envelope)
if err != nil {
return fmt.Errorf("failed to marshal advertise envelope: %w", err)
}
customCID, err := createCIDFromKey(key)
if err != nil {
return fmt.Errorf("failed to create cid for key %s: %w", key, err)
}
err = l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), envelopeBytes)
if err != nil {
return fmt.Errorf("failed to put key %s into the dht: %w", key, err)
}
err = l.DHT.Provide(ctx, customCID, true)
if err != nil {
return fmt.Errorf("failed to provide key %s into the dht: %w", key, err)
}
return nil
}
// Unadvertise removes the data from the dht.
func (l *Libp2p) Unadvertise(ctx context.Context, key string) error {
err := l.DHT.PutValue(ctx, l.getCustomNamespace(key, l.DHT.PeerID().String()), nil)
if err != nil {
return fmt.Errorf("failed to remove key %s from the DHT: %w", key, err)
}
return nil
}
// Publish publishes data to a topic.
// The requirements are that only one topic handler should exist per topic.
func (l *Libp2p) Publish(ctx context.Context, topic string, data []byte) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to publish: %w", err)
}
err = topicHandler.Publish(ctx, data)
if err != nil {
return fmt.Errorf("failed to publish to topic %s: %w", topic, err)
}
return nil
}
// Subscribe subscribes to a topic and sends the messages to the handler.
func (l *Libp2p) Subscribe(ctx context.Context, topic string, handler func(data []byte)) error {
topicHandler, err := l.getOrJoinTopicHandler(topic)
if err != nil {
return fmt.Errorf("failed to subscribe to topic: %w", err)
}
sub, err := topicHandler.Subscribe()
if err != nil {
return fmt.Errorf("failed to subscribe to topic %s: %w", topic, err)
}
l.topicMux.Lock()
l.topicSubscription[topic] = sub
l.topicMux.Unlock()
go func() {
for {
msg, err := sub.Next(ctx)
if err != nil {
continue
}
handler(msg.Data)
}
}()
return nil
}
func (l *Libp2p) sendMessage(ctx context.Context, addr string, msg types.MessageEnvelope) error {
peerAddr, err := multiaddr.NewMultiaddr(addr)
if err != nil {
return fmt.Errorf("invalid multiaddr %s: %v", addr, err)
}
peerInfo, err := peer.AddrInfoFromP2pAddr(peerAddr)
if err != nil {
return fmt.Errorf("failed to get peer info %s: %v", addr, err)
}
// we are delivering a message to ourself
// we should use the handler to send the message to the handler directly which has been previously registered.
if peerInfo.ID.String() == l.Host.ID().String() {
l.handlerRegistry.SendMessageToLocalHandler(msg.Type, msg.Data)
return nil
}
if err := l.Host.Connect(ctx, *peerInfo); err != nil {
return fmt.Errorf("failed to connect to peer %v: %v", peerInfo.ID, err)
}
stream, err := l.Host.NewStream(ctx, peerInfo.ID, protocol.ID(msg.Type))
if err != nil {
return fmt.Errorf("failed to open stream to peer %v: %v", peerInfo.ID, err)
}
defer stream.Close()
requestBufferSize := 8 + len(msg.Data)
if requestBufferSize > MaxMessageLengthMB*MB {
return fmt.Errorf("message size %d is greater than limit %d bytes", requestBufferSize, MaxMessageLengthMB*MB)
}
requestPayloadWithLength := make([]byte, requestBufferSize)
binary.LittleEndian.PutUint64(requestPayloadWithLength, uint64(len(msg.Data)))
copy(requestPayloadWithLength[8:], msg.Data)
_, err = stream.Write(requestPayloadWithLength)
if err != nil {
return fmt.Errorf("failed to send message to peer %v: %v", peerInfo.ID, err)
}
return nil
}
// getOrJoinTopicHandler gets the topic handler, it will be created if it doesn't exist.
// for publishing and subscribing its needed therefore its implemented in this function.
func (l *Libp2p) getOrJoinTopicHandler(topic string) (*pubsub.Topic, error) {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
t, err := l.pubsub.Join(topic)
if err != nil {
return nil, fmt.Errorf("failed to join topic %s: %w", topic, err)
}
topicHandler = t
l.pubsubTopics[topic] = t
}
return topicHandler, nil
}
// Unsubscribe cancels the subscription to a topic
func (l *Libp2p) Unsubscribe(topic string) error {
l.topicMux.Lock()
defer l.topicMux.Unlock()
topicHandler, ok := l.pubsubTopics[topic]
if !ok {
return fmt.Errorf("not subscribed to topic: %s", topic)
}
// delete subscription handler and subscription
sub, ok := l.topicSubscription[topic]
if ok {
sub.Cancel()
delete(l.topicSubscription, topic)
}
if err := topicHandler.Close(); err != nil {
return fmt.Errorf("failed to close topic handler: %w", err)
}
delete(l.pubsubTopics, topic)
return nil
}
func (l *Libp2p) VisiblePeers() []peer.AddrInfo {
return l.discoveredPeers
}
func (l *Libp2p) KnownPeers() ([]peer.AddrInfo, error) {
knownPeers := l.Host.Peerstore().Peers()
peers := make([]peer.AddrInfo, 0, len(knownPeers))
for _, p := range knownPeers {
peers = append(peers, peer.AddrInfo{ID: p})
}
return peers, nil
}
func (l *Libp2p) DumpDHTRoutingTable() ([]kbucket.PeerInfo, error) {
rt := l.DHT.RoutingTable()
return rt.GetPeerInfos(), nil
}
func (l *Libp2p) registerStreamHandlers() {
l.pingService = ping.NewPingService(l.Host)
l.Host.SetStreamHandler(protocol.ID("/ipfs/ping/1.0.0"), l.pingService.PingHandler)
}
func (l *Libp2p) sign(data []byte) ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
signature, err := privKey.Sign(data)
if err != nil {
return nil, fmt.Errorf("failed to sign data: %w", err)
}
return signature, nil
}
func (l *Libp2p) getPublicKey() ([]byte, error) {
privKey := l.Host.Peerstore().PrivKey(l.Host.ID())
if privKey == nil {
return nil, errors.New("private key not found for the host")
}
pubKey := privKey.GetPublic()
return pubKey.Raw()
}
func (l *Libp2p) getCustomNamespace(key, peerID string) string {
return fmt.Sprintf("%s-%s-%s", l.config.CustomNamespace, key, peerID)
}
func createCIDFromKey(key string) (cid.Cid, error) {
hash := sha256.Sum256([]byte(key))
mh, err := multihash.Encode(hash[:], multihash.SHA2_256)
if err != nil {
return cid.Cid{}, err
}
return cid.NewCidV1(cid.Raw, mh), nil
}
func CleanupPeer(_ peer.ID) error {
zlog.Warn("CleanupPeer: Stub")
return nil
}
func PingPeer(_ context.Context, _ peer.ID) (bool, *ping.Result) {
zlog.Warn("PingPeer: Stub")
return false, nil
}
func DumpKademliaDHT(_ context.Context) ([]types.PeerData, error) {
zlog.Warn("DumpKademliaDHT: Stub")
return nil, nil
}
func OldPingPeer(_ context.Context, _ peer.ID) (bool, *types.PingResult) {
zlog.Warn("OldPingPeer: Stub")
return false, nil
}
package libp2p
import (
"bytes"
"encoding/base64"
"fmt"
"os"
"path/filepath"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/pnet"
"github.com/spf13/afero"
)
/*
** Swarm key **
By default, the swarm key shall be stored in a file named `swarm.key`
using the following pathbased codec:
`/key/swarm/psk/1.0.0/<base_encoding>/<256_bits_key>`
`<base_encoding>` is either bin, base16 or base64.
*/
// TODO-pnet-1: we shouldn't handle configuration paths here, a general configuration path
// should be provided by /internal/config.go
func getBasePath(_ afero.Fs) (string, error) {
homeDir, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("error getting home directory: %w", err)
}
nunetDir := filepath.Join(homeDir, ".nunet")
return nunetDir, nil
}
// configureSwarmKey try to read the swarm key from `<config_path>/swarm.key` file.
// If a swarm key is not found, generate a new one.
//
// TODO-ask: should we continue to generate a new swarm key if one is not found?
// Or we should enforce the user to use some cmd/API rpc to generate a new one?
func configureSwarmKey(fs afero.Fs) (pnet.PSK, error) {
var psk pnet.PSK
var err error
psk, err = getSwarmKey(fs)
if err != nil {
psk, err = generateSwarmKey(fs)
if err != nil {
return nil, fmt.Errorf("failed to generate new swarm key: %w", err)
}
}
return psk, nil
}
// getSwarmKey reads the swarm key from a file
func getSwarmKey(fs afero.Fs) (pnet.PSK, error) {
homeDir, err := getBasePath(fs)
if err != nil {
return nil, fmt.Errorf("failed to get base file path: %w", err)
}
swarmkey, err := afero.ReadFile(fs, filepath.Join(homeDir, "swarm.key"))
if err != nil {
return nil, fmt.Errorf("failed to read swarm key file: %w", err)
}
psk, err := pnet.DecodeV1PSK(bytes.NewReader(swarmkey))
if err != nil {
return nil, fmt.Errorf("failed to configure private network: %s", err)
}
// TODO-ask: should we return psk fingerprint?
return psk, nil
}
// generateSwarmKey generates a new swarm key, storing it within
// `<nunet_config_dir>/swarm.key`.
func generateSwarmKey(fs afero.Fs) (pnet.PSK, error) {
priv, _, err := crypto.GenerateKeyPair(crypto.Secp256k1, 256)
if err != nil {
return nil, err
}
privBytes, err := crypto.MarshalPrivateKey(priv)
if err != nil {
return nil, err
}
encodedKey := base64.StdEncoding.EncodeToString(privBytes)
swarmKeyWithCodec := fmt.Sprintf("/key/swarm/psk/1.0.0/\n/base64/\n%s\n", encodedKey)
// TODO-pnet-1
nunetDir, err := getBasePath(fs)
if err != nil {
return nil, err
}
swarmKeyPath := filepath.Join(nunetDir, "swarm.key")
if err := afero.WriteFile(fs, swarmKeyPath, []byte(swarmKeyWithCodec), 0o600); err != nil {
return nil, fmt.Errorf("error writing swarm key to file: %w", err)
}
psk, err := pnet.DecodeV1PSK(bytes.NewReader([]byte(swarmKeyWithCodec)))
if err != nil {
return nil, fmt.Errorf("failed to decode generated swarm key: %s", err)
}
zlog.Sugar().Infof("A new Swarm key was generated and written to %s\n"+
"IMPORTANT: If you'd like to create the swarm key using a cryptography algorithm "+
"of your choice, just modify the swarm.key file with your own key.\n"+
"The content of `swarm.key` should look like: `/key/swarm/psk/1.0.0/<base_encoding>/<your_key>`\n"+
"where `<base_encoding>` is either `bin`, `base16`, or `base64`.\n",
swarmKeyPath,
)
return psk, nil
}
package network
import (
"context"
"errors"
"fmt"
"time"
"github.com/spf13/afero"
commonproto "gitlab.com/nunet/device-management-service/proto/generated/v1/common"
"gitlab.com/nunet/device-management-service/network/libp2p"
"gitlab.com/nunet/device-management-service/types"
)
// Messenger defines the interface for sending messages.
type Messenger interface {
// SendMessage sends a message to the given address.
SendMessage(ctx context.Context, addrs []string, msg types.MessageEnvelope) error
}
type Network interface {
// Messenger embedded interface
Messenger
// Init initializes the network
Init(context.Context) error
// Start starts the network
Start(context context.Context) error
// Stat returns the network information
Stat() types.NetworkStats
// Ping pings the given address and returns the PingResult
Ping(ctx context.Context, address string, timeout time.Duration) (types.PingResult, error)
// HandleMessage is responsible for registering a message type and its handler.
HandleMessage(messageType string, handler func(data []byte)) error
// ResolveAddress given an id it retruns the address of the peer.
// In libp2p, id represents the peerID and the response is the addrinfo
ResolveAddress(ctx context.Context, id string) ([]string, error)
// Advertise advertises the given data with the given adId
// such as advertising device capabilities on the DHT
Advertise(ctx context.Context, key string, data []byte) error
// Unadvertise stops advertising data corresponding to the given adId
Unadvertise(ctx context.Context, key string) error
// Query returns the network advertisement
Query(ctx context.Context, key string) ([]*commonproto.Advertisement, error)
// Publish publishes the given data to the given topic if the network
// type allows publish/subscribe functionality such as gossipsub or nats
Publish(ctx context.Context, topic string, data []byte) error
// Subscribe subscribes to the given topic and calls the handler function
// if the network type allows it similar to Publish()
Subscribe(ctx context.Context, topic string, handler func(data []byte)) error
// Unsubscribe from a topic
Unsubscribe(topic string) error
// Stop stops the network including any existing advertisements and subscriptions
Stop() error
}
// NewNetwork returns a new network given the configuration.
func NewNetwork(netConfig *types.NetworkConfig, fs afero.Fs) (Network, error) {
// TODO: probable additional params to receive: DB, FileSystem
if netConfig == nil {
return nil, errors.New("network configuration is nil")
}
switch netConfig.Type {
case types.Libp2pNetwork:
ln, err := libp2p.New(&netConfig.Libp2pConfig, fs)
return ln, err
case types.NATSNetwork:
return nil, errors.New("not implemented")
default:
return nil, fmt.Errorf("unsupported network type: %s", netConfig.Type)
}
}
package basiccontroller
import (
"context"
"fmt"
"os"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db/repositories"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
"gitlab.com/nunet/device-management-service/utils"
)
// BasicVolumeController is the default implementation of the VolumeController.
// It persists storage volumes information using the StorageVolume.
type BasicVolumeController struct {
// repo is the repository for storage volume operations
repo repositories.StorageVolume
// basePath is the base path where volumes are stored under
basePath string
// file system to act upon
FS afero.Fs
}
// NewDefaultVolumeController returns a new instance of BasicVolumeController
//
// TODO-BugFix [path]: volBasePath might not end with `/`, causing errors when calling methods.
// We need to validate it using the `path` library or just verifying the string.
func NewDefaultVolumeController(repo repositories.StorageVolume, volBasePath string, fs afero.Fs) (*BasicVolumeController, error) {
return &BasicVolumeController{
repo: repo,
basePath: volBasePath,
FS: fs,
}, nil
}
// CreateVolume creates a new storage volume given a source (S3, IPFS, job, etc). The
// creation of a storage volume effectively creates an empty directory in the local filesystem
// and writes a record in the database.
//
// The directory name follows the format: `<volSource> + "-" + <name>
// where `name` is random.
//
// TODO-maybe [withName]: allow callers to specify custom name for path
func (vc *BasicVolumeController) CreateVolume(volSource storage.VolumeSource, opts ...storage.CreateVolOpt) (types.StorageVolume, error) {
vol := types.StorageVolume{
Private: false,
ReadOnly: false,
EncryptionType: types.EncryptionTypeNull,
}
for _, opt := range opts {
opt(&vol)
}
randomStr, err := utils.RandomString(16)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create random string: %w", err)
}
vol.Path = vc.basePath + string(volSource) + "-" + randomStr
if err := vc.FS.Mkdir(vol.Path, os.ModePerm); err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %w", err)
}
createdVol, err := vc.repo.Create(context.Background(), vol)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume in repository: %w", err)
}
return createdVol, nil
}
// LockVolume makes the volume read-only, not only changing the field value but also changing file permissions.
// It should be used after all necessary data has been written.
// It optionally can also set the CID and mark the volume as private.
//
// TODO-maybe [CID]: maybe calculate CID of every volume in case WithCID opt is not provided
func (vc *BasicVolumeController) LockVolume(pathToVol string, opts ...storage.LockVolOpt) error {
query := vc.repo.GetQuery()
query.Conditions = append(query.Conditions, repositories.EQ("Path", pathToVol))
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
return fmt.Errorf("failed to find storage volume with path %s - Error: %w", pathToVol, err)
}
for _, opt := range opts {
opt(&vol)
}
// update records
vol.ReadOnly = true
updatedVol, err := vc.repo.Update(context.Background(), vol.ID, vol)
if err != nil {
return fmt.Errorf("failed to update storage volume with path %s - Error: %w", pathToVol, err)
}
// change file permissions
if err := vc.FS.Chmod(updatedVol.Path, 0o400); err != nil {
return fmt.Errorf("failed to make storage volume read-only (path: %s): %w", updatedVol.Path, err)
}
return nil
}
// WithPrivate designates a given volume as private. It can be used both
// when creating or locking a volume.
func WithPrivate[T storage.CreateVolOpt | storage.LockVolOpt]() T {
return func(v *types.StorageVolume) {
v.Private = true
}
}
// WithCID sets the CID of a given volume if already calculated
//
// TODO [validate]: check if CID provided is valid
func WithCID(cid string) storage.LockVolOpt {
return func(v *types.StorageVolume) {
v.CID = cid
}
}
// DeleteVolume deletes a given storage volume record from the database.
// Identifier is either a CID or a path of a volume. Therefore, records for both
// will be deleted.
//
// Note [CID]: if we start to type CID as cid.CID, we may have to use generics here
// as in `[T string | cid.CID]`
func (vc *BasicVolumeController) DeleteVolume(identifier string, idType storage.IDType) error {
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
return fmt.Errorf("identifier type not supported")
}
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
if err == repositories.ErrNotFound {
return fmt.Errorf("volume not found: %w", err)
}
return fmt.Errorf("failed to find volume: %w", err)
}
err = vc.repo.Delete(context.Background(), vol.ID)
if err != nil {
return fmt.Errorf("failed to delete volume: %w", err)
}
return nil
}
// ListVolumes returns a list of all storage volumes stored on the database
//
// TODO [filter]: maybe add opts to filter results by certain values
func (vc *BasicVolumeController) ListVolumes() ([]types.StorageVolume, error) {
volumes, err := vc.repo.FindAll(context.Background(), vc.repo.GetQuery())
if err != nil {
return nil, fmt.Errorf("failed to list volumes: %w", err)
}
return volumes, nil
}
// GetSize returns the size of a volume
// TODO-minor: identify which measurement type will be used
func (vc *BasicVolumeController) GetSize(identifier string, idType storage.IDType) (int64, error) {
query := vc.repo.GetQuery()
switch idType {
case storage.IDTypePath:
query.Conditions = append(query.Conditions, repositories.EQ("Path", identifier))
case storage.IDTypeCID:
query.Conditions = append(query.Conditions, repositories.EQ("CID", identifier))
default:
return 0, fmt.Errorf("unsupported ID type: %d", idType)
}
vol, err := vc.repo.Find(context.Background(), query)
if err != nil {
return 0, fmt.Errorf("failed to find volume: %w", err)
}
size, err := utils.GetDirectorySize(vc.FS, vol.Path)
if err != nil {
return 0, fmt.Errorf("failed to get directory size: %w", err)
}
return size, nil
}
// EncryptVolume encrypts a given volume
func (vc *BasicVolumeController) EncryptVolume(_ string, _ types.Encryptor, _ types.EncryptionType) error {
return fmt.Errorf("not implemented")
}
// DecryptVolume decrypts a given volume
func (vc *BasicVolumeController) DecryptVolume(_ string, _ types.Decryptor, _ types.EncryptionType) error {
return fmt.Errorf("not implemented")
}
// TODO-minor: compiler-time check for interface implementation
var _ storage.VolumeController = (*BasicVolumeController)(nil)
package basiccontroller
import (
"context"
"fmt"
"os"
"testing"
clover "github.com/ostafen/clover/v2"
"github.com/spf13/afero"
rclover "gitlab.com/nunet/device-management-service/db/repositories/clover"
"gitlab.com/nunet/device-management-service/types"
)
type VolControllerTestSuiteHelper struct {
BasicVolController *BasicVolumeController
Fs afero.Fs
DB *clover.DB
Volumes map[string]*types.StorageVolume
TempDBDir string
}
// SetupVolControllerTestSuite sets up a volume controller with 0-n volumes given a base path.
// If volumes are inputed, directories will be created and volumes will be stored in the database
func SetupVolControllerTestSuite(t *testing.T, basePath string, volumes map[string]*types.StorageVolume) (*VolControllerTestSuiteHelper, error) {
tempDir, err := os.MkdirTemp("", "clover-test-*")
if err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
db, err := rclover.NewDB(tempDir, []string{"storage_volume"})
if err != nil {
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to open clover db: %w", err)
}
fs := afero.NewMemMapFs()
err = fs.MkdirAll(basePath, 0o755)
if err != nil {
db.Close()
os.RemoveAll(tempDir)
return nil, fmt.Errorf("failed to create base path: %w", err)
}
repo := rclover.NewStorageVolume(db)
vc, err := NewDefaultVolumeController(repo, basePath, fs)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume controller: %w", err)
}
for _, vol := range volumes {
// create root volume dir
err = fs.MkdirAll(vol.Path, 0o755)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume dir: %w", err)
}
// create volume record in db
_, err = repo.Create(context.Background(), *vol)
if err != nil {
db.Close()
os.Remove(tempDir)
return nil, fmt.Errorf("failed to create volume record: %w", err)
}
}
helper := &VolControllerTestSuiteHelper{vc, fs, db, volumes, tempDir}
t.Cleanup(func() {
TearDownVolControllerTestSuite(helper)
})
return helper, nil
}
// TearDownVolControllerTestSuite cleans up resources created during setup
func TearDownVolControllerTestSuite(helper *VolControllerTestSuiteHelper) {
if helper.DB != nil {
helper.DB.Close()
}
if helper.TempDBDir != "" {
os.RemoveAll(helper.TempDBDir)
}
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/storage"
basiccontroller "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Download fetch files from a given S3 bucket. The key may be a directory ending
// with `/` or have a wildcard (`*`) so it handles normal S3 folders but it does
// not handle x-directory.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Download(ctx context.Context, sourceSpecs *types.SpecConfig) (types.StorageVolume, error) {
var storageVol types.StorageVolume
source, err := DecodeInputSpec(sourceSpecs)
if err != nil {
return types.StorageVolume{}, err
}
storageVol, err = s.volController.CreateVolume(storage.VolumeSourceS3)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to create storage volume: %v", err)
}
resolvedObjects, err := resolveStorageKey(ctx, s.Client, &source)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to resolve storage key: %v", err)
}
for _, resolvedObject := range resolvedObjects {
err = s.downloadObject(ctx, &source, resolvedObject, storageVol.Path)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to download s3 object: %v", err)
}
}
// after data is filled within the volume, we have to lock it
err = s.volController.LockVolume(storageVol.Path)
if err != nil {
return types.StorageVolume{}, fmt.Errorf("failed to lock storage volume: %v", err)
}
return storageVol, nil
}
func (s *Storage) downloadObject(ctx context.Context, source *InputSource, object s3Object, volPath string) error {
outputPath := filepath.Join(volPath, *object.key)
// use the same file system instance used by the Volume Controller
var fs afero.Fs
if basicVolController, ok := s.volController.(*basiccontroller.BasicVolumeController); ok {
fs = basicVolController.FS
}
err := fs.MkdirAll(outputPath, 0o755)
if err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
if object.isDir {
// if object is a directory, we don't need to download it (just create the dir)
return nil
}
outputFile, err := fs.OpenFile(outputPath, os.O_RDWR|os.O_CREATE, 0o755)
if err != nil {
return err
}
defer outputFile.Close()
zlog.Sugar().Debugf("Downloading s3 object %s to %s", *object.key, outputPath)
_, err = s.downloader.Download(ctx, outputFile, &s3.GetObjectInput{
Bucket: aws.String(source.Bucket),
Key: object.key,
IfMatch: object.eTag,
})
if err != nil {
return fmt.Errorf("failed to download file: %w", err)
}
return nil
}
// resolveStorageKey returns a list of s3 objects within a bucket accordingly to the key provided.
func resolveStorageKey(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := source.Key
if key == "" {
return nil, fmt.Errorf("key is required")
}
// Check if the key represents a single object
if !strings.HasSuffix(key, "/") && !strings.Contains(key, "*") {
return resolveSingleObject(ctx, client, source)
}
// key represents multiple objects
return resolveObjectsWithPrefix(ctx, client, source)
}
func resolveSingleObject(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
headObjectInput := &s3.HeadObjectInput{
Bucket: aws.String(source.Bucket),
Key: aws.String(key),
}
headObjectOut, err := client.HeadObject(ctx, headObjectInput)
if err != nil {
return []s3Object{}, fmt.Errorf("failed to retrieve object metadata: %v", err)
}
// TODO-minor: validate checksum if provided
if strings.HasPrefix(*headObjectOut.ContentType, "application/x-directory") {
return []s3Object{}, fmt.Errorf("x-directory is not yet handled")
}
return []s3Object{
{
key: aws.String(source.Key),
eTag: headObjectOut.ETag,
size: *headObjectOut.ContentLength,
},
}, nil
}
func resolveObjectsWithPrefix(ctx context.Context, client *s3.Client, source *InputSource) ([]s3Object, error) {
key := sanitizeKey(source.Key)
// List objects with the given prefix
listObjectsInput := &s3.ListObjectsV2Input{
Bucket: aws.String(source.Bucket),
Prefix: aws.String(key),
}
var objects []s3Object
paginator := s3.NewListObjectsV2Paginator(client, listObjectsInput)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list objects: %v", err)
}
for _, obj := range page.Contents {
objects = append(objects, s3Object{
key: aws.String(*obj.Key),
size: *obj.Size,
isDir: strings.HasSuffix(*obj.Key, "/"),
})
}
}
return objects, nil
}
package s3
import (
"context"
"strings"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
)
// GetAWSDefaultConfig returns the default AWS config based on environment variables,
// shared configuration and shared credentials files.
func GetAWSDefaultConfig() (aws.Config, error) {
var optFns []func(*config.LoadOptions) error
return config.LoadDefaultConfig(context.Background(), optFns...)
}
func hasValidCredentials(config aws.Config) bool {
credentials, err := config.Credentials.Retrieve(context.Background())
if err != nil {
return false
}
return credentials.HasKeys()
}
// sanitizeKey removes trailing spaces and wildcards
func sanitizeKey(key string) string {
return strings.TrimSuffix(strings.TrimSpace(key), "*")
}
package s3
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
func init() {
zlog = logger.OtelZapLogger("s3")
}
package s3
import (
"context"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
s3Manager "github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"gitlab.com/nunet/device-management-service/storage"
"gitlab.com/nunet/device-management-service/types"
)
type Storage struct {
*s3.Client
volController storage.VolumeController
downloader *s3Manager.Downloader
uploader *s3Manager.Uploader
}
type s3Object struct {
key *string
eTag *string
size int64
isDir bool
}
// NewClient creates a new S3Storage which includes a S3-SDK client.
// It depends on a VolumeController to manage the volumes being acted upon.
func NewClient(config aws.Config, volController storage.VolumeController) (*Storage, error) {
if !hasValidCredentials(config) {
return nil, fmt.Errorf("invalid credentials")
}
s3Client := s3.NewFromConfig(config)
return &Storage{
s3Client,
volController,
s3Manager.NewDownloader(s3Client),
s3Manager.NewUploader(s3Client),
}, nil
}
func (s *Storage) Size(ctx context.Context, source *types.SpecConfig) (uint64, error) {
inputSource, err := DecodeInputSpec(source)
if err != nil {
return 0, fmt.Errorf("failed to decode input spec: %v", err)
}
input := &s3.HeadObjectInput{
Bucket: aws.String(inputSource.Bucket),
Key: aws.String(inputSource.Key),
}
output, err := s.HeadObject(ctx, input)
if err != nil {
return 0, fmt.Errorf("failed to get object size: %v", err)
}
return uint64(*output.ContentLength), nil
}
// Compile time interface check
// var _ storage.StorageProvider = (*S3Storage)(nil)
package s3
import (
"fmt"
"github.com/fatih/structs"
"github.com/mitchellh/mapstructure"
"gitlab.com/nunet/device-management-service/types"
)
type InputSource struct {
Bucket string
Key string
Filter string
Region string
Endpoint string
}
func (s InputSource) Validate() error {
if s.Bucket == "" {
return fmt.Errorf("invalid s3 storage params: bucket cannot be empty")
}
return nil
}
func (s InputSource) ToMap() map[string]interface{} {
return structs.Map(s)
}
func DecodeInputSpec(spec *types.SpecConfig) (InputSource, error) {
if !spec.IsType(types.StorageProviderS3) {
return InputSource{}, fmt.Errorf("invalid storage source type. Expected %s but received %s", types.StorageProviderS3, spec.Type)
}
inputParams := spec.Params
if inputParams == nil {
return InputSource{}, fmt.Errorf("invalid storage input source params. cannot be nil")
}
var c InputSource
if err := mapstructure.Decode(spec.Params, &c); err != nil {
return c, err
}
return c, c.Validate()
}
package s3
import (
"context"
"fmt"
"os"
"path/filepath"
"github.com/spf13/afero"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
basiccontroller "gitlab.com/nunet/device-management-service/storage/basic_controller"
"gitlab.com/nunet/device-management-service/types"
)
// Upload uploads all files (recursively) from a local volume to an S3 bucket.
// It handles directories.
//
// Warning: the implementation should rely on the FS provided by the volume controller,
// be careful if managing files with `os` (the volume controller might be
// using an in-memory one)
func (s *Storage) Upload(ctx context.Context, vol types.StorageVolume, destinationSpecs *types.SpecConfig) error {
target, err := DecodeInputSpec(destinationSpecs)
if err != nil {
return fmt.Errorf("failed to decode input spec: %v", err)
}
sanitizedKey := sanitizeKey(target.Key)
// set file system to act upon based on the volume controller implementation
var fs afero.Fs
if basicVolController, ok := s.volController.(*basiccontroller.BasicVolumeController); ok {
fs = basicVolController.FS
}
zlog.Sugar().Debugf("Uploading files from %s to s3://%s/%s", vol.Path, target.Bucket, sanitizedKey)
err = afero.Walk(fs, vol.Path, func(filePath string, info os.FileInfo, err error) error {
if err != nil {
return err
}
// Skip directories
if info.IsDir() {
return nil
}
relPath, err := filepath.Rel(vol.Path, filePath)
if err != nil {
return fmt.Errorf("failed to get relative path: %v", err)
}
// Construct the S3 key by joining the sanitized key and the relative path
s3Key := filepath.Join(sanitizedKey, relPath)
file, err := fs.Open(filePath)
if err != nil {
return fmt.Errorf("failed to open file: %v", err)
}
defer file.Close()
zlog.Sugar().Debugf("Uploading %s to s3://%s/%s", filePath, target.Bucket, s3Key)
_, err = s.uploader.Upload(ctx, &s3.PutObjectInput{
Bucket: aws.String(target.Bucket),
Key: aws.String(s3Key),
Body: file,
})
if err != nil {
return fmt.Errorf("failed to upload file to S3: %v", err)
}
return nil
})
if err != nil {
return fmt.Errorf("upload failed. It's possible that some files were uploaded; Error: %v", err)
}
return nil
}
package types
import (
"errors"
)
// note: this data type may be moved to dms or jobs package in the future
type CapabilityAdder interface {
Add(Capability) error
}
type CapabilitySubtractor interface {
Subtract(Capability) error
}
type CapabilityAddSubtractor interface {
CapabilityAdder
CapabilitySubtractor
}
type Capability struct {
Executors Executors `json:"executor" description:"Executor type required for the job (docker, vm, wasm, or others)"`
JobTypes JobTypes `json:"type" description:"Details about type of the job (One time, batch, recurring, long running). Refer to dms.jobs package for jobType data model"`
Resources ExecutionResources `json:"resources" description:"Resources required for the job"`
Libraries []Library `json:"libraries" description:"Libraries required for the job"`
Localities []Locality `json:"locality" description:"Preferred localities of the machine for execution"`
Storage []Storage `json:"storage" description:"Preferred storage options that the machine should have"`
Connectivity Connectivity `json:"connectivity" description:"Network configuration required"`
Price []PriceInformation `json:"price" description:"Pricing information"`
Time TimeInformation `json:"time" description:"Time constraints"`
KYCs []KYC
}
var _ CapabilityAddSubtractor = &Capability{}
type Connectivity struct {
Ports []int `json:"ports" description:"Ports that need to be open for the job to run"`
VPN bool `json:"vpn" description:"Whether VPN is required"`
}
type PriceInformation struct {
Currency string `json:"currency" description:"Currency used for pricing"`
CurrencyPerHour int `json:"currency_per_hour" description:"Price charged per hour"`
TotalPerJob int `json:"total_per_job" description:"Maximum total price or budget of the job"`
Preference int `json:"preference" description:"Pricing preference"`
}
type TimeInformation struct {
Units string `json:"units" description:"Time units"`
MaxTime int `json:"max_time" description:"Maximum time that job should run"`
Preference int `json:"preference" description:"Time preference"`
}
type Library struct {
Name string `json:"name" description:"Name of the library"`
Constraint string `json:"constraint" description:"Constraint of the library"`
Version string `json:"version" description:"Version of the library"`
}
type Locality struct {
Kind string `json:"kind" description:"Kind of the region (geographic, nunet-defined, etc)"`
Name string `json:"name" description:"Name of the region"`
}
type Storage struct {
Type StorageType `json:"type" description:"Type of storage"`
Size int `json:"size" description:"Size of storage"`
Amount int `json:"amount" description:"Amount of storage"`
}
type StorageType string
const (
//nolint
SSD_STORAGE_TYPE StorageType = "ssd"
//nolint
HDD_STORAGE_TYPE StorageType = "hdd"
)
type KYC struct {
Type string `json:"type" description:"Type of KYC"`
Data string `json:"data" description:"Data required for KYC"`
}
type JobTypes []JobType
type JobType string
const (
BATCH JobType = "batch"
SINGLERUN JobType = "single_run"
RECURRING JobType = "recurring"
LONGRUNNING JobType = "long_running"
)
type (
Executors []Executor
Libraries []Library
Localities []Locality
KYCs []KYC
Storages []Storage
PricesInformation []PriceInformation
)
func (lib *Library) Equal(library Library) bool {
if lib.Name == library.Name && lib.Constraint == library.Constraint && lib.Version == library.Version {
return true
}
return false
}
func (loc *Locality) Equal(locality Locality) bool {
if loc.Kind == locality.Kind && loc.Name == locality.Name {
return true
}
return false
}
func (e *Executor) Equal(executor Executor) bool {
return e.ExecutorType == executor.ExecutorType
}
func (k *KYC) Equal(kyc KYC) bool {
if k.Type == kyc.Type && k.Data == kyc.Data {
return true
}
return false
}
func (p *PriceInformation) Equal(price PriceInformation) bool {
if p.Currency == price.Currency && p.CurrencyPerHour == price.CurrencyPerHour && p.TotalPerJob == price.TotalPerJob && p.Preference == price.Preference {
return true
}
return false
}
func (es Executors) Contains(executor Executor) bool {
for _, e := range es {
if e.Equal(executor) {
return true
}
}
return false
}
func (j JobTypes) Contains(jobType JobType) bool {
for _, j := range j {
if j == jobType {
return true
}
}
return false
}
func (l Libraries) Contains(library Library) bool {
for _, lib := range l {
if lib.Equal(library) {
return true
}
}
return false
}
func (l Localities) Contains(locality Locality) bool {
for _, loc := range l {
if loc.Equal(locality) {
return true
}
}
return false
}
func (s Storages) Contains(storage Storage) bool {
for _, s := range s {
if s.Type == storage.Type {
return true
}
}
return false
}
func (k KYCs) Contains(kyc KYC) bool {
for _, k := range k {
if k.Equal(kyc) {
return true
}
}
return false
}
func (ps PricesInformation) Contains(price PriceInformation) bool {
for _, p := range ps {
if p.Equal(price) {
return true
}
}
return false
}
// Add adds the resources of the given Capability to the current Capability
func (c *Capability) Add(cap Capability) error {
// Executors
for _, executor := range cap.Executors {
if !c.Executors.Contains(executor) {
c.Executors = append(c.Executors, executor)
}
}
// JobTypes
for _, jobType := range cap.JobTypes {
if !c.JobTypes.Contains(jobType) {
c.JobTypes = append(c.JobTypes, jobType)
}
}
// Resources
if c.Resources.CPU.Architecture == cap.Resources.CPU.Architecture {
c.Resources.CPU.Cores += cap.Resources.CPU.Cores
c.Resources.CPU.ClockSpeedHz += cap.Resources.CPU.ClockSpeedHz
}
c.Resources.Memory.ClockSpeedHz += cap.Resources.Memory.ClockSpeedHz // does it make sense?
c.Resources.Memory.Size += cap.Resources.Memory.Size
if c.Resources.Disk.Type == cap.Resources.Disk.Type {
c.Resources.Disk.Size += cap.Resources.Disk.Size
}
c.Resources.GPUs = append(c.Resources.GPUs, cap.Resources.GPUs...)
// Libraries
var myLibraries Libraries = c.Libraries
for _, library := range cap.Libraries {
if !myLibraries.Contains(library) {
c.Libraries = append(c.Libraries, library)
}
}
// Localities
var myLocalities Localities = c.Localities
for _, locality := range cap.Localities {
if !myLocalities.Contains(locality) {
c.Localities = append(c.Localities, locality)
}
}
// Storage
var myStorages Storages = c.Storage
for _, storage := range cap.Storage {
if !myStorages.Contains(storage) {
c.Storage = append(c.Storage, storage)
} else {
for i, s := range c.Storage {
if s.Type == storage.Type {
c.Storage[i].Size += storage.Size
c.Storage[i].Amount += storage.Amount
}
}
}
}
// Connectivity
if cap.Connectivity.VPN {
c.Connectivity.VPN = true
}
for _, port := range cap.Connectivity.Ports {
if !sliceContainsInt(c.Connectivity.Ports, port) {
c.Connectivity.Ports = append(c.Connectivity.Ports, port)
}
}
// Price
var myPrice PricesInformation = c.Price
for _, price := range cap.Price {
if !myPrice.Contains(price) {
c.Price = append(c.Price, price)
}
}
// Time
c.Time.MaxTime += cap.Time.MaxTime
// KYCs
var myKYCs KYCs = c.KYCs
for _, kyc := range cap.KYCs {
if !myKYCs.Contains(kyc) {
c.KYCs = append(c.KYCs, kyc)
}
}
return nil
}
// Subtract subtracts the resources of the given Capability from the current Capability
func (c *Capability) Subtract(cap Capability) error {
// Executors
// No Subtract operation for Executors
// JobTypes
// No Subtract operation for JobTypes
// Resources
if c.Resources.CPU.Cores < cap.Resources.CPU.Cores || c.Resources.CPU.ClockSpeedHz < cap.Resources.CPU.ClockSpeedHz {
return errors.New("cpu resources are not enough")
}
if c.Resources.Memory.Size < cap.Resources.Memory.Size {
return errors.New("memory resources are not enough")
}
if c.Resources.Disk.Size < cap.Resources.Disk.Size {
return errors.New("disk resources are not enough")
}
if c.Resources.CPU.Architecture == cap.Resources.CPU.Architecture {
c.Resources.CPU.Cores -= cap.Resources.CPU.Cores
c.Resources.CPU.Threads -= cap.Resources.CPU.Threads
c.Resources.CPU.ClockSpeedHz -= cap.Resources.CPU.ClockSpeedHz
}
c.Resources.Memory.ClockSpeedHz -= cap.Resources.Memory.ClockSpeedHz // does it make sense?
c.Resources.Memory.Size -= cap.Resources.Memory.Size
if c.Resources.Disk.Type == cap.Resources.Disk.Type {
c.Resources.Disk.Size -= cap.Resources.Disk.Size
}
// Remove the GPUs from the current Capability
// This is a naive implementation and may not work as expected
// if the GPUs are not unique
for _, gpu := range cap.Resources.GPUs {
for i, cgpu := range c.Resources.GPUs {
//nolint
if cgpu.Equal(&gpu) {
c.Resources.GPUs = append(c.Resources.GPUs[:i], c.Resources.GPUs[i+1:]...)
}
}
}
// Libraries
// No Subtract operation for Libraries
// Localities
// No Subtract operation for Localities
// Storage
for _, storage := range cap.Storage {
for i, myStorage := range c.Storage {
if storage.Type == myStorage.Type && storage.Amount == myStorage.Amount && storage.Size == myStorage.Size {
c.Storage = append(c.Storage[:i], c.Storage[i+1:]...)
} else if storage.Type == myStorage.Type {
c.Storage[i].Size -= storage.Size
c.Storage[i].Amount -= storage.Amount
}
}
}
// Connectivity
for _, port := range cap.Connectivity.Ports {
for i, cport := range c.Connectivity.Ports {
if cport == port {
c.Connectivity.Ports = append(c.Connectivity.Ports[:i], c.Connectivity.Ports[i+1:]...)
}
}
}
// Price
// No Subtract operation for Price
// Time
c.Time.MaxTime -= cap.Time.MaxTime
// KYCs
// No Subtract operation for KYCs
return nil
}
// SliceContainsInt checks if a integer exists in a slice
func sliceContainsInt(s []int, val int) bool {
for _, v := range s {
if v == val {
return true
}
}
return false
}
package types
type Comparison string
const (
Worse Comparison = "Worse" // left object is 'worse' than right object
Better Comparison = "Better" // left object is 'better' than right object
Equal Comparison = "Equal" // objects on the left and right are 'equally good'
Error Comparison = "Error" // error in comparison or objects incomparable
)
// TODO: Consider comments in this thread: https://gitlab.com/nunet/device-management-service/-/merge_requests/356#note_1997854443
// TODO: Consider comments in this thread: https://gitlab.com/nunet/device-management-service/-/merge_requests/356#note_1997875361
// 'left' means 'this object' and 'right' means 'the supplied other object';
// it makes sense when using the type in functions like Compare(left, right)
// this type is unused but still reserved in case we will need it in the future
// and still used in some tests that are left from previous versions of the package;
// TOTO: remove / update after the package is finished and refactor
type ComplexComparison map[string]Comparison
// And returns the result of AND operation of two Comparison values
// it respects the following table of truth:
// | AND | Better | Worse | Equal | Error |
// | ------ | ------ |--------|--------|--------|
// | Better | Better | Worse | Better | Error |
// | Worse | Worse | Worse | Worse | Error |
// | Equal | Better | Worse | Equal | Error |
// | Error | Error | Error | Error | Error |
func (c Comparison) And(cmp Comparison) Comparison {
if c == Error || cmp == Error {
return Error
}
if c == cmp {
return c
}
switch c {
case Equal:
switch cmp {
case Better:
return Better
case Worse:
return Worse
default:
return Error
}
case Better:
switch cmp {
case Worse:
return Worse
case Equal:
return Better
default:
return Error
}
case Worse:
if cmp == Error {
return Error
}
return Worse
default:
return Error
}
}
package types
type Executor struct {
ExecutorType ExecutorType `json:"executor_type"`
}
type ExecutorType string
const (
ExecutorTypeDocker = "docker"
ExecutorTypeFirecracker = "firecracker"
ExecutorTypeWasm = "wasm"
ExecutionStatusCodeSuccess = 0
)
// ExecutionRequest is the request object for executing a job
type ExecutionRequest struct {
JobID string // ID of the job to execute
ExecutionID string // ID of the execution
EngineSpec *SpecConfig // Engine spec for the execution
Resources *ExecutionResources // Resources for the execution
Inputs []*StorageVolumeExecutor // Input volumes for the execution
Outputs []*StorageVolumeExecutor // Output volumes for the results
ResultsDir string // Directory to store the results
}
// ExecutionResult is the result of an execution
type ExecutionResult struct {
STDOUT string `json:"stdout"` // STDOUT of the execution
STDERR string `json:"stderr"` // STDERR of the execution
ExitCode int `json:"exit_code"` // Exit code of the execution
ErrorMsg string `json:"error_msg"` // Error message if the execution failed
}
// NewExecutionResult creates a new ExecutionResult object
func NewExecutionResult(code int) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: code,
}
}
// NewFailedExecutionResult creates a new ExecutionResult object for a failed execution
// It sets the error message from the provided error and sets the exit code to -1
func NewFailedExecutionResult(err error) *ExecutionResult {
return &ExecutionResult{
STDOUT: "",
STDERR: "",
ExitCode: -1,
ErrorMsg: err.Error(),
}
}
// LogStreamRequest is the request object for streaming logs from an execution
type LogStreamRequest struct {
JobID string // ID of the job
ExecutionID string // ID of the execution
Tail bool // Tail the logs
Follow bool // Follow the logs
}
package types
const (
NetP2P = "p2p"
)
// NetworkSpec is a stub. Please expand based on requirements.
type NetworkSpec struct{}
// NetConfig is a stub. Please expand it or completely change it based on requirements.
type NetConfig struct {
NetworkSpec SpecConfig `json:"network_spec"` // Network specification
}
func (nc *NetConfig) GetNetworkConfig() *SpecConfig {
return &nc.NetworkSpec
}
// NetworkStats should contain all network info the user is interested in.
// for now there's only peerID and listening address but reachability, local and remote addr etc...
// can be added when necessary.
type NetworkStats struct {
ID string `json:"id"`
ListenAddr string `json:"listen_addr"`
}
// MessageInfo is a stub. Please expand it or completely change it based on requirements.
type MessageInfo struct {
Info string `json:"info"` // Message information
}
package types
import (
"fmt"
"strings"
)
type GPUVendor string
const (
GPUVendorNvidia GPUVendor = "NVIDIA"
GPUVendorAMDATI GPUVendor = "AMD/ATI"
GPUVendorIntel GPUVendor = "Intel"
GPUVendorUnknown GPUVendor = "Unknown"
None GPUVendor = "None"
)
func ParseGPUVendor(vendor string) GPUVendor {
switch {
case strings.Contains(vendor, "NVIDIA"):
return GPUVendorNvidia
case strings.Contains(vendor, "AMD") || strings.Contains(vendor, "ATI"):
return GPUVendorAMDATI
case strings.Contains(vendor, "Intel"):
return GPUVendorIntel
default:
return GPUVendorUnknown
}
}
type GPU struct {
// Index is the self-reported index of the device in the system
Index int
// Name is the model name of the GPU e.g. Tesla T4
Name string
// Vendor is the maker of the GPU, e.g. NVidia, AMD, Intel
Vendor GPUVendor
// PCIAddress is the PCI address of the device, in the format AAAA:BB:CC.C
// Used to discover the correct device rendering cards
PCIAddress string
// Model of the GPU, e.g. A100
Model string `json:"model" description:"GPU model, ex A100"`
// TotalVRAM is the total amount of VRAM on the device
TotalVRAM uint64
// UsedVRAM is the amount of VRAM currently in use
UsedVRAM uint64
// FreeVRAM is the amount of VRAM currently free
FreeVRAM uint64
// Gorm fields
// Team, is this the right way to do this? What is the best practice we're following?
ResourceID uint `gorm:"foreignKey:ID"`
}
func (g *GPU) Equal(gpu *GPU) bool {
if g.Model == gpu.Model &&
g.TotalVRAM == gpu.TotalVRAM &&
g.UsedVRAM == gpu.UsedVRAM &&
g.FreeVRAM == gpu.FreeVRAM &&
g.Index == gpu.Index &&
g.Vendor == gpu.Vendor &&
g.PCIAddress == gpu.PCIAddress {
return true
}
return false
}
type GPUList []GPU
// GetGPUWithHighestFreeVRAM Determine the GPU vendor with the highest free VRAM: NVIDIA, AMD, or Intel.
// Useful for selecting the best GPU if multiple vendors are available,
// especially in multi-GPU systems or mining rigs.
func (gpus GPUList) GetGPUWithHighestFreeVRAM() (GPU, error) {
if len(gpus) == 0 {
// Return a GPU with Vendor set to None if no GPUs are detected - Useful for launching CPU-only containers
return GPU{Vendor: None}, nil
}
var maxFreeVRAMGpu GPU
maxFreeVRAM := uint64(0)
for _, gpu := range gpus {
if gpu.FreeVRAM > maxFreeVRAM {
maxFreeVRAM = gpu.FreeVRAM
maxFreeVRAMGpu = gpu
}
}
return maxFreeVRAMGpu, nil
}
// negativeValueError is a type struct used to return a custom error for negative values in resources subtraction
type negativeValueError struct {
resource string
r1 any
r2 any
}
// Error returns the error message
func (e *negativeValueError) Error() string {
return fmt.Sprintf("Error: %s subtraction results in negative values. (%d - %d)", e.resource, e.r1, e.r2)
}
// ResourceOps defines the operations on resources
// TODO: Check how to handle GPU resources
type ResourceOps interface {
// Add returns the sum of the resources
Add(r Resources) Resources
// Subtract returns the difference of the resources
Subtract(r Resources) (Resources, error)
}
// Resources represents the resources of the machine
type Resources struct {
CPU float64
NumCores int
GPU []GPU `gorm:"foreignKey:ResourceID"`
RAM uint64
Disk uint64
}
// Add returns the sum of the resources
func (r Resources) Add(r2 Resources) Resources {
// TODO: GPU addition
return Resources{
CPU: r.CPU + r2.CPU,
RAM: r.RAM + r2.RAM,
Disk: r.Disk + r2.Disk,
}
}
// Subtract returns the difference of the resources
func (r Resources) Subtract(r2 Resources) (Resources, error) {
// Check if the subtraction results in negative values
// Team, why do we need to return a negative value error?
// Can't we just return the negative value indicating that the machine is overused?
// We can then handle the scenario in the calling function by checking if the result is negative if required.
if r.CPU < r2.CPU {
return Resources{}, &negativeValueError{resource: "CPU", r1: r.CPU, r2: r2.CPU}
}
// TODO: GPU subtraction
if r.RAM < r2.RAM {
return Resources{}, &negativeValueError{resource: "RAM", r1: r.RAM, r2: r2.RAM}
}
if r.Disk < r2.Disk {
return Resources{}, &negativeValueError{resource: "Disk", r1: r.Disk, r2: r2.Disk}
}
return Resources{
CPU: r.CPU - r2.CPU,
RAM: r.RAM - r2.RAM,
Disk: r.Disk - r2.Disk,
}, nil
}
var _ ResourceOps = (*Resources)(nil)
// MachineResources represents the total resources of the machine
type MachineResources struct {
BaseDBModel
Resources
}
// FreeResources represents the free resources of the machine
type FreeResources struct {
BaseDBModel
Resources
}
// OnboardedResources represents the onboarded resources of the machine
type OnboardedResources struct {
BaseDBModel
Resources
}
// RequiredResources represents the resources required by the jobs running on the machine
// TODO: this is a replacement for ServiceResourceRequirements. Check with the team on this.
type RequiredResources struct {
BaseDBModel
JobID int
Resources
}
// CPUInfo represents the CPU information of the machine
type CPUInfo struct {
NumCores uint64
MHzPerCore float64
Compute float64
}
// SpecInfo represents the machine specifications
// TODO: Finalise the fields required in this struct
// https://gitlab.com/nunet/device-management-service/-/issues/533
type SpecInfo struct {
CPUs []CPU
GPUs []GPU
RAMs []RAM
Disks []Disk
Network NetworkInfo
}
// CPU represents the CPU information
type CPU struct {
// Model represents the CPU model, e.g., "Intel Core i7-9700K", "AMD Ryzen 9 5900X"
Model string
// Vendor represents the CPU manufacturer, e.g., "Intel", "AMD"
Vendor string
// ClockSpeedHz represents the CPU clock speed in Hz
ClockSpeedHz int64
// Cores represents the number of physical CPU cores
Cores uint32
// Threads represents the number of logical CPU threads (including hyperthreading)
Threads int
// Architecture represents the CPU architecture, e.g., "x86", "x86_64", "arm64"
Architecture string
// Cache size in bytes
CacheSize uint64
}
// RAM represents the RAM information
type RAM struct {
// Size in bytes
Size int64
// Clock speed in Hz
ClockSpeedHz uint64
// Type represents the RAM type, e.g., "DDR4", "DDR5", "LPDDR4"
Type string
}
// Disk represents the disk information
type Disk struct {
// Model represents the disk model, e.g., "Samsung 970 EVO Plus", "Western Digital Blue SN550"
// TODO: may be removed as Disk models will be usually irrelevant, right?
Model string
// Vendor represents the disk manufacturer, e.g., "Samsung", "Western Digital"
// TODO: may be removed as Disk vendors will be usually irrelevant, right?
Vendor string
// Size in bytes
Size uint64
// Type represents the disk type, e.g., "SSD", "HDD", "NVMe"
Type string
// Interface represents the disk interface, e.g., "SATA", "PCIe", "M.2"
Interface string
// Read speed in bytes per second
// TODO: may be removed as it may be too specific for our case
ReadSpeed uint64
// Write speed in bytes per second
// TODO: may be removed as it may be too specific for our case
WriteSpeed uint64
}
// NetworkInfo represents the network information
type NetworkInfo struct {
// Bandwidth in bits per second (b/s)
Bandwidth uint64
// NetworkType represents the network type, e.g., "Ethernet", "Wi-Fi", "Cellular"
NetworkType string
}
// ExecutionResources represents the resources required to execute a task
type ExecutionResources struct {
// CPU configuration
CPU CPU `json:"cpu,omitempty" description:"CPU configuration"`
// Memory configuration
Memory RAM `json:"memory,omitempty" description:"Memory configuration"`
// Disk configuration
Disk Disk `json:"disk,omitempty" description:"Disk configuration"`
// GPU configuration
GPUs []GPU `json:"gpus,omitempty" description:"GPU configuration"`
}
package types
import (
"errors"
"strings"
"gitlab.com/nunet/device-management-service/utils/validate"
)
// SpecConfig represents a configuration for a spec
// A SpecConfig can be used to define an engine spec, a storage volume, etc.
type SpecConfig struct {
// Type of the spec (e.g. docker, firecracker, storage, etc.)
Type string `json:"type"`
// Params of the spec
Params map[string]interface{} `json:"params,omitempty"`
}
type Config interface {
GetNetworkConfig() *SpecConfig
}
// NewSpecConfig creates a new SpecConfig with the given type
func NewSpecConfig(t string) *SpecConfig {
return &SpecConfig{
Type: t,
Params: make(map[string]interface{}),
}
}
// WithParam adds a new key-value pair to the spec params
func (s *SpecConfig) WithParam(key string, value interface{}) *SpecConfig {
if s.Params == nil {
s.Params = make(map[string]interface{})
}
s.Params[key] = value
return s
}
// Normalize ensures that the spec config is in a valid state
func (s *SpecConfig) Normalize() {
if s == nil {
return
}
s.Type = strings.TrimSpace(s.Type)
// Ensure that an empty and nil map are treated the same
if len(s.Params) == 0 {
s.Params = make(map[string]interface{})
}
}
// Validate checks if the spec config is valid
func (s *SpecConfig) Validate() error {
if s == nil {
return errors.New("nil spec config")
}
if validate.IsBlank(s.Type) {
return errors.New("missing spec type")
}
return nil
}
// IsType returns true if the current SpecConfig is of the given type
func (s *SpecConfig) IsType(t string) bool {
if s == nil {
return false
}
t = strings.TrimSpace(t)
return strings.EqualFold(s.Type, t)
}
// IsEmpty returns true if the spec config is empty
func (s *SpecConfig) IsEmpty() bool {
return s == nil || (validate.IsBlank(s.Type) && len(s.Params) == 0)
}
package types
import (
"log"
"os"
)
type CollectorConfig struct {
CollectorType string
CollectorEndpoint string
}
type TelemetryConfig struct {
ServiceName string
GlobalEndpoint string
ObservabilityLevel int
CollectorConfigs map[string]CollectorConfig
}
func LoadConfigFromEnv() (*TelemetryConfig, error) {
levelStr := os.Getenv("OBSERVABILITY_LEVEL")
level := parseObservabilityLevel(levelStr)
// Assume the format for collector-specific configs is like COLLECTOR_<TYPE>_ENDPOINT
collectorConfigs := make(map[string]CollectorConfig)
for _, collectorType := range []string{"OPENTELEMETRY", "LOG"} {
endpoint := os.Getenv("COLLECTOR_" + collectorType + "_ENDPOINT")
if endpoint != "" {
collectorConfigs[collectorType] = CollectorConfig{
CollectorType: collectorType,
CollectorEndpoint: endpoint,
}
}
}
config := &TelemetryConfig{
ServiceName: os.Getenv("SERVICE_NAME"),
GlobalEndpoint: os.Getenv("COLLECTOR_ENDPOINT"),
ObservabilityLevel: level,
CollectorConfigs: collectorConfigs,
}
// Debug: Log loaded environment variables
log.Printf("Loaded environment variables: SERVICE_NAME=%s, COLLECTOR_ENDPOINT=%s, OBSERVABILITY_LEVEL=%s", config.ServiceName, config.GlobalEndpoint, levelStr)
return config, nil
}
func parseObservabilityLevel(levelStr string) int {
switch levelStr {
case "TRACE":
return int(TRACE)
case "DEBUG":
return int(DEBUG)
case "INFO":
return int(INFO)
case "WARN":
return int(WARN)
case "ERROR":
return int(ERROR)
case "FATAL":
return int(FATAL)
default:
log.Printf("Invalid OBSERVABILITY_LEVEL: %s, defaulting to INFO", levelStr)
return int(INFO)
}
}
// ObservabilityLevel defines levels of observability.
type ObservabilityLevel int
// Constants representing levels of observability.
const (
TRACE ObservabilityLevel = 1
DEBUG ObservabilityLevel = 2
INFO ObservabilityLevel = 3
WARN ObservabilityLevel = 4
ERROR ObservabilityLevel = 5
FATAL ObservabilityLevel = 6
)
package types
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// BaseDBModel is a base model for all entities. It'll be mainly used for database
// records.
type BaseDBModel struct {
ID string `gorm:"type:uuid"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
// GetID returns the ID of the entity.
func (m BaseDBModel) GetID() string {
return m.ID
}
// BeforeCreate sets the ID and CreatedAt fields before creating a new entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeCreate(_ *gorm.DB) error {
m.ID = uuid.NewString()
m.CreatedAt = time.Now()
return nil
}
// BeforeUpdate sets the UpdatedAt field before updating an entity.
// This is a GORM hook and should not be called directly.
// We can move this to generic repository create methods.
func (m *BaseDBModel) BeforeUpdate(_ *gorm.DB) error {
m.UpdatedAt = time.Now()
return nil
}
package utils
import (
"bytes"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/cosmos/btcutil/bech32"
"github.com/ethereum/go-ethereum/common"
"github.com/fivebinaries/go-cardano-serialization/address"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
)
// KoiosEndpoint type for Koios rest api endpoints
type KoiosEndpoint string
const (
// KoiosMainnet - mainnet Koios rest api endpoint
KoiosMainnet KoiosEndpoint = "api.koios.rest"
// KoiosPreProd - testnet preprod Koios rest api endpoint
KoiosPreProd KoiosEndpoint = "preprod.koios.rest"
)
type UTXOs struct {
TxHash string `json:"tx_hash"`
IsSpent bool `json:"is_spent"`
}
type TxHashResp struct {
TxHash string `json:"tx_hash"`
TransactionType string `json:"transaction_type"`
DateTime string `json:"date_time"`
}
type ClaimCardanoTokenBody struct {
ComputeProviderAddress string `json:"compute_provider_address"`
TxHash string `json:"tx_hash"`
}
type RewardRespToCPD struct {
ServiceProviderAddr string `json:"service_provider_addr"`
ComputeProviderAddr string `json:"compute_provider_addr"`
RewardType string `json:"reward_type,omitempty"`
SignatureDatum string `json:"signature_datum,omitempty"`
MessageHashDatum string `json:"message_hash_datum,omitempty"`
Datum string `json:"datum,omitempty"`
SignatureAction string `json:"signature_action,omitempty"`
MessageHashAction string `json:"message_hash_action,omitempty"`
Action string `json:"action,omitempty"`
}
type UpdateTxStatusBody struct {
Address string `json:"address,omitempty"`
}
func GetJobTxHashes(size int, clean string) ([]TxHashResp, error) {
if clean != "done" && clean != "refund" && clean != "withdraw" && clean != "" {
return nil, fmt.Errorf("invalid clean_tx parameter")
}
err := db.DB.Where("transaction_type = ?", clean).Delete(&types.Services{}).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
}
resp := make([]TxHashResp, 0)
services := make([]types.Services, 0)
if size == 0 {
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type is NOT NULL").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("no job deployed to request reward for: %w", err)
}
} else {
services, err = getLimitedTransactions(size)
if err != nil {
zlog.Sugar().Errorf("%w", err)
return nil, fmt.Errorf("could not get limited transactions: %w", err)
}
}
for _, service := range services {
resp = append(resp, TxHashResp{
TxHash: service.TxHash,
TransactionType: service.TransactionType,
DateTime: service.CreatedAt.String(),
})
}
return resp, nil
}
func RequestReward(claim ClaimCardanoTokenBody) (*RewardRespToCPD, error) {
// At some point, management dashboard should send container ID to identify
// against which container we are requesting reward
service := types.Services{
TxHash: claim.TxHash,
}
// SELECTs the first record; first record which is not marked as delete
err := db.DB.Where("tx_hash = ?", claim.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
return nil, fmt.Errorf("unknown tx hash: %w", err)
}
zlog.Sugar().Infof("service found from txHash: %+v", service)
if service.JobStatus == "running" {
return nil, fmt.Errorf("job is still running")
// c.JSON(503, gin.H{"error": "the job is still running"})
}
reward := RewardRespToCPD{
ServiceProviderAddr: service.ServiceProviderAddr,
ComputeProviderAddr: service.ComputeProviderAddr,
RewardType: service.TransactionType,
SignatureDatum: service.SignatureDatum,
MessageHashDatum: service.MessageHashDatum,
Datum: service.Datum,
SignatureAction: service.SignatureAction,
MessageHashAction: service.MessageHashAction,
Action: service.Action,
}
return &reward, nil
}
func SendStatus(status types.BlockchainTxStatus) string {
if status.TransactionStatus == "success" {
zlog.Sugar().Infof("withdraw transaction successful - updating DB")
// Partial deletion of entry
var service types.Services
err := db.DB.Where("tx_hash = ?", status.TxHash).Find(&service).Error
if err != nil {
zlog.Sugar().Errorln(err)
}
service.TransactionType = "done"
db.DB.Save(&service)
}
return status.TransactionStatus
}
func UpdateStatus(body UpdateTxStatusBody) error {
utxoHashes, err := GetUTXOsOfSmartContract(body.Address, KoiosPreProd)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to fetch UTXOs from Blockchain: %w", err)
}
fiveMinAgo := time.Now().Add(-5 * time.Minute)
var services []types.Services
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Where("deleted_at IS NULL").
Where("created_at <= ?", fiveMinAgo).
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("no job deployed to request reward for: %w", err)
}
err = UpdateTransactionStatus(services, utxoHashes)
if err != nil {
zlog.Sugar().Errorln(err)
return fmt.Errorf("failed to update transaction status")
}
return nil
}
func getLimitedTransactions(sizeDone int) ([]types.Services, error) {
var doneServices []types.Services
var services []types.Services
err := db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type = ?", "done").
Order("created_at DESC").
Limit(sizeDone).
Find(&doneServices).Error
if err != nil {
return []types.Services{}, err
}
err = db.DB.
Where("tx_hash IS NOT NULL").
Where("log_url LIKE ?", "%log.nunet.io%").
Where("transaction_type IS NOT NULL").
Not("transaction_type = ?", "done").
Not("transaction_type = ?", "").
Find(&services).Error
if err != nil {
return []types.Services{}, err
}
services = append(services, doneServices...)
return services, nil
}
// isValidCardano checks if the cardano address is valid
func isValidCardano(addr string, valid *bool) {
defer func() {
if r := recover(); r != nil {
*valid = false
}
}()
if _, err := address.NewAddress(addr); err == nil {
*valid = true
}
}
// ValidateAddress checks if the wallet address is a valid ethereum/cardano address
func ValidateAddress(addr string) error {
if common.IsHexAddress(addr) {
return errors.New("ethereum wallet address not allowed")
}
validCardano := false
isValidCardano(addr, &validCardano)
if validCardano {
return nil
}
return errors.New("invalid cardano wallet address")
}
func GetAddressPaymentCredential(addr string) (string, error) {
_, data, err := bech32.Decode(addr, 1023)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
converted, err := bech32.ConvertBits(data, 5, 8, false)
if err != nil {
return "", fmt.Errorf("decoding bech32 failed: %w", err)
}
return hex.EncodeToString(converted)[2:58], nil
}
// GetTxReceiver returns the list of receivers of a transaction from the transaction hash
func GetTxReceiver(txHash string, endpoint KoiosEndpoint) (string, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_info", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return "", err
}
defer resp.Body.Close()
res := []struct {
Outputs []struct {
InlineDatum struct {
Value struct {
Fields []struct {
Bytes string `json:"bytes"`
} `json:"fields"`
} `json:"value"`
} `json:"inline_datum"`
} `json:"outputs"`
}{}
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&res); err != nil && err != io.EOF {
return "", err
}
if len(res) == 0 || len(res[0].Outputs) == 0 || len(res[0].Outputs[1].InlineDatum.Value.Fields) == 0 {
return "", fmt.Errorf("unable to find receiver")
}
receiver := res[0].Outputs[1].InlineDatum.Value.Fields[1].Bytes
return receiver, nil
}
// GetTxConfirmations returns the number of confirmations of a transaction from the transaction hash
func GetTxConfirmations(txHash string, endpoint KoiosEndpoint) (int, error) {
type Request struct {
TxHashes []string `json:"_tx_hashes"`
}
reqBody, _ := json.Marshal(Request{TxHashes: []string{txHash}})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/tx_status", endpoint),
"application/json",
bytes.NewBuffer(reqBody))
if err != nil {
return 0, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return 0, err
}
var res []struct {
TxHash string `json:"tx_hash"`
Confirmations int `json:"num_confirmations"`
}
if err := json.Unmarshal(body, &res); err != nil {
return 0, err
}
return res[len(res)-1].Confirmations, nil
}
// WaitForTxConfirmation waits for a transaction to be confirmed
func WaitForTxConfirmation(confirmations int, timeout time.Duration, txHash string, endpoint KoiosEndpoint) error {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
conf, err := GetTxConfirmations(txHash, endpoint)
if err != nil {
return err
}
if conf >= confirmations {
return nil
}
case <-time.After(timeout):
return errors.New("timeout")
}
}
}
// GetUTXOsOfSmartContract fetch all utxos of smart contract and return list of tx_hash
func GetUTXOsOfSmartContract(address string, endpoint KoiosEndpoint) ([]string, error) {
type Request struct {
Address []string `json:"_addresses"`
Extended bool `json:"_extended"`
}
reqBody, _ := json.Marshal(Request{Address: []string{address}, Extended: true})
resp, err := http.Post(
fmt.Sprintf("https://%s/api/v1/address_utxos", endpoint),
"application/json",
bytes.NewBuffer(reqBody),
)
if err != nil {
return nil, fmt.Errorf("error making POST request: %v", err)
}
defer resp.Body.Close()
var utxos []UTXOs
jsonDecoder := json.NewDecoder(resp.Body)
if err := jsonDecoder.Decode(&utxos); err != nil && err != io.EOF {
return nil, err
}
utxoHashes := make([]string, 0)
for _, utxo := range utxos {
utxoHashes = append(utxoHashes, utxo.TxHash)
}
return utxoHashes, nil
}
// UpdateTransactionStatus updates the status of claimed transactions in local DB
func UpdateTransactionStatus(services []types.Services, utxoHashes []string) error {
for _, service := range services {
if !SliceContains(utxoHashes, service.TxHash) {
switch service.TransactionType {
case "withdraw":
{
service.TransactionType = transactionWithdrawnStatus
}
case "refund":
{
service.TransactionType = transactionRefundedStatus
}
case "distribute-50":
case "distribute-75":
{
service.TransactionType = transactionDistributedStatus
}
}
s := service
if err := db.DB.Save(&s).Error; err != nil {
return err
}
}
}
return nil
}
package utils
import (
"fmt"
"os"
"github.com/spf13/afero"
)
func GetDirectorySize(fs afero.Fs, path string) (int64, error) {
var size int64
err := afero.Walk(fs, path, func(_ string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
size += info.Size()
}
return nil
})
if err != nil {
return 0, fmt.Errorf("failed to calculate volume size: %w", err)
}
return size, nil
}
package utils
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"path"
)
type HTTPClient struct {
BaseURL string
APIVersion string
Client *http.Client
}
func NewHTTPClient(baseURL, version string) *HTTPClient {
return &HTTPClient{
BaseURL: baseURL,
APIVersion: version,
Client: http.DefaultClient,
}
}
// MakeRequest performs an HTTP request with the given method, path, and body
// It returns the response body, status code, and an error if any
func (c *HTTPClient) MakeRequest(method, relativePath string, body []byte) ([]byte, int, error) {
url, err := url.Parse(c.BaseURL)
if err != nil {
return nil, 0, fmt.Errorf("failed to parse base URL: %v", err)
}
url.Path = path.Join(c.APIVersion, relativePath)
req, err := http.NewRequest(method, url.String(), bytes.NewBuffer(body))
if err != nil {
return nil, 0, fmt.Errorf("failed to create request: %v", err)
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
resp, err := c.Client.Do(req)
if err != nil {
return nil, 0, fmt.Errorf("request failed: %v", err)
}
defer resp.Body.Close()
// Read the response body
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, 0, fmt.Errorf("failed to read response body: %v", err)
}
return respBody, resp.StatusCode, nil
}
package utils
import (
"github.com/uptrace/opentelemetry-go-extra/otelzap"
"gitlab.com/nunet/device-management-service/telemetry/logger"
)
var zlog *otelzap.Logger
const transactionWithdrawnStatus = "withdrawn"
const transactionRefundedStatus = "refunded"
const transactionDistributedStatus = "distributed"
func init() {
zlog = logger.OtelZapLogger("utils")
}
package utils
import (
"io"
"sync"
"time"
)
type IOProgress struct {
n float64
size float64
started time.Time
estimated time.Time
err error
}
type Reader struct {
reader io.Reader
lock sync.RWMutex
Progress IOProgress
}
type Writer struct {
writer io.Writer
lock sync.RWMutex
Progress IOProgress
}
func ReaderWithProgress(r io.Reader, size int64) *Reader {
return &Reader{
reader: r,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func WriterWithProgress(w io.Writer, size int64) *Writer {
return &Writer{
writer: w,
Progress: IOProgress{started: time.Now(), size: float64(size)},
}
}
func (r *Reader) Read(p []byte) (n int, err error) {
n, err = r.reader.Read(p)
r.lock.Lock()
r.Progress.n += float64(n)
r.Progress.err = err
r.lock.Unlock()
return n, err
}
func (w *Writer) Write(p []byte) (n int, err error) {
n, err = w.writer.Write(p)
w.lock.Lock()
w.Progress.n += float64(n)
w.Progress.err = err
w.lock.Unlock()
return n, err
}
func (p IOProgress) Size() float64 {
return p.size
}
func (p IOProgress) N() float64 {
return p.n
}
func (p IOProgress) Complete() bool {
if p.err == io.EOF {
return true
}
if p.size == -1 {
return false
}
return p.n >= p.size
}
// Percent calculates the percentage complete.
func (p IOProgress) Percent() float64 {
if p.n == 0 {
return 0
}
if p.n >= p.size {
return 100
}
return 100.0 / (p.size / p.n)
}
func (p IOProgress) Remaining() time.Duration {
if p.estimated.IsZero() {
return time.Until(p.Estimated())
}
return time.Until(p.estimated)
}
func (p IOProgress) Estimated() time.Time {
ratio := p.n / p.size
past := float64(time.Since(p.started))
if p.n > 0.0 {
total := time.Duration(past / ratio)
p.estimated = p.started.Add(total)
}
return p.estimated
}
package utils
import (
"fmt"
"strings"
"sync"
)
// A SyncMap is a concurrency-safe sync.Map that uses strongly-typed
// method signatures to ensure the types of its stored data are known.
type SyncMap[K comparable, V any] struct {
sync.Map
}
// SyncMapFromMap converts a standard Go map to a concurrency-safe SyncMap.
func SyncMapFromMap[K comparable, V any](m map[K]V) *SyncMap[K, V] {
ret := &SyncMap[K, V]{}
for k, v := range m {
ret.Put(k, v)
}
return ret
}
// Get retrieves the value associated with the given key from the map.
// It returns the value and a boolean indicating whether the key was found.
func (m *SyncMap[K, V]) Get(key K) (V, bool) {
value, ok := m.Load(key)
if !ok {
var empty V
return empty, false
}
return value.(V), true
}
// Put inserts or updates a key-value pair in the map.
func (m *SyncMap[K, V]) Put(key K, value V) {
m.Store(key, value)
}
// Iter iterates over each key-value pair in the map, executing the provided function on each pair.
// The iteration stops if the provided function returns false.
func (m *SyncMap[K, V]) Iter(ranger func(key K, value V) bool) {
m.Range(func(key, value any) bool {
k := key.(K)
v := value.(V)
return ranger(k, v)
})
}
// Keys returns a slice containing all the keys present in the map.
func (m *SyncMap[K, V]) Keys() []K {
var keys []K
m.Iter(func(key K, _ V) bool {
keys = append(keys, key)
return true
})
return keys
}
// String provides a string representation of the map, listing all key-value pairs.
func (m *SyncMap[K, V]) String() string {
// Use a strings.Builder for efficient string concatenation.
var sb strings.Builder
sb.Write([]byte(`{`))
m.Range(func(key, value any) bool {
// Append each key-value pair to the string builder.
sb.Write([]byte(fmt.Sprintf(`%s=%s`, key, value)))
return true
})
sb.Write([]byte(`}`))
return sb.String()
}
package utils
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"log"
"math/big"
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"time"
"github.com/google/uuid"
"github.com/spf13/afero"
"gitlab.com/nunet/device-management-service/db"
"gitlab.com/nunet/device-management-service/types"
"golang.org/x/exp/slices"
)
const (
KernelFileURL = "https://d.nunet.io/fc/vmlinux"
KernelFilePath = "/etc/nunet/vmlinux"
FilesystemURL = "https://d.nunet.io/fc/nunet-fc-ubuntu-20.04-0.ext4"
FilesystemPath = "/etc/nunet/nunet-fc-ubuntu-20.04-0.ext4"
)
// DownloadFile downloads a file from a url and saves it to a filepath
func DownloadFile(url string, filepath string) (err error) {
zlog.Sugar().Infof("Downloading file '", filepath, "' from '", url, "'")
file, err := os.Create(filepath)
if err != nil {
return err
}
defer file.Close()
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
_, err = io.Copy(file, resp.Body)
if err != nil {
return err
}
log.Println("Finished downloading file '", filepath, "'")
return nil
}
// ReadHTTPString GET request to http endpoint and return response as string
func ReadHTTPString(url string) (string, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", err
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", err
}
return string(respBody), nil
}
// RandomString generates a random string of length n
func RandomString(n int) (string, error) {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
sb := strings.Builder{}
sb.Grow(n)
for i := 0; i < n; i++ {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(charset))))
if err != nil {
return "", err
}
sb.WriteByte(charset[n.Int64()])
}
return sb.String(), nil
}
// GenerateMachineUUID generates a machine uuid
func GenerateMachineUUID() (string, error) {
var machine types.MachineUUID
machineUUID, err := uuid.NewDCEGroup()
if err != nil {
return "", err
}
machine.UUID = machineUUID.String()
return machine.UUID, nil
}
// GetMachineUUID returns the machine uuid from the DB
func GetMachineUUID() string {
var machine types.MachineUUID
uuid, err := GenerateMachineUUID()
if err != nil {
zlog.Sugar().Errorf("could not generate machine uuid: %v", err)
}
machine.UUID = uuid
result := db.DB.FirstOrCreate(&machine)
if result.Error != nil {
zlog.Sugar().Errorf("could not find or create machine uuid record in DB: %v", result.Error)
}
return machine.UUID
}
// SliceContains checks if a string exists in a slice
func SliceContains(s []string, str string) bool {
for _, v := range s {
if v == str {
return true
}
}
return false
}
// DeleteFile deletes a file, with or without a backup
func DeleteFile(path string, backup bool) (err error) {
if backup {
err = os.Rename(path, fmt.Sprintf("%s.bk.%d", path, time.Now().Unix()))
} else {
err = os.Remove(path)
}
return
}
// ReadyForElastic checks if the device is ready to send logs to elastic
func ReadyForElastic() bool {
elasticToken := types.ElasticToken{}
db.DB.Find(&elasticToken)
return elasticToken.NodeID != "" && elasticToken.ChannelName != ""
}
// PromptYesNo loops on confirmation from user until valid answer
func PromptYesNo(in io.Reader, out io.Writer, prompt string) (bool, error) {
reader := bufio.NewReader(in)
for {
fmt.Fprintf(out, "%s (y/N): ", prompt)
response, err := reader.ReadString('\n')
if err != nil {
return false, fmt.Errorf("read response string failed: %w", err)
}
response = strings.ToLower(strings.TrimSpace(response))
if response == "y" || response == "yes" {
return true, nil
} else if response == "n" || response == "no" {
return false, nil
}
}
}
// CreateDirectoryIfNotExists creates a directory if it does not exist
func CreateDirectoryIfNotExists(path string) error {
if _, err := os.Stat(path); os.IsNotExist(err) {
err := os.MkdirAll(path, 0o755)
if err != nil {
return err
}
}
return nil
}
// CalculateSHA256Checksum calculates the SHA256 checksum of a file
func CalculateSHA256Checksum(filePath string) (string, error) {
// Open the file for reading
file, err := os.Open(filePath)
if err != nil {
return "", err
}
defer file.Close()
// Create a new SHA-256 hash
hash := sha256.New()
// Copy the file's contents into the hash object
if _, err := io.Copy(hash, file); err != nil {
return "", err
}
// Calculate the checksum and return it as a hexadecimal string
checksum := hex.EncodeToString(hash.Sum(nil))
return checksum, nil
}
// put checksum in file
func CreateCheckSumFile(filePath string, checksum string) (string, error) {
sha256FilePath := fmt.Sprintf("%s.sha256.txt", filePath)
sha256File, err := os.Create(sha256FilePath)
if err != nil {
return "", fmt.Errorf("unable to create SHA-256 checksum file: %v", err)
}
defer sha256File.Close()
_, err = sha256File.WriteString(checksum)
if err != nil {
return "", fmt.Errorf("unable to write to SHA-256 checksum file: %v", err)
}
return sha256FilePath, nil
}
// SanitizeArchivePath Sanitize archive file pathing from "G305: Zip Slip vulnerability"
func SanitizeArchivePath(d, t string) (v string, err error) {
v = filepath.Join(d, t)
if strings.HasPrefix(v, filepath.Clean(d)) {
return v, nil
}
return "", fmt.Errorf("%s: %s", "content filepath is tainted", t)
}
// ExtractTarGzToPath extracts a tar.gz file to a specified path
func ExtractTarGzToPath(tarGzFilePath, extractedPath string) error {
// Ensure the target directory exists; create it if it doesn't.
if err := os.MkdirAll(extractedPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating target directory: %v", err)
}
tarGzFile, err := os.Open(tarGzFilePath)
if err != nil {
return fmt.Errorf("error opening tar.gz file: %v", err)
}
defer tarGzFile.Close()
gzipReader, err := gzip.NewReader(tarGzFile)
if err != nil {
return fmt.Errorf("error creating gzip reader: %v", err)
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
for {
header, err := tarReader.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("error reading tar header: %v", err)
}
// Construct the full target path by joining the target directory with
// the name of the file or directory from the archive.
fullTargetPath, err := SanitizeArchivePath(extractedPath, header.Name)
if err != nil {
return fmt.Errorf("failed to santize path %w", err)
}
// Ensure that the directory path leading to the file exists.
if header.FileInfo().IsDir() {
// Create the directory and any parent directories as needed.
if err := os.MkdirAll(fullTargetPath, os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
} else {
// Create the file and any parent directories as needed.
if err := os.MkdirAll(filepath.Dir(fullTargetPath), os.ModePerm); err != nil {
return fmt.Errorf("error creating directory: %v", err)
}
// Create a new file with the specified path.
newFile, err := os.Create(fullTargetPath)
if err != nil {
return fmt.Errorf("error creating file: %v", err)
}
defer newFile.Close()
// Copy the file contents from the tar archive to the new file.
for {
_, err := io.CopyN(newFile, tarReader, 1024)
if err != nil {
if err == io.EOF {
break
}
return err
}
}
}
}
return nil
}
// CheckWSL check if running in WSL
func CheckWSL(afs afero.Afero) (bool, error) {
file, err := afs.Open("/proc/version")
if err != nil {
return false, err
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if strings.Contains(line, "Microsoft") || strings.Contains(line, "WSL") {
return true, nil
}
}
if scanner.Err() != nil {
return false, scanner.Err()
}
return false, nil
}
// SaveServiceInfo updates service info into SP's DMS for claim Reward by SP user
func SaveServiceInfo(cpService types.Services) error {
var spService types.Services
err := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Find(&spService).Error
if err != nil {
return fmt.Errorf("unable to find service on SP side: %v", err)
}
cpService.ID = spService.ID
cpService.CreatedAt = spService.CreatedAt
result := db.DB.Model(&types.Services{}).Where("tx_hash = ?", cpService.TxHash).Updates(&cpService)
if result.Error != nil {
return fmt.Errorf("unable to update service info on SP side: %v", result.Error.Error())
}
return nil
}
func RandomBool() (bool, error) {
n, err := rand.Int(rand.Reader, big.NewInt(2))
if err != nil {
return false, err
}
// Return true if the number is 1, otherwise false
return n.Int64() == 1, nil
}
func IsExecutorType(v interface{}) bool {
_, ok := v.(types.ExecutorType)
return ok
}
func IsGPUVendor(v interface{}) bool {
_, ok := v.(types.GPUVendor)
return ok
}
func IsJobType(v interface{}) bool {
_, ok := v.(types.JobType)
return ok
}
func IsJobTypes(v interface{}) bool {
_, ok := v.(types.JobTypes)
return ok
}
func IsExecutor(v interface{}) bool {
_, ok := v.(types.Executor)
return ok
}
// IsStrictlyContained checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContained(leftSlice, rightSlice []interface{}) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
// IsStrictlyContainedInt checks if all elements of rightSlice are contained in leftSlice
func IsStrictlyContainedInt(leftSlice, rightSlice []int) bool {
result := false // the default result is false
for _, subElement := range rightSlice {
if !slices.Contains(leftSlice, subElement) {
result = false
break
}
result = true
}
return result
}
func NoIntersectionSlices(slice1, slice2 []interface{}) bool {
result := false // the default result is false
for _, subElement := range slice1 {
if slices.Contains(slice2, subElement) {
result = false
} else {
result = true
}
}
return result
}
// IntersectionStringSlices returns the intersection of two slices of strings.
func IntersectionSlices(slice1, slice2 []interface{}) []interface{} {
// Create a map to store strings from the first slice.
executorMap := make(map[interface{}]bool)
// Iterate through the first slice and add elements to the map.
for _, str := range slice1 {
executorMap[str] = true
}
// Create a slice to store the intersection of the strings.
intersectionSlice := []interface{}{}
// Iterate through the second slice and check for common elements.
for _, str := range slice2 {
if executorMap[str] {
// If the string is found in the map, add to the intersection slice.
intersectionSlice = append(intersectionSlice, str)
// Remove the string from the map to avoid duplicates in the result.
delete(executorMap, str)
}
}
return intersectionSlice
}
func IsSameShallowType(a, b interface{}) bool {
aType := reflect.TypeOf(a)
bType := reflect.TypeOf(b)
result := aType == bType
return result
}
func ConvertTypedSliceToUntypedSlice(typedSlice interface{}) []interface{} {
s := reflect.ValueOf(typedSlice)
if s.Kind() != reflect.Slice {
return nil
}
result := make([]interface{}, s.Len())
for i := 0; i < s.Len(); i++ {
result[i] = s.Index(i).Interface()
}
return result
}
package validate
import (
"reflect"
)
func ConvertNumericToFloat64(n any) (float64, bool) {
switch n := n.(type) {
case int, int8, int16, int32, int64:
return float64(reflect.ValueOf(n).Int()), true
case uint, uint8, uint16, uint32, uint64:
return float64(reflect.ValueOf(n).Uint()), true
case float32:
return float64(n), true
case float64:
return n, true
default:
return 0, false
}
}
package validate
import (
"strings"
)
// IsBlank checks if a string is empty or contains only whitespace
func IsBlank(s string) bool {
return len(strings.TrimSpace(s)) == 0
}
// IsNotBlank checks if a string is not empty and does not contain only whitespace
func IsNotBlank(s string) bool {
return !IsBlank(s)
}
// Just checks if a variable is a string
func IsLiteral(s interface{}) bool {
switch s.(type) {
case string:
return true
default:
return false
}
}