- added timeout
- added accuracy and minScore for text and object detection respectively
This commit is contained in:
parent
ec1d3e6e4e
commit
ce06ac7baf
|
@ -14,6 +14,7 @@ import {
|
||||||
import { ibExtractFaceImageFromCrop } from 'utils/machineLearning/faceCrop';
|
import { ibExtractFaceImageFromCrop } from 'utils/machineLearning/faceCrop';
|
||||||
import { FaceCropsRow, FaceImagesRow, ImageBitmapView } from './ImageViews';
|
import { FaceCropsRow, FaceImagesRow, ImageBitmapView } from './ImageViews';
|
||||||
import ssdMobileNetV2Service from 'services/machineLearning/ssdMobileNetV2Service';
|
import ssdMobileNetV2Service from 'services/machineLearning/ssdMobileNetV2Service';
|
||||||
|
import { DEFAULT_ML_SYNC_CONFIG } from 'constants/machineLearning/config';
|
||||||
|
|
||||||
interface MLFileDebugViewProps {
|
interface MLFileDebugViewProps {
|
||||||
file: File;
|
file: File;
|
||||||
|
@ -94,7 +95,8 @@ export default function MLFileDebugView(props: MLFileDebugViewProps) {
|
||||||
console.log('detectedFaces: ', faceDetections.length);
|
console.log('detectedFaces: ', faceDetections.length);
|
||||||
|
|
||||||
const objectDetections = await ssdMobileNetV2Service.detectObjects(
|
const objectDetections = await ssdMobileNetV2Service.detectObjects(
|
||||||
imageBitmap
|
imageBitmap,
|
||||||
|
DEFAULT_ML_SYNC_CONFIG.objectDetection.minScore
|
||||||
);
|
);
|
||||||
console.log('detectedObjects: ', objectDetections);
|
console.log('detectedObjects: ', objectDetections);
|
||||||
|
|
||||||
|
|
|
@ -43,10 +43,14 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
|
||||||
// maxDistanceInsideCluster: 0.4,
|
// maxDistanceInsideCluster: 0.4,
|
||||||
generateDebugInfo: true,
|
generateDebugInfo: true,
|
||||||
},
|
},
|
||||||
ObjectDetection: {
|
objectDetection: {
|
||||||
method: 'SSDMobileNetV2',
|
method: 'SSDMobileNetV2',
|
||||||
|
minScore: 0.2,
|
||||||
|
},
|
||||||
|
textDetection: {
|
||||||
|
method: 'Tesseract',
|
||||||
|
minAccuracy: 75,
|
||||||
},
|
},
|
||||||
TextDetection: { method: 'Tesseract' },
|
|
||||||
// tsne: {
|
// tsne: {
|
||||||
// samples: 200,
|
// samples: 200,
|
||||||
// dim: 2,
|
// dim: 2,
|
||||||
|
@ -66,3 +70,5 @@ export const ML_SYNC_DOWNLOAD_TIMEOUT_MS = 300000;
|
||||||
export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100;
|
export const MAX_FACE_DISTANCE_PERCENT = Math.sqrt(2) / 100;
|
||||||
|
|
||||||
export const MAX_ML_SYNC_ERROR_COUNT = 4;
|
export const MAX_ML_SYNC_ERROR_COUNT = 4;
|
||||||
|
|
||||||
|
export const ML_DETECTION_TIMEOUT_MS = 30000;
|
||||||
|
|
|
@ -172,10 +172,10 @@ export class LocalMLSyncContext implements MLSyncContext {
|
||||||
);
|
);
|
||||||
|
|
||||||
this.objectDetectionService = MLFactory.getObjectDetectionService(
|
this.objectDetectionService = MLFactory.getObjectDetectionService(
|
||||||
this.config.ObjectDetection.method
|
this.config.objectDetection.method
|
||||||
);
|
);
|
||||||
this.textDetectionService = MLFactory.getTextDetectionService(
|
this.textDetectionService = MLFactory.getTextDetectionService(
|
||||||
this.config.TextDetection.method
|
this.config.textDetection.method
|
||||||
);
|
);
|
||||||
|
|
||||||
this.outOfSyncFiles = [];
|
this.outOfSyncFiles = [];
|
||||||
|
|
|
@ -25,11 +25,15 @@ import { MLFactory } from './machineLearningFactory';
|
||||||
import mlIDbStorage from 'utils/storage/mlIDbStorage';
|
import mlIDbStorage from 'utils/storage/mlIDbStorage';
|
||||||
import { getMLSyncConfig } from 'utils/machineLearning/config';
|
import { getMLSyncConfig } from 'utils/machineLearning/config';
|
||||||
import { CustomError, parseServerError } from 'utils/error';
|
import { CustomError, parseServerError } from 'utils/error';
|
||||||
import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config';
|
import {
|
||||||
|
MAX_ML_SYNC_ERROR_COUNT,
|
||||||
|
ML_DETECTION_TIMEOUT_MS,
|
||||||
|
} from 'constants/machineLearning/config';
|
||||||
import FaceService from './faceService';
|
import FaceService from './faceService';
|
||||||
import PeopleService from './peopleService';
|
import PeopleService from './peopleService';
|
||||||
import ObjectService from './objectService';
|
import ObjectService from './objectService';
|
||||||
import TextService from './textService';
|
import TextService from './textService';
|
||||||
|
import { promiseWithTimeout } from 'utils/common/promiseTimeout';
|
||||||
class MachineLearningService {
|
class MachineLearningService {
|
||||||
private initialized = false;
|
private initialized = false;
|
||||||
// private faceDetectionService: FaceDetectionService;
|
// private faceDetectionService: FaceDetectionService;
|
||||||
|
@ -382,7 +386,9 @@ class MachineLearningService {
|
||||||
localFile?: globalThis.File
|
localFile?: globalThis.File
|
||||||
) {
|
) {
|
||||||
const fileContext: MLSyncFileContext = { enteFile, localFile };
|
const fileContext: MLSyncFileContext = { enteFile, localFile };
|
||||||
fileContext.oldMlFile = await this.getMLFileData(enteFile.id);
|
const oldMlFile = (fileContext.oldMlFile = await this.getMLFileData(
|
||||||
|
enteFile.id
|
||||||
|
));
|
||||||
if (
|
if (
|
||||||
fileContext.oldMlFile?.mlVersion === syncContext.config.mlVersion
|
fileContext.oldMlFile?.mlVersion === syncContext.config.mlVersion
|
||||||
// TODO: reset mlversion of all files when user changes image source
|
// TODO: reset mlversion of all files when user changes image source
|
||||||
|
@ -396,25 +402,39 @@ class MachineLearningService {
|
||||||
} else if (fileContext.oldMlFile?.mlVersion) {
|
} else if (fileContext.oldMlFile?.mlVersion) {
|
||||||
newMlFile.mlVersion = fileContext.oldMlFile.mlVersion;
|
newMlFile.mlVersion = fileContext.oldMlFile.mlVersion;
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
await FaceService.syncFileFaceDetections(syncContext, fileContext);
|
await FaceService.syncFileFaceDetections(syncContext, fileContext);
|
||||||
|
|
||||||
if (newMlFile.faces && newMlFile.faces.length > 0) {
|
if (newMlFile.faces && newMlFile.faces.length > 0) {
|
||||||
await FaceService.syncFileFaceCrops(syncContext, fileContext);
|
await FaceService.syncFileFaceCrops(syncContext, fileContext);
|
||||||
|
|
||||||
await FaceService.syncFileFaceAlignments(syncContext, fileContext);
|
await FaceService.syncFileFaceAlignments(
|
||||||
|
syncContext,
|
||||||
|
fileContext
|
||||||
|
);
|
||||||
|
|
||||||
await FaceService.syncFileFaceEmbeddings(syncContext, fileContext);
|
await FaceService.syncFileFaceEmbeddings(
|
||||||
|
syncContext,
|
||||||
|
fileContext
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
await ObjectService.syncFileObjectDetections(syncContext, fileContext);
|
await ObjectService.syncFileObjectDetections(
|
||||||
|
syncContext,
|
||||||
await TextService.syncFileTextDetections(syncContext, fileContext);
|
fileContext
|
||||||
|
);
|
||||||
|
|
||||||
|
await promiseWithTimeout(
|
||||||
|
TextService.syncFileTextDetections(syncContext, fileContext),
|
||||||
|
ML_DETECTION_TIMEOUT_MS
|
||||||
|
);
|
||||||
|
} catch (e) {
|
||||||
|
newMlFile.mlVersion = oldMlFile.mlVersion;
|
||||||
|
throw e;
|
||||||
|
} finally {
|
||||||
fileContext.tfImage && fileContext.tfImage.dispose();
|
fileContext.tfImage && fileContext.tfImage.dispose();
|
||||||
fileContext.imageBitmap && fileContext.imageBitmap.close();
|
fileContext.imageBitmap && fileContext.imageBitmap.close();
|
||||||
// console.log('8 TF Memory stats: ', tf.memory());
|
// console.log('8 TF Memory stats: ', tf.memory());
|
||||||
|
|
||||||
newMlFile.errorCount = 0;
|
newMlFile.errorCount = 0;
|
||||||
newMlFile.lastErrorMessage = undefined;
|
newMlFile.lastErrorMessage = undefined;
|
||||||
await this.persistMLFileData(syncContext, newMlFile);
|
await this.persistMLFileData(syncContext, newMlFile);
|
||||||
|
@ -424,6 +444,7 @@ class MachineLearningService {
|
||||||
// fileContext.oldMlFile,
|
// fileContext.oldMlFile,
|
||||||
// fileContext.newMlFile
|
// fileContext.newMlFile
|
||||||
// );
|
// );
|
||||||
|
}
|
||||||
|
|
||||||
return newMlFile;
|
return newMlFile;
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,10 @@ class ObjectService {
|
||||||
fileContext
|
fileContext
|
||||||
);
|
);
|
||||||
const objectDetections =
|
const objectDetections =
|
||||||
await syncContext.objectDetectionService.detectObjects(imageBitmap);
|
await syncContext.objectDetectionService.detectObjects(
|
||||||
|
imageBitmap,
|
||||||
|
syncContext.config.objectDetection.minScore
|
||||||
|
);
|
||||||
// console.log('3 TF Memory stats: ', tf.memory());
|
// console.log('3 TF Memory stats: ', tf.memory());
|
||||||
// TODO: reenable faces filtering based on width
|
// TODO: reenable faces filtering based on width
|
||||||
const detectedObjects = objectDetections?.map((detection) => {
|
const detectedObjects = objectDetections?.map((detection) => {
|
||||||
|
|
|
@ -40,10 +40,17 @@ class SSDMobileNetV2 implements ObjectDetectionService {
|
||||||
return this.ssdMobileNetV2Model;
|
return this.ssdMobileNetV2Model;
|
||||||
}
|
}
|
||||||
|
|
||||||
public async detectObjects(image: ImageBitmap): Promise<ObjectDetection[]> {
|
public async detectObjects(
|
||||||
|
image: ImageBitmap,
|
||||||
|
minScore?: number
|
||||||
|
): Promise<ObjectDetection[]> {
|
||||||
const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model();
|
const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model();
|
||||||
const tfImage = tf.browser.fromPixels(image);
|
const tfImage = tf.browser.fromPixels(image);
|
||||||
const detections = await ssdMobileNetV2Model.detect(tfImage);
|
const detections = await ssdMobileNetV2Model.detect(
|
||||||
|
tfImage,
|
||||||
|
undefined,
|
||||||
|
minScore
|
||||||
|
);
|
||||||
return detections;
|
return detections;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,12 +42,15 @@ class TextService {
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
const detectedText: DetectedText[] = textDetections.data.words.map(
|
const detectedText: DetectedText[] = textDetections.data.words
|
||||||
({ bbox, confidence, text }) => ({
|
.filter(
|
||||||
|
({ confidence }) =>
|
||||||
|
confidence >= syncContext.config.textDetection.minAccuracy
|
||||||
|
)
|
||||||
|
.map(({ bbox, confidence, text }) => ({
|
||||||
fileID: fileContext.enteFile.id,
|
fileID: fileContext.enteFile.id,
|
||||||
detection: { bbox, confidence, word: text },
|
detection: { bbox, confidence, word: text },
|
||||||
})
|
}));
|
||||||
);
|
|
||||||
newMlFile.text = detectedText;
|
newMlFile.text = detectedText;
|
||||||
console.log(
|
console.log(
|
||||||
'[MLService] Detected text: ',
|
'[MLService] Detected text: ',
|
||||||
|
|
|
@ -246,10 +246,12 @@ export interface FaceDetectionConfig {
|
||||||
|
|
||||||
export interface ObjectDetectionConfig {
|
export interface ObjectDetectionConfig {
|
||||||
method: ObjectDetectionMethod;
|
method: ObjectDetectionMethod;
|
||||||
|
minScore: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface TextDetectionConfig {
|
export interface TextDetectionConfig {
|
||||||
method: TextDetectionMethod;
|
method: TextDetectionMethod;
|
||||||
|
minAccuracy: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface FaceCropConfig {
|
export interface FaceCropConfig {
|
||||||
|
@ -295,8 +297,8 @@ export interface MLSyncConfig extends Config {
|
||||||
faceAlignment: FaceAlignmentConfig;
|
faceAlignment: FaceAlignmentConfig;
|
||||||
faceEmbedding: FaceEmbeddingConfig;
|
faceEmbedding: FaceEmbeddingConfig;
|
||||||
faceClustering: FaceClusteringConfig;
|
faceClustering: FaceClusteringConfig;
|
||||||
ObjectDetection: ObjectDetectionConfig;
|
objectDetection: ObjectDetectionConfig;
|
||||||
TextDetection: TextDetectionConfig;
|
textDetection: TextDetectionConfig;
|
||||||
tsne?: TSNEConfig;
|
tsne?: TSNEConfig;
|
||||||
mlVersion: number;
|
mlVersion: number;
|
||||||
}
|
}
|
||||||
|
@ -378,7 +380,10 @@ export interface FaceDetectionService {
|
||||||
export interface ObjectDetectionService {
|
export interface ObjectDetectionService {
|
||||||
method: Versioned<ObjectDetectionMethod>;
|
method: Versioned<ObjectDetectionMethod>;
|
||||||
// init(): Promise<void>;
|
// init(): Promise<void>;
|
||||||
detectObjects(image: ImageBitmap): Promise<ObjectDetection[]>;
|
detectObjects(
|
||||||
|
image: ImageBitmap,
|
||||||
|
minScore?: number
|
||||||
|
): Promise<ObjectDetection[]>;
|
||||||
dispose(): Promise<void>;
|
dispose(): Promise<void>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
14
src/utils/common/promiseTimeout.ts
Normal file
14
src/utils/common/promiseTimeout.ts
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
import { CustomError } from 'utils/error';
|
||||||
|
|
||||||
|
export const promiseWithTimeout = async (
|
||||||
|
request: Promise<any>,
|
||||||
|
timeout: number
|
||||||
|
) => {
|
||||||
|
const rejectOnTimeout = new Promise((_, reject) => {
|
||||||
|
setTimeout(
|
||||||
|
() => reject(Error(CustomError.WAIT_TIME_EXCEEDED)),
|
||||||
|
timeout
|
||||||
|
);
|
||||||
|
});
|
||||||
|
await Promise.race([request, rejectOnTimeout]);
|
||||||
|
};
|
Loading…
Reference in a new issue