[web] ML prune todos (#1791)

This commit is contained in:
Manav Rathi 2024-05-21 11:56:13 +05:30 committed by GitHub
commit 4dbc8ab31e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 105 additions and 79 deletions

View file

@ -27,7 +27,7 @@
"leaflet-defaulticon-compatibility": "^0.1.1",
"localforage": "^1.9.0",
"memoize-one": "^6.0.0",
"ml-matrix": "^6.10.4",
"ml-matrix": "^6.11",
"otpauth": "^9.0.2",
"p-debounce": "^4.0.0",
"p-queue": "^7.1.0",
@ -42,7 +42,7 @@
"react-window": "^1.8.6",
"sanitize-filename": "^1.6.3",
"similarity-transformation": "^0.0.1",
"transformation-matrix": "^2.15.0",
"transformation-matrix": "^2.16",
"uuid": "^9.0.1",
"vscode-uri": "^3.0.7",
"xml-js": "^1.6.11",

View file

@ -2,7 +2,6 @@ import { FILE_TYPE } from "@/media/file-type";
import { blobCache } from "@/next/blob-cache";
import log from "@/next/log";
import { workerBridge } from "@/next/worker/worker-bridge";
import { euclidean } from "hdbscan";
import { Matrix } from "ml-matrix";
import {
Box,
@ -19,6 +18,13 @@ import type {
} from "services/face/types";
import { defaultMLVersion } from "services/machineLearning/machineLearningService";
import { getSimilarityTransformation } from "similarity-transformation";
import {
Matrix as TransformationMatrix,
applyToPoint,
compose,
scale,
translate,
} from "transformation-matrix";
import type { EnteFile } from "types/file";
import { fetchImageBitmap, getLocalFileImageBitmap } from "./file";
import {
@ -27,13 +33,7 @@ import {
pixelRGBBilinear,
warpAffineFloat32List,
} from "./image";
import {
Matrix as transformMatrix,
applyToPoint,
compose,
scale,
translate,
} from "transformation-matrix";
/**
* Index faces in the given file.
*
@ -221,8 +221,16 @@ const convertToYOLOInputFloat32ChannelsFirst = (imageBitmap: ImageBitmap) => {
*
* Only detections that exceed a minimum score are returned.
*
* @param rows A Float32Array of shape [25200, 16], where each row
* represents a bounding box.
* @param rows A Float32Array of shape [25200, 16], where each row represents a
* face detection.
*
* YOLO detects a fixed number of faces, 25200, always from the input it is
* given. Each detection is a "row" of 16 bytes, containing the bounding box,
* score, and landmarks of the detection.
*
* We prune out detections with a score lower than our threshold. However, we
* will still be left with some overlapping detections of the same face: these
* we will deduplicate in {@link removeDuplicateDetections}.
*/
const filterExtractDetectionsFromYOLOOutput = (
rows: Float32Array,
@ -288,18 +296,21 @@ const transformFaceDetections = (
}));
};
const boxTransformationMatrix = (inBox: Box, toBox: Box): transformMatrix =>
const boxTransformationMatrix = (
inBox: Box,
toBox: Box,
): TransformationMatrix =>
compose(
translate(toBox.x, toBox.y),
scale(toBox.width / inBox.width, toBox.height / inBox.height),
);
const transformPoint = (point: Point, transform: transformMatrix) => {
const transformPoint = (point: Point, transform: TransformationMatrix) => {
const txdPoint = applyToPoint(transform, point);
return new Point(txdPoint.x, txdPoint.y);
};
const transformBox = (box: Box, transform: transformMatrix) => {
const transformBox = (box: Box, transform: TransformationMatrix) => {
const topLeft = transformPoint(new Point(box.x, box.y), transform);
const bottomRight = transformPoint(
new Point(box.x + box.width, box.y + box.height),
@ -315,17 +326,24 @@ const transformBox = (box: Box, transform: transformMatrix) => {
};
/**
* Remove overlapping faces from an array of face detections through non-maximum suppression algorithm.
* Remove overlapping faces from an array of face detections through non-maximum
* suppression algorithm.
*
* This function sorts the detections by their probability in descending order, then iterates over them.
* This function sorts the detections by their probability in descending order,
* then iterates over them.
*
* For each detection, it calculates the Intersection over Union (IoU) with all other detections.
* For each detection, it calculates the Intersection over Union (IoU) with all
* other detections.
*
* If the IoU is greater than or equal to the specified threshold (`iouThreshold`), the other detection is considered overlapping and is removed.
* If the IoU is greater than or equal to the specified threshold
* (`iouThreshold`), the other detection is considered overlapping and is
* removed.
*
* @param detections - An array of face detections to remove overlapping faces from.
* @param detections - An array of face detections to remove overlapping faces
* from.
*
* @param iouThreshold - The minimum IoU between two detections for them to be considered overlapping.
* @param iouThreshold - The minimum IoU between two detections for them to be
* considered overlapping.
*
* @returns An array of face detections with overlapping faces removed
*/
@ -333,13 +351,13 @@ const naiveNonMaxSuppression = (
detections: FaceDetection[],
iouThreshold: number,
): FaceDetection[] => {
// Sort the detections by score, the highest first
// Sort the detections by score, the highest first.
detections.sort((a, b) => b.probability - a.probability);
// Loop through the detections and calculate the IOU
// Loop through the detections and calculate the IOU.
for (let i = 0; i < detections.length - 1; i++) {
for (let j = i + 1; j < detections.length; j++) {
const iou = calculateIOU(detections[i], detections[j]);
const iou = intersectionOverUnion(detections[i], detections[j]);
if (iou >= iouThreshold) {
detections.splice(j, 1);
j--;
@ -350,7 +368,7 @@ const naiveNonMaxSuppression = (
return detections;
};
const calculateIOU = (a: FaceDetection, b: FaceDetection): number => {
const intersectionOverUnion = (a: FaceDetection, b: FaceDetection): number => {
const intersectionMinX = Math.max(a.box.x, b.box.x);
const intersectionMinY = Math.max(a.box.y, b.box.y);
const intersectionMaxX = Math.min(
@ -453,12 +471,15 @@ const faceAlignmentUsingSimilarityTransform = (
const meanTranslation = simTransform.toMean.sub(0.5).mul(size);
const centerMat = simTransform.fromMean.sub(meanTranslation);
const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0));
const rotation = -Math.atan2(
simTransform.rotation.get(0, 1),
simTransform.rotation.get(0, 0),
);
return { affineMatrix, center, size, rotation };
const boundingBox = new Box({
x: center.x - size / 2,
y: center.y - size / 2,
width: size,
height: size,
});
return { affineMatrix, boundingBox };
};
const convertToMobileFaceNetInput = (
@ -733,33 +754,22 @@ const extractFaceCrop = (
imageBitmap: ImageBitmap,
alignment: FaceAlignment,
): ImageBitmap => {
const alignmentBox = new Box({
x: alignment.center.x - alignment.size / 2,
y: alignment.center.y - alignment.size / 2,
width: alignment.size,
height: alignment.size,
});
// TODO-ML: This algorithm is different from what is used by the mobile app.
// Also, it needs to be something that can work fully using the embedding we
// receive from remote - the `alignment.boundingBox` will not be available
// to us in such cases.
const paddedBox = roundBox(enlargeBox(alignment.boundingBox, 1.5));
const outputSize = { width: paddedBox.width, height: paddedBox.height };
const padding = 0.25;
const scaleForPadding = 1 + padding * 2;
const paddedBox = roundBox(enlargeBox(alignmentBox, scaleForPadding));
const maxDimension = 256;
const scale = Math.min(
maxDimension / paddedBox.width,
maxDimension / paddedBox.height,
);
return cropImage(imageBitmap, paddedBox, 256);
};
const cropImage = (
imageBitmap: ImageBitmap,
cropBox: Box,
maxDimension: number,
) => {
const box = roundBox(cropBox);
const outputSize = { width: box.width, height: box.height };
const scale = Math.min(maxDimension / box.width, maxDimension / box.height);
if (scale < 1) {
outputSize.width = Math.round(scale * box.width);
outputSize.height = Math.round(scale * box.height);
outputSize.width = Math.round(scale * paddedBox.width);
outputSize.height = Math.round(scale * paddedBox.height);
}
const offscreen = new OffscreenCanvas(outputSize.width, outputSize.height);
@ -775,7 +785,7 @@ const cropImage = (
height: outputSize.height,
});
const enlargedBox = enlargeBox(box, 1.5);
const enlargedBox = enlargeBox(paddedBox, 1.5);
const enlargedOutputBox = enlargeBox(outputBox, 1.5);
offscreenCtx.drawImage(

View file

@ -20,16 +20,18 @@ export const putFaceEmbedding = async (
const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance();
const { file: encryptedEmbeddingData } =
await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key);
log.info(
`putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
);
const res = await putEmbedding({
// TODO-ML(MR): Do we need any of these fields
// log.info(
// `putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`,
// );
/*const res =*/ await putEmbedding({
fileID: enteFile.id,
encryptedEmbedding: encryptedEmbeddingData.encryptedData,
decryptionHeader: encryptedEmbeddingData.decryptionHeader,
model: "file-ml-clip-face",
});
log.info("putEmbedding response: ", res);
// TODO-ML(MR): Do we need any of these fields
// log.info("putEmbedding response: ", res);
};
export interface FileML extends ServerFileMl {

View file

@ -8,13 +8,20 @@ export interface FaceDetection {
}
export interface FaceAlignment {
// TODO-ML(MR): remove affine matrix as rotation, size and center
// are simple to store and use, affine matrix adds complexity while getting crop
/**
* An affine transformation matrix (rotation, translation, scaling) to align
* the face extracted from the image.
*/
affineMatrix: number[][];
rotation: number;
// size and center is relative to image dimentions stored at mlFileData
size: number;
center: Point;
/**
* The bounding box of the transformed box.
*
* The affine transformation shifts the original detection box a new,
* transformed, box (possibily rotated). This property is the bounding box
* of that transformed box. It is in the coordinate system of the original,
* full, image on which the detection occurred.
*/
boundingBox: Box;
}
export interface Face {

View file

@ -177,12 +177,19 @@ some cases.
## Face search
- [matrix](https://github.com/mljs/matrix) and
[similarity-transformation](https://github.com/shaileshpandit/similarity-transformation-js)
are used during face alignment.
- [transformation-matrix](https://github.com/chrvadala/transformation-matrix)
is used during face detection.
is used for performing 2D affine transformations using transformation
matrices. It is used during face detection.
- [matrix](https://github.com/mljs/matrix) is mathematical matrix abstraction.
It is used alongwith
[similarity-transformation](https://github.com/shaileshpandit/similarity-transformation-js)
during face alignment.
> Note that while both `transformation-matrix` and `matrix` are "matrix"
> libraries, they have different foci and purposes: `transformation-matrix`
> provides affine transforms, while `matrix` is for performing computations
> on matrices, say inverting them or performing their decomposition.
- [hdbscan](https://github.com/shaileshpandit/hdbscan-js) is used for face
clustering.

View file

@ -3528,7 +3528,7 @@ ml-array-rescale@^1.3.7:
ml-array-max "^1.2.4"
ml-array-min "^1.2.3"
ml-matrix@^6.10.4:
ml-matrix@^6.11:
version "6.11.0"
resolved "https://registry.yarnpkg.com/ml-matrix/-/ml-matrix-6.11.0.tgz#3cf2260ef04cbb8e0e0425e71d200f5cbcf82772"
integrity sha512-7jr9NmFRkaUxbKslfRu3aZOjJd2LkSitCGv+QH9PF0eJoEG7jIpjXra1Vw8/kgao8+kHCSsJONG6vfWmXQ+/Eg==
@ -4628,7 +4628,7 @@ tr46@~0.0.3:
resolved "https://registry.yarnpkg.com/tr46/-/tr46-0.0.3.tgz#8184fd347dac9cdc185992f3a6622e14b9d9ab6a"
integrity sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==
transformation-matrix@^2.15.0:
transformation-matrix@^2.16:
version "2.16.1"
resolved "https://registry.yarnpkg.com/transformation-matrix/-/transformation-matrix-2.16.1.tgz#4a2de06331b94ae953193d1b9a5ba002ec5f658a"
integrity sha512-tdtC3wxVEuzU7X/ydL131Q3JU5cPMEn37oqVLITjRDSDsnSHVFzW2JiCLfZLIQEgWzZHdSy3J6bZzvKEN24jGA==