diff --git a/pkg/tidb/forwarder.go b/pkg/tidb/forwarder.go index 88d07e9ccf..33d9045ea5 100644 --- a/pkg/tidb/forwarder.go +++ b/pkg/tidb/forwarder.go @@ -7,6 +7,7 @@ import ( "fmt" "net" "strconv" + "sync" "time" "github.com/cenkalti/backoff/v4" @@ -37,6 +38,8 @@ type Forwarder struct { sqlPort int statusProxy *proxy statusPort int + + wg *sync.WaitGroup } func (f *Forwarder) Start(ctx context.Context) error { @@ -53,9 +56,10 @@ func (f *Forwarder) Start(ctx context.Context) error { f.sqlPort = f.sqlProxy.port() f.statusPort = f.statusProxy.port() + f.wg.Add(3) go f.pollingForTiDB() - go f.sqlProxy.run(ctx) - go f.statusProxy.run(ctx) + go f.sqlProxy.run(ctx, f.wg) + go f.statusProxy.run(ctx, f.wg) return nil } @@ -70,6 +74,7 @@ func (f *Forwarder) createProxy() (*proxy, error) { } func (f *Forwarder) pollingForTiDB() { + defer f.wg.Done() ebo := backoff.NewExponentialBackOff() ebo.MaxInterval = f.config.TiDBPollInterval bo := backoff.WithContext(ebo, f.lifecycleCtx) @@ -119,9 +124,14 @@ func newForwarder(lc fx.Lifecycle, etcdClient *clientv3.Client) *Forwarder { ProxyCheckInterval: 2 * time.Second, }, etcdClient: etcdClient, + wg: &sync.WaitGroup{}, } lc.Append(fx.Hook{ OnStart: f.Start, + OnStop: func(context.Context) error { + f.wg.Wait() + return nil + }, }) return f } diff --git a/pkg/tidb/proxy.go b/pkg/tidb/proxy.go index 3a2effa7b3..f502f78ab0 100644 --- a/pkg/tidb/proxy.go +++ b/pkg/tidb/proxy.go @@ -201,7 +201,8 @@ func (p *proxy) doCheck(ctx context.Context) { } } -func (p *proxy) run(ctx context.Context) { +func (p *proxy) run(ctx context.Context, wg *sync.WaitGroup) { + defer wg.Done() endpoints := make([]string, 0) p.remotes.Range(func(key, value interface{}) bool { r := value.(*remote) @@ -220,15 +221,24 @@ func (p *proxy) run(ctx context.Context) { } // serve for { + type accepted struct { + conn net.Conn + err error + } + accept := make(chan accepted, 1) + go func() { + conn, err := p.listener.Accept() + accept <- accepted{conn, err} + }() + select { case <-ctx.Done(): return - default: - incoming, err := p.listener.Accept() - if err != nil { - log.Warn("got err from listener", zap.Error(err), zap.String("from", p.listener.Addr().String())) + case a := <-accept: + if a.err != nil { + log.Warn("got err from listener", zap.Error(a.err), zap.String("from", p.listener.Addr().String())) } else { - go p.serve(incoming) + go p.serve(a.conn) } } } diff --git a/pkg/tidb/proxy_test.go b/pkg/tidb/proxy_test.go index e360624f4c..6804ca97e7 100644 --- a/pkg/tidb/proxy_test.go +++ b/pkg/tidb/proxy_test.go @@ -11,6 +11,7 @@ import ( "net/http/httptest" "net/url" "strconv" + "sync" "testing" "time" @@ -37,7 +38,9 @@ func TestProxy(t *testing.T) { } p := newProxy(l, map[string]string{"test": fmt.Sprintf("%s:%s", u.Hostname(), u.Port())}, 0, 0) ctx, cancel := context.WithCancel(context.Background()) - go p.run(ctx) + wg := &sync.WaitGroup{} + wg.Add(1) + go p.run(ctx, wg) defer cancel() u.Host = l.Addr().String() @@ -86,7 +89,9 @@ func TestProxyPick(t *testing.T) { } p := newProxy(l, endpoints, 0, 0) ctx, cancel := context.WithCancel(context.Background()) - go p.run(ctx) + wg := &sync.WaitGroup{} + wg.Add(1) + go p.run(ctx, wg) defer cancel() for i := 0; i < n; i++ { diff --git a/tests/integration/info/info_test.go b/tests/integration/info/info_test.go index d78e32edae..4684b18f48 100644 --- a/tests/integration/info/info_test.go +++ b/tests/integration/info/info_test.go @@ -4,6 +4,7 @@ package info import ( "bytes" + "context" "encoding/json" "net/http" "testing" @@ -42,7 +43,8 @@ func TestInfoSuite(t *testing.T) { fx.Populate(&infoService), fx.Populate(&codeService), ) - app.RequireStart() + ctx, cancel := context.WithCancel(context.Background()) + app.RequireStart(ctx) suite.Run(t, &testInfoSuite{ db: db, @@ -51,6 +53,8 @@ func TestInfoSuite(t *testing.T) { codeService: codeService, }) + // exit the app + cancel() app.RequireStop() } diff --git a/tests/integration/user/user_test.go b/tests/integration/user/user_test.go index c266cd6de6..116e77c869 100644 --- a/tests/integration/user/user_test.go +++ b/tests/integration/user/user_test.go @@ -4,6 +4,7 @@ package user import ( "bytes" + "context" "encoding/json" "net/http" "testing" @@ -41,7 +42,8 @@ func TestUserSuite(t *testing.T) { fx.Populate(&authService), fx.Populate(&infoService), ) - app.RequireStart() + ctx, cancel := context.WithCancel(context.Background()) + app.RequireStart(ctx) suite.Run(t, &testUserSuite{ db: db, @@ -49,6 +51,8 @@ func TestUserSuite(t *testing.T) { infoService: infoService, }) + // exit the app + cancel() app.RequireStop() } diff --git a/tests/util/mock_app.go b/tests/util/mock_app.go index 8f6cb7141b..39cad98481 100644 --- a/tests/util/mock_app.go +++ b/tests/util/mock_app.go @@ -39,8 +39,8 @@ func NewMockApp(tb fxtest.TB, tidbVersion string, c *config.Config, opts ...fx.O // RequireStart calls Start, failing the test if an error is encountered. // It also sleep 5 seconds to wait for the server to start. -func (app *App) RequireStart() *App { - if err := app.Start(context.Background()); err != nil { +func (app *App) RequireStart(ctx context.Context) *App { + if err := app.Start(ctx); err != nil { app.tb.Errorf("application didn't start cleanly: %v", err) app.tb.FailNow() }