[mob][photos] Use vectors everywhere in cluster suggestion

This commit is contained in:
laurenspriem 2024-04-24 16:01:03 +05:30
parent 4b6641d7d8
commit e829f7b62f

View file

@ -3,6 +3,7 @@ import "dart:math" show Random, min;
import "package:flutter/foundation.dart";
import "package:logging/logging.dart";
import "package:ml_linalg/linalg.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/db/files_db.dart";
// import "package:photos/events/files_updated_event.dart";
@ -245,13 +246,13 @@ class ClusterFeedbackService {
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log(
'existing clusters for ${p.data.name} are $personClusters',
'${p.data.name} has ${personClusters.length} existing clusters',
name: "ClusterFeedbackService",
);
// Get and update the cluster summary to get the avg (centroid) and count
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
);
@ -466,19 +467,19 @@ class ClusterFeedbackService {
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
dev.log(
'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
'${p.data.name} has ${personClusters.length} existing clusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
name: "getSuggestionsUsingMedian",
);
// First only do a simple check on the big clusters, if the person does not have small clusters yet
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
final smallestPersonClusterSize = personClusters
.map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
.reduce((value, element) => min(value, element));
final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1];
for (final minimumSize in checkSizes.toSet()) {
if (smallestPersonClusterSize >= minimumSize) {
final Map<int, List<double>> clusterAvgBigClusters =
final Map<int, Vector> clusterAvgBigClusters =
await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
@ -487,6 +488,7 @@ class ClusterFeedbackService {
dev.log(
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
);
w?.log('Calculate avg for min size $minimumSize');
final List<(int, double)> suggestionsMeanBigClusters =
_calcSuggestionsMean(
clusterAvgBigClusters,
@ -494,6 +496,7 @@ class ClusterFeedbackService {
ignoredClusters,
goodMeanDistance,
);
w?.log('Calculate suggestions using mean for min size $minimumSize');
if (suggestionsMeanBigClusters.isNotEmpty) {
return suggestionsMeanBigClusters
.map((e) => (e.$1, e.$2, true))
@ -501,9 +504,10 @@ class ClusterFeedbackService {
}
}
}
w?.reset();
// Get and update the cluster summary to get the avg (centroid) and count
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
allClusterIdsToCountMap,
ignoredClusters,
);
@ -547,7 +551,7 @@ class ClusterFeedbackService {
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
);
watch.logAndReset("Starting median test");
w?.logAndReset("Starting median test");
// Take the embeddings from the person's clusters in one big list and sample from it
final List<Uint8List> personEmbeddingsProto = [];
for (final clusterID in personClusters) {
@ -600,7 +604,7 @@ class ClusterFeedbackService {
}
}
}
watch.log("Finished median test");
w?.log("Finished median test");
if (suggestionsMedian.isEmpty) {
_logger.info("No suggestions found using median");
return [];
@ -632,7 +636,7 @@ class ClusterFeedbackService {
return finalSuggestionsMedian;
}
Future<Map<int, List<double>>> _getUpdateClusterAvg(
Future<Map<int, Vector>> _getUpdateClusterAvg(
Map<int, int> allClusterIdsToCountMap,
Set<int> ignoredClusters, {
int minClusterSize = 1,
@ -649,7 +653,7 @@ class ClusterFeedbackService {
await faceMlDb.getAllClusterSummary(minClusterSize);
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
final Map<int, List<double>> clusterAvg = {};
final Map<int, Vector> clusterAvg = {};
dev.log(
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
@ -666,7 +670,9 @@ class ClusterFeedbackService {
}
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
allClusterIds.remove(id);
clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values;
clusterAvg[id] = Vector.fromList(
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
dtype: DType.float32,);
alreadyUpdatedClustersCnt++;
}
if (allClusterIdsToCountMap[id]! < minClusterSize) {
@ -731,19 +737,15 @@ class ClusterFeedbackService {
);
for (final clusterID in clusterEmbeddings.keys) {
late List<double> avg;
final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
final List<double> sum = List.filled(192, 0);
for (final embedding in embedings) {
final data = EVector.fromBuffer(embedding).values;
for (int i = 0; i < sum.length; i++) {
sum[i] += data[i];
}
}
avg = sum.map((e) => e / embedings.length).toList();
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
EVector.fromBuffer(e).values,
dtype: DType.float32,
),);
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
updatesForClusterSummary[clusterID] =
(avgEmbeedingBuffer, embedings.length);
(avgEmbeddingBuffer, embeddings.length);
// store the intermediate updates
indexedInCurrentRun++;
if (updatesForClusterSummary.length > 100) {
@ -770,7 +772,7 @@ class ClusterFeedbackService {
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
List<(int, double)> _calcSuggestionsMean(
Map<int, List<double>> clusterAvg,
Map<int, Vector> clusterAvg,
Set<int> personClusters,
Set<int> ignoredClusters,
double maxClusterDistance, {
@ -779,23 +781,14 @@ class ClusterFeedbackService {
final Map<int, List<(int, double)>> suggestions = {};
int suggestionCount = 0;
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
final clusterAvgVectors = clusterAvg.map(
(key, value) => MapEntry(
key,
Vector.fromList(
value,
dtype: DType.float32,
),
),
);
w?.log('converted avg to vectors for ${clusterAvg.length} averages');
for (final otherClusterID in clusterAvgVectors.keys) {
for (final otherClusterID in clusterAvg.keys) {
// ignore the cluster that belong to the person or is ignored
if (personClusters.contains(otherClusterID) ||
ignoredClusters.contains(otherClusterID)) {
continue;
}
final otherAvg = clusterAvgVectors[otherClusterID]!;
final Vector otherAvg = clusterAvg[otherClusterID]!;
int? nearestPersonCluster;
double? minDistance;
for (final personCluster in personClusters) {
@ -803,7 +796,7 @@ class ClusterFeedbackService {
_logger.info('no avg for cluster $personCluster');
continue;
}
final avg = clusterAvgVectors[personCluster]!;
final Vector avg = clusterAvg[personCluster]!;
final distance = 1 - avg.dot(otherAvg);
if (distance < maxClusterDistance) {
if (minDistance == null || distance < minDistance) {