diff --git a/json.go b/json.go index e6d021e..7947678 100644 --- a/json.go +++ b/json.go @@ -457,7 +457,7 @@ type JSONArrayExpression struct { equalsValue interface{} } -// Contains checks if the column[keys] contains the value given. The keys parameter is only supported for MySQL. +// Contains checks if column[keys] contains the value given. The keys parameter is only supported for MySQL and SQLite. func (json *JSONArrayExpression) Contains(value interface{}, keys ...string) *JSONArrayExpression { json.contains = true json.equalsValue = value @@ -465,7 +465,7 @@ func (json *JSONArrayExpression) Contains(value interface{}, keys ...string) *JS return json } -// In checks if columns[keys] is in the array value given. This method is only supported for MySQL. +// In checks if columns[keys] is in the array value given. This method is only supported for MySQL and SQLite. func (json *JSONArrayExpression) In(value interface{}, keys ...string) *JSONArrayExpression { json.in = true json.keys = keys @@ -521,6 +521,34 @@ func (json *JSONArrayExpression) Build(builder clause.Builder) { builder.AddVar(stmt, jsonQueryJoin(json.keys)) } builder.WriteString(") > 0") + case json.in: + builder.WriteString("CASE WHEN json_type(") + builder.WriteQuoted(json.column) + if len(json.keys) > 0 { + builder.WriteByte(',') + builder.AddVar(stmt, jsonQueryJoin(json.keys)) + } + builder.WriteString(") = 'array' THEN NOT EXISTS(SELECT 1 FROM json_each(") + builder.WriteQuoted(json.column) + if len(json.keys) > 0 { + builder.WriteByte(',') + builder.AddVar(stmt, jsonQueryJoin(json.keys)) + } + builder.WriteString(") WHERE value NOT IN ") + builder.AddVar(stmt, json.equalsValue) + builder.WriteString(") ELSE ") + if len(json.keys) > 0 { + builder.WriteString("json_extract(") + } + builder.WriteQuoted(json.column) + if len(json.keys) > 0 { + builder.WriteByte(',') + builder.AddVar(stmt, jsonQueryJoin(json.keys)) + builder.WriteByte(')') + } + builder.WriteString(" IN ") + builder.AddVar(stmt, json.equalsValue) + builder.WriteString(" END") } case "postgres": switch { diff --git a/json_test.go b/json_test.go index 598dc65..a5badb1 100644 --- a/json_test.go +++ b/json_test.go @@ -514,16 +514,14 @@ func TestJSONArrayQuery(t *testing.T) { } AssertEqual(t, len(retMultiple), 1) - if SupportedDriver("mysql") { - if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "a"})).Find(&retMultiple).Error; err != nil { - t.Fatalf("failed to find params with json value, got error %v", err) - } - AssertEqual(t, len(retMultiple), 1) + if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "a"})).Find(&retMultiple).Error; err != nil { + t.Fatalf("failed to find params with json value, got error %v", err) + } + AssertEqual(t, len(retMultiple), 1) - if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "d"}, "test")).Find(&retMultiple).Error; err != nil { - t.Fatalf("failed to find params with json value and keys, got error %v", err) - } - AssertEqual(t, len(retMultiple), 1) + if err := DB.Where(datatypes.JSONArrayQuery("config").In([]string{"c", "d"}, "test")).Find(&retMultiple).Error; err != nil { + t.Fatalf("failed to find params with json value and keys, got error %v", err) } + AssertEqual(t, len(retMultiple), 1) } }