diff --git a/src/services/machineLearning/imageSceneService.ts b/src/services/machineLearning/imageSceneService.ts index aa6a70ee0..f97ae4e84 100644 --- a/src/services/machineLearning/imageSceneService.ts +++ b/src/services/machineLearning/imageSceneService.ts @@ -23,6 +23,7 @@ class ImageScene implements SceneDetectionService { } private async init() { + console.log('ImageScene init called'); if (this.model) { return; } @@ -31,21 +32,20 @@ class ImageScene implements SceneDetectionService { await fetch('/models/imagescene/sceneMap.json') ).json(); - const model = await tfjsConverter.loadGraphModel( + this.model = await tfjsConverter.loadGraphModel( '/models/imagescene/model.json' ); console.log('loaded ImageScene model', tf.getBackend()); - this.model = model; - // warmup the model - const warmupResult = this.model.predict( - tf.zeros([1, 224, 224, 3]) - ) as tf.Tensor; - await warmupResult.data(); - warmupResult.dispose(); + tf.tidy(() => { + const zeroTensor = tf.zeros([1, 224, 224, 3]); + // warmup the model + this.model.predict(zeroTensor) as tf.Tensor; + }); } private async getImageSceneModel() { + console.log('ImageScene getImageSceneModel called'); if (!this.ready) { this.ready = this.init(); } @@ -54,7 +54,6 @@ class ImageScene implements SceneDetectionService { } async detectScenes(image: ImageBitmap, minScore: number) { - await tf.ready(); const resized = resizeToSquare(image, SCENE_DETECTION_IMAGE_SIZE); const model = await this.getImageSceneModel(); @@ -66,24 +65,24 @@ class ImageScene implements SceneDetectionService { return output; }); - const data = await output.data(); + const data = (await output.data()) as Float32Array; output.dispose(); - const scenes = this.getScenes( - data as Float32Array, + const scenes = this.parseSceneDetectionResult( + data, + minScore, image.width, - image.height, - minScore + image.height ); return scenes; } - private getScenes( + private parseSceneDetectionResult( outputData: Float32Array, + minScore: number, width: number, - height: number, - minScore: number + height: number ): ObjectDetection[] { const scenes = []; for (let i = 0; i < outputData.length; i++) { diff --git a/src/utils/image/index.ts b/src/utils/image/index.ts index 8fcbc7c49..cff331035 100644 --- a/src/utils/image/index.ts +++ b/src/utils/image/index.ts @@ -12,8 +12,8 @@ export function resizeToSquare(img: ImageBitmap, size: number) { const ctx = offscreen.getContext('2d'); ctx.imageSmoothingQuality = 'high'; ctx.drawImage(img, 0, 0, width, height); - - return { image: offscreen.transferToImageBitmap(), width, height }; + const resizedImage = offscreen.transferToImageBitmap(); + return { image: resizedImage, width, height }; } export function transform(