- added timeout

- added accuracy and minScore for text and object detection respectively
This commit is contained in:
Abhinav 2022-03-23 03:47:56 +05:30
parent ec1d3e6e4e
commit ce06ac7baf
9 changed files with 101 additions and 40 deletions

View file

@ -14,6 +14,7 @@ import {
import { ibExtractFaceImageFromCrop } from 'utils/machineLearning/faceCrop';
import { FaceCropsRow, FaceImagesRow, ImageBitmapView } from './ImageViews';
import ssdMobileNetV2Service from 'services/machineLearning/ssdMobileNetV2Service';
import { DEFAULT_ML_SYNC_CONFIG } from 'constants/machineLearning/config';
interface MLFileDebugViewProps {
file: File;
@ -94,7 +95,8 @@ export default function MLFileDebugView(props: MLFileDebugViewProps) {
console.log('detectedFaces: ', faceDetections.length);
const objectDetections = await ssdMobileNetV2Service.detectObjects(
imageBitmap
imageBitmap,
DEFAULT_ML_SYNC_CONFIG.objectDetection.minScore
);
console.log('detectedObjects: ', objectDetections);

View file

@ -43,10 +43,14 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
// maxDistanceInsideCluster: 0.4,
generateDebugInfo: true,
},
ObjectDetection: {
objectDetection: {
method: 'SSDMobileNetV2',
minScore: 0.2,
},
textDetection: {
method: 'Tesseract',
minAccuracy: 75,
},
TextDetection: { method: 'Tesseract' },
// tsne: {
// samples: 200,
// dim: 2,
@ -66,3 +70,5 @@ export const ML_SYNC_DOWNLOAD_TIMEOUT_MS = 300000;
export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100;
export const MAX_ML_SYNC_ERROR_COUNT = 4;
export const ML_DETECTION_TIMEOUT_MS = 30000;

View file

@ -172,10 +172,10 @@ export class LocalMLSyncContext implements MLSyncContext {
);
this.objectDetectionService = MLFactory.getObjectDetectionService(
this.config.ObjectDetection.method
this.config.objectDetection.method
);
this.textDetectionService = MLFactory.getTextDetectionService(
this.config.TextDetection.method
this.config.textDetection.method
);
this.outOfSyncFiles = [];

View file

@ -25,11 +25,15 @@ import { MLFactory } from './machineLearningFactory';
import mlIDbStorage from 'utils/storage/mlIDbStorage';
import { getMLSyncConfig } from 'utils/machineLearning/config';
import { CustomError, parseServerError } from 'utils/error';
import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config';
import {
MAX_ML_SYNC_ERROR_COUNT,
ML_DETECTION_TIMEOUT_MS,
} from 'constants/machineLearning/config';
import FaceService from './faceService';
import PeopleService from './peopleService';
import ObjectService from './objectService';
import TextService from './textService';
import { promiseWithTimeout } from 'utils/common/promiseTimeout';
class MachineLearningService {
private initialized = false;
// private faceDetectionService: FaceDetectionService;
@ -382,7 +386,9 @@ class MachineLearningService {
localFile?: globalThis.File
) {
const fileContext: MLSyncFileContext = { enteFile, localFile };
fileContext.oldMlFile = await this.getMLFileData(enteFile.id);
const oldMlFile = (fileContext.oldMlFile = await this.getMLFileData(
enteFile.id
));
if (
fileContext.oldMlFile?.mlVersion === syncContext.config.mlVersion
// TODO: reset mlversion of all files when user changes image source
@ -396,25 +402,39 @@ class MachineLearningService {
} else if (fileContext.oldMlFile?.mlVersion) {
newMlFile.mlVersion = fileContext.oldMlFile.mlVersion;
}
try {
await FaceService.syncFileFaceDetections(syncContext, fileContext);
if (newMlFile.faces && newMlFile.faces.length > 0) {
await FaceService.syncFileFaceCrops(syncContext, fileContext);
await FaceService.syncFileFaceAlignments(syncContext, fileContext);
await FaceService.syncFileFaceAlignments(
syncContext,
fileContext
);
await FaceService.syncFileFaceEmbeddings(syncContext, fileContext);
await FaceService.syncFileFaceEmbeddings(
syncContext,
fileContext
);
}
await ObjectService.syncFileObjectDetections(syncContext, fileContext);
await TextService.syncFileTextDetections(syncContext, fileContext);
await ObjectService.syncFileObjectDetections(
syncContext,
fileContext
);
await promiseWithTimeout(
TextService.syncFileTextDetections(syncContext, fileContext),
ML_DETECTION_TIMEOUT_MS
);
} catch (e) {
newMlFile.mlVersion = oldMlFile.mlVersion;
throw e;
} finally {
fileContext.tfImage && fileContext.tfImage.dispose();
fileContext.imageBitmap && fileContext.imageBitmap.close();
// console.log('8 TF Memory stats: ', tf.memory());
newMlFile.errorCount = 0;
newMlFile.lastErrorMessage = undefined;
await this.persistMLFileData(syncContext, newMlFile);
@ -424,6 +444,7 @@ class MachineLearningService {
// fileContext.oldMlFile,
// fileContext.newMlFile
// );
}
return newMlFile;
}

View file

@ -40,7 +40,10 @@ class ObjectService {
fileContext
);
const objectDetections =
await syncContext.objectDetectionService.detectObjects(imageBitmap);
await syncContext.objectDetectionService.detectObjects(
imageBitmap,
syncContext.config.objectDetection.minScore
);
// console.log('3 TF Memory stats: ', tf.memory());
// TODO: reenable faces filtering based on width
const detectedObjects = objectDetections?.map((detection) => {

View file

@ -40,10 +40,17 @@ class SSDMobileNetV2 implements ObjectDetectionService {
return this.ssdMobileNetV2Model;
}
public async detectObjects(image: ImageBitmap): Promise<ObjectDetection[]> {
public async detectObjects(
image: ImageBitmap,
minScore?: number
): Promise<ObjectDetection[]> {
const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model();
const tfImage = tf.browser.fromPixels(image);
const detections = await ssdMobileNetV2Model.detect(tfImage);
const detections = await ssdMobileNetV2Model.detect(
tfImage,
undefined,
minScore
);
return detections;
}

View file

@ -42,12 +42,15 @@ class TextService {
)
);
const detectedText: DetectedText[] = textDetections.data.words.map(
({ bbox, confidence, text }) => ({
const detectedText: DetectedText[] = textDetections.data.words
.filter(
({ confidence }) =>
confidence >= syncContext.config.textDetection.minAccuracy
)
.map(({ bbox, confidence, text }) => ({
fileID: fileContext.enteFile.id,
detection: { bbox, confidence, word: text },
})
);
}));
newMlFile.text = detectedText;
console.log(
'[MLService] Detected text: ',

View file

@ -246,10 +246,12 @@ export interface FaceDetectionConfig {
export interface ObjectDetectionConfig {
method: ObjectDetectionMethod;
minScore: number;
}
export interface TextDetectionConfig {
method: TextDetectionMethod;
minAccuracy: number;
}
export interface FaceCropConfig {
@ -295,8 +297,8 @@ export interface MLSyncConfig extends Config {
faceAlignment: FaceAlignmentConfig;
faceEmbedding: FaceEmbeddingConfig;
faceClustering: FaceClusteringConfig;
ObjectDetection: ObjectDetectionConfig;
TextDetection: TextDetectionConfig;
objectDetection: ObjectDetectionConfig;
textDetection: TextDetectionConfig;
tsne?: TSNEConfig;
mlVersion: number;
}
@ -378,7 +380,10 @@ export interface FaceDetectionService {
export interface ObjectDetectionService {
method: Versioned<ObjectDetectionMethod>;
// init(): Promise<void>;
detectObjects(image: ImageBitmap): Promise<ObjectDetection[]>;
detectObjects(
image: ImageBitmap,
minScore?: number
): Promise<ObjectDetection[]>;
dispose(): Promise<void>;
}

View file

@ -0,0 +1,14 @@
import { CustomError } from 'utils/error';
export const promiseWithTimeout = async (
request: Promise<any>,
timeout: number
) => {
const rejectOnTimeout = new Promise((_, reject) => {
setTimeout(
() => reject(Error(CustomError.WAIT_TIME_EXCEEDED)),
timeout
);
});
await Promise.race([request, rejectOnTimeout]);
};