From 35dad74ea2fdca42f664d5b1c561ed4f9303bc08 Mon Sep 17 00:00:00 2001 From: gaowenju Date: Wed, 8 Nov 2023 18:56:45 +0800 Subject: [PATCH] test: add ut --- pkg/app/server/hertz_test.go | 8 ++++---- pkg/protocol/uri.go | 3 +++ pkg/protocol/uri_test.go | 37 ++++++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index 5baa3ecf5..036df76a5 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -259,7 +259,7 @@ func TestNotAbsolutePath(t *testing.T) { go engine.Run() time.Sleep(200 * time.Microsecond) - s := "POST ?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) ctx := app.NewContext(0) @@ -270,7 +270,7 @@ func TestNotAbsolutePath(t *testing.T) { assert.DeepEqual(t, consts.StatusOK, ctx.Response.StatusCode()) assert.DeepEqual(t, ctx.Request.Body(), ctx.Response.Body()) - s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr = mock.NewZeroCopyReader(s) ctx = app.NewContext(0) @@ -291,7 +291,7 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { go engine.Run() time.Sleep(200 * time.Microsecond) - s := "POST ?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) ctx := app.NewContext(0) @@ -302,7 +302,7 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { assert.DeepEqual(t, consts.StatusBadRequest, ctx.Response.StatusCode()) assert.DeepEqual(t, default400Body, ctx.Response.Body()) - s = "POST a?a=b HTTP/1.1\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" + s = "POST a?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr = mock.NewZeroCopyReader(s) ctx = app.NewContext(0) diff --git a/pkg/protocol/uri.go b/pkg/protocol/uri.go index 25edce466..ca483bb57 100644 --- a/pkg/protocol/uri.go +++ b/pkg/protocol/uri.go @@ -610,6 +610,9 @@ func (u *URI) updateBytes(newURI, buf []byte) []byte { if len(u.scheme) > 0 { schemeOriginal = append([]byte(nil), u.scheme...) } + if n == 0 { + newURI = bytes.Join([][]byte{u.scheme, bytestr.StrColon, newURI}, nil) + } u.Parse(nil, newURI) if len(schemeOriginal) > 0 && len(u.scheme) == 0 { u.scheme = append(u.scheme[:0], schemeOriginal...) diff --git a/pkg/protocol/uri_test.go b/pkg/protocol/uri_test.go index e245faf26..2e8109dc7 100644 --- a/pkg/protocol/uri_test.go +++ b/pkg/protocol/uri_test.go @@ -42,6 +42,7 @@ package protocol import ( + "bytes" "path/filepath" "reflect" "runtime" @@ -468,3 +469,39 @@ func TestParseURI(t *testing.T) { uri := string(ParseURI(expectURI).FullURI()) assert.DeepEqual(t, expectURI, uri) } + +func TestSplitHostURI(t *testing.T) { + cases := []struct { + host, uri []byte + wantScheme, wantHost, wantPath []byte + }{ + { + []byte("example.com"), []byte("/foobar"), + []byte("http"), []byte("example.com"), []byte("/foobar"), + }, + { + []byte("example2.com"), []byte("http://example2.com"), + []byte("http"), []byte("example2.com"), []byte("/"), + }, + { + []byte("example2.com"), []byte("http://example3.com"), + []byte("http"), []byte("example3.com"), []byte("/"), + }, + { + []byte("example3.com"), []byte("https://foobar.com?a=b"), + []byte("https"), []byte("foobar.com"), []byte("?a=b"), + }, + { + []byte("example.com"), []byte("//www.google.com/://../../ping"), + []byte("http"), []byte("example.com"), []byte("//www.google.com/://../../ping"), + }, + } + + for _, c := range cases { + gotScheme, gotHost, gotPath := splitHostURI(c.host, c.uri) + if !bytes.Equal(gotScheme, c.wantScheme) || !bytes.Equal(gotHost, c.wantHost) || !bytes.Equal(gotPath, c.wantPath) { + t.Errorf("splitHostURI(%q, %q) == (%q, %q, %q), want (%q, %q, %q)", + c.host, c.uri, gotScheme, gotHost, gotPath, c.wantScheme, c.wantHost, c.wantPath) + } + } +}