Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support serializing more primitive values #4

Merged
merged 12 commits into from
Jun 20, 2024
215 changes: 191 additions & 24 deletions any.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,69 @@ package dispatch
import (
"fmt"
"reflect"
"time"

"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
)

// Any represents any value.
type Any struct {
proto *anypb.Any
type Any struct{ proto *anypb.Any }

// Bool creates an Any that contains a boolean value.
func Bool(v bool) Any {
return knownAny(wrapperspb.Bool(v))
}

// Int creates an Any that contains an integer value.
func Int(v int64) Any {
// Note: we serialize all integers using wrapperspb.Int64, even
// though wrapperspb.Int32 is available. A variable-length
// format is used for the wire representation of the integer, so
// there's no penalty for using a wider variable-length type.
// It simplifies the implementation here and elsewhere if there's
// only one wrapper used.
return knownAny(wrapperspb.Int64(v))
}

// Uint creates an Any that contains an unsigned integer value.
func Uint(v uint64) Any {
// See note above about 64-bit wrapper.
return knownAny(wrapperspb.UInt64(v))
}

// Float creates an Any that contains a floating point value.
func Float(v float64) Any {
// See notes above. We also exclusively use the Double (float64)
// wrapper to carry 32-bit and 64-bit floats. Although there
// is a size penalty in some cases, we're not shipping around
// so many floats that this is an issue. Prefer simplifying the
// implementation here and elsewhere by limiting the number of
// wrappers that are used.
return knownAny(wrapperspb.Double(v))
}

// String creates an Any that contains a string value.
func String(v string) Any {
return knownAny(wrapperspb.String(v))
}

// Bytes creates an Any that contains a bytes value.
func Bytes(v []byte) Any {
return knownAny(wrapperspb.Bytes(v))
}

// Time creates an Any that contains a time value.
func Time(v time.Time) Any {
return knownAny(timestamppb.New(v))
}

// Duration creates an Any that contains a duration value.
func Duration(v time.Duration) Any {
return knownAny(durationpb.New(v))
}

// NewAny creates an Any from a proto.Message.
Expand All @@ -20,80 +74,193 @@ func NewAny(v any) (Any, error) {
switch vv := v.(type) {
case proto.Message:
m = vv

case bool:
m = wrapperspb.Bool(vv)

case int:
m = wrapperspb.Int64(int64(vv))
case int8:
m = wrapperspb.Int64(int64(vv))
case int16:
m = wrapperspb.Int64(int64(vv))
case int32:
m = wrapperspb.Int64(int64(vv))
case int64:
m = wrapperspb.Int64(vv)

case uint:
m = wrapperspb.UInt64(uint64(vv))
case uint8:
m = wrapperspb.UInt64(uint64(vv))
case uint16:
m = wrapperspb.UInt64(uint64(vv))
case uint32:
m = wrapperspb.UInt64(uint64(vv))
case uint64:
m = wrapperspb.UInt64(uint64(vv))

case float32:
m = wrapperspb.Double(float64(vv))
case float64:
m = wrapperspb.Double(vv)

case string:
m = wrapperspb.String(vv)

case []byte:
m = wrapperspb.Bytes(vv)

case time.Time:
m = timestamppb.New(vv)
case time.Duration:
m = durationpb.New(vv)

default:
// TODO: support more types
return Any{}, fmt.Errorf("unsupported type: %T", v)
}

proto, err := anypb.New(m)
if err != nil {
return Any{}, err
}
return Any{proto}, nil
}

// Int creates an Any that contains an integer value.
func Int(v int) Any {
any, err := NewAny(wrapperspb.Int64(int64(v)))
func knownAny(v any) Any {
any, err := NewAny(v)
if err != nil {
panic(err)
}
return any
}

// String creates an Any that contains a string value.
func String(v string) Any {
any, err := NewAny(wrapperspb.String(v))
if err != nil {
panic(err)
}
return any
}
var (
timeType = reflect.TypeFor[time.Time]()
durationType = reflect.TypeFor[time.Duration]()
)

// Unmarshal unmarshals the value.
func (a Any) Unmarshal(v any) error {
if a.proto == nil {
return fmt.Errorf("empty Any")
}

r := reflect.ValueOf(v)
if r.Kind() != reflect.Pointer || r.IsNil() {
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Pointer || rv.IsNil() {
panic("Any.Unmarshal expects a pointer")
}
elem := r.Elem()
elem := rv.Elem()

m, err := a.proto.UnmarshalNew()
if err != nil {
return err
}

rm := reflect.ValueOf(m)
if rm.Type() == elem.Type() {

switch elem.Type() {
case rm.Type(): // e.g. a proto.Message impl
elem.Set(rm)
return nil

case timeType:
v, ok := m.(*timestamppb.Timestamp)
if !ok {
return fmt.Errorf("cannot unmarshal %T into time.Time", m)
} else if err := v.CheckValid(); err != nil {
return fmt.Errorf("cannot unmarshal %T into time.Time: %w", m, err)
}
elem.Set(reflect.ValueOf(v.AsTime()))
return nil

case durationType:
v, ok := m.(*durationpb.Duration)
if !ok {
return fmt.Errorf("cannot unmarshal %T into time.Duration", m)
} else if err := v.CheckValid(); err != nil {
return fmt.Errorf("cannot unmarshal %T into time.Duration: %w", m, err)
}
elem.SetInt(int64(v.AsDuration()))
return nil
}

switch elem.Kind() {
case reflect.Int:
v, ok := m.(*wrapperspb.Int64Value)
case reflect.Bool:
v, ok := m.(*wrapperspb.BoolValue)
if !ok {
return fmt.Errorf("cannot unmarshal %T into int", m)
return fmt.Errorf("cannot unmarshal %T into bool", m)
}
elem.SetInt(v.Value)
elem.SetBool(v.Value)
return nil

case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var i int64
if v, ok := m.(*wrapperspb.Int64Value); ok {
i = v.Value
} else if v, ok := m.(*wrapperspb.Int32Value); ok {
i = int64(v.Value)
} else {
return fmt.Errorf("cannot unmarshal %T into %T", m, elem.Interface())
}
if elem.OverflowInt(i) {
return fmt.Errorf("cannot unmarshal %T of %v into %T", m, i, elem.Interface())
}
elem.SetInt(i)
return nil

case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var u uint64
if v, ok := m.(*wrapperspb.UInt64Value); ok {
u = v.Value
} else if v, ok := m.(*wrapperspb.UInt32Value); ok {
u = uint64(v.Value)
} else {
return fmt.Errorf("cannot unmarshal %T into %T", m, elem.Interface())
}
if elem.OverflowUint(u) {
return fmt.Errorf("cannot unmarshal %T of %v into %T", m, u, elem.Interface())
}
elem.SetUint(u)
return nil

case reflect.Float32, reflect.Float64:
var f float64
if v, ok := m.(*wrapperspb.DoubleValue); ok {
f = v.Value
} else if v, ok := m.(*wrapperspb.FloatValue); ok {
f = float64(v.Value)
} else {
return fmt.Errorf("cannot unmarshal %T into %T", m, elem.Interface())
}
if elem.OverflowFloat(f) {
return fmt.Errorf("cannot unmarshal %T of %v into %T", m, f, elem.Interface())
}
elem.SetFloat(f)
return nil

case reflect.String:
v, ok := m.(*wrapperspb.StringValue)
if !ok {
return fmt.Errorf("cannot unmarshal %T into string", m)
}
elem.SetString(v.Value)
return nil

default:
// Special case for []byte. Other reflect.Slice values aren't supported at this time.
if elem.Kind() == reflect.Slice && elem.Type().Elem().Kind() == reflect.Uint8 {
v, ok := m.(*wrapperspb.BytesValue)
if !ok {
return fmt.Errorf("cannot unmarshal %T into []byte", m)
}
elem.SetBytes(v.Value)
return nil
}

// TODO: support more types
return fmt.Errorf("unsupported type: %T", elem.Interface())
return fmt.Errorf("unsupported type: %v (%v kind)", elem.Type(), elem.Kind())
}
return nil
}

// TypeURL is a URL that uniquely identifies the type of the
Expand Down
Loading
Loading