From ed3b3313b48025c1629ab5cb64031d3be6286fe1 Mon Sep 17 00:00:00 2001 From: Abhinav Date: Wed, 16 Feb 2022 14:56:14 +0530 Subject: [PATCH] refactor ml service to move faceDetection logic to seperate service --- .../machineLearning/faceDetectionService.ts | 192 ++++++++++++++ .../machineLearning/machineLearningService.ts | 240 ++---------------- src/services/machineLearning/readerService.ts | 53 ++++ src/utils/machineLearning/index.ts | 8 +- 4 files changed, 266 insertions(+), 227 deletions(-) create mode 100644 src/services/machineLearning/faceDetectionService.ts create mode 100644 src/services/machineLearning/readerService.ts diff --git a/src/services/machineLearning/faceDetectionService.ts b/src/services/machineLearning/faceDetectionService.ts new file mode 100644 index 000000000..e6b8e835c --- /dev/null +++ b/src/services/machineLearning/faceDetectionService.ts @@ -0,0 +1,192 @@ +import { + MLSyncContext, + MLSyncFileContext, + DetectedFace, + Face, +} from 'types/machineLearning'; +import { + isDifferentOrOld, + getFaceId, + areFaceIdsSame, + extractFaceImages, +} from 'utils/machineLearning'; +import { storeFaceCrop } from 'utils/machineLearning/faceCrop'; +import ReaderService from './readerService'; + +class FaceDetectionService { + async syncFileFaceDetections( + syncContext: MLSyncContext, + fileContext: MLSyncFileContext + ) { + const { oldMlFile, newMlFile } = fileContext; + if ( + !isDifferentOrOld( + oldMlFile?.faceDetectionMethod, + syncContext.faceDetectionService.method + ) && + oldMlFile?.imageSource === syncContext.config.imageSource + ) { + newMlFile.faces = oldMlFile?.faces?.map((existingFace) => ({ + id: existingFace.id, + fileId: existingFace.fileId, + detection: existingFace.detection, + })); + + newMlFile.imageSource = oldMlFile.imageSource; + newMlFile.imageDimentions = oldMlFile.imageDimentions; + newMlFile.faceDetectionMethod = oldMlFile.faceDetectionMethod; + return; + } + + newMlFile.faceDetectionMethod = syncContext.faceDetectionService.method; + fileContext.newDetection = true; + const imageBitmap = await ReaderService.getImageBitmap( + syncContext, + fileContext + ); + const faceDetections = + await syncContext.faceDetectionService.detectFaces(imageBitmap); + // console.log('3 TF Memory stats: ', tf.memory()); + // TODO: reenable faces filtering based on width + const detectedFaces = faceDetections?.map((detection) => { + return { + fileId: fileContext.enteFile.id, + detection, + } as DetectedFace; + }); + newMlFile.faces = detectedFaces?.map((detectedFace) => ({ + ...detectedFace, + id: getFaceId(detectedFace, newMlFile.imageDimentions), + })); + // ?.filter((f) => + // f.box.width > syncContext.config.faceDetection.minFaceSize + // ); + console.log('[MLService] Detected Faces: ', newMlFile.faces?.length); + } + + async syncFileFaceCrops( + syncContext: MLSyncContext, + fileContext: MLSyncFileContext + ) { + const { oldMlFile, newMlFile } = fileContext; + if ( + // !syncContext.config.faceCrop.enabled || + !fileContext.newDetection && + !isDifferentOrOld( + oldMlFile?.faceCropMethod, + syncContext.faceCropService.method + ) && + areFaceIdsSame(newMlFile.faces, oldMlFile?.faces) + ) { + for (const [index, face] of newMlFile.faces.entries()) { + face.crop = oldMlFile.faces[index].crop; + } + newMlFile.faceCropMethod = oldMlFile.faceCropMethod; + return; + } + + const imageBitmap = await ReaderService.getImageBitmap( + syncContext, + fileContext + ); + newMlFile.faceCropMethod = syncContext.faceCropService.method; + + for (const face of newMlFile.faces) { + await this.saveFaceCrop(imageBitmap, face, syncContext); + } + } + + async syncFileFaceAlignments( + syncContext: MLSyncContext, + fileContext: MLSyncFileContext + ) { + const { oldMlFile, newMlFile } = fileContext; + if ( + !fileContext.newDetection && + !isDifferentOrOld( + oldMlFile?.faceAlignmentMethod, + syncContext.faceAlignmentService.method + ) && + areFaceIdsSame(newMlFile.faces, oldMlFile?.faces) + ) { + for (const [index, face] of newMlFile.faces.entries()) { + face.alignment = oldMlFile.faces[index].alignment; + } + newMlFile.faceAlignmentMethod = oldMlFile.faceAlignmentMethod; + return; + } + + newMlFile.faceAlignmentMethod = syncContext.faceAlignmentService.method; + fileContext.newAlignment = true; + for (const face of newMlFile.faces) { + face.alignment = syncContext.faceAlignmentService.getFaceAlignment( + face.detection + ); + } + console.log('[MLService] alignedFaces: ', newMlFile.faces?.length); + // console.log('4 TF Memory stats: ', tf.memory()); + } + + async syncFileFaceEmbeddings( + syncContext: MLSyncContext, + fileContext: MLSyncFileContext + ) { + const { oldMlFile, newMlFile } = fileContext; + if ( + !fileContext.newAlignment && + !isDifferentOrOld( + oldMlFile?.faceEmbeddingMethod, + syncContext.faceEmbeddingService.method + ) && + areFaceIdsSame(newMlFile.faces, oldMlFile?.faces) + ) { + for (const [index, face] of newMlFile.faces.entries()) { + face.embedding = oldMlFile.faces[index].embedding; + } + newMlFile.faceEmbeddingMethod = oldMlFile.faceEmbeddingMethod; + return; + } + + newMlFile.faceEmbeddingMethod = syncContext.faceEmbeddingService.method; + // TODO: when not storing face crops, image will be needed to extract faces + // fileContext.imageBitmap || + // (await this.getImageBitmap(syncContext, fileContext)); + const faceImages = await extractFaceImages( + newMlFile.faces, + syncContext.faceEmbeddingService.faceSize + ); + + const embeddings = + await syncContext.faceEmbeddingService.getFaceEmbeddings( + faceImages + ); + faceImages.forEach((faceImage) => faceImage.close()); + newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i])); + + console.log( + '[MLService] facesWithEmbeddings: ', + newMlFile.faces.length + ); + // console.log('5 TF Memory stats: ', tf.memory()); + } + + private async saveFaceCrop( + imageBitmap: ImageBitmap, + face: Face, + syncContext: MLSyncContext + ) { + const faceCrop = await syncContext.faceCropService.getFaceCrop( + imageBitmap, + face.detection, + syncContext.config.faceCrop + ); + face.crop = await storeFaceCrop( + face.id, + faceCrop, + syncContext.config.faceCrop.blobOptions + ); + faceCrop.image.close(); + } +} + +export default new FaceDetectionService(); diff --git a/src/services/machineLearning/machineLearningService.ts b/src/services/machineLearning/machineLearningService.ts index 6505efbee..229794ce9 100644 --- a/src/services/machineLearning/machineLearningService.ts +++ b/src/services/machineLearning/machineLearningService.ts @@ -1,6 +1,5 @@ import { getLocalFiles } from 'services/fileService'; import { EnteFile } from 'types/file'; -import { FILE_TYPE } from 'constants/file'; import * as tf from '@tensorflow/tfjs-core'; import '@tensorflow/tfjs-backend-webgl'; @@ -9,7 +8,6 @@ import '@tensorflow/tfjs-backend-webgl'; // import '@tensorflow/tfjs-backend-cpu'; import { - DetectedFace, Face, MlFileData, MLSyncContext, @@ -24,23 +22,18 @@ import { toTSNE } from 'utils/machineLearning/visualization'; // mlFilesStore // } from 'utils/storage/mlStorage'; import { - areFaceIdsSame, - extractFaceImages, findFirstIfSorted, getAllFacesFromMap, - getFaceId, getLocalFile, - getLocalFileImageBitmap, getOriginalImageBitmap, - getThumbnailImageBitmap, isDifferentOrOld, } from 'utils/machineLearning'; import { MLFactory } from './machineLearningFactory'; import mlIDbStorage from 'utils/storage/mlIDbStorage'; -import { storeFaceCrop } from 'utils/machineLearning/faceCrop'; import { getMLSyncConfig } from 'utils/machineLearning/config'; import { CustomError, parseServerError } from 'utils/error'; import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config'; +import FaceDetectionService from './faceDetectionService'; class MachineLearningService { private initialized = false; @@ -409,14 +402,26 @@ class MachineLearningService { newMlFile.mlVersion = fileContext.oldMlFile.mlVersion; } - await this.syncFileFaceDetections(syncContext, fileContext); + await FaceDetectionService.syncFileFaceDetections( + syncContext, + fileContext + ); if (newMlFile.faces && newMlFile.faces.length > 0) { - await this.syncFileFaceCrops(syncContext, fileContext); + await FaceDetectionService.syncFileFaceCrops( + syncContext, + fileContext + ); - await this.syncFileFaceAlignments(syncContext, fileContext); + await FaceDetectionService.syncFileFaceAlignments( + syncContext, + fileContext + ); - await this.syncFileFaceEmbeddings(syncContext, fileContext); + await FaceDetectionService.syncFileFaceEmbeddings( + syncContext, + fileContext + ); } fileContext.tfImage && fileContext.tfImage.dispose(); @@ -436,217 +441,6 @@ class MachineLearningService { return newMlFile; } - private async getImageBitmap( - syncContext: MLSyncContext, - fileContext: MLSyncFileContext - ) { - if (fileContext.imageBitmap) { - return fileContext.imageBitmap; - } - // console.log('1 TF Memory stats: ', tf.memory()); - if (fileContext.localFile) { - if (fileContext.enteFile.metadata.fileType !== FILE_TYPE.IMAGE) { - throw new Error('Local file of only image type is supported'); - } - fileContext.imageBitmap = await getLocalFileImageBitmap( - fileContext.enteFile, - fileContext.localFile, - () => syncContext.getEnteWorker(fileContext.enteFile.id) - ); - } else if ( - syncContext.config.imageSource === 'Original' && - [FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes( - fileContext.enteFile.metadata.fileType - ) - ) { - fileContext.imageBitmap = await getOriginalImageBitmap( - fileContext.enteFile, - syncContext.token, - await syncContext.getEnteWorker(fileContext.enteFile.id) - ); - } else { - fileContext.imageBitmap = await getThumbnailImageBitmap( - fileContext.enteFile, - syncContext.token - ); - } - - fileContext.newMlFile.imageSource = syncContext.config.imageSource; - const { width, height } = fileContext.imageBitmap; - fileContext.newMlFile.imageDimentions = { width, height }; - // console.log('2 TF Memory stats: ', tf.memory()); - - return fileContext.imageBitmap; - } - - private async syncFileFaceDetections( - syncContext: MLSyncContext, - fileContext: MLSyncFileContext - ) { - const { oldMlFile, newMlFile } = fileContext; - if ( - !isDifferentOrOld( - oldMlFile?.faceDetectionMethod, - syncContext.faceDetectionService.method - ) && - oldMlFile?.imageSource === syncContext.config.imageSource - ) { - newMlFile.faces = oldMlFile?.faces?.map((existingFace) => ({ - id: existingFace.id, - fileId: existingFace.fileId, - detection: existingFace.detection, - })); - - newMlFile.imageSource = oldMlFile.imageSource; - newMlFile.imageDimentions = oldMlFile.imageDimentions; - newMlFile.faceDetectionMethod = oldMlFile.faceDetectionMethod; - return; - } - - newMlFile.faceDetectionMethod = syncContext.faceDetectionService.method; - fileContext.newDetection = true; - const imageBitmap = await this.getImageBitmap(syncContext, fileContext); - const faceDetections = - await syncContext.faceDetectionService.detectFaces(imageBitmap); - // console.log('3 TF Memory stats: ', tf.memory()); - // TODO: reenable faces filtering based on width - const detectedFaces = faceDetections?.map((detection) => { - return { - fileId: fileContext.enteFile.id, - detection, - } as DetectedFace; - }); - newMlFile.faces = detectedFaces?.map((detectedFace) => ({ - ...detectedFace, - id: getFaceId(detectedFace, newMlFile.imageDimentions), - })); - // ?.filter((f) => - // f.box.width > syncContext.config.faceDetection.minFaceSize - // ); - console.log('[MLService] Detected Faces: ', newMlFile.faces?.length); - } - - private async saveFaceCrop( - imageBitmap: ImageBitmap, - face: Face, - syncContext: MLSyncContext - ) { - const faceCrop = await syncContext.faceCropService.getFaceCrop( - imageBitmap, - face.detection, - syncContext.config.faceCrop - ); - face.crop = await storeFaceCrop( - face.id, - faceCrop, - syncContext.config.faceCrop.blobOptions - ); - faceCrop.image.close(); - } - - private async syncFileFaceCrops( - syncContext: MLSyncContext, - fileContext: MLSyncFileContext - ) { - const { oldMlFile, newMlFile } = fileContext; - if ( - // !syncContext.config.faceCrop.enabled || - !fileContext.newDetection && - !isDifferentOrOld( - oldMlFile?.faceCropMethod, - syncContext.faceCropService.method - ) && - areFaceIdsSame(newMlFile.faces, oldMlFile?.faces) - ) { - for (const [index, face] of newMlFile.faces.entries()) { - face.crop = oldMlFile.faces[index].crop; - } - newMlFile.faceCropMethod = oldMlFile.faceCropMethod; - return; - } - - const imageBitmap = await this.getImageBitmap(syncContext, fileContext); - newMlFile.faceCropMethod = syncContext.faceCropService.method; - - for (const face of newMlFile.faces) { - await this.saveFaceCrop(imageBitmap, face, syncContext); - } - } - - private async syncFileFaceAlignments( - syncContext: MLSyncContext, - fileContext: MLSyncFileContext - ) { - const { oldMlFile, newMlFile } = fileContext; - if ( - !fileContext.newDetection && - !isDifferentOrOld( - oldMlFile?.faceAlignmentMethod, - syncContext.faceAlignmentService.method - ) && - areFaceIdsSame(newMlFile.faces, oldMlFile?.faces) - ) { - for (const [index, face] of newMlFile.faces.entries()) { - face.alignment = oldMlFile.faces[index].alignment; - } - newMlFile.faceAlignmentMethod = oldMlFile.faceAlignmentMethod; - return; - } - - newMlFile.faceAlignmentMethod = syncContext.faceAlignmentService.method; - fileContext.newAlignment = true; - for (const face of newMlFile.faces) { - face.alignment = syncContext.faceAlignmentService.getFaceAlignment( - face.detection - ); - } - console.log('[MLService] alignedFaces: ', newMlFile.faces?.length); - // console.log('4 TF Memory stats: ', tf.memory()); - } - - private async syncFileFaceEmbeddings( - syncContext: MLSyncContext, - fileContext: MLSyncFileContext - ) { - const { oldMlFile, newMlFile } = fileContext; - if ( - !fileContext.newAlignment && - !isDifferentOrOld( - oldMlFile?.faceEmbeddingMethod, - syncContext.faceEmbeddingService.method - ) && - areFaceIdsSame(newMlFile.faces, oldMlFile?.faces) - ) { - for (const [index, face] of newMlFile.faces.entries()) { - face.embedding = oldMlFile.faces[index].embedding; - } - newMlFile.faceEmbeddingMethod = oldMlFile.faceEmbeddingMethod; - return; - } - - newMlFile.faceEmbeddingMethod = syncContext.faceEmbeddingService.method; - // TODO: when not storing face crops, image will be needed to extract faces - // fileContext.imageBitmap || - // (await this.getImageBitmap(syncContext, fileContext)); - const faceImages = await extractFaceImages( - newMlFile.faces, - syncContext.faceEmbeddingService.faceSize - ); - - const embeddings = - await syncContext.faceEmbeddingService.getFaceEmbeddings( - faceImages - ); - faceImages.forEach((faceImage) => faceImage.close()); - newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i])); - - console.log( - '[MLService] facesWithEmbeddings: ', - newMlFile.faces.length - ); - // console.log('5 TF Memory stats: ', tf.memory()); - } - public async init() { if (this.initialized) { return; diff --git a/src/services/machineLearning/readerService.ts b/src/services/machineLearning/readerService.ts new file mode 100644 index 000000000..a246a1dbd --- /dev/null +++ b/src/services/machineLearning/readerService.ts @@ -0,0 +1,53 @@ +import { FILE_TYPE } from 'constants/file'; +import { MLSyncContext, MLSyncFileContext } from 'types/machineLearning'; +import { + getLocalFileImageBitmap, + getOriginalImageBitmap, + getThumbnailImageBitmap, +} from 'utils/machineLearning'; + +class ReaderService { + async getImageBitmap( + syncContext: MLSyncContext, + fileContext: MLSyncFileContext + ) { + if (fileContext.imageBitmap) { + return fileContext.imageBitmap; + } + // console.log('1 TF Memory stats: ', tf.memory()); + if (fileContext.localFile) { + if (fileContext.enteFile.metadata.fileType !== FILE_TYPE.IMAGE) { + throw new Error('Local file of only image type is supported'); + } + fileContext.imageBitmap = await getLocalFileImageBitmap( + fileContext.enteFile, + fileContext.localFile, + () => syncContext.getEnteWorker(fileContext.enteFile.id) + ); + } else if ( + syncContext.config.imageSource === 'Original' && + [FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes( + fileContext.enteFile.metadata.fileType + ) + ) { + fileContext.imageBitmap = await getOriginalImageBitmap( + fileContext.enteFile, + syncContext.token, + await syncContext.getEnteWorker(fileContext.enteFile.id) + ); + } else { + fileContext.imageBitmap = await getThumbnailImageBitmap( + fileContext.enteFile, + syncContext.token + ); + } + + fileContext.newMlFile.imageSource = syncContext.config.imageSource; + const { width, height } = fileContext.imageBitmap; + fileContext.newMlFile.imageDimentions = { width, height }; + // console.log('2 TF Memory stats: ', tf.memory()); + + return fileContext.imageBitmap; + } +} +export default new ReaderService(); diff --git a/src/utils/machineLearning/index.ts b/src/utils/machineLearning/index.ts index c6e5a45d6..7bad7abe2 100644 --- a/src/utils/machineLearning/index.ts +++ b/src/utils/machineLearning/index.ts @@ -278,7 +278,7 @@ export async function getTFImage(blob): Promise { return tfImage; } -export async function getImageBitmap(blob: Blob): Promise { +export async function getImageBlobBitmap(blob: Blob): Promise { return await createImageBitmap(blob); } @@ -352,7 +352,7 @@ export async function getOriginalImageBitmap( } console.log('[MLService] Got file: ', file.id.toString()); - return getImageBitmap(fileBlob); + return getImageBlobBitmap(fileBlob); } export async function getThumbnailImageBitmap(file: EnteFile, token: string) { @@ -365,7 +365,7 @@ export async function getThumbnailImageBitmap(file: EnteFile, token: string) { const thumbFile = await fetch(fileUrl); - return getImageBitmap(await thumbFile.blob()); + return getImageBlobBitmap(await thumbFile.blob()); } export async function getLocalFileImageBitmap( @@ -378,7 +378,7 @@ export async function getLocalFileImageBitmap( const enteWorker = await enteWorkerProvider(); fileBlob = await convertForPreview(enteFile, fileBlob, enteWorker); } - return getImageBitmap(fileBlob); + return getImageBlobBitmap(fileBlob); } export async function getPeopleList(file: EnteFile): Promise> {