[mob][photos] Use vectors everywhere in cluster suggestion
This commit is contained in:
parent
4b6641d7d8
commit
e829f7b62f
|
@ -3,6 +3,7 @@ import "dart:math" show Random, min;
|
||||||
|
|
||||||
import "package:flutter/foundation.dart";
|
import "package:flutter/foundation.dart";
|
||||||
import "package:logging/logging.dart";
|
import "package:logging/logging.dart";
|
||||||
|
import "package:ml_linalg/linalg.dart";
|
||||||
import "package:photos/core/event_bus.dart";
|
import "package:photos/core/event_bus.dart";
|
||||||
import "package:photos/db/files_db.dart";
|
import "package:photos/db/files_db.dart";
|
||||||
// import "package:photos/events/files_updated_event.dart";
|
// import "package:photos/events/files_updated_event.dart";
|
||||||
|
@ -245,13 +246,13 @@ class ClusterFeedbackService {
|
||||||
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
||||||
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
||||||
dev.log(
|
dev.log(
|
||||||
'existing clusters for ${p.data.name} are $personClusters',
|
'${p.data.name} has ${personClusters.length} existing clusters',
|
||||||
name: "ClusterFeedbackService",
|
name: "ClusterFeedbackService",
|
||||||
);
|
);
|
||||||
|
|
||||||
// Get and update the cluster summary to get the avg (centroid) and count
|
// Get and update the cluster summary to get the avg (centroid) and count
|
||||||
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
||||||
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
|
||||||
allClusterIdsToCountMap,
|
allClusterIdsToCountMap,
|
||||||
ignoredClusters,
|
ignoredClusters,
|
||||||
);
|
);
|
||||||
|
@ -466,19 +467,19 @@ class ClusterFeedbackService {
|
||||||
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
||||||
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
||||||
dev.log(
|
dev.log(
|
||||||
'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
'${p.data.name} has ${personClusters.length} existing clusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||||
name: "getSuggestionsUsingMedian",
|
name: "getSuggestionsUsingMedian",
|
||||||
);
|
);
|
||||||
|
|
||||||
// First only do a simple check on the big clusters, if the person does not have small clusters yet
|
// First only do a simple check on the big clusters, if the person does not have small clusters yet
|
||||||
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
||||||
final smallestPersonClusterSize = personClusters
|
final smallestPersonClusterSize = personClusters
|
||||||
.map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
|
.map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
|
||||||
.reduce((value, element) => min(value, element));
|
.reduce((value, element) => min(value, element));
|
||||||
final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1];
|
final checkSizes = [kMinimumClusterSizeSearchResult, 20, 10, 5, 1];
|
||||||
for (final minimumSize in checkSizes.toSet()) {
|
for (final minimumSize in checkSizes.toSet()) {
|
||||||
if (smallestPersonClusterSize >= minimumSize) {
|
if (smallestPersonClusterSize >= minimumSize) {
|
||||||
final Map<int, List<double>> clusterAvgBigClusters =
|
final Map<int, Vector> clusterAvgBigClusters =
|
||||||
await _getUpdateClusterAvg(
|
await _getUpdateClusterAvg(
|
||||||
allClusterIdsToCountMap,
|
allClusterIdsToCountMap,
|
||||||
ignoredClusters,
|
ignoredClusters,
|
||||||
|
@ -487,6 +488,7 @@ class ClusterFeedbackService {
|
||||||
dev.log(
|
dev.log(
|
||||||
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||||
);
|
);
|
||||||
|
w?.log('Calculate avg for min size $minimumSize');
|
||||||
final List<(int, double)> suggestionsMeanBigClusters =
|
final List<(int, double)> suggestionsMeanBigClusters =
|
||||||
_calcSuggestionsMean(
|
_calcSuggestionsMean(
|
||||||
clusterAvgBigClusters,
|
clusterAvgBigClusters,
|
||||||
|
@ -494,6 +496,7 @@ class ClusterFeedbackService {
|
||||||
ignoredClusters,
|
ignoredClusters,
|
||||||
goodMeanDistance,
|
goodMeanDistance,
|
||||||
);
|
);
|
||||||
|
w?.log('Calculate suggestions using mean for min size $minimumSize');
|
||||||
if (suggestionsMeanBigClusters.isNotEmpty) {
|
if (suggestionsMeanBigClusters.isNotEmpty) {
|
||||||
return suggestionsMeanBigClusters
|
return suggestionsMeanBigClusters
|
||||||
.map((e) => (e.$1, e.$2, true))
|
.map((e) => (e.$1, e.$2, true))
|
||||||
|
@ -501,9 +504,10 @@ class ClusterFeedbackService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
w?.reset();
|
||||||
|
|
||||||
// Get and update the cluster summary to get the avg (centroid) and count
|
// Get and update the cluster summary to get the avg (centroid) and count
|
||||||
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
|
||||||
allClusterIdsToCountMap,
|
allClusterIdsToCountMap,
|
||||||
ignoredClusters,
|
ignoredClusters,
|
||||||
);
|
);
|
||||||
|
@ -547,7 +551,7 @@ class ClusterFeedbackService {
|
||||||
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
|
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
|
||||||
);
|
);
|
||||||
|
|
||||||
watch.logAndReset("Starting median test");
|
w?.logAndReset("Starting median test");
|
||||||
// Take the embeddings from the person's clusters in one big list and sample from it
|
// Take the embeddings from the person's clusters in one big list and sample from it
|
||||||
final List<Uint8List> personEmbeddingsProto = [];
|
final List<Uint8List> personEmbeddingsProto = [];
|
||||||
for (final clusterID in personClusters) {
|
for (final clusterID in personClusters) {
|
||||||
|
@ -600,7 +604,7 @@ class ClusterFeedbackService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
watch.log("Finished median test");
|
w?.log("Finished median test");
|
||||||
if (suggestionsMedian.isEmpty) {
|
if (suggestionsMedian.isEmpty) {
|
||||||
_logger.info("No suggestions found using median");
|
_logger.info("No suggestions found using median");
|
||||||
return [];
|
return [];
|
||||||
|
@ -632,7 +636,7 @@ class ClusterFeedbackService {
|
||||||
return finalSuggestionsMedian;
|
return finalSuggestionsMedian;
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<Map<int, List<double>>> _getUpdateClusterAvg(
|
Future<Map<int, Vector>> _getUpdateClusterAvg(
|
||||||
Map<int, int> allClusterIdsToCountMap,
|
Map<int, int> allClusterIdsToCountMap,
|
||||||
Set<int> ignoredClusters, {
|
Set<int> ignoredClusters, {
|
||||||
int minClusterSize = 1,
|
int minClusterSize = 1,
|
||||||
|
@ -649,7 +653,7 @@ class ClusterFeedbackService {
|
||||||
await faceMlDb.getAllClusterSummary(minClusterSize);
|
await faceMlDb.getAllClusterSummary(minClusterSize);
|
||||||
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
|
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
|
||||||
|
|
||||||
final Map<int, List<double>> clusterAvg = {};
|
final Map<int, Vector> clusterAvg = {};
|
||||||
|
|
||||||
dev.log(
|
dev.log(
|
||||||
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||||
|
@ -666,7 +670,9 @@ class ClusterFeedbackService {
|
||||||
}
|
}
|
||||||
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
||||||
allClusterIds.remove(id);
|
allClusterIds.remove(id);
|
||||||
clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values;
|
clusterAvg[id] = Vector.fromList(
|
||||||
|
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
|
||||||
|
dtype: DType.float32,);
|
||||||
alreadyUpdatedClustersCnt++;
|
alreadyUpdatedClustersCnt++;
|
||||||
}
|
}
|
||||||
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
||||||
|
@ -731,19 +737,15 @@ class ClusterFeedbackService {
|
||||||
);
|
);
|
||||||
|
|
||||||
for (final clusterID in clusterEmbeddings.keys) {
|
for (final clusterID in clusterEmbeddings.keys) {
|
||||||
late List<double> avg;
|
final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
|
||||||
final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
|
final Iterable<Vector> vectors = embeddings.map((e) => Vector.fromList(
|
||||||
final List<double> sum = List.filled(192, 0);
|
EVector.fromBuffer(e).values,
|
||||||
for (final embedding in embedings) {
|
dtype: DType.float32,
|
||||||
final data = EVector.fromBuffer(embedding).values;
|
),);
|
||||||
for (int i = 0; i < sum.length; i++) {
|
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
|
||||||
sum[i] += data[i];
|
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
|
||||||
}
|
|
||||||
}
|
|
||||||
avg = sum.map((e) => e / embedings.length).toList();
|
|
||||||
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
|
|
||||||
updatesForClusterSummary[clusterID] =
|
updatesForClusterSummary[clusterID] =
|
||||||
(avgEmbeedingBuffer, embedings.length);
|
(avgEmbeddingBuffer, embeddings.length);
|
||||||
// store the intermediate updates
|
// store the intermediate updates
|
||||||
indexedInCurrentRun++;
|
indexedInCurrentRun++;
|
||||||
if (updatesForClusterSummary.length > 100) {
|
if (updatesForClusterSummary.length > 100) {
|
||||||
|
@ -770,7 +772,7 @@ class ClusterFeedbackService {
|
||||||
|
|
||||||
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
|
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
|
||||||
List<(int, double)> _calcSuggestionsMean(
|
List<(int, double)> _calcSuggestionsMean(
|
||||||
Map<int, List<double>> clusterAvg,
|
Map<int, Vector> clusterAvg,
|
||||||
Set<int> personClusters,
|
Set<int> personClusters,
|
||||||
Set<int> ignoredClusters,
|
Set<int> ignoredClusters,
|
||||||
double maxClusterDistance, {
|
double maxClusterDistance, {
|
||||||
|
@ -779,23 +781,14 @@ class ClusterFeedbackService {
|
||||||
final Map<int, List<(int, double)>> suggestions = {};
|
final Map<int, List<(int, double)>> suggestions = {};
|
||||||
int suggestionCount = 0;
|
int suggestionCount = 0;
|
||||||
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
||||||
final clusterAvgVectors = clusterAvg.map(
|
|
||||||
(key, value) => MapEntry(
|
|
||||||
key,
|
|
||||||
Vector.fromList(
|
|
||||||
value,
|
|
||||||
dtype: DType.float32,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
);
|
|
||||||
w?.log('converted avg to vectors for ${clusterAvg.length} averages');
|
w?.log('converted avg to vectors for ${clusterAvg.length} averages');
|
||||||
for (final otherClusterID in clusterAvgVectors.keys) {
|
for (final otherClusterID in clusterAvg.keys) {
|
||||||
// ignore the cluster that belong to the person or is ignored
|
// ignore the cluster that belong to the person or is ignored
|
||||||
if (personClusters.contains(otherClusterID) ||
|
if (personClusters.contains(otherClusterID) ||
|
||||||
ignoredClusters.contains(otherClusterID)) {
|
ignoredClusters.contains(otherClusterID)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
final otherAvg = clusterAvgVectors[otherClusterID]!;
|
final Vector otherAvg = clusterAvg[otherClusterID]!;
|
||||||
int? nearestPersonCluster;
|
int? nearestPersonCluster;
|
||||||
double? minDistance;
|
double? minDistance;
|
||||||
for (final personCluster in personClusters) {
|
for (final personCluster in personClusters) {
|
||||||
|
@ -803,7 +796,7 @@ class ClusterFeedbackService {
|
||||||
_logger.info('no avg for cluster $personCluster');
|
_logger.info('no avg for cluster $personCluster');
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
final avg = clusterAvgVectors[personCluster]!;
|
final Vector avg = clusterAvg[personCluster]!;
|
||||||
final distance = 1 - avg.dot(otherAvg);
|
final distance = 1 - avg.dot(otherAvg);
|
||||||
if (distance < maxClusterDistance) {
|
if (distance < maxClusterDistance) {
|
||||||
if (minDistance == null || distance < minDistance) {
|
if (minDistance == null || distance < minDistance) {
|
||||||
|
|
Loading…
Reference in a new issue