diff --git a/src/main.go b/src/main.go index b49d0fd..4c4b002 100644 --- a/src/main.go +++ b/src/main.go @@ -19,6 +19,8 @@ import ( "database/sql" "flag" "fmt" + "io" + "net/http" "os" "path/filepath" "regexp" @@ -36,6 +38,7 @@ import ( var ( host string port string + statusPort string user string passwd string logLevel string @@ -52,6 +55,7 @@ var ( func init() { flag.StringVar(&host, "host", "127.0.0.1", "The host of the TiDB/MySQL server.") flag.StringVar(&port, "port", "4000", "The listen port of TiDB/MySQL server.") + flag.StringVar(&statusPort, "status", "10080", "The status port of TiDB server.") flag.StringVar(&user, "user", "root", "The user for connecting to the database.") flag.StringVar(&passwd, "passwd", "", "The password for the user.") flag.StringVar(&logLevel, "log-level", "error", "The log level of mysql-tester: info, warn, error, debug.") @@ -129,6 +133,11 @@ type tester struct { // enable query info, like rowsAffected, lastMessage etc. enableInfo bool + // Compare the query result with failpoint enabled. + // Fail the test if there is a difference. + enableCompareResult bool + compareCfg map[compareCfgTp]*compareCfgVal + // check expected error, use --error before the statement // see http://dev.mysql.com/doc/mysqltest/2.0/en/writing-tests-expecting-errors.html expectedErrs []string @@ -398,6 +407,15 @@ func (t *tester) Run() error { t.enableInfo = true case Q_DISABLE_INFO: t.enableInfo = false + case Q_ENABLE_COMPARE_RESULT: + t.enableCompareResult = true + err := handleCompareResultCommand(t, s) + if err != nil { + t.addFailure(&testSuite, &err, testCnt) + return err + } + case Q_DISABLE_COMPARE_RESULT: + t.enableCompareResult = false case Q_BEGIN_CONCURRENT: // mysql-tester enhancement concurrentQueue = make([]query, 0) @@ -704,7 +722,7 @@ func (t *tester) checkExpectedError(q query, err error) error { } return errors.Errorf("Statement succeeded, expected error(s) '%s'", strings.Join(t.expectedErrs, ",")) } - if err != nil && len(t.expectedErrs) == 0 { + if len(t.expectedErrs) == 0 { return err } // Parse the error to get the mysql error code @@ -787,10 +805,6 @@ func (t *tester) execute(query query) error { // clear expected errors after we execute the first query t.expectedErrs = nil - if err != nil { - return errors.Trace(errors.Errorf("run \"%v\" at line %d err %v", query.Query, query.Line, err)) - } - if !record { // check test result now gotBuf := t.buf.Bytes()[offset:] @@ -861,6 +875,40 @@ type byteRow struct { data [][]byte } +func (b byteRow) Equal(other byteRow) bool { + if b.data == nil && other.data == nil { + return true + } + if b.data == nil || other.data == nil { + return false + } + if len(b.data) != len(other.data) { + return false + } + for i := 0; i < len(b.data); i++ { + if len(b.data[i]) != len(other.data[i]) { + return false + } + for j := 0; j < len(b.data[i]); j++ { + if b.data[i][j] != other.data[i][j] { + return false + } + } + } + return true +} + +func (b byteRow) String() string { + var sb strings.Builder + for i, d := range b.data { + if i > 0 { + sb.WriteString("\n") + } + sb.WriteString(string(d)) + } + return sb.String() +} + type byteRows struct { cols []string data []byteRow @@ -899,6 +947,46 @@ func (rows *byteRows) Swap(i, j int) { rows.data[i], rows.data[j] = rows.data[j], rows.data[i] } +func (rows *byteRows) Equal(other *byteRows) bool { + if rows == nil && other == nil { + return true + } + if rows == nil || other == nil { + return false + } + if len(rows.cols) != len(other.cols) || len(rows.data) != len(other.data) { + return false + } + for i := 0; i < len(rows.cols); i++ { + if rows.cols[i] != other.cols[i] { + return false + } + } + for i := 0; i < len(rows.data); i++ { + if !rows.data[i].Equal(other.data[i]) { + return false + } + } + return true +} + +func (rows *byteRows) String() string { + var sb strings.Builder + for i, r := range rows.cols { + if i > 0 { + sb.WriteString("\n") + } + sb.WriteString(r) + } + for i, r := range rows.data { + if i > 0 { + sb.WriteString("\n") + } + sb.WriteString(r.String()) + } + return sb.String() +} + func dumpToByteRows(rows *sql.Rows) (*byteRows, error) { cols, err := rows.Columns() if err != nil { @@ -934,24 +1022,70 @@ func dumpToByteRows(rows *sql.Rows) (*byteRows, error) { func (t *tester) executeStmt(query string) error { log.Debugf("executeStmt: %s", query) - raw, err := t.curr.conn.QueryContext(context.Background(), query) - if err != nil { - return errors.Trace(err) - } - - rows, err := dumpToByteRows(raw) - if err != nil { - return errors.Trace(err) + var rows *byteRows + if t.enableCompareResult && isQueryReadOnly(query) { + var firstRows *byteRows + var firstTag string + for tp, val := range t.compareCfg { + var cleanup func() error + switch tp { + case compareCfgTpFailpoint: + err := changeFailpoint(val.failpointPath, val.failpointValue, true) + if err != nil { + return errors.Trace(err) + } + cleanup = func() error { + return changeFailpoint(val.failpointPath, "", false) + } + case compareCfgTpThis: + // do nothing. + default: + return errors.Errorf("unknown compare option %d", tp) + } + raw, err := t.curr.conn.QueryContext(context.Background(), query) + if err != nil { + return errors.Trace(err) + } + rows, err = dumpToByteRows(raw) + if err != nil { + return errors.Trace(err) + } + if firstRows == nil { + firstRows = rows + firstTag = convertCompareCfgToString(tp, val) + } else { + if !firstRows.Equal(rows) { + curTag := convertCompareCfgToString(tp, val) + return errors.Trace(errors.Errorf("compare results \"%v\" failed, [\"%s\"] has result %s, but [\"%s\"] has result %s", + query, firstTag, firstRows.String(), curTag, rows.String())) + } + } + if cleanup != nil { + err := cleanup() + if err != nil { + return errors.Trace(err) + } + } + } + } else { + raw, err := t.curr.conn.QueryContext(context.Background(), query) + if err != nil { + return errors.Trace(err) + } + rows, err = dumpToByteRows(raw) + if err != nil { + return errors.Trace(err) + } } if t.enableResultLog && (len(rows.cols) > 0 || len(rows.data) > 0) { - if err = t.writeQueryResult(rows); err != nil { + if err := t.writeQueryResult(rows); err != nil { return errors.Trace(err) } } if t.enableInfo { - err = t.curr.conn.Raw(func(driverConn any) error { + err := t.curr.conn.Raw(func(driverConn any) error { rowsAffected := driverConn.(*mysql.MysqlConn).RowsAffected() lastMessage := driverConn.(*mysql.MysqlConn).LastMessage() t.buf.WriteString(fmt.Sprintf("affected rows: %d\n", rowsAffected)) @@ -982,6 +1116,45 @@ func (t *tester) executeStmt(query string) error { return nil } +func changeFailpoint(path, val string, enable bool) (err error) { + url := fmt.Sprintf("http://%s:%s/fail/%s", host, statusPort, path) + var req *http.Request + var op string + if enable { + body := bytes.NewBuffer([]byte(val)) + req, err = http.NewRequest(http.MethodPut, url, body) + if err != nil { + return err + } + op = "enable" + } else { + req, err = http.NewRequest(http.MethodDelete, url, nil) + if err != nil { + return err + } + op = "disable" + } + + res, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusNoContent { + resBody, err := io.ReadAll(res.Body) + if err != nil { + return err + } + return errors.Errorf("cannot %s failpoint: %s, status code: %d, body: %s", op, url, res.StatusCode, string(resBody)) + } + return res.Body.Close() +} + +func isQueryReadOnly(query string) bool { + trimmed := strings.TrimSpace(query) + lowered := strings.ToLower(trimmed) + return strings.HasPrefix(lowered, "select") +} + func (t *tester) executeStmtString(query string) (string, error) { var result string err := t.mdb.QueryRow(query).Scan(&result) @@ -1266,3 +1439,72 @@ func main() { println("Great, All tests passed") } } + +type compareCfgTp int8 + +const ( + compareCfgTpThis compareCfgTp = iota + compareCfgTpFailpoint + compareCfgTpUnknown +) + +type compareCfgVal struct { + failpointPath string + failpointValue string +} + +func convertCompareCfgToString(tp compareCfgTp, v *compareCfgVal) string { + var sb strings.Builder + switch tp { + case compareCfgTpThis: + sb.WriteString("this") + case compareCfgTpFailpoint: + sb.WriteString("failpoint") + case compareCfgTpUnknown: + sb.WriteString("unknown") + } + sb.WriteString(":") + if v != nil { + sb.WriteString(v.failpointPath) + sb.WriteString("=") + sb.WriteString(v.failpointValue) + } + return sb.String() +} + +func handleCompareResultCommand(t *tester, query string) error { + t.compareCfg = make(map[compareCfgTp]*compareCfgVal) + for _, field := range strings.Fields(query) { + tp, v := parseCompareResultItem(field) + if tp == compareCfgTpUnknown { + return errors.Errorf("Unknown compare result type in %s", query) + } + t.compareCfg[tp] = v + } + if len(t.compareCfg) < 2 { + return errors.Errorf("Compare result needs at least two arguments: %s", query) + } + return nil +} + +// Parse the compare result item, which is in the format of `type:key=value`. +func parseCompareResultItem(item string) (compareCfgTp, *compareCfgVal) { + if item == "this" { + return compareCfgTpThis, nil + } + s := strings.Split(item, ":") + if len(s) != 2 { + return compareCfgTpUnknown, nil + } + tp := s[0] + if tp == "failpoint" { + s = strings.Split(s[1], "=") + if len(s) != 2 { + return compareCfgTpUnknown, nil + } + key := s[0] + value := s[1] + return compareCfgTpFailpoint, &compareCfgVal{failpointPath: key, failpointValue: value} + } + return compareCfgTpUnknown, nil +} diff --git a/src/query.go b/src/query.go index 6a128d8..c74a5f3 100644 --- a/src/query.go +++ b/src/query.go @@ -124,6 +124,8 @@ const ( Q_COMMENT /* Comments, ignored. */ Q_COMMENT_WITH_COMMAND Q_EMPTY_LINE + Q_ENABLE_COMPARE_RESULT + Q_DISABLE_COMPARE_RESULT ) // ParseQueries parses an array of string into an array of query object. diff --git a/src/type.go b/src/type.go index 50ea5a6..31dbdbd 100644 --- a/src/type.go +++ b/src/type.go @@ -114,6 +114,8 @@ var commandMap = map[string]int{ "single_query": Q_SINGLE_QUERY, "begin_concurrent": Q_BEGIN_CONCURRENT, "end_concurrent": Q_END_CONCURRENT, + "enable_compare_result": Q_ENABLE_COMPARE_RESULT, + "disable_compare_result": Q_DISABLE_COMPARE_RESULT, } func findType(cmdName string) int {