Skip to content

Commit

Permalink
Merge pull request #820 from dolthub/zachmu/dolt-revert
Browse files Browse the repository at this point in the history
Unskip tests for dolt_revert
  • Loading branch information
zachmu authored Oct 8, 2024
2 parents 94536b3 + 4437932 commit bb1ad6f
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 71 deletions.
80 changes: 9 additions & 71 deletions testing/go/enginetest/doltgres_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,79 +148,15 @@ func TestSingleScript(t *testing.T) {

var scripts = []queries.ScriptTest{
{
Name: "primary key table: basic cases",
Name: "dolt_revert() detects not null violation (issue #4527)",
SetUpScript: []string{
"create table t1 (n int primary key, de varchar(20));",
"call dolt_add('.')",
"insert into t1 values (1, 'Eins'), (2, 'Zwei'), (3, 'Drei');",
"call dolt_commit('-am', 'inserting into t1', '--date', '2022-08-06T12:00:01');",
"SET @Commit1 = (select hashof('HEAD'));",

"alter table t1 add column fr varchar(20);",
"insert into t1 values (4, 'Vier', 'Quatre');",
"call dolt_commit('-am', 'adding column and inserting data in t1', '--date', '2022-08-06T12:00:02');",
"SET @Commit2 = (select hashof('HEAD'));",

"update t1 set fr='Un' where n=1;",
"update t1 set fr='Deux' where n=2;",
"call dolt_commit('-am', 'updating data in t1', '--date', '2022-08-06T12:00:03');",
"SET @Commit3 = (select hashof('HEAD'));",

"update t1 set de=concat(de, ', meine herren') where n>1;",
"call dolt_commit('-am', 'be polite when you address a gentleman', '--date', '2022-08-06T12:00:04');",
"SET @Commit4 = (select hashof('HEAD'));",

"delete from t1 where n=2;",
"call dolt_commit('-am', 'we don''t need the number 2', '--date', '2022-08-06T12:00:05');",
"SET @Commit5 = (select hashof('HEAD'));",
"create table test2 (pk int primary key, c0 int)",
"alter table test2 modify c0 int not null",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "select count(*) from Dolt_History_t1;",
Expected: []sql.Row{{18}},
},
{
Query: "select n, de, fr from dolt_history_T1 where commit_hash = @Commit1;",
Expected: []sql.Row{{1, "Eins", nil}, {2, "Zwei", nil}, {3, "Drei", nil}},
},
{
Query: "select de, fr from dolt_history_T1 where commit_hash = @Commit1;",
Expected: []sql.Row{{"Eins", nil}, {"Zwei", nil}, {"Drei", nil}},
},
{
Query: "select n, de, fr from dolt_history_T1 where commit_hash = @Commit2;",
Expected: []sql.Row{{1, "Eins", nil}, {2, "Zwei", nil}, {3, "Drei", nil}, {4, "Vier", "Quatre"}},
},
{
Query: "select n, de, fr from dolt_history_T1 where commit_hash = @Commit3;",
Expected: []sql.Row{{1, "Eins", "Un"}, {2, "Zwei", "Deux"}, {3, "Drei", nil}, {4, "Vier", "Quatre"}},
},
{
Query: "select n, de, fr from dolt_history_T1 where commit_hash = @Commit4;",
Expected: []sql.Row{
{1, "Eins", "Un"},
{2, "Zwei, meine herren", "Deux"},
{3, "Drei, meine herren", nil},
{4, "Vier, meine herren", "Quatre"},
},
},
{
Query: "select n, de, fr from dolt_history_T1 where commit_hash = @Commit5;",
Expected: []sql.Row{
{1, "Eins", "Un"},
{3, "Drei, meine herren", nil},
{4, "Vier, meine herren", "Quatre"},
},
},
{
Query: "select de, fr, commit_hash=@commit1, commit_hash=@commit2, commit_hash=@commit3, commit_hash=@commit4" +
" from dolt_history_T1 where n=2 order by commit_date",
Expected: []sql.Row{
{"Zwei", nil, true, false, false, false},
{"Zwei", nil, false, true, false, false},
{"Zwei", "Deux", false, false, true, false},
{"Zwei, meine herren", "Deux", false, false, false, true},
},
Query: "call dolt_revert('head~1');",
ExpectedErrStr: "revert currently does not handle constraint violations",
},
},
},
Expand Down Expand Up @@ -1147,8 +1083,10 @@ func TestDoltRebasePrepared(t *testing.T) {
}

func TestDoltRevert(t *testing.T) {
t.Skip()
h := newDoltgresServerHarness(t)
h := newDoltgresServerHarness(t).WithSkippedQueries([]string{
"dolt_revert() respects dolt_ignore", // ERROR: INSERT: non-Doltgres type found in destination: text
"dolt_revert() automatically resolves some conflicts", // panic: interface conversion: sql.Type is types.VarCharType, not types.StringType
})
denginetest.RunDoltRevertTests(t, h)
}

Expand Down
85 changes: 85 additions & 0 deletions testing/go/enginetest/query_converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,96 @@ func transformAST(query string) ([]string, bool) {
return transformSet(stmt)
case *sqlparser.Select:
return transformSelect(stmt)
case *sqlparser.AlterTable:
return transformAlterTable(stmt)
}

return nil, false
}

