Skip to content

Commit

Permalink
Collect Prometheus metrics with token introspection result status
Browse files Browse the repository at this point in the history
  • Loading branch information
vasayxtx committed Dec 11, 2024
1 parent db558a7 commit 8e54d4a
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 15 deletions.
2 changes: 1 addition & 1 deletion idptoken/grpc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func NewGRPCClientWithOpts(
client: pb.NewIDPTokenServiceClient(conn),
clientConn: conn,
reqTimeout: opts.RequestTimeout,
promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "grpc_client"),
promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceGRPCClient),
}, nil
}

Expand Down
4 changes: 1 addition & 3 deletions idptoken/introspector.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ import (

const minAccessTokenProviderInvalidationInterval = time.Minute

const tokenIntrospectorPromSource = "token_introspector"

const (
// DefaultIntrospectionClaimsCacheMaxEntries is a default maximum number of entries in the claims cache.
// Claims cache is used for storing introspected active tokens.
Expand Down Expand Up @@ -250,7 +248,7 @@ func NewIntrospectorWithOpts(accessTokenProvider IntrospectionTokenProvider, opt
}
scopeFilterFormURLEncoded := values.Encode()

promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, tokenIntrospectorPromSource)
promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceTokenIntrospector)

claimsCache := makeIntrospectionClaimsCache(opts.ClaimsCache, DefaultIntrospectionClaimsCacheMaxEntries, promMetrics)
if opts.ClaimsCache.TTL == 0 {
Expand Down
2 changes: 1 addition & 1 deletion idptoken/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func NewMultiSourceProviderWithOpts(sources []Source, opts ProviderOpts) *MultiS
minRefreshPeriod: opts.MinRefreshPeriod,
logger: idputil.PrepareLogger(opts.Logger),
tokenIssuers: make(map[string]*oauth2Issuer),
promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "token_provider"),
promMetrics: metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceTokenProvider),
customHeaders: opts.CustomHeaders,
cache: opts.CustomCacheInstance,
httpClient: opts.HTTPClient,
Expand Down
4 changes: 2 additions & 2 deletions idptoken/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func TestProviderWithCache(t *testing.T) {
metrics.HTTPClientRequestLabelStatusCode: "500",
metrics.HTTPClientRequestLabelError: "unexpected_status_code",
}
promMetrics := metrics.GetPrometheusMetrics("", "token_provider")
promMetrics := metrics.GetPrometheusMetrics("", metrics.SourceTokenProvider)
hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram)
testutil.AssertSamplesCountInHistogram(t, hist, 1)
})
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestProviderWithCache(t *testing.T) {
metrics.HTTPClientRequestLabelStatusCode: "200",
metrics.HTTPClientRequestLabelError: "",
}
promMetrics := metrics.GetPrometheusMetrics("", "token_provider")
promMetrics := metrics.GetPrometheusMetrics("", metrics.SourceTokenProvider)
hist := promMetrics.HTTPClientRequestDuration.With(labels).(prometheus.Histogram)
testutil.AssertSamplesCountInHistogram(t, hist, 1)
})
Expand Down
41 changes: 39 additions & 2 deletions internal/metrics/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,31 @@ const (

GRPCClientRequestLabelMethod = "grpc_method"
GRPCClientRequestLabelCode = "grpc_code"

TokenIntrospectionLabelStatus = "status"
)

const (
HTTPRequestErrorDo = "do_request_error"
HTTPRequestErrorDecodeBody = "decode_body_error"
HTTPRequestErrorUnexpectedStatusCode = "unexpected_status_code"

TokenIntrospectionStatusActive = "active"
TokenIntrospectionStatusNotActive = "not_active"
TokenIntrospectionStatusNotNeeded = "not_needed"
TokenIntrospectionStatusNotIntrospectable = "not_introspectable"
TokenIntrospectionStatusError = "error"
)

type Source string

