[mob] Option to add/remove face to cluster from file info

This commit is contained in:
laurenspriem 2024-04-04 18:47:09 +05:30
parent 19007c38b5
commit f1fd74b119
4 changed files with 161 additions and 62 deletions

View file

@ -763,4 +763,16 @@ class FaceMLDataDB {
}
await forceUpdateClusterIds(faceIDToClusterID);
}
Future<void> addFacesToCluster(
List<String> faceIDs,
int clusterID,
) async {
final faceIDToClusterID = <String, int>{};
for (final faceID in faceIDs) {
faceIDToClusterID[faceID] = clusterID;
}
await forceUpdateClusterIds(faceIDToClusterID);
}
}

View file

@ -340,6 +340,12 @@ class ClusterFeedbackService {
return;
}
Future<void> addFilesToCluster(List<String> faceIDs, int clusterID) async {
await FaceMLDataDB.instance.addFacesToCluster(faceIDs, clusterID);
Bus.instance.fire(PeopleChangedEvent());
return;
}
Future<bool> checkAndDoAutomaticMerges(Person p) async {
final faceMlDb = FaceMLDataDB.instance;
final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount());

View file

@ -2,12 +2,14 @@ import "dart:developer" show log;
import "dart:io" show Platform;
import "dart:typed_data";
import "package:flutter/cupertino.dart";
import "package:flutter/foundation.dart" show kDebugMode;
import "package:flutter/material.dart";
import "package:photos/face/db.dart";
import "package:photos/face/model/face.dart";
import "package:photos/face/model/person.dart";
import 'package:photos/models/file/file.dart';
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
import "package:photos/services/search_service.dart";
import "package:photos/theme/ente_theme.dart";
import "package:photos/ui/viewer/file/no_thumbnail_widget.dart";
@ -16,13 +18,15 @@ import "package:photos/ui/viewer/people/cropped_face_image_view.dart";
import "package:photos/ui/viewer/people/people_page.dart";
import "package:photos/utils/face/face_box_crop.dart";
import "package:photos/utils/thumbnail_util.dart";
// import "package:photos/utils/toast_util.dart";
class FaceWidget extends StatelessWidget {
class FaceWidget extends StatefulWidget {
final EnteFile file;
final Face face;
final Person? person;
final int? clusterID;
final bool highlight;
final bool editMode;
const FaceWidget(
this.file,
@ -30,9 +34,17 @@ class FaceWidget extends StatelessWidget {
this.person,
this.clusterID,
this.highlight = false,
this.editMode = false,
Key? key,
}) : super(key: key);
@override
State<FaceWidget> createState() => _FaceWidgetState();
}
class _FaceWidgetState extends State<FaceWidget> {
bool isJustRemoved = false;
@override
Widget build(BuildContext context) {
if (Platform.isIOS || Platform.isAndroid) {
@ -43,22 +55,27 @@ class FaceWidget extends StatelessWidget {
final ImageProvider imageProvider = MemoryImage(snapshot.data!);
return GestureDetector(
onTap: () async {
log(
"FaceWidget is tapped, with person $person and clusterID $clusterID",
name: "FaceWidget",
);
if (person == null && clusterID == null) {
if (widget.editMode) {
_cornerIconPressed();
return;
}
if (person != null) {
log(
"FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}",
name: "FaceWidget",
);
if (widget.person == null && widget.clusterID == null) {
return;
}
if (widget.person != null) {
await Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => PeoplePage(
person: person!,
person: widget.person!,
),
),
);
} else if (clusterID != null) {
} else if (widget.clusterID != null) {
final fileIdsToClusterIds =
await FaceMLDataDB.instance.getFileIdToClusterIds();
final files = await SearchService.instance.getAllFiles();
@ -66,7 +83,7 @@ class FaceWidget extends StatelessWidget {
.where(
(file) =>
fileIdsToClusterIds[file.uploadedFileID]
?.contains(clusterID) ??
?.contains(widget.clusterID) ??
false,
)
.toList();
@ -74,7 +91,7 @@ class FaceWidget extends StatelessWidget {
MaterialPageRoute(
builder: (context) => ClusterPage(
clusterFiles,
clusterID: clusterID!,
clusterID: widget.clusterID!,
),
),
);
@ -82,17 +99,20 @@ class FaceWidget extends StatelessWidget {
},
child: Column(
children: [
// TODO: the edges of the green line are still not properly rounded around ClipRRect
Stack(
children: [
Container(
height: 60,
width: 60,
decoration: ShapeDecoration(
shape: RoundedRectangleBorder(
borderRadius:
const BorderRadius.all(Radius.elliptical(16, 12)),
side: highlight
borderRadius: const BorderRadius.all(
Radius.elliptical(16, 12),
),
side: widget.highlight
? BorderSide(
color: getEnteColorScheme(context).primary700,
color:
getEnteColorScheme(context).primary700,
width: 1.0,
)
: BorderSide.none,
@ -111,17 +131,37 @@ class FaceWidget extends StatelessWidget {
),
),
),
// TODO: the edges of the green line are still not properly rounded around ClipRRect
if (widget.editMode)
Positioned(
right: 0,
top: 0,
child: GestureDetector(
onTap: _cornerIconPressed,
child: isJustRemoved
? const Icon(
CupertinoIcons.add_circled_solid,
color: Colors.green,
)
: const Icon(
Icons.cancel,
color: Colors.red,
),
),
),
],
),
const SizedBox(height: 8),
if (person != null)
if (widget.person != null)
Text(
person!.attr.name.trim(),
widget.person!.attr.name.trim(),
style: Theme.of(context).textTheme.bodySmall,
overflow: TextOverflow.ellipsis,
maxLines: 1,
),
if (kDebugMode)
Text(
'S: ${face.score.toStringAsFixed(3)}',
'S: ${widget.face.score.toStringAsFixed(3)}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
@ -168,21 +208,21 @@ class FaceWidget extends StatelessWidget {
return GestureDetector(
onTap: () async {
log(
"FaceWidget is tapped, with person $person and clusterID $clusterID",
"FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}",
name: "FaceWidget",
);
if (person == null && clusterID == null) {
if (widget.person == null && widget.clusterID == null) {
return;
}
if (person != null) {
if (widget.person != null) {
await Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => PeoplePage(
person: person!,
person: widget.person!,
),
),
);
} else if (clusterID != null) {
} else if (widget.clusterID != null) {
final fileIdsToClusterIds =
await FaceMLDataDB.instance.getFileIdToClusterIds();
final files = await SearchService.instance.getAllFiles();
@ -190,7 +230,7 @@ class FaceWidget extends StatelessWidget {
.where(
(file) =>
fileIdsToClusterIds[file.uploadedFileID]
?.contains(clusterID) ??
?.contains(widget.clusterID) ??
false,
)
.toList();
@ -198,7 +238,7 @@ class FaceWidget extends StatelessWidget {
MaterialPageRoute(
builder: (context) => ClusterPage(
clusterFiles,
clusterID: clusterID!,
clusterID: widget.clusterID!,
),
),
);
@ -213,7 +253,7 @@ class FaceWidget extends StatelessWidget {
shape: RoundedRectangleBorder(
borderRadius:
const BorderRadius.all(Radius.elliptical(16, 12)),
side: highlight
side: widget.highlight
? BorderSide(
color: getEnteColorScheme(context).primary700,
width: 2.0,
@ -228,23 +268,23 @@ class FaceWidget extends StatelessWidget {
width: 60,
height: 60,
child: CroppedFaceImageView(
enteFile: file,
face: face,
enteFile: widget.file,
face: widget.face,
),
),
),
),
const SizedBox(height: 8),
if (person != null)
if (widget.person != null)
Text(
person!.attr.name.trim(),
widget.person!.attr.name.trim(),
style: Theme.of(context).textTheme.bodySmall,
overflow: TextOverflow.ellipsis,
maxLines: 1,
),
if (kDebugMode)
Text(
'S: ${face.score.toStringAsFixed(3)}',
'S: ${widget.face.score.toStringAsFixed(3)}',
style: Theme.of(context).textTheme.bodySmall,
maxLines: 1,
),
@ -256,36 +296,55 @@ class FaceWidget extends StatelessWidget {
}
}
void _cornerIconPressed() async {
log('face widget (file info) corner icon is pressed');
try {
if (isJustRemoved) {
await ClusterFeedbackService.instance
.addFilesToCluster([widget.face.faceID], widget.clusterID!);
} else {
await ClusterFeedbackService.instance
.removeFilesFromCluster([widget.file], widget.clusterID!);
}
setState(() {
isJustRemoved = !isJustRemoved;
});
} catch (e, s) {
log("removing face/file from cluster from file info widget failed: $e, \n $s");
}
}
Future<Uint8List?> getFaceCrop() async {
try {
final Uint8List? cachedFace = faceCropCache.get(face.faceID);
final Uint8List? cachedFace = faceCropCache.get(widget.face.faceID);
if (cachedFace != null) {
return cachedFace;
}
final faceCropCacheFile = cachedFaceCropPath(face.faceID);
final faceCropCacheFile = cachedFaceCropPath(widget.face.faceID);
if ((await faceCropCacheFile.exists())) {
final data = await faceCropCacheFile.readAsBytes();
faceCropCache.put(face.faceID, data);
faceCropCache.put(widget.face.faceID, data);
return data;
}
final result = await pool.withResource(
() async => await getFaceCrops(
file,
widget.file,
{
face.faceID: face.detection.box,
widget.face.faceID: widget.face.detection.box,
},
),
);
final Uint8List? computedCrop = result?[face.faceID];
final Uint8List? computedCrop = result?[widget.face.faceID];
if (computedCrop != null) {
faceCropCache.put(face.faceID, computedCrop);
faceCropCache.put(widget.face.faceID, computedCrop);
faceCropCacheFile.writeAsBytes(computedCrop).ignore();
}
return computedCrop;
} catch (e, s) {
log(
"Error getting face for faceID: ${face.faceID}",
"Error getting face for faceID: ${widget.face.faceID}",
error: e,
stackTrace: s,
);

View file

@ -9,23 +9,44 @@ import "package:photos/ui/components/buttons/chip_button_widget.dart";
import "package:photos/ui/components/info_item_widget.dart";
import "package:photos/ui/viewer/file_details/face_widget.dart";
class FacesItemWidget extends StatelessWidget {
class FacesItemWidget extends StatefulWidget {
final EnteFile file;
const FacesItemWidget(this.file, {super.key});
@override
State<FacesItemWidget> createState() => _FacesItemWidgetState();
}
class _FacesItemWidgetState extends State<FacesItemWidget> {
bool editMode = false;
@override
void initState() {
super.initState();
setState(() {});
}
@override
Widget build(BuildContext context) {
return InfoItemWidget(
key: const ValueKey("Faces"),
leadingIcon: Icons.face_retouching_natural_outlined,
subtitleSection: _faceWidgets(context, file),
subtitleSection: _faceWidgets(context, widget.file, editMode),
hasChipButtons: true,
editOnTap: _toggleEditMode,
);
}
void _toggleEditMode() {
setState(() {
editMode = !editMode;
});
}
Future<List<Widget>> _faceWidgets(
BuildContext context,
EnteFile file,
bool editMode,
) async {
try {
if (file.uploadedFileID == null) {
@ -84,6 +105,7 @@ class FacesItemWidget extends StatelessWidget {
clusterID: clusterID,
person: person,
highlight: highlight,
editMode: highlight ? editMode : false,
),
);
}