Skip to content

Commit

Permalink
refactor: abstract identify throttling and correct max_concurrency ha…
Browse files Browse the repository at this point in the history
…ndling (#9375)

* refactor: properly support max_concurrency ratelimit keys

* fix: properly block for same key

* chore: export session state

* chore: throttler no longer requires manager

* refactor: abstract throttlers

* chore: proper member order

* chore: remove leftover debug log

* chore: use @link tag in doc comment

Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com>

* chore: suggested changes

* fix(WebSocketShard): cancel identify if the shard closed in the meantime

* refactor(throttlers): support abort signals

* fix: memory leak

* chore: remove leftover

---------

Co-authored-by: Jiralite <33201955+Jiralite@users.noreply.github.com>
Co-authored-by: kodiakhq[bot] <49736102+kodiakhq[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 14, 2023
1 parent cac3c07 commit 02dfaf1
Show file tree
Hide file tree
Showing 16 changed files with 279 additions and 161 deletions.
30 changes: 15 additions & 15 deletions packages/ws/__tests__/strategy/WorkerShardingStrategy.test.ts
Expand Up @@ -57,9 +57,9 @@ vi.mock('node:worker_threads', async () => {
this.emit('online');
// same deal here
setImmediate(() => {
const message = {
const message: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.WorkerReady,
} satisfies WorkerReceivePayload;
};
this.emit('message', message);
});
});
Expand All @@ -68,39 +68,39 @@ vi.mock('node:worker_threads', async () => {
public postMessage(message: WorkerSendPayload) {
switch (message.op) {
case WorkerSendPayloadOp.Connect: {
const response = {
const response: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.Connected,
shardId: message.shardId,
} satisfies WorkerReceivePayload;
};
this.emit('message', response);
break;
}

case WorkerSendPayloadOp.Destroy: {
const response = {
const response: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.Destroyed,
shardId: message.shardId,
} satisfies WorkerReceivePayload;
};
this.emit('message', response);
break;
}

case WorkerSendPayloadOp.Send: {
if (message.payload.op === GatewayOpcodes.RequestGuildMembers) {
const response = {
const response: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.Event,
shardId: message.shardId,
event: WebSocketShardEvents.Dispatch,
data: memberChunkData,
} satisfies WorkerReceivePayload;
};
this.emit('message', response);

// Fetch session info
const sessionFetch = {
const sessionFetch: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.RetrieveSessionInfo,
shardId: message.shardId,
nonce: Math.random(),
} satisfies WorkerReceivePayload;
};
this.emit('message', sessionFetch);
}

Expand All @@ -111,16 +111,16 @@ vi.mock('node:worker_threads', async () => {
case WorkerSendPayloadOp.SessionInfoResponse: {
message.session ??= sessionInfo;

const session = {
const session: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.UpdateSessionInfo,
shardId: message.session.shardId,
session: { ...message.session, sequence: message.session.sequence + 1 },
} satisfies WorkerReceivePayload;
};
this.emit('message', session);
break;
}

case WorkerSendPayloadOp.ShardCanIdentify: {
case WorkerSendPayloadOp.ShardIdentifyResponse: {
break;
}

Expand Down Expand Up @@ -198,10 +198,10 @@ test('spawn, connect, send a message, session info, and destroy', async () => {
expect.objectContaining({ workerData: expect.objectContaining({ shardIds: [0, 1] }) }),
);

const payload = {
const payload: GatewaySendPayload = {
op: GatewayOpcodes.RequestGuildMembers,
d: { guild_id: '123', limit: 0, query: '' },
} satisfies GatewaySendPayload;
};
await manager.send(0, payload);
expect(mockSend).toHaveBeenCalledWith(0, payload);
expect(managerEmitSpy).toHaveBeenCalledWith(WebSocketShardEvents.Dispatch, {
Expand Down
46 changes: 0 additions & 46 deletions packages/ws/__tests__/util/IdentifyThrottler.test.ts

This file was deleted.

32 changes: 32 additions & 0 deletions packages/ws/__tests__/util/SimpleIdentifyThrottler.test.ts
@@ -0,0 +1,32 @@
import { setTimeout as sleep } from 'node:timers/promises';
import { expect, test, vi, type Mock } from 'vitest';
import { SimpleIdentifyThrottler } from '../../src/index.js';

vi.mock('node:timers/promises', () => ({
setTimeout: vi.fn(),
}));

const throttler = new SimpleIdentifyThrottler(2);

vi.useFakeTimers();

const NOW = vi.fn().mockReturnValue(Date.now());
global.Date.now = NOW;

test('basic case', async () => {
// Those shouldn't wait since they're in different keys

await throttler.waitForIdentify(0);
expect(sleep).not.toHaveBeenCalled();

await throttler.waitForIdentify(1);
expect(sleep).not.toHaveBeenCalled();

// Those should wait

await throttler.waitForIdentify(2);
expect(sleep).toHaveBeenCalledTimes(1);

await throttler.waitForIdentify(3);
expect(sleep).toHaveBeenCalledTimes(2);
});
4 changes: 3 additions & 1 deletion packages/ws/src/index.ts
Expand Up @@ -6,8 +6,10 @@ export * from './strategies/sharding/IShardingStrategy.js';
export * from './strategies/sharding/SimpleShardingStrategy.js';
export * from './strategies/sharding/WorkerShardingStrategy.js';

export * from './throttling/IIdentifyThrottler.js';
export * from './throttling/SimpleIdentifyThrottler.js';

export * from './utils/constants.js';
export * from './utils/IdentifyThrottler.js';
export * from './utils/WorkerBootstrapper.js';

export * from './ws/WebSocketManager.js';
Expand Down
28 changes: 23 additions & 5 deletions packages/ws/src/strategies/context/IContextFetchingStrategy.ts
Expand Up @@ -5,7 +5,13 @@ import type { SessionInfo, WebSocketManager, WebSocketManagerOptions } from '../
export interface FetchingStrategyOptions
extends Omit<
WebSocketManagerOptions,
'buildStrategy' | 'rest' | 'retrieveSessionInfo' | 'shardCount' | 'shardIds' | 'updateSessionInfo'
| 'buildIdentifyThrottler'
| 'buildStrategy'
| 'rest'
| 'retrieveSessionInfo'
| 'shardCount'
| 'shardIds'
| 'updateSessionInfo'
> {
readonly gatewayInformation: APIGatewayBotInfo;
readonly shardCount: number;
Expand All @@ -18,13 +24,25 @@ export interface IContextFetchingStrategy {
readonly options: FetchingStrategyOptions;
retrieveSessionInfo(shardId: number): Awaitable<SessionInfo | null>;
updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null): Awaitable<void>;
waitForIdentify(): Promise<void>;
/**
* Resolves once the given shard should be allowed to identify, or rejects if the operation was aborted
*/
waitForIdentify(shardId: number, signal: AbortSignal): Promise<void>;
}

export async function managerToFetchingStrategyOptions(manager: WebSocketManager): Promise<FetchingStrategyOptions> {
// eslint-disable-next-line @typescript-eslint/unbound-method
const { buildStrategy, retrieveSessionInfo, updateSessionInfo, shardCount, shardIds, rest, ...managerOptions } =
manager.options;
/* eslint-disable @typescript-eslint/unbound-method */
const {
buildIdentifyThrottler,
buildStrategy,
retrieveSessionInfo,
updateSessionInfo,
shardCount,
shardIds,
rest,
...managerOptions
} = manager.options;
/* eslint-enable @typescript-eslint/unbound-method */

return {
...managerOptions,
Expand Down
30 changes: 14 additions & 16 deletions packages/ws/src/strategies/context/SimpleContextFetchingStrategy.ts
@@ -1,29 +1,26 @@
import { IdentifyThrottler } from '../../utils/IdentifyThrottler.js';
import type { IIdentifyThrottler } from '../../throttling/IIdentifyThrottler.js';
import type { SessionInfo, WebSocketManager } from '../../ws/WebSocketManager.js';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';

export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
// This strategy assumes every shard is running under the same process - therefore we need a single
// IdentifyThrottler per manager.
private static throttlerCache = new WeakMap<WebSocketManager, IdentifyThrottler>();
private static throttlerCache = new WeakMap<WebSocketManager, IIdentifyThrottler>();

private static ensureThrottler(manager: WebSocketManager): IdentifyThrottler {
const existing = SimpleContextFetchingStrategy.throttlerCache.get(manager);
if (existing) {
return existing;
private static async ensureThrottler(manager: WebSocketManager): Promise<IIdentifyThrottler> {
const throttler = SimpleContextFetchingStrategy.throttlerCache.get(manager);
if (throttler) {
return throttler;
}

const throttler = new IdentifyThrottler(manager);
SimpleContextFetchingStrategy.throttlerCache.set(manager, throttler);
return throttler;
}

private readonly throttler: IdentifyThrottler;
const newThrottler = await manager.options.buildIdentifyThrottler(manager);
SimpleContextFetchingStrategy.throttlerCache.set(manager, newThrottler);

public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {
this.throttler = SimpleContextFetchingStrategy.ensureThrottler(manager);
return newThrottler;
}

public constructor(private readonly manager: WebSocketManager, public readonly options: FetchingStrategyOptions) {}

public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
return this.manager.options.retrieveSessionInfo(shardId);
}
Expand All @@ -32,7 +29,8 @@ export class SimpleContextFetchingStrategy implements IContextFetchingStrategy {
return this.manager.options.updateSessionInfo(shardId, sessionInfo);
}

public async waitForIdentify(): Promise<void> {
await this.throttler.waitForIdentify();
public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
const throttler = await SimpleContextFetchingStrategy.ensureThrottler(this.manager);
await throttler.waitForIdentify(shardId, signal);
}
}
60 changes: 47 additions & 13 deletions packages/ws/src/strategies/context/WorkerContextFetchingStrategy.ts
Expand Up @@ -9,10 +9,17 @@ import {
} from '../sharding/WorkerShardingStrategy.js';
import type { FetchingStrategyOptions, IContextFetchingStrategy } from './IContextFetchingStrategy.js';

// Because the global types are incomplete for whatever reason
interface PolyFillAbortSignal {
readonly aborted: boolean;
addEventListener(type: 'abort', listener: () => void): void;
removeEventListener(type: 'abort', listener: () => void): void;
}

export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
private readonly sessionPromises = new Collection<number, (session: SessionInfo | null) => void>();

private readonly waitForIdentifyPromises = new Collection<number, () => void>();
private readonly waitForIdentifyPromises = new Collection<number, { reject(): void; resolve(): void }>();

public constructor(public readonly options: FetchingStrategyOptions) {
if (isMainThread) {
Expand All @@ -25,44 +32,71 @@ export class WorkerContextFetchingStrategy implements IContextFetchingStrategy {
this.sessionPromises.delete(payload.nonce);
}

if (payload.op === WorkerSendPayloadOp.ShardCanIdentify) {
this.waitForIdentifyPromises.get(payload.nonce)?.();
if (payload.op === WorkerSendPayloadOp.ShardIdentifyResponse) {
const promise = this.waitForIdentifyPromises.get(payload.nonce);
if (payload.ok) {
promise?.resolve();
} else {
promise?.reject();
}

this.waitForIdentifyPromises.delete(payload.nonce);
}
});
}

public async retrieveSessionInfo(shardId: number): Promise<SessionInfo | null> {
const nonce = Math.random();
const payload = {
const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.RetrieveSessionInfo,
shardId,
nonce,
} satisfies WorkerReceivePayload;
};
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<SessionInfo | null>((resolve) => this.sessionPromises.set(nonce, resolve));
parentPort!.postMessage(payload);
return promise;
}

public updateSessionInfo(shardId: number, sessionInfo: SessionInfo | null) {
const payload = {
const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.UpdateSessionInfo,
shardId,
session: sessionInfo,
} satisfies WorkerReceivePayload;
};
parentPort!.postMessage(payload);
}

public async waitForIdentify(): Promise<void> {
public async waitForIdentify(shardId: number, signal: AbortSignal): Promise<void> {
const nonce = Math.random();
const payload = {

const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.WaitForIdentify,
nonce,
} satisfies WorkerReceivePayload;
// eslint-disable-next-line no-promise-executor-return
const promise = new Promise<void>((resolve) => this.waitForIdentifyPromises.set(nonce, resolve));
shardId,
};
const promise = new Promise<void>((resolve, reject) =>
// eslint-disable-next-line no-promise-executor-return
this.waitForIdentifyPromises.set(nonce, { resolve, reject }),
);

parentPort!.postMessage(payload);
return promise;

const listener = () => {
const payload: WorkerReceivePayload = {
op: WorkerReceivePayloadOp.CancelIdentify,
nonce,
};

parentPort!.postMessage(payload);
};

(signal as unknown as PolyFillAbortSignal).addEventListener('abort', listener);

try {
await promise;
} finally {
(signal as unknown as PolyFillAbortSignal).removeEventListener('abort', listener);
}
}
}
Expand Up @@ -23,6 +23,7 @@ export class SimpleShardingStrategy implements IShardingStrategy {
*/
public async spawn(shardIds: number[]) {
const strategyOptions = await managerToFetchingStrategyOptions(this.manager);

for (const shardId of shardIds) {
const strategy = new SimpleContextFetchingStrategy(this.manager, strategyOptions);
const shard = new WebSocketShard(strategy, shardId);
Expand Down

0 comments on commit 02dfaf1

Please sign in to comment.