diff --git a/core/trace/baggage.go b/core/trace/baggage.go new file mode 100644 index 000000000000..c17bfc7ed8cf --- /dev/null +++ b/core/trace/baggage.go @@ -0,0 +1,59 @@ +package trace + +import ( + "context" + + "github.com/zeromicro/go-zero/core/logc" + "go.opentelemetry.io/otel/baggage" +) + +// GetBaggageValue get baggage info from context, if key not exists, return "", false. +func GetBaggageValue(ctx context.Context, key string) (string, bool) { + b := baggage.FromContext(ctx) + m := b.Member(key) + + if m.Value() == "" { + return "", false + } + + return m.Value(), true +} + +// WithBaggage append baggage by string key val. +func WithBaggage(parent context.Context, key, val string) context.Context { + member, err := baggage.NewMember(key, val) + if err != nil { + logc.Error(parent, err) + return parent + } + + b := baggage.FromContext(parent) + b, err = b.SetMember(member) + if err != nil { + logc.Error(parent, err) + return parent + } + + return baggage.ContextWithBaggage(parent, b) +} + +// AddBaggagesFromMap append map kvs to current ctx baggage. +func AddBaggagesFromMap(parent context.Context, mp map[string]string) context.Context { + b := baggage.FromContext(parent) + + for k, v := range mp { + m, err := baggage.NewMember(k, v) + if err != nil { + logc.Error(parent, err) + return parent + } + + b, err = b.SetMember(m) + if err != nil { + logc.Error(parent, err) + return parent + } + } + + return baggage.ContextWithBaggage(parent, b) +} diff --git a/core/trace/baggage_test.go b/core/trace/baggage_test.go new file mode 100644 index 000000000000..5b481e53a499 --- /dev/null +++ b/core/trace/baggage_test.go @@ -0,0 +1,74 @@ +package trace + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.opentelemetry.io/otel/baggage" +) + +func TestAddBaggagesFromMap(t *testing.T) { + ctx := AddBaggagesFromMap(context.Background(), map[string]string{"test": "test", "aaa": "aaa"}) + b := baggage.FromContext(ctx) + assert.Equal(t, "test", b.Member("test").Value()) + assert.Equal(t, "aaa", b.Member("aaa").Value()) +} + +func TestGetBaggageValue(t *testing.T) { + ctx := AddBaggagesFromMap(context.Background(), map[string]string{"test": "aaa"}) + + type args struct { + ctx context.Context + key string + } + tests := []struct { + name string + args args + want string + want1 bool + }{ + { + "not exists", + args{ + ctx: context.Background(), + key: "test", + }, + "", + false, + }, + { + "exists", + args{ + ctx: ctx, + key: "test", + }, + "aaa", + true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := GetBaggageValue(tt.args.ctx, tt.args.key) + if got != tt.want { + t.Errorf("GetBaggageValue() got = %v, want %v", got, tt.want) + } + if got1 != tt.want1 { + t.Errorf("GetBaggageValue() got1 = %v, want %v", got1, tt.want1) + } + }) + } +} + +func TestWithBaggage(t *testing.T) { + ctx := context.Background() + ctx = WithBaggage(ctx, "aaa", "aaa") + val, ok := GetBaggageValue(ctx, "aaa") + assert.True(t, ok) + assert.Equal(t, "aaa", val) + + ctx = WithBaggage(ctx, "aaa", "bbb") + val, ok = GetBaggageValue(ctx, "aaa") + assert.True(t, ok) + assert.Equal(t, "bbb", val) +}