This commit is contained in:
Manav Rathi 2024-05-14 15:32:02 +05:30
parent 8378b76a8c
commit 613324a4ae
No known key found for this signature in database
3 changed files with 26 additions and 27 deletions

View file

@ -11,7 +11,7 @@ import { Embedding } from "types/embedding";
import { EnteFile } from "types/file";
import { getPersonalFiles } from "utils/file";
import downloadManager from "./download";
import { getLocalEmbeddings, putEmbedding } from "./embeddingService";
import { localCLIPEmbeddings, putEmbedding } from "./embeddingService";
import { getAllLocalFiles, getLocalFiles } from "./fileService";
/** Status of CLIP indexing on the images in the user's local library. */
@ -195,7 +195,7 @@ class CLIPService {
return;
}
const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
const existingEmbeddings = await getLocalEmbeddings();
const existingEmbeddings = await localCLIPEmbeddings();
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,
existingEmbeddings,
@ -394,7 +394,7 @@ export const computeClipMatchScore = async (
const initialIndexingStatus = async (): Promise<CLIPIndexingStatus> => {
const user = getData(LS_KEYS.USER);
if (!user) throw new Error("Orphan CLIP indexing without a login");
const allEmbeddings = await getLocalEmbeddings();
const allEmbeddings = await localCLIPEmbeddings();
const localFiles = getPersonalFiles(await getLocalFiles(), user);
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,

View file

@ -20,22 +20,26 @@ import { getLocalCollections } from "./collectionService";
import { getAllLocalFiles } from "./fileService";
import { getLocalTrashedFiles } from "./trashService";
const ENDPOINT = getEndpoint();
const DIFF_LIMIT = 500;
const EMBEDDINGS_TABLE_V1 = "embeddings";
const EMBEDDINGS_TABLE = "embeddings_v2";
/** Local storage key suffix for embedding sync times */
const embeddingSyncTimeLSKeySuffix = "embedding_sync_time";
/** Local storage key for CLIP embeddings. */
const clipEmbeddingsLSKey = "embeddings_v2";
const FILE_EMBEDING_TABLE = "file_embeddings";
const EMBEDDING_SYNC_TIME_TABLE = "embedding_sync_time";
export const getAllLocalEmbeddings = async () => {
/** Return all CLIP embeddings that we have available locally. */
export const localCLIPEmbeddings = async () =>
(await storedCLIPEmbeddings()).filter(({ model }) => model === "onnx-clip");
const storedCLIPEmbeddings = async () => {
const embeddings: Array<Embedding> =
await localForage.getItem<Embedding[]>(EMBEDDINGS_TABLE);
await localForage.getItem<Embedding[]>(clipEmbeddingsLSKey);
if (!embeddings) {
await localForage.removeItem(EMBEDDINGS_TABLE_V1);
await localForage.removeItem(EMBEDDING_SYNC_TIME_TABLE);
await localForage.setItem(EMBEDDINGS_TABLE, []);
// Migrate
await localForage.removeItem("embeddings");
await localForage.removeItem("embedding_sync_time");
await localForage.setItem(clipEmbeddingsLSKey, []);
return [];
}
return embeddings;
@ -50,15 +54,10 @@ export const getFileMLEmbeddings = async (): Promise<FileML[]> => {
return embeddings;
};
export const getLocalEmbeddings = async () => {
const embeddings = await getAllLocalEmbeddings();
return embeddings.filter((embedding) => embedding.model === "onnx-clip");
};
const getModelEmbeddingSyncTime = async (model: EmbeddingModel) => {
return (
(await localForage.getItem<number>(
`${model}-${EMBEDDING_SYNC_TIME_TABLE}`,
`${model}-${embeddingSyncTimeLSKeySuffix}`,
)) ?? 0
);
};
@ -67,13 +66,13 @@ const setModelEmbeddingSyncTime = async (
model: EmbeddingModel,
time: number,
) => {
await localForage.setItem(`${model}-${EMBEDDING_SYNC_TIME_TABLE}`, time);
await localForage.setItem(`${model}-${embeddingSyncTimeLSKeySuffix}`, time);
};
export const syncEmbeddings = async () => {
const models: EmbeddingModel[] = ["onnx-clip"];
try {
let allEmbeddings = await getAllLocalEmbeddings();
let allEmbeddings = await storedCLIPEmbeddings();
const localFiles = await getAllLocalFiles();
const hiddenAlbums = await getLocalCollections("hidden");
const localTrashFiles = await getLocalTrashedFiles();
@ -85,7 +84,7 @@ export const syncEmbeddings = async () => {
await cleanupDeletedEmbeddings(
allLocalFiles,
allEmbeddings,
EMBEDDINGS_TABLE,
clipEmbeddingsLSKey,
);
log.info(`Syncing embeddings localCount: ${allEmbeddings.length}`);
for (const model of models) {
@ -144,7 +143,7 @@ export const syncEmbeddings = async () => {
if (response.diff.length) {
modelLastSinceTime = response.diff.slice(-1)[0].updatedAt;
}
await localForage.setItem(EMBEDDINGS_TABLE, allEmbeddings);
await localForage.setItem(clipEmbeddingsLSKey, allEmbeddings);
await setModelEmbeddingSyncTime(model, modelLastSinceTime);
log.info(
`Syncing embeddings syncedEmbeddingsCount: ${allEmbeddings.length}`,
@ -281,7 +280,7 @@ export const getEmbeddingsDiff = async (
return;
}
const response = await HTTPService.get(
`${ENDPOINT}/embeddings/diff`,
`${getEndpoint()}/embeddings/diff`,
{
sinceTime,
limit: DIFF_LIMIT,
@ -310,7 +309,7 @@ export const putEmbedding = async (
throw Error(CustomError.TOKEN_MISSING);
}
const resp = await HTTPService.put(
`${ENDPOINT}/embeddings`,
`${getEndpoint()}/embeddings`,
putEmbeddingReq,
null,
{

View file

@ -20,7 +20,7 @@ import ComlinkSearchWorker from "utils/comlink/ComlinkSearchWorker";
import { getUniqueFiles } from "utils/file";
import { getFormattedDate } from "utils/search";
import { clipService, computeClipMatchScore } from "./clip-service";
import { getLocalEmbeddings } from "./embeddingService";
import { localCLIPEmbeddings } from "./embeddingService";
import { getLatestEntities } from "./entityService";
import locationSearchService, { City } from "./locationSearchService";
@ -375,7 +375,7 @@ const searchClip = async (
await clipService.getTextEmbeddingIfAvailable(searchPhrase);
if (!textEmbedding) return undefined;
const imageEmbeddings = await getLocalEmbeddings();
const imageEmbeddings = await localCLIPEmbeddings();
const clipSearchResult = new Map<number, number>(
(
await Promise.all(