View faces with highest distance in cluster suggestion
This commit is contained in:
parent
c85692360c
commit
255b566342
|
@ -11,6 +11,7 @@ import "package:photos/face/model/person.dart";
|
||||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||||
import "package:photos/models/file/file.dart";
|
import "package:photos/models/file/file.dart";
|
||||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||||
|
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||||
import "package:photos/services/search_service.dart";
|
import "package:photos/services/search_service.dart";
|
||||||
|
|
||||||
class ClusterFeedbackService {
|
class ClusterFeedbackService {
|
||||||
|
@ -241,22 +242,13 @@ class ClusterFeedbackService {
|
||||||
/// 3. bool: whether the suggestion was found using the mean (true) or the median (false)
|
/// 3. bool: whether the suggestion was found using the mean (true) or the median (false)
|
||||||
/// 4. List<EnteFile>: the files in the cluster
|
/// 4. List<EnteFile>: the files in the cluster
|
||||||
Future<List<(int, double, bool, List<EnteFile>)>> getClusterFilesForPersonID(
|
Future<List<(int, double, bool, List<EnteFile>)>> getClusterFilesForPersonID(
|
||||||
Person person,
|
Person person, {
|
||||||
) async {
|
bool extremeFilesFirst = true,
|
||||||
|
}) async {
|
||||||
_logger.info(
|
_logger.info(
|
||||||
'getClusterFilesForPersonID ${kDebugMode ? person.attr.name : person.remoteID}',
|
'getClusterFilesForPersonID ${kDebugMode ? person.attr.name : person.remoteID}',
|
||||||
);
|
);
|
||||||
|
|
||||||
// Get the suggestions for the person using only centroids
|
|
||||||
// final Map<int, List<(int, double)>> suggestions =
|
|
||||||
// await getSuggestionsUsingMean(person);
|
|
||||||
// final Set<int> suggestClusterIds = {};
|
|
||||||
// for (final List<(int, double)> suggestion in suggestions.values) {
|
|
||||||
// for (final clusterNeighbors in suggestion) {
|
|
||||||
// suggestClusterIds.add(clusterNeighbors.$1);
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Get the suggestions for the person using centroids and median
|
// Get the suggestions for the person using centroids and median
|
||||||
final List<(int, double, bool)> suggestClusterIds =
|
final List<(int, double, bool)> suggestClusterIds =
|
||||||
|
@ -297,6 +289,10 @@ class ClusterFeedbackService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (extremeFilesFirst) {
|
||||||
|
await _sortSuggestionsOnDistanceToPerson(person, clusterIdAndFiles);
|
||||||
|
}
|
||||||
|
|
||||||
return clusterIdAndFiles;
|
return clusterIdAndFiles;
|
||||||
} catch (e, s) {
|
} catch (e, s) {
|
||||||
_logger.severe("Error in getClusterFilesForPersonID", e, s);
|
_logger.severe("Error in getClusterFilesForPersonID", e, s);
|
||||||
|
@ -505,4 +501,67 @@ class ClusterFeedbackService {
|
||||||
|
|
||||||
return sampledEmbeddings;
|
return sampledEmbeddings;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Future<void> _sortSuggestionsOnDistanceToPerson(
|
||||||
|
Person person,
|
||||||
|
List<(int, double, bool, List<EnteFile>)> suggestions,
|
||||||
|
) async {
|
||||||
|
if (suggestions.isEmpty) {
|
||||||
|
debugPrint('No suggestions to sort');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
final startTime = DateTime.now();
|
||||||
|
final faceMlDb = FaceMLDataDB.instance;
|
||||||
|
|
||||||
|
// Get the cluster averages for the person's clusters and the suggestions' clusters
|
||||||
|
final Map<int, (Uint8List, int)> clusterToSummary =
|
||||||
|
await faceMlDb.clusterSummaryAll();
|
||||||
|
|
||||||
|
// Calculate the avg embedding of the person
|
||||||
|
final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID);
|
||||||
|
final personEmbeddingsCount = personClusters
|
||||||
|
.map((e) => clusterToSummary[e]!.$2)
|
||||||
|
.reduce((a, b) => a + b);
|
||||||
|
final List<double> personAvg = List.filled(192, 0);
|
||||||
|
for (final personClusterID in personClusters) {
|
||||||
|
final personClusterBlob = clusterToSummary[personClusterID]!.$1;
|
||||||
|
final personClusterAvg = EVector.fromBuffer(personClusterBlob).values;
|
||||||
|
final clusterWeight =
|
||||||
|
clusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
|
||||||
|
for (int i = 0; i < personClusterAvg.length; i++) {
|
||||||
|
personAvg[i] += personClusterAvg[i] *
|
||||||
|
clusterWeight; // Weighted sum of the cluster averages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort the suggestions based on the distance to the person
|
||||||
|
for (final suggestion in suggestions) {
|
||||||
|
final clusterID = suggestion.$1;
|
||||||
|
final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFile(
|
||||||
|
suggestion.$4.map((e) => e.uploadedFileID!).toList(),
|
||||||
|
);
|
||||||
|
final fileIdToDistanceMap = {};
|
||||||
|
for (final entry in faceIdToEmbeddingMap.entries) {
|
||||||
|
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
|
||||||
|
cosineDistForNormVectors(
|
||||||
|
personAvg,
|
||||||
|
EVector.fromBuffer(entry.value).values,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
suggestion.$4.sort((b, a) {
|
||||||
|
final double distanceA = fileIdToDistanceMap[a.uploadedFileID!];
|
||||||
|
final double distanceB = fileIdToDistanceMap[b.uploadedFileID!];
|
||||||
|
return distanceA.compareTo(distanceB);
|
||||||
|
});
|
||||||
|
|
||||||
|
debugPrint(
|
||||||
|
"[${_logger.name}] Sorted suggestions for cluster $clusterID based on distance to person: ${suggestion.$4.map((e) => fileIdToDistanceMap[e.uploadedFileID]).toList()}",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
final endTime = DateTime.now();
|
||||||
|
_logger.info(
|
||||||
|
"Sorting suggestions based on distance to person took ${endTime.difference(startTime).inMilliseconds} ms for ${suggestions.length} suggestions",
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue