diff options
Diffstat (limited to 'packages/backend/src/server/api')
| -rw-r--r-- | packages/backend/src/server/api/AuthenticateService.ts | 2 | ||||
| -rw-r--r-- | packages/backend/src/server/api/StreamingApiServerService.ts | 115 | ||||
| -rw-r--r-- | packages/backend/src/server/api/stream/index.ts | 21 |
3 files changed, 78 insertions, 60 deletions
diff --git a/packages/backend/src/server/api/AuthenticateService.ts b/packages/backend/src/server/api/AuthenticateService.ts index 6548c475b2..e23591d876 100644 --- a/packages/backend/src/server/api/AuthenticateService.ts +++ b/packages/backend/src/server/api/AuthenticateService.ts @@ -36,7 +36,7 @@ export class AuthenticateService { } @bindThis - public async authenticate(token: string | null | undefined): Promise<[LocalUser | null | undefined, AccessToken | null | undefined]> { + public async authenticate(token: string | null | undefined): Promise<[LocalUser | null, AccessToken | null]> { if (token == null) { return [null, null]; } diff --git a/packages/backend/src/server/api/StreamingApiServerService.ts b/packages/backend/src/server/api/StreamingApiServerService.ts index 258e8de034..fdda581ada 100644 --- a/packages/backend/src/server/api/StreamingApiServerService.ts +++ b/packages/backend/src/server/api/StreamingApiServerService.ts @@ -1,23 +1,25 @@ import { EventEmitter } from 'events'; import { Inject, Injectable } from '@nestjs/common'; import * as Redis from 'ioredis'; -import * as websocket from 'websocket'; +import * as WebSocket from 'ws'; import { DI } from '@/di-symbols.js'; -import type { UsersRepository, BlockingsRepository, ChannelFollowingsRepository, FollowingsRepository, MutingsRepository, UserProfilesRepository, RenoteMutingsRepository } from '@/models/index.js'; +import type { UsersRepository, AccessToken } from '@/models/index.js'; import type { Config } from '@/config.js'; import { NoteReadService } from '@/core/NoteReadService.js'; import { GlobalEventService } from '@/core/GlobalEventService.js'; import { NotificationService } from '@/core/NotificationService.js'; import { bindThis } from '@/decorators.js'; import { CacheService } from '@/core/CacheService.js'; -import { AuthenticateService } from './AuthenticateService.js'; +import { LocalUser } from '@/models/entities/User'; +import { AuthenticateService, AuthenticationError } from './AuthenticateService.js'; import MainStreamConnection from './stream/index.js'; import { ChannelsService } from './stream/ChannelsService.js'; -import type { ParsedUrlQuery } from 'querystring'; import type * as http from 'node:http'; @Injectable() export class StreamingApiServerService { + #wss: WebSocket.WebSocketServer; + constructor( @Inject(DI.config) private config: Config, @@ -28,24 +30,6 @@ export class StreamingApiServerService { @Inject(DI.usersRepository) private usersRepository: UsersRepository, - @Inject(DI.followingsRepository) - private followingsRepository: FollowingsRepository, - - @Inject(DI.mutingsRepository) - private mutingsRepository: MutingsRepository, - - @Inject(DI.renoteMutingsRepository) - private renoteMutingsRepository: RenoteMutingsRepository, - - @Inject(DI.blockingsRepository) - private blockingsRepository: BlockingsRepository, - - @Inject(DI.channelFollowingsRepository) - private channelFollowingsRepository: ChannelFollowingsRepository, - - @Inject(DI.userProfilesRepository) - private userProfilesRepository: UserProfilesRepository, - private cacheService: CacheService, private noteReadService: NoteReadService, private authenticateService: AuthenticateService, @@ -55,47 +39,75 @@ export class StreamingApiServerService { } @bindThis - public attachStreamingApi(server: http.Server) { - // Init websocket server - const ws = new websocket.server({ - httpServer: server, + public attach(server: http.Server): void { + this.#wss = new WebSocket.WebSocketServer({ + noServer: true, }); - ws.on('request', async (request) => { - const q = request.resourceURL.query as ParsedUrlQuery; - - // TODO: トークンが間違ってるなどしてauthenticateに失敗したら - // コネクション切断するなりエラーメッセージ返すなりする - // (現状はエラーがキャッチされておらずサーバーのログに流れて邪魔なので) - const [user, miapp] = await this.authenticateService.authenticate(q.i as string); - - if (user?.isSuspended) { - request.reject(400); + server.on('upgrade', async (request, socket, head) => { + if (request.url == null) { + socket.write('HTTP/1.1 400 Bad Request\r\n\r\n'); + socket.destroy(); return; } - const ev = new EventEmitter(); + const q = new URL(request.url, `http://${request.headers.host}`).searchParams; - async function onRedisMessage(_: string, data: string): Promise<void> { - const parsed = JSON.parse(data); - ev.emit(parsed.channel, parsed.message); + let user: LocalUser | null = null; + let app: AccessToken | null = null; + + try { + [user, app] = await this.authenticateService.authenticate(q.get('i')); + } catch (e) { + if (e instanceof AuthenticationError) { + socket.write('HTTP/1.1 401 Unauthorized\r\n\r\n'); + } else { + socket.write('HTTP/1.1 500 Internal Server Error\r\n\r\n'); + } + socket.destroy(); + return; } - this.redisForSub.on('message', onRedisMessage); + if (user?.isSuspended) { + socket.write('HTTP/1.1 403 Forbidden\r\n\r\n'); + socket.destroy(); + return; + } - const main = new MainStreamConnection( + const stream = new MainStreamConnection( this.channelsService, this.noteReadService, this.notificationService, this.cacheService, - ev, user, miapp, + user, app, ); - await main.init(); + await stream.init(); + + this.#wss.handleUpgrade(request, socket, head, (ws) => { + this.#wss.emit('connection', ws, request, { + stream, user, app, + }); + }); + }); + + this.#wss.on('connection', async (connection: WebSocket.WebSocket, request: http.IncomingMessage, ctx: { + stream: MainStreamConnection, + user: LocalUser | null; + app: AccessToken | null + }) => { + const { stream, user, app } = ctx; + + const ev = new EventEmitter(); + + async function onRedisMessage(_: string, data: string): Promise<void> { + const parsed = JSON.parse(data); + ev.emit(parsed.channel, parsed.message); + } - const connection = request.accept(); + this.redisForSub.on('message', onRedisMessage); - main.init2(connection); + await stream.listen(ev, connection); const intervalId = user ? setInterval(() => { this.usersRepository.update(user.id, { @@ -110,16 +122,23 @@ export class StreamingApiServerService { connection.once('close', () => { ev.removeAllListeners(); - main.dispose(); + stream.dispose(); this.redisForSub.off('message', onRedisMessage); if (intervalId) clearInterval(intervalId); }); connection.on('message', async (data) => { - if (data.type === 'utf8' && data.utf8Data === 'ping') { + if (data.toString() === 'ping') { connection.send('pong'); } }); }); } + + @bindThis + public detach(): Promise<void> { + return new Promise((resolve) => { + this.#wss.close(() => resolve()); + }); + } } diff --git a/packages/backend/src/server/api/stream/index.ts b/packages/backend/src/server/api/stream/index.ts index fee56e3668..8b1c2c09c9 100644 --- a/packages/backend/src/server/api/stream/index.ts +++ b/packages/backend/src/server/api/stream/index.ts @@ -1,3 +1,4 @@ +import * as WebSocket from 'ws'; import type { User } from '@/models/entities/User.js'; import type { AccessToken } from '@/models/entities/AccessToken.js'; import type { Packed } from '@/misc/json-schema.js'; @@ -7,7 +8,6 @@ import { bindThis } from '@/decorators.js'; import { CacheService } from '@/core/CacheService.js'; import { UserProfile } from '@/models/index.js'; import type { ChannelsService } from './ChannelsService.js'; -import type * as websocket from 'websocket'; import type { EventEmitter } from 'events'; import type Channel from './channel.js'; import type { StreamEventEmitter, StreamMessages } from './types.js'; @@ -18,7 +18,7 @@ import type { StreamEventEmitter, StreamMessages } from './types.js'; export default class Connection { public user?: User; public token?: AccessToken; - private wsConnection: websocket.connection; + private wsConnection: WebSocket.WebSocket; public subscriber: StreamEventEmitter; private channels: Channel[] = []; private subscribingNotes: any = {}; @@ -37,11 +37,9 @@ export default class Connection { private notificationService: NotificationService, private cacheService: CacheService, - subscriber: EventEmitter, user: User | null | undefined, token: AccessToken | null | undefined, ) { - this.subscriber = subscriber; if (user) this.user = user; if (token) this.token = token; } @@ -70,12 +68,16 @@ export default class Connection { if (this.user != null) { await this.fetch(); - this.fetchIntervalId = setInterval(this.fetch, 1000 * 10); + if (!this.fetchIntervalId) { + this.fetchIntervalId = setInterval(this.fetch, 1000 * 10); + } } } @bindThis - public async init2(wsConnection: websocket.connection) { + public async listen(subscriber: EventEmitter, wsConnection: WebSocket.WebSocket) { + this.subscriber = subscriber; + this.wsConnection = wsConnection; this.wsConnection.on('message', this.onWsConnectionMessage); @@ -88,14 +90,11 @@ export default class Connection { * クライアントからメッセージ受信時 */ @bindThis - private async onWsConnectionMessage(data: websocket.Message) { - if (data.type !== 'utf8') return; - if (data.utf8Data == null) return; - + private async onWsConnectionMessage(data: WebSocket.RawData) { let obj: Record<string, any>; try { - obj = JSON.parse(data.utf8Data); + obj = JSON.parse(data.toString()); } catch (e) { return; } |