diff --git a/src/constants/machineLearning/config.ts b/src/constants/machineLearning/config.ts index 454599fae..536ec425a 100644 --- a/src/constants/machineLearning/config.ts +++ b/src/constants/machineLearning/config.ts @@ -90,4 +90,4 @@ export const TESSERACT_MIN_IMAGE_WIDTH = 44; export const TESSERACT_MIN_IMAGE_HEIGHT = 20; export const TESSERACT_MAX_IMAGE_DIMENSION = 720; -export const SCENE_DETECTION_IMAGE_SIZE = [224, 224]; +export const SCENE_DETECTION_IMAGE_SIZE = 224; diff --git a/src/services/machineLearning/imageSceneService.ts b/src/services/machineLearning/imageSceneService.ts index b48448630..21ae8949a 100644 --- a/src/services/machineLearning/imageSceneService.ts +++ b/src/services/machineLearning/imageSceneService.ts @@ -7,6 +7,7 @@ import { Versioned, } from 'types/machineLearning'; import { SCENE_DETECTION_IMAGE_SIZE } from 'constants/machineLearning/config'; +import { resizeToSquare } from 'utils/image'; class ImageScene implements SceneDetectionService { method: Versioned; @@ -54,24 +55,16 @@ class ImageScene implements SceneDetectionService { async detectScenes(image: ImageBitmap, minScore: number) { await tf.ready(); + // scene detection model takes fixed-shaped (224x224) inputs + // https://tfhub.dev/sayannath/lite-model/image-scene/1 + const resized = resizeToSquare(image, SCENE_DETECTION_IMAGE_SIZE); const model = await this.getImageSceneModel(); const output = tf.tidy(() => { - const tensor = tf.browser.fromPixels(image); - - // This model takes fixed-shaped (224x224) inputs - // https://tfhub.dev/sayannath/lite-model/image-scene/1 - let resizedTensor = tf.image.resizeBilinear( - tensor, - SCENE_DETECTION_IMAGE_SIZE as [number, number] - ); - - resizedTensor = tf.expandDims(resizedTensor); - resizedTensor = tf.cast(resizedTensor, 'float32'); - - const output = model.predict(resizedTensor) as tf.Tensor; - + const tfImage = tf.browser.fromPixels(resized.image); + const input = tf.expandDims(tf.cast(tfImage, 'float32')); + const output = model.predict(input) as tf.Tensor; return output; });