Skip to content

Commit

Permalink
feat: Introduce a way to pass additional parameters to auhtorization …
Browse files Browse the repository at this point in the history
…url (#1134)

* feat: Introduce a way to pass additional parameters to auhtorization url

* casing

* Add custom params to token endpoint

* minor updates

* modify test to check for persistence of additional params
  • Loading branch information
sai-sunder-s committed Jun 16, 2023
1 parent 5fa7039 commit 3a2c5d3
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 3 deletions.
43 changes: 43 additions & 0 deletions oauth2_http/java/com/google/auth/oauth2/UserAuthorizer.java
Expand Up @@ -50,6 +50,7 @@
import java.util.Collection;
import java.util.Date;
import java.util.List;
import java.util.Map;

/** Handles an interactive 3-Legged-OAuth2 (3LO) user consent authorization. */
public class UserAuthorizer {
Expand Down Expand Up @@ -168,6 +169,20 @@ public TokenStore getTokenStore() {
* @return The URL that can be navigated or redirected to.
*/
public URL getAuthorizationUrl(String userId, String state, URI baseUri) {
return this.getAuthorizationUrl(userId, state, baseUri, null);
}

/**
* Return an URL that performs the authorization consent prompt web UI.
*
* @param userId Application's identifier for the end user.
* @param state State that is passed on to the OAuth2 callback URI after the consent.
* @param baseUri The URI to resolve the OAuth2 callback URI relative to.
* @param additionalParameters Additional query parameters to be added to the authorization URL.
* @return The URL that can be navigated or redirected to.
*/
public URL getAuthorizationUrl(
String userId, String state, URI baseUri, Map<String, String> additionalParameters) {
URI resolvedCallbackUri = getCallbackUri(baseUri);
String scopesString = Joiner.on(' ').join(scopes);

Expand All @@ -185,6 +200,13 @@ public URL getAuthorizationUrl(String userId, String state, URI baseUri) {
url.put("login_hint", userId);
}
url.put("include_granted_scopes", true);

if (additionalParameters != null) {
for (Map.Entry<String, String> entry : additionalParameters.entrySet()) {
url.put(entry.getKey(), entry.getValue());
}
}

if (pkce != null) {
url.put("code_challenge", pkce.getCodeChallenge());
url.put("code_challenge_method", pkce.getCodeChallengeMethod());
Expand Down Expand Up @@ -247,6 +269,21 @@ public UserCredentials getCredentials(String userId) throws IOException {
* @throws IOException An error from the server API call to get the tokens.
*/
public UserCredentials getCredentialsFromCode(String code, URI baseUri) throws IOException {
return getCredentialsFromCode(code, baseUri, null);
}

/**
* Returns a UserCredentials instance by exchanging an OAuth2 authorization code for tokens.
*
* @param code Code returned from OAuth2 consent prompt.
* @param baseUri The URI to resolve the OAuth2 callback URI relative to.
* @param additionalParameters Additional parameters to be added to the post body of token
* endpoint request.
* @return the UserCredentials instance created from the authorization code.
* @throws IOException An error from the server API call to get the tokens.
*/
public UserCredentials getCredentialsFromCode(
String code, URI baseUri, Map<String, String> additionalParameters) throws IOException {
Preconditions.checkNotNull(code);
URI resolvedCallbackUri = getCallbackUri(baseUri);

Expand All @@ -257,6 +294,12 @@ public UserCredentials getCredentialsFromCode(String code, URI baseUri) throws I
tokenData.put("redirect_uri", resolvedCallbackUri);
tokenData.put("grant_type", "authorization_code");

if (additionalParameters != null) {
for (Map.Entry<String, String> entry : additionalParameters.entrySet()) {
tokenData.put(entry.getKey(), entry.getValue());
}
}

if (pkce != null) {
tokenData.put("code_verifier", pkce.getCodeVerifier());
}
Expand Down
Expand Up @@ -65,6 +65,8 @@ public class MockTokenServerTransport extends MockHttpTransport {
final Map<String, String> serviceAccounts = new HashMap<String, String>();
final Map<String, String> gdchServiceAccounts = new HashMap<String, String>();
final Map<String, String> codes = new HashMap<String, String>();
final Map<String, Map<String, String>> additionalParameters =
new HashMap<String, Map<String, String>>();
URI tokenServerUri = OAuth2Utils.TOKEN_SERVER_URI;
private IOException error;
private final Queue<Future<LowLevelHttpResponse>> responseSequence = new ArrayDeque<>();
Expand All @@ -81,10 +83,18 @@ public void setTokenServerUri(URI tokenServerUri) {
}

public void addAuthorizationCode(
String code, String refreshToken, String accessToken, String grantedScopes) {
String code,
String refreshToken,
String accessToken,
String grantedScopes,
Map<String, String> additionalParameters) {
codes.put(code, refreshToken);
refreshTokens.put(refreshToken, accessToken);
this.grantedScopes.put(refreshToken, grantedScopes);

if (additionalParameters != null) {
this.additionalParameters.put(refreshToken, additionalParameters);
}
}

public void addClient(String clientId, String clientSecret) {
Expand Down Expand Up @@ -220,6 +230,29 @@ public LowLevelHttpResponse execute() throws IOException {
if (grantedScopes.containsKey(refreshToken)) {
grantedScopesString = grantedScopes.get(refreshToken);
}

if (additionalParameters.containsKey(refreshToken)) {
Map<String, String> additionalParametersMap = additionalParameters.get(refreshToken);
for (Map.Entry<String, String> entry : additionalParametersMap.entrySet()) {
String key = entry.getKey();
String expectedValue = entry.getValue();
if (!query.containsKey(key)) {
throw new IllegalArgumentException("Missing additional parameter: " + key);
} else {
String actualValue = query.get(key);
if (!expectedValue.equals(actualValue)) {
throw new IllegalArgumentException(
"For additional parameter "
+ key
+ ", Actual value: "
+ actualValue
+ ", Expected value: "
+ expectedValue);
}
}
}
}

} else if (query.containsKey("grant_type")) {
String grantType = query.get("grant_type");
String assertion = query.get("assertion");
Expand Down
Expand Up @@ -32,6 +32,7 @@
package com.google.auth.oauth2;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.fail;
Expand All @@ -43,6 +44,7 @@
import java.net.URL;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.junit.Test;
Expand Down Expand Up @@ -170,6 +172,50 @@ public void getAuthorizationUrl() throws IOException {
assertEquals(pkce.getCodeChallengeMethod(), parameters.get("code_challenge_method"));
}

@Test
public void getAuthorizationUrl_additionalParameters() throws IOException {
final String CUSTOM_STATE = "custom_state";
final String PROTOCOL = "https";
final String HOST = "accounts.test.com";
final String PATH = "/o/o/oauth2/auth";
final URI AUTH_URI = URI.create(PROTOCOL + "://" + HOST + PATH);
final String EXPECTED_CALLBACK = "http://example.com" + CALLBACK_URI.toString();
UserAuthorizer authorizer =
UserAuthorizer.newBuilder()
.setClientId(CLIENT_ID)
.setScopes(DUMMY_SCOPES)
.setCallbackUri(CALLBACK_URI)
.setUserAuthUri(AUTH_URI)
.build();
Map<String, String> additionalParameters = new HashMap<String, String>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");

// Verify that the authorization URL doesn't include the additional parameters if they are not
// passed in.
URL authorizationUrl = authorizer.getAuthorizationUrl(USER_ID, CUSTOM_STATE, BASE_URI);
String query = authorizationUrl.getQuery();
Map<String, String> parameters = TestUtils.parseQuery(query);
assertFalse(parameters.containsKey("param1"));
assertFalse(parameters.containsKey("param2"));

// Verify that the authorization URL includes the additional parameters if they are passed in.
authorizationUrl =
authorizer.getAuthorizationUrl(USER_ID, CUSTOM_STATE, BASE_URI, additionalParameters);
query = authorizationUrl.getQuery();
parameters = TestUtils.parseQuery(query);
assertEquals("value1", parameters.get("param1"));
assertEquals("value2", parameters.get("param2"));

// Verify that the authorization URL doesn't include the additional parameters passed in the
// previous call to the authorizer
authorizationUrl = authorizer.getAuthorizationUrl(USER_ID, CUSTOM_STATE, BASE_URI);
query = authorizationUrl.getQuery();
parameters = TestUtils.parseQuery(query);
assertFalse(parameters.containsKey("param1"));
assertFalse(parameters.containsKey("param2"));
}

@Test
public void getCredentials_noCredentials_returnsNull() throws IOException {
UserAuthorizer authorizer =
Expand Down Expand Up @@ -340,7 +386,41 @@ public void getCredentialsFromCode_conevertsCodeToTokens() throws IOException {
MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
transportFactory.transport.addClient(CLIENT_ID_VALUE, CLIENT_SECRET);
transportFactory.transport.addAuthorizationCode(
CODE, REFRESH_TOKEN, ACCESS_TOKEN_VALUE, GRANTED_SCOPES_STRING);
CODE, REFRESH_TOKEN, ACCESS_TOKEN_VALUE, GRANTED_SCOPES_STRING, null);
TokenStore tokenStore = new MemoryTokensStorage();
UserAuthorizer authorizer =
UserAuthorizer.newBuilder()
.setClientId(CLIENT_ID)
.setScopes(DUMMY_SCOPES)
.setTokenStore(tokenStore)
.setHttpTransportFactory(transportFactory)
.build();

UserCredentials credentials = authorizer.getCredentialsFromCode(CODE, BASE_URI);

assertEquals(REFRESH_TOKEN, credentials.getRefreshToken());
assertEquals(ACCESS_TOKEN_VALUE, credentials.getAccessToken().getTokenValue());
assertEquals(GRANTED_SCOPES, credentials.getAccessToken().getScopes());
}

@Test
public void getCredentialsFromCode_additionalParameters() throws IOException {
MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
transportFactory.transport.addClient(CLIENT_ID_VALUE, CLIENT_SECRET);

Map<String, String> additionalParameters = new HashMap<String, String>();
additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2");

String code2 = "code2";
String refreshToken2 = "refreshToken2";
String accessTokenValue2 = "accessTokenValue2";

transportFactory.transport.addAuthorizationCode(
CODE, REFRESH_TOKEN, ACCESS_TOKEN_VALUE, GRANTED_SCOPES_STRING, null);
transportFactory.transport.addAuthorizationCode(
code2, refreshToken2, accessTokenValue2, GRANTED_SCOPES_STRING, additionalParameters);

TokenStore tokenStore = new MemoryTokensStorage();
UserAuthorizer authorizer =
UserAuthorizer.newBuilder()
Expand All @@ -350,8 +430,20 @@ public void getCredentialsFromCode_conevertsCodeToTokens() throws IOException {
.setHttpTransportFactory(transportFactory)
.build();

// Verify that the additional parameters are not attached to the post body when not specified
UserCredentials credentials = authorizer.getCredentialsFromCode(CODE, BASE_URI);
assertEquals(REFRESH_TOKEN, credentials.getRefreshToken());
assertEquals(ACCESS_TOKEN_VALUE, credentials.getAccessToken().getTokenValue());
assertEquals(GRANTED_SCOPES, credentials.getAccessToken().getScopes());

// Verify that the additional parameters are attached to the post body when specified
credentials = authorizer.getCredentialsFromCode(code2, BASE_URI, additionalParameters);
assertEquals(refreshToken2, credentials.getRefreshToken());
assertEquals(accessTokenValue2, credentials.getAccessToken().getTokenValue());
assertEquals(GRANTED_SCOPES, credentials.getAccessToken().getScopes());

// Verify that the additional parameters from previous request are not attached to the post body
credentials = authorizer.getCredentialsFromCode(CODE, BASE_URI);
assertEquals(REFRESH_TOKEN, credentials.getRefreshToken());
assertEquals(ACCESS_TOKEN_VALUE, credentials.getAccessToken().getTokenValue());
assertEquals(GRANTED_SCOPES, credentials.getAccessToken().getScopes());
Expand All @@ -376,7 +468,7 @@ public void getAndStoreCredentialsFromCode_getAndStoresCredentials() throws IOEx
MockTokenServerTransportFactory transportFactory = new MockTokenServerTransportFactory();
transportFactory.transport.addClient(CLIENT_ID_VALUE, CLIENT_SECRET);
transportFactory.transport.addAuthorizationCode(
CODE, REFRESH_TOKEN, accessTokenValue1, GRANTED_SCOPES_STRING);
CODE, REFRESH_TOKEN, accessTokenValue1, GRANTED_SCOPES_STRING, null);
TokenStore tokenStore = new MemoryTokensStorage();
UserAuthorizer authorizer =
UserAuthorizer.newBuilder()
Expand Down

0 comments on commit 3a2c5d3

Please sign in to comment.