Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Defined2014 committed Dec 24, 2024
1 parent f995dea commit 83dec84
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
16 changes: 9 additions & 7 deletions src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ const (
type query struct {
firstWord string
Query string
Delimiter string
Line int
tp int
}
Expand Down Expand Up @@ -626,7 +627,7 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure
return
}

err := tt.stmtExecute(query.Query)
err := tt.stmtExecute(query)
if err != nil && len(t.expectedErrs) > 0 {
for _, tStr := range t.expectedErrs {
if strings.Contains(err.Error(), tStr) {
Expand Down Expand Up @@ -668,7 +669,7 @@ func (t *tester) loadQueries() ([]query, error) {
if len(buffer) != 0 {
return nil, errors.Errorf("Has remained message(%s) before COMMANDS", buffer)
}
q, err := ParseQuery(query{Query: s, Line: i + 1}, t.delimiter)
q, err := ParseQuery(query{Query: s, Line: i + 1, Delimiter: t.delimiter})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -711,7 +712,7 @@ func (t *tester) loadQueries() ([]query, error) {

queryStr := buffer[:idx+len(t.delimiter)]
buffer = buffer[idx+len(t.delimiter):]
q, err := ParseQuery(query{Query: strings.TrimSpace(queryStr), Line: i + 1}, t.delimiter)
q, err := ParseQuery(query{Query: strings.TrimSpace(queryStr), Line: i + 1, Delimiter: t.delimiter})
if err != nil {
return nil, err
}
Expand All @@ -731,12 +732,13 @@ func (t *tester) loadQueries() ([]query, error) {
return queries, nil
}

func (t *tester) stmtExecute(query string) (err error) {
func (t *tester) stmtExecute(query query) (err error) {
if t.enableQueryLog {
t.buf.WriteString(query)
t.buf.WriteString(query.Query)
t.buf.WriteString("\n")
}
return t.executeStmt(query)

return t.executeStmt(strings.TrimSuffix(query.Query, query.Delimiter))
}

// checkExpectedError check if error was expected
Expand Down Expand Up @@ -834,7 +836,7 @@ func (t *tester) execute(query query) error {
}

offset := t.buf.Len()
err := t.stmtExecute(query.Query)
err := t.stmtExecute(query)

err = t.checkExpectedError(query, err)
if err != nil {
Expand Down
7 changes: 3 additions & 4 deletions src/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,11 @@ const (

// ParseQuery parses an array of string into an array of query object.
// Note: a query statement may reside in several lines.
func ParseQuery(rs query, delimiter string) (*query, error) {
func ParseQuery(rs query) (*query, error) {
realS := rs.Query
s := rs.Query
q := query{}
q := query{Delimiter: rs.Delimiter, Line: rs.Line}
q.tp = Q_UNKNOWN
q.Line = rs.Line
// a valid query's length should be at least 3.
if len(s) < 3 {
return nil, nil
Expand All @@ -160,7 +159,7 @@ func ParseQuery(rs query, delimiter string) (*query, error) {
if s[i] == '(' || s[i] == ' ' || s[i] == '\n' {
break
}
if i+len(delimiter) <= len(s) && s[i:i+len(delimiter)] == delimiter {
if i+len(rs.Delimiter) <= len(s) && s[i:i+len(rs.Delimiter)] == rs.Delimiter {
break
}
}
Expand Down
21 changes: 11 additions & 10 deletions src/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ func assertEqual(t *testing.T, a interface{}, b interface{}, message string) {
func TestParseQueryies(t *testing.T) {
sql := "select * from t;"

if q, err := ParseQuery(query{Query: sql, Line: 1}, ";"); err == nil {
if q, err := ParseQuery(query{Query: sql, Line: 1, Delimiter: ";"}); err == nil {
assertEqual(t, q.tp, Q_QUERY, fmt.Sprintf("Expected: %d, got: %d", Q_QUERY, q.tp))
assertEqual(t, q.Query, sql, fmt.Sprintf("Expected: %s, got: %s", sql, q.Query))
} else {
t.Fatalf("error is not nil. %v", err)
}

sql = "--sorted_result select * from t;"
if q, err := ParseQuery(query{Query: sql, Line: 1}, ";"); err == nil {
if q, err := ParseQuery(query{Query: sql, Line: 1, Delimiter: ";"}); err == nil {
assertEqual(t, q.tp, Q_SORTED_RESULT, "sorted_result")
assertEqual(t, q.Query, "select * from t;", fmt.Sprintf("Expected: '%s', got '%s'", "select * from t;", q.Query))
} else {
Expand All @@ -52,11 +52,11 @@ func TestParseQueryies(t *testing.T) {

// invalid comment command style
sql = "--abc select * from t;"
_, err := ParseQuery(query{Query: sql, Line: 1}, ";")
_, err := ParseQuery(query{Query: sql, Line: 1, Delimiter: ";"})
assertEqual(t, err, ErrInvalidCommand, fmt.Sprintf("Expected: %v, got %v", ErrInvalidCommand, err))

sql = "--let $foo=`SELECT 1`"
if q, err := ParseQuery(query{Query: sql, Line: 1}, ";"); err == nil {
if q, err := ParseQuery(query{Query: sql, Line: 1, Delimiter: ";"}); err == nil {
assertEqual(t, q.tp, Q_LET, fmt.Sprintf("Expected: %d, got: %d", Q_LET, q.tp))
}
}
Expand All @@ -76,22 +76,22 @@ func TestLoadQueries(t *testing.T) {
{
input: "delimiter |\n do something; select something; |\n delimiter ; \nselect 1;",
queries: []query{
{Query: "do something; select something; |", tp: Q_QUERY},
{Query: "select 1;", tp: Q_QUERY},
{Query: "do something; select something; |", tp: Q_QUERY, Delimiter: "|"},
{Query: "select 1;", tp: Q_QUERY, Delimiter: ";"},
},
},
{
input: "delimiter |\ndrop procedure if exists scopel\ncreate procedure scope(a int, b float)\nbegin\ndeclare b int;\ndeclare c float;\nbegin\ndeclare c int;\nend;\nend |\ndrop procedure scope|\ndelimiter ;\n",
queries: []query{
{Query: "drop procedure if exists scopel\ncreate procedure scope(a int, b float)\nbegin\ndeclare b int;\ndeclare c float;\nbegin\ndeclare c int;\nend;\nend |", tp: Q_QUERY},
{Query: "drop procedure scope|", tp: Q_QUERY},
{Query: "drop procedure if exists scopel\ncreate procedure scope(a int, b float)\nbegin\ndeclare b int;\ndeclare c float;\nbegin\ndeclare c int;\nend;\nend |", tp: Q_QUERY, Delimiter: "|"},
{Query: "drop procedure scope|", tp: Q_QUERY, Delimiter: "|"},
},
},
{
input: "--error 1054\nselect 1;",
queries: []query{
{Query: " 1054", tp: Q_ERROR},
{Query: "select 1;", tp: Q_QUERY},
{Query: " 1054", tp: Q_ERROR, Delimiter: ";"},
{Query: "select 1;", tp: Q_QUERY, Delimiter: ";"},
},
},
}
Expand All @@ -111,6 +111,7 @@ func TestLoadQueries(t *testing.T) {
for i, query := range testCase.queries {
assert.Equal(t, queries[i].Query, query.Query)
assert.Equal(t, queries[i].tp, query.tp)
assert.Equal(t, queries[i].Delimiter, query.Delimiter)
}
}
}

0 comments on commit 83dec84

Please sign in to comment.