diff --git a/pkg/protocol/uri.go b/pkg/protocol/uri.go index 4fd8788e5..73b5b984c 100644 --- a/pkg/protocol/uri.go +++ b/pkg/protocol/uri.go @@ -49,7 +49,6 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" "github.com/cloudwego/hertz/internal/nocopy" - "github.com/cloudwego/hertz/pkg/common/hlog" ) // AcquireURI returns an empty URI instance from the pool. @@ -388,11 +387,7 @@ func getScheme(rawURL []byte) (scheme, path []byte) { return nil, rawURL } case c == ':': - if i == 0 { - hlog.Errorf("error happened when try to parse the rawURL(%s): missing protocol scheme", rawURL) - return nil, nil - } - return rawURL[:i], rawURL[i+1:] + return checkSchemeWhenCharIsColon(i, rawURL) default: // we have encountered an invalid character, // so there is no valid scheme diff --git a/pkg/protocol/uri_unix.go b/pkg/protocol/uri_unix.go index 0127ceef0..d3726d8aa 100644 --- a/pkg/protocol/uri_unix.go +++ b/pkg/protocol/uri_unix.go @@ -44,6 +44,8 @@ package protocol +import "github.com/cloudwego/hertz/pkg/common/hlog" + func addLeadingSlash(dst, src []byte) []byte { // add leading slash for unix paths if len(src) == 0 || src[0] != '/' { @@ -52,3 +54,13 @@ func addLeadingSlash(dst, src []byte) []byte { return dst } + +// checkSchemeWhenCharIsColon check url begin with : +// Scenarios that handle protocols like "http:" +func checkSchemeWhenCharIsColon(i int, rawURL []byte) (scheme, path []byte) { + if i == 0 { + hlog.Errorf("error happened when try to parse the rawURL(%s): missing protocol scheme", rawURL) + return + } + return rawURL[:i], rawURL[i+1:] +} diff --git a/pkg/protocol/uri_unix_test.go b/pkg/protocol/uri_unix_test.go new file mode 100644 index 000000000..f89cd112c --- /dev/null +++ b/pkg/protocol/uri_unix_test.go @@ -0,0 +1,44 @@ +//go:build !windows +// +build !windows + +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package protocol + +import ( + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +func TestGetScheme(t *testing.T) { + scheme, path := getScheme([]byte("https://foo.com")) + assert.DeepEqual(t, "https", string(scheme)) + assert.DeepEqual(t, "//foo.com", string(path)) + + scheme, path = getScheme([]byte(":")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "", string(path)) + + scheme, path = getScheme([]byte("ws://127.0.0.1")) + assert.DeepEqual(t, "ws", string(scheme)) + assert.DeepEqual(t, "//127.0.0.1", string(path)) + + scheme, path = getScheme([]byte("/hertz/demo")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "/hertz/demo", string(path)) +} diff --git a/pkg/protocol/uri_windows.go b/pkg/protocol/uri_windows.go index abf13e72f..2e0bf9df8 100644 --- a/pkg/protocol/uri_windows.go +++ b/pkg/protocol/uri_windows.go @@ -44,6 +44,8 @@ package protocol +import "github.com/cloudwego/hertz/pkg/common/hlog" + func addLeadingSlash(dst, src []byte) []byte { // zero length and "C:/" case isDisk := len(src) > 2 && src[1] == ':' @@ -53,3 +55,20 @@ func addLeadingSlash(dst, src []byte) []byte { return dst } + +// checkSchemeWhenCharIsColon check url begin with : +// Scenarios that handle protocols like "http:" +// Add the path to the win file, e.g. "E:\gopath", "E:\". +func checkSchemeWhenCharIsColon(i int, rawURL []byte) (scheme, path []byte) { + if i == 0 { + hlog.Errorf("error happened when trying to parse the rawURL(%s): missing protocol scheme", rawURL) + return + } + + // case :\ + if i+1 < len(rawURL) && rawURL[i+1] == '\\' { + return nil, rawURL + } + + return rawURL[:i], rawURL[i+1:] +} diff --git a/pkg/protocol/uri_windows_test.go b/pkg/protocol/uri_windows_test.go index 507924b97..0ec5e4e78 100644 --- a/pkg/protocol/uri_windows_test.go +++ b/pkg/protocol/uri_windows_test.go @@ -14,7 +14,11 @@ package protocol -import "testing" +import ( + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) func TestURIPathNormalizeIssue86(t *testing.T) { t.Parallel() @@ -26,3 +30,25 @@ func TestURIPathNormalizeIssue86(t *testing.T) { testURIPathNormalize(t, &u, "/..\\..\\..\\..\\..\\", "/") testURIPathNormalize(t, &u, "/..%5c..%5cfoo", "/foo") } + +func TestGetScheme(t *testing.T) { + scheme, path := getScheme([]byte("E:\\file.go")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "E:\\file.go", string(path)) + + scheme, path = getScheme([]byte("E:\\")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "E:\\", string(path)) + + scheme, path = getScheme([]byte("https://foo.com")) + assert.DeepEqual(t, "https", string(scheme)) + assert.DeepEqual(t, "//foo.com", string(path)) + + scheme, path = getScheme([]byte("://")) + assert.DeepEqual(t, "", string(scheme)) + assert.DeepEqual(t, "", string(path)) + + scheme, path = getScheme([]byte("ws://127.0.0.1")) + assert.DeepEqual(t, "ws", string(scheme)) + assert.DeepEqual(t, "//127.0.0.1", string(path)) +}