[mob] Clustering with dynamic threshold based on face blur and score
This commit is contained in:
parent
72e677e9e5
commit
51d15cc441
|
@ -12,6 +12,7 @@ import 'package:photos/face/db_fields.dart';
|
|||
import "package:photos/face/db_model_mappers.dart";
|
||||
import "package:photos/face/model/face.dart";
|
||||
import "package:photos/models/file/file.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
|
||||
import 'package:sqflite/sqflite.dart';
|
||||
import 'package:sqlite_async/sqlite_async.dart' as sqlite_async;
|
||||
|
@ -444,12 +445,63 @@ class FaceMLDataDB {
|
|||
);
|
||||
}
|
||||
|
||||
Future<Set<FaceInfoForClustering>> getFaceInfoForClustering({
|
||||
double minScore = kMinHighQualityFaceScore,
|
||||
int minClarity = kLaplacianHardThreshold,
|
||||
int maxFaces = 20000,
|
||||
int offset = 0,
|
||||
int batchSize = 10000,
|
||||
}) async {
|
||||
final EnteWatch w = EnteWatch("getFaceEmbeddingMap")..start();
|
||||
w.logAndReset(
|
||||
'reading as float offset: $offset, maxFaces: $maxFaces, batchSize: $batchSize',
|
||||
);
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
|
||||
final Set<FaceInfoForClustering> result = {};
|
||||
while (true) {
|
||||
// Query a batch of rows
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur FROM $facesTable'
|
||||
' WHERE $faceScore > $minScore AND $faceBlur > $minClarity'
|
||||
' ORDER BY $faceIDColumn'
|
||||
' DESC LIMIT $batchSize OFFSET $offset',
|
||||
);
|
||||
// Break the loop if no more rows
|
||||
if (maps.isEmpty) {
|
||||
break;
|
||||
}
|
||||
final List<String> faceIds = [];
|
||||
for (final map in maps) {
|
||||
faceIds.add(map[faceIDColumn] as String);
|
||||
}
|
||||
final faceIdToClusterId = await getFaceIdsToClusterIds(faceIds);
|
||||
for (final map in maps) {
|
||||
final faceID = map[faceIDColumn] as String;
|
||||
final faceInfo = FaceInfoForClustering(
|
||||
faceID: faceID,
|
||||
clusterId: faceIdToClusterId[faceID],
|
||||
embeddingBytes: map[faceEmbeddingBlob] as Uint8List,
|
||||
faceScore: map[faceScore] as double,
|
||||
blurValue: map[faceBlur] as double,
|
||||
);
|
||||
result.add(faceInfo);
|
||||
}
|
||||
if (result.length >= maxFaces) {
|
||||
break;
|
||||
}
|
||||
offset += batchSize;
|
||||
}
|
||||
w.stopWithLog('done reading face embeddings ${result.length}');
|
||||
return result;
|
||||
}
|
||||
|
||||
/// Returns a map of faceID to record of clusterId and faceEmbeddingBlob
|
||||
///
|
||||
/// Only selects faces with score greater than [minScore] and blur score greater than [minClarity]
|
||||
Future<Map<String, (int?, Uint8List)>> getFaceEmbeddingMap({
|
||||
double minScore = kMinHighQualityFaceScore,
|
||||
int minClarity = kLaplacianThreshold,
|
||||
int minClarity = kLaplacianHardThreshold,
|
||||
int maxFaces = 20000,
|
||||
int offset = 0,
|
||||
int batchSize = 10000,
|
||||
|
@ -515,7 +567,7 @@ class FaceMLDataDB {
|
|||
facesTable,
|
||||
columns: [faceIDColumn, faceEmbeddingBlob],
|
||||
where:
|
||||
'$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold AND $fileIDColumn IN (${fileIDs.join(",")})',
|
||||
'$faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")})',
|
||||
limit: batchSize,
|
||||
offset: offset,
|
||||
orderBy: '$faceIDColumn DESC',
|
||||
|
@ -542,7 +594,7 @@ class FaceMLDataDB {
|
|||
}) async {
|
||||
final db = await instance.sqliteAsyncDB;
|
||||
final List<Map<String, dynamic>> maps = await db.getAll(
|
||||
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianThreshold',
|
||||
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianHardThreshold',
|
||||
);
|
||||
return maps.first['count'] as int;
|
||||
}
|
||||
|
@ -551,7 +603,7 @@ class FaceMLDataDB {
|
|||
final db = await instance.sqliteAsyncDB;
|
||||
|
||||
final List<Map<String, dynamic>> totalFacesMaps = await db.getAll(
|
||||
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianThreshold',
|
||||
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinHighQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold',
|
||||
);
|
||||
final int totalFaces = totalFacesMaps.first['count'] as int;
|
||||
|
||||
|
@ -564,7 +616,7 @@ class FaceMLDataDB {
|
|||
}
|
||||
|
||||
Future<int> getBlurryFaceCount([
|
||||
int blurThreshold = kLaplacianThreshold,
|
||||
int blurThreshold = kLaplacianHardThreshold,
|
||||
]) async {
|
||||
final db = await instance.database;
|
||||
final List<Map<String, dynamic>> maps = await db.rawQuery(
|
||||
|
|
|
@ -20,7 +20,7 @@ class Face {
|
|||
final double blur;
|
||||
FileInfo? fileInfo;
|
||||
|
||||
bool get isBlurry => blur < kLaplacianThreshold;
|
||||
bool get isBlurry => blur < kLaplacianHardThreshold;
|
||||
|
||||
bool get hasHighScore => score > kMinHighQualityFaceScore;
|
||||
|
||||
|
|
|
@ -9,12 +9,16 @@ import "package:ml_linalg/dtype.dart";
|
|||
import "package:ml_linalg/vector.dart";
|
||||
import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
import "package:simple_cluster/simple_cluster.dart";
|
||||
import "package:synchronized/synchronized.dart";
|
||||
|
||||
class FaceInfo {
|
||||
final String faceID;
|
||||
final double? faceScore;
|
||||
final double? blurValue;
|
||||
final List<double>? embedding;
|
||||
final Vector? vEmbedding;
|
||||
int? clusterId;
|
||||
|
@ -23,6 +27,8 @@ class FaceInfo {
|
|||
int? fileCreationTime;
|
||||
FaceInfo({
|
||||
required this.faceID,
|
||||
this.faceScore,
|
||||
this.blurValue,
|
||||
this.embedding,
|
||||
this.vEmbedding,
|
||||
this.clusterId,
|
||||
|
@ -49,6 +55,7 @@ class FaceClusteringService {
|
|||
bool isRunning = false;
|
||||
|
||||
static const kRecommendedDistanceThreshold = 0.24;
|
||||
static const kConservativeDistanceThreshold = 0.06;
|
||||
|
||||
// singleton pattern
|
||||
FaceClusteringService._privateConstructor();
|
||||
|
@ -180,9 +187,11 @@ class FaceClusteringService {
|
|||
///
|
||||
/// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic.
|
||||
Future<Map<String, int>?> predictLinear(
|
||||
Map<String, (int?, Uint8List)> input, {
|
||||
Set<FaceInfoForClustering> input, {
|
||||
Map<int, int>? fileIDToCreationTime,
|
||||
double distanceThreshold = kRecommendedDistanceThreshold,
|
||||
double conservativeDistanceThreshold = kConservativeDistanceThreshold,
|
||||
bool useDynamicThreshold = true,
|
||||
int? offset,
|
||||
}) async {
|
||||
if (input.isEmpty) {
|
||||
|
@ -212,6 +221,8 @@ class FaceClusteringService {
|
|||
'input': input,
|
||||
'fileIDToCreationTime': fileIDToCreationTime,
|
||||
'distanceThreshold': distanceThreshold,
|
||||
'conservativeDistanceThreshold': conservativeDistanceThreshold,
|
||||
'useDynamicThreshold': useDynamicThreshold,
|
||||
'offset': offset,
|
||||
}
|
||||
),
|
||||
|
@ -280,9 +291,13 @@ class FaceClusteringService {
|
|||
}
|
||||
|
||||
static Map<String, int> _runLinearClustering(Map args) {
|
||||
final input = args['input'] as Map<String, (int?, Uint8List)>;
|
||||
// final input = args['input'] as Map<String, (int?, Uint8List)>;
|
||||
final input = args['input'] as Set<FaceInfoForClustering>;
|
||||
final fileIDToCreationTime = args['fileIDToCreationTime'] as Map<int, int>?;
|
||||
final distanceThreshold = args['distanceThreshold'] as double;
|
||||
final conservativeDistanceThreshold =
|
||||
args['conservativeDistanceThreshold'] as double;
|
||||
final useDynamicThreshold = args['useDynamicThreshold'] as bool;
|
||||
final offset = args['offset'] as int?;
|
||||
|
||||
log(
|
||||
|
@ -291,17 +306,19 @@ class FaceClusteringService {
|
|||
|
||||
// Organize everything into a list of FaceInfo objects
|
||||
final List<FaceInfo> faceInfos = [];
|
||||
for (final entry in input.entries) {
|
||||
for (final face in input) {
|
||||
faceInfos.add(
|
||||
FaceInfo(
|
||||
faceID: entry.key,
|
||||
faceID: face.faceID,
|
||||
faceScore: face.faceScore,
|
||||
blurValue: face.blurValue,
|
||||
vEmbedding: Vector.fromList(
|
||||
EVector.fromBuffer(entry.value.$2).values,
|
||||
EVector.fromBuffer(face.embeddingBytes).values,
|
||||
dtype: DType.float32,
|
||||
),
|
||||
clusterId: entry.value.$1,
|
||||
clusterId: face.clusterId,
|
||||
fileCreationTime:
|
||||
fileIDToCreationTime?[getFileIdFromFaceId(entry.key)],
|
||||
fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)],
|
||||
),
|
||||
);
|
||||
}
|
||||
|
@ -341,6 +358,7 @@ class FaceClusteringService {
|
|||
|
||||
// Make sure the first face has a clusterId
|
||||
final int totalFaces = sortedFaceInfos.length;
|
||||
int dynamicThresholdCount = 0;
|
||||
|
||||
if (sortedFaceInfos.isEmpty) {
|
||||
return {};
|
||||
|
@ -368,6 +386,17 @@ class FaceClusteringService {
|
|||
|
||||
int closestIdx = -1;
|
||||
double closestDistance = double.infinity;
|
||||
late double thresholdValue;
|
||||
if (useDynamicThreshold) {
|
||||
final bool badFace =
|
||||
(sortedFaceInfos[i].faceScore! < kMinHighQualityFaceScore ||
|
||||
sortedFaceInfos[i].blurValue! < kLaplacianSoftThreshold);
|
||||
thresholdValue =
|
||||
badFace ? conservativeDistanceThreshold : distanceThreshold;
|
||||
if (badFace) dynamicThresholdCount++;
|
||||
} else {
|
||||
thresholdValue = distanceThreshold;
|
||||
}
|
||||
if (i % 250 == 0) {
|
||||
log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces");
|
||||
}
|
||||
|
@ -396,7 +425,7 @@ class FaceClusteringService {
|
|||
}
|
||||
}
|
||||
|
||||
if (closestDistance < distanceThreshold) {
|
||||
if (closestDistance < thresholdValue) {
|
||||
if (sortedFaceInfos[closestIdx].clusterId == null) {
|
||||
// Ideally this should never happen, but just in case log it
|
||||
log(
|
||||
|
@ -432,6 +461,11 @@ class FaceClusteringService {
|
|||
log(
|
||||
' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms',
|
||||
);
|
||||
if (useDynamicThreshold) {
|
||||
log(
|
||||
"[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or high blur value",
|
||||
);
|
||||
}
|
||||
|
||||
// analyze the results
|
||||
FaceClusteringService._analyzeClusterResults(sortedFaceInfos);
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
|
||||
import "dart:typed_data" show Uint8List;
|
||||
|
||||
class FaceInfoForClustering {
|
||||
final String faceID;
|
||||
final int? clusterId;
|
||||
final Uint8List embeddingBytes;
|
||||
final double faceScore;
|
||||
final double blurValue;
|
||||
|
||||
FaceInfoForClustering({
|
||||
required this.faceID,
|
||||
this.clusterId,
|
||||
required this.embeddingBytes,
|
||||
required this.faceScore,
|
||||
required this.blurValue,
|
||||
});
|
||||
}
|
|
@ -12,7 +12,7 @@ class BlurDetectionService {
|
|||
|
||||
Future<(bool, double)> predictIsBlurGrayLaplacian(
|
||||
List<List<int>> grayImage, {
|
||||
int threshold = kLaplacianThreshold,
|
||||
int threshold = kLaplacianHardThreshold,
|
||||
FaceDirection faceDirection = FaceDirection.straight,
|
||||
}) async {
|
||||
final List<List<int>> laplacian =
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart';
|
||||
|
||||
/// Blur detection threshold
|
||||
const kLaplacianThreshold = 15;
|
||||
const kLaplacianHardThreshold = 15;
|
||||
const kLaplacianSoftThreshold = 100;
|
||||
|
||||
/// Default blur value
|
||||
const kLapacianDefault = 10000.0;
|
||||
|
|
|
@ -504,7 +504,7 @@ class FaceResult {
|
|||
final int fileId;
|
||||
final String faceId;
|
||||
|
||||
bool get isBlurry => blurValue < kLaplacianThreshold;
|
||||
bool get isBlurry => blurValue < kLaplacianHardThreshold;
|
||||
|
||||
const FaceResult({
|
||||
required this.detection,
|
||||
|
@ -545,7 +545,7 @@ class FaceResultBuilder {
|
|||
int fileId = -1;
|
||||
String faceId = '';
|
||||
|
||||
bool get isBlurry => blurValue < kLaplacianThreshold;
|
||||
bool get isBlurry => blurValue < kLaplacianHardThreshold;
|
||||
|
||||
FaceResultBuilder({
|
||||
required this.fileId,
|
||||
|
|
|
@ -310,14 +310,14 @@ class FaceMlService {
|
|||
int bucket = 1;
|
||||
|
||||
while (true) {
|
||||
final faceIdToEmbeddingBucket =
|
||||
await FaceMLDataDB.instance.getFaceEmbeddingMap(
|
||||
final faceInfoForClustering =
|
||||
await FaceMLDataDB.instance.getFaceInfoForClustering(
|
||||
minScore: minFaceScore,
|
||||
maxFaces: bucketSize,
|
||||
offset: offset,
|
||||
batchSize: batchSize,
|
||||
);
|
||||
if (faceIdToEmbeddingBucket.isEmpty) {
|
||||
if (faceInfoForClustering.isEmpty) {
|
||||
_logger.warning(
|
||||
'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces',
|
||||
);
|
||||
|
@ -332,7 +332,7 @@ class FaceMlService {
|
|||
|
||||
final faceIdToCluster =
|
||||
await FaceClusteringService.instance.predictLinear(
|
||||
faceIdToEmbeddingBucket,
|
||||
faceInfoForClustering,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
offset: offset,
|
||||
);
|
||||
|
@ -343,7 +343,7 @@ class FaceMlService {
|
|||
|
||||
await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster);
|
||||
_logger.info(
|
||||
'Done with clustering ${offset + faceIdToEmbeddingBucket.length} embeddings (${(100 * (offset + faceIdToEmbeddingBucket.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
|
||||
'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset',
|
||||
);
|
||||
if (offset + bucketSize >= totalFaces) {
|
||||
_logger.info('All faces clustered');
|
||||
|
@ -355,14 +355,14 @@ class FaceMlService {
|
|||
} else {
|
||||
// Read all the embeddings from the database, in a map from faceID to embedding
|
||||
final clusterStartTime = DateTime.now();
|
||||
final faceIdToEmbedding =
|
||||
await FaceMLDataDB.instance.getFaceEmbeddingMap(
|
||||
final faceInfoForClustering =
|
||||
await FaceMLDataDB.instance.getFaceInfoForClustering(
|
||||
minScore: minFaceScore,
|
||||
maxFaces: totalFaces,
|
||||
);
|
||||
final gotFaceEmbeddingsTime = DateTime.now();
|
||||
_logger.info(
|
||||
'read embeddings ${faceIdToEmbedding.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms',
|
||||
'read embeddings ${faceInfoForClustering.length} in ${gotFaceEmbeddingsTime.difference(clusterStartTime).inMilliseconds} ms',
|
||||
);
|
||||
|
||||
// Read the creation times from Files DB, in a map from fileID to creation time
|
||||
|
@ -374,7 +374,7 @@ class FaceMlService {
|
|||
// Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID
|
||||
final faceIdToCluster =
|
||||
await FaceClusteringService.instance.predictLinear(
|
||||
faceIdToEmbedding,
|
||||
faceInfoForClustering,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
);
|
||||
if (faceIdToCluster == null) {
|
||||
|
@ -383,7 +383,7 @@ class FaceMlService {
|
|||
}
|
||||
final clusterDoneTime = DateTime.now();
|
||||
_logger.info(
|
||||
'done with clustering ${faceIdToEmbedding.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
|
||||
'done with clustering ${faceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ',
|
||||
);
|
||||
|
||||
// Store the updated clusterIDs in the database
|
||||
|
|
|
@ -13,6 +13,8 @@ import "package:photos/generated/protos/ente/common/vector.pb.dart";
|
|||
import "package:photos/models/file/file.dart";
|
||||
import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart';
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
|
||||
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
|
||||
import "package:photos/services/search_service.dart";
|
||||
|
@ -233,14 +235,23 @@ class ClusterFeedbackService {
|
|||
}
|
||||
} else {
|
||||
final clusteringInput = embeddings.map((key, value) {
|
||||
return MapEntry(key, (null, value));
|
||||
});
|
||||
return MapEntry(
|
||||
key,
|
||||
FaceInfoForClustering(
|
||||
faceID: key,
|
||||
embeddingBytes: value,
|
||||
faceScore: kMinHighQualityFaceScore + 0.01,
|
||||
blurValue: kLapacianDefault,
|
||||
),
|
||||
);
|
||||
}).values.toSet();
|
||||
|
||||
final faceIdToCluster =
|
||||
await FaceClusteringService.instance.predictLinear(
|
||||
clusteringInput,
|
||||
fileIDToCreationTime: fileIDToCreationTime,
|
||||
distanceThreshold: 0.23,
|
||||
useDynamicThreshold: false,
|
||||
);
|
||||
|
||||
if (faceIdToCluster == null) {
|
||||
|
|
Loading…
Reference in a new issue