From da20daf573a3d3f399d4ace7a35d0515aba883be Mon Sep 17 00:00:00 2001 From: HR Wu <5631010+heiruwu@users.noreply.github.com> Date: Mon, 16 Dec 2024 17:53:05 +0800 Subject: [PATCH] feat(ray): utilize redis as deployment config cache (#722) Because - avoid constant disk IO ops This commit - utilize redis as deployment config cache - write config to persistent disk only on program exit --- cmd/main/main.go | 6 +- cmd/worker/main.go | 6 +- pkg/mock/ray_mock.gen.go | 119 +++++++++++++++++++++++++++++++++++---- pkg/ray/const.go | 9 +-- pkg/ray/ray.go | 64 +++++++++++++-------- 5 files changed, 160 insertions(+), 44 deletions(-) diff --git a/cmd/main/main.go b/cmd/main/main.go index feb6a249..f3c0874e 100644 --- a/cmd/main/main.go +++ b/cmd/main/main.go @@ -160,9 +160,6 @@ func main() { publicGrpcS := grpc.NewServer(grpcServerOpts...) reflection.Register(publicGrpcS) - rayService := ray.NewRay() - defer rayService.Close() - mgmtPublicServiceClient, mgmtPublicServiceClientConn := external.InitMgmtPublicServiceClient(ctx) defer mgmtPublicServiceClientConn.Close() @@ -175,6 +172,9 @@ func main() { redisClient := redis.NewClient(&config.Config.Cache.Redis.RedisOptions) defer redisClient.Close() + rayService := ray.NewRay(redisClient) + defer rayService.Close() + temporalTracingInterceptor, err := opentelemetry.NewTracingInterceptor(opentelemetry.TracerOptions{ Tracer: otel.Tracer("temporal-tracer"), TextMapPropagator: propagator, diff --git a/cmd/worker/main.go b/cmd/worker/main.go index 3ba1b953..5b6fa900 100644 --- a/cmd/worker/main.go +++ b/cmd/worker/main.go @@ -109,12 +109,12 @@ func main() { db := database.GetSharedConnection() defer database.Close(db) - rayService := ray.NewRay() - defer rayService.Close() - redisClient := redis.NewClient(&config.Config.Cache.Redis.RedisOptions) defer redisClient.Close() + rayService := ray.NewRay(redisClient) + defer rayService.Close() + temporalTracingInterceptor, err := opentelemetry.NewTracingInterceptor(opentelemetry.TracerOptions{ Tracer: otel.Tracer("temporal-tracer"), TextMapPropagator: b3.New(b3.WithInjectEncoding(b3.B3MultipleHeader)), diff --git a/pkg/mock/ray_mock.gen.go b/pkg/mock/ray_mock.gen.go index b7ee024f..f7bb39e1 100644 --- a/pkg/mock/ray_mock.gen.go +++ b/pkg/mock/ray_mock.gen.go @@ -13,6 +13,7 @@ import ( "github.com/instill-ai/model-backend/pkg/ray/rayserver" commonpb "github.com/instill-ai/protogen-go/common/task/v1alpha" modelpb "github.com/instill-ai/protogen-go/model/model/v1alpha" + "github.com/redis/go-redis/v9" ) // RayMock implements ray.Ray @@ -26,8 +27,8 @@ type RayMock struct { beforeCloseCounter uint64 CloseMock mRayMockClose - funcInit func() - inspectFuncInit func() + funcInit func(rc *redis.Client) + inspectFuncInit func(rc *redis.Client) afterInitCounter uint64 beforeInitCounter uint64 InitMock mRayMockInit @@ -68,6 +69,7 @@ func NewRayMock(t minimock.Tester) *RayMock { m.CloseMock = mRayMockClose{mock: m} m.InitMock = mRayMockInit{mock: m} + m.InitMock.callArgs = []*RayMockInitParams{} m.IsRayServerReadyMock = mRayMockIsRayServerReady{mock: m} m.IsRayServerReadyMock.callArgs = []*RayMockIsRayServerReadyParams{} @@ -263,16 +265,31 @@ type mRayMockInit struct { defaultExpectation *RayMockInitExpectation expectations []*RayMockInitExpectation + callArgs []*RayMockInitParams + mutex sync.RWMutex + expectedInvocations uint64 } // RayMockInitExpectation specifies expectation struct of the Ray.Init type RayMockInitExpectation struct { - mock *RayMock + mock *RayMock + params *RayMockInitParams + paramPtrs *RayMockInitParamPtrs Counter uint64 } +// RayMockInitParams contains parameters of the Ray.Init +type RayMockInitParams struct { + rc *redis.Client +} + +// RayMockInitParamPtrs contains pointers to parameters of the Ray.Init +type RayMockInitParamPtrs struct { + rc **redis.Client +} + // Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning // the test will fail minimock's automatic final call check if the mocked method was not called at least once. // Optional() makes method check to work in '0 or more' mode. @@ -284,7 +301,7 @@ func (mmInit *mRayMockInit) Optional() *mRayMockInit { } // Expect sets up expected params for Ray.Init -func (mmInit *mRayMockInit) Expect() *mRayMockInit { +func (mmInit *mRayMockInit) Expect(rc *redis.Client) *mRayMockInit { if mmInit.mock.funcInit != nil { mmInit.mock.t.Fatalf("RayMock.Init mock is already set by Set") } @@ -293,11 +310,44 @@ func (mmInit *mRayMockInit) Expect() *mRayMockInit { mmInit.defaultExpectation = &RayMockInitExpectation{} } + if mmInit.defaultExpectation.paramPtrs != nil { + mmInit.mock.t.Fatalf("RayMock.Init mock is already set by ExpectParams functions") + } + + mmInit.defaultExpectation.params = &RayMockInitParams{rc} + for _, e := range mmInit.expectations { + if minimock.Equal(e.params, mmInit.defaultExpectation.params) { + mmInit.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmInit.defaultExpectation.params) + } + } + + return mmInit +} + +// ExpectRcParam1 sets up expected param rc for Ray.Init +func (mmInit *mRayMockInit) ExpectRcParam1(rc *redis.Client) *mRayMockInit { + if mmInit.mock.funcInit != nil { + mmInit.mock.t.Fatalf("RayMock.Init mock is already set by Set") + } + + if mmInit.defaultExpectation == nil { + mmInit.defaultExpectation = &RayMockInitExpectation{} + } + + if mmInit.defaultExpectation.params != nil { + mmInit.mock.t.Fatalf("RayMock.Init mock is already set by Expect") + } + + if mmInit.defaultExpectation.paramPtrs == nil { + mmInit.defaultExpectation.paramPtrs = &RayMockInitParamPtrs{} + } + mmInit.defaultExpectation.paramPtrs.rc = &rc + return mmInit } // Inspect accepts an inspector function that has same arguments as the Ray.Init -func (mmInit *mRayMockInit) Inspect(f func()) *mRayMockInit { +func (mmInit *mRayMockInit) Inspect(f func(rc *redis.Client)) *mRayMockInit { if mmInit.mock.inspectFuncInit != nil { mmInit.mock.t.Fatalf("Inspect function is already set for RayMock.Init") } @@ -321,7 +371,7 @@ func (mmInit *mRayMockInit) Return() *RayMock { } // Set uses given function f to mock the Ray.Init method -func (mmInit *mRayMockInit) Set(f func()) *RayMock { +func (mmInit *mRayMockInit) Set(f func(rc *redis.Client)) *RayMock { if mmInit.defaultExpectation != nil { mmInit.mock.t.Fatalf("Default expectation is already set for the Ray.Init method") } @@ -355,25 +405,53 @@ func (mmInit *mRayMockInit) invocationsDone() bool { } // Init implements ray.Ray -func (mmInit *RayMock) Init() { +func (mmInit *RayMock) Init(rc *redis.Client) { mm_atomic.AddUint64(&mmInit.beforeInitCounter, 1) defer mm_atomic.AddUint64(&mmInit.afterInitCounter, 1) if mmInit.inspectFuncInit != nil { - mmInit.inspectFuncInit() + mmInit.inspectFuncInit(rc) + } + + mm_params := RayMockInitParams{rc} + + // Record call args + mmInit.InitMock.mutex.Lock() + mmInit.InitMock.callArgs = append(mmInit.InitMock.callArgs, &mm_params) + mmInit.InitMock.mutex.Unlock() + + for _, e := range mmInit.InitMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return + } } if mmInit.InitMock.defaultExpectation != nil { mm_atomic.AddUint64(&mmInit.InitMock.defaultExpectation.Counter, 1) + mm_want := mmInit.InitMock.defaultExpectation.params + mm_want_ptrs := mmInit.InitMock.defaultExpectation.paramPtrs + + mm_got := RayMockInitParams{rc} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.rc != nil && !minimock.Equal(*mm_want_ptrs.rc, mm_got.rc) { + mmInit.t.Errorf("RayMock.Init got unexpected parameter rc, want: %#v, got: %#v%s\n", *mm_want_ptrs.rc, mm_got.rc, minimock.Diff(*mm_want_ptrs.rc, mm_got.rc)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmInit.t.Errorf("RayMock.Init got unexpected parameters, want: %#v, got: %#v%s\n", *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } return } if mmInit.funcInit != nil { - mmInit.funcInit() + mmInit.funcInit(rc) return } - mmInit.t.Fatalf("Unexpected call to RayMock.Init.") + mmInit.t.Fatalf("Unexpected call to RayMock.Init. %v", rc) } @@ -387,6 +465,19 @@ func (mmInit *RayMock) InitBeforeCounter() uint64 { return mm_atomic.LoadUint64(&mmInit.beforeInitCounter) } +// Calls returns a list of arguments used in each call to RayMock.Init. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmInit *mRayMockInit) Calls() []*RayMockInitParams { + mmInit.mutex.RLock() + + argCopy := make([]*RayMockInitParams, len(mmInit.callArgs)) + copy(argCopy, mmInit.callArgs) + + mmInit.mutex.RUnlock() + + return argCopy +} + // MinimockInitDone returns true if the count of the Init invocations corresponds // the number of defined expectations func (m *RayMock) MinimockInitDone() bool { @@ -408,14 +499,18 @@ func (m *RayMock) MinimockInitDone() bool { func (m *RayMock) MinimockInitInspect() { for _, e := range m.InitMock.expectations { if mm_atomic.LoadUint64(&e.Counter) < 1 { - m.t.Error("Expected call to RayMock.Init") + m.t.Errorf("Expected call to RayMock.Init with params: %#v", *e.params) } } afterInitCounter := mm_atomic.LoadUint64(&m.afterInitCounter) // if default expectation was set then invocations count should be greater than zero if m.InitMock.defaultExpectation != nil && afterInitCounter < 1 { - m.t.Error("Expected call to RayMock.Init") + if m.InitMock.defaultExpectation.params == nil { + m.t.Error("Expected call to RayMock.Init") + } else { + m.t.Errorf("Expected call to RayMock.Init with params: %#v", *m.InitMock.defaultExpectation.params) + } } // if func was set then invocations count should be greater than zero if m.funcInit != nil && afterInitCounter < 1 { diff --git a/pkg/ray/const.go b/pkg/ray/const.go index e68666a2..cba587ee 100644 --- a/pkg/ray/const.go +++ b/pkg/ray/const.go @@ -282,6 +282,10 @@ var SupportedAcceleratorTypeMemory = map[string]int{ } const ( + // Ray redis key + RayDeploymentKey = "model_deployment_config" + + // Ray deployment env variables EnvIsTestModel = "RAY_IS_TEST_MODEL" EnvMemory = "RAY_MEMORY" EnvTotalVRAM = "RAY_TOTAL_VRAM" @@ -291,8 +295,5 @@ const ( EnvNumOfCPUs = "RAY_NUM_OF_CPUS" EnvNumOfMinReplicas = "RAY_NUM_OF_MIN_REPLICAS" EnvNumOfMaxReplicas = "RAY_NUM_OF_MAX_REPLICAS" -) - -const ( - DummyModelPrefix = "dummy-" + DummyModelPrefix = "dummy-" ) diff --git a/pkg/ray/ray.go b/pkg/ray/ray.go index 85fc9537..f2ad309f 100644 --- a/pkg/ray/ray.go +++ b/pkg/ray/ray.go @@ -23,6 +23,7 @@ import ( "github.com/instill-ai/model-backend/config" "github.com/instill-ai/model-backend/pkg/constant" "github.com/instill-ai/model-backend/pkg/ray/rayserver" + "github.com/redis/go-redis/v9" commonpb "github.com/instill-ai/protogen-go/common/task/v1alpha" modelpb "github.com/instill-ai/protogen-go/model/model/v1alpha" @@ -38,7 +39,7 @@ type Ray interface { // standard IsRayServerReady(ctx context.Context) bool UpdateContainerizedModel(ctx context.Context, modelName string, userID string, imageName string, version string, hardware string, action Action, scalingConfig []string, numOfGPU string) error - Init() + Init(rc *redis.Client) Close() } @@ -46,6 +47,7 @@ type ray struct { rayClient rayserver.RayServiceClient rayServeClient rayserver.RayServeAPIServiceClient rayHTTPClient *http.Client + redisClient *redis.Client connection *grpc.ClientConn configFilePath string configChan chan ApplicationWithAction @@ -55,15 +57,16 @@ type ray struct { var once sync.Once var rayService *ray -func NewRay() Ray { +func NewRay(rc *redis.Client) Ray { once.Do(func() { rayService = &ray{} - rayService.Init() + rayService.Init(rc) }) return rayService } -func (r *ray) Init() { +func (r *ray) Init(rc *redis.Client) { + ctx := context.Background() // Connect to gRPC server conn, err := grpc.NewClient( config.Config.RayServer.GrpcURI, @@ -78,6 +81,8 @@ func (r *ray) Init() { log.Fatalf("Couldn't connect to endpoint %s: %v", config.Config.RayServer.GrpcURI, err) } + r.redisClient = rc + // Create client from gRPC server connection r.connection = conn r.rayClient = rayserver.NewRayServiceClient(conn) @@ -87,28 +92,19 @@ func (r *ray) Init() { r.doneChan = make(chan error, 10000) r.configFilePath = path.Join(config.Config.RayServer.ModelStore, "deploy.yaml") - var modelDeploymentConfig ModelDeploymentConfig isCorrupted := false currentConfigFile, err := os.ReadFile(r.configFilePath) if err != nil { isCorrupted = true } - err = yaml.Unmarshal(currentConfigFile, &modelDeploymentConfig) - if err != nil { - isCorrupted = true - } - if _, err := os.Stat(r.configFilePath); os.IsNotExist(err) || isCorrupted { - initDeployConfig := ModelDeploymentConfig{ - Applications: []Application{}, - } - initConfigData, err := yaml.Marshal(&initDeployConfig) - if err != nil { - fmt.Printf("error while Marshaling deployment config: %v\n", err) - } - if err := os.WriteFile(r.configFilePath, initConfigData, 0666); err != nil { - fmt.Printf("error creating deployment config: %v\n", err) - } + if _, err := os.Stat(r.configFilePath); !os.IsNotExist(err) && !isCorrupted { + r.redisClient.Set( + ctx, + RayDeploymentKey, + currentConfigFile, + 0, + ) } // avoid race condition with file writing @@ -381,7 +377,10 @@ func (r *ray) sync() { var modelDeploymentConfig ModelDeploymentConfig - currentConfigFile, err := os.ReadFile(r.configFilePath) + currentConfigFile, err := r.redisClient.Get( + ctx, + RayDeploymentKey, + ).Bytes() if err != nil { logger.Error(fmt.Sprintf("error while reading deployment config: %v", err)) } @@ -414,7 +413,12 @@ func (r *ray) sync() { logger.Error(fmt.Sprintf("error while Marshaling YAML deployment config: %v", err)) } - if err := os.WriteFile(r.configFilePath, modelDeploymentConfigData, 0666); err != nil { + if err := r.redisClient.Set( + ctx, + RayDeploymentKey, + modelDeploymentConfigData, + 0, + ).Err(); err != nil { logger.Error(fmt.Sprintf("error creating deployment config: %v", err)) } @@ -449,6 +453,22 @@ func (r *ray) sync() { } func (r *ray) Close() { + ctx := context.Background() + + logger, _ := custom_logger.GetZapLogger(ctx) + + currentConfigFile, err := r.redisClient.Get( + ctx, + RayDeploymentKey, + ).Bytes() + if err != nil { + logger.Error(fmt.Sprintf("error while reading deployment config: %v", err)) + } + + if err := os.WriteFile(r.configFilePath, currentConfigFile, 0666); err != nil { + logger.Error(fmt.Sprintf("error creating deployment config: %v", err)) + } + if r.connection != nil { r.connection.Close() }