diff --git a/packages/backend/src/server/api/stream/Connection.ts b/packages/backend/src/server/api/stream/Connection.ts index 36f7cd93b2..d7340f99fd 100644 --- a/packages/backend/src/server/api/stream/Connection.ts +++ b/packages/backend/src/server/api/stream/Connection.ts @@ -52,7 +52,7 @@ export default class Connection { public token?: MiAccessToken; private wsConnection: WebSocket.WebSocket; public subscriber: StreamEventEmitter; - private channels: Channel[] = []; + private channels: Map = new Map(); private subscribingNotes: Partial> = {}; public userProfile: MiUserProfile | null = null; public following: Record | undefined> = {}; @@ -262,7 +262,11 @@ export default class Connection { */ @bindThis public async connectChannel(id: string, params: JsonObject | undefined, channel: string, pong = false) { - if (this.channels.length >= MAX_CHANNELS_PER_CONNECTION) { + if (this.channels.has(id)) { + this.disconnectChannel(id); + } + + if (this.channels.size >= MAX_CHANNELS_PER_CONNECTION) { return; } @@ -278,8 +282,12 @@ export default class Connection { } // 共有可能チャンネルに接続しようとしていて、かつそのチャンネルに既に接続していたら無意味なので無視 - if (channelConstructor.shouldShare && this.channels.some(c => c.chName === channel)) { - return; + if (channelConstructor.shouldShare) { + for (const c of this.channels.values()) { + if (c.chName === channel) { + return; + } + } } const contextId = ContextIdFactory.create(); @@ -289,8 +297,13 @@ export default class Connection { }, contextId); const ch: Channel = await this.moduleRef.create(channelConstructor, contextId); - this.channels.push(ch); - ch.init(params ?? {}); + this.channels.set(ch.id, ch); + const valid = await ch.init(params ?? {}); + if (typeof valid === 'boolean' && !valid) { + // 初期化処理の結果、接続拒否されたので切断 + this.disconnectChannel(id); + return; + } if (pong) { this.sendMessageToWs('connected', { @@ -332,11 +345,11 @@ export default class Connection { */ @bindThis public disconnectChannel(id: string) { - const channel = this.channels.find(c => c.id === id); + const channel = this.channels.get(id); if (channel) { if (channel.dispose) channel.dispose(); - this.channels = this.channels.filter(c => c.id !== id); + this.channels.delete(id); } } @@ -351,7 +364,7 @@ export default class Connection { if (typeof data.type !== 'string') return; if (typeof data.body === 'undefined') return; - const channel = this.channels.find(c => c.id === data.id); + const channel = this.channels.get(data.id); if (channel != null && channel.onMessage != null) { channel.onMessage(data.type, data.body); } @@ -363,7 +376,7 @@ export default class Connection { @bindThis public dispose() { if (this.fetchIntervalId) clearInterval(this.fetchIntervalId); - for (const c of this.channels.filter(c => c.dispose)) { + for (const c of this.channels.values()) { if (c.dispose) c.dispose(); } } diff --git a/packages/backend/src/server/api/stream/channel.ts b/packages/backend/src/server/api/stream/channel.ts index 21036ee351..010d47a7ff 100644 --- a/packages/backend/src/server/api/stream/channel.ts +++ b/packages/backend/src/server/api/stream/channel.ts @@ -8,6 +8,7 @@ import { isInstanceMuted } from '@/misc/is-instance-muted.js'; import { isUserRelated } from '@/misc/is-user-related.js'; import { isQuotePacked, isRenotePacked } from '@/misc/is-renote.js'; import { isChannelRelated } from '@/misc/is-channel-related.js'; +import type { Awaitable } from '@/types.js'; import type { Packed } from '@/misc/json-schema.js'; import type { JsonObject, JsonValue } from '@/misc/json-value.js'; import type Connection from './Connection.js'; @@ -141,7 +142,14 @@ export default abstract class Channel { }); } - public abstract init(params: JsonObject): void; + /** + * チャンネルの初期化処理(接続時点での接続可否チェックを兼ねる) + * + * - `void / Promise` を返す場合は、チェックなし + * - `true / Promise` を返す場合は、接続可能 + * - `false / Promise` を返す場合は、接続不可(接続を切断) + */ + public abstract init(params: JsonObject): Awaitable; public dispose?(): void; diff --git a/packages/backend/src/server/api/stream/channels/antenna.ts b/packages/backend/src/server/api/stream/channels/antenna.ts index a37c2b8fcf..3fb1ad63d4 100644 --- a/packages/backend/src/server/api/stream/channels/antenna.ts +++ b/packages/backend/src/server/api/stream/channels/antenna.ts @@ -4,6 +4,8 @@ */ import { Inject, Injectable, Scope } from '@nestjs/common'; +import { DI } from '@/di-symbols.js'; +import type { AntennasRepository } from '@/models/_.js'; import { NoteEntityService } from '@/core/entities/NoteEntityService.js'; import { NoteStreamingHidingService } from '../NoteStreamingHidingService.js'; import { bindThis } from '@/decorators.js'; @@ -25,6 +27,9 @@ export class AntennaChannel extends Channel { @Inject(REQUEST) request: ChannelRequest, + @Inject(DI.antennasRepository) + private antennasReposiotry: AntennasRepository, + private noteEntityService: NoteEntityService, private noteStreamingHidingService: NoteStreamingHidingService, ) { @@ -33,12 +38,25 @@ export class AntennaChannel extends Channel { } @bindThis - public async init(params: JsonObject) { - if (typeof params.antennaId !== 'string') return; + public async init(params: JsonObject): Promise { + if (typeof params.antennaId !== 'string') return false; + if (!this.user) return false; + this.antennaId = params.antennaId; + const antennaExists = await this.antennasReposiotry.exists({ + where: { + id: this.antennaId, + userId: this.user.id, + }, + }); + + if (!antennaExists) return false; + // Subscribe stream this.subscriber.on(`antennaStream:${this.antennaId}`, this.onEvent); + + return true; } @bindThis diff --git a/packages/backend/src/server/api/stream/channels/chat-room.ts b/packages/backend/src/server/api/stream/channels/chat-room.ts index 7f949032e2..f56522716e 100644 --- a/packages/backend/src/server/api/stream/channels/chat-room.ts +++ b/packages/backend/src/server/api/stream/channels/chat-room.ts @@ -4,12 +4,14 @@ */ import { Inject, Injectable, Scope } from '@nestjs/common'; +import { DI } from '@/di-symbols.js'; import { bindThis } from '@/decorators.js'; import type { GlobalEvents } from '@/core/GlobalEventService.js'; import type { JsonObject } from '@/misc/json-value.js'; import { ChatService } from '@/core/ChatService.js'; import Channel, { type ChannelRequest } from '../channel.js'; import { REQUEST } from '@nestjs/core'; +import type { ChatRoomsRepository } from '@/models/_.js'; @Injectable({ scope: Scope.TRANSIENT }) export class ChatRoomChannel extends Channel { @@ -23,17 +25,31 @@ export class ChatRoomChannel extends Channel { @Inject(REQUEST) request: ChannelRequest, + @Inject(DI.chatRoomsRepository) + private chatRoomsRepository: ChatRoomsRepository, + private chatService: ChatService, ) { super(request); } @bindThis - public async init(params: JsonObject) { - if (typeof params.roomId !== 'string') return; + public async init(params: JsonObject): Promise { + if (typeof params.roomId !== 'string') return false; + if (!this.user) return false; + this.roomId = params.roomId; + const room = await this.chatRoomsRepository.findOneBy({ + id: this.roomId, + }); + + if (room == null) return false; + if (!(await this.chatService.hasPermissionToViewRoomTimeline(this.user.id, room))) return false; + this.subscriber.on(`chatRoomStream:${this.roomId}`, this.onEvent); + + return true; } @bindThis diff --git a/packages/backend/src/server/api/stream/channels/chat-user.ts b/packages/backend/src/server/api/stream/channels/chat-user.ts index 36f3f67b28..6d96c658ad 100644 --- a/packages/backend/src/server/api/stream/channels/chat-user.ts +++ b/packages/backend/src/server/api/stream/channels/chat-user.ts @@ -29,11 +29,16 @@ export class ChatUserChannel extends Channel { } @bindThis - public async init(params: JsonObject) { - if (typeof params.otherId !== 'string') return; + public async init(params: JsonObject): Promise { + if (typeof params.otherId !== 'string') return false; + if (!this.user) return false; + if (params.otherId === this.user.id) return false; + this.otherId = params.otherId; - this.subscriber.on(`chatUserStream:${this.user!.id}-${this.otherId}`, this.onEvent); + this.subscriber.on(`chatUserStream:${this.user.id}-${this.otherId}`, this.onEvent); + + return true; } @bindThis diff --git a/packages/backend/src/server/api/stream/channels/hashtag.ts b/packages/backend/src/server/api/stream/channels/hashtag.ts index 66a0df3e6b..6ba8963044 100644 --- a/packages/backend/src/server/api/stream/channels/hashtag.ts +++ b/packages/backend/src/server/api/stream/channels/hashtag.ts @@ -32,17 +32,19 @@ export class HashtagChannel extends Channel { } @bindThis - public async init(params: JsonObject) { - if (!Array.isArray(params.q)) return; + public async init(params: JsonObject): Promise { + if (!Array.isArray(params.q)) return false; if (!params.q.every((x): x is string[] => ( Array.isArray(x) && x.length >= 1 && x.every(y => typeof y === 'string') - ))) return; + ))) return false; this.q = params.q; // Subscribe stream this.subscriber.on('notesStream', this.onNote); + + return true; } @bindThis diff --git a/packages/backend/src/server/api/stream/channels/main.ts b/packages/backend/src/server/api/stream/channels/main.ts index 68a249cbee..224d9e0d89 100644 --- a/packages/backend/src/server/api/stream/channels/main.ts +++ b/packages/backend/src/server/api/stream/channels/main.ts @@ -28,9 +28,10 @@ export class MainChannel extends Channel { } @bindThis - public async init(params: JsonObject) { - // Subscribe main stream channel - this.subscriber.on(`mainStream:${this.user!.id}`, async data => { + public async init(params: JsonObject): Promise { + if (!this.user) return false; + + this.subscriber.on(`mainStream:${this.user.id}`, async data => { switch (data.type) { case 'notification': { // Ignore notifications from instances the user has muted @@ -61,5 +62,7 @@ export class MainChannel extends Channel { this.send(data.type, data.body); }); + + return true; } } diff --git a/packages/backend/src/server/api/stream/channels/user-list.ts b/packages/backend/src/server/api/stream/channels/user-list.ts index ae6386855b..6df7e32005 100644 --- a/packages/backend/src/server/api/stream/channels/user-list.ts +++ b/packages/backend/src/server/api/stream/channels/user-list.ts @@ -45,8 +45,8 @@ export class UserListChannel extends Channel { } @bindThis - public async init(params: JsonObject) { - if (typeof params.listId !== 'string') return; + public async init(params: JsonObject): Promise { + if (typeof params.listId !== 'string') return false; this.listId = params.listId; this.withFiles = !!(params.withFiles ?? false); this.withRenotes = !!(params.withRenotes ?? true); @@ -58,7 +58,7 @@ export class UserListChannel extends Channel { userId: this.user!.id, }, }); - if (!listExist) return; + if (!listExist) return false; // Subscribe stream this.subscriber.on(`userListStream:${this.listId}`, this.send); @@ -67,6 +67,8 @@ export class UserListChannel extends Channel { this.updateListUsers(); this.listUsersClock = setInterval(this.updateListUsers, 5000); + + return true; } @bindThis diff --git a/packages/backend/src/types.ts b/packages/backend/src/types.ts index 24654b0017..d3b2ec1eb0 100644 --- a/packages/backend/src/types.ts +++ b/packages/backend/src/types.ts @@ -414,3 +414,5 @@ export type FilterUnionByProperty< Property extends string | number | symbol, Condition, > = Union extends Record ? Union : never; + +export type Awaitable = T | Promise;