diff --git a/http2/clientconn_test.go b/http2/clientconn_test.go index 36f080b9bf..7aae704b34 100644 --- a/http2/clientconn_test.go +++ b/http2/clientconn_test.go @@ -10,13 +10,16 @@ package http2 import ( "bytes" "context" + "crypto/tls" "fmt" "io" + "net" "net/http" "os" "reflect" "runtime" "slices" + "sync" "sync/atomic" "testing" "time" @@ -81,6 +84,56 @@ func TestTestClientConn(t *testing.T) { rt.wantBody(nil) } +// TestConnectTimeout tests that a request does not exceed request timeout + dial timeout +func TestConnectTimeout(t *testing.T) { + tr := &Transport{ + DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { + // mock a net dialler with 1s timeout, encountering network issue + // keeping dialing until timeout + var dialer = net.Dialer{Timeout: time.Duration(-1)} + select { + case <-time.After(time.Second): + case <-ctx.Done(): + } + return dialer.DialContext(ctx, network, addr) + }, + AllowHTTP: true, + } + + var sg sync.WaitGroup + parentCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for j := 0; j < 2; j++ { + sg.Add(1) + go func() { + for i := 0; i < 10000; i++ { + sg.Add(1) + go func() { + ctx, _ := context.WithTimeout(parentCtx, time.Second) + req, err := http.NewRequestWithContext(ctx, "GET", "http://127.0.0.1:80", nil) + if err != nil { + t.Errorf("NewRequest: %v", err) + } + + start := time.Now() + tr.RoundTrip(req) + duration := time.Since(start) + // duration should not exceed request timeout + dial timeout + if duration > 2*time.Second { + t.Errorf("RoundTrip took %s; want <2s", duration.String()) + } + sg.Done() + }() + time.Sleep(1 * time.Millisecond) + } + sg.Done() + }() + } + + sg.Wait() +} + // A testClientConn allows testing ClientConn.RoundTrip against a fake server. // // A test using testClientConn consists of: