[mob] Clustering with dynamic threshold based on face blur and score

This commit is contained in:
laurenspriem 2024-04-17 16:38:47 +05:30
parent 72e677e9e5
commit 51d15cc441
9 changed files with 146 additions and 30 deletions

View file

@ -12,6 +12,7 @@ import 'package:photos/face/db_fields.dart';
import "package:photos/face/db_model_mappers.dart";
import "package:photos/face/model/face.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
import 'package:sqflite/sqflite.dart';
import 'package:sqlite_async/sqlite_async.dart' as sqlite_async;
@ -444,12 +445,63 @@ class FaceMLDataDB {
);
}
Future<Set<FaceInfoForClustering>> getFaceInfoForClustering({
double minScore = kMinHighQualityFaceScore,
int minClarity = kLaplacianHardThreshold,
int maxFaces = 20000,
int offset = 0,
int batchSize = 10000,
}) async {
final EnteWatch w = EnteWatch("getFaceEmbeddingMap")..start();
w.logAndReset(
'reading as float offset: $offset, maxFaces: $maxFaces, batchSize: $batchSize',
);
final db = await instance.sqliteAsyncDB;
final Set<FaceInfoForClustering> result = {};
while (true) {
// Query a batch of rows
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur FROM $facesTable'
' WHERE $faceScore > $minScore AND $faceBlur > $minClarity'
' ORDER BY $faceIDColumn'
' DESC LIMIT $batchSize OFFSET $offset',
);
// Break the loop if no more rows
if (maps.isEmpty) {
break;
}
final List<String> faceIds = [];
for (final map in maps) {
faceIds.add(map[faceIDColumn] as String);
}
final faceIdToClusterId = await getFaceIdsToClusterIds(faceIds);
for (final map in maps) {
final faceID = map[faceIDColumn] as String;
final faceInfo = FaceInfoForClustering(
faceID: faceID,
clusterId: faceIdToClusterId[faceID],
embeddingBytes: map[faceEmbeddingBlob] as Uint8List,
faceScore: map[faceScore] as double,
blurValue: map[faceBlur] as double,
);
result.add(faceInfo);
}
if (result.length >= maxFaces) {
break;
}
offset += batchSize;
}
w.stopWithLog('done reading face embeddings ${result.length}');
return result;
}
/// Returns a map of faceID to record of clusterId and faceEmbeddingBlob
///
/// Only selects faces with score greater than [minScore] and blur score greater than [minClarity]
Future<Map<String, (int?, Uint8List)>> getFaceEmbeddingMap({
double minScore = kMinHighQualityFaceScore,
int minClarity = kLaplacianThreshold,
int minClarity = kLaplacianHardThreshold,
int maxFaces = 20000,
int offset = 0,
int batchSize = 10000,
@ -515,7 +567,7 @@ class FaceMLDataDB {
facesTable,
columns: [faceIDColumn, faceEmbeddingBlob],
where:
'$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold AND $fileIDColumn IN (${fileIDs.join(",")})',
'$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")})',
limit: batchSize,
offset: offset,
orderBy: '$faceIDColumn DESC',
@ -542,7 +594,7 @@ class FaceMLDataDB {
}) async {
final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianThreshold',
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianHardThreshold',
);
return maps.first['count'] as int;
}
@ -551,7 +603,7 @@ class FaceMLDataDB {
final db = await instance.sqliteAsyncDB;
final List<Map<String, dynamic>> totalFacesMaps = await db.getAll(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold',
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold',
);
final int totalFaces = totalFacesMaps.first['count'] as int;
@ -564,7 +616,7 @@ class FaceMLDataDB {
}
Future<int> getBlurryFaceCount([
int blurThreshold = kLaplacianThreshold,
int blurThreshold = kLaplacianHardThreshold,
]) async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(

View file

@ -20,7 +20,7 @@ class Face {
final double blur;
FileInfo? fileInfo;
bool get isBlurry => blur < kLaplacianThreshold;
bool get isBlurry => blur < kLaplacianHardThreshold;
bool get hasHighScore => score > kMinHighQualityFaceScore;

View file

@ -9,12 +9,16 @@ import "package:ml_linalg/dtype.dart";
import "package:ml_linalg/vector.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:simple_cluster/simple_cluster.dart";
import "package:synchronized/synchronized.dart";
class FaceInfo {
final String faceID;
final double? faceScore;
final double? blurValue;
final List<double>? embedding;
final Vector? vEmbedding;
int? clusterId;
@ -23,6 +27,8 @@ class FaceInfo {
int? fileCreationTime;
FaceInfo({
required this.faceID,
this.faceScore,
this.blurValue,
this.embedding,
this.vEmbedding,
this.clusterId,
@ -49,6 +55,7 @@ class FaceClusteringService {
bool isRunning = false;
static const kRecommendedDistanceThreshold = 0.24;
static const kConservativeDistanceThreshold = 0.06;
// singleton pattern
FaceClusteringService._privateConstructor();
@ -180,9 +187,11 @@ class FaceClusteringService {
///
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
Future<Map<String, int>?> predictLinear(
Map<String, (int?, Uint8List)> input, {
Set<FaceInfoForClustering> input, {
Map<int, int>? fileIDToCreationTime,
double distanceThreshold = kRecommendedDistanceThreshold,
double conservativeDistanceThreshold = kConservativeDistanceThreshold,
bool useDynamicThreshold = true,
int? offset,
}) async {
if (input.isEmpty) {
@ -212,6 +221,8 @@ class FaceClusteringService {
'input': input,
'fileIDToCreationTime': fileIDToCreationTime,
'distanceThreshold': distanceThreshold,
'conservativeDistanceThreshold': conservativeDistanceThreshold,
'useDynamicThreshold': useDynamicThreshold,
'offset': offset,
}
),
@ -280,9 +291,13 @@ class FaceClusteringService {
}
static Map<String, int> _runLinearClustering(Map args) {
final input = args['input'] as Map<String, (int?, Uint8List)>;
// final input = args['input'] as Map<String, (int?, Uint8List)>;
final input = args['input'] as Set<FaceInfoForClustering>;
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
final distanceThreshold = args['distanceThreshold'] as double;
final conservativeDistanceThreshold =
args['conservativeDistanceThreshold'] as double;
final useDynamicThreshold = args['useDynamicThreshold'] as bool;
final offset = args['offset'] as int?;
log(
@ -291,17 +306,19 @@ class FaceClusteringService {
// Organize everything into a list of FaceInfo objects
final List<FaceInfo> faceInfos = [];
for (final entry in input.entries) {
for (final face in input) {
faceInfos.add(
FaceInfo(
faceID: entry.key,
faceID: face.faceID,
faceScore: face.faceScore,
blurValue: face.blurValue,
vEmbedding: Vector.fromList(
EVector.fromBuffer(entry.value.$2).values,
EVector.fromBuffer(face.embeddingBytes).values,
dtype: DType.float32,
),
clusterId: entry.value.$1,
clusterId: face.clusterId,
fileCreationTime:
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
),
);
}
@ -341,6 +358,7 @@ class FaceClusteringService {
// Make sure the first face has a clusterId
final int totalFaces = sortedFaceInfos.length;
int dynamicThresholdCount = 0;
if (sortedFaceInfos.isEmpty) {
return {};
@ -368,6 +386,17 @@ class FaceClusteringService {
int closestIdx = -1;
double closestDistance = double.infinity;
late double thresholdValue;
if (useDynamicThreshold) {
final bool badFace =
(sortedFaceInfos[i].faceScore! < kMinHighQualityFaceScore ||
sortedFaceInfos[i].blurValue! < kLaplacianSoftThreshold);
thresholdValue =
badFace ? conservativeDistanceThreshold : distanceThreshold;
if (badFace) dynamicThresholdCount++;
} else {
thresholdValue = distanceThreshold;
}
if (i % 250 == 0) {
log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces");
}
@ -396,7 +425,7 @@ class FaceClusteringService {
}
}
if (closestDistance < distanceThreshold) {
if (closestDistance < thresholdValue) {
if (sortedFaceInfos[closestIdx].clusterId == null) {
// Ideally this should never happen, but just in case log it
log(
@ -432,6 +461,11 @@ class FaceClusteringService {
log(
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
);
if (useDynamicThreshold) {
log(
"[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or high blur value",
);
}
// analyze the results
FaceClusteringService._analyzeClusterResults(sortedFaceInfos);

View file

@ -0,0 +1,18 @@
import "dart:typed_data" show Uint8List;
class FaceInfoForClustering {
final String faceID;
final int? clusterId;
final Uint8List embeddingBytes;
final double faceScore;
final double blurValue;
FaceInfoForClustering({
required this.faceID,
this.clusterId,
required this.embeddingBytes,
required this.faceScore,
required this.blurValue,
});
}

View file

@ -12,7 +12,7 @@ class BlurDetectionService {
Future<(bool, double)> predictIsBlurGrayLaplacian(
List<List<int>> grayImage, {
int threshold = kLaplacianThreshold,
int threshold = kLaplacianHardThreshold,
FaceDirection faceDirection = FaceDirection.straight,
}) async {
final List<List<int>> laplacian =

View file

@ -1,7 +1,8 @@
import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart';
/// Blur detection threshold
const kLaplacianThreshold = 15;
const kLaplacianHardThreshold = 15;
const kLaplacianSoftThreshold = 100;
/// Default blur value
const kLapacianDefault = 10000.0;

View file

@ -504,7 +504,7 @@ class FaceResult {
final int fileId;
final String faceId;
bool get isBlurry => blurValue < kLaplacianThreshold;
bool get isBlurry => blurValue < kLaplacianHardThreshold;
const FaceResult({
required this.detection,
@ -545,7 +545,7 @@ class FaceResultBuilder {
int fileId = -1;
String faceId = '';
bool get isBlurry => blurValue < kLaplacianThreshold;
bool get isBlurry => blurValue < kLaplacianHardThreshold;
FaceResultBuilder({
required this.fileId,

View file

@ -310,14 +310,14 @@ class FaceMlService {
int bucket = 1;
while (true) {
final faceIdToEmbeddingBucket =
await FaceMLDataDB.instance.getFaceEmbeddingMap(
final faceInfoForClustering =
await FaceMLDataDB.instance.getFaceInfoForClustering(
minScore: minFaceScore,
maxFaces: bucketSize,
offset: offset,
batchSize: batchSize,
);
if (faceIdToEmbeddingBucket.isEmpty) {
if (faceInfoForClustering.isEmpty) {
_logger.warning(
'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces',
);
@ -332,7 +332,7 @@ class FaceMlService {
final faceIdToCluster =
await FaceClusteringService.instance.predictLinear(
faceIdToEmbeddingBucket,
faceInfoForClustering,
fileIDToCreationTime: fileIDToCreationTime,
offset: offset,
);
@ -343,7 +343,7 @@ class FaceMlService {
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
_logger.info(
'Done with clustering ${offset + faceIdToEmbeddingBucket.length} embeddings (${(100 * (offset + faceIdToEmbeddingBucket.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
);
if (offset + bucketSize >= totalFaces) {
_logger.info('All faces clustered');
@ -355,14 +355,14 @@ class FaceMlService {
} else {
// Read all the embeddings from the database, in a map from faceID to embedding
final clusterStartTime = DateTime.now();
final faceIdToEmbedding =
await FaceMLDataDB.instance.getFaceEmbeddingMap(
final faceInfoForClustering =
await FaceMLDataDB.instance.getFaceInfoForClustering(
minScore: minFaceScore,
maxFaces: totalFaces,
);
final gotFaceEmbeddingsTime = DateTime.now();
_logger.info(
'read embeddings ${faceIdToEmbedding.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms',
'read embeddings ${faceInfoForClustering.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms',
);
// Read the creation times from Files DB, in a map from fileID to creation time
@ -374,7 +374,7 @@ class FaceMlService {
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
final faceIdToCluster =
await FaceClusteringService.instance.predictLinear(
faceIdToEmbedding,
faceInfoForClustering,
fileIDToCreationTime: fileIDToCreationTime,
);
if (faceIdToCluster == null) {
@ -383,7 +383,7 @@ class FaceMlService {
}
final clusterDoneTime = DateTime.now();
_logger.info(
'done with clustering ${faceIdToEmbedding.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
'done with clustering ${faceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
);
// Store the updated clusterIDs in the database

View file

@ -13,6 +13,8 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
import "package:photos/models/file/file.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import "package:photos/services/search_service.dart";
@ -233,14 +235,23 @@ class ClusterFeedbackService {
}
} else {
final clusteringInput = embeddings.map((key, value) {
return MapEntry(key, (null, value));
});
return MapEntry(
key,
FaceInfoForClustering(
faceID: key,
embeddingBytes: value,
faceScore: kMinHighQualityFaceScore + 0.01,
blurValue: kLapacianDefault,
),
);
}).values.toSet();
final faceIdToCluster =
await FaceClusteringService.instance.predictLinear(
clusteringInput,
fileIDToCreationTime: fileIDToCreationTime,
distanceThreshold: 0.23,
useDynamicThreshold: false,
);
if (faceIdToCluster == null) {