Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TestUpdate and TestDelete #963

Merged
merged 8 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions server/ast/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
package ast

import (
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
"fmt"

pgexprs "github.com/dolthub/doltgresql/server/expression"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
pgexprs "github.com/dolthub/doltgresql/server/expression"
)

// nodeLimit handles *tree.Limit nodes.
Expand All @@ -43,11 +44,31 @@ func nodeLimit(ctx *Context, node *tree.Limit) (*vitess.Limit, error) {
// We need to remove the hard dependency, but for now, we'll just convert our literals to a vitess.SQLVal.
if injectedExpr, ok := count.(vitess.InjectedExpr); ok {
if literal, ok := injectedExpr.Expression.(*pgexprs.Literal); ok {
l := literal.Value()
limitValue, err := int64ValueForLimit(l)
if err != nil {
return nil, err
}

if limitValue < 0 {
return nil, fmt.Errorf("LIMIT must be greater than or equal to 0")
}

count = literal.ToVitessLiteral()
}
}
if injectedExpr, ok := offset.(vitess.InjectedExpr); ok {
if literal, ok := injectedExpr.Expression.(*pgexprs.Literal); ok {
o := literal.Value()
offsetVal, err := int64ValueForLimit(o)
if err != nil {
return nil, err
}

if offsetVal < 0 {
return nil, fmt.Errorf("OFFSET must be greater than or equal to 0")
}

offset = literal.ToVitessLiteral()
}
}
Expand All @@ -56,3 +77,23 @@ func nodeLimit(ctx *Context, node *tree.Limit) (*vitess.Limit, error) {
Rowcount: count,
}, nil
}

