Skip to content

Commit

Permalink
Merge pull request #8 from go-goyave/bug/mismatching-data-type-sql-error
Browse files Browse the repository at this point in the history
Add type-safety in operators
  • Loading branch information
System-Glitch authored Apr 12, 2023
2 parents 1402f6e + c104a7b commit 60c7424
Show file tree
Hide file tree
Showing 12 changed files with 2,002 additions and 234 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
go: [1.17, 1.18, 1.19]
go: ["1.17", "1.18", "1.19", "1.20"]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
Expand All @@ -23,7 +23,7 @@ jobs:
- name: Run tests
run: |
go test -v -race -coverprofile=coverage.txt -covermode=atomic -coverpkg=./... ./...
- if: ${{ matrix.go == 1.19 }}
- if: ${{ matrix.go == 1.20 }}
uses: shogo82148/actions-goveralls@v1
with:
path-to-profile: coverage.txt
Expand All @@ -36,5 +36,5 @@ jobs:
- name: Run lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.50
version: v1.52
args: --timeout 5m
134 changes: 129 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,11 @@ Internally, `goyave.dev/filter` uses [Goyave's `Paginator`](https://goyave.dev/g
Sometimes you need to work with a "virtual" column that is not stored in your database, but is computed using an SQL expression. A dynamic status depending on a date for example. In order to support the features of this library properly, you will have to add the expression to your model using the `computed` struct tag:

```go
type MyModel struct{
type MyModel struct {
ID uint
// ...
StartDate time.Time
Status string `gorm:"->;-:migration" computed:"CASE WHEN ~~~ct~~~.start_date < NOW() THEN 'pending' ELSE 'started' END"`
Status string `gorm:"->;-:migration" computed:"CASE WHEN ~~~ct~~~.start_date < NOW() THEN 'pending' ELSE 'started' END"`
}
```

Expand All @@ -232,10 +232,25 @@ type MyModelWithStatus struct{
}
```

When using JSON columns, you can support filters on nested fields inside that JSON column using a computed column:

```go
// This example is compatible with PostgreSQL.
// JSON processing may be different if you are using another database engine.
type MyModel struct {
ID uint
JSONColumn datatypes.JSON
SomeJSONField null.Int `gorm:"->;-:migration" computed:"(~~~ct~~~.json_column->>'fieldName')::int"`
}
```

It is important to make sure your JSON expression returns a value that has a type that matches the struct field to avoid DB errors. Database engines usually only return text types from JSON. If your field is a number, you'll have to cast it or you will get database errors when filtering on this field.

## Security

- Inputs are escaped to prevent SQL injections.
- Fields are pre-processed and clients cannot request fields that don't exist. This prevents database errors. If a non-existing field is required, it is simply ignored. The same goes for sorts and joins. It is not possible to request a relation that doesn't exist.
- Type-safety: in the same field pre-processing, the broad type of the field is checked against the database type (based on the model definition). This prevents database errors if the input cannot be converted to the column's type.
- Foreign keys are always selected in joins to ensure associations can be assigned to parent model.
- **Be careful** with bidirectional relations (for example an article is written by a user, and a user can have many articles). If you enabled both your models to preload these relations, the client can request them with an infinite depth (`Articles.User.Articles.User...`). To prevent this, it is advised to use **the relation blacklist** or **IsFinal** on the deepest requestable models. See the settings section for more details.

Expand All @@ -251,6 +266,33 @@ type MyModelWithStatus struct{
- Don't use `gorm.Model` and add the necessary fields manually. You get better control over json struct tags this way.
- Use pointers for nullable relations and nullable fields that implement `sql.Scanner` (such as `null.Time`).

### Type-safety

For non-native types that don't implement the `driver.Valuer` interface, you should always use the `filterType` struct tag. This struct tag enforces the field's recognized broad type for the type-safety conversion. It is also recommended to always add this tag when working with arrays. This tag is effective for the filter and search features.

Available broad types are:
- `text` / `text[]`
- `enum` / `enum[]`: use this with custom enum types to prevent "invalid input value" or "invalid operator" errors
- `bool` / `bool[]`
- `int8` / `int8[]`, `int16` / `int16[]`, `int32` / `int32[]`, `int64` / `int64[]`
- `uint` / `uint[]`, `uint16` / `uint16[]`, `uint32` / `uint32[]`, `uint64` / `uint64[]`
- `float32` / `float32[]`, `float64` / `float64[]`
- `time` / `time[]`
- `-`: unsupported data type. Fields tagged with `-` will be ignored in filters and search: no condition will be added to the `WHERE` clause.

If not provided, the type will be determined from GORM's data type. If GORM's data type is a custom type that is not directly supported by this library, the type will fall back to `-` (unsupported) and the field will be ignored in the filters.

If the type is supported but the user input cannot be used with the requested column, the built-in operators will generate a `FALSE` condition.

**Example**
```go
type MyModel struct{
ID uint
// ...
StartDate null.Time `filterType:"time"`
}
```

### Static conditions

If you want to add static conditions (not automatically defined by the library), it is advised to group them like so:
Expand Down Expand Up @@ -279,13 +321,95 @@ import (
// ...

filter.Operators["$cont"] = &filter.Operator{
Function: func(tx *gorm.DB, filter *filter.Filter, column string, dataType schema.DataType) *gorm.DB {
Function: func(tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB {
if dataType != filter.DataTypeString {
return tx.Where("FALSE")
}
query := column + " LIKE ?"
value := "%" + sqlutil.EscapeLike(filter.Args[0]) + "%"
return filter.Where(tx, query, value)
value := "%" + sqlutil.EscapeLike(f.Args[0]) + "%"
return f.Where(tx, query, value)
},
RequiredArguments: 1,
}

filter.Operators["$eq"] = &filter.Operator{
Function: func(tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB {
if dataType.IsArray() {
return tx.Where("FALSE")
}
arg, ok := filter.ConvertToSafeType(f.Args[0], dataType)
if !ok {
return tx.Where("FALSE")
}
query := fmt.Sprintf("%s = ?", column, op)
return f.Where(tx, query, arg)
},
RequiredArguments: 1,
}
```

#### Array operators

Some database engines such as PostgreSQL provide operators for array operations (`@>`, `&&`, ...). You may encounter issue implementing these operators in your project because of GORM converting slices into records (`("a", "b")` instead of `{"a", "b"}`).

To fix this issue, you will have to implement your own variant of `ConvertArgsToSafeType` so it returns a **pointer** to a slice with a concrete type instead of `[]interface{}`. By sending a pointer to GORM, it won't try to render the slice itself and pass it directly to the underlying driver, which usually knows how to handle slices for the native types.

**Example** (using generics with go 1.18+):
```go
type argType interface {
string | int64 | uint64 | float64 | bool
}

func init() {
filter.Operators["$arrayin"] = &filter.Operator{
Function: func (tx *gorm.DB, f *filter.Filter, column string, dataType filter.DataType) *gorm.DB {
if !dataType.IsArray() {
return tx.Where("FALSE")
}

if dataType == filter.DataTypeEnumArray {
column = fmt.Sprintf("CAST(%s as TEXT[])", column)
}

query := fmt.Sprintf("%s @> ?", column)
switch dataType {
case filter.DataTypeTextArray, filter.DataTypeEnumArray, filter.DataTypeTimeArray:
return bindArrayArg[string](tx, query, f, dataType)
case filter.DataTypeFloat32Array, filter.DataTypeFloat64Array:
return bindArrayArg[float64](tx, query, f, dataType)
case filter.DataTypeUint8Array, filter.DataTypeUint16Array, filter.DataTypeUint32Array, filter.DataTypeUint64Array:
return bindArrayArg[uint64](tx, query, f, dataType)
case filter.DataTypeInt8Array, filter.DataTypeInt16Array, filter.DataTypeInt32Array, filter.DataTypeInt64Array:
return bindArrayArg[int64](tx, query, f, dataType)
}

// If you need to handle DataTypeBoolArray, use pgtype.BoolArray
return tx.Where("FALSE")
},
RequiredArguments: 1,
}
}

func bindArrayArg[T argType](tx *gorm.DB, query string, f *filter.Filter, dataType filter.DataType) *gorm.DB {
args, ok := convertArgsToSafeTypeArray[T](f.Args, dataType)
if !ok {
return tx.Where("FALSE")
}
return f.Where(tx, query, args)
}

func convertArgsToSafeTypeArray[T argType](args []string, dataType filter.DataType) (*[]T, bool) {
result := make([]T, 0, len(args))
for _, arg := range args {
a, ok := filter.ConvertToSafeType(arg, dataType)
if !ok {
return nil, false
}
result = append(result, a.(T))
}

return &result, true
}
```

### Manual joins
Expand Down
12 changes: 11 additions & 1 deletion filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ func (f *Filter) Scope(settings *Settings, sch *schema.Schema) (func(*gorm.DB) *
return nil, nil
}

dataType := getDataType(field)

joinScope := func(tx *gorm.DB) *gorm.DB {
if dataType == DataTypeUnsupported {
return tx
}
if joinName != "" {
if err := tx.Statement.Parse(tx.Statement.Model); err != nil {
tx.AddError(err)
Expand All @@ -39,14 +44,19 @@ func (f *Filter) Scope(settings *Settings, sch *schema.Schema) (func(*gorm.DB) *
computed := field.StructField.Tag.Get("computed")

conditionScope := func(tx *gorm.DB) *gorm.DB {
if dataType == DataTypeUnsupported {
return tx
}

table := tx.Statement.Quote(tableFromJoinName(s.Table, joinName))
var fieldExpr string
if computed != "" {
fieldExpr = fmt.Sprintf("(%s)", strings.ReplaceAll(computed, clause.CurrentTable, table))
} else {
fieldExpr = table + "." + tx.Statement.Quote(field.DBName)
}
return f.Operator.Function(tx, f, fieldExpr, field.DataType)

return f.Operator.Function(tx, f, fieldExpr, dataType)
}

return joinScope, conditionScope
Expand Down
61 changes: 58 additions & 3 deletions filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestFilterScope(t *testing.T) {
schema := &schema.Schema{
DBNames: []string{"name"},
FieldsByDBName: map[string]*schema.Field{
"name": {Name: "Name", DBName: "name"},
"name": {Name: "Name", DBName: "name", GORMDataType: schema.String},
},
Table: "test_scope_models",
}
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestFilterScopeBlacklisted(t *testing.T) {
schema := &schema.Schema{
DBNames: []string{"name"},
FieldsByDBName: map[string]*schema.Field{
"name": {Name: "Name"},
"name": {Name: "Name", GORMDataType: schema.String},
},
}

Expand Down Expand Up @@ -112,6 +112,18 @@ type FilterTestModel struct {
ID uint
}

type FilterTestRelationUnsupported struct {
Name string `filterType:"-"`
ID uint
ParentID uint
}

type FilterTestModelUnsupported struct {
Relation *FilterTestRelationUnsupported `gorm:"foreignKey:ParentID"`
Name string
ID uint
}

func TestFilterScopeWithJoin(t *testing.T) {
db := openDryRunDB(t)
filter := &Filter{Field: "Relation.name", Args: []string{"val1"}, Operator: Operators["$eq"]}
Expand Down Expand Up @@ -346,7 +358,7 @@ func TestFilterScopeWithJoinDontDuplicate(t *testing.T) {
Expression: clause.Where{
Exprs: []clause.Expression{
clause.Expr{SQL: "`Relation`.`name` = ?", Vars: []interface{}{"val1"}},
clause.Expr{SQL: "`Relation`.`id` > ?", Vars: []interface{}{"0"}},
clause.Expr{SQL: "`Relation`.`id` > ?", Vars: []interface{}{uint64(0)}},
},
},
},
Expand Down Expand Up @@ -608,3 +620,46 @@ func TestFilterScopeComputedRelation(t *testing.T) {
}
assert.Equal(t, expected, db.Statement.Clauses)
}

func TestFilterScopeWithUnsupportedDataType(t *testing.T) {
db := openDryRunDB(t)
filter := &Filter{Field: "name", Args: []string{"val1"}, Operator: Operators["$eq"]}
schema := &schema.Schema{
DBNames: []string{"name"},
FieldsByDBName: map[string]*schema.Field{
"name": {Name: "Name", DBName: "name", GORMDataType: "custom", DataType: "CHARACTER VARYING(255)"},
},
Table: "test_scope_models",
}

results := []map[string]interface{}{}
db = db.Scopes(filter.Scope(&Settings{}, schema)).Find(results)
expected := map[string]clause.Clause{}
assert.Equal(t, expected, db.Statement.Clauses)
}

func TestFilterScopeWithJoinedUnsupportedDataType(t *testing.T) {
db := openDryRunDB(t)
filter := &Filter{Field: "Relation.name", Args: []string{"val1"}, Operator: Operators["$eq"]}

results := []*FilterTestModelUnsupported{}
schema, err := parseModel(db, &results)
if !assert.Nil(t, err) {
return
}

db.DryRun = true
db = db.Model(&results).Scopes(filter.Scope(&Settings{}, schema)).Find(&results)
expected := map[string]clause.Clause{
"FROM": {
Name: "FROM",
Expression: clause.From{},
},
"SELECT": {
Name: "SELECT",
Expression: clause.Select{},
},
}
assert.Equal(t, expected, db.Statement.Clauses)
assert.Nil(t, db.Error)
}
4 changes: 2 additions & 2 deletions join.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func join(tx *gorm.DB, joinName string, sch *schema.Schema) *gorm.DB {
Table: clause.Table{Name: sch.Table, Alias: relation.Name},
ON: clause.Where{Exprs: exprs},
}
if !joinExists(tx.Statement, j) && !findStatementJoin(tx.Statement, relation, &j) {
if !joinExists(tx.Statement, j) && !findStatementJoin(tx.Statement, &j) {
joins = append(joins, j)
}
}
Expand Down Expand Up @@ -201,7 +201,7 @@ func joinExists(stmt *gorm.Statement, join clause.Join) bool {
// Removes this information from the join afterwards to avoid Gorm reprocessing it.
// This is used to avoid duplicate joins that produce ambiguous column names and to
// support computed columns.
func findStatementJoin(stmt *gorm.Statement, relation *schema.Relationship, join *clause.Join) bool {
func findStatementJoin(stmt *gorm.Statement, join *clause.Join) bool {
for _, j := range stmt.Joins {
if j.Name == join.Table.Alias {
return true
Expand Down
Loading

0 comments on commit 60c7424

Please sign in to comment.