[mob] Add merges to predictComplete method

This commit is contained in:
laurenspriem 2024-04-18 14:44:12 +05:30
parent 7a5e1263e0
commit ba58ac1358
2 changed files with 108 additions and 9 deletions

View file

@ -249,6 +249,7 @@ class FaceClusteringService {
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold,
double mergeThreshold = 0.30,
}) async {
if (input.isEmpty) {
_logger.warning(
@ -270,6 +271,7 @@ class FaceClusteringService {
"input": input,
"fileIDToCreationTime": fileIDToCreationTime,
"distanceThreshold": distanceThreshold,
"mergeThreshold": mergeThreshold,
},
taskName: "createImageEmbedding",
) as Map<String, int>;
@ -578,12 +580,11 @@ class FaceClusteringService {
);
}
static Map<String, int> runCompleteClustering(Map args) {
final input = args['input'] as Map<String, Uint8List>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
final distanceThreshold = args['distanceThreshold'] as double;
final mergeThreshold = args['mergeThreshold'] as double;
log(
"[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
@ -637,10 +638,11 @@ class FaceClusteringService {
final Map<String, int> newFaceIdToCluster = {};
final stopwatchClustering = Stopwatch()..start();
for (int i = 0; i < totalFaces; i++) {
if (faceInfos[i].clusterId != null) continue;
int closestIdx = -1;
double closestDistance = double.infinity;
if (i % 250 == 0) {
log("[CompleteClustering] ${DateTime.now()} Processed $i faces");
if (i + 1 % 250 == 0) {
log("[CompleteClustering] ${DateTime.now()} Processed ${i - 1} faces");
}
for (int j = 0; j < totalFaces; j++) {
if (i == j) continue;
@ -656,18 +658,91 @@ class FaceClusteringService {
if (faceInfos[closestIdx].clusterId == null) {
clusterID++;
faceInfos[closestIdx].clusterId = clusterID;
newFaceIdToCluster[faceInfos[closestIdx].faceID] = clusterID;
}
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
newFaceIdToCluster[faceInfos[i].faceID] =
faceInfos[closestIdx].clusterId!;
} else {
clusterID++;
faceInfos[i].clusterId = clusterID;
newFaceIdToCluster[faceInfos[i].faceID] = clusterID;
}
}
// Now calculate the mean of the embeddings for each cluster
final Map<int, List<FaceInfo>> clusterIdToFaceInfos = {};
for (final faceInfo in faceInfos) {
if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) {
clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo);
} else {
clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
}
}
final Map<int, (Vector, int)> clusterIdToMeanEmbeddingAndWeight = {};
for (final clusterId in clusterIdToFaceInfos.keys) {
final List<Vector> embeddings = clusterIdToFaceInfos[clusterId]!
.map((faceInfo) => faceInfo.vEmbedding!)
.toList();
final count = clusterIdToFaceInfos[clusterId]!.length;
final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count;
clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count);
}
// Now merge the clusters that are close to each other, based on mean embedding
final List<(int, int)> mergedClustersList = [];
final List<int> clusterIds =
clusterIdToMeanEmbeddingAndWeight.keys.toList();
log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges');
while (true) {
if (clusterIds.length < 2) break;
double distance = double.infinity;
(int, int) clusterIDsToMerge = (-1, -1);
for (int i = 0; i < clusterIds.length; i++) {
for (int j = 0; j < clusterIds.length; j++) {
if (i == j) continue;
final double newDistance = 1.0 -
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
);
if (newDistance < distance) {
distance = newDistance;
clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
}
}
}
if (distance < mergeThreshold) {
mergedClustersList.add(clusterIDsToMerge);
final clusterID1 = clusterIDsToMerge.$1;
final clusterID2 = clusterIDsToMerge.$2;
final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1;
final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1;
final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2;
final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2;
final weight1 = count1 / (count1 + count2);
final weight2 = count2 / (count1 + count2);
clusterIdToMeanEmbeddingAndWeight[clusterID1] = (
mean1 * weight1 + mean2 * weight2,
count1 + count2,
);
clusterIdToMeanEmbeddingAndWeight.remove(clusterID2);
clusterIds.remove(clusterID2);
} else {
break;
}
}
log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged');
// Now assign the new clusterId to the faces
for (final faceInfo in faceInfos) {
for (final mergedClusters in mergedClustersList) {
if (faceInfo.clusterId == mergedClusters.$2) {
faceInfo.clusterId = mergedClusters.$1;
}
}
}
// Finally, assign the new clusterId to the faces
for (final faceInfo in faceInfos) {
newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
}
stopwatchClustering.stop();
log(
' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',

View file

@ -13,6 +13,8 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
// import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
// import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import "package:photos/services/search_service.dart";
@ -232,14 +234,36 @@ class ClusterFeedbackService {
maxClusterID++;
}
} else {
// final clusteringInput = embeddings
// .map((key, value) {
// return MapEntry(
// key,
// FaceInfoForClustering(
// faceID: key,
// embeddingBytes: value,
// faceScore: kMinHighQualityFaceScore + 0.01,
// blurValue: kLapacianDefault,
// ),
// );
// })
// .values
// .toSet();
// final faceIdToCluster =
// await FaceClusteringService.instance.predictLinear(
// clusteringInput,
// fileIDToCreationTime: fileIDToCreationTime,
// distanceThreshold: 0.23,
// useDynamicThreshold: false,
// );
final faceIdToCluster =
await FaceClusteringService.instance.predictComplete(
embeddings,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.30,
mergeThreshold: 0.30,
);
if (faceIdToCluster.isEmpty) {
if (faceIdToCluster == null || faceIdToCluster.isEmpty) {
_logger.info('No clusters found');
return {};
} else {