add cancel logic and fixed duplicate file entry

This commit is contained in:
Abhinav 2023-10-19 17:59:53 +05:30
parent 91b47673d5
commit d1f49d5b05

View file

@ -8,10 +8,11 @@ import { Embedding, Model } from 'types/embedding';
import ComlinkCryptoWorker from 'utils/comlink/ComlinkCryptoWorker'; import ComlinkCryptoWorker from 'utils/comlink/ComlinkCryptoWorker';
import { logError } from 'utils/sentry'; import { logError } from 'utils/sentry';
import { addLogLine } from 'utils/logging'; import { addLogLine } from 'utils/logging';
import { CustomError } from 'utils/error';
class ClipServiceImpl { class ClipServiceImpl {
private electronAPIs: ElectronAPIs; private electronAPIs: ElectronAPIs;
private embeddingExtractionInProgress = false; private embeddingExtractionInProgress: AbortController = null;
private reRunNeeded = false; private reRunNeeded = false;
constructor() { constructor() {
@ -31,12 +32,13 @@ class ClipServiceImpl {
'clip embedding extraction not in progress, starting clip embedding extraction' 'clip embedding extraction not in progress, starting clip embedding extraction'
); );
} }
this.embeddingExtractionInProgress = true; const canceller = new AbortController();
this.embeddingExtractionInProgress = canceller;
try { try {
await this.runClipEmbeddingExtraction(); await this.runClipEmbeddingExtraction(canceller);
} finally { } finally {
this.embeddingExtractionInProgress = false; this.embeddingExtractionInProgress = null;
if (this.reRunNeeded) { if (!canceller.signal.aborted && this.reRunNeeded) {
this.reRunNeeded = false; this.reRunNeeded = false;
addLogLine('re-running clip embedding extraction'); addLogLine('re-running clip embedding extraction');
setTimeout( setTimeout(
@ -46,7 +48,9 @@ class ClipServiceImpl {
} }
} }
} catch (e) { } catch (e) {
logError(e, 'failed to schedule clip embedding extraction'); if (e.message !== CustomError.REQUEST_CANCELLED) {
logError(e, 'failed to schedule clip embedding extraction');
}
} }
}; };
@ -59,10 +63,10 @@ class ClipServiceImpl {
} }
}; };
private runClipEmbeddingExtraction = async () => { private runClipEmbeddingExtraction = async (canceller: AbortController) => {
try { try {
const localFiles = await getLocalFiles(); const localFiles = await getLocalFiles();
const existingEmbeddings = await getLocalEmbeddings(); const existingEmbeddings = await getAllClipImageEmbeddings();
const pendingFiles = await getNonClipEmbeddingExtractedFiles( const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles, localFiles,
existingEmbeddings existingEmbeddings
@ -72,28 +76,41 @@ class ClipServiceImpl {
} }
for (const file of pendingFiles) { for (const file of pendingFiles) {
try { try {
const embedding = await this.extractClipImageEmbedding( if (canceller.signal.aborted) {
throw Error(CustomError.REQUEST_CANCELLED);
}
const embeddingData = await this.extractClipImageEmbedding(
file file
); );
const comlinkCryptoWorker = const comlinkCryptoWorker =
await ComlinkCryptoWorker.getInstance(); await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbedding } = const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptEmbedding( await comlinkCryptoWorker.encryptEmbedding(
embedding, embeddingData,
file.key file.key
); );
await putEmbedding({ await putEmbedding({
fileID: file.id, fileID: file.id,
encryptedEmbedding: encryptedEmbedding.encryptedData, encryptedEmbedding:
decryptionHeader: encryptedEmbedding.decryptionHeader, encryptedEmbeddingData.encryptedData,
decryptionHeader:
encryptedEmbeddingData.decryptionHeader,
model: Model.GGML_CLIP, model: Model.GGML_CLIP,
}); });
} catch (e) { } catch (e) {
logError(e, 'failed to extract clip embedding for file'); if (e.message !== CustomError.REQUEST_CANCELLED) {
logError(
e,
'failed to extract clip embedding for file'
);
}
} }
} }
} catch (e) { } catch (e) {
logError(e, 'failed to extract clip embedding'); if (e.message !== CustomError.REQUEST_CANCELLED) {
logError(e, 'failed to extract clip embedding');
}
throw e;
} }
}; };
@ -118,7 +135,17 @@ const getNonClipEmbeddingExtractedFiles = async (
existingEmbeddings.forEach((embedding) => existingEmbeddings.forEach((embedding) =>
existingEmbeddingFileIds.add(embedding.fileID) existingEmbeddingFileIds.add(embedding.fileID)
); );
return files.filter((file) => !existingEmbeddingFileIds.has(file.id)); const idSet = new Set<number>();
return files.filter((file) => {
if (idSet.has(file.id)) {
return false;
}
if (existingEmbeddingFileIds.has(file.id)) {
return false;
}
idSet.add(file.id);
return true;
});
}; };
export const getAllClipImageEmbeddings = async () => { export const getAllClipImageEmbeddings = async () => {