diff --git a/backend/src/main/java/com/bakdata/conquery/models/auth/oidc/JwtPkceVerifyingRealm.java b/backend/src/main/java/com/bakdata/conquery/models/auth/oidc/JwtPkceVerifyingRealm.java index f18bfb31f4..79bb0cf1af 100644 --- a/backend/src/main/java/com/bakdata/conquery/models/auth/oidc/JwtPkceVerifyingRealm.java +++ b/backend/src/main/java/com/bakdata/conquery/models/auth/oidc/JwtPkceVerifyingRealm.java @@ -1,5 +1,11 @@ package com.bakdata.conquery.models.auth.oidc; +import java.lang.reflect.Array; +import java.security.PublicKey; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; + import com.bakdata.conquery.io.storage.MetaStorage; import com.bakdata.conquery.models.auth.ConqueryAuthenticationInfo; import com.bakdata.conquery.models.auth.ConqueryAuthenticationRealm; @@ -9,21 +15,19 @@ import com.bakdata.conquery.models.identifiable.ids.specific.UserId; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; -import org.apache.shiro.authc.*; +import org.apache.shiro.authc.AuthenticationException; +import org.apache.shiro.authc.AuthenticationToken; +import org.apache.shiro.authc.BearerToken; +import org.apache.shiro.authc.IncorrectCredentialsException; +import org.apache.shiro.authc.UnknownAccountException; import org.apache.shiro.authc.pam.UnsupportedTokenException; import org.apache.shiro.realm.AuthenticatingRealm; import org.keycloak.TokenVerifier; import org.keycloak.common.VerificationException; -import org.keycloak.exceptions.TokenNotActiveException; +import org.keycloak.jose.JOSEParser; import org.keycloak.representations.AccessToken; import org.keycloak.representations.JsonWebToken; -import java.lang.reflect.Array; -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.function.Supplier; - /** * This realm uses the configured public key to verify the signature of a provided JWT and extracts informations about * the authenticated user from it. @@ -65,17 +69,26 @@ public ConqueryAuthenticationInfo doGetAuthenticationInfo(AuthenticationToken to return null; } JwtPkceVerifyingRealmFactory.IdpConfiguration idpConfiguration = idpConfigurationOpt.get(); + final BearerToken bearerToken = (BearerToken) token; + + log.trace("Parsing token ({}) to extract key id from header", bearerToken.getToken()); + final String keyId = JOSEParser.parse(bearerToken.getToken()).getHeader().getKeyId(); + log.trace("Key id of token signer: {}", keyId); + final PublicKey publicKey = idpConfiguration.signingKeys().get(keyId); + + if (publicKey == null) { + throw new UnsupportedTokenException("Token was signed by a key with an unknown Id: " + keyId); + } log.trace("Creating token verifier"); - TokenVerifier verifier = TokenVerifier.create(((BearerToken) token).getToken(), AccessToken.class) - .withChecks(new TokenVerifier.RealmUrlCheck(idpConfiguration.getIssuer()), TokenVerifier.SUBJECT_EXISTS_CHECK, activeVerifier) + TokenVerifier verifier = TokenVerifier.create(bearerToken.getToken(), AccessToken.class) + .withChecks(new TokenVerifier.RealmUrlCheck(idpConfiguration.issuer()), TokenVerifier.SUBJECT_EXISTS_CHECK, activeVerifier) .withChecks(tokenChecks) - .publicKey(idpConfiguration.getPublicKey()) + .publicKey(publicKey) .audience(allowedAudience); - String subject; log.trace("Verifying token"); - AccessToken accessToken = null; + final AccessToken accessToken; try { verifier.verify(); accessToken = verifier.getToken(); @@ -84,14 +97,14 @@ public ConqueryAuthenticationInfo doGetAuthenticationInfo(AuthenticationToken to log.trace("Verification failed", e); throw new IncorrectCredentialsException(e); } - subject = accessToken.getSubject(); + final String subject = accessToken.getSubject(); if (subject == null) { // Should not happen, as sub is mandatory in an access_token throw new UnsupportedTokenException("Unable to extract a subject from the provided token."); } - log.trace("Authentication successfull for subject {}", subject); + log.trace("Authentication was successful for subject: {}", subject); UserId userId = new UserId(subject); @@ -102,7 +115,6 @@ public ConqueryAuthenticationInfo doGetAuthenticationInfo(AuthenticationToken to } // Try alternative ids - List alternativeIds = new ArrayList<>(); for (String alternativeIdClaim : alternativeIdClaims) { Object altId = accessToken.getOtherClaims().get(alternativeIdClaim); if (!(altId instanceof String)) { @@ -120,15 +132,4 @@ public ConqueryAuthenticationInfo doGetAuthenticationInfo(AuthenticationToken to throw new UnknownAccountException("The user id was unknown: " + subject); } - public static final TokenVerifier.Predicate IS_ACTIVE = new TokenVerifier.Predicate() { - @Override - public boolean test(JsonWebToken t) throws VerificationException { - if (!t.isActive()) { - throw new TokenNotActiveException(t, "Token is not active"); - } - - return true; - } - }; - } diff --git a/backend/src/main/java/com/bakdata/conquery/models/config/auth/JwtPkceVerifyingRealmFactory.java b/backend/src/main/java/com/bakdata/conquery/models/config/auth/JwtPkceVerifyingRealmFactory.java index 024f142b2d..b496d0c74a 100644 --- a/backend/src/main/java/com/bakdata/conquery/models/config/auth/JwtPkceVerifyingRealmFactory.java +++ b/backend/src/main/java/com/bakdata/conquery/models/config/auth/JwtPkceVerifyingRealmFactory.java @@ -1,20 +1,51 @@ package com.bakdata.conquery.models.config.auth; -import com.bakdata.conquery.apiv1.RequestAwareUriBuilder; +import java.io.IOException; +import java.net.URI; +import java.security.PublicKey; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.function.BiFunction; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import javax.validation.constraints.Min; +import javax.validation.constraints.NotEmpty; +import javax.validation.constraints.NotNull; +import javax.ws.rs.client.Client; +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.core.Cookie; +import javax.ws.rs.core.HttpHeaders; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.NewCookie; +import javax.ws.rs.core.Response; +import javax.ws.rs.core.UriBuilder; + import com.bakdata.conquery.apiv1.RequestHelper; import com.bakdata.conquery.commands.ManagerNode; import com.bakdata.conquery.io.cps.CPSType; import com.bakdata.conquery.io.jackson.Jackson; import com.bakdata.conquery.models.auth.ConqueryAuthenticationRealm; import com.bakdata.conquery.models.auth.oidc.JwtPkceVerifyingRealm; -import com.bakdata.conquery.models.auth.web.AuthCookieFilter; import com.bakdata.conquery.models.auth.web.RedirectingAuthFilter; import com.bakdata.conquery.resources.admin.AdminServlet; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.nimbusds.jwt.JWTParser; -import com.nimbusds.oauth2.sdk.*; +import com.nimbusds.oauth2.sdk.AccessTokenResponse; +import com.nimbusds.oauth2.sdk.AuthorizationCode; +import com.nimbusds.oauth2.sdk.AuthorizationCodeGrant; +import com.nimbusds.oauth2.sdk.AuthorizationGrant; +import com.nimbusds.oauth2.sdk.ParseException; +import com.nimbusds.oauth2.sdk.RefreshTokenGrant; +import com.nimbusds.oauth2.sdk.TokenRequest; +import com.nimbusds.oauth2.sdk.TokenResponse; import com.nimbusds.oauth2.sdk.http.HTTPResponse; import com.nimbusds.oauth2.sdk.id.ClientID; import com.nimbusds.oauth2.sdk.token.RefreshToken; @@ -22,7 +53,10 @@ import groovy.lang.GroovyShell; import groovy.lang.Script; import io.dropwizard.validation.ValidationMethod; -import lombok.*; +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.SneakyThrows; import lombok.extern.slf4j.Slf4j; import org.codehaus.groovy.control.CompilerConfiguration; import org.codehaus.groovy.control.customizers.ImportCustomizer; @@ -33,21 +67,6 @@ import org.keycloak.jose.jwk.JWKParser; import org.keycloak.representations.AccessToken; -import javax.validation.constraints.Min; -import javax.validation.constraints.NotEmpty; -import javax.validation.constraints.NotNull; -import javax.ws.rs.client.Client; -import javax.ws.rs.container.ContainerRequestContext; -import javax.ws.rs.core.*; -import javax.ws.rs.core.Response; - -import java.io.IOException; -import java.net.URI; -import java.security.PublicKey; -import java.util.*; -import java.util.function.BiFunction; -import java.util.function.Supplier; - /** * A realm that verifies oauth tokens using PKCE. * @@ -85,7 +104,7 @@ public class JwtPkceVerifyingRealmFactory implements AuthenticationRealmFactory /** * See wellKnownEndpoint. */ - private IdpConfiguration idpConfiguration; + private volatile IdpConfiguration idpConfiguration; /** * A leeway for token's expiration in seconds, this should be a short time. @@ -119,25 +138,15 @@ public boolean isConfigurationAvailable() { return wellKnownEndpoint != null || idpConfiguration != null; } - @AllArgsConstructor - @Getter - public static class IdpConfiguration { - - /** - * The public key information that is used to validate signed JWT. - * It can be retrieved from the IDP. - */ - @NonNull - private final PublicKey publicKey; - - @NonNull - private final URI authorizationEndpoint; - - @NonNull - private final URI tokenEndpoint; - - @NotEmpty - private final String issuer; + /** + * @param signingKeys The public key information that is used to validate signed JWT. + * It can be retrieved from the IDP. + */ + public record IdpConfiguration( + @NonNull Map signingKeys, + @NonNull URI authorizationEndpoint, + @NonNull URI tokenEndpoint, + @NotEmpty String issuer) { } public ConqueryAuthenticationRealm createRealm(ManagerNode manager) { @@ -230,14 +239,19 @@ private IdpConfiguration retrieveIdpConfiguration(final Client client) { } - final List keys = jwks.getKeys(); - if (keys.size() != 1) { - throw new IllegalStateException("Expected exactly 1 jwk for realm but found: " + keys.size()); - } + // Filter for keys that are used for signing (discard encryption keys) + final Map signingKeys = jwks.getKeys().stream() + .filter(jwk -> JWK.Use.SIG.name().equals(jwk.getPublicKeyUse())) + .collect(Collectors.toMap(JWK::getKeyId, JwtPkceVerifyingRealmFactory::getPublicKey)); - JWK jwk = keys.get(0); - return new IdpConfiguration(getPublicKey(jwk), authorizationEndpoint, tokenEndpoint, issuer); + if (signingKeys.isEmpty()) { + throw new IllegalStateException("No signing keys could be retrieved from IDP. Received these JWKs (Key Ids):" + jwks.getKeys() + .stream() + .map(JWK::getKeyId)); + } + + return new IdpConfiguration(signingKeys, authorizationEndpoint, tokenEndpoint, issuer); } @@ -290,7 +304,7 @@ private URI initiateLogin(ContainerRequestContext request) { return null; } JwtPkceVerifyingRealmFactory.IdpConfiguration idpConfiguration = idpConfigurationOpt.get(); - return UriBuilder.fromUri(idpConfiguration.getAuthorizationEndpoint()) + return UriBuilder.fromUri(idpConfiguration.authorizationEndpoint()) .queryParam("response_type", "code") .queryParam("client_id", client) .queryParam("redirect_uri", UriBuilder.fromUri(RequestHelper.getRequestURL(request)).path(AdminServlet.ADMIN_UI).build()) @@ -419,7 +433,7 @@ private AccessTokenResponse getTokenResponse(ContainerRequestContext request, Au // Send the auth code/refresh token to the IDP to redeem them for a new access and refresh token final TokenRequest tokenRequest = new TokenRequest( - UriBuilder.fromUri(idpConfiguration.getTokenEndpoint()).build(), + UriBuilder.fromUri(idpConfiguration.tokenEndpoint()).build(), new ClientID(client), authzGrant ); diff --git a/backend/src/test/java/com/bakdata/conquery/models/auth/oidc/JwtPkceVerifyingRealmTest.java b/backend/src/test/java/com/bakdata/conquery/models/auth/oidc/JwtPkceVerifyingRealmTest.java index a475fc4818..6b544c3e2d 100644 --- a/backend/src/test/java/com/bakdata/conquery/models/auth/oidc/JwtPkceVerifyingRealmTest.java +++ b/backend/src/test/java/com/bakdata/conquery/models/auth/oidc/JwtPkceVerifyingRealmTest.java @@ -11,6 +11,7 @@ import java.security.interfaces.RSAPublicKey; import java.util.Date; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.UUID; @@ -23,6 +24,7 @@ import com.bakdata.conquery.util.NonPersistentStoreFactory; import org.apache.commons.lang3.time.DateUtils; import org.apache.shiro.authc.BearerToken; +import org.apache.shiro.authc.pam.UnsupportedTokenException; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -38,6 +40,7 @@ class JwtPkceVerifyingRealmTest { private static final String ALTERNATIVE_ID_CLAIM = "alternativeId"; public static final int TOKEN_LEEWAY = 60; private static JwtPkceVerifyingRealm REALM; + private static final String KEY_ID = "valid_key_id"; private static RSAPrivateKey PRIVATE_KEY; private static RSAPublicKey PUBLIC_KEY; @@ -53,7 +56,7 @@ static void setup() throws NoSuchAlgorithmException { // Create the realm REALM = new JwtPkceVerifyingRealm( - () -> Optional.of(new JwtPkceVerifyingRealmFactory.IdpConfiguration(PUBLIC_KEY, URI.create("auth"), URI.create("token"), HTTP_REALM_URL)), + () -> Optional.of(new JwtPkceVerifyingRealmFactory.IdpConfiguration(Map.of(KEY_ID, PUBLIC_KEY), URI.create("auth"), URI.create("token"), HTTP_REALM_URL)), AUDIENCE, List.of(JwtPkceVerifyingRealmFactory.ScriptedTokenChecker.create("t.getOtherClaims().get(\"groups\").equals(\"conquery\")")), List.of(ALTERNATIVE_ID_CLAIM), @@ -81,6 +84,7 @@ void verifyToken() { .withClaim("groups", "conquery") .withIssuedAt(issueDate) .withExpiresAt(expDate) + .withKeyId(KEY_ID) .sign(Algorithm.RSA256(PUBLIC_KEY, PRIVATE_KEY)); BearerToken accessToken = new BearerToken(token); @@ -108,6 +112,7 @@ void verifyTokenInLeeway() { .withClaim("groups", "conquery") .withIssuedAt(issueDate) .withExpiresAt(expDate) + .withKeyId(KEY_ID) .sign(Algorithm.RSA256(PUBLIC_KEY, PRIVATE_KEY)); BearerToken accessToken = new BearerToken(token); @@ -133,6 +138,7 @@ void verifyTokenAlternativeId() { .withIssuedAt(issueDate) .withExpiresAt(expDate) .withClaim(ALTERNATIVE_ID_CLAIM, expected.getName()) + .withKeyId(KEY_ID) .sign(Algorithm.RSA256(PUBLIC_KEY, PRIVATE_KEY)); BearerToken accessToken = new BearerToken(token); @@ -154,6 +160,7 @@ void falsifyTokenMissingCustomClaim() { .withSubject(expected.getName()) .withIssuedAt(issueDate) .withExpiresAt(expDate) + .withKeyId(KEY_ID) .sign(Algorithm.RSA256(PUBLIC_KEY, PRIVATE_KEY)); BearerToken accessToken = new BearerToken(token); @@ -175,6 +182,7 @@ void falsifyTokenWrongAudience() { .withClaim("groups", "conquery") .withIssuedAt(issueDate) .withExpiresAt(expDate) + .withKeyId(KEY_ID) .sign(Algorithm.RSA256(PUBLIC_KEY, PRIVATE_KEY)); BearerToken accessToken = new BearerToken(token); @@ -194,6 +202,7 @@ void falsifyTokenOutdated() { .withClaim("groups", "conquery") .withIssuedAt(issueDate) .withExpiresAt(expDate) + .withKeyId(KEY_ID) .sign(Algorithm.RSA256(PUBLIC_KEY, PRIVATE_KEY)); BearerToken accessToken = new BearerToken(token); @@ -217,9 +226,35 @@ void falsifyTokenWrongIssuer() { .withClaim("groups", "conquery") .withIssuedAt(issueDate) .withExpiresAt(expDate) + .withKeyId(KEY_ID) .sign(Algorithm.RSA256(PUBLIC_KEY, PRIVATE_KEY)); BearerToken accessToken = new BearerToken(token); assertThatCode(() -> REALM.doGetAuthenticationInfo(accessToken)).hasCauseInstanceOf(VerificationException.class); } + + @Test + void falsifyTokenUnknownKid() { + + // Setup the expected user id + User expected = new User("Test", "Test", STORAGE); + STORAGE.updateUser(expected); + + Date issueDate = new Date(); + Date expDate = DateUtils.addMinutes(issueDate, 1); + String token = JWT.create() + .withIssuer(HTTP_REALM_URL) + .withAudience(AUDIENCE) + .withSubject(expected.getName()) + .withIssuedAt(issueDate) + .withExpiresAt(expDate) + .withClaim("groups", "conquery") + .withIssuedAt(issueDate) + .withExpiresAt(expDate) + .withKeyId("unknown_key_id") + .sign(Algorithm.RSA256(PUBLIC_KEY, PRIVATE_KEY)); + BearerToken accessToken = new BearerToken(token); + + assertThatCode(() -> REALM.doGetAuthenticationInfo(accessToken)).isInstanceOf(UnsupportedTokenException.class); + } } \ No newline at end of file