From cb8f66fcaa12f633413beb19a37d109cb57bdedd Mon Sep 17 00:00:00 2001 From: laurenspriem Date: Wed, 3 Apr 2024 13:06:46 +0530 Subject: [PATCH] [mob][wip] break up cluster method --- mobile/lib/face/db.dart | 11 ++++++ .../face_ml/feedback/cluster_feedback.dart | 38 +++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index 13db25f21..0dca25e93 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -322,6 +322,17 @@ class FaceMLDataDB { return mapRowToFace(result.first); } + Future> getFaceIDsForCluster(int clusterID) async { + final db = await instance.database; + final List> maps = await db.query( + faceClustersTable, + columns: [fcFaceId], + where: '$fcClusterID = ?', + whereArgs: [clusterID], + ); + return maps.map((e) => e[fcFaceId] as String).toSet(); + } + Future> getFaceIdsToClusterIds( Iterable faceIds, ) async { diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index 2da6fc1f5..4fd501bd8 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -11,6 +11,7 @@ import "package:photos/face/model/person.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/models/file/file.dart"; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; +import "package:photos/services/machine_learning/face_ml/face_clustering/linear_clustering_service.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/search_service.dart"; @@ -365,6 +366,43 @@ class ClusterFeedbackService { return true; } + Future>> breakUpCluster(int clusterID) async { + final faceMlDb = FaceMLDataDB.instance; + + final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID); + final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList(); + + final embeddings = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs); + embeddings.removeWhere((key, value) => !faceIDs.contains(key)); + final clusteringInput = embeddings.map((key, value) { + return MapEntry(key, (null, value)); + }); + + final faceIdToCluster = await FaceLinearClustering.instance + .predict(clusteringInput, distanceThreshold: 0.15); + + if (faceIdToCluster == null) { + return {}; + } + + final clusterIdToFaceIds = >{}; + for (final entry in faceIdToCluster.entries) { + final clusterID = entry.value; + if (clusterIdToFaceIds.containsKey(clusterID)) { + clusterIdToFaceIds[clusterID]!.add(entry.key); + } else { + clusterIdToFaceIds[clusterID] = [entry.key]; + } + } + + final clusterIdToCount = clusterIdToFaceIds.map((key, value) { + return MapEntry(key, value.length); + }); + final amountOfNewClusters = clusterIdToCount.length; + + return clusterIdToFaceIds; + } + Future>> _getUpdateClusterAvg( Map allClusterIdsToCountMap, Set ignoredClusters, {