diff --git a/cerberus-auth-connector-okta/src/main/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnector.java b/cerberus-auth-connector-okta/src/main/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnector.java index dbb07bee6..e5e0c18d0 100644 --- a/cerberus-auth-connector-okta/src/main/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnector.java +++ b/cerberus-auth-connector-okta/src/main/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnector.java @@ -38,7 +38,10 @@ import com.okta.sdk.authc.credentials.TokenClientCredentials; import com.okta.sdk.client.Client; import com.okta.sdk.client.Clients; +import com.okta.sdk.resource.ResourceException; +import com.okta.sdk.resource.group.Group; import com.okta.sdk.resource.group.GroupList; +import com.okta.sdk.resource.group.GroupProfile; import com.okta.sdk.resource.user.User; import java.util.HashSet; import java.util.Map; @@ -217,20 +220,60 @@ public AuthResponse mfaCheck(String stateToken, String deviceId, String otpToken } } + /** + * Get a valid user from the identity provider if possible + * + * @param userId + * @return User corresponding to the id + * @throws ApiException if user cannot be resolved + */ + protected User getUserFromIDP(String userId) { + try { + return sdkClient.getUser(userId); + } catch (IllegalStateException ise) { + throw ApiException.newBuilder() + .withExceptionCause(ise) + .withApiErrors(DefaultApiError.IDENTITY_PROVIDER_BAD_GATEWAY) + .withExceptionMessage("Could not communicate properly with identity provider") + .build(); + } catch (ResourceException rexc) { + String msg = + String.format("Got invalid response from identity providers: %s", rexc.getMessage()); + throw ApiException.newBuilder() + .withExceptionCause(rexc) + .withApiErrors(DefaultApiError.IDENTITY_PROVIDER_BAD_GATEWAY) + .withExceptionMessage(msg) + .build(); + } catch (Exception exc) { + throw ApiException.newBuilder() + .withExceptionCause(exc) + .withApiErrors(DefaultApiError.INTERNAL_SERVER_ERROR) + .withExceptionMessage("Unknown error trying to getUser from identity provider") + .build(); + } + } + /** Obtains groups user belongs to. */ @Override public Set getGroups(AuthData authData) { Preconditions.checkNotNull(authData, "auth data cannot be null."); - User user = sdkClient.getUser(authData.getUserId()); + String userId = authData.getUserId(); + User user = getUserFromIDP(userId); GroupList userGroups = user.listGroups(); final Set groups = new HashSet<>(); if (userGroups == null) { return groups; } - userGroups.forEach(group -> groups.add(group.getProfile().getName())); + + for (Group group : userGroups) { + GroupProfile profile = group.getProfile(); + if (profile != null) { + groups.add(profile.getName()); + } + } return groups; } diff --git a/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnectorTest.java b/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnectorTest.java index 8ff70e93b..ed42c9aed 100644 --- a/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnectorTest.java +++ b/cerberus-auth-connector-okta/src/test/java/com/nike/cerberus/auth/connector/okta/OktaAuthConnectorTest.java @@ -23,20 +23,34 @@ import static org.mockito.Mockito.*; import static org.mockito.MockitoAnnotations.initMocks; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.nike.backstopper.apierror.ApiError; import com.nike.backstopper.exception.ApiException; import com.nike.cerberus.auth.connector.AuthData; import com.nike.cerberus.auth.connector.AuthResponse; import com.nike.cerberus.auth.connector.AuthStatus; import com.nike.cerberus.auth.connector.okta.statehandlers.InitialLoginStateHandler; import com.nike.cerberus.auth.connector.okta.statehandlers.MfaStateHandler; +import com.nike.cerberus.error.DefaultApiError; import com.okta.authn.sdk.client.AuthenticationClient; import com.okta.authn.sdk.impl.resource.DefaultVerifyPassCodeFactorRequest; import com.okta.jwt.AccessTokenVerifier; import com.okta.jwt.Jwt; import com.okta.jwt.JwtVerificationException; import com.okta.sdk.client.Client; +import com.okta.sdk.impl.client.DefaultClient; +import com.okta.sdk.impl.error.DefaultError; +import com.okta.sdk.resource.ResourceException; +import com.okta.sdk.resource.group.Group; +import com.okta.sdk.resource.group.GroupList; +import com.okta.sdk.resource.group.GroupProfile; +import com.okta.sdk.resource.user.User; import java.util.HashMap; +import java.util.HashSet; +import java.util.List; import java.util.Map; +import java.util.Set; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; @@ -327,4 +341,130 @@ public void testGetAccessTokenVerifier() { AccessTokenVerifier verifier = this.oktaAuthConnector.getAccessTokenVerifier(); assertNotNull(verifier); } + + @Test + public void testGetGroups() { + AccessTokenVerifier verifier = mock(AccessTokenVerifier.class); + + GroupProfile groupProfile = mock(GroupProfile.class); + when(groupProfile.getName()).thenReturn("testGroup"); + + Group fakeGroup = mock(Group.class); + when(fakeGroup.getProfile()).thenReturn(groupProfile); + + List groupIteraterList = Lists.newArrayList(fakeGroup); + GroupList groupList = mock(GroupList.class); + when(groupList.iterator()).thenReturn(groupIteraterList.iterator()); + + User mockUser = mock(User.class); + when(mockUser.listGroups()).thenReturn(groupList); + + DefaultClient mockClient = mock(DefaultClient.class); + when(mockClient.getUser(anyString())).thenReturn(mockUser); + + OktaAuthConnector connector = + new OktaAuthConnector( + client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier); + AuthData authData = AuthData.builder().userId("deadbeef").build(); + Set groups = connector.getGroups(authData); + assertEquals(groups, Set.of("testGroup")); + } + + @Test + public void testGetGroupsMissingProfile() { + AccessTokenVerifier verifier = mock(AccessTokenVerifier.class); + + Group fakeGroup = mock(Group.class); + when(fakeGroup.getProfile()).thenReturn(null); + + List groupIteraterList = Lists.newArrayList(fakeGroup); + GroupList groupList = mock(GroupList.class); + when(groupList.iterator()).thenReturn(groupIteraterList.iterator()); + + User mockUser = mock(User.class); + when(mockUser.listGroups()).thenReturn(groupList); + + DefaultClient mockClient = mock(DefaultClient.class); + when(mockClient.getUser(anyString())).thenReturn(mockUser); + + OktaAuthConnector connector = + new OktaAuthConnector( + client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier); + AuthData authData = AuthData.builder().userId("deadbeef").build(); + Set groups = connector.getGroups(authData); + assertEquals(groups, new HashSet()); + } + + @Test + public void testGetGroupsNullGroups() { + AccessTokenVerifier verifier = mock(AccessTokenVerifier.class); + + User mockUser = mock(User.class); + when(mockUser.listGroups()).thenReturn(null); + + DefaultClient mockClient = mock(DefaultClient.class); + when(mockClient.getUser(anyString())).thenReturn(mockUser); + + OktaAuthConnector connector = + new OktaAuthConnector( + client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier); + + AuthData authData = AuthData.builder().userId("deadbeef").build(); + Set groups = connector.getGroups(authData); + assertEquals(groups, new HashSet<>()); + } + + @Test(expected = ApiException.class) + public void testBadGetUser() { + AccessTokenVerifier verifier = mock(AccessTokenVerifier.class); + Client mockClient = mock(Client.class); + when(mockClient.getUser(anyString())).thenThrow(new RuntimeException("it's broke")); + OktaAuthConnector connector = + new OktaAuthConnector( + client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier); + AuthData authData = AuthData.builder().userId("deadbeef").build(); + connector.getGroups(authData); + } + + @Test + public void testGetUserFromIdpCompletelyBrokenOkta() { + AccessTokenVerifier verifier = mock(AccessTokenVerifier.class); + Client mockClient = mock(Client.class); + String exceptionMessage = "who knows what broke?"; + when(mockClient.getUser(anyString())).thenThrow(new IllegalStateException(exceptionMessage)); + OktaAuthConnector connector = + new OktaAuthConnector( + client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier); + try { + connector.getUserFromIDP("fooUser"); + } catch (ApiException exc) { + String actualMessage = exc.getMessage(); + assertEquals(actualMessage, "Could not communicate properly with identity provider"); + ApiError apiError = exc.getApiErrors().get(0); + assertEquals(apiError, DefaultApiError.IDENTITY_PROVIDER_BAD_GATEWAY); + String causeMessage = exc.getCause().getMessage(); + assertEquals(causeMessage, exceptionMessage); + } + } + + @Test + public void testGetUserFromIdpOktaProblem() { + AccessTokenVerifier verifier = mock(AccessTokenVerifier.class); + Client mockClient = mock(Client.class); + String excMessage = "A specific thing had a problem"; + String excpetionPrefix = "Got invalid response from identity providers"; + ResourceException resourceException = + new ResourceException(new DefaultError(ImmutableMap.of("message", excMessage))); + when(mockClient.getUser(anyString())).thenThrow(resourceException); + OktaAuthConnector connector = + new OktaAuthConnector( + client, mockClient, "https://foo.bar/oauth2/skiddleydee", "dogs", verifier); + try { + connector.getUserFromIDP("fooUser"); + } catch (ApiException exc) { + String actualMessage = exc.getMessage(); + assert actualMessage.startsWith(excpetionPrefix); + assertEquals(exc.getApiErrors().get(0), DefaultApiError.IDENTITY_PROVIDER_BAD_GATEWAY); + } + } } diff --git a/cerberus-core/src/main/java/com/nike/cerberus/error/DefaultApiError.java b/cerberus-core/src/main/java/com/nike/cerberus/error/DefaultApiError.java index 98db2eda8..6edf9141d 100644 --- a/cerberus-core/src/main/java/com/nike/cerberus/error/DefaultApiError.java +++ b/cerberus-core/src/main/java/com/nike/cerberus/error/DefaultApiError.java @@ -291,6 +291,9 @@ public enum DefaultApiError implements ApiError { /** Generic bad requests. This is useful because the blueprint error handling sucks. */ GENERIC_BAD_REQUEST(99999, "Request will not be completed.", SC_BAD_REQUEST), + /** Bad response from identity provider */ + IDENTITY_PROVIDER_BAD_GATEWAY(99988, "Bad response from identity provider", SC_BAD_GATEWAY), + /** * If we encounter an error where something expected is not setup correctly, meaning the service * is not functional. diff --git a/cerberus-web/src/main/java/com/nike/cerberus/error/DefaultApiErrorsImpl.java b/cerberus-web/src/main/java/com/nike/cerberus/error/DefaultApiErrorsImpl.java index 39ef7f830..820a4c9ea 100644 --- a/cerberus-web/src/main/java/com/nike/cerberus/error/DefaultApiErrorsImpl.java +++ b/cerberus-web/src/main/java/com/nike/cerberus/error/DefaultApiErrorsImpl.java @@ -18,6 +18,7 @@ import static com.nike.backstopper.apierror.ApiErrorConstants.*; import static com.nike.backstopper.apierror.projectspecificinfo.ProjectSpecificErrorCodeRange.ALLOW_ALL_ERROR_CODES; +import static javax.servlet.http.HttpServletResponse.SC_BAD_GATEWAY; import static javax.servlet.http.HttpServletResponse.SC_NOT_IMPLEMENTED; import com.nike.backstopper.apierror.ApiError; @@ -34,6 +35,7 @@ public class DefaultApiErrorsImpl extends SampleProjectApiErrorsBase { Arrays.asList( HTTP_STATUS_CODE_FORBIDDEN, HTTP_STATUS_CODE_UNAUTHORIZED, + SC_BAD_GATEWAY, HTTP_STATUS_CODE_SERVICE_UNAVAILABLE, HTTP_STATUS_CODE_TOO_MANY_REQUESTS, HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,