From 610fde07d40a366da7562c53a1bbdadca4f9f931 Mon Sep 17 00:00:00 2001 From: Bart Jeukendrup Date: Sun, 10 Dec 2023 20:19:11 +0100 Subject: [PATCH] feat: allow response rewrite being set from authorization service --- cmd/filter-proxy/main.go | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/cmd/filter-proxy/main.go b/cmd/filter-proxy/main.go index f6c5d07..9012b7c 100644 --- a/cmd/filter-proxy/main.go +++ b/cmd/filter-proxy/main.go @@ -28,11 +28,8 @@ type ClaimsWithGroups struct { } type AuthorizationResponse struct { - User struct { - Id int64 - Username string - Name string - } + Result bool `json:"result"` + ResponseFilter string `json:"response_filter"` } func main() { @@ -53,12 +50,17 @@ func main() { utils.DelHopHeaders(r.Header) - authorizationStatusCode, _ := authorizeRequestWithService(config, backend, path, r) + authorizationStatusCode, authorizationResponse := authorizeRequestWithService(config, backend, path, r) if authorizationStatusCode != http.StatusOK { writeError(w, authorizationStatusCode, "unauthorized request") return } + if !authorizationResponse.Result { + writeError(w, http.StatusUnauthorized, "result field is not true") + return + } + allowedMethods := path.AllowedMethods if len(allowedMethods) == 0 { allowedMethods = []string{"GET"} @@ -183,13 +185,20 @@ func main() { defer proxyResp.Body.Close() - if path.ResponseRewrite != "" && proxyResp.StatusCode == http.StatusOK { + if proxyResp.StatusCode == http.StatusOK && (path.ResponseRewrite != "" || authorizationResponse.ResponseFilter != "") { body, _ := io.ReadAll(proxyResp.Body) var result map[string]interface{} json.Unmarshal(body, &result) - query, err := gojq.Parse(path.ResponseRewrite) + var responseRewrite = "" + if authorizationResponse.ResponseFilter != "" { + responseRewrite = authorizationResponse.ResponseFilter + } else { + responseRewrite = path.ResponseRewrite + } + + query, err := gojq.Parse(responseRewrite) if err != nil { writeError(w, http.StatusInternalServerError, "could not parse filter") return