[mob] Add completeClustering functionality

This commit is contained in:
laurenspriem 2024-04-18 11:25:48 +05:30
parent e3fd836901
commit 45d18b187c

View file

@ -4,6 +4,7 @@ import "dart:isolate";
import "dart:math" show max;
import "dart:typed_data";
import "package:computer/computer.dart";
import "package:logging/logging.dart";
import "package:ml_linalg/dtype.dart";
import "package:ml_linalg/vector.dart";
@ -42,6 +43,7 @@ enum ClusterOperation { linearIncrementalClustering, dbscanClustering }
class FaceClusteringService {
final _logger = Logger("FaceLinearClustering");
final _computer = Computer.shared();
Timer? _inactivityTimer;
final Duration _inactivityDuration = const Duration(minutes: 3);
@ -243,6 +245,45 @@ class FaceClusteringService {
}
}
Future<Map<String, int>> predictComplete(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold,
}) async {
if (input.isEmpty) {
_logger.warning(
"Complete Clustering dataset of embeddings is empty, returning empty list.",
);
return {};
}
// Clustering inside the isolate
_logger.info(
"Start Complete clustering on ${input.length} embeddings inside computer isolate",
);
try {
final startTime = DateTime.now();
final faceIdToCluster = await _computer.compute(
runCompleteClustering,
param: {
"input": input,
"fileIDToCreationTime": fileIDToCreationTime,
"distanceThreshold": distanceThreshold,
},
taskName: "createImageEmbedding",
) as Map<String, int>;
final endTime = DateTime.now();
_logger.info(
"Complete Clustering took: ${endTime.difference(startTime).inMilliseconds}ms",
);
return faceIdToCluster;
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
}
Future<List<List<String>>> predictDbscan(
Map<String, Uint8List> input, {
Map<int, int>? fileIDToCreationTime,
@ -537,6 +578,104 @@ class FaceClusteringService {
);
}
static Map<String, int> runCompleteClustering(Map args) {
final input = args['input'] as Map<String, Uint8List>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
final distanceThreshold = args['distanceThreshold'] as double;
log(
"[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
);
// Organize everything into a list of FaceInfo objects
final List<FaceInfo> faceInfos = [];
for (final entry in input.entries) {
faceInfos.add(
FaceInfo(
faceID: entry.key,
vEmbedding: Vector.fromList(
EVector.fromBuffer(entry.value).values,
dtype: DType.float32,
),
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
),
);
}
// Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first
if (fileIDToCreationTime != null) {
faceInfos.sort((a, b) {
if (a.fileCreationTime == null && b.fileCreationTime == null) {
return 0;
} else if (a.fileCreationTime == null) {
return 1;
} else if (b.fileCreationTime == null) {
return -1;
} else {
return a.fileCreationTime!.compareTo(b.fileCreationTime!);
}
});
}
if (faceInfos.isEmpty) {
return {};
}
final int totalFaces = faceInfos.length;
// Start actual clustering
log(
"[CompleteClustering] ${DateTime.now()} Processing $totalFaces faces in one single round of complete clustering",
);
// set current epoch time as clusterID
int clusterID = DateTime.now().microsecondsSinceEpoch;
// Start actual clustering
final Map<String, int> newFaceIdToCluster = {};
final stopwatchClustering = Stopwatch()..start();
for (int i = 0; i < totalFaces; i++) {
int closestIdx = -1;
double closestDistance = double.infinity;
if (i % 250 == 0) {
log("[CompleteClustering] ${DateTime.now()} Processed $i faces");
}
for (int j = 0; j < totalFaces; j++) {
if (i == j) continue;
final double distance =
1.0 - faceInfos[i].vEmbedding!.dot(faceInfos[j].vEmbedding!);
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
}
}
if (closestDistance < distanceThreshold) {
if (faceInfos[closestIdx].clusterId == null) {
clusterID++;
faceInfos[closestIdx].clusterId = clusterID;
newFaceIdToCluster[faceInfos[closestIdx].faceID] = clusterID;
}
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
newFaceIdToCluster[faceInfos[i].faceID] =
faceInfos[closestIdx].clusterId!;
} else {
clusterID++;
faceInfos[i].clusterId = clusterID;
newFaceIdToCluster[faceInfos[i].faceID] = clusterID;
}
}
stopwatchClustering.stop();
log(
' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
);
return newFaceIdToCluster;
}
static List<List<String>> _runDbscanClustering(Map args) {
final input = args['input'] as Map<String, Uint8List>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;