// int64ValueForLimit converts a literal value to an int64
func int64ValueForLimit(l any) (int64, error) {
var limitValue int64
switch l := l.(type) {
case int:
limitValue = int64(l)
case int32:
limitValue = int64(l)
case int64:
limitValue = l
case float64:
limitValue = int64(l)
case float32:
limitValue = int64(l)
default:
return 0, fmt.Errorf("unsupported limit/offset value type %T", l)
}
return limitValue, nil
}
2 changes: 1 addition & 1 deletion server/types/date.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func (b DateType) ToArrayType() DoltgresArrayType {

// Type implements the DoltgresType interface.
func (b DateType) Type() query.Type {
return sqltypes.Text
return sqltypes.Date
}

// ValueType implements the DoltgresType interface.
Expand Down
141 changes: 100 additions & 41 deletions testing/go/enginetest/doltgres_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,35 +171,23 @@ func TestSingleScript(t *testing.T) {

var scripts = []queries.ScriptTest{
{
Name: "Insert throws primary key violations",
Name: "people tables",
SetUpScript: []string{
"CREATE TABLE t (pk int PRIMARY key);",
"CREATE TABLE t2 (pk1 int, pk2 int, PRIMARY KEY (pk1, pk2));",
"CREATE TABLE `people` ( `dob` date NOT NULL," +
" `first_name` varchar(20) NOT NULL," +
" `last_name` varchar(20) NOT NULL," +
" `middle_name` varchar(20) NOT NULL," +
" `height_inches` bigint NOT NULL," +
" `gender` bigint NOT NULL," +
" PRIMARY KEY (`dob`,`first_name`,`last_name`,`middle_name`) )",
`insert into people values ('1970-12-1'::date, 'jon', 'smith', 'a', 72, 0)`,
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "INSERT INTO t VALUES (1), (2);",
Expected: []sql.Row{{types.NewOkResult(2)}},
},
{
Query: "INSERT into t VALUES (1);",
ExpectedErr: sql.ErrPrimaryKeyViolation,
},
{
Query: "SELECT * from t;",
Expected: []sql.Row{{1}, {2}},
},
{
Query: "INSERT into t2 VALUES (1, 1), (2, 2);",
Expected: []sql.Row{{types.NewOkResult(2)}},
},
{
Query: "INSERT into t2 VALUES (1, 1);",
ExpectedErr: sql.ErrPrimaryKeyViolation,
},
{
Query: "SELECT * from t2;",
Expected: []sql.Row{{1, 1}, {2, 2}},
Query: "select * from people order by dob",
Expected: []sql.Row{
{"1970-12-01", "jon", "smith", "a", int64(72), int64(0)},
},
},
},
},
Expand Down Expand Up @@ -422,38 +410,109 @@ func TestReplaceIntoErrors(t *testing.T) {
}

func TestUpdate(t *testing.T) {
t.Skip()
h := newDoltgresServerHarness(t)
h := newDoltgresServerHarness(t).WithSkippedQueries([]string{
"UPDATE mytable SET s = _binary 'updated' WHERE i = 3;", // _binary not supported
"UPDATE mytable SET s = 'updated' ORDER BY i LIMIT 1 OFFSET 1;", // offset not supported (limit isn't selected in vanilla postgres but is in the cockroach grammar)
// TODO: Postgres supports update joins but with a very different syntax, and some join types are not supported
`UPDATE one_pk INNER JOIN two_pk on one_pk.pk = two_pk.pk1 SET two_pk.c1 = two_pk.c1 + 1`,
"UPDATE mytable INNER JOIN one_pk ON mytable.i = one_pk.c5 SET mytable.i = mytable.i * 10",
`UPDATE one_pk INNER JOIN two_pk on one_pk.pk = two_pk.pk1 SET two_pk.c1 = two_pk.c1 + 1 WHERE one_pk.c5 < 10`,
`UPDATE one_pk INNER JOIN two_pk on one_pk.pk = two_pk.pk1 INNER JOIN othertable on othertable.i2 = two_pk.pk2 SET one_pk.c1 = one_pk.c1 + 1`,
`UPDATE one_pk INNER JOIN (SELECT * FROM two_pk order by pk1, pk2) as t2 on one_pk.pk = t2.pk1 SET one_pk.c1 = t2.c1 + 1 where one_pk.pk < 1`,
`UPDATE one_pk INNER JOIN two_pk on one_pk.pk = two_pk.pk1 SET one_pk.c1 = one_pk.c1 + 1`,
`UPDATE one_pk INNER JOIN two_pk on one_pk.pk = two_pk.pk1 SET one_pk.c1 = one_pk.c1 + 1, one_pk.c2 = one_pk.c2 + 1 ORDER BY one_pk.pk`,
`UPDATE one_pk INNER JOIN two_pk on one_pk.pk = two_pk.pk1 SET one_pk.c1 = one_pk.c1 + 1, two_pk.c1 = two_pk.c2 + 1`,
`update mytable h join mytable on h.i = mytable.i and h.s <> mytable.s set h.i = mytable.i+1;`,
`UPDATE othertable CROSS JOIN tabletest set othertable.i2 = othertable.i2 * 10`, // cross join
`UPDATE tabletest cross join tabletest as t2 set tabletest.i = tabletest.i * 10`, // cross join
`UPDATE othertable cross join tabletest set tabletest.i = tabletest.i * 10`, // cross join
`UPDATE one_pk INNER JOIN two_pk on one_pk.pk = two_pk.pk1 INNER JOIN two_pk a1 on one_pk.pk = two_pk.pk2 SET two_pk.c1 = two_pk.c1 + 1`, // cross join
`UPDATE othertable INNER JOIN tabletest on othertable.i2=3 and tabletest.i=3 SET othertable.s2 = 'fourth'`, // cross join
`UPDATE tabletest cross join tabletest as t2 set t2.i = t2.i * 10`, // cross join
`UPDATE othertable LEFT JOIN tabletest on othertable.i2=3 and tabletest.i=3 SET othertable.s2 = 'fourth'`, // left join
`UPDATE othertable LEFT JOIN tabletest on othertable.i2=3 and tabletest.i=3 SET tabletest.s = 'fourth row', tabletest.i = tabletest.i + 1`, // left join
`UPDATE othertable LEFT JOIN tabletest t3 on othertable.i2=3 and t3.i=3 SET t3.s = 'fourth row', t3.i = t3.i + 1`, // left join
`UPDATE othertable LEFT JOIN tabletest on othertable.i2=3 and tabletest.i=3 LEFT JOIN one_pk on othertable.i2 = one_pk.pk SET one_pk.c1 = one_pk.c1 + 1`, // left join
`UPDATE othertable LEFT JOIN tabletest on othertable.i2=3 and tabletest.i=3 LEFT JOIN one_pk on othertable.i2 = one_pk.pk SET one_pk.c1 = one_pk.c1 + 1 where one_pk.pk > 4`, // left join
`UPDATE othertable LEFT JOIN tabletest on othertable.i2=3 and tabletest.i=3 LEFT JOIN one_pk on othertable.i2 = 1 and one_pk.pk = 1 SET one_pk.c1 = one_pk.c1 + 1`, // left join
`UPDATE othertable RIGHT JOIN tabletest on othertable.i2=3 and tabletest.i=3 SET othertable.s2 = 'fourth'`, // right join
`UPDATE othertable RIGHT JOIN tabletest on othertable.i2=3 and tabletest.i=3 SET othertable.i2 = othertable.i2 + 1`, // right join
`UPDATE othertable LEFT JOIN tabletest on othertable.i2=tabletest.i RIGHT JOIN one_pk on othertable.i2 = 1 and one_pk.pk = 1 SET tabletest.s = 'updated';`, // right join
`UPDATE IGNORE one_pk INNER JOIN two_pk on one_pk.pk = two_pk.pk1 SET two_pk.c1 = two_pk.c1 + 1`,
`UPDATE IGNORE one_pk JOIN one_pk one_pk2 on one_pk.pk = one_pk2.pk SET one_pk.pk = 10`,
`with t (n) as (select (1) from dual) UPDATE mytable set s = concat('updated ', i) where i in (select n from t)`, // with not supported
`with recursive t (n) as (select (1) from dual union all select n + 1 from t where n < 2) UPDATE mytable set s = concat('updated ', i) where i in (select n from t)`,
})
defer h.Close()
enginetest.TestUpdate(t, h)
}

func TestUpdateIgnore(t *testing.T) {
t.Skip()
h := newDoltgresServerHarness(t)
defer h.Close()
enginetest.TestUpdateIgnore(t, h)
}

func TestUpdateErrors(t *testing.T) {
t.Skip()
h := newDoltgresServerHarness(t)
h := newDoltgresServerHarness(t).WithSkippedQueries([]string{
`UPDATE keyless INNER JOIN one_pk on keyless.c0 = one_pk.pk SET keyless.c0 = keyless.c0 + 1`,
"try updating string that is too long", // works but error message doesn't match
"UPDATE mytable SET s = 'hi' LIMIT -1;", // unsupported syntax
})
defer h.Close()
enginetest.TestUpdateErrors(t, h)
}

func TestDeleteFrom(t *testing.T) {
t.Skip()
h := newDoltgresServerHarness(t)
h := newDoltgresServerHarness(t).WithSkippedQueries([]string{
"DELETE FROM mytable ORDER BY i DESC LIMIT 1 OFFSET 1;", // offset is unsupported syntax
"DELETE FROM mytable WHERE (i,s) = (1, 'first row');", // type error, needs investigation
"with t (n) as (select (1) from dual) delete from mytable where i in (select n from t)",
"with recursive t (n) as (select (1) from dual union all select n + 1 from t where n < 2) delete from mytable where i in (select n from t)",
})
defer h.Close()
enginetest.TestDelete(t, h)

// We've inlined part of engineTest.TestDeleteFrom here because that method tests many queries for join deletions
// that would be tedious to write out as skips
h.Setup(setup.MydbData, setup.MytableData, setup.TabletestData)
t.Run("Delete from single table", func(t *testing.T) {
for _, tt := range queries.DeleteTests {
enginetest.RunWriteQueryTest(t, h, tt)
}
})
}

func TestDeleteFromErrors(t *testing.T) {
t.Skip()
h := newDoltgresServerHarness(t)
defer h.Close()
enginetest.TestDeleteErrors(t, h)

// These tests are overspecified to mysql-specific errors and include some syntax we don't support, so we redefine
// the subset we're interested in checking here
h.Setup(setup.MydbData, setup.MytableData, setup.TabletestData)
deleteScripts := []queries.ScriptTest{
{
Name: "DELETE FROM error cases",
Assertions: []queries.ScriptTestAssertion{
{
Query: "DELETE FROM invalidtable WHERE x < 1;",
ExpectedErrStr: "table not found: invalidtable",
},
{
Query: "DELETE FROM mytable WHERE z = 'dne';",
ExpectedErrStr: "column \"z\" could not be found in any table in scope",
},
{
Query: "DELETE FROM mytable LIMIT -1;",
ExpectedErrStr: "LIMIT must be greater than or equal to 0",
},
{
Query: "DELETE mytable WHERE i = 1;",
ExpectedErrStr: "syntax error",
},
{
Query: "DELETE FROM (SELECT * FROM mytable) mytable WHERE i = 1;",
ExpectedErrStr: "syntax error",
},
},
},
}
for _, tt := range deleteScripts {
enginetest.TestScript(t, h, tt)
}
}

func TestSpatialDelete(t *testing.T) {
Expand Down
8 changes: 6 additions & 2 deletions testing/go/enginetest/doltgres_harness_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -543,8 +543,12 @@ func (d *DoltgresHarness) EvaluateExpectedError(t *testing.T, expected string, e
// EvaluateExpectedErrorKind is a harness extension that gives us more control over matching expected errors. We don't
// have access to the error kind object eny longer, so we have to see if the error we get matches its pattern
func (d *DoltgresHarness) EvaluateExpectedErrorKind(t *testing.T, expected *errors.Kind, actualErr error) {
pattern := strings.ReplaceAll(expected.Message, "%s", "\\w+")
pattern = strings.ReplaceAll(pattern, "%q", "\"\\w+\"")
pattern := strings.ReplaceAll(expected.Message, "*", "\\*")
pattern = strings.ReplaceAll(pattern, "(", "\\(")
pattern = strings.ReplaceAll(pattern, ")", "\\)")
pattern = strings.ReplaceAll(pattern, "%s", ".+")
pattern = strings.ReplaceAll(pattern, "%q", "\".+\"")
pattern = strings.ReplaceAll(pattern, "%v", ".+?")
regex, regexErr := regexp.Compile(pattern)
require.NoError(t, regexErr)

Expand Down
Loading
Loading