move scene minScore to config

This commit is contained in:
Rushikesh Tote 2022-05-30 20:21:12 +05:30
parent e835010716
commit f6d550d689
5 changed files with 21 additions and 44 deletions

View file

@ -54,6 +54,7 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
}, },
sceneDetection: { sceneDetection: {
method: 'Image-Scene', method: 'Image-Scene',
minScore: 0.1,
}, },
// tsne: { // tsne: {
// samples: 200, // samples: 200,

View file

@ -10,7 +10,13 @@ function SceneDebug() {
const handleSubmission = async () => { const handleSubmission = async () => {
for (const file of selectedFiles) { for (const file of selectedFiles) {
await sceneDetectionService.detectByFile(file); console.log(
`scene detection for file ${file.name}`,
await sceneDetectionService.detectScenes(
await createImageBitmap(file),
0.1
)
);
} }
console.log('done with scene detection'); console.log('done with scene detection');
}; };

View file

@ -7,8 +7,6 @@ import {
} from 'types/machineLearning'; } from 'types/machineLearning';
import sceneMap from 'utils/machineLearning/sceneMap'; import sceneMap from 'utils/machineLearning/sceneMap';
const MIN_SCENE_DETECTION_SCORE = 0.1;
class ImageScene implements SceneDetectionService { class ImageScene implements SceneDetectionService {
method: Versioned<SceneDetectionMethod>; method: Versioned<SceneDetectionMethod>;
model: tf.GraphModel; model: tf.GraphModel;
@ -30,42 +28,7 @@ class ImageScene implements SceneDetectionService {
this.model = model; this.model = model;
} }
async detectByFile(file: File) { async detectScenes(image: ImageBitmap, minScore: number) {
const bmp = await createImageBitmap(file);
await tf.ready();
if (!this.model) {
await this.init();
}
const currTime = new Date().getTime();
const output = tf.tidy(() => {
let tensor = tf.browser.fromPixels(bmp);
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
tensor = tf.expandDims(tensor);
tensor = tf.cast(tensor, 'float32');
const output = this.model.predict(tensor, {
verbose: true,
});
return output;
});
console.log('done in', new Date().getTime() - currTime, 'ms');
const data = await (output as tf.Tensor).data();
const scenes = this.getScenes(
data as Float32Array,
bmp.width,
bmp.height
);
console.log(`scenes for ${file.name}`, scenes);
}
async detectScenes(image: ImageBitmap) {
await tf.ready(); await tf.ready();
if (!this.model) { if (!this.model) {
@ -88,7 +51,8 @@ class ImageScene implements SceneDetectionService {
const scenes = this.getScenes( const scenes = this.getScenes(
data as Float32Array, data as Float32Array,
image.width, image.width,
image.height image.height,
minScore
); );
return scenes; return scenes;
@ -97,11 +61,12 @@ class ImageScene implements SceneDetectionService {
private getScenes( private getScenes(
outputData: Float32Array, outputData: Float32Array,
width: number, width: number,
height: number height: number,
minScore: number
): ObjectDetection[] { ): ObjectDetection[] {
const scenes = []; const scenes = [];
for (let i = 0; i < outputData.length; i++) { for (let i = 0; i < outputData.length; i++) {
if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) { if (outputData[i] >= minScore) {
scenes.push({ scenes.push({
class: sceneMap.get(i), class: sceneMap.get(i),
score: outputData[i], score: outputData[i],

View file

@ -56,7 +56,8 @@ class ObjectService {
); );
objectDetections.push( objectDetections.push(
...(await syncContext.sceneDetectionService.detectScenes( ...(await syncContext.sceneDetectionService.detectScenes(
imageBitmap imageBitmap,
syncContext.config.sceneDetection.minScore
)) ))
); );
// console.log('3 TF Memory stats: ', tf.memory()); // console.log('3 TF Memory stats: ', tf.memory());

View file

@ -255,6 +255,7 @@ export interface ObjectDetectionConfig {
export interface SceneDetectionConfig { export interface SceneDetectionConfig {
method: SceneDetectionMethod; method: SceneDetectionMethod;
minScore: number;
} }
export interface TextDetectionConfig { export interface TextDetectionConfig {
@ -393,7 +394,10 @@ export interface ObjectDetectionService {
export interface SceneDetectionService { export interface SceneDetectionService {
method: Versioned<SceneDetectionMethod>; method: Versioned<SceneDetectionMethod>;
// init(): Promise<void>; // init(): Promise<void>;
detectScenes(image: ImageBitmap): Promise<ObjectDetection[]>; detectScenes(
image: ImageBitmap,
minScore: number
): Promise<ObjectDetection[]>;
} }
export interface TextDetectionService { export interface TextDetectionService {