This commit is contained in:
Manav Rathi 2024-04-09 19:59:31 +05:30
parent b02600cb42
commit 67e39daff5
No known key found for this signature in database
3 changed files with 57 additions and 94 deletions

View file

@ -14,7 +14,7 @@ import { EnteMenuItem } from "components/Menu/EnteMenuItem";
import { MenuItemGroup } from "components/Menu/MenuItemGroup";
import isElectron from "is-electron";
import { AppContext } from "pages/_app";
import { ClipExtractionStatus, clipService } from "services/clip-service";
import { CLIPIndexingStatus, clipService } from "services/clip-service";
import { formatNumber } from "utils/number/format";
export default function AdvancedSettings({ open, onClose, onRootClose }) {
@ -44,17 +44,15 @@ export default function AdvancedSettings({ open, onClose, onRootClose }) {
log.error("toggleFasterUpload failed", e);
}
};
const [indexingStatus, setIndexingStatus] = useState<ClipExtractionStatus>({
const [indexingStatus, setIndexingStatus] = useState<CLIPIndexingStatus>({
indexed: 0,
pending: 0,
});
useEffect(() => {
const main = async () => {
setIndexingStatus(await clipService.getIndexingStatus());
clipService.setOnUpdateHandler(setIndexingStatus);
};
main();
clipService.setOnUpdateHandler(setIndexingStatus);
clipService.getIndexingStatus().then((st) => setIndexingStatus(st));
return () => clipService.setOnUpdateHandler(undefined);
}, []);
return (

View file

@ -14,10 +14,11 @@ import downloadManager from "./download";
import { getLocalEmbeddings, putEmbedding } from "./embeddingService";
import { getAllLocalFiles, getLocalFiles } from "./fileService";
const CLIP_EMBEDDING_LENGTH = 512;
export interface ClipExtractionStatus {
/** Status of CLIP indexing on the images in the user's local library. */
export interface CLIPIndexingStatus {
/** Number of items pending indexing. */
pending: number;
/** Number of items that have already been indexed. */
indexed: number;
}
@ -62,15 +63,14 @@ export interface ClipExtractionStatus {
*
* Both these currently have one (and only one) associated model.
*/
class ClipServiceImpl {
class ClipService {
private embeddingExtractionInProgress: AbortController | null = null;
private reRunNeeded = false;
private clipExtractionStatus: ClipExtractionStatus = {
private indexingStatus: CLIPIndexingStatus = {
pending: 0,
indexed: 0,
};
private onUpdateHandler: ((status: ClipExtractionStatus) => void) | null =
null;
private onUpdateHandler: ((status: CLIPIndexingStatus) => void) | undefined;
private liveEmbeddingExtractionQueue: PQueue;
private onFileUploadedHandler:
| ((arg: { enteFile: EnteFile; localFile: globalThis.File }) => void)
@ -137,28 +137,23 @@ class ClipServiceImpl {
};
getIndexingStatus = async () => {
try {
if (
!this.clipExtractionStatus ||
(this.clipExtractionStatus.pending === 0 &&
this.clipExtractionStatus.indexed === 0)
) {
this.clipExtractionStatus = await getClipExtractionStatus();
}
return this.clipExtractionStatus;
} catch (e) {
log.error("failed to get clip indexing status", e);
if (
this.indexingStatus.pending === 0 &&
this.indexingStatus.indexed === 0
) {
this.indexingStatus = await initialIndexingStatus();
}
return this.indexingStatus;
};
setOnUpdateHandler = (handler: (status: ClipExtractionStatus) => void) => {
/**
* Set the {@link handler} to invoke whenever our indexing status changes.
*/
setOnUpdateHandler = (handler?: (status: CLIPIndexingStatus) => void) => {
this.onUpdateHandler = handler;
handler(this.clipExtractionStatus);
};
scheduleImageEmbeddingExtraction = async (
model: Model = Model.ONNX_CLIP,
) => {
scheduleImageEmbeddingExtraction = async () => {
try {
if (this.embeddingExtractionInProgress) {
log.info(
@ -174,7 +169,7 @@ class ClipServiceImpl {
const canceller = new AbortController();
this.embeddingExtractionInProgress = canceller;
try {
await this.runClipEmbeddingExtraction(canceller, model);
await this.runClipEmbeddingExtraction(canceller);
} finally {
this.embeddingExtractionInProgress = null;
if (!canceller.signal.aborted && this.reRunNeeded) {
@ -193,12 +188,9 @@ class ClipServiceImpl {
}
};
getTextEmbedding = async (
text: string,
model: Model = Model.ONNX_CLIP,
): Promise<Float32Array> => {
getTextEmbedding = async (text: string): Promise<Float32Array> => {
try {
return ensureElectron().computeTextEmbedding(model, text);
return ensureElectron().computeTextEmbedding(Model.ONNX_CLIP, text);
} catch (e) {
if (e?.message?.includes(CustomError.UNSUPPORTED_PLATFORM)) {
this.unsupportedPlatform = true;
@ -208,10 +200,7 @@ class ClipServiceImpl {
}
};
private runClipEmbeddingExtraction = async (
canceller: AbortController,
model: Model,
) => {
private runClipEmbeddingExtraction = async (canceller: AbortController) => {
try {
if (this.unsupportedPlatform) {
log.info(
@ -224,12 +213,12 @@ class ClipServiceImpl {
return;
}
const localFiles = getPersonalFiles(await getAllLocalFiles(), user);
const existingEmbeddings = await getLocalEmbeddings(model);
const existingEmbeddings = await getLocalEmbeddings();
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,
existingEmbeddings,
);
this.updateClipEmbeddingExtractionStatus({
this.updateIndexingStatus({
indexed: existingEmbeddings.length,
pending: pendingFiles.length,
});
@ -249,15 +238,11 @@ class ClipServiceImpl {
throw Error(CustomError.REQUEST_CANCELLED);
}
const embeddingData =
await this.extractFileClipImageEmbedding(model, file);
await this.extractFileClipImageEmbedding(file);
log.info(
`successfully extracted clip embedding for file: ${file.metadata.title} fileID: ${file.id} embedding length: ${embeddingData?.length}`,
);
await this.encryptAndUploadEmbedding(
model,
file,
embeddingData,
);
await this.encryptAndUploadEmbedding(file, embeddingData);
this.onSuccessStatusUpdater();
log.info(
`successfully put clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
@ -290,13 +275,10 @@ class ClipServiceImpl {
}
};
private async runLocalFileClipExtraction(
arg: {
enteFile: EnteFile;
localFile: globalThis.File;
},
model: Model = Model.ONNX_CLIP,
) {
private async runLocalFileClipExtraction(arg: {
enteFile: EnteFile;
localFile: globalThis.File;
}) {
const { enteFile, localFile } = arg;
log.info(
`clip embedding extraction onFileUploadedHandler file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
@ -320,15 +302,9 @@ class ClipServiceImpl {
);
try {
await this.liveEmbeddingExtractionQueue.add(async () => {
const embedding = await this.extractLocalFileClipImageEmbedding(
model,
localFile,
);
await this.encryptAndUploadEmbedding(
model,
enteFile,
embedding,
);
const embedding =
await this.extractLocalFileClipImageEmbedding(localFile);
await this.encryptAndUploadEmbedding(enteFile, embedding);
});
log.info(
`successfully extracted clip embedding for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
@ -338,26 +314,22 @@ class ClipServiceImpl {
}
}
private extractLocalFileClipImageEmbedding = async (
model: Model,
localFile: File,
) => {
private extractLocalFileClipImageEmbedding = async (localFile: File) => {
const file = await localFile
.arrayBuffer()
.then((buffer) => new Uint8Array(buffer));
const embedding = await ensureElectron().computeImageEmbedding(
model,
Model.ONNX_CLIP,
file,
);
return embedding;
};
private encryptAndUploadEmbedding = async (
model: Model,
file: EnteFile,
embeddingData: Float32Array,
) => {
if (embeddingData?.length !== CLIP_EMBEDDING_LENGTH) {
if (embeddingData?.length !== 512) {
throw Error(
`invalid length embedding data length: ${embeddingData?.length}`,
);
@ -368,6 +340,7 @@ class ClipServiceImpl {
log.info(
`putting clip embedding to server for file: ${file.metadata.title} fileID: ${file.id}`,
);
const model = Model.ONNX_CLIP;
await putEmbedding({
fileID: file.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
@ -376,34 +349,30 @@ class ClipServiceImpl {
});
};
updateClipEmbeddingExtractionStatus = (status: ClipExtractionStatus) => {
this.clipExtractionStatus = status;
if (this.onUpdateHandler) {
this.onUpdateHandler(status);
}
private updateIndexingStatus = (status: CLIPIndexingStatus) => {
this.indexingStatus = status;
const handler = this.onUpdateHandler;
if (handler) handler(status);
};
private extractFileClipImageEmbedding = async (
model: Model,
file: EnteFile,
) => {
private extractFileClipImageEmbedding = async (file: EnteFile) => {
const thumb = await downloadManager.getThumbnail(file);
const embedding = await ensureElectron().computeImageEmbedding(
model,
Model.ONNX_CLIP,
thumb,
);
return embedding;
};
private onSuccessStatusUpdater = () => {
this.updateClipEmbeddingExtractionStatus({
pending: this.clipExtractionStatus.pending - 1,
indexed: this.clipExtractionStatus.indexed + 1,
this.updateIndexingStatus({
pending: this.indexingStatus.pending - 1,
indexed: this.indexingStatus.indexed + 1,
});
};
}
export const clipService = new ClipServiceImpl();
export const clipService = new ClipService();
const getNonClipEmbeddingExtractedFiles = async (
files: EnteFile[],
@ -453,14 +422,10 @@ export const computeClipMatchScore = async (
return score;
};
const getClipExtractionStatus = async (
model: Model = Model.ONNX_CLIP,
): Promise<ClipExtractionStatus> => {
const initialIndexingStatus = async (): Promise<CLIPIndexingStatus> => {
const user = getData(LS_KEYS.USER);
if (!user) {
return;
}
const allEmbeddings = await getLocalEmbeddings(model);
if (!user) throw new Error("Orphan CLIP indexing without a login");
const allEmbeddings = await getLocalEmbeddings();
const localFiles = getPersonalFiles(await getLocalFiles(), user);
const pendingFiles = await getNonClipEmbeddingExtractedFiles(
localFiles,

View file

@ -38,9 +38,9 @@ export const getAllLocalEmbeddings = async () => {
return embeddings;
};
export const getLocalEmbeddings = async (model: Model) => {
export const getLocalEmbeddings = async () => {
const embeddings = await getAllLocalEmbeddings();
return embeddings.filter((embedding) => embedding.model === model);
return embeddings.filter((embedding) => embedding.model === Model.ONNX_CLIP);
};
const getModelEmbeddingSyncTime = async (model: Model) => {