add clipExtraction logic

This commit is contained in:
Abhinav 2023-10-18 15:02:48 +05:30
parent fbbd4c9c3a
commit 3acd38ba24
3 changed files with 130 additions and 0 deletions

View file

@ -122,6 +122,7 @@ import { constructUserIDToEmailMap } from 'services/collectionService';
import { getLocalFamilyData } from 'utils/user/family';
import InMemoryStore, { MS_KEYS } from 'services/InMemoryStore';
import { syncEmbeddings } from 'services/embeddingService';
import { ClipService } from 'services/clipService';
export const DeadCenter = styled('div')`
flex: 1;
@ -694,6 +695,7 @@ export default function Gallery() {
await syncEntities();
await syncMapEnabled();
await syncEmbeddings();
await ClipService.scheduleImageEmbeddingExtraction();
} catch (e) {
switch (e.message) {
case ServerErrorCodes.SESSION_EXPIRED:

View file

@ -0,0 +1,113 @@
import { EnteFile } from 'types/file';
import { putEmbedding, getLocalEmbeddings } from './embeddingService';
import { getLocalFiles } from './fileService';
import { ElectronAPIs } from 'types/electron';
import downloadManager from './downloadManager';
import { getToken } from 'utils/common/key';
import { Embedding, Model } from 'types/embedding';
import ComlinkCryptoWorker from 'utils/comlink/ComlinkCryptoWorker';
import { logError } from 'utils/sentry';
import { addLogLine } from 'utils/logging';
class ClipServiceImpl {
private electronAPIs: ElectronAPIs;
private embeddingExtractionInProgress = false;
private reRunNeeded = false;
constructor() {
this.electronAPIs = globalThis['ElectronAPIs'];
}
scheduleImageEmbeddingExtraction = async () => {
try {
if (this.embeddingExtractionInProgress) {
addLogLine(
'clip embedding extraction already in progress, scheduling re-run'
);
this.reRunNeeded = true;
return;
} else {
addLogLine(
'clip embedding extraction not in progress, starting clip embedding extraction'
);
}
this.embeddingExtractionInProgress = true;
try {
await this.runClipEmbeddingExtraction();
} finally {
this.embeddingExtractionInProgress = false;
if (this.reRunNeeded) {
this.reRunNeeded = false;
addLogLine('re-running clip embedding extraction');
setTimeout(
() => this.scheduleImageEmbeddingExtraction(),
0
);
}
}
} catch (e) {
logError(e, 'failed to schedule clip embedding extraction');
}
};
private runClipEmbeddingExtraction = async () => {
try {
const localFiles = await getLocalFiles();
const existingEmbeddings = await getLocalEmbeddings();
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,
existingEmbeddings
);
if (pendingFiles.length === 0) {
return;
}
for (const file of pendingFiles) {
try {
const embedding = await this.extractClipImageEmbedding(
file
);
const comlinkCryptoWorker =
await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbedding } =
await comlinkCryptoWorker.encryptEmbedding(
embedding,
file.key
);
await putEmbedding({
fileID: file.id,
encryptedEmbedding: encryptedEmbedding.encryptedData,
decryptionHeader: encryptedEmbedding.decryptionHeader,
model: Model.GGML_CLIP,
});
} catch (e) {
logError(e, 'failed to extract clip embedding for file');
}
}
} catch (e) {
logError(e, 'failed to extract clip embedding');
}
};
private extractClipImageEmbedding = async (file: EnteFile) => {
const token = getToken();
if (!token) {
return;
}
const thumb = await downloadManager.downloadThumb(token, file);
const embedding = await this.electronAPIs.computeImageEmbedding(thumb);
return embedding;
};
}
export const ClipService = new ClipServiceImpl();
const getNonClipEmbeddingExtractedFiles = async (
files: EnteFile[],
existingEmbeddings: Embedding[]
) => {
const existingEmbeddingFileIds = new Set<number>();
existingEmbeddings.forEach((embedding) =>
existingEmbeddingFileIds.add(embedding.fileID)
);
return files.filter((file) => !existingEmbeddingFileIds.has(file.id));
};

View file

@ -62,6 +62,21 @@ export class DedicatedCryptoWorker {
return libsodium.encryptChaChaOneShot(fileData, key);
}
async encryptEmbedding(embedding: Float32Array, key: string) {
const { file: encryptEmbedding } = await libsodium.encryptChaChaOneShot(
new Uint8Array(embedding.buffer),
key
);
const { encryptedData, ...other } = encryptEmbedding;
return {
file: {
encryptedData: await libsodium.toB64(encryptedData),
...other,
},
key,
};
}
async encryptFile(fileData: Uint8Array) {
return libsodium.encryptChaCha(fileData);
}