diff options
Diffstat (limited to 'packages/backend/src/server/api/stream/Connection.ts')
| -rw-r--r-- | packages/backend/src/server/api/stream/Connection.ts | 59 |
1 files changed, 36 insertions, 23 deletions
diff --git a/packages/backend/src/server/api/stream/Connection.ts b/packages/backend/src/server/api/stream/Connection.ts index 7ea92eb797..9378b6c62b 100644 --- a/packages/backend/src/server/api/stream/Connection.ts +++ b/packages/backend/src/server/api/stream/Connection.ts @@ -14,6 +14,7 @@ import { CacheService } from '@/core/CacheService.js'; import { MiFollowing, MiUserProfile } from '@/models/_.js'; import type { StreamEventEmitter, GlobalEvents } from '@/core/GlobalEventService.js'; import { ChannelFollowingService } from '@/core/ChannelFollowingService.js'; +import type { JsonObject } from '@/misc/json-value.js'; import type { ChannelsService } from './ChannelsService.js'; import type { EventEmitter } from 'events'; import type Channel from './channel.js'; @@ -33,7 +34,7 @@ export default class Connection { private wsConnection: WebSocket.WebSocket; public subscriber: StreamEventEmitter; private channels: Channel[] = []; - private subscribingNotes: any = {}; + private subscribingNotes: Partial<Record<string, number>> = {}; private cachedNotes: Packed<'Note'>[] = []; public userProfile: MiUserProfile | null = null; public following: Record<string, Pick<MiFollowing, 'withReplies'> | undefined> = {}; @@ -64,7 +65,7 @@ export default class Connection { if (token) this.token = token; if (rateLimiter) this.rateLimiter = rateLimiter; - this.logger = loggerService.getLogger('streaming', 'coral', false); + this.logger = loggerService.getLogger('streaming', 'coral'); } @bindThis @@ -115,7 +116,7 @@ export default class Connection { */ @bindThis private async onWsConnectionMessage(data: WebSocket.RawData) { - let obj: Record<string, any>; + let obj: JsonObject; if (this.closingConnection) return; @@ -148,6 +149,8 @@ export default class Connection { const { type, body } = obj; + if (typeof body !== 'object' || body === null || Array.isArray(body)) return; + switch (type) { case 'readNotification': this.onReadNotification(body); break; case 'subNote': this.onSubscribeNote(body); break; @@ -188,7 +191,7 @@ export default class Connection { } @bindThis - private readNote(body: any) { + private readNote(body: JsonObject) { const id = body.id; const note = this.cachedNotes.find(n => n.id === id); @@ -200,7 +203,7 @@ export default class Connection { } @bindThis - private onReadNotification(payload: any) { + private onReadNotification(payload: JsonObject) { this.notificationService.readAllNotification(this.user!.id); } @@ -208,16 +211,14 @@ export default class Connection { * 投稿購読要求時 */ @bindThis - private onSubscribeNote(payload: any) { - if (!payload.id) return; - - if (this.subscribingNotes[payload.id] == null) { - this.subscribingNotes[payload.id] = 0; - } + private onSubscribeNote(payload: JsonObject) { + if (!payload.id || typeof payload.id !== 'string') return; - this.subscribingNotes[payload.id]++; + const current = this.subscribingNotes[payload.id] ?? 0; + const updated = current + 1; + this.subscribingNotes[payload.id] = updated; - if (this.subscribingNotes[payload.id] === 1) { + if (updated === 1) { this.subscriber.on(`noteStream:${payload.id}`, this.onNoteStreamMessage); } } @@ -226,11 +227,14 @@ export default class Connection { * 投稿購読解除要求時 */ @bindThis - private onUnsubscribeNote(payload: any) { - if (!payload.id) return; + private onUnsubscribeNote(payload: JsonObject) { + if (!payload.id || typeof payload.id !== 'string') return; - this.subscribingNotes[payload.id]--; - if (this.subscribingNotes[payload.id] <= 0) { + const current = this.subscribingNotes[payload.id]; + if (current == null) return; + const updated = current - 1; + this.subscribingNotes[payload.id] = updated; + if (updated <= 0) { delete this.subscribingNotes[payload.id]; this.subscriber.off(`noteStream:${payload.id}`, this.onNoteStreamMessage); } @@ -261,17 +265,22 @@ export default class Connection { * チャンネル接続要求時 */ @bindThis - private onChannelConnectRequested(payload: any) { + private onChannelConnectRequested(payload: JsonObject) { const { channel, id, params, pong } = payload; - this.connectChannel(id, params, channel, pong); + if (typeof id !== 'string') return; + if (typeof channel !== 'string') return; + if (typeof pong !== 'boolean' && typeof pong !== 'undefined' && pong !== null) return; + if (typeof params !== 'undefined' && (typeof params !== 'object' || params === null || Array.isArray(params))) return; + this.connectChannel(id, params, channel, pong ?? undefined); } /** * チャンネル切断要求時 */ @bindThis - private onChannelDisconnectRequested(payload: any) { + private onChannelDisconnectRequested(payload: JsonObject) { const { id } = payload; + if (typeof id !== 'string') return; this.disconnectChannel(id); } @@ -279,7 +288,7 @@ export default class Connection { * クライアントにメッセージ送信 */ @bindThis - public sendMessageToWs(type: string, payload: any) { + public sendMessageToWs(type: string, payload: JsonObject) { this.wsConnection.send(JSON.stringify({ type: type, body: payload, @@ -290,7 +299,7 @@ export default class Connection { * チャンネルに接続 */ @bindThis - public connectChannel(id: string, params: any, channel: string, pong = false) { + public connectChannel(id: string, params: JsonObject | undefined, channel: string, pong = false) { if (this.channels.length >= MAX_CHANNELS_PER_CONNECTION) { return; } @@ -341,7 +350,11 @@ export default class Connection { * @param data メッセージ */ @bindThis - private onChannelMessageRequested(data: any) { + private onChannelMessageRequested(data: JsonObject) { + if (typeof data.id !== 'string') return; + if (typeof data.type !== 'string') return; + if (typeof data.body === 'undefined') return; + const channel = this.channels.find(c => c.id === data.id); if (channel != null && channel.onMessage != null) { channel.onMessage(data.type, data.body); |