summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHazelnoot <acomputerdog@gmail.com>2024-12-07 10:22:45 -0500
committerHazelnoot <acomputerdog@gmail.com>2024-12-07 10:22:49 -0500
commitffc2737478c6f9efd5de9fbaf526b13164727f87 (patch)
tree416a391cdd024e11ad34dfc7707d28dcbbecce19
parentmerge: Fix Content-Length resetting for partial content length requests (!796) (diff)
downloadsharkey-ffc2737478c6f9efd5de9fbaf526b13164727f87.tar.gz
sharkey-ffc2737478c6f9efd5de9fbaf526b13164727f87.tar.bz2
sharkey-ffc2737478c6f9efd5de9fbaf526b13164727f87.zip
implement SkRateLimiterService with Leaky Bucket rate limiting
-rw-r--r--packages/backend/eslint.config.js7
-rw-r--r--packages/backend/src/core/CoreModule.ts6
-rw-r--r--packages/backend/src/core/EnvService.ts20
-rw-r--r--packages/backend/src/core/TimeService.ts27
-rw-r--r--packages/backend/src/server/ServerModule.ts6
-rw-r--r--packages/backend/src/server/api/ApiCallService.ts64
-rw-r--r--packages/backend/src/server/api/RateLimiterService.ts16
-rw-r--r--packages/backend/src/server/api/SkRateLimiterService.ts279
-rw-r--r--packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts703
9 files changed, 1102 insertions, 26 deletions
diff --git a/packages/backend/eslint.config.js b/packages/backend/eslint.config.js
index 452045bc3e..7ee9953478 100644
--- a/packages/backend/eslint.config.js
+++ b/packages/backend/eslint.config.js
@@ -42,6 +42,13 @@ export default [
name: '__filename',
message: 'Not in ESModule. Use `import.meta.url` instead.',
}],
+ // https://typescript-eslint.io/rules/prefer-nullish-coalescing/
+ '@typescript-eslint/prefer-nullish-coalescing': ['warn', {
+ ignorePrimitives: {
+ // Without this, the rule breaks for nullable booleans
+ boolean: true,
+ },
+ }],
},
},
{
diff --git a/packages/backend/src/core/CoreModule.ts b/packages/backend/src/core/CoreModule.ts
index c083068392..b18db7f366 100644
--- a/packages/backend/src/core/CoreModule.ts
+++ b/packages/backend/src/core/CoreModule.ts
@@ -14,6 +14,8 @@ import { AbuseReportNotificationService } from '@/core/AbuseReportNotificationSe
import { SystemWebhookService } from '@/core/SystemWebhookService.js';
import { UserSearchService } from '@/core/UserSearchService.js';
import { WebhookTestService } from '@/core/WebhookTestService.js';
+import { TimeService } from '@/core/TimeService.js';
+import { EnvService } from '@/core/EnvService.js';
import { AccountMoveService } from './AccountMoveService.js';
import { AccountUpdateService } from './AccountUpdateService.js';
import { AnnouncementService } from './AnnouncementService.js';
@@ -381,6 +383,8 @@ const $SponsorsService: Provider = { provide: 'SponsorsService', useExisting: Sp
ChannelFollowingService,
RegistryApiService,
ReversiService,
+ TimeService,
+ EnvService,
ChartLoggerService,
FederationChart,
@@ -680,6 +684,8 @@ const $SponsorsService: Provider = { provide: 'SponsorsService', useExisting: Sp
ChannelFollowingService,
RegistryApiService,
ReversiService,
+ TimeService,
+ EnvService,
FederationChart,
NotesChart,
diff --git a/packages/backend/src/core/EnvService.ts b/packages/backend/src/core/EnvService.ts
new file mode 100644
index 0000000000..8cc3b95735
--- /dev/null
+++ b/packages/backend/src/core/EnvService.ts
@@ -0,0 +1,20 @@
+/*
+ * SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+import { Injectable } from '@nestjs/common';
+
+/**
+ * Provides access to the process environment variables.
+ * This exists for testing purposes, so that a test can mock the environment without corrupting state for other tests.
+ */
+@Injectable()
+export class EnvService {
+ /**
+ * Passthrough to process.env
+ */
+ public get env() {
+ return process.env;
+ }
+}
diff --git a/packages/backend/src/core/TimeService.ts b/packages/backend/src/core/TimeService.ts
new file mode 100644
index 0000000000..59c3d4c12b
--- /dev/null
+++ b/packages/backend/src/core/TimeService.ts
@@ -0,0 +1,27 @@
+/*
+ * SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+import { Injectable } from '@nestjs/common';
+
+/**
+ * Provides abstractions to access the current time.
+ * Exists for unit testing purposes, so that tests can "simulate" any given time for consistency.
+ */
+@Injectable()
+export class TimeService {
+ /**
+ * Returns Date.now()
+ */
+ public get now() {
+ return Date.now();
+ }
+
+ /**
+ * Returns a new Date instance.
+ */
+ public get date() {
+ return new Date();
+ }
+}
diff --git a/packages/backend/src/server/ServerModule.ts b/packages/backend/src/server/ServerModule.ts
index 216e6b4fb8..890447a47f 100644
--- a/packages/backend/src/server/ServerModule.ts
+++ b/packages/backend/src/server/ServerModule.ts
@@ -6,6 +6,7 @@
import { Module } from '@nestjs/common';
import { EndpointsModule } from '@/server/api/EndpointsModule.js';
import { CoreModule } from '@/core/CoreModule.js';
+import { SkRateLimiterService } from '@/server/api/SkRateLimiterService.js';
import { ApiCallService } from './api/ApiCallService.js';
import { FileServerService } from './FileServerService.js';
import { HealthServerService } from './HealthServerService.js';
@@ -73,7 +74,10 @@ import { SigninWithPasskeyApiService } from './api/SigninWithPasskeyApiService.j
ApiLoggerService,
ApiServerService,
AuthenticateService,
- RateLimiterService,
+ {
+ provide: RateLimiterService,
+ useClass: SkRateLimiterService,
+ },
SigninApiService,
SigninWithPasskeyApiService,
SigninService,
diff --git a/packages/backend/src/server/api/ApiCallService.ts b/packages/backend/src/server/api/ApiCallService.ts
index 6f51825494..14367e02bb 100644
--- a/packages/backend/src/server/api/ApiCallService.ts
+++ b/packages/backend/src/server/api/ApiCallService.ts
@@ -8,6 +8,7 @@ import * as fs from 'node:fs';
import * as stream from 'node:stream/promises';
import { Inject, Injectable } from '@nestjs/common';
import * as Sentry from '@sentry/node';
+import { LimiterInfo } from 'ratelimiter';
import { DI } from '@/di-symbols.js';
import { getIpHash } from '@/misc/get-ip-hash.js';
import type { MiLocalUser, MiUser } from '@/models/User.js';
@@ -18,6 +19,7 @@ import { createTemp } from '@/misc/create-temp.js';
import { bindThis } from '@/decorators.js';
import { RoleService } from '@/core/RoleService.js';
import type { Config } from '@/config.js';
+import { isLimitInfo } from '@/server/api/SkRateLimiterService.js';
import { ApiError } from './error.js';
import { RateLimiterService } from './RateLimiterService.js';
import { ApiLoggerService } from './ApiLoggerService.js';
@@ -68,12 +70,17 @@ export class ApiCallService implements OnApplicationShutdown {
} else if (err.code === 'RATE_LIMIT_EXCEEDED') {
const info: unknown = err.info;
const unixEpochInSeconds = Date.now();
- if (typeof(info) === 'object' && info && 'resetMs' in info && typeof(info.resetMs) === 'number') {
+ if (isLimitInfo(info)) {
+ // Number of seconds to wait before trying again. Left for backwards compatibility.
+ reply.header('Retry-After', info.resetSec.toString());
+ // Number of milliseconds to wait before trying again.
+ reply.header('X-RateLimit-Reset', info.resetMs.toString());
+ } else if (typeof(info) === 'object' && info && 'resetMs' in info && typeof(info.resetMs) === 'number') {
const cooldownInSeconds = Math.ceil((info.resetMs - unixEpochInSeconds) / 1000);
// もしかするとマイナスになる可能性がなくはないのでマイナスだったら0にしておく
reply.header('Retry-After', Math.max(cooldownInSeconds, 0).toString(10));
} else {
- this.logger.warn(`rate limit information has unexpected type ${typeof(err.info?.reset)}`);
+ this.logger.warn(`rate limit information has unexpected type: ${JSON.stringify(info)}`);
}
} else if (err.kind === 'client') {
reply.header('WWW-Authenticate', `Bearer realm="Misskey", error="invalid_request", error_description="${err.message}"`);
@@ -168,7 +175,7 @@ export class ApiCallService implements OnApplicationShutdown {
return;
}
this.authenticateService.authenticate(token).then(([user, app]) => {
- this.call(endpoint, user, app, body, null, request).then((res) => {
+ this.call(endpoint, user, app, body, null, request, reply).then((res) => {
if (request.method === 'GET' && endpoint.meta.cacheSec && !token && !user) {
reply.header('Cache-Control', `public, max-age=${endpoint.meta.cacheSec}`);
}
@@ -229,7 +236,7 @@ export class ApiCallService implements OnApplicationShutdown {
this.call(endpoint, user, app, fields, {
name: multipartData.filename,
path: path,
- }, request).then((res) => {
+ }, request, reply).then((res) => {
this.send(reply, res);
}).catch((err: ApiError) => {
this.#sendApiError(reply, err);
@@ -304,6 +311,7 @@ export class ApiCallService implements OnApplicationShutdown {
path: string;
} | null,
request: FastifyRequest<{ Body: Record<string, unknown> | undefined, Querystring: Record<string, unknown> }>,
+ reply: FastifyReply,
) {
const isSecure = user != null && token == null;
@@ -339,19 +347,41 @@ export class ApiCallService implements OnApplicationShutdown {
if (factor > 0) {
// Rate limit
- await this.rateLimiterService.limit(limit as IEndpointMeta['limit'] & { key: NonNullable<string> }, limitActor, factor).catch(err => {
- if ('info' in err) {
- // errはLimiter.LimiterInfoであることが期待される
- throw new ApiError({
- message: 'Rate limit exceeded. Please try again later.',
- code: 'RATE_LIMIT_EXCEEDED',
- id: 'd5826d14-3982-4d2e-8011-b9e9f02499ef',
- httpStatusCode: 429,
- }, err.info);
- } else {
- throw new TypeError('information must be a rate-limiter information.');
- }
- });
+ const info = await this.rateLimiterService.limit(limit as IEndpointMeta['limit'] & { key: NonNullable<string> }, limitActor, factor)
+ .then(info => {
+ // We always want these headers, because clients need them for pacing.
+ // Conditional check in case we somehow revert to the old limiter, which does not return info.
+ if (info) {
+ // Number of seconds until the limit has fully reset.
+ reply.header('X-RateLimit-Clear', info.fullResetSec.toString());
+ // Number of calls that can be made before being limited.
+ reply.header('X-RateLimit-Remaining', info.remaining.toString());
+
+ // Only forward the info object if it's blocked, otherwise we'll reject *all* requests
+ if (info.blocked) {
+ return info;
+ }
+ }
+
+ return undefined;
+ })
+ .catch(err => {
+ // The old limiter throws info instead of returning it.
+ if ('info' in err) {
+ return err.info as LimiterInfo;
+ } else {
+ throw err;
+ }
+ });
+
+ if (info) {
+ throw new ApiError({
+ message: 'Rate limit exceeded. Please try again later.',
+ code: 'RATE_LIMIT_EXCEEDED',
+ id: 'd5826d14-3982-4d2e-8011-b9e9f02499ef',
+ httpStatusCode: 429,
+ }, info);
+ }
}
}
diff --git a/packages/backend/src/server/api/RateLimiterService.ts b/packages/backend/src/server/api/RateLimiterService.ts
index e9afb9d05a..33db016a7c 100644
--- a/packages/backend/src/server/api/RateLimiterService.ts
+++ b/packages/backend/src/server/api/RateLimiterService.ts
@@ -10,28 +10,28 @@ import { DI } from '@/di-symbols.js';
import type Logger from '@/logger.js';
import { LoggerService } from '@/core/LoggerService.js';
import { bindThis } from '@/decorators.js';
+import type { LimitInfo } from '@/server/api/SkRateLimiterService.js';
+import { EnvService } from '@/core/EnvService.js';
import type { IEndpointMeta } from './endpoints.js';
@Injectable()
export class RateLimiterService {
- private logger: Logger;
- private disabled = false;
+ protected readonly logger: Logger;
+ protected readonly disabled: boolean;
constructor(
@Inject(DI.redis)
- private redisClient: Redis.Redis,
+ protected readonly redisClient: Redis.Redis,
private loggerService: LoggerService,
+ envService: EnvService,
) {
this.logger = this.loggerService.getLogger('limiter');
-
- if (process.env.NODE_ENV !== 'production') {
- this.disabled = true;
- }
+ this.disabled = envService.env.NODE_ENV !== 'production';
}
@bindThis
- public limit(limitation: IEndpointMeta['limit'] & { key: NonNullable<string> }, actor: string, factor = 1) {
+ public limit(limitation: IEndpointMeta['limit'] & { key: NonNullable<string> }, actor: string, factor = 1): Promise<LimitInfo | void> {
return new Promise<void>((ok, reject) => {
if (this.disabled) ok();
diff --git a/packages/backend/src/server/api/SkRateLimiterService.ts b/packages/backend/src/server/api/SkRateLimiterService.ts
new file mode 100644
index 0000000000..c44accdb09
--- /dev/null
+++ b/packages/backend/src/server/api/SkRateLimiterService.ts
@@ -0,0 +1,279 @@
+/*
+ * SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+import { Injectable } from '@nestjs/common';
+import Redis from 'ioredis';
+import type { IEndpointMeta } from '@/server/api/endpoints.js';
+import { LoggerService } from '@/core/LoggerService.js';
+import { TimeService } from '@/core/TimeService.js';
+import { EnvService } from '@/core/EnvService.js';
+import { RateLimiterService } from './RateLimiterService.js';
+
+/**
+ * Metadata about the current status of a rate limiter
+ */
+export interface LimitInfo {
+ /**
+ * True if the limit has been reached, and the call should be blocked.
+ */
+ blocked: boolean;
+
+ /**
+ * Number of calls that can be made before the limit is triggered.
+ */
+ remaining: number;
+
+ /**
+ * Time in seconds until the next call can be made, or zero if the next call can be made immediately.
+ * Rounded up to the nearest second.
+ */
+ resetSec: number;
+
+ /**
+ * Time in milliseconds until the next call can be made, or zero if the next call can be made immediately.
+ * Rounded up to the nearest milliseconds.
+ */
+ resetMs: number;
+
+ /**
+ * Time in seconds until the limit has fully reset.
+ * Rounded up to the nearest second.
+ */
+ fullResetSec: number;
+
+ /**
+ * Time in milliseconds until the limit has fully reset.
+ * Rounded up to the nearest millisecond.
+ */
+ fullResetMs: number;
+}
+
+export function isLimitInfo(info: unknown): info is LimitInfo {
+ if (info == null) return false;
+ if (typeof(info) !== 'object') return false;
+ if (!('blocked' in info) || typeof(info.blocked) !== 'boolean') return false;
+ if (!('remaining' in info) || typeof(info.remaining) !== 'number') return false;
+ if (!('resetSec' in info) || typeof(info.resetSec) !== 'number') return false;
+ if (!('resetMs' in info) || typeof(info.resetMs) !== 'number') return false;
+ if (!('fullResetSec' in info) || typeof(info.fullResetSec) !== 'number') return false;
+ if (!('fullResetMs' in info) || typeof(info.fullResetMs) !== 'number') return false;
+ return true;
+}
+
+/**
+ * Rate limit based on "leaky bucket" logic.
+ * The bucket count increases with each call, and decreases gradually at a given rate.
+ * The subject is blocked until the bucket count drops below the limit.
+ */
+export interface RateLimit {
+ /**
+ * Unique key identifying the particular resource (or resource group) being limited.
+ */
+ key: string;
+
+ /**
+ * Constant value identifying the type of rate limit.
+ */
+ type: 'bucket';
+
+ /**
+ * Size of the bucket, in number of requests.
+ * The subject will be blocked when the number of calls exceeds this size.
+ */
+ size: number;
+
+ /**
+ * How often the bucket should "drip" and reduce the counter, measured in milliseconds.
+ * Defaults to 1000 (1 second).
+ */
+ dripRate?: number;
+
+ /**
+ * Amount to reduce the counter on each drip.
+ * Defaults to 1.
+ */
+ dripSize?: number;
+}
+
+export type SupportedRateLimit = RateLimit | LegacyRateLimit;
+export type LegacyRateLimit = IEndpointMeta['limit'] & { key: NonNullable<string>, type: undefined | 'legacy' };
+
+export function isLegacyRateLimit(limit: SupportedRateLimit): limit is LegacyRateLimit {
+ return limit.type === undefined || limit.type === 'legacy';
+}
+
+export function hasMinLimit(limit: LegacyRateLimit): limit is LegacyRateLimit & { minInterval: number } {
+ return !!limit.minInterval;
+}
+
+@Injectable()
+export class SkRateLimiterService extends RateLimiterService {
+ constructor(
+ private readonly timeService: TimeService,
+ redisClient: Redis.Redis,
+ loggerService: LoggerService,
+ envService: EnvService,
+ ) {
+ super(redisClient, loggerService, envService);
+ }
+
+ public async limit(limit: SupportedRateLimit, actor: string, factor = 1): Promise<LimitInfo> {
+ if (this.disabled) {
+ return {
+ blocked: false,
+ remaining: Number.MAX_SAFE_INTEGER,
+ resetSec: 0,
+ resetMs: 0,
+ fullResetSec: 0,
+ fullResetMs: 0,
+ };
+ }
+
+ if (isLegacyRateLimit(limit)) {
+ return await this.limitLegacy(limit, actor, factor);
+ } else {
+ return await this.limitBucket(limit, actor, factor);
+ }
+ }
+
+ private async limitLegacy(limit: LegacyRateLimit, actor: string, factor: number): Promise<LimitInfo> {
+ const promises: Promise<LimitInfo | null>[] = [];
+
+ // The "min" limit - if present - is handled directly.
+ if (hasMinLimit(limit)) {
+ promises.push(
+ this.limitMin(limit, actor, factor),
+ );
+ }
+
+ // Convert the "max" limit into a leaky bucket with 1 drip / second rate.
+ if (limit.max && limit.duration) {
+ promises.push(
+ this.limitBucket({
+ type: 'bucket',
+ key: limit.key,
+ size: limit.max,
+ dripRate: Math.round(limit.duration / limit.max),
+ }, actor, factor),
+ );
+ }
+
+ const [lim1, lim2] = await Promise.all(promises);
+ return {
+ blocked: (lim1?.blocked || lim2?.blocked) ?? false,
+ remaining: Math.min(lim1?.remaining ?? 1, lim2?.remaining ?? 1),
+ resetSec: Math.max(lim1?.resetSec ?? 0, lim2?.resetSec ?? 0),
+ resetMs: Math.max(lim1?.resetMs ?? 0, lim2?.resetMs ?? 0),
+ fullResetSec: Math.max(lim1?.fullResetSec ?? 0, lim2?.fullResetSec ?? 0),
+ fullResetMs: Math.max(lim1?.fullResetMs ?? 0, lim2?.fullResetMs ?? 0),
+ };
+ }
+
+ private async limitMin(limit: LegacyRateLimit & { minInterval: number }, actor: string, factor: number): Promise<LimitInfo | null> {
+ const counter = await this.getLimitCounter(limit, actor, 'min');
+ const maxCalls = Math.max(Math.ceil(factor), 1);
+
+ // Update expiration
+ if (counter.c >= maxCalls) {
+ const isCleared = this.timeService.now - counter.t >= limit.minInterval;
+ if (isCleared) {
+ counter.c = 0;
+ }
+ }
+
+ const blocked = counter.c >= maxCalls;
+ if (!blocked) {
+ counter.c++;
+ counter.t = this.timeService.now;
+ }
+
+ // Calculate limit status
+ const remaining = Math.max(maxCalls - counter.c, 0);
+ const fullResetMs = Math.max(Math.ceil(limit.minInterval - (this.timeService.now - counter.t)), 0);
+ const fullResetSec = Math.ceil(fullResetMs / 1000);
+ const resetMs = remaining < 1 ? fullResetMs : 0;
+ const resetSec = remaining < 1 ? fullResetSec : 0;
+ const limitInfo: LimitInfo = { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs,
+ };
+
+ // Update the limit counter, but not if blocked
+ if (!blocked) {
+ // Don't await, or we will slow down the API.
+ this.setLimitCounter(limit, actor, counter, resetMs, 'min')
+ .catch(err => this.logger.error(`Failed to update limit ${limit.key}:min for ${actor}:`, err));
+ }
+
+ return limitInfo;
+ }
+
+ private async limitBucket(limit: RateLimit, actor: string, factor: number): Promise<LimitInfo> {
+ const counter = await this.getLimitCounter(limit, actor);
+ const dripRate = (limit.dripRate ?? 1000);
+ const dripSize = (limit.dripSize ?? 1);
+ const bucketSize = (limit.size * factor);
+
+ // Update drips
+ if (counter.c > 0) {
+ const dripsSinceLastTick = Math.floor((this.timeService.now - counter.t) / dripRate) * dripSize;
+ counter.c = Math.max(counter.c - dripsSinceLastTick, 0);
+ }
+
+ const blocked = counter.c >= bucketSize;
+ if (!blocked) {
+ counter.c++;
+ counter.t = this.timeService.now;
+ }
+
+ // Calculate limit status
+ const remaining = Math.max(bucketSize - counter.c, 0);
+ const resetMs = remaining > 0 ? 0 : Math.max(dripRate - (this.timeService.now - counter.t), 0);
+ const resetSec = Math.ceil(resetMs / 1000);
+ const fullResetMs = Math.ceil(counter.c / dripSize) * dripRate;
+ const fullResetSec = Math.ceil(fullResetMs / 1000);
+ const limitInfo: LimitInfo = { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs };
+
+ // Update the limit counter, but not if blocked
+ if (!blocked) {
+ // Don't await, or we will slow down the API.
+ this.setLimitCounter(limit, actor, counter, fullResetMs)
+ .catch(err => this.logger.error(`Failed to update limit ${limit.key} for ${actor}:`, err));
+ }
+
+ return limitInfo;
+ }
+
+ private async getLimitCounter(limit: SupportedRateLimit, actor: string, subject?: string): Promise<LimitCounter> {
+ const key = createLimitKey(limit, actor, subject);
+
+ const value = await this.redisClient.get(key);
+ if (value == null) {
+ return { t: 0, c: 0 };
+ }
+
+ return JSON.parse(value);
+ }
+
+ private async setLimitCounter(limit: SupportedRateLimit, actor: string, counter: LimitCounter, expirationMs: number, subject?: string): Promise<void> {
+ const key = createLimitKey(limit, actor, subject);
+ const value = JSON.stringify(counter);
+ await this.redisClient.set(key, value, 'PX', expirationMs);
+ }
+}
+
+function createLimitKey(limit: SupportedRateLimit, actor: string, subject?: string): string {
+ if (subject) {
+ return `rl_${actor}_${limit.key}_${subject}`;
+ } else {
+ return `rl_${actor}_${limit.key}`;
+ }
+}
+
+export interface LimitCounter {
+ /** Timestamp */
+ t: number;
+
+ /** Counter */
+ c: number;
+}
diff --git a/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts b/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts
new file mode 100644
index 0000000000..8554aa39ef
--- /dev/null
+++ b/packages/backend/test/unit/server/api/SkRateLimiterServiceTests.ts
@@ -0,0 +1,703 @@
+/*
+ * SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors
+ * SPDX-License-Identifier: AGPL-3.0-only
+ */
+
+import { KEYWORD } from 'color-convert/conversions.js';
+import type Redis from 'ioredis';
+import { LegacyRateLimit, LimitCounter, RateLimit, SkRateLimiterService } from '@/server/api/SkRateLimiterService.js';
+import { LoggerService } from '@/core/LoggerService.js';
+
+/* eslint-disable @typescript-eslint/no-non-null-assertion */
+/* eslint-disable @typescript-eslint/no-unnecessary-condition */
+
+describe(SkRateLimiterService, () => {
+ let mockTimeService: { now: number, date: Date } = null!;
+ let mockRedisGet: ((key: string) => string | null) | undefined = undefined;
+ let mockRedisSet: ((args: unknown[]) => void) | undefined = undefined;
+ let mockEnvironment: Record<string, string | undefined> = null!;
+ let serviceUnderTest: () => SkRateLimiterService = null!;
+
+ let loggedMessages: { level: string, data: unknown[] }[] = [];
+
+ beforeEach(() => {
+ mockTimeService = {
+ now: 0,
+ get date() {
+ return new Date(mockTimeService.now);
+ },
+ };
+
+ mockRedisGet = undefined;
+ mockRedisSet = undefined;
+ const mockRedisClient = {
+ get(key: string) {
+ if (mockRedisGet) return Promise.resolve(mockRedisGet(key));
+ else return Promise.resolve(null);
+ },
+ set(...args: unknown[]): Promise<void> {
+ if (mockRedisSet) mockRedisSet(args);
+ return Promise.resolve();
+ },
+ } as unknown as Redis.Redis;
+
+ mockEnvironment = Object.create(process.env);
+ mockEnvironment.NODE_ENV = 'production';
+ const mockEnvService = {
+ env: mockEnvironment,
+ };
+
+ loggedMessages = [];
+ const mockLogService = {
+ getLogger() {
+ return {
+ createSubLogger(context: string, color?: KEYWORD) {
+ return mockLogService.getLogger(context, color);
+ },
+ error(...data: unknown[]) {
+ loggedMessages.push({ level: 'error', data });
+ },
+ warn(...data: unknown[]) {
+ loggedMessages.push({ level: 'warn', data });
+ },
+ succ(...data: unknown[]) {
+ loggedMessages.push({ level: 'succ', data });
+ },
+ debug(...data: unknown[]) {
+ loggedMessages.push({ level: 'debug', data });
+ },
+ info(...data: unknown[]) {
+ loggedMessages.push({ level: 'info', data });
+ },
+ };
+ },
+ } as unknown as LoggerService;
+
+ let service: SkRateLimiterService | undefined = undefined;
+ serviceUnderTest = () => {
+ return service ??= new SkRateLimiterService(mockTimeService, mockRedisClient, mockLogService, mockEnvService);
+ };
+ });
+
+ function expectNoUnhandledErrors() {
+ const unhandledErrors = loggedMessages.filter(m => m.level === 'error');
+ if (unhandledErrors.length > 0) {
+ throw new Error(`Test failed: got unhandled errors ${unhandledErrors.join('\n')}`);
+ }
+ }
+
+ describe('limit', () => {
+ const actor = 'actor';
+ const key = 'test';
+
+ let counter: LimitCounter | undefined = undefined;
+ let minCounter: LimitCounter | undefined = undefined;
+
+ beforeEach(() => {
+ counter = undefined;
+ minCounter = undefined;
+
+ mockRedisGet = (key: string) => {
+ if (key === 'rl_actor_test' && counter) {
+ return JSON.stringify(counter);
+ }
+
+ if (key === 'rl_actor_test_min' && minCounter) {
+ return JSON.stringify(minCounter);
+ }
+
+ return null;
+ };
+
+ mockRedisSet = (args: unknown[]) => {
+ const [key, value] = args;
+
+ if (key === 'rl_actor_test') {
+ if (value == null) counter = undefined;
+ else if (typeof(value) === 'string') counter = JSON.parse(value);
+ else throw new Error('invalid redis call');
+ }
+
+ if (key === 'rl_actor_test_min') {
+ if (value == null) minCounter = undefined;
+ else if (typeof(value) === 'string') minCounter = JSON.parse(value);
+ else throw new Error('invalid redis call');
+ }
+ };
+ });
+
+ it('should bypass in non-production', async () => {
+ mockEnvironment.NODE_ENV = 'test';
+
+ const info = await serviceUnderTest().limit({ key: 'l', type: undefined, max: 0 }, 'actor');
+
+ expect(info.blocked).toBeFalsy();
+ expect(info.remaining).toBe(Number.MAX_SAFE_INTEGER);
+ expect(info.resetSec).toBe(0);
+ expect(info.resetMs).toBe(0);
+ expect(info.fullResetSec).toBe(0);
+ expect(info.fullResetMs).toBe(0);
+ });
+
+ describe('with bucket limit', () => {
+ let limit: RateLimit = null!;
+
+ beforeEach(() => {
+ limit = {
+ type: 'bucket',
+ key: 'test',
+ size: 1,
+ };
+ });
+
+ it('should allow when limit is not reached', async () => {
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeFalsy();
+ });
+
+ it('should not error when allowed', async () => {
+ await serviceUnderTest().limit(limit, actor);
+
+ expectNoUnhandledErrors();
+ });
+
+ it('should return correct info when allowed', async () => {
+ limit.size = 2;
+ counter = { c: 1, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.remaining).toBe(0);
+ expect(info.resetSec).toBe(1);
+ expect(info.resetMs).toBe(1000);
+ expect(info.fullResetSec).toBe(2);
+ expect(info.fullResetMs).toBe(2000);
+ });
+
+ it('should increment counter when called', async () => {
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(counter).not.toBeUndefined();
+ expect(counter?.c).toBe(1);
+ });
+
+ it('should set timestamp when called', async () => {
+ mockTimeService.now = 1000;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(counter).not.toBeUndefined();
+ expect(counter?.t).toBe(1000);
+ });
+
+ it('should decrement counter when dripRate has passed', async () => {
+ counter = { c: 2, t: 0 };
+ mockTimeService.now = 2000;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(counter).not.toBeUndefined();
+ expect(counter?.c).toBe(1); // 2 (starting) - 2 (2x1 drip) + 1 (call) = 1
+ });
+
+ it('should decrement counter by dripSize', async () => {
+ counter = { c: 2, t: 0 };
+ limit.dripSize = 2;
+ mockTimeService.now = 1000;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(counter).not.toBeUndefined();
+ expect(counter?.c).toBe(1); // 2 (starting) - 2 (1x2 drip) + 1 (call) = 1
+ });
+
+ it('should maintain counter between calls over time', async () => {
+ limit.size = 5;
+
+ await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
+ mockTimeService.now += 1000; // 1 - 1 = 0
+ await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
+ await serviceUnderTest().limit(limit, actor); // 1 + 1 = 2
+ await serviceUnderTest().limit(limit, actor); // 2 + 1 = 3
+ mockTimeService.now += 1000; // 3 - 1 = 2
+ mockTimeService.now += 1000; // 2 - 1 = 1
+ await serviceUnderTest().limit(limit, actor); // 1 + 1 = 2
+
+ expect(counter?.c).toBe(2);
+ expect(counter?.t).toBe(3000);
+ });
+
+ it('should log error and continue when update fails', async () => {
+ mockRedisSet = () => {
+ throw new Error('test error');
+ };
+
+ await serviceUnderTest().limit(limit, actor);
+
+ const matchingError = loggedMessages
+ .find(m => m.level === 'error' && m.data
+ .some(d => typeof(d) === 'string' && d.includes('Failed to update limit')));
+ expect(matchingError).toBeTruthy();
+ });
+
+ it('should block when bucket is filled', async () => {
+ counter = { c: 1, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeTruthy();
+ });
+
+ it('should calculate correct info when blocked', async () => {
+ counter = { c: 1, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.resetSec).toBe(1);
+ expect(info.resetMs).toBe(1000);
+ expect(info.fullResetSec).toBe(1);
+ expect(info.fullResetMs).toBe(1000);
+ });
+
+ it('should allow when bucket is filled but should drip', async () => {
+ counter = { c: 1, t: 0 };
+ mockTimeService.now = 1000;
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeFalsy();
+ });
+
+ it('should scale limit by factor', async () => {
+ counter = { c: 1, t: 0 };
+
+ const i1 = await serviceUnderTest().limit(limit, actor, 2); // 1 + 1 = 2
+ const i2 = await serviceUnderTest().limit(limit, actor, 2); // 2 + 1 = 3
+
+ expect(i1.blocked).toBeFalsy();
+ expect(i2.blocked).toBeTruthy();
+ });
+
+ it('should set key expiration', async () => {
+ mockRedisSet = args => {
+ expect(args[2]).toBe('PX');
+ expect(args[3]).toBe(1000);
+ };
+
+ await serviceUnderTest().limit(limit, actor);
+ });
+
+ it('should not increment when already blocked', async () => {
+ counter = { c: 1, t: 0 };
+ mockTimeService.now += 100;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(counter?.c).toBe(1);
+ expect(counter?.t).toBe(0);
+ });
+ });
+
+ describe('with min interval', () => {
+ let limit: MutableLegacyRateLimit = null!;
+
+ beforeEach(() => {
+ limit = {
+ type: undefined,
+ key,
+ minInterval: 1000,
+ };
+ });
+
+ it('should allow when limit is not reached', async () => {
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeFalsy();
+ });
+
+ it('should not error when allowed', async () => {
+ await serviceUnderTest().limit(limit, actor);
+
+ expectNoUnhandledErrors();
+ });
+
+ it('should calculate correct info when allowed', async () => {
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.remaining).toBe(0);
+ expect(info.resetSec).toBe(1);
+ expect(info.resetMs).toBe(1000);
+ expect(info.fullResetSec).toBe(1);
+ expect(info.fullResetMs).toBe(1000);
+ });
+
+ it('should increment counter when called', async () => {
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(minCounter).not.toBeUndefined();
+ expect(minCounter?.c).toBe(1);
+ });
+
+ it('should set timestamp when called', async () => {
+ mockTimeService.now = 1000;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(minCounter).not.toBeUndefined();
+ expect(minCounter?.t).toBe(1000);
+ });
+
+ it('should decrement counter when minInterval has passed', async () => {
+ minCounter = { c: 1, t: 0 };
+ mockTimeService.now = 1000;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(minCounter).not.toBeUndefined();
+ expect(minCounter?.c).toBe(1); // 1 (starting) - 1 (interval) + 1 (call) = 1
+ });
+
+ it('should reset counter entirely', async () => {
+ minCounter = { c: 2, t: 0 };
+ mockTimeService.now = 1000;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(minCounter).not.toBeUndefined();
+ expect(minCounter?.c).toBe(1); // 2 (starting) - 2 (interval) + 1 (call) = 1
+ });
+
+ it('should maintain counter between calls over time', async () => {
+ await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
+ mockTimeService.now += 1000; // 1 - 1 = 0
+ await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
+ await serviceUnderTest().limit(limit, actor); // blocked
+ await serviceUnderTest().limit(limit, actor); // blocked
+ mockTimeService.now += 1000; // 1 - 1 = 0
+ mockTimeService.now += 1000; // 0 - 1 = 0
+ await serviceUnderTest().limit(limit, actor); // 0 + 1 = 1
+
+ expect(minCounter?.c).toBe(1);
+ expect(minCounter?.t).toBe(3000);
+ });
+
+ it('should log error and continue when update fails', async () => {
+ mockRedisSet = () => {
+ throw new Error('test error');
+ };
+
+ await serviceUnderTest().limit(limit, actor);
+
+ const matchingError = loggedMessages
+ .find(m => m.level === 'error' && m.data
+ .some(d => typeof(d) === 'string' && d.includes('Failed to update limit')));
+ expect(matchingError).toBeTruthy();
+ });
+
+ it('should block when interval exceeded', async () => {
+ minCounter = { c: 1, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeTruthy();
+ });
+
+ it('should calculate correct info when blocked', async () => {
+ minCounter = { c: 1, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.resetSec).toBe(1);
+ expect(info.resetMs).toBe(1000);
+ expect(info.fullResetSec).toBe(1);
+ expect(info.fullResetMs).toBe(1000);
+ });
+
+ it('should allow when bucket is filled but interval has passed', async () => {
+ minCounter = { c: 1, t: 0 };
+ mockTimeService.now = 1000;
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeFalsy();
+ });
+
+ it('should scale limit by factor', async () => {
+ minCounter = { c: 1, t: 0 };
+
+ const i1 = await serviceUnderTest().limit(limit, actor, 2); // 1 + 1 = 2
+ const i2 = await serviceUnderTest().limit(limit, actor, 2); // 2 + 1 = 3
+
+ expect(i1.blocked).toBeFalsy();
+ expect(i2.blocked).toBeTruthy();
+ });
+
+ it('should set key expiration', async () => {
+ mockRedisSet = args => {
+ expect(args[2]).toBe('PX');
+ expect(args[3]).toBe(1000);
+ };
+
+ await serviceUnderTest().limit(limit, actor);
+ });
+
+ it('should not increment when already blocked', async () => {
+ minCounter = { c: 1, t: 0 };
+ mockTimeService.now += 100;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(minCounter?.c).toBe(1);
+ expect(minCounter?.t).toBe(0);
+ });
+ });
+
+ describe('with legacy limit', () => {
+ let limit: MutableLegacyRateLimit = null!;
+
+ beforeEach(() => {
+ limit = {
+ type: undefined,
+ key,
+ max: 1,
+ duration: 1000,
+ };
+ });
+
+ it('should allow when limit is not reached', async () => {
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeFalsy();
+ });
+
+ it('should not error when allowed', async () => {
+ await serviceUnderTest().limit(limit, actor);
+
+ expectNoUnhandledErrors();
+ });
+
+ it('should infer dripRate from duration', async () => {
+ limit.max = 10;
+ limit.duration = 10000;
+ counter = { c: 10, t: 0 };
+
+ const i1 = await serviceUnderTest().limit(limit, actor);
+ mockTimeService.now += 1000;
+ const i2 = await serviceUnderTest().limit(limit, actor);
+ mockTimeService.now += 2000;
+ const i3 = await serviceUnderTest().limit(limit, actor);
+ const i4 = await serviceUnderTest().limit(limit, actor);
+ const i5 = await serviceUnderTest().limit(limit, actor);
+ mockTimeService.now += 2000;
+ const i6 = await serviceUnderTest().limit(limit, actor);
+
+ expect(i1.blocked).toBeTruthy();
+ expect(i2.blocked).toBeFalsy();
+ expect(i3.blocked).toBeFalsy();
+ expect(i4.blocked).toBeFalsy();
+ expect(i5.blocked).toBeTruthy();
+ expect(i6.blocked).toBeFalsy();
+ });
+
+ it('should calculate correct info when allowed', async () => {
+ limit.max = 10;
+ limit.duration = 10000;
+ counter = { c: 10, t: 0 };
+ mockTimeService.now += 2000;
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.remaining).toBe(1);
+ expect(info.resetSec).toBe(0);
+ expect(info.resetMs).toBe(0);
+ expect(info.fullResetSec).toBe(9);
+ expect(info.fullResetMs).toBe(9000);
+ });
+
+ it('should calculate correct info when blocked', async () => {
+ limit.max = 10;
+ limit.duration = 10000;
+ counter = { c: 10, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.remaining).toBe(0);
+ expect(info.resetSec).toBe(1);
+ expect(info.resetMs).toBe(1000);
+ expect(info.fullResetSec).toBe(10);
+ expect(info.fullResetMs).toBe(10000);
+ });
+
+ it('should allow when bucket is filled but interval has passed', async () => {
+ counter = { c: 10, t: 0 };
+ mockTimeService.now = 1000;
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeTruthy();
+ });
+
+ it('should scale limit by factor', async () => {
+ counter = { c: 10, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor, 2); // 10 + 1 = 11
+
+ expect(info.blocked).toBeTruthy();
+ });
+
+ it('should set key expiration', async () => {
+ mockRedisSet = args => {
+ expect(args[2]).toBe('PX');
+ expect(args[3]).toBe(1000);
+ };
+
+ await serviceUnderTest().limit(limit, actor);
+ });
+
+ it('should not increment when already blocked', async () => {
+ counter = { c: 1, t: 0 };
+ mockTimeService.now += 100;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(counter?.c).toBe(1);
+ expect(counter?.t).toBe(0);
+ });
+ });
+
+ describe('with legacy limit and min interval', () => {
+ let limit: MutableLegacyRateLimit = null!;
+
+ beforeEach(() => {
+ limit = {
+ type: undefined,
+ key,
+ max: 5,
+ duration: 5000,
+ minInterval: 1000,
+ };
+ });
+
+ it('should allow when limit is not reached', async () => {
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeFalsy();
+ });
+
+ it('should not error when allowed', async () => {
+ await serviceUnderTest().limit(limit, actor);
+
+ expectNoUnhandledErrors();
+ });
+
+ it('should block when limit exceeded', async () => {
+ counter = { c: 5, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeTruthy();
+ });
+
+ it('should block when minInterval exceeded', async () => {
+ minCounter = { c: 1, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeTruthy();
+ });
+
+ it('should calculate correct info when allowed', async () => {
+ counter = { c: 1, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.remaining).toBe(0);
+ expect(info.resetSec).toBe(1);
+ expect(info.resetMs).toBe(1000);
+ expect(info.fullResetSec).toBe(2);
+ expect(info.fullResetMs).toBe(2000);
+ });
+
+ it('should calculate correct info when blocked by limit', async () => {
+ counter = { c: 5, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.remaining).toBe(0);
+ expect(info.resetSec).toBe(1);
+ expect(info.resetMs).toBe(1000);
+ expect(info.fullResetSec).toBe(5);
+ expect(info.fullResetMs).toBe(5000);
+ });
+
+ it('should calculate correct info when blocked by minInterval', async () => {
+ minCounter = { c: 1, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.remaining).toBe(0);
+ expect(info.resetSec).toBe(1);
+ expect(info.resetMs).toBe(1000);
+ expect(info.fullResetSec).toBe(1);
+ expect(info.fullResetMs).toBe(1000);
+ });
+
+ it('should allow when counter is filled but interval has passed', async () => {
+ counter = { c: 5, t: 0 };
+ mockTimeService.now = 1000;
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeFalsy();
+ });
+
+ it('should allow when minCounter is filled but interval has passed', async () => {
+ minCounter = { c: 1, t: 0 };
+ mockTimeService.now = 1000;
+
+ const info = await serviceUnderTest().limit(limit, actor);
+
+ expect(info.blocked).toBeFalsy();
+ });
+
+ it('should scale limit by factor', async () => {
+ minCounter = { c: 5, t: 0 };
+
+ const info = await serviceUnderTest().limit(limit, actor, 2);
+
+ expect(info.blocked).toBeTruthy();
+ });
+
+ it('should set key expiration', async () => {
+ mockRedisSet = args => {
+ expect(args[2]).toBe('PX');
+ expect(args[3]).toBe(1000);
+ };
+
+ await serviceUnderTest().limit(limit, actor);
+ });
+
+ it('should not increment when already blocked', async () => {
+ counter = { c: 5, t: 0 };
+ minCounter = { c: 1, t: 0 };
+ mockTimeService.now += 100;
+
+ await serviceUnderTest().limit(limit, actor);
+
+ expect(counter?.c).toBe(5);
+ expect(counter?.t).toBe(0);
+ expect(minCounter?.c).toBe(1);
+ expect(minCounter?.t).toBe(0);
+ });
+ });
+ });
+});
+
+// The same thing, but mutable
+interface MutableLegacyRateLimit extends LegacyRateLimit {
+ key: string;
+ duration?: number;
+ max?: number;
+ minInterval?: number;
+}