Skip to content

Commit

Permalink
refactor: create TwitchHelixTokenManager
Browse files Browse the repository at this point in the history
  • Loading branch information
iProdigy committed Apr 21, 2022
1 parent 9b444ae commit a4321fe
Show file tree
Hide file tree
Showing 9 changed files with 208 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import io.github.bucket4j.local.LocalBucketBuilder;
import org.jetbrains.annotations.NotNull;

import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
Expand All @@ -16,7 +17,8 @@ public class BucketUtils {
* @param limit the bandwidth
* @return the bucket
*/
public static Bucket createBucket(Bandwidth limit) {
@NotNull
public static Bucket createBucket(@NotNull Bandwidth limit) {
return Bucket.builder().addLimit(limit).build();
}

Expand All @@ -26,7 +28,8 @@ public static Bucket createBucket(Bandwidth limit) {
* @param limits the bandwidths
* @return the bucket
*/
public static Bucket createBucket(Bandwidth... limits) {
@NotNull
public static Bucket createBucket(@NotNull Bandwidth... limits) {
LocalBucketBuilder builder = Bucket.builder();
for (Bandwidth limit : limits) {
builder.addLimit(limit);
Expand All @@ -40,7 +43,8 @@ public static Bucket createBucket(Bandwidth... limits) {
* @param limits the bandwidths
* @return the bucket
*/
public static Bucket createBucket(Iterable<Bandwidth> limits) {
@NotNull
public static Bucket createBucket(@NotNull Iterable<Bandwidth> limits) {
LocalBucketBuilder builder = Bucket.builder();
for (Bandwidth limit : limits) {
builder.addLimit(limit);
Expand All @@ -58,7 +62,8 @@ public static Bucket createBucket(Iterable<Bandwidth> limits) {
* @param call task that requires a bucket point
* @return the future result of the call
*/
public static <T> CompletableFuture<T> scheduleAgainstBucket(Bucket bucket, ScheduledExecutorService executor, Callable<T> call) {
@NotNull
public static <T> CompletableFuture<T> scheduleAgainstBucket(@NotNull Bucket bucket, @NotNull ScheduledExecutorService executor, @NotNull Callable<T> call) {
if (bucket.tryConsume(1L))
return CompletableFuture.supplyAsync(new SneakySupplier<>(call));

Expand All @@ -75,7 +80,8 @@ public static <T> CompletableFuture<T> scheduleAgainstBucket(Bucket bucket, Sche
* @param action runnable that requires a bucket point
* @return a future to track completion progress
*/
public static CompletableFuture<Void> scheduleAgainstBucket(Bucket bucket, ScheduledExecutorService executor, Runnable action) {
@NotNull
public static CompletableFuture<Void> scheduleAgainstBucket(@NotNull Bucket bucket, @NotNull ScheduledExecutorService executor, @NotNull Runnable action) {
if (bucket.tryConsume(1L))
return CompletableFuture.runAsync(action);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import lombok.RequiredArgsConstructor;
import lombok.SneakyThrows;
import org.jetbrains.annotations.NotNull;

import java.util.concurrent.Callable;
import java.util.function.Supplier;
Expand All @@ -19,6 +20,7 @@ public final class SneakySupplier<T> implements Supplier<T> {
/**
* The action to compute the supplied value, possibly throwing an exception.
*/
@NotNull
private final Callable<T> callable;

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import com.github.twitch4j.common.util.TypeConvert;
import com.github.twitch4j.helix.domain.CustomReward;
import com.github.twitch4j.helix.interceptor.CustomRewardEncodeMixIn;
import com.github.twitch4j.helix.interceptor.TwitchHelixTokenManager;
import com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor;
import com.github.twitch4j.helix.interceptor.TwitchHelixDecoder;
import com.github.twitch4j.helix.interceptor.TwitchHelixHttpClient;
import com.github.twitch4j.helix.interceptor.TwitchHelixRateLimitTracker;
import com.netflix.config.ConfigurationManager;
import feign.Logger;
import feign.Request;
Expand Down Expand Up @@ -167,15 +169,16 @@ public TwitchHelix build() {
apiRateLimit = DEFAULT_BANDWIDTH;

// Feign
TwitchHelixClientIdInterceptor interceptor = new TwitchHelixClientIdInterceptor(this);
TwitchHelixTokenManager tokenManager = new TwitchHelixTokenManager(clientId, clientSecret, defaultAuthToken);
TwitchHelixRateLimitTracker rateLimitTracker = new TwitchHelixRateLimitTracker(apiRateLimit, tokenManager);
return HystrixFeign.builder()
.client(new TwitchHelixHttpClient(new OkHttpClient(clientBuilder.build()), scheduledThreadPoolExecutor, interceptor, timeout))
.client(new TwitchHelixHttpClient(new OkHttpClient(clientBuilder.build()), scheduledThreadPoolExecutor, tokenManager, rateLimitTracker, timeout))
.encoder(new JacksonEncoder(serializer))
.decoder(new TwitchHelixDecoder(mapper, interceptor))
.decoder(new TwitchHelixDecoder(mapper, rateLimitTracker))
.logger(new Slf4jLogger())
.logLevel(logLevel)
.errorDecoder(new TwitchHelixErrorDecoder(new JacksonDecoder(), interceptor))
.requestInterceptor(interceptor)
.errorDecoder(new TwitchHelixErrorDecoder(new JacksonDecoder(), rateLimitTracker))
.requestInterceptor(new TwitchHelixClientIdInterceptor(userAgent, tokenManager))
.options(new Request.Options(timeout / 3, TimeUnit.MILLISECONDS, timeout, TimeUnit.MILLISECONDS, true))
.retryer(new Retryer.Default(500, timeout, 2))
.target(TwitchHelix.class, baseUrl);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import com.github.twitch4j.common.exception.UnauthorizedException;
import com.github.twitch4j.common.util.TypeConvert;
import com.github.twitch4j.helix.domain.TwitchHelixError;
import com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor;
import com.github.twitch4j.helix.interceptor.TwitchHelixRateLimitTracker;
import feign.Request;
import feign.RequestTemplate;
import feign.Response;
Expand All @@ -25,8 +25,8 @@ public class TwitchHelixErrorDecoder implements ErrorDecoder {
// Decoder
final Decoder decoder;

// Interceptor
final TwitchHelixClientIdInterceptor interceptor;
// Rate Limit Tracker
final TwitchHelixRateLimitTracker rateLimitTracker;

// Error Decoder
final ErrorDecoder defaultDecoder = new ErrorDecoder.Default();
Expand All @@ -37,12 +37,12 @@ public class TwitchHelixErrorDecoder implements ErrorDecoder {
/**
* Constructor
*
* @param decoder Feign Decoder
* @param interceptor Helix Interceptor
* @param decoder Feign Decoder
* @param rateLimitTracker Helix Rate Limit Tracker
*/
public TwitchHelixErrorDecoder(Decoder decoder, TwitchHelixClientIdInterceptor interceptor) {
public TwitchHelixErrorDecoder(Decoder decoder, TwitchHelixRateLimitTracker rateLimitTracker) {
this.decoder = decoder;
this.interceptor = interceptor;
this.rateLimitTracker = rateLimitTracker;
}

/**
Expand All @@ -54,7 +54,7 @@ public TwitchHelixErrorDecoder(Decoder decoder, TwitchHelixClientIdInterceptor i
*/
@Override
public Exception decode(String methodKey, Response response) {
Exception ex = null;
Exception ex;

try {
String responseBody = response.body() == null ? "" : IOUtils.toString(response.body().asInputStream(), StandardCharsets.UTF_8.name());
Expand All @@ -79,7 +79,7 @@ public Exception decode(String methodKey, Response response) {
RequestTemplate template = response.request().requestTemplate();
if (template.path().endsWith("/moderation/bans")) {
String channelId = template.queries().get("broadcaster_id").iterator().next();
interceptor.getRateLimitTracker().markDepletedBanBucket(channelId);
rateLimitTracker.markDepletedBanBucket(channelId);
}
} else if (response.status() == 503) {
// If you get an HTTP 503 (Service Unavailable) error, retry once.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,11 @@
package com.github.twitch4j.helix.interceptor;

import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.philippheuer.credentialmanager.domain.OAuth2Credential;
import com.github.twitch4j.auth.providers.TwitchIdentityProvider;
import com.github.twitch4j.helix.TwitchHelixBuilder;
import feign.RequestInterceptor;
import feign.RequestTemplate;
import io.github.bucket4j.Bandwidth;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

import java.util.Optional;
import java.util.concurrent.TimeUnit;

/**
* Injects ClientId Header, the User Agent and other common headers into each API Request
*/
Expand All @@ -27,64 +16,21 @@ public class TwitchHelixClientIdInterceptor implements RequestInterceptor {
public static final String BEARER_PREFIX = "Bearer ";

/**
* Reference to the Client Builder
*/
private final TwitchHelixBuilder twitchAPIBuilder;

/**
* Helix Rate Limit
* User Agent
*/
@Getter(AccessLevel.PROTECTED)
private final Bandwidth apiRateLimit;

/**
* Helix Rate Limit Tracker
*/
@Getter
private final TwitchHelixRateLimitTracker rateLimitTracker;

/**
* Reference to the twitch identity provider
*/
@Setter
private TwitchIdentityProvider twitchIdentityProvider;
private final String userAgent;

/**
* Access token cache
*/
@Getter(value = AccessLevel.PROTECTED)
private final Cache<String, OAuth2Credential> accessTokenCache = Caffeine.newBuilder()
.expireAfterAccess(15, TimeUnit.MINUTES)
.maximumSize(10_000)
.build();

/**
* The default app access token that is used if no oauth was passed by the user
*/
private volatile OAuth2Credential defaultAuthToken;

/**
* The default client id, typically associated with {@link TwitchHelixClientIdInterceptor#defaultAuthToken}
*/
private volatile String defaultClientId;
private final TwitchHelixTokenManager tokenManager;

/**
* Constructor
*
* @param twitchHelixBuilder Twitch Client Builder
*/
public TwitchHelixClientIdInterceptor(TwitchHelixBuilder twitchHelixBuilder) {
this.twitchAPIBuilder = twitchHelixBuilder;
twitchIdentityProvider = new TwitchIdentityProvider(twitchHelixBuilder.getClientId(), twitchHelixBuilder.getClientSecret(), null);
this.defaultClientId = twitchAPIBuilder.getClientId();
this.apiRateLimit = twitchAPIBuilder.getApiRateLimit();
this.defaultAuthToken = twitchHelixBuilder.getDefaultAuthToken();
this.rateLimitTracker = new TwitchHelixRateLimitTracker(this);
if (defaultAuthToken != null)
twitchIdentityProvider.getAdditionalCredentialInformation(defaultAuthToken).ifPresent(oauth -> {
this.defaultClientId = (String) oauth.getContext().get("client_id");
accessTokenCache.put(oauth.getAccessToken(), oauth);
});
public TwitchHelixClientIdInterceptor(String userAgent, TwitchHelixTokenManager tokenManager) {
this.userAgent = userAgent;
this.tokenManager = tokenManager;
}

/**
Expand All @@ -94,40 +40,25 @@ public TwitchHelixClientIdInterceptor(TwitchHelixBuilder twitchHelixBuilder) {
*/
@Override
public void apply(RequestTemplate template) {
String clientId = this.defaultClientId;
String clientId = tokenManager.getDefaultClientId();

// if a oauth token is passed is has to match that client id, default to global client id otherwise (for ie. token verification)
if (template.headers().containsKey(AUTH_HEADER)) {
String oauthToken = template.headers().get(AUTH_HEADER).iterator().next().substring(BEARER_PREFIX.length());

if (oauthToken.isEmpty()) {
String clientSecret = twitchAPIBuilder.getClientSecret();
if (defaultAuthToken == null && (StringUtils.isEmpty(clientId) || StringUtils.isEmpty(clientSecret) || clientSecret.charAt(0) == '*'))
throw new RuntimeException("Necessary OAuth token was missing from Helix call, without the means to generate one!");

try {
oauthToken = getOrCreateAuthToken().getAccessToken();
oauthToken = tokenManager.getDefaultAuthToken().getAccessToken();
clientId = tokenManager.getDefaultClientId();
} catch (Exception e) {
throw new RuntimeException("Failed to generate an app access token as no oauth token was passed to this Helix call", e);
}

template.removeHeader(AUTH_HEADER);
template.header(AUTH_HEADER, BEARER_PREFIX + oauthToken);
} else if (!StringUtils.contains(oauthToken, '.')) {
OAuth2Credential verifiedCredential = accessTokenCache.getIfPresent(oauthToken);
if (verifiedCredential == null) {
log.debug("Getting matching client-id for authorization token {}", oauthToken.substring(0, 5));

Optional<OAuth2Credential> requestedCredential = twitchIdentityProvider.getAdditionalCredentialInformation(new OAuth2Credential("twitch", oauthToken));
if (!requestedCredential.isPresent()) {
throw new RuntimeException("Failed to get the client_id for the provided authentication token, the authentication token may be invalid!");
}

verifiedCredential = requestedCredential.get();
accessTokenCache.put(oauthToken, verifiedCredential);
}

clientId = (String) verifiedCredential.getContext().get("client_id");
OAuth2Credential verifiedCredential = tokenManager.getOrPopulateCache(oauthToken);
clientId = TwitchHelixTokenManager.extractClientId(verifiedCredential);
}

log.debug("Setting new client-id {} for token {}", clientId, oauthToken.substring(0, 5));
Expand All @@ -136,28 +67,7 @@ public void apply(RequestTemplate template) {
// set headers
if (!template.headers().containsKey("Client-Id"))
template.header("Client-Id", clientId);
template.header("User-Agent", twitchAPIBuilder.getUserAgent());
}

public void clearDefaultToken() {
this.defaultAuthToken = null;
}

private OAuth2Credential getOrCreateAuthToken() {
if (defaultAuthToken == null)
synchronized (this) {
if (defaultAuthToken == null) {
String clientId = twitchAPIBuilder.getClientId();
OAuth2Credential token = twitchIdentityProvider.getAppAccessToken();
token.getContext().put("client_id", clientId);
rateLimitTracker.getOrInitializeBucket(clientId);
accessTokenCache.put(token.getAccessToken(), token);
this.defaultClientId = clientId;
return this.defaultAuthToken = token;
}
}

return this.defaultAuthToken;
template.header("User-Agent", userAgent);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ public class TwitchHelixDecoder extends JacksonDecoder {

public static final String REMAINING_HEADER = "Ratelimit-Remaining";

private final TwitchHelixClientIdInterceptor interceptor;
private final TwitchHelixRateLimitTracker rateLimitTracker;

public TwitchHelixDecoder(ObjectMapper mapper, TwitchHelixClientIdInterceptor interceptor) {
public TwitchHelixDecoder(ObjectMapper mapper, TwitchHelixRateLimitTracker rateLimitTracker) {
super(mapper);
this.interceptor = interceptor;
this.rateLimitTracker = rateLimitTracker;
}

@Override
Expand All @@ -42,10 +42,10 @@ public Object decode(Response response, Type type) throws IOException {
String bearer = token.substring(BEARER_PREFIX.length());
if (response.request().httpMethod() == Request.HttpMethod.POST && response.request().requestTemplate().path().endsWith("/clips")) {
// Create Clip has a separate rate limit to synchronize
interceptor.getRateLimitTracker().updateRemainingCreateClip(bearer, remaining);
rateLimitTracker.updateRemainingCreateClip(bearer, remaining);
} else {
// Normal/global helix rate limit synchronization
interceptor.getRateLimitTracker().updateRemaining(bearer, remaining);
rateLimitTracker.updateRemaining(bearer, remaining);
}
}
}
Expand Down

0 comments on commit a4321fe

Please sign in to comment.