From 226808aadb985440fc819a1e6c62b703a3e63f4c Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Tue, 2 Apr 2024 11:53:40 +0530 Subject: [PATCH] [mob] Use vector for cosine dist --- .../linear_clustering_service.dart | 33 ++++++++++++++----- .../face_ml/feedback/cluster_feedback.dart | 13 ++++++++ .../search/result/person_face_widget.dart | 2 +- 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart index 09dc39889..9f4e6b160 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart @@ -5,6 +5,8 @@ import "dart:math" show max; import "dart:typed_data"; import "package:logging/logging.dart"; +import "package:ml_linalg/dtype.dart"; +import "package:ml_linalg/vector.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; @@ -12,14 +14,16 @@ import "package:synchronized/synchronized.dart"; class FaceInfo { final String faceID; - final List embedding; + final List? embedding; + final Vector? vEmbedding; int? clusterId; String? closestFaceId; int? closestDist; int? fileCreationTime; FaceInfo({ required this.faceID, - required this.embedding, + this.embedding, + this.vEmbedding, this.clusterId, this.fileCreationTime, }); @@ -230,7 +234,10 @@ class FaceLinearClustering { faceInfos.add( FaceInfo( faceID: entry.key, - embedding: EVector.fromBuffer(entry.value.$2).values, + vEmbedding: Vector.fromList( + EVector.fromBuffer(entry.value.$2).values, + dtype: DType.float32, + ), clusterId: entry.value.$1, fileCreationTime: fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], @@ -300,17 +307,25 @@ class FaceLinearClustering { } continue; } - final currentEmbedding = sortedFaceInfos[i].embedding; + int closestIdx = -1; double closestDistance = double.infinity; if (i % 250 == 0) { log("[ClusterIsolate] ${DateTime.now()} Processing $i faces"); } for (int j = i - 1; j >= 0; j--) { - final double distance = cosineDistForNormVectors( - currentEmbedding, - sortedFaceInfos[j].embedding, - ); + late double distance; + if (sortedFaceInfos[i].vEmbedding != null) { + distance = 1.0 - + sortedFaceInfos[i] + .vEmbedding! + .dot(sortedFaceInfos[j].vEmbedding!); + } else { + distance = cosineDistForNormVectors( + sortedFaceInfos[i].embedding!, + sortedFaceInfos[j].embedding!, + ); + } if (distance < closestDistance) { closestDistance = distance; closestIdx = j; @@ -339,7 +354,7 @@ class FaceLinearClustering { stopwatchClustering.stop(); log( - ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings (${sortedFaceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID', + ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID', ); // analyze the results diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index a3c1e89be..a1d8c650e 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -389,6 +389,10 @@ class ClusterFeedbackService { if (ignoredClusters.contains(clusterID)) { continue; } + if (allClusterIdsToCountMap[clusterID]! < 2) { + continue; + } + late List avg; if (clusterToSummary[clusterID]?.$2 == allClusterIdsToCountMap[clusterID]) { @@ -412,6 +416,11 @@ class ClusterFeedbackService { if (updatesForClusterSummary.length > 100) { await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); updatesForClusterSummary.clear(); + if (kDebugMode) { + _logger.info( + 'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters', + ); + } } clusterAvg[clusterID] = avg; } @@ -441,6 +450,10 @@ class ClusterFeedbackService { int? nearestPersonCluster; double? minDistance; for (final personCluster in personClusters) { + if (clusterAvg[personCluster] == null) { + _logger.info('no avg for cluster $personCluster'); + continue; + } final avg = clusterAvg[personCluster]!; final distance = cosineDistForNormVectors(avg, otherAvg); if (distance < maxClusterDistance) { diff --git a/mobile/lib/ui/viewer/search/result/person_face_widget.dart b/mobile/lib/ui/viewer/search/result/person_face_widget.dart index 4aba9ae47..b2286dc65 100644 --- a/mobile/lib/ui/viewer/search/result/person_face_widget.dart +++ b/mobile/lib/ui/viewer/search/result/person_face_widget.dart @@ -93,7 +93,7 @@ class PersonFaceWidget extends StatelessWidget { ); if (face == null) { debugPrint( - "No cover face for person: $personId and cluster $clusterID", + "No cover face for person: $personId and cluster $clusterID and recentFile ${file.uploadedFileID}", ); return null; }