Skip to content

Commit

Permalink
fix bugs in ApplyStandardMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
johnwarden committed Oct 27, 2022
1 parent d3800c9 commit 4b0370f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
28 changes: 24 additions & 4 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,29 @@ func TestPanic(t *testing.T) {
func TestApplyStandardMiddleware(t *testing.T) {
{
h := httperror.ApplyStandardMiddleware(okHandler, myMiddleware)
s, m := testRequest(h, "/")
s, _ := testRequest(h, "/")
assert.Equal(t, 200, s)
assert.Equal(t, "OK\nDid Middleware\n", m, "got middleware output")
}

{
h := httperror.ApplyStandardMiddleware(notFoundHandler, myMiddleware)
s, m := testRequest(h, "/")
assert.Equal(t, 404, s)
assert.Equal(t, "404 Not Found\nDid Middleware\n", m, "got middleware output")
assert.Equal(t, "404 Not Found\n", m, "got correct response status")
}

{
inner := httperror.XApplyStandardMiddleware[string](nameHandler, myMiddleware)

h := httperror.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return inner(w, r, "Bill")
})

s, m := testRequest(h, "/")
assert.Equal(t, 200, s)
assert.Equal(t, "Hello, Bill\n", m, "got middleware output")
}

}

var getMeOuttaHere = httperror.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
Expand All @@ -94,6 +106,14 @@ var notFoundHandler = httperror.HandlerFunc(func(w http.ResponseWriter, _ *http.
return httperror.NotFound
})

var nameHandler = httperror.XHandlerFunc[string](func(w http.ResponseWriter, r *http.Request, name string) error {
w.Header().Set("Content-Type", "text/plain")

fmt.Fprintf(w, "Hello, %s\n", name)

return nil
})

func helloHandler(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Content-Type", "text/plain")

Expand Down Expand Up @@ -128,7 +148,7 @@ func customErrorHandler(w http.ResponseWriter, err error) {

func myMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Foo", "Bar")
h.ServeHTTP(w, r)
w.Write([]byte("Did Middleware\n"))
})
}
6 changes: 3 additions & 3 deletions standardmiddleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func XApplyStandardMiddleware[P any](h XHandler[P], m StandardMiddleware) XHandl
*errPtr = err
})

handler = m(h)
handler = m(handler)

return func(w http.ResponseWriter, r *http.Request, p P) error {
var err error
Expand All @@ -45,7 +45,6 @@ func XApplyStandardMiddleware[P any](h XHandler[P], m StandardMiddleware) XHandl
// ApplyStandardMiddleware applies middleware written for a standard [http.Handler] to an [httperror.XHandler].
// It works by passing parameters and returning errors through the context.
func ApplyStandardMiddleware(h Handler, m StandardMiddleware) HandlerFunc {
errPtrKey := contextKey("errPtr")

var handler http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
Expand All @@ -56,12 +55,13 @@ func ApplyStandardMiddleware(h Handler, m StandardMiddleware) HandlerFunc {
*errPtr = err
})

handler = m(h)
handler = m(handler)

return func(w http.ResponseWriter, r *http.Request) error {
var err error
c := r.Context()
c = context.WithValue(c, errPtrKey, &err)
c = context.WithValue(c, paramsKey, "no params")

handler.ServeHTTP(w, r.WithContext(c))

Expand Down

0 comments on commit 4b0370f

Please sign in to comment.