diff --git a/mobile/android/app/build.gradle b/mobile/android/app/build.gradle index a241cbe7e..7a0519484 100644 --- a/mobile/android/app/build.gradle +++ b/mobile/android/app/build.gradle @@ -47,7 +47,7 @@ android { defaultConfig { applicationId "io.ente.photos" - minSdkVersion 21 + minSdkVersion 26 targetSdkVersion 33 versionCode flutterVersionCode.toInteger() versionName flutterVersionName @@ -74,6 +74,10 @@ android { dimension "default" applicationIdSuffix ".dev" } + face { + dimension "default" + applicationIdSuffix ".face" + } playstore { dimension "default" } diff --git a/mobile/android/app/src/face/AndroidManifest.xml b/mobile/android/app/src/face/AndroidManifest.xml new file mode 100644 index 000000000..cbf1924b2 --- /dev/null +++ b/mobile/android/app/src/face/AndroidManifest.xml @@ -0,0 +1,10 @@ + + + + + + + diff --git a/mobile/android/app/src/face/res/values/strings.xml b/mobile/android/app/src/face/res/values/strings.xml new file mode 100644 index 000000000..4932deb96 --- /dev/null +++ b/mobile/android/app/src/face/res/values/strings.xml @@ -0,0 +1,4 @@ + + ente face + backup face + diff --git a/mobile/ios/Podfile.lock b/mobile/ios/Podfile.lock index cf6d6b875..6d2a86940 100644 --- a/mobile/ios/Podfile.lock +++ b/mobile/ios/Podfile.lock @@ -59,6 +59,8 @@ PODS: - flutter_inappwebview/Core (0.0.1): - Flutter - OrderedSet (~> 5.0) + - flutter_isolate (0.0.1): + - Flutter - flutter_local_notifications (0.0.1): - Flutter - flutter_native_splash (0.0.1): @@ -197,6 +199,28 @@ PODS: - sqlite3/fts5 - sqlite3/perf-threadsafe - sqlite3/rtree + - TensorFlowLiteC (2.12.0): + - TensorFlowLiteC/Core (= 2.12.0) + - TensorFlowLiteC/Core (2.12.0) + - TensorFlowLiteC/CoreML (2.12.0): + - TensorFlowLiteC/Core + - TensorFlowLiteC/Metal (2.12.0): + - TensorFlowLiteC/Core + - TensorFlowLiteSwift (2.12.0): + - TensorFlowLiteSwift/Core (= 2.12.0) + - TensorFlowLiteSwift/Core (2.12.0): + - TensorFlowLiteC (= 2.12.0) + - TensorFlowLiteSwift/CoreML (2.12.0): + - TensorFlowLiteC/CoreML (= 2.12.0) + - TensorFlowLiteSwift/Core (= 2.12.0) + - TensorFlowLiteSwift/Metal (2.12.0): + - TensorFlowLiteC/Metal (= 2.12.0) + - TensorFlowLiteSwift/Core (= 2.12.0) + - tflite_flutter (0.0.1): + - Flutter + - TensorFlowLiteSwift (= 2.12.0) + - TensorFlowLiteSwift/CoreML (= 2.12.0) + - TensorFlowLiteSwift/Metal (= 2.12.0) - Toast (4.1.0) - uni_links (0.0.1): - Flutter @@ -228,6 +252,7 @@ DEPENDENCIES: - flutter_email_sender (from `.symlinks/plugins/flutter_email_sender/ios`) - flutter_image_compress (from `.symlinks/plugins/flutter_image_compress/ios`) - flutter_inappwebview (from `.symlinks/plugins/flutter_inappwebview/ios`) + - flutter_isolate (from `.symlinks/plugins/flutter_isolate/ios`) - flutter_local_notifications (from `.symlinks/plugins/flutter_local_notifications/ios`) - flutter_native_splash (from `.symlinks/plugins/flutter_native_splash/ios`) - flutter_secure_storage (from `.symlinks/plugins/flutter_secure_storage/ios`) @@ -257,6 +282,7 @@ DEPENDENCIES: - shared_preferences_foundation (from `.symlinks/plugins/shared_preferences_foundation/darwin`) - sqflite (from `.symlinks/plugins/sqflite/darwin`) - sqlite3_flutter_libs (from `.symlinks/plugins/sqlite3_flutter_libs/ios`) + - tflite_flutter (from `.symlinks/plugins/tflite_flutter/ios`) - uni_links (from `.symlinks/plugins/uni_links/ios`) - url_launcher_ios (from `.symlinks/plugins/url_launcher_ios/ios`) - video_player_avfoundation (from `.symlinks/plugins/video_player_avfoundation/darwin`) @@ -287,6 +313,8 @@ SPEC REPOS: - Sentry - SentryPrivate - sqlite3 + - TensorFlowLiteC + - TensorFlowLiteSwift - Toast EXTERNAL SOURCES: @@ -314,6 +342,8 @@ EXTERNAL SOURCES: :path: ".symlinks/plugins/flutter_image_compress/ios" flutter_inappwebview: :path: ".symlinks/plugins/flutter_inappwebview/ios" + flutter_isolate: + :path: ".symlinks/plugins/flutter_isolate/ios" flutter_local_notifications: :path: ".symlinks/plugins/flutter_local_notifications/ios" flutter_native_splash: @@ -372,6 +402,8 @@ EXTERNAL SOURCES: :path: ".symlinks/plugins/sqflite/darwin" sqlite3_flutter_libs: :path: ".symlinks/plugins/sqlite3_flutter_libs/ios" + tflite_flutter: + :path: ".symlinks/plugins/tflite_flutter/ios" uni_links: :path: ".symlinks/plugins/uni_links/ios" url_launcher_ios: @@ -405,6 +437,7 @@ SPEC CHECKSUMS: flutter_email_sender: 02d7443217d8c41483223627972bfdc09f74276b flutter_image_compress: 5a5e9aee05b6553048b8df1c3bc456d0afaac433 flutter_inappwebview: 3d32228f1304635e7c028b0d4252937730bbc6cf + flutter_isolate: 0edf5081826d071adf21759d1eb10ff5c24503b5 flutter_local_notifications: 0c0b1ae97e741e1521e4c1629a459d04b9aec743 flutter_native_splash: 52501b97d1c0a5f898d687f1646226c1f93c56ef flutter_secure_storage: 23fc622d89d073675f2eaa109381aefbcf5a49be @@ -449,6 +482,9 @@ SPEC CHECKSUMS: sqflite: 673a0e54cc04b7d6dba8d24fb8095b31c3a99eec sqlite3: 73b7fc691fdc43277614250e04d183740cb15078 sqlite3_flutter_libs: aeb4d37509853dfa79d9b59386a2dac5dd079428 + TensorFlowLiteC: 20785a69299185a379ba9852b6625f00afd7984a + TensorFlowLiteSwift: 3a4928286e9e35bdd3e17970f48e53c80d25e793 + tflite_flutter: 9433d086a3060431bbc9f3c7c20d017db0e72d08 Toast: ec33c32b8688982cecc6348adeae667c1b9938da uni_links: d97da20c7701486ba192624d99bffaaffcfc298a url_launcher_ios: bbd758c6e7f9fd7b5b1d4cde34d2b95fcce5e812 diff --git a/mobile/ios/Runner.xcodeproj/project.pbxproj b/mobile/ios/Runner.xcodeproj/project.pbxproj index e9cbf0685..c05eaf21d 100644 --- a/mobile/ios/Runner.xcodeproj/project.pbxproj +++ b/mobile/ios/Runner.xcodeproj/project.pbxproj @@ -299,6 +299,7 @@ "${BUILT_PRODUCTS_DIR}/flutter_email_sender/flutter_email_sender.framework", "${BUILT_PRODUCTS_DIR}/flutter_image_compress/flutter_image_compress.framework", "${BUILT_PRODUCTS_DIR}/flutter_inappwebview/flutter_inappwebview.framework", + "${BUILT_PRODUCTS_DIR}/flutter_isolate/flutter_isolate.framework", "${BUILT_PRODUCTS_DIR}/flutter_local_notifications/flutter_local_notifications.framework", "${BUILT_PRODUCTS_DIR}/flutter_native_splash/flutter_native_splash.framework", "${BUILT_PRODUCTS_DIR}/flutter_secure_storage/flutter_secure_storage.framework", @@ -382,6 +383,7 @@ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_email_sender.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_image_compress.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_inappwebview.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_isolate.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_local_notifications.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_native_splash.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/flutter_secure_storage.framework", diff --git a/mobile/lib/core/configuration.dart b/mobile/lib/core/configuration.dart index f82486631..5a3ac03a0 100644 --- a/mobile/lib/core/configuration.dart +++ b/mobile/lib/core/configuration.dart @@ -18,6 +18,7 @@ import 'package:photos/db/trash_db.dart'; import 'package:photos/db/upload_locks_db.dart'; import 'package:photos/events/signed_in_event.dart'; import 'package:photos/events/user_logged_out_event.dart'; +import "package:photos/face/db.dart"; import 'package:photos/models/key_attributes.dart'; import 'package:photos/models/key_gen_result.dart'; import 'package:photos/models/private_key_attributes.dart'; @@ -164,6 +165,7 @@ class Configuration { : null; await CollectionsDB.instance.clearTable(); await MemoriesDB.instance.clearTable(); + await FaceMLDataDB.instance.clearTable(); await UploadLocksDB.instance.clearTable(); await IgnoredFilesService.instance.reset(); diff --git a/mobile/lib/db/ml_data_db.dart b/mobile/lib/db/ml_data_db.dart new file mode 100644 index 000000000..07150f09b --- /dev/null +++ b/mobile/lib/db/ml_data_db.dart @@ -0,0 +1,714 @@ +import 'dart:async'; + +import 'package:logging/logging.dart'; +import 'package:path/path.dart' show join; +import 'package:path_provider/path_provider.dart'; +import 'package:photos/models/ml/ml_typedefs.dart'; +import "package:photos/services/face_ml/face_feedback.dart/cluster_feedback.dart"; +import "package:photos/services/face_ml/face_feedback.dart/feedback_types.dart"; +import "package:photos/services/face_ml/face_ml_result.dart"; +import 'package:sqflite/sqflite.dart'; + +/// Stores all data for the ML-related features. The database can be accessed by `MlDataDB.instance.database`. +/// +/// This includes: +/// [facesTable] - Stores all the detected faces and its embeddings in the images. +/// [peopleTable] - Stores all the clusters of faces which are considered to be the same person. +class MlDataDB { + static final Logger _logger = Logger("MlDataDB"); + + // TODO: [BOB] put the db in files + static const _databaseName = "ente.ml_data.db"; + static const _databaseVersion = 1; + + static const facesTable = 'faces'; + static const fileIDColumn = 'file_id'; + static const faceMlResultColumn = 'face_ml_result'; + static const mlVersionColumn = 'ml_version'; + + static const peopleTable = 'people'; + static const personIDColumn = 'person_id'; + static const clusterResultColumn = 'cluster_result'; + static const centroidColumn = 'cluster_centroid'; + static const centroidDistanceThresholdColumn = 'centroid_distance_threshold'; + + static const feedbackTable = 'feedback'; + static const feedbackIDColumn = 'feedback_id'; + static const feedbackTypeColumn = 'feedback_type'; + static const feedbackDataColumn = 'feedback_data'; + static const feedbackTimestampColumn = 'feedback_timestamp'; + static const feedbackFaceMlVersionColumn = 'feedback_face_ml_version'; + static const feedbackClusterMlVersionColumn = 'feedback_cluster_ml_version'; + + static const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable ( + $fileIDColumn INTEGER NOT NULL UNIQUE, + $faceMlResultColumn TEXT NOT NULL, + $mlVersionColumn INTEGER NOT NULL, + PRIMARY KEY($fileIDColumn) + ); + '''; + static const createPeopleTable = '''CREATE TABLE IF NOT EXISTS $peopleTable ( + $personIDColumn INTEGER NOT NULL UNIQUE, + $clusterResultColumn TEXT NOT NULL, + $centroidColumn TEXT NOT NULL, + $centroidDistanceThresholdColumn REAL NOT NULL, + PRIMARY KEY($personIDColumn) + ); + '''; + static const createFeedbackTable = + '''CREATE TABLE IF NOT EXISTS $feedbackTable ( + $feedbackIDColumn TEXT NOT NULL UNIQUE, + $feedbackTypeColumn TEXT NOT NULL, + $feedbackDataColumn TEXT NOT NULL, + $feedbackTimestampColumn TEXT NOT NULL, + $feedbackFaceMlVersionColumn INTEGER NOT NULL, + $feedbackClusterMlVersionColumn INTEGER NOT NULL, + PRIMARY KEY($feedbackIDColumn) + ); + '''; + static const _deleteFacesTable = 'DROP TABLE IF EXISTS $facesTable'; + static const _deletePeopleTable = 'DROP TABLE IF EXISTS $peopleTable'; + static const _deleteFeedbackTable = 'DROP TABLE IF EXISTS $feedbackTable'; + + MlDataDB._privateConstructor(); + static final MlDataDB instance = MlDataDB._privateConstructor(); + + static Future? _dbFuture; + Future get database async { + _dbFuture ??= _initDatabase(); + return _dbFuture!; + } + + Future _initDatabase() async { + final documentsDirectory = await getApplicationDocumentsDirectory(); + final String databaseDirectory = + join(documentsDirectory.path, _databaseName); + return await openDatabase( + databaseDirectory, + version: _databaseVersion, + onCreate: _onCreate, + ); + } + + Future _onCreate(Database db, int version) async { + await db.execute(createFacesTable); + await db.execute(createPeopleTable); + await db.execute(createFeedbackTable); + } + + /// WARNING: This will delete ALL data in the database! Only use this for debug/testing purposes! + Future cleanTables({ + bool cleanFaces = false, + bool cleanPeople = false, + bool cleanFeedback = false, + }) async { + _logger.fine('`cleanTables()` called'); + final db = await instance.database; + + if (cleanFaces) { + _logger.fine('`cleanTables()`: Cleaning faces table'); + await db.execute(_deleteFacesTable); + } + + if (cleanPeople) { + _logger.fine('`cleanTables()`: Cleaning people table'); + await db.execute(_deletePeopleTable); + } + + if (cleanFeedback) { + _logger.fine('`cleanTables()`: Cleaning feedback table'); + await db.execute(_deleteFeedbackTable); + } + + if (!cleanFaces && !cleanPeople && !cleanFeedback) { + _logger.fine( + '`cleanTables()`: No tables cleaned, since no table was specified. Please be careful with this function!', + ); + } + + await db.execute(createFacesTable); + await db.execute(createPeopleTable); + await db.execute(createFeedbackTable); + } + + Future createFaceMlResult(FaceMlResult faceMlResult) async { + _logger.fine('createFaceMlResult called'); + + final existingResult = await getFaceMlResult(faceMlResult.fileId); + if (existingResult != null) { + if (faceMlResult.mlVersion <= existingResult.mlVersion) { + _logger.fine( + 'FaceMlResult with file ID ${faceMlResult.fileId} already exists with equal or higher version. Skipping insert.', + ); + return; + } + } + + final db = await instance.database; + await db.insert( + facesTable, + { + fileIDColumn: faceMlResult.fileId, + faceMlResultColumn: faceMlResult.toJsonString(), + mlVersionColumn: faceMlResult.mlVersion, + }, + conflictAlgorithm: ConflictAlgorithm.replace, + ); + } + + Future doesFaceMlResultExist(int fileId, {int? mlVersion}) async { + _logger.fine('doesFaceMlResultExist called'); + final db = await instance.database; + + String whereString = '$fileIDColumn = ?'; + final List whereArgs = [fileId]; + + if (mlVersion != null) { + whereString += ' AND $mlVersionColumn = ?'; + whereArgs.add(mlVersion); + } + + final result = await db.query( + facesTable, + where: whereString, + whereArgs: whereArgs, + limit: 1, + ); + return result.isNotEmpty; + } + + Future getFaceMlResult(int fileId, {int? mlVersion}) async { + _logger.fine('getFaceMlResult called'); + final db = await instance.database; + + String whereString = '$fileIDColumn = ?'; + final List whereArgs = [fileId]; + + if (mlVersion != null) { + whereString += ' AND $mlVersionColumn = ?'; + whereArgs.add(mlVersion); + } + + final result = await db.query( + facesTable, + where: whereString, + whereArgs: whereArgs, + limit: 1, + ); + if (result.isNotEmpty) { + return FaceMlResult.fromJsonString( + result.first[faceMlResultColumn] as String, + ); + } + _logger.fine( + 'No faceMlResult found for fileID $fileId and mlVersion $mlVersion (null if not specified)', + ); + return null; + } + + /// Returns the faceMlResults for the given [fileIds]. + Future> getSelectedFaceMlResults( + List fileIds, + ) async { + _logger.fine('getSelectedFaceMlResults called'); + final db = await instance.database; + + if (fileIds.isEmpty) { + _logger.warning('getSelectedFaceMlResults called with empty fileIds'); + return []; + } + + final List> results = await db.query( + facesTable, + columns: [faceMlResultColumn], + where: '$fileIDColumn IN (${fileIds.join(',')})', + orderBy: fileIDColumn, + ); + + return results + .map( + (result) => + FaceMlResult.fromJsonString(result[faceMlResultColumn] as String), + ) + .toList(); + } + + Future> getAllFaceMlResults({int? mlVersion}) async { + _logger.fine('getAllFaceMlResults called'); + final db = await instance.database; + + String? whereString; + List? whereArgs; + + if (mlVersion != null) { + whereString = '$mlVersionColumn = ?'; + whereArgs = [mlVersion]; + } + + final results = await db.query( + facesTable, + where: whereString, + whereArgs: whereArgs, + orderBy: fileIDColumn, + ); + + return results + .map( + (result) => + FaceMlResult.fromJsonString(result[faceMlResultColumn] as String), + ) + .toList(); + } + + /// getAllFileIDs returns a set of all fileIDs from the facesTable, meaning all the fileIDs for which a FaceMlResult exists, optionally filtered by mlVersion. + Future> getAllFaceMlResultFileIDs({int? mlVersion}) async { + _logger.fine('getAllFaceMlResultFileIDs called'); + final db = await instance.database; + + String? whereString; + List? whereArgs; + + if (mlVersion != null) { + whereString = '$mlVersionColumn = ?'; + whereArgs = [mlVersion]; + } + + final List> results = await db.query( + facesTable, + where: whereString, + whereArgs: whereArgs, + orderBy: fileIDColumn, + ); + + return results.map((result) => result[fileIDColumn] as int).toSet(); + } + + Future> getAllFaceMlResultFileIDsProcessedWithThumbnailOnly({ + int? mlVersion, + }) async { + _logger.fine('getAllFaceMlResultFileIDsProcessedWithThumbnailOnly called'); + final db = await instance.database; + + String? whereString; + List? whereArgs; + + if (mlVersion != null) { + whereString = '$mlVersionColumn = ?'; + whereArgs = [mlVersion]; + } + + final List> results = await db.query( + facesTable, + where: whereString, + whereArgs: whereArgs, + orderBy: fileIDColumn, + ); + + return results + .map( + (result) => + FaceMlResult.fromJsonString(result[faceMlResultColumn] as String), + ) + .where((element) => element.onlyThumbnailUsed) + .map((result) => result.fileId) + .toSet(); + } + + /// Updates the faceMlResult for the given [faceMlResult.fileId]. Update is done regardless of the [faceMlResult.mlVersion]. + /// However, if [updateHigherVersionOnly] is set to true, the update is only done if the [faceMlResult.mlVersion] is higher than the existing one. + Future updateFaceMlResult( + FaceMlResult faceMlResult, { + bool updateHigherVersionOnly = false, + }) async { + _logger.fine('updateFaceMlResult called'); + + if (updateHigherVersionOnly) { + final existingResult = await getFaceMlResult(faceMlResult.fileId); + if (existingResult != null) { + if (faceMlResult.mlVersion <= existingResult.mlVersion) { + _logger.fine( + 'FaceMlResult with file ID ${faceMlResult.fileId} already exists with equal or higher version. Skipping update.', + ); + return 0; + } + } + } + + final db = await instance.database; + return await db.update( + facesTable, + { + fileIDColumn: faceMlResult.fileId, + faceMlResultColumn: faceMlResult.toJsonString(), + mlVersionColumn: faceMlResult.mlVersion, + }, + where: '$fileIDColumn = ?', + whereArgs: [faceMlResult.fileId], + ); + } + + Future deleteFaceMlResult(int fileId) async { + _logger.fine('deleteFaceMlResult called'); + final db = await instance.database; + final deleteCount = await db.delete( + facesTable, + where: '$fileIDColumn = ?', + whereArgs: [fileId], + ); + _logger.fine('Deleted $deleteCount faceMlResults'); + return deleteCount; + } + + Future createAllClusterResults( + List clusterResults, { + bool cleanExistingClusters = true, + }) async { + _logger.fine('createClusterResults called'); + final db = await instance.database; + + if (clusterResults.isEmpty) { + _logger.fine('No clusterResults given, skipping insert.'); + return; + } + + // Completely clean the table and start fresh + if (cleanExistingClusters) { + await deleteAllClusterResults(); + } + + // Insert all the cluster results + for (final clusterResult in clusterResults) { + await db.insert( + peopleTable, + { + personIDColumn: clusterResult.personId, + clusterResultColumn: clusterResult.toJsonString(), + centroidColumn: clusterResult.medoid.toString(), + centroidDistanceThresholdColumn: + clusterResult.medoidDistanceThreshold, + }, + conflictAlgorithm: ConflictAlgorithm.replace, + ); + } + } + + Future getClusterResult(int personId) async { + _logger.fine('getClusterResult called'); + final db = await instance.database; + + final result = await db.query( + peopleTable, + where: '$personIDColumn = ?', + whereArgs: [personId], + limit: 1, + ); + if (result.isNotEmpty) { + return ClusterResult.fromJsonString( + result.first[clusterResultColumn] as String, + ); + } + _logger.fine('No clusterResult found for personID $personId'); + return null; + } + + /// Returns the ClusterResult objects for the given [personIDs]. + Future> getSelectedClusterResults( + List personIDs, + ) async { + _logger.fine('getSelectedClusterResults called'); + final db = await instance.database; + + if (personIDs.isEmpty) { + _logger.warning('getSelectedClusterResults called with empty personIDs'); + return []; + } + + final results = await db.query( + peopleTable, + where: '$personIDColumn IN (${personIDs.join(',')})', + orderBy: personIDColumn, + ); + + return results + .map( + (result) => ClusterResult.fromJsonString( + result[clusterResultColumn] as String, + ), + ) + .toList(); + } + + Future> getAllClusterResults() async { + _logger.fine('getAllClusterResults called'); + final db = await instance.database; + + final results = await db.query( + peopleTable, + ); + + return results + .map( + (result) => ClusterResult.fromJsonString( + result[clusterResultColumn] as String, + ), + ) + .toList(); + } + + /// Returns the personIDs of all clustered people in the database. + Future> getAllClusterIds() async { + _logger.fine('getAllClusterIds called'); + final db = await instance.database; + + final results = await db.query( + peopleTable, + columns: [personIDColumn], + ); + + return results.map((result) => result[personIDColumn] as int).toList(); + } + + /// Returns the fileIDs of all files associated with a given [personId]. + Future> getClusterFileIds(int personId) async { + _logger.fine('getClusterFileIds called'); + + final ClusterResult? clusterResult = await getClusterResult(personId); + if (clusterResult == null) { + return []; + } + return clusterResult.uniqueFileIds; + } + + Future> getClusterFaceIds(int personId) async { + _logger.fine('getClusterFaceIds called'); + + final ClusterResult? clusterResult = await getClusterResult(personId); + if (clusterResult == null) { + return []; + } + return clusterResult.faceIDs; + } + + Future> getClusterEmbeddings( + int personId, + ) async { + _logger.fine('getClusterEmbeddings called'); + + final ClusterResult? clusterResult = await getClusterResult(personId); + if (clusterResult == null) return []; + + final fileIds = clusterResult.uniqueFileIds; + final faceIds = clusterResult.faceIDs; + if (fileIds.length != faceIds.length) { + _logger.severe( + 'fileIds and faceIds have different lengths: ${fileIds.length} vs ${faceIds.length}. This should not happen!', + ); + return []; + } + + final faceMlResults = await getSelectedFaceMlResults(fileIds); + if (faceMlResults.isEmpty) return []; + + final embeddings = []; + for (var i = 0; i < faceMlResults.length; i++) { + final faceMlResult = faceMlResults[i]; + final int faceIndex = faceMlResult.allFaceIds.indexOf(faceIds[i]); + if (faceIndex == -1) { + _logger.severe( + 'Could not find faceIndex for faceId ${faceIds[i]} in faceMlResult ${faceMlResult.fileId}', + ); + return []; + } + embeddings.add(faceMlResult.faces[faceIndex].embedding); + } + + return embeddings; + } + + Future updateClusterResult(ClusterResult clusterResult) async { + _logger.fine('updateClusterResult called'); + final db = await instance.database; + await db.update( + peopleTable, + { + personIDColumn: clusterResult.personId, + clusterResultColumn: clusterResult.toJsonString(), + centroidColumn: clusterResult.medoid.toString(), + centroidDistanceThresholdColumn: clusterResult.medoidDistanceThreshold, + }, + where: '$personIDColumn = ?', + whereArgs: [clusterResult.personId], + ); + } + + Future deleteClusterResult(int personId) async { + _logger.fine('deleteClusterResult called'); + final db = await instance.database; + final deleteCount = await db.delete( + peopleTable, + where: '$personIDColumn = ?', + whereArgs: [personId], + ); + _logger.fine('Deleted $deleteCount clusterResults'); + return deleteCount; + } + + Future deleteAllClusterResults() async { + _logger.fine('deleteAllClusterResults called'); + final db = await instance.database; + await db.execute(_deletePeopleTable); + await db.execute(createPeopleTable); + } + + // TODO: current function implementation will skip inserting for a similar feedback, which means I can't remove two photos from the same person in a row + Future createClusterFeedback( + T feedback, { + bool skipIfSimilarFeedbackExists = false, + }) async { + _logger.fine('createClusterFeedback called'); + + // TODO: this skipping might cause issues for adding photos to the same person in a row!! + if (skipIfSimilarFeedbackExists && + await doesSimilarClusterFeedbackExist(feedback)) { + _logger.fine( + 'ClusterFeedback with ID ${feedback.feedbackID} already has a similar feedback installed. Skipping insert.', + ); + return; + } + + final db = await instance.database; + await db.insert( + feedbackTable, + { + feedbackIDColumn: feedback.feedbackID, + feedbackTypeColumn: feedback.typeString, + feedbackDataColumn: feedback.toJsonString(), + feedbackTimestampColumn: feedback.timestampString, + feedbackFaceMlVersionColumn: feedback.madeOnFaceMlVersion, + feedbackClusterMlVersionColumn: feedback.madeOnClusterMlVersion, + }, + conflictAlgorithm: ConflictAlgorithm.replace, + ); + return; + } + + Future doesSimilarClusterFeedbackExist( + T feedback, + ) async { + _logger.fine('doesClusterFeedbackExist called'); + + final List existingFeedback = + await getAllClusterFeedback(type: feedback.type); + + if (existingFeedback.isNotEmpty) { + for (final existingFeedbackItem in existingFeedback) { + assert( + existingFeedbackItem.type == feedback.type, + 'Feedback types should be the same!', + ); + if (feedback.looselyMatchesMedoid(existingFeedbackItem)) { + _logger.fine( + 'ClusterFeedback of type ${feedback.typeString} with ID ${feedback.feedbackID} already has a similar feedback installed!', + ); + return true; + } + } + } + return false; + } + + /// Returns all the clusterFeedbacks of type [T] which match the given [feedback], sorted by timestamp (latest first). + Future> getAllMatchingClusterFeedback( + T feedback, { + bool sortNewestFirst = true, + }) async { + _logger.fine('getAllMatchingClusterFeedback called'); + + final List existingFeedback = + await getAllClusterFeedback(type: feedback.type); + final List matchingFeedback = []; + if (existingFeedback.isNotEmpty) { + for (final existingFeedbackItem in existingFeedback) { + assert( + existingFeedbackItem.type == feedback.type, + 'Feedback types should be the same!', + ); + if (feedback.looselyMatchesMedoid(existingFeedbackItem)) { + _logger.fine( + 'ClusterFeedback of type ${feedback.typeString} with ID ${feedback.feedbackID} already has a similar feedback installed!', + ); + matchingFeedback.add(existingFeedbackItem); + } + } + } + if (sortNewestFirst) { + matchingFeedback.sort((a, b) => b.timestamp.compareTo(a.timestamp)); + } + return matchingFeedback; + } + + Future> getAllClusterFeedback({ + required FeedbackType type, + int? mlVersion, + int? clusterMlVersion, + }) async { + _logger.fine('getAllClusterFeedback called'); + final db = await instance.database; + + // TODO: implement the versions for FeedbackType.imageFeedback and FeedbackType.faceFeedback and rename this function to getAllFeedback? + + String whereString = '$feedbackTypeColumn = ?'; + final List whereArgs = [type.toValueString()]; + + if (mlVersion != null) { + whereString += ' AND $feedbackFaceMlVersionColumn = ?'; + whereArgs.add(mlVersion); + } + if (clusterMlVersion != null) { + whereString += ' AND $feedbackClusterMlVersionColumn = ?'; + whereArgs.add(clusterMlVersion); + } + + final results = await db.query( + feedbackTable, + where: whereString, + whereArgs: whereArgs, + ); + + if (results.isNotEmpty) { + if (ClusterFeedback.fromJsonStringRegistry.containsKey(type)) { + final Function(String) fromJsonString = + ClusterFeedback.fromJsonStringRegistry[type]!; + return results + .map((e) => fromJsonString(e[feedbackDataColumn] as String) as T) + .toList(); + } else { + _logger.severe( + 'No fromJsonString function found for type ${type.name}. This should not happen!', + ); + } + } + _logger.fine( + 'No clusterFeedback results found of type $type' + + (mlVersion != null ? ' and mlVersion $mlVersion' : '') + + (clusterMlVersion != null + ? ' and clusterMlVersion $clusterMlVersion' + : ''), + ); + return []; + } + + Future deleteClusterFeedback( + T feedback, + ) async { + _logger.fine('deleteClusterFeedback called'); + final db = await instance.database; + final deleteCount = await db.delete( + feedbackTable, + where: '$feedbackIDColumn = ?', + whereArgs: [feedback.feedbackID], + ); + _logger.fine('Deleted $deleteCount clusterFeedbacks'); + return deleteCount; + } +} diff --git a/mobile/lib/events/files_updated_event.dart b/mobile/lib/events/files_updated_event.dart index 18aa8757b..7d7779d49 100644 --- a/mobile/lib/events/files_updated_event.dart +++ b/mobile/lib/events/files_updated_event.dart @@ -26,4 +26,5 @@ enum EventType { hide, unhide, coverChanged, + peopleChanged, } diff --git a/mobile/lib/events/people_changed_event.dart b/mobile/lib/events/people_changed_event.dart new file mode 100644 index 000000000..e2d135866 --- /dev/null +++ b/mobile/lib/events/people_changed_event.dart @@ -0,0 +1,3 @@ +import "package:photos/events/event.dart"; + +class PeopleChangedEvent extends Event {} diff --git a/mobile/lib/extensions/ml_linalg_extensions.dart b/mobile/lib/extensions/ml_linalg_extensions.dart new file mode 100644 index 000000000..85a980855 --- /dev/null +++ b/mobile/lib/extensions/ml_linalg_extensions.dart @@ -0,0 +1,193 @@ +import 'dart:math' as math show sin, cos, atan2, sqrt, pow; +import 'package:ml_linalg/linalg.dart'; + +extension SetVectorValues on Vector { + Vector setValues(int start, int end, Iterable values) { + if (values.length > length) { + throw Exception('Values cannot be larger than vector'); + } else if (end - start != values.length) { + throw Exception('Values must be same length as range'); + } else if (start < 0 || end > length) { + throw Exception('Range must be within vector'); + } + final tempList = toList(); + tempList.replaceRange(start, end, values); + final newVector = Vector.fromList(tempList); + return newVector; + } +} + +extension SetMatrixValues on Matrix { + Matrix setSubMatrix( + int startRow, + int endRow, + int startColumn, + int endColumn, + Iterable> values, + ) { + if (values.length > rowCount) { + throw Exception('New values cannot have more rows than original matrix'); + } else if (values.elementAt(0).length > columnCount) { + throw Exception( + 'New values cannot have more columns than original matrix', + ); + } else if (endRow - startRow != values.length) { + throw Exception('Values (number of rows) must be same length as range'); + } else if (endColumn - startColumn != values.elementAt(0).length) { + throw Exception( + 'Values (number of columns) must be same length as range', + ); + } else if (startRow < 0 || + endRow > rowCount || + startColumn < 0 || + endColumn > columnCount) { + throw Exception('Range must be within matrix'); + } + final tempList = asFlattenedList + .toList(); // You need `.toList()` here to make sure the list is growable, otherwise `replaceRange` will throw an error + for (var i = startRow; i < endRow; i++) { + tempList.replaceRange( + i * columnCount + startColumn, + i * columnCount + endColumn, + values.elementAt(i).toList(), + ); + } + final newMatrix = Matrix.fromFlattenedList(tempList, rowCount, columnCount); + return newMatrix; + } + + Matrix setValues( + int startRow, + int endRow, + int startColumn, + int endColumn, + Iterable values, + ) { + if ((startRow - endRow) * (startColumn - endColumn) != values.length) { + throw Exception('Values must be same length as range'); + } else if (startRow < 0 || + endRow > rowCount || + startColumn < 0 || + endColumn > columnCount) { + throw Exception('Range must be within matrix'); + } + + final tempList = asFlattenedList + .toList(); // You need `.toList()` here to make sure the list is growable, otherwise `replaceRange` will throw an error + var index = 0; + for (var i = startRow; i < endRow; i++) { + for (var j = startColumn; j < endColumn; j++) { + tempList[i * columnCount + j] = values.elementAt(index); + index++; + } + } + final newMatrix = Matrix.fromFlattenedList(tempList, rowCount, columnCount); + return newMatrix; + } + + Matrix setValue(int row, int column, double value) { + if (row < 0 || row > rowCount || column < 0 || column > columnCount) { + throw Exception('Index must be within range of matrix'); + } + final tempList = asFlattenedList; + tempList[row * columnCount + column] = value; + final newMatrix = Matrix.fromFlattenedList(tempList, rowCount, columnCount); + return newMatrix; + } + + Matrix appendRow(List row) { + final oldNumberOfRows = rowCount; + final oldNumberOfColumns = columnCount; + if (row.length != oldNumberOfColumns) { + throw Exception('Row must have same number of columns as matrix'); + } + final flatListMatrix = asFlattenedList; + flatListMatrix.addAll(row); + return Matrix.fromFlattenedList( + flatListMatrix, + oldNumberOfRows + 1, + oldNumberOfColumns, + ); + } +} + +extension MatrixCalculations on Matrix { + double determinant() { + final int length = rowCount; + if (length != columnCount) { + throw Exception('Matrix must be square'); + } + if (length == 1) { + return this[0][0]; + } else if (length == 2) { + return this[0][0] * this[1][1] - this[0][1] * this[1][0]; + } else { + throw Exception('Determinant for Matrix larger than 2x2 not implemented'); + } + } + + /// Computes the singular value decomposition of a matrix, using https://lucidar.me/en/mathematics/singular-value-decomposition-of-a-2x2-matrix/ as reference, but with slightly different signs for the second columns of U and V + Map svd() { + if (rowCount != 2 || columnCount != 2) { + throw Exception('Matrix must be 2x2'); + } + final a = this[0][0]; + final b = this[0][1]; + final c = this[1][0]; + final d = this[1][1]; + + // Computation of U matrix + final tempCalc = a * a + b * b - c * c - d * d; + final theta = 0.5 * math.atan2(2 * a * c + 2 * b * d, tempCalc); + final U = Matrix.fromList([ + [math.cos(theta), math.sin(theta)], + [math.sin(theta), -math.cos(theta)], + ]); + + // Computation of S matrix + // ignore: non_constant_identifier_names + final S1 = a * a + b * b + c * c + d * d; + // ignore: non_constant_identifier_names + final S2 = + math.sqrt(math.pow(tempCalc, 2) + 4 * math.pow(a * c + b * d, 2)); + final sigma1 = math.sqrt((S1 + S2) / 2); + final sigma2 = math.sqrt((S1 - S2) / 2); + final S = Vector.fromList([sigma1, sigma2]); + + // Computation of V matrix + final tempCalc2 = a * a - b * b + c * c - d * d; + final phi = 0.5 * math.atan2(2 * a * b + 2 * c * d, tempCalc2); + final s11 = (a * math.cos(theta) + c * math.sin(theta)) * math.cos(phi) + + (b * math.cos(theta) + d * math.sin(theta)) * math.sin(phi); + final s22 = (a * math.sin(theta) - c * math.cos(theta)) * math.sin(phi) + + (-b * math.sin(theta) + d * math.cos(theta)) * math.cos(phi); + final V = Matrix.fromList([ + [s11.sign * math.cos(phi), s22.sign * math.sin(phi)], + [s11.sign * math.sin(phi), -s22.sign * math.cos(phi)], + ]); + + return { + 'U': U, + 'S': S, + 'V': V, + }; + } + + int matrixRank() { + final svdResult = svd(); + final Vector S = svdResult['S']!; + final rank = S.toList().where((element) => element > 1e-10).length; + return rank; + } +} + +extension TransformMatrix on Matrix { + List> to2DList() { + final List> outerList = []; + for (var i = 0; i < rowCount; i++) { + final innerList = this[i].toList(); + outerList.add(innerList); + } + return outerList; + } +} diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart new file mode 100644 index 000000000..cd82bcab1 --- /dev/null +++ b/mobile/lib/face/db.dart @@ -0,0 +1,679 @@ +import 'dart:async'; +import "dart:math"; +import "dart:typed_data"; + +import "package:collection/collection.dart"; +import "package:flutter/foundation.dart"; +import 'package:logging/logging.dart'; +import 'package:path/path.dart' show join; +import 'package:path_provider/path_provider.dart'; +import 'package:photos/face/db_fields.dart'; +import "package:photos/face/db_model_mappers.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/services/face_ml/blur_detection/blur_constants.dart"; +import 'package:sqflite/sqflite.dart'; + +/// Stores all data for the ML-related features. The database can be accessed by `MlDataDB.instance.database`. +/// +/// This includes: +/// [facesTable] - Stores all the detected faces and its embeddings in the images. +/// [peopleTable] - Stores all the clusters of faces which are considered to be the same person. +class FaceMLDataDB { + static final Logger _logger = Logger("FaceMLDataDB"); + + static const _databaseName = "ente.face_ml_db.db"; + static const _databaseVersion = 1; + + FaceMLDataDB._privateConstructor(); + + static final FaceMLDataDB instance = FaceMLDataDB._privateConstructor(); + + static Future? _dbFuture; + + Future get database async { + _dbFuture ??= _initDatabase(); + return _dbFuture!; + } + + Future _initDatabase() async { + final documentsDirectory = await getApplicationDocumentsDirectory(); + final String databaseDirectory = + join(documentsDirectory.path, _databaseName); + return await openDatabase( + databaseDirectory, + version: _databaseVersion, + onCreate: _onCreate, + ); + } + + Future _onCreate(Database db, int version) async { + await db.execute(createFacesTable); + await db.execute(createPeopleTable); + await db.execute(createClusterTable); + await db.execute(createClusterSummaryTable); + await db.execute(createNotPersonFeedbackTable); + } + + // bulkInsertFaces inserts the faces in the database in batches of 1000. + // This is done to avoid the error "too many SQL variables" when inserting + // a large number of faces. + Future bulkInsertFaces(List faces) async { + final db = await instance.database; + const batchSize = 500; + final numBatches = (faces.length / batchSize).ceil(); + for (int i = 0; i < numBatches; i++) { + final start = i * batchSize; + final end = min((i + 1) * batchSize, faces.length); + final batch = faces.sublist(start, end); + final batchInsert = db.batch(); + for (final face in batch) { + batchInsert.insert( + facesTable, + mapRemoteToFaceDB(face), + conflictAlgorithm: ConflictAlgorithm.ignore, + ); + } + await batchInsert.commit(noResult: true); + } + } + + Future> getIndexedFileIds() async { + final db = await instance.database; + final List> maps = await db.rawQuery( + 'SELECT DISTINCT $fileIDColumn FROM $facesTable', + ); + return maps.map((e) => e[fileIDColumn] as int).toSet(); + } + + Future> clusterIdToFaceCount() async { + final db = await instance.database; + final List> maps = await db.rawQuery( + 'SELECT $cluserIDColumn, COUNT(*) as count FROM $facesTable where $cluserIDColumn IS NOT NULL GROUP BY $cluserIDColumn ', + ); + final Map result = {}; + for (final map in maps) { + result[map[cluserIDColumn] as int] = map['count'] as int; + } + return result; + } + + Future> getPersonIgnoredClusters(String personID) async { + final db = await instance.database; + // find out clusterIds that are assigned to other persons using the clusters table + final List> maps = await db.rawQuery( + 'SELECT $cluserIDColumn FROM $clustersTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', + [personID], + ); + final Set ignoredClusterIDs = + maps.map((e) => e[cluserIDColumn] as int).toSet(); + final List> rejectMaps = await db.rawQuery( + 'SELECT $cluserIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', + [personID], + ); + final Set rejectClusterIDs = + rejectMaps.map((e) => e[cluserIDColumn] as int).toSet(); + return ignoredClusterIDs.union(rejectClusterIDs); + } + + Future> getPersonClusterIDs(String personID) async { + final db = await instance.database; + final List> maps = await db.rawQuery( + 'SELECT $cluserIDColumn FROM $clustersTable WHERE $personIdColumn = ?', + [personID], + ); + return maps.map((e) => e[cluserIDColumn] as int).toSet(); + } + + Future clearTable() async { + final db = await instance.database; + await db.delete(facesTable); + await db.delete(createClusterTable); + await db.delete(clusterSummaryTable); + await db.delete(peopleTable); + await db.delete(notPersonFeedback); + } + + Future> getFaceEmbeddingsForCluster( + int clusterID, { + int? limit, + }) async { + final db = await instance.database; + final List> maps = await db.rawQuery( + 'SELECT $faceEmbeddingBlob FROM $facesTable WHERE $cluserIDColumn = ? ${limit != null ? 'LIMIT $limit' : ''}', + [clusterID], + ); + return maps.map((e) => e[faceEmbeddingBlob] as Uint8List); + } + + Future> getFileIdToCount() async { + final Map result = {}; + final db = await instance.database; + final List> maps = await db.rawQuery( + 'SELECT $fileIDColumn, COUNT(*) as count FROM $facesTable where $faceScore > 0.8 GROUP BY $fileIDColumn', + ); + + for (final map in maps) { + result[map[fileIDColumn] as int] = map['count'] as int; + } + return result; + } + + Future getCoverFaceForPerson({ + required int recentFileID, + String? personID, + int? clusterID, + }) async { + // read person from db + final db = await instance.database; + if (personID != null) { + final List> maps = await db.rawQuery( + 'SELECT * FROM $peopleTable where $idColumn = ?', + [personID], + ); + if (maps.isEmpty) { + throw Exception("Person with id $personID not found"); + } + + final person = mapRowToPerson(maps.first); + final List fileId = [recentFileID]; + int? avatarFileId; + if (person.attr.avatarFaceId != null) { + avatarFileId = int.tryParse(person.attr.avatarFaceId!.split('-')[0]); + if (avatarFileId != null) { + fileId.add(avatarFileId); + } + } + final cluterRows = await db.query( + clustersTable, + columns: [cluserIDColumn], + where: '$personIdColumn = ?', + whereArgs: [personID], + ); + final clusterIDs = + cluterRows.map((e) => e[cluserIDColumn] as int).toList(); + final List> faceMaps = await db.rawQuery( + 'SELECT * FROM $facesTable where $faceClusterId IN (${clusterIDs.join(",")}) AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > 0.8 ORDER BY $faceScore DESC', + ); + if (faceMaps.isNotEmpty) { + if (avatarFileId != null) { + final row = faceMaps.firstWhereOrNull( + (element) => (element[fileIDColumn] as int) == avatarFileId, + ); + if (row != null) { + return mapRowToFace(row); + } + } + return mapRowToFace(faceMaps.first); + } + } + if (clusterID != null) { + final clusterIDs = [clusterID]; + final List> faceMaps = await db.rawQuery( + 'SELECT * FROM $facesTable where $faceClusterId IN (${clusterIDs.join(",")}) AND $fileIDColumn = $recentFileID ', + ); + if (faceMaps.isNotEmpty) { + return mapRowToFace(faceMaps.first); + } + } + if (personID == null && clusterID == null) { + throw Exception("personID and clusterID cannot be null"); + } + return null; + } + + Future> getFacesForGivenFileID(int fileUploadID) async { + final db = await instance.database; + final List> maps = await db.query( + facesTable, + columns: [ + fileIDColumn, + faceIDColumn, + faceDetectionColumn, + faceEmbeddingBlob, + faceScore, + faceBlur, + faceClusterId, + faceClosestDistColumn, + faceClosestFaceID, + faceConfirmedColumn, + mlVersionColumn, + ], + where: '$fileIDColumn = ?', + whereArgs: [fileUploadID], + ); + return maps.map((e) => mapRowToFace(e)).toList(); + } + + Future getFaceForFaceID(String faceID) async { + final db = await instance.database; + final result = await db.rawQuery( + 'SELECT * FROM $facesTable where $faceIDColumn = ?', + [faceID], + ); + if (result.isEmpty) { + return null; + } + return mapRowToFace(result.first); + } + + Future> getFaceIdsToClusterIds( + Iterable faceIds, + ) async { + final db = await instance.database; + final List> maps = await db.rawQuery( + 'SELECT $faceIDColumn, $faceClusterId FROM $facesTable where $faceIDColumn IN (${faceIds.map((id) => "'$id'").join(",")})', + ); + final Map result = {}; + for (final map in maps) { + result[map[faceIDColumn] as String] = map[faceClusterId] as int?; + } + return result; + } + + Future>> getFileIdToClusterIds() async { + final Map> result = {}; + final db = await instance.database; + final List> maps = await db.rawQuery( + 'SELECT $faceClusterId, $fileIDColumn FROM $facesTable where $faceClusterId IS NOT NULL', + ); + + for (final map in maps) { + final personID = map[faceClusterId] as int; + final fileID = map[fileIDColumn] as int; + result[fileID] = (result[fileID] ?? {})..add(personID); + } + return result; + } + + Future updatePersonIDForFaceIDIFNotSet( + Map faceIDToPersonID, + ) async { + final db = await instance.database; + + // Start a batch + final batch = db.batch(); + + for (final map in faceIDToPersonID.entries) { + final faceID = map.key; + final personID = map.value; + batch.update( + facesTable, + {faceClusterId: personID}, + where: '$faceIDColumn = ? AND $faceClusterId IS NULL', + whereArgs: [faceID], + ); + } + // Commit the batch + await batch.commit(noResult: true); + } + + Future forceUpdateClusterIds( + Map faceIDToPersonID, + ) async { + final db = await instance.database; + + // Start a batch + final batch = db.batch(); + + for (final map in faceIDToPersonID.entries) { + final faceID = map.key; + final personID = map.value; + batch.update( + facesTable, + {faceClusterId: personID}, + where: '$faceIDColumn = ?', + whereArgs: [faceID], + ); + } + // Commit the batch + await batch.commit(noResult: true); + } + + /// Returns a map of faceID to record of faceClusterID and faceEmbeddingBlob + /// + /// Only selects faces with score greater than [minScore] and blur score greater than [minClarity] + Future> getFaceEmbeddingMap({ + double minScore = 0.8, + int minClarity = kLaplacianThreshold, + int maxRows = 20000, + }) async { + _logger.info('reading as float'); + final db = await instance.database; + + // Define the batch size + const batchSize = 10000; + int offset = 0; + + final Map result = {}; + while (true) { + // Query a batch of rows + final List> maps = await db.query( + facesTable, + columns: [faceIDColumn, faceClusterId, faceEmbeddingBlob], + where: '$faceScore > $minScore and $faceBlur > $minClarity', + limit: batchSize, + offset: offset, + // orderBy: '$faceClusterId DESC', + orderBy: '$faceIDColumn DESC', + ); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + result[faceID] = + (map[faceClusterId] as int?, map[faceEmbeddingBlob] as Uint8List); + } + if (result.length >= 20000) { + break; + } + offset += batchSize; + } + return result; + } + + Future> getFaceEmbeddingMapForFile( + List fileIDs, + ) async { + _logger.info('reading as float'); + final db = await instance.database; + + // Define the batch size + const batchSize = 10000; + int offset = 0; + + final Map result = {}; + while (true) { + // Query a batch of rows + final List> maps = await db.query( + facesTable, + columns: [faceIDColumn, faceEmbeddingBlob], + where: + '$faceScore > 0.8 AND $faceBlur > $kLaplacianThreshold AND $fileIDColumn IN (${fileIDs.join(",")})', + limit: batchSize, + offset: offset, + orderBy: '$faceIDColumn DESC', + ); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + result[faceID] = map[faceEmbeddingBlob] as Uint8List; + } + if (result.length > 10000) { + break; + } + offset += batchSize; + } + return result; + } + + Future resetClusterIDs() async { + final db = await instance.database; + await db.update( + facesTable, + {faceClusterId: null}, + ); + } + + Future insert(Person p, int cluserID) async { + debugPrint("inserting person"); + final db = await instance.database; + await db.insert( + peopleTable, + mapPersonToRow(p), + conflictAlgorithm: ConflictAlgorithm.replace, + ); + await db.insert( + clustersTable, + { + personIdColumn: p.remoteID, + cluserIDColumn: cluserID, + }, + conflictAlgorithm: ConflictAlgorithm.replace, + ); + } + + Future updatePerson(Person p) async { + final db = await instance.database; + await db.update( + peopleTable, + mapPersonToRow(p), + where: '$idColumn = ?', + whereArgs: [p.remoteID], + ); + } + + Future assignClusterToPerson({ + required String personID, + required int clusterID, + }) async { + final db = await instance.database; + await db.insert( + clustersTable, + { + personIdColumn: personID, + cluserIDColumn: clusterID, + }, + ); + } + + Future captureNotPersonFeedback({ + required String personID, + required int clusterID, + }) async { + final db = await instance.database; + await db.insert( + notPersonFeedback, + { + personIdColumn: personID, + cluserIDColumn: clusterID, + }, + ); + } + + Future removeClusterToPerson({ + required String personID, + required int clusterID, + }) async { + final db = await instance.database; + return db.delete( + clustersTable, + where: '$personIdColumn = ? AND $cluserIDColumn = ?', + whereArgs: [personID, clusterID], + ); + } + + // for a given personID, return a map of clusterID to fileIDs using join query + Future>> getFileIdToClusterIDSet(String personID) { + final db = instance.database; + return db.then((db) async { + final List> maps = await db.rawQuery( + 'SELECT $clustersTable.$cluserIDColumn, $fileIDColumn FROM $facesTable ' + 'INNER JOIN $clustersTable ' + 'ON $facesTable.$faceClusterId = $clustersTable.$cluserIDColumn ' + 'WHERE $clustersTable.$personIdColumn = ?', + [personID], + ); + final Map> result = {}; + for (final map in maps) { + final clusterID = map[cluserIDColumn] as int; + final fileID = map[fileIDColumn] as int; + result[fileID] = (result[fileID] ?? {})..add(clusterID); + } + return result; + }); + } + + Future>> getFileIdToClusterIDSetForCluster( + Set clusterIDs, + ) { + final db = instance.database; + return db.then((db) async { + final List> maps = await db.rawQuery( + 'SELECT $cluserIDColumn, $fileIDColumn FROM $facesTable ' + 'WHERE $cluserIDColumn IN (${clusterIDs.join(",")})', + ); + final Map> result = {}; + for (final map in maps) { + final clusterID = map[cluserIDColumn] as int; + final fileID = map[fileIDColumn] as int; + result[fileID] = (result[fileID] ?? {})..add(clusterID); + } + return result; + }); + } + + Future clusterSummaryUpdate(Map summary) async { + final db = await instance.database; + var batch = db.batch(); + int batchCounter = 0; + for (final entry in summary.entries) { + if (batchCounter == 400) { + await batch.commit(noResult: true); + batch = db.batch(); + batchCounter = 0; + } + final int cluserID = entry.key; + final int count = entry.value.$2; + final Uint8List avg = entry.value.$1; + batch.insert( + clusterSummaryTable, + { + cluserIDColumn: cluserID, + avgColumn: avg, + countColumn: count, + }, + conflictAlgorithm: ConflictAlgorithm.replace, + ); + batchCounter++; + } + await batch.commit(noResult: true); + } + + /// Returns a map of clusterID to (avg embedding, count) + Future> clusterSummaryAll() async { + final db = await instance.database; + final Map result = {}; + final rows = await db.rawQuery('SELECT * from $clusterSummaryTable'); + for (final r in rows) { + final id = r[cluserIDColumn] as int; + final avg = r[avgColumn] as Uint8List; + final count = r[countColumn] as int; + result[id] = (avg, count); + } + return result; + } + + Future> getCluserIDToPersonMap() async { + final db = await instance.database; + final List> maps = await db.rawQuery( + 'SELECT $personIdColumn, $cluserIDColumn FROM $clustersTable', + ); + final Map result = {}; + for (final map in maps) { + result[map[cluserIDColumn] as int] = map[personIdColumn] as String; + } + return result; + } + + Future<(Map, Map)> getClusterIdToPerson() async { + final db = await instance.database; + final Map peopleMap = await getPeopleMap(); + final List> maps = await db.rawQuery( + 'SELECT $personIdColumn, $cluserIDColumn FROM $clustersTable', + ); + + final Map result = {}; + for (final map in maps) { + final Person? p = peopleMap[map[personIdColumn] as String]; + if (p != null) { + result[map[cluserIDColumn] as int] = p; + } else { + _logger.warning( + 'Person with id ${map[personIdColumn]} not found', + ); + } + } + return (result, peopleMap); + } + + Future> getPeopleMap() async { + final db = await instance.database; + final List> maps = await db.query( + peopleTable, + columns: [ + idColumn, + nameColumn, + personHiddenColumn, + clusterToFaceIdJson, + coverFaceIDColumn, + ], + ); + final Map result = {}; + for (final map in maps) { + result[map[idColumn] as String] = mapRowToPerson(map); + } + return result; + } + + Future> getPeople() async { + final db = await instance.database; + final List> maps = await db.query( + peopleTable, + columns: [ + idColumn, + nameColumn, + personHiddenColumn, + clusterToFaceIdJson, + coverFaceIDColumn, + ], + ); + return maps.map((map) => mapRowToPerson(map)).toList(); + } + + /// WARNING: This will delete ALL data in the database! Only use this for debug/testing purposes! + Future dropClustersAndPeople({bool faces = false}) async { + final db = await instance.database; + if (faces) { + await db.execute(deleteFacesTable); + await db.execute(createFacesTable); + } + await db.execute(deletePeopleTable); + await db.execute(dropClustersTable); + await db.execute(dropClusterSummaryTable); + await db.execute(dropNotPersonFeedbackTable); + + // await db.execute(createFacesTable); + await db.execute(createPeopleTable); + await db.execute(createClusterTable); + await db.execute(createNotPersonFeedbackTable); + await db.execute(createClusterSummaryTable); + } + + Future removePersonFromFiles(List files, Person p) async { + final db = await instance.database; + final result = await db.rawQuery( + 'SELECT $faceIDColumn FROM $facesTable LEFT JOIN $clustersTable ' + 'ON $facesTable.$faceClusterId = $clustersTable.$cluserIDColumn ' + 'WHERE $clustersTable.$personIdColumn = ? AND $facesTable.$fileIDColumn IN (${files.map((e) => e.uploadedFileID).join(",")})', + [p.remoteID], + ); + // get max clusterID + final maxRows = + await db.rawQuery('SELECT max($faceClusterId) from $facesTable'); + int maxClusterID = maxRows.first.values.first as int; + final Map faceIDToClusterID = {}; + for (final faceRow in result) { + final faceID = faceRow[faceIDColumn] as String; + faceIDToClusterID[faceID] = maxClusterID + 1; + maxClusterID = maxClusterID + 1; + } + await forceUpdateClusterIds(faceIDToClusterID); + } +} diff --git a/mobile/lib/face/db_fields.dart b/mobile/lib/face/db_fields.dart new file mode 100644 index 000000000..a1f185e7e --- /dev/null +++ b/mobile/lib/face/db_fields.dart @@ -0,0 +1,99 @@ +// Faces Table Fields & Schema Queries +import "package:photos/services/face_ml/blur_detection/blur_constants.dart"; + +const facesTable = 'faces'; +const fileIDColumn = 'file_id'; +const faceIDColumn = 'face_id'; +const faceDetectionColumn = 'detection'; +const faceEmbeddingBlob = 'eBlob'; +const faceScore = 'score'; +const faceBlur = 'blur'; +const faceClusterId = 'cluster_id'; +const faceConfirmedColumn = 'confirmed'; +const faceClosestDistColumn = 'close_dist'; +const faceClosestFaceID = 'close_face_id'; +const mlVersionColumn = 'ml_version'; + +const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable ( + $fileIDColumn INTEGER NOT NULL, + $faceIDColumn TEXT NOT NULL, + $faceDetectionColumn TEXT NOT NULL, + $faceEmbeddingBlob BLOB NOT NULL, + $faceScore REAL NOT NULL, + $faceBlur REAL NOT NULL DEFAULT $kLapacianDefault, + $faceClusterId INTEGER, + $faceClosestDistColumn REAL, + $faceClosestFaceID TEXT, + $faceConfirmedColumn INTEGER NOT NULL DEFAULT 0, + $mlVersionColumn INTEGER NOT NULL DEFAULT -1, + PRIMARY KEY($fileIDColumn, $faceIDColumn) + ); + '''; + +const deleteFacesTable = 'DROP TABLE IF EXISTS $facesTable'; +// End of Faces Table Fields & Schema Queries + +// People Table Fields & Schema Queries +const peopleTable = 'people'; +const idColumn = 'id'; +const nameColumn = 'name'; +const personHiddenColumn = 'hidden'; +const clusterToFaceIdJson = 'clusterToFaceIds'; +const coverFaceIDColumn = 'cover_face_id'; + +const createPeopleTable = '''CREATE TABLE IF NOT EXISTS $peopleTable ( + $idColumn TEXT NOT NULL UNIQUE, + $nameColumn TEXT NOT NULL DEFAULT '', + $personHiddenColumn INTEGER NOT NULL DEFAULT 0, + $clusterToFaceIdJson TEXT NOT NULL DEFAULT '{}', + $coverFaceIDColumn TEXT, + PRIMARY KEY($idColumn) + ); + '''; + +const deletePeopleTable = 'DROP TABLE IF EXISTS $peopleTable'; +//End People Table Fields & Schema Queries + +// Clusters Table Fields & Schema Queries +const clustersTable = 'clusters'; +const personIdColumn = 'person_id'; +const cluserIDColumn = 'cluster_id'; + +const createClusterTable = ''' +CREATE TABLE IF NOT EXISTS $clustersTable ( + $personIdColumn TEXT NOT NULL, + $cluserIDColumn INTEGER NOT NULL, + PRIMARY KEY($personIdColumn, $cluserIDColumn) +); +'''; +const dropClustersTable = 'DROP TABLE IF EXISTS $clustersTable'; +// End Clusters Table Fields & Schema Queries + +/// Cluster Summary Table Fields & Schema Queries +const clusterSummaryTable = 'cluster_summary'; +const avgColumn = 'avg'; +const countColumn = 'count'; +const createClusterSummaryTable = ''' +CREATE TABLE IF NOT EXISTS $clusterSummaryTable ( + $cluserIDColumn INTEGER NOT NULL, + $avgColumn BLOB NOT NULL, + $countColumn INTEGER NOT NULL, + PRIMARY KEY($cluserIDColumn) +); +'''; + +const dropClusterSummaryTable = 'DROP TABLE IF EXISTS $clusterSummaryTable'; + +/// End Cluster Summary Table Fields & Schema Queries + +/// notPersonFeedback Table Fields & Schema Queries +const notPersonFeedback = 'not_person_feedback'; + +const createNotPersonFeedbackTable = ''' +CREATE TABLE IF NOT EXISTS $notPersonFeedback ( + $personIdColumn TEXT NOT NULL, + $cluserIDColumn INTEGER NOT NULL +); +'''; +const dropNotPersonFeedbackTable = 'DROP TABLE IF EXISTS $notPersonFeedback'; +// End Clusters Table Fields & Schema Queries diff --git a/mobile/lib/face/db_model_mappers.dart b/mobile/lib/face/db_model_mappers.dart new file mode 100644 index 000000000..57dc311aa --- /dev/null +++ b/mobile/lib/face/db_model_mappers.dart @@ -0,0 +1,86 @@ +import "dart:convert"; + +import 'package:photos/face/db_fields.dart'; +import "package:photos/face/model/detection.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/person.dart"; +import 'package:photos/face/model/person_face.dart'; +import "package:photos/generated/protos/ente/common/vector.pb.dart"; + +int boolToSQLInt(bool? value, {bool defaultValue = false}) { + final bool v = value ?? defaultValue; + if (v == false) { + return 0; + } else { + return 1; + } +} + +bool sqlIntToBool(int? value, {bool defaultValue = false}) { + final int v = value ?? (defaultValue ? 1 : 0); + if (v == 0) { + return false; + } else { + return true; + } +} + +Map mapToFaceDB(PersonFace personFace) { + return { + faceIDColumn: personFace.face.faceID, + faceDetectionColumn: json.encode(personFace.face.detection.toJson()), + faceConfirmedColumn: boolToSQLInt(personFace.confirmed), + faceClusterId: personFace.personID, + faceClosestDistColumn: personFace.closeDist, + faceClosestFaceID: personFace.closeFaceID, + }; +} + +Map mapPersonToRow(Person p) { + return { + idColumn: p.remoteID, + nameColumn: p.attr.name, + personHiddenColumn: boolToSQLInt(p.attr.isHidden), + coverFaceIDColumn: p.attr.avatarFaceId, + clusterToFaceIdJson: jsonEncode(p.attr.faces.toList()), + }; +} + +Person mapRowToPerson(Map row) { + return Person( + row[idColumn] as String, + PersonAttr( + name: row[nameColumn] as String, + isHidden: sqlIntToBool(row[personHiddenColumn] as int), + avatarFaceId: row[coverFaceIDColumn] as String?, + faces: (jsonDecode(row[clusterToFaceIdJson]) as List) + .map((e) => e.toString()) + .toList(), + ), + ); +} + +Map mapRemoteToFaceDB(Face face) { + return { + faceIDColumn: face.faceID, + fileIDColumn: face.fileID, + faceDetectionColumn: json.encode(face.detection.toJson()), + faceEmbeddingBlob: EVector( + values: face.embedding, + ).writeToBuffer(), + faceScore: face.score, + faceBlur: face.blur, + mlVersionColumn: 1, + }; +} + +Face mapRowToFace(Map row) { + return Face( + row[faceIDColumn] as String, + row[fileIDColumn] as int, + EVector.fromBuffer(row[faceEmbeddingBlob] as List).values, + row[faceScore] as double, + Detection.fromJson(json.decode(row[faceDetectionColumn] as String)), + row[faceBlur] as double, + ); +} diff --git a/mobile/lib/face/feedback.dart b/mobile/lib/face/feedback.dart new file mode 100644 index 000000000..e69de29bb diff --git a/mobile/lib/face/model/box.dart b/mobile/lib/face/model/box.dart new file mode 100644 index 000000000..1ef89144c --- /dev/null +++ b/mobile/lib/face/model/box.dart @@ -0,0 +1,42 @@ +/// Bounding box of a face. +/// +/// [`x`] and [y] are the coordinates of the top left corner of the box, so the minimim values +/// [width] and [height] are the width and height of the box. +/// All values are in absolute pixels relative to the original image size. +class FaceBox { + final double x; + final double y; + final double width; + final double height; + + FaceBox({ + required this.x, + required this.y, + required this.width, + required this.height, + }); + + factory FaceBox.fromJson(Map json) { + return FaceBox( + x: (json['x'] is int + ? (json['x'] as int).toDouble() + : json['x'] as double), + y: (json['y'] is int + ? (json['y'] as int).toDouble() + : json['y'] as double), + width: (json['width'] is int + ? (json['width'] as int).toDouble() + : json['width'] as double), + height: (json['height'] is int + ? (json['height'] as int).toDouble() + : json['height'] as double), + ); + } + + Map toJson() => { + 'x': x, + 'y': y, + 'width': width, + 'height': height, + }; +} diff --git a/mobile/lib/face/model/detection.dart b/mobile/lib/face/model/detection.dart new file mode 100644 index 000000000..7d5b02cc6 --- /dev/null +++ b/mobile/lib/face/model/detection.dart @@ -0,0 +1,37 @@ +import "package:photos/face/model/box.dart"; +import "package:photos/face/model/landmark.dart"; + +class Detection { + FaceBox box; + List landmarks; + + Detection({ + required this.box, + required this.landmarks, + }); + + // emoty box + Detection.empty() + : box = FaceBox( + x: 0, + y: 0, + width: 0, + height: 0, + ), + landmarks = []; + + Map toJson() => { + 'box': box.toJson(), + 'landmarks': landmarks.map((x) => x.toJson()).toList(), + }; + + factory Detection.fromJson(Map json) { + return Detection( + box: FaceBox.fromJson(json['box'] as Map), + landmarks: List.from( + json['landmarks'] + .map((x) => Landmark.fromJson(x as Map)), + ), + ); + } +} diff --git a/mobile/lib/face/model/face.dart b/mobile/lib/face/model/face.dart new file mode 100644 index 000000000..49612c3d7 --- /dev/null +++ b/mobile/lib/face/model/face.dart @@ -0,0 +1,43 @@ +import "package:photos/face/model/detection.dart"; +import "package:photos/services/face_ml/blur_detection/blur_constants.dart"; + +class Face { + final int fileID; + final String faceID; + final List embedding; + Detection detection; + final double score; + final double blur; + + bool get isBlurry => blur < kLaplacianThreshold; + + Face( + this.faceID, + this.fileID, + this.embedding, + this.score, + this.detection, + this.blur, + ); + + factory Face.fromJson(Map json) { + return Face( + json['faceID'] as String, + json['fileID'] as int, + List.from(json['embeddings'] as List), + json['score'] as double, + Detection.fromJson(json['detection'] as Map), + // high value means t + (json['blur'] ?? kLapacianDefault) as double, + ); + } + + Map toJson() => { + 'faceID': faceID, + 'fileID': fileID, + 'embeddings': embedding, + 'detection': detection.toJson(), + 'score': score, + 'blur': blur, + }; +} diff --git a/mobile/lib/face/model/landmark.dart b/mobile/lib/face/model/landmark.dart new file mode 100644 index 000000000..03a68fd11 --- /dev/null +++ b/mobile/lib/face/model/landmark.dart @@ -0,0 +1,26 @@ +// Class for the 'landmark' sub-object +class Landmark { + double x; + double y; + + Landmark({ + required this.x, + required this.y, + }); + + Map toJson() => { + 'x': x, + 'y': y, + }; + + factory Landmark.fromJson(Map json) { + return Landmark( + x: (json['x'] is int + ? (json['x'] as int).toDouble() + : json['x'] as double), + y: (json['y'] is int + ? (json['y'] as int).toDouble() + : json['y'] as double), + ); + } +} diff --git a/mobile/lib/face/model/person.dart b/mobile/lib/face/model/person.dart new file mode 100644 index 000000000..ea7ab3d5d --- /dev/null +++ b/mobile/lib/face/model/person.dart @@ -0,0 +1,70 @@ +class Person { + final String remoteID; + final PersonAttr attr; + Person( + this.remoteID, + this.attr, + ); + + // copyWith + Person copyWith({ + String? remoteID, + PersonAttr? attr, + }) { + return Person( + remoteID ?? this.remoteID, + attr ?? this.attr, + ); + } +} + +class PersonAttr { + final String name; + final bool isHidden; + String? avatarFaceId; + final List faces; + final String? birthDatae; + PersonAttr({ + required this.name, + required this.faces, + this.avatarFaceId, + this.isHidden = false, + this.birthDatae, + }); + // copyWith + PersonAttr copyWith({ + String? name, + List? faces, + String? avatarFaceId, + bool? isHidden, + String? birthDatae, + }) { + return PersonAttr( + name: name ?? this.name, + faces: faces ?? this.faces, + avatarFaceId: avatarFaceId ?? this.avatarFaceId, + isHidden: isHidden ?? this.isHidden, + birthDatae: birthDatae ?? this.birthDatae, + ); + } + + // toJson + Map toJson() => { + 'name': name, + 'faces': faces.toList(), + 'avatarFaceId': avatarFaceId, + 'isHidden': isHidden, + 'birthDatae': birthDatae, + }; + + // fromJson + factory PersonAttr.fromJson(Map json) { + return PersonAttr( + name: json['name'] as String, + faces: List.from(json['faces'] as List), + avatarFaceId: json['avatarFaceId'] as String?, + isHidden: json['isHidden'] as bool? ?? false, + birthDatae: json['birthDatae'] as String?, + ); + } +} diff --git a/mobile/lib/face/model/person_face.dart b/mobile/lib/face/model/person_face.dart new file mode 100644 index 000000000..6d9744c27 --- /dev/null +++ b/mobile/lib/face/model/person_face.dart @@ -0,0 +1,37 @@ +import 'package:photos/face/model/face.dart'; + +class PersonFace { + final Face face; + int? personID; + bool? confirmed; + double? closeDist; + String? closeFaceID; + + PersonFace( + this.face, + this.personID, + this.closeDist, + this.closeFaceID, { + this.confirmed, + }); + + // toJson + Map toJson() => { + 'face': face.toJson(), + 'personID': personID, + 'confirmed': confirmed ?? false, + 'close_dist': closeDist, + 'close_face_id': closeFaceID, + }; + + // fromJson + factory PersonFace.fromJson(Map json) { + return PersonFace( + Face.fromJson(json['face'] as Map), + json['personID'] as int?, + json['close_dist'] as double?, + json['close_face_id'] as String?, + confirmed: json['confirmed'] as bool?, + ); + } +} diff --git a/mobile/lib/face/utils/import_from_zip.dart b/mobile/lib/face/utils/import_from_zip.dart new file mode 100644 index 000000000..08dda4158 --- /dev/null +++ b/mobile/lib/face/utils/import_from_zip.dart @@ -0,0 +1,44 @@ +// import "dart:io"; + +import "package:dio/dio.dart"; +import "package:logging/logging.dart"; +import "package:photos/core/configuration.dart"; +import "package:photos/core/network/network.dart"; +import "package:photos/face/model/face.dart"; + +final _logger = Logger("import_from_zip"); +Future> downloadZip() async { + final List result = []; + for (int i = 0; i < 2; i++) { + _logger.info("downloading $i"); + final remoteZipUrl = "http://192.168.1.13:8700/ml/cx_ml_json_${i}.json"; + final response = await NetworkClient.instance.getDio().get( + remoteZipUrl, + options: Options( + headers: {"X-Auth-Token": Configuration.instance.getToken()}, + ), + ); + + if (response.statusCode != 200) { + _logger.warning('download failed ${response.toString()}'); + throw Exception("download failed"); + } + final res = response.data as List; + for (final item in res) { + try { + result.add(Face.fromJson(item)); + } catch (e) { + _logger.warning("failed to parse $item"); + rethrow; + } + } + } + Set faceID = {}; + for (final face in result) { + if (faceID.contains(face.faceID)) { + _logger.warning("duplicate faceID ${face.faceID}"); + } + faceID.add(face.faceID); + } + return result; +} diff --git a/mobile/lib/generated/intl/messages_en.dart b/mobile/lib/generated/intl/messages_en.dart index 8889b63c5..799b8f122 100644 --- a/mobile/lib/generated/intl/messages_en.dart +++ b/mobile/lib/generated/intl/messages_en.dart @@ -973,6 +973,7 @@ class MessageLookup extends MessageLookupByLibrary { "paymentFailedWithReason": m36, "pendingItems": MessageLookupByLibrary.simpleMessage("Pending items"), "pendingSync": MessageLookupByLibrary.simpleMessage("Pending sync"), + "people": MessageLookupByLibrary.simpleMessage("People"), "peopleUsingYourCode": MessageLookupByLibrary.simpleMessage("People using your code"), "permDeleteWarning": MessageLookupByLibrary.simpleMessage( diff --git a/mobile/lib/generated/l10n.dart b/mobile/lib/generated/l10n.dart index b19a95551..1ae5cc7c3 100644 --- a/mobile/lib/generated/l10n.dart +++ b/mobile/lib/generated/l10n.dart @@ -8158,6 +8158,16 @@ class S { ); } + /// `People` + String get people { + return Intl.message( + 'People', + name: 'people', + desc: '', + args: [], + ); + } + /// `Contents` String get contents { return Intl.message( diff --git a/mobile/lib/generated/protos/ente/common/box.pb.dart b/mobile/lib/generated/protos/ente/common/box.pb.dart new file mode 100644 index 000000000..41518e9ae --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/box.pb.dart @@ -0,0 +1,111 @@ +// +// Generated code. Do not modify. +// source: ente/common/box.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:protobuf/protobuf.dart' as $pb; + +/// CenterBox is a box where x,y is the center of the box +class CenterBox extends $pb.GeneratedMessage { + factory CenterBox({ + $core.double? x, + $core.double? y, + $core.double? height, + $core.double? width, + }) { + final $result = create(); + if (x != null) { + $result.x = x; + } + if (y != null) { + $result.y = y; + } + if (height != null) { + $result.height = height; + } + if (width != null) { + $result.width = width; + } + return $result; + } + CenterBox._() : super(); + factory CenterBox.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory CenterBox.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'CenterBox', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.common'), createEmptyInstance: create) + ..a<$core.double>(1, _omitFieldNames ? '' : 'x', $pb.PbFieldType.OF) + ..a<$core.double>(2, _omitFieldNames ? '' : 'y', $pb.PbFieldType.OF) + ..a<$core.double>(3, _omitFieldNames ? '' : 'height', $pb.PbFieldType.OF) + ..a<$core.double>(4, _omitFieldNames ? '' : 'width', $pb.PbFieldType.OF) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + CenterBox clone() => CenterBox()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + CenterBox copyWith(void Function(CenterBox) updates) => super.copyWith((message) => updates(message as CenterBox)) as CenterBox; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static CenterBox create() => CenterBox._(); + CenterBox createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static CenterBox getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static CenterBox? _defaultInstance; + + @$pb.TagNumber(1) + $core.double get x => $_getN(0); + @$pb.TagNumber(1) + set x($core.double v) { $_setFloat(0, v); } + @$pb.TagNumber(1) + $core.bool hasX() => $_has(0); + @$pb.TagNumber(1) + void clearX() => clearField(1); + + @$pb.TagNumber(2) + $core.double get y => $_getN(1); + @$pb.TagNumber(2) + set y($core.double v) { $_setFloat(1, v); } + @$pb.TagNumber(2) + $core.bool hasY() => $_has(1); + @$pb.TagNumber(2) + void clearY() => clearField(2); + + @$pb.TagNumber(3) + $core.double get height => $_getN(2); + @$pb.TagNumber(3) + set height($core.double v) { $_setFloat(2, v); } + @$pb.TagNumber(3) + $core.bool hasHeight() => $_has(2); + @$pb.TagNumber(3) + void clearHeight() => clearField(3); + + @$pb.TagNumber(4) + $core.double get width => $_getN(3); + @$pb.TagNumber(4) + set width($core.double v) { $_setFloat(3, v); } + @$pb.TagNumber(4) + $core.bool hasWidth() => $_has(3); + @$pb.TagNumber(4) + void clearWidth() => clearField(4); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/common/box.pbenum.dart b/mobile/lib/generated/protos/ente/common/box.pbenum.dart new file mode 100644 index 000000000..7310e57a0 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/box.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/common/box.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/common/box.pbjson.dart b/mobile/lib/generated/protos/ente/common/box.pbjson.dart new file mode 100644 index 000000000..6c9ab3cb2 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/box.pbjson.dart @@ -0,0 +1,38 @@ +// +// Generated code. Do not modify. +// source: ente/common/box.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use centerBoxDescriptor instead') +const CenterBox$json = { + '1': 'CenterBox', + '2': [ + {'1': 'x', '3': 1, '4': 1, '5': 2, '9': 0, '10': 'x', '17': true}, + {'1': 'y', '3': 2, '4': 1, '5': 2, '9': 1, '10': 'y', '17': true}, + {'1': 'height', '3': 3, '4': 1, '5': 2, '9': 2, '10': 'height', '17': true}, + {'1': 'width', '3': 4, '4': 1, '5': 2, '9': 3, '10': 'width', '17': true}, + ], + '8': [ + {'1': '_x'}, + {'1': '_y'}, + {'1': '_height'}, + {'1': '_width'}, + ], +}; + +/// Descriptor for `CenterBox`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List centerBoxDescriptor = $convert.base64Decode( + 'CglDZW50ZXJCb3gSEQoBeBgBIAEoAkgAUgF4iAEBEhEKAXkYAiABKAJIAVIBeYgBARIbCgZoZW' + 'lnaHQYAyABKAJIAlIGaGVpZ2h0iAEBEhkKBXdpZHRoGAQgASgCSANSBXdpZHRoiAEBQgQKAl94' + 'QgQKAl95QgkKB19oZWlnaHRCCAoGX3dpZHRo'); + diff --git a/mobile/lib/generated/protos/ente/common/box.pbserver.dart b/mobile/lib/generated/protos/ente/common/box.pbserver.dart new file mode 100644 index 000000000..1e8625388 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/box.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/common/box.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'box.pb.dart'; + diff --git a/mobile/lib/generated/protos/ente/common/point.pb.dart b/mobile/lib/generated/protos/ente/common/point.pb.dart new file mode 100644 index 000000000..47f9b87ce --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/point.pb.dart @@ -0,0 +1,83 @@ +// +// Generated code. Do not modify. +// source: ente/common/point.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:protobuf/protobuf.dart' as $pb; + +/// EPoint is a point in 2D space +class EPoint extends $pb.GeneratedMessage { + factory EPoint({ + $core.double? x, + $core.double? y, + }) { + final $result = create(); + if (x != null) { + $result.x = x; + } + if (y != null) { + $result.y = y; + } + return $result; + } + EPoint._() : super(); + factory EPoint.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory EPoint.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'EPoint', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.common'), createEmptyInstance: create) + ..a<$core.double>(1, _omitFieldNames ? '' : 'x', $pb.PbFieldType.OF) + ..a<$core.double>(2, _omitFieldNames ? '' : 'y', $pb.PbFieldType.OF) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + EPoint clone() => EPoint()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + EPoint copyWith(void Function(EPoint) updates) => super.copyWith((message) => updates(message as EPoint)) as EPoint; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static EPoint create() => EPoint._(); + EPoint createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static EPoint getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static EPoint? _defaultInstance; + + @$pb.TagNumber(1) + $core.double get x => $_getN(0); + @$pb.TagNumber(1) + set x($core.double v) { $_setFloat(0, v); } + @$pb.TagNumber(1) + $core.bool hasX() => $_has(0); + @$pb.TagNumber(1) + void clearX() => clearField(1); + + @$pb.TagNumber(2) + $core.double get y => $_getN(1); + @$pb.TagNumber(2) + set y($core.double v) { $_setFloat(1, v); } + @$pb.TagNumber(2) + $core.bool hasY() => $_has(1); + @$pb.TagNumber(2) + void clearY() => clearField(2); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/common/point.pbenum.dart b/mobile/lib/generated/protos/ente/common/point.pbenum.dart new file mode 100644 index 000000000..3c242a2fc --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/point.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/common/point.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/common/point.pbjson.dart b/mobile/lib/generated/protos/ente/common/point.pbjson.dart new file mode 100644 index 000000000..44d2d0712 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/point.pbjson.dart @@ -0,0 +1,33 @@ +// +// Generated code. Do not modify. +// source: ente/common/point.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use ePointDescriptor instead') +const EPoint$json = { + '1': 'EPoint', + '2': [ + {'1': 'x', '3': 1, '4': 1, '5': 2, '9': 0, '10': 'x', '17': true}, + {'1': 'y', '3': 2, '4': 1, '5': 2, '9': 1, '10': 'y', '17': true}, + ], + '8': [ + {'1': '_x'}, + {'1': '_y'}, + ], +}; + +/// Descriptor for `EPoint`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List ePointDescriptor = $convert.base64Decode( + 'CgZFUG9pbnQSEQoBeBgBIAEoAkgAUgF4iAEBEhEKAXkYAiABKAJIAVIBeYgBAUIECgJfeEIECg' + 'JfeQ=='); + diff --git a/mobile/lib/generated/protos/ente/common/point.pbserver.dart b/mobile/lib/generated/protos/ente/common/point.pbserver.dart new file mode 100644 index 000000000..66728e123 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/point.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/common/point.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'point.pb.dart'; + diff --git a/mobile/lib/generated/protos/ente/common/vector.pb.dart b/mobile/lib/generated/protos/ente/common/vector.pb.dart new file mode 100644 index 000000000..44aa7d748 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/vector.pb.dart @@ -0,0 +1,64 @@ +// +// Generated code. Do not modify. +// source: ente/common/vector.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:protobuf/protobuf.dart' as $pb; + +/// Vector is generic message for dealing with lists of doubles +/// It should ideally be used independently and not as a submessage +class EVector extends $pb.GeneratedMessage { + factory EVector({ + $core.Iterable<$core.double>? values, + }) { + final $result = create(); + if (values != null) { + $result.values.addAll(values); + } + return $result; + } + EVector._() : super(); + factory EVector.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory EVector.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'EVector', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.common'), createEmptyInstance: create) + ..p<$core.double>(1, _omitFieldNames ? '' : 'values', $pb.PbFieldType.KD) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + EVector clone() => EVector()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + EVector copyWith(void Function(EVector) updates) => super.copyWith((message) => updates(message as EVector)) as EVector; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static EVector create() => EVector._(); + EVector createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static EVector getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static EVector? _defaultInstance; + + @$pb.TagNumber(1) + $core.List<$core.double> get values => $_getList(0); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/common/vector.pbenum.dart b/mobile/lib/generated/protos/ente/common/vector.pbenum.dart new file mode 100644 index 000000000..c88d2648a --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/vector.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/common/vector.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/common/vector.pbjson.dart b/mobile/lib/generated/protos/ente/common/vector.pbjson.dart new file mode 100644 index 000000000..1aff5cb29 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/vector.pbjson.dart @@ -0,0 +1,27 @@ +// +// Generated code. Do not modify. +// source: ente/common/vector.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use eVectorDescriptor instead') +const EVector$json = { + '1': 'EVector', + '2': [ + {'1': 'values', '3': 1, '4': 3, '5': 1, '10': 'values'}, + ], +}; + +/// Descriptor for `EVector`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List eVectorDescriptor = $convert.base64Decode( + 'CgdFVmVjdG9yEhYKBnZhbHVlcxgBIAMoAVIGdmFsdWVz'); + diff --git a/mobile/lib/generated/protos/ente/common/vector.pbserver.dart b/mobile/lib/generated/protos/ente/common/vector.pbserver.dart new file mode 100644 index 000000000..dbf5ac36f --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/vector.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/common/vector.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'vector.pb.dart'; + diff --git a/mobile/lib/generated/protos/ente/ml/face.pb.dart b/mobile/lib/generated/protos/ente/ml/face.pb.dart new file mode 100644 index 000000000..55d512b66 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/face.pb.dart @@ -0,0 +1,169 @@ +// +// Generated code. Do not modify. +// source: ente/ml/face.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:protobuf/protobuf.dart' as $pb; + +import '../common/box.pb.dart' as $0; +import '../common/point.pb.dart' as $1; + +class Detection extends $pb.GeneratedMessage { + factory Detection({ + $0.CenterBox? box, + $1.EPoint? landmarks, + }) { + final $result = create(); + if (box != null) { + $result.box = box; + } + if (landmarks != null) { + $result.landmarks = landmarks; + } + return $result; + } + Detection._() : super(); + factory Detection.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory Detection.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'Detection', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create) + ..aOM<$0.CenterBox>(1, _omitFieldNames ? '' : 'box', subBuilder: $0.CenterBox.create) + ..aOM<$1.EPoint>(2, _omitFieldNames ? '' : 'landmarks', subBuilder: $1.EPoint.create) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + Detection clone() => Detection()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + Detection copyWith(void Function(Detection) updates) => super.copyWith((message) => updates(message as Detection)) as Detection; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static Detection create() => Detection._(); + Detection createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static Detection getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static Detection? _defaultInstance; + + @$pb.TagNumber(1) + $0.CenterBox get box => $_getN(0); + @$pb.TagNumber(1) + set box($0.CenterBox v) { setField(1, v); } + @$pb.TagNumber(1) + $core.bool hasBox() => $_has(0); + @$pb.TagNumber(1) + void clearBox() => clearField(1); + @$pb.TagNumber(1) + $0.CenterBox ensureBox() => $_ensure(0); + + @$pb.TagNumber(2) + $1.EPoint get landmarks => $_getN(1); + @$pb.TagNumber(2) + set landmarks($1.EPoint v) { setField(2, v); } + @$pb.TagNumber(2) + $core.bool hasLandmarks() => $_has(1); + @$pb.TagNumber(2) + void clearLandmarks() => clearField(2); + @$pb.TagNumber(2) + $1.EPoint ensureLandmarks() => $_ensure(1); +} + +class Face extends $pb.GeneratedMessage { + factory Face({ + $core.String? id, + Detection? detection, + $core.double? confidence, + }) { + final $result = create(); + if (id != null) { + $result.id = id; + } + if (detection != null) { + $result.detection = detection; + } + if (confidence != null) { + $result.confidence = confidence; + } + return $result; + } + Face._() : super(); + factory Face.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory Face.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'Face', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create) + ..aOS(1, _omitFieldNames ? '' : 'id') + ..aOM(2, _omitFieldNames ? '' : 'detection', subBuilder: Detection.create) + ..a<$core.double>(3, _omitFieldNames ? '' : 'confidence', $pb.PbFieldType.OF) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + Face clone() => Face()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + Face copyWith(void Function(Face) updates) => super.copyWith((message) => updates(message as Face)) as Face; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static Face create() => Face._(); + Face createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static Face getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static Face? _defaultInstance; + + @$pb.TagNumber(1) + $core.String get id => $_getSZ(0); + @$pb.TagNumber(1) + set id($core.String v) { $_setString(0, v); } + @$pb.TagNumber(1) + $core.bool hasId() => $_has(0); + @$pb.TagNumber(1) + void clearId() => clearField(1); + + @$pb.TagNumber(2) + Detection get detection => $_getN(1); + @$pb.TagNumber(2) + set detection(Detection v) { setField(2, v); } + @$pb.TagNumber(2) + $core.bool hasDetection() => $_has(1); + @$pb.TagNumber(2) + void clearDetection() => clearField(2); + @$pb.TagNumber(2) + Detection ensureDetection() => $_ensure(1); + + @$pb.TagNumber(3) + $core.double get confidence => $_getN(2); + @$pb.TagNumber(3) + set confidence($core.double v) { $_setFloat(2, v); } + @$pb.TagNumber(3) + $core.bool hasConfidence() => $_has(2); + @$pb.TagNumber(3) + void clearConfidence() => clearField(3); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/ml/face.pbenum.dart b/mobile/lib/generated/protos/ente/ml/face.pbenum.dart new file mode 100644 index 000000000..2eefe1f44 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/face.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/ml/face.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/ml/face.pbjson.dart b/mobile/lib/generated/protos/ente/ml/face.pbjson.dart new file mode 100644 index 000000000..5aa614a8b --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/face.pbjson.dart @@ -0,0 +1,55 @@ +// +// Generated code. Do not modify. +// source: ente/ml/face.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use detectionDescriptor instead') +const Detection$json = { + '1': 'Detection', + '2': [ + {'1': 'box', '3': 1, '4': 1, '5': 11, '6': '.ente.common.CenterBox', '9': 0, '10': 'box', '17': true}, + {'1': 'landmarks', '3': 2, '4': 1, '5': 11, '6': '.ente.common.EPoint', '9': 1, '10': 'landmarks', '17': true}, + ], + '8': [ + {'1': '_box'}, + {'1': '_landmarks'}, + ], +}; + +/// Descriptor for `Detection`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List detectionDescriptor = $convert.base64Decode( + 'CglEZXRlY3Rpb24SLQoDYm94GAEgASgLMhYuZW50ZS5jb21tb24uQ2VudGVyQm94SABSA2JveI' + 'gBARI2CglsYW5kbWFya3MYAiABKAsyEy5lbnRlLmNvbW1vbi5FUG9pbnRIAVIJbGFuZG1hcmtz' + 'iAEBQgYKBF9ib3hCDAoKX2xhbmRtYXJrcw=='); + +@$core.Deprecated('Use faceDescriptor instead') +const Face$json = { + '1': 'Face', + '2': [ + {'1': 'id', '3': 1, '4': 1, '5': 9, '9': 0, '10': 'id', '17': true}, + {'1': 'detection', '3': 2, '4': 1, '5': 11, '6': '.ente.ml.Detection', '9': 1, '10': 'detection', '17': true}, + {'1': 'confidence', '3': 3, '4': 1, '5': 2, '9': 2, '10': 'confidence', '17': true}, + ], + '8': [ + {'1': '_id'}, + {'1': '_detection'}, + {'1': '_confidence'}, + ], +}; + +/// Descriptor for `Face`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List faceDescriptor = $convert.base64Decode( + 'CgRGYWNlEhMKAmlkGAEgASgJSABSAmlkiAEBEjUKCWRldGVjdGlvbhgCIAEoCzISLmVudGUubW' + 'wuRGV0ZWN0aW9uSAFSCWRldGVjdGlvbogBARIjCgpjb25maWRlbmNlGAMgASgCSAJSCmNvbmZp' + 'ZGVuY2WIAQFCBQoDX2lkQgwKCl9kZXRlY3Rpb25CDQoLX2NvbmZpZGVuY2U='); + diff --git a/mobile/lib/generated/protos/ente/ml/face.pbserver.dart b/mobile/lib/generated/protos/ente/ml/face.pbserver.dart new file mode 100644 index 000000000..a2cd6ff85 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/face.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/ml/face.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'face.pb.dart'; + diff --git a/mobile/lib/generated/protos/ente/ml/fileml.pb.dart b/mobile/lib/generated/protos/ente/ml/fileml.pb.dart new file mode 100644 index 000000000..853f89bac --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/fileml.pb.dart @@ -0,0 +1,179 @@ +// +// Generated code. Do not modify. +// source: ente/ml/fileml.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:fixnum/fixnum.dart' as $fixnum; +import 'package:protobuf/protobuf.dart' as $pb; + +import 'face.pb.dart' as $2; + +class FileML extends $pb.GeneratedMessage { + factory FileML({ + $fixnum.Int64? id, + $core.Iterable<$core.double>? clip, + }) { + final $result = create(); + if (id != null) { + $result.id = id; + } + if (clip != null) { + $result.clip.addAll(clip); + } + return $result; + } + FileML._() : super(); + factory FileML.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory FileML.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'FileML', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create) + ..aInt64(1, _omitFieldNames ? '' : 'id') + ..p<$core.double>(2, _omitFieldNames ? '' : 'clip', $pb.PbFieldType.KD) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + FileML clone() => FileML()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + FileML copyWith(void Function(FileML) updates) => super.copyWith((message) => updates(message as FileML)) as FileML; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static FileML create() => FileML._(); + FileML createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static FileML getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static FileML? _defaultInstance; + + @$pb.TagNumber(1) + $fixnum.Int64 get id => $_getI64(0); + @$pb.TagNumber(1) + set id($fixnum.Int64 v) { $_setInt64(0, v); } + @$pb.TagNumber(1) + $core.bool hasId() => $_has(0); + @$pb.TagNumber(1) + void clearId() => clearField(1); + + @$pb.TagNumber(2) + $core.List<$core.double> get clip => $_getList(1); +} + +class FileFaces extends $pb.GeneratedMessage { + factory FileFaces({ + $core.Iterable<$2.Face>? faces, + $core.int? height, + $core.int? width, + $core.int? version, + $core.String? error, + }) { + final $result = create(); + if (faces != null) { + $result.faces.addAll(faces); + } + if (height != null) { + $result.height = height; + } + if (width != null) { + $result.width = width; + } + if (version != null) { + $result.version = version; + } + if (error != null) { + $result.error = error; + } + return $result; + } + FileFaces._() : super(); + factory FileFaces.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory FileFaces.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'FileFaces', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create) + ..pc<$2.Face>(1, _omitFieldNames ? '' : 'faces', $pb.PbFieldType.PM, subBuilder: $2.Face.create) + ..a<$core.int>(2, _omitFieldNames ? '' : 'height', $pb.PbFieldType.O3) + ..a<$core.int>(3, _omitFieldNames ? '' : 'width', $pb.PbFieldType.O3) + ..a<$core.int>(4, _omitFieldNames ? '' : 'version', $pb.PbFieldType.O3) + ..aOS(5, _omitFieldNames ? '' : 'error') + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + FileFaces clone() => FileFaces()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + FileFaces copyWith(void Function(FileFaces) updates) => super.copyWith((message) => updates(message as FileFaces)) as FileFaces; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static FileFaces create() => FileFaces._(); + FileFaces createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static FileFaces getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static FileFaces? _defaultInstance; + + @$pb.TagNumber(1) + $core.List<$2.Face> get faces => $_getList(0); + + @$pb.TagNumber(2) + $core.int get height => $_getIZ(1); + @$pb.TagNumber(2) + set height($core.int v) { $_setSignedInt32(1, v); } + @$pb.TagNumber(2) + $core.bool hasHeight() => $_has(1); + @$pb.TagNumber(2) + void clearHeight() => clearField(2); + + @$pb.TagNumber(3) + $core.int get width => $_getIZ(2); + @$pb.TagNumber(3) + set width($core.int v) { $_setSignedInt32(2, v); } + @$pb.TagNumber(3) + $core.bool hasWidth() => $_has(2); + @$pb.TagNumber(3) + void clearWidth() => clearField(3); + + @$pb.TagNumber(4) + $core.int get version => $_getIZ(3); + @$pb.TagNumber(4) + set version($core.int v) { $_setSignedInt32(3, v); } + @$pb.TagNumber(4) + $core.bool hasVersion() => $_has(3); + @$pb.TagNumber(4) + void clearVersion() => clearField(4); + + @$pb.TagNumber(5) + $core.String get error => $_getSZ(4); + @$pb.TagNumber(5) + set error($core.String v) { $_setString(4, v); } + @$pb.TagNumber(5) + $core.bool hasError() => $_has(4); + @$pb.TagNumber(5) + void clearError() => clearField(5); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/ml/fileml.pbenum.dart b/mobile/lib/generated/protos/ente/ml/fileml.pbenum.dart new file mode 100644 index 000000000..71d796efe --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/fileml.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/ml/fileml.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/ml/fileml.pbjson.dart b/mobile/lib/generated/protos/ente/ml/fileml.pbjson.dart new file mode 100644 index 000000000..824741733 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/fileml.pbjson.dart @@ -0,0 +1,57 @@ +// +// Generated code. Do not modify. +// source: ente/ml/fileml.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use fileMLDescriptor instead') +const FileML$json = { + '1': 'FileML', + '2': [ + {'1': 'id', '3': 1, '4': 1, '5': 3, '9': 0, '10': 'id', '17': true}, + {'1': 'clip', '3': 2, '4': 3, '5': 1, '10': 'clip'}, + ], + '8': [ + {'1': '_id'}, + ], +}; + +/// Descriptor for `FileML`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List fileMLDescriptor = $convert.base64Decode( + 'CgZGaWxlTUwSEwoCaWQYASABKANIAFICaWSIAQESEgoEY2xpcBgCIAMoAVIEY2xpcEIFCgNfaW' + 'Q='); + +@$core.Deprecated('Use fileFacesDescriptor instead') +const FileFaces$json = { + '1': 'FileFaces', + '2': [ + {'1': 'faces', '3': 1, '4': 3, '5': 11, '6': '.ente.ml.Face', '10': 'faces'}, + {'1': 'height', '3': 2, '4': 1, '5': 5, '9': 0, '10': 'height', '17': true}, + {'1': 'width', '3': 3, '4': 1, '5': 5, '9': 1, '10': 'width', '17': true}, + {'1': 'version', '3': 4, '4': 1, '5': 5, '9': 2, '10': 'version', '17': true}, + {'1': 'error', '3': 5, '4': 1, '5': 9, '9': 3, '10': 'error', '17': true}, + ], + '8': [ + {'1': '_height'}, + {'1': '_width'}, + {'1': '_version'}, + {'1': '_error'}, + ], +}; + +/// Descriptor for `FileFaces`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List fileFacesDescriptor = $convert.base64Decode( + 'CglGaWxlRmFjZXMSIwoFZmFjZXMYASADKAsyDS5lbnRlLm1sLkZhY2VSBWZhY2VzEhsKBmhlaW' + 'dodBgCIAEoBUgAUgZoZWlnaHSIAQESGQoFd2lkdGgYAyABKAVIAVIFd2lkdGiIAQESHQoHdmVy' + 'c2lvbhgEIAEoBUgCUgd2ZXJzaW9uiAEBEhkKBWVycm9yGAUgASgJSANSBWVycm9yiAEBQgkKB1' + '9oZWlnaHRCCAoGX3dpZHRoQgoKCF92ZXJzaW9uQggKBl9lcnJvcg=='); + diff --git a/mobile/lib/generated/protos/ente/ml/fileml.pbserver.dart b/mobile/lib/generated/protos/ente/ml/fileml.pbserver.dart new file mode 100644 index 000000000..4cb208d27 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/fileml.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/ml/fileml.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'fileml.pb.dart'; + diff --git a/mobile/lib/l10n/intl_en.arb b/mobile/lib/l10n/intl_en.arb index e4ad661aa..abf2420e6 100644 --- a/mobile/lib/l10n/intl_en.arb +++ b/mobile/lib/l10n/intl_en.arb @@ -1170,6 +1170,7 @@ } }, "faces": "Faces", + "people": "People", "contents": "Contents", "addNew": "Add new", "@addNew": { diff --git a/mobile/lib/main.dart b/mobile/lib/main.dart index 7e639891f..4f315a97b 100644 --- a/mobile/lib/main.dart +++ b/mobile/lib/main.dart @@ -25,6 +25,7 @@ import 'package:photos/services/app_lifecycle_service.dart'; import 'package:photos/services/billing_service.dart'; import 'package:photos/services/collections_service.dart'; import "package:photos/services/entity_service.dart"; +import "package:photos/services/face_ml/face_ml_service.dart"; import 'package:photos/services/favorites_service.dart'; import 'package:photos/services/feature_flag_service.dart'; import 'package:photos/services/local_file_update_service.dart'; @@ -242,9 +243,11 @@ Future _init(bool isBackground, {String via = ''}) async { // Can not including existing tf/ml binaries as they are not being built // from source. // See https://gitlab.com/fdroid/fdroiddata/-/merge_requests/12671#note_1294346819 - // if (!UpdateService.instance.isFdroidFlavor()) { - // unawaited(ObjectDetectionService.instance.init()); - // } + if (!UpdateService.instance.isFdroidFlavor()) { + // unawaited(ObjectDetectionService.instance.init()); + unawaited(FaceMlService.instance.init()); + FaceMlService.instance.listenIndexOnDiffSync(); + } _logger.info("Initialization done"); } diff --git a/mobile/lib/models/gallery_type.dart b/mobile/lib/models/gallery_type.dart index ba0eb397f..b711e0f74 100644 --- a/mobile/lib/models/gallery_type.dart +++ b/mobile/lib/models/gallery_type.dart @@ -18,6 +18,8 @@ enum GalleryType { searchResults, locationTag, quickLink, + peopleTag, + cluster, } extension GalleyTypeExtension on GalleryType { @@ -32,12 +34,14 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.locationTag: case GalleryType.quickLink: case GalleryType.uncategorized: + case GalleryType.peopleTag: return true; case GalleryType.hiddenSection: case GalleryType.hiddenOwnedCollection: case GalleryType.trash: case GalleryType.sharedCollection: + case GalleryType.cluster: return false; } } @@ -50,6 +54,7 @@ extension GalleyTypeExtension on GalleryType { return true; case GalleryType.hiddenSection: + case GalleryType.peopleTag: case GalleryType.hiddenOwnedCollection: case GalleryType.favorite: case GalleryType.searchResults: @@ -59,6 +64,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.trash: case GalleryType.sharedCollection: case GalleryType.locationTag: + case GalleryType.cluster: return false; } } @@ -75,12 +81,14 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.uncategorized: case GalleryType.locationTag: case GalleryType.quickLink: + case GalleryType.peopleTag: return true; case GalleryType.trash: case GalleryType.archive: case GalleryType.hiddenSection: case GalleryType.hiddenOwnedCollection: case GalleryType.sharedCollection: + case GalleryType.cluster: return false; } } @@ -98,8 +106,10 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.localFolder: case GalleryType.locationTag: case GalleryType.quickLink: + case GalleryType.peopleTag: return true; case GalleryType.trash: + case GalleryType.cluster: case GalleryType.sharedCollection: return false; } @@ -114,8 +124,10 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.archive: case GalleryType.uncategorized: case GalleryType.locationTag: + case GalleryType.peopleTag: return true; case GalleryType.hiddenSection: + case GalleryType.cluster: case GalleryType.hiddenOwnedCollection: case GalleryType.localFolder: case GalleryType.trash: @@ -132,6 +144,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.quickLink: return true; case GalleryType.hiddenSection: + case GalleryType.peopleTag: case GalleryType.hiddenOwnedCollection: case GalleryType.uncategorized: case GalleryType.favorite: @@ -139,6 +152,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.homepage: case GalleryType.archive: case GalleryType.localFolder: + case GalleryType.cluster: case GalleryType.trash: case GalleryType.locationTag: return false; @@ -154,6 +168,7 @@ extension GalleyTypeExtension on GalleryType { return true; case GalleryType.hiddenSection: + case GalleryType.peopleTag: case GalleryType.hiddenOwnedCollection: case GalleryType.favorite: case GalleryType.searchResults: @@ -162,6 +177,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.trash: case GalleryType.sharedCollection: case GalleryType.locationTag: + case GalleryType.cluster: return false; } } @@ -182,10 +198,12 @@ extension GalleyTypeExtension on GalleryType { return true; case GalleryType.hiddenSection: + case GalleryType.peopleTag: case GalleryType.hiddenOwnedCollection: case GalleryType.localFolder: case GalleryType.trash: case GalleryType.favorite: + case GalleryType.cluster: case GalleryType.sharedCollection: return false; } @@ -203,12 +221,14 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.searchResults: case GalleryType.uncategorized: case GalleryType.locationTag: + case GalleryType.peopleTag: return true; case GalleryType.hiddenSection: case GalleryType.hiddenOwnedCollection: case GalleryType.quickLink: case GalleryType.favorite: + case GalleryType.cluster: case GalleryType.archive: case GalleryType.localFolder: case GalleryType.trash: @@ -244,7 +264,7 @@ extension GalleyTypeExtension on GalleryType { } bool showEditLocation() { - return this != GalleryType.sharedCollection; + return this != GalleryType.sharedCollection && this != GalleryType.cluster; } } @@ -334,7 +354,9 @@ extension GalleryAppBarExtn on GalleryType { case GalleryType.locationTag: case GalleryType.searchResults: return false; + case GalleryType.cluster: case GalleryType.uncategorized: + case GalleryType.peopleTag: case GalleryType.ownedCollection: case GalleryType.sharedCollection: case GalleryType.quickLink: diff --git a/mobile/lib/models/ml/ml_typedefs.dart b/mobile/lib/models/ml/ml_typedefs.dart new file mode 100644 index 000000000..bcb23251e --- /dev/null +++ b/mobile/lib/models/ml/ml_typedefs.dart @@ -0,0 +1,7 @@ +typedef Embedding = List; + +typedef Num3DInputMatrix = List>>; + +typedef Int3DInputMatrix = List>>; + +typedef Double3DInputMatrix = List>>; diff --git a/mobile/lib/models/ml/ml_versions.dart b/mobile/lib/models/ml/ml_versions.dart new file mode 100644 index 000000000..857bef33c --- /dev/null +++ b/mobile/lib/models/ml/ml_versions.dart @@ -0,0 +1,3 @@ +const faceMlVersion = 1; +const clusterMlVersion = 1; +const minimumClusterSize = 2; \ No newline at end of file diff --git a/mobile/lib/models/search/generic_search_result.dart b/mobile/lib/models/search/generic_search_result.dart index 352886a50..a40f71fd3 100644 --- a/mobile/lib/models/search/generic_search_result.dart +++ b/mobile/lib/models/search/generic_search_result.dart @@ -8,8 +8,15 @@ class GenericSearchResult extends SearchResult { final List _files; final ResultType _type; final Function(BuildContext context)? onResultTap; + final Map params; - GenericSearchResult(this._type, this._name, this._files, {this.onResultTap}); + GenericSearchResult( + this._type, + this._name, + this._files, { + this.onResultTap, + this.params = const {}, + }); @override String name() { diff --git a/mobile/lib/models/search/search_constants.dart b/mobile/lib/models/search/search_constants.dart new file mode 100644 index 000000000..6a0bcb886 --- /dev/null +++ b/mobile/lib/models/search/search_constants.dart @@ -0,0 +1,3 @@ +const kPersonParamID = 'person_id'; +const kClusterParamId = 'cluster_id'; +const kFileID = 'file_id'; diff --git a/mobile/lib/models/search/search_types.dart b/mobile/lib/models/search/search_types.dart index 9f5ee9b1d..1ae358c8c 100644 --- a/mobile/lib/models/search/search_types.dart +++ b/mobile/lib/models/search/search_types.dart @@ -33,6 +33,7 @@ enum ResultType { fileCaption, event, shared, + faces, magic, } @@ -55,7 +56,7 @@ extension SectionTypeExtensions on SectionType { String sectionTitle(BuildContext context) { switch (this) { case SectionType.face: - return S.of(context).faces; + return S.of(context).people; case SectionType.content: return S.of(context).contents; case SectionType.moment: @@ -99,7 +100,7 @@ extension SectionTypeExtensions on SectionType { bool get isCTAVisible { switch (this) { case SectionType.face: - return false; + return true; case SectionType.content: return false; case SectionType.moment: @@ -117,6 +118,8 @@ extension SectionTypeExtensions on SectionType { } } + bool get sortByName => this != SectionType.face; + bool get isEmptyCTAVisible { switch (this) { case SectionType.face: @@ -245,8 +248,7 @@ extension SectionTypeExtensions on SectionType { }) { switch (this) { case SectionType.face: - return Future.value(List.empty()); - + return SearchService.instance.getAllFace(limit); case SectionType.content: return Future.value(List.empty()); diff --git a/mobile/lib/services/face_ml/blur_detection/blur_constants.dart b/mobile/lib/services/face_ml/blur_detection/blur_constants.dart new file mode 100644 index 000000000..4d770162c --- /dev/null +++ b/mobile/lib/services/face_ml/blur_detection/blur_constants.dart @@ -0,0 +1,2 @@ +const kLaplacianThreshold = 10; +const kLapacianDefault = 10000.0; diff --git a/mobile/lib/services/face_ml/blur_detection/blur_detection_service.dart b/mobile/lib/services/face_ml/blur_detection/blur_detection_service.dart new file mode 100644 index 000000000..ff5830468 --- /dev/null +++ b/mobile/lib/services/face_ml/blur_detection/blur_detection_service.dart @@ -0,0 +1,115 @@ +import 'package:logging/logging.dart'; +import "package:photos/services/face_ml/blur_detection/blur_constants.dart"; + +class BlurDetectionService { + final _logger = Logger('BlurDetectionService'); + + // singleton pattern + BlurDetectionService._privateConstructor(); + static final instance = BlurDetectionService._privateConstructor(); + factory BlurDetectionService() => instance; + + Future<(bool, double)> predictIsBlurGrayLaplacian( + List> grayImage, { + int threshold = kLaplacianThreshold, + }) async { + final List> laplacian = _applyLaplacian(grayImage); + final double variance = _calculateVariance(laplacian); + _logger.info('Variance: $variance'); + return (variance < threshold, variance); + } + + double _calculateVariance(List> matrix) { + final int numRows = matrix.length; + final int numCols = matrix[0].length; + final int totalElements = numRows * numCols; + + // Calculate the mean + double mean = 0; + for (var row in matrix) { + for (var value in row) { + mean += value; + } + } + mean /= totalElements; + + // Calculate the variance + double variance = 0; + for (var row in matrix) { + for (var value in row) { + final double diff = value - mean; + variance += diff * diff; + } + } + variance /= totalElements; + + return variance; + } + + List> _padImage(List> image) { + final int numRows = image.length; + final int numCols = image[0].length; + + // Create a new matrix with extra padding + final List> paddedImage = List.generate( + numRows + 2, + (i) => List.generate(numCols + 2, (j) => 0, growable: false), + growable: false, + ); + + // Copy original image into the center of the padded image + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + paddedImage[i + 1][j + 1] = image[i][j]; + } + } + + // Reflect padding + // Top and bottom rows + for (int j = 1; j <= numCols; j++) { + paddedImage[0][j] = paddedImage[2][j]; // Top row + paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row + } + // Left and right columns + for (int i = 0; i < numRows + 2; i++) { + paddedImage[i][0] = paddedImage[i][2]; // Left column + paddedImage[i][numCols + 1] = paddedImage[i][numCols - 1]; // Right column + } + + return paddedImage; + } + + List> _applyLaplacian(List> image) { + final List> paddedImage = _padImage(image); + final int numRows = image.length; + final int numCols = image[0].length; + final List> outputImage = List.generate( + numRows, + (i) => List.generate(numCols, (j) => 0, growable: false), + growable: false, + ); + + // Define the Laplacian kernel + final List> kernel = [ + [0, 1, 0], + [1, -4, 1], + [0, 1, 0], + ]; + + // Apply the kernel to each pixel + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + int sum = 0; + for (int ki = 0; ki < 3; ki++) { + for (int kj = 0; kj < 3; kj++) { + sum += paddedImage[i + ki][j + kj] * kernel[ki][kj]; + } + } + // Adjust the output value if necessary (e.g., clipping) + outputImage[i][j] = sum; //.clamp(0, 255); + } + } + + return outputImage; + } +} diff --git a/mobile/lib/services/face_ml/face_alignment/alignment_result.dart b/mobile/lib/services/face_ml/face_alignment/alignment_result.dart new file mode 100644 index 000000000..41fd88b61 --- /dev/null +++ b/mobile/lib/services/face_ml/face_alignment/alignment_result.dart @@ -0,0 +1,36 @@ +class AlignmentResult { + final List> affineMatrix; // 3x3 + final List center; // [x, y] + final double size; // 1 / scale + final double rotation; // atan2(simRotation[1][0], simRotation[0][0]); + + AlignmentResult({required this.affineMatrix, required this.center, required this.size, required this.rotation}); + + AlignmentResult.empty() + : affineMatrix = >[ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + center = [0, 0], + size = 1, + rotation = 0; + + factory AlignmentResult.fromJson(Map json) { + return AlignmentResult( + affineMatrix: (json['affineMatrix'] as List) + .map((item) => List.from(item)) + .toList(), + center: List.from(json['center'] as List), + size: json['size'] as double, + rotation: json['rotation'] as double, + ); + } + + Map toJson() => { + 'affineMatrix': affineMatrix, + 'center': center, + 'size': size, + 'rotation': rotation, + }; +} \ No newline at end of file diff --git a/mobile/lib/services/face_ml/face_alignment/similarity_transform.dart b/mobile/lib/services/face_ml/face_alignment/similarity_transform.dart new file mode 100644 index 000000000..4ae27794b --- /dev/null +++ b/mobile/lib/services/face_ml/face_alignment/similarity_transform.dart @@ -0,0 +1,171 @@ +import 'dart:math' show atan2; +import 'package:ml_linalg/linalg.dart'; +import 'package:photos/extensions/ml_linalg_extensions.dart'; +import "package:photos/services/face_ml/face_alignment/alignment_result.dart"; + +/// Class to compute the similarity transform between two sets of points. +/// +/// The class estimates the parameters of the similarity transformation via the `estimate` function. +/// After estimation, the transformation can be applied to an image using the `warpAffine` function. +class SimilarityTransform { + Matrix _params = Matrix.fromList([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0, 0, 1], + ]); + List _center = [0, 0]; // [x, y] + double _size = 1; // 1 / scale + double _rotation = 0; // atan2(simRotation[1][0], simRotation[0][0]); + + final arcface4Landmarks = [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [56.1396, 92.2848], + ]; + final arcface5Landmarks = [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041], + ]; + get arcfaceNormalized4 => arcface4Landmarks + .map((list) => list.map((value) => value / 112.0).toList()) + .toList(); + get arcfaceNormalized5 => arcface5Landmarks + .map((list) => list.map((value) => value / 112.0).toList()) + .toList(); + + List> get paramsList => _params.to2DList(); + + // singleton pattern + SimilarityTransform._privateConstructor(); + static final instance = SimilarityTransform._privateConstructor(); + factory SimilarityTransform() => instance; + + void _cleanParams() { + _params = Matrix.fromList([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0, 0, 1], + ]); + _center = [0, 0]; + _size = 1; + _rotation = 0; + } + + /// Function to estimate the parameters of the affine transformation. These parameters are stored in the class variable params. + /// + /// Returns a tuple of (AlignmentResult, bool). The bool indicates whether the parameters are valid or not. + /// + /// Runs efficiently in about 1-3 ms after initial warm-up. + /// + /// It takes the source and destination points as input and returns the + /// parameters of the affine transformation as output. The function + /// returns false if the parameters cannot be estimated. The function + /// estimates the parameters by solving a least-squares problem using + /// the Umeyama algorithm, via [_umeyama]. + (AlignmentResult, bool) estimate(List> src) { + _cleanParams(); + final (params, center, size, rotation) = + _umeyama(src, arcfaceNormalized5, true); + _params = params; + _center = center; + _size = size; + _rotation = rotation; + final alignmentResult = AlignmentResult( + affineMatrix: paramsList, + center: _center, + size: _size, + rotation: _rotation, + ); + // We check for NaN in the transformation matrix params. + final isNoNanInParam = + !_params.asFlattenedList.any((element) => element.isNaN); + return (alignmentResult, isNoNanInParam); + } + + static (Matrix, List, double, double) _umeyama( + List> src, + List> dst, [ + bool estimateScale = true, + ]) { + final srcMat = Matrix.fromList( + src, + // .map((list) => list.map((value) => value.toDouble()).toList()) + // .toList(), + ); + final dstMat = Matrix.fromList(dst); + final num = srcMat.rowCount; + final dim = srcMat.columnCount; + + // Compute mean of src and dst. + final srcMean = srcMat.mean(Axis.columns); + final dstMean = dstMat.mean(Axis.columns); + + // Subtract mean from src and dst. + final srcDemean = srcMat.mapRows((vector) => vector - srcMean); + final dstDemean = dstMat.mapRows((vector) => vector - dstMean); + + // Eq. (38). + final A = (dstDemean.transpose() * srcDemean) / num; + + // Eq. (39). + var d = Vector.filled(dim, 1.0); + if (A.determinant() < 0) { + d = d.set(dim - 1, -1); + } + + var T = Matrix.identity(dim + 1); + + final svdResult = A.svd(); + final Matrix U = svdResult['U']!; + final Vector S = svdResult['S']!; + final Matrix V = svdResult['V']!; + + // Eq. (40) and (43). + final rank = A.matrixRank(); + if (rank == 0) { + return (T * double.nan, [0, 0], 1, 0); + } else if (rank == dim - 1) { + if (U.determinant() * V.determinant() > 0) { + T = T.setSubMatrix(0, dim, 0, dim, U * V); + } else { + final s = d[dim - 1]; + d = d.set(dim - 1, -1); + final replacement = U * Matrix.diagonal(d.toList()) * V; + T = T.setSubMatrix(0, dim, 0, dim, replacement); + d = d.set(dim - 1, s); + } + } else { + final replacement = U * Matrix.diagonal(d.toList()) * V; + T = T.setSubMatrix(0, dim, 0, dim, replacement); + } + final Matrix simRotation = U * Matrix.diagonal(d.toList()) * V; + + var scale = 1.0; + if (estimateScale) { + // Eq. (41) and (42). + scale = 1.0 / srcDemean.variance(Axis.columns).sum() * (S * d).sum(); + } + + final subTIndices = Iterable.generate(dim, (index) => index); + final subT = T.sample(rowIndices: subTIndices, columnIndices: subTIndices); + final newSubT = dstMean - (subT * srcMean) * scale; + T = T.setValues(0, dim, dim, dim + 1, newSubT); + final newNewSubT = + T.sample(rowIndices: subTIndices, columnIndices: subTIndices) * scale; + T = T.setSubMatrix(0, dim, 0, dim, newNewSubT); + + // final List translation = [T[0][2], T[1][2]]; + // final simRotation = replacement?; + final size = 1 / scale; + final rotation = atan2(simRotation[1][0], simRotation[0][0]); + final meanTranslation = (dstMean - 0.5) * size; + final centerMat = srcMean - meanTranslation; + final List center = [centerMat[0], centerMat[1]]; + + return (T, center, size, rotation); + } +} diff --git a/mobile/lib/services/face_ml/face_clustering/cosine_distance.dart b/mobile/lib/services/face_ml/face_clustering/cosine_distance.dart new file mode 100644 index 000000000..f8f2e68a8 --- /dev/null +++ b/mobile/lib/services/face_ml/face_clustering/cosine_distance.dart @@ -0,0 +1,55 @@ +import 'dart:math' show sqrt; + +/// Calculates the cosine distance between two embeddings/vectors. +/// +/// Throws an ArgumentError if the vectors are of different lengths or +/// if either of the vectors has a magnitude of zero. +double cosineDistance(List vector1, List vector2) { + if (vector1.length != vector2.length) { + throw ArgumentError('Vectors must be the same length'); + } + + double dotProduct = 0.0; + double magnitude1 = 0.0; + double magnitude2 = 0.0; + + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2[i]; + magnitude1 += vector1[i] * vector1[i]; + magnitude2 += vector2[i] * vector2[i]; + } + + magnitude1 = sqrt(magnitude1); + magnitude2 = sqrt(magnitude2); + + // Avoid division by zero. This should never happen. If it does, then one of the vectors contains only zeros. + if (magnitude1 == 0 || magnitude2 == 0) { + throw ArgumentError('Vectors must not have a magnitude of zero'); + } + + final double similarity = dotProduct / (magnitude1 * magnitude2); + + // Cosine distance is the complement of cosine similarity + return 1.0 - similarity; +} + +// cosineDistForNormVectors calculates the cosine distance between two normalized embeddings/vectors. +@pragma('vm:entry-point') +double cosineDistForNormVectors(List vector1, List vector2) { + if (vector1.length != vector2.length) { + throw ArgumentError('Vectors must be the same length'); + } + double dotProduct = 0.0; + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2[i]; + } + return 1.0 - dotProduct; +} + +double calculateSqrDistance(List v1, List v2) { + double sum = 0; + for (int i = 0; i < v1.length; i++) { + sum += (v1[i] - v2[i]) * (v1[i] - v2[i]); + } + return sqrt(sum); +} diff --git a/mobile/lib/services/face_ml/face_clustering/linear_clustering_service.dart b/mobile/lib/services/face_ml/face_clustering/linear_clustering_service.dart new file mode 100644 index 000000000..8e7783859 --- /dev/null +++ b/mobile/lib/services/face_ml/face_clustering/linear_clustering_service.dart @@ -0,0 +1,405 @@ +import "dart:async"; +import "dart:developer"; +import "dart:isolate"; +import "dart:math" show max; +import "dart:typed_data"; + +import "package:logging/logging.dart"; +import "package:photos/generated/protos/ente/common/vector.pb.dart"; +import "package:photos/services/face_ml/face_clustering/cosine_distance.dart"; +import "package:synchronized/synchronized.dart"; + +class FaceInfo { + final String faceID; + final List embedding; + int? clusterId; + String? closestFaceId; + int? closestDist; + FaceInfo({ + required this.faceID, + required this.embedding, + this.clusterId, + }); +} + +enum ClusterOperation { linearIncrementalClustering } + +class FaceLinearClustering { + final _logger = Logger("FaceLinearClustering"); + + Timer? _inactivityTimer; + final Duration _inactivityDuration = const Duration(seconds: 30); + int _activeTasks = 0; + + + final _initLock = Lock(); + + late Isolate _isolate; + late ReceivePort _receivePort = ReceivePort(); + late SendPort _mainSendPort; + + bool isSpawned = false; + bool isRunning = false; + + static const recommendedDistanceThreshold = 0.3; + + // singleton pattern + FaceLinearClustering._privateConstructor(); + + /// Use this instance to access the FaceClustering service. + /// e.g. `FaceLinearClustering.instance.predict(dataset)` + static final instance = FaceLinearClustering._privateConstructor(); + factory FaceLinearClustering() => instance; + + Future init() async { + return _initLock.synchronized(() async { + if (isSpawned) return; + + _receivePort = ReceivePort(); + + try { + _isolate = await Isolate.spawn( + _isolateMain, + _receivePort.sendPort, + ); + _mainSendPort = await _receivePort.first as SendPort; + isSpawned = true; + + _resetInactivityTimer(); + } catch (e) { + _logger.severe('Could not spawn isolate', e); + isSpawned = false; + } + }); + } + + Future ensureSpawned() async { + if (!isSpawned) { + await init(); + } + } + + /// The main execution function of the isolate. + static void _isolateMain(SendPort mainSendPort) async { + final receivePort = ReceivePort(); + mainSendPort.send(receivePort.sendPort); + + receivePort.listen((message) async { + final functionIndex = message[0] as int; + final function = ClusterOperation.values[functionIndex]; + final args = message[1] as Map; + final sendPort = message[2] as SendPort; + + try { + switch (function) { + case ClusterOperation.linearIncrementalClustering: + final input = args['input'] as Map; + final result = FaceLinearClustering._runLinearClustering(input); + sendPort.send(result); + break; + } + } catch (e, stackTrace) { + sendPort + .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); + } + }); + } + + /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. + Future _runInIsolate( + (ClusterOperation, Map) message, + ) async { + await ensureSpawned(); + _resetInactivityTimer(); + final completer = Completer(); + final answerPort = ReceivePort(); + + _activeTasks++; + _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); + + answerPort.listen((receivedMessage) { + if (receivedMessage is Map && receivedMessage.containsKey('error')) { + // Handle the error + final errorMessage = receivedMessage['error']; + final errorStackTrace = receivedMessage['stackTrace']; + final exception = Exception(errorMessage); + final stackTrace = StackTrace.fromString(errorStackTrace); + completer.completeError(exception, stackTrace); + } else { + completer.complete(receivedMessage); + } + }); + _activeTasks--; + + return completer.future; + } + + /// Resets a timer that kills the isolate after a certain amount of inactivity. + /// + /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) + void _resetInactivityTimer() { + _inactivityTimer?.cancel(); + _inactivityTimer = Timer(_inactivityDuration, () { + if (_activeTasks > 0) { + _logger.info('Tasks are still running. Delaying isolate disposal.'); + // Optionally, reschedule the timer to check again later. + _resetInactivityTimer(); + } else { + _logger.info( + 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', + ); + dispose(); + } + }); + } + + /// Disposes the isolate worker. + void dispose() { + if (!isSpawned) return; + + isSpawned = false; + _isolate.kill(); + _receivePort.close(); + _inactivityTimer?.cancel(); + } + + /// Runs the clustering algorithm on the given [input], in an isolate. + /// + /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset. + /// + /// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic. + Future?> predict( + Map input, + ) async { + if (input.isEmpty) { + _logger.warning( + "Clustering dataset of embeddings is empty, returning empty list.", + ); + return null; + } + if (isRunning) { + _logger.warning("Clustering is already running, returning empty list."); + return null; + } + + isRunning = true; + + // Clustering inside the isolate + _logger.info( + "Start clustering on ${input.length} embeddings inside computer isolate", + ); + final stopwatchClustering = Stopwatch()..start(); + // final Map faceIdToCluster = + // await _runLinearClusteringInComputer(input); + final Map faceIdToCluster = await _runInIsolate( + (ClusterOperation.linearIncrementalClustering, {'input': input}), + ); + // return _runLinearClusteringInComputer(input); + _logger.info( + 'Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds', + ); + + isRunning = false; + + return faceIdToCluster; + } + + static Map _runLinearClustering( + Map x, + ) { + log( + "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${x.length} faces", + ); + final List faceInfos = []; + for (final entry in x.entries) { + faceInfos.add( + FaceInfo( + faceID: entry.key, + embedding: EVector.fromBuffer(entry.value.$2).values, + clusterId: entry.value.$1, + ), + ); + } + // Sort the faceInfos such that the ones with null clusterId are at the end + faceInfos.sort((a, b) { + if (a.clusterId == null && b.clusterId == null) { + return 0; + } else if (a.clusterId == null) { + return 1; + } else if (b.clusterId == null) { + return -1; + } else { + return 0; + } + }); + // Count the amount of null values at the end + int nullCount = 0; + for (final faceInfo in faceInfos.reversed) { + if (faceInfo.clusterId == null) { + nullCount++; + } else { + break; + } + } + log( + "[ClusterIsolate] ${DateTime.now()} Clustering $nullCount new faces without clusterId, and ${faceInfos.length - nullCount} faces with clusterId", + ); + for (final clusteredFaceInfo + in faceInfos.sublist(0, faceInfos.length - nullCount)) { + assert(clusteredFaceInfo.clusterId != null); + } + + final int totalFaces = faceInfos.length; + int clusterID = 1; + if (faceInfos.isNotEmpty) { + faceInfos.first.clusterId = clusterID; + } + log( + "[ClusterIsolate] ${DateTime.now()} Processing $totalFaces faces", + ); + final stopwatchClustering = Stopwatch()..start(); + for (int i = 1; i < totalFaces; i++) { + // Incremental clustering, so we can skip faces that already have a clusterId + if (faceInfos[i].clusterId != null) { + clusterID = max(clusterID, faceInfos[i].clusterId!); + continue; + } + final currentEmbedding = faceInfos[i].embedding; + int closestIdx = -1; + double closestDistance = double.infinity; + if (i % 250 == 0) { + log("[ClusterIsolate] ${DateTime.now()} Processing $i faces"); + } + for (int j = 0; j < i; j++) { + final double distance = cosineDistForNormVectors( + currentEmbedding, + faceInfos[j].embedding, + ); + if (distance < closestDistance) { + closestDistance = distance; + closestIdx = j; + } + } + + if (closestDistance < recommendedDistanceThreshold) { + if (faceInfos[closestIdx].clusterId == null) { + // Ideally this should never happen, but just in case log it + log( + " [ClusterIsolate] ${DateTime.now()} Found new cluster $clusterID", + ); + clusterID++; + faceInfos[closestIdx].clusterId = clusterID; + } + faceInfos[i].clusterId = faceInfos[closestIdx].clusterId; + } else { + clusterID++; + faceInfos[i].clusterId = clusterID; + } + } + final Map result = {}; + for (final faceInfo in faceInfos) { + result[faceInfo.faceID] = faceInfo.clusterId!; + } + stopwatchClustering.stop(); + log( + ' [ClusterIsolate] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings (${faceInfos[0].embedding.length} size) executed in ${stopwatchClustering.elapsedMilliseconds}ms, clusters $clusterID', + ); + // return result; + + // NOTe: The main clustering logic is done, the following is just filtering and logging + final input = x; + final faceIdToCluster = result; + stopwatchClustering.reset(); + stopwatchClustering.start(); + + final Set newFaceIds = {}; + input.forEach((key, value) { + if (value.$1 == null) { + newFaceIds.add(key); + } + }); + + // Find faceIDs that are part of a cluster which is larger than 5 and are new faceIDs + final Map clusterIdToSize = {}; + faceIdToCluster.forEach((key, value) { + if (clusterIdToSize.containsKey(value)) { + clusterIdToSize[value] = clusterIdToSize[value]! + 1; + } else { + clusterIdToSize[value] = 1; + } + }); + final Map faceIdToClusterFiltered = {}; + for (final entry in faceIdToCluster.entries) { + if (clusterIdToSize[entry.value]! > 0 && newFaceIds.contains(entry.key)) { + faceIdToClusterFiltered[entry.key] = entry.value; + } + } + + // print top 10 cluster ids and their sizes based on the internal cluster id + final clusterIds = faceIdToCluster.values.toSet(); + final clusterSizes = clusterIds.map((clusterId) { + return faceIdToCluster.values.where((id) => id == clusterId).length; + }).toList(); + clusterSizes.sort(); + // find clusters whose size is graeter than 1 + int oneClusterCount = 0; + int moreThan5Count = 0; + int moreThan10Count = 0; + int moreThan20Count = 0; + int moreThan50Count = 0; + int moreThan100Count = 0; + + // for (int i = 0; i < clusterSizes.length; i++) { + // if (clusterSizes[i] > 100) { + // moreThan100Count++; + // } else if (clusterSizes[i] > 50) { + // moreThan50Count++; + // } else if (clusterSizes[i] > 20) { + // moreThan20Count++; + // } else if (clusterSizes[i] > 10) { + // moreThan10Count++; + // } else if (clusterSizes[i] > 5) { + // moreThan5Count++; + // } else if (clusterSizes[i] == 1) { + // oneClusterCount++; + // } + // } + for (int i = 0; i < clusterSizes.length; i++) { + if (clusterSizes[i] > 100) { + moreThan100Count++; + } + if (clusterSizes[i] > 50) { + moreThan50Count++; + } + if (clusterSizes[i] > 20) { + moreThan20Count++; + } + if (clusterSizes[i] > 10) { + moreThan10Count++; + } + if (clusterSizes[i] > 5) { + moreThan5Count++; + } + if (clusterSizes[i] == 1) { + oneClusterCount++; + } + } + // print the metrics + log( + '[ClusterIsolate] Total clusters ${clusterIds.length}, ' + 'oneClusterCount $oneClusterCount, ' + 'moreThan5Count $moreThan5Count, ' + 'moreThan10Count $moreThan10Count, ' + 'moreThan20Count $moreThan20Count, ' + 'moreThan50Count $moreThan50Count, ' + 'moreThan100Count $moreThan100Count', + ); + stopwatchClustering.stop(); + log( + "[ClusterIsolate] Clustering additional steps took ${stopwatchClustering.elapsedMilliseconds} ms", + ); + + // log('Top clusters count ${clusterSizes.reversed.take(10).toList()}'); + return faceIdToClusterFiltered; + } +} diff --git a/mobile/lib/services/face_ml/face_detection/detection.dart b/mobile/lib/services/face_ml/face_detection/detection.dart new file mode 100644 index 000000000..5a3d12606 --- /dev/null +++ b/mobile/lib/services/face_ml/face_detection/detection.dart @@ -0,0 +1,469 @@ +import 'dart:convert' show utf8; +import 'dart:math' show sqrt, pow; +import 'dart:ui' show Size; +import 'package:crypto/crypto.dart' show sha256; + +abstract class Detection { + final double score; + + Detection({required this.score}); + + const Detection.empty() : score = 0; + + get width; + get height; + + @override + String toString(); +} + +extension BBoxExtension on List { + void roundBoxToDouble() { + final widthRounded = (this[2] - this[0]).roundToDouble(); + final heightRounded = (this[3] - this[1]).roundToDouble(); + this[0] = this[0].roundToDouble(); + this[1] = this[1].roundToDouble(); + this[2] = this[0] + widthRounded; + this[3] = this[1] + heightRounded; + } + + // double get xMinBox => + // isNotEmpty ? this[0] : throw IndexError.withLength(0, length); + // double get yMinBox => + // length >= 2 ? this[1] : throw IndexError.withLength(1, length); + // double get xMaxBox => + // length >= 3 ? this[2] : throw IndexError.withLength(2, length); + // double get yMaxBox => + // length >= 4 ? this[3] : throw IndexError.withLength(3, length); +} + +/// This class represents a face detection with relative coordinates in the range [0, 1]. +/// The coordinates are relative to the image size. The pattern for the coordinates is always [x, y], where x is the horizontal coordinate and y is the vertical coordinate. +/// +/// The [score] attribute is a double representing the confidence of the face detection. +/// +/// The [box] attribute is a list of 4 doubles, representing the coordinates of the bounding box of the face detection. +/// The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox]. +/// +/// The [allKeypoints] attribute is a list of 6 lists of 2 doubles, representing the coordinates of the keypoints of the face detection. +/// The six lists of two values in order are: [leftEye, rightEye, nose, mouth, leftEar, rightEar]. Again, all in [x, y] order. +class FaceDetectionRelative extends Detection { + final List box; + final List> allKeypoints; + + double get xMinBox => box[0]; + double get yMinBox => box[1]; + double get xMaxBox => box[2]; + double get yMaxBox => box[3]; + + List get leftEye => allKeypoints[0]; + List get rightEye => allKeypoints[1]; + List get nose => allKeypoints[2]; + List get leftMouth => allKeypoints[3]; + List get rightMouth => allKeypoints[4]; + + FaceDetectionRelative({ + required double score, + required List box, + required List> allKeypoints, + }) : assert( + box.every((e) => e >= -0.1 && e <= 1.1), + "Bounding box values must be in the range [0, 1], with only a small margin of error allowed.", + ), + assert( + allKeypoints + .every((sublist) => sublist.every((e) => e >= -0.1 && e <= 1.1)), + "All keypoints must be in the range [0, 1], with only a small margin of error allowed.", + ), + box = List.from(box.map((e) => e.clamp(0.0, 1.0))), + allKeypoints = allKeypoints + .map( + (sublist) => + List.from(sublist.map((e) => e.clamp(0.0, 1.0))), + ) + .toList(), + super(score: score); + + factory FaceDetectionRelative.zero() { + return FaceDetectionRelative( + score: 0, + box: [0, 0, 0, 0], + allKeypoints: >[ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + ); + } + + /// This is used to initialize the FaceDetectionRelative object with default values. + /// This constructor is useful because it can be used to initialize a FaceDetectionRelative object as a constant. + /// Contrary to the `FaceDetectionRelative.zero()` constructor, this one gives immutable attributes [box] and [allKeypoints]. + FaceDetectionRelative.defaultInitialization() + : box = const [0, 0, 0, 0], + allKeypoints = const >[ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + super.empty(); + + FaceDetectionRelative getNearestDetection( + List detections, + ) { + if (detections.isEmpty) { + throw ArgumentError("The detection list cannot be empty."); + } + + var nearestDetection = detections[0]; + var minDistance = double.infinity; + + // Calculate the center of the current instance + final centerX1 = (xMinBox + xMaxBox) / 2; + final centerY1 = (yMinBox + yMaxBox) / 2; + + for (var detection in detections) { + final centerX2 = (detection.xMinBox + detection.xMaxBox) / 2; + final centerY2 = (detection.yMinBox + detection.yMaxBox) / 2; + final distance = + sqrt(pow(centerX2 - centerX1, 2) + pow(centerY2 - centerY1, 2)); + if (distance < minDistance) { + minDistance = distance; + nearestDetection = detection; + } + } + return nearestDetection; + } + + void transformRelativeToOriginalImage( + List fromBox, // [xMin, yMin, xMax, yMax] + List toBox, // [xMin, yMin, xMax, yMax] + ) { + // Return if all elements of fromBox and toBox are equal + for (int i = 0; i < fromBox.length; i++) { + if (fromBox[i] != toBox[i]) { + break; + } + if (i == fromBox.length - 1) { + return; + } + } + + // Account for padding + final double paddingXRatio = + (fromBox[0] - toBox[0]) / (toBox[2] - toBox[0]); + final double paddingYRatio = + (fromBox[1] - toBox[1]) / (toBox[3] - toBox[1]); + + // Calculate the scaling and translation + final double scaleX = (fromBox[2] - fromBox[0]) / (1 - 2 * paddingXRatio); + final double scaleY = (fromBox[3] - fromBox[1]) / (1 - 2 * paddingYRatio); + final double translateX = fromBox[0] - paddingXRatio * scaleX; + final double translateY = fromBox[1] - paddingYRatio * scaleY; + + // Transform Box + _transformBox(box, scaleX, scaleY, translateX, translateY); + + // Transform All Keypoints + for (int i = 0; i < allKeypoints.length; i++) { + allKeypoints[i] = _transformPoint( + allKeypoints[i], + scaleX, + scaleY, + translateX, + translateY, + ); + } + } + + void correctForMaintainedAspectRatio( + Size originalSize, + Size newSize, + ) { + // Return if both are the same size, meaning no scaling was done on both width and height + if (originalSize == newSize) { + return; + } + + // Calculate the scaling + final double scaleX = originalSize.width / newSize.width; + final double scaleY = originalSize.height / newSize.height; + const double translateX = 0; + const double translateY = 0; + + // Transform Box + _transformBox(box, scaleX, scaleY, translateX, translateY); + + // Transform All Keypoints + for (int i = 0; i < allKeypoints.length; i++) { + allKeypoints[i] = _transformPoint( + allKeypoints[i], + scaleX, + scaleY, + translateX, + translateY, + ); + } + } + + void _transformBox( + List box, + double scaleX, + double scaleY, + double translateX, + double translateY, + ) { + box[0] = (box[0] * scaleX + translateX).clamp(0.0, 1.0); + box[1] = (box[1] * scaleY + translateY).clamp(0.0, 1.0); + box[2] = (box[2] * scaleX + translateX).clamp(0.0, 1.0); + box[3] = (box[3] * scaleY + translateY).clamp(0.0, 1.0); + } + + List _transformPoint( + List point, + double scaleX, + double scaleY, + double translateX, + double translateY, + ) { + return [ + (point[0] * scaleX + translateX).clamp(0.0, 1.0), + (point[1] * scaleY + translateY).clamp(0.0, 1.0), + ]; + } + + FaceDetectionAbsolute toAbsolute({ + required int imageWidth, + required int imageHeight, + }) { + final scoreCopy = score; + final boxCopy = List.from(box, growable: false); + final allKeypointsCopy = allKeypoints + .map((sublist) => List.from(sublist, growable: false)) + .toList(); + + boxCopy[0] *= imageWidth; + boxCopy[1] *= imageHeight; + boxCopy[2] *= imageWidth; + boxCopy[3] *= imageHeight; + // final intbox = boxCopy.map((e) => e.toInt()).toList(); + + for (List keypoint in allKeypointsCopy) { + keypoint[0] *= imageWidth; + keypoint[1] *= imageHeight; + } + // final intKeypoints = + // allKeypointsCopy.map((e) => e.map((e) => e.toInt()).toList()).toList(); + return FaceDetectionAbsolute( + score: scoreCopy, + box: boxCopy, + allKeypoints: allKeypointsCopy, + ); + } + + String toFaceID({required int fileID}) { + // Assert that the values are within the expected range + assert( + (xMinBox >= 0 && xMinBox <= 1) && + (yMinBox >= 0 && yMinBox <= 1) && + (xMaxBox >= 0 && xMaxBox <= 1) && + (yMaxBox >= 0 && yMaxBox <= 1), + "Bounding box values must be in the range [0, 1]", + ); + + // Extract bounding box values + final String xMin = + xMinBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2); + final String yMin = + yMinBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2); + final String xMax = + xMaxBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2); + final String yMax = + yMaxBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2); + + // Convert the bounding box values to string and concatenate + final String rawID = "${xMin}_${yMin}_${xMax}_$yMax"; + + // Hash the concatenated string using SHA256 + final digest = sha256.convert(utf8.encode(rawID)); + + // Return the hexadecimal representation of the hash + return fileID.toString() + '_' + digest.toString(); + } + + /// This method is used to generate a faceID for a face detection that was manually added by the user. + static String toFaceIDEmpty({required int fileID}) { + return fileID.toString() + '_0'; + } + + /// This method is used to check if a faceID corresponds to a manually added face detection and not an actual face detection. + static bool isFaceIDEmpty(String faceID) { + return faceID.split('_')[1] == '0'; + } + + @override + String toString() { + return 'FaceDetectionRelative( with relative coordinates: \n score: $score \n Box: xMinBox: $xMinBox, yMinBox: $yMinBox, xMaxBox: $xMaxBox, yMaxBox: $yMaxBox, \n Keypoints: leftEye: $leftEye, rightEye: $rightEye, nose: $nose, leftMouth: $leftMouth, rightMouth: $rightMouth \n )'; + } + + Map toJson() { + return { + 'score': score, + 'box': box, + 'allKeypoints': allKeypoints, + }; + } + + factory FaceDetectionRelative.fromJson(Map json) { + return FaceDetectionRelative( + score: (json['score'] as num).toDouble(), + box: List.from(json['box']), + allKeypoints: (json['allKeypoints'] as List) + .map((item) => List.from(item)) + .toList(), + ); + } + + @override + + /// The width of the bounding box of the face detection, in relative range [0, 1]. + double get width => xMaxBox - xMinBox; + @override + + /// The height of the bounding box of the face detection, in relative range [0, 1]. + double get height => yMaxBox - yMinBox; +} + +/// This class represents a face detection with absolute coordinates in pixels, in the range [0, imageWidth] for the horizontal coordinates and [0, imageHeight] for the vertical coordinates. +/// The pattern for the coordinates is always [x, y], where x is the horizontal coordinate and y is the vertical coordinate. +/// +/// The [score] attribute is a double representing the confidence of the face detection. +/// +/// The [box] attribute is a list of 4 integers, representing the coordinates of the bounding box of the face detection. +/// The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox]. +/// +/// The [allKeypoints] attribute is a list of 6 lists of 2 integers, representing the coordinates of the keypoints of the face detection. +/// The six lists of two values in order are: [leftEye, rightEye, nose, mouth, leftEar, rightEar]. Again, all in [x, y] order. +class FaceDetectionAbsolute extends Detection { + final List box; + final List> allKeypoints; + + double get xMinBox => box[0]; + double get yMinBox => box[1]; + double get xMaxBox => box[2]; + double get yMaxBox => box[3]; + + List get leftEye => allKeypoints[0]; + List get rightEye => allKeypoints[1]; + List get nose => allKeypoints[2]; + List get leftMouth => allKeypoints[3]; + List get rightMouth => allKeypoints[4]; + + FaceDetectionAbsolute({ + required double score, + required this.box, + required this.allKeypoints, + }) : super(score: score); + + factory FaceDetectionAbsolute._zero() { + return FaceDetectionAbsolute( + score: 0, + box: [0, 0, 0, 0], + allKeypoints: >[ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + ); + } + + FaceDetectionAbsolute.defaultInitialization() + : box = const [0, 0, 0, 0], + allKeypoints = const >[ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + super.empty(); + + @override + String toString() { + return 'FaceDetectionAbsolute( with absolute coordinates: \n score: $score \n Box: xMinBox: $xMinBox, yMinBox: $yMinBox, xMaxBox: $xMaxBox, yMaxBox: $yMaxBox, \n Keypoints: leftEye: $leftEye, rightEye: $rightEye, nose: $nose, leftMouth: $leftMouth, rightMouth: $rightMouth \n )'; + } + + Map toJson() { + return { + 'score': score, + 'box': box, + 'allKeypoints': allKeypoints, + }; + } + + factory FaceDetectionAbsolute.fromJson(Map json) { + return FaceDetectionAbsolute( + score: (json['score'] as num).toDouble(), + box: List.from(json['box']), + allKeypoints: (json['allKeypoints'] as List) + .map((item) => List.from(item)) + .toList(), + ); + } + + static FaceDetectionAbsolute empty = FaceDetectionAbsolute._zero(); + + @override + + /// The width of the bounding box of the face detection, in number of pixels, range [0, imageWidth]. + double get width => xMaxBox - xMinBox; + @override + + /// The height of the bounding box of the face detection, in number of pixels, range [0, imageHeight]. + double get height => yMaxBox - yMinBox; +} + +List relativeToAbsoluteDetections({ + required List relativeDetections, + required int imageWidth, + required int imageHeight, +}) { + final numberOfDetections = relativeDetections.length; + final absoluteDetections = List.filled( + numberOfDetections, + FaceDetectionAbsolute._zero(), + ); + for (var i = 0; i < relativeDetections.length; i++) { + final relativeDetection = relativeDetections[i]; + final absoluteDetection = relativeDetection.toAbsolute( + imageWidth: imageWidth, + imageHeight: imageHeight, + ); + + absoluteDetections[i] = absoluteDetection; + } + + return absoluteDetections; +} + +/// Returns an enlarged version of the [box] by a factor of [factor]. +List getEnlargedRelativeBox(List box, [double factor = 2]) { + final boxCopy = List.from(box, growable: false); + // The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox]. + + final width = boxCopy[2] - boxCopy[0]; + final height = boxCopy[3] - boxCopy[1]; + + boxCopy[0] -= width * (factor - 1) / 2; + boxCopy[1] -= height * (factor - 1) / 2; + boxCopy[2] += width * (factor - 1) / 2; + boxCopy[3] += height * (factor - 1) / 2; + + return boxCopy; +} diff --git a/mobile/lib/services/face_ml/face_detection/naive_non_max_suppression.dart b/mobile/lib/services/face_ml/face_detection/naive_non_max_suppression.dart new file mode 100644 index 000000000..ca1e4aba5 --- /dev/null +++ b/mobile/lib/services/face_ml/face_detection/naive_non_max_suppression.dart @@ -0,0 +1,49 @@ +import 'dart:math' as math show max, min; + +import "package:photos/services/face_ml/face_detection/detection.dart"; + +List naiveNonMaxSuppression({ + required List detections, + required double iouThreshold, +}) { + // Sort the detections by score, the highest first + detections.sort((a, b) => b.score.compareTo(a.score)); + + // Loop through the detections and calculate the IOU + for (var i = 0; i < detections.length - 1; i++) { + for (var j = i + 1; j < detections.length; j++) { + final iou = _calculateIOU(detections[i], detections[j]); + if (iou >= iouThreshold) { + detections.removeAt(j); + j--; + } + } + } + return detections; +} + +double _calculateIOU( + FaceDetectionRelative detectionA, + FaceDetectionRelative detectionB, +) { + final areaA = detectionA.width * detectionA.height; + final areaB = detectionB.width * detectionB.height; + + final intersectionMinX = math.max(detectionA.xMinBox, detectionB.xMinBox); + final intersectionMinY = math.max(detectionA.yMinBox, detectionB.yMinBox); + final intersectionMaxX = math.min(detectionA.xMaxBox, detectionB.xMaxBox); + final intersectionMaxY = math.min(detectionA.yMaxBox, detectionB.yMaxBox); + + final intersectionWidth = intersectionMaxX - intersectionMinX; + final intersectionHeight = intersectionMaxY - intersectionMinY; + + if (intersectionWidth < 0 || intersectionHeight < 0) { + return 0.0; // If boxes do not overlap, IoU is 0 + } + + final intersectionArea = intersectionWidth * intersectionHeight; + + final unionArea = areaA + areaB - intersectionArea; + + return intersectionArea / unionArea; +} diff --git a/mobile/lib/services/face_ml/face_detection/yolov5face/onnx_face_detection.dart b/mobile/lib/services/face_ml/face_detection/yolov5face/onnx_face_detection.dart new file mode 100644 index 000000000..94cfc2fdc --- /dev/null +++ b/mobile/lib/services/face_ml/face_detection/yolov5face/onnx_face_detection.dart @@ -0,0 +1,786 @@ +import "dart:async"; +import "dart:developer" as dev show log; +import "dart:io" show File; +import "dart:isolate"; +import 'dart:typed_data' show Float32List, Uint8List; + +import "package:computer/computer.dart"; +import 'package:flutter/material.dart'; +import 'package:logging/logging.dart'; +import 'package:onnxruntime/onnxruntime.dart'; +import "package:photos/services/face_ml/face_detection/detection.dart"; +import "package:photos/services/face_ml/face_detection/naive_non_max_suppression.dart"; +import "package:photos/services/face_ml/face_detection/yolov5face/yolo_face_detection_exceptions.dart"; +import "package:photos/services/face_ml/face_detection/yolov5face/yolo_filter_extract_detections.dart"; +import "package:photos/services/remote_assets_service.dart"; +import "package:photos/utils/image_ml_isolate.dart"; +import "package:photos/utils/image_ml_util.dart"; +import "package:synchronized/synchronized.dart"; + +enum FaceDetectionOperation { yoloInferenceAndPostProcessing } + +class YoloOnnxFaceDetection { + static final _logger = Logger('YOLOFaceDetectionService'); + + final _computer = Computer.shared(); + + int sessionAddress = 0; + + static const kModelBucketEndpoint = "https://models.ente.io/"; + static const kRemoteBucketModelPath = "yolov5s_face_640_640_dynamic.onnx"; + // static const kRemoteBucketModelPath = "yolov5n_face_640_640.onnx"; + static const modelRemotePath = kModelBucketEndpoint + kRemoteBucketModelPath; + + static const kInputWidth = 640; + static const kInputHeight = 640; + static const kIouThreshold = 0.4; + static const kMinScoreSigmoidThreshold = 0.8; + + bool isInitialized = false; + + // Isolate things + Timer? _inactivityTimer; + final Duration _inactivityDuration = const Duration(seconds: 30); + + final _initLock = Lock(); + final _computerLock = Lock(); + + late Isolate _isolate; + late ReceivePort _receivePort = ReceivePort(); + late SendPort _mainSendPort; + + bool isSpawned = false; + bool isRunning = false; + + // singleton pattern + YoloOnnxFaceDetection._privateConstructor(); + + /// Use this instance to access the FaceDetection service. Make sure to call `init()` before using it. + /// e.g. `await FaceDetection.instance.init();` + /// + /// Then you can use `predict()` to get the bounding boxes of the faces, so `FaceDetection.instance.predict(imageData)` + /// + /// config options: yoloV5FaceN // + static final instance = YoloOnnxFaceDetection._privateConstructor(); + + factory YoloOnnxFaceDetection() => instance; + + /// Check if the interpreter is initialized, if not initialize it with `loadModel()` + Future init() async { + if (!isInitialized) { + _logger.info('init is called'); + final model = + await RemoteAssetsService.instance.getAsset(modelRemotePath); + final startTime = DateTime.now(); + // Doing this from main isolate since `rootBundle` cannot be accessed outside it + sessionAddress = await _computer.compute( + _loadModel, + param: { + "modelPath": model.path, + }, + ); + final endTime = DateTime.now(); + _logger.info( + "Face detection model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", + ); + if (sessionAddress != -1) { + isInitialized = true; + } + } + } + + Future release() async { + if (isInitialized) { + await _computer + .compute(_releaseModel, param: {'address': sessionAddress}); + isInitialized = false; + sessionAddress = 0; + } + } + + Future initIsolate() async { + return _initLock.synchronized(() async { + if (isSpawned) return; + + _receivePort = ReceivePort(); + + try { + _isolate = await Isolate.spawn( + _isolateMain, + _receivePort.sendPort, + ); + _mainSendPort = await _receivePort.first as SendPort; + isSpawned = true; + + _resetInactivityTimer(); + } catch (e) { + _logger.severe('Could not spawn isolate', e); + isSpawned = false; + } + }); + } + + Future ensureSpawnedIsolate() async { + if (!isSpawned) { + await initIsolate(); + } + } + + /// The main execution function of the isolate. + static void _isolateMain(SendPort mainSendPort) async { + final receivePort = ReceivePort(); + mainSendPort.send(receivePort.sendPort); + + receivePort.listen((message) async { + final functionIndex = message[0] as int; + final function = FaceDetectionOperation.values[functionIndex]; + final args = message[1] as Map; + final sendPort = message[2] as SendPort; + + try { + switch (function) { + case FaceDetectionOperation.yoloInferenceAndPostProcessing: + final inputImageList = args['inputImageList'] as Float32List; + final inputShape = args['inputShape'] as List; + final newSize = args['newSize'] as Size; + final sessionAddress = args['sessionAddress'] as int; + final timeSentToIsolate = args['timeNow'] as DateTime; + final delaySentToIsolate = + DateTime.now().difference(timeSentToIsolate).inMilliseconds; + + final Stopwatch stopwatchPrepare = Stopwatch()..start(); + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchPrepare.stop(); + dev.log( + '[YOLOFaceDetectionService] data preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms', + ); + + stopwatchPrepare.reset(); + stopwatchPrepare.start(); + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + stopwatchPrepare.stop(); + dev.log( + '[YOLOFaceDetectionService] session preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms', + ); + + final stopwatchInterpreter = Stopwatch()..start(); + late final List outputs; + try { + outputs = session.run(runOptions, inputs); + } catch (e, s) { + dev.log( + '[YOLOFaceDetectionService] Error while running inference: $e \n $s', + ); + throw YOLOInterpreterRunException(); + } + stopwatchInterpreter.stop(); + dev.log( + '[YOLOFaceDetectionService] interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + ); + + final relativeDetections = + _yoloPostProcessOutputs(outputs, newSize); + + sendPort + .send((relativeDetections, delaySentToIsolate, DateTime.now())); + break; + } + } catch (e, stackTrace) { + sendPort + .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); + } + }); + } + + /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. + Future _runInIsolate( + (FaceDetectionOperation, Map) message, + ) async { + await ensureSpawnedIsolate(); + _resetInactivityTimer(); + final completer = Completer(); + final answerPort = ReceivePort(); + + _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); + + answerPort.listen((receivedMessage) { + if (receivedMessage is Map && receivedMessage.containsKey('error')) { + // Handle the error + final errorMessage = receivedMessage['error']; + final errorStackTrace = receivedMessage['stackTrace']; + final exception = Exception(errorMessage); + final stackTrace = StackTrace.fromString(errorStackTrace); + completer.completeError(exception, stackTrace); + } else { + completer.complete(receivedMessage); + } + }); + + return completer.future; + } + + /// Resets a timer that kills the isolate after a certain amount of inactivity. + /// + /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) + void _resetInactivityTimer() { + _inactivityTimer?.cancel(); + _inactivityTimer = Timer(_inactivityDuration, () { + _logger.info( + 'Face detection (YOLO ONNX) Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds. Killing isolate.', + ); + disposeIsolate(); + }); + } + + /// Disposes the isolate worker. + void disposeIsolate() { + if (!isSpawned) return; + + isSpawned = false; + _isolate.kill(); + _receivePort.close(); + _inactivityTimer?.cancel(); + } + + /// Detects faces in the given image data. + Future<(List, Size)> predict( + Uint8List imageData, + ) async { + assert(isInitialized); + + final stopwatch = Stopwatch()..start(); + + final stopwatchDecoding = Stopwatch()..start(); + final (inputImageList, originalSize, newSize) = + await ImageMlIsolate.instance.preprocessImageYoloOnnx( + imageData, + normalize: true, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + quality: FilterQuality.medium, + ); + + // final input = [inputImageList]; + final inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchDecoding.stop(); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + _logger.info('original size: $originalSize \n new size: $newSize'); + + // Run inference + final stopwatchInterpreter = Stopwatch()..start(); + List? outputs; + try { + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + outputs = session.run(runOptions, inputs); + // inputOrt.release(); + // runOptions.release(); + } catch (e, s) { + _logger.severe('Error while running inference: $e \n $s'); + throw YOLOInterpreterRunException(); + } + stopwatchInterpreter.stop(); + _logger.info( + 'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + ); + + final relativeDetections = _yoloPostProcessOutputs(outputs, newSize); + + stopwatch.stop(); + _logger.info( + 'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms', + ); + + return (relativeDetections, originalSize); + } + + /// Detects faces in the given image data. + static Future<(List, Size)> predictSync( + String imagePath, + int sessionAddress, + ) async { + assert(sessionAddress != 0 && sessionAddress != -1); + + final stopwatch = Stopwatch()..start(); + + final stopwatchDecoding = Stopwatch()..start(); + final imageData = await File(imagePath).readAsBytes(); + final (inputImageList, originalSize, newSize) = + await preprocessImageToFloat32ChannelsFirst( + imageData, + normalization: 1, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + quality: FilterQuality.medium, + ); + + // final input = [inputImageList]; + final inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchDecoding.stop(); + dev.log( + 'Face detection image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + _logger.info('original size: $originalSize \n new size: $newSize'); + + // Run inference + final stopwatchInterpreter = Stopwatch()..start(); + List? outputs; + try { + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + outputs = session.run(runOptions, inputs); + // inputOrt.release(); + // runOptions.release(); + } catch (e, s) { + _logger.severe('Error while running inference: $e \n $s'); + throw YOLOInterpreterRunException(); + } + stopwatchInterpreter.stop(); + _logger.info( + 'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + ); + + final relativeDetections = _yoloPostProcessOutputs(outputs, newSize); + + stopwatch.stop(); + _logger.info( + 'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms', + ); + + return (relativeDetections, originalSize); + } + + /// Detects faces in the given image data. + Future<(List, Size)> predictInIsolate( + Uint8List imageData, + ) async { + await ensureSpawnedIsolate(); + assert(isInitialized); + + _logger.info('predictInIsolate() is called'); + + final stopwatch = Stopwatch()..start(); + + final stopwatchDecoding = Stopwatch()..start(); + final (inputImageList, originalSize, newSize) = + await ImageMlIsolate.instance.preprocessImageYoloOnnx( + imageData, + normalize: true, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + quality: FilterQuality.medium, + ); + // final input = [inputImageList]; + final inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; + + stopwatchDecoding.stop(); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + _logger.info('original size: $originalSize \n new size: $newSize'); + + final ( + List relativeDetections, + delaySentToIsolate, + timeSentToMain + ) = await _runInIsolate( + ( + FaceDetectionOperation.yoloInferenceAndPostProcessing, + { + 'inputImageList': inputImageList, + 'inputShape': inputShape, + 'newSize': newSize, + 'sessionAddress': sessionAddress, + 'timeNow': DateTime.now(), + } + ), + ) as (List, int, DateTime); + + final delaySentToMain = + DateTime.now().difference(timeSentToMain).inMilliseconds; + + stopwatch.stop(); + _logger.info( + 'predictInIsolate() face detection executed in ${stopwatch.elapsedMilliseconds}ms, with ${delaySentToIsolate}ms delay sent to isolate, and ${delaySentToMain}ms delay sent to main, for a total of ${delaySentToIsolate + delaySentToMain}ms delay due to isolate', + ); + + return (relativeDetections, originalSize); + } + + Future<(List, Size)> predictInComputer( + String imagePath, + ) async { + assert(isInitialized); + + _logger.info('predictInComputer() is called'); + + final stopwatch = Stopwatch()..start(); + + final stopwatchDecoding = Stopwatch()..start(); + final imageData = await File(imagePath).readAsBytes(); + final (inputImageList, originalSize, newSize) = + await ImageMlIsolate.instance.preprocessImageYoloOnnx( + imageData, + normalize: true, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + quality: FilterQuality.medium, + ); + // final input = [inputImageList]; + return await _computerLock.synchronized(() async { + final inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; + + stopwatchDecoding.stop(); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + _logger.info('original size: $originalSize \n new size: $newSize'); + + final ( + List relativeDetections, + delaySentToIsolate, + timeSentToMain + ) = await _computer.compute( + inferenceAndPostProcess, + param: { + 'inputImageList': inputImageList, + 'inputShape': inputShape, + 'newSize': newSize, + 'sessionAddress': sessionAddress, + 'timeNow': DateTime.now(), + }, + ) as (List, int, DateTime); + + final delaySentToMain = + DateTime.now().difference(timeSentToMain).inMilliseconds; + + stopwatch.stop(); + _logger.info( + 'predictInIsolate() face detection executed in ${stopwatch.elapsedMilliseconds}ms, with ${delaySentToIsolate}ms delay sent to isolate, and ${delaySentToMain}ms delay sent to main, for a total of ${delaySentToIsolate + delaySentToMain}ms delay due to isolate', + ); + + return (relativeDetections, originalSize); + }); + } + + /// Detects faces in the given image data. + /// This method is optimized for batch processing. + /// + /// `imageDataList`: The image data to analyze. + /// + /// WARNING: Currently this method only returns the detections for the first image in the batch. + /// Change the function to output all detection before actually using it in production. + Future> predictBatch( + List imageDataList, + ) async { + assert(isInitialized); + + final stopwatch = Stopwatch()..start(); + + final stopwatchDecoding = Stopwatch()..start(); + final List inputImageDataLists = []; + final List<(Size, Size)> originalAndNewSizeList = []; + int concatenatedImageInputsLength = 0; + for (final imageData in imageDataList) { + final (inputImageList, originalSize, newSize) = + await ImageMlIsolate.instance.preprocessImageYoloOnnx( + imageData, + normalize: true, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + quality: FilterQuality.medium, + ); + inputImageDataLists.add(inputImageList); + originalAndNewSizeList.add((originalSize, newSize)); + concatenatedImageInputsLength += inputImageList.length; + } + + final inputImageList = Float32List(concatenatedImageInputsLength); + + int offset = 0; + for (int i = 0; i < inputImageDataLists.length; i++) { + final inputImageData = inputImageDataLists[i]; + inputImageList.setRange( + offset, + offset + inputImageData.length, + inputImageData, + ); + offset += inputImageData.length; + } + + // final input = [inputImageList]; + final inputShape = [ + inputImageDataLists.length, + 3, + kInputHeight, + kInputWidth, + ]; + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchDecoding.stop(); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + // _logger.info('original size: $originalSize \n new size: $newSize'); + + _logger.info('interpreter.run is called'); + // Run inference + final stopwatchInterpreter = Stopwatch()..start(); + List? outputs; + try { + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + outputs = session.run(runOptions, inputs); + inputOrt.release(); + runOptions.release(); + } catch (e, s) { + _logger.severe('Error while running inference: $e \n $s'); + throw YOLOInterpreterRunException(); + } + stopwatchInterpreter.stop(); + _logger.info( + 'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms, or ${stopwatchInterpreter.elapsedMilliseconds / inputImageDataLists.length} ms per image', + ); + + _logger.info('outputs: $outputs'); + + const int imageOutputToUse = 0; + + // // Get output tensors + final nestedResults = + outputs[0]?.value as List>>; // [b, 25200, 16] + final selectedResults = nestedResults[imageOutputToUse]; // [25200, 16] + + // final rawScores = []; + // for (final result in firstResults) { + // rawScores.add(result[4]); + // } + // final rawScoresCopy = List.from(rawScores); + // rawScoresCopy.sort(); + // _logger.info('rawScores minimum: ${rawScoresCopy.first}'); + // _logger.info('rawScores maximum: ${rawScoresCopy.last}'); + + var relativeDetections = yoloOnnxFilterExtractDetections( + kMinScoreSigmoidThreshold, + kInputWidth, + kInputHeight, + results: selectedResults, + ); + + // Release outputs + for (var element in outputs) { + element?.release(); + } + + // Account for the fact that the aspect ratio was maintained + for (final faceDetection in relativeDetections) { + faceDetection.correctForMaintainedAspectRatio( + Size( + kInputWidth.toDouble(), + kInputHeight.toDouble(), + ), + originalAndNewSizeList[imageOutputToUse].$2, + ); + } + + // Non-maximum suppression to remove duplicate detections + relativeDetections = naiveNonMaxSuppression( + detections: relativeDetections, + iouThreshold: kIouThreshold, + ); + + if (relativeDetections.isEmpty) { + _logger.info('No face detected'); + return []; + } + + stopwatch.stop(); + _logger.info( + 'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms', + ); + + return relativeDetections; + } + + static List _yoloPostProcessOutputs( + List? outputs, + Size newSize, + ) { + // // Get output tensors + final nestedResults = + outputs?[0]?.value as List>>; // [1, 25200, 16] + final firstResults = nestedResults[0]; // [25200, 16] + + // final rawScores = []; + // for (final result in firstResults) { + // rawScores.add(result[4]); + // } + // final rawScoresCopy = List.from(rawScores); + // rawScoresCopy.sort(); + // _logger.info('rawScores minimum: ${rawScoresCopy.first}'); + // _logger.info('rawScores maximum: ${rawScoresCopy.last}'); + + var relativeDetections = yoloOnnxFilterExtractDetections( + kMinScoreSigmoidThreshold, + kInputWidth, + kInputHeight, + results: firstResults, + ); + + // Release outputs + // outputs?.forEach((element) { + // element?.release(); + // }); + + // Account for the fact that the aspect ratio was maintained + for (final faceDetection in relativeDetections) { + faceDetection.correctForMaintainedAspectRatio( + Size( + kInputWidth.toDouble(), + kInputHeight.toDouble(), + ), + newSize, + ); + } + + // Non-maximum suppression to remove duplicate detections + relativeDetections = naiveNonMaxSuppression( + detections: relativeDetections, + iouThreshold: kIouThreshold, + ); + + dev.log( + '[YOLOFaceDetectionService] ${relativeDetections.length} faces detected', + ); + + return relativeDetections; + } + + /// Initialize the interpreter by loading the model file. + static Future _loadModel(Map args) async { + final sessionOptions = OrtSessionOptions() + ..setInterOpNumThreads(1) + ..setIntraOpNumThreads(1) + ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); + try { + // _logger.info('Loading face embedding model'); + final session = + OrtSession.fromFile(File(args["modelPath"]), sessionOptions); + // _logger.info('Face embedding model loaded'); + return session.address; + } catch (e, _) { + // _logger.severe('Face embedding model not loaded', e, s); + } + return -1; + } + + static Future _releaseModel(Map args) async { + final address = args['address'] as int; + if (address == 0) { + return; + } + final session = OrtSession.fromAddress(address); + session.release(); + return; + } + + static Future<(List, int, DateTime)> + inferenceAndPostProcess( + Map args, + ) async { + final inputImageList = args['inputImageList'] as Float32List; + final inputShape = args['inputShape'] as List; + final newSize = args['newSize'] as Size; + final sessionAddress = args['sessionAddress'] as int; + final timeSentToIsolate = args['timeNow'] as DateTime; + final delaySentToIsolate = + DateTime.now().difference(timeSentToIsolate).inMilliseconds; + + final Stopwatch stopwatchPrepare = Stopwatch()..start(); + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchPrepare.stop(); + dev.log( + '[YOLOFaceDetectionService] data preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms', + ); + + stopwatchPrepare.reset(); + stopwatchPrepare.start(); + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + stopwatchPrepare.stop(); + dev.log( + '[YOLOFaceDetectionService] session preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms', + ); + + final stopwatchInterpreter = Stopwatch()..start(); + late final List outputs; + try { + outputs = session.run(runOptions, inputs); + } catch (e, s) { + dev.log( + '[YOLOFaceDetectionService] Error while running inference: $e \n $s', + ); + throw YOLOInterpreterRunException(); + } + stopwatchInterpreter.stop(); + dev.log( + '[YOLOFaceDetectionService] interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + ); + + final relativeDetections = _yoloPostProcessOutputs(outputs, newSize); + + return (relativeDetections, delaySentToIsolate, DateTime.now()); + } +} diff --git a/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_face_detection_exceptions.dart b/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_face_detection_exceptions.dart new file mode 100644 index 000000000..69b1edbc0 --- /dev/null +++ b/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_face_detection_exceptions.dart @@ -0,0 +1,3 @@ +class YOLOInterpreterInitializationException implements Exception {} + +class YOLOInterpreterRunException implements Exception {} \ No newline at end of file diff --git a/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_face_detection_options.dart b/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_face_detection_options.dart new file mode 100644 index 000000000..f4cec3458 --- /dev/null +++ b/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_face_detection_options.dart @@ -0,0 +1,31 @@ +import 'dart:math' as math show log; + +class FaceDetectionOptionsYOLO { + final double minScoreSigmoidThreshold; + final double iouThreshold; + final int inputWidth; + final int inputHeight; + final int numCoords; + final int numKeypoints; + final int numValuesPerKeypoint; + final int maxNumFaces; + final double scoreClippingThresh; + final double inverseSigmoidMinScoreThreshold; + final bool useSigmoidScore; + final bool flipVertically; + + FaceDetectionOptionsYOLO({ + required this.minScoreSigmoidThreshold, + required this.iouThreshold, + required this.inputWidth, + required this.inputHeight, + this.numCoords = 14, + this.numKeypoints = 5, + this.numValuesPerKeypoint = 2, + this.maxNumFaces = 100, + this.scoreClippingThresh = 100.0, + this.useSigmoidScore = true, + this.flipVertically = false, + }) : inverseSigmoidMinScoreThreshold = + math.log(minScoreSigmoidThreshold / (1 - minScoreSigmoidThreshold)); +} \ No newline at end of file diff --git a/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_filter_extract_detections.dart b/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_filter_extract_detections.dart new file mode 100644 index 000000000..6fb6744c4 --- /dev/null +++ b/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_filter_extract_detections.dart @@ -0,0 +1,81 @@ +import "package:photos/services/face_ml/face_detection/detection.dart"; + +List yoloOnnxFilterExtractDetections( + double minScoreSigmoidThreshold, + int inputWidth, + int inputHeight, { + required List> results, // // [25200, 16] +}) { + final outputDetections = []; + final output = >[]; + + // Go through the raw output and check the scores + for (final result in results) { + // Filter out raw detections with low scores + if (result[4] < minScoreSigmoidThreshold) { + continue; + } + + // Get the raw detection + final rawDetection = List.from(result); + + // Append the processed raw detection to the output + output.add(rawDetection); + } + + for (final List rawDetection in output) { + // Get absolute bounding box coordinates in format [xMin, yMin, xMax, yMax] https://github.com/deepcam-cn/yolov5-face/blob/eb23d18defe4a76cc06449a61cd51004c59d2697/utils/general.py#L216 + final xMinAbs = rawDetection[0] - rawDetection[2] / 2; + final yMinAbs = rawDetection[1] - rawDetection[3] / 2; + final xMaxAbs = rawDetection[0] + rawDetection[2] / 2; + final yMaxAbs = rawDetection[1] + rawDetection[3] / 2; + + // Get the relative bounding box coordinates in format [xMin, yMin, xMax, yMax] + final box = [ + xMinAbs / inputWidth, + yMinAbs / inputHeight, + xMaxAbs / inputWidth, + yMaxAbs / inputHeight, + ]; + + // Get the keypoints coordinates in format [x, y] + final allKeypoints = >[ + [ + rawDetection[5] / inputWidth, + rawDetection[6] / inputHeight, + ], + [ + rawDetection[7] / inputWidth, + rawDetection[8] / inputHeight, + ], + [ + rawDetection[9] / inputWidth, + rawDetection[10] / inputHeight, + ], + [ + rawDetection[11] / inputWidth, + rawDetection[12] / inputHeight, + ], + [ + rawDetection[13] / inputWidth, + rawDetection[14] / inputHeight, + ], + ]; + + // Get the score + final score = + rawDetection[4]; // Or should it be rawDetection[4]*rawDetection[15]? + + // Create the relative detection + final detection = FaceDetectionRelative( + score: score, + box: box, + allKeypoints: allKeypoints, + ); + + // Append the relative detection to the output + outputDetections.add(detection); + } + + return outputDetections; +} diff --git a/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_model_config.dart b/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_model_config.dart new file mode 100644 index 000000000..c803beffd --- /dev/null +++ b/mobile/lib/services/face_ml/face_detection/yolov5face/yolo_model_config.dart @@ -0,0 +1,22 @@ +import "package:photos/services/face_ml/face_detection/yolov5face/yolo_face_detection_options.dart"; +import "package:photos/services/face_ml/model_file.dart"; + +class YOLOModelConfig { + final String modelPath; + final FaceDetectionOptionsYOLO faceOptions; + + YOLOModelConfig({ + required this.modelPath, + required this.faceOptions, + }); +} + +final YOLOModelConfig yoloV5FaceS640x640DynamicBatchonnx = YOLOModelConfig( + modelPath: ModelFile.yoloV5FaceS640x640DynamicBatchonnx, + faceOptions: FaceDetectionOptionsYOLO( + minScoreSigmoidThreshold: 0.8, + iouThreshold: 0.4, + inputWidth: 640, + inputHeight: 640, + ), +); diff --git a/mobile/lib/services/face_ml/face_embedding/face_embedding_exceptions.dart b/mobile/lib/services/face_ml/face_embedding/face_embedding_exceptions.dart new file mode 100644 index 000000000..548b80a95 --- /dev/null +++ b/mobile/lib/services/face_ml/face_embedding/face_embedding_exceptions.dart @@ -0,0 +1,11 @@ +class MobileFaceNetInterpreterInitializationException implements Exception {} + +class MobileFaceNetImagePreprocessingException implements Exception {} + +class MobileFaceNetEmptyInput implements Exception {} + +class MobileFaceNetWrongInputSize implements Exception {} + +class MobileFaceNetWrongInputRange implements Exception {} + +class MobileFaceNetInterpreterRunException implements Exception {} \ No newline at end of file diff --git a/mobile/lib/services/face_ml/face_embedding/face_embedding_options.dart b/mobile/lib/services/face_ml/face_embedding/face_embedding_options.dart new file mode 100644 index 000000000..6ac7f339a --- /dev/null +++ b/mobile/lib/services/face_ml/face_embedding/face_embedding_options.dart @@ -0,0 +1,15 @@ +class FaceEmbeddingOptions { + final int inputWidth; + final int inputHeight; + final int embeddingLength; + final int numChannels; + final bool preWhiten; + + FaceEmbeddingOptions({ + required this.inputWidth, + required this.inputHeight, + this.embeddingLength = 192, + this.numChannels = 3, + this.preWhiten = false, + }); +} diff --git a/mobile/lib/services/face_ml/face_embedding/face_embedding_service.dart b/mobile/lib/services/face_ml/face_embedding/face_embedding_service.dart new file mode 100644 index 000000000..2711550bd --- /dev/null +++ b/mobile/lib/services/face_ml/face_embedding/face_embedding_service.dart @@ -0,0 +1,279 @@ +import 'dart:io'; +import "dart:math" show min, max, sqrt; +// import 'dart:math' as math show min, max; +import 'dart:typed_data' show Uint8List; + +import "package:flutter/foundation.dart"; +import "package:logging/logging.dart"; +import 'package:photos/models/ml/ml_typedefs.dart'; +import "package:photos/services/face_ml/face_detection/detection.dart"; +import "package:photos/services/face_ml/face_embedding/face_embedding_exceptions.dart"; +import "package:photos/services/face_ml/face_embedding/face_embedding_options.dart"; +import "package:photos/services/face_ml/face_embedding/mobilefacenet_model_config.dart"; +import 'package:photos/utils/image_ml_isolate.dart'; +import 'package:photos/utils/image_ml_util.dart'; +import 'package:tflite_flutter/tflite_flutter.dart'; + +/// This class is responsible for running the MobileFaceNet model, and can be accessed through the singleton `FaceEmbedding.instance`. +class FaceEmbedding { + Interpreter? _interpreter; + IsolateInterpreter? _isolateInterpreter; + int get getAddress => _interpreter!.address; + + final outputShapes = >[]; + final outputTypes = []; + + final _logger = Logger("FaceEmbeddingService"); + + final MobileFaceNetModelConfig config; + final FaceEmbeddingOptions embeddingOptions; + // singleton pattern + FaceEmbedding._privateConstructor({required this.config}) + : embeddingOptions = config.faceEmbeddingOptions; + + /// Use this instance to access the FaceEmbedding service. Make sure to call `init()` before using it. + /// e.g. `await FaceEmbedding.instance.init();` + /// + /// Then you can use `predict()` to get the embedding of a face, so `FaceEmbedding.instance.predict(imageData)` + /// + /// config options: faceEmbeddingEnte + static final instance = + FaceEmbedding._privateConstructor(config: faceEmbeddingEnte); + factory FaceEmbedding() => instance; + + /// Check if the interpreter is initialized, if not initialize it with `loadModel()` + Future init() async { + if (_interpreter == null || _isolateInterpreter == null) { + await _loadModel(); + } + } + + Future dispose() async { + _logger.info('dispose() is called'); + + try { + _interpreter?.close(); + _interpreter = null; + await _isolateInterpreter?.close(); + _isolateInterpreter = null; + } catch (e) { + _logger.severe('Error while closing interpreter: $e'); + rethrow; + } + } + + /// WARNING: This function only works for one face at a time. it's better to use [predict], which can handle both single and multiple faces. + Future> predictSingle( + Uint8List imageData, + FaceDetectionRelative face, + ) async { + assert(_interpreter != null && _isolateInterpreter != null); + + final stopwatch = Stopwatch()..start(); + + // Image decoding and preprocessing + List>>> input; + List output; + try { + final stopwatchDecoding = Stopwatch()..start(); + final (inputImageMatrix, _, _, _, _) = + await ImageMlIsolate.instance.preprocessMobileFaceNet( + imageData, + [face], + ); + input = inputImageMatrix; + stopwatchDecoding.stop(); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + + output = createEmptyOutputMatrix(outputShapes[0]); + } catch (e) { + _logger.severe('Error while decoding and preprocessing image: $e'); + throw MobileFaceNetImagePreprocessingException(); + } + + _logger.info('interpreter.run is called'); + // Run inference + try { + await _isolateInterpreter!.run(input, output); + // _interpreter!.run(input, output); + // ignore: avoid_catches_without_on_clauses + } catch (e) { + _logger.severe('Error while running inference: $e'); + throw MobileFaceNetInterpreterRunException(); + } + _logger.info('interpreter.run is finished'); + + // Get output tensors + final embedding = output[0] as List; + + // Normalize the embedding + final norm = sqrt(embedding.map((e) => e * e).reduce((a, b) => a + b)); + for (int i = 0; i < embedding.length; i++) { + embedding[i] /= norm; + } + + stopwatch.stop(); + _logger.info( + 'predict() executed in ${stopwatch.elapsedMilliseconds}ms', + ); + + // _logger.info( + // 'results (only first few numbers): embedding ${embedding.sublist(0, 5)}', + // ); + // _logger.info( + // 'Mean of embedding: ${embedding.reduce((a, b) => a + b) / embedding.length}', + // ); + // _logger.info( + // 'Max of embedding: ${embedding.reduce(math.max)}', + // ); + // _logger.info( + // 'Min of embedding: ${embedding.reduce(math.min)}', + // ); + + return embedding; + } + + Future>> predict( + List inputImageMatrix, + ) async { + assert(_interpreter != null && _isolateInterpreter != null); + + final stopwatch = Stopwatch()..start(); + + _checkPreprocessedInput(inputImageMatrix); // [inputHeight, inputWidth, 3] + final input = [inputImageMatrix]; + // await encodeAndSaveData(inputImageMatrix, 'input_mobilefacenet'); + + final output = {}; + final outputShape = outputShapes[0]; + outputShape[0] = inputImageMatrix.length; + output[0] = createEmptyOutputMatrix(outputShape); + // for (int i = 0; i < faces.length; i++) { + // output[i] = createEmptyOutputMatrix(outputShapes[0]); + // } + + _logger.info('interpreter.run is called'); + // Run inference + final stopwatchInterpreter = Stopwatch()..start(); + try { + await _isolateInterpreter!.runForMultipleInputs(input, output); + // _interpreter!.runForMultipleInputs(input, output); + // ignore: avoid_catches_without_on_clauses + } catch (e) { + _logger.severe('Error while running inference: $e'); + throw MobileFaceNetInterpreterRunException(); + } + stopwatchInterpreter.stop(); + _logger.info( + 'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds}ms', + ); + // _logger.info('output: $output'); + + // Get output tensors + final embeddings = >[]; + final outerEmbedding = output[0]! as Iterable; + for (int i = 0; i < inputImageMatrix.length; i++) { + final embedding = List.from(outerEmbedding.toList()[i]); + // _logger.info("The $i-th embedding: $embedding"); + embeddings.add(embedding); + } + // await encodeAndSaveData(embeddings, 'output_mobilefacenet'); + + // Normalize the embedding + for (int i = 0; i < embeddings.length; i++) { + final embedding = embeddings[i]; + final norm = sqrt(embedding.map((e) => e * e).reduce((a, b) => a + b)); + for (int j = 0; j < embedding.length; j++) { + embedding[j] /= norm; + } + } + + stopwatch.stop(); + _logger.info( + 'predictBatch() executed in ${stopwatch.elapsedMilliseconds}ms', + ); + + return embeddings; + } + + Future _loadModel() async { + _logger.info('loadModel is called'); + + try { + final interpreterOptions = InterpreterOptions(); + + // Android Delegates + // TODO: Make sure this works on both platforms: Android and iOS + if (Platform.isAndroid) { + // Use GPU Delegate (GPU). WARNING: It doesn't work on emulator + // if (!kDebugMode) { + // interpreterOptions.addDelegate(GpuDelegateV2()); + // } + // Use XNNPACK Delegate (CPU) + interpreterOptions.addDelegate(XNNPackDelegate()); + } + + // iOS Delegates + if (Platform.isIOS) { + // Use Metal Delegate (GPU) + interpreterOptions.addDelegate(GpuDelegate()); + } + + // Load model from assets + _interpreter ??= await Interpreter.fromAsset( + config.modelPath, + options: interpreterOptions, + ); + _isolateInterpreter ??= + await IsolateInterpreter.create(address: _interpreter!.address); + + _logger.info('Interpreter created from asset: ${config.modelPath}'); + + // Get tensor input shape [1, 112, 112, 3] + final inputTensors = _interpreter!.getInputTensors().first; + _logger.info('Input Tensors: $inputTensors'); + // Get tensour output shape [1, 192] + final outputTensors = _interpreter!.getOutputTensors(); + final outputTensor = outputTensors.first; + _logger.info('Output Tensors: $outputTensor'); + + for (var tensor in outputTensors) { + outputShapes.add(tensor.shape); + outputTypes.add(tensor.type); + } + _logger.info('outputShapes: $outputShapes'); + _logger.info('loadModel is finished'); + // ignore: avoid_catches_without_on_clauses + } catch (e) { + _logger.severe('Error while creating interpreter: $e'); + throw MobileFaceNetInterpreterInitializationException(); + } + } + + void _checkPreprocessedInput( + List inputMatrix, + ) { + final embeddingOptions = config.faceEmbeddingOptions; + + if (inputMatrix.isEmpty) { + // Check if the input is empty + throw MobileFaceNetEmptyInput(); + } + + // Check if the input is the correct size + if (inputMatrix[0].length != embeddingOptions.inputHeight || + inputMatrix[0][0].length != embeddingOptions.inputWidth) { + throw MobileFaceNetWrongInputSize(); + } + + final flattened = inputMatrix[0].expand((i) => i).expand((i) => i); + final minValue = flattened.reduce(min); + final maxValue = flattened.reduce(max); + + if (minValue < -1 || maxValue > 1) { + throw MobileFaceNetWrongInputRange(); + } + } +} diff --git a/mobile/lib/services/face_ml/face_embedding/mobilefacenet_model_config.dart b/mobile/lib/services/face_ml/face_embedding/mobilefacenet_model_config.dart new file mode 100644 index 000000000..d55a2d333 --- /dev/null +++ b/mobile/lib/services/face_ml/face_embedding/mobilefacenet_model_config.dart @@ -0,0 +1,20 @@ +import "package:photos/services/face_ml/face_embedding/face_embedding_options.dart"; +import "package:photos/services/face_ml/model_file.dart"; + +class MobileFaceNetModelConfig { + final String modelPath; + final FaceEmbeddingOptions faceEmbeddingOptions; + + MobileFaceNetModelConfig({ + required this.modelPath, + required this.faceEmbeddingOptions, + }); +} + +final MobileFaceNetModelConfig faceEmbeddingEnte = MobileFaceNetModelConfig( + modelPath: ModelFile.faceEmbeddingEnte, + faceEmbeddingOptions: FaceEmbeddingOptions( + inputWidth: 112, + inputHeight: 112, + ), +); diff --git a/mobile/lib/services/face_ml/face_embedding/onnx_face_embedding.dart b/mobile/lib/services/face_ml/face_embedding/onnx_face_embedding.dart new file mode 100644 index 000000000..0ab126253 --- /dev/null +++ b/mobile/lib/services/face_ml/face_embedding/onnx_face_embedding.dart @@ -0,0 +1,245 @@ +import "dart:io" show File; +import 'dart:math' as math show max, min, sqrt; +import 'dart:typed_data' show Float32List; + +import 'package:computer/computer.dart'; +import 'package:logging/logging.dart'; +import 'package:onnxruntime/onnxruntime.dart'; +import "package:photos/services/face_ml/face_detection/detection.dart"; +import "package:photos/services/remote_assets_service.dart"; +import "package:photos/utils/image_ml_isolate.dart"; +import "package:synchronized/synchronized.dart"; + +class FaceEmbeddingOnnx { + static const kModelBucketEndpoint = "https://models.ente.io/"; + static const kRemoteBucketModelPath = "mobilefacenet_opset15.onnx"; + static const modelRemotePath = kModelBucketEndpoint + kRemoteBucketModelPath; + + static const int kInputSize = 112; + static const int kEmbeddingSize = 192; + + static final _logger = Logger('FaceEmbeddingOnnx'); + + bool isInitialized = false; + int sessionAddress = 0; + + final _computer = Computer.shared(); + + final _computerLock = Lock(); + + // singleton pattern + FaceEmbeddingOnnx._privateConstructor(); + + /// Use this instance to access the FaceEmbedding service. Make sure to call `init()` before using it. + /// e.g. `await FaceEmbedding.instance.init();` + /// + /// Then you can use `predict()` to get the embedding of a face, so `FaceEmbedding.instance.predict(imageData)` + /// + /// config options: faceEmbeddingEnte + static final instance = FaceEmbeddingOnnx._privateConstructor(); + factory FaceEmbeddingOnnx() => instance; + + /// Check if the interpreter is initialized, if not initialize it with `loadModel()` + Future init() async { + if (!isInitialized) { + _logger.info('init is called'); + final model = + await RemoteAssetsService.instance.getAsset(modelRemotePath); + final startTime = DateTime.now(); + // Doing this from main isolate since `rootBundle` cannot be accessed outside it + sessionAddress = await _computer.compute( + _loadModel, + param: { + "modelPath": model.path, + }, + ); + final endTime = DateTime.now(); + _logger.info( + "Face embedding model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", + ); + if (sessionAddress != -1) { + isInitialized = true; + } + } + } + + Future release() async { + if (isInitialized) { + await _computer.compute(_releaseModel, param: {'address': sessionAddress}); + isInitialized = false; + sessionAddress = 0; + } + } + + static Future _loadModel(Map args) async { + final sessionOptions = OrtSessionOptions() + ..setInterOpNumThreads(1) + ..setIntraOpNumThreads(1) + ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); + try { + // _logger.info('Loading face embedding model'); + final session = + OrtSession.fromFile(File(args["modelPath"]), sessionOptions); + // _logger.info('Face embedding model loaded'); + return session.address; + } catch (e, _) { + // _logger.severe('Face embedding model not loaded', e, s); + } + return -1; + } + + static Future _releaseModel(Map args) async { + final address = args['address'] as int; + if (address == 0) { + return; + } + final session = OrtSession.fromAddress(address); + session.release(); + return; + } + + Future<(List, bool, double)> predictFromImageDataInComputer( + String imagePath, + FaceDetectionRelative face, + ) async { + assert(sessionAddress != 0 && sessionAddress != -1 && isInitialized); + + try { + final stopwatchDecoding = Stopwatch()..start(); + final (inputImageList, alignmentResults, isBlur, blurValue, _) = + await ImageMlIsolate.instance.preprocessMobileFaceNetOnnx( + imagePath, + [face], + ); + stopwatchDecoding.stop(); + _logger.info( + 'MobileFaceNet image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + + final stopwatch = Stopwatch()..start(); + _logger.info('MobileFaceNet interpreter.run is called'); + final embedding = await _computer.compute( + inferFromMap, + param: { + 'input': inputImageList, + 'address': sessionAddress, + 'inputSize': kInputSize, + }, + taskName: 'createFaceEmbedding', + ) as List; + stopwatch.stop(); + _logger.info( + 'MobileFaceNet interpreter.run is finished, in ${stopwatch.elapsedMilliseconds}ms', + ); + + _logger.info( + 'MobileFaceNet results (only first few numbers): embedding ${embedding.sublist(0, 5)}', + ); + _logger.info( + 'Mean of embedding: ${embedding.reduce((a, b) => a + b) / embedding.length}', + ); + _logger.info( + 'Max of embedding: ${embedding.reduce(math.max)}', + ); + _logger.info( + 'Min of embedding: ${embedding.reduce(math.min)}', + ); + + return (embedding, isBlur[0], blurValue[0]); + } catch (e) { + _logger.info('MobileFaceNet Error while running inference: $e'); + rethrow; + } + } + + Future>> predictInComputer(Float32List input) async { + assert(sessionAddress != 0 && sessionAddress != -1 && isInitialized); + return await _computerLock.synchronized(() async { + try { + final stopwatch = Stopwatch()..start(); + _logger.info('MobileFaceNet interpreter.run is called'); + final embeddings = await _computer.compute( + inferFromMap, + param: { + 'input': input, + 'address': sessionAddress, + 'inputSize': kInputSize, + }, + taskName: 'createFaceEmbedding', + ) as List>; + stopwatch.stop(); + _logger.info( + 'MobileFaceNet interpreter.run is finished, in ${stopwatch.elapsedMilliseconds}ms', + ); + + return embeddings; + } catch (e) { + _logger.info('MobileFaceNet Error while running inference: $e'); + rethrow; + } + }); + } + + static Future>> predictSync( + Float32List input, + int sessionAddress, + ) async { + assert(sessionAddress != 0 && sessionAddress != -1); + try { + final stopwatch = Stopwatch()..start(); + _logger.info('MobileFaceNet interpreter.run is called'); + final embeddings = await infer( + input, + sessionAddress, + kInputSize, + ); + stopwatch.stop(); + _logger.info( + 'MobileFaceNet interpreter.run is finished, in ${stopwatch.elapsedMilliseconds}ms', + ); + + return embeddings; + } catch (e) { + _logger.info('MobileFaceNet Error while running inference: $e'); + rethrow; + } + } + + static Future>> inferFromMap(Map args) async { + final inputImageList = args['input'] as Float32List; + final address = args['address'] as int; + final inputSize = args['inputSize'] as int; + return await infer(inputImageList, address, inputSize); + } + + static Future>> infer( + Float32List inputImageList, + int address, + int inputSize, + ) async { + final runOptions = OrtRunOptions(); + final int numberOfFaces = + inputImageList.length ~/ (inputSize * inputSize * 3); + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + [numberOfFaces, inputSize, inputSize, 3], + ); + final inputs = {'img_inputs': inputOrt}; + final session = OrtSession.fromAddress(address); + final List outputs = session.run(runOptions, inputs); + final embeddings = outputs[0]?.value as List>; + + for (final embedding in embeddings) { + double normalization = 0; + for (int i = 0; i < kEmbeddingSize; i++) { + normalization += embedding[i] * embedding[i]; + } + final double sqrtNormalization = math.sqrt(normalization); + for (int i = 0; i < kEmbeddingSize; i++) { + embedding[i] = embedding[i] / sqrtNormalization; + } + } + + return embeddings; + } +} diff --git a/mobile/lib/services/face_ml/face_feedback.dart/cluster_feedback.dart b/mobile/lib/services/face_ml/face_feedback.dart/cluster_feedback.dart new file mode 100644 index 000000000..b99d3950a --- /dev/null +++ b/mobile/lib/services/face_ml/face_feedback.dart/cluster_feedback.dart @@ -0,0 +1,379 @@ +import "dart:convert"; + +import "package:photos/services/face_ml/face_clustering/cosine_distance.dart"; +import "package:photos/services/face_ml/face_feedback.dart/feedback.dart"; +import "package:photos/services/face_ml/face_feedback.dart/feedback_types.dart"; + +abstract class ClusterFeedback extends Feedback { + static final Map fromJsonStringRegistry = { + FeedbackType.deleteClusterFeedback: DeleteClusterFeedback.fromJsonString, + FeedbackType.mergeClusterFeedback: MergeClusterFeedback.fromJsonString, + FeedbackType.renameOrCustomThumbnailClusterFeedback: + RenameOrCustomThumbnailClusterFeedback.fromJsonString, + FeedbackType.removePhotosClusterFeedback: + RemovePhotosClusterFeedback.fromJsonString, + FeedbackType.addPhotosClusterFeedback: + AddPhotosClusterFeedback.fromJsonString, + }; + + final List medoid; + final double medoidDistanceThreshold; + // TODO: work out the optimal distance threshold so there's never an overlap between clusters + + ClusterFeedback( + FeedbackType type, + this.medoid, + this.medoidDistanceThreshold, { + String? feedbackID, + DateTime? timestamp, + int? madeOnFaceMlVersion, + int? madeOnClusterMlVersion, + }) : super( + type, + feedbackID: feedbackID, + timestamp: timestamp, + madeOnFaceMlVersion: madeOnFaceMlVersion, + madeOnClusterMlVersion: madeOnClusterMlVersion, + ); + + /// Compares this feedback with another [ClusterFeedback] to see if they are similar enough that only one should be kept. + /// + /// It checks this by comparing the distance between the two medoids with the medoidDistanceThreshold of each feedback. + /// + /// Returns true if they are similar enough, false otherwise. + /// // TODO: Should it maybe return a merged feedback instead, when you are similar enough? + bool looselyMatchesMedoid(ClusterFeedback other) { + // Using the cosineDistance function you mentioned + final double distance = cosineDistance(medoid, other.medoid); + + // Check if the distance is less than either of the threshold values + return distance < medoidDistanceThreshold || + distance < other.medoidDistanceThreshold; + } + + bool exactlyMatchesMedoid(ClusterFeedback other) { + if (medoid.length != other.medoid.length) { + return false; + } + for (int i = 0; i < medoid.length; i++) { + if (medoid[i] != other.medoid[i]) { + return false; + } + } + return true; + } +} + +class DeleteClusterFeedback extends ClusterFeedback { + DeleteClusterFeedback({ + required List medoid, + required double medoidDistanceThreshold, + String? feedbackID, + DateTime? timestamp, + int? madeOnFaceMlVersion, + int? madeOnClusterMlVersion, + }) : super( + FeedbackType.deleteClusterFeedback, + medoid, + medoidDistanceThreshold, + feedbackID: feedbackID, + timestamp: timestamp, + madeOnFaceMlVersion: madeOnFaceMlVersion, + madeOnClusterMlVersion: madeOnClusterMlVersion, + ); + + @override + Map toJson() { + return { + 'type': type.toValueString(), + 'medoid': medoid, + 'medoidDistanceThreshold': medoidDistanceThreshold, + 'feedbackID': feedbackID, + 'timestamp': timestamp.toIso8601String(), + 'madeOnFaceMlVersion': madeOnFaceMlVersion, + 'madeOnClusterMlVersion': madeOnClusterMlVersion, + }; + } + + @override + String toJsonString() => jsonEncode(toJson()); + + static DeleteClusterFeedback fromJson(Map json) { + assert(json['type'] == FeedbackType.deleteClusterFeedback.toValueString()); + return DeleteClusterFeedback( + medoid: + (json['medoid'] as List?)?.map((item) => item as double).toList() ?? + [], + medoidDistanceThreshold: json['medoidDistanceThreshold'], + feedbackID: json['feedbackID'], + timestamp: DateTime.parse(json['timestamp']), + madeOnFaceMlVersion: json['madeOnFaceMlVersion'], + madeOnClusterMlVersion: json['madeOnClusterMlVersion'], + ); + } + + static fromJsonString(String jsonString) { + return fromJson(jsonDecode(jsonString)); + } +} + +class MergeClusterFeedback extends ClusterFeedback { + final List medoidToMoveTo; + + MergeClusterFeedback({ + required List medoid, + required double medoidDistanceThreshold, + required this.medoidToMoveTo, + String? feedbackID, + DateTime? timestamp, + int? madeOnFaceMlVersion, + int? madeOnClusterMlVersion, + }) : super( + FeedbackType.mergeClusterFeedback, + medoid, + medoidDistanceThreshold, + feedbackID: feedbackID, + timestamp: timestamp, + madeOnFaceMlVersion: madeOnFaceMlVersion, + madeOnClusterMlVersion: madeOnClusterMlVersion, + ); + + @override + Map toJson() { + return { + 'type': type.toValueString(), + 'medoid': medoid, + 'medoidDistanceThreshold': medoidDistanceThreshold, + 'medoidToMoveTo': medoidToMoveTo, + 'feedbackID': feedbackID, + 'timestamp': timestamp.toIso8601String(), + 'madeOnFaceMlVersion': madeOnFaceMlVersion, + 'madeOnClusterMlVersion': madeOnClusterMlVersion, + }; + } + + @override + String toJsonString() => jsonEncode(toJson()); + + static MergeClusterFeedback fromJson(Map json) { + assert(json['type'] == FeedbackType.mergeClusterFeedback.toValueString()); + return MergeClusterFeedback( + medoid: + (json['medoid'] as List?)?.map((item) => item as double).toList() ?? + [], + medoidDistanceThreshold: json['medoidDistanceThreshold'], + medoidToMoveTo: (json['medoidToMoveTo'] as List?) + ?.map((item) => item as double) + .toList() ?? + [], + feedbackID: json['feedbackID'], + timestamp: DateTime.parse(json['timestamp']), + madeOnFaceMlVersion: json['madeOnFaceMlVersion'], + madeOnClusterMlVersion: json['madeOnClusterMlVersion'], + ); + } + + static MergeClusterFeedback fromJsonString(String jsonString) { + return fromJson(jsonDecode(jsonString)); + } +} + +class RenameOrCustomThumbnailClusterFeedback extends ClusterFeedback { + String? customName; + String? customThumbnailFaceId; + + RenameOrCustomThumbnailClusterFeedback({ + required List medoid, + required double medoidDistanceThreshold, + this.customName, + this.customThumbnailFaceId, + String? feedbackID, + DateTime? timestamp, + int? madeOnFaceMlVersion, + int? madeOnClusterMlVersion, + }) : assert( + customName != null || customThumbnailFaceId != null, + "Either customName or customThumbnailFaceId must be non-null!", + ), + super( + FeedbackType.renameOrCustomThumbnailClusterFeedback, + medoid, + medoidDistanceThreshold, + feedbackID: feedbackID, + timestamp: timestamp, + madeOnFaceMlVersion: madeOnFaceMlVersion, + madeOnClusterMlVersion: madeOnClusterMlVersion, + ); + + @override + Map toJson() { + return { + 'type': type.toValueString(), + 'medoid': medoid, + 'medoidDistanceThreshold': medoidDistanceThreshold, + if (customName != null) 'customName': customName, + if (customThumbnailFaceId != null) + 'customThumbnailFaceId': customThumbnailFaceId, + 'feedbackID': feedbackID, + 'timestamp': timestamp.toIso8601String(), + 'madeOnFaceMlVersion': madeOnFaceMlVersion, + 'madeOnClusterMlVersion': madeOnClusterMlVersion, + }; + } + + @override + String toJsonString() => jsonEncode(toJson()); + + static RenameOrCustomThumbnailClusterFeedback fromJson( + Map json, + ) { + assert( + json['type'] == + FeedbackType.renameOrCustomThumbnailClusterFeedback.toValueString(), + ); + return RenameOrCustomThumbnailClusterFeedback( + medoid: + (json['medoid'] as List?)?.map((item) => item as double).toList() ?? + [], + medoidDistanceThreshold: json['medoidDistanceThreshold'], + customName: json['customName'], + customThumbnailFaceId: json['customThumbnailFaceId'], + feedbackID: json['feedbackID'], + timestamp: DateTime.parse(json['timestamp']), + madeOnFaceMlVersion: json['madeOnFaceMlVersion'], + madeOnClusterMlVersion: json['madeOnClusterMlVersion'], + ); + } + + static RenameOrCustomThumbnailClusterFeedback fromJsonString( + String jsonString, + ) { + return fromJson(jsonDecode(jsonString)); + } +} + +class RemovePhotosClusterFeedback extends ClusterFeedback { + final List removedPhotosFileID; + + RemovePhotosClusterFeedback({ + required List medoid, + required double medoidDistanceThreshold, + required this.removedPhotosFileID, + String? feedbackID, + DateTime? timestamp, + int? madeOnFaceMlVersion, + int? madeOnClusterMlVersion, + }) : super( + FeedbackType.removePhotosClusterFeedback, + medoid, + medoidDistanceThreshold, + feedbackID: feedbackID, + timestamp: timestamp, + madeOnFaceMlVersion: madeOnFaceMlVersion, + madeOnClusterMlVersion: madeOnClusterMlVersion, + ); + + @override + Map toJson() { + return { + 'type': type.toValueString(), + 'medoid': medoid, + 'medoidDistanceThreshold': medoidDistanceThreshold, + 'removedPhotosFileID': removedPhotosFileID, + 'feedbackID': feedbackID, + 'timestamp': timestamp.toIso8601String(), + 'madeOnFaceMlVersion': madeOnFaceMlVersion, + 'madeOnClusterMlVersion': madeOnClusterMlVersion, + }; + } + + @override + String toJsonString() => jsonEncode(toJson()); + + static RemovePhotosClusterFeedback fromJson(Map json) { + assert( + json['type'] == FeedbackType.removePhotosClusterFeedback.toValueString(), + ); + return RemovePhotosClusterFeedback( + medoid: + (json['medoid'] as List?)?.map((item) => item as double).toList() ?? + [], + medoidDistanceThreshold: json['medoidDistanceThreshold'], + removedPhotosFileID: (json['removedPhotosFileID'] as List?) + ?.map((item) => item as int) + .toList() ?? + [], + feedbackID: json['feedbackID'], + timestamp: DateTime.parse(json['timestamp']), + madeOnFaceMlVersion: json['madeOnFaceMlVersion'], + madeOnClusterMlVersion: json['madeOnClusterMlVersion'], + ); + } + + static RemovePhotosClusterFeedback fromJsonString(String jsonString) { + return fromJson(jsonDecode(jsonString)); + } +} + +class AddPhotosClusterFeedback extends ClusterFeedback { + final List addedPhotoFileIDs; + + AddPhotosClusterFeedback({ + required List medoid, + required double medoidDistanceThreshold, + required this.addedPhotoFileIDs, + String? feedbackID, + DateTime? timestamp, + int? madeOnFaceMlVersion, + int? madeOnClusterMlVersion, + }) : super( + FeedbackType.addPhotosClusterFeedback, + medoid, + medoidDistanceThreshold, + feedbackID: feedbackID, + timestamp: timestamp, + madeOnFaceMlVersion: madeOnFaceMlVersion, + madeOnClusterMlVersion: madeOnClusterMlVersion, + ); + + @override + Map toJson() { + return { + 'type': type.toValueString(), + 'medoid': medoid, + 'medoidDistanceThreshold': medoidDistanceThreshold, + 'addedPhotoFileIDs': addedPhotoFileIDs, + 'feedbackID': feedbackID, + 'timestamp': timestamp.toIso8601String(), + 'madeOnFaceMlVersion': madeOnFaceMlVersion, + 'madeOnClusterMlVersion': madeOnClusterMlVersion, + }; + } + + @override + String toJsonString() => jsonEncode(toJson()); + + static AddPhotosClusterFeedback fromJson(Map json) { + assert( + json['type'] == FeedbackType.addPhotosClusterFeedback.toValueString(), + ); + return AddPhotosClusterFeedback( + medoid: + (json['medoid'] as List?)?.map((item) => item as double).toList() ?? + [], + medoidDistanceThreshold: json['medoidDistanceThreshold'], + addedPhotoFileIDs: (json['addedPhotoFileIDs'] as List?) + ?.map((item) => item as int) + .toList() ?? + [], + feedbackID: json['feedbackID'], + timestamp: DateTime.parse(json['timestamp']), + madeOnFaceMlVersion: json['madeOnFaceMlVersion'], + madeOnClusterMlVersion: json['madeOnClusterMlVersion'], + ); + } + + static AddPhotosClusterFeedback fromJsonString(String jsonString) { + return fromJson(jsonDecode(jsonString)); + } +} diff --git a/mobile/lib/services/face_ml/face_feedback.dart/face_feedback_service.dart b/mobile/lib/services/face_ml/face_feedback.dart/face_feedback_service.dart new file mode 100644 index 000000000..0e95e3d7c --- /dev/null +++ b/mobile/lib/services/face_ml/face_feedback.dart/face_feedback_service.dart @@ -0,0 +1,416 @@ +import "package:logging/logging.dart"; +import "package:photos/db/ml_data_db.dart"; +import "package:photos/services/face_ml/face_detection/detection.dart"; +import "package:photos/services/face_ml/face_feedback.dart/cluster_feedback.dart"; +import "package:photos/services/face_ml/face_ml_result.dart"; + +class FaceFeedbackService { + final _logger = Logger("FaceFeedbackService"); + + final _mlDatabase = MlDataDB.instance; + + int executedFeedbackCount = 0; + final int _reclusterFeedbackThreshold = 10; + + // singleton pattern + FaceFeedbackService._privateConstructor(); + static final instance = FaceFeedbackService._privateConstructor(); + factory FaceFeedbackService() => instance; + + /// Returns the updated cluster after removing the given file from the given person's cluster. + /// + /// If the file is not in the cluster, returns null. + /// + /// The updated cluster is also updated in [MlDataDB]. + Future removePhotosFromCluster( + List fileIDs, + int personID, + ) async { + // TODO: check if photo was originally added to cluster by user. If so, we should remove that addition instead of changing the embedding, because there is no embedding... + _logger.info( + 'removePhotoFromCluster called with fileIDs $fileIDs and personID $personID', + ); + + if (fileIDs.isEmpty) { + _logger.severe( + "No fileIDs given, unable to add photos to cluster!", + ); + throw ArgumentError( + "No fileIDs given, unable to add photos to cluster!", + ); + } + + // Get the relevant cluster + final ClusterResult? cluster = await _mlDatabase.getClusterResult(personID); + if (cluster == null) { + _logger.severe( + "No cluster found for personID $personID, unable to remove photo from non-existent cluster!", + ); + throw ArgumentError( + "No cluster found for personID $personID, unable to remove photo from non-existent cluster!", + ); + } + // Get the relevant faceMlResults + final List faceMlResults = + await _mlDatabase.getSelectedFaceMlResults(fileIDs); + if (faceMlResults.length != fileIDs.length) { + final List foundFileIDs = + faceMlResults.map((faceMlResult) => faceMlResult.fileId).toList(); + _logger.severe( + "Couldn't find all facemlresults for fileIDs $fileIDs, only found for $foundFileIDs. Unable to remove unindexed photos from cluster!", + ); + throw ArgumentError( + "Couldn't find all facemlresults for fileIDs $fileIDs, only found for $foundFileIDs. Unable to remove unindexed photos from cluster!", + ); + } + + // Check if at least one of the files is in the cluster. If all files are already not in the cluster, return the cluster. + final List fileIDsInCluster = fileIDs + .where((fileID) => cluster.uniqueFileIds.contains(fileID)) + .toList(); + if (fileIDsInCluster.isEmpty) { + _logger.warning( + "All fileIDs are already not in the cluster, unable to remove photos from cluster!", + ); + return cluster; + } + final List faceMlResultsInCluster = faceMlResults + .where((faceMlResult) => fileIDsInCluster.contains(faceMlResult.fileId)) + .toList(); + assert(faceMlResultsInCluster.length == fileIDsInCluster.length); + + for (var i = 0; i < fileIDsInCluster.length; i++) { + // Find the faces/embeddings associated with both the fileID and personID + final List faceIDs = faceMlResultsInCluster[i].allFaceIds; + final List faceIDsInCluster = cluster.faceIDs; + final List relevantFaceIDs = + faceIDsInCluster.where((faceID) => faceIDs.contains(faceID)).toList(); + if (relevantFaceIDs.isEmpty) { + _logger.severe( + "No faces found in both cluster and file, unable to remove photo from cluster!", + ); + throw ArgumentError( + "No faces found in both cluster and file, unable to remove photo from cluster!", + ); + } + + // Set the embeddings to [10, 10,..., 10] and save the updated faceMlResult + faceMlResultsInCluster[i].setEmbeddingsToTen(relevantFaceIDs); + await _mlDatabase.updateFaceMlResult(faceMlResultsInCluster[i]); + + // Make sure there is a manual override for [10, 10,..., 10] embeddings (not actually here, but in building the clusters, see _checkIfClusterIsDeleted function) + + // Manually remove the fileID from the cluster + cluster.removeFileId(fileIDsInCluster[i]); + } + + // TODO: see below + // Re-cluster and check if this leads to more deletions. If so, save them and ask the user if they want to delete them too. + executedFeedbackCount++; + if (executedFeedbackCount % _reclusterFeedbackThreshold == 0) { + // await recluster(); + } + + // Update the cluster in the database + await _mlDatabase.updateClusterResult(cluster); + + // TODO: see below + // Safe the given feedback to the database + final removePhotoFeedback = RemovePhotosClusterFeedback( + medoid: cluster.medoid, + medoidDistanceThreshold: cluster.medoidDistanceThreshold, + removedPhotosFileID: fileIDsInCluster, + ); + await _mlDatabase.createClusterFeedback( + removePhotoFeedback, + skipIfSimilarFeedbackExists: false, + ); + + // Return the updated cluster + return cluster; + } + + Future addPhotosToCluster(List fileIDs, personID) async { + _logger.info( + 'addPhotosToCluster called with fileIDs $fileIDs and personID $personID', + ); + + if (fileIDs.isEmpty) { + _logger.severe( + "No fileIDs given, unable to add photos to cluster!", + ); + throw ArgumentError( + "No fileIDs given, unable to add photos to cluster!", + ); + } + + // Get the relevant cluster + final ClusterResult? cluster = await _mlDatabase.getClusterResult(personID); + if (cluster == null) { + _logger.severe( + "No cluster found for personID $personID, unable to add photos to non-existent cluster!", + ); + throw ArgumentError( + "No cluster found for personID $personID, unable to add photos to non-existent cluster!", + ); + } + + // Check if at least one of the files is not in the cluster. If all files are already in the cluster, return the cluster. + final List fileIDsNotInCluster = fileIDs + .where((fileID) => !cluster.uniqueFileIds.contains(fileID)) + .toList(); + if (fileIDsNotInCluster.isEmpty) { + _logger.warning( + "All fileIDs are already in the cluster, unable to add new photos to cluster!", + ); + return cluster; + } + final List faceIDsNotInCluster = fileIDsNotInCluster + .map((fileID) => FaceDetectionRelative.toFaceIDEmpty(fileID: fileID)) + .toList(); + + // Add the new files to the cluster + cluster.addFileIDsAndFaceIDs(fileIDsNotInCluster, faceIDsNotInCluster); + + // Update the cluster in the database + await _mlDatabase.updateClusterResult(cluster); + + // Build the addPhotoFeedback + final AddPhotosClusterFeedback addPhotosFeedback = AddPhotosClusterFeedback( + medoid: cluster.medoid, + medoidDistanceThreshold: cluster.medoidDistanceThreshold, + addedPhotoFileIDs: fileIDsNotInCluster, + ); + + // TODO: check for exact match and update feedback if necessary + + // Save the addPhotoFeedback to the database + await _mlDatabase.createClusterFeedback( + addPhotosFeedback, + skipIfSimilarFeedbackExists: false, + ); + + // Return the updated cluster + return cluster; + } + + /// Deletes the given cluster completely. + Future deleteCluster(int personID) async { + _logger.info( + 'deleteCluster called with personID $personID', + ); + + // Get the relevant cluster + final cluster = await _mlDatabase.getClusterResult(personID); + if (cluster == null) { + _logger.severe( + "No cluster found for personID $personID, unable to delete non-existent cluster!", + ); + throw ArgumentError( + "No cluster found for personID $personID, unable to delete non-existent cluster!", + ); + } + + // Delete the cluster from the database + await _mlDatabase.deleteClusterResult(cluster.personId); + + // TODO: look into the right threshold distance. + // Build the deleteClusterFeedback + final DeleteClusterFeedback deleteClusterFeedback = DeleteClusterFeedback( + medoid: cluster.medoid, + medoidDistanceThreshold: cluster.medoidDistanceThreshold, + ); + + // TODO: maybe I should merge the two feedbacks if they are similar enough? Or alternatively, I keep them both? + // Check if feedback doesn't already exist + if (await _mlDatabase + .doesSimilarClusterFeedbackExist(deleteClusterFeedback)) { + _logger.warning( + "Feedback already exists for deleting cluster $personID, unable to delete cluster!", + ); + return; + } + + // Save the deleteClusterFeedback to the database + await _mlDatabase.createClusterFeedback(deleteClusterFeedback); + } + + /// Renames the given cluster and/or sets the thumbnail of the given cluster. + /// + /// Requires either a [customName] or a [customFaceID]. If both are given, both are used. If neither are given, an error is thrown. + Future renameOrSetThumbnailCluster( + int personID, { + String? customName, + String? customFaceID, + }) async { + _logger.info( + 'renameOrSetThumbnailCluster called with personID $personID, customName $customName, and customFaceID $customFaceID', + ); + + if (customFaceID != null && + FaceDetectionRelative.isFaceIDEmpty(customFaceID)) { + _logger.severe( + "customFaceID $customFaceID is belongs to empty detection, unable to set as thumbnail of cluster!", + ); + customFaceID = null; + } + if (customName == null && customFaceID == null) { + _logger.severe( + "No name or faceID given, unable to rename or set thumbnail of cluster!", + ); + throw ArgumentError( + "No name or faceID given, unable to rename or set thumbnail of cluster!", + ); + } + + // Get the relevant cluster + final cluster = await _mlDatabase.getClusterResult(personID); + if (cluster == null) { + _logger.severe( + "No cluster found for personID $personID, unable to delete non-existent cluster!", + ); + throw ArgumentError( + "No cluster found for personID $personID, unable to delete non-existent cluster!", + ); + } + + // Update the cluster + if (customName != null) cluster.setUserDefinedName = customName; + if (customFaceID != null) cluster.setThumbnailFaceId = customFaceID; + + // Update the cluster in the database + await _mlDatabase.updateClusterResult(cluster); + + // Build the RenameOrCustomThumbnailClusterFeedback + final RenameOrCustomThumbnailClusterFeedback renameClusterFeedback = + RenameOrCustomThumbnailClusterFeedback( + medoid: cluster.medoid, + medoidDistanceThreshold: cluster.medoidDistanceThreshold, + customName: customName, + customThumbnailFaceId: customFaceID, + ); + + // TODO: maybe I should merge the two feedbacks if they are similar enough? + // Check if feedback doesn't already exist + final matchingFeedbacks = + await _mlDatabase.getAllMatchingClusterFeedback(renameClusterFeedback); + for (final matchingFeedback in matchingFeedbacks) { + // Update the current feedback wherever possible + renameClusterFeedback.customName ??= matchingFeedback.customName; + renameClusterFeedback.customThumbnailFaceId ??= + matchingFeedback.customThumbnailFaceId; + + // Delete the old feedback (since we want the user to be able to overwrite their earlier feedback) + await _mlDatabase.deleteClusterFeedback(matchingFeedback); + } + + // Save the RenameOrCustomThumbnailClusterFeedback to the database + await _mlDatabase.createClusterFeedback(renameClusterFeedback); + + // Return the updated cluster + return cluster; + } + + /// Merges the given clusters. The largest cluster is kept and the other clusters are deleted. + /// + /// Requires either a [clusters] or [personIDs]. If both are given, the [clusters] are used. + Future mergeClusters(List personIDs) async { + _logger.info( + 'mergeClusters called with personIDs $personIDs', + ); + + // Get the relevant clusters + final List clusters = + await _mlDatabase.getSelectedClusterResults(personIDs); + if (clusters.length <= 1) { + _logger.severe( + "${clusters.length} clusters found for personIDs $personIDs, unable to merge non-existent clusters!", + ); + throw ArgumentError( + "${clusters.length} clusters found for personIDs $personIDs, unable to merge non-existent clusters!", + ); + } + + // Find the largest cluster + clusters.sort((a, b) => b.clusterSize.compareTo(a.clusterSize)); + final ClusterResult largestCluster = clusters.first; + + // Now iterate through the clusters to be merged and deleted + for (var i = 1; i < clusters.length; i++) { + final ClusterResult clusterToBeMerged = clusters[i]; + + // Add the files and faces of the cluster to be merged to the largest cluster + largestCluster.addFileIDsAndFaceIDs( + clusterToBeMerged.fileIDsIncludingPotentialDuplicates, + clusterToBeMerged.faceIDs, + ); + + // TODO: maybe I should wrap the logic below in a separate function, since it's also used in renameOrSetThumbnailCluster + // Merge any names and thumbnails if the largest cluster doesn't have them + bool shouldCreateNamingFeedback = false; + String? nameToBeMerged; + String? thumbnailToBeMerged; + if (!largestCluster.hasUserDefinedName && + clusterToBeMerged.hasUserDefinedName) { + largestCluster.setUserDefinedName = clusterToBeMerged.userDefinedName!; + nameToBeMerged = clusterToBeMerged.userDefinedName!; + shouldCreateNamingFeedback = true; + } + if (!largestCluster.thumbnailFaceIdIsUserDefined && + clusterToBeMerged.thumbnailFaceIdIsUserDefined) { + largestCluster.setThumbnailFaceId = clusterToBeMerged.thumbnailFaceId; + thumbnailToBeMerged = clusterToBeMerged.thumbnailFaceId; + shouldCreateNamingFeedback = true; + } + if (shouldCreateNamingFeedback) { + final RenameOrCustomThumbnailClusterFeedback renameClusterFeedback = + RenameOrCustomThumbnailClusterFeedback( + medoid: largestCluster.medoid, + medoidDistanceThreshold: largestCluster.medoidDistanceThreshold, + customName: nameToBeMerged, + customThumbnailFaceId: thumbnailToBeMerged, + ); + // Check if feedback doesn't already exist + final matchingFeedbacks = await _mlDatabase + .getAllMatchingClusterFeedback(renameClusterFeedback); + for (final matchingFeedback in matchingFeedbacks) { + // Update the current feedback wherever possible + renameClusterFeedback.customName ??= matchingFeedback.customName; + renameClusterFeedback.customThumbnailFaceId ??= + matchingFeedback.customThumbnailFaceId; + + // Delete the old feedback (since we want the user to be able to overwrite their earlier feedback) + await _mlDatabase.deleteClusterFeedback(matchingFeedback); + } + + // Save the RenameOrCustomThumbnailClusterFeedback to the database + await _mlDatabase.createClusterFeedback(renameClusterFeedback); + } + + // Build the mergeClusterFeedback + final MergeClusterFeedback mergeClusterFeedback = MergeClusterFeedback( + medoid: clusterToBeMerged.medoid, + medoidDistanceThreshold: clusterToBeMerged.medoidDistanceThreshold, + medoidToMoveTo: largestCluster.medoid, + ); + + // Save the mergeClusterFeedback to the database and delete any old matching feedbacks + final matchingFeedbacks = + await _mlDatabase.getAllMatchingClusterFeedback(mergeClusterFeedback); + for (final matchingFeedback in matchingFeedbacks) { + await _mlDatabase.deleteClusterFeedback(matchingFeedback); + } + await _mlDatabase.createClusterFeedback(mergeClusterFeedback); + + // Delete the cluster from the database + await _mlDatabase.deleteClusterResult(clusterToBeMerged.personId); + } + + // TODO: should I update the medoid of this new cluster? My intuition says no, but I'm not sure. + // Update the largest cluster in the database + await _mlDatabase.updateClusterResult(largestCluster); + + // Return the merged cluster + return largestCluster; + } +} diff --git a/mobile/lib/services/face_ml/face_feedback.dart/feedback.dart b/mobile/lib/services/face_ml/face_feedback.dart/feedback.dart new file mode 100644 index 000000000..320ec64e9 --- /dev/null +++ b/mobile/lib/services/face_ml/face_feedback.dart/feedback.dart @@ -0,0 +1,34 @@ +import "package:photos/models/ml/ml_versions.dart"; +import "package:photos/services/face_ml/face_feedback.dart/feedback_types.dart"; +import "package:uuid/uuid.dart"; + +abstract class Feedback { + final FeedbackType type; + final String feedbackID; + final DateTime timestamp; + final int madeOnFaceMlVersion; + final int madeOnClusterMlVersion; + + get typeString => type.toValueString(); + + get timestampString => timestamp.toIso8601String(); + + Feedback( + this.type, { + String? feedbackID, + DateTime? timestamp, + int? madeOnFaceMlVersion, + int? madeOnClusterMlVersion, + }) : feedbackID = feedbackID ?? const Uuid().v4(), + timestamp = timestamp ?? DateTime.now(), + madeOnFaceMlVersion = madeOnFaceMlVersion ?? faceMlVersion, + madeOnClusterMlVersion = madeOnClusterMlVersion ?? clusterMlVersion; + + Map toJson(); + + String toJsonString(); + + // Feedback fromJson(Map json); + + // Feedback fromJsonString(String jsonString); +} diff --git a/mobile/lib/services/face_ml/face_feedback.dart/feedback_types.dart b/mobile/lib/services/face_ml/face_feedback.dart/feedback_types.dart new file mode 100644 index 000000000..8451d4790 --- /dev/null +++ b/mobile/lib/services/face_ml/face_feedback.dart/feedback_types.dart @@ -0,0 +1,26 @@ +enum FeedbackType { + removePhotosClusterFeedback, + addPhotosClusterFeedback, + deleteClusterFeedback, + mergeClusterFeedback, + renameOrCustomThumbnailClusterFeedback; // I have merged renameClusterFeedback and customThumbnailClusterFeedback, since I suspect they will be used together often + + factory FeedbackType.fromValueString(String value) { + switch (value) { + case 'deleteClusterFeedback': + return FeedbackType.deleteClusterFeedback; + case 'mergeClusterFeedback': + return FeedbackType.mergeClusterFeedback; + case 'renameOrCustomThumbnailClusterFeedback': + return FeedbackType.renameOrCustomThumbnailClusterFeedback; + case 'removePhotoClusterFeedback': + return FeedbackType.removePhotosClusterFeedback; + case 'addPhotoClusterFeedback': + return FeedbackType.addPhotosClusterFeedback; + default: + throw Exception('Invalid FeedbackType: $value'); + } + } + + String toValueString() => name; +} diff --git a/mobile/lib/services/face_ml/face_ml_exceptions.dart b/mobile/lib/services/face_ml/face_ml_exceptions.dart new file mode 100644 index 000000000..78a4bcb1f --- /dev/null +++ b/mobile/lib/services/face_ml/face_ml_exceptions.dart @@ -0,0 +1,30 @@ + +class GeneralFaceMlException implements Exception { + final String message; + + GeneralFaceMlException(this.message); + + @override + String toString() => 'GeneralFaceMlException: $message'; +} + +class CouldNotRetrieveAnyFileData implements Exception {} + +class CouldNotInitializeFaceDetector implements Exception {} + +class CouldNotRunFaceDetector implements Exception {} + +class CouldNotWarpAffine implements Exception {} + +class CouldNotInitializeFaceEmbeddor implements Exception {} + +class InputProblemFaceEmbeddor implements Exception { + final String message; + + InputProblemFaceEmbeddor(this.message); + + @override + String toString() => 'InputProblemFaceEmbeddor: $message'; +} + +class CouldNotRunFaceEmbeddor implements Exception {} \ No newline at end of file diff --git a/mobile/lib/services/face_ml/face_ml_methods.dart b/mobile/lib/services/face_ml/face_ml_methods.dart new file mode 100644 index 000000000..a6c967e52 --- /dev/null +++ b/mobile/lib/services/face_ml/face_ml_methods.dart @@ -0,0 +1,90 @@ +import "package:photos/services/face_ml/face_ml_version.dart"; + +/// Represents a face detection method with a specific version. +class FaceDetectionMethod extends VersionedMethod { + /// Creates a [FaceDetectionMethod] instance with a specific `method` and `version` (default `1`) + FaceDetectionMethod(String method, {int version = 1}) + : super(method, version); + + /// Creates a [FaceDetectionMethod] instance with 'Empty method' as the method, and a specific `version` (default `1`) + const FaceDetectionMethod.empty() : super.empty(); + + /// Creates a [FaceDetectionMethod] instance with 'BlazeFace' as the method, and a specific `version` (default `1`) + FaceDetectionMethod.blazeFace({int version = 1}) + : super('BlazeFace', version); + + static FaceDetectionMethod fromMlVersion(int version) { + switch (version) { + case 1: + return FaceDetectionMethod.blazeFace(version: version); + default: + return const FaceDetectionMethod.empty(); + } + } + + static FaceDetectionMethod fromJson(Map json) { + return FaceDetectionMethod( + json['method'], + version: json['version'], + ); + } +} + +/// Represents a face alignment method with a specific version. +class FaceAlignmentMethod extends VersionedMethod { + /// Creates a [FaceAlignmentMethod] instance with a specific `method` and `version` (default `1`) + FaceAlignmentMethod(String method, {int version = 1}) + : super(method, version); + + /// Creates a [FaceAlignmentMethod] instance with 'Empty method' as the method, and a specific `version` (default `1`) + const FaceAlignmentMethod.empty() : super.empty(); + + /// Creates a [FaceAlignmentMethod] instance with 'ArcFace' as the method, and a specific `version` (default `1`) + FaceAlignmentMethod.arcFace({int version = 1}) : super('ArcFace', version); + + static FaceAlignmentMethod fromMlVersion(int version) { + switch (version) { + case 1: + return FaceAlignmentMethod.arcFace(version: version); + default: + return const FaceAlignmentMethod.empty(); + } + } + + static FaceAlignmentMethod fromJson(Map json) { + return FaceAlignmentMethod( + json['method'], + version: json['version'], + ); + } +} + +/// Represents a face embedding method with a specific version. +class FaceEmbeddingMethod extends VersionedMethod { + /// Creates a [FaceEmbeddingMethod] instance with a specific `method` and `version` (default `1`) + FaceEmbeddingMethod(String method, {int version = 1}) + : super(method, version); + + /// Creates a [FaceEmbeddingMethod] instance with 'Empty method' as the method, and a specific `version` (default `1`) + const FaceEmbeddingMethod.empty() : super.empty(); + + /// Creates a [FaceEmbeddingMethod] instance with 'MobileFaceNet' as the method, and a specific `version` (default `1`) + FaceEmbeddingMethod.mobileFaceNet({int version = 1}) + : super('MobileFaceNet', version); + + static FaceEmbeddingMethod fromMlVersion(int version) { + switch (version) { + case 1: + return FaceEmbeddingMethod.mobileFaceNet(version: version); + default: + return const FaceEmbeddingMethod.empty(); + } + } + + static FaceEmbeddingMethod fromJson(Map json) { + return FaceEmbeddingMethod( + json['method'], + version: json['version'], + ); + } +} diff --git a/mobile/lib/services/face_ml/face_ml_result.dart b/mobile/lib/services/face_ml/face_ml_result.dart new file mode 100644 index 000000000..c770efde7 --- /dev/null +++ b/mobile/lib/services/face_ml/face_ml_result.dart @@ -0,0 +1,753 @@ +import "dart:convert" show jsonEncode, jsonDecode; + +import "package:flutter/material.dart" show Size, debugPrint, immutable; +import "package:logging/logging.dart"; +import "package:photos/db/ml_data_db.dart"; +import "package:photos/models/file/file.dart"; +import 'package:photos/models/ml/ml_typedefs.dart'; +import "package:photos/models/ml/ml_versions.dart"; +import "package:photos/services/face_ml/blur_detection/blur_constants.dart"; +import "package:photos/services/face_ml/face_alignment/alignment_result.dart"; +import "package:photos/services/face_ml/face_clustering/cosine_distance.dart"; +import "package:photos/services/face_ml/face_detection/detection.dart"; +import "package:photos/services/face_ml/face_feedback.dart/cluster_feedback.dart"; +import "package:photos/services/face_ml/face_ml_methods.dart"; + +final _logger = Logger('ClusterResult_FaceMlResult'); + +// TODO: should I add [faceMlVersion] and [clusterMlVersion] to the [ClusterResult] class? +class ClusterResult { + final int personId; + String? userDefinedName; + bool get hasUserDefinedName => userDefinedName != null; + + String _thumbnailFaceId; + bool thumbnailFaceIdIsUserDefined; + + final List _fileIds; + final List _faceIds; + + final Embedding medoid; + double medoidDistanceThreshold; + + List get uniqueFileIds => _fileIds.toSet().toList(); + List get fileIDsIncludingPotentialDuplicates => _fileIds; + + List get faceIDs => _faceIds; + + String get thumbnailFaceId => _thumbnailFaceId; + + int get thumbnailFileId => _getFileIdFromFaceId(_thumbnailFaceId); + + /// Sets the thumbnail faceId to the given faceId. + /// Throws an exception if the faceId is not in the list of faceIds. + set setThumbnailFaceId(String faceId) { + if (!_faceIds.contains(faceId)) { + throw Exception( + "The faceId $faceId is not in the list of faceIds: $faceId", + ); + } + _thumbnailFaceId = faceId; + thumbnailFaceIdIsUserDefined = true; + } + + /// Sets the [userDefinedName] to the given [customName] + set setUserDefinedName(String customName) { + userDefinedName = customName; + } + + int get clusterSize => _fileIds.toSet().length; + + ClusterResult({ + required this.personId, + required String thumbnailFaceId, + required List fileIds, + required List faceIds, + required this.medoid, + required this.medoidDistanceThreshold, + this.userDefinedName, + this.thumbnailFaceIdIsUserDefined = false, + }) : _thumbnailFaceId = thumbnailFaceId, + _faceIds = faceIds, + _fileIds = fileIds; + + void addFileIDsAndFaceIDs(List fileIDs, List faceIDs) { + assert(fileIDs.length == faceIDs.length); + _fileIds.addAll(fileIDs); + _faceIds.addAll(faceIDs); + } + + // TODO: Consider if we should recalculated the medoid and threshold when deleting or adding a file from the cluster + int removeFileId(int fileId) { + assert(_fileIds.length == _faceIds.length); + if (!_fileIds.contains(fileId)) { + throw Exception( + "The fileId $fileId is not in the list of fileIds: $fileId, so it's not in the cluster and cannot be removed.", + ); + } + + int removedCount = 0; + for (var i = 0; i < _fileIds.length; i++) { + if (_fileIds[i] == fileId) { + assert(_getFileIdFromFaceId(_faceIds[i]) == fileId); + _fileIds.removeAt(i); + _faceIds.removeAt(i); + debugPrint( + "Removed fileId $fileId from cluster $personId at index ${i + removedCount}}", + ); + i--; // Adjust index due to removal + removedCount++; + } + } + + _ensureClusterSizeIsAboveMinimum(); + + return removedCount; + } + + int addFileID(int fileID) { + assert(_fileIds.length == _faceIds.length); + if (_fileIds.contains(fileID)) { + return 0; + } + + _fileIds.add(fileID); + _faceIds.add(FaceDetectionRelative.toFaceIDEmpty(fileID: fileID)); + + return 1; + } + + void ensureThumbnailFaceIdIsInCluster() { + if (!_faceIds.contains(_thumbnailFaceId)) { + _thumbnailFaceId = _faceIds[0]; + } + } + + void _ensureClusterSizeIsAboveMinimum() { + if (clusterSize < minimumClusterSize) { + throw Exception( + "Cluster size is below minimum cluster size of $minimumClusterSize", + ); + } + } + + Map _toJson() => { + 'personId': personId, + 'thumbnailFaceId': _thumbnailFaceId, + 'fileIds': _fileIds, + 'faceIds': _faceIds, + 'medoid': medoid, + 'medoidDistanceThreshold': medoidDistanceThreshold, + if (userDefinedName != null) 'userDefinedName': userDefinedName, + 'thumbnailFaceIdIsUserDefined': thumbnailFaceIdIsUserDefined, + }; + + String toJsonString() => jsonEncode(_toJson()); + + static ClusterResult _fromJson(Map json) { + return ClusterResult( + personId: json['personId'] ?? -1, + thumbnailFaceId: json['thumbnailFaceId'] ?? '', + fileIds: + (json['fileIds'] as List?)?.map((item) => item as int).toList() ?? [], + faceIds: + (json['faceIds'] as List?)?.map((item) => item as String).toList() ?? + [], + medoid: + (json['medoid'] as List?)?.map((item) => item as double).toList() ?? + [], + medoidDistanceThreshold: json['medoidDistanceThreshold'] ?? 0, + userDefinedName: json['userDefinedName'], + thumbnailFaceIdIsUserDefined: + json['thumbnailFaceIdIsUserDefined'] as bool, + ); + } + + static ClusterResult fromJsonString(String jsonString) { + return _fromJson(jsonDecode(jsonString)); + } +} + +class ClusterResultBuilder { + int personId = -1; + String? userDefinedName; + String thumbnailFaceId = ''; + bool thumbnailFaceIdIsUserDefined = false; + + List fileIds = []; + List faceIds = []; + + List embeddings = []; + Embedding medoid = []; + double medoidDistanceThreshold = 0; + bool medoidAndThresholdCalculated = false; + final int k = 5; + + ClusterResultBuilder.createFromIndices({ + required List clusterIndices, + required List labels, + required List allEmbeddings, + required List allFileIds, + required List allFaceIds, + }) { + final clusteredFileIds = + clusterIndices.map((fileIndex) => allFileIds[fileIndex]).toList(); + final clusteredFaceIds = + clusterIndices.map((fileIndex) => allFaceIds[fileIndex]).toList(); + final clusteredEmbeddings = + clusterIndices.map((fileIndex) => allEmbeddings[fileIndex]).toList(); + personId = labels[clusterIndices[0]]; + fileIds = clusteredFileIds; + faceIds = clusteredFaceIds; + thumbnailFaceId = faceIds[0]; + embeddings = clusteredEmbeddings; + } + + void calculateAndSetMedoidAndThreshold() { + if (embeddings.isEmpty) { + throw Exception("Cannot calculate medoid and threshold for empty list"); + } + + // Calculate the medoid and threshold + final (tempMedoid, distanceThreshold) = + _calculateMedoidAndDistanceTreshold(embeddings); + + // Update the medoid + medoid = List.from(tempMedoid); + + // Update the medoidDistanceThreshold as the distance of the medoid to its k-th nearest neighbor + medoidDistanceThreshold = distanceThreshold; + + medoidAndThresholdCalculated = true; + } + + (List, double) _calculateMedoidAndDistanceTreshold( + List> embeddings, + ) { + double minDistance = double.infinity; + List? medoid; + + // Calculate the distance between all pairs + for (int i = 0; i < embeddings.length; ++i) { + double totalDistance = 0; + for (int j = 0; j < embeddings.length; ++j) { + if (i != j) { + totalDistance += cosineDistance(embeddings[i], embeddings[j]); + + // Break early if we already exceed minDistance + if (totalDistance > minDistance) { + break; + } + } + } + + // Find the minimum total distance + if (totalDistance < minDistance) { + minDistance = totalDistance; + medoid = embeddings[i]; + } + } + + // Now, calculate k-th nearest neighbor for the medoid + final List distancesToMedoid = []; + for (List embedding in embeddings) { + if (embedding != medoid) { + distancesToMedoid.add(cosineDistance(medoid!, embedding)); + } + } + distancesToMedoid.sort(); + // TODO: empirically find the best k. Probably it should be dynamic in some way, so for instance larger for larger clusters and smaller for smaller clusters, especially since there are a lot of really small clusters and a few really large ones. + final double kthDistance = distancesToMedoid[ + distancesToMedoid.length >= k ? k - 1 : distancesToMedoid.length - 1]; + + return (medoid!, kthDistance); + } + + Future _checkIfClusterIsDeleted() async { + assert(medoidAndThresholdCalculated); + + // Check if the medoid is the default medoid for deleted faces + if (cosineDistance(medoid, List.filled(medoid.length, 10.0)) < 0.001) { + return true; + } + + final tempFeedback = DeleteClusterFeedback( + medoid: medoid, + medoidDistanceThreshold: medoidDistanceThreshold, + ); + return await MlDataDB.instance + .doesSimilarClusterFeedbackExist(tempFeedback); + } + + Future _checkAndAddPhotos() async { + assert(medoidAndThresholdCalculated); + + final tempFeedback = AddPhotosClusterFeedback( + medoid: medoid, + medoidDistanceThreshold: medoidDistanceThreshold, + addedPhotoFileIDs: [], + ); + final allAddPhotosFeedbacks = + await MlDataDB.instance.getAllMatchingClusterFeedback(tempFeedback); + + for (final addPhotosFeedback in allAddPhotosFeedbacks) { + final fileIDsToAdd = addPhotosFeedback.addedPhotoFileIDs; + final faceIDsToAdd = fileIDsToAdd + .map((fileID) => FaceDetectionRelative.toFaceIDEmpty(fileID: fileID)) + .toList(); + addFileIDsAndFaceIDs(fileIDsToAdd, faceIDsToAdd); + } + } + + Future _checkAndAddCustomName() async { + assert(medoidAndThresholdCalculated); + + final tempFeedback = RenameOrCustomThumbnailClusterFeedback( + medoid: medoid, + medoidDistanceThreshold: medoidDistanceThreshold, + customName: 'test', + ); + final allRenameFeedbacks = + await MlDataDB.instance.getAllMatchingClusterFeedback(tempFeedback); + + for (final nameFeedback in allRenameFeedbacks) { + userDefinedName ??= nameFeedback.customName; + if (!thumbnailFaceIdIsUserDefined) { + thumbnailFaceId = nameFeedback.customThumbnailFaceId ?? thumbnailFaceId; + thumbnailFaceIdIsUserDefined = + nameFeedback.customThumbnailFaceId != null; + } + } + return; + } + + void changeThumbnailFaceId(String faceId) { + if (!faceIds.contains(faceId)) { + throw Exception( + "The faceId $faceId is not in the list of faceIds: $faceIds", + ); + } + thumbnailFaceId = faceId; + } + + void addFileIDsAndFaceIDs(List addedFileIDs, List addedFaceIDs) { + assert(addedFileIDs.length == addedFaceIDs.length); + fileIds.addAll(addedFileIDs); + faceIds.addAll(addedFaceIDs); + } + + static Future> buildClusters( + List clusterBuilders, + ) async { + final List deletedClusterIndices = []; + for (var i = 0; i < clusterBuilders.length; i++) { + final clusterBuilder = clusterBuilders[i]; + clusterBuilder.calculateAndSetMedoidAndThreshold(); + + // Check if the cluster has been deleted + if (await clusterBuilder._checkIfClusterIsDeleted()) { + deletedClusterIndices.add(i); + } + + await clusterBuilder._checkAndAddPhotos(); + } + + // Check if a cluster should be merged with another cluster + for (var i = 0; i < clusterBuilders.length; i++) { + // Don't check for clusters that have been deleted + if (deletedClusterIndices.contains(i)) { + continue; + } + final clusterBuilder = clusterBuilders[i]; + final List allMatchingMergeFeedback = + await MlDataDB.instance.getAllMatchingClusterFeedback( + MergeClusterFeedback( + medoid: clusterBuilder.medoid, + medoidDistanceThreshold: clusterBuilder.medoidDistanceThreshold, + medoidToMoveTo: clusterBuilder.medoid, + ), + ); + if (allMatchingMergeFeedback.isEmpty) { + continue; + } + // Merge the cluster with the first merge feedback + final mainFeedback = allMatchingMergeFeedback.first; + if (allMatchingMergeFeedback.length > 1) { + // This is the BUG!!!! + _logger.warning( + "There are ${allMatchingMergeFeedback.length} merge feedbacks for cluster ${clusterBuilder.personId}. Using the first one.", + ); + } + for (var j = 0; j < clusterBuilders.length; j++) { + if (i == j) continue; + final clusterBuilderToMergeTo = clusterBuilders[j]; + final distance = cosineDistance( + // BUG: it hasn't calculated the medoid for every clusterBuilder yet!!! + mainFeedback.medoidToMoveTo, + clusterBuilderToMergeTo.medoid, + ); + if (distance < mainFeedback.medoidDistanceThreshold || + distance < clusterBuilderToMergeTo.medoidDistanceThreshold) { + clusterBuilderToMergeTo.addFileIDsAndFaceIDs( + clusterBuilder.fileIds, + clusterBuilder.faceIds, + ); + deletedClusterIndices.add(i); + } + } + } + + final clusterResults = []; + for (var i = 0; i < clusterBuilders.length; i++) { + // Don't build the cluster if it has been deleted or merged + if (deletedClusterIndices.contains(i)) { + continue; + } + final clusterBuilder = clusterBuilders[i]; + // Check if the cluster has a custom name or thumbnail + await clusterBuilder._checkAndAddCustomName(); + + // Build the clusterResult + clusterResults.add( + ClusterResult( + personId: clusterBuilder.personId, + thumbnailFaceId: clusterBuilder.thumbnailFaceId, + fileIds: clusterBuilder.fileIds, + faceIds: clusterBuilder.faceIds, + medoid: clusterBuilder.medoid, + medoidDistanceThreshold: clusterBuilder.medoidDistanceThreshold, + userDefinedName: clusterBuilder.userDefinedName, + thumbnailFaceIdIsUserDefined: + clusterBuilder.thumbnailFaceIdIsUserDefined, + ), + ); + } + + return clusterResults; + } + + // TODO: This function should include the feedback from the user. Should also be nullable, since user might want to delete the cluster. + Future _buildSingleCluster() async { + calculateAndSetMedoidAndThreshold(); + if (await _checkIfClusterIsDeleted()) { + return null; + } + await _checkAndAddCustomName(); + return ClusterResult( + personId: personId, + thumbnailFaceId: thumbnailFaceId, + fileIds: fileIds, + faceIds: faceIds, + medoid: medoid, + medoidDistanceThreshold: medoidDistanceThreshold, + ); + } +} + +@immutable +class FaceMlResult { + final int fileId; + + final List faces; + + final Size? faceDetectionImageSize; + final Size? faceAlignmentImageSize; + + final int mlVersion; + final bool errorOccured; + final bool onlyThumbnailUsed; + + bool get hasFaces => faces.isNotEmpty; + int get numberOfFaces => faces.length; + + List get allFaceEmbeddings { + return faces.map((face) => face.embedding).toList(); + } + + List get allFaceIds { + return faces.map((face) => face.faceId).toList(); + } + + List get fileIdForEveryFace { + return List.filled(faces.length, fileId); + } + + FaceDetectionMethod get faceDetectionMethod => + FaceDetectionMethod.fromMlVersion(mlVersion); + FaceAlignmentMethod get faceAlignmentMethod => + FaceAlignmentMethod.fromMlVersion(mlVersion); + FaceEmbeddingMethod get faceEmbeddingMethod => + FaceEmbeddingMethod.fromMlVersion(mlVersion); + + const FaceMlResult({ + required this.fileId, + required this.faces, + required this.mlVersion, + required this.errorOccured, + required this.onlyThumbnailUsed, + required this.faceDetectionImageSize, + this.faceAlignmentImageSize, + }); + + Map _toJson() => { + 'fileId': fileId, + 'faces': faces.map((face) => face.toJson()).toList(), + 'mlVersion': mlVersion, + 'errorOccured': errorOccured, + 'onlyThumbnailUsed': onlyThumbnailUsed, + if (faceDetectionImageSize != null) + 'faceDetectionImageSize': { + 'width': faceDetectionImageSize!.width, + 'height': faceDetectionImageSize!.height, + }, + if (faceAlignmentImageSize != null) + 'faceAlignmentImageSize': { + 'width': faceAlignmentImageSize!.width, + 'height': faceAlignmentImageSize!.height, + }, + }; + + String toJsonString() => jsonEncode(_toJson()); + + static FaceMlResult _fromJson(Map json) { + return FaceMlResult( + fileId: json['fileId'], + faces: (json['faces'] as List) + .map((item) => FaceResult.fromJson(item as Map)) + .toList(), + mlVersion: json['mlVersion'], + errorOccured: json['errorOccured'] ?? false, + onlyThumbnailUsed: json['onlyThumbnailUsed'] ?? false, + faceDetectionImageSize: json['faceDetectionImageSize'] == null + ? null + : Size( + json['faceDetectionImageSize']['width'], + json['faceDetectionImageSize']['height'], + ), + faceAlignmentImageSize: json['faceAlignmentImageSize'] == null + ? null + : Size( + json['faceAlignmentImageSize']['width'], + json['faceAlignmentImageSize']['height'], + ), + ); + } + + static FaceMlResult fromJsonString(String jsonString) { + return _fromJson(jsonDecode(jsonString)); + } + + /// Sets the embeddings of the faces with the given faceIds to [10, 10,..., 10]. + /// + /// Throws an exception if a faceId is not found in the FaceMlResult. + void setEmbeddingsToTen(List faceIds) { + for (final faceId in faceIds) { + final faceIndex = faces.indexWhere((face) => face.faceId == faceId); + if (faceIndex == -1) { + throw Exception("No face found with faceId $faceId"); + } + for (var i = 0; i < faces[faceIndex].embedding.length; i++) { + faces[faceIndex].embedding[i] = 10; + } + } + } + + FaceDetectionRelative getDetectionForFaceId(String faceId) { + final faceIndex = faces.indexWhere((face) => face.faceId == faceId); + if (faceIndex == -1) { + throw Exception("No face found with faceId $faceId"); + } + return faces[faceIndex].detection; + } +} + +class FaceMlResultBuilder { + int fileId; + + List faces = []; + + Size? faceDetectionImageSize; + Size? faceAlignmentImageSize; + + int mlVersion; + bool errorOccured; + bool onlyThumbnailUsed; + + FaceMlResultBuilder({ + this.fileId = -1, + this.mlVersion = faceMlVersion, + this.errorOccured = false, + this.onlyThumbnailUsed = false, + }); + + FaceMlResultBuilder.fromEnteFile( + EnteFile file, { + this.mlVersion = faceMlVersion, + this.errorOccured = false, + this.onlyThumbnailUsed = false, + }) : fileId = file.uploadedFileID ?? -1; + + FaceMlResultBuilder.fromEnteFileID( + int fileID, { + this.mlVersion = faceMlVersion, + this.errorOccured = false, + this.onlyThumbnailUsed = false, + }) : fileId = fileID; + + void addNewlyDetectedFaces( + List faceDetections, + Size originalSize, + ) { + faceDetectionImageSize = originalSize; + for (var i = 0; i < faceDetections.length; i++) { + faces.add( + FaceResultBuilder.fromFaceDetection( + faceDetections[i], + resultBuilder: this, + ), + ); + } + } + + void addAlignmentResults( + List alignmentResults, + List blurValues, + Size imageSizeUsedForAlignment, + ) { + if (alignmentResults.length != faces.length) { + throw Exception( + "The amount of alignment results (${alignmentResults.length}) does not match the number of faces (${faces.length})", + ); + } + + for (var i = 0; i < alignmentResults.length; i++) { + faces[i].alignment = alignmentResults[i]; + faces[i].blurValue = blurValues[i]; + } + faceAlignmentImageSize = imageSizeUsedForAlignment; + } + + void addEmbeddingsToExistingFaces( + List embeddings, + ) { + if (embeddings.length != faces.length) { + throw Exception( + "The amount of embeddings (${embeddings.length}) does not match the number of faces (${faces.length})", + ); + } + for (var faceIndex = 0; faceIndex < faces.length; faceIndex++) { + faces[faceIndex].embedding = embeddings[faceIndex]; + } + } + + FaceMlResult build() { + final faceResults = []; + for (var i = 0; i < faces.length; i++) { + faceResults.add(faces[i].build()); + } + return FaceMlResult( + fileId: fileId, + faces: faceResults, + mlVersion: mlVersion, + errorOccured: errorOccured, + onlyThumbnailUsed: onlyThumbnailUsed, + faceDetectionImageSize: faceDetectionImageSize, + faceAlignmentImageSize: faceAlignmentImageSize, + ); + } + + FaceMlResult buildNoFaceDetected() { + faces = []; + return build(); + } + + FaceMlResult buildErrorOccurred() { + faces = []; + errorOccured = true; + return build(); + } +} + +@immutable +class FaceResult { + final FaceDetectionRelative detection; + final double blurValue; + final AlignmentResult alignment; + final Embedding embedding; + final int fileId; + final String faceId; + + bool get isBlurry => blurValue < kLaplacianThreshold; + + const FaceResult({ + required this.detection, + required this.blurValue, + required this.alignment, + required this.embedding, + required this.fileId, + required this.faceId, + }); + + Map toJson() => { + 'detection': detection.toJson(), + 'blurValue': blurValue, + 'alignment': alignment.toJson(), + 'embedding': embedding, + 'fileId': fileId, + 'faceId': faceId, + }; + + static FaceResult fromJson(Map json) { + return FaceResult( + detection: FaceDetectionRelative.fromJson(json['detection']), + blurValue: json['blurValue'], + alignment: AlignmentResult.fromJson(json['alignment']), + embedding: Embedding.from(json['embedding']), + fileId: json['fileId'], + faceId: json['faceId'], + ); + } +} + +class FaceResultBuilder { + FaceDetectionRelative detection = + FaceDetectionRelative.defaultInitialization(); + double blurValue = 1000; + AlignmentResult alignment = AlignmentResult.empty(); + Embedding embedding = []; + int fileId = -1; + String faceId = ''; + + bool get isBlurry => blurValue < kLaplacianThreshold; + + FaceResultBuilder({ + required this.fileId, + required this.faceId, + }); + + FaceResultBuilder.fromFaceDetection( + FaceDetectionRelative faceDetection, { + required FaceMlResultBuilder resultBuilder, + }) { + fileId = resultBuilder.fileId; + faceId = faceDetection.toFaceID(fileID: resultBuilder.fileId); + detection = faceDetection; + } + + FaceResult build() { + assert(detection.allKeypoints[0][0] <= 1); + assert(detection.box[0] <= 1); + return FaceResult( + detection: detection, + blurValue: blurValue, + alignment: alignment, + embedding: embedding, + fileId: fileId, + faceId: faceId, + ); + } +} + +int _getFileIdFromFaceId(String faceId) { + return int.parse(faceId.split("_")[0]); +} diff --git a/mobile/lib/services/face_ml/face_ml_service.dart b/mobile/lib/services/face_ml/face_ml_service.dart new file mode 100644 index 000000000..c2523309e --- /dev/null +++ b/mobile/lib/services/face_ml/face_ml_service.dart @@ -0,0 +1,1149 @@ +import "dart:async"; +import "dart:developer" as dev show log; +import "dart:io" show File; +import "dart:isolate"; +import "dart:typed_data" show Uint8List, Float32List; + +import "package:computer/computer.dart"; +import "package:flutter/foundation.dart"; +import "package:flutter_image_compress/flutter_image_compress.dart"; +import "package:flutter_isolate/flutter_isolate.dart"; +import "package:logging/logging.dart"; +import "package:onnxruntime/onnxruntime.dart"; +import "package:photos/core/configuration.dart"; +import "package:photos/core/constants.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/db/ml_data_db.dart"; +import "package:photos/events/diff_sync_complete_event.dart"; +import "package:photos/extensions/list.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/box.dart"; +import "package:photos/face/model/detection.dart" as face_detection; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/landmark.dart"; +import "package:photos/models/file/extensions/file_props.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/file/file_type.dart"; +import "package:photos/models/ml/ml_versions.dart"; +import "package:photos/services/face_ml/face_clustering/linear_clustering_service.dart"; +import "package:photos/services/face_ml/face_detection/detection.dart"; +import 'package:photos/services/face_ml/face_detection/yolov5face/onnx_face_detection.dart'; +import "package:photos/services/face_ml/face_detection/yolov5face/yolo_face_detection_exceptions.dart"; +import "package:photos/services/face_ml/face_embedding/face_embedding_exceptions.dart"; +import 'package:photos/services/face_ml/face_embedding/onnx_face_embedding.dart'; +import "package:photos/services/face_ml/face_ml_exceptions.dart"; +import "package:photos/services/face_ml/face_ml_result.dart"; +import "package:photos/services/search_service.dart"; +import "package:photos/utils/file_util.dart"; +import 'package:photos/utils/image_ml_isolate.dart'; +import "package:photos/utils/image_ml_util.dart"; +import "package:photos/utils/local_settings.dart"; +import "package:photos/utils/thumbnail_util.dart"; +import "package:synchronized/synchronized.dart"; + +enum FileDataForML { thumbnailData, fileData, compressedFileData } + +enum FaceMlOperation { analyzeImage } + +/// This class is responsible for running the full face ml pipeline on images. +/// +/// WARNING: For getting the ML results needed for the UI, you should use `FaceSearchService` instead of this class! +/// +/// The pipeline consists of face detection, face alignment and face embedding. +class FaceMlService { + final _logger = Logger("FaceMlService"); + + // Flutter isolate things for running the image ml pipeline + Timer? _inactivityTimer; + final Duration _inactivityDuration = const Duration(seconds: 120); + int _activeTasks = 0; + final _initLockIsolate = Lock(); + late FlutterIsolate _isolate; + late ReceivePort _receivePort = ReceivePort(); + late SendPort _mainSendPort; + + bool isIsolateSpawned = false; + + // singleton pattern + FaceMlService._privateConstructor(); + static final instance = FaceMlService._privateConstructor(); + factory FaceMlService() => instance; + + final _initLock = Lock(); + final _functionLock = Lock(); + + final _computer = Computer.shared(); + + bool isInitialized = false; + bool isImageIndexRunning = false; + int kParallelism = 15; + + Future init({bool initializeImageMlIsolate = false}) async { + return _initLock.synchronized(() async { + if (isInitialized) { + return; + } + _logger.info("init called"); + await _computer.compute(initOrtEnv); + try { + await YoloOnnxFaceDetection.instance.init(); + } catch (e, s) { + _logger.severe("Could not initialize yolo onnx", e, s); + } + if (initializeImageMlIsolate) { + try { + await ImageMlIsolate.instance.init(); + } catch (e, s) { + _logger.severe("Could not initialize image ml isolate", e, s); + } + } + try { + await FaceEmbeddingOnnx.instance.init(); + } catch (e, s) { + _logger.severe("Could not initialize mobilefacenet", e, s); + } + + isInitialized = true; + }); + } + + static void initOrtEnv() async { + OrtEnv.instance.init(); + } + + void listenIndexOnDiffSync() { + Bus.instance.on().listen((event) async { + if (LocalSettings.instance.isFaceIndexingEnabled == false) { + return; + } + unawaited(indexAllImages()); + }); + } + + Future ensureInitialized() async { + if (!isInitialized) { + await init(); + } + } + + Future release() async { + return _initLock.synchronized(() async { + _logger.info("dispose called"); + if (!isInitialized) { + return; + } + try { + await YoloOnnxFaceDetection.instance.release(); + } catch (e, s) { + _logger.severe("Could not dispose yolo onnx", e, s); + } + try { + ImageMlIsolate.instance.dispose(); + } catch (e, s) { + _logger.severe("Could not dispose image ml isolate", e, s); + } + try { + await FaceEmbeddingOnnx.instance.release(); + } catch (e, s) { + _logger.severe("Could not dispose mobilefacenet", e, s); + } + OrtEnv.instance.release(); + isInitialized = false; + }); + } + + Future initIsolate() async { + return _initLockIsolate.synchronized(() async { + if (isIsolateSpawned) return; + _logger.info("initIsolate called"); + + _receivePort = ReceivePort(); + + try { + _isolate = await FlutterIsolate.spawn( + _isolateMain, + _receivePort.sendPort, + ); + _mainSendPort = await _receivePort.first as SendPort; + isIsolateSpawned = true; + + _resetInactivityTimer(); + } catch (e) { + _logger.severe('Could not spawn isolate', e); + isIsolateSpawned = false; + } + }); + } + + Future ensureSpawnedIsolate() async { + if (!isIsolateSpawned) { + await initIsolate(); + } + } + + /// The main execution function of the isolate. + static void _isolateMain(SendPort mainSendPort) async { + final receivePort = ReceivePort(); + mainSendPort.send(receivePort.sendPort); + + receivePort.listen((message) async { + final functionIndex = message[0] as int; + final function = FaceMlOperation.values[functionIndex]; + final args = message[1] as Map; + final sendPort = message[2] as SendPort; + + try { + switch (function) { + case FaceMlOperation.analyzeImage: + final int enteFileID = args["enteFileID"] as int; + final String smallDataPath = args["smallDataPath"] as String; + final String largeDataPath = args["largeDataPath"] as String; + final int faceDetectionAddress = + args["faceDetectionAddress"] as int; + final int faceEmbeddingAddress = + args["faceEmbeddingAddress"] as int; + + final resultBuilder = + FaceMlResultBuilder.fromEnteFileID(enteFileID); + + dev.log( + "Start analyzing image with uploadedFileID: $enteFileID inside the isolate", + ); + final stopwatchTotal = Stopwatch()..start(); + final stopwatch = Stopwatch()..start(); + + // Get the faces + final List faceDetectionResult = + await FaceMlService.detectFacesSync( + smallDataPath, + faceDetectionAddress, + resultBuilder: resultBuilder, + ); + + dev.log("Completed `detectFaces` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + // If no faces were detected, return a result with no faces. Otherwise, continue. + if (faceDetectionResult.isEmpty) { + dev.log( + "No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in " + "${stopwatch.elapsedMilliseconds} ms"); + sendPort.send(resultBuilder.buildNoFaceDetected().toJsonString()); + break; + } + + stopwatch.reset(); + // Align the faces + final Float32List faceAlignmentResult = + await FaceMlService.alignFacesSync( + largeDataPath, + faceDetectionResult, + resultBuilder: resultBuilder, + ); + + dev.log("Completed `alignFaces` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + stopwatch.reset(); + // Get the embeddings of the faces + final embeddings = await FaceMlService.embedFacesSync( + faceAlignmentResult, + faceEmbeddingAddress, + resultBuilder: resultBuilder, + ); + + dev.log("Completed `embedBatchFaces` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + stopwatch.stop(); + stopwatchTotal.stop(); + dev.log("Finished Analyze image (${embeddings.length} faces) with " + "uploadedFileID $enteFileID, in " + "${stopwatchTotal.elapsedMilliseconds} ms"); + + sendPort.send(resultBuilder.build().toJsonString()); + break; + } + } catch (e, stackTrace) { + dev.log( + "[SEVERE] Error in FaceML isolate: $e", + error: e, + stackTrace: stackTrace, + ); + sendPort + .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); + } + }); + } + + /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. + Future _runInIsolate( + (FaceMlOperation, Map) message, + ) async { + await ensureSpawnedIsolate(); + return _functionLock.synchronized(() async { + _resetInactivityTimer(); + + if (isImageIndexRunning == false) { + return null; + } + + final completer = Completer(); + final answerPort = ReceivePort(); + + _activeTasks++; + _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); + + answerPort.listen((receivedMessage) { + if (receivedMessage is Map && receivedMessage.containsKey('error')) { + // Handle the error + final errorMessage = receivedMessage['error']; + final errorStackTrace = receivedMessage['stackTrace']; + final exception = Exception(errorMessage); + final stackTrace = StackTrace.fromString(errorStackTrace); + completer.completeError(exception, stackTrace); + } else { + completer.complete(receivedMessage); + } + }); + _activeTasks--; + + return completer.future; + }); + } + + /// Resets a timer that kills the isolate after a certain amount of inactivity. + /// + /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) + void _resetInactivityTimer() { + _inactivityTimer?.cancel(); + _inactivityTimer = Timer(_inactivityDuration, () { + if (_activeTasks > 0) { + _logger.info('Tasks are still running. Delaying isolate disposal.'); + // Optionally, reschedule the timer to check again later. + _resetInactivityTimer(); + } else { + _logger.info( + 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', + ); + disposeIsolate(); + } + }); + } + + void disposeIsolate() async { + if (!isIsolateSpawned) return; + await release(); + + isIsolateSpawned = false; + _isolate.kill(); + _receivePort.close(); + _inactivityTimer?.cancel(); + } + + Future indexAndClusterAllImages() async { + // Run the analysis on all images to make sure everything is analyzed + await indexAllImages(); + + // Cluster all the images + await clusterAllImages(); + } + + Future clusterAllImages({double minFaceScore = 0.75}) async { + _logger.info("`clusterAllImages()` called"); + + try { + // Read all the embeddings from the database, in a map from faceID to embedding + final clusterStartTime = DateTime.now(); + final faceIdToEmbedding = await FaceMLDataDB.instance.getFaceEmbeddingMap( + minScore: minFaceScore, + ); + _logger.info('read embeddings ${faceIdToEmbedding.length} '); + + // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID + final faceIdToCluster = + await FaceLinearClustering.instance.predict(faceIdToEmbedding); + if (faceIdToCluster == null) { + _logger.warning("faceIdToCluster is null"); + return; + } + final clusterDoneTime = DateTime.now(); + _logger.info( + 'done with clustering ${faceIdToEmbedding.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ', + ); + + // Store the updated clusterIDs in the database + _logger.info( + 'Updating ${faceIdToCluster?.length} FaceIDs with clusterIDs in the DB', + ); + await FaceMLDataDB.instance + .updatePersonIDForFaceIDIFNotSet(faceIdToCluster!); + _logger.info('Done updating FaceIDs with clusterIDs in the DB, in ' + '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds'); + } catch (e, s) { + _logger.severe("`clusterAllImages` failed", e, s); + } + } + + /// Analyzes all the images in the database with the latest ml version and stores the results in the database. + /// + /// This function first checks if the image has already been analyzed with the lastest faceMlVersion and stored in the database. If so, it skips the image. + Future indexAllImages() async { + if (isImageIndexRunning) { + _logger.warning("indexAllImages is already running, skipping"); + return; + } + // verify indexing is enabled + if (LocalSettings.instance.isFaceIndexingEnabled == false) { + _logger.warning("indexAllImages is disabled"); + return; + } + try { + isImageIndexRunning = true; + _logger.info('starting image indexing'); + + final List enteFiles = + await SearchService.instance.getAllFiles(); + final Set alreadyIndexedFiles = + await FaceMLDataDB.instance.getIndexedFileIds(); + + // Make sure the image conversion isolate is spawned + // await ImageMlIsolate.instance.ensureSpawned(); + await ensureInitialized(); + + int fileAnalyzedCount = 0; + int fileSkippedCount = 0; + final stopwatch = Stopwatch()..start(); + final split = enteFiles.splitMatch((e) => (e.localID ?? '') == ''); + // list of files where files with localID are first + final sortedBylocalID = []; + sortedBylocalID.addAll(split.unmatched); + sortedBylocalID.addAll(split.matched); + final List> chunks = sortedBylocalID.chunks(kParallelism); + outerLoop: + for (final chunk in chunks) { + final futures = []; + for (final enteFile in chunk) { + if (isImageIndexRunning == false) { + _logger.info("indexAllImages() was paused, stopping"); + break outerLoop; + } + if (_skipAnalysisEnteFile( + enteFile, + alreadyIndexedFiles, + )) { + fileSkippedCount++; + continue; + } + futures.add(processImage(enteFile, alreadyIndexedFiles)); + } + await Future.wait(futures); + fileAnalyzedCount += futures.length; + } + + stopwatch.stop(); + _logger.info( + "`indexAllImages()` finished. Analyzed $fileAnalyzedCount images, in ${stopwatch.elapsed.inSeconds} seconds (avg of ${stopwatch.elapsed.inSeconds / fileAnalyzedCount} seconds per image, skipped $fileSkippedCount images)", + ); + + // Dispose of all the isolates + // ImageMlIsolate.instance.dispose(); + // await release(); + } catch (e, s) { + _logger.severe("indexAllImages failed", e, s); + } finally { + isImageIndexRunning = false; + } + } + + Future processImage( + EnteFile enteFile, + Set alreadyIndexedFiles, + ) async { + _logger.info( + "`indexAllImages()` on file number start processing image with uploadedFileID: ${enteFile.uploadedFileID}", + ); + + try { + final FaceMlResult? result = await analyzeImageInSingleIsolate( + enteFile, + // preferUsingThumbnailForEverything: false, + // disposeImageIsolateAfterUse: false, + ); + if (result == null) { + _logger.warning( + "Image not analyzed with uploadedFileID: ${enteFile.uploadedFileID}", + ); + return; + } + final List faces = []; + if (!result.hasFaces) { + faces.add( + Face( + '${result.fileId}-0', + result.fileId, + [], + result.errorOccured ? -1.0 : 0.0, + face_detection.Detection.empty(), + 0.0, + ), + ); + } else { + if (result.faceDetectionImageSize == null || + result.faceAlignmentImageSize == null) { + _logger.severe( + "faceDetectionImageSize or faceDetectionImageSize is null for image with " + "ID: ${enteFile.uploadedFileID}"); + } + final bool useAlign = result.faceAlignmentImageSize != null && + result.faceAlignmentImageSize!.width > 0 && + result.faceAlignmentImageSize!.height > 0 && + result.onlyThumbnailUsed == false; + if (useAlign) { + _logger.info( + "Using aligned image size for image with ID: ${enteFile.uploadedFileID}. This size is ${result.faceAlignmentImageSize!.width}x${result.faceAlignmentImageSize!.height} compared to size of ${enteFile.width}x${enteFile.height} in the metadata", + ); + } + for (int i = 0; i < result.faces.length; ++i) { + final FaceResult faceRes = result.faces[i]; + final FaceDetectionAbsolute absoluteDetection = + faceRes.detection.toAbsolute( + imageWidth: useAlign + ? result.faceAlignmentImageSize!.width.toInt() + : enteFile.width, + imageHeight: useAlign + ? result.faceAlignmentImageSize!.height.toInt() + : enteFile.height, + ); + final detection = face_detection.Detection( + box: FaceBox( + x: absoluteDetection.xMinBox, + y: absoluteDetection.yMinBox, + width: absoluteDetection.width, + height: absoluteDetection.height, + ), + landmarks: absoluteDetection.allKeypoints + .map( + (keypoint) => Landmark( + x: keypoint[0], + y: keypoint[0], + ), + ) + .toList(), + ); + faces.add( + Face( + faceRes.faceId, + result.fileId, + faceRes.embedding, + faceRes.detection.score, + detection, + faceRes.blurValue, + ), + ); + } + } + _logger.info("inserting ${faces.length} faces for ${result.fileId}"); + await FaceMLDataDB.instance.bulkInsertFaces(faces); + } catch (e, s) { + _logger.severe( + "Failed to analyze using FaceML for image with ID: ${enteFile.uploadedFileID}", + e, + s, + ); + } + } + + void pauseIndexing() { + isImageIndexRunning = false; + } + + /// Analyzes the given image data by running the full pipeline using [analyzeImageInComputerAndImageIsolate] and stores the result in the database [MlDataDB]. + /// This function first checks if the image has already been analyzed (with latest ml version) and stored in the database. If so, it returns the stored result. + /// + /// 'enteFile': The ente file to analyze. + /// + /// Returns an immutable [FaceMlResult] instance containing the results of the analysis. The result is also stored in the database. + Future indexImage(EnteFile enteFile) async { + _logger.info( + "`indexImage` called on image with uploadedFileID ${enteFile.uploadedFileID}", + ); + _checkEnteFileForID(enteFile); + + // Check if the image has already been analyzed and stored in the database with the latest ml version + final existingResult = await _checkForExistingUpToDateResult(enteFile); + if (existingResult != null) { + return existingResult; + } + + // If the image has not been analyzed and stored in the database, analyze it and store the result in the database + _logger.info( + "Image with uploadedFileID ${enteFile.uploadedFileID} has not been analyzed and stored in the database. Analyzing it now.", + ); + FaceMlResult result; + try { + result = await analyzeImageInComputerAndImageIsolate(enteFile); + } catch (e, s) { + _logger.severe( + "`indexImage` failed on image with uploadedFileID ${enteFile.uploadedFileID}", + e, + s, + ); + throw GeneralFaceMlException( + "`indexImage` failed on image with uploadedFileID ${enteFile.uploadedFileID}", + ); + } + + // Store the result in the database + await MlDataDB.instance.createFaceMlResult(result); + + return result; + } + + /// Analyzes the given image data by running the full pipeline (face detection, face alignment, face embedding). + /// + /// [enteFile] The ente file to analyze. + /// + /// [preferUsingThumbnailForEverything] If true, the thumbnail will be used for everything (face detection, face alignment, face embedding), and file data will be used only if a thumbnail is unavailable. + /// If false, thumbnail will only be used for detection, and the original image will be used for face alignment and face embedding. + /// + /// Returns an immutable [FaceMlResult] instance containing the results of the analysis. + /// Does not store the result in the database, for that you should use [indexImage]. + /// Throws [CouldNotRetrieveAnyFileData] or [GeneralFaceMlException] if something goes wrong. + /// TODO: improve function such that it only uses full image if it is already on the device, otherwise it uses thumbnail. And make sure to store what is used! + Future analyzeImageInComputerAndImageIsolate( + EnteFile enteFile, { + bool preferUsingThumbnailForEverything = false, + bool disposeImageIsolateAfterUse = true, + }) async { + _checkEnteFileForID(enteFile); + + final String? thumbnailPath = await _getImagePathForML( + enteFile, + typeOfData: FileDataForML.thumbnailData, + ); + String? filePath; + + // // TODO: remove/optimize this later. Not now though: premature optimization + // fileData = + // await _getDataForML(enteFile, typeOfData: FileDataForML.fileData); + + if (thumbnailPath == null) { + filePath = await _getImagePathForML( + enteFile, + typeOfData: FileDataForML.fileData, + ); + if (thumbnailPath == null && filePath == null) { + _logger.severe( + "Failed to get any data for enteFile with uploadedFileID ${enteFile.uploadedFileID}", + ); + throw CouldNotRetrieveAnyFileData(); + } + } + // TODO: use smallData and largeData instead of thumbnailData and fileData again! + final String smallDataPath = thumbnailPath ?? filePath!; + + final resultBuilder = FaceMlResultBuilder.fromEnteFile(enteFile); + + _logger.info( + "Analyzing image with uploadedFileID: ${enteFile.uploadedFileID} ${kDebugMode ? enteFile.displayName : ''}", + ); + final stopwatch = Stopwatch()..start(); + + try { + // Get the faces + final List faceDetectionResult = + await _detectFacesIsolate( + smallDataPath, + resultBuilder: resultBuilder, + ); + + _logger.info("Completed `detectFaces` function"); + + // If no faces were detected, return a result with no faces. Otherwise, continue. + if (faceDetectionResult.isEmpty) { + _logger.info( + "No faceDetectionResult, Completed analyzing image with uploadedFileID ${enteFile.uploadedFileID}, in " + "${stopwatch.elapsedMilliseconds} ms"); + return resultBuilder.buildNoFaceDetected(); + } + + if (!preferUsingThumbnailForEverything) { + filePath ??= await _getImagePathForML( + enteFile, + typeOfData: FileDataForML.fileData, + ); + } + resultBuilder.onlyThumbnailUsed = filePath == null; + final String largeDataPath = filePath ?? thumbnailPath!; + + // Align the faces + final Float32List faceAlignmentResult = await _alignFaces( + largeDataPath, + faceDetectionResult, + resultBuilder: resultBuilder, + ); + + _logger.info("Completed `alignFaces` function"); + + // Get the embeddings of the faces + final embeddings = await _embedFaces( + faceAlignmentResult, + resultBuilder: resultBuilder, + ); + + _logger.info("Completed `embedBatchFaces` function"); + + stopwatch.stop(); + _logger.info("Finished Analyze image (${embeddings.length} faces) with " + "uploadedFileID ${enteFile.uploadedFileID}, in " + "${stopwatch.elapsedMilliseconds} ms"); + + if (disposeImageIsolateAfterUse) { + // Close the image conversion isolate + ImageMlIsolate.instance.dispose(); + } + + return resultBuilder.build(); + } catch (e, s) { + _logger.severe( + "Could not analyze image with ID ${enteFile.uploadedFileID} \n", + e, + s, + ); + // throw GeneralFaceMlException("Could not analyze image"); + return resultBuilder.buildErrorOccurred(); + } + } + + Future analyzeImageInSingleIsolate(EnteFile enteFile) async { + _checkEnteFileForID(enteFile); + await ensureInitialized(); + + final String? thumbnailPath = await _getImagePathForML( + enteFile, + typeOfData: FileDataForML.thumbnailData, + ); + final String? filePath = + await _getImagePathForML(enteFile, typeOfData: FileDataForML.fileData); + + if (thumbnailPath == null && filePath == null) { + _logger.severe( + "Failed to get any data for enteFile with uploadedFileID ${enteFile.uploadedFileID}", + ); + throw CouldNotRetrieveAnyFileData(); + } + + final String smallDataPath = thumbnailPath ?? filePath!; + final String largeDataPath = filePath ?? thumbnailPath!; + + final Stopwatch stopwatch = Stopwatch()..start(); + late FaceMlResult result; + + try { + final resultJsonString = await _runInIsolate( + ( + FaceMlOperation.analyzeImage, + { + "enteFileID": enteFile.uploadedFileID ?? -1, + "smallDataPath": smallDataPath, + "largeDataPath": largeDataPath, + "faceDetectionAddress": + YoloOnnxFaceDetection.instance.sessionAddress, + "faceEmbeddingAddress": FaceEmbeddingOnnx.instance.sessionAddress, + } + ), + ) as String?; + if (resultJsonString == null) { + return null; + } + result = FaceMlResult.fromJsonString(resultJsonString); + } catch (e, s) { + _logger.severe( + "Could not analyze image with ID ${enteFile.uploadedFileID} \n", + e, + s, + ); + final resultBuilder = FaceMlResultBuilder.fromEnteFile(enteFile); + return resultBuilder.buildErrorOccurred(); + } + stopwatch.stop(); + _logger.info( + "Finished Analyze image (${result.faces.length} faces) with uploadedFileID ${enteFile.uploadedFileID}, in " + "${stopwatch.elapsedMilliseconds} ms", + ); + + return result; + } + + Future _getImagePathForML( + EnteFile enteFile, { + FileDataForML typeOfData = FileDataForML.fileData, + }) async { + String? imagePath; + + switch (typeOfData) { + case FileDataForML.fileData: + final stopwatch = Stopwatch()..start(); + final File? file = await getFile(enteFile, isOrigin: true); + if (file == null) { + _logger.warning("Could not get file for $enteFile"); + imagePath = null; + break; + } + imagePath = file.path; + stopwatch.stop(); + _logger.info( + "Getting file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + break; + + case FileDataForML.thumbnailData: + final stopwatch = Stopwatch()..start(); + final File? thumbnail = await getThumbnailForUploadedFile(enteFile); + if (thumbnail == null) { + _logger.warning("Could not get thumbnail for $enteFile"); + imagePath = null; + break; + } + imagePath = thumbnail.path; + stopwatch.stop(); + _logger.info( + "Getting thumbnail data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + break; + + case FileDataForML.compressedFileData: + _logger.warning( + "Getting compressed file data for uploadedFileID ${enteFile.uploadedFileID} is not implemented yet", + ); + imagePath = null; + break; + } + + return imagePath; + } + + Future _getDataForML( + EnteFile enteFile, { + FileDataForML typeOfData = FileDataForML.fileData, + }) async { + Uint8List? data; + + switch (typeOfData) { + case FileDataForML.fileData: + final stopwatch = Stopwatch()..start(); + final File? actualIoFile = await getFile(enteFile, isOrigin: true); + if (actualIoFile != null) { + data = await actualIoFile.readAsBytes(); + } + stopwatch.stop(); + _logger.info( + "Getting file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + + break; + + case FileDataForML.thumbnailData: + final stopwatch = Stopwatch()..start(); + data = await getThumbnail(enteFile); + stopwatch.stop(); + _logger.info( + "Getting thumbnail data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + break; + + case FileDataForML.compressedFileData: + final stopwatch = Stopwatch()..start(); + final String tempPath = Configuration.instance.getTempDirectory() + + "${enteFile.uploadedFileID!}"; + final File? actualIoFile = await getFile(enteFile); + if (actualIoFile != null) { + final compressResult = await FlutterImageCompress.compressAndGetFile( + actualIoFile.path, + tempPath + ".jpg", + ); + if (compressResult != null) { + data = await compressResult.readAsBytes(); + } + } + stopwatch.stop(); + _logger.info( + "Getting compressed file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + break; + } + + return data; + } + + /// Detects faces in the given image data. + /// + /// `imageData`: The image data to analyze. + /// + /// Returns a list of face detection results. + /// + /// Throws [CouldNotInitializeFaceDetector], [CouldNotRunFaceDetector] or [GeneralFaceMlException] if something goes wrong. + Future> _detectFacesIsolate( + String imagePath, + // Uint8List fileData, + { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + // Get the bounding boxes of the faces + final (List faces, dataSize) = + await YoloOnnxFaceDetection.instance.predictInComputer(imagePath); + + // Add detected faces to the resultBuilder + if (resultBuilder != null) { + resultBuilder.addNewlyDetectedFaces(faces, dataSize); + } + + return faces; + } on YOLOInterpreterInitializationException { + throw CouldNotInitializeFaceDetector(); + } on YOLOInterpreterRunException { + throw CouldNotRunFaceDetector(); + } catch (e) { + _logger.severe('Face detection failed: $e'); + throw GeneralFaceMlException('Face detection failed: $e'); + } + } + + /// Detects faces in the given image data. + /// + /// `imageData`: The image data to analyze. + /// + /// Returns a list of face detection results. + /// + /// Throws [CouldNotInitializeFaceDetector], [CouldNotRunFaceDetector] or [GeneralFaceMlException] if something goes wrong. + static Future> detectFacesSync( + String imagePath, + int interpreterAddress, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + // Get the bounding boxes of the faces + final (List faces, dataSize) = + await YoloOnnxFaceDetection.predictSync( + imagePath, + interpreterAddress, + ); + + // Add detected faces to the resultBuilder + if (resultBuilder != null) { + resultBuilder.addNewlyDetectedFaces(faces, dataSize); + } + + return faces; + } on YOLOInterpreterInitializationException { + throw CouldNotInitializeFaceDetector(); + } on YOLOInterpreterRunException { + throw CouldNotRunFaceDetector(); + } catch (e) { + dev.log('[SEVERE] Face detection failed: $e'); + throw GeneralFaceMlException('Face detection failed: $e'); + } + } + + /// Aligns multiple faces from the given image data. + /// + /// `imageData`: The image data in [Uint8List] that contains the faces. + /// `faces`: The face detection results in a list of [FaceDetectionAbsolute] for the faces to align. + /// + /// Returns a list of the aligned faces as image data. + /// + /// Throws [CouldNotWarpAffine] or [GeneralFaceMlException] if the face alignment fails. + Future _alignFaces( + String imagePath, + List faces, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + final ( + alignedFaces, + alignmentResults, + isBlurs, + blurValues, + originalImageSize + ) = await ImageMlIsolate.instance + .preprocessMobileFaceNetOnnx(imagePath, faces); + + if (resultBuilder != null) { + resultBuilder.addAlignmentResults( + alignmentResults, + blurValues, + originalImageSize, + ); + } + + return alignedFaces; + } catch (e, s) { + _logger.severe('Face alignment failed: $e', e, s); + throw CouldNotWarpAffine(); + } + } + + /// Aligns multiple faces from the given image data. + /// + /// `imageData`: The image data in [Uint8List] that contains the faces. + /// `faces`: The face detection results in a list of [FaceDetectionAbsolute] for the faces to align. + /// + /// Returns a list of the aligned faces as image data. + /// + /// Throws [CouldNotWarpAffine] or [GeneralFaceMlException] if the face alignment fails. + static Future alignFacesSync( + String imagePath, + List faces, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + final stopwatch = Stopwatch()..start(); + final ( + alignedFaces, + alignmentResults, + isBlurs, + blurValues, + originalImageSize + ) = await preprocessToMobileFaceNetFloat32List(imagePath, faces); + stopwatch.stop(); + dev.log( + "Face alignment image decoding and processing took ${stopwatch.elapsedMilliseconds} ms", + ); + + if (resultBuilder != null) { + resultBuilder.addAlignmentResults( + alignmentResults, + blurValues, + originalImageSize, + ); + } + + return alignedFaces; + } catch (e, s) { + dev.log('[SEVERE] Face alignment failed: $e $s'); + throw CouldNotWarpAffine(); + } + } + + /// Embeds multiple faces from the given input matrices. + /// + /// `facesMatrices`: The input matrices of the faces to embed. + /// + /// Returns a list of the face embeddings as lists of doubles. + /// + /// Throws [CouldNotInitializeFaceEmbeddor], [CouldNotRunFaceEmbeddor], [InputProblemFaceEmbeddor] or [GeneralFaceMlException] if the face embedding fails. + Future>> _embedFaces( + Float32List facesList, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + // Get the embedding of the faces + final List> embeddings = + await FaceEmbeddingOnnx.instance.predictInComputer(facesList); + + // Add the embeddings to the resultBuilder + if (resultBuilder != null) { + resultBuilder.addEmbeddingsToExistingFaces(embeddings); + } + + return embeddings; + } on MobileFaceNetInterpreterInitializationException { + throw CouldNotInitializeFaceEmbeddor(); + } on MobileFaceNetInterpreterRunException { + throw CouldNotRunFaceEmbeddor(); + } on MobileFaceNetEmptyInput { + throw InputProblemFaceEmbeddor("Input is empty"); + } on MobileFaceNetWrongInputSize { + throw InputProblemFaceEmbeddor("Input size is wrong"); + } on MobileFaceNetWrongInputRange { + throw InputProblemFaceEmbeddor("Input range is wrong"); + // ignore: avoid_catches_without_on_clauses + } catch (e) { + _logger.severe('Face embedding (batch) failed: $e'); + throw GeneralFaceMlException('Face embedding (batch) failed: $e'); + } + } + + static Future>> embedFacesSync( + Float32List facesList, + int interpreterAddress, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + // Get the embedding of the faces + final List> embeddings = + await FaceEmbeddingOnnx.predictSync(facesList, interpreterAddress); + + // Add the embeddings to the resultBuilder + if (resultBuilder != null) { + resultBuilder.addEmbeddingsToExistingFaces(embeddings); + } + + return embeddings; + } on MobileFaceNetInterpreterInitializationException { + throw CouldNotInitializeFaceEmbeddor(); + } on MobileFaceNetInterpreterRunException { + throw CouldNotRunFaceEmbeddor(); + } on MobileFaceNetEmptyInput { + throw InputProblemFaceEmbeddor("Input is empty"); + } on MobileFaceNetWrongInputSize { + throw InputProblemFaceEmbeddor("Input size is wrong"); + } on MobileFaceNetWrongInputRange { + throw InputProblemFaceEmbeddor("Input range is wrong"); + // ignore: avoid_catches_without_on_clauses + } catch (e) { + dev.log('[SEVERE] Face embedding (batch) failed: $e'); + throw GeneralFaceMlException('Face embedding (batch) failed: $e'); + } + } + + /// Checks if the ente file to be analyzed actually can be analyzed: it must be uploaded and in the correct format. + void _checkEnteFileForID(EnteFile enteFile) { + if (_skipAnalysisEnteFile(enteFile, {})) { + _logger.severe( + "Skipped analysis of image with enteFile ${enteFile.toString()} because it is the wrong format or has no uploadedFileID", + ); + throw CouldNotRetrieveAnyFileData(); + } + } + + bool _skipAnalysisEnteFile(EnteFile enteFile, Set indexedFileIds) { + // Skip if the file is not uploaded or not owned by the user + if (!enteFile.isUploaded || enteFile.isOwner == false) { + return true; + } + + // Skip if the file is a video + if (enteFile.fileType == FileType.video) { + return true; + } + // I don't know how motionPhotos and livePhotos work, so I'm also just skipping them for now + if (enteFile.fileType == FileType.other) { + return true; + } + // Skip if the file is already analyzed with the latest ml version + final id = enteFile.uploadedFileID!; + return indexedFileIds.contains(id); + } + + Future _checkForExistingUpToDateResult( + EnteFile enteFile, + ) async { + // Check if the image has already been analyzed and stored in the database + final existingResult = + await MlDataDB.instance.getFaceMlResult(enteFile.uploadedFileID!); + + // If the image has already been analyzed and stored in the database, return the stored result + if (existingResult != null) { + if (existingResult.mlVersion >= faceMlVersion) { + _logger.info( + "Image with uploadedFileID ${enteFile.uploadedFileID} has already been analyzed and stored in the database with the latest ml version. Returning the stored result.", + ); + return existingResult; + } + } + return null; + } +} diff --git a/mobile/lib/services/face_ml/face_ml_version.dart b/mobile/lib/services/face_ml/face_ml_version.dart new file mode 100644 index 000000000..a91c4c843 --- /dev/null +++ b/mobile/lib/services/face_ml/face_ml_version.dart @@ -0,0 +1,15 @@ +abstract class VersionedMethod { + final String method; + final int version; + + VersionedMethod(this.method, [this.version = 0]); + + const VersionedMethod.empty() + : method = 'Empty method', + version = 0; + + Map toJson() => { + 'method': method, + 'version': version, + }; +} diff --git a/mobile/lib/services/face_ml/face_search_service.dart b/mobile/lib/services/face_ml/face_search_service.dart new file mode 100644 index 000000000..c373f73e4 --- /dev/null +++ b/mobile/lib/services/face_ml/face_search_service.dart @@ -0,0 +1,120 @@ +import "dart:typed_data"; + +import "package:logging/logging.dart"; +import "package:photos/db/files_db.dart"; +import "package:photos/db/ml_data_db.dart"; +import "package:photos/models/file/file.dart"; +import 'package:photos/utils/image_ml_isolate.dart'; +import "package:photos/utils/thumbnail_util.dart"; + +class FaceSearchService { + final _logger = Logger("FaceSearchService"); + + final _mlDatabase = MlDataDB.instance; + final _filesDatabase = FilesDB.instance; + + // singleton pattern + FaceSearchService._privateConstructor(); + static final instance = FaceSearchService._privateConstructor(); + factory FaceSearchService() => instance; + + /// Returns the personIDs of all clustered people in the database. + Future> getAllPeople() async { + final peopleIds = await _mlDatabase.getAllClusterIds(); + return peopleIds; + } + + /// Returns the thumbnail associated with a given personId. + Future getPersonThumbnail(int personID) async { + // get the cluster associated with the personID + final cluster = await _mlDatabase.getClusterResult(personID); + if (cluster == null) { + _logger.warning( + "No cluster found for personID $personID, unable to get thumbnail.", + ); + return null; + } + + // get the faceID and fileID you want to use to generate the thumbnail + final String thumbnailFaceID = cluster.thumbnailFaceId; + final int thumbnailFileID = cluster.thumbnailFileId; + + // get the full file thumbnail + final EnteFile enteFile = await _filesDatabase + .getFilesFromIDs([thumbnailFileID]).then((value) => value.values.first); + final Uint8List? fileThumbnail = await getThumbnail(enteFile); + if (fileThumbnail == null) { + _logger.warning( + "No full file thumbnail found for thumbnail faceID $thumbnailFaceID, unable to get thumbnail.", + ); + return null; + } + + // get the face detection for the thumbnail + final thumbnailMlResult = + await _mlDatabase.getFaceMlResult(thumbnailFileID); + if (thumbnailMlResult == null) { + _logger.warning( + "No face ml result found for thumbnail faceID $thumbnailFaceID, unable to get thumbnail.", + ); + return null; + } + final detection = thumbnailMlResult.getDetectionForFaceId(thumbnailFaceID); + + // create the thumbnail from the full file thumbnail and the face detection + Uint8List faceThumbnail; + try { + faceThumbnail = await ImageMlIsolate.instance.generateFaceThumbnail( + fileThumbnail, + detection, + ); + } catch (e, s) { + _logger.warning( + "Unable to generate face thumbnail for thumbnail faceID $thumbnailFaceID, unable to get thumbnail.", + e, + s, + ); + return null; + } + + return faceThumbnail; + } + + /// Returns all files associated with a given personId. + Future> getFilesForPerson(int personID) async { + final fileIDs = await _mlDatabase.getClusterFileIds(personID); + + final Map files = + await _filesDatabase.getFilesFromIDs(fileIDs); + return files.values.toList(); + } + + Future> getFilesForIntersectOfPeople( + List personIDs, + ) async { + if (personIDs.length <= 1) { + _logger + .warning('Cannot get intersection of files for less than 2 people'); + return []; + } + + final Set fileIDsFirstCluster = await _mlDatabase + .getClusterFileIds(personIDs.first) + .then((value) => value.toSet()); + for (final personID in personIDs.sublist(1)) { + final fileIDsSingleCluster = + await _mlDatabase.getClusterFileIds(personID); + fileIDsFirstCluster.retainAll(fileIDsSingleCluster); + + // Early termination if intersection is empty + if (fileIDsFirstCluster.isEmpty) { + return []; + } + } + + final Map files = + await _filesDatabase.getFilesFromIDs(fileIDsFirstCluster.toList()); + + return files.values.toList(); + } +} diff --git a/mobile/lib/services/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/face_ml/feedback/cluster_feedback.dart new file mode 100644 index 000000000..129eee83f --- /dev/null +++ b/mobile/lib/services/face_ml/feedback/cluster_feedback.dart @@ -0,0 +1,464 @@ +import 'dart:developer' as dev; +import "dart:math" show Random; +import "dart:typed_data"; + +import "package:flutter/foundation.dart"; +import "package:logging/logging.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/extensions/stop_watch.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/generated/protos/ente/common/vector.pb.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/services/face_ml/face_clustering/cosine_distance.dart"; +import "package:photos/services/search_service.dart"; + +class ClusterFeedbackService { + final Logger _logger = Logger("ClusterFeedbackService"); + ClusterFeedbackService._privateConstructor(); + + static final ClusterFeedbackService instance = + ClusterFeedbackService._privateConstructor(); + + /// Returns a map of person's clusterID to map of closest clusterID to with disstance + Future>> getSuggestionsUsingMean( + Person p, { + double maxClusterDistance = 0.4, + }) async { + // Get all the cluster data + final faceMlDb = FaceMLDataDB.instance; + + final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount()); + final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); + final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); + dev.log( + 'existing clusters for ${p.attr.name} are $personClusters', + name: "ClusterFeedbackService", + ); + + // Get and update the cluster summary to get the avg (centroid) and count + final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); + final Map> clusterAvg = await _getUpdateClusterAvg( + allClusterIdsToCountMap, + ignoredClusters, + ); + watch.log('computed avg for ${clusterAvg.length} clusters'); + + // Find the actual closest clusters for the person + final Map> suggestions = _calcSuggestionsMean( + clusterAvg, + personClusters, + ignoredClusters, + maxClusterDistance, + ); + + // log suggestions + for (final entry in suggestions.entries) { + dev.log( + ' ${entry.value.length} suggestion for ${p.attr.name} for cluster ID ${entry.key} are suggestions ${entry.value}}', + name: "ClusterFeedbackService", + ); + } + return suggestions; + } + + Future> getSuggestionsUsingMedian( + Person p, { + int sampleSize = 50, + double maxMedianDistance = 0.65, + double goodMedianDistance = 0.55, + double maxMeanDistance = 0.65, + double goodMeanDistance = 0.4, + }) async { + // Get all the cluster data + final faceMlDb = FaceMLDataDB.instance; + // final Map> suggestions = {}; + final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount()); + final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); + final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); + dev.log( + 'existing clusters for ${p.attr.name} are $personClusters', + name: "ClusterFeedbackService", + ); + + // Get and update the cluster summary to get the avg (centroid) and count + final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); + final Map> clusterAvg = await _getUpdateClusterAvg( + allClusterIdsToCountMap, + ignoredClusters, + ); + watch.log('computed avg for ${clusterAvg.length} clusters'); + + // Find the other cluster candidates based on the mean + final Map> suggestionsMean = _calcSuggestionsMean( + clusterAvg, + personClusters, + ignoredClusters, + goodMeanDistance, + ); + if (suggestionsMean.isNotEmpty) { + final List<(int, double)> suggestClusterIds = []; + for (final List<(int, double)> suggestion in suggestionsMean.values) { + suggestClusterIds.addAll(suggestion); + } + suggestClusterIds.sort( + (a, b) => allClusterIdsToCountMap[b.$1]! + .compareTo(allClusterIdsToCountMap[a.$1]!), + ); + final suggestClusterIdsSizes = suggestClusterIds + .map((e) => allClusterIdsToCountMap[e.$1]!) + .toList(growable: false); + final suggestClusterIdsDistances = + suggestClusterIds.map((e) => e.$2).toList(growable: false); + _logger.info( + "Already found good suggestions using mean: $suggestClusterIds, with sizes $suggestClusterIdsSizes and distances $suggestClusterIdsDistances", + ); + return suggestClusterIds.map((e) => e.$1).toList(growable: false); + } + + // Find the other cluster candidates based on the median + final Map> moreSuggestionsMean = + _calcSuggestionsMean( + clusterAvg, + personClusters, + ignoredClusters, + maxMeanDistance, + ); + if (moreSuggestionsMean.isEmpty) { + _logger + .info("No suggestions found using mean, even with higher threshold"); + return []; + } + + final List<(int, double)> temp = []; + for (final List<(int, double)> suggestion in moreSuggestionsMean.values) { + temp.addAll(suggestion); + } + temp.sort((a, b) => a.$2.compareTo(b.$2)); + final otherClusterIdsCandidates = temp + .map( + (e) => e.$1, + ) + .toList(growable: false); + _logger.info( + "Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates", + ); + + watch.logAndReset("Starting median test"); + // Take the embeddings from the person's clusters in one big list and sample from it + final List personEmbeddingsProto = []; + for (final clusterID in personClusters) { + final Iterable embedings = + await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); + personEmbeddingsProto.addAll(embedings); + } + final List sampledEmbeddingsProto = + _randomSampleWithoutReplacement( + personEmbeddingsProto, + sampleSize, + ); + final List> sampledEmbeddings = sampledEmbeddingsProto + .map((embedding) => EVector.fromBuffer(embedding).values) + .toList(growable: false); + + // Find the actual closest clusters for the person using median + final List<(int, double)> suggestionsMedian = []; + final List<(int, double)> greatSuggestionsMedian = []; + double minMedianDistance = maxMedianDistance; + for (final otherClusterId in otherClusterIdsCandidates) { + final Iterable otherEmbeddingsProto = + await FaceMLDataDB.instance.getFaceEmbeddingsForCluster( + otherClusterId, + ); + final sampledOtherEmbeddingsProto = _randomSampleWithoutReplacement( + otherEmbeddingsProto, + sampleSize, + ); + final List> sampledOtherEmbeddings = + sampledOtherEmbeddingsProto + .map((embedding) => EVector.fromBuffer(embedding).values) + .toList(growable: false); + + // Calculate distances and find the median + final List distances = []; + for (final otherEmbedding in sampledOtherEmbeddings) { + for (final embedding in sampledEmbeddings) { + distances.add(cosineDistForNormVectors(embedding, otherEmbedding)); + } + } + distances.sort(); + final double medianDistance = distances[distances.length ~/ 2]; + if (medianDistance < minMedianDistance) { + suggestionsMedian.add((otherClusterId, medianDistance)); + minMedianDistance = medianDistance; + if (medianDistance < goodMedianDistance) { + greatSuggestionsMedian.add((otherClusterId, medianDistance)); + break; + } + } + } + watch.log("Finished median test"); + if (suggestionsMedian.isEmpty) { + _logger.info("No suggestions found using median"); + return []; + } else { + _logger.info("Found suggestions using median: $suggestionsMedian"); + } + + final List finalSuggestionsMedian = suggestionsMedian + .map(((e) => e.$1)) + .toList(growable: false) + .reversed + .toList(growable: false); + + if (greatSuggestionsMedian.isNotEmpty) { + _logger.info( + "Found great suggestion using median: $greatSuggestionsMedian", + ); + // // Return the largest size cluster by using allClusterIdsToCountMap + // final List greatSuggestionsMedianClusterIds = + // greatSuggestionsMedian.map((e) => e.$1).toList(growable: false); + // greatSuggestionsMedianClusterIds.sort( + // (a, b) => + // allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!), + // ); + + // return [greatSuggestionsMedian.last.$1, ...finalSuggestionsMedian]; + } + + return finalSuggestionsMedian; + } + + Future)>> getClusterFilesForPersonID( + Person person, + ) async { + _logger.info( + 'getClusterFilesForPersonID ${kDebugMode ? person.attr.name : person.remoteID}', + ); + + // Get the suggestions for the person using only centroids + // final Map> suggestions = + // await getSuggestionsUsingMean(person); + // final Set suggestClusterIds = {}; + // for (final List<(int, double)> suggestion in suggestions.values) { + // for (final clusterNeighbors in suggestion) { + // suggestClusterIds.add(clusterNeighbors.$1); + // } + // } + + try { + // Get the suggestions for the person using centroids and median + final List suggestClusterIds = + await getSuggestionsUsingMedian(person); + + // Get the files for the suggestions + final Map> fileIdToClusterID = await FaceMLDataDB.instance + .getFileIdToClusterIDSetForCluster(suggestClusterIds.toSet()); + final Map> clusterIDToFiles = {}; + final allFiles = await SearchService.instance.getAllFiles(); + for (final f in allFiles) { + if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) { + continue; + } + final cluserIds = fileIdToClusterID[f.uploadedFileID ?? -1]!; + for (final cluster in cluserIds) { + if (clusterIDToFiles.containsKey(cluster)) { + clusterIDToFiles[cluster]!.add(f); + } else { + clusterIDToFiles[cluster] = [f]; + } + } + } + + final List<(int, List)> clusterIdAndFiles = []; + for (final clusterId in suggestClusterIds) { + if (clusterIDToFiles.containsKey(clusterId)) { + clusterIdAndFiles.add( + (clusterId, clusterIDToFiles[clusterId]!), + ); + } + } + + return clusterIdAndFiles; + } catch (e, s) { + _logger.severe("Error in getClusterFilesForPersonID", e, s); + rethrow; + } + } + + Future removePersonFromFiles(List files, Person p) { + return FaceMLDataDB.instance.removePersonFromFiles(files, p); + } + + Future checkAndDoAutomaticMerges(Person p) async { + final faceMlDb = FaceMLDataDB.instance; + final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount()); + final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); + final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); + dev.log( + 'existing clusters for ${p.attr.name} are $personClusters', + name: "ClusterFeedbackService", + ); + + // Get and update the cluster summary to get the avg (centroid) and count + final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); + final Map> clusterAvg = await _getUpdateClusterAvg( + allClusterIdsToCountMap, + ignoredClusters, + ); + watch.log('computed avg for ${clusterAvg.length} clusters'); + + // Find the actual closest clusters for the person + final Map> suggestions = _calcSuggestionsMean( + clusterAvg, + personClusters, + ignoredClusters, + 0.3, + ); + + if (suggestions.isEmpty) { + dev.log( + 'No automatic merge suggestions for ${p.attr.name}', + name: "ClusterFeedbackService", + ); + return false; + } + + // log suggestions + for (final entry in suggestions.entries) { + dev.log( + ' ${entry.value.length} suggestion for ${p.attr.name} for cluster ID ${entry.key} are suggestions ${entry.value}}', + name: "ClusterFeedbackService", + ); + } + + for (final suggestionsPerCluster in suggestions.values) { + for (final suggestion in suggestionsPerCluster) { + final clusterID = suggestion.$1; + await faceMlDb.assignClusterToPerson( + personID: p.remoteID, + clusterID: clusterID, + ); + } + } + + Bus.instance.fire(PeopleChangedEvent()); + + return true; + } + + Future>> _getUpdateClusterAvg( + Map allClusterIdsToCountMap, + Set ignoredClusters, + ) async { + final faceMlDb = FaceMLDataDB.instance; + + final Map clusterToSummary = + await faceMlDb.clusterSummaryAll(); + final Map updatesForClusterSummary = {}; + + final Map> clusterAvg = {}; + + final allClusterIds = allClusterIdsToCountMap.keys; + for (final clusterID in allClusterIds) { + if (ignoredClusters.contains(clusterID)) { + continue; + } + late List avg; + if (clusterToSummary[clusterID]?.$2 == + allClusterIdsToCountMap[clusterID]) { + avg = EVector.fromBuffer(clusterToSummary[clusterID]!.$1).values; + } else { + final Iterable embedings = + await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); + final List sum = List.filled(192, 0); + for (final embedding in embedings) { + final data = EVector.fromBuffer(embedding).values; + for (int i = 0; i < sum.length; i++) { + sum[i] += data[i]; + } + } + avg = sum.map((e) => e / embedings.length).toList(); + final avgEmbeedingBuffer = EVector(values: avg).writeToBuffer(); + updatesForClusterSummary[clusterID] = + (avgEmbeedingBuffer, embedings.length); + } + clusterAvg[clusterID] = avg; + } + if (updatesForClusterSummary.isNotEmpty) { + await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); + } + + return clusterAvg; + } + + Map> _calcSuggestionsMean( + Map> clusterAvg, + Set personClusters, + Set ignoredClusters, + double maxClusterDistance, + ) { + final Map> suggestions = {}; + for (final otherClusterID in clusterAvg.keys) { + // ignore the cluster that belong to the person or is ignored + if (personClusters.contains(otherClusterID) || + ignoredClusters.contains(otherClusterID)) { + continue; + } + final otherAvg = clusterAvg[otherClusterID]!; + int? nearestPersonCluster; + double? minDistance; + for (final personCluster in personClusters) { + final avg = clusterAvg[personCluster]!; + final distance = cosineDistForNormVectors(avg, otherAvg); + if (distance < maxClusterDistance) { + if (minDistance == null || distance < minDistance) { + minDistance = distance; + nearestPersonCluster = personCluster; + } + } + } + if (nearestPersonCluster != null && minDistance != null) { + suggestions + .putIfAbsent(nearestPersonCluster, () => []) + .add((otherClusterID, minDistance)); + } + } + for (final entry in suggestions.entries) { + entry.value.sort((a, b) => a.$1.compareTo(b.$1)); + } + + return suggestions; + } + + List _randomSampleWithoutReplacement( + Iterable embeddings, + int sampleSize, + ) { + final random = Random(); + + if (sampleSize >= embeddings.length) { + return embeddings.toList(); + } + + // If sampleSize is more than half the list size, shuffle and take first sampleSize elements + if (sampleSize > embeddings.length / 2) { + final List shuffled = List.from(embeddings)..shuffle(random); + return shuffled.take(sampleSize).toList(growable: false); + } + + // Otherwise, use the set-based method for efficiency + final selectedIndices = {}; + final sampledEmbeddings = []; + while (sampledEmbeddings.length < sampleSize) { + final int index = random.nextInt(embeddings.length); + if (!selectedIndices.contains(index)) { + selectedIndices.add(index); + sampledEmbeddings.add(embeddings.elementAt(index)); + } + } + + return sampledEmbeddings; + } +} diff --git a/mobile/lib/services/face_ml/model_file.dart b/mobile/lib/services/face_ml/model_file.dart new file mode 100644 index 000000000..66a457109 --- /dev/null +++ b/mobile/lib/services/face_ml/model_file.dart @@ -0,0 +1,11 @@ +mixin ModelFile { + static const String faceDetectionBackWeb = + 'assets/models/blazeface/blazeface_back_ente_web.tflite'; + // TODO: which of the two mobilefacenet model should I use now?? + // static const String faceEmbeddingEnte = + // 'assets/models/mobilefacenet/mobilefacenet_ente_web.tflite'; + static const String faceEmbeddingEnte = + 'assets/models/mobilefacenet/mobilefacenet_unq_TF211.tflite'; + static const String yoloV5FaceS640x640DynamicBatchonnx = + 'assets/models/yolov5face/yolov5s_face_640_640_dynamic.onnx'; +} diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index fa2317836..f9c17bf1d 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -11,6 +11,8 @@ import 'package:photos/data/years.dart'; import 'package:photos/db/files_db.dart'; import 'package:photos/events/local_photos_updated_event.dart'; import "package:photos/extensions/string_ext.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; import "package:photos/models/api/collection/user.dart"; import 'package:photos/models/collection/collection.dart'; import 'package:photos/models/collection/collection_items.dart'; @@ -22,6 +24,7 @@ import "package:photos/models/location/location.dart"; import "package:photos/models/location_tag/location_tag.dart"; import 'package:photos/models/search/album_search_result.dart'; import 'package:photos/models/search/generic_search_result.dart'; +import "package:photos/models/search/search_constants.dart"; import "package:photos/models/search/search_types.dart"; import 'package:photos/services/collections_service.dart'; import "package:photos/services/location_service.dart"; @@ -29,6 +32,8 @@ import 'package:photos/services/machine_learning/semantic_search/semantic_search import "package:photos/states/location_screen_state.dart"; import "package:photos/ui/viewer/location/add_location_sheet.dart"; import "package:photos/ui/viewer/location/location_screen.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/people/people_page.dart"; import 'package:photos/utils/date_time_util.dart'; import "package:photos/utils/navigation_util.dart"; import 'package:tuple/tuple.dart'; @@ -704,6 +709,146 @@ class SearchService { return searchResults; } + Future>> getClusterFilesForPersonID( + String personID, + ) async { + _logger.info('getClusterFilesForPersonID $personID'); + final Map> fileIdToClusterID = + await FaceMLDataDB.instance.getFileIdToClusterIDSet(personID); + _logger.info('faceDbDone getClusterFilesForPersonID $personID'); + final Map> clusterIDToFiles = {}; + final allFiles = await getAllFiles(); + for (final f in allFiles) { + if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) { + continue; + } + final cluserIds = fileIdToClusterID[f.uploadedFileID ?? -1]!; + for (final cluster in cluserIds) { + if (clusterIDToFiles.containsKey(cluster)) { + clusterIDToFiles[cluster]!.add(f); + } else { + clusterIDToFiles[cluster] = [f]; + } + } + } + _logger.info('done getClusterFilesForPersonID $personID'); + return clusterIDToFiles; + } + + Future> getAllFace(int? limit) async { + debugPrint("getting faces"); + final Map> fileIdToClusterID = + await FaceMLDataDB.instance.getFileIdToClusterIds(); + final (clusterIDToPerson, personIdToPerson) = + await FaceMLDataDB.instance.getClusterIdToPerson(); + + debugPrint("building result"); + final List facesResult = []; + final Map> clusterIdToFiles = {}; + final Map> personIdToFiles = {}; + final allFiles = await getAllFiles(); + for (final f in allFiles) { + if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) { + continue; + } + final cluserIds = fileIdToClusterID[f.uploadedFileID ?? -1]!; + for (final cluster in cluserIds) { + final Person? p = clusterIDToPerson[cluster]; + if (p != null) { + if (personIdToFiles.containsKey(p.remoteID)) { + personIdToFiles[p.remoteID]!.add(f); + } else { + personIdToFiles[p.remoteID] = [f]; + } + } else { + if (clusterIdToFiles.containsKey(cluster)) { + clusterIdToFiles[cluster]!.add(f); + } else { + clusterIdToFiles[cluster] = [f]; + } + } + } + } + // get sorted personId by files count + final sortedPersonIds = personIdToFiles.keys.toList() + ..sort( + (a, b) => personIdToFiles[b]!.length.compareTo( + personIdToFiles[a]!.length, + ), + ); + for (final personID in sortedPersonIds) { + final files = personIdToFiles[personID]!; + if (files.isEmpty) { + continue; + } + final Person p = personIdToPerson[personID]!; + facesResult.add( + GenericSearchResult( + ResultType.faces, + p.attr.name, + files, + params: { + kPersonParamID: personID, + kFileID: files.first.uploadedFileID, + }, + onResultTap: (ctx) { + routeToPage( + ctx, + PeoplePage( + tagPrefix: "${ResultType.faces.toString()}_${p.attr.name}", + person: p, + ), + ); + }, + ), + ); + } + final sortedClusterIds = clusterIdToFiles.keys.toList() + ..sort( + (a, b) => + clusterIdToFiles[b]!.length.compareTo(clusterIdToFiles[a]!.length), + ); + + for (final clusterId in sortedClusterIds) { + final files = clusterIdToFiles[clusterId]!; + // final String clusterName = "ID:$clusterId, ${files.length}"; + final String clusterName = "${files.length}"; + final Person? p = clusterIDToPerson[clusterId]; + if (p != null) { + throw Exception("Person should be null"); + } + if (files.length < 3) { + continue; + } + facesResult.add( + GenericSearchResult( + ResultType.faces, + clusterName, + files, + params: { + kClusterParamId: clusterId, + kFileID: files.first.uploadedFileID, + }, + onResultTap: (ctx) { + routeToPage( + ctx, + ClusterPage( + files, + tagPrefix: "${ResultType.faces.toString()}_$clusterName", + cluserID: clusterId, + ), + ); + }, + ), + ); + } + if (limit != null) { + return facesResult.sublist(0, min(limit, facesResult.length)); + } else { + return facesResult; + } + } + Future> getAllLocationTags(int? limit) async { try { final Map, List> tagToItemsMap = {}; diff --git a/mobile/lib/states/all_sections_examples_state.dart b/mobile/lib/states/all_sections_examples_state.dart index fdeb6fcdf..a40ecd925 100644 --- a/mobile/lib/states/all_sections_examples_state.dart +++ b/mobile/lib/states/all_sections_examples_state.dart @@ -6,6 +6,7 @@ import "package:logging/logging.dart"; import "package:photos/core/constants.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/events/files_updated_event.dart"; +import "package:photos/events/people_changed_event.dart"; import "package:photos/events/tab_changed_event.dart"; import "package:photos/models/search/search_result.dart"; import "package:photos/models/search/search_types.dart"; @@ -31,6 +32,7 @@ class _AllSectionsExamplesProviderState Future>> allSectionsExamplesFuture = Future.value([]); late StreamSubscription _filesUpdatedEvent; + late StreamSubscription _onPeopleChangedEvent; late StreamSubscription _tabChangeEvent; bool hasPendingUpdate = false; bool isOnSearchTab = false; @@ -46,16 +48,11 @@ class _AllSectionsExamplesProviderState super.initState(); //add all common events for all search sections to reload to here. _filesUpdatedEvent = Bus.instance.on().listen((event) { - if (!isOnSearchTab) { - if (kDebugMode) { - _logger.finest('Skip reload till user clicks on search tab'); - } - hasPendingUpdate = true; - return; - } else { - hasPendingUpdate = false; - reloadAllSections(); - } + onDataUpdate(); + }); + _onPeopleChangedEvent = + Bus.instance.on().listen((event) { + onDataUpdate(); }); _tabChangeEvent = Bus.instance.on().listen((event) { if (event.source == TabChangedEventSource.pageView && @@ -72,6 +69,18 @@ class _AllSectionsExamplesProviderState reloadAllSections(); } + void onDataUpdate() { + if (!isOnSearchTab) { + if (kDebugMode) { + _logger.finest('Skip reload till user clicks on search tab'); + } + hasPendingUpdate = true; + } else { + hasPendingUpdate = false; + reloadAllSections(); + } + } + void reloadAllSections() { _logger.info('queue reload all sections'); _debouncer.run(() async { @@ -79,22 +88,28 @@ class _AllSectionsExamplesProviderState _logger.info("'_debounceTimer: reloading all sections in search tab"); final allSectionsExamples = >>[]; for (SectionType sectionType in SectionType.values) { - if (sectionType == SectionType.face || - sectionType == SectionType.content) { + if (sectionType == SectionType.content) { continue; } allSectionsExamples.add( sectionType.getData(context, limit: kSearchSectionLimit), ); } - allSectionsExamplesFuture = - Future.wait>(allSectionsExamples); + try { + allSectionsExamplesFuture = Future.wait>( + allSectionsExamples, + eagerError: false, + ); + } catch (e) { + _logger.severe("Error reloading all sections: $e"); + } }); }); } @override void dispose() { + _onPeopleChangedEvent.cancel(); _filesUpdatedEvent.cancel(); _tabChangeEvent.cancel(); _debouncer.cancelDebounce(); diff --git a/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart b/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart index b896e0f1f..8228e30e2 100644 --- a/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart +++ b/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart @@ -1,5 +1,6 @@ import 'package:flutter/material.dart'; import 'package:photos/core/constants.dart'; +import "package:photos/face/model/person.dart"; import 'package:photos/models/collection/collection.dart'; import "package:photos/models/gallery_type.dart"; import 'package:photos/models/selected_files.dart'; @@ -11,6 +12,8 @@ import "package:photos/ui/viewer/actions/file_selection_actions_widget.dart"; class BottomActionBarWidget extends StatelessWidget { final GalleryType galleryType; final Collection? collection; + final Person? person; + final int? clusterID; final SelectedFiles selectedFiles; final VoidCallback? onCancel; final Color? backgroundColor; @@ -19,6 +22,8 @@ class BottomActionBarWidget extends StatelessWidget { required this.galleryType, required this.selectedFiles, this.collection, + this.person, + this.clusterID, this.onCancel, this.backgroundColor, super.key, @@ -54,6 +59,8 @@ class BottomActionBarWidget extends StatelessWidget { galleryType, selectedFiles, collection: collection, + person: person, + clusterID: clusterID, ), const DividerWidget(dividerType: DividerType.bottomBar), ActionBarWidget( diff --git a/mobile/lib/ui/settings/debug_section_widget.dart b/mobile/lib/ui/settings/debug/debug_section_widget.dart similarity index 99% rename from mobile/lib/ui/settings/debug_section_widget.dart rename to mobile/lib/ui/settings/debug/debug_section_widget.dart index 039655ca3..56070c214 100644 --- a/mobile/lib/ui/settings/debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/debug_section_widget.dart @@ -67,7 +67,6 @@ class DebugSectionWidget extends StatelessWidget { showShortToast(context, "Done"); }, ), - sectionOptionSpacing, ], ); } diff --git a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart new file mode 100644 index 000000000..d9a03b37a --- /dev/null +++ b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart @@ -0,0 +1,214 @@ +import "dart:async"; + +import "package:flutter/foundation.dart"; +import 'package:flutter/material.dart'; +import "package:logging/logging.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/extensions/stop_watch.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/services/face_ml/face_ml_service.dart"; +import "package:photos/services/face_ml/feedback/cluster_feedback.dart"; +import 'package:photos/theme/ente_theme.dart'; +import 'package:photos/ui/components/captioned_text_widget.dart'; +import 'package:photos/ui/components/expandable_menu_item_widget.dart'; +import 'package:photos/ui/components/menu_item_widget/menu_item_widget.dart'; +import 'package:photos/ui/settings/common_settings.dart'; +import "package:photos/utils/dialog_util.dart"; +import "package:photos/utils/local_settings.dart"; +import 'package:photos/utils/toast_util.dart'; + +class FaceDebugSectionWidget extends StatefulWidget { + const FaceDebugSectionWidget({Key? key}) : super(key: key); + + @override + State createState() => _FaceDebugSectionWidgetState(); +} + +class _FaceDebugSectionWidgetState extends State { + Timer? _timer; + @override + void initState() { + super.initState(); + _timer = Timer.periodic(const Duration(seconds: 5), (timer) { + setState(() { + // Your state update logic here + }); + }); + } + + @override + void dispose() { + _timer?.cancel(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + return ExpandableMenuItemWidget( + title: "Face Beta", + selectionOptionsWidget: _getSectionOptions(context), + leadingIcon: Icons.bug_report_outlined, + ); + } + + Widget _getSectionOptions(BuildContext context) { + final Logger _logger = Logger("FaceDebugSectionWidget"); + return Column( + children: [ + MenuItemWidget( + captionedTextWidget: FutureBuilder>( + future: FaceMLDataDB.instance.getIndexedFileIds(), + builder: (context, snapshot) { + if (snapshot.hasData) { + return CaptionedTextWidget( + title: LocalSettings.instance.isFaceIndexingEnabled + ? "Disable Indexing (${snapshot.data!.length})" + : "Enable indexing (${snapshot.data!.length})", + ); + } + return const SizedBox.shrink(); + }, + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + final isEnabled = + await LocalSettings.instance.toggleFaceIndexing(); + if (isEnabled) { + FaceMlService.instance.indexAllImages().ignore(); + } else { + FaceMlService.instance.pauseIndexing(); + } + if (mounted) { + setState(() {}); + } + } catch (e, s) { + _logger.warning('indexing failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Run Clustering", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + await FaceMlService.instance.clusterAllImages(minFaceScore: 0.75); + Bus.instance.fire(PeopleChangedEvent()); + showShortToast(context, "Done"); + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Reset feedback & labels", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + await FaceMLDataDB.instance.resetClusterIDs(); + await FaceMLDataDB.instance.dropClustersAndPeople(); + Bus.instance.fire(PeopleChangedEvent()); + showShortToast(context, "Done"); + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Drop embeddings & feedback", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + await showChoiceDialog( + context, + title: "Are you sure?", + body: + "You will need to again re-index all the faces. You can drop feedback if you want to label again", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + await FaceMLDataDB.instance.dropClustersAndPeople(faces: true); + Bus.instance.fire(PeopleChangedEvent()); + showShortToast(context, "Done"); + }, + ); + }, + ), + if (kDebugMode) sectionOptionSpacing, + if (kDebugMode) + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Pull Embeddings From Local", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + final List persons = + await FaceMLDataDB.instance.getPeople(); + final EnteWatch w = EnteWatch('feedback')..start(); + for (final Person p in persons) { + await ClusterFeedbackService.instance + .getSuggestionsUsingMean(p); + w.logAndReset('suggestion calculated for ${p.attr.name}'); + } + w.log("done with feedback"); + showShortToast(context, "done avg"); + // await FaceMLDataDB.instance.bulkInsertFaces([]); + // final EnteWatch watch = EnteWatch("face_time")..start(); + + // final results = await downloadZip(); + // watch.logAndReset('downloaded and de-serialized'); + // await FaceMLDataDB.instance.bulkInsertFaces(results); + // watch.logAndReset('inserted in to db'); + // showShortToast(context, "Got ${results.length} results"); + } catch (e, s) { + _logger.warning('download failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + // _showKeyAttributesDialog(context); + }, + ), + if (kDebugMode) sectionOptionSpacing, + if (kDebugMode) + MenuItemWidget( + captionedTextWidget: FutureBuilder>( + future: FaceMLDataDB.instance.getIndexedFileIds(), + builder: (context, snapshot) { + if (snapshot.hasData) { + return CaptionedTextWidget( + title: "Read embeddings for ${snapshot.data!.length} files", + ); + } + return const CaptionedTextWidget( + title: "Loading...", + ); + }, + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + final EnteWatch watch = EnteWatch("read_embeddings")..start(); + final result = await FaceMLDataDB.instance.getFaceEmbeddingMap(); + watch.logAndReset('read embeddings ${result.length} '); + showShortToast( + context, + "Done in ${watch.elapsed.inSeconds} secs", + ); + }, + ), + ], + ); + } +} diff --git a/mobile/lib/ui/settings_page.dart b/mobile/lib/ui/settings_page.dart index 85ce16afb..5c45589a5 100644 --- a/mobile/lib/ui/settings_page.dart +++ b/mobile/lib/ui/settings_page.dart @@ -7,7 +7,6 @@ import 'package:photos/core/configuration.dart'; import 'package:photos/core/event_bus.dart'; import 'package:photos/events/opened_settings_event.dart'; import "package:photos/generated/l10n.dart"; -import 'package:photos/services/feature_flag_service.dart'; import "package:photos/services/storage_bonus_service.dart"; import 'package:photos/theme/colors.dart'; import 'package:photos/theme/ente_theme.dart'; @@ -17,7 +16,8 @@ import 'package:photos/ui/settings/about_section_widget.dart'; import 'package:photos/ui/settings/account_section_widget.dart'; import 'package:photos/ui/settings/app_version_widget.dart'; import 'package:photos/ui/settings/backup/backup_section_widget.dart'; -import 'package:photos/ui/settings/debug_section_widget.dart'; +import 'package:photos/ui/settings/debug/debug_section_widget.dart'; +import "package:photos/ui/settings/debug/face_debug_section_widget.dart"; import 'package:photos/ui/settings/general_section_widget.dart'; import 'package:photos/ui/settings/inherited_settings_state.dart'; import 'package:photos/ui/settings/security_section_widget.dart'; @@ -52,6 +52,10 @@ class SettingsPage extends StatelessWidget { final hasLoggedIn = Configuration.instance.isLoggedIn(); final enteTextTheme = getEnteTextTheme(context); final List contents = []; + const sectionSpacing = SizedBox(height: 8); + if (kDebugMode) { + contents.addAll([const FaceDebugSectionWidget(), sectionSpacing]); + } contents.add( GestureDetector( onDoubleTap: () { @@ -81,7 +85,7 @@ class SettingsPage extends StatelessWidget { ), ), ); - const sectionSpacing = SizedBox(height: 8); + contents.add(const SizedBox(height: 8)); if (hasLoggedIn) { final showStorageBonusBanner = @@ -139,9 +143,9 @@ class SettingsPage extends StatelessWidget { const AboutSectionWidget(), ]); - if (hasLoggedIn && - FeatureFlagService.instance.isInternalUserOrDebugBuild()) { + if (hasLoggedIn) { contents.addAll([sectionSpacing, const DebugSectionWidget()]); + contents.addAll([sectionSpacing, const FaceDebugSectionWidget()]); } contents.add(const AppVersionWidget()); contents.add( diff --git a/mobile/lib/ui/tools/app_lock.dart b/mobile/lib/ui/tools/app_lock.dart index 1fbc1678e..fe05e12b4 100644 --- a/mobile/lib/ui/tools/app_lock.dart +++ b/mobile/lib/ui/tools/app_lock.dart @@ -113,6 +113,7 @@ class _AppLockState extends State with WidgetsBindingObserver { theme: widget.lightTheme, darkTheme: widget.darkTheme, locale: widget.locale, + debugShowCheckedModeBanner: false, supportedLocales: appSupportedLocales, localeListResolutionCallback: localResolutionCallBack, localizationsDelegates: const [ diff --git a/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart b/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart index dff39ef60..369269673 100644 --- a/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart +++ b/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart @@ -1,10 +1,15 @@ import "dart:async"; import 'package:fast_base58/fast_base58.dart'; +import "package:flutter/cupertino.dart"; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; import "package:modal_bottom_sheet/modal_bottom_sheet.dart"; import 'package:photos/core/configuration.dart'; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; import "package:photos/generated/l10n.dart"; import 'package:photos/models/collection/collection.dart'; import 'package:photos/models/device_collection.dart'; @@ -15,6 +20,7 @@ import 'package:photos/models/gallery_type.dart'; import "package:photos/models/metadata/common_keys.dart"; import 'package:photos/models/selected_files.dart'; import 'package:photos/services/collections_service.dart'; +import "package:photos/services/face_ml/feedback/cluster_feedback.dart"; import 'package:photos/services/hidden_service.dart'; import "package:photos/theme/colors.dart"; import "package:photos/theme/ente_theme.dart"; @@ -39,12 +45,16 @@ class FileSelectionActionsWidget extends StatefulWidget { final Collection? collection; final DeviceCollection? deviceCollection; final SelectedFiles selectedFiles; + final Person? person; + final int? clusterID; const FileSelectionActionsWidget( this.type, this.selectedFiles, { Key? key, this.collection, + this.person, + this.clusterID, this.deviceCollection, }) : super(key: key); @@ -116,7 +126,24 @@ class _FileSelectionActionsWidgetState //and set [shouldShow] to false for items that should not be shown and true //for items that should be shown. final List items = []; - + if (widget.type == GalleryType.peopleTag && widget.person != null) { + items.add( + SelectionActionButton( + icon: Icons.remove_circle_outline, + labelText: 'Not ${widget.person!.attr.name}?', + onTap: anyUploadedFiles ? _onNotpersonClicked : null, + ), + ); + if (ownedFilesCount == 1) { + items.add( + SelectionActionButton( + icon: Icons.image_outlined, + labelText: 'Use as cover', + onTap: anyUploadedFiles ? _setPersonCover : null, + ), + ); + } + } if (widget.type.showCreateLink()) { if (_cachedCollectionForSharedLink != null && anyUploadedFiles) { items.add( @@ -374,6 +401,16 @@ class _FileSelectionActionsWidgetState ), ); + if (widget.type == GalleryType.cluster) { + items.add( + SelectionActionButton( + labelText: 'Remove', + icon: CupertinoIcons.minus, + onTap: () => showToast(context, 'yet to implement'), + ), + ); + } + if (items.isNotEmpty) { final scrollController = ScrollController(); // h4ck: https://github.com/flutter/flutter/issues/57920#issuecomment-893970066 @@ -613,6 +650,59 @@ class _FileSelectionActionsWidgetState } } + Future _setPersonCover() async { + final EnteFile file = widget.selectedFiles.files.first; + final Person newPerson = widget.person!.copyWith( + attr: widget.person!.attr + .copyWith(avatarFaceId: file.uploadedFileID.toString()), + ); + await FaceMLDataDB.instance.updatePerson(newPerson); + widget.selectedFiles.clearAll(); + if (mounted) { + setState(() => {}); + } + Bus.instance.fire(PeopleChangedEvent()); + } + + Future _onNotpersonClicked() async { + final actionResult = await showActionSheet( + context: context, + buttons: [ + ButtonWidget( + labelText: S.of(context).yesRemove, + buttonType: ButtonType.neutral, + buttonSize: ButtonSize.large, + shouldStickToDarkTheme: true, + buttonAction: ButtonAction.first, + isInAlert: true, + ), + ButtonWidget( + labelText: S.of(context).cancel, + buttonType: ButtonType.secondary, + buttonSize: ButtonSize.large, + buttonAction: ButtonAction.second, + shouldStickToDarkTheme: true, + isInAlert: true, + ), + ], + title: "Remove these photos for ${widget.person!.attr.name}?", + actionSheetType: ActionSheetType.defaultActionSheet, + ); + if (actionResult?.action != null) { + if (actionResult!.action == ButtonAction.first) { + await ClusterFeedbackService.instance.removePersonFromFiles( + widget.selectedFiles.files.toList(), + widget.person!, + ); + } + Bus.instance.fire(PeopleChangedEvent()); + } + widget.selectedFiles.clearAll(); + if (mounted) { + setState(() => {}); + } + } + Future _copyLink() async { if (_cachedCollectionForSharedLink != null) { final String collectionKey = Base58Encode( diff --git a/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart b/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart index bc832c573..f0d258956 100644 --- a/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart +++ b/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart @@ -1,4 +1,5 @@ import 'package:flutter/material.dart'; +import "package:photos/face/model/person.dart"; import 'package:photos/models/collection/collection.dart'; import 'package:photos/models/gallery_type.dart'; import 'package:photos/models/selected_files.dart'; @@ -10,12 +11,14 @@ class FileSelectionOverlayBar extends StatefulWidget { final SelectedFiles selectedFiles; final Collection? collection; final Color? backgroundColor; + final Person? person; const FileSelectionOverlayBar( this.galleryType, this.selectedFiles, { this.collection, this.backgroundColor, + this.person, Key? key, }) : super(key: key); @@ -65,6 +68,7 @@ class _FileSelectionOverlayBarState extends State { selectedFiles: widget.selectedFiles, galleryType: widget.galleryType, collection: widget.collection, + person: widget.person, onCancel: () { if (widget.selectedFiles.files.isNotEmpty) { widget.selectedFiles.clearAll(); diff --git a/mobile/lib/ui/viewer/file/file_details_widget.dart b/mobile/lib/ui/viewer/file/file_details_widget.dart index f8e7abb8e..13c0e8b79 100644 --- a/mobile/lib/ui/viewer/file/file_details_widget.dart +++ b/mobile/lib/ui/viewer/file/file_details_widget.dart @@ -18,9 +18,9 @@ import "package:photos/ui/viewer/file_details/albums_item_widget.dart"; import 'package:photos/ui/viewer/file_details/backed_up_time_item_widget.dart'; import "package:photos/ui/viewer/file_details/creation_time_item_widget.dart"; import 'package:photos/ui/viewer/file_details/exif_item_widgets.dart'; +import "package:photos/ui/viewer/file_details/faces_item_widget.dart"; import "package:photos/ui/viewer/file_details/file_properties_item_widget.dart"; import "package:photos/ui/viewer/file_details/location_tags_widget.dart"; -import "package:photos/ui/viewer/file_details/objects_item_widget.dart"; import "package:photos/utils/exif_util.dart"; class FileDetailsWidget extends StatefulWidget { @@ -221,7 +221,8 @@ class _FileDetailsWidgetState extends State { if (!UpdateService.instance.isFdroidFlavor()) { fileDetailsTiles.addAll([ - ObjectsItemWidget(file), + // ObjectsItemWidget(file), + FacesItemWidget(file), const FileDetailsDivider(), ]); } diff --git a/mobile/lib/ui/viewer/file/zoomable_image.dart b/mobile/lib/ui/viewer/file/zoomable_image.dart index 3329e6955..266dbbaaa 100644 --- a/mobile/lib/ui/viewer/file/zoomable_image.dart +++ b/mobile/lib/ui/viewer/file/zoomable_image.dart @@ -1,5 +1,5 @@ import 'dart:async'; -import 'dart:io'; +import 'dart:io' as io; import 'package:flutter/material.dart'; import 'package:flutter/widgets.dart'; @@ -198,7 +198,7 @@ class _ZoomableImageState extends State _loadingFinalImage = true; getFile( _photo, - isOrigin: Platform.isIOS && + isOrigin: io.Platform.isIOS && _isGIF(), // since on iOS GIFs playback only when origin-files are loaded ).then((file) { if (file != null && file.existsSync()) { @@ -240,7 +240,25 @@ class _ZoomableImageState extends State } } - void _onFinalImageLoaded(ImageProvider imageProvider) { + void _onFinalImageLoaded(ImageProvider imageProvider) async { + // // final result = await FaceMlService.instance.analyzeImage( + // // _photo, + // // preferUsingThumbnailForEverything: false, + // // disposeImageIsolateAfterUse: false, + // // ); + // // _logger.info("FaceMlService result: $result"); + // // _logger.info("Number of faces detected: ${result.faces.length}"); + // // _logger.info("Box: ${result.faces[0].detection.box}"); + // // _logger.info("Landmarks: ${result.faces[0].detection.allKeypoints}"); + // // final embedding = result.faces[0].embedding; + // // Calculate the magnitude of the embedding vector + // double sum = 0; + // for (final double value in embedding) { + // sum += value * value; + // } + // final magnitude = math.sqrt(sum); + // log("Magnitude: $magnitude"); + // log("Embedding: $embedding"); if (mounted) { precacheImage(imageProvider, context).then((value) async { if (mounted) { diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart new file mode 100644 index 000000000..1ea94c8e3 --- /dev/null +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -0,0 +1,160 @@ +import "dart:developer" show log; +import "dart:typed_data"; + +import "package:flutter/material.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/person.dart"; +import 'package:photos/models/file/file.dart'; +import "package:photos/services/search_service.dart"; +import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/people/people_page.dart"; +import "package:photos/utils/face/face_box_crop.dart"; +import "package:photos/utils/thumbnail_util.dart"; + +class FaceWidget extends StatelessWidget { + final EnteFile file; + final Face face; + final Person? person; + final int? clusterID; + + const FaceWidget( + this.file, + this.face, { + this.person, + this.clusterID, + Key? key, + }) : super(key: key); + + @override + Widget build(BuildContext context) { + return FutureBuilder( + future: getFaceCrop(), + builder: (context, snapshot) { + if (snapshot.hasData) { + final ImageProvider imageProvider = MemoryImage(snapshot.data!); + return GestureDetector( + onTap: () async { + log( + "FaceWidget is tapped, with person $person and clusterID $clusterID", + name: "FaceWidget", + ); + if (person == null && clusterID == null) { + return; + } + if (person != null) { + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => PeoplePage( + person: person!, + ), + ), + ); + } else if (clusterID != null) { + final fileIdsToClusterIds = + await FaceMLDataDB.instance.getFileIdToClusterIds(); + final files = await SearchService.instance.getAllFiles(); + final clusterFiles = files + .where( + (file) => + fileIdsToClusterIds[file.uploadedFileID] + ?.contains(clusterID) ?? + false, + ) + .toList(); + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + clusterFiles, + cluserID: clusterID!, + ), + ), + ); + } + }, + child: Column( + children: [ + ClipOval( + child: SizedBox( + width: 60, + height: 60, + child: Image( + image: imageProvider, + fit: BoxFit.cover, + ), + ), + ), + const SizedBox(height: 8), + if (person != null) + Text( + person!.attr.name.trim(), + style: Theme.of(context).textTheme.bodySmall, + overflow: TextOverflow.ellipsis, + maxLines: 1, + ), + ], + ), + ); + } else { + if (snapshot.connectionState == ConnectionState.waiting) { + return const ClipOval( + child: SizedBox( + width: 60, // Ensure consistent sizing + height: 60, + child: CircularProgressIndicator(), + ), + ); + } + if (snapshot.hasError) { + log('Error getting face: ${snapshot.error}'); + } + return const ClipOval( + child: SizedBox( + width: 60, // Ensure consistent sizing + height: 60, + child: NoThumbnailWidget(), + ), + ); + } + }, + ); + } + + Future getFaceCrop() async { + try { + final Uint8List? cachedFace = faceCropCache.get(face.faceID); + if (cachedFace != null) { + return cachedFace; + } + final faceCropCacheFile = cachedFaceCropPath(face.faceID); + if ((await faceCropCacheFile.exists())) { + final data = await faceCropCacheFile.readAsBytes(); + faceCropCache.put(face.faceID, data); + return data; + } + + final result = await pool.withResource( + () async => await getFaceCrops( + file, + { + face.faceID: face.detection.box, + }, + ), + ); + final Uint8List? computedCrop = result?[face.faceID]; + if (computedCrop != null) { + faceCropCache.put(face.faceID, computedCrop); + faceCropCacheFile.writeAsBytes(computedCrop).ignore(); + } + return computedCrop; + } catch (e, s) { + log( + "Error getting face for faceID: ${face.faceID}", + error: e, + stackTrace: s, + ); + return null; + } + } +} diff --git a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart new file mode 100644 index 000000000..a5d2ba809 --- /dev/null +++ b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart @@ -0,0 +1,79 @@ +import "package:flutter/material.dart"; +import "package:logging/logging.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/ui/components/buttons/chip_button_widget.dart"; +import "package:photos/ui/components/info_item_widget.dart"; +import "package:photos/ui/viewer/file_details/face_widget.dart"; + +class FacesItemWidget extends StatelessWidget { + final EnteFile file; + const FacesItemWidget(this.file, {super.key}); + + @override + Widget build(BuildContext context) { + return InfoItemWidget( + key: const ValueKey("Faces"), + leadingIcon: Icons.face_retouching_natural_outlined, + subtitleSection: _faceWidgets(context, file), + hasChipButtons: true, + ); + } + + Future> _faceWidgets( + BuildContext context, + EnteFile file, + ) async { + try { + if (file.uploadedFileID == null) { + return [ + const ChipButtonWidget( + "File not uploaded yet", + noChips: true, + ), + ]; + } + + final List faces = await FaceMLDataDB.instance + .getFacesForGivenFileID(file.uploadedFileID!); + if (faces.isEmpty || faces.every((face) => face.score < 0.5)) { + return [ + const ChipButtonWidget( + "No faces found", + noChips: true, + ), + ]; + } + + // Sort the faces by score in descending order, so that the highest scoring face is first. + faces.sort((Face a, Face b) => b.score.compareTo(a.score)); + + // TODO: add deduplication of faces of same person + final faceIdsToClusterIds = await FaceMLDataDB.instance + .getFaceIdsToClusterIds(faces.map((face) => face.faceID)); + final (clusterIDToPerson, personIdToPerson) = + await FaceMLDataDB.instance.getClusterIdToPerson(); + + final faceWidgets = []; + for (final Face face in faces) { + final int? clusterID = faceIdsToClusterIds[face.faceID]; + final Person? person = clusterIDToPerson[clusterID]; + faceWidgets.add( + FaceWidget( + file, + face, + clusterID: clusterID, + person: person, + ), + ); + } + + return faceWidgets; + } catch (e, s) { + Logger("FacesItemWidget").info(e, s); + return []; + } + } +} diff --git a/mobile/lib/ui/viewer/file_details/objects_item_widget.dart b/mobile/lib/ui/viewer/file_details/objects_item_widget.dart index 5b91b9b12..c02576c11 100644 --- a/mobile/lib/ui/viewer/file_details/objects_item_widget.dart +++ b/mobile/lib/ui/viewer/file_details/objects_item_widget.dart @@ -27,6 +27,7 @@ class ObjectsItemWidget extends StatelessWidget { try { final chipButtons = []; var objectTags = {}; + // final thumbnail = await getThumbnail(file); // if (thumbnail != null) { // objectTags = await ObjectDetectionService.instance.predict(thumbnail); diff --git a/mobile/lib/ui/viewer/people/add_person_action_sheet.dart b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart new file mode 100644 index 000000000..4a072280f --- /dev/null +++ b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart @@ -0,0 +1,301 @@ +import "dart:async"; +import "dart:developer"; +import "dart:math" as math; + +import 'package:flutter/material.dart'; +import "package:logging/logging.dart"; +import 'package:modal_bottom_sheet/modal_bottom_sheet.dart'; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/generated/l10n.dart"; +import "package:photos/services/face_ml/feedback/cluster_feedback.dart"; +import 'package:photos/theme/colors.dart'; +import 'package:photos/theme/ente_theme.dart'; +import 'package:photos/ui/common/loading_widget.dart'; +import 'package:photos/ui/components/bottom_of_title_bar_widget.dart'; +import 'package:photos/ui/components/buttons/button_widget.dart'; +import 'package:photos/ui/components/models/button_type.dart'; +import "package:photos/ui/components/text_input_widget.dart"; +import 'package:photos/ui/components/title_bar_title_widget.dart'; +import "package:photos/ui/viewer/people/new_person_item_widget.dart"; +import "package:photos/ui/viewer/people/person_row_item.dart"; +import "package:photos/utils/dialog_util.dart"; +import "package:photos/utils/toast_util.dart"; +import "package:uuid/uuid.dart"; + +enum PersonActionType { + assignPerson, +} + +String _actionName( + BuildContext context, + PersonActionType type, +) { + String text = ""; + switch (type) { + case PersonActionType.assignPerson: + text = "Add name"; + break; + } + return text; +} + +Future showAssignPersonAction( + BuildContext context, { + required int clusterID, + PersonActionType actionType = PersonActionType.assignPerson, + bool showOptionToCreateNewAlbum = true, +}) { + return showBarModalBottomSheet( + context: context, + builder: (context) { + return PersonActionSheet( + actionType: actionType, + showOptionToCreateNewAlbum: showOptionToCreateNewAlbum, + cluserID: clusterID, + ); + }, + shape: const RoundedRectangleBorder( + side: BorderSide(width: 0), + borderRadius: BorderRadius.vertical( + top: Radius.circular(5), + ), + ), + topControl: const SizedBox.shrink(), + backgroundColor: getEnteColorScheme(context).backgroundElevated, + barrierColor: backdropFaintDark, + enableDrag: false, + ); +} + +class PersonActionSheet extends StatefulWidget { + final PersonActionType actionType; + final int cluserID; + final bool showOptionToCreateNewAlbum; + const PersonActionSheet({ + required this.actionType, + required this.cluserID, + required this.showOptionToCreateNewAlbum, + super.key, + }); + + @override + State createState() => _PersonActionSheetState(); +} + +class _PersonActionSheetState extends State { + static const int cancelButtonSize = 80; + String _searchQuery = ""; + + @override + void initState() { + super.initState(); + } + + @override + Widget build(BuildContext context) { + final bottomInset = MediaQuery.of(context).viewInsets.bottom; + final isKeyboardUp = bottomInset > 100; + return Padding( + padding: EdgeInsets.only( + bottom: isKeyboardUp ? bottomInset - cancelButtonSize : 0, + ), + child: Row( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + ConstrainedBox( + constraints: BoxConstraints( + maxWidth: math.min(428, MediaQuery.of(context).size.width), + ), + child: Padding( + padding: const EdgeInsets.fromLTRB(0, 32, 0, 8), + child: Column( + mainAxisSize: MainAxisSize.max, + children: [ + Expanded( + child: Column( + children: [ + BottomOfTitleBarWidget( + title: TitleBarTitleWidget( + title: _actionName(context, widget.actionType), + ), + // caption: 'Select or create a ', + ), + Padding( + padding: const EdgeInsets.only( + top: 16, + left: 16, + right: 16, + ), + child: TextInputWidget( + hintText: 'Person name', + prefixIcon: Icons.search_rounded, + onChange: (value) { + setState(() { + _searchQuery = value; + }); + }, + isClearable: true, + shouldUnfocusOnClearOrSubmit: true, + borderRadius: 2, + ), + ), + _getPersonItems(), + ], + ), + ), + SafeArea( + child: Container( + //inner stroke of 1pt + 15 pts of top padding = 16 pts + padding: const EdgeInsets.fromLTRB(16, 15, 16, 8), + decoration: BoxDecoration( + border: Border( + top: BorderSide( + color: getEnteColorScheme(context).strokeFaint, + ), + ), + ), + child: ButtonWidget( + buttonType: ButtonType.secondary, + buttonAction: ButtonAction.cancel, + isInAlert: true, + labelText: S.of(context).cancel, + ), + ), + ), + ], + ), + ), + ), + ], + ), + ); + } + + Flexible _getPersonItems() { + return Flexible( + child: Padding( + padding: const EdgeInsets.fromLTRB(16, 24, 4, 0), + child: FutureBuilder>( + future: _getPersons(), + builder: (context, snapshot) { + if (snapshot.hasError) { + log("Error: ${snapshot.error} ${snapshot.stackTrace}}"); + //Need to show an error on the UI here + return const SizedBox.shrink(); + } else if (snapshot.hasData) { + final persons = snapshot.data as List; + final searchResults = _searchQuery.isNotEmpty + ? persons + .where( + (element) => element.attr.name + .toLowerCase() + .contains(_searchQuery), + ) + .toList() + : persons; + final shouldShowCreateAlbum = widget.showOptionToCreateNewAlbum && + (_searchQuery.isEmpty || searchResults.isEmpty); + + return Scrollbar( + thumbVisibility: true, + radius: const Radius.circular(2), + child: Padding( + padding: const EdgeInsets.only(right: 12), + child: ListView.separated( + itemCount: + searchResults.length + (shouldShowCreateAlbum ? 1 : 0), + itemBuilder: (context, index) { + if (index == 0 && shouldShowCreateAlbum) { + return GestureDetector( + child: const NewPersonItemWidget(), + onTap: () async => { + addNewPerson( + context, + initValue: _searchQuery.trim(), + clusterID: widget.cluserID, + ), + }, + ); + } + final person = searchResults[ + index - (shouldShowCreateAlbum ? 1 : 0)]; + return PersonRowItem( + person: person, + onTap: () async { + await FaceMLDataDB.instance.assignClusterToPerson( + personID: person.remoteID, + clusterID: widget.cluserID, + ); + Bus.instance.fire(PeopleChangedEvent()); + + Navigator.pop(context, person); + }, + ); + }, + separatorBuilder: (context, index) { + return const SizedBox(height: 2); + }, + ), + ), + ); + } else { + return const EnteLoadingWidget(); + } + }, + ), + ), + ); + } + + Future addNewPerson( + BuildContext context, { + String initValue = '', + required int clusterID, + }) async { + final result = await showTextInputDialog( + context, + title: "New person", + submitButtonLabel: 'Add', + hintText: 'Add name', + alwaysShowSuccessState: false, + initialValue: initValue, + textCapitalization: TextCapitalization.words, + onSubmit: (String text) async { + // indicates user cancelled the rename request + if (text.trim() == "") { + return; + } + try { + final String id = const Uuid().v4().toString(); + final Person p = Person( + id, + PersonAttr(name: text, faces: []), + ); + await FaceMLDataDB.instance.insert(p, clusterID); + final bool extraPhotosFound = + await ClusterFeedbackService.instance.checkAndDoAutomaticMerges(p); + if (extraPhotosFound) { + showShortToast(context, "Extra photos found for $text"); + } + Bus.instance.fire(PeopleChangedEvent()); + Navigator.pop(context, p); + log("inserted person"); + } catch (e, s) { + Logger("_PersonActionSheetState") + .severe("Failed to rename album", e, s); + rethrow; + } + }, + ); + if (result is Exception) { + await showGenericErrorDialog(context: context, error: result); + } + } + + Future> _getPersons() async { + return FaceMLDataDB.instance.getPeople(); + } +} diff --git a/mobile/lib/ui/viewer/people/cluster_page.dart b/mobile/lib/ui/viewer/people/cluster_page.dart new file mode 100644 index 000000000..4fdb2e977 --- /dev/null +++ b/mobile/lib/ui/viewer/people/cluster_page.dart @@ -0,0 +1,140 @@ +import "dart:async"; + +import 'package:flutter/material.dart'; +import 'package:photos/core/event_bus.dart'; +import 'package:photos/events/files_updated_event.dart'; +import 'package:photos/events/local_photos_updated_event.dart'; +import "package:photos/face/model/person.dart"; +import 'package:photos/models/file/file.dart'; +import 'package:photos/models/file_load_result.dart'; +import 'package:photos/models/gallery_type.dart'; +import 'package:photos/models/selected_files.dart'; +import 'package:photos/ui/viewer/actions/file_selection_overlay_bar.dart'; +import 'package:photos/ui/viewer/gallery/gallery.dart'; +import 'package:photos/ui/viewer/gallery/gallery_app_bar_widget.dart'; +import "package:photos/ui/viewer/people/add_person_action_sheet.dart"; +import "package:photos/ui/viewer/people/people_page.dart"; +import "package:photos/ui/viewer/search/result/search_result_page.dart"; +import "package:photos/utils/navigation_util.dart"; +import "package:photos/utils/toast_util.dart"; + +class ClusterPage extends StatefulWidget { + final List searchResult; + final bool enableGrouping; + final String tagPrefix; + final int cluserID; + final Person? personID; + + static const GalleryType appBarType = GalleryType.cluster; + static const GalleryType overlayType = GalleryType.cluster; + + const ClusterPage( + this.searchResult, { + this.enableGrouping = true, + this.tagPrefix = "", + required this.cluserID, + this.personID, + Key? key, + }) : super(key: key); + + @override + State createState() => _ClusterPageState(); +} + +class _ClusterPageState extends State { + final _selectedFiles = SelectedFiles(); + late final List files; + late final StreamSubscription _filesUpdatedEvent; + + @override + void initState() { + super.initState(); + files = widget.searchResult; + _filesUpdatedEvent = + Bus.instance.on().listen((event) { + if (event.type == EventType.deletedFromDevice || + event.type == EventType.deletedFromEverywhere || + event.type == EventType.deletedFromRemote || + event.type == EventType.hide) { + for (var updatedFile in event.updatedFiles) { + files.remove(updatedFile); + } + setState(() {}); + } + }); + } + + @override + void dispose() { + _filesUpdatedEvent.cancel(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + final gallery = Gallery( + asyncLoader: (creationStartTime, creationEndTime, {limit, asc}) { + final result = files + .where( + (file) => + file.creationTime! >= creationStartTime && + file.creationTime! <= creationEndTime, + ) + .toList(); + return Future.value( + FileLoadResult( + result, + result.length < files.length, + ), + ); + }, + reloadEvent: Bus.instance.on(), + removalEventTypes: const { + EventType.deletedFromRemote, + EventType.deletedFromEverywhere, + EventType.hide, + }, + tagPrefix: widget.tagPrefix + widget.tagPrefix, + selectedFiles: _selectedFiles, + enableFileGrouping: widget.enableGrouping, + initialFiles: [widget.searchResult.first], + ); + return Scaffold( + appBar: PreferredSize( + preferredSize: const Size.fromHeight(50.0), + child: GestureDetector( + onTap: () async { + if (widget.personID == null) { + final result = await showAssignPersonAction( + context, + clusterID: widget.cluserID, + ); + if (result != null && result is Person) { + Navigator.pop(context); + // ignore: unawaited_futures + routeToPage(context, PeoplePage(person: result)); + } + } else { + showShortToast(context, "11No personID or clusterID"); + } + }, + child: GalleryAppBarWidget( + SearchResultPage.appBarType, + widget.personID != null ? widget.personID!.attr.name : "Add name", + _selectedFiles, + ), + ), + ), + body: Stack( + alignment: Alignment.bottomCenter, + children: [ + gallery, + FileSelectionOverlayBar( + ClusterPage.overlayType, + _selectedFiles, + ), + ], + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/new_person_item_widget.dart b/mobile/lib/ui/viewer/people/new_person_item_widget.dart new file mode 100644 index 000000000..83d5e157d --- /dev/null +++ b/mobile/lib/ui/viewer/people/new_person_item_widget.dart @@ -0,0 +1,73 @@ +import 'package:dotted_border/dotted_border.dart'; +import 'package:flutter/material.dart'; +import 'package:photos/theme/ente_theme.dart'; + +///https://www.figma.com/file/SYtMyLBs5SAOkTbfMMzhqt/ente-Visual-Design?node-id=10854%3A57947&t=H5AvR79OYDnB9ekw-4 +class NewPersonItemWidget extends StatelessWidget { + const NewPersonItemWidget({ + super.key, + }); + + @override + Widget build(BuildContext context) { + final textTheme = getEnteTextTheme(context); + final colorScheme = getEnteColorScheme(context); + const sideOfThumbnail = 60.0; + return LayoutBuilder( + builder: (context, constraints) { + return Stack( + alignment: Alignment.center, + children: [ + Row( + children: [ + ClipRRect( + borderRadius: const BorderRadius.horizontal( + left: Radius.circular(4), + ), + child: SizedBox( + height: sideOfThumbnail, + width: sideOfThumbnail, + child: Icon( + Icons.add_outlined, + color: colorScheme.strokeMuted, + ), + ), + ), + Padding( + padding: const EdgeInsets.only(left: 12), + child: Text( + 'Add person', + style: + textTheme.body.copyWith(color: colorScheme.textMuted), + ), + ), + ], + ), + IgnorePointer( + child: DottedBorder( + dashPattern: const [4], + color: colorScheme.strokeFainter, + strokeWidth: 1, + padding: const EdgeInsets.all(0), + borderType: BorderType.RRect, + radius: const Radius.circular(4), + child: SizedBox( + //Have to decrease the height and width by 1 pt as the stroke + //dotted border gives is of strokeAlign.center, so 0.5 inside and + // outside. Here for the row, stroke should be inside so we + //decrease the size of this sizedBox by 1 (so it shrinks 0.5 from + //every side) so that the strokeAlign.center of this sizedBox + //looks like a strokeAlign.inside in the row. + height: sideOfThumbnail - 1, + //This width will work for this only if the row widget takes up the + //full size it's parent (stack). + width: constraints.maxWidth - 1, + ), + ), + ), + ], + ); + }, + ); + } +} diff --git a/mobile/lib/ui/viewer/people/people_app_bar.dart b/mobile/lib/ui/viewer/people/people_app_bar.dart new file mode 100644 index 000000000..fe5af20bf --- /dev/null +++ b/mobile/lib/ui/viewer/people/people_app_bar.dart @@ -0,0 +1,256 @@ +import 'dart:async'; + +import "package:flutter/cupertino.dart"; +import 'package:flutter/material.dart'; +import 'package:logging/logging.dart'; +import 'package:photos/core/configuration.dart'; +import 'package:photos/core/event_bus.dart'; +import "package:photos/events/people_changed_event.dart"; +import 'package:photos/events/subscription_purchased_event.dart'; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/generated/l10n.dart"; +import 'package:photos/models/gallery_type.dart'; +import 'package:photos/models/selected_files.dart'; +import 'package:photos/services/collections_service.dart'; +import 'package:photos/ui/actions/collection/collection_sharing_actions.dart'; +import "package:photos/ui/viewer/people/person_cluserts.dart"; +import "package:photos/ui/viewer/people/person_cluster_suggestion.dart"; +import "package:photos/utils/dialog_util.dart"; + +class PeopleAppBar extends StatefulWidget { + final GalleryType type; + final String? title; + final SelectedFiles selectedFiles; + final Person person; + + const PeopleAppBar( + this.type, + this.title, + this.selectedFiles, + this.person, { + Key? key, + }) : super(key: key); + + @override + State createState() => _AppBarWidgetState(); +} + +enum PeoplPopupAction { + rename, + setCover, + viewPhotos, + confirmPhotos, + hide, +} + +class _AppBarWidgetState extends State { + final _logger = Logger("_AppBarWidgetState"); + late StreamSubscription _userAuthEventSubscription; + late Function() _selectedFilesListener; + String? _appBarTitle; + late CollectionActions collectionActions; + final GlobalKey shareButtonKey = GlobalKey(); + bool isQuickLink = false; + late GalleryType galleryType; + + @override + void initState() { + super.initState(); + _selectedFilesListener = () { + setState(() {}); + }; + collectionActions = CollectionActions(CollectionsService.instance); + widget.selectedFiles.addListener(_selectedFilesListener); + _userAuthEventSubscription = + Bus.instance.on().listen((event) { + setState(() {}); + }); + _appBarTitle = widget.title; + galleryType = widget.type; + } + + @override + void dispose() { + _userAuthEventSubscription.cancel(); + widget.selectedFiles.removeListener(_selectedFilesListener); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + return AppBar( + elevation: 0, + centerTitle: false, + title: Text( + _appBarTitle!, + style: + Theme.of(context).textTheme.headlineSmall!.copyWith(fontSize: 16), + maxLines: 2, + overflow: TextOverflow.ellipsis, + ), + actions: _getDefaultActions(context), + ); + } + + Future _renameAlbum(BuildContext context) async { + final result = await showTextInputDialog( + context, + title: 'Rename', + submitButtonLabel: S.of(context).done, + hintText: S.of(context).enterAlbumName, + alwaysShowSuccessState: true, + initialValue: widget.person.attr.name, + textCapitalization: TextCapitalization.words, + onSubmit: (String text) async { + // indicates user cancelled the rename request + if (text == "" || text == _appBarTitle!) { + return; + } + + try { + final updatePerson = widget.person + .copyWith(attr: widget.person.attr.copyWith(name: text)); + await FaceMLDataDB.instance.updatePerson(updatePerson); + if (mounted) { + _appBarTitle = text; + setState(() {}); + } + Bus.instance.fire(PeopleChangedEvent()); + } catch (e, s) { + _logger.severe("Failed to rename album", e, s); + rethrow; + } + }, + ); + if (result is Exception) { + await showGenericErrorDialog(context: context, error: result); + } + } + + List _getDefaultActions(BuildContext context) { + final List actions = []; + // If the user has selected files, don't show any actions + if (widget.selectedFiles.files.isNotEmpty || + !Configuration.instance.hasConfiguredAccount()) { + return actions; + } + + final List> items = []; + + items.addAll( + [ + PopupMenuItem( + value: PeoplPopupAction.rename, + child: Row( + children: [ + const Icon(Icons.edit), + const Padding( + padding: EdgeInsets.all(8), + ), + Text(S.of(context).rename), + ], + ), + ), + // PopupMenuItem( + // value: PeoplPopupAction.setCover, + // child: Row( + // children: [ + // const Icon(Icons.image_outlined), + // const Padding( + // padding: EdgeInsets.all(8), + // ), + // Text(S.of(context).setCover), + // ], + // ), + // ), + // PopupMenuItem( + // value: PeoplPopupAction.rename, + // child: Row( + // children: [ + // const Icon(Icons.visibility_off), + // const Padding( + // padding: EdgeInsets.all(8), + // ), + // Text(S.of(context).hide), + // ], + // ), + // ), + const PopupMenuItem( + value: PeoplPopupAction.viewPhotos, + child: Row( + children: [ + Icon(Icons.view_array_outlined), + Padding( + padding: EdgeInsets.all(8), + ), + Text('View confirmed photos'), + ], + ), + ), + const PopupMenuItem( + value: PeoplPopupAction.confirmPhotos, + child: Row( + children: [ + Icon(CupertinoIcons.square_stack_3d_down_right), + Padding( + padding: EdgeInsets.all(8), + ), + Text('Review suggestions'), + ], + ), + ), + ], + ); + + if (items.isNotEmpty) { + actions.add( + PopupMenuButton( + itemBuilder: (context) { + return items; + }, + onSelected: (PeoplPopupAction value) async { + if (value == PeoplPopupAction.viewPhotos) { + // ignore: unawaited_futures + unawaited( + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => PersonClusters(widget.person), + ), + ), + ); + } else if (value == PeoplPopupAction.confirmPhotos) { + // ignore: unawaited_futures + unawaited( + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => + PersonReviewClusterSuggestion(widget.person), + ), + ), + ); + } else if (value == PeoplPopupAction.rename) { + await _renameAlbum(context); + } else if (value == PeoplPopupAction.setCover) { + await setCoverPhoto(context); + } else if (value == PeoplPopupAction.hide) { + // ignore: unawaited_futures + } + }, + ), + ); + } + + return actions; + } + + Future setCoverPhoto(BuildContext context) async { + // final int? coverPhotoID = await showPickCoverPhotoSheet( + // context, + // widget.collection!, + // ); + // if (coverPhotoID != null) { + // unawaited(changeCoverPhoto(context, widget.collection!, coverPhotoID)); + // } + } +} diff --git a/mobile/lib/ui/viewer/people/people_page.dart b/mobile/lib/ui/viewer/people/people_page.dart new file mode 100644 index 000000000..8ae365368 --- /dev/null +++ b/mobile/lib/ui/viewer/people/people_page.dart @@ -0,0 +1,155 @@ +import "dart:async"; +import "dart:developer"; + +import 'package:flutter/material.dart'; +import "package:logging/logging.dart"; +import 'package:photos/core/event_bus.dart'; +import 'package:photos/events/files_updated_event.dart'; +import 'package:photos/events/local_photos_updated_event.dart'; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/model/person.dart"; +import 'package:photos/models/file/file.dart'; +import 'package:photos/models/file_load_result.dart'; +import 'package:photos/models/gallery_type.dart'; +import 'package:photos/models/selected_files.dart'; +import "package:photos/services/search_service.dart"; +import 'package:photos/ui/viewer/actions/file_selection_overlay_bar.dart'; +import 'package:photos/ui/viewer/gallery/gallery.dart'; +import "package:photos/ui/viewer/people/people_app_bar.dart"; + +class PeoplePage extends StatefulWidget { + final String tagPrefix; + final Person person; + + static const GalleryType appBarType = GalleryType.peopleTag; + static const GalleryType overlayType = GalleryType.peopleTag; + + const PeoplePage({ + this.tagPrefix = "", + required this.person, + Key? key, + }) : super(key: key); + + @override + State createState() => _PeoplePageState(); +} + +class _PeoplePageState extends State { + final Logger _logger = Logger("_PeoplePageState"); + final _selectedFiles = SelectedFiles(); + List? files; + + late final StreamSubscription _filesUpdatedEvent; + late final StreamSubscription _peopleChangedEvent; + + @override + void initState() { + super.initState(); + _peopleChangedEvent = Bus.instance.on().listen((event) { + setState(() {}); + }); + + _filesUpdatedEvent = + Bus.instance.on().listen((event) { + if (event.type == EventType.deletedFromDevice || + event.type == EventType.deletedFromEverywhere || + event.type == EventType.deletedFromRemote || + event.type == EventType.hide) { + for (var updatedFile in event.updatedFiles) { + files?.remove(updatedFile); + } + setState(() {}); + } + }); + } + + Future> loadPersonFiles() async { + log("loadPersonFiles"); + final result = await SearchService.instance + .getClusterFilesForPersonID(widget.person.remoteID); + final List resultFiles = []; + for (final e in result.entries) { + resultFiles.addAll(e.value); + } + files = resultFiles; + return resultFiles; + } + + @override + void dispose() { + _filesUpdatedEvent.cancel(); + _peopleChangedEvent.cancel(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + _logger.info("Building for ${widget.person.attr.name}"); + return Scaffold( + appBar: PreferredSize( + preferredSize: const Size.fromHeight(50.0), + child: PeopleAppBar( + GalleryType.peopleTag, + widget.person.attr.name, + _selectedFiles, + widget.person, + ), + ), + body: Stack( + alignment: Alignment.bottomCenter, + children: [ + FutureBuilder>( + future: loadPersonFiles(), + builder: (context, snapshot) { + if (snapshot.hasData) { + final personFiles = snapshot.data as List; + return Gallery( + asyncLoader: ( + creationStartTime, + creationEndTime, { + limit, + asc, + }) async { + final result = await loadPersonFiles(); + return Future.value( + FileLoadResult( + result, + false, + ), + ); + }, + reloadEvent: Bus.instance.on(), + forceReloadEvents: [ + Bus.instance.on(), + ], + removalEventTypes: const { + EventType.deletedFromRemote, + EventType.deletedFromEverywhere, + EventType.hide, + }, + tagPrefix: widget.tagPrefix + widget.tagPrefix, + selectedFiles: _selectedFiles, + initialFiles: + personFiles.isNotEmpty ? [personFiles.first] : [], + ); + } else if (snapshot.hasError) { + log("Error: ${snapshot.error} ${snapshot.stackTrace}}"); + //Need to show an error on the UI here + return const SizedBox.shrink(); + } else { + return const Center( + child: CircularProgressIndicator(), + ); + } + }, + ), + FileSelectionOverlayBar( + PeoplePage.overlayType, + _selectedFiles, + person: widget.person, + ), + ], + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/person_cluserts.dart b/mobile/lib/ui/viewer/people/person_cluserts.dart new file mode 100644 index 000000000..d9cf15706 --- /dev/null +++ b/mobile/lib/ui/viewer/people/person_cluserts.dart @@ -0,0 +1,139 @@ +import "package:flutter/cupertino.dart"; +import "package:flutter/material.dart"; +import "package:logging/logging.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/services/search_service.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; +// import "package:photos/ui/viewer/file/thumbnail_widget.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/search/result/person_face_widget.dart"; + +class PersonClusters extends StatefulWidget { + final Person person; + + const PersonClusters( + this.person, { + super.key, + }); + + @override + State createState() => _PersonClustersState(); +} + +class _PersonClustersState extends State { + final Logger _logger = Logger("_PersonClustersState"); + @override + Widget build(BuildContext context) { + return Scaffold( + appBar: AppBar( + title: Text(widget.person.attr.name), + ), + body: FutureBuilder>>( + future: SearchService.instance + .getClusterFilesForPersonID(widget.person.remoteID), + builder: (context, snapshot) { + if (snapshot.hasData) { + final List keys = snapshot.data!.keys.toList(); + return ListView.builder( + itemCount: keys.length, + itemBuilder: (context, index) { + final int clusterID = keys[index]; + final List files = snapshot.data![keys[index]]!; + return InkWell( + onTap: () { + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + files, + personID: widget.person, + cluserID: index, + ), + ), + ); + }, + child: Container( + padding: const EdgeInsets.all(8.0), + child: Row( + children: [ + SizedBox( + width: 64, + height: 64, + child: files.isNotEmpty + ? ClipOval( + child: PersonFaceWidget( + files.first, + clusterID: clusterID, + ), + ) + : const ClipOval( + child: NoThumbnailWidget( + addBorder: false, + ), + ), + ), + const SizedBox( + width: 8.0, + ), // Add some spacing between the thumbnail and the text + Expanded( + child: Padding( + padding: + const EdgeInsets.symmetric(horizontal: 8.0), + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text( + "${snapshot.data![keys[index]]!.length} photos", + style: getEnteTextTheme(context).body, + ), + GestureDetector( + onTap: () async { + try { + final int result = await FaceMLDataDB + .instance + .removeClusterToPerson( + personID: widget.person.remoteID, + clusterID: clusterID, + ); + _logger.info( + "Removed cluster $clusterID from person ${widget.person.remoteID}, result: $result", + ); + Bus.instance.fire(PeopleChangedEvent()); + setState(() {}); + } catch (e) { + _logger.severe( + "removing cluster from person,", + e, + ); + } + }, + child: const Icon( + CupertinoIcons.minus_circled, + color: Colors.red, + ), + ), + ], + ), + ), + ), + ], + ), + ), + ); + }, + ); + } else if (snapshot.hasError) { + _logger.warning("Failed to get cluster", snapshot.error); + return const Center(child: Text("Error")); + } else { + return const Center(child: CircularProgressIndicator()); + } + }, + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart b/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart new file mode 100644 index 000000000..3ec179856 --- /dev/null +++ b/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart @@ -0,0 +1,244 @@ +import "dart:math"; + +import "package:flutter/material.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/services/face_ml/feedback/cluster_feedback.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/components/buttons/button_widget.dart"; +import "package:photos/ui/components/models/button_type.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/search/result/person_face_widget.dart"; + +class PersonReviewClusterSuggestion extends StatefulWidget { + final Person person; + + const PersonReviewClusterSuggestion( + this.person, { + super.key, + }); + + @override + State createState() => _PersonClustersState(); +} + +class _PersonClustersState extends State { + int currentSuggestionIndex = 0; + Key futureBuilderKey = UniqueKey(); + + // Declare a variable for the future + late Future)>> futureClusterSuggestions; + + @override + void initState() { + super.initState(); + // Initialize the future in initState + _fetchClusterSuggestions(); + // futureClusterSuggestions = ClusterFeedbackService.instance + // .getClusterFilesForPersonID(widget.person); + } + + @override + Widget build(BuildContext context) { + return Scaffold( + appBar: AppBar( + title: const Text('Review suggestions'), + ), + body: FutureBuilder)>>( + key: futureBuilderKey, + future: futureClusterSuggestions, + builder: (context, snapshot) { + if (snapshot.hasData) { + // final List keys = snapshot.data!.map((e) => e.$1).toList(); + if (snapshot.data!.isEmpty) { + return Center( + child: Text( + "No suggestions for ${widget.person.attr.name}", + style: getEnteTextTheme(context).largeMuted, + ), + ); + } + final numberOfDifferentSuggestions = snapshot.data!.length; + final currentSuggestion = snapshot.data![currentSuggestionIndex]; + final int clusterID = currentSuggestion.$1; + final List files = currentSuggestion.$2; + return InkWell( + onTap: () { + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + files, + personID: widget.person, + cluserID: clusterID, + ), + ), + ); + }, + child: Container( + padding: const EdgeInsets.symmetric( + horizontal: 8.0, + vertical: 20, + ), + child: _buildSuggestionView( + clusterID, + files, + numberOfDifferentSuggestions, + ), + ), + ); + } else if (snapshot.hasError) { + // log the error + return const Center(child: Text("Error")); + } else { + return const Center(child: CircularProgressIndicator()); + } + }, + ), + ); + } + + Future _handleUserClusterChoice( + int clusterID, + bool yesOrNo, + int numberOfSuggestions, + ) async { + // Perform the action based on clusterID, e.g., assignClusterToPerson or captureNotPersonFeedback + if (yesOrNo) { + await FaceMLDataDB.instance.assignClusterToPerson( + personID: widget.person.remoteID, + clusterID: clusterID, + ); + Bus.instance.fire(PeopleChangedEvent()); + } else { + await FaceMLDataDB.instance.captureNotPersonFeedback( + personID: widget.person.remoteID, + clusterID: clusterID, + ); + } + + // Increment the suggestion index + if (mounted) { + setState(() => currentSuggestionIndex++); + } + + // Check if we need to fetch new data + if (currentSuggestionIndex >= (numberOfSuggestions)) { + setState(() { + currentSuggestionIndex = 0; + futureBuilderKey = UniqueKey(); // Reset to trigger FutureBuilder + _fetchClusterSuggestions(); + }); + } + } + + // Method to fetch cluster suggestions + void _fetchClusterSuggestions() { + futureClusterSuggestions = ClusterFeedbackService.instance + .getClusterFilesForPersonID(widget.person); + } + + Widget _buildSuggestionView( + int clusterID, + List files, + int numberOfSuggestions, + ) { + return Column( + key: ValueKey("cluster_id-$clusterID"), + children: [ + Text( + files.length > 1 + ? "These photos belong to ${widget.person.attr.name}?" + : "This photo belongs to ${widget.person.attr.name}?", + style: getEnteTextTheme(context).largeMuted, + ), + const SizedBox(height: 24), + Row( + mainAxisAlignment: MainAxisAlignment.center, + children: _buildThumbnailWidgets( + files, + clusterID, + ), + ), + if (files.length > 4) const SizedBox(height: 24), + if (files.length > 4) + Row( + mainAxisAlignment: MainAxisAlignment.center, + children: _buildThumbnailWidgets( + files, + clusterID, + start: 4, + ), + ), + const SizedBox( + height: 24.0, + ), + Text( + "${files.length} photos", + style: getEnteTextTheme(context).body, + ), + const SizedBox( + height: 24.0, + ), // Add some spacing between the thumbnail and the text + Padding( + padding: const EdgeInsets.symmetric(horizontal: 24.0), + child: Column( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + ButtonWidget( + buttonType: ButtonType.primary, + labelText: 'Yes, confirm', + buttonSize: ButtonSize.large, + onTap: () async => { + await _handleUserClusterChoice( + clusterID, + true, + numberOfSuggestions, + ), + }, + ), + const SizedBox(height: 12.0), // Add some spacing + ButtonWidget( + buttonType: ButtonType.critical, + labelText: 'No', + buttonSize: ButtonSize.large, + onTap: () async => { + await _handleUserClusterChoice( + clusterID, + false, + numberOfSuggestions, + ), + }, + ), + ], + ), + ), + ], + ); + } + + List _buildThumbnailWidgets( + List files, + int cluserId, { + int start = 0, + }) { + return List.generate( + min(4, max(0, files.length - start)), + (index) => Padding( + padding: const EdgeInsets.all(8.0), + child: SizedBox( + width: 72, + height: 72, + child: ClipOval( + child: PersonFaceWidget( + files[start + index], + clusterID: cluserId, + ), + ), + ), + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/person_row_item.dart b/mobile/lib/ui/viewer/people/person_row_item.dart new file mode 100644 index 000000000..74fa3c875 --- /dev/null +++ b/mobile/lib/ui/viewer/people/person_row_item.dart @@ -0,0 +1,24 @@ +import "package:flutter/material.dart"; +import "package:photos/face/model/person.dart"; + +class PersonRowItem extends StatelessWidget { + final Person person; + final VoidCallback onTap; + + const PersonRowItem({ + Key? key, + required this.person, + required this.onTap, + }) : super(key: key); + + @override + Widget build(BuildContext context) { + return ListTile( + leading: CircleAvatar( + child: Text(person.attr.name.substring(0, 1)), + ), + title: Text(person.attr.name), + onTap: onTap, + ); + } +} diff --git a/mobile/lib/ui/viewer/search/result/no_result_widget.dart b/mobile/lib/ui/viewer/search/result/no_result_widget.dart index 9ebb9cf80..48ba811df 100644 --- a/mobile/lib/ui/viewer/search/result/no_result_widget.dart +++ b/mobile/lib/ui/viewer/search/result/no_result_widget.dart @@ -21,7 +21,6 @@ class _NoResultWidgetState extends State { super.initState(); searchTypes = SectionType.values.toList(growable: true); // remove face and content sectionType - searchTypes.remove(SectionType.face); searchTypes.remove(SectionType.content); } diff --git a/mobile/lib/ui/viewer/search/result/person_face_widget.dart b/mobile/lib/ui/viewer/search/result/person_face_widget.dart new file mode 100644 index 000000000..09b457b98 --- /dev/null +++ b/mobile/lib/ui/viewer/search/result/person_face_widget.dart @@ -0,0 +1,108 @@ +import "dart:developer"; +import "dart:typed_data"; + +import 'package:flutter/widgets.dart'; +import "package:photos/db/files_db.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/face.dart"; +import 'package:photos/models/file/file.dart'; +import 'package:photos/ui/viewer/file/thumbnail_widget.dart'; +import "package:photos/utils/face/face_box_crop.dart"; +import "package:photos/utils/thumbnail_util.dart"; + +class PersonFaceWidget extends StatelessWidget { + final EnteFile file; + final String? personId; + final int? clusterID; + + const PersonFaceWidget( + this.file, { + this.personId, + this.clusterID, + Key? key, + }) : super(key: key); + + @override + Widget build(BuildContext context) { + return FutureBuilder( + future: getFaceCrop(), + builder: (context, snapshot) { + if (snapshot.hasData) { + final ImageProvider imageProvider = MemoryImage(snapshot.data!); + return Stack( + fit: StackFit.expand, + children: [ + Image( + image: imageProvider, + fit: BoxFit.cover, + ), + ], + ); + } else { + if (snapshot.hasError) { + log('Error getting cover face for person: ${snapshot.error}'); + } + return ThumbnailWidget( + file, + ); + } + }, + ); + } + + Future getFaceCrop() async { + try { + final Face? face = await FaceMLDataDB.instance.getCoverFaceForPerson( + recentFileID: file.uploadedFileID!, + personID: personId, + clusterID: clusterID, + ); + if (face == null) { + debugPrint( + "No cover face for person: $personId and cluster $clusterID", + ); + return null; + } + final Uint8List? cachedFace = faceCropCache.get(face.faceID); + if (cachedFace != null) { + return cachedFace; + } + final faceCropCacheFile = cachedFaceCropPath(face.faceID); + if ((await faceCropCacheFile.exists())) { + final data = await faceCropCacheFile.readAsBytes(); + faceCropCache.put(face.faceID, data); + return data; + } + EnteFile? fileForFaceCrop = file; + if (face.fileID != file.uploadedFileID!) { + fileForFaceCrop = + await FilesDB.instance.getAnyUploadedFile(face.fileID!); + } + if (fileForFaceCrop == null) { + return null; + } + + final result = await pool.withResource( + () async => await getFaceCrops( + fileForFaceCrop!, + { + face.faceID: face.detection.box, + }, + ), + ); + final Uint8List? computedCrop = result?[face.faceID]; + if (computedCrop != null) { + faceCropCache.put(face.faceID, computedCrop); + faceCropCacheFile.writeAsBytes(computedCrop).ignore(); + } + return computedCrop; + } catch (e, s) { + log( + "Error getting cover face for person: $personId and cluster $clusterID", + error: e, + stackTrace: s, + ); + return null; + } + } +} diff --git a/mobile/lib/ui/viewer/search/result/search_result_widget.dart b/mobile/lib/ui/viewer/search/result/search_result_widget.dart index 5564af7c9..fbd77531a 100644 --- a/mobile/lib/ui/viewer/search/result/search_result_widget.dart +++ b/mobile/lib/ui/viewer/search/result/search_result_widget.dart @@ -13,12 +13,14 @@ class SearchResultWidget extends StatelessWidget { final SearchResult searchResult; final Future? resultCount; final Function? onResultTap; + final Map? params; const SearchResultWidget( this.searchResult, { Key? key, this.resultCount, this.onResultTap, + this.params, }) : super(key: key); @override @@ -42,6 +44,7 @@ class SearchResultWidget extends StatelessWidget { SearchThumbnailWidget( searchResult.previewThumbnail(), heroTagPrefix, + searchResult: searchResult, ), const SizedBox(width: 12), Padding( @@ -143,6 +146,8 @@ class SearchResultWidget extends StatelessWidget { return "Magic"; case ResultType.shared: return "Shared"; + case ResultType.faces: + return "Person"; default: return type.name.toUpperCase(); } diff --git a/mobile/lib/ui/viewer/search/result/search_section_all_page.dart b/mobile/lib/ui/viewer/search/result/search_section_all_page.dart index 59761009a..17dea1f84 100644 --- a/mobile/lib/ui/viewer/search/result/search_section_all_page.dart +++ b/mobile/lib/ui/viewer/search/result/search_section_all_page.dart @@ -1,5 +1,6 @@ import "dart:async"; +import "package:collection/collection.dart"; import "package:flutter/material.dart"; import "package:flutter_animate/flutter_animate.dart"; import "package:photos/events/event.dart"; @@ -109,7 +110,12 @@ class _SearchSectionAllPageState extends State { builder: (context, snapshot) { if (snapshot.hasData) { List sectionResults = snapshot.data!; - sectionResults.sort((a, b) => a.name().compareTo(b.name())); + if (widget.sectionType.sortByName) { + sectionResults.sort( + (a, b) => + compareAsciiLowerCaseNatural(b.name(), a.name()), + ); + } if (widget.sectionType == SectionType.location) { final result = sectionResults.splitMatch( (e) => e.type() == ResultType.location, diff --git a/mobile/lib/ui/viewer/search/result/search_thumbnail_widget.dart b/mobile/lib/ui/viewer/search/result/search_thumbnail_widget.dart index 13b303fec..514c65b99 100644 --- a/mobile/lib/ui/viewer/search/result/search_thumbnail_widget.dart +++ b/mobile/lib/ui/viewer/search/result/search_thumbnail_widget.dart @@ -1,15 +1,22 @@ import 'package:flutter/widgets.dart'; import 'package:photos/models/file/file.dart'; +import "package:photos/models/search/generic_search_result.dart"; +import "package:photos/models/search/search_constants.dart"; +import "package:photos/models/search/search_result.dart"; +import "package:photos/models/search/search_types.dart"; import 'package:photos/ui/viewer/file/no_thumbnail_widget.dart'; import 'package:photos/ui/viewer/file/thumbnail_widget.dart'; +import 'package:photos/ui/viewer/search/result/person_face_widget.dart'; class SearchThumbnailWidget extends StatelessWidget { final EnteFile? file; + final SearchResult? searchResult; final String tagPrefix; const SearchThumbnailWidget( this.file, this.tagPrefix, { + this.searchResult, Key? key, }) : super(key: key); @@ -23,9 +30,18 @@ class SearchThumbnailWidget extends StatelessWidget { child: ClipRRect( borderRadius: const BorderRadius.horizontal(left: Radius.circular(4)), child: file != null - ? ThumbnailWidget( - file!, - ) + ? (searchResult != null && + searchResult!.type() == ResultType.faces) + ? PersonFaceWidget( + file!, + personId: (searchResult as GenericSearchResult) + .params[kPersonParamID], + clusterID: (searchResult as GenericSearchResult) + .params[kClusterParamId], + ) + : ThumbnailWidget( + file!, + ) : const NoThumbnailWidget( addBorder: false, ), diff --git a/mobile/lib/ui/viewer/search/result/searchable_item.dart b/mobile/lib/ui/viewer/search/result/searchable_item.dart index 1124d925e..327c2e899 100644 --- a/mobile/lib/ui/viewer/search/result/searchable_item.dart +++ b/mobile/lib/ui/viewer/search/result/searchable_item.dart @@ -66,6 +66,7 @@ class SearchableItemWidget extends StatelessWidget { child: SearchThumbnailWidget( searchResult.previewThumbnail(), heroTagPrefix, + searchResult: searchResult, ), ), const SizedBox(width: 12), diff --git a/mobile/lib/ui/viewer/search/search_widget.dart b/mobile/lib/ui/viewer/search/search_widget.dart index 2d9132875..dcf87ad4c 100644 --- a/mobile/lib/ui/viewer/search/search_widget.dart +++ b/mobile/lib/ui/viewer/search/search_widget.dart @@ -2,9 +2,11 @@ import "dart:async"; import "package:flutter/material.dart"; import "package:flutter/scheduler.dart"; +import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/events/clear_and_unfocus_search_bar_event.dart"; import "package:photos/events/tab_changed_event.dart"; +import "package:photos/models/search/generic_search_result.dart"; import "package:photos/models/search/index_of_indexed_stack.dart"; import "package:photos/models/search/search_result.dart"; import "package:photos/services/search_service.dart"; @@ -40,6 +42,7 @@ class SearchWidgetState extends State { TextEditingController textController = TextEditingController(); late final StreamSubscription _clearAndUnfocusSearchBar; + late final Logger _logger = Logger("SearchWidgetState"); @override void initState() { @@ -200,7 +203,7 @@ class SearchWidgetState extends State { String query, ) { int resultCount = 0; - final maxResultCount = _isYearValid(query) ? 11 : 10; + final maxResultCount = _isYearValid(query) ? 13 : 12; final streamController = StreamController>(); if (query.isEmpty) { @@ -215,6 +218,11 @@ class SearchWidgetState extends State { if (resultCount == maxResultCount) { streamController.close(); } + if (resultCount > maxResultCount) { + _logger.warning( + "More results than expected. Expected: $maxResultCount, actual: $resultCount", + ); + } } if (_isYearValid(query)) { @@ -252,6 +260,17 @@ class SearchWidgetState extends State { onResultsReceived(locationResult); }, ); + _searchService.getAllFace(null).then( + (locationResult) { + final List filteredResults = []; + for (final result in locationResult) { + if (result.name().toLowerCase().contains(query.toLowerCase())) { + filteredResults.add(result); + } + } + onResultsReceived(filteredResults); + }, + ); _searchService.getCollectionSearchResults(query).then( (collectionResults) { diff --git a/mobile/lib/ui/viewer/search_tab/people_section.dart b/mobile/lib/ui/viewer/search_tab/people_section.dart new file mode 100644 index 000000000..da97b1aef --- /dev/null +++ b/mobile/lib/ui/viewer/search_tab/people_section.dart @@ -0,0 +1,290 @@ +import "dart:async"; + +import "package:collection/collection.dart"; +import "package:flutter/material.dart"; +import "package:photos/core/constants.dart"; +import "package:photos/events/event.dart"; +import "package:photos/models/search/album_search_result.dart"; +import "package:photos/models/search/generic_search_result.dart"; +import "package:photos/models/search/recent_searches.dart"; +import "package:photos/models/search/search_constants.dart"; +import "package:photos/models/search/search_result.dart"; +import "package:photos/models/search/search_types.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; +import "package:photos/ui/viewer/file/thumbnail_widget.dart"; +import "package:photos/ui/viewer/gallery/collection_page.dart"; +import 'package:photos/ui/viewer/search/result/person_face_widget.dart'; +import "package:photos/ui/viewer/search/result/search_result_page.dart"; +import 'package:photos/ui/viewer/search/result/search_section_all_page.dart'; +import "package:photos/ui/viewer/search/search_section_cta.dart"; +import "package:photos/utils/navigation_util.dart"; + +class SearchSection extends StatefulWidget { + final SectionType sectionType; + final List examples; + final int limit; + + const SearchSection({ + Key? key, + required this.sectionType, + required this.examples, + required this.limit, + }) : super(key: key); + + @override + State createState() => _SearchSectionState(); +} + +class _SearchSectionState extends State { + late List _examples; + final streamSubscriptions = []; + + @override + void initState() { + super.initState(); + _examples = widget.examples; + + final streamsToListenTo = widget.sectionType.sectionUpdateEvents(); + for (Stream stream in streamsToListenTo) { + streamSubscriptions.add( + stream.listen((event) async { + _examples = await widget.sectionType.getData( + context, + limit: kSearchSectionLimit, + ); + setState(() {}); + }), + ); + } + } + + @override + void dispose() { + for (var subscriptions in streamSubscriptions) { + subscriptions.cancel(); + } + super.dispose(); + } + + @override + void didUpdateWidget(covariant SearchSection oldWidget) { + super.didUpdateWidget(oldWidget); + _examples = widget.examples; + } + + @override + Widget build(BuildContext context) { + debugPrint("Building section for ${widget.sectionType.name}"); + final shouldShowMore = _examples.length >= widget.limit - 1; + final textTheme = getEnteTextTheme(context); + return Padding( + padding: const EdgeInsets.symmetric(vertical: 8), + child: _examples.isNotEmpty + ? GestureDetector( + behavior: HitTestBehavior.opaque, + onTap: () { + if (shouldShowMore) { + routeToPage( + context, + SearchSectionAllPage( + sectionType: widget.sectionType, + ), + ); + } + }, + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Padding( + padding: const EdgeInsets.all(12), + child: Text( + widget.sectionType.sectionTitle(context), + style: textTheme.largeBold, + ), + ), + shouldShowMore + ? Padding( + padding: const EdgeInsets.all(12), + child: Icon( + Icons.chevron_right_outlined, + color: getEnteColorScheme(context).strokeMuted, + ), + ) + : const SizedBox.shrink(), + ], + ), + const SizedBox(height: 2), + SearchExampleRow(_examples, widget.sectionType), + ], + ), + ) + : Padding( + padding: const EdgeInsets.only(left: 16, right: 8), + child: Row( + children: [ + Expanded( + child: Padding( + padding: const EdgeInsets.symmetric(vertical: 12), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + widget.sectionType.sectionTitle(context), + style: textTheme.largeBold, + ), + const SizedBox(height: 24), + Text( + widget.sectionType.getEmptyStateText(context), + style: textTheme.smallMuted, + ), + ], + ), + ), + ), + const SizedBox(width: 8), + SearchSectionEmptyCTAIcon(widget.sectionType), + ], + ), + ), + ); + } +} + +class SearchExampleRow extends StatelessWidget { + final SectionType sectionType; + final List examples; + + const SearchExampleRow(this.examples, this.sectionType, {super.key}); + + @override + Widget build(BuildContext context) { + //Cannot use listView.builder here + final scrollableExamples = []; + if (sectionType == SectionType.location) { + scrollableExamples.add(const GoToMapWidget()); + } + examples.forEachIndexed((index, element) { + scrollableExamples.add( + SearchExample( + searchResult: examples.elementAt(index), + ), + ); + }); + scrollableExamples.add(SearchSectionCTAIcon(sectionType)); + return SizedBox( + child: SingleChildScrollView( + physics: const BouncingScrollPhysics(), + scrollDirection: Axis.horizontal, + child: Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: scrollableExamples, + ), + ), + ); + } +} + +class SearchExample extends StatelessWidget { + final SearchResult searchResult; + const SearchExample({required this.searchResult, super.key}); + + @override + Widget build(BuildContext context) { + final textScaleFactor = MediaQuery.textScaleFactorOf(context); + late final double width; + if (textScaleFactor <= 1.0) { + width = 85.0; + } else { + width = 85.0 + ((textScaleFactor - 1.0) * 64); + } + final heroTag = + searchResult.heroTag() + (searchResult.previewThumbnail()?.tag ?? ""); + return GestureDetector( + onTap: () { + RecentSearches().add(searchResult.name()); + + if (searchResult is GenericSearchResult) { + final genericSearchResult = searchResult as GenericSearchResult; + if (genericSearchResult.onResultTap != null) { + genericSearchResult.onResultTap!(context); + } else { + routeToPage( + context, + SearchResultPage(searchResult), + ); + } + } else if (searchResult is AlbumSearchResult) { + final albumSearchResult = searchResult as AlbumSearchResult; + routeToPage( + context, + CollectionPage( + albumSearchResult.collectionWithThumbnail, + tagPrefix: albumSearchResult.heroTag(), + ), + ); + } + }, + child: SizedBox( + width: width, + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 6, vertical: 10), + child: Column( + mainAxisSize: MainAxisSize.min, + children: [ + SizedBox( + width: 64, + height: 64, + child: searchResult.previewThumbnail() != null + ? Hero( + tag: heroTag, + child: ClipOval( + child: searchResult.type() != ResultType.faces + ? ThumbnailWidget( + searchResult.previewThumbnail()!, + shouldShowSyncStatus: false, + ) + : FaceSearchResult(searchResult, heroTag), + ), + ) + : const ClipOval( + child: NoThumbnailWidget( + addBorder: false, + ), + ), + ), + const SizedBox( + height: 10, + ), + Text( + searchResult.name(), + maxLines: 2, + textAlign: TextAlign.center, + overflow: TextOverflow.ellipsis, + style: getEnteTextTheme(context).mini, + ), + ], + ), + ), + ), + ); + } +} + +class FaceSearchResult extends StatelessWidget { + final SearchResult searchResult; + final String heroTagPrefix; + const FaceSearchResult(this.searchResult, this.heroTagPrefix, {super.key}); + + @override + Widget build(BuildContext context) { + return PersonFaceWidget( + searchResult.previewThumbnail()!, + personId: (searchResult as GenericSearchResult).params[kPersonParamID], + clusterID: (searchResult as GenericSearchResult).params[kClusterParamId], + ); + } +} diff --git a/mobile/lib/ui/viewer/search_tab/search_tab.dart b/mobile/lib/ui/viewer/search_tab/search_tab.dart index bfb35600a..b20ef87ea 100644 --- a/mobile/lib/ui/viewer/search_tab/search_tab.dart +++ b/mobile/lib/ui/viewer/search_tab/search_tab.dart @@ -1,6 +1,8 @@ import "package:fade_indexed_stack/fade_indexed_stack.dart"; +import "package:flutter/foundation.dart"; import "package:flutter/material.dart"; import "package:flutter_animate/flutter_animate.dart"; +import "package:logging/logging.dart"; import "package:photos/models/search/album_search_result.dart"; import "package:photos/models/search/generic_search_result.dart"; import "package:photos/models/search/index_of_indexed_stack.dart"; @@ -73,11 +75,12 @@ class AllSearchSections extends StatefulWidget { } class _AllSearchSectionsState extends State { + final Logger _logger = Logger('_AllSearchSectionsState'); @override Widget build(BuildContext context) { final searchTypes = SectionType.values.toList(growable: true); // remove face and content sectionType - searchTypes.remove(SectionType.face); + // searchTypes.remove(SectionType.face); searchTypes.remove(SectionType.content); return Padding( padding: const EdgeInsets.only(top: 8), @@ -150,6 +153,13 @@ class _AllSearchSectionsState extends State { curve: Curves.easeOut, ); } else if (snapshot.hasError) { + _logger.severe('Failed to load sections: ', snapshot.error); + if (kDebugMode) { + return Padding( + padding: const EdgeInsets.only(bottom: 72), + child: Text('Error: ${snapshot.error}'), + ); + } //Errors are handled and this else if condition will be false always //is the understanding. return const Padding( diff --git a/mobile/lib/utils/debug_ml_export_data.dart b/mobile/lib/utils/debug_ml_export_data.dart new file mode 100644 index 000000000..f7a5e9646 --- /dev/null +++ b/mobile/lib/utils/debug_ml_export_data.dart @@ -0,0 +1,40 @@ +import "dart:convert"; +import "dart:developer" show log; +import "dart:io"; + +import "package:path_provider/path_provider.dart"; + +Future encodeAndSaveData( + dynamic nestedData, + String fileName, [ + String? service, +]) async { + // Convert map keys to strings if nestedData is a map + final dataToEncode = nestedData is Map + ? nestedData.map((key, value) => MapEntry(key.toString(), value)) + : nestedData; + // Step 1: Serialize Your Data + final String jsonData = jsonEncode(dataToEncode); + + // Step 2: Encode the JSON String to Base64 + // final String base64String = base64Encode(utf8.encode(jsonData)); + + // Step 3 & 4: Write the Base64 String to a File and Execute the Function + try { + final File file = await _writeStringToFile(jsonData, fileName); + // Success, handle the file, e.g., print the file path + log('[$service]: File saved at ${file.path}'); + } catch (e) { + // If an error occurs, handle it. + log('[$service]: Error saving file: $e'); + } +} + +Future _writeStringToFile( + String dataString, + String fileName, +) async { + final directory = await getExternalStorageDirectory(); + final file = File('${directory!.path}/$fileName.json'); + return file.writeAsString(dataString); +} diff --git a/mobile/lib/utils/dialog_util.dart b/mobile/lib/utils/dialog_util.dart index ae4425620..c43229152 100644 --- a/mobile/lib/utils/dialog_util.dart +++ b/mobile/lib/utils/dialog_util.dart @@ -110,7 +110,11 @@ String parseErrorForUI( errorInfo = "Reason: " + dioError.type.toString(); } } else { - errorInfo = error.toString().split('Source stack')[0]; + if (kDebugMode) { + errorInfo = error.toString(); + } else { + errorInfo = error.toString().split('Source stack')[0]; + } } if (errorInfo.isNotEmpty) { return "$genericError\n\n$errorInfo"; diff --git a/mobile/lib/utils/face/face_box_crop.dart b/mobile/lib/utils/face/face_box_crop.dart new file mode 100644 index 000000000..2f63ca7e2 --- /dev/null +++ b/mobile/lib/utils/face/face_box_crop.dart @@ -0,0 +1,45 @@ +import "dart:io"; + +import "package:flutter/foundation.dart"; +import "package:photos/core/cache/lru_map.dart"; +import "package:photos/face/model/box.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/file/file_type.dart"; +import "package:photos/utils/file_util.dart"; +import "package:photos/utils/image_ml_isolate.dart"; +import "package:photos/utils/thumbnail_util.dart"; +import "package:pool/pool.dart"; + +final LRUMap faceCropCache = LRUMap(1000); +final pool = Pool(5, timeout: const Duration(seconds: 15)); +Future?> getFaceCrops( + EnteFile file, + Map faceBoxeMap, +) async { + late Uint8List? ioFileBytes; + if (file.fileType != FileType.video) { + final File? ioFile = await getFile(file); + if (ioFile == null) { + return null; + } + ioFileBytes = await ioFile.readAsBytes(); + } else { + ioFileBytes = await getThumbnail(file); + } + final List faceIds = []; + final List faceBoxes = []; + for (final e in faceBoxeMap.entries) { + faceIds.add(e.key); + faceBoxes.add(e.value); + } + final List faceCrop = + await ImageMlIsolate.instance.generateFaceThumbnailsForImage( + ioFileBytes!, + faceBoxes, + ); + final Map result = {}; + for (int i = 0; i < faceIds.length; i++) { + result[faceIds[i]] = faceCrop[i]; + } + return result; +} diff --git a/mobile/lib/utils/file_download_util.dart b/mobile/lib/utils/file_download_util.dart index f99a43527..6f8219ed5 100644 --- a/mobile/lib/utils/file_download_util.dart +++ b/mobile/lib/utils/file_download_util.dart @@ -38,9 +38,9 @@ Future downloadAndDecrypt( ), onReceiveProgress: (a, b) { if (kDebugMode && a >= 0 && b >= 0) { - _logger.fine( - "$logPrefix download progress: ${formatBytes(a)} / ${formatBytes(b)}", - ); + // _logger.fine( + // "$logPrefix download progress: ${formatBytes(a)} / ${formatBytes(b)}", + // ); } progressCallback?.call(a, b); }, diff --git a/mobile/lib/utils/image_ml_isolate.dart b/mobile/lib/utils/image_ml_isolate.dart new file mode 100644 index 000000000..9427d800f --- /dev/null +++ b/mobile/lib/utils/image_ml_isolate.dart @@ -0,0 +1,536 @@ +import 'dart:async'; +import 'dart:isolate'; +import 'dart:typed_data' show Float32List, Uint8List; +import 'dart:ui'; + +import 'package:flutter_isolate/flutter_isolate.dart'; +import "package:logging/logging.dart"; +import "package:photos/face/model/box.dart"; +import 'package:photos/models/ml/ml_typedefs.dart'; +import "package:photos/services/face_ml/face_alignment/alignment_result.dart"; +import "package:photos/services/face_ml/face_detection/detection.dart"; +import "package:photos/utils/image_ml_util.dart"; +import "package:synchronized/synchronized.dart"; + +enum ImageOperation { + preprocessBlazeFace, + preprocessYoloOnnx, + preprocessFaceAlign, + preprocessMobileFaceNet, + preprocessMobileFaceNetOnnx, + generateFaceThumbnail, + generateFaceThumbnailsForImage, + cropAndPadFace, +} + +/// The isolate below uses functions from ["package:photos/utils/image_ml_util.dart"] to preprocess images for ML models. + +/// This class is responsible for all image operations needed for ML models. It runs in a separate isolate to avoid jank. +/// +/// It can be accessed through the singleton `ImageConversionIsolate.instance`. e.g. `ImageConversionIsolate.instance.convert(imageData)` +/// +/// IMPORTANT: Make sure to dispose of the isolate when you're done with it with `dispose()`, e.g. `ImageConversionIsolate.instance.dispose();` +class ImageMlIsolate { + // static const String debugName = 'ImageMlIsolate'; + + final _logger = Logger('ImageMlIsolate'); + + Timer? _inactivityTimer; + final Duration _inactivityDuration = const Duration(seconds: 60); + int _activeTasks = 0; + + final _initLock = Lock(); + final _functionLock = Lock(); + + late FlutterIsolate _isolate; + late ReceivePort _receivePort = ReceivePort(); + late SendPort _mainSendPort; + + bool isSpawned = false; + + // singleton pattern + ImageMlIsolate._privateConstructor(); + + /// Use this instance to access the ImageConversionIsolate service. Make sure to call `init()` before using it. + /// e.g. `await ImageConversionIsolate.instance.init();` + /// And kill the isolate when you're done with it with `dispose()`, e.g. `ImageConversionIsolate.instance.dispose();` + /// + /// Then you can use `convert()` to get the image, so `ImageConversionIsolate.instance.convert(imageData, imagePath: imagePath)` + static final ImageMlIsolate instance = ImageMlIsolate._privateConstructor(); + factory ImageMlIsolate() => instance; + + Future init() async { + return _initLock.synchronized(() async { + if (isSpawned) return; + + _receivePort = ReceivePort(); + + try { + _isolate = await FlutterIsolate.spawn( + _isolateMain, + _receivePort.sendPort, + ); + _mainSendPort = await _receivePort.first as SendPort; + isSpawned = true; + + _resetInactivityTimer(); + } catch (e) { + _logger.severe('Could not spawn isolate', e); + isSpawned = false; + } + }); + } + + Future ensureSpawned() async { + if (!isSpawned) { + await init(); + } + } + + @pragma('vm:entry-point') + static void _isolateMain(SendPort mainSendPort) async { + final receivePort = ReceivePort(); + mainSendPort.send(receivePort.sendPort); + + receivePort.listen((message) async { + final functionIndex = message[0] as int; + final function = ImageOperation.values[functionIndex]; + final args = message[1] as Map; + final sendPort = message[2] as SendPort; + + try { + switch (function) { + case ImageOperation.preprocessBlazeFace: + final imageData = args['imageData'] as Uint8List; + final normalize = args['normalize'] as bool; + final int normalization = normalize ? 2 : -1; + final requiredWidth = args['requiredWidth'] as int; + final requiredHeight = args['requiredHeight'] as int; + final qualityIndex = args['quality'] as int; + final maintainAspectRatio = args['maintainAspectRatio'] as bool; + final quality = FilterQuality.values[qualityIndex]; + final (result, originalSize, newSize) = + await preprocessImageToMatrix( + imageData, + normalization: normalization, + requiredWidth: requiredWidth, + requiredHeight: requiredHeight, + quality: quality, + maintainAspectRatio: maintainAspectRatio, + ); + sendPort.send({ + 'inputs': result, + 'originalWidth': originalSize.width, + 'originalHeight': originalSize.height, + 'newWidth': newSize.width, + 'newHeight': newSize.height, + }); + case ImageOperation.preprocessYoloOnnx: + final imageData = args['imageData'] as Uint8List; + final normalize = args['normalize'] as bool; + final int normalization = normalize ? 1 : -1; + final requiredWidth = args['requiredWidth'] as int; + final requiredHeight = args['requiredHeight'] as int; + final qualityIndex = args['quality'] as int; + final maintainAspectRatio = args['maintainAspectRatio'] as bool; + final quality = FilterQuality.values[qualityIndex]; + final (result, originalSize, newSize) = + await preprocessImageToFloat32ChannelsFirst( + imageData, + normalization: normalization, + requiredWidth: requiredWidth, + requiredHeight: requiredHeight, + quality: quality, + maintainAspectRatio: maintainAspectRatio, + ); + sendPort.send({ + 'inputs': result, + 'originalWidth': originalSize.width, + 'originalHeight': originalSize.height, + 'newWidth': newSize.width, + 'newHeight': newSize.height, + }); + case ImageOperation.preprocessFaceAlign: + final imageData = args['imageData'] as Uint8List; + final faceLandmarks = + args['faceLandmarks'] as List>>; + final List result = await preprocessFaceAlignToUint8List( + imageData, + faceLandmarks, + ); + sendPort.send(List.from(result)); + case ImageOperation.preprocessMobileFaceNet: + final imageData = args['imageData'] as Uint8List; + final facesJson = args['facesJson'] as List>; + final ( + inputs, + alignmentResults, + isBlurs, + blurValues, + originalSize + ) = await preprocessToMobileFaceNetInput( + imageData, + facesJson, + ); + final List> alignmentResultsJson = + alignmentResults.map((result) => result.toJson()).toList(); + sendPort.send({ + 'inputs': inputs, + 'alignmentResultsJson': alignmentResultsJson, + 'isBlurs': isBlurs, + 'blurValues': blurValues, + 'originalWidth': originalSize.width, + 'originalHeight': originalSize.height, + }); + case ImageOperation.preprocessMobileFaceNetOnnx: + final imagePath = args['imagePath'] as String; + final facesJson = args['facesJson'] as List>; + final List relativeFaces = facesJson + .map((face) => FaceDetectionRelative.fromJson(face)) + .toList(); + final ( + inputs, + alignmentResults, + isBlurs, + blurValues, + originalSize + ) = await preprocessToMobileFaceNetFloat32List( + imagePath, + relativeFaces, + ); + final List> alignmentResultsJson = + alignmentResults.map((result) => result.toJson()).toList(); + sendPort.send({ + 'inputs': inputs, + 'alignmentResultsJson': alignmentResultsJson, + 'isBlurs': isBlurs, + 'blurValues': blurValues, + 'originalWidth': originalSize.width, + 'originalHeight': originalSize.height, + }); + case ImageOperation.generateFaceThumbnail: + final imageData = args['imageData'] as Uint8List; + final faceDetectionJson = + args['faceDetection'] as Map; + final faceDetection = + FaceDetectionRelative.fromJson(faceDetectionJson); + final Uint8List result = + await generateFaceThumbnailFromData(imageData, faceDetection); + sendPort.send([result]); + case ImageOperation.generateFaceThumbnailsForImage: + final imageData = args['imageData'] as Uint8List; + final faceBoxesJson = + args['faceBoxesList'] as List>; + final List faceBoxes = + faceBoxesJson.map((json) => FaceBox.fromJson(json)).toList(); + final List results = + await generateFaceThumbnailsFromDataAndDetections( + imageData, + faceBoxes, + ); + sendPort.send(List.from(results)); + case ImageOperation.cropAndPadFace: + final imageData = args['imageData'] as Uint8List; + final faceBox = args['faceBox'] as List; + final Uint8List result = + await cropAndPadFaceData(imageData, faceBox); + sendPort.send([result]); + } + } catch (e, stackTrace) { + sendPort + .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); + } + }); + } + + /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. + Future _runInIsolate( + (ImageOperation, Map) message, + ) async { + await ensureSpawned(); + return _functionLock.synchronized(() async { + _resetInactivityTimer(); + final completer = Completer(); + final answerPort = ReceivePort(); + + _activeTasks++; + _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); + + answerPort.listen((receivedMessage) { + if (receivedMessage is Map && receivedMessage.containsKey('error')) { + // Handle the error + final errorMessage = receivedMessage['error']; + final errorStackTrace = receivedMessage['stackTrace']; + final exception = Exception(errorMessage); + final stackTrace = StackTrace.fromString(errorStackTrace); + completer.completeError(exception, stackTrace); + } else { + completer.complete(receivedMessage); + } + }); + _activeTasks--; + + return completer.future; + }); + } + + /// Resets a timer that kills the isolate after a certain amount of inactivity. + /// + /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) + void _resetInactivityTimer() { + _inactivityTimer?.cancel(); + _inactivityTimer = Timer(_inactivityDuration, () { + if (_activeTasks > 0) { + _logger.info('Tasks are still running. Delaying isolate disposal.'); + // Optionally, reschedule the timer to check again later. + _resetInactivityTimer(); + } else { + _logger.info( + 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', + ); + dispose(); + } + }); + } + + /// Disposes the isolate worker. + void dispose() { + if (!isSpawned) return; + + isSpawned = false; + _isolate.kill(); + _receivePort.close(); + _inactivityTimer?.cancel(); + } + + /// Preprocesses [imageData] for standard ML models inside a separate isolate. + /// + /// Returns a [Num3DInputMatrix] image usable for ML inference with BlazeFace. + /// + /// Uses [preprocessImageToMatrix] inside the isolate. + Future<(Num3DInputMatrix, Size, Size)> preprocessImageBlazeFace( + Uint8List imageData, { + required bool normalize, + required int requiredWidth, + required int requiredHeight, + FilterQuality quality = FilterQuality.medium, + bool maintainAspectRatio = true, + }) async { + final Map results = await _runInIsolate( + ( + ImageOperation.preprocessBlazeFace, + { + 'imageData': imageData, + 'normalize': normalize, + 'requiredWidth': requiredWidth, + 'requiredHeight': requiredHeight, + 'quality': quality.index, + 'maintainAspectRatio': maintainAspectRatio, + }, + ), + ); + final inputs = results['inputs'] as Num3DInputMatrix; + final originalSize = Size( + results['originalWidth'] as double, + results['originalHeight'] as double, + ); + final newSize = Size( + results['newWidth'] as double, + results['newHeight'] as double, + ); + return (inputs, originalSize, newSize); + } + + /// Uses [preprocessImageToFloat32ChannelsFirst] inside the isolate. + Future<(Float32List, Size, Size)> preprocessImageYoloOnnx( + Uint8List imageData, { + required bool normalize, + required int requiredWidth, + required int requiredHeight, + FilterQuality quality = FilterQuality.medium, + bool maintainAspectRatio = true, + }) async { + final Map results = await _runInIsolate( + ( + ImageOperation.preprocessYoloOnnx, + { + 'imageData': imageData, + 'normalize': normalize, + 'requiredWidth': requiredWidth, + 'requiredHeight': requiredHeight, + 'quality': quality.index, + 'maintainAspectRatio': maintainAspectRatio, + }, + ), + ); + final inputs = results['inputs'] as Float32List; + final originalSize = Size( + results['originalWidth'] as double, + results['originalHeight'] as double, + ); + final newSize = Size( + results['newWidth'] as double, + results['newHeight'] as double, + ); + return (inputs, originalSize, newSize); + } + + /// Preprocesses [imageData] for face alignment inside a separate isolate, to display the aligned faces. Mostly used for debugging. + /// + /// Returns a list of [Uint8List] images, one for each face, in png format. + /// + /// Uses [preprocessFaceAlignToUint8List] inside the isolate. + /// + /// WARNING: For preprocessing for MobileFaceNet, use [preprocessMobileFaceNet] instead! + Future> preprocessFaceAlign( + Uint8List imageData, + List faces, + ) async { + final faceLandmarks = faces.map((face) => face.allKeypoints).toList(); + return await _runInIsolate( + ( + ImageOperation.preprocessFaceAlign, + { + 'imageData': imageData, + 'faceLandmarks': faceLandmarks, + }, + ), + ).then((value) => value.cast()); + } + + /// Preprocesses [imageData] for MobileFaceNet input inside a separate isolate. + /// + /// Returns a list of [Num3DInputMatrix] images, one for each face. + /// + /// Uses [preprocessToMobileFaceNetInput] inside the isolate. + Future< + ( + List, + List, + List, + List, + Size, + )> preprocessMobileFaceNet( + Uint8List imageData, + List faces, + ) async { + final List> facesJson = + faces.map((face) => face.toJson()).toList(); + final Map results = await _runInIsolate( + ( + ImageOperation.preprocessMobileFaceNet, + { + 'imageData': imageData, + 'facesJson': facesJson, + }, + ), + ); + final inputs = results['inputs'] as List; + final alignmentResultsJson = + results['alignmentResultsJson'] as List>; + final alignmentResults = alignmentResultsJson.map((json) { + return AlignmentResult.fromJson(json); + }).toList(); + final isBlurs = results['isBlurs'] as List; + final blurValues = results['blurValues'] as List; + final originalSize = Size( + results['originalWidth'] as double, + results['originalHeight'] as double, + ); + return (inputs, alignmentResults, isBlurs, blurValues, originalSize); + } + + /// Uses [preprocessToMobileFaceNetFloat32List] inside the isolate. + Future<(Float32List, List, List, List, Size)> + preprocessMobileFaceNetOnnx( + String imagePath, + List faces, + ) async { + final List> facesJson = + faces.map((face) => face.toJson()).toList(); + final Map results = await _runInIsolate( + ( + ImageOperation.preprocessMobileFaceNetOnnx, + { + 'imagePath': imagePath, + 'facesJson': facesJson, + }, + ), + ); + final inputs = results['inputs'] as Float32List; + final alignmentResultsJson = + results['alignmentResultsJson'] as List>; + final alignmentResults = alignmentResultsJson.map((json) { + return AlignmentResult.fromJson(json); + }).toList(); + final isBlurs = results['isBlurs'] as List; + final blurValues = results['blurValues'] as List; + final originalSize = Size( + results['originalWidth'] as double, + results['originalHeight'] as double, + ); + + return (inputs, alignmentResults, isBlurs, blurValues, originalSize); + } + + /// Generates a face thumbnail from [imageData] and a [faceDetection]. + /// + /// Uses [generateFaceThumbnailFromData] inside the isolate. + Future generateFaceThumbnail( + Uint8List imageData, + FaceDetectionRelative faceDetection, + ) async { + return await _runInIsolate( + ( + ImageOperation.generateFaceThumbnail, + { + 'imageData': imageData, + 'faceDetection': faceDetection.toJson(), + }, + ), + ).then((value) => value[0] as Uint8List); + } + + /// Generates face thumbnails for all [faceBoxes] in [imageData]. + /// + /// Uses [generateFaceThumbnailsFromDataAndDetections] inside the isolate. + Future> generateFaceThumbnailsForImage( + Uint8List imageData, + List faceBoxes, + ) async { + final List> faceBoxesJson = + faceBoxes.map((box) => box.toJson()).toList(); + return await _runInIsolate( + ( + ImageOperation.generateFaceThumbnailsForImage, + { + 'imageData': imageData, + 'faceBoxesList': faceBoxesJson, + }, + ), + ).then((value) => value.cast()); + } + + /// Generates cropped and padded image data from [imageData] and a [faceBox]. + /// + /// The steps are: + /// 1. Crop the image to the face bounding box + /// 2. Resize this cropped image to a square that is half the BlazeFace input size + /// 3. Pad the image to the BlazeFace input size + /// + /// Uses [cropAndPadFaceData] inside the isolate. + Future cropAndPadFace( + Uint8List imageData, + List faceBox, + ) async { + return await _runInIsolate( + ( + ImageOperation.cropAndPadFace, + { + 'imageData': imageData, + 'faceBox': List.from(faceBox), + }, + ), + ).then((value) => value[0] as Uint8List); + } +} diff --git a/mobile/lib/utils/image_ml_util.dart b/mobile/lib/utils/image_ml_util.dart new file mode 100644 index 000000000..b3e6e2dfc --- /dev/null +++ b/mobile/lib/utils/image_ml_util.dart @@ -0,0 +1,1273 @@ +import "dart:async"; +import "dart:developer" show log; +import "dart:io" show File; +import "dart:math" show min, max; +import "dart:typed_data" show Float32List, Uint8List, ByteData; +import "dart:ui"; + +// import 'package:flutter/material.dart' +// show +// ImageProvider, +// ImageStream, +// ImageStreamListener, +// ImageInfo, +// MemoryImage, +// ImageConfiguration; +// import 'package:flutter/material.dart' as material show Image; +import 'package:flutter/painting.dart' as paint show decodeImageFromList; +import 'package:ml_linalg/linalg.dart'; +import "package:photos/face/model/box.dart"; +import 'package:photos/models/ml/ml_typedefs.dart'; +import "package:photos/services/face_ml/blur_detection/blur_detection_service.dart"; +import "package:photos/services/face_ml/face_alignment/alignment_result.dart"; +import "package:photos/services/face_ml/face_alignment/similarity_transform.dart"; +import "package:photos/services/face_ml/face_detection/detection.dart"; + +/// All of the functions in this file are helper functions for the [ImageMlIsolate] isolate. +/// Don't use them outside of the isolate, unless you are okay with UI jank!!!! + +/// Reads the pixel color at the specified coordinates. +Color readPixelColor( + Image image, + ByteData byteData, + int x, + int y, +) { + if (x < 0 || x >= image.width || y < 0 || y >= image.height) { + // throw ArgumentError('Invalid pixel coordinates.'); + return const Color(0x00000000); + } + assert(byteData.lengthInBytes == 4 * image.width * image.height); + + final int byteOffset = 4 * (image.width * y + x); + return Color(_rgbaToArgb(byteData.getUint32(byteOffset))); +} + +int _rgbaToArgb(int rgbaColor) { + final int a = rgbaColor & 0xFF; + final int rgb = rgbaColor >> 8; + return rgb + (a << 24); +} + +/// Creates an empty matrix with the specified shape. +/// +/// The `shape` argument must be a list of length 2 or 3, where the first +/// element represents the number of rows, the second element represents +/// the number of columns, and the optional third element represents the +/// number of channels. The function returns a matrix filled with zeros. +/// +/// Throws an [ArgumentError] if the `shape` argument is invalid. +List createEmptyOutputMatrix(List shape, [double fillValue = 0.0]) { + if (shape.length > 5) { + throw ArgumentError('Shape must have length 1-5'); + } + + if (shape.length == 1) { + return List.filled(shape[0], fillValue); + } else if (shape.length == 2) { + return List.generate(shape[0], (_) => List.filled(shape[1], fillValue)); + } else if (shape.length == 3) { + return List.generate( + shape[0], + (_) => List.generate(shape[1], (_) => List.filled(shape[2], fillValue)), + ); + } else if (shape.length == 4) { + return List.generate( + shape[0], + (_) => List.generate( + shape[1], + (_) => List.generate(shape[2], (_) => List.filled(shape[3], fillValue)), + ), + ); + } else if (shape.length == 5) { + return List.generate( + shape[0], + (_) => List.generate( + shape[1], + (_) => List.generate( + shape[2], + (_) => + List.generate(shape[3], (_) => List.filled(shape[4], fillValue)), + ), + ), + ); + } else { + throw ArgumentError('Shape must have length 2 or 3'); + } +} + +/// Creates an input matrix from the specified image, which can be used for inference +/// +/// Returns a matrix with the shape [image.height, image.width, 3], where the third dimension represents the RGB channels, as [Num3DInputMatrix]. +/// In fact, this is either a [Double3DInputMatrix] or a [Int3DInputMatrix] depending on the `normalize` argument. +/// If `normalize` is true, the pixel values are normalized doubles in range [-1, 1]. Otherwise, they are integers in range [0, 255]. +/// +/// The `image` argument must be an ui.[Image] object. The function returns a matrix +/// with the shape `[image.height, image.width, 3]`, where the third dimension +/// represents the RGB channels. +/// +/// bool `normalize`: Normalize the image to range [-1, 1] +Num3DInputMatrix createInputMatrixFromImage( + Image image, + ByteData byteDataRgba, { + double Function(num) normFunction = normalizePixelRange2, +}) { + return List.generate( + image.height, + (y) => List.generate( + image.width, + (x) { + final pixel = readPixelColor(image, byteDataRgba, x, y); + return [ + normFunction(pixel.red), + normFunction(pixel.green), + normFunction(pixel.blue), + ]; + }, + ), + ); +} + +void addInputImageToFloat32List( + Image image, + ByteData byteDataRgba, + Float32List float32List, + int startIndex, { + double Function(num) normFunction = normalizePixelRange2, +}) { + int pixelIndex = startIndex; + for (var h = 0; h < image.height; h++) { + for (var w = 0; w < image.width; w++) { + final pixel = readPixelColor(image, byteDataRgba, w, h); + float32List[pixelIndex] = normFunction(pixel.red); + float32List[pixelIndex + 1] = normFunction(pixel.green); + float32List[pixelIndex + 2] = normFunction(pixel.blue); + pixelIndex += 3; + } + } + return; +} + +List> createGrayscaleIntMatrixFromImage( + Image image, + ByteData byteDataRgba, +) { + return List.generate( + image.height, + (y) => List.generate( + image.width, + (x) { + // 0.299 ∙ Red + 0.587 ∙ Green + 0.114 ∙ Blue + final pixel = readPixelColor(image, byteDataRgba, x, y); + return (0.299 * pixel.red + 0.587 * pixel.green + 0.114 * pixel.blue) + .round() + .clamp(0, 255); + }, + ), + ); +} + +Float32List createFloat32ListFromImageChannelsFirst( + Image image, + ByteData byteDataRgba, { + double Function(num) normFunction = normalizePixelRange2, +}) { + final convertedBytes = Float32List(3 * image.height * image.width); + final buffer = Float32List.view(convertedBytes.buffer); + + int pixelIndex = 0; + final int channelOffsetGreen = image.height * image.width; + final int channelOffsetBlue = 2 * image.height * image.width; + for (var h = 0; h < image.height; h++) { + for (var w = 0; w < image.width; w++) { + final pixel = readPixelColor(image, byteDataRgba, w, h); + buffer[pixelIndex] = normFunction(pixel.red); + buffer[pixelIndex + channelOffsetGreen] = normFunction(pixel.green); + buffer[pixelIndex + channelOffsetBlue] = normFunction(pixel.blue); + pixelIndex++; + } + } + return convertedBytes.buffer.asFloat32List(); +} + +/// Creates an input matrix from the specified image, which can be used for inference +/// +/// Returns a matrix with the shape `[3, image.height, image.width]`, where the first dimension represents the RGB channels, as [Num3DInputMatrix]. +/// In fact, this is either a [Double3DInputMatrix] or a [Int3DInputMatrix] depending on the `normalize` argument. +/// If `normalize` is true, the pixel values are normalized doubles in range [-1, 1]. Otherwise, they are integers in range [0, 255]. +/// +/// The `image` argument must be an ui.[Image] object. The function returns a matrix +/// with the shape `[3, image.height, image.width]`, where the first dimension +/// represents the RGB channels. +/// +/// bool `normalize`: Normalize the image to range [-1, 1] +Num3DInputMatrix createInputMatrixFromImageChannelsFirst( + Image image, + ByteData byteDataRgba, { + bool normalize = true, +}) { + // Create an empty 3D list. + final Num3DInputMatrix imageMatrix = List.generate( + 3, + (i) => List.generate( + image.height, + (j) => List.filled(image.width, 0), + ), + ); + + // Determine which function to use to get the pixel value. + final pixelValue = normalize ? normalizePixelRange2 : (num value) => value; + + for (int y = 0; y < image.height; y++) { + for (int x = 0; x < image.width; x++) { + // Get the pixel at (x, y). + final pixel = readPixelColor(image, byteDataRgba, x, y); + + // Assign the color channels to the respective lists. + imageMatrix[0][y][x] = pixelValue(pixel.red); + imageMatrix[1][y][x] = pixelValue(pixel.green); + imageMatrix[2][y][x] = pixelValue(pixel.blue); + } + } + return imageMatrix; +} + +/// Function normalizes the pixel value to be in range [-1, 1]. +/// +/// It assumes that the pixel value is originally in range [0, 255] +double normalizePixelRange2(num pixelValue) { + return (pixelValue / 127.5) - 1; +} + +/// Function normalizes the pixel value to be in range [0, 1]. +/// +/// It assumes that the pixel value is originally in range [0, 255] +double normalizePixelRange1(num pixelValue) { + return (pixelValue / 255); +} + +double normalizePixelNoRange(num pixelValue) { + return pixelValue.toDouble(); +} + +/// Decodes [Uint8List] image data to an ui.[Image] object. +Future decodeImageFromData(Uint8List imageData) async { + // Decoding using flutter paint. This is the fastest and easiest method. + final Image image = await paint.decodeImageFromList(imageData); + return image; + + // // Similar decoding as above, but without using flutter paint. This is not faster than the above. + // final Codec codec = await instantiateImageCodecFromBuffer( + // await ImmutableBuffer.fromUint8List(imageData), + // ); + // final FrameInfo frameInfo = await codec.getNextFrame(); + // return frameInfo.image; + + // Decoding using the ImageProvider, same as `image_pixels` package. This is not faster than the above. + // final Completer completer = Completer(); + // final ImageProvider provider = MemoryImage(imageData); + // final ImageStream stream = provider.resolve(const ImageConfiguration()); + // final ImageStreamListener listener = + // ImageStreamListener((ImageInfo info, bool _) { + // completer.complete(info.image); + // }); + // stream.addListener(listener); + // final Image image = await completer.future; + // stream.removeListener(listener); + // return image; + + // // Decoding using the ImageProvider from material.Image. This is not faster than the above, and also the code below is not finished! + // final materialImage = material.Image.memory(imageData); + // final ImageProvider uiImage = await materialImage.image; +} + +/// Decodes [Uint8List] RGBA bytes to an ui.[Image] object. +Future decodeImageFromRgbaBytes( + Uint8List rgbaBytes, + int width, + int height, +) { + final Completer completer = Completer(); + decodeImageFromPixels( + rgbaBytes, + width, + height, + PixelFormat.rgba8888, + (Image image) { + completer.complete(image); + }, + ); + return completer.future; +} + +/// Returns the [ByteData] object of the image, in rawRgba format. +/// +/// Throws an exception if the image could not be converted to ByteData. +Future getByteDataFromImage( + Image image, { + ImageByteFormat format = ImageByteFormat.rawRgba, +}) async { + final ByteData? byteDataRgba = await image.toByteData(format: format); + if (byteDataRgba == null) { + log('[ImageMlUtils] Could not convert image to ByteData'); + throw Exception('Could not convert image to ByteData'); + } + return byteDataRgba; +} + +/// Encodes an [Image] object to a [Uint8List], by default in the png format. +/// +/// Note that the result can be used with `Image.memory()` only if the [format] is png. +Future encodeImageToUint8List( + Image image, { + ImageByteFormat format = ImageByteFormat.png, +}) async { + final ByteData byteDataPng = + await getByteDataFromImage(image, format: format); + final encodedImage = byteDataPng.buffer.asUint8List(); + + return encodedImage; +} + +/// Resizes the [image] to the specified [width] and [height]. +/// Returns the resized image and its size as a [Size] object. Note that this size excludes any empty pixels, hence it can be different than the actual image size if [maintainAspectRatio] is true. +/// +/// [quality] determines the interpolation quality. The default [FilterQuality.medium] works best for most cases, unless you're scaling by a factor of 5-10 or more +/// [maintainAspectRatio] determines whether to maintain the aspect ratio of the original image or not. Note that maintaining aspect ratio here does not change the size of the image, but instead often means empty pixels that have to be taken into account +Future<(Image, Size)> resizeImage( + Image image, + int width, + int height, { + FilterQuality quality = FilterQuality.medium, + bool maintainAspectRatio = false, +}) async { + if (image.width == width && image.height == height) { + return (image, Size(width.toDouble(), height.toDouble())); + } + final recorder = PictureRecorder(); + final canvas = Canvas( + recorder, + Rect.fromPoints( + const Offset(0, 0), + Offset(width.toDouble(), height.toDouble()), + ), + ); + // Pre-fill the canvas with RGB color (114, 114, 114) + canvas.drawRect( + Rect.fromPoints( + const Offset(0, 0), + Offset(width.toDouble(), height.toDouble()), + ), + Paint()..color = const Color.fromARGB(255, 114, 114, 114), + ); + + double scaleW = width / image.width; + double scaleH = height / image.height; + if (maintainAspectRatio) { + final scale = min(width / image.width, height / image.height); + scaleW = scale; + scaleH = scale; + } + final scaledWidth = (image.width * scaleW).round(); + final scaledHeight = (image.height * scaleH).round(); + + canvas.drawImageRect( + image, + Rect.fromPoints( + const Offset(0, 0), + Offset(image.width.toDouble(), image.height.toDouble()), + ), + Rect.fromPoints( + const Offset(0, 0), + Offset(scaledWidth.toDouble(), scaledHeight.toDouble()), + ), + Paint()..filterQuality = quality, + ); + + final picture = recorder.endRecording(); + final resizedImage = await picture.toImage(width, height); + return (resizedImage, Size(scaledWidth.toDouble(), scaledHeight.toDouble())); +} + +Future resizeAndCenterCropImage( + Image image, + int size, { + FilterQuality quality = FilterQuality.medium, +}) async { + if (image.width == size && image.height == size) { + return image; + } + final recorder = PictureRecorder(); + final canvas = Canvas( + recorder, + Rect.fromPoints( + const Offset(0, 0), + Offset(size.toDouble(), size.toDouble()), + ), + ); + + final scale = max(size / image.width, size / image.height); + final scaledWidth = (image.width * scale).round(); + final scaledHeight = (image.height * scale).round(); + + canvas.drawImageRect( + image, + Rect.fromPoints( + const Offset(0, 0), + Offset(image.width.toDouble(), image.height.toDouble()), + ), + Rect.fromPoints( + const Offset(0, 0), + Offset(scaledWidth.toDouble(), scaledHeight.toDouble()), + ), + Paint()..filterQuality = quality, + ); + + final picture = recorder.endRecording(); + final resizedImage = await picture.toImage(size, size); + return resizedImage; +} + +/// Crops an [image] based on the specified [x], [y], [width] and [height]. +/// Optionally, the cropped image can be resized to comply with a [maxSize] and/or [minSize]. +/// Optionally, the cropped image can be rotated from the center by [rotation] radians. +/// Optionally, the [quality] of the resizing interpolation can be specified. +Future cropImage( + Image image, { + required double x, + required double y, + required double width, + required double height, + Size? maxSize, + Size? minSize, + double rotation = 0.0, // rotation in radians + FilterQuality quality = FilterQuality.medium, +}) async { + // Calculate the scale for resizing based on maxSize and minSize + double scaleX = 1.0; + double scaleY = 1.0; + if (maxSize != null) { + final minScale = min(maxSize.width / width, maxSize.height / height); + if (minScale < 1.0) { + scaleX = minScale; + scaleY = minScale; + } + } + if (minSize != null) { + final maxScale = max(minSize.width / width, minSize.height / height); + if (maxScale > 1.0) { + scaleX = maxScale; + scaleY = maxScale; + } + } + + // Calculate the final dimensions + final targetWidth = (width * scaleX).round(); + final targetHeight = (height * scaleY).round(); + + // Create the canvas + final recorder = PictureRecorder(); + final canvas = Canvas( + recorder, + Rect.fromPoints( + const Offset(0, 0), + Offset(targetWidth.toDouble(), targetHeight.toDouble()), + ), + ); + + // Apply rotation + final center = Offset(targetWidth / 2, targetHeight / 2); + canvas.translate(center.dx, center.dy); + canvas.rotate(rotation); + + // Enlarge both the source and destination boxes to account for the rotation (i.e. avoid cropping the corners of the image) + final List enlargedSrc = + getEnlargedAbsoluteBox([x, y, x + width, y + height], 1.5); + final List enlargedDst = getEnlargedAbsoluteBox( + [ + -center.dx, + -center.dy, + -center.dx + targetWidth, + -center.dy + targetHeight, + ], + 1.5, + ); + + canvas.drawImageRect( + image, + Rect.fromPoints( + Offset(enlargedSrc[0], enlargedSrc[1]), + Offset(enlargedSrc[2], enlargedSrc[3]), + ), + Rect.fromPoints( + Offset(enlargedDst[0], enlargedDst[1]), + Offset(enlargedDst[2], enlargedDst[3]), + ), + Paint()..filterQuality = quality, + ); + + final picture = recorder.endRecording(); + + return picture.toImage(targetWidth, targetHeight); +} + +/// Adds padding around an [Image] object. +Future addPaddingToImage( + Image image, [ + double padding = 0.5, +]) async { + const Color paddingColor = Color.fromARGB(0, 0, 0, 0); + final originalWidth = image.width; + final originalHeight = image.height; + + final paddedWidth = (originalWidth + 2 * padding * originalWidth).toInt(); + final paddedHeight = (originalHeight + 2 * padding * originalHeight).toInt(); + + final recorder = PictureRecorder(); + final canvas = Canvas( + recorder, + Rect.fromPoints( + const Offset(0, 0), + Offset(paddedWidth.toDouble(), paddedHeight.toDouble()), + ), + ); + + final paint = Paint(); + paint.color = paddingColor; + + // Draw the padding + canvas.drawRect( + Rect.fromPoints( + const Offset(0, 0), + Offset(paddedWidth.toDouble(), paddedHeight.toDouble()), + ), + paint, + ); + + // Draw the original image on top of the padding + canvas.drawImageRect( + image, + Rect.fromPoints( + const Offset(0, 0), + Offset(image.width.toDouble(), image.height.toDouble()), + ), + Rect.fromPoints( + Offset(padding * originalWidth, padding * originalHeight), + Offset( + (1 + padding) * originalWidth, + (1 + padding) * originalHeight, + ), + ), + Paint()..filterQuality = FilterQuality.none, + ); + + final picture = recorder.endRecording(); + return picture.toImage(paddedWidth, paddedHeight); +} + +/// Preprocesses [imageData] for standard ML models. +/// Returns a [Num3DInputMatrix] image, ready for inference. +/// Also returns the original image size and the new image size, respectively. +/// +/// The [imageData] argument must be a [Uint8List] object. +/// The [normalize] argument determines whether the image is normalized to range [-1, 1]. +/// The [requiredWidth] and [requiredHeight] arguments determine the size of the output image. +/// The [quality] argument determines the quality of the resizing interpolation. +/// The [maintainAspectRatio] argument determines whether the aspect ratio of the image is maintained. +@Deprecated("Old method used in blazeface") +Future<(Num3DInputMatrix, Size, Size)> preprocessImageToMatrix( + Uint8List imageData, { + required int normalization, + required int requiredWidth, + required int requiredHeight, + FilterQuality quality = FilterQuality.medium, + maintainAspectRatio = true, +}) async { + final normFunction = normalization == 2 + ? normalizePixelRange2 + : normalization == 1 + ? normalizePixelRange1 + : normalizePixelNoRange; + final Image image = await decodeImageFromData(imageData); + final originalSize = Size(image.width.toDouble(), image.height.toDouble()); + + if (image.width == requiredWidth && image.height == requiredHeight) { + final ByteData imgByteData = await getByteDataFromImage(image); + return ( + createInputMatrixFromImage( + image, + imgByteData, + normFunction: normFunction, + ), + originalSize, + originalSize + ); + } + + final (resizedImage, newSize) = await resizeImage( + image, + requiredWidth, + requiredHeight, + quality: quality, + maintainAspectRatio: maintainAspectRatio, + ); + + final ByteData imgByteData = await getByteDataFromImage(resizedImage); + final Num3DInputMatrix imageMatrix = createInputMatrixFromImage( + resizedImage, + imgByteData, + normFunction: normFunction, + ); + + return (imageMatrix, originalSize, newSize); +} + +Future<(Float32List, Size, Size)> preprocessImageToFloat32ChannelsFirst( + Uint8List imageData, { + required int normalization, + required int requiredWidth, + required int requiredHeight, + FilterQuality quality = FilterQuality.medium, + maintainAspectRatio = true, +}) async { + final normFunction = normalization == 2 + ? normalizePixelRange2 + : normalization == 1 + ? normalizePixelRange1 + : normalizePixelNoRange; + final stopwatch = Stopwatch()..start(); + final Image image = await decodeImageFromData(imageData); + stopwatch.stop(); + log("Face Detection decoding ui image took: ${stopwatch.elapsedMilliseconds} ms"); + final originalSize = Size(image.width.toDouble(), image.height.toDouble()); + late final Image resizedImage; + late final Size newSize; + + if (image.width == requiredWidth && image.height == requiredHeight) { + resizedImage = image; + newSize = originalSize; + } else { + (resizedImage, newSize) = await resizeImage( + image, + requiredWidth, + requiredHeight, + quality: quality, + maintainAspectRatio: maintainAspectRatio, + ); + } + final ByteData imgByteData = await getByteDataFromImage(resizedImage); + final Float32List imageFloat32List = createFloat32ListFromImageChannelsFirst( + resizedImage, + imgByteData, + normFunction: normFunction, + ); + + return (imageFloat32List, originalSize, newSize); +} + +/// Preprocesses [imageData] based on [faceLandmarks] to align the faces in the images. +/// +/// Returns a list of [Uint8List] images, one for each face, in png format. +@Deprecated("Old method used in blazeface") +Future> preprocessFaceAlignToUint8List( + Uint8List imageData, + List>> faceLandmarks, { + int width = 112, + int height = 112, +}) async { + final alignedImages = []; + final Image image = await decodeImageFromData(imageData); + + for (final faceLandmark in faceLandmarks) { + final (alignmentResult, correctlyEstimated) = + SimilarityTransform.instance.estimate(faceLandmark); + if (!correctlyEstimated) { + alignedImages.add(Uint8List(0)); + continue; + } + final alignmentBox = getAlignedFaceBox(alignmentResult); + final Image alignedFace = await cropImage( + image, + x: alignmentBox[0], + y: alignmentBox[1], + width: alignmentBox[2] - alignmentBox[0], + height: alignmentBox[3] - alignmentBox[1], + maxSize: Size(width.toDouble(), height.toDouble()), + minSize: Size(width.toDouble(), height.toDouble()), + rotation: alignmentResult.rotation, + ); + final Uint8List alignedFacePng = await encodeImageToUint8List(alignedFace); + alignedImages.add(alignedFacePng); + + // final Uint8List alignedImageRGBA = await warpAffineToUint8List( + // image, + // imgByteData, + // alignmentResult.affineMatrix + // .map( + // (row) => row.map((e) { + // if (e != 1.0) { + // return e * 112; + // } else { + // return 1.0; + // } + // }).toList(), + // ) + // .toList(), + // width: width, + // height: height, + // ); + // final Image alignedImage = + // await decodeImageFromRgbaBytes(alignedImageRGBA, width, height); + // final Uint8List alignedImagePng = + // await encodeImageToUint8List(alignedImage); + + // alignedImages.add(alignedImagePng); + } + return alignedImages; +} + +/// Preprocesses [imageData] based on [faceLandmarks] to align the faces in the images +/// +/// Returns a list of [Num3DInputMatrix] images, one for each face, ready for MobileFaceNet inference +Future< + ( + List, + List, + List, + List, + Size, + )> preprocessToMobileFaceNetInput( + Uint8List imageData, + List> facesJson, { + int width = 112, + int height = 112, +}) async { + final Image image = await decodeImageFromData(imageData); + final Size originalSize = + Size(image.width.toDouble(), image.height.toDouble()); + + final List relativeFaces = + facesJson.map((face) => FaceDetectionRelative.fromJson(face)).toList(); + + final List absoluteFaces = + relativeToAbsoluteDetections( + relativeDetections: relativeFaces, + imageWidth: image.width, + imageHeight: image.height, + ); + + final List>> faceLandmarks = + absoluteFaces.map((face) => face.allKeypoints).toList(); + + final alignedImages = []; + final alignmentResults = []; + final isBlurs = []; + final blurValues = []; + + for (final faceLandmark in faceLandmarks) { + final (alignmentResult, correctlyEstimated) = + SimilarityTransform.instance.estimate(faceLandmark); + if (!correctlyEstimated) { + alignedImages.add([]); + alignmentResults.add(AlignmentResult.empty()); + continue; + } + final alignmentBox = getAlignedFaceBox(alignmentResult); + final Image alignedFace = await cropImage( + image, + x: alignmentBox[0], + y: alignmentBox[1], + width: alignmentBox[2] - alignmentBox[0], + height: alignmentBox[3] - alignmentBox[1], + maxSize: Size(width.toDouble(), height.toDouble()), + minSize: Size(width.toDouble(), height.toDouble()), + rotation: alignmentResult.rotation, + quality: FilterQuality.medium, + ); + final alignedFaceByteData = await getByteDataFromImage(alignedFace); + final alignedFaceMatrix = createInputMatrixFromImage( + alignedFace, + alignedFaceByteData, + normFunction: normalizePixelRange2, + ); + alignedImages.add(alignedFaceMatrix); + alignmentResults.add(alignmentResult); + final faceGrayMatrix = createGrayscaleIntMatrixFromImage( + alignedFace, + alignedFaceByteData, + ); + final (isBlur, blurValue) = await BlurDetectionService.instance + .predictIsBlurGrayLaplacian(faceGrayMatrix); + isBlurs.add(isBlur); + blurValues.add(blurValue); + + // final Double3DInputMatrix alignedImage = await warpAffineToMatrix( + // image, + // imgByteData, + // transformationMatrix, + // width: width, + // height: height, + // normalize: true, + // ); + // alignedImages.add(alignedImage); + // transformationMatrices.add(transformationMatrix); + } + return (alignedImages, alignmentResults, isBlurs, blurValues, originalSize); +} + +Future<(Float32List, List, List, List, Size)> + preprocessToMobileFaceNetFloat32List( + String imagePath, + List relativeFaces, { + int width = 112, + int height = 112, +}) async { + final Uint8List imageData = await File(imagePath).readAsBytes(); + final stopwatch = Stopwatch()..start(); + final Image image = await decodeImageFromData(imageData); + stopwatch.stop(); + log("Face Alignment decoding ui image took: ${stopwatch.elapsedMilliseconds} ms"); + final Size originalSize = + Size(image.width.toDouble(), image.height.toDouble()); + + final List absoluteFaces = + relativeToAbsoluteDetections( + relativeDetections: relativeFaces, + imageWidth: image.width, + imageHeight: image.height, + ); + + final List>> faceLandmarks = + absoluteFaces.map((face) => face.allKeypoints).toList(); + + final alignedImagesFloat32List = + Float32List(3 * width * height * faceLandmarks.length); + final alignmentResults = []; + final isBlurs = []; + final blurValues = []; + + int alignedImageIndex = 0; + for (final faceLandmark in faceLandmarks) { + final (alignmentResult, correctlyEstimated) = + SimilarityTransform.instance.estimate(faceLandmark); + if (!correctlyEstimated) { + alignedImageIndex += 3 * width * height; + alignmentResults.add(AlignmentResult.empty()); + continue; + } + final alignmentBox = getAlignedFaceBox(alignmentResult); + final Image alignedFace = await cropImage( + image, + x: alignmentBox[0], + y: alignmentBox[1], + width: alignmentBox[2] - alignmentBox[0], + height: alignmentBox[3] - alignmentBox[1], + maxSize: Size(width.toDouble(), height.toDouble()), + minSize: Size(width.toDouble(), height.toDouble()), + rotation: alignmentResult.rotation, + quality: FilterQuality.medium, + ); + final alignedFaceByteData = await getByteDataFromImage(alignedFace); + addInputImageToFloat32List( + alignedFace, + alignedFaceByteData, + alignedImagesFloat32List, + alignedImageIndex, + normFunction: normalizePixelRange2, + ); + alignedImageIndex += 3 * width * height; + alignmentResults.add(alignmentResult); + final blurDetectionStopwatch = Stopwatch()..start(); + final faceGrayMatrix = createGrayscaleIntMatrixFromImage( + alignedFace, + alignedFaceByteData, + ); + final grascalems = blurDetectionStopwatch.elapsedMilliseconds; + log('creating grayscale matrix took $grascalems ms'); + final (isBlur, blurValue) = await BlurDetectionService.instance + .predictIsBlurGrayLaplacian(faceGrayMatrix); + final blurms = blurDetectionStopwatch.elapsedMilliseconds - grascalems; + log('blur detection took $blurms ms'); + log( + 'total blur detection took ${blurDetectionStopwatch.elapsedMilliseconds} ms', + ); + blurDetectionStopwatch.stop(); + isBlurs.add(isBlur); + blurValues.add(blurValue); + } + return ( + alignedImagesFloat32List, + alignmentResults, + isBlurs, + blurValues, + originalSize + ); +} + +/// Function to warp an image [imageData] with an affine transformation using the estimated [transformationMatrix]. +/// +/// Returns the warped image in the specified width and height, in [Uint8List] RGBA format. +Future warpAffineToUint8List( + Image inputImage, + ByteData imgByteDataRgba, + List> transformationMatrix, { + required int width, + required int height, +}) async { + final Uint8List outputList = Uint8List(4 * width * height); + + if (width != 112 || height != 112) { + throw Exception( + 'Width and height must be 112, other transformations are not supported yet.', + ); + } + + final A = Matrix.fromList([ + [transformationMatrix[0][0], transformationMatrix[0][1]], + [transformationMatrix[1][0], transformationMatrix[1][1]], + ]); + final aInverse = A.inverse(); + // final aInverseMinus = aInverse * -1; + final B = Vector.fromList( + [transformationMatrix[0][2], transformationMatrix[1][2]], + ); + final b00 = B[0]; + final b10 = B[1]; + final a00Prime = aInverse[0][0]; + final a01Prime = aInverse[0][1]; + final a10Prime = aInverse[1][0]; + final a11Prime = aInverse[1][1]; + + for (int yTrans = 0; yTrans < height; ++yTrans) { + for (int xTrans = 0; xTrans < width; ++xTrans) { + // Perform inverse affine transformation (original implementation, intuitive but slow) + // final X = aInverse * (Vector.fromList([xTrans, yTrans]) - B); + // final X = aInverseMinus * (B - [xTrans, yTrans]); + // final xList = X.asFlattenedList; + // num xOrigin = xList[0]; + // num yOrigin = xList[1]; + + // Perform inverse affine transformation (fast implementation, less intuitive) + num xOrigin = (xTrans - b00) * a00Prime + (yTrans - b10) * a01Prime; + num yOrigin = (xTrans - b00) * a10Prime + (yTrans - b10) * a11Prime; + + // Clamp to image boundaries + xOrigin = xOrigin.clamp(0, inputImage.width - 1); + yOrigin = yOrigin.clamp(0, inputImage.height - 1); + + // Bilinear interpolation + final int x0 = xOrigin.floor(); + final int x1 = xOrigin.ceil(); + final int y0 = yOrigin.floor(); + final int y1 = yOrigin.ceil(); + + // Get the original pixels + final Color pixel1 = readPixelColor(inputImage, imgByteDataRgba, x0, y0); + final Color pixel2 = readPixelColor(inputImage, imgByteDataRgba, x1, y0); + final Color pixel3 = readPixelColor(inputImage, imgByteDataRgba, x0, y1); + final Color pixel4 = readPixelColor(inputImage, imgByteDataRgba, x1, y1); + + // Calculate the weights for each pixel + final fx = xOrigin - x0; + final fy = yOrigin - y0; + final fx1 = 1.0 - fx; + final fy1 = 1.0 - fy; + + // Calculate the weighted sum of pixels + final int r = bilinearInterpolation( + pixel1.red, + pixel2.red, + pixel3.red, + pixel4.red, + fx, + fy, + fx1, + fy1, + ); + final int g = bilinearInterpolation( + pixel1.green, + pixel2.green, + pixel3.green, + pixel4.green, + fx, + fy, + fx1, + fy1, + ); + final int b = bilinearInterpolation( + pixel1.blue, + pixel2.blue, + pixel3.blue, + pixel4.blue, + fx, + fy, + fx1, + fy1, + ); + + // Set the new pixel + outputList[4 * (yTrans * width + xTrans)] = r; + outputList[4 * (yTrans * width + xTrans) + 1] = g; + outputList[4 * (yTrans * width + xTrans) + 2] = b; + outputList[4 * (yTrans * width + xTrans) + 3] = 255; + } + } + + return outputList; +} + +/// Function to warp an image [imageData] with an affine transformation using the estimated [transformationMatrix]. +/// +/// Returns a [Num3DInputMatrix], potentially normalized (RGB) and ready to be used as input for a ML model. +Future warpAffineToMatrix( + Image inputImage, + ByteData imgByteDataRgba, + List> transformationMatrix, { + required int width, + required int height, + bool normalize = true, +}) async { + final List>> outputMatrix = List.generate( + height, + (y) => List.generate( + width, + (_) => List.filled(3, 0.0), + ), + ); + final double Function(num) pixelValue = + normalize ? normalizePixelRange2 : (num value) => value.toDouble(); + + if (width != 112 || height != 112) { + throw Exception( + 'Width and height must be 112, other transformations are not supported yet.', + ); + } + + final A = Matrix.fromList([ + [transformationMatrix[0][0], transformationMatrix[0][1]], + [transformationMatrix[1][0], transformationMatrix[1][1]], + ]); + final aInverse = A.inverse(); + // final aInverseMinus = aInverse * -1; + final B = Vector.fromList( + [transformationMatrix[0][2], transformationMatrix[1][2]], + ); + final b00 = B[0]; + final b10 = B[1]; + final a00Prime = aInverse[0][0]; + final a01Prime = aInverse[0][1]; + final a10Prime = aInverse[1][0]; + final a11Prime = aInverse[1][1]; + + for (int yTrans = 0; yTrans < height; ++yTrans) { + for (int xTrans = 0; xTrans < width; ++xTrans) { + // Perform inverse affine transformation (original implementation, intuitive but slow) + // final X = aInverse * (Vector.fromList([xTrans, yTrans]) - B); + // final X = aInverseMinus * (B - [xTrans, yTrans]); + // final xList = X.asFlattenedList; + // num xOrigin = xList[0]; + // num yOrigin = xList[1]; + + // Perform inverse affine transformation (fast implementation, less intuitive) + num xOrigin = (xTrans - b00) * a00Prime + (yTrans - b10) * a01Prime; + num yOrigin = (xTrans - b00) * a10Prime + (yTrans - b10) * a11Prime; + + // Clamp to image boundaries + xOrigin = xOrigin.clamp(0, inputImage.width - 1); + yOrigin = yOrigin.clamp(0, inputImage.height - 1); + + // Bilinear interpolation + final int x0 = xOrigin.floor(); + final int x1 = xOrigin.ceil(); + final int y0 = yOrigin.floor(); + final int y1 = yOrigin.ceil(); + + // Get the original pixels + final Color pixel1 = readPixelColor(inputImage, imgByteDataRgba, x0, y0); + final Color pixel2 = readPixelColor(inputImage, imgByteDataRgba, x1, y0); + final Color pixel3 = readPixelColor(inputImage, imgByteDataRgba, x0, y1); + final Color pixel4 = readPixelColor(inputImage, imgByteDataRgba, x1, y1); + + // Calculate the weights for each pixel + final fx = xOrigin - x0; + final fy = yOrigin - y0; + final fx1 = 1.0 - fx; + final fy1 = 1.0 - fy; + + // Calculate the weighted sum of pixels + final int r = bilinearInterpolation( + pixel1.red, + pixel2.red, + pixel3.red, + pixel4.red, + fx, + fy, + fx1, + fy1, + ); + final int g = bilinearInterpolation( + pixel1.green, + pixel2.green, + pixel3.green, + pixel4.green, + fx, + fy, + fx1, + fy1, + ); + final int b = bilinearInterpolation( + pixel1.blue, + pixel2.blue, + pixel3.blue, + pixel4.blue, + fx, + fy, + fx1, + fy1, + ); + + // Set the new pixel + outputMatrix[yTrans][xTrans] = [ + pixelValue(r), + pixelValue(g), + pixelValue(b), + ]; + } + } + + return outputMatrix; +} + +/// Generates a face thumbnail from [imageData] and a [faceDetection]. +/// +/// Returns a [Uint8List] image, in png format. +Future generateFaceThumbnailFromData( + Uint8List imageData, + FaceDetectionRelative faceDetection, +) async { + final Image image = await decodeImageFromData(imageData); + + final Image faceThumbnail = await cropImage( + image, + x: (faceDetection.xMinBox * image.width).round() - 20, + y: (faceDetection.yMinBox * image.height).round() - 30, + width: (faceDetection.width * image.width).round() + 40, + height: (faceDetection.height * image.height).round() + 60, + ); + + return await encodeImageToUint8List( + faceThumbnail, + format: ImageByteFormat.png, + ); +} + +/// Generates a face thumbnail from [imageData] and a [faceDetection]. +/// +/// Returns a [Uint8List] image, in png format. +Future> generateFaceThumbnailsFromDataAndDetections( + Uint8List imageData, + List faceBoxes, +) async { + final Image image = await decodeImageFromData(imageData); + int i = 0; + + try { + final List faceThumbnails = []; + + for (final faceBox in faceBoxes) { + final Image faceThumbnail = await cropImage( + image, + x: faceBox.x - faceBox.width / 2, + y: faceBox.y - faceBox.height / 2, + width: faceBox.width * 2, + height: faceBox.height * 2, + ); + final Uint8List faceThumbnailPng = await encodeImageToUint8List( + faceThumbnail, + format: ImageByteFormat.png, + ); + faceThumbnails.add(faceThumbnailPng); + i++; + } + return faceThumbnails; + } catch (e) { + log('[ImageMlUtils] Error generating face thumbnails: $e'); + log('[ImageMlUtils] cropImage problematic input argument: ${faceBoxes[i]}'); + return []; + } +} + +/// Generates cropped and padded image data from [imageData] and a [faceBox]. +/// +/// The steps are: +/// 1. Crop the image to the face bounding box +/// 2. Resize this cropped image to a square that is half the BlazeFace input size +/// 3. Pad the image to the BlazeFace input size +/// +/// Note that [faceBox] is a list of the following values: [xMinBox, yMinBox, xMaxBox, yMaxBox]. +Future cropAndPadFaceData( + Uint8List imageData, + List faceBox, +) async { + final Image image = await decodeImageFromData(imageData); + + final Image faceCrop = await cropImage( + image, + x: (faceBox[0] * image.width), + y: (faceBox[1] * image.height), + width: ((faceBox[2] - faceBox[0]) * image.width), + height: ((faceBox[3] - faceBox[1]) * image.height), + maxSize: const Size(128, 128), + minSize: const Size(128, 128), + ); + + final Image facePadded = await addPaddingToImage( + faceCrop, + 0.5, + ); + + return await encodeImageToUint8List(facePadded); +} + +int bilinearInterpolation( + num val1, + num val2, + num val3, + num val4, + num fx, + num fy, + num fx1, + num fy1, +) { + return (val1 * fx1 * fy1 + val2 * fx * fy1 + val3 * fx1 * fy + val4 * fx * fy) + .round(); +} + +List getAlignedFaceBox(AlignmentResult alignment) { + final List box = [ + // [xMinBox, yMinBox, xMaxBox, yMaxBox] + alignment.center[0] - alignment.size / 2, + alignment.center[1] - alignment.size / 2, + alignment.center[0] + alignment.size / 2, + alignment.center[1] + alignment.size / 2, + ]; + box.roundBoxToDouble(); + return box; +} + +/// Returns an enlarged version of the [box] by a factor of [factor]. +/// The [box] is in absolute coordinates: [xMinBox, yMinBox, xMaxBox, yMaxBox]. +List getEnlargedAbsoluteBox(List box, [double factor = 2]) { + final boxCopy = List.from(box, growable: false); + // The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox]. + + final width = boxCopy[2] - boxCopy[0]; + final height = boxCopy[3] - boxCopy[1]; + + boxCopy[0] -= width * (factor - 1) / 2; + boxCopy[1] -= height * (factor - 1) / 2; + boxCopy[2] += width * (factor - 1) / 2; + boxCopy[3] += height * (factor - 1) / 2; + + return boxCopy; +} diff --git a/mobile/lib/utils/local_settings.dart b/mobile/lib/utils/local_settings.dart index 2f277c80b..cba446065 100644 --- a/mobile/lib/utils/local_settings.dart +++ b/mobile/lib/utils/local_settings.dart @@ -14,6 +14,8 @@ class LocalSettings { static const kCollectionSortPref = "collection_sort_pref"; static const kPhotoGridSize = "photo_grid_size"; static const kEnableMagicSearch = "enable_magic_search"; + static const kEnableFaceIndexing = "enable_face_indexing"; + static const kEnableFaceClustering = "enable_face_clustering"; static const kRateUsShownCount = "rate_us_shown_count"; static const kRateUsPromptThreshold = 2; @@ -44,9 +46,10 @@ class LocalSettings { } bool hasEnabledMagicSearch() { - if (_prefs.containsKey(kEnableMagicSearch)) { - return _prefs.getBool(kEnableMagicSearch)!; - } + // TODO: change this back by uncommenting the line below + // if (_prefs.containsKey(kEnableMagicSearch)) { + // return _prefs.getBool(kEnableMagicSearch)!; + // } return false; } @@ -69,4 +72,22 @@ class LocalSettings { bool shouldPromptToRateUs() { return getRateUsShownCount() < kRateUsPromptThreshold; } + + bool get isFaceIndexingEnabled => + _prefs.getBool(kEnableFaceIndexing) ?? false; + + bool get isFaceClusteringEnabled => + _prefs.getBool(kEnableFaceIndexing) ?? false; + + /// toggleFaceIndexing toggles the face indexing setting and returns the new value + Future toggleFaceIndexing() async { + await _prefs.setBool(kEnableFaceIndexing, !isFaceIndexingEnabled); + return isFaceIndexingEnabled; + } + + /// toggleFaceClustering toggles the face clustering setting and returns the new value + Future toggleFaceClustering() async { + await _prefs.setBool(kEnableFaceClustering, !isFaceClusteringEnabled); + return isFaceClusteringEnabled; + } } diff --git a/mobile/lib/utils/thumbnail_util.dart b/mobile/lib/utils/thumbnail_util.dart index dc2167632..db7648b92 100644 --- a/mobile/lib/utils/thumbnail_util.dart +++ b/mobile/lib/utils/thumbnail_util.dart @@ -217,3 +217,11 @@ File cachedThumbnailPath(EnteFile file) { thumbnailCacheDirectory + "/" + file.uploadedFileID.toString(), ); } + +File cachedFaceCropPath(String faceID) { + final thumbnailCacheDirectory = + Configuration.instance.getThumbnailCacheDirectory(); + return File( + thumbnailCacheDirectory + "/" + faceID, + ); +} diff --git a/mobile/pubspec.lock b/mobile/pubspec.lock index 7298e7134..9f8ee1b28 100644 --- a/mobile/pubspec.lock +++ b/mobile/pubspec.lock @@ -686,6 +686,14 @@ packages: url: "https://pub.dev" source: hosted version: "5.8.0" + flutter_isolate: + dependency: "direct main" + description: + name: flutter_isolate + sha256: "994ddec596da4ca12ca52154fd59404077584643eb7e3f1008a55fda9fe0b76b" + url: "https://pub.dev" + source: hosted + version: "2.0.4" flutter_launcher_icons: dependency: "direct main" description: @@ -1336,6 +1344,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.0.4" + ml_linalg: + dependency: "direct main" + description: + name: ml_linalg + sha256: "36b658c6619c4e1b47f8e2edd100308fd75a5b71d1e6359d8be0cbe06214b441" + url: "https://pub.dev" + source: hosted + version: "13.11.31" modal_bottom_sheet: dependency: "direct main" description: @@ -1604,7 +1620,7 @@ packages: source: hosted version: "1.0.1" pool: - dependency: transitive + dependency: "direct main" description: name: pool sha256: "20fe868b6314b322ea036ba325e6fc0711a22948856475e2c2b6306e8ab39c2a" @@ -1627,6 +1643,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.1.0" + protobuf: + dependency: "direct main" + description: + name: protobuf + sha256: "68645b24e0716782e58948f8467fd42a880f255096a821f9e7d0ec625b00c84d" + url: "https://pub.dev" + source: hosted + version: "3.1.0" provider: dependency: "direct main" description: @@ -1891,6 +1915,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.0.4" + simple_cluster: + dependency: "direct main" + description: + name: simple_cluster + sha256: "64d6b7d60d641299ad8c3f012417c711532792c1bc61ac6a7f52b942cdba65da" + url: "https://pub.dev" + source: hosted + version: "0.3.0" sky_engine: dependency: transitive description: flutter @@ -2057,7 +2089,7 @@ packages: source: hosted version: "19.4.56" synchronized: - dependency: transitive + dependency: "direct main" description: name: synchronized sha256: "539ef412b170d65ecdafd780f924e5be3f60032a1128df156adad6c5b373d558" @@ -2096,6 +2128,14 @@ packages: url: "https://pub.dev" source: hosted version: "0.5.3" + tflite_flutter: + dependency: "direct main" + description: + name: tflite_flutter + sha256: ffb8651fdb116ab0131d6dc47ff73883e0f634ad1ab12bb2852eef1bbeab4a6a + url: "https://pub.dev" + source: hosted + version: "0.10.4" time: dependency: transitive description: diff --git a/mobile/pubspec.yaml b/mobile/pubspec.yaml index 578331c8d..4bc1d80a8 100644 --- a/mobile/pubspec.yaml +++ b/mobile/pubspec.yaml @@ -113,6 +113,7 @@ dependencies: logging: ^1.0.1 lottie: ^1.2.2 media_extension: ^1.0.1 + ml_linalg: ^13.11.31 media_kit: ^1.1.10+1 media_kit_libs_video: ^1.0.4 media_kit_video: ^1.2.4 @@ -147,6 +148,7 @@ dependencies: sentry_flutter: ^7.9.0 share_plus: ^4.0.10 shared_preferences: ^2.0.5 + simple_cluster: ^0.3.0 sqflite: ^2.3.0 sqflite_migration: ^0.3.0 sqlite3: ^2.1.0 @@ -155,11 +157,11 @@ dependencies: styled_text: ^7.0.0 syncfusion_flutter_core: ^19.2.49 syncfusion_flutter_sliders: ^19.2.49 - # tflite_flutter: ^0.9.0 - # tflite_flutter_helper: - # git: - # url: https://github.com/pnyompen/tflite_flutter_helper.git - # ref: 43e87d4b9627539266dc20250beb35bf36320dce + synchronized: ^3.1.0 + tflite_flutter: ^0.10.1 + # tflite_flutter_helper: + # git: + # url: https://github.com/pnyompen/tflite_flutter_helper.git # Fixes https://github.com/am15h/tflite_flutter_helper/issues/57 tuple: ^2.0.0 uni_links: ^0.5.1 url_launcher: ^6.0.3 @@ -175,6 +177,9 @@ dependencies: wallpaper_manager_flutter: ^0.0.2 wechat_assets_picker: ^8.6.3 widgets_to_image: ^0.0.2 + flutter_isolate: ^2.0.4 + protobuf: ^3.1.0 + pool: ^1.5.1 workmanager: ^0.5.2 dependency_overrides: @@ -238,8 +243,10 @@ flutter: assets: - assets/ - assets/models/cocossd/ + - assets/models/mobilefacenet/ - assets/models/mobilenet/ - assets/models/scenes/ + - assets/models/yolov5face/ - assets/models/clip/ fonts: - family: Inter