Refactor of clustering

This commit is contained in:
laurenspriem 2024-03-21 16:59:55 +05:30
parent 212208ae01
commit b5cff212bb

View file

@ -7,6 +7,7 @@ import "dart:typed_data";
import "package:logging/logging.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:synchronized/synchronized.dart";
class FaceInfo {
@ -15,10 +16,12 @@ class FaceInfo {
int? clusterId;
String? closestFaceId;
int? closestDist;
int? fileCreationTime;
FaceInfo({
required this.faceID,
required this.embedding,
this.clusterId,
this.fileCreationTime,
});
}
@ -31,7 +34,6 @@ class FaceLinearClustering {
final Duration _inactivityDuration = const Duration(seconds: 30);
int _activeTasks = 0;
final _initLock = Lock();
late Isolate _isolate;
@ -220,6 +222,8 @@ class FaceLinearClustering {
log(
"[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces",
);
// Organize everything into a list of FaceInfo objects
final List<FaceInfo> faceInfos = [];
for (final entry in x.entries) {
faceInfos.add(
@ -249,59 +253,61 @@ class FaceLinearClustering {
}
// Sort the faceInfos such that the ones with null clusterId are at the end
faceInfos.sort((a, b) {
if (a.clusterId == null && b.clusterId == null) {
return 0;
} else if (a.clusterId == null) {
return 1;
} else if (b.clusterId == null) {
return -1;
} else {
return 0;
}
});
// Count the amount of null values at the end
int nullCount = 0;
for (final faceInfo in faceInfos.reversed) {
final List<FaceInfo> facesWithClusterID = <FaceInfo>[];
final List<FaceInfo> facesWithoutClusterID = <FaceInfo>[];
for (final FaceInfo faceInfo in faceInfos) {
if (faceInfo.clusterId == null) {
nullCount++;
facesWithoutClusterID.add(faceInfo);
} else {
break;
facesWithClusterID.add(faceInfo);
}
}
final sortedFaceInfos = <FaceInfo>[];
sortedFaceInfos.addAll(facesWithClusterID);
sortedFaceInfos.addAll(facesWithoutClusterID);
log(
"[ClusterIsolate] ${DateTime.now()} Clustering $nullCount new faces without clusterId, and ${faceInfos.length - nullCount} faces with clusterId",
"[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and ${facesWithClusterID.length} faces with clusterId",
);
for (final clusteredFaceInfo
in faceInfos.sublist(0, faceInfos.length - nullCount)) {
assert(clusteredFaceInfo.clusterId != null);
// Make sure the first face has a clusterId
final int totalFaces = sortedFaceInfos.length;
int clusterID = 1;
if (sortedFaceInfos.isNotEmpty) {
if (sortedFaceInfos.first.clusterId == null) {
sortedFaceInfos.first.clusterId = clusterID;
} else {
clusterID = sortedFaceInfos.first.clusterId!;
}
} else {
return {};
}
final int totalFaces = faceInfos.length;
int clusterID = 1;
if (faceInfos.isNotEmpty) {
faceInfos.first.clusterId = clusterID;
}
// Start actual clustering
log(
"[ClusterIsolate] ${DateTime.now()} Processing $totalFaces faces",
);
final Map<String, int> newFaceIdToCluster = {};
final stopwatchClustering = Stopwatch()..start();
for (int i = 1; i < totalFaces; i++) {
// Incremental clustering, so we can skip faces that already have a clusterId
if (faceInfos[i].clusterId != null) {
clusterID = max(clusterID, faceInfos[i].clusterId!);
if (sortedFaceInfos[i].clusterId != null) {
clusterID = max(clusterID, sortedFaceInfos[i].clusterId!);
if (i % 250 == 0) {
log("[ClusterIsolate] ${DateTime.now()} First $i faces already had a clusterID");
}
continue;
}
final currentEmbedding = faceInfos[i].embedding;
final currentEmbedding = sortedFaceInfos[i].embedding;
int closestIdx = -1;
double closestDistance = double.infinity;
if (i % 250 == 0) {
log("[ClusterIsolate] ${DateTime.now()} Processing $i faces");
}
for (int j = 0; j < i; j++) {
for (int j = i - 1; j >= 0; j--) {
final double distance = cosineDistForNormVectors(
currentEmbedding,
faceInfos[j].embedding,
sortedFaceInfos[j].embedding,
);
if (distance < closestDistance) {
closestDistance = distance;
@ -310,42 +316,43 @@ class FaceLinearClustering {
}
if (closestDistance < recommendedDistanceThreshold) {
if (faceInfos[closestIdx].clusterId == null) {
if (sortedFaceInfos[closestIdx].clusterId == null) {
// Ideally this should never happen, but just in case log it
log(
" [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID",
" [ClusterIsolate] [WARNING] ${DateTime.now()} Found new cluster $clusterID",
);
clusterID++;
faceInfos[closestIdx].clusterId = clusterID;
sortedFaceInfos[closestIdx].clusterId = clusterID;
newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
}
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId;
sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
newFaceIdToCluster[sortedFaceInfos[i].faceID] =
sortedFaceInfos[closestIdx].clusterId!;
} else {
clusterID++;
faceInfos[i].clusterId = clusterID;
sortedFaceInfos[i].clusterId = clusterID;
newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;
}
}
final Map<String, int> result = {};
for (final faceInfo in faceInfos) {
result[faceInfo.faceID] = faceInfo.clusterId!;
}
stopwatchClustering.stop();
log(
' [ClusterIsolate] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings (${faceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings (${sortedFaceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
);
// return result;
// NOTe: The main clustering logic is done, the following is just filtering and logging
final input = x;
final faceIdToCluster = result;
stopwatchClustering.reset();
stopwatchClustering.start();
// analyze the results
FaceLinearClustering._analyzeClusterResults(sortedFaceInfos);
final Set<String> newFaceIds = <String>{};
input.forEach((key, value) {
if (value.$1 == null) {
newFaceIds.add(key);
return newFaceIdToCluster;
}
static void _analyzeClusterResults(List<FaceInfo> sortedFaceInfos) {
final stopwatch = Stopwatch()..start();
final Map<String, int> faceIdToCluster = {};
for (final faceInfo in sortedFaceInfos) {
faceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
}
});
// Find faceIDs that are part of a cluster which is larger than 5 and are new faceIDs
final Map<int, int> clusterIdToSize = {};
@ -356,12 +363,6 @@ class FaceLinearClustering {
clusterIdToSize[value] = 1;
}
});
final Map<String, int> faceIdToClusterFiltered = {};
for (final entry in faceIdToCluster.entries) {
if (clusterIdToSize[entry.value]! > 0 && newFaceIds.contains(entry.key)) {
faceIdToClusterFiltered[entry.key] = entry.value;
}
}
// print top 10 cluster ids and their sizes based on the internal cluster id
final clusterIds = faceIdToCluster.values.toSet();
@ -369,7 +370,7 @@ class FaceLinearClustering {
return faceIdToCluster.values.where((id) => id == clusterId).length;
}).toList();
clusterSizes.sort();
// find clusters whose size is graeter than 1
// find clusters whose size is greater than 1
int oneClusterCount = 0;
int moreThan5Count = 0;
int moreThan10Count = 0;
@ -377,57 +378,29 @@ class FaceLinearClustering {
int moreThan50Count = 0;
int moreThan100Count = 0;
// for (int i = 0; i < clusterSizes.length; i++) {
// if (clusterSizes[i] > 100) {
// moreThan100Count++;
// } else if (clusterSizes[i] > 50) {
// moreThan50Count++;
// } else if (clusterSizes[i] > 20) {
// moreThan20Count++;
// } else if (clusterSizes[i] > 10) {
// moreThan10Count++;
// } else if (clusterSizes[i] > 5) {
// moreThan5Count++;
// } else if (clusterSizes[i] == 1) {
// oneClusterCount++;
// }
// }
for (int i = 0; i < clusterSizes.length; i++) {
if (clusterSizes[i] > 100) {
moreThan100Count++;
}
if (clusterSizes[i] > 50) {
} else if (clusterSizes[i] > 50) {
moreThan50Count++;
}
if (clusterSizes[i] > 20) {
} else if (clusterSizes[i] > 20) {
moreThan20Count++;
}
if (clusterSizes[i] > 10) {
} else if (clusterSizes[i] > 10) {
moreThan10Count++;
}
if (clusterSizes[i] > 5) {
} else if (clusterSizes[i] > 5) {
moreThan5Count++;
}
if (clusterSizes[i] == 1) {
} else if (clusterSizes[i] == 1) {
oneClusterCount++;
}
}
// print the metrics
log(
'[ClusterIsolate] Total clusters ${clusterIds.length}, '
'oneClusterCount $oneClusterCount, '
'moreThan5Count $moreThan5Count, '
'moreThan10Count $moreThan10Count, '
'moreThan20Count $moreThan20Count, '
'moreThan50Count $moreThan50Count, '
'moreThan100Count $moreThan100Count',
"[ClusterIsolate] Total clusters ${clusterIds.length}: \n oneClusterCount $oneClusterCount \n moreThan5Count $moreThan5Count \n moreThan10Count $moreThan10Count \n moreThan20Count $moreThan20Count \n moreThan50Count $moreThan50Count \n moreThan100Count $moreThan100Count",
);
stopwatchClustering.stop();
stopwatch.stop();
log(
"[ClusterIsolate] Clustering additional steps took ${stopwatchClustering.elapsedMilliseconds} ms",
"[ClusterIsolate] Clustering additional analysis took ${stopwatch.elapsedMilliseconds} ms",
);
// log('Top clusters count ${clusterSizes.reversed.take(10).toList()}');
return faceIdToClusterFiltered;
}
}