Skip to content

Commit

Permalink
connect: add initial custom evaler support
Browse files Browse the repository at this point in the history
Add support for customizing the way user input is processed.
The `--evaler` option is added to specify the Lua code, which
accepts `cmd` var containing user input for current connection.
The evaler Lua code can be loaded from file.
Autocompletion is disabled if custom evaler is set.
  • Loading branch information
psergee committed Nov 17, 2024
1 parent f87a566 commit 3d702c4
Show file tree
Hide file tree
Showing 11 changed files with 184 additions and 23 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.

### Added

- `tt connect`: add new `--evaler` option to support for customizing
the way user input is processed.

### Changed

### Fixed
Expand Down
5 changes: 5 additions & 0 deletions cli/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ var (
connectSslCiphers string
connectInteractive bool
connectBinary bool
connectEvaler string
)

// NewConnectCmd creates connect command.
Expand Down Expand Up @@ -102,6 +103,9 @@ func NewConnectCmd() *cobra.Command {
false, `enter interactive mode after executing 'FILE'`)
connectCmd.Flags().BoolVarP(&connectBinary, "binary", "",
false, `connect to instance using binary port`)
connectCmd.Flags().StringVar(&connectEvaler, "evaler", "",
`use the provided Lua expression as an interpreter for user's input of the connection`)
connectCmd.Flags().MarkHidden("evaler")

return connectCmd
}
Expand Down Expand Up @@ -197,6 +201,7 @@ func internalConnectModule(cmdCtx *cmdcontext.CmdCtx, args []string) error {
SslCiphers: connectSslCiphers,
Interactive: connectInteractive,
Binary: connectBinary,
Evaler: connectEvaler,
}

var ok bool
Expand Down
8 changes: 0 additions & 8 deletions cli/codegen/generate_code.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,6 @@ var luaCodeFiles = []generateLuaCodeOpts{
"checkSyntax": "cli/running/lua/check.lua",
},
},
{
PackageName: "connect",
FileName: "cli/connect/lua_code_gen.go",
VariablesMap: map[string]string{
"evalFuncBody": "cli/connect/lua/eval_func_body.lua",
"getSuggestionsFuncBody": "cli/connect/lua/get_suggestions_func_body.lua",
},
},
{
PackageName: "connector",
FileName: "cli/connector/lua_code_gen.go",
Expand Down
9 changes: 8 additions & 1 deletion cli/connect/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path"
"syscall"

"github.com/tarantool/tt/cli/connect/internal/luabody"
"github.com/tarantool/tt/cli/connector"
"github.com/tarantool/tt/cli/formatter"
"golang.org/x/crypto/ssh/terminal"
Expand Down Expand Up @@ -40,6 +41,8 @@ type ConnectCtx struct {
ConnectTarget string
// Binary port is used
Binary bool
// Evaler lua expression.
Evaler string
}

const (
Expand Down Expand Up @@ -114,7 +117,11 @@ func Eval(connectCtx ConnectCtx, connOpts connector.ConnectOpts, args []string)
}

// Execution of the command.
response, err := conn.Eval(evalFuncBody, evalArgs, connector.RequestOpts{})
evalBody, err := luabody.GetEvalFuncBody(connectCtx.Evaler)
if err != nil {
return nil, err
}
response, err := conn.Eval(evalBody, evalArgs, connector.RequestOpts{})
if err != nil {
return nil, err
}
Expand Down
30 changes: 23 additions & 7 deletions cli/connect/console.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"gopkg.in/yaml.v2"

"github.com/tarantool/go-prompt"
"github.com/tarantool/tt/cli/connect/internal/luabody"
"github.com/tarantool/tt/cli/connector"
"github.com/tarantool/tt/cli/formatter"
)
Expand Down Expand Up @@ -121,10 +122,13 @@ func NewConsole(connOpts connector.ConnectOpts, connectCtx ConnectCtx, title str
}

// Initialize user commands executor.
console.executor = getExecutor(console)
console.executor, err = getExecutor(console, connectCtx)
if err != nil {
return nil, fmt.Errorf("failed to init prompt: %s", err)
}

// Initialize commands completer.
console.completer = getCompleter(console)
console.completer = getCompleter(console, connectCtx)

// Initialize syntax checkers.
luaValidator := NewLuaValidator()
Expand Down Expand Up @@ -182,8 +186,14 @@ func (console *Console) Close() {
}

// getExecutor returns command executor.
func getExecutor(console *Console) func(string) {
func getExecutor(console *Console, connectCtx ConnectCtx) (func(string), error) {
commandsExecutor := newCmdExecutor()

evalBody, err := luabody.GetEvalFuncBody(connectCtx.Evaler)
if err != nil {
return nil, err
}

executor := func(in string) {
if console.input == "" {
if commandsExecutor.Execute(console, in) {
Expand Down Expand Up @@ -237,7 +247,7 @@ func getExecutor(console *Console) func(string) {
}

var data string
if _, err := console.conn.Eval(evalFuncBody, args, opts); err != nil {
if _, err := console.conn.Eval(evalBody, args, opts); err != nil {
if err == io.EOF {
// We need to call 'console.Close()' here because in some cases (e.g 'os.exit()')
// it won't be called from 'defer console.Close' in 'connect.runConsole()'.
Expand Down Expand Up @@ -266,10 +276,16 @@ func getExecutor(console *Console) func(string) {
console.livePrefixEnabled = false
}

return executor
return executor, nil
}

func getCompleter(console *Console) prompt.Completer {
func getCompleter(console *Console, connectCtx ConnectCtx) prompt.Completer {
if len(connectCtx.Evaler) != 0 {
return func(prompt.Document) []prompt.Suggest {
return nil
}
}

completer := func(in prompt.Document) []prompt.Suggest {
if len(in.Text) == 0 {
return nil
Expand All @@ -295,7 +311,7 @@ func getCompleter(console *Console) prompt.Completer {
ResData: &suggestionsTexts,
}

if _, err := console.conn.Eval(getSuggestionsFuncBody, args, opts); err != nil {
if _, err := console.conn.Eval(luabody.GetSuggestionsFuncBody(), args, opts); err != nil {
return nil
}

Expand Down
39 changes: 39 additions & 0 deletions cli/connect/internal/luabody/eval.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package luabody

import (
_ "embed"
"fmt"
"os"
"strings"

"github.com/tarantool/cartridge-cli/cli/templates"
)

//go:embed eval_func_body.lua
var evalFuncBody string

//go:embed get_suggestions_func_body.lua
var getSuggestionsFuncBody string

// GetEvalFuncBody returns lua code of eval func.
func GetEvalFuncBody(evaler string) (string, error) {
mapping := map[string]string{}
if len(evaler) != 0 {
if strings.HasPrefix(evaler, "@") {
evalerFileBytes, err := os.ReadFile(strings.TrimPrefix(evaler, "@"))
if err != nil {
return "", fmt.Errorf("failed to read the evaler file: %s", err)
}
mapping["evaler"] = string(evalerFileBytes)
} else {
mapping["evaler"] = evaler
}
}

return templates.GetTemplatedStr(&evalFuncBody, mapping)
}

// GetEvalFuncBody returns lua code for completions.
func GetSuggestionsFuncBody() string {
return getSuggestionsFuncBody
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,19 @@ if is_command(cmd) or is_sql_language == true then
return require('console').eval(cmd)
end

{{ if .evaler }}
local function fun()
{{ .evaler }}
end
{{ else }}
local fun, errmsg = loadstring("return "..cmd)
if not fun then
fun, errmsg = loadstring(cmd)
end
if not fun then
return yaml.encode({box.NULL})
end
{{ end }}

local function table_pack(...)
return {n = select('#', ...), ...}
Expand All @@ -44,7 +50,7 @@ if not ret[1] then
if err == nil then
err = box.NULL
end
return yaml.encode({{error = err}})
return yaml.encode({ {error = err} })
end
if ret.n == 1 then
return "---\n...\n"
Expand Down
7 changes: 6 additions & 1 deletion cli/connect/language.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"gopkg.in/yaml.v2"

"github.com/tarantool/tt/cli/connect/internal/luabody"
"github.com/tarantool/tt/cli/connector"
)

Expand Down Expand Up @@ -60,7 +61,11 @@ func ChangeLanguage(evaler connector.Evaler, lang Language) error {
}

languageCmd := setLanguagePrefix + " " + lang.String()
response, err := evaler.Eval(evalFuncBody,
evalBody, err := luabody.GetEvalFuncBody("")
if err != nil {
return err
}
response, err := evaler.Eval(evalBody,
[]interface{}{languageCmd},
connector.RequestOpts{},
)
Expand Down
13 changes: 8 additions & 5 deletions cli/connect/language_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/tarantool/cartridge-cli/cli/templates"
. "github.com/tarantool/tt/cli/connect"
"github.com/tarantool/tt/cli/connector"
)
Expand Down Expand Up @@ -83,11 +85,12 @@ func (evaler *inputEvaler) Eval(fun string,
}

func TestChangeLanguage_requestInputs(t *testing.T) {
rawFun, err := os.ReadFile("./lua/eval_func_body.lua")
if err != nil {
t.Fatal("Failed to read lua file:", err)
}
expectedFun := string(rawFun)
rawFun, err := os.ReadFile("./internal/luabody/eval_func_body.lua")
require.NoError(t, err, "Failed to read lua file")
rawFunStr := string(rawFun)
expectedFun, err := templates.GetTemplatedStr(&rawFunStr, map[string]string{})
require.NoError(t, err, "Failed to render eval func body template")

expectedOpts := connector.RequestOpts{}
cases := []struct {
lang Language
Expand Down
85 changes: 85 additions & 0 deletions test/integration/connect/test_connect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import os
import platform
import re
Expand Down Expand Up @@ -3088,3 +3089,87 @@ def test_set_delimiter(
finally:
# Stop the Instance.
stop_app(tt_cmd, tmpdir, "test_app")


def test_custom_evaler(tt_cmd, tmpdir_with_cfg):
skip_if_tuple_format_supported(tt_cmd, tmpdir_with_cfg)

tmpdir = tmpdir_with_cfg
test_app_path = os.path.join(os.path.dirname(__file__), "test_output_format_app",
"test_app.lua")
copy_data(tmpdir, [test_app_path])
start_app(tt_cmd, tmpdir, "test_app")

expected_output = '''+---------+
| name |
+---------+
| William |
+---------+
'''
try:
file = wait_file(os.path.join(tmpdir, "test_app"), "ready", [])
assert file != ""

with open(os.path.join(tmpdir, "evaler.lua"), 'w') as f:
f.write("return box.execute(cmd)")

uris = ["localhost:3013", "tcp://localhost:3013"]
evalers = ["return box.execute(cmd)", os.path.join("@" + tmpdir, "evaler.lua")]
for uri, evaler in itertools.product(uris, evalers):
ret, output = try_execute_on_instance(
tt_cmd,
tmpdir,
uri,
stdin="SELECT name FROM customers WHERE id = 4",
opts={"--evaler": evaler,
"-x": "table"},
)

assert ret
assert output == expected_output

finally:
stop_app(tt_cmd, tmpdir, "test_app")


def test_custom_evaler_errors(tt_cmd, tmpdir_with_cfg):
tmpdir = tmpdir_with_cfg
test_app_path = os.path.join(os.path.dirname(__file__), "test_output_format_app",
"test_app.lua")
copy_data(tmpdir, [test_app_path])
start_app(tt_cmd, tmpdir, "test_app")

try:
file = wait_file(os.path.join(tmpdir, "test_app"), "ready", [])
assert file != ""

with open(os.path.join(tmpdir, "evaler.lua"), 'w') as f:
f.write("return box.execute(cmd)")

uris = ["localhost:3013", "tcp://localhost:3013"]
for uri in uris:
ret, output = try_execute_on_instance(
tt_cmd,
tmpdir,
uri,
stdin="SELECT name FROM customers WHERE id = 4",
opts={"--evaler": "return box.execute(cmd"},
)

assert not ret
assert "Failed to execute command" in output

ret, output = try_execute_on_instance(
tt_cmd,
tmpdir,
"localhost:3013",
stdin="SELECT name FROM customers WHERE id = 4",
opts={"--evaler": "@missing_file"},
)

assert not ret
assert "failed to read the evaler file: open missing_file: " \
"no such file or directory" in output

finally:
stop_app(tt_cmd, tmpdir, "test_app")

0 comments on commit 3d702c4

Please sign in to comment.