From bcaa17032c0204480355ccc1afaf114de3bb9924 Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Wed, 28 Aug 2024 09:22:01 +0800 Subject: [PATCH 1/2] test(streaming): stream ctx diverge and nphttp2.GetServerConn --- thrift_streaming/thrift_handler.go | 39 ++++++++++++++++----- thrift_streaming/thrift_test.go | 54 ++++++++++++++++++++++++++++++ thrift_streaming/util.go | 9 +++++ 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/thrift_streaming/thrift_handler.go b/thrift_streaming/thrift_handler.go index ebb66d6..b84b562 100644 --- a/thrift_streaming/thrift_handler.go +++ b/thrift_streaming/thrift_handler.go @@ -17,6 +17,7 @@ package thrift_streaming import ( "context" "errors" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "strconv" "github.com/cloudwego/kitex/pkg/klog" @@ -55,19 +56,39 @@ func (e EchoServiceImpl) EchoBidirectional(stream echo.EchoService_EchoBidirecti func (e EchoServiceImpl) EchoClient(stream echo.EchoService_EchoClientServer) (err error) { klog.Infof("EchoClient: start") - count := GetInt(stream.Context(), KeyCount, 0) - var req *echo.EchoRequest - for i := 0; i < count; i++ { - req, err = stream.Recv() + resp := &echo.EchoResponse{} + doGetServerConn := GetBool(stream.Context(), KeyGetServerConn, false) + doInspectMWCtx := GetBool(stream.Context(), KeyInspectMWCtx, false) + switch { + case doGetServerConn: + _, err = nphttp2.GetServerConn(stream) if err != nil { - klog.Infof("EchoClient: recv error = %v", err) + klog.Errorf("EchoClient: GetServerConn failed, error = %v", err) return } - klog.Infof("EchoClient: recv req = %v", req) - } - resp := &echo.EchoResponse{ - Message: strconv.Itoa(count), + resp.Message = "GetServerConn Succeeded" + case doInspectMWCtx: + val, ok := stream.Context().Value("key").(string) + if !ok || val != "val" { + err = errors.New("can not get ctx value set in server MW") + klog.Errorf("EchoClient: InspectMWCtx failed, error = %v", err) + return + } + resp.Message = "InspectMWCtx Succeeded" + default: + count := GetInt(stream.Context(), KeyCount, 0) + var req *echo.EchoRequest + for i := 0; i < count; i++ { + req, err = stream.Recv() + if err != nil { + klog.Infof("EchoClient: recv error = %v", err) + return + } + klog.Infof("EchoClient: recv req = %v", req) + } + resp.Message = strconv.Itoa(count) } + if err = stream.SendAndClose(resp); err != nil { klog.Infof("EchoClient: send&close error = %v", err) return diff --git a/thrift_streaming/thrift_test.go b/thrift_streaming/thrift_test.go index 8c42892..c9e7a71 100644 --- a/thrift_streaming/thrift_test.go +++ b/thrift_streaming/thrift_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2" "io" "reflect" "strconv" @@ -553,6 +554,59 @@ func TestKitexServerMiddleware(t *testing.T) { test.Assert(t, err == nil, err) test.Assert(t, resp.Message == "2_send_middleware", resp.Message) }) + + t.Run("gRPC GetServerConn", func(t *testing.T) { + svr := RunThriftServer(&EchoServiceImpl{}, addr, + server.WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, args, result interface{}) (err error) { + streamArg, ok := args.(*streaming.Args) + test.Assert(t, ok) + _, err = nphttp2.GetServerConn(streamArg.Stream) + test.Assert(t, err == nil, err) + return next(ctx, args, result) + } + }), + ) + defer svr.Stop() + + cli := echoservice.MustNewStreamClient("service", streamclient.WithHostPorts(addr)) + ctx := metainfo.WithValue(context.Background(), KeyGetServerConn, "true") + stream, err := cli.EchoClient(ctx) + test.Assert(t, err == nil, err) + + err = stream.Send(&echo.EchoRequest{Message: "GetServerConn"}) + test.Assert(t, err == nil, err) + + resp, err := stream.CloseAndRecv() + test.Assert(t, err == nil, err) + test.Assert(t, resp.Message == "GetServerConn Succeeded") + }) + + t.Run("process ctx in middleware and reflect to Stream.Context()", func(t *testing.T) { + svr := RunThriftServer(&EchoServiceImpl{}, addr, + server.WithMiddleware(func(next endpoint.Endpoint) endpoint.Endpoint { + return func(ctx context.Context, args, result interface{}) (err error) { + _, ok := args.(*streaming.Args) + test.Assert(t, ok) + ctx = context.WithValue(ctx, "key", "val") + return next(ctx, args, result) + } + }), + ) + defer svr.Stop() + + cli := echoservice.MustNewStreamClient("service", streamclient.WithHostPorts(addr)) + ctx := metainfo.WithValue(context.Background(), KeyInspectMWCtx, "true") + stream, err := cli.EchoClient(ctx) + test.Assert(t, err == nil, err) + + err = stream.Send(&echo.EchoRequest{Message: "InspectMWCtx"}) + test.Assert(t, err == nil, err) + + resp, err := stream.CloseAndRecv() + test.Assert(t, err == nil, err) + test.Assert(t, resp.Message == "InspectMWCtx Succeeded") + }) } func TestTimeoutRecvSend(t *testing.T) { diff --git a/thrift_streaming/util.go b/thrift_streaming/util.go index 96f9178..86cfff5 100644 --- a/thrift_streaming/util.go +++ b/thrift_streaming/util.go @@ -29,6 +29,8 @@ const ( KeyCount = "COUNT" KeyServerRecvTimeoutMS = "RECV_TIMEOUT_MS" KeyServerSendTimeoutMS = "SEND_TIMEOUT_MS" + KeyGetServerConn = "GET_SERVER_CONN" + KeyInspectMWCtx = "INSPECT_MW_CTX" ) func GetError(ctx context.Context) error { @@ -81,3 +83,10 @@ func GetInt(ctx context.Context, key string, defaultValue int) int { } return defaultValue } + +func GetBool(ctx context.Context, key string, defaultBool bool) bool { + if b, err := strconv.ParseBool(GetValue(ctx, key, "")); err == nil { + return b + } + return defaultBool +} From f4275c263b5744e6325e464e81cd6a1d5a0d144d Mon Sep 17 00:00:00 2001 From: "yuxuan.wang1" Date: Wed, 28 Aug 2024 09:24:53 +0800 Subject: [PATCH 2/2] adjust log level --- thrift_streaming/thrift_handler.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thrift_streaming/thrift_handler.go b/thrift_streaming/thrift_handler.go index b84b562..5f8d1f8 100644 --- a/thrift_streaming/thrift_handler.go +++ b/thrift_streaming/thrift_handler.go @@ -63,7 +63,7 @@ func (e EchoServiceImpl) EchoClient(stream echo.EchoService_EchoClientServer) (e case doGetServerConn: _, err = nphttp2.GetServerConn(stream) if err != nil { - klog.Errorf("EchoClient: GetServerConn failed, error = %v", err) + klog.Infof("EchoClient: GetServerConn failed, error = %v", err) return } resp.Message = "GetServerConn Succeeded" @@ -71,7 +71,7 @@ func (e EchoServiceImpl) EchoClient(stream echo.EchoService_EchoClientServer) (e val, ok := stream.Context().Value("key").(string) if !ok || val != "val" { err = errors.New("can not get ctx value set in server MW") - klog.Errorf("EchoClient: InspectMWCtx failed, error = %v", err) + klog.Infof("EchoClient: InspectMWCtx failed, error = %v", err) return } resp.Message = "InspectMWCtx Succeeded"