diff --git a/internal/client/client_batch.go b/internal/client/client_batch.go index d29606f90a..5ba07c25ac 100644 --- a/internal/client/client_batch.go +++ b/internal/client/client_batch.go @@ -588,7 +588,7 @@ func (c *batchCommandsClient) send(forwardedHost string, req *tikvpb.BatchComman zap.String("forwardedHost", forwardedHost), zap.Error(err), ) - c.failPendingRequests(err) + c.failRequestsByIDs(err, req.RequestIds) // fast fail requests. return } @@ -604,23 +604,50 @@ func (c *batchCommandsClient) send(forwardedHost string, req *tikvpb.BatchComman zap.Uint64s("requestIDs", req.RequestIds), zap.Error(err), ) - c.failPendingRequests(err) + c.failRequestsByIDs(err, req.RequestIds) // fast fail requests. } } // `failPendingRequests` must be called in locked contexts in order to avoid double closing channels. -func (c *batchCommandsClient) failPendingRequests(err error) { +// when enable-forwarding is true, the `forwardedHost` maybe not empty. +// failPendingRequests fails all pending requests which req.forwardedHost equals to forwardedHost parameter. +// Why need check `forwardedHost`? Here is an example, when enable-forwarding is true, and this client has network issue with store1: +// - some requests are sent to store1 with forwarding, such as forwardedHost is store2, those requests will succeed. +// - some requests are sent to store1 without forwarding, and may fail then `failPendingRequests` would be called, +// if we don't check `forwardedHost` and fail all pending requests, the requests with forwarding will be failed too. this may cause some issue: +// 1. data race. see https://github.com/tikv/client-go/issues/1222 and TestRandomRestartStoreAndForwarding. +// 2. panic which cause by `send on closed channel`, since failPendingRequests will close the entry.res channel, +// but in another batchRecvLoop goroutine, it may receive the response from forwardedHost store2 and try to send the response to entry.res channel, +// then panic by send on closed channel. +func (c *batchCommandsClient) failPendingRequests(err error, forwardedHost string) { util.EvalFailpoint("panicInFailPendingRequests") c.batched.Range(func(key, value interface{}) bool { id, _ := key.(uint64) entry, _ := value.(*batchCommandsEntry) - c.batched.Delete(id) - c.sent.Add(-1) - entry.error(err) + if entry.forwardedHost == forwardedHost { + c.failRequest(err, id, entry) + } return true }) } +// failRequestsByIDs fails requests by requestID. +func (c *batchCommandsClient) failRequestsByIDs(err error, requestIDs []uint64) { + for _, requestID := range requestIDs { + value, ok := c.batched.Load(requestID) + if !ok { + continue + } + c.failRequest(err, requestID, value.(*batchCommandsEntry)) + } +} + +func (c *batchCommandsClient) failRequest(err error, requestID uint64, entry *batchCommandsEntry) { + c.batched.Delete(requestID) + c.sent.Add(-1) + entry.error(err) +} + func (c *batchCommandsClient) waitConnReady() (err error) { state := c.conn.GetState() if state == connectivity.Ready { @@ -793,7 +820,7 @@ func (c *batchCommandsClient) recreateStreamingClient(err error, streamClient *b } *epoch++ - c.failPendingRequests(err) // fail all pending requests. + c.failPendingRequests(err, streamClient.forwardedHost) // fail all pending requests. b := retry.NewBackofferWithVars(context.Background(), math.MaxInt32, nil) for { // try to re-create the streaming in the loop. if c.isStopped() { diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 60d282ea74..3436406ecc 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -59,7 +59,6 @@ import ( "github.com/tikv/client-go/v2/internal/client/mockserver" "github.com/tikv/client-go/v2/internal/logutil" "github.com/tikv/client-go/v2/tikvrpc" - "github.com/tikv/client-go/v2/util/israce" "go.uber.org/zap" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/metadata" @@ -888,9 +887,6 @@ func TestBatchClientReceiveHealthFeedback(t *testing.T) { } func TestRandomRestartStoreAndForwarding(t *testing.T) { - if israce.RaceEnabled { - t.Skip("skip since race bug in issue #1222") - } store1, port1 := mockserver.StartMockTikvService() require.True(t, port1 > 0) require.True(t, store1.IsRunning()) @@ -908,6 +904,8 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) { wg := sync.WaitGroup{} done := int64(0) concurrency := 500 + addr1 := store1.Addr() + addr2 := store2.Addr() wg.Add(1) go func() { defer wg.Done() @@ -931,7 +929,7 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) { } }() - conn, err := client1.getConnArray(store1.Addr(), true) + conn, err := client1.getConnArray(addr1, true) assert.Nil(t, err) for j := 0; j < concurrency; j++ { wg.Add(1) @@ -944,9 +942,9 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) { req := &tikvpb.BatchCommandsRequest_Request{Cmd: &tikvpb.BatchCommandsRequest_Request_Coprocessor{Coprocessor: &coprocessor.Request{}}} forwardedHost := "" if i%2 != 0 { - forwardedHost = store2.Addr() + forwardedHost = addr2 } - _, err := sendBatchRequest(context.Background(), store1.Addr(), forwardedHost, conn.batchConn, req, time.Millisecond*50, 0) + _, err := sendBatchRequest(context.Background(), addr1, forwardedHost, conn.batchConn, req, time.Millisecond*50, 0) if err == nil || err.Error() == "EOF" || err.Error() == "rpc error: code = Unavailable desc = error reading from server: EOF" || @@ -964,11 +962,24 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) { for _, cli := range conn.batchConn.batchCommandsClients { require.Equal(t, int64(9223372036854775807), cli.maxConcurrencyRequestLimit.Load()) require.True(t, cli.available() > 0, fmt.Sprintf("sent: %d", cli.sent.Load())) - // TODO(crazycs520): fix me, see https://github.com/tikv/client-go/pull/1219 - //require.True(t, cli.sent.Load() >= 0, fmt.Sprintf("sent: %d", cli.sent.Load())) + require.True(t, cli.sent.Load() >= 0, fmt.Sprintf("sent: %d", cli.sent.Load())) } } +func TestFastFailRequest(t *testing.T) { + client := NewRPCClient() + defer func() { + err := client.Close() + require.NoError(t, err) + }() + start := time.Now() + unknownAddr := "127.0.0.1:52027" + req := tikvrpc.NewRequest(tikvrpc.CmdGet, &kvrpcpb.GetRequest{Key: []byte("key")}) + _, err := client.sendRequest(context.Background(), unknownAddr, req, time.Second*20) + require.Equal(t, "context deadline exceeded", errors.Cause(err).Error()) + require.True(t, time.Since(start) < time.Second*6) // fast fail when dial target failed. +} + func TestErrConn(t *testing.T) { e := errors.New("conn error") err1 := &ErrConn{Err: e, Addr: "127.0.0.1", Ver: 10}