[mob] Use vector for cosine dist
This commit is contained in:
parent
4cb7334868
commit
226808aadb
|
@ -5,6 +5,8 @@ import "dart:math" show max;
|
|||
import "dart:typed_data";
|
||||
|
||||
import "package:logging/logging.dart";
|
||||
import "package:ml_linalg/dtype.dart";
|
||||
import "package:ml_linalg/vector.dart";
|
||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
|
@ -12,14 +14,16 @@ import "package:synchronized/synchronized.dart";
|
|||
|
||||
class FaceInfo {
|
||||
final String faceID;
|
||||
final List<double> embedding;
|
||||
final List<double>? embedding;
|
||||
final Vector? vEmbedding;
|
||||
int? clusterId;
|
||||
String? closestFaceId;
|
||||
int? closestDist;
|
||||
int? fileCreationTime;
|
||||
FaceInfo({
|
||||
required this.faceID,
|
||||
required this.embedding,
|
||||
this.embedding,
|
||||
this.vEmbedding,
|
||||
this.clusterId,
|
||||
this.fileCreationTime,
|
||||
});
|
||||
|
@ -230,7 +234,10 @@ class FaceLinearClustering {
|
|||
faceInfos.add(
|
||||
FaceInfo(
|
||||
faceID: entry.key,
|
||||
embedding: EVector.fromBuffer(entry.value.$2).values,
|
||||
vEmbedding: Vector.fromList(
|
||||
EVector.fromBuffer(entry.value.$2).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
clusterId: entry.value.$1,
|
||||
fileCreationTime:
|
||||
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
||||
|
@ -300,17 +307,25 @@ class FaceLinearClustering {
|
|||
}
|
||||
continue;
|
||||
}
|
||||
final currentEmbedding = sortedFaceInfos[i].embedding;
|
||||
|
||||
int closestIdx = -1;
|
||||
double closestDistance = double.infinity;
|
||||
if (i % 250 == 0) {
|
||||
log("[ClusterIsolate] ${DateTime.now()} Processing $i faces");
|
||||
}
|
||||
for (int j = i - 1; j >= 0; j--) {
|
||||
final double distance = cosineDistForNormVectors(
|
||||
currentEmbedding,
|
||||
sortedFaceInfos[j].embedding,
|
||||
);
|
||||
late double distance;
|
||||
if (sortedFaceInfos[i].vEmbedding != null) {
|
||||
distance = 1.0 -
|
||||
sortedFaceInfos[i]
|
||||
.vEmbedding!
|
||||
.dot(sortedFaceInfos[j].vEmbedding!);
|
||||
} else {
|
||||
distance = cosineDistForNormVectors(
|
||||
sortedFaceInfos[i].embedding!,
|
||||
sortedFaceInfos[j].embedding!,
|
||||
);
|
||||
}
|
||||
if (distance < closestDistance) {
|
||||
closestDistance = distance;
|
||||
closestIdx = j;
|
||||
|
@ -339,7 +354,7 @@ class FaceLinearClustering {
|
|||
|
||||
stopwatchClustering.stop();
|
||||
log(
|
||||
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings (${sortedFaceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
|
||||
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
|
||||
);
|
||||
|
||||
// analyze the results
|
||||
|
|
|
@ -389,6 +389,10 @@ class ClusterFeedbackService {
|
|||
if (ignoredClusters.contains(clusterID)) {
|
||||
continue;
|
||||
}
|
||||
if (allClusterIdsToCountMap[clusterID]! < 2) {
|
||||
continue;
|
||||
}
|
||||
|
||||
late List<double> avg;
|
||||
if (clusterToSummary[clusterID]?.$2 ==
|
||||
allClusterIdsToCountMap[clusterID]) {
|
||||
|
@ -412,6 +416,11 @@ class ClusterFeedbackService {
|
|||
if (updatesForClusterSummary.length > 100) {
|
||||
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
||||
updatesForClusterSummary.clear();
|
||||
if (kDebugMode) {
|
||||
_logger.info(
|
||||
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
|
||||
);
|
||||
}
|
||||
}
|
||||
clusterAvg[clusterID] = avg;
|
||||
}
|
||||
|
@ -441,6 +450,10 @@ class ClusterFeedbackService {
|
|||
int? nearestPersonCluster;
|
||||
double? minDistance;
|
||||
for (final personCluster in personClusters) {
|
||||
if (clusterAvg[personCluster] == null) {
|
||||
_logger.info('no avg for cluster $personCluster');
|
||||
continue;
|
||||
}
|
||||
final avg = clusterAvg[personCluster]!;
|
||||
final distance = cosineDistForNormVectors(avg, otherAvg);
|
||||
if (distance < maxClusterDistance) {
|
||||
|
|
|
@ -93,7 +93,7 @@ class PersonFaceWidget extends StatelessWidget {
|
|||
);
|
||||
if (face == null) {
|
||||
debugPrint(
|
||||
"No cover face for person: $personId and cluster $clusterID",
|
||||
"No cover face for person: $personId and cluster $clusterID and recentFile ${file.uploadedFileID}",
|
||||
);
|
||||
return null;
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue