Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CASSGO-36 could map scan get nil instead of zero value for null value #1834

Open
wants to merge 3 commits into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import (
"context"
"errors"
"fmt"
"github.com/stretchr/testify/require"
"io"
"math"
"math/big"
Expand Down Expand Up @@ -943,6 +944,190 @@ func TestMapScan(t *testing.T) {
assertEqual(t, "address", "10.0.0.1", row["address"])
}

type expec struct {
Id int
Col_ascii interface{}
Col_bigint interface{}
Col_blob interface{}
Col_boolean interface{}
Col_date interface{}
Col_decimal interface{}
Col_double interface{}
Col_duration interface{}
Col_float interface{}
Col_inet interface{}
Col_int interface{}
Col_smallint interface{}
Col_text interface{}
Col_time interface{}
Col_timestamp interface{}
Col_timeuuid interface{}
Col_tinyint interface{}
Col_uuid interface{}
Col_varchar interface{}
Col_varint interface{}
}

func TestMapScanWithNullbleValue(t *testing.T) {
timeUUID := TimeUUID()
date := time.Date(2009, time.November, 10, 0, 0, 0, 0, time.UTC)
timestamp := time.Time{}.Add(time.Duration(200))

testCases := []struct {
name string
query string
keys []string
values []interface{}
expectations expec
id int64
}{
{
name: "with values",
query: `INSERT INTO gocql_test.scan_map_with_nullable_value_table
(Id, Col_ascii, Col_bigint, Col_blob, Col_boolean, Col_date, Col_decimal, Col_double,
Col_duration, Col_float, Col_inet, Col_int, Col_smallint, Col_text, Col_time, Col_timestamp,
Col_timeuuid, Col_tinyint, Col_uuid, Col_varchar, Col_varint)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
keys: []string{"Id", "Col_ascii", "Col_bigint", "Col_blob", "Col_boolean", "Col_date", "Col_decimal", "Col_double", "Col_duration", "Col_float", "Col_inet", "Col_int", "Col_smallint", "Col_text", "Col_time", "Col_timestamp", "Col_timeuuid", "Col_tinyint", "Col_uuid", "Col_varchar", "Col_varint"},
values: []interface{}{1, "test_ascii", int64(123456789), []byte{0x01, 0x02, 0x03}, true,
date, *inf.NewDec(12345, 0), 123.45, Duration{
Months: 250,
Days: 500,
Nanoseconds: 300010001,
}, float32(3.14), "127.0.0.1",
123, int16(1000), "test_text", time.Duration(200), timestamp, timeUUID,
int8(5), timeUUID, "test_varchar", *big.NewInt(99999)},
expectations: expec{
Id: 1,
Col_ascii: "test_ascii",
Col_bigint: int64(123456789),
Col_blob: []byte{0x01, 0x02, 0x03},
Col_boolean: true,
Col_date: date,
Col_decimal: *inf.NewDec(12345, 0),
Col_double: 123.45,
Col_duration: Duration{
Months: 250,
Days: 500,
Nanoseconds: 300010001,
},
Col_float: float32(3.14),
Col_inet: "127.0.0.1",
Col_int: 123,
Col_smallint: int16(1000),
Col_text: "test_text",
Col_time: time.Duration(200),
Col_timestamp: timestamp,
Col_timeuuid: timeUUID,
Col_tinyint: int8(5),
Col_uuid: timeUUID,
Col_varchar: "test_varchar",
Col_varint: *big.NewInt(99999),
},
id: 1,
},

{
name: "without values",
query: `INSERT INTO gocql_test.scan_map_with_nullable_value_table (Id) VALUES (?)`,
keys: []string{"Id", "Col_ascii", "Col_bigint", "Col_blob", "Col_boolean", "Col_date", "Col_decimal", "Col_double", "Col_duration", "Col_float", "Col_inet", "Col_int", "Col_smallint", "Col_text", "Col_time", "Col_timestamp", "Col_timeuuid", "Col_tinyint", "Col_uuid", "Col_varchar", "Col_varint"},
values: []interface{}{2},
expectations: expec{
Id: 2,
Col_ascii: nil,
Col_bigint: nil,
Col_blob: nil,
Col_boolean: nil,
Col_date: nil,
Col_decimal: nil,
Col_double: nil,
Col_duration: nil,
Col_float: nil,
Col_inet: nil,
Col_int: nil,
Col_smallint: nil,
Col_text: nil,
Col_time: nil,
Col_timestamp: nil,
Col_timeuuid: nil,
Col_tinyint: nil,
Col_uuid: nil,
Col_varchar: nil,
Col_varint: nil,
},
id: 2,
},
}
session := createSession(t)
defer session.Close()

createTableQuery := `
CREATE TABLE IF NOT EXISTS gocql_test.scan_map_with_nullable_value_table (
Id INT PRIMARY KEY,
Col_ascii ASCII,
Col_bigint BIGINT,
Col_blob BLOB,
Col_boolean BOOLEAN,
Col_date DATE,
Col_decimal DECIMAL,
Col_double DOUBLE,
Col_duration DURATION,
Col_float FLOAT,
Col_inet INET,
Col_int INT,
Col_smallint SMALLINT,
Col_text TEXT,
Col_time TIME,
Col_timestamp TIMESTAMP,
Col_timeuuid TIMEUUID,
Col_tinyint TINYINT,
Col_uuid UUID,
Col_varchar VARCHAR,
Col_varint VARINT
);
`

err := session.Query(createTableQuery).Exec()
if err != nil {
t.Fatal("Failed to create the table:", err)
}

t.Log("Table created successfully!")

for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
err = session.Query(testCase.query, testCase.values...).Exec()
if err != nil {
t.Fatal("Failed to execute query:", err)
}

iter := session.Query(`SELECT * FROM gocql_test.scan_map_with_nullable_value_table WHERE Id = ? LIMIT 1`, testCase.id).Iter()
row := make(map[string]interface{})

if !iter.MapScanWithNullableValues(row) {
t.Fatal("select:", iter.Close())
}

v := reflect.ValueOf(testCase.expectations)
for _, key := range testCase.keys {
if testCase.id == 1 {
col := row[strings.ToLower(key)]
if !reflect.ValueOf(col).Elem().IsZero() {
got := reflect.ValueOf(col).Elem().Interface()

require.Equal(t, v.FieldByName(key).Interface(), got, key)
}
} else {
if key != "Id" && !reflect.ValueOf(row[strings.ToLower(key)]).IsZero() {
t.Fatalf("Failed on:%v,\nExpected %v to be %v,\n Got: %v", key, key, v.FieldByName(key).Interface(), row[strings.ToLower(key)])
}
}
}
})
}
}

func TestSliceMap(t *testing.T) {
session := createSession(t)
defer session.Close()
Expand Down
8 changes: 8 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,14 @@ func assertEqual(t *testing.T, description string, expected, actual interface{})

func assertDeepEqual(t *testing.T, description string, expected, actual interface{}) {
t.Helper()
rv1 := reflect.ValueOf(expected)
rv2 := reflect.ValueOf(actual)
if rv1.Kind() == reflect.Ptr {
expected = rv1.Elem().Interface()
}
if rv2.Kind() == reflect.Ptr {
actual = rv2.Elem().Interface()
}
if !reflect.DeepEqual(expected, actual) {
t.Fatalf("expected %s to be (%+v) but was (%+v) instead", description, expected, actual)
}
Expand Down
82 changes: 80 additions & 2 deletions helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,66 @@ func goType(t TypeInfo) (reflect.Type, error) {
}
}

func nullableGoType(t TypeInfo) (reflect.Type, error) {
switch t.Type() {
case TypeVarchar, TypeAscii, TypeInet, TypeText:
return reflect.TypeOf(new(string)), nil
case TypeBigInt, TypeCounter:
return reflect.TypeOf(new(int64)), nil
case TypeTime:
return reflect.TypeOf(new(time.Duration)), nil
case TypeTimestamp:
return reflect.TypeOf(new(time.Time)), nil
case TypeBlob:
return reflect.TypeOf(new([]byte)), nil
case TypeBoolean:
return reflect.TypeOf(new(bool)), nil
case TypeFloat:
return reflect.TypeOf(new(float32)), nil
case TypeDouble:
return reflect.TypeOf(new(float64)), nil
case TypeInt:
return reflect.TypeOf(new(int)), nil
case TypeSmallInt:
return reflect.TypeOf(new(int16)), nil
case TypeTinyInt:
return reflect.TypeOf(new(int8)), nil
case TypeDecimal:
return reflect.TypeOf(*new(*inf.Dec)), nil
case TypeUUID, TypeTimeUUID:
return reflect.TypeOf(new(UUID)), nil
case TypeList, TypeSet:
elemType, err := nullableGoType(t.(CollectionType).Elem)
if err != nil {
return nil, err
}
return reflect.SliceOf(elemType), nil
case TypeMap:
keyType, err := nullableGoType(t.(CollectionType).Key)
if err != nil {
return nil, err
}
valueType, err := nullableGoType(t.(CollectionType).Elem)
if err != nil {
return nil, err
}
return reflect.MapOf(keyType, valueType), nil
case TypeVarint:
return reflect.TypeOf(*new(*big.Int)), nil
case TypeTuple:
tuple := t.(TupleTypeInfo)
return reflect.TypeOf(make([]interface{}, len(tuple.Elems))), nil
case TypeUDT:
return reflect.TypeOf(make(map[string]interface{})), nil
case TypeDate:
return reflect.TypeOf(new(time.Time)), nil
case TypeDuration:
return reflect.TypeOf(new(Duration)), nil
default:
return nil, fmt.Errorf("cannot create Go type for unknown CQL type %s", t)
}
}

func dereference(i interface{}) interface{} {
return reflect.Indirect(reflect.ValueOf(i)).Interface()
}
Expand Down Expand Up @@ -323,6 +383,8 @@ func TupleColumnName(c string, n int) string {
}

func (iter *Iter) RowData() (RowData, error) {
var err error
var val interface{}
if iter.err != nil {
return RowData{}, iter.err
}
Expand All @@ -332,7 +394,12 @@ func (iter *Iter) RowData() (RowData, error) {

for _, column := range iter.Columns() {
if c, ok := column.TypeInfo.(TupleTypeInfo); !ok {
val, err := column.TypeInfo.NewWithError()
if !iter.isNullableScan {
val, err = column.TypeInfo.NewWithError()
} else {
val, err = column.TypeInfo.NewWithNullable()
}

if err != nil {
return RowData{}, err
}
Expand All @@ -342,10 +409,11 @@ func (iter *Iter) RowData() (RowData, error) {
for i, elem := range c.Elems {
columns = append(columns, TupleColumnName(column.Name, i))
val, err := elem.NewWithError()

if err != nil {
return RowData{}, err
}
values = append(values, val)
values = append(values, &val)
}
}
}
Expand Down Expand Up @@ -451,6 +519,16 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool {
return false
}

// MapScanWithNullableValues takes a map[string]interface{} and populates it with a row
// that is returned from cassandra.
//
// Each call to MapScanWithNullableValues() must be called with a new map object.
func (iter *Iter) MapScanWithNullableValues(m map[string]interface{}) bool {
iter.setNullableScan(true)
scan := iter.MapScan(m)
return scan
}

func copyBytes(p []byte) []byte {
b := make([]byte, len(p))
copy(b, p)
Expand Down
Loading