diff --git a/appsTransport.go b/appsTransport.go index a67510c..2cfb914 100644 --- a/appsTransport.go +++ b/appsTransport.go @@ -24,7 +24,7 @@ type AppsTransport struct { BaseURL string // BaseURL is the scheme and host for GitHub API, defaults to https://api.github.com Client Client // Client to use to refresh tokens, defaults to http.Client with provided transport tr http.RoundTripper // tr is the underlying roundtripper being wrapped - signer Signer // signer signs JWT tokens. + signer SignerWithContext // signer signs JWT tokens. appID int64 // appID is the GitHub App's ID } @@ -94,7 +94,7 @@ func (t *AppsTransport) RoundTrip(req *http.Request) (*http.Response, error) { Issuer: strconv.FormatInt(t.appID, 10), } - ss, err := t.signer.Sign(claims) + ss, err := t.signer.SignContext(req.Context(), claims) if err != nil { return nil, fmt.Errorf("could not sign jwt: %s", err) } @@ -113,8 +113,19 @@ func (t *AppsTransport) AppID() int64 { type AppsTransportOption func(*AppsTransport) -// WithSigner configures the AppsTransport to use the given Signer for generating JWT tokens. +// WithSigner configures the AppsTransport to use the given Signer for +// generating JWT tokens. +// +// Deprecated: Use [WithContextSigner] instead. func WithSigner(signer Signer) AppsTransportOption { + return func(at *AppsTransport) { + at.signer = SignerWithContextAdapter{signer} + } +} + +// WithContextSigner configures the AppsTransport to use the given Signer for +// generating JWT tokens. +func WithContextSigner(signer SignerWithContext) AppsTransportOption { return func(at *AppsTransport) { at.signer = signer } diff --git a/appsTransport_test.go b/appsTransport_test.go index b6e93c8..ccfbb9f 100644 --- a/appsTransport_test.go +++ b/appsTransport_test.go @@ -2,6 +2,7 @@ package ghinstallation import ( "bytes" + "context" "fmt" "io/ioutil" "net/http" @@ -111,33 +112,76 @@ func TestJWTExpiry(t *testing.T) { } func TestCustomSigner(t *testing.T) { - check := RoundTrip{ - rt: func(req *http.Request) (*http.Response, error) { - h, ok := req.Header["Authorization"] - if !ok { - t.Error("Header Accept not set") - } - want := []string{"Bearer hunter2"} - if diff := cmp.Diff(want, h); diff != "" { - t.Errorf("HTTP Accept headers want->got: %s", diff) - } - return nil, nil + tc := []struct { + nm string + option AppsTransportOption + bearerSuffix string + }{ + { + nm: "context-free signer", + option: WithSigner(&noopSigner{}), + bearerSuffix: "", + }, + { + nm: "context signer", + option: WithContextSigner(&noopSigner{}), + bearerSuffix: ":context", }, } - tr, err := NewAppsTransportWithOptions(check, appID, WithSigner(&noopSigner{})) - if err != nil { - t.Fatalf("NewAppsTransportWithOptions: %v", err) - } + for _, c := range tc { + t.Run(c.nm, func(t *testing.T) { + check := AuthCaptureRoundTripper{} - req := httptest.NewRequest(http.MethodGet, "http://example.com", new(bytes.Buffer)) - if _, err := tr.RoundTrip(req); err != nil { - t.Fatalf("error calling RoundTrip: %v", err) + tr, err := NewAppsTransportWithOptions(&check, appID, c.option) + if err != nil { + t.Fatalf("NewAppsTransportWithOptions: %v", err) + } + + req := httptest.NewRequest(http.MethodGet, "http://example.com", new(bytes.Buffer)) + req = req.WithContext( + context.WithValue(context.Background(), contextSignerKey("test"), c.bearerSuffix), + ) + + if _, err := tr.RoundTrip(req); err != nil { + t.Fatalf("error calling RoundTrip: %v", err) + } + + if !check.Captured { + t.Error("Header Authorization not set") + } + + want := []string{"Bearer hunter2" + c.bearerSuffix} + if diff := cmp.Diff(want, check.Value); diff != "" { + t.Errorf("HTTP Authorization header want->got: %s", diff) + } + }) } } +type contextSignerKey string + type noopSigner struct{} func (noopSigner) Sign(jwt.Claims) (string, error) { return "hunter2", nil } + +func (noopSigner) SignContext(ctx context.Context, _ jwt.Claims) (string, error) { + // mark the returned token with the context suffix expected by the test + v := ctx.Value(contextSignerKey("test")) + return fmt.Sprintf("hunter2%v", v), nil +} + +type AuthCaptureRoundTripper struct { + Captured bool + Value []string +} + +func (a *AuthCaptureRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + h, ok := req.Header["Authorization"] + a.Captured = ok + a.Value = h + + return nil, nil +} diff --git a/sign.go b/sign.go index 928e10e..1039ec6 100644 --- a/sign.go +++ b/sign.go @@ -1,19 +1,39 @@ package ghinstallation import ( + "context" "crypto/rsa" jwt "github.com/golang-jwt/jwt/v4" ) -// Signer is a JWT token signer. This is a wrapper around [jwt.SigningMethod] with predetermined -// key material. +// Signer is a JWT token signer. This is a wrapper around [jwt.SigningMethod] +// with predetermined key material. type Signer interface { // Sign signs the given claims and returns a JWT token string, as specified // by [jwt.Token.SignedString] Sign(claims jwt.Claims) (string, error) } +// SignerWithContext is a JWT token signer. This is a wrapper around +// [jwt.SigningMethod] with predetermined key material. +type SignerWithContext interface { + // SignContext signs the given claims and returns a JWT token string, as + // specified by [jwt.Token.SignedString]. The signing operation should use the + // provided context as appropriate. + SignContext(ctx context.Context, claims jwt.Claims) (string, error) +} + +// SignerWithContextAdapter is a simple [Signer] wrapper that allows it to act +// as a [SignerWithContext]. +type SignerWithContextAdapter struct { + Signer Signer +} + +func (s SignerWithContextAdapter) SignContext(_ context.Context, claims jwt.Claims) (string, error) { + return s.Signer.Sign(claims) +} + // RSASigner signs JWT tokens using RSA keys. type RSASigner struct { method *jwt.SigningMethodRSA @@ -28,6 +48,13 @@ func NewRSASigner(method *jwt.SigningMethodRSA, key *rsa.PrivateKey) *RSASigner } // Sign signs the JWT claims with the RSA key. +// +// Deprecated: Use [SignContext] instead. func (s *RSASigner) Sign(claims jwt.Claims) (string, error) { return jwt.NewWithClaims(s.method, claims).SignedString(s.key) } + +// Sign signs the JWT claims with the RSA key. +func (s *RSASigner) SignContext(_ context.Context, claims jwt.Claims) (string, error) { + return s.Sign(claims) +}