diff --git a/packages/rest/__tests__/RequestHandler.test.ts b/packages/rest/__tests__/RequestHandler.test.ts index 67c285b..d2e4ca2 100644 --- a/packages/rest/__tests__/RequestHandler.test.ts +++ b/packages/rest/__tests__/RequestHandler.test.ts @@ -1,8 +1,9 @@ import nock from 'nock'; -import { DefaultRestOptions, DiscordAPIError, HTTPError, REST, RESTEvents } from '../src'; +import { DefaultRestOptions, DiscordAPIError, HTTPError, RateLimitError, REST, RESTEvents } from '../src'; const api = new REST({ timeout: 2000, offset: 5 }).setToken('A-Very-Fake-Token'); const invalidAuthApi = new REST({ timeout: 2000 }).setToken('Definitely-Not-A-Fake-Token'); +const rateLimitErrorApi = new REST({ rejectOnRateLimit: ['/channels'] }).setToken('Obviouslly-Not-A-Fake-Token'); let resetAfter = 0; let sublimitResetAfter = 0; @@ -17,6 +18,9 @@ const sublimitIntervals = { retry: null, }; +const sublimit = { body: { name: 'newname' } }; +const noSublimit = { body: { bitrate: 40000 } }; + nock(`${DefaultRestOptions.api}/v${DefaultRestOptions.version}`) .persist() .replyDate() @@ -253,8 +257,6 @@ test('Handle global rate limits', async () => { }); test('Handle sublimits', async () => { - const sublimit = { body: { name: 'newname' } }; - const noSublimit = { body: { bitrate: 40000 } }; // Return the current time on these results as their response does not indicate anything // Queue all requests, don't wait, to allow retroactive check const [aP, bP, cP, dP, eP] = [ @@ -325,6 +327,19 @@ test('Unauthorized', async () => { await expect(promise).rejects.toBeInstanceOf(DiscordAPIError); }); +test('Reject on RateLimit', async () => { + const [aP, bP, cP] = [ + rateLimitErrorApi.patch('/channels/:id', sublimit), + rateLimitErrorApi.patch('/channels/:id', sublimit), + rateLimitErrorApi.patch('/channels/:id', sublimit), + ]; + await expect(aP).resolves; + await expect(bP).rejects.toThrowError(); + await expect(bP).rejects.toBeInstanceOf(RateLimitError); + await expect(cP).rejects.toThrowError(); + await expect(cP).rejects.toBeInstanceOf(RateLimitError); +}); + test('malformedRequest', async () => { expect(await api.get('/malformedRequest')).toBe(null); }); diff --git a/packages/rest/src/index.ts b/packages/rest/src/index.ts index 86e9f1a..a34b851 100644 --- a/packages/rest/src/index.ts +++ b/packages/rest/src/index.ts @@ -4,6 +4,7 @@ export * from './lib/CDN'; export * from './lib/errors/DiscordAPIError'; export * from './lib/errors/HTTPError'; +export * from './lib/errors/RateLimitError'; export * from './lib/RequestManager'; export * from './lib/REST'; export * from './lib/utils/constants'; diff --git a/packages/rest/src/lib/REST.ts b/packages/rest/src/lib/REST.ts index 83382b1..ef07257 100644 --- a/packages/rest/src/lib/REST.ts +++ b/packages/rest/src/lib/REST.ts @@ -38,6 +38,14 @@ export interface RESTOptions { * @default 50 */ offset: number; + /** + * Determines how how rate limiting and pre-emptive throttling should be handled. + * When an array of strings, each element is treated as a prefix for the request route + * (e.g. `/channels/` to match any route starting with `/channels` such as `/channels/:id/messages`) + * for which to throw {@link RateLimitError}s. All other requests routes will be queued normally + * @default null + */ + rejectOnRateLimit: string[] | RateLimitQueueFilter | null; /** * The number of retries for errors with the 500 code, or errors * that timeout @@ -81,6 +89,10 @@ export interface RateLimitData { * The bucket hash for this request */ hash: string; + /** + * The full url for this request + */ + url: string; /** * The route being hit in this request */ @@ -98,6 +110,11 @@ export interface RateLimitData { global: boolean; } +/** + * A function that determines whether the rate limit hit should throw an Error + */ +export type RateLimitQueueFilter = (rateLimitData: RateLimitData) => boolean | Promise; + export interface InvalidRequestWarningData { /** * Number of invalid requests that have been made in the window @@ -109,7 +126,7 @@ export interface InvalidRequestWarningData { remainingTime: number; } -interface RestEvents { +export interface RestEvents { invalidRequestWarning: [invalidRequestInfo: InvalidRequestWarningData]; restDebug: [info: string]; rateLimited: [rateLimitInfo: RateLimitData]; diff --git a/packages/rest/src/lib/RequestManager.ts b/packages/rest/src/lib/RequestManager.ts index 73b631d..191387b 100644 --- a/packages/rest/src/lib/RequestManager.ts +++ b/packages/rest/src/lib/RequestManager.ts @@ -8,6 +8,7 @@ import type { IHandler } from './handlers/IHandler'; import { SequentialHandler } from './handlers/SequentialHandler'; import type { RESTOptions } from './REST'; import { DefaultRestOptions, DefaultUserAgent } from './utils/constants'; +import type { RestEvents } from '..'; const agent = new Agent({ keepAlive: true }); @@ -103,6 +104,23 @@ export interface RouteData { original: string; } +export interface RequestManager { + on(event: K, listener: (...args: RestEvents[K]) => void): this; + on(event: Exclude, listener: (...args: any[]) => void): this; + + once(event: K, listener: (...args: RestEvents[K]) => void): this; + once(event: Exclude, listener: (...args: any[]) => void): this; + + emit(event: K, ...args: RestEvents[K]): boolean; + emit(event: Exclude, ...args: any[]): boolean; + + off(event: K, listener: (...args: RestEvents[K]) => void): this; + off(event: Exclude, listener: (...args: any[]) => void): this; + + removeAllListeners(event?: K): this; + removeAllListeners(event?: Exclude): this; +} + /** * Represents the class that manages handlers for endpoints */ diff --git a/packages/rest/src/lib/errors/RateLimitError.ts b/packages/rest/src/lib/errors/RateLimitError.ts new file mode 100644 index 0000000..c6b5701 --- /dev/null +++ b/packages/rest/src/lib/errors/RateLimitError.ts @@ -0,0 +1,30 @@ +import type { RateLimitData } from '../REST'; + +export class RateLimitError extends Error implements RateLimitData { + public timeToReset: number; + public limit: number; + public method: string; + public hash: string; + public url: string; + public route: string; + public majorParameter: string; + public global: boolean; + public constructor({ timeToReset, limit, method, hash, url, route, majorParameter, global }: RateLimitData) { + super(); + this.timeToReset = timeToReset; + this.limit = limit; + this.method = method; + this.hash = hash; + this.url = url; + this.route = route; + this.majorParameter = majorParameter; + this.global = global; + } + + /** + * The name of the error + */ + public get name(): string { + return `${RateLimitError.name}[${this.route}]`; + } +} diff --git a/packages/rest/src/lib/handlers/SequentialHandler.ts b/packages/rest/src/lib/handlers/SequentialHandler.ts index 33de022..29e95f2 100644 --- a/packages/rest/src/lib/handlers/SequentialHandler.ts +++ b/packages/rest/src/lib/handlers/SequentialHandler.ts @@ -6,6 +6,8 @@ import { HTTPError } from '../errors/HTTPError'; import type { InternalRequest, RequestManager, RouteData } from '../RequestManager'; import { RESTEvents } from '../utils/constants'; import { hasSublimit, parseResponse } from '../utils/utils'; +import type { RateLimitData } from '../REST'; +import { RateLimitError } from '../..'; /* Invalid request limiting is done on a per-IP basis, not a per-token basis. * The best we can do is track invalid counts process-wide (on the theory that @@ -139,6 +141,22 @@ export class SequentialHandler { this.manager.globalDelay = null; } + /* + * Determines whether the request should be queued or whether a RateLimitError should be thrown + */ + private async onRateLimit(rateLimitData: RateLimitData) { + const { options } = this.manager; + if (!options.rejectOnRateLimit) return; + + const shouldThrow = + typeof options.rejectOnRateLimit === 'function' + ? await options.rejectOnRateLimit(rateLimitData) + : options.rejectOnRateLimit.some((route) => rateLimitData.route.startsWith(route.toLowerCase())); + if (shouldThrow) { + throw new RateLimitError(rateLimitData); + } + } + /** * Queues a request to be sent * @param routeId The generalized api route with literal ids for major parameters @@ -236,16 +254,21 @@ export class SequentialHandler { timeout = this.timeToReset; delay = sleep(timeout, undefined, { ref: false }); } - // Let library users know they have hit a rate limit - this.manager.emit(RESTEvents.RateLimited, { + const rateLimitData: RateLimitData = { timeToReset: timeout, limit, - method: options.method, + method: options.method ?? 'get', hash: this.hash, + url, route: routeId.bucketRoute, majorParameter: this.majorParameter, global: isGlobal, - }); + }; + // Let library users know they have hit a rate limit + this.manager.emit(RESTEvents.RateLimited, rateLimitData); + // Determine whether a RateLimitError should be thrown + await this.onRateLimit(rateLimitData); + // When not erroring, emit debug for what is happening if (isGlobal) { this.debug(`Global rate limit hit, blocking all requests for ${timeout}ms`); } else { @@ -354,6 +377,29 @@ export class SequentialHandler { ` Retry After : ${retryAfter}ms`, ].join('\n'), ); + const isGlobal = this.globalLimited; + let limit: number; + let timeout: number; + + if (isGlobal) { + // Set RateLimitData based on the globl limit + limit = this.manager.options.globalRequestsPerSecond; + timeout = this.manager.globalReset + this.manager.options.offset - Date.now(); + } else { + // Set RateLimitData based on the route-specific limit + limit = this.limit; + timeout = this.timeToReset; + } + await this.onRateLimit({ + timeToReset: timeout, + limit, + method, + hash: this.hash, + url, + route: routeId.bucketRoute, + majorParameter: this.majorParameter, + global: isGlobal, + }); // If caused by a sublimit, wait it out here so other requests on the route can be handled if (sublimitTimeout) { // Normally the sublimit queue will not exist, however, if a sublimit is hit while in the sublimit queue, it will diff --git a/packages/rest/src/lib/utils/constants.ts b/packages/rest/src/lib/utils/constants.ts index 629826d..9328db4 100644 --- a/packages/rest/src/lib/utils/constants.ts +++ b/packages/rest/src/lib/utils/constants.ts @@ -13,6 +13,7 @@ export const DefaultRestOptions: Required = { invalidRequestWarningInterval: 0, globalRequestsPerSecond: 50, offset: 50, + rejectOnRateLimit: null, retries: 3, timeout: 15_000, userAgentAppendix: `Node.js ${process.version}`,