diff --git a/client/go/internal/cli/cmd/query.go b/client/go/internal/cli/cmd/query.go index 5fa225777f02..95d8799c9b66 100644 --- a/client/go/internal/cli/cmd/query.go +++ b/client/go/internal/cli/cmd/query.go @@ -5,11 +5,13 @@ package cmd import ( + "bytes" "encoding/json" "fmt" "io" "net/http" "net/url" + "os" "strings" "time" @@ -22,14 +24,17 @@ import ( "github.com/vespa-engine/vespa/client/go/internal/vespa" ) +type queryOptions struct { + printCurl bool + queryTimeoutSecs int + waitSecs int + format string + postFile string + headers []string +} + func newQueryCmd(cli *CLI) *cobra.Command { - var ( - printCurl bool - queryTimeoutSecs int - waitSecs int - format string - headers []string - ) + opts := queryOptions{} cmd := &cobra.Command{ Use: "query query-parameters", Short: "Issue a query to Vespa", @@ -43,32 +48,45 @@ can be set by the syntax [parameter-name]=[value].`, // TODO: Support referencing a query json file DisableAutoGenTag: true, SilenceUsage: true, - Args: cobra.MinimumNArgs(1), + Args: cobra.MinimumNArgs(0), RunE: func(cmd *cobra.Command, args []string) error { - waiter := cli.waiter(time.Duration(waitSecs)*time.Second, cmd) - return query(cli, args, queryTimeoutSecs, printCurl, format, headers, waiter) + if len(args) == 0 && opts.postFile == "" { + return fmt.Errorf("requires at least 1 arg") + } + waiter := cli.waiter(time.Duration(opts.waitSecs)*time.Second, cmd) + return query(cli, args, &opts, waiter) }, } - cmd.Flags().BoolVarP(&printCurl, "verbose", "v", false, "Print the equivalent curl command for the query") - cmd.Flags().StringVarP(&format, "format", "", "human", "Output format. Must be 'human' (human-readable) or 'plain' (no formatting)") - cmd.Flags().StringSliceVarP(&headers, "header", "", nil, "Add a header to the HTTP request, on the format 'Header: Value'. This can be specified multiple times") - cmd.Flags().IntVarP(&queryTimeoutSecs, "timeout", "T", 10, "Timeout for the query in seconds") - cli.bindWaitFlag(cmd, 0, &waitSecs) + cmd.Flags().BoolVarP(&opts.printCurl, "verbose", "v", false, "Print the equivalent curl command for the query") + cmd.Flags().StringVarP(&opts.postFile, "file", "", "", "Read query parameters from the given JSON file and send a POST request, with overrides from arguments") + cmd.Flags().StringVarP(&opts.format, "format", "", "human", "Output format. Must be 'human' (human-readable) or 'plain' (no formatting)") + cmd.Flags().StringSliceVarP(&opts.headers, "header", "", nil, "Add a header to the HTTP request, on the format 'Header: Value'. This can be specified multiple times") + cmd.Flags().IntVarP(&opts.queryTimeoutSecs, "timeout", "T", 10, "Timeout for the query in seconds") + cli.bindWaitFlag(cmd, 0, &opts.waitSecs) return cmd } -func printCurl(stderr io.Writer, url string, service *vespa.Service) error { - cmd, err := curl.RawArgs(url) +func printCurl(stderr io.Writer, req *http.Request, postFile string, service *vespa.Service) error { + cmd, err := curl.RawArgs(req.URL.String()) if err != nil { return err } + cmd.Method = req.Method + if postFile != "" { + cmd.WithBodyFile(postFile) + } + for k, vl := range req.Header { + for _, v := range vl { + cmd.Header(k, v) + } + } cmd.Certificate = service.TLSOptions.CertificateFile cmd.PrivateKey = service.TLSOptions.PrivateKeyFile _, err = io.WriteString(stderr, cmd.String()+"\n") return err } -func query(cli *CLI, arguments []string, timeoutSecs int, curl bool, format string, headers []string, waiter *Waiter) error { +func query(cli *CLI, arguments []string, opts *queryOptions, waiter *Waiter) error { target, err := cli.target(targetOptions{}) if err != nil { return err @@ -77,12 +95,12 @@ func query(cli *CLI, arguments []string, timeoutSecs int, curl bool, format stri if err != nil { return err } - switch format { + switch opts.format { case "plain", "human": default: - return fmt.Errorf("invalid format: %s", format) + return fmt.Errorf("invalid format: %s", opts.format) } - url, _ := url.Parse(service.BaseURL + "/search/") + url, _ := url.Parse(strings.TrimSuffix(service.BaseURL, "/") + "/search/") urlQuery := url.Query() for i := range len(arguments) { key, value := splitArg(arguments[i]) @@ -91,31 +109,44 @@ func query(cli *CLI, arguments []string, timeoutSecs int, curl bool, format stri queryTimeout := urlQuery.Get("timeout") if queryTimeout == "" { // No timeout set by user, use the timeout option - queryTimeout = fmt.Sprintf("%ds", timeoutSecs) + queryTimeout = fmt.Sprintf("%ds", opts.queryTimeoutSecs) urlQuery.Set("timeout", queryTimeout) } - url.RawQuery = urlQuery.Encode() deadline, err := time.ParseDuration(queryTimeout) if err != nil { return fmt.Errorf("invalid query timeout: %w", err) } - if curl { - if err := printCurl(cli.Stderr, url.String(), service); err != nil { - return err - } - } - header, err := httputil.ParseHeader(headers) + header, err := httputil.ParseHeader(opts.headers) if err != nil { return err } - response, err := service.Do(&http.Request{Header: header, URL: url}, deadline+time.Second) // Slightly longer than query timeout + hReq := &http.Request{Header: header, URL: url} + if opts.postFile != "" { + json, err := getJsonFrom(opts.postFile, urlQuery) + if err != nil { + return fmt.Errorf("bad JSON in postFile '%s': %w", opts.postFile, err) + } + header.Set("Content-Type", "application/json") + hReq.Method = "POST" + hReq.Body = io.NopCloser(bytes.NewBuffer(bytes.Clone(json))) + if err != nil { + return fmt.Errorf("bad postFile '%s': %w", opts.postFile, err) + } + } + url.RawQuery = urlQuery.Encode() + if opts.printCurl { + if err := printCurl(cli.Stderr, hReq, opts.postFile, service); err != nil { + return err + } + } + response, err := service.Do(hReq, deadline+time.Second) // Slightly longer than query timeout if err != nil { return fmt.Errorf("request failed: %w", err) } defer response.Body.Close() if response.StatusCode == 200 { - if err := printResponse(response.Body, response.Header.Get("Content-Type"), format, cli); err != nil { + if err := printResponse(response.Body, response.Header.Get("Content-Type"), opts.format, cli); err != nil { return err } } else if response.StatusCode/100 == 4 { @@ -207,3 +238,32 @@ func splitArg(argument string) (string, string) { } return parts[0], parts[1] } + +func getJsonFrom(fn string, query url.Values) ([]byte, error) { + parsed := make(map[string]any) + f, err := os.Open(fn) + if err != nil { + return nil, err + } + body, err := io.ReadAll(f) + if err != nil { + return nil, err + } + err = json.Unmarshal(body, &parsed) + if err != nil { + return nil, err + } + for k, vl := range query { + if len(vl) == 1 { + parsed[k] = vl[0] + } else { + parsed[k] = vl + } + query.Del(k) + } + b, err := json.Marshal(parsed) + if err != nil { + return nil, err + } + return b, nil +} diff --git a/client/go/internal/cli/cmd/query_test.go b/client/go/internal/cli/cmd/query_test.go index f5b113b6acb7..ef78335f726f 100644 --- a/client/go/internal/cli/cmd/query_test.go +++ b/client/go/internal/cli/cmd/query_test.go @@ -6,10 +6,13 @@ package cmd import ( "net/http" + "os" + "path/filepath" "strconv" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/vespa-engine/vespa/client/go/internal/mock" ) @@ -134,6 +137,51 @@ data: { assertStreamingQuery(t, bodyWithError, bodyWithError, "--format=plain") } +func TestQueryPostFile(t *testing.T) { + mockResponse := `{"query":"result"}"` + client := &mock.HTTPClient{ReadBody: true} + client.NextResponseString(200, mockResponse) + cli, stdout, _ := newTestCLI(t) + cli.httpClient = client + + tmpFileName := filepath.Join(t.TempDir(), "tq1.json") + jsonQuery := []byte(`{"yql": "some yql here"}`) + require.Nil(t, os.WriteFile(tmpFileName, jsonQuery, 0644)) + + assert.Nil(t, cli.Run("-t", "http://127.0.0.1:8080", "query", "--file", tmpFileName)) + assert.Equal(t, mockResponse+"\n", stdout.String()) + assert.Equal(t, `{"timeout":"10s","yql":"some yql here"}`, string(client.LastBody)) + assert.Equal(t, []string{"application/json"}, client.LastRequest.Header.Values("Content-Type")) + assert.Equal(t, "POST", client.LastRequest.Method) + assert.Equal(t, "http://127.0.0.1:8080/search/", client.LastRequest.URL.String()) +} + +func TestQueryPostFileWithArgs(t *testing.T) { + mockResponse := `{"query":"result"}"` + client := &mock.HTTPClient{ReadBody: true} + client.NextResponseString(200, mockResponse) + cli, _, _ := newTestCLI(t) + cli.httpClient = client + + tmpFileName := filepath.Join(t.TempDir(), "tq2.json") + jsonQuery := []byte(`{"yql": "some yql here"}`) + require.Nil(t, os.WriteFile(tmpFileName, jsonQuery, 0644)) + + assert.Nil(t, cli.Run( + "-t", "http://foo.bar:1234/", + "query", + "--file", tmpFileName, + "yql=foo bar", + "tracelevel=3", + "dispatch.docsumRetryLimit=42")) + assert.Equal(t, + `{"dispatch.docsumRetryLimit":"42","timeout":"10s","tracelevel":"3","yql":"foo bar"}`, + string(client.LastBody)) + assert.Equal(t, []string{"application/json"}, client.LastRequest.Header.Values("Content-Type")) + assert.Equal(t, "POST", client.LastRequest.Method) + assert.Equal(t, "http://foo.bar:1234/search/", client.LastRequest.URL.String()) +} + func assertStreamingQuery(t *testing.T, expectedOutput, body string, args ...string) { t.Helper() client := &mock.HTTPClient{}