From ce06ac7baf4e64d11e59f5e134806e50ba269097 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Wed, 23 Mar 2022 03:47:56 +0530 Subject: [PATCH] - added timeout - added accuracy and minScore for text and object detection respectively --- .../MachineLearning/MLFileDebugView.tsx | 4 +- src/constants/machineLearning/config.ts | 10 ++- .../machineLearning/machineLearningFactory.ts | 4 +- .../machineLearning/machineLearningService.ts | 71 ++++++++++++------- src/services/machineLearning/objectService.ts | 5 +- .../machineLearning/ssdMobileNetV2Service.ts | 11 ++- src/services/machineLearning/textService.ts | 11 +-- src/types/machineLearning/index.ts | 11 ++- src/utils/common/promiseTimeout.ts | 14 ++++ 9 files changed, 101 insertions(+), 40 deletions(-) create mode 100644 src/utils/common/promiseTimeout.ts diff --git a/src/components/MachineLearning/MLFileDebugView.tsx b/src/components/MachineLearning/MLFileDebugView.tsx index 7234bf604..db94d4e38 100644 --- a/src/components/MachineLearning/MLFileDebugView.tsx +++ b/src/components/MachineLearning/MLFileDebugView.tsx @@ -14,6 +14,7 @@ import { import { ibExtractFaceImageFromCrop } from 'utils/machineLearning/faceCrop'; import { FaceCropsRow, FaceImagesRow, ImageBitmapView } from './ImageViews'; import ssdMobileNetV2Service from 'services/machineLearning/ssdMobileNetV2Service'; +import { DEFAULT_ML_SYNC_CONFIG } from 'constants/machineLearning/config'; interface MLFileDebugViewProps { file: File; @@ -94,7 +95,8 @@ export default function MLFileDebugView(props: MLFileDebugViewProps) { console.log('detectedFaces: ', faceDetections.length); const objectDetections = await ssdMobileNetV2Service.detectObjects( - imageBitmap + imageBitmap, + DEFAULT_ML_SYNC_CONFIG.objectDetection.minScore ); console.log('detectedObjects: ', objectDetections); diff --git a/src/constants/machineLearning/config.ts b/src/constants/machineLearning/config.ts index 7f8d60201..0afd35215 100644 --- a/src/constants/machineLearning/config.ts +++ b/src/constants/machineLearning/config.ts @@ -43,10 +43,14 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = { // maxDistanceInsideCluster: 0.4, generateDebugInfo: true, }, - ObjectDetection: { + objectDetection: { method: 'SSDMobileNetV2', + minScore: 0.2, + }, + textDetection: { + method: 'Tesseract', + minAccuracy: 75, }, - TextDetection: { method: 'Tesseract' }, // tsne: { // samples: 200, // dim: 2, @@ -66,3 +70,5 @@ export const ML_SYNC_DOWNLOAD_TIMEOUT_MS = 300000; export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100; export const MAX_ML_SYNC_ERROR_COUNT = 4; + +export const ML_DETECTION_TIMEOUT_MS = 30000; diff --git a/src/services/machineLearning/machineLearningFactory.ts b/src/services/machineLearning/machineLearningFactory.ts index 46e2d411d..4a53e101c 100644 --- a/src/services/machineLearning/machineLearningFactory.ts +++ b/src/services/machineLearning/machineLearningFactory.ts @@ -172,10 +172,10 @@ export class LocalMLSyncContext implements MLSyncContext { ); this.objectDetectionService = MLFactory.getObjectDetectionService( - this.config.ObjectDetection.method + this.config.objectDetection.method ); this.textDetectionService = MLFactory.getTextDetectionService( - this.config.TextDetection.method + this.config.textDetection.method ); this.outOfSyncFiles = []; diff --git a/src/services/machineLearning/machineLearningService.ts b/src/services/machineLearning/machineLearningService.ts index e92a8f816..5b12adc15 100644 --- a/src/services/machineLearning/machineLearningService.ts +++ b/src/services/machineLearning/machineLearningService.ts @@ -25,11 +25,15 @@ import { MLFactory } from './machineLearningFactory'; import mlIDbStorage from 'utils/storage/mlIDbStorage'; import { getMLSyncConfig } from 'utils/machineLearning/config'; import { CustomError, parseServerError } from 'utils/error'; -import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config'; +import { + MAX_ML_SYNC_ERROR_COUNT, + ML_DETECTION_TIMEOUT_MS, +} from 'constants/machineLearning/config'; import FaceService from './faceService'; import PeopleService from './peopleService'; import ObjectService from './objectService'; import TextService from './textService'; +import { promiseWithTimeout } from 'utils/common/promiseTimeout'; class MachineLearningService { private initialized = false; // private faceDetectionService: FaceDetectionService; @@ -382,7 +386,9 @@ class MachineLearningService { localFile?: globalThis.File ) { const fileContext: MLSyncFileContext = { enteFile, localFile }; - fileContext.oldMlFile = await this.getMLFileData(enteFile.id); + const oldMlFile = (fileContext.oldMlFile = await this.getMLFileData( + enteFile.id + )); if ( fileContext.oldMlFile?.mlVersion === syncContext.config.mlVersion // TODO: reset mlversion of all files when user changes image source @@ -396,35 +402,50 @@ class MachineLearningService { } else if (fileContext.oldMlFile?.mlVersion) { newMlFile.mlVersion = fileContext.oldMlFile.mlVersion; } + try { + await FaceService.syncFileFaceDetections(syncContext, fileContext); - await FaceService.syncFileFaceDetections(syncContext, fileContext); + if (newMlFile.faces && newMlFile.faces.length > 0) { + await FaceService.syncFileFaceCrops(syncContext, fileContext); - if (newMlFile.faces && newMlFile.faces.length > 0) { - await FaceService.syncFileFaceCrops(syncContext, fileContext); + await FaceService.syncFileFaceAlignments( + syncContext, + fileContext + ); - await FaceService.syncFileFaceAlignments(syncContext, fileContext); + await FaceService.syncFileFaceEmbeddings( + syncContext, + fileContext + ); + } - await FaceService.syncFileFaceEmbeddings(syncContext, fileContext); + await ObjectService.syncFileObjectDetections( + syncContext, + fileContext + ); + + await promiseWithTimeout( + TextService.syncFileTextDetections(syncContext, fileContext), + ML_DETECTION_TIMEOUT_MS + ); + } catch (e) { + newMlFile.mlVersion = oldMlFile.mlVersion; + throw e; + } finally { + fileContext.tfImage && fileContext.tfImage.dispose(); + fileContext.imageBitmap && fileContext.imageBitmap.close(); + // console.log('8 TF Memory stats: ', tf.memory()); + newMlFile.errorCount = 0; + newMlFile.lastErrorMessage = undefined; + await this.persistMLFileData(syncContext, newMlFile); + + // TODO: enable once faceId changes go in + // await removeOldFaceCrops( + // fileContext.oldMlFile, + // fileContext.newMlFile + // ); } - await ObjectService.syncFileObjectDetections(syncContext, fileContext); - - await TextService.syncFileTextDetections(syncContext, fileContext); - - fileContext.tfImage && fileContext.tfImage.dispose(); - fileContext.imageBitmap && fileContext.imageBitmap.close(); - // console.log('8 TF Memory stats: ', tf.memory()); - - newMlFile.errorCount = 0; - newMlFile.lastErrorMessage = undefined; - await this.persistMLFileData(syncContext, newMlFile); - - // TODO: enable once faceId changes go in - // await removeOldFaceCrops( - // fileContext.oldMlFile, - // fileContext.newMlFile - // ); - return newMlFile; } diff --git a/src/services/machineLearning/objectService.ts b/src/services/machineLearning/objectService.ts index c4a1bd7d4..14286d7c1 100644 --- a/src/services/machineLearning/objectService.ts +++ b/src/services/machineLearning/objectService.ts @@ -40,7 +40,10 @@ class ObjectService { fileContext ); const objectDetections = - await syncContext.objectDetectionService.detectObjects(imageBitmap); + await syncContext.objectDetectionService.detectObjects( + imageBitmap, + syncContext.config.objectDetection.minScore + ); // console.log('3 TF Memory stats: ', tf.memory()); // TODO: reenable faces filtering based on width const detectedObjects = objectDetections?.map((detection) => { diff --git a/src/services/machineLearning/ssdMobileNetV2Service.ts b/src/services/machineLearning/ssdMobileNetV2Service.ts index eb1dad318..4371256bf 100644 --- a/src/services/machineLearning/ssdMobileNetV2Service.ts +++ b/src/services/machineLearning/ssdMobileNetV2Service.ts @@ -40,10 +40,17 @@ class SSDMobileNetV2 implements ObjectDetectionService { return this.ssdMobileNetV2Model; } - public async detectObjects(image: ImageBitmap): Promise { + public async detectObjects( + image: ImageBitmap, + minScore?: number + ): Promise { const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model(); const tfImage = tf.browser.fromPixels(image); - const detections = await ssdMobileNetV2Model.detect(tfImage); + const detections = await ssdMobileNetV2Model.detect( + tfImage, + undefined, + minScore + ); return detections; } diff --git a/src/services/machineLearning/textService.ts b/src/services/machineLearning/textService.ts index dbff151dd..c4bc4540e 100644 --- a/src/services/machineLearning/textService.ts +++ b/src/services/machineLearning/textService.ts @@ -42,12 +42,15 @@ class TextService { ) ); - const detectedText: DetectedText[] = textDetections.data.words.map( - ({ bbox, confidence, text }) => ({ + const detectedText: DetectedText[] = textDetections.data.words + .filter( + ({ confidence }) => + confidence >= syncContext.config.textDetection.minAccuracy + ) + .map(({ bbox, confidence, text }) => ({ fileID: fileContext.enteFile.id, detection: { bbox, confidence, word: text }, - }) - ); + })); newMlFile.text = detectedText; console.log( '[MLService] Detected text: ', diff --git a/src/types/machineLearning/index.ts b/src/types/machineLearning/index.ts index da18d70f1..2cbda8ebb 100644 --- a/src/types/machineLearning/index.ts +++ b/src/types/machineLearning/index.ts @@ -246,10 +246,12 @@ export interface FaceDetectionConfig { export interface ObjectDetectionConfig { method: ObjectDetectionMethod; + minScore: number; } export interface TextDetectionConfig { method: TextDetectionMethod; + minAccuracy: number; } export interface FaceCropConfig { @@ -295,8 +297,8 @@ export interface MLSyncConfig extends Config { faceAlignment: FaceAlignmentConfig; faceEmbedding: FaceEmbeddingConfig; faceClustering: FaceClusteringConfig; - ObjectDetection: ObjectDetectionConfig; - TextDetection: TextDetectionConfig; + objectDetection: ObjectDetectionConfig; + textDetection: TextDetectionConfig; tsne?: TSNEConfig; mlVersion: number; } @@ -378,7 +380,10 @@ export interface FaceDetectionService { export interface ObjectDetectionService { method: Versioned; // init(): Promise; - detectObjects(image: ImageBitmap): Promise; + detectObjects( + image: ImageBitmap, + minScore?: number + ): Promise; dispose(): Promise; } diff --git a/src/utils/common/promiseTimeout.ts b/src/utils/common/promiseTimeout.ts new file mode 100644 index 000000000..c7da91885 --- /dev/null +++ b/src/utils/common/promiseTimeout.ts @@ -0,0 +1,14 @@ +import { CustomError } from 'utils/error'; + +export const promiseWithTimeout = async ( + request: Promise, + timeout: number +) => { + const rejectOnTimeout = new Promise((_, reject) => { + setTimeout( + () => reject(Error(CustomError.WAIT_TIME_EXCEEDED)), + timeout + ); + }); + await Promise.race([request, rejectOnTimeout]); +};