diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart index 3219fe7f9..c84059e39 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -854,7 +854,10 @@ class FaceClusteringService { .toList(); final count = clusterIdToFaceInfos[clusterId]!.length; final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count; - clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count); + final Vector meanEmbeddingNormalized = + meanEmbedding / meanEmbedding.norm(); + clusterIdToMeanEmbeddingAndWeight[clusterId] = + (meanEmbeddingNormalized, count); } // Now merge the clusters that are close to each other, based on mean embedding @@ -889,8 +892,10 @@ class FaceClusteringService { final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2; final weight1 = count1 / (count1 + count2); final weight2 = count2 / (count1 + count2); + final weightedMean = mean1 * weight1 + mean2 * weight2; + final weightedMeanNormalized = weightedMean / weightedMean.norm(); clusterIdToMeanEmbeddingAndWeight[clusterID1] = ( - mean1 * weight1 + mean2 * weight2, + weightedMeanNormalized, count1 + count2, ); clusterIdToMeanEmbeddingAndWeight.remove(clusterID2);