From faa07a0704ac38056bb008c08109fec4db111593 Mon Sep 17 00:00:00 2001 From: Neeraj Gupta <254676+ua741@users.noreply.github.com> Date: Tue, 2 Apr 2024 16:56:55 +0530 Subject: [PATCH] [mob] compute suggestion in small batches --- .../face_ml/feedback/cluster_feedback.dart | 84 ++++++++++++------- 1 file changed, 53 insertions(+), 31 deletions(-) diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart index a0ede5809..745c73245 100644 --- a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -367,11 +367,13 @@ class ClusterFeedbackService { Future>> _getUpdateClusterAvg( Map allClusterIdsToCountMap, - Set ignoredClusters, - ) async { + Set ignoredClusters, { + int minClusterSize = 1, + int maxClusterInCurrentRun = 500, + }) async { final faceMlDb = FaceMLDataDB.instance; _logger.info( - 'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters', + 'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun', ); final Map clusterToSummary = @@ -380,42 +382,61 @@ class ClusterFeedbackService { final Map> clusterAvg = {}; - final allClusterIds = allClusterIdsToCountMap.keys; - for (final clusterID in allClusterIds) { - if (ignoredClusters.contains(clusterID)) { - continue; + final allClusterIds = allClusterIdsToCountMap.keys.toSet(); + int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0; + int smallerClustersCnt = 0; + for (final id in allClusterIdsToCountMap.keys) { + if (ignoredClusters.contains(id)) { + allClusterIds.remove(id); + ignoredClustersCnt++; } - if (allClusterIdsToCountMap[clusterID]! < 2) { - continue; + if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) { + allClusterIds.remove(id); + clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values; + alreadyUpdatedClustersCnt++; } + if (allClusterIdsToCountMap[id]! < minClusterSize) { + allClusterIds.remove(id); + smallerClustersCnt++; + } + } + _logger.info( + 'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize', + ); + // get clusterIDs sorted by count in descending order + final sortedClusterIDs = allClusterIds.toList(); + sortedClusterIDs.sort( + (a, b) => + allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!), + ); + int indexedInCurrentRun = 0; - late List avg; - if (clusterToSummary[clusterID]?.$2 == - allClusterIdsToCountMap[clusterID]) { - avg = EVector.fromBuffer(clusterToSummary[clusterID]!.$1).values; - } else { - final Iterable embedings = - await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); - final List sum = List.filled(192, 0); - for (final embedding in embedings) { - final data = EVector.fromBuffer(embedding).values; - for (int i = 0; i < sum.length; i++) { - sum[i] += data[i]; - } - } - avg = sum.map((e) => e / embedings.length).toList(); - final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer(); - updatesForClusterSummary[clusterID] = - (avgEmbeedingBuffer, embedings.length); + for (final clusterID in sortedClusterIDs) { + if (maxClusterInCurrentRun-- <= 0) { + break; } + indexedInCurrentRun++; + late List avg; + final Iterable embedings = + await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); + final List sum = List.filled(192, 0); + for (final embedding in embedings) { + final data = EVector.fromBuffer(embedding).values; + for (int i = 0; i < sum.length; i++) { + sum[i] += data[i]; + } + } + avg = sum.map((e) => e / embedings.length).toList(); + final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer(); + updatesForClusterSummary[clusterID] = + (avgEmbeedingBuffer, embedings.length); // store the intermediate updates if (updatesForClusterSummary.length > 100) { await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); updatesForClusterSummary.clear(); if (kDebugMode) { _logger.info( - 'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters', - ); + 'getUpdateClusterAvg $indexedInCurrentRun clusters in current one'); } } clusterAvg[clusterID] = avg; @@ -549,8 +570,9 @@ class ClusterFeedbackService { ); } suggestion.$4.sort((b, a) { - final double distanceA = fileIdToDistanceMap[a.uploadedFileID!]; - final double distanceB = fileIdToDistanceMap[b.uploadedFileID!]; + //todo: review with @laurens, added this to avoid null safety issue + final double distanceA = fileIdToDistanceMap[a.uploadedFileID!] ?? -1; + final double distanceB = fileIdToDistanceMap[b.uploadedFileID!] ?? -1; return distanceA.compareTo(distanceB); });