Merge branch 'mobile_face' of https://github.com/ente-io/auth into mobile_face

This commit is contained in:
Neeraj Gupta 2024-04-05 16:00:09 +05:30
commit 1b9c81c50c
12 changed files with 503 additions and 111 deletions

View file

@ -481,6 +481,16 @@ class FaceMLDataDB {
return maps.first['count'] as int;
}
Future<int> getBlurryFaceCount([
int blurThreshold = kLaplacianThreshold,
]) async {
final db = await instance.database;
final List<Map<String, dynamic>> maps = await db.rawQuery(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinHighQualityFaceScore',
);
return maps.first['count'] as int;
}
Future<void> resetClusterIDs() async {
final db = await instance.database;
await db.execute(dropFaceClustersTable);
@ -726,7 +736,7 @@ class FaceMLDataDB {
for (final enteFile in files) {
fileIds.add(enteFile.uploadedFileID.toString());
}
int maxClusterID = DateTime.now().millisecondsSinceEpoch;
int maxClusterID = DateTime.now().microsecondsSinceEpoch;
final Map<String, int> faceIDToClusterID = {};
for (final row in faceIdsResult) {
final faceID = row[fcFaceId] as String;
@ -752,7 +762,7 @@ class FaceMLDataDB {
for (final enteFile in files) {
fileIds.add(enteFile.uploadedFileID.toString());
}
int maxClusterID = DateTime.now().millisecondsSinceEpoch;
int maxClusterID = DateTime.now().microsecondsSinceEpoch;
final Map<String, int> faceIDToClusterID = {};
for (final row in faceIdsResult) {
final faceID = row[fcFaceId] as String;
@ -763,4 +773,16 @@ class FaceMLDataDB {
}
await forceUpdateClusterIds(faceIDToClusterID);
}
Future<void> addFacesToCluster(
List<String> faceIDs,
int clusterID,
) async {
final faceIDToClusterID = <String, int>{};
for (final faceID in faceIDs) {
faceIDToClusterID[faceID] = clusterID;
}
await forceUpdateClusterIds(faceIDToClusterID);
}
}

View file

@ -3,6 +3,9 @@ import "package:photos/face/model/landmark.dart";
/// Stores the face detection data, notably the bounding box and landmarks.
///
/// - Bounding box: [FaceBox] with xMin, yMin (so top left corner), width, height
/// - Landmarks: list of [Landmark]s, namely leftEye, rightEye, nose, leftMouth, rightMouth
///
/// WARNING: All coordinates are relative to the image size, so in the range [0, 1]!
class Detection {
FaceBox box;
@ -39,4 +42,43 @@ class Detection {
),
);
}
// TODO: iterate on better area calculation, potentially using actual indexing image dimensions instead of file metadata
int getFaceArea(int imageWidth, int imageHeight) {
return (box.width * imageWidth * box.height * imageHeight).toInt();
}
// TODO: iterate on better scoring logic, current is a placeholder
int getVisibilityScore() {
final double aspectRatio = box.width / box.height;
final double eyeDistance = (landmarks[1].x - landmarks[0].x).abs();
final double mouthDistance = (landmarks[4].x - landmarks[3].x).abs();
final double noseEyeDistance =
(landmarks[2].y - ((landmarks[0].y + landmarks[1].y) / 2)).abs();
final double normalizedEyeDistance = eyeDistance / box.width;
final double normalizedMouthDistance = mouthDistance / box.width;
final double normalizedNoseEyeDistance = noseEyeDistance / box.height;
const double aspectRatioThreshold = 0.8;
const double eyeDistanceThreshold = 0.2;
const double mouthDistanceThreshold = 0.3;
const double noseEyeDistanceThreshold = 0.1;
double score = 0;
if (aspectRatio >= aspectRatioThreshold) {
score += 50;
}
if (normalizedEyeDistance >= eyeDistanceThreshold) {
score += 20;
}
if (normalizedMouthDistance >= mouthDistanceThreshold) {
score += 20;
}
if (normalizedNoseEyeDistance >= noseEyeDistanceThreshold) {
score += 10;
}
return score.clamp(0, 100).toInt();
}
}

View file

@ -353,14 +353,8 @@ class FaceClustering {
// Make sure the first face has a clusterId
final int totalFaces = sortedFaceInfos.length;
// set current epoch time as clusterID
int clusterID = DateTime.now().millisecondsSinceEpoch;
if (sortedFaceInfos.isNotEmpty) {
if (sortedFaceInfos.first.clusterId == null) {
sortedFaceInfos.first.clusterId = clusterID;
} else {
clusterID = sortedFaceInfos.first.clusterId!;
}
} else {
int clusterID = DateTime.now().microsecondsSinceEpoch;
if (sortedFaceInfos.isEmpty) {
return {};
}
@ -401,6 +395,12 @@ class FaceClustering {
if (distance < closestDistance) {
closestDistance = distance;
closestIdx = j;
// if (distance < distanceThreshold) {
// if (sortedFaceInfos[j].faceID.startsWith("14914702") ||
// sortedFaceInfos[j].faceID.startsWith("15488756")) {
// log('[XXX] faceIDs: ${sortedFaceInfos[j].faceID} and ${sortedFaceInfos[i].faceID} with distance $distance');
// }
// }
}
}
@ -414,10 +414,22 @@ class FaceClustering {
sortedFaceInfos[closestIdx].clusterId = clusterID;
newFaceIdToCluster[sortedFaceInfos[closestIdx].faceID] = clusterID;
}
// if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
// sortedFaceInfos[i].faceID.startsWith("15488756")) {
// log(
// "[XXX] [ClusterIsolate] ${DateTime.now()} Found similar face ${sortedFaceInfos[i].faceID} to ${sortedFaceInfos[closestIdx].faceID} with distance $closestDistance",
// );
// }
sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId;
newFaceIdToCluster[sortedFaceInfos[i].faceID] =
sortedFaceInfos[closestIdx].clusterId!;
} else {
// if (sortedFaceInfos[i].faceID.startsWith("14914702") ||
// sortedFaceInfos[i].faceID.startsWith("15488756")) {
// log(
// "[XXX] [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID for face ${sortedFaceInfos[i].faceID}",
// );
// }
clusterID++;
sortedFaceInfos[i].clusterId = clusterID;
newFaceIdToCluster[sortedFaceInfos[i].faceID] = clusterID;

View file

@ -654,7 +654,7 @@ class FaceMlService {
.map(
(keypoint) => Landmark(
x: keypoint[0],
y: keypoint[0],
y: keypoint[1],
),
)
.toList(),

View file

@ -325,12 +325,25 @@ class ClusterFeedbackService {
}
}
Future<void> removeFilesFromPerson(List<EnteFile> files, Person p) {
return FaceMLDataDB.instance.removeFilesFromPerson(files, p);
Future<void> removeFilesFromPerson(List<EnteFile> files, Person p) async {
await FaceMLDataDB.instance.removeFilesFromPerson(files, p);
Bus.instance.fire(PeopleChangedEvent());
return;
}
Future<void> removeFilesFromCluster(List<EnteFile> files, int clusterID) {
return FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID);
Future<void> removeFilesFromCluster(
List<EnteFile> files,
int clusterID,
) async {
await FaceMLDataDB.instance.removeFilesFromCluster(files, clusterID);
Bus.instance.fire(PeopleChangedEvent());
return;
}
Future<void> addFilesToCluster(List<String> faceIDs, int clusterID) async {
await FaceMLDataDB.instance.addFacesToCluster(faceIDs, clusterID);
Bus.instance.fire(PeopleChangedEvent());
return;
}
Future<bool> checkAndDoAutomaticMerges(Person p) async {
@ -413,7 +426,7 @@ class ClusterFeedbackService {
embeddings,
fileIDToCreationTime: fileIDToCreationTime,
eps: 0.30,
minPts: 5,
minPts: 8,
);
if (dbscanClusters.isEmpty) {

View file

@ -114,7 +114,9 @@ class _FaceDebugSectionWidgetState extends State<FaceDebugSectionWidget> {
.getTotalFaceCount(minFaceScore: 0.75);
final faces78 = await FaceMLDataDB.instance
.getTotalFaceCount(minFaceScore: kMinHighQualityFaceScore);
showShortToast(context, "Faces75: $faces75, Faces78: $faces78");
final blurryFaceCount =
await FaceMLDataDB.instance.getBlurryFaceCount(15);
showShortToast(context, "$blurryFaceCount blurry faces");
},
),
// MenuItemWidget(

View file

@ -2,12 +2,14 @@ import "dart:developer" show log;
import "dart:io" show Platform;
import "dart:typed_data";
import "package:flutter/cupertino.dart";
import "package:flutter/foundation.dart" show kDebugMode;
import "package:flutter/material.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/face.dart";
import "package:photos/face/model/person.dart";
import 'package:photos/models/file/file.dart';
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
import "package:photos/services/search_service.dart";
import "package:photos/theme/ente_theme.dart";
import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
@ -16,13 +18,15 @@ import "package:photos/ui/viewer/people/cropped_face_image_view.dart";
import "package:photos/ui/viewer/people/people_page.dart";
import "package:photos/utils/face/face_box_crop.dart";
import "package:photos/utils/thumbnail_util.dart";
// import "package:photos/utils/toast_util.dart";
class FaceWidget extends StatelessWidget {
class FaceWidget extends StatefulWidget {
final EnteFile file;
final Face face;
final Person? person;
final int? clusterID;
final bool highlight;
final bool editMode;
const FaceWidget(
this.file,
@ -30,9 +34,17 @@ class FaceWidget extends StatelessWidget {
this.person,
this.clusterID,
this.highlight = false,
this.editMode = false,
Key? key,
}) : super(key: key);
@override
State<FaceWidget> createState() => _FaceWidgetState();
}
class _FaceWidgetState extends State<FaceWidget> {
bool isJustRemoved = false;
@override
Widget build(BuildContext context) {
if (Platform.isIOS || Platform.isAndroid) {
@ -43,22 +55,24 @@ class FaceWidget extends StatelessWidget {
final ImageProvider imageProvider = MemoryImage(snapshot.data!);
return GestureDetector(
onTap: () async {
if (widget.editMode) return;
log(
"FaceWidget is tapped, with person $person and clusterID $clusterID",
"FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}",
name: "FaceWidget",
);
if (person == null && clusterID == null) {
if (widget.person == null && widget.clusterID == null) {
return;
}
if (person != null) {
if (widget.person != null) {
await Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => PeoplePage(
person: person!,
person: widget.person!,
),
),
);
} else if (clusterID != null) {
} else if (widget.clusterID != null) {
final fileIdsToClusterIds =
await FaceMLDataDB.instance.getFileIdToClusterIds();
final files = await SearchService.instance.getAllFiles();
@ -66,7 +80,7 @@ class FaceWidget extends StatelessWidget {
.where(
(file) =>
fileIdsToClusterIds[file.uploadedFileID]
?.contains(clusterID) ??
?.contains(widget.clusterID) ??
false,
)
.toList();
@ -74,7 +88,7 @@ class FaceWidget extends StatelessWidget {
MaterialPageRoute(
builder: (context) => ClusterPage(
clusterFiles,
clusterID: clusterID!,
clusterID: widget.clusterID!,
),
),
);
@ -82,46 +96,87 @@ class FaceWidget extends StatelessWidget {
},
child: Column(
children: [
// TODO: the edges of the green line are still not properly rounded around ClipRRect
Container(
height: 60,
width: 60,
decoration: ShapeDecoration(
shape: RoundedRectangleBorder(
borderRadius:
const BorderRadius.all(Radius.elliptical(16, 12)),
side: highlight
? BorderSide(
color: getEnteColorScheme(context).primary700,
width: 2.0,
)
: BorderSide.none,
),
),
child: ClipRRect(
borderRadius:
const BorderRadius.all(Radius.elliptical(16, 12)),
child: SizedBox(
width: 60,
Stack(
children: [
Container(
height: 60,
child: Image(
image: imageProvider,
fit: BoxFit.cover,
width: 60,
decoration: ShapeDecoration(
shape: RoundedRectangleBorder(
borderRadius: const BorderRadius.all(
Radius.elliptical(16, 12),
),
side: widget.highlight
? BorderSide(
color:
getEnteColorScheme(context).primary700,
width: 1.0,
)
: BorderSide.none,
),
),
child: ClipRRect(
borderRadius:
const BorderRadius.all(Radius.elliptical(16, 12)),
child: SizedBox(
width: 60,
height: 60,
child: Image(
image: imageProvider,
fit: BoxFit.cover,
),
),
),
),
),
// TODO: the edges of the green line are still not properly rounded around ClipRRect
if (widget.editMode)
Positioned(
right: 0,
top: 0,
child: GestureDetector(
onTap: _cornerIconPressed,
child: isJustRemoved
? const Icon(
CupertinoIcons.add_circled_solid,
color: Colors.green,
)
: const Icon(
Icons.cancel,
color: Colors.red,
),
),
),
],
),
const SizedBox(height: 8),
if (person != null)
if (widget.person != null)
Text(
person!.attr.name.trim(),
widget.person!.attr.name.trim(),
style: Theme.of(context).textTheme.bodySmall,
overflow: TextOverflow.ellipsis,
maxLines: 1,
),
if (kDebugMode)
Text(
'S: ${face.score.toStringAsFixed(3)}',
'S: ${widget.face.score.toStringAsFixed(3)}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
if (kDebugMode)
Text(
'B: ${widget.face.blur.toStringAsFixed(3)}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
if (kDebugMode)
Text(
'V: ${widget.face.detection.getVisibilityScore()}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
if (kDebugMode)
Text(
'A: ${widget.face.detection.getFaceArea(widget.file.width, widget.file.height)}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
@ -168,21 +223,21 @@ class FaceWidget extends StatelessWidget {
return GestureDetector(
onTap: () async {
log(
"FaceWidget is tapped, with person $person and clusterID $clusterID",
"FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}",
name: "FaceWidget",
);
if (person == null && clusterID == null) {
if (widget.person == null && widget.clusterID == null) {
return;
}
if (person != null) {
if (widget.person != null) {
await Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => PeoplePage(
person: person!,
person: widget.person!,
),
),
);
} else if (clusterID != null) {
} else if (widget.clusterID != null) {
final fileIdsToClusterIds =
await FaceMLDataDB.instance.getFileIdToClusterIds();
final files = await SearchService.instance.getAllFiles();
@ -190,7 +245,7 @@ class FaceWidget extends StatelessWidget {
.where(
(file) =>
fileIdsToClusterIds[file.uploadedFileID]
?.contains(clusterID) ??
?.contains(widget.clusterID) ??
false,
)
.toList();
@ -198,7 +253,7 @@ class FaceWidget extends StatelessWidget {
MaterialPageRoute(
builder: (context) => ClusterPage(
clusterFiles,
clusterID: clusterID!,
clusterID: widget.clusterID!,
),
),
);
@ -213,7 +268,7 @@ class FaceWidget extends StatelessWidget {
shape: RoundedRectangleBorder(
borderRadius:
const BorderRadius.all(Radius.elliptical(16, 12)),
side: highlight
side: widget.highlight
? BorderSide(
color: getEnteColorScheme(context).primary700,
width: 2.0,
@ -228,23 +283,23 @@ class FaceWidget extends StatelessWidget {
width: 60,
height: 60,
child: CroppedFaceImageView(
enteFile: file,
face: face,
enteFile: widget.file,
face: widget.face,
),
),
),
),
const SizedBox(height: 8),
if (person != null)
if (widget.person != null)
Text(
person!.attr.name.trim(),
widget.person!.attr.name.trim(),
style: Theme.of(context).textTheme.bodySmall,
overflow: TextOverflow.ellipsis,
maxLines: 1,
),
if (kDebugMode)
Text(
'S: ${face.score.toStringAsFixed(3)}',
'S: ${widget.face.score.toStringAsFixed(3)}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
@ -256,36 +311,55 @@ class FaceWidget extends StatelessWidget {
}
}
void _cornerIconPressed() async {
log('face widget (file info) corner icon is pressed');
try {
if (isJustRemoved) {
await ClusterFeedbackService.instance
.addFilesToCluster([widget.face.faceID], widget.clusterID!);
} else {
await ClusterFeedbackService.instance
.removeFilesFromCluster([widget.file], widget.clusterID!);
}
setState(() {
isJustRemoved = !isJustRemoved;
});
} catch (e, s) {
log("removing face/file from cluster from file info widget failed: $e, \n $s");
}
}
Future<Uint8List?> getFaceCrop() async {
try {
final Uint8List? cachedFace = faceCropCache.get(face.faceID);
final Uint8List? cachedFace = faceCropCache.get(widget.face.faceID);
if (cachedFace != null) {
return cachedFace;
}
final faceCropCacheFile = cachedFaceCropPath(face.faceID);
final faceCropCacheFile = cachedFaceCropPath(widget.face.faceID);
if ((await faceCropCacheFile.exists())) {
final data = await faceCropCacheFile.readAsBytes();
faceCropCache.put(face.faceID, data);
faceCropCache.put(widget.face.faceID, data);
return data;
}
final result = await pool.withResource(
() async => await getFaceCrops(
file,
widget.file,
{
face.faceID: face.detection.box,
widget.face.faceID: widget.face.detection.box,
},
),
);
final Uint8List? computedCrop = result?[face.faceID];
final Uint8List? computedCrop = result?[widget.face.faceID];
if (computedCrop != null) {
faceCropCache.put(face.faceID, computedCrop);
faceCropCache.put(widget.face.faceID, computedCrop);
faceCropCacheFile.writeAsBytes(computedCrop).ignore();
}
return computedCrop;
} catch (e, s) {
log(
"Error getting face for faceID: ${face.faceID}",
"Error getting face for faceID: ${widget.face.faceID}",
error: e,
stackTrace: s,
);

View file

@ -1,3 +1,4 @@
import "package:flutter/foundation.dart" show kDebugMode;
import "package:flutter/material.dart";
import "package:logging/logging.dart";
import "package:photos/face/db.dart";
@ -9,23 +10,44 @@ import "package:photos/ui/components/buttons/chip_button_widget.dart";
import "package:photos/ui/components/info_item_widget.dart";
import "package:photos/ui/viewer/file_details/face_widget.dart";
class FacesItemWidget extends StatelessWidget {
class FacesItemWidget extends StatefulWidget {
final EnteFile file;
const FacesItemWidget(this.file, {super.key});
@override
State<FacesItemWidget> createState() => _FacesItemWidgetState();
}
class _FacesItemWidgetState extends State<FacesItemWidget> {
bool editMode = false;
@override
void initState() {
super.initState();
setState(() {});
}
@override
Widget build(BuildContext context) {
return InfoItemWidget(
key: const ValueKey("Faces"),
leadingIcon: Icons.face_retouching_natural_outlined,
subtitleSection: _faceWidgets(context, file),
subtitleSection: _faceWidgets(context, widget.file, editMode),
hasChipButtons: true,
editOnTap: _toggleEditMode,
);
}
void _toggleEditMode() {
setState(() {
editMode = !editMode;
});
}
Future<List<Widget>> _faceWidgets(
BuildContext context,
EnteFile file,
bool editMode,
) async {
try {
if (file.uploadedFileID == null) {
@ -47,8 +69,13 @@ class FacesItemWidget extends StatelessWidget {
),
];
}
if (faces.isEmpty ||
faces.every((face) => face.score < 0.75 || face.isBlurry)) {
// Remove faces with low scores and blurry faces
if (!kDebugMode) {
faces.removeWhere((face) => (face.isBlurry || face.score < 0.75));
}
if (faces.isEmpty) {
return [
const ChipButtonWidget(
"No faces found",
@ -60,9 +87,6 @@ class FacesItemWidget extends StatelessWidget {
// Sort the faces by score in descending order, so that the highest scoring face is first.
faces.sort((Face a, Face b) => b.score.compareTo(a.score));
// Remove faces with low scores and blurry faces
faces.removeWhere((face) => (face.isBlurry || face.score < 0.75));
// TODO: add deduplication of faces of same person
final faceIdsToClusterIds = await FaceMLDataDB.instance
.getFaceIdsToClusterIds(faces.map((face) => face.faceID));
@ -84,6 +108,7 @@ class FacesItemWidget extends StatelessWidget {
clusterID: clusterID,
person: person,
highlight: highlight,
editMode: highlight ? editMode : false,
),
);
}

View file

@ -3,21 +3,22 @@ import 'dart:async';
import "package:flutter/foundation.dart";
import 'package:flutter/material.dart';
import 'package:logging/logging.dart';
import "package:ml_linalg/linalg.dart";
import 'package:photos/core/configuration.dart';
import 'package:photos/core/event_bus.dart';
import "package:photos/db/files_db.dart";
// import "package:photos/events/people_changed_event.dart";
import 'package:photos/events/subscription_purchased_event.dart';
// import "package:photos/face/db.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/person.dart";
import "package:photos/generated/protos/ente/common/vector.pb.dart";
import 'package:photos/models/gallery_type.dart';
import 'package:photos/models/selected_files.dart';
import 'package:photos/services/collections_service.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart";
import "package:photos/services/machine_learning/face_ml/face_ml_result.dart";
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
import "package:photos/ui/viewer/people/cluster_page.dart";
// import "package:photos/utils/dialog_util.dart";
import "package:photos/ui/viewer/people/cluster_breakup_page.dart";
class ClusterAppBar extends StatefulWidget {
final GalleryType type;
@ -42,6 +43,7 @@ class ClusterAppBar extends StatefulWidget {
enum ClusterPopupAction {
setCover,
breakupCluster,
validateCluster,
hide,
}
@ -130,6 +132,18 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
],
),
),
const PopupMenuItem(
value: ClusterPopupAction.validateCluster,
child: Row(
children: [
Icon(Icons.search_off_outlined),
Padding(
padding: EdgeInsets.all(8),
),
Text('Validate cluster'),
],
),
),
// PopupMenuItem(
// value: ClusterPopupAction.hide,
// child: Row(
@ -155,6 +169,8 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
if (value == ClusterPopupAction.breakupCluster) {
// ignore: unawaited_futures
await _breakUpCluster(context);
} else if (value == ClusterPopupAction.validateCluster) {
await _validateCluster(context);
}
// else if (value == ClusterPopupAction.setCover) {
// await setCoverPhoto(context);
@ -169,28 +185,84 @@ class _AppBarWidgetState extends State<ClusterAppBar> {
return actions;
}
Future<void> _validateCluster(BuildContext context) async {
_logger.info('_validateCluster called');
final faceMlDb = FaceMLDataDB.instance;
final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID);
final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList();
final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs);
embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key));
final embeddings = embeddingsBlobs
.map((key, value) => MapEntry(key, EVector.fromBuffer(value).values));
for (final MapEntry<String, List<double>> embedding in embeddings.entries) {
double closestDistance = double.infinity;
double closestDistance32 = double.infinity;
double closestDistance64 = double.infinity;
String? closestFaceID;
for (final MapEntry<String, List<double>> otherEmbedding
in embeddings.entries) {
if (embedding.key == otherEmbedding.key) {
continue;
}
final distance64 = 1.0 -
Vector.fromList(embedding.value, dtype: DType.float64).dot(
Vector.fromList(otherEmbedding.value, dtype: DType.float64),
);
final distance32 = 1.0 -
Vector.fromList(embedding.value, dtype: DType.float32).dot(
Vector.fromList(otherEmbedding.value, dtype: DType.float32),
);
final distance = cosineDistForNormVectors(
embedding.value,
otherEmbedding.value,
);
if (distance < closestDistance) {
closestDistance = distance;
closestDistance32 = distance32;
closestDistance64 = distance64;
closestFaceID = otherEmbedding.key;
}
}
if (closestDistance > 0.3) {
_logger.severe(
"Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64",
);
}
}
}
Future<void> _breakUpCluster(BuildContext context) async {
final newClusterIDToFaceIDs =
await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID);
for (final cluster in newClusterIDToFaceIDs.entries) {
// ignore: unawaited_futures
final newClusterID = cluster.key;
final faceIDs = cluster.value;
final files = await FilesDB.instance
.getFilesFromIDs(faceIDs.map((e) => getFileIdFromFaceId(e)).toList());
unawaited(
Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => ClusterPage(
files.values.toList(),
appendTitle:
(newClusterID == -1) ? "(Analysis noise)" : "(Analysis)",
clusterID: newClusterID,
),
),
final allFileIDs = newClusterIDToFaceIDs.values
.expand((e) => e)
.map((e) => getFileIdFromFaceId(e))
.toList();
final fileIDtoFile = await FilesDB.instance.getFilesFromIDs(
allFileIDs,
);
final newClusterIDToFiles = newClusterIDToFaceIDs.map(
(key, value) => MapEntry(
key,
value
.map((faceId) => fileIDtoFile[getFileIdFromFaceId(faceId)]!)
.toList(),
),
);
await Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => ClusterBreakupPage(
newClusterIDToFiles,
"(Analysis)",
),
);
}
),
);
}
}

View file

@ -0,0 +1,124 @@
import "package:flutter/material.dart";
import "package:photos/models/file/file.dart";
import "package:photos/theme/ente_theme.dart";
import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
import "package:photos/ui/viewer/people/cluster_page.dart";
import "package:photos/ui/viewer/search/result/person_face_widget.dart";
class ClusterBreakupPage extends StatefulWidget {
final Map<int, List<EnteFile>> newClusterIDsToFiles;
final String title;
const ClusterBreakupPage(
this.newClusterIDsToFiles,
this.title, {
super.key,
});
@override
State<ClusterBreakupPage> createState() => _ClusterBreakupPageState();
}
class _ClusterBreakupPageState extends State<ClusterBreakupPage> {
@override
Widget build(BuildContext context) {
final keys = widget.newClusterIDsToFiles.keys.toList();
final clusterIDsToFiles = widget.newClusterIDsToFiles;
return Scaffold(
appBar: AppBar(
title: Text(widget.title),
),
body: ListView.builder(
itemCount: widget.newClusterIDsToFiles.keys.length,
itemBuilder: (context, index) {
final int clusterID = keys[index];
final List<EnteFile> files = clusterIDsToFiles[keys[index]]!;
return InkWell(
onTap: () {
Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => ClusterPage(
files,
clusterID: index,
appendTitle: "(Analysis)",
),
),
);
},
child: Container(
padding: const EdgeInsets.all(8.0),
child: Row(
children: <Widget>[
SizedBox(
width: 64,
height: 64,
child: files.isNotEmpty
? ClipRRect(
borderRadius: const BorderRadius.all(
Radius.elliptical(16, 12),),
child: PersonFaceWidget(
files.first,
clusterID: clusterID,
),
)
: const ClipRRect(
borderRadius:
BorderRadius.all(Radius.elliptical(16, 12)),
child: NoThumbnailWidget(
addBorder: false,
),
),
),
const SizedBox(
width: 8.0,
), // Add some spacing between the thumbnail and the text
Expanded(
child: Padding(
padding: const EdgeInsets.symmetric(horizontal: 8.0),
child: Row(
mainAxisAlignment: MainAxisAlignment.spaceBetween,
children: <Widget>[
Text(
"${clusterIDsToFiles[keys[index]]!.length} photos",
style: getEnteTextTheme(context).body,
),
// GestureDetector(
// onTap: () async {
// try {
// final int result = await FaceMLDataDB
// .instance
// .removeClusterToPerson(
// personID: widget.person.remoteID,
// clusterID: clusterID,
// );
// _logger.info(
// "Removed cluster $clusterID from person ${widget.person.remoteID}, result: $result",
// );
// Bus.instance.fire(PeopleChangedEvent());
// setState(() {});
// } catch (e) {
// _logger.severe(
// "removing cluster from person,",
// e,
// );
// }
// },
// child: const Icon(
// CupertinoIcons.minus_circled,
// color: Colors.red,
// ),
// ),
],
),
),
),
],
),
),
);
},
),
);
}
}

View file

@ -14,8 +14,8 @@ import 'package:photos/models/gallery_type.dart';
import 'package:photos/models/selected_files.dart';
import 'package:photos/services/collections_service.dart';
import 'package:photos/ui/actions/collection/collection_sharing_actions.dart';
import "package:photos/ui/viewer/people/person_cluserts.dart";
import "package:photos/ui/viewer/people/person_cluster_suggestion.dart";
import 'package:photos/ui/viewer/people/person_clusters_page.dart';
import "package:photos/utils/dialog_util.dart";
class PeopleAppBar extends StatefulWidget {
@ -215,7 +215,7 @@ class _AppBarWidgetState extends State<PeopleAppBar> {
unawaited(
Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => PersonClusters(widget.person),
builder: (context) => PersonClustersPage(widget.person),
),
),
);

View file

@ -13,19 +13,19 @@ import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
import "package:photos/ui/viewer/people/cluster_page.dart";
import "package:photos/ui/viewer/search/result/person_face_widget.dart";
class PersonClusters extends StatefulWidget {
class PersonClustersPage extends StatefulWidget {
final Person person;
const PersonClusters(
const PersonClustersPage(
this.person, {
super.key,
});
@override
State<PersonClusters> createState() => _PersonClustersState();
State<PersonClustersPage> createState() => _PersonClustersPageState();
}
class _PersonClustersState extends State<PersonClusters> {
class _PersonClustersPageState extends State<PersonClustersPage> {
final Logger _logger = Logger("_PersonClustersState");
@override
Widget build(BuildContext context) {
@ -64,13 +64,19 @@ class _PersonClustersState extends State<PersonClusters> {
width: 64,
height: 64,
child: files.isNotEmpty
? ClipOval(
? ClipRRect(
borderRadius: const BorderRadius.all(
Radius.elliptical(16, 12),
),
child: PersonFaceWidget(
files.first,
clusterID: clusterID,
),
)
: const ClipOval(
: const ClipRRect(
borderRadius: BorderRadius.all(
Radius.elliptical(16, 12),
),
child: NoThumbnailWidget(
addBorder: false,
),