integrated scene detection

This commit is contained in:
Rushikesh Tote 2022-05-30 20:16:52 +05:30
parent 2521073b04
commit e835010716
8 changed files with 215 additions and 102 deletions

View file

@ -52,6 +52,9 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
method: 'Tesseract',
minAccuracy: 75,
},
sceneDetection: {
method: 'Image-Scene',
},
// tsne: {
// samples: 200,
// dim: 2,

View file

@ -1,30 +1,40 @@
import React, { useEffect, useState } from 'react';
import sceneDetectionService from 'services/machineLearning/sceneDetectionService';
import sceneDetectionService from 'services/machineLearning/imageSceneService';
function SceneDebug() {
const [selectedFile, setSelectedFile] = useState<File>(null);
const [selectedFiles, setSelectedFiles] = useState<File[]>(null);
const changeHandler = (event: React.ChangeEvent<HTMLInputElement>) => {
setSelectedFile(event.target.files[0]);
setSelectedFiles([...event.target.files]);
};
const handleSubmission = async () => {
await sceneDetectionService.init();
await sceneDetectionService.run(selectedFile);
for (const file of selectedFiles) {
await sceneDetectionService.detectByFile(file);
}
console.log('done with scene detection');
};
useEffect(() => {
console.log(selectedFile);
}, [selectedFile]);
console.log('loaded', selectedFiles);
}, [selectedFiles]);
return (
<div>
<input type="file" name="file" onChange={changeHandler} />
<input
type="file"
name="file"
multiple={true}
onChange={changeHandler}
/>
<div>
<button onClick={handleSubmission}>Submit</button>
</div>
{selectedFile && (
<img src={URL.createObjectURL(selectedFile)} width={'400px'} />
{selectedFiles?.length > 0 && (
<img
src={URL.createObjectURL(selectedFiles[0])}
width={'400px'}
/>
)}
</div>
);

View file

@ -0,0 +1,116 @@
import * as tf from '@tensorflow/tfjs';
import {
ObjectDetection,
SceneDetectionMethod,
SceneDetectionService,
Versioned,
} from 'types/machineLearning';
import sceneMap from 'utils/machineLearning/sceneMap';
const MIN_SCENE_DETECTION_SCORE = 0.1;
class ImageScene implements SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
model: tf.GraphModel;
public constructor() {
this.method = {
value: 'Image-Scene',
version: 1,
};
}
private async init() {
if (this.model) {
return;
}
const model = await tf.loadGraphModel('/models/imagescene/model.json');
console.log('loaded image-scene model', model, tf.getBackend());
this.model = model;
}
async detectByFile(file: File) {
const bmp = await createImageBitmap(file);
await tf.ready();
if (!this.model) {
await this.init();
}
const currTime = new Date().getTime();
const output = tf.tidy(() => {
let tensor = tf.browser.fromPixels(bmp);
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
tensor = tf.expandDims(tensor);
tensor = tf.cast(tensor, 'float32');
const output = this.model.predict(tensor, {
verbose: true,
});
return output;
});
console.log('done in', new Date().getTime() - currTime, 'ms');
const data = await (output as tf.Tensor).data();
const scenes = this.getScenes(
data as Float32Array,
bmp.width,
bmp.height
);
console.log(`scenes for ${file.name}`, scenes);
}
async detectScenes(image: ImageBitmap) {
await tf.ready();
if (!this.model) {
await this.init();
}
const output = tf.tidy(() => {
let tensor = tf.browser.fromPixels(image);
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
tensor = tf.expandDims(tensor);
tensor = tf.cast(tensor, 'float32');
const output = this.model.predict(tensor);
return output;
});
const data = await (output as tf.Tensor).data();
const scenes = this.getScenes(
data as Float32Array,
image.width,
image.height
);
return scenes;
}
private getScenes(
outputData: Float32Array,
width: number,
height: number
): ObjectDetection[] {
const scenes = [];
for (let i = 0; i < outputData.length; i++) {
if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) {
scenes.push({
class: sceneMap.get(i),
score: outputData[i],
bbox: [0, 0, width, height],
});
}
}
return scenes;
}
}
export default new ImageScene();

View file

@ -19,6 +19,8 @@ import {
ObjectDetectionMethod,
TextDetectionMethod,
TextDetectionService,
SceneDetectionService,
SceneDetectionMethod,
} from 'types/machineLearning';
import { CONCURRENCY } from 'utils/common/concurrency';
import { ComlinkWorker, getDedicatedCryptoWorker } from 'utils/crypto';
@ -31,6 +33,7 @@ import mobileFaceNetEmbeddingService from './mobileFaceNetEmbeddingService';
import dbscanClusteringService from './dbscanClusteringService';
import ssdMobileNetV2Service from './ssdMobileNetV2Service';
import tesseractService from './tesseractService';
import imageSceneService from './imageSceneService';
export class MLFactory {
public static getFaceDetectionService(
@ -53,6 +56,16 @@ export class MLFactory {
throw Error('Unknown object detection method: ' + method);
}
public static getSceneDetectionService(
method: SceneDetectionMethod
): SceneDetectionService {
if (method === 'Image-Scene') {
return imageSceneService;
}
throw Error('Unknown scene detection method: ' + method);
}
public static getTextDetectionService(
method: TextDetectionMethod
): TextDetectionService {
@ -124,6 +137,7 @@ export class LocalMLSyncContext implements MLSyncContext {
public faceEmbeddingService: FaceEmbeddingService;
public faceClusteringService: ClusteringService;
public objectDetectionService: ObjectDetectionService;
public sceneDetectionService: SceneDetectionService;
public textDetectionService: TextDetectionService;
public localFilesMap: Map<number, EnteFile>;
@ -174,6 +188,10 @@ export class LocalMLSyncContext implements MLSyncContext {
this.objectDetectionService = MLFactory.getObjectDetectionService(
this.config.objectDetection.method
);
this.sceneDetectionService = MLFactory.getSceneDetectionService(
this.config.sceneDetection.method
);
this.textDetectionService = MLFactory.getTextDetectionService(
this.config.textDetection.method
);

View file

@ -24,17 +24,25 @@ class ObjectService {
oldMlFile?.objectDetectionMethod,
syncContext.objectDetectionService.method
) &&
!isDifferentOrOld(
oldMlFile?.sceneDetectionMethod,
syncContext.sceneDetectionService.method
) &&
oldMlFile?.imageSource === syncContext.config.imageSource
) {
newMlFile.things = oldMlFile?.things;
newMlFile.imageSource = oldMlFile.imageSource;
newMlFile.imageDimensions = oldMlFile.imageDimensions;
newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod;
newMlFile.sceneDetectionMethod = oldMlFile.sceneDetectionMethod;
return;
}
newMlFile.objectDetectionMethod =
syncContext.objectDetectionService.method;
newMlFile.sceneDetectionMethod =
syncContext.sceneDetectionService.method;
fileContext.newDetection = true;
const imageBitmap = await ReaderService.getImageBitmap(
syncContext,
@ -46,6 +54,11 @@ class ObjectService {
syncContext.config.objectDetection.maxNumBoxes,
syncContext.config.objectDetection.minScore
);
objectDetections.push(
...(await syncContext.sceneDetectionService.detectScenes(
imageBitmap
))
);
// console.log('3 TF Memory stats: ', tf.memory());
// TODO: reenable faces filtering based on width
const detectedObjects = objectDetections?.map((detection) => {

View file

@ -1,62 +0,0 @@
import * as tf from '@tensorflow/tfjs';
import '@tensorflow/tfjs-backend-webgl';
import '@tensorflow/tfjs-backend-cpu';
import sceneMap from 'utils/machineLearning/sceneMap';
const MIN_SCENE_DETECTION_SCORE = 0.25;
class SceneDetectionService {
model: tf.GraphModel;
async init() {
if (this.model) {
return;
}
const model = await tf.loadGraphModel('/models/imagescene/model.json');
console.log('loaded image-scene model', model, tf.getBackend());
this.model = model;
}
async run(file: File) {
const bmp = await createImageBitmap(file);
await tf.ready();
const currTime = new Date().getTime();
const output = tf.tidy(() => {
let tensor = tf.browser.fromPixels(bmp);
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
tensor = tf.expandDims(tensor);
tensor = tf.cast(tensor, 'float32');
const output = this.model.predict(tensor, {
verbose: true,
});
return output;
});
console.log('done in', new Date().getTime() - currTime, 'ms');
const data = await (output as tf.Tensor).data();
const scenes = this.getScenes(data as Float32Array);
console.log('scenes', scenes);
}
getScenes(outputData: Float32Array) {
const scenes = [];
for (let i = 0; i < outputData.length; i++) {
if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) {
scenes.push({
name: sceneMap.get(i),
score: outputData[i],
});
}
}
return scenes;
}
}
export default new SceneDetectionService();

View file

@ -95,6 +95,8 @@ export declare type FaceDetectionMethod = 'BlazeFace' | 'FaceApiSSD';
export declare type ObjectDetectionMethod = 'SSDMobileNetV2';
export declare type SceneDetectionMethod = 'Image-Scene';
export declare type TextDetectionMethod = 'Tesseract';
export declare type FaceCropMethod = 'ArcFace';
@ -233,6 +235,7 @@ export interface MlFileData {
faceAlignmentMethod?: Versioned<FaceAlignmentMethod>;
faceEmbeddingMethod?: Versioned<FaceEmbeddingMethod>;
objectDetectionMethod?: Versioned<ObjectDetectionMethod>;
sceneDetectionMethod?: Versioned<SceneDetectionMethod>;
textDetectionMethod?: Versioned<TextDetectionMethod>;
mlVersion: number;
errorCount: number;
@ -250,6 +253,10 @@ export interface ObjectDetectionConfig {
minScore: number;
}
export interface SceneDetectionConfig {
method: SceneDetectionMethod;
}
export interface TextDetectionConfig {
method: TextDetectionMethod;
minAccuracy: number;
@ -299,6 +306,7 @@ export interface MLSyncConfig extends Config {
faceEmbedding: FaceEmbeddingConfig;
faceClustering: FaceClusteringConfig;
objectDetection: ObjectDetectionConfig;
sceneDetection: SceneDetectionConfig;
textDetection: TextDetectionConfig;
tsne?: TSNEConfig;
mlVersion: number;
@ -319,6 +327,7 @@ export interface MLSyncContext {
faceEmbeddingService: FaceEmbeddingService;
faceClusteringService: ClusteringService;
objectDetectionService?: ObjectDetectionService;
sceneDetectionService?: SceneDetectionService;
textDetectionService?: TextDetectionService;
localFilesMap: Map<number, EnteFile>;
@ -381,6 +390,12 @@ export interface ObjectDetectionService {
dispose(): Promise<void>;
}
export interface SceneDetectionService {
method: Versioned<SceneDetectionMethod>;
// init(): Promise<void>;
detectScenes(image: ImageBitmap): Promise<ObjectDetection[]>;
}
export interface TextDetectionService {
method: Versioned<TextDetectionMethod>;
// init(): Promise<void>;

View file

@ -1,34 +1,34 @@
const sceneMap = new Map([
[0, 'Waterfall'],
[1, 'Snow'],
[2, 'Landscape'],
[3, 'Underwater'],
[4, 'Architecture'],
[5, 'Sunset Sunrise'],
[6, 'Blue Sky'],
[7, 'Cloudy Sky'],
[8, 'Greenery'],
[9, 'Autumn Leaves'],
[10, 'Potrait'],
[11, 'Flower'],
[12, 'Night Shot'],
[13, 'Stage Concert'],
[14, 'Fireworks'],
[15, 'Candle Light'],
[16, 'Neon Lights'],
[17, 'Indoor'],
[18, 'Backlight'],
[19, 'Text Documents'],
[20, 'QR Images'],
[21, 'Group Potrait'],
[22, 'Computer Screens'],
[23, 'Kids'],
[24, 'Dog'],
[25, 'Cat'],
[26, 'Macro'],
[27, 'Food'],
[28, 'Beach'],
[29, 'Mountain'],
[0, 'waterfall'],
[1, 'snow'],
[2, 'landscape'],
[3, 'underwater'],
[4, 'architecture'],
[5, 'sunset / sunrise'],
[6, 'blue sky'],
[7, 'cloudy sky'],
[8, 'greenery'],
[9, 'autumn leaves'],
[10, 'potrait'],
[11, 'flower'],
[12, 'night shot'],
[13, 'stage concert'],
[14, 'fireworks'],
[15, 'candle light'],
[16, 'neon lights'],
[17, 'indoor'],
[18, 'backlight'],
[19, 'text documents'],
[20, 'qr images'],
[21, 'group potrait'],
[22, 'computer screens'],
[23, 'kids'],
[24, 'dog'],
[25, 'cat'],
[26, 'macro'],
[27, 'food'],
[28, 'beach'],
[29, 'mountain'],
]);
export default sceneMap;