func transformAlterTable(stmt *sqlparser.AlterTable) ([]string, bool) {
var outputStmts []string
for _, statement := range stmt.Statements {
converted, ok := convertDdlStatement(statement)
if !ok {
return nil, false
}
outputStmts = append(outputStmts, converted...)
}
return outputStmts, true
}

func convertDdlStatement(statement *sqlparser.DDL) ([]string, bool) {
switch statement.Action {
case "alter":
switch statement.ColumnAction {
case "modify":
if len(statement.TableSpec.Columns) != 1 {
return nil, false
}

stmts := make([]string, 0)

col := statement.TableSpec.Columns[0]
tableName, err := tree.NewUnresolvedObjectName(1, [3]string{statement.Table.Name.String(), "", ""}, 0)
if err != nil {
panic(err)
}

newType := convertTypeDef(col.Type)
alter := tree.AlterTable{
Table: tableName,
Cmds: []tree.AlterTableCmd{
&tree.AlterTableAlterColumnType{
Column: tree.Name(col.Name.String()),
ToType: newType,
},
},
}

ctx := formatNodeWithUnqualifiedTableNames(&alter)
stmts = append(stmts, ctx.String())

// constraints
if col.Type.NotNull {
alter.Cmds = []tree.AlterTableCmd{
&tree.AlterTableSetNotNull{
Column: tree.Name(col.Name.String()),
},
}
ctx = formatNodeWithUnqualifiedTableNames(&alter)
stmts = append(stmts, ctx.String())
} else {
alter.Cmds = []tree.AlterTableCmd{
&tree.AlterTableDropNotNull{
Column: tree.Name(col.Name.String()),
},
}
ctx = formatNodeWithUnqualifiedTableNames(&alter)
stmts = append(stmts, ctx.String())
}

// rename
if statement.Column.String() != col.Name.String() {
alter.Cmds = []tree.AlterTableCmd{
&tree.AlterTableRenameColumn{
Column: tree.Name(statement.Column.String()),
NewName: tree.Name(col.Name.String()),
},
}
ctx = formatNodeWithUnqualifiedTableNames(&alter)
stmts = append(stmts, ctx.String())
}

return stmts, true
default:
return nil, false
}
default:
return nil, false
}
}

// transformSelect converts a MySQL SELECT statement to a postgres-compatible SELECT statement.
// This is a very broad surface area, so we do this very selectively
func transformSelect(stmt *sqlparser.Select) ([]string, bool) {
Expand Down

0 comments on commit bb1ad6f

Please sign in to comment.