diff --git a/core/search/tree.go b/core/search/tree.go index c386660ce6b8..fdb8288165c1 100644 --- a/core/search/tree.go +++ b/core/search/tree.go @@ -3,6 +3,7 @@ package search import ( "errors" "fmt" + "strings" ) const ( @@ -80,17 +81,21 @@ func (t *Tree) Add(route string, item any) error { } // Search searches item that associates with given route. -func (t *Tree) Search(route string) (Result, bool) { +func (t *Tree) Search(route string, routeCaseSensitive bool) (Result, bool) { if len(route) == 0 || route[0] != slash { return NotFound, false } var result Result - ok := t.next(t.root, route[1:], &result) + ok := t.next(t.root, route[1:], routeCaseSensitive, &result) return result, ok } -func (t *Tree) next(n *node, route string, result *Result) bool { +func (t *Tree) next(n *node, route string, routeCaseSensitive bool, result *Result) bool { + if !routeCaseSensitive { + route = strings.ToLower(route) + } + if len(route) == 0 && n.item != nil { result.Item = n.item return true @@ -102,9 +107,16 @@ func (t *Tree) next(n *node, route string, result *Result) bool { } token := route[:i] + return n.forEach(func(k string, v *node) bool { + if !routeCaseSensitive { + k = strings.ToLower(k) + } + + fmt.Println(token, k) + r := match(k, token) - if !r.found || !t.next(v, route[i+1:], result) { + if !r.found || !t.next(v, route[i+1:], routeCaseSensitive, result) { return false } if r.named { @@ -116,6 +128,10 @@ func (t *Tree) next(n *node, route string, result *Result) bool { } return n.forEach(func(k string, v *node) bool { + if !routeCaseSensitive { + k = strings.ToLower(k) + } + if r := match(k, route); r.found && v.item != nil { result.Item = v.item if r.named { diff --git a/core/search/tree_test.go b/core/search/tree_test.go index 1bad563a62ae..7418fa4440d1 100644 --- a/core/search/tree_test.go +++ b/core/search/tree_test.go @@ -98,7 +98,7 @@ func TestSearch(t *testing.T) { for _, r := range routes { tree.Add(r.route, r.value) } - result, ok := tree.Search(test.query) + result, ok := tree.Search(test.query, true) assert.Equal(t, test.contains, ok) if ok { actual := result.Item.(int) @@ -122,7 +122,7 @@ func TestStrictSearch(t *testing.T) { } for i := 0; i < 1000; i++ { - result, ok := tree.Search(query) + result, ok := tree.Search(query, true) assert.True(t, ok) assert.Equal(t, 1, result.Item.(int)) } @@ -142,11 +142,22 @@ func TestStrictSearchSibling(t *testing.T) { tree.Add(r.route, r.value) } - result, ok := tree.Search(query) + result, ok := tree.Search(query, true) assert.True(t, ok) assert.Equal(t, 3, result.Item.(int)) } +func TestRoutePathCaseInsensitive(t *testing.T) { + tree := NewTree() + tree.Add("/api/:user/profile/name", 1) + + query := "/aPI/123/prOfiLe/NaMe" + + result, ok := tree.Search(query, false) + assert.True(t, ok) + assert.Equal(t, 1, result.Item.(int)) +} + func TestAddDuplicate(t *testing.T) { tree := NewTree() err := tree.Add("/a/b", 1) @@ -163,7 +174,7 @@ func TestPlain(t *testing.T) { assert.Nil(t, err) err = tree.Add("/a/c", 2) assert.Nil(t, err) - _, ok := tree.Search("/a/d") + _, ok := tree.Search("/a/d", true) assert.False(t, ok) } @@ -227,6 +238,6 @@ func BenchmarkSearchTree(b *testing.B) { } for i := 0; i < b.N; i++ { - tree.Search(query) + tree.Search(query, true) } } diff --git a/go.mod b/go.mod index 1d8026cac025..131761eda7df 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,8 @@ module github.com/zeromicro/go-zero -go 1.20 +go 1.21 + +toolchain go1.23.1 require ( github.com/DATA-DOG/go-sqlmock v1.5.2 diff --git a/go.sum b/go.sum index cedb04f7ad24..d27c3ea9d2ce 100644 --- a/go.sum +++ b/go.sum @@ -7,10 +7,13 @@ github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGn github.com/alicebob/miniredis/v2 v2.33.0 h1:uvTF0EDeu9RLnUEG27Db5I68ESoIxTiXbNUiji6lZrA= github.com/alicebob/miniredis/v2 v2.33.0/go.mod h1:MhP4a3EU7aENRi9aO+tHfTBZicLqQevyi/DJpoj6mi0= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/FBatYVw= github.com/bufbuild/protocompile v0.14.1/go.mod h1:ppVdAIhbr2H8asPk6k4pY7t9zB1OU5DoEw9xY/FUi1c= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= @@ -55,6 +58,7 @@ github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= +github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= @@ -75,6 +79,7 @@ github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= +github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= @@ -102,6 +107,7 @@ github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2 github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= @@ -132,15 +138,19 @@ github.com/nbio/st v0.0.0-20140626010706-e9e8d9816f32/go.mod h1:9wM+0iRr9ahx58uY github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/onsi/ginkgo/v2 v2.13.0 h1:0jY9lJquiL8fcf3M4LAXN5aMlS/b2BV86HFFPCPMgE4= +github.com/onsi/ginkgo/v2 v2.13.0/go.mod h1:TE309ZR8s5FsKKpuB1YAQYBzCaAfUgatB/xlT/ETL/o= github.com/onsi/gomega v1.29.0 h1:KIA/t2t5UBzoirT4H9tsML45GEbo3ouUnBHsCfD2tVg= +github.com/onsi/gomega v1.29.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ= github.com/openzipkin/zipkin-go v0.4.3 h1:9EGwpqkgnwdEIJ+Od7QVSEIH+ocmm5nPat0G7sjsSdg= github.com/openzipkin/zipkin-go v0.4.3/go.mod h1:M9wCJZFWCo2RiY+o1eBCEMe0Dp2S5LDHcMZmk3RmK7c= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/prometheus/client_golang v1.20.4 h1:Tgh3Yr67PaOv/uTqloMsCEdeuFTatm5zIq5+qNN23vI= github.com/prometheus/client_golang v1.20.4/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= @@ -154,9 +164,11 @@ github.com/redis/go-redis/v9 v9.6.1/go.mod h1:0C0c6ycQsdpVNQpxb1njEQIqkx5UcsM8FJ github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rogpeppe/go-internal v1.10.0 h1:TMyTOH3F/DB16zRVcYyreMH6GnZZrwQVAoYjRBZyWFQ= +github.com/rogpeppe/go-internal v1.10.0/go.mod h1:UQnix2H7Ngw/k4C5ijL5+65zddjncjaFoBhdsK/akog= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -284,6 +296,7 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/rest/config.go b/rest/config.go index eb5fdb0ba234..da43eab0b2a3 100644 --- a/rest/config.go +++ b/rest/config.go @@ -44,13 +44,14 @@ type ( // if with the name Conf, there will be two Conf inside Config. RestConf struct { service.ServiceConf - Host string `json:",default=0.0.0.0"` - Port int - CertFile string `json:",optional"` - KeyFile string `json:",optional"` - Verbose bool `json:",optional"` - MaxConns int `json:",default=10000"` - MaxBytes int64 `json:",default=1048576"` + Host string `json:",default=0.0.0.0"` + Port int + CertFile string `json:",optional"` + KeyFile string `json:",optional"` + Verbose bool `json:",optional"` + MaxConns int `json:",default=10000"` + MaxBytes int64 `json:",default=1048576"` + RoutePathsCaseSensitive bool `json:",default=true"` // milliseconds Timeout int64 `json:",default=3000"` CpuThreshold int64 `json:",default=900,range=[0:1000)"` diff --git a/rest/engine_test.go b/rest/engine_test.go index 4f86d2173efd..b769d923d12b 100644 --- a/rest/engine_test.go +++ b/rest/engine_test.go @@ -414,7 +414,7 @@ func TestEngine_start(t *testing.T) { Host: "localhost", Port: -1, }) - assert.Error(t, ng.start(router.NewRouter())) + assert.Error(t, ng.start(router.NewRouter(true))) }) t.Run("https", func(t *testing.T) { @@ -425,7 +425,7 @@ func TestEngine_start(t *testing.T) { KeyFile: "bar", }) ng.tlsConfig = &tls.Config{} - assert.Error(t, ng.start(router.NewRouter())) + assert.Error(t, ng.start(router.NewRouter(true))) }) } diff --git a/rest/httpc/requests_test.go b/rest/httpc/requests_test.go index 440752110857..a38fbb08bba5 100644 --- a/rest/httpc/requests_test.go +++ b/rest/httpc/requests_test.go @@ -70,7 +70,7 @@ func TestDo(t *testing.T) { Body string `json:"body"` } - rt := router.NewRouter() + rt := router.NewRouter(true) err := rt.Handle(http.MethodPost, "/nodes/:key", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req Data @@ -107,7 +107,7 @@ func TestDo_Ptr(t *testing.T) { Body string `json:"body"` } - rt := router.NewRouter() + rt := router.NewRouter(true) err := rt.Handle(http.MethodPost, "/nodes/:key", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req Data @@ -192,7 +192,7 @@ func TestDo_Json(t *testing.T) { Body chan int `json:"body"` } - rt := router.NewRouter() + rt := router.NewRouter(true) err := rt.Handle(http.MethodPost, "/nodes/:key", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var req Data diff --git a/rest/router/patrouter.go b/rest/router/patrouter.go index d4af14b361ad..1d21bac22892 100644 --- a/rest/router/patrouter.go +++ b/rest/router/patrouter.go @@ -24,15 +24,17 @@ var ( ) type patRouter struct { - trees map[string]*search.Tree - notFound http.Handler - notAllowed http.Handler + routePathsCaseSensitive bool + trees map[string]*search.Tree + notFound http.Handler + notAllowed http.Handler } // NewRouter returns a httpx.Router. -func NewRouter() httpx.Router { +func NewRouter(routePathsCaseSensitive bool) httpx.Router { return &patRouter{ - trees: make(map[string]*search.Tree), + routePathsCaseSensitive: routePathsCaseSensitive, + trees: make(map[string]*search.Tree), } } @@ -59,7 +61,7 @@ func (pr *patRouter) Handle(method, reqPath string, handler http.Handler) error func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { reqPath := path.Clean(r.URL.Path) if tree, ok := pr.trees[r.Method]; ok { - if result, ok := tree.Search(reqPath); ok { + if result, ok := tree.Search(reqPath, pr.routePathsCaseSensitive); ok { if len(result.Params) > 0 { r = pathvar.WithVars(r, result.Params) } @@ -106,7 +108,7 @@ func (pr *patRouter) methodsAllowed(method, path string) (string, bool) { continue } - _, ok := tree.Search(path) + _, ok := tree.Search(path, pr.routePathsCaseSensitive) if ok { allows = append(allows, treeMethod) } diff --git a/rest/router/patrouter_test.go b/rest/router/patrouter_test.go index 4b3e8d1fe9ef..8a4dc9201ba0 100644 --- a/rest/router/patrouter_test.go +++ b/rest/router/patrouter_test.go @@ -45,7 +45,7 @@ func TestPatRouterHandleErrors(t *testing.T) { for _, test := range tests { t.Run(test.method, func(t *testing.T) { - router := NewRouter() + router := NewRouter(true) err := router.Handle(test.method, test.path, nil) assert.Equal(t, test.err, err) }) @@ -54,7 +54,7 @@ func TestPatRouterHandleErrors(t *testing.T) { func TestPatRouterNotFound(t *testing.T) { var notFound bool - router := NewRouter() + router := NewRouter(true) router.SetNotFoundHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { notFound = true })) @@ -69,7 +69,7 @@ func TestPatRouterNotFound(t *testing.T) { func TestPatRouterNotAllowed(t *testing.T) { var notAllowed bool - router := NewRouter() + router := NewRouter(true) router.SetNotAllowedHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { notAllowed = true })) @@ -102,7 +102,7 @@ func TestPatRouter(t *testing.T) { for _, test := range tests { t.Run(test.method+":"+test.path, func(t *testing.T) { routed := false - router := NewRouter() + router := NewRouter(true) err := router.Handle(test.method, "/a/:b", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { routed = true @@ -143,7 +143,7 @@ func TestParseSlice(t *testing.T) { assert.Nil(t, err) r.Header.Set("Content-Type", "application/x-www-form-urlencoded") - rt := NewRouter() + rt := NewRouter(true) err = rt.Handle(http.MethodPost, "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { v := struct { Names []string `form:"names"` @@ -167,7 +167,7 @@ func TestParseJsonPost(t *testing.T) { assert.Nil(t, err) r.Header.Set(httpx.ContentType, httpx.JsonContentType) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func( w http.ResponseWriter, r *http.Request) { v := struct { @@ -199,7 +199,7 @@ func TestParseJsonPostWithIntSlice(t *testing.T) { assert.Nil(t, err) r.Header.Set(httpx.ContentType, httpx.JsonContentType) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func( w http.ResponseWriter, r *http.Request) { v := struct { @@ -227,7 +227,7 @@ func TestParseJsonPostError(t *testing.T) { assert.Nil(t, err) r.Header.Set(httpx.ContentType, httpx.JsonContentType) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -255,7 +255,7 @@ func TestParseJsonPostInvalidRequest(t *testing.T) { assert.Nil(t, err) r.Header.Set(httpx.ContentType, httpx.JsonContentType) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -277,7 +277,7 @@ func TestParseJsonPostRequired(t *testing.T) { assert.Nil(t, err) r.Header.Set(httpx.ContentType, httpx.JsonContentType) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -300,7 +300,7 @@ func TestParsePath(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017", nil) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -325,7 +325,7 @@ func TestParsePathRequired(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin", nil) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -346,7 +346,7 @@ func TestParseQuery(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -371,7 +371,7 @@ func TestParseQueryRequired(t *testing.T) { r, err := http.NewRequest(http.MethodPost, "http://hello.com/kevin/2017?nickname=whatever", nil) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { v := struct { Nickname string `form:"nickname"` @@ -391,7 +391,7 @@ func TestParseOptional(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -432,7 +432,7 @@ func TestParseNestedInRequestEmpty(t *testing.T) { } ) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { var v WrappedRequest @@ -471,7 +471,7 @@ func TestParsePtrInRequest(t *testing.T) { } ) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { var v WrappedRequest @@ -502,7 +502,7 @@ func TestParsePtrInRequestEmpty(t *testing.T) { } ) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/kevin", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { var v WrappedRequest @@ -519,7 +519,7 @@ func TestParseQueryOptional(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=", nil) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -544,7 +544,7 @@ func TestParse(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -582,7 +582,7 @@ func TestParseWrappedRequest(t *testing.T) { } ) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { var v WrappedRequest @@ -614,7 +614,7 @@ func TestParseWrappedGetRequestWithJsonHeader(t *testing.T) { } ) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { var v WrappedRequest @@ -647,7 +647,7 @@ func TestParseWrappedHeadRequestWithJsonHeader(t *testing.T) { } ) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodHead, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { var v WrappedRequest @@ -679,7 +679,7 @@ func TestParseWrappedRequestPtr(t *testing.T) { } ) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { var v WrappedRequest @@ -702,7 +702,7 @@ func TestParseWithAll(t *testing.T) { assert.Nil(t, err) r.Header.Set(httpx.ContentType, httpx.JsonContentType) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { v := struct { Name string `path:"name"` @@ -733,7 +733,7 @@ func TestParseWithAllUtf8(t *testing.T) { assert.Nil(t, err) r.Header.Set(httpx.ContentType, header.JsonContentType) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -764,7 +764,7 @@ func TestParseWithMissingForm(t *testing.T) { bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -791,7 +791,7 @@ func TestParseWithMissingAllForms(t *testing.T) { bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -817,7 +817,7 @@ func TestParseWithMissingJson(t *testing.T) { bytes.NewBufferString(`{"location": "shanghai"}`)) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -843,7 +843,7 @@ func TestParseWithMissingAllJsons(t *testing.T) { r, err := http.NewRequest(http.MethodGet, "http://hello.com/kevin/2017?nickname=whatever&zipcode=200000", nil) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -870,7 +870,7 @@ func TestParseWithMissingPath(t *testing.T) { bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -897,7 +897,7 @@ func TestParseWithMissingAllPaths(t *testing.T) { bytes.NewBufferString(`{"location": "shanghai", "time": 20170912}`)) assert.Nil(t, err) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -924,7 +924,7 @@ func TestParseGetWithContentLengthHeader(t *testing.T) { r.Header.Set(httpx.ContentType, header.JsonContentType) r.Header.Set(contentLength, "1024") - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodGet, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -951,7 +951,7 @@ func TestParseJsonPostWithTypeMismatch(t *testing.T) { assert.Nil(t, err) r.Header.Set(httpx.ContentType, header.JsonContentType) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -977,7 +977,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) { assert.Nil(t, err) r.Header.Set(httpx.ContentType, header.JsonContentType) - router := NewRouter() + router := NewRouter(true) err = router.Handle(http.MethodPost, "/:name/:year", http.HandlerFunc( func(w http.ResponseWriter, r *http.Request) { v := struct { @@ -998,7 +998,7 @@ func TestParseJsonPostWithInt2String(t *testing.T) { func BenchmarkPatRouter(b *testing.B) { b.ReportAllocs() - router := NewRouter() + router := NewRouter(true) router.Handle(http.MethodGet, "/api/:user/:name", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { })) w := &mockedResponseWriter{} diff --git a/rest/server.go b/rest/server.go index b1e5487bd8a5..8460d76e2e5c 100644 --- a/rest/server.go +++ b/rest/server.go @@ -52,7 +52,7 @@ func NewServer(c RestConf, opts ...RunOption) (*Server, error) { server := &Server{ ngin: newEngine(c), - router: router.NewRouter(), + router: router.NewRouter(c.RoutePathsCaseSensitive), } opts = append([]RunOption{WithNotFoundHandler(nil)}, opts...) diff --git a/rest/server_test.go b/rest/server_test.go index 9a92d58f8203..3c79ce2fd64b 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -136,7 +136,7 @@ func TestWithMaxBytes(t *testing.T) { func TestWithMiddleware(t *testing.T) { m := make(map[string]string) - rt := router.NewRouter() + rt := router.NewRouter(false) handler := func(w http.ResponseWriter, r *http.Request) { var v struct { Nickname string `form:"nickname"` @@ -243,7 +243,7 @@ func TestWithFileServerMiddleware(t *testing.T) { func TestMultiMiddlewares(t *testing.T) { m := make(map[string]string) - rt := router.NewRouter() + rt := router.NewRouter(false) handler := func(w http.ResponseWriter, r *http.Request) { var v struct { Nickname string `form:"nickname"` @@ -392,7 +392,7 @@ Port: 54321 ` var cnf RestConf assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) - rt := router.NewRouter() + rt := router.NewRouter(cnf.RoutePathsCaseSensitive) svr, err := NewServer(cnf, WithRouter(rt)) assert.Nil(t, err) defer svr.Stop() @@ -408,7 +408,7 @@ Port: 54321 ` var cnf RestConf assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) - rt := router.NewRouter() + rt := router.NewRouter(cnf.RoutePathsCaseSensitive) svr, err := NewServer(cnf, WithRouter(rt)) assert.Nil(t, err) @@ -447,7 +447,7 @@ Port: 54321 ` var cnf RestConf assert.Nil(t, conf.LoadFromYamlBytes([]byte(configYaml), &cnf)) - rt := router.NewRouter() + rt := router.NewRouter(cnf.RoutePathsCaseSensitive) svr, err := NewServer(cnf, WithRouter(rt)) assert.Nil(t, err) defer svr.Stop() @@ -639,7 +639,7 @@ func TestServer_WithChain(t *testing.T) { }, }, ) - rt := router.NewRouter() + rt := router.NewRouter(true) assert.Nil(t, server.ngin.bindRoutes(rt)) req, err := http.NewRequest(http.MethodGet, "/", http.NoBody) assert.Nil(t, err) @@ -655,7 +655,7 @@ func TestServer_WithCors(t *testing.T) { next.ServeHTTP(w, r) }) } - r := router.NewRouter() + r := router.NewRouter(true) assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler()))) cr := &corsRouter{