diff options
| author | Hazelnoot <acomputerdog@gmail.com> | 2024-12-10 19:01:35 -0500 |
|---|---|---|
| committer | Hazelnoot <acomputerdog@gmail.com> | 2024-12-10 19:01:35 -0500 |
| commit | 407b2423af31ecaf44035f66a180a0bbc40e3aaa (patch) | |
| tree | e93a48eee9dfb8d3b5237d4279f1f97573e5cee9 /packages/backend/src/server/api | |
| parent | enable rate limits for dev environment (diff) | |
| download | sharkey-407b2423af31ecaf44035f66a180a0bbc40e3aaa.tar.gz sharkey-407b2423af31ecaf44035f66a180a0bbc40e3aaa.tar.bz2 sharkey-407b2423af31ecaf44035f66a180a0bbc40e3aaa.zip | |
fix redis transaction implementation
Diffstat (limited to 'packages/backend/src/server/api')
| -rw-r--r-- | packages/backend/src/server/api/SkRateLimiterService.ts | 216 |
1 files changed, 101 insertions, 115 deletions
diff --git a/packages/backend/src/server/api/SkRateLimiterService.ts b/packages/backend/src/server/api/SkRateLimiterService.ts index b11d1556ba..71681aadc9 100644 --- a/packages/backend/src/server/api/SkRateLimiterService.ts +++ b/packages/backend/src/server/api/SkRateLimiterService.ts @@ -7,8 +7,9 @@ import { Inject, Injectable } from '@nestjs/common'; import Redis from 'ioredis'; import { TimeService } from '@/core/TimeService.js'; import { EnvService } from '@/core/EnvService.js'; -import { DI } from '@/di-symbols.js'; -import { BucketRateLimit, LegacyRateLimit, LimitInfo, RateLimit, hasMinLimit, isLegacyRateLimit, Keyed } from '@/misc/rate-limit-utils.js'; +import { BucketRateLimit, LegacyRateLimit, LimitInfo, RateLimit, hasMinLimit, isLegacyRateLimit, Keyed, hasMaxLimit, disabledLimitInfo, MaxLegacyLimit, MinLegacyLimit } from '@/misc/rate-limit-utils.js'; +import { RedisConnectionPool } from '@/core/RedisConnectionPool.js'; +import { TimeoutService } from '@/core/TimeoutService.js'; @Injectable() export class SkRateLimiterService { @@ -18,8 +19,11 @@ export class SkRateLimiterService { @Inject(TimeService) private readonly timeService: TimeService, - @Inject(DI.redis) - private readonly redisClient: Redis.Redis, + @Inject(TimeoutService) + private readonly timeoutService: TimeoutService, + + @Inject(RedisConnectionPool) + private readonly redisPool: RedisConnectionPool, @Inject(EnvService) envService: EnvService, @@ -29,117 +33,110 @@ export class SkRateLimiterService { public async limit(limit: Keyed<RateLimit>, actor: string, factor = 1): Promise<LimitInfo> { if (this.disabled || factor === 0) { - return { - blocked: false, - remaining: Number.MAX_SAFE_INTEGER, - resetSec: 0, - resetMs: 0, - fullResetSec: 0, - fullResetMs: 0, - }; + return disabledLimitInfo; } if (factor < 0) { throw new Error(`Rate limit factor is zero or negative: ${factor}`); } - return await this.tryLimit(limit, actor, factor); + const redis = await this.redisPool.alloc(); + try { + return await this.tryLimit(redis, limit, actor, factor); + } finally { + await this.redisPool.free(redis); + } } - private async tryLimit(limit: Keyed<RateLimit>, actor: string, factor: number, retry = 1): Promise<LimitInfo> { + private async tryLimit(redis: Redis.Redis, limit: Keyed<RateLimit>, actor: string, factor: number, retry = 0): Promise<LimitInfo> { try { + if (retry > 0) { + // Real-world testing showed the need for backoff to "spread out" bursty traffic. + const backoff = Math.round(Math.pow(2, retry + Math.random())); + await this.timeoutService.delay(backoff); + } + if (isLegacyRateLimit(limit)) { - return await this.limitLegacy(limit, actor, factor); + return await this.limitLegacy(redis, limit, actor, factor); } else { - return await this.limitBucket(limit, actor, factor); + return await this.limitBucket(redis, limit, actor, factor); } } catch (err) { // We may experience collision errors from optimistic locking. // This is expected, so we should retry a few times before giving up. // https://redis.io/docs/latest/develop/interact/transactions/#optimistic-locking-using-check-and-set - if (err instanceof TransactionError && retry < 3) { - return await this.tryLimit(limit, actor, factor, retry + 1); + if (err instanceof ConflictError && retry < 4) { + // We can reuse the same connection to reduce pool contention, but we have to reset it first. + await redis.reset(); + return await this.tryLimit(redis, limit, actor, factor, retry + 1); } throw err; } } - private async limitLegacy(limit: Keyed<LegacyRateLimit>, actor: string, factor: number): Promise<LimitInfo> { - const promises: Promise<LimitInfo | null>[] = []; - - // The "min" limit - if present - is handled directly. - if (hasMinLimit(limit)) { - promises.push( - this.limitMin(limit, actor, factor), - ); + private async limitLegacy(redis: Redis.Redis, limit: Keyed<LegacyRateLimit>, actor: string, factor: number): Promise<LimitInfo> { + if (hasMaxLimit(limit)) { + return await this.limitMaxLegacy(redis, limit, actor, factor); + } else if (hasMinLimit(limit)) { + return await this.limitMinLegacy(redis, limit, actor, factor); + } else { + return disabledLimitInfo; } + } - // Convert the "max" limit into a leaky bucket with 1 drip / second rate. - if (limit.max != null && limit.duration != null) { - promises.push( - this.limitBucket({ - type: 'bucket', - key: limit.key, - size: limit.max, - dripRate: Math.max(Math.round(limit.duration / limit.max), 1), - }, actor, factor), - ); - } + private async limitMaxLegacy(redis: Redis.Redis, limit: Keyed<MaxLegacyLimit>, actor: string, factor: number): Promise<LimitInfo> { + if (limit.duration === 0) return disabledLimitInfo; + if (limit.duration < 0) throw new Error(`Invalid rate limit ${limit.key}: duration is negative (${limit.duration})`); + if (limit.max < 1) throw new Error(`Invalid rate limit ${limit.key}: max is less than 1 (${limit.max})`); - const [lim1, lim2] = await Promise.all(promises); - return { - blocked: (lim1?.blocked || lim2?.blocked) ?? false, - remaining: Math.min(lim1?.remaining ?? Number.MAX_SAFE_INTEGER, lim2?.remaining ?? Number.MAX_SAFE_INTEGER), - resetSec: Math.max(lim1?.resetSec ?? 0, lim2?.resetSec ?? 0), - resetMs: Math.max(lim1?.resetMs ?? 0, lim2?.resetMs ?? 0), - fullResetSec: Math.max(lim1?.fullResetSec ?? 0, lim2?.fullResetSec ?? 0), - fullResetMs: Math.max(lim1?.fullResetMs ?? 0, lim2?.fullResetMs ?? 0), - }; - } + // Derive initial dripRate from minInterval OR duration/max. + const initialDripRate = Math.max(limit.minInterval ?? Math.round(limit.duration / limit.max), 1); - private async limitMin(limit: Keyed<LegacyRateLimit> & { minInterval: number }, actor: string, factor: number): Promise<LimitInfo | null> { - if (limit.minInterval === 0) return null; - if (limit.minInterval < 0) throw new Error(`Invalid rate limit ${limit.key}: minInterval is negative (${limit.minInterval})`); + // Calculate dripSize to reach max at exactly duration + const dripSize = Math.max(Math.round(limit.max / (limit.duration / initialDripRate)), 1); - const minInterval = Math.max(Math.ceil(limit.minInterval * factor), 0); - const expirationSec = Math.max(Math.ceil(minInterval / 1000), 1); + // Calculate final dripRate from dripSize and duration/max + const dripRate = Math.max(Math.round(limit.duration / (limit.max / dripSize)), 1); - // Check for window clear - const counter = await this.getLimitCounter(limit, actor, 'min'); - if (counter.counter > 0) { - const isCleared = this.timeService.now - counter.timestamp >= minInterval; - if (isCleared) { - counter.counter = 0; - } - } + const bucketLimit: Keyed<BucketRateLimit> = { + type: 'bucket', + key: limit.key, + size: limit.max, + dripRate, + dripSize, + }; + return await this.limitBucket(redis, bucketLimit, actor, factor); + } - // Increment the limit, then synchronize with redis - const blocked = counter.counter > 0; - if (!blocked) { - counter.counter++; - counter.timestamp = this.timeService.now; - await this.updateLimitCounter(limit, actor, 'min', expirationSec, counter); - } + private async limitMinLegacy(redis: Redis.Redis, limit: Keyed<MinLegacyLimit>, actor: string, factor: number): Promise<LimitInfo> { + if (limit.minInterval === 0) return disabledLimitInfo; + if (limit.minInterval < 0) throw new Error(`Invalid rate limit ${limit.key}: minInterval is negative (${limit.minInterval})`); - // Calculate limit status - const resetMs = Math.max(minInterval - (this.timeService.now - counter.timestamp), 0); - const resetSec = Math.ceil(resetMs / 1000); - return { blocked, remaining: 0, resetSec, resetMs, fullResetSec: resetSec, fullResetMs: resetMs }; + const dripRate = Math.max(Math.round(limit.minInterval), 1); + const bucketLimit: Keyed<BucketRateLimit> = { + type: 'bucket', + key: limit.key, + size: 1, + dripRate, + dripSize: 1, + }; + return await this.limitBucket(redis, bucketLimit, actor, factor); } - private async limitBucket(limit: Keyed<BucketRateLimit>, actor: string, factor: number): Promise<LimitInfo> { + private async limitBucket(redis: Redis.Redis, limit: Keyed<BucketRateLimit>, actor: string, factor: number): Promise<LimitInfo> { if (limit.size < 1) throw new Error(`Invalid rate limit ${limit.key}: size is less than 1 (${limit.size})`); if (limit.dripRate != null && limit.dripRate < 1) throw new Error(`Invalid rate limit ${limit.key}: dripRate is less than 1 (${limit.dripRate})`); if (limit.dripSize != null && limit.dripSize < 1) throw new Error(`Invalid rate limit ${limit.key}: dripSize is less than 1 (${limit.dripSize})`); + const redisKey = createLimitKey(limit, actor); const bucketSize = Math.max(Math.ceil(limit.size / factor), 1); const dripRate = Math.ceil(limit.dripRate ?? 1000); const dripSize = Math.ceil(limit.dripSize ?? 1); const expirationSec = Math.max(Math.ceil(bucketSize / dripRate), 1); // Simulate bucket drips - const counter = await this.getLimitCounter(limit, actor, 'bucket'); + const counter = await this.getLimitCounter(redis, redisKey); if (counter.counter > 0) { const dripsSinceLastTick = Math.floor((this.timeService.now - counter.timestamp) / dripRate) * dripSize; counter.counter = Math.max(counter.counter - dripsSinceLastTick, 0); @@ -150,7 +147,7 @@ export class SkRateLimiterService { if (!blocked) { counter.counter++; counter.timestamp = this.timeService.now; - await this.updateLimitCounter(limit, actor, 'bucket', expirationSec, counter); + await this.updateLimitCounter(redis, redisKey, expirationSec, counter); } // Calculate how much time is needed to free up a bucket slot @@ -167,60 +164,49 @@ export class SkRateLimiterService { return { blocked, remaining, resetSec, resetMs, fullResetSec, fullResetMs }; } - private async getLimitCounter(limit: Keyed<RateLimit>, actor: string, subject: string): Promise<LimitCounter> { - const timestampKey = createLimitKey(limit, actor, subject, 't'); - const counterKey = createLimitKey(limit, actor, subject, 'c'); + private async getLimitCounter(redis: Redis.Redis, key: string): Promise<LimitCounter> { + const counter: LimitCounter = { counter: 0, timestamp: 0 }; - const [timestamp, counter] = await this.executeRedis( - [ - ['get', timestampKey], - ['get', counterKey], - ], - [ - timestampKey, - counterKey, - ], - ); + // Watch the key BEFORE reading it! + await redis.watch(key); + const data = await redis.get(key); - return { - timestamp: timestamp ? parseInt(timestamp) : 0, - counter: counter ? parseInt(counter) : 0, - }; + // Data may be missing or corrupt if the key doesn't exist. + // This is an expected edge case. + if (data) { + const parts = data.split(':'); + if (parts.length === 2) { + counter.counter = parseInt(parts[0]); + counter.timestamp = parseInt(parts[1]); + } + } + + return counter; } - private async updateLimitCounter(limit: Keyed<RateLimit>, actor: string, subject: string, expirationSec: number, counter: LimitCounter): Promise<void> { - const timestampKey = createLimitKey(limit, actor, subject, 't'); - const counterKey = createLimitKey(limit, actor, subject, 'c'); + private async updateLimitCounter(redis: Redis.Redis, key: string, expirationSec: number, counter: LimitCounter): Promise<void> { + const data = `${counter.counter}:${counter.timestamp}`; - await this.executeRedis( - [ - ['set', timestampKey, counter.timestamp.toString(), 'EX', expirationSec], - ['set', counterKey, counter.counter.toString(), 'EX', expirationSec], - ], - [ - timestampKey, - counterKey, - ], + await this.executeRedisMulti( + redis, + [['set', key, data, 'EX', expirationSec]], ); } - private async executeRedis<Num extends number>(batch: RedisBatch<Num>, watch: string[]): Promise<RedisResults<Num>> { - const results = await this.redisClient - .multi(batch) - .watch(watch) - .exec(); + private async executeRedisMulti<Num extends number>(redis: Redis.Redis, batch: RedisBatch<Num>): Promise<RedisResults<Num>> { + const results = await redis.multi(batch).exec(); - // Transaction error + // Transaction conflict (retryable) if (!results) { - throw new TransactionError('Redis error: transaction conflict'); + throw new ConflictError('Redis error: transaction conflict'); } - // The entire call failed + // Transaction failed (fatal) if (results.length !== batch.length) { throw new Error('Redis error: failed to execute batch'); } - // A particular command failed + // Command failed (fatal) const errors = results.map(r => r[0]).filter(e => e != null); if (errors.length > 0) { throw new AggregateError(errors, `Redis error: failed to execute command(s): '${errors.join('\', \'')}'`); @@ -233,11 +219,11 @@ export class SkRateLimiterService { type RedisBatch<Num extends number> = [string, ...unknown[]][] & { length: Num }; type RedisResults<Num extends number> = (string | null)[] & { length: Num }; -function createLimitKey(limit: Keyed<RateLimit>, actor: string, subject: string, value: string): string { - return `rl_${actor}_${limit.key}_${subject}_${value}`; +function createLimitKey(limit: Keyed<RateLimit>, actor: string): string { + return `rl_${actor}_${limit.key}`; } -class TransactionError extends Error {} +class ConflictError extends Error {} interface LimitCounter { timestamp: number; |