From 1c44dbeeb85a12800b42d902c00b83fe50e6f40e Mon Sep 17 00:00:00 2001 From: kevin Date: Sun, 29 Dec 2024 23:46:37 +0800 Subject: [PATCH] fix: service group not working well when callback takes long time Signed-off-by: kevin --- core/proc/shutdown.go | 3 +++ core/proc/shutdown_test.go | 37 ++++++++++++++++++++++++++++++++++++ core/service/servicegroup.go | 7 ++++++- gateway/server.go | 18 ++++++++++++++++++ gateway/server_test.go | 2 +- 5 files changed, 65 insertions(+), 2 deletions(-) diff --git a/core/proc/shutdown.go b/core/proc/shutdown.go index 763742a298df3..71ce429a4b0a6 100644 --- a/core/proc/shutdown.go +++ b/core/proc/shutdown.go @@ -82,6 +82,9 @@ func (lm *listenerManager) addListener(fn func()) (waitForCalled func()) { }) lm.lock.Unlock() + // we can return lm.waitGroup.Wait directly, + // but we want to make the returned func more readable. + // creating an extra closure would be negligible in practice. return func() { lm.waitGroup.Wait() } diff --git a/core/proc/shutdown_test.go b/core/proc/shutdown_test.go index 79c7deffa8af4..64517f0fb9274 100644 --- a/core/proc/shutdown_test.go +++ b/core/proc/shutdown_test.go @@ -3,6 +3,7 @@ package proc import ( + "sync/atomic" "testing" "time" @@ -29,6 +30,42 @@ func TestShutdown(t *testing.T) { assert.Equal(t, 3, val) } +func TestShutdownWithMultipleServices(t *testing.T) { + SetTimeToForceQuit(time.Hour) + assert.Equal(t, time.Hour, delayTimeBeforeForceQuit) + + var val int32 + called1 := AddShutdownListener(func() { + atomic.AddInt32(&val, 1) + }) + called2 := AddShutdownListener(func() { + atomic.AddInt32(&val, 2) + }) + Shutdown() + called1() + called2() + + assert.Equal(t, int32(3), atomic.LoadInt32(&val)) +} + +func TestWrapUpWithMultipleServices(t *testing.T) { + SetTimeToForceQuit(time.Hour) + assert.Equal(t, time.Hour, delayTimeBeforeForceQuit) + + var val int32 + called1 := AddWrapUpListener(func() { + atomic.AddInt32(&val, 1) + }) + called2 := AddWrapUpListener(func() { + atomic.AddInt32(&val, 2) + }) + WrapUp() + called1() + called2() + + assert.Equal(t, int32(3), atomic.LoadInt32(&val)) +} + func TestNotifyMoreThanOnce(t *testing.T) { ch := make(chan struct{}, 1) diff --git a/core/service/servicegroup.go b/core/service/servicegroup.go index 031abb5d30402..9281fb299a483 100644 --- a/core/service/servicegroup.go +++ b/core/service/servicegroup.go @@ -76,9 +76,14 @@ func (sg *ServiceGroup) doStart() { } func (sg *ServiceGroup) doStop() { + group := threading.NewRoutineGroup() for _, service := range sg.services { - service.Stop() + // new variable to avoid closure problems, can be removed after go 1.22 + // see https://golang.org/doc/faq#closures_and_goroutines + service := service + group.Run(service.Stop) } + group.Wait() } // WithStart wraps a start func as a Service. diff --git a/gateway/server.go b/gateway/server.go index 71d1e554eff00..a03e1bcecc47d 100644 --- a/gateway/server.go +++ b/gateway/server.go @@ -11,6 +11,7 @@ import ( "github.com/jhump/protoreflect/grpcreflect" "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/mr" + "github.com/zeromicro/go-zero/core/threading" "github.com/zeromicro/go-zero/gateway/internal" "github.com/zeromicro/go-zero/rest" "github.com/zeromicro/go-zero/rest/httpx" @@ -23,6 +24,7 @@ type ( Server struct { *rest.Server upstreams []Upstream + conns []zrpc.Client processHeader func(http.Header) []string dialer func(conf zrpc.RpcClientConf) zrpc.Client } @@ -52,7 +54,22 @@ func (s *Server) Start() { // Stop stops the gateway server. func (s *Server) Stop() { + // stop the HTTP server first, then close gRPC connections. + // in case of the gRPC server is stopped first, + // the HTTP server may still be running to accept requests. s.Server.Stop() + + group := threading.NewRoutineGroup() + for _, conn := range s.conns { + // new variable to avoid closure problems, can be removed after go 1.22 + // see https://golang.org/doc/faq#closures_and_goroutines + conn := conn + group.Run(func() { + // ignore the error when closing the connection + _ = conn.Conn().Close() + }) + } + group.Wait() } func (s *Server) build() error { @@ -71,6 +88,7 @@ func (s *Server) build() error { } else { cli = zrpc.MustNewClient(up.Grpc) } + s.conns = append(s.conns, cli) source, err := s.createDescriptorSource(cli, up) if err != nil { diff --git a/gateway/server_test.go b/gateway/server_test.go index 74168559a5739..68b56ade80afe 100644 --- a/gateway/server_test.go +++ b/gateway/server_test.go @@ -46,7 +46,7 @@ func dialer() func(context.Context, string) (net.Conn, error) { func TestMustNewServer(t *testing.T) { var c GatewayConf assert.NoError(t, conf.FillDefault(&c)) - // avoid popup alert on macos for asking permissions + // avoid popup alert on MacOS for asking permissions c.DevServer.Host = "localhost" c.Host = "localhost" c.Port = 18881