integrated scene detection

This commit is contained in:
Rushikesh Tote 2022-05-30 20:16:52 +05:30
parent 2521073b04
commit e835010716
8 changed files with 215 additions and 102 deletions

View file

@ -52,6 +52,9 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
method: 'Tesseract', method: 'Tesseract',
minAccuracy: 75, minAccuracy: 75,
}, },
sceneDetection: {
method: 'Image-Scene',
},
// tsne: { // tsne: {
// samples: 200, // samples: 200,
// dim: 2, // dim: 2,

View file

@ -1,30 +1,40 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import sceneDetectionService from 'services/machineLearning/sceneDetectionService'; import sceneDetectionService from 'services/machineLearning/imageSceneService';
function SceneDebug() { function SceneDebug() {
const [selectedFile, setSelectedFile] = useState<File>(null); const [selectedFiles, setSelectedFiles] = useState<File[]>(null);
const changeHandler = (event: React.ChangeEvent<HTMLInputElement>) => { const changeHandler = (event: React.ChangeEvent<HTMLInputElement>) => {
setSelectedFile(event.target.files[0]); setSelectedFiles([...event.target.files]);
}; };
const handleSubmission = async () => { const handleSubmission = async () => {
await sceneDetectionService.init(); for (const file of selectedFiles) {
await sceneDetectionService.run(selectedFile); await sceneDetectionService.detectByFile(file);
}
console.log('done with scene detection');
}; };
useEffect(() => { useEffect(() => {
console.log(selectedFile); console.log('loaded', selectedFiles);
}, [selectedFile]); }, [selectedFiles]);
return ( return (
<div> <div>
<input type="file" name="file" onChange={changeHandler} /> <input
type="file"
name="file"
multiple={true}
onChange={changeHandler}
/>
<div> <div>
<button onClick={handleSubmission}>Submit</button> <button onClick={handleSubmission}>Submit</button>
</div> </div>
{selectedFile && ( {selectedFiles?.length > 0 && (
<img src={URL.createObjectURL(selectedFile)} width={'400px'} /> <img
src={URL.createObjectURL(selectedFiles[0])}
width={'400px'}
/>
)} )}
</div> </div>
); );

View file

@ -0,0 +1,116 @@
import * as tf from '@tensorflow/tfjs';
import {
ObjectDetection,
SceneDetectionMethod,
SceneDetectionService,
Versioned,
} from 'types/machineLearning';
import sceneMap from 'utils/machineLearning/sceneMap';
const MIN_SCENE_DETECTION_SCORE = 0.1;
class ImageScene implements SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
model: tf.GraphModel;
public constructor() {
this.method = {
value: 'Image-Scene',
version: 1,
};
}
private async init() {
if (this.model) {
return;
}
const model = await tf.loadGraphModel('/models/imagescene/model.json');
console.log('loaded image-scene model', model, tf.getBackend());
this.model = model;
}
async detectByFile(file: File) {
const bmp = await createImageBitmap(file);
await tf.ready();
if (!this.model) {
await this.init();
}
const currTime = new Date().getTime();
const output = tf.tidy(() => {
let tensor = tf.browser.fromPixels(bmp);
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
tensor = tf.expandDims(tensor);
tensor = tf.cast(tensor, 'float32');
const output = this.model.predict(tensor, {
verbose: true,
});
return output;
});
console.log('done in', new Date().getTime() - currTime, 'ms');
const data = await (output as tf.Tensor).data();
const scenes = this.getScenes(
data as Float32Array,
bmp.width,
bmp.height
);
console.log(`scenes for ${file.name}`, scenes);
}
async detectScenes(image: ImageBitmap) {
await tf.ready();
if (!this.model) {
await this.init();
}
const output = tf.tidy(() => {
let tensor = tf.browser.fromPixels(image);
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
tensor = tf.expandDims(tensor);
tensor = tf.cast(tensor, 'float32');
const output = this.model.predict(tensor);
return output;
});
const data = await (output as tf.Tensor).data();
const scenes = this.getScenes(
data as Float32Array,
image.width,
image.height
);
return scenes;
}
private getScenes(
outputData: Float32Array,
width: number,
height: number
): ObjectDetection[] {
const scenes = [];
for (let i = 0; i < outputData.length; i++) {
if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) {
scenes.push({
class: sceneMap.get(i),
score: outputData[i],
bbox: [0, 0, width, height],
});
}
}
return scenes;
}
}
export default new ImageScene();

View file

@ -19,6 +19,8 @@ import {
ObjectDetectionMethod, ObjectDetectionMethod,
TextDetectionMethod, TextDetectionMethod,
TextDetectionService, TextDetectionService,
SceneDetectionService,
SceneDetectionMethod,
} from 'types/machineLearning'; } from 'types/machineLearning';
import { CONCURRENCY } from 'utils/common/concurrency'; import { CONCURRENCY } from 'utils/common/concurrency';
import { ComlinkWorker, getDedicatedCryptoWorker } from 'utils/crypto'; import { ComlinkWorker, getDedicatedCryptoWorker } from 'utils/crypto';
@ -31,6 +33,7 @@ import mobileFaceNetEmbeddingService from './mobileFaceNetEmbeddingService';
import dbscanClusteringService from './dbscanClusteringService'; import dbscanClusteringService from './dbscanClusteringService';
import ssdMobileNetV2Service from './ssdMobileNetV2Service'; import ssdMobileNetV2Service from './ssdMobileNetV2Service';
import tesseractService from './tesseractService'; import tesseractService from './tesseractService';
import imageSceneService from './imageSceneService';
export class MLFactory { export class MLFactory {
public static getFaceDetectionService( public static getFaceDetectionService(
@ -53,6 +56,16 @@ export class MLFactory {
throw Error('Unknown object detection method: ' + method); throw Error('Unknown object detection method: ' + method);
} }
public static getSceneDetectionService(
method: SceneDetectionMethod
): SceneDetectionService {
if (method === 'Image-Scene') {
return imageSceneService;
}
throw Error('Unknown scene detection method: ' + method);
}
public static getTextDetectionService( public static getTextDetectionService(
method: TextDetectionMethod method: TextDetectionMethod
): TextDetectionService { ): TextDetectionService {
@ -124,6 +137,7 @@ export class LocalMLSyncContext implements MLSyncContext {
public faceEmbeddingService: FaceEmbeddingService; public faceEmbeddingService: FaceEmbeddingService;
public faceClusteringService: ClusteringService; public faceClusteringService: ClusteringService;
public objectDetectionService: ObjectDetectionService; public objectDetectionService: ObjectDetectionService;
public sceneDetectionService: SceneDetectionService;
public textDetectionService: TextDetectionService; public textDetectionService: TextDetectionService;
public localFilesMap: Map<number, EnteFile>; public localFilesMap: Map<number, EnteFile>;
@ -174,6 +188,10 @@ export class LocalMLSyncContext implements MLSyncContext {
this.objectDetectionService = MLFactory.getObjectDetectionService( this.objectDetectionService = MLFactory.getObjectDetectionService(
this.config.objectDetection.method this.config.objectDetection.method
); );
this.sceneDetectionService = MLFactory.getSceneDetectionService(
this.config.sceneDetection.method
);
this.textDetectionService = MLFactory.getTextDetectionService( this.textDetectionService = MLFactory.getTextDetectionService(
this.config.textDetection.method this.config.textDetection.method
); );

View file

@ -24,17 +24,25 @@ class ObjectService {
oldMlFile?.objectDetectionMethod, oldMlFile?.objectDetectionMethod,
syncContext.objectDetectionService.method syncContext.objectDetectionService.method
) && ) &&
!isDifferentOrOld(
oldMlFile?.sceneDetectionMethod,
syncContext.sceneDetectionService.method
) &&
oldMlFile?.imageSource === syncContext.config.imageSource oldMlFile?.imageSource === syncContext.config.imageSource
) { ) {
newMlFile.things = oldMlFile?.things; newMlFile.things = oldMlFile?.things;
newMlFile.imageSource = oldMlFile.imageSource; newMlFile.imageSource = oldMlFile.imageSource;
newMlFile.imageDimensions = oldMlFile.imageDimensions; newMlFile.imageDimensions = oldMlFile.imageDimensions;
newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod; newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod;
newMlFile.sceneDetectionMethod = oldMlFile.sceneDetectionMethod;
return; return;
} }
newMlFile.objectDetectionMethod = newMlFile.objectDetectionMethod =
syncContext.objectDetectionService.method; syncContext.objectDetectionService.method;
newMlFile.sceneDetectionMethod =
syncContext.sceneDetectionService.method;
fileContext.newDetection = true; fileContext.newDetection = true;
const imageBitmap = await ReaderService.getImageBitmap( const imageBitmap = await ReaderService.getImageBitmap(
syncContext, syncContext,
@ -46,6 +54,11 @@ class ObjectService {
syncContext.config.objectDetection.maxNumBoxes, syncContext.config.objectDetection.maxNumBoxes,
syncContext.config.objectDetection.minScore syncContext.config.objectDetection.minScore
); );
objectDetections.push(
...(await syncContext.sceneDetectionService.detectScenes(
imageBitmap
))
);
// console.log('3 TF Memory stats: ', tf.memory()); // console.log('3 TF Memory stats: ', tf.memory());
// TODO: reenable faces filtering based on width // TODO: reenable faces filtering based on width
const detectedObjects = objectDetections?.map((detection) => { const detectedObjects = objectDetections?.map((detection) => {

View file

@ -1,62 +0,0 @@
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgl';
import '@tensorflow/tfjs-backend-cpu';
import sceneMap from 'utils/machineLearning/sceneMap';
const MIN_SCENE_DETECTION_SCORE = 0.25;
class SceneDetectionService {
model: tf.GraphModel;
async init() {
if (this.model) {
return;
}
const model = await tf.loadGraphModel('/models/imagescene/model.json');
console.log('loaded image-scene model', model, tf.getBackend());
this.model = model;
}
async run(file: File) {
const bmp = await createImageBitmap(file);
await tf.ready();
const currTime = new Date().getTime();
const output = tf.tidy(() => {
let tensor = tf.browser.fromPixels(bmp);
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
tensor = tf.expandDims(tensor);
tensor = tf.cast(tensor, 'float32');
const output = this.model.predict(tensor, {
verbose: true,
});
return output;
});
console.log('done in', new Date().getTime() - currTime, 'ms');
const data = await (output as tf.Tensor).data();
const scenes = this.getScenes(data as Float32Array);
console.log('scenes', scenes);
}
getScenes(outputData: Float32Array) {
const scenes = [];
for (let i = 0; i < outputData.length; i++) {
if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) {
scenes.push({
name: sceneMap.get(i),
score: outputData[i],
});
}
}
return scenes;
}
}
export default new SceneDetectionService();

View file

@ -95,6 +95,8 @@ export declare type FaceDetectionMethod = 'BlazeFace' | 'FaceApiSSD';
export declare type ObjectDetectionMethod = 'SSDMobileNetV2'; export declare type ObjectDetectionMethod = 'SSDMobileNetV2';
export declare type SceneDetectionMethod = 'Image-Scene';
export declare type TextDetectionMethod = 'Tesseract'; export declare type TextDetectionMethod = 'Tesseract';
export declare type FaceCropMethod = 'ArcFace'; export declare type FaceCropMethod = 'ArcFace';
@ -233,6 +235,7 @@ export interface MlFileData {
faceAlignmentMethod?: Versioned<FaceAlignmentMethod>; faceAlignmentMethod?: Versioned<FaceAlignmentMethod>;
faceEmbeddingMethod?: Versioned<FaceEmbeddingMethod>; faceEmbeddingMethod?: Versioned<FaceEmbeddingMethod>;
objectDetectionMethod?: Versioned<ObjectDetectionMethod>; objectDetectionMethod?: Versioned<ObjectDetectionMethod>;
sceneDetectionMethod?: Versioned<SceneDetectionMethod>;
textDetectionMethod?: Versioned<TextDetectionMethod>; textDetectionMethod?: Versioned<TextDetectionMethod>;
mlVersion: number; mlVersion: number;
errorCount: number; errorCount: number;
@ -250,6 +253,10 @@ export interface ObjectDetectionConfig {
minScore: number; minScore: number;
} }
export interface SceneDetectionConfig {
method: SceneDetectionMethod;
}
export interface TextDetectionConfig { export interface TextDetectionConfig {
method: TextDetectionMethod; method: TextDetectionMethod;
minAccuracy: number; minAccuracy: number;
@ -299,6 +306,7 @@ export interface MLSyncConfig extends Config {
faceEmbedding: FaceEmbeddingConfig; faceEmbedding: FaceEmbeddingConfig;
faceClustering: FaceClusteringConfig; faceClustering: FaceClusteringConfig;
objectDetection: ObjectDetectionConfig; objectDetection: ObjectDetectionConfig;
sceneDetection: SceneDetectionConfig;
textDetection: TextDetectionConfig; textDetection: TextDetectionConfig;
tsne?: TSNEConfig; tsne?: TSNEConfig;
mlVersion: number; mlVersion: number;
@ -319,6 +327,7 @@ export interface MLSyncContext {
faceEmbeddingService: FaceEmbeddingService; faceEmbeddingService: FaceEmbeddingService;
faceClusteringService: ClusteringService; faceClusteringService: ClusteringService;
objectDetectionService?: ObjectDetectionService; objectDetectionService?: ObjectDetectionService;
sceneDetectionService?: SceneDetectionService;
textDetectionService?: TextDetectionService; textDetectionService?: TextDetectionService;
localFilesMap: Map<number, EnteFile>; localFilesMap: Map<number, EnteFile>;
@ -381,6 +390,12 @@ export interface ObjectDetectionService {
dispose(): Promise<void>; dispose(): Promise<void>;
} }
export interface SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
// init(): Promise<void>;
detectScenes(image: ImageBitmap): Promise<ObjectDetection[]>;
}
export interface TextDetectionService { export interface TextDetectionService {
method: Versioned<TextDetectionMethod>; method: Versioned<TextDetectionMethod>;
// init(): Promise<void>; // init(): Promise<void>;

View file

@ -1,34 +1,34 @@
const sceneMap = new Map([ const sceneMap = new Map([
[0, 'Waterfall'], [0, 'waterfall'],
[1, 'Snow'], [1, 'snow'],
[2, 'Landscape'], [2, 'landscape'],
[3, 'Underwater'], [3, 'underwater'],
[4, 'Architecture'], [4, 'architecture'],
[5, 'Sunset Sunrise'], [5, 'sunset / sunrise'],
[6, 'Blue Sky'], [6, 'blue sky'],
[7, 'Cloudy Sky'], [7, 'cloudy sky'],
[8, 'Greenery'], [8, 'greenery'],
[9, 'Autumn Leaves'], [9, 'autumn leaves'],
[10, 'Potrait'], [10, 'potrait'],
[11, 'Flower'], [11, 'flower'],
[12, 'Night Shot'], [12, 'night shot'],
[13, 'Stage Concert'], [13, 'stage concert'],
[14, 'Fireworks'], [14, 'fireworks'],
[15, 'Candle Light'], [15, 'candle light'],
[16, 'Neon Lights'], [16, 'neon lights'],
[17, 'Indoor'], [17, 'indoor'],
[18, 'Backlight'], [18, 'backlight'],
[19, 'Text Documents'], [19, 'text documents'],
[20, 'QR Images'], [20, 'qr images'],
[21, 'Group Potrait'], [21, 'group potrait'],
[22, 'Computer Screens'], [22, 'computer screens'],
[23, 'Kids'], [23, 'kids'],
[24, 'Dog'], [24, 'dog'],
[25, 'Cat'], [25, 'cat'],
[26, 'Macro'], [26, 'macro'],
[27, 'Food'], [27, 'food'],
[28, 'Beach'], [28, 'beach'],
[29, 'Mountain'], [29, 'mountain'],
]); ]);
export default sceneMap; export default sceneMap;