[mob] Add merges to predictComplete method
This commit is contained in:
parent
7a5e1263e0
commit
ba58ac1358
|
@ -249,6 +249,7 @@ class FaceClusteringService {
|
||||||
Map<String, Uint8List> input, {
|
Map<String, Uint8List> input, {
|
||||||
Map<int, int>? fileIDToCreationTime,
|
Map<int, int>? fileIDToCreationTime,
|
||||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||||
|
double mergeThreshold = 0.30,
|
||||||
}) async {
|
}) async {
|
||||||
if (input.isEmpty) {
|
if (input.isEmpty) {
|
||||||
_logger.warning(
|
_logger.warning(
|
||||||
|
@ -270,6 +271,7 @@ class FaceClusteringService {
|
||||||
"input": input,
|
"input": input,
|
||||||
"fileIDToCreationTime": fileIDToCreationTime,
|
"fileIDToCreationTime": fileIDToCreationTime,
|
||||||
"distanceThreshold": distanceThreshold,
|
"distanceThreshold": distanceThreshold,
|
||||||
|
"mergeThreshold": mergeThreshold,
|
||||||
},
|
},
|
||||||
taskName: "createImageEmbedding",
|
taskName: "createImageEmbedding",
|
||||||
) as Map<String, int>;
|
) as Map<String, int>;
|
||||||
|
@ -578,12 +580,11 @@ class FaceClusteringService {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static Map<String, int> runCompleteClustering(Map args) {
|
static Map<String, int> runCompleteClustering(Map args) {
|
||||||
final input = args['input'] as Map<String, Uint8List>;
|
final input = args['input'] as Map<String, Uint8List>;
|
||||||
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
||||||
final distanceThreshold = args['distanceThreshold'] as double;
|
final distanceThreshold = args['distanceThreshold'] as double;
|
||||||
|
final mergeThreshold = args['mergeThreshold'] as double;
|
||||||
|
|
||||||
log(
|
log(
|
||||||
"[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
|
"[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering",
|
||||||
|
@ -637,10 +638,11 @@ class FaceClusteringService {
|
||||||
final Map<String, int> newFaceIdToCluster = {};
|
final Map<String, int> newFaceIdToCluster = {};
|
||||||
final stopwatchClustering = Stopwatch()..start();
|
final stopwatchClustering = Stopwatch()..start();
|
||||||
for (int i = 0; i < totalFaces; i++) {
|
for (int i = 0; i < totalFaces; i++) {
|
||||||
|
if (faceInfos[i].clusterId != null) continue;
|
||||||
int closestIdx = -1;
|
int closestIdx = -1;
|
||||||
double closestDistance = double.infinity;
|
double closestDistance = double.infinity;
|
||||||
if (i % 250 == 0) {
|
if (i + 1 % 250 == 0) {
|
||||||
log("[CompleteClustering] ${DateTime.now()} Processed $i faces");
|
log("[CompleteClustering] ${DateTime.now()} Processed ${i - 1} faces");
|
||||||
}
|
}
|
||||||
for (int j = 0; j < totalFaces; j++) {
|
for (int j = 0; j < totalFaces; j++) {
|
||||||
if (i == j) continue;
|
if (i == j) continue;
|
||||||
|
@ -656,18 +658,91 @@ class FaceClusteringService {
|
||||||
if (faceInfos[closestIdx].clusterId == null) {
|
if (faceInfos[closestIdx].clusterId == null) {
|
||||||
clusterID++;
|
clusterID++;
|
||||||
faceInfos[closestIdx].clusterId = clusterID;
|
faceInfos[closestIdx].clusterId = clusterID;
|
||||||
newFaceIdToCluster[faceInfos[closestIdx].faceID] = clusterID;
|
|
||||||
}
|
}
|
||||||
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
|
faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!;
|
||||||
newFaceIdToCluster[faceInfos[i].faceID] =
|
|
||||||
faceInfos[closestIdx].clusterId!;
|
|
||||||
} else {
|
} else {
|
||||||
clusterID++;
|
clusterID++;
|
||||||
faceInfos[i].clusterId = clusterID;
|
faceInfos[i].clusterId = clusterID;
|
||||||
newFaceIdToCluster[faceInfos[i].faceID] = clusterID;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now calculate the mean of the embeddings for each cluster
|
||||||
|
final Map<int, List<FaceInfo>> clusterIdToFaceInfos = {};
|
||||||
|
for (final faceInfo in faceInfos) {
|
||||||
|
if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) {
|
||||||
|
clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo);
|
||||||
|
} else {
|
||||||
|
clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
final Map<int, (Vector, int)> clusterIdToMeanEmbeddingAndWeight = {};
|
||||||
|
for (final clusterId in clusterIdToFaceInfos.keys) {
|
||||||
|
final List<Vector> embeddings = clusterIdToFaceInfos[clusterId]!
|
||||||
|
.map((faceInfo) => faceInfo.vEmbedding!)
|
||||||
|
.toList();
|
||||||
|
final count = clusterIdToFaceInfos[clusterId]!.length;
|
||||||
|
final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count;
|
||||||
|
clusterIdToMeanEmbeddingAndWeight[clusterId] = (meanEmbedding, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now merge the clusters that are close to each other, based on mean embedding
|
||||||
|
final List<(int, int)> mergedClustersList = [];
|
||||||
|
final List<int> clusterIds =
|
||||||
|
clusterIdToMeanEmbeddingAndWeight.keys.toList();
|
||||||
|
log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges');
|
||||||
|
while (true) {
|
||||||
|
if (clusterIds.length < 2) break;
|
||||||
|
double distance = double.infinity;
|
||||||
|
(int, int) clusterIDsToMerge = (-1, -1);
|
||||||
|
for (int i = 0; i < clusterIds.length; i++) {
|
||||||
|
for (int j = 0; j < clusterIds.length; j++) {
|
||||||
|
if (i == j) continue;
|
||||||
|
final double newDistance = 1.0 -
|
||||||
|
clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1.dot(
|
||||||
|
clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1,
|
||||||
|
);
|
||||||
|
if (newDistance < distance) {
|
||||||
|
distance = newDistance;
|
||||||
|
clusterIDsToMerge = (clusterIds[i], clusterIds[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (distance < mergeThreshold) {
|
||||||
|
mergedClustersList.add(clusterIDsToMerge);
|
||||||
|
final clusterID1 = clusterIDsToMerge.$1;
|
||||||
|
final clusterID2 = clusterIDsToMerge.$2;
|
||||||
|
final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1;
|
||||||
|
final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1;
|
||||||
|
final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2;
|
||||||
|
final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2;
|
||||||
|
final weight1 = count1 / (count1 + count2);
|
||||||
|
final weight2 = count2 / (count1 + count2);
|
||||||
|
clusterIdToMeanEmbeddingAndWeight[clusterID1] = (
|
||||||
|
mean1 * weight1 + mean2 * weight2,
|
||||||
|
count1 + count2,
|
||||||
|
);
|
||||||
|
clusterIdToMeanEmbeddingAndWeight.remove(clusterID2);
|
||||||
|
clusterIds.remove(clusterID2);
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged');
|
||||||
|
|
||||||
|
// Now assign the new clusterId to the faces
|
||||||
|
for (final faceInfo in faceInfos) {
|
||||||
|
for (final mergedClusters in mergedClustersList) {
|
||||||
|
if (faceInfo.clusterId == mergedClusters.$2) {
|
||||||
|
faceInfo.clusterId = mergedClusters.$1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, assign the new clusterId to the faces
|
||||||
|
for (final faceInfo in faceInfos) {
|
||||||
|
newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!;
|
||||||
|
}
|
||||||
|
|
||||||
stopwatchClustering.stop();
|
stopwatchClustering.stop();
|
||||||
log(
|
log(
|
||||||
' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
|
' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
|
||||||
|
|
|
@ -13,6 +13,8 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||||
import "package:photos/models/file/file.dart";
|
import "package:photos/models/file/file.dart";
|
||||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
||||||
|
// import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
|
||||||
|
// import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
||||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||||
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
|
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
|
||||||
import "package:photos/services/search_service.dart";
|
import "package:photos/services/search_service.dart";
|
||||||
|
@ -232,14 +234,36 @@ class ClusterFeedbackService {
|
||||||
maxClusterID++;
|
maxClusterID++;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// final clusteringInput = embeddings
|
||||||
|
// .map((key, value) {
|
||||||
|
// return MapEntry(
|
||||||
|
// key,
|
||||||
|
// FaceInfoForClustering(
|
||||||
|
// faceID: key,
|
||||||
|
// embeddingBytes: value,
|
||||||
|
// faceScore: kMinHighQualityFaceScore + 0.01,
|
||||||
|
// blurValue: kLapacianDefault,
|
||||||
|
// ),
|
||||||
|
// );
|
||||||
|
// })
|
||||||
|
// .values
|
||||||
|
// .toSet();
|
||||||
|
// final faceIdToCluster =
|
||||||
|
// await FaceClusteringService.instance.predictLinear(
|
||||||
|
// clusteringInput,
|
||||||
|
// fileIDToCreationTime: fileIDToCreationTime,
|
||||||
|
// distanceThreshold: 0.23,
|
||||||
|
// useDynamicThreshold: false,
|
||||||
|
// );
|
||||||
final faceIdToCluster =
|
final faceIdToCluster =
|
||||||
await FaceClusteringService.instance.predictComplete(
|
await FaceClusteringService.instance.predictComplete(
|
||||||
embeddings,
|
embeddings,
|
||||||
fileIDToCreationTime: fileIDToCreationTime,
|
fileIDToCreationTime: fileIDToCreationTime,
|
||||||
distanceThreshold: 0.30,
|
distanceThreshold: 0.30,
|
||||||
|
mergeThreshold: 0.30,
|
||||||
);
|
);
|
||||||
|
|
||||||
if (faceIdToCluster.isEmpty) {
|
if (faceIdToCluster == null || faceIdToCluster.isEmpty) {
|
||||||
_logger.info('No clusters found');
|
_logger.info('No clusters found');
|
||||||
return {};
|
return {};
|
||||||
} else {
|
} else {
|
||||||
|
|
Loading…
Reference in a new issue