Merge remote-tracking branch 'origin/mobile_face' into mobile_face

This commit is contained in:
laurenspriem 2024-05-20 16:56:53 +05:30
commit 5172ce3126
51 changed files with 1616 additions and 2387 deletions

6
.gitignore vendored
View file

@ -1,8 +1,6 @@
# Let folks use their custom .vscode settings
# Let folks use their custom editor settings
.vscode
.idea
# macOS
.DS_Store
.idea
.ente.authenticator.db
.ente.offline_authenticator.db

4
.gitmodules vendored
View file

@ -2,6 +2,10 @@
path = auth/thirdparty/sentry-dart
url = https://github.com/ente-io/sentry-dart.git
branch = sentry_flutter_ente
[submodule "auth/flutter"]
path = auth/flutter
url = https://github.com/flutter/flutter.git
branch = stable
[submodule "auth/assets/simple-icons"]
path = auth/assets/simple-icons
url = https://github.com/simple-icons/simple-icons.git

View file

@ -14,7 +14,7 @@
"build:ci": "yarn build-renderer && tsc",
"build:quick": "yarn build-renderer && yarn build-main:quick",
"dev": "concurrently --kill-others --success first --names 'main,rndr' \"yarn dev-main\" \"yarn dev-renderer\"",
"dev-main": "tsc && electron app/main.js",
"dev-main": "tsc && electron .",
"dev-renderer": "cd ../web && yarn install && yarn dev:photos",
"postinstall": "electron-builder install-app-deps",
"lint": "yarn prettier --check --log-level warn . && eslint --ext .ts src && yarn tsc",

View file

@ -315,32 +315,18 @@ const setupTrayItem = (mainWindow: BrowserWindow) => {
/**
* Older versions of our app used to maintain a cache dir using the main
* process. This has been removed in favor of cache on the web layer.
* process. This has been removed in favor of cache on the web layer. Delete the
* old cache dir if it exists.
*
* Delete the old cache dir if it exists.
*
* This will happen in two phases. The cache had three subdirectories:
*
* - Two of them, "thumbs" and "files", will be removed now (v1.7.0, May 2024).
*
* - The third one, "face-crops" will be removed once we finish the face search
* changes. See: [Note: Legacy face crops].
*
* This migration code can be removed after some time once most people have
* upgraded to newer versions.
* Added May 2024, v1.7.0. This migration code can be removed after some time
* once most people have upgraded to newer versions.
*/
const deleteLegacyDiskCacheDirIfExists = async () => {
const removeIfExists = async (dirPath: string) => {
if (existsSync(dirPath)) {
log.info(`Removing legacy disk cache from ${dirPath}`);
await fs.rm(dirPath, { recursive: true });
}
};
// [Note: Getting the cache path]
//
// The existing code was passing "cache" as a parameter to getPath.
//
// However, "cache" is not a valid parameter to getPath. It works! (for
// However, "cache" is not a valid parameter to getPath. It works (for
// example, on macOS I get `~/Library/Caches`), but it is intentionally not
// documented as part of the public API:
//
@ -353,8 +339,8 @@ const deleteLegacyDiskCacheDirIfExists = async () => {
// @ts-expect-error "cache" works but is not part of the public API.
const cacheDir = path.join(app.getPath("cache"), "ente");
if (existsSync(cacheDir)) {
await removeIfExists(path.join(cacheDir, "thumbs"));
await removeIfExists(path.join(cacheDir, "files"));
log.info(`Removing legacy disk cache from ${cacheDir}`);
await fs.rm(cacheDir, { recursive: true });
}
};

View file

@ -24,7 +24,6 @@ import {
updateOnNextRestart,
} from "./services/app-update";
import {
legacyFaceCrop,
openDirectory,
openLogDirectory,
selectDirectory,
@ -43,10 +42,10 @@ import {
import { convertToJPEG, generateImageThumbnail } from "./services/image";
import { logout } from "./services/logout";
import {
clipImageEmbedding,
clipTextEmbeddingIfAvailable,
computeCLIPImageEmbedding,
computeCLIPTextEmbeddingIfAvailable,
} from "./services/ml-clip";
import { detectFaces, faceEmbeddings } from "./services/ml-face";
import { computeFaceEmbeddings, detectFaces } from "./services/ml-face";
import { encryptionKey, saveEncryptionKey } from "./services/store";
import {
clearPendingUploads,
@ -170,24 +169,22 @@ export const attachIPCHandlers = () => {
// - ML
ipcMain.handle("clipImageEmbedding", (_, jpegImageData: Uint8Array) =>
clipImageEmbedding(jpegImageData),
ipcMain.handle(
"computeCLIPImageEmbedding",
(_, jpegImageData: Uint8Array) =>
computeCLIPImageEmbedding(jpegImageData),
);
ipcMain.handle("clipTextEmbeddingIfAvailable", (_, text: string) =>
clipTextEmbeddingIfAvailable(text),
ipcMain.handle("computeCLIPTextEmbeddingIfAvailable", (_, text: string) =>
computeCLIPTextEmbeddingIfAvailable(text),
);
ipcMain.handle("detectFaces", (_, input: Float32Array) =>
detectFaces(input),
);
ipcMain.handle("faceEmbeddings", (_, input: Float32Array) =>
faceEmbeddings(input),
);
ipcMain.handle("legacyFaceCrop", (_, faceID: string) =>
legacyFaceCrop(faceID),
ipcMain.handle("computeFaceEmbeddings", (_, input: Float32Array) =>
computeFaceEmbeddings(input),
);
// - Upload

View file

@ -163,7 +163,7 @@ const checkForUpdatesAndNotify = async (mainWindow: BrowserWindow) => {
};
/**
* Return the version of the desktop app
* Return the version of the desktop app.
*
* The return value is of the form `v1.2.3`.
*/

View file

@ -1,7 +1,5 @@
import { shell } from "electron/common";
import { app, dialog } from "electron/main";
import { existsSync } from "fs";
import fs from "node:fs/promises";
import path from "node:path";
import { posixPath } from "../utils/electron";
@ -78,16 +76,3 @@ export const openLogDirectory = () => openDirectory(logDirectoryPath());
* - Windows: %USERPROFILE%\AppData\Roaming\ente\logs\ente.log
*/
const logDirectoryPath = () => app.getPath("logs");
/**
* See: [Note: Legacy face crops]
*/
export const legacyFaceCrop = async (
faceID: string,
): Promise<Uint8Array | undefined> => {
// See: [Note: Getting the cache path]
// @ts-expect-error "cache" works but is not part of the public API.
const cacheDir = path.join(app.getPath("cache"), "ente");
const filePath = path.join(cacheDir, "face-crops", faceID);
return existsSync(filePath) ? await fs.readFile(filePath) : undefined;
};

View file

@ -11,7 +11,7 @@ import * as ort from "onnxruntime-node";
import Tokenizer from "../../thirdparty/clip-bpe-ts/mod";
import log from "../log";
import { writeStream } from "../stream";
import { ensure } from "../utils/common";
import { ensure, wait } from "../utils/common";
import { deleteTempFile, makeTempFilePath } from "../utils/temp";
import { makeCachedInferenceSession } from "./ml";
@ -20,7 +20,7 @@ const cachedCLIPImageSession = makeCachedInferenceSession(
351468764 /* 335.2 MB */,
);
export const clipImageEmbedding = async (jpegImageData: Uint8Array) => {
export const computeCLIPImageEmbedding = async (jpegImageData: Uint8Array) => {
const tempFilePath = await makeTempFilePath();
const imageStream = new Response(jpegImageData.buffer).body;
await writeStream(tempFilePath, ensure(imageStream));
@ -42,7 +42,7 @@ const clipImageEmbedding_ = async (jpegFilePath: string) => {
const results = await session.run(feeds);
log.debug(
() =>
`onnx/clip image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
`ONNX/CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
/* Need these model specific casts to type the result */
const imageEmbedding = ensure(results.output).data as Float32Array;
@ -140,21 +140,23 @@ const getTokenizer = () => {
return _tokenizer;
};
export const clipTextEmbeddingIfAvailable = async (text: string) => {
const sessionOrStatus = await Promise.race([
export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => {
const sessionOrSkip = await Promise.race([
cachedCLIPTextSession(),
"downloading-model",
// Wait for a tick to get the session promise to resolved the first time
// this code runs on each app start (and the model has been downloaded).
wait(0).then(() => 1),
]);
// Don't wait for the download to complete
if (typeof sessionOrStatus == "string") {
// Don't wait for the download to complete.
if (typeof sessionOrSkip == "number") {
log.info(
"Ignoring CLIP text embedding request because model download is pending",
);
return undefined;
}
const session = sessionOrStatus;
const session = sessionOrSkip;
const t1 = Date.now();
const tokenizer = getTokenizer();
const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text));
@ -165,7 +167,7 @@ export const clipTextEmbeddingIfAvailable = async (text: string) => {
const results = await session.run(feeds);
log.debug(
() =>
`onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
`ONNX/CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`,
);
const textEmbedding = ensure(results.output).data as Float32Array;
return normalizeEmbedding(textEmbedding);

View file

@ -23,7 +23,7 @@ export const detectFaces = async (input: Float32Array) => {
input: new ort.Tensor("float32", input, [1, 3, 640, 640]),
};
const results = await session.run(feeds);
log.debug(() => `onnx/yolo face detection took ${Date.now() - t} ms`);
log.debug(() => `ONNX/YOLO face detection took ${Date.now() - t} ms`);
return ensure(results.output).data;
};
@ -32,7 +32,7 @@ const cachedFaceEmbeddingSession = makeCachedInferenceSession(
5286998 /* 5 MB */,
);
export const faceEmbeddings = async (input: Float32Array) => {
export const computeFaceEmbeddings = async (input: Float32Array) => {
// Dimension of each face (alias)
const mobileFaceNetFaceSize = 112;
// Smaller alias
@ -45,7 +45,7 @@ export const faceEmbeddings = async (input: Float32Array) => {
const t = Date.now();
const feeds = { img_inputs: inputTensor };
const results = await session.run(feeds);
log.debug(() => `onnx/yolo face embedding took ${Date.now() - t} ms`);
log.debug(() => `ONNX/MFNT face embedding took ${Date.now() - t} ms`);
/* Need these model specific casts to extract and type the result */
return (results.embeddings as unknown as Record<string, unknown>)
.cpuData as Float32Array;

View file

@ -13,3 +13,12 @@ export const ensure = <T>(v: T | null | undefined): T => {
if (v === undefined) throw new Error("Required value was not found");
return v;
};
/**
* Wait for {@link ms} milliseconds
*
* This function is a promisified `setTimeout`. It returns a promise that
* resolves after {@link ms} milliseconds.
*/
export const wait = (ms: number) =>
new Promise((resolve) => setTimeout(resolve, ms));

View file

@ -55,9 +55,7 @@ export const execAsync = async (command: string | string[]) => {
: command;
const startTime = Date.now();
const result = await execAsync_(escapedCommand);
log.debug(
() => `${escapedCommand} (${Math.round(Date.now() - startTime)} ms)`,
);
log.debug(() => `${escapedCommand} (${Date.now() - startTime} ms)`);
return result;
};

View file

@ -153,20 +153,17 @@ const ffmpegExec = (
// - ML
const clipImageEmbedding = (jpegImageData: Uint8Array) =>
ipcRenderer.invoke("clipImageEmbedding", jpegImageData);
const computeCLIPImageEmbedding = (jpegImageData: Uint8Array) =>
ipcRenderer.invoke("computeCLIPImageEmbedding", jpegImageData);
const clipTextEmbeddingIfAvailable = (text: string) =>
ipcRenderer.invoke("clipTextEmbeddingIfAvailable", text);
const computeCLIPTextEmbeddingIfAvailable = (text: string) =>
ipcRenderer.invoke("computeCLIPTextEmbeddingIfAvailable", text);
const detectFaces = (input: Float32Array) =>
ipcRenderer.invoke("detectFaces", input);
const faceEmbeddings = (input: Float32Array) =>
ipcRenderer.invoke("faceEmbeddings", input);
const legacyFaceCrop = (faceID: string) =>
ipcRenderer.invoke("legacyFaceCrop", faceID);
const computeFaceEmbeddings = (input: Float32Array) =>
ipcRenderer.invoke("computeFaceEmbeddings", input);
// - Watch
@ -340,11 +337,10 @@ contextBridge.exposeInMainWorld("electron", {
// - ML
clipImageEmbedding,
clipTextEmbeddingIfAvailable,
computeCLIPImageEmbedding,
computeCLIPTextEmbeddingIfAvailable,
detectFaces,
faceEmbeddings,
legacyFaceCrop,
computeFaceEmbeddings,
// - Watch

View file

@ -71,37 +71,21 @@ func (c *Controller) deleteEmbedding(qItem repo.QueueItem) {
ctxLogger.WithError(err).Error("Failed to fetch datacenters")
return
}
// Ensure that the object are deleted from active derived storage dc. Ideally, this section should never be executed
// unless there's a bug in storing the DC or the service restarts before removing the rows from the table
// todo:(neeraj): remove this section after a few weeks of deployment
if len(datacenters) == 0 {
ctxLogger.Warn("No datacenters found for file, ensuring deletion from derived storage and hot DC")
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetDerivedStorageDataCenter())
if err != nil {
ctxLogger.WithError(err).Error("Failed to delete all objects")
return
}
// if Derived DC is different from hot DC, delete from hot DC as well
if c.derivedStorageDataCenter != c.S3Config.GetHotDataCenter() {
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter())
if err != nil {
ctxLogger.WithError(err).Error("Failed to delete all objects from hot DC")
return
}
}
} else {
ctxLogger.Infof("Deleting from all datacenters %v", datacenters)
}
ctxLogger.Infof("Deleting from all datacenters %v", datacenters)
for i := range datacenters {
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, datacenters[i])
dc := datacenters[i]
err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, dc)
if err != nil {
ctxLogger.WithError(err).Errorf("Failed to delete all objects from %s", datacenters[i])
ctxLogger.WithError(err).
WithField("dc", dc).
Errorf("Failed to delete all objects from %s", datacenters[i])
return
} else {
removeErr := c.Repo.RemoveDatacenter(context.Background(), fileID, datacenters[i])
if removeErr != nil {
ctxLogger.WithError(removeErr).Error("Failed to remove datacenter from db")
ctxLogger.WithError(removeErr).
WithField("dc", dc).
Error("Failed to remove datacenter from db")
return
}
}

View file

@ -260,7 +260,10 @@ func (c *ObjectCleanupController) DeleteAllObjectsWithPrefix(prefix string, dc s
Prefix: &prefix,
})
if err != nil {
log.Error(err)
log.WithFields(log.Fields{
"prefix": prefix,
"dc": dc,
}).WithError(err).Error("Failed to list objects")
return stacktrace.Propagate(err, "")
}
var keys []string
@ -270,7 +273,10 @@ func (c *ObjectCleanupController) DeleteAllObjectsWithPrefix(prefix string, dc s
for _, key := range keys {
err = c.DeleteObjectFromDataCenter(key, dc)
if err != nil {
log.Error(err)
log.WithFields(log.Fields{
"object_key": key,
"dc": dc,
}).WithError(err).Error("Failed to delete object")
return stacktrace.Propagate(err, "")
}
}

View file

@ -9,7 +9,7 @@ import { useCallback, useContext, useEffect, useRef, useState } from "react";
import { components } from "react-select";
import AsyncSelect from "react-select/async";
import { InputActionMeta } from "react-select/src/types";
import { Person } from "services/face/types";
import type { Person } from "services/face/people";
import { City } from "services/locationSearchService";
import {
getAutoCompleteSuggestions,

View file

@ -270,14 +270,7 @@ function EnableMLSearch({ onClose, enableMlSearch, onRootClose }) {
{" "}
<Typography color="text.muted">
{/* <Trans i18nKey={"ENABLE_ML_SEARCH_DESCRIPTION"} /> */}
<p>
We're putting finishing touches, coming back soon!
</p>
<p>
<small>
Existing indexed faces will continue to show.
</small>
</p>
We're putting finishing touches, coming back soon!
</Typography>
</Box>
{isInternalUserForML() && (

View file

@ -1,10 +1,11 @@
import { blobCache } from "@/next/blob-cache";
import log from "@/next/log";
import { Skeleton, styled } from "@mui/material";
import { Legend } from "components/PhotoViewer/styledComponents/Legend";
import { t } from "i18next";
import React, { useEffect, useState } from "react";
import mlIDbStorage from "services/face/db";
import { Face, Person, type MlFileData } from "services/face/types";
import type { Person } from "services/face/people";
import { EnteFile } from "types/file";
const FaceChipContainer = styled("div")`
@ -57,10 +58,7 @@ export const PeopleList = React.memo((props: PeopleListProps) => {
props.onSelect && props.onSelect(person, index)
}
>
<FaceCropImageView
faceID={person.displayFaceId}
cacheKey={person.faceCropCacheKey}
/>
<FaceCropImageView faceID={person.displayFaceId} />
</FaceChip>
))}
</FaceChipContainer>
@ -108,7 +106,7 @@ export function UnidentifiedFaces(props: {
file: EnteFile;
updateMLDataIndex: number;
}) {
const [faces, setFaces] = useState<Array<Face>>([]);
const [faces, setFaces] = useState<{ id: string }[]>([]);
useEffect(() => {
let didCancel = false;
@ -136,10 +134,7 @@ export function UnidentifiedFaces(props: {
{faces &&
faces.map((face, index) => (
<FaceChip key={index}>
<FaceCropImageView
faceID={face.id}
cacheKey={face.crop?.cacheKey}
/>
<FaceCropImageView faceID={face.id} />
</FaceChip>
))}
</FaceChipContainer>
@ -149,29 +144,22 @@ export function UnidentifiedFaces(props: {
interface FaceCropImageViewProps {
faceID: string;
cacheKey?: string;
}
const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
faceID,
cacheKey,
}) => {
const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({ faceID }) => {
const [objectURL, setObjectURL] = useState<string | undefined>();
useEffect(() => {
let didCancel = false;
const electron = globalThis.electron;
if (faceID && electron) {
electron
.legacyFaceCrop(faceID)
/*
cachedOrNew("face-crops", cacheKey, async () => {
return machineLearningService.regenerateFaceCrop(
faceId,
);
})*/
if (faceID) {
blobCache("face-crops")
.then((cache) => cache.get(faceID))
.then((data) => {
/*
TODO(MR): regen if needed and get this to work on web too.
cachedOrNew("face-crops", cacheKey, async () => {
return regenerateFaceCrop(faceId);
})*/
if (data) {
const blob = new Blob([data]);
if (!didCancel) setObjectURL(URL.createObjectURL(blob));
@ -183,7 +171,7 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
didCancel = true;
if (objectURL) URL.revokeObjectURL(objectURL);
};
}, [faceID, cacheKey]);
}, [faceID]);
return objectURL ? (
<img src={objectURL} />
@ -192,9 +180,9 @@ const FaceCropImageView: React.FC<FaceCropImageViewProps> = ({
);
};
async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
async function getPeopleList(file: EnteFile): Promise<Person[]> {
let startTime = Date.now();
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
const mlFileData = await mlIDbStorage.getFile(file.id);
log.info(
"getPeopleList:mlFilesStore:getItem",
Date.now() - startTime,
@ -226,8 +214,8 @@ async function getPeopleList(file: EnteFile): Promise<Array<Person>> {
return peopleList;
}
async function getUnidentifiedFaces(file: EnteFile): Promise<Array<Face>> {
const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id);
async function getUnidentifiedFaces(file: EnteFile): Promise<{ id: string }[]> {
const mlFileData = await mlIDbStorage.getFile(file.id);
return mlFileData?.faces?.filter(
(f) => f.personId === null || f.personId === undefined,

View file

@ -184,7 +184,7 @@ class CLIPService {
};
getTextEmbeddingIfAvailable = async (text: string) => {
return ensureElectron().clipTextEmbeddingIfAvailable(text);
return ensureElectron().computeCLIPTextEmbeddingIfAvailable(text);
};
private runClipEmbeddingExtraction = async (canceller: AbortController) => {
@ -294,7 +294,7 @@ class CLIPService {
const file = await localFile
.arrayBuffer()
.then((buffer) => new Uint8Array(buffer));
return await ensureElectron().clipImageEmbedding(file);
return await ensureElectron().computeCLIPImageEmbedding(file);
};
private encryptAndUploadEmbedding = async (
@ -328,7 +328,8 @@ class CLIPService {
private extractFileClipImageEmbedding = async (file: EnteFile) => {
const thumb = await downloadManager.getThumbnail(file);
const embedding = await ensureElectron().clipImageEmbedding(thumb);
const embedding =
await ensureElectron().computeCLIPImageEmbedding(thumb);
return embedding;
};

View file

@ -1,6 +1,6 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import { openCache, type BlobCache } from "@/next/blob-cache";
import { blobCache, type BlobCache } from "@/next/blob-cache";
import log from "@/next/log";
import { APPS } from "@ente/shared/apps/constants";
import ComlinkCryptoWorker from "@ente/shared/crypto";
@ -91,7 +91,7 @@ class DownloadManagerImpl {
}
this.downloadClient = createDownloadClient(app, tokens);
try {
this.thumbnailCache = await openCache("thumbs");
this.thumbnailCache = await blobCache("thumbs");
} catch (e) {
log.error(
"Failed to open thumbnail cache, will continue without it",
@ -100,7 +100,7 @@ class DownloadManagerImpl {
}
// TODO (MR): Revisit full file caching cf disk space usage
// try {
// if (isElectron()) this.fileCache = await openCache("files");
// if (isElectron()) this.fileCache = await cache("files");
// } catch (e) {
// log.error("Failed to open file cache, will continue without it", e);
// }

View file

@ -7,7 +7,7 @@ import HTTPService from "@ente/shared/network/HTTPService";
import { getEndpoint } from "@ente/shared/network/api";
import localForage from "@ente/shared/storage/localForage";
import { getToken } from "@ente/shared/storage/localStorage/helpers";
import { FileML } from "services/machineLearning/machineLearningService";
import { FileML } from "services/face/remote";
import type {
Embedding,
EmbeddingModel,

View file

@ -1,88 +0,0 @@
import { Matrix } from "ml-matrix";
import { Point } from "services/face/geom";
import { FaceAlignment, FaceDetection } from "services/face/types";
import { getSimilarityTransformation } from "similarity-transformation";
const ARCFACE_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[56.1396, 92.2848],
] as Array<[number, number]>;
const ARCFACE_LANDMARKS_FACE_SIZE = 112;
const ARC_FACE_5_LANDMARKS = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[41.5493, 92.3655],
[70.7299, 92.2041],
] as Array<[number, number]>;
/**
* Compute and return an {@link FaceAlignment} for the given face detection.
*
* @param faceDetection A geometry indicating a face detected in an image.
*/
export const faceAlignment = (faceDetection: FaceDetection): FaceAlignment => {
const landmarkCount = faceDetection.landmarks.length;
return getFaceAlignmentUsingSimilarityTransform(
faceDetection,
normalizeLandmarks(
landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS,
ARCFACE_LANDMARKS_FACE_SIZE,
),
);
};
function getFaceAlignmentUsingSimilarityTransform(
faceDetection: FaceDetection,
alignedLandmarks: Array<[number, number]>,
): FaceAlignment {
const landmarksMat = new Matrix(
faceDetection.landmarks
.map((p) => [p.x, p.y])
.slice(0, alignedLandmarks.length),
).transpose();
const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose();
const simTransform = getSimilarityTransformation(
landmarksMat,
alignedLandmarksMat,
);
const RS = Matrix.mul(simTransform.rotation, simTransform.scale);
const TR = simTransform.translation;
const affineMatrix = [
[RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)],
[RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)],
[0, 0, 1],
];
const size = 1 / simTransform.scale;
const meanTranslation = simTransform.toMean.sub(0.5).mul(size);
const centerMat = simTransform.fromMean.sub(meanTranslation);
const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0));
const rotation = -Math.atan2(
simTransform.rotation.get(0, 1),
simTransform.rotation.get(0, 0),
);
return {
affineMatrix,
center,
size,
rotation,
};
}
function normalizeLandmarks(
landmarks: Array<[number, number]>,
faceSize: number,
): Array<[number, number]> {
return landmarks.map((landmark) =>
landmark.map((p) => p / faceSize),
) as Array<[number, number]>;
}

View file

@ -1,187 +0,0 @@
import { Face } from "services/face/types";
import { createGrayscaleIntMatrixFromNormalized2List } from "utils/image";
import { mobileFaceNetFaceSize } from "./embed";
/**
* Laplacian blur detection.
*/
export const detectBlur = (
alignedFaces: Float32Array,
faces: Face[],
): number[] => {
const numFaces = Math.round(
alignedFaces.length /
(mobileFaceNetFaceSize * mobileFaceNetFaceSize * 3),
);
const blurValues: number[] = [];
for (let i = 0; i < numFaces; i++) {
const face = faces[i];
const direction = faceDirection(face);
const faceImage = createGrayscaleIntMatrixFromNormalized2List(
alignedFaces,
i,
);
const laplacian = applyLaplacian(faceImage, direction);
blurValues.push(matrixVariance(laplacian));
}
return blurValues;
};
type FaceDirection = "left" | "right" | "straight";
const faceDirection = (face: Face): FaceDirection => {
const landmarks = face.detection.landmarks;
const leftEye = landmarks[0];
const rightEye = landmarks[1];
const nose = landmarks[2];
const leftMouth = landmarks[3];
const rightMouth = landmarks[4];
const eyeDistanceX = Math.abs(rightEye.x - leftEye.x);
const eyeDistanceY = Math.abs(rightEye.y - leftEye.y);
const mouthDistanceY = Math.abs(rightMouth.y - leftMouth.y);
const faceIsUpright =
Math.max(leftEye.y, rightEye.y) + 0.5 * eyeDistanceY < nose.y &&
nose.y + 0.5 * mouthDistanceY < Math.min(leftMouth.y, rightMouth.y);
const noseStickingOutLeft =
nose.x < Math.min(leftEye.x, rightEye.x) &&
nose.x < Math.min(leftMouth.x, rightMouth.x);
const noseStickingOutRight =
nose.x > Math.max(leftEye.x, rightEye.x) &&
nose.x > Math.max(leftMouth.x, rightMouth.x);
const noseCloseToLeftEye =
Math.abs(nose.x - leftEye.x) < 0.2 * eyeDistanceX;
const noseCloseToRightEye =
Math.abs(nose.x - rightEye.x) < 0.2 * eyeDistanceX;
if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) {
return "left";
} else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) {
return "right";
}
return "straight";
};
/**
* Return a new image by applying a Laplacian blur kernel to each pixel.
*/
const applyLaplacian = (
image: number[][],
direction: FaceDirection,
): number[][] => {
const paddedImage: number[][] = padImage(image, direction);
const numRows = paddedImage.length - 2;
const numCols = paddedImage[0].length - 2;
// Create an output image initialized to 0.
const outputImage: number[][] = Array.from({ length: numRows }, () =>
new Array(numCols).fill(0),
);
// Define the Laplacian kernel.
const kernel: number[][] = [
[0, 1, 0],
[1, -4, 1],
[0, 1, 0],
];
// Apply the kernel to each pixel
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < numCols; j++) {
let sum = 0;
for (let ki = 0; ki < 3; ki++) {
for (let kj = 0; kj < 3; kj++) {
sum += paddedImage[i + ki][j + kj] * kernel[ki][kj];
}
}
// Adjust the output value if necessary (e.g., clipping).
outputImage[i][j] = sum;
}
}
return outputImage;
};
const padImage = (image: number[][], direction: FaceDirection): number[][] => {
const removeSideColumns = 56; /* must be even */
const numRows = image.length;
const numCols = image[0].length;
const paddedNumCols = numCols + 2 - removeSideColumns;
const paddedNumRows = numRows + 2;
// Create a new matrix with extra padding.
const paddedImage: number[][] = Array.from({ length: paddedNumRows }, () =>
new Array(paddedNumCols).fill(0),
);
if (direction === "straight") {
// Copy original image into the center of the padded image.
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < paddedNumCols - 2; j++) {
paddedImage[i + 1][j + 1] =
image[i][j + Math.round(removeSideColumns / 2)];
}
}
} else if (direction === "left") {
// If the face is facing left, we only take the right side of the face image.
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < paddedNumCols - 2; j++) {
paddedImage[i + 1][j + 1] = image[i][j + removeSideColumns];
}
}
} else if (direction === "right") {
// If the face is facing right, we only take the left side of the face image.
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < paddedNumCols - 2; j++) {
paddedImage[i + 1][j + 1] = image[i][j];
}
}
}
// Reflect padding
// Top and bottom rows
for (let j = 1; j <= paddedNumCols - 2; j++) {
paddedImage[0][j] = paddedImage[2][j]; // Top row
paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row
}
// Left and right columns
for (let i = 0; i < numRows + 2; i++) {
paddedImage[i][0] = paddedImage[i][2]; // Left column
paddedImage[i][paddedNumCols - 1] = paddedImage[i][paddedNumCols - 3]; // Right column
}
return paddedImage;
};
const matrixVariance = (matrix: number[][]): number => {
const numRows = matrix.length;
const numCols = matrix[0].length;
const totalElements = numRows * numCols;
// Calculate the mean.
let mean: number = 0;
matrix.forEach((row) => {
row.forEach((value) => {
mean += value;
});
});
mean /= totalElements;
// Calculate the variance.
let variance: number = 0;
matrix.forEach((row) => {
row.forEach((value) => {
const diff: number = value - mean;
variance += diff * diff;
});
});
variance /= totalElements;
return variance;
};

View file

@ -1,8 +1,9 @@
import { Hdbscan, type DebugInfo } from "hdbscan";
import { type Cluster } from "services/face/types";
export type Cluster = number[];
export interface ClusterFacesResult {
clusters: Array<Cluster>;
clusters: Cluster[];
noise: Cluster;
debugInfo?: DebugInfo;
}

View file

@ -1,32 +0,0 @@
import { Box, enlargeBox } from "services/face/geom";
import { FaceCrop, FaceDetection } from "services/face/types";
import { cropWithRotation } from "utils/image";
import { faceAlignment } from "./align";
export const getFaceCrop = (
imageBitmap: ImageBitmap,
faceDetection: FaceDetection,
): FaceCrop => {
const alignment = faceAlignment(faceDetection);
const padding = 0.25;
const maxSize = 256;
const alignmentBox = new Box({
x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2,
width: alignment.size,
height: alignment.size,
}).round();
const scaleForPadding = 1 + padding * 2;
const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round();
const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, {
width: maxSize,
height: maxSize,
});
return {
image: faceImageBitmap,
imageBox: paddedBox,
};
};

View file

@ -9,7 +9,8 @@ import {
openDB,
} from "idb";
import isElectron from "is-electron";
import { Face, MLLibraryData, MlFileData, Person } from "services/face/types";
import type { Person } from "services/face/people";
import type { MlFileData } from "services/face/types";
import {
DEFAULT_ML_SEARCH_CONFIG,
MAX_ML_SYNC_ERROR_COUNT,
@ -23,6 +24,18 @@ export interface IndexStatus {
peopleIndexSynced: boolean;
}
/**
* TODO(MR): Transient type with an intersection of values that both existing
* and new types during the migration will have. Eventually we'll store the the
* server ML data shape here exactly.
*/
export interface MinimalPersistedFileData {
fileId: number;
mlVersion: number;
errorCount: number;
faces?: { personId?: number; id: string }[];
}
interface Config {}
export const ML_SEARCH_CONFIG_NAME = "ml-search";
@ -31,7 +44,7 @@ const MLDATA_DB_NAME = "mldata";
interface MLDb extends DBSchema {
files: {
key: number;
value: MlFileData;
value: MinimalPersistedFileData;
indexes: { mlVersion: [number, number] };
};
people: {
@ -50,7 +63,7 @@ interface MLDb extends DBSchema {
};
library: {
key: string;
value: MLLibraryData;
value: unknown;
};
configs: {
key: string;
@ -177,6 +190,7 @@ class MLIDbStorage {
ML_SEARCH_CONFIG_NAME,
);
db.deleteObjectStore("library");
db.deleteObjectStore("things");
} catch {
// TODO: ignore for now as we finalize the new version
@ -210,38 +224,6 @@ class MLIDbStorage {
await this.db;
}
public async getAllFileIds() {
const db = await this.db;
return db.getAllKeys("files");
}
public async putAllFilesInTx(mlFiles: Array<MlFileData>) {
const db = await this.db;
const tx = db.transaction("files", "readwrite");
await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile)));
await tx.done;
}
public async removeAllFilesInTx(fileIds: Array<number>) {
const db = await this.db;
const tx = db.transaction("files", "readwrite");
await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId)));
await tx.done;
}
public async newTransaction<
Name extends StoreNames<MLDb>,
Mode extends IDBTransactionMode = "readonly",
>(storeNames: Name, mode?: Mode) {
const db = await this.db;
return db.transaction(storeNames, mode);
}
public async commit(tx: IDBPTransaction<MLDb>) {
return tx.done;
}
public async getAllFileIdsForUpdate(
tx: IDBPTransaction<MLDb, ["files"], "readwrite">,
) {
@ -275,16 +257,11 @@ class MLIDbStorage {
return fileIds;
}
public async getFile(fileId: number) {
public async getFile(fileId: number): Promise<MinimalPersistedFileData> {
const db = await this.db;
return db.get("files", fileId);
}
public async getAllFiles() {
const db = await this.db;
return db.getAll("files");
}
public async putFile(mlFile: MlFileData) {
const db = await this.db;
return db.put("files", mlFile);
@ -292,7 +269,7 @@ class MLIDbStorage {
public async upsertFileInTx(
fileId: number,
upsert: (mlFile: MlFileData) => MlFileData,
upsert: (mlFile: MinimalPersistedFileData) => MinimalPersistedFileData,
) {
const db = await this.db;
const tx = db.transaction("files", "readwrite");
@ -305,7 +282,7 @@ class MLIDbStorage {
}
public async putAllFiles(
mlFiles: Array<MlFileData>,
mlFiles: MinimalPersistedFileData[],
tx: IDBPTransaction<MLDb, ["files"], "readwrite">,
) {
await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile)));
@ -318,44 +295,6 @@ class MLIDbStorage {
await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId)));
}
public async getFace(fileID: number, faceId: string) {
const file = await this.getFile(fileID);
const face = file.faces.filter((f) => f.id === faceId);
return face[0];
}
public async getAllFacesMap() {
const startTime = Date.now();
const db = await this.db;
const allFiles = await db.getAll("files");
const allFacesMap = new Map<number, Array<Face>>();
allFiles.forEach(
(mlFileData) =>
mlFileData.faces &&
allFacesMap.set(mlFileData.fileId, mlFileData.faces),
);
log.info("getAllFacesMap", Date.now() - startTime, "ms");
return allFacesMap;
}
public async updateFaces(allFacesMap: Map<number, Face[]>) {
const startTime = Date.now();
const db = await this.db;
const tx = db.transaction("files", "readwrite");
let cursor = await tx.store.openCursor();
while (cursor) {
if (allFacesMap.has(cursor.key)) {
const mlFileData = { ...cursor.value };
mlFileData.faces = allFacesMap.get(cursor.key);
cursor.update(mlFileData);
}
cursor = await cursor.continue();
}
await tx.done;
log.info("updateFaces", Date.now() - startTime, "ms");
}
public async getPerson(id: number) {
const db = await this.db;
return db.get("people", id);
@ -366,21 +305,6 @@ class MLIDbStorage {
return db.getAll("people");
}
public async putPerson(person: Person) {
const db = await this.db;
return db.put("people", person);
}
public async clearAllPeople() {
const db = await this.db;
return db.clear("people");
}
public async getIndexVersion(index: string) {
const db = await this.db;
return db.get("versions", index);
}
public async incrementIndexVersion(index: StoreNames<MLDb>) {
if (index === "versions") {
throw new Error("versions store can not be versioned");
@ -395,21 +319,6 @@ class MLIDbStorage {
return version;
}
public async setIndexVersion(index: string, version: number) {
const db = await this.db;
return db.put("versions", version, index);
}
public async getLibraryData() {
const db = await this.db;
return db.get("library", "data");
}
public async putLibraryData(data: MLLibraryData) {
const db = await this.db;
return db.put("library", data, "data");
}
public async getConfig<T extends Config>(name: string, def: T) {
const db = await this.db;
const tx = db.transaction("configs", "readwrite");
@ -473,66 +382,6 @@ class MLIDbStorage {
peopleIndexVersion === filesIndexVersion,
};
}
// for debug purpose
public async getAllMLData() {
const db = await this.db;
const tx = db.transaction(db.objectStoreNames, "readonly");
const allMLData: any = {};
for (const store of tx.objectStoreNames) {
const keys = await tx.objectStore(store).getAllKeys();
const data = await tx.objectStore(store).getAll();
allMLData[store] = {};
for (let i = 0; i < keys.length; i++) {
allMLData[store][keys[i]] = data[i];
}
}
await tx.done;
const files = allMLData["files"];
for (const fileId of Object.keys(files)) {
const fileData = files[fileId];
fileData.faces?.forEach(
(f) => (f.embedding = Array.from(f.embedding)),
);
}
return allMLData;
}
// for debug purpose, this will overwrite all data
public async putAllMLData(allMLData: Map<string, any>) {
const db = await this.db;
const tx = db.transaction(db.objectStoreNames, "readwrite");
for (const store of tx.objectStoreNames) {
const records = allMLData[store];
if (!records) {
continue;
}
const txStore = tx.objectStore(store);
if (store === "files") {
const files = records;
for (const fileId of Object.keys(files)) {
const fileData = files[fileId];
fileData.faces?.forEach(
(f) => (f.embedding = Float32Array.from(f.embedding)),
);
}
}
await txStore.clear();
for (const key of Object.keys(records)) {
if (txStore.keyPath) {
txStore.put(records[key]);
} else {
txStore.put(records[key], key);
}
}
}
await tx.done;
}
}
export default new MLIDbStorage();

View file

@ -1,316 +0,0 @@
import { workerBridge } from "@/next/worker/worker-bridge";
import { euclidean } from "hdbscan";
import {
Box,
Dimensions,
Point,
boxFromBoundingBox,
newBox,
} from "services/face/geom";
import { FaceDetection } from "services/face/types";
import {
Matrix,
applyToPoint,
compose,
scale,
translate,
} from "transformation-matrix";
import {
clamp,
getPixelBilinear,
normalizePixelBetween0And1,
} from "utils/image";
/**
* Detect faces in the given {@link imageBitmap}.
*
* The model used is YOLO, running in an ONNX runtime.
*/
export const detectFaces = async (
imageBitmap: ImageBitmap,
): Promise<Array<FaceDetection>> => {
const maxFaceDistancePercent = Math.sqrt(2) / 100;
const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent;
const preprocessResult = preprocessImageBitmapToFloat32ChannelsFirst(
imageBitmap,
640,
640,
);
const data = preprocessResult.data;
const resized = preprocessResult.newSize;
const outputData = await workerBridge.detectFaces(data);
const faces = getFacesFromYOLOOutput(outputData as Float32Array, 0.7);
const inBox = newBox(0, 0, resized.width, resized.height);
const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height);
const transform = computeTransformToBox(inBox, toBox);
const faceDetections: Array<FaceDetection> = faces?.map((f) => {
const box = transformBox(f.box, transform);
const normLandmarks = f.landmarks;
const landmarks = transformPoints(normLandmarks, transform);
return {
box,
landmarks,
probability: f.probability as number,
} as FaceDetection;
});
return removeDuplicateDetections(faceDetections, maxFaceDistance);
};
const preprocessImageBitmapToFloat32ChannelsFirst = (
imageBitmap: ImageBitmap,
requiredWidth: number,
requiredHeight: number,
maintainAspectRatio: boolean = true,
normFunction: (pixelValue: number) => number = normalizePixelBetween0And1,
) => {
// Create an OffscreenCanvas and set its size.
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
let scaleW = requiredWidth / imageBitmap.width;
let scaleH = requiredHeight / imageBitmap.height;
if (maintainAspectRatio) {
const scale = Math.min(
requiredWidth / imageBitmap.width,
requiredHeight / imageBitmap.height,
);
scaleW = scale;
scaleH = scale;
}
const scaledWidth = clamp(
Math.round(imageBitmap.width * scaleW),
0,
requiredWidth,
);
const scaledHeight = clamp(
Math.round(imageBitmap.height * scaleH),
0,
requiredHeight,
);
const processedImage = new Float32Array(
1 * 3 * requiredWidth * requiredHeight,
);
// Populate the Float32Array with normalized pixel values
let pixelIndex = 0;
const channelOffsetGreen = requiredHeight * requiredWidth;
const channelOffsetBlue = 2 * requiredHeight * requiredWidth;
for (let h = 0; h < requiredHeight; h++) {
for (let w = 0; w < requiredWidth; w++) {
let pixel: {
r: number;
g: number;
b: number;
};
if (w >= scaledWidth || h >= scaledHeight) {
pixel = { r: 114, g: 114, b: 114 };
} else {
pixel = getPixelBilinear(
w / scaleW,
h / scaleH,
pixelData,
imageBitmap.width,
imageBitmap.height,
);
}
processedImage[pixelIndex] = normFunction(pixel.r);
processedImage[pixelIndex + channelOffsetGreen] = normFunction(
pixel.g,
);
processedImage[pixelIndex + channelOffsetBlue] = normFunction(
pixel.b,
);
pixelIndex++;
}
}
return {
data: processedImage,
originalSize: {
width: imageBitmap.width,
height: imageBitmap.height,
},
newSize: { width: scaledWidth, height: scaledHeight },
};
};
/**
* @param rowOutput A Float32Array of shape [25200, 16], where each row
* represents a bounding box.
*/
const getFacesFromYOLOOutput = (
rowOutput: Float32Array,
minScore: number,
): Array<FaceDetection> => {
const faces: Array<FaceDetection> = [];
// Iterate over each row.
for (let i = 0; i < rowOutput.length; i += 16) {
const score = rowOutput[i + 4];
if (score < minScore) {
continue;
}
// The first 4 values represent the bounding box's coordinates:
//
// (x1, y1, x2, y2)
//
const xCenter = rowOutput[i];
const yCenter = rowOutput[i + 1];
const width = rowOutput[i + 2];
const height = rowOutput[i + 3];
const xMin = xCenter - width / 2.0; // topLeft
const yMin = yCenter - height / 2.0; // topLeft
const leftEyeX = rowOutput[i + 5];
const leftEyeY = rowOutput[i + 6];
const rightEyeX = rowOutput[i + 7];
const rightEyeY = rowOutput[i + 8];
const noseX = rowOutput[i + 9];
const noseY = rowOutput[i + 10];
const leftMouthX = rowOutput[i + 11];
const leftMouthY = rowOutput[i + 12];
const rightMouthX = rowOutput[i + 13];
const rightMouthY = rowOutput[i + 14];
const box = new Box({
x: xMin,
y: yMin,
width: width,
height: height,
});
const probability = score as number;
const landmarks = [
new Point(leftEyeX, leftEyeY),
new Point(rightEyeX, rightEyeY),
new Point(noseX, noseY),
new Point(leftMouthX, leftMouthY),
new Point(rightMouthX, rightMouthY),
];
faces.push({ box, landmarks, probability });
}
return faces;
};
export const getRelativeDetection = (
faceDetection: FaceDetection,
dimensions: Dimensions,
): FaceDetection => {
const oldBox: Box = faceDetection.box;
const box = new Box({
x: oldBox.x / dimensions.width,
y: oldBox.y / dimensions.height,
width: oldBox.width / dimensions.width,
height: oldBox.height / dimensions.height,
});
const oldLandmarks: Point[] = faceDetection.landmarks;
const landmarks = oldLandmarks.map((l) => {
return new Point(l.x / dimensions.width, l.y / dimensions.height);
});
const probability = faceDetection.probability;
return { box, landmarks, probability };
};
/**
* Removes duplicate face detections from an array of detections.
*
* This function sorts the detections by their probability in descending order,
* then iterates over them.
*
* For each detection, it calculates the Euclidean distance to all other
* detections.
*
* If the distance is less than or equal to the specified threshold
* (`withinDistance`), the other detection is considered a duplicate and is
* removed.
*
* @param detections - An array of face detections to remove duplicates from.
*
* @param withinDistance - The maximum Euclidean distance between two detections
* for them to be considered duplicates.
*
* @returns An array of face detections with duplicates removed.
*/
const removeDuplicateDetections = (
detections: Array<FaceDetection>,
withinDistance: number,
) => {
detections.sort((a, b) => b.probability - a.probability);
const isSelected = new Map<number, boolean>();
for (let i = 0; i < detections.length; i++) {
if (isSelected.get(i) === false) {
continue;
}
isSelected.set(i, true);
for (let j = i + 1; j < detections.length; j++) {
if (isSelected.get(j) === false) {
continue;
}
const centeri = getDetectionCenter(detections[i]);
const centerj = getDetectionCenter(detections[j]);
const dist = euclidean(
[centeri.x, centeri.y],
[centerj.x, centerj.y],
);
if (dist <= withinDistance) {
isSelected.set(j, false);
}
}
}
const uniques: Array<FaceDetection> = [];
for (let i = 0; i < detections.length; i++) {
isSelected.get(i) && uniques.push(detections[i]);
}
return uniques;
};
function getDetectionCenter(detection: FaceDetection) {
const center = new Point(0, 0);
// TODO: first 4 landmarks is applicable to blazeface only
// this needs to consider eyes, nose and mouth landmarks to take center
detection.landmarks?.slice(0, 4).forEach((p) => {
center.x += p.x;
center.y += p.y;
});
return new Point(center.x / 4, center.y / 4);
}
function computeTransformToBox(inBox: Box, toBox: Box): Matrix {
return compose(
translate(toBox.x, toBox.y),
scale(toBox.width / inBox.width, toBox.height / inBox.height),
);
}
function transformPoint(point: Point, transform: Matrix) {
const txdPoint = applyToPoint(transform, point);
return new Point(txdPoint.x, txdPoint.y);
}
function transformPoints(points: Point[], transform: Matrix) {
return points?.map((p) => transformPoint(p, transform));
}
function transformBox(box: Box, transform: Matrix) {
const topLeft = transformPoint(box.topLeft, transform);
const bottomRight = transformPoint(box.bottomRight, transform);
return boxFromBoundingBox({
left: topLeft.x,
top: topLeft.y,
right: bottomRight.x,
bottom: bottomRight.y,
});
}

View file

@ -1,26 +0,0 @@
import { workerBridge } from "@/next/worker/worker-bridge";
import { FaceEmbedding } from "services/face/types";
export const mobileFaceNetFaceSize = 112;
/**
* Compute embeddings for the given {@link faceData}.
*
* The model used is MobileFaceNet, running in an ONNX runtime.
*/
export const faceEmbeddings = async (
faceData: Float32Array,
): Promise<Array<FaceEmbedding>> => {
const outputData = await workerBridge.faceEmbeddings(faceData);
const embeddingSize = 192;
const embeddings = new Array<FaceEmbedding>(
outputData.length / embeddingSize,
);
for (let i = 0; i < embeddings.length; i++) {
embeddings[i] = new Float32Array(
outputData.slice(i * embeddingSize, (i + 1) * embeddingSize),
);
}
return embeddings;
};

View file

@ -1,194 +1,742 @@
import { openCache } from "@/next/blob-cache";
import { FILE_TYPE } from "@/media/file-type";
import { blobCache } from "@/next/blob-cache";
import log from "@/next/log";
import { faceAlignment } from "services/face/align";
import mlIDbStorage from "services/face/db";
import { detectFaces, getRelativeDetection } from "services/face/detect";
import { faceEmbeddings, mobileFaceNetFaceSize } from "services/face/embed";
import { workerBridge } from "@/next/worker/worker-bridge";
import { euclidean } from "hdbscan";
import { Matrix } from "ml-matrix";
import {
DetectedFace,
Box,
Dimensions,
Point,
enlargeBox,
roundBox,
} from "services/face/geom";
import type {
Face,
MLSyncFileContext,
type FaceAlignment,
FaceAlignment,
FaceDetection,
MlFileData,
} from "services/face/types";
import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image";
import { detectBlur } from "./blur";
import { getFaceCrop } from "./crop";
import { defaultMLVersion } from "services/machineLearning/machineLearningService";
import { getSimilarityTransformation } from "similarity-transformation";
import type { EnteFile } from "types/file";
import { fetchImageBitmap, getLocalFileImageBitmap } from "./file";
import {
fetchImageBitmap,
fetchImageBitmapForContext,
getFaceId,
getLocalFile,
clamp,
grayscaleIntMatrixFromNormalized2List,
pixelRGBBilinear,
warpAffineFloat32List,
} from "./image";
import { transformFaceDetections } from "./transform-box";
export const syncFileAnalyzeFaces = async (fileContext: MLSyncFileContext) => {
const { newMlFile } = fileContext;
/**
* Index faces in the given file.
*
* This function is the entry point to the indexing pipeline. The file goes
* through various stages:
*
* 1. Downloading the original if needed.
* 2. Detect faces using ONNX/YOLO
* 3. Align the face rectangles, compute blur.
* 4. Compute embeddings for the detected face (crops).
*
* Once all of it is done, it returns the face rectangles and embeddings so that
* they can be saved locally for offline use, and encrypts and uploads them to
* the user's remote storage so that their other devices can download them
* instead of needing to reindex.
*/
export const indexFaces = async (enteFile: EnteFile, localFile?: File) => {
const startTime = Date.now();
await syncFileFaceDetections(fileContext);
if (newMlFile.faces && newMlFile.faces.length > 0) {
await syncFileFaceCrops(fileContext);
const alignedFacesData = await syncFileFaceAlignments(fileContext);
await syncFileFaceEmbeddings(fileContext, alignedFacesData);
await syncFileFaceMakeRelativeDetections(fileContext);
const imageBitmap = await fetchOrCreateImageBitmap(enteFile, localFile);
let mlFile: MlFileData;
try {
mlFile = await indexFaces_(enteFile, imageBitmap);
} finally {
imageBitmap.close();
}
log.debug(
() =>
`Face detection for file ${fileContext.enteFile.id} took ${Math.round(Date.now() - startTime)} ms`,
);
};
const syncFileFaceDetections = async (fileContext: MLSyncFileContext) => {
const { newMlFile } = fileContext;
newMlFile.faceDetectionMethod = {
value: "YoloFace",
version: 1,
};
fileContext.newDetection = true;
const imageBitmap = await fetchImageBitmapForContext(fileContext);
const faceDetections = await detectFaces(imageBitmap);
// TODO: reenable faces filtering based on width
const detectedFaces = faceDetections?.map((detection) => {
return {
fileId: fileContext.enteFile.id,
detection,
} as DetectedFace;
log.debug(() => {
const nf = mlFile.faces?.length ?? 0;
const ms = Date.now() - startTime;
return `Indexed ${nf} faces in file ${enteFile.id} (${ms} ms)`;
});
newMlFile.faces = detectedFaces?.map((detectedFace) => ({
...detectedFace,
id: getFaceId(detectedFace, newMlFile.imageDimensions),
return mlFile;
};
/**
* Return a {@link ImageBitmap}, using {@link localFile} if present otherwise
* downloading the source image corresponding to {@link enteFile} from remote.
*/
const fetchOrCreateImageBitmap = async (
enteFile: EnteFile,
localFile: File,
) => {
const fileType = enteFile.metadata.fileType;
if (localFile) {
// TODO-ML(MR): Could also be image part of live photo?
if (fileType !== FILE_TYPE.IMAGE)
throw new Error("Local file of only image type is supported");
return await getLocalFileImageBitmap(enteFile, localFile);
} else if ([FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes(fileType)) {
return await fetchImageBitmap(enteFile);
} else {
throw new Error(`Cannot index unsupported file type ${fileType}`);
}
};
const indexFaces_ = async (enteFile: EnteFile, imageBitmap: ImageBitmap) => {
const fileID = enteFile.id;
const { width, height } = imageBitmap;
const imageDimensions = { width, height };
const mlFile: MlFileData = {
fileId: fileID,
mlVersion: defaultMLVersion,
imageDimensions,
errorCount: 0,
};
const faceDetections = await detectFaces(imageBitmap);
const detectedFaces = faceDetections.map((detection) => ({
id: makeFaceID(fileID, detection, imageDimensions),
fileId: fileID,
detection,
}));
// ?.filter((f) =>
// f.box.width > syncContext.config.faceDetection.minFaceSize
// );
log.info("[MLService] Detected Faces: ", newMlFile.faces?.length);
};
mlFile.faces = detectedFaces;
const syncFileFaceCrops = async (fileContext: MLSyncFileContext) => {
const { newMlFile } = fileContext;
const imageBitmap = await fetchImageBitmapForContext(fileContext);
newMlFile.faceCropMethod = {
value: "ArcFace",
version: 1,
};
if (detectedFaces.length > 0) {
const alignments: FaceAlignment[] = [];
for (const face of newMlFile.faces) {
await saveFaceCrop(imageBitmap, face);
}
};
for (const face of mlFile.faces) {
const alignment = faceAlignment(face.detection);
face.alignment = alignment;
alignments.push(alignment);
const syncFileFaceAlignments = async (
fileContext: MLSyncFileContext,
): Promise<Float32Array> => {
const { newMlFile } = fileContext;
newMlFile.faceAlignmentMethod = {
value: "ArcFace",
version: 1,
};
fileContext.newAlignment = true;
const imageBitmap =
fileContext.imageBitmap ||
(await fetchImageBitmapForContext(fileContext));
await saveFaceCrop(imageBitmap, face);
}
// Execute the face alignment calculations
for (const face of newMlFile.faces) {
face.alignment = faceAlignment(face.detection);
}
// Extract face images and convert to Float32Array
const faceAlignments = newMlFile.faces.map((f) => f.alignment);
const faceImages = await extractFaceImagesToFloat32(
faceAlignments,
mobileFaceNetFaceSize,
imageBitmap,
);
const blurValues = detectBlur(faceImages, newMlFile.faces);
newMlFile.faces.forEach((f, i) => (f.blurValue = blurValues[i]));
imageBitmap.close();
log.info("[MLService] alignedFaces: ", newMlFile.faces?.length);
return faceImages;
};
const syncFileFaceEmbeddings = async (
fileContext: MLSyncFileContext,
alignedFacesInput: Float32Array,
) => {
const { newMlFile } = fileContext;
newMlFile.faceEmbeddingMethod = {
value: "MobileFaceNet",
version: 2,
};
// TODO: when not storing face crops, image will be needed to extract faces
// fileContext.imageBitmap ||
// (await this.getImageBitmap(fileContext));
const embeddings = await faceEmbeddings(alignedFacesInput);
newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
log.info("[MLService] facesWithEmbeddings: ", newMlFile.faces.length);
};
const syncFileFaceMakeRelativeDetections = async (
fileContext: MLSyncFileContext,
) => {
const { newMlFile } = fileContext;
for (let i = 0; i < newMlFile.faces.length; i++) {
const face = newMlFile.faces[i];
if (face.detection.box.x + face.detection.box.width < 2) continue; // Skip if somehow already relative
face.detection = getRelativeDetection(
face.detection,
newMlFile.imageDimensions,
const alignedFacesData = convertToMobileFaceNetInput(
imageBitmap,
alignments,
);
}
};
export const saveFaceCrop = async (imageBitmap: ImageBitmap, face: Face) => {
const faceCrop = getFaceCrop(imageBitmap, face.detection);
const blurValues = detectBlur(alignedFacesData, mlFile.faces);
mlFile.faces.forEach((f, i) => (f.blurValue = blurValues[i]));
const blob = await imageBitmapToBlob(faceCrop.image);
const embeddings = await computeEmbeddings(alignedFacesData);
mlFile.faces.forEach((f, i) => (f.embedding = embeddings[i]));
const cache = await openCache("face-crops");
await cache.put(face.id, blob);
faceCrop.image.close();
return blob;
};
export const regenerateFaceCrop = async (faceID: string) => {
const fileID = Number(faceID.split("-")[0]);
const personFace = await mlIDbStorage.getFace(fileID, faceID);
if (!personFace) {
throw Error("Face not found");
mlFile.faces.forEach((face) => {
face.detection = relativeDetection(face.detection, imageDimensions);
});
}
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
return await saveFaceCrop(imageBitmap, personFace);
return mlFile;
};
async function extractFaceImagesToFloat32(
faceAlignments: Array<FaceAlignment>,
/**
* Detect faces in the given {@link imageBitmap}.
*
* The model used is YOLO, running in an ONNX runtime.
*/
const detectFaces = async (
imageBitmap: ImageBitmap,
): Promise<FaceDetection[]> => {
const rect = ({ width, height }: Dimensions) =>
new Box({ x: 0, y: 0, width, height });
const { yoloInput, yoloSize } =
convertToYOLOInputFloat32ChannelsFirst(imageBitmap);
const yoloOutput = await workerBridge.detectFaces(yoloInput);
const faces = faceDetectionsFromYOLOOutput(yoloOutput);
const faceDetections = transformFaceDetections(
faces,
rect(yoloSize),
rect(imageBitmap),
);
const maxFaceDistancePercent = Math.sqrt(2) / 100;
const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent;
return removeDuplicateDetections(faceDetections, maxFaceDistance);
};
/**
* Convert {@link imageBitmap} into the format that the YOLO face detection
* model expects.
*/
const convertToYOLOInputFloat32ChannelsFirst = (imageBitmap: ImageBitmap) => {
const requiredWidth = 640;
const requiredHeight = 640;
const { width, height } = imageBitmap;
// Create an OffscreenCanvas and set its size.
const offscreenCanvas = new OffscreenCanvas(width, height);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, width, height);
const imageData = ctx.getImageData(0, 0, width, height);
const pixelData = imageData.data;
// Maintain aspect ratio.
const scale = Math.min(requiredWidth / width, requiredHeight / height);
const scaledWidth = clamp(Math.round(width * scale), 0, requiredWidth);
const scaledHeight = clamp(Math.round(height * scale), 0, requiredHeight);
const yoloInput = new Float32Array(1 * 3 * requiredWidth * requiredHeight);
const yoloSize = { width: scaledWidth, height: scaledHeight };
// Populate the Float32Array with normalized pixel values.
let pi = 0;
const channelOffsetGreen = requiredHeight * requiredWidth;
const channelOffsetBlue = 2 * requiredHeight * requiredWidth;
for (let h = 0; h < requiredHeight; h++) {
for (let w = 0; w < requiredWidth; w++) {
const { r, g, b } =
w >= scaledWidth || h >= scaledHeight
? { r: 114, g: 114, b: 114 }
: pixelRGBBilinear(
w / scale,
h / scale,
pixelData,
width,
height,
);
yoloInput[pi] = r / 255.0;
yoloInput[pi + channelOffsetGreen] = g / 255.0;
yoloInput[pi + channelOffsetBlue] = b / 255.0;
pi++;
}
}
return { yoloInput, yoloSize };
};
/**
* Extract detected faces from the YOLO's output.
*
* Only detections that exceed a minimum score are returned.
*
* @param rows A Float32Array of shape [25200, 16], where each row
* represents a bounding box.
*/
const faceDetectionsFromYOLOOutput = (rows: Float32Array): FaceDetection[] => {
const faces: FaceDetection[] = [];
// Iterate over each row.
for (let i = 0; i < rows.length; i += 16) {
const score = rows[i + 4];
if (score < 0.7) continue;
const xCenter = rows[i];
const yCenter = rows[i + 1];
const width = rows[i + 2];
const height = rows[i + 3];
const xMin = xCenter - width / 2.0; // topLeft
const yMin = yCenter - height / 2.0; // topLeft
const leftEyeX = rows[i + 5];
const leftEyeY = rows[i + 6];
const rightEyeX = rows[i + 7];
const rightEyeY = rows[i + 8];
const noseX = rows[i + 9];
const noseY = rows[i + 10];
const leftMouthX = rows[i + 11];
const leftMouthY = rows[i + 12];
const rightMouthX = rows[i + 13];
const rightMouthY = rows[i + 14];
const box = new Box({
x: xMin,
y: yMin,
width: width,
height: height,
});
const probability = score as number;
const landmarks = [
new Point(leftEyeX, leftEyeY),
new Point(rightEyeX, rightEyeY),
new Point(noseX, noseY),
new Point(leftMouthX, leftMouthY),
new Point(rightMouthX, rightMouthY),
];
faces.push({ box, landmarks, probability });
}
return faces;
};
/**
* Removes duplicate face detections from an array of detections.
*
* This function sorts the detections by their probability in descending order,
* then iterates over them.
*
* For each detection, it calculates the Euclidean distance to all other
* detections.
*
* If the distance is less than or equal to the specified threshold
* (`withinDistance`), the other detection is considered a duplicate and is
* removed.
*
* @param detections - An array of face detections to remove duplicates from.
*
* @param withinDistance - The maximum Euclidean distance between two detections
* for them to be considered duplicates.
*
* @returns An array of face detections with duplicates removed.
*/
const removeDuplicateDetections = (
detections: FaceDetection[],
withinDistance: number,
) => {
detections.sort((a, b) => b.probability - a.probability);
const dupIndices = new Set<number>();
for (let i = 0; i < detections.length; i++) {
if (dupIndices.has(i)) continue;
for (let j = i + 1; j < detections.length; j++) {
if (dupIndices.has(j)) continue;
const centeri = faceDetectionCenter(detections[i]);
const centerj = faceDetectionCenter(detections[j]);
const dist = euclidean(
[centeri.x, centeri.y],
[centerj.x, centerj.y],
);
if (dist <= withinDistance) dupIndices.add(j);
}
}
return detections.filter((_, i) => !dupIndices.has(i));
};
const faceDetectionCenter = (detection: FaceDetection) => {
const center = new Point(0, 0);
// TODO-ML(LAURENS): first 4 landmarks is applicable to blazeface only this
// needs to consider eyes, nose and mouth landmarks to take center
detection.landmarks?.slice(0, 4).forEach((p) => {
center.x += p.x;
center.y += p.y;
});
return new Point(center.x / 4, center.y / 4);
};
const makeFaceID = (
fileID: number,
detection: FaceDetection,
imageDims: Dimensions,
) => {
const part = (v: number) => clamp(v, 0.0, 0.999999).toFixed(5).substring(2);
const xMin = part(detection.box.x / imageDims.width);
const yMin = part(detection.box.y / imageDims.height);
const xMax = part(
(detection.box.x + detection.box.width) / imageDims.width,
);
const yMax = part(
(detection.box.y + detection.box.height) / imageDims.height,
);
return [`${fileID}`, xMin, yMin, xMax, yMax].join("_");
};
/**
* Compute and return an {@link FaceAlignment} for the given face detection.
*
* @param faceDetection A geometry indicating a face detected in an image.
*/
const faceAlignment = (faceDetection: FaceDetection): FaceAlignment =>
faceAlignmentUsingSimilarityTransform(
faceDetection,
normalizeLandmarks(idealMobileFaceNetLandmarks, mobileFaceNetFaceSize),
);
/**
* The ideal location of the landmarks (eye etc) that the MobileFaceNet
* embedding model expects.
*/
const idealMobileFaceNetLandmarks: [number, number][] = [
[38.2946, 51.6963],
[73.5318, 51.5014],
[56.0252, 71.7366],
[41.5493, 92.3655],
[70.7299, 92.2041],
];
const normalizeLandmarks = (
landmarks: [number, number][],
faceSize: number,
image: ImageBitmap,
): Promise<Float32Array> {
): [number, number][] =>
landmarks.map(([x, y]) => [x / faceSize, y / faceSize]);
const faceAlignmentUsingSimilarityTransform = (
faceDetection: FaceDetection,
alignedLandmarks: [number, number][],
): FaceAlignment => {
const landmarksMat = new Matrix(
faceDetection.landmarks
.map((p) => [p.x, p.y])
.slice(0, alignedLandmarks.length),
).transpose();
const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose();
const simTransform = getSimilarityTransformation(
landmarksMat,
alignedLandmarksMat,
);
const RS = Matrix.mul(simTransform.rotation, simTransform.scale);
const TR = simTransform.translation;
const affineMatrix = [
[RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)],
[RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)],
[0, 0, 1],
];
const size = 1 / simTransform.scale;
const meanTranslation = simTransform.toMean.sub(0.5).mul(size);
const centerMat = simTransform.fromMean.sub(meanTranslation);
const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0));
const rotation = -Math.atan2(
simTransform.rotation.get(0, 1),
simTransform.rotation.get(0, 0),
);
return { affineMatrix, center, size, rotation };
};
const convertToMobileFaceNetInput = (
imageBitmap: ImageBitmap,
faceAlignments: FaceAlignment[],
): Float32Array => {
const faceSize = mobileFaceNetFaceSize;
const faceData = new Float32Array(
faceAlignments.length * faceSize * faceSize * 3,
);
for (let i = 0; i < faceAlignments.length; i++) {
const alignedFace = faceAlignments[i];
const { affineMatrix } = faceAlignments[i];
const faceDataOffset = i * faceSize * faceSize * 3;
warpAffineFloat32List(
image,
alignedFace,
imageBitmap,
affineMatrix,
faceSize,
faceData,
faceDataOffset,
);
}
return faceData;
}
};
/**
* Laplacian blur detection.
*
* Return an array of detected blur values, one for each face in {@link faces}.
* The face data is taken from the slice of {@link alignedFacesData}
* corresponding to each face of {@link faces}.
*/
const detectBlur = (alignedFacesData: Float32Array, faces: Face[]): number[] =>
faces.map((face, i) => {
const faceImage = grayscaleIntMatrixFromNormalized2List(
alignedFacesData,
i,
mobileFaceNetFaceSize,
mobileFaceNetFaceSize,
);
return matrixVariance(applyLaplacian(faceImage, faceDirection(face)));
});
type FaceDirection = "left" | "right" | "straight";
const faceDirection = (face: Face): FaceDirection => {
const landmarks = face.detection.landmarks;
const leftEye = landmarks[0];
const rightEye = landmarks[1];
const nose = landmarks[2];
const leftMouth = landmarks[3];
const rightMouth = landmarks[4];
const eyeDistanceX = Math.abs(rightEye.x - leftEye.x);
const eyeDistanceY = Math.abs(rightEye.y - leftEye.y);
const mouthDistanceY = Math.abs(rightMouth.y - leftMouth.y);
const faceIsUpright =
Math.max(leftEye.y, rightEye.y) + 0.5 * eyeDistanceY < nose.y &&
nose.y + 0.5 * mouthDistanceY < Math.min(leftMouth.y, rightMouth.y);
const noseStickingOutLeft =
nose.x < Math.min(leftEye.x, rightEye.x) &&
nose.x < Math.min(leftMouth.x, rightMouth.x);
const noseStickingOutRight =
nose.x > Math.max(leftEye.x, rightEye.x) &&
nose.x > Math.max(leftMouth.x, rightMouth.x);
const noseCloseToLeftEye =
Math.abs(nose.x - leftEye.x) < 0.2 * eyeDistanceX;
const noseCloseToRightEye =
Math.abs(nose.x - rightEye.x) < 0.2 * eyeDistanceX;
if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) {
return "left";
} else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) {
return "right";
}
return "straight";
};
/**
* Return a new image by applying a Laplacian blur kernel to each pixel.
*/
const applyLaplacian = (
image: number[][],
direction: FaceDirection,
): number[][] => {
const paddedImage = padImage(image, direction);
const numRows = paddedImage.length - 2;
const numCols = paddedImage[0].length - 2;
// Create an output image initialized to 0.
const outputImage: number[][] = Array.from({ length: numRows }, () =>
new Array(numCols).fill(0),
);
// Define the Laplacian kernel.
const kernel = [
[0, 1, 0],
[1, -4, 1],
[0, 1, 0],
];
// Apply the kernel to each pixel
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < numCols; j++) {
let sum = 0;
for (let ki = 0; ki < 3; ki++) {
for (let kj = 0; kj < 3; kj++) {
sum += paddedImage[i + ki][j + kj] * kernel[ki][kj];
}
}
// Adjust the output value if necessary (e.g., clipping).
outputImage[i][j] = sum;
}
}
return outputImage;
};
const padImage = (image: number[][], direction: FaceDirection): number[][] => {
const removeSideColumns = 56; /* must be even */
const numRows = image.length;
const numCols = image[0].length;
const paddedNumCols = numCols + 2 - removeSideColumns;
const paddedNumRows = numRows + 2;
// Create a new matrix with extra padding.
const paddedImage: number[][] = Array.from({ length: paddedNumRows }, () =>
new Array(paddedNumCols).fill(0),
);
if (direction === "straight") {
// Copy original image into the center of the padded image.
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < paddedNumCols - 2; j++) {
paddedImage[i + 1][j + 1] =
image[i][j + Math.round(removeSideColumns / 2)];
}
}
} else if (direction === "left") {
// If the face is facing left, we only take the right side of the face
// image.
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < paddedNumCols - 2; j++) {
paddedImage[i + 1][j + 1] = image[i][j + removeSideColumns];
}
}
} else if (direction === "right") {
// If the face is facing right, we only take the left side of the face
// image.
for (let i = 0; i < numRows; i++) {
for (let j = 0; j < paddedNumCols - 2; j++) {
paddedImage[i + 1][j + 1] = image[i][j];
}
}
}
// Reflect padding
// - Top and bottom rows
for (let j = 1; j <= paddedNumCols - 2; j++) {
// Top row
paddedImage[0][j] = paddedImage[2][j];
// Bottom row
paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j];
}
// - Left and right columns
for (let i = 0; i < numRows + 2; i++) {
// Left column
paddedImage[i][0] = paddedImage[i][2];
// Right column
paddedImage[i][paddedNumCols - 1] = paddedImage[i][paddedNumCols - 3];
}
return paddedImage;
};
const matrixVariance = (matrix: number[][]): number => {
const numRows = matrix.length;
const numCols = matrix[0].length;
const totalElements = numRows * numCols;
// Calculate the mean.
let mean: number = 0;
matrix.forEach((row) => {
row.forEach((value) => {
mean += value;
});
});
mean /= totalElements;
// Calculate the variance.
let variance: number = 0;
matrix.forEach((row) => {
row.forEach((value) => {
const diff: number = value - mean;
variance += diff * diff;
});
});
variance /= totalElements;
return variance;
};
const mobileFaceNetFaceSize = 112;
const mobileFaceNetEmbeddingSize = 192;
/**
* Compute embeddings for the given {@link faceData}.
*
* The model used is MobileFaceNet, running in an ONNX runtime.
*/
const computeEmbeddings = async (
faceData: Float32Array,
): Promise<Float32Array[]> => {
const outputData = await workerBridge.computeFaceEmbeddings(faceData);
const embeddingSize = mobileFaceNetEmbeddingSize;
const embeddings = new Array<Float32Array>(
outputData.length / embeddingSize,
);
for (let i = 0; i < embeddings.length; i++) {
embeddings[i] = new Float32Array(
outputData.slice(i * embeddingSize, (i + 1) * embeddingSize),
);
}
return embeddings;
};
/**
* Convert the coordinates to between 0-1, normalized by the image's dimensions.
*/
const relativeDetection = (
faceDetection: FaceDetection,
{ width, height }: Dimensions,
): FaceDetection => {
const oldBox: Box = faceDetection.box;
const box = new Box({
x: oldBox.x / width,
y: oldBox.y / height,
width: oldBox.width / width,
height: oldBox.height / height,
});
const landmarks = faceDetection.landmarks.map((l) => {
return new Point(l.x / width, l.y / height);
});
const probability = faceDetection.probability;
return { box, landmarks, probability };
};
export const saveFaceCrop = async (imageBitmap: ImageBitmap, face: Face) => {
const faceCrop = extractFaceCrop(imageBitmap, face.alignment);
const blob = await imageBitmapToBlob(faceCrop);
faceCrop.close();
const cache = await blobCache("face-crops");
await cache.put(face.id, blob);
return blob;
};
const imageBitmapToBlob = (imageBitmap: ImageBitmap) => {
const canvas = new OffscreenCanvas(imageBitmap.width, imageBitmap.height);
canvas.getContext("2d").drawImage(imageBitmap, 0, 0);
return canvas.convertToBlob({ type: "image/jpeg", quality: 0.8 });
};
const extractFaceCrop = (
imageBitmap: ImageBitmap,
alignment: FaceAlignment,
): ImageBitmap => {
const alignmentBox = new Box({
x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2,
width: alignment.size,
height: alignment.size,
});
const padding = 0.25;
const scaleForPadding = 1 + padding * 2;
const paddedBox = roundBox(enlargeBox(alignmentBox, scaleForPadding));
// TODO-ML(LAURENS): The rotation doesn't seem to be used? it's set to 0.
return cropWithRotation(imageBitmap, paddedBox, 0, 256);
};
const cropWithRotation = (
imageBitmap: ImageBitmap,
cropBox: Box,
rotation: number,
maxDimension: number,
) => {
const box = roundBox(cropBox);
const outputSize = { width: box.width, height: box.height };
const scale = Math.min(maxDimension / box.width, maxDimension / box.height);
if (scale < 1) {
outputSize.width = Math.round(scale * box.width);
outputSize.height = Math.round(scale * box.height);
}
const offscreen = new OffscreenCanvas(outputSize.width, outputSize.height);
const offscreenCtx = offscreen.getContext("2d");
offscreenCtx.imageSmoothingQuality = "high";
offscreenCtx.translate(outputSize.width / 2, outputSize.height / 2);
rotation && offscreenCtx.rotate(rotation);
const outputBox = new Box({
x: -outputSize.width / 2,
y: -outputSize.height / 2,
width: outputSize.width,
height: outputSize.height,
});
const enlargedBox = enlargeBox(box, 1.5);
const enlargedOutputBox = enlargeBox(outputBox, 1.5);
offscreenCtx.drawImage(
imageBitmap,
enlargedBox.x,
enlargedBox.y,
enlargedBox.width,
enlargedBox.height,
enlargedOutputBox.x,
enlargedOutputBox.y,
enlargedOutputBox.width,
enlargedOutputBox.height,
);
return offscreen.transferToImageBitmap();
};

View file

@ -12,20 +12,16 @@ export class DedicatedMLWorker {
public async syncLocalFile(
token: string,
userID: number,
userAgent: string,
enteFile: EnteFile,
localFile: globalThis.File,
) {
mlService.syncLocalFile(token, userID, enteFile, localFile);
mlService.syncLocalFile(token, userID, userAgent, enteFile, localFile);
}
public async sync(token: string, userID: number) {
public async sync(token: string, userID: number, userAgent: string) {
await downloadManager.init(APPS.PHOTOS, { token });
return mlService.sync(token, userID);
}
public async regenerateFaceCrop(token: string, faceID: string) {
await downloadManager.init(APPS.PHOTOS, { token });
return mlService.regenerateFaceCrop(faceID);
return mlService.sync(token, userID, userAgent);
}
}

View file

@ -0,0 +1,37 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import DownloadManager from "services/download";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
export const fetchImageBitmap = async (file: EnteFile) =>
fetchRenderableBlob(file).then(createImageBitmap);
async function fetchRenderableBlob(file: EnteFile) {
const fileStream = await DownloadManager.getFile(file);
const fileBlob = await new Response(fileStream).blob();
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
return await getRenderableImage(file.metadata.title, fileBlob);
} else {
const { imageFileName, imageData } = await decodeLivePhoto(
file.metadata.title,
fileBlob,
);
return await getRenderableImage(imageFileName, new Blob([imageData]));
}
}
export async function getLocalFileImageBitmap(
enteFile: EnteFile,
localFile: globalThis.File,
) {
let fileBlob = localFile as Blob;
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
return createImageBitmap(fileBlob);
}

View file

@ -13,13 +13,6 @@ export interface Dimensions {
height: number;
}
export interface IBoundingBox {
left: number;
top: number;
right: number;
bottom: number;
}
export interface IRect {
x: number;
y: number;
@ -27,24 +20,6 @@ export interface IRect {
height: number;
}
export function newBox(x: number, y: number, width: number, height: number) {
return new Box({ x, y, width, height });
}
export const boxFromBoundingBox = ({
left,
top,
right,
bottom,
}: IBoundingBox) => {
return new Box({
x: left,
y: top,
width: right - left,
height: bottom - top,
});
};
export class Box implements IRect {
public x: number;
public y: number;
@ -57,36 +32,26 @@ export class Box implements IRect {
this.width = width;
this.height = height;
}
public get topLeft(): Point {
return new Point(this.x, this.y);
}
public get bottomRight(): Point {
return new Point(this.x + this.width, this.y + this.height);
}
public round(): Box {
const [x, y, width, height] = [
this.x,
this.y,
this.width,
this.height,
].map((val) => Math.round(val));
return new Box({ x, y, width, height });
}
}
export function enlargeBox(box: Box, factor: number = 1.5) {
/** Round all the components of the box. */
export const roundBox = (box: Box): Box => {
const [x, y, width, height] = [box.x, box.y, box.width, box.height].map(
(val) => Math.round(val),
);
return new Box({ x, y, width, height });
};
/** Increase the size of the given {@link box} by {@link factor}. */
export const enlargeBox = (box: Box, factor: number) => {
const center = new Point(box.x + box.width / 2, box.y + box.height / 2);
const newWidth = factor * box.width;
const newHeight = factor * box.height;
const size = new Point(box.width, box.height);
const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2);
return boxFromBoundingBox({
left: center.x - newHalfSize.x,
top: center.y - newHalfSize.y,
right: center.x + newHalfSize.x,
bottom: center.y + newHalfSize.y,
return new Box({
x: center.x - newWidth / 2,
y: center.y - newHeight / 2,
width: newWidth,
height: newHeight,
});
}
};

View file

@ -1,121 +1,295 @@
import { FILE_TYPE } from "@/media/file-type";
import { decodeLivePhoto } from "@/media/live-photo";
import log from "@/next/log";
import DownloadManager from "services/download";
import { Dimensions } from "services/face/geom";
import { DetectedFace, MLSyncFileContext } from "services/face/types";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
import { clamp } from "utils/image";
import { Matrix, inverse } from "ml-matrix";
export const fetchImageBitmapForContext = async (
fileContext: MLSyncFileContext,
/**
* Clamp {@link value} to between {@link min} and {@link max}, inclusive.
*/
export const clamp = (value: number, min: number, max: number) =>
Math.min(max, Math.max(min, value));
/**
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
* {@link fy}) using bilinear interpolation.
*/
export function pixelRGBBilinear(
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) {
// Clamp to image boundaries.
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
// Get the surrounding coordinates and their weights.
const x0 = Math.floor(fx);
const x1 = Math.ceil(fx);
const y0 = Math.floor(fy);
const y1 = Math.ceil(fy);
const dx = fx - x0;
const dy = fy - y0;
const dx1 = 1.0 - dx;
const dy1 = 1.0 - dy;
// Get the original pixels.
const pixel1 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y0);
const pixel2 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y0);
const pixel3 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y1);
const pixel4 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y1);
const bilinear = (val1: number, val2: number, val3: number, val4: number) =>
Math.round(
val1 * dx1 * dy1 +
val2 * dx * dy1 +
val3 * dx1 * dy +
val4 * dx * dy,
);
// Return interpolated pixel colors.
return {
r: bilinear(pixel1.r, pixel2.r, pixel3.r, pixel4.r),
g: bilinear(pixel1.g, pixel2.g, pixel3.g, pixel4.g),
b: bilinear(pixel1.b, pixel2.b, pixel3.b, pixel4.b),
};
}
const pixelRGBA = (
imageData: Uint8ClampedArray,
width: number,
height: number,
x: number,
y: number,
) => {
if (fileContext.imageBitmap) {
return fileContext.imageBitmap;
if (x < 0 || x >= width || y < 0 || y >= height) {
return { r: 0, g: 0, b: 0, a: 0 };
}
if (fileContext.localFile) {
if (fileContext.enteFile.metadata.fileType !== FILE_TYPE.IMAGE) {
throw new Error("Local file of only image type is supported");
}
fileContext.imageBitmap = await getLocalFileImageBitmap(
fileContext.enteFile,
fileContext.localFile,
);
} else if (
[FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes(
fileContext.enteFile.metadata.fileType,
)
) {
fileContext.imageBitmap = await fetchImageBitmap(fileContext.enteFile);
} else {
// TODO-ML(MR): We don't do it on videos, when will we ever come
// here?
fileContext.imageBitmap = await getThumbnailImageBitmap(
fileContext.enteFile,
);
}
fileContext.newMlFile.imageSource = "Original";
const { width, height } = fileContext.imageBitmap;
fileContext.newMlFile.imageDimensions = { width, height };
return fileContext.imageBitmap;
const index = (y * width + x) * 4;
return {
r: imageData[index],
g: imageData[index + 1],
b: imageData[index + 2],
a: imageData[index + 3],
};
};
export async function getLocalFile(fileId: number) {
const localFiles = await getLocalFiles();
return localFiles.find((f) => f.id === fileId);
}
/**
* Returns the pixel value (RGB) at the given coordinates ({@link fx},
* {@link fy}) using bicubic interpolation.
*/
const pixelRGBBicubic = (
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) => {
// Clamp to image boundaries.
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) {
const xMin = clamp(
detectedFace.detection.box.x / imageDims.width,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const yMin = clamp(
detectedFace.detection.box.y / imageDims.height,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const xMax = clamp(
(detectedFace.detection.box.x + detectedFace.detection.box.width) /
imageDims.width,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const yMax = clamp(
(detectedFace.detection.box.y + detectedFace.detection.box.height) /
imageDims.height,
0.0,
0.999999,
)
.toFixed(5)
.substring(2);
const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1);
const px = x - 1;
const nx = x + 1;
const ax = x + 2;
const y = Math.trunc(fy) - (fy >= 0.0 ? 0 : 1);
const py = y - 1;
const ny = y + 1;
const ay = y + 2;
const dx = fx - x;
const dy = fy - y;
const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`;
const faceID = `${detectedFace.fileId}_${rawFaceID}`;
const cubic = (
dx: number,
ipp: number,
icp: number,
inp: number,
iap: number,
) =>
icp +
0.5 *
(dx * (-ipp + inp) +
dx * dx * (2 * ipp - 5 * icp + 4 * inp - iap) +
dx * dx * dx * (-ipp + 3 * icp - 3 * inp + iap));
return faceID;
}
const icc = pixelRGBA(imageData, imageWidth, imageHeight, x, y);
export const fetchImageBitmap = async (file: EnteFile) =>
fetchRenderableBlob(file).then(createImageBitmap);
const ipp =
px < 0 || py < 0
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, px, py);
const icp =
px < 0 ? icc : pixelRGBA(imageData, imageWidth, imageHeight, x, py);
const inp =
py < 0 || nx >= imageWidth
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, nx, py);
const iap =
ax >= imageWidth || py < 0
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, ax, py);
async function fetchRenderableBlob(file: EnteFile) {
const fileStream = await DownloadManager.getFile(file);
const fileBlob = await new Response(fileStream).blob();
if (file.metadata.fileType === FILE_TYPE.IMAGE) {
return await getRenderableImage(file.metadata.title, fileBlob);
} else {
const { imageFileName, imageData } = await decodeLivePhoto(
file.metadata.title,
fileBlob,
);
return await getRenderableImage(imageFileName, new Blob([imageData]));
const ip0 = cubic(dx, ipp.r, icp.r, inp.r, iap.r);
const ip1 = cubic(dx, ipp.g, icp.g, inp.g, iap.g);
const ip2 = cubic(dx, ipp.b, icp.b, inp.b, iap.b);
// const ip3 = cubic(dx, ipp.a, icp.a, inp.a, iap.a);
const ipc =
px < 0 ? icc : pixelRGBA(imageData, imageWidth, imageHeight, px, y);
const inc =
nx >= imageWidth
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, nx, y);
const iac =
ax >= imageWidth
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, ax, y);
const ic0 = cubic(dx, ipc.r, icc.r, inc.r, iac.r);
const ic1 = cubic(dx, ipc.g, icc.g, inc.g, iac.g);
const ic2 = cubic(dx, ipc.b, icc.b, inc.b, iac.b);
// const ic3 = cubic(dx, ipc.a, icc.a, inc.a, iac.a);
const ipn =
px < 0 || ny >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, px, ny);
const icn =
ny >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, x, ny);
const inn =
nx >= imageWidth || ny >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, nx, ny);
const ian =
ax >= imageWidth || ny >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, ax, ny);
const in0 = cubic(dx, ipn.r, icn.r, inn.r, ian.r);
const in1 = cubic(dx, ipn.g, icn.g, inn.g, ian.g);
const in2 = cubic(dx, ipn.b, icn.b, inn.b, ian.b);
// const in3 = cubic(dx, ipn.a, icn.a, inn.a, ian.a);
const ipa =
px < 0 || ay >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, px, ay);
const ica =
ay >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, x, ay);
const ina =
nx >= imageWidth || ay >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, nx, ay);
const iaa =
ax >= imageWidth || ay >= imageHeight
? icc
: pixelRGBA(imageData, imageWidth, imageHeight, ax, ay);
const ia0 = cubic(dx, ipa.r, ica.r, ina.r, iaa.r);
const ia1 = cubic(dx, ipa.g, ica.g, ina.g, iaa.g);
const ia2 = cubic(dx, ipa.b, ica.b, ina.b, iaa.b);
// const ia3 = cubic(dx, ipa.a, ica.a, ina.a, iaa.a);
const c0 = Math.trunc(clamp(cubic(dy, ip0, ic0, in0, ia0), 0, 255));
const c1 = Math.trunc(clamp(cubic(dy, ip1, ic1, in1, ia1), 0, 255));
const c2 = Math.trunc(clamp(cubic(dy, ip2, ic2, in2, ia2), 0, 255));
// const c3 = cubic(dy, ip3, ic3, in3, ia3);
return { r: c0, g: c1, b: c2 };
};
/**
* Transform {@link inputData} starting at {@link inputStartIndex}.
*/
export const warpAffineFloat32List = (
imageBitmap: ImageBitmap,
faceAlignmentAffineMatrix: number[][],
faceSize: number,
inputData: Float32Array,
inputStartIndex: number,
): void => {
const { width, height } = imageBitmap;
// Get the pixel data.
const offscreenCanvas = new OffscreenCanvas(width, height);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, width, height);
const imageData = ctx.getImageData(0, 0, width, height);
const pixelData = imageData.data;
const transformationMatrix = faceAlignmentAffineMatrix.map((row) =>
row.map((val) => (val != 1.0 ? val * faceSize : 1.0)),
); // 3x3
const A: Matrix = new Matrix([
[transformationMatrix[0][0], transformationMatrix[0][1]],
[transformationMatrix[1][0], transformationMatrix[1][1]],
]);
const Ainverse = inverse(A);
const b00 = transformationMatrix[0][2];
const b10 = transformationMatrix[1][2];
const a00Prime = Ainverse.get(0, 0);
const a01Prime = Ainverse.get(0, 1);
const a10Prime = Ainverse.get(1, 0);
const a11Prime = Ainverse.get(1, 1);
for (let yTrans = 0; yTrans < faceSize; ++yTrans) {
for (let xTrans = 0; xTrans < faceSize; ++xTrans) {
// Perform inverse affine transformation.
const xOrigin =
a00Prime * (xTrans - b00) + a01Prime * (yTrans - b10);
const yOrigin =
a10Prime * (xTrans - b00) + a11Prime * (yTrans - b10);
// Get the pixel RGB using bicubic interpolation.
const { r, g, b } = pixelRGBBicubic(
xOrigin,
yOrigin,
pixelData,
width,
height,
);
// Set the pixel in the input data.
const index = (yTrans * faceSize + xTrans) * 3;
inputData[inputStartIndex + index] = rgbToBipolarFloat(r);
inputData[inputStartIndex + index + 1] = rgbToBipolarFloat(g);
inputData[inputStartIndex + index + 2] = rgbToBipolarFloat(b);
}
}
}
};
export async function getThumbnailImageBitmap(file: EnteFile) {
const thumb = await DownloadManager.getThumbnail(file);
log.info("[MLService] Got thumbnail: ", file.id.toString());
/** Convert a RGB component 0-255 to a floating point value between -1 and 1. */
const rgbToBipolarFloat = (pixelValue: number) => pixelValue / 127.5 - 1.0;
return createImageBitmap(new Blob([thumb]));
}
/** Convert a floating point value between -1 and 1 to a RGB component 0-255. */
const bipolarFloatToRGB = (pixelValue: number) =>
clamp(Math.round((pixelValue + 1.0) * 127.5), 0, 255);
export async function getLocalFileImageBitmap(
enteFile: EnteFile,
localFile: globalThis.File,
) {
let fileBlob = localFile as Blob;
fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob);
return createImageBitmap(fileBlob);
}
export const grayscaleIntMatrixFromNormalized2List = (
imageList: Float32Array,
faceNumber: number,
width: number,
height: number,
): number[][] => {
const startIndex = faceNumber * width * height * 3;
return Array.from({ length: height }, (_, y) =>
Array.from({ length: width }, (_, x) => {
// 0.299 ∙ Red + 0.587 ∙ Green + 0.114 ∙ Blue
const pixelIndex = startIndex + 3 * (y * width + x);
return clamp(
Math.round(
0.299 * bipolarFloatToRGB(imageList[pixelIndex]) +
0.587 * bipolarFloatToRGB(imageList[pixelIndex + 1]) +
0.114 * bipolarFloatToRGB(imageList[pixelIndex + 2]),
),
0,
255,
);
}),
);
};

View file

@ -1,37 +1,54 @@
import log from "@/next/log";
import mlIDbStorage from "services/face/db";
import { Face, Person } from "services/face/types";
import { type MLSyncContext } from "services/machineLearning/machineLearningService";
import { clusterFaces } from "./cluster";
import { saveFaceCrop } from "./f-index";
import { fetchImageBitmap, getLocalFile } from "./image";
export interface Person {
id: number;
name?: string;
files: Array<number>;
displayFaceId?: string;
}
// TODO-ML(MR): Forced disable clustering. It doesn't currently work,
// need to finalize it before we move out of beta.
//
// > Error: Failed to execute 'transferToImageBitmap' on
// > 'OffscreenCanvas': ImageBitmap construction failed
/*
export const syncPeopleIndex = async () => {
if (
syncContext.outOfSyncFiles.length <= 0 ||
(syncContext.nSyncedFiles === batchSize && Math.random() < 0)
) {
await this.syncIndex(syncContext);
}
public async syncIndex(syncContext: MLSyncContext) {
await this.getMLLibraryData(syncContext);
await syncPeopleIndex(syncContext);
await this.persistMLLibraryData(syncContext);
}
export const syncPeopleIndex = async (syncContext: MLSyncContext) => {
const filesVersion = await mlIDbStorage.getIndexVersion("files");
if (filesVersion <= (await mlIDbStorage.getIndexVersion("people"))) {
return;
}
// TODO: have faces addresable through fileId + faceId
// to avoid index based addressing, which is prone to wrong results
// one way could be to match nearest face within threshold in the file
const allFacesMap =
syncContext.allSyncedFacesMap ??
(syncContext.allSyncedFacesMap = await mlIDbStorage.getAllFacesMap());
const allFaces = [...allFacesMap.values()].flat();
await runFaceClustering(syncContext, allFaces);
await syncPeopleFromClusters(syncContext, allFacesMap, allFaces);
await mlIDbStorage.setIndexVersion("people", filesVersion);
};
const runFaceClustering = async (
syncContext: MLSyncContext,
allFaces: Array<Face>,
) => {
// await this.init();
const allFacesMap = await mlIDbStorage.getAllFacesMap();
const allFaces = [...allFacesMap.values()].flat();
if (!allFaces || allFaces.length < 50) {
log.info(
`Skipping clustering since number of faces (${allFaces.length}) is less than the clustering threshold (50)`,
@ -40,34 +57,15 @@ const runFaceClustering = async (
}
log.info("Running clustering allFaces: ", allFaces.length);
syncContext.mlLibraryData.faceClusteringResults = await clusterFaces(
const faceClusteringResults = await clusterFaces(
allFaces.map((f) => Array.from(f.embedding)),
);
syncContext.mlLibraryData.faceClusteringMethod = {
value: "Hdbscan",
version: 1,
};
log.info(
"[MLService] Got face clustering results: ",
JSON.stringify(syncContext.mlLibraryData.faceClusteringResults),
JSON.stringify(faceClusteringResults),
);
// syncContext.faceClustersWithNoise = {
// clusters: syncContext.faceClusteringResults.clusters.map(
// (faces) => ({
// faces,
// })
// ),
// noise: syncContext.faceClusteringResults.noise,
// };
};
const syncPeopleFromClusters = async (
syncContext: MLSyncContext,
allFacesMap: Map<number, Array<Face>>,
allFaces: Array<Face>,
) => {
const clusters = syncContext.mlLibraryData.faceClusteringResults?.clusters;
const clusters = faceClusteringResults?.clusters;
if (!clusters || clusters.length < 1) {
return;
}
@ -86,17 +84,18 @@ const syncPeopleFromClusters = async (
: best,
);
if (personFace && !personFace.crop?.cacheKey) {
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
await saveFaceCrop(imageBitmap, personFace);
}
const person: Person = {
id: index,
files: faces.map((f) => f.fileId),
displayFaceId: personFace?.id,
faceCropCacheKey: personFace?.crop?.cacheKey,
};
await mlIDbStorage.putPerson(person);
@ -108,4 +107,24 @@ const syncPeopleFromClusters = async (
}
await mlIDbStorage.updateFaces(allFacesMap);
// await mlIDbStorage.setIndexVersion("people", filesVersion);
};
public async regenerateFaceCrop(token: string, faceID: string) {
await downloadManager.init(APPS.PHOTOS, { token });
return mlService.regenerateFaceCrop(faceID);
}
export const regenerateFaceCrop = async (faceID: string) => {
const fileID = Number(faceID.split("-")[0]);
const personFace = await mlIDbStorage.getFace(fileID, faceID);
if (!personFace) {
throw Error("Face not found");
}
const file = await getLocalFile(personFace.fileId);
const imageBitmap = await fetchImageBitmap(file);
return await saveFaceCrop(imageBitmap, personFace);
};
*/

View file

@ -0,0 +1,158 @@
import log from "@/next/log";
import ComlinkCryptoWorker from "@ente/shared/crypto";
import { putEmbedding } from "services/embeddingService";
import type { EnteFile } from "types/file";
import type { Point } from "./geom";
import type { Face, FaceDetection, MlFileData } from "./types";
export const putFaceEmbedding = async (
enteFile: EnteFile,
mlFileData: MlFileData,
userAgent: string,
) => {
const serverMl = LocalFileMlDataToServerFileMl(mlFileData, userAgent);
log.debug(() => ({ t: "Local ML file data", mlFileData }));
log.debug(() => ({
t: "Uploaded ML file data",
d: JSON.stringify(serverMl),
}));
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
log.info(
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
const res = await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "file-ml-clip-face",
});
log.info("putEmbedding response: ", res);
};
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
public client: string;
public constructor(faces: ServerFace[], client: string, version: number) {
this.faces = faces;
this.client = client;
this.version = version;
}
}
class ServerFace {
public faceID: string;
public embedding: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public constructor(
faceID: string,
embedding: number[],
detection: ServerDetection,
score: number,
blur: number,
) {
this.faceID = faceID;
this.embedding = embedding;
this.detection = detection;
this.score = score;
this.blur = blur;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Point[];
public constructor(box: ServerFaceBox, landmarks: Point[]) {
this.box = box;
this.landmarks = landmarks;
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
}
function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
userAgent: string,
): ServerFileMl {
if (localFileMlData.errorCount > 0) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
const newFaceObject = new ServerFace(
faceID,
Array.from(embedding),
new ServerDetection(newBox, landmarks),
score,
blur,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(faces, userAgent, 1);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
imageDimensions.height,
imageDimensions.width,
);
}

View file

@ -0,0 +1,57 @@
import { Box, Point } from "services/face/geom";
import type { FaceDetection } from "services/face/types";
// TODO-ML(LAURENS): Do we need two separate Matrix libraries?
//
// Keeping this in a separate file so that we can audit this. If these can be
// expressed using ml-matrix, then we can move this code to f-index.ts
import {
Matrix,
applyToPoint,
compose,
scale,
translate,
} from "transformation-matrix";
/**
* Transform the given {@link faceDetections} from their coordinate system in
* which they were detected ({@link inBox}) back to the coordinate system of the
* original image ({@link toBox}).
*/
export const transformFaceDetections = (
faceDetections: FaceDetection[],
inBox: Box,
toBox: Box,
): FaceDetection[] => {
const transform = boxTransformationMatrix(inBox, toBox);
return faceDetections.map((f) => ({
box: transformBox(f.box, transform),
landmarks: f.landmarks.map((p) => transformPoint(p, transform)),
probability: f.probability,
}));
};
const boxTransformationMatrix = (inBox: Box, toBox: Box): Matrix =>
compose(
translate(toBox.x, toBox.y),
scale(toBox.width / inBox.width, toBox.height / inBox.height),
);
const transformPoint = (point: Point, transform: Matrix) => {
const txdPoint = applyToPoint(transform, point);
return new Point(txdPoint.x, txdPoint.y);
};
const transformBox = (box: Box, transform: Matrix) => {
const topLeft = transformPoint(new Point(box.x, box.y), transform);
const bottomRight = transformPoint(
new Point(box.x + box.width, box.y + box.height),
transform,
);
return new Box({
x: topLeft.x,
y: topLeft.y,
width: bottomRight.x - topLeft.x,
height: bottomRight.y - topLeft.y,
});
};

View file

@ -1,161 +1,39 @@
import type { ClusterFacesResult } from "services/face/cluster";
import { Dimensions } from "services/face/geom";
import { EnteFile } from "types/file";
import { Box, Point } from "./geom";
export interface MLSyncResult {
nOutOfSyncFiles: number;
nSyncedFiles: number;
nSyncedFaces: number;
nFaceClusters: number;
nFaceNoise: number;
error?: Error;
}
export declare type FaceDescriptor = Float32Array;
export declare type Cluster = Array<number>;
export interface FacesCluster {
faces: Cluster;
summary?: FaceDescriptor;
}
export interface FacesClustersWithNoise {
clusters: Array<FacesCluster>;
noise: Cluster;
}
export interface NearestCluster {
cluster: FacesCluster;
distance: number;
}
export declare type Landmark = Point;
export declare type ImageType = "Original" | "Preview";
export declare type FaceDetectionMethod = "YoloFace";
export declare type FaceCropMethod = "ArcFace";
export declare type FaceAlignmentMethod = "ArcFace";
export declare type FaceEmbeddingMethod = "MobileFaceNet";
export declare type BlurDetectionMethod = "Laplacian";
export declare type ClusteringMethod = "Hdbscan" | "Dbscan";
export class AlignedBox {
box: Box;
rotation: number;
}
export interface Versioned<T> {
value: T;
version: number;
}
import { Box, Dimensions, Point } from "services/face/geom";
export interface FaceDetection {
// box and landmarks is relative to image dimentions stored at mlFileData
box: Box;
landmarks?: Array<Landmark>;
landmarks?: Point[];
probability?: number;
}
export interface DetectedFace {
fileId: number;
detection: FaceDetection;
}
export interface DetectedFaceWithId extends DetectedFace {
id: string;
}
export interface FaceCrop {
image: ImageBitmap;
// imageBox is relative to image dimentions stored at mlFileData
imageBox: Box;
}
export interface StoredFaceCrop {
cacheKey: string;
imageBox: Box;
}
export interface CroppedFace extends DetectedFaceWithId {
crop?: StoredFaceCrop;
}
export interface FaceAlignment {
// TODO: remove affine matrix as rotation, size and center
// TODO-ML(MR): remove affine matrix as rotation, size and center
// are simple to store and use, affine matrix adds complexity while getting crop
affineMatrix: Array<Array<number>>;
affineMatrix: number[][];
rotation: number;
// size and center is relative to image dimentions stored at mlFileData
size: number;
center: Point;
}
export interface AlignedFace extends CroppedFace {
export interface Face {
fileId: number;
detection: FaceDetection;
id: string;
alignment?: FaceAlignment;
blurValue?: number;
}
export declare type FaceEmbedding = Float32Array;
embedding?: Float32Array;
export interface FaceWithEmbedding extends AlignedFace {
embedding?: FaceEmbedding;
}
export interface Face extends FaceWithEmbedding {
personId?: number;
}
export interface Person {
id: number;
name?: string;
files: Array<number>;
displayFaceId?: string;
faceCropCacheKey?: string;
}
export interface MlFileData {
fileId: number;
faces?: Face[];
imageSource?: ImageType;
imageDimensions?: Dimensions;
faceDetectionMethod?: Versioned<FaceDetectionMethod>;
faceCropMethod?: Versioned<FaceCropMethod>;
faceAlignmentMethod?: Versioned<FaceAlignmentMethod>;
faceEmbeddingMethod?: Versioned<FaceEmbeddingMethod>;
mlVersion: number;
errorCount: number;
lastErrorMessage?: string;
}
export interface MLSearchConfig {
enabled: boolean;
}
export interface MLSyncFileContext {
enteFile: EnteFile;
localFile?: globalThis.File;
oldMlFile?: MlFileData;
newMlFile?: MlFileData;
imageBitmap?: ImageBitmap;
newDetection?: boolean;
newAlignment?: boolean;
}
export interface MLLibraryData {
faceClusteringMethod?: Versioned<ClusteringMethod>;
faceClusteringResults?: ClusterFacesResult;
faceClustersWithNoise?: FacesClustersWithNoise;
}
export declare type MLIndex = "files" | "people";

View file

@ -51,9 +51,7 @@ class HEICConverter {
const startTime = Date.now();
const convertedHEIC =
await worker.heicToJPEG(fileBlob);
const ms = Math.round(
Date.now() - startTime,
);
const ms = Date.now() - startTime;
log.debug(() => `heic => jpeg (${ms} ms)`);
clearTimeout(timeout);
resolve(convertedHEIC);

View file

@ -1,41 +1,26 @@
import { haveWindow } from "@/next/env";
import log from "@/next/log";
import { ComlinkWorker } from "@/next/worker/comlink-worker";
import ComlinkCryptoWorker, {
getDedicatedCryptoWorker,
} from "@ente/shared/crypto";
import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker";
import { CustomError, parseUploadErrorCodes } from "@ente/shared/error";
import PQueue from "p-queue";
import { putEmbedding } from "services/embeddingService";
import mlIDbStorage, { ML_SEARCH_CONFIG_NAME } from "services/face/db";
import {
Face,
FaceDetection,
Landmark,
MLLibraryData,
MLSearchConfig,
MLSyncFileContext,
MLSyncResult,
MlFileData,
} from "services/face/types";
import mlIDbStorage, {
ML_SEARCH_CONFIG_NAME,
type MinimalPersistedFileData,
} from "services/face/db";
import { putFaceEmbedding } from "services/face/remote";
import { getLocalFiles } from "services/fileService";
import { EnteFile } from "types/file";
import { isInternalUserForML } from "utils/user";
import { regenerateFaceCrop, syncFileAnalyzeFaces } from "../face/f-index";
import { fetchImageBitmapForContext } from "../face/image";
import { syncPeopleIndex } from "../face/people";
import { indexFaces } from "../face/f-index";
/**
* TODO-ML(MR): What and why.
* Also, needs to be 1 (in sync with mobile) when we move out of beta.
*/
export const defaultMLVersion = 3;
export const defaultMLVersion = 1;
const batchSize = 200;
export const MAX_ML_SYNC_ERROR_COUNT = 1;
export interface MLSearchConfig {
enabled: boolean;
}
export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = {
enabled: false,
};
@ -56,107 +41,54 @@ export async function updateMLSearchConfig(newConfig: MLSearchConfig) {
return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig);
}
export interface MLSyncContext {
token: string;
userID: number;
localFilesMap: Map<number, EnteFile>;
outOfSyncFiles: EnteFile[];
nSyncedFiles: number;
nSyncedFaces: number;
allSyncedFacesMap?: Map<number, Array<Face>>;
error?: Error;
// oldMLLibraryData: MLLibraryData;
mlLibraryData: MLLibraryData;
syncQueue: PQueue;
getEnteWorker(id: number): Promise<any>;
dispose(): Promise<void>;
}
export class LocalMLSyncContext implements MLSyncContext {
class MLSyncContext {
public token: string;
public userID: number;
public userAgent: string;
public localFilesMap: Map<number, EnteFile>;
public outOfSyncFiles: EnteFile[];
public nSyncedFiles: number;
public nSyncedFaces: number;
public allSyncedFacesMap?: Map<number, Array<Face>>;
public error?: Error;
public mlLibraryData: MLLibraryData;
public syncQueue: PQueue;
// TODO: wheather to limit concurrent downloads
// private downloadQueue: PQueue;
private concurrency: number;
private comlinkCryptoWorker: Array<
ComlinkWorker<typeof DedicatedCryptoWorker>
>;
private enteWorkers: Array<any>;
constructor(token: string, userID: number, concurrency?: number) {
constructor(token: string, userID: number, userAgent: string) {
this.token = token;
this.userID = userID;
this.userAgent = userAgent;
this.outOfSyncFiles = [];
this.nSyncedFiles = 0;
this.nSyncedFaces = 0;
this.concurrency = concurrency ?? getConcurrency();
log.info("Using concurrency: ", this.concurrency);
// timeout is added on downloads
// timeout on queue will keep the operation open till worker is terminated
this.syncQueue = new PQueue({ concurrency: this.concurrency });
logQueueStats(this.syncQueue, "sync");
// this.downloadQueue = new PQueue({ concurrency: 1 });
// logQueueStats(this.downloadQueue, 'download');
this.comlinkCryptoWorker = new Array(this.concurrency);
this.enteWorkers = new Array(this.concurrency);
}
public async getEnteWorker(id: number): Promise<any> {
const wid = id % this.enteWorkers.length;
console.log("getEnteWorker: ", id, wid);
if (!this.enteWorkers[wid]) {
this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker();
this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote;
}
return this.enteWorkers[wid];
const concurrency = getConcurrency();
this.syncQueue = new PQueue({ concurrency });
}
public async dispose() {
this.localFilesMap = undefined;
await this.syncQueue.onIdle();
this.syncQueue.removeAllListeners();
for (const enteComlinkWorker of this.comlinkCryptoWorker) {
enteComlinkWorker?.terminate();
}
}
}
export const getConcurrency = () =>
haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2));
const getConcurrency = () =>
Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2));
class MachineLearningService {
private localSyncContext: Promise<MLSyncContext>;
private syncContext: Promise<MLSyncContext>;
public async sync(token: string, userID: number): Promise<MLSyncResult> {
public async sync(
token: string,
userID: number,
userAgent: string,
): Promise<boolean> {
if (!token) {
throw Error("Token needed by ml service to sync file");
}
const syncContext = await this.getSyncContext(token, userID);
const syncContext = await this.getSyncContext(token, userID, userAgent);
await this.syncLocalFiles(syncContext);
@ -166,38 +98,9 @@ class MachineLearningService {
await this.syncFiles(syncContext);
}
// TODO-ML(MR): Forced disable clustering. It doesn't currently work,
// need to finalize it before we move out of beta.
//
// > Error: Failed to execute 'transferToImageBitmap' on
// > 'OffscreenCanvas': ImageBitmap construction failed
/*
if (
syncContext.outOfSyncFiles.length <= 0 ||
(syncContext.nSyncedFiles === batchSize && Math.random() < 0)
) {
await this.syncIndex(syncContext);
}
*/
const mlSyncResult: MLSyncResult = {
nOutOfSyncFiles: syncContext.outOfSyncFiles.length,
nSyncedFiles: syncContext.nSyncedFiles,
nSyncedFaces: syncContext.nSyncedFaces,
nFaceClusters:
syncContext.mlLibraryData?.faceClusteringResults?.clusters
.length,
nFaceNoise:
syncContext.mlLibraryData?.faceClusteringResults?.noise.length,
error: syncContext.error,
};
// log.info('[MLService] sync results: ', mlSyncResult);
return mlSyncResult;
}
public async regenerateFaceCrop(faceID: string) {
return regenerateFaceCrop(faceID);
const error = syncContext.error;
const nOutOfSyncFiles = syncContext.outOfSyncFiles.length;
return !error && nOutOfSyncFiles > 0;
}
private newMlData(fileId: number) {
@ -205,7 +108,7 @@ class MachineLearningService {
fileId,
mlVersion: 0,
errorCount: 0,
} as MlFileData;
} as MinimalPersistedFileData;
}
private async getLocalFilesMap(syncContext: MLSyncContext) {
@ -309,7 +212,6 @@ class MachineLearningService {
syncContext.error = error;
}
await syncContext.syncQueue.onIdle();
log.info("allFaces: ", syncContext.nSyncedFaces);
// TODO: In case syncJob has to use multiple ml workers
// do in same transaction with each file update
@ -318,13 +220,17 @@ class MachineLearningService {
// await this.disposeMLModels();
}
private async getSyncContext(token: string, userID: number) {
private async getSyncContext(
token: string,
userID: number,
userAgent: string,
) {
if (!this.syncContext) {
log.info("Creating syncContext");
// TODO-ML(MR): Keep as promise for now.
this.syncContext = new Promise((resolve) => {
resolve(new LocalMLSyncContext(token, userID));
resolve(new MLSyncContext(token, userID, userAgent));
});
} else {
log.info("reusing existing syncContext");
@ -332,13 +238,17 @@ class MachineLearningService {
return this.syncContext;
}
private async getLocalSyncContext(token: string, userID: number) {
private async getLocalSyncContext(
token: string,
userID: number,
userAgent: string,
) {
// TODO-ML(MR): This is updating the file ML version. verify.
if (!this.localSyncContext) {
log.info("Creating localSyncContext");
// TODO-ML(MR):
this.localSyncContext = new Promise((resolve) => {
resolve(new LocalMLSyncContext(token, userID));
resolve(new MLSyncContext(token, userID, userAgent));
});
} else {
log.info("reusing existing localSyncContext");
@ -358,10 +268,15 @@ class MachineLearningService {
public async syncLocalFile(
token: string,
userID: number,
userAgent: string,
enteFile: EnteFile,
localFile?: globalThis.File,
) {
const syncContext = await this.getLocalSyncContext(token, userID);
const syncContext = await this.getLocalSyncContext(
token,
userID,
userAgent,
);
try {
await this.syncFileWithErrorHandler(
@ -385,11 +300,11 @@ class MachineLearningService {
localFile?: globalThis.File,
) {
try {
console.log(
`Indexing ${enteFile.title ?? "<untitled>"} ${enteFile.id}`,
const mlFileData = await this.syncFile(
enteFile,
localFile,
syncContext.userAgent,
);
const mlFileData = await this.syncFile(enteFile, localFile);
syncContext.nSyncedFaces += mlFileData.faces?.length || 0;
syncContext.nSyncedFiles += 1;
return mlFileData;
} catch (e) {
@ -421,62 +336,22 @@ class MachineLearningService {
}
}
private async syncFile(enteFile: EnteFile, localFile?: globalThis.File) {
log.debug(() => ({ a: "Syncing file", enteFile }));
const fileContext: MLSyncFileContext = { enteFile, localFile };
const oldMlFile = await this.getMLFileData(enteFile.id);
private async syncFile(
enteFile: EnteFile,
localFile: globalThis.File | undefined,
userAgent: string,
) {
const oldMlFile = await mlIDbStorage.getFile(enteFile.id);
if (oldMlFile && oldMlFile.mlVersion) {
return oldMlFile;
}
const newMlFile = (fileContext.newMlFile = this.newMlData(enteFile.id));
newMlFile.mlVersion = defaultMLVersion;
try {
await fetchImageBitmapForContext(fileContext);
await syncFileAnalyzeFaces(fileContext);
newMlFile.errorCount = 0;
newMlFile.lastErrorMessage = undefined;
await this.persistOnServer(newMlFile, enteFile);
await mlIDbStorage.putFile(newMlFile);
} catch (e) {
log.error("ml detection failed", e);
newMlFile.mlVersion = oldMlFile.mlVersion;
throw e;
} finally {
fileContext.imageBitmap && fileContext.imageBitmap.close();
}
const newMlFile = await indexFaces(enteFile, localFile);
await putFaceEmbedding(enteFile, newMlFile, userAgent);
await mlIDbStorage.putFile(newMlFile);
return newMlFile;
}
private async persistOnServer(mlFileData: MlFileData, enteFile: EnteFile) {
const serverMl = LocalFileMlDataToServerFileMl(mlFileData);
log.debug(() => ({ t: "Local ML file data", mlFileData }));
log.debug(() => ({
t: "Uploaded ML file data",
d: JSON.stringify(serverMl),
}));
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
log.info(
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
const res = await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "file-ml-clip-face",
});
log.info("putEmbedding response: ", res);
}
private async getMLFileData(fileId: number) {
return mlIDbStorage.getFile(fileId);
}
private async persistMLFileSyncError(enteFile: EnteFile, e: Error) {
try {
await mlIDbStorage.upsertFileInTx(enteFile.id, (mlFileData) => {
@ -484,7 +359,7 @@ class MachineLearningService {
mlFileData = this.newMlData(enteFile.id);
}
mlFileData.errorCount = (mlFileData.errorCount || 0) + 1;
mlFileData.lastErrorMessage = e.message;
console.error(`lastError for ${enteFile.id}`, e);
return mlFileData;
});
@ -493,183 +368,6 @@ class MachineLearningService {
console.error("Error while storing ml sync error", e);
}
}
private async getMLLibraryData(syncContext: MLSyncContext) {
syncContext.mlLibraryData = await mlIDbStorage.getLibraryData();
if (!syncContext.mlLibraryData) {
syncContext.mlLibraryData = {};
}
}
private async persistMLLibraryData(syncContext: MLSyncContext) {
return mlIDbStorage.putLibraryData(syncContext.mlLibraryData);
}
public async syncIndex(syncContext: MLSyncContext) {
await this.getMLLibraryData(syncContext);
// TODO-ML(MR): Ensure this doesn't run until fixed.
await syncPeopleIndex(syncContext);
await this.persistMLLibraryData(syncContext);
}
}
export default new MachineLearningService();
export interface FileML extends ServerFileMl {
updatedAt: number;
}
class ServerFileMl {
public fileID: number;
public height?: number;
public width?: number;
public faceEmbedding: ServerFaceEmbeddings;
public constructor(
fileID: number,
faceEmbedding: ServerFaceEmbeddings,
height?: number,
width?: number,
) {
this.fileID = fileID;
this.height = height;
this.width = width;
this.faceEmbedding = faceEmbedding;
}
}
class ServerFaceEmbeddings {
public faces: ServerFace[];
public version: number;
public client?: string;
public error?: boolean;
public constructor(
faces: ServerFace[],
version: number,
client?: string,
error?: boolean,
) {
this.faces = faces;
this.version = version;
this.client = client;
this.error = error;
}
}
class ServerFace {
public faceID: string;
public embeddings: number[];
public detection: ServerDetection;
public score: number;
public blur: number;
public constructor(
faceID: string,
embeddings: number[],
detection: ServerDetection,
score: number,
blur: number,
) {
this.faceID = faceID;
this.embeddings = embeddings;
this.detection = detection;
this.score = score;
this.blur = blur;
}
}
class ServerDetection {
public box: ServerFaceBox;
public landmarks: Landmark[];
public constructor(box: ServerFaceBox, landmarks: Landmark[]) {
this.box = box;
this.landmarks = landmarks;
}
}
class ServerFaceBox {
public xMin: number;
public yMin: number;
public width: number;
public height: number;
public constructor(
xMin: number,
yMin: number,
width: number,
height: number,
) {
this.xMin = xMin;
this.yMin = yMin;
this.width = width;
this.height = height;
}
}
function LocalFileMlDataToServerFileMl(
localFileMlData: MlFileData,
): ServerFileMl {
if (
localFileMlData.errorCount > 0 &&
localFileMlData.lastErrorMessage !== undefined
) {
return null;
}
const imageDimensions = localFileMlData.imageDimensions;
const faces: ServerFace[] = [];
for (let i = 0; i < localFileMlData.faces.length; i++) {
const face: Face = localFileMlData.faces[i];
const faceID = face.id;
const embedding = face.embedding;
const score = face.detection.probability;
const blur = face.blurValue;
const detection: FaceDetection = face.detection;
const box = detection.box;
const landmarks = detection.landmarks;
const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height);
const newLandmarks: Landmark[] = [];
for (let j = 0; j < landmarks.length; j++) {
newLandmarks.push({
x: landmarks[j].x,
y: landmarks[j].y,
} as Landmark);
}
const newFaceObject = new ServerFace(
faceID,
Array.from(embedding),
new ServerDetection(newBox, newLandmarks),
score,
blur,
);
faces.push(newFaceObject);
}
const faceEmbeddings = new ServerFaceEmbeddings(
faces,
1,
localFileMlData.lastErrorMessage,
);
return new ServerFileMl(
localFileMlData.fileId,
faceEmbeddings,
imageDimensions.height,
imageDimensions.width,
);
}
export function logQueueStats(queue: PQueue, name: string) {
queue.on("active", () =>
log.info(
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
),
);
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
queue.on("error", (error) =>
console.error(`queuestats: ${name}: Error, `, error),
);
}

View file

@ -1,6 +1,8 @@
import { FILE_TYPE } from "@/media/file-type";
import { ensureElectron } from "@/next/electron";
import log from "@/next/log";
import { ComlinkWorker } from "@/next/worker/comlink-worker";
import { clientPackageNamePhotosDesktop } from "@ente/shared/apps/constants";
import { eventBus, Events } from "@ente/shared/events";
import { getToken, getUserID } from "@ente/shared/storage/localStorage/helpers";
import debounce from "debounce";
@ -8,25 +10,18 @@ import PQueue from "p-queue";
import { createFaceComlinkWorker } from "services/face";
import mlIDbStorage from "services/face/db";
import type { DedicatedMLWorker } from "services/face/face.worker";
import { MLSyncResult } from "services/face/types";
import { EnteFile } from "types/file";
import { logQueueStats } from "./machineLearningService";
export type JobState = "Scheduled" | "Running" | "NotScheduled";
export interface MLSyncJobResult {
shouldBackoff: boolean;
mlSyncResult: MLSyncResult;
}
export class MLSyncJob {
private runCallback: () => Promise<MLSyncJobResult>;
private runCallback: () => Promise<boolean>;
private state: JobState;
private stopped: boolean;
private intervalSec: number;
private nextTimeoutId: ReturnType<typeof setTimeout>;
constructor(runCallback: () => Promise<MLSyncJobResult>) {
constructor(runCallback: () => Promise<boolean>) {
this.runCallback = runCallback;
this.state = "NotScheduled";
this.stopped = true;
@ -65,13 +60,11 @@ export class MLSyncJob {
this.state = "Running";
try {
const jobResult = await this.runCallback();
if (jobResult && jobResult.shouldBackoff) {
this.intervalSec = Math.min(960, this.intervalSec * 2);
} else {
if (await this.runCallback()) {
this.resetInterval();
} else {
this.intervalSec = Math.min(960, this.intervalSec * 2);
}
log.info("Job completed");
} catch (e) {
console.error("Error while running Job: ", e);
} finally {
@ -236,8 +229,15 @@ class MLWorkManager {
this.stopSyncJob();
const token = getToken();
const userID = getUserID();
const userAgent = await getUserAgent();
const mlWorker = await this.getLiveSyncWorker();
return mlWorker.syncLocalFile(token, userID, enteFile, localFile);
return mlWorker.syncLocalFile(
token,
userID,
userAgent,
enteFile,
localFile,
);
});
}
@ -255,7 +255,14 @@ class MLWorkManager {
this.syncJobWorker = undefined;
}
private async runMLSyncJob(): Promise<MLSyncJobResult> {
/**
* Returns `false` to indicate that either an error occurred, or there are
* not more files to process, or that we cannot currently process files.
*
* Which means that when it returns true, all is well and there are more
* things pending to process, so we should chug along at full speed.
*/
private async runMLSyncJob(): Promise<boolean> {
try {
// TODO: skipping is not required if we are caching chunks through service worker
// currently worker chunk itself is not loaded when network is not there
@ -263,29 +270,17 @@ class MLWorkManager {
log.info(
"Skipping ml-sync job run as not connected to internet.",
);
return {
shouldBackoff: true,
mlSyncResult: undefined,
};
return false;
}
const token = getToken();
const userID = getUserID();
const userAgent = await getUserAgent();
const jobWorkerProxy = await this.getSyncJobWorker();
const mlSyncResult = await jobWorkerProxy.sync(token, userID);
return await jobWorkerProxy.sync(token, userID, userAgent);
// this.terminateSyncJobWorker();
const jobResult: MLSyncJobResult = {
shouldBackoff:
!!mlSyncResult.error || mlSyncResult.nOutOfSyncFiles < 1,
mlSyncResult,
};
log.info("ML Sync Job result: ", JSON.stringify(jobResult));
// TODO: redirect/refresh to gallery in case of session_expired, stop ml sync job
return jobResult;
} catch (e) {
log.error("Failed to run MLSync Job", e);
}
@ -323,3 +318,22 @@ class MLWorkManager {
}
export default new MLWorkManager();
export function logQueueStats(queue: PQueue, name: string) {
queue.on("active", () =>
log.info(
`queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`,
),
);
queue.on("idle", () => log.info(`queuestats: ${name}: Idle`));
queue.on("error", (error) =>
console.error(`queuestats: ${name}: Error, `, error),
);
}
const getUserAgent = async () => {
const electron = ensureElectron();
const name = clientPackageNamePhotosDesktop;
const version = await electron.appVersion();
return `${name}/${version}`;
};

View file

@ -3,7 +3,7 @@ import log from "@/next/log";
import * as chrono from "chrono-node";
import { t } from "i18next";
import mlIDbStorage from "services/face/db";
import { Person } from "services/face/types";
import type { Person } from "services/face/people";
import { defaultMLVersion } from "services/machineLearning/machineLearningService";
import { Collection } from "types/collection";
import { EntityType, LocationTag, LocationTagData } from "types/entity";

View file

@ -1,6 +1,6 @@
import { FILE_TYPE } from "@/media/file-type";
import { IndexStatus } from "services/face/db";
import { Person } from "services/face/types";
import type { Person } from "services/face/people";
import { City } from "services/locationSearchService";
import { LocationTagData } from "types/entity";
import { EnteFile } from "types/file";

View file

@ -5,11 +5,13 @@ import { type DedicatedSearchWorker } from "worker/search.worker";
class ComlinkSearchWorker {
private comlinkWorkerInstance: Remote<DedicatedSearchWorker>;
private comlinkWorker: ComlinkWorker<typeof DedicatedSearchWorker>;
async getInstance() {
if (!this.comlinkWorkerInstance) {
this.comlinkWorkerInstance =
await getDedicatedSearchWorker().remote;
if (!this.comlinkWorker)
this.comlinkWorker = getDedicatedSearchWorker();
this.comlinkWorkerInstance = await this.comlinkWorker.remote;
}
return this.comlinkWorkerInstance;
}

View file

@ -1,468 +0,0 @@
// these utils only work in env where OffscreenCanvas is available
import { Matrix, inverse } from "ml-matrix";
import { Box, Dimensions, enlargeBox } from "services/face/geom";
import { FaceAlignment } from "services/face/types";
export function normalizePixelBetween0And1(pixelValue: number) {
return pixelValue / 255.0;
}
export function normalizePixelBetweenMinus1And1(pixelValue: number) {
return pixelValue / 127.5 - 1.0;
}
export function unnormalizePixelFromBetweenMinus1And1(pixelValue: number) {
return clamp(Math.round((pixelValue + 1.0) * 127.5), 0, 255);
}
export function readPixelColor(
imageData: Uint8ClampedArray,
width: number,
height: number,
x: number,
y: number,
) {
if (x < 0 || x >= width || y < 0 || y >= height) {
return { r: 0, g: 0, b: 0, a: 0 };
}
const index = (y * width + x) * 4;
return {
r: imageData[index],
g: imageData[index + 1],
b: imageData[index + 2],
a: imageData[index + 3],
};
}
export function clamp(value: number, min: number, max: number) {
return Math.min(max, Math.max(min, value));
}
export function getPixelBicubic(
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) {
// Clamp to image boundaries
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1);
const px = x - 1;
const nx = x + 1;
const ax = x + 2;
const y = Math.trunc(fy) - (fy >= 0.0 ? 0 : 1);
const py = y - 1;
const ny = y + 1;
const ay = y + 2;
const dx = fx - x;
const dy = fy - y;
function cubic(
dx: number,
ipp: number,
icp: number,
inp: number,
iap: number,
) {
return (
icp +
0.5 *
(dx * (-ipp + inp) +
dx * dx * (2 * ipp - 5 * icp + 4 * inp - iap) +
dx * dx * dx * (-ipp + 3 * icp - 3 * inp + iap))
);
}
const icc = readPixelColor(imageData, imageWidth, imageHeight, x, y);
const ipp =
px < 0 || py < 0
? icc
: readPixelColor(imageData, imageWidth, imageHeight, px, py);
const icp =
px < 0
? icc
: readPixelColor(imageData, imageWidth, imageHeight, x, py);
const inp =
py < 0 || nx >= imageWidth
? icc
: readPixelColor(imageData, imageWidth, imageHeight, nx, py);
const iap =
ax >= imageWidth || py < 0
? icc
: readPixelColor(imageData, imageWidth, imageHeight, ax, py);
const ip0 = cubic(dx, ipp.r, icp.r, inp.r, iap.r);
const ip1 = cubic(dx, ipp.g, icp.g, inp.g, iap.g);
const ip2 = cubic(dx, ipp.b, icp.b, inp.b, iap.b);
// const ip3 = cubic(dx, ipp.a, icp.a, inp.a, iap.a);
const ipc =
px < 0
? icc
: readPixelColor(imageData, imageWidth, imageHeight, px, y);
const inc =
nx >= imageWidth
? icc
: readPixelColor(imageData, imageWidth, imageHeight, nx, y);
const iac =
ax >= imageWidth
? icc
: readPixelColor(imageData, imageWidth, imageHeight, ax, y);
const ic0 = cubic(dx, ipc.r, icc.r, inc.r, iac.r);
const ic1 = cubic(dx, ipc.g, icc.g, inc.g, iac.g);
const ic2 = cubic(dx, ipc.b, icc.b, inc.b, iac.b);
// const ic3 = cubic(dx, ipc.a, icc.a, inc.a, iac.a);
const ipn =
px < 0 || ny >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, px, ny);
const icn =
ny >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, x, ny);
const inn =
nx >= imageWidth || ny >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, nx, ny);
const ian =
ax >= imageWidth || ny >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, ax, ny);
const in0 = cubic(dx, ipn.r, icn.r, inn.r, ian.r);
const in1 = cubic(dx, ipn.g, icn.g, inn.g, ian.g);
const in2 = cubic(dx, ipn.b, icn.b, inn.b, ian.b);
// const in3 = cubic(dx, ipn.a, icn.a, inn.a, ian.a);
const ipa =
px < 0 || ay >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, px, ay);
const ica =
ay >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, x, ay);
const ina =
nx >= imageWidth || ay >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, nx, ay);
const iaa =
ax >= imageWidth || ay >= imageHeight
? icc
: readPixelColor(imageData, imageWidth, imageHeight, ax, ay);
const ia0 = cubic(dx, ipa.r, ica.r, ina.r, iaa.r);
const ia1 = cubic(dx, ipa.g, ica.g, ina.g, iaa.g);
const ia2 = cubic(dx, ipa.b, ica.b, ina.b, iaa.b);
// const ia3 = cubic(dx, ipa.a, ica.a, ina.a, iaa.a);
const c0 = Math.trunc(clamp(cubic(dy, ip0, ic0, in0, ia0), 0, 255));
const c1 = Math.trunc(clamp(cubic(dy, ip1, ic1, in1, ia1), 0, 255));
const c2 = Math.trunc(clamp(cubic(dy, ip2, ic2, in2, ia2), 0, 255));
// const c3 = cubic(dy, ip3, ic3, in3, ia3);
return { r: c0, g: c1, b: c2 };
}
/// Returns the pixel value (RGB) at the given coordinates using bilinear interpolation.
export function getPixelBilinear(
fx: number,
fy: number,
imageData: Uint8ClampedArray,
imageWidth: number,
imageHeight: number,
) {
// Clamp to image boundaries
fx = clamp(fx, 0, imageWidth - 1);
fy = clamp(fy, 0, imageHeight - 1);
// Get the surrounding coordinates and their weights
const x0 = Math.floor(fx);
const x1 = Math.ceil(fx);
const y0 = Math.floor(fy);
const y1 = Math.ceil(fy);
const dx = fx - x0;
const dy = fy - y0;
const dx1 = 1.0 - dx;
const dy1 = 1.0 - dy;
// Get the original pixels
const pixel1 = readPixelColor(imageData, imageWidth, imageHeight, x0, y0);
const pixel2 = readPixelColor(imageData, imageWidth, imageHeight, x1, y0);
const pixel3 = readPixelColor(imageData, imageWidth, imageHeight, x0, y1);
const pixel4 = readPixelColor(imageData, imageWidth, imageHeight, x1, y1);
function bilinear(val1: number, val2: number, val3: number, val4: number) {
return Math.round(
val1 * dx1 * dy1 +
val2 * dx * dy1 +
val3 * dx1 * dy +
val4 * dx * dy,
);
}
// Interpolate the pixel values
const red = bilinear(pixel1.r, pixel2.r, pixel3.r, pixel4.r);
const green = bilinear(pixel1.g, pixel2.g, pixel3.g, pixel4.g);
const blue = bilinear(pixel1.b, pixel2.b, pixel3.b, pixel4.b);
return { r: red, g: green, b: blue };
}
export function warpAffineFloat32List(
imageBitmap: ImageBitmap,
faceAlignment: FaceAlignment,
faceSize: number,
inputData: Float32Array,
inputStartIndex: number,
): void {
// Get the pixel data
const offscreenCanvas = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
const ctx = offscreenCanvas.getContext("2d");
ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height);
const imageData = ctx.getImageData(
0,
0,
imageBitmap.width,
imageBitmap.height,
);
const pixelData = imageData.data;
const transformationMatrix = faceAlignment.affineMatrix.map((row) =>
row.map((val) => (val != 1.0 ? val * faceSize : 1.0)),
); // 3x3
const A: Matrix = new Matrix([
[transformationMatrix[0][0], transformationMatrix[0][1]],
[transformationMatrix[1][0], transformationMatrix[1][1]],
]);
const Ainverse = inverse(A);
const b00 = transformationMatrix[0][2];
const b10 = transformationMatrix[1][2];
const a00Prime = Ainverse.get(0, 0);
const a01Prime = Ainverse.get(0, 1);
const a10Prime = Ainverse.get(1, 0);
const a11Prime = Ainverse.get(1, 1);
for (let yTrans = 0; yTrans < faceSize; ++yTrans) {
for (let xTrans = 0; xTrans < faceSize; ++xTrans) {
// Perform inverse affine transformation
const xOrigin =
a00Prime * (xTrans - b00) + a01Prime * (yTrans - b10);
const yOrigin =
a10Prime * (xTrans - b00) + a11Prime * (yTrans - b10);
// Get the pixel from interpolation
const pixel = getPixelBicubic(
xOrigin,
yOrigin,
pixelData,
imageBitmap.width,
imageBitmap.height,
);
// Set the pixel in the input data
const index = (yTrans * faceSize + xTrans) * 3;
inputData[inputStartIndex + index] =
normalizePixelBetweenMinus1And1(pixel.r);
inputData[inputStartIndex + index + 1] =
normalizePixelBetweenMinus1And1(pixel.g);
inputData[inputStartIndex + index + 2] =
normalizePixelBetweenMinus1And1(pixel.b);
}
}
}
export function createGrayscaleIntMatrixFromNormalized2List(
imageList: Float32Array,
faceNumber: number,
width: number = 112,
height: number = 112,
): number[][] {
const startIndex = faceNumber * width * height * 3;
return Array.from({ length: height }, (_, y) =>
Array.from({ length: width }, (_, x) => {
// 0.299 ∙ Red + 0.587 ∙ Green + 0.114 ∙ Blue
const pixelIndex = startIndex + 3 * (y * width + x);
return clamp(
Math.round(
0.299 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex],
) +
0.587 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex + 1],
) +
0.114 *
unnormalizePixelFromBetweenMinus1And1(
imageList[pixelIndex + 2],
),
),
0,
255,
);
}),
);
}
export function resizeToSquare(img: ImageBitmap, size: number) {
const scale = size / Math.max(img.height, img.width);
const width = scale * img.width;
const height = scale * img.height;
const offscreen = new OffscreenCanvas(size, size);
const ctx = offscreen.getContext("2d");
ctx.imageSmoothingQuality = "high";
ctx.drawImage(img, 0, 0, width, height);
const resizedImage = offscreen.transferToImageBitmap();
return { image: resizedImage, width, height };
}
export function transform(
imageBitmap: ImageBitmap,
affineMat: number[][],
outputWidth: number,
outputHeight: number,
) {
const offscreen = new OffscreenCanvas(outputWidth, outputHeight);
const context = offscreen.getContext("2d");
context.imageSmoothingQuality = "high";
context.transform(
affineMat[0][0],
affineMat[1][0],
affineMat[0][1],
affineMat[1][1],
affineMat[0][2],
affineMat[1][2],
);
context.drawImage(imageBitmap, 0, 0);
return offscreen.transferToImageBitmap();
}
export function crop(imageBitmap: ImageBitmap, cropBox: Box, size: number) {
const dimensions: Dimensions = {
width: size,
height: size,
};
return cropWithRotation(imageBitmap, cropBox, 0, dimensions, dimensions);
}
export function cropWithRotation(
imageBitmap: ImageBitmap,
cropBox: Box,
rotation?: number,
maxSize?: Dimensions,
minSize?: Dimensions,
) {
const box = cropBox.round();
const outputSize = { width: box.width, height: box.height };
if (maxSize) {
const minScale = Math.min(
maxSize.width / box.width,
maxSize.height / box.height,
);
if (minScale < 1) {
outputSize.width = Math.round(minScale * box.width);
outputSize.height = Math.round(minScale * box.height);
}
}
if (minSize) {
const maxScale = Math.max(
minSize.width / box.width,
minSize.height / box.height,
);
if (maxScale > 1) {
outputSize.width = Math.round(maxScale * box.width);
outputSize.height = Math.round(maxScale * box.height);
}
}
// log.info({ imageBitmap, box, outputSize });
const offscreen = new OffscreenCanvas(outputSize.width, outputSize.height);
const offscreenCtx = offscreen.getContext("2d");
offscreenCtx.imageSmoothingQuality = "high";
offscreenCtx.translate(outputSize.width / 2, outputSize.height / 2);
rotation && offscreenCtx.rotate(rotation);
const outputBox = new Box({
x: -outputSize.width / 2,
y: -outputSize.height / 2,
width: outputSize.width,
height: outputSize.height,
});
const enlargedBox = enlargeBox(box, 1.5);
const enlargedOutputBox = enlargeBox(outputBox, 1.5);
offscreenCtx.drawImage(
imageBitmap,
enlargedBox.x,
enlargedBox.y,
enlargedBox.width,
enlargedBox.height,
enlargedOutputBox.x,
enlargedOutputBox.y,
enlargedOutputBox.width,
enlargedOutputBox.height,
);
return offscreen.transferToImageBitmap();
}
export function addPadding(image: ImageBitmap, padding: number) {
const scale = 1 + padding * 2;
const width = scale * image.width;
const height = scale * image.height;
const offscreen = new OffscreenCanvas(width, height);
const ctx = offscreen.getContext("2d");
ctx.imageSmoothingEnabled = false;
ctx.drawImage(
image,
width / 2 - image.width / 2,
height / 2 - image.height / 2,
image.width,
image.height,
);
return offscreen.transferToImageBitmap();
}
export interface BlobOptions {
type?: string;
quality?: number;
}
export async function imageBitmapToBlob(imageBitmap: ImageBitmap) {
const offscreen = new OffscreenCanvas(
imageBitmap.width,
imageBitmap.height,
);
offscreen.getContext("2d").drawImage(imageBitmap, 0, 0);
return offscreen.convertToBlob({
type: "image/jpeg",
quality: 0.8,
});
}
export async function imageBitmapFromBlob(blob: Blob) {
return createImageBitmap(blob);
}

View file

@ -82,7 +82,7 @@ const ffmpegExec = async (
const result = ffmpeg.FS("readFile", outputPath);
const ms = Math.round(Date.now() - startTime);
const ms = Date.now() - startTime;
log.debug(() => `[wasm] ffmpeg ${cmd.join(" ")} (${ms} ms)`);
return result;
} finally {

View file

@ -1,4 +1,4 @@
import { clearCaches } from "@/next/blob-cache";
import { clearBlobCaches } from "@/next/blob-cache";
import log from "@/next/log";
import InMemoryStore from "@ente/shared/storage/InMemoryStore";
import localForage from "@ente/shared/storage/localForage";
@ -43,7 +43,7 @@ export const accountLogout = async () => {
log.error("Ignoring error during logout (local forage)", e);
}
try {
await clearCaches();
await clearBlobCaches();
} catch (e) {
log.error("Ignoring error during logout (cache)", e);
}

View file

@ -20,8 +20,8 @@ export type BlobCacheNamespace = (typeof blobCacheNames)[number];
*
* This cache is suitable for storing large amounts of data (entire files).
*
* To obtain a cache for a given namespace, use {@link openCache}. To clear all
* cached data (e.g. during logout), use {@link clearCaches}.
* To obtain a cache for a given namespace, use {@link openBlobCache}. To clear all
* cached data (e.g. during logout), use {@link clearBlobCaches}.
*
* [Note: Caching files]
*
@ -69,14 +69,31 @@ export interface BlobCache {
delete: (key: string) => Promise<boolean>;
}
const cachedCaches = new Map<BlobCacheNamespace, BlobCache>();
/**
* Return the {@link BlobCache} corresponding to the given {@link name}.
*
* This is a wrapper over {@link openBlobCache} that caches (pun intended) the
* cache and returns the same one each time it is called with the same name.
* It'll open the cache lazily the first time it is invoked.
*/
export const blobCache = async (
name: BlobCacheNamespace,
): Promise<BlobCache> => {
let c = cachedCaches.get(name);
if (!c) cachedCaches.set(name, (c = await openBlobCache(name)));
return c;
};
/**
* Create a new {@link BlobCache} corresponding to the given {@link name}.
*
* @param name One of the arbitrary but predefined namespaces of type
* {@link BlobCacheNamespace} which group related data and allow us to use the
* same key across namespaces.
*/
export const openCache = async (
export const openBlobCache = async (
name: BlobCacheNamespace,
): Promise<BlobCache> =>
isElectron() ? openOPFSCacheWeb(name) : openWebCache(name);
@ -194,7 +211,7 @@ export const cachedOrNew = async (
key: string,
get: () => Promise<Blob>,
): Promise<Blob> => {
const cache = await openCache(cacheName);
const cache = await openBlobCache(cacheName);
const cachedBlob = await cache.get(key);
if (cachedBlob) return cachedBlob;
@ -204,15 +221,17 @@ export const cachedOrNew = async (
};
/**
* Delete all cached data.
* Delete all cached data, including cached caches.
*
* Meant for use during logout, to reset the state of the user's account.
*/
export const clearCaches = async () =>
isElectron() ? clearOPFSCaches() : clearWebCaches();
export const clearBlobCaches = async () => {
cachedCaches.clear();
return isElectron() ? clearOPFSCaches() : clearWebCaches();
};
const clearWebCaches = async () => {
await Promise.all(blobCacheNames.map((name) => caches.delete(name)));
await Promise.allSettled(blobCacheNames.map((name) => caches.delete(name)));
};
const clearOPFSCaches = async () => {

View file

@ -297,7 +297,9 @@ export interface Electron {
*
* @returns A CLIP embedding.
*/
clipImageEmbedding: (jpegImageData: Uint8Array) => Promise<Float32Array>;
computeCLIPImageEmbedding: (
jpegImageData: Uint8Array,
) => Promise<Float32Array>;
/**
* Return a CLIP embedding of the given image if we already have the model
@ -319,7 +321,7 @@ export interface Electron {
*
* @returns A CLIP embedding.
*/
clipTextEmbeddingIfAvailable: (
computeCLIPTextEmbeddingIfAvailable: (
text: string,
) => Promise<Float32Array | undefined>;
@ -337,29 +339,7 @@ export interface Electron {
* Both the input and output are opaque binary data whose internal structure
* is specific to our implementation and the model (MobileFaceNet) we use.
*/
faceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
/**
* Return a face crop stored by a previous version of ML.
*
* [Note: Legacy face crops]
*
* Older versions of ML generated and stored face crops in a "face-crops"
* cache directory on the Electron side. For the time being, we have
* disabled the face search whilst we put finishing touches to it. However,
* it'll be nice to still show the existing faces that have been clustered
* for people who opted in to the older beta.
*
* So we retain the older "face-crops" disk cache, and use this method to
* serve faces from it when needed.
*
* @param faceID An identifier corresponding to which the face crop had been
* stored by the older version of our app.
*
* @returns the JPEG data of the face crop if a file is found for the given
* {@link faceID}, otherwise undefined.
*/
legacyFaceCrop: (faceID: string) => Promise<Uint8Array | undefined>;
computeFaceEmbeddings: (input: Float32Array) => Promise<Float32Array>;
// - Watch

View file

@ -47,8 +47,8 @@ const workerBridge = {
convertToJPEG: (imageData: Uint8Array) =>
ensureElectron().convertToJPEG(imageData),
detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input),
faceEmbeddings: (input: Float32Array) =>
ensureElectron().faceEmbeddings(input),
computeFaceEmbeddings: (input: Float32Array) =>
ensureElectron().computeFaceEmbeddings(input),
};
export type WorkerBridge = typeof workerBridge;

View file

@ -14,6 +14,8 @@ export const CLIENT_PACKAGE_NAMES = new Map([
[APPS.ACCOUNTS, "io.ente.accounts.web"],
]);
export const clientPackageNamePhotosDesktop = "io.ente.photos.desktop";
export const APP_TITLES = new Map([
[APPS.ALBUMS, "Ente Albums"],
[APPS.PHOTOS, "Ente Photos"],

View file

@ -28,8 +28,8 @@ class HTTPService {
const responseData = response.data;
log.error(
`HTTP Service Error - ${JSON.stringify({
url: config.url,
method: config.method,
url: config?.url,
method: config?.method,
xRequestId: response.headers["x-request-id"],
httpStatus: response.status,
errMessage: responseData.message,

View file

@ -10,6 +10,10 @@ export const wait = (ms: number) =>
/**
* Await the given {@link promise} for {@link timeoutMS} milliseconds. If it
* does not resolve within {@link timeoutMS}, then reject with a timeout error.
*
* Note that this does not abort {@link promise} itself - it will still get
* resolved to completion, just its result will be ignored if it gets resolved
* after we've already timed out.
*/
export const withTimeout = async <T>(promise: Promise<T>, ms: number) => {
let timeoutId: ReturnType<typeof setTimeout>;