Skip to content

Commit

Permalink
refactor(WebSocketShard): identify throttling (#8888)
Browse files Browse the repository at this point in the history
* refactor(WebSocketShard): identify throttling

* chore: add worker handling

* refactor: worker handling

* chore: update tests

* chore: use satisfies where applicable

* chore: add informative comment

* chore: apply suggestions

* refactor(SimpleContextFetchingStrategy): support multiple managers
  • Loading branch information
didinele committed Dec 2, 2022
1 parent 3fca638 commit 8f552a0
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 38 deletions.
14 changes: 11 additions & 3 deletions packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,22 @@ const mockConstructor = vi.fn();
const mockSend = vi.fn();
const mockTerminate = vi.fn();

const memberChunkData: GatewayDispatchPayload = {
const memberChunkData = {
op: GatewayOpcodes.Dispatch,
s: 123,
t: GatewayDispatchEvents.GuildMembersChunk,
d: {
guild_id: '123',
members: [],
},
};
} as unknown as GatewayDispatchPayload;

const sessionInfo: SessionInfo = {
shardId: 0,
shardCount: 2,
sequence: 123,
sessionId: 'abc',
resumeURL: 'wss://ehehe.gg',
};

vi.mock('node:worker_threads', async () => {
Expand Down Expand Up @@ -109,6 +110,10 @@ vi.mock('node:worker_threads', async () => {
this.emit('message', session);
break;
}

case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
}
}

Expand Down Expand Up @@ -181,7 +186,10 @@ test('spawn, connect, send a message, session info, and destroy', async () => {
expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }),
);

const payload: GatewaySendPayload = { op: GatewayOpcodes.RequestGuildMembers, d: { guild_id: '123', limit: 0 } };
const payload = {
op: GatewayOpcodes.RequestGuildMembers,
d: { guild_id: '123', limit: 0, query: '' },
} satisfies GatewaySendPayload;
await manager.send(0, payload);
expect(mockSend).toHaveBeenCalledWith(0, payload);
expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export interface IContextFetchingStrategy {
readonly options: FetchingStrategyOptions;
retrieveSessionInfo(shardId: number): Awaitable<SessionInfo | null>;
updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>;
waitForIdentify(): Promise<void>;
}

export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';

export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {}
// This strategy assumes every shard is running under the same process - therefore we need a single
// IdentifyThrottler per manager.
private static throttlerCache = new WeakMap<WebSocketManager, IdentifyThrottler>();

private static ensureThrottler(manager: WebSocketManager): IdentifyThrottler {
const existing = SimpleContextFetchingStrategy.throttlerCache.get(manager);
if (existing) {
return existing;
}

const throttler = new IdentifyThrottler(manager);
SimpleContextFetchingStrategy.throttlerCache.set(manager, throttler);
return throttler;
}

private readonly throttler: IdentifyThrottler;

public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {
this.throttler = SimpleContextFetchingStrategy.ensureThrottler(manager);
}

public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
return this.manager.options.retrieveSessionInfo(shardId);
Expand All @@ -11,4 +31,8 @@ export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
return this.manager.options.updateSessionInfo(shardId, sessionInfo);
}

public async waitForIdentify(): Promise<void> {
await this.throttler.waitForIdentify();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,39 +12,57 @@ import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IConte
export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>();

private readonly waitForIdentifyPromises = new Collection<number, () => void>();

public constructor(public readonly options: FetchingStrategyOptions) {
if (isMainThread) {
throw new Error('Cannot instantiate WorkerContextFetchingStrategy on the main thread');
}

parentPort!.on('message', (payload: WorkerSendPayload) => {
if (payload.op === WorkerSendPayloadOp.SessionInfoResponse) {
const resolve = this.sessionPromises.get(payload.nonce);
resolve?.(payload.session);
this.sessionPromises.get(payload.nonce)?.(payload.session);
this.sessionPromises.delete(payload.nonce);
}

if (payload.op === WorkerSendPayloadOp.ShardCanIdentify) {
this.waitForIdentifyPromises.get(payload.nonce)?.();
this.waitForIdentifyPromises.delete(payload.nonce);
}
});
}

public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
const nonce = Math.random();
const payload: WorkerRecievePayload = {
const payload = {
op: WorkerRecievePayloadOp.RetrieveSessionInfo,
shardId,
nonce,
};
} satisfies WorkerRecievePayload;
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve));
parentPort!.postMessage(payload);
return promise;
}

public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
const payload: WorkerRecievePayload = {
const payload = {
op: WorkerRecievePayloadOp.UpdateSessionInfo,
shardId,
session: sessionInfo,
};
} satisfies WorkerRecievePayload;
parentPort!.postMessage(payload);
}

public async waitForIdentify(): Promise<void> {
const nonce = Math.random();
const payload = {
op: WorkerRecievePayloadOp.WaitForIdentify,
nonce,
} satisfies WorkerRecievePayload;
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<void>((resolve) => this.waitForIdentifyPromises.set(nonce, resolve));
parentPort!.postMessage(payload);
return promise;
}
}
5 changes: 0 additions & 5 deletions packages/ws/src/strategies/sharding/SimpleShardingStrategy.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import { Collection } from '@discordjs/collection';
import type { GatewaySendPayload } from 'discord-api-types/v10';
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
import type { WebSocketManager } from '../../ws/WebSocketManager';
import { WebSocketShard, WebSocketShardEvents, type WebSocketShardDestroyOptions } from '../../ws/WebSocketShard.js';
import { managerToFetchingStrategyOptions } from '../context/IContextFetchingStrategy.js';
Expand All @@ -15,11 +14,8 @@ export class SimpleShardingStrategy implements IShardingStrategy {

private readonly shards = new Collection<number, WebSocketShard>();

private readonly throttler: IdentifyThrottler;

public constructor(manager: WebSocketManager) {
this.manager = manager;
this.throttler = new IdentifyThrottler(manager);
}

/**
Expand All @@ -46,7 +42,6 @@ export class SimpleShardingStrategy implements IShardingStrategy {
const promises = [];

for (const shard of this.shards.values()) {
await this.throttler.waitForIdentify();
promises.push(shard.connect());
}

Expand Down
28 changes: 20 additions & 8 deletions packages/ws/src/strategies/sharding/WorkerShardingStrategy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ export enum WorkerSendPayloadOp {
Destroy,
Send,
SessionInfoResponse,
ShardCanIdentify,
}

export type WorkerSendPayload =
| { nonce: number; op: WorkerSendPayloadOp.SessionInfoResponse; session: SessionInfo | null }
| { nonce: number; op: WorkerSendPayloadOp.ShardCanIdentify }
| { op: WorkerSendPayloadOp.Connect; shardId: number }
| { op: WorkerSendPayloadOp.Destroy; options?: WebSocketShardDestroyOptions; shardId: number }
| { op: WorkerSendPayloadOp.Send; payload: GatewaySendPayload; shardId: number };
Expand All @@ -32,12 +34,14 @@ export enum WorkerRecievePayloadOp {
Event,
RetrieveSessionInfo,
UpdateSessionInfo,
WaitForIdentify,
}

export type WorkerRecievePayload =
// Can't seem to get a type-safe union based off of the event, so I'm sadly leaving data as any for now
| { data: any; event: WebSocketShardEvents; op: WorkerRecievePayloadOp.Event; shardId: number }
| { nonce: number; op: WorkerRecievePayloadOp.RetrieveSessionInfo; shardId: number }
| { nonce: number; op: WorkerRecievePayloadOp.WaitForIdentify }
| { op: WorkerRecievePayloadOp.Connected; shardId: number }
| { op: WorkerRecievePayloadOp.Destroyed; shardId: number }
| { op: WorkerRecievePayloadOp.UpdateSessionInfo; session: SessionInfo | null; shardId: number };
Expand Down Expand Up @@ -118,12 +122,10 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const promises = [];

for (const [shardId, worker] of this.#workerByShardId.entries()) {
await this.throttler.waitForIdentify();

const payload: WorkerSendPayload = {
const payload = {
op: WorkerSendPayloadOp.Connect,
shardId,
};
} satisfies WorkerSendPayload;

// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<void>((resolve) => this.connectPromises.set(shardId, resolve));
Expand All @@ -141,11 +143,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
const promises = [];

for (const [shardId, worker] of this.#workerByShardId.entries()) {
const payload: WorkerSendPayload = {
const payload = {
op: WorkerSendPayloadOp.Destroy,
shardId,
options,
};
} satisfies WorkerSendPayload;

promises.push(
// eslint-disable-next-line no-promise-executor-return, promise/prefer-await-to-then
Expand All @@ -169,11 +171,11 @@ export class WorkerShardingStrategy implements IShardingStrategy {
throw new Error(`No worker found for shard ${shardId}`);
}

const payload: WorkerSendPayload = {
const payload = {
op: WorkerSendPayloadOp.Send,
shardId,
payload: data,
};
} satisfies WorkerSendPayload;
worker.postMessage(payload);
}

Expand Down Expand Up @@ -213,6 +215,16 @@ export class WorkerShardingStrategy implements IShardingStrategy {
await this.manager.options.updateSessionInfo(payload.shardId, payload.session);
break;
}

case WorkerRecievePayloadOp.WaitForIdentify: {
await this.throttler.waitForIdentify();
const response: WorkerSendPayload = {
op: WorkerSendPayloadOp.ShardCanIdentify,
nonce: payload.nonce,
};
worker.postMessage(response);
break;
}
}
}
}
8 changes: 6 additions & 2 deletions packages/ws/src/strategies/sharding/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ for (const shardId of data.shardIds) {
for (const event of Object.values(WebSocketShardEvents)) {
// @ts-expect-error: Event types incompatible
shard.on(event, (data) => {
const payload: WorkerRecievePayload = {
const payload = {
op: WorkerRecievePayloadOp.Event,
event,
data,
shardId,
};
} satisfies WorkerRecievePayload;
parentPort!.postMessage(payload);
});
}
Expand Down Expand Up @@ -93,5 +93,9 @@ parentPort!
case WorkerSendPayloadOp.SessionInfoResponse: {
break;
}

case WorkerSendPayloadOp.ShardCanIdentify: {
break;
}
}
});
36 changes: 23 additions & 13 deletions packages/ws/src/utils/IdentifyThrottler.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { setTimeout as sleep } from 'node:timers/promises';
import type { WebSocketManager } from '../ws/WebSocketManager';
import { AsyncQueue } from '@sapphire/async-queue';
import type { WebSocketManager } from '../ws/WebSocketManager.js';

export class IdentifyThrottler {
private readonly queue = new AsyncQueue();

private identifyState = {
remaining: 0,
resetsAt: Number.POSITIVE_INFINITY,
Expand All @@ -10,20 +13,27 @@ export class IdentifyThrottler {
public constructor(private readonly manager: WebSocketManager) {}

public async waitForIdentify(): Promise<void> {
if (this.identifyState.remaining <= 0) {
const diff = this.identifyState.resetsAt - Date.now();
if (diff <= 5_000) {
const time = diff + Math.random() * 1_500;
await sleep(time);
await this.queue.wait();

try {
if (this.identifyState.remaining <= 0) {
const diff = this.identifyState.resetsAt - Date.now();
if (diff <= 5_000) {
// To account for the latency the IDENTIFY payload goes through, we add a bit more wait time
const time = diff + Math.random() * 1_500;
await sleep(time);
}

const info = await this.manager.fetchGatewayInformation();
this.identifyState = {
remaining: info.session_start_limit.max_concurrency,
resetsAt: Date.now() + 5_000,
};
}

const info = await this.manager.fetchGatewayInformation();
this.identifyState = {
remaining: info.session_start_limit.max_concurrency,
resetsAt: Date.now() + 5_000,
};
this.identifyState.remaining--;
} finally {
this.queue.shift();
}

this.identifyState.remaining--;
}
}
3 changes: 3 additions & 0 deletions packages/ws/src/ws/WebSocketShard.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,9 @@ export class WebSocketShard extends AsyncEventEmitter<WebSocketShardEventsMap> {
`intents: ${this.strategy.options.intents}`,
`compression: ${this.inflate ? 'zlib-stream' : this.useIdentifyCompress ? 'identify' : 'none'}`,
]);

await this.strategy.waitForIdentify();

const d: GatewayIdentifyData = {
token: this.strategy.options.token,
properties: this.strategy.options.identifyProperties,
Expand Down

0 comments on commit 8f552a0

Please sign in to comment.