[mob] Add more validation for clustering

This commit is contained in:
laurenspriem 2024-04-05 15:50:52 +05:30
parent 723253a12c
commit 0c72fd2a69
2 changed files with 86 additions and 0 deletions

View file

@ -395,6 +395,12 @@ class FaceClustering {
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
// if (distance < distanceThreshold) {
// if (sortedFaceInfos[j].faceID.startsWith("14914702") ||
// sortedFaceInfos[j].faceID.startsWith("15488756")) {
// log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance');
// }
// }
}
}
@ -408,10 +414,22 @@ class FaceClustering {
sortedFaceInfos[closestIdx].clusterId = clusterID;
newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
}
// if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
// sortedFaceInfos[i].faceID.startsWith("15488756")) {
// log(
// "[XXX] [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance",
// );
// }
sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
newFaceIdToCluster[sortedFaceInfos[i].faceID] =
sortedFaceInfos[closestIdx].clusterId!;
} else {
// if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
// sortedFaceInfos[i].faceID.startsWith("15488756")) {
// log(
// "[XXX] [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}",
// );
// }
clusterID++;
sortedFaceInfos[i].clusterId = clusterID;
newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;

View file

@ -3,14 +3,18 @@ import 'dart:async';
import "package:flutter/foundation.dart";
import 'package:flutter/material.dart';
import 'package:logging/logging.dart';
import "package:ml_linalg/linalg.dart";
import 'package:photos/core/configuration.dart';
import 'package:photos/core/event_bus.dart';
import "package:photos/db/files_db.dart";
import 'package:photos/events/subscription_purchased_event.dart';
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/models/gallery_type.dart';
import 'package:photos/models/selected_files.dart';
import 'package:photos/services/collections_service.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
@ -39,6 +43,7 @@ class ClusterAppBar extends StatefulWidget {
enum ClusterPopupAction {
setCover,
breakupCluster,
validateCluster,
hide,
}
@ -127,6 +132,18 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
],
),
),
const PopupMenuItem(
value: ClusterPopupAction.validateCluster,
child: Row(
children: [
Icon(Icons.search_off_outlined),
Padding(
padding: EdgeInsets.all(8),
),
Text('Validate cluster'),
],
),
),
// PopupMenuItem(
// value: ClusterPopupAction.hide,
// child: Row(
@ -152,6 +169,8 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
if (value == ClusterPopupAction.breakupCluster) {
// ignore: unawaited_futures
await _breakUpCluster(context);
} else if (value == ClusterPopupAction.validateCluster) {
await _validateCluster(context);
}
// else if (value == ClusterPopupAction.setCover) {
// await setCoverPhoto(context);
@ -166,6 +185,55 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
return actions;
}
Future<void> _validateCluster(BuildContext context) async {
_logger.info('_validateCluster called');
final faceMlDb = FaceMLDataDB.instance;
final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID);
final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key));
final embeddings = embeddingsBlobs
.map((key, value) => MapEntry(key, EVector.fromBuffer(value).values));
for (final MapEntry<String, List<double>> embedding in embeddings.entries) {
double closestDistance = double.infinity;
double closestDistance32 = double.infinity;
double closestDistance64 = double.infinity;
String? closestFaceID;
for (final MapEntry<String, List<double>> otherEmbedding
in embeddings.entries) {
if (embedding.key == otherEmbedding.key) {
continue;
}
final distance64 = 1.0 -
Vector.fromList(embedding.value, dtype: DType.float64).dot(
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
);
final distance32 = 1.0 -
Vector.fromList(embedding.value, dtype: DType.float32).dot(
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
);
final distance = cosineDistForNormVectors(
embedding.value,
otherEmbedding.value,
);
if (distance < closestDistance) {
closestDistance = distance;
closestDistance32 = distance32;
closestDistance64 = distance64;
closestFaceID = otherEmbedding.key;
}
}
if (closestDistance > 0.3) {
_logger.severe(
"Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64",
);
}
}
}
Future<void> _breakUpCluster(BuildContext context) async {
final newClusterIDToFaceIDs =
await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID);