Skip to content

Commit

Permalink
Add support for the source command
Browse files Browse the repository at this point in the history
  • Loading branch information
dveeden committed Jul 26, 2024
1 parent aa83826 commit 4a2376e
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 33 deletions.
1 change: 1 addition & 0 deletions include/hello.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
--echo Hello from the included file
3 changes: 3 additions & 0 deletions r/source.result
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
first line
Hello from the included file
last line
131 changes: 98 additions & 33 deletions src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,15 @@ const (
type query struct {
firstWord string
Query string
File string
Line int
tp int
}

func (q *query) location() string {
return fmt.Sprintf("%s:%d", q.File, q.Line)
}

type Conn struct {
// DB might be a shared one by multiple Conn, if the connection information are the same.
mdb *sql.DB
Expand Down Expand Up @@ -325,7 +330,7 @@ func (t *tester) addSuccess(testSuite *XUnitTestSuite, startTime *time.Time, cnt
func (t *tester) Run() error {
t.preProcess()
defer t.postProcess()
queries, err := t.loadQueries()
queries, err := t.loadQueries(t.testFileName())
if err != nil {
err = errors.Trace(err)
t.addFailure(&testSuite, &err, 0)
Expand All @@ -338,17 +343,33 @@ func (t *tester) Run() error {
return err
}

var s string
defer func() {
if t.resultFD != nil {
t.resultFD.Close()
}
}()

testCnt := 0
startTime := time.Now()
testCnt, err := t.runQueries(queries)
if err != nil {
return err
}

fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds())

if xmlPath != "" {
t.addSuccess(&testSuite, &startTime, testCnt)
}

return t.flushResult()
}

func (t *tester) runQueries(queries []query) (int, error) {
testCnt := 0
var concurrentQueue []query
var concurrentSize int
var s string
var err error
for _, q := range queries {
s = q.Query
switch q.tp {
Expand Down Expand Up @@ -379,15 +400,15 @@ func (t *tester) Run() error {
if err != nil {
err = errors.Annotate(err, "Atoi failed")
t.addFailure(&testSuite, &err, testCnt)
return err
return testCnt, err
}
}
case Q_END_CONCURRENT:
t.enableConcurrent = false
if err = t.concurrentRun(concurrentQueue, concurrentSize); err != nil {
err = errors.Annotate(err, fmt.Sprintf("concurrent test failed in %v", t.name))
t.addFailure(&testSuite, &err, testCnt)
return err
return testCnt, err
}
t.expectedErrs = nil
case Q_ERROR:
Expand All @@ -406,7 +427,7 @@ func (t *tester) Run() error {
} else if err = t.execute(q); err != nil {
err = errors.Annotate(err, fmt.Sprintf("sql:%v", q.Query))
t.addFailure(&testSuite, &err, testCnt)
return err
return testCnt, err
}

testCnt++
Expand All @@ -426,7 +447,7 @@ func (t *tester) Run() error {
if err != nil {
err = errors.Annotate(err, fmt.Sprintf("Could not parse column in --replace_column: sql:%v", q.Query))
t.addFailure(&testSuite, &err, testCnt)
return err
return testCnt, err
}

t.replaceColumn = append(t.replaceColumn, ReplaceColumn{col: colNr, replace: []byte(cols[i+1])})
Expand Down Expand Up @@ -473,7 +494,7 @@ func (t *tester) Run() error {
r, err := t.executeStmtString(s)
if err != nil {
log.WithFields(log.Fields{
"query": s, "line": q.Line},
"query": s, "line": q.location()},
).Error("failed to perform let query")
return ""
}
Expand All @@ -484,27 +505,59 @@ func (t *tester) Run() error {
case Q_REMOVE_FILE:
err = os.Remove(strings.TrimSpace(q.Query))
if err != nil {
return errors.Annotate(err, "failed to remove file")
return testCnt, errors.Annotate(err, "failed to remove file")
}
case Q_REPLACE_REGEX:
t.replaceRegex = nil
regex, err := ParseReplaceRegex(q.Query)
if err != nil {
return errors.Annotate(err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %d sql:%v", q.Line, q.Query))
return testCnt, errors.Annotate(
err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %s sql:%v",
q.location(), q.Query))
}
t.replaceRegex = regex
default:
log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.Line}).Warn("command not implemented")
}
}
case Q_SOURCE:
fileName := strings.TrimSpace(q.Query)
cwd, err := os.Getwd()
if err != nil {
return testCnt, err
}

fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds())
// For security, don't allow to include files from other locations
fullpath, err := filepath.Abs(fileName)
if err != nil {
return testCnt, err
}
if !strings.HasPrefix(fullpath, cwd) {
return testCnt, errors.Errorf("included file %s is not prefixed with %s", fullpath, cwd)
}

if xmlPath != "" {
t.addSuccess(&testSuite, &startTime, testCnt)
}
// Make sure we have a useful error message if the file can't be found or isn't a regular file
s, err := os.Stat(fileName)
if err != nil {
return testCnt, errors.Annotate(err,
fmt.Sprintf("file sourced with --source doesn't exist: line %s, file: %s",
q.location(), fileName))
}
if !s.Mode().IsRegular() {
return testCnt, errors.Errorf("file sourced with --source isn't a regular file: line %s, file: %s",
q.location(), fileName)
}

return t.flushResult()
// Process the queries in the file
includedQueries, err := t.loadQueries(fileName)
if err != nil {
return testCnt, errors.Annotate(err, fmt.Sprintf("error loading queries from %s", fileName))
}
_, err = t.runQueries(includedQueries)
if err != nil {
return testCnt, err
}
default:
log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.location()}).Warn("command not implemented")
}
}
return testCnt, nil
}

func (t *tester) concurrentRun(concurrentQueue []query, concurrentSize int) error {
Expand Down Expand Up @@ -606,8 +659,8 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure
}
}

func (t *tester) loadQueries() ([]query, error) {
data, err := os.ReadFile(t.testFileName())
func (t *tester) loadQueries(fileName string) ([]query, error) {
data, err := os.ReadFile(fileName)
if err != nil {
return nil, err
}
Expand All @@ -623,18 +676,30 @@ func (t *tester) loadQueries() ([]query, error) {
newStmt = true
continue
} else if strings.HasPrefix(s, "--") {
queries = append(queries, query{Query: s, Line: i + 1})
queries = append(queries, query{
Query: s,
Line: i + 1,
File: fileName,
})
newStmt = true
continue
} else if len(s) == 0 {
continue
}

if newStmt {
queries = append(queries, query{Query: s, Line: i + 1})
queries = append(queries, query{
Query: s,
Line: i + 1,
File: fileName,
})
} else {
lastQuery := queries[len(queries)-1]
lastQuery = query{Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), Line: lastQuery.Line}
lastQuery = query{
Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s),
Line: lastQuery.Line,
File: fileName,
}
queries[len(queries)-1] = lastQuery
}

Expand Down Expand Up @@ -668,8 +733,8 @@ func (t *tester) checkExpectedError(q query, err error) error {
}
}
if !checkErr {
log.Warnf("%s:%d query succeeded, but expected error(s)! (expected errors: %s) (query: %s)",
t.name, q.Line, strings.Join(t.expectedErrs, ","), q.Query)
log.Warnf("%s query succeeded, but expected error(s)! (expected errors: %s) (query: %s)",
q.location(), strings.Join(t.expectedErrs, ","), q.Query)
return nil
}
return errors.Errorf("Statement succeeded, expected error(s) '%s'", strings.Join(t.expectedErrs, ","))
Expand All @@ -684,7 +749,7 @@ func (t *tester) checkExpectedError(q query, err error) error {
errNo = int(innerErr.Number)
}
if errNo == 0 {
log.Warnf("%s:%d Could not parse mysql error: %s", t.name, q.Line, err.Error())
log.Warnf("%s Could not parse mysql error: %s", q.location(), err.Error())
return err
}
for _, s := range t.expectedErrs {
Expand All @@ -696,9 +761,9 @@ func (t *tester) checkExpectedError(q query, err error) error {
checkErrNo = i
} else {
if len(t.expectedErrs) > 1 {
log.Warnf("%s:%d Unknown named error %s in --error %s", t.name, q.Line, s, strings.Join(t.expectedErrs, ","))
log.Warnf("%s Unknown named error %s in --error %s", q.location(), s, strings.Join(t.expectedErrs, ","))
} else {
log.Warnf("%s:%d Unknown named --error %s", t.name, q.Line, s)
log.Warnf("%s Unknown named --error %s", q.location(), s)
}
continue
}
Expand Down Expand Up @@ -726,11 +791,11 @@ func (t *tester) checkExpectedError(q query, err error) error {
}
}
if len(t.expectedErrs) > 1 {
log.Warnf("%s:%d query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)",
t.name, q.Line, gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query)
log.Warnf("%s query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)",
q.location(), gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query)
} else {
log.Warnf("%s:%d query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)",
t.name, q.Line, gotErrCode, t.expectedErrs[0], err.Error(), q.Query)
log.Warnf("%s query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)",
q.location(), gotErrCode, t.expectedErrs[0], err.Error(), q.Query)
}
errStr := err.Error()
for _, reg := range t.replaceRegex {
Expand Down
1 change: 1 addition & 0 deletions src/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ func ParseQueries(qs ...query) ([]query, error) {
q := query{}
q.tp = Q_UNKNOWN
q.Line = rs.Line
q.File = rs.File
// a valid query's length should be at least 3.
if len(s) < 3 {
continue
Expand Down
3 changes: 3 additions & 0 deletions t/source.test
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
--echo first line
--source include/hello.inc
--echo last line

0 comments on commit 4a2376e

Please sign in to comment.