integrated scene detection
This commit is contained in:
parent
2521073b04
commit
e835010716
|
@ -52,6 +52,9 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
|
|||
method: 'Tesseract',
|
||||
minAccuracy: 75,
|
||||
},
|
||||
sceneDetection: {
|
||||
method: 'Image-Scene',
|
||||
},
|
||||
// tsne: {
|
||||
// samples: 200,
|
||||
// dim: 2,
|
||||
|
|
|
@ -1,30 +1,40 @@
|
|||
import React, { useEffect, useState } from 'react';
|
||||
import sceneDetectionService from 'services/machineLearning/sceneDetectionService';
|
||||
import sceneDetectionService from 'services/machineLearning/imageSceneService';
|
||||
|
||||
function SceneDebug() {
|
||||
const [selectedFile, setSelectedFile] = useState<File>(null);
|
||||
const [selectedFiles, setSelectedFiles] = useState<File[]>(null);
|
||||
|
||||
const changeHandler = (event: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setSelectedFile(event.target.files[0]);
|
||||
setSelectedFiles([...event.target.files]);
|
||||
};
|
||||
|
||||
const handleSubmission = async () => {
|
||||
await sceneDetectionService.init();
|
||||
await sceneDetectionService.run(selectedFile);
|
||||
for (const file of selectedFiles) {
|
||||
await sceneDetectionService.detectByFile(file);
|
||||
}
|
||||
console.log('done with scene detection');
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
console.log(selectedFile);
|
||||
}, [selectedFile]);
|
||||
console.log('loaded', selectedFiles);
|
||||
}, [selectedFiles]);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<input type="file" name="file" onChange={changeHandler} />
|
||||
<input
|
||||
type="file"
|
||||
name="file"
|
||||
multiple={true}
|
||||
onChange={changeHandler}
|
||||
/>
|
||||
<div>
|
||||
<button onClick={handleSubmission}>Submit</button>
|
||||
</div>
|
||||
{selectedFile && (
|
||||
<img src={URL.createObjectURL(selectedFile)} width={'400px'} />
|
||||
{selectedFiles?.length > 0 && (
|
||||
<img
|
||||
src={URL.createObjectURL(selectedFiles[0])}
|
||||
width={'400px'}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
|
116
src/services/machineLearning/imageSceneService.ts
Normal file
116
src/services/machineLearning/imageSceneService.ts
Normal file
|
@ -0,0 +1,116 @@
|
|||
import * as tf from '@tensorflow/tfjs';
|
||||
import {
|
||||
ObjectDetection,
|
||||
SceneDetectionMethod,
|
||||
SceneDetectionService,
|
||||
Versioned,
|
||||
} from 'types/machineLearning';
|
||||
import sceneMap from 'utils/machineLearning/sceneMap';
|
||||
|
||||
const MIN_SCENE_DETECTION_SCORE = 0.1;
|
||||
|
||||
class ImageScene implements SceneDetectionService {
|
||||
method: Versioned<SceneDetectionMethod>;
|
||||
model: tf.GraphModel;
|
||||
|
||||
public constructor() {
|
||||
this.method = {
|
||||
value: 'Image-Scene',
|
||||
version: 1,
|
||||
};
|
||||
}
|
||||
|
||||
private async init() {
|
||||
if (this.model) {
|
||||
return;
|
||||
}
|
||||
|
||||
const model = await tf.loadGraphModel('/models/imagescene/model.json');
|
||||
console.log('loaded image-scene model', model, tf.getBackend());
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
async detectByFile(file: File) {
|
||||
const bmp = await createImageBitmap(file);
|
||||
|
||||
await tf.ready();
|
||||
|
||||
if (!this.model) {
|
||||
await this.init();
|
||||
}
|
||||
|
||||
const currTime = new Date().getTime();
|
||||
const output = tf.tidy(() => {
|
||||
let tensor = tf.browser.fromPixels(bmp);
|
||||
|
||||
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
|
||||
tensor = tf.expandDims(tensor);
|
||||
tensor = tf.cast(tensor, 'float32');
|
||||
|
||||
const output = this.model.predict(tensor, {
|
||||
verbose: true,
|
||||
});
|
||||
|
||||
return output;
|
||||
});
|
||||
|
||||
console.log('done in', new Date().getTime() - currTime, 'ms');
|
||||
|
||||
const data = await (output as tf.Tensor).data();
|
||||
const scenes = this.getScenes(
|
||||
data as Float32Array,
|
||||
bmp.width,
|
||||
bmp.height
|
||||
);
|
||||
console.log(`scenes for ${file.name}`, scenes);
|
||||
}
|
||||
|
||||
async detectScenes(image: ImageBitmap) {
|
||||
await tf.ready();
|
||||
|
||||
if (!this.model) {
|
||||
await this.init();
|
||||
}
|
||||
|
||||
const output = tf.tidy(() => {
|
||||
let tensor = tf.browser.fromPixels(image);
|
||||
|
||||
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
|
||||
tensor = tf.expandDims(tensor);
|
||||
tensor = tf.cast(tensor, 'float32');
|
||||
|
||||
const output = this.model.predict(tensor);
|
||||
|
||||
return output;
|
||||
});
|
||||
|
||||
const data = await (output as tf.Tensor).data();
|
||||
const scenes = this.getScenes(
|
||||
data as Float32Array,
|
||||
image.width,
|
||||
image.height
|
||||
);
|
||||
|
||||
return scenes;
|
||||
}
|
||||
|
||||
private getScenes(
|
||||
outputData: Float32Array,
|
||||
width: number,
|
||||
height: number
|
||||
): ObjectDetection[] {
|
||||
const scenes = [];
|
||||
for (let i = 0; i < outputData.length; i++) {
|
||||
if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) {
|
||||
scenes.push({
|
||||
class: sceneMap.get(i),
|
||||
score: outputData[i],
|
||||
bbox: [0, 0, width, height],
|
||||
});
|
||||
}
|
||||
}
|
||||
return scenes;
|
||||
}
|
||||
}
|
||||
|
||||
export default new ImageScene();
|
|
@ -19,6 +19,8 @@ import {
|
|||
ObjectDetectionMethod,
|
||||
TextDetectionMethod,
|
||||
TextDetectionService,
|
||||
SceneDetectionService,
|
||||
SceneDetectionMethod,
|
||||
} from 'types/machineLearning';
|
||||
import { CONCURRENCY } from 'utils/common/concurrency';
|
||||
import { ComlinkWorker, getDedicatedCryptoWorker } from 'utils/crypto';
|
||||
|
@ -31,6 +33,7 @@ import mobileFaceNetEmbeddingService from './mobileFaceNetEmbeddingService';
|
|||
import dbscanClusteringService from './dbscanClusteringService';
|
||||
import ssdMobileNetV2Service from './ssdMobileNetV2Service';
|
||||
import tesseractService from './tesseractService';
|
||||
import imageSceneService from './imageSceneService';
|
||||
|
||||
export class MLFactory {
|
||||
public static getFaceDetectionService(
|
||||
|
@ -53,6 +56,16 @@ export class MLFactory {
|
|||
throw Error('Unknown object detection method: ' + method);
|
||||
}
|
||||
|
||||
public static getSceneDetectionService(
|
||||
method: SceneDetectionMethod
|
||||
): SceneDetectionService {
|
||||
if (method === 'Image-Scene') {
|
||||
return imageSceneService;
|
||||
}
|
||||
|
||||
throw Error('Unknown scene detection method: ' + method);
|
||||
}
|
||||
|
||||
public static getTextDetectionService(
|
||||
method: TextDetectionMethod
|
||||
): TextDetectionService {
|
||||
|
@ -124,6 +137,7 @@ export class LocalMLSyncContext implements MLSyncContext {
|
|||
public faceEmbeddingService: FaceEmbeddingService;
|
||||
public faceClusteringService: ClusteringService;
|
||||
public objectDetectionService: ObjectDetectionService;
|
||||
public sceneDetectionService: SceneDetectionService;
|
||||
public textDetectionService: TextDetectionService;
|
||||
|
||||
public localFilesMap: Map<number, EnteFile>;
|
||||
|
@ -174,6 +188,10 @@ export class LocalMLSyncContext implements MLSyncContext {
|
|||
this.objectDetectionService = MLFactory.getObjectDetectionService(
|
||||
this.config.objectDetection.method
|
||||
);
|
||||
this.sceneDetectionService = MLFactory.getSceneDetectionService(
|
||||
this.config.sceneDetection.method
|
||||
);
|
||||
|
||||
this.textDetectionService = MLFactory.getTextDetectionService(
|
||||
this.config.textDetection.method
|
||||
);
|
||||
|
|
|
@ -24,17 +24,25 @@ class ObjectService {
|
|||
oldMlFile?.objectDetectionMethod,
|
||||
syncContext.objectDetectionService.method
|
||||
) &&
|
||||
!isDifferentOrOld(
|
||||
oldMlFile?.sceneDetectionMethod,
|
||||
syncContext.sceneDetectionService.method
|
||||
) &&
|
||||
oldMlFile?.imageSource === syncContext.config.imageSource
|
||||
) {
|
||||
newMlFile.things = oldMlFile?.things;
|
||||
newMlFile.imageSource = oldMlFile.imageSource;
|
||||
newMlFile.imageDimensions = oldMlFile.imageDimensions;
|
||||
newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod;
|
||||
newMlFile.sceneDetectionMethod = oldMlFile.sceneDetectionMethod;
|
||||
return;
|
||||
}
|
||||
|
||||
newMlFile.objectDetectionMethod =
|
||||
syncContext.objectDetectionService.method;
|
||||
newMlFile.sceneDetectionMethod =
|
||||
syncContext.sceneDetectionService.method;
|
||||
|
||||
fileContext.newDetection = true;
|
||||
const imageBitmap = await ReaderService.getImageBitmap(
|
||||
syncContext,
|
||||
|
@ -46,6 +54,11 @@ class ObjectService {
|
|||
syncContext.config.objectDetection.maxNumBoxes,
|
||||
syncContext.config.objectDetection.minScore
|
||||
);
|
||||
objectDetections.push(
|
||||
...(await syncContext.sceneDetectionService.detectScenes(
|
||||
imageBitmap
|
||||
))
|
||||
);
|
||||
// console.log('3 TF Memory stats: ', tf.memory());
|
||||
// TODO: reenable faces filtering based on width
|
||||
const detectedObjects = objectDetections?.map((detection) => {
|
||||
|
|
|
@ -1,62 +0,0 @@
|
|||
import * as tf from '@tensorflow/tfjs';
|
||||
import '@tensorflow/tfjs-backend-webgl';
|
||||
import '@tensorflow/tfjs-backend-cpu';
|
||||
import sceneMap from 'utils/machineLearning/sceneMap';
|
||||
|
||||
const MIN_SCENE_DETECTION_SCORE = 0.25;
|
||||
|
||||
class SceneDetectionService {
|
||||
model: tf.GraphModel;
|
||||
|
||||
async init() {
|
||||
if (this.model) {
|
||||
return;
|
||||
}
|
||||
|
||||
const model = await tf.loadGraphModel('/models/imagescene/model.json');
|
||||
console.log('loaded image-scene model', model, tf.getBackend());
|
||||
this.model = model;
|
||||
}
|
||||
|
||||
async run(file: File) {
|
||||
const bmp = await createImageBitmap(file);
|
||||
|
||||
await tf.ready();
|
||||
|
||||
const currTime = new Date().getTime();
|
||||
const output = tf.tidy(() => {
|
||||
let tensor = tf.browser.fromPixels(bmp);
|
||||
|
||||
tensor = tf.image.resizeBilinear(tensor, [224, 224]);
|
||||
tensor = tf.expandDims(tensor);
|
||||
tensor = tf.cast(tensor, 'float32');
|
||||
|
||||
const output = this.model.predict(tensor, {
|
||||
verbose: true,
|
||||
});
|
||||
|
||||
return output;
|
||||
});
|
||||
|
||||
console.log('done in', new Date().getTime() - currTime, 'ms');
|
||||
|
||||
const data = await (output as tf.Tensor).data();
|
||||
const scenes = this.getScenes(data as Float32Array);
|
||||
console.log('scenes', scenes);
|
||||
}
|
||||
|
||||
getScenes(outputData: Float32Array) {
|
||||
const scenes = [];
|
||||
for (let i = 0; i < outputData.length; i++) {
|
||||
if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) {
|
||||
scenes.push({
|
||||
name: sceneMap.get(i),
|
||||
score: outputData[i],
|
||||
});
|
||||
}
|
||||
}
|
||||
return scenes;
|
||||
}
|
||||
}
|
||||
|
||||
export default new SceneDetectionService();
|
|
@ -95,6 +95,8 @@ export declare type FaceDetectionMethod = 'BlazeFace' | 'FaceApiSSD';
|
|||
|
||||
export declare type ObjectDetectionMethod = 'SSDMobileNetV2';
|
||||
|
||||
export declare type SceneDetectionMethod = 'Image-Scene';
|
||||
|
||||
export declare type TextDetectionMethod = 'Tesseract';
|
||||
|
||||
export declare type FaceCropMethod = 'ArcFace';
|
||||
|
@ -233,6 +235,7 @@ export interface MlFileData {
|
|||
faceAlignmentMethod?: Versioned<FaceAlignmentMethod>;
|
||||
faceEmbeddingMethod?: Versioned<FaceEmbeddingMethod>;
|
||||
objectDetectionMethod?: Versioned<ObjectDetectionMethod>;
|
||||
sceneDetectionMethod?: Versioned<SceneDetectionMethod>;
|
||||
textDetectionMethod?: Versioned<TextDetectionMethod>;
|
||||
mlVersion: number;
|
||||
errorCount: number;
|
||||
|
@ -250,6 +253,10 @@ export interface ObjectDetectionConfig {
|
|||
minScore: number;
|
||||
}
|
||||
|
||||
export interface SceneDetectionConfig {
|
||||
method: SceneDetectionMethod;
|
||||
}
|
||||
|
||||
export interface TextDetectionConfig {
|
||||
method: TextDetectionMethod;
|
||||
minAccuracy: number;
|
||||
|
@ -299,6 +306,7 @@ export interface MLSyncConfig extends Config {
|
|||
faceEmbedding: FaceEmbeddingConfig;
|
||||
faceClustering: FaceClusteringConfig;
|
||||
objectDetection: ObjectDetectionConfig;
|
||||
sceneDetection: SceneDetectionConfig;
|
||||
textDetection: TextDetectionConfig;
|
||||
tsne?: TSNEConfig;
|
||||
mlVersion: number;
|
||||
|
@ -319,6 +327,7 @@ export interface MLSyncContext {
|
|||
faceEmbeddingService: FaceEmbeddingService;
|
||||
faceClusteringService: ClusteringService;
|
||||
objectDetectionService?: ObjectDetectionService;
|
||||
sceneDetectionService?: SceneDetectionService;
|
||||
textDetectionService?: TextDetectionService;
|
||||
|
||||
localFilesMap: Map<number, EnteFile>;
|
||||
|
@ -381,6 +390,12 @@ export interface ObjectDetectionService {
|
|||
dispose(): Promise<void>;
|
||||
}
|
||||
|
||||
export interface SceneDetectionService {
|
||||
method: Versioned<SceneDetectionMethod>;
|
||||
// init(): Promise<void>;
|
||||
detectScenes(image: ImageBitmap): Promise<ObjectDetection[]>;
|
||||
}
|
||||
|
||||
export interface TextDetectionService {
|
||||
method: Versioned<TextDetectionMethod>;
|
||||
// init(): Promise<void>;
|
||||
|
|
|
@ -1,34 +1,34 @@
|
|||
const sceneMap = new Map([
|
||||
[0, 'Waterfall'],
|
||||
[1, 'Snow'],
|
||||
[2, 'Landscape'],
|
||||
[3, 'Underwater'],
|
||||
[4, 'Architecture'],
|
||||
[5, 'Sunset Sunrise'],
|
||||
[6, 'Blue Sky'],
|
||||
[7, 'Cloudy Sky'],
|
||||
[8, 'Greenery'],
|
||||
[9, 'Autumn Leaves'],
|
||||
[10, 'Potrait'],
|
||||
[11, 'Flower'],
|
||||
[12, 'Night Shot'],
|
||||
[13, 'Stage Concert'],
|
||||
[14, 'Fireworks'],
|
||||
[15, 'Candle Light'],
|
||||
[16, 'Neon Lights'],
|
||||
[17, 'Indoor'],
|
||||
[18, 'Backlight'],
|
||||
[19, 'Text Documents'],
|
||||
[20, 'QR Images'],
|
||||
[21, 'Group Potrait'],
|
||||
[22, 'Computer Screens'],
|
||||
[23, 'Kids'],
|
||||
[24, 'Dog'],
|
||||
[25, 'Cat'],
|
||||
[26, 'Macro'],
|
||||
[27, 'Food'],
|
||||
[28, 'Beach'],
|
||||
[29, 'Mountain'],
|
||||
[0, 'waterfall'],
|
||||
[1, 'snow'],
|
||||
[2, 'landscape'],
|
||||
[3, 'underwater'],
|
||||
[4, 'architecture'],
|
||||
[5, 'sunset / sunrise'],
|
||||
[6, 'blue sky'],
|
||||
[7, 'cloudy sky'],
|
||||
[8, 'greenery'],
|
||||
[9, 'autumn leaves'],
|
||||
[10, 'potrait'],
|
||||
[11, 'flower'],
|
||||
[12, 'night shot'],
|
||||
[13, 'stage concert'],
|
||||
[14, 'fireworks'],
|
||||
[15, 'candle light'],
|
||||
[16, 'neon lights'],
|
||||
[17, 'indoor'],
|
||||
[18, 'backlight'],
|
||||
[19, 'text documents'],
|
||||
[20, 'qr images'],
|
||||
[21, 'group potrait'],
|
||||
[22, 'computer screens'],
|
||||
[23, 'kids'],
|
||||
[24, 'dog'],
|
||||
[25, 'cat'],
|
||||
[26, 'macro'],
|
||||
[27, 'food'],
|
||||
[28, 'beach'],
|
||||
[29, 'mountain'],
|
||||
]);
|
||||
|
||||
export default sceneMap;
|
||||
|
|
Loading…
Reference in a new issue