From 22f01c79af6d52e9b525bbeb3e5db5878592f57a Mon Sep 17 00:00:00 2001 From: Noah Dietz Date: Tue, 9 Jul 2019 19:28:50 -0700 Subject: [PATCH] add OPTIONS handler for CORS (#4) --- server/server.go | 17 ++++++++++++- server/server_test.go | 59 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/server/server.go b/server/server.go index 6ace2a1..66ff5c2 100644 --- a/server/server.go +++ b/server/server.go @@ -28,6 +28,8 @@ import ( "google.golang.org/grpc/status" ) +const fallbackPath = "/$rpc/{service:[.a-zA-Z0-9]+}/{method:[a-zA-Z]+}" + // FallbackServer is a grpc-fallback HTTP server. type FallbackServer struct { backend string @@ -84,7 +86,10 @@ func (f *FallbackServer) preStart() { // setup grpc-fallback complient router r := mux.NewRouter() - r.HandleFunc("/$rpc/{service:[.a-zA-Z0-9]+}/{method:[a-zA-Z]+}", f.handler).Headers("Content-Type", "application/x-protobuf") + r.HandleFunc(fallbackPath, f.options). + Methods(http.MethodOptions) + r.HandleFunc(fallbackPath, f.handler). + Headers("Content-Type", "application/x-protobuf") f.server.Handler = r } @@ -144,3 +149,13 @@ func (f *FallbackServer) dial() (connection, error) { return grpc.Dial(f.backend, opts...) } + +// options is a handler for the OPTIONS call that precedes CORS-enabled calls. +func (f *FallbackServer) options(w http.ResponseWriter, r *http.Request) { + w.Header().Add("access-control-allow-credentials", "true") + w.Header().Add("access-control-allow-headers", "*") + w.Header().Add("access-control-allow-methods", http.MethodPost) + w.Header().Add("access-control-allow-origin", "*") + w.Header().Add("access-control-max-age", "3600") + w.WriteHeader(http.StatusOK) +} diff --git a/server/server_test.go b/server/server_test.go index 6d78e8e..69eed85 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -249,3 +249,62 @@ func TestFallbackServer_preStart(t *testing.T) { }) } } + +func TestFallbackServer_options(t *testing.T) { + type fields struct { + backend string + server http.Server + cc connection + } + type args struct { + w http.ResponseWriter + r *http.Request + } + + req, _ := http.NewRequest("OPTIONS", "/test", nil) + hdr := make(http.Header) + hdr.Add("access-control-allow-credentials", "true") + hdr.Add("access-control-allow-headers", "*") + hdr.Add("access-control-allow-methods", http.MethodPost) + hdr.Add("access-control-allow-origin", "*") + hdr.Add("access-control-max-age", "3600") + + tests := []struct { + name string + fields fields + args args + wantHeader http.Header + }{ + { + name: "basic", + args: args{ + r: req, + w: &testRespWriter{}, + }, + fields: fields{ + cc: &testConnection{}, + }, + wantHeader: hdr, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &FallbackServer{ + backend: tt.fields.backend, + server: tt.fields.server, + cc: tt.fields.cc, + } + f.options(tt.args.w, tt.args.r) + + resp := tt.args.w.(*testRespWriter) + + if resp.code != http.StatusOK { + t.Errorf("handler() %s: got = %d, want = %d", tt.name, resp.code, http.StatusOK) + } + + if !reflect.DeepEqual(resp.Header(), tt.wantHeader) { + t.Errorf("handler() %s: got = %s, want = %s", tt.name, resp.Header(), tt.wantHeader) + } + }) + } +}