summaryrefslogtreecommitdiff
path: root/packages/backend/src/core/AiService.ts
diff options
context:
space:
mode:
authorpopkirby <popkirby@gmail.com>2023-07-08 07:27:26 +0900
committerGitHub <noreply@github.com>2023-07-08 07:27:26 +0900
commit8daca59ca61dac8fd7f68f3135d69092c70e599f (patch)
tree3bc50331092a6c6c27abae001467fd8f11226f2c /packages/backend/src/core/AiService.ts
parentcleanup: trim trailing whitespace (#11136) (diff)
downloadsharkey-8daca59ca61dac8fd7f68f3135d69092c70e599f.tar.gz
sharkey-8daca59ca61dac8fd7f68f3135d69092c70e599f.tar.bz2
sharkey-8daca59ca61dac8fd7f68f3135d69092c70e599f.zip
perf(backend): use mutex for nsfw model loading (#11109)
Co-authored-by: tamaina <tamaina@hotmail.co.jp>
Diffstat (limited to 'packages/backend/src/core/AiService.ts')
-rw-r--r--packages/backend/src/core/AiService.ts10
1 files changed, 9 insertions, 1 deletions
diff --git a/packages/backend/src/core/AiService.ts b/packages/backend/src/core/AiService.ts
index 02501b832b..c0596446dd 100644
--- a/packages/backend/src/core/AiService.ts
+++ b/packages/backend/src/core/AiService.ts
@@ -4,6 +4,7 @@ import { dirname } from 'node:path';
import { Inject, Injectable } from '@nestjs/common';
import * as nsfw from 'nsfwjs';
import si from 'systeminformation';
+import { Mutex } from 'async-mutex';
import type { Config } from '@/config.js';
import { DI } from '@/di-symbols.js';
import { bindThis } from '@/decorators.js';
@@ -17,6 +18,7 @@ let isSupportedCpu: undefined | boolean = undefined;
@Injectable()
export class AiService {
private model: nsfw.NSFWJS;
+ private modelLoadMutex: Mutex = new Mutex();
constructor(
@Inject(DI.config)
@@ -39,7 +41,13 @@ export class AiService {
const tf = await import('@tensorflow/tfjs-node');
- if (this.model == null) this.model = await nsfw.load(`file://${_dirname}/../../nsfw-model/`, { size: 299 });
+ if (this.model == null) {
+ await this.modelLoadMutex.runExclusive(async () => {
+ if (this.model == null) {
+ this.model = await nsfw.load(`file://${_dirname}/../../nsfw-model/`, { size: 299 });
+ }
+ });
+ }
const buffer = await fs.promises.readFile(path);
const image = await tf.node.decodeImage(buffer, 3) as any;