[mob] Use vector for cosine dist

This commit is contained in:
Neeraj Gupta 2024-04-02 11:53:40 +05:30
parent 4cb7334868
commit 226808aadb
3 changed files with 38 additions and 10 deletions

View file

@ -5,6 +5,8 @@ import "dart:math" show max;
import "dart:typed_data"; import "dart:typed_data";
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:ml_linalg/dtype.dart";
import "package:ml_linalg/vector.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart"; import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
@ -12,14 +14,16 @@ import "package:synchronized/synchronized.dart";
class FaceInfo { class FaceInfo {
final String faceID; final String faceID;
final List<double> embedding; final List<double>? embedding;
final Vector? vEmbedding;
int? clusterId; int? clusterId;
String? closestFaceId; String? closestFaceId;
int? closestDist; int? closestDist;
int? fileCreationTime; int? fileCreationTime;
FaceInfo({ FaceInfo({
required this.faceID, required this.faceID,
required this.embedding, this.embedding,
this.vEmbedding,
this.clusterId, this.clusterId,
this.fileCreationTime, this.fileCreationTime,
}); });
@ -230,7 +234,10 @@ class FaceLinearClustering {
faceInfos.add( faceInfos.add(
FaceInfo( FaceInfo(
faceID: entry.key, faceID: entry.key,
embedding: EVector.fromBuffer(entry.value.$2).values, vEmbedding: Vector.fromList(
EVector.fromBuffer(entry.value.$2).values,
dtype: DType.float32,
),
clusterId: entry.value.$1, clusterId: entry.value.$1,
fileCreationTime: fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
@ -300,17 +307,25 @@ class FaceLinearClustering {
} }
continue; continue;
} }
final currentEmbedding = sortedFaceInfos[i].embedding;
int closestIdx = -1; int closestIdx = -1;
double closestDistance = double.infinity; double closestDistance = double.infinity;
if (i % 250 == 0) { if (i % 250 == 0) {
log("[ClusterIsolate] ${DateTime.now()} Processing $i faces"); log("[ClusterIsolate] ${DateTime.now()} Processing $i faces");
} }
for (int j = i - 1; j >= 0; j--) { for (int j = i - 1; j >= 0; j--) {
final double distance = cosineDistForNormVectors( late double distance;
currentEmbedding, if (sortedFaceInfos[i].vEmbedding != null) {
sortedFaceInfos[j].embedding, distance = 1.0 -
); sortedFaceInfos[i]
.vEmbedding!
.dot(sortedFaceInfos[j].vEmbedding!);
} else {
distance = cosineDistForNormVectors(
sortedFaceInfos[i].embedding!,
sortedFaceInfos[j].embedding!,
);
}
if (distance < closestDistance) { if (distance < closestDistance) {
closestDistance = distance; closestDistance = distance;
closestIdx = j; closestIdx = j;
@ -339,7 +354,7 @@ class FaceLinearClustering {
stopwatchClustering.stop(); stopwatchClustering.stop();
log( log(
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings (${sortedFaceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID', ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
); );
// analyze the results // analyze the results

View file

@ -389,6 +389,10 @@ class ClusterFeedbackService {
if (ignoredClusters.contains(clusterID)) { if (ignoredClusters.contains(clusterID)) {
continue; continue;
} }
if (allClusterIdsToCountMap[clusterID]! < 2) {
continue;
}
late List<double> avg; late List<double> avg;
if (clusterToSummary[clusterID]?.$2 == if (clusterToSummary[clusterID]?.$2 ==
allClusterIdsToCountMap[clusterID]) { allClusterIdsToCountMap[clusterID]) {
@ -412,6 +416,11 @@ class ClusterFeedbackService {
if (updatesForClusterSummary.length > 100) { if (updatesForClusterSummary.length > 100) {
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
updatesForClusterSummary.clear(); updatesForClusterSummary.clear();
if (kDebugMode) {
_logger.info(
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
);
}
} }
clusterAvg[clusterID] = avg; clusterAvg[clusterID] = avg;
} }
@ -441,6 +450,10 @@ class ClusterFeedbackService {
int? nearestPersonCluster; int? nearestPersonCluster;
double? minDistance; double? minDistance;
for (final personCluster in personClusters) { for (final personCluster in personClusters) {
if (clusterAvg[personCluster] == null) {
_logger.info('no avg for cluster $personCluster');
continue;
}
final avg = clusterAvg[personCluster]!; final avg = clusterAvg[personCluster]!;
final distance = cosineDistForNormVectors(avg, otherAvg); final distance = cosineDistForNormVectors(avg, otherAvg);
if (distance < maxClusterDistance) { if (distance < maxClusterDistance) {

View file

@ -93,7 +93,7 @@ class PersonFaceWidget extends StatelessWidget {
); );
if (face == null) { if (face == null) {
debugPrint( debugPrint(
"No cover face for person: $personId and cluster $clusterID", "No cover face for person: $personId and cluster $clusterID and recentFile ${file.uploadedFileID}",
); );
return null; return null;
} }