Skip to content

Commit

Permalink
[shim] Implement multi-task state
Browse files Browse the repository at this point in the history
The shim is now able to process multiple tasks at a time.
The global state for the existing "legacy" API is emulated.
From the API caller's point of view, the shim works exactly
the same as before.

Part-of: #1780
  • Loading branch information
un-def committed Dec 10, 2024
1 parent 7be0fb0 commit 3367301
Show file tree
Hide file tree
Showing 10 changed files with 424 additions and 120 deletions.
10 changes: 0 additions & 10 deletions runner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,6 @@ These are nonexhaustive lists of external dependencies (executables, libraries)

### `dstack-shim`

#### Libraries

* libc
* ...

#### Executables

* `mount`
Expand All @@ -104,11 +99,6 @@ Debian/Ubuntu packages: `mount` (`mount`, `umount`), `util-linux` (`mountpoint`,

### `dstack-runner`

#### Libraries

* libc
* ...

#### Executables

* ...
2 changes: 1 addition & 1 deletion runner/internal/runner/api/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (ds DummyRunner) GetState() (shim.RunnerStatus, shim.JobResult) {
return ds.State, ds.JobResult
}

func (ds DummyRunner) Run(context.Context, shim.TaskConfig) error {
func (ds DummyRunner) Run(context.Context, shim.Task) error {
return nil
}

Expand Down
23 changes: 20 additions & 3 deletions runner/internal/shim/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,31 @@ func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) (
return nil, &api.Error{Status: http.StatusConflict}
}

var body TaskConfigBody
var body SubmitBody
if err := api.DecodeJSONBody(w, r, &body, true); err != nil {
log.Println("Failed to decode submit body", "err", err)
return nil, err
}

go func(taskConfig shim.TaskConfig) {
err := s.runner.Run(context.Background(), taskConfig)
go func(body SubmitBody) {
cfg := shim.TaskConfig{
ID: shim.LegacyTaskID,
Name: body.ContainerName,
RegistryUsername: body.Username,
RegistryPassword: body.Password,
ImageName: body.ImageName,
ContainerUser: body.ContainerUser,
Privileged: body.Privileged,
ShmSize: body.ShmSize,
PublicKeys: body.PublicKeys,
SshUser: body.SshUser,
SshKey: body.SshKey,
Volumes: body.Volumes,
VolumeMounts: body.VolumeMounts,
InstanceMounts: body.InstanceMounts,
}
task := shim.NewTask(cfg)
err := s.runner.Run(context.Background(), task)
if err != nil {
fmt.Printf("failed Run %v\n", err)
}
Expand Down
16 changes: 15 additions & 1 deletion runner/internal/shim/api/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,21 @@ package api

import "github.com/dstackai/dstack/runner/internal/shim"

type TaskConfigBody = shim.TaskConfig
type SubmitBody struct {
Username string `json:"username"`
Password string `json:"password"`
ImageName string `json:"image_name"`
Privileged bool `json:"privileged"`
ContainerName string `json:"container_name"`
ContainerUser string `json:"container_user"`
ShmSize int64 `json:"shm_size"`
PublicKeys []string `json:"public_keys"`
SshUser string `json:"ssh_user"`
SshKey string `json:"ssh_key"`
VolumeMounts []shim.VolumeMountPoint `json:"mounts"`
Volumes []shim.VolumeInfo `json:"volumes"`
InstanceMounts []shim.InstanceMountPoint `json:"instance_mounts"`
}

type StopBody struct {
Force bool `json:"force"`
Expand Down
9 changes: 7 additions & 2 deletions runner/internal/shim/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

type TaskRunner interface {
Run(context.Context, shim.TaskConfig) error
Run(context.Context, shim.Task) error
GetState() (shim.RunnerStatus, shim.JobResult)
Stop(bool)
}
Expand All @@ -36,8 +36,13 @@ func NewShimServer(address string, runner TaskRunner, version string) *ShimServe

version: version,
}
mux.HandleFunc("/api/submit", api.JSONResponseHandler("POST", s.SubmitPostHandler))
// The healthcheck endpoint should stay backward compatible, as it is used for negotiation
mux.HandleFunc("/api/healthcheck", api.JSONResponseHandler("GET", s.HealthcheckGetHandler))
// The following endpoints constitute a so-called legacy API, where shim has one global state
// and is able to process only one task at a time
// NOTE: as of 2024-12-10, there is _only_ legacy API, but the "legacy" label is used to
// distinguish the "old" API from the upcoming new one
mux.HandleFunc("/api/submit", api.JSONResponseHandler("POST", s.SubmitPostHandler))
mux.HandleFunc("/api/pull", api.JSONResponseHandler("GET", s.PullGetHandler))
mux.HandleFunc("/api/stop", api.JSONResponseHandler("POST", s.StopPostHandler))
return s
Expand Down
Loading

0 comments on commit 3367301

Please sign in to comment.