diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart index e7fe12dc8..2778ec289 100644 --- a/mobile/lib/face/db.dart +++ b/mobile/lib/face/db.dart @@ -80,31 +80,28 @@ class FaceMLDataDB { } } - Future updatePersonIDForFaceIDIFNotSet( - Map faceIDToPersonID, + Future updateClusterIdToFaceId( + Map faceIDToClusterID, ) async { final db = await instance.database; const batchSize = 500; - final numBatches = (faceIDToPersonID.length / batchSize).ceil(); - + final numBatches = (faceIDToClusterID.length / batchSize).ceil(); for (int i = 0; i < numBatches; i++) { - _logger.info('updatePersonIDForFaceIDIFNotSet Batch $i of $numBatches'); final start = i * batchSize; - final end = min((i + 1) * batchSize, faceIDToPersonID.length); - final batch = faceIDToPersonID.entries.toList().sublist(start, end); + final end = min((i + 1) * batchSize, faceIDToClusterID.length); + final batch = faceIDToClusterID.entries.toList().sublist(start, end); final batchUpdate = db.batch(); for (final entry in batch) { final faceID = entry.key; - final personID = entry.value; + final clusterID = entry.value; batchUpdate.insert( faceClustersTable, - {fcClusterID: personID, fcFaceId: faceID}, + {fcClusterID: clusterID, fcFaceId: faceID}, conflictAlgorithm: ConflictAlgorithm.replace, ); } - await batchUpdate.commit(noResult: true); } } @@ -496,14 +493,7 @@ class FaceMLDataDB { mapPersonToRow(p), conflictAlgorithm: ConflictAlgorithm.replace, ); - await db.insert( - clusterPersonTable, - { - personIdColumn: p.remoteID, - cluserIDColumn: cluserID, - }, - conflictAlgorithm: ConflictAlgorithm.replace, - ); + await assignClusterToPerson(personID: p.remoteID, clusterID: cluserID); } Future updatePerson(PersonEntity p) async { diff --git a/mobile/lib/main.dart b/mobile/lib/main.dart index cc0bb60b1..08ce0e98a 100644 --- a/mobile/lib/main.dart +++ b/mobile/lib/main.dart @@ -20,6 +20,7 @@ import 'package:photos/core/errors.dart'; import 'package:photos/core/network/network.dart'; import 'package:photos/db/upload_locks_db.dart'; import 'package:photos/ente_theme_data.dart'; +import "package:photos/face/db.dart"; import "package:photos/l10n/l10n.dart"; import 'package:photos/services/app_lifecycle_service.dart'; import 'package:photos/services/billing_service.dart'; @@ -32,6 +33,7 @@ import 'package:photos/services/local_file_update_service.dart'; import 'package:photos/services/local_sync_service.dart'; import "package:photos/services/location_service.dart"; import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart'; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart'; import "package:photos/services/machine_learning/machine_learning_controller.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; @@ -238,6 +240,7 @@ Future _init(bool isBackground, {String via = ''}) async { unawaited(FaceMlService.instance.init()); FaceMlService.instance.listenIndexOnDiffSync(); } + PersonService.init(EntityService.instance, FaceMLDataDB.instance, preferences); _logger.info("Initialization done"); } diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart index e163defe6..e407dc318 100644 --- a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -416,8 +416,7 @@ class FaceMlService { return; } - await FaceMLDataDB.instance - .updatePersonIDForFaceIDIFNotSet(faceIdToCluster); + await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); offset += offsetIncrement; } } else { @@ -456,8 +455,7 @@ class FaceMlService { _logger.info( 'Updating ${faceIdToCluster.length} FaceIDs with clusterIDs in the DB', ); - await FaceMLDataDB.instance - .updatePersonIDForFaceIDIFNotSet(faceIdToCluster); + await FaceMLDataDB.instance.updateClusterIdToFaceId(faceIdToCluster); _logger.info('Done updating FaceIDs with clusterIDs in the DB, in ' '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds'); } diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart index 8004b45b2..d2f05b7cc 100644 --- a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -1,10 +1,52 @@ +import "dart:convert"; + import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/api/entity/type.dart"; import "package:photos/services/entity_service.dart"; import "package:shared_preferences/shared_preferences.dart"; class PersonService { final EntityService entityService; final FaceMLDataDB faceMLDataDB; - final SharedPreferences _prefs; - PersonService(this.entityService, this.faceMLDataDB, this._prefs); + final SharedPreferences prefs; + PersonService(this.entityService, this.faceMLDataDB, this.prefs); + // instance + static PersonService? _instance; + static PersonService get instance { + if (_instance == null) { + throw Exception("PersonService not initialized"); + } + return _instance!; + } + + static init( + EntityService entityService, + FaceMLDataDB faceMLDataDB, + SharedPreferences prefs, + ) { + _instance = PersonService(entityService, faceMLDataDB, prefs); + } + + Future addPerson(String name, int clusterID) async { + final faceIds = await faceMLDataDB.getFaceIDsForCluster(clusterID); + final data = PersonData( + name: name, + assigned: [ + ClusterInfo( + id: clusterID, + faces: faceIds.toSet(), + ), + ], + ); + final result = await entityService.addOrUpdate( + EntityType.person, + json.encode(data.toJson()), + ); + await faceMLDataDB.assignClusterToPerson( + personID: result.id, + clusterID: clusterID, + ); + return PersonEntity(result.id, data); + } } diff --git a/mobile/lib/ui/viewer/people/add_person_action_sheet.dart b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart index c66770124..16213db51 100644 --- a/mobile/lib/ui/viewer/people/add_person_action_sheet.dart +++ b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart @@ -11,6 +11,7 @@ import "package:photos/face/db.dart"; import "package:photos/face/model/person.dart"; import "package:photos/generated/l10n.dart"; import 'package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart'; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import 'package:photos/theme/colors.dart'; import 'package:photos/theme/ente_theme.dart'; import 'package:photos/ui/common/loading_widget.dart'; @@ -23,7 +24,6 @@ import "package:photos/ui/viewer/people/new_person_item_widget.dart"; import "package:photos/ui/viewer/people/person_row_item.dart"; import "package:photos/utils/dialog_util.dart"; import "package:photos/utils/toast_util.dart"; -import "package:uuid/uuid.dart"; enum PersonActionType { assignPerson, @@ -269,12 +269,8 @@ class _PersonActionSheetState extends State { return; } try { - final String id = const Uuid().v4().toString(); - final PersonEntity p = PersonEntity( - id, - PersonData(name: text, assigned: []), - ); - await FaceMLDataDB.instance.insert(p, clusterID); + final PersonEntity p = + await PersonService.instance.addPerson(text, clusterID); final bool extraPhotosFound = await ClusterFeedbackService.instance .checkAndDoAutomaticMerges(p); if (extraPhotosFound) { @@ -282,7 +278,6 @@ class _PersonActionSheetState extends State { } Bus.instance.fire(PeopleChangedEvent()); Navigator.pop(context, p); - log("inserted person"); } catch (e, s) { Logger("_PersonActionSheetState") .severe("Failed to rename album", e, s); diff --git a/mobile/lib/ui/viewer/search/result/person_face_widget.dart b/mobile/lib/ui/viewer/search/result/person_face_widget.dart index b2286dc65..832f4abb3 100644 --- a/mobile/lib/ui/viewer/search/result/person_face_widget.dart +++ b/mobile/lib/ui/viewer/search/result/person_face_widget.dart @@ -17,12 +17,16 @@ class PersonFaceWidget extends StatelessWidget { final String? personId; final int? clusterID; + // PersonFaceWidget constructor checks that both personId and clusterID are not null + // and that the file is not null const PersonFaceWidget( this.file, { this.personId, this.clusterID, Key? key, - }) : super(key: key); + }) : assert(personId != null || clusterID != null, + "PersonFaceWidget requires either personId or clusterID to be non-null"), + super(key: key); @override Widget build(BuildContext context) { @@ -53,7 +57,7 @@ class PersonFaceWidget extends StatelessWidget { ); } else { return FutureBuilder( - future: getFace(), + future: _getFace(), builder: (context, snapshot) { if (snapshot.hasData) { final Face face = snapshot.data!; @@ -76,7 +80,7 @@ class PersonFaceWidget extends StatelessWidget { } } - Future getFace() async { + Future _getFace() async { return await FaceMLDataDB.instance.getCoverFaceForPerson( recentFileID: file.uploadedFileID!, personID: personId, @@ -86,11 +90,7 @@ class PersonFaceWidget extends StatelessWidget { Future getFaceCrop() async { try { - final Face? face = await FaceMLDataDB.instance.getCoverFaceForPerson( - recentFileID: file.uploadedFileID!, - personID: personId, - clusterID: clusterID, - ); + final Face? face = await _getFace(); if (face == null) { debugPrint( "No cover face for person: $personId and cluster $clusterID and recentFile ${file.uploadedFileID}",