Skip to content

Commit

Permalink
feat: implement send extension message rate limit (#576)
Browse files Browse the repository at this point in the history
  • Loading branch information
iProdigy committed Jul 25, 2022
1 parent 6cb62f4 commit 55b99ad
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 36 deletions.
35 changes: 24 additions & 11 deletions rest-helix/src/main/java/com/github/twitch4j/helix/TwitchHelix.java
Expand Up @@ -842,6 +842,22 @@ HystrixCommand<ExtensionLiveChannelsList> getExtensionLiveChannels(
@Param("after") String after
);

@Deprecated // meant to only be called internally
@SuppressWarnings("DeprecatedIsStillUsed")
@RequestLine("POST /extensions/pubsub")
@Headers({
"Authorization: Bearer {token}",
"Client-Id: {extension_id}",
"Content-Type: application/json",
"Twitch4J-Target: {twitch4j_target}"
})
HystrixCommand<Void> sendExtensionPubSubMessage(
@Param("token") String jwtToken,
@Param("extension_id") String extensionId,
@Param("twitch4j_target") String target,
SendPubSubMessageInput input
);

/**
* Twitch provides a publish-subscribe system for your EBS to communicate with both the broadcaster and viewers.
* Calling this endpoint forwards your message using the same mechanism as the send JavaScript helper function.
Expand All @@ -857,17 +873,14 @@ HystrixCommand<ExtensionLiveChannelsList> getExtensionLiveChannels(
* @param input Details on the message to be sent and its targets.
* @return 204 No Content upon a successful request.
*/
@RequestLine("POST /extensions/pubsub")
@Headers({
"Authorization: Bearer {token}",
"Client-Id: {extension_id}",
"Content-Type: application/json"
})
HystrixCommand<Void> sendExtensionPubSubMessage(
@Param("token") String jwtToken,
@Param("extension_id") String extensionId,
SendPubSubMessageInput input
);
default HystrixCommand<Void> sendExtensionPubSubMessage(
String jwtToken,
@NotNull String extensionId,
@NotNull SendPubSubMessageInput input
) {
final String target = input.isGlobalBroadcast() ? SendPubSubMessageInput.GLOBAL_TARGET : input.getBroadcasterId() != null ? input.getBroadcasterId() : input.getTargets().get(0);
return this.sendExtensionPubSubMessage(jwtToken, extensionId, target, input);
}

/**
* Gets information about a released Extension; either the current version or a specified version.
Expand Down
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
Expand All @@ -23,6 +24,7 @@
@NoArgsConstructor
@AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonPropertyOrder(alphabetic = true) // used in http client to avoid extra deserialization
public class SendPubSubMessageInput {

/**
Expand Down
Expand Up @@ -14,6 +14,7 @@ public class TwitchHelixClientIdInterceptor implements RequestInterceptor {

public static final String AUTH_HEADER = "Authorization";
public static final String BEARER_PREFIX = "Bearer ";
public static final String CLIENT_HEADER = "Client-Id";

/**
* User Agent
Expand Down Expand Up @@ -44,7 +45,8 @@ public void apply(RequestTemplate template) {

// if a oauth token is passed is has to match that client id, default to global client id otherwise (for ie. token verification)
if (template.headers().containsKey(AUTH_HEADER)) {
String oauthToken = template.headers().get(AUTH_HEADER).iterator().next().substring(BEARER_PREFIX.length());
// noinspection ConstantConditions
String oauthToken = TwitchHelixHttpClient.getFirst(AUTH_HEADER, template.headers()).substring(BEARER_PREFIX.length());

if (oauthToken.isEmpty()) {
try {
Expand All @@ -65,8 +67,8 @@ public void apply(RequestTemplate template) {
}

// set headers
if (!template.headers().containsKey("Client-Id"))
template.header("Client-Id", clientId);
if (!template.headers().containsKey(CLIENT_HEADER))
template.header(CLIENT_HEADER, clientId);
template.header("User-Agent", userAgent);
if (template.body() != null && !template.headers().containsKey("Content-Type"))
template.header("Content-Type", "application/json");
Expand Down
Expand Up @@ -7,10 +7,12 @@

import java.io.IOException;
import java.lang.reflect.Type;
import java.util.Collection;

import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.AUTH_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.BEARER_PREFIX;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.CLIENT_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixHttpClient.getFirstHeader;
import static com.github.twitch4j.helix.interceptor.TwitchHelixHttpClient.getFirstParam;

public class TwitchHelixDecoder extends JacksonDecoder {

Expand All @@ -26,10 +28,11 @@ public TwitchHelixDecoder(ObjectMapper mapper, TwitchHelixRateLimitTracker rateL
@Override
public Object decode(Response response, Type type) throws IOException {
// track rate limit for token
String token = singleFirst(response.request().headers().get(AUTH_HEADER));
Request request = response.request();
String token = getFirstHeader(AUTH_HEADER, request);
if (token != null && token.startsWith(BEARER_PREFIX)) {
// Parse remaining
String remainingStr = singleFirst(response.headers().get(REMAINING_HEADER));
String remainingStr = getFirstHeader(REMAINING_HEADER, request);
Integer remaining;
try {
remaining = Integer.parseInt(remainingStr);
Expand All @@ -40,9 +43,19 @@ public Object decode(Response response, Type type) throws IOException {
// Synchronize library buckets with twitch data
if (remaining != null) {
String bearer = token.substring(BEARER_PREFIX.length());
if (response.request().httpMethod() == Request.HttpMethod.POST && response.request().requestTemplate().path().endsWith("/clips")) {
if (request.httpMethod() == Request.HttpMethod.POST && request.requestTemplate().path().endsWith("/clips")) {
// Create Clip has a separate rate limit to synchronize
rateLimitTracker.updateRemainingCreateClip(bearer, remaining);
} else if (request.httpMethod() == Request.HttpMethod.POST && request.requestTemplate().path().endsWith("/extensions/chat")) {
// Send Extension Chat Message rate limit
String clientId = getFirstHeader(CLIENT_HEADER, request);
String channelId = getFirstParam("broadcaster_id", request);
rateLimitTracker.updateRemainingExtensionChat(clientId, channelId, remaining);
} else if (request.httpMethod() == Request.HttpMethod.POST && request.requestTemplate().path().endsWith("/extensions/pubsub")) {
// Send Extension PubSub Message rate limit
String clientId = getFirstHeader(CLIENT_HEADER, request);
String target = getFirstHeader("Twitch4J-Target", request);
rateLimitTracker.updateRemainingExtensionPubSub(clientId, target, remaining);
} else {
// Normal/global helix rate limit synchronization
rateLimitTracker.updateRemaining(bearer, remaining);
Expand All @@ -54,9 +67,4 @@ public Object decode(Response response, Type type) throws IOException {
return super.decode(response, type);
}

static String singleFirst(Collection<String> collection) {
if (collection == null || collection.size() != 1) return null;
return collection.toArray(new String[1])[0];
}

}
Expand Up @@ -8,9 +8,12 @@
import feign.okhttp.OkHttpClient;
import io.github.bucket4j.Bucket;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.Nullable;

import java.io.IOException;
import java.util.Collections;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
Expand All @@ -20,7 +23,7 @@

import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.AUTH_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.BEARER_PREFIX;
import static com.github.twitch4j.helix.interceptor.TwitchHelixDecoder.singleFirst;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.CLIENT_HEADER;

@Slf4j
public class TwitchHelixHttpClient implements Client {
Expand All @@ -42,7 +45,7 @@ public TwitchHelixHttpClient(OkHttpClient client, ScheduledThreadPoolExecutor ex
@Override
public Response execute(Request request, Request.Options options) throws IOException {
// Check whether this request should be delayed to conform to rate limits
String token = singleFirst(request.headers().get(AUTH_HEADER));
String token = getFirstHeader(AUTH_HEADER, request);
if (token != null && token.startsWith(BEARER_PREFIX)) {
OAuth2Credential credential = tokenManager.getIfPresent(token.substring(BEARER_PREFIX.length()));
if (credential != null) {
Expand Down Expand Up @@ -70,7 +73,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Channels API: addChannelVip and removeChannelVip (likely) share a bucket per channel id
if (templatePath.endsWith("/channels/vips")) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket vipBucket;
Expand All @@ -89,7 +92,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Moderation API: Check AutoMod Status has a stricter bucket that applies per channel id
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/moderation/enforcements/status")) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket autoModBucket = rateLimitTracker.getAutomodStatusBucket(channelId);
Expand All @@ -99,7 +102,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Moderation API: banUser and unbanUser share a bucket per channel id
if (templatePath.endsWith("/moderation/bans")) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket modBucket = rateLimitTracker.getModerationBucket(channelId);
Expand All @@ -109,7 +112,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Moderation API: addBlockedTerm and removeBlockedTerm share a bucket per channel id
if (templatePath.endsWith("/moderation/blocked_terms") && (request.httpMethod() == Request.HttpMethod.POST || request.httpMethod() == Request.HttpMethod.DELETE)) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket termsBucket = rateLimitTracker.getTermsBucket(channelId);
Expand All @@ -119,7 +122,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Moderation API: addChannelModerator and removeChannelModerator have independent buckets per channel id
if (templatePath.endsWith("/moderation/moderators")) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket modsBucket;
Expand All @@ -138,7 +141,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Clips API: createClip has a stricter bucket that applies per user id
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/clips")) {
// Obtain user id
String token = request.headers().get(AUTH_HEADER).iterator().next().substring(BEARER_PREFIX.length());
String token = Objects.requireNonNull(getFirstHeader(AUTH_HEADER, request)).substring(BEARER_PREFIX.length());
OAuth2Credential cred = tokenManager.getIfPresent(token);
String userId = cred != null ? cred.getUserId() : "";

Expand All @@ -147,31 +150,69 @@ private Response delegatedExecute(Request request, Request.Options options) thro
return executeAgainstBucket(clipBucket, () -> client.execute(request, options));
}

// Extensions API: sendExtensionChatMessage has a stricter per-channel bucket
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/extensions/chat")) {
// Obtain the bucket key
String clientId = getFirstHeader(CLIENT_HEADER, request);
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket chatBucket = rateLimitTracker.getExtensionChatBucket(Objects.requireNonNull(clientId), Objects.requireNonNull(channelId));
return executeAgainstBucket(chatBucket, () -> client.execute(request, options));
}

// Extensions API: sendExtensionPubSubMessage has a stricter bucket depending on the target
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/extensions/pubsub")) {
// Obtain the bucket key
String clientId = getFirstHeader(CLIENT_HEADER, request);
String target = getFirstHeader("Twitch4J-Target", request);

// Conform to endpoint-specific bucket
Bucket pubSubBucket = rateLimitTracker.getExtensionPubSubBucket(Objects.requireNonNull(clientId), Objects.requireNonNull(target));
return executeAgainstBucket(pubSubBucket, () -> client.execute(request, options));
}

// Raids API: startRaid and cancelRaid have a stricter bucket that applies per channel id
if (templatePath.endsWith("/raids")) {
// Obtain the channel id
String param = request.httpMethod() == Request.HttpMethod.POST ? "from_broadcaster_id" : "broadcaster_id";
String channelId = request.requestTemplate().queries().getOrDefault(param, Collections.emptyList()).iterator().next();
String channelId = getFirstParam(param, request);

// Conform to endpoint-specific bucket
Bucket raidBucket = rateLimitTracker.getRaidsBucket(channelId);
Bucket raidBucket = rateLimitTracker.getRaidsBucket(Objects.requireNonNull(channelId));
return executeAgainstBucket(raidBucket, () -> client.execute(request, options));
}

// Whispers API: sendWhisper has a stricter bucket that applies per user id
if (templatePath.endsWith("/whispers")) {
// Obtain the user id
String userId = request.requestTemplate().queries().getOrDefault("from_user_id", Collections.emptyList()).iterator().next();
String userId = getFirstParam("from_user_id", request);

// Conform to endpoint-specific bucket
Bucket whisperBucket = rateLimitTracker.getWhispersBucket(userId);
Bucket whisperBucket = rateLimitTracker.getWhispersBucket(Objects.requireNonNull(userId));
return executeAgainstBucket(whisperBucket, () -> client.execute(request, options));
}

// no endpoint-specific rate limiting was needed; simply perform network request now
return client.execute(request, options);
}

@Nullable
static String getFirstHeader(String key, Request request) {
return getFirst(key, request.headers());
}

@Nullable
static String getFirstParam(String key, Request request) {
return getFirst(key, request.requestTemplate().queries());
}

@Nullable
static String getFirst(String key, Map<String, Collection<String>> map) {
final Collection<String> values = map.get(key);
return values != null && !values.isEmpty() ? values.iterator().next() : null;
}

private <T> T executeAgainstBucket(Bucket bucket, Callable<T> call) throws IOException {
try {
return BucketUtils.scheduleAgainstBucket(bucket, executor, call).get(timeout, TimeUnit.MILLISECONDS);
Expand Down

0 comments on commit 55b99ad

Please sign in to comment.