[mob] merge mobile_face to fix_face_thumbnail
This commit is contained in:
commit
a577611e65
|
@ -98,7 +98,7 @@ class FaceMLDataDB {
|
|||
}
|
||||
}
|
||||
|
||||
Future<void> updateClusterIdToFaceId(
|
||||
Future<void> updateFaceIdToClusterId(
|
||||
Map<String, int> faceIDToClusterID,
|
||||
) async {
|
||||
final db = await instance.database;
|
||||
|
@ -146,8 +146,8 @@ class FaceMLDataDB {
|
|||
}
|
||||
|
||||
Future<Map<int, int>> clusterIdToFaceCount() async {
|
||||
final db = await instance.database;
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $fcClusterID, COUNT(*) as count FROM $faceClustersTable where $fcClusterID IS NOT NULL GROUP BY $fcClusterID ',
|
||||
);
|
||||
final Map<int, int> result = {};
|
||||
|
@ -158,15 +158,15 @@ class FaceMLDataDB {
|
|||
}
|
||||
|
||||
Future<Set<int>> getPersonIgnoredClusters(String personID) async {
|
||||
final db = await instance.database;
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
// find out clusterIds that are assigned to other persons using the clusters table
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL',
|
||||
[personID],
|
||||
);
|
||||
final Set<int> ignoredClusterIDs =
|
||||
maps.map((e) => e[clusterIDColumn] as int).toSet();
|
||||
final List<Map<String, dynamic>> rejectMaps = await db.rawQuery(
|
||||
final List<Map<String, dynamic>> rejectMaps = await db.getAll(
|
||||
'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?',
|
||||
[personID],
|
||||
);
|
||||
|
@ -176,8 +176,8 @@ class FaceMLDataDB {
|
|||
}
|
||||
|
||||
Future<Set<int>> getPersonClusterIDs(String personID) async {
|
||||
final db = await instance.database;
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?',
|
||||
[personID],
|
||||
);
|
||||
|
@ -197,8 +197,8 @@ class FaceMLDataDB {
|
|||
int clusterID, {
|
||||
int? limit,
|
||||
}) async {
|
||||
final db = await instance.database;
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $faceEmbeddingBlob FROM $facesTable WHERE $faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID = ?) ${limit != null ? 'LIMIT $limit' : ''}',
|
||||
[clusterID],
|
||||
);
|
||||
|
@ -209,7 +209,7 @@ class FaceMLDataDB {
|
|||
Iterable<int> clusterIDs, {
|
||||
int? limit,
|
||||
}) async {
|
||||
final db = await instance.database;
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final Map<int, List<Uint8List>> result = {};
|
||||
|
||||
final selectQuery = '''
|
||||
|
@ -220,7 +220,7 @@ class FaceMLDataDB {
|
|||
${limit != null ? 'LIMIT $limit' : ''}
|
||||
''';
|
||||
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(selectQuery);
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(selectQuery);
|
||||
|
||||
for (final map in maps) {
|
||||
final clusterID = map[fcClusterID] as int;
|
||||
|
@ -321,8 +321,8 @@ class FaceMLDataDB {
|
|||
}
|
||||
|
||||
Future<Face?> getFaceForFaceID(String faceID) async {
|
||||
final db = await instance.database;
|
||||
final result = await db.rawQuery(
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final result = await db.getAll(
|
||||
'SELECT * FROM $facesTable where $faceIDColumn = ?',
|
||||
[faceID],
|
||||
);
|
||||
|
@ -332,6 +332,36 @@ class FaceMLDataDB {
|
|||
return mapRowToFace(result.first);
|
||||
}
|
||||
|
||||
Future<Map<int, Iterable<String>>> getClusterToFaceIDs(
|
||||
Set<int> clusterIDs,
|
||||
) async {
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final Map<int, List<String>> result = {};
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable WHERE $fcClusterID IN (${clusterIDs.join(",")})',
|
||||
);
|
||||
for (final map in maps) {
|
||||
final clusterID = map[fcClusterID] as int;
|
||||
final faceID = map[fcFaceId] as String;
|
||||
result.putIfAbsent(clusterID, () => <String>[]).add(faceID);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Future<Map<int, Iterable<String>>> getAllClusterIdToFaceIDs() async {
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final Map<int, List<String>> result = {};
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable',
|
||||
);
|
||||
for (final map in maps) {
|
||||
final clusterID = map[fcClusterID] as int;
|
||||
final faceID = map[fcFaceId] as String;
|
||||
result.putIfAbsent(clusterID, () => <String>[]).add(faceID);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Future<Iterable<String>> getFaceIDsForCluster(int clusterID) async {
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
|
@ -390,8 +420,8 @@ class FaceMLDataDB {
|
|||
Future<Map<String, int?>> getFaceIdsToClusterIds(
|
||||
Iterable<String> faceIds,
|
||||
) async {
|
||||
final db = await instance.database;
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $fcFaceId, $fcClusterID FROM $faceClustersTable where $fcFaceId IN (${faceIds.map((id) => "'$id'").join(",")})',
|
||||
);
|
||||
final Map<String, int?> result = {};
|
||||
|
@ -403,8 +433,8 @@ class FaceMLDataDB {
|
|||
|
||||
Future<Map<int, Set<int>>> getFileIdToClusterIds() async {
|
||||
final Map<int, Set<int>> result = {};
|
||||
final db = await instance.database;
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable',
|
||||
);
|
||||
|
||||
|
@ -761,9 +791,9 @@ class FaceMLDataDB {
|
|||
|
||||
// for a given personID, return a map of clusterID to fileIDs using join query
|
||||
Future<Map<int, Set<int>>> getFileIdToClusterIDSet(String personID) {
|
||||
final db = instance.database;
|
||||
final db = instance.sqliteAsyncDB;
|
||||
return db.then((db) async {
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $faceClustersTable.$fcClusterID, $fcFaceId FROM $faceClustersTable '
|
||||
'INNER JOIN $clusterPersonTable '
|
||||
'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn '
|
||||
|
@ -784,9 +814,9 @@ class FaceMLDataDB {
|
|||
Future<Map<int, Set<int>>> getFileIdToClusterIDSetForCluster(
|
||||
Set<int> clusterIDs,
|
||||
) {
|
||||
final db = instance.database;
|
||||
final db = instance.sqliteAsyncDB;
|
||||
return db.then((db) async {
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable '
|
||||
'WHERE $fcClusterID IN (${clusterIDs.join(",")})',
|
||||
);
|
||||
|
@ -846,9 +876,26 @@ class FaceMLDataDB {
|
|||
return result;
|
||||
}
|
||||
|
||||
Future<Map<int, (Uint8List, int)>> getClusterToClusterSummary(
|
||||
Iterable<int> clusterIDs,
|
||||
) async {
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final Map<int, (Uint8List, int)> result = {};
|
||||
final rows = await db.getAll(
|
||||
'SELECT * FROM $clusterSummaryTable WHERE $clusterIDColumn IN (${clusterIDs.join(",")})',
|
||||
);
|
||||
for (final r in rows) {
|
||||
final id = r[clusterIDColumn] as int;
|
||||
final avg = r[avgColumn] as Uint8List;
|
||||
final count = r[countColumn] as int;
|
||||
result[id] = (avg, count);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
Future<Map<int, String>> getClusterIDToPersonID() async {
|
||||
final db = await instance.database;
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable',
|
||||
);
|
||||
final Map<int, String> result = {};
|
||||
|
|
|
@ -61,7 +61,7 @@ class EntityService {
|
|||
}) async {
|
||||
final key = await getOrCreateEntityKey(type);
|
||||
final encryptedKeyData = await CryptoUtil.encryptChaCha(
|
||||
utf8.encode(plainText) as Uint8List,
|
||||
utf8.encode(plainText),
|
||||
key,
|
||||
);
|
||||
final String encryptedData =
|
||||
|
|
|
@ -1,5 +1,18 @@
|
|||
import 'dart:math' show sqrt;
|
||||
|
||||
import "package:ml_linalg/vector.dart";
|
||||
|
||||
/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg
|
||||
///
|
||||
/// WARNING: This assumes both vectors are already normalized!
|
||||
double cosineDistanceSIMD(Vector vector1, Vector vector2) {
|
||||
if (vector1.length != vector2.length) {
|
||||
throw ArgumentError('Vectors must be the same length');
|
||||
}
|
||||
|
||||
return 1 - vector1.dot(vector2);
|
||||
}
|
||||
|
||||
/// Calculates the cosine distance between two embeddings/vectors.
|
||||
///
|
||||
/// Throws an ArgumentError if the vectors are of different lengths or
|
||||
|
|
|
@ -69,7 +69,7 @@ class FaceClusteringService {
|
|||
bool isRunning = false;
|
||||
|
||||
static const kRecommendedDistanceThreshold = 0.24;
|
||||
static const kConservativeDistanceThreshold = 0.06;
|
||||
static const kConservativeDistanceThreshold = 0.16;
|
||||
|
||||
// singleton pattern
|
||||
FaceClusteringService._privateConstructor();
|
||||
|
@ -560,10 +560,10 @@ class FaceClusteringService {
|
|||
for (int j = i - 1; j >= 0; j--) {
|
||||
late double distance;
|
||||
if (sortedFaceInfos[i].vEmbedding != null) {
|
||||
distance = 1.0 -
|
||||
sortedFaceInfos[i]
|
||||
.vEmbedding!
|
||||
.dot(sortedFaceInfos[j].vEmbedding!);
|
||||
distance = cosineDistanceSIMD(
|
||||
sortedFaceInfos[i].vEmbedding!,
|
||||
sortedFaceInfos[j].vEmbedding!,
|
||||
);
|
||||
} else {
|
||||
distance = cosineDistForNormVectors(
|
||||
sortedFaceInfos[i].embedding!,
|
||||
|
@ -804,8 +804,10 @@ class FaceClusteringService {
|
|||
double closestDistance = double.infinity;
|
||||
for (int j = 0; j < totalFaces; j++) {
|
||||
if (i == j) continue;
|
||||
final double distance =
|
||||
1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
|
||||
final double distance = cosineDistanceSIMD(
|
||||
faceInfos[i].vEmbedding!,
|
||||
faceInfos[j].vEmbedding!,
|
||||
);
|
||||
if (distance < closestDistance) {
|
||||
closestDistance = distance;
|
||||
closestIdx = j;
|
||||
|
@ -855,10 +857,10 @@ class FaceClusteringService {
|
|||
for (int i = 0; i < clusterIds.length; i++) {
|
||||
for (int j = 0; j < clusterIds.length; j++) {
|
||||
if (i == j) continue;
|
||||
final double newDistance = 1.0 -
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
|
||||
);
|
||||
final double newDistance = cosineDistanceSIMD(
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1,
|
||||
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
|
||||
);
|
||||
if (newDistance < distance) {
|
||||
distance = newDistance;
|
||||
clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
|
||||
|
@ -959,9 +961,9 @@ class FaceClusteringService {
|
|||
|
||||
// Run the DBSCAN clustering
|
||||
final List<List<int>> clusterOutput = dbscan.run(embeddings);
|
||||
final List<List<FaceInfo>> clusteredFaceInfos = clusterOutput
|
||||
.map((cluster) => cluster.map((idx) => faceInfos[idx]).toList())
|
||||
.toList();
|
||||
// final List<List<FaceInfo>> clusteredFaceInfos = clusterOutput
|
||||
// .map((cluster) => cluster.map((idx) => faceInfos[idx]).toList())
|
||||
// .toList();
|
||||
final List<List<String>> clusteredFaceIDs = clusterOutput
|
||||
.map((cluster) => cluster.map((idx) => faceInfos[idx].faceID).toList())
|
||||
.toList();
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart';
|
||||
|
||||
/// Blur detection threshold
|
||||
const kLaplacianHardThreshold = 15;
|
||||
const kLaplacianSoftThreshold = 100;
|
||||
const kLaplacianHardThreshold = 10;
|
||||
const kLaplacianSoftThreshold = 50;
|
||||
const kLaplacianVerySoftThreshold = 200;
|
||||
|
||||
/// Default blur value
|
||||
|
|
|
@ -350,7 +350,7 @@ class FaceMlService {
|
|||
}
|
||||
|
||||
await FaceMLDataDB.instance
|
||||
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
|
||||
.updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster);
|
||||
await FaceMLDataDB.instance
|
||||
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
|
||||
_logger.info(
|
||||
|
@ -403,7 +403,7 @@ class FaceMlService {
|
|||
'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB',
|
||||
);
|
||||
await FaceMLDataDB.instance
|
||||
.updateClusterIdToFaceId(clusteringResult.newFaceIdToCluster);
|
||||
.updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster);
|
||||
await FaceMLDataDB.instance
|
||||
.clusterSummaryUpdate(clusteringResult.newClusterSummaries!);
|
||||
_logger.info('Done updating FaceIDs with clusterIDs in the DB, in '
|
||||
|
|
|
@ -1,19 +1,18 @@
|
|||
import 'dart:developer' as dev;
|
||||
import "dart:math" show Random;
|
||||
import "dart:math" show Random, min;
|
||||
|
||||
import "package:flutter/foundation.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:ml_linalg/linalg.dart";
|
||||
import "package:photos/core/event_bus.dart";
|
||||
import "package:photos/db/files_db.dart";
|
||||
// import "package:photos/events/files_updated_event.dart";
|
||||
// import "package:photos/events/local_photos_updated_event.dart";
|
||||
import "package:photos/events/people_changed_event.dart";
|
||||
import "package:photos/extensions/stop_watch.dart";
|
||||
import "package:photos/face/db.dart";
|
||||
import "package:photos/face/model/person.dart";
|
||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||
import "package:photos/models/file/file.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
|
@ -25,12 +24,14 @@ class ClusterSuggestion {
|
|||
final double distancePersonToCluster;
|
||||
final bool usedOnlyMeanForSuggestion;
|
||||
final List<EnteFile> filesInCluster;
|
||||
final List<String> faceIDsInCluster;
|
||||
|
||||
ClusterSuggestion(
|
||||
this.clusterIDToMerge,
|
||||
this.distancePersonToCluster,
|
||||
this.usedOnlyMeanForSuggestion,
|
||||
this.filesInCluster,
|
||||
this.faceIDsInCluster,
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -60,19 +61,27 @@ class ClusterFeedbackService {
|
|||
bool extremeFilesFirst = true,
|
||||
}) async {
|
||||
_logger.info(
|
||||
'getClusterFilesForPersonID ${kDebugMode ? person.data.name : person.remoteID}',
|
||||
'getSuggestionForPerson ${kDebugMode ? person.data.name : person.remoteID}',
|
||||
);
|
||||
|
||||
try {
|
||||
// Get the suggestions for the person using centroids and median
|
||||
final List<(int, double, bool)> suggestClusterIds =
|
||||
final startTime = DateTime.now();
|
||||
final List<(int, double, bool)> foundSuggestions =
|
||||
await _getSuggestions(person);
|
||||
final findSuggestionsTime = DateTime.now();
|
||||
_logger.info(
|
||||
'getSuggestionForPerson `_getSuggestions`: Found ${foundSuggestions.length} suggestions in ${findSuggestionsTime.difference(startTime).inMilliseconds} ms',
|
||||
);
|
||||
|
||||
// Get the files for the suggestions
|
||||
final suggestionClusterIDs = foundSuggestions.map((e) => e.$1).toSet();
|
||||
final Map<int, Set<int>> fileIdToClusterID =
|
||||
await FaceMLDataDB.instance.getFileIdToClusterIDSetForCluster(
|
||||
suggestClusterIds.map((e) => e.$1).toSet(),
|
||||
suggestionClusterIDs,
|
||||
);
|
||||
final clusterIdToFaceIDs =
|
||||
await FaceMLDataDB.instance.getClusterToFaceIDs(suggestionClusterIDs);
|
||||
final Map<int, List<EnteFile>> clusterIDToFiles = {};
|
||||
final allFiles = await SearchService.instance.getAllFiles();
|
||||
for (final f in allFiles) {
|
||||
|
@ -89,25 +98,31 @@ class ClusterFeedbackService {
|
|||
}
|
||||
}
|
||||
|
||||
final List<ClusterSuggestion> clusterIdAndFiles = [];
|
||||
for (final clusterSuggestion in suggestClusterIds) {
|
||||
final List<ClusterSuggestion> finalSuggestions = [];
|
||||
for (final clusterSuggestion in foundSuggestions) {
|
||||
if (clusterIDToFiles.containsKey(clusterSuggestion.$1)) {
|
||||
clusterIdAndFiles.add(
|
||||
finalSuggestions.add(
|
||||
ClusterSuggestion(
|
||||
clusterSuggestion.$1,
|
||||
clusterSuggestion.$2,
|
||||
clusterSuggestion.$3,
|
||||
clusterIDToFiles[clusterSuggestion.$1]!,
|
||||
clusterIdToFaceIDs[clusterSuggestion.$1]!.toList(),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
final getFilesTime = DateTime.now();
|
||||
|
||||
final sortingStartTime = DateTime.now();
|
||||
if (extremeFilesFirst) {
|
||||
await _sortSuggestionsOnDistanceToPerson(person, clusterIdAndFiles);
|
||||
await _sortSuggestionsOnDistanceToPerson(person, finalSuggestions);
|
||||
}
|
||||
_logger.info(
|
||||
'getSuggestionForPerson post-processing suggestions took ${DateTime.now().difference(findSuggestionsTime).inMilliseconds} ms, of which sorting took ${DateTime.now().difference(sortingStartTime).inMilliseconds} ms and getting files took ${getFilesTime.difference(findSuggestionsTime).inMilliseconds} ms',
|
||||
);
|
||||
|
||||
return clusterIdAndFiles;
|
||||
return finalSuggestions;
|
||||
} catch (e, s) {
|
||||
_logger.severe("Error in getClusterFilesForPersonID", e, s);
|
||||
rethrow;
|
||||
|
@ -229,13 +244,13 @@ class ClusterFeedbackService {
|
|||
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
||||
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
||||
dev.log(
|
||||
'existing clusters for ${p.data.name} are $personClusters',
|
||||
'${p.data.name} has ${personClusters.length} existing clusters',
|
||||
name: "ClusterFeedbackService",
|
||||
);
|
||||
|
||||
// Get and update the cluster summary to get the avg (centroid) and count
|
||||
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
||||
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
||||
final Map<int, Vector> clusterAvg = await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
);
|
||||
|
@ -397,7 +412,7 @@ class ClusterFeedbackService {
|
|||
final newClusterID = startClusterID + blurValue ~/ 10;
|
||||
faceIdToCluster[faceID] = newClusterID;
|
||||
}
|
||||
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
|
||||
await FaceMLDataDB.instance.updateFaceIdToClusterId(faceIdToCluster);
|
||||
|
||||
Bus.instance.fire(PeopleChangedEvent());
|
||||
} catch (e, s) {
|
||||
|
@ -437,69 +452,81 @@ class ClusterFeedbackService {
|
|||
Future<List<(int, double, bool)>> _getSuggestions(
|
||||
PersonEntity p, {
|
||||
int sampleSize = 50,
|
||||
double maxMedianDistance = 0.65,
|
||||
double maxMedianDistance = 0.62,
|
||||
double goodMedianDistance = 0.55,
|
||||
double maxMeanDistance = 0.65,
|
||||
double goodMeanDistance = 0.5,
|
||||
double goodMeanDistance = 0.50,
|
||||
}) async {
|
||||
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
||||
// Get all the cluster data
|
||||
final startTime = DateTime.now();
|
||||
final faceMlDb = FaceMLDataDB.instance;
|
||||
// final Map<int, List<(int, double)>> suggestions = {};
|
||||
final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount();
|
||||
final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID);
|
||||
final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID);
|
||||
dev.log(
|
||||
'existing clusters for ${p.data.name} are $personClusters, getting all database data took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
name: "getSuggestionsUsingMedian",
|
||||
final personFaceIDs =
|
||||
await FaceMLDataDB.instance.getFaceIDsForPerson(p.remoteID);
|
||||
final personFileIDs = personFaceIDs.map(getFileIdFromFaceId).toSet();
|
||||
w?.log(
|
||||
'${p.data.name} has ${personClusters.length} existing clusters, getting all database data done',
|
||||
);
|
||||
final allClusterIdToFaceIDs =
|
||||
await FaceMLDataDB.instance.getAllClusterIdToFaceIDs();
|
||||
w?.log('getAllClusterIdToFaceIDs done');
|
||||
|
||||
// First only do a simple check on the big clusters
|
||||
final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start();
|
||||
final Map<int, List<double>> clusterAvgBigClusters =
|
||||
await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
minClusterSize: kMinimumClusterSizeSearchResult,
|
||||
);
|
||||
dev.log(
|
||||
'computed avg for ${clusterAvgBigClusters.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
);
|
||||
final List<(int, double)> suggestionsMeanBigClusters = _calcSuggestionsMean(
|
||||
clusterAvgBigClusters,
|
||||
personClusters,
|
||||
ignoredClusters,
|
||||
goodMeanDistance,
|
||||
);
|
||||
if (suggestionsMeanBigClusters.isNotEmpty) {
|
||||
return suggestionsMeanBigClusters
|
||||
.map((e) => (e.$1, e.$2, true))
|
||||
.toList(growable: false);
|
||||
}
|
||||
|
||||
// Get and update the cluster summary to get the avg (centroid) and count
|
||||
final Map<int, List<double>> clusterAvg = await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
);
|
||||
dev.log(
|
||||
'computed avg for ${clusterAvg.length} clusters, in ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
);
|
||||
|
||||
// Find the other cluster candidates based on the mean
|
||||
final List<(int, double)> suggestionsMean = _calcSuggestionsMean(
|
||||
clusterAvg,
|
||||
personClusters,
|
||||
ignoredClusters,
|
||||
goodMeanDistance,
|
||||
);
|
||||
if (suggestionsMean.isNotEmpty) {
|
||||
return suggestionsMean
|
||||
.map((e) => (e.$1, e.$2, true))
|
||||
.toList(growable: false);
|
||||
// First only do a simple check on the big clusters, if the person does not have small clusters yet
|
||||
final smallestPersonClusterSize = personClusters
|
||||
.map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0)
|
||||
.reduce((value, element) => min(value, element));
|
||||
final checkSizes = [20, kMinimumClusterSizeSearchResult, 10, 5, 1];
|
||||
late Map<int, Vector> clusterAvgBigClusters;
|
||||
final List<(int, double)> suggestionsMean = [];
|
||||
for (final minimumSize in checkSizes.toSet()) {
|
||||
// if (smallestPersonClusterSize >= minimumSize) {
|
||||
clusterAvgBigClusters = await _getUpdateClusterAvg(
|
||||
allClusterIdsToCountMap,
|
||||
ignoredClusters,
|
||||
minClusterSize: minimumSize,
|
||||
);
|
||||
w?.log(
|
||||
'Calculate avg for ${clusterAvgBigClusters.length} clusters of min size $minimumSize',
|
||||
);
|
||||
final List<(int, double)> suggestionsMeanBigClusters =
|
||||
_calcSuggestionsMean(
|
||||
clusterAvgBigClusters,
|
||||
personClusters,
|
||||
ignoredClusters,
|
||||
goodMeanDistance,
|
||||
);
|
||||
w?.log(
|
||||
'Calculate suggestions using mean for ${clusterAvgBigClusters.length} clusters of min size $minimumSize',
|
||||
);
|
||||
for (final suggestion in suggestionsMeanBigClusters) {
|
||||
// Skip suggestions that have a high overlap with the person's files
|
||||
final suggestionSet = allClusterIdToFaceIDs[suggestion.$1]!
|
||||
.map((faceID) => getFileIdFromFaceId(faceID))
|
||||
.toSet();
|
||||
final overlap = personFileIDs.intersection(suggestionSet);
|
||||
if (overlap.isNotEmpty &&
|
||||
((overlap.length / suggestionSet.length) > 0.5)) {
|
||||
await FaceMLDataDB.instance.captureNotPersonFeedback(
|
||||
personID: p.remoteID,
|
||||
clusterID: suggestion.$1,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
suggestionsMean.add(suggestion);
|
||||
}
|
||||
if (suggestionsMean.isNotEmpty) {
|
||||
return suggestionsMean
|
||||
.map((e) => (e.$1, e.$2, true))
|
||||
.toList(growable: false);
|
||||
// }
|
||||
}
|
||||
}
|
||||
w?.reset();
|
||||
|
||||
// Find the other cluster candidates based on the median
|
||||
final clusterAvg = clusterAvgBigClusters;
|
||||
final List<(int, double)> moreSuggestionsMean = _calcSuggestionsMean(
|
||||
clusterAvg,
|
||||
personClusters,
|
||||
|
@ -522,21 +549,26 @@ class ClusterFeedbackService {
|
|||
"Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates",
|
||||
);
|
||||
|
||||
watch.logAndReset("Starting median test");
|
||||
w?.logAndReset("Starting median test");
|
||||
// Take the embeddings from the person's clusters in one big list and sample from it
|
||||
final List<Uint8List> personEmbeddingsProto = [];
|
||||
for (final clusterID in personClusters) {
|
||||
final Iterable<Uint8List> embedings =
|
||||
final Iterable<Uint8List> embeddings =
|
||||
await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID);
|
||||
personEmbeddingsProto.addAll(embedings);
|
||||
personEmbeddingsProto.addAll(embeddings);
|
||||
}
|
||||
final List<Uint8List> sampledEmbeddingsProto =
|
||||
_randomSampleWithoutReplacement(
|
||||
personEmbeddingsProto,
|
||||
sampleSize,
|
||||
);
|
||||
final List<List<double>> sampledEmbeddings = sampledEmbeddingsProto
|
||||
.map((embedding) => EVector.fromBuffer(embedding).values)
|
||||
final List<Vector> sampledEmbeddings = sampledEmbeddingsProto
|
||||
.map(
|
||||
(embedding) => Vector.fromList(
|
||||
EVector.fromBuffer(embedding).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
)
|
||||
.toList(growable: false);
|
||||
|
||||
// Find the actual closest clusters for the person using median
|
||||
|
@ -552,16 +584,20 @@ class ClusterFeedbackService {
|
|||
otherEmbeddingsProto,
|
||||
sampleSize,
|
||||
);
|
||||
final List<List<double>> sampledOtherEmbeddings =
|
||||
sampledOtherEmbeddingsProto
|
||||
.map((embedding) => EVector.fromBuffer(embedding).values)
|
||||
.toList(growable: false);
|
||||
final List<Vector> sampledOtherEmbeddings = sampledOtherEmbeddingsProto
|
||||
.map(
|
||||
(embedding) => Vector.fromList(
|
||||
EVector.fromBuffer(embedding).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
)
|
||||
.toList(growable: false);
|
||||
|
||||
// Calculate distances and find the median
|
||||
final List<double> distances = [];
|
||||
for (final otherEmbedding in sampledOtherEmbeddings) {
|
||||
for (final embedding in sampledEmbeddings) {
|
||||
distances.add(cosineDistForNormVectors(embedding, otherEmbedding));
|
||||
distances.add(cosineDistanceSIMD(embedding, otherEmbedding));
|
||||
}
|
||||
}
|
||||
distances.sort();
|
||||
|
@ -575,7 +611,7 @@ class ClusterFeedbackService {
|
|||
}
|
||||
}
|
||||
}
|
||||
watch.log("Finished median test");
|
||||
w?.log("Finished median test");
|
||||
if (suggestionsMedian.isEmpty) {
|
||||
_logger.info("No suggestions found using median");
|
||||
return [];
|
||||
|
@ -607,13 +643,14 @@ class ClusterFeedbackService {
|
|||
return finalSuggestionsMedian;
|
||||
}
|
||||
|
||||
Future<Map<int, List<double>>> _getUpdateClusterAvg(
|
||||
Future<Map<int, Vector>> _getUpdateClusterAvg(
|
||||
Map<int, int> allClusterIdsToCountMap,
|
||||
Set<int> ignoredClusters, {
|
||||
int minClusterSize = 1,
|
||||
int maxClusterInCurrentRun = 500,
|
||||
int maxEmbeddingToRead = 10000,
|
||||
}) async {
|
||||
final w = (kDebugMode ? EnteWatch('_getUpdateClusterAvg') : null)?..start();
|
||||
final startTime = DateTime.now();
|
||||
final faceMlDb = FaceMLDataDB.instance;
|
||||
_logger.info(
|
||||
|
@ -624,16 +661,15 @@ class ClusterFeedbackService {
|
|||
await faceMlDb.getAllClusterSummary(minClusterSize);
|
||||
final Map<int, (Uint8List, int)> updatesForClusterSummary = {};
|
||||
|
||||
final Map<int, List<double>> clusterAvg = {};
|
||||
final Map<int, Vector> clusterAvg = {};
|
||||
|
||||
dev.log(
|
||||
'getUpdateClusterAvg database call for getAllClusterSummary took ${DateTime.now().difference(startTime).inMilliseconds} ms',
|
||||
w?.log(
|
||||
'getUpdateClusterAvg database call for getAllClusterSummary',
|
||||
);
|
||||
|
||||
final allClusterIds = allClusterIdsToCountMap.keys.toSet();
|
||||
int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0;
|
||||
int smallerClustersCnt = 0;
|
||||
final serializationTime = DateTime.now();
|
||||
for (final id in allClusterIdsToCountMap.keys) {
|
||||
if (ignoredClusters.contains(id)) {
|
||||
allClusterIds.remove(id);
|
||||
|
@ -641,7 +677,10 @@ class ClusterFeedbackService {
|
|||
}
|
||||
if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) {
|
||||
allClusterIds.remove(id);
|
||||
clusterAvg[id] = EVector.fromBuffer(clusterToSummary[id]!.$1).values;
|
||||
clusterAvg[id] = Vector.fromList(
|
||||
EVector.fromBuffer(clusterToSummary[id]!.$1).values,
|
||||
dtype: DType.float32,
|
||||
);
|
||||
alreadyUpdatedClustersCnt++;
|
||||
}
|
||||
if (allClusterIdsToCountMap[id]! < minClusterSize) {
|
||||
|
@ -649,8 +688,8 @@ class ClusterFeedbackService {
|
|||
smallerClustersCnt++;
|
||||
}
|
||||
}
|
||||
dev.log(
|
||||
'serialization of embeddings took ${DateTime.now().difference(serializationTime).inMilliseconds} ms',
|
||||
w?.log(
|
||||
'serialization of embeddings',
|
||||
);
|
||||
_logger.info(
|
||||
'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize',
|
||||
|
@ -670,12 +709,7 @@ class ClusterFeedbackService {
|
|||
allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!),
|
||||
);
|
||||
int indexedInCurrentRun = 0;
|
||||
final EnteWatch? w = kDebugMode ? EnteWatch("computeAvg") : null;
|
||||
w?.start();
|
||||
|
||||
w?.log(
|
||||
'reading embeddings for $maxClusterInCurrentRun or ${sortedClusterIDs.length} clusters',
|
||||
);
|
||||
w?.reset();
|
||||
|
||||
int currentPendingRead = 0;
|
||||
final List<int> clusterIdsToRead = [];
|
||||
|
@ -706,19 +740,17 @@ class ClusterFeedbackService {
|
|||
);
|
||||
|
||||
for (final clusterID in clusterEmbeddings.keys) {
|
||||
late List<double> avg;
|
||||
final Iterable<Uint8List> embedings = clusterEmbeddings[clusterID]!;
|
||||
final List<double> sum = List.filled(192, 0);
|
||||
for (final embedding in embedings) {
|
||||
final data = EVector.fromBuffer(embedding).values;
|
||||
for (int i = 0; i < sum.length; i++) {
|
||||
sum[i] += data[i];
|
||||
}
|
||||
}
|
||||
avg = sum.map((e) => e / embedings.length).toList();
|
||||
final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer();
|
||||
final Iterable<Uint8List> embeddings = clusterEmbeddings[clusterID]!;
|
||||
final Iterable<Vector> vectors = embeddings.map(
|
||||
(e) => Vector.fromList(
|
||||
EVector.fromBuffer(e).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
);
|
||||
final avg = vectors.reduce((a, b) => a + b) / vectors.length;
|
||||
final avgEmbeddingBuffer = EVector(values: avg).writeToBuffer();
|
||||
updatesForClusterSummary[clusterID] =
|
||||
(avgEmbeedingBuffer, embedings.length);
|
||||
(avgEmbeddingBuffer, embeddings.length);
|
||||
// store the intermediate updates
|
||||
indexedInCurrentRun++;
|
||||
if (updatesForClusterSummary.length > 100) {
|
||||
|
@ -745,20 +777,22 @@ class ClusterFeedbackService {
|
|||
|
||||
/// Returns a map of person's clusterID to map of closest clusterID to with disstance
|
||||
List<(int, double)> _calcSuggestionsMean(
|
||||
Map<int, List<double>> clusterAvg,
|
||||
Map<int, Vector> clusterAvg,
|
||||
Set<int> personClusters,
|
||||
Set<int> ignoredClusters,
|
||||
double maxClusterDistance, {
|
||||
Map<int, int>? allClusterIdsToCountMap,
|
||||
}) {
|
||||
final Map<int, List<(int, double)>> suggestions = {};
|
||||
int suggestionCount = 0;
|
||||
final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start();
|
||||
for (final otherClusterID in clusterAvg.keys) {
|
||||
// ignore the cluster that belong to the person or is ignored
|
||||
if (personClusters.contains(otherClusterID) ||
|
||||
ignoredClusters.contains(otherClusterID)) {
|
||||
continue;
|
||||
}
|
||||
final otherAvg = clusterAvg[otherClusterID]!;
|
||||
final Vector otherAvg = clusterAvg[otherClusterID]!;
|
||||
int? nearestPersonCluster;
|
||||
double? minDistance;
|
||||
for (final personCluster in personClusters) {
|
||||
|
@ -766,8 +800,8 @@ class ClusterFeedbackService {
|
|||
_logger.info('no avg for cluster $personCluster');
|
||||
continue;
|
||||
}
|
||||
final avg = clusterAvg[personCluster]!;
|
||||
final distance = cosineDistForNormVectors(avg, otherAvg);
|
||||
final Vector avg = clusterAvg[personCluster]!;
|
||||
final distance = cosineDistanceSIMD(avg, otherAvg);
|
||||
if (distance < maxClusterDistance) {
|
||||
if (minDistance == null || distance < minDistance) {
|
||||
minDistance = distance;
|
||||
|
@ -779,30 +813,35 @@ class ClusterFeedbackService {
|
|||
suggestions
|
||||
.putIfAbsent(nearestPersonCluster, () => [])
|
||||
.add((otherClusterID, minDistance));
|
||||
suggestionCount++;
|
||||
}
|
||||
if (suggestionCount >= 2000) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
w?.log('calculation inside calcSuggestionsMean');
|
||||
|
||||
if (suggestions.isNotEmpty) {
|
||||
final List<(int, double)> suggestClusterIds = [];
|
||||
for (final List<(int, double)> suggestion in suggestions.values) {
|
||||
suggestClusterIds.addAll(suggestion);
|
||||
}
|
||||
List<int>? suggestClusterIdsSizes;
|
||||
if (allClusterIdsToCountMap != null) {
|
||||
suggestClusterIds.sort(
|
||||
(a, b) => allClusterIdsToCountMap[b.$1]!
|
||||
.compareTo(allClusterIdsToCountMap[a.$1]!),
|
||||
);
|
||||
suggestClusterIdsSizes = suggestClusterIds
|
||||
.map((e) => allClusterIdsToCountMap[e.$1]!)
|
||||
.toList(growable: false);
|
||||
}
|
||||
final suggestClusterIdsDistances =
|
||||
suggestClusterIds.map((e) => e.$2).toList(growable: false);
|
||||
suggestClusterIds.sort(
|
||||
(a, b) => a.$2.compareTo(b.$2),
|
||||
); // sort by distance
|
||||
|
||||
// List<int>? suggestClusterIdsSizes;
|
||||
// if (allClusterIdsToCountMap != null) {
|
||||
// suggestClusterIdsSizes = suggestClusterIds
|
||||
// .map((e) => allClusterIdsToCountMap[e.$1]!)
|
||||
// .toList(growable: false);
|
||||
// }
|
||||
// final suggestClusterIdsDistances =
|
||||
// suggestClusterIds.map((e) => e.$2).toList(growable: false);
|
||||
_logger.info(
|
||||
"Already found good suggestions using mean: $suggestClusterIds, ${suggestClusterIdsSizes != null ? 'with sizes $suggestClusterIdsSizes' : ''} and distances $suggestClusterIdsDistances",
|
||||
"Already found ${suggestClusterIds.length} good suggestions using mean",
|
||||
);
|
||||
return suggestClusterIds;
|
||||
return suggestClusterIds.sublist(0, min(suggestClusterIds.length, 20));
|
||||
} else {
|
||||
_logger.info("No suggestions found using mean");
|
||||
return <(int, double)>[];
|
||||
|
@ -841,56 +880,88 @@ class ClusterFeedbackService {
|
|||
|
||||
Future<void> _sortSuggestionsOnDistanceToPerson(
|
||||
PersonEntity person,
|
||||
List<ClusterSuggestion> suggestions,
|
||||
) async {
|
||||
List<ClusterSuggestion> suggestions, {
|
||||
bool onlySortBigSuggestions = true,
|
||||
}) async {
|
||||
if (suggestions.isEmpty) {
|
||||
debugPrint('No suggestions to sort');
|
||||
return;
|
||||
}
|
||||
if (onlySortBigSuggestions) {
|
||||
final bigSuggestions = suggestions
|
||||
.where(
|
||||
(s) => s.filesInCluster.length > kMinimumClusterSizeSearchResult,
|
||||
)
|
||||
.toList();
|
||||
if (bigSuggestions.isEmpty) {
|
||||
debugPrint('No big suggestions to sort');
|
||||
return;
|
||||
}
|
||||
}
|
||||
final startTime = DateTime.now();
|
||||
final faceMlDb = FaceMLDataDB.instance;
|
||||
|
||||
// Get the cluster averages for the person's clusters and the suggestions' clusters
|
||||
final Map<int, (Uint8List, int)> clusterToSummary =
|
||||
await faceMlDb.getAllClusterSummary();
|
||||
final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID);
|
||||
final Map<int, (Uint8List, int)> personClusterToSummary =
|
||||
await faceMlDb.getClusterToClusterSummary(personClusters);
|
||||
final clusterSummaryCallTime = DateTime.now();
|
||||
|
||||
// Calculate the avg embedding of the person
|
||||
final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID);
|
||||
final w = (kDebugMode ? EnteWatch('sortSuggestions') : null)?..start();
|
||||
final personEmbeddingsCount = personClusters
|
||||
.map((e) => clusterToSummary[e]!.$2)
|
||||
.map((e) => personClusterToSummary[e]!.$2)
|
||||
.reduce((a, b) => a + b);
|
||||
final List<double> personAvg = List.filled(192, 0);
|
||||
Vector personAvg = Vector.filled(192, 0);
|
||||
for (final personClusterID in personClusters) {
|
||||
final personClusterBlob = clusterToSummary[personClusterID]!.$1;
|
||||
final personClusterAvg = EVector.fromBuffer(personClusterBlob).values;
|
||||
final personClusterBlob = personClusterToSummary[personClusterID]!.$1;
|
||||
final personClusterAvg = Vector.fromList(
|
||||
EVector.fromBuffer(personClusterBlob).values,
|
||||
dtype: DType.float32,
|
||||
);
|
||||
final clusterWeight =
|
||||
clusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
|
||||
for (int i = 0; i < personClusterAvg.length; i++) {
|
||||
personAvg[i] += personClusterAvg[i] *
|
||||
clusterWeight; // Weighted sum of the cluster averages
|
||||
}
|
||||
personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount;
|
||||
personAvg += personClusterAvg * clusterWeight;
|
||||
}
|
||||
w?.log('calculated person avg');
|
||||
|
||||
// Sort the suggestions based on the distance to the person
|
||||
for (final suggestion in suggestions) {
|
||||
if (onlySortBigSuggestions) {
|
||||
if (suggestion.filesInCluster.length <= 8) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
final clusterID = suggestion.clusterIDToMerge;
|
||||
final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFile(
|
||||
suggestion.filesInCluster.map((e) => e.uploadedFileID!).toList(),
|
||||
final faceIDs = suggestion.faceIDsInCluster;
|
||||
final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces(
|
||||
faceIDs,
|
||||
);
|
||||
final faceIdToVectorMap = faceIdToEmbeddingMap.map(
|
||||
(key, value) => MapEntry(
|
||||
key,
|
||||
Vector.fromList(
|
||||
EVector.fromBuffer(value).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
),
|
||||
);
|
||||
w?.log(
|
||||
'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID',
|
||||
);
|
||||
final fileIdToDistanceMap = {};
|
||||
for (final entry in faceIdToEmbeddingMap.entries) {
|
||||
for (final entry in faceIdToVectorMap.entries) {
|
||||
fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] =
|
||||
cosineDistForNormVectors(
|
||||
personAvg,
|
||||
EVector.fromBuffer(entry.value).values,
|
||||
);
|
||||
cosineDistanceSIMD(personAvg, entry.value);
|
||||
}
|
||||
w?.log('calculated distances for cluster $clusterID');
|
||||
suggestion.filesInCluster.sort((b, a) {
|
||||
//todo: review with @laurens, added this to avoid null safety issue
|
||||
final double distanceA = fileIdToDistanceMap[a.uploadedFileID!] ?? -1;
|
||||
final double distanceB = fileIdToDistanceMap[b.uploadedFileID!] ?? -1;
|
||||
return distanceA.compareTo(distanceB);
|
||||
});
|
||||
w?.log('sorted files for cluster $clusterID');
|
||||
|
||||
debugPrint(
|
||||
"[${_logger.name}] Sorted suggestions for cluster $clusterID based on distance to person: ${suggestion.filesInCluster.map((e) => fileIdToDistanceMap[e.uploadedFileID]).toList()}",
|
||||
|
@ -899,7 +970,7 @@ class ClusterFeedbackService {
|
|||
|
||||
final endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"Sorting suggestions based on distance to person took ${endTime.difference(startTime).inMilliseconds} ms for ${suggestions.length} suggestions",
|
||||
"Sorting suggestions based on distance to person took ${endTime.difference(startTime).inMilliseconds} ms for ${suggestions.length} suggestions, of which ${clusterSummaryCallTime.difference(startTime).inMilliseconds} ms was spent on the cluster summary call",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import "dart:async" show unawaited;
|
||||
import "dart:convert";
|
||||
|
||||
import "package:flutter/foundation.dart";
|
||||
|
@ -102,10 +103,12 @@ class PersonService {
|
|||
faces: faceIds.toSet(),
|
||||
);
|
||||
personData.assigned!.add(clusterInfo);
|
||||
await entityService.addOrUpdate(
|
||||
EntityType.person,
|
||||
json.encode(personData.toJson()),
|
||||
id: personID,
|
||||
unawaited(
|
||||
entityService.addOrUpdate(
|
||||
EntityType.person,
|
||||
json.encode(personData.toJson()),
|
||||
id: personID,
|
||||
),
|
||||
);
|
||||
await faceMLDataDB.assignClusterToPerson(
|
||||
personID: personID,
|
||||
|
@ -190,7 +193,7 @@ class PersonService {
|
|||
}
|
||||
|
||||
logger.info("Storing feedback for ${faceIdToClusterID.length} faces");
|
||||
await faceMLDataDB.updateClusterIdToFaceId(faceIdToClusterID);
|
||||
await faceMLDataDB.updateFaceIdToClusterId(faceIdToClusterID);
|
||||
await faceMLDataDB.bulkAssignClusterToPersonID(clusterToPersonID);
|
||||
}
|
||||
|
||||
|
|
|
@ -264,13 +264,56 @@ class _FaceWidgetState extends State<FaceWidget> {
|
|||
},
|
||||
child: Column(
|
||||
children: [
|
||||
SizedBox(
|
||||
width: 60,
|
||||
height: 60,
|
||||
child: CroppedFaceImgImageView(
|
||||
enteFile: widget.file,
|
||||
face: widget.face,
|
||||
),
|
||||
Stack(
|
||||
children: [
|
||||
Container(
|
||||
height: 60,
|
||||
width: 60,
|
||||
decoration: ShapeDecoration(
|
||||
shape: RoundedRectangleBorder(
|
||||
borderRadius: const BorderRadius.all(
|
||||
Radius.elliptical(16, 12),
|
||||
),
|
||||
side: widget.highlight
|
||||
? BorderSide(
|
||||
color: getEnteColorScheme(context).primary700,
|
||||
width: 1.0,
|
||||
)
|
||||
: BorderSide.none,
|
||||
),
|
||||
),
|
||||
child: ClipRRect(
|
||||
borderRadius:
|
||||
const BorderRadius.all(Radius.elliptical(16, 12)),
|
||||
child: SizedBox(
|
||||
width: 60,
|
||||
height: 60,
|
||||
child: CroppedFaceImgImageView(
|
||||
enteFile: widget.file,
|
||||
face: widget.face,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
// TODO: the edges of the green line are still not properly rounded around ClipRRect
|
||||
if (widget.editMode)
|
||||
Positioned(
|
||||
right: 0,
|
||||
top: 0,
|
||||
child: GestureDetector(
|
||||
onTap: _cornerIconPressed,
|
||||
child: isJustRemoved
|
||||
? const Icon(
|
||||
CupertinoIcons.add_circled_solid,
|
||||
color: Colors.green,
|
||||
)
|
||||
: const Icon(
|
||||
Icons.cancel,
|
||||
color: Colors.red,
|
||||
),
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
const SizedBox(height: 8),
|
||||
if (widget.person != null)
|
||||
|
|
|
@ -71,9 +71,9 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
|
|||
];
|
||||
}
|
||||
|
||||
// Remove faces with low scores and blurry faces
|
||||
// Remove faces with low scores
|
||||
if (!kDebugMode) {
|
||||
faces.removeWhere((face) => (face.isBlurry || face.score < 0.75));
|
||||
faces.removeWhere((face) => (face.score < 0.75));
|
||||
}
|
||||
|
||||
if (faces.isEmpty) {
|
||||
|
@ -85,9 +85,6 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
|
|||
];
|
||||
}
|
||||
|
||||
// Sort the faces by score in descending order, so that the highest scoring face is first.
|
||||
faces.sort((Face a, Face b) => b.score.compareTo(a.score));
|
||||
|
||||
// TODO: add deduplication of faces of same person
|
||||
final faceIdsToClusterIds = await FaceMLDataDB.instance
|
||||
.getFaceIdsToClusterIds(faces.map((face) => face.faceID));
|
||||
|
@ -96,6 +93,29 @@ class _FacesItemWidgetState extends State<FacesItemWidget> {
|
|||
final clusterIDToPerson =
|
||||
await FaceMLDataDB.instance.getClusterIDToPersonID();
|
||||
|
||||
// Sort faces by name and score
|
||||
final faceIdToPersonID = <String, String>{};
|
||||
for (final face in faces) {
|
||||
final clusterID = faceIdsToClusterIds[face.faceID];
|
||||
if (clusterID != null) {
|
||||
final personID = clusterIDToPerson[clusterID];
|
||||
if (personID != null) {
|
||||
faceIdToPersonID[face.faceID] = personID;
|
||||
}
|
||||
}
|
||||
}
|
||||
faces.sort((Face a, Face b) {
|
||||
final aPersonID = faceIdToPersonID[a.faceID];
|
||||
final bPersonID = faceIdToPersonID[b.faceID];
|
||||
if (aPersonID != null && bPersonID == null) {
|
||||
return -1;
|
||||
} else if (aPersonID == null && bPersonID != null) {
|
||||
return 1;
|
||||
} else {
|
||||
return b.score.compareTo(a.score);
|
||||
}
|
||||
});
|
||||
|
||||
final lastViewedClusterID = ClusterFeedbackService.lastViewedClusterID;
|
||||
|
||||
final faceWidgets = <FaceWidget>[];
|
||||
|
|
|
@ -207,14 +207,14 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
|
|||
if (embedding.key == otherEmbedding.key) {
|
||||
continue;
|
||||
}
|
||||
final distance64 = 1.0 -
|
||||
Vector.fromList(embedding.value, dtype: DType.float64).dot(
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
|
||||
);
|
||||
final distance32 = 1.0 -
|
||||
Vector.fromList(embedding.value, dtype: DType.float32).dot(
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
|
||||
);
|
||||
final distance64 = cosineDistanceSIMD(
|
||||
Vector.fromList(embedding.value, dtype: DType.float64),
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
|
||||
);
|
||||
final distance32 = cosineDistanceSIMD(
|
||||
Vector.fromList(embedding.value, dtype: DType.float32),
|
||||
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
|
||||
);
|
||||
final distance = cosineDistForNormVectors(
|
||||
embedding.value,
|
||||
otherEmbedding.value,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import "dart:async" show StreamSubscription, unawaited;
|
||||
import "dart:math";
|
||||
|
||||
import "package:flutter/foundation.dart" show kDebugMode;
|
||||
|
@ -29,16 +30,25 @@ class PersonReviewClusterSuggestion extends StatefulWidget {
|
|||
|
||||
class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
|
||||
int currentSuggestionIndex = 0;
|
||||
bool fetch = true;
|
||||
Key futureBuilderKey = UniqueKey();
|
||||
|
||||
// Declare a variable for the future
|
||||
late Future<List<ClusterSuggestion>> futureClusterSuggestions;
|
||||
late StreamSubscription<PeopleChangedEvent> _peopleChangedEvent;
|
||||
|
||||
@override
|
||||
void initState() {
|
||||
super.initState();
|
||||
// Initialize the future in initState
|
||||
_fetchClusterSuggestions();
|
||||
if (fetch) _fetchClusterSuggestions();
|
||||
fetch = true;
|
||||
}
|
||||
|
||||
@override
|
||||
void dispose() {
|
||||
_peopleChangedEvent.cancel();
|
||||
super.dispose();
|
||||
}
|
||||
|
||||
@override
|
||||
|
@ -61,12 +71,27 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
|
|||
),
|
||||
);
|
||||
}
|
||||
final numberOfDifferentSuggestions = snapshot.data!.length;
|
||||
final currentSuggestion = snapshot.data![currentSuggestionIndex];
|
||||
|
||||
final allSuggestions = snapshot.data!;
|
||||
final numberOfDifferentSuggestions = allSuggestions.length;
|
||||
final currentSuggestion = allSuggestions[currentSuggestionIndex];
|
||||
final int clusterID = currentSuggestion.clusterIDToMerge;
|
||||
final double distance = currentSuggestion.distancePersonToCluster;
|
||||
final bool usingMean = currentSuggestion.usedOnlyMeanForSuggestion;
|
||||
final List<EnteFile> files = currentSuggestion.filesInCluster;
|
||||
|
||||
_peopleChangedEvent =
|
||||
Bus.instance.on<PeopleChangedEvent>().listen((event) {
|
||||
if (event.type == PeopleEventType.removedFilesFromCluster &&
|
||||
(event.source == clusterID.toString())) {
|
||||
for (var updatedFile in event.relevantFiles!) {
|
||||
files.remove(updatedFile);
|
||||
}
|
||||
fetch = false;
|
||||
setState(() {});
|
||||
}
|
||||
});
|
||||
|
||||
return InkWell(
|
||||
onTap: () {
|
||||
Navigator.of(context).push(
|
||||
|
@ -90,6 +115,7 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
|
|||
usingMean,
|
||||
files,
|
||||
numberOfDifferentSuggestions,
|
||||
allSuggestions,
|
||||
),
|
||||
),
|
||||
);
|
||||
|
@ -116,20 +142,25 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
|
|||
clusterID: clusterID,
|
||||
);
|
||||
Bus.instance.fire(PeopleChangedEvent());
|
||||
// Increment the suggestion index
|
||||
if (mounted) {
|
||||
setState(() => currentSuggestionIndex++);
|
||||
}
|
||||
|
||||
// Check if we need to fetch new data
|
||||
if (currentSuggestionIndex >= (numberOfSuggestions)) {
|
||||
setState(() {
|
||||
currentSuggestionIndex = 0;
|
||||
futureBuilderKey = UniqueKey(); // Reset to trigger FutureBuilder
|
||||
_fetchClusterSuggestions();
|
||||
});
|
||||
}
|
||||
} else {
|
||||
await FaceMLDataDB.instance.captureNotPersonFeedback(
|
||||
personID: widget.person.remoteID,
|
||||
clusterID: clusterID,
|
||||
);
|
||||
}
|
||||
|
||||
// Increment the suggestion index
|
||||
if (mounted) {
|
||||
setState(() => currentSuggestionIndex++);
|
||||
}
|
||||
|
||||
// Check if we need to fetch new data
|
||||
if (currentSuggestionIndex >= (numberOfSuggestions)) {
|
||||
// Recalculate the suggestions when a suggestion is rejected
|
||||
setState(() {
|
||||
currentSuggestionIndex = 0;
|
||||
futureBuilderKey = UniqueKey(); // Reset to trigger FutureBuilder
|
||||
|
@ -150,9 +181,10 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
|
|||
bool usingMean,
|
||||
List<EnteFile> files,
|
||||
int numberOfSuggestions,
|
||||
List<ClusterSuggestion> allSuggestions,
|
||||
) {
|
||||
return Column(
|
||||
key: ValueKey("cluster_id-$clusterID"),
|
||||
final widgetToReturn = Column(
|
||||
key: ValueKey("cluster_id-$clusterID-files-${files.length}"),
|
||||
children: <Widget>[
|
||||
if (kDebugMode)
|
||||
Text(
|
||||
|
@ -228,6 +260,28 @@ class _PersonClustersState extends State<PersonReviewClusterSuggestion> {
|
|||
),
|
||||
],
|
||||
);
|
||||
// Precompute face thumbnails for next suggestions, in case there are
|
||||
const precompute = 6;
|
||||
const maxComputations = 10;
|
||||
int compCount = 0;
|
||||
|
||||
if (allSuggestions.length > currentSuggestionIndex + 1) {
|
||||
for (final suggestion in allSuggestions.sublist(
|
||||
currentSuggestionIndex + 1,
|
||||
min(allSuggestions.length, currentSuggestionIndex + precompute),
|
||||
)) {
|
||||
final files = suggestion.filesInCluster;
|
||||
final clusterID = suggestion.clusterIDToMerge;
|
||||
for (final file in files.sublist(0, min(files.length, 8))) {
|
||||
unawaited(PersonFaceWidget.precomputeFaceCrops(file, clusterID));
|
||||
compCount++;
|
||||
if (compCount >= maxComputations) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return widgetToReturn;
|
||||
}
|
||||
|
||||
List<Widget> _buildThumbnailWidgets(
|
||||
|
|
|
@ -33,9 +33,64 @@ class PersonFaceWidget extends StatelessWidget {
|
|||
),
|
||||
super(key: key);
|
||||
|
||||
static Future<void> precomputeFaceCrops(file, clusterID) async {
|
||||
try {
|
||||
final Face? face = await FaceMLDataDB.instance.getCoverFaceForPerson(
|
||||
recentFileID: file.uploadedFileID!,
|
||||
clusterID: clusterID,
|
||||
);
|
||||
if (face == null) {
|
||||
debugPrint(
|
||||
"No cover face for cluster $clusterID and recentFile ${file.uploadedFileID}",
|
||||
);
|
||||
return;
|
||||
}
|
||||
final Uint8List? cachedFace = faceCropCache.get(face.faceID);
|
||||
if (cachedFace != null) {
|
||||
return;
|
||||
}
|
||||
final faceCropCacheFile = cachedFaceCropPath(face.faceID);
|
||||
if ((await faceCropCacheFile.exists())) {
|
||||
final data = await faceCropCacheFile.readAsBytes();
|
||||
faceCropCache.put(face.faceID, data);
|
||||
return;
|
||||
}
|
||||
EnteFile? fileForFaceCrop = file;
|
||||
if (face.fileID != file.uploadedFileID!) {
|
||||
fileForFaceCrop =
|
||||
await FilesDB.instance.getAnyUploadedFile(face.fileID);
|
||||
}
|
||||
if (fileForFaceCrop == null) {
|
||||
return;
|
||||
}
|
||||
|
||||
final result = await pool.withResource(
|
||||
() async => await getFaceCrops(
|
||||
fileForFaceCrop!,
|
||||
{
|
||||
face.faceID: face.detection.box,
|
||||
},
|
||||
),
|
||||
);
|
||||
final Uint8List? computedCrop = result?[face.faceID];
|
||||
if (computedCrop != null) {
|
||||
faceCropCache.put(face.faceID, computedCrop);
|
||||
faceCropCacheFile.writeAsBytes(computedCrop).ignore();
|
||||
}
|
||||
return;
|
||||
} catch (e, s) {
|
||||
log(
|
||||
"Error getting cover face for cluster $clusterID",
|
||||
error: e,
|
||||
stackTrace: s,
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
@override
|
||||
Widget build(BuildContext context) {
|
||||
if (useGeneratedFaceCrops) {
|
||||
if (!useGeneratedFaceCrops) {
|
||||
return FutureBuilder<Uint8List?>(
|
||||
future: getFaceCrop(),
|
||||
builder: (context, snapshot) {
|
||||
|
|
|
@ -11,7 +11,7 @@ import "package:photos/utils/thumbnail_util.dart";
|
|||
import "package:pool/pool.dart";
|
||||
|
||||
final LRUMap<String, Uint8List?> faceCropCache = LRUMap(1000);
|
||||
final pool = Pool(5, timeout: const Duration(seconds: 15));
|
||||
final pool = Pool(10, timeout: const Duration(seconds: 15));
|
||||
Future<Map<String, Uint8List>?> getFaceCrops(
|
||||
EnteFile file,
|
||||
Map<String, FaceBox> faceBoxeMap,
|
||||
|
|
Loading…
Reference in a new issue