[mob] Use vector for cosine dist

This commit is contained in:
Neeraj Gupta 2024-04-02 11:53:40 +05:30
parent 4cb7334868
commit 226808aadb
3 changed files with 38 additions and 10 deletions

View file

@ -5,6 +5,8 @@ import "dart:math" show max;
import "dart:typed_data";
import "package:logging/logging.dart";
import "package:ml_linalg/dtype.dart";
import "package:ml_linalg/vector.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
@ -12,14 +14,16 @@ import "package:synchronized/synchronized.dart";
class FaceInfo {
final String faceID;
final List<double> embedding;
final List<double>? embedding;
final Vector? vEmbedding;
int? clusterId;
String? closestFaceId;
int? closestDist;
int? fileCreationTime;
FaceInfo({
required this.faceID,
required this.embedding,
this.embedding,
this.vEmbedding,
this.clusterId,
this.fileCreationTime,
});
@ -230,7 +234,10 @@ class FaceLinearClustering {
faceInfos.add(
FaceInfo(
faceID: entry.key,
embedding: EVector.fromBuffer(entry.value.$2).values,
vEmbedding: Vector.fromList(
EVector.fromBuffer(entry.value.$2).values,
dtype: DType.float32,
),
clusterId: entry.value.$1,
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
@ -300,17 +307,25 @@ class FaceLinearClustering {
}
continue;
}
final currentEmbedding = sortedFaceInfos[i].embedding;
int closestIdx = -1;
double closestDistance = double.infinity;
if (i % 250 == 0) {
log("[ClusterIsolate] ${DateTime.now()} Processing $i faces");
}
for (int j = i - 1; j >= 0; j--) {
final double distance = cosineDistForNormVectors(
currentEmbedding,
sortedFaceInfos[j].embedding,
);
late double distance;
if (sortedFaceInfos[i].vEmbedding != null) {
distance = 1.0 -
sortedFaceInfos[i]
.vEmbedding!
.dot(sortedFaceInfos[j].vEmbedding!);
} else {
distance = cosineDistForNormVectors(
sortedFaceInfos[i].embedding!,
sortedFaceInfos[j].embedding!,
);
}
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
@ -339,7 +354,7 @@ class FaceLinearClustering {
stopwatchClustering.stop();
log(
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings (${sortedFaceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
);
// analyze the results

View file

@ -389,6 +389,10 @@ class ClusterFeedbackService {
if (ignoredClusters.contains(clusterID)) {
continue;
}
if (allClusterIdsToCountMap[clusterID]! < 2) {
continue;
}
late List<double> avg;
if (clusterToSummary[clusterID]?.$2 ==
allClusterIdsToCountMap[clusterID]) {
@ -412,6 +416,11 @@ class ClusterFeedbackService {
if (updatesForClusterSummary.length > 100) {
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
updatesForClusterSummary.clear();
if (kDebugMode) {
_logger.info(
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
);
}
}
clusterAvg[clusterID] = avg;
}
@ -441,6 +450,10 @@ class ClusterFeedbackService {
int? nearestPersonCluster;
double? minDistance;
for (final personCluster in personClusters) {
if (clusterAvg[personCluster] == null) {
_logger.info('no avg for cluster $personCluster');
continue;
}
final avg = clusterAvg[personCluster]!;
final distance = cosineDistForNormVectors(avg, otherAvg);
if (distance < maxClusterDistance) {

View file

@ -93,7 +93,7 @@ class PersonFaceWidget extends StatelessWidget {
);
if (face == null) {
debugPrint(
"No cover face for person: $personId and cluster $clusterID",
"No cover face for person: $personId and cluster $clusterID and recentFile ${file.uploadedFileID}",
);
return null;
}