diff --git a/middleware/csrf/custom_errorfunc/main.go b/middleware/csrf/custom_errorfunc/main.go index 4642276..f4557e7 100644 --- a/middleware/csrf/custom_errorfunc/main.go +++ b/middleware/csrf/custom_errorfunc/main.go @@ -45,14 +45,21 @@ func myErrFunc(_ context.Context, ctx *app.RequestContext) { return } - if errors.Is(err, errMissingForm) || errors.Is(err, errMissingParam) || errors.Is(err, errMissingHeader) || errors.Is(err, errMissingQuery) { - ctx.String(http.StatusBadRequest, err.Error()) // extract csrf-token failed - } else if errors.Is(err, errMissingSalt) { - fmt.Println(err.Error()) - ctx.String(http.StatusInternalServerError, err.Error()) // get salt failed, which is unexpected - } else if errors.Is(err, errInvalidToken) { - ctx.String(http.StatusBadRequest, err.Error()) // csrf-token is invalid + switch err.Err.(type) { + case error: + switch { + case errors.Is(err, errMissingForm), errors.Is(err, errMissingParam), errors.Is(err, errMissingHeader), errors.Is(err, errMissingQuery): + ctx.String(http.StatusBadRequest, err.Error()) // extract csrf-token failed + case errors.Is(err, errMissingSalt): + fmt.Println(err.Error()) + ctx.String(http.StatusInternalServerError, err.Error()) // get salt failed, which is unexpected + case errors.Is(err, errInvalidToken): + ctx.String(http.StatusBadRequest, err.Error()) // csrf-token is invalid + default: + ctx.String(http.StatusInternalServerError, "Unknown error") // handle unknown errors + } } + ctx.Abort() }