[mob][photos] Use cosineDistanceSIMD

This commit is contained in:
laurenspriem 2024-04-24 16:37:39 +05:30
parent 05a4e9f90b
commit 462d1d4854
4 changed files with 37 additions and 21 deletions

View file

@ -1,5 +1,18 @@
import 'dart:math' show sqrt;
import "package:ml_linalg/vector.dart";
/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg
///
/// WARNING: This assumes both vectors are already normalized!
double cosineDistanceSIMD(Vector vector1, Vector vector2) {
if (vector1.length != vector2.length) {
throw ArgumentError('Vectors must be the same length');
}
return 1 - vector1.dot(vector2);
}
/// Calculates the cosine distance between two embeddings/vectors.
///
/// Throws an ArgumentError if the vectors are of different lengths or

View file

@ -560,10 +560,10 @@ class FaceClusteringService {
for (int j = i - 1; j >= 0; j--) {
late double distance;
if (sortedFaceInfos[i].vEmbedding != null) {
distance = 1.0 -
sortedFaceInfos[i]
.vEmbedding!
.dot(sortedFaceInfos[j].vEmbedding!);
distance = cosineDistanceSIMD(
sortedFaceInfos[i].vEmbedding!,
sortedFaceInfos[j].vEmbedding!,
);
} else {
distance = cosineDistForNormVectors(
sortedFaceInfos[i].embedding!,
@ -804,8 +804,10 @@ class FaceClusteringService {
double closestDistance = double.infinity;
for (int j = 0; j < totalFaces; j++) {
if (i == j) continue;
final double distance =
1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
final double distance = cosineDistanceSIMD(
faceInfos[i].vEmbedding!,
faceInfos[j].vEmbedding!,
);
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
@ -855,10 +857,10 @@ class FaceClusteringService {
for (int i = 0; i < clusterIds.length; i++) {
for (int j = 0; j < clusterIds.length; j++) {
if (i == j) continue;
final double newDistance = 1.0 -
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
);
final double newDistance = cosineDistanceSIMD(
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1,
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
);
if (newDistance < distance) {
distance = newDistance;
clusterIDsToMerge = (clusterIds[i], clusterIds[j]);

View file

@ -12,6 +12,7 @@ import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
@ -594,7 +595,7 @@ class ClusterFeedbackService {
final List<double> distances = [];
for (final otherEmbedding in sampledOtherEmbeddings) {
for (final embedding in sampledEmbeddings) {
distances.add(1 - embedding.dot(otherEmbedding));
distances.add(cosineDistanceSIMD(embedding,otherEmbedding));
}
}
distances.sort();
@ -799,7 +800,7 @@ class ClusterFeedbackService {
continue;
}
final Vector avg = clusterAvg[personCluster]!;
final distance = 1 - avg.dot(otherAvg);
final distance = cosineDistanceSIMD(avg,otherAvg);
if (distance < maxClusterDistance) {
if (minDistance == null || distance < minDistance) {
minDistance = distance;
@ -950,7 +951,7 @@ class ClusterFeedbackService {
final fileIdToDistanceMap = {};
for (final entry in faceIdToVectorMap.entries) {
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
1 - personAvg.dot(entry.value);
cosineDistanceSIMD(personAvg,entry.value);
}
w?.log('calculated distances for cluster $clusterID');
suggestion.filesInCluster.sort((b, a) {

View file

@ -207,14 +207,14 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
if (embedding.key == otherEmbedding.key) {
continue;
}
final distance64 = 1.0 -
Vector.fromList(embedding.value, dtype: DType.float64).dot(
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
);
final distance32 = 1.0 -
Vector.fromList(embedding.value, dtype: DType.float32).dot(
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
);
final distance64 = cosineDistanceSIMD(
Vector.fromList(embedding.value, dtype: DType.float64),
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
);
final distance32 = cosineDistanceSIMD(
Vector.fromList(embedding.value, dtype: DType.float32),
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
);
final distance = cosineDistForNormVectors(
embedding.value,
otherEmbedding.value,