Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support enable_compare_result and failpoint #130

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 257 additions & 15 deletions src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (
"database/sql"
"flag"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
Expand All @@ -36,6 +38,7 @@ import (
var (
host string
port string
statusPort string
user string
passwd string
logLevel string
Expand All @@ -52,6 +55,7 @@ var (
func init() {
flag.StringVar(&host, "host", "127.0.0.1", "The host of the TiDB/MySQL server.")
flag.StringVar(&port, "port", "4000", "The listen port of TiDB/MySQL server.")
flag.StringVar(&statusPort, "status", "10080", "The status port of TiDB server.")
flag.StringVar(&user, "user", "root", "The user for connecting to the database.")
flag.StringVar(&passwd, "passwd", "", "The password for the user.")
flag.StringVar(&logLevel, "log-level", "error", "The log level of mysql-tester: info, warn, error, debug.")
Expand Down Expand Up @@ -129,6 +133,11 @@ type tester struct {
// enable query info, like rowsAffected, lastMessage etc.
enableInfo bool

// Compare the query result with failpoint enabled.
// Fail the test if there is a difference.
enableCompareResult bool
compareCfg map[compareCfgTp]*compareCfgVal

// check expected error, use --error before the statement
// see http://dev.mysql.com/doc/mysqltest/2.0/en/writing-tests-expecting-errors.html
expectedErrs []string
Expand Down Expand Up @@ -398,6 +407,15 @@ func (t *tester) Run() error {
t.enableInfo = true
case Q_DISABLE_INFO:
t.enableInfo = false
case Q_ENABLE_COMPARE_RESULT:
t.enableCompareResult = true
err := handleCompareResultCommand(t, s)
if err != nil {
t.addFailure(&testSuite, &err, testCnt)
return err
}
case Q_DISABLE_COMPARE_RESULT:
t.enableCompareResult = false
case Q_BEGIN_CONCURRENT:
// mysql-tester enhancement
concurrentQueue = make([]query, 0)
Expand Down Expand Up @@ -704,7 +722,7 @@ func (t *tester) checkExpectedError(q query, err error) error {
}
return errors.Errorf("Statement succeeded, expected error(s) '%s'", strings.Join(t.expectedErrs, ","))
}
if err != nil && len(t.expectedErrs) == 0 {
if len(t.expectedErrs) == 0 {
return err
}
// Parse the error to get the mysql error code
Expand Down Expand Up @@ -787,10 +805,6 @@ func (t *tester) execute(query query) error {
// clear expected errors after we execute the first query
t.expectedErrs = nil

if err != nil {
return errors.Trace(errors.Errorf("run \"%v\" at line %d err %v", query.Query, query.Line, err))
}

if !record {
// check test result now
gotBuf := t.buf.Bytes()[offset:]
Expand Down Expand Up @@ -861,6 +875,40 @@ type byteRow struct {
data [][]byte
}

func (b byteRow) Equal(other byteRow) bool {
if b.data == nil && other.data == nil {
return true
}
if b.data == nil || other.data == nil {
return false
}
if len(b.data) != len(other.data) {
return false
}
for i := 0; i < len(b.data); i++ {
if len(b.data[i]) != len(other.data[i]) {
return false
}
for j := 0; j < len(b.data[i]); j++ {
if b.data[i][j] != other.data[i][j] {
return false
}
}
}
return true
}

func (b byteRow) String() string {
var sb strings.Builder
for i, d := range b.data {
if i > 0 {
sb.WriteString("\n")
}
sb.WriteString(string(d))
}
return sb.String()
}

type byteRows struct {
cols []string
data []byteRow
Expand Down Expand Up @@ -899,6 +947,46 @@ func (rows *byteRows) Swap(i, j int) {
rows.data[i], rows.data[j] = rows.data[j], rows.data[i]
}

func (rows *byteRows) Equal(other *byteRows) bool {
if rows == nil && other == nil {
return true
}
if rows == nil || other == nil {
return false
}
if len(rows.cols) != len(other.cols) || len(rows.data) != len(other.data) {
return false
}
for i := 0; i < len(rows.cols); i++ {
if rows.cols[i] != other.cols[i] {
return false
}
}
for i := 0; i < len(rows.data); i++ {
if !rows.data[i].Equal(other.data[i]) {
return false
}
}
return true
}

func (rows *byteRows) String() string {
var sb strings.Builder
for i, r := range rows.cols {
if i > 0 {
sb.WriteString("\n")
}
sb.WriteString(r)
}
for i, r := range rows.data {
if i > 0 {
sb.WriteString("\n")
}
sb.WriteString(r.String())
}
return sb.String()
}

func dumpToByteRows(rows *sql.Rows) (*byteRows, error) {
cols, err := rows.Columns()
if err != nil {
Expand Down Expand Up @@ -934,24 +1022,70 @@ func dumpToByteRows(rows *sql.Rows) (*byteRows, error) {

func (t *tester) executeStmt(query string) error {
log.Debugf("executeStmt: %s", query)
raw, err := t.curr.conn.QueryContext(context.Background(), query)
if err != nil {
return errors.Trace(err)
}

rows, err := dumpToByteRows(raw)
if err != nil {
return errors.Trace(err)
var rows *byteRows
if t.enableCompareResult && isQueryReadOnly(query) {
var firstRows *byteRows
var firstTag string
for tp, val := range t.compareCfg {
var cleanup func() error
switch tp {
case compareCfgTpFailpoint:
err := changeFailpoint(val.failpointPath, val.failpointValue, true)
if err != nil {
return errors.Trace(err)
}
cleanup = func() error {
return changeFailpoint(val.failpointPath, "", false)
}
case compareCfgTpThis:
// do nothing.
default:
return errors.Errorf("unknown compare option %d", tp)
}
raw, err := t.curr.conn.QueryContext(context.Background(), query)
if err != nil {
return errors.Trace(err)
}
rows, err = dumpToByteRows(raw)
if err != nil {
return errors.Trace(err)
}
if firstRows == nil {
firstRows = rows
firstTag = convertCompareCfgToString(tp, val)
} else {
if !firstRows.Equal(rows) {
curTag := convertCompareCfgToString(tp, val)
return errors.Trace(errors.Errorf("compare results \"%v\" failed, [\"%s\"] has result %s, but [\"%s\"] has result %s",
query, firstTag, firstRows.String(), curTag, rows.String()))
}
}
if cleanup != nil {
err := cleanup()
if err != nil {
return errors.Trace(err)
}
}
}
} else {
raw, err := t.curr.conn.QueryContext(context.Background(), query)
if err != nil {
return errors.Trace(err)
}
rows, err = dumpToByteRows(raw)
if err != nil {
return errors.Trace(err)
}
}

if t.enableResultLog && (len(rows.cols) > 0 || len(rows.data) > 0) {
if err = t.writeQueryResult(rows); err != nil {
if err := t.writeQueryResult(rows); err != nil {
return errors.Trace(err)
}
}

if t.enableInfo {
err = t.curr.conn.Raw(func(driverConn any) error {
err := t.curr.conn.Raw(func(driverConn any) error {
rowsAffected := driverConn.(*mysql.MysqlConn).RowsAffected()
lastMessage := driverConn.(*mysql.MysqlConn).LastMessage()
t.buf.WriteString(fmt.Sprintf("affected rows: %d\n", rowsAffected))
Expand Down Expand Up @@ -982,6 +1116,45 @@ func (t *tester) executeStmt(query string) error {
return nil
}

func changeFailpoint(path, val string, enable bool) (err error) {
url := fmt.Sprintf("http://%s:%s/fail/%s", host, statusPort, path)
var req *http.Request
var op string
if enable {
body := bytes.NewBuffer([]byte(val))
req, err = http.NewRequest(http.MethodPut, url, body)
if err != nil {
return err
}
op = "enable"
} else {
req, err = http.NewRequest(http.MethodDelete, url, nil)
if err != nil {
return err
}
op = "disable"
}

res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
if res.StatusCode != http.StatusOK && res.StatusCode != http.StatusNoContent {
resBody, err := io.ReadAll(res.Body)
if err != nil {
return err
}
return errors.Errorf("cannot %s failpoint: %s, status code: %d, body: %s", op, url, res.StatusCode, string(resBody))
}
return res.Body.Close()
}

func isQueryReadOnly(query string) bool {
trimmed := strings.TrimSpace(query)
lowered := strings.ToLower(trimmed)
return strings.HasPrefix(lowered, "select")
}

func (t *tester) executeStmtString(query string) (string, error) {
var result string
err := t.mdb.QueryRow(query).Scan(&result)
Expand Down Expand Up @@ -1266,3 +1439,72 @@ func main() {
println("Great, All tests passed")
}
}

type compareCfgTp int8

const (
compareCfgTpThis compareCfgTp = iota
compareCfgTpFailpoint
compareCfgTpUnknown
)

type compareCfgVal struct {
failpointPath string
failpointValue string
}

func convertCompareCfgToString(tp compareCfgTp, v *compareCfgVal) string {
var sb strings.Builder
switch tp {
case compareCfgTpThis:
sb.WriteString("this")
case compareCfgTpFailpoint:
sb.WriteString("failpoint")
case compareCfgTpUnknown:
sb.WriteString("unknown")
}
sb.WriteString(":")
if v != nil {
sb.WriteString(v.failpointPath)
sb.WriteString("=")
sb.WriteString(v.failpointValue)
}
return sb.String()
}

func handleCompareResultCommand(t *tester, query string) error {
t.compareCfg = make(map[compareCfgTp]*compareCfgVal)
for _, field := range strings.Fields(query) {
tp, v := parseCompareResultItem(field)
if tp == compareCfgTpUnknown {
return errors.Errorf("Unknown compare result type in %s", query)
}
t.compareCfg[tp] = v
}
if len(t.compareCfg) < 2 {
return errors.Errorf("Compare result needs at least two arguments: %s", query)
}
return nil
}

// Parse the compare result item, which is in the format of `type:key=value`.
func parseCompareResultItem(item string) (compareCfgTp, *compareCfgVal) {
if item == "this" {
return compareCfgTpThis, nil
}
s := strings.Split(item, ":")
if len(s) != 2 {
return compareCfgTpUnknown, nil
}
tp := s[0]
if tp == "failpoint" {
s = strings.Split(s[1], "=")
if len(s) != 2 {
return compareCfgTpUnknown, nil
}
key := s[0]
value := s[1]
return compareCfgTpFailpoint, &compareCfgVal{failpointPath: key, failpointValue: value}
}
return compareCfgTpUnknown, nil
}
2 changes: 2 additions & 0 deletions src/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ const (
Q_COMMENT /* Comments, ignored. */
Q_COMMENT_WITH_COMMAND
Q_EMPTY_LINE
Q_ENABLE_COMPARE_RESULT
Q_DISABLE_COMPARE_RESULT
)

// ParseQueries parses an array of string into an array of query object.
Expand Down
2 changes: 2 additions & 0 deletions src/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ var commandMap = map[string]int{
"single_query": Q_SINGLE_QUERY,
"begin_concurrent": Q_BEGIN_CONCURRENT,
"end_concurrent": Q_END_CONCURRENT,
"enable_compare_result": Q_ENABLE_COMPARE_RESULT,
"disable_compare_result": Q_DISABLE_COMPARE_RESULT,
}

func findType(cmdName string) int {
Expand Down
Loading