- added timeout

- added accuracy and minScore for text and object detection respectively
This commit is contained in:
Abhinav 2022-03-23 03:47:56 +05:30
parent ec1d3e6e4e
commit ce06ac7baf
9 changed files with 101 additions and 40 deletions

View file

@ -14,6 +14,7 @@ import {
import { ibExtractFaceImageFromCrop } from 'utils/machineLearning/faceCrop'; import { ibExtractFaceImageFromCrop } from 'utils/machineLearning/faceCrop';
import { FaceCropsRow, FaceImagesRow, ImageBitmapView } from './ImageViews'; import { FaceCropsRow, FaceImagesRow, ImageBitmapView } from './ImageViews';
import ssdMobileNetV2Service from 'services/machineLearning/ssdMobileNetV2Service'; import ssdMobileNetV2Service from 'services/machineLearning/ssdMobileNetV2Service';
import { DEFAULT_ML_SYNC_CONFIG } from 'constants/machineLearning/config';
interface MLFileDebugViewProps { interface MLFileDebugViewProps {
file: File; file: File;
@ -94,7 +95,8 @@ export default function MLFileDebugView(props: MLFileDebugViewProps) {
console.log('detectedFaces: ', faceDetections.length); console.log('detectedFaces: ', faceDetections.length);
const objectDetections = await ssdMobileNetV2Service.detectObjects( const objectDetections = await ssdMobileNetV2Service.detectObjects(
imageBitmap imageBitmap,
DEFAULT_ML_SYNC_CONFIG.objectDetection.minScore
); );
console.log('detectedObjects: ', objectDetections); console.log('detectedObjects: ', objectDetections);

View file

@ -43,10 +43,14 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
// maxDistanceInsideCluster: 0.4, // maxDistanceInsideCluster: 0.4,
generateDebugInfo: true, generateDebugInfo: true,
}, },
ObjectDetection: { objectDetection: {
method: 'SSDMobileNetV2', method: 'SSDMobileNetV2',
minScore: 0.2,
},
textDetection: {
method: 'Tesseract',
minAccuracy: 75,
}, },
TextDetection: { method: 'Tesseract' },
// tsne: { // tsne: {
// samples: 200, // samples: 200,
// dim: 2, // dim: 2,
@ -66,3 +70,5 @@ export const ML_SYNC_DOWNLOAD_TIMEOUT_MS = 300000;
export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100; export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100;
export const MAX_ML_SYNC_ERROR_COUNT = 4; export const MAX_ML_SYNC_ERROR_COUNT = 4;
export const ML_DETECTION_TIMEOUT_MS = 30000;

View file

@ -172,10 +172,10 @@ export class LocalMLSyncContext implements MLSyncContext {
); );
this.objectDetectionService = MLFactory.getObjectDetectionService( this.objectDetectionService = MLFactory.getObjectDetectionService(
this.config.ObjectDetection.method this.config.objectDetection.method
); );
this.textDetectionService = MLFactory.getTextDetectionService( this.textDetectionService = MLFactory.getTextDetectionService(
this.config.TextDetection.method this.config.textDetection.method
); );
this.outOfSyncFiles = []; this.outOfSyncFiles = [];

View file

@ -25,11 +25,15 @@ import { MLFactory } from './machineLearningFactory';
import mlIDbStorage from 'utils/storage/mlIDbStorage'; import mlIDbStorage from 'utils/storage/mlIDbStorage';
import { getMLSyncConfig } from 'utils/machineLearning/config'; import { getMLSyncConfig } from 'utils/machineLearning/config';
import { CustomError, parseServerError } from 'utils/error'; import { CustomError, parseServerError } from 'utils/error';
import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config'; import {
MAX_ML_SYNC_ERROR_COUNT,
ML_DETECTION_TIMEOUT_MS,
} from 'constants/machineLearning/config';
import FaceService from './faceService'; import FaceService from './faceService';
import PeopleService from './peopleService'; import PeopleService from './peopleService';
import ObjectService from './objectService'; import ObjectService from './objectService';
import TextService from './textService'; import TextService from './textService';
import { promiseWithTimeout } from 'utils/common/promiseTimeout';
class MachineLearningService { class MachineLearningService {
private initialized = false; private initialized = false;
// private faceDetectionService: FaceDetectionService; // private faceDetectionService: FaceDetectionService;
@ -382,7 +386,9 @@ class MachineLearningService {
localFile?: globalThis.File localFile?: globalThis.File
) { ) {
const fileContext: MLSyncFileContext = { enteFile, localFile }; const fileContext: MLSyncFileContext = { enteFile, localFile };
fileContext.oldMlFile = await this.getMLFileData(enteFile.id); const oldMlFile = (fileContext.oldMlFile = await this.getMLFileData(
enteFile.id
));
if ( if (
fileContext.oldMlFile?.mlVersion === syncContext.config.mlVersion fileContext.oldMlFile?.mlVersion === syncContext.config.mlVersion
// TODO: reset mlversion of all files when user changes image source // TODO: reset mlversion of all files when user changes image source
@ -396,35 +402,50 @@ class MachineLearningService {
} else if (fileContext.oldMlFile?.mlVersion) { } else if (fileContext.oldMlFile?.mlVersion) {
newMlFile.mlVersion = fileContext.oldMlFile.mlVersion; newMlFile.mlVersion = fileContext.oldMlFile.mlVersion;
} }
try {
await FaceService.syncFileFaceDetections(syncContext, fileContext);
await FaceService.syncFileFaceDetections(syncContext, fileContext); if (newMlFile.faces && newMlFile.faces.length > 0) {
await FaceService.syncFileFaceCrops(syncContext, fileContext);
if (newMlFile.faces && newMlFile.faces.length > 0) { await FaceService.syncFileFaceAlignments(
await FaceService.syncFileFaceCrops(syncContext, fileContext); syncContext,
fileContext
);
await FaceService.syncFileFaceAlignments(syncContext, fileContext); await FaceService.syncFileFaceEmbeddings(
syncContext,
fileContext
);
}
await FaceService.syncFileFaceEmbeddings(syncContext, fileContext); await ObjectService.syncFileObjectDetections(
syncContext,
fileContext
);
await promiseWithTimeout(
TextService.syncFileTextDetections(syncContext, fileContext),
ML_DETECTION_TIMEOUT_MS
);
} catch (e) {
newMlFile.mlVersion = oldMlFile.mlVersion;
throw e;
} finally {
fileContext.tfImage && fileContext.tfImage.dispose();
fileContext.imageBitmap && fileContext.imageBitmap.close();
// console.log('8 TF Memory stats: ', tf.memory());
newMlFile.errorCount = 0;
newMlFile.lastErrorMessage = undefined;
await this.persistMLFileData(syncContext, newMlFile);
// TODO: enable once faceId changes go in
// await removeOldFaceCrops(
// fileContext.oldMlFile,
// fileContext.newMlFile
// );
} }
await ObjectService.syncFileObjectDetections(syncContext, fileContext);
await TextService.syncFileTextDetections(syncContext, fileContext);
fileContext.tfImage && fileContext.tfImage.dispose();
fileContext.imageBitmap && fileContext.imageBitmap.close();
// console.log('8 TF Memory stats: ', tf.memory());
newMlFile.errorCount = 0;
newMlFile.lastErrorMessage = undefined;
await this.persistMLFileData(syncContext, newMlFile);
// TODO: enable once faceId changes go in
// await removeOldFaceCrops(
// fileContext.oldMlFile,
// fileContext.newMlFile
// );
return newMlFile; return newMlFile;
} }

View file

@ -40,7 +40,10 @@ class ObjectService {
fileContext fileContext
); );
const objectDetections = const objectDetections =
await syncContext.objectDetectionService.detectObjects(imageBitmap); await syncContext.objectDetectionService.detectObjects(
imageBitmap,
syncContext.config.objectDetection.minScore
);
// console.log('3 TF Memory stats: ', tf.memory()); // console.log('3 TF Memory stats: ', tf.memory());
// TODO: reenable faces filtering based on width // TODO: reenable faces filtering based on width
const detectedObjects = objectDetections?.map((detection) => { const detectedObjects = objectDetections?.map((detection) => {

View file

@ -40,10 +40,17 @@ class SSDMobileNetV2 implements ObjectDetectionService {
return this.ssdMobileNetV2Model; return this.ssdMobileNetV2Model;
} }
public async detectObjects(image: ImageBitmap): Promise<ObjectDetection[]> { public async detectObjects(
image: ImageBitmap,
minScore?: number
): Promise<ObjectDetection[]> {
const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model(); const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model();
const tfImage = tf.browser.fromPixels(image); const tfImage = tf.browser.fromPixels(image);
const detections = await ssdMobileNetV2Model.detect(tfImage); const detections = await ssdMobileNetV2Model.detect(
tfImage,
undefined,
minScore
);
return detections; return detections;
} }

View file

@ -42,12 +42,15 @@ class TextService {
) )
); );
const detectedText: DetectedText[] = textDetections.data.words.map( const detectedText: DetectedText[] = textDetections.data.words
({ bbox, confidence, text }) => ({ .filter(
({ confidence }) =>
confidence >= syncContext.config.textDetection.minAccuracy
)
.map(({ bbox, confidence, text }) => ({
fileID: fileContext.enteFile.id, fileID: fileContext.enteFile.id,
detection: { bbox, confidence, word: text }, detection: { bbox, confidence, word: text },
}) }));
);
newMlFile.text = detectedText; newMlFile.text = detectedText;
console.log( console.log(
'[MLService] Detected text: ', '[MLService] Detected text: ',

View file

@ -246,10 +246,12 @@ export interface FaceDetectionConfig {
export interface ObjectDetectionConfig { export interface ObjectDetectionConfig {
method: ObjectDetectionMethod; method: ObjectDetectionMethod;
minScore: number;
} }
export interface TextDetectionConfig { export interface TextDetectionConfig {
method: TextDetectionMethod; method: TextDetectionMethod;
minAccuracy: number;
} }
export interface FaceCropConfig { export interface FaceCropConfig {
@ -295,8 +297,8 @@ export interface MLSyncConfig extends Config {
faceAlignment: FaceAlignmentConfig; faceAlignment: FaceAlignmentConfig;
faceEmbedding: FaceEmbeddingConfig; faceEmbedding: FaceEmbeddingConfig;
faceClustering: FaceClusteringConfig; faceClustering: FaceClusteringConfig;
ObjectDetection: ObjectDetectionConfig; objectDetection: ObjectDetectionConfig;
TextDetection: TextDetectionConfig; textDetection: TextDetectionConfig;
tsne?: TSNEConfig; tsne?: TSNEConfig;
mlVersion: number; mlVersion: number;
} }
@ -378,7 +380,10 @@ export interface FaceDetectionService {
export interface ObjectDetectionService { export interface ObjectDetectionService {
method: Versioned<ObjectDetectionMethod>; method: Versioned<ObjectDetectionMethod>;
// init(): Promise<void>; // init(): Promise<void>;
detectObjects(image: ImageBitmap): Promise<ObjectDetection[]>; detectObjects(
image: ImageBitmap,
minScore?: number
): Promise<ObjectDetection[]>;
dispose(): Promise<void>; dispose(): Promise<void>;
} }

View file

@ -0,0 +1,14 @@
import { CustomError } from 'utils/error';
export const promiseWithTimeout = async (
request: Promise<any>,
timeout: number
) => {
const rejectOnTimeout = new Promise((_, reject) => {
setTimeout(
() => reject(Error(CustomError.WAIT_TIME_EXCEEDED)),
timeout
);
});
await Promise.race([request, rejectOnTimeout]);
};