[mob] Use vector for cosine dist
This commit is contained in:
parent
4cb7334868
commit
226808aadb
|
@ -5,6 +5,8 @@ import "dart:math" show max;
|
||||||
import "dart:typed_data";
|
import "dart:typed_data";
|
||||||
|
|
||||||
import "package:logging/logging.dart";
|
import "package:logging/logging.dart";
|
||||||
|
import "package:ml_linalg/dtype.dart";
|
||||||
|
import "package:ml_linalg/vector.dart";
|
||||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||||
|
@ -12,14 +14,16 @@ import "package:synchronized/synchronized.dart";
|
||||||
|
|
||||||
class FaceInfo {
|
class FaceInfo {
|
||||||
final String faceID;
|
final String faceID;
|
||||||
final List<double> embedding;
|
final List<double>? embedding;
|
||||||
|
final Vector? vEmbedding;
|
||||||
int? clusterId;
|
int? clusterId;
|
||||||
String? closestFaceId;
|
String? closestFaceId;
|
||||||
int? closestDist;
|
int? closestDist;
|
||||||
int? fileCreationTime;
|
int? fileCreationTime;
|
||||||
FaceInfo({
|
FaceInfo({
|
||||||
required this.faceID,
|
required this.faceID,
|
||||||
required this.embedding,
|
this.embedding,
|
||||||
|
this.vEmbedding,
|
||||||
this.clusterId,
|
this.clusterId,
|
||||||
this.fileCreationTime,
|
this.fileCreationTime,
|
||||||
});
|
});
|
||||||
|
@ -230,7 +234,10 @@ class FaceLinearClustering {
|
||||||
faceInfos.add(
|
faceInfos.add(
|
||||||
FaceInfo(
|
FaceInfo(
|
||||||
faceID: entry.key,
|
faceID: entry.key,
|
||||||
embedding: EVector.fromBuffer(entry.value.$2).values,
|
vEmbedding: Vector.fromList(
|
||||||
|
EVector.fromBuffer(entry.value.$2).values,
|
||||||
|
dtype: DType.float32,
|
||||||
|
),
|
||||||
clusterId: entry.value.$1,
|
clusterId: entry.value.$1,
|
||||||
fileCreationTime:
|
fileCreationTime:
|
||||||
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
||||||
|
@ -300,17 +307,25 @@ class FaceLinearClustering {
|
||||||
}
|
}
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
final currentEmbedding = sortedFaceInfos[i].embedding;
|
|
||||||
int closestIdx = -1;
|
int closestIdx = -1;
|
||||||
double closestDistance = double.infinity;
|
double closestDistance = double.infinity;
|
||||||
if (i % 250 == 0) {
|
if (i % 250 == 0) {
|
||||||
log("[ClusterIsolate] ${DateTime.now()} Processing $i faces");
|
log("[ClusterIsolate] ${DateTime.now()} Processing $i faces");
|
||||||
}
|
}
|
||||||
for (int j = i - 1; j >= 0; j--) {
|
for (int j = i - 1; j >= 0; j--) {
|
||||||
final double distance = cosineDistForNormVectors(
|
late double distance;
|
||||||
currentEmbedding,
|
if (sortedFaceInfos[i].vEmbedding != null) {
|
||||||
sortedFaceInfos[j].embedding,
|
distance = 1.0 -
|
||||||
);
|
sortedFaceInfos[i]
|
||||||
|
.vEmbedding!
|
||||||
|
.dot(sortedFaceInfos[j].vEmbedding!);
|
||||||
|
} else {
|
||||||
|
distance = cosineDistForNormVectors(
|
||||||
|
sortedFaceInfos[i].embedding!,
|
||||||
|
sortedFaceInfos[j].embedding!,
|
||||||
|
);
|
||||||
|
}
|
||||||
if (distance < closestDistance) {
|
if (distance < closestDistance) {
|
||||||
closestDistance = distance;
|
closestDistance = distance;
|
||||||
closestIdx = j;
|
closestIdx = j;
|
||||||
|
@ -339,7 +354,7 @@ class FaceLinearClustering {
|
||||||
|
|
||||||
stopwatchClustering.stop();
|
stopwatchClustering.stop();
|
||||||
log(
|
log(
|
||||||
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings (${sortedFaceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
|
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID',
|
||||||
);
|
);
|
||||||
|
|
||||||
// analyze the results
|
// analyze the results
|
||||||
|
|
|
@ -389,6 +389,10 @@ class ClusterFeedbackService {
|
||||||
if (ignoredClusters.contains(clusterID)) {
|
if (ignoredClusters.contains(clusterID)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (allClusterIdsToCountMap[clusterID]! < 2) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
late List<double> avg;
|
late List<double> avg;
|
||||||
if (clusterToSummary[clusterID]?.$2 ==
|
if (clusterToSummary[clusterID]?.$2 ==
|
||||||
allClusterIdsToCountMap[clusterID]) {
|
allClusterIdsToCountMap[clusterID]) {
|
||||||
|
@ -412,6 +416,11 @@ class ClusterFeedbackService {
|
||||||
if (updatesForClusterSummary.length > 100) {
|
if (updatesForClusterSummary.length > 100) {
|
||||||
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary);
|
||||||
updatesForClusterSummary.clear();
|
updatesForClusterSummary.clear();
|
||||||
|
if (kDebugMode) {
|
||||||
|
_logger.info(
|
||||||
|
'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters',
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
clusterAvg[clusterID] = avg;
|
clusterAvg[clusterID] = avg;
|
||||||
}
|
}
|
||||||
|
@ -441,6 +450,10 @@ class ClusterFeedbackService {
|
||||||
int? nearestPersonCluster;
|
int? nearestPersonCluster;
|
||||||
double? minDistance;
|
double? minDistance;
|
||||||
for (final personCluster in personClusters) {
|
for (final personCluster in personClusters) {
|
||||||
|
if (clusterAvg[personCluster] == null) {
|
||||||
|
_logger.info('no avg for cluster $personCluster');
|
||||||
|
continue;
|
||||||
|
}
|
||||||
final avg = clusterAvg[personCluster]!;
|
final avg = clusterAvg[personCluster]!;
|
||||||
final distance = cosineDistForNormVectors(avg, otherAvg);
|
final distance = cosineDistForNormVectors(avg, otherAvg);
|
||||||
if (distance < maxClusterDistance) {
|
if (distance < maxClusterDistance) {
|
||||||
|
|
|
@ -93,7 +93,7 @@ class PersonFaceWidget extends StatelessWidget {
|
||||||
);
|
);
|
||||||
if (face == null) {
|
if (face == null) {
|
||||||
debugPrint(
|
debugPrint(
|
||||||
"No cover face for person: $personId and cluster $clusterID",
|
"No cover face for person: $personId and cluster $clusterID and recentFile ${file.uploadedFileID}",
|
||||||
);
|
);
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue