This commit is contained in:
Rushikesh Tote 2022-06-01 18:33:27 +05:30
parent b2acec63d8
commit 793cf9d5a1
2 changed files with 15 additions and 7 deletions

View file

@ -89,3 +89,5 @@ export const MOBILEFACENET_FACE_SIZE = 112;
export const TESSERACT_MIN_IMAGE_WIDTH = 44; export const TESSERACT_MIN_IMAGE_WIDTH = 44;
export const TESSERACT_MIN_IMAGE_HEIGHT = 20; export const TESSERACT_MIN_IMAGE_HEIGHT = 20;
export const TESSERACT_MAX_IMAGE_DIMENSION = 720; export const TESSERACT_MAX_IMAGE_DIMENSION = 720;
export const SCENE_DETECTION_IMAGE_SIZE = [224, 224];

View file

@ -6,6 +6,7 @@ import {
SceneDetectionService, SceneDetectionService,
Versioned, Versioned,
} from 'types/machineLearning'; } from 'types/machineLearning';
import { SCENE_DETECTION_IMAGE_SIZE } from 'constants/machineLearning/config';
class ImageScene implements SceneDetectionService { class ImageScene implements SceneDetectionService {
method: Versioned<SceneDetectionMethod>; method: Versioned<SceneDetectionMethod>;
@ -35,9 +36,11 @@ class ImageScene implements SceneDetectionService {
this.model = model; this.model = model;
// warmup the model // warmup the model
const warmupResult = this.model.predict(tf.zeros([1, 224, 224, 3])); const warmupResult = this.model.predict(
await (warmupResult as tf.Tensor).data(); tf.zeros([1, 224, 224, 3])
(warmupResult as tf.Tensor).dispose(); ) as tf.Tensor;
await warmupResult.data();
warmupResult.dispose();
} }
async detectScenes(image: ImageBitmap, minScore: number) { async detectScenes(image: ImageBitmap, minScore: number) {
@ -52,18 +55,21 @@ class ImageScene implements SceneDetectionService {
// This model takes fixed-shaped (224x224) inputs // This model takes fixed-shaped (224x224) inputs
// https://tfhub.dev/sayannath/lite-model/image-scene/1 // https://tfhub.dev/sayannath/lite-model/image-scene/1
let resizedTensor = tf.image.resizeBilinear(tensor, [224, 224]); let resizedTensor = tf.image.resizeBilinear(
tensor,
SCENE_DETECTION_IMAGE_SIZE as [number, number]
);
resizedTensor = tf.expandDims(resizedTensor); resizedTensor = tf.expandDims(resizedTensor);
resizedTensor = tf.cast(resizedTensor, 'float32'); resizedTensor = tf.cast(resizedTensor, 'float32');
const output = this.model.predict(resizedTensor); const output = this.model.predict(resizedTensor) as tf.Tensor;
return output; return output;
}); });
const data = await (output as tf.Tensor).data(); const data = await output.data();
(output as tf.Tensor).dispose(); output.dispose();
const scenes = this.getScenes( const scenes = this.getScenes(
data as Float32Array, data as Float32Array,