Skip to content

Commit

Permalink
[Refactor] Add a util function IsAutoscalingEnabled and refactor vali…
Browse files Browse the repository at this point in the history
…dations of RayJob deletion policy (#2775)

Signed-off-by: kaihsun <[email protected]>
  • Loading branch information
kevin85421 authored Jan 20, 2025
1 parent f191a75 commit da6b356
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (v *VolcanoBatchScheduler) Name() string {
func (v *VolcanoBatchScheduler) DoBatchSchedulingOnSubmission(ctx context.Context, app *rayv1.RayCluster) error {
var minMember int32
var totalResource corev1.ResourceList
if app.Spec.EnableInTreeAutoscaling == nil || !*app.Spec.EnableInTreeAutoscaling {
if !utils.IsAutoscalingEnabled(app) {
minMember = utils.CalculateDesiredReplicas(ctx, app) + 1
totalResource = utils.CalculateDesiredResources(app)
} else {
Expand Down
6 changes: 3 additions & 3 deletions ray-operator/controllers/ray/common/pod.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ func DefaultHeadPodTemplate(ctx context.Context, instance rayv1.RayCluster, head
initTemplateAnnotations(instance, &podTemplate)

// if in-tree autoscaling is enabled, then autoscaler container should be injected into head pod.
if instance.Spec.EnableInTreeAutoscaling != nil && *instance.Spec.EnableInTreeAutoscaling {
if utils.IsAutoscalingEnabled(&instance) {
// The default autoscaler is not compatible with Kubernetes. As a result, we disable
// the monitor process by default and inject a KubeRay autoscaler side container into the head pod.
headSpec.RayStartParams["no-monitor"] = "true"
Expand Down Expand Up @@ -380,7 +380,7 @@ func initLivenessAndReadinessProbe(rayContainer *corev1.Container, rayNodeType r
}

// BuildPod a pod config
func BuildPod(ctx context.Context, podTemplateSpec corev1.PodTemplateSpec, rayNodeType rayv1.RayNodeType, rayStartParams map[string]string, headPort string, enableRayAutoscaler *bool, creatorCRDType utils.CRDType, fqdnRayIP string) (aPod corev1.Pod) {
func BuildPod(ctx context.Context, podTemplateSpec corev1.PodTemplateSpec, rayNodeType rayv1.RayNodeType, rayStartParams map[string]string, headPort string, enableRayAutoscaler bool, creatorCRDType utils.CRDType, fqdnRayIP string) (aPod corev1.Pod) {
log := ctrl.LoggerFrom(ctx)

// For Worker Pod: Traffic readiness is determined by the readiness probe.
Expand All @@ -405,7 +405,7 @@ func BuildPod(ctx context.Context, podTemplateSpec corev1.PodTemplateSpec, rayNo

// Add /dev/shm volumeMount for the object store to avoid performance degradation.
addEmptyDir(ctx, &pod.Spec.Containers[utils.RayContainerIndex], &pod, SharedMemoryVolumeName, SharedMemoryVolumeMountPath, corev1.StorageMediumMemory)
if rayNodeType == rayv1.HeadNode && enableRayAutoscaler != nil && *enableRayAutoscaler {
if rayNodeType == rayv1.HeadNode && enableRayAutoscaler {
// The Ray autoscaler writes logs which are read by the Ray head.
// We need a shared log volume to enable this information flow.
// Specifically, this is required for the event-logging functionality
Expand Down
20 changes: 10 additions & 10 deletions ray-operator/controllers/ray/common/pod_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -576,7 +576,7 @@ func TestBuildPod(t *testing.T) {
// Test head pod
podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0))
podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", nil, utils.GetCRDType(""), "")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", false, utils.GetCRDType(""), "")

// Check environment variables
rayContainer := pod.Spec.Containers[utils.RayContainerIndex]
Expand Down Expand Up @@ -631,7 +631,7 @@ func TestBuildPod(t *testing.T) {
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", nil, utils.GetCRDType(""), fqdnRayIP)
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)

// Check resources
rayContainer = pod.Spec.Containers[utils.RayContainerIndex]
Expand Down Expand Up @@ -694,7 +694,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) {
// Test head pod
podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0))
podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", nil, utils.GetCRDType(""), "")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", false, utils.GetCRDType(""), "")
expectedCommandArg := splitAndSort("ulimit -n 65536; ray start --head --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --metrics-export-port=8080 --dashboard-host=0.0.0.0")
actualCommandArg := splitAndSort(pod.Spec.Containers[0].Args[0])
if !reflect.DeepEqual(expectedCommandArg, actualCommandArg) {
Expand All @@ -706,7 +706,7 @@ func TestBuildPod_WithNoCPULimits(t *testing.T) {
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", nil, utils.GetCRDType(""), fqdnRayIP)
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
expectedCommandArg = splitAndSort("ulimit -n 65536; ray start --block --dashboard-agent-listen-port=52365 --memory=1073741824 --num-cpus=2 --num-gpus=3 --address=raycluster-sample-head-svc.default.svc.cluster.local:6379 --port=6379 --metrics-export-port=8080")
actualCommandArg = splitAndSort(pod.Spec.Containers[0].Args[0])
if !reflect.DeepEqual(expectedCommandArg, actualCommandArg) {
Expand All @@ -730,7 +730,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) {

podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0))
podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379")
headPod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", nil, utils.GetCRDType(""), "")
headPod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", false, utils.GetCRDType(""), "")
headContainer := headPod.Spec.Containers[utils.RayContainerIndex]
assert.Equal(t, headContainer.Command, []string{"I am head"})
assert.Equal(t, headContainer.Args, []string{"I am head again"})
Expand All @@ -739,7 +739,7 @@ func TestBuildPod_WithOverwriteCommand(t *testing.T) {
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", nil, utils.GetCRDType(""), fqdnRayIP)
workerPod := BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.GetCRDType(""), fqdnRayIP)
workerContainer := workerPod.Spec.Containers[utils.RayContainerIndex]
assert.Equal(t, workerContainer.Command, []string{"I am worker"})
assert.Equal(t, workerContainer.Args, []string{"I am worker again"})
Expand All @@ -751,7 +751,7 @@ func TestBuildPod_WithAutoscalerEnabled(t *testing.T) {
cluster.Spec.EnableInTreeAutoscaling = &trueFlag
podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0))
podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", &trueFlag, utils.GetCRDType(""), "")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", true, utils.GetCRDType(""), "")

actualResult := pod.Labels[utils.RayClusterLabelKey]
expectedResult := cluster.Name
Expand Down Expand Up @@ -808,7 +808,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) {
cluster.Spec.EnableInTreeAutoscaling = &trueFlag
podName := strings.ToLower(cluster.Name + utils.DashSymbol + string(rayv1.HeadNode) + utils.DashSymbol + utils.FormatInt32(0))
podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", &trueFlag, utils.RayServiceCRD, "")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", true, utils.RayServiceCRD, "")

val, ok := pod.Labels[utils.RayClusterServingServiceLabelKey]
assert.True(t, ok, "Expected serve label is not present")
Expand All @@ -819,7 +819,7 @@ func TestBuildPod_WithCreatedByRayService(t *testing.T) {
podName = cluster.Name + utils.DashSymbol + string(rayv1.WorkerNode) + utils.DashSymbol + worker.GroupName + utils.DashSymbol + utils.FormatInt32(0)
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, *cluster, cluster.Namespace)
podTemplateSpec = DefaultWorkerPodTemplate(ctx, *cluster, worker, podName, fqdnRayIP, "6379")
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", nil, utils.RayServiceCRD, fqdnRayIP)
pod = BuildPod(ctx, podTemplateSpec, rayv1.WorkerNode, worker.RayStartParams, "6379", false, utils.RayServiceCRD, fqdnRayIP)

val, ok = pod.Labels[utils.RayClusterServingServiceLabelKey]
assert.True(t, ok, "Expected serve label is not present")
Expand Down Expand Up @@ -891,7 +891,7 @@ func TestBuildPodWithAutoscalerOptions(t *testing.T) {
SecurityContext: &customSecurityContext,
}
podTemplateSpec := DefaultHeadPodTemplate(ctx, *cluster, cluster.Spec.HeadGroupSpec, podName, "6379")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", &trueFlag, utils.GetCRDType(""), "")
pod := BuildPod(ctx, podTemplateSpec, rayv1.HeadNode, cluster.Spec.HeadGroupSpec.RayStartParams, "6379", true, utils.GetCRDType(""), "")
expectedContainer := *autoscalerContainer.DeepCopy()
expectedContainer.Image = customAutoscalerImage
expectedContainer.ImagePullPolicy = customPullPolicy
Expand Down
15 changes: 7 additions & 8 deletions ray-operator/controllers/ray/raycluster_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,7 @@ func validateRayClusterSpec(instance *rayv1.RayCluster) error {
}
}

enableInTreeAutoscaling := (instance.Spec.EnableInTreeAutoscaling != nil) && (*instance.Spec.EnableInTreeAutoscaling)
if enableInTreeAutoscaling {
if utils.IsAutoscalingEnabled(instance) {
for _, workerGroup := range instance.Spec.WorkerGroupSpecs {
if workerGroup.Suspend != nil && *workerGroup.Suspend {
// TODO (rueian): This can be supported in future Ray. We should check the RayVersion once we know the version.
Expand Down Expand Up @@ -943,7 +942,7 @@ func (r *RayClusterReconciler) reconcilePods(ctx context.Context, instance *rayv
// diff < 0 indicates the need to delete some Pods to match the desired number of replicas. However,
// randomly deleting Pods is certainly not ideal. So, if autoscaling is enabled for the cluster, we
// will disable random Pod deletion, making Autoscaler the sole decision-maker for Pod deletions.
enableInTreeAutoscaling := (instance.Spec.EnableInTreeAutoscaling != nil) && (*instance.Spec.EnableInTreeAutoscaling)
enableInTreeAutoscaling := utils.IsAutoscalingEnabled(instance)

// TODO (kevin85421): `enableRandomPodDelete` is a feature flag for KubeRay v0.6.0. If users want to use
// the old behavior, they can set the environment variable `ENABLE_RANDOM_POD_DELETE` to `true`. When the
Expand Down Expand Up @@ -1174,7 +1173,7 @@ func (r *RayClusterReconciler) buildHeadPod(ctx context.Context, instance rayv1.
fqdnRayIP := utils.GenerateFQDNServiceName(ctx, instance, instance.Namespace) // Fully Qualified Domain Name
// The Ray head port used by workers to connect to the cluster (GCS server port for Ray >= 1.11.0, Redis port for older Ray.)
headPort := common.GetHeadPort(instance.Spec.HeadGroupSpec.RayStartParams)
autoscalingEnabled := instance.Spec.EnableInTreeAutoscaling
autoscalingEnabled := utils.IsAutoscalingEnabled(&instance)
podConf := common.DefaultHeadPodTemplate(ctx, instance, instance.Spec.HeadGroupSpec, podName, headPort)
if len(r.headSidecarContainers) > 0 {
podConf.Spec.Containers = append(podConf.Spec.Containers, r.headSidecarContainers...)
Expand Down Expand Up @@ -1202,7 +1201,7 @@ func (r *RayClusterReconciler) buildWorkerPod(ctx context.Context, instance rayv

// The Ray head port used by workers to connect to the cluster (GCS server port for Ray >= 1.11.0, Redis port for older Ray.)
headPort := common.GetHeadPort(instance.Spec.HeadGroupSpec.RayStartParams)
autoscalingEnabled := instance.Spec.EnableInTreeAutoscaling
autoscalingEnabled := utils.IsAutoscalingEnabled(&instance)
podTemplateSpec := common.DefaultWorkerPodTemplate(ctx, instance, worker, podName, fqdnRayIP, headPort)
if len(r.workerSidecarContainers) > 0 {
podTemplateSpec.Spec.Containers = append(podTemplateSpec.Spec.Containers, r.workerSidecarContainers...)
Expand Down Expand Up @@ -1580,7 +1579,7 @@ func (r *RayClusterReconciler) updateHeadInfo(ctx context.Context, instance *ray

func (r *RayClusterReconciler) reconcileAutoscalerServiceAccount(ctx context.Context, instance *rayv1.RayCluster) error {
logger := ctrl.LoggerFrom(ctx)
if instance.Spec.EnableInTreeAutoscaling == nil || !*instance.Spec.EnableInTreeAutoscaling {
if !utils.IsAutoscalingEnabled(instance) {
return nil
}

Expand Down Expand Up @@ -1637,7 +1636,7 @@ func (r *RayClusterReconciler) reconcileAutoscalerServiceAccount(ctx context.Con

func (r *RayClusterReconciler) reconcileAutoscalerRole(ctx context.Context, instance *rayv1.RayCluster) error {
logger := ctrl.LoggerFrom(ctx)
if instance.Spec.EnableInTreeAutoscaling == nil || !*instance.Spec.EnableInTreeAutoscaling {
if !utils.IsAutoscalingEnabled(instance) {
return nil
}

Expand Down Expand Up @@ -1679,7 +1678,7 @@ func (r *RayClusterReconciler) reconcileAutoscalerRole(ctx context.Context, inst

func (r *RayClusterReconciler) reconcileAutoscalerRoleBinding(ctx context.Context, instance *rayv1.RayCluster) error {
logger := ctrl.LoggerFrom(ctx)
if instance.Spec.EnableInTreeAutoscaling == nil || !*instance.Spec.EnableInTreeAutoscaling {
if !utils.IsAutoscalingEnabled(instance) {
return nil
}

Expand Down
41 changes: 24 additions & 17 deletions ray-operator/controllers/ray/rayjob_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -885,10 +885,12 @@ func validateRayJobSpec(rayJob *rayv1.RayJob) error {
if rayJob.Spec.Suspend && !rayJob.Spec.ShutdownAfterJobFinishes {
return fmt.Errorf("a RayJob with shutdownAfterJobFinishes set to false is not allowed to be suspended")
}
if rayJob.Spec.Suspend && len(rayJob.Spec.ClusterSelector) != 0 {

isClusterSelectorMode := len(rayJob.Spec.ClusterSelector) != 0
if rayJob.Spec.Suspend && isClusterSelectorMode {
return fmt.Errorf("the ClusterSelector mode doesn't support the suspend operation")
}
if rayJob.Spec.RayClusterSpec == nil && len(rayJob.Spec.ClusterSelector) == 0 {
if rayJob.Spec.RayClusterSpec == nil && !isClusterSelectorMode {
return fmt.Errorf("one of RayClusterSpec or ClusterSelector must be set")
}
// Validate whether RuntimeEnvYAML is a valid YAML string. Note that this only checks its validity
Expand All @@ -905,21 +907,26 @@ func validateRayJobSpec(rayJob *rayv1.RayJob) error {
if !features.Enabled(features.RayJobDeletionPolicy) && rayJob.Spec.DeletionPolicy != nil {
return fmt.Errorf("RayJobDeletionPolicy feature gate must be enabled to use the DeletionPolicy feature")
}
if rayJob.Spec.ClusterSelector != nil &&
rayJob.Spec.DeletionPolicy != nil && *rayJob.Spec.DeletionPolicy == rayv1.DeleteClusterDeletionPolicy {
return fmt.Errorf("the ClusterSelector mode doesn't support DeletionPolicy=DeleteCluster")
}
if rayJob.Spec.ClusterSelector != nil &&
rayJob.Spec.DeletionPolicy != nil && *rayJob.Spec.DeletionPolicy == rayv1.DeleteWorkersDeletionPolicy {
return fmt.Errorf("the ClusterSelector mode doesn't support DeletionPolicy=DeleteWorkers")
}
if rayJob.Spec.DeletionPolicy != nil && *rayJob.Spec.DeletionPolicy == rayv1.DeleteWorkersDeletionPolicy &&
rayJob.Spec.RayClusterSpec.EnableInTreeAutoscaling != nil && *rayJob.Spec.RayClusterSpec.EnableInTreeAutoscaling {
// TODO (rueian): This can be supported in future Ray. We should check the RayVersion once we know the version.
return fmt.Errorf("DeletionPolicy=DeleteWorkers currently does not support RayClusterSpec.EnableInTreeAutoscaling")
}
if rayJob.Spec.ShutdownAfterJobFinishes && rayJob.Spec.DeletionPolicy != nil && *rayJob.Spec.DeletionPolicy == rayv1.DeleteNoneDeletionPolicy {
return fmt.Errorf("shutdownAfterJobFinshes is set to 'true' while deletion policy is 'DeleteNone'")

if rayJob.Spec.DeletionPolicy != nil {
policy := *rayJob.Spec.DeletionPolicy
if isClusterSelectorMode {
switch policy {
case rayv1.DeleteClusterDeletionPolicy:
return fmt.Errorf("the ClusterSelector mode doesn't support DeletionPolicy=DeleteCluster")
case rayv1.DeleteWorkersDeletionPolicy:
return fmt.Errorf("the ClusterSelector mode doesn't support DeletionPolicy=DeleteWorkers")
}
}

if policy == rayv1.DeleteWorkersDeletionPolicy && utils.IsAutoscalingEnabled(rayJob) {
// TODO (rueian): This can be supported in a future Ray version. We should check the RayVersion once we know it.
return fmt.Errorf("DeletionPolicy=DeleteWorkers currently does not support RayCluster with autoscaling enabled")
}

if rayJob.Spec.ShutdownAfterJobFinishes && policy == rayv1.DeleteNoneDeletionPolicy {
return fmt.Errorf("shutdownAfterJobFinshes is set to 'true' while deletion policy is 'DeleteNone'")
}
}
return nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ func TestValidateRayJobSpec(t *testing.T) {
},
},
})
assert.ErrorContains(t, err, "DeletionPolicy=DeleteWorkers currently does not support RayClusterSpec.EnableInTreeAutoscaling")
assert.ErrorContains(t, err, "DeletionPolicy=DeleteWorkers currently does not support RayCluster with autoscaling enabled")

err = validateRayJobSpec(&rayv1.RayJob{
Spec: rayv1.RayJobSpec{
Expand Down
13 changes: 13 additions & 0 deletions ray-operator/controllers/ray/utils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -620,3 +620,16 @@ func ManagedByExternalController(controllerName *string) *string {
}
return nil
}

func IsAutoscalingEnabled[T *rayv1.RayCluster | *rayv1.RayJob | *rayv1.RayService](obj T) bool {
switch obj := (interface{})(obj).(type) {
case *rayv1.RayCluster:
return obj.Spec.EnableInTreeAutoscaling != nil && *obj.Spec.EnableInTreeAutoscaling
case *rayv1.RayJob:
return obj.Spec.RayClusterSpec != nil && obj.Spec.RayClusterSpec.EnableInTreeAutoscaling != nil && *obj.Spec.RayClusterSpec.EnableInTreeAutoscaling
case *rayv1.RayService:
return obj.Spec.RayClusterSpec.EnableInTreeAutoscaling != nil && *obj.Spec.RayClusterSpec.EnableInTreeAutoscaling
default:
panic(fmt.Sprintf("unsupported type: %T", obj))
}
}
39 changes: 39 additions & 0 deletions ray-operator/controllers/ray/utils/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -713,3 +713,42 @@ func TestErrRayClusterReplicaFailureReason(t *testing.T) {
assert.Equal(t, RayClusterReplicaFailureReason(errors.Join(ErrFailedCreateWorkerPod, errors.New("other error"))), "FailedCreateWorkerPod")
assert.Equal(t, RayClusterReplicaFailureReason(errors.New("other error")), "")
}

func TestIsAutoscalingEnabled(t *testing.T) {
// Test: RayCluster
cluster := &rayv1.RayCluster{}
assert.False(t, IsAutoscalingEnabled(cluster))

cluster = &rayv1.RayCluster{
Spec: rayv1.RayClusterSpec{
EnableInTreeAutoscaling: ptr.To[bool](true),
},
}
assert.True(t, IsAutoscalingEnabled(cluster))

// Test: RayJob
job := &rayv1.RayJob{}
assert.False(t, IsAutoscalingEnabled(job))

job = &rayv1.RayJob{
Spec: rayv1.RayJobSpec{
RayClusterSpec: &rayv1.RayClusterSpec{
EnableInTreeAutoscaling: ptr.To[bool](true),
},
},
}
assert.True(t, IsAutoscalingEnabled(job))

// Test: RayService
service := &rayv1.RayService{}
assert.False(t, IsAutoscalingEnabled(service))

service = &rayv1.RayService{
Spec: rayv1.RayServiceSpec{
RayClusterSpec: rayv1.RayClusterSpec{
EnableInTreeAutoscaling: ptr.To[bool](true),
},
},
}
assert.True(t, IsAutoscalingEnabled(service))
}

0 comments on commit da6b356

Please sign in to comment.