diff --git a/any.go b/any.go index 04d26a2..7c0dfd2 100644 --- a/any.go +++ b/any.go @@ -3,15 +3,69 @@ package dispatch import ( "fmt" "reflect" + "time" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" ) // Any represents any value. -type Any struct { - proto *anypb.Any +type Any struct{ proto *anypb.Any } + +// Bool creates an Any that contains a boolean value. +func Bool(v bool) Any { + return knownAny(wrapperspb.Bool(v)) +} + +// Int creates an Any that contains an integer value. +func Int(v int64) Any { + // Note: we serialize all integers using wrapperspb.Int64, even + // though wrapperspb.Int32 is available. A variable-length + // format is used for the wire representation of the integer, so + // there's no penalty for using a wider variable-length type. + // It simplifies the implementation here and elsewhere if there's + // only one wrapper used. + return knownAny(wrapperspb.Int64(v)) +} + +// Uint creates an Any that contains an unsigned integer value. +func Uint(v uint64) Any { + // See note above about 64-bit wrapper. + return knownAny(wrapperspb.UInt64(v)) +} + +// Float creates an Any that contains a floating point value. +func Float(v float64) Any { + // See notes above. We also exclusively use the Double (float64) + // wrapper to carry 32-bit and 64-bit floats. Although there + // is a size penalty in some cases, we're not shipping around + // so many floats that this is an issue. Prefer simplifying the + // implementation here and elsewhere by limiting the number of + // wrappers that are used. + return knownAny(wrapperspb.Double(v)) +} + +// String creates an Any that contains a string value. +func String(v string) Any { + return knownAny(wrapperspb.String(v)) +} + +// Bytes creates an Any that contains a bytes value. +func Bytes(v []byte) Any { + return knownAny(wrapperspb.Bytes(v)) +} + +// Time creates an Any that contains a time value. +func Time(v time.Time) Any { + return knownAny(timestamppb.New(v)) +} + +// Duration creates an Any that contains a duration value. +func Duration(v time.Duration) Any { + return knownAny(durationpb.New(v)) } // NewAny creates an Any from a proto.Message. @@ -20,14 +74,53 @@ func NewAny(v any) (Any, error) { switch vv := v.(type) { case proto.Message: m = vv + + case bool: + m = wrapperspb.Bool(vv) + case int: m = wrapperspb.Int64(int64(vv)) + case int8: + m = wrapperspb.Int64(int64(vv)) + case int16: + m = wrapperspb.Int64(int64(vv)) + case int32: + m = wrapperspb.Int64(int64(vv)) + case int64: + m = wrapperspb.Int64(vv) + + case uint: + m = wrapperspb.UInt64(uint64(vv)) + case uint8: + m = wrapperspb.UInt64(uint64(vv)) + case uint16: + m = wrapperspb.UInt64(uint64(vv)) + case uint32: + m = wrapperspb.UInt64(uint64(vv)) + case uint64: + m = wrapperspb.UInt64(uint64(vv)) + + case float32: + m = wrapperspb.Double(float64(vv)) + case float64: + m = wrapperspb.Double(vv) + case string: m = wrapperspb.String(vv) + + case []byte: + m = wrapperspb.Bytes(vv) + + case time.Time: + m = timestamppb.New(vv) + case time.Duration: + m = durationpb.New(vv) + default: // TODO: support more types return Any{}, fmt.Errorf("unsupported type: %T", v) } + proto, err := anypb.New(m) if err != nil { return Any{}, err @@ -35,23 +128,18 @@ func NewAny(v any) (Any, error) { return Any{proto}, nil } -// Int creates an Any that contains an integer value. -func Int(v int) Any { - any, err := NewAny(wrapperspb.Int64(int64(v))) +func knownAny(v any) Any { + any, err := NewAny(v) if err != nil { panic(err) } return any } -// String creates an Any that contains a string value. -func String(v string) Any { - any, err := NewAny(wrapperspb.String(v)) - if err != nil { - panic(err) - } - return any -} +var ( + timeType = reflect.TypeFor[time.Time]() + durationType = reflect.TypeFor[time.Duration]() +) // Unmarshal unmarshals the value. func (a Any) Unmarshal(v any) error { @@ -59,41 +147,120 @@ func (a Any) Unmarshal(v any) error { return fmt.Errorf("empty Any") } - r := reflect.ValueOf(v) - if r.Kind() != reflect.Pointer || r.IsNil() { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { panic("Any.Unmarshal expects a pointer") } - elem := r.Elem() + elem := rv.Elem() m, err := a.proto.UnmarshalNew() if err != nil { return err } - rm := reflect.ValueOf(m) - if rm.Type() == elem.Type() { + + switch elem.Type() { + case rm.Type(): // e.g. a proto.Message impl elem.Set(rm) return nil + + case timeType: + v, ok := m.(*timestamppb.Timestamp) + if !ok { + return fmt.Errorf("cannot unmarshal %T into time.Time", m) + } else if err := v.CheckValid(); err != nil { + return fmt.Errorf("cannot unmarshal %T into time.Time: %w", m, err) + } + elem.Set(reflect.ValueOf(v.AsTime())) + return nil + + case durationType: + v, ok := m.(*durationpb.Duration) + if !ok { + return fmt.Errorf("cannot unmarshal %T into time.Duration", m) + } else if err := v.CheckValid(); err != nil { + return fmt.Errorf("cannot unmarshal %T into time.Duration: %w", m, err) + } + elem.SetInt(int64(v.AsDuration())) + return nil } switch elem.Kind() { - case reflect.Int: - v, ok := m.(*wrapperspb.Int64Value) + case reflect.Bool: + v, ok := m.(*wrapperspb.BoolValue) if !ok { - return fmt.Errorf("cannot unmarshal %T into int", m) + return fmt.Errorf("cannot unmarshal %T into bool", m) } - elem.SetInt(v.Value) + elem.SetBool(v.Value) + return nil + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + var i int64 + if v, ok := m.(*wrapperspb.Int64Value); ok { + i = v.Value + } else if v, ok := m.(*wrapperspb.Int32Value); ok { + i = int64(v.Value) + } else { + return fmt.Errorf("cannot unmarshal %T into %T", m, elem.Interface()) + } + if elem.OverflowInt(i) { + return fmt.Errorf("cannot unmarshal %T of %v into %T", m, i, elem.Interface()) + } + elem.SetInt(i) + return nil + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + var u uint64 + if v, ok := m.(*wrapperspb.UInt64Value); ok { + u = v.Value + } else if v, ok := m.(*wrapperspb.UInt32Value); ok { + u = uint64(v.Value) + } else { + return fmt.Errorf("cannot unmarshal %T into %T", m, elem.Interface()) + } + if elem.OverflowUint(u) { + return fmt.Errorf("cannot unmarshal %T of %v into %T", m, u, elem.Interface()) + } + elem.SetUint(u) + return nil + + case reflect.Float32, reflect.Float64: + var f float64 + if v, ok := m.(*wrapperspb.DoubleValue); ok { + f = v.Value + } else if v, ok := m.(*wrapperspb.FloatValue); ok { + f = float64(v.Value) + } else { + return fmt.Errorf("cannot unmarshal %T into %T", m, elem.Interface()) + } + if elem.OverflowFloat(f) { + return fmt.Errorf("cannot unmarshal %T of %v into %T", m, f, elem.Interface()) + } + elem.SetFloat(f) + return nil + case reflect.String: v, ok := m.(*wrapperspb.StringValue) if !ok { return fmt.Errorf("cannot unmarshal %T into string", m) } elem.SetString(v.Value) + return nil + default: + // Special case for []byte. Other reflect.Slice values aren't supported at this time. + if elem.Kind() == reflect.Slice && elem.Type().Elem().Kind() == reflect.Uint8 { + v, ok := m.(*wrapperspb.BytesValue) + if !ok { + return fmt.Errorf("cannot unmarshal %T into []byte", m) + } + elem.SetBytes(v.Value) + return nil + } + // TODO: support more types - return fmt.Errorf("unsupported type: %T", elem.Interface()) + return fmt.Errorf("unsupported type: %v (%v kind)", elem.Type(), elem.Kind()) } - return nil } // TypeURL is a URL that uniquely identifies the type of the diff --git a/any_test.go b/any_test.go new file mode 100644 index 0000000..f551d7e --- /dev/null +++ b/any_test.go @@ -0,0 +1,236 @@ +package dispatch_test + +import ( + "bytes" + "fmt" + "math" + "reflect" + "strings" + "testing" + "time" + + "github.com/dispatchrun/dispatch-go" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/emptypb" + "google.golang.org/protobuf/types/known/timestamppb" + "google.golang.org/protobuf/types/known/wrapperspb" +) + +func TestAnyBool(t *testing.T) { + for _, v := range []bool{true, false} { + boxed := dispatch.Bool(v) + var got bool + if err := boxed.Unmarshal(&got); err != nil { + t.Fatal(err) + } else if got != v { + t.Errorf("unexpected bool: got %v, want %v", got, v) + } + } +} + +func TestAnyInt(t *testing.T) { + for _, v := range []int64{0, 11, -1, 2, math.MinInt, math.MaxInt} { + boxed := dispatch.Int(v) + var got int64 + if err := boxed.Unmarshal(&got); err != nil { + t.Fatal(err) + } else if got != v { + t.Errorf("unexpected int: got %v, want %v", got, v) + } + } +} + +func TestAnyUint(t *testing.T) { + for _, v := range []uint64{0, 11, 2, math.MaxUint} { + boxed := dispatch.Uint(v) + var got uint64 + if err := boxed.Unmarshal(&got); err != nil { + t.Fatal(err) + } else if got != v { + t.Errorf("unexpected uint: got %v, want %v", got, v) + } + } +} + +func TestAnyFloat(t *testing.T) { + for _, v := range []float64{0, 3.14, 11.11, math.MaxFloat64} { + boxed := dispatch.Float(v) + var got float64 + if err := boxed.Unmarshal(&got); err != nil { + t.Fatal(err) + } else if got != v { + t.Errorf("unexpected float: got %v, want %v", got, v) + } + } +} + +func TestAnyString(t *testing.T) { + for _, v := range []string{"", "x", "foobar", strings.Repeat("abc", 100)} { + boxed := dispatch.String(v) + var got string + if err := boxed.Unmarshal(&got); err != nil { + t.Fatal(err) + } else if got != v { + t.Errorf("unexpected string: got %v, want %v", got, v) + } + } +} + +func TestAnyBytes(t *testing.T) { + for _, v := range [][]byte{nil, []byte("foobar"), bytes.Repeat([]byte("abc"), 100)} { + boxed := dispatch.Bytes(v) + var got []byte + if err := boxed.Unmarshal(&got); err != nil { + t.Fatal(err) + } else if !bytes.Equal(v, got) { + t.Errorf("unexpected bytes: got %v, want %v", got, v) + } + } +} + +func TestAnyTime(t *testing.T) { + for _, v := range []time.Time{time.Now(), { /*zero*/ }, time.Date(2024, time.June, 10, 11, 30, 1, 2, time.UTC)} { + boxed := dispatch.Time(v) + var got time.Time + if err := boxed.Unmarshal(&got); err != nil { + t.Fatal(err) + } else if !got.Equal(v) { + t.Errorf("unexpected time: got %v, want %v", got, v) + } + } +} + +func TestAnyDuration(t *testing.T) { + for _, v := range []time.Duration{0, time.Second, 10 * time.Hour} { + boxed := dispatch.Duration(v) + var got time.Duration + if err := boxed.Unmarshal(&got); err != nil { + t.Fatal(err) + } else if got != v { + t.Errorf("unexpected duration: got %v, want %v", got, v) + } + } +} + +func TestOverflow(t *testing.T) { + var i8 int8 + if err := dispatch.Int(math.MinInt8 - 1).Unmarshal(&i8); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of -129 into int8" { + t.Errorf("unexpected error: %v", err) + } + if err := dispatch.Int(math.MaxInt8 + 1).Unmarshal(&i8); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of 128 into int8" { + t.Errorf("unexpected error: %v", err) + } + + var i16 int16 + if err := dispatch.Int(math.MinInt16 - 1).Unmarshal(&i16); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of -32769 into int16" { + t.Errorf("unexpected error: %v", err) + } + if err := dispatch.Int(math.MaxInt16 + 1).Unmarshal(&i16); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of 32768 into int16" { + t.Errorf("unexpected error: %v", err) + } + + var i32 int32 + if err := dispatch.Int(math.MinInt32 - 1).Unmarshal(&i32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of -2147483649 into int32" { + t.Errorf("unexpected error: %v", err) + } + if err := dispatch.Int(math.MaxInt32 + 1).Unmarshal(&i32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.Int64Value of 2147483648 into int32" { + t.Errorf("unexpected error: %v", err) + } + + var u8 uint8 + if err := dispatch.Uint(math.MaxUint8 + 1).Unmarshal(&u8); err == nil || err.Error() != "cannot unmarshal *wrapperspb.UInt64Value of 256 into uint8" { + t.Errorf("unexpected error: %v", err) + } + var u16 uint16 + if err := dispatch.Uint(math.MaxUint16 + 1).Unmarshal(&u16); err == nil || err.Error() != "cannot unmarshal *wrapperspb.UInt64Value of 65536 into uint16" { + t.Errorf("unexpected error: %v", err) + } + var u32 uint32 + if err := dispatch.Uint(math.MaxUint32 + 1).Unmarshal(&u32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.UInt64Value of 4294967296 into uint32" { + t.Errorf("unexpected error: %v", err) + } + + var f32 float32 + if err := dispatch.Float(math.MaxFloat32 + math.MaxFloat32).Unmarshal(&f32); err == nil || err.Error() != "cannot unmarshal *wrapperspb.DoubleValue of 6.805646932770577e+38 into float32" { + t.Errorf("unexpected error: %v", err) + } + + badTime, err := dispatch.NewAny(×tamppb.Timestamp{Seconds: math.MinInt64}) + if err != nil { + t.Fatal(err) + } + var tt time.Time + if err := badTime.Unmarshal(&tt); err == nil { + t.Error("expected an error") + } + + badDuration, err := dispatch.NewAny(&durationpb.Duration{Seconds: math.MaxInt64}) + if err != nil { + t.Fatal(err) + } + var td time.Duration + if err := badDuration.Unmarshal(&td); err == nil { + t.Error("expected an error") + } +} + +func TestAny(t *testing.T) { + for _, v := range []any{ + true, + false, + + 11, + int8(-1), + int16(math.MaxInt16), + int32(23), + int64(math.MinInt64), + + uint(1), + uint8(128), + uint16(math.MaxUint16), + uint32(0xDEADBEEF), + uint64(math.MaxUint64), + + float32(3.14), + float64(11.11), + + "", + "foo", + + []byte("bar"), + + time.Now().UTC(), + + 11 * time.Second, + + // Raw proto.Message + &emptypb.Empty{}, + &wrapperspb.Int32Value{Value: 11}, + } { + t.Run(fmt.Sprintf("%v", v), func(t *testing.T) { + boxed, err := dispatch.NewAny(v) + if err != nil { + t.Fatalf("NewAny(%v): %v", v, err) + } + + rv := reflect.New(reflect.TypeOf(v)) + if err := boxed.Unmarshal(rv.Interface()); err != nil { + t.Fatal(err) + } + + got := rv.Elem().Interface() + want := reflect.ValueOf(v).Interface() + + var equal bool + if wantProto, ok := want.(proto.Message); ok { + equal = proto.Equal(got.(proto.Message), wantProto) + } else { + equal = reflect.DeepEqual(got, want) + } + if !equal { + t.Errorf("unexpected NewAny(%v).Unmarshal result: %#v", v, got) + } + }) + } +}