[mob][photos] Use vectors everywhere in cluster suggestion

This commit is contained in:
laurenspriem 2024-04-24 16:01:03 +05:30
parent 4b6641d7d8
commit e829f7b62f

View file

@ -3,6 +3,7 @@ import "dart:math" show Random, min;
import "package:flutter/foundation.dart"; import "package:flutter/foundation.dart";
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:ml_linalg/linalg.dart";
import "package:photos/core/event_bus.dart"; import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart"; import "package:photos/db/files_db.dart";
// import "package:photos/events/files_updated_event.dart"; // import "package:photos/events/files_updated_event.dart";
@ -245,13 +246,13 @@ class ClusterFeedbackService {
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log( dev.log(
'existing clusters for ${p.data.name} are $personClusters', '${p.data.name} has ${personClusters.length} existing clusters',
name: "ClusterFeedbackService", name: "ClusterFeedbackService",
); );
// Get and update the cluster summary to get the avg (centroid) and count // Get and update the cluster summary to get the avg (centroid) and count
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg( final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap, allClusterIdsToCountMap,
ignoredClusters, ignoredClusters,
); );
@ -466,19 +467,19 @@ class ClusterFeedbackService {
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log( dev.log(
'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms', '${p.data.name} has ${personClusters.length} existing clusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
name: "getSuggestionsUsingMedian", name: "getSuggestionsUsingMedian",
); );
// First only do a simple check on the big clusters, if the person does not have small clusters yet // First only do a simple check on the big clusters, if the person does not have small clusters yet
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
final smallestPersonClusterSize = personClusters final smallestPersonClusterSize = personClusters
.map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0) .map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
.reduce((value, element) => min(value, element)); .reduce((value, element) => min(value, element));
final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1]; final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1];
for (final minimumSize in checkSizes.toSet()) { for (final minimumSize in checkSizes.toSet()) {
if (smallestPersonClusterSize >= minimumSize) { if (smallestPersonClusterSize >= minimumSize) {
final Map<int, List<double>> clusterAvgBigClusters = final Map<int, Vector> clusterAvgBigClusters =
await _getUpdateClusterAvg( await _getUpdateClusterAvg(
allClusterIdsToCountMap, allClusterIdsToCountMap,
ignoredClusters, ignoredClusters,
@ -487,6 +488,7 @@ class ClusterFeedbackService {
dev.log( dev.log(
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms', 'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
); );
w?.log('Calculate avg for min size $minimumSize');
final List<(int, double)> suggestionsMeanBigClusters = final List<(int, double)> suggestionsMeanBigClusters =
_calcSuggestionsMean( _calcSuggestionsMean(
clusterAvgBigClusters, clusterAvgBigClusters,
@ -494,6 +496,7 @@ class ClusterFeedbackService {
ignoredClusters, ignoredClusters,
goodMeanDistance, goodMeanDistance,
); );
w?.log('Calculate suggestions using mean for min size $minimumSize');
if (suggestionsMeanBigClusters.isNotEmpty) { if (suggestionsMeanBigClusters.isNotEmpty) {
return suggestionsMeanBigClusters return suggestionsMeanBigClusters
.map((e) => (e.$1, e.$2, true)) .map((e) => (e.$1, e.$2, true))
@ -501,9 +504,10 @@ class ClusterFeedbackService {
} }
} }
} }
w?.reset();
// Get and update the cluster summary to get the avg (centroid) and count // Get and update the cluster summary to get the avg (centroid) and count
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg( final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap, allClusterIdsToCountMap,
ignoredClusters, ignoredClusters,
); );
@ -547,7 +551,7 @@ class ClusterFeedbackService {
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates", "Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
); );
watch.logAndReset("Starting median test"); w?.logAndReset("Starting median test");
// Take the embeddings from the person's clusters in one big list and sample from it // Take the embeddings from the person's clusters in one big list and sample from it
final List<Uint8List> personEmbeddingsProto = []; final List<Uint8List> personEmbeddingsProto = [];
for (final clusterID in personClusters) { for (final clusterID in personClusters) {
@ -600,7 +604,7 @@ class ClusterFeedbackService {
} }
} }
} }
watch.log("Finished median test"); w?.log("Finished median test");
if (suggestionsMedian.isEmpty) { if (suggestionsMedian.isEmpty) {
_logger.info("No suggestions found using median"); _logger.info("No suggestions found using median");
return []; return [];
@ -632,7 +636,7 @@ class ClusterFeedbackService {
return finalSuggestionsMedian; return finalSuggestionsMedian;
} }
Future<Map<int, List<double>>> _getUpdateClusterAvg( Future<Map<int, Vector>> _getUpdateClusterAvg(
Map<int, int> allClusterIdsToCountMap, Map<int, int> allClusterIdsToCountMap,
Set<int> ignoredClusters, { Set<int> ignoredClusters, {
int minClusterSize = 1, int minClusterSize = 1,
@ -649,7 +653,7 @@ class ClusterFeedbackService {
await faceMlDb.getAllClusterSummary(minClusterSize); await faceMlDb.getAllClusterSummary(minClusterSize);
final Map<int, (Uint8List, int)> updatesForClusterSummary = {}; final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
final Map<int, List<double>> clusterAvg = {}; final Map<int, Vector> clusterAvg = {};
dev.log( dev.log(
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms', 'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
@ -666,7 +670,9 @@ class ClusterFeedbackService {
} }
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) { if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
allClusterIds.remove(id); allClusterIds.remove(id);
clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values; clusterAvg[id] = Vector.fromList(
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
dtype: DType.float32,);
alreadyUpdatedClustersCnt++; alreadyUpdatedClustersCnt++;
} }
if (allClusterIdsToCountMap[id]! < minClusterSize) { if (allClusterIdsToCountMap[id]! < minClusterSize) {
@ -731,19 +737,15 @@ class ClusterFeedbackService {
); );
for (final clusterID in clusterEmbeddings.keys) { for (final clusterID in clusterEmbeddings.keys) {
late List<double> avg; final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!; final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
final List<double> sum = List.filled(192, 0); EVector.fromBuffer(e).values,
for (final embedding in embedings) { dtype: DType.float32,
final data = EVector.fromBuffer(embedding).values; ),);
for (int i = 0; i < sum.length; i++) { final avg = vectors.reduce((a, b) => a + b) / vectors.length;
sum[i] += data[i]; final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
}
}
avg = sum.map((e) => e / embedings.length).toList();
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
updatesForClusterSummary[clusterID] = updatesForClusterSummary[clusterID] =
(avgEmbeedingBuffer, embedings.length); (avgEmbeddingBuffer, embeddings.length);
// store the intermediate updates // store the intermediate updates
indexedInCurrentRun++; indexedInCurrentRun++;
if (updatesForClusterSummary.length > 100) { if (updatesForClusterSummary.length > 100) {
@ -770,7 +772,7 @@ class ClusterFeedbackService {
/// Returns a map of person's clusterID to map of closest clusterID to with disstance /// Returns a map of person's clusterID to map of closest clusterID to with disstance
List<(int, double)> _calcSuggestionsMean( List<(int, double)> _calcSuggestionsMean(
Map<int, List<double>> clusterAvg, Map<int, Vector> clusterAvg,
Set<int> personClusters, Set<int> personClusters,
Set<int> ignoredClusters, Set<int> ignoredClusters,
double maxClusterDistance, { double maxClusterDistance, {
@ -779,23 +781,14 @@ class ClusterFeedbackService {
final Map<int, List<(int, double)>> suggestions = {}; final Map<int, List<(int, double)>> suggestions = {};
int suggestionCount = 0; int suggestionCount = 0;
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start(); final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
final clusterAvgVectors = clusterAvg.map(
(key, value) => MapEntry(
key,
Vector.fromList(
value,
dtype: DType.float32,
),
),
);
w?.log('converted avg to vectors for ${clusterAvg.length} averages'); w?.log('converted avg to vectors for ${clusterAvg.length} averages');
for (final otherClusterID in clusterAvgVectors.keys) { for (final otherClusterID in clusterAvg.keys) {
// ignore the cluster that belong to the person or is ignored // ignore the cluster that belong to the person or is ignored
if (personClusters.contains(otherClusterID) || if (personClusters.contains(otherClusterID) ||
ignoredClusters.contains(otherClusterID)) { ignoredClusters.contains(otherClusterID)) {
continue; continue;
} }
final otherAvg = clusterAvgVectors[otherClusterID]!; final Vector otherAvg = clusterAvg[otherClusterID]!;
int? nearestPersonCluster; int? nearestPersonCluster;
double? minDistance; double? minDistance;
for (final personCluster in personClusters) { for (final personCluster in personClusters) {
@ -803,7 +796,7 @@ class ClusterFeedbackService {
_logger.info('no avg for cluster $personCluster'); _logger.info('no avg for cluster $personCluster');
continue; continue;
} }
final avg = clusterAvgVectors[personCluster]!; final Vector avg = clusterAvg[personCluster]!;
final distance = 1 - avg.dot(otherAvg); final distance = 1 - avg.dot(otherAvg);
if (distance < maxClusterDistance) { if (distance < maxClusterDistance) {
if (minDistance == null || distance < minDistance) { if (minDistance == null || distance < minDistance) {