Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: comply with undocumented helix rate limits #561

Merged
merged 11 commits into from May 4, 2022
Expand Up @@ -12,7 +12,6 @@
import feign.RetryableException;
import feign.codec.Decoder;
import feign.codec.ErrorDecoder;
import io.github.bucket4j.Bucket;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.exception.ContextedRuntimeException;
Expand Down Expand Up @@ -80,8 +79,7 @@ public Exception decode(String methodKey, Response response) {
RequestTemplate template = response.request().requestTemplate();
if (template.path().endsWith("/moderation/bans")) {
String channelId = template.queries().get("broadcaster_id").iterator().next();
Bucket modBucket = interceptor.getModerationBucket(channelId);
modBucket.consumeIgnoringRateLimits(Math.max(modBucket.tryConsumeAsMuchAsPossible(), 1)); // intentionally go negative to induce a pause
interceptor.getRateLimitTracker().markDepletedBanBucket(channelId);
}
} else if (response.status() == 503) {
// If you get an HTTP 503 (Service Unavailable) error, retry once.
Expand Down
Expand Up @@ -42,7 +42,7 @@ public class SendPubSubMessageInput {

/**
* Strings for valid PubSub targets.
* Valid values: "broadcast", "global", "whisper-<user-id>"
* Valid values: "broadcast", "global", "{@literal whisper-<user-id>}"
*/
@Singular
@JsonProperty("target")
Expand Down
Expand Up @@ -4,23 +4,18 @@
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.philippheuer.credentialmanager.domain.OAuth2Credential;
import com.github.twitch4j.auth.providers.TwitchIdentityProvider;
import com.github.twitch4j.common.annotation.Unofficial;
import com.github.twitch4j.common.util.BucketUtils;
import com.github.twitch4j.helix.TwitchHelixBuilder;
import feign.RequestInterceptor;
import feign.RequestTemplate;
import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import lombok.AccessLevel;
import lombok.Getter;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;

import java.time.Duration;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

/**
* Injects ClientId Header, the User Agent and other common headers into each API Request
Expand All @@ -31,24 +26,6 @@ public class TwitchHelixClientIdInterceptor implements RequestInterceptor {
public static final String AUTH_HEADER = "Authorization";
public static final String BEARER_PREFIX = "Bearer ";

/**
* Empirically determined rate limit on helix bans and unbans, per channel
*/
@Unofficial
private static final Bandwidth BANS_BANDWIDTH = Bandwidth.simple(100, Duration.ofSeconds(30));

/**
* Empirically determined rate limit on the helix create clip endpoint, per user
*/
@Unofficial
private static final Bandwidth CLIPS_BANDWIDTH = Bandwidth.simple(600, Duration.ofSeconds(60));

/**
* Empirically determined rate limit on helix add and remove block term, per channel
*/
@Unofficial
private static final Bandwidth TERMS_BANDWIDTH = Bandwidth.simple(60, Duration.ofSeconds(60));

/**
* Reference to the Client Builder
*/
Expand All @@ -57,8 +34,15 @@ public class TwitchHelixClientIdInterceptor implements RequestInterceptor {
/**
* Helix Rate Limit
*/
@Getter(AccessLevel.PROTECTED)
private final Bandwidth apiRateLimit;

/**
* Helix Rate Limit Tracker
*/
@Getter
private final TwitchHelixRateLimitTracker rateLimitTracker;

/**
* Reference to the twitch identity provider
*/
Expand All @@ -74,34 +58,6 @@ public class TwitchHelixClientIdInterceptor implements RequestInterceptor {
.maximumSize(10_000)
.build();

/**
* Rate limit buckets by user/app
*/
private final Cache<String, Bucket> buckets = Caffeine.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build();

/**
* Moderation API: ban and unban rate limit buckets per channel
*/
private final Cache<String, Bucket> bansByChannelId = Caffeine.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build();

/**
* Create Clip API rate limit buckets per user
*/
private final Cache<String, Bucket> clipsByUserId = Caffeine.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build();

/**
* Moderation API: add and remove blocked term rate limit buckets per channel
*/
private final Cache<String, Bucket> termsByChannelId = Caffeine.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build();

/**
* The default app access token that is used if no oauth was passed by the user
*/
Expand All @@ -123,6 +79,7 @@ public TwitchHelixClientIdInterceptor(TwitchHelixBuilder twitchHelixBuilder) {
this.defaultClientId = twitchAPIBuilder.getClientId();
this.apiRateLimit = twitchAPIBuilder.getApiRateLimit();
this.defaultAuthToken = twitchHelixBuilder.getDefaultAuthToken();
this.rateLimitTracker = new TwitchHelixRateLimitTracker(this);
iProdigy marked this conversation as resolved.
Show resolved Hide resolved
if (defaultAuthToken != null)
twitchIdentityProvider.getAdditionalCredentialInformation(defaultAuthToken).ifPresent(oauth -> {
this.defaultClientId = (String) oauth.getContext().get("client_id");
Expand Down Expand Up @@ -182,47 +139,18 @@ public void apply(RequestTemplate template) {
template.header("User-Agent", twitchAPIBuilder.getUserAgent());
}

public void updateRemaining(String token, int remaining) {
this.updateRemainingGeneric(token, remaining, this::getKey, this::getOrInitializeBucket);
}

public void updateRemainingCreateClip(String token, int remaining) {
this.updateRemainingGeneric(token, remaining, OAuth2Credential::getUserId, this::getClipBucket);
}

public void clearDefaultToken() {
this.defaultAuthToken = null;
}

protected String getKey(OAuth2Credential credential) {
String clientId = (String) credential.getContext().get("client_id");
return clientId == null ? null : credential.getUserId() == null ? clientId : clientId + "-" + credential.getUserId();
}

protected Bucket getOrInitializeBucket(String key) {
return buckets.get(key, k -> BucketUtils.createBucket(this.apiRateLimit));
}

public Bucket getModerationBucket(String channelId) {
return bansByChannelId.get(channelId, k -> BucketUtils.createBucket(BANS_BANDWIDTH));
}

protected Bucket getClipBucket(String userId) {
return clipsByUserId.get(userId, k -> BucketUtils.createBucket(CLIPS_BANDWIDTH));
}

protected Bucket getTermsBucket(String channelId) {
return termsByChannelId.get(channelId, k -> BucketUtils.createBucket(TERMS_BANDWIDTH));
}

private OAuth2Credential getOrCreateAuthToken() {
if (defaultAuthToken == null)
synchronized (this) {
if (defaultAuthToken == null) {
String clientId = twitchAPIBuilder.getClientId();
OAuth2Credential token = twitchIdentityProvider.getAppAccessToken();
token.getContext().put("client_id", clientId);
getOrInitializeBucket(clientId);
rateLimitTracker.getOrInitializeBucket(clientId);
accessTokenCache.put(token.getAccessToken(), token);
this.defaultClientId = clientId;
return this.defaultAuthToken = token;
Expand All @@ -232,16 +160,4 @@ private OAuth2Credential getOrCreateAuthToken() {
return this.defaultAuthToken;
}

private void updateRemainingGeneric(String token, int remaining, Function<OAuth2Credential, String> credToKey, Function<String, Bucket> keyToBucket) {
OAuth2Credential credential = accessTokenCache.getIfPresent(token);
if (credential == null) return;

String key = credToKey.apply(credential);
if (key == null) return;

Bucket bucket = keyToBucket.apply(key);
long diff = bucket.getAvailableTokens() - remaining;
if (diff > 0) bucket.tryConsumeAsMuchAsPossible(diff);
}

}
Expand Up @@ -42,10 +42,10 @@ public Object decode(Response response, Type type) throws IOException {
String bearer = token.substring(BEARER_PREFIX.length());
if (response.request().httpMethod() == Request.HttpMethod.POST && response.request().requestTemplate().path().endsWith("/clips")) {
// Create Clip has a separate rate limit to synchronize
interceptor.updateRemainingCreateClip(bearer, remaining);
interceptor.getRateLimitTracker().updateRemainingCreateClip(bearer, remaining);
} else {
// Normal/global helix rate limit synchronization
interceptor.updateRemaining(bearer, remaining);
interceptor.getRateLimitTracker().updateRemaining(bearer, remaining);
}
}
}
Expand Down
Expand Up @@ -45,7 +45,7 @@ public Response execute(Request request, Request.Options options) throws IOExcep
OAuth2Credential credential = interceptor.getAccessTokenCache().getIfPresent(token.substring(BEARER_PREFIX.length()));
if (credential != null) {
// First consume from helix global rate limit (800/min by default)
Bucket bucket = interceptor.getOrInitializeBucket(interceptor.getKey(credential));
Bucket bucket = interceptor.getRateLimitTracker().getOrInitializeBucket(interceptor.getRateLimitTracker().getKey(credential));
return executeAgainstBucket(bucket, () -> delegatedExecute(request, options));
}
}
Expand All @@ -71,7 +71,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();

// Conform to endpoint-specific bucket
Bucket modBucket = interceptor.getModerationBucket(channelId);
Bucket modBucket = interceptor.getRateLimitTracker().getModerationBucket(channelId);
return executeAgainstBucket(modBucket, () -> client.execute(request, options));
}

Expand All @@ -81,7 +81,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();

// Conform to endpoint-specific bucket
Bucket termsBucket = interceptor.getTermsBucket(channelId);
Bucket termsBucket = interceptor.getRateLimitTracker().getTermsBucket(channelId);
return executeAgainstBucket(termsBucket, () -> client.execute(request, options));
}

Expand All @@ -93,7 +93,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
String userId = cred != null ? cred.getUserId() : "";

// Conform to endpoint-specific bucket
Bucket clipBucket = interceptor.getClipBucket(userId != null ? userId : "");
Bucket clipBucket = interceptor.getRateLimitTracker().getClipBucket(userId != null ? userId : "");
return executeAgainstBucket(clipBucket, () -> client.execute(request, options));
}

Expand Down
@@ -0,0 +1,136 @@
package com.github.twitch4j.helix.interceptor;

import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.github.philippheuer.credentialmanager.domain.OAuth2Credential;
import com.github.twitch4j.common.annotation.Unofficial;
import com.github.twitch4j.common.util.BucketUtils;
import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import lombok.RequiredArgsConstructor;

import java.time.Duration;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

@RequiredArgsConstructor
public class TwitchHelixRateLimitTracker {

/**
* Empirically determined rate limit on helix bans and unbans, per channel
*/
@Unofficial
private static final Bandwidth BANS_BANDWIDTH = Bandwidth.simple(100, Duration.ofSeconds(30));

/**
* Empirically determined rate limit on the helix create clip endpoint, per user
*/
@Unofficial
private static final Bandwidth CLIPS_BANDWIDTH = Bandwidth.simple(600, Duration.ofSeconds(60));

/**
* Empirically determined rate limit on helix add and remove block term, per channel
*/
@Unofficial
private static final Bandwidth TERMS_BANDWIDTH = Bandwidth.simple(60, Duration.ofSeconds(60));

/**
* Rate limit buckets by user/app
*/
private final Cache<String, Bucket> primaryBuckets = Caffeine.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build();

/**
* Moderation API: ban and unban rate limit buckets per channel
*/
private final Cache<String, Bucket> bansByChannelId = Caffeine.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build();

/**
* Create Clip API rate limit buckets per user
*/
private final Cache<String, Bucket> clipsByUserId = Caffeine.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build();

/**
* Moderation API: add and remove blocked term rate limit buckets per channel
*/
private final Cache<String, Bucket> termsByChannelId = Caffeine.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build();

/**
* Twitch Helix Interceptor
*/
private final TwitchHelixClientIdInterceptor interceptor; // provided by RequiredArgsConstructor

/*
* Primary (global helix) rate limit bucket finder
*/

protected Bucket getOrInitializeBucket(String key) {
return primaryBuckets.get(key, k -> BucketUtils.createBucket(interceptor.getApiRateLimit()));
}

protected String getKey(OAuth2Credential credential) {
// App access tokens share the same bucket for a given client id
// User access tokens share the same bucket for a given client id & user id pair
// For this method to work, credential must have been augmented with information from getAdditionalCredentialInformation (which is done by the interceptor)
// Thus, this logic yields the key that is associated with each primary helix bucket
String clientId = (String) credential.getContext().get("client_id");
return clientId == null ? null : credential.getUserId() == null ? clientId : clientId + "-" + credential.getUserId();
}

/*
* Secondary (endpoint-specific) rate limit buckets
*/

@Unofficial
protected Bucket getModerationBucket(String channelId) {
return bansByChannelId.get(channelId, k -> BucketUtils.createBucket(BANS_BANDWIDTH));
}

@Unofficial
protected Bucket getClipBucket(String userId) {
return clipsByUserId.get(userId, k -> BucketUtils.createBucket(CLIPS_BANDWIDTH));
}

@Unofficial
protected Bucket getTermsBucket(String channelId) {
return termsByChannelId.get(channelId, k -> BucketUtils.createBucket(TERMS_BANDWIDTH));
}

/*
* Methods to conservatively update remaining points in rate limit buckets, based on incoming twitch statistics
*/

public void updateRemaining(String token, int remaining) {
this.updateRemainingGeneric(token, remaining, this::getKey, this::getOrInitializeBucket);
}

public void updateRemainingCreateClip(String token, int remaining) {
this.updateRemainingGeneric(token, remaining, OAuth2Credential::getUserId, this::getClipBucket);
}

public void markDepletedBanBucket(String channelId) {
// Called upon a 429 for banUser or unbanUser
Bucket modBucket = this.getModerationBucket(channelId);
modBucket.consumeIgnoringRateLimits(Math.max(modBucket.tryConsumeAsMuchAsPossible(), 1)); // intentionally go negative to induce a pause
}

private void updateRemainingGeneric(String token, int remaining, Function<OAuth2Credential, String> credToKey, Function<String, Bucket> keyToBucket) {
OAuth2Credential credential = interceptor.getAccessTokenCache().getIfPresent(token);
if (credential == null) return;

String key = credToKey.apply(credential);
if (key == null) return;

Bucket bucket = keyToBucket.apply(key);
long diff = bucket.getAvailableTokens() - remaining;
if (diff > 0) bucket.tryConsumeAsMuchAsPossible(diff);
}

}