Skip to content

Commit

Permalink
feat: register service with service level middleware (#1510)
Browse files Browse the repository at this point in the history
  • Loading branch information
joway authored Sep 10, 2024
1 parent efd5eca commit 240f4ab
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 188 deletions.
3 changes: 3 additions & 0 deletions internal/server/register_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package server

import "github.com/cloudwego/kitex/pkg/endpoint"

// RegisterOption is the only way to config service registration.
type RegisterOption struct {
F func(o *RegisterOptions)
Expand All @@ -24,6 +26,7 @@ type RegisterOption struct {
// RegisterOptions is used to config service registration.
type RegisterOptions struct {
IsFallbackService bool
Middlewares []endpoint.Middleware
}

// NewRegisterOptions creates a register options.
Expand Down
21 changes: 14 additions & 7 deletions pkg/acl/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,27 @@ import (
// Returns a reason if rejected, otherwise returns nil.
type RejectFunc func(ctx context.Context, request interface{}) (reason error)

func ApplyRules(ctx context.Context, request interface{}, rules []RejectFunc) error {
for _, r := range rules {
if err := r(ctx, request); err != nil {
if !errors.Is(err, kerrors.ErrACL) {
err = kerrors.ErrACL.WithCause(err)
}
return err
}
}
return nil
}

// NewACLMiddleware creates a new ACL middleware using the provided reject funcs.
func NewACLMiddleware(rules []RejectFunc) endpoint.Middleware {
if len(rules) == 0 {
return endpoint.DummyMiddleware
}
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request, response interface{}) error {
for _, r := range rules {
if e := r(ctx, request); e != nil {
if !errors.Is(e, kerrors.ErrACL) {
e = kerrors.ErrACL.WithCause(e)
}
return e
}
if err := ApplyRules(ctx, request, rules); err != nil {
return err
}
return next(ctx, request, response)
}
Expand Down
1 change: 0 additions & 1 deletion server/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ func (s *tInvoker) Init() (err error) {
if len(s.server.svcs.svcMap) == 0 {
return errors.New("run: no service. Use RegisterService to set one")
}
s.buildFullInvokeChain()
s.initBasicRemoteOption()
// for server trans info handler
if len(s.server.opt.MetaHandlers) > 0 {
Expand Down
34 changes: 16 additions & 18 deletions server/middlewares.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,24 @@ import (
"github.com/cloudwego/kitex/pkg/streaming"
)

func serverTimeoutMW(initCtx context.Context) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request, response interface{}) (err error) {
// Regardless of the underlying protocol, only by checking the RPCTimeout
// For TTHeader, it will be set by transmeta.ServerTTHeaderHandler (not added by default though)
// For GRPC/HTTP2, the timeout deadline is already set in the context, so no need to check it
ri := rpcinfo.GetRPCInfo(ctx)
timeout := ri.Config().RPCTimeout()
if timeout <= 0 {
return next(ctx, request, response)
}

ctx, cancel := context.WithTimeout(ctx, timeout)
defer func() {
if err != nil {
cancel()
}
}()
func serverTimeoutMW(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request, response interface{}) (err error) {
// Regardless of the underlying protocol, only by checking the RPCTimeout
// For TTHeader, it will be set by transmeta.ServerTTHeaderHandler (not added by default though)
// For GRPC/HTTP2, the timeout deadline is already set in the context, so no need to check it
ri := rpcinfo.GetRPCInfo(ctx)
timeout := ri.Config().RPCTimeout()
if timeout <= 0 {
return next(ctx, request, response)
}

ctx, cancel := context.WithTimeout(ctx, timeout)
defer func() {
if err != nil {
cancel()
}
}()
return next(ctx, request, response)
}
}

Expand Down
11 changes: 5 additions & 6 deletions server/middlewares_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ func Test_serverTimeoutMW(t *testing.T) {
ri := rpcinfo.NewRPCInfo(from, to, nil, cfg, nil)
return rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
}
timeoutMW := serverTimeoutMW(context.Background())

t.Run("no_timeout(fastPath)", func(t *testing.T) {
// prepare
ctx := newCtxWithRPCInfo(0)

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
err := serverTimeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
ddl, ok := ctx.Deadline()
test.Assert(t, !ok)
test.Assert(t, ddl.IsZero())
Expand All @@ -87,7 +86,7 @@ func Test_serverTimeoutMW(t *testing.T) {
waitFinish := make(chan struct{})

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
err := serverTimeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
go func() {
timer := time.NewTimer(time.Millisecond * 20)
select {
Expand All @@ -112,7 +111,7 @@ func Test_serverTimeoutMW(t *testing.T) {
waitFinish := make(chan struct{})

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
err := serverTimeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
go func() {
timer := time.NewTimer(time.Millisecond * 20)
select {
Expand Down Expand Up @@ -141,7 +140,7 @@ func Test_serverTimeoutMW(t *testing.T) {
waitFinish := make(chan struct{})

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
err := serverTimeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
go func() {
timer := time.NewTimer(time.Millisecond * 60)
select {
Expand Down Expand Up @@ -171,7 +170,7 @@ func Test_serverTimeoutMW(t *testing.T) {
waitFinish := make(chan struct{})

// test
err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
err := serverTimeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
go func() {
timer := time.NewTimer(time.Millisecond * 60)
select {
Expand Down
1 change: 1 addition & 0 deletions server/option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func TestOptionDebugInfo(t *testing.T) {
}))

svr := NewServer(opts...)
svr.(*server).init()

// check probe result
pp := md.ProbePairs()
Expand Down
12 changes: 12 additions & 0 deletions server/register_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package server

import (
internal_server "github.com/cloudwego/kitex/internal/server"
"github.com/cloudwego/kitex/pkg/endpoint"
)

// RegisterOption is the only way to config service registration.
Expand All @@ -26,6 +27,17 @@ type RegisterOption = internal_server.RegisterOption
// RegisterOptions is used to config service registration.
type RegisterOptions = internal_server.RegisterOptions

// WithServiceMiddleware add middleware for a single service
// The execution order of middlewares follows:
// - server middlewares
// - service middlewares
// - service handler
func WithServiceMiddleware(mw endpoint.Middleware) RegisterOption {
return RegisterOption{F: func(o *internal_server.RegisterOptions) {
o.Middlewares = append(o.Middlewares, mw)
}}
}

func WithFallbackService() RegisterOption {
return RegisterOption{F: func(o *internal_server.RegisterOptions) {
o.IsFallbackService = true
Expand Down
Loading

0 comments on commit 240f4ab

Please sign in to comment.