add face crop regeneration logic

This commit is contained in:
Abhinav 2023-12-15 14:16:46 +05:30
parent b694c6c9ba
commit ece409a63d
7 changed files with 83 additions and 6 deletions

View file

@ -3,7 +3,10 @@ import { Skeleton, styled } from '@mui/material';
import { imageBitmapToBlob } from 'utils/image'; import { imageBitmapToBlob } from 'utils/image';
import { logError } from '@ente/shared/sentry'; import { logError } from '@ente/shared/sentry';
import { getBlobFromCache } from '@ente/shared/storage/cacheStorage/helpers'; import { cached } from '@ente/shared/storage/cacheStorage/helpers';
import machineLearningService from 'services/machineLearning/machineLearningService';
import { LS_KEYS, getData } from '@ente/shared/storage/localStorage';
import { User } from '@ente/shared/user/types';
export const FaceCropsRow = styled('div')` export const FaceCropsRow = styled('div')`
& > img { & > img {
@ -19,19 +22,35 @@ export const FaceImagesRow = styled('div')`
} }
`; `;
export function ImageCacheView(props: { url: string; cacheName: string }) { export function ImageCacheView(props: {
url: string;
cacheName: string;
faceID: string;
fileID: number;
}) {
const [imageBlob, setImageBlob] = useState<Blob>(); const [imageBlob, setImageBlob] = useState<Blob>();
useEffect(() => { useEffect(() => {
let didCancel = false; let didCancel = false;
const user: User = getData(LS_KEYS.USER);
async function loadImage() { async function loadImage() {
try { try {
let blob: Blob; let blob: Blob;
if (!props.url || !props.cacheName) { if (!props.url || !props.cacheName) {
blob = undefined; blob = undefined;
} else { } else {
blob = await getBlobFromCache(props.cacheName, props.url); blob = await cached(
props.cacheName,
props.url,
async () => {
return machineLearningService.regenerateFaceCrop(
user.token,
user.id,
props.fileID,
props.faceID
);
}
);
} }
!didCancel && setImageBlob(blob); !didCancel && setImageBlob(blob);

View file

@ -172,6 +172,8 @@ export function UnidentifiedFaces(props: {
faces.map((face, index) => ( faces.map((face, index) => (
<FaceChip key={index}> <FaceChip key={index}>
<ImageCacheView <ImageCacheView
faceID={face.id}
fileID={face.fileId}
url={face.crop?.imageUrl} url={face.crop?.imageUrl}
cacheName={CACHES.FACE_CROPS} cacheName={CACHES.FACE_CROPS}
/> />

View file

@ -10,10 +10,13 @@ import {
getFaceId, getFaceId,
areFaceIdsSame, areFaceIdsSame,
extractFaceImages, extractFaceImages,
getLocalFile,
getOriginalImageBitmap,
} from 'utils/machineLearning'; } from 'utils/machineLearning';
import { storeFaceCrop } from 'utils/machineLearning/faceCrop'; import { storeFaceCrop } from 'utils/machineLearning/faceCrop';
import mlIDbStorage from 'utils/storage/mlIDbStorage'; import mlIDbStorage from 'utils/storage/mlIDbStorage';
import ReaderService from './readerService'; import ReaderService from './readerService';
import { imageBitmapToBlob } from 'utils/image';
class FaceService { class FaceService {
async syncFileFaceDetections( async syncFileFaceDetections(
@ -172,7 +175,10 @@ class FaceService {
async saveFaceCrop( async saveFaceCrop(
imageBitmap: ImageBitmap, imageBitmap: ImageBitmap,
face: Face, face: Face,
syncContext: MLSyncContext syncContext: MLSyncContext,
options?: {
returnCrop: boolean;
}
) { ) {
const faceCrop = await syncContext.faceCropService.getFaceCrop( const faceCrop = await syncContext.faceCropService.getFaceCrop(
imageBitmap, imageBitmap,
@ -184,7 +190,12 @@ class FaceService {
faceCrop, faceCrop,
syncContext.config.faceCrop.blobOptions syncContext.config.faceCrop.blobOptions
); );
let blob: Blob;
if (options?.returnCrop) {
blob = await imageBitmapToBlob(faceCrop.image);
}
faceCrop.image.close(); faceCrop.image.close();
return blob;
} }
async getAllSyncedFacesMap(syncContext: MLSyncContext) { async getAllSyncedFacesMap(syncContext: MLSyncContext) {
@ -234,6 +245,23 @@ class FaceService {
// noise: syncContext.faceClusteringResults.noise, // noise: syncContext.faceClusteringResults.noise,
// }; // };
} }
public async regenerateFaceCrop(
syncContext: MLSyncContext,
fileID: number,
faceID: string
) {
const personFace = await mlIDbStorage.getFace(fileID, faceID);
if (!personFace) {
return;
}
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await getOriginalImageBitmap(file);
return await this.saveFaceCrop(imageBitmap, personFace, syncContext, {
returnCrop: true,
});
}
} }
export default new FaceService(); export default new FaceService();

View file

@ -116,6 +116,17 @@ class MachineLearningService {
return mlSyncResult; return mlSyncResult;
} }
public async regenerateFaceCrop(
token: string,
userID: number,
fileID: number,
faceID: string
) {
await downloadManager.init(APPS.PHOTOS, { token });
const syncContext = await this.getSyncContext(token, userID);
return FaceService.regenerateFaceCrop(syncContext, fileID, faceID);
}
private newMlData(fileId: number) { private newMlData(fileId: number) {
return { return {
fileId, fileId,

View file

@ -263,6 +263,12 @@ class MLIDbStorage {
await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId))); await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId)));
} }
public async getFace(fileID: number, faceId: string) {
const file = await this.getFile(fileID);
const face = file.faces.filter((f) => f.id === faceId);
return face[0];
}
public async getAllFacesMap() { public async getAllFacesMap() {
const startTime = Date.now(); const startTime = Date.now();
const db = await this.db; const db = await this.db;

View file

@ -37,6 +37,15 @@ export class DedicatedMLWorker implements MachineLearningWorker {
return mlService.sync(token, userID); return mlService.sync(token, userID);
} }
public async regenerateFaceCrop(
token: string,
userID: number,
fileID: number,
faceID: string
) {
return mlService.regenerateFaceCrop(token, userID, fileID, faceID);
}
public close() { public close() {
self.close(); self.close();
} }

View file

@ -33,7 +33,9 @@ export async function getBlobFromCache(
): Promise<Blob> { ): Promise<Blob> {
const cache = await CacheStorageService.open(cacheName); const cache = await CacheStorageService.open(cacheName);
const response = await cache.match(url); const response = await cache.match(url);
if (!response) {
return undefined;
}
return response.blob(); return response.blob();
} }