diff --git a/selenoid.go b/selenoid.go index 8e1f3aa3..80d28dde 100644 --- a/selenoid.go +++ b/selenoid.go @@ -442,7 +442,7 @@ func getSessionTimeout(sessionTimeout string, maxTimeout time.Duration, defaultT if sessionTimeout != "" { st, err := time.ParseDuration(sessionTimeout) if err != nil { - return 0, fmt.Errorf("Invalid sessionTimeout capability: %v", err) + return 0, fmt.Errorf("invalid sessionTimeout capability: %v", err) } if st <= maxTimeout { return st, nil @@ -470,6 +470,8 @@ func generateRandomFileName(extension string) string { return "selenoid" + hex.EncodeToString(randBytes) + extension } +const vendorPrefix = "aerokube" + func proxy(w http.ResponseWriter, r *http.Request) { done := make(chan func()) go func() { @@ -486,6 +488,15 @@ func proxy(w http.ResponseWriter, r *http.Request) { id := fragments[2] sess, ok := sessions.Get(id) if ok { + if len(fragments) >= 4 && fragments[3] == vendorPrefix { + newFragments := []string{"", fragments[4], id} + if len(fragments) >= 5 { + newFragments = append(newFragments, fragments[5:]...) + } + r.URL.Host = (&request{r}).localaddr() + r.URL.Path = path.Clean(strings.Join(newFragments, slash)) + return + } sess.Lock.Lock() defer sess.Lock.Unlock() select { diff --git a/selenoid_test.go b/selenoid_test.go index d15c1e5b..a541cacf 100644 --- a/selenoid_test.go +++ b/selenoid_test.go @@ -811,6 +811,18 @@ func TestServeAndDeleteLogFile(t *testing.T) { } func TestFileDownload(t *testing.T) { + testFileDownload(t, func(sessionId string) string { + return fmt.Sprintf("/download/%s/testfile", sessionId) + }) +} + +func TestFileDownloadProtocolExtension(t *testing.T) { + testFileDownload(t, func(sessionId string) string { + return fmt.Sprintf("/wd/hub/session/%s/aerokube/download/testfile", sessionId) + }) +} + +func testFileDownload(t *testing.T, path func(string) string) { manager = &HTTPTest{Handler: Selenium()} resp, err := http.Post(With(srv.URL).Path("/wd/hub/session"), "", bytes.NewReader([]byte("{}"))) @@ -819,7 +831,7 @@ func TestFileDownload(t *testing.T) { var sess map[string]string AssertThat(t, resp, AllOf{Code{http.StatusOK}, IsJson{&sess}}) - rsp, err := http.Get(With(srv.URL).Path(fmt.Sprintf("/download/%s/testfile", sess["sessionId"]))) + rsp, err := http.Get(With(srv.URL).Path(path(sess["sessionId"]))) AssertThat(t, err, Is{nil}) AssertThat(t, rsp, Code{http.StatusOK}) data, err := io.ReadAll(rsp.Body) @@ -837,6 +849,18 @@ func TestFileDownloadMissingSession(t *testing.T) { } func TestClipboard(t *testing.T) { + testClipboard(t, func(sessionId string) string { + return fmt.Sprintf("/clipboard/%s", sessionId) + }) +} + +func TestClipboardProtocolExtension(t *testing.T) { + testClipboard(t, func(sessionId string) string { + return fmt.Sprintf("/wd/hub/session/%s/aerokube/clipboard", sessionId) + }) +} + +func testClipboard(t *testing.T, path func(string) string) { manager = &HTTPTest{Handler: Selenium()} resp, err := http.Post(With(srv.URL).Path("/wd/hub/session"), "", bytes.NewReader([]byte("{}"))) @@ -845,14 +869,14 @@ func TestClipboard(t *testing.T) { var sess map[string]string AssertThat(t, resp, AllOf{Code{http.StatusOK}, IsJson{&sess}}) - rsp, err := http.Get(With(srv.URL).Path(fmt.Sprintf("/clipboard/%s", sess["sessionId"]))) + rsp, err := http.Get(With(srv.URL).Path(path(sess["sessionId"]))) AssertThat(t, err, Is{nil}) AssertThat(t, rsp, Code{http.StatusOK}) data, err := io.ReadAll(rsp.Body) AssertThat(t, err, Is{nil}) AssertThat(t, string(data), EqualTo{"test-clipboard-value"}) - rsp, err = http.Post(With(srv.URL).Path(fmt.Sprintf("/clipboard/%s", sess["sessionId"])), "text/plain", bytes.NewReader([]byte("any-data"))) + rsp, err = http.Post(With(srv.URL).Path(path(sess["sessionId"])), "text/plain", bytes.NewReader([]byte("any-data"))) AssertThat(t, err, Is{nil}) AssertThat(t, rsp, Code{http.StatusOK})