diff --git a/internal/client/client_batch.go b/internal/client/client_batch.go index 5e6d0e97b4..a4d72486c7 100644 --- a/internal/client/client_batch.go +++ b/internal/client/client_batch.go @@ -588,6 +588,7 @@ func (c *batchCommandsClient) send(forwardedHost string, req *tikvpb.BatchComman zap.String("forwardedHost", forwardedHost), zap.Error(err), ) + c.fastFailRequests(err, req.RequestIds) return } @@ -603,6 +604,7 @@ func (c *batchCommandsClient) send(forwardedHost string, req *tikvpb.BatchComman zap.Uint64s("requestIDs", req.RequestIds), zap.Error(err), ) + c.fastFailRequests(err, req.RequestIds) } } @@ -623,14 +625,29 @@ func (c *batchCommandsClient) failPendingRequests(err error, forwardedHost strin id, _ := key.(uint64) entry, _ := value.(*batchCommandsEntry) if entry.forwardedHost == forwardedHost { - c.batched.Delete(id) - c.sent.Add(-1) - entry.error(err) + c.failRequest(err, id, entry) } return true }) } +// fastFailRequests fast fails requests by requestID. +func (c *batchCommandsClient) fastFailRequests(err error, requestIDs []uint64) { + for _, requestID := range requestIDs { + value, ok := c.batched.Load(requestID) + if !ok { + continue + } + c.failRequest(err, requestID, value.(*batchCommandsEntry)) + } +} + +func (c *batchCommandsClient) failRequest(err error, requestID uint64, entry *batchCommandsEntry) { + c.batched.Delete(requestID) + c.sent.Add(-1) + entry.error(err) +} + func (c *batchCommandsClient) waitConnReady() (err error) { state := c.conn.GetState() if state == connectivity.Ready { diff --git a/internal/client/client_test.go b/internal/client/client_test.go index 56b0b398c1..85734d7189 100644 --- a/internal/client/client_test.go +++ b/internal/client/client_test.go @@ -931,6 +931,8 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) { conn, err := client1.getConnArray(addr1, true) assert.Nil(t, err) + count := int64(0) + start := time.Now() for j := 0; j < concurrency; j++ { wg.Add(1) go func() { @@ -945,6 +947,7 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) { forwardedHost = addr2 } _, err := sendBatchRequest(context.Background(), addr1, forwardedHost, conn.batchConn, req, time.Millisecond*50, 0) + atomic.AddInt64(&count, 1) if err == nil || err.Error() == "EOF" || err.Error() == "rpc error: code = Unavailable desc = error reading from server: EOF" || @@ -958,6 +961,11 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) { }() } wg.Wait() + qps := int64(float64(atomic.LoadInt64(&count)) / time.Since(start).Seconds()) + logutil.BgLogger().Info("TestRandomRestartStoreAndForwarding QPS", + zap.Int64("qps", qps), + zap.Int64("qps/concurrency", qps/int64(concurrency)), + zap.Duration("cost", time.Since(start))) for _, cli := range conn.batchConn.batchCommandsClients { require.Equal(t, int64(9223372036854775807), cli.maxConcurrencyRequestLimit.Load()) @@ -966,6 +974,21 @@ func TestRandomRestartStoreAndForwarding(t *testing.T) { } } +func TestFastFailRequest(t *testing.T) { + client1 := NewRPCClient() + defer func() { + err := client1.Close() + require.NoError(t, err) + }() + + start := time.Now() + unknownAddr := "127.0.0.1:52027" + req := tikvrpc.NewRequest(tikvrpc.CmdGet, &kvrpcpb.GetRequest{Key: []byte("key")}) + _, err := client1.sendRequest(context.Background(), unknownAddr, req, time.Second*20) + require.Equal(t, "context deadline exceeded", errors.Cause(err).Error()) + require.True(t, time.Since(start) < time.Second*6) // fast fail when dial target failed. +} + func TestErrConn(t *testing.T) { e := errors.New("conn error") err1 := &ErrConn{Err: e, Addr: "127.0.0.1", Ver: 10}