Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement send extension message rate limit #576

Merged
merged 6 commits into from Jul 25, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 24 additions & 11 deletions rest-helix/src/main/java/com/github/twitch4j/helix/TwitchHelix.java
Expand Up @@ -842,6 +842,22 @@ HystrixCommand<ExtensionLiveChannelsList> getExtensionLiveChannels(
@Param("after") String after
);

@Deprecated // meant to only be called internally
@SuppressWarnings("DeprecatedIsStillUsed")
@RequestLine("POST /extensions/pubsub")
@Headers({
"Authorization: Bearer {token}",
"Client-Id: {extension_id}",
"Content-Type: application/json",
"Twitch4J-Target: {twitch4j_target}"
})
HystrixCommand<Void> sendExtensionPubSubMessage(
@Param("token") String jwtToken,
@Param("extension_id") String extensionId,
@Param("twitch4j_target") String target,
SendPubSubMessageInput input
);

/**
* Twitch provides a publish-subscribe system for your EBS to communicate with both the broadcaster and viewers.
* Calling this endpoint forwards your message using the same mechanism as the send JavaScript helper function.
Expand All @@ -857,17 +873,14 @@ HystrixCommand<ExtensionLiveChannelsList> getExtensionLiveChannels(
* @param input Details on the message to be sent and its targets.
* @return 204 No Content upon a successful request.
*/
@RequestLine("POST /extensions/pubsub")
@Headers({
"Authorization: Bearer {token}",
"Client-Id: {extension_id}",
"Content-Type: application/json"
})
HystrixCommand<Void> sendExtensionPubSubMessage(
@Param("token") String jwtToken,
@Param("extension_id") String extensionId,
SendPubSubMessageInput input
);
default HystrixCommand<Void> sendExtensionPubSubMessage(
String jwtToken,
@NotNull String extensionId,
@NotNull SendPubSubMessageInput input
) {
final String target = input.isGlobalBroadcast() ? "broadcast" : input.getBroadcasterId() != null ? input.getBroadcasterId() : input.getTargets().get(0);
return this.sendExtensionPubSubMessage(jwtToken, extensionId, target, input);
}

/**
* Gets information about a released Extension; either the current version or a specified version.
Expand Down
Expand Up @@ -2,6 +2,7 @@

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonPropertyOrder;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Builder;
Expand All @@ -23,6 +24,7 @@
@NoArgsConstructor
@AllArgsConstructor
@JsonInclude(JsonInclude.Include.NON_NULL)
@JsonPropertyOrder(alphabetic = true) // used in http client to avoid extra deserialization
public class SendPubSubMessageInput {

/**
Expand Down
Expand Up @@ -14,6 +14,7 @@ public class TwitchHelixClientIdInterceptor implements RequestInterceptor {

public static final String AUTH_HEADER = "Authorization";
public static final String BEARER_PREFIX = "Bearer ";
public static final String CLIENT_HEADER = "Client-Id";

/**
* User Agent
Expand Down Expand Up @@ -44,7 +45,8 @@ public void apply(RequestTemplate template) {

// if a oauth token is passed is has to match that client id, default to global client id otherwise (for ie. token verification)
if (template.headers().containsKey(AUTH_HEADER)) {
String oauthToken = template.headers().get(AUTH_HEADER).iterator().next().substring(BEARER_PREFIX.length());
// noinspection ConstantConditions
String oauthToken = TwitchHelixHttpClient.getFirst(AUTH_HEADER, template.headers()).substring(BEARER_PREFIX.length());

if (oauthToken.isEmpty()) {
try {
Expand All @@ -65,8 +67,8 @@ public void apply(RequestTemplate template) {
}

// set headers
if (!template.headers().containsKey("Client-Id"))
template.header("Client-Id", clientId);
if (!template.headers().containsKey(CLIENT_HEADER))
template.header(CLIENT_HEADER, clientId);
template.header("User-Agent", userAgent);
if (template.body() != null && !template.headers().containsKey("Content-Type"))
template.header("Content-Type", "application/json");
Expand Down
Expand Up @@ -7,10 +7,12 @@

import java.io.IOException;
import java.lang.reflect.Type;
import java.util.Collection;

import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.AUTH_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.BEARER_PREFIX;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.CLIENT_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixHttpClient.getFirstHeader;
import static com.github.twitch4j.helix.interceptor.TwitchHelixHttpClient.getFirstParam;

public class TwitchHelixDecoder extends JacksonDecoder {

Expand All @@ -26,10 +28,11 @@ public TwitchHelixDecoder(ObjectMapper mapper, TwitchHelixRateLimitTracker rateL
@Override
public Object decode(Response response, Type type) throws IOException {
// track rate limit for token
String token = singleFirst(response.request().headers().get(AUTH_HEADER));
Request request = response.request();
String token = getFirstHeader(AUTH_HEADER, request);
if (token != null && token.startsWith(BEARER_PREFIX)) {
// Parse remaining
String remainingStr = singleFirst(response.headers().get(REMAINING_HEADER));
String remainingStr = getFirstHeader(REMAINING_HEADER, request);
Integer remaining;
try {
remaining = Integer.parseInt(remainingStr);
Expand All @@ -40,9 +43,19 @@ public Object decode(Response response, Type type) throws IOException {
// Synchronize library buckets with twitch data
if (remaining != null) {
String bearer = token.substring(BEARER_PREFIX.length());
if (response.request().httpMethod() == Request.HttpMethod.POST && response.request().requestTemplate().path().endsWith("/clips")) {
if (request.httpMethod() == Request.HttpMethod.POST && request.requestTemplate().path().endsWith("/clips")) {
// Create Clip has a separate rate limit to synchronize
rateLimitTracker.updateRemainingCreateClip(bearer, remaining);
} else if (request.httpMethod() == Request.HttpMethod.POST && request.requestTemplate().path().endsWith("/extensions/chat")) {
// Send Extension Chat Message rate limit
String clientId = getFirstHeader(CLIENT_HEADER, request);
String channelId = getFirstParam("broadcaster_id", request);
rateLimitTracker.updateRemainingExtensionChat(clientId, channelId, remaining);
} else if (request.httpMethod() == Request.HttpMethod.POST && request.requestTemplate().path().endsWith("/extensions/pubsub")) {
// Send Extension PubSub Message rate limit
String clientId = getFirstHeader(CLIENT_HEADER, request);
String target = getFirstHeader("Twitch4J-Target", request);
rateLimitTracker.updateRemainingExtensionPubSub(clientId, target, remaining);
} else {
// Normal/global helix rate limit synchronization
rateLimitTracker.updateRemaining(bearer, remaining);
Expand All @@ -54,9 +67,4 @@ public Object decode(Response response, Type type) throws IOException {
return super.decode(response, type);
}

static String singleFirst(Collection<String> collection) {
if (collection == null || collection.size() != 1) return null;
return collection.toArray(new String[1])[0];
}

}
Expand Up @@ -8,9 +8,12 @@
import feign.okhttp.OkHttpClient;
import io.github.bucket4j.Bucket;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.Nullable;

import java.io.IOException;
import java.util.Collections;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
Expand All @@ -20,7 +23,7 @@

import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.AUTH_HEADER;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.BEARER_PREFIX;
import static com.github.twitch4j.helix.interceptor.TwitchHelixDecoder.singleFirst;
import static com.github.twitch4j.helix.interceptor.TwitchHelixClientIdInterceptor.CLIENT_HEADER;

@Slf4j
public class TwitchHelixHttpClient implements Client {
Expand All @@ -42,7 +45,7 @@ public TwitchHelixHttpClient(OkHttpClient client, ScheduledThreadPoolExecutor ex
@Override
public Response execute(Request request, Request.Options options) throws IOException {
// Check whether this request should be delayed to conform to rate limits
String token = singleFirst(request.headers().get(AUTH_HEADER));
String token = getFirstHeader(AUTH_HEADER, request);
if (token != null && token.startsWith(BEARER_PREFIX)) {
OAuth2Credential credential = tokenManager.getIfPresent(token.substring(BEARER_PREFIX.length()));
if (credential != null) {
Expand Down Expand Up @@ -70,7 +73,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Channels API: addChannelVip and removeChannelVip (likely) share a bucket per channel id
if (templatePath.endsWith("/channels/vips")) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket vipBucket;
Expand All @@ -89,7 +92,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Moderation API: Check AutoMod Status has a stricter bucket that applies per channel id
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/moderation/enforcements/status")) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket autoModBucket = rateLimitTracker.getAutomodStatusBucket(channelId);
Expand All @@ -99,7 +102,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Moderation API: banUser and unbanUser share a bucket per channel id
if (templatePath.endsWith("/moderation/bans")) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket modBucket = rateLimitTracker.getModerationBucket(channelId);
Expand All @@ -109,7 +112,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Moderation API: addBlockedTerm and removeBlockedTerm share a bucket per channel id
if (templatePath.endsWith("/moderation/blocked_terms") && (request.httpMethod() == Request.HttpMethod.POST || request.httpMethod() == Request.HttpMethod.DELETE)) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket termsBucket = rateLimitTracker.getTermsBucket(channelId);
Expand All @@ -119,7 +122,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Moderation API: addChannelModerator and removeChannelModerator have independent buckets per channel id
if (templatePath.endsWith("/moderation/moderators")) {
// Obtain the channel id
String channelId = request.requestTemplate().queries().getOrDefault("broadcaster_id", Collections.emptyList()).iterator().next();
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket modsBucket;
Expand All @@ -138,7 +141,7 @@ private Response delegatedExecute(Request request, Request.Options options) thro
// Clips API: createClip has a stricter bucket that applies per user id
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/clips")) {
// Obtain user id
String token = request.headers().get(AUTH_HEADER).iterator().next().substring(BEARER_PREFIX.length());
String token = Objects.requireNonNull(getFirstHeader(AUTH_HEADER, request)).substring(BEARER_PREFIX.length());
OAuth2Credential cred = tokenManager.getIfPresent(token);
String userId = cred != null ? cred.getUserId() : "";

Expand All @@ -147,31 +150,69 @@ private Response delegatedExecute(Request request, Request.Options options) thro
return executeAgainstBucket(clipBucket, () -> client.execute(request, options));
}

// Extensions API: sendExtensionChatMessage has a stricter per-channel bucket
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/extensions/chat")) {
// Obtain the bucket key
String clientId = getFirstHeader(CLIENT_HEADER, request);
String channelId = getFirstParam("broadcaster_id", request);

// Conform to endpoint-specific bucket
Bucket chatBucket = rateLimitTracker.getExtensionChatBucket(Objects.requireNonNull(clientId), Objects.requireNonNull(channelId));
return executeAgainstBucket(chatBucket, () -> client.execute(request, options));
}

// Extensions API: sendExtensionPubSubMessage has a stricter bucket depending on the target
if (request.httpMethod() == Request.HttpMethod.POST && templatePath.endsWith("/extensions/pubsub")) {
// Obtain the bucket key
String clientId = getFirstHeader(CLIENT_HEADER, request);
String target = getFirstHeader("Twitch4J-Target", request);

// Conform to endpoint-specific bucket
Bucket pubSubBucket = rateLimitTracker.getExtensionPubSubBucket(Objects.requireNonNull(clientId), Objects.requireNonNull(target));
return executeAgainstBucket(pubSubBucket, () -> client.execute(request, options));
}

// Raids API: startRaid and cancelRaid have a stricter bucket that applies per channel id
if (templatePath.endsWith("/raids")) {
// Obtain the channel id
String param = request.httpMethod() == Request.HttpMethod.POST ? "from_broadcaster_id" : "broadcaster_id";
String channelId = request.requestTemplate().queries().getOrDefault(param, Collections.emptyList()).iterator().next();
String channelId = getFirstParam(param, request);

// Conform to endpoint-specific bucket
Bucket raidBucket = rateLimitTracker.getRaidsBucket(channelId);
Bucket raidBucket = rateLimitTracker.getRaidsBucket(Objects.requireNonNull(channelId));
return executeAgainstBucket(raidBucket, () -> client.execute(request, options));
}

// Whispers API: sendWhisper has a stricter bucket that applies per user id
if (templatePath.endsWith("/whispers")) {
// Obtain the user id
String userId = request.requestTemplate().queries().getOrDefault("from_user_id", Collections.emptyList()).iterator().next();
String userId = getFirstParam("from_user_id", request);

// Conform to endpoint-specific bucket
Bucket whisperBucket = rateLimitTracker.getWhispersBucket(userId);
Bucket whisperBucket = rateLimitTracker.getWhispersBucket(Objects.requireNonNull(userId));
return executeAgainstBucket(whisperBucket, () -> client.execute(request, options));
}

// no endpoint-specific rate limiting was needed; simply perform network request now
return client.execute(request, options);
}

@Nullable
static String getFirstHeader(String key, Request request) {
return getFirst(key, request.headers());
}

@Nullable
static String getFirstParam(String key, Request request) {
return getFirst(key, request.requestTemplate().queries());
}

@Nullable
static String getFirst(String key, Map<String, Collection<String>> map) {
final Collection<String> values = map.get(key);
return values != null && !values.isEmpty() ? values.iterator().next() : null;
}

private <T> T executeAgainstBucket(Bucket bucket, Callable<T> call) throws IOException {
try {
return BucketUtils.scheduleAgainstBucket(bucket, executor, call).get(timeout, TimeUnit.MILLISECONDS);
Expand Down