Handle first search on app start
This commit is contained in:
parent
34a8bdcf47
commit
10934b08a8
|
@ -11,7 +11,7 @@ import * as ort from "onnxruntime-node";
|
||||||
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
|
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
|
||||||
import log from "../log";
|
import log from "../log";
|
||||||
import { writeStream } from "../stream";
|
import { writeStream } from "../stream";
|
||||||
import { ensure } from "../utils/common";
|
import { ensure, wait } from "../utils/common";
|
||||||
import { deleteTempFile, makeTempFilePath } from "../utils/temp";
|
import { deleteTempFile, makeTempFilePath } from "../utils/temp";
|
||||||
import { makeCachedInferenceSession } from "./ml";
|
import { makeCachedInferenceSession } from "./ml";
|
||||||
|
|
||||||
|
@ -141,20 +141,22 @@ const getTokenizer = () => {
|
||||||
};
|
};
|
||||||
|
|
||||||
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
|
||||||
const sessionOrStatus = await Promise.race([
|
const sessionOrSkip = await Promise.race([
|
||||||
cachedCLIPTextSession(),
|
cachedCLIPTextSession(),
|
||||||
"downloading-model",
|
// Wait for a tick to get the session promise to resolved the first time
|
||||||
|
// this code runs on each app start (and the model has been downloaded).
|
||||||
|
wait(0).then(() => 1),
|
||||||
]);
|
]);
|
||||||
|
|
||||||
// Don't wait for the download to complete
|
// Don't wait for the download to complete.
|
||||||
if (typeof sessionOrStatus == "string") {
|
if (typeof sessionOrSkip == "number") {
|
||||||
log.info(
|
log.info(
|
||||||
"Ignoring CLIP text embedding request because model download is pending",
|
"Ignoring CLIP text embedding request because model download is pending",
|
||||||
);
|
);
|
||||||
return undefined;
|
return undefined;
|
||||||
}
|
}
|
||||||
|
|
||||||
const session = sessionOrStatus;
|
const session = sessionOrSkip;
|
||||||
const t1 = Date.now();
|
const t1 = Date.now();
|
||||||
const tokenizer = getTokenizer();
|
const tokenizer = getTokenizer();
|
||||||
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
|
||||||
|
|
|
@ -13,3 +13,12 @@ export const ensure = <T>(v: T | null | undefined): T => {
|
||||||
if (v === undefined) throw new Error("Required value was not found");
|
if (v === undefined) throw new Error("Required value was not found");
|
||||||
return v;
|
return v;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wait for {@link ms} milliseconds
|
||||||
|
*
|
||||||
|
* This function is a promisified `setTimeout`. It returns a promise that
|
||||||
|
* resolves after {@link ms} milliseconds.
|
||||||
|
*/
|
||||||
|
export const wait = (ms: number) =>
|
||||||
|
new Promise((resolve) => setTimeout(resolve, ms));
|
||||||
|
|
|
@ -10,6 +10,10 @@ export const wait = (ms: number) =>
|
||||||
/**
|
/**
|
||||||
* Await the given {@link promise} for {@link timeoutMS} milliseconds. If it
|
* Await the given {@link promise} for {@link timeoutMS} milliseconds. If it
|
||||||
* does not resolve within {@link timeoutMS}, then reject with a timeout error.
|
* does not resolve within {@link timeoutMS}, then reject with a timeout error.
|
||||||
|
*
|
||||||
|
* Note that this does not abort {@link promise} itself - it will still get
|
||||||
|
* resolved to completion, just its result will be ignored if it gets resolved
|
||||||
|
* after we've already timed out.
|
||||||
*/
|
*/
|
||||||
export const withTimeout = async <T>(promise: Promise<T>, ms: number) => {
|
export const withTimeout = async <T>(promise: Promise<T>, ms: number) => {
|
||||||
let timeoutId: ReturnType<typeof setTimeout>;
|
let timeoutId: ReturnType<typeof setTimeout>;
|
||||||
|
|
Loading…
Reference in a new issue