Handle first search on app start

This commit is contained in:
Manav Rathi 2024-05-20 10:38:12 +05:30
parent 34a8bdcf47
commit 10934b08a8
No known key found for this signature in database
3 changed files with 21 additions and 6 deletions

View file

@ -11,7 +11,7 @@ import * as ort from "onnxruntime-node";
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod"; import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
import log from "../log"; import log from "../log";
import { writeStream } from "../stream"; import { writeStream } from "../stream";
import { ensure } from "../utils/common"; import { ensure, wait } from "../utils/common";
import { deleteTempFile, makeTempFilePath } from "../utils/temp"; import { deleteTempFile, makeTempFilePath } from "../utils/temp";
import { makeCachedInferenceSession } from "./ml"; import { makeCachedInferenceSession } from "./ml";
@ -141,20 +141,22 @@ const getTokenizer = () => {
}; };
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => { export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
const sessionOrStatus = await Promise.race([ const sessionOrSkip = await Promise.race([
cachedCLIPTextSession(), cachedCLIPTextSession(),
"downloading-model", // Wait for a tick to get the session promise to resolved the first time
// this code runs on each app start (and the model has been downloaded).
wait(0).then(() => 1),
]); ]);
// Don't wait for the download to complete // Don't wait for the download to complete.
if (typeof sessionOrStatus == "string") { if (typeof sessionOrSkip == "number") {
log.info( log.info(
"Ignoring CLIP text embedding request because model download is pending", "Ignoring CLIP text embedding request because model download is pending",
); );
return undefined; return undefined;
} }
const session = sessionOrStatus; const session = sessionOrSkip;
const t1 = Date.now(); const t1 = Date.now();
const tokenizer = getTokenizer(); const tokenizer = getTokenizer();
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));

View file

@ -13,3 +13,12 @@ export const ensure = <T>(v: T | null | undefined): T => {
if (v === undefined) throw new Error("Required value was not found"); if (v === undefined) throw new Error("Required value was not found");
return v; return v;
}; };
/**
* Wait for {@link ms} milliseconds
*
* This function is a promisified `setTimeout`. It returns a promise that
* resolves after {@link ms} milliseconds.
*/
export const wait = (ms: number) =>
new Promise((resolve) => setTimeout(resolve, ms));

View file

@ -10,6 +10,10 @@ export const wait = (ms: number) =>
/** /**
* Await the given {@link promise} for {@link timeoutMS} milliseconds. If it * Await the given {@link promise} for {@link timeoutMS} milliseconds. If it
* does not resolve within {@link timeoutMS}, then reject with a timeout error. * does not resolve within {@link timeoutMS}, then reject with a timeout error.
*
* Note that this does not abort {@link promise} itself - it will still get
* resolved to completion, just its result will be ignored if it gets resolved
* after we've already timed out.
*/ */
export const withTimeout = async <T>(promise: Promise<T>, ms: number) => { export const withTimeout = async <T>(promise: Promise<T>, ms: number) => {
let timeoutId: ReturnType<typeof setTimeout>; let timeoutId: ReturnType<typeof setTimeout>;