From ba58ac1358983b81ef60e3d602daaa2d1594dfd0 Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Thu, 18 Apr 2024 14:44:12 +0530 Subject: [PATCH] [mob] Add merges to predictComplete method --- .../face_clustering_service.dart | 91 +++++++++++++++++-- .../face_ml/feedback/cluster_feedback.dart | 26 +++++- 2 files changed, 108 insertions(+), 9 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 25259a0c0..2920cd760 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -249,6 +249,7 @@ class FaceClusteringService { Map input, { Map? fileIDToCreationTime, double distanceThreshold = kRecommendedDistanceThreshold, + double mergeThreshold = 0.30, }) async { if (input.isEmpty) { _logger.warning( @@ -270,6 +271,7 @@ class FaceClusteringService { "input": input, "fileIDToCreationTime": fileIDToCreationTime, "distanceThreshold": distanceThreshold, + "mergeThreshold": mergeThreshold, }, taskName: "createImageEmbedding", ) as Map; @@ -578,12 +580,11 @@ class FaceClusteringService { ); } - - static Map runCompleteClustering(Map args) { final input = args['input'] as Map; final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; final distanceThreshold = args['distanceThreshold'] as double; + final mergeThreshold = args['mergeThreshold'] as double; log( "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering", @@ -637,10 +638,11 @@ class FaceClusteringService { final Map newFaceIdToCluster = {}; final stopwatchClustering = Stopwatch()..start(); for (int i = 0; i < totalFaces; i++) { + if (faceInfos[i].clusterId != null) continue; int closestIdx = -1; double closestDistance = double.infinity; - if (i % 250 == 0) { - log("[CompleteClustering] ${DateTime.now()} Processed $i faces"); + if (i + 1 % 250 == 0) { + log("[CompleteClustering] ${DateTime.now()} Processed ${i - 1} faces"); } for (int j = 0; j < totalFaces; j++) { if (i == j) continue; @@ -656,18 +658,91 @@ class FaceClusteringService { if (faceInfos[closestIdx].clusterId == null) { clusterID++; faceInfos[closestIdx].clusterId = clusterID; - newFaceIdToCluster[faceInfos[closestIdx].faceID] = clusterID; } faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!; - newFaceIdToCluster[faceInfos[i].faceID] = - faceInfos[closestIdx].clusterId!; } else { clusterID++; faceInfos[i].clusterId = clusterID; - newFaceIdToCluster[faceInfos[i].faceID] = clusterID; } } + // Now calculate the mean of the embeddings for each cluster + final Map> clusterIdToFaceInfos = {}; + for (final faceInfo in faceInfos) { + if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) { + clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo); + } else { + clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo]; + } + } + final Map clusterIdToMeanEmbeddingAndWeight = {}; + for (final clusterId in clusterIdToFaceInfos.keys) { + final List embeddings = clusterIdToFaceInfos[clusterId]! + .map((faceInfo) => faceInfo.vEmbedding!) + .toList(); + final count = clusterIdToFaceInfos[clusterId]!.length; + final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count; + clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count); + } + + // Now merge the clusters that are close to each other, based on mean embedding + final List<(int, int)> mergedClustersList = []; + final List clusterIds = + clusterIdToMeanEmbeddingAndWeight.keys.toList(); + log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges'); + while (true) { + if (clusterIds.length < 2) break; + double distance = double.infinity; + (int, int) clusterIDsToMerge = (-1, -1); + for (int i = 0; i < clusterIds.length; i++) { + for (int j = 0; j < clusterIds.length; j++) { + if (i == j) continue; + final double newDistance = 1.0 - + clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot( + clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1, + ); + if (newDistance < distance) { + distance = newDistance; + clusterIDsToMerge = (clusterIds[i], clusterIds[j]); + } + } + } + if (distance < mergeThreshold) { + mergedClustersList.add(clusterIDsToMerge); + final clusterID1 = clusterIDsToMerge.$1; + final clusterID2 = clusterIDsToMerge.$2; + final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1; + final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1; + final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2; + final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2; + final weight1 = count1 / (count1 + count2); + final weight2 = count2 / (count1 + count2); + clusterIdToMeanEmbeddingAndWeight[clusterID1] = ( + mean1 * weight1 + mean2 * weight2, + count1 + count2, + ); + clusterIdToMeanEmbeddingAndWeight.remove(clusterID2); + clusterIds.remove(clusterID2); + } else { + break; + } + } + log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged'); + + // Now assign the new clusterId to the faces + for (final faceInfo in faceInfos) { + for (final mergedClusters in mergedClustersList) { + if (faceInfo.clusterId == mergedClusters.$2) { + faceInfo.clusterId = mergedClusters.$1; + } + } + } + + // Finally, assign the new clusterId to the faces + for (final faceInfo in faceInfos) { + newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; + } + stopwatchClustering.stop(); log( ' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index fe64218ec..e3edccc51 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -13,6 +13,8 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; +// import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; +// import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/services/search_service.dart"; @@ -232,14 +234,36 @@ class ClusterFeedbackService { maxClusterID++; } } else { + // final clusteringInput = embeddings + // .map((key, value) { + // return MapEntry( + // key, + // FaceInfoForClustering( + // faceID: key, + // embeddingBytes: value, + // faceScore: kMinHighQualityFaceScore + 0.01, + // blurValue: kLapacianDefault, + // ), + // ); + // }) + // .values + // .toSet(); + // final faceIdToCluster = + // await FaceClusteringService.instance.predictLinear( + // clusteringInput, + // fileIDToCreationTime: fileIDToCreationTime, + // distanceThreshold: 0.23, + // useDynamicThreshold: false, + // ); final faceIdToCluster = await FaceClusteringService.instance.predictComplete( embeddings, fileIDToCreationTime: fileIDToCreationTime, distanceThreshold: 0.30, + mergeThreshold: 0.30, ); - if (faceIdToCluster.isEmpty) { + if (faceIdToCluster == null || faceIdToCluster.isEmpty) { _logger.info('No clusters found'); return {}; } else {