const (
SourceJWKSClient Source = "jwks_client"
SourceJWTParser Source = "jwt_parser"
SourceGRPCClient Source = "grpc_client"
SourceTokenIntrospector Source = "token_introspector"
SourceTokenProvider Source = "token_provider"
SourceHTTPMiddleware Source = "http_middleware"
)

var requestDurationBuckets = []float64{0.005, 0.01, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10}
Expand All @@ -58,12 +77,13 @@ var (
type PrometheusMetrics struct {
HTTPClientRequestDuration *prometheus.HistogramVec
GRPCClientRequestDuration *prometheus.HistogramVec
TokenIntrospectionsTotal *prometheus.CounterVec
TokenClaimsCache *lrucache.PrometheusMetrics
TokenNegativeCache *lrucache.PrometheusMetrics
EndpointDiscoveryCache *lrucache.PrometheusMetrics
}

func GetPrometheusMetrics(instance string, source string) *PrometheusMetrics {
func GetPrometheusMetrics(instance string, source Source) *PrometheusMetrics {
prometheusMetricsOnce.Do(func() {
prometheusMetrics = newPrometheusMetrics()
prometheusMetrics.MustRegister()
Expand All @@ -73,7 +93,7 @@ func GetPrometheusMetrics(instance string, source string) *PrometheusMetrics {
}
return prometheusMetrics.MustCurryWith(map[string]string{
PrometheusLibInstanceLabel: instance,
PrometheusLibSourceLabel: source,
PrometheusLibSourceLabel: string(source),
})
}

Expand All @@ -95,6 +115,7 @@ func newPrometheusMetrics() *PrometheusMetrics {
makeLabelNames(HTTPClientRequestLabelMethod, HTTPClientRequestLabelURL,
HTTPClientRequestLabelStatusCode, HTTPClientRequestLabelError),
)

grpcClientReqDuration := prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: PrometheusNamespace,
Expand All @@ -106,6 +127,16 @@ func newPrometheusMetrics() *PrometheusMetrics {
makeLabelNames(GRPCClientRequestLabelMethod, GRPCClientRequestLabelCode),
)

tokenIntrospectionsTotal := prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: PrometheusNamespace,
Name: "token_introspections_total",
Help: "Total number of tokens' introspections",
ConstLabels: PrometheusLabels(),
},
makeLabelNames(TokenIntrospectionLabelStatus),
)

tokenClaimsCache := lrucache.NewPrometheusMetricsWithOpts(lrucache.PrometheusMetricsOpts{
Namespace: PrometheusNamespace + "_token_claims",
ConstLabels: PrometheusLabels(),
Expand All @@ -127,6 +158,7 @@ func newPrometheusMetrics() *PrometheusMetrics {
return &PrometheusMetrics{
HTTPClientRequestDuration: httpClientReqDuration,
GRPCClientRequestDuration: grpcClientReqDuration,
TokenIntrospectionsTotal: tokenIntrospectionsTotal,
TokenClaimsCache: tokenClaimsCache,
TokenNegativeCache: tokenNegativeCache,
EndpointDiscoveryCache: endpointDiscoveryCache,
Expand All @@ -138,6 +170,7 @@ func (pm *PrometheusMetrics) MustCurryWith(labels prometheus.Labels) *Prometheus
return &PrometheusMetrics{
HTTPClientRequestDuration: pm.HTTPClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec),
GRPCClientRequestDuration: pm.GRPCClientRequestDuration.MustCurryWith(labels).(*prometheus.HistogramVec),
TokenIntrospectionsTotal: pm.TokenIntrospectionsTotal.MustCurryWith(labels),
TokenClaimsCache: pm.TokenClaimsCache.MustCurryWith(labels),
TokenNegativeCache: pm.TokenNegativeCache.MustCurryWith(labels),
EndpointDiscoveryCache: pm.EndpointDiscoveryCache.MustCurryWith(labels),
Expand Down Expand Up @@ -183,3 +216,7 @@ func (pm *PrometheusMetrics) ObserveGRPCClientRequest(
GRPCClientRequestLabelCode: code.String(),
}).Observe(elapsed.Seconds())
}

func (pm *PrometheusMetrics) IncTokenIntrospectionsTotal(status string) {
pm.TokenIntrospectionsTotal.With(prometheus.Labels{TokenIntrospectionLabelStatus: status}).Inc()
}
2 changes: 1 addition & 1 deletion jwks/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func NewClient() *Client {

// NewClientWithOpts returns a new Client with options.
func NewClientWithOpts(opts ClientOpts) *Client {
promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, "jwks_client")
promMetrics := metrics.GetPrometheusMetrics(opts.PrometheusLibInstanceLabel, metrics.SourceJWKSClient)
if opts.HTTPClient == nil {
opts.HTTPClient = idputil.MakeDefaultHTTPClient(idputil.DefaultHTTPRequestTimeout, opts.LoggerProvider)
}
Expand Down
2 changes: 1 addition & 1 deletion jwt/caching_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewCachingParser(keysProvider KeysProvider) (*CachingParser, error) {
func NewCachingParserWithOpts(
keysProvider KeysProvider, opts CachingParserOpts,
) (*CachingParser, error) {
promMetrics := metrics.GetPrometheusMetrics(opts.CachePrometheusInstanceLabel, "jwt_parser")
promMetrics := metrics.GetPrometheusMetrics(opts.CachePrometheusInstanceLabel, metrics.SourceJWTParser)
if opts.CacheMaxEntries == 0 {
opts.CacheMaxEntries = DefaultClaimsCacheMaxEntries
}
Expand Down
22 changes: 19 additions & 3 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (

"github.com/acronis/go-authkit/idptoken"
"github.com/acronis/go-authkit/internal/idputil"
"github.com/acronis/go-authkit/internal/metrics"
"github.com/acronis/go-authkit/jwt"
)

Expand Down Expand Up @@ -70,12 +71,14 @@ type jwtAuthHandler struct {
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
promMetrics *metrics.PrometheusMetrics
}

type jwtAuthMiddlewareOpts struct {
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
verifyAccess func(r *http.Request, claims jwt.Claims) bool
tokenIntrospector TokenIntrospector
loggerProvider func(ctx context.Context) log.FieldLogger
prometheusLibInstanceLabel string
}

// JWTAuthMiddlewareOption is an option for JWTAuthMiddleware.
Expand All @@ -102,6 +105,13 @@ func WithJWTAuthMiddlewareLoggerProvider(loggerProvider func(ctx context.Context
}
}

// WithJWTAuthMiddlewarePrometheusLibInstanceLabel is an option to set a label for Prometheus metrics that are used by JWTAuthMiddleware.
func WithJWTAuthMiddlewarePrometheusLibInstanceLabel(label string) JWTAuthMiddlewareOption {
return func(options *jwtAuthMiddlewareOpts) {
options.prometheusLibInstanceLabel = label
}
}

// JWTAuthMiddleware is a middleware that does authentication
// by Access Token from the "Authorization" HTTP header of incoming request.
// errorDomain is used for error responses. It is usually the name of the service that uses the middleware,
Expand All @@ -123,6 +133,7 @@ func JWTAuthMiddleware(errorDomain string, jwtParser JWTParser, opts ...JWTAuthM
verifyAccess: options.verifyAccess,
tokenIntrospector: options.tokenIntrospector,
loggerProvider: options.loggerProvider,
promMetrics: metrics.GetPrometheusMetrics(options.prometheusLibInstanceLabel, metrics.SourceHTTPMiddleware),
}
}
}
Expand All @@ -146,21 +157,25 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token's introspection is not needed")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotNeeded)
case errors.Is(err, idptoken.ErrTokenNotIntrospectable):
// Token is not introspectable by some reason.
// In this case, we will parse it as JWT and use it for authZ.
h.logger(reqCtx).Warn("token is not introspectable, it will be used for authentication and authorization as is",
log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotIntrospectable)
default:
logger := h.logger(reqCtx)
logger.Error("token's introspection failed", log.Error(err))
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusError)
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, logger)
return
}
} else {
if !introspectionResult.IsActive() {
h.logger(reqCtx).Warn("token was successfully introspected, but it is not active")
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusNotActive)
apiErr := restapi.NewError(h.errorDomain, ErrCodeAuthenticationFailed, ErrMessageAuthenticationFailed)
restapi.RespondError(rw, http.StatusUnauthorized, apiErr, h.logger(reqCtx))
return
Expand All @@ -169,6 +184,7 @@ func (h *jwtAuthHandler) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
h.logger(reqCtx).AtLevel(log.LevelDebug, func(logFunc log.LogFunc) {
logFunc("token was successfully introspected")
})
h.promMetrics.IncTokenIntrospectionsTotal(metrics.TokenIntrospectionStatusActive)
}
}

Expand Down
32 changes: 31 additions & 1 deletion middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/stretchr/testify/require"

"github.com/acronis/go-authkit/idptoken"
"github.com/acronis/go-authkit/internal/metrics"
"github.com/acronis/go-authkit/jwt"
)

Expand Down Expand Up @@ -122,12 +123,18 @@ func TestJWTAuthMiddleware(t *testing.T) {
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusError), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed)
require.Equal(t, 1, introspector.introspectCalled)
require.Equal(t, 0, parser.parseCalled)
require.Equal(t, 0, next.called)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusError), 1)
})

t.Run("introspection is not needed", func(t *testing.T) {
Expand All @@ -139,6 +146,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotNeeded), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
Expand All @@ -148,6 +158,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
nextIssuer, err := next.jwtClaims.GetIssuer()
require.NoError(t, err)
require.Equal(t, issuer, nextIssuer)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotNeeded), 1)
})

t.Run("ok, token is not introspectable", func(t *testing.T) {
Expand All @@ -159,6 +172,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotIntrospectable), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
Expand All @@ -169,23 +185,31 @@ func TestJWTAuthMiddleware(t *testing.T) {
nextIssuer, err := next.jwtClaims.GetIssuer()
require.NoError(t, err)
require.Equal(t, issuer, nextIssuer)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotIntrospectable), 1)
})

t.Run("authentication failed, token is introspected but inactive", func(t *testing.T) {
const issuer = "my-idp.com"
parser := &mockJWTParser{}
introspector := &mockTokenIntrospector{resultToReturn: &idptoken.DefaultIntrospectionResult{Active: false}}
next := &mockJWTAuthMiddlewareNextHandler{}
req := httptest.NewRequest(http.MethodGet, "/", http.NoBody)
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotActive), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

testutil.RequireErrorInRecorder(t, resp, http.StatusUnauthorized, errDomain, ErrCodeAuthenticationFailed)
require.Equal(t, 1, introspector.introspectCalled)
require.Equal(t, 0, parser.parseCalled)
require.Equal(t, 0, next.called)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusNotActive), 1)
})

t.Run("ok, token is introspected and active", func(t *testing.T) {
Expand All @@ -198,6 +222,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
req.Header.Set(HeaderAuthorization, "Bearer a.b.c")
resp := httptest.NewRecorder()

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 0)

JWTAuthMiddleware(errDomain, parser, WithJWTAuthMiddlewareTokenIntrospector(introspector))(next).ServeHTTP(resp, req)

require.Equal(t, http.StatusOK, resp.Code)
Expand All @@ -208,6 +235,9 @@ func TestJWTAuthMiddleware(t *testing.T) {
nextIssuer, err := next.jwtClaims.GetIssuer()
require.NoError(t, err)
require.Equal(t, issuer, nextIssuer)

testutil.RequireSamplesCountInCounter(t, metrics.GetPrometheusMetrics("", metrics.SourceHTTPMiddleware).
TokenIntrospectionsTotal.WithLabelValues(metrics.TokenIntrospectionStatusActive), 1)
})
}

Expand Down

0 comments on commit 8e54d4a

Please sign in to comment.