move scene minScore to config

This commit is contained in:
Rushikesh Tote 2022-05-30 20:21:12 +05:30
parent e835010716
commit f6d550d689
5 changed files with 21 additions and 44 deletions

View file

@ -54,6 +54,7 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
},
sceneDetection: {
method: 'Image-Scene',
minScore: 0.1,
},
// tsne: {
// samples: 200,

View file

@ -10,7 +10,13 @@ function SceneDebug() {
const handleSubmission = async () => {
for (const file of selectedFiles) {
await sceneDetectionService.detectByFile(file);
console.log(
`scene detection for file ${file.name}`,
await sceneDetectionService.detectScenes(
await createImageBitmap(file),
0.1
)
);
}
console.log('done with scene detection');
};

View file

@ -7,8 +7,6 @@ import {
} from 'types/machineLearning';
import sceneMap from 'utils/machineLearning/sceneMap';
const MIN_SCENE_DETECTION_SCORE = 0.1;
class ImageScene implements SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
model: tf.GraphModel;
@ -30,42 +28,7 @@ class ImageScene implements SceneDetectionService {
this.model = model;
}
async detectByFile(file: File) {
const bmp = await createImageBitmap(file);
await tf.ready();
if (!this.model) {
await this.init();
}
const currTime = new Date().getTime();
const output = tf.tidy(() => {
let tensor = tf.browser.fromPixels(bmp);
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
tensor = tf.expandDims(tensor);
tensor = tf.cast(tensor, 'float32');
const output = this.model.predict(tensor, {
verbose: true,
});
return output;
});
console.log('done in', new Date().getTime() - currTime, 'ms');
const data = await (output as tf.Tensor).data();
const scenes = this.getScenes(
data as Float32Array,
bmp.width,
bmp.height
);
console.log(`scenes for ${file.name}`, scenes);
}
async detectScenes(image: ImageBitmap) {
async detectScenes(image: ImageBitmap, minScore: number) {
await tf.ready();
if (!this.model) {
@ -88,7 +51,8 @@ class ImageScene implements SceneDetectionService {
const scenes = this.getScenes(
data as Float32Array,
image.width,
image.height
image.height,
minScore
);
return scenes;
@ -97,11 +61,12 @@ class ImageScene implements SceneDetectionService {
private getScenes(
outputData: Float32Array,
width: number,
height: number
height: number,
minScore: number
): ObjectDetection[] {
const scenes = [];
for (let i = 0; i < outputData.length; i++) {
if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) {
if (outputData[i] >= minScore) {
scenes.push({
class: sceneMap.get(i),
score: outputData[i],

View file

@ -56,7 +56,8 @@ class ObjectService {
);
objectDetections.push(
...(await syncContext.sceneDetectionService.detectScenes(
imageBitmap
imageBitmap,
syncContext.config.sceneDetection.minScore
))
);
// console.log('3 TF Memory stats: ', tf.memory());

View file

@ -255,6 +255,7 @@ export interface ObjectDetectionConfig {
export interface SceneDetectionConfig {
method: SceneDetectionMethod;
minScore: number;
}
export interface TextDetectionConfig {
@ -393,7 +394,10 @@ export interface ObjectDetectionService {
export interface SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
// init(): Promise<void>;
detectScenes(image: ImageBitmap): Promise<ObjectDetection[]>;
detectScenes(
image: ImageBitmap,
minScore: number
): Promise<ObjectDetection[]>;
}
export interface TextDetectionService {