diff --git a/changes/20241107160700.feature b/changes/20241107160700.feature new file mode 100644 index 0000000000..aefd8bccf5 --- /dev/null +++ b/changes/20241107160700.feature @@ -0,0 +1 @@ +:sparkles: `[safecast]` Introduced utilities to perform casting safely and protect against [CWE-190](https://cwe.mitre.org/data/definitions/190.html) diff --git a/utils/field/fields_test.go b/utils/field/fields_test.go index 54d7947b18..805875a490 100644 --- a/utils/field/fields_test.go +++ b/utils/field/fields_test.go @@ -10,6 +10,8 @@ import ( "github.com/go-faker/faker/v4" "github.com/stretchr/testify/assert" + + "github.com/ARM-software/golang-utils/utils/safecast" ) func TestOptionalField(t *testing.T) { @@ -22,7 +24,7 @@ func TestOptionalField(t *testing.T) { }{ { fieldType: "Int", - value: time.Now().Second(), + value: safecast.ToInt(time.Now().Second()), defaultValue: 76, setFunction: func(a any) any { return ToOptionalInt(a.(int)) @@ -37,8 +39,8 @@ func TestOptionalField(t *testing.T) { }, { fieldType: "UInt", - value: uint(time.Now().Second()), //nolint:gosec // time is positive and uint has more bits than int so no overflow - defaultValue: uint(76), + value: safecast.ToUint(time.Now().Second()), + defaultValue: safecast.ToUint(76), setFunction: func(a any) any { return ToOptionalUint(a.(uint)) }, @@ -52,8 +54,8 @@ func TestOptionalField(t *testing.T) { }, { fieldType: "Int32", - value: int32(time.Now().Second()), //nolint:gosec // this should be okay until 2038 - defaultValue: int32(97894), + value: safecast.ToInt32(time.Now().Second()), + defaultValue: safecast.ToInt32(97894), setFunction: func(a any) any { return ToOptionalInt32(a.(int32)) }, @@ -67,8 +69,8 @@ func TestOptionalField(t *testing.T) { }, { fieldType: "UInt32", - value: uint32(time.Now().Second()), //nolint:gosec // this should be okay until 2038 - defaultValue: uint32(97894), + value: safecast.ToUint32(time.Now().Second()), + defaultValue: safecast.ToUint32(97894), setFunction: func(a any) any { return ToOptionalUint32(a.(uint32)) }, @@ -83,7 +85,7 @@ func TestOptionalField(t *testing.T) { { fieldType: "Int64", value: time.Now().Unix(), - defaultValue: int64(97894), + defaultValue: safecast.ToInt64(97894), setFunction: func(a any) any { return ToOptionalInt64(a.(int64)) }, @@ -97,8 +99,8 @@ func TestOptionalField(t *testing.T) { }, { fieldType: "UInt64", - value: uint64(time.Now().Unix()), //nolint:gosec // time is positive and uint64 has more bits than int64 so no overflow - defaultValue: uint64(97894), + value: safecast.ToUint64(time.Now().Unix()), + defaultValue: safecast.ToUint64(97894), setFunction: func(a any) any { return ToOptionalUint64(a.(uint64)) }, diff --git a/utils/filesystem/zip.go b/utils/filesystem/zip.go index 18de71a894..2ce20359cb 100644 --- a/utils/filesystem/zip.go +++ b/utils/filesystem/zip.go @@ -18,6 +18,7 @@ import ( "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/parallelisation" + "github.com/ARM-software/golang-utils/utils/safecast" "github.com/ARM-software/golang-utils/utils/safeio" ) @@ -356,16 +357,16 @@ func (fs *VFS) unzip(ctx context.Context, source string, destination string, lim fileCounter.Inc() fileList = append(fileList, filePath) } - totalSizeOnDisk.Add(uint64(fileSizeOnDisk)) //nolint:gosec // file size is positive and uint64 has more bits than int64 so no overflow + totalSizeOnDisk.Add(safecast.ToUint64(fileSizeOnDisk)) } } else { - totalSizeOnDisk.Add(uint64(fileSizeOnDisk)) //nolint:gosec // file size is positive and uint64 has more bits than int64 so no overflow + totalSizeOnDisk.Add(safecast.ToUint64(fileSizeOnDisk)) } if limits.Apply() && totalSizeOnDisk.Load() > limits.GetMaxTotalSize() { return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), fmt.Errorf("%w: more than %v B of disk space was used while unzipping %v (%v B used already)", commonerrors.ErrTooLarge, limits.GetMaxTotalSize(), source, totalSizeOnDisk.Load()) } - if filecount := fileCounter.Load(); limits.Apply() && filecount <= math.MaxInt64 && int64(filecount) > limits.GetMaxFileCount() { //nolint:gosec // if filecount of uint64 is greater than the max value of int64 then it must be greater than GetMaxFileCount as that is an int64 + if filecount := fileCounter.Load(); limits.Apply() && filecount <= math.MaxInt64 && safecast.ToInt64(filecount) > limits.GetMaxFileCount() { return fileList, filecount, totalSizeOnDisk.Load(), fmt.Errorf("%w: more than %v files were created while unzipping %v (%v files created already)", commonerrors.ErrTooLarge, limits.GetMaxFileCount(), source, filecount) } } diff --git a/utils/idgen/uuid.go b/utils/idgen/uuid.go index 472c15870d..cad22e145a 100644 --- a/utils/idgen/uuid.go +++ b/utils/idgen/uuid.go @@ -4,13 +4,19 @@ */ package idgen -import "github.com/gofrs/uuid" +import ( + "fmt" + + "github.com/gofrs/uuid" + + "github.com/ARM-software/golang-utils/utils/commonerrors" +) // Generates a UUID. func GenerateUUID4() (string, error) { uuid, err := uuid.NewV4() if err != nil { - return "", err + return "", fmt.Errorf("%w: failed generating uuid: %v", commonerrors.ErrUnexpected, err.Error()) } return uuid.String(), nil } diff --git a/utils/idgen/uuid_test.go b/utils/idgen/uuid_test.go index 1e3686c767..2e9ebec1bc 100644 --- a/utils/idgen/uuid_test.go +++ b/utils/idgen/uuid_test.go @@ -13,17 +13,17 @@ import ( func TestUuidUniqueness(t *testing.T) { uuid1, err := GenerateUUID4() - require.Nil(t, err) + require.NoError(t, err) uuid2, err := GenerateUUID4() - require.Nil(t, err) + require.NoError(t, err) assert.NotEqual(t, uuid1, uuid2) } func TestUuidLength(t *testing.T) { uuid, err := GenerateUUID4() - require.Nil(t, err) + require.NoError(t, err) assert.Equal(t, 36, len(uuid)) } diff --git a/utils/platform/os.go b/utils/platform/os.go index 5c88173087..bc43f1e16d 100644 --- a/utils/platform/os.go +++ b/utils/platform/os.go @@ -19,6 +19,7 @@ import ( "github.com/shirou/gopsutil/v3/mem" "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/safecast" ) var ( @@ -114,7 +115,7 @@ func UpTime() (uptime time.Duration, err error) { err = fmt.Errorf("%w: could not convert uptime '%v' to duration as it exceeds the upper limit for time.Duration", commonerrors.ErrOutOfRange, _uptime) return } - uptime = time.Duration(_uptime) * time.Second //nolint:gosec // we have verified the value of _uptime is whithin the upper limit for time.Duration in the above check + uptime = time.Duration(safecast.ToInt64(_uptime)) * time.Second return } @@ -128,7 +129,7 @@ func BootTime() (bootime time.Time, err error) { err = fmt.Errorf("%w: could not convert uptime '%v' to duration as it exceeds the upper limit for time.Duration", commonerrors.ErrOutOfRange, _bootime) return } - bootime = time.Unix(int64(_bootime), 0) //nolint:gosec // we have verified the value of _bootime is whithin the upper limit for time.Duration in the above check + bootime = time.Unix(safecast.ToInt64(_bootime), 0) return } diff --git a/utils/proc/process.go b/utils/proc/process.go index af425b886d..fb79d332e4 100644 --- a/utils/proc/process.go +++ b/utils/proc/process.go @@ -14,6 +14,7 @@ import ( "github.com/ARM-software/golang-utils/utils/collection" "github.com/ARM-software/golang-utils/utils/commonerrors" "github.com/ARM-software/golang-utils/utils/parallelisation" + "github.com/ARM-software/golang-utils/utils/safecast" ) const ( @@ -245,7 +246,7 @@ func isProcessRunning(p *process.Process) (running bool) { // to get more information about the process. An error will be returned // if the process does not exist. func NewProcess(ctx context.Context, pid int) (pr IProcess, err error) { - p, err := process.NewProcessWithContext(ctx, int32(pid)) //nolint:gosec // Max PID is 2^22 which is within int32 range https://stackoverflow.com/a/6294196 + p, err := process.NewProcessWithContext(ctx, safecast.ToInt32(pid)) err = ConvertProcessError(err) if err != nil { return diff --git a/utils/reflection/reflection.go b/utils/reflection/reflection.go index 0ad0b96a52..937aef9285 100644 --- a/utils/reflection/reflection.go +++ b/utils/reflection/reflection.go @@ -20,7 +20,7 @@ func GetStructureField(field reflect.Value) interface{} { if !field.IsValid() { return nil } - return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). //nolint:gosec // this conversion is is between types recommended by Go https://cs.opensource.google/go/go/+/master:src/reflect/value.go;l=2445 + return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). //nolint:gosec // this conversion is between types recommended by Go https://cs.opensource.google/go/go/+/master:src/reflect/value.go;l=2445 Elem(). Interface() } @@ -31,7 +31,7 @@ func SetStructureField(field reflect.Value, value interface{}) { if !field.IsValid() { return } - reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). //nolint:gosec // this conversion is is between types recommended by Go https://cs.opensource.google/go/go/+/master:src/reflect/value.go;l=2445 + reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). //nolint:gosec // this conversion is between types recommended by Go https://cs.opensource.google/go/go/+/master:src/reflect/value.go;l=2445 Elem(). Set(reflect.ValueOf(value)) } diff --git a/utils/retry/retry.go b/utils/retry/retry.go index b53e9aad45..315a71e21b 100644 --- a/utils/retry/retry.go +++ b/utils/retry/retry.go @@ -9,6 +9,7 @@ import ( "github.com/go-logr/logr" "github.com/ARM-software/golang-utils/utils/commonerrors" + "github.com/ARM-software/golang-utils/utils/safecast" ) // RetryIf will retry fn when the value returned from retryConditionFn is true @@ -39,7 +40,7 @@ func RetryIf(ctx context.Context, logger logr.Logger, retryPolicy *RetryPolicyCo retry.MaxDelay(retryPolicy.RetryWaitMax), retry.MaxJitter(25*time.Millisecond), retry.DelayType(retryType), - retry.Attempts(uint(retryPolicy.RetryMax)), //nolint:gosec // in normal use this will have had Validate() called which enforces that the minimum number of RetryMax is 0 so it won't overflow + retry.Attempts(safecast.ToUint(retryPolicy.RetryMax)), retry.RetryIf(retryConditionFn), retry.LastErrorOnly(true), retry.Context(ctx), diff --git a/utils/safecast/ReadMe.md b/utils/safecast/ReadMe.md new file mode 100644 index 0000000000..ffa69a57ac --- /dev/null +++ b/utils/safecast/ReadMe.md @@ -0,0 +1,19 @@ +# Safecast + +the purpose of this utilities is to perform safe number conversion in go similarly to [go-safecast](https://github.com/ccoVeille/go-safecast) from which they are inspired from. +It should help tackling gosec [G115 rule](https://github.com/securego/gosec/pull/1149) + + G115: Potential overflow when converting between integer types. + + and [CWE-190](https://cwe.mitre.org/data/definitions/190.html) + + + infinite loop + access to wrong resource by id + grant access to someone who exhausted their quota + +Contrary to `go-safecast` no error is returned when attempting casting and the MAX or MIN value of the type is returned instead if the value is beyond the allowed window. +For instance, `toInt8(255)-> 127` and `toInt8(-255)-> -128` + + + diff --git a/utils/safecast/boundary.go b/utils/safecast/boundary.go new file mode 100644 index 0000000000..9f761ee34d --- /dev/null +++ b/utils/safecast/boundary.go @@ -0,0 +1,35 @@ +package safecast + +func greaterThanUpperBoundary[C1 IConvertable, C2 IConvertable](value C1, upperBoundary C2) (greater bool) { + if value <= 0 { + return + } + + switch f := any(value).(type) { + case float64: + greater = f >= float64(upperBoundary) + case float32: + greater = float64(f) >= float64(upperBoundary) + default: + // for all other integer types, it fits in an uint64 without overflow as we know value is positive. + greater = uint64(value) > uint64(upperBoundary) + } + + return +} + +func lessThanLowerBoundary[T IConvertable, T2 IConvertable](value T, boundary T2) (lower bool) { + if value >= 0 { + return + } + + switch f := any(value).(type) { + case float64: + lower = f <= float64(boundary) + case float32: + lower = float64(f) <= float64(boundary) + default: + lower = int64(value) < int64(boundary) + } + return +} diff --git a/utils/safecast/cast.go b/utils/safecast/cast.go new file mode 100644 index 0000000000..279570d8d3 --- /dev/null +++ b/utils/safecast/cast.go @@ -0,0 +1,133 @@ +package safecast + +import "math" + +// ToInt attempts to convert any [IConvertable] value to an int. +// If the conversion results in a value outside the range of an int, +// the closest boundary value will be returned. +func ToInt[C IConvertable](i C) int { + if lessThanLowerBoundary(i, math.MinInt) { + return math.MinInt + } + if greaterThanUpperBoundary(i, math.MaxInt) { + return math.MaxInt + } + return int(i) +} + +// ToUint attempts to convert any [IConvertable] value to an uint. +// If the conversion results in a value outside the range of an uint, +// the closest boundary value will be returned. +func ToUint[C IConvertable](i C) uint { + if lessThanLowerBoundary(i, uint(0)) { + return 0 + } + if greaterThanUpperBoundary(i, uint(math.MaxUint)) { + return math.MaxUint + } + return uint(i) +} + +// ToInt8 attempts to convert any [IConvertable] value to an int8. +// If the conversion results in a value outside the range of an int8, +// the closest boundary value will be returned. +func ToInt8[C IConvertable](i C) int8 { + if lessThanLowerBoundary(i, math.MinInt8) { + return math.MinInt8 + } + if greaterThanUpperBoundary(i, math.MaxInt8) { + return math.MaxInt8 + } + return int8(i) +} + +// ToUint8 attempts to convert any [IConvertable] value to an uint8. +// If the conversion results in a value outside the range of an uint8, +// the closest boundary value will be returned. +func ToUint8[C IConvertable](i C) uint8 { + if lessThanLowerBoundary(i, 0) { + return 0 + } + if greaterThanUpperBoundary(i, math.MaxUint8) { + return math.MaxUint8 + } + return uint8(i) +} + +// ToInt16 attempts to convert any [IConvertable] value to an int16. +// If the conversion results in a value outside the range of an int16, +// the closest boundary value will be returned. +func ToInt16[C IConvertable](i C) int16 { + if lessThanLowerBoundary(i, math.MinInt16) { + return math.MinInt16 + } + if greaterThanUpperBoundary(i, math.MaxInt16) { + return math.MaxInt16 + } + return int16(i) +} + +// ToUint16 attempts to convert any [IConvertable] value to an uint16. +// If the conversion results in a value outside the range of an uint16, +// the closest boundary value will be returned. +func ToUint16[C IConvertable](i C) uint16 { + if lessThanLowerBoundary(i, 0) { + return 0 + } + if greaterThanUpperBoundary(i, math.MaxUint16) { + return math.MaxUint16 + } + return uint16(i) +} + +// ToInt32 attempts to convert any [IConvertable] value to an int32. +// If the conversion results in a value outside the range of an int32, +// the closest boundary value will be returned. +func ToInt32[C IConvertable](i C) int32 { + if lessThanLowerBoundary(i, math.MinInt32) { + return math.MinInt32 + } + if greaterThanUpperBoundary(i, math.MaxInt32) { + return math.MaxInt32 + } + return int32(i) +} + +// ToUint32 attempts to convert any [IConvertable] value to an uint32. +// If the conversion results in a value outside the range of an uint32, +// the closest boundary value will be returned. +func ToUint32[C IConvertable](i C) uint32 { + if lessThanLowerBoundary(i, 0) { + return 0 + } + if greaterThanUpperBoundary(i, math.MaxUint32) { + return math.MaxUint32 + } + return uint32(i) +} + +// ToInt64 attempts to convert any [IConvertable] value to an int64. +// If the conversion results in a value outside the range of an int64, +// the closest boundary value will be returned. +func ToInt64[C IConvertable](i C) int64 { + if lessThanLowerBoundary(i, math.MinInt64) { + return math.MinInt64 + } + if greaterThanUpperBoundary(i, math.MaxInt64) { + return math.MaxInt64 + } + return int64(i) +} + +// ToUint64 attempts to convert any [IConvertable] value to an uint64. +// If the conversion results in a value outside the range of an uint64, +// the closest boundary value will be returned. +func ToUint64[C IConvertable](i C) uint64 { + if lessThanLowerBoundary(i, uint64(0)) { + return 0 + } + if greaterThanUpperBoundary(i, uint64(math.MaxUint64)) { + return math.MaxUint64 + } + return uint64(i) +} diff --git a/utils/safecast/cast_test.go b/utils/safecast/cast_test.go new file mode 100644 index 0000000000..a4bf591d89 --- /dev/null +++ b/utils/safecast/cast_test.go @@ -0,0 +1,499 @@ +package safecast + +import ( + "fmt" + "math" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testCase[C1 IConvertable, C2 IConvertable] struct { + name string + ctype string + value C1 + expected C2 + testCaseFunc func(t *testing.T, tCase *testCase[C1, C2]) +} + +func TestCastingToInt(t *testing.T) { + tests := []testCase[int64, int64]{ + { + name: "zero", + ctype: "int8", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int8(tCase.expected), ToInt8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "zero", + ctype: "uint8", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint8(tCase.expected), ToUint8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "1", + ctype: "int8", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int8(tCase.expected), ToInt8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "1", + ctype: "uint8", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint8(tCase.expected), ToUint8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "-1", + ctype: "int8", + value: -1, + expected: -1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int8(tCase.expected), ToInt8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "-1", + ctype: "uint8", + value: -1, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint8(tCase.expected), ToUint8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Max", + ctype: "int8", + value: math.MaxInt8 + 1, + expected: math.MaxInt8, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int8(tCase.expected), ToInt8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Max", + ctype: "uint8", + value: math.MaxInt8 + 1, + expected: math.MaxInt8 + 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint8(tCase.expected), ToUint8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Max", + ctype: "uint8", + value: math.MaxUint8 + 1, + expected: math.MaxUint8, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint8(tCase.expected), ToUint8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Min", + ctype: "int8", + value: math.MinInt8 - 1, + expected: math.MinInt8, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int8(tCase.expected), ToInt8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Min", + ctype: "uint8", + value: math.MinInt8 - 1, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint8(tCase.expected), ToUint8(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "zero", + ctype: "int16", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int16(tCase.expected), ToInt16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "zero", + ctype: "uint16", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint16(tCase.expected), ToUint16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "1", + ctype: "int16", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int16(tCase.expected), ToInt16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "1", + ctype: "uint16", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint16(tCase.expected), ToUint16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "-1", + ctype: "int16", + value: -1, + expected: -1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int16(tCase.expected), ToInt16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "-1", + ctype: "uint16", + value: -1, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint16(tCase.expected), ToUint16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Max", + ctype: "int16", + value: math.MaxInt16 + 1, + expected: math.MaxInt16, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int16(tCase.expected), ToInt16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Max", + ctype: "uint16", + value: math.MaxInt16 + 1, + expected: math.MaxInt16 + 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint16(tCase.expected), ToUint16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Max", + ctype: "uint16", + value: math.MaxUint16 + 1, + expected: math.MaxUint16, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint16(tCase.expected), ToUint16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Min", + ctype: "int16", + value: math.MinInt16 - 1, + expected: math.MinInt16, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int16(tCase.expected), ToInt16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Min", + ctype: "uint16", + value: math.MinInt16 - 1, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint16(tCase.expected), ToUint16(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "zero", + ctype: "int32", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int32(tCase.expected), ToInt32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "zero", + ctype: "uint32", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint32(tCase.expected), ToUint32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "zero", + ctype: "uint32", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint32(tCase.expected), ToUint32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "1", + ctype: "int32", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int32(tCase.expected), ToInt32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "1", + ctype: "uint32", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint32(tCase.expected), ToUint32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "-1", + ctype: "int32", + value: -1, + expected: -1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int32(tCase.expected), ToInt32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "-1", + ctype: "uint32", + value: -1, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint32(tCase.expected), ToUint32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Max", + ctype: "int32", + value: math.MaxInt32 + 1, + expected: math.MaxInt32, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int32(tCase.expected), ToInt32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Max", + ctype: "uint32", + value: math.MaxInt32 + 1, + expected: math.MaxInt32 + 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint32(tCase.expected), ToUint32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Max", + ctype: "uint32", + value: math.MaxUint32 + 1, + expected: math.MaxUint32, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint32(tCase.expected), ToUint32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Min", + ctype: "int32", + value: math.MinInt32 - 1, + expected: math.MinInt32, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int32(tCase.expected), ToInt32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "Min", + ctype: "uint32", + value: math.MinInt32 - 1, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint32(tCase.expected), ToUint32(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "zero", + ctype: "int", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int(tCase.expected), ToInt(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "zero", + ctype: "uint", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint(tCase.expected), ToUint(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "1", + ctype: "int", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int(tCase.expected), ToInt(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "1", + ctype: "uint", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint(tCase.expected), ToUint(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "-1", + ctype: "int", + value: -1, + expected: -1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, int(tCase.expected), ToInt(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "-1", + ctype: "uint", + value: -1, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint(tCase.expected), ToUint(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "zero", + ctype: "int64", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, tCase.expected, ToInt64(tCase.value)) + }, + }, + { + name: "zero", + ctype: "uint64", + value: 0, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint64(tCase.expected), ToUint64(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "1", + ctype: "int64", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, tCase.expected, ToInt64(tCase.value)) + }, + }, + { + name: "1", + ctype: "uint64", + value: 1, + expected: 1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint64(tCase.expected), ToUint64(tCase.value)) //nolint: gosec //G115: testing + }, + }, + { + name: "-1", + ctype: "int64", + value: -1, + expected: -1, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, tCase.expected, ToInt64(tCase.value)) + }, + }, + { + name: "-1", + ctype: "uint64", + value: -1, + expected: 0, + testCaseFunc: func(r *testing.T, tCase *testCase[int64, int64]) { + assert.Equal(r, uint64(tCase.expected), ToUint64(tCase.value)) //nolint: gosec //G115: testing + }, + }, + } + + for i := range tests { + test := tests[i] + t.Run(fmt.Sprintf("%v/%v", test.ctype, test.name), func(r *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + test.testCaseFunc(r, &test) + }) + } + + t.Run("float", func(t *testing.T) { + t.Run("int8", func(t *testing.T) { + assert.Equal(t, int8(-4), ToInt8(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, int8(4), ToInt8(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint8(0), ToUint8(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint8(4), ToUint8(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, int8(math.MaxInt8), ToInt8(256.4)) //nolint: gosec //G115: testing + assert.Equal(t, uint8(math.MaxUint8), ToUint8(256.4)) //nolint: gosec //G115: testing + }) + t.Run("int16", func(t *testing.T) { + assert.Equal(t, int16(-4), ToInt16(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, int16(4), ToInt16(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint16(0), ToUint16(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint16(4), ToUint16(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, int16(math.MaxInt16), ToInt16(40000.4)) //nolint: gosec //G115: testing + assert.Equal(t, int16(math.MaxInt16), ToInt16(float32(40000.4))) //nolint: gosec //G115: testing + assert.Equal(t, int16(math.MinInt16), ToInt16(-32768.4)) //nolint: gosec //G115: testing + assert.Equal(t, uint16(math.MaxUint16), ToUint16(70000.4)) //nolint: gosec //G115: testing + }) + t.Run("int32", func(t *testing.T) { + assert.Equal(t, int32(-4), ToInt32(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, int32(4), ToInt32(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint32(0), ToUint32(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint32(4), ToUint32(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, int32(math.MaxInt32), ToInt32(2147483647.4)) //nolint: gosec //G115: testing + assert.Equal(t, int32(math.MaxInt32), ToInt32(float32(2147483647.4))) //nolint: gosec //G115: testing + assert.Equal(t, int32(math.MinInt32), ToInt32(float32(-2147483648.4))) //nolint: gosec //G115: testing + assert.Equal(t, uint32(math.MaxUint32), ToUint32(4294967295.4)) //nolint: gosec //G115: testing + }) + t.Run("int", func(t *testing.T) { + assert.Equal(t, -4, ToInt(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, 4, ToInt(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint(0), ToUint(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint(4), ToUint(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, math.MaxInt, ToInt(9223372036854775807.4)) //nolint: gosec //G115: testing + assert.Equal(t, uint(math.MaxUint), ToUint(18446744073709551615.4)) //nolint: gosec //G115: testing + assert.Equal(t, math.MinInt, ToInt(-9223372036854775808.4)) //nolint: gosec //G115: testing + assert.Equal(t, uint(0), ToUint(-18446744073709551615.4)) //nolint: gosec //G115: testing + }) + t.Run("int64", func(t *testing.T) { + assert.Equal(t, int64(-4), ToInt64(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, int64(4), ToInt64(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint64(0), ToUint64(-4.6)) //nolint: gosec //G115: testing + assert.Equal(t, uint64(4), ToUint64(4.6)) //nolint: gosec //G115: testing + assert.Equal(t, int64(math.MaxInt64), ToInt64(9223372036854775807.4)) //nolint: gosec //G115: testing + assert.Equal(t, uint64(math.MaxUint64), ToUint64(18446744073709551616.4)) //nolint: gosec //G115: testing + assert.Equal(t, int64(math.MinInt64), ToInt64(-9223372036854775808.4)) //nolint: gosec //G115: testing + assert.Equal(t, uint64(0), ToUint64(-18446744073709551616.4)) //nolint: gosec //G115: testing + }) + }) +} diff --git a/utils/safecast/fuzzcast_test.go b/utils/safecast/fuzzcast_test.go new file mode 100644 index 0000000000..b8777b2cdc --- /dev/null +++ b/utils/safecast/fuzzcast_test.go @@ -0,0 +1,141 @@ +package safecast + +import ( + "math" + "testing" +) + +func FuzzToInt(f *testing.F) { + f.Add(0) + f.Add(math.MinInt) + f.Add(math.MaxInt) + f.Fuzz(func(t *testing.T, from int) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToInt(from) + }) +} + +func FuzzToInt8(f *testing.F) { + f.Add(int8(0)) + f.Add(int8(math.MinInt8)) + f.Add(int8(math.MaxInt8)) + f.Fuzz(func(t *testing.T, from int8) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToInt8(from) + }) +} + +func FuzzToInt16(f *testing.F) { + f.Add(int16(0)) + f.Add(int16(math.MinInt16)) + f.Add(int16(math.MaxInt16)) + f.Fuzz(func(t *testing.T, from int16) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToInt16(from) + }) +} + +func FuzzToInt32(f *testing.F) { + f.Add(int32(0)) + f.Add(int32(math.MinInt32)) + f.Add(int32(math.MaxInt32)) + f.Fuzz(func(t *testing.T, from int32) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToInt32(from) + }) +} + +func FuzzToInt64(f *testing.F) { + f.Add(int64(0)) + f.Add(int64(math.MinInt64)) + f.Add(int64(math.MaxInt64)) + f.Fuzz(func(t *testing.T, from int64) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToInt64(from) + }) +} + +func FuzzToUint(f *testing.F) { + f.Add(uint(0)) + f.Add(uint(math.MaxUint)) + f.Fuzz(func(t *testing.T, from uint) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToUint(from) + }) +} + +func FuzzToUint8(f *testing.F) { + f.Add(uint8(0)) + f.Add(uint8(math.MaxUint8)) + f.Fuzz(func(t *testing.T, from uint8) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToUint8(from) + }) +} + +func FuzzToUint16(f *testing.F) { + f.Add(uint16(0)) + f.Add(uint16(math.MaxUint16)) + f.Fuzz(func(t *testing.T, from uint16) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToUint16(from) + }) +} + +func FuzzToUint32(f *testing.F) { + f.Add(uint32(0)) + f.Add(uint32(math.MaxUint32)) + f.Fuzz(func(t *testing.T, from uint32) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToUint32(from) + }) +} + +func FuzzToUint64(f *testing.F) { + f.Add(uint64(0)) + f.Add(uint64(math.MaxUint64)) + f.Fuzz(func(t *testing.T, from uint64) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("panic: %v", r) + } + }() + _ = ToUint64(from) + }) +} diff --git a/utils/safecast/number.go b/utils/safecast/number.go new file mode 100644 index 0000000000..ebb8af7733 --- /dev/null +++ b/utils/safecast/number.go @@ -0,0 +1,33 @@ +package safecast + +// This file is highly inspired from https://pkg.go.dev/golang.org/x/exp/constraints + +// ISignedInteger is an alias for all signed integers: int, int8, int16, int32, and int64 types. +type ISignedInteger interface { + ~int | ~int8 | ~int16 | ~int32 | ~int64 +} + +// IUnsignedInteger is an alias for all unsigned integers: uint, uint8, uint16, uint32, and uint64 types. +type IUnsignedInteger interface { + ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 +} + +// IInteger is an alias for the all unsigned and signed integers +type IInteger interface { + ISignedInteger | IUnsignedInteger +} + +// IFloat is an alias for the float32 and float64 types. +type IFloat interface { + ~float32 | ~float64 +} + +// INumber is an alias for all integers and floats +type INumber interface { + IInteger | IFloat +} + +// IConvertable is an alias for everything that can be converted +type IConvertable interface { + INumber +}