From 1606ad32a7ff412ed541ddc99be68af9d5234354 Mon Sep 17 00:00:00 2001 From: Mike Zupper Date: Wed, 9 Oct 2024 08:43:36 -0400 Subject: [PATCH] Livepeer.Cloud SPE - Proposal #2 - Enable Single Orchestrator AI Job Tests --- cmd/livepeer/livepeer.go | 4 +- cmd/livepeer/starter/starter.go | 18 +- common/util.go | 3 - core/livepeernode.go | 2 + core/os.go | 35 +++- discovery/discovery_test.go | 2 +- discovery/wh_discovery.go | 28 +-- server/ai_session.go | 47 +++-- server/handlers.go | 60 ++++++ server/mediaserver.go | 4 +- .../orchestrator_ai_capabilities_manager.go | 190 ++++++++++++++++++ server/webserver.go | 2 + 12 files changed, 349 insertions(+), 46 deletions(-) create mode 100644 server/orchestrator_ai_capabilities_manager.go diff --git a/cmd/livepeer/livepeer.go b/cmd/livepeer/livepeer.go index 49506cb946..e51a1d31f0 100755 --- a/cmd/livepeer/livepeer.go +++ b/cmd/livepeer/livepeer.go @@ -139,7 +139,9 @@ func parseLivepeerConfig() starter.LivepeerConfig { cfg.IgnoreMaxPriceIfNeeded = flag.Bool("ignoreMaxPriceIfNeeded", *cfg.IgnoreMaxPriceIfNeeded, "Set to true to allow exceeding max price condition if there is no O that meets this requirement") cfg.MinPerfScore = flag.Float64("minPerfScore", *cfg.MinPerfScore, "The minimum orchestrator's performance score a broadcaster is willing to accept") cfg.DiscoveryTimeout = flag.Duration("discoveryTimeout", *cfg.DiscoveryTimeout, "Time to wait for orchestrators to return info to be included in transcoding sessions for manifest (default = 500ms)") - + cfg.AISessionTimeout = flag.Duration("aiSessionTimeout", *cfg.AISessionTimeout, "The length of time (in seconds) that an AI Session will be cached (default = 600s)") + cfg.WebhookRefreshInterval = flag.Duration("webhookRefreshInterval", *cfg.WebhookRefreshInterval, "The length of time (in seconds) that an Orchestrator Webhook Discovery Request will be cached (default = 60s)") + cfg.AITesterGateway = flag.Bool("aiTesterGateway", *cfg.AITesterGateway, "Set to true to allow the gateway to run in \"tester\" mode. This will bypass caching of AI session selectors.") // Transcoding: cfg.Orchestrator = flag.Bool("orchestrator", *cfg.Orchestrator, "Set to true to be an orchestrator") cfg.Transcoder = flag.Bool("transcoder", *cfg.Transcoder, "Set to true to be a transcoder") diff --git a/cmd/livepeer/starter/starter.go b/cmd/livepeer/starter/starter.go index 72fd638de4..cb436ee0ec 100755 --- a/cmd/livepeer/starter/starter.go +++ b/cmd/livepeer/starter/starter.go @@ -157,10 +157,13 @@ type LivepeerConfig struct { FVfailGsKey *string AuthWebhookURL *string OrchWebhookURL *string + WebhookRefreshInterval *time.Duration OrchBlacklist *string OrchMinLivepeerVersion *string TestOrchAvail *bool AIRunnerImage *string + AISessionTimeout *time.Duration + AITesterGateway *bool } // DefaultLivepeerConfig creates LivepeerConfig exactly the same as when no flags are passed to the livepeer process. @@ -204,6 +207,9 @@ func DefaultLivepeerConfig() LivepeerConfig { defaultAIModels := "" defaultAIModelsDir := "" defaultAIRunnerImage := "livepeer/ai-runner:latest" + defaultAISessionTimeout := 10 * time.Minute + defaultWebhookRefreshInterval := 1 * time.Minute + defaultAITesterGateway := false // Onchain: defaultEthAcctAddr := "" @@ -304,6 +310,8 @@ func DefaultLivepeerConfig() LivepeerConfig { AIModels: &defaultAIModels, AIModelsDir: &defaultAIModelsDir, AIRunnerImage: &defaultAIRunnerImage, + AISessionTimeout: &defaultAISessionTimeout, + AITesterGateway: &defaultAITesterGateway, // Onchain: EthAcctAddr: &defaultEthAcctAddr, @@ -357,8 +365,9 @@ func DefaultLivepeerConfig() LivepeerConfig { FVfailGsKey: &defaultFVfailGsKey, // API - AuthWebhookURL: &defaultAuthWebhookURL, - OrchWebhookURL: &defaultOrchWebhookURL, + AuthWebhookURL: &defaultAuthWebhookURL, + OrchWebhookURL: &defaultOrchWebhookURL, + WebhookRefreshInterval: &defaultWebhookRefreshInterval, // Versioning constraints OrchMinLivepeerVersion: &defaultMinLivepeerVersion, @@ -1042,6 +1051,8 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { server.BroadcastCfg.SetCapabilityMaxPrice(cap, p.ModelID, autoCapPrice) } } + n.AITesterGateway = *cfg.AITesterGateway + n.AISessionTimeout = *cfg.AISessionTimeout } if n.NodeType == core.RedeemerNode { @@ -1399,7 +1410,8 @@ func StartLivepeer(ctx context.Context, cfg LivepeerConfig) { glog.Exit("Error setting orch webhook URL ", err) } glog.Info("Using orchestrator webhook URL ", whurl) - n.OrchestratorPool = discovery.NewWebhookPool(bcast, whurl, *cfg.DiscoveryTimeout) + glog.Info("Using orchestrator webhook refresh interval ", *cfg.WebhookRefreshInterval) + n.OrchestratorPool = discovery.NewWebhookPool(bcast, whurl, *cfg.DiscoveryTimeout, *cfg.WebhookRefreshInterval) } else if len(orchURLs) > 0 { n.OrchestratorPool = discovery.NewOrchestratorPool(bcast, orchURLs, common.Score_Trusted, orchBlacklist, *cfg.DiscoveryTimeout) } diff --git a/common/util.go b/common/util.go index dd6e73b872..3bb09cdf0f 100644 --- a/common/util.go +++ b/common/util.go @@ -44,9 +44,6 @@ var SegUploadTimeoutMultiplier = 0.5 // MinSegmentUploadTimeout defines the minimum timeout enforced for uploading a segment to orchestrators var MinSegmentUploadTimeout = 2 * time.Second -// WebhookDiscoveryRefreshInterval defines for long the Webhook Discovery values should be cached -var WebhookDiscoveryRefreshInterval = 1 * time.Minute - // Max Segment Duration var MaxDuration = (5 * time.Minute) diff --git a/core/livepeernode.go b/core/livepeernode.go index 4ef1fbcfd8..d9191623e9 100644 --- a/core/livepeernode.go +++ b/core/livepeernode.go @@ -120,6 +120,8 @@ type LivepeerNode struct { // AI worker public fields AIWorker AI AIWorkerManager *RemoteAIWorkerManager + AISessionTimeout time.Duration + AITesterGateway bool // Transcoder public fields SegmentChans map[ManifestID]SegmentChan diff --git a/core/os.go b/core/os.go index 9a3978b82a..3fea72bf43 100644 --- a/core/os.go +++ b/core/os.go @@ -8,6 +8,7 @@ import ( "crypto/tls" "fmt" "net/http" + "os" "time" "github.com/livepeer/go-livepeer/clog" @@ -20,9 +21,35 @@ func DownloadData(ctx context.Context, uri string) ([]byte, error) { return downloadDataHTTP(ctx, uri) } -var httpc = &http.Client{ - Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, - Timeout: common.HTTPTimeout / 2, +var osHttpClient = getHTTPClient() + +// getHTTPClient creates an HTTP client with a timeout based on an environment variable or defaults to common.HTTPTimeout/2 +func getHTTPClient() *http.Client { + // Get the timeout value from the environment variable + timeoutStr := os.Getenv("LIVEPEER_OS_HTTP_TIMEOUT") + + // Define a default timeout value as common.HTTPTimeout / 2 + defaultTimeout := common.HTTPTimeout / 2 + + var timeout time.Duration + var err error + + // If the environment variable is set, attempt to parse it + if timeoutStr != "" { + timeout, err = time.ParseDuration(timeoutStr) + if err != nil { + timeout = defaultTimeout + } + } else { + // If the environment variable is not set, use the default timeout + timeout = defaultTimeout + } + + // Return the HTTP client with the calculated timeout + return &http.Client{ + Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}, + Timeout: timeout, + } } func FromNetOsInfo(os *net.OSInfo) *drivers.OSInfo { @@ -76,7 +103,7 @@ func ToNetS3Info(storage *drivers.S3OSInfo) *net.S3OSInfo { func downloadDataHTTP(ctx context.Context, uri string) ([]byte, error) { clog.V(common.VERBOSE).Infof(ctx, "Downloading uri=%s", uri) started := time.Now() - resp, err := httpc.Get(uri) + resp, err := osHttpClient.Get(uri) if err != nil { clog.Errorf(ctx, "Error getting HTTP uri=%s err=%q", uri, err) return nil, err diff --git a/discovery/discovery_test.go b/discovery/discovery_test.go index 8ffc2fe013..bdfe63a52d 100644 --- a/discovery/discovery_test.go +++ b/discovery/discovery_test.go @@ -1120,7 +1120,7 @@ func TestNewWHOrchestratorPoolCache(t *testing.T) { // assert created webhook pool is correct length whURL, _ := url.ParseRequestURI("https://livepeer.live/api/orchestrator") - whpool := NewWebhookPool(nil, whURL, 500*time.Millisecond) + whpool := NewWebhookPool(nil, whURL, 500*time.Millisecond, 1*time.Minute) assert.Equal(3, whpool.Size()) // assert that list is not refreshed if lastRequest is less than 1 min ago and hash is the same diff --git a/discovery/wh_discovery.go b/discovery/wh_discovery.go index 7e552dfcce..ce295d8cea 100644 --- a/discovery/wh_discovery.go +++ b/discovery/wh_discovery.go @@ -21,21 +21,23 @@ type webhookResponse struct { } type webhookPool struct { - pool *orchestratorPool - callback *url.URL - responseHash ethcommon.Hash - lastRequest time.Time - mu *sync.RWMutex - bcast common.Broadcaster - discoveryTimeout time.Duration + pool *orchestratorPool + callback *url.URL + responseHash ethcommon.Hash + lastRequest time.Time + mu *sync.RWMutex + bcast common.Broadcaster + discoveryTimeout time.Duration + webhookRefreshInterval time.Duration } -func NewWebhookPool(bcast common.Broadcaster, callback *url.URL, discoveryTimeout time.Duration) *webhookPool { +func NewWebhookPool(bcast common.Broadcaster, callback *url.URL, discoveryTimeout time.Duration, webhookRefreshInterval time.Duration) *webhookPool { p := &webhookPool{ - callback: callback, - mu: &sync.RWMutex{}, - bcast: bcast, - discoveryTimeout: discoveryTimeout, + callback: callback, + mu: &sync.RWMutex{}, + bcast: bcast, + discoveryTimeout: discoveryTimeout, + webhookRefreshInterval: webhookRefreshInterval, } go p.getInfos() return p @@ -48,7 +50,7 @@ func (w *webhookPool) getInfos() ([]common.OrchestratorLocalInfo, error) { w.mu.RUnlock() // retrive addrs from cache if time since lastRequest is less than the refresh interval - if time.Since(lastReq) < common.WebhookDiscoveryRefreshInterval { + if time.Since(lastReq) < w.webhookRefreshInterval { return pool.GetInfos(), nil } diff --git a/server/ai_session.go b/server/ai_session.go index 63cb8134cc..d12f6fe7ec 100644 --- a/server/ai_session.go +++ b/server/ai_session.go @@ -341,18 +341,20 @@ func (sel *AISessionSelector) getSessions(ctx context.Context) ([]*BroadcastSess } type AISessionManager struct { - node *core.LivepeerNode - selectors map[string]*AISessionSelector - mu sync.Mutex - ttl time.Duration + node *core.LivepeerNode + selectors map[string]*AISessionSelector + mu sync.Mutex + ttl time.Duration + testerGatewayEnabled bool } -func NewAISessionManager(node *core.LivepeerNode, ttl time.Duration) *AISessionManager { +func NewAISessionManager(node *core.LivepeerNode, ttl time.Duration, testerGatewayEnabled bool) *AISessionManager { return &AISessionManager{ - node: node, - selectors: make(map[string]*AISessionSelector), - mu: sync.Mutex{}, - ttl: ttl, + node: node, + selectors: make(map[string]*AISessionSelector), + mu: sync.Mutex{}, + ttl: ttl, + testerGatewayEnabled: testerGatewayEnabled, } } @@ -400,18 +402,27 @@ func (c *AISessionManager) getSelector(ctx context.Context, cap core.Capability, c.mu.Lock() defer c.mu.Unlock() - cacheKey := strconv.Itoa(int(cap)) + "_" + modelID - sel, ok := c.selectors[cacheKey] - if !ok { - // Create the selector - var err error - sel, err = NewAISessionSelector(cap, modelID, c.node, c.ttl) + if c.testerGatewayEnabled { + sel, err := NewAISessionSelector(cap, modelID, c.node, c.ttl) if err != nil { return nil, err } + return sel, nil + } else { + cacheKey := strconv.Itoa(int(cap)) + "_" + modelID + sel, ok := c.selectors[cacheKey] + if !ok { + // Create the selector + var err error + sel, err = NewAISessionSelector(cap, modelID, c.node, c.ttl) + if err != nil { + return nil, err + } + + c.selectors[cacheKey] = sel + } - c.selectors[cacheKey] = sel - } + return sel, nil - return sel, nil + } } diff --git a/server/handlers.go b/server/handlers.go index 3f4673b6fe..45bb1266cf 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -296,6 +296,31 @@ func getAvailableTranscodingOptionsHandler() http.Handler { }) } +func (s *LivepeerServer) getOrchestratorAICapabilitiesHandler() http.Handler { + // Define the data fetch function + fetchFunc := func() (*OrchestratorAICapabilitiesManager, error) { + networkCapsMgr, err := buildOrchestratorAICapabilitiesManager(s.LivepeerNode) + if err != nil { + return nil, fmt.Errorf(`failed to fetch orch AI capabilities: %v`, err.Error()) + } + return networkCapsMgr, nil + } + + // Initialize the cache with a TTL + cacheTTL := 120 * time.Second + networkCapsMgrCache := NewCache(cacheTTL, fetchFunc) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get data from cache (or fetch if stale) + networkCapsMgr, err := networkCapsMgrCache.GetCache() + if err != nil { + respond500(w, err.Error()) + return + } + respondJson(w, networkCapsMgr) + }) +} + // Rounds func currentRoundHandler(client eth.LivepeerEthClient) http.Handler { return mustHaveClient(client, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -1634,3 +1659,38 @@ func mustHaveDb(db interface{}, h http.Handler) http.Handler { h.ServeHTTP(w, r) }) } + +// Cache struct to hold the Cacheable output and expiration time +type Cache[T any] struct { + Data T + Expiration time.Time + TTL time.Duration + fetchFunc func() (T, error) // Function to fetch data when cache is stale +} + +// NewCache initializes the cache with a TTL and fetch function +func NewCache[T any](ttl time.Duration, fetchFunc func() (T, error)) *Cache[T] { + return &Cache[T]{ + TTL: ttl, + fetchFunc: fetchFunc, + } +} + +// GetCache returns the cached string if it's still valid; otherwise, it fetches new data +func (c *Cache[T]) GetCache() (T, error) { + var zeroValue T + // Check if cache is still valid + if time.Now().Before(c.Expiration) { + return c.Data, nil + } + // Cache is stale; fetch new data + newData, err := c.fetchFunc() + if err != nil { + return zeroValue, err + } + + // Update cache with the new data and set expiration + c.Data = newData + c.Expiration = time.Now().Add(c.TTL) + return c.Data, nil +} diff --git a/server/mediaserver.go b/server/mediaserver.go index 5bd7b5a2f9..2db16c9f1c 100644 --- a/server/mediaserver.go +++ b/server/mediaserver.go @@ -62,8 +62,6 @@ const StreamKeyBytes = 6 const SegLen = 2 * time.Second const BroadcastRetry = 15 * time.Second -const AISessionManagerTTL = 10 * time.Minute - var BroadcastJobVideoProfiles = []ffmpeg.VideoProfile{ffmpeg.P240p30fps4x3, ffmpeg.P360p30fps16x9} var AuthWebhookURL *url.URL @@ -188,7 +186,7 @@ func NewLivepeerServer(rtmpAddr string, lpNode *core.LivepeerNode, httpIngest bo rtmpConnections: make(map[core.ManifestID]*rtmpConnection), internalManifests: make(map[core.ManifestID]core.ManifestID), recordingsAuthResponses: cache.New(time.Hour, 2*time.Hour), - AISessionManager: NewAISessionManager(lpNode, AISessionManagerTTL), + AISessionManager: NewAISessionManager(lpNode, lpNode.AISessionTimeout, lpNode.AITesterGateway), } if lpNode.NodeType == core.BroadcasterNode && httpIngest { opts.HttpMux.HandleFunc("/live/", ls.HandlePush) diff --git a/server/orchestrator_ai_capabilities_manager.go b/server/orchestrator_ai_capabilities_manager.go new file mode 100644 index 0000000000..1511ee10d6 --- /dev/null +++ b/server/orchestrator_ai_capabilities_manager.go @@ -0,0 +1,190 @@ +package server + +import ( + "context" + "errors" + "fmt" + + "github.com/ethereum/go-ethereum/common/hexutil" + "github.com/golang/glog" + "github.com/livepeer/go-livepeer/common" + "github.com/livepeer/go-livepeer/core" +) + +// ModelStatus represents the status information for each model under an orchestrator. +type ModelStatus struct { + Cold int `json:"Cold"` + Warm int `json:"Warm"` +} + +// ModelInfo represents information about a model in a pipeline. +type ModelInfo struct { + Name string `json:"name"` + Status ModelStatus `json:"status"` +} + +// PipelineInfo represents information about a pipeline. +type PipelineInfo struct { + Type string `json:"type"` + Models []ModelInfo `json:"models"` +} + +// OrchestratorInfo represents the information for each orchestrator. +type OrchestratorInfo struct { + Address string `json:"address"` + Pipelines []PipelineInfo `json:"pipelines"` +} + +// OrchestratorAICapabilitiesManager represents the full structure with better type safety. +type OrchestratorAICapabilitiesManager struct { + Orchestrators []OrchestratorInfo `json:"orchestrators"` +} + +// NewOrchestratorAICapabilitiesManager creates and initializes a new OrchestratorAICapabilitiesManager structure. +func NewOrchestratorAICapabilitiesManager() *OrchestratorAICapabilitiesManager { + return &OrchestratorAICapabilitiesManager{ + Orchestrators: []OrchestratorInfo{}, + } +} + +// AddOrchestrator adds a new orchestrator if it doesn't exist. +func (ncm *OrchestratorAICapabilitiesManager) AddOrchestrator(orchAddress string) { + if ncm.getOrchestrator(orchAddress) == nil { + ncm.Orchestrators = append(ncm.Orchestrators, OrchestratorInfo{ + Address: orchAddress, + Pipelines: []PipelineInfo{}, + }) + } +} + +// getOrchestrator retrieves an orchestrator by address. +func (ncm *OrchestratorAICapabilitiesManager) getOrchestrator(orchAddress string) *OrchestratorInfo { + for i := range ncm.Orchestrators { + if ncm.Orchestrators[i].Address == orchAddress { + return &ncm.Orchestrators[i] + } + } + return nil +} + +// AddOrchestratorPipeline adds or updates a pipeline entry for an orchestrator. +func (ncm *OrchestratorAICapabilitiesManager) AddOrchestratorPipeline(orchAddress string, pipelineType string) { + orch := ncm.getOrchestrator(orchAddress) + if orch == nil { + ncm.AddOrchestrator(orchAddress) + orch = ncm.getOrchestrator(orchAddress) + } + + if orch.getPipeline(pipelineType) == nil { + orch.Pipelines = append(orch.Pipelines, PipelineInfo{ + Type: pipelineType, + Models: []ModelInfo{}, + }) + } +} + +// getPipeline retrieves a pipeline by type. +func (orch *OrchestratorInfo) getPipeline(pipelineType string) *PipelineInfo { + for i := range orch.Pipelines { + if orch.Pipelines[i].Type == pipelineType { + return &orch.Pipelines[i] + } + } + return nil +} + +// AddOrchestratorPipelineModel adds or updates a model in an orchestrator's pipeline. +func (ncm *OrchestratorAICapabilitiesManager) AddOrchestratorPipelineModel(orchAddress, pipelineType, modelName string, warm bool) error { + orch := ncm.getOrchestrator(orchAddress) + if orch == nil { + return errors.New("orchestrator not found") + } + pipeline := orch.getPipeline(pipelineType) + if pipeline == nil { + orch.Pipelines = append(orch.Pipelines, PipelineInfo{ + Type: pipelineType, + Models: []ModelInfo{}, + }) + pipeline = orch.getPipeline(pipelineType) + } + + model := pipeline.getModel(modelName) + if model == nil { + pipeline.Models = append(pipeline.Models, ModelInfo{ + Name: modelName, + Status: ModelStatus{}, + }) + model = pipeline.getModel(modelName) + } + + if warm { + model.Status.Warm++ + } else { + model.Status.Cold++ + } + return nil +} + +// getModel retrieves a model by name. +func (pipeline *PipelineInfo) getModel(modelName string) *ModelInfo { + for i := range pipeline.Models { + if pipeline.Models[i].Name == modelName { + return &pipeline.Models[i] + } + } + return nil +} + +// PrintJSONData is a utility function to print the current JSONData. +func (ncm *OrchestratorAICapabilitiesManager) PrintJSONData() { + fmt.Printf("Orchestrators: %+v\n", ncm.Orchestrators) +} + +func buildOrchestratorAICapabilitiesManager(livepeerNode *core.LivepeerNode) (*OrchestratorAICapabilitiesManager, error) { + caps := core.NewCapabilities(core.DefaultCapabilities(), nil) + caps.SetPerCapabilityConstraints(nil) + ods, err := livepeerNode.OrchestratorPool.GetOrchestrators(context.Background(), 100, newSuspender(), caps, common.ScoreAtLeast(0)) + caps.SetMinVersionConstraint(livepeerNode.Capabilities.MinVersionConstraint()) + if err != nil { + return nil, err + } + + networkCapsMgr := NewOrchestratorAICapabilitiesManager() + remoteInfos := ods.GetRemoteInfos() + + glog.V(common.SHORT).Infof("getting network capabilities for %d orchestrators", len(remoteInfos)) + + for idx, orch_info := range remoteInfos { + glog.V(common.DEBUG).Infof("getting capabilities for orchestrator %d %v", idx, orch_info.Transcoder) + + // Ensure the orch has the proper on-chain TicketParams to ensure ethAddress was configured. + tp := orch_info.TicketParams + var orchAddr string + if tp != nil { + ethAddress := tp.Recipient + orchAddr = hexutil.Encode(ethAddress) + } else { + orchAddr = "0x0000000000000000000000000000000000000000" + } + + // Parse the capabilities and capacities. + if orch_info.GetCapabilities() != nil { + for capability, constraints := range orch_info.Capabilities.Constraints.PerCapability { + capName, err := core.CapabilityToName(core.Capability(int(capability))) + if err != nil { + continue + } + networkCapsMgr.AddOrchestratorPipeline(orchAddr, capName) + + models := constraints.GetModels() + for model, constraint := range models { + err := networkCapsMgr.AddOrchestratorPipelineModel(orchAddr, capName, model, constraint.GetWarm()) + if err != nil { + glog.V(common.DEBUG).Infof(" error adding model to orch %v", err) + } + } + } + } + } + return networkCapsMgr, nil +} diff --git a/server/webserver.go b/server/webserver.go index cf1578b317..758baa275b 100644 --- a/server/webserver.go +++ b/server/webserver.go @@ -114,5 +114,7 @@ func (s *LivepeerServer) cliWebServerHandlers(bindAddr string) *http.ServeMux { mux.Handle("/metrics", monitor.Exporter) } + //AI Handlers + mux.Handle("/getOrchestratorAICapabilities", s.getOrchestratorAICapabilitiesHandler()) return mux }