diff --git a/src/constants/machineLearning/config.ts b/src/constants/machineLearning/config.ts index 0cb3747d1..4355c9d88 100644 --- a/src/constants/machineLearning/config.ts +++ b/src/constants/machineLearning/config.ts @@ -52,6 +52,9 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = { method: 'Tesseract', minAccuracy: 75, }, + sceneDetection: { + method: 'Image-Scene', + }, // tsne: { // samples: 200, // dim: 2, diff --git a/src/pages/scene-debug/index.tsx b/src/pages/scene-debug/index.tsx index 667e547cb..2a7153346 100644 --- a/src/pages/scene-debug/index.tsx +++ b/src/pages/scene-debug/index.tsx @@ -1,30 +1,40 @@ import React, { useEffect, useState } from 'react'; -import sceneDetectionService from 'services/machineLearning/sceneDetectionService'; +import sceneDetectionService from 'services/machineLearning/imageSceneService'; function SceneDebug() { - const [selectedFile, setSelectedFile] = useState(null); + const [selectedFiles, setSelectedFiles] = useState(null); const changeHandler = (event: React.ChangeEvent) => { - setSelectedFile(event.target.files[0]); + setSelectedFiles([...event.target.files]); }; const handleSubmission = async () => { - await sceneDetectionService.init(); - await sceneDetectionService.run(selectedFile); + for (const file of selectedFiles) { + await sceneDetectionService.detectByFile(file); + } + console.log('done with scene detection'); }; useEffect(() => { - console.log(selectedFile); - }, [selectedFile]); + console.log('loaded', selectedFiles); + }, [selectedFiles]); return (
- +
- {selectedFile && ( - + {selectedFiles?.length > 0 && ( + )}
); diff --git a/src/services/machineLearning/imageSceneService.ts b/src/services/machineLearning/imageSceneService.ts new file mode 100644 index 000000000..b092a9f91 --- /dev/null +++ b/src/services/machineLearning/imageSceneService.ts @@ -0,0 +1,116 @@ +import * as tf from '@tensorflow/tfjs'; +import { + ObjectDetection, + SceneDetectionMethod, + SceneDetectionService, + Versioned, +} from 'types/machineLearning'; +import sceneMap from 'utils/machineLearning/sceneMap'; + +const MIN_SCENE_DETECTION_SCORE = 0.1; + +class ImageScene implements SceneDetectionService { + method: Versioned; + model: tf.GraphModel; + + public constructor() { + this.method = { + value: 'Image-Scene', + version: 1, + }; + } + + private async init() { + if (this.model) { + return; + } + + const model = await tf.loadGraphModel('/models/imagescene/model.json'); + console.log('loaded image-scene model', model, tf.getBackend()); + this.model = model; + } + + async detectByFile(file: File) { + const bmp = await createImageBitmap(file); + + await tf.ready(); + + if (!this.model) { + await this.init(); + } + + const currTime = new Date().getTime(); + const output = tf.tidy(() => { + let tensor = tf.browser.fromPixels(bmp); + + tensor = tf.image.resizeBilinear(tensor, [224, 224]); + tensor = tf.expandDims(tensor); + tensor = tf.cast(tensor, 'float32'); + + const output = this.model.predict(tensor, { + verbose: true, + }); + + return output; + }); + + console.log('done in', new Date().getTime() - currTime, 'ms'); + + const data = await (output as tf.Tensor).data(); + const scenes = this.getScenes( + data as Float32Array, + bmp.width, + bmp.height + ); + console.log(`scenes for ${file.name}`, scenes); + } + + async detectScenes(image: ImageBitmap) { + await tf.ready(); + + if (!this.model) { + await this.init(); + } + + const output = tf.tidy(() => { + let tensor = tf.browser.fromPixels(image); + + tensor = tf.image.resizeBilinear(tensor, [224, 224]); + tensor = tf.expandDims(tensor); + tensor = tf.cast(tensor, 'float32'); + + const output = this.model.predict(tensor); + + return output; + }); + + const data = await (output as tf.Tensor).data(); + const scenes = this.getScenes( + data as Float32Array, + image.width, + image.height + ); + + return scenes; + } + + private getScenes( + outputData: Float32Array, + width: number, + height: number + ): ObjectDetection[] { + const scenes = []; + for (let i = 0; i < outputData.length; i++) { + if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) { + scenes.push({ + class: sceneMap.get(i), + score: outputData[i], + bbox: [0, 0, width, height], + }); + } + } + return scenes; + } +} + +export default new ImageScene(); diff --git a/src/services/machineLearning/machineLearningFactory.ts b/src/services/machineLearning/machineLearningFactory.ts index 4a53e101c..1b3b2d6c8 100644 --- a/src/services/machineLearning/machineLearningFactory.ts +++ b/src/services/machineLearning/machineLearningFactory.ts @@ -19,6 +19,8 @@ import { ObjectDetectionMethod, TextDetectionMethod, TextDetectionService, + SceneDetectionService, + SceneDetectionMethod, } from 'types/machineLearning'; import { CONCURRENCY } from 'utils/common/concurrency'; import { ComlinkWorker, getDedicatedCryptoWorker } from 'utils/crypto'; @@ -31,6 +33,7 @@ import mobileFaceNetEmbeddingService from './mobileFaceNetEmbeddingService'; import dbscanClusteringService from './dbscanClusteringService'; import ssdMobileNetV2Service from './ssdMobileNetV2Service'; import tesseractService from './tesseractService'; +import imageSceneService from './imageSceneService'; export class MLFactory { public static getFaceDetectionService( @@ -53,6 +56,16 @@ export class MLFactory { throw Error('Unknown object detection method: ' + method); } + public static getSceneDetectionService( + method: SceneDetectionMethod + ): SceneDetectionService { + if (method === 'Image-Scene') { + return imageSceneService; + } + + throw Error('Unknown scene detection method: ' + method); + } + public static getTextDetectionService( method: TextDetectionMethod ): TextDetectionService { @@ -124,6 +137,7 @@ export class LocalMLSyncContext implements MLSyncContext { public faceEmbeddingService: FaceEmbeddingService; public faceClusteringService: ClusteringService; public objectDetectionService: ObjectDetectionService; + public sceneDetectionService: SceneDetectionService; public textDetectionService: TextDetectionService; public localFilesMap: Map; @@ -174,6 +188,10 @@ export class LocalMLSyncContext implements MLSyncContext { this.objectDetectionService = MLFactory.getObjectDetectionService( this.config.objectDetection.method ); + this.sceneDetectionService = MLFactory.getSceneDetectionService( + this.config.sceneDetection.method + ); + this.textDetectionService = MLFactory.getTextDetectionService( this.config.textDetection.method ); diff --git a/src/services/machineLearning/objectService.ts b/src/services/machineLearning/objectService.ts index 13d651472..30753da0c 100644 --- a/src/services/machineLearning/objectService.ts +++ b/src/services/machineLearning/objectService.ts @@ -24,17 +24,25 @@ class ObjectService { oldMlFile?.objectDetectionMethod, syncContext.objectDetectionService.method ) && + !isDifferentOrOld( + oldMlFile?.sceneDetectionMethod, + syncContext.sceneDetectionService.method + ) && oldMlFile?.imageSource === syncContext.config.imageSource ) { newMlFile.things = oldMlFile?.things; newMlFile.imageSource = oldMlFile.imageSource; newMlFile.imageDimensions = oldMlFile.imageDimensions; newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod; + newMlFile.sceneDetectionMethod = oldMlFile.sceneDetectionMethod; return; } newMlFile.objectDetectionMethod = syncContext.objectDetectionService.method; + newMlFile.sceneDetectionMethod = + syncContext.sceneDetectionService.method; + fileContext.newDetection = true; const imageBitmap = await ReaderService.getImageBitmap( syncContext, @@ -46,6 +54,11 @@ class ObjectService { syncContext.config.objectDetection.maxNumBoxes, syncContext.config.objectDetection.minScore ); + objectDetections.push( + ...(await syncContext.sceneDetectionService.detectScenes( + imageBitmap + )) + ); // console.log('3 TF Memory stats: ', tf.memory()); // TODO: reenable faces filtering based on width const detectedObjects = objectDetections?.map((detection) => { diff --git a/src/services/machineLearning/sceneDetectionService.ts b/src/services/machineLearning/sceneDetectionService.ts deleted file mode 100644 index a64c9f594..000000000 --- a/src/services/machineLearning/sceneDetectionService.ts +++ /dev/null @@ -1,62 +0,0 @@ -import * as tf from '@tensorflow/tfjs'; -import '@tensorflow/tfjs-backend-webgl'; -import '@tensorflow/tfjs-backend-cpu'; -import sceneMap from 'utils/machineLearning/sceneMap'; - -const MIN_SCENE_DETECTION_SCORE = 0.25; - -class SceneDetectionService { - model: tf.GraphModel; - - async init() { - if (this.model) { - return; - } - - const model = await tf.loadGraphModel('/models/imagescene/model.json'); - console.log('loaded image-scene model', model, tf.getBackend()); - this.model = model; - } - - async run(file: File) { - const bmp = await createImageBitmap(file); - - await tf.ready(); - - const currTime = new Date().getTime(); - const output = tf.tidy(() => { - let tensor = tf.browser.fromPixels(bmp); - - tensor = tf.image.resizeBilinear(tensor, [224, 224]); - tensor = tf.expandDims(tensor); - tensor = tf.cast(tensor, 'float32'); - - const output = this.model.predict(tensor, { - verbose: true, - }); - - return output; - }); - - console.log('done in', new Date().getTime() - currTime, 'ms'); - - const data = await (output as tf.Tensor).data(); - const scenes = this.getScenes(data as Float32Array); - console.log('scenes', scenes); - } - - getScenes(outputData: Float32Array) { - const scenes = []; - for (let i = 0; i < outputData.length; i++) { - if (outputData[i] >= MIN_SCENE_DETECTION_SCORE) { - scenes.push({ - name: sceneMap.get(i), - score: outputData[i], - }); - } - } - return scenes; - } -} - -export default new SceneDetectionService(); diff --git a/src/types/machineLearning/index.ts b/src/types/machineLearning/index.ts index f0d2e7837..c23429ad0 100644 --- a/src/types/machineLearning/index.ts +++ b/src/types/machineLearning/index.ts @@ -95,6 +95,8 @@ export declare type FaceDetectionMethod = 'BlazeFace' | 'FaceApiSSD'; export declare type ObjectDetectionMethod = 'SSDMobileNetV2'; +export declare type SceneDetectionMethod = 'Image-Scene'; + export declare type TextDetectionMethod = 'Tesseract'; export declare type FaceCropMethod = 'ArcFace'; @@ -233,6 +235,7 @@ export interface MlFileData { faceAlignmentMethod?: Versioned; faceEmbeddingMethod?: Versioned; objectDetectionMethod?: Versioned; + sceneDetectionMethod?: Versioned; textDetectionMethod?: Versioned; mlVersion: number; errorCount: number; @@ -250,6 +253,10 @@ export interface ObjectDetectionConfig { minScore: number; } +export interface SceneDetectionConfig { + method: SceneDetectionMethod; +} + export interface TextDetectionConfig { method: TextDetectionMethod; minAccuracy: number; @@ -299,6 +306,7 @@ export interface MLSyncConfig extends Config { faceEmbedding: FaceEmbeddingConfig; faceClustering: FaceClusteringConfig; objectDetection: ObjectDetectionConfig; + sceneDetection: SceneDetectionConfig; textDetection: TextDetectionConfig; tsne?: TSNEConfig; mlVersion: number; @@ -319,6 +327,7 @@ export interface MLSyncContext { faceEmbeddingService: FaceEmbeddingService; faceClusteringService: ClusteringService; objectDetectionService?: ObjectDetectionService; + sceneDetectionService?: SceneDetectionService; textDetectionService?: TextDetectionService; localFilesMap: Map; @@ -381,6 +390,12 @@ export interface ObjectDetectionService { dispose(): Promise; } +export interface SceneDetectionService { + method: Versioned; + // init(): Promise; + detectScenes(image: ImageBitmap): Promise; +} + export interface TextDetectionService { method: Versioned; // init(): Promise; diff --git a/src/utils/machineLearning/sceneMap.ts b/src/utils/machineLearning/sceneMap.ts index 94c7d5e92..2fc6a0666 100644 --- a/src/utils/machineLearning/sceneMap.ts +++ b/src/utils/machineLearning/sceneMap.ts @@ -1,34 +1,34 @@ const sceneMap = new Map([ - [0, 'Waterfall'], - [1, 'Snow'], - [2, 'Landscape'], - [3, 'Underwater'], - [4, 'Architecture'], - [5, 'Sunset Sunrise'], - [6, 'Blue Sky'], - [7, 'Cloudy Sky'], - [8, 'Greenery'], - [9, 'Autumn Leaves'], - [10, 'Potrait'], - [11, 'Flower'], - [12, 'Night Shot'], - [13, 'Stage Concert'], - [14, 'Fireworks'], - [15, 'Candle Light'], - [16, 'Neon Lights'], - [17, 'Indoor'], - [18, 'Backlight'], - [19, 'Text Documents'], - [20, 'QR Images'], - [21, 'Group Potrait'], - [22, 'Computer Screens'], - [23, 'Kids'], - [24, 'Dog'], - [25, 'Cat'], - [26, 'Macro'], - [27, 'Food'], - [28, 'Beach'], - [29, 'Mountain'], + [0, 'waterfall'], + [1, 'snow'], + [2, 'landscape'], + [3, 'underwater'], + [4, 'architecture'], + [5, 'sunset / sunrise'], + [6, 'blue sky'], + [7, 'cloudy sky'], + [8, 'greenery'], + [9, 'autumn leaves'], + [10, 'potrait'], + [11, 'flower'], + [12, 'night shot'], + [13, 'stage concert'], + [14, 'fireworks'], + [15, 'candle light'], + [16, 'neon lights'], + [17, 'indoor'], + [18, 'backlight'], + [19, 'text documents'], + [20, 'qr images'], + [21, 'group potrait'], + [22, 'computer screens'], + [23, 'kids'], + [24, 'dog'], + [25, 'cat'], + [26, 'macro'], + [27, 'food'], + [28, 'beach'], + [29, 'mountain'], ]); export default sceneMap;