diff --git a/web/apps/photos/src/services/clip-service.ts b/web/apps/photos/src/services/clip-service.ts index 703c89cf4..aa724b4d5 100644 --- a/web/apps/photos/src/services/clip-service.ts +++ b/web/apps/photos/src/services/clip-service.ts @@ -11,7 +11,7 @@ import { Embedding } from "types/embedding"; import { EnteFile } from "types/file"; import { getPersonalFiles } from "utils/file"; import downloadManager from "./download"; -import { getLocalEmbeddings, putEmbedding } from "./embeddingService"; +import { localCLIPEmbeddings, putEmbedding } from "./embeddingService"; import { getAllLocalFiles, getLocalFiles } from "./fileService"; /** Status of CLIP indexing on the images in the user's local library. */ @@ -195,7 +195,7 @@ class CLIPService { return; } const localFiles = getPersonalFiles(await getAllLocalFiles(), user); - const existingEmbeddings = await getLocalEmbeddings(); + const existingEmbeddings = await localCLIPEmbeddings(); const pendingFiles = await getNonClipEmbeddingExtractedFiles( localFiles, existingEmbeddings, @@ -394,7 +394,7 @@ export const computeClipMatchScore = async ( const initialIndexingStatus = async (): Promise => { const user = getData(LS_KEYS.USER); if (!user) throw new Error("Orphan CLIP indexing without a login"); - const allEmbeddings = await getLocalEmbeddings(); + const allEmbeddings = await localCLIPEmbeddings(); const localFiles = getPersonalFiles(await getLocalFiles(), user); const pendingFiles = await getNonClipEmbeddingExtractedFiles( localFiles, diff --git a/web/apps/photos/src/services/embeddingService.ts b/web/apps/photos/src/services/embeddingService.ts index 59096349e..2b1ebc0db 100644 --- a/web/apps/photos/src/services/embeddingService.ts +++ b/web/apps/photos/src/services/embeddingService.ts @@ -20,22 +20,26 @@ import { getLocalCollections } from "./collectionService"; import { getAllLocalFiles } from "./fileService"; import { getLocalTrashedFiles } from "./trashService"; -const ENDPOINT = getEndpoint(); - const DIFF_LIMIT = 500; -const EMBEDDINGS_TABLE_V1 = "embeddings"; -const EMBEDDINGS_TABLE = "embeddings_v2"; +/** Local storage key suffix for embedding sync times */ +const embeddingSyncTimeLSKeySuffix = "embedding_sync_time"; +/** Local storage key for CLIP embeddings. */ +const clipEmbeddingsLSKey = "embeddings_v2"; const FILE_EMBEDING_TABLE = "file_embeddings"; -const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time"; -export const getAllLocalEmbeddings = async () => { +/** Return all CLIP embeddings that we have available locally. */ +export const localCLIPEmbeddings = async () => + (await storedCLIPEmbeddings()).filter(({ model }) => model === "onnx-clip"); + +const storedCLIPEmbeddings = async () => { const embeddings: Array = - await localForage.getItem(EMBEDDINGS_TABLE); + await localForage.getItem(clipEmbeddingsLSKey); if (!embeddings) { - await localForage.removeItem(EMBEDDINGS_TABLE_V1); - await localForage.removeItem(EMBEDDING_SYNC_TIME_TABLE); - await localForage.setItem(EMBEDDINGS_TABLE, []); + // Migrate + await localForage.removeItem("embeddings"); + await localForage.removeItem("embedding_sync_time"); + await localForage.setItem(clipEmbeddingsLSKey, []); return []; } return embeddings; @@ -50,15 +54,10 @@ export const getFileMLEmbeddings = async (): Promise => { return embeddings; }; -export const getLocalEmbeddings = async () => { - const embeddings = await getAllLocalEmbeddings(); - return embeddings.filter((embedding) => embedding.model === "onnx-clip"); -}; - const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => { return ( (await localForage.getItem( - `${model}-${EMBEDDING_SYNC_TIME_TABLE}`, + `${model}-${embeddingSyncTimeLSKeySuffix}`, )) ?? 0 ); }; @@ -67,13 +66,13 @@ const setModelEmbeddingSyncTime = async ( model: EmbeddingModel, time: number, ) => { - await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time); + await localForage.setItem(`${model}-${embeddingSyncTimeLSKeySuffix}`, time); }; export const syncEmbeddings = async () => { const models: EmbeddingModel[] = ["onnx-clip"]; try { - let allEmbeddings = await getAllLocalEmbeddings(); + let allEmbeddings = await storedCLIPEmbeddings(); const localFiles = await getAllLocalFiles(); const hiddenAlbums = await getLocalCollections("hidden"); const localTrashFiles = await getLocalTrashedFiles(); @@ -85,7 +84,7 @@ export const syncEmbeddings = async () => { await cleanupDeletedEmbeddings( allLocalFiles, allEmbeddings, - EMBEDDINGS_TABLE, + clipEmbeddingsLSKey, ); log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`); for (const model of models) { @@ -144,7 +143,7 @@ export const syncEmbeddings = async () => { if (response.diff.length) { modelLastSinceTime = response.diff.slice(-1)[0].updatedAt; } - await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings); + await localForage.setItem(clipEmbeddingsLSKey, allEmbeddings); await setModelEmbeddingSyncTime(model, modelLastSinceTime); log.info( `Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`, @@ -281,7 +280,7 @@ export const getEmbeddingsDiff = async ( return; } const response = await HTTPService.get( - `${ENDPOINT}/embeddings/diff`, + `${getEndpoint()}/embeddings/diff`, { sinceTime, limit: DIFF_LIMIT, @@ -310,7 +309,7 @@ export const putEmbedding = async ( throw Error(CustomError.TOKEN_MISSING); } const resp = await HTTPService.put( - `${ENDPOINT}/embeddings`, + `${getEndpoint()}/embeddings`, putEmbeddingReq, null, { diff --git a/web/apps/photos/src/services/searchService.ts b/web/apps/photos/src/services/searchService.ts index ef49b126d..a212fc9dc 100644 --- a/web/apps/photos/src/services/searchService.ts +++ b/web/apps/photos/src/services/searchService.ts @@ -20,7 +20,7 @@ import ComlinkSearchWorker from "utils/comlink/ComlinkSearchWorker"; import { getUniqueFiles } from "utils/file"; import { getFormattedDate } from "utils/search"; import { clipService, computeClipMatchScore } from "./clip-service"; -import { getLocalEmbeddings } from "./embeddingService"; +import { localCLIPEmbeddings } from "./embeddingService"; import { getLatestEntities } from "./entityService"; import locationSearchService, { City } from "./locationSearchService"; @@ -375,7 +375,7 @@ const searchClip = async ( await clipService.getTextEmbeddingIfAvailable(searchPhrase); if (!textEmbedding) return undefined; - const imageEmbeddings = await getLocalEmbeddings(); + const imageEmbeddings = await localCLIPEmbeddings(); const clipSearchResult = new Map( ( await Promise.all(