From f92fb3bb8c70b5ca35f4f61a437b101bd5bae7e8 Mon Sep 17 00:00:00 2001 From: Hazelnoot Date: Wed, 5 Feb 2025 11:16:39 -0500 Subject: move SkRateLimiterService to correct directory --- .../backend/src/server/SkRateLimiterService.ts | 280 +++++++++++++++++++++ 1 file changed, 280 insertions(+) create mode 100644 packages/backend/src/server/SkRateLimiterService.ts (limited to 'packages/backend/src/server/SkRateLimiterService.ts') diff --git a/packages/backend/src/server/SkRateLimiterService.ts b/packages/backend/src/server/SkRateLimiterService.ts new file mode 100644 index 0000000000..038f12cb25 --- /dev/null +++ b/packages/backend/src/server/SkRateLimiterService.ts @@ -0,0 +1,280 @@ +/* + * SPDX-FileCopyrightText: hazelnoot and other Sharkey contributors + * SPDX-License-Identifier: AGPL-3.0-only + */ + +import { Inject, Injectable } from '@nestjs/common'; +import Redis from 'ioredis'; +import type { TimeService } from '@/core/TimeService.js'; +import type { EnvService } from '@/core/EnvService.js'; +import { BucketRateLimit, LegacyRateLimit, LimitInfo, RateLimit, hasMinLimit, isLegacyRateLimit, Keyed, hasMaxLimit, disabledLimitInfo, MaxLegacyLimit, MinLegacyLimit } from '@/misc/rate-limit-utils.js'; +import { DI } from '@/di-symbols.js'; +import { MemoryKVCache } from '@/misc/cache.js'; +import type { MiUser } from '@/models/_.js'; +import type { RoleService } from '@/core/RoleService.js'; + +// Sentinel value used for caching the default role template. +// Required because MemoryKVCache doesn't support null keys. +const defaultUserKey = ''; + +@Injectable() +export class SkRateLimiterService { + // 1-minute cache interval + private readonly factorCache = new MemoryKVCache(1000 * 60); + private readonly disabled: boolean; + + constructor( + @Inject('TimeService') + private readonly timeService: TimeService, + + @Inject(DI.redis) + private readonly redisClient: Redis.Redis, + + @Inject('RoleService') + private readonly roleService: RoleService, + + @Inject('EnvService') + envService: EnvService, + ) { + this.disabled = envService.env.NODE_ENV === 'test'; + } + + /** + * Check & increment a rate limit for a client. + * + * If the client (actorOrUser) is passed as a string, then it uses the default rate limit factor from the role template. + * If the client (actorOrUser) is passed as an MiUser, then it queries the user's actual rate limit factor from their assigned roles. + * + * A factor of zero (0) will disable the limit, while any negative number will produce an error. + * A factor between zero (0) and one (1) will increase the limit from its default values (allowing more actions per time interval). + * A factor greater than one (1) will decrease the limit from its default values (allowing fewer actions per time interval). + * + * @param limit The limit definition + * @param actorOrUser authenticated client user or IP hash + */ + public async limit(limit: Keyed, actorOrUser: string | MiUser): Promise { + if (this.disabled) { + return disabledLimitInfo; + } + + const actor = typeof(actorOrUser) === 'object' ? actorOrUser.id : actorOrUser; + const userCacheKey = typeof(actorOrUser) === 'object' ? actorOrUser.id : defaultUserKey; + const userRoleKey = typeof(actorOrUser) === 'object' ? actorOrUser.id : null; + const factor = this.factorCache.get(userCacheKey) ?? await this.factorCache.fetch(userCacheKey, async () => { + const role = await this.roleService.getUserPolicies(userRoleKey); + return role.rateLimitFactor; + }); + + if (factor === 0) { + return disabledLimitInfo; + } + + if (factor < 0) { + throw new Error(`Rate limit factor is zero or negative: ${factor}`); + } + + if (isLegacyRateLimit(limit)) { + return await this.limitLegacy(limit, actor, factor); + } else { + return await this.limitBucket(limit, actor, factor); + } + } + + private async limitLegacy(limit: Keyed, actor: string, factor: number): Promise { + if (hasMaxLimit(limit)) { + return await this.limitLegacyMinMax(limit, actor, factor); + } else if (hasMinLimit(limit)) { + return await this.limitLegacyMinOnly(limit, actor, factor); + } else { + return disabledLimitInfo; + } + } + + private async limitLegacyMinMax(limit: Keyed, actor: string, factor: number): Promise { + if (limit.duration === 0) return disabledLimitInfo; + if (limit.duration < 0) throw new Error(`Invalid rate limit ${limit.key}: duration is negative (${limit.duration})`); + if (limit.max < 1) throw new Error(`Invalid rate limit ${limit.key}: max is less than 1 (${limit.max})`); + + // Derive initial dripRate from minInterval OR duration/max. + const initialDripRate = Math.max(limit.minInterval ?? Math.round(limit.duration / limit.max), 1); + + // Calculate dripSize to reach max at exactly duration + const dripSize = Math.max(Math.round(limit.max / (limit.duration / initialDripRate)), 1); + + // Calculate final dripRate from dripSize and duration/max + const dripRate = Math.max(Math.round(limit.duration / (limit.max / dripSize)), 1); + + const bucketLimit: Keyed = { + type: 'bucket', + key: limit.key, + size: limit.max, + dripRate, + dripSize, + }; + return await this.limitBucket(bucketLimit, actor, factor); + } + + private async limitLegacyMinOnly(limit: Keyed, actor: string, factor: number): Promise { + if (limit.minInterval === 0) return disabledLimitInfo; + if (limit.minInterval < 0) throw new Error(`Invalid rate limit ${limit.key}: minInterval is negative (${limit.minInterval})`); + + const dripRate = Math.max(Math.round(limit.minInterval), 1); + const bucketLimit: Keyed = { + type: 'bucket', + key: limit.key, + size: 1, + dripRate, + dripSize: 1, + }; + return await this.limitBucket(bucketLimit, actor, factor); + } + + /** + * Implementation of Leaky Bucket rate limiting - see SkRateLimiterService.md for details. + */ + private async limitBucket(limit: Keyed, actor: string, factor: number): Promise { + if (limit.size < 1) throw new Error(`Invalid rate limit ${limit.key}: size is less than 1 (${limit.size})`); + if (limit.dripRate != null && limit.dripRate < 1) throw new Error(`Invalid rate limit ${limit.key}: dripRate is less than 1 (${limit.dripRate})`); + if (limit.dripSize != null && limit.dripSize < 1) throw new Error(`Invalid rate limit ${limit.key}: dripSize is less than 1 (${limit.dripSize})`); + + // 0 - Calculate + const now = this.timeService.now; + const bucketSize = Math.max(Math.ceil(limit.size / factor), 1); + const dripRate = Math.ceil(limit.dripRate ?? 1000); + const dripSize = Math.ceil(limit.dripSize ?? 1); + const expirationSec = Math.max(Math.ceil((dripRate * Math.ceil(bucketSize / dripSize)) / 1000), 1); + + // 1 - Read + const counterKey = createLimitKey(limit, actor, 'c'); + const timestampKey = createLimitKey(limit, actor, 't'); + const counter = await this.getLimitCounter(counterKey, timestampKey); + + // 2 - Drip + const dripsSinceLastTick = Math.floor((now - counter.timestamp) / dripRate) * dripSize; + const deltaCounter = Math.min(dripsSinceLastTick, counter.counter); + const deltaTimestamp = dripsSinceLastTick * dripRate; + if (deltaCounter > 0) { + // Execute the next drip(s) + const results = await this.executeRedisMulti( + ['get', timestampKey], + ['incrby', timestampKey, deltaTimestamp], + ['expire', timestampKey, expirationSec], + ['get', timestampKey], + ['decrby', counterKey, deltaCounter], + ['expire', counterKey, expirationSec], + ['get', counterKey], + ); + const expectedTimestamp = counter.timestamp; + const canaryTimestamp = results[0] ? parseInt(results[0]) : 0; + counter.timestamp = results[3] ? parseInt(results[3]) : 0; + counter.counter = results[6] ? parseInt(results[6]) : 0; + + // Check for a data collision and rollback + if (canaryTimestamp !== expectedTimestamp) { + const rollbackResults = await this.executeRedisMulti( + ['decrby', timestampKey, deltaTimestamp], + ['get', timestampKey], + ['incrby', counterKey, deltaCounter], + ['get', counterKey], + ); + counter.timestamp = rollbackResults[1] ? parseInt(rollbackResults[1]) : 0; + counter.counter = rollbackResults[3] ? parseInt(rollbackResults[3]) : 0; + } + } + + // 3 - Check + const blocked = counter.counter >= bucketSize; + if (!blocked) { + if (counter.timestamp === 0) { + const results = await this.executeRedisMulti( + ['set', timestampKey, now], + ['expire', timestampKey, expirationSec], + ['incr', counterKey], + ['expire', counterKey, expirationSec], + ['get', counterKey], + ); + counter.timestamp = now; + counter.counter = results[4] ? parseInt(results[4]) : 0; + } else { + const results = await this.executeRedisMulti( + ['incr', counterKey], + ['expire', counterKey, expirationSec], + ['get', counterKey], + ); + counter.counter = results[2] ? parseInt(results[2]) : 0; + } + } + + // Calculate how much time is needed to free up a bucket slot + const overflow = Math.max((counter.counter + 1) - bucketSize, 0); + const dripsNeeded = Math.ceil(overflow / dripSize); + const timeNeeded = Math.max((dripRate * dripsNeeded) - (this.timeService.now - counter.timestamp), 0); + + // Calculate limit status + const remaining = Math.max(bucketSize - counter.counter, 0); + const resetMs = timeNeeded; + const resetSec = Math.ceil(resetMs / 1000); + const fullResetMs = Math.ceil(counter.counter / dripSize) * dripRate; + const fullResetSec = Math.ceil(fullResetMs / 1000); + return { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs }; + } + + private async getLimitCounter(counterKey: string, timestampKey: string): Promise { + const [counter, timestamp] = await this.executeRedisMulti( + ['get', counterKey], + ['get', timestampKey], + ); + + return { + counter: counter ? parseInt(counter) : 0, + timestamp: timestamp ? parseInt(timestamp) : 0, + }; + } + + private async executeRedisMulti(...batch: RedisCommand[]): Promise { + const results = await this.redisClient.multi(batch).exec(); + + // Transaction conflict (retryable) + if (!results) { + throw new ConflictError('Redis error: transaction conflict'); + } + + // Transaction failed (fatal) + if (results.length !== batch.length) { + throw new Error('Redis error: failed to execute batch'); + } + + // Map responses + const errors: Error[] = []; + const responses: RedisResult[] = []; + for (const [error, response] of results) { + if (error) errors.push(error); + responses.push(response as RedisResult); + } + + // Command failed (fatal) + if (errors.length > 0) { + const errorMessages = errors + .map((e, i) => `Error in command ${i}: ${e}`) + .join('\', \''); + throw new AggregateError(errors, `Redis error: failed to execute command(s): '${errorMessages}'`); + } + + return responses; + } +} + +// Not correct, but good enough for the basic commands we use. +type RedisResult = string | null; +type RedisCommand = [command: string, ...args: unknown[]]; + +function createLimitKey(limit: Keyed, actor: string, value: string): string { + return `rl_${actor}_${limit.key}_${value}`; +} + +class ConflictError extends Error {} + +interface LimitCounter { + timestamp: number; + counter: number; +} -- cgit v1.2.3-freya