add MLFactory getObjectDetectionService logic
This commit is contained in:
parent
e0f136a19e
commit
ab4b7e3247
|
@ -15,6 +15,8 @@ import {
|
||||||
ClusteringMethod,
|
ClusteringMethod,
|
||||||
ClusteringService,
|
ClusteringService,
|
||||||
MLLibraryData,
|
MLLibraryData,
|
||||||
|
ObjectDetectionService,
|
||||||
|
ObjectDetectionMethod,
|
||||||
} from 'types/machineLearning';
|
} from 'types/machineLearning';
|
||||||
import { CONCURRENCY } from 'utils/common/concurrency';
|
import { CONCURRENCY } from 'utils/common/concurrency';
|
||||||
import { ComlinkWorker, getDedicatedCryptoWorker } from 'utils/crypto';
|
import { ComlinkWorker, getDedicatedCryptoWorker } from 'utils/crypto';
|
||||||
|
@ -25,6 +27,7 @@ import hdbscanClusteringService from './hdbscanClusteringService';
|
||||||
import blazeFaceDetectionService from './blazeFaceDetectionService';
|
import blazeFaceDetectionService from './blazeFaceDetectionService';
|
||||||
import mobileFaceNetEmbeddingService from './mobileFaceNetEmbeddingService';
|
import mobileFaceNetEmbeddingService from './mobileFaceNetEmbeddingService';
|
||||||
import dbscanClusteringService from './dbscanClusteringService';
|
import dbscanClusteringService from './dbscanClusteringService';
|
||||||
|
import ssdMobileNetV2Service from './ssdMobileNetV2Service';
|
||||||
|
|
||||||
export class MLFactory {
|
export class MLFactory {
|
||||||
public static getFaceDetectionService(
|
public static getFaceDetectionService(
|
||||||
|
@ -37,6 +40,16 @@ export class MLFactory {
|
||||||
throw Error('Unknon face detection method: ' + method);
|
throw Error('Unknon face detection method: ' + method);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static getObjectDetectionService(
|
||||||
|
method: ObjectDetectionMethod
|
||||||
|
): ObjectDetectionService {
|
||||||
|
if (method === 'SSDMobileNetV2') {
|
||||||
|
return ssdMobileNetV2Service;
|
||||||
|
}
|
||||||
|
|
||||||
|
throw Error('Unknown object detection method: ' + method);
|
||||||
|
}
|
||||||
|
|
||||||
public static getFaceCropService(method: FaceCropMethod) {
|
public static getFaceCropService(method: FaceCropMethod) {
|
||||||
if (method === 'ArcFace') {
|
if (method === 'ArcFace') {
|
||||||
return arcfaceCropService;
|
return arcfaceCropService;
|
||||||
|
@ -97,6 +110,7 @@ export class LocalMLSyncContext implements MLSyncContext {
|
||||||
public faceAlignmentService: FaceAlignmentService;
|
public faceAlignmentService: FaceAlignmentService;
|
||||||
public faceEmbeddingService: FaceEmbeddingService;
|
public faceEmbeddingService: FaceEmbeddingService;
|
||||||
public faceClusteringService: ClusteringService;
|
public faceClusteringService: ClusteringService;
|
||||||
|
public objectDetectionService: ObjectDetectionService;
|
||||||
|
|
||||||
public localFilesMap: Map<number, EnteFile>;
|
public localFilesMap: Map<number, EnteFile>;
|
||||||
public outOfSyncFiles: EnteFile[];
|
public outOfSyncFiles: EnteFile[];
|
||||||
|
@ -143,6 +157,10 @@ export class LocalMLSyncContext implements MLSyncContext {
|
||||||
this.config.faceClustering.method
|
this.config.faceClustering.method
|
||||||
);
|
);
|
||||||
|
|
||||||
|
this.objectDetectionService = MLFactory.getObjectDetectionService(
|
||||||
|
this.config.ObjectDetection.method
|
||||||
|
);
|
||||||
|
|
||||||
this.outOfSyncFiles = [];
|
this.outOfSyncFiles = [];
|
||||||
this.nSyncedFiles = 0;
|
this.nSyncedFiles = 0;
|
||||||
this.nSyncedFaces = 0;
|
this.nSyncedFaces = 0;
|
||||||
|
|
Loading…
Reference in a new issue