Skip to content

Commit

Permalink
Merge pull request #15729 from Awk34/patch-1
Browse files Browse the repository at this point in the history
Add HTTP Basic Auth option to OAuthHandler
  • Loading branch information
itsankit-google authored Jan 10, 2025
2 parents 9e317c1 + 82c4e26 commit c7478d1
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,18 @@ public class OAuthProvider {
private final String tokenRefreshURL;
@Nullable
private final OAuthClientCredentials clientCreds;
private final CredentialEncodingStrategy strategy;

public OAuthProvider(String name,
String loginURL,
String tokenRefreshURL,
@Nullable OAuthClientCredentials clientCreds) {
@Nullable OAuthClientCredentials clientCreds,
@Nullable CredentialEncodingStrategy strategy) {
this.name = name;
this.loginURL = loginURL;
this.tokenRefreshURL = tokenRefreshURL;
this.clientCreds = clientCreds;
this.strategy = strategy;
}

public String getName() {
Expand All @@ -54,6 +57,17 @@ public OAuthClientCredentials getClientCredentials() {
return clientCreds;
}

public CredentialEncodingStrategy getCredentialEncodingStrategy() {
return strategy;
}

public enum CredentialEncodingStrategy {
// (default) Sends client ID & secret as part of the POST request body
FORM_BODY,
// Sends client ID & secret as part of a HTTP Basic Auth header
BASIC_AUTH,
}

public static Builder newBuilder() {
return new Builder();
}
Expand All @@ -66,6 +80,7 @@ public static class Builder {
private String loginURL;
private String tokenRefreshURL;
private OAuthClientCredentials clientCreds;
private CredentialEncodingStrategy strategy;

public Builder() {}

Expand All @@ -89,11 +104,20 @@ public Builder withClientCredentials(@Nullable OAuthClientCredentials clientCred
return this;
}

public Builder withCredentialEncodingStrategy(@Nullable CredentialEncodingStrategy strategy) {
this.strategy = strategy;
return this;
}

public OAuthProvider build() {
Preconditions.checkNotNull(name, "OAuth provider name missing");
Preconditions.checkNotNull(loginURL, "Login URL missing");
Preconditions.checkNotNull(tokenRefreshURL, "Token refresh URL missing");
return new OAuthProvider(name, loginURL, tokenRefreshURL, clientCreds);
// Default to FORM_BODY strategy
if (strategy == null) {
this.strategy = CredentialEncodingStrategy.FORM_BODY;
}
return new OAuthProvider(name, loginURL, tokenRefreshURL, clientCreds, strategy);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,23 @@ public class PutOAuthProviderRequest {
private final String tokenRefreshURL;
private final String clientId;
private final String clientSecret;
private final OAuthProvider.CredentialEncodingStrategy strategy;

public PutOAuthProviderRequest(String loginURL, String tokenRefreshURL, String clientId, String clientSecret) {
this(loginURL, tokenRefreshURL, clientId, clientSecret, OAuthProvider.CredentialEncodingStrategy.FORM_BODY);
}

public PutOAuthProviderRequest(
String loginURL,
String tokenRefreshURL,
String clientId,
String clientSecret,
OAuthProvider.CredentialEncodingStrategy strategy) {
this.loginURL = loginURL;
this.tokenRefreshURL = tokenRefreshURL;
this.clientId = clientId;
this.clientSecret = clientSecret;
this.strategy = strategy;
}

public String getLoginURL() {
Expand All @@ -48,4 +59,8 @@ public String getClientId() {
public String getClientSecret() {
return clientSecret;
}

public OAuthProvider.CredentialEncodingStrategy getCredentialEncodingStrategy() {
return strategy;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import io.cdap.cdap.datapipeline.oauth.GetAccessTokenResponse;
import io.cdap.cdap.datapipeline.oauth.OAuthClientCredentials;
import io.cdap.cdap.datapipeline.oauth.OAuthProvider;
import io.cdap.cdap.datapipeline.oauth.OAuthProvider.CredentialEncodingStrategy;
import io.cdap.cdap.datapipeline.oauth.OAuthRefreshToken;
import io.cdap.cdap.datapipeline.oauth.OAuthStore;
import io.cdap.cdap.datapipeline.oauth.OAuthStoreException;
Expand All @@ -43,6 +44,7 @@
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Optional;
import javax.ws.rs.DefaultValue;
import javax.ws.rs.GET;
Expand Down Expand Up @@ -115,6 +117,7 @@ public void putOAuthProvider(HttpServiceRequest request, HttpServiceResponder re
PutOAuthProviderRequest putOAuthProviderRequest = GSON.fromJson(
StandardCharsets.UTF_8.decode(request.getContent()).toString(),
PutOAuthProviderRequest.class);
CredentialEncodingStrategy strategy = putOAuthProviderRequest.getCredentialEncodingStrategy();
// Validate URLs
URL loginURL = new URL(putOAuthProviderRequest.getLoginURL());
URL tokenRefreshURL = new URL(putOAuthProviderRequest.getTokenRefreshURL());
Expand All @@ -132,6 +135,7 @@ public void putOAuthProvider(HttpServiceRequest request, HttpServiceResponder re
.withLoginURL(loginURL.toString())
.withTokenRefreshURL(tokenRefreshURL.toString())
.withClientCredentials(clientCredentials)
.withCredentialEncodingStrategy(strategy)
.build();
oauthStore.writeProvider(provider, reuseClientCredentials);
responder.sendStatus(HttpURLConnection.HTTP_OK);
Expand Down Expand Up @@ -190,7 +194,7 @@ public void putOAuthCredential(HttpServiceRequest request, HttpServiceResponder
+ response.getResponseCode()
+ " , response message: "
+ response.getResponseMessage()
+ " , respone body: "
+ " , response body: "
+ response.getResponseBodyAsString());
}

Expand Down Expand Up @@ -248,7 +252,7 @@ public void getOAuthCredential(HttpServiceRequest request, HttpServiceResponder
+ response.getResponseCode()
+ " , response message: "
+ response.getResponseMessage()
+ " , respone body: "
+ " , response body: "
+ response.getResponseBodyAsString());
}

Expand Down Expand Up @@ -307,37 +311,102 @@ private boolean checkCredIsValid(HttpResponse response) throws OAuthServiceExcep
return !(refreshTokenResponse.getAccessToken() == null || refreshTokenResponse.getAccessToken().isEmpty());
}

/**
* Create the request body for refresh token & access token requests
* @param strategy which encoding strategy is used to send client ID + secret
* @param grantType whether an authorization code used to fetch a refresh token or a refresh token used to fetch an
* access token is used
* @param code used when building a request to get a refresh token
* @param redirectURI used when building a request to get an access token
* @param refreshToken used when building a request to get an access token
* @param clientCreds the client ID + secret
* @return request body
*/
private String buildRequestBody(CredentialEncodingStrategy strategy,
String grantType,
String code,
String redirectURI,
String refreshToken,
OAuthClientCredentials clientCreds) {
switch (strategy) {
case BASIC_AUTH:
return grantType.equals("authorization_code")
? String.format("code=%s&redirect_uri=%s&grant_type=%s", code, redirectURI, grantType)
: String.format("grant_type=%s&refresh_token=%s", grantType, refreshToken);
case FORM_BODY: // fall-through
default:
return grantType.equals("authorization_code")
? String.format("code=%s&redirect_uri=%s&client_id=%s&client_secret=%s&grant_type=%s",
code, redirectURI, clientCreds.getClientId(), clientCreds.getClientSecret(), grantType)
: String.format("grant_type=%s&client_id=%s&client_secret=%s&refresh_token=%s",
grantType, clientCreds.getClientId(), clientCreds.getClientSecret(), refreshToken);
}
}

/** Build HTTP request for getting tokens */
private HttpRequest.Builder buildHttpRequest(String body,
CredentialEncodingStrategy strategy,
OAuthClientCredentials clientCreds,
String refreshTokenURL,
boolean addContentType) throws MalformedURLException {
HttpRequest.Builder requestBuilder = HttpRequest.post(new URL(refreshTokenURL))
.withBody(body);

if (addContentType) {
requestBuilder.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED);
}

if (strategy == CredentialEncodingStrategy.BASIC_AUTH) {
requestBuilder.addHeader(HttpHeaders.AUTHORIZATION, getBasicAuthHeader(clientCreds));
}

return requestBuilder;
}

/**
* Build the HttpRequest to request a refresh token from the OAuth provider
* @param provider
* @param code the authorization code given after the user accepts OAuth from the provider
* @param redirectURI
*/
private HttpRequest createGetRefreshTokenRequest(OAuthProvider provider, String code, String redirectURI)
throws OAuthServiceException {
OAuthClientCredentials clientCreds = provider.getClientCredentials();
CredentialEncodingStrategy strategy = provider.getCredentialEncodingStrategy();
String tokenRefreshURL = provider.getTokenRefreshURL();
String body = buildRequestBody(strategy, "authorization_code", code, redirectURI, null, clientCreds);

try {
return HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(String.format(
"code=%s&redirect_uri=%s&client_id=%s&client_secret=%s&grant_type=authorization_code",
code, redirectURI, clientCreds.getClientId(), clientCreds.getClientSecret()))
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED)
.build();
return buildHttpRequest(body, strategy, clientCreds, tokenRefreshURL, false).build();
} catch (MalformedURLException e) {
throw new OAuthServiceException(HttpURLConnection.HTTP_INTERNAL_ERROR, "Malformed URL", e);
}
}

/**
* Build the HttpRequest to request an access token for making data requests from the OAuth provider
* @param provider
* @param refreshToken the refresh token requested previously from the provider
*/
private HttpRequest createGetAccessTokenRequest(OAuthProvider provider, String refreshToken)
throws OAuthServiceException {
OAuthClientCredentials clientCreds = provider.getClientCredentials();
CredentialEncodingStrategy strategy = provider.getCredentialEncodingStrategy();
String tokenRefreshURL = provider.getTokenRefreshURL();
String body = buildRequestBody(strategy, "refresh_token", null, null, refreshToken, clientCreds);

try {
return HttpRequest.post(new URL(provider.getTokenRefreshURL()))
.withBody(
String.format("grant_type=refresh_token&client_id=%s&client_secret=%s&refresh_token=%s",
clientCreds.getClientId(),
clientCreds.getClientSecret(),
refreshToken))
.build();
return buildHttpRequest(body, strategy, clientCreds, tokenRefreshURL, true).build();
} catch (MalformedURLException e) {
throw new OAuthServiceException(HttpURLConnection.HTTP_INTERNAL_ERROR, "Malformed URL", e);
}
}

private String getBasicAuthHeader(OAuthClientCredentials clientCreds) {
String authInfo = String.format("%s:%s", clientCreds.getClientId(), clientCreds.getClientSecret());
return String.format("Basic %s", Base64.getEncoder().encode(authInfo.getBytes()));
}

private OAuthProvider getProvider(String provider) throws OAuthServiceException {
try {
Optional<OAuthProvider> providerOptional = oauthStore.getProvider(provider);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

package io.cdap.cdap.datapipeline;

import com.google.common.collect.Multimap;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import io.cdap.cdap.common.http.DefaultHttpRequestConfig;
import io.cdap.cdap.datapipeline.oauth.OAuthProvider;
import io.cdap.cdap.datapipeline.oauth.PutOAuthProviderRequest;
import io.cdap.cdap.datapipeline.oauth.PutOAuthCredentialRequest;
import io.cdap.common.http.HttpMethod;
import io.cdap.common.http.HttpRequest;
import io.cdap.common.http.HttpRequests;
Expand Down Expand Up @@ -104,6 +107,29 @@ public void testCreateProviderWithReuseClientCredentialsFalse() throws IOExcepti
Assert.assertEquals(400, createResponse.getResponseCode());
}

@Test
public void testCreateProviderWithBasicAuth() throws IOException {
// Attempt to create provider
String loginURL = "http://www.example.com/login31";
String tokenRefreshURL = "http://www.example.com/token31";
String clientId = "clientid";
String clientSecret = "clientsecret";
PutOAuthProviderRequest request = new PutOAuthProviderRequest(
loginURL,
tokenRefreshURL,
clientId,
clientSecret,
OAuthProvider.CredentialEncodingStrategy.BASIC_AUTH);
HttpResponse createOauthProviderResponse = makePutCall("provider/testprovider31", request);
Assert.assertEquals(200, createOauthProviderResponse.getResponseCode());

// Grab OAuth login URL to verify write succeeded
HttpResponse getAuthUrlResponse = makeGetCall("provider/testprovider31/authurl");
Assert.assertEquals(200, getAuthUrlResponse.getResponseCode());
String authURL = getAuthUrlResponse.getResponseBodyAsString();
Assert.assertEquals("http://www.example.com/login31?client_id=clientid&redirect_uri=null", authURL);
}

@Test
public void testGetAuthURLForMissingClientCredentials() throws IOException {
// Attempt to create provider with missing client credentials and 'reuse_client_credentials'
Expand Down

0 comments on commit c7478d1

Please sign in to comment.