diff --git a/apps/photos/src/services/embeddingService.ts b/apps/photos/src/services/embeddingService.ts index 492078192..a69f95ffa 100644 --- a/apps/photos/src/services/embeddingService.ts +++ b/apps/photos/src/services/embeddingService.ts @@ -44,11 +44,19 @@ export const getLocalEmbeddings = async (model: Model) => { return embeddings.filter((embedding) => embedding.model === model); }; -const getEmbeddingSyncTime = async () => { - return (await localForage.getItem(EMBEDDING_SYNC_TIME_TABLE)) ?? 0; +const getModelEmbeddingSyncTime = async (model: Model) => { + return ( + (await localForage.getItem( + `${model}-${EMBEDDING_SYNC_TIME_TABLE}` + )) ?? 0 + ); }; -export const syncEmbeddings = async (model: Model = Model.ONNX_CLIP) => { +const setModelEmbeddingSyncTime = async (model: Model, time: number) => { + await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time); +}; + +export const syncEmbeddings = async (models: Model[] = [Model.ONNX_CLIP]) => { try { let allEmbeddings = await getAllLocalEmbeddings(); const localFiles = await getAllLocalFiles(); @@ -61,62 +69,74 @@ export const syncEmbeddings = async (model: Model = Model.ONNX_CLIP) => { }); await cleanupDeletedEmbeddings(allLocalFiles, allEmbeddings); addLogLine(`Syncing embeddings localCount: ${allEmbeddings.length}`); - let sinceTime = await getEmbeddingSyncTime(); - addLogLine(`Syncing embeddings sinceTime: ${sinceTime}`); - let response: GetEmbeddingDiffResponse; - do { - response = await getEmbeddingsDiff(sinceTime, model); - if (!response.diff?.length) { - return; - } - const newEmbeddings = await Promise.all( - response.diff.map(async (embedding) => { - try { - const { - encryptedEmbedding, - decryptionHeader, - ...rest - } = embedding; - const worker = await ComlinkCryptoWorker.getInstance(); - const fileKey = fileIdToKeyMap.get(embedding.fileID); - if (!fileKey) { - throw Error(CustomError.FILE_NOT_FOUND); - } - const decryptedData = await worker.decryptEmbedding( - encryptedEmbedding, - decryptionHeader, - fileIdToKeyMap.get(embedding.fileID) - ); - - return { - ...rest, - embedding: decryptedData, - } as Embedding; - } catch (e) { - let info: Record; - if (e.message === CustomError.FILE_NOT_FOUND) { - const hasHiddenAlbums = hiddenAlbums?.length > 0; - info = { - hasHiddenAlbums, - }; - } - logError(e, 'decryptEmbedding failed for file', info); - } - }) - ); - allEmbeddings = getLatestVersionEmbeddings([ - ...allEmbeddings, - ...newEmbeddings, - ]); - if (response.diff.length) { - sinceTime = response.diff.slice(-1)[0].updatedAt; - } - await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings); - await localForage.setItem(EMBEDDING_SYNC_TIME_TABLE, sinceTime); + for (const model of models) { + let modelLastSinceTime = await getModelEmbeddingSyncTime(model); addLogLine( - `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}` + `Syncing ${model} model's embeddings sinceTime: ${modelLastSinceTime}` ); - } while (response.diff.length === DIFF_LIMIT); + let response: GetEmbeddingDiffResponse; + do { + response = await getEmbeddingsDiff(modelLastSinceTime, model); + if (!response.diff?.length) { + return; + } + const newEmbeddings = await Promise.all( + response.diff.map(async (embedding) => { + try { + const { + encryptedEmbedding, + decryptionHeader, + ...rest + } = embedding; + const worker = + await ComlinkCryptoWorker.getInstance(); + const fileKey = fileIdToKeyMap.get( + embedding.fileID + ); + if (!fileKey) { + throw Error(CustomError.FILE_NOT_FOUND); + } + const decryptedData = await worker.decryptEmbedding( + encryptedEmbedding, + decryptionHeader, + fileIdToKeyMap.get(embedding.fileID) + ); + + return { + ...rest, + embedding: decryptedData, + } as Embedding; + } catch (e) { + let info: Record; + if (e.message === CustomError.FILE_NOT_FOUND) { + const hasHiddenAlbums = + hiddenAlbums?.length > 0; + info = { + hasHiddenAlbums, + }; + } + logError( + e, + 'decryptEmbedding failed for file', + info + ); + } + }) + ); + allEmbeddings = getLatestVersionEmbeddings([ + ...allEmbeddings, + ...newEmbeddings, + ]); + if (response.diff.length) { + modelLastSinceTime = response.diff.slice(-1)[0].updatedAt; + } + await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings); + await setModelEmbeddingSyncTime(model, modelLastSinceTime); + addLogLine( + `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}` + ); + } while (response.diff.length === DIFF_LIMIT); + } } catch (e) { logError(e, 'Sync embeddings failed'); }