From 0ea9d6ec5d4f037b37a98603f8942404530f2802 Mon Sep 17 00:00:00 2001 From: Hazelnoot Date: Wed, 11 Dec 2024 09:10:11 -0500 Subject: use atomic variant of Leaky Bucket for safe concurrent rate limits --- .../backend/src/server/api/SkRateLimiterService.ts | 206 ++++++++++++--------- 1 file changed, 114 insertions(+), 92 deletions(-) (limited to 'packages/backend/src/server/api') diff --git a/packages/backend/src/server/api/SkRateLimiterService.ts b/packages/backend/src/server/api/SkRateLimiterService.ts index 71681aadc9..d349e192e1 100644 --- a/packages/backend/src/server/api/SkRateLimiterService.ts +++ b/packages/backend/src/server/api/SkRateLimiterService.ts @@ -8,8 +8,7 @@ import Redis from 'ioredis'; import { TimeService } from '@/core/TimeService.js'; import { EnvService } from '@/core/EnvService.js'; import { BucketRateLimit, LegacyRateLimit, LimitInfo, RateLimit, hasMinLimit, isLegacyRateLimit, Keyed, hasMaxLimit, disabledLimitInfo, MaxLegacyLimit, MinLegacyLimit } from '@/misc/rate-limit-utils.js'; -import { RedisConnectionPool } from '@/core/RedisConnectionPool.js'; -import { TimeoutService } from '@/core/TimeoutService.js'; +import { DI } from '@/di-symbols.js'; @Injectable() export class SkRateLimiterService { @@ -19,11 +18,8 @@ export class SkRateLimiterService { @Inject(TimeService) private readonly timeService: TimeService, - @Inject(TimeoutService) - private readonly timeoutService: TimeoutService, - - @Inject(RedisConnectionPool) - private readonly redisPool: RedisConnectionPool, + @Inject(DI.redis) + private readonly redisClient: Redis.Redis, @Inject(EnvService) envService: EnvService, @@ -31,6 +27,12 @@ export class SkRateLimiterService { this.disabled = envService.env.NODE_ENV === 'test'; } + /** + * Check & increment a rate limit + * @param limit The limit definition + * @param actor Client who is calling this limit + * @param factor Scaling factor - smaller = larger limit (less restrictive) + */ public async limit(limit: Keyed, actor: string, factor = 1): Promise { if (this.disabled || factor === 0) { return disabledLimitInfo; @@ -40,52 +42,28 @@ export class SkRateLimiterService { throw new Error(`Rate limit factor is zero or negative: ${factor}`); } - const redis = await this.redisPool.alloc(); - try { - return await this.tryLimit(redis, limit, actor, factor); - } finally { - await this.redisPool.free(redis); - } + return await this.tryLimit(limit, actor, factor); } - private async tryLimit(redis: Redis.Redis, limit: Keyed, actor: string, factor: number, retry = 0): Promise { - try { - if (retry > 0) { - // Real-world testing showed the need for backoff to "spread out" bursty traffic. - const backoff = Math.round(Math.pow(2, retry + Math.random())); - await this.timeoutService.delay(backoff); - } - - if (isLegacyRateLimit(limit)) { - return await this.limitLegacy(redis, limit, actor, factor); - } else { - return await this.limitBucket(redis, limit, actor, factor); - } - } catch (err) { - // We may experience collision errors from optimistic locking. - // This is expected, so we should retry a few times before giving up. - // https://redis.io/docs/latest/develop/interact/transactions/#optimistic-locking-using-check-and-set - if (err instanceof ConflictError && retry < 4) { - // We can reuse the same connection to reduce pool contention, but we have to reset it first. - await redis.reset(); - return await this.tryLimit(redis, limit, actor, factor, retry + 1); - } - - throw err; + private async tryLimit(limit: Keyed, actor: string, factor: number): Promise { + if (isLegacyRateLimit(limit)) { + return await this.limitLegacy(limit, actor, factor); + } else { + return await this.limitBucket(limit, actor, factor); } } - private async limitLegacy(redis: Redis.Redis, limit: Keyed, actor: string, factor: number): Promise { + private async limitLegacy(limit: Keyed, actor: string, factor: number): Promise { if (hasMaxLimit(limit)) { - return await this.limitMaxLegacy(redis, limit, actor, factor); + return await this.limitMaxLegacy(limit, actor, factor); } else if (hasMinLimit(limit)) { - return await this.limitMinLegacy(redis, limit, actor, factor); + return await this.limitMinLegacy(limit, actor, factor); } else { return disabledLimitInfo; } } - private async limitMaxLegacy(redis: Redis.Redis, limit: Keyed, actor: string, factor: number): Promise { + private async limitMaxLegacy(limit: Keyed, actor: string, factor: number): Promise { if (limit.duration === 0) return disabledLimitInfo; if (limit.duration < 0) throw new Error(`Invalid rate limit ${limit.key}: duration is negative (${limit.duration})`); if (limit.max < 1) throw new Error(`Invalid rate limit ${limit.key}: max is less than 1 (${limit.max})`); @@ -106,10 +84,10 @@ export class SkRateLimiterService { dripRate, dripSize, }; - return await this.limitBucket(redis, bucketLimit, actor, factor); + return await this.limitBucket(bucketLimit, actor, factor); } - private async limitMinLegacy(redis: Redis.Redis, limit: Keyed, actor: string, factor: number): Promise { + private async limitMinLegacy(limit: Keyed, actor: string, factor: number): Promise { if (limit.minInterval === 0) return disabledLimitInfo; if (limit.minInterval < 0) throw new Error(`Invalid rate limit ${limit.key}: minInterval is negative (${limit.minInterval})`); @@ -121,33 +99,83 @@ export class SkRateLimiterService { dripRate, dripSize: 1, }; - return await this.limitBucket(redis, bucketLimit, actor, factor); + return await this.limitBucket(bucketLimit, actor, factor); } - private async limitBucket(redis: Redis.Redis, limit: Keyed, actor: string, factor: number): Promise { + /** + * Implementation of Leaky Bucket rate limiting - see SkRateLimiterService.md for details. + */ + private async limitBucket(limit: Keyed, actor: string, factor: number): Promise { if (limit.size < 1) throw new Error(`Invalid rate limit ${limit.key}: size is less than 1 (${limit.size})`); if (limit.dripRate != null && limit.dripRate < 1) throw new Error(`Invalid rate limit ${limit.key}: dripRate is less than 1 (${limit.dripRate})`); if (limit.dripSize != null && limit.dripSize < 1) throw new Error(`Invalid rate limit ${limit.key}: dripSize is less than 1 (${limit.dripSize})`); - const redisKey = createLimitKey(limit, actor); + // 0 - Calculate + const now = this.timeService.now; const bucketSize = Math.max(Math.ceil(limit.size / factor), 1); const dripRate = Math.ceil(limit.dripRate ?? 1000); const dripSize = Math.ceil(limit.dripSize ?? 1); - const expirationSec = Math.max(Math.ceil(bucketSize / dripRate), 1); - - // Simulate bucket drips - const counter = await this.getLimitCounter(redis, redisKey); - if (counter.counter > 0) { - const dripsSinceLastTick = Math.floor((this.timeService.now - counter.timestamp) / dripRate) * dripSize; - counter.counter = Math.max(counter.counter - dripsSinceLastTick, 0); + const expirationSec = Math.max(Math.ceil((dripRate * Math.ceil(bucketSize / dripSize)) / 1000), 1); + + // 1 - Read + const counterKey = createLimitKey(limit, actor, 'c'); + const timestampKey = createLimitKey(limit, actor, 't'); + const counter = await this.getLimitCounter(counterKey, timestampKey); + + // 2 - Drip + const dripsSinceLastTick = Math.floor((now - counter.timestamp) / dripRate) * dripSize; + const deltaCounter = Math.min(dripsSinceLastTick, counter.counter); + const deltaTimestamp = dripsSinceLastTick * dripRate; + if (deltaCounter > 0) { + // Execute the next drip(s) + const results = await this.executeRedisMulti( + ['get', timestampKey], + ['incrby', timestampKey, deltaTimestamp], + ['expire', timestampKey, expirationSec], + ['get', timestampKey], + ['decrby', counterKey, deltaCounter], + ['expire', counterKey, expirationSec], + ['get', counterKey], + ); + const expectedTimestamp = counter.timestamp; + const canaryTimestamp = results[0] ? parseInt(results[0]) : 0; + counter.timestamp = results[3] ? parseInt(results[3]) : 0; + counter.counter = results[6] ? parseInt(results[6]) : 0; + + // Check for a data collision and rollback + if (canaryTimestamp !== expectedTimestamp) { + const rollbackResults = await this.executeRedisMulti( + ['decrby', timestampKey, deltaTimestamp], + ['get', timestampKey], + ['incrby', counterKey, deltaCounter], + ['get', counterKey], + ); + counter.timestamp = rollbackResults[1] ? parseInt(rollbackResults[1]) : 0; + counter.counter = rollbackResults[3] ? parseInt(rollbackResults[3]) : 0; + } } - // Increment the limit, then synchronize with redis + // 3 - Check const blocked = counter.counter >= bucketSize; if (!blocked) { - counter.counter++; - counter.timestamp = this.timeService.now; - await this.updateLimitCounter(redis, redisKey, expirationSec, counter); + if (counter.timestamp === 0) { + const results = await this.executeRedisMulti( + ['set', timestampKey, now], + ['expire', timestampKey, expirationSec], + ['incr', counterKey], + ['expire', counterKey, expirationSec], + ['get', counterKey], + ); + counter.timestamp = now; + counter.counter = results[4] ? parseInt(results[4]) : 0; + } else { + const results = await this.executeRedisMulti( + ['incr', counterKey], + ['expire', counterKey, expirationSec], + ['get', counterKey], + ); + counter.counter = results[2] ? parseInt(results[2]) : 0; + } } // Calculate how much time is needed to free up a bucket slot @@ -164,37 +192,20 @@ export class SkRateLimiterService { return { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs }; } - private async getLimitCounter(redis: Redis.Redis, key: string): Promise { - const counter: LimitCounter = { counter: 0, timestamp: 0 }; - - // Watch the key BEFORE reading it! - await redis.watch(key); - const data = await redis.get(key); - - // Data may be missing or corrupt if the key doesn't exist. - // This is an expected edge case. - if (data) { - const parts = data.split(':'); - if (parts.length === 2) { - counter.counter = parseInt(parts[0]); - counter.timestamp = parseInt(parts[1]); - } - } - - return counter; - } - - private async updateLimitCounter(redis: Redis.Redis, key: string, expirationSec: number, counter: LimitCounter): Promise { - const data = `${counter.counter}:${counter.timestamp}`; - - await this.executeRedisMulti( - redis, - [['set', key, data, 'EX', expirationSec]], + private async getLimitCounter(counterKey: string, timestampKey: string): Promise { + const [counter, timestamp] = await this.executeRedisMulti( + ['get', counterKey], + ['get', timestampKey], ); + + return { + counter: counter ? parseInt(counter) : 0, + timestamp: timestamp ? parseInt(timestamp) : 0, + }; } - private async executeRedisMulti(redis: Redis.Redis, batch: RedisBatch): Promise> { - const results = await redis.multi(batch).exec(); + private async executeRedisMulti(...batch: RedisCommand[]): Promise { + const results = await this.redisClient.multi(batch).exec(); // Transaction conflict (retryable) if (!results) { @@ -206,21 +217,32 @@ export class SkRateLimiterService { throw new Error('Redis error: failed to execute batch'); } + // Map responses + const errors: Error[] = []; + const responses: RedisResult[] = []; + for (const [error, response] of results) { + if (error) errors.push(error); + responses.push(response as RedisResult); + } + // Command failed (fatal) - const errors = results.map(r => r[0]).filter(e => e != null); if (errors.length > 0) { - throw new AggregateError(errors, `Redis error: failed to execute command(s): '${errors.join('\', \'')}'`); + const errorMessages = errors + .map((e, i) => `Error in command ${i}: ${e}`) + .join('\', \''); + throw new AggregateError(errors, `Redis error: failed to execute command(s): '${errorMessages}'`); } - return results.map(r => r[1]) as RedisResults; + return responses; } } -type RedisBatch = [string, ...unknown[]][] & { length: Num }; -type RedisResults = (string | null)[] & { length: Num }; +// Not correct, but good enough for the basic commands we use. +type RedisResult = string | null; +type RedisCommand = [command: string, ...args: unknown[]]; -function createLimitKey(limit: Keyed, actor: string): string { - return `rl_${actor}_${limit.key}`; +function createLimitKey(limit: Keyed, actor: string, value: string): string { + return `rl_${actor}_${limit.key}_${value}`; } class ConflictError extends Error {} -- cgit v1.2.3-freya