Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement send extension message rate limit #576

Merged
merged 6 commits into from Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
Expand All @@ -23,6 +24,7 @@
@NoArgsConstructor
@AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonPropertyOrder(alphabetic = true) // used in http client to avoid extra deserialization
public class SendPubSubMessageInput {

/**
Expand Down
Expand Up @@ -14,6 +14,7 @@ public class TwitchHelixClientIdInterceptor implements RequestInterceptor {

public static final String AUTH_HEADER = "Authorization";
public static final String BEARER_PREFIX = "Bearer ";
public static final String CLIENT_HEADER = "Client-Id";

/**
* User Agent
Expand Down Expand Up @@ -65,8 +66,8 @@ public void apply(RequestTemplate template) {
}

// set headers
if (!template.headers().containsKey("Client-Id"))
template.header("Client-Id", clientId);
if (!template.headers().containsKey(CLIENT_HEADER))
template.header(CLIENT_HEADER, clientId);
template.header("User-Agent", userAgent);
}

Expand Down
Expand Up @@ -8,9 +8,12 @@
import java.io.IOException;
import java.lang.reflect.Type;
import java.util.Collection;
import java.util.Collections;

import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.AUTH_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.BEARER_PREFIX;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.CLIENT_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixHttpClient.getExtensionPubSubTarget;

public class TwitchHelixDecoder extends JacksonDecoder {

Expand All @@ -26,7 +29,8 @@ public TwitchHelixDecoder(ObjectMapper mapper, TwitchHelixRateLimitTracker rateL
@Override
public Object decode(Response response, Type type) throws IOException {
// track rate limit for token
String token = singleFirst(response.request().headers().get(AUTH_HEADER));
Request request = response.request();
String token = singleFirst(request.headers().get(AUTH_HEADER));
if (token != null && token.startsWith(BEARER_PREFIX)) {
// Parse remaining
String remainingStr = singleFirst(response.headers().get(REMAINING_HEADER));
Expand All @@ -40,9 +44,19 @@ public Object decode(Response response, Type type) throws IOException {
// Synchronize library buckets with twitch data
if (remaining != null) {
String bearer = token.substring(BEARER_PREFIX.length());
if (response.request().httpMethod() == Request.HttpMethod.POST && response.request().requestTemplate().path().endsWith("/clips")) {
if (request.httpMethod() == Request.HttpMethod.POST && request.requestTemplate().path().endsWith("/clips")) {
// Create Clip has a separate rate limit to synchronize
rateLimitTracker.updateRemainingCreateClip(bearer, remaining);
} else if (request.httpMethod() == Request.HttpMethod.POST && request.requestTemplate().path().endsWith("/extensions/chat")) {
// Send Extension Chat Message rate limit
String clientId = request.headers().getOrDefault(CLIENT_HEADER, Collections.emptyList()).iterator().next();
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
rateLimitTracker.updateRemainingExtensionChat(clientId, channelId, remaining);
} else if (request.httpMethod() == Request.HttpMethod.POST && request.requestTemplate().path().endsWith("/extensions/pubsub")) {
// Send Extension PubSub Message rate limit
String clientId = request.headers().getOrDefault(CLIENT_HEADER, Collections.emptyList()).iterator().next();
String target = getExtensionPubSubTarget(request.body());
rateLimitTracker.updateRemainingExtensionPubSub(clientId, target, remaining);
} else {
// Normal/global helix rate limit synchronization
rateLimitTracker.updateRemaining(bearer, remaining);
Expand Down
Expand Up @@ -20,6 +20,7 @@

import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.AUTH_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.BEARER_PREFIX;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.CLIENT_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixDecoder.singleFirst;

@Slf4j
Expand Down Expand Up @@ -99,10 +100,47 @@ private Response delegatedExecute(Request request, Request.Options options) thro
return executeAgainstBucket(clipBucket, () -> client.execute(request, options));
}

// Extensions API: sendExtensionChatMessage has a stricter per-channel bucket
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/extensions/chat")) {
// Obtain the bucket key
String clientId = request.headers().getOrDefault(CLIENT_HEADER, Collections.emptyList()).iterator().next();
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();

// Conform to endpoint-specific bucket
Bucket chatBucket = rateLimitTracker.getExtensionChatBucket(clientId, channelId);
return executeAgainstBucket(chatBucket, () -> client.execute(request, options));
}

// Extensions API: sendExtensionPubSubMessage has a stricter bucket depending on the target
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/extensions/pubsub")) {
// Obtain the bucket key
String clientId = request.headers().getOrDefault(CLIENT_HEADER, Collections.emptyList()).iterator().next();
String target = getExtensionPubSubTarget(request.body());

// Conform to endpoint-specific bucket
Bucket pubSubBucket = rateLimitTracker.getExtensionPubSubBucket(clientId, target);
return executeAgainstBucket(pubSubBucket, () -> client.execute(request, options));
}

// no endpoint-specific rate limiting was needed; simply perform network request now
return client.execute(request, options);
}

static String getExtensionPubSubTarget(byte[] body) {
iProdigy marked this conversation as resolved.
Show resolved Hide resolved
String bodyStr = new String(body);
int i = bodyStr.indexOf("\"broadcaster_id\":");
String target; // if no broadcaster is specified, the target is global. alphabetical field order provides a check that we aren't grabbing an id from within the specified message
if (i < 0 || i > bodyStr.indexOf("\"message\":")) {
target = "global";
} else {
i += "\"broadcaster_id\":".length();
int start = bodyStr.indexOf('"', i) + 1;
int end = bodyStr.indexOf('"', start);
target = end > start ? bodyStr.substring(start, end) : "global";
}
return target;
}

private <T> T executeAgainstBucket(Bucket bucket, Callable<T> call) throws IOException {
try {
return BucketUtils.scheduleAgainstBucket(bucket, executor, call).get(timeout, TimeUnit.MILLISECONDS);
Expand Down
Expand Up @@ -5,8 +5,10 @@
import com.github.philippheuer.credentialmanager.domain.OAuth2Credential;
import com.github.twitch4j.common.annotation.Unofficial;
import com.github.twitch4j.common.util.BucketUtils;
import com.github.twitch4j.helix.domain.SendPubSubMessageInput;
import io.github.bucket4j.Bandwidth;
import io.github.bucket4j.Bucket;
import io.github.bucket4j.Refill;
import lombok.RequiredArgsConstructor;
import org.apache.commons.lang3.StringUtils;
import org.jetbrains.annotations.NotNull;
Expand All @@ -19,6 +21,18 @@
@SuppressWarnings("ConstantConditions")
public final class TwitchHelixRateLimitTracker {

/**
* Officially documented per-channel rate limit on {@link com.github.twitch4j.helix.TwitchHelix#sendExtensionChatMessage(String, String, String, String, String)}
*/
private static final Bandwidth EXT_CHAT_BANDWIDTH = Bandwidth.simple(12, Duration.ofMinutes(1L));

/**
* Officially documented bucket size (but unofficial refill rate) for {@link com.github.twitch4j.helix.TwitchHelix#sendExtensionPubSubMessage(String, String, SendPubSubMessageInput)}
*
* @see <a href="https://github.com/twitchdev/issues/issues/612">Issue report</a>
*/
private static final Bandwidth EXT_PUBSUB_BANDWIDTH = Bandwidth.classic(100, Refill.greedy(1, Duration.ofSeconds(1L)));

/**
* Empirically determined rate limit on helix bans and unbans, per channel
*/
Expand All @@ -41,9 +55,23 @@ public final class TwitchHelixRateLimitTracker {
* Rate limit buckets by user/app
*/
private final Cache<String, Bucket> primaryBuckets = Caffeine.newBuilder()
.expireAfterAccess(80, TimeUnit.SECONDS)
.build();

/**
* Extensions API: send chat message rate limit buckets per channel
*/
private final Cache<String, Bucket> extensionChatBuckets = Caffeine.newBuilder()
.expireAfterAccess(1, TimeUnit.MINUTES)
.build();

/**
* Extensions API: send pubsub message rate limit buckets per channel
*/
private final Cache<String, Bucket> extensionPubSubBuckets = Caffeine.newBuilder()
.expireAfterAccess(100, TimeUnit.SECONDS)
.build();

/**
* Moderation API: ban and unban rate limit buckets per channel
*/
Expand Down Expand Up @@ -98,6 +126,16 @@ String getPrimaryBucketKey(@NotNull OAuth2Credential credential) {
* Secondary (endpoint-specific) rate limit buckets
*/

@NotNull
Bucket getExtensionChatBucket(@NotNull String clientId, @NotNull String channelId) {
return extensionChatBuckets.get(clientId + ':' + channelId, k -> BucketUtils.createBucket(EXT_CHAT_BANDWIDTH));
}

@NotNull
Bucket getExtensionPubSubBucket(@NotNull String clientId, @NotNull String channelId) {
return extensionPubSubBuckets.get(clientId + ':' + channelId, k -> BucketUtils.createBucket(EXT_PUBSUB_BANDWIDTH));
}

@NotNull
@Unofficial
Bucket getModerationBucket(@NotNull String channelId) {
Expand All @@ -124,6 +162,14 @@ public void updateRemaining(@NotNull String token, int remaining) {
this.updateRemainingGeneric(token, remaining, this::getPrimaryBucketKey, this::getOrInitializeBucket);
}

public void updateRemainingExtensionChat(@NotNull String clientId, @NotNull String channelId, int remaining) {
this.updateRemainingConservative(getExtensionChatBucket(clientId, channelId), remaining);
}

public void updateRemainingExtensionPubSub(@NotNull String clientId, @NotNull String target, int remaining) {
this.updateRemainingConservative(getExtensionPubSubBucket(clientId, target), remaining);
}

public void updateRemainingCreateClip(@NotNull String token, int remaining) {
this.updateRemainingGeneric(token, remaining, OAuth2Credential::getUserId, this::getClipBucket);
}
Expand All @@ -143,6 +189,10 @@ private void updateRemainingGeneric(String token, int remaining, Function<OAuth2
if (key == null) return;

Bucket bucket = keyToBucket.apply(key);
updateRemainingConservative(bucket, remaining);
}

private void updateRemainingConservative(Bucket bucket, int remaining) {
long diff = bucket.getAvailableTokens() - remaining;
if (diff > 0) bucket.tryConsumeAsMuchAsPossible(diff);
}
Expand Down