[mob][photos] Use cosineDistanceSIMD
This commit is contained in:
parent
05a4e9f90b
commit
462d1d4854
|
@ -1,5 +1,18 @@
|
|||
import 'dart:math' show sqrt;
|
||||
|
||||
import "package:ml_linalg/vector.dart";
|
||||
|
||||
/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg
|
||||
///
|
||||
/// WARNING: This assumes both vectors are already normalized!
|
||||
double cosineDistanceSIMD(Vector vector1, Vector vector2) {
|
||||
if (vector1.length != vector2.length) {
|
||||
throw ArgumentError('Vectors must be the same length');
|
||||
}
|
||||
|
||||
return 1 - vector1.dot(vector2);
|
||||
}
|
||||
|
||||
/// Calculates the cosine distance between two embeddings/vectors.
|
||||
///
|
||||
/// Throws an ArgumentError if the vectors are of different lengths or
|
||||
|
|
|
@ -560,10 +560,10 @@ class FaceClusteringService {
|
|||
for (int j = i - 1; j >= 0; j--) {
|
||||
late double distance;
|
||||
if (sortedFaceInfos[i].vEmbedding != null) {
|
||||
distance = 1.0 -
|
||||
sortedFaceInfos[i]
|
||||
.vEmbedding!
|
||||
.dot(sortedFaceInfos[j].vEmbedding!);
|
||||
distance = cosineDistanceSIMD(
|
||||
sortedFaceInfos[i].vEmbedding!,
|
||||
sortedFaceInfos[j].vEmbedding!,
|
||||
);
|
||||
} else {
|
||||
distance = cosineDistForNormVectors(
|
||||
sortedFaceInfos[i].embedding!,
|
||||
|
@ -804,8 +804,10 @@ class FaceClusteringService {
|
|||
double closestDistance = double.infinity;
|
||||
for (int j = 0; j < totalFaces; j++) {
|
||||
if (i == j) continue;
|
||||
final double distance =
|
||||
1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
|
||||
final double distance = cosineDistanceSIMD(
|
||||
faceInfos[i].vEmbedding!,
|
||||
faceInfos[j].vEmbedding!,
|
||||
);
|
||||
if (distance < closestDistance) {
|
||||
closestDistance = distance;
|
||||
closestIdx = j;
|
||||
|
@ -855,10 +857,10 @@ class FaceClusteringService {
|
|||
for (int i = 0; i < clusterIds.length; i++) {
|
||||
for (int j = 0; j < clusterIds.length; j++) {
|
||||
if (i == j) continue;
|
||||
final double newDistance = 1.0 -
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
|
||||
);
|
||||
final double newDistance = cosineDistanceSIMD(
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1,
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
|
||||
);
|
||||
if (newDistance < distance) {
|
||||
distance = newDistance;
|
||||
clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
|
||||
|
|
|
@ -12,6 +12,7 @@ import "package:photos/face/db.dart";
|
|||
import "package:photos/face/model/person.dart";
|
||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||
import "package:photos/models/file/file.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
|
@ -594,7 +595,7 @@ class ClusterFeedbackService {
|
|||
final List<double> distances = [];
|
||||
for (final otherEmbedding in sampledOtherEmbeddings) {
|
||||
for (final embedding in sampledEmbeddings) {
|
||||
distances.add(1 - embedding.dot(otherEmbedding));
|
||||
distances.add(cosineDistanceSIMD(embedding,otherEmbedding));
|
||||
}
|
||||
}
|
||||
distances.sort();
|
||||
|
@ -799,7 +800,7 @@ class ClusterFeedbackService {
|
|||
continue;
|
||||
}
|
||||
final Vector avg = clusterAvg[personCluster]!;
|
||||
final distance = 1 - avg.dot(otherAvg);
|
||||
final distance = cosineDistanceSIMD(avg,otherAvg);
|
||||
if (distance < maxClusterDistance) {
|
||||
if (minDistance == null || distance < minDistance) {
|
||||
minDistance = distance;
|
||||
|
@ -950,7 +951,7 @@ class ClusterFeedbackService {
|
|||
final fileIdToDistanceMap = {};
|
||||
for (final entry in faceIdToVectorMap.entries) {
|
||||
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
|
||||
1 - personAvg.dot(entry.value);
|
||||
cosineDistanceSIMD(personAvg,entry.value);
|
||||
}
|
||||
w?.log('calculated distances for cluster $clusterID');
|
||||
suggestion.filesInCluster.sort((b, a) {
|
||||
|
|
|
@ -207,14 +207,14 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
|
|||
if (embedding.key == otherEmbedding.key) {
|
||||
continue;
|
||||
}
|
||||
final distance64 = 1.0 -
|
||||
Vector.fromList(embedding.value, dtype: DType.float64).dot(
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
|
||||
);
|
||||
final distance32 = 1.0 -
|
||||
Vector.fromList(embedding.value, dtype: DType.float32).dot(
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
|
||||
);
|
||||
final distance64 = cosineDistanceSIMD(
|
||||
Vector.fromList(embedding.value, dtype: DType.float64),
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
|
||||
);
|
||||
final distance32 = cosineDistanceSIMD(
|
||||
Vector.fromList(embedding.value, dtype: DType.float32),
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
|
||||
);
|
||||
final distance = cosineDistForNormVectors(
|
||||
embedding.value,
|
||||
otherEmbedding.value,
|
||||
|
|
Loading…
Reference in a new issue