diff --git a/ray-operator/controllers/ray/common/job.go b/ray-operator/controllers/ray/common/job.go index 41d18728d1..dc4dbe74a3 100644 --- a/ray-operator/controllers/ray/common/job.go +++ b/ray-operator/controllers/ray/common/job.go @@ -3,6 +3,7 @@ package common import ( "encoding/json" "fmt" + "strconv" "strings" semver "github.com/Masterminds/semver/v3" @@ -33,15 +34,6 @@ func getRuntimeEnvJson(rayJobInstance *rayv1.RayJob) (string, error) { return "", nil } -// GetBaseRayJobCommand returns the first part of the Ray Job command up to and including the address, e.g. "ray job submit --address http://..." -func GetBaseRayJobCommand(address string) []string { - // add http:// if needed - if !strings.HasPrefix(address, "http://") { - address = "http://" + address - } - return []string{"ray", "job", "submit", "--address", address} -} - // GetMetadataJson returns the JSON string of the metadata for the Ray job. func GetMetadataJson(metadata map[string]string, rayVersion string) (string, error) { // Check that the Ray version is at least 2.6.0. @@ -73,14 +65,34 @@ func GetK8sJobCommand(rayJobInstance *rayv1.RayJob) ([]string, error) { entrypointNumGpus := rayJobInstance.Spec.EntrypointNumGpus entrypointResources := rayJobInstance.Spec.EntrypointResources - k8sJobCommand := GetBaseRayJobCommand(address) + // add http:// if needed + if !strings.HasPrefix(address, "http://") { + address = "http://" + address + } + + // `ray job submit` alone doesn't handle duplicated submission gracefully. See https://github.com/ray-project/kuberay/issues/2154. + // In order to deal with that, we use `ray job status` first to check if the jobId has been submitted. + // If the jobId has been submitted, we use `ray job logs` to follow the logs. + // Otherwise, we submit the job normally with `ray job submit`. The full shell command looks like this: + // if ray job status --address http://$RAY_ADDRESS $RAY_JOB_SUBMISSION_ID >/dev/null 2>&1 ; + // then ray job logs --address http://RAY_ADDRESS --follow $RAY_JOB_SUBMISSION_ID ; + // else ray job submit --address http://RAY_ADDRESS --submission-id $RAY_JOB_SUBMISSION_ID -- ... ; + // fi + jobStatusCommand := []string{"ray", "job", "status", "--address", address, jobId, ">/dev/null", "2>&1"} + jobFollowCommand := []string{"ray", "job", "logs", "--address", address, "--follow", jobId} + jobSubmitCommand := []string{"ray", "job", "submit", "--address", address} + k8sJobCommand := append([]string{"if"}, jobStatusCommand...) + k8sJobCommand = append(k8sJobCommand, ";", "then") + k8sJobCommand = append(k8sJobCommand, jobFollowCommand...) + k8sJobCommand = append(k8sJobCommand, ";", "else") + k8sJobCommand = append(k8sJobCommand, jobSubmitCommand...) runtimeEnvJson, err := getRuntimeEnvJson(rayJobInstance) if err != nil { return nil, err } if len(runtimeEnvJson) > 0 { - k8sJobCommand = append(k8sJobCommand, "--runtime-env-json", runtimeEnvJson) + k8sJobCommand = append(k8sJobCommand, "--runtime-env-json", strconv.Quote(runtimeEnvJson)) } if len(metadata) > 0 { @@ -88,7 +100,7 @@ func GetK8sJobCommand(rayJobInstance *rayv1.RayJob) ([]string, error) { if err != nil { return nil, err } - k8sJobCommand = append(k8sJobCommand, "--metadata-json", metadataJson) + k8sJobCommand = append(k8sJobCommand, "--metadata-json", strconv.Quote(metadataJson)) } if len(jobId) > 0 { @@ -104,7 +116,7 @@ func GetK8sJobCommand(rayJobInstance *rayv1.RayJob) ([]string, error) { } if len(entrypointResources) > 0 { - k8sJobCommand = append(k8sJobCommand, "--entrypoint-resources", entrypointResources) + k8sJobCommand = append(k8sJobCommand, "--entrypoint-resources", strconv.Quote(entrypointResources)) } // "--" is used to separate the entrypoint from the Ray Job CLI command and its arguments. @@ -116,6 +128,8 @@ func GetK8sJobCommand(rayJobInstance *rayv1.RayJob) ([]string, error) { } k8sJobCommand = append(k8sJobCommand, commandSlice...) + k8sJobCommand = append(k8sJobCommand, ";", "fi") + return k8sJobCommand, nil } diff --git a/ray-operator/controllers/ray/common/job_test.go b/ray-operator/controllers/ray/common/job_test.go index ba325bb874..da41eaeda4 100644 --- a/ray-operator/controllers/ray/common/job_test.go +++ b/ray-operator/controllers/ray/common/job_test.go @@ -2,6 +2,7 @@ package common import ( "encoding/json" + "strconv" "testing" "github.com/stretchr/testify/assert" @@ -62,12 +63,6 @@ pip: ["python-multipart==0.0.6"] assert.Equal(t, expectedMap, actualMap) } -func TestGetBaseRayJobCommand(t *testing.T) { - expected := []string{"ray", "job", "submit", "--address", "http://127.0.0.1:8265"} - command := GetBaseRayJobCommand(testRayJob.Status.DashboardURL) - assert.Equal(t, expected, command) -} - func TestGetMetadataJson(t *testing.T) { expected := `{"testKey":"testValue"}` metadataJson, err := GetMetadataJson(testRayJob.Spec.Metadata, testRayJob.Spec.RayClusterSpec.RayVersion) @@ -77,15 +72,21 @@ func TestGetMetadataJson(t *testing.T) { func TestGetK8sJobCommand(t *testing.T) { expected := []string{ + "if", + "ray", "job", "status", "--address", "http://127.0.0.1:8265", "testJobId", ">/dev/null", "2>&1", + ";", "then", + "ray", "job", "logs", "--address", "http://127.0.0.1:8265", "--follow", "testJobId", + ";", "else", "ray", "job", "submit", "--address", "http://127.0.0.1:8265", - "--runtime-env-json", `{"test":"test"}`, - "--metadata-json", `{"testKey":"testValue"}`, + "--runtime-env-json", strconv.Quote(`{"test":"test"}`), + "--metadata-json", strconv.Quote(`{"testKey":"testValue"}`), "--submission-id", "testJobId", "--entrypoint-num-cpus", "1.000000", "--entrypoint-num-gpus", "0.500000", - "--entrypoint-resources", `{"Custom_1": 1, "Custom_2": 5.5}`, + "--entrypoint-resources", strconv.Quote(`{"Custom_1": 1, "Custom_2": 5.5}`), "--", "echo", "hello", + ";", "fi", } command, err := GetK8sJobCommand(testRayJob) assert.NoError(t, err) @@ -113,12 +114,18 @@ pip: ["python-multipart==0.0.6"] }, } expected := []string{ + "if", + "ray", "job", "status", "--address", "http://127.0.0.1:8265", "testJobId", ">/dev/null", "2>&1", + ";", "then", + "ray", "job", "logs", "--address", "http://127.0.0.1:8265", "--follow", "testJobId", + ";", "else", "ray", "job", "submit", "--address", "http://127.0.0.1:8265", - "--runtime-env-json", `{"working_dir":"https://github.com/ray-project/serve_config_examples/archive/b393e77bbd6aba0881e3d94c05f968f05a387b96.zip","pip":["python-multipart==0.0.6"]}`, - "--metadata-json", `{"testKey":"testValue"}`, + "--runtime-env-json", strconv.Quote(`{"working_dir":"https://github.com/ray-project/serve_config_examples/archive/b393e77bbd6aba0881e3d94c05f968f05a387b96.zip","pip":["python-multipart==0.0.6"]}`), + "--metadata-json", strconv.Quote(`{"testKey":"testValue"}`), "--submission-id", "testJobId", "--", "echo", "hello", + ";", "fi", } command, err := GetK8sJobCommand(rayJobWithYAML) assert.NoError(t, err) @@ -127,11 +134,17 @@ pip: ["python-multipart==0.0.6"] assert.Equal(t, len(expected), len(command)) for i := 0; i < len(expected); i++ { + // For non-JSON elements, compare them directly. + assert.Equal(t, expected[i], command[i]) if expected[i] == "--runtime-env-json" { // Decode the JSON string from the next element. var expectedMap, actualMap map[string]interface{} - err1 := json.Unmarshal([]byte(expected[i+1]), &expectedMap) - err2 := json.Unmarshal([]byte(command[i+1]), &actualMap) + unquoteExpected, err1 := strconv.Unquote(expected[i+1]) + assert.NoError(t, err1) + unquotedCommand, err2 := strconv.Unquote(command[i+1]) + assert.NoError(t, err2) + err1 = json.Unmarshal([]byte(unquoteExpected), &expectedMap) + err2 = json.Unmarshal([]byte(unquotedCommand), &actualMap) // If there's an error decoding either JSON string, it's an error in the test. assert.NoError(t, err1) @@ -142,9 +155,6 @@ pip: ["python-multipart==0.0.6"] // Skip the next element because we've just checked it. i++ - } else { - // For non-JSON elements, compare them directly. - assert.Equal(t, expected[i], command[i]) } } } diff --git a/ray-operator/controllers/ray/rayjob_controller.go b/ray-operator/controllers/ray/rayjob_controller.go index ccf041f06f..0cb3ebd4e9 100644 --- a/ray-operator/controllers/ray/rayjob_controller.go +++ b/ray-operator/controllers/ray/rayjob_controller.go @@ -476,7 +476,8 @@ func (r *RayJobReconciler) getSubmitterTemplate(ctx context.Context, rayJobInsta if err != nil { return corev1.PodTemplateSpec{}, err } - submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command = k8sJobCommand + submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command = []string{"/bin/sh"} + submitterTemplate.Spec.Containers[utils.RayContainerIndex].Args = []string{"-c", strings.Join(k8sJobCommand, " ")} logger.Info("No command is specified in the user-provided template. Default command is used", "command", k8sJobCommand) } else { logger.Info("User-provided command is used", "command", submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command) diff --git a/ray-operator/controllers/ray/rayjob_controller_unit_test.go b/ray-operator/controllers/ray/rayjob_controller_unit_test.go index ed767259e8..433cdcd438 100644 --- a/ray-operator/controllers/ray/rayjob_controller_unit_test.go +++ b/ray-operator/controllers/ray/rayjob_controller_unit_test.go @@ -166,12 +166,14 @@ func TestGetSubmitterTemplate(t *testing.T) { rayJobInstanceWithTemplate.Spec.SubmitterPodTemplate.Spec.Containers[utils.RayContainerIndex].Command = []string{} submitterTemplate, err = r.getSubmitterTemplate(ctx, rayJobInstanceWithTemplate, nil) assert.NoError(t, err) - assert.Equal(t, []string{"ray", "job", "submit", "--address", "http://test-url", "--submission-id", "test-job-id", "--", "echo", "hello", "world"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command) + assert.Equal(t, []string{"/bin/sh"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command) + assert.Equal(t, []string{"-c", "if ray job status --address http://test-url test-job-id >/dev/null 2>&1 ; then ray job logs --address http://test-url --follow test-job-id ; else ray job submit --address http://test-url --submission-id test-job-id -- echo hello world ; fi"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Args) // Test 3: User did not provide template, should use the image of the Ray Head submitterTemplate, err = r.getSubmitterTemplate(ctx, rayJobInstanceWithoutTemplate, rayClusterInstance) assert.NoError(t, err) - assert.Equal(t, []string{"ray", "job", "submit", "--address", "http://test-url", "--submission-id", "test-job-id", "--", "echo", "hello", "world"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command) + assert.Equal(t, []string{"/bin/sh"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Command) + assert.Equal(t, []string{"-c", "if ray job status --address http://test-url test-job-id >/dev/null 2>&1 ; then ray job logs --address http://test-url --follow test-job-id ; else ray job submit --address http://test-url --submission-id test-job-id -- echo hello world ; fi"}, submitterTemplate.Spec.Containers[utils.RayContainerIndex].Args) assert.Equal(t, "rayproject/ray:custom-version", submitterTemplate.Spec.Containers[utils.RayContainerIndex].Image) // Test 4: Check default PYTHONUNBUFFERED setting