diff --git a/include/hello.inc b/include/hello.inc new file mode 100644 index 0000000..3b2bbc4 --- /dev/null +++ b/include/hello.inc @@ -0,0 +1 @@ +--echo Hello from the included file diff --git a/r/source.result b/r/source.result new file mode 100644 index 0000000..911f76d --- /dev/null +++ b/r/source.result @@ -0,0 +1,3 @@ +first line +Hello from the included file +last line diff --git a/src/main.go b/src/main.go index 9ddaf45..97fcd4b 100644 --- a/src/main.go +++ b/src/main.go @@ -72,10 +72,15 @@ const ( type query struct { firstWord string Query string + File string Line int tp int } +func (q *query) location() string { + return fmt.Sprintf("%s:%d", q.File, q.Line) +} + type Conn struct { // DB might be a shared one by multiple Conn, if the connection information are the same. mdb *sql.DB @@ -325,7 +330,7 @@ func (t *tester) addSuccess(testSuite *XUnitTestSuite, startTime *time.Time, cnt func (t *tester) Run() error { t.preProcess() defer t.postProcess() - queries, err := t.loadQueries() + queries, err := t.loadQueries(t.testFileName()) if err != nil { err = errors.Trace(err) t.addFailure(&testSuite, &err, 0) @@ -338,17 +343,33 @@ func (t *tester) Run() error { return err } - var s string defer func() { if t.resultFD != nil { t.resultFD.Close() } }() - testCnt := 0 startTime := time.Now() + testCnt, err := t.runQueries(queries) + if err != nil { + return err + } + + fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds()) + + if xmlPath != "" { + t.addSuccess(&testSuite, &startTime, testCnt) + } + + return t.flushResult() +} + +func (t *tester) runQueries(queries []query) (int, error) { + testCnt := 0 var concurrentQueue []query var concurrentSize int + var s string + var err error for _, q := range queries { s = q.Query switch q.tp { @@ -379,7 +400,7 @@ func (t *tester) Run() error { if err != nil { err = errors.Annotate(err, "Atoi failed") t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } } case Q_END_CONCURRENT: @@ -387,7 +408,7 @@ func (t *tester) Run() error { if err = t.concurrentRun(concurrentQueue, concurrentSize); err != nil { err = errors.Annotate(err, fmt.Sprintf("concurrent test failed in %v", t.name)) t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } t.expectedErrs = nil case Q_ERROR: @@ -406,7 +427,7 @@ func (t *tester) Run() error { } else if err = t.execute(q); err != nil { err = errors.Annotate(err, fmt.Sprintf("sql:%v", q.Query)) t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } testCnt++ @@ -426,7 +447,7 @@ func (t *tester) Run() error { if err != nil { err = errors.Annotate(err, fmt.Sprintf("Could not parse column in --replace_column: sql:%v", q.Query)) t.addFailure(&testSuite, &err, testCnt) - return err + return testCnt, err } t.replaceColumn = append(t.replaceColumn, ReplaceColumn{col: colNr, replace: []byte(cols[i+1])}) @@ -473,7 +494,7 @@ func (t *tester) Run() error { r, err := t.executeStmtString(s) if err != nil { log.WithFields(log.Fields{ - "query": s, "line": q.Line}, + "query": s, "line": q.location()}, ).Error("failed to perform let query") return "" } @@ -484,27 +505,59 @@ func (t *tester) Run() error { case Q_REMOVE_FILE: err = os.Remove(strings.TrimSpace(q.Query)) if err != nil { - return errors.Annotate(err, "failed to remove file") + return testCnt, errors.Annotate(err, "failed to remove file") } case Q_REPLACE_REGEX: t.replaceRegex = nil regex, err := ParseReplaceRegex(q.Query) if err != nil { - return errors.Annotate(err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %d sql:%v", q.Line, q.Query)) + return testCnt, errors.Annotate( + err, fmt.Sprintf("Could not parse regex in --replace_regex: line: %s sql:%v", + q.location(), q.Query)) } t.replaceRegex = regex - default: - log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.Line}).Warn("command not implemented") - } - } + case Q_SOURCE: + fileName := strings.TrimSpace(q.Query) + cwd, err := os.Getwd() + if err != nil { + return testCnt, err + } - fmt.Printf("%s: ok! %d test cases passed, take time %v s\n", t.testFileName(), testCnt, time.Since(startTime).Seconds()) + // For security, don't allow to include files from other locations + fullpath, err := filepath.Abs(fileName) + if err != nil { + return testCnt, err + } + if !strings.HasPrefix(fullpath, cwd) { + return testCnt, errors.Errorf("included file %s is not prefixed with %s", fullpath, cwd) + } - if xmlPath != "" { - t.addSuccess(&testSuite, &startTime, testCnt) - } + // Make sure we have a useful error message if the file can't be found or isn't a regular file + s, err := os.Stat(fileName) + if err != nil { + return testCnt, errors.Annotate(err, + fmt.Sprintf("file sourced with --source doesn't exist: line %s, file: %s", + q.location(), fileName)) + } + if !s.Mode().IsRegular() { + return testCnt, errors.Errorf("file sourced with --source isn't a regular file: line %s, file: %s", + q.location(), fileName) + } - return t.flushResult() + // Process the queries in the file + includedQueries, err := t.loadQueries(fileName) + if err != nil { + return testCnt, errors.Annotate(err, fmt.Sprintf("error loading queries from %s", fileName)) + } + _, err = t.runQueries(includedQueries) + if err != nil { + return testCnt, err + } + default: + log.WithFields(log.Fields{"command": q.firstWord, "arguments": q.Query, "line": q.location()}).Warn("command not implemented") + } + } + return testCnt, nil } func (t *tester) concurrentRun(concurrentQueue []query, concurrentSize int) error { @@ -606,8 +659,8 @@ func (t *tester) concurrentExecute(querys []query, wg *sync.WaitGroup, errOccure } } -func (t *tester) loadQueries() ([]query, error) { - data, err := os.ReadFile(t.testFileName()) +func (t *tester) loadQueries(fileName string) ([]query, error) { + data, err := os.ReadFile(fileName) if err != nil { return nil, err } @@ -623,7 +676,11 @@ func (t *tester) loadQueries() ([]query, error) { newStmt = true continue } else if strings.HasPrefix(s, "--") { - queries = append(queries, query{Query: s, Line: i + 1}) + queries = append(queries, query{ + Query: s, + Line: i + 1, + File: fileName, + }) newStmt = true continue } else if len(s) == 0 { @@ -631,10 +688,18 @@ func (t *tester) loadQueries() ([]query, error) { } if newStmt { - queries = append(queries, query{Query: s, Line: i + 1}) + queries = append(queries, query{ + Query: s, + Line: i + 1, + File: fileName, + }) } else { lastQuery := queries[len(queries)-1] - lastQuery = query{Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), Line: lastQuery.Line} + lastQuery = query{ + Query: fmt.Sprintf("%s\n%s", lastQuery.Query, s), + Line: lastQuery.Line, + File: fileName, + } queries[len(queries)-1] = lastQuery } @@ -668,8 +733,8 @@ func (t *tester) checkExpectedError(q query, err error) error { } } if !checkErr { - log.Warnf("%s:%d query succeeded, but expected error(s)! (expected errors: %s) (query: %s)", - t.name, q.Line, strings.Join(t.expectedErrs, ","), q.Query) + log.Warnf("%s query succeeded, but expected error(s)! (expected errors: %s) (query: %s)", + q.location(), strings.Join(t.expectedErrs, ","), q.Query) return nil } return errors.Errorf("Statement succeeded, expected error(s) '%s'", strings.Join(t.expectedErrs, ",")) @@ -684,7 +749,7 @@ func (t *tester) checkExpectedError(q query, err error) error { errNo = int(innerErr.Number) } if errNo == 0 { - log.Warnf("%s:%d Could not parse mysql error: %s", t.name, q.Line, err.Error()) + log.Warnf("%s Could not parse mysql error: %s", q.location(), err.Error()) return err } for _, s := range t.expectedErrs { @@ -696,9 +761,9 @@ func (t *tester) checkExpectedError(q query, err error) error { checkErrNo = i } else { if len(t.expectedErrs) > 1 { - log.Warnf("%s:%d Unknown named error %s in --error %s", t.name, q.Line, s, strings.Join(t.expectedErrs, ",")) + log.Warnf("%s Unknown named error %s in --error %s", q.location(), s, strings.Join(t.expectedErrs, ",")) } else { - log.Warnf("%s:%d Unknown named --error %s", t.name, q.Line, s) + log.Warnf("%s Unknown named --error %s", q.location(), s) } continue } @@ -726,11 +791,11 @@ func (t *tester) checkExpectedError(q query, err error) error { } } if len(t.expectedErrs) > 1 { - log.Warnf("%s:%d query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)", - t.name, q.Line, gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query) + log.Warnf("%s query failed with non expected error(s)! (%s not in %s) (err: %s) (query: %s)", + q.location(), gotErrCode, strings.Join(t.expectedErrs, ","), err.Error(), q.Query) } else { - log.Warnf("%s:%d query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)", - t.name, q.Line, gotErrCode, t.expectedErrs[0], err.Error(), q.Query) + log.Warnf("%s query failed with non expected error(s)! (%s != %s) (err: %s) (query: %s)", + q.location(), gotErrCode, t.expectedErrs[0], err.Error(), q.Query) } errStr := err.Error() for _, reg := range t.replaceRegex { diff --git a/src/query.go b/src/query.go index 6a128d8..f9bb41b 100644 --- a/src/query.go +++ b/src/query.go @@ -136,6 +136,7 @@ func ParseQueries(qs ...query) ([]query, error) { q := query{} q.tp = Q_UNKNOWN q.Line = rs.Line + q.File = rs.File // a valid query's length should be at least 3. if len(s) < 3 { continue diff --git a/t/source.test b/t/source.test new file mode 100644 index 0000000..a204f48 --- /dev/null +++ b/t/source.test @@ -0,0 +1,3 @@ +--echo first line +--source include/hello.inc +--echo last line