diff --git a/src/services/machineLearning/machineLearningFactory.ts b/src/services/machineLearning/machineLearningFactory.ts index 9bccc6cb9..75f432fd5 100644 --- a/src/services/machineLearning/machineLearningFactory.ts +++ b/src/services/machineLearning/machineLearningFactory.ts @@ -15,6 +15,8 @@ import { ClusteringMethod, ClusteringService, MLLibraryData, + ObjectDetectionService, + ObjectDetectionMethod, } from 'types/machineLearning'; import { CONCURRENCY } from 'utils/common/concurrency'; import { ComlinkWorker, getDedicatedCryptoWorker } from 'utils/crypto'; @@ -25,6 +27,7 @@ import hdbscanClusteringService from './hdbscanClusteringService'; import blazeFaceDetectionService from './blazeFaceDetectionService'; import mobileFaceNetEmbeddingService from './mobileFaceNetEmbeddingService'; import dbscanClusteringService from './dbscanClusteringService'; +import ssdMobileNetV2Service from './ssdMobileNetV2Service'; export class MLFactory { public static getFaceDetectionService( @@ -37,6 +40,16 @@ export class MLFactory { throw Error('Unknon face detection method: ' + method); } + public static getObjectDetectionService( + method: ObjectDetectionMethod + ): ObjectDetectionService { + if (method === 'SSDMobileNetV2') { + return ssdMobileNetV2Service; + } + + throw Error('Unknown object detection method: ' + method); + } + public static getFaceCropService(method: FaceCropMethod) { if (method === 'ArcFace') { return arcfaceCropService; @@ -97,6 +110,7 @@ export class LocalMLSyncContext implements MLSyncContext { public faceAlignmentService: FaceAlignmentService; public faceEmbeddingService: FaceEmbeddingService; public faceClusteringService: ClusteringService; + public objectDetectionService: ObjectDetectionService; public localFilesMap: Map; public outOfSyncFiles: EnteFile[]; @@ -143,6 +157,10 @@ export class LocalMLSyncContext implements MLSyncContext { this.config.faceClustering.method ); + this.objectDetectionService = MLFactory.getObjectDetectionService( + this.config.ObjectDetection.method + ); + this.outOfSyncFiles = []; this.nSyncedFiles = 0; this.nSyncedFaces = 0;