refactor ml service to move faceDetection logic to seperate service

This commit is contained in:
Abhinav 2022-02-16 14:56:14 +05:30
parent 8a3b08f4a7
commit ed3b3313b4
4 changed files with 266 additions and 227 deletions

View file

@ -0,0 +1,192 @@
import {
MLSyncContext,
MLSyncFileContext,
DetectedFace,
Face,
} from 'types/machineLearning';
import {
isDifferentOrOld,
getFaceId,
areFaceIdsSame,
extractFaceImages,
} from 'utils/machineLearning';
import { storeFaceCrop } from 'utils/machineLearning/faceCrop';
import ReaderService from './readerService';
class FaceDetectionService {
async syncFileFaceDetections(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
const { oldMlFile, newMlFile } = fileContext;
if (
!isDifferentOrOld(
oldMlFile?.faceDetectionMethod,
syncContext.faceDetectionService.method
) &&
oldMlFile?.imageSource === syncContext.config.imageSource
) {
newMlFile.faces = oldMlFile?.faces?.map((existingFace) => ({
id: existingFace.id,
fileId: existingFace.fileId,
detection: existingFace.detection,
}));
newMlFile.imageSource = oldMlFile.imageSource;
newMlFile.imageDimentions = oldMlFile.imageDimentions;
newMlFile.faceDetectionMethod = oldMlFile.faceDetectionMethod;
return;
}
newMlFile.faceDetectionMethod = syncContext.faceDetectionService.method;
fileContext.newDetection = true;
const imageBitmap = await ReaderService.getImageBitmap(
syncContext,
fileContext
);
const faceDetections =
await syncContext.faceDetectionService.detectFaces(imageBitmap);
// console.log('3 TF Memory stats: ', tf.memory());
// TODO: reenable faces filtering based on width
const detectedFaces = faceDetections?.map((detection) => {
return {
fileId: fileContext.enteFile.id,
detection,
} as DetectedFace;
});
newMlFile.faces = detectedFaces?.map((detectedFace) => ({
...detectedFace,
id: getFaceId(detectedFace, newMlFile.imageDimentions),
}));
// ?.filter((f) =>
// f.box.width > syncContext.config.faceDetection.minFaceSize
// );
console.log('[MLService] Detected Faces: ', newMlFile.faces?.length);
}
async syncFileFaceCrops(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
const { oldMlFile, newMlFile } = fileContext;
if (
// !syncContext.config.faceCrop.enabled ||
!fileContext.newDetection &&
!isDifferentOrOld(
oldMlFile?.faceCropMethod,
syncContext.faceCropService.method
) &&
areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
) {
for (const [index, face] of newMlFile.faces.entries()) {
face.crop = oldMlFile.faces[index].crop;
}
newMlFile.faceCropMethod = oldMlFile.faceCropMethod;
return;
}
const imageBitmap = await ReaderService.getImageBitmap(
syncContext,
fileContext
);
newMlFile.faceCropMethod = syncContext.faceCropService.method;
for (const face of newMlFile.faces) {
await this.saveFaceCrop(imageBitmap, face, syncContext);
}
}
async syncFileFaceAlignments(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
const { oldMlFile, newMlFile } = fileContext;
if (
!fileContext.newDetection &&
!isDifferentOrOld(
oldMlFile?.faceAlignmentMethod,
syncContext.faceAlignmentService.method
) &&
areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
) {
for (const [index, face] of newMlFile.faces.entries()) {
face.alignment = oldMlFile.faces[index].alignment;
}
newMlFile.faceAlignmentMethod = oldMlFile.faceAlignmentMethod;
return;
}
newMlFile.faceAlignmentMethod = syncContext.faceAlignmentService.method;
fileContext.newAlignment = true;
for (const face of newMlFile.faces) {
face.alignment = syncContext.faceAlignmentService.getFaceAlignment(
face.detection
);
}
console.log('[MLService] alignedFaces: ', newMlFile.faces?.length);
// console.log('4 TF Memory stats: ', tf.memory());
}
async syncFileFaceEmbeddings(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
const { oldMlFile, newMlFile } = fileContext;
if (
!fileContext.newAlignment &&
!isDifferentOrOld(
oldMlFile?.faceEmbeddingMethod,
syncContext.faceEmbeddingService.method
) &&
areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
) {
for (const [index, face] of newMlFile.faces.entries()) {
face.embedding = oldMlFile.faces[index].embedding;
}
newMlFile.faceEmbeddingMethod = oldMlFile.faceEmbeddingMethod;
return;
}
newMlFile.faceEmbeddingMethod = syncContext.faceEmbeddingService.method;
// TODO: when not storing face crops, image will be needed to extract faces
// fileContext.imageBitmap ||
// (await this.getImageBitmap(syncContext, fileContext));
const faceImages = await extractFaceImages(
newMlFile.faces,
syncContext.faceEmbeddingService.faceSize
);
const embeddings =
await syncContext.faceEmbeddingService.getFaceEmbeddings(
faceImages
);
faceImages.forEach((faceImage) => faceImage.close());
newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
console.log(
'[MLService] facesWithEmbeddings: ',
newMlFile.faces.length
);
// console.log('5 TF Memory stats: ', tf.memory());
}
private async saveFaceCrop(
imageBitmap: ImageBitmap,
face: Face,
syncContext: MLSyncContext
) {
const faceCrop = await syncContext.faceCropService.getFaceCrop(
imageBitmap,
face.detection,
syncContext.config.faceCrop
);
face.crop = await storeFaceCrop(
face.id,
faceCrop,
syncContext.config.faceCrop.blobOptions
);
faceCrop.image.close();
}
}
export default new FaceDetectionService();

View file

@ -1,6 +1,5 @@
import { getLocalFiles } from 'services/fileService';
import { EnteFile } from 'types/file';
import { FILE_TYPE } from 'constants/file';
import * as tf from '@tensorflow/tfjs-core';
import '@tensorflow/tfjs-backend-webgl';
@ -9,7 +8,6 @@ import '@tensorflow/tfjs-backend-webgl';
// import '@tensorflow/tfjs-backend-cpu';
import {
DetectedFace,
Face,
MlFileData,
MLSyncContext,
@ -24,23 +22,18 @@ import { toTSNE } from 'utils/machineLearning/visualization';
// mlFilesStore
// } from 'utils/storage/mlStorage';
import {
areFaceIdsSame,
extractFaceImages,
findFirstIfSorted,
getAllFacesFromMap,
getFaceId,
getLocalFile,
getLocalFileImageBitmap,
getOriginalImageBitmap,
getThumbnailImageBitmap,
isDifferentOrOld,
} from 'utils/machineLearning';
import { MLFactory } from './machineLearningFactory';
import mlIDbStorage from 'utils/storage/mlIDbStorage';
import { storeFaceCrop } from 'utils/machineLearning/faceCrop';
import { getMLSyncConfig } from 'utils/machineLearning/config';
import { CustomError, parseServerError } from 'utils/error';
import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config';
import FaceDetectionService from './faceDetectionService';
class MachineLearningService {
private initialized = false;
@ -409,14 +402,26 @@ class MachineLearningService {
newMlFile.mlVersion = fileContext.oldMlFile.mlVersion;
}
await this.syncFileFaceDetections(syncContext, fileContext);
await FaceDetectionService.syncFileFaceDetections(
syncContext,
fileContext
);
if (newMlFile.faces && newMlFile.faces.length > 0) {
await this.syncFileFaceCrops(syncContext, fileContext);
await FaceDetectionService.syncFileFaceCrops(
syncContext,
fileContext
);
await this.syncFileFaceAlignments(syncContext, fileContext);
await FaceDetectionService.syncFileFaceAlignments(
syncContext,
fileContext
);
await this.syncFileFaceEmbeddings(syncContext, fileContext);
await FaceDetectionService.syncFileFaceEmbeddings(
syncContext,
fileContext
);
}
fileContext.tfImage && fileContext.tfImage.dispose();
@ -436,217 +441,6 @@ class MachineLearningService {
return newMlFile;
}
private async getImageBitmap(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
if (fileContext.imageBitmap) {
return fileContext.imageBitmap;
}
// console.log('1 TF Memory stats: ', tf.memory());
if (fileContext.localFile) {
if (fileContext.enteFile.metadata.fileType !== FILE_TYPE.IMAGE) {
throw new Error('Local file of only image type is supported');
}
fileContext.imageBitmap = await getLocalFileImageBitmap(
fileContext.enteFile,
fileContext.localFile,
() => syncContext.getEnteWorker(fileContext.enteFile.id)
);
} else if (
syncContext.config.imageSource === 'Original' &&
[FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes(
fileContext.enteFile.metadata.fileType
)
) {
fileContext.imageBitmap = await getOriginalImageBitmap(
fileContext.enteFile,
syncContext.token,
await syncContext.getEnteWorker(fileContext.enteFile.id)
);
} else {
fileContext.imageBitmap = await getThumbnailImageBitmap(
fileContext.enteFile,
syncContext.token
);
}
fileContext.newMlFile.imageSource = syncContext.config.imageSource;
const { width, height } = fileContext.imageBitmap;
fileContext.newMlFile.imageDimentions = { width, height };
// console.log('2 TF Memory stats: ', tf.memory());
return fileContext.imageBitmap;
}
private async syncFileFaceDetections(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
const { oldMlFile, newMlFile } = fileContext;
if (
!isDifferentOrOld(
oldMlFile?.faceDetectionMethod,
syncContext.faceDetectionService.method
) &&
oldMlFile?.imageSource === syncContext.config.imageSource
) {
newMlFile.faces = oldMlFile?.faces?.map((existingFace) => ({
id: existingFace.id,
fileId: existingFace.fileId,
detection: existingFace.detection,
}));
newMlFile.imageSource = oldMlFile.imageSource;
newMlFile.imageDimentions = oldMlFile.imageDimentions;
newMlFile.faceDetectionMethod = oldMlFile.faceDetectionMethod;
return;
}
newMlFile.faceDetectionMethod = syncContext.faceDetectionService.method;
fileContext.newDetection = true;
const imageBitmap = await this.getImageBitmap(syncContext, fileContext);
const faceDetections =
await syncContext.faceDetectionService.detectFaces(imageBitmap);
// console.log('3 TF Memory stats: ', tf.memory());
// TODO: reenable faces filtering based on width
const detectedFaces = faceDetections?.map((detection) => {
return {
fileId: fileContext.enteFile.id,
detection,
} as DetectedFace;
});
newMlFile.faces = detectedFaces?.map((detectedFace) => ({
...detectedFace,
id: getFaceId(detectedFace, newMlFile.imageDimentions),
}));
// ?.filter((f) =>
// f.box.width > syncContext.config.faceDetection.minFaceSize
// );
console.log('[MLService] Detected Faces: ', newMlFile.faces?.length);
}
private async saveFaceCrop(
imageBitmap: ImageBitmap,
face: Face,
syncContext: MLSyncContext
) {
const faceCrop = await syncContext.faceCropService.getFaceCrop(
imageBitmap,
face.detection,
syncContext.config.faceCrop
);
face.crop = await storeFaceCrop(
face.id,
faceCrop,
syncContext.config.faceCrop.blobOptions
);
faceCrop.image.close();
}
private async syncFileFaceCrops(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
const { oldMlFile, newMlFile } = fileContext;
if (
// !syncContext.config.faceCrop.enabled ||
!fileContext.newDetection &&
!isDifferentOrOld(
oldMlFile?.faceCropMethod,
syncContext.faceCropService.method
) &&
areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
) {
for (const [index, face] of newMlFile.faces.entries()) {
face.crop = oldMlFile.faces[index].crop;
}
newMlFile.faceCropMethod = oldMlFile.faceCropMethod;
return;
}
const imageBitmap = await this.getImageBitmap(syncContext, fileContext);
newMlFile.faceCropMethod = syncContext.faceCropService.method;
for (const face of newMlFile.faces) {
await this.saveFaceCrop(imageBitmap, face, syncContext);
}
}
private async syncFileFaceAlignments(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
const { oldMlFile, newMlFile } = fileContext;
if (
!fileContext.newDetection &&
!isDifferentOrOld(
oldMlFile?.faceAlignmentMethod,
syncContext.faceAlignmentService.method
) &&
areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
) {
for (const [index, face] of newMlFile.faces.entries()) {
face.alignment = oldMlFile.faces[index].alignment;
}
newMlFile.faceAlignmentMethod = oldMlFile.faceAlignmentMethod;
return;
}
newMlFile.faceAlignmentMethod = syncContext.faceAlignmentService.method;
fileContext.newAlignment = true;
for (const face of newMlFile.faces) {
face.alignment = syncContext.faceAlignmentService.getFaceAlignment(
face.detection
);
}
console.log('[MLService] alignedFaces: ', newMlFile.faces?.length);
// console.log('4 TF Memory stats: ', tf.memory());
}
private async syncFileFaceEmbeddings(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
const { oldMlFile, newMlFile } = fileContext;
if (
!fileContext.newAlignment &&
!isDifferentOrOld(
oldMlFile?.faceEmbeddingMethod,
syncContext.faceEmbeddingService.method
) &&
areFaceIdsSame(newMlFile.faces, oldMlFile?.faces)
) {
for (const [index, face] of newMlFile.faces.entries()) {
face.embedding = oldMlFile.faces[index].embedding;
}
newMlFile.faceEmbeddingMethod = oldMlFile.faceEmbeddingMethod;
return;
}
newMlFile.faceEmbeddingMethod = syncContext.faceEmbeddingService.method;
// TODO: when not storing face crops, image will be needed to extract faces
// fileContext.imageBitmap ||
// (await this.getImageBitmap(syncContext, fileContext));
const faceImages = await extractFaceImages(
newMlFile.faces,
syncContext.faceEmbeddingService.faceSize
);
const embeddings =
await syncContext.faceEmbeddingService.getFaceEmbeddings(
faceImages
);
faceImages.forEach((faceImage) => faceImage.close());
newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
console.log(
'[MLService] facesWithEmbeddings: ',
newMlFile.faces.length
);
// console.log('5 TF Memory stats: ', tf.memory());
}
public async init() {
if (this.initialized) {
return;

View file

@ -0,0 +1,53 @@
import { FILE_TYPE } from 'constants/file';
import { MLSyncContext, MLSyncFileContext } from 'types/machineLearning';
import {
getLocalFileImageBitmap,
getOriginalImageBitmap,
getThumbnailImageBitmap,
} from 'utils/machineLearning';
class ReaderService {
async getImageBitmap(
syncContext: MLSyncContext,
fileContext: MLSyncFileContext
) {
if (fileContext.imageBitmap) {
return fileContext.imageBitmap;
}
// console.log('1 TF Memory stats: ', tf.memory());
if (fileContext.localFile) {
if (fileContext.enteFile.metadata.fileType !== FILE_TYPE.IMAGE) {
throw new Error('Local file of only image type is supported');
}
fileContext.imageBitmap = await getLocalFileImageBitmap(
fileContext.enteFile,
fileContext.localFile,
() => syncContext.getEnteWorker(fileContext.enteFile.id)
);
} else if (
syncContext.config.imageSource === 'Original' &&
[FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes(
fileContext.enteFile.metadata.fileType
)
) {
fileContext.imageBitmap = await getOriginalImageBitmap(
fileContext.enteFile,
syncContext.token,
await syncContext.getEnteWorker(fileContext.enteFile.id)
);
} else {
fileContext.imageBitmap = await getThumbnailImageBitmap(
fileContext.enteFile,
syncContext.token
);
}
fileContext.newMlFile.imageSource = syncContext.config.imageSource;
const { width, height } = fileContext.imageBitmap;
fileContext.newMlFile.imageDimentions = { width, height };
// console.log('2 TF Memory stats: ', tf.memory());
return fileContext.imageBitmap;
}
}
export default new ReaderService();

View file

@ -278,7 +278,7 @@ export async function getTFImage(blob): Promise<tf.Tensor3D> {
return tfImage;
}
export async function getImageBitmap(blob: Blob): Promise<ImageBitmap> {
export async function getImageBlobBitmap(blob: Blob): Promise<ImageBitmap> {
return await createImageBitmap(blob);
}
@ -352,7 +352,7 @@ export async function getOriginalImageBitmap(
}
console.log('[MLService] Got file: ', file.id.toString());
return getImageBitmap(fileBlob);
return getImageBlobBitmap(fileBlob);
}
export async function getThumbnailImageBitmap(file: EnteFile, token: string) {
@ -365,7 +365,7 @@ export async function getThumbnailImageBitmap(file: EnteFile, token: string) {
const thumbFile = await fetch(fileUrl);
return getImageBitmap(await thumbFile.blob());
return getImageBlobBitmap(await thumbFile.blob());
}
export async function getLocalFileImageBitmap(
@ -378,7 +378,7 @@ export async function getLocalFileImageBitmap(
const enteWorker = await enteWorkerProvider();
fileBlob = await convertForPreview(enteFile, fileBlob, enteWorker);
}
return getImageBitmap(fileBlob);
return getImageBlobBitmap(fileBlob);
}
export async function getPeopleList(file: EnteFile): Promise<Array<Person>> {