-
Notifications
You must be signed in to change notification settings - Fork 589
Expand file tree
/
Copy pathnsfwjs.worker.ts
More file actions
123 lines (103 loc) · 3.27 KB
/
nsfwjs.worker.ts
File metadata and controls
123 lines (103 loc) · 3.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import type { ModelName, NSFWJS, PredictionType } from "nsfwjs";
import { load } from "nsfwjs/core";
import { InceptionV3Model } from "nsfwjs/models/inception_v3";
import { MobileNetV2Model } from "nsfwjs/models/mobilenet_v2";
import { MobileNetV2MidModel } from "nsfwjs/models/mobilenet_v2_mid";
let tfReadyPromise: Promise<void> | null = null;
const ensureTfReady = async () => {
if (tfReadyPromise) {
return tfReadyPromise;
}
tfReadyPromise = (async () => {
const tf = await import("@tensorflow/tfjs");
await import("@tensorflow/tfjs-backend-webgpu");
tf.enableProdMode();
await tf.setBackend("webgpu").catch(() => false);
await tf.ready();
})().catch((error) => {
tfReadyPromise = null;
throw error;
});
return tfReadyPromise;
};
export type Message = {
type: "load" | "predict";
modelName?: ModelName;
file?: File;
};
export type ReturnMessage = {
modelLoaded?: boolean;
predictions?: PredictionType[];
error?: string;
};
interface NSFWJSOptions {
size?: number;
type?: string;
}
type ModelConfig = {
[key in ModelName]: NSFWJSOptions;
};
const modelOptions: ModelConfig = {
MobileNetV2: {},
MobileNetV2Mid: { type: "graph" },
InceptionV3: { size: 299 },
};
let model: NSFWJS | null = null;
let loadedModelName: ModelName | null = null;
onmessage = async (event: MessageEvent<Message>) => {
const type = event.data.type;
const modelName = event.data.modelName;
const file = event.data.file;
if (type === "load" && modelName) {
try {
if (model && loadedModelName === modelName) {
postMessage({ modelLoaded: true } as ReturnMessage);
return;
}
await ensureTfReady();
try {
model = await load(`indexeddb://${modelName}`, modelOptions[modelName]);
console.info("Loaded from IndexedDB cache");
} catch {
model = await load(modelName, {
...modelOptions[modelName],
modelDefinitions: [MobileNetV2Model, MobileNetV2MidModel, InceptionV3Model,],
});
await model.model.save(`indexeddb://${modelName}`);
}
loadedModelName = modelName;
postMessage({ modelLoaded: true } as ReturnMessage);
} catch (error) {
postMessage({
modelLoaded: false,
error: error instanceof Error ? error.message : "Failed to load model",
} as ReturnMessage);
}
} else if (type === "predict" && file) {
if (!model) {
postMessage({ error: "Model is not loaded" } as ReturnMessage);
return;
}
const offscreenCanvas = new OffscreenCanvas(1, 1);
const ctx = offscreenCanvas.getContext("2d");
if (!ctx) {
postMessage({ error: "2D canvas context is unavailable" } as ReturnMessage);
return;
}
const imgBitmap = await createImageBitmap(file);
offscreenCanvas.width = imgBitmap.width;
offscreenCanvas.height = imgBitmap.height;
try {
ctx.drawImage(imgBitmap, 0, 0);
const imageData = ctx.getImageData(0, 0, imgBitmap.width, imgBitmap.height);
const predictions = await model.classify(imageData);
postMessage({ predictions } as ReturnMessage);
} catch (error) {
postMessage({
error: error instanceof Error ? error.message : "Prediction failed",
} as ReturnMessage);
} finally {
imgBitmap.close();
}
}
};