diff --git a/.gitignore b/.gitignore index 8699b46ee..0901b55d6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,6 @@ -# Let folks use their custom .vscode settings +# Let folks use their custom editor settings .vscode +.idea # macOS .DS_Store -.idea -.ente.authenticator.db -.ente.offline_authenticator.db diff --git a/auth/ios/Podfile.lock b/auth/ios/Podfile.lock index 3db46ac14..814568fba 100644 --- a/auth/ios/Podfile.lock +++ b/auth/ios/Podfile.lock @@ -85,20 +85,16 @@ PODS: - SDWebImage (5.19.2): - SDWebImage/Core (= 5.19.2) - SDWebImage/Core (5.19.2) - - Sentry/HybridSDK (8.21.0): - - SentryPrivate (= 8.21.0) - - sentry_flutter (0.0.1): + - Sentry/HybridSDK (8.25.0) + - sentry_flutter (7.20.1): - Flutter - FlutterMacOS - - Sentry/HybridSDK (= 8.21.0) - - SentryPrivate (8.21.0) + - Sentry/HybridSDK (= 8.25.0) - share_plus (0.0.1): - Flutter - shared_preferences_foundation (0.0.1): - Flutter - FlutterMacOS - - smart_auth (0.0.1): - - Flutter - sodium_libs (2.2.1): - Flutter - sqflite (0.0.3): @@ -115,7 +111,7 @@ PODS: - sqlite3/common - sqlite3_flutter_libs (0.0.1): - Flutter - - sqlite3 (~> 3.45.1) + - "sqlite3 (~> 3.45.3+1)" - sqlite3/fts5 - sqlite3/perf-threadsafe - sqlite3/rtree @@ -148,7 +144,6 @@ DEPENDENCIES: - sentry_flutter (from `.symlinks/plugins/sentry_flutter/ios`) - share_plus (from `.symlinks/plugins/share_plus/ios`) - shared_preferences_foundation (from `.symlinks/plugins/shared_preferences_foundation/darwin`) - - smart_auth (from `.symlinks/plugins/smart_auth/ios`) - sodium_libs (from `.symlinks/plugins/sodium_libs/ios`) - sqflite (from `.symlinks/plugins/sqflite/darwin`) - sqlite3_flutter_libs (from `.symlinks/plugins/sqlite3_flutter_libs/ios`) @@ -163,7 +158,6 @@ SPEC REPOS: - ReachabilitySwift - SDWebImage - Sentry - - SentryPrivate - sqlite3 - SwiftyGif - Toast @@ -215,8 +209,6 @@ EXTERNAL SOURCES: :path: ".symlinks/plugins/share_plus/ios" shared_preferences_foundation: :path: ".symlinks/plugins/shared_preferences_foundation/darwin" - smart_auth: - :path: ".symlinks/plugins/smart_auth/ios" sodium_libs: :path: ".symlinks/plugins/sodium_libs/ios" sqflite: @@ -236,13 +228,13 @@ SPEC CHECKSUMS: file_saver: 503e386464dbe118f630e17b4c2e1190fa0cf808 fk_user_agent: 1f47ec39291e8372b1d692b50084b0d54103c545 Flutter: e0871f40cf51350855a761d2e70bf5af5b9b5de7 - flutter_email_sender: 02d7443217d8c41483223627972bfdc09f74276b + flutter_email_sender: 10a22605f92809a11ef52b2f412db806c6082d40 flutter_inappwebview_ios: 97215cf7d4677db55df76782dbd2930c5e1c1ea0 flutter_local_authentication: 1172a4dd88f6306dadce067454e2c4caf07977bb flutter_local_notifications: 4cde75091f6327eb8517fa068a0a5950212d2086 flutter_native_splash: edf599c81f74d093a4daf8e17bd7a018854bc778 flutter_secure_storage: 23fc622d89d073675f2eaa109381aefbcf5a49be - fluttertoast: 31b00dabfa7fb7bacd9e7dbee580d7a2ff4bf265 + fluttertoast: 9f2f8e81bb5ce18facb9748d7855bf5a756fe3db local_auth_darwin: c7e464000a6a89e952235699e32b329457608d98 move_to_background: 39a5b79b26d577b0372cbe8a8c55e7aa9fcd3a2d MTBBarcodeScanner: f453b33c4b7dfe545d8c6484ed744d55671788cb @@ -253,16 +245,14 @@ SPEC CHECKSUMS: qr_code_scanner: bb67d64904c3b9658ada8c402e8b4d406d5d796e ReachabilitySwift: 2128f3a8c9107e1ad33574c6e58e8285d460b149 SDWebImage: dfe95b2466a9823cf9f0c6d01217c06550d7b29a - Sentry: ebc12276bd17613a114ab359074096b6b3725203 - sentry_flutter: dff1df05dc39c83d04f9330b36360fc374574c5e - SentryPrivate: d651efb234cf385ec9a1cdd3eff94b5e78a0e0fe + Sentry: cd86fc55628f5b7c572cabe66cc8f95a9d2f165a + sentry_flutter: 4cb24c1055c556d7b27262ab2e179d1e5a0b9b0c share_plus: c3fef564749587fc939ef86ffb283ceac0baf9f5 shared_preferences_foundation: b4c3b4cddf1c21f02770737f147a3f5da9d39695 - smart_auth: 4bedbc118723912d0e45a07e8ab34039c19e04f2 sodium_libs: 1faae17af662384acbd13e41867a0008cd2e2318 sqflite: 673a0e54cc04b7d6dba8d24fb8095b31c3a99eec sqlite3: 02d1f07eaaa01f80a1c16b4b31dfcbb3345ee01a - sqlite3_flutter_libs: af0e8fe9bce48abddd1ffdbbf839db0302d72d80 + sqlite3_flutter_libs: 9bfe005308998aeca155330bbc2ea6dddf834a3b SwiftyGif: 706c60cf65fa2bc5ee0313beece843c8eb8194d4 Toast: 1f5ea13423a1e6674c4abdac5be53587ae481c4e url_launcher_ios: 6116280ddcfe98ab8820085d8d76ae7449447586 diff --git a/auth/lib/l10n/arb/app_ru.arb b/auth/lib/l10n/arb/app_ru.arb index ca98611ee..42571a166 100644 --- a/auth/lib/l10n/arb/app_ru.arb +++ b/auth/lib/l10n/arb/app_ru.arb @@ -188,6 +188,8 @@ "recoveryKeySaveDescription": "Мы не храним этот ключ, пожалуйста, сохраните этот ключ в безопасном месте.", "doThisLater": "Сделать позже", "saveKey": "Сохранить ключ", + "save": "Сохранить", + "send": "Отправить", "back": "Вернуться", "createAccount": "Создать аккаунт", "passwordStrength": "Мощность пароля: {passwordStrengthValue}", @@ -394,5 +396,13 @@ "signOutOtherDevices": "Выйти из других устройств", "doNotSignOut": "Не выходить", "hearUsWhereTitle": "Как вы узнали о Ente? (необязательно)", - "hearUsExplanation": "Будет полезно, если вы укажете, где нашли нас, так как мы не отслеживаем установки приложения" + "hearUsExplanation": "Будет полезно, если вы укажете, где нашли нас, так как мы не отслеживаем установки приложения", + "waitingForVerification": "Ожидание подтверждения...", + "developerSettingsWarning": "Вы уверены, что хотите изменить настройки разработчика?", + "developerSettings": "Настройки разработчика", + "serverEndpoint": "Конечная точка сервера", + "invalidEndpoint": "Неверная конечная точка", + "invalidEndpointMessage": "Извините, введенная вами конечная точка неверна. Пожалуйста, введите корректную конечную точку и повторите попытку.", + "endpointUpdatedMessage": "Конечная точка успешно обновлена", + "customEndpoint": "Подключено к {endpoint}" } \ No newline at end of file diff --git a/auth/macos/Podfile.lock b/auth/macos/Podfile.lock index 87c5556d2..92d05104e 100644 --- a/auth/macos/Podfile.lock +++ b/auth/macos/Podfile.lock @@ -29,20 +29,16 @@ PODS: - ReachabilitySwift (5.2.2) - screen_retriever (0.0.1): - FlutterMacOS - - Sentry/HybridSDK (8.21.0): - - SentryPrivate (= 8.21.0) - - sentry_flutter (0.0.1): + - Sentry/HybridSDK (8.25.0) + - sentry_flutter (7.20.1): - Flutter - FlutterMacOS - - Sentry/HybridSDK (= 8.21.0) - - SentryPrivate (8.21.0) + - Sentry/HybridSDK (= 8.25.0) - share_plus (0.0.1): - FlutterMacOS - shared_preferences_foundation (0.0.1): - Flutter - FlutterMacOS - - smart_auth (0.0.1): - - FlutterMacOS - sodium_libs (2.2.1): - FlutterMacOS - sqflite (0.0.3): @@ -59,7 +55,7 @@ PODS: - sqlite3/common - sqlite3_flutter_libs (0.0.1): - FlutterMacOS - - sqlite3 (~> 3.45.1) + - "sqlite3 (~> 3.45.3+1)" - sqlite3/fts5 - sqlite3/perf-threadsafe - sqlite3/rtree @@ -87,7 +83,6 @@ DEPENDENCIES: - sentry_flutter (from `Flutter/ephemeral/.symlinks/plugins/sentry_flutter/macos`) - share_plus (from `Flutter/ephemeral/.symlinks/plugins/share_plus/macos`) - shared_preferences_foundation (from `Flutter/ephemeral/.symlinks/plugins/shared_preferences_foundation/darwin`) - - smart_auth (from `Flutter/ephemeral/.symlinks/plugins/smart_auth/macos`) - sodium_libs (from `Flutter/ephemeral/.symlinks/plugins/sodium_libs/macos`) - sqflite (from `Flutter/ephemeral/.symlinks/plugins/sqflite/darwin`) - sqlite3_flutter_libs (from `Flutter/ephemeral/.symlinks/plugins/sqlite3_flutter_libs/macos`) @@ -100,7 +95,6 @@ SPEC REPOS: - OrderedSet - ReachabilitySwift - Sentry - - SentryPrivate - sqlite3 EXTERNAL SOURCES: @@ -136,8 +130,6 @@ EXTERNAL SOURCES: :path: Flutter/ephemeral/.symlinks/plugins/share_plus/macos shared_preferences_foundation: :path: Flutter/ephemeral/.symlinks/plugins/shared_preferences_foundation/darwin - smart_auth: - :path: Flutter/ephemeral/.symlinks/plugins/smart_auth/macos sodium_libs: :path: Flutter/ephemeral/.symlinks/plugins/sodium_libs/macos sqflite: @@ -167,16 +159,14 @@ SPEC CHECKSUMS: path_provider_foundation: 3784922295ac71e43754bd15e0653ccfd36a147c ReachabilitySwift: 2128f3a8c9107e1ad33574c6e58e8285d460b149 screen_retriever: 59634572a57080243dd1bf715e55b6c54f241a38 - Sentry: ebc12276bd17613a114ab359074096b6b3725203 - sentry_flutter: dff1df05dc39c83d04f9330b36360fc374574c5e - SentryPrivate: d651efb234cf385ec9a1cdd3eff94b5e78a0e0fe + Sentry: cd86fc55628f5b7c572cabe66cc8f95a9d2f165a + sentry_flutter: 4cb24c1055c556d7b27262ab2e179d1e5a0b9b0c share_plus: 76dd39142738f7a68dd57b05093b5e8193f220f7 shared_preferences_foundation: b4c3b4cddf1c21f02770737f147a3f5da9d39695 - smart_auth: b38e3ab4bfe089eacb1e233aca1a2340f96c28e9 sodium_libs: d39bd76697736cb11ce4a0be73b9b4bc64466d6f sqflite: 673a0e54cc04b7d6dba8d24fb8095b31c3a99eec sqlite3: 02d1f07eaaa01f80a1c16b4b31dfcbb3345ee01a - sqlite3_flutter_libs: 06a05802529659a272beac4ee1350bfec294f386 + sqlite3_flutter_libs: 8d204ef443cf0d5c1c8b058044eab53f3943a9c5 tray_manager: 9064e219c56d75c476e46b9a21182087930baf90 url_launcher_macos: d2691c7dd33ed713bf3544850a623080ec693d95 window_manager: 3a1844359a6295ab1e47659b1a777e36773cd6e8 diff --git a/auth/pubspec.yaml b/auth/pubspec.yaml index c4aa6503f..d28f1c27d 100644 --- a/auth/pubspec.yaml +++ b/auth/pubspec.yaml @@ -1,6 +1,6 @@ name: ente_auth description: ente two-factor authenticator -version: 2.0.58+258 +version: 3.0.2+302 publish_to: none environment: diff --git a/desktop/package.json b/desktop/package.json index 2f2cc8f16..236dd5592 100644 --- a/desktop/package.json +++ b/desktop/package.json @@ -14,7 +14,7 @@ "build:ci": "yarn build-renderer && tsc", "build:quick": "yarn build-renderer && yarn build-main:quick", "dev": "concurrently --kill-others --success first --names 'main,rndr' \"yarn dev-main\" \"yarn dev-renderer\"", - "dev-main": "tsc && electron app/main.js", + "dev-main": "tsc && electron .", "dev-renderer": "cd ../web && yarn install && yarn dev:photos", "postinstall": "electron-builder install-app-deps", "lint": "yarn prettier --check --log-level warn . && eslint --ext .ts src && yarn tsc", diff --git a/desktop/src/main.ts b/desktop/src/main.ts index 7ffbdeced..463774dc2 100644 --- a/desktop/src/main.ts +++ b/desktop/src/main.ts @@ -241,7 +241,7 @@ const uniqueSavePath = (dirPath: string, fileName: string) => { * * @param webContents The renderer to configure. */ -export const allowExternalLinks = (webContents: WebContents) => { +export const allowExternalLinks = (webContents: WebContents) => // By default, if the user were open a link, say // https://github.com/ente-io/ente/discussions, then it would open a _new_ // BrowserWindow within our app. @@ -253,13 +253,37 @@ export const allowExternalLinks = (webContents: WebContents) => { // Returning `action` "deny" accomplishes this. webContents.setWindowOpenHandler(({ url }) => { if (!url.startsWith(rendererURL)) { + // This does not work in Ubuntu currently: mailto links seem to just + // get ignored, and HTTP links open in the text editor instead of in + // the browser. + // https://github.com/electron/electron/issues/31485 void shell.openExternal(url); return { action: "deny" }; } else { return { action: "allow" }; } }); -}; + +/** + * Allow uploading to arbitrary S3 buckets. + * + * The files in the desktop app are served over the ente:// protocol. During + * testing or self-hosting, we might be using a S3 bucket that does not allow + * whitelisting a custom URI scheme. To avoid requiring the bucket to set an + * "Access-Control-Allow-Origin: *" or do a echo-back of `Origin`, we add a + * workaround here instead, intercepting the ACAO header and allowing `*`. + */ +export const allowAllCORSOrigins = (webContents: WebContents) => + webContents.session.webRequest.onHeadersReceived( + ({ responseHeaders }, callback) => { + const headers: NonNullable = {}; + for (const [key, value] of Object.entries(responseHeaders ?? {})) + if (key.toLowerCase() != "access-control-allow-origin") + headers[key] = value; + headers["Access-Control-Allow-Origin"] = ["*"]; + callback({ responseHeaders: headers }); + }, + ); /** * Add an icon for our app in the system tray. @@ -291,32 +315,18 @@ const setupTrayItem = (mainWindow: BrowserWindow) => { /** * Older versions of our app used to maintain a cache dir using the main - * process. This has been removed in favor of cache on the web layer. + * process. This has been removed in favor of cache on the web layer. Delete the + * old cache dir if it exists. * - * Delete the old cache dir if it exists. - * - * This will happen in two phases. The cache had three subdirectories: - * - * - Two of them, "thumbs" and "files", will be removed now (v1.7.0, May 2024). - * - * - The third one, "face-crops" will be removed once we finish the face search - * changes. See: [Note: Legacy face crops]. - * - * This migration code can be removed after some time once most people have - * upgraded to newer versions. + * Added May 2024, v1.7.0. This migration code can be removed after some time + * once most people have upgraded to newer versions. */ const deleteLegacyDiskCacheDirIfExists = async () => { - const removeIfExists = async (dirPath: string) => { - if (existsSync(dirPath)) { - log.info(`Removing legacy disk cache from ${dirPath}`); - await fs.rm(dirPath, { recursive: true }); - } - }; // [Note: Getting the cache path] // // The existing code was passing "cache" as a parameter to getPath. // - // However, "cache" is not a valid parameter to getPath. It works! (for + // However, "cache" is not a valid parameter to getPath. It works (for // example, on macOS I get `~/Library/Caches`), but it is intentionally not // documented as part of the public API: // @@ -329,8 +339,8 @@ const deleteLegacyDiskCacheDirIfExists = async () => { // @ts-expect-error "cache" works but is not part of the public API. const cacheDir = path.join(app.getPath("cache"), "ente"); if (existsSync(cacheDir)) { - await removeIfExists(path.join(cacheDir, "thumbs")); - await removeIfExists(path.join(cacheDir, "files")); + log.info(`Removing legacy disk cache from ${cacheDir}`); + await fs.rm(cacheDir, { recursive: true }); } }; @@ -390,8 +400,10 @@ const main = () => { registerStreamProtocol(); // Configure the renderer's environment. - setDownloadPath(mainWindow.webContents); - allowExternalLinks(mainWindow.webContents); + const webContents = mainWindow.webContents; + setDownloadPath(webContents); + allowExternalLinks(webContents); + allowAllCORSOrigins(webContents); // Start loading the renderer. void mainWindow.loadURL(rendererURL); diff --git a/desktop/src/main/ipc.ts b/desktop/src/main/ipc.ts index e74d5e9d2..6e7df7cde 100644 --- a/desktop/src/main/ipc.ts +++ b/desktop/src/main/ipc.ts @@ -24,7 +24,6 @@ import { updateOnNextRestart, } from "./services/app-update"; import { - legacyFaceCrop, openDirectory, openLogDirectory, selectDirectory, @@ -43,10 +42,10 @@ import { import { convertToJPEG, generateImageThumbnail } from "./services/image"; import { logout } from "./services/logout"; import { - clipImageEmbedding, - clipTextEmbeddingIfAvailable, + computeCLIPImageEmbedding, + computeCLIPTextEmbeddingIfAvailable, } from "./services/ml-clip"; -import { detectFaces, faceEmbeddings } from "./services/ml-face"; +import { computeFaceEmbeddings, detectFaces } from "./services/ml-face"; import { encryptionKey, saveEncryptionKey } from "./services/store"; import { clearPendingUploads, @@ -170,24 +169,22 @@ export const attachIPCHandlers = () => { // - ML - ipcMain.handle("clipImageEmbedding", (_, jpegImageData: Uint8Array) => - clipImageEmbedding(jpegImageData), + ipcMain.handle( + "computeCLIPImageEmbedding", + (_, jpegImageData: Uint8Array) => + computeCLIPImageEmbedding(jpegImageData), ); - ipcMain.handle("clipTextEmbeddingIfAvailable", (_, text: string) => - clipTextEmbeddingIfAvailable(text), + ipcMain.handle("computeCLIPTextEmbeddingIfAvailable", (_, text: string) => + computeCLIPTextEmbeddingIfAvailable(text), ); ipcMain.handle("detectFaces", (_, input: Float32Array) => detectFaces(input), ); - ipcMain.handle("faceEmbeddings", (_, input: Float32Array) => - faceEmbeddings(input), - ); - - ipcMain.handle("legacyFaceCrop", (_, faceID: string) => - legacyFaceCrop(faceID), + ipcMain.handle("computeFaceEmbeddings", (_, input: Float32Array) => + computeFaceEmbeddings(input), ); // - Upload diff --git a/desktop/src/main/log.ts b/desktop/src/main/log.ts index d1d65f8ae..9718dfea5 100644 --- a/desktop/src/main/log.ts +++ b/desktop/src/main/log.ts @@ -70,8 +70,9 @@ const logInfo = (...params: unknown[]) => { const message = params .map((p) => (typeof p == "string" ? p : util.inspect(p))) .join(" "); - log.info(`[main] ${message}`); - if (isDev) console.log(`[info] ${message}`); + const m = `[info] ${message}`; + if (isDev) console.log(m); + log.info(`[main] ${m}`); }; const logDebug = (param: () => unknown) => { diff --git a/desktop/src/main/services/app-update.ts b/desktop/src/main/services/app-update.ts index 6e3890e16..8b2d07a49 100644 --- a/desktop/src/main/services/app-update.ts +++ b/desktop/src/main/services/app-update.ts @@ -163,7 +163,7 @@ const checkForUpdatesAndNotify = async (mainWindow: BrowserWindow) => { }; /** - * Return the version of the desktop app + * Return the version of the desktop app. * * The return value is of the form `v1.2.3`. */ diff --git a/desktop/src/main/services/dir.ts b/desktop/src/main/services/dir.ts index d97cad6fb..4b1f748fe 100644 --- a/desktop/src/main/services/dir.ts +++ b/desktop/src/main/services/dir.ts @@ -1,7 +1,5 @@ import { shell } from "electron/common"; import { app, dialog } from "electron/main"; -import { existsSync } from "fs"; -import fs from "node:fs/promises"; import path from "node:path"; import { posixPath } from "../utils/electron"; @@ -53,14 +51,6 @@ export const openLogDirectory = () => openDirectory(logDirectoryPath()); * "userData" directory. This is the **primary** place applications are meant to * store user's data, e.g. various configuration files and saved state. * - * During development, our app name is "Electron", so this'd be, for example, - * `~/Library/Application Support/Electron` if we run using `yarn dev`. For the - * packaged production app, our app name is "ente", so this would be: - * - * - Windows: `%APPDATA%\ente`, e.g. `C:\Users\\AppData\Local\ente` - * - Linux: `~/.config/ente` - * - macOS: `~/Library/Application Support/ente` - * * Note that Chromium also stores the browser state, e.g. localStorage or disk * caches, in userData. * @@ -73,21 +63,7 @@ export const openLogDirectory = () => openDirectory(logDirectoryPath()); * "ente.log", it can be found at: * * - macOS: ~/Library/Logs/ente/ente.log (production) - * - macOS: ~/Library/Logs/Electron/ente.log (dev) * - Linux: ~/.config/ente/logs/ente.log * - Windows: %USERPROFILE%\AppData\Roaming\ente\logs\ente.log */ const logDirectoryPath = () => app.getPath("logs"); - -/** - * See: [Note: Legacy face crops] - */ -export const legacyFaceCrop = async ( - faceID: string, -): Promise => { - // See: [Note: Getting the cache path] - // @ts-expect-error "cache" works but is not part of the public API. - const cacheDir = path.join(app.getPath("cache"), "ente"); - const filePath = path.join(cacheDir, "face-crops", faceID); - return existsSync(filePath) ? await fs.readFile(filePath) : undefined; -}; diff --git a/desktop/src/main/services/image.ts b/desktop/src/main/services/image.ts index c07b051a1..fca4628b6 100644 --- a/desktop/src/main/services/image.ts +++ b/desktop/src/main/services/image.ts @@ -3,7 +3,6 @@ import fs from "node:fs/promises"; import path from "node:path"; import { CustomErrorMessage, type ZipItem } from "../../types/ipc"; -import log from "../log"; import { execAsync, isDev } from "../utils/electron"; import { deleteTempFileIgnoringErrors, @@ -93,9 +92,6 @@ export const generateImageThumbnail = async ( let thumbnail: Uint8Array; do { await execAsync(command); - // TODO(MR): release 1.7 - // TODO(MR): imagemagick debugging. Remove me after verifying logs. - log.info(`Generated thumbnail using ${command.join(" ")}`); thumbnail = new Uint8Array(await fs.readFile(outputFilePath)); quality -= 10; command = generateImageThumbnailCommand( diff --git a/desktop/src/main/services/ml-clip.ts b/desktop/src/main/services/ml-clip.ts index e3dd99204..cea1574e0 100644 --- a/desktop/src/main/services/ml-clip.ts +++ b/desktop/src/main/services/ml-clip.ts @@ -11,7 +11,7 @@ import * as ort from "onnxruntime-node"; import Tokenizer from "../../thirdparty/clip-bpe-ts/mod"; import log from "../log"; import { writeStream } from "../stream"; -import { ensure } from "../utils/common"; +import { ensure, wait } from "../utils/common"; import { deleteTempFile, makeTempFilePath } from "../utils/temp"; import { makeCachedInferenceSession } from "./ml"; @@ -20,7 +20,7 @@ const cachedCLIPImageSession = makeCachedInferenceSession( 351468764 /* 335.2 MB */, ); -export const clipImageEmbedding = async (jpegImageData: Uint8Array) => { +export const computeCLIPImageEmbedding = async (jpegImageData: Uint8Array) => { const tempFilePath = await makeTempFilePath(); const imageStream = new Response(jpegImageData.buffer).body; await writeStream(tempFilePath, ensure(imageStream)); @@ -42,7 +42,7 @@ const clipImageEmbedding_ = async (jpegFilePath: string) => { const results = await session.run(feeds); log.debug( () => - `onnx/clip image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, + `ONNX/CLIP image embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, ); /* Need these model specific casts to type the result */ const imageEmbedding = ensure(results.output).data as Float32Array; @@ -140,21 +140,23 @@ const getTokenizer = () => { return _tokenizer; }; -export const clipTextEmbeddingIfAvailable = async (text: string) => { - const sessionOrStatus = await Promise.race([ +export const computeCLIPTextEmbeddingIfAvailable = async (text: string) => { + const sessionOrSkip = await Promise.race([ cachedCLIPTextSession(), - "downloading-model", + // Wait for a tick to get the session promise to resolved the first time + // this code runs on each app start (and the model has been downloaded). + wait(0).then(() => 1), ]); - // Don't wait for the download to complete - if (typeof sessionOrStatus == "string") { + // Don't wait for the download to complete. + if (typeof sessionOrSkip == "number") { log.info( "Ignoring CLIP text embedding request because model download is pending", ); return undefined; } - const session = sessionOrStatus; + const session = sessionOrSkip; const t1 = Date.now(); const tokenizer = getTokenizer(); const tokenizedText = Int32Array.from(tokenizer.encodeForCLIP(text)); @@ -165,7 +167,7 @@ export const clipTextEmbeddingIfAvailable = async (text: string) => { const results = await session.run(feeds); log.debug( () => - `onnx/clip text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, + `ONNX/CLIP text embedding took ${Date.now() - t1} ms (prep: ${t2 - t1} ms, inference: ${Date.now() - t2} ms)`, ); const textEmbedding = ensure(results.output).data as Float32Array; return normalizeEmbedding(textEmbedding); diff --git a/desktop/src/main/services/ml-face.ts b/desktop/src/main/services/ml-face.ts index 33157694f..b6fb5c90f 100644 --- a/desktop/src/main/services/ml-face.ts +++ b/desktop/src/main/services/ml-face.ts @@ -23,7 +23,7 @@ export const detectFaces = async (input: Float32Array) => { input: new ort.Tensor("float32", input, [1, 3, 640, 640]), }; const results = await session.run(feeds); - log.debug(() => `onnx/yolo face detection took ${Date.now() - t} ms`); + log.debug(() => `ONNX/YOLO face detection took ${Date.now() - t} ms`); return ensure(results.output).data; }; @@ -32,7 +32,7 @@ const cachedFaceEmbeddingSession = makeCachedInferenceSession( 5286998 /* 5 MB */, ); -export const faceEmbeddings = async (input: Float32Array) => { +export const computeFaceEmbeddings = async (input: Float32Array) => { // Dimension of each face (alias) const mobileFaceNetFaceSize = 112; // Smaller alias @@ -45,7 +45,7 @@ export const faceEmbeddings = async (input: Float32Array) => { const t = Date.now(); const feeds = { img_inputs: inputTensor }; const results = await session.run(feeds); - log.debug(() => `onnx/yolo face embedding took ${Date.now() - t} ms`); + log.debug(() => `ONNX/MFNT face embedding took ${Date.now() - t} ms`); /* Need these model specific casts to extract and type the result */ return (results.embeddings as unknown as Record) .cpuData as Float32Array; diff --git a/desktop/src/main/services/store.ts b/desktop/src/main/services/store.ts index 471928d76..253c2cbf0 100644 --- a/desktop/src/main/services/store.ts +++ b/desktop/src/main/services/store.ts @@ -18,10 +18,7 @@ export const clearStores = () => { * [Note: Safe storage keys] * * On macOS, `safeStorage` stores our data under a Keychain entry named - * " Safe Storage". Which resolves to: - * - * - Electron Safe Storage (dev) - * - ente Safe Storage (prod) + * " Safe Storage". In our case, "ente Safe Storage". */ export const saveEncryptionKey = (encryptionKey: string) => { const encryptedKey = safeStorage.encryptString(encryptionKey); diff --git a/desktop/src/main/utils/common.ts b/desktop/src/main/utils/common.ts index 5ed46aa8a..929281d74 100644 --- a/desktop/src/main/utils/common.ts +++ b/desktop/src/main/utils/common.ts @@ -13,3 +13,12 @@ export const ensure = (v: T | null | undefined): T => { if (v === undefined) throw new Error("Required value was not found"); return v; }; + +/** + * Wait for {@link ms} milliseconds + * + * This function is a promisified `setTimeout`. It returns a promise that + * resolves after {@link ms} milliseconds. + */ +export const wait = (ms: number) => + new Promise((resolve) => setTimeout(resolve, ms)); diff --git a/desktop/src/main/utils/electron.ts b/desktop/src/main/utils/electron.ts index 133edf87c..c11391dd6 100644 --- a/desktop/src/main/utils/electron.ts +++ b/desktop/src/main/utils/electron.ts @@ -55,9 +55,7 @@ export const execAsync = async (command: string | string[]) => { : command; const startTime = Date.now(); const result = await execAsync_(escapedCommand); - log.debug( - () => `${escapedCommand} (${Math.round(Date.now() - startTime)} ms)`, - ); + log.debug(() => `${escapedCommand} (${Date.now() - startTime} ms)`); return result; }; diff --git a/desktop/src/preload.ts b/desktop/src/preload.ts index c5a1d0d31..85475031d 100644 --- a/desktop/src/preload.ts +++ b/desktop/src/preload.ts @@ -65,7 +65,7 @@ const selectDirectory = () => ipcRenderer.invoke("selectDirectory"); const logout = () => { watchRemoveListeners(); - ipcRenderer.send("logout"); + return ipcRenderer.invoke("logout"); }; const encryptionKey = () => ipcRenderer.invoke("encryptionKey"); @@ -153,20 +153,17 @@ const ffmpegExec = ( // - ML -const clipImageEmbedding = (jpegImageData: Uint8Array) => - ipcRenderer.invoke("clipImageEmbedding", jpegImageData); +const computeCLIPImageEmbedding = (jpegImageData: Uint8Array) => + ipcRenderer.invoke("computeCLIPImageEmbedding", jpegImageData); -const clipTextEmbeddingIfAvailable = (text: string) => - ipcRenderer.invoke("clipTextEmbeddingIfAvailable", text); +const computeCLIPTextEmbeddingIfAvailable = (text: string) => + ipcRenderer.invoke("computeCLIPTextEmbeddingIfAvailable", text); const detectFaces = (input: Float32Array) => ipcRenderer.invoke("detectFaces", input); -const faceEmbeddings = (input: Float32Array) => - ipcRenderer.invoke("faceEmbeddings", input); - -const legacyFaceCrop = (faceID: string) => - ipcRenderer.invoke("legacyFaceCrop", faceID); +const computeFaceEmbeddings = (input: Float32Array) => + ipcRenderer.invoke("computeFaceEmbeddings", input); // - Watch @@ -340,11 +337,10 @@ contextBridge.exposeInMainWorld("electron", { // - ML - clipImageEmbedding, - clipTextEmbeddingIfAvailable, + computeCLIPImageEmbedding, + computeCLIPTextEmbeddingIfAvailable, detectFaces, - faceEmbeddings, - legacyFaceCrop, + computeFaceEmbeddings, // - Watch diff --git a/mobile/android/app/build.gradle b/mobile/android/app/build.gradle index 01ec11ff8..b5225db8e 100644 --- a/mobile/android/app/build.gradle +++ b/mobile/android/app/build.gradle @@ -43,7 +43,7 @@ android { defaultConfig { applicationId "io.ente.photos" - minSdkVersion 21 + minSdkVersion 26 targetSdkVersion 33 versionCode flutterVersionCode.toInteger() versionName flutterVersionName @@ -70,6 +70,10 @@ android { dimension "default" applicationIdSuffix ".dev" } + face { + dimension "default" + applicationIdSuffix ".face" + } playstore { dimension "default" } diff --git a/mobile/android/app/src/face/AndroidManifest.xml b/mobile/android/app/src/face/AndroidManifest.xml new file mode 100644 index 000000000..cbf1924b2 --- /dev/null +++ b/mobile/android/app/src/face/AndroidManifest.xml @@ -0,0 +1,10 @@ + + + + + + + diff --git a/mobile/android/app/src/face/res/values/strings.xml b/mobile/android/app/src/face/res/values/strings.xml new file mode 100644 index 000000000..4932deb96 --- /dev/null +++ b/mobile/android/app/src/face/res/values/strings.xml @@ -0,0 +1,4 @@ + + ente face + backup face + diff --git a/mobile/assets/models/cocossd/labels.txt b/mobile/assets/models/cocossd/labels.txt deleted file mode 100644 index fc674c0b9..000000000 --- a/mobile/assets/models/cocossd/labels.txt +++ /dev/null @@ -1,91 +0,0 @@ -unknown -person -bicycle -car -motorcycle -airplane -bus -train -truck -boat -traffic light -fire hydrant -unknown -stop sign -parking meter -bench -bird -cat -dog -horse -sheep -cow -elephant -bear -zebra -giraffe -unknown -backpack -umbrella -unknown -unknown -handbag -tie -suitcase -frisbee -skis -snowboard -sports ball -kite -baseball bat -baseball glove -skateboard -surfboard -tennis racket -bottle -unknown -wine glass -cup -fork -knife -spoon -bowl -banana -apple -sandwich -orange -broccoli -carrot -hot dog -pizza -donut -cake -chair -couch -potted plant -bed -unknown -dining table -unknown -unknown -toilet -unknown -tv -laptop -mouse -remote -keyboard -cell phone -microwave -oven -toaster -sink -refrigerator -unknown -book -clock -vase -scissors -teddy bear -hair drier -toothbrush diff --git a/mobile/assets/models/cocossd/model.tflite b/mobile/assets/models/cocossd/model.tflite deleted file mode 100644 index 8015ee5d8..000000000 Binary files a/mobile/assets/models/cocossd/model.tflite and /dev/null differ diff --git a/mobile/assets/models/mobilenet/labels_mobilenet_quant_v1_224.txt b/mobile/assets/models/mobilenet/labels_mobilenet_quant_v1_224.txt deleted file mode 100644 index fe811239d..000000000 --- a/mobile/assets/models/mobilenet/labels_mobilenet_quant_v1_224.txt +++ /dev/null @@ -1,1001 +0,0 @@ -background -tench -goldfish -great white shark -tiger shark -hammerhead -electric ray -stingray -cock -hen -ostrich -brambling -goldfinch -house finch -junco -indigo bunting -robin -bulbul -jay -magpie -chickadee -water ouzel -kite -bald eagle -vulture -great grey owl -European fire salamander -common newt -eft -spotted salamander -axolotl -bullfrog -tree frog -tailed frog -loggerhead -leatherback turtle -mud turtle -terrapin -box turtle -banded gecko -common iguana -American chameleon -whiptail -agama -frilled lizard -alligator lizard -Gila monster -green lizard -African chameleon -Komodo dragon -African crocodile -American alligator -triceratops -thunder snake -ringneck snake -hognose snake -green snake -king snake -garter snake -water snake -vine snake -night snake -boa constrictor -rock python -Indian cobra -green mamba -sea snake -horned viper -diamondback -sidewinder -trilobite -harvestman -scorpion -black and gold garden spider -barn spider -garden spider -black widow -tarantula -wolf spider -tick -centipede -black grouse -ptarmigan -ruffed grouse -prairie chicken -peacock -quail -partridge -African grey -macaw -sulphur-crested cockatoo -lorikeet -coucal -bee eater -hornbill -hummingbird -jacamar -toucan -drake -red-breasted merganser -goose -black swan -tusker -echidna -platypus -wallaby -koala -wombat -jellyfish -sea anemone -brain coral -flatworm -nematode -conch -snail -slug -sea slug -chiton -chambered nautilus -Dungeness crab -rock crab -fiddler crab -king crab -American lobster -spiny lobster -crayfish -hermit crab -isopod -white stork -black stork -spoonbill -flamingo -little blue heron -American egret -bittern -crane -limpkin -European gallinule -American coot -bustard -ruddy turnstone -red-backed sandpiper -redshank -dowitcher -oystercatcher -pelican -king penguin -albatross -grey whale -killer whale -dugong -sea lion -Chihuahua -Japanese spaniel -Maltese dog -Pekinese -Shih-Tzu -Blenheim spaniel -papillon -toy terrier -Rhodesian ridgeback -Afghan hound -basset -beagle -bloodhound -bluetick -black-and-tan coonhound -Walker hound -English foxhound -redbone -borzoi -Irish wolfhound -Italian greyhound -whippet -Ibizan hound -Norwegian elkhound -otterhound -Saluki -Scottish deerhound -Weimaraner -Staffordshire bullterrier -American Staffordshire terrier -Bedlington terrier -Border terrier -Kerry blue terrier -Irish terrier -Norfolk terrier -Norwich terrier -Yorkshire terrier -wire-haired fox terrier -Lakeland terrier -Sealyham terrier -Airedale -cairn -Australian terrier -Dandie Dinmont -Boston bull -miniature schnauzer -giant schnauzer -standard schnauzer -Scotch terrier -Tibetan terrier -silky terrier -soft-coated wheaten terrier -West Highland white terrier -Lhasa -flat-coated retriever -curly-coated retriever -golden retriever -Labrador retriever -Chesapeake Bay retriever -German short-haired pointer -vizsla -English setter -Irish setter -Gordon setter -Brittany spaniel -clumber -English springer -Welsh springer spaniel -cocker spaniel -Sussex spaniel -Irish water spaniel -kuvasz -schipperke -groenendael -malinois -briard -kelpie -komondor -Old English sheepdog -Shetland sheepdog -collie -Border collie -Bouvier des Flandres -Rottweiler -German shepherd -Doberman -miniature pinscher -Greater Swiss Mountain dog -Bernese mountain dog -Appenzeller -EntleBucher -boxer -bull mastiff -Tibetan mastiff -French bulldog -Great Dane -Saint Bernard -Eskimo dog -malamute -Siberian husky -dalmatian -affenpinscher -basenji -pug -Leonberg -Newfoundland -Great Pyrenees -Samoyed -Pomeranian -chow -keeshond -Brabancon griffon -Pembroke -Cardigan -toy poodle -miniature poodle -standard poodle -Mexican hairless -timber wolf -white wolf -red wolf -coyote -dingo -dhole -African hunting dog -hyena -red fox -kit fox -Arctic fox -grey fox -tabby -tiger cat -Persian cat -Siamese cat -Egyptian cat -cougar -lynx -leopard -snow leopard -jaguar -lion -tiger -cheetah -brown bear -American black bear -ice bear -sloth bear -mongoose -meerkat -tiger beetle -ladybug -ground beetle -long-horned beetle -leaf beetle -dung beetle -rhinoceros beetle -weevil -fly -bee -ant -grasshopper -cricket -walking stick -cockroach -mantis -cicada -leafhopper -lacewing -dragonfly -damselfly -admiral -ringlet -monarch -cabbage butterfly -sulphur butterfly -lycaenid -starfish -sea urchin -sea cucumber -wood rabbit -hare -Angora -hamster -porcupine -fox squirrel -marmot -beaver -guinea pig -sorrel -zebra -hog -wild boar -warthog -hippopotamus -ox -water buffalo -bison -ram -bighorn -ibex -hartebeest -impala -gazelle -Arabian camel -llama -weasel -mink -polecat -black-footed ferret -otter -skunk -badger -armadillo -three-toed sloth -orangutan -gorilla -chimpanzee -gibbon -siamang -guenon -patas -baboon -macaque -langur -colobus -proboscis monkey -marmoset -capuchin -howler monkey -titi -spider monkey -squirrel monkey -Madagascar cat -indri -Indian elephant -African elephant -lesser panda -giant panda -barracouta -eel -coho -rock beauty -anemone fish -sturgeon -gar -lionfish -puffer -abacus -abaya -academic gown -accordion -acoustic guitar -aircraft carrier -airliner -airship -altar -ambulance -amphibian -analog clock -apiary -apron -ashcan -assault rifle -backpack -bakery -balance beam -balloon -ballpoint -Band Aid -banjo -bannister -barbell -barber chair -barbershop -barn -barometer -barrel -barrow -baseball -basketball -bassinet -bassoon -bathing cap -bath towel -bathtub -beach wagon -beacon -beaker -bearskin -beer bottle -beer glass -bell cote -bib -bicycle-built-for-two -bikini -binder -binoculars -birdhouse -boathouse -bobsled -bolo tie -bonnet -bookcase -bookshop -bottlecap -bow -bow tie -brass -brassiere -breakwater -breastplate -broom -bucket -buckle -bulletproof vest -bullet train -butcher shop -cab -caldron -candle -cannon -canoe -can opener -cardigan -car mirror -carousel -carpenter's kit -carton -car wheel -cash machine -cassette -cassette player -castle -catamaran -CD player -cello -cellular telephone -chain -chainlink fence -chain mail -chain saw -chest -chiffonier -chime -china cabinet -Christmas stocking -church -cinema -cleaver -cliff dwelling -cloak -clog -cocktail shaker -coffee mug -coffeepot -coil -combination lock -computer keyboard -confectionery -container ship -convertible -corkscrew -cornet -cowboy boot -cowboy hat -cradle -crane -crash helmet -crate -crib -Crock Pot -croquet ball -crutch -cuirass -dam -desk -desktop computer -dial telephone -diaper -digital clock -digital watch -dining table -dishrag -dishwasher -disk brake -dock -dogsled -dome -doormat -drilling platform -drum -drumstick -dumbbell -Dutch oven -electric fan -electric guitar -electric locomotive -entertainment center -envelope -espresso maker -face powder -feather boa -file -fireboat -fire engine -fire screen -flagpole -flute -folding chair -football helmet -forklift -fountain -fountain pen -four-poster -freight car -French horn -frying pan -fur coat -garbage truck -gasmask -gas pump -goblet -go-kart -golf ball -golfcart -gondola -gong -gown -grand piano -greenhouse -grille -grocery store -guillotine -hair slide -hair spray -half track -hammer -hamper -hand blower -hand-held computer -handkerchief -hard disc -harmonica -harp -harvester -hatchet -holster -home theater -honeycomb -hook -hoopskirt -horizontal bar -horse cart -hourglass -iPod -iron -jack-o'-lantern -jean -jeep -jersey -jigsaw puzzle -jinrikisha -joystick -kimono -knee pad -knot -lab coat -ladle -lampshade -laptop -lawn mower -lens cap -letter opener -library -lifeboat -lighter -limousine -liner -lipstick -Loafer -lotion -loudspeaker -loupe -lumbermill -magnetic compass -mailbag -mailbox -maillot -maillot -manhole cover -maraca -marimba -mask -matchstick -maypole -maze -measuring cup -medicine chest -megalith -microphone -microwave -military uniform -milk can -minibus -miniskirt -minivan -missile -mitten -mixing bowl -mobile home -Model T -modem -monastery -monitor -moped -mortar -mortarboard -mosque -mosquito net -motor scooter -mountain bike -mountain tent -mouse -mousetrap -moving van -muzzle -nail -neck brace -necklace -nipple -notebook -obelisk -oboe -ocarina -odometer -oil filter -organ -oscilloscope -overskirt -oxcart -oxygen mask -packet -paddle -paddlewheel -padlock -paintbrush -pajama -palace -panpipe -paper towel -parachute -parallel bars -park bench -parking meter -passenger car -patio -pay-phone -pedestal -pencil box -pencil sharpener -perfume -Petri dish -photocopier -pick -pickelhaube -picket fence -pickup -pier -piggy bank -pill bottle -pillow -ping-pong ball -pinwheel -pirate -pitcher -plane -planetarium -plastic bag -plate rack -plow -plunger -Polaroid camera -pole -police van -poncho -pool table -pop bottle -pot -potter's wheel -power drill -prayer rug -printer -prison -projectile -projector -puck -punching bag -purse -quill -quilt -racer -racket -radiator -radio -radio telescope -rain barrel -recreational vehicle -reel -reflex camera -refrigerator -remote control -restaurant -revolver -rifle -rocking chair -rotisserie -rubber eraser -rugby ball -rule -running shoe -safe -safety pin -saltshaker -sandal -sarong -sax -scabbard -scale -school bus -schooner -scoreboard -screen -screw -screwdriver -seat belt -sewing machine -shield -shoe shop -shoji -shopping basket -shopping cart -shovel -shower cap -shower curtain -ski -ski mask -sleeping bag -slide rule -sliding door -slot -snorkel -snowmobile -snowplow -soap dispenser -soccer ball -sock -solar dish -sombrero -soup bowl -space bar -space heater -space shuttle -spatula -speedboat -spider web -spindle -sports car -spotlight -stage -steam locomotive -steel arch bridge -steel drum -stethoscope -stole -stone wall -stopwatch -stove -strainer -streetcar -stretcher -studio couch -stupa -submarine -suit -sundial -sunglass -sunglasses -sunscreen -suspension bridge -swab -sweatshirt -swimming trunks -swing -switch -syringe -table lamp -tank -tape player -teapot -teddy -television -tennis ball -thatch -theater curtain -thimble -thresher -throne -tile roof -toaster -tobacco shop -toilet seat -torch -totem pole -tow truck -toyshop -tractor -trailer truck -tray -trench coat -tricycle -trimaran -tripod -triumphal arch -trolleybus -trombone -tub -turnstile -typewriter keyboard -umbrella -unicycle -upright -vacuum -vase -vault -velvet -vending machine -vestment -viaduct -violin -volleyball -waffle iron -wall clock -wallet -wardrobe -warplane -washbasin -washer -water bottle -water jug -water tower -whiskey jug -whistle -wig -window screen -window shade -Windsor tie -wine bottle -wing -wok -wooden spoon -wool -worm fence -wreck -yawl -yurt -web site -comic book -crossword puzzle -street sign -traffic light -book jacket -menu -plate -guacamole -consomme -hot pot -trifle -ice cream -ice lolly -French loaf -bagel -pretzel -cheeseburger -hotdog -mashed potato -head cabbage -broccoli -cauliflower -zucchini -spaghetti squash -acorn squash -butternut squash -cucumber -artichoke -bell pepper -cardoon -mushroom -Granny Smith -strawberry -orange -lemon -fig -pineapple -banana -jackfruit -custard apple -pomegranate -hay -carbonara -chocolate sauce -dough -meat loaf -pizza -potpie -burrito -red wine -espresso -cup -eggnog -alp -bubble -cliff -coral reef -geyser -lakeside -promontory -sandbar -seashore -valley -volcano -ballplayer -groom -scuba diver -rapeseed -daisy -yellow lady's slipper -corn -acorn -hip -buckeye -coral fungus -agaric -gyromitra -stinkhorn -earthstar -hen-of-the-woods -bolete -ear -toilet tissue diff --git a/mobile/assets/models/mobilenet/mobilenet_v1_1.0_224_quant.tflite b/mobile/assets/models/mobilenet/mobilenet_v1_1.0_224_quant.tflite deleted file mode 100644 index 437640b06..000000000 Binary files a/mobile/assets/models/mobilenet/mobilenet_v1_1.0_224_quant.tflite and /dev/null differ diff --git a/mobile/assets/models/scenes/labels.txt b/mobile/assets/models/scenes/labels.txt deleted file mode 100644 index e0df14082..000000000 --- a/mobile/assets/models/scenes/labels.txt +++ /dev/null @@ -1,30 +0,0 @@ -waterfall -snow -landscape -underwater -architecture -sunset / sunrise -blue sky -cloudy sky -greenery -autumn leaves -portrait -flower -night shot -stage concert -fireworks -candle light -neon lights -indoor -backlight -text documents -qr images -group portrait -computer screens -kids -dog -cat -macro -food -beach -mountain diff --git a/mobile/assets/models/scenes/model.tflite b/mobile/assets/models/scenes/model.tflite deleted file mode 100644 index f2c942354..000000000 Binary files a/mobile/assets/models/scenes/model.tflite and /dev/null differ diff --git a/mobile/ios/Podfile.lock b/mobile/ios/Podfile.lock index 731514957..558a27910 100644 --- a/mobile/ios/Podfile.lock +++ b/mobile/ios/Podfile.lock @@ -6,6 +6,8 @@ PODS: - connectivity_plus (0.0.1): - Flutter - FlutterMacOS + - dart_ui_isolate (0.0.1): + - Flutter - device_info_plus (0.0.1): - Flutter - file_saver (0.0.1): @@ -226,6 +228,7 @@ DEPENDENCIES: - background_fetch (from `.symlinks/plugins/background_fetch/ios`) - battery_info (from `.symlinks/plugins/battery_info/ios`) - connectivity_plus (from `.symlinks/plugins/connectivity_plus/darwin`) + - dart_ui_isolate (from `.symlinks/plugins/dart_ui_isolate/ios`) - device_info_plus (from `.symlinks/plugins/device_info_plus/ios`) - file_saver (from `.symlinks/plugins/file_saver/ios`) - firebase_core (from `.symlinks/plugins/firebase_core/ios`) @@ -302,6 +305,8 @@ EXTERNAL SOURCES: :path: ".symlinks/plugins/battery_info/ios" connectivity_plus: :path: ".symlinks/plugins/connectivity_plus/darwin" + dart_ui_isolate: + :path: ".symlinks/plugins/dart_ui_isolate/ios" device_info_plus: :path: ".symlinks/plugins/device_info_plus/ios" file_saver: @@ -397,6 +402,7 @@ SPEC CHECKSUMS: background_fetch: 2319bf7e18237b4b269430b7f14d177c0df09c5a battery_info: 09f5c9ee65394f2291c8c6227bedff345b8a730c connectivity_plus: ddd7f30999e1faaef5967c23d5b6d503d10434db + dart_ui_isolate: d5bcda83ca4b04f129d70eb90110b7a567aece14 device_info_plus: c6fb39579d0f423935b0c9ce7ee2f44b71b9fce6 file_saver: 503e386464dbe118f630e17b4c2e1190fa0cf808 Firebase: 91fefd38712feb9186ea8996af6cbdef41473442 diff --git a/mobile/ios/Runner.xcodeproj/project.pbxproj b/mobile/ios/Runner.xcodeproj/project.pbxproj index c88f9da38..22d5e8e68 100644 --- a/mobile/ios/Runner.xcodeproj/project.pbxproj +++ b/mobile/ios/Runner.xcodeproj/project.pbxproj @@ -293,6 +293,7 @@ "${BUILT_PRODUCTS_DIR}/background_fetch/background_fetch.framework", "${BUILT_PRODUCTS_DIR}/battery_info/battery_info.framework", "${BUILT_PRODUCTS_DIR}/connectivity_plus/connectivity_plus.framework", + "${BUILT_PRODUCTS_DIR}/dart_ui_isolate/dart_ui_isolate.framework", "${BUILT_PRODUCTS_DIR}/device_info_plus/device_info_plus.framework", "${BUILT_PRODUCTS_DIR}/file_saver/file_saver.framework", "${BUILT_PRODUCTS_DIR}/fk_user_agent/fk_user_agent.framework", @@ -374,6 +375,7 @@ "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/background_fetch.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/battery_info.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/connectivity_plus.framework", + "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/dart_ui_isolate.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/device_info_plus.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/file_saver.framework", "${TARGET_BUILD_DIR}/${FRAMEWORKS_FOLDER_PATH}/fk_user_agent.framework", diff --git a/mobile/ios/Runner/Info.plist b/mobile/ios/Runner/Info.plist index 9afb874e5..fe571afeb 100644 --- a/mobile/ios/Runner/Info.plist +++ b/mobile/ios/Runner/Info.plist @@ -65,9 +65,9 @@ ITSAppUsesNonExemptEncryption FLTEnableImpeller - + FLTEnableWideGamut - + NSFaceIDUsageDescription Please allow ente to lock itself with FaceID or TouchID NSCameraUsageDescription diff --git a/mobile/lib/core/configuration.dart b/mobile/lib/core/configuration.dart index 334da4af9..4809ba863 100644 --- a/mobile/lib/core/configuration.dart +++ b/mobile/lib/core/configuration.dart @@ -19,6 +19,7 @@ import 'package:photos/db/upload_locks_db.dart'; import "package:photos/events/endpoint_updated_event.dart"; import 'package:photos/events/signed_in_event.dart'; import 'package:photos/events/user_logged_out_event.dart'; +import "package:photos/face/db.dart"; import 'package:photos/models/key_attributes.dart'; import 'package:photos/models/key_gen_result.dart'; import 'package:photos/models/private_key_attributes.dart'; @@ -187,6 +188,7 @@ class Configuration { : null; await CollectionsDB.instance.clearTable(); await MemoriesDB.instance.clearTable(); + await FaceMLDataDB.instance.clearTable(); await UploadLocksDB.instance.clearTable(); await IgnoredFilesService.instance.reset(); diff --git a/mobile/lib/core/constants.dart b/mobile/lib/core/constants.dart index 77764ee65..02923b6c4 100644 --- a/mobile/lib/core/constants.dart +++ b/mobile/lib/core/constants.dart @@ -99,6 +99,9 @@ const blackThumbnailBase64 = '/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAEBAQEBAQEB' 'AKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAo' + 'AKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgAoAKACgD/9k='; +const localFileServer = + String.fromEnvironment("localFileServer", defaultValue: ""); + const uploadTempFilePrefix = "upload_file_"; final tempDirCleanUpInterval = kDebugMode ? const Duration(seconds: 30).inMicroseconds diff --git a/mobile/lib/db/entities_db.dart b/mobile/lib/db/entities_db.dart index b8b48fbe4..cee32641a 100644 --- a/mobile/lib/db/entities_db.dart +++ b/mobile/lib/db/entities_db.dart @@ -9,7 +9,7 @@ extension EntitiesDB on FilesDB { List data, { ConflictAlgorithm conflictAlgorithm = ConflictAlgorithm.replace, }) async { - debugPrint("Inserting missing PathIDToLocalIDMapping"); + debugPrint("entitiesDB: upsertEntities ${data.length} entities"); final db = await database; var batch = db.batch(); int batchCounter = 0; @@ -62,4 +62,17 @@ extension EntitiesDB on FilesDB { return LocalEntityData.fromJson(maps[i]); }); } + + Future getEntity(EntityType type, String id) async { + final db = await database; + final List> maps = await db.query( + "entities", + where: "type = ? AND id = ?", + whereArgs: [type.typeToString(), id], + ); + if (maps.isEmpty) { + return null; + } + return LocalEntityData.fromJson(maps.first); + } } diff --git a/mobile/lib/db/files_db.dart b/mobile/lib/db/files_db.dart index 7022100b7..f72ecb32a 100644 --- a/mobile/lib/db/files_db.dart +++ b/mobile/lib/db/files_db.dart @@ -491,6 +491,18 @@ class FilesDB { return convertToFiles(results)[0]; } + Future getAnyUploadedFile(int uploadedID) async { + final db = await instance.sqliteAsyncDB; + final results = await db.getAll( + 'SELECT * FROM $filesTable WHERE $columnUploadedFileID = ?', + [uploadedID], + ); + if (results.isEmpty) { + return null; + } + return convertToFiles(results)[0]; + } + Future> getUploadedFileIDs(int collectionID) async { final db = await instance.sqliteAsyncDB; final results = await db.getAll( @@ -683,6 +695,17 @@ class FilesDB { return files; } + Future> getAllFilesFromCollections( + Iterable collectionID, + ) async { + final db = await instance.sqliteAsyncDB; + final String sql = + 'SELECT * FROM $filesTable WHERE $columnCollectionID IN (${collectionID.join(',')})'; + final results = await db.getAll(sql); + final files = convertToFiles(results); + return files; + } + Future> getNewFilesInCollection( int collectionID, int addedTime, @@ -1304,6 +1327,23 @@ class FilesDB { return result; } + Future> getFileIDToCreationTime() async { + final db = await instance.sqliteAsyncDB; + final rows = await db.getAll( + ''' + SELECT $columnUploadedFileID, $columnCreationTime + FROM $filesTable + WHERE + ($columnUploadedFileID IS NOT NULL AND $columnUploadedFileID IS NOT -1); + ''', + ); + final result = {}; + for (final row in rows) { + result[row[columnUploadedFileID] as int] = row[columnCreationTime] as int; + } + return result; + } + // getCollectionFileFirstOrLast returns the first or last uploaded file in // the collection based on the given collectionID and the order. Future getCollectionFileFirstOrLast( @@ -1643,13 +1683,14 @@ class FilesDB { } Future> getOwnedFileIDs(int ownerID) async { - final db = await instance.database; - final results = await db.query( - filesTable, - columns: [columnUploadedFileID], - where: - '($columnOwnerID = $ownerID AND $columnUploadedFileID IS NOT NULL AND $columnUploadedFileID IS NOT -1)', - distinct: true, + final db = await instance.sqliteAsyncDB; + final results = await db.getAll( + ''' + SELECT DISTINCT $columnUploadedFileID FROM $filesTable + WHERE ($columnOwnerID = ? AND $columnUploadedFileID IS NOT NULL AND + $columnUploadedFileID IS NOT -1) + ''', + [ownerID], ); final ids = []; for (final result in results) { @@ -1659,16 +1700,17 @@ class FilesDB { } Future> getUploadedFiles(List uploadedIDs) async { - final db = await instance.database; + final db = await instance.sqliteAsyncDB; String inParam = ""; for (final id in uploadedIDs) { inParam += "'" + id.toString() + "',"; } inParam = inParam.substring(0, inParam.length - 1); - final results = await db.query( - filesTable, - where: '$columnUploadedFileID IN ($inParam)', - groupBy: columnUploadedFileID, + final results = await db.getAll( + ''' + SELECT * FROM $filesTable WHERE $columnUploadedFileID IN ($inParam) + GROUP BY $columnUploadedFileID +''', ); if (results.isEmpty) { return []; diff --git a/mobile/lib/events/files_updated_event.dart b/mobile/lib/events/files_updated_event.dart index 18aa8757b..2fc67d646 100644 --- a/mobile/lib/events/files_updated_event.dart +++ b/mobile/lib/events/files_updated_event.dart @@ -26,4 +26,6 @@ enum EventType { hide, unhide, coverChanged, + peopleChanged, + peopleClusterChanged, } diff --git a/mobile/lib/events/people_changed_event.dart b/mobile/lib/events/people_changed_event.dart new file mode 100644 index 000000000..51f4eaeef --- /dev/null +++ b/mobile/lib/events/people_changed_event.dart @@ -0,0 +1,22 @@ +import "package:photos/events/event.dart"; +import "package:photos/models/file/file.dart"; + +class PeopleChangedEvent extends Event { + final List? relevantFiles; + final PeopleEventType type; + final String source; + + PeopleChangedEvent({ + this.relevantFiles, + this.type = PeopleEventType.defaultType, + this.source = "", + }); + + @override + String get reason => '$runtimeType{type: ${type.name}, "via": $source}'; +} + +enum PeopleEventType { + defaultType, + removedFilesFromCluster, +} \ No newline at end of file diff --git a/mobile/lib/extensions/ml_linalg_extensions.dart b/mobile/lib/extensions/ml_linalg_extensions.dart new file mode 100644 index 000000000..85a980855 --- /dev/null +++ b/mobile/lib/extensions/ml_linalg_extensions.dart @@ -0,0 +1,193 @@ +import 'dart:math' as math show sin, cos, atan2, sqrt, pow; +import 'package:ml_linalg/linalg.dart'; + +extension SetVectorValues on Vector { + Vector setValues(int start, int end, Iterable values) { + if (values.length > length) { + throw Exception('Values cannot be larger than vector'); + } else if (end - start != values.length) { + throw Exception('Values must be same length as range'); + } else if (start < 0 || end > length) { + throw Exception('Range must be within vector'); + } + final tempList = toList(); + tempList.replaceRange(start, end, values); + final newVector = Vector.fromList(tempList); + return newVector; + } +} + +extension SetMatrixValues on Matrix { + Matrix setSubMatrix( + int startRow, + int endRow, + int startColumn, + int endColumn, + Iterable> values, + ) { + if (values.length > rowCount) { + throw Exception('New values cannot have more rows than original matrix'); + } else if (values.elementAt(0).length > columnCount) { + throw Exception( + 'New values cannot have more columns than original matrix', + ); + } else if (endRow - startRow != values.length) { + throw Exception('Values (number of rows) must be same length as range'); + } else if (endColumn - startColumn != values.elementAt(0).length) { + throw Exception( + 'Values (number of columns) must be same length as range', + ); + } else if (startRow < 0 || + endRow > rowCount || + startColumn < 0 || + endColumn > columnCount) { + throw Exception('Range must be within matrix'); + } + final tempList = asFlattenedList + .toList(); // You need `.toList()` here to make sure the list is growable, otherwise `replaceRange` will throw an error + for (var i = startRow; i < endRow; i++) { + tempList.replaceRange( + i * columnCount + startColumn, + i * columnCount + endColumn, + values.elementAt(i).toList(), + ); + } + final newMatrix = Matrix.fromFlattenedList(tempList, rowCount, columnCount); + return newMatrix; + } + + Matrix setValues( + int startRow, + int endRow, + int startColumn, + int endColumn, + Iterable values, + ) { + if ((startRow - endRow) * (startColumn - endColumn) != values.length) { + throw Exception('Values must be same length as range'); + } else if (startRow < 0 || + endRow > rowCount || + startColumn < 0 || + endColumn > columnCount) { + throw Exception('Range must be within matrix'); + } + + final tempList = asFlattenedList + .toList(); // You need `.toList()` here to make sure the list is growable, otherwise `replaceRange` will throw an error + var index = 0; + for (var i = startRow; i < endRow; i++) { + for (var j = startColumn; j < endColumn; j++) { + tempList[i * columnCount + j] = values.elementAt(index); + index++; + } + } + final newMatrix = Matrix.fromFlattenedList(tempList, rowCount, columnCount); + return newMatrix; + } + + Matrix setValue(int row, int column, double value) { + if (row < 0 || row > rowCount || column < 0 || column > columnCount) { + throw Exception('Index must be within range of matrix'); + } + final tempList = asFlattenedList; + tempList[row * columnCount + column] = value; + final newMatrix = Matrix.fromFlattenedList(tempList, rowCount, columnCount); + return newMatrix; + } + + Matrix appendRow(List row) { + final oldNumberOfRows = rowCount; + final oldNumberOfColumns = columnCount; + if (row.length != oldNumberOfColumns) { + throw Exception('Row must have same number of columns as matrix'); + } + final flatListMatrix = asFlattenedList; + flatListMatrix.addAll(row); + return Matrix.fromFlattenedList( + flatListMatrix, + oldNumberOfRows + 1, + oldNumberOfColumns, + ); + } +} + +extension MatrixCalculations on Matrix { + double determinant() { + final int length = rowCount; + if (length != columnCount) { + throw Exception('Matrix must be square'); + } + if (length == 1) { + return this[0][0]; + } else if (length == 2) { + return this[0][0] * this[1][1] - this[0][1] * this[1][0]; + } else { + throw Exception('Determinant for Matrix larger than 2x2 not implemented'); + } + } + + /// Computes the singular value decomposition of a matrix, using https://lucidar.me/en/mathematics/singular-value-decomposition-of-a-2x2-matrix/ as reference, but with slightly different signs for the second columns of U and V + Map svd() { + if (rowCount != 2 || columnCount != 2) { + throw Exception('Matrix must be 2x2'); + } + final a = this[0][0]; + final b = this[0][1]; + final c = this[1][0]; + final d = this[1][1]; + + // Computation of U matrix + final tempCalc = a * a + b * b - c * c - d * d; + final theta = 0.5 * math.atan2(2 * a * c + 2 * b * d, tempCalc); + final U = Matrix.fromList([ + [math.cos(theta), math.sin(theta)], + [math.sin(theta), -math.cos(theta)], + ]); + + // Computation of S matrix + // ignore: non_constant_identifier_names + final S1 = a * a + b * b + c * c + d * d; + // ignore: non_constant_identifier_names + final S2 = + math.sqrt(math.pow(tempCalc, 2) + 4 * math.pow(a * c + b * d, 2)); + final sigma1 = math.sqrt((S1 + S2) / 2); + final sigma2 = math.sqrt((S1 - S2) / 2); + final S = Vector.fromList([sigma1, sigma2]); + + // Computation of V matrix + final tempCalc2 = a * a - b * b + c * c - d * d; + final phi = 0.5 * math.atan2(2 * a * b + 2 * c * d, tempCalc2); + final s11 = (a * math.cos(theta) + c * math.sin(theta)) * math.cos(phi) + + (b * math.cos(theta) + d * math.sin(theta)) * math.sin(phi); + final s22 = (a * math.sin(theta) - c * math.cos(theta)) * math.sin(phi) + + (-b * math.sin(theta) + d * math.cos(theta)) * math.cos(phi); + final V = Matrix.fromList([ + [s11.sign * math.cos(phi), s22.sign * math.sin(phi)], + [s11.sign * math.sin(phi), -s22.sign * math.cos(phi)], + ]); + + return { + 'U': U, + 'S': S, + 'V': V, + }; + } + + int matrixRank() { + final svdResult = svd(); + final Vector S = svdResult['S']!; + final rank = S.toList().where((element) => element > 1e-10).length; + return rank; + } +} + +extension TransformMatrix on Matrix { + List> to2DList() { + final List> outerList = []; + for (var i = 0; i < rowCount; i++) { + final innerList = this[i].toList(); + outerList.add(innerList); + } + return outerList; + } +} diff --git a/mobile/lib/extensions/stop_watch.dart b/mobile/lib/extensions/stop_watch.dart index a381fcbc1..708af081b 100644 --- a/mobile/lib/extensions/stop_watch.dart +++ b/mobile/lib/extensions/stop_watch.dart @@ -23,4 +23,9 @@ class EnteWatch extends Stopwatch { reset(); previousElapsed = 0; } + + void stopWithLog(String msg) { + log(msg); + stop(); + } } diff --git a/mobile/lib/face/db.dart b/mobile/lib/face/db.dart new file mode 100644 index 000000000..9316916a3 --- /dev/null +++ b/mobile/lib/face/db.dart @@ -0,0 +1,1018 @@ +import 'dart:async'; +import "dart:math"; + +import "package:collection/collection.dart"; +import "package:flutter/foundation.dart"; +import 'package:logging/logging.dart'; +import 'package:path/path.dart' show join; +import 'package:path_provider/path_provider.dart'; +import "package:photos/extensions/stop_watch.dart"; +import 'package:photos/face/db_fields.dart'; +import "package:photos/face/db_model_mappers.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; +import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; +import 'package:sqlite_async/sqlite_async.dart'; + +/// Stores all data for the FacesML-related features. The database can be accessed by `FaceMLDataDB.instance.database`. +/// +/// This includes: +/// [facesTable] - Stores all the detected faces and its embeddings in the images. +/// [createFaceClustersTable] - Stores all the mappings from the faces (faceID) to the clusters (clusterID). +/// [clusterPersonTable] - Stores all the clusters that are mapped to a certain person. +/// [clusterSummaryTable] - Stores a summary of each cluster, containg the mean embedding and the number of faces in the cluster. +/// [notPersonFeedback] - Stores the clusters that are confirmed not to belong to a certain person by the user +class FaceMLDataDB { + static final Logger _logger = Logger("FaceMLDataDB"); + + static const _databaseName = "ente.face_ml_db.db"; + static const _databaseVersion = 1; + + FaceMLDataDB._privateConstructor(); + + static final FaceMLDataDB instance = FaceMLDataDB._privateConstructor(); + + // only have a single app-wide reference to the database + static Future? _sqliteAsyncDBFuture; + + Future get asyncDB async { + _sqliteAsyncDBFuture ??= _initSqliteAsyncDatabase(); + return _sqliteAsyncDBFuture!; + } + + Future _initSqliteAsyncDatabase() async { + final documentsDirectory = await getApplicationDocumentsDirectory(); + final String databaseDirectory = + join(documentsDirectory.path, _databaseName); + _logger.info("Opening sqlite_async access: DB path " + databaseDirectory); + final asyncDBConnection = + SqliteDatabase(path: databaseDirectory, maxReaders: 2); + await _onCreate(asyncDBConnection); + return asyncDBConnection; + } + + Future _onCreate(SqliteDatabase asyncDBConnection) async { + final migrations = SqliteMigrations() + ..add( + SqliteMigration(_databaseVersion, (tx) async { + await tx.execute(createFacesTable); + await tx.execute(createFaceClustersTable); + await tx.execute(createClusterPersonTable); + await tx.execute(createClusterSummaryTable); + await tx.execute(createNotPersonFeedbackTable); + await tx.execute(fcClusterIDIndex); + }), + ); + await migrations.migrate(asyncDBConnection); + } + + // bulkInsertFaces inserts the faces in the database in batches of 1000. + // This is done to avoid the error "too many SQL variables" when inserting + // a large number of faces. + Future bulkInsertFaces(List faces) async { + final db = await instance.asyncDB; + const batchSize = 500; + final numBatches = (faces.length / batchSize).ceil(); + for (int i = 0; i < numBatches; i++) { + final start = i * batchSize; + final end = min((i + 1) * batchSize, faces.length); + final batch = faces.sublist(start, end); + + const String sql = ''' + INSERT INTO $facesTable ( + $fileIDColumn, $faceIDColumn, $faceDetectionColumn, $faceEmbeddingBlob, $faceScore, $faceBlur, $isSideways, $imageHeight, $imageWidth, $mlVersionColumn + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT($fileIDColumn, $faceIDColumn) DO UPDATE SET $faceIDColumn = excluded.$faceIDColumn, $faceDetectionColumn = excluded.$faceDetectionColumn, $faceEmbeddingBlob = excluded.$faceEmbeddingBlob, $faceScore = excluded.$faceScore, $faceBlur = excluded.$faceBlur, $isSideways = excluded.$isSideways, $imageHeight = excluded.$imageHeight, $imageWidth = excluded.$imageWidth, $mlVersionColumn = excluded.$mlVersionColumn + '''; + final parameterSets = batch.map((face) { + final map = mapRemoteToFaceDB(face); + return [ + map[fileIDColumn], + map[faceIDColumn], + map[faceDetectionColumn], + map[faceEmbeddingBlob], + map[faceScore], + map[faceBlur], + map[isSideways], + map[imageHeight], + map[imageWidth], + map[mlVersionColumn], + ]; + }).toList(); + + await db.executeBatch(sql, parameterSets); + } + } + + Future updateFaceIdToClusterId( + Map faceIDToClusterID, + ) async { + final db = await instance.asyncDB; + const batchSize = 500; + final numBatches = (faceIDToClusterID.length / batchSize).ceil(); + for (int i = 0; i < numBatches; i++) { + final start = i * batchSize; + final end = min((i + 1) * batchSize, faceIDToClusterID.length); + final batch = faceIDToClusterID.entries.toList().sublist(start, end); + + const String sql = ''' + INSERT INTO $faceClustersTable ($fcFaceId, $fcClusterID) + VALUES (?, ?) + ON CONFLICT($fcFaceId) DO UPDATE SET $fcClusterID = excluded.$fcClusterID + '''; + final parameterSets = batch.map((e) => [e.key, e.value]).toList(); + + await db.executeBatch(sql, parameterSets); + } + } + + /// Returns a map of fileID to the indexed ML version + Future> getIndexedFileIds({int? minimumMlVersion}) async { + final db = await instance.asyncDB; + String query = ''' + SELECT $fileIDColumn, $mlVersionColumn + FROM $facesTable + '''; + if (minimumMlVersion != null) { + query += ' WHERE $mlVersionColumn >= $minimumMlVersion'; + } + final List> maps = await db.getAll(query); + final Map result = {}; + for (final map in maps) { + result[map[fileIDColumn] as int] = map[mlVersionColumn] as int; + } + return result; + } + + Future getIndexedFileCount({int? minimumMlVersion}) async { + final db = await instance.asyncDB; + String query = + 'SELECT COUNT(DISTINCT $fileIDColumn) as count FROM $facesTable'; + if (minimumMlVersion != null) { + query += ' WHERE $mlVersionColumn >= $minimumMlVersion'; + } + final List> maps = await db.getAll(query); + return maps.first['count'] as int; + } + + Future> clusterIdToFaceCount() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $fcClusterID, COUNT(*) as count FROM $faceClustersTable where $fcClusterID IS NOT NULL GROUP BY $fcClusterID ', + ); + final Map result = {}; + for (final map in maps) { + result[map[fcClusterID] as int] = map['count'] as int; + } + return result; + } + + Future> getPersonIgnoredClusters(String personID) async { + final db = await instance.asyncDB; + // find out clusterIds that are assigned to other persons using the clusters table + final List> otherPersonMaps = await db.getAll( + 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn != ? AND $personIdColumn IS NOT NULL', + [personID], + ); + final Set ignoredClusterIDs = + otherPersonMaps.map((e) => e[clusterIDColumn] as int).toSet(); + final List> rejectMaps = await db.getAll( + 'SELECT $clusterIDColumn FROM $notPersonFeedback WHERE $personIdColumn = ?', + [personID], + ); + final Set rejectClusterIDs = + rejectMaps.map((e) => e[clusterIDColumn] as int).toSet(); + return ignoredClusterIDs.union(rejectClusterIDs); + } + + Future> getPersonClusterIDs(String personID) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $clusterIDColumn FROM $clusterPersonTable WHERE $personIdColumn = ?', + [personID], + ); + return maps.map((e) => e[clusterIDColumn] as int).toSet(); + } + + Future clearTable() async { + final db = await instance.asyncDB; + + await db.execute(deleteFacesTable); + await db.execute(dropClusterPersonTable); + await db.execute(dropClusterSummaryTable); + await db.execute(deletePersonTable); + await db.execute(dropNotPersonFeedbackTable); + } + + Future> getFaceEmbeddingsForCluster( + int clusterID, { + int? limit, + }) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $faceEmbeddingBlob FROM $facesTable WHERE $faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID = ?) ${limit != null ? 'LIMIT $limit' : ''}', + [clusterID], + ); + return maps.map((e) => e[faceEmbeddingBlob] as Uint8List); + } + + Future>> getFaceEmbeddingsForClusters( + Iterable clusterIDs, { + int? limit, + }) async { + final db = await instance.asyncDB; + final Map> result = {}; + + final selectQuery = ''' + SELECT fc.$fcClusterID, fe.$faceEmbeddingBlob + FROM $faceClustersTable fc + INNER JOIN $facesTable fe ON fc.$fcFaceId = fe.$faceIDColumn + WHERE fc.$fcClusterID IN (${clusterIDs.join(',')}) + ${limit != null ? 'LIMIT $limit' : ''} + '''; + + final List> maps = await db.getAll(selectQuery); + + for (final map in maps) { + final clusterID = map[fcClusterID] as int; + final faceEmbedding = map[faceEmbeddingBlob] as Uint8List; + result.putIfAbsent(clusterID, () => []).add(faceEmbedding); + } + + return result; + } + + Future getCoverFaceForPerson({ + required int recentFileID, + String? personID, + String? avatarFaceId, + int? clusterID, + }) async { + // read person from db + final db = await instance.asyncDB; + if (personID != null) { + final List fileId = [recentFileID]; + int? avatarFileId; + if (avatarFaceId != null) { + avatarFileId = int.tryParse(avatarFaceId.split('_')[0]); + if (avatarFileId != null) { + fileId.add(avatarFileId); + } + } + const String queryClusterID = ''' + SELECT $clusterIDColumn + FROM $clusterPersonTable + WHERE $personIdColumn = ? + '''; + final clusterRows = await db.getAll( + queryClusterID, + [personID], + ); + final clusterIDs = + clusterRows.map((e) => e[clusterIDColumn] as int).toList(); + final List> faceMaps = await db.getAll( + 'SELECT * FROM $facesTable where ' + '$faceIDColumn in (SELECT $fcFaceId from $faceClustersTable where $fcClusterID IN (${clusterIDs.join(",")}))' + 'AND $fileIDColumn in (${fileId.join(",")}) AND $faceScore > $kMinimumQualityFaceScore ORDER BY $faceScore DESC', + ); + if (faceMaps.isNotEmpty) { + if (avatarFileId != null) { + final row = faceMaps.firstWhereOrNull( + (element) => (element[fileIDColumn] as int) == avatarFileId, + ); + if (row != null) { + return mapRowToFace(row); + } + } + return mapRowToFace(faceMaps.first); + } + } + if (clusterID != null) { + const String queryFaceID = ''' + SELECT $fcFaceId + FROM $faceClustersTable + WHERE $fcClusterID = ? + '''; + final List> faceMaps = await db.getAll( + queryFaceID, + [clusterID], + ); + final List? faces = await getFacesForGivenFileID(recentFileID); + if (faces != null) { + for (final face in faces) { + if (faceMaps + .any((element) => (element[fcFaceId] as String) == face.faceID)) { + return face; + } + } + } + } + if (personID == null && clusterID == null) { + throw Exception("personID and clusterID cannot be null"); + } + return null; + } + + Future?> getFacesForGivenFileID(int fileUploadID) async { + final db = await instance.asyncDB; + const String query = ''' + SELECT * FROM $facesTable + WHERE $fileIDColumn = ? + '''; + final List> maps = await db.getAll( + query, + [fileUploadID], + ); + if (maps.isEmpty) { + return null; + } + return maps.map((e) => mapRowToFace(e)).toList(); + } + + Future getFaceForFaceID(String faceID) async { + final db = await instance.asyncDB; + final result = await db.getAll( + 'SELECT * FROM $facesTable where $faceIDColumn = ?', + [faceID], + ); + if (result.isEmpty) { + return null; + } + return mapRowToFace(result.first); + } + + Future>> getClusterToFaceIDs( + Set clusterIDs, + ) async { + final db = await instance.asyncDB; + final Map> result = {}; + final List> maps = await db.getAll( + 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable WHERE $fcClusterID IN (${clusterIDs.join(",")})', + ); + for (final map in maps) { + final clusterID = map[fcClusterID] as int; + final faceID = map[fcFaceId] as String; + result.putIfAbsent(clusterID, () => []).add(faceID); + } + return result; + } + + Future getClusterIDForFaceID(String faceID) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $fcClusterID FROM $faceClustersTable WHERE $fcFaceId = ?', + [faceID], + ); + if (maps.isEmpty) { + return null; + } + return maps.first[fcClusterID] as int; + } + + Future>> getAllClusterIdToFaceIDs() async { + final db = await instance.asyncDB; + final Map> result = {}; + final List> maps = await db.getAll( + 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable', + ); + for (final map in maps) { + final clusterID = map[fcClusterID] as int; + final faceID = map[fcFaceId] as String; + result.putIfAbsent(clusterID, () => []).add(faceID); + } + return result; + } + + Future> getFaceIDsForCluster(int clusterID) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $fcFaceId FROM $faceClustersTable ' + 'WHERE $faceClustersTable.$fcClusterID = ?', + [clusterID], + ); + return maps.map((e) => e[fcFaceId] as String).toSet(); + } + + // Get Map of personID to Map of clusterID to faceIDs + Future>>> + getPersonToClusterIdToFaceIds() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $personIdColumn, $faceClustersTable.$fcClusterID, $fcFaceId FROM $clusterPersonTable ' + 'LEFT JOIN $faceClustersTable ON $clusterPersonTable.$clusterIDColumn = $faceClustersTable.$fcClusterID', + ); + final Map>> result = {}; + for (final map in maps) { + final personID = map[personIdColumn] as String; + final clusterID = map[fcClusterID] as int; + final faceID = map[fcFaceId] as String; + result.putIfAbsent(personID, () => {}).putIfAbsent(clusterID, () => {}) + ..add(faceID); + } + return result; + } + + Future> getFaceIDsForPerson(String personID) async { + final db = await instance.asyncDB; + final faceIdsResult = await db.getAll( + 'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' + 'WHERE $clusterPersonTable.$personIdColumn = ?', + [personID], + ); + return faceIdsResult.map((e) => e[fcFaceId] as String).toSet(); + } + + Future> getBlurValuesForCluster(int clusterID) async { + final db = await instance.asyncDB; + const String query = ''' + SELECT $facesTable.$faceBlur + FROM $facesTable + JOIN $faceClustersTable ON $facesTable.$faceIDColumn = $faceClustersTable.$fcFaceId + WHERE $faceClustersTable.$fcClusterID = ? + '''; + // const String query2 = ''' + // SELECT $faceBlur + // FROM $facesTable + // WHERE $faceIDColumn IN (SELECT $fcFaceId FROM $faceClustersTable WHERE $fcClusterID = ?) + // '''; + final List> maps = await db.getAll( + query, + [clusterID], + ); + return maps.map((e) => e[faceBlur] as double).toSet(); + } + + Future> getFaceIDsToBlurValues( + int maxBlurValue, + ) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $faceBlur FROM $facesTable WHERE $faceBlur < $maxBlurValue AND $faceBlur > 1 ORDER BY $faceBlur ASC', + ); + final Map result = {}; + for (final map in maps) { + result[map[faceIDColumn] as String] = map[faceBlur] as double; + } + return result; + } + + Future> getFaceIdsToClusterIds( + Iterable faceIds, + ) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $fcFaceId, $fcClusterID FROM $faceClustersTable where $fcFaceId IN (${faceIds.map((id) => "'$id'").join(",")})', + ); + final Map result = {}; + for (final map in maps) { + result[map[fcFaceId] as String] = map[fcClusterID] as int?; + } + return result; + } + + Future>> getFileIdToClusterIds() async { + final Map> result = {}; + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable', + ); + + for (final map in maps) { + final clusterID = map[fcClusterID] as int; + final faceID = map[fcFaceId] as String; + final x = faceID.split('_').first; + final fileID = int.parse(x); + result[fileID] = (result[fileID] ?? {})..add(clusterID); + } + return result; + } + + Future forceUpdateClusterIds( + Map faceIDToClusterID, + ) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $faceClustersTable ($fcFaceId, $fcClusterID) + VALUES (?, ?) + ON CONFLICT($fcFaceId) DO UPDATE SET $fcClusterID = excluded.$fcClusterID + '''; + final parameterSets = + faceIDToClusterID.entries.map((e) => [e.key, e.value]).toList(); + await db.executeBatch(sql, parameterSets); + } + + Future removePerson(String personID) async { + final db = await instance.asyncDB; + + await db.writeTransaction((tx) async { + await tx.execute( + 'DELETE FROM $clusterPersonTable WHERE $personIdColumn = ?', + [personID], + ); + await tx.execute( + 'DELETE FROM $notPersonFeedback WHERE $personIdColumn = ?', + [personID], + ); + }); + } + + Future> getFaceInfoForClustering({ + double minScore = kMinimumQualityFaceScore, + int minClarity = kLaplacianHardThreshold, + int maxFaces = 20000, + int offset = 0, + int batchSize = 10000, + }) async { + try { + final EnteWatch w = EnteWatch("getFaceEmbeddingMap")..start(); + w.logAndReset( + 'reading as float offset: $offset, maxFaces: $maxFaces, batchSize: $batchSize', + ); + final db = await instance.asyncDB; + + final List result = []; + while (true) { + // Query a batch of rows + final List> maps = await db.getAll( + 'SELECT $faceIDColumn, $faceEmbeddingBlob, $faceScore, $faceBlur, $isSideways FROM $facesTable' + ' WHERE $faceScore > $minScore AND $faceBlur > $minClarity' + ' ORDER BY $faceIDColumn' + ' DESC LIMIT $batchSize OFFSET $offset', + ); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + final List faceIds = []; + for (final map in maps) { + faceIds.add(map[faceIDColumn] as String); + } + final faceIdToClusterId = await getFaceIdsToClusterIds(faceIds); + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + final faceInfo = FaceInfoForClustering( + faceID: faceID, + clusterId: faceIdToClusterId[faceID], + embeddingBytes: map[faceEmbeddingBlob] as Uint8List, + faceScore: map[faceScore] as double, + blurValue: map[faceBlur] as double, + isSideways: (map[isSideways] as int) == 1, + ); + result.add(faceInfo); + } + if (result.length >= maxFaces) { + break; + } + offset += batchSize; + } + w.stopWithLog('done reading face embeddings ${result.length}'); + return result; + } catch (e) { + _logger.severe('err in getFaceInfoForClustering', e); + rethrow; + } + } + + Future> getFaceEmbeddingMapForFile( + List fileIDs, + ) async { + _logger.info('reading face embeddings for ${fileIDs.length} files'); + final db = await instance.asyncDB; + + // Define the batch size + const batchSize = 10000; + int offset = 0; + + final Map result = {}; + while (true) { + // Query a batch of rows + + final List> maps = await db.getAll(''' + SELECT $faceIDColumn, $faceEmbeddingBlob + FROM $facesTable + WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")}) + ORDER BY $faceIDColumn DESC + LIMIT $batchSize OFFSET $offset + '''); + // final List> maps = await db.query( + // facesTable, + // columns: [faceIDColumn, faceEmbeddingBlob], + // where: + // '$faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold AND $fileIDColumn IN (${fileIDs.join(",")})', + // limit: batchSize, + // offset: offset, + // orderBy: '$faceIDColumn DESC', + // ); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + result[faceID] = map[faceEmbeddingBlob] as Uint8List; + } + if (result.length > 10000) { + break; + } + offset += batchSize; + } + _logger.info('done reading face embeddings for ${fileIDs.length} files'); + return result; + } + + Future> getFaceEmbeddingMapForFaces( + Iterable faceIDs, + ) async { + _logger.info('reading face embeddings for ${faceIDs.length} faces'); + final db = await instance.asyncDB; + + // Define the batch size + const batchSize = 10000; + int offset = 0; + + final Map result = {}; + while (true) { + // Query a batch of rows + final String query = ''' + SELECT $faceIDColumn, $faceEmbeddingBlob + FROM $facesTable + WHERE $faceIDColumn IN (${faceIDs.map((id) => "'$id'").join(",")}) + ORDER BY $faceIDColumn DESC + LIMIT $batchSize OFFSET $offset + '''; + final List> maps = await db.getAll(query); + // Break the loop if no more rows + if (maps.isEmpty) { + break; + } + for (final map in maps) { + final faceID = map[faceIDColumn] as String; + result[faceID] = map[faceEmbeddingBlob] as Uint8List; + } + if (result.length > 10000) { + break; + } + offset += batchSize; + } + _logger.info('done reading face embeddings for ${faceIDs.length} faces'); + return result; + } + + Future getTotalFaceCount({ + double minFaceScore = kMinimumQualityFaceScore, + }) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianHardThreshold', + ); + return maps.first['count'] as int; + } + + Future getClusteredFaceCount() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT COUNT(DISTINCT $fcFaceId) as count FROM $faceClustersTable', + ); + return maps.first['count'] as int; + } + + Future getClusteredToTotalFacesRatio() async { + final int totalFaces = await getTotalFaceCount(); + final int clusteredFaces = await getClusteredFaceCount(); + + return clusteredFaces / totalFaces; + } + + Future getBlurryFaceCount([ + int blurThreshold = kLaplacianHardThreshold, + ]) async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT COUNT(*) as count FROM $facesTable WHERE $faceBlur <= $blurThreshold AND $faceScore > $kMinimumQualityFaceScore', + ); + return maps.first['count'] as int; + } + + /// WARNING: This method does not drop the persons and other feedback. Consider using [dropClustersAndPersonTable] instead. + Future resetClusterIDs() async { + try { + final db = await instance.asyncDB; + + await db.execute(dropFaceClustersTable); + await db.execute(createFaceClustersTable); + await db.execute(fcClusterIDIndex); + } catch (e, s) { + _logger.severe('Error resetting clusterIDs', e, s); + } + } + + Future assignClusterToPerson({ + required String personID, + required int clusterID, + }) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $clusterPersonTable ($personIdColumn, $clusterIDColumn) VALUES (?, ?) ON CONFLICT($personIdColumn, $clusterIDColumn) DO NOTHING + '''; + await db.execute(sql, [personID, clusterID]); + } + + Future bulkAssignClusterToPersonID( + Map clusterToPersonID, + ) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $clusterPersonTable ($personIdColumn, $clusterIDColumn) VALUES (?, ?) ON CONFLICT($personIdColumn, $clusterIDColumn) DO NOTHING + '''; + final parameterSets = + clusterToPersonID.entries.map((e) => [e.value, e.key]).toList(); + await db.executeBatch(sql, parameterSets); + // final batch = db.batch(); + // for (final entry in clusterToPersonID.entries) { + // final clusterID = entry.key; + // final personID = entry.value; + // batch.insert( + // clusterPersonTable, + // { + // personIdColumn: personID, + // clusterIDColumn: clusterID, + // }, + // conflictAlgorithm: ConflictAlgorithm.replace, + // ); + // } + // await batch.commit(noResult: true); + } + + Future captureNotPersonFeedback({ + required String personID, + required int clusterID, + }) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $notPersonFeedback ($personIdColumn, $clusterIDColumn) VALUES (?, ?) ON CONFLICT DO NOTHING + '''; + await db.execute(sql, [personID, clusterID]); + } + + Future bulkCaptureNotPersonFeedback( + Map clusterToPersonID, + ) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $notPersonFeedback ($personIdColumn, $clusterIDColumn) VALUES (?, ?) ON CONFLICT DO NOTHING + '''; + final parameterSets = + clusterToPersonID.entries.map((e) => [e.value, e.key]).toList(); + + await db.executeBatch(sql, parameterSets); + } + + Future removeClusterToPerson({ + required String personID, + required int clusterID, + }) async { + final db = await instance.asyncDB; + + const String sql = ''' + DELETE FROM $clusterPersonTable WHERE $personIdColumn = ? AND $clusterIDColumn = ? + '''; + await db.execute(sql, [personID, clusterID]); + } + + // for a given personID, return a map of clusterID to fileIDs using join query + Future>> getFileIdToClusterIDSet(String personID) { + final db = instance.asyncDB; + return db.then((db) async { + final List> maps = await db.getAll( + 'SELECT $faceClustersTable.$fcClusterID, $fcFaceId FROM $faceClustersTable ' + 'INNER JOIN $clusterPersonTable ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' + 'WHERE $clusterPersonTable.$personIdColumn = ?', + [personID], + ); + final Map> result = {}; + for (final map in maps) { + final clusterID = map[clusterIDColumn] as int; + final String faceID = map[fcFaceId] as String; + final fileID = int.parse(faceID.split('_').first); + result[fileID] = (result[fileID] ?? {})..add(clusterID); + } + return result; + }); + } + + Future>> getFileIdToClusterIDSetForCluster( + Set clusterIDs, + ) { + final db = instance.asyncDB; + return db.then((db) async { + final List> maps = await db.getAll( + 'SELECT $fcClusterID, $fcFaceId FROM $faceClustersTable ' + 'WHERE $fcClusterID IN (${clusterIDs.join(",")})', + ); + final Map> result = {}; + for (final map in maps) { + final clusterID = map[fcClusterID] as int; + final faceId = map[fcFaceId] as String; + final fileID = int.parse(faceId.split("_").first); + result[fileID] = (result[fileID] ?? {})..add(clusterID); + } + return result; + }); + } + + Future clusterSummaryUpdate(Map summary) async { + final db = await instance.asyncDB; + + const String sql = ''' + INSERT INTO $clusterSummaryTable ($clusterIDColumn, $avgColumn, $countColumn) VALUES (?, ?, ?) ON CONFLICT($clusterIDColumn) DO UPDATE SET $avgColumn = excluded.$avgColumn, $countColumn = excluded.$countColumn + '''; + final List> parameterSets = []; + int batchCounter = 0; + for (final entry in summary.entries) { + if (batchCounter == 400) { + await db.executeBatch(sql, parameterSets); + batchCounter = 0; + parameterSets.clear(); + } + final int clusterID = entry.key; + final int count = entry.value.$2; + final Uint8List avg = entry.value.$1; + parameterSets.add([clusterID, avg, count]); + batchCounter++; + } + await db.executeBatch(sql, parameterSets); + } + + Future deleteClusterSummary(int clusterID) async { + final db = await instance.asyncDB; + const String sqlDelete = + 'DELETE FROM $clusterSummaryTable WHERE $clusterIDColumn = ?'; + await db.execute(sqlDelete, [clusterID]); + } + + /// Returns a map of clusterID to (avg embedding, count) + Future> getAllClusterSummary([ + int? minClusterSize, + ]) async { + final db = await instance.asyncDB; + final Map result = {}; + final rows = await db.getAll( + 'SELECT * FROM $clusterSummaryTable${minClusterSize != null ? ' WHERE $countColumn >= $minClusterSize' : ''}', + ); + for (final r in rows) { + final id = r[clusterIDColumn] as int; + final avg = r[avgColumn] as Uint8List; + final count = r[countColumn] as int; + result[id] = (avg, count); + } + return result; + } + + Future> getClusterToClusterSummary( + Iterable clusterIDs, + ) async { + final db = await instance.asyncDB; + final Map result = {}; + final rows = await db.getAll( + 'SELECT * FROM $clusterSummaryTable WHERE $clusterIDColumn IN (${clusterIDs.join(",")})', + ); + for (final r in rows) { + final id = r[clusterIDColumn] as int; + final avg = r[avgColumn] as Uint8List; + final count = r[countColumn] as int; + result[id] = (avg, count); + } + return result; + } + + Future> getClusterIDToPersonID() async { + final db = await instance.asyncDB; + final List> maps = await db.getAll( + 'SELECT $personIdColumn, $clusterIDColumn FROM $clusterPersonTable', + ); + final Map result = {}; + for (final map in maps) { + result[map[clusterIDColumn] as int] = map[personIdColumn] as String; + } + return result; + } + + /// WARNING: This will delete ALL data in the database! Only use this for debug/testing purposes! + Future dropClustersAndPersonTable({bool faces = false}) async { + try { + final db = await instance.asyncDB; + if (faces) { + await db.execute(deleteFacesTable); + await db.execute(createFacesTable); + await db.execute(dropFaceClustersTable); + await db.execute(createFaceClustersTable); + await db.execute(fcClusterIDIndex); + } + + await db.execute(deletePersonTable); + await db.execute(dropClusterPersonTable); + await db.execute(dropNotPersonFeedbackTable); + await db.execute(dropClusterSummaryTable); + await db.execute(dropFaceClustersTable); + + await db.execute(createClusterPersonTable); + await db.execute(createNotPersonFeedbackTable); + await db.execute(createClusterSummaryTable); + await db.execute(createFaceClustersTable); + await db.execute(fcClusterIDIndex); + } catch (e, s) { + _logger.severe('Error dropping clusters and person table', e, s); + } + } + + /// WARNING: This will delete ALL data in the tables! Only use this for debug/testing purposes! + Future dropFeedbackTables() async { + try { + final db = await instance.asyncDB; + + // Drop the tables + await db.execute(deletePersonTable); + await db.execute(dropClusterPersonTable); + await db.execute(dropNotPersonFeedbackTable); + + // Recreate the tables + await db.execute(createClusterPersonTable); + await db.execute(createNotPersonFeedbackTable); + } catch (e) { + _logger.severe('Error dropping feedback tables', e); + } + } + + Future removeFilesFromPerson( + List files, + String personID, + ) async { + final db = await instance.asyncDB; + final faceIdsResult = await db.getAll( + 'SELECT $fcFaceId FROM $faceClustersTable LEFT JOIN $clusterPersonTable ' + 'ON $faceClustersTable.$fcClusterID = $clusterPersonTable.$clusterIDColumn ' + 'WHERE $clusterPersonTable.$personIdColumn = ?', + [personID], + ); + final Set fileIds = {}; + for (final enteFile in files) { + fileIds.add(enteFile.uploadedFileID.toString()); + } + int maxClusterID = DateTime.now().microsecondsSinceEpoch; + final Map faceIDToClusterID = {}; + for (final row in faceIdsResult) { + final faceID = row[fcFaceId] as String; + if (fileIds.contains(faceID.split('_').first)) { + maxClusterID += 1; + faceIDToClusterID[faceID] = maxClusterID; + } + } + await forceUpdateClusterIds(faceIDToClusterID); + } + + Future removeFilesFromCluster( + List files, + int clusterID, + ) async { + final db = await instance.asyncDB; + final faceIdsResult = await db.getAll( + 'SELECT $fcFaceId FROM $faceClustersTable ' + 'WHERE $faceClustersTable.$fcClusterID = ?', + [clusterID], + ); + final Set fileIds = {}; + for (final enteFile in files) { + fileIds.add(enteFile.uploadedFileID.toString()); + } + int maxClusterID = DateTime.now().microsecondsSinceEpoch; + final Map faceIDToClusterID = {}; + for (final row in faceIdsResult) { + final faceID = row[fcFaceId] as String; + if (fileIds.contains(faceID.split('_').first)) { + maxClusterID += 1; + faceIDToClusterID[faceID] = maxClusterID; + } + } + await forceUpdateClusterIds(faceIDToClusterID); + } + + Future addFacesToCluster( + List faceIDs, + int clusterID, + ) async { + final faceIDToClusterID = {}; + for (final faceID in faceIDs) { + faceIDToClusterID[faceID] = clusterID; + } + + await forceUpdateClusterIds(faceIDToClusterID); + } +} diff --git a/mobile/lib/face/db_fields.dart b/mobile/lib/face/db_fields.dart new file mode 100644 index 000000000..e6a70a7d4 --- /dev/null +++ b/mobile/lib/face/db_fields.dart @@ -0,0 +1,103 @@ +// Faces Table Fields & Schema Queries +import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; + +const facesTable = 'faces'; +const fileIDColumn = 'file_id'; +const faceIDColumn = 'face_id'; +const faceDetectionColumn = 'detection'; +const faceEmbeddingBlob = 'eBlob'; +const faceScore = 'score'; +const faceBlur = 'blur'; +const isSideways = 'is_sideways'; +const imageWidth = 'width'; +const imageHeight = 'height'; +const faceClusterId = 'cluster_id'; +const mlVersionColumn = 'ml_version'; + +const createFacesTable = '''CREATE TABLE IF NOT EXISTS $facesTable ( + $fileIDColumn INTEGER NOT NULL, + $faceIDColumn TEXT NOT NULL UNIQUE, + $faceDetectionColumn TEXT NOT NULL, + $faceEmbeddingBlob BLOB NOT NULL, + $faceScore REAL NOT NULL, + $faceBlur REAL NOT NULL DEFAULT $kLapacianDefault, + $isSideways INTEGER NOT NULL DEFAULT 0, + $imageHeight INTEGER NOT NULL DEFAULT 0, + $imageWidth INTEGER NOT NULL DEFAULT 0, + $mlVersionColumn INTEGER NOT NULL DEFAULT -1, + PRIMARY KEY($fileIDColumn, $faceIDColumn) + ); + '''; + +const deleteFacesTable = 'DROP TABLE IF EXISTS $facesTable'; +// End of Faces Table Fields & Schema Queries + +//##region Face Clusters Table Fields & Schema Queries +const faceClustersTable = 'face_clusters'; +const fcClusterID = 'cluster_id'; +const fcFaceId = 'face_id'; + +// fcClusterId & fcFaceId are the primary keys and fcClusterId is a foreign key to faces table +const createFaceClustersTable = ''' +CREATE TABLE IF NOT EXISTS $faceClustersTable ( + $fcFaceId TEXT NOT NULL, + $fcClusterID INTEGER NOT NULL, + PRIMARY KEY($fcFaceId) +); +'''; +// -- Creating a non-unique index on clusterID for query optimization +const fcClusterIDIndex = + '''CREATE INDEX IF NOT EXISTS idx_fcClusterID ON $faceClustersTable($fcClusterID);'''; +const dropFaceClustersTable = 'DROP TABLE IF EXISTS $faceClustersTable'; +//##endregion + +// People Table Fields & Schema Queries +const personTable = 'person'; + +const deletePersonTable = 'DROP TABLE IF EXISTS $personTable'; +//End People Table Fields & Schema Queries + +// Clusters Table Fields & Schema Queries +const clusterPersonTable = 'cluster_person'; +const personIdColumn = 'person_id'; +const clusterIDColumn = 'cluster_id'; + +const createClusterPersonTable = ''' +CREATE TABLE IF NOT EXISTS $clusterPersonTable ( + $personIdColumn TEXT NOT NULL, + $clusterIDColumn INTEGER NOT NULL, + PRIMARY KEY($personIdColumn, $clusterIDColumn) +); +'''; +const dropClusterPersonTable = 'DROP TABLE IF EXISTS $clusterPersonTable'; +// End Clusters Table Fields & Schema Queries + +/// Cluster Summary Table Fields & Schema Queries +const clusterSummaryTable = 'cluster_summary'; +const avgColumn = 'avg'; +const countColumn = 'count'; +const createClusterSummaryTable = ''' +CREATE TABLE IF NOT EXISTS $clusterSummaryTable ( + $clusterIDColumn INTEGER NOT NULL, + $avgColumn BLOB NOT NULL, + $countColumn INTEGER NOT NULL, + PRIMARY KEY($clusterIDColumn) +); +'''; + +const dropClusterSummaryTable = 'DROP TABLE IF EXISTS $clusterSummaryTable'; + +/// End Cluster Summary Table Fields & Schema Queries + +/// notPersonFeedback Table Fields & Schema Queries +const notPersonFeedback = 'not_person_feedback'; + +const createNotPersonFeedbackTable = ''' +CREATE TABLE IF NOT EXISTS $notPersonFeedback ( + $personIdColumn TEXT NOT NULL, + $clusterIDColumn INTEGER NOT NULL, + PRIMARY KEY($personIdColumn, $clusterIDColumn) +); +'''; +const dropNotPersonFeedbackTable = 'DROP TABLE IF EXISTS $notPersonFeedback'; +// End Clusters Table Fields & Schema Queries diff --git a/mobile/lib/face/db_model_mappers.dart b/mobile/lib/face/db_model_mappers.dart new file mode 100644 index 000000000..70dc77915 --- /dev/null +++ b/mobile/lib/face/db_model_mappers.dart @@ -0,0 +1,57 @@ +import "dart:convert"; + +import 'package:photos/face/db_fields.dart'; +import "package:photos/face/model/detection.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/generated/protos/ente/common/vector.pb.dart"; +import "package:photos/models/ml/ml_versions.dart"; + +int boolToSQLInt(bool? value, {bool defaultValue = false}) { + final bool v = value ?? defaultValue; + if (v == false) { + return 0; + } else { + return 1; + } +} + +bool sqlIntToBool(int? value, {bool defaultValue = false}) { + final int v = value ?? (defaultValue ? 1 : 0); + if (v == 0) { + return false; + } else { + return true; + } +} + +Map mapRemoteToFaceDB(Face face) { + return { + faceIDColumn: face.faceID, + fileIDColumn: face.fileID, + faceDetectionColumn: json.encode(face.detection.toJson()), + faceEmbeddingBlob: EVector( + values: face.embedding, + ).writeToBuffer(), + faceScore: face.score, + faceBlur: face.blur, + isSideways: face.detection.faceIsSideways() ? 1 : 0, + mlVersionColumn: faceMlVersion, + imageWidth: face.fileInfo?.imageWidth ?? 0, + imageHeight: face.fileInfo?.imageHeight ?? 0, + }; +} + +Face mapRowToFace(Map row) { + return Face( + row[faceIDColumn] as String, + row[fileIDColumn] as int, + EVector.fromBuffer(row[faceEmbeddingBlob] as List).values, + row[faceScore] as double, + Detection.fromJson(json.decode(row[faceDetectionColumn] as String)), + row[faceBlur] as double, + fileInfo: FileInfo( + imageWidth: row[imageWidth] as int, + imageHeight: row[imageHeight] as int, + ), + ); +} diff --git a/mobile/lib/face/model/box.dart b/mobile/lib/face/model/box.dart new file mode 100644 index 000000000..73d7dea38 --- /dev/null +++ b/mobile/lib/face/model/box.dart @@ -0,0 +1,43 @@ +/// Bounding box of a face. +/// +/// [xMin] and [yMin] are the coordinates of the top left corner of the box, and +/// [width] and [height] are the width and height of the box. +/// +/// WARNING: All values are relative to the original image size, so in the range [0, 1]. +class FaceBox { + final double xMin; + final double yMin; + final double width; + final double height; + + FaceBox({ + required this.xMin, + required this.yMin, + required this.width, + required this.height, + }); + + factory FaceBox.fromJson(Map json) { + return FaceBox( + xMin: (json['xMin'] is int + ? (json['xMin'] as int).toDouble() + : json['xMin'] as double), + yMin: (json['yMin'] is int + ? (json['yMin'] as int).toDouble() + : json['yMin'] as double), + width: (json['width'] is int + ? (json['width'] as int).toDouble() + : json['width'] as double), + height: (json['height'] is int + ? (json['height'] as int).toDouble() + : json['height'] as double), + ); + } + + Map toJson() => { + 'xMin': xMin, + 'yMin': yMin, + 'width': width, + 'height': height, + }; +} diff --git a/mobile/lib/face/model/detection.dart b/mobile/lib/face/model/detection.dart new file mode 100644 index 000000000..6fc5fa07b --- /dev/null +++ b/mobile/lib/face/model/detection.dart @@ -0,0 +1,120 @@ +import "dart:math" show min, max; + +import "package:photos/face/model/box.dart"; +import "package:photos/face/model/landmark.dart"; +import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; + +/// Stores the face detection data, notably the bounding box and landmarks. +/// +/// - Bounding box: [FaceBox] with xMin, yMin (so top left corner), width, height +/// - Landmarks: list of [Landmark]s, namely leftEye, rightEye, nose, leftMouth, rightMouth +/// +/// WARNING: All coordinates are relative to the image size, so in the range [0, 1]! +class Detection { + FaceBox box; + List landmarks; + + Detection({ + required this.box, + required this.landmarks, + }); + + bool get isEmpty => box.width == 0 && box.height == 0 && landmarks.isEmpty; + + // empty box + Detection.empty() + : box = FaceBox( + xMin: 0, + yMin: 0, + width: 0, + height: 0, + ), + landmarks = []; + + Map toJson() => { + 'box': box.toJson(), + 'landmarks': landmarks.map((x) => x.toJson()).toList(), + }; + + factory Detection.fromJson(Map json) { + return Detection( + box: FaceBox.fromJson(json['box'] as Map), + landmarks: List.from( + json['landmarks'] + .map((x) => Landmark.fromJson(x as Map)), + ), + ); + } + + int getFaceArea(int imageWidth, int imageHeight) { + return (box.width * imageWidth * box.height * imageHeight).toInt(); + } + + FaceDirection getFaceDirection() { + if (isEmpty) { + return FaceDirection.straight; + } + final leftEye = [landmarks[0].x, landmarks[0].y]; + final rightEye = [landmarks[1].x, landmarks[1].y]; + final nose = [landmarks[2].x, landmarks[2].y]; + final leftMouth = [landmarks[3].x, landmarks[3].y]; + final rightMouth = [landmarks[4].x, landmarks[4].y]; + + final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs(); + final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs(); + final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs(); + + final bool faceIsUpright = + (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) && + (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1])); + + final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) && + (nose[0] < min(leftMouth[0], rightMouth[0])); + final bool noseStickingOutRight = + (nose[0] > max(leftEye[0], rightEye[0])) && + (nose[0] > max(leftMouth[0], rightMouth[0])); + + final bool noseCloseToLeftEye = + (nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX; + final bool noseCloseToRightEye = + (nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX; + + // if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) { + if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) { + return FaceDirection.left; + // } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) { + } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) { + return FaceDirection.right; + } + + return FaceDirection.straight; + } + + bool faceIsSideways() { + if (isEmpty) { + return false; + } + final leftEye = [landmarks[0].x, landmarks[0].y]; + final rightEye = [landmarks[1].x, landmarks[1].y]; + final nose = [landmarks[2].x, landmarks[2].y]; + final leftMouth = [landmarks[3].x, landmarks[3].y]; + final rightMouth = [landmarks[4].x, landmarks[4].y]; + + final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs(); + final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs(); + final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs(); + + final bool faceIsUpright = + (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) && + (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1])); + + final bool noseStickingOutLeft = + (nose[0] < min(leftEye[0], rightEye[0]) - 0.5 * eyeDistanceX) && + (nose[0] < min(leftMouth[0], rightMouth[0])); + final bool noseStickingOutRight = + (nose[0] > max(leftEye[0], rightEye[0]) + 0.5 * eyeDistanceX) && + (nose[0] > max(leftMouth[0], rightMouth[0])); + + return faceIsUpright && (noseStickingOutLeft || noseStickingOutRight); + } +} diff --git a/mobile/lib/face/model/dimension.dart b/mobile/lib/face/model/dimension.dart new file mode 100644 index 000000000..d4ae7a3bc --- /dev/null +++ b/mobile/lib/face/model/dimension.dart @@ -0,0 +1,25 @@ +class Dimensions { + final int width; + final int height; + + const Dimensions({required this.width, required this.height}); + + @override + String toString() { + return 'Dimensions(width: $width, height: $height})'; + } + + Map toJson() { + return { + 'width': width, + 'height': height, + }; + } + + factory Dimensions.fromJson(Map json) { + return Dimensions( + width: json['width'] as int, + height: json['height'] as int, + ); + } +} diff --git a/mobile/lib/face/model/face.dart b/mobile/lib/face/model/face.dart new file mode 100644 index 000000000..c21538949 --- /dev/null +++ b/mobile/lib/face/model/face.dart @@ -0,0 +1,85 @@ +import "package:photos/face/model/detection.dart"; +import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; +import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; + +// FileInfo contains the image width and height of the image the face was detected in. +class FileInfo { + int? imageWidth; + int? imageHeight; + FileInfo({ + this.imageWidth, + this.imageHeight, + }); +} + +class Face { + final String faceID; + final List embedding; + Detection detection; + final double score; + final double blur; + + ///#region Local DB fields + // This is not stored on the server, using it for local DB row + FileInfo? fileInfo; + final int fileID; + ///#endregion + + bool get isBlurry => blur < kLaplacianHardThreshold; + + bool get hasHighScore => score > kMinimumQualityFaceScore; + + bool get isHighQuality => (!isBlurry) && hasHighScore; + + int area({int? w, int? h}) { + return detection.getFaceArea( + fileInfo?.imageWidth ?? w ?? 0, + fileInfo?.imageHeight ?? h ?? 0, + ); + } + + Face( + this.faceID, + this.fileID, + this.embedding, + this.score, + this.detection, + this.blur, { + this.fileInfo, + }); + + factory Face.empty(int fileID, {bool error = false}) { + return Face( + "$fileID-0", + fileID, + [], + error ? -1.0 : 0.0, + Detection.empty(), + 0.0, + ); + } + + factory Face.fromJson(Map json) { + final String faceID = json['faceID'] as String; + final int fileID = getFileIdFromFaceId(faceID); + return Face( + faceID, + fileID, + List.from((json['embedding'] ?? json['embeddings']) as List), + json['score'] as double, + Detection.fromJson(json['detection'] as Map), + // high value means t + (json['blur'] ?? kLapacianDefault) as double, + ); + } + + // Note: Keep the information in toJson minimum. Keep in sync with desktop. + // Derive fields like fileID from other values whenever possible + Map toJson() => { + 'faceID': faceID, + 'embedding': embedding, + 'detection': detection.toJson(), + 'score': score, + 'blur': blur, + }; +} diff --git a/mobile/lib/face/model/landmark.dart b/mobile/lib/face/model/landmark.dart new file mode 100644 index 000000000..320afbabd --- /dev/null +++ b/mobile/lib/face/model/landmark.dart @@ -0,0 +1,33 @@ +/// Landmark coordinate data. +/// +/// WARNING: All coordinates are relative to the image size, so in the range [0, 1]! +class Landmark { + double x; + double y; + + Landmark({ + required this.x, + required this.y, + }); + + Map toJson() => { + 'x': x, + 'y': y, + }; + + factory Landmark.fromJson(Map json) { + return Landmark( + x: (json['x'] is int + ? (json['x'] as int).toDouble() + : json['x'] as double), + y: (json['y'] is int + ? (json['y'] as int).toDouble() + : json['y'] as double), + ); + } + + @override + toString() { + return '(x: ${x.toStringAsFixed(4)}, y: ${y.toStringAsFixed(4)})'; + } +} diff --git a/mobile/lib/face/model/person.dart b/mobile/lib/face/model/person.dart new file mode 100644 index 000000000..cedec7a0d --- /dev/null +++ b/mobile/lib/face/model/person.dart @@ -0,0 +1,139 @@ +// PersonEntity represents information about a Person in the context of FaceClustering that is stored. +// On the remote server, the PersonEntity is stored as {Entity} with type person. +// On the device, this information is stored as [LocalEntityData] with type person. +import "package:flutter/foundation.dart"; + +class PersonEntity { + final String remoteID; + final PersonData data; + PersonEntity( + this.remoteID, + this.data, + ); + + // copyWith + PersonEntity copyWith({ + String? remoteID, + PersonData? data, + }) { + return PersonEntity( + remoteID ?? this.remoteID, + data ?? this.data, + ); + } +} + +class ClusterInfo { + final int id; + final Set faces; + ClusterInfo({ + required this.id, + required this.faces, + }); + + // toJson + Map toJson() => { + 'id': id, + 'faces': faces.toList(), + }; + + // from Json + factory ClusterInfo.fromJson(Map json) { + return ClusterInfo( + id: json['id'] as int, + faces: (json['faces'] as List).map((e) => e as String).toSet(), + ); + } +} + +class PersonData { + final String name; + final bool isHidden; + String? avatarFaceId; + List? assigned = List.empty(); + List? rejected = List.empty(); + final String? birthDate; + + bool hasAvatar() => avatarFaceId != null; + + bool get isIgnored => + (name.isEmpty || name == '(hidden)' || name == '(ignored)'); + + PersonData({ + required this.name, + this.assigned, + this.rejected, + this.avatarFaceId, + this.isHidden = false, + this.birthDate, + }); + // copyWith + PersonData copyWith({ + String? name, + List? assigned, + String? avatarFaceId, + bool? isHidden, + int? version, + String? birthDate, + }) { + return PersonData( + name: name ?? this.name, + assigned: assigned ?? this.assigned, + avatarFaceId: avatarFaceId ?? this.avatarFaceId, + isHidden: isHidden ?? this.isHidden, + birthDate: birthDate ?? this.birthDate, + ); + } + + void logStats() { + if (kDebugMode == false) return; + // log number of assigned and rejected clusters and total number of faces in each cluster + final StringBuffer sb = StringBuffer(); + sb.writeln('Person: $name'); + int assignedCount = 0; + for (final a in (assigned ?? [])) { + assignedCount += a.faces.length; + } + sb.writeln('Assigned: ${assigned?.length} withFaces $assignedCount'); + sb.writeln('Rejected: ${rejected?.length}'); + if (assigned != null) { + for (var cluster in assigned!) { + sb.writeln('Cluster: ${cluster.id} - ${cluster.faces.length}'); + } + } + debugPrint(sb.toString()); + } + + // toJson + Map toJson() => { + 'name': name, + 'assigned': assigned?.map((e) => e.toJson()).toList(), + 'rejected': rejected?.map((e) => e.toJson()).toList(), + 'avatarFaceId': avatarFaceId, + 'isHidden': isHidden, + 'birthDate': birthDate, + }; + + // fromJson + factory PersonData.fromJson(Map json) { + final assigned = (json['assigned'] == null || json['assigned'].length == 0) + ? [] + : List.from( + json['assigned'].map((x) => ClusterInfo.fromJson(x)), + ); + + final rejected = (json['rejected'] == null || json['rejected'].length == 0) + ? [] + : List.from( + json['rejected'].map((x) => ClusterInfo.fromJson(x)), + ); + return PersonData( + name: json['name'] as String, + assigned: assigned, + rejected: rejected, + avatarFaceId: json['avatarFaceId'] as String?, + isHidden: json['isHidden'] as bool? ?? false, + birthDate: json['birthDate'] as String?, + ); + } +} diff --git a/mobile/lib/generated/intl/messages_cs.dart b/mobile/lib/generated/intl/messages_cs.dart index 8db8489d3..4506011b1 100644 --- a/mobile/lib/generated/intl/messages_cs.dart +++ b/mobile/lib/generated/intl/messages_cs.dart @@ -34,6 +34,8 @@ class MessageLookup extends MessageLookupByLibrary { "addViewers": m1, "changeLocationOfSelectedItems": MessageLookupByLibrary.simpleMessage( "Change location of selected items?"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "contacts": MessageLookupByLibrary.simpleMessage("Contacts"), "createCollaborativeLink": MessageLookupByLibrary.simpleMessage("Create collaborative link"), @@ -44,7 +46,14 @@ class MessageLookup extends MessageLookupByLibrary { "editsToLocationWillOnlyBeSeenWithinEnte": MessageLookupByLibrary.simpleMessage( "Edits to location will only be seen within Ente"), + "enterPersonName": + MessageLookupByLibrary.simpleMessage("Enter person name"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "fileTypes": MessageLookupByLibrary.simpleMessage("File types"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "joinDiscord": MessageLookupByLibrary.simpleMessage("Join Discord"), "locations": MessageLookupByLibrary.simpleMessage("Locations"), "longPressAnEmailToVerifyEndToEndEncryption": @@ -55,6 +64,8 @@ class MessageLookup extends MessageLookupByLibrary { "Modify your query, or try searching for"), "moveToHiddenAlbum": MessageLookupByLibrary.simpleMessage("Move to hidden album"), + "removePersonLabel": + MessageLookupByLibrary.simpleMessage("Remove person label"), "search": MessageLookupByLibrary.simpleMessage("Search"), "selectALocation": MessageLookupByLibrary.simpleMessage("Select a location"), diff --git a/mobile/lib/generated/intl/messages_de.dart b/mobile/lib/generated/intl/messages_de.dart index 442cae919..0ff50cfa4 100644 --- a/mobile/lib/generated/intl/messages_de.dart +++ b/mobile/lib/generated/intl/messages_de.dart @@ -227,6 +227,7 @@ class MessageLookup extends MessageLookupByLibrary { "Ich verstehe, dass ich meine Daten verlieren kann, wenn ich mein Passwort vergesse, da meine Daten Ende-zu-Ende-verschlüsselt sind."), "activeSessions": MessageLookupByLibrary.simpleMessage("Aktive Sitzungen"), + "addAName": MessageLookupByLibrary.simpleMessage("Add a name"), "addANewEmail": MessageLookupByLibrary.simpleMessage( "Neue E-Mail-Adresse hinzufügen"), "addCollaborator": @@ -435,6 +436,8 @@ class MessageLookup extends MessageLookupByLibrary { "Nach Aufnahmezeit gruppieren"), "clubByFileName": MessageLookupByLibrary.simpleMessage("Nach Dateiname gruppieren"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "codeAppliedPageTitle": MessageLookupByLibrary.simpleMessage("Code eingelöst"), "codeCopiedToClipboard": MessageLookupByLibrary.simpleMessage( @@ -675,6 +678,8 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Passwort eingeben"), "enterPasswordToEncrypt": MessageLookupByLibrary.simpleMessage( "Gib ein Passwort ein, mit dem wir deine Daten verschlüsseln können"), + "enterPersonName": + MessageLookupByLibrary.simpleMessage("Enter person name"), "enterReferralCode": MessageLookupByLibrary.simpleMessage( "Gib den Weiterempfehlungs-Code ein"), "enterThe6digitCodeFromnyourAuthenticatorApp": @@ -699,6 +704,10 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Protokolle exportieren"), "exportYourData": MessageLookupByLibrary.simpleMessage("Daten exportieren"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "faces": MessageLookupByLibrary.simpleMessage("Gesichter"), "failedToApplyCode": MessageLookupByLibrary.simpleMessage( "Der Code konnte nicht aktiviert werden"), @@ -738,11 +747,14 @@ class MessageLookup extends MessageLookupByLibrary { "filesBackedUpInAlbum": m23, "filesDeleted": MessageLookupByLibrary.simpleMessage("Dateien gelöscht"), + "findPeopleByName": MessageLookupByLibrary.simpleMessage( + "Find people quickly by searching by name"), "flip": MessageLookupByLibrary.simpleMessage("Spiegeln"), "forYourMemories": MessageLookupByLibrary.simpleMessage("Als Erinnerung"), "forgotPassword": MessageLookupByLibrary.simpleMessage("Passwort vergessen"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "freeStorageClaimed": MessageLookupByLibrary.simpleMessage( "Kostenlos hinzugefügter Speicherplatz"), "freeStorageOnReferralSuccess": m24, @@ -1164,6 +1176,8 @@ class MessageLookup extends MessageLookupByLibrary { "removeParticipant": MessageLookupByLibrary.simpleMessage("Teilnehmer entfernen"), "removeParticipantBody": m43, + "removePersonLabel": + MessageLookupByLibrary.simpleMessage("Remove person label"), "removePublicLink": MessageLookupByLibrary.simpleMessage("Öffentlichen Link entfernen"), "removeShareItemsWarning": MessageLookupByLibrary.simpleMessage( diff --git a/mobile/lib/generated/intl/messages_en.dart b/mobile/lib/generated/intl/messages_en.dart index aab7f47bd..ee799aeb9 100644 --- a/mobile/lib/generated/intl/messages_en.dart +++ b/mobile/lib/generated/intl/messages_en.dart @@ -132,7 +132,7 @@ class MessageLookup extends MessageLookupByLibrary { "Please talk to ${providerName} support if you were charged"; static String m38(endDate) => - "Free trial valid till ${endDate}.\nYou can purchase a paid plan afterwards."; + "Free trial valid till ${endDate}.\nYou can choose a paid plan afterwards."; static String m39(toEmail) => "Please email us at ${toEmail}"; @@ -225,6 +225,7 @@ class MessageLookup extends MessageLookupByLibrary { "I understand that if I lose my password, I may lose my data since my data is end-to-end encrypted."), "activeSessions": MessageLookupByLibrary.simpleMessage("Active sessions"), + "addAName": MessageLookupByLibrary.simpleMessage("Add a name"), "addANewEmail": MessageLookupByLibrary.simpleMessage("Add a new email"), "addCollaborator": MessageLookupByLibrary.simpleMessage("Add collaborator"), @@ -434,6 +435,8 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Club by capture time"), "clubByFileName": MessageLookupByLibrary.simpleMessage("Club by file name"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "codeAppliedPageTitle": MessageLookupByLibrary.simpleMessage("Code applied"), "codeCopiedToClipboard": @@ -675,6 +678,8 @@ class MessageLookup extends MessageLookupByLibrary { "enterPassword": MessageLookupByLibrary.simpleMessage("Enter password"), "enterPasswordToEncrypt": MessageLookupByLibrary.simpleMessage( "Enter a password we can use to encrypt your data"), + "enterPersonName": + MessageLookupByLibrary.simpleMessage("Enter person name"), "enterReferralCode": MessageLookupByLibrary.simpleMessage("Enter referral code"), "enterThe6digitCodeFromnyourAuthenticatorApp": @@ -697,6 +702,10 @@ class MessageLookup extends MessageLookupByLibrary { "exportLogs": MessageLookupByLibrary.simpleMessage("Export logs"), "exportYourData": MessageLookupByLibrary.simpleMessage("Export your data"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "faces": MessageLookupByLibrary.simpleMessage("Faces"), "failedToApplyCode": MessageLookupByLibrary.simpleMessage("Failed to apply code"), @@ -736,11 +745,14 @@ class MessageLookup extends MessageLookupByLibrary { "filesDeleted": MessageLookupByLibrary.simpleMessage("Files deleted"), "filesSavedToGallery": MessageLookupByLibrary.simpleMessage("Files saved to gallery"), + "findPeopleByName": + MessageLookupByLibrary.simpleMessage("Find people quickly by name"), "flip": MessageLookupByLibrary.simpleMessage("Flip"), "forYourMemories": MessageLookupByLibrary.simpleMessage("for your memories"), "forgotPassword": MessageLookupByLibrary.simpleMessage("Forgot password"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "freeStorageClaimed": MessageLookupByLibrary.simpleMessage("Free storage claimed"), "freeStorageOnReferralSuccess": m24, @@ -1022,6 +1034,7 @@ class MessageLookup extends MessageLookupByLibrary { "paymentFailedTalkToProvider": m37, "pendingItems": MessageLookupByLibrary.simpleMessage("Pending items"), "pendingSync": MessageLookupByLibrary.simpleMessage("Pending sync"), + "people": MessageLookupByLibrary.simpleMessage("People"), "peopleUsingYourCode": MessageLookupByLibrary.simpleMessage("People using your code"), "permDeleteWarning": MessageLookupByLibrary.simpleMessage( @@ -1151,6 +1164,8 @@ class MessageLookup extends MessageLookupByLibrary { "removeParticipant": MessageLookupByLibrary.simpleMessage("Remove participant"), "removeParticipantBody": m43, + "removePersonLabel": + MessageLookupByLibrary.simpleMessage("Remove person label"), "removePublicLink": MessageLookupByLibrary.simpleMessage("Remove public link"), "removeShareItemsWarning": MessageLookupByLibrary.simpleMessage( @@ -1208,8 +1223,8 @@ class MessageLookup extends MessageLookupByLibrary { "Add descriptions like \"#trip\" in photo info to quickly find them here"), "searchDatesEmptySection": MessageLookupByLibrary.simpleMessage( "Search by a date, month or year"), - "searchFaceEmptySection": - MessageLookupByLibrary.simpleMessage("Find all photos of a person"), + "searchFaceEmptySection": MessageLookupByLibrary.simpleMessage( + "Persons will be shown here once indexing is done"), "searchFileTypesAndNamesEmptySection": MessageLookupByLibrary.simpleMessage("File types and names"), "searchHint1": diff --git a/mobile/lib/generated/intl/messages_es.dart b/mobile/lib/generated/intl/messages_es.dart index a6294d4a4..879f0f8c1 100644 --- a/mobile/lib/generated/intl/messages_es.dart +++ b/mobile/lib/generated/intl/messages_es.dart @@ -367,6 +367,8 @@ class MessageLookup extends MessageLookupByLibrary { "close": MessageLookupByLibrary.simpleMessage("Cerrar"), "clubByCaptureTime": MessageLookupByLibrary.simpleMessage( "Agrupar por tiempo de captura"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "codeAppliedPageTitle": MessageLookupByLibrary.simpleMessage("Código aplicado"), "codeCopiedToClipboard": MessageLookupByLibrary.simpleMessage( @@ -585,6 +587,8 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Introduzca contraseña"), "enterPasswordToEncrypt": MessageLookupByLibrary.simpleMessage( "Introduzca una contraseña que podamos usar para cifrar sus datos"), + "enterPersonName": + MessageLookupByLibrary.simpleMessage("Enter person name"), "enterReferralCode": MessageLookupByLibrary.simpleMessage( "Ingresar código de referencia"), "enterThe6digitCodeFromnyourAuthenticatorApp": @@ -609,6 +613,10 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Exportar registros"), "exportYourData": MessageLookupByLibrary.simpleMessage("Exportar tus datos"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "failedToApplyCode": MessageLookupByLibrary.simpleMessage("Error al aplicar el código"), "failedToCancel": @@ -647,6 +655,7 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("para tus recuerdos"), "forgotPassword": MessageLookupByLibrary.simpleMessage("Olvidé mi contraseña"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "freeStorageClaimed": MessageLookupByLibrary.simpleMessage( "Almacenamiento gratuito reclamado"), "freeStorageOnReferralSuccess": m24, @@ -997,6 +1006,8 @@ class MessageLookup extends MessageLookupByLibrary { "removeParticipant": MessageLookupByLibrary.simpleMessage("Quitar participante"), "removeParticipantBody": m43, + "removePersonLabel": + MessageLookupByLibrary.simpleMessage("Remove person label"), "removePublicLink": MessageLookupByLibrary.simpleMessage("Quitar enlace público"), "removeShareItemsWarning": MessageLookupByLibrary.simpleMessage( diff --git a/mobile/lib/generated/intl/messages_fr.dart b/mobile/lib/generated/intl/messages_fr.dart index 82125afcc..47817371e 100644 --- a/mobile/lib/generated/intl/messages_fr.dart +++ b/mobile/lib/generated/intl/messages_fr.dart @@ -425,6 +425,8 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Grouper par durée"), "clubByFileName": MessageLookupByLibrary.simpleMessage("Grouper par nom de fichier"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "codeAppliedPageTitle": MessageLookupByLibrary.simpleMessage("Code appliqué"), "codeCopiedToClipboard": MessageLookupByLibrary.simpleMessage( @@ -665,6 +667,8 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Saisissez le mot de passe"), "enterPasswordToEncrypt": MessageLookupByLibrary.simpleMessage( "Entrez un mot de passe que nous pouvons utiliser pour chiffrer vos données"), + "enterPersonName": + MessageLookupByLibrary.simpleMessage("Enter person name"), "enterReferralCode": MessageLookupByLibrary.simpleMessage( "Entrez le code de parrainage"), "enterThe6digitCodeFromnyourAuthenticatorApp": @@ -688,6 +692,10 @@ class MessageLookup extends MessageLookupByLibrary { "exportLogs": MessageLookupByLibrary.simpleMessage("Exporter les logs"), "exportYourData": MessageLookupByLibrary.simpleMessage("Exportez vos données"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "faces": MessageLookupByLibrary.simpleMessage("Visages"), "failedToApplyCode": MessageLookupByLibrary.simpleMessage( "Impossible d\'appliquer le code"), @@ -732,6 +740,7 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("pour vos souvenirs"), "forgotPassword": MessageLookupByLibrary.simpleMessage("Mot de passe oublié"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "freeStorageClaimed": MessageLookupByLibrary.simpleMessage("Stockage gratuit réclamé"), "freeStorageOnReferralSuccess": m24, @@ -1129,6 +1138,8 @@ class MessageLookup extends MessageLookupByLibrary { "removeParticipant": MessageLookupByLibrary.simpleMessage("Supprimer le participant"), "removeParticipantBody": m43, + "removePersonLabel": + MessageLookupByLibrary.simpleMessage("Remove person label"), "removePublicLink": MessageLookupByLibrary.simpleMessage("Supprimer le lien public"), "removeShareItemsWarning": MessageLookupByLibrary.simpleMessage( diff --git a/mobile/lib/generated/intl/messages_it.dart b/mobile/lib/generated/intl/messages_it.dart index e6db5b380..6dbae342c 100644 --- a/mobile/lib/generated/intl/messages_it.dart +++ b/mobile/lib/generated/intl/messages_it.dart @@ -411,6 +411,8 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Club per tempo di cattura"), "clubByFileName": MessageLookupByLibrary.simpleMessage("Unisci per nome file"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "codeAppliedPageTitle": MessageLookupByLibrary.simpleMessage("Codice applicato"), "codeCopiedToClipboard": MessageLookupByLibrary.simpleMessage( @@ -644,6 +646,8 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Inserisci password"), "enterPasswordToEncrypt": MessageLookupByLibrary.simpleMessage( "Inserisci una password per criptare i tuoi dati"), + "enterPersonName": + MessageLookupByLibrary.simpleMessage("Enter person name"), "enterReferralCode": MessageLookupByLibrary.simpleMessage( "Inserisci il codice di invito"), "enterThe6digitCodeFromnyourAuthenticatorApp": @@ -665,6 +669,10 @@ class MessageLookup extends MessageLookupByLibrary { "Questo link è scaduto. Si prega di selezionare un nuovo orario di scadenza o disabilitare la scadenza del link."), "exportLogs": MessageLookupByLibrary.simpleMessage("Esporta log"), "exportYourData": MessageLookupByLibrary.simpleMessage("Esporta dati"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "failedToApplyCode": MessageLookupByLibrary.simpleMessage( "Impossibile applicare il codice"), "failedToCancel": @@ -704,6 +712,7 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("per i tuoi ricordi"), "forgotPassword": MessageLookupByLibrary.simpleMessage("Password dimenticata"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "freeStorageClaimed": MessageLookupByLibrary.simpleMessage("Spazio gratuito richiesto"), "freeStorageOnReferralSuccess": m24, @@ -1090,6 +1099,8 @@ class MessageLookup extends MessageLookupByLibrary { "removeParticipant": MessageLookupByLibrary.simpleMessage("Rimuovi partecipante"), "removeParticipantBody": m43, + "removePersonLabel": + MessageLookupByLibrary.simpleMessage("Remove person label"), "removePublicLink": MessageLookupByLibrary.simpleMessage("Rimuovi link pubblico"), "removeShareItemsWarning": MessageLookupByLibrary.simpleMessage( diff --git a/mobile/lib/generated/intl/messages_ko.dart b/mobile/lib/generated/intl/messages_ko.dart index c91d849f6..65e26e631 100644 --- a/mobile/lib/generated/intl/messages_ko.dart +++ b/mobile/lib/generated/intl/messages_ko.dart @@ -34,6 +34,8 @@ class MessageLookup extends MessageLookupByLibrary { "addViewers": m1, "changeLocationOfSelectedItems": MessageLookupByLibrary.simpleMessage( "Change location of selected items?"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "contacts": MessageLookupByLibrary.simpleMessage("Contacts"), "createCollaborativeLink": MessageLookupByLibrary.simpleMessage("Create collaborative link"), @@ -44,7 +46,14 @@ class MessageLookup extends MessageLookupByLibrary { "editsToLocationWillOnlyBeSeenWithinEnte": MessageLookupByLibrary.simpleMessage( "Edits to location will only be seen within Ente"), + "enterPersonName": + MessageLookupByLibrary.simpleMessage("Enter person name"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "fileTypes": MessageLookupByLibrary.simpleMessage("File types"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "joinDiscord": MessageLookupByLibrary.simpleMessage("Join Discord"), "locations": MessageLookupByLibrary.simpleMessage("Locations"), "longPressAnEmailToVerifyEndToEndEncryption": @@ -55,6 +64,8 @@ class MessageLookup extends MessageLookupByLibrary { "Modify your query, or try searching for"), "moveToHiddenAlbum": MessageLookupByLibrary.simpleMessage("Move to hidden album"), + "removePersonLabel": + MessageLookupByLibrary.simpleMessage("Remove person label"), "search": MessageLookupByLibrary.simpleMessage("Search"), "selectALocation": MessageLookupByLibrary.simpleMessage("Select a location"), diff --git a/mobile/lib/generated/intl/messages_nl.dart b/mobile/lib/generated/intl/messages_nl.dart index f6987973c..b0f7b601f 100644 --- a/mobile/lib/generated/intl/messages_nl.dart +++ b/mobile/lib/generated/intl/messages_nl.dart @@ -447,6 +447,8 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Samenvoegen op tijd"), "clubByFileName": MessageLookupByLibrary.simpleMessage("Samenvoegen op bestandsnaam"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "codeAppliedPageTitle": MessageLookupByLibrary.simpleMessage("Code toegepast"), "codeCopiedToClipboard": MessageLookupByLibrary.simpleMessage( @@ -723,6 +725,10 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Logboek exporteren"), "exportYourData": MessageLookupByLibrary.simpleMessage("Exporteer je gegevens"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "faces": MessageLookupByLibrary.simpleMessage("Gezichten"), "failedToApplyCode": MessageLookupByLibrary.simpleMessage("Code toepassen mislukt"), @@ -771,6 +777,7 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("voor uw herinneringen"), "forgotPassword": MessageLookupByLibrary.simpleMessage("Wachtwoord vergeten"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "freeStorageClaimed": MessageLookupByLibrary.simpleMessage("Gratis opslag geclaimd"), "freeStorageOnReferralSuccess": m24, diff --git a/mobile/lib/generated/intl/messages_no.dart b/mobile/lib/generated/intl/messages_no.dart index 0e5bd97b2..88d2b1632 100644 --- a/mobile/lib/generated/intl/messages_no.dart +++ b/mobile/lib/generated/intl/messages_no.dart @@ -39,6 +39,8 @@ class MessageLookup extends MessageLookupByLibrary { "cancel": MessageLookupByLibrary.simpleMessage("Avbryt"), "changeLocationOfSelectedItems": MessageLookupByLibrary.simpleMessage( "Change location of selected items?"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "confirmAccountDeletion": MessageLookupByLibrary.simpleMessage("Bekreft sletting av konto"), "confirmDeletePrompt": MessageLookupByLibrary.simpleMessage( @@ -57,12 +59,19 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage( "Edits to location will only be seen within Ente"), "email": MessageLookupByLibrary.simpleMessage("E-post"), + "enterPersonName": + MessageLookupByLibrary.simpleMessage("Enter person name"), "enterValidEmail": MessageLookupByLibrary.simpleMessage( "Vennligst skriv inn en gyldig e-postadresse."), "enterYourEmailAddress": MessageLookupByLibrary.simpleMessage( "Skriv inn e-postadressen din"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "feedback": MessageLookupByLibrary.simpleMessage("Tilbakemelding"), "fileTypes": MessageLookupByLibrary.simpleMessage("File types"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "invalidEmailAddress": MessageLookupByLibrary.simpleMessage("Ugyldig e-postadresse"), "joinDiscord": MessageLookupByLibrary.simpleMessage("Join Discord"), @@ -77,6 +86,8 @@ class MessageLookup extends MessageLookupByLibrary { "Modify your query, or try searching for"), "moveToHiddenAlbum": MessageLookupByLibrary.simpleMessage("Move to hidden album"), + "removePersonLabel": + MessageLookupByLibrary.simpleMessage("Remove person label"), "search": MessageLookupByLibrary.simpleMessage("Search"), "selectALocation": MessageLookupByLibrary.simpleMessage("Select a location"), diff --git a/mobile/lib/generated/intl/messages_pl.dart b/mobile/lib/generated/intl/messages_pl.dart index b3a922b0a..096a0eb65 100644 --- a/mobile/lib/generated/intl/messages_pl.dart +++ b/mobile/lib/generated/intl/messages_pl.dart @@ -49,6 +49,8 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Zmień hasło"), "checkInboxAndSpamFolder": MessageLookupByLibrary.simpleMessage( "Sprawdź swoją skrzynkę odbiorczą (i spam), aby zakończyć weryfikację"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "codeCopiedToClipboard": MessageLookupByLibrary.simpleMessage( "Kod został skopiowany do schowka"), "confirm": MessageLookupByLibrary.simpleMessage("Potwierdź"), @@ -101,6 +103,8 @@ class MessageLookup extends MessageLookupByLibrary { "Wprowadź nowe hasło, którego możemy użyć do zaszyfrowania Twoich danych"), "enterPasswordToEncrypt": MessageLookupByLibrary.simpleMessage( "Wprowadź hasło, którego możemy użyć do zaszyfrowania Twoich danych"), + "enterPersonName": + MessageLookupByLibrary.simpleMessage("Enter person name"), "enterValidEmail": MessageLookupByLibrary.simpleMessage( "Podaj poprawny adres e-mail."), "enterYourEmailAddress": @@ -109,10 +113,15 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("Wprowadź hasło"), "enterYourRecoveryKey": MessageLookupByLibrary.simpleMessage( "Wprowadź swój klucz odzyskiwania"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "feedback": MessageLookupByLibrary.simpleMessage("Informacja zwrotna"), "fileTypes": MessageLookupByLibrary.simpleMessage("File types"), "forgotPassword": MessageLookupByLibrary.simpleMessage("Nie pamiętam hasła"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "generatingEncryptionKeys": MessageLookupByLibrary.simpleMessage( "Generowanie kluczy szyfrujących..."), "howItWorks": MessageLookupByLibrary.simpleMessage("Jak to działa"), @@ -166,6 +175,8 @@ class MessageLookup extends MessageLookupByLibrary { "Jeśli zapomnisz hasła, jedynym sposobem odzyskania danych jest ten klucz."), "recoverySuccessful": MessageLookupByLibrary.simpleMessage("Odzyskano pomyślnie!"), + "removePersonLabel": + MessageLookupByLibrary.simpleMessage("Remove person label"), "resendEmail": MessageLookupByLibrary.simpleMessage("Wyślij e-mail ponownie"), "resetPasswordTitle": diff --git a/mobile/lib/generated/intl/messages_pt.dart b/mobile/lib/generated/intl/messages_pt.dart index e17cb674e..aa17fc422 100644 --- a/mobile/lib/generated/intl/messages_pt.dart +++ b/mobile/lib/generated/intl/messages_pt.dart @@ -445,6 +445,8 @@ class MessageLookup extends MessageLookupByLibrary { "Agrupar por tempo de captura"), "clubByFileName": MessageLookupByLibrary.simpleMessage( "Agrupar pelo nome de arquivo"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "codeAppliedPageTitle": MessageLookupByLibrary.simpleMessage("Código aplicado"), "codeCopiedToClipboard": MessageLookupByLibrary.simpleMessage( @@ -587,7 +589,7 @@ class MessageLookup extends MessageLookupByLibrary { "descriptions": MessageLookupByLibrary.simpleMessage("Descrições"), "deselectAll": MessageLookupByLibrary.simpleMessage("Desmarcar todos"), "designedToOutlive": - MessageLookupByLibrary.simpleMessage("Feito para ter logenvidade"), + MessageLookupByLibrary.simpleMessage("Feito para ter longevidade"), "details": MessageLookupByLibrary.simpleMessage("Detalhes"), "devAccountChanged": MessageLookupByLibrary.simpleMessage( "A conta de desenvolvedor que usamos para publicar o Ente na App Store foi alterada. Por esse motivo, você precisará fazer entrar novamente.\n\nPedimos desculpas pelo inconveniente, mas isso era inevitável."), @@ -714,6 +716,10 @@ class MessageLookup extends MessageLookupByLibrary { "exportLogs": MessageLookupByLibrary.simpleMessage("Exportar logs"), "exportYourData": MessageLookupByLibrary.simpleMessage("Exportar seus dados"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "faces": MessageLookupByLibrary.simpleMessage("Rostos"), "failedToApplyCode": MessageLookupByLibrary.simpleMessage("Falha ao aplicar o código"), @@ -760,6 +766,7 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("para suas memórias"), "forgotPassword": MessageLookupByLibrary.simpleMessage("Esqueceu sua senha"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "freeStorageClaimed": MessageLookupByLibrary.simpleMessage( "Armazenamento gratuito reivindicado"), "freeStorageOnReferralSuccess": m24, diff --git a/mobile/lib/generated/intl/messages_zh.dart b/mobile/lib/generated/intl/messages_zh.dart index db60c5e0b..63b8668b5 100644 --- a/mobile/lib/generated/intl/messages_zh.dart +++ b/mobile/lib/generated/intl/messages_zh.dart @@ -382,6 +382,8 @@ class MessageLookup extends MessageLookupByLibrary { "close": MessageLookupByLibrary.simpleMessage("关闭"), "clubByCaptureTime": MessageLookupByLibrary.simpleMessage("按拍摄时间分组"), "clubByFileName": MessageLookupByLibrary.simpleMessage("按文件名排序"), + "clusteringProgress": + MessageLookupByLibrary.simpleMessage("Clustering progress"), "codeAppliedPageTitle": MessageLookupByLibrary.simpleMessage("代码已应用"), "codeCopiedToClipboard": MessageLookupByLibrary.simpleMessage("代码已复制到剪贴板"), @@ -543,7 +545,7 @@ class MessageLookup extends MessageLookupByLibrary { "emailVerificationToggle": MessageLookupByLibrary.simpleMessage("电子邮件验证"), "emailYourLogs": MessageLookupByLibrary.simpleMessage("通过电子邮件发送您的日志"), - "empty": MessageLookupByLibrary.simpleMessage("空的"), + "empty": MessageLookupByLibrary.simpleMessage("清空"), "emptyTrash": MessageLookupByLibrary.simpleMessage("要清空回收站吗?"), "enableMaps": MessageLookupByLibrary.simpleMessage("启用地图"), "enableMapsDesc": MessageLookupByLibrary.simpleMessage( @@ -592,6 +594,10 @@ class MessageLookup extends MessageLookupByLibrary { MessageLookupByLibrary.simpleMessage("此链接已过期。请选择新的过期时间或禁用链接有效期。"), "exportLogs": MessageLookupByLibrary.simpleMessage("导出日志"), "exportYourData": MessageLookupByLibrary.simpleMessage("导出您的数据"), + "faceRecognition": + MessageLookupByLibrary.simpleMessage("Face recognition"), + "faceRecognitionIndexingDescription": MessageLookupByLibrary.simpleMessage( + "Please note that this will result in a higher bandwidth and battery usage until all items are indexed."), "faces": MessageLookupByLibrary.simpleMessage("人脸"), "failedToApplyCode": MessageLookupByLibrary.simpleMessage("无法使用此代码"), "failedToCancel": MessageLookupByLibrary.simpleMessage("取消失败"), @@ -626,6 +632,7 @@ class MessageLookup extends MessageLookupByLibrary { "flip": MessageLookupByLibrary.simpleMessage("上下翻转"), "forYourMemories": MessageLookupByLibrary.simpleMessage("为您的回忆"), "forgotPassword": MessageLookupByLibrary.simpleMessage("忘记密码"), + "foundFaces": MessageLookupByLibrary.simpleMessage("Found faces"), "freeStorageClaimed": MessageLookupByLibrary.simpleMessage("已领取的免费存储"), "freeStorageOnReferralSuccess": m24, "freeStorageSpace": m25, diff --git a/mobile/lib/generated/l10n.dart b/mobile/lib/generated/l10n.dart index 4c7679154..4e2c53e29 100644 --- a/mobile/lib/generated/l10n.dart +++ b/mobile/lib/generated/l10n.dart @@ -4034,10 +4034,10 @@ class S { ); } - /// `Free trial valid till {endDate}.\nYou can purchase a paid plan afterwards.` + /// `Free trial valid till {endDate}.\nYou can choose a paid plan afterwards.` String playStoreFreeTrialValidTill(Object endDate) { return Intl.message( - 'Free trial valid till $endDate.\nYou can purchase a paid plan afterwards.', + 'Free trial valid till $endDate.\nYou can choose a paid plan afterwards.', name: 'playStoreFreeTrialValidTill', desc: '', args: [endDate], @@ -6969,10 +6969,10 @@ class S { ); } - /// `Find all photos of a person` + /// `Persons will be shown here once indexing is done` String get searchFaceEmptySection { return Intl.message( - 'Find all photos of a person', + 'Persons will be shown here once indexing is done', name: 'searchFaceEmptySection', desc: '', args: [], @@ -8168,6 +8168,16 @@ class S { ); } + /// `People` + String get people { + return Intl.message( + 'People', + name: 'people', + desc: '', + args: [], + ); + } + /// `Contents` String get contents { return Intl.message( @@ -8388,26 +8398,6 @@ class S { ); } - /// `Auto pair` - String get autoPair { - return Intl.message( - 'Auto pair', - name: 'autoPair', - desc: '', - args: [], - ); - } - - /// `Pair with PIN` - String get pairWithPin { - return Intl.message( - 'Pair with PIN', - name: 'pairWithPin', - desc: '', - args: [], - ); - } - /// `Device not found` String get deviceNotFound { return Intl.message( @@ -8468,6 +8458,26 @@ class S { ); } + /// `Add a name` + String get addAName { + return Intl.message( + 'Add a name', + name: 'addAName', + desc: '', + args: [], + ); + } + + /// `Find people quickly by name` + String get findPeopleByName { + return Intl.message( + 'Find people quickly by name', + name: 'findPeopleByName', + desc: '', + args: [], + ); + } + /// `{count, plural, zero {Add viewer} one {Add viewer} other {Add viewers}}` String addViewers(num count) { return Intl.plural( @@ -8594,6 +8604,26 @@ class S { ); } + /// `Enter person name` + String get enterPersonName { + return Intl.message( + 'Enter person name', + name: 'enterPersonName', + desc: '', + args: [], + ); + } + + /// `Remove person label` + String get removePersonLabel { + return Intl.message( + 'Remove person label', + name: 'removePersonLabel', + desc: '', + args: [], + ); + } + /// `Auto pair works only with devices that support Chromecast.` String get autoPairDesc { return Intl.message( @@ -8703,6 +8733,66 @@ class S { args: [], ); } + + /// `Auto pair` + String get autoPair { + return Intl.message( + 'Auto pair', + name: 'autoPair', + desc: '', + args: [], + ); + } + + /// `Pair with PIN` + String get pairWithPin { + return Intl.message( + 'Pair with PIN', + name: 'pairWithPin', + desc: '', + args: [], + ); + } + + /// `Face recognition` + String get faceRecognition { + return Intl.message( + 'Face recognition', + name: 'faceRecognition', + desc: '', + args: [], + ); + } + + /// `Please note that this will result in a higher bandwidth and battery usage until all items are indexed.` + String get faceRecognitionIndexingDescription { + return Intl.message( + 'Please note that this will result in a higher bandwidth and battery usage until all items are indexed.', + name: 'faceRecognitionIndexingDescription', + desc: '', + args: [], + ); + } + + /// `Found faces` + String get foundFaces { + return Intl.message( + 'Found faces', + name: 'foundFaces', + desc: '', + args: [], + ); + } + + /// `Clustering progress` + String get clusteringProgress { + return Intl.message( + 'Clustering progress', + name: 'clusteringProgress', + desc: '', + args: [], + ); + } } class AppLocalizationDelegate extends LocalizationsDelegate { diff --git a/mobile/lib/generated/protos/ente/common/box.pb.dart b/mobile/lib/generated/protos/ente/common/box.pb.dart new file mode 100644 index 000000000..41518e9ae --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/box.pb.dart @@ -0,0 +1,111 @@ +// +// Generated code. Do not modify. +// source: ente/common/box.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:protobuf/protobuf.dart' as $pb; + +/// CenterBox is a box where x,y is the center of the box +class CenterBox extends $pb.GeneratedMessage { + factory CenterBox({ + $core.double? x, + $core.double? y, + $core.double? height, + $core.double? width, + }) { + final $result = create(); + if (x != null) { + $result.x = x; + } + if (y != null) { + $result.y = y; + } + if (height != null) { + $result.height = height; + } + if (width != null) { + $result.width = width; + } + return $result; + } + CenterBox._() : super(); + factory CenterBox.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory CenterBox.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'CenterBox', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.common'), createEmptyInstance: create) + ..a<$core.double>(1, _omitFieldNames ? '' : 'x', $pb.PbFieldType.OF) + ..a<$core.double>(2, _omitFieldNames ? '' : 'y', $pb.PbFieldType.OF) + ..a<$core.double>(3, _omitFieldNames ? '' : 'height', $pb.PbFieldType.OF) + ..a<$core.double>(4, _omitFieldNames ? '' : 'width', $pb.PbFieldType.OF) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + CenterBox clone() => CenterBox()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + CenterBox copyWith(void Function(CenterBox) updates) => super.copyWith((message) => updates(message as CenterBox)) as CenterBox; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static CenterBox create() => CenterBox._(); + CenterBox createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static CenterBox getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static CenterBox? _defaultInstance; + + @$pb.TagNumber(1) + $core.double get x => $_getN(0); + @$pb.TagNumber(1) + set x($core.double v) { $_setFloat(0, v); } + @$pb.TagNumber(1) + $core.bool hasX() => $_has(0); + @$pb.TagNumber(1) + void clearX() => clearField(1); + + @$pb.TagNumber(2) + $core.double get y => $_getN(1); + @$pb.TagNumber(2) + set y($core.double v) { $_setFloat(1, v); } + @$pb.TagNumber(2) + $core.bool hasY() => $_has(1); + @$pb.TagNumber(2) + void clearY() => clearField(2); + + @$pb.TagNumber(3) + $core.double get height => $_getN(2); + @$pb.TagNumber(3) + set height($core.double v) { $_setFloat(2, v); } + @$pb.TagNumber(3) + $core.bool hasHeight() => $_has(2); + @$pb.TagNumber(3) + void clearHeight() => clearField(3); + + @$pb.TagNumber(4) + $core.double get width => $_getN(3); + @$pb.TagNumber(4) + set width($core.double v) { $_setFloat(3, v); } + @$pb.TagNumber(4) + $core.bool hasWidth() => $_has(3); + @$pb.TagNumber(4) + void clearWidth() => clearField(4); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/common/box.pbenum.dart b/mobile/lib/generated/protos/ente/common/box.pbenum.dart new file mode 100644 index 000000000..7310e57a0 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/box.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/common/box.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/common/box.pbjson.dart b/mobile/lib/generated/protos/ente/common/box.pbjson.dart new file mode 100644 index 000000000..6c9ab3cb2 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/box.pbjson.dart @@ -0,0 +1,38 @@ +// +// Generated code. Do not modify. +// source: ente/common/box.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use centerBoxDescriptor instead') +const CenterBox$json = { + '1': 'CenterBox', + '2': [ + {'1': 'x', '3': 1, '4': 1, '5': 2, '9': 0, '10': 'x', '17': true}, + {'1': 'y', '3': 2, '4': 1, '5': 2, '9': 1, '10': 'y', '17': true}, + {'1': 'height', '3': 3, '4': 1, '5': 2, '9': 2, '10': 'height', '17': true}, + {'1': 'width', '3': 4, '4': 1, '5': 2, '9': 3, '10': 'width', '17': true}, + ], + '8': [ + {'1': '_x'}, + {'1': '_y'}, + {'1': '_height'}, + {'1': '_width'}, + ], +}; + +/// Descriptor for `CenterBox`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List centerBoxDescriptor = $convert.base64Decode( + 'CglDZW50ZXJCb3gSEQoBeBgBIAEoAkgAUgF4iAEBEhEKAXkYAiABKAJIAVIBeYgBARIbCgZoZW' + 'lnaHQYAyABKAJIAlIGaGVpZ2h0iAEBEhkKBXdpZHRoGAQgASgCSANSBXdpZHRoiAEBQgQKAl94' + 'QgQKAl95QgkKB19oZWlnaHRCCAoGX3dpZHRo'); + diff --git a/mobile/lib/generated/protos/ente/common/box.pbserver.dart b/mobile/lib/generated/protos/ente/common/box.pbserver.dart new file mode 100644 index 000000000..1e8625388 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/box.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/common/box.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'box.pb.dart'; + diff --git a/mobile/lib/generated/protos/ente/common/point.pb.dart b/mobile/lib/generated/protos/ente/common/point.pb.dart new file mode 100644 index 000000000..47f9b87ce --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/point.pb.dart @@ -0,0 +1,83 @@ +// +// Generated code. Do not modify. +// source: ente/common/point.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:protobuf/protobuf.dart' as $pb; + +/// EPoint is a point in 2D space +class EPoint extends $pb.GeneratedMessage { + factory EPoint({ + $core.double? x, + $core.double? y, + }) { + final $result = create(); + if (x != null) { + $result.x = x; + } + if (y != null) { + $result.y = y; + } + return $result; + } + EPoint._() : super(); + factory EPoint.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory EPoint.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'EPoint', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.common'), createEmptyInstance: create) + ..a<$core.double>(1, _omitFieldNames ? '' : 'x', $pb.PbFieldType.OF) + ..a<$core.double>(2, _omitFieldNames ? '' : 'y', $pb.PbFieldType.OF) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + EPoint clone() => EPoint()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + EPoint copyWith(void Function(EPoint) updates) => super.copyWith((message) => updates(message as EPoint)) as EPoint; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static EPoint create() => EPoint._(); + EPoint createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static EPoint getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static EPoint? _defaultInstance; + + @$pb.TagNumber(1) + $core.double get x => $_getN(0); + @$pb.TagNumber(1) + set x($core.double v) { $_setFloat(0, v); } + @$pb.TagNumber(1) + $core.bool hasX() => $_has(0); + @$pb.TagNumber(1) + void clearX() => clearField(1); + + @$pb.TagNumber(2) + $core.double get y => $_getN(1); + @$pb.TagNumber(2) + set y($core.double v) { $_setFloat(1, v); } + @$pb.TagNumber(2) + $core.bool hasY() => $_has(1); + @$pb.TagNumber(2) + void clearY() => clearField(2); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/common/point.pbenum.dart b/mobile/lib/generated/protos/ente/common/point.pbenum.dart new file mode 100644 index 000000000..3c242a2fc --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/point.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/common/point.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/common/point.pbjson.dart b/mobile/lib/generated/protos/ente/common/point.pbjson.dart new file mode 100644 index 000000000..44d2d0712 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/point.pbjson.dart @@ -0,0 +1,33 @@ +// +// Generated code. Do not modify. +// source: ente/common/point.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use ePointDescriptor instead') +const EPoint$json = { + '1': 'EPoint', + '2': [ + {'1': 'x', '3': 1, '4': 1, '5': 2, '9': 0, '10': 'x', '17': true}, + {'1': 'y', '3': 2, '4': 1, '5': 2, '9': 1, '10': 'y', '17': true}, + ], + '8': [ + {'1': '_x'}, + {'1': '_y'}, + ], +}; + +/// Descriptor for `EPoint`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List ePointDescriptor = $convert.base64Decode( + 'CgZFUG9pbnQSEQoBeBgBIAEoAkgAUgF4iAEBEhEKAXkYAiABKAJIAVIBeYgBAUIECgJfeEIECg' + 'JfeQ=='); + diff --git a/mobile/lib/generated/protos/ente/common/point.pbserver.dart b/mobile/lib/generated/protos/ente/common/point.pbserver.dart new file mode 100644 index 000000000..66728e123 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/point.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/common/point.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'point.pb.dart'; + diff --git a/mobile/lib/generated/protos/ente/common/vector.pb.dart b/mobile/lib/generated/protos/ente/common/vector.pb.dart new file mode 100644 index 000000000..44aa7d748 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/vector.pb.dart @@ -0,0 +1,64 @@ +// +// Generated code. Do not modify. +// source: ente/common/vector.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:protobuf/protobuf.dart' as $pb; + +/// Vector is generic message for dealing with lists of doubles +/// It should ideally be used independently and not as a submessage +class EVector extends $pb.GeneratedMessage { + factory EVector({ + $core.Iterable<$core.double>? values, + }) { + final $result = create(); + if (values != null) { + $result.values.addAll(values); + } + return $result; + } + EVector._() : super(); + factory EVector.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory EVector.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'EVector', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.common'), createEmptyInstance: create) + ..p<$core.double>(1, _omitFieldNames ? '' : 'values', $pb.PbFieldType.KD) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + EVector clone() => EVector()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + EVector copyWith(void Function(EVector) updates) => super.copyWith((message) => updates(message as EVector)) as EVector; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static EVector create() => EVector._(); + EVector createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static EVector getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static EVector? _defaultInstance; + + @$pb.TagNumber(1) + $core.List<$core.double> get values => $_getList(0); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/common/vector.pbenum.dart b/mobile/lib/generated/protos/ente/common/vector.pbenum.dart new file mode 100644 index 000000000..c88d2648a --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/vector.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/common/vector.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/common/vector.pbjson.dart b/mobile/lib/generated/protos/ente/common/vector.pbjson.dart new file mode 100644 index 000000000..1aff5cb29 --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/vector.pbjson.dart @@ -0,0 +1,27 @@ +// +// Generated code. Do not modify. +// source: ente/common/vector.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use eVectorDescriptor instead') +const EVector$json = { + '1': 'EVector', + '2': [ + {'1': 'values', '3': 1, '4': 3, '5': 1, '10': 'values'}, + ], +}; + +/// Descriptor for `EVector`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List eVectorDescriptor = $convert.base64Decode( + 'CgdFVmVjdG9yEhYKBnZhbHVlcxgBIAMoAVIGdmFsdWVz'); + diff --git a/mobile/lib/generated/protos/ente/common/vector.pbserver.dart b/mobile/lib/generated/protos/ente/common/vector.pbserver.dart new file mode 100644 index 000000000..dbf5ac36f --- /dev/null +++ b/mobile/lib/generated/protos/ente/common/vector.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/common/vector.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'vector.pb.dart'; + diff --git a/mobile/lib/generated/protos/ente/ml/face.pb.dart b/mobile/lib/generated/protos/ente/ml/face.pb.dart new file mode 100644 index 000000000..55d512b66 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/face.pb.dart @@ -0,0 +1,169 @@ +// +// Generated code. Do not modify. +// source: ente/ml/face.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:protobuf/protobuf.dart' as $pb; + +import '../common/box.pb.dart' as $0; +import '../common/point.pb.dart' as $1; + +class Detection extends $pb.GeneratedMessage { + factory Detection({ + $0.CenterBox? box, + $1.EPoint? landmarks, + }) { + final $result = create(); + if (box != null) { + $result.box = box; + } + if (landmarks != null) { + $result.landmarks = landmarks; + } + return $result; + } + Detection._() : super(); + factory Detection.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory Detection.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'Detection', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create) + ..aOM<$0.CenterBox>(1, _omitFieldNames ? '' : 'box', subBuilder: $0.CenterBox.create) + ..aOM<$1.EPoint>(2, _omitFieldNames ? '' : 'landmarks', subBuilder: $1.EPoint.create) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + Detection clone() => Detection()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + Detection copyWith(void Function(Detection) updates) => super.copyWith((message) => updates(message as Detection)) as Detection; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static Detection create() => Detection._(); + Detection createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static Detection getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static Detection? _defaultInstance; + + @$pb.TagNumber(1) + $0.CenterBox get box => $_getN(0); + @$pb.TagNumber(1) + set box($0.CenterBox v) { setField(1, v); } + @$pb.TagNumber(1) + $core.bool hasBox() => $_has(0); + @$pb.TagNumber(1) + void clearBox() => clearField(1); + @$pb.TagNumber(1) + $0.CenterBox ensureBox() => $_ensure(0); + + @$pb.TagNumber(2) + $1.EPoint get landmarks => $_getN(1); + @$pb.TagNumber(2) + set landmarks($1.EPoint v) { setField(2, v); } + @$pb.TagNumber(2) + $core.bool hasLandmarks() => $_has(1); + @$pb.TagNumber(2) + void clearLandmarks() => clearField(2); + @$pb.TagNumber(2) + $1.EPoint ensureLandmarks() => $_ensure(1); +} + +class Face extends $pb.GeneratedMessage { + factory Face({ + $core.String? id, + Detection? detection, + $core.double? confidence, + }) { + final $result = create(); + if (id != null) { + $result.id = id; + } + if (detection != null) { + $result.detection = detection; + } + if (confidence != null) { + $result.confidence = confidence; + } + return $result; + } + Face._() : super(); + factory Face.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory Face.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'Face', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create) + ..aOS(1, _omitFieldNames ? '' : 'id') + ..aOM(2, _omitFieldNames ? '' : 'detection', subBuilder: Detection.create) + ..a<$core.double>(3, _omitFieldNames ? '' : 'confidence', $pb.PbFieldType.OF) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + Face clone() => Face()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + Face copyWith(void Function(Face) updates) => super.copyWith((message) => updates(message as Face)) as Face; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static Face create() => Face._(); + Face createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static Face getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static Face? _defaultInstance; + + @$pb.TagNumber(1) + $core.String get id => $_getSZ(0); + @$pb.TagNumber(1) + set id($core.String v) { $_setString(0, v); } + @$pb.TagNumber(1) + $core.bool hasId() => $_has(0); + @$pb.TagNumber(1) + void clearId() => clearField(1); + + @$pb.TagNumber(2) + Detection get detection => $_getN(1); + @$pb.TagNumber(2) + set detection(Detection v) { setField(2, v); } + @$pb.TagNumber(2) + $core.bool hasDetection() => $_has(1); + @$pb.TagNumber(2) + void clearDetection() => clearField(2); + @$pb.TagNumber(2) + Detection ensureDetection() => $_ensure(1); + + @$pb.TagNumber(3) + $core.double get confidence => $_getN(2); + @$pb.TagNumber(3) + set confidence($core.double v) { $_setFloat(2, v); } + @$pb.TagNumber(3) + $core.bool hasConfidence() => $_has(2); + @$pb.TagNumber(3) + void clearConfidence() => clearField(3); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/ml/face.pbenum.dart b/mobile/lib/generated/protos/ente/ml/face.pbenum.dart new file mode 100644 index 000000000..2eefe1f44 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/face.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/ml/face.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/ml/face.pbjson.dart b/mobile/lib/generated/protos/ente/ml/face.pbjson.dart new file mode 100644 index 000000000..5aa614a8b --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/face.pbjson.dart @@ -0,0 +1,55 @@ +// +// Generated code. Do not modify. +// source: ente/ml/face.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use detectionDescriptor instead') +const Detection$json = { + '1': 'Detection', + '2': [ + {'1': 'box', '3': 1, '4': 1, '5': 11, '6': '.ente.common.CenterBox', '9': 0, '10': 'box', '17': true}, + {'1': 'landmarks', '3': 2, '4': 1, '5': 11, '6': '.ente.common.EPoint', '9': 1, '10': 'landmarks', '17': true}, + ], + '8': [ + {'1': '_box'}, + {'1': '_landmarks'}, + ], +}; + +/// Descriptor for `Detection`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List detectionDescriptor = $convert.base64Decode( + 'CglEZXRlY3Rpb24SLQoDYm94GAEgASgLMhYuZW50ZS5jb21tb24uQ2VudGVyQm94SABSA2JveI' + 'gBARI2CglsYW5kbWFya3MYAiABKAsyEy5lbnRlLmNvbW1vbi5FUG9pbnRIAVIJbGFuZG1hcmtz' + 'iAEBQgYKBF9ib3hCDAoKX2xhbmRtYXJrcw=='); + +@$core.Deprecated('Use faceDescriptor instead') +const Face$json = { + '1': 'Face', + '2': [ + {'1': 'id', '3': 1, '4': 1, '5': 9, '9': 0, '10': 'id', '17': true}, + {'1': 'detection', '3': 2, '4': 1, '5': 11, '6': '.ente.ml.Detection', '9': 1, '10': 'detection', '17': true}, + {'1': 'confidence', '3': 3, '4': 1, '5': 2, '9': 2, '10': 'confidence', '17': true}, + ], + '8': [ + {'1': '_id'}, + {'1': '_detection'}, + {'1': '_confidence'}, + ], +}; + +/// Descriptor for `Face`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List faceDescriptor = $convert.base64Decode( + 'CgRGYWNlEhMKAmlkGAEgASgJSABSAmlkiAEBEjUKCWRldGVjdGlvbhgCIAEoCzISLmVudGUubW' + 'wuRGV0ZWN0aW9uSAFSCWRldGVjdGlvbogBARIjCgpjb25maWRlbmNlGAMgASgCSAJSCmNvbmZp' + 'ZGVuY2WIAQFCBQoDX2lkQgwKCl9kZXRlY3Rpb25CDQoLX2NvbmZpZGVuY2U='); + diff --git a/mobile/lib/generated/protos/ente/ml/face.pbserver.dart b/mobile/lib/generated/protos/ente/ml/face.pbserver.dart new file mode 100644 index 000000000..a2cd6ff85 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/face.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/ml/face.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'face.pb.dart'; + diff --git a/mobile/lib/generated/protos/ente/ml/fileml.pb.dart b/mobile/lib/generated/protos/ente/ml/fileml.pb.dart new file mode 100644 index 000000000..853f89bac --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/fileml.pb.dart @@ -0,0 +1,179 @@ +// +// Generated code. Do not modify. +// source: ente/ml/fileml.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:core' as $core; + +import 'package:fixnum/fixnum.dart' as $fixnum; +import 'package:protobuf/protobuf.dart' as $pb; + +import 'face.pb.dart' as $2; + +class FileML extends $pb.GeneratedMessage { + factory FileML({ + $fixnum.Int64? id, + $core.Iterable<$core.double>? clip, + }) { + final $result = create(); + if (id != null) { + $result.id = id; + } + if (clip != null) { + $result.clip.addAll(clip); + } + return $result; + } + FileML._() : super(); + factory FileML.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory FileML.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'FileML', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create) + ..aInt64(1, _omitFieldNames ? '' : 'id') + ..p<$core.double>(2, _omitFieldNames ? '' : 'clip', $pb.PbFieldType.KD) + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + FileML clone() => FileML()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + FileML copyWith(void Function(FileML) updates) => super.copyWith((message) => updates(message as FileML)) as FileML; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static FileML create() => FileML._(); + FileML createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static FileML getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static FileML? _defaultInstance; + + @$pb.TagNumber(1) + $fixnum.Int64 get id => $_getI64(0); + @$pb.TagNumber(1) + set id($fixnum.Int64 v) { $_setInt64(0, v); } + @$pb.TagNumber(1) + $core.bool hasId() => $_has(0); + @$pb.TagNumber(1) + void clearId() => clearField(1); + + @$pb.TagNumber(2) + $core.List<$core.double> get clip => $_getList(1); +} + +class FileFaces extends $pb.GeneratedMessage { + factory FileFaces({ + $core.Iterable<$2.Face>? faces, + $core.int? height, + $core.int? width, + $core.int? version, + $core.String? error, + }) { + final $result = create(); + if (faces != null) { + $result.faces.addAll(faces); + } + if (height != null) { + $result.height = height; + } + if (width != null) { + $result.width = width; + } + if (version != null) { + $result.version = version; + } + if (error != null) { + $result.error = error; + } + return $result; + } + FileFaces._() : super(); + factory FileFaces.fromBuffer($core.List<$core.int> i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromBuffer(i, r); + factory FileFaces.fromJson($core.String i, [$pb.ExtensionRegistry r = $pb.ExtensionRegistry.EMPTY]) => create()..mergeFromJson(i, r); + + static final $pb.BuilderInfo _i = $pb.BuilderInfo(_omitMessageNames ? '' : 'FileFaces', package: const $pb.PackageName(_omitMessageNames ? '' : 'ente.ml'), createEmptyInstance: create) + ..pc<$2.Face>(1, _omitFieldNames ? '' : 'faces', $pb.PbFieldType.PM, subBuilder: $2.Face.create) + ..a<$core.int>(2, _omitFieldNames ? '' : 'height', $pb.PbFieldType.O3) + ..a<$core.int>(3, _omitFieldNames ? '' : 'width', $pb.PbFieldType.O3) + ..a<$core.int>(4, _omitFieldNames ? '' : 'version', $pb.PbFieldType.O3) + ..aOS(5, _omitFieldNames ? '' : 'error') + ..hasRequiredFields = false + ; + + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.deepCopy] instead. ' + 'Will be removed in next major version') + FileFaces clone() => FileFaces()..mergeFromMessage(this); + @$core.Deprecated( + 'Using this can add significant overhead to your binary. ' + 'Use [GeneratedMessageGenericExtensions.rebuild] instead. ' + 'Will be removed in next major version') + FileFaces copyWith(void Function(FileFaces) updates) => super.copyWith((message) => updates(message as FileFaces)) as FileFaces; + + $pb.BuilderInfo get info_ => _i; + + @$core.pragma('dart2js:noInline') + static FileFaces create() => FileFaces._(); + FileFaces createEmptyInstance() => create(); + static $pb.PbList createRepeated() => $pb.PbList(); + @$core.pragma('dart2js:noInline') + static FileFaces getDefault() => _defaultInstance ??= $pb.GeneratedMessage.$_defaultFor(create); + static FileFaces? _defaultInstance; + + @$pb.TagNumber(1) + $core.List<$2.Face> get faces => $_getList(0); + + @$pb.TagNumber(2) + $core.int get height => $_getIZ(1); + @$pb.TagNumber(2) + set height($core.int v) { $_setSignedInt32(1, v); } + @$pb.TagNumber(2) + $core.bool hasHeight() => $_has(1); + @$pb.TagNumber(2) + void clearHeight() => clearField(2); + + @$pb.TagNumber(3) + $core.int get width => $_getIZ(2); + @$pb.TagNumber(3) + set width($core.int v) { $_setSignedInt32(2, v); } + @$pb.TagNumber(3) + $core.bool hasWidth() => $_has(2); + @$pb.TagNumber(3) + void clearWidth() => clearField(3); + + @$pb.TagNumber(4) + $core.int get version => $_getIZ(3); + @$pb.TagNumber(4) + set version($core.int v) { $_setSignedInt32(3, v); } + @$pb.TagNumber(4) + $core.bool hasVersion() => $_has(3); + @$pb.TagNumber(4) + void clearVersion() => clearField(4); + + @$pb.TagNumber(5) + $core.String get error => $_getSZ(4); + @$pb.TagNumber(5) + set error($core.String v) { $_setString(4, v); } + @$pb.TagNumber(5) + $core.bool hasError() => $_has(4); + @$pb.TagNumber(5) + void clearError() => clearField(5); +} + + +const _omitFieldNames = $core.bool.fromEnvironment('protobuf.omit_field_names'); +const _omitMessageNames = $core.bool.fromEnvironment('protobuf.omit_message_names'); diff --git a/mobile/lib/generated/protos/ente/ml/fileml.pbenum.dart b/mobile/lib/generated/protos/ente/ml/fileml.pbenum.dart new file mode 100644 index 000000000..71d796efe --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/fileml.pbenum.dart @@ -0,0 +1,11 @@ +// +// Generated code. Do not modify. +// source: ente/ml/fileml.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + diff --git a/mobile/lib/generated/protos/ente/ml/fileml.pbjson.dart b/mobile/lib/generated/protos/ente/ml/fileml.pbjson.dart new file mode 100644 index 000000000..824741733 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/fileml.pbjson.dart @@ -0,0 +1,57 @@ +// +// Generated code. Do not modify. +// source: ente/ml/fileml.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +import 'dart:convert' as $convert; +import 'dart:core' as $core; +import 'dart:typed_data' as $typed_data; + +@$core.Deprecated('Use fileMLDescriptor instead') +const FileML$json = { + '1': 'FileML', + '2': [ + {'1': 'id', '3': 1, '4': 1, '5': 3, '9': 0, '10': 'id', '17': true}, + {'1': 'clip', '3': 2, '4': 3, '5': 1, '10': 'clip'}, + ], + '8': [ + {'1': '_id'}, + ], +}; + +/// Descriptor for `FileML`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List fileMLDescriptor = $convert.base64Decode( + 'CgZGaWxlTUwSEwoCaWQYASABKANIAFICaWSIAQESEgoEY2xpcBgCIAMoAVIEY2xpcEIFCgNfaW' + 'Q='); + +@$core.Deprecated('Use fileFacesDescriptor instead') +const FileFaces$json = { + '1': 'FileFaces', + '2': [ + {'1': 'faces', '3': 1, '4': 3, '5': 11, '6': '.ente.ml.Face', '10': 'faces'}, + {'1': 'height', '3': 2, '4': 1, '5': 5, '9': 0, '10': 'height', '17': true}, + {'1': 'width', '3': 3, '4': 1, '5': 5, '9': 1, '10': 'width', '17': true}, + {'1': 'version', '3': 4, '4': 1, '5': 5, '9': 2, '10': 'version', '17': true}, + {'1': 'error', '3': 5, '4': 1, '5': 9, '9': 3, '10': 'error', '17': true}, + ], + '8': [ + {'1': '_height'}, + {'1': '_width'}, + {'1': '_version'}, + {'1': '_error'}, + ], +}; + +/// Descriptor for `FileFaces`. Decode as a `google.protobuf.DescriptorProto`. +final $typed_data.Uint8List fileFacesDescriptor = $convert.base64Decode( + 'CglGaWxlRmFjZXMSIwoFZmFjZXMYASADKAsyDS5lbnRlLm1sLkZhY2VSBWZhY2VzEhsKBmhlaW' + 'dodBgCIAEoBUgAUgZoZWlnaHSIAQESGQoFd2lkdGgYAyABKAVIAVIFd2lkdGiIAQESHQoHdmVy' + 'c2lvbhgEIAEoBUgCUgd2ZXJzaW9uiAEBEhkKBWVycm9yGAUgASgJSANSBWVycm9yiAEBQgkKB1' + '9oZWlnaHRCCAoGX3dpZHRoQgoKCF92ZXJzaW9uQggKBl9lcnJvcg=='); + diff --git a/mobile/lib/generated/protos/ente/ml/fileml.pbserver.dart b/mobile/lib/generated/protos/ente/ml/fileml.pbserver.dart new file mode 100644 index 000000000..4cb208d27 --- /dev/null +++ b/mobile/lib/generated/protos/ente/ml/fileml.pbserver.dart @@ -0,0 +1,14 @@ +// +// Generated code. Do not modify. +// source: ente/ml/fileml.proto +// +// @dart = 2.12 + +// ignore_for_file: annotate_overrides, camel_case_types, comment_references +// ignore_for_file: constant_identifier_names +// ignore_for_file: deprecated_member_use_from_same_package, library_prefixes +// ignore_for_file: non_constant_identifier_names, prefer_final_fields +// ignore_for_file: unnecessary_import, unnecessary_this, unused_import + +export 'fileml.pb.dart'; + diff --git a/mobile/lib/l10n/intl_cs.arb b/mobile/lib/l10n/intl_cs.arb index e7d374725..449bdb760 100644 --- a/mobile/lib/l10n/intl_cs.arb +++ b/mobile/lib/l10n/intl_cs.arb @@ -18,5 +18,11 @@ "addCollaborators": "{count, plural, zero {Add collaborator} one {Add collaborator} other {Add collaborators}}", "longPressAnEmailToVerifyEndToEndEncryption": "Long press an email to verify end to end encryption.", "createCollaborativeLink": "Create collaborative link", - "search": "Search" + "search": "Search", + "enterPersonName": "Enter person name", + "removePersonLabel": "Remove person label", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_de.arb b/mobile/lib/l10n/intl_de.arb index 0e5807e1e..acee623ab 100644 --- a/mobile/lib/l10n/intl_de.arb +++ b/mobile/lib/l10n/intl_de.arb @@ -1187,6 +1187,8 @@ "changeLocationOfSelectedItems": "Standort der gewählten Elemente ändern?", "editsToLocationWillOnlyBeSeenWithinEnte": "Änderungen des Standorts werden nur in ente sichtbar sein", "cleanUncategorized": "Unkategorisiert leeren", + "addAName": "Add a name", + "findPeopleByName": "Find people quickly by searching by name", "cleanUncategorizedDescription": "Entferne alle Dateien von \"Unkategorisiert\" die in anderen Alben vorhanden sind", "waitingForVerification": "Warte auf Bestätigung...", "passkey": "Passkey", @@ -1204,5 +1206,11 @@ "addCollaborators": "{count, plural, zero {Add collaborator} one {Add collaborator} other {Add collaborators}}", "longPressAnEmailToVerifyEndToEndEncryption": "Long press an email to verify end to end encryption.", "createCollaborativeLink": "Create collaborative link", - "search": "Search" + "search": "Search", + "enterPersonName": "Enter person name", + "removePersonLabel": "Remove person label", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_en.arb b/mobile/lib/l10n/intl_en.arb index 6bc8b5926..0fe06c95a 100644 --- a/mobile/lib/l10n/intl_en.arb +++ b/mobile/lib/l10n/intl_en.arb @@ -569,7 +569,7 @@ "freeTrialValidTill": "Free trial valid till {endDate}", "validTill": "Valid till {endDate}", "addOnValidTill": "Your {storageAmount} add-on is valid till {endDate}", - "playStoreFreeTrialValidTill": "Free trial valid till {endDate}.\nYou can purchase a paid plan afterwards.", + "playStoreFreeTrialValidTill": "Free trial valid till {endDate}.\nYou can choose a paid plan afterwards.", "subWillBeCancelledOn": "Your subscription will be cancelled on {endDate}", "subscription": "Subscription", "paymentDetails": "Payment details", @@ -987,7 +987,7 @@ "fileTypesAndNames": "File types and names", "location": "Location", "moments": "Moments", - "searchFaceEmptySection": "Find all photos of a person", + "searchFaceEmptySection": "Persons will be shown here once indexing is done", "searchDatesEmptySection": "Search by a date, month or year", "searchLocationEmptySection": "Group photos that are taken within some radius of a photo", "searchPeopleEmptySection": "Invite people, and you'll see all photos shared by them here", @@ -1171,6 +1171,7 @@ } }, "faces": "Faces", + "people": "People", "contents": "Contents", "addNew": "Add new", "@addNew": { @@ -1196,14 +1197,14 @@ "verifyPasskey": "Verify passkey", "playOnTv": "Play album on TV", "pair": "Pair", - "autoPair": "Auto pair", - "pairWithPin": "Pair with PIN", "deviceNotFound": "Device not found", "castInstruction": "Visit cast.ente.io on the device you want to pair.\n\nEnter the code below to play the album on your TV.", "deviceCodeHint": "Enter the code", "joinDiscord": "Join Discord", "locations": "Locations", "descriptions": "Descriptions", + "addAName": "Add a name", + "findPeopleByName": "Find people quickly by name", "addViewers": "{count, plural, zero {Add viewer} one {Add viewer} other {Add viewers}}", "addCollaborators": "{count, plural, zero {Add collaborator} one {Add collaborator} other {Add collaborators}}", "longPressAnEmailToVerifyEndToEndEncryption": "Long press an email to verify end to end encryption.", @@ -1216,6 +1217,8 @@ "customEndpoint": "Connected to {endpoint}", "createCollaborativeLink": "Create collaborative link", "search": "Search", + "enterPersonName": "Enter person name", + "removePersonLabel": "Remove person label", "autoPairDesc": "Auto pair works only with devices that support Chromecast.", "manualPairDesc": "Pair with PIN works with any screen you wish to view your album on.", "connectToDevice": "Connect to device", @@ -1226,5 +1229,11 @@ "stopCastingBody": "Do you want to stop casting?", "castIPMismatchTitle": "Failed to cast album", "castIPMismatchBody": "Please make sure you are on the same network as the TV.", - "pairingComplete": "Pairing complete" + "pairingComplete": "Pairing complete", + "autoPair": "Auto pair", + "pairWithPin": "Pair with PIN", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_es.arb b/mobile/lib/l10n/intl_es.arb index 6515371fa..a472aaf8e 100644 --- a/mobile/lib/l10n/intl_es.arb +++ b/mobile/lib/l10n/intl_es.arb @@ -980,5 +980,11 @@ "addCollaborators": "{count, plural, zero {Add collaborator} one {Add collaborator} other {Add collaborators}}", "longPressAnEmailToVerifyEndToEndEncryption": "Long press an email to verify end to end encryption.", "createCollaborativeLink": "Create collaborative link", - "search": "Search" + "search": "Search", + "enterPersonName": "Enter person name", + "removePersonLabel": "Remove person label", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_fr.arb b/mobile/lib/l10n/intl_fr.arb index 1d8e5f6d3..a5b2f2fd0 100644 --- a/mobile/lib/l10n/intl_fr.arb +++ b/mobile/lib/l10n/intl_fr.arb @@ -1161,5 +1161,11 @@ "addCollaborators": "{count, plural, zero {Add collaborator} one {Add collaborator} other {Add collaborators}}", "longPressAnEmailToVerifyEndToEndEncryption": "Long press an email to verify end to end encryption.", "createCollaborativeLink": "Create collaborative link", - "search": "Search" + "search": "Search", + "enterPersonName": "Enter person name", + "removePersonLabel": "Remove person label", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_it.arb b/mobile/lib/l10n/intl_it.arb index c9655dd06..e81ac6377 100644 --- a/mobile/lib/l10n/intl_it.arb +++ b/mobile/lib/l10n/intl_it.arb @@ -1123,5 +1123,11 @@ "addCollaborators": "{count, plural, zero {Add collaborator} one {Add collaborator} other {Add collaborators}}", "longPressAnEmailToVerifyEndToEndEncryption": "Long press an email to verify end to end encryption.", "createCollaborativeLink": "Create collaborative link", - "search": "Search" + "search": "Search", + "enterPersonName": "Enter person name", + "removePersonLabel": "Remove person label", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_ko.arb b/mobile/lib/l10n/intl_ko.arb index e7d374725..449bdb760 100644 --- a/mobile/lib/l10n/intl_ko.arb +++ b/mobile/lib/l10n/intl_ko.arb @@ -18,5 +18,11 @@ "addCollaborators": "{count, plural, zero {Add collaborator} one {Add collaborator} other {Add collaborators}}", "longPressAnEmailToVerifyEndToEndEncryption": "Long press an email to verify end to end encryption.", "createCollaborativeLink": "Create collaborative link", - "search": "Search" + "search": "Search", + "enterPersonName": "Enter person name", + "removePersonLabel": "Remove person label", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_nl.arb b/mobile/lib/l10n/intl_nl.arb index a8f854a43..682aee259 100644 --- a/mobile/lib/l10n/intl_nl.arb +++ b/mobile/lib/l10n/intl_nl.arb @@ -1226,5 +1226,9 @@ "stopCastingBody": "Wil je stoppen met casten?", "castIPMismatchTitle": "Album casten mislukt", "castIPMismatchBody": "Zorg ervoor dat je op hetzelfde netwerk zit als de tv.", - "pairingComplete": "Koppeling voltooid" + "pairingComplete": "Koppeling voltooid", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_no.arb b/mobile/lib/l10n/intl_no.arb index 8908eadb0..697a9f3c4 100644 --- a/mobile/lib/l10n/intl_no.arb +++ b/mobile/lib/l10n/intl_no.arb @@ -32,5 +32,11 @@ "addCollaborators": "{count, plural, zero {Add collaborator} one {Add collaborator} other {Add collaborators}}", "longPressAnEmailToVerifyEndToEndEncryption": "Long press an email to verify end to end encryption.", "createCollaborativeLink": "Create collaborative link", - "search": "Search" + "search": "Search", + "enterPersonName": "Enter person name", + "removePersonLabel": "Remove person label", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_pl.arb b/mobile/lib/l10n/intl_pl.arb index 13d740614..f9b66901e 100644 --- a/mobile/lib/l10n/intl_pl.arb +++ b/mobile/lib/l10n/intl_pl.arb @@ -119,5 +119,11 @@ "addCollaborators": "{count, plural, zero {Add collaborator} one {Add collaborator} other {Add collaborators}}", "longPressAnEmailToVerifyEndToEndEncryption": "Long press an email to verify end to end encryption.", "createCollaborativeLink": "Create collaborative link", - "search": "Search" + "search": "Search", + "enterPersonName": "Enter person name", + "removePersonLabel": "Remove person label", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_pt.arb b/mobile/lib/l10n/intl_pt.arb index f2ca4ba13..f47dd89e9 100644 --- a/mobile/lib/l10n/intl_pt.arb +++ b/mobile/lib/l10n/intl_pt.arb @@ -1226,5 +1226,9 @@ "stopCastingBody": "Você quer parar a transmissão?", "castIPMismatchTitle": "Falha ao transmitir álbum", "castIPMismatchBody": "Certifique-se de estar na mesma rede que a TV.", - "pairingComplete": "Pareamento concluído" + "pairingComplete": "Pareamento concluído", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/l10n/intl_zh.arb b/mobile/lib/l10n/intl_zh.arb index de473e54b..933eea126 100644 --- a/mobile/lib/l10n/intl_zh.arb +++ b/mobile/lib/l10n/intl_zh.arb @@ -1226,5 +1226,9 @@ "stopCastingBody": "您想停止投放吗?", "castIPMismatchTitle": "投放相册失败", "castIPMismatchBody": "请确保您的设备与电视处于同一网络。", - "pairingComplete": "配对完成" + "pairingComplete": "配对完成", + "faceRecognition": "Face recognition", + "faceRecognitionIndexingDescription": "Please note that this will result in a higher bandwidth and battery usage until all items are indexed.", + "foundFaces": "Found faces", + "clusteringProgress": "Clustering progress" } \ No newline at end of file diff --git a/mobile/lib/main.dart b/mobile/lib/main.dart index 52c9c715a..6a42a0a3b 100644 --- a/mobile/lib/main.dart +++ b/mobile/lib/main.dart @@ -21,6 +21,7 @@ import 'package:photos/core/errors.dart'; import 'package:photos/core/network/network.dart'; import 'package:photos/db/upload_locks_db.dart'; import 'package:photos/ente_theme_data.dart'; +import "package:photos/face/db.dart"; import "package:photos/l10n/l10n.dart"; import "package:photos/service_locator.dart"; import 'package:photos/services/app_lifecycle_service.dart'; @@ -32,6 +33,9 @@ import 'package:photos/services/home_widget_service.dart'; import 'package:photos/services/local_file_update_service.dart'; import 'package:photos/services/local_sync_service.dart'; import "package:photos/services/location_service.dart"; +import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart'; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart'; import "package:photos/services/machine_learning/machine_learning_controller.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import 'package:photos/services/memories_service.dart'; @@ -213,6 +217,7 @@ Future _init(bool isBackground, {String via = ''}) async { LocalFileUpdateService.instance.init(preferences); SearchService.instance.init(); StorageBonusService.instance.init(preferences); + RemoteFileMLService.instance.init(preferences); if (!isBackground && Platform.isAndroid && await HomeWidgetService.instance.countHomeWidgets() == 0) { @@ -233,9 +238,23 @@ Future _init(bool isBackground, {String via = ''}) async { // Can not including existing tf/ml binaries as they are not being built // from source. // See https://gitlab.com/fdroid/fdroiddata/-/merge_requests/12671#note_1294346819 - // if (!UpdateService.instance.isFdroidFlavor()) { - // unawaited(ObjectDetectionService.instance.init()); - // } + if (!UpdateService.instance.isFdroidFlavor()) { + // unawaited(ObjectDetectionService.instance.init()); + if (flagService.faceSearchEnabled) { + unawaited(FaceMlService.instance.init()); + FaceMlService.instance.listenIndexOnDiffSync(); + FaceMlService.instance.listenOnPeopleChangedSync(); + } else { + if (LocalSettings.instance.isFaceIndexingEnabled) { + unawaited(LocalSettings.instance.toggleFaceIndexing()); + } + } + } + PersonService.init( + EntityService.instance, + FaceMLDataDB.instance, + preferences, + ); _logger.info("Initialization done"); } diff --git a/mobile/lib/models/api/entity/type.dart b/mobile/lib/models/api/entity/type.dart index 3631792de..88e60d62f 100644 --- a/mobile/lib/models/api/entity/type.dart +++ b/mobile/lib/models/api/entity/type.dart @@ -2,6 +2,7 @@ import "package:flutter/foundation.dart"; enum EntityType { location, + person, unknown, } @@ -9,6 +10,8 @@ EntityType typeFromString(String type) { switch (type) { case "location": return EntityType.location; + case "person": + return EntityType.location; } debugPrint("unexpected collection type $type"); return EntityType.unknown; @@ -19,6 +22,8 @@ extension EntityTypeExtn on EntityType { switch (this) { case EntityType.location: return "location"; + case EntityType.person: + return "person"; case EntityType.unknown: return "unknown"; } diff --git a/mobile/lib/models/file/file.dart b/mobile/lib/models/file/file.dart index d96a81e1c..9df25bb05 100644 --- a/mobile/lib/models/file/file.dart +++ b/mobile/lib/models/file/file.dart @@ -243,6 +243,9 @@ class EnteFile { } String get downloadUrl { + if (localFileServer.isNotEmpty) { + return "$localFileServer/$uploadedFileID"; + } final endpoint = Configuration.instance.getHttpEndpoint(); if (endpoint != kDefaultProductionEndpoint || flagService.disableCFWorker) { return endpoint + "/files/download/" + uploadedFileID.toString(); @@ -256,6 +259,9 @@ class EnteFile { } String get thumbnailUrl { + if (localFileServer.isNotEmpty) { + return "$localFileServer/thumb/$uploadedFileID"; + } final endpoint = Configuration.instance.getHttpEndpoint(); if (endpoint != kDefaultProductionEndpoint || flagService.disableCFWorker) { return endpoint + "/files/preview/" + uploadedFileID.toString(); diff --git a/mobile/lib/models/gallery_type.dart b/mobile/lib/models/gallery_type.dart index 40426f701..bb02f1bbc 100644 --- a/mobile/lib/models/gallery_type.dart +++ b/mobile/lib/models/gallery_type.dart @@ -18,6 +18,8 @@ enum GalleryType { searchResults, locationTag, quickLink, + peopleTag, + cluster, } extension GalleyTypeExtension on GalleryType { @@ -32,12 +34,14 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.locationTag: case GalleryType.quickLink: case GalleryType.uncategorized: + case GalleryType.peopleTag: case GalleryType.sharedCollection: return true; case GalleryType.hiddenSection: case GalleryType.hiddenOwnedCollection: case GalleryType.trash: + case GalleryType.cluster: return false; } } @@ -50,6 +54,7 @@ extension GalleyTypeExtension on GalleryType { return true; case GalleryType.hiddenSection: + case GalleryType.peopleTag: case GalleryType.hiddenOwnedCollection: case GalleryType.favorite: case GalleryType.searchResults: @@ -59,6 +64,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.trash: case GalleryType.sharedCollection: case GalleryType.locationTag: + case GalleryType.cluster: return false; } } @@ -75,12 +81,14 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.uncategorized: case GalleryType.locationTag: case GalleryType.quickLink: + case GalleryType.peopleTag: return true; case GalleryType.trash: case GalleryType.archive: case GalleryType.hiddenSection: case GalleryType.hiddenOwnedCollection: case GalleryType.sharedCollection: + case GalleryType.cluster: return false; } } @@ -98,8 +106,10 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.localFolder: case GalleryType.locationTag: case GalleryType.quickLink: + case GalleryType.peopleTag: return true; case GalleryType.trash: + case GalleryType.cluster: case GalleryType.sharedCollection: return false; } @@ -114,8 +124,10 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.archive: case GalleryType.uncategorized: case GalleryType.locationTag: + case GalleryType.peopleTag: return true; case GalleryType.hiddenSection: + case GalleryType.cluster: case GalleryType.hiddenOwnedCollection: case GalleryType.localFolder: case GalleryType.trash: @@ -132,6 +144,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.quickLink: return true; case GalleryType.hiddenSection: + case GalleryType.peopleTag: case GalleryType.hiddenOwnedCollection: case GalleryType.uncategorized: case GalleryType.favorite: @@ -139,6 +152,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.homepage: case GalleryType.archive: case GalleryType.localFolder: + case GalleryType.cluster: case GalleryType.trash: case GalleryType.locationTag: return false; @@ -154,6 +168,7 @@ extension GalleyTypeExtension on GalleryType { return true; case GalleryType.hiddenSection: + case GalleryType.peopleTag: case GalleryType.hiddenOwnedCollection: case GalleryType.favorite: case GalleryType.searchResults: @@ -162,6 +177,7 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.trash: case GalleryType.sharedCollection: case GalleryType.locationTag: + case GalleryType.cluster: return false; } } @@ -182,10 +198,12 @@ extension GalleyTypeExtension on GalleryType { return true; case GalleryType.hiddenSection: + case GalleryType.peopleTag: case GalleryType.hiddenOwnedCollection: case GalleryType.localFolder: case GalleryType.trash: case GalleryType.favorite: + case GalleryType.cluster: case GalleryType.sharedCollection: return false; } @@ -203,12 +221,14 @@ extension GalleyTypeExtension on GalleryType { case GalleryType.searchResults: case GalleryType.uncategorized: case GalleryType.locationTag: + case GalleryType.peopleTag: return true; case GalleryType.hiddenSection: case GalleryType.hiddenOwnedCollection: case GalleryType.quickLink: case GalleryType.favorite: + case GalleryType.cluster: case GalleryType.archive: case GalleryType.localFolder: case GalleryType.trash: @@ -244,7 +264,7 @@ extension GalleyTypeExtension on GalleryType { } bool showEditLocation() { - return this != GalleryType.sharedCollection; + return this != GalleryType.sharedCollection && this != GalleryType.cluster; } } @@ -334,7 +354,9 @@ extension GalleryAppBarExtn on GalleryType { case GalleryType.locationTag: case GalleryType.searchResults: return false; + case GalleryType.cluster: case GalleryType.uncategorized: + case GalleryType.peopleTag: case GalleryType.ownedCollection: case GalleryType.sharedCollection: case GalleryType.quickLink: diff --git a/mobile/lib/models/local_entity_data.dart b/mobile/lib/models/local_entity_data.dart index 9066e16fd..910167b13 100644 --- a/mobile/lib/models/local_entity_data.dart +++ b/mobile/lib/models/local_entity_data.dart @@ -1,6 +1,7 @@ import "package:equatable/equatable.dart"; import "package:photos/models/api/entity/type.dart"; +// LocalEntityData is a class that represents the data of an entity stored locally. class LocalEntityData { final String id; final EntityType type; diff --git a/mobile/lib/models/ml/ml_typedefs.dart b/mobile/lib/models/ml/ml_typedefs.dart new file mode 100644 index 000000000..bcb23251e --- /dev/null +++ b/mobile/lib/models/ml/ml_typedefs.dart @@ -0,0 +1,7 @@ +typedef Embedding = List; + +typedef Num3DInputMatrix = List>>; + +typedef Int3DInputMatrix = List>>; + +typedef Double3DInputMatrix = List>>; diff --git a/mobile/lib/models/ml/ml_versions.dart b/mobile/lib/models/ml/ml_versions.dart new file mode 100644 index 000000000..857bef33c --- /dev/null +++ b/mobile/lib/models/ml/ml_versions.dart @@ -0,0 +1,3 @@ +const faceMlVersion = 1; +const clusterMlVersion = 1; +const minimumClusterSize = 2; \ No newline at end of file diff --git a/mobile/lib/models/search/generic_search_result.dart b/mobile/lib/models/search/generic_search_result.dart index 352886a50..a40f71fd3 100644 --- a/mobile/lib/models/search/generic_search_result.dart +++ b/mobile/lib/models/search/generic_search_result.dart @@ -8,8 +8,15 @@ class GenericSearchResult extends SearchResult { final List _files; final ResultType _type; final Function(BuildContext context)? onResultTap; + final Map params; - GenericSearchResult(this._type, this._name, this._files, {this.onResultTap}); + GenericSearchResult( + this._type, + this._name, + this._files, { + this.onResultTap, + this.params = const {}, + }); @override String name() { diff --git a/mobile/lib/models/search/search_constants.dart b/mobile/lib/models/search/search_constants.dart new file mode 100644 index 000000000..6a0bcb886 --- /dev/null +++ b/mobile/lib/models/search/search_constants.dart @@ -0,0 +1,3 @@ +const kPersonParamID = 'person_id'; +const kClusterParamId = 'cluster_id'; +const kFileID = 'file_id'; diff --git a/mobile/lib/models/search/search_types.dart b/mobile/lib/models/search/search_types.dart index 1ec197c7e..a13fd57dc 100644 --- a/mobile/lib/models/search/search_types.dart +++ b/mobile/lib/models/search/search_types.dart @@ -6,6 +6,7 @@ import "package:photos/core/event_bus.dart"; import "package:photos/events/collection_updated_event.dart"; import "package:photos/events/event.dart"; import "package:photos/events/location_tag_updated_event.dart"; +import "package:photos/events/people_changed_event.dart"; import "package:photos/generated/l10n.dart"; import "package:photos/models/collection/collection.dart"; import "package:photos/models/collection/collection_items.dart"; @@ -33,6 +34,7 @@ enum ResultType { fileCaption, event, shared, + faces, magic, } @@ -55,7 +57,7 @@ extension SectionTypeExtensions on SectionType { String sectionTitle(BuildContext context) { switch (this) { case SectionType.face: - return S.of(context).faces; + return S.of(context).people; case SectionType.content: return S.of(context).contents; case SectionType.moment: @@ -117,10 +119,12 @@ extension SectionTypeExtensions on SectionType { } } + bool get sortByName => this != SectionType.face; + bool get isEmptyCTAVisible { switch (this) { case SectionType.face: - return true; + return false; case SectionType.content: return false; case SectionType.moment: @@ -245,8 +249,7 @@ extension SectionTypeExtensions on SectionType { }) { switch (this) { case SectionType.face: - return Future.value(List.empty()); - + return SearchService.instance.getAllFace(limit); case SectionType.content: return Future.value(List.empty()); @@ -277,6 +280,8 @@ extension SectionTypeExtensions on SectionType { return [Bus.instance.on()]; case SectionType.album: return [Bus.instance.on()]; + case SectionType.face: + return [Bus.instance.on()]; default: return []; } diff --git a/mobile/lib/services/entity_service.dart b/mobile/lib/services/entity_service.dart index e681f37b7..6ffe87358 100644 --- a/mobile/lib/services/entity_service.dart +++ b/mobile/lib/services/entity_service.dart @@ -50,6 +50,10 @@ class EntityService { return await _db.getEntities(type); } + Future getEntity(EntityType type, String id) async { + return await _db.getEntity(type, id); + } + Future addOrUpdate( EntityType type, String plainText, { @@ -57,13 +61,16 @@ class EntityService { }) async { final key = await getOrCreateEntityKey(type); final encryptedKeyData = await CryptoUtil.encryptChaCha( - utf8.encode(plainText) as Uint8List, + utf8.encode(plainText), key, ); final String encryptedData = CryptoUtil.bin2base64(encryptedKeyData.encryptedData!); final String header = CryptoUtil.bin2base64(encryptedKeyData.header!); - debugPrint("Adding entity of type: " + type.typeToString()); + debugPrint( + " ${id == null ? 'Adding' : 'Updating'} entity of type: " + + type.typeToString(), + ); final EntityData data = id == null ? await _gateway.createEntity(type, encryptedData, header) : await _gateway.updateEntity(type, id, encryptedData, header); @@ -87,6 +94,7 @@ class EntityService { Future syncEntities() async { try { await _remoteToLocalSync(EntityType.location); + await _remoteToLocalSync(EntityType.person); } catch (e) { _logger.severe("Failed to sync entities", e); } diff --git a/mobile/lib/services/machine_learning/face_ml/face_alignment/alignment_result.dart b/mobile/lib/services/machine_learning/face_ml/face_alignment/alignment_result.dart new file mode 100644 index 000000000..41fd88b61 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_alignment/alignment_result.dart @@ -0,0 +1,36 @@ +class AlignmentResult { + final List> affineMatrix; // 3x3 + final List center; // [x, y] + final double size; // 1 / scale + final double rotation; // atan2(simRotation[1][0], simRotation[0][0]); + + AlignmentResult({required this.affineMatrix, required this.center, required this.size, required this.rotation}); + + AlignmentResult.empty() + : affineMatrix = >[ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + ], + center = [0, 0], + size = 1, + rotation = 0; + + factory AlignmentResult.fromJson(Map json) { + return AlignmentResult( + affineMatrix: (json['affineMatrix'] as List) + .map((item) => List.from(item)) + .toList(), + center: List.from(json['center'] as List), + size: json['size'] as double, + rotation: json['rotation'] as double, + ); + } + + Map toJson() => { + 'affineMatrix': affineMatrix, + 'center': center, + 'size': size, + 'rotation': rotation, + }; +} \ No newline at end of file diff --git a/mobile/lib/services/machine_learning/face_ml/face_alignment/similarity_transform.dart b/mobile/lib/services/machine_learning/face_ml/face_alignment/similarity_transform.dart new file mode 100644 index 000000000..0d8e7ab3a --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_alignment/similarity_transform.dart @@ -0,0 +1,171 @@ +import 'dart:math' show atan2; +import 'package:ml_linalg/linalg.dart'; +import 'package:photos/extensions/ml_linalg_extensions.dart'; +import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; + +/// Class to compute the similarity transform between two sets of points. +/// +/// The class estimates the parameters of the similarity transformation via the `estimate` function. +/// After estimation, the transformation can be applied to an image using the `warpAffine` function. +class SimilarityTransform { + Matrix _params = Matrix.fromList([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0, 0, 1], + ]); + List _center = [0, 0]; // [x, y] + double _size = 1; // 1 / scale + double _rotation = 0; // atan2(simRotation[1][0], simRotation[0][0]); + + final arcface4Landmarks = [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [56.1396, 92.2848], + ]; + final arcface5Landmarks = [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041], + ]; + get arcfaceNormalized4 => arcface4Landmarks + .map((list) => list.map((value) => value / 112.0).toList()) + .toList(); + get arcfaceNormalized5 => arcface5Landmarks + .map((list) => list.map((value) => value / 112.0).toList()) + .toList(); + + List> get paramsList => _params.to2DList(); + + // singleton pattern + SimilarityTransform._privateConstructor(); + static final instance = SimilarityTransform._privateConstructor(); + factory SimilarityTransform() => instance; + + void _cleanParams() { + _params = Matrix.fromList([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0, 0, 1], + ]); + _center = [0, 0]; + _size = 1; + _rotation = 0; + } + + /// Function to estimate the parameters of the affine transformation. These parameters are stored in the class variable params. + /// + /// Returns a tuple of (AlignmentResult, bool). The bool indicates whether the parameters are valid or not. + /// + /// Runs efficiently in about 1-3 ms after initial warm-up. + /// + /// It takes the source and destination points as input and returns the + /// parameters of the affine transformation as output. The function + /// returns false if the parameters cannot be estimated. The function + /// estimates the parameters by solving a least-squares problem using + /// the Umeyama algorithm, via [_umeyama]. + (AlignmentResult, bool) estimate(List> src) { + _cleanParams(); + final (params, center, size, rotation) = + _umeyama(src, arcfaceNormalized5, true); + _params = params; + _center = center; + _size = size; + _rotation = rotation; + final alignmentResult = AlignmentResult( + affineMatrix: paramsList, + center: _center, + size: _size, + rotation: _rotation, + ); + // We check for NaN in the transformation matrix params. + final isNoNanInParam = + !_params.asFlattenedList.any((element) => element.isNaN); + return (alignmentResult, isNoNanInParam); + } + + static (Matrix, List, double, double) _umeyama( + List> src, + List> dst, [ + bool estimateScale = true, + ]) { + final srcMat = Matrix.fromList( + src, + // .map((list) => list.map((value) => value.toDouble()).toList()) + // .toList(), + ); + final dstMat = Matrix.fromList(dst); + final num = srcMat.rowCount; + final dim = srcMat.columnCount; + + // Compute mean of src and dst. + final srcMean = srcMat.mean(Axis.columns); + final dstMean = dstMat.mean(Axis.columns); + + // Subtract mean from src and dst. + final srcDemean = srcMat.mapRows((vector) => vector - srcMean); + final dstDemean = dstMat.mapRows((vector) => vector - dstMean); + + // Eq. (38). + final A = (dstDemean.transpose() * srcDemean) / num; + + // Eq. (39). + var d = Vector.filled(dim, 1.0); + if (A.determinant() < 0) { + d = d.set(dim - 1, -1); + } + + var T = Matrix.identity(dim + 1); + + final svdResult = A.svd(); + final Matrix U = svdResult['U']!; + final Vector S = svdResult['S']!; + final Matrix V = svdResult['V']!; + + // Eq. (40) and (43). + final rank = A.matrixRank(); + if (rank == 0) { + return (T * double.nan, [0, 0], 1, 0); + } else if (rank == dim - 1) { + if (U.determinant() * V.determinant() > 0) { + T = T.setSubMatrix(0, dim, 0, dim, U * V); + } else { + final s = d[dim - 1]; + d = d.set(dim - 1, -1); + final replacement = U * Matrix.diagonal(d.toList()) * V; + T = T.setSubMatrix(0, dim, 0, dim, replacement); + d = d.set(dim - 1, s); + } + } else { + final replacement = U * Matrix.diagonal(d.toList()) * V; + T = T.setSubMatrix(0, dim, 0, dim, replacement); + } + final Matrix simRotation = U * Matrix.diagonal(d.toList()) * V; + + var scale = 1.0; + if (estimateScale) { + // Eq. (41) and (42). + scale = 1.0 / srcDemean.variance(Axis.columns).sum() * (S * d).sum(); + } + + final subTIndices = Iterable.generate(dim, (index) => index); + final subT = T.sample(rowIndices: subTIndices, columnIndices: subTIndices); + final newSubT = dstMean - (subT * srcMean) * scale; + T = T.setValues(0, dim, dim, dim + 1, newSubT); + final newNewSubT = + T.sample(rowIndices: subTIndices, columnIndices: subTIndices) * scale; + T = T.setSubMatrix(0, dim, 0, dim, newNewSubT); + + // final List translation = [T[0][2], T[1][2]]; + // final simRotation = replacement?; + final size = 1 / scale; + final rotation = atan2(simRotation[1][0], simRotation[0][0]); + final meanTranslation = (dstMean - 0.5) * size; + final centerMat = srcMean - meanTranslation; + final List center = [centerMat[0], centerMat[1]]; + + return (T, center, size, rotation); + } +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/clusters_mapping.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/clusters_mapping.dart new file mode 100644 index 000000000..77be47e2b --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/clusters_mapping.dart @@ -0,0 +1,22 @@ +import "package:photos/face/model/person.dart"; + +enum MappingSource { + local, + remote, +} + +class ClustersMapping { + final Map> fileIDToClusterIDs; + final Map clusterToPersonID; + // personIDToPerson is a map of personID to PersonEntity, and it's same for + // both local and remote sources + final Map personIDToPerson; + final MappingSource source; + + ClustersMapping({ + required this.fileIDToClusterIDs, + required this.clusterToPersonID, + required this.personIDToPerson, + required this.source, + }); +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/cosine_distance.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/cosine_distance.dart new file mode 100644 index 000000000..0611a1d83 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/cosine_distance.dart @@ -0,0 +1,79 @@ +import 'dart:math' show sqrt; + +import "package:ml_linalg/linalg.dart"; + +/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg +/// +/// WARNING: This assumes both vectors are already normalized! +double cosineDistanceSIMD(Vector vector1, Vector vector2) { + if (vector1.length != vector2.length) { + throw ArgumentError('Vectors must be the same length'); + } + + return 1 - vector1.dot(vector2); +} + +/// Calculates the cosine distance between two embeddings/vectors using SIMD from ml_linalg +/// +/// WARNING: Only use when you're not sure if vectors are normalized. If you're sure they are, use [cosineDistanceSIMD] instead for better performance. +double cosineDistanceSIMDSafe(Vector vector1, Vector vector2) { + if (vector1.length != vector2.length) { + throw ArgumentError('Vectors must be the same length'); + } + + return vector1.distanceTo(vector2, distance: Distance.cosine); +} + +/// Calculates the cosine distance between two embeddings/vectors. +/// +/// Throws an ArgumentError if the vectors are of different lengths or +/// if either of the vectors has a magnitude of zero. +double cosineDistance(List vector1, List vector2) { + if (vector1.length != vector2.length) { + throw ArgumentError('Vectors must be the same length'); + } + + double dotProduct = 0.0; + double magnitude1 = 0.0; + double magnitude2 = 0.0; + + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2[i]; + magnitude1 += vector1[i] * vector1[i]; + magnitude2 += vector2[i] * vector2[i]; + } + + magnitude1 = sqrt(magnitude1); + magnitude2 = sqrt(magnitude2); + + // Avoid division by zero. This should never happen. If it does, then one of the vectors contains only zeros. + if (magnitude1 == 0 || magnitude2 == 0) { + throw ArgumentError('Vectors must not have a magnitude of zero'); + } + + final double similarity = dotProduct / (magnitude1 * magnitude2); + + // Cosine distance is the complement of cosine similarity + return 1.0 - similarity; +} + +// cosineDistForNormVectors calculates the cosine distance between two normalized embeddings/vectors. +@pragma('vm:entry-point') +double cosineDistForNormVectors(List vector1, List vector2) { + if (vector1.length != vector2.length) { + throw ArgumentError('Vectors must be the same length'); + } + double dotProduct = 0.0; + for (int i = 0; i < vector1.length; i++) { + dotProduct += vector1[i] * vector2[i]; + } + return 1.0 - dotProduct; +} + +double calculateSqrDistance(List v1, List v2) { + double sum = 0; + for (int i = 0; i < v1.length; i++) { + sum += (v1[i] - v2[i]) * (v1[i] - v2[i]); + } + return sqrt(sum); +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart new file mode 100644 index 000000000..1b8d9c3bd --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart @@ -0,0 +1,1029 @@ +import "dart:async"; +import "dart:developer"; +import "dart:isolate"; +import "dart:math" show max; +import "dart:typed_data" show Uint8List; + +import "package:computer/computer.dart"; +import "package:flutter/foundation.dart" show kDebugMode; +import "package:logging/logging.dart"; +import "package:ml_linalg/dtype.dart"; +import "package:ml_linalg/vector.dart"; +import "package:photos/generated/protos/ente/common/vector.pb.dart"; +import 'package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart'; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; +import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; +import "package:simple_cluster/simple_cluster.dart"; +import "package:synchronized/synchronized.dart"; + +class FaceInfo { + final String faceID; + final double? faceScore; + final double? blurValue; + final bool? badFace; + final List? embedding; + final Vector? vEmbedding; + int? clusterId; + String? closestFaceId; + int? closestDist; + int? fileCreationTime; + FaceInfo({ + required this.faceID, + this.faceScore, + this.blurValue, + this.badFace, + this.embedding, + this.vEmbedding, + this.clusterId, + this.fileCreationTime, + }); +} + +enum ClusterOperation { linearIncrementalClustering, dbscanClustering } + +class ClusteringResult { + final Map newFaceIdToCluster; + final Map>? newClusterIdToFaceIds; + final Map? newClusterSummaries; + + bool get isEmpty => newFaceIdToCluster.isEmpty; + + ClusteringResult({ + required this.newFaceIdToCluster, + this.newClusterSummaries, + this.newClusterIdToFaceIds, + }); +} + +class FaceClusteringService { + final _logger = Logger("FaceLinearClustering"); + final _computer = Computer.shared(); + + Timer? _inactivityTimer; + final Duration _inactivityDuration = const Duration(minutes: 3); + int _activeTasks = 0; + + final _initLock = Lock(); + + late Isolate _isolate; + late ReceivePort _receivePort = ReceivePort(); + late SendPort _mainSendPort; + + bool isSpawned = false; + bool isRunning = false; + + static const kRecommendedDistanceThreshold = 0.24; + static const kConservativeDistanceThreshold = 0.16; + + // singleton pattern + FaceClusteringService._privateConstructor(); + + /// Use this instance to access the FaceClustering service. + /// e.g. `FaceLinearClustering.instance.predict(dataset)` + static final instance = FaceClusteringService._privateConstructor(); + factory FaceClusteringService() => instance; + + Future init() async { + return _initLock.synchronized(() async { + if (isSpawned) return; + + _receivePort = ReceivePort(); + + try { + _isolate = await Isolate.spawn( + _isolateMain, + _receivePort.sendPort, + ); + _mainSendPort = await _receivePort.first as SendPort; + isSpawned = true; + + _resetInactivityTimer(); + } catch (e) { + _logger.severe('Could not spawn isolate', e); + isSpawned = false; + } + }); + } + + Future ensureSpawned() async { + if (!isSpawned) { + await init(); + } + } + + /// The main execution function of the isolate. + static void _isolateMain(SendPort mainSendPort) async { + final receivePort = ReceivePort(); + mainSendPort.send(receivePort.sendPort); + + receivePort.listen((message) async { + final functionIndex = message[0] as int; + final function = ClusterOperation.values[functionIndex]; + final args = message[1] as Map; + final sendPort = message[2] as SendPort; + + try { + switch (function) { + case ClusterOperation.linearIncrementalClustering: + final result = FaceClusteringService.runLinearClustering(args); + sendPort.send(result); + break; + case ClusterOperation.dbscanClustering: + final result = FaceClusteringService._runDbscanClustering(args); + sendPort.send(result); + break; + } + } catch (e, stackTrace) { + sendPort + .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); + } + }); + } + + /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. + Future _runInIsolate( + (ClusterOperation, Map) message, + ) async { + await ensureSpawned(); + _resetInactivityTimer(); + final completer = Completer(); + final answerPort = ReceivePort(); + + _activeTasks++; + _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); + + answerPort.listen((receivedMessage) { + if (receivedMessage is Map && receivedMessage.containsKey('error')) { + // Handle the error + final errorMessage = receivedMessage['error']; + final errorStackTrace = receivedMessage['stackTrace']; + final exception = Exception(errorMessage); + final stackTrace = StackTrace.fromString(errorStackTrace); + _activeTasks--; + completer.completeError(exception, stackTrace); + } else { + _activeTasks--; + completer.complete(receivedMessage); + } + }); + + return completer.future; + } + + /// Resets a timer that kills the isolate after a certain amount of inactivity. + /// + /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) + void _resetInactivityTimer() { + _inactivityTimer?.cancel(); + _inactivityTimer = Timer(_inactivityDuration, () { + if (_activeTasks > 0) { + _logger.info('Tasks are still running. Delaying isolate disposal.'); + // Optionally, reschedule the timer to check again later. + _resetInactivityTimer(); + } else { + _logger.info( + 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', + ); + dispose(); + } + }); + } + + /// Disposes the isolate worker. + void dispose() { + if (!isSpawned) return; + + isSpawned = false; + _isolate.kill(); + _receivePort.close(); + _inactivityTimer?.cancel(); + } + + /// Runs the clustering algorithm [runLinearClustering] on the given [input], in an isolate. + /// + /// Returns the clustering result, which is a list of clusters, where each cluster is a list of indices of the dataset. + /// + /// WARNING: Make sure to always input data in the same ordering, otherwise the clustering can less less deterministic. + Future predictLinear( + Set input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + double conservativeDistanceThreshold = kConservativeDistanceThreshold, + bool useDynamicThreshold = true, + int? offset, + required Map oldClusterSummaries, + }) async { + if (input.isEmpty) { + _logger.warning( + "Clustering dataset of embeddings is empty, returning empty list.", + ); + return null; + } + if (isRunning) { + _logger.warning("Clustering is already running, returning empty list."); + return null; + } + + isRunning = true; + try { + // Clustering inside the isolate + _logger.info( + "Start clustering on ${input.length} embeddings inside computer isolate", + ); + final stopwatchClustering = Stopwatch()..start(); + // final Map faceIdToCluster = + // await _runLinearClusteringInComputer(input); + final ClusteringResult? faceIdToCluster = await _runInIsolate( + ( + ClusterOperation.linearIncrementalClustering, + { + 'input': input, + 'fileIDToCreationTime': fileIDToCreationTime, + 'distanceThreshold': distanceThreshold, + 'conservativeDistanceThreshold': conservativeDistanceThreshold, + 'useDynamicThreshold': useDynamicThreshold, + 'offset': offset, + 'oldClusterSummaries': oldClusterSummaries, + } + ), + ); + // return _runLinearClusteringInComputer(input); + _logger.info( + 'predictLinear Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds', + ); + + isRunning = false; + return faceIdToCluster; + } catch (e, stackTrace) { + _logger.severe('Error while running clustering', e, stackTrace); + isRunning = false; + rethrow; + } + } + + /// Runs the clustering algorithm [runLinearClustering] on the given [input], in computer, without any dynamic thresholding + Future predictLinearComputer( + Map input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + }) async { + if (input.isEmpty) { + _logger.warning( + "Linear Clustering dataset of embeddings is empty, returning empty list.", + ); + return null; + } + + // Clustering inside the isolate + _logger.info( + "Start Linear clustering on ${input.length} embeddings inside computer isolate", + ); + + try { + final clusteringInput = input + .map((key, value) { + return MapEntry( + key, + FaceInfoForClustering( + faceID: key, + embeddingBytes: value, + faceScore: kMinimumQualityFaceScore + 0.01, + blurValue: kLapacianDefault, + ), + ); + }) + .values + .toSet(); + final startTime = DateTime.now(); + final faceIdToCluster = await _computer.compute( + runLinearClustering, + param: { + "input": clusteringInput, + "fileIDToCreationTime": fileIDToCreationTime, + "distanceThreshold": distanceThreshold, + "conservativeDistanceThreshold": distanceThreshold - 0.08, + "useDynamicThreshold": false, + }, + taskName: "createImageEmbedding", + ) as ClusteringResult; + final endTime = DateTime.now(); + _logger.info( + "Linear Clustering took: ${endTime.difference(startTime).inMilliseconds}ms", + ); + return faceIdToCluster; + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + + /// Runs the clustering algorithm [runCompleteClustering] on the given [input], in computer. + /// + /// WARNING: Only use on small datasets, as it is not optimized for large datasets. + Future predictCompleteComputer( + Map input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + double mergeThreshold = 0.30, + }) async { + if (input.isEmpty) { + _logger.warning( + "Complete Clustering dataset of embeddings is empty, returning empty list.", + ); + return ClusteringResult(newFaceIdToCluster: {}); + } + + // Clustering inside the isolate + _logger.info( + "Start Complete clustering on ${input.length} embeddings inside computer isolate", + ); + + try { + final startTime = DateTime.now(); + final clusteringResult = await _computer.compute( + runCompleteClustering, + param: { + "input": input, + "fileIDToCreationTime": fileIDToCreationTime, + "distanceThreshold": distanceThreshold, + "mergeThreshold": mergeThreshold, + }, + taskName: "createImageEmbedding", + ) as ClusteringResult; + final endTime = DateTime.now(); + _logger.info( + "Complete Clustering took: ${endTime.difference(startTime).inMilliseconds}ms", + ); + return clusteringResult; + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + + Future predictWithinClusterComputer( + Map input, { + Map? fileIDToCreationTime, + double distanceThreshold = kRecommendedDistanceThreshold, + }) async { + _logger.info( + '`predictWithinClusterComputer` called with ${input.length} faces and distance threshold $distanceThreshold', + ); + try { + if (input.length < 500) { + final mergeThreshold = distanceThreshold; + _logger.info( + 'Running complete clustering on ${input.length} faces with distance threshold $mergeThreshold', + ); + final result = await predictCompleteComputer( + input, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: distanceThreshold - 0.08, + mergeThreshold: mergeThreshold, + ); + if (result.newFaceIdToCluster.isEmpty) return null; + return result; + } else { + _logger.info( + 'Running linear clustering on ${input.length} faces with distance threshold $distanceThreshold', + ); + final clusterResult = await predictLinearComputer( + input, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: distanceThreshold, + ); + return clusterResult; + } + } catch (e, s) { + _logger.severe(e, s); + rethrow; + } + } + + Future>> predictDbscan( + Map input, { + Map? fileIDToCreationTime, + double eps = 0.3, + int minPts = 5, + }) async { + if (input.isEmpty) { + _logger.warning( + "DBSCAN Clustering dataset of embeddings is empty, returning empty list.", + ); + return []; + } + if (isRunning) { + _logger.warning( + "DBSCAN Clustering is already running, returning empty list.", + ); + return []; + } + + isRunning = true; + + // Clustering inside the isolate + _logger.info( + "Start DBSCAN clustering on ${input.length} embeddings inside computer isolate", + ); + final stopwatchClustering = Stopwatch()..start(); + // final Map faceIdToCluster = + // await _runLinearClusteringInComputer(input); + final List> clusterFaceIDs = await _runInIsolate( + ( + ClusterOperation.dbscanClustering, + { + 'input': input, + 'fileIDToCreationTime': fileIDToCreationTime, + 'eps': eps, + 'minPts': minPts, + } + ), + ); + // return _runLinearClusteringInComputer(input); + _logger.info( + 'DBSCAN Clustering executed in ${stopwatchClustering.elapsed.inSeconds} seconds', + ); + + isRunning = false; + + return clusterFaceIDs; + } + + static ClusteringResult? runLinearClustering(Map args) { + // final input = args['input'] as Map; + final input = args['input'] as Set; + final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; + final distanceThreshold = args['distanceThreshold'] as double; + final conservativeDistanceThreshold = + args['conservativeDistanceThreshold'] as double; + final useDynamicThreshold = args['useDynamicThreshold'] as bool; + final offset = args['offset'] as int?; + final oldClusterSummaries = + args['oldClusterSummaries'] as Map?; + + log( + "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces", + ); + + // Organize everything into a list of FaceInfo objects + final List faceInfos = []; + for (final face in input) { + faceInfos.add( + FaceInfo( + faceID: face.faceID, + faceScore: face.faceScore, + blurValue: face.blurValue, + badFace: face.faceScore < kMinimumQualityFaceScore || + face.blurValue < kLaplacianSoftThreshold || + (face.blurValue < kLaplacianVerySoftThreshold && + face.faceScore < kMediumQualityFaceScore) || + face.isSideways, + vEmbedding: Vector.fromList( + EVector.fromBuffer(face.embeddingBytes).values, + dtype: DType.float32, + ), + clusterId: face.clusterId, + fileCreationTime: + fileIDToCreationTime?[getFileIdFromFaceId(face.faceID)], + ), + ); + } + + // Assert that the embeddings are normalized + for (final faceInfo in faceInfos) { + if (faceInfo.vEmbedding != null) { + final norm = faceInfo.vEmbedding!.norm(); + assert((norm - 1.0).abs() < 1e-5); + } + } + + // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first + if (fileIDToCreationTime != null) { + faceInfos.sort((a, b) { + if (a.fileCreationTime == null && b.fileCreationTime == null) { + return 0; + } else if (a.fileCreationTime == null) { + return 1; + } else if (b.fileCreationTime == null) { + return -1; + } else { + return a.fileCreationTime!.compareTo(b.fileCreationTime!); + } + }); + } + + // Sort the faceInfos such that the ones with null clusterId are at the end + final List facesWithClusterID = []; + final List facesWithoutClusterID = []; + for (final FaceInfo faceInfo in faceInfos) { + if (faceInfo.clusterId == null) { + facesWithoutClusterID.add(faceInfo); + } else { + facesWithClusterID.add(faceInfo); + } + } + final alreadyClusteredCount = facesWithClusterID.length; + final sortedFaceInfos = []; + sortedFaceInfos.addAll(facesWithClusterID); + sortedFaceInfos.addAll(facesWithoutClusterID); + + log( + "[ClusterIsolate] ${DateTime.now()} Clustering ${facesWithoutClusterID.length} new faces without clusterId, and $alreadyClusteredCount faces with clusterId", + ); + + // Make sure the first face has a clusterId + final int totalFaces = sortedFaceInfos.length; + int dynamicThresholdCount = 0; + + if (sortedFaceInfos.isEmpty) { + return null; + } + + // Start actual clustering + log( + "[ClusterIsolate] ${DateTime.now()} Processing $totalFaces faces in total in this round ${offset != null ? "on top of ${offset + facesWithClusterID.length} earlier processed faces" : ""}", + ); + // set current epoch time as clusterID + int clusterID = DateTime.now().microsecondsSinceEpoch; + if (facesWithClusterID.isEmpty) { + // assign a clusterID to the first face + sortedFaceInfos[0].clusterId = clusterID; + clusterID++; + } + final stopwatchClustering = Stopwatch()..start(); + for (int i = 1; i < totalFaces; i++) { + // Incremental clustering, so we can skip faces that already have a clusterId + if (sortedFaceInfos[i].clusterId != null) { + clusterID = max(clusterID, sortedFaceInfos[i].clusterId!); + continue; + } + + int closestIdx = -1; + double closestDistance = double.infinity; + late double thresholdValue; + if (useDynamicThreshold) { + thresholdValue = sortedFaceInfos[i].badFace! + ? conservativeDistanceThreshold + : distanceThreshold; + if (sortedFaceInfos[i].badFace!) dynamicThresholdCount++; + } else { + thresholdValue = distanceThreshold; + } + if (i % 250 == 0) { + log("[ClusterIsolate] ${DateTime.now()} Processed ${offset != null ? i + offset : i} faces"); + } + for (int j = i - 1; j >= 0; j--) { + late double distance; + if (sortedFaceInfos[i].vEmbedding != null) { + distance = cosineDistanceSIMD( + sortedFaceInfos[i].vEmbedding!, + sortedFaceInfos[j].vEmbedding!, + ); + } else { + distance = cosineDistForNormVectors( + sortedFaceInfos[i].embedding!, + sortedFaceInfos[j].embedding!, + ); + } + if (distance < closestDistance) { + if (sortedFaceInfos[j].badFace! && + distance > conservativeDistanceThreshold) { + continue; + } + closestDistance = distance; + closestIdx = j; + } + } + + if (closestDistance < thresholdValue) { + if (sortedFaceInfos[closestIdx].clusterId == null) { + // Ideally this should never happen, but just in case log it + log( + " [ClusterIsolate] [WARNING] ${DateTime.now()} Found new cluster $clusterID", + ); + clusterID++; + sortedFaceInfos[closestIdx].clusterId = clusterID; + } + sortedFaceInfos[i].clusterId = sortedFaceInfos[closestIdx].clusterId; + } else { + clusterID++; + sortedFaceInfos[i].clusterId = clusterID; + } + } + + // Finally, assign the new clusterId to the faces + final Map newFaceIdToCluster = {}; + final newClusteredFaceInfos = + sortedFaceInfos.sublist(alreadyClusteredCount); + for (final faceInfo in newClusteredFaceInfos) { + newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; + } + + // Create a map of clusterId to faceIds + final Map> clusterIdToFaceIds = {}; + for (final entry in newFaceIdToCluster.entries) { + final clusterID = entry.value; + if (clusterIdToFaceIds.containsKey(clusterID)) { + clusterIdToFaceIds[clusterID]!.add(entry.key); + } else { + clusterIdToFaceIds[clusterID] = [entry.key]; + } + } + + stopwatchClustering.stop(); + log( + ' [ClusterIsolate] ${DateTime.now()} Clustering for ${sortedFaceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', + ); + if (useDynamicThreshold) { + log( + "[ClusterIsolate] ${DateTime.now()} Dynamic thresholding: $dynamicThresholdCount faces had a low face score or low blur clarity", + ); + } + + // Now calculate the mean of the embeddings for each cluster and update the cluster summaries + Map? newClusterSummaries; + if (oldClusterSummaries != null) { + newClusterSummaries = FaceClusteringService.updateClusterSummaries( + oldSummary: oldClusterSummaries, + newFaceInfos: newClusteredFaceInfos, + ); + } + + // analyze the results + // FaceClusteringService._analyzeClusterResults(sortedFaceInfos); + + return ClusteringResult( + newFaceIdToCluster: newFaceIdToCluster, + newClusterSummaries: newClusterSummaries, + newClusterIdToFaceIds: clusterIdToFaceIds, + ); + } + + static Map updateClusterSummaries({ + required Map oldSummary, + required List newFaceInfos, + }) { + final calcSummariesStart = DateTime.now(); + final Map> newClusterIdToFaceInfos = {}; + for (final faceInfo in newFaceInfos) { + if (newClusterIdToFaceInfos.containsKey(faceInfo.clusterId!)) { + newClusterIdToFaceInfos[faceInfo.clusterId!]!.add(faceInfo); + } else { + newClusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo]; + } + } + + final Map newClusterSummaries = {}; + for (final clusterId in newClusterIdToFaceInfos.keys) { + final List newEmbeddings = newClusterIdToFaceInfos[clusterId]! + .map((faceInfo) => faceInfo.vEmbedding!) + .toList(); + final newCount = newEmbeddings.length; + if (oldSummary.containsKey(clusterId)) { + final oldMean = Vector.fromList( + EVector.fromBuffer(oldSummary[clusterId]!.$1).values, + dtype: DType.float32, + ); + final oldCount = oldSummary[clusterId]!.$2; + final oldEmbeddings = oldMean * oldCount; + newEmbeddings.add(oldEmbeddings); + final newMeanVector = + newEmbeddings.reduce((a, b) => a + b) / (oldCount + newCount); + final newMeanVectorNormalized = newMeanVector / newMeanVector.norm(); + newClusterSummaries[clusterId] = ( + EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), + oldCount + newCount + ); + } else { + final newMeanVector = newEmbeddings.reduce((a, b) => a + b); + final newMeanVectorNormalized = newMeanVector / newMeanVector.norm(); + newClusterSummaries[clusterId] = ( + EVector(values: newMeanVectorNormalized.toList()).writeToBuffer(), + newCount + ); + } + } + log( + "[ClusterIsolate] ${DateTime.now()} Calculated cluster summaries in ${DateTime.now().difference(calcSummariesStart).inMilliseconds}ms", + ); + + return newClusterSummaries; + } + + static void _analyzeClusterResults(List sortedFaceInfos) { + if (!kDebugMode) return; + final stopwatch = Stopwatch()..start(); + + final Map faceIdToCluster = {}; + for (final faceInfo in sortedFaceInfos) { + faceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; + } + + // Find faceIDs that are part of a cluster which is larger than 5 and are new faceIDs + final Map clusterIdToSize = {}; + faceIdToCluster.forEach((key, value) { + if (clusterIdToSize.containsKey(value)) { + clusterIdToSize[value] = clusterIdToSize[value]! + 1; + } else { + clusterIdToSize[value] = 1; + } + }); + + // print top 10 cluster ids and their sizes based on the internal cluster id + final clusterIds = faceIdToCluster.values.toSet(); + final clusterSizes = clusterIds.map((clusterId) { + return faceIdToCluster.values.where((id) => id == clusterId).length; + }).toList(); + clusterSizes.sort(); + // find clusters whose size is greater than 1 + int oneClusterCount = 0; + int moreThan5Count = 0; + int moreThan10Count = 0; + int moreThan20Count = 0; + int moreThan50Count = 0; + int moreThan100Count = 0; + + for (int i = 0; i < clusterSizes.length; i++) { + if (clusterSizes[i] > 100) { + moreThan100Count++; + } else if (clusterSizes[i] > 50) { + moreThan50Count++; + } else if (clusterSizes[i] > 20) { + moreThan20Count++; + } else if (clusterSizes[i] > 10) { + moreThan10Count++; + } else if (clusterSizes[i] > 5) { + moreThan5Count++; + } else if (clusterSizes[i] == 1) { + oneClusterCount++; + } + } + + // print the metrics + log( + "[ClusterIsolate] Total clusters ${clusterIds.length}: \n oneClusterCount $oneClusterCount \n moreThan5Count $moreThan5Count \n moreThan10Count $moreThan10Count \n moreThan20Count $moreThan20Count \n moreThan50Count $moreThan50Count \n moreThan100Count $moreThan100Count", + ); + stopwatch.stop(); + log( + "[ClusterIsolate] Clustering additional analysis took ${stopwatch.elapsedMilliseconds} ms", + ); + } + + static ClusteringResult runCompleteClustering(Map args) { + final input = args['input'] as Map; + final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; + final distanceThreshold = args['distanceThreshold'] as double; + final mergeThreshold = args['mergeThreshold'] as double; + + log( + "[CompleteClustering] ${DateTime.now()} Copied to isolate ${input.length} faces for clustering", + ); + + // Organize everything into a list of FaceInfo objects + final List faceInfos = []; + for (final entry in input.entries) { + faceInfos.add( + FaceInfo( + faceID: entry.key, + vEmbedding: Vector.fromList( + EVector.fromBuffer(entry.value).values, + dtype: DType.float32, + ), + fileCreationTime: + fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], + ), + ); + } + + // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first + if (fileIDToCreationTime != null) { + faceInfos.sort((a, b) { + if (a.fileCreationTime == null && b.fileCreationTime == null) { + return 0; + } else if (a.fileCreationTime == null) { + return 1; + } else if (b.fileCreationTime == null) { + return -1; + } else { + return a.fileCreationTime!.compareTo(b.fileCreationTime!); + } + }); + } + + if (faceInfos.isEmpty) { + ClusteringResult(newFaceIdToCluster: {}); + } + final int totalFaces = faceInfos.length; + + // Start actual clustering + log( + "[CompleteClustering] ${DateTime.now()} Processing $totalFaces faces in one single round of complete clustering", + ); + + // set current epoch time as clusterID + int clusterID = DateTime.now().microsecondsSinceEpoch; + + // Start actual clustering + final Map newFaceIdToCluster = {}; + final stopwatchClustering = Stopwatch()..start(); + for (int i = 0; i < totalFaces; i++) { + if ((i + 1) % 250 == 0) { + log("[CompleteClustering] ${DateTime.now()} Processed ${i + 1} faces"); + } + if (faceInfos[i].clusterId != null) continue; + int closestIdx = -1; + double closestDistance = double.infinity; + for (int j = 0; j < totalFaces; j++) { + if (i == j) continue; + final double distance = cosineDistanceSIMD( + faceInfos[i].vEmbedding!, + faceInfos[j].vEmbedding!, + ); + if (distance < closestDistance) { + closestDistance = distance; + closestIdx = j; + } + } + + if (closestDistance < distanceThreshold) { + if (faceInfos[closestIdx].clusterId == null) { + clusterID++; + faceInfos[closestIdx].clusterId = clusterID; + } + faceInfos[i].clusterId = faceInfos[closestIdx].clusterId!; + } else { + clusterID++; + faceInfos[i].clusterId = clusterID; + } + } + + // Now calculate the mean of the embeddings for each cluster + final Map> clusterIdToFaceInfos = {}; + for (final faceInfo in faceInfos) { + if (clusterIdToFaceInfos.containsKey(faceInfo.clusterId)) { + clusterIdToFaceInfos[faceInfo.clusterId]!.add(faceInfo); + } else { + clusterIdToFaceInfos[faceInfo.clusterId!] = [faceInfo]; + } + } + final Map clusterIdToMeanEmbeddingAndWeight = {}; + for (final clusterId in clusterIdToFaceInfos.keys) { + final List embeddings = clusterIdToFaceInfos[clusterId]! + .map((faceInfo) => faceInfo.vEmbedding!) + .toList(); + final count = clusterIdToFaceInfos[clusterId]!.length; + final Vector meanEmbedding = embeddings.reduce((a, b) => a + b) / count; + final Vector meanEmbeddingNormalized = + meanEmbedding / meanEmbedding.norm(); + clusterIdToMeanEmbeddingAndWeight[clusterId] = + (meanEmbeddingNormalized, count); + } + + // Now merge the clusters that are close to each other, based on mean embedding + final List<(int, int)> mergedClustersList = []; + final List clusterIds = + clusterIdToMeanEmbeddingAndWeight.keys.toList(); + log(' [CompleteClustering] ${DateTime.now()} ${clusterIds.length} clusters found, now checking for merges'); + while (true) { + if (clusterIds.length < 2) break; + double distance = double.infinity; + (int, int) clusterIDsToMerge = (-1, -1); + for (int i = 0; i < clusterIds.length; i++) { + for (int j = 0; j < clusterIds.length; j++) { + if (i == j) continue; + final double newDistance = cosineDistanceSIMD( + clusterIdToMeanEmbeddingAndWeight[clusterIds[i]]!.$1, + clusterIdToMeanEmbeddingAndWeight[clusterIds[j]]!.$1, + ); + if (newDistance < distance) { + distance = newDistance; + clusterIDsToMerge = (clusterIds[i], clusterIds[j]); + } + } + } + if (distance < mergeThreshold) { + mergedClustersList.add(clusterIDsToMerge); + final clusterID1 = clusterIDsToMerge.$1; + final clusterID2 = clusterIDsToMerge.$2; + final mean1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$1; + final mean2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$1; + final count1 = clusterIdToMeanEmbeddingAndWeight[clusterID1]!.$2; + final count2 = clusterIdToMeanEmbeddingAndWeight[clusterID2]!.$2; + final weight1 = count1 / (count1 + count2); + final weight2 = count2 / (count1 + count2); + final weightedMean = mean1 * weight1 + mean2 * weight2; + final weightedMeanNormalized = weightedMean / weightedMean.norm(); + clusterIdToMeanEmbeddingAndWeight[clusterID1] = ( + weightedMeanNormalized, + count1 + count2, + ); + clusterIdToMeanEmbeddingAndWeight.remove(clusterID2); + clusterIds.remove(clusterID2); + } else { + break; + } + } + log(' [CompleteClustering] ${DateTime.now()} ${mergedClustersList.length} clusters merged'); + + // Now assign the new clusterId to the faces + for (final faceInfo in faceInfos) { + for (final mergedClusters in mergedClustersList) { + if (faceInfo.clusterId == mergedClusters.$2) { + faceInfo.clusterId = mergedClusters.$1; + } + } + } + + // Finally, assign the new clusterId to the faces + for (final faceInfo in faceInfos) { + newFaceIdToCluster[faceInfo.faceID] = faceInfo.clusterId!; + } + + final Map> clusterIdToFaceIds = {}; + for (final entry in newFaceIdToCluster.entries) { + final clusterID = entry.value; + if (clusterIdToFaceIds.containsKey(clusterID)) { + clusterIdToFaceIds[clusterID]!.add(entry.key); + } else { + clusterIdToFaceIds[clusterID] = [entry.key]; + } + } + + final newClusterSummaries = FaceClusteringService.updateClusterSummaries( + oldSummary: {}, + newFaceInfos: faceInfos, + ); + + stopwatchClustering.stop(); + log( + ' [CompleteClustering] ${DateTime.now()} Clustering for ${faceInfos.length} embeddings executed in ${stopwatchClustering.elapsedMilliseconds}ms', + ); + + return ClusteringResult( + newFaceIdToCluster: newFaceIdToCluster, + newClusterSummaries: newClusterSummaries, + newClusterIdToFaceIds: clusterIdToFaceIds, + ); + } + + static List> _runDbscanClustering(Map args) { + final input = args['input'] as Map; + final fileIDToCreationTime = args['fileIDToCreationTime'] as Map?; + final eps = args['eps'] as double; + final minPts = args['minPts'] as int; + + log( + "[ClusterIsolate] ${DateTime.now()} Copied to isolate ${input.length} faces", + ); + + final DBSCAN dbscan = DBSCAN( + epsilon: eps, + minPoints: minPts, + distanceMeasure: cosineDistForNormVectors, + ); + + // Organize everything into a list of FaceInfo objects + final List faceInfos = []; + for (final entry in input.entries) { + faceInfos.add( + FaceInfo( + faceID: entry.key, + embedding: EVector.fromBuffer(entry.value).values, + fileCreationTime: + fileIDToCreationTime?[getFileIdFromFaceId(entry.key)], + ), + ); + } + + // Sort the faceInfos based on fileCreationTime, in ascending order, so oldest faces are first + if (fileIDToCreationTime != null) { + faceInfos.sort((a, b) { + if (a.fileCreationTime == null && b.fileCreationTime == null) { + return 0; + } else if (a.fileCreationTime == null) { + return 1; + } else if (b.fileCreationTime == null) { + return -1; + } else { + return a.fileCreationTime!.compareTo(b.fileCreationTime!); + } + }); + } + + // Get the embeddings + final List> embeddings = + faceInfos.map((faceInfo) => faceInfo.embedding!).toList(); + + // Run the DBSCAN clustering + final List> clusterOutput = dbscan.run(embeddings); + // final List> clusteredFaceInfos = clusterOutput + // .map((cluster) => cluster.map((idx) => faceInfos[idx]).toList()) + // .toList(); + final List> clusteredFaceIDs = clusterOutput + .map((cluster) => cluster.map((idx) => faceInfos[idx].faceID).toList()) + .toList(); + + return clusteredFaceIDs; + } +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart new file mode 100644 index 000000000..b2f5c2e9e --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart @@ -0,0 +1,25 @@ +import "dart:typed_data" show Uint8List; + +class FaceInfoForClustering { + final String faceID; + int? clusterId; + final Uint8List embeddingBytes; + final double faceScore; + final double blurValue; + final bool isSideways; + int? _fileID; + + int get fileID { + _fileID ??= int.parse(faceID.split('_').first); + return _fileID!; + } + + FaceInfoForClustering({ + required this.faceID, + this.clusterId, + required this.embeddingBytes, + required this.faceScore, + required this.blurValue, + this.isSideways = false, + }); +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart new file mode 100644 index 000000000..de8535c87 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/detection.dart @@ -0,0 +1,516 @@ +import 'dart:math' show max, min, pow, sqrt; + +import "package:photos/face/model/dimension.dart"; + +enum FaceDirection { left, right, straight } + +extension FaceDirectionExtension on FaceDirection { + String toDirectionString() { + switch (this) { + case FaceDirection.left: + return 'Left'; + case FaceDirection.right: + return 'Right'; + case FaceDirection.straight: + return 'Straight'; + default: + throw Exception('Unknown FaceDirection'); + } + } +} + +abstract class Detection { + final double score; + + Detection({required this.score}); + + const Detection.empty() : score = 0; + + get width; + get height; + + @override + String toString(); +} + +@Deprecated('Old method only used in other deprecated methods') +extension BBoxExtension on List { + void roundBoxToDouble() { + final widthRounded = (this[2] - this[0]).roundToDouble(); + final heightRounded = (this[3] - this[1]).roundToDouble(); + this[0] = this[0].roundToDouble(); + this[1] = this[1].roundToDouble(); + this[2] = this[0] + widthRounded; + this[3] = this[1] + heightRounded; + } + + // double get xMinBox => + // isNotEmpty ? this[0] : throw IndexError.withLength(0, length); + // double get yMinBox => + // length >= 2 ? this[1] : throw IndexError.withLength(1, length); + // double get xMaxBox => + // length >= 3 ? this[2] : throw IndexError.withLength(2, length); + // double get yMaxBox => + // length >= 4 ? this[3] : throw IndexError.withLength(3, length); +} + +/// This class represents a face detection with relative coordinates in the range [0, 1]. +/// The coordinates are relative to the image size. The pattern for the coordinates is always [x, y], where x is the horizontal coordinate and y is the vertical coordinate. +/// +/// The [score] attribute is a double representing the confidence of the face detection. +/// +/// The [box] attribute is a list of 4 doubles, representing the coordinates of the bounding box of the face detection. +/// The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox]. +/// +/// The [allKeypoints] attribute is a list of 6 lists of 2 doubles, representing the coordinates of the keypoints of the face detection. +/// The six lists of two values in order are: [leftEye, rightEye, nose, mouth, leftEar, rightEar]. Again, all in [x, y] order. +class FaceDetectionRelative extends Detection { + final List box; + final List> allKeypoints; + + double get xMinBox => box[0]; + double get yMinBox => box[1]; + double get xMaxBox => box[2]; + double get yMaxBox => box[3]; + + List get leftEye => allKeypoints[0]; + List get rightEye => allKeypoints[1]; + List get nose => allKeypoints[2]; + List get leftMouth => allKeypoints[3]; + List get rightMouth => allKeypoints[4]; + + FaceDetectionRelative({ + required double score, + required List box, + required List> allKeypoints, + }) : assert( + box.every((e) => e >= -0.1 && e <= 1.1), + "Bounding box values must be in the range [0, 1], with only a small margin of error allowed.", + ), + assert( + allKeypoints + .every((sublist) => sublist.every((e) => e >= -0.1 && e <= 1.1)), + "All keypoints must be in the range [0, 1], with only a small margin of error allowed.", + ), + box = List.from(box.map((e) => e.clamp(0.0, 1.0))), + allKeypoints = allKeypoints + .map( + (sublist) => + List.from(sublist.map((e) => e.clamp(0.0, 1.0))), + ) + .toList(), + super(score: score); + + factory FaceDetectionRelative.zero() { + return FaceDetectionRelative( + score: 0, + box: [0, 0, 0, 0], + allKeypoints: >[ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + ); + } + + /// This is used to initialize the FaceDetectionRelative object with default values. + /// This constructor is useful because it can be used to initialize a FaceDetectionRelative object as a constant. + /// Contrary to the `FaceDetectionRelative.zero()` constructor, this one gives immutable attributes [box] and [allKeypoints]. + FaceDetectionRelative.defaultInitialization() + : box = const [0, 0, 0, 0], + allKeypoints = const >[ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + super.empty(); + + FaceDetectionRelative getNearestDetection( + List detections, + ) { + if (detections.isEmpty) { + throw ArgumentError("The detection list cannot be empty."); + } + + var nearestDetection = detections[0]; + var minDistance = double.infinity; + + // Calculate the center of the current instance + final centerX1 = (xMinBox + xMaxBox) / 2; + final centerY1 = (yMinBox + yMaxBox) / 2; + + for (var detection in detections) { + final centerX2 = (detection.xMinBox + detection.xMaxBox) / 2; + final centerY2 = (detection.yMinBox + detection.yMaxBox) / 2; + final distance = + sqrt(pow(centerX2 - centerX1, 2) + pow(centerY2 - centerY1, 2)); + if (distance < minDistance) { + minDistance = distance; + nearestDetection = detection; + } + } + return nearestDetection; + } + + void transformRelativeToOriginalImage( + List fromBox, // [xMin, yMin, xMax, yMax] + List toBox, // [xMin, yMin, xMax, yMax] + ) { + // Return if all elements of fromBox and toBox are equal + for (int i = 0; i < fromBox.length; i++) { + if (fromBox[i] != toBox[i]) { + break; + } + if (i == fromBox.length - 1) { + return; + } + } + + // Account for padding + final double paddingXRatio = + (fromBox[0] - toBox[0]) / (toBox[2] - toBox[0]); + final double paddingYRatio = + (fromBox[1] - toBox[1]) / (toBox[3] - toBox[1]); + + // Calculate the scaling and translation + final double scaleX = (fromBox[2] - fromBox[0]) / (1 - 2 * paddingXRatio); + final double scaleY = (fromBox[3] - fromBox[1]) / (1 - 2 * paddingYRatio); + final double translateX = fromBox[0] - paddingXRatio * scaleX; + final double translateY = fromBox[1] - paddingYRatio * scaleY; + + // Transform Box + _transformBox(box, scaleX, scaleY, translateX, translateY); + + // Transform All Keypoints + for (int i = 0; i < allKeypoints.length; i++) { + allKeypoints[i] = _transformPoint( + allKeypoints[i], + scaleX, + scaleY, + translateX, + translateY, + ); + } + } + + void correctForMaintainedAspectRatio( + Dimensions originalSize, + Dimensions newSize, + ) { + // Return if both are the same size, meaning no scaling was done on both width and height + if (originalSize == newSize) { + return; + } + + // Calculate the scaling + final double scaleX = originalSize.width / newSize.width; + final double scaleY = originalSize.height / newSize.height; + const double translateX = 0; + const double translateY = 0; + + // Transform Box + _transformBox(box, scaleX, scaleY, translateX, translateY); + + // Transform All Keypoints + for (int i = 0; i < allKeypoints.length; i++) { + allKeypoints[i] = _transformPoint( + allKeypoints[i], + scaleX, + scaleY, + translateX, + translateY, + ); + } + } + + void _transformBox( + List box, + double scaleX, + double scaleY, + double translateX, + double translateY, + ) { + box[0] = (box[0] * scaleX + translateX).clamp(0.0, 1.0); + box[1] = (box[1] * scaleY + translateY).clamp(0.0, 1.0); + box[2] = (box[2] * scaleX + translateX).clamp(0.0, 1.0); + box[3] = (box[3] * scaleY + translateY).clamp(0.0, 1.0); + } + + List _transformPoint( + List point, + double scaleX, + double scaleY, + double translateX, + double translateY, + ) { + return [ + (point[0] * scaleX + translateX).clamp(0.0, 1.0), + (point[1] * scaleY + translateY).clamp(0.0, 1.0), + ]; + } + + FaceDetectionAbsolute toAbsolute({ + required int imageWidth, + required int imageHeight, + }) { + final scoreCopy = score; + final boxCopy = List.from(box, growable: false); + final allKeypointsCopy = allKeypoints + .map((sublist) => List.from(sublist, growable: false)) + .toList(); + + boxCopy[0] *= imageWidth; + boxCopy[1] *= imageHeight; + boxCopy[2] *= imageWidth; + boxCopy[3] *= imageHeight; + // final intbox = boxCopy.map((e) => e.toInt()).toList(); + + for (List keypoint in allKeypointsCopy) { + keypoint[0] *= imageWidth; + keypoint[1] *= imageHeight; + } + // final intKeypoints = + // allKeypointsCopy.map((e) => e.map((e) => e.toInt()).toList()).toList(); + return FaceDetectionAbsolute( + score: scoreCopy, + box: boxCopy, + allKeypoints: allKeypointsCopy, + ); + } + + String toFaceID({required int fileID}) { + // Assert that the values are within the expected range + assert( + (xMinBox >= 0 && xMinBox <= 1) && + (yMinBox >= 0 && yMinBox <= 1) && + (xMaxBox >= 0 && xMaxBox <= 1) && + (yMaxBox >= 0 && yMaxBox <= 1), + "Bounding box values must be in the range [0, 1]", + ); + + // Extract bounding box values + final String xMin = + xMinBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2); + final String yMin = + yMinBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2); + final String xMax = + xMaxBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2); + final String yMax = + yMaxBox.clamp(0.0, 0.999999).toStringAsFixed(5).substring(2); + + // Convert the bounding box values to string and concatenate + final String rawID = "${xMin}_${yMin}_${xMax}_$yMax"; + + final faceID = fileID.toString() + '_' + rawID.toString(); + + // Return the hexadecimal representation of the hash + return faceID; + } + + /// This method is used to generate a faceID for a face detection that was manually added by the user. + static String toFaceIDEmpty({required int fileID}) { + return fileID.toString() + '_0'; + } + + /// This method is used to check if a faceID corresponds to a manually added face detection and not an actual face detection. + static bool isFaceIDEmpty(String faceID) { + return faceID.split('_')[1] == '0'; + } + + @override + String toString() { + return 'FaceDetectionRelative( with relative coordinates: \n score: $score \n Box: xMinBox: $xMinBox, yMinBox: $yMinBox, xMaxBox: $xMaxBox, yMaxBox: $yMaxBox, \n Keypoints: leftEye: $leftEye, rightEye: $rightEye, nose: $nose, leftMouth: $leftMouth, rightMouth: $rightMouth \n )'; + } + + Map toJson() { + return { + 'score': score, + 'box': box, + 'allKeypoints': allKeypoints, + }; + } + + factory FaceDetectionRelative.fromJson(Map json) { + return FaceDetectionRelative( + score: (json['score'] as num).toDouble(), + box: List.from(json['box']), + allKeypoints: (json['allKeypoints'] as List) + .map((item) => List.from(item)) + .toList(), + ); + } + + @override + + /// The width of the bounding box of the face detection, in relative range [0, 1]. + double get width => xMaxBox - xMinBox; + @override + + /// The height of the bounding box of the face detection, in relative range [0, 1]. + double get height => yMaxBox - yMinBox; +} + +/// This class represents a face detection with absolute coordinates in pixels, in the range [0, imageWidth] for the horizontal coordinates and [0, imageHeight] for the vertical coordinates. +/// The pattern for the coordinates is always [x, y], where x is the horizontal coordinate and y is the vertical coordinate. +/// +/// The [score] attribute is a double representing the confidence of the face detection. +/// +/// The [box] attribute is a list of 4 integers, representing the coordinates of the bounding box of the face detection. +/// The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox]. +/// +/// The [allKeypoints] attribute is a list of 6 lists of 2 integers, representing the coordinates of the keypoints of the face detection. +/// The six lists of two values in order are: [leftEye, rightEye, nose, mouth, leftEar, rightEar]. Again, all in [x, y] order. +class FaceDetectionAbsolute extends Detection { + final List box; + final List> allKeypoints; + + double get xMinBox => box[0]; + double get yMinBox => box[1]; + double get xMaxBox => box[2]; + double get yMaxBox => box[3]; + + List get leftEye => allKeypoints[0]; + List get rightEye => allKeypoints[1]; + List get nose => allKeypoints[2]; + List get leftMouth => allKeypoints[3]; + List get rightMouth => allKeypoints[4]; + + FaceDetectionAbsolute({ + required double score, + required this.box, + required this.allKeypoints, + }) : super(score: score); + + factory FaceDetectionAbsolute._zero() { + return FaceDetectionAbsolute( + score: 0, + box: [0, 0, 0, 0], + allKeypoints: >[ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + ); + } + + FaceDetectionAbsolute.defaultInitialization() + : box = const [0, 0, 0, 0], + allKeypoints = const >[ + [0, 0], + [0, 0], + [0, 0], + [0, 0], + [0, 0], + ], + super.empty(); + + @override + String toString() { + return 'FaceDetectionAbsolute( with absolute coordinates: \n score: $score \n Box: xMinBox: $xMinBox, yMinBox: $yMinBox, xMaxBox: $xMaxBox, yMaxBox: $yMaxBox, \n Keypoints: leftEye: $leftEye, rightEye: $rightEye, nose: $nose, leftMouth: $leftMouth, rightMouth: $rightMouth \n )'; + } + + Map toJson() { + return { + 'score': score, + 'box': box, + 'allKeypoints': allKeypoints, + }; + } + + factory FaceDetectionAbsolute.fromJson(Map json) { + return FaceDetectionAbsolute( + score: (json['score'] as num).toDouble(), + box: List.from(json['box']), + allKeypoints: (json['allKeypoints'] as List) + .map((item) => List.from(item)) + .toList(), + ); + } + + static FaceDetectionAbsolute empty = FaceDetectionAbsolute._zero(); + + @override + + /// The width of the bounding box of the face detection, in number of pixels, range [0, imageWidth]. + double get width => xMaxBox - xMinBox; + @override + + /// The height of the bounding box of the face detection, in number of pixels, range [0, imageHeight]. + double get height => yMaxBox - yMinBox; + + FaceDirection getFaceDirection() { + final double eyeDistanceX = (rightEye[0] - leftEye[0]).abs(); + final double eyeDistanceY = (rightEye[1] - leftEye[1]).abs(); + final double mouthDistanceY = (rightMouth[1] - leftMouth[1]).abs(); + + final bool faceIsUpright = + (max(leftEye[1], rightEye[1]) + 0.5 * eyeDistanceY < nose[1]) && + (nose[1] + 0.5 * mouthDistanceY < min(leftMouth[1], rightMouth[1])); + + final bool noseStickingOutLeft = (nose[0] < min(leftEye[0], rightEye[0])) && + (nose[0] < min(leftMouth[0], rightMouth[0])); + final bool noseStickingOutRight = + (nose[0] > max(leftEye[0], rightEye[0])) && + (nose[0] > max(leftMouth[0], rightMouth[0])); + + final bool noseCloseToLeftEye = + (nose[0] - leftEye[0]).abs() < 0.2 * eyeDistanceX; + final bool noseCloseToRightEye = + (nose[0] - rightEye[0]).abs() < 0.2 * eyeDistanceX; + + // if (faceIsUpright && (noseStickingOutLeft || noseCloseToLeftEye)) { + if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) { + return FaceDirection.left; + // } else if (faceIsUpright && (noseStickingOutRight || noseCloseToRightEye)) { + } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) { + return FaceDirection.right; + } + + return FaceDirection.straight; + } +} + +List relativeToAbsoluteDetections({ + required List relativeDetections, + required int imageWidth, + required int imageHeight, +}) { + final numberOfDetections = relativeDetections.length; + final absoluteDetections = List.filled( + numberOfDetections, + FaceDetectionAbsolute._zero(), + ); + for (var i = 0; i < relativeDetections.length; i++) { + final relativeDetection = relativeDetections[i]; + final absoluteDetection = relativeDetection.toAbsolute( + imageWidth: imageWidth, + imageHeight: imageHeight, + ); + + absoluteDetections[i] = absoluteDetection; + } + + return absoluteDetections; +} + +/// Returns an enlarged version of the [box] by a factor of [factor]. +List getEnlargedRelativeBox(List box, [double factor = 2]) { + final boxCopy = List.from(box, growable: false); + // The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox]. + + final width = boxCopy[2] - boxCopy[0]; + final height = boxCopy[3] - boxCopy[1]; + + boxCopy[0] -= width * (factor - 1) / 2; + boxCopy[1] -= height * (factor - 1) / 2; + boxCopy[2] += width * (factor - 1) / 2; + boxCopy[3] += height * (factor - 1) / 2; + + return boxCopy; +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_exceptions.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_exceptions.dart new file mode 100644 index 000000000..ed2f97791 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_exceptions.dart @@ -0,0 +1,3 @@ +class YOLOFaceInterpreterInitializationException implements Exception {} + +class YOLOFaceInterpreterRunException implements Exception {} diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart new file mode 100644 index 000000000..443df50f2 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/face_detection_service.dart @@ -0,0 +1,788 @@ +import "dart:async"; +import "dart:developer" as dev show log; +import "dart:io" show File; +import "dart:isolate"; +import 'dart:typed_data' show ByteData, Float32List, Uint8List; +import 'dart:ui' as ui show Image; + +import "package:computer/computer.dart"; +import 'package:flutter/material.dart'; +import 'package:logging/logging.dart'; +import 'package:onnxruntime/onnxruntime.dart'; +import "package:photos/face/model/dimension.dart"; +import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; +import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_exceptions.dart'; +import 'package:photos/services/machine_learning/face_ml/face_detection/naive_non_max_suppression.dart'; +import 'package:photos/services/machine_learning/face_ml/face_detection/yolo_filter_extract_detections.dart'; +import "package:photos/services/remote_assets_service.dart"; +import "package:photos/utils/image_ml_isolate.dart"; +import "package:photos/utils/image_ml_util.dart"; +import "package:synchronized/synchronized.dart"; + +enum FaceDetectionOperation { yoloInferenceAndPostProcessing } + +/// This class is responsible for running the face detection model (YOLOv5Face) on ONNX runtime, and can be accessed through the singleton instance [FaceDetectionService.instance]. +class FaceDetectionService { + static final _logger = Logger('YOLOFaceDetectionService'); + + final _computer = Computer.shared(); + + int sessionAddress = 0; + + static const String kModelBucketEndpoint = "https://models.ente.io/"; + static const String kRemoteBucketModelPath = + "yolov5s_face_640_640_dynamic.onnx"; + // static const kRemoteBucketModelPath = "yolov5n_face_640_640.onnx"; + static const String modelRemotePath = + kModelBucketEndpoint + kRemoteBucketModelPath; + + static const int kInputWidth = 640; + static const int kInputHeight = 640; + static const double kIouThreshold = 0.4; + static const double kMinScoreSigmoidThreshold = 0.7; + static const int kNumKeypoints = 5; + + bool isInitialized = false; + + // Isolate things + Timer? _inactivityTimer; + final Duration _inactivityDuration = const Duration(seconds: 30); + + final _initLock = Lock(); + final _computerLock = Lock(); + + late Isolate _isolate; + late ReceivePort _receivePort = ReceivePort(); + late SendPort _mainSendPort; + + bool isSpawned = false; + bool isRunning = false; + + // singleton pattern + FaceDetectionService._privateConstructor(); + + /// Use this instance to access the FaceDetection service. Make sure to call `init()` before using it. + /// e.g. `await FaceDetection.instance.init();` + /// + /// Then you can use `predict()` to get the bounding boxes of the faces, so `FaceDetection.instance.predict(imageData)` + /// + /// config options: yoloV5FaceN // + static final instance = FaceDetectionService._privateConstructor(); + + factory FaceDetectionService() => instance; + + /// Check if the interpreter is initialized, if not initialize it with `loadModel()` + Future init() async { + if (!isInitialized) { + _logger.info('init is called'); + final model = + await RemoteAssetsService.instance.getAsset(modelRemotePath); + final startTime = DateTime.now(); + // Doing this from main isolate since `rootBundle` cannot be accessed outside it + sessionAddress = await _computer.compute( + _loadModel, + param: { + "modelPath": model.path, + }, + ); + final endTime = DateTime.now(); + _logger.info( + "Face detection model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", + ); + if (sessionAddress != -1) { + isInitialized = true; + } + } + } + + Future release() async { + if (isInitialized) { + await _computer + .compute(_releaseModel, param: {'address': sessionAddress}); + isInitialized = false; + sessionAddress = 0; + } + } + + Future initIsolate() async { + return _initLock.synchronized(() async { + if (isSpawned) return; + + _receivePort = ReceivePort(); + + try { + _isolate = await Isolate.spawn( + _isolateMain, + _receivePort.sendPort, + ); + _mainSendPort = await _receivePort.first as SendPort; + isSpawned = true; + + _resetInactivityTimer(); + } catch (e) { + _logger.severe('Could not spawn isolate', e); + isSpawned = false; + } + }); + } + + Future ensureSpawnedIsolate() async { + if (!isSpawned) { + await initIsolate(); + } + } + + /// The main execution function of the isolate. + static void _isolateMain(SendPort mainSendPort) async { + final receivePort = ReceivePort(); + mainSendPort.send(receivePort.sendPort); + + receivePort.listen((message) async { + final functionIndex = message[0] as int; + final function = FaceDetectionOperation.values[functionIndex]; + final args = message[1] as Map; + final sendPort = message[2] as SendPort; + + try { + switch (function) { + case FaceDetectionOperation.yoloInferenceAndPostProcessing: + final inputImageList = args['inputImageList'] as Float32List; + final inputShape = args['inputShape'] as List; + final newSize = args['newSize'] as Dimensions; + final sessionAddress = args['sessionAddress'] as int; + final timeSentToIsolate = args['timeNow'] as DateTime; + final delaySentToIsolate = + DateTime.now().difference(timeSentToIsolate).inMilliseconds; + + final Stopwatch stopwatchPrepare = Stopwatch()..start(); + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchPrepare.stop(); + dev.log( + '[YOLOFaceDetectionService] data preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms', + ); + + stopwatchPrepare.reset(); + stopwatchPrepare.start(); + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + stopwatchPrepare.stop(); + dev.log( + '[YOLOFaceDetectionService] session preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms', + ); + + final stopwatchInterpreter = Stopwatch()..start(); + late final List outputs; + try { + outputs = session.run(runOptions, inputs); + } catch (e, s) { + dev.log( + '[YOLOFaceDetectionService] Error while running inference: $e \n $s', + ); + throw YOLOFaceInterpreterRunException(); + } + stopwatchInterpreter.stop(); + dev.log( + '[YOLOFaceDetectionService] interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + ); + + final relativeDetections = + _yoloPostProcessOutputs(outputs, newSize); + + sendPort + .send((relativeDetections, delaySentToIsolate, DateTime.now())); + break; + } + } catch (e, stackTrace) { + sendPort + .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); + } + }); + } + + /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. + Future _runInIsolate( + (FaceDetectionOperation, Map) message, + ) async { + await ensureSpawnedIsolate(); + _resetInactivityTimer(); + final completer = Completer(); + final answerPort = ReceivePort(); + + _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); + + answerPort.listen((receivedMessage) { + if (receivedMessage is Map && receivedMessage.containsKey('error')) { + // Handle the error + final errorMessage = receivedMessage['error']; + final errorStackTrace = receivedMessage['stackTrace']; + final exception = Exception(errorMessage); + final stackTrace = StackTrace.fromString(errorStackTrace); + completer.completeError(exception, stackTrace); + } else { + completer.complete(receivedMessage); + } + }); + + return completer.future; + } + + /// Resets a timer that kills the isolate after a certain amount of inactivity. + /// + /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) + void _resetInactivityTimer() { + _inactivityTimer?.cancel(); + _inactivityTimer = Timer(_inactivityDuration, () { + _logger.info( + 'Face detection (YOLO ONNX) Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds. Killing isolate.', + ); + disposeIsolate(); + }); + } + + /// Disposes the isolate worker. + void disposeIsolate() { + if (!isSpawned) return; + + isSpawned = false; + _isolate.kill(); + _receivePort.close(); + _inactivityTimer?.cancel(); + } + + /// Detects faces in the given image data. + Future<(List, Dimensions)> predict( + Uint8List imageData, + ) async { + assert(isInitialized); + + final stopwatch = Stopwatch()..start(); + + final stopwatchDecoding = Stopwatch()..start(); + final (inputImageList, originalSize, newSize) = + await ImageMlIsolate.instance.preprocessImageYoloOnnx( + imageData, + normalize: true, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + quality: FilterQuality.medium, + ); + + // final input = [inputImageList]; + final inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchDecoding.stop(); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + _logger.info('original size: $originalSize \n new size: $newSize'); + + // Run inference + final stopwatchInterpreter = Stopwatch()..start(); + List? outputs; + try { + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + outputs = session.run(runOptions, inputs); + // inputOrt.release(); + // runOptions.release(); + } catch (e, s) { + _logger.severe('Error while running inference: $e \n $s'); + throw YOLOFaceInterpreterRunException(); + } + stopwatchInterpreter.stop(); + _logger.info( + 'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + ); + + final relativeDetections = _yoloPostProcessOutputs(outputs, newSize); + + stopwatch.stop(); + _logger.info( + 'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms', + ); + + return (relativeDetections, originalSize); + } + + /// Detects faces in the given image data. + static Future<(List, Dimensions)> predictSync( + ui.Image image, + ByteData imageByteData, + int sessionAddress, + ) async { + assert(sessionAddress != 0 && sessionAddress != -1); + + final stopwatch = Stopwatch()..start(); + + final stopwatchPreprocessing = Stopwatch()..start(); + final (inputImageList, originalSize, newSize) = + await preprocessImageToFloat32ChannelsFirst( + image, + imageByteData, + normalization: 1, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + ); + + // final input = [inputImageList]; + final inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchPreprocessing.stop(); + dev.log( + 'Face detection image preprocessing is finished, in ${stopwatchPreprocessing.elapsedMilliseconds}ms', + ); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchPreprocessing.elapsedMilliseconds}ms', + ); + _logger.info('original size: $originalSize \n new size: $newSize'); + + // Run inference + final stopwatchInterpreter = Stopwatch()..start(); + List? outputs; + try { + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + outputs = session.run(runOptions, inputs); + // inputOrt.release(); + // runOptions.release(); + } catch (e, s) { + _logger.severe('Error while running inference: $e \n $s'); + throw YOLOFaceInterpreterRunException(); + } + stopwatchInterpreter.stop(); + _logger.info( + 'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + ); + + final relativeDetections = _yoloPostProcessOutputs(outputs, newSize); + + stopwatch.stop(); + _logger.info( + 'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms', + ); + + return (relativeDetections, originalSize); + } + + /// Detects faces in the given image data. + Future<(List, Dimensions)> predictInIsolate( + Uint8List imageData, + ) async { + await ensureSpawnedIsolate(); + assert(isInitialized); + + _logger.info('predictInIsolate() is called'); + + final stopwatch = Stopwatch()..start(); + + final stopwatchDecoding = Stopwatch()..start(); + final (inputImageList, originalSize, newSize) = + await ImageMlIsolate.instance.preprocessImageYoloOnnx( + imageData, + normalize: true, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + quality: FilterQuality.medium, + ); + // final input = [inputImageList]; + final inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; + + stopwatchDecoding.stop(); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + _logger.info('original size: $originalSize \n new size: $newSize'); + + final ( + List relativeDetections, + delaySentToIsolate, + timeSentToMain + ) = await _runInIsolate( + ( + FaceDetectionOperation.yoloInferenceAndPostProcessing, + { + 'inputImageList': inputImageList, + 'inputShape': inputShape, + 'newSize': newSize, + 'sessionAddress': sessionAddress, + 'timeNow': DateTime.now(), + } + ), + ) as (List, int, DateTime); + + final delaySentToMain = + DateTime.now().difference(timeSentToMain).inMilliseconds; + + stopwatch.stop(); + _logger.info( + 'predictInIsolate() face detection executed in ${stopwatch.elapsedMilliseconds}ms, with ${delaySentToIsolate}ms delay sent to isolate, and ${delaySentToMain}ms delay sent to main, for a total of ${delaySentToIsolate + delaySentToMain}ms delay due to isolate', + ); + + return (relativeDetections, originalSize); + } + + Future<(List, Dimensions)> predictInComputer( + String imagePath, + ) async { + assert(isInitialized); + + _logger.info('predictInComputer() is called'); + + final stopwatch = Stopwatch()..start(); + + final stopwatchDecoding = Stopwatch()..start(); + final imageData = await File(imagePath).readAsBytes(); + final (inputImageList, originalSize, newSize) = + await ImageMlIsolate.instance.preprocessImageYoloOnnx( + imageData, + normalize: true, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + quality: FilterQuality.medium, + ); + // final input = [inputImageList]; + return await _computerLock.synchronized(() async { + final inputShape = [ + 1, + 3, + kInputHeight, + kInputWidth, + ]; + + stopwatchDecoding.stop(); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + _logger.info('original size: $originalSize \n new size: $newSize'); + + final ( + List relativeDetections, + delaySentToIsolate, + timeSentToMain + ) = await _computer.compute( + inferenceAndPostProcess, + param: { + 'inputImageList': inputImageList, + 'inputShape': inputShape, + 'newSize': newSize, + 'sessionAddress': sessionAddress, + 'timeNow': DateTime.now(), + }, + ) as (List, int, DateTime); + + final delaySentToMain = + DateTime.now().difference(timeSentToMain).inMilliseconds; + + stopwatch.stop(); + _logger.info( + 'predictInIsolate() face detection executed in ${stopwatch.elapsedMilliseconds}ms, with ${delaySentToIsolate}ms delay sent to isolate, and ${delaySentToMain}ms delay sent to main, for a total of ${delaySentToIsolate + delaySentToMain}ms delay due to isolate', + ); + + return (relativeDetections, originalSize); + }); + } + + /// Detects faces in the given image data. + /// This method is optimized for batch processing. + /// + /// `imageDataList`: The image data to analyze. + /// + /// WARNING: Currently this method only returns the detections for the first image in the batch. + /// Change the function to output all detection before actually using it in production. + Future> predictBatch( + List imageDataList, + ) async { + assert(isInitialized); + + final stopwatch = Stopwatch()..start(); + + final stopwatchDecoding = Stopwatch()..start(); + final List inputImageDataLists = []; + final List<(Dimensions, Dimensions)> originalAndNewSizeList = []; + int concatenatedImageInputsLength = 0; + for (final imageData in imageDataList) { + final (inputImageList, originalSize, newSize) = + await ImageMlIsolate.instance.preprocessImageYoloOnnx( + imageData, + normalize: true, + requiredWidth: kInputWidth, + requiredHeight: kInputHeight, + maintainAspectRatio: true, + quality: FilterQuality.medium, + ); + inputImageDataLists.add(inputImageList); + originalAndNewSizeList.add((originalSize, newSize)); + concatenatedImageInputsLength += inputImageList.length; + } + + final inputImageList = Float32List(concatenatedImageInputsLength); + + int offset = 0; + for (int i = 0; i < inputImageDataLists.length; i++) { + final inputImageData = inputImageDataLists[i]; + inputImageList.setRange( + offset, + offset + inputImageData.length, + inputImageData, + ); + offset += inputImageData.length; + } + + // final input = [inputImageList]; + final inputShape = [ + inputImageDataLists.length, + 3, + kInputHeight, + kInputWidth, + ]; + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchDecoding.stop(); + _logger.info( + 'Image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + // _logger.info('original size: $originalSize \n new size: $newSize'); + + _logger.info('interpreter.run is called'); + // Run inference + final stopwatchInterpreter = Stopwatch()..start(); + List? outputs; + try { + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + outputs = session.run(runOptions, inputs); + inputOrt.release(); + runOptions.release(); + } catch (e, s) { + _logger.severe('Error while running inference: $e \n $s'); + throw YOLOFaceInterpreterRunException(); + } + stopwatchInterpreter.stop(); + _logger.info( + 'interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms, or ${stopwatchInterpreter.elapsedMilliseconds / inputImageDataLists.length} ms per image', + ); + + _logger.info('outputs: $outputs'); + + const int imageOutputToUse = 0; + + // // Get output tensors + final nestedResults = + outputs[0]?.value as List>>; // [b, 25200, 16] + final selectedResults = nestedResults[imageOutputToUse]; // [25200, 16] + + // final rawScores = []; + // for (final result in firstResults) { + // rawScores.add(result[4]); + // } + // final rawScoresCopy = List.from(rawScores); + // rawScoresCopy.sort(); + // _logger.info('rawScores minimum: ${rawScoresCopy.first}'); + // _logger.info('rawScores maximum: ${rawScoresCopy.last}'); + + var relativeDetections = yoloOnnxFilterExtractDetections( + kMinScoreSigmoidThreshold, + kInputWidth, + kInputHeight, + results: selectedResults, + ); + + // Release outputs + for (var element in outputs) { + element?.release(); + } + + // Account for the fact that the aspect ratio was maintained + for (final faceDetection in relativeDetections) { + faceDetection.correctForMaintainedAspectRatio( + const Dimensions( + width: kInputWidth, + height: kInputHeight, + ), + originalAndNewSizeList[imageOutputToUse].$2, + ); + } + + // Non-maximum suppression to remove duplicate detections + relativeDetections = naiveNonMaxSuppression( + detections: relativeDetections, + iouThreshold: kIouThreshold, + ); + + if (relativeDetections.isEmpty) { + _logger.info('No face detected'); + return []; + } + + stopwatch.stop(); + _logger.info( + 'predict() face detection executed in ${stopwatch.elapsedMilliseconds}ms', + ); + + return relativeDetections; + } + + static List _yoloPostProcessOutputs( + List? outputs, + Dimensions newSize, + ) { + // // Get output tensors + final nestedResults = + outputs?[0]?.value as List>>; // [1, 25200, 16] + final firstResults = nestedResults[0]; // [25200, 16] + + // final rawScores = []; + // for (final result in firstResults) { + // rawScores.add(result[4]); + // } + // final rawScoresCopy = List.from(rawScores); + // rawScoresCopy.sort(); + // _logger.info('rawScores minimum: ${rawScoresCopy.first}'); + // _logger.info('rawScores maximum: ${rawScoresCopy.last}'); + + var relativeDetections = yoloOnnxFilterExtractDetections( + kMinScoreSigmoidThreshold, + kInputWidth, + kInputHeight, + results: firstResults, + ); + + // Release outputs + // outputs?.forEach((element) { + // element?.release(); + // }); + + // Account for the fact that the aspect ratio was maintained + for (final faceDetection in relativeDetections) { + faceDetection.correctForMaintainedAspectRatio( + const Dimensions( + width: kInputWidth, + height: kInputHeight, + ), + newSize, + ); + } + + // Non-maximum suppression to remove duplicate detections + relativeDetections = naiveNonMaxSuppression( + detections: relativeDetections, + iouThreshold: kIouThreshold, + ); + + return relativeDetections; + } + + /// Initialize the interpreter by loading the model file. + static Future _loadModel(Map args) async { + final sessionOptions = OrtSessionOptions() + ..setInterOpNumThreads(1) + ..setIntraOpNumThreads(1) + ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); + try { + // _logger.info('Loading face embedding model'); + final session = + OrtSession.fromFile(File(args["modelPath"]), sessionOptions); + // _logger.info('Face embedding model loaded'); + return session.address; + } catch (e, _) { + // _logger.severe('Face embedding model not loaded', e, s); + } + return -1; + } + + static Future _releaseModel(Map args) async { + final address = args['address'] as int; + if (address == 0) { + return; + } + final session = OrtSession.fromAddress(address); + session.release(); + return; + } + + static Future<(List, int, DateTime)> + inferenceAndPostProcess( + Map args, + ) async { + final inputImageList = args['inputImageList'] as Float32List; + final inputShape = args['inputShape'] as List; + final newSize = args['newSize'] as Dimensions; + final sessionAddress = args['sessionAddress'] as int; + final timeSentToIsolate = args['timeNow'] as DateTime; + final delaySentToIsolate = + DateTime.now().difference(timeSentToIsolate).inMilliseconds; + + final Stopwatch stopwatchPrepare = Stopwatch()..start(); + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + inputShape, + ); + final inputs = {'input': inputOrt}; + stopwatchPrepare.stop(); + dev.log( + '[YOLOFaceDetectionService] data preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms', + ); + + stopwatchPrepare.reset(); + stopwatchPrepare.start(); + final runOptions = OrtRunOptions(); + final session = OrtSession.fromAddress(sessionAddress); + stopwatchPrepare.stop(); + dev.log( + '[YOLOFaceDetectionService] session preparation is finished, in ${stopwatchPrepare.elapsedMilliseconds}ms', + ); + + final stopwatchInterpreter = Stopwatch()..start(); + late final List outputs; + try { + outputs = session.run(runOptions, inputs); + } catch (e, s) { + dev.log( + '[YOLOFaceDetectionService] Error while running inference: $e \n $s', + ); + throw YOLOFaceInterpreterRunException(); + } + stopwatchInterpreter.stop(); + dev.log( + '[YOLOFaceDetectionService] interpreter.run is finished, in ${stopwatchInterpreter.elapsedMilliseconds} ms', + ); + + final relativeDetections = _yoloPostProcessOutputs(outputs, newSize); + + return (relativeDetections, delaySentToIsolate, DateTime.now()); + } +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/naive_non_max_suppression.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/naive_non_max_suppression.dart new file mode 100644 index 000000000..624181a66 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/naive_non_max_suppression.dart @@ -0,0 +1,49 @@ +import 'dart:math' as math show max, min; + +import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; + +List naiveNonMaxSuppression({ + required List detections, + required double iouThreshold, +}) { + // Sort the detections by score, the highest first + detections.sort((a, b) => b.score.compareTo(a.score)); + + // Loop through the detections and calculate the IOU + for (var i = 0; i < detections.length - 1; i++) { + for (var j = i + 1; j < detections.length; j++) { + final iou = _calculateIOU(detections[i], detections[j]); + if (iou >= iouThreshold) { + detections.removeAt(j); + j--; + } + } + } + return detections; +} + +double _calculateIOU( + FaceDetectionRelative detectionA, + FaceDetectionRelative detectionB, +) { + final areaA = detectionA.width * detectionA.height; + final areaB = detectionB.width * detectionB.height; + + final intersectionMinX = math.max(detectionA.xMinBox, detectionB.xMinBox); + final intersectionMinY = math.max(detectionA.yMinBox, detectionB.yMinBox); + final intersectionMaxX = math.min(detectionA.xMaxBox, detectionB.xMaxBox); + final intersectionMaxY = math.min(detectionA.yMaxBox, detectionB.yMaxBox); + + final intersectionWidth = intersectionMaxX - intersectionMinX; + final intersectionHeight = intersectionMaxY - intersectionMinY; + + if (intersectionWidth < 0 || intersectionHeight < 0) { + return 0.0; // If boxes do not overlap, IoU is 0 + } + + final intersectionArea = intersectionWidth * intersectionHeight; + + final unionArea = areaA + areaB - intersectionArea; + + return intersectionArea / unionArea; +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_detection/yolo_filter_extract_detections.dart b/mobile/lib/services/machine_learning/face_ml/face_detection/yolo_filter_extract_detections.dart new file mode 100644 index 000000000..ec546533a --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_detection/yolo_filter_extract_detections.dart @@ -0,0 +1,95 @@ +import 'dart:developer' as dev show log; + +import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; + +List yoloOnnxFilterExtractDetections( + double minScoreSigmoidThreshold, + int inputWidth, + int inputHeight, { + required List> results, // // [25200, 16] +}) { + final outputDetections = []; + final output = >[]; + + // Go through the raw output and check the scores + for (final result in results) { + // Filter out raw detections with low scores + if (result[4] < minScoreSigmoidThreshold) { + continue; + } + + // Get the raw detection + final rawDetection = List.from(result); + + // Append the processed raw detection to the output + output.add(rawDetection); + } + + if (output.isEmpty) { + double maxScore = 0; + for (final result in results) { + if (result[4] > maxScore) { + maxScore = result[4]; + } + } + dev.log( + 'No face detections found above the minScoreSigmoidThreshold of $minScoreSigmoidThreshold. The max score was $maxScore.', + ); + } + + for (final List rawDetection in output) { + // Get absolute bounding box coordinates in format [xMin, yMin, xMax, yMax] https://github.com/deepcam-cn/yolov5-face/blob/eb23d18defe4a76cc06449a61cd51004c59d2697/utils/general.py#L216 + final xMinAbs = rawDetection[0] - rawDetection[2] / 2; + final yMinAbs = rawDetection[1] - rawDetection[3] / 2; + final xMaxAbs = rawDetection[0] + rawDetection[2] / 2; + final yMaxAbs = rawDetection[1] + rawDetection[3] / 2; + + // Get the relative bounding box coordinates in format [xMin, yMin, xMax, yMax] + final box = [ + xMinAbs / inputWidth, + yMinAbs / inputHeight, + xMaxAbs / inputWidth, + yMaxAbs / inputHeight, + ]; + + // Get the keypoints coordinates in format [x, y] + final allKeypoints = >[ + [ + rawDetection[5] / inputWidth, + rawDetection[6] / inputHeight, + ], + [ + rawDetection[7] / inputWidth, + rawDetection[8] / inputHeight, + ], + [ + rawDetection[9] / inputWidth, + rawDetection[10] / inputHeight, + ], + [ + rawDetection[11] / inputWidth, + rawDetection[12] / inputHeight, + ], + [ + rawDetection[13] / inputWidth, + rawDetection[14] / inputHeight, + ], + ]; + + // Get the score + final score = + rawDetection[4]; // Or should it be rawDetection[4]*rawDetection[15]? + + // Create the relative detection + final detection = FaceDetectionRelative( + score: score, + box: box, + allKeypoints: allKeypoints, + ); + + // Append the relative detection to the output + outputDetections.add(detection); + } + + return outputDetections; +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_exceptions.dart b/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_exceptions.dart new file mode 100644 index 000000000..548b80a95 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_exceptions.dart @@ -0,0 +1,11 @@ +class MobileFaceNetInterpreterInitializationException implements Exception {} + +class MobileFaceNetImagePreprocessingException implements Exception {} + +class MobileFaceNetEmptyInput implements Exception {} + +class MobileFaceNetWrongInputSize implements Exception {} + +class MobileFaceNetWrongInputRange implements Exception {} + +class MobileFaceNetInterpreterRunException implements Exception {} \ No newline at end of file diff --git a/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart b/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart new file mode 100644 index 000000000..777e79376 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart @@ -0,0 +1,249 @@ +import "dart:io" show File; +import 'dart:math' as math show max, min, sqrt; +import 'dart:typed_data' show Float32List; + +import 'package:computer/computer.dart'; +import 'package:logging/logging.dart'; +import 'package:onnxruntime/onnxruntime.dart'; +import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; +import "package:photos/services/remote_assets_service.dart"; +import "package:photos/utils/image_ml_isolate.dart"; +import "package:synchronized/synchronized.dart"; + +/// This class is responsible for running the face embedding model (MobileFaceNet) on ONNX runtime, and can be accessed through the singleton instance [FaceEmbeddingService.instance]. +class FaceEmbeddingService { + static const kModelBucketEndpoint = "https://models.ente.io/"; + static const kRemoteBucketModelPath = "mobilefacenet_opset15.onnx"; + static const modelRemotePath = kModelBucketEndpoint + kRemoteBucketModelPath; + + static const int kInputSize = 112; + static const int kEmbeddingSize = 192; + static const int kNumChannels = 3; + static const bool kPreWhiten = false; + + static final _logger = Logger('FaceEmbeddingOnnx'); + + bool isInitialized = false; + int sessionAddress = 0; + + final _computer = Computer.shared(); + + final _computerLock = Lock(); + + // singleton pattern + FaceEmbeddingService._privateConstructor(); + + /// Use this instance to access the FaceEmbedding service. Make sure to call `init()` before using it. + /// e.g. `await FaceEmbedding.instance.init();` + /// + /// Then you can use `predict()` to get the embedding of a face, so `FaceEmbedding.instance.predict(imageData)` + /// + /// config options: faceEmbeddingEnte + static final instance = FaceEmbeddingService._privateConstructor(); + factory FaceEmbeddingService() => instance; + + /// Check if the interpreter is initialized, if not initialize it with `loadModel()` + Future init() async { + if (!isInitialized) { + _logger.info('init is called'); + final model = + await RemoteAssetsService.instance.getAsset(modelRemotePath); + final startTime = DateTime.now(); + // Doing this from main isolate since `rootBundle` cannot be accessed outside it + sessionAddress = await _computer.compute( + _loadModel, + param: { + "modelPath": model.path, + }, + ); + final endTime = DateTime.now(); + _logger.info( + "Face embedding model loaded, took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", + ); + if (sessionAddress != -1) { + isInitialized = true; + } + } + } + + Future release() async { + if (isInitialized) { + await _computer + .compute(_releaseModel, param: {'address': sessionAddress}); + isInitialized = false; + sessionAddress = 0; + } + } + + static Future _loadModel(Map args) async { + final sessionOptions = OrtSessionOptions() + ..setInterOpNumThreads(1) + ..setIntraOpNumThreads(1) + ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); + try { + // _logger.info('Loading face embedding model'); + final session = + OrtSession.fromFile(File(args["modelPath"]), sessionOptions); + // _logger.info('Face embedding model loaded'); + return session.address; + } catch (e, _) { + // _logger.severe('Face embedding model not loaded', e, s); + } + return -1; + } + + static Future _releaseModel(Map args) async { + final address = args['address'] as int; + if (address == 0) { + return; + } + final session = OrtSession.fromAddress(address); + session.release(); + return; + } + + Future<(List, bool, double)> predictFromImageDataInComputer( + String imagePath, + FaceDetectionRelative face, + ) async { + assert(sessionAddress != 0 && sessionAddress != -1 && isInitialized); + + try { + final stopwatchDecoding = Stopwatch()..start(); + final (inputImageList, _, isBlur, blurValue, _) = + await ImageMlIsolate.instance.preprocessMobileFaceNetOnnx( + imagePath, + [face], + ); + stopwatchDecoding.stop(); + _logger.info( + 'MobileFaceNet image decoding and preprocessing is finished, in ${stopwatchDecoding.elapsedMilliseconds}ms', + ); + + final stopwatch = Stopwatch()..start(); + _logger.info('MobileFaceNet interpreter.run is called'); + final embedding = await _computer.compute( + inferFromMap, + param: { + 'input': inputImageList, + 'address': sessionAddress, + 'inputSize': kInputSize, + }, + taskName: 'createFaceEmbedding', + ) as List; + stopwatch.stop(); + _logger.info( + 'MobileFaceNet interpreter.run is finished, in ${stopwatch.elapsedMilliseconds}ms', + ); + + _logger.info( + 'MobileFaceNet results (only first few numbers): embedding ${embedding.sublist(0, 5)}', + ); + _logger.info( + 'Mean of embedding: ${embedding.reduce((a, b) => a + b) / embedding.length}', + ); + _logger.info( + 'Max of embedding: ${embedding.reduce(math.max)}', + ); + _logger.info( + 'Min of embedding: ${embedding.reduce(math.min)}', + ); + + return (embedding, isBlur[0], blurValue[0]); + } catch (e) { + _logger.info('MobileFaceNet Error while running inference: $e'); + rethrow; + } + } + + Future>> predictInComputer(Float32List input) async { + assert(sessionAddress != 0 && sessionAddress != -1 && isInitialized); + return await _computerLock.synchronized(() async { + try { + final stopwatch = Stopwatch()..start(); + _logger.info('MobileFaceNet interpreter.run is called'); + final embeddings = await _computer.compute( + inferFromMap, + param: { + 'input': input, + 'address': sessionAddress, + 'inputSize': kInputSize, + }, + taskName: 'createFaceEmbedding', + ) as List>; + stopwatch.stop(); + _logger.info( + 'MobileFaceNet interpreter.run is finished, in ${stopwatch.elapsedMilliseconds}ms', + ); + + return embeddings; + } catch (e) { + _logger.info('MobileFaceNet Error while running inference: $e'); + rethrow; + } + }); + } + + static Future>> predictSync( + Float32List input, + int sessionAddress, + ) async { + assert(sessionAddress != 0 && sessionAddress != -1); + try { + final stopwatch = Stopwatch()..start(); + _logger.info('MobileFaceNet interpreter.run is called'); + final embeddings = await infer( + input, + sessionAddress, + kInputSize, + ); + stopwatch.stop(); + _logger.info( + 'MobileFaceNet interpreter.run is finished, in ${stopwatch.elapsedMilliseconds}ms', + ); + + return embeddings; + } catch (e) { + _logger.info('MobileFaceNet Error while running inference: $e'); + rethrow; + } + } + + static Future>> inferFromMap(Map args) async { + final inputImageList = args['input'] as Float32List; + final address = args['address'] as int; + final inputSize = args['inputSize'] as int; + return await infer(inputImageList, address, inputSize); + } + + static Future>> infer( + Float32List inputImageList, + int address, + int inputSize, + ) async { + final runOptions = OrtRunOptions(); + final int numberOfFaces = + inputImageList.length ~/ (inputSize * inputSize * 3); + final inputOrt = OrtValueTensor.createTensorWithDataList( + inputImageList, + [numberOfFaces, inputSize, inputSize, 3], + ); + final inputs = {'img_inputs': inputOrt}; + final session = OrtSession.fromAddress(address); + final List outputs = session.run(runOptions, inputs); + final embeddings = outputs[0]?.value as List>; + + for (final embedding in embeddings) { + double normalization = 0; + for (int i = 0; i < kEmbeddingSize; i++) { + normalization += embedding[i] * embedding[i]; + } + final double sqrtNormalization = math.sqrt(normalization); + for (int i = 0; i < kEmbeddingSize; i++) { + embedding[i] = embedding[i] / sqrtNormalization; + } + } + + return embeddings; + } +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart new file mode 100644 index 000000000..9c8d2d8c8 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart @@ -0,0 +1,155 @@ +import 'package:logging/logging.dart'; +import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; +import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; + +class BlurDetectionService { + final _logger = Logger('BlurDetectionService'); + + // singleton pattern + BlurDetectionService._privateConstructor(); + static final instance = BlurDetectionService._privateConstructor(); + factory BlurDetectionService() => instance; + + Future<(bool, double)> predictIsBlurGrayLaplacian( + List> grayImage, { + int threshold = kLaplacianHardThreshold, + FaceDirection faceDirection = FaceDirection.straight, + }) async { + final List> laplacian = + _applyLaplacian(grayImage, faceDirection: faceDirection); + final double variance = _calculateVariance(laplacian); + _logger.info('Variance: $variance'); + return (variance < threshold, variance); + } + + double _calculateVariance(List> matrix) { + final int numRows = matrix.length; + final int numCols = matrix[0].length; + final int totalElements = numRows * numCols; + + // Calculate the mean + double mean = 0; + for (var row in matrix) { + for (var value in row) { + mean += value; + } + } + mean /= totalElements; + + // Calculate the variance + double variance = 0; + for (var row in matrix) { + for (var value in row) { + final double diff = value - mean; + variance += diff * diff; + } + } + variance /= totalElements; + + return variance; + } + + List> _padImage( + List> image, { + int removeSideColumns = 56, + FaceDirection faceDirection = FaceDirection.straight, + }) { + // Exception is removeSideColumns is not even + if (removeSideColumns % 2 != 0) { + throw Exception('removeSideColumns must be even'); + } + + final int numRows = image.length; + final int numCols = image[0].length; + final int paddedNumCols = numCols + 2 - removeSideColumns; + final int paddedNumRows = numRows + 2; + + // Create a new matrix with extra padding + final List> paddedImage = List.generate( + paddedNumRows, + (i) => List.generate( + paddedNumCols, + (j) => 0, + growable: false, + ), + growable: false, + ); + + // Copy original image into the center of the padded image, taking into account the face direction + if (faceDirection == FaceDirection.straight) { + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < (paddedNumCols - 2); j++) { + paddedImage[i + 1][j + 1] = + image[i][j + (removeSideColumns / 2).round()]; + } + } + // If the face is facing left, we only take the right side of the face image + } else if (faceDirection == FaceDirection.left) { + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < (paddedNumCols - 2); j++) { + paddedImage[i + 1][j + 1] = image[i][j + removeSideColumns]; + } + } + // If the face is facing right, we only take the left side of the face image + } else if (faceDirection == FaceDirection.right) { + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < (paddedNumCols - 2); j++) { + paddedImage[i + 1][j + 1] = image[i][j]; + } + } + } + + // Reflect padding + // Top and bottom rows + for (int j = 1; j <= (paddedNumCols - 2); j++) { + paddedImage[0][j] = paddedImage[2][j]; // Top row + paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row + } + // Left and right columns + for (int i = 0; i < numRows + 2; i++) { + paddedImage[i][0] = paddedImage[i][2]; // Left column + paddedImage[i][paddedNumCols - 1] = + paddedImage[i][paddedNumCols - 3]; // Right column + } + + return paddedImage; + } + + List> _applyLaplacian( + List> image, { + FaceDirection faceDirection = FaceDirection.straight, + }) { + final List> paddedImage = + _padImage(image, faceDirection: faceDirection); + final int numRows = paddedImage.length - 2; + final int numCols = paddedImage[0].length - 2; + final List> outputImage = List.generate( + numRows, + (i) => List.generate(numCols, (j) => 0, growable: false), + growable: false, + ); + + // Define the Laplacian kernel + final List> kernel = [ + [0, 1, 0], + [1, -4, 1], + [0, 1, 0], + ]; + + // Apply the kernel to each pixel + for (int i = 0; i < numRows; i++) { + for (int j = 0; j < numCols; j++) { + int sum = 0; + for (int ki = 0; ki < 3; ki++) { + for (int kj = 0; kj < 3; kj++) { + sum += paddedImage[i + ki][j + kj] * kernel[ki][kj]; + } + } + // Adjust the output value if necessary (e.g., clipping) + outputImage[i][j] = sum; //.clamp(0, 255); + } + } + + return outputImage; + } +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart new file mode 100644 index 000000000..b0f954f8f --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart @@ -0,0 +1,20 @@ +import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart'; + +/// Blur detection threshold +const kLaplacianHardThreshold = 10; +const kLaplacianSoftThreshold = 50; +const kLaplacianVerySoftThreshold = 200; + +/// Default blur value +const kLapacianDefault = 10000.0; + +/// The minimum score for a face to be considered a high quality face for clustering and person detection +const kMinimumQualityFaceScore = 0.80; +const kMediumQualityFaceScore = 0.85; +const kHighQualityFaceScore = 0.90; + +/// The minimum score for a face to be detected, regardless of quality. Use [kMinimumQualityFaceScore] for high quality faces. +const kMinFaceDetectionScore = FaceDetectionService.kMinScoreSigmoidThreshold; + +/// The minimum cluster size for displaying a cluster in the UI +const kMinimumClusterSizeSearchResult = 20; diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_exceptions.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_exceptions.dart new file mode 100644 index 000000000..78a4bcb1f --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_exceptions.dart @@ -0,0 +1,30 @@ + +class GeneralFaceMlException implements Exception { + final String message; + + GeneralFaceMlException(this.message); + + @override + String toString() => 'GeneralFaceMlException: $message'; +} + +class CouldNotRetrieveAnyFileData implements Exception {} + +class CouldNotInitializeFaceDetector implements Exception {} + +class CouldNotRunFaceDetector implements Exception {} + +class CouldNotWarpAffine implements Exception {} + +class CouldNotInitializeFaceEmbeddor implements Exception {} + +class InputProblemFaceEmbeddor implements Exception { + final String message; + + InputProblemFaceEmbeddor(this.message); + + @override + String toString() => 'InputProblemFaceEmbeddor: $message'; +} + +class CouldNotRunFaceEmbeddor implements Exception {} \ No newline at end of file diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_methods.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_methods.dart new file mode 100644 index 000000000..5745234b5 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_methods.dart @@ -0,0 +1,90 @@ +import 'package:photos/services/machine_learning/face_ml/face_ml_version.dart'; + +/// Represents a face detection method with a specific version. +class FaceDetectionMethod extends VersionedMethod { + /// Creates a [FaceDetectionMethod] instance with a specific `method` and `version` (default `1`) + FaceDetectionMethod(String method, {int version = 1}) + : super(method, version); + + /// Creates a [FaceDetectionMethod] instance with 'Empty method' as the method, and a specific `version` (default `1`) + const FaceDetectionMethod.empty() : super.empty(); + + /// Creates a [FaceDetectionMethod] instance with 'BlazeFace' as the method, and a specific `version` (default `1`) + FaceDetectionMethod.blazeFace({int version = 1}) + : super('BlazeFace', version); + + static FaceDetectionMethod fromMlVersion(int version) { + switch (version) { + case 1: + return FaceDetectionMethod.blazeFace(version: version); + default: + return const FaceDetectionMethod.empty(); + } + } + + static FaceDetectionMethod fromJson(Map json) { + return FaceDetectionMethod( + json['method'], + version: json['version'], + ); + } +} + +/// Represents a face alignment method with a specific version. +class FaceAlignmentMethod extends VersionedMethod { + /// Creates a [FaceAlignmentMethod] instance with a specific `method` and `version` (default `1`) + FaceAlignmentMethod(String method, {int version = 1}) + : super(method, version); + + /// Creates a [FaceAlignmentMethod] instance with 'Empty method' as the method, and a specific `version` (default `1`) + const FaceAlignmentMethod.empty() : super.empty(); + + /// Creates a [FaceAlignmentMethod] instance with 'ArcFace' as the method, and a specific `version` (default `1`) + FaceAlignmentMethod.arcFace({int version = 1}) : super('ArcFace', version); + + static FaceAlignmentMethod fromMlVersion(int version) { + switch (version) { + case 1: + return FaceAlignmentMethod.arcFace(version: version); + default: + return const FaceAlignmentMethod.empty(); + } + } + + static FaceAlignmentMethod fromJson(Map json) { + return FaceAlignmentMethod( + json['method'], + version: json['version'], + ); + } +} + +/// Represents a face embedding method with a specific version. +class FaceEmbeddingMethod extends VersionedMethod { + /// Creates a [FaceEmbeddingMethod] instance with a specific `method` and `version` (default `1`) + FaceEmbeddingMethod(String method, {int version = 1}) + : super(method, version); + + /// Creates a [FaceEmbeddingMethod] instance with 'Empty method' as the method, and a specific `version` (default `1`) + const FaceEmbeddingMethod.empty() : super.empty(); + + /// Creates a [FaceEmbeddingMethod] instance with 'MobileFaceNet' as the method, and a specific `version` (default `1`) + FaceEmbeddingMethod.mobileFaceNet({int version = 1}) + : super('MobileFaceNet', version); + + static FaceEmbeddingMethod fromMlVersion(int version) { + switch (version) { + case 1: + return FaceEmbeddingMethod.mobileFaceNet(version: version); + default: + return const FaceEmbeddingMethod.empty(); + } + } + + static FaceEmbeddingMethod fromJson(Map json) { + return FaceEmbeddingMethod( + json['method'], + version: json['version'], + ); + } +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart new file mode 100644 index 000000000..19f954013 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_result.dart @@ -0,0 +1,314 @@ +import "dart:convert" show jsonEncode, jsonDecode; + +import "package:flutter/material.dart" show immutable; +import "package:logging/logging.dart"; +import "package:photos/face/model/dimension.dart"; +import "package:photos/models/file/file.dart"; +import 'package:photos/models/ml/ml_typedefs.dart'; +import "package:photos/models/ml/ml_versions.dart"; +import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; +import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; +import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; +import 'package:photos/services/machine_learning/face_ml/face_ml_methods.dart'; + +final _logger = Logger('ClusterResult_FaceMlResult'); + +@immutable +class FaceMlResult { + final int fileId; + + final List faces; + + final Dimensions decodedImageSize; + + final int mlVersion; + final bool errorOccured; + final bool onlyThumbnailUsed; + + bool get hasFaces => faces.isNotEmpty; + int get numberOfFaces => faces.length; + + List get allFaceEmbeddings { + return faces.map((face) => face.embedding).toList(); + } + + List get allFaceIds { + return faces.map((face) => face.faceId).toList(); + } + + List get fileIdForEveryFace { + return List.filled(faces.length, fileId); + } + + FaceDetectionMethod get faceDetectionMethod => + FaceDetectionMethod.fromMlVersion(mlVersion); + FaceAlignmentMethod get faceAlignmentMethod => + FaceAlignmentMethod.fromMlVersion(mlVersion); + FaceEmbeddingMethod get faceEmbeddingMethod => + FaceEmbeddingMethod.fromMlVersion(mlVersion); + + const FaceMlResult({ + required this.fileId, + required this.faces, + required this.mlVersion, + required this.errorOccured, + required this.onlyThumbnailUsed, + required this.decodedImageSize, + }); + + Map _toJson() => { + 'fileId': fileId, + 'faces': faces.map((face) => face.toJson()).toList(), + 'mlVersion': mlVersion, + 'errorOccured': errorOccured, + 'onlyThumbnailUsed': onlyThumbnailUsed, + 'decodedImageSize': { + 'width': decodedImageSize.width, + 'height': decodedImageSize.height, + }, + }; + + String toJsonString() => jsonEncode(_toJson()); + + static FaceMlResult _fromJson(Map json) { + return FaceMlResult( + fileId: json['fileId'], + faces: (json['faces'] as List) + .map((item) => FaceResult.fromJson(item as Map)) + .toList(), + mlVersion: json['mlVersion'], + errorOccured: json['errorOccured'] ?? false, + onlyThumbnailUsed: json['onlyThumbnailUsed'] ?? false, + decodedImageSize: json['decodedImageSize'] != null + ? Dimensions( + width: json['decodedImageSize']['width'], + height: json['decodedImageSize']['height'], + ) + : json['faceDetectionImageSize'] == null + ? const Dimensions(width: -1, height: -1) + : Dimensions( + width: (json['faceDetectionImageSize']['width'] as double) + .truncate(), + height: (json['faceDetectionImageSize']['height'] as double) + .truncate(), + ), + ); + } + + static FaceMlResult fromJsonString(String jsonString) { + return _fromJson(jsonDecode(jsonString)); + } + + /// Sets the embeddings of the faces with the given faceIds to [10, 10,..., 10]. + /// + /// Throws an exception if a faceId is not found in the FaceMlResult. + void setEmbeddingsToTen(List faceIds) { + for (final faceId in faceIds) { + final faceIndex = faces.indexWhere((face) => face.faceId == faceId); + if (faceIndex == -1) { + throw Exception("No face found with faceId $faceId"); + } + for (var i = 0; i < faces[faceIndex].embedding.length; i++) { + faces[faceIndex].embedding[i] = 10; + } + } + } + + FaceDetectionRelative getDetectionForFaceId(String faceId) { + final faceIndex = faces.indexWhere((face) => face.faceId == faceId); + if (faceIndex == -1) { + throw Exception("No face found with faceId $faceId"); + } + return faces[faceIndex].detection; + } +} + +class FaceMlResultBuilder { + int fileId; + + List faces = []; + + Dimensions decodedImageSize; + + int mlVersion; + bool errorOccured; + bool onlyThumbnailUsed; + + FaceMlResultBuilder({ + this.fileId = -1, + this.mlVersion = faceMlVersion, + this.errorOccured = false, + this.onlyThumbnailUsed = false, + this.decodedImageSize = const Dimensions(width: -1, height: -1), + }); + + FaceMlResultBuilder.fromEnteFile( + EnteFile file, { + this.mlVersion = faceMlVersion, + this.errorOccured = false, + this.onlyThumbnailUsed = false, + this.decodedImageSize = const Dimensions(width: -1, height: -1), + }) : fileId = file.uploadedFileID ?? -1; + + FaceMlResultBuilder.fromEnteFileID( + int fileID, { + this.mlVersion = faceMlVersion, + this.errorOccured = false, + this.onlyThumbnailUsed = false, + this.decodedImageSize = const Dimensions(width: -1, height: -1), + }) : fileId = fileID; + + void addNewlyDetectedFaces( + List faceDetections, + Dimensions originalSize, + ) { + decodedImageSize = originalSize; + for (var i = 0; i < faceDetections.length; i++) { + faces.add( + FaceResultBuilder.fromFaceDetection( + faceDetections[i], + resultBuilder: this, + ), + ); + } + } + + void addAlignmentResults( + List alignmentResults, + List blurValues, + ) { + if (alignmentResults.length != faces.length) { + throw Exception( + "The amount of alignment results (${alignmentResults.length}) does not match the number of faces (${faces.length})", + ); + } + + for (var i = 0; i < alignmentResults.length; i++) { + faces[i].alignment = alignmentResults[i]; + faces[i].blurValue = blurValues[i]; + } + } + + void addEmbeddingsToExistingFaces( + List embeddings, + ) { + if (embeddings.length != faces.length) { + throw Exception( + "The amount of embeddings (${embeddings.length}) does not match the number of faces (${faces.length})", + ); + } + for (var faceIndex = 0; faceIndex < faces.length; faceIndex++) { + faces[faceIndex].embedding = embeddings[faceIndex]; + } + } + + FaceMlResult build() { + final faceResults = []; + for (var i = 0; i < faces.length; i++) { + faceResults.add(faces[i].build()); + } + return FaceMlResult( + fileId: fileId, + faces: faceResults, + mlVersion: mlVersion, + errorOccured: errorOccured, + onlyThumbnailUsed: onlyThumbnailUsed, + decodedImageSize: decodedImageSize, + ); + } + + FaceMlResult buildNoFaceDetected() { + faces = []; + return build(); + } + + FaceMlResult buildErrorOccurred() { + faces = []; + errorOccured = true; + return build(); + } +} + +@immutable +class FaceResult { + final FaceDetectionRelative detection; + final double blurValue; + final AlignmentResult alignment; + final Embedding embedding; + final int fileId; + final String faceId; + + bool get isBlurry => blurValue < kLaplacianHardThreshold; + + const FaceResult({ + required this.detection, + required this.blurValue, + required this.alignment, + required this.embedding, + required this.fileId, + required this.faceId, + }); + + Map toJson() => { + 'detection': detection.toJson(), + 'blurValue': blurValue, + 'alignment': alignment.toJson(), + 'embedding': embedding, + 'fileId': fileId, + 'faceId': faceId, + }; + + static FaceResult fromJson(Map json) { + return FaceResult( + detection: FaceDetectionRelative.fromJson(json['detection']), + blurValue: json['blurValue'], + alignment: AlignmentResult.fromJson(json['alignment']), + embedding: Embedding.from(json['embedding']), + fileId: json['fileId'], + faceId: json['faceId'], + ); + } +} + +class FaceResultBuilder { + FaceDetectionRelative detection = + FaceDetectionRelative.defaultInitialization(); + double blurValue = 1000; + AlignmentResult alignment = AlignmentResult.empty(); + Embedding embedding = []; + int fileId = -1; + String faceId = ''; + + bool get isBlurry => blurValue < kLaplacianHardThreshold; + + FaceResultBuilder({ + required this.fileId, + required this.faceId, + }); + + FaceResultBuilder.fromFaceDetection( + FaceDetectionRelative faceDetection, { + required FaceMlResultBuilder resultBuilder, + }) { + fileId = resultBuilder.fileId; + faceId = faceDetection.toFaceID(fileID: resultBuilder.fileId); + detection = faceDetection; + } + + FaceResult build() { + assert(detection.allKeypoints[0][0] <= 1); + assert(detection.box[0] <= 1); + return FaceResult( + detection: detection, + blurValue: blurValue, + alignment: alignment, + embedding: embedding, + fileId: fileId, + faceId: faceId, + ); + } +} + +int getFileIdFromFaceId(String faceId) { + return int.parse(faceId.split("_")[0]); +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart new file mode 100644 index 000000000..0bfbfcd22 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_service.dart @@ -0,0 +1,1374 @@ +import "dart:async"; +import "dart:developer" as dev show log; +import "dart:io" show File, Platform; +import "dart:isolate"; +import "dart:math" show min; +import "dart:typed_data" show Uint8List, Float32List, ByteData; +import "dart:ui" show Image; + +import "package:computer/computer.dart"; +import "package:dart_ui_isolate/dart_ui_isolate.dart"; +import "package:flutter/foundation.dart" show debugPrint, kDebugMode; +import "package:flutter_image_compress/flutter_image_compress.dart"; +import "package:logging/logging.dart"; +import "package:onnxruntime/onnxruntime.dart"; +import "package:package_info_plus/package_info_plus.dart"; +import "package:photos/core/configuration.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/db/files_db.dart"; +import "package:photos/events/diff_sync_complete_event.dart"; +import "package:photos/events/machine_learning_control_event.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/extensions/list.dart"; +import "package:photos/extensions/stop_watch.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/box.dart"; +import "package:photos/face/model/detection.dart" as face_detection; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/landmark.dart"; +import "package:photos/models/file/extensions/file_props.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/file/file_type.dart"; +import "package:photos/models/ml/ml_versions.dart"; +import "package:photos/service_locator.dart"; +import 'package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart'; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_info_for_clustering.dart"; +import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; +import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_exceptions.dart'; +import 'package:photos/services/machine_learning/face_ml/face_detection/face_detection_service.dart'; +import 'package:photos/services/machine_learning/face_ml/face_embedding/face_embedding_exceptions.dart'; +import 'package:photos/services/machine_learning/face_ml/face_embedding/face_embedding_service.dart'; +import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart'; +import 'package:photos/services/machine_learning/face_ml/face_ml_exceptions.dart'; +import 'package:photos/services/machine_learning/face_ml/face_ml_result.dart'; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import 'package:photos/services/machine_learning/file_ml/file_ml.dart'; +import 'package:photos/services/machine_learning/file_ml/remote_fileml_service.dart'; +import "package:photos/services/search_service.dart"; +import "package:photos/utils/file_util.dart"; +import 'package:photos/utils/image_ml_isolate.dart'; +import "package:photos/utils/image_ml_util.dart"; +import "package:photos/utils/local_settings.dart"; +import "package:photos/utils/network_util.dart"; +import "package:photos/utils/thumbnail_util.dart"; +import "package:synchronized/synchronized.dart"; + +enum FileDataForML { thumbnailData, fileData, compressedFileData } + +enum FaceMlOperation { analyzeImage } + +/// This class is responsible for running the full face ml pipeline on images. +/// +/// WARNING: For getting the ML results needed for the UI, you should use `FaceSearchService` instead of this class! +/// +/// The pipeline consists of face detection, face alignment and face embedding. +class FaceMlService { + final _logger = Logger("FaceMlService"); + + // Flutter isolate things for running the image ml pipeline + Timer? _inactivityTimer; + final Duration _inactivityDuration = const Duration(seconds: 120); + int _activeTasks = 0; + final _initLockIsolate = Lock(); + late DartUiIsolate _isolate; + late ReceivePort _receivePort = ReceivePort(); + late SendPort _mainSendPort; + + bool isIsolateSpawned = false; + + // singleton pattern + FaceMlService._privateConstructor(); + + static final instance = FaceMlService._privateConstructor(); + + factory FaceMlService() => instance; + + final _initLock = Lock(); + final _functionLock = Lock(); + + final _computer = Computer.shared(); + + bool isInitialized = false; + late String client; + + bool canRunMLController = false; + bool isImageIndexRunning = false; + bool isClusteringRunning = false; + bool shouldSyncPeople = false; + + final int _fileDownloadLimit = 15; + final int _embeddingFetchLimit = 200; + + Future init({bool initializeImageMlIsolate = false}) async { + if (LocalSettings.instance.isFaceIndexingEnabled == false) { + return; + } + return _initLock.synchronized(() async { + if (isInitialized) { + return; + } + _logger.info("init called"); + await _computer.compute(initOrtEnv); + try { + await FaceDetectionService.instance.init(); + } catch (e, s) { + _logger.severe("Could not initialize yolo onnx", e, s); + } + if (initializeImageMlIsolate) { + try { + await ImageMlIsolate.instance.init(); + } catch (e, s) { + _logger.severe("Could not initialize image ml isolate", e, s); + } + } + try { + await FaceEmbeddingService.instance.init(); + } catch (e, s) { + _logger.severe("Could not initialize mobilefacenet", e, s); + } + + // Get client name + final packageInfo = await PackageInfo.fromPlatform(); + client = "${packageInfo.packageName}/${packageInfo.version}"; + _logger.info("client: $client"); + + isInitialized = true; + canRunMLController = !Platform.isAndroid || kDebugMode; + + /// hooking FaceML into [MachineLearningController] + if (Platform.isAndroid && !kDebugMode) { + Bus.instance.on().listen((event) { + if (LocalSettings.instance.isFaceIndexingEnabled == false) { + return; + } + canRunMLController = event.shouldRun; + if (canRunMLController) { + unawaited(indexAndClusterAll()); + } else { + pauseIndexing(); + } + }); + } else { + if (!kDebugMode) { + unawaited(indexAndClusterAll()); + } + } + }); + } + + static void initOrtEnv() async { + OrtEnv.instance.init(); + } + + void listenIndexOnDiffSync() { + Bus.instance.on().listen((event) async { + if (LocalSettings.instance.isFaceIndexingEnabled == false || kDebugMode) { + return; + } + // [neeraj] intentional delay in starting indexing on diff sync, this gives time for the user + // to disable face-indexing in case it's causing crash. In the future, we + // should have a better way to handle this. + shouldSyncPeople = true; + Future.delayed(const Duration(seconds: 10), () { + unawaited(indexAndClusterAll()); + }); + }); + } + + void listenOnPeopleChangedSync() { + Bus.instance.on().listen((event) { + shouldSyncPeople = true; + }); + } + + Future ensureInitialized() async { + if (!isInitialized) { + await init(); + } + } + + Future release() async { + return _initLock.synchronized(() async { + _logger.info("dispose called"); + if (!isInitialized) { + return; + } + try { + await FaceDetectionService.instance.release(); + } catch (e, s) { + _logger.severe("Could not dispose yolo onnx", e, s); + } + try { + ImageMlIsolate.instance.dispose(); + } catch (e, s) { + _logger.severe("Could not dispose image ml isolate", e, s); + } + try { + await FaceEmbeddingService.instance.release(); + } catch (e, s) { + _logger.severe("Could not dispose mobilefacenet", e, s); + } + OrtEnv.instance.release(); + isInitialized = false; + }); + } + + Future initIsolate() async { + return _initLockIsolate.synchronized(() async { + if (isIsolateSpawned) return; + _logger.info("initIsolate called"); + + _receivePort = ReceivePort(); + + try { + _isolate = await DartUiIsolate.spawn( + _isolateMain, + _receivePort.sendPort, + ); + _mainSendPort = await _receivePort.first as SendPort; + isIsolateSpawned = true; + + _resetInactivityTimer(); + } catch (e) { + _logger.severe('Could not spawn isolate', e); + isIsolateSpawned = false; + } + }); + } + + Future ensureSpawnedIsolate() async { + if (!isIsolateSpawned) { + await initIsolate(); + } + } + + /// The main execution function of the isolate. + static void _isolateMain(SendPort mainSendPort) async { + final receivePort = ReceivePort(); + mainSendPort.send(receivePort.sendPort); + + receivePort.listen((message) async { + final functionIndex = message[0] as int; + final function = FaceMlOperation.values[functionIndex]; + final args = message[1] as Map; + final sendPort = message[2] as SendPort; + + try { + switch (function) { + case FaceMlOperation.analyzeImage: + final time = DateTime.now(); + final FaceMlResult result = + await FaceMlService.analyzeImageSync(args); + dev.log( + "`analyzeImageSync` function executed in ${DateTime.now().difference(time).inMilliseconds} ms", + ); + sendPort.send(result.toJsonString()); + break; + } + } catch (e, stackTrace) { + dev.log( + "[SEVERE] Error in FaceML isolate: $e", + error: e, + stackTrace: stackTrace, + ); + sendPort + .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); + } + }); + } + + /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. + Future _runInIsolate( + (FaceMlOperation, Map) message, + ) async { + await ensureSpawnedIsolate(); + return _functionLock.synchronized(() async { + _resetInactivityTimer(); + + if (isImageIndexRunning == false || canRunMLController == false) { + return null; + } + + final completer = Completer(); + final answerPort = ReceivePort(); + + _activeTasks++; + _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); + + answerPort.listen((receivedMessage) { + if (receivedMessage is Map && receivedMessage.containsKey('error')) { + // Handle the error + final errorMessage = receivedMessage['error']; + final errorStackTrace = receivedMessage['stackTrace']; + final exception = Exception(errorMessage); + final stackTrace = StackTrace.fromString(errorStackTrace); + completer.completeError(exception, stackTrace); + } else { + completer.complete(receivedMessage); + } + }); + _activeTasks--; + + return completer.future; + }); + } + + /// Resets a timer that kills the isolate after a certain amount of inactivity. + /// + /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) + void _resetInactivityTimer() { + _inactivityTimer?.cancel(); + _inactivityTimer = Timer(_inactivityDuration, () { + if (_activeTasks > 0) { + _logger.info('Tasks are still running. Delaying isolate disposal.'); + // Optionally, reschedule the timer to check again later. + _resetInactivityTimer(); + } else { + _logger.info( + 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', + ); + disposeIsolate(); + } + }); + } + + void disposeIsolate() async { + if (!isIsolateSpawned) return; + await release(); + + isIsolateSpawned = false; + _isolate.kill(); + _receivePort.close(); + _inactivityTimer?.cancel(); + } + + Future indexAndClusterAll() async { + if (isClusteringRunning || isImageIndexRunning) { + _logger.info("indexing or clustering is already running, skipping"); + return; + } + if (shouldSyncPeople) { + await PersonService.instance.reconcileClusters(); + shouldSyncPeople = false; + } + await indexAllImages(); + final indexingCompleteRatio = await _getIndexedDoneRatio(); + if (indexingCompleteRatio < 0.95) { + _logger.info( + "Indexing is not far enough, skipping clustering. Indexing is at $indexingCompleteRatio", + ); + return; + } else { + await clusterAllImages(); + } + } + + Future clusterAllImages({ + double minFaceScore = kMinimumQualityFaceScore, + bool clusterInBuckets = true, + }) async { + if (!canRunMLController) { + _logger + .info("MLController does not allow running ML, skipping clustering"); + return; + } + if (isClusteringRunning) { + _logger.info("clusterAllImages is already running, skipping"); + return; + } + // verify faces is enabled + if (LocalSettings.instance.isFaceIndexingEnabled == false) { + _logger.warning("clustering is disabled by user"); + return; + } + + final indexingCompleteRatio = await _getIndexedDoneRatio(); + if (indexingCompleteRatio < 0.95) { + _logger.info( + "Indexing is not far enough, skipping clustering. Indexing is at $indexingCompleteRatio", + ); + return; + } + + _logger.info("`clusterAllImages()` called"); + isClusteringRunning = true; + final clusterAllImagesTime = DateTime.now(); + + try { + // Get a sense of the total number of faces in the database + final int totalFaces = await FaceMLDataDB.instance + .getTotalFaceCount(minFaceScore: minFaceScore); + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + final startEmbeddingFetch = DateTime.now(); + // read all embeddings + final result = await FaceMLDataDB.instance.getFaceInfoForClustering( + minScore: minFaceScore, + maxFaces: totalFaces, + ); + final Set missingFileIDs = {}; + final allFaceInfoForClustering = []; + for (final faceInfo in result) { + if (!fileIDToCreationTime.containsKey(faceInfo.fileID)) { + missingFileIDs.add(faceInfo.fileID); + } else { + allFaceInfoForClustering.add(faceInfo); + } + } + // sort the embeddings based on file creation time, oldest first + allFaceInfoForClustering.sort((a, b) { + return fileIDToCreationTime[a.fileID]! + .compareTo(fileIDToCreationTime[b.fileID]!); + }); + _logger.info( + 'Getting and sorting embeddings took ${DateTime.now().difference(startEmbeddingFetch).inMilliseconds} ms for ${allFaceInfoForClustering.length} embeddings' + 'and ${missingFileIDs.length} missing fileIDs', + ); + + // Get the current cluster statistics + final Map oldClusterSummaries = + await FaceMLDataDB.instance.getAllClusterSummary(); + + if (clusterInBuckets) { + const int bucketSize = 20000; + const int offsetIncrement = 7500; + int offset = 0; + int bucket = 1; + + while (true) { + if (!canRunMLController) { + _logger.info( + "MLController does not allow running ML, stopping before clustering bucket $bucket", + ); + break; + } + if (offset > allFaceInfoForClustering.length - 1) { + _logger.warning( + 'faceIdToEmbeddingBucket is empty, this should ideally not happen as it should have stopped earlier. offset: $offset, totalFaces: $totalFaces', + ); + break; + } + if (offset > totalFaces) { + _logger.warning( + 'offset > totalFaces, this should ideally not happen. offset: $offset, totalFaces: $totalFaces', + ); + break; + } + + final bucketStartTime = DateTime.now(); + final faceInfoForClustering = allFaceInfoForClustering.sublist( + offset, + min(offset + bucketSize, allFaceInfoForClustering.length), + ); + + final clusteringResult = + await FaceClusteringService.instance.predictLinear( + faceInfoForClustering.toSet(), + fileIDToCreationTime: fileIDToCreationTime, + offset: offset, + oldClusterSummaries: oldClusterSummaries, + ); + if (clusteringResult == null) { + _logger.warning("faceIdToCluster is null"); + return; + } + + await FaceMLDataDB.instance + .updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusteringResult.newClusterSummaries!); + for (final faceInfo in faceInfoForClustering) { + faceInfo.clusterId ??= + clusteringResult.newFaceIdToCluster[faceInfo.faceID]; + } + for (final clusterUpdate + in clusteringResult.newClusterSummaries!.entries) { + oldClusterSummaries[clusterUpdate.key] = clusterUpdate.value; + } + _logger.info( + 'Done with clustering ${offset + faceInfoForClustering.length} embeddings (${(100 * (offset + faceInfoForClustering.length) / totalFaces).toStringAsFixed(0)}%) in bucket $bucket, offset: $offset, in ${DateTime.now().difference(bucketStartTime).inSeconds} seconds', + ); + if (offset + bucketSize >= totalFaces) { + _logger.info('All faces clustered'); + break; + } + offset += offsetIncrement; + bucket++; + } + } else { + final clusterStartTime = DateTime.now(); + // Cluster the embeddings using the linear clustering algorithm, returning a map from faceID to clusterID + final clusteringResult = + await FaceClusteringService.instance.predictLinear( + allFaceInfoForClustering.toSet(), + fileIDToCreationTime: fileIDToCreationTime, + oldClusterSummaries: oldClusterSummaries, + ); + if (clusteringResult == null) { + _logger.warning("faceIdToCluster is null"); + return; + } + final clusterDoneTime = DateTime.now(); + _logger.info( + 'done with clustering ${allFaceInfoForClustering.length} in ${clusterDoneTime.difference(clusterStartTime).inSeconds} seconds ', + ); + + // Store the updated clusterIDs in the database + _logger.info( + 'Updating ${clusteringResult.newFaceIdToCluster.length} FaceIDs with clusterIDs in the DB', + ); + await FaceMLDataDB.instance + .updateFaceIdToClusterId(clusteringResult.newFaceIdToCluster); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusteringResult.newClusterSummaries!); + _logger.info('Done updating FaceIDs with clusterIDs in the DB, in ' + '${DateTime.now().difference(clusterDoneTime).inSeconds} seconds'); + } + Bus.instance.fire(PeopleChangedEvent()); + _logger.info('clusterAllImages() finished, in ' + '${DateTime.now().difference(clusterAllImagesTime).inSeconds} seconds'); + isClusteringRunning = false; + } catch (e, s) { + _logger.severe("`clusterAllImages` failed", e, s); + } + } + + /// Analyzes all the images in the database with the latest ml version and stores the results in the database. + /// + /// This function first checks if the image has already been analyzed with the lastest faceMlVersion and stored in the database. If so, it skips the image. + Future indexAllImages({int retryFetchCount = 10}) async { + if (isImageIndexRunning) { + _logger.warning("indexAllImages is already running, skipping"); + return; + } + // verify faces is enabled + if (LocalSettings.instance.isFaceIndexingEnabled == false) { + _logger.warning("indexing is disabled by user"); + return; + } + try { + isImageIndexRunning = true; + _logger.info('starting image indexing'); + + final w = (kDebugMode ? EnteWatch('prepare indexing files') : null) + ?..start(); + final Map alreadyIndexedFiles = + await FaceMLDataDB.instance.getIndexedFileIds(); + w?.log('getIndexedFileIds'); + final List enteFiles = + await SearchService.instance.getAllFiles(); + w?.log('getAllFiles'); + + // Make sure the image conversion isolate is spawned + // await ImageMlIsolate.instance.ensureSpawned(); + await ensureInitialized(); + + int fileAnalyzedCount = 0; + int fileSkippedCount = 0; + final stopwatch = Stopwatch()..start(); + final List filesWithLocalID = []; + final List filesWithoutLocalID = []; + final List hiddenFilesToIndex = []; + w?.log('getIndexableFileIDs'); + + for (final EnteFile enteFile in enteFiles) { + if (_skipAnalysisEnteFile(enteFile, alreadyIndexedFiles)) { + fileSkippedCount++; + continue; + } + if ((enteFile.localID ?? '').isEmpty) { + filesWithoutLocalID.add(enteFile); + } else { + filesWithLocalID.add(enteFile); + } + } + w?.log('sifting through all normal files'); + final List hiddenFiles = + await SearchService.instance.getHiddenFiles(); + w?.log('getHiddenFiles: ${hiddenFiles.length} hidden files'); + for (final EnteFile enteFile in hiddenFiles) { + if (_skipAnalysisEnteFile(enteFile, alreadyIndexedFiles)) { + fileSkippedCount++; + continue; + } + hiddenFilesToIndex.add(enteFile); + } + + // list of files where files with localID are first + final sortedBylocalID = []; + sortedBylocalID.addAll(filesWithLocalID); + sortedBylocalID.addAll(filesWithoutLocalID); + sortedBylocalID.addAll(hiddenFilesToIndex); + w?.log('preparing all files to index'); + final List> chunks = + sortedBylocalID.chunks(_embeddingFetchLimit); + outerLoop: + for (final chunk in chunks) { + final futures = >[]; + + if (LocalSettings.instance.remoteFetchEnabled) { + try { + final List fileIds = []; + // Try to find embeddings on the remote server + for (final f in chunk) { + fileIds.add(f.uploadedFileID!); + } + final EnteWatch? w = + flagService.internalUser ? EnteWatch("face_em_fetch") : null; + w?.start(); + w?.log('starting remote fetch for ${fileIds.length} files'); + final res = + await RemoteFileMLService.instance.getFilessEmbedding(fileIds); + w?.logAndReset('fetched ${res.mlData.length} embeddings'); + final List faces = []; + final remoteFileIdToVersion = {}; + for (FileMl fileMl in res.mlData.values) { + if (shouldDiscardRemoteEmbedding(fileMl)) continue; + if (fileMl.faceEmbedding.faces.isEmpty) { + faces.add( + Face.empty( + fileMl.fileID, + ), + ); + } else { + for (final f in fileMl.faceEmbedding.faces) { + f.fileInfo = FileInfo( + imageHeight: fileMl.height, + imageWidth: fileMl.width, + ); + faces.add(f); + } + } + remoteFileIdToVersion[fileMl.fileID] = + fileMl.faceEmbedding.version; + } + if (res.noEmbeddingFileIDs.isNotEmpty) { + _logger.info( + 'No embeddings found for ${res.noEmbeddingFileIDs.length} files', + ); + for (final fileID in res.noEmbeddingFileIDs) { + faces.add(Face.empty(fileID, error: false)); + remoteFileIdToVersion[fileID] = faceMlVersion; + } + } + + await FaceMLDataDB.instance.bulkInsertFaces(faces); + w?.logAndReset('stored embeddings'); + for (final entry in remoteFileIdToVersion.entries) { + alreadyIndexedFiles[entry.key] = entry.value; + } + _logger + .info('already indexed files ${remoteFileIdToVersion.length}'); + } catch (e, s) { + _logger.severe("err while getting files embeddings", e, s); + if (retryFetchCount < 1000) { + Future.delayed(Duration(seconds: retryFetchCount), () { + unawaited(indexAllImages(retryFetchCount: retryFetchCount * 2)); + }); + return; + } else { + _logger.severe( + "Failed to fetch embeddings for files after multiple retries", + e, + s, + ); + rethrow; + } + } + } + if (!await canUseHighBandwidth()) { + continue; + } + final smallerChunks = chunk.chunks(_fileDownloadLimit); + for (final smallestChunk in smallerChunks) { + for (final enteFile in smallestChunk) { + if (isImageIndexRunning == false) { + _logger.info("indexAllImages() was paused, stopping"); + break outerLoop; + } + if (_skipAnalysisEnteFile( + enteFile, + alreadyIndexedFiles, + )) { + fileSkippedCount++; + continue; + } + futures.add(processImage(enteFile)); + } + final awaitedFutures = await Future.wait(futures); + final sumFutures = awaitedFutures.fold( + 0, + (previousValue, element) => previousValue + (element ? 1 : 0), + ); + fileAnalyzedCount += sumFutures; + } + } + + stopwatch.stop(); + _logger.info( + "`indexAllImages()` finished. Analyzed $fileAnalyzedCount images, in ${stopwatch.elapsed.inSeconds} seconds (avg of ${stopwatch.elapsed.inSeconds / fileAnalyzedCount} seconds per image, skipped $fileSkippedCount images. MLController status: $canRunMLController)", + ); + } catch (e, s) { + _logger.severe("indexAllImages failed", e, s); + } finally { + isImageIndexRunning = false; + } + } + + bool shouldDiscardRemoteEmbedding(FileMl fileMl) { + if (fileMl.faceEmbedding.version < faceMlVersion) { + debugPrint("Discarding remote embedding for fileID ${fileMl.fileID} " + "because version is ${fileMl.faceEmbedding.version} and we need $faceMlVersion"); + return true; + } + // are all landmarks equal? + bool allLandmarksEqual = true; + if (fileMl.faceEmbedding.faces.isEmpty) { + debugPrint("No face for ${fileMl.fileID}"); + allLandmarksEqual = false; + } + for (final face in fileMl.faceEmbedding.faces) { + if (face.detection.landmarks.isEmpty) { + allLandmarksEqual = false; + break; + } + if (face.detection.landmarks + .any((landmark) => landmark.x != landmark.y)) { + allLandmarksEqual = false; + break; + } + } + if (allLandmarksEqual) { + debugPrint("Discarding remote embedding for fileID ${fileMl.fileID} " + "because landmarks are equal"); + debugPrint( + fileMl.faceEmbedding.faces + .map((e) => e.detection.landmarks.toString()) + .toList() + .toString(), + ); + return true; + } + if (fileMl.width == null || fileMl.height == null) { + debugPrint("Discarding remote embedding for fileID ${fileMl.fileID} " + "because width is null"); + return true; + } + return false; + } + + Future processImage(EnteFile enteFile) async { + _logger.info( + "`processImage` start processing image with uploadedFileID: ${enteFile.uploadedFileID}", + ); + + try { + final FaceMlResult? result = await analyzeImageInSingleIsolate( + enteFile, + // preferUsingThumbnailForEverything: false, + // disposeImageIsolateAfterUse: false, + ); + if (result == null) { + return false; + } + final List faces = []; + if (!result.hasFaces) { + debugPrint( + 'No faces detected for file with name:${enteFile.displayName}', + ); + faces.add( + Face.empty(result.fileId, error: result.errorOccured), + ); + } else { + if (result.decodedImageSize.width == -1 || + result.decodedImageSize.height == -1) { + _logger + .severe("decodedImageSize is not stored correctly for image with " + "ID: ${enteFile.uploadedFileID}"); + _logger.info( + "Using aligned image size for image with ID: ${enteFile.uploadedFileID}. This size is ${result.decodedImageSize.width}x${result.decodedImageSize.height} compared to size of ${enteFile.width}x${enteFile.height} in the metadata", + ); + } + for (int i = 0; i < result.faces.length; ++i) { + final FaceResult faceRes = result.faces[i]; + final detection = face_detection.Detection( + box: FaceBox( + xMin: faceRes.detection.xMinBox, + yMin: faceRes.detection.yMinBox, + width: faceRes.detection.width, + height: faceRes.detection.height, + ), + landmarks: faceRes.detection.allKeypoints + .map( + (keypoint) => Landmark( + x: keypoint[0], + y: keypoint[1], + ), + ) + .toList(), + ); + faces.add( + Face( + faceRes.faceId, + result.fileId, + faceRes.embedding, + faceRes.detection.score, + detection, + faceRes.blurValue, + fileInfo: FileInfo( + imageHeight: result.decodedImageSize.height, + imageWidth: result.decodedImageSize.width, + ), + ), + ); + } + } + _logger.info("inserting ${faces.length} faces for ${result.fileId}"); + if (!result.errorOccured) { + await RemoteFileMLService.instance.putFileEmbedding( + enteFile, + FileMl( + enteFile.uploadedFileID!, + FaceEmbeddings( + faces, + result.mlVersion, + client: client, + ), + height: result.decodedImageSize.height, + width: result.decodedImageSize.width, + ), + ); + } else { + _logger.warning( + 'Skipped putting embedding because of error ${result.toJsonString()}', + ); + } + await FaceMLDataDB.instance.bulkInsertFaces(faces); + return true; + } catch (e, s) { + _logger.severe( + "Failed to analyze using FaceML for image with ID: ${enteFile.uploadedFileID}", + e, + s, + ); + return true; + } + } + + void pauseIndexing() { + isImageIndexRunning = false; + } + + /// Analyzes the given image data by running the full pipeline for faces, using [analyzeImageSync] in the isolate. + Future analyzeImageInSingleIsolate(EnteFile enteFile) async { + _checkEnteFileForID(enteFile); + await ensureInitialized(); + + final String? filePath = + await _getImagePathForML(enteFile, typeOfData: FileDataForML.fileData); + + if (filePath == null) { + _logger.severe( + "Failed to get any data for enteFile with uploadedFileID ${enteFile.uploadedFileID}", + ); + throw CouldNotRetrieveAnyFileData(); + } + + final Stopwatch stopwatch = Stopwatch()..start(); + late FaceMlResult result; + + try { + final resultJsonString = await _runInIsolate( + ( + FaceMlOperation.analyzeImage, + { + "enteFileID": enteFile.uploadedFileID ?? -1, + "filePath": filePath, + "faceDetectionAddress": + FaceDetectionService.instance.sessionAddress, + "faceEmbeddingAddress": + FaceEmbeddingService.instance.sessionAddress, + } + ), + ) as String?; + if (resultJsonString == null) { + return null; + } + result = FaceMlResult.fromJsonString(resultJsonString); + } catch (e, s) { + _logger.severe( + "Could not analyze image with ID ${enteFile.uploadedFileID} \n", + e, + s, + ); + debugPrint( + "This image with ID ${enteFile.uploadedFileID} has name ${enteFile.displayName}.", + ); + final resultBuilder = FaceMlResultBuilder.fromEnteFile(enteFile); + return resultBuilder.buildErrorOccurred(); + } + stopwatch.stop(); + _logger.info( + "Finished Analyze image (${result.faces.length} faces) with uploadedFileID ${enteFile.uploadedFileID}, in " + "${stopwatch.elapsedMilliseconds} ms (including time waiting for inference engine availability)", + ); + + return result; + } + + static Future analyzeImageSync(Map args) async { + try { + final int enteFileID = args["enteFileID"] as int; + final String imagePath = args["filePath"] as String; + final int faceDetectionAddress = args["faceDetectionAddress"] as int; + final int faceEmbeddingAddress = args["faceEmbeddingAddress"] as int; + + final resultBuilder = FaceMlResultBuilder.fromEnteFileID(enteFileID); + + dev.log( + "Start analyzing image with uploadedFileID: $enteFileID inside the isolate", + ); + final stopwatchTotal = Stopwatch()..start(); + final stopwatch = Stopwatch()..start(); + + // Decode the image once to use for both face detection and alignment + final imageData = await File(imagePath).readAsBytes(); + final image = await decodeImageFromData(imageData); + final ByteData imgByteData = await getByteDataFromImage(image); + dev.log('Reading and decoding image took ' + '${stopwatch.elapsedMilliseconds} ms'); + stopwatch.reset(); + + // Get the faces + final List faceDetectionResult = + await FaceMlService.detectFacesSync( + image, + imgByteData, + faceDetectionAddress, + resultBuilder: resultBuilder, + ); + + dev.log( + "${faceDetectionResult.length} faces detected with scores ${faceDetectionResult.map((e) => e.score).toList()}: completed `detectFacesSync` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + // If no faces were detected, return a result with no faces. Otherwise, continue. + if (faceDetectionResult.isEmpty) { + dev.log( + "No faceDetectionResult, Completed analyzing image with uploadedFileID $enteFileID, in " + "${stopwatch.elapsedMilliseconds} ms"); + return resultBuilder.buildNoFaceDetected(); + } + + stopwatch.reset(); + // Align the faces + final Float32List faceAlignmentResult = + await FaceMlService.alignFacesSync( + image, + imgByteData, + faceDetectionResult, + resultBuilder: resultBuilder, + ); + + dev.log("Completed `alignFacesSync` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + stopwatch.reset(); + // Get the embeddings of the faces + final embeddings = await FaceMlService.embedFacesSync( + faceAlignmentResult, + faceEmbeddingAddress, + resultBuilder: resultBuilder, + ); + + dev.log("Completed `embedFacesSync` function, in " + "${stopwatch.elapsedMilliseconds} ms"); + + stopwatch.stop(); + stopwatchTotal.stop(); + dev.log("Finished Analyze image (${embeddings.length} faces) with " + "uploadedFileID $enteFileID, in " + "${stopwatchTotal.elapsedMilliseconds} ms"); + + return resultBuilder.build(); + } catch (e, s) { + dev.log("Could not analyze image: \n e: $e \n s: $s"); + rethrow; + } + } + + Future _getImagePathForML( + EnteFile enteFile, { + FileDataForML typeOfData = FileDataForML.fileData, + }) async { + String? imagePath; + + switch (typeOfData) { + case FileDataForML.fileData: + final stopwatch = Stopwatch()..start(); + File? file; + if (enteFile.fileType == FileType.video) { + file = await getThumbnailForUploadedFile(enteFile); + } else { + file = await getFile(enteFile, isOrigin: true); + } + if (file == null) { + _logger.warning("Could not get file for $enteFile"); + imagePath = null; + break; + } + imagePath = file.path; + stopwatch.stop(); + _logger.info( + "Getting file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + break; + + case FileDataForML.thumbnailData: + final stopwatch = Stopwatch()..start(); + final File? thumbnail = await getThumbnailForUploadedFile(enteFile); + if (thumbnail == null) { + _logger.warning("Could not get thumbnail for $enteFile"); + imagePath = null; + break; + } + imagePath = thumbnail.path; + stopwatch.stop(); + _logger.info( + "Getting thumbnail data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + break; + + case FileDataForML.compressedFileData: + _logger.warning( + "Getting compressed file data for uploadedFileID ${enteFile.uploadedFileID} is not implemented yet", + ); + imagePath = null; + break; + } + + return imagePath; + } + + @Deprecated('Deprecated in favor of `_getImagePathForML`') + Future _getDataForML( + EnteFile enteFile, { + FileDataForML typeOfData = FileDataForML.fileData, + }) async { + Uint8List? data; + + switch (typeOfData) { + case FileDataForML.fileData: + final stopwatch = Stopwatch()..start(); + final File? actualIoFile = await getFile(enteFile, isOrigin: true); + if (actualIoFile != null) { + data = await actualIoFile.readAsBytes(); + } + stopwatch.stop(); + _logger.info( + "Getting file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + + break; + + case FileDataForML.thumbnailData: + final stopwatch = Stopwatch()..start(); + data = await getThumbnail(enteFile); + stopwatch.stop(); + _logger.info( + "Getting thumbnail data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + break; + + case FileDataForML.compressedFileData: + final stopwatch = Stopwatch()..start(); + final String tempPath = Configuration.instance.getTempDirectory() + + "${enteFile.uploadedFileID!}"; + final File? actualIoFile = await getFile(enteFile); + if (actualIoFile != null) { + final compressResult = await FlutterImageCompress.compressAndGetFile( + actualIoFile.path, + tempPath + ".jpg", + ); + if (compressResult != null) { + data = await compressResult.readAsBytes(); + } + } + stopwatch.stop(); + _logger.info( + "Getting compressed file data for uploadedFileID ${enteFile.uploadedFileID} took ${stopwatch.elapsedMilliseconds} ms", + ); + break; + } + + return data; + } + + /// Detects faces in the given image data. + /// + /// `imageData`: The image data to analyze. + /// + /// Returns a list of face detection results. + /// + /// Throws [CouldNotInitializeFaceDetector], [CouldNotRunFaceDetector] or [GeneralFaceMlException] if something goes wrong. + Future> _detectFacesIsolate( + String imagePath, + // Uint8List fileData, + { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + // Get the bounding boxes of the faces + final (List faces, dataSize) = + await FaceDetectionService.instance.predictInComputer(imagePath); + + // Add detected faces to the resultBuilder + if (resultBuilder != null) { + resultBuilder.addNewlyDetectedFaces(faces, dataSize); + } + + return faces; + } on YOLOFaceInterpreterInitializationException { + throw CouldNotInitializeFaceDetector(); + } on YOLOFaceInterpreterRunException { + throw CouldNotRunFaceDetector(); + } catch (e) { + _logger.severe('Face detection failed: $e'); + throw GeneralFaceMlException('Face detection failed: $e'); + } + } + + /// Detects faces in the given image data. + /// + /// `imageData`: The image data to analyze. + /// + /// Returns a list of face detection results. + /// + /// Throws [CouldNotInitializeFaceDetector], [CouldNotRunFaceDetector] or [GeneralFaceMlException] if something goes wrong. + static Future> detectFacesSync( + Image image, + ByteData imageByteData, + int interpreterAddress, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + // Get the bounding boxes of the faces + final (List faces, dataSize) = + await FaceDetectionService.predictSync( + image, + imageByteData, + interpreterAddress, + ); + + // Add detected faces to the resultBuilder + if (resultBuilder != null) { + resultBuilder.addNewlyDetectedFaces(faces, dataSize); + } + + return faces; + } on YOLOFaceInterpreterInitializationException { + throw CouldNotInitializeFaceDetector(); + } on YOLOFaceInterpreterRunException { + throw CouldNotRunFaceDetector(); + } catch (e) { + dev.log('[SEVERE] Face detection failed: $e'); + throw GeneralFaceMlException('Face detection failed: $e'); + } + } + + /// Aligns multiple faces from the given image data. + /// + /// `imageData`: The image data in [Uint8List] that contains the faces. + /// `faces`: The face detection results in a list of [FaceDetectionAbsolute] for the faces to align. + /// + /// Returns a list of the aligned faces as image data. + /// + /// Throws [CouldNotWarpAffine] or [GeneralFaceMlException] if the face alignment fails. + Future _alignFaces( + String imagePath, + List faces, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + final (alignedFaces, alignmentResults, _, blurValues, _) = + await ImageMlIsolate.instance + .preprocessMobileFaceNetOnnx(imagePath, faces); + + if (resultBuilder != null) { + resultBuilder.addAlignmentResults( + alignmentResults, + blurValues, + ); + } + + return alignedFaces; + } catch (e, s) { + _logger.severe('Face alignment failed: $e', e, s); + throw CouldNotWarpAffine(); + } + } + + /// Aligns multiple faces from the given image data. + /// + /// `imageData`: The image data in [Uint8List] that contains the faces. + /// `faces`: The face detection results in a list of [FaceDetectionAbsolute] for the faces to align. + /// + /// Returns a list of the aligned faces as image data. + /// + /// Throws [CouldNotWarpAffine] or [GeneralFaceMlException] if the face alignment fails. + static Future alignFacesSync( + Image image, + ByteData imageByteData, + List faces, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + final stopwatch = Stopwatch()..start(); + final (alignedFaces, alignmentResults, _, blurValues, _) = + await preprocessToMobileFaceNetFloat32List( + image, + imageByteData, + faces, + ); + stopwatch.stop(); + dev.log( + "Face alignment image decoding and processing took ${stopwatch.elapsedMilliseconds} ms", + ); + + if (resultBuilder != null) { + resultBuilder.addAlignmentResults( + alignmentResults, + blurValues, + ); + } + + return alignedFaces; + } catch (e, s) { + dev.log('[SEVERE] Face alignment failed: $e $s'); + throw CouldNotWarpAffine(); + } + } + + /// Embeds multiple faces from the given input matrices. + /// + /// `facesMatrices`: The input matrices of the faces to embed. + /// + /// Returns a list of the face embeddings as lists of doubles. + /// + /// Throws [CouldNotInitializeFaceEmbeddor], [CouldNotRunFaceEmbeddor], [InputProblemFaceEmbeddor] or [GeneralFaceMlException] if the face embedding fails. + Future>> _embedFaces( + Float32List facesList, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + // Get the embedding of the faces + final List> embeddings = + await FaceEmbeddingService.instance.predictInComputer(facesList); + + // Add the embeddings to the resultBuilder + if (resultBuilder != null) { + resultBuilder.addEmbeddingsToExistingFaces(embeddings); + } + + return embeddings; + } on MobileFaceNetInterpreterInitializationException { + throw CouldNotInitializeFaceEmbeddor(); + } on MobileFaceNetInterpreterRunException { + throw CouldNotRunFaceEmbeddor(); + } on MobileFaceNetEmptyInput { + throw InputProblemFaceEmbeddor("Input is empty"); + } on MobileFaceNetWrongInputSize { + throw InputProblemFaceEmbeddor("Input size is wrong"); + } on MobileFaceNetWrongInputRange { + throw InputProblemFaceEmbeddor("Input range is wrong"); + // ignore: avoid_catches_without_on_clauses + } catch (e) { + _logger.severe('Face embedding (batch) failed: $e'); + throw GeneralFaceMlException('Face embedding (batch) failed: $e'); + } + } + + static Future>> embedFacesSync( + Float32List facesList, + int interpreterAddress, { + FaceMlResultBuilder? resultBuilder, + }) async { + try { + // Get the embedding of the faces + final List> embeddings = + await FaceEmbeddingService.predictSync(facesList, interpreterAddress); + + // Add the embeddings to the resultBuilder + if (resultBuilder != null) { + resultBuilder.addEmbeddingsToExistingFaces(embeddings); + } + + return embeddings; + } on MobileFaceNetInterpreterInitializationException { + throw CouldNotInitializeFaceEmbeddor(); + } on MobileFaceNetInterpreterRunException { + throw CouldNotRunFaceEmbeddor(); + } on MobileFaceNetEmptyInput { + throw InputProblemFaceEmbeddor("Input is empty"); + } on MobileFaceNetWrongInputSize { + throw InputProblemFaceEmbeddor("Input size is wrong"); + } on MobileFaceNetWrongInputRange { + throw InputProblemFaceEmbeddor("Input range is wrong"); + // ignore: avoid_catches_without_on_clauses + } catch (e) { + dev.log('[SEVERE] Face embedding (batch) failed: $e'); + throw GeneralFaceMlException('Face embedding (batch) failed: $e'); + } + } + + /// Checks if the ente file to be analyzed actually can be analyzed: it must be uploaded and in the correct format. + void _checkEnteFileForID(EnteFile enteFile) { + if (_skipAnalysisEnteFile(enteFile, {})) { + _logger.warning( + '''Skipped analysis of image with enteFile, it might be the wrong format or has no uploadedFileID, or MLController doesn't allow it to run. + enteFile: ${enteFile.toString()} + isImageIndexRunning: $isImageIndexRunning + canRunML: $canRunMLController + ''', + ); + throw CouldNotRetrieveAnyFileData(); + } + } + + Future _getIndexedDoneRatio() async { + final w = (kDebugMode ? EnteWatch('_getIndexedDoneRatio') : null)?..start(); + + final int alreadyIndexedCount = await FaceMLDataDB.instance + .getIndexedFileCount(minimumMlVersion: faceMlVersion); + final int totalIndexableCount = (await getIndexableFileIDs()).length; + final ratio = alreadyIndexedCount / totalIndexableCount; + + w?.log('getIndexedDoneRatio'); + + return ratio; + } + + static Future> getIndexableFileIDs() async { + return FilesDB.instance + .getOwnedFileIDs(Configuration.instance.getUserID()!); + } + + bool _skipAnalysisEnteFile(EnteFile enteFile, Map indexedFileIds) { + if (isImageIndexRunning == false || canRunMLController == false) { + return true; + } + // Skip if the file is not uploaded or not owned by the user + if (!enteFile.isUploaded || enteFile.isOwner == false) { + return true; + } + // I don't know how motionPhotos and livePhotos work, so I'm also just skipping them for now + if (enteFile.fileType == FileType.other) { + return true; + } + // Skip if the file is already analyzed with the latest ml version + final id = enteFile.uploadedFileID!; + + return indexedFileIds.containsKey(id) && + indexedFileIds[id]! >= faceMlVersion; + } +} diff --git a/mobile/lib/services/machine_learning/face_ml/face_ml_version.dart b/mobile/lib/services/machine_learning/face_ml/face_ml_version.dart new file mode 100644 index 000000000..a91c4c843 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/face_ml_version.dart @@ -0,0 +1,15 @@ +abstract class VersionedMethod { + final String method; + final int version; + + VersionedMethod(this.method, [this.version = 0]); + + const VersionedMethod.empty() + : method = 'Empty method', + version = 0; + + Map toJson() => { + 'method': method, + 'version': version, + }; +} diff --git a/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart new file mode 100644 index 000000000..8567e8868 --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/feedback/cluster_feedback.dart @@ -0,0 +1,1226 @@ +import 'dart:developer' as dev; +import "dart:math" show Random, min; + +import "package:computer/computer.dart"; +import "package:flutter/foundation.dart"; +import "package:logging/logging.dart"; +import "package:ml_linalg/linalg.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/db/files_db.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/extensions/stop_watch.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/generated/protos/ente/common/vector.pb.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart"; +import "package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart"; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; +import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import "package:photos/services/search_service.dart"; + +class ClusterSuggestion { + final int clusterIDToMerge; + final double distancePersonToCluster; + final bool usedOnlyMeanForSuggestion; + final List filesInCluster; + final List faceIDsInCluster; + + ClusterSuggestion( + this.clusterIDToMerge, + this.distancePersonToCluster, + this.usedOnlyMeanForSuggestion, + this.filesInCluster, + this.faceIDsInCluster, + ); +} + +class ClusterFeedbackService { + final Logger _logger = Logger("ClusterFeedbackService"); + final _computer = Computer.shared(); + ClusterFeedbackService._privateConstructor(); + + static final ClusterFeedbackService instance = + ClusterFeedbackService._privateConstructor(); + + static int lastViewedClusterID = -1; + static setLastViewedClusterID(int clusterID) { + lastViewedClusterID = clusterID; + } + + static resetLastViewedClusterID() { + lastViewedClusterID = -1; + } + + /// Returns a list of cluster suggestions for a person. Each suggestion is a tuple of the following elements: + /// 1. clusterID: the ID of the cluster + /// 2. distance: the distance between the person's cluster and the suggestion + /// 3. bool: whether the suggestion was found using the mean (true) or the median (false) + /// 4. List: the files in the cluster + Future> getSuggestionForPerson( + PersonEntity person, { + bool extremeFilesFirst = true, + }) async { + _logger.info( + 'getSuggestionForPerson ${kDebugMode ? person.data.name : person.remoteID}', + ); + + try { + // Get the suggestions for the person using centroids and median + final startTime = DateTime.now(); + final List<(int, double, bool)> foundSuggestions = + await _getSuggestions(person); + final findSuggestionsTime = DateTime.now(); + _logger.info( + 'getSuggestionForPerson `_getSuggestions`: Found ${foundSuggestions.length} suggestions in ${findSuggestionsTime.difference(startTime).inMilliseconds} ms', + ); + + // Get the files for the suggestions + final suggestionClusterIDs = foundSuggestions.map((e) => e.$1).toSet(); + final Map> fileIdToClusterID = + await FaceMLDataDB.instance.getFileIdToClusterIDSetForCluster( + suggestionClusterIDs, + ); + final clusterIdToFaceIDs = + await FaceMLDataDB.instance.getClusterToFaceIDs(suggestionClusterIDs); + final Map> clusterIDToFiles = {}; + final allFiles = await SearchService.instance.getAllFiles(); + for (final f in allFiles) { + if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) { + continue; + } + final cluserIds = fileIdToClusterID[f.uploadedFileID ?? -1]!; + for (final cluster in cluserIds) { + if (clusterIDToFiles.containsKey(cluster)) { + clusterIDToFiles[cluster]!.add(f); + } else { + clusterIDToFiles[cluster] = [f]; + } + } + } + + final List finalSuggestions = []; + for (final clusterSuggestion in foundSuggestions) { + if (clusterIDToFiles.containsKey(clusterSuggestion.$1)) { + finalSuggestions.add( + ClusterSuggestion( + clusterSuggestion.$1, + clusterSuggestion.$2, + clusterSuggestion.$3, + clusterIDToFiles[clusterSuggestion.$1]!, + clusterIdToFaceIDs[clusterSuggestion.$1]!.toList(), + ), + ); + } + } + final getFilesTime = DateTime.now(); + + final sortingStartTime = DateTime.now(); + if (extremeFilesFirst) { + await _sortSuggestionsOnDistanceToPerson(person, finalSuggestions); + } + _logger.info( + 'getSuggestionForPerson post-processing suggestions took ${DateTime.now().difference(findSuggestionsTime).inMilliseconds} ms, of which sorting took ${DateTime.now().difference(sortingStartTime).inMilliseconds} ms and getting files took ${getFilesTime.difference(findSuggestionsTime).inMilliseconds} ms', + ); + + return finalSuggestions; + } catch (e, s) { + _logger.severe("Error in getClusterFilesForPersonID", e, s); + rethrow; + } + } + + Future removeFilesFromPerson( + List files, + PersonEntity p, + ) async { + try { + // Get the relevant faces to be removed + final faceIDs = await FaceMLDataDB.instance + .getFaceIDsForPerson(p.remoteID) + .then((iterable) => iterable.toList()); + faceIDs.retainWhere((faceID) { + final fileID = getFileIdFromFaceId(faceID); + return files.any((file) => file.uploadedFileID == fileID); + }); + final embeddings = + await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); + + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + + // Re-cluster within the deleted faces + final clusterResult = + await FaceClusteringService.instance.predictWithinClusterComputer( + embeddings, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: 0.20, + ); + if (clusterResult == null || clusterResult.isEmpty) { + return; + } + final newFaceIdToClusterID = clusterResult.newFaceIdToCluster; + + // Update the deleted faces + await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusterResult.newClusterSummaries!); + + // Make sure the deleted faces don't get suggested in the future + final notClusterIdToPersonId = {}; + for (final clusterId in newFaceIdToClusterID.values.toSet()) { + notClusterIdToPersonId[clusterId] = p.remoteID; + } + await FaceMLDataDB.instance + .bulkCaptureNotPersonFeedback(notClusterIdToPersonId); + + Bus.instance.fire(PeopleChangedEvent()); + return; + } catch (e, s) { + _logger.severe("Error in removeFilesFromPerson", e, s); + rethrow; + } + } + + Future removeFilesFromCluster( + List files, + int clusterID, + ) async { + try { + // Get the relevant faces to be removed + final faceIDs = await FaceMLDataDB.instance + .getFaceIDsForCluster(clusterID) + .then((iterable) => iterable.toList()); + faceIDs.retainWhere((faceID) { + final fileID = getFileIdFromFaceId(faceID); + return files.any((file) => file.uploadedFileID == fileID); + }); + final embeddings = + await FaceMLDataDB.instance.getFaceEmbeddingMapForFaces(faceIDs); + + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + + // Re-cluster within the deleted faces + final clusterResult = + await FaceClusteringService.instance.predictWithinClusterComputer( + embeddings, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: 0.20, + ); + if (clusterResult == null || clusterResult.isEmpty) { + return; + } + final newFaceIdToClusterID = clusterResult.newFaceIdToCluster; + + // Update the deleted faces + await FaceMLDataDB.instance.forceUpdateClusterIds(newFaceIdToClusterID); + await FaceMLDataDB.instance + .clusterSummaryUpdate(clusterResult.newClusterSummaries!); + + Bus.instance.fire( + PeopleChangedEvent( + relevantFiles: files, + type: PeopleEventType.removedFilesFromCluster, + source: "$clusterID", + ), + ); + // Bus.instance.fire( + // LocalPhotosUpdatedEvent( + // files, + // type: EventType.peopleClusterChanged, + // source: "$clusterID", + // ), + // ); + return; + } catch (e, s) { + _logger.severe("Error in removeFilesFromCluster", e, s); + rethrow; + } + } + + Future addFilesToCluster(List faceIDs, int clusterID) async { + await FaceMLDataDB.instance.addFacesToCluster(faceIDs, clusterID); + Bus.instance.fire(PeopleChangedEvent()); + return; + } + + Future checkAndDoAutomaticMerges( + PersonEntity p, { + required int personClusterID, + }) async { + final faceMlDb = FaceMLDataDB.instance; + final faceIDs = await faceMlDb.getFaceIDsForCluster(personClusterID); + final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); + if (faceIDs.length < 2 * kMinimumClusterSizeSearchResult) { + final fileIDs = faceIDs.map(getFileIdFromFaceId).toSet(); + if (fileIDs.length < kMinimumClusterSizeSearchResult) { + _logger.info( + 'Cluster $personClusterID has less than $kMinimumClusterSizeSearchResult faces, not doing automatic merges', + ); + return false; + } + } + final allClusterIdsToCountMap = (await faceMlDb.clusterIdToFaceCount()); + _logger.info( + '${kDebugMode ? p.data.name : "private"} has existing clusterID $personClusterID, checking if we can automatically merge more', + ); + + // Get and update the cluster summary to get the avg (centroid) and count + final EnteWatch watch = EnteWatch("ClusterFeedbackService")..start(); + final Map clusterAvg = await _getUpdateClusterAvg( + allClusterIdsToCountMap, + ignoredClusters, + minClusterSize: kMinimumClusterSizeSearchResult, + ); + watch.log('computed avg for ${clusterAvg.length} clusters'); + + // Find the actual closest clusters for the person + final List<(int, double)> suggestions = await calcSuggestionsMeanInComputer( + clusterAvg, + {personClusterID}, + ignoredClusters, + 0.24, + ); + + if (suggestions.isEmpty) { + _logger.info( + 'No automatic merge suggestions for ${kDebugMode ? p.data.name : "private"}', + ); + return false; + } + + // log suggestions + _logger.info( + 'suggestions for ${kDebugMode ? p.data.name : "private"} for cluster ID ${p.remoteID} are suggestions $suggestions}', + ); + + for (final suggestion in suggestions) { + final clusterID = suggestion.$1; + await FaceMLDataDB.instance.assignClusterToPerson( + personID: p.remoteID, + clusterID: clusterID, + ); + } + + Bus.instance.fire(PeopleChangedEvent()); + + return true; + } + + Future ignoreCluster(int clusterID) async { + await PersonService.instance.addPerson('', clusterID); + Bus.instance.fire(PeopleChangedEvent()); + return; + } + + Future> checkForMixedClusters() async { + final faceMlDb = FaceMLDataDB.instance; + final allClusterToFaceCount = await faceMlDb.clusterIdToFaceCount(); + final clustersToInspect = []; + for (final clusterID in allClusterToFaceCount.keys) { + if (allClusterToFaceCount[clusterID]! > 20 && + allClusterToFaceCount[clusterID]! < 500) { + clustersToInspect.add(clusterID); + } + } + + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + + final susClusters = <(int, int)>[]; + + final inspectionStart = DateTime.now(); + for (final clusterID in clustersToInspect) { + final int originalClusterSize = allClusterToFaceCount[clusterID]!; + final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID); + + final embeddings = await faceMlDb.getFaceEmbeddingMapForFaces(faceIDs); + + final clusterResult = + await FaceClusteringService.instance.predictWithinClusterComputer( + embeddings, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: 0.22, + ); + + if (clusterResult == null || + clusterResult.newClusterIdToFaceIds == null || + clusterResult.isEmpty) { + _logger.warning( + '[CheckMixedClusters] Clustering did not seem to work for cluster $clusterID of size ${allClusterToFaceCount[clusterID]}', + ); + continue; + } + + final newClusterIdToCount = + clusterResult.newClusterIdToFaceIds!.map((key, value) { + return MapEntry(key, value.length); + }); + final amountOfNewClusters = newClusterIdToCount.length; + + _logger.info( + '[CheckMixedClusters] Broke up cluster $clusterID into $amountOfNewClusters clusters \n ${newClusterIdToCount.toString()}', + ); + + // Now find the sizes of the biggest and second biggest cluster + final int biggestClusterID = newClusterIdToCount.keys.reduce((a, b) { + return newClusterIdToCount[a]! > newClusterIdToCount[b]! ? a : b; + }); + final int biggestSize = newClusterIdToCount[biggestClusterID]!; + final biggestRatio = biggestSize / originalClusterSize; + if (newClusterIdToCount.length > 1) { + final List clusterIDs = newClusterIdToCount.keys.toList(); + clusterIDs.remove(biggestClusterID); + final int secondBiggestClusterID = clusterIDs.reduce((a, b) { + return newClusterIdToCount[a]! > newClusterIdToCount[b]! ? a : b; + }); + final int secondBiggestSize = + newClusterIdToCount[secondBiggestClusterID]!; + final secondBiggestRatio = secondBiggestSize / originalClusterSize; + + if (biggestRatio < 0.5 || secondBiggestRatio > 0.2) { + final faceIdsOfCluster = + await faceMlDb.getFaceIDsForCluster(clusterID); + final uniqueFileIDs = + faceIdsOfCluster.map(getFileIdFromFaceId).toSet(); + susClusters.add((clusterID, uniqueFileIDs.length)); + _logger.info( + '[CheckMixedClusters] Detected that cluster $clusterID with size ${uniqueFileIDs.length} might be mixed', + ); + } + } else { + _logger.info( + '[CheckMixedClusters] For cluster $clusterID we only found one cluster after reclustering', + ); + } + } + _logger.info( + '[CheckMixedClusters] Inspection took ${DateTime.now().difference(inspectionStart).inSeconds} seconds', + ); + if (susClusters.isNotEmpty) { + _logger.info( + '[CheckMixedClusters] Found ${susClusters.length} clusters that might be mixed: $susClusters', + ); + } else { + _logger.info('[CheckMixedClusters] No mixed clusters found'); + } + return susClusters; + } + + // TODO: iterate over this method to find sweet spot + Future breakUpCluster( + int clusterID, { + bool useDbscan = false, + }) async { + _logger.info( + 'breakUpCluster called for cluster $clusterID with dbscan $useDbscan', + ); + final faceMlDb = FaceMLDataDB.instance; + + final faceIDs = await faceMlDb.getFaceIDsForCluster(clusterID); + final originalFaceIDsSet = faceIDs.toSet(); + + final embeddings = await faceMlDb.getFaceEmbeddingMapForFaces(faceIDs); + + final fileIDToCreationTime = + await FilesDB.instance.getFileIDToCreationTime(); + + final clusterResult = + await FaceClusteringService.instance.predictWithinClusterComputer( + embeddings, + fileIDToCreationTime: fileIDToCreationTime, + distanceThreshold: 0.22, + ); + + if (clusterResult == null || clusterResult.newClusterIdToFaceIds == null || clusterResult.isEmpty) { + _logger.warning('No clusters found or something went wrong'); + return ClusteringResult(newFaceIdToCluster: {}); + } + + final clusterIdToCount = + clusterResult.newClusterIdToFaceIds!.map((key, value) { + return MapEntry(key, value.length); + }); + final amountOfNewClusters = clusterIdToCount.length; + + _logger.info( + 'Broke up cluster $clusterID into $amountOfNewClusters clusters \n ${clusterIdToCount.toString()}', + ); + + if (kDebugMode) { + final Set allClusteredFaceIDsSet = {}; + for (final List value + in clusterResult.newClusterIdToFaceIds!.values) { + allClusteredFaceIDsSet.addAll(value); + } + assert((originalFaceIDsSet.difference(allClusteredFaceIDsSet)).isEmpty); + } + + return clusterResult; + } + + /// WARNING: this method is purely for debugging purposes, never use in production + Future createFakeClustersByBlurValue() async { + try { + // Delete old clusters + await FaceMLDataDB.instance.dropClustersAndPersonTable(); + final List persons = + await PersonService.instance.getPersons(); + for (final PersonEntity p in persons) { + await PersonService.instance.deletePerson(p.remoteID); + } + + // Create new fake clusters based on blur value. One for values between 0 and 10, one for 10-20, etc till 200 + final int startClusterID = DateTime.now().microsecondsSinceEpoch; + final faceIDsToBlurValues = + await FaceMLDataDB.instance.getFaceIDsToBlurValues(200); + final faceIdToCluster = {}; + for (final entry in faceIDsToBlurValues.entries) { + final faceID = entry.key; + final blurValue = entry.value; + final newClusterID = startClusterID + blurValue ~/ 10; + faceIdToCluster[faceID] = newClusterID; + } + await FaceMLDataDB.instance.updateFaceIdToClusterId(faceIdToCluster); + + Bus.instance.fire(PeopleChangedEvent()); + } catch (e, s) { + _logger.severe("Error in createFakeClustersByBlurValue", e, s); + rethrow; + } + } + + Future debugLogClusterBlurValues( + int clusterID, { + int? clusterSize, + bool logClusterSummary = false, + bool logBlurValues = false, + }) async { + if (!kDebugMode) return; + + // Logging the clusterID + _logger.info( + "Debug logging for cluster $clusterID${clusterSize != null ? ' with $clusterSize photos' : ''}", + ); + const int biggestClusterID = 1715061228725148; + + // Logging the cluster summary for the cluster + if (logClusterSummary) { + final summaryMap = await FaceMLDataDB.instance.getClusterToClusterSummary( + [clusterID, biggestClusterID], + ); + final summary = summaryMap[clusterID]; + if (summary != null) { + _logger.info( + "Cluster summary for cluster $clusterID says the amount of faces is: ${summary.$2}", + ); + } + + final biggestClusterSummary = summaryMap[biggestClusterID]; + final clusterSummary = summaryMap[clusterID]; + if (biggestClusterSummary != null && clusterSummary != null) { + _logger.info( + "Cluster summary for biggest cluster $biggestClusterID says the size is: ${biggestClusterSummary.$2}", + ); + _logger.info( + "Cluster summary for current cluster $clusterID says the size is: ${clusterSummary.$2}", + ); + + // Mean distance + final biggestMean = Vector.fromList( + EVector.fromBuffer(biggestClusterSummary.$1).values, + dtype: DType.float32, + ); + final currentMean = Vector.fromList( + EVector.fromBuffer(clusterSummary.$1).values, + dtype: DType.float32, + ); + final bigClustersMeanDistance = + cosineDistanceSIMD(biggestMean, currentMean); + _logger.info( + "Mean distance between biggest cluster and current cluster: $bigClustersMeanDistance", + ); + _logger.info( + 'Element differences between the two means are ${biggestMean - currentMean}', + ); + final currentL2Norm = currentMean.norm(); + _logger.info( + 'L2 norm of current mean: $currentL2Norm', + ); + final trueDistance = + biggestMean.distanceTo(currentMean, distance: Distance.cosine); + _logger.info('True distance between the two means: $trueDistance'); + + // Median distance + const sampleSize = 100; + final Iterable biggestEmbeddings = await FaceMLDataDB + .instance + .getFaceEmbeddingsForCluster(biggestClusterID); + final List biggestSampledEmbeddingsProto = + _randomSampleWithoutReplacement( + biggestEmbeddings, + sampleSize, + ); + final List biggestSampledEmbeddings = + biggestSampledEmbeddingsProto + .map( + (embedding) => Vector.fromList( + EVector.fromBuffer(embedding).values, + dtype: DType.float32, + ), + ) + .toList(growable: false); + + final Iterable currentEmbeddings = + await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); + final List currentSampledEmbeddingsProto = + _randomSampleWithoutReplacement( + currentEmbeddings, + sampleSize, + ); + final List currentSampledEmbeddings = + currentSampledEmbeddingsProto + .map( + (embedding) => Vector.fromList( + EVector.fromBuffer(embedding).values, + dtype: DType.float32, + ), + ) + .toList(growable: false); + + // Calculate distances and find the median + final List distances = []; + final List trueDistances = []; + for (final biggestEmbedding in biggestSampledEmbeddings) { + for (final currentEmbedding in currentSampledEmbeddings) { + distances + .add(cosineDistanceSIMD(biggestEmbedding, currentEmbedding)); + trueDistances.add( + biggestEmbedding.distanceTo( + currentEmbedding, + distance: Distance.cosine, + ), + ); + } + } + distances.sort(); + trueDistances.sort(); + final double medianDistance = distances[distances.length ~/ 2]; + final double trueMedianDistance = + trueDistances[trueDistances.length ~/ 2]; + _logger.info( + "Median distance between biggest cluster and current cluster: $medianDistance (using sample of $sampleSize)", + ); + _logger.info( + 'True distance median between the two embeddings: $trueMedianDistance', + ); + } + } + + // Logging the blur values for the cluster + if (logBlurValues) { + final List blurValues = await FaceMLDataDB.instance + .getBlurValuesForCluster(clusterID) + .then((value) => value.toList()); + final blurValuesIntegers = + blurValues.map((value) => value.round()).toList(); + blurValuesIntegers.sort(); + _logger.info( + "Blur values for cluster $clusterID${clusterSize != null ? ' with $clusterSize photos' : ''}: $blurValuesIntegers", + ); + } + + return; + } + + /// Returns a list of suggestions. For each suggestion we return a record consisting of the following elements: + /// 1. clusterID: the ID of the cluster + /// 2. distance: the distance between the person's cluster and the suggestion + /// 3. usedMean: whether the suggestion was found using the mean (true) or the median (false) + Future> _getSuggestions( + PersonEntity p, { + int sampleSize = 50, + double maxMedianDistance = 0.62, + double goodMedianDistance = 0.55, + double maxMeanDistance = 0.65, + double goodMeanDistance = 0.45, + }) async { + final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start(); + // Get all the cluster data + final faceMlDb = FaceMLDataDB.instance; + final allClusterIdsToCountMap = await faceMlDb.clusterIdToFaceCount(); + final ignoredClusters = await faceMlDb.getPersonIgnoredClusters(p.remoteID); + final personClusters = await faceMlDb.getPersonClusterIDs(p.remoteID); + final personFaceIDs = + await FaceMLDataDB.instance.getFaceIDsForPerson(p.remoteID); + final personFileIDs = personFaceIDs.map(getFileIdFromFaceId).toSet(); + w?.log( + '${p.data.name} has ${personClusters.length} existing clusters, getting all database data done', + ); + final allClusterIdToFaceIDs = + await FaceMLDataDB.instance.getAllClusterIdToFaceIDs(); + w?.log('getAllClusterIdToFaceIDs done'); + + // First only do a simple check on the big clusters, if the person does not have small clusters yet + final smallestPersonClusterSize = personClusters + .map((clusterID) => allClusterIdsToCountMap[clusterID] ?? 0) + .reduce((value, element) => min(value, element)); + final checkSizes = [100, 20, kMinimumClusterSizeSearchResult, 10, 5, 1]; + late Map clusterAvgBigClusters; + final List<(int, double)> suggestionsMean = []; + for (final minimumSize in checkSizes.toSet()) { + if (smallestPersonClusterSize >= + min(minimumSize, kMinimumClusterSizeSearchResult)) { + clusterAvgBigClusters = await _getUpdateClusterAvg( + allClusterIdsToCountMap, + ignoredClusters, + minClusterSize: minimumSize, + ); + w?.log( + 'Calculate avg for ${clusterAvgBigClusters.length} clusters of min size $minimumSize', + ); + final List<(int, double)> suggestionsMeanBigClusters = + await calcSuggestionsMeanInComputer( + clusterAvgBigClusters, + personClusters, + ignoredClusters, + (minimumSize == 100) ? goodMeanDistance + 0.15 : goodMeanDistance, + ); + w?.log( + 'Calculate suggestions using mean for ${clusterAvgBigClusters.length} clusters of min size $minimumSize', + ); + for (final suggestion in suggestionsMeanBigClusters) { + // Skip suggestions that have a high overlap with the person's files + final suggestionSet = allClusterIdToFaceIDs[suggestion.$1]! + .map((faceID) => getFileIdFromFaceId(faceID)) + .toSet(); + final overlap = personFileIDs.intersection(suggestionSet); + if (overlap.isNotEmpty && + ((overlap.length / suggestionSet.length) > 0.5)) { + await FaceMLDataDB.instance.captureNotPersonFeedback( + personID: p.remoteID, + clusterID: suggestion.$1, + ); + continue; + } + suggestionsMean.add(suggestion); + } + if (suggestionsMean.isNotEmpty) { + return suggestionsMean + .map((e) => (e.$1, e.$2, true)) + .toList(growable: false); + } + } + } + w?.reset(); + + // Find the other cluster candidates based on the median + final clusterAvg = clusterAvgBigClusters; + final List<(int, double)> moreSuggestionsMean = + await calcSuggestionsMeanInComputer( + clusterAvg, + personClusters, + ignoredClusters, + maxMeanDistance, + ); + if (moreSuggestionsMean.isEmpty) { + _logger + .info("No suggestions found using mean, even with higher threshold"); + return []; + } + + moreSuggestionsMean.sort((a, b) => a.$2.compareTo(b.$2)); + final otherClusterIdsCandidates = moreSuggestionsMean + .map( + (e) => e.$1, + ) + .toList(growable: false); + _logger.info( + "Found potential suggestions from loose mean for median test: $otherClusterIdsCandidates", + ); + + w?.logAndReset("Starting median test"); + // Take the embeddings from the person's clusters in one big list and sample from it + final List personEmbeddingsProto = []; + for (final clusterID in personClusters) { + final Iterable embeddings = + await FaceMLDataDB.instance.getFaceEmbeddingsForCluster(clusterID); + personEmbeddingsProto.addAll(embeddings); + } + final List sampledEmbeddingsProto = + _randomSampleWithoutReplacement( + personEmbeddingsProto, + sampleSize, + ); + final List sampledEmbeddings = sampledEmbeddingsProto + .map( + (embedding) => Vector.fromList( + EVector.fromBuffer(embedding).values, + dtype: DType.float32, + ), + ) + .toList(growable: false); + + // Find the actual closest clusters for the person using median + final List<(int, double)> suggestionsMedian = []; + final List<(int, double)> greatSuggestionsMedian = []; + double minMedianDistance = maxMedianDistance; + for (final otherClusterId in otherClusterIdsCandidates) { + final Iterable otherEmbeddingsProto = + await FaceMLDataDB.instance.getFaceEmbeddingsForCluster( + otherClusterId, + ); + final sampledOtherEmbeddingsProto = _randomSampleWithoutReplacement( + otherEmbeddingsProto, + sampleSize, + ); + final List sampledOtherEmbeddings = sampledOtherEmbeddingsProto + .map( + (embedding) => Vector.fromList( + EVector.fromBuffer(embedding).values, + dtype: DType.float32, + ), + ) + .toList(growable: false); + + // Calculate distances and find the median + final List distances = []; + for (final otherEmbedding in sampledOtherEmbeddings) { + for (final embedding in sampledEmbeddings) { + distances.add(cosineDistanceSIMD(embedding, otherEmbedding)); + } + } + distances.sort(); + final double medianDistance = distances[distances.length ~/ 2]; + if (medianDistance < minMedianDistance) { + suggestionsMedian.add((otherClusterId, medianDistance)); + minMedianDistance = medianDistance; + if (medianDistance < goodMedianDistance) { + greatSuggestionsMedian.add((otherClusterId, medianDistance)); + break; + } + } + } + w?.log("Finished median test"); + if (suggestionsMedian.isEmpty) { + _logger.info("No suggestions found using median"); + return []; + } else { + _logger.info("Found suggestions using median: $suggestionsMedian"); + } + + final List<(int, double, bool)> finalSuggestionsMedian = suggestionsMedian + .map(((e) => (e.$1, e.$2, false))) + .toList(growable: false) + .reversed + .toList(growable: false); + + if (greatSuggestionsMedian.isNotEmpty) { + _logger.info( + "Found great suggestion using median: $greatSuggestionsMedian", + ); + // // Return the largest size cluster by using allClusterIdsToCountMap + // final List greatSuggestionsMedianClusterIds = + // greatSuggestionsMedian.map((e) => e.$1).toList(growable: false); + // greatSuggestionsMedianClusterIds.sort( + // (a, b) => + // allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!), + // ); + + // return [greatSuggestionsMedian.last.$1, ...finalSuggestionsMedian]; + } + + return finalSuggestionsMedian; + } + + Future> _getUpdateClusterAvg( + Map allClusterIdsToCountMap, + Set ignoredClusters, { + int minClusterSize = 1, + int maxClusterInCurrentRun = 500, + int maxEmbeddingToRead = 10000, + }) async { + final w = (kDebugMode ? EnteWatch('_getUpdateClusterAvg') : null)?..start(); + final startTime = DateTime.now(); + final faceMlDb = FaceMLDataDB.instance; + _logger.info( + 'start getUpdateClusterAvg for ${allClusterIdsToCountMap.length} clusters, minClusterSize $minClusterSize, maxClusterInCurrentRun $maxClusterInCurrentRun', + ); + + final Map clusterToSummary = + await faceMlDb.getAllClusterSummary(minClusterSize); + final Map updatesForClusterSummary = {}; + + w?.log( + 'getUpdateClusterAvg database call for getAllClusterSummary', + ); + + final serializationEmbeddings = await _computer.compute( + checkAndSerializeCurrentClusterMeans, + param: { + 'allClusterIdsToCountMap': allClusterIdsToCountMap, + 'minClusterSize': minClusterSize, + 'ignoredClusters': ignoredClusters, + 'clusterToSummary': clusterToSummary, + }, + ) as (Map, Set, int, int, int); + final clusterAvg = serializationEmbeddings.$1; + final allClusterIds = serializationEmbeddings.$2; + final ignoredClustersCnt = serializationEmbeddings.$3; + final alreadyUpdatedClustersCnt = serializationEmbeddings.$4; + final smallerClustersCnt = serializationEmbeddings.$5; + + // Assert that all existing clusterAvg are normalized + for (final avg in clusterAvg.values) { + assert((avg.norm() - 1.0).abs() < 1e-5); + } + + w?.log( + 'serialization of embeddings', + ); + _logger.info( + 'Ignored $ignoredClustersCnt clusters, already updated $alreadyUpdatedClustersCnt clusters, $smallerClustersCnt clusters are smaller than $minClusterSize', + ); + + if (allClusterIds.isEmpty) { + _logger.info( + 'No clusters to update, getUpdateClusterAvg done in ${DateTime.now().difference(startTime).inMilliseconds} ms', + ); + return clusterAvg; + } + + // get clusterIDs sorted by count in descending order + final sortedClusterIDs = allClusterIds.toList(); + sortedClusterIDs.sort( + (a, b) => + allClusterIdsToCountMap[b]!.compareTo(allClusterIdsToCountMap[a]!), + ); + int indexedInCurrentRun = 0; + w?.reset(); + + int currentPendingRead = 0; + final List clusterIdsToRead = []; + for (final clusterID in sortedClusterIDs) { + if (maxClusterInCurrentRun-- <= 0) { + break; + } + if (currentPendingRead == 0) { + currentPendingRead = allClusterIdsToCountMap[clusterID] ?? 0; + clusterIdsToRead.add(clusterID); + } else { + if ((currentPendingRead + allClusterIdsToCountMap[clusterID]!) < + maxEmbeddingToRead) { + clusterIdsToRead.add(clusterID); + currentPendingRead += allClusterIdsToCountMap[clusterID]!; + } else { + break; + } + } + } + + final Map> clusterEmbeddings = await FaceMLDataDB + .instance + .getFaceEmbeddingsForClusters(clusterIdsToRead); + + w?.logAndReset( + 'read $currentPendingRead embeddings for ${clusterEmbeddings.length} clusters', + ); + + for (final clusterID in clusterEmbeddings.keys) { + final Iterable embeddings = clusterEmbeddings[clusterID]!; + final Iterable vectors = embeddings.map( + (e) => Vector.fromList( + EVector.fromBuffer(e).values, + dtype: DType.float32, + ), + ); + final avg = vectors.reduce((a, b) => a + b) / vectors.length; + final avgNormalized = avg / avg.norm(); + final avgEmbeddingBuffer = EVector(values: avgNormalized).writeToBuffer(); + updatesForClusterSummary[clusterID] = + (avgEmbeddingBuffer, embeddings.length); + // store the intermediate updates + indexedInCurrentRun++; + if (updatesForClusterSummary.length > 100) { + await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); + updatesForClusterSummary.clear(); + if (kDebugMode) { + _logger.info( + 'getUpdateClusterAvg $indexedInCurrentRun clusters in current one', + ); + } + } + clusterAvg[clusterID] = avgNormalized; + } + if (updatesForClusterSummary.isNotEmpty) { + await faceMlDb.clusterSummaryUpdate(updatesForClusterSummary); + } + w?.logAndReset('done computing avg '); + _logger.info( + 'end getUpdateClusterAvg for ${clusterAvg.length} clusters, done in ${DateTime.now().difference(startTime).inMilliseconds} ms', + ); + + return clusterAvg; + } + + Future> calcSuggestionsMeanInComputer( + Map clusterAvg, + Set personClusters, + Set ignoredClusters, + double maxClusterDistance, + ) async { + return await _computer.compute( + _calcSuggestionsMean, + param: { + 'clusterAvg': clusterAvg, + 'personClusters': personClusters, + 'ignoredClusters': ignoredClusters, + 'maxClusterDistance': maxClusterDistance, + }, + ); + } + + List _randomSampleWithoutReplacement( + Iterable embeddings, + int sampleSize, + ) { + final random = Random(); + + if (sampleSize >= embeddings.length) { + return embeddings.toList(); + } + + // If sampleSize is more than half the list size, shuffle and take first sampleSize elements + if (sampleSize > embeddings.length / 2) { + final List shuffled = List.from(embeddings)..shuffle(random); + return shuffled.take(sampleSize).toList(growable: false); + } + + // Otherwise, use the set-based method for efficiency + final selectedIndices = {}; + final sampledEmbeddings = []; + while (sampledEmbeddings.length < sampleSize) { + final int index = random.nextInt(embeddings.length); + if (!selectedIndices.contains(index)) { + selectedIndices.add(index); + sampledEmbeddings.add(embeddings.elementAt(index)); + } + } + + return sampledEmbeddings; + } + + Future _sortSuggestionsOnDistanceToPerson( + PersonEntity person, + List suggestions, { + bool onlySortBigSuggestions = true, + }) async { + if (suggestions.isEmpty) { + debugPrint('No suggestions to sort'); + return; + } + if (onlySortBigSuggestions) { + final bigSuggestions = suggestions + .where( + (s) => s.filesInCluster.length > kMinimumClusterSizeSearchResult, + ) + .toList(); + if (bigSuggestions.isEmpty) { + debugPrint('No big suggestions to sort'); + return; + } + } + final startTime = DateTime.now(); + final faceMlDb = FaceMLDataDB.instance; + + // Get the cluster averages for the person's clusters and the suggestions' clusters + final personClusters = await faceMlDb.getPersonClusterIDs(person.remoteID); + final Map personClusterToSummary = + await faceMlDb.getClusterToClusterSummary(personClusters); + final clusterSummaryCallTime = DateTime.now(); + + // Calculate the avg embedding of the person + final w = (kDebugMode ? EnteWatch('sortSuggestions') : null)?..start(); + final personEmbeddingsCount = personClusters + .map((e) => personClusterToSummary[e]!.$2) + .reduce((a, b) => a + b); + Vector personAvg = Vector.filled(192, 0); + for (final personClusterID in personClusters) { + final personClusterBlob = personClusterToSummary[personClusterID]!.$1; + final personClusterAvg = Vector.fromList( + EVector.fromBuffer(personClusterBlob).values, + dtype: DType.float32, + ); + final clusterWeight = + personClusterToSummary[personClusterID]!.$2 / personEmbeddingsCount; + personAvg += personClusterAvg * clusterWeight; + } + w?.log('calculated person avg'); + + // Sort the suggestions based on the distance to the person + for (final suggestion in suggestions) { + if (onlySortBigSuggestions) { + if (suggestion.filesInCluster.length <= 8) { + continue; + } + } + final clusterID = suggestion.clusterIDToMerge; + final faceIDs = suggestion.faceIDsInCluster; + final faceIdToEmbeddingMap = await faceMlDb.getFaceEmbeddingMapForFaces( + faceIDs, + ); + final faceIdToVectorMap = faceIdToEmbeddingMap.map( + (key, value) => MapEntry( + key, + Vector.fromList( + EVector.fromBuffer(value).values, + dtype: DType.float32, + ), + ), + ); + w?.log( + 'got ${faceIdToEmbeddingMap.values.length} embeddings for ${suggestion.filesInCluster.length} files for cluster $clusterID', + ); + final fileIdToDistanceMap = {}; + for (final entry in faceIdToVectorMap.entries) { + fileIdToDistanceMap[getFileIdFromFaceId(entry.key)] = + cosineDistanceSIMD(personAvg, entry.value); + } + w?.log('calculated distances for cluster $clusterID'); + suggestion.filesInCluster.sort((b, a) { + //todo: review with @laurens, added this to avoid null safety issue + final double distanceA = fileIdToDistanceMap[a.uploadedFileID!] ?? -1; + final double distanceB = fileIdToDistanceMap[b.uploadedFileID!] ?? -1; + return distanceA.compareTo(distanceB); + }); + w?.log('sorted files for cluster $clusterID'); + + debugPrint( + "[${_logger.name}] Sorted suggestions for cluster $clusterID based on distance to person: ${suggestion.filesInCluster.map((e) => fileIdToDistanceMap[e.uploadedFileID]).toList()}", + ); + } + + final endTime = DateTime.now(); + _logger.info( + "Sorting suggestions based on distance to person took ${endTime.difference(startTime).inMilliseconds} ms for ${suggestions.length} suggestions, of which ${clusterSummaryCallTime.difference(startTime).inMilliseconds} ms was spent on the cluster summary call", + ); + } +} + +/// Returns a map of person's clusterID to map of closest clusterID to with disstance +List<(int, double)> _calcSuggestionsMean(Map args) { + // Fill in args + final Map clusterAvg = args['clusterAvg']; + final Set personClusters = args['personClusters']; + final Set ignoredClusters = args['ignoredClusters']; + final double maxClusterDistance = args['maxClusterDistance']; + + final Map> suggestions = {}; + const suggestionMax = 2000; + int suggestionCount = 0; + int comparisons = 0; + final w = (kDebugMode ? EnteWatch('getSuggestions') : null)?..start(); + + // ignore the clusters that belong to the person or is ignored + Set otherClusters = clusterAvg.keys.toSet().difference(personClusters); + otherClusters = otherClusters.difference(ignoredClusters); + + for (final otherClusterID in otherClusters) { + final Vector? otherAvg = clusterAvg[otherClusterID]; + if (otherAvg == null) { + dev.log('[WARNING] no avg for othercluster $otherClusterID'); + continue; + } + int? nearestPersonCluster; + double? minDistance; + for (final personCluster in personClusters) { + if (clusterAvg[personCluster] == null) { + dev.log('[WARNING] no avg for personcluster $personCluster'); + continue; + } + final Vector avg = clusterAvg[personCluster]!; + final distance = cosineDistanceSIMD(avg, otherAvg); + comparisons++; + if (distance < maxClusterDistance) { + if (minDistance == null || distance < minDistance) { + minDistance = distance; + nearestPersonCluster = personCluster; + } + } + } + if (nearestPersonCluster != null && minDistance != null) { + suggestions + .putIfAbsent(nearestPersonCluster, () => []) + .add((otherClusterID, minDistance)); + suggestionCount++; + } + if (suggestionCount >= suggestionMax) { + break; + } + } + w?.log( + 'calculation inside calcSuggestionsMean for ${personClusters.length} person clusters and ${otherClusters.length} other clusters (so ${personClusters.length * otherClusters.length} combinations, $comparisons comparisons made resulted in $suggestionCount suggestions)', + ); + + if (suggestions.isNotEmpty) { + final List<(int, double)> suggestClusterIds = []; + for (final List<(int, double)> suggestion in suggestions.values) { + suggestClusterIds.addAll(suggestion); + } + suggestClusterIds.sort( + (a, b) => a.$2.compareTo(b.$2), + ); // sort by distance + + dev.log( + "Already found ${suggestClusterIds.length} good suggestions using mean", + ); + return suggestClusterIds.sublist(0, min(suggestClusterIds.length, 20)); + } else { + dev.log("No suggestions found using mean"); + return <(int, double)>[]; + } +} + +Future<(Map, Set, int, int, int)> + checkAndSerializeCurrentClusterMeans( + Map args, +) async { + final Map allClusterIdsToCountMap = args['allClusterIdsToCountMap']; + final int minClusterSize = args['minClusterSize'] ?? 1; + final Set ignoredClusters = args['ignoredClusters'] ?? {}; + final Map clusterToSummary = args['clusterToSummary']; + + final Map clusterAvg = {}; + + final allClusterIds = allClusterIdsToCountMap.keys.toSet(); + int ignoredClustersCnt = 0, alreadyUpdatedClustersCnt = 0; + int smallerClustersCnt = 0; + for (final id in allClusterIdsToCountMap.keys) { + if (ignoredClusters.contains(id)) { + allClusterIds.remove(id); + ignoredClustersCnt++; + } + if (clusterToSummary[id]?.$2 == allClusterIdsToCountMap[id]) { + allClusterIds.remove(id); + clusterAvg[id] = Vector.fromList( + EVector.fromBuffer(clusterToSummary[id]!.$1).values, + dtype: DType.float32, + ); + alreadyUpdatedClustersCnt++; + } + if (allClusterIdsToCountMap[id]! < minClusterSize) { + allClusterIds.remove(id); + smallerClustersCnt++; + } + } + + return ( + clusterAvg, + allClusterIds, + ignoredClustersCnt, + alreadyUpdatedClustersCnt, + smallerClustersCnt + ); +} diff --git a/mobile/lib/services/machine_learning/face_ml/person/person_service.dart b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart new file mode 100644 index 000000000..7517d057d --- /dev/null +++ b/mobile/lib/services/machine_learning/face_ml/person/person_service.dart @@ -0,0 +1,294 @@ +import "dart:convert"; +import "dart:developer"; + +import "package:flutter/foundation.dart"; +import "package:logging/logging.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/extensions/stop_watch.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/api/entity/type.dart"; +import "package:photos/services/entity_service.dart"; +import "package:shared_preferences/shared_preferences.dart"; + +class PersonService { + final EntityService entityService; + final FaceMLDataDB faceMLDataDB; + final SharedPreferences prefs; + PersonService(this.entityService, this.faceMLDataDB, this.prefs); + // instance + static PersonService? _instance; + static PersonService get instance { + if (_instance == null) { + throw Exception("PersonService not initialized"); + } + return _instance!; + } + + late Logger logger = Logger("PersonService"); + + static init( + EntityService entityService, + FaceMLDataDB faceMLDataDB, + SharedPreferences prefs, + ) { + _instance = PersonService(entityService, faceMLDataDB, prefs); + } + + Future> getPersons() async { + final entities = await entityService.getEntities(EntityType.person); + return entities + .map( + (e) => PersonEntity(e.id, PersonData.fromJson(json.decode(e.data))), + ) + .toList(); + } + + Future getPerson(String id) { + return entityService.getEntity(EntityType.person, id).then((e) { + if (e == null) { + return null; + } + return PersonEntity(e.id, PersonData.fromJson(json.decode(e.data))); + }); + } + + Future> getPersonsMap() async { + final entities = await entityService.getEntities(EntityType.person); + final Map map = {}; + for (var e in entities) { + final person = + PersonEntity(e.id, PersonData.fromJson(json.decode(e.data))); + map[person.remoteID] = person; + } + return map; + } + + Future> personIDs() async { + final entities = await entityService.getEntities(EntityType.person); + return entities.map((e) => e.id).toSet(); + } + + Future reconcileClusters() async { + final EnteWatch? w = kDebugMode ? EnteWatch("reconcileClusters") : null; + w?.start(); + await storeRemoteFeedback(); + w?.log("Stored remote feedback"); + final dbPersonClusterInfo = + await faceMLDataDB.getPersonToClusterIdToFaceIds(); + w?.log("Got DB person cluster info"); + final persons = await getPersonsMap(); + w?.log("Got persons"); + for (var personID in dbPersonClusterInfo.keys) { + final person = persons[personID]; + if (person == null) { + logger.warning("Person $personID not found"); + continue; + } + final personData = person.data; + final Map> dbPersonCluster = + dbPersonClusterInfo[personID]!; + if (_shouldUpdateRemotePerson(personData, dbPersonCluster)) { + final personData = person.data; + personData.assigned = dbPersonCluster.entries + .map( + (e) => ClusterInfo( + id: e.key, + faces: e.value, + ), + ) + .toList(); + entityService + .addOrUpdate( + EntityType.person, + json.encode(personData.toJson()), + id: personID, + ) + .ignore(); + personData.logStats(); + } + } + w?.log("Reconciled clusters for ${persons.length} persons"); + } + + bool _shouldUpdateRemotePerson( + PersonData personData, + Map> dbPersonCluster, + ) { + bool result = false; + if ((personData.assigned?.length ?? 0) != dbPersonCluster.length) { + log( + "Person ${personData.name} has ${personData.assigned?.length} clusters, but ${dbPersonCluster.length} clusters found in DB", + name: "PersonService", + ); + result = true; + } else { + for (ClusterInfo info in personData.assigned!) { + final dbCluster = dbPersonCluster[info.id]; + if (dbCluster == null) { + log( + "Cluster ${info.id} not found in DB for person ${personData.name}", + name: "PersonService", + ); + result = true; + continue; + } + if (info.faces.length != dbCluster.length) { + log( + "Cluster ${info.id} has ${info.faces.length} faces, but ${dbCluster.length} faces found in DB", + name: "PersonService", + ); + result = true; + } + for (var faceId in info.faces) { + if (!dbCluster.contains(faceId)) { + log( + "Face $faceId not found in cluster ${info.id} for person ${personData.name}", + name: "PersonService", + ); + result = true; + } + } + } + } + return result; + } + + Future addPerson( + String name, + int clusterID, { + bool isHidden = false, + }) async { + final faceIds = await faceMLDataDB.getFaceIDsForCluster(clusterID); + final data = PersonData( + name: name, + assigned: [ + ClusterInfo( + id: clusterID, + faces: faceIds.toSet(), + ), + ], + isHidden: isHidden, + ); + final result = await entityService.addOrUpdate( + EntityType.person, + json.encode(data.toJson()), + ); + await faceMLDataDB.assignClusterToPerson( + personID: result.id, + clusterID: clusterID, + ); + return PersonEntity(result.id, data); + } + + Future removeClusterToPerson({ + required String personID, + required int clusterID, + }) async { + final person = (await getPerson(personID))!; + final personData = person.data; + personData.assigned!.removeWhere((element) => element.id != clusterID); + await entityService.addOrUpdate( + EntityType.person, + json.encode(personData.toJson()), + id: personID, + ); + await faceMLDataDB.removeClusterToPerson( + personID: personID, + clusterID: clusterID, + ); + personData.logStats(); + } + + Future deletePerson(String personID, {bool onlyMapping = false}) async { + if (onlyMapping) { + final PersonEntity? entity = await getPerson(personID); + if (entity == null) { + return; + } + final PersonEntity justName = + PersonEntity(personID, PersonData(name: entity.data.name)); + await entityService.addOrUpdate( + EntityType.person, + json.encode(justName.data.toJson()), + id: personID, + ); + await faceMLDataDB.removePerson(personID); + justName.data.logStats(); + } else { + await entityService.deleteEntry(personID); + await faceMLDataDB.removePerson(personID); + } + + // fire PeopleChangeEvent + Bus.instance.fire(PeopleChangedEvent()); + } + + Future storeRemoteFeedback() async { + await entityService.syncEntities(); + final entities = await entityService.getEntities(EntityType.person); + entities.sort((a, b) => a.updatedAt.compareTo(b.updatedAt)); + final Map faceIdToClusterID = {}; + final Map clusterToPersonID = {}; + for (var e in entities) { + final personData = PersonData.fromJson(json.decode(e.data)); + int faceCount = 0; + for (var cluster in personData.assigned!) { + faceCount += cluster.faces.length; + for (var faceId in cluster.faces) { + if (faceIdToClusterID.containsKey(faceId)) { + final otherPersonID = clusterToPersonID[faceIdToClusterID[faceId]!]; + if (otherPersonID != e.id) { + final otherPerson = await getPerson(otherPersonID!); + throw Exception( + "Face $faceId is already assigned to person $otherPersonID (${otherPerson!.data.name}) and person ${e.id} (${personData.name})", + ); + } + } + faceIdToClusterID[faceId] = cluster.id; + } + clusterToPersonID[cluster.id] = e.id; + } + if (kDebugMode) { + logger.info( + "Person ${e.id} ${personData.name} has ${personData.assigned!.length} clusters with $faceCount faces", + ); + } + } + + logger.info("Storing feedback for ${faceIdToClusterID.length} faces"); + await faceMLDataDB.updateFaceIdToClusterId(faceIdToClusterID); + await faceMLDataDB.bulkAssignClusterToPersonID(clusterToPersonID); + } + + Future updateAttributes( + String id, { + String? name, + String? avatarFaceId, + bool? isHidden, + int? version, + String? birthDate, + }) async { + final person = (await getPerson(id))!; + final updatedPerson = person.copyWith( + data: person.data.copyWith( + name: name, + avatarFaceId: avatarFaceId, + isHidden: isHidden, + version: version, + birthDate: birthDate, + ), + ); + await _updatePerson(updatedPerson); + } + + Future _updatePerson(PersonEntity updatePerson) async { + await entityService.addOrUpdate( + EntityType.person, + json.encode(updatePerson.data.toJson()), + id: updatePerson.remoteID, + ); + updatePerson.data.logStats(); + } +} diff --git a/mobile/lib/services/machine_learning/file_ml/file_ml.dart b/mobile/lib/services/machine_learning/file_ml/file_ml.dart new file mode 100644 index 000000000..990990276 --- /dev/null +++ b/mobile/lib/services/machine_learning/file_ml/file_ml.dart @@ -0,0 +1,89 @@ +import "package:photos/face/model/face.dart"; + +class FileMl { + final int fileID; + final int? height; + final int? width; + final FaceEmbeddings faceEmbedding; + final ClipEmbedding? clipEmbedding; + + FileMl( + this.fileID, + this.faceEmbedding, { + this.height, + this.width, + this.clipEmbedding, + }); + + // toJson + Map toJson() => { + 'fileID': fileID, + 'height': height, + 'width': width, + 'faceEmbedding': faceEmbedding.toJson(), + 'clipEmbedding': clipEmbedding?.toJson(), + }; + // fromJson + factory FileMl.fromJson(Map json) { + return FileMl( + json['fileID'] as int, + FaceEmbeddings.fromJson(json['faceEmbedding'] as Map), + height: json['height'] as int?, + width: json['width'] as int?, + clipEmbedding: json['clipEmbedding'] == null + ? null + : ClipEmbedding.fromJson( + json['clipEmbedding'] as Map, + ), + ); + } +} + +class FaceEmbeddings { + final List faces; + final int version; + // pkgname/version + final String client; + + FaceEmbeddings( + this.faces, + this.version, { + required this.client, + }); + + // toJson + Map toJson() => { + 'faces': faces.map((x) => x.toJson()).toList(), + 'version': version, + 'client': client, + }; + // fromJson + factory FaceEmbeddings.fromJson(Map json) { + return FaceEmbeddings( + List.from( + json['faces'].map((x) => Face.fromJson(x as Map)), + ), + json['version'] as int, + client: json['client'] ?? + 'unknown', + ); + } +} + +class ClipEmbedding { + final int? version; + final List embedding; + ClipEmbedding(this.embedding, {this.version}); + // toJson + Map toJson() => { + 'version': version, + 'embedding': embedding, + }; + // fromJson + factory ClipEmbedding.fromJson(Map json) { + return ClipEmbedding( + List.from(json['embedding'] as List), + version: json['version'] as int?, + ); + } +} diff --git a/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart b/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart new file mode 100644 index 000000000..475f52d0a --- /dev/null +++ b/mobile/lib/services/machine_learning/file_ml/files_ml_data_response.dart @@ -0,0 +1,19 @@ +import 'package:photos/services/machine_learning/file_ml/file_ml.dart'; + +class FilesMLDataResponse { + final Map mlData; + // fileIDs that were indexed but they don't contain any meaningful embeddings + // and hence should be discarded for re-indexing + final Set noEmbeddingFileIDs; + // fetchErrorFileIDs are the fileIDs for whom we failed failed to fetch embeddings + // from the storage + final Set fetchErrorFileIDs; + // pendingIndexFileIDs are the fileIDs that were never indexed + final Set pendingIndexFileIDs; + FilesMLDataResponse( + this.mlData, { + required this.noEmbeddingFileIDs, + required this.fetchErrorFileIDs, + required this.pendingIndexFileIDs, + }); +} diff --git a/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart b/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart new file mode 100644 index 000000000..eafbc6323 --- /dev/null +++ b/mobile/lib/services/machine_learning/file_ml/remote_fileml_service.dart @@ -0,0 +1,138 @@ +import "dart:async"; +import "dart:convert"; + +import "package:logging/logging.dart"; +import "package:photos/core/network/network.dart"; +import "package:photos/db/files_db.dart"; +import "package:photos/models/file/file.dart"; +import 'package:photos/services/machine_learning/file_ml/file_ml.dart'; +import "package:photos/services/machine_learning/file_ml/files_ml_data_response.dart"; +import "package:photos/services/machine_learning/semantic_search/embedding_store.dart"; +import "package:photos/services/machine_learning/semantic_search/remote_embedding.dart"; +import "package:photos/utils/crypto_util.dart"; +import "package:photos/utils/file_download_util.dart"; +import "package:shared_preferences/shared_preferences.dart"; + +class RemoteFileMLService { + RemoteFileMLService._privateConstructor(); + + static final RemoteFileMLService instance = + RemoteFileMLService._privateConstructor(); + + final _logger = Logger("RemoteFileMLService"); + final _dio = NetworkClient.instance.enteDio; + + void init(SharedPreferences prefs) {} + + Future putFileEmbedding(EnteFile file, FileMl fileML) async { + final encryptionKey = getFileKey(file); + final embeddingJSON = jsonEncode(fileML.toJson()); + final encryptedEmbedding = await CryptoUtil.encryptChaCha( + utf8.encode(embeddingJSON), + encryptionKey, + ); + final encryptedData = + CryptoUtil.bin2base64(encryptedEmbedding.encryptedData!); + final header = CryptoUtil.bin2base64(encryptedEmbedding.header!); + try { + final _ = await _dio.put( + "/embeddings", + data: { + "fileID": file.uploadedFileID!, + "model": 'file-ml-clip-face', + "encryptedEmbedding": encryptedData, + "decryptionHeader": header, + }, + ); + // final updationTime = response.data["updatedAt"]; + } catch (e, s) { + _logger.severe("Failed to put embedding", e, s); + rethrow; + } + } + + Future getFilessEmbedding( + List fileIds, + ) async { + try { + final res = await _dio.post( + "/embeddings/files", + data: { + "fileIDs": fileIds, + "model": 'file-ml-clip-face', + }, + ); + final remoteEmb = res.data['embeddings'] as List; + final pendingIndexFiles = res.data['pendingIndexFileIDs'] as List; + final noEmbeddingFiles = res.data['noEmbeddingFileIDs'] as List; + final errFileIds = res.data['errFileIDs'] as List; + + final List remoteEmbeddings = []; + for (var entry in remoteEmb) { + final embedding = RemoteEmbedding.fromMap(entry); + remoteEmbeddings.add(embedding); + } + + final fileIDToFileMl = await decryptFileMLData(remoteEmbeddings); + return FilesMLDataResponse( + fileIDToFileMl, + noEmbeddingFileIDs: + Set.from(noEmbeddingFiles.map((x) => x as int)), + fetchErrorFileIDs: Set.from(errFileIds.map((x) => x as int)), + pendingIndexFileIDs: + Set.from(pendingIndexFiles.map((x) => x as int)), + ); + } catch (e, s) { + _logger.severe("Failed to get embeddings", e, s); + rethrow; + } + } + + Future> decryptFileMLData( + List remoteEmbeddings, + ) async { + final result = {}; + if (remoteEmbeddings.isEmpty) { + return result; + } + final inputs = []; + final fileMap = await FilesDB.instance + .getFilesFromIDs(remoteEmbeddings.map((e) => e.fileID).toList()); + for (final embedding in remoteEmbeddings) { + final file = fileMap[embedding.fileID]; + if (file == null) { + continue; + } + final fileKey = getFileKey(file); + final input = EmbeddingsDecoderInput(embedding, fileKey); + inputs.add(input); + } + // todo: use compute or isolate + return decryptFileMLComputer( + { + "inputs": inputs, + }, + ); + } + + Future> decryptFileMLComputer( + Map args, + ) async { + final result = {}; + final inputs = args["inputs"] as List; + for (final input in inputs) { + final decryptArgs = {}; + decryptArgs["source"] = + CryptoUtil.base642bin(input.embedding.encryptedEmbedding); + decryptArgs["key"] = input.decryptionKey; + decryptArgs["header"] = + CryptoUtil.base642bin(input.embedding.decryptionHeader); + final embeddingData = chachaDecryptData(decryptArgs); + final decodedJson = jsonDecode(utf8.decode(embeddingData)); + final FileMl decodedEmbedding = + FileMl.fromJson(decodedJson as Map); + result[input.embedding.fileID] = decodedEmbedding; + } + return result; + } +} diff --git a/mobile/lib/services/machine_learning/machine_learning_controller.dart b/mobile/lib/services/machine_learning/machine_learning_controller.dart index 145670f2c..65daf614c 100644 --- a/mobile/lib/services/machine_learning/machine_learning_controller.dart +++ b/mobile/lib/services/machine_learning/machine_learning_controller.dart @@ -22,7 +22,7 @@ class MachineLearningController { bool _isDeviceHealthy = true; bool _isUserInteracting = true; - bool _isRunningML = false; + bool _canRunML = false; late Timer _userInteractionTimer; void init() { @@ -35,6 +35,7 @@ class MachineLearningController { }); } else { // Always run Machine Learning on iOS + _canRunML = true; Bus.instance.fire(MachineLearningControlEvent(true)); } } @@ -53,10 +54,10 @@ class MachineLearningController { void _fireControlEvent() { final shouldRunML = _isDeviceHealthy && !_isUserInteracting; - if (shouldRunML != _isRunningML) { - _isRunningML = shouldRunML; + if (shouldRunML != _canRunML) { + _canRunML = shouldRunML; _logger.info( - "Firing event with device health: $_isDeviceHealthy and user interaction: $_isUserInteracting", + "Firing event with $shouldRunML, device health: $_isDeviceHealthy and user interaction: $_isUserInteracting", ); Bus.instance.fire(MachineLearningControlEvent(shouldRunML)); } diff --git a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart index 99aa3a011..d85b4ceb5 100644 --- a/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart +++ b/mobile/lib/services/machine_learning/semantic_search/semantic_search_service.dart @@ -1,6 +1,7 @@ import "dart:async"; import "dart:collection"; import "dart:io"; +import "dart:math" show min; import "package:computer/computer.dart"; import "package:logging/logging.dart"; @@ -164,8 +165,10 @@ class SemanticSearchService { } Future getIndexStatus() async { + final indexableFileIDs = await FilesDB.instance + .getOwnedFileIDs(Configuration.instance.getUserID()!); return IndexStatus( - _cachedEmbeddings.length, + min(_cachedEmbeddings.length, indexableFileIDs.length), (await _getFileIDsToBeIndexed()).length, ); } diff --git a/mobile/lib/services/search_service.dart b/mobile/lib/services/search_service.dart index e27ca7582..5e21b0334 100644 --- a/mobile/lib/services/search_service.dart +++ b/mobile/lib/services/search_service.dart @@ -11,6 +11,8 @@ import 'package:photos/data/years.dart'; import 'package:photos/db/files_db.dart'; import 'package:photos/events/local_photos_updated_event.dart'; import "package:photos/extensions/string_ext.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; import "package:photos/models/api/collection/user.dart"; import 'package:photos/models/collection/collection.dart'; import 'package:photos/models/collection/collection_items.dart'; @@ -22,19 +24,25 @@ import "package:photos/models/location/location.dart"; import "package:photos/models/location_tag/location_tag.dart"; import 'package:photos/models/search/album_search_result.dart'; import 'package:photos/models/search/generic_search_result.dart'; +import "package:photos/models/search/search_constants.dart"; import "package:photos/models/search/search_types.dart"; import 'package:photos/services/collections_service.dart'; import "package:photos/services/location_service.dart"; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import "package:photos/states/location_screen_state.dart"; import "package:photos/ui/viewer/location/add_location_sheet.dart"; import "package:photos/ui/viewer/location/location_screen.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/people/people_page.dart"; import 'package:photos/utils/date_time_util.dart'; import "package:photos/utils/navigation_util.dart"; import 'package:tuple/tuple.dart'; class SearchService { Future>? _cachedFilesFuture; + Future>? _cachedHiddenFilesFuture; final _logger = Logger((SearchService).toString()); final _collectionService = CollectionsService.instance; static const _maximumResultsLimit = 20; @@ -47,6 +55,7 @@ class SearchService { Bus.instance.on().listen((event) { // only invalidate, let the load happen on demand _cachedFilesFuture = null; + _cachedHiddenFilesFuture = null; }); } @@ -66,8 +75,21 @@ class SearchService { return _cachedFilesFuture!; } + Future> getHiddenFiles() async { + if (_cachedHiddenFilesFuture != null) { + return _cachedHiddenFilesFuture!; + } + _logger.fine("Reading hidden files from db"); + final hiddenCollections = + CollectionsService.instance.getHiddenCollectionIds(); + _cachedHiddenFilesFuture = + FilesDB.instance.getAllFilesFromCollections(hiddenCollections); + return _cachedHiddenFilesFuture!; + } + void clearCache() { _cachedFilesFuture = null; + _cachedHiddenFilesFuture = null; } // getFilteredCollectionsWithThumbnail removes deleted or archived or @@ -704,6 +726,169 @@ class SearchService { return searchResults; } + Future>> getClusterFilesForPersonID( + String personID, + ) async { + _logger.info('getClusterFilesForPersonID $personID'); + final Map> fileIdToClusterID = + await FaceMLDataDB.instance.getFileIdToClusterIDSet(personID); + _logger.info('faceDbDone getClusterFilesForPersonID $personID'); + final Map> clusterIDToFiles = {}; + final allFiles = await getAllFiles(); + for (final f in allFiles) { + if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) { + continue; + } + final cluserIds = fileIdToClusterID[f.uploadedFileID ?? -1]!; + for (final cluster in cluserIds) { + if (clusterIDToFiles.containsKey(cluster)) { + clusterIDToFiles[cluster]!.add(f); + } else { + clusterIDToFiles[cluster] = [f]; + } + } + } + _logger.info('done getClusterFilesForPersonID $personID'); + return clusterIDToFiles; + } + + Future> getAllFace(int? limit) async { + try { + // Don't return anything if clustering is not nearly complete yet + final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount(); + final clusteredFaces = + await FaceMLDataDB.instance.getClusteredFaceCount(); + final clusteringDoneRatio = clusteredFaces / foundFaces; + if (clusteringDoneRatio < 0.9) { + return []; + } + + debugPrint("getting faces"); + final Map> fileIdToClusterID = + await FaceMLDataDB.instance.getFileIdToClusterIds(); + final Map personIdToPerson = + await PersonService.instance.getPersonsMap(); + final clusterIDToPersonID = + await FaceMLDataDB.instance.getClusterIDToPersonID(); + + final List facesResult = []; + final Map> clusterIdToFiles = {}; + final Map> personIdToFiles = {}; + final allFiles = await getAllFiles(); + for (final f in allFiles) { + if (!fileIdToClusterID.containsKey(f.uploadedFileID ?? -1)) { + continue; + } + final cluserIds = fileIdToClusterID[f.uploadedFileID ?? -1]!; + for (final cluster in cluserIds) { + final PersonEntity? p = + personIdToPerson[clusterIDToPersonID[cluster] ?? ""]; + if (p != null) { + if (personIdToFiles.containsKey(p.remoteID)) { + personIdToFiles[p.remoteID]!.add(f); + } else { + personIdToFiles[p.remoteID] = [f]; + } + } else { + if (clusterIdToFiles.containsKey(cluster)) { + clusterIdToFiles[cluster]!.add(f); + } else { + clusterIdToFiles[cluster] = [f]; + } + } + } + } + // get sorted personId by files count + final sortedPersonIds = personIdToFiles.keys.toList() + ..sort( + (a, b) => personIdToFiles[b]!.length.compareTo( + personIdToFiles[a]!.length, + ), + ); + for (final personID in sortedPersonIds) { + final files = personIdToFiles[personID]!; + if (files.isEmpty) { + continue; + } + final PersonEntity p = personIdToPerson[personID]!; + if (p.data.isIgnored) continue; + facesResult.add( + GenericSearchResult( + ResultType.faces, + p.data.name, + files, + params: { + kPersonParamID: personID, + kFileID: files.first.uploadedFileID, + }, + onResultTap: (ctx) { + routeToPage( + ctx, + PeoplePage( + tagPrefix: "${ResultType.faces.toString()}_${p.data.name}", + person: p, + ), + ); + }, + ), + ); + } + final sortedClusterIds = clusterIdToFiles.keys.toList() + ..sort( + (a, b) => clusterIdToFiles[b]! + .length + .compareTo(clusterIdToFiles[a]!.length), + ); + + for (final clusterId in sortedClusterIds) { + final files = clusterIdToFiles[clusterId]!; + // final String clusterName = "ID:$clusterId, ${files.length}"; + // final String clusterName = "${files.length}"; + // const String clusterName = ""; + final String clusterName = "$clusterId"; + + if (clusterIDToPersonID[clusterId] != null) { + throw Exception( + "Cluster $clusterId should not have person id ${clusterIDToPersonID[clusterId]}", + ); + } + if (files.length < kMinimumClusterSizeSearchResult && + sortedClusterIds.length > 3) { + continue; + } + facesResult.add( + GenericSearchResult( + ResultType.faces, + clusterName, + files, + params: { + kClusterParamId: clusterId, + kFileID: files.first.uploadedFileID, + }, + onResultTap: (ctx) { + routeToPage( + ctx, + ClusterPage( + files, + tagPrefix: "${ResultType.faces.toString()}_$clusterName", + clusterID: clusterId, + ), + ); + }, + ), + ); + } + if (limit != null) { + return facesResult.sublist(0, min(limit, facesResult.length)); + } else { + return facesResult; + } + } catch (e, s) { + _logger.severe("Error in getAllFace", e, s); + rethrow; + } + } + Future> getAllLocationTags(int? limit) async { try { final Map, List> tagToItemsMap = {}; diff --git a/mobile/lib/states/all_sections_examples_state.dart b/mobile/lib/states/all_sections_examples_state.dart index fdeb6fcdf..a40ecd925 100644 --- a/mobile/lib/states/all_sections_examples_state.dart +++ b/mobile/lib/states/all_sections_examples_state.dart @@ -6,6 +6,7 @@ import "package:logging/logging.dart"; import "package:photos/core/constants.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/events/files_updated_event.dart"; +import "package:photos/events/people_changed_event.dart"; import "package:photos/events/tab_changed_event.dart"; import "package:photos/models/search/search_result.dart"; import "package:photos/models/search/search_types.dart"; @@ -31,6 +32,7 @@ class _AllSectionsExamplesProviderState Future>> allSectionsExamplesFuture = Future.value([]); late StreamSubscription _filesUpdatedEvent; + late StreamSubscription _onPeopleChangedEvent; late StreamSubscription _tabChangeEvent; bool hasPendingUpdate = false; bool isOnSearchTab = false; @@ -46,16 +48,11 @@ class _AllSectionsExamplesProviderState super.initState(); //add all common events for all search sections to reload to here. _filesUpdatedEvent = Bus.instance.on().listen((event) { - if (!isOnSearchTab) { - if (kDebugMode) { - _logger.finest('Skip reload till user clicks on search tab'); - } - hasPendingUpdate = true; - return; - } else { - hasPendingUpdate = false; - reloadAllSections(); - } + onDataUpdate(); + }); + _onPeopleChangedEvent = + Bus.instance.on().listen((event) { + onDataUpdate(); }); _tabChangeEvent = Bus.instance.on().listen((event) { if (event.source == TabChangedEventSource.pageView && @@ -72,6 +69,18 @@ class _AllSectionsExamplesProviderState reloadAllSections(); } + void onDataUpdate() { + if (!isOnSearchTab) { + if (kDebugMode) { + _logger.finest('Skip reload till user clicks on search tab'); + } + hasPendingUpdate = true; + } else { + hasPendingUpdate = false; + reloadAllSections(); + } + } + void reloadAllSections() { _logger.info('queue reload all sections'); _debouncer.run(() async { @@ -79,22 +88,28 @@ class _AllSectionsExamplesProviderState _logger.info("'_debounceTimer: reloading all sections in search tab"); final allSectionsExamples = >>[]; for (SectionType sectionType in SectionType.values) { - if (sectionType == SectionType.face || - sectionType == SectionType.content) { + if (sectionType == SectionType.content) { continue; } allSectionsExamples.add( sectionType.getData(context, limit: kSearchSectionLimit), ); } - allSectionsExamplesFuture = - Future.wait>(allSectionsExamples); + try { + allSectionsExamplesFuture = Future.wait>( + allSectionsExamples, + eagerError: false, + ); + } catch (e) { + _logger.severe("Error reloading all sections: $e"); + } }); }); } @override void dispose() { + _onPeopleChangedEvent.cancel(); _filesUpdatedEvent.cancel(); _tabChangeEvent.cancel(); _debouncer.cancelDebounce(); diff --git a/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart b/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart index b896e0f1f..a0c50be21 100644 --- a/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart +++ b/mobile/lib/ui/components/bottom_action_bar/bottom_action_bar_widget.dart @@ -1,5 +1,6 @@ import 'package:flutter/material.dart'; import 'package:photos/core/constants.dart'; +import "package:photos/face/model/person.dart"; import 'package:photos/models/collection/collection.dart'; import "package:photos/models/gallery_type.dart"; import 'package:photos/models/selected_files.dart'; @@ -11,6 +12,8 @@ import "package:photos/ui/viewer/actions/file_selection_actions_widget.dart"; class BottomActionBarWidget extends StatelessWidget { final GalleryType galleryType; final Collection? collection; + final PersonEntity? person; + final int? clusterID; final SelectedFiles selectedFiles; final VoidCallback? onCancel; final Color? backgroundColor; @@ -19,6 +22,8 @@ class BottomActionBarWidget extends StatelessWidget { required this.galleryType, required this.selectedFiles, this.collection, + this.person, + this.clusterID, this.onCancel, this.backgroundColor, super.key, @@ -54,6 +59,8 @@ class BottomActionBarWidget extends StatelessWidget { galleryType, selectedFiles, collection: collection, + person: person, + clusterID: clusterID, ), const DividerWidget(dividerType: DividerType.bottomBar), ActionBarWidget( diff --git a/mobile/lib/ui/components/buttons/icon_button_widget.dart b/mobile/lib/ui/components/buttons/icon_button_widget.dart index 258b339d7..3e51f8789 100644 --- a/mobile/lib/ui/components/buttons/icon_button_widget.dart +++ b/mobile/lib/ui/components/buttons/icon_button_widget.dart @@ -17,6 +17,7 @@ class IconButtonWidget extends StatefulWidget { final Color? pressedColor; final Color? iconColor; final double size; + final bool roundedIcon; const IconButtonWidget({ required this.icon, required this.iconButtonType, @@ -26,6 +27,7 @@ class IconButtonWidget extends StatefulWidget { this.pressedColor, this.iconColor, this.size = 24, + this.roundedIcon = true, super.key, }); @@ -68,22 +70,31 @@ class _IconButtonWidgetState extends State { Widget _iconButton(EnteColorScheme colorTheme) { return Padding( padding: const EdgeInsets.all(4.0), - child: AnimatedContainer( - duration: const Duration(milliseconds: 20), - padding: const EdgeInsets.all(8), - decoration: BoxDecoration( - borderRadius: BorderRadius.circular(widget.size), - color: iconStateColor, - ), - child: Icon( - widget.icon, - color: widget.iconColor ?? - (widget.iconButtonType == IconButtonType.secondary - ? colorTheme.strokeMuted - : colorTheme.strokeBase), - size: widget.size, - ), - ), + child: widget.roundedIcon + ? AnimatedContainer( + duration: const Duration(milliseconds: 20), + padding: const EdgeInsets.all(8), + decoration: BoxDecoration( + borderRadius: BorderRadius.circular(widget.size), + color: iconStateColor, + ), + child: Icon( + widget.icon, + color: widget.iconColor ?? + (widget.iconButtonType == IconButtonType.secondary + ? colorTheme.strokeMuted + : colorTheme.strokeBase), + size: widget.size, + ), + ) + : Icon( + widget.icon, + color: widget.iconColor ?? + (widget.iconButtonType == IconButtonType.secondary + ? colorTheme.strokeMuted + : colorTheme.strokeBase), + size: widget.size, + ), ); } diff --git a/mobile/lib/ui/components/info_item_widget.dart b/mobile/lib/ui/components/info_item_widget.dart index 5bec95ccf..73517e052 100644 --- a/mobile/lib/ui/components/info_item_widget.dart +++ b/mobile/lib/ui/components/info_item_widget.dart @@ -11,6 +11,7 @@ class InfoItemWidget extends StatelessWidget { final Widget? endSection; final Future> subtitleSection; final bool hasChipButtons; + final bool biggerSpinner; final VoidCallback? onTap; const InfoItemWidget({ required this.leadingIcon, @@ -19,6 +20,7 @@ class InfoItemWidget extends StatelessWidget { this.endSection, required this.subtitleSection, this.hasChipButtons = false, + this.biggerSpinner = false, this.onTap, super.key, }); @@ -57,10 +59,11 @@ class InfoItemWidget extends StatelessWidget { } } else { child = EnteLoadingWidget( - padding: 3, - size: 11, + padding: biggerSpinner ? 6 : 3, + size: biggerSpinner ? 20 : 11, color: getEnteColorScheme(context).strokeMuted, - alignment: Alignment.centerLeft, + alignment: + biggerSpinner ? Alignment.center : Alignment.centerLeft, ); } return AnimatedSwitcher( diff --git a/mobile/lib/ui/components/notification_widget.dart b/mobile/lib/ui/components/notification_widget.dart index 6779a58fa..864e4c29c 100644 --- a/mobile/lib/ui/components/notification_widget.dart +++ b/mobile/lib/ui/components/notification_widget.dart @@ -10,6 +10,7 @@ import 'package:photos/ui/components/buttons/icon_button_widget.dart'; enum NotificationType { warning, banner, + greenBanner, goldenBanner, notice, } @@ -67,6 +68,18 @@ class NotificationWidget extends StatelessWidget { ); boxShadow = Theme.of(context).colorScheme.enteTheme.shadowMenu; break; + case NotificationType.greenBanner: + backgroundGradient = LinearGradient( + colors: [ + getEnteColorScheme(context).primary700, + getEnteColorScheme(context).primary500, + ], + stops: const [0.25, 1], + begin: Alignment.bottomCenter, + end: Alignment.topCenter, + ); + boxShadow = Theme.of(context).colorScheme.enteTheme.shadowMenu; + break; case NotificationType.notice: backgroundColor = colorScheme.backgroundElevated2; mainTextStyle = textTheme.bodyBold; diff --git a/mobile/lib/ui/settings/debug_section_widget.dart b/mobile/lib/ui/settings/debug/debug_section_widget.dart similarity index 99% rename from mobile/lib/ui/settings/debug_section_widget.dart rename to mobile/lib/ui/settings/debug/debug_section_widget.dart index 039655ca3..56070c214 100644 --- a/mobile/lib/ui/settings/debug_section_widget.dart +++ b/mobile/lib/ui/settings/debug/debug_section_widget.dart @@ -67,7 +67,6 @@ class DebugSectionWidget extends StatelessWidget { showShortToast(context, "Done"); }, ), - sectionOptionSpacing, ], ); } diff --git a/mobile/lib/ui/settings/debug/face_debug_section_widget.dart b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart new file mode 100644 index 000000000..01b10ff80 --- /dev/null +++ b/mobile/lib/ui/settings/debug/face_debug_section_widget.dart @@ -0,0 +1,341 @@ +import "dart:async"; + +import "package:flutter/foundation.dart"; +import 'package:flutter/material.dart'; +import "package:logging/logging.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import 'package:photos/services/machine_learning/face_ml/face_ml_service.dart'; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import 'package:photos/theme/ente_theme.dart'; +import 'package:photos/ui/components/captioned_text_widget.dart'; +import 'package:photos/ui/components/expandable_menu_item_widget.dart'; +import 'package:photos/ui/components/menu_item_widget/menu_item_widget.dart'; +import 'package:photos/ui/settings/common_settings.dart'; +import "package:photos/utils/dialog_util.dart"; +import "package:photos/utils/local_settings.dart"; +import 'package:photos/utils/toast_util.dart'; + +class FaceDebugSectionWidget extends StatefulWidget { + const FaceDebugSectionWidget({Key? key}) : super(key: key); + + @override + State createState() => _FaceDebugSectionWidgetState(); +} + +class _FaceDebugSectionWidgetState extends State { + Timer? _timer; + @override + void initState() { + super.initState(); + _timer = Timer.periodic(const Duration(seconds: 5), (timer) { + setState(() { + // Your state update logic here + }); + }); + } + + @override + void dispose() { + _timer?.cancel(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + return ExpandableMenuItemWidget( + title: "Faces Debug", + selectionOptionsWidget: _getSectionOptions(context), + leadingIcon: Icons.bug_report_outlined, + ); + } + + Widget _getSectionOptions(BuildContext context) { + final Logger _logger = Logger("FaceDebugSectionWidget"); + return Column( + children: [ + MenuItemWidget( + captionedTextWidget: FutureBuilder( + future: FaceMLDataDB.instance.getIndexedFileCount(), + builder: (context, snapshot) { + if (snapshot.hasData) { + return CaptionedTextWidget( + title: LocalSettings.instance.isFaceIndexingEnabled + ? "Disable faces (${snapshot.data!} files done)" + : "Enable faces (${snapshot.data!} files done)", + ); + } + return const SizedBox.shrink(); + }, + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + final isEnabled = + await LocalSettings.instance.toggleFaceIndexing(); + if (!isEnabled) { + FaceMlService.instance.pauseIndexing(); + } + if (mounted) { + setState(() {}); + } + } catch (e, s) { + _logger.warning('indexing failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: CaptionedTextWidget( + title: LocalSettings.instance.remoteFetchEnabled + ? "Remote fetch enabled" + : "Remote fetch disabled", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + await LocalSettings.instance.toggleRemoteFetch(); + if (mounted) { + setState(() {}); + } + } catch (e, s) { + _logger.warning('indexing failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: CaptionedTextWidget( + title: FaceMlService.instance.canRunMLController + ? "canRunML enabled" + : "canRunML disabled", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + FaceMlService.instance.canRunMLController = + !FaceMlService.instance.canRunMLController; + if (mounted) { + setState(() {}); + } + } catch (e, s) { + _logger.warning('canRunML toggle failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Run sync, indexing, clustering", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + unawaited(FaceMlService.instance.indexAndClusterAll()); + } catch (e, s) { + _logger.warning('indexAndClusterAll failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Run indexing", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + unawaited(FaceMlService.instance.indexAllImages()); + } catch (e, s) { + _logger.warning('indexing failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: FutureBuilder( + future: FaceMLDataDB.instance.getClusteredToTotalFacesRatio(), + builder: (context, snapshot) { + if (snapshot.hasData) { + return CaptionedTextWidget( + title: + "Run clustering (${(100 * snapshot.data!).toStringAsFixed(0)}% done)", + ); + } + return const SizedBox.shrink(); + }, + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + await PersonService.instance.storeRemoteFeedback(); + await FaceMlService.instance + .clusterAllImages(clusterInBuckets: true); + Bus.instance.fire(PeopleChangedEvent()); + showShortToast(context, "Done"); + } catch (e, s) { + _logger.warning('clustering failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Check for mixed clusters", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + final susClusters = + await ClusterFeedbackService.instance.checkForMixedClusters(); + for (final clusterinfo in susClusters) { + Future.delayed(const Duration(seconds: 4), () { + showToast( + context, + 'Cluster with ${clusterinfo.$2} photos is sus', + ); + }); + } + } catch (e, s) { + _logger.warning('Checking for mixed clusters failed', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Sync person mappings ", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + try { + await PersonService.instance.reconcileClusters(); + Bus.instance.fire(PeopleChangedEvent()); + showShortToast(context, "Done"); + } catch (e, s) { + _logger.warning('sync person mappings failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Reset feedback", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + alwaysShowSuccessState: true, + onTap: () async { + await showChoiceDialog( + context, + title: "Are you sure?", + body: + "This will drop all people and their related feedback. It will keep clustering labels and embeddings untouched.", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + try { + await FaceMLDataDB.instance.dropFeedbackTables(); + Bus.instance.fire(PeopleChangedEvent()); + showShortToast(context, "Done"); + } catch (e, s) { + _logger.warning('reset feedback failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ); + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Reset feedback and clustering", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + await showChoiceDialog( + context, + title: "Are you sure?", + body: + "This will delete all people, their related feedback and clustering labels. It will keep embeddings untouched.", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + try { + final List persons = + await PersonService.instance.getPersons(); + for (final PersonEntity p in persons) { + await PersonService.instance.deletePerson(p.remoteID); + } + await FaceMLDataDB.instance.dropClustersAndPersonTable(); + Bus.instance.fire(PeopleChangedEvent()); + showShortToast(context, "Done"); + } catch (e, s) { + _logger.warning('peopleToPersonMapping remove failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ); + }, + ), + sectionOptionSpacing, + MenuItemWidget( + captionedTextWidget: const CaptionedTextWidget( + title: "Reset everything (embeddings)", + ), + pressedColor: getEnteColorScheme(context).fillFaint, + trailingIcon: Icons.chevron_right_outlined, + trailingIconIsMuted: true, + onTap: () async { + await showChoiceDialog( + context, + title: "Are you sure?", + body: + "You will need to again re-index all the faces. You can drop feedback if you want to label again", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + try { + await FaceMLDataDB.instance + .dropClustersAndPersonTable(faces: true); + Bus.instance.fire(PeopleChangedEvent()); + showShortToast(context, "Done"); + } catch (e, s) { + _logger.warning('drop feedback failed ', e, s); + await showGenericErrorDialog(context: context, error: e); + } + }, + ); + }, + ), + ], + ); + } +} diff --git a/mobile/lib/ui/settings/machine_learning_settings_page.dart b/mobile/lib/ui/settings/machine_learning_settings_page.dart index 9820d882f..9fa98e46f 100644 --- a/mobile/lib/ui/settings/machine_learning_settings_page.dart +++ b/mobile/lib/ui/settings/machine_learning_settings_page.dart @@ -1,11 +1,15 @@ import "dart:async"; +import "dart:math" show max, min; import "package:flutter/material.dart"; import "package:intl/intl.dart"; import "package:photos/core/event_bus.dart"; import 'package:photos/events/embedding_updated_event.dart'; +import "package:photos/face/db.dart"; import "package:photos/generated/l10n.dart"; +import "package:photos/models/ml/ml_versions.dart"; import "package:photos/service_locator.dart"; +import "package:photos/services/machine_learning/face_ml/face_ml_service.dart"; import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart'; import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart'; import "package:photos/services/remote_assets_service.dart"; @@ -60,6 +64,7 @@ class _MachineLearningSettingsPageState @override Widget build(BuildContext context) { + final bool facesFlag = flagService.faceSearchEnabled; return Scaffold( body: CustomScrollView( primary: false, @@ -91,6 +96,10 @@ class _MachineLearningSettingsPageState mainAxisSize: MainAxisSize.min, children: [ _getMagicSearchSettings(context), + const SizedBox(height: 12), + facesFlag + ? _getFacesSearchSettings(context) + : const SizedBox.shrink(), ], ), ), @@ -176,6 +185,51 @@ class _MachineLearningSettingsPageState ], ); } + + Widget _getFacesSearchSettings(BuildContext context) { + final colorScheme = getEnteColorScheme(context); + final hasEnabled = LocalSettings.instance.isFaceIndexingEnabled; + return Column( + children: [ + MenuItemWidget( + captionedTextWidget: CaptionedTextWidget( + title: S.of(context).faceRecognition, + ), + menuItemColor: colorScheme.fillFaint, + trailingWidget: ToggleSwitchWidget( + value: () => LocalSettings.instance.isFaceIndexingEnabled, + onChanged: () async { + final isEnabled = + await LocalSettings.instance.toggleFaceIndexing(); + if (isEnabled) { + unawaited(FaceMlService.instance.ensureInitialized()); + } else { + FaceMlService.instance.pauseIndexing(); + } + if (mounted) { + setState(() {}); + } + }, + ), + singleBorderRadius: 8, + alignCaptionedTextToLeft: true, + isGestureDetectorDisabled: true, + ), + const SizedBox( + height: 4, + ), + MenuSectionDescriptionWidget( + content: S.of(context).faceRecognitionIndexingDescription, + ), + const SizedBox( + height: 12, + ), + hasEnabled + ? const FaceRecognitionStatusWidget() + : const SizedBox.shrink(), + ], + ); + } } class ModelLoadingState extends StatefulWidget { @@ -356,3 +410,134 @@ class _MagicSearchIndexStatsWidgetState ); } } + +class FaceRecognitionStatusWidget extends StatefulWidget { + const FaceRecognitionStatusWidget({ + super.key, + }); + + @override + State createState() => + FaceRecognitionStatusWidgetState(); +} + +class FaceRecognitionStatusWidgetState + extends State { + Timer? _timer; + @override + void initState() { + super.initState(); + _timer = Timer.periodic(const Duration(seconds: 10), (timer) { + setState(() { + // Your state update logic here + }); + }); + } + + Future<(int, int, int, double)> getIndexStatus() async { + final indexedFiles = await FaceMLDataDB.instance + .getIndexedFileCount(minimumMlVersion: faceMlVersion); + final indexableFiles = (await FaceMlService.getIndexableFileIDs()).length; + final showIndexedFiles = min(indexedFiles, indexableFiles); + final pendingFiles = max(indexableFiles - indexedFiles, 0); + final foundFaces = await FaceMLDataDB.instance.getTotalFaceCount(); + final clusteredFaces = await FaceMLDataDB.instance.getClusteredFaceCount(); + final clusteringDoneRatio = clusteredFaces / foundFaces; + + return (showIndexedFiles, pendingFiles, foundFaces, clusteringDoneRatio); + } + + @override + void dispose() { + _timer?.cancel(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + return Column( + children: [ + Row( + children: [ + MenuSectionTitle(title: S.of(context).status), + Expanded(child: Container()), + ], + ), + FutureBuilder( + future: getIndexStatus(), + builder: (context, snapshot) { + if (snapshot.hasData) { + final int indexedFiles = snapshot.data!.$1; + final int pendingFiles = snapshot.data!.$2; + final int foundFaces = snapshot.data!.$3; + final double clusteringDoneRatio = snapshot.data!.$4; + final double clusteringPercentage = + (clusteringDoneRatio * 100).clamp(0, 100); + + return Column( + children: [ + MenuItemWidget( + captionedTextWidget: CaptionedTextWidget( + title: S.of(context).indexedItems, + ), + trailingWidget: Text( + NumberFormat().format(indexedFiles), + style: Theme.of(context).textTheme.bodySmall, + ), + singleBorderRadius: 8, + alignCaptionedTextToLeft: true, + isGestureDetectorDisabled: true, + key: ValueKey("indexed_items_" + indexedFiles.toString()), + ), + MenuItemWidget( + captionedTextWidget: CaptionedTextWidget( + title: S.of(context).pendingItems, + ), + trailingWidget: Text( + NumberFormat().format(pendingFiles), + style: Theme.of(context).textTheme.bodySmall, + ), + singleBorderRadius: 8, + alignCaptionedTextToLeft: true, + isGestureDetectorDisabled: true, + key: ValueKey("pending_items_" + pendingFiles.toString()), + ), + MenuItemWidget( + captionedTextWidget: CaptionedTextWidget( + title: S.of(context).foundFaces, + ), + trailingWidget: Text( + NumberFormat().format(foundFaces), + style: Theme.of(context).textTheme.bodySmall, + ), + singleBorderRadius: 8, + alignCaptionedTextToLeft: true, + isGestureDetectorDisabled: true, + key: ValueKey("found_faces_" + foundFaces.toString()), + ), + MenuItemWidget( + captionedTextWidget: CaptionedTextWidget( + title: S.of(context).clusteringProgress, + ), + trailingWidget: Text( + "${clusteringPercentage.toStringAsFixed(0)}%", + style: Theme.of(context).textTheme.bodySmall, + ), + singleBorderRadius: 8, + alignCaptionedTextToLeft: true, + isGestureDetectorDisabled: true, + key: ValueKey( + "clustering_progress_" + + clusteringPercentage.toStringAsFixed(0), + ), + ), + ], + ); + } + return const EnteLoadingWidget(); + }, + ), + ], + ); + } +} diff --git a/mobile/lib/ui/settings_page.dart b/mobile/lib/ui/settings_page.dart index d5ba1254f..cc0064a30 100644 --- a/mobile/lib/ui/settings_page.dart +++ b/mobile/lib/ui/settings_page.dart @@ -17,7 +17,8 @@ import 'package:photos/ui/settings/about_section_widget.dart'; import 'package:photos/ui/settings/account_section_widget.dart'; import 'package:photos/ui/settings/app_version_widget.dart'; import 'package:photos/ui/settings/backup/backup_section_widget.dart'; -import 'package:photos/ui/settings/debug_section_widget.dart'; +import 'package:photos/ui/settings/debug/debug_section_widget.dart'; +import "package:photos/ui/settings/debug/face_debug_section_widget.dart"; import "package:photos/ui/settings/developer_settings_widget.dart"; import 'package:photos/ui/settings/general_section_widget.dart'; import 'package:photos/ui/settings/inherited_settings_state.dart'; @@ -53,6 +54,7 @@ class SettingsPage extends StatelessWidget { final hasLoggedIn = Configuration.instance.isLoggedIn(); final enteTextTheme = getEnteTextTheme(context); final List contents = []; + const sectionSpacing = SizedBox(height: 8); contents.add( GestureDetector( onDoubleTap: () { @@ -82,7 +84,7 @@ class SettingsPage extends StatelessWidget { ), ), ); - const sectionSpacing = SizedBox(height: 8); + contents.add(const SizedBox(height: 8)); if (hasLoggedIn) { final showStorageBonusBanner = @@ -142,6 +144,9 @@ class SettingsPage extends StatelessWidget { if (hasLoggedIn && flagService.internalUser) { contents.addAll([sectionSpacing, const DebugSectionWidget()]); + if (flagService.faceSearchEnabled) { + contents.addAll([sectionSpacing, const FaceDebugSectionWidget()]); + } } contents.add(const AppVersionWidget()); contents.add(const DeveloperSettingsWidget()); diff --git a/mobile/lib/ui/tools/app_lock.dart b/mobile/lib/ui/tools/app_lock.dart index c27555df0..c9af24f71 100644 --- a/mobile/lib/ui/tools/app_lock.dart +++ b/mobile/lib/ui/tools/app_lock.dart @@ -113,6 +113,7 @@ class _AppLockState extends State with WidgetsBindingObserver { theme: widget.lightTheme, darkTheme: widget.darkTheme, locale: widget.locale, + debugShowCheckedModeBanner: false, supportedLocales: appSupportedLocales, localeListResolutionCallback: localResolutionCallBack, localizationsDelegates: const [ diff --git a/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart b/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart index e805927a6..beeb9164d 100644 --- a/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart +++ b/mobile/lib/ui/viewer/actions/file_selection_actions_widget.dart @@ -1,11 +1,15 @@ import "dart:async"; import 'package:fast_base58/fast_base58.dart'; +import "package:flutter/cupertino.dart"; import 'package:flutter/material.dart'; import 'package:flutter/services.dart'; import "package:logging/logging.dart"; import "package:modal_bottom_sheet/modal_bottom_sheet.dart"; import 'package:photos/core/configuration.dart'; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/model/person.dart"; import "package:photos/generated/l10n.dart"; import 'package:photos/models/collection/collection.dart'; import 'package:photos/models/device_collection.dart'; @@ -17,6 +21,8 @@ import "package:photos/models/metadata/common_keys.dart"; import 'package:photos/models/selected_files.dart'; import 'package:photos/services/collections_service.dart'; import 'package:photos/services/hidden_service.dart'; +import 'package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart'; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; import "package:photos/theme/colors.dart"; import "package:photos/theme/ente_theme.dart"; import 'package:photos/ui/actions/collection/collection_file_actions.dart'; @@ -42,12 +48,16 @@ class FileSelectionActionsWidget extends StatefulWidget { final Collection? collection; final DeviceCollection? deviceCollection; final SelectedFiles selectedFiles; + final PersonEntity? person; + final int? clusterID; const FileSelectionActionsWidget( this.type, this.selectedFiles, { Key? key, this.collection, + this.person, + this.clusterID, this.deviceCollection, }) : super(key: key); @@ -123,7 +133,24 @@ class _FileSelectionActionsWidgetState //and set [shouldShow] to false for items that should not be shown and true //for items that should be shown. final List items = []; - + if (widget.type == GalleryType.peopleTag && widget.person != null) { + items.add( + SelectionActionButton( + icon: Icons.remove_circle_outline, + labelText: 'Not ${widget.person!.data.name}?', + onTap: anyUploadedFiles ? _onNotpersonClicked : null, + ), + ); + if (ownedFilesCount == 1) { + items.add( + SelectionActionButton( + icon: Icons.image_outlined, + labelText: 'Use as cover', + onTap: anyUploadedFiles ? _setPersonCover : null, + ), + ); + } + } if (widget.type.showCreateLink()) { if (_cachedCollectionForSharedLink != null && anyUploadedFiles) { items.add( @@ -390,36 +417,50 @@ class _FileSelectionActionsWidgetState ), ); - final scrollController = ScrollController(); - // h4ck: https://github.com/flutter/flutter/issues/57920#issuecomment-893970066 - return MediaQuery( - data: MediaQuery.of(context).removePadding(removeBottom: true), - child: SafeArea( - child: Scrollbar( - radius: const Radius.circular(1), - thickness: 2, - controller: scrollController, - thumbVisibility: true, - child: SingleChildScrollView( - physics: const BouncingScrollPhysics( - decelerationRate: ScrollDecelerationRate.fast, - ), - scrollDirection: Axis.horizontal, - child: Container( - padding: const EdgeInsets.only(bottom: 24), - child: Row( - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - const SizedBox(width: 4), - ...items, - const SizedBox(width: 4), - ], + // if (widget.type == GalleryType.cluster && widget.clusterID != null) { + if (widget.type == GalleryType.cluster && widget.clusterID != null) { + items.add( + SelectionActionButton( + labelText: 'Remove', + icon: CupertinoIcons.minus, + onTap: anyUploadedFiles ? _onRemoveFromClusterClicked : null, + ), + ); + } + + if (items.isNotEmpty) { + final scrollController = ScrollController(); + // h4ck: https://github.com/flutter/flutter/issues/57920#issuecomment-893970066 + return MediaQuery( + data: MediaQuery.of(context).removePadding(removeBottom: true), + child: SafeArea( + child: Scrollbar( + radius: const Radius.circular(1), + thickness: 2, + controller: scrollController, + thumbVisibility: true, + child: SingleChildScrollView( + physics: const BouncingScrollPhysics( + decelerationRate: ScrollDecelerationRate.fast, + ), + scrollDirection: Axis.horizontal, + child: Container( + padding: const EdgeInsets.only(bottom: 24), + child: Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + const SizedBox(width: 4), + ...items, + const SizedBox(width: 4), + ], + ), ), ), ), ), - ), - ); + ); + } + return const SizedBox(); } Future _moveFiles() async { @@ -620,6 +661,101 @@ class _FileSelectionActionsWidgetState } } + Future _setPersonCover() async { + final EnteFile file = widget.selectedFiles.files.first; + await PersonService.instance.updateAttributes( + widget.person!.remoteID, + avatarFaceId: file.uploadedFileID.toString(), + ); + widget.selectedFiles.clearAll(); + if (mounted) { + setState(() => {}); + } + Bus.instance.fire(PeopleChangedEvent()); + } + + Future _onNotpersonClicked() async { + final actionResult = await showActionSheet( + context: context, + buttons: [ + ButtonWidget( + labelText: S.of(context).yesRemove, + buttonType: ButtonType.neutral, + buttonSize: ButtonSize.large, + shouldStickToDarkTheme: true, + buttonAction: ButtonAction.first, + isInAlert: true, + ), + ButtonWidget( + labelText: S.of(context).cancel, + buttonType: ButtonType.secondary, + buttonSize: ButtonSize.large, + buttonAction: ButtonAction.second, + shouldStickToDarkTheme: true, + isInAlert: true, + ), + ], + title: "Remove these photos for ${widget.person!.data.name}?", + actionSheetType: ActionSheetType.defaultActionSheet, + ); + if (actionResult?.action != null) { + if (actionResult!.action == ButtonAction.first) { + await ClusterFeedbackService.instance.removeFilesFromPerson( + widget.selectedFiles.files.toList(), + widget.person!, + ); + } + Bus.instance.fire(PeopleChangedEvent()); + } + widget.selectedFiles.clearAll(); + if (mounted) { + setState(() => {}); + } + } + + Future _onRemoveFromClusterClicked() async { + if (widget.clusterID == null) { + showShortToast(context, 'Cluster ID is null. Cannot remove files.'); + return; + } + final actionResult = await showActionSheet( + context: context, + buttons: [ + ButtonWidget( + labelText: S.of(context).yesRemove, + buttonType: ButtonType.neutral, + buttonSize: ButtonSize.large, + shouldStickToDarkTheme: true, + buttonAction: ButtonAction.first, + isInAlert: true, + ), + ButtonWidget( + labelText: S.of(context).cancel, + buttonType: ButtonType.secondary, + buttonSize: ButtonSize.large, + buttonAction: ButtonAction.second, + shouldStickToDarkTheme: true, + isInAlert: true, + ), + ], + title: "Remove these photos?", + actionSheetType: ActionSheetType.defaultActionSheet, + ); + if (actionResult?.action != null) { + if (actionResult!.action == ButtonAction.first) { + await ClusterFeedbackService.instance.removeFilesFromCluster( + widget.selectedFiles.files.toList(), + widget.clusterID!, + ); + } + Bus.instance.fire(PeopleChangedEvent()); + } + widget.selectedFiles.clearAll(); + if (mounted) { + setState(() => {}); + } + } + Future _copyLink() async { if (_cachedCollectionForSharedLink != null) { final String collectionKey = Base58Encode( diff --git a/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart b/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart index bc832c573..8e2260c74 100644 --- a/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart +++ b/mobile/lib/ui/viewer/actions/file_selection_overlay_bar.dart @@ -1,4 +1,5 @@ import 'package:flutter/material.dart'; +import "package:photos/face/model/person.dart"; import 'package:photos/models/collection/collection.dart'; import 'package:photos/models/gallery_type.dart'; import 'package:photos/models/selected_files.dart'; @@ -10,12 +11,16 @@ class FileSelectionOverlayBar extends StatefulWidget { final SelectedFiles selectedFiles; final Collection? collection; final Color? backgroundColor; + final PersonEntity? person; + final int? clusterID; const FileSelectionOverlayBar( this.galleryType, this.selectedFiles, { this.collection, this.backgroundColor, + this.person, + this.clusterID, Key? key, }) : super(key: key); @@ -65,6 +70,8 @@ class _FileSelectionOverlayBarState extends State { selectedFiles: widget.selectedFiles, galleryType: widget.galleryType, collection: widget.collection, + person: widget.person, + clusterID: widget.clusterID, onCancel: () { if (widget.selectedFiles.files.isNotEmpty) { widget.selectedFiles.clearAll(); diff --git a/mobile/lib/ui/viewer/file/file_details_widget.dart b/mobile/lib/ui/viewer/file/file_details_widget.dart index f8e7abb8e..d87a806cc 100644 --- a/mobile/lib/ui/viewer/file/file_details_widget.dart +++ b/mobile/lib/ui/viewer/file/file_details_widget.dart @@ -1,7 +1,11 @@ +import "dart:async" show StreamSubscription; + import "package:exif/exif.dart"; import "package:flutter/material.dart"; import "package:logging/logging.dart"; import "package:photos/core/configuration.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; import "package:photos/generated/l10n.dart"; import 'package:photos/models/file/file.dart'; import 'package:photos/models/file/file_type.dart'; @@ -18,9 +22,9 @@ import "package:photos/ui/viewer/file_details/albums_item_widget.dart"; import 'package:photos/ui/viewer/file_details/backed_up_time_item_widget.dart'; import "package:photos/ui/viewer/file_details/creation_time_item_widget.dart"; import 'package:photos/ui/viewer/file_details/exif_item_widgets.dart'; +import "package:photos/ui/viewer/file_details/faces_item_widget.dart"; import "package:photos/ui/viewer/file_details/file_properties_item_widget.dart"; import "package:photos/ui/viewer/file_details/location_tags_widget.dart"; -import "package:photos/ui/viewer/file_details/objects_item_widget.dart"; import "package:photos/utils/exif_util.dart"; class FileDetailsWidget extends StatefulWidget { @@ -51,6 +55,8 @@ class _FileDetailsWidgetState extends State { "longRef": null, }; + late final StreamSubscription _peopleChangedEvent; + bool _isImage = false; late int _currentUserID; bool showExifListTile = false; @@ -65,6 +71,10 @@ class _FileDetailsWidgetState extends State { _isImage = widget.file.fileType == FileType.image || widget.file.fileType == FileType.livePhoto; + _peopleChangedEvent = Bus.instance.on().listen((event) { + setState(() {}); + }); + _exifNotifier.addListener(() { if (_exifNotifier.value != null && !widget.file.hasLocation) { _updateLocationFromExif(_exifNotifier.value!).ignore(); @@ -93,6 +103,7 @@ class _FileDetailsWidgetState extends State { @override void dispose() { _exifNotifier.dispose(); + _peopleChangedEvent.cancel(); super.dispose(); } @@ -221,7 +232,8 @@ class _FileDetailsWidgetState extends State { if (!UpdateService.instance.isFdroidFlavor()) { fileDetailsTiles.addAll([ - ObjectsItemWidget(file), + // ObjectsItemWidget(file), + FacesItemWidget(file), const FileDetailsDivider(), ]); } diff --git a/mobile/lib/ui/viewer/file_details/face_widget.dart b/mobile/lib/ui/viewer/file_details/face_widget.dart new file mode 100644 index 000000000..1ec7a2eb2 --- /dev/null +++ b/mobile/lib/ui/viewer/file_details/face_widget.dart @@ -0,0 +1,506 @@ +import "dart:developer" show log; +import "dart:typed_data"; + +import "package:flutter/cupertino.dart"; +import "package:flutter/foundation.dart" show kDebugMode; +import "package:flutter/material.dart"; +import "package:photos/extensions/stop_watch.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/person.dart"; +import 'package:photos/models/file/file.dart'; +import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart"; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; +import "package:photos/services/search_service.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/people/cropped_face_image_view.dart"; +import "package:photos/ui/viewer/people/people_page.dart"; +import "package:photos/utils/face/face_box_crop.dart"; +import "package:photos/utils/thumbnail_util.dart"; +// import "package:photos/utils/toast_util.dart"; + +const useGeneratedFaceCrops = true; + +class FaceWidget extends StatefulWidget { + final EnteFile file; + final Face face; + final Future?>? faceCrops; + final PersonEntity? person; + final int? clusterID; + final bool highlight; + final bool editMode; + + const FaceWidget( + this.file, + this.face, { + this.faceCrops, + this.person, + this.clusterID, + this.highlight = false, + this.editMode = false, + Key? key, + }) : super(key: key); + + @override + State createState() => _FaceWidgetState(); +} + +class _FaceWidgetState extends State { + bool isJustRemoved = false; + + @override + Widget build(BuildContext context) { + final bool givenFaces = widget.faceCrops != null; + if (useGeneratedFaceCrops) { + return _buildFaceImageGenerated(givenFaces); + } else { + return _buildFaceImageFlutterZoom(); + } + } + + Widget _buildFaceImageGenerated(bool givenFaces) { + return FutureBuilder?>( + future: givenFaces ? widget.faceCrops : getFaceCrop(), + builder: (context, snapshot) { + if (snapshot.hasData) { + final ImageProvider imageProvider = + MemoryImage(snapshot.data![widget.face.faceID]!); + + return GestureDetector( + onTap: () async { + if (widget.editMode) return; + + log( + "FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}", + name: "FaceWidget", + ); + if (widget.person == null && widget.clusterID == null) { + // Get faceID and double check that it doesn't belong to an existing clusterID. If it does, push that cluster page + final w = (kDebugMode ? EnteWatch('FaceWidget') : null) + ?..start(); + final existingClusterID = await FaceMLDataDB.instance + .getClusterIDForFaceID(widget.face.faceID); + w?.log('getting existing clusterID for faceID'); + if (existingClusterID != null) { + final fileIdsToClusterIds = + await FaceMLDataDB.instance.getFileIdToClusterIds(); + final files = await SearchService.instance.getAllFiles(); + final clusterFiles = files + .where( + (file) => + fileIdsToClusterIds[file.uploadedFileID] + ?.contains(existingClusterID) ?? + false, + ) + .toList(); + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + clusterFiles, + clusterID: existingClusterID, + ), + ), + ); + } + + // Create new clusterID for the faceID and update DB to assign the faceID to the new clusterID + final int newClusterID = DateTime.now().microsecondsSinceEpoch; + await FaceMLDataDB.instance.updateFaceIdToClusterId( + {widget.face.faceID: newClusterID}, + ); + + // Push page for the new cluster + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + [widget.file], + clusterID: newClusterID, + ), + ), + ); + } + if (widget.person != null) { + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => PeoplePage( + person: widget.person!, + ), + ), + ); + } else if (widget.clusterID != null) { + final fileIdsToClusterIds = + await FaceMLDataDB.instance.getFileIdToClusterIds(); + final files = await SearchService.instance.getAllFiles(); + final clusterFiles = files + .where( + (file) => + fileIdsToClusterIds[file.uploadedFileID] + ?.contains(widget.clusterID) ?? + false, + ) + .toList(); + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + clusterFiles, + clusterID: widget.clusterID!, + ), + ), + ); + } + }, + child: Column( + children: [ + Stack( + children: [ + Container( + height: 60, + width: 60, + decoration: ShapeDecoration( + shape: RoundedRectangleBorder( + borderRadius: const BorderRadius.all( + Radius.elliptical(16, 12), + ), + side: widget.highlight + ? BorderSide( + color: getEnteColorScheme(context).primary700, + width: 1.0, + ) + : BorderSide.none, + ), + ), + child: ClipRRect( + borderRadius: + const BorderRadius.all(Radius.elliptical(16, 12)), + child: SizedBox( + width: 60, + height: 60, + child: Image( + image: imageProvider, + fit: BoxFit.cover, + ), + ), + ), + ), + // TODO: the edges of the green line are still not properly rounded around ClipRRect + if (widget.editMode) + Positioned( + right: 0, + top: 0, + child: GestureDetector( + onTap: _cornerIconPressed, + child: isJustRemoved + ? const Icon( + CupertinoIcons.add_circled_solid, + color: Colors.green, + ) + : const Icon( + Icons.cancel, + color: Colors.red, + ), + ), + ), + ], + ), + const SizedBox(height: 8), + if (widget.person != null) + Text( + widget.person!.data.isIgnored + ? '(ignored)' + : widget.person!.data.name.trim(), + style: Theme.of(context).textTheme.bodySmall, + overflow: TextOverflow.ellipsis, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'S: ${widget.face.score.toStringAsFixed(3)}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'B: ${widget.face.blur.toStringAsFixed(0)}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'D: ${widget.face.detection.getFaceDirection().toDirectionString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'Sideways: ${widget.face.detection.faceIsSideways().toString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode && widget.face.score < 0.75) + Text( + '[Debug only]', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + ], + ), + ); + } else { + if (snapshot.connectionState == ConnectionState.waiting) { + return const ClipRRect( + borderRadius: BorderRadius.all(Radius.elliptical(16, 12)), + child: SizedBox( + width: 60, + height: 60, + child: CircularProgressIndicator(), + ), + ); + } + if (snapshot.hasError) { + log('Error getting face: ${snapshot.error}'); + } + return const ClipRRect( + borderRadius: BorderRadius.all(Radius.elliptical(16, 12)), + child: SizedBox( + width: 60, + height: 60, + child: NoThumbnailWidget(), + ), + ); + } + }, + ); + } + + void _cornerIconPressed() async { + log('face widget (file info) corner icon is pressed'); + try { + if (isJustRemoved) { + await ClusterFeedbackService.instance + .addFilesToCluster([widget.face.faceID], widget.clusterID!); + } else { + await ClusterFeedbackService.instance + .removeFilesFromCluster([widget.file], widget.clusterID!); + } + + setState(() { + isJustRemoved = !isJustRemoved; + }); + } catch (e, s) { + log("removing face/file from cluster from file info widget failed: $e, \n $s"); + } + } + + Future?> getFaceCrop() async { + try { + final Uint8List? cachedFace = faceCropCache.get(widget.face.faceID); + if (cachedFace != null) { + return {widget.face.faceID: cachedFace}; + } + final faceCropCacheFile = cachedFaceCropPath(widget.face.faceID); + if ((await faceCropCacheFile.exists())) { + final data = await faceCropCacheFile.readAsBytes(); + faceCropCache.put(widget.face.faceID, data); + return {widget.face.faceID: data}; + } + + final result = await poolFullFileFaceGenerations.withResource( + () async => await getFaceCrops( + widget.file, + { + widget.face.faceID: widget.face.detection.box, + }, + ), + ); + final Uint8List? computedCrop = result?[widget.face.faceID]; + if (computedCrop != null) { + faceCropCache.put(widget.face.faceID, computedCrop); + faceCropCacheFile.writeAsBytes(computedCrop).ignore(); + } + return {widget.face.faceID: computedCrop!}; + } catch (e, s) { + log( + "Error getting face for faceID: ${widget.face.faceID}", + error: e, + stackTrace: s, + ); + return null; + } + } + + Widget _buildFaceImageFlutterZoom() { + return Builder( + builder: (context) { + return GestureDetector( + onTap: () async { + log( + "FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}", + name: "FaceWidget", + ); + if (widget.person == null && widget.clusterID == null) { + // Get faceID and double check that it doesn't belong to an existing clusterID. If it does, push that cluster page + final w = (kDebugMode ? EnteWatch('FaceWidget') : null)?..start(); + final existingClusterID = await FaceMLDataDB.instance + .getClusterIDForFaceID(widget.face.faceID); + w?.log('getting existing clusterID for faceID'); + if (existingClusterID != null) { + final fileIdsToClusterIds = + await FaceMLDataDB.instance.getFileIdToClusterIds(); + final files = await SearchService.instance.getAllFiles(); + final clusterFiles = files + .where( + (file) => + fileIdsToClusterIds[file.uploadedFileID] + ?.contains(existingClusterID) ?? + false, + ) + .toList(); + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + clusterFiles, + clusterID: existingClusterID, + ), + ), + ); + } + + // Create new clusterID for the faceID and update DB to assign the faceID to the new clusterID + final int newClusterID = DateTime.now().microsecondsSinceEpoch; + await FaceMLDataDB.instance.updateFaceIdToClusterId( + {widget.face.faceID: newClusterID}, + ); + + // Push page for the new cluster + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + [widget.file], + clusterID: newClusterID, + ), + ), + ); + } + if (widget.person != null) { + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => PeoplePage( + person: widget.person!, + ), + ), + ); + } else if (widget.clusterID != null) { + final fileIdsToClusterIds = + await FaceMLDataDB.instance.getFileIdToClusterIds(); + final files = await SearchService.instance.getAllFiles(); + final clusterFiles = files + .where( + (file) => + fileIdsToClusterIds[file.uploadedFileID] + ?.contains(widget.clusterID) ?? + false, + ) + .toList(); + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + clusterFiles, + clusterID: widget.clusterID!, + ), + ), + ); + } + }, + child: Column( + children: [ + Stack( + children: [ + Container( + height: 60, + width: 60, + decoration: ShapeDecoration( + shape: RoundedRectangleBorder( + borderRadius: const BorderRadius.all( + Radius.elliptical(16, 12), + ), + side: widget.highlight + ? BorderSide( + color: getEnteColorScheme(context).primary700, + width: 1.0, + ) + : BorderSide.none, + ), + ), + child: ClipRRect( + borderRadius: + const BorderRadius.all(Radius.elliptical(16, 12)), + child: SizedBox( + width: 60, + height: 60, + child: CroppedFaceImageView( + enteFile: widget.file, + face: widget.face, + ), + ), + ), + ), + if (widget.editMode) + Positioned( + right: 0, + top: 0, + child: GestureDetector( + onTap: _cornerIconPressed, + child: isJustRemoved + ? const Icon( + CupertinoIcons.add_circled_solid, + color: Colors.green, + ) + : const Icon( + Icons.cancel, + color: Colors.red, + ), + ), + ), + ], + ), + const SizedBox(height: 8), + if (widget.person != null) + Text( + widget.person!.data.name.trim(), + style: Theme.of(context).textTheme.bodySmall, + overflow: TextOverflow.ellipsis, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'S: ${widget.face.score.toStringAsFixed(3)}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'B: ${widget.face.blur.toStringAsFixed(0)}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'D: ${widget.face.detection.getFaceDirection().toDirectionString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + if (kDebugMode) + Text( + 'Sideways: ${widget.face.detection.faceIsSideways().toString()}', + style: Theme.of(context).textTheme.bodySmall, + maxLines: 1, + ), + ], + ), + ); + }, + ); + } +} diff --git a/mobile/lib/ui/viewer/file_details/faces_item_widget.dart b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart new file mode 100644 index 000000000..ed2fb0f12 --- /dev/null +++ b/mobile/lib/ui/viewer/file_details/faces_item_widget.dart @@ -0,0 +1,229 @@ +import "dart:developer" as dev show log; + +import "package:flutter/foundation.dart" show Uint8List, kDebugMode; +import "package:flutter/material.dart"; +import "package:logging/logging.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/box.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import "package:photos/ui/components/buttons/chip_button_widget.dart"; +import "package:photos/ui/components/info_item_widget.dart"; +import "package:photos/ui/viewer/file_details/face_widget.dart"; +import "package:photos/utils/face/face_box_crop.dart"; +import "package:photos/utils/thumbnail_util.dart"; + +class FacesItemWidget extends StatefulWidget { + final EnteFile file; + const FacesItemWidget(this.file, {super.key}); + + @override + State createState() => _FacesItemWidgetState(); +} + +class _FacesItemWidgetState extends State { + bool editMode = false; + + @override + void initState() { + super.initState(); + setState(() {}); + } + + @override + Widget build(BuildContext context) { + return InfoItemWidget( + key: const ValueKey("Faces"), + leadingIcon: Icons.face_retouching_natural_outlined, + subtitleSection: _faceWidgets(context, widget.file, editMode), + hasChipButtons: true, + biggerSpinner: true, + // editOnTap: _toggleEditMode, // TODO: re-enable at later time when the UI is less ugly + ); + } + + void _toggleEditMode() { + setState(() { + editMode = !editMode; + }); + } + + Future> _faceWidgets( + BuildContext context, + EnteFile file, + bool editMode, + ) async { + try { + if (file.uploadedFileID == null) { + return [ + const ChipButtonWidget( + "File not uploaded yet", + noChips: true, + ), + ]; + } + + final List? faces = await FaceMLDataDB.instance + .getFacesForGivenFileID(file.uploadedFileID!); + if (faces == null) { + return [ + const ChipButtonWidget( + "Image not analyzed", + noChips: true, + ), + ]; + } + + // Remove faces with low scores + if (!kDebugMode) { + faces.removeWhere((face) => (face.score < 0.75)); + } else { + faces.removeWhere((face) => (face.score < 0.5)); + } + + if (faces.isEmpty) { + return [ + const ChipButtonWidget( + "No faces found", + noChips: true, + ), + ]; + } + + final faceIdsToClusterIds = await FaceMLDataDB.instance + .getFaceIdsToClusterIds(faces.map((face) => face.faceID)); + final Map persons = + await PersonService.instance.getPersonsMap(); + final clusterIDToPerson = + await FaceMLDataDB.instance.getClusterIDToPersonID(); + + // Sort faces by name and score + final faceIdToPersonID = {}; + for (final face in faces) { + final clusterID = faceIdsToClusterIds[face.faceID]; + if (clusterID != null) { + final personID = clusterIDToPerson[clusterID]; + if (personID != null) { + faceIdToPersonID[face.faceID] = personID; + } + } + } + faces.sort((Face a, Face b) { + final aPersonID = faceIdToPersonID[a.faceID]; + final bPersonID = faceIdToPersonID[b.faceID]; + if (aPersonID != null && bPersonID == null) { + return -1; + } else if (aPersonID == null && bPersonID != null) { + return 1; + } else { + return b.score.compareTo(a.score); + } + }); + // Make sure hidden faces are last + faces.sort((Face a, Face b) { + final aIsHidden = + persons[faceIdToPersonID[a.faceID]]?.data.isIgnored ?? false; + final bIsHidden = + persons[faceIdToPersonID[b.faceID]]?.data.isIgnored ?? false; + if (aIsHidden && !bIsHidden) { + return 1; + } else if (!aIsHidden && bIsHidden) { + return -1; + } else { + return 0; + } + }); + + final lastViewedClusterID = ClusterFeedbackService.lastViewedClusterID; + + final faceWidgets = []; + + // await generation of the face crops here, so that the file info shows one central loading spinner + final _ = await getRelevantFaceCrops(faces); + + final faceCrops = getRelevantFaceCrops(faces); + for (final Face face in faces) { + final int? clusterID = faceIdsToClusterIds[face.faceID]; + final PersonEntity? person = clusterIDToPerson[clusterID] != null + ? persons[clusterIDToPerson[clusterID]!] + : null; + final highlight = + (clusterID == lastViewedClusterID) && (person == null); + faceWidgets.add( + FaceWidget( + file, + face, + faceCrops: faceCrops, + clusterID: clusterID, + person: person, + highlight: highlight, + editMode: highlight ? editMode : false, + ), + ); + } + + return faceWidgets; + } catch (e, s) { + Logger("FacesItemWidget").info(e, s); + return []; + } + } + + Future?> getRelevantFaceCrops( + Iterable faces, + ) async { + try { + final faceIdToCrop = {}; + final facesWithoutCrops = {}; + for (final face in faces) { + final Uint8List? cachedFace = faceCropCache.get(face.faceID); + if (cachedFace != null) { + faceIdToCrop[face.faceID] = cachedFace; + } else { + final faceCropCacheFile = cachedFaceCropPath(face.faceID); + if ((await faceCropCacheFile.exists())) { + final data = await faceCropCacheFile.readAsBytes(); + faceCropCache.put(face.faceID, data); + faceIdToCrop[face.faceID] = data; + } else { + facesWithoutCrops[face.faceID] = face.detection.box; + } + } + } + + if (facesWithoutCrops.isEmpty) { + return faceIdToCrop; + } + + final result = await poolFullFileFaceGenerations.withResource( + () async => await getFaceCrops( + widget.file, + facesWithoutCrops, + ), + ); + if (result == null) { + return (faceIdToCrop.isEmpty) ? null : faceIdToCrop; + } + for (final entry in result.entries) { + final Uint8List? computedCrop = result[entry.key]; + if (computedCrop != null) { + faceCropCache.put(entry.key, computedCrop); + final faceCropCacheFile = cachedFaceCropPath(entry.key); + faceCropCacheFile.writeAsBytes(computedCrop).ignore(); + faceIdToCrop[entry.key] = computedCrop; + } + } + return (faceIdToCrop.isEmpty) ? null : faceIdToCrop; + } catch (e, s) { + dev.log( + "Error getting face crops for faceIDs: ${faces.map((face) => face.faceID).toList()}", + error: e, + stackTrace: s, + ); + return null; + } + } +} diff --git a/mobile/lib/ui/viewer/file_details/objects_item_widget.dart b/mobile/lib/ui/viewer/file_details/objects_item_widget.dart index 5b91b9b12..c02576c11 100644 --- a/mobile/lib/ui/viewer/file_details/objects_item_widget.dart +++ b/mobile/lib/ui/viewer/file_details/objects_item_widget.dart @@ -27,6 +27,7 @@ class ObjectsItemWidget extends StatelessWidget { try { final chipButtons = []; var objectTags = {}; + // final thumbnail = await getThumbnail(file); // if (thumbnail != null) { // objectTags = await ObjectDetectionService.instance.predict(thumbnail); diff --git a/mobile/lib/ui/viewer/gallery/gallery_app_bar_widget.dart b/mobile/lib/ui/viewer/gallery/gallery_app_bar_widget.dart index d2b7a6ec3..c62d1f738 100644 --- a/mobile/lib/ui/viewer/gallery/gallery_app_bar_widget.dart +++ b/mobile/lib/ui/viewer/gallery/gallery_app_bar_widget.dart @@ -3,6 +3,7 @@ import 'dart:io'; import 'dart:math' as math; import "package:flutter/cupertino.dart"; +import "package:flutter/foundation.dart"; import 'package:flutter/material.dart'; import 'package:logging/logging.dart'; import 'package:photos/core/configuration.dart'; @@ -736,7 +737,7 @@ class _GalleryAppBarWidgetState extends State { // stop any existing cast session gw.revokeAllTokens().ignore(); - if (!Platform.isAndroid) { + if (!Platform.isAndroid && !kDebugMode) { await _pairWithPin(gw, ''); } else { final result = await showDialog( diff --git a/mobile/lib/ui/viewer/people/add_person_action_sheet.dart b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart new file mode 100644 index 000000000..7a0c3a471 --- /dev/null +++ b/mobile/lib/ui/viewer/people/add_person_action_sheet.dart @@ -0,0 +1,324 @@ +import "dart:async"; +import "dart:developer"; +import "dart:math" as math; + +import 'package:flutter/material.dart'; +import "package:logging/logging.dart"; +import 'package:modal_bottom_sheet/modal_bottom_sheet.dart'; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/generated/l10n.dart"; +import "package:photos/models/file/file.dart"; +import 'package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart'; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import "package:photos/services/search_service.dart"; +import 'package:photos/theme/colors.dart'; +import 'package:photos/theme/ente_theme.dart'; +import 'package:photos/ui/common/loading_widget.dart'; +import 'package:photos/ui/components/bottom_of_title_bar_widget.dart'; +import 'package:photos/ui/components/buttons/button_widget.dart'; +import 'package:photos/ui/components/models/button_type.dart'; +import "package:photos/ui/components/text_input_widget.dart"; +import 'package:photos/ui/components/title_bar_title_widget.dart'; +import "package:photos/ui/viewer/people/new_person_item_widget.dart"; +import "package:photos/ui/viewer/people/person_row_item.dart"; +import "package:photos/utils/dialog_util.dart"; +import "package:photos/utils/toast_util.dart"; + +enum PersonActionType { + assignPerson, +} + +String _actionName( + BuildContext context, + PersonActionType type, +) { + String text = ""; + switch (type) { + case PersonActionType.assignPerson: + text = "Add name"; + break; + } + return text; +} + +Future showAssignPersonAction( + BuildContext context, { + required int clusterID, + PersonActionType actionType = PersonActionType.assignPerson, + bool showOptionToCreateNewAlbum = true, +}) { + return showBarModalBottomSheet( + context: context, + builder: (context) { + return PersonActionSheet( + actionType: actionType, + showOptionToCreateNewPerson: showOptionToCreateNewAlbum, + cluserID: clusterID, + ); + }, + shape: const RoundedRectangleBorder( + side: BorderSide(width: 0), + borderRadius: BorderRadius.vertical( + top: Radius.circular(5), + ), + ), + topControl: const SizedBox.shrink(), + backgroundColor: getEnteColorScheme(context).backgroundElevated, + barrierColor: backdropFaintDark, + enableDrag: false, + ); +} + +class PersonActionSheet extends StatefulWidget { + final PersonActionType actionType; + final int cluserID; + final bool showOptionToCreateNewPerson; + const PersonActionSheet({ + required this.actionType, + required this.cluserID, + required this.showOptionToCreateNewPerson, + super.key, + }); + + @override + State createState() => _PersonActionSheetState(); +} + +class _PersonActionSheetState extends State { + static const int cancelButtonSize = 80; + String _searchQuery = ""; + bool userAlreadyAssigned = false; + + @override + void initState() { + super.initState(); + } + + @override + Widget build(BuildContext context) { + final bottomInset = MediaQuery.of(context).viewInsets.bottom; + final isKeyboardUp = bottomInset > 100; + return Padding( + padding: EdgeInsets.only( + bottom: isKeyboardUp ? bottomInset - cancelButtonSize : 0, + ), + child: Row( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + ConstrainedBox( + constraints: BoxConstraints( + maxWidth: math.min(428, MediaQuery.of(context).size.width), + ), + child: Padding( + padding: const EdgeInsets.fromLTRB(0, 32, 0, 8), + child: Column( + mainAxisSize: MainAxisSize.max, + children: [ + Expanded( + child: Column( + children: [ + BottomOfTitleBarWidget( + title: TitleBarTitleWidget( + title: _actionName(context, widget.actionType), + ), + // caption: 'Select or create a ', + ), + Padding( + padding: const EdgeInsets.only( + top: 16, + left: 16, + right: 16, + ), + child: TextInputWidget( + hintText: 'Person name', + prefixIcon: Icons.search_rounded, + onChange: (value) { + setState(() { + _searchQuery = value; + }); + }, + isClearable: true, + shouldUnfocusOnClearOrSubmit: true, + borderRadius: 2, + ), + ), + _getPersonItems(), + ], + ), + ), + SafeArea( + child: Container( + //inner stroke of 1pt + 15 pts of top padding = 16 pts + padding: const EdgeInsets.fromLTRB(16, 15, 16, 8), + decoration: BoxDecoration( + border: Border( + top: BorderSide( + color: getEnteColorScheme(context).strokeFaint, + ), + ), + ), + child: ButtonWidget( + buttonType: ButtonType.secondary, + buttonAction: ButtonAction.cancel, + isInAlert: true, + labelText: S.of(context).cancel, + ), + ), + ), + ], + ), + ), + ), + ], + ), + ); + } + + Flexible _getPersonItems() { + return Flexible( + child: Padding( + padding: const EdgeInsets.fromLTRB(16, 24, 4, 0), + child: FutureBuilder>( + future: _getPersons(), + builder: (context, snapshot) { + if (snapshot.hasError) { + log("Error: ${snapshot.error} ${snapshot.stackTrace}}"); + //Need to show an error on the UI here + return const SizedBox.shrink(); + } else if (snapshot.hasData) { + final persons = snapshot.data!; + final searchResults = _searchQuery.isNotEmpty + ? persons + .where( + (element) => element.$1.data.name + .toLowerCase() + .contains(_searchQuery), + ) + .toList() + : persons; + final shouldShowAddPerson = widget.showOptionToCreateNewPerson && + (_searchQuery.isEmpty || searchResults.isEmpty); + + return Scrollbar( + thumbVisibility: true, + radius: const Radius.circular(2), + child: Padding( + padding: const EdgeInsets.only(right: 12), + child: ListView.separated( + itemCount: + searchResults.length + (shouldShowAddPerson ? 1 : 0), + itemBuilder: (context, index) { + if (index == 0 && shouldShowAddPerson) { + return GestureDetector( + behavior: HitTestBehavior.opaque, + child: const NewPersonItemWidget(), + onTap: () async => { + addNewPerson( + context, + initValue: _searchQuery.trim(), + clusterID: widget.cluserID, + ), + }, + ); + } + final person = + searchResults[index - (shouldShowAddPerson ? 1 : 0)]; + return PersonRowItem( + person: person.$1, + personFile: person.$2, + onTap: () async { + if (userAlreadyAssigned) { + return; + } + userAlreadyAssigned = true; + await FaceMLDataDB.instance.assignClusterToPerson( + personID: person.$1.remoteID, + clusterID: widget.cluserID, + ); + Bus.instance.fire(PeopleChangedEvent()); + + Navigator.pop(context, person); + }, + ); + }, + separatorBuilder: (context, index) { + return const SizedBox(height: 6); + }, + ), + ), + ); + } else { + return const EnteLoadingWidget(); + } + }, + ), + ), + ); + } + + Future addNewPerson( + BuildContext context, { + String initValue = '', + required int clusterID, + }) async { + final result = await showTextInputDialog( + context, + title: "New person", + submitButtonLabel: 'Add', + hintText: 'Add name', + alwaysShowSuccessState: false, + initialValue: initValue, + textCapitalization: TextCapitalization.words, + onSubmit: (String text) async { + if (userAlreadyAssigned) { + return; + } + // indicates user cancelled the rename request + if (text.trim() == "") { + return; + } + try { + userAlreadyAssigned = true; + final PersonEntity p = + await PersonService.instance.addPerson(text, clusterID); + final bool extraPhotosFound = await ClusterFeedbackService.instance + .checkAndDoAutomaticMerges(p, personClusterID: clusterID); + if (extraPhotosFound) { + showShortToast(context, "Extra photos found for $text"); + } + Bus.instance.fire(PeopleChangedEvent()); + Navigator.pop(context, p); + } catch (e, s) { + Logger("_PersonActionSheetState") + .severe("Failed to add person", e, s); + rethrow; + } + }, + ); + if (result is Exception) { + await showGenericErrorDialog(context: context, error: result); + } + } + + Future> _getPersons({ + bool excludeHidden = true, + }) async { + final persons = await PersonService.instance.getPersons(); + if (excludeHidden) { + persons.removeWhere((person) => person.data.isIgnored); + } + final List<(PersonEntity, EnteFile)> personAndFileID = []; + for (final person in persons) { + final clustersToFiles = + await SearchService.instance.getClusterFilesForPersonID( + person.remoteID, + ); + final files = clustersToFiles.values.expand((e) => e).toList(); + personAndFileID.add((person, files.first)); + } + return personAndFileID; + } +} diff --git a/mobile/lib/ui/viewer/people/cluster_app_bar.dart b/mobile/lib/ui/viewer/people/cluster_app_bar.dart new file mode 100644 index 000000000..0896d0689 --- /dev/null +++ b/mobile/lib/ui/viewer/people/cluster_app_bar.dart @@ -0,0 +1,341 @@ +import 'dart:async'; + +import "package:flutter/foundation.dart"; +import 'package:flutter/material.dart'; +import 'package:logging/logging.dart'; +import "package:ml_linalg/linalg.dart"; +import 'package:photos/core/configuration.dart'; +import 'package:photos/core/event_bus.dart'; +import "package:photos/db/files_db.dart"; +import "package:photos/events/people_changed_event.dart"; +import 'package:photos/events/subscription_purchased_event.dart'; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/generated/protos/ente/common/vector.pb.dart"; +import "package:photos/models/file/file.dart"; +import 'package:photos/models/gallery_type.dart'; +import 'package:photos/models/selected_files.dart'; +import 'package:photos/services/collections_service.dart'; +import "package:photos/services/machine_learning/face_ml/face_clustering/cosine_distance.dart"; +import "package:photos/services/machine_learning/face_ml/face_ml_result.dart"; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; +import 'package:photos/ui/actions/collection/collection_sharing_actions.dart'; +import "package:photos/ui/common/popup_item.dart"; +import "package:photos/ui/viewer/people/cluster_breakup_page.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/utils/dialog_util.dart"; + +class ClusterAppBar extends StatefulWidget { + final GalleryType type; + final String? title; + final SelectedFiles selectedFiles; + final int clusterID; + final PersonEntity? person; + + const ClusterAppBar( + this.type, + this.title, + this.selectedFiles, + this.clusterID, { + this.person, + Key? key, + }) : super(key: key); + + @override + State createState() => _AppBarWidgetState(); +} + +enum ClusterPopupAction { + setCover, + breakupCluster, + breakupClusterDebug, + ignore, +} + +class _AppBarWidgetState extends State { + final _logger = Logger("_AppBarWidgetState"); + late StreamSubscription _userAuthEventSubscription; + late Function() _selectedFilesListener; + String? _appBarTitle; + late CollectionActions collectionActions; + final GlobalKey shareButtonKey = GlobalKey(); + bool isQuickLink = false; + late GalleryType galleryType; + + @override + void initState() { + super.initState(); + _selectedFilesListener = () { + setState(() {}); + }; + collectionActions = CollectionActions(CollectionsService.instance); + widget.selectedFiles.addListener(_selectedFilesListener); + _userAuthEventSubscription = + Bus.instance.on().listen((event) { + setState(() {}); + }); + _appBarTitle = widget.title; + galleryType = widget.type; + } + + @override + void dispose() { + _userAuthEventSubscription.cancel(); + widget.selectedFiles.removeListener(_selectedFilesListener); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + return AppBar( + elevation: 0, + centerTitle: false, + title: Text( + _appBarTitle!, + style: + Theme.of(context).textTheme.headlineSmall!.copyWith(fontSize: 16), + maxLines: 2, + overflow: TextOverflow.ellipsis, + ), + actions: kDebugMode ? _getDefaultActions(context) : null, + ); + } + + List _getDefaultActions(BuildContext context) { + final List actions = []; + // If the user has selected files, don't show any actions + if (widget.selectedFiles.files.isNotEmpty || + !Configuration.instance.hasConfiguredAccount()) { + return actions; + } + + final List> items = []; + + items.addAll( + [ + EntePopupMenuItem( + "Ignore person", + value: ClusterPopupAction.ignore, + icon: Icons.hide_image_outlined, + ), + EntePopupMenuItem( + "Mixed grouping?", + value: ClusterPopupAction.breakupCluster, + icon: Icons.analytics_outlined, + ), + ], + ); + if (kDebugMode) { + items.add( + EntePopupMenuItem( + "Debug mixed grouping", + value: ClusterPopupAction.breakupClusterDebug, + icon: Icons.analytics_outlined, + ), + ); + } + + if (items.isNotEmpty) { + actions.add( + PopupMenuButton( + itemBuilder: (context) { + return items; + }, + onSelected: (ClusterPopupAction value) async { + if (value == ClusterPopupAction.breakupCluster) { + // ignore: unawaited_futures + await _breakUpCluster(context); + } else if (value == ClusterPopupAction.ignore) { + await _onIgnoredClusterClicked(context); + } else if (value == ClusterPopupAction.breakupClusterDebug) { + await _breakUpClusterDebug(context); + } + // else if (value == ClusterPopupAction.setCover) { + // await setCoverPhoto(context); + }, + ), + ); + } + + return actions; + } + + @Deprecated( + 'Used for debugging an issue with conflicts on cluster IDs, resolved now', + ) + Future _validateCluster(BuildContext context) async { + _logger.info('_validateCluster called'); + final faceMlDb = FaceMLDataDB.instance; + + final faceIDs = await faceMlDb.getFaceIDsForCluster(widget.clusterID); + final fileIDs = faceIDs.map((e) => getFileIdFromFaceId(e)).toList(); + + final embeddingsBlobs = await faceMlDb.getFaceEmbeddingMapForFile(fileIDs); + embeddingsBlobs.removeWhere((key, value) => !faceIDs.contains(key)); + final embeddings = embeddingsBlobs + .map((key, value) => MapEntry(key, EVector.fromBuffer(value).values)); + + for (final MapEntry> embedding in embeddings.entries) { + double closestDistance = double.infinity; + double closestDistance32 = double.infinity; + double closestDistance64 = double.infinity; + String? closestFaceID; + for (final MapEntry> otherEmbedding + in embeddings.entries) { + if (embedding.key == otherEmbedding.key) { + continue; + } + final distance64 = cosineDistanceSIMD( + Vector.fromList(embedding.value, dtype: DType.float64), + Vector.fromList(otherEmbedding.value, dtype: DType.float64), + ); + final distance32 = cosineDistanceSIMD( + Vector.fromList(embedding.value, dtype: DType.float32), + Vector.fromList(otherEmbedding.value, dtype: DType.float32), + ); + final distance = cosineDistForNormVectors( + embedding.value, + otherEmbedding.value, + ); + if (distance < closestDistance) { + closestDistance = distance; + closestDistance32 = distance32; + closestDistance64 = distance64; + closestFaceID = otherEmbedding.key; + } + } + if (closestDistance > 0.3) { + _logger.severe( + "Face ${embedding.key} is similar to $closestFaceID with distance $closestDistance, and float32 distance $closestDistance32, and float64 distance $closestDistance64", + ); + } + } + } + + Future _onIgnoredClusterClicked(BuildContext context) async { + await showChoiceDialog( + context, + title: "Are you sure you want to ignore this person?", + body: + "The person grouping will not be displayed in the discovery tap anymore. Photos will remain untouched.", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + try { + await ClusterFeedbackService.instance.ignoreCluster(widget.clusterID); + Navigator.of(context).pop(); // Close the cluster page + } catch (e, s) { + _logger.severe('Ignoring a cluster failed', e, s); + // await showGenericErrorDialog(context: context, error: e); + } + }, + ); + } + + Future _breakUpCluster(BuildContext context) async { + bool userConfirmed = false; + List biggestClusterFiles = []; + int biggestClusterID = -1; + await showChoiceDialog( + context, + title: "Does this grouping contain multiple people?", + body: + "We will automatically analyze the grouping to determine if there are multiple people present, and separate them out again. This may take a few seconds.", + firstButtonLabel: "Yes, confirm", + firstButtonOnTap: () async { + try { + final breakupResult = await ClusterFeedbackService.instance + .breakUpCluster(widget.clusterID); + final Map> newClusterIDToFaceIDs = + breakupResult.newClusterIdToFaceIds!; + final Map newFaceIdToClusterID = + breakupResult.newFaceIdToCluster; + + // Update to delete the old clusters and save the new clusters + await FaceMLDataDB.instance.deleteClusterSummary(widget.clusterID); + await FaceMLDataDB.instance + .clusterSummaryUpdate(breakupResult.newClusterSummaries!); + await FaceMLDataDB.instance + .updateFaceIdToClusterId(newFaceIdToClusterID); + + // Find the biggest cluster + biggestClusterID = -1; + int biggestClusterSize = 0; + for (final MapEntry> clusterToFaces + in newClusterIDToFaceIDs.entries) { + if (clusterToFaces.value.length > biggestClusterSize) { + biggestClusterSize = clusterToFaces.value.length; + biggestClusterID = clusterToFaces.key; + } + } + // Get the files for the biggest new cluster + final biggestClusterFileIDs = newClusterIDToFaceIDs[biggestClusterID]! + .map((e) => getFileIdFromFaceId(e)) + .toList(); + biggestClusterFiles = await FilesDB.instance + .getFilesFromIDs( + biggestClusterFileIDs, + ) + .then((mapping) => mapping.values.toList()); + // Sort the files to prevent issues with the order of the files in gallery + biggestClusterFiles + .sort((a, b) => b.creationTime!.compareTo(a.creationTime!)); + + userConfirmed = true; + } catch (e, s) { + _logger.severe('Breakup cluster failed', e, s); + // await showGenericErrorDialog(context: context, error: e); + } + }, + ); + if (userConfirmed) { + // Close the old cluster page + Navigator.of(context).pop(); + + // Push the new cluster page + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + biggestClusterFiles, + clusterID: biggestClusterID, + ), + ), + ); + Bus.instance.fire(PeopleChangedEvent()); + } + } + + Future _breakUpClusterDebug(BuildContext context) async { + final breakupResult = + await ClusterFeedbackService.instance.breakUpCluster(widget.clusterID); + + final Map> newClusterIDToFaceIDs = + breakupResult.newClusterIdToFaceIds!; + + final allFileIDs = newClusterIDToFaceIDs.values + .expand((e) => e) + .map((e) => getFileIdFromFaceId(e)) + .toList(); + + final fileIDtoFile = await FilesDB.instance.getFilesFromIDs( + allFileIDs, + ); + + final newClusterIDToFiles = newClusterIDToFaceIDs.map( + (key, value) => MapEntry( + key, + value + .map((faceId) => fileIDtoFile[getFileIdFromFaceId(faceId)]!) + .toList(), + ), + ); + + await Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterBreakupPage( + newClusterIDToFiles, + "(Analysis)", + ), + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/cluster_breakup_page.dart b/mobile/lib/ui/viewer/people/cluster_breakup_page.dart new file mode 100644 index 000000000..e91909f47 --- /dev/null +++ b/mobile/lib/ui/viewer/people/cluster_breakup_page.dart @@ -0,0 +1,124 @@ +import "package:flutter/material.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/search/result/person_face_widget.dart"; + +class ClusterBreakupPage extends StatefulWidget { + final Map> newClusterIDsToFiles; + final String title; + + const ClusterBreakupPage( + this.newClusterIDsToFiles, + this.title, { + super.key, + }); + + @override + State createState() => _ClusterBreakupPageState(); +} + +class _ClusterBreakupPageState extends State { + @override + Widget build(BuildContext context) { + final keys = widget.newClusterIDsToFiles.keys.toList(); + final clusterIDsToFiles = widget.newClusterIDsToFiles; + + return Scaffold( + appBar: AppBar( + title: Text(widget.title), + ), + body: ListView.builder( + itemCount: widget.newClusterIDsToFiles.keys.length, + itemBuilder: (context, index) { + final int clusterID = keys[index]; + final List files = clusterIDsToFiles[keys[index]]!; + return InkWell( + onTap: () { + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + files, + clusterID: index, + appendTitle: "(Analysis)", + ), + ), + ); + }, + child: Container( + padding: const EdgeInsets.all(8.0), + child: Row( + children: [ + SizedBox( + width: 64, + height: 64, + child: files.isNotEmpty + ? ClipRRect( + borderRadius: const BorderRadius.all( + Radius.elliptical(16, 12),), + child: PersonFaceWidget( + files.first, + clusterID: clusterID, + ), + ) + : const ClipRRect( + borderRadius: + BorderRadius.all(Radius.elliptical(16, 12)), + child: NoThumbnailWidget( + addBorder: false, + ), + ), + ), + const SizedBox( + width: 8.0, + ), // Add some spacing between the thumbnail and the text + Expanded( + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 8.0), + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text( + "${clusterIDsToFiles[keys[index]]!.length} photos", + style: getEnteTextTheme(context).body, + ), + // GestureDetector( + // onTap: () async { + // try { + // final int result = await FaceMLDataDB + // .instance + // .removeClusterToPerson( + // personID: widget.person.remoteID, + // clusterID: clusterID, + // ); + // _logger.info( + // "Removed cluster $clusterID from person ${widget.person.remoteID}, result: $result", + // ); + // Bus.instance.fire(PeopleChangedEvent()); + // setState(() {}); + // } catch (e) { + // _logger.severe( + // "removing cluster from person,", + // e, + // ); + // } + // }, + // child: const Icon( + // CupertinoIcons.minus_circled, + // color: Colors.red, + // ), + // ), + ], + ), + ), + ), + ], + ), + ), + ); + }, + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/cluster_page.dart b/mobile/lib/ui/viewer/people/cluster_page.dart new file mode 100644 index 000000000..f6b720f02 --- /dev/null +++ b/mobile/lib/ui/viewer/people/cluster_page.dart @@ -0,0 +1,208 @@ +import "dart:async"; + +import "package:flutter/foundation.dart"; +import 'package:flutter/material.dart'; +import 'package:photos/core/event_bus.dart'; +import 'package:photos/events/files_updated_event.dart'; +import 'package:photos/events/local_photos_updated_event.dart'; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/generated/l10n.dart"; +import 'package:photos/models/file/file.dart'; +import 'package:photos/models/file_load_result.dart'; +import 'package:photos/models/gallery_type.dart'; +import 'package:photos/models/selected_files.dart'; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; +import 'package:photos/ui/viewer/actions/file_selection_overlay_bar.dart'; +import 'package:photos/ui/viewer/gallery/gallery.dart'; +import "package:photos/ui/viewer/people/add_person_action_sheet.dart"; +import "package:photos/ui/viewer/people/cluster_app_bar.dart"; +import "package:photos/ui/viewer/people/people_banner.dart"; +import "package:photos/ui/viewer/people/people_page.dart"; +import "package:photos/ui/viewer/search/result/person_face_widget.dart"; +import "package:photos/ui/viewer/search/result/search_result_page.dart"; +import "package:photos/utils/navigation_util.dart"; +import "package:photos/utils/toast_util.dart"; + +class ClusterPage extends StatefulWidget { + final List searchResult; + final bool enableGrouping; + final String tagPrefix; + final int clusterID; + final PersonEntity? personID; + final String appendTitle; + final bool showNamingBanner; + + static const GalleryType appBarType = GalleryType.cluster; + static const GalleryType overlayType = GalleryType.cluster; + + const ClusterPage( + this.searchResult, { + this.enableGrouping = true, + this.tagPrefix = "", + required this.clusterID, + this.personID, + this.appendTitle = "", + this.showNamingBanner = true, + Key? key, + }) : super(key: key); + + @override + State createState() => _ClusterPageState(); +} + +class _ClusterPageState extends State { + final _selectedFiles = SelectedFiles(); + late final List files; + late final StreamSubscription _filesUpdatedEvent; + late final StreamSubscription _peopleChangedEvent; + + bool get showNamingBanner => + (!userDismissedNamingBanner && widget.showNamingBanner); + + bool userDismissedNamingBanner = false; + + @override + void initState() { + super.initState(); + ClusterFeedbackService.setLastViewedClusterID(widget.clusterID); + files = widget.searchResult; + _filesUpdatedEvent = + Bus.instance.on().listen((event) { + if (event.type == EventType.deletedFromDevice || + event.type == EventType.deletedFromEverywhere || + event.type == EventType.deletedFromRemote || + event.type == EventType.hide) { + for (var updatedFile in event.updatedFiles) { + files.remove(updatedFile); + } + setState(() {}); + } + }); + _peopleChangedEvent = Bus.instance.on().listen((event) { + if (event.type == PeopleEventType.removedFilesFromCluster && + (event.source == widget.clusterID.toString())) { + for (var updatedFile in event.relevantFiles!) { + files.remove(updatedFile); + } + setState(() {}); + } + }); + kDebugMode + ? ClusterFeedbackService.instance.debugLogClusterBlurValues( + widget.clusterID, + clusterSize: files.length, + ) + : null; + } + + @override + void dispose() { + _filesUpdatedEvent.cancel(); + _peopleChangedEvent.cancel(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + final gallery = Gallery( + asyncLoader: (creationStartTime, creationEndTime, {limit, asc}) { + final result = files + .where( + (file) => + file.creationTime! >= creationStartTime && + file.creationTime! <= creationEndTime, + ) + .toList(); + return Future.value( + FileLoadResult( + result, + result.length < files.length, + ), + ); + }, + reloadEvent: Bus.instance.on(), + forceReloadEvents: [Bus.instance.on()], + removalEventTypes: const { + EventType.deletedFromRemote, + EventType.deletedFromEverywhere, + EventType.hide, + EventType.peopleClusterChanged, + }, + tagPrefix: widget.tagPrefix + widget.tagPrefix, + selectedFiles: _selectedFiles, + enableFileGrouping: widget.enableGrouping, + initialFiles: [widget.searchResult.first], + ); + return Scaffold( + appBar: PreferredSize( + preferredSize: const Size.fromHeight(50.0), + child: ClusterAppBar( + SearchResultPage.appBarType, + "${files.length} memories${widget.appendTitle}", + _selectedFiles, + widget.clusterID, + key: ValueKey(files.length), + ), + ), + body: Column( + children: [ + Expanded( + child: Stack( + alignment: Alignment.bottomCenter, + children: [ + gallery, + FileSelectionOverlayBar( + ClusterPage.overlayType, + _selectedFiles, + clusterID: widget.clusterID, + ), + ], + ), + ), + showNamingBanner + ? Dismissible( + key: const Key("namingBanner"), + direction: DismissDirection.horizontal, + onDismissed: (direction) { + setState(() { + userDismissedNamingBanner = true; + }); + }, + child: PeopleBanner( + type: PeopleBannerType.addName, + faceWidget: PersonFaceWidget( + files.first, + clusterID: widget.clusterID, + ), + actionIcon: Icons.add_outlined, + text: S.of(context).addAName, + subText: S.of(context).findPeopleByName, + onTap: () async { + if (widget.personID == null) { + final result = await showAssignPersonAction( + context, + clusterID: widget.clusterID, + ); + if (result != null && + result is (PersonEntity, EnteFile)) { + Navigator.pop(context); + // ignore: unawaited_futures + routeToPage(context, PeoplePage(person: result.$1)); + } else if (result != null && result is PersonEntity) { + Navigator.pop(context); + // ignore: unawaited_futures + routeToPage(context, PeoplePage(person: result)); + } + } else { + showShortToast(context, "No personID or clusterID"); + } + }, + ), + ) + : const SizedBox.shrink(), + ], + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/cropped_face_image_view.dart b/mobile/lib/ui/viewer/people/cropped_face_image_view.dart new file mode 100644 index 000000000..823a979fc --- /dev/null +++ b/mobile/lib/ui/viewer/people/cropped_face_image_view.dart @@ -0,0 +1,123 @@ +import 'dart:developer' show log; +import "dart:io" show File; + +import 'package:flutter/material.dart'; +import "package:photos/face/model/face.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/file/file_type.dart"; +import "package:photos/ui/viewer/file/thumbnail_widget.dart"; +import "package:photos/utils/file_util.dart"; +import "package:photos/utils/thumbnail_util.dart"; + +class CroppedFaceInfo { + final Image image; + final double scale; + final double offsetX; + final double offsetY; + + const CroppedFaceInfo({ + required this.image, + required this.scale, + required this.offsetX, + required this.offsetY, + }); +} + +class CroppedFaceImageView extends StatelessWidget { + final EnteFile enteFile; + final Face face; + + const CroppedFaceImageView({ + Key? key, + required this.enteFile, + required this.face, + }) : super(key: key); + + @override + Widget build(BuildContext context) { + return FutureBuilder( + future: getImage(), + builder: (context, snapshot) { + if (snapshot.hasData) { + return LayoutBuilder( + builder: ((context, constraints) { + final double imageAspectRatio = enteFile.width / enteFile.height; + final Image image = snapshot.data!; + + final double viewWidth = constraints.maxWidth; + final double viewHeight = constraints.maxHeight; + + final faceBox = face.detection.box; + + final double relativeFaceCenterX = + faceBox.xMin + faceBox.width / 2; + final double relativeFaceCenterY = + faceBox.yMin + faceBox.height / 2; + + const double desiredFaceHeightRelativeToWidget = 8 / 10; + final double scale = + (1 / faceBox.height) * desiredFaceHeightRelativeToWidget; + + final double widgetCenterX = viewWidth / 2; + final double widgetCenterY = viewHeight / 2; + + final double widgetAspectRatio = viewWidth / viewHeight; + final double imageToWidgetRatio = + imageAspectRatio / widgetAspectRatio; + + double offsetX = + (widgetCenterX - relativeFaceCenterX * viewWidth) * scale; + double offsetY = + (widgetCenterY - relativeFaceCenterY * viewHeight) * scale; + + if (imageAspectRatio < widgetAspectRatio) { + // Landscape Image: Adjust offsetX more conservatively + offsetX = offsetX * imageToWidgetRatio; + } else { + // Portrait Image: Adjust offsetY more conservatively + offsetY = offsetY / imageToWidgetRatio; + } + return ClipRRect( + borderRadius: const BorderRadius.all(Radius.elliptical(16, 12)), + child: Transform.translate( + offset: Offset( + offsetX, + offsetY, + ), + child: Transform.scale( + scale: scale, + child: image, + ), + ), + ); + }), + ); + } else { + if (snapshot.hasError) { + log('Error getting cover face for person: ${snapshot.error}'); + } + return ThumbnailWidget( + enteFile, + ); + } + }, + ); + } + + Future getImage() async { + final File? ioFile; + if (enteFile.fileType == FileType.video) { + ioFile = await getThumbnailForUploadedFile(enteFile); + } else { + ioFile = await getFile(enteFile); + } + if (ioFile == null) { + return null; + } + + final imageData = await ioFile.readAsBytes(); + final image = Image.memory(imageData, fit: BoxFit.contain); + + return image; + } +} diff --git a/mobile/lib/ui/viewer/people/new_person_item_widget.dart b/mobile/lib/ui/viewer/people/new_person_item_widget.dart new file mode 100644 index 000000000..c60f89259 --- /dev/null +++ b/mobile/lib/ui/viewer/people/new_person_item_widget.dart @@ -0,0 +1,73 @@ +import 'package:dotted_border/dotted_border.dart'; +import 'package:flutter/material.dart'; +import 'package:photos/theme/ente_theme.dart'; + +///https://www.figma.com/file/SYtMyLBs5SAOkTbfMMzhqt/ente-Visual-Design?node-id=10854%3A57947&t=H5AvR79OYDnB9ekw-4 +class NewPersonItemWidget extends StatelessWidget { + const NewPersonItemWidget({ + super.key, + }); + + @override + Widget build(BuildContext context) { + final textTheme = getEnteTextTheme(context); + final colorScheme = getEnteColorScheme(context); + const sideOfThumbnail = 60.0; + return LayoutBuilder( + builder: (context, constraints) { + return Stack( + alignment: Alignment.center, + children: [ + Row( + children: [ + ClipRRect( + borderRadius: const BorderRadius.horizontal( + left: Radius.circular(4), + ), + child: SizedBox( + height: sideOfThumbnail, + width: sideOfThumbnail, + child: Icon( + Icons.add_outlined, + color: colorScheme.strokeMuted, + ), + ), + ), + Padding( + padding: const EdgeInsets.only(left: 12), + child: Text( + 'Add new person', + style: + textTheme.body.copyWith(color: colorScheme.textMuted), + ), + ), + ], + ), + IgnorePointer( + child: DottedBorder( + dashPattern: const [4], + color: colorScheme.strokeFainter, + strokeWidth: 1, + padding: const EdgeInsets.all(0), + borderType: BorderType.RRect, + radius: const Radius.circular(4), + child: SizedBox( + //Have to decrease the height and width by 1 pt as the stroke + //dotted border gives is of strokeAlign.center, so 0.5 inside and + // outside. Here for the row, stroke should be inside so we + //decrease the size of this sizedBox by 1 (so it shrinks 0.5 from + //every side) so that the strokeAlign.center of this sizedBox + //looks like a strokeAlign.inside in the row. + height: sideOfThumbnail - 1, + //This width will work for this only if the row widget takes up the + //full size it's parent (stack). + width: constraints.maxWidth - 1, + ), + ), + ), + ], + ); + }, + ); + } +} diff --git a/mobile/lib/ui/viewer/people/people_app_bar.dart b/mobile/lib/ui/viewer/people/people_app_bar.dart new file mode 100644 index 000000000..d53059327 --- /dev/null +++ b/mobile/lib/ui/viewer/people/people_app_bar.dart @@ -0,0 +1,337 @@ +import 'dart:async'; + +import "package:flutter/cupertino.dart"; +import 'package:flutter/material.dart'; +import 'package:logging/logging.dart'; +import 'package:photos/core/configuration.dart'; +import 'package:photos/core/event_bus.dart'; +import "package:photos/events/people_changed_event.dart"; +import 'package:photos/events/subscription_purchased_event.dart'; +import "package:photos/face/model/person.dart"; +import "package:photos/generated/l10n.dart"; +import "package:photos/models/file/file.dart"; +import 'package:photos/models/gallery_type.dart'; +import 'package:photos/models/selected_files.dart'; +import 'package:photos/services/collections_service.dart'; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import 'package:photos/ui/actions/collection/collection_sharing_actions.dart'; +import "package:photos/ui/viewer/people/add_person_action_sheet.dart"; +import "package:photos/ui/viewer/people/people_page.dart"; +import "package:photos/ui/viewer/people/person_cluster_suggestion.dart"; +import 'package:photos/ui/viewer/people/person_clusters_page.dart'; +import "package:photos/utils/dialog_util.dart"; +import "package:photos/utils/navigation_util.dart"; + +class PeopleAppBar extends StatefulWidget { + final GalleryType type; + final String? title; + final SelectedFiles selectedFiles; + final PersonEntity person; + + bool get isIgnored => person.data.isIgnored; + + const PeopleAppBar( + this.type, + this.title, + this.selectedFiles, + this.person, { + Key? key, + }) : super(key: key); + + @override + State createState() => _AppBarWidgetState(); +} + +enum PeoplePopupAction { + rename, + setCover, + removeLabel, + viewPhotos, + confirmPhotos, + unignore, +} + +class _AppBarWidgetState extends State { + final _logger = Logger("_AppBarWidgetState"); + late StreamSubscription _userAuthEventSubscription; + late Function() _selectedFilesListener; + String? _appBarTitle; + late CollectionActions collectionActions; + final GlobalKey shareButtonKey = GlobalKey(); + bool isQuickLink = false; + late GalleryType galleryType; + + @override + void initState() { + super.initState(); + _selectedFilesListener = () { + setState(() {}); + }; + collectionActions = CollectionActions(CollectionsService.instance); + widget.selectedFiles.addListener(_selectedFilesListener); + _userAuthEventSubscription = + Bus.instance.on().listen((event) { + setState(() {}); + }); + _appBarTitle = widget.title; + galleryType = widget.type; + } + + @override + void dispose() { + _userAuthEventSubscription.cancel(); + widget.selectedFiles.removeListener(_selectedFilesListener); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + return AppBar( + elevation: 0, + centerTitle: false, + title: Text( + _appBarTitle!, + style: + Theme.of(context).textTheme.headlineSmall!.copyWith(fontSize: 16), + maxLines: 2, + overflow: TextOverflow.ellipsis, + ), + actions: _getDefaultActions(context), + ); + } + + Future _renamePerson(BuildContext context) async { + final result = await showTextInputDialog( + context, + title: S.of(context).rename, + submitButtonLabel: S.of(context).done, + hintText: S.of(context).enterPersonName, + alwaysShowSuccessState: true, + initialValue: widget.person.data.name, + textCapitalization: TextCapitalization.words, + onSubmit: (String text) async { + // indicates user cancelled the rename request + if (text == "" || text == _appBarTitle!) { + return; + } + + try { + await PersonService.instance + .updateAttributes(widget.person.remoteID, name: text); + if (mounted) { + _appBarTitle = text; + setState(() {}); + } + Bus.instance.fire(PeopleChangedEvent()); + } catch (e, s) { + _logger.severe("Failed to rename album", e, s); + rethrow; + } + }, + ); + if (result is Exception) { + await showGenericErrorDialog(context: context, error: result); + } + } + + List _getDefaultActions(BuildContext context) { + final List actions = []; + // If the user has selected files, don't show any actions + if (widget.selectedFiles.files.isNotEmpty || + !Configuration.instance.hasConfiguredAccount()) { + return actions; + } + + final List> items = []; + + if (!widget.isIgnored) { + items.addAll( + [ + PopupMenuItem( + value: PeoplePopupAction.rename, + child: Row( + children: [ + const Icon(Icons.edit), + const Padding( + padding: EdgeInsets.all(8), + ), + Text(S.of(context).rename), + ], + ), + ), + // PopupMenuItem( + // value: PeoplPopupAction.setCover, + // child: Row( + // children: [ + // const Icon(Icons.image_outlined), + // const Padding( + // padding: EdgeInsets.all(8), + // ), + // Text(S.of(context).setCover), + // ], + // ), + // ), + + PopupMenuItem( + value: PeoplePopupAction.removeLabel, + child: Row( + children: [ + const Icon(Icons.remove_circle_outline), + const Padding( + padding: EdgeInsets.all(8), + ), + Text(S.of(context).removePersonLabel), + ], + ), + ), + const PopupMenuItem( + value: PeoplePopupAction.viewPhotos, + child: Row( + children: [ + Icon(Icons.view_array_outlined), + Padding( + padding: EdgeInsets.all(8), + ), + Text('View confirmed photos'), + ], + ), + ), + const PopupMenuItem( + value: PeoplePopupAction.confirmPhotos, + child: Row( + children: [ + Icon(CupertinoIcons.square_stack_3d_down_right), + Padding( + padding: EdgeInsets.all(8), + ), + Text('Review suggestions'), + ], + ), + ), + ], + ); + } else { + items.addAll( + [ + const PopupMenuItem( + value: PeoplePopupAction.unignore, + child: Row( + children: [ + Icon(Icons.visibility_outlined), + Padding( + padding: EdgeInsets.all(8), + ), + Text("Show person"), + ], + ), + ), + ], + ); + } + + if (items.isNotEmpty) { + actions.add( + PopupMenuButton( + itemBuilder: (context) { + return items; + }, + onSelected: (PeoplePopupAction value) async { + if (value == PeoplePopupAction.viewPhotos) { + // ignore: unawaited_futures + unawaited( + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => PersonClustersPage(widget.person), + ), + ), + ); + } else if (value == PeoplePopupAction.confirmPhotos) { + // ignore: unawaited_futures + unawaited( + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => + PersonReviewClusterSuggestion(widget.person), + ), + ), + ); + } else if (value == PeoplePopupAction.rename) { + await _renamePerson(context); + } else if (value == PeoplePopupAction.setCover) { + await setCoverPhoto(context); + } else if (value == PeoplePopupAction.unignore) { + await _showPerson(context); + } else if (value == PeoplePopupAction.removeLabel) { + await _removePersonLabel(context); + } + }, + ), + ); + } + + return actions; + } + + Future _removePersonLabel(BuildContext context) async { + await showChoiceDialog( + context, + title: "Are you sure you want to remove this person label?", + body: + "All groupings for this person will be reset, and you will lose all suggestions made for this person", + firstButtonLabel: "Yes, remove person", + firstButtonOnTap: () async { + try { + await PersonService.instance.deletePerson(widget.person.remoteID); + Navigator.of(context).pop(); + } catch (e, s) { + _logger.severe('Removing person label failed', e, s); + } + }, + ); + } + + Future _showPerson(BuildContext context) async { + bool assignName = false; + await showChoiceDialog( + context, + title: + "Are you sure you want to show this person in people section again?", + firstButtonLabel: "Yes, show person", + firstButtonOnTap: () async { + try { + await PersonService.instance + .deletePerson(widget.person.remoteID, onlyMapping: false); + Bus.instance.fire(PeopleChangedEvent()); + assignName = true; + } catch (e, s) { + _logger.severe('Unignoring/showing and naming person failed', e, s); + // await showGenericErrorDialog(context: context, error: e); + } + }, + ); + if (assignName) { + final result = await showAssignPersonAction( + context, + clusterID: widget.person.data.assigned!.first.id, + ); + Navigator.pop(context); + if (result != null && result is (PersonEntity, EnteFile)) { + // ignore: unawaited_futures + routeToPage(context, PeoplePage(person: result.$1)); + } else if (result != null && result is PersonEntity) { + // ignore: unawaited_futures + routeToPage(context, PeoplePage(person: result)); + } + } + } + + Future setCoverPhoto(BuildContext context) async { + // final int? coverPhotoID = await showPickCoverPhotoSheet( + // context, + // widget.collection!, + // ); + // if (coverPhotoID != null) { + // unawaited(changeCoverPhoto(context, widget.collection!, coverPhotoID)); + // } + } +} diff --git a/mobile/lib/ui/viewer/people/people_banner.dart b/mobile/lib/ui/viewer/people/people_banner.dart new file mode 100644 index 000000000..db242a523 --- /dev/null +++ b/mobile/lib/ui/viewer/people/people_banner.dart @@ -0,0 +1,132 @@ +import "package:flutter/material.dart"; +import "package:flutter_animate/flutter_animate.dart"; +import "package:photos/ente_theme_data.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/components/buttons/icon_button_widget.dart"; +import "package:photos/ui/viewer/search/result/person_face_widget.dart"; + +enum PeopleBannerType { + addName, + suggestion, +} + +class PeopleBanner extends StatelessWidget { + final PeopleBannerType type; + final IconData? startIcon; + final PersonFaceWidget? faceWidget; + final IconData actionIcon; + final String text; + final String? subText; + final GestureTapCallback onTap; + + const PeopleBanner({ + Key? key, + required this.type, + this.startIcon, + this.faceWidget, + required this.actionIcon, + required this.text, + required this.onTap, + this.subText, + }) : super(key: key); + + @override + Widget build(BuildContext context) { + final colorScheme = getEnteColorScheme(context); + final textTheme = getEnteTextTheme(context); + final backgroundColor = colorScheme.backgroundElevated2; + final TextStyle mainTextStyle = textTheme.bodyBold; + final TextStyle subTextStyle = textTheme.miniMuted; + late final Widget startWidget; + late final bool roundedActionIcon; + switch (type) { + case PeopleBannerType.suggestion: + assert(startIcon != null); + startWidget = Padding( + padding: + const EdgeInsets.only(top: 10, bottom: 10, left: 6, right: 4), + child: Icon( + startIcon!, + size: 40, + color: colorScheme.primary500, + ), + ); + roundedActionIcon = true; + break; + case PeopleBannerType.addName: + assert(faceWidget != null); + startWidget = SizedBox( + width: 56, + height: 56, + child: ClipRRect( + borderRadius: const BorderRadius.all( + Radius.circular(4), + ), + child: faceWidget!, + ), + ); + roundedActionIcon = false; + } + + return RepaintBoundary( + child: Center( + child: GestureDetector( + onTap: onTap, + child: Container( + decoration: BoxDecoration( + boxShadow: Theme.of(context).colorScheme.enteTheme.shadowMenu, + color: backgroundColor, + ), + child: Padding( + padding: const EdgeInsets.symmetric(horizontal: 2, vertical: 2), + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + startWidget, + const SizedBox(width: 12), + Expanded( + child: Column( + mainAxisAlignment: MainAxisAlignment.spaceEvenly, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + text, + style: mainTextStyle, + textAlign: TextAlign.left, + ), + subText != null + ? const SizedBox(height: 6) + : const SizedBox.shrink(), + subText != null + ? Text( + subText!, + style: subTextStyle, + ) + : const SizedBox.shrink(), + ], + ), + ), + const SizedBox(width: 12), + IconButtonWidget( + icon: actionIcon, + iconButtonType: IconButtonType.primary, + iconColor: colorScheme.strokeBase, + defaultColor: colorScheme.fillFaint, + pressedColor: colorScheme.fillMuted, + roundedIcon: roundedActionIcon, + onTap: onTap, + ), + const SizedBox(width: 6), + ], + ), + ), + ), + ), + ).animate(onPlay: (controller) => controller.repeat()).shimmer( + duration: 1000.ms, + delay: 3200.ms, + size: 0.6, + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/people_page.dart b/mobile/lib/ui/viewer/people/people_page.dart new file mode 100644 index 000000000..8b399ced0 --- /dev/null +++ b/mobile/lib/ui/viewer/people/people_page.dart @@ -0,0 +1,215 @@ +import "dart:async"; +import "dart:developer"; + +import 'package:flutter/material.dart'; +import "package:logging/logging.dart"; +import 'package:photos/core/event_bus.dart'; +import 'package:photos/events/files_updated_event.dart'; +import 'package:photos/events/local_photos_updated_event.dart'; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/model/person.dart"; +import 'package:photos/models/file/file.dart'; +import 'package:photos/models/file_load_result.dart'; +import 'package:photos/models/gallery_type.dart'; +import 'package:photos/models/selected_files.dart'; +import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart"; +import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart"; +import "package:photos/services/search_service.dart"; +import 'package:photos/ui/viewer/actions/file_selection_overlay_bar.dart'; +import 'package:photos/ui/viewer/gallery/gallery.dart'; +import "package:photos/ui/viewer/people/people_app_bar.dart"; +import "package:photos/ui/viewer/people/people_banner.dart"; +import "package:photos/ui/viewer/people/person_cluster_suggestion.dart"; + +class PeoplePage extends StatefulWidget { + final String tagPrefix; + final PersonEntity person; + + static const GalleryType appBarType = GalleryType.peopleTag; + static const GalleryType overlayType = GalleryType.peopleTag; + + const PeoplePage({ + this.tagPrefix = "", + required this.person, + Key? key, + }) : super(key: key); + + @override + State createState() => _PeoplePageState(); +} + +class _PeoplePageState extends State { + final Logger _logger = Logger("_PeoplePageState"); + final _selectedFiles = SelectedFiles(); + List? files; + int? smallestClusterSize; + Future> filesFuture = Future.value([]); + + bool get showSuggestionBanner => (!userDismissedSuggestionBanner && + smallestClusterSize != null && + smallestClusterSize! >= kMinimumClusterSizeSearchResult && + files != null && + files!.isNotEmpty && + files!.length > 200); + + bool userDismissedSuggestionBanner = false; + + late final StreamSubscription _filesUpdatedEvent; + late final StreamSubscription _peopleChangedEvent; + + @override + void initState() { + super.initState(); + ClusterFeedbackService.resetLastViewedClusterID(); + _peopleChangedEvent = Bus.instance.on().listen((event) { + setState(() {}); + }); + + filesFuture = loadPersonFiles(); + + _filesUpdatedEvent = + Bus.instance.on().listen((event) { + if (event.type == EventType.deletedFromDevice || + event.type == EventType.deletedFromEverywhere || + event.type == EventType.deletedFromRemote || + event.type == EventType.hide) { + for (var updatedFile in event.updatedFiles) { + files?.remove(updatedFile); + } + setState(() {}); + } + }); + } + + Future> loadPersonFiles() async { + log("loadPersonFiles"); + final result = await SearchService.instance + .getClusterFilesForPersonID(widget.person.remoteID); + smallestClusterSize = result.values.fold(result.values.first.length, + (previousValue, element) { + return element.length < previousValue ? element.length : previousValue; + }); + final List resultFiles = []; + for (final e in result.entries) { + resultFiles.addAll(e.value); + } + final List sortedFiles = List.from(resultFiles); + sortedFiles.sort((a, b) => b.creationTime!.compareTo(a.creationTime!)); + files = sortedFiles; + return sortedFiles; + } + + @override + void dispose() { + _filesUpdatedEvent.cancel(); + _peopleChangedEvent.cancel(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + _logger.info("Building for ${widget.person.data.name}"); + return Scaffold( + appBar: PreferredSize( + preferredSize: const Size.fromHeight(50.0), + child: PeopleAppBar( + GalleryType.peopleTag, + widget.person.data.name, + _selectedFiles, + widget.person, + ), + ), + body: FutureBuilder>( + future: filesFuture, + builder: (context, snapshot) { + if (snapshot.hasData) { + final personFiles = snapshot.data as List; + return Column( + children: [ + Expanded( + child: Stack( + alignment: Alignment.bottomCenter, + children: [ + Gallery( + asyncLoader: ( + creationStartTime, + creationEndTime, { + limit, + asc, + }) async { + final result = await loadPersonFiles(); + return Future.value( + FileLoadResult( + result, + false, + ), + ); + }, + reloadEvent: Bus.instance.on(), + forceReloadEvents: [ + Bus.instance.on(), + ], + removalEventTypes: const { + EventType.deletedFromRemote, + EventType.deletedFromEverywhere, + EventType.hide, + }, + tagPrefix: widget.tagPrefix + widget.tagPrefix, + selectedFiles: _selectedFiles, + initialFiles: + personFiles.isNotEmpty ? [personFiles.first] : [], + ), + FileSelectionOverlayBar( + PeoplePage.overlayType, + _selectedFiles, + person: widget.person, + ), + ], + ), + ), + showSuggestionBanner + ? Dismissible( + key: const Key("suggestionBanner"), + direction: DismissDirection.horizontal, + onDismissed: (direction) { + setState(() { + userDismissedSuggestionBanner = true; + }); + }, + child: PeopleBanner( + type: PeopleBannerType.suggestion, + startIcon: Icons.face_retouching_natural, + actionIcon: Icons.search_outlined, + text: "Review suggestions", + subText: "Improve the results", + onTap: () async { + unawaited( + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => + PersonReviewClusterSuggestion( + widget.person, + ), + ), + ), + ); + }, + ), + ) + : const SizedBox.shrink(), + ], + ); + } else if (snapshot.hasError) { + log("Error: ${snapshot.error} ${snapshot.stackTrace}}"); + //Need to show an error on the UI here + return const SizedBox.shrink(); + } else { + return const Center( + child: CircularProgressIndicator(), + ); + } + }, + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart b/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart new file mode 100644 index 000000000..2a904720b --- /dev/null +++ b/mobile/lib/ui/viewer/people/person_cluster_suggestion.dart @@ -0,0 +1,452 @@ +import "dart:async" show StreamSubscription, unawaited; +import "dart:math"; +import "dart:typed_data"; + +import "package:flutter/foundation.dart" show kDebugMode; +import "package:flutter/material.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/file/file.dart"; +import 'package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart'; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/components/buttons/button_widget.dart"; +import "package:photos/ui/components/models/button_type.dart"; +// import "package:photos/ui/viewer/people/add_person_action_sheet.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/people/person_clusters_page.dart"; +import "package:photos/ui/viewer/search/result/person_face_widget.dart"; + +class PersonReviewClusterSuggestion extends StatefulWidget { + final PersonEntity person; + + const PersonReviewClusterSuggestion( + this.person, { + super.key, + }); + + @override + State createState() => _PersonClustersState(); +} + +class _PersonClustersState extends State { + int currentSuggestionIndex = 0; + bool fetch = true; + Key futureBuilderKeySuggestions = UniqueKey(); + Key futureBuilderKeyFaceThumbnails = UniqueKey(); + bool canGiveFeedback = true; + + // Declare a variable for the future + late Future> futureClusterSuggestions; + late StreamSubscription _peopleChangedEvent; + + @override + void initState() { + super.initState(); + // Initialize the future in initState + if (fetch) _fetchClusterSuggestions(); + fetch = true; + } + + @override + void dispose() { + _peopleChangedEvent.cancel(); + super.dispose(); + } + + @override + Widget build(BuildContext context) { + return Scaffold( + appBar: AppBar( + title: const Text('Review suggestions'), + actions: [ + IconButton( + icon: const Icon(Icons.history_outlined), + onPressed: () { + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => PersonClustersPage(widget.person), + ), + ); + }, + ), + ], + ), + body: FutureBuilder>( + key: futureBuilderKeySuggestions, + future: futureClusterSuggestions, + builder: (context, snapshot) { + if (snapshot.hasData) { + if (snapshot.data!.isEmpty) { + return Center( + child: Text( + "No suggestions for ${widget.person.data.name}", + style: getEnteTextTheme(context).largeMuted, + ), + ); + } + + final allSuggestions = snapshot.data!; + final numberOfDifferentSuggestions = allSuggestions.length; + final currentSuggestion = allSuggestions[currentSuggestionIndex]; + final int clusterID = currentSuggestion.clusterIDToMerge; + final double distance = currentSuggestion.distancePersonToCluster; + final bool usingMean = currentSuggestion.usedOnlyMeanForSuggestion; + final List files = currentSuggestion.filesInCluster; + + final Future> generateFacedThumbnails = + _generateFaceThumbnails( + files.sublist(0, min(files.length, 8)), + clusterID, + ); + + _peopleChangedEvent = + Bus.instance.on().listen((event) { + if (event.type == PeopleEventType.removedFilesFromCluster && + (event.source == clusterID.toString())) { + for (var updatedFile in event.relevantFiles!) { + files.remove(updatedFile); + } + fetch = false; + setState(() {}); + } + }); + return InkWell( + onTap: () { + final List sortedFiles = + List.from(currentSuggestion.filesInCluster); + sortedFiles.sort( + (a, b) => b.creationTime!.compareTo(a.creationTime!), + ); + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + sortedFiles, + personID: widget.person, + clusterID: clusterID, + showNamingBanner: false, + ), + ), + ); + }, + child: Container( + padding: const EdgeInsets.symmetric( + horizontal: 8.0, + vertical: 20, + ), + child: _buildSuggestionView( + clusterID, + distance, + usingMean, + files, + numberOfDifferentSuggestions, + allSuggestions, + generateFacedThumbnails, + ), + ), + ); + } else if (snapshot.hasError) { + // log the error + return const Center(child: Text("Error")); + } else { + return const Center(child: CircularProgressIndicator()); + } + }, + ), + ); + } + + Future _handleUserClusterChoice( + int clusterID, + bool yesOrNo, + int numberOfSuggestions, + ) async { + // Perform the action based on clusterID, e.g., assignClusterToPerson or captureNotPersonFeedback + if (!canGiveFeedback) { + return; + } + if (yesOrNo) { + canGiveFeedback = false; + await FaceMLDataDB.instance.assignClusterToPerson( + personID: widget.person.remoteID, + clusterID: clusterID, + ); + Bus.instance.fire(PeopleChangedEvent()); + // Increment the suggestion index + if (mounted) { + setState(() => currentSuggestionIndex++); + } + + // Check if we need to fetch new data + if (currentSuggestionIndex >= (numberOfSuggestions)) { + setState(() { + currentSuggestionIndex = 0; + futureBuilderKeySuggestions = + UniqueKey(); // Reset to trigger FutureBuilder + futureBuilderKeyFaceThumbnails = UniqueKey(); + _fetchClusterSuggestions(); + }); + } else { + futureBuilderKeyFaceThumbnails = UniqueKey(); + fetch = false; + setState(() {}); + } + } else { + await _rejectSuggestion(clusterID, numberOfSuggestions); + } + } + + Future _rejectSuggestion( + int clusterID, + int numberOfSuggestions, + ) async { + canGiveFeedback = false; + await FaceMLDataDB.instance.captureNotPersonFeedback( + personID: widget.person.remoteID, + clusterID: clusterID, + ); + // Recalculate the suggestions when a suggestion is rejected + setState(() { + currentSuggestionIndex = 0; + futureBuilderKeySuggestions = + UniqueKey(); // Reset to trigger FutureBuilder + futureBuilderKeyFaceThumbnails = UniqueKey(); + _fetchClusterSuggestions(); + }); + } + + // Method to fetch cluster suggestions + void _fetchClusterSuggestions() { + futureClusterSuggestions = + ClusterFeedbackService.instance.getSuggestionForPerson(widget.person); + } + + Widget _buildSuggestionView( + int clusterID, + double distance, + bool usingMean, + List files, + int numberOfSuggestions, + List allSuggestions, + Future> generateFaceThumbnails, + ) { + final widgetToReturn = Column( + key: ValueKey("cluster_id-$clusterID-files-${files.length}"), + children: [ + if (kDebugMode) + Text( + "ClusterID: $clusterID, Distance: ${distance.toStringAsFixed(3)}, usingMean: $usingMean", + style: getEnteTextTheme(context).smallMuted, + ), + Text( + // TODO: come up with a better copy for strings below! + "${widget.person.data.name}?", + style: getEnteTextTheme(context).largeMuted, + ), + const SizedBox(height: 24), + _buildThumbnailWidget( + files, + clusterID, + generateFaceThumbnails, + ), + const SizedBox( + height: 24.0, + ), + Padding( + padding: const EdgeInsets.symmetric(horizontal: 24.0), + child: Row( + children: [ + Expanded( + child: ButtonWidget( + buttonType: ButtonType.critical, + labelText: 'No', + buttonSize: ButtonSize.large, + onTap: () async => { + await _handleUserClusterChoice( + clusterID, + false, + numberOfSuggestions, + ), + }, + ), + ), + const SizedBox(width: 12.0), + Expanded( + child: ButtonWidget( + buttonType: ButtonType.primary, + labelText: 'Yes', + buttonSize: ButtonSize.large, + onTap: () async => { + await _handleUserClusterChoice( + clusterID, + true, + numberOfSuggestions, + ), + }, + ), + ), + ], + ), + ), + // const SizedBox( + // height: 24.0, + // ), + // ButtonWidget( + // shouldSurfaceExecutionStates: false, + // buttonType: ButtonType.neutral, + // labelText: 'Assign different person', + // buttonSize: ButtonSize.small, + // onTap: () async { + // final result = await showAssignPersonAction( + // context, + // clusterID: clusterID, + // ); + // if (result != null && + // (result is (PersonEntity, EnteFile) || + // result is PersonEntity)) { + // await _rejectSuggestion(clusterID, numberOfSuggestions); + // } + // }, + // ), + ], + ); + // Precompute face thumbnails for next suggestions, in case there are + const precomputeSuggestions = 8; + const maxPrecomputations = 8; + int compCount = 0; + if (allSuggestions.length > currentSuggestionIndex + 1) { + outerLoop: + for (final suggestion in allSuggestions.sublist( + currentSuggestionIndex + 1, + min( + allSuggestions.length, + currentSuggestionIndex + precomputeSuggestions, + ), + )) { + final files = suggestion.filesInCluster; + final clusterID = suggestion.clusterIDToMerge; + for (final file in files.sublist(0, min(files.length, 8))) { + unawaited( + PersonFaceWidget.precomputeNextFaceCrops( + file, + clusterID, + useFullFile: false, + ), + ); + compCount++; + if (compCount >= maxPrecomputations) { + debugPrint( + 'Prefetching $compCount face thumbnails for suggestions', + ); + break outerLoop; + } + } + } + } + return widgetToReturn; + } + + Widget _buildThumbnailWidget( + List files, + int clusterID, + Future> generateFaceThumbnails, + ) { + return SizedBox( + height: MediaQuery.of(context).size.height * 0.4, + child: FutureBuilder>( + key: futureBuilderKeyFaceThumbnails, + future: generateFaceThumbnails, + builder: (context, snapshot) { + if (snapshot.hasData) { + final faceThumbnails = snapshot.data!; + canGiveFeedback = true; + return Column( + children: [ + Row( + mainAxisAlignment: MainAxisAlignment.center, + children: _buildThumbnailWidgetsRow( + files, + clusterID, + faceThumbnails, + ), + ), + if (files.length > 4) const SizedBox(height: 24), + if (files.length > 4) + Row( + mainAxisAlignment: MainAxisAlignment.center, + children: _buildThumbnailWidgetsRow( + files, + clusterID, + faceThumbnails, + start: 4, + ), + ), + const SizedBox(height: 24.0), + Text( + "${files.length} photos", + style: getEnteTextTheme(context).body, + ), + ], + ); + } else if (snapshot.hasError) { + // log the error + return const Center(child: Text("Error")); + } else { + canGiveFeedback = false; + return const Center(child: CircularProgressIndicator()); + } + }, + ), + ); + } + + List _buildThumbnailWidgetsRow( + List files, + int cluserId, + Map faceThumbnails, { + int start = 0, + }) { + return List.generate( + min(4, max(0, files.length - start)), + (index) => Padding( + padding: const EdgeInsets.all(8.0), + child: SizedBox( + width: 72, + height: 72, + child: ClipOval( + child: PersonFaceWidget( + files[start + index], + clusterID: cluserId, + useFullFile: false, + thumbnailFallback: false, + faceCrop: faceThumbnails[files[start + index].uploadedFileID!], + ), + ), + ), + ), + ); + } + + Future> _generateFaceThumbnails( + List files, + int clusterID, + ) async { + final futures = >[]; + for (final file in files) { + futures.add( + PersonFaceWidget.precomputeNextFaceCrops( + file, + clusterID, + useFullFile: false, + ), + ); + } + final faceCropsList = await Future.wait(futures); + final faceCrops = {}; + for (var i = 0; i < faceCropsList.length; i++) { + faceCrops[files[i].uploadedFileID!] = faceCropsList[i]; + } + return faceCrops; + } +} diff --git a/mobile/lib/ui/viewer/people/person_clusters_page.dart b/mobile/lib/ui/viewer/people/person_clusters_page.dart new file mode 100644 index 000000000..2c493fc21 --- /dev/null +++ b/mobile/lib/ui/viewer/people/person_clusters_page.dart @@ -0,0 +1,144 @@ +import "package:flutter/cupertino.dart"; +import "package:flutter/material.dart"; +import "package:logging/logging.dart"; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/people_changed_event.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import "package:photos/services/search_service.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; +// import "package:photos/ui/viewer/file/thumbnail_widget.dart"; +import "package:photos/ui/viewer/people/cluster_page.dart"; +import "package:photos/ui/viewer/search/result/person_face_widget.dart"; + +class PersonClustersPage extends StatefulWidget { + final PersonEntity person; + + const PersonClustersPage( + this.person, { + super.key, + }); + + @override + State createState() => _PersonClustersPageState(); +} + +class _PersonClustersPageState extends State { + final Logger _logger = Logger("_PersonClustersState"); + @override + Widget build(BuildContext context) { + return Scaffold( + appBar: AppBar( + title: Text(widget.person.data.name), + ), + body: FutureBuilder>>( + future: SearchService.instance + .getClusterFilesForPersonID(widget.person.remoteID), + builder: (context, snapshot) { + if (snapshot.hasData) { + final List keys = snapshot.data!.keys.toList(); + return ListView.builder( + itemCount: keys.length, + itemBuilder: (context, index) { + final int clusterID = keys[index]; + final List files = snapshot.data![keys[index]]!; + return InkWell( + onTap: () { + Navigator.of(context).push( + MaterialPageRoute( + builder: (context) => ClusterPage( + files, + personID: widget.person, + clusterID: index, + ), + ), + ); + }, + child: Container( + padding: const EdgeInsets.all(8.0), + child: Row( + children: [ + SizedBox( + width: 64, + height: 64, + child: files.isNotEmpty + ? ClipRRect( + borderRadius: const BorderRadius.all( + Radius.elliptical(16, 12), + ), + child: PersonFaceWidget( + files.first, + clusterID: clusterID, + ), + ) + : const ClipRRect( + borderRadius: BorderRadius.all( + Radius.elliptical(16, 12), + ), + child: NoThumbnailWidget( + addBorder: false, + ), + ), + ), + const SizedBox( + width: 8.0, + ), // Add some spacing between the thumbnail and the text + Expanded( + child: Padding( + padding: + const EdgeInsets.symmetric(horizontal: 8.0), + child: Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Text( + "${snapshot.data![keys[index]]!.length} photos", + style: getEnteTextTheme(context).body, + ), + GestureDetector( + onTap: () async { + try { + await PersonService.instance + .removeClusterToPerson( + personID: widget.person.remoteID, + clusterID: clusterID, + ); + _logger.info( + "Removed cluster $clusterID from person ${widget.person.remoteID}", + ); + Bus.instance.fire(PeopleChangedEvent()); + setState(() {}); + } catch (e) { + _logger.severe( + "removing cluster from person,", + e, + ); + } + }, + child: const Icon( + CupertinoIcons.minus_circled, + color: Colors.red, + ), + ), + ], + ), + ), + ), + ], + ), + ), + ); + }, + ); + } else if (snapshot.hasError) { + _logger.warning("Failed to get cluster", snapshot.error); + return const Center(child: Text("Error")); + } else { + return const Center(child: CircularProgressIndicator()); + } + }, + ), + ); + } +} diff --git a/mobile/lib/ui/viewer/people/person_row_item.dart b/mobile/lib/ui/viewer/people/person_row_item.dart new file mode 100644 index 000000000..831fe9729 --- /dev/null +++ b/mobile/lib/ui/viewer/people/person_row_item.dart @@ -0,0 +1,36 @@ +import "package:flutter/material.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/ui/viewer/search/result/person_face_widget.dart"; + +class PersonRowItem extends StatelessWidget { + final PersonEntity person; + final EnteFile personFile; + final VoidCallback onTap; + + const PersonRowItem({ + Key? key, + required this.person, + required this.personFile, + required this.onTap, + }) : super(key: key); + + @override + Widget build(BuildContext context) { + return ListTile( + dense: false, + leading: SizedBox( + width: 56, + height: 56, + child: ClipRRect( + borderRadius: const BorderRadius.all( + Radius.elliptical(16, 12), + ), + child: PersonFaceWidget(personFile, personId: person.remoteID), + ), + ), + title: Text(person.data.name), + onTap: onTap, + ); + } +} diff --git a/mobile/lib/ui/viewer/search/result/no_result_widget.dart b/mobile/lib/ui/viewer/search/result/no_result_widget.dart index 9ebb9cf80..48ba811df 100644 --- a/mobile/lib/ui/viewer/search/result/no_result_widget.dart +++ b/mobile/lib/ui/viewer/search/result/no_result_widget.dart @@ -21,7 +21,6 @@ class _NoResultWidgetState extends State { super.initState(); searchTypes = SectionType.values.toList(growable: true); // remove face and content sectionType - searchTypes.remove(SectionType.face); searchTypes.remove(SectionType.content); } diff --git a/mobile/lib/ui/viewer/search/result/person_face_widget.dart b/mobile/lib/ui/viewer/search/result/person_face_widget.dart new file mode 100644 index 000000000..8be99e5f6 --- /dev/null +++ b/mobile/lib/ui/viewer/search/result/person_face_widget.dart @@ -0,0 +1,270 @@ +import "dart:developer"; +// import "dart:io"; +import "dart:typed_data"; + +import 'package:flutter/widgets.dart'; +import "package:photos/db/files_db.dart"; +import "package:photos/face/db.dart"; +import "package:photos/face/model/face.dart"; +import "package:photos/face/model/person.dart"; +import 'package:photos/models/file/file.dart'; +import "package:photos/services/machine_learning/face_ml/person/person_service.dart"; +import "package:photos/ui/common/loading_widget.dart"; +import "package:photos/ui/viewer/file/thumbnail_widget.dart"; +import "package:photos/ui/viewer/file_details/face_widget.dart"; +import "package:photos/ui/viewer/people/cropped_face_image_view.dart"; +import "package:photos/utils/face/face_box_crop.dart"; +import "package:photos/utils/thumbnail_util.dart"; +import "package:pool/pool.dart"; + +class PersonFaceWidget extends StatelessWidget { + final EnteFile file; + final String? personId; + final int? clusterID; + final bool useFullFile; + final bool thumbnailFallback; + final Uint8List? faceCrop; + + // PersonFaceWidget constructor checks that both personId and clusterID are not null + // and that the file is not null + const PersonFaceWidget( + this.file, { + this.personId, + this.clusterID, + this.useFullFile = true, + this.thumbnailFallback = true, + this.faceCrop, + Key? key, + }) : assert( + personId != null || clusterID != null, + "PersonFaceWidget requires either personId or clusterID to be non-null", + ), + super(key: key); + + @override + Widget build(BuildContext context) { + if (faceCrop != null) { + return Stack( + fit: StackFit.expand, + children: [ + Image( + image: MemoryImage(faceCrop!), + fit: BoxFit.cover, + ), + ], + ); + } + if (useGeneratedFaceCrops) { + return FutureBuilder( + future: getFaceCrop(), + builder: (context, snapshot) { + if (snapshot.hasData) { + final ImageProvider imageProvider = MemoryImage(snapshot.data!); + return Stack( + fit: StackFit.expand, + children: [ + Image( + image: imageProvider, + fit: BoxFit.cover, + ), + ], + ); + } else { + if (snapshot.hasError) { + log('Error getting cover face for person: ${snapshot.error}'); + } + return thumbnailFallback + ? ThumbnailWidget(file) + : const EnteLoadingWidget(); + } + }, + ); + } else { + return FutureBuilder( + future: _getFace(), + builder: (context, snapshot) { + if (snapshot.hasData) { + final Face face = snapshot.data!; + return Stack( + fit: StackFit.expand, + children: [ + CroppedFaceImageView(enteFile: file, face: face), + ], + ); + } else { + if (snapshot.hasError) { + log('Error getting cover face for person: ${snapshot.error}'); + } + return thumbnailFallback + ? ThumbnailWidget(file) + : const EnteLoadingWidget(); + } + }, + ); + } + } + + Future _getFace() async { + String? personAvatarFaceID; + if (personId != null) { + final PersonEntity? personEntity = + await PersonService.instance.getPerson(personId!); + if (personEntity != null) { + personAvatarFaceID = personEntity.data.avatarFaceId; + } + } + return await FaceMLDataDB.instance.getCoverFaceForPerson( + recentFileID: file.uploadedFileID!, + avatarFaceId: personAvatarFaceID, + personID: personId, + clusterID: clusterID, + ); + } + + Future getFaceCrop() async { + try { + final Face? face = await _getFace(); + if (face == null) { + debugPrint( + "No cover face for person: $personId and cluster $clusterID and recentFile ${file.uploadedFileID}", + ); + return null; + } + final Uint8List? cachedFace = faceCropCache.get(face.faceID); + if (cachedFace != null) { + return cachedFace; + } + final faceCropCacheFile = cachedFaceCropPath(face.faceID); + if ((await faceCropCacheFile.exists())) { + final data = await faceCropCacheFile.readAsBytes(); + faceCropCache.put(face.faceID, data); + return data; + } + if (!useFullFile) { + final Uint8List? cachedFaceThumbnail = + faceCropThumbnailCache.get(face.faceID); + if (cachedFaceThumbnail != null) { + return cachedFaceThumbnail; + } + } + EnteFile? fileForFaceCrop = file; + if (face.fileID != file.uploadedFileID!) { + fileForFaceCrop = + await FilesDB.instance.getAnyUploadedFile(face.fileID); + } + if (fileForFaceCrop == null) { + return null; + } + + late final Pool relevantResourcePool; + if (useFullFile) { + relevantResourcePool = poolFullFileFaceGenerations; + } else { + relevantResourcePool = poolThumbnailFaceGenerations; + } + final result = await relevantResourcePool.withResource( + () async => await getFaceCrops( + fileForFaceCrop!, + { + face.faceID: face.detection.box, + }, + useFullFile: useFullFile, + ), + ); + final Uint8List? computedCrop = result?[face.faceID]; + if (computedCrop != null) { + if (useFullFile) { + faceCropCache.put(face.faceID, computedCrop); + faceCropCacheFile.writeAsBytes(computedCrop).ignore(); + } else { + faceCropThumbnailCache.put(face.faceID, computedCrop); + } + } + return computedCrop; + } catch (e, s) { + log( + "Error getting cover face for person: $personId and cluster $clusterID", + error: e, + stackTrace: s, + ); + return null; + } + } + + static Future precomputeNextFaceCrops( + file, + clusterID, { + required bool useFullFile, + }) async { + try { + final Face? face = await FaceMLDataDB.instance.getCoverFaceForPerson( + recentFileID: file.uploadedFileID!, + clusterID: clusterID, + ); + if (face == null) { + debugPrint( + "No cover face for cluster $clusterID and recentFile ${file.uploadedFileID}", + ); + return null; + } + final Uint8List? cachedFace = faceCropCache.get(face.faceID); + if (cachedFace != null) { + return cachedFace; + } + final faceCropCacheFile = cachedFaceCropPath(face.faceID); + if ((await faceCropCacheFile.exists())) { + final data = await faceCropCacheFile.readAsBytes(); + faceCropCache.put(face.faceID, data); + return data; + } + if (!useFullFile) { + final Uint8List? cachedFaceThumbnail = + faceCropThumbnailCache.get(face.faceID); + if (cachedFaceThumbnail != null) { + return cachedFaceThumbnail; + } + } + EnteFile? fileForFaceCrop = file; + if (face.fileID != file.uploadedFileID!) { + fileForFaceCrop = + await FilesDB.instance.getAnyUploadedFile(face.fileID); + } + if (fileForFaceCrop == null) { + return null; + } + + late final Pool relevantResourcePool; + if (useFullFile) { + relevantResourcePool = poolFullFileFaceGenerations; + } else { + relevantResourcePool = poolThumbnailFaceGenerations; + } + final result = await relevantResourcePool.withResource( + () async => await getFaceCrops( + fileForFaceCrop!, + { + face.faceID: face.detection.box, + }, + useFullFile: useFullFile, + ), + ); + final Uint8List? computedCrop = result?[face.faceID]; + if (computedCrop != null) { + if (useFullFile) { + faceCropCache.put(face.faceID, computedCrop); + faceCropCacheFile.writeAsBytes(computedCrop).ignore(); + } else { + faceCropThumbnailCache.put(face.faceID, computedCrop); + } + } + return computedCrop; + } catch (e, s) { + log( + "Error getting cover face for cluster $clusterID", + error: e, + stackTrace: s, + ); + return null; + } + } +} diff --git a/mobile/lib/ui/viewer/search/result/search_result_widget.dart b/mobile/lib/ui/viewer/search/result/search_result_widget.dart index 5564af7c9..fbd77531a 100644 --- a/mobile/lib/ui/viewer/search/result/search_result_widget.dart +++ b/mobile/lib/ui/viewer/search/result/search_result_widget.dart @@ -13,12 +13,14 @@ class SearchResultWidget extends StatelessWidget { final SearchResult searchResult; final Future? resultCount; final Function? onResultTap; + final Map? params; const SearchResultWidget( this.searchResult, { Key? key, this.resultCount, this.onResultTap, + this.params, }) : super(key: key); @override @@ -42,6 +44,7 @@ class SearchResultWidget extends StatelessWidget { SearchThumbnailWidget( searchResult.previewThumbnail(), heroTagPrefix, + searchResult: searchResult, ), const SizedBox(width: 12), Padding( @@ -143,6 +146,8 @@ class SearchResultWidget extends StatelessWidget { return "Magic"; case ResultType.shared: return "Shared"; + case ResultType.faces: + return "Person"; default: return type.name.toUpperCase(); } diff --git a/mobile/lib/ui/viewer/search/result/search_section_all_page.dart b/mobile/lib/ui/viewer/search/result/search_section_all_page.dart index 59761009a..17dea1f84 100644 --- a/mobile/lib/ui/viewer/search/result/search_section_all_page.dart +++ b/mobile/lib/ui/viewer/search/result/search_section_all_page.dart @@ -1,5 +1,6 @@ import "dart:async"; +import "package:collection/collection.dart"; import "package:flutter/material.dart"; import "package:flutter_animate/flutter_animate.dart"; import "package:photos/events/event.dart"; @@ -109,7 +110,12 @@ class _SearchSectionAllPageState extends State { builder: (context, snapshot) { if (snapshot.hasData) { List sectionResults = snapshot.data!; - sectionResults.sort((a, b) => a.name().compareTo(b.name())); + if (widget.sectionType.sortByName) { + sectionResults.sort( + (a, b) => + compareAsciiLowerCaseNatural(b.name(), a.name()), + ); + } if (widget.sectionType == SectionType.location) { final result = sectionResults.splitMatch( (e) => e.type() == ResultType.location, diff --git a/mobile/lib/ui/viewer/search/result/search_thumbnail_widget.dart b/mobile/lib/ui/viewer/search/result/search_thumbnail_widget.dart index 13b303fec..514c65b99 100644 --- a/mobile/lib/ui/viewer/search/result/search_thumbnail_widget.dart +++ b/mobile/lib/ui/viewer/search/result/search_thumbnail_widget.dart @@ -1,15 +1,22 @@ import 'package:flutter/widgets.dart'; import 'package:photos/models/file/file.dart'; +import "package:photos/models/search/generic_search_result.dart"; +import "package:photos/models/search/search_constants.dart"; +import "package:photos/models/search/search_result.dart"; +import "package:photos/models/search/search_types.dart"; import 'package:photos/ui/viewer/file/no_thumbnail_widget.dart'; import 'package:photos/ui/viewer/file/thumbnail_widget.dart'; +import 'package:photos/ui/viewer/search/result/person_face_widget.dart'; class SearchThumbnailWidget extends StatelessWidget { final EnteFile? file; + final SearchResult? searchResult; final String tagPrefix; const SearchThumbnailWidget( this.file, this.tagPrefix, { + this.searchResult, Key? key, }) : super(key: key); @@ -23,9 +30,18 @@ class SearchThumbnailWidget extends StatelessWidget { child: ClipRRect( borderRadius: const BorderRadius.horizontal(left: Radius.circular(4)), child: file != null - ? ThumbnailWidget( - file!, - ) + ? (searchResult != null && + searchResult!.type() == ResultType.faces) + ? PersonFaceWidget( + file!, + personId: (searchResult as GenericSearchResult) + .params[kPersonParamID], + clusterID: (searchResult as GenericSearchResult) + .params[kClusterParamId], + ) + : ThumbnailWidget( + file!, + ) : const NoThumbnailWidget( addBorder: false, ), diff --git a/mobile/lib/ui/viewer/search/result/searchable_item.dart b/mobile/lib/ui/viewer/search/result/searchable_item.dart index 1124d925e..f8e2ed1ac 100644 --- a/mobile/lib/ui/viewer/search/result/searchable_item.dart +++ b/mobile/lib/ui/viewer/search/result/searchable_item.dart @@ -30,6 +30,8 @@ class SearchableItemWidget extends StatelessWidget { final heroTagPrefix = additionalPrefix + searchResult.heroTag(); final textTheme = getEnteTextTheme(context); final colorScheme = getEnteColorScheme(context); + final bool isCluster = (searchResult.type() == ResultType.faces && + int.tryParse(searchResult.name()) != null); return GestureDetector( onTap: () { @@ -66,6 +68,7 @@ class SearchableItemWidget extends StatelessWidget { child: SearchThumbnailWidget( searchResult.previewThumbnail(), heroTagPrefix, + searchResult: searchResult, ), ), const SizedBox(width: 12), @@ -75,14 +78,16 @@ class SearchableItemWidget extends StatelessWidget { child: Column( crossAxisAlignment: CrossAxisAlignment.start, children: [ - Text( - searchResult.name(), - style: searchResult.type() == - ResultType.locationSuggestion - ? textTheme.bodyFaint - : textTheme.body, - overflow: TextOverflow.ellipsis, - ), + isCluster + ? const SizedBox.shrink() + : Text( + searchResult.name(), + style: searchResult.type() == + ResultType.locationSuggestion + ? textTheme.bodyFaint + : textTheme.body, + overflow: TextOverflow.ellipsis, + ), const SizedBox( height: 2, ), diff --git a/mobile/lib/ui/viewer/search/search_widget.dart b/mobile/lib/ui/viewer/search/search_widget.dart index 1c6c7b693..c917d60e9 100644 --- a/mobile/lib/ui/viewer/search/search_widget.dart +++ b/mobile/lib/ui/viewer/search/search_widget.dart @@ -2,10 +2,12 @@ import "dart:async"; import "package:flutter/material.dart"; import "package:flutter/scheduler.dart"; +import "package:logging/logging.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/events/clear_and_unfocus_search_bar_event.dart"; import "package:photos/events/tab_changed_event.dart"; import "package:photos/generated/l10n.dart"; +import "package:photos/models/search/generic_search_result.dart"; import "package:photos/models/search/index_of_indexed_stack.dart"; import "package:photos/models/search/search_result.dart"; import "package:photos/services/search_service.dart"; @@ -41,6 +43,7 @@ class SearchWidgetState extends State { TextEditingController textController = TextEditingController(); late final StreamSubscription _clearAndUnfocusSearchBar; + late final Logger _logger = Logger("SearchWidgetState"); @override void initState() { @@ -200,7 +203,7 @@ class SearchWidgetState extends State { String query, ) { int resultCount = 0; - final maxResultCount = _isYearValid(query) ? 11 : 10; + final maxResultCount = _isYearValid(query) ? 13 : 12; final streamController = StreamController>(); if (query.isEmpty) { @@ -215,6 +218,11 @@ class SearchWidgetState extends State { if (resultCount == maxResultCount) { streamController.close(); } + if (resultCount > maxResultCount) { + _logger.warning( + "More results than expected. Expected: $maxResultCount, actual: $resultCount", + ); + } } if (_isYearValid(query)) { @@ -252,6 +260,17 @@ class SearchWidgetState extends State { onResultsReceived(locationResult); }, ); + _searchService.getAllFace(null).then( + (locationResult) { + final List filteredResults = []; + for (final result in locationResult) { + if (result.name().toLowerCase().contains(query.toLowerCase())) { + filteredResults.add(result); + } + } + onResultsReceived(filteredResults); + }, + ); _searchService.getCollectionSearchResults(query).then( (collectionResults) { diff --git a/mobile/lib/ui/viewer/search_tab/people_section.dart b/mobile/lib/ui/viewer/search_tab/people_section.dart new file mode 100644 index 000000000..13e2f8a81 --- /dev/null +++ b/mobile/lib/ui/viewer/search_tab/people_section.dart @@ -0,0 +1,329 @@ +import "dart:async"; + +import "package:collection/collection.dart"; +import "package:flutter/material.dart"; +import "package:photos/core/constants.dart"; +import "package:photos/events/event.dart"; +import "package:photos/face/model/person.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/search/album_search_result.dart"; +import "package:photos/models/search/generic_search_result.dart"; +import "package:photos/models/search/recent_searches.dart"; +import "package:photos/models/search/search_constants.dart"; +import "package:photos/models/search/search_result.dart"; +import "package:photos/models/search/search_types.dart"; +import "package:photos/theme/ente_theme.dart"; +import "package:photos/ui/settings/machine_learning_settings_page.dart"; +import "package:photos/ui/viewer/file/no_thumbnail_widget.dart"; +import "package:photos/ui/viewer/file/thumbnail_widget.dart"; +import "package:photos/ui/viewer/people/add_person_action_sheet.dart"; +import "package:photos/ui/viewer/people/people_page.dart"; +import 'package:photos/ui/viewer/search/result/person_face_widget.dart'; +import "package:photos/ui/viewer/search/result/search_result_page.dart"; +import 'package:photos/ui/viewer/search/result/search_section_all_page.dart'; +import "package:photos/ui/viewer/search/search_section_cta.dart"; +import "package:photos/utils/navigation_util.dart"; + +class PeopleSection extends StatefulWidget { + final SectionType sectionType = SectionType.face; + final List examples; + final int limit; + + const PeopleSection({ + Key? key, + required this.examples, + this.limit = 7, + }) : super(key: key); + + @override + State createState() => _PeopleSectionState(); +} + +class _PeopleSectionState extends State { + late List _examples; + final streamSubscriptions = []; + + @override + void initState() { + super.initState(); + _examples = widget.examples; + + final streamsToListenTo = widget.sectionType.sectionUpdateEvents(); + for (Stream stream in streamsToListenTo) { + streamSubscriptions.add( + stream.listen((event) async { + _examples = await widget.sectionType.getData( + context, + limit: kSearchSectionLimit, + ); + setState(() {}); + }), + ); + } + } + + @override + void dispose() { + for (var subscriptions in streamSubscriptions) { + subscriptions.cancel(); + } + super.dispose(); + } + + @override + void didUpdateWidget(covariant PeopleSection oldWidget) { + super.didUpdateWidget(oldWidget); + _examples = widget.examples; + } + + @override + Widget build(BuildContext context) { + debugPrint("Building section for ${widget.sectionType.name}"); + final shouldShowMore = _examples.length >= widget.limit - 1; + final textTheme = getEnteTextTheme(context); + return _examples.isNotEmpty + ? GestureDetector( + behavior: HitTestBehavior.opaque, + onTap: () { + if (shouldShowMore) { + routeToPage( + context, + SearchSectionAllPage( + sectionType: widget.sectionType, + ), + ); + } + }, + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Row( + mainAxisAlignment: MainAxisAlignment.spaceBetween, + children: [ + Padding( + padding: const EdgeInsets.all(12), + child: Text( + widget.sectionType.sectionTitle(context), + style: textTheme.largeBold, + ), + ), + shouldShowMore + ? Padding( + padding: const EdgeInsets.all(12), + child: Icon( + Icons.chevron_right_outlined, + color: getEnteColorScheme(context).strokeMuted, + ), + ) + : const SizedBox.shrink(), + ], + ), + const SizedBox(height: 2), + SearchExampleRow(_examples, widget.sectionType), + ], + ), + ) + : GestureDetector( + behavior: HitTestBehavior.opaque, + onTap: () { + routeToPage( + context, + const MachineLearningSettingsPage(), + ); + }, + child: Padding( + padding: const EdgeInsets.only(left: 16, right: 8), + child: Row( + children: [ + Expanded( + child: Padding( + padding: const EdgeInsets.symmetric(vertical: 12), + child: Column( + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + widget.sectionType.sectionTitle(context), + style: textTheme.largeBold, + ), + const SizedBox(height: 24), + Text( + widget.sectionType.getEmptyStateText(context), + style: textTheme.smallMuted, + ), + ], + ), + ), + ), + const SizedBox(width: 8), + SearchSectionEmptyCTAIcon(widget.sectionType), + ], + ), + ), + ); + } +} + +class SearchExampleRow extends StatelessWidget { + final SectionType sectionType; + final List examples; + + const SearchExampleRow(this.examples, this.sectionType, {super.key}); + + @override + Widget build(BuildContext context) { + //Cannot use listView.builder here + final scrollableExamples = []; + examples.forEachIndexed((index, element) { + scrollableExamples.add( + SearchExample( + searchResult: examples.elementAt(index), + ), + ); + }); + return SizedBox( + child: SingleChildScrollView( + physics: const BouncingScrollPhysics(), + scrollDirection: Axis.horizontal, + child: Row( + crossAxisAlignment: CrossAxisAlignment.start, + children: scrollableExamples, + ), + ), + ); + } +} + +class SearchExample extends StatelessWidget { + final SearchResult searchResult; + const SearchExample({required this.searchResult, super.key}); + + @override + Widget build(BuildContext context) { + final textScaleFactor = MediaQuery.textScaleFactorOf(context); + final bool isCluster = (searchResult.type() == ResultType.faces && + int.tryParse(searchResult.name()) != null); + late final double width; + if (textScaleFactor <= 1.0) { + width = 85.0; + } else { + width = 85.0 + ((textScaleFactor - 1.0) * 64); + } + final heroTag = + searchResult.heroTag() + (searchResult.previewThumbnail()?.tag ?? ""); + return GestureDetector( + onTap: () { + RecentSearches().add(searchResult.name()); + + if (searchResult is GenericSearchResult) { + final genericSearchResult = searchResult as GenericSearchResult; + if (genericSearchResult.onResultTap != null) { + genericSearchResult.onResultTap!(context); + } else { + routeToPage( + context, + SearchResultPage(searchResult), + ); + } + } else if (searchResult is AlbumSearchResult) { + final albumSearchResult = searchResult as GenericSearchResult; + routeToPage( + context, + SearchResultPage( + albumSearchResult, + tagPrefix: albumSearchResult.heroTag(), + ), + ); + } + }, + child: SizedBox( + width: width, + child: Padding( + padding: const EdgeInsets.only(left: 6, right: 6, top: 8), + child: Column( + mainAxisSize: MainAxisSize.min, + children: [ + SizedBox( + width: 64, + height: 64, + child: searchResult.previewThumbnail() != null + ? Hero( + tag: heroTag, + child: ClipRRect( + borderRadius: + const BorderRadius.all(Radius.elliptical(16, 12)), + child: searchResult.type() != ResultType.faces + ? ThumbnailWidget( + searchResult.previewThumbnail()!, + shouldShowSyncStatus: false, + ) + : FaceSearchResult(searchResult, heroTag), + ), + ) + : const ClipRRect( + borderRadius: + BorderRadius.all(Radius.elliptical(16, 12)), + child: NoThumbnailWidget( + addBorder: false, + ), + ), + ), + isCluster + ? GestureDetector( + behavior: HitTestBehavior.translucent, + onTap: () async { + final result = await showAssignPersonAction( + context, + clusterID: int.parse(searchResult.name()), + ); + if (result != null && + result is (PersonEntity, EnteFile)) { + // ignore: unawaited_futures + routeToPage(context, PeoplePage(person: result.$1)); + } else if (result != null && result is PersonEntity) { + // ignore: unawaited_futures + routeToPage(context, PeoplePage(person: result)); + } + }, + child: Padding( + padding: const EdgeInsets.only(top: 10, bottom: 16), + child: Text( + "Add name", + maxLines: 1, + textAlign: TextAlign.center, + overflow: TextOverflow.ellipsis, + style: getEnteTextTheme(context).mini, + ), + ), + ) + : Padding( + padding: const EdgeInsets.only(top: 10, bottom: 16), + child: Text( + searchResult.name(), + maxLines: 2, + textAlign: TextAlign.center, + overflow: TextOverflow.ellipsis, + style: getEnteTextTheme(context).mini, + ), + ), + ], + ), + ), + ), + ); + } +} + +class FaceSearchResult extends StatelessWidget { + final SearchResult searchResult; + final String heroTagPrefix; + const FaceSearchResult(this.searchResult, this.heroTagPrefix, {super.key}); + + @override + Widget build(BuildContext context) { + return PersonFaceWidget( + searchResult.previewThumbnail()!, + personId: (searchResult as GenericSearchResult).params[kPersonParamID], + clusterID: (searchResult as GenericSearchResult).params[kClusterParamId], + ); + } +} diff --git a/mobile/lib/ui/viewer/search_tab/search_tab.dart b/mobile/lib/ui/viewer/search_tab/search_tab.dart index bfb35600a..46dcfda03 100644 --- a/mobile/lib/ui/viewer/search_tab/search_tab.dart +++ b/mobile/lib/ui/viewer/search_tab/search_tab.dart @@ -1,9 +1,12 @@ import "package:fade_indexed_stack/fade_indexed_stack.dart"; +import "package:flutter/foundation.dart"; import "package:flutter/material.dart"; import "package:flutter_animate/flutter_animate.dart"; +import "package:logging/logging.dart"; import "package:photos/models/search/album_search_result.dart"; import "package:photos/models/search/generic_search_result.dart"; import "package:photos/models/search/index_of_indexed_stack.dart"; +import "package:photos/models/search/search_result.dart"; import "package:photos/models/search/search_types.dart"; import "package:photos/states/all_sections_examples_state.dart"; import "package:photos/ui/common/loading_widget.dart"; @@ -16,6 +19,8 @@ import "package:photos/ui/viewer/search_tab/descriptions_section.dart"; import "package:photos/ui/viewer/search_tab/file_type_section.dart"; import "package:photos/ui/viewer/search_tab/locations_section.dart"; import "package:photos/ui/viewer/search_tab/moments_section.dart"; +import "package:photos/ui/viewer/search_tab/people_section.dart"; +import "package:photos/utils/local_settings.dart"; class SearchTab extends StatefulWidget { const SearchTab({Key? key}) : super(key: key); @@ -73,17 +78,17 @@ class AllSearchSections extends StatefulWidget { } class _AllSearchSectionsState extends State { + final Logger _logger = Logger('_AllSearchSectionsState'); @override Widget build(BuildContext context) { final searchTypes = SectionType.values.toList(growable: true); - // remove face and content sectionType - searchTypes.remove(SectionType.face); searchTypes.remove(SectionType.content); + return Padding( padding: const EdgeInsets.only(top: 8), child: Stack( children: [ - FutureBuilder( + FutureBuilder>>( future: InheritedAllSectionsExamples.of(context) .allSectionsExamplesFuture, builder: (context, snapshot) { @@ -94,6 +99,14 @@ class _AllSearchSectionsState extends State { child: SearchTabEmptyState(), ); } + if (snapshot.data!.length != searchTypes.length) { + return Padding( + padding: const EdgeInsets.only(bottom: 72), + child: Text( + 'Sections length mismatch: ${snapshot.data!.length} != ${searchTypes.length}', + ), + ); + } return ListView.builder( padding: const EdgeInsets.only(bottom: 180), physics: const BouncingScrollPhysics(), @@ -101,7 +114,16 @@ class _AllSearchSectionsState extends State { // ignore: body_might_complete_normally_nullable itemBuilder: (context, index) { switch (searchTypes[index]) { + case SectionType.face: + if (!LocalSettings.instance.isFaceIndexingEnabled) { + return const SizedBox.shrink(); + } + return PeopleSection( + examples: snapshot.data!.elementAt(index) + as List, + ); case SectionType.album: + // return const SizedBox.shrink(); return AlbumsSection( snapshot.data!.elementAt(index) as List, @@ -150,6 +172,17 @@ class _AllSearchSectionsState extends State { curve: Curves.easeOut, ); } else if (snapshot.hasError) { + _logger.severe( + 'Failed to load sections: ', + snapshot.error, + snapshot.stackTrace, + ); + if (kDebugMode) { + return Padding( + padding: const EdgeInsets.only(bottom: 72), + child: Text('Error: ${snapshot.error}'), + ); + } //Errors are handled and this else if condition will be false always //is the understanding. return const Padding( diff --git a/mobile/lib/utils/debug_ml_export_data.dart b/mobile/lib/utils/debug_ml_export_data.dart new file mode 100644 index 000000000..f7a5e9646 --- /dev/null +++ b/mobile/lib/utils/debug_ml_export_data.dart @@ -0,0 +1,40 @@ +import "dart:convert"; +import "dart:developer" show log; +import "dart:io"; + +import "package:path_provider/path_provider.dart"; + +Future encodeAndSaveData( + dynamic nestedData, + String fileName, [ + String? service, +]) async { + // Convert map keys to strings if nestedData is a map + final dataToEncode = nestedData is Map + ? nestedData.map((key, value) => MapEntry(key.toString(), value)) + : nestedData; + // Step 1: Serialize Your Data + final String jsonData = jsonEncode(dataToEncode); + + // Step 2: Encode the JSON String to Base64 + // final String base64String = base64Encode(utf8.encode(jsonData)); + + // Step 3 & 4: Write the Base64 String to a File and Execute the Function + try { + final File file = await _writeStringToFile(jsonData, fileName); + // Success, handle the file, e.g., print the file path + log('[$service]: File saved at ${file.path}'); + } catch (e) { + // If an error occurs, handle it. + log('[$service]: Error saving file: $e'); + } +} + +Future _writeStringToFile( + String dataString, + String fileName, +) async { + final directory = await getExternalStorageDirectory(); + final file = File('${directory!.path}/$fileName.json'); + return file.writeAsString(dataString); +} diff --git a/mobile/lib/utils/dialog_util.dart b/mobile/lib/utils/dialog_util.dart index f6e9eb021..d57a6990a 100644 --- a/mobile/lib/utils/dialog_util.dart +++ b/mobile/lib/utils/dialog_util.dart @@ -109,7 +109,11 @@ String parseErrorForUI( errorInfo = "Reason: " + dioError.type.toString(); } } else { - errorInfo = error.toString().split('Source stack')[0]; + if (kDebugMode) { + errorInfo = error.toString(); + } else { + errorInfo = error.toString().split('Source stack')[0]; + } } if (errorInfo.isNotEmpty) { return "$genericError\n\n$errorInfo"; diff --git a/mobile/lib/utils/face/face_box_crop.dart b/mobile/lib/utils/face/face_box_crop.dart new file mode 100644 index 000000000..281c0ef49 --- /dev/null +++ b/mobile/lib/utils/face/face_box_crop.dart @@ -0,0 +1,56 @@ +import "dart:io" show File; + +import "package:flutter/foundation.dart"; +import "package:photos/core/cache/lru_map.dart"; +import "package:photos/face/model/box.dart"; +import "package:photos/models/file/file.dart"; +import "package:photos/models/file/file_type.dart"; +// import "package:photos/utils/face/face_util.dart"; +import "package:photos/utils/file_util.dart"; +import "package:photos/utils/image_ml_isolate.dart"; +import "package:photos/utils/thumbnail_util.dart"; +import "package:pool/pool.dart"; + +final LRUMap faceCropCache = LRUMap(1000); +final LRUMap faceCropThumbnailCache = LRUMap(1000); +final poolFullFileFaceGenerations = + Pool(20, timeout: const Duration(seconds: 15)); +final poolThumbnailFaceGenerations = + Pool(100, timeout: const Duration(seconds: 15)); +Future?> getFaceCrops( + EnteFile file, + Map faceBoxeMap, { + bool useFullFile = true, +}) async { + late String? imagePath; + if (useFullFile && file.fileType != FileType.video) { + final File? ioFile = await getFile(file); + if (ioFile == null) { + return null; + } + imagePath = ioFile.path; + } else { + final thumbnail = await getThumbnailForUploadedFile(file); + if (thumbnail == null) { + return null; + } + imagePath = thumbnail.path; + } + final List faceIds = []; + final List faceBoxes = []; + for (final e in faceBoxeMap.entries) { + faceIds.add(e.key); + faceBoxes.add(e.value); + } + final List faceCrop = + await ImageMlIsolate.instance.generateFaceThumbnailsForImageUsingCanvas( + // await generateJpgFaceThumbnails( + imagePath, + faceBoxes, + ); + final Map result = {}; + for (int i = 0; i < faceIds.length; i++) { + result[faceIds[i]] = faceCrop[i]; + } + return result; +} diff --git a/mobile/lib/utils/face/face_util.dart b/mobile/lib/utils/face/face_util.dart new file mode 100644 index 000000000..c49d57b40 --- /dev/null +++ b/mobile/lib/utils/face/face_util.dart @@ -0,0 +1,175 @@ +import "dart:math"; +import "dart:typed_data"; + +import "package:computer/computer.dart"; +import "package:flutter_image_compress/flutter_image_compress.dart"; +import "package:image/image.dart" as img; +import "package:logging/logging.dart"; +import "package:photos/face/model/box.dart"; + +/// Bounding box of a face. +/// +/// [xMin] and [yMin] are the coordinates of the top left corner of the box, and +/// [width] and [height] are the width and height of the box. +/// +/// One unit is equal to one pixel in the original image. +class FaceBoxImage { + final int xMin; + final int yMin; + final int width; + final int height; + + FaceBoxImage({ + required this.xMin, + required this.yMin, + required this.width, + required this.height, + }); +} + +final _logger = Logger("FaceUtil"); +final _computer = Computer.shared(); +const _faceImageBufferFactor = 0.2; + +///Convert img.Image to ui.Image and use RawImage to display. +Future> generateImgFaceThumbnails( + String imagePath, + List faceBoxes, +) async { + final faceThumbnails = []; + + final image = await decodeToImgImage(imagePath); + + for (FaceBox faceBox in faceBoxes) { + final croppedImage = cropFaceBoxFromImage(image, faceBox); + faceThumbnails.add(croppedImage); + } + + return faceThumbnails; +} + +Future> generateJpgFaceThumbnails( + String imagePath, + List faceBoxes, +) async { + final image = await decodeToImgImage(imagePath); + final croppedImages = []; + for (FaceBox faceBox in faceBoxes) { + final croppedImage = cropFaceBoxFromImage(image, faceBox); + croppedImages.add(croppedImage); + } + + return await _computer + .compute(_encodeImagesToJpg, param: {"images": croppedImages}); +} + +Future decodeToImgImage(String imagePath) async { + img.Image? image = + await _computer.compute(_decodeImageFile, param: {"filePath": imagePath}); + + if (image == null) { + _logger.info( + "Failed to decode image. Compressing to jpg and decoding", + ); + final compressedJPGImage = + await FlutterImageCompress.compressWithFile(imagePath); + image = await _computer.compute( + _decodeJpg, + param: {"image": compressedJPGImage}, + ); + + if (image == null) { + throw Exception("Failed to decode image"); + } else { + return image; + } + } else { + return image; + } +} + +/// Returns an Image from 'package:image/image.dart' +img.Image cropFaceBoxFromImage(img.Image image, FaceBox faceBox) { + final squareFaceBox = _getSquareFaceBoxImage(image, faceBox); + final squareFaceBoxWithBuffer = + _addBufferAroundFaceBox(squareFaceBox, _faceImageBufferFactor); + return img.copyCrop( + image, + x: squareFaceBoxWithBuffer.xMin, + y: squareFaceBoxWithBuffer.yMin, + width: squareFaceBoxWithBuffer.width, + height: squareFaceBoxWithBuffer.height, + antialias: false, + ); +} + +/// Returns a square face box image from the original image with +/// side length equal to the maximum of the width and height of the face box in +/// the OG image. +FaceBoxImage _getSquareFaceBoxImage(img.Image image, FaceBox faceBox) { + final width = (image.width * faceBox.width).round(); + final height = (image.height * faceBox.height).round(); + final side = max(width, height); + final xImage = (image.width * faceBox.xMin).round(); + final yImage = (image.height * faceBox.yMin).round(); + + if (height >= width) { + final xImageAdj = (xImage - (height - width) / 2).round(); + return FaceBoxImage( + xMin: xImageAdj, + yMin: yImage, + width: side, + height: side, + ); + } else { + final yImageAdj = (yImage - (width - height) / 2).round(); + return FaceBoxImage( + xMin: xImage, + yMin: yImageAdj, + width: side, + height: side, + ); + } +} + +///To add some buffer around the face box so that the face isn't cropped +///too close to the face. +FaceBoxImage _addBufferAroundFaceBox( + FaceBoxImage faceBoxImage, + double bufferFactor, +) { + final heightBuffer = faceBoxImage.height * bufferFactor; + final widthBuffer = faceBoxImage.width * bufferFactor; + final xMinWithBuffer = faceBoxImage.xMin - widthBuffer; + final yMinWithBuffer = faceBoxImage.yMin - heightBuffer; + final widthWithBuffer = faceBoxImage.width + 2 * widthBuffer; + final heightWithBuffer = faceBoxImage.height + 2 * heightBuffer; + //Do not add buffer if the top left edge of the image is out of bounds + //after adding the buffer. + if (xMinWithBuffer < 0 || yMinWithBuffer < 0) { + return faceBoxImage; + } + //Another similar case that can be handled is when the bottom right edge + //of the image is out of bounds after adding the buffer. But the + //the visual difference is not as significant as when the top left edge + //is out of bounds, so we are not handling that case. + return FaceBoxImage( + xMin: xMinWithBuffer.round(), + yMin: yMinWithBuffer.round(), + width: widthWithBuffer.round(), + height: heightWithBuffer.round(), + ); +} + +List _encodeImagesToJpg(Map args) { + final images = args["images"] as List; + return images.map((img.Image image) => img.encodeJpg(image)).toList(); +} + +Future _decodeImageFile(Map args) async { + return await img.decodeImageFile(args["filePath"]); +} + +img.Image? _decodeJpg(Map args) { + return img.decodeJpg(args["image"])!; +} diff --git a/mobile/lib/utils/file_download_util.dart b/mobile/lib/utils/file_download_util.dart index a8847e3fd..6db6ecbe0 100644 --- a/mobile/lib/utils/file_download_util.dart +++ b/mobile/lib/utils/file_download_util.dart @@ -47,9 +47,9 @@ Future downloadAndDecrypt( ), onReceiveProgress: (a, b) { if (kDebugMode && a >= 0 && b >= 0) { - _logger.fine( - "$logPrefix download progress: ${formatBytes(a)} / ${formatBytes(b)}", - ); + // _logger.fine( + // "$logPrefix download progress: ${formatBytes(a)} / ${formatBytes(b)}", + // ); } progressCallback?.call(a, b); }, @@ -89,7 +89,8 @@ Future downloadAndDecrypt( getFileKey(file), ); fakeProgress?.stop(); - _logger.info('$logPrefix decryption completed'); + _logger + .info('$logPrefix decryption completed (genID ${file.generatedID})'); } catch (e, s) { fakeProgress?.stop(); _logger.severe("Critical: $logPrefix failed to decrypt", e, s); diff --git a/mobile/lib/utils/file_uploader.dart b/mobile/lib/utils/file_uploader.dart index ad1015303..9b1b37fb4 100644 --- a/mobile/lib/utils/file_uploader.dart +++ b/mobile/lib/utils/file_uploader.dart @@ -5,7 +5,6 @@ import 'dart:io'; import 'dart:math' as math; import 'package:collection/collection.dart'; -import 'package:connectivity_plus/connectivity_plus.dart'; import 'package:dio/dio.dart'; import 'package:flutter/foundation.dart'; import 'package:logging/logging.dart'; @@ -39,6 +38,7 @@ import 'package:photos/utils/crypto_util.dart'; import 'package:photos/utils/file_download_util.dart'; import 'package:photos/utils/file_uploader_util.dart'; import "package:photos/utils/file_util.dart"; +import "package:photos/utils/network_util.dart"; import 'package:shared_preferences/shared_preferences.dart'; import 'package:tuple/tuple.dart'; import "package:uuid/uuid.dart"; @@ -382,18 +382,7 @@ class FileUploader { if (isForceUpload) { return; } - final List connections = - await (Connectivity().checkConnectivity()); - bool canUploadUnderCurrentNetworkConditions = true; - if (!Configuration.instance.shouldBackupOverMobileData()) { - if (connections.any((element) => element == ConnectivityResult.mobile)) { - canUploadUnderCurrentNetworkConditions = false; - } else { - _logger.info( - "mobileBackupDisabled, backing up with connections: ${connections.map((e) => e.name).toString()}", - ); - } - } + final canUploadUnderCurrentNetworkConditions = await canUseHighBandwidth(); if (!canUploadUnderCurrentNetworkConditions) { throw WiFiUnavailableError(); diff --git a/mobile/lib/utils/file_util.dart b/mobile/lib/utils/file_util.dart index 5c9dcede1..b845d2ff6 100644 --- a/mobile/lib/utils/file_util.dart +++ b/mobile/lib/utils/file_util.dart @@ -278,7 +278,9 @@ Future<_LivePhoto?> _downloadLivePhoto( if (imageFileCache != null && videoFileCache != null) { return _LivePhoto(imageFileCache, videoFileCache); } else { - debugPrint("Warning: Either image or video is missing from remoteLive"); + debugPrint( + "Warning: ${file.tag} either image ${imageFileCache == null} or video ${videoFileCache == null} is missing from remoteLive", + ); return null; } }).catchError((e) { diff --git a/mobile/lib/utils/image_ml_isolate.dart b/mobile/lib/utils/image_ml_isolate.dart new file mode 100644 index 000000000..66de0c255 --- /dev/null +++ b/mobile/lib/utils/image_ml_isolate.dart @@ -0,0 +1,562 @@ +import 'dart:async'; +import "dart:io" show File; +import 'dart:isolate'; +import 'dart:typed_data' show Float32List, Uint8List; +import 'dart:ui'; + +import "package:dart_ui_isolate/dart_ui_isolate.dart"; +import "package:logging/logging.dart"; +import "package:photos/face/model/box.dart"; +import "package:photos/face/model/dimension.dart"; +import 'package:photos/models/ml/ml_typedefs.dart'; +import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; +import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; +import "package:photos/utils/image_ml_util.dart"; +import "package:synchronized/synchronized.dart"; + +enum ImageOperation { + @Deprecated("No longer using BlazeFace`") + preprocessBlazeFace, + preprocessYoloOnnx, + preprocessFaceAlign, + preprocessMobileFaceNet, + preprocessMobileFaceNetOnnx, + generateFaceThumbnails, + generateFaceThumbnailsUsingCanvas, + cropAndPadFace, +} + +/// The isolate below uses functions from ["package:photos/utils/image_ml_util.dart"] to preprocess images for ML models. + +/// This class is responsible for all image operations needed for ML models. It runs in a separate isolate to avoid jank. +/// +/// It can be accessed through the singleton `ImageConversionIsolate.instance`. e.g. `ImageConversionIsolate.instance.convert(imageData)` +/// +/// IMPORTANT: Make sure to dispose of the isolate when you're done with it with `dispose()`, e.g. `ImageConversionIsolate.instance.dispose();` +class ImageMlIsolate { + // static const String debugName = 'ImageMlIsolate'; + + final _logger = Logger('ImageMlIsolate'); + + Timer? _inactivityTimer; + final Duration _inactivityDuration = const Duration(seconds: 60); + int _activeTasks = 0; + + final _initLock = Lock(); + final _functionLock = Lock(); + + late DartUiIsolate _isolate; + late ReceivePort _receivePort = ReceivePort(); + late SendPort _mainSendPort; + + bool isSpawned = false; + + // singleton pattern + ImageMlIsolate._privateConstructor(); + + /// Use this instance to access the ImageConversionIsolate service. Make sure to call `init()` before using it. + /// e.g. `await ImageConversionIsolate.instance.init();` + /// And kill the isolate when you're done with it with `dispose()`, e.g. `ImageConversionIsolate.instance.dispose();` + /// + /// Then you can use `convert()` to get the image, so `ImageConversionIsolate.instance.convert(imageData, imagePath: imagePath)` + static final ImageMlIsolate instance = ImageMlIsolate._privateConstructor(); + factory ImageMlIsolate() => instance; + + Future init() async { + return _initLock.synchronized(() async { + if (isSpawned) return; + + _receivePort = ReceivePort(); + + try { + _isolate = await DartUiIsolate.spawn( + _isolateMain, + _receivePort.sendPort, + ); + _mainSendPort = await _receivePort.first as SendPort; + isSpawned = true; + + _resetInactivityTimer(); + } catch (e) { + _logger.severe('Could not spawn isolate', e); + isSpawned = false; + } + }); + } + + Future ensureSpawned() async { + if (!isSpawned) { + await init(); + } + } + + @pragma('vm:entry-point') + static void _isolateMain(SendPort mainSendPort) async { + final receivePort = ReceivePort(); + mainSendPort.send(receivePort.sendPort); + + receivePort.listen((message) async { + final functionIndex = message[0] as int; + final function = ImageOperation.values[functionIndex]; + final args = message[1] as Map; + final sendPort = message[2] as SendPort; + + try { + switch (function) { + case ImageOperation.preprocessBlazeFace: + final imageData = args['imageData'] as Uint8List; + final normalize = args['normalize'] as bool; + final int normalization = normalize ? 2 : -1; + final requiredWidth = args['requiredWidth'] as int; + final requiredHeight = args['requiredHeight'] as int; + final qualityIndex = args['quality'] as int; + final maintainAspectRatio = args['maintainAspectRatio'] as bool; + final quality = FilterQuality.values[qualityIndex]; + final (result, originalSize, newSize) = + await preprocessImageToMatrix( + imageData, + normalization: normalization, + requiredWidth: requiredWidth, + requiredHeight: requiredHeight, + quality: quality, + maintainAspectRatio: maintainAspectRatio, + ); + sendPort.send({ + 'inputs': result, + 'originalWidth': originalSize.width, + 'originalHeight': originalSize.height, + 'newWidth': newSize.width, + 'newHeight': newSize.height, + }); + case ImageOperation.preprocessYoloOnnx: + final imageData = args['imageData'] as Uint8List; + final normalize = args['normalize'] as bool; + final int normalization = normalize ? 1 : -1; + final requiredWidth = args['requiredWidth'] as int; + final requiredHeight = args['requiredHeight'] as int; + final maintainAspectRatio = args['maintainAspectRatio'] as bool; + final Image image = await decodeImageFromData(imageData); + final imageByteData = await getByteDataFromImage(image); + final (result, originalSize, newSize) = + await preprocessImageToFloat32ChannelsFirst( + image, + imageByteData, + normalization: normalization, + requiredWidth: requiredWidth, + requiredHeight: requiredHeight, + maintainAspectRatio: maintainAspectRatio, + ); + sendPort.send({ + 'inputs': result, + 'originalWidth': originalSize.width, + 'originalHeight': originalSize.height, + 'newWidth': newSize.width, + 'newHeight': newSize.height, + }); + case ImageOperation.preprocessFaceAlign: + final imageData = args['imageData'] as Uint8List; + final faceLandmarks = + args['faceLandmarks'] as List>>; + final List result = await preprocessFaceAlignToUint8List( + imageData, + faceLandmarks, + ); + sendPort.send(List.from(result)); + case ImageOperation.preprocessMobileFaceNet: + final imageData = args['imageData'] as Uint8List; + final facesJson = args['facesJson'] as List>; + final ( + inputs, + alignmentResults, + isBlurs, + blurValues, + originalSize + ) = await preprocessToMobileFaceNetInput( + imageData, + facesJson, + ); + final List> alignmentResultsJson = + alignmentResults.map((result) => result.toJson()).toList(); + sendPort.send({ + 'inputs': inputs, + 'alignmentResultsJson': alignmentResultsJson, + 'isBlurs': isBlurs, + 'blurValues': blurValues, + 'originalWidth': originalSize.width, + 'originalHeight': originalSize.height, + }); + case ImageOperation.preprocessMobileFaceNetOnnx: + final imagePath = args['imagePath'] as String; + final facesJson = args['facesJson'] as List>; + final List relativeFaces = facesJson + .map((face) => FaceDetectionRelative.fromJson(face)) + .toList(); + final imageData = await File(imagePath).readAsBytes(); + final Image image = await decodeImageFromData(imageData); + final imageByteData = await getByteDataFromImage(image); + final ( + inputs, + alignmentResults, + isBlurs, + blurValues, + originalSize + ) = await preprocessToMobileFaceNetFloat32List( + image, + imageByteData, + relativeFaces, + ); + final List> alignmentResultsJson = + alignmentResults.map((result) => result.toJson()).toList(); + sendPort.send({ + 'inputs': inputs, + 'alignmentResultsJson': alignmentResultsJson, + 'isBlurs': isBlurs, + 'blurValues': blurValues, + 'originalWidth': originalSize.width, + 'originalHeight': originalSize.height, + }); + case ImageOperation.generateFaceThumbnails: + final imagePath = args['imagePath'] as String; + final Uint8List imageData = await File(imagePath).readAsBytes(); + final faceBoxesJson = + args['faceBoxesList'] as List>; + final List faceBoxes = + faceBoxesJson.map((json) => FaceBox.fromJson(json)).toList(); + final List results = await generateFaceThumbnails( + imageData, + faceBoxes, + ); + sendPort.send(List.from(results)); + case ImageOperation.generateFaceThumbnailsUsingCanvas: + final imagePath = args['imagePath'] as String; + final Uint8List imageData = await File(imagePath).readAsBytes(); + final faceBoxesJson = + args['faceBoxesList'] as List>; + final List faceBoxes = + faceBoxesJson.map((json) => FaceBox.fromJson(json)).toList(); + final List results = + await generateFaceThumbnailsUsingCanvas( + imageData, + faceBoxes, + ); + sendPort.send(List.from(results)); + case ImageOperation.cropAndPadFace: + final imageData = args['imageData'] as Uint8List; + final faceBox = args['faceBox'] as List; + final Uint8List result = + await cropAndPadFaceData(imageData, faceBox); + sendPort.send([result]); + } + } catch (e, stackTrace) { + sendPort + .send({'error': e.toString(), 'stackTrace': stackTrace.toString()}); + } + }); + } + + /// The common method to run any operation in the isolate. It sends the [message] to [_isolateMain] and waits for the result. + Future _runInIsolate( + (ImageOperation, Map) message, + ) async { + await ensureSpawned(); + return _functionLock.synchronized(() async { + _resetInactivityTimer(); + final completer = Completer(); + final answerPort = ReceivePort(); + + _activeTasks++; + _mainSendPort.send([message.$1.index, message.$2, answerPort.sendPort]); + + answerPort.listen((receivedMessage) { + if (receivedMessage is Map && receivedMessage.containsKey('error')) { + // Handle the error + final errorMessage = receivedMessage['error']; + final errorStackTrace = receivedMessage['stackTrace']; + final exception = Exception(errorMessage); + final stackTrace = StackTrace.fromString(errorStackTrace); + completer.completeError(exception, stackTrace); + } else { + completer.complete(receivedMessage); + } + }); + _activeTasks--; + + return completer.future; + }); + } + + /// Resets a timer that kills the isolate after a certain amount of inactivity. + /// + /// Should be called after initialization (e.g. inside `init()`) and after every call to isolate (e.g. inside `_runInIsolate()`) + void _resetInactivityTimer() { + _inactivityTimer?.cancel(); + _inactivityTimer = Timer(_inactivityDuration, () { + if (_activeTasks > 0) { + _logger.info('Tasks are still running. Delaying isolate disposal.'); + // Optionally, reschedule the timer to check again later. + _resetInactivityTimer(); + } else { + _logger.info( + 'Clustering Isolate has been inactive for ${_inactivityDuration.inSeconds} seconds with no tasks running. Killing isolate.', + ); + dispose(); + } + }); + } + + /// Disposes the isolate worker. + void dispose() { + if (!isSpawned) return; + + isSpawned = false; + _isolate.kill(); + _receivePort.close(); + _inactivityTimer?.cancel(); + } + + /// Preprocesses [imageData] for standard ML models inside a separate isolate. + /// + /// Returns a [Num3DInputMatrix] image usable for ML inference with BlazeFace. + /// + /// Uses [preprocessImageToMatrix] inside the isolate. + @Deprecated("No longer using BlazeFace") + Future<(Num3DInputMatrix, Size, Size)> preprocessImageBlazeFace( + Uint8List imageData, { + required bool normalize, + required int requiredWidth, + required int requiredHeight, + FilterQuality quality = FilterQuality.medium, + bool maintainAspectRatio = true, + }) async { + final Map results = await _runInIsolate( + ( + ImageOperation.preprocessBlazeFace, + { + 'imageData': imageData, + 'normalize': normalize, + 'requiredWidth': requiredWidth, + 'requiredHeight': requiredHeight, + 'quality': quality.index, + 'maintainAspectRatio': maintainAspectRatio, + }, + ), + ); + final inputs = results['inputs'] as Num3DInputMatrix; + final originalSize = Size( + results['originalWidth'] as double, + results['originalHeight'] as double, + ); + final newSize = Size( + results['newWidth'] as double, + results['newHeight'] as double, + ); + return (inputs, originalSize, newSize); + } + + /// Uses [preprocessImageToFloat32ChannelsFirst] inside the isolate. + @Deprecated( + "Old method, not needed since we now run the whole ML pipeline for faces in a single isolate", + ) + Future<(Float32List, Dimensions, Dimensions)> preprocessImageYoloOnnx( + Uint8List imageData, { + required bool normalize, + required int requiredWidth, + required int requiredHeight, + FilterQuality quality = FilterQuality.medium, + bool maintainAspectRatio = true, + }) async { + final Map results = await _runInIsolate( + ( + ImageOperation.preprocessYoloOnnx, + { + 'imageData': imageData, + 'normalize': normalize, + 'requiredWidth': requiredWidth, + 'requiredHeight': requiredHeight, + 'quality': quality.index, + 'maintainAspectRatio': maintainAspectRatio, + }, + ), + ); + final inputs = results['inputs'] as Float32List; + final originalSize = Dimensions( + width: results['originalWidth'] as int, + height: results['originalHeight'] as int, + ); + final newSize = Dimensions( + width: results['newWidth'] as int, + height: results['newHeight'] as int, + ); + return (inputs, originalSize, newSize); + } + + /// Preprocesses [imageData] for face alignment inside a separate isolate, to display the aligned faces. Mostly used for debugging. + /// + /// Returns a list of [Uint8List] images, one for each face, in png format. + /// + /// Uses [preprocessFaceAlignToUint8List] inside the isolate. + /// + /// WARNING: For preprocessing for MobileFaceNet, use [preprocessMobileFaceNet] instead! + @Deprecated( + "Old method, not needed since we now run the whole ML pipeline for faces in a single isolate", + ) + Future> preprocessFaceAlign( + Uint8List imageData, + List faces, + ) async { + final faceLandmarks = faces.map((face) => face.allKeypoints).toList(); + return await _runInIsolate( + ( + ImageOperation.preprocessFaceAlign, + { + 'imageData': imageData, + 'faceLandmarks': faceLandmarks, + }, + ), + ).then((value) => value.cast()); + } + + /// Preprocesses [imageData] for MobileFaceNet input inside a separate isolate. + /// + /// Returns a list of [Num3DInputMatrix] images, one for each face. + /// + /// Uses [preprocessToMobileFaceNetInput] inside the isolate. + @Deprecated("Old method used in TensorFlow Lite") + Future< + ( + List, + List, + List, + List, + Size, + )> preprocessMobileFaceNet( + Uint8List imageData, + List faces, + ) async { + final List> facesJson = + faces.map((face) => face.toJson()).toList(); + final Map results = await _runInIsolate( + ( + ImageOperation.preprocessMobileFaceNet, + { + 'imageData': imageData, + 'facesJson': facesJson, + }, + ), + ); + final inputs = results['inputs'] as List; + final alignmentResultsJson = + results['alignmentResultsJson'] as List>; + final alignmentResults = alignmentResultsJson.map((json) { + return AlignmentResult.fromJson(json); + }).toList(); + final isBlurs = results['isBlurs'] as List; + final blurValues = results['blurValues'] as List; + final originalSize = Size( + results['originalWidth'] as double, + results['originalHeight'] as double, + ); + return (inputs, alignmentResults, isBlurs, blurValues, originalSize); + } + + /// Uses [preprocessToMobileFaceNetFloat32List] inside the isolate. + @Deprecated( + "Old method, not needed since we now run the whole ML pipeline for faces in a single isolate", + ) + Future<(Float32List, List, List, List, Size)> + preprocessMobileFaceNetOnnx( + String imagePath, + List faces, + ) async { + final List> facesJson = + faces.map((face) => face.toJson()).toList(); + final Map results = await _runInIsolate( + ( + ImageOperation.preprocessMobileFaceNetOnnx, + { + 'imagePath': imagePath, + 'facesJson': facesJson, + }, + ), + ); + final inputs = results['inputs'] as Float32List; + final alignmentResultsJson = + results['alignmentResultsJson'] as List>; + final alignmentResults = alignmentResultsJson.map((json) { + return AlignmentResult.fromJson(json); + }).toList(); + final isBlurs = results['isBlurs'] as List; + final blurValues = results['blurValues'] as List; + final originalSize = Size( + results['originalWidth'] as double, + results['originalHeight'] as double, + ); + + return (inputs, alignmentResults, isBlurs, blurValues, originalSize); + } + + /// Generates face thumbnails for all [faceBoxes] in [imageData]. + /// + /// Uses [generateFaceThumbnails] inside the isolate. + Future> generateFaceThumbnailsForImage( + String imagePath, + List faceBoxes, + ) async { + final List> faceBoxesJson = + faceBoxes.map((box) => box.toJson()).toList(); + return await _runInIsolate( + ( + ImageOperation.generateFaceThumbnails, + { + 'imagePath': imagePath, + 'faceBoxesList': faceBoxesJson, + }, + ), + ).then((value) => value.cast()); + } + + /// Generates face thumbnails for all [faceBoxes] in [imageData]. + /// + /// Uses [generateFaceThumbnailsUsingCanvas] inside the isolate. + Future> generateFaceThumbnailsForImageUsingCanvas( + String imagePath, + List faceBoxes, + ) async { + final List> faceBoxesJson = + faceBoxes.map((box) => box.toJson()).toList(); + return await _runInIsolate( + ( + ImageOperation.generateFaceThumbnailsUsingCanvas, + { + 'imagePath': imagePath, + 'faceBoxesList': faceBoxesJson, + }, + ), + ).then((value) => value.cast()); + } + + @Deprecated('For second pass of BlazeFace, no longer used') + + /// Generates cropped and padded image data from [imageData] and a [faceBox]. + /// + /// The steps are: + /// 1. Crop the image to the face bounding box + /// 2. Resize this cropped image to a square that is half the BlazeFace input size + /// 3. Pad the image to the BlazeFace input size + /// + /// Uses [cropAndPadFaceData] inside the isolate. + Future cropAndPadFace( + Uint8List imageData, + List faceBox, + ) async { + return await _runInIsolate( + ( + ImageOperation.cropAndPadFace, + { + 'imageData': imageData, + 'faceBox': List.from(faceBox), + }, + ), + ).then((value) => value[0] as Uint8List); + } +} diff --git a/mobile/lib/utils/image_ml_util.dart b/mobile/lib/utils/image_ml_util.dart new file mode 100644 index 000000000..8a6051793 --- /dev/null +++ b/mobile/lib/utils/image_ml_util.dart @@ -0,0 +1,1583 @@ +import "dart:async"; +import "dart:developer" show log; +import "dart:io" show File; +import "dart:math" show min, max; +import "dart:typed_data" show Float32List, Uint8List, ByteData; +import "dart:ui"; + +// import 'package:flutter/material.dart' +// show +// ImageProvider, +// ImageStream, +// ImageStreamListener, +// ImageInfo, +// MemoryImage, +// ImageConfiguration; +// import 'package:flutter/material.dart' as material show Image; +import 'package:flutter/painting.dart' as paint show decodeImageFromList; +import 'package:ml_linalg/linalg.dart'; +import "package:photos/face/model/box.dart"; +import "package:photos/face/model/dimension.dart"; +import 'package:photos/models/ml/ml_typedefs.dart'; +import 'package:photos/services/machine_learning/face_ml/face_alignment/alignment_result.dart'; +import 'package:photos/services/machine_learning/face_ml/face_alignment/similarity_transform.dart'; +import 'package:photos/services/machine_learning/face_ml/face_detection/detection.dart'; +import 'package:photos/services/machine_learning/face_ml/face_filtering/blur_detection_service.dart'; + +/// All of the functions in this file are helper functions for the [ImageMlIsolate] isolate. +/// Don't use them outside of the isolate, unless you are okay with UI jank!!!! + +/// Reads the pixel color at the specified coordinates. +Color readPixelColor( + Image image, + ByteData byteData, + int x, + int y, +) { + if (x < 0 || x >= image.width || y < 0 || y >= image.height) { + // throw ArgumentError('Invalid pixel coordinates.'); + if (y != -1) { + log('[WARNING] `readPixelColor`: Invalid pixel coordinates, out of bounds'); + } + return const Color.fromARGB(0, 0, 0, 0); + } + assert(byteData.lengthInBytes == 4 * image.width * image.height); + + final int byteOffset = 4 * (image.width * y + x); + return Color(_rgbaToArgb(byteData.getUint32(byteOffset))); +} + +void setPixelColor( + Size imageSize, + ByteData byteData, + int x, + int y, + Color color, +) { + if (x < 0 || x >= imageSize.width || y < 0 || y >= imageSize.height) { + log('[WARNING] `setPixelColor`: Invalid pixel coordinates, out of bounds'); + return; + } + assert(byteData.lengthInBytes == 4 * imageSize.width * imageSize.height); + + final int byteOffset = 4 * (imageSize.width.toInt() * y + x); + byteData.setUint32(byteOffset, _argbToRgba(color.value)); +} + +int _rgbaToArgb(int rgbaColor) { + final int a = rgbaColor & 0xFF; + final int rgb = rgbaColor >> 8; + return rgb + (a << 24); +} + +int _argbToRgba(int argbColor) { + final int r = (argbColor >> 16) & 0xFF; + final int g = (argbColor >> 8) & 0xFF; + final int b = argbColor & 0xFF; + final int a = (argbColor >> 24) & 0xFF; + return (r << 24) + (g << 16) + (b << 8) + a; +} + +@Deprecated('Used in TensorFlow Lite only, no longer needed') + +/// Creates an empty matrix with the specified shape. +/// +/// The `shape` argument must be a list of length 2 or 3, where the first +/// element represents the number of rows, the second element represents +/// the number of columns, and the optional third element represents the +/// number of channels. The function returns a matrix filled with zeros. +/// +/// Throws an [ArgumentError] if the `shape` argument is invalid. +List createEmptyOutputMatrix(List shape, [double fillValue = 0.0]) { + if (shape.length > 5) { + throw ArgumentError('Shape must have length 1-5'); + } + + if (shape.length == 1) { + return List.filled(shape[0], fillValue); + } else if (shape.length == 2) { + return List.generate(shape[0], (_) => List.filled(shape[1], fillValue)); + } else if (shape.length == 3) { + return List.generate( + shape[0], + (_) => List.generate(shape[1], (_) => List.filled(shape[2], fillValue)), + ); + } else if (shape.length == 4) { + return List.generate( + shape[0], + (_) => List.generate( + shape[1], + (_) => List.generate(shape[2], (_) => List.filled(shape[3], fillValue)), + ), + ); + } else if (shape.length == 5) { + return List.generate( + shape[0], + (_) => List.generate( + shape[1], + (_) => List.generate( + shape[2], + (_) => + List.generate(shape[3], (_) => List.filled(shape[4], fillValue)), + ), + ), + ); + } else { + throw ArgumentError('Shape must have length 2 or 3'); + } +} + +/// Creates an input matrix from the specified image, which can be used for inference +/// +/// Returns a matrix with the shape [image.height, image.width, 3], where the third dimension represents the RGB channels, as [Num3DInputMatrix]. +/// In fact, this is either a [Double3DInputMatrix] or a [Int3DInputMatrix] depending on the `normalize` argument. +/// If `normalize` is true, the pixel values are normalized doubles in range [-1, 1]. Otherwise, they are integers in range [0, 255]. +/// +/// The `image` argument must be an ui.[Image] object. The function returns a matrix +/// with the shape `[image.height, image.width, 3]`, where the third dimension +/// represents the RGB channels. +/// +/// bool `normalize`: Normalize the image to range [-1, 1] +Num3DInputMatrix createInputMatrixFromImage( + Image image, + ByteData byteDataRgba, { + double Function(num) normFunction = normalizePixelRange2, +}) { + return List.generate( + image.height, + (y) => List.generate( + image.width, + (x) { + final pixel = readPixelColor(image, byteDataRgba, x, y); + return [ + normFunction(pixel.red), + normFunction(pixel.green), + normFunction(pixel.blue), + ]; + }, + ), + ); +} + +void addInputImageToFloat32List( + Image image, + ByteData byteDataRgba, + Float32List float32List, + int startIndex, { + double Function(num) normFunction = normalizePixelRange2, +}) { + int pixelIndex = startIndex; + for (var h = 0; h < image.height; h++) { + for (var w = 0; w < image.width; w++) { + final pixel = readPixelColor(image, byteDataRgba, w, h); + float32List[pixelIndex] = normFunction(pixel.red); + float32List[pixelIndex + 1] = normFunction(pixel.green); + float32List[pixelIndex + 2] = normFunction(pixel.blue); + pixelIndex += 3; + } + } + return; +} + +List> createGrayscaleIntMatrixFromImage( + Image image, + ByteData byteDataRgba, +) { + return List.generate( + image.height, + (y) => List.generate( + image.width, + (x) { + // 0.299 ∙ Red + 0.587 ∙ Green + 0.114 ∙ Blue + final pixel = readPixelColor(image, byteDataRgba, x, y); + return (0.299 * pixel.red + 0.587 * pixel.green + 0.114 * pixel.blue) + .round() + .clamp(0, 255); + }, + ), + ); +} + +List> createGrayscaleIntMatrixFromNormalized2List( + Float32List imageList, + int startIndex, { + int width = 112, + int height = 112, +}) { + return List.generate( + height, + (y) => List.generate( + width, + (x) { + // 0.299 ∙ Red + 0.587 ∙ Green + 0.114 ∙ Blue + final pixelIndex = startIndex + 3 * (y * width + x); + return (0.299 * unnormalizePixelRange2(imageList[pixelIndex]) + + 0.587 * unnormalizePixelRange2(imageList[pixelIndex + 1]) + + 0.114 * unnormalizePixelRange2(imageList[pixelIndex + 2])) + .round() + .clamp(0, 255); + // return unnormalizePixelRange2( + // (0.299 * imageList[pixelIndex] + + // 0.587 * imageList[pixelIndex + 1] + + // 0.114 * imageList[pixelIndex + 2]), + // ).round().clamp(0, 255); + }, + ), + ); +} + +Float32List createFloat32ListFromImageChannelsFirst( + Image image, + ByteData byteDataRgba, { + double Function(num) normFunction = normalizePixelRange2, +}) { + final convertedBytes = Float32List(3 * image.height * image.width); + final buffer = Float32List.view(convertedBytes.buffer); + + int pixelIndex = 0; + final int channelOffsetGreen = image.height * image.width; + final int channelOffsetBlue = 2 * image.height * image.width; + for (var h = 0; h < image.height; h++) { + for (var w = 0; w < image.width; w++) { + final pixel = readPixelColor(image, byteDataRgba, w, h); + buffer[pixelIndex] = normFunction(pixel.red); + buffer[pixelIndex + channelOffsetGreen] = normFunction(pixel.green); + buffer[pixelIndex + channelOffsetBlue] = normFunction(pixel.blue); + pixelIndex++; + } + } + return convertedBytes.buffer.asFloat32List(); +} + +/// Creates an input matrix from the specified image, which can be used for inference +/// +/// Returns a matrix with the shape `[3, image.height, image.width]`, where the first dimension represents the RGB channels, as [Num3DInputMatrix]. +/// In fact, this is either a [Double3DInputMatrix] or a [Int3DInputMatrix] depending on the `normalize` argument. +/// If `normalize` is true, the pixel values are normalized doubles in range [-1, 1]. Otherwise, they are integers in range [0, 255]. +/// +/// The `image` argument must be an ui.[Image] object. The function returns a matrix +/// with the shape `[3, image.height, image.width]`, where the first dimension +/// represents the RGB channels. +/// +/// bool `normalize`: Normalize the image to range [-1, 1] +Num3DInputMatrix createInputMatrixFromImageChannelsFirst( + Image image, + ByteData byteDataRgba, { + bool normalize = true, +}) { + // Create an empty 3D list. + final Num3DInputMatrix imageMatrix = List.generate( + 3, + (i) => List.generate( + image.height, + (j) => List.filled(image.width, 0), + ), + ); + + // Determine which function to use to get the pixel value. + final pixelValue = normalize ? normalizePixelRange2 : (num value) => value; + + for (int y = 0; y < image.height; y++) { + for (int x = 0; x < image.width; x++) { + // Get the pixel at (x, y). + final pixel = readPixelColor(image, byteDataRgba, x, y); + + // Assign the color channels to the respective lists. + imageMatrix[0][y][x] = pixelValue(pixel.red); + imageMatrix[1][y][x] = pixelValue(pixel.green); + imageMatrix[2][y][x] = pixelValue(pixel.blue); + } + } + return imageMatrix; +} + +/// Function normalizes the pixel value to be in range [-1, 1]. +/// +/// It assumes that the pixel value is originally in range [0, 255] +double normalizePixelRange2(num pixelValue) { + return (pixelValue / 127.5) - 1; +} + +/// Function unnormalizes the pixel value to be in range [0, 255]. +/// +/// It assumes that the pixel value is originally in range [-1, 1] +int unnormalizePixelRange2(double pixelValue) { + return ((pixelValue + 1) * 127.5).round().clamp(0, 255); +} + +/// Function normalizes the pixel value to be in range [0, 1]. +/// +/// It assumes that the pixel value is originally in range [0, 255] +double normalizePixelRange1(num pixelValue) { + return (pixelValue / 255); +} + +double normalizePixelNoRange(num pixelValue) { + return pixelValue.toDouble(); +} + +/// Decodes [Uint8List] image data to an ui.[Image] object. +Future decodeImageFromData(Uint8List imageData) async { + // Decoding using flutter paint. This is the fastest and easiest method. + final Image image = await paint.decodeImageFromList(imageData); + return image; + + // // Similar decoding as above, but without using flutter paint. This is not faster than the above. + // final Codec codec = await instantiateImageCodecFromBuffer( + // await ImmutableBuffer.fromUint8List(imageData), + // ); + // final FrameInfo frameInfo = await codec.getNextFrame(); + // return frameInfo.image; + + // Decoding using the ImageProvider, same as `image_pixels` package. This is not faster than the above. + // final Completer completer = Completer(); + // final ImageProvider provider = MemoryImage(imageData); + // final ImageStream stream = provider.resolve(const ImageConfiguration()); + // final ImageStreamListener listener = + // ImageStreamListener((ImageInfo info, bool _) { + // completer.complete(info.image); + // }); + // stream.addListener(listener); + // final Image image = await completer.future; + // stream.removeListener(listener); + // return image; + + // // Decoding using the ImageProvider from material.Image. This is not faster than the above, and also the code below is not finished! + // final materialImage = material.Image.memory(imageData); + // final ImageProvider uiImage = await materialImage.image; +} + +/// Decodes [Uint8List] RGBA bytes to an ui.[Image] object. +Future decodeImageFromRgbaBytes( + Uint8List rgbaBytes, + int width, + int height, +) { + final Completer completer = Completer(); + decodeImageFromPixels( + rgbaBytes, + width, + height, + PixelFormat.rgba8888, + (Image image) { + completer.complete(image); + }, + ); + return completer.future; +} + +/// Returns the [ByteData] object of the image, in rawRgba format. +/// +/// Throws an exception if the image could not be converted to ByteData. +Future getByteDataFromImage( + Image image, { + ImageByteFormat format = ImageByteFormat.rawRgba, +}) async { + final ByteData? byteDataRgba = await image.toByteData(format: format); + if (byteDataRgba == null) { + log('[ImageMlUtils] Could not convert image to ByteData'); + throw Exception('Could not convert image to ByteData'); + } + return byteDataRgba; +} + +/// Encodes an [Image] object to a [Uint8List], by default in the png format. +/// +/// Note that the result can be used with `Image.memory()` only if the [format] is png. +Future encodeImageToUint8List( + Image image, { + ImageByteFormat format = ImageByteFormat.png, +}) async { + final ByteData byteDataPng = + await getByteDataFromImage(image, format: format); + final encodedImage = byteDataPng.buffer.asUint8List(); + + return encodedImage; +} + +/// Resizes the [image] to the specified [width] and [height]. +/// Returns the resized image and its size as a [Size] object. Note that this size excludes any empty pixels, hence it can be different than the actual image size if [maintainAspectRatio] is true. +/// +/// [quality] determines the interpolation quality. The default [FilterQuality.medium] works best for most cases, unless you're scaling by a factor of 5-10 or more +/// [maintainAspectRatio] determines whether to maintain the aspect ratio of the original image or not. Note that maintaining aspect ratio here does not change the size of the image, but instead often means empty pixels that have to be taken into account +Future<(Image, Size)> resizeImage( + Image image, + int width, + int height, { + FilterQuality quality = FilterQuality.medium, + bool maintainAspectRatio = false, +}) async { + if (image.width == width && image.height == height) { + return (image, Size(width.toDouble(), height.toDouble())); + } + final recorder = PictureRecorder(); + final canvas = Canvas( + recorder, + Rect.fromPoints( + const Offset(0, 0), + Offset(width.toDouble(), height.toDouble()), + ), + ); + // Pre-fill the canvas with RGB color (114, 114, 114) + canvas.drawRect( + Rect.fromPoints( + const Offset(0, 0), + Offset(width.toDouble(), height.toDouble()), + ), + Paint()..color = const Color.fromARGB(255, 114, 114, 114), + ); + + double scaleW = width / image.width; + double scaleH = height / image.height; + if (maintainAspectRatio) { + final scale = min(width / image.width, height / image.height); + scaleW = scale; + scaleH = scale; + } + final scaledWidth = (image.width * scaleW).round(); + final scaledHeight = (image.height * scaleH).round(); + + canvas.drawImageRect( + image, + Rect.fromPoints( + const Offset(0, 0), + Offset(image.width.toDouble(), image.height.toDouble()), + ), + Rect.fromPoints( + const Offset(0, 0), + Offset(scaledWidth.toDouble(), scaledHeight.toDouble()), + ), + Paint()..filterQuality = quality, + ); + + final picture = recorder.endRecording(); + final resizedImage = await picture.toImage(width, height); + return (resizedImage, Size(scaledWidth.toDouble(), scaledHeight.toDouble())); +} + +Future resizeAndCenterCropImage( + Image image, + int size, { + FilterQuality quality = FilterQuality.medium, +}) async { + if (image.width == size && image.height == size) { + return image; + } + final recorder = PictureRecorder(); + final canvas = Canvas( + recorder, + Rect.fromPoints( + const Offset(0, 0), + Offset(size.toDouble(), size.toDouble()), + ), + ); + + final scale = max(size / image.width, size / image.height); + final scaledWidth = (image.width * scale).round(); + final scaledHeight = (image.height * scale).round(); + + canvas.drawImageRect( + image, + Rect.fromPoints( + const Offset(0, 0), + Offset(image.width.toDouble(), image.height.toDouble()), + ), + Rect.fromPoints( + const Offset(0, 0), + Offset(scaledWidth.toDouble(), scaledHeight.toDouble()), + ), + Paint()..filterQuality = quality, + ); + + final picture = recorder.endRecording(); + final resizedImage = await picture.toImage(size, size); + return resizedImage; +} + +/// Crops an [image] based on the specified [x], [y], [width] and [height]. +Future cropImage( + Image image, + ByteData imgByteData, { + required int x, + required int y, + required int width, + required int height, +}) async { + final newByteData = ByteData(width * height * 4); + for (var h = y; h < y + height; h++) { + for (var w = x; w < x + width; w++) { + final pixel = readPixelColor(image, imgByteData, w, h); + setPixelColor( + Size(width.toDouble(), height.toDouble()), + newByteData, + w - x, + h - y, + pixel, + ); + } + } + final newImage = await decodeImageFromRgbaBytes( + newByteData.buffer.asUint8List(), + width, + height, + ); + + return newImage; +} + +Future cropImageWithCanvasSimple( + Image image, { + required double x, + required double y, + required double width, + required double height, +}) async { + final recorder = PictureRecorder(); + final canvas = Canvas( + recorder, + Rect.fromPoints( + const Offset(0, 0), + Offset(width, height), + ), + ); + + canvas.drawImageRect( + image, + Rect.fromPoints( + Offset(x, y), + Offset(x + width, y + height), + ), + Rect.fromPoints( + const Offset(0, 0), + Offset(width, height), + ), + Paint()..filterQuality = FilterQuality.medium, + ); + + final picture = recorder.endRecording(); + return picture.toImage(width.toInt(), height.toInt()); +} + +@Deprecated('Old image processing method, use `cropImage` instead!') +/// Crops an [image] based on the specified [x], [y], [width] and [height]. +/// Optionally, the cropped image can be resized to comply with a [maxSize] and/or [minSize]. +/// Optionally, the cropped image can be rotated from the center by [rotation] radians. +/// Optionally, the [quality] of the resizing interpolation can be specified. +Future cropImageWithCanvas( + Image image, { + required double x, + required double y, + required double width, + required double height, + Size? maxSize, + Size? minSize, + double rotation = 0.0, // rotation in radians + FilterQuality quality = FilterQuality.medium, +}) async { + // Calculate the scale for resizing based on maxSize and minSize + double scaleX = 1.0; + double scaleY = 1.0; + if (maxSize != null) { + final minScale = min(maxSize.width / width, maxSize.height / height); + if (minScale < 1.0) { + scaleX = minScale; + scaleY = minScale; + } + } + if (minSize != null) { + final maxScale = max(minSize.width / width, minSize.height / height); + if (maxScale > 1.0) { + scaleX = maxScale; + scaleY = maxScale; + } + } + + // Calculate the final dimensions + final targetWidth = (width * scaleX).round(); + final targetHeight = (height * scaleY).round(); + + // Create the canvas + final recorder = PictureRecorder(); + final canvas = Canvas( + recorder, + Rect.fromPoints( + const Offset(0, 0), + Offset(targetWidth.toDouble(), targetHeight.toDouble()), + ), + ); + + // Apply rotation + final center = Offset(targetWidth / 2, targetHeight / 2); + canvas.translate(center.dx, center.dy); + canvas.rotate(rotation); + + // Enlarge both the source and destination boxes to account for the rotation (i.e. avoid cropping the corners of the image) + final List enlargedSrc = + getEnlargedAbsoluteBox([x, y, x + width, y + height], 1.5); + final List enlargedDst = getEnlargedAbsoluteBox( + [ + -center.dx, + -center.dy, + -center.dx + targetWidth, + -center.dy + targetHeight, + ], + 1.5, + ); + + canvas.drawImageRect( + image, + Rect.fromPoints( + Offset(enlargedSrc[0], enlargedSrc[1]), + Offset(enlargedSrc[2], enlargedSrc[3]), + ), + Rect.fromPoints( + Offset(enlargedDst[0], enlargedDst[1]), + Offset(enlargedDst[2], enlargedDst[3]), + ), + Paint()..filterQuality = quality, + ); + + final picture = recorder.endRecording(); + + return picture.toImage(targetWidth, targetHeight); +} + +/// Adds padding around an [Image] object. +Future addPaddingToImage( + Image image, [ + double padding = 0.5, +]) async { + const Color paddingColor = Color.fromARGB(0, 0, 0, 0); + final originalWidth = image.width; + final originalHeight = image.height; + + final paddedWidth = (originalWidth + 2 * padding * originalWidth).toInt(); + final paddedHeight = (originalHeight + 2 * padding * originalHeight).toInt(); + + final recorder = PictureRecorder(); + final canvas = Canvas( + recorder, + Rect.fromPoints( + const Offset(0, 0), + Offset(paddedWidth.toDouble(), paddedHeight.toDouble()), + ), + ); + + final paint = Paint(); + paint.color = paddingColor; + + // Draw the padding + canvas.drawRect( + Rect.fromPoints( + const Offset(0, 0), + Offset(paddedWidth.toDouble(), paddedHeight.toDouble()), + ), + paint, + ); + + // Draw the original image on top of the padding + canvas.drawImageRect( + image, + Rect.fromPoints( + const Offset(0, 0), + Offset(image.width.toDouble(), image.height.toDouble()), + ), + Rect.fromPoints( + Offset(padding * originalWidth, padding * originalHeight), + Offset( + (1 + padding) * originalWidth, + (1 + padding) * originalHeight, + ), + ), + Paint()..filterQuality = FilterQuality.none, + ); + + final picture = recorder.endRecording(); + return picture.toImage(paddedWidth, paddedHeight); +} + +/// Preprocesses [imageData] for standard ML models. +/// Returns a [Num3DInputMatrix] image, ready for inference. +/// Also returns the original image size and the new image size, respectively. +/// +/// The [imageData] argument must be a [Uint8List] object. +/// The [normalize] argument determines whether the image is normalized to range [-1, 1]. +/// The [requiredWidth] and [requiredHeight] arguments determine the size of the output image. +/// The [quality] argument determines the quality of the resizing interpolation. +/// The [maintainAspectRatio] argument determines whether the aspect ratio of the image is maintained. +@Deprecated("Old method used in blazeface") +Future<(Num3DInputMatrix, Size, Size)> preprocessImageToMatrix( + Uint8List imageData, { + required int normalization, + required int requiredWidth, + required int requiredHeight, + FilterQuality quality = FilterQuality.medium, + maintainAspectRatio = true, +}) async { + final normFunction = normalization == 2 + ? normalizePixelRange2 + : normalization == 1 + ? normalizePixelRange1 + : normalizePixelNoRange; + final Image image = await decodeImageFromData(imageData); + final originalSize = Size(image.width.toDouble(), image.height.toDouble()); + + if (image.width == requiredWidth && image.height == requiredHeight) { + final ByteData imgByteData = await getByteDataFromImage(image); + return ( + createInputMatrixFromImage( + image, + imgByteData, + normFunction: normFunction, + ), + originalSize, + originalSize + ); + } + + final (resizedImage, newSize) = await resizeImage( + image, + requiredWidth, + requiredHeight, + quality: quality, + maintainAspectRatio: maintainAspectRatio, + ); + + final ByteData imgByteData = await getByteDataFromImage(resizedImage); + final Num3DInputMatrix imageMatrix = createInputMatrixFromImage( + resizedImage, + imgByteData, + normFunction: normFunction, + ); + + return (imageMatrix, originalSize, newSize); +} + +Future<(Float32List, Dimensions, Dimensions)> + preprocessImageToFloat32ChannelsFirst( + Image image, + ByteData imgByteData, { + required int normalization, + required int requiredWidth, + required int requiredHeight, + Color Function(num, num, Image, ByteData) getPixel = getPixelBilinear, + maintainAspectRatio = true, +}) async { + final normFunction = normalization == 2 + ? normalizePixelRange2 + : normalization == 1 + ? normalizePixelRange1 + : normalizePixelNoRange; + final originalSize = Dimensions(width: image.width, height: image.height); + + if (image.width == requiredWidth && image.height == requiredHeight) { + return ( + createFloat32ListFromImageChannelsFirst( + image, + imgByteData, + normFunction: normFunction, + ), + originalSize, + originalSize + ); + } + + double scaleW = requiredWidth / image.width; + double scaleH = requiredHeight / image.height; + if (maintainAspectRatio) { + final scale = + min(requiredWidth / image.width, requiredHeight / image.height); + scaleW = scale; + scaleH = scale; + } + final scaledWidth = (image.width * scaleW).round().clamp(0, requiredWidth); + final scaledHeight = (image.height * scaleH).round().clamp(0, requiredHeight); + + final processedBytes = Float32List(3 * requiredHeight * requiredWidth); + + final buffer = Float32List.view(processedBytes.buffer); + int pixelIndex = 0; + final int channelOffsetGreen = requiredHeight * requiredWidth; + final int channelOffsetBlue = 2 * requiredHeight * requiredWidth; + for (var h = 0; h < requiredHeight; h++) { + for (var w = 0; w < requiredWidth; w++) { + late Color pixel; + if (w >= scaledWidth || h >= scaledHeight) { + pixel = const Color.fromRGBO(114, 114, 114, 1.0); + } else { + pixel = getPixel( + w / scaleW, + h / scaleH, + image, + imgByteData, + ); + } + buffer[pixelIndex] = normFunction(pixel.red); + buffer[pixelIndex + channelOffsetGreen] = normFunction(pixel.green); + buffer[pixelIndex + channelOffsetBlue] = normFunction(pixel.blue); + pixelIndex++; + } + } + + return ( + processedBytes, + originalSize, + Dimensions(width: scaledWidth, height: scaledHeight) + ); +} + +@Deprecated( + 'Replaced by `preprocessImageToFloat32ChannelsFirst` to avoid issue with iOS canvas', +) +Future<(Float32List, Size, Size)> preprocessImageToFloat32ChannelsFirstCanvas( + Uint8List imageData, { + required int normalization, + required int requiredWidth, + required int requiredHeight, + FilterQuality quality = FilterQuality.medium, + maintainAspectRatio = true, +}) async { + final normFunction = normalization == 2 + ? normalizePixelRange2 + : normalization == 1 + ? normalizePixelRange1 + : normalizePixelNoRange; + final stopwatch = Stopwatch()..start(); + final Image image = await decodeImageFromData(imageData); + stopwatch.stop(); + log("Face Detection decoding ui image took: ${stopwatch.elapsedMilliseconds} ms"); + final originalSize = Size(image.width.toDouble(), image.height.toDouble()); + late final Image resizedImage; + late final Size newSize; + + if (image.width == requiredWidth && image.height == requiredHeight) { + resizedImage = image; + newSize = originalSize; + } else { + (resizedImage, newSize) = await resizeImage( + image, + requiredWidth, + requiredHeight, + quality: quality, + maintainAspectRatio: maintainAspectRatio, + ); + } + final ByteData imgByteData = await getByteDataFromImage(resizedImage); + final Float32List imageFloat32List = createFloat32ListFromImageChannelsFirst( + resizedImage, + imgByteData, + normFunction: normFunction, + ); + + return (imageFloat32List, originalSize, newSize); +} + +/// Preprocesses [imageData] based on [faceLandmarks] to align the faces in the images. +/// +/// Returns a list of [Uint8List] images, one for each face, in png format. +@Deprecated("Old method used in blazeface") +Future> preprocessFaceAlignToUint8List( + Uint8List imageData, + List>> faceLandmarks, { + int width = 112, + int height = 112, +}) async { + final alignedImages = []; + final Image image = await decodeImageFromData(imageData); + + for (final faceLandmark in faceLandmarks) { + final (alignmentResult, correctlyEstimated) = + SimilarityTransform.instance.estimate(faceLandmark); + if (!correctlyEstimated) { + alignedImages.add(Uint8List(0)); + continue; + } + final alignmentBox = getAlignedFaceBox(alignmentResult); + final Image alignedFace = await cropImageWithCanvas( + image, + x: alignmentBox[0], + y: alignmentBox[1], + width: alignmentBox[2] - alignmentBox[0], + height: alignmentBox[3] - alignmentBox[1], + maxSize: Size(width.toDouble(), height.toDouble()), + minSize: Size(width.toDouble(), height.toDouble()), + rotation: alignmentResult.rotation, + ); + final Uint8List alignedFacePng = await encodeImageToUint8List(alignedFace); + alignedImages.add(alignedFacePng); + + // final Uint8List alignedImageRGBA = await warpAffineToUint8List( + // image, + // imgByteData, + // alignmentResult.affineMatrix + // .map( + // (row) => row.map((e) { + // if (e != 1.0) { + // return e * 112; + // } else { + // return 1.0; + // } + // }).toList(), + // ) + // .toList(), + // width: width, + // height: height, + // ); + // final Image alignedImage = + // await decodeImageFromRgbaBytes(alignedImageRGBA, width, height); + // final Uint8List alignedImagePng = + // await encodeImageToUint8List(alignedImage); + + // alignedImages.add(alignedImagePng); + } + return alignedImages; +} + +/// Preprocesses [imageData] based on [faceLandmarks] to align the faces in the images +/// +/// Returns a list of [Num3DInputMatrix] images, one for each face, ready for MobileFaceNet inference +@Deprecated("Old method used in TensorFlow Lite") +Future< + ( + List, + List, + List, + List, + Size, + )> preprocessToMobileFaceNetInput( + Uint8List imageData, + List> facesJson, { + int width = 112, + int height = 112, +}) async { + final Image image = await decodeImageFromData(imageData); + final Size originalSize = + Size(image.width.toDouble(), image.height.toDouble()); + + final List relativeFaces = + facesJson.map((face) => FaceDetectionRelative.fromJson(face)).toList(); + + final List absoluteFaces = + relativeToAbsoluteDetections( + relativeDetections: relativeFaces, + imageWidth: image.width, + imageHeight: image.height, + ); + + final List>> faceLandmarks = + absoluteFaces.map((face) => face.allKeypoints).toList(); + + final alignedImages = []; + final alignmentResults = []; + final isBlurs = []; + final blurValues = []; + + for (final faceLandmark in faceLandmarks) { + final (alignmentResult, correctlyEstimated) = + SimilarityTransform.instance.estimate(faceLandmark); + if (!correctlyEstimated) { + alignedImages.add([]); + alignmentResults.add(AlignmentResult.empty()); + continue; + } + final alignmentBox = getAlignedFaceBox(alignmentResult); + final Image alignedFace = await cropImageWithCanvas( + image, + x: alignmentBox[0], + y: alignmentBox[1], + width: alignmentBox[2] - alignmentBox[0], + height: alignmentBox[3] - alignmentBox[1], + maxSize: Size(width.toDouble(), height.toDouble()), + minSize: Size(width.toDouble(), height.toDouble()), + rotation: alignmentResult.rotation, + quality: FilterQuality.medium, + ); + final alignedFaceByteData = await getByteDataFromImage(alignedFace); + final alignedFaceMatrix = createInputMatrixFromImage( + alignedFace, + alignedFaceByteData, + normFunction: normalizePixelRange2, + ); + alignedImages.add(alignedFaceMatrix); + alignmentResults.add(alignmentResult); + final faceGrayMatrix = createGrayscaleIntMatrixFromImage( + alignedFace, + alignedFaceByteData, + ); + final (isBlur, blurValue) = await BlurDetectionService.instance + .predictIsBlurGrayLaplacian(faceGrayMatrix); + isBlurs.add(isBlur); + blurValues.add(blurValue); + + // final Double3DInputMatrix alignedImage = await warpAffineToMatrix( + // image, + // imgByteData, + // transformationMatrix, + // width: width, + // height: height, + // normalize: true, + // ); + // alignedImages.add(alignedImage); + // transformationMatrices.add(transformationMatrix); + } + return (alignedImages, alignmentResults, isBlurs, blurValues, originalSize); +} + +@Deprecated("Old image manipulation that used canvas, causing issues on iOS") +Future<(Float32List, List, List, List, Size)> + preprocessToMobileFaceNetFloat32ListCanvas( + String imagePath, + List relativeFaces, { + int width = 112, + int height = 112, +}) async { + final Uint8List imageData = await File(imagePath).readAsBytes(); + final stopwatch = Stopwatch()..start(); + final Image image = await decodeImageFromData(imageData); + stopwatch.stop(); + log("Face Alignment decoding ui image took: ${stopwatch.elapsedMilliseconds} ms"); + final Size originalSize = + Size(image.width.toDouble(), image.height.toDouble()); + + final List absoluteFaces = + relativeToAbsoluteDetections( + relativeDetections: relativeFaces, + imageWidth: image.width, + imageHeight: image.height, + ); + + final List>> faceLandmarks = + absoluteFaces.map((face) => face.allKeypoints).toList(); + + final alignedImagesFloat32List = + Float32List(3 * width * height * faceLandmarks.length); + final alignmentResults = []; + final isBlurs = []; + final blurValues = []; + + int alignedImageIndex = 0; + for (final faceLandmark in faceLandmarks) { + final (alignmentResult, correctlyEstimated) = + SimilarityTransform.instance.estimate(faceLandmark); + if (!correctlyEstimated) { + alignedImageIndex += 3 * width * height; + alignmentResults.add(AlignmentResult.empty()); + continue; + } + final alignmentBox = getAlignedFaceBox(alignmentResult); + final Image alignedFace = await cropImageWithCanvas( + image, + x: alignmentBox[0], + y: alignmentBox[1], + width: alignmentBox[2] - alignmentBox[0], + height: alignmentBox[3] - alignmentBox[1], + maxSize: Size(width.toDouble(), height.toDouble()), + minSize: Size(width.toDouble(), height.toDouble()), + rotation: alignmentResult.rotation, + quality: FilterQuality.medium, + ); + final alignedFaceByteData = await getByteDataFromImage(alignedFace); + addInputImageToFloat32List( + alignedFace, + alignedFaceByteData, + alignedImagesFloat32List, + alignedImageIndex, + normFunction: normalizePixelRange2, + ); + alignedImageIndex += 3 * width * height; + alignmentResults.add(alignmentResult); + final blurDetectionStopwatch = Stopwatch()..start(); + final faceGrayMatrix = createGrayscaleIntMatrixFromImage( + alignedFace, + alignedFaceByteData, + ); + final grascalems = blurDetectionStopwatch.elapsedMilliseconds; + log('creating grayscale matrix took $grascalems ms'); + final (isBlur, blurValue) = await BlurDetectionService.instance + .predictIsBlurGrayLaplacian(faceGrayMatrix); + final blurms = blurDetectionStopwatch.elapsedMilliseconds - grascalems; + log('blur detection took $blurms ms'); + log( + 'total blur detection took ${blurDetectionStopwatch.elapsedMilliseconds} ms', + ); + blurDetectionStopwatch.stop(); + isBlurs.add(isBlur); + blurValues.add(blurValue); + } + return ( + alignedImagesFloat32List, + alignmentResults, + isBlurs, + blurValues, + originalSize + ); +} + +Future<(Float32List, List, List, List, Size)> + preprocessToMobileFaceNetFloat32List( + Image image, + ByteData imageByteData, + List relativeFaces, { + int width = 112, + int height = 112, +}) async { + final stopwatch = Stopwatch()..start(); + + final Size originalSize = + Size(image.width.toDouble(), image.height.toDouble()); + + final List absoluteFaces = + relativeToAbsoluteDetections( + relativeDetections: relativeFaces, + imageWidth: image.width, + imageHeight: image.height, + ); + + final alignedImagesFloat32List = + Float32List(3 * width * height * absoluteFaces.length); + final alignmentResults = []; + final isBlurs = []; + final blurValues = []; + + int alignedImageIndex = 0; + for (final face in absoluteFaces) { + final (alignmentResult, correctlyEstimated) = + SimilarityTransform.instance.estimate(face.allKeypoints); + if (!correctlyEstimated) { + alignedImageIndex += 3 * width * height; + alignmentResults.add(AlignmentResult.empty()); + continue; + } + alignmentResults.add(alignmentResult); + + warpAffineFloat32List( + image, + imageByteData, + alignmentResult.affineMatrix, + alignedImagesFloat32List, + alignedImageIndex, + ); + + final blurDetectionStopwatch = Stopwatch()..start(); + final faceGrayMatrix = createGrayscaleIntMatrixFromNormalized2List( + alignedImagesFloat32List, + alignedImageIndex, + ); + + alignedImageIndex += 3 * width * height; + final grayscalems = blurDetectionStopwatch.elapsedMilliseconds; + log('creating grayscale matrix took $grayscalems ms'); + final (isBlur, blurValue) = + await BlurDetectionService.instance.predictIsBlurGrayLaplacian( + faceGrayMatrix, + faceDirection: face.getFaceDirection(), + ); + final blurms = blurDetectionStopwatch.elapsedMilliseconds - grayscalems; + log('blur detection took $blurms ms'); + log( + 'total blur detection took ${blurDetectionStopwatch.elapsedMilliseconds} ms', + ); + blurDetectionStopwatch.stop(); + isBlurs.add(isBlur); + blurValues.add(blurValue); + } + stopwatch.stop(); + log("Face Alignment took: ${stopwatch.elapsedMilliseconds} ms"); + return ( + alignedImagesFloat32List, + alignmentResults, + isBlurs, + blurValues, + originalSize + ); +} + +void warpAffineFloat32List( + Image inputImage, + ByteData imgByteDataRgba, + List> affineMatrix, + Float32List outputList, + int startIndex, { + int width = 112, + int height = 112, +}) { + if (width != 112 || height != 112) { + throw Exception( + 'Width and height must be 112, other transformations are not supported yet.', + ); + } + + final transformationMatrix = affineMatrix + .map( + (row) => row.map((e) { + if (e != 1.0) { + return e * 112; + } else { + return 1.0; + } + }).toList(), + ) + .toList(); + + final A = Matrix.fromList([ + [transformationMatrix[0][0], transformationMatrix[0][1]], + [transformationMatrix[1][0], transformationMatrix[1][1]], + ]); + final aInverse = A.inverse(); + // final aInverseMinus = aInverse * -1; + final B = Vector.fromList( + [transformationMatrix[0][2], transformationMatrix[1][2]], + ); + final b00 = B[0]; + final b10 = B[1]; + final a00Prime = aInverse[0][0]; + final a01Prime = aInverse[0][1]; + final a10Prime = aInverse[1][0]; + final a11Prime = aInverse[1][1]; + + for (int yTrans = 0; yTrans < height; ++yTrans) { + for (int xTrans = 0; xTrans < width; ++xTrans) { + // Perform inverse affine transformation (original implementation, intuitive but slow) + // final X = aInverse * (Vector.fromList([xTrans, yTrans]) - B); + // final X = aInverseMinus * (B - [xTrans, yTrans]); + // final xList = X.asFlattenedList; + // num xOrigin = xList[0]; + // num yOrigin = xList[1]; + + // Perform inverse affine transformation (fast implementation, less intuitive) + final num xOrigin = (xTrans - b00) * a00Prime + (yTrans - b10) * a01Prime; + final num yOrigin = (xTrans - b00) * a10Prime + (yTrans - b10) * a11Prime; + + final Color pixel = + getPixelBicubic(xOrigin, yOrigin, inputImage, imgByteDataRgba); + + // Set the new pixel + outputList[startIndex + 3 * (yTrans * width + xTrans)] = + normalizePixelRange2(pixel.red); + outputList[startIndex + 3 * (yTrans * width + xTrans) + 1] = + normalizePixelRange2(pixel.green); + outputList[startIndex + 3 * (yTrans * width + xTrans) + 2] = + normalizePixelRange2(pixel.blue); + } + } +} + +Future> generateFaceThumbnails( + Uint8List imageData, + List faceBoxes, +) async { + final stopwatch = Stopwatch()..start(); + + final Image img = await decodeImageFromData(imageData); + final ByteData imgByteData = await getByteDataFromImage(img); + + try { + final List faceThumbnails = []; + + for (final faceBox in faceBoxes) { + // Note that the faceBox values are relative to the image size, so we need to convert them to absolute values first + final double xMinAbs = faceBox.xMin * img.width; + final double yMinAbs = faceBox.yMin * img.height; + final double widthAbs = faceBox.width * img.width; + final double heightAbs = faceBox.height * img.height; + + final int xCrop = (xMinAbs - widthAbs / 2).round().clamp(0, img.width); + final int yCrop = (yMinAbs - heightAbs / 2).round().clamp(0, img.height); + final int widthCrop = min((widthAbs * 2).round(), img.width - xCrop); + final int heightCrop = min((heightAbs * 2).round(), img.height - yCrop); + final Image faceThumbnail = await cropImage( + img, + imgByteData, + x: xCrop, + y: yCrop, + width: widthCrop, + height: heightCrop, + ); + final Uint8List faceThumbnailPng = await encodeImageToUint8List( + faceThumbnail, + format: ImageByteFormat.png, + ); + faceThumbnails.add(faceThumbnailPng); + } + stopwatch.stop(); + log('Face thumbnail generation took: ${stopwatch.elapsedMilliseconds} ms'); + + return faceThumbnails; + } catch (e, s) { + log('[ImageMlUtils] Error generating face thumbnails: $e, \n stackTrace: $s'); + rethrow; + } +} + +/// Generates a face thumbnail from [imageData] and a [faceDetection]. +/// +/// Returns a [Uint8List] image, in png format. +Future> generateFaceThumbnailsUsingCanvas( + Uint8List imageData, + List faceBoxes, +) async { + final Image img = await decodeImageFromData(imageData); + int i = 0; + + try { + final futureFaceThumbnails = >[]; + for (final faceBox in faceBoxes) { + // Note that the faceBox values are relative to the image size, so we need to convert them to absolute values first + final double xMinAbs = faceBox.xMin * img.width; + final double yMinAbs = faceBox.yMin * img.height; + final double widthAbs = faceBox.width * img.width; + final double heightAbs = faceBox.height * img.height; + + // Calculate the crop values by adding some padding around the face and making sure it's centered + const regularPadding = 0.4; + const minimumPadding = 0.1; + final num xCrop = (xMinAbs - widthAbs * regularPadding); + final num xOvershoot = min(0, xCrop).abs() / widthAbs; + final num widthCrop = widthAbs * (1 + 2 * regularPadding) - + 2 * min(xOvershoot, regularPadding - minimumPadding) * widthAbs; + final num yCrop = (yMinAbs - heightAbs * regularPadding); + final num yOvershoot = min(0, yCrop).abs() / heightAbs; + final num heightCrop = heightAbs * (1 + 2 * regularPadding) - + 2 * min(yOvershoot, regularPadding - minimumPadding) * heightAbs; + + // Prevent the face from going out of image bounds + final xCropSafe = xCrop.clamp(0, img.width); + final yCropSafe = yCrop.clamp(0, img.height); + final widthCropSafe = widthCrop.clamp(0, img.width - xCropSafe); + final heightCropSafe = heightCrop.clamp(0, img.height - yCropSafe); + + futureFaceThumbnails.add( + cropAndEncodeCanvas( + img, + x: xCropSafe.toDouble(), + y: yCropSafe.toDouble(), + width: widthCropSafe.toDouble(), + height: heightCropSafe.toDouble(), + ), + ); + i++; + } + final List faceThumbnails = + await Future.wait(futureFaceThumbnails); + return faceThumbnails; + } catch (e) { + log('[ImageMlUtils] Error generating face thumbnails: $e'); + log('[ImageMlUtils] cropImage problematic input argument: ${faceBoxes[i]}'); + return []; + } +} + +Future cropAndEncodeCanvas( + Image image, { + required double x, + required double y, + required double width, + required double height, +}) async { + final croppedImage = await cropImageWithCanvasSimple( + image, + x: x, + y: y, + width: width, + height: height, + ); + return await encodeImageToUint8List( + croppedImage, + format: ImageByteFormat.png, + ); +} + +@Deprecated('For second pass of BlazeFace, no longer used') + +/// Generates cropped and padded image data from [imageData] and a [faceBox]. +/// +/// The steps are: +/// 1. Crop the image to the face bounding box +/// 2. Resize this cropped image to a square that is half the BlazeFace input size +/// 3. Pad the image to the BlazeFace input size +/// +/// Note that [faceBox] is a list of the following values: [xMinBox, yMinBox, xMaxBox, yMaxBox]. +Future cropAndPadFaceData( + Uint8List imageData, + List faceBox, +) async { + final Image image = await decodeImageFromData(imageData); + + final Image faceCrop = await cropImageWithCanvas( + image, + x: (faceBox[0] * image.width), + y: (faceBox[1] * image.height), + width: ((faceBox[2] - faceBox[0]) * image.width), + height: ((faceBox[3] - faceBox[1]) * image.height), + maxSize: const Size(128, 128), + minSize: const Size(128, 128), + ); + + final Image facePadded = await addPaddingToImage( + faceCrop, + 0.5, + ); + + return await encodeImageToUint8List(facePadded); +} + +Color getPixelBilinear(num fx, num fy, Image image, ByteData byteDataRgba) { + // Clamp to image boundaries + fx = fx.clamp(0, image.width - 1); + fy = fy.clamp(0, image.height - 1); + + // Get the surrounding coordinates and their weights + final int x0 = fx.floor(); + final int x1 = fx.ceil(); + final int y0 = fy.floor(); + final int y1 = fy.ceil(); + final dx = fx - x0; + final dy = fy - y0; + final dx1 = 1.0 - dx; + final dy1 = 1.0 - dy; + + // Get the original pixels + final Color pixel1 = readPixelColor(image, byteDataRgba, x0, y0); + final Color pixel2 = readPixelColor(image, byteDataRgba, x1, y0); + final Color pixel3 = readPixelColor(image, byteDataRgba, x0, y1); + final Color pixel4 = readPixelColor(image, byteDataRgba, x1, y1); + + int bilinear( + num val1, + num val2, + num val3, + num val4, + ) => + (val1 * dx1 * dy1 + val2 * dx * dy1 + val3 * dx1 * dy + val4 * dx * dy) + .round(); + + // Calculate the weighted sum of pixels + final int r = bilinear(pixel1.red, pixel2.red, pixel3.red, pixel4.red); + final int g = + bilinear(pixel1.green, pixel2.green, pixel3.green, pixel4.green); + final int b = bilinear(pixel1.blue, pixel2.blue, pixel3.blue, pixel4.blue); + + return Color.fromRGBO(r, g, b, 1.0); +} + +/// Get the pixel value using Bicubic Interpolation. Code taken mainly from https://github.com/brendan-duncan/image/blob/6e407612752ffdb90b28cd5863c7f65856349348/lib/src/image/image.dart#L697 +Color getPixelBicubic(num fx, num fy, Image image, ByteData byteDataRgba) { + fx = fx.clamp(0, image.width - 1); + fy = fy.clamp(0, image.height - 1); + + final x = fx.toInt() - (fx >= 0.0 ? 0 : 1); + final px = x - 1; + final nx = x + 1; + final ax = x + 2; + final y = fy.toInt() - (fy >= 0.0 ? 0 : 1); + final py = y - 1; + final ny = y + 1; + final ay = y + 2; + final dx = fx - x; + final dy = fy - y; + num cubic(num dx, num ipp, num icp, num inp, num iap) => + icp + + 0.5 * + (dx * (-ipp + inp) + + dx * dx * (2 * ipp - 5 * icp + 4 * inp - iap) + + dx * dx * dx * (-ipp + 3 * icp - 3 * inp + iap)); + + final icc = readPixelColor(image, byteDataRgba, x, y); + + final ipp = + px < 0 || py < 0 ? icc : readPixelColor(image, byteDataRgba, px, py); + final icp = px < 0 ? icc : readPixelColor(image, byteDataRgba, x, py); + final inp = py < 0 || nx >= image.width + ? icc + : readPixelColor(image, byteDataRgba, nx, py); + final iap = ax >= image.width || py < 0 + ? icc + : readPixelColor(image, byteDataRgba, ax, py); + + final ip0 = cubic(dx, ipp.red, icp.red, inp.red, iap.red); + final ip1 = cubic(dx, ipp.green, icp.green, inp.green, iap.green); + final ip2 = cubic(dx, ipp.blue, icp.blue, inp.blue, iap.blue); + // final ip3 = cubic(dx, ipp.a, icp.a, inp.a, iap.a); + + final ipc = px < 0 ? icc : readPixelColor(image, byteDataRgba, px, y); + final inc = + nx >= image.width ? icc : readPixelColor(image, byteDataRgba, nx, y); + final iac = + ax >= image.width ? icc : readPixelColor(image, byteDataRgba, ax, y); + + final ic0 = cubic(dx, ipc.red, icc.red, inc.red, iac.red); + final ic1 = cubic(dx, ipc.green, icc.green, inc.green, iac.green); + final ic2 = cubic(dx, ipc.blue, icc.blue, inc.blue, iac.blue); + // final ic3 = cubic(dx, ipc.a, icc.a, inc.a, iac.a); + + final ipn = px < 0 || ny >= image.height + ? icc + : readPixelColor(image, byteDataRgba, px, ny); + final icn = + ny >= image.height ? icc : readPixelColor(image, byteDataRgba, x, ny); + final inn = nx >= image.width || ny >= image.height + ? icc + : readPixelColor(image, byteDataRgba, nx, ny); + final ian = ax >= image.width || ny >= image.height + ? icc + : readPixelColor(image, byteDataRgba, ax, ny); + + final in0 = cubic(dx, ipn.red, icn.red, inn.red, ian.red); + final in1 = cubic(dx, ipn.green, icn.green, inn.green, ian.green); + final in2 = cubic(dx, ipn.blue, icn.blue, inn.blue, ian.blue); + // final in3 = cubic(dx, ipn.a, icn.a, inn.a, ian.a); + + final ipa = px < 0 || ay >= image.height + ? icc + : readPixelColor(image, byteDataRgba, px, ay); + final ica = + ay >= image.height ? icc : readPixelColor(image, byteDataRgba, x, ay); + final ina = nx >= image.width || ay >= image.height + ? icc + : readPixelColor(image, byteDataRgba, nx, ay); + final iaa = ax >= image.width || ay >= image.height + ? icc + : readPixelColor(image, byteDataRgba, ax, ay); + + final ia0 = cubic(dx, ipa.red, ica.red, ina.red, iaa.red); + final ia1 = cubic(dx, ipa.green, ica.green, ina.green, iaa.green); + final ia2 = cubic(dx, ipa.blue, ica.blue, ina.blue, iaa.blue); + // final ia3 = cubic(dx, ipa.a, ica.a, ina.a, iaa.a); + + final c0 = cubic(dy, ip0, ic0, in0, ia0).clamp(0, 255).toInt(); + final c1 = cubic(dy, ip1, ic1, in1, ia1).clamp(0, 255).toInt(); + final c2 = cubic(dy, ip2, ic2, in2, ia2).clamp(0, 255).toInt(); + // final c3 = cubic(dy, ip3, ic3, in3, ia3); + + return Color.fromRGBO(c0, c1, c2, 1.0); +} + +@Deprecated('Old method only used in other deprecated methods') +List getAlignedFaceBox(AlignmentResult alignment) { + final List box = [ + // [xMinBox, yMinBox, xMaxBox, yMaxBox] + alignment.center[0] - alignment.size / 2, + alignment.center[1] - alignment.size / 2, + alignment.center[0] + alignment.size / 2, + alignment.center[1] + alignment.size / 2, + ]; + box.roundBoxToDouble(); + return box; +} + +/// Returns an enlarged version of the [box] by a factor of [factor]. +/// The [box] is in absolute coordinates: [xMinBox, yMinBox, xMaxBox, yMaxBox]. +List getEnlargedAbsoluteBox(List box, [double factor = 2]) { + final boxCopy = List.from(box, growable: false); + // The four values of the box in order are: [xMinBox, yMinBox, xMaxBox, yMaxBox]. + + final width = boxCopy[2] - boxCopy[0]; + final height = boxCopy[3] - boxCopy[1]; + + boxCopy[0] -= width * (factor - 1) / 2; + boxCopy[1] -= height * (factor - 1) / 2; + boxCopy[2] += width * (factor - 1) / 2; + boxCopy[3] += height * (factor - 1) / 2; + + return boxCopy; +} diff --git a/mobile/lib/utils/image_util.dart b/mobile/lib/utils/image_util.dart index a5bcb03a7..e5b0d72fa 100644 --- a/mobile/lib/utils/image_util.dart +++ b/mobile/lib/utils/image_util.dart @@ -1,6 +1,8 @@ import 'dart:async'; +import 'dart:ui' as ui; import 'package:flutter/widgets.dart'; +import 'package:image/image.dart' as img; Future getImageInfo(ImageProvider imageProvider) { final completer = Completer(); @@ -14,3 +16,35 @@ Future getImageInfo(ImageProvider imageProvider) { completer.future.whenComplete(() => imageStream.removeListener(listener)); return completer.future; } + +Future convertImageToFlutterUi(img.Image image) async { + if (image.format != img.Format.uint8 || image.numChannels != 4) { + final cmd = img.Command() + ..image(image) + ..convert(format: img.Format.uint8, numChannels: 4); + final rgba8 = await cmd.getImageThread(); + if (rgba8 != null) { + image = rgba8; + } + } + + final ui.ImmutableBuffer buffer = + await ui.ImmutableBuffer.fromUint8List(image.toUint8List()); + + final ui.ImageDescriptor id = ui.ImageDescriptor.raw( + buffer, + height: image.height, + width: image.width, + pixelFormat: ui.PixelFormat.rgba8888, + ); + + final ui.Codec codec = await id.instantiateCodec( + targetHeight: image.height, + targetWidth: image.width, + ); + + final ui.FrameInfo fi = await codec.getNextFrame(); + final ui.Image uiImage = fi.image; + + return uiImage; +} diff --git a/mobile/lib/utils/local_settings.dart b/mobile/lib/utils/local_settings.dart index 2f277c80b..6b81e7697 100644 --- a/mobile/lib/utils/local_settings.dart +++ b/mobile/lib/utils/local_settings.dart @@ -14,6 +14,8 @@ class LocalSettings { static const kCollectionSortPref = "collection_sort_pref"; static const kPhotoGridSize = "photo_grid_size"; static const kEnableMagicSearch = "enable_magic_search"; + static const kEnableFaceIndexing = "enable_face_indexing"; + static const kEnableFaceClustering = "enable_face_clustering"; static const kRateUsShownCount = "rate_us_shown_count"; static const kRateUsPromptThreshold = 2; @@ -69,4 +71,30 @@ class LocalSettings { bool shouldPromptToRateUs() { return getRateUsShownCount() < kRateUsPromptThreshold; } + + bool get isFaceIndexingEnabled => + _prefs.getBool(kEnableFaceIndexing) ?? false; + + bool get isFaceClusteringEnabled => + _prefs.getBool(kEnableFaceIndexing) ?? false; + + /// toggleFaceIndexing toggles the face indexing setting and returns the new value + Future toggleFaceIndexing() async { + await _prefs.setBool(kEnableFaceIndexing, !isFaceIndexingEnabled); + return isFaceIndexingEnabled; + } + + //#region todo:(NG) remove this section, only needed for internal testing to see + // if the OS stops the app during indexing + bool get remoteFetchEnabled => _prefs.getBool("remoteFetchEnabled") ?? true; + Future toggleRemoteFetch() async { + await _prefs.setBool("remoteFetchEnabled", !remoteFetchEnabled); + } + //#endregion + + /// toggleFaceClustering toggles the face clustering setting and returns the new value + Future toggleFaceClustering() async { + await _prefs.setBool(kEnableFaceClustering, !isFaceClusteringEnabled); + return isFaceClusteringEnabled; + } } diff --git a/mobile/lib/utils/network_util.dart b/mobile/lib/utils/network_util.dart new file mode 100644 index 000000000..a3b28561c --- /dev/null +++ b/mobile/lib/utils/network_util.dart @@ -0,0 +1,21 @@ +import "package:connectivity_plus/connectivity_plus.dart"; +import "package:flutter/foundation.dart" show debugPrint; +import "package:photos/core/configuration.dart"; + +Future canUseHighBandwidth() async { + // Connections will contain a list of currently active connections. + // could be vpn and wifi or mobile and vpn, but should not be wifi and mobile + final List connections = + await (Connectivity().checkConnectivity()); + bool canUploadUnderCurrentNetworkConditions = true; + if (!Configuration.instance.shouldBackupOverMobileData()) { + if (connections.any((element) => element == ConnectivityResult.mobile)) { + canUploadUnderCurrentNetworkConditions = false; + } else { + debugPrint( + "[canUseHighBandwidth] mobileBackupDisabled, backing up with connections: ${connections.map((e) => e.name).toString()}", + ); + } + } + return canUploadUnderCurrentNetworkConditions; +} diff --git a/mobile/lib/utils/thumbnail_util.dart b/mobile/lib/utils/thumbnail_util.dart index dc2167632..db7648b92 100644 --- a/mobile/lib/utils/thumbnail_util.dart +++ b/mobile/lib/utils/thumbnail_util.dart @@ -217,3 +217,11 @@ File cachedThumbnailPath(EnteFile file) { thumbnailCacheDirectory + "/" + file.uploadedFileID.toString(), ); } + +File cachedFaceCropPath(String faceID) { + final thumbnailCacheDirectory = + Configuration.instance.getThumbnailCacheDirectory(); + return File( + thumbnailCacheDirectory + "/" + faceID, + ); +} diff --git a/mobile/plugins/ente_feature_flag/lib/src/service.dart b/mobile/plugins/ente_feature_flag/lib/src/service.dart index 47539eeb5..8d7f22679 100644 --- a/mobile/plugins/ente_feature_flag/lib/src/service.dart +++ b/mobile/plugins/ente_feature_flag/lib/src/service.dart @@ -67,7 +67,7 @@ class FlagService { bool get mapEnabled => flags.mapEnabled; - bool get faceSearchEnabled => flags.faceSearchEnabled; + bool get faceSearchEnabled => internalUser || flags.faceSearchEnabled; bool get passKeyEnabled => flags.passKeyEnabled || internalOrBetaUser; diff --git a/mobile/pubspec.lock b/mobile/pubspec.lock index ae74068eb..1d1082bfd 100644 --- a/mobile/pubspec.lock +++ b/mobile/pubspec.lock @@ -363,6 +363,14 @@ packages: url: "https://pub.dev" source: hosted version: "2.3.2" + dart_ui_isolate: + dependency: "direct main" + description: + name: dart_ui_isolate + sha256: bd531558002a00de0ac7dd73c84887dd01e652bd254d3098d7763881535196d7 + url: "https://pub.dev" + source: hosted + version: "1.1.1" dbus: dependency: transitive description: @@ -1380,6 +1388,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.0.5" + ml_linalg: + dependency: "direct main" + description: + name: ml_linalg + sha256: "304cb8a2a172f2303226d672d0b6f18dbfe558e2db49d27c8aa9f3e15475c0cd" + url: "https://pub.dev" + source: hosted + version: "13.12.2" modal_bottom_sheet: dependency: "direct main" description: @@ -1712,7 +1728,7 @@ packages: source: hosted version: "1.0.1" pool: - dependency: transitive + dependency: "direct main" description: name: pool sha256: "20fe868b6314b322ea036ba325e6fc0711a22948856475e2c2b6306e8ab39c2a" @@ -1736,7 +1752,7 @@ packages: source: hosted version: "2.1.0" protobuf: - dependency: transitive + dependency: "direct main" description: name: protobuf sha256: "68645b24e0716782e58948f8467fd42a880f255096a821f9e7d0ec625b00c84d" @@ -1975,6 +1991,14 @@ packages: url: "https://pub.dev" source: hosted version: "1.0.4" + simple_cluster: + dependency: "direct main" + description: + name: simple_cluster + sha256: "64d6b7d60d641299ad8c3f012417c711532792c1bc61ac6a7f52b942cdba65da" + url: "https://pub.dev" + source: hosted + version: "0.3.0" sky_engine: dependency: transitive description: flutter @@ -2149,7 +2173,7 @@ packages: source: hosted version: "19.4.56" synchronized: - dependency: transitive + dependency: "direct main" description: name: synchronized sha256: "539ef412b170d65ecdafd780f924e5be3f60032a1128df156adad6c5b373d558" diff --git a/mobile/pubspec.yaml b/mobile/pubspec.yaml index 1f14bb037..6464496f5 100644 --- a/mobile/pubspec.yaml +++ b/mobile/pubspec.yaml @@ -12,7 +12,7 @@ description: ente photos application # Read more about iOS versioning at # https://developer.apple.com/library/archive/documentation/General/Reference/InfoPlistKeyReference/Articles/CoreFoundationKeys.html -version: 0.8.96+616 +version: 0.8.97+617 publish_to: none environment: @@ -39,7 +39,8 @@ dependencies: connectivity_plus: ^6.0.2 cross_file: ^0.3.3 crypto: ^3.0.2 - cupertino_icons: ^1.0.8 + cupertino_icons: ^1.0.0 + dart_ui_isolate: ^1.1.1 defer_pointer: ^0.0.2 device_info_plus: ^9.0.3 dio: ^4.0.6 @@ -112,6 +113,7 @@ dependencies: media_kit: ^1.1.10+1 media_kit_libs_video: ^1.0.4 media_kit_video: ^1.2.4 + ml_linalg: ^13.11.31 modal_bottom_sheet: ^3.0.0-pre motion_photos: git: "https://github.com/ente-io/motion_photo.git" @@ -134,6 +136,8 @@ dependencies: photo_view: ^0.14.0 pinput: ^1.2.2 pointycastle: ^3.7.3 + pool: ^1.5.1 + protobuf: ^3.1.0 provider: ^6.0.0 quiver: ^3.0.1 receive_sharing_intent: ^1.7.0 @@ -142,6 +146,7 @@ dependencies: sentry_flutter: ^7.9.0 share_plus: 7.2.2 shared_preferences: ^2.0.5 + simple_cluster: ^0.3.0 sqflite: ^2.3.0 sqflite_migration: ^0.3.0 sqlite3_flutter_libs: ^0.5.20 @@ -150,11 +155,7 @@ dependencies: styled_text: ^7.0.0 syncfusion_flutter_core: ^19.2.49 syncfusion_flutter_sliders: ^19.2.49 - # tflite_flutter: ^0.9.0 - # tflite_flutter_helper: - # git: - # url: https://github.com/pnyompen/tflite_flutter_helper.git - # ref: 43e87d4b9627539266dc20250beb35bf36320dce + synchronized: ^3.1.0 tuple: ^2.0.0 uni_links: ^0.5.1 url_launcher: ^6.0.3 @@ -228,9 +229,6 @@ flutter_native_splash: flutter: assets: - assets/ - - assets/models/cocossd/ - - assets/models/mobilenet/ - - assets/models/scenes/ - assets/models/clip/ fonts: - family: Inter diff --git a/server/cmd/museum/main.go b/server/cmd/museum/main.go index 84c34189d..8ccb43cc0 100644 --- a/server/cmd/museum/main.go +++ b/server/cmd/museum/main.go @@ -678,7 +678,7 @@ func main() { pushHandler := &api.PushHandler{PushController: pushController} privateAPI.POST("/push/token", pushHandler.AddToken) - embeddingController := &embeddingCtrl.Controller{Repo: embeddingRepo, AccessCtrl: accessCtrl, ObjectCleanupController: objectCleanupController, S3Config: s3Config, FileRepo: fileRepo, CollectionRepo: collectionRepo, QueueRepo: queueRepo, TaskLockingRepo: taskLockingRepo, HostName: hostName} + embeddingController := embeddingCtrl.New(embeddingRepo, accessCtrl, objectCleanupController, s3Config, queueRepo, taskLockingRepo, fileRepo, collectionRepo, hostName) embeddingHandler := &api.EmbeddingHandler{Controller: embeddingController} privateAPI.PUT("/embeddings", embeddingHandler.InsertOrUpdate) diff --git a/server/configurations/local.yaml b/server/configurations/local.yaml index 196c56f1f..87502c271 100644 --- a/server/configurations/local.yaml +++ b/server/configurations/local.yaml @@ -125,6 +125,16 @@ s3: endpoint: region: bucket: + wasabi-eu-central-2-derived: + key: + secret: + endpoint: + region: + bucket: + # Derived storage bucket is used for storing derived data like embeddings, preview etc. + # By default, it is the same as the hot storage bucket. + # derived-storage: wasabi-eu-central-2-derived + # If true, enable some workarounds to allow us to use a local minio instance # for object storage. # diff --git a/server/migrations/86_add_dc_embedding.down.sql b/server/migrations/86_add_dc_embedding.down.sql new file mode 100644 index 000000000..b705b29b6 --- /dev/null +++ b/server/migrations/86_add_dc_embedding.down.sql @@ -0,0 +1,18 @@ +-- Add types for the new dcs that are introduced for the derived data +ALTER TABLE embeddings DROP COLUMN IF EXISTS datacenters; + +DO +$$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'update_embeddings_updated_at') THEN + CREATE TRIGGER update_embeddings_updated_at + BEFORE UPDATE + ON embeddings + FOR EACH ROW + EXECUTE PROCEDURE + trigger_updated_at_microseconds_column(); + ELSE + RAISE NOTICE 'Trigger update_embeddings_updated_at already exists.'; + END IF; + END +$$; \ No newline at end of file diff --git a/server/migrations/86_add_dc_embedding.up.sql b/server/migrations/86_add_dc_embedding.up.sql new file mode 100644 index 000000000..9d8e28ba7 --- /dev/null +++ b/server/migrations/86_add_dc_embedding.up.sql @@ -0,0 +1,4 @@ +-- Add types for the new dcs that are introduced for the derived data +ALTER TYPE s3region ADD VALUE 'wasabi-eu-central-2-derived'; +DROP TRIGGER IF EXISTS update_embeddings_updated_at ON embeddings; +ALTER TABLE embeddings ADD COLUMN IF NOT EXISTS datacenters s3region[] default '{b2-eu-cen}'; diff --git a/server/pkg/controller/embedding/controller.go b/server/pkg/controller/embedding/controller.go index bf317ccfe..6f3de3ca7 100644 --- a/server/pkg/controller/embedding/controller.go +++ b/server/pkg/controller/embedding/controller.go @@ -6,8 +6,10 @@ import ( "encoding/json" "errors" "fmt" + "github.com/aws/aws-sdk-go/aws/awserr" "github.com/ente-io/museum/pkg/utils/array" "strconv" + "strings" "sync" gTime "time" @@ -22,7 +24,6 @@ import ( "github.com/ente-io/museum/pkg/utils/auth" "github.com/ente-io/museum/pkg/utils/network" "github.com/ente-io/museum/pkg/utils/s3config" - "github.com/ente-io/museum/pkg/utils/time" "github.com/ente-io/stacktrace" "github.com/gin-gonic/gin" log "github.com/sirupsen/logrus" @@ -31,20 +32,54 @@ import ( const ( // maxEmbeddingDataSize is the min size of an embedding object in bytes minEmbeddingDataSize = 2048 - embeddingFetchTimeout = 15 * gTime.Second + embeddingFetchTimeout = 10 * gTime.Second ) +// _fetchConfig is the configuration for the fetching objects from S3 +type _fetchConfig struct { + RetryCount int + InitialTimeout gTime.Duration + MaxTimeout gTime.Duration +} + +var _defaultFetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 10 * gTime.Second, MaxTimeout: 30 * gTime.Second} +var _b2FetchConfig = _fetchConfig{RetryCount: 3, InitialTimeout: 15 * gTime.Second, MaxTimeout: 30 * gTime.Second} + type Controller struct { - Repo *embedding.Repository - AccessCtrl access.Controller - ObjectCleanupController *controller.ObjectCleanupController - S3Config *s3config.S3Config - QueueRepo *repo.QueueRepository - TaskLockingRepo *repo.TaskLockRepository - FileRepo *repo.FileRepository - CollectionRepo *repo.CollectionRepository - HostName string - cleanupCronRunning bool + Repo *embedding.Repository + AccessCtrl access.Controller + ObjectCleanupController *controller.ObjectCleanupController + S3Config *s3config.S3Config + QueueRepo *repo.QueueRepository + TaskLockingRepo *repo.TaskLockRepository + FileRepo *repo.FileRepository + CollectionRepo *repo.CollectionRepository + HostName string + cleanupCronRunning bool + derivedStorageDataCenter string + downloadManagerCache map[string]*s3manager.Downloader +} + +func New(repo *embedding.Repository, accessCtrl access.Controller, objectCleanupController *controller.ObjectCleanupController, s3Config *s3config.S3Config, queueRepo *repo.QueueRepository, taskLockingRepo *repo.TaskLockRepository, fileRepo *repo.FileRepository, collectionRepo *repo.CollectionRepository, hostName string) *Controller { + embeddingDcs := []string{s3Config.GetHotBackblazeDC(), s3Config.GetHotWasabiDC(), s3Config.GetWasabiDerivedDC(), s3Config.GetDerivedStorageDataCenter()} + cache := make(map[string]*s3manager.Downloader, len(embeddingDcs)) + for i := range embeddingDcs { + s3Client := s3Config.GetS3Client(embeddingDcs[i]) + cache[embeddingDcs[i]] = s3manager.NewDownloaderWithClient(&s3Client) + } + return &Controller{ + Repo: repo, + AccessCtrl: accessCtrl, + ObjectCleanupController: objectCleanupController, + S3Config: s3Config, + QueueRepo: queueRepo, + TaskLockingRepo: taskLockingRepo, + FileRepo: fileRepo, + CollectionRepo: collectionRepo, + HostName: hostName, + derivedStorageDataCenter: s3Config.GetDerivedStorageDataCenter(), + downloadManagerCache: cache, + } } func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmbeddingRequest) (*ente.Embedding, error) { @@ -77,12 +112,12 @@ func (c *Controller) InsertOrUpdate(ctx *gin.Context, req ente.InsertOrUpdateEmb DecryptionHeader: req.DecryptionHeader, Client: network.GetPrettyUA(ctx.GetHeader("User-Agent")) + "/" + ctx.GetHeader("X-Client-Version"), } - size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model)) + size, uploadErr := c.uploadObject(obj, c.getObjectKey(userID, req.FileID, req.Model), c.derivedStorageDataCenter) if uploadErr != nil { log.Error(uploadErr) return nil, stacktrace.Propagate(uploadErr, "") } - embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version) + embedding, err := c.Repo.InsertOrUpdate(ctx, userID, req, size, version, c.derivedStorageDataCenter) embedding.Version = &version if err != nil { return nil, stacktrace.Propagate(err, "") @@ -113,7 +148,7 @@ func (c *Controller) GetDiff(ctx *gin.Context, req ente.GetEmbeddingDiffRequest) // Fetch missing embeddings in parallel if len(objectKeys) > 0 { - embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys) + embeddingObjects, err := c.getEmbeddingObjectsParallel(objectKeys, c.derivedStorageDataCenter) if err != nil { return nil, stacktrace.Propagate(err, "") } @@ -146,7 +181,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd embeddingsWithData := make([]ente.Embedding, 0) noEmbeddingFileIds := make([]int64, 0) dbFileIds := make([]int64, 0) - // fileIDs that were indexed but they don't contain any embedding information + // fileIDs that were indexed, but they don't contain any embedding information for i := range userFileEmbeddings { dbFileIds = append(dbFileIds, userFileEmbeddings[i].FileID) if userFileEmbeddings[i].Size != nil && *userFileEmbeddings[i].Size < minEmbeddingDataSize { @@ -159,7 +194,7 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd errFileIds := make([]int64, 0) // Fetch missing userFileEmbeddings in parallel - embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData) + embeddingObjects, err := c.getEmbeddingObjectsParallelV2(userID, embeddingsWithData, c.derivedStorageDataCenter) if err != nil { return nil, stacktrace.Propagate(err, "") } @@ -189,82 +224,6 @@ func (c *Controller) GetFilesEmbedding(ctx *gin.Context, req ente.GetFilesEmbedd }, nil } -func (c *Controller) DeleteAll(ctx *gin.Context) error { - userID := auth.GetUserID(ctx.Request.Header) - - err := c.Repo.DeleteAll(ctx, userID) - if err != nil { - return stacktrace.Propagate(err, "") - } - return nil -} - -// CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store -func (c *Controller) CleanupDeletedEmbeddings() { - log.Info("Cleaning up deleted embeddings") - if c.cleanupCronRunning { - log.Info("Skipping CleanupDeletedEmbeddings cron run as another instance is still running") - return - } - c.cleanupCronRunning = true - defer func() { - c.cleanupCronRunning = false - }() - items, err := c.QueueRepo.GetItemsReadyForDeletion(repo.DeleteEmbeddingsQueue, 200) - if err != nil { - log.WithError(err).Error("Failed to fetch items from queue") - return - } - for _, i := range items { - c.deleteEmbedding(i) - } -} - -func (c *Controller) deleteEmbedding(qItem repo.QueueItem) { - lockName := fmt.Sprintf("Embedding:%s", qItem.Item) - lockStatus, err := c.TaskLockingRepo.AcquireLock(lockName, time.MicrosecondsAfterHours(1), c.HostName) - ctxLogger := log.WithField("item", qItem.Item).WithField("queue_id", qItem.Id) - if err != nil || !lockStatus { - ctxLogger.Warn("unable to acquire lock") - return - } - defer func() { - err = c.TaskLockingRepo.ReleaseLock(lockName) - if err != nil { - ctxLogger.Errorf("Error while releasing lock %s", err) - } - }() - ctxLogger.Info("Deleting all embeddings") - - fileID, _ := strconv.ParseInt(qItem.Item, 10, 64) - ownerID, err := c.FileRepo.GetOwnerID(fileID) - if err != nil { - ctxLogger.WithError(err).Error("Failed to fetch ownerID") - return - } - prefix := c.getEmbeddingObjectPrefix(ownerID, fileID) - - err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, c.S3Config.GetHotDataCenter()) - if err != nil { - ctxLogger.WithError(err).Error("Failed to delete all objects") - return - } - - err = c.Repo.Delete(fileID) - if err != nil { - ctxLogger.WithError(err).Error("Failed to remove from db") - return - } - - err = c.QueueRepo.DeleteItem(repo.DeleteEmbeddingsQueue, qItem.Item) - if err != nil { - ctxLogger.WithError(err).Error("Failed to remove item from the queue") - return - } - - ctxLogger.Info("Successfully deleted all embeddings") -} - func (c *Controller) getObjectKey(userID int64, fileID int64, model string) string { return c.getEmbeddingObjectPrefix(userID, fileID) + model + ".json" } @@ -273,12 +232,23 @@ func (c *Controller) getEmbeddingObjectPrefix(userID int64, fileID int64) string return strconv.FormatInt(userID, 10) + "/ml-data/" + strconv.FormatInt(fileID, 10) + "/" } +// Get userId, model and fileID from the object key +func (c *Controller) getEmbeddingObjectDetails(objectKey string) (userID int64, model string, fileID int64) { + split := strings.Split(objectKey, "/") + userID, _ = strconv.ParseInt(split[0], 10, 64) + fileID, _ = strconv.ParseInt(split[2], 10, 64) + model = strings.Split(split[3], ".")[0] + return userID, model, fileID +} + // uploadObject uploads the embedding object to the object store and returns the object size -func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string) (int, error) { +func (c *Controller) uploadObject(obj ente.EmbeddingObject, key string, dc string) (int, error) { embeddingObj, _ := json.Marshal(obj) - uploader := s3manager.NewUploaderWithClient(c.S3Config.GetHotS3Client()) + s3Client := c.S3Config.GetS3Client(dc) + s3Bucket := c.S3Config.GetBucket(dc) + uploader := s3manager.NewUploaderWithClient(&s3Client) up := s3manager.UploadInput{ - Bucket: c.S3Config.GetHotBucket(), + Bucket: s3Bucket, Key: &key, Body: bytes.NewReader(embeddingObj), } @@ -296,12 +266,10 @@ var globalDiffFetchSemaphore = make(chan struct{}, 300) var globalFileFetchSemaphore = make(chan struct{}, 400) -func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.EmbeddingObject, error) { +func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string, dc string) ([]ente.EmbeddingObject, error) { var wg sync.WaitGroup var errs []error embeddingObjects := make([]ente.EmbeddingObject, len(objectKeys)) - downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) - for i, objectKey := range objectKeys { wg.Add(1) globalDiffFetchSemaphore <- struct{}{} // Acquire from global semaphore @@ -309,7 +277,7 @@ func (c *Controller) getEmbeddingObjectsParallel(objectKeys []string) ([]ente.Em defer wg.Done() defer func() { <-globalDiffFetchSemaphore }() // Release back to global semaphore - obj, err := c.getEmbeddingObject(context.Background(), objectKey, downloader) + obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc) if err != nil { errs = append(errs, err) log.Error("error fetching embedding object: "+objectKey, err) @@ -334,10 +302,9 @@ type embeddingObjectResult struct { err error } -func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding) ([]embeddingObjectResult, error) { +func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows []ente.Embedding, dc string) ([]embeddingObjectResult, error) { var wg sync.WaitGroup embeddingObjects := make([]embeddingObjectResult, len(dbEmbeddingRows)) - downloader := s3manager.NewDownloaderWithClient(c.S3Config.GetHotS3Client()) for i, dbEmbeddingRow := range dbEmbeddingRows { wg.Add(1) @@ -346,9 +313,7 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows defer wg.Done() defer func() { <-globalFileFetchSemaphore }() // Release back to global semaphore objectKey := c.getObjectKey(userID, dbEmbeddingRow.FileID, dbEmbeddingRow.Model) - ctx, cancel := context.WithTimeout(context.Background(), embeddingFetchTimeout) - defer cancel() - obj, err := c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 0) + obj, err := c.getEmbeddingObject(context.Background(), objectKey, dc) if err != nil { log.Error("error fetching embedding object: "+objectKey, err) embeddingObjects[i] = embeddingObjectResult{ @@ -368,32 +333,125 @@ func (c *Controller) getEmbeddingObjectsParallelV2(userID int64, dbEmbeddingRows return embeddingObjects, nil } -func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, downloader *s3manager.Downloader) (ente.EmbeddingObject, error) { - return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, 3) +func (c *Controller) getEmbeddingObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) { + opt := _defaultFetchConfig + if dc == c.S3Config.GetHotBackblazeDC() { + opt = _b2FetchConfig + } + ctxLogger := log.WithField("objectKey", objectKey).WithField("dc", dc) + totalAttempts := opt.RetryCount + 1 + timeout := opt.InitialTimeout + for i := 0; i < totalAttempts; i++ { + if i > 0 { + timeout = timeout * 2 + if timeout > opt.MaxTimeout { + timeout = opt.MaxTimeout + } + } + fetchCtx, cancel := context.WithTimeout(ctx, timeout) + select { + case <-ctx.Done(): + cancel() + return ente.EmbeddingObject{}, stacktrace.Propagate(ctx.Err(), "") + default: + obj, err := c.downloadObject(fetchCtx, objectKey, dc) + cancel() // Ensure cancel is called to release resources + if err == nil { + if i > 0 { + ctxLogger.Infof("Fetched object after %d attempts", i) + } + return obj, nil + } + // Check if the error is due to context timeout or cancellation + if err == nil && fetchCtx.Err() != nil { + ctxLogger.Error("Fetch timed out or cancelled: ", fetchCtx.Err()) + } else { + // check if the error is due to object not found + if s3Err, ok := err.(awserr.RequestFailure); ok { + if s3Err.Code() == s3.ErrCodeNoSuchKey { + var srcDc, destDc string + destDc = c.S3Config.GetDerivedStorageDataCenter() + // todo:(neeraj) Refactor this later to get available the DC from the DB instead of + // querying the DB. This will help in case of multiple DCs and avoid querying the DB + // for each object. + // For initial migration, as we know that original DC was b2, and if the embedding is not found + // in the new derived DC, we can try to fetch it from the B2 DC. + if c.derivedStorageDataCenter != c.S3Config.GetHotBackblazeDC() { + // embeddings ideally should ideally be in the default hot bucket b2 + srcDc = c.S3Config.GetHotBackblazeDC() + } else { + _, modelName, fileID := c.getEmbeddingObjectDetails(objectKey) + activeDcs, err := c.Repo.GetOtherDCsForFileAndModel(context.Background(), fileID, modelName, c.derivedStorageDataCenter) + if err != nil { + return ente.EmbeddingObject{}, stacktrace.Propagate(err, "failed to get other dc") + } + if len(activeDcs) > 0 { + srcDc = activeDcs[0] + } else { + ctxLogger.Error("Object not found in any dc ", s3Err) + return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") + } + } + copyEmbeddingObject, err := c.copyEmbeddingObject(ctx, objectKey, srcDc, destDc) + if err == nil { + ctxLogger.Infof("Got object from dc %s", srcDc) + return *copyEmbeddingObject, nil + } else { + ctxLogger.WithError(err).Errorf("Failed to get object from fallback dc %s", srcDc) + } + return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("object not found"), "") + } + } + ctxLogger.Error("Failed to fetch object: ", err) + } + } + } + return ente.EmbeddingObject{}, stacktrace.Propagate(errors.New("failed to fetch object"), "") } -func (c *Controller) getEmbeddingObjectWithRetries(ctx context.Context, objectKey string, downloader *s3manager.Downloader, retryCount int) (ente.EmbeddingObject, error) { +func (c *Controller) downloadObject(ctx context.Context, objectKey string, dc string) (ente.EmbeddingObject, error) { var obj ente.EmbeddingObject buff := &aws.WriteAtBuffer{} + bucket := c.S3Config.GetBucket(dc) + downloader := c.downloadManagerCache[dc] _, err := downloader.DownloadWithContext(ctx, buff, &s3.GetObjectInput{ - Bucket: c.S3Config.GetHotBucket(), + Bucket: bucket, Key: &objectKey, }) if err != nil { - log.Error(err) - if retryCount > 0 { - return c.getEmbeddingObjectWithRetries(ctx, objectKey, downloader, retryCount-1) - } - return obj, stacktrace.Propagate(err, "") + return obj, err } err = json.Unmarshal(buff.Bytes(), &obj) if err != nil { - log.Error(err) - return obj, stacktrace.Propagate(err, "") + return obj, stacktrace.Propagate(err, "unmarshal failed") } return obj, nil } +// download the embedding object from hot bucket and upload to embeddings bucket +func (c *Controller) copyEmbeddingObject(ctx context.Context, objectKey string, srcDC, destDC string) (*ente.EmbeddingObject, error) { + if srcDC == destDC { + return nil, stacktrace.Propagate(errors.New("src and dest dc can not be same"), "") + } + obj, err := c.downloadObject(ctx, objectKey, srcDC) + if err != nil { + return nil, stacktrace.Propagate(err, fmt.Sprintf("failed to download object from %s", srcDC)) + } + go func() { + userID, modelName, fileID := c.getEmbeddingObjectDetails(objectKey) + size, uploadErr := c.uploadObject(obj, objectKey, c.derivedStorageDataCenter) + if uploadErr != nil { + log.WithField("object", objectKey).Error("Failed to copy to embeddings bucket: ", uploadErr) + } + updateDcErr := c.Repo.AddNewDC(context.Background(), fileID, ente.Model(modelName), userID, size, destDC) + if updateDcErr != nil { + log.WithField("object", objectKey).Error("Failed to update dc in db: ", updateDcErr) + return + } + }() + return &obj, nil +} + func (c *Controller) _validateGetFileEmbeddingsRequest(ctx *gin.Context, userID int64, req ente.GetFilesEmbeddingRequest) error { if req.Model == "" { return ente.NewBadRequestWithMessage("model is required") diff --git a/server/pkg/controller/embedding/delete.go b/server/pkg/controller/embedding/delete.go new file mode 100644 index 000000000..91a70963f --- /dev/null +++ b/server/pkg/controller/embedding/delete.go @@ -0,0 +1,110 @@ +package embedding + +import ( + "context" + "fmt" + "github.com/ente-io/museum/pkg/repo" + "github.com/ente-io/museum/pkg/utils/auth" + "github.com/ente-io/museum/pkg/utils/time" + "github.com/ente-io/stacktrace" + "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "strconv" +) + +func (c *Controller) DeleteAll(ctx *gin.Context) error { + userID := auth.GetUserID(ctx.Request.Header) + + err := c.Repo.DeleteAll(ctx, userID) + if err != nil { + return stacktrace.Propagate(err, "") + } + return nil +} + +// CleanupDeletedEmbeddings clears all embeddings for deleted files from the object store +func (c *Controller) CleanupDeletedEmbeddings() { + log.Info("Cleaning up deleted embeddings") + if c.cleanupCronRunning { + log.Info("Skipping CleanupDeletedEmbeddings cron run as another instance is still running") + return + } + c.cleanupCronRunning = true + defer func() { + c.cleanupCronRunning = false + }() + items, err := c.QueueRepo.GetItemsReadyForDeletion(repo.DeleteEmbeddingsQueue, 200) + if err != nil { + log.WithError(err).Error("Failed to fetch items from queue") + return + } + for _, i := range items { + c.deleteEmbedding(i) + } +} + +func (c *Controller) deleteEmbedding(qItem repo.QueueItem) { + lockName := fmt.Sprintf("Embedding:%s", qItem.Item) + lockStatus, err := c.TaskLockingRepo.AcquireLock(lockName, time.MicrosecondsAfterHours(1), c.HostName) + ctxLogger := log.WithField("item", qItem.Item).WithField("queue_id", qItem.Id) + if err != nil || !lockStatus { + ctxLogger.Warn("unable to acquire lock") + return + } + defer func() { + err = c.TaskLockingRepo.ReleaseLock(lockName) + if err != nil { + ctxLogger.Errorf("Error while releasing lock %s", err) + } + }() + ctxLogger.Info("Deleting all embeddings") + + fileID, _ := strconv.ParseInt(qItem.Item, 10, 64) + ownerID, err := c.FileRepo.GetOwnerID(fileID) + if err != nil { + ctxLogger.WithError(err).Error("Failed to fetch ownerID") + return + } + prefix := c.getEmbeddingObjectPrefix(ownerID, fileID) + datacenters, err := c.Repo.GetDatacenters(context.Background(), fileID) + if err != nil { + ctxLogger.WithError(err).Error("Failed to fetch datacenters") + return + } + ctxLogger.Infof("Deleting from all datacenters %v", datacenters) + for i := range datacenters { + dc := datacenters[i] + err = c.ObjectCleanupController.DeleteAllObjectsWithPrefix(prefix, dc) + if err != nil { + ctxLogger.WithError(err). + WithField("dc", dc). + Errorf("Failed to delete all objects from %s", datacenters[i]) + return + } else { + removeErr := c.Repo.RemoveDatacenter(context.Background(), fileID, datacenters[i]) + if removeErr != nil { + ctxLogger.WithError(removeErr). + WithField("dc", dc). + Error("Failed to remove datacenter from db") + return + } + } + } + + noDcs, noDcErr := c.Repo.GetDatacenters(context.Background(), fileID) + if len(noDcs) > 0 || noDcErr != nil { + ctxLogger.Errorf("Failed to delete from all datacenters %s", noDcs) + return + } + err = c.Repo.Delete(fileID) + if err != nil { + ctxLogger.WithError(err).Error("Failed to remove from db") + return + } + err = c.QueueRepo.DeleteItem(repo.DeleteEmbeddingsQueue, qItem.Item) + if err != nil { + ctxLogger.WithError(err).Error("Failed to remove item from the queue") + return + } + ctxLogger.Info("Successfully deleted all embeddings") +} diff --git a/server/pkg/controller/object_cleanup.go b/server/pkg/controller/object_cleanup.go index a1ba2dba5..91426cb56 100644 --- a/server/pkg/controller/object_cleanup.go +++ b/server/pkg/controller/object_cleanup.go @@ -260,7 +260,10 @@ func (c *ObjectCleanupController) DeleteAllObjectsWithPrefix(prefix string, dc s Prefix: &prefix, }) if err != nil { - log.Error(err) + log.WithFields(log.Fields{ + "prefix": prefix, + "dc": dc, + }).WithError(err).Error("Failed to list objects") return stacktrace.Propagate(err, "") } var keys []string @@ -270,7 +273,10 @@ func (c *ObjectCleanupController) DeleteAllObjectsWithPrefix(prefix string, dc s for _, key := range keys { err = c.DeleteObjectFromDataCenter(key, dc) if err != nil { - log.Error(err) + log.WithFields(log.Fields{ + "object_key": key, + "dc": dc, + }).WithError(err).Error("Failed to delete object") return stacktrace.Propagate(err, "") } } diff --git a/server/pkg/repo/embedding/repository.go b/server/pkg/repo/embedding/repository.go index 86915fde5..5cfbd35c5 100644 --- a/server/pkg/repo/embedding/repository.go +++ b/server/pkg/repo/embedding/repository.go @@ -3,11 +3,11 @@ package embedding import ( "context" "database/sql" + "errors" "fmt" - "github.com/lib/pq" - "github.com/ente-io/museum/ente" "github.com/ente-io/stacktrace" + "github.com/lib/pq" "github.com/sirupsen/logrus" ) @@ -18,15 +18,26 @@ type Repository struct { } // Create inserts a new embedding - -func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int) (ente.Embedding, error) { +func (r *Repository) InsertOrUpdate(ctx context.Context, ownerID int64, entry ente.InsertOrUpdateEmbeddingRequest, size int, version int, dc string) (ente.Embedding, error) { var updatedAt int64 - err := r.DB.QueryRowContext(ctx, `INSERT INTO embeddings - (file_id, owner_id, model, size, version) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model - DO UPDATE SET updated_at = now_utc_micro_seconds(), size = $4, version = $5 - RETURNING updated_at`, entry.FileID, ownerID, entry.Model, size, version).Scan(&updatedAt) + err := r.DB.QueryRowContext(ctx, ` + INSERT INTO embeddings + (file_id, owner_id, model, size, version, datacenters) + VALUES + ($1, $2, $3, $4, $5, ARRAY[$6]::s3region[]) + ON CONFLICT ON CONSTRAINT unique_embeddings_file_id_model + DO UPDATE + SET + updated_at = now_utc_micro_seconds(), + size = $4, + version = $5, + datacenters = CASE + WHEN $6 = ANY(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[])) THEN embeddings.datacenters + ELSE array_append(COALESCE(embeddings.datacenters, ARRAY['b2-eu-cen']::s3region[]), $6::s3region) + END + RETURNING updated_at`, + entry.FileID, ownerID, entry.Model, size, version, dc).Scan(&updatedAt) + if err != nil { // check if error is due to model enum invalid value if err.Error() == fmt.Sprintf("pq: invalid input value for enum model: \"%s\"", entry.Model) { @@ -82,6 +93,89 @@ func (r *Repository) Delete(fileID int64) error { return nil } +// GetDatacenters returns unique list of datacenters where derived embeddings are stored +func (r *Repository) GetDatacenters(ctx context.Context, fileID int64) ([]string, error) { + rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1`, fileID) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + uniqueDatacenters := make(map[string]struct{}) + for rows.Next() { + var datacenters []string + err = rows.Scan(pq.Array(&datacenters)) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + for _, dc := range datacenters { + uniqueDatacenters[dc] = struct{}{} + } + } + datacenters := make([]string, 0, len(uniqueDatacenters)) + for dc := range uniqueDatacenters { + datacenters = append(datacenters, dc) + } + return datacenters, nil +} + +// GetOtherDCsForFileAndModel returns the list of datacenters where the embeddings are stored for a given file and model, excluding the ignoredDC +func (r *Repository) GetOtherDCsForFileAndModel(ctx context.Context, fileID int64, model string, ignoredDC string) ([]string, error) { + rows, err := r.DB.QueryContext(ctx, `SELECT datacenters FROM embeddings WHERE file_id = $1 AND model = $2`, fileID, model) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + uniqueDatacenters := make(map[string]bool) + for rows.Next() { + var datacenters []string + err = rows.Scan(pq.Array(&datacenters)) + if err != nil { + return nil, stacktrace.Propagate(err, "") + } + for _, dc := range datacenters { + // add to uniqueDatacenters if it is not the ignoredDC + if dc != ignoredDC { + uniqueDatacenters[dc] = true + } + } + } + datacenters := make([]string, 0, len(uniqueDatacenters)) + for dc := range uniqueDatacenters { + datacenters = append(datacenters, dc) + } + return datacenters, nil +} + +// RemoveDatacenter removes the given datacenter from the list of datacenters +func (r *Repository) RemoveDatacenter(ctx context.Context, fileID int64, dc string) error { + _, err := r.DB.ExecContext(ctx, `UPDATE embeddings SET datacenters = array_remove(datacenters, $1) WHERE file_id = $2`, dc, fileID) + if err != nil { + return stacktrace.Propagate(err, "") + } + return nil +} + +// AddNewDC adds the dc name to the list of datacenters, if it doesn't exist already, for a given file, model and user. It also updates the size of the embedding +func (r *Repository) AddNewDC(ctx context.Context, fileID int64, model ente.Model, userID int64, size int, dc string) error { + res, err := r.DB.ExecContext(ctx, ` + UPDATE embeddings + SET size = $1, + datacenters = CASE + WHEN $2::s3region = ANY(datacenters) THEN datacenters + ELSE array_append(datacenters, $2::s3region) + END + WHERE file_id = $3 AND model = $4 AND owner_id = $5`, size, dc, fileID, model, userID) + if err != nil { + return stacktrace.Propagate(err, "") + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return stacktrace.Propagate(err, "") + } + if rowsAffected == 0 { + return stacktrace.Propagate(errors.New("no row got updated"), "") + } + return nil +} + func convertRowsToEmbeddings(rows *sql.Rows) ([]ente.Embedding, error) { defer func() { if err := rows.Close(); err != nil { diff --git a/server/pkg/utils/s3config/s3config.go b/server/pkg/utils/s3config/s3config.go index 9b273bd61..a562e5181 100644 --- a/server/pkg/utils/s3config/s3config.go +++ b/server/pkg/utils/s3config/s3config.go @@ -28,6 +28,8 @@ type S3Config struct { hotDC string // Secondary (hot) data center secondaryHotDC string + //Derived data data center for derived files like ml embeddings & preview files + derivedStorageDC string // A map from data centers to S3 configurations s3Configs map[string]*aws.Config // A map from data centers to pre-created S3 clients @@ -71,6 +73,7 @@ var ( dcWasabiEuropeCentralDeprecated string = "wasabi-eu-central-2" dcWasabiEuropeCentral_v3 string = "wasabi-eu-central-2-v3" dcSCWEuropeFrance_v3 string = "scw-eu-fr-v3" + dcWasabiEuropeCentralDerived string = "wasabi-eu-central-2-derived" ) // Number of days that the wasabi bucket is configured to retain objects. @@ -86,9 +89,9 @@ func NewS3Config() *S3Config { } func (config *S3Config) initialize() { - dcs := [5]string{ + dcs := [6]string{ dcB2EuropeCentral, dcSCWEuropeFranceLockedDeprecated, dcWasabiEuropeCentralDeprecated, - dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3} + dcWasabiEuropeCentral_v3, dcSCWEuropeFrance_v3, dcWasabiEuropeCentralDerived} config.hotDC = dcB2EuropeCentral config.secondaryHotDC = dcWasabiEuropeCentral_v3 @@ -99,6 +102,12 @@ func (config *S3Config) initialize() { config.secondaryHotDC = hs2 log.Infof("Hot storage: %s (secondary: %s)", hs1, hs2) } + config.derivedStorageDC = config.hotDC + embeddingsDC := viper.GetString("s3.derived-storage") + if embeddingsDC != "" && array.StringInList(embeddingsDC, dcs[:]) { + config.derivedStorageDC = embeddingsDC + log.Infof("Embeddings bucket: %s", embeddingsDC) + } config.buckets = make(map[string]string) config.s3Configs = make(map[string]*aws.Config) @@ -171,6 +180,18 @@ func (config *S3Config) GetHotS3Client() *s3.S3 { return &s3Client } +func (config *S3Config) GetDerivedStorageDataCenter() string { + return config.derivedStorageDC +} +func (config *S3Config) GetDerivedStorageBucket() *string { + return config.GetBucket(config.derivedStorageDC) +} + +func (config *S3Config) GetDerivedStorageS3Client() *s3.S3 { + s3Client := config.GetS3Client(config.derivedStorageDC) + return &s3Client +} + // Return the name of the hot Backblaze data center func (config *S3Config) GetHotBackblazeDC() string { return dcB2EuropeCentral @@ -181,6 +202,10 @@ func (config *S3Config) GetHotWasabiDC() string { return dcWasabiEuropeCentral_v3 } +func (config *S3Config) GetWasabiDerivedDC() string { + return dcWasabiEuropeCentralDerived +} + // Return the name of the cold Scaleway data center func (config *S3Config) GetColdScalewayDC() string { return dcSCWEuropeFrance_v3 diff --git a/web/apps/photos/src/components/Search/SearchBar/searchInput/index.tsx b/web/apps/photos/src/components/Search/SearchBar/searchInput/index.tsx index 1e62422dc..62d4a1f43 100644 --- a/web/apps/photos/src/components/Search/SearchBar/searchInput/index.tsx +++ b/web/apps/photos/src/components/Search/SearchBar/searchInput/index.tsx @@ -9,7 +9,7 @@ import { useCallback, useContext, useEffect, useRef, useState } from "react"; import { components } from "react-select"; import AsyncSelect from "react-select/async"; import { InputActionMeta } from "react-select/src/types"; -import { Person } from "services/face/types"; +import type { Person } from "services/face/people"; import { City } from "services/locationSearchService"; import { getAutoCompleteSuggestions, diff --git a/web/apps/photos/src/components/ml/MLSearchSettings.tsx b/web/apps/photos/src/components/ml/MLSearchSettings.tsx index 409df4fc6..d71dffab7 100644 --- a/web/apps/photos/src/components/ml/MLSearchSettings.tsx +++ b/web/apps/photos/src/components/ml/MLSearchSettings.tsx @@ -270,14 +270,7 @@ function EnableMLSearch({ onClose, enableMlSearch, onRootClose }) { {" "} {/* */} -

- We're putting finishing touches, coming back soon! -

-

- - Existing indexed faces will continue to show. - -

+ We're putting finishing touches, coming back soon!
{isInternalUserForML() && ( diff --git a/web/apps/photos/src/components/ml/PeopleList.tsx b/web/apps/photos/src/components/ml/PeopleList.tsx index 9e5620a5c..da003d97d 100644 --- a/web/apps/photos/src/components/ml/PeopleList.tsx +++ b/web/apps/photos/src/components/ml/PeopleList.tsx @@ -1,10 +1,11 @@ +import { blobCache } from "@/next/blob-cache"; import log from "@/next/log"; import { Skeleton, styled } from "@mui/material"; import { Legend } from "components/PhotoViewer/styledComponents/Legend"; import { t } from "i18next"; import React, { useEffect, useState } from "react"; import mlIDbStorage from "services/face/db"; -import { Face, Person, type MlFileData } from "services/face/types"; +import type { Person } from "services/face/people"; import { EnteFile } from "types/file"; const FaceChipContainer = styled("div")` @@ -57,10 +58,7 @@ export const PeopleList = React.memo((props: PeopleListProps) => { props.onSelect && props.onSelect(person, index) } > - + ))} @@ -108,7 +106,7 @@ export function UnidentifiedFaces(props: { file: EnteFile; updateMLDataIndex: number; }) { - const [faces, setFaces] = useState>([]); + const [faces, setFaces] = useState<{ id: string }[]>([]); useEffect(() => { let didCancel = false; @@ -136,10 +134,7 @@ export function UnidentifiedFaces(props: { {faces && faces.map((face, index) => ( - + ))} @@ -149,29 +144,22 @@ export function UnidentifiedFaces(props: { interface FaceCropImageViewProps { faceID: string; - cacheKey?: string; } -const FaceCropImageView: React.FC = ({ - faceID, - cacheKey, -}) => { +const FaceCropImageView: React.FC = ({ faceID }) => { const [objectURL, setObjectURL] = useState(); useEffect(() => { let didCancel = false; - const electron = globalThis.electron; - - if (faceID && electron) { - electron - .legacyFaceCrop(faceID) - /* - cachedOrNew("face-crops", cacheKey, async () => { - return machineLearningService.regenerateFaceCrop( - faceId, - ); - })*/ + if (faceID) { + blobCache("face-crops") + .then((cache) => cache.get(faceID)) .then((data) => { + /* + TODO(MR): regen if needed and get this to work on web too. + cachedOrNew("face-crops", cacheKey, async () => { + return regenerateFaceCrop(faceId); + })*/ if (data) { const blob = new Blob([data]); if (!didCancel) setObjectURL(URL.createObjectURL(blob)); @@ -183,7 +171,7 @@ const FaceCropImageView: React.FC = ({ didCancel = true; if (objectURL) URL.revokeObjectURL(objectURL); }; - }, [faceID, cacheKey]); + }, [faceID]); return objectURL ? ( @@ -192,9 +180,9 @@ const FaceCropImageView: React.FC = ({ ); }; -async function getPeopleList(file: EnteFile): Promise> { +async function getPeopleList(file: EnteFile): Promise { let startTime = Date.now(); - const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id); + const mlFileData = await mlIDbStorage.getFile(file.id); log.info( "getPeopleList:mlFilesStore:getItem", Date.now() - startTime, @@ -226,8 +214,8 @@ async function getPeopleList(file: EnteFile): Promise> { return peopleList; } -async function getUnidentifiedFaces(file: EnteFile): Promise> { - const mlFileData: MlFileData = await mlIDbStorage.getFile(file.id); +async function getUnidentifiedFaces(file: EnteFile): Promise<{ id: string }[]> { + const mlFileData = await mlIDbStorage.getFile(file.id); return mlFileData?.faces?.filter( (f) => f.personId === null || f.personId === undefined, diff --git a/web/apps/photos/src/pages/gallery/index.tsx b/web/apps/photos/src/pages/gallery/index.tsx index 9ade12fc5..cb0ae1bf1 100644 --- a/web/apps/photos/src/pages/gallery/index.tsx +++ b/web/apps/photos/src/pages/gallery/index.tsx @@ -717,10 +717,10 @@ export default function Gallery() { await syncTrash(collections, setTrashedFiles); await syncEntities(); await syncMapEnabled(); - await syncCLIPEmbeddings(); const electron = globalThis.electron; - if (isInternalUserForML() && electron) { - await syncFaceEmbeddings(); + if (electron) { + await syncCLIPEmbeddings(); + if (isInternalUserForML()) await syncFaceEmbeddings(); } if (clipService.isPlatformSupported()) { void clipService.scheduleImageEmbeddingExtraction(); diff --git a/web/apps/photos/src/services/clip-service.ts b/web/apps/photos/src/services/clip-service.ts index eb5d7ada5..915f9ae03 100644 --- a/web/apps/photos/src/services/clip-service.ts +++ b/web/apps/photos/src/services/clip-service.ts @@ -184,7 +184,7 @@ class CLIPService { }; getTextEmbeddingIfAvailable = async (text: string) => { - return ensureElectron().clipTextEmbeddingIfAvailable(text); + return ensureElectron().computeCLIPTextEmbeddingIfAvailable(text); }; private runClipEmbeddingExtraction = async (canceller: AbortController) => { @@ -294,7 +294,7 @@ class CLIPService { const file = await localFile .arrayBuffer() .then((buffer) => new Uint8Array(buffer)); - return await ensureElectron().clipImageEmbedding(file); + return await ensureElectron().computeCLIPImageEmbedding(file); }; private encryptAndUploadEmbedding = async ( @@ -328,7 +328,8 @@ class CLIPService { private extractFileClipImageEmbedding = async (file: EnteFile) => { const thumb = await downloadManager.getThumbnail(file); - const embedding = await ensureElectron().clipImageEmbedding(thumb); + const embedding = + await ensureElectron().computeCLIPImageEmbedding(thumb); return embedding; }; diff --git a/web/apps/photos/src/services/download/index.ts b/web/apps/photos/src/services/download/index.ts index 0618cd0e6..d0be660c9 100644 --- a/web/apps/photos/src/services/download/index.ts +++ b/web/apps/photos/src/services/download/index.ts @@ -1,6 +1,6 @@ import { FILE_TYPE } from "@/media/file-type"; import { decodeLivePhoto } from "@/media/live-photo"; -import { openCache, type BlobCache } from "@/next/blob-cache"; +import { blobCache, type BlobCache } from "@/next/blob-cache"; import log from "@/next/log"; import { APPS } from "@ente/shared/apps/constants"; import ComlinkCryptoWorker from "@ente/shared/crypto"; @@ -91,7 +91,7 @@ class DownloadManagerImpl { } this.downloadClient = createDownloadClient(app, tokens); try { - this.thumbnailCache = await openCache("thumbs"); + this.thumbnailCache = await blobCache("thumbs"); } catch (e) { log.error( "Failed to open thumbnail cache, will continue without it", @@ -100,7 +100,7 @@ class DownloadManagerImpl { } // TODO (MR): Revisit full file caching cf disk space usage // try { - // if (isElectron()) this.fileCache = await openCache("files"); + // if (isElectron()) this.fileCache = await cache("files"); // } catch (e) { // log.error("Failed to open file cache, will continue without it", e); // } diff --git a/web/apps/photos/src/services/embeddingService.ts b/web/apps/photos/src/services/embeddingService.ts index 17ea5a396..56cebe5a0 100644 --- a/web/apps/photos/src/services/embeddingService.ts +++ b/web/apps/photos/src/services/embeddingService.ts @@ -7,7 +7,7 @@ import HTTPService from "@ente/shared/network/HTTPService"; import { getEndpoint } from "@ente/shared/network/api"; import localForage from "@ente/shared/storage/localForage"; import { getToken } from "@ente/shared/storage/localStorage/helpers"; -import { FileML } from "services/machineLearning/machineLearningService"; +import { FileML } from "services/face/remote"; import type { Embedding, EmbeddingModel, diff --git a/web/apps/photos/src/services/face/align.ts b/web/apps/photos/src/services/face/align.ts deleted file mode 100644 index 7a3bf7a04..000000000 --- a/web/apps/photos/src/services/face/align.ts +++ /dev/null @@ -1,88 +0,0 @@ -import { Matrix } from "ml-matrix"; -import { Point } from "services/face/geom"; -import { FaceAlignment, FaceDetection } from "services/face/types"; -import { getSimilarityTransformation } from "similarity-transformation"; - -const ARCFACE_LANDMARKS = [ - [38.2946, 51.6963], - [73.5318, 51.5014], - [56.0252, 71.7366], - [56.1396, 92.2848], -] as Array<[number, number]>; - -const ARCFACE_LANDMARKS_FACE_SIZE = 112; - -const ARC_FACE_5_LANDMARKS = [ - [38.2946, 51.6963], - [73.5318, 51.5014], - [56.0252, 71.7366], - [41.5493, 92.3655], - [70.7299, 92.2041], -] as Array<[number, number]>; - -/** - * Compute and return an {@link FaceAlignment} for the given face detection. - * - * @param faceDetection A geometry indicating a face detected in an image. - */ -export const faceAlignment = (faceDetection: FaceDetection): FaceAlignment => { - const landmarkCount = faceDetection.landmarks.length; - return getFaceAlignmentUsingSimilarityTransform( - faceDetection, - normalizeLandmarks( - landmarkCount === 5 ? ARC_FACE_5_LANDMARKS : ARCFACE_LANDMARKS, - ARCFACE_LANDMARKS_FACE_SIZE, - ), - ); -}; - -function getFaceAlignmentUsingSimilarityTransform( - faceDetection: FaceDetection, - alignedLandmarks: Array<[number, number]>, -): FaceAlignment { - const landmarksMat = new Matrix( - faceDetection.landmarks - .map((p) => [p.x, p.y]) - .slice(0, alignedLandmarks.length), - ).transpose(); - const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose(); - - const simTransform = getSimilarityTransformation( - landmarksMat, - alignedLandmarksMat, - ); - - const RS = Matrix.mul(simTransform.rotation, simTransform.scale); - const TR = simTransform.translation; - - const affineMatrix = [ - [RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)], - [RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)], - [0, 0, 1], - ]; - - const size = 1 / simTransform.scale; - const meanTranslation = simTransform.toMean.sub(0.5).mul(size); - const centerMat = simTransform.fromMean.sub(meanTranslation); - const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0)); - const rotation = -Math.atan2( - simTransform.rotation.get(0, 1), - simTransform.rotation.get(0, 0), - ); - - return { - affineMatrix, - center, - size, - rotation, - }; -} - -function normalizeLandmarks( - landmarks: Array<[number, number]>, - faceSize: number, -): Array<[number, number]> { - return landmarks.map((landmark) => - landmark.map((p) => p / faceSize), - ) as Array<[number, number]>; -} diff --git a/web/apps/photos/src/services/face/blur.ts b/web/apps/photos/src/services/face/blur.ts deleted file mode 100644 index c79081297..000000000 --- a/web/apps/photos/src/services/face/blur.ts +++ /dev/null @@ -1,187 +0,0 @@ -import { Face } from "services/face/types"; -import { createGrayscaleIntMatrixFromNormalized2List } from "utils/image"; -import { mobileFaceNetFaceSize } from "./embed"; - -/** - * Laplacian blur detection. - */ -export const detectBlur = ( - alignedFaces: Float32Array, - faces: Face[], -): number[] => { - const numFaces = Math.round( - alignedFaces.length / - (mobileFaceNetFaceSize * mobileFaceNetFaceSize * 3), - ); - const blurValues: number[] = []; - for (let i = 0; i < numFaces; i++) { - const face = faces[i]; - const direction = faceDirection(face); - const faceImage = createGrayscaleIntMatrixFromNormalized2List( - alignedFaces, - i, - ); - const laplacian = applyLaplacian(faceImage, direction); - blurValues.push(matrixVariance(laplacian)); - } - return blurValues; -}; - -type FaceDirection = "left" | "right" | "straight"; - -const faceDirection = (face: Face): FaceDirection => { - const landmarks = face.detection.landmarks; - const leftEye = landmarks[0]; - const rightEye = landmarks[1]; - const nose = landmarks[2]; - const leftMouth = landmarks[3]; - const rightMouth = landmarks[4]; - - const eyeDistanceX = Math.abs(rightEye.x - leftEye.x); - const eyeDistanceY = Math.abs(rightEye.y - leftEye.y); - const mouthDistanceY = Math.abs(rightMouth.y - leftMouth.y); - - const faceIsUpright = - Math.max(leftEye.y, rightEye.y) + 0.5 * eyeDistanceY < nose.y && - nose.y + 0.5 * mouthDistanceY < Math.min(leftMouth.y, rightMouth.y); - - const noseStickingOutLeft = - nose.x < Math.min(leftEye.x, rightEye.x) && - nose.x < Math.min(leftMouth.x, rightMouth.x); - - const noseStickingOutRight = - nose.x > Math.max(leftEye.x, rightEye.x) && - nose.x > Math.max(leftMouth.x, rightMouth.x); - - const noseCloseToLeftEye = - Math.abs(nose.x - leftEye.x) < 0.2 * eyeDistanceX; - const noseCloseToRightEye = - Math.abs(nose.x - rightEye.x) < 0.2 * eyeDistanceX; - - if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) { - return "left"; - } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) { - return "right"; - } - - return "straight"; -}; - -/** - * Return a new image by applying a Laplacian blur kernel to each pixel. - */ -const applyLaplacian = ( - image: number[][], - direction: FaceDirection, -): number[][] => { - const paddedImage: number[][] = padImage(image, direction); - const numRows = paddedImage.length - 2; - const numCols = paddedImage[0].length - 2; - - // Create an output image initialized to 0. - const outputImage: number[][] = Array.from({ length: numRows }, () => - new Array(numCols).fill(0), - ); - - // Define the Laplacian kernel. - const kernel: number[][] = [ - [0, 1, 0], - [1, -4, 1], - [0, 1, 0], - ]; - - // Apply the kernel to each pixel - for (let i = 0; i < numRows; i++) { - for (let j = 0; j < numCols; j++) { - let sum = 0; - for (let ki = 0; ki < 3; ki++) { - for (let kj = 0; kj < 3; kj++) { - sum += paddedImage[i + ki][j + kj] * kernel[ki][kj]; - } - } - // Adjust the output value if necessary (e.g., clipping). - outputImage[i][j] = sum; - } - } - - return outputImage; -}; - -const padImage = (image: number[][], direction: FaceDirection): number[][] => { - const removeSideColumns = 56; /* must be even */ - - const numRows = image.length; - const numCols = image[0].length; - const paddedNumCols = numCols + 2 - removeSideColumns; - const paddedNumRows = numRows + 2; - - // Create a new matrix with extra padding. - const paddedImage: number[][] = Array.from({ length: paddedNumRows }, () => - new Array(paddedNumCols).fill(0), - ); - - if (direction === "straight") { - // Copy original image into the center of the padded image. - for (let i = 0; i < numRows; i++) { - for (let j = 0; j < paddedNumCols - 2; j++) { - paddedImage[i + 1][j + 1] = - image[i][j + Math.round(removeSideColumns / 2)]; - } - } - } else if (direction === "left") { - // If the face is facing left, we only take the right side of the face image. - for (let i = 0; i < numRows; i++) { - for (let j = 0; j < paddedNumCols - 2; j++) { - paddedImage[i + 1][j + 1] = image[i][j + removeSideColumns]; - } - } - } else if (direction === "right") { - // If the face is facing right, we only take the left side of the face image. - for (let i = 0; i < numRows; i++) { - for (let j = 0; j < paddedNumCols - 2; j++) { - paddedImage[i + 1][j + 1] = image[i][j]; - } - } - } - - // Reflect padding - // Top and bottom rows - for (let j = 1; j <= paddedNumCols - 2; j++) { - paddedImage[0][j] = paddedImage[2][j]; // Top row - paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; // Bottom row - } - // Left and right columns - for (let i = 0; i < numRows + 2; i++) { - paddedImage[i][0] = paddedImage[i][2]; // Left column - paddedImage[i][paddedNumCols - 1] = paddedImage[i][paddedNumCols - 3]; // Right column - } - - return paddedImage; -}; - -const matrixVariance = (matrix: number[][]): number => { - const numRows = matrix.length; - const numCols = matrix[0].length; - const totalElements = numRows * numCols; - - // Calculate the mean. - let mean: number = 0; - matrix.forEach((row) => { - row.forEach((value) => { - mean += value; - }); - }); - mean /= totalElements; - - // Calculate the variance. - let variance: number = 0; - matrix.forEach((row) => { - row.forEach((value) => { - const diff: number = value - mean; - variance += diff * diff; - }); - }); - variance /= totalElements; - - return variance; -}; diff --git a/web/apps/photos/src/services/face/cluster.ts b/web/apps/photos/src/services/face/cluster.ts index 9ddf156cc..41ba76504 100644 --- a/web/apps/photos/src/services/face/cluster.ts +++ b/web/apps/photos/src/services/face/cluster.ts @@ -1,8 +1,9 @@ import { Hdbscan, type DebugInfo } from "hdbscan"; -import { type Cluster } from "services/face/types"; + +export type Cluster = number[]; export interface ClusterFacesResult { - clusters: Array; + clusters: Cluster[]; noise: Cluster; debugInfo?: DebugInfo; } diff --git a/web/apps/photos/src/services/face/crop.ts b/web/apps/photos/src/services/face/crop.ts deleted file mode 100644 index acd49228e..000000000 --- a/web/apps/photos/src/services/face/crop.ts +++ /dev/null @@ -1,32 +0,0 @@ -import { Box, enlargeBox } from "services/face/geom"; -import { FaceCrop, FaceDetection } from "services/face/types"; -import { cropWithRotation } from "utils/image"; -import { faceAlignment } from "./align"; - -export const getFaceCrop = ( - imageBitmap: ImageBitmap, - faceDetection: FaceDetection, -): FaceCrop => { - const alignment = faceAlignment(faceDetection); - - const padding = 0.25; - const maxSize = 256; - - const alignmentBox = new Box({ - x: alignment.center.x - alignment.size / 2, - y: alignment.center.y - alignment.size / 2, - width: alignment.size, - height: alignment.size, - }).round(); - const scaleForPadding = 1 + padding * 2; - const paddedBox = enlargeBox(alignmentBox, scaleForPadding).round(); - const faceImageBitmap = cropWithRotation(imageBitmap, paddedBox, 0, { - width: maxSize, - height: maxSize, - }); - - return { - image: faceImageBitmap, - imageBox: paddedBox, - }; -}; diff --git a/web/apps/photos/src/services/face/db.ts b/web/apps/photos/src/services/face/db.ts index 399bfff1a..4742dd9d7 100644 --- a/web/apps/photos/src/services/face/db.ts +++ b/web/apps/photos/src/services/face/db.ts @@ -9,7 +9,8 @@ import { openDB, } from "idb"; import isElectron from "is-electron"; -import { Face, MLLibraryData, MlFileData, Person } from "services/face/types"; +import type { Person } from "services/face/people"; +import type { MlFileData } from "services/face/types"; import { DEFAULT_ML_SEARCH_CONFIG, MAX_ML_SYNC_ERROR_COUNT, @@ -23,6 +24,18 @@ export interface IndexStatus { peopleIndexSynced: boolean; } +/** + * TODO(MR): Transient type with an intersection of values that both existing + * and new types during the migration will have. Eventually we'll store the the + * server ML data shape here exactly. + */ +export interface MinimalPersistedFileData { + fileId: number; + mlVersion: number; + errorCount: number; + faces?: { personId?: number; id: string }[]; +} + interface Config {} export const ML_SEARCH_CONFIG_NAME = "ml-search"; @@ -31,7 +44,7 @@ const MLDATA_DB_NAME = "mldata"; interface MLDb extends DBSchema { files: { key: number; - value: MlFileData; + value: MinimalPersistedFileData; indexes: { mlVersion: [number, number] }; }; people: { @@ -50,7 +63,7 @@ interface MLDb extends DBSchema { }; library: { key: string; - value: MLLibraryData; + value: unknown; }; configs: { key: string; @@ -177,6 +190,7 @@ class MLIDbStorage { ML_SEARCH_CONFIG_NAME, ); + db.deleteObjectStore("library"); db.deleteObjectStore("things"); } catch { // TODO: ignore for now as we finalize the new version @@ -210,38 +224,6 @@ class MLIDbStorage { await this.db; } - public async getAllFileIds() { - const db = await this.db; - return db.getAllKeys("files"); - } - - public async putAllFilesInTx(mlFiles: Array) { - const db = await this.db; - const tx = db.transaction("files", "readwrite"); - await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile))); - await tx.done; - } - - public async removeAllFilesInTx(fileIds: Array) { - const db = await this.db; - const tx = db.transaction("files", "readwrite"); - - await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId))); - await tx.done; - } - - public async newTransaction< - Name extends StoreNames, - Mode extends IDBTransactionMode = "readonly", - >(storeNames: Name, mode?: Mode) { - const db = await this.db; - return db.transaction(storeNames, mode); - } - - public async commit(tx: IDBPTransaction) { - return tx.done; - } - public async getAllFileIdsForUpdate( tx: IDBPTransaction, ) { @@ -275,16 +257,11 @@ class MLIDbStorage { return fileIds; } - public async getFile(fileId: number) { + public async getFile(fileId: number): Promise { const db = await this.db; return db.get("files", fileId); } - public async getAllFiles() { - const db = await this.db; - return db.getAll("files"); - } - public async putFile(mlFile: MlFileData) { const db = await this.db; return db.put("files", mlFile); @@ -292,7 +269,7 @@ class MLIDbStorage { public async upsertFileInTx( fileId: number, - upsert: (mlFile: MlFileData) => MlFileData, + upsert: (mlFile: MinimalPersistedFileData) => MinimalPersistedFileData, ) { const db = await this.db; const tx = db.transaction("files", "readwrite"); @@ -305,7 +282,7 @@ class MLIDbStorage { } public async putAllFiles( - mlFiles: Array, + mlFiles: MinimalPersistedFileData[], tx: IDBPTransaction, ) { await Promise.all(mlFiles.map((mlFile) => tx.store.put(mlFile))); @@ -318,44 +295,6 @@ class MLIDbStorage { await Promise.all(fileIds.map((fileId) => tx.store.delete(fileId))); } - public async getFace(fileID: number, faceId: string) { - const file = await this.getFile(fileID); - const face = file.faces.filter((f) => f.id === faceId); - return face[0]; - } - - public async getAllFacesMap() { - const startTime = Date.now(); - const db = await this.db; - const allFiles = await db.getAll("files"); - const allFacesMap = new Map>(); - allFiles.forEach( - (mlFileData) => - mlFileData.faces && - allFacesMap.set(mlFileData.fileId, mlFileData.faces), - ); - log.info("getAllFacesMap", Date.now() - startTime, "ms"); - - return allFacesMap; - } - - public async updateFaces(allFacesMap: Map) { - const startTime = Date.now(); - const db = await this.db; - const tx = db.transaction("files", "readwrite"); - let cursor = await tx.store.openCursor(); - while (cursor) { - if (allFacesMap.has(cursor.key)) { - const mlFileData = { ...cursor.value }; - mlFileData.faces = allFacesMap.get(cursor.key); - cursor.update(mlFileData); - } - cursor = await cursor.continue(); - } - await tx.done; - log.info("updateFaces", Date.now() - startTime, "ms"); - } - public async getPerson(id: number) { const db = await this.db; return db.get("people", id); @@ -366,21 +305,6 @@ class MLIDbStorage { return db.getAll("people"); } - public async putPerson(person: Person) { - const db = await this.db; - return db.put("people", person); - } - - public async clearAllPeople() { - const db = await this.db; - return db.clear("people"); - } - - public async getIndexVersion(index: string) { - const db = await this.db; - return db.get("versions", index); - } - public async incrementIndexVersion(index: StoreNames) { if (index === "versions") { throw new Error("versions store can not be versioned"); @@ -395,21 +319,6 @@ class MLIDbStorage { return version; } - public async setIndexVersion(index: string, version: number) { - const db = await this.db; - return db.put("versions", version, index); - } - - public async getLibraryData() { - const db = await this.db; - return db.get("library", "data"); - } - - public async putLibraryData(data: MLLibraryData) { - const db = await this.db; - return db.put("library", data, "data"); - } - public async getConfig(name: string, def: T) { const db = await this.db; const tx = db.transaction("configs", "readwrite"); @@ -473,66 +382,6 @@ class MLIDbStorage { peopleIndexVersion === filesIndexVersion, }; } - - // for debug purpose - public async getAllMLData() { - const db = await this.db; - const tx = db.transaction(db.objectStoreNames, "readonly"); - const allMLData: any = {}; - for (const store of tx.objectStoreNames) { - const keys = await tx.objectStore(store).getAllKeys(); - const data = await tx.objectStore(store).getAll(); - - allMLData[store] = {}; - for (let i = 0; i < keys.length; i++) { - allMLData[store][keys[i]] = data[i]; - } - } - await tx.done; - - const files = allMLData["files"]; - for (const fileId of Object.keys(files)) { - const fileData = files[fileId]; - fileData.faces?.forEach( - (f) => (f.embedding = Array.from(f.embedding)), - ); - } - - return allMLData; - } - - // for debug purpose, this will overwrite all data - public async putAllMLData(allMLData: Map) { - const db = await this.db; - const tx = db.transaction(db.objectStoreNames, "readwrite"); - for (const store of tx.objectStoreNames) { - const records = allMLData[store]; - if (!records) { - continue; - } - const txStore = tx.objectStore(store); - - if (store === "files") { - const files = records; - for (const fileId of Object.keys(files)) { - const fileData = files[fileId]; - fileData.faces?.forEach( - (f) => (f.embedding = Float32Array.from(f.embedding)), - ); - } - } - - await txStore.clear(); - for (const key of Object.keys(records)) { - if (txStore.keyPath) { - txStore.put(records[key]); - } else { - txStore.put(records[key], key); - } - } - } - await tx.done; - } } export default new MLIDbStorage(); diff --git a/web/apps/photos/src/services/face/detect.ts b/web/apps/photos/src/services/face/detect.ts deleted file mode 100644 index 39b843062..000000000 --- a/web/apps/photos/src/services/face/detect.ts +++ /dev/null @@ -1,316 +0,0 @@ -import { workerBridge } from "@/next/worker/worker-bridge"; -import { euclidean } from "hdbscan"; -import { - Box, - Dimensions, - Point, - boxFromBoundingBox, - newBox, -} from "services/face/geom"; -import { FaceDetection } from "services/face/types"; -import { - Matrix, - applyToPoint, - compose, - scale, - translate, -} from "transformation-matrix"; -import { - clamp, - getPixelBilinear, - normalizePixelBetween0And1, -} from "utils/image"; - -/** - * Detect faces in the given {@link imageBitmap}. - * - * The model used is YOLO, running in an ONNX runtime. - */ -export const detectFaces = async ( - imageBitmap: ImageBitmap, -): Promise> => { - const maxFaceDistancePercent = Math.sqrt(2) / 100; - const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent; - const preprocessResult = preprocessImageBitmapToFloat32ChannelsFirst( - imageBitmap, - 640, - 640, - ); - const data = preprocessResult.data; - const resized = preprocessResult.newSize; - const outputData = await workerBridge.detectFaces(data); - const faces = getFacesFromYOLOOutput(outputData as Float32Array, 0.7); - const inBox = newBox(0, 0, resized.width, resized.height); - const toBox = newBox(0, 0, imageBitmap.width, imageBitmap.height); - const transform = computeTransformToBox(inBox, toBox); - const faceDetections: Array = faces?.map((f) => { - const box = transformBox(f.box, transform); - const normLandmarks = f.landmarks; - const landmarks = transformPoints(normLandmarks, transform); - return { - box, - landmarks, - probability: f.probability as number, - } as FaceDetection; - }); - return removeDuplicateDetections(faceDetections, maxFaceDistance); -}; - -const preprocessImageBitmapToFloat32ChannelsFirst = ( - imageBitmap: ImageBitmap, - requiredWidth: number, - requiredHeight: number, - maintainAspectRatio: boolean = true, - normFunction: (pixelValue: number) => number = normalizePixelBetween0And1, -) => { - // Create an OffscreenCanvas and set its size. - const offscreenCanvas = new OffscreenCanvas( - imageBitmap.width, - imageBitmap.height, - ); - const ctx = offscreenCanvas.getContext("2d"); - ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height); - const imageData = ctx.getImageData( - 0, - 0, - imageBitmap.width, - imageBitmap.height, - ); - const pixelData = imageData.data; - - let scaleW = requiredWidth / imageBitmap.width; - let scaleH = requiredHeight / imageBitmap.height; - if (maintainAspectRatio) { - const scale = Math.min( - requiredWidth / imageBitmap.width, - requiredHeight / imageBitmap.height, - ); - scaleW = scale; - scaleH = scale; - } - const scaledWidth = clamp( - Math.round(imageBitmap.width * scaleW), - 0, - requiredWidth, - ); - const scaledHeight = clamp( - Math.round(imageBitmap.height * scaleH), - 0, - requiredHeight, - ); - - const processedImage = new Float32Array( - 1 * 3 * requiredWidth * requiredHeight, - ); - - // Populate the Float32Array with normalized pixel values - let pixelIndex = 0; - const channelOffsetGreen = requiredHeight * requiredWidth; - const channelOffsetBlue = 2 * requiredHeight * requiredWidth; - for (let h = 0; h < requiredHeight; h++) { - for (let w = 0; w < requiredWidth; w++) { - let pixel: { - r: number; - g: number; - b: number; - }; - if (w >= scaledWidth || h >= scaledHeight) { - pixel = { r: 114, g: 114, b: 114 }; - } else { - pixel = getPixelBilinear( - w / scaleW, - h / scaleH, - pixelData, - imageBitmap.width, - imageBitmap.height, - ); - } - processedImage[pixelIndex] = normFunction(pixel.r); - processedImage[pixelIndex + channelOffsetGreen] = normFunction( - pixel.g, - ); - processedImage[pixelIndex + channelOffsetBlue] = normFunction( - pixel.b, - ); - pixelIndex++; - } - } - - return { - data: processedImage, - originalSize: { - width: imageBitmap.width, - height: imageBitmap.height, - }, - newSize: { width: scaledWidth, height: scaledHeight }, - }; -}; - -/** - * @param rowOutput A Float32Array of shape [25200, 16], where each row - * represents a bounding box. - */ -const getFacesFromYOLOOutput = ( - rowOutput: Float32Array, - minScore: number, -): Array => { - const faces: Array = []; - // Iterate over each row. - for (let i = 0; i < rowOutput.length; i += 16) { - const score = rowOutput[i + 4]; - if (score < minScore) { - continue; - } - // The first 4 values represent the bounding box's coordinates: - // - // (x1, y1, x2, y2) - // - const xCenter = rowOutput[i]; - const yCenter = rowOutput[i + 1]; - const width = rowOutput[i + 2]; - const height = rowOutput[i + 3]; - const xMin = xCenter - width / 2.0; // topLeft - const yMin = yCenter - height / 2.0; // topLeft - - const leftEyeX = rowOutput[i + 5]; - const leftEyeY = rowOutput[i + 6]; - const rightEyeX = rowOutput[i + 7]; - const rightEyeY = rowOutput[i + 8]; - const noseX = rowOutput[i + 9]; - const noseY = rowOutput[i + 10]; - const leftMouthX = rowOutput[i + 11]; - const leftMouthY = rowOutput[i + 12]; - const rightMouthX = rowOutput[i + 13]; - const rightMouthY = rowOutput[i + 14]; - - const box = new Box({ - x: xMin, - y: yMin, - width: width, - height: height, - }); - const probability = score as number; - const landmarks = [ - new Point(leftEyeX, leftEyeY), - new Point(rightEyeX, rightEyeY), - new Point(noseX, noseY), - new Point(leftMouthX, leftMouthY), - new Point(rightMouthX, rightMouthY), - ]; - faces.push({ box, landmarks, probability }); - } - return faces; -}; - -export const getRelativeDetection = ( - faceDetection: FaceDetection, - dimensions: Dimensions, -): FaceDetection => { - const oldBox: Box = faceDetection.box; - const box = new Box({ - x: oldBox.x / dimensions.width, - y: oldBox.y / dimensions.height, - width: oldBox.width / dimensions.width, - height: oldBox.height / dimensions.height, - }); - const oldLandmarks: Point[] = faceDetection.landmarks; - const landmarks = oldLandmarks.map((l) => { - return new Point(l.x / dimensions.width, l.y / dimensions.height); - }); - const probability = faceDetection.probability; - return { box, landmarks, probability }; -}; - -/** - * Removes duplicate face detections from an array of detections. - * - * This function sorts the detections by their probability in descending order, - * then iterates over them. - * - * For each detection, it calculates the Euclidean distance to all other - * detections. - * - * If the distance is less than or equal to the specified threshold - * (`withinDistance`), the other detection is considered a duplicate and is - * removed. - * - * @param detections - An array of face detections to remove duplicates from. - * - * @param withinDistance - The maximum Euclidean distance between two detections - * for them to be considered duplicates. - * - * @returns An array of face detections with duplicates removed. - */ -const removeDuplicateDetections = ( - detections: Array, - withinDistance: number, -) => { - detections.sort((a, b) => b.probability - a.probability); - const isSelected = new Map(); - for (let i = 0; i < detections.length; i++) { - if (isSelected.get(i) === false) { - continue; - } - isSelected.set(i, true); - for (let j = i + 1; j < detections.length; j++) { - if (isSelected.get(j) === false) { - continue; - } - const centeri = getDetectionCenter(detections[i]); - const centerj = getDetectionCenter(detections[j]); - const dist = euclidean( - [centeri.x, centeri.y], - [centerj.x, centerj.y], - ); - if (dist <= withinDistance) { - isSelected.set(j, false); - } - } - } - - const uniques: Array = []; - for (let i = 0; i < detections.length; i++) { - isSelected.get(i) && uniques.push(detections[i]); - } - return uniques; -}; - -function getDetectionCenter(detection: FaceDetection) { - const center = new Point(0, 0); - // TODO: first 4 landmarks is applicable to blazeface only - // this needs to consider eyes, nose and mouth landmarks to take center - detection.landmarks?.slice(0, 4).forEach((p) => { - center.x += p.x; - center.y += p.y; - }); - - return new Point(center.x / 4, center.y / 4); -} - -function computeTransformToBox(inBox: Box, toBox: Box): Matrix { - return compose( - translate(toBox.x, toBox.y), - scale(toBox.width / inBox.width, toBox.height / inBox.height), - ); -} - -function transformPoint(point: Point, transform: Matrix) { - const txdPoint = applyToPoint(transform, point); - return new Point(txdPoint.x, txdPoint.y); -} - -function transformPoints(points: Point[], transform: Matrix) { - return points?.map((p) => transformPoint(p, transform)); -} - -function transformBox(box: Box, transform: Matrix) { - const topLeft = transformPoint(box.topLeft, transform); - const bottomRight = transformPoint(box.bottomRight, transform); - - return boxFromBoundingBox({ - left: topLeft.x, - top: topLeft.y, - right: bottomRight.x, - bottom: bottomRight.y, - }); -} diff --git a/web/apps/photos/src/services/face/embed.ts b/web/apps/photos/src/services/face/embed.ts deleted file mode 100644 index 2e0977ea1..000000000 --- a/web/apps/photos/src/services/face/embed.ts +++ /dev/null @@ -1,26 +0,0 @@ -import { workerBridge } from "@/next/worker/worker-bridge"; -import { FaceEmbedding } from "services/face/types"; - -export const mobileFaceNetFaceSize = 112; - -/** - * Compute embeddings for the given {@link faceData}. - * - * The model used is MobileFaceNet, running in an ONNX runtime. - */ -export const faceEmbeddings = async ( - faceData: Float32Array, -): Promise> => { - const outputData = await workerBridge.faceEmbeddings(faceData); - - const embeddingSize = 192; - const embeddings = new Array( - outputData.length / embeddingSize, - ); - for (let i = 0; i < embeddings.length; i++) { - embeddings[i] = new Float32Array( - outputData.slice(i * embeddingSize, (i + 1) * embeddingSize), - ); - } - return embeddings; -}; diff --git a/web/apps/photos/src/services/face/f-index.ts b/web/apps/photos/src/services/face/f-index.ts index db054ac29..853cd15af 100644 --- a/web/apps/photos/src/services/face/f-index.ts +++ b/web/apps/photos/src/services/face/f-index.ts @@ -1,194 +1,742 @@ -import { openCache } from "@/next/blob-cache"; +import { FILE_TYPE } from "@/media/file-type"; +import { blobCache } from "@/next/blob-cache"; import log from "@/next/log"; -import { faceAlignment } from "services/face/align"; -import mlIDbStorage from "services/face/db"; -import { detectFaces, getRelativeDetection } from "services/face/detect"; -import { faceEmbeddings, mobileFaceNetFaceSize } from "services/face/embed"; +import { workerBridge } from "@/next/worker/worker-bridge"; +import { euclidean } from "hdbscan"; +import { Matrix } from "ml-matrix"; import { - DetectedFace, + Box, + Dimensions, + Point, + enlargeBox, + roundBox, +} from "services/face/geom"; +import type { Face, - MLSyncFileContext, - type FaceAlignment, + FaceAlignment, + FaceDetection, + MlFileData, } from "services/face/types"; -import { imageBitmapToBlob, warpAffineFloat32List } from "utils/image"; -import { detectBlur } from "./blur"; -import { getFaceCrop } from "./crop"; +import { defaultMLVersion } from "services/machineLearning/machineLearningService"; +import { getSimilarityTransformation } from "similarity-transformation"; +import type { EnteFile } from "types/file"; +import { fetchImageBitmap, getLocalFileImageBitmap } from "./file"; import { - fetchImageBitmap, - fetchImageBitmapForContext, - getFaceId, - getLocalFile, + clamp, + grayscaleIntMatrixFromNormalized2List, + pixelRGBBilinear, + warpAffineFloat32List, } from "./image"; +import { transformFaceDetections } from "./transform-box"; -export const syncFileAnalyzeFaces = async (fileContext: MLSyncFileContext) => { - const { newMlFile } = fileContext; +/** + * Index faces in the given file. + * + * This function is the entry point to the indexing pipeline. The file goes + * through various stages: + * + * 1. Downloading the original if needed. + * 2. Detect faces using ONNX/YOLO + * 3. Align the face rectangles, compute blur. + * 4. Compute embeddings for the detected face (crops). + * + * Once all of it is done, it returns the face rectangles and embeddings so that + * they can be saved locally for offline use, and encrypts and uploads them to + * the user's remote storage so that their other devices can download them + * instead of needing to reindex. + */ +export const indexFaces = async (enteFile: EnteFile, localFile?: File) => { const startTime = Date.now(); - await syncFileFaceDetections(fileContext); - - if (newMlFile.faces && newMlFile.faces.length > 0) { - await syncFileFaceCrops(fileContext); - - const alignedFacesData = await syncFileFaceAlignments(fileContext); - - await syncFileFaceEmbeddings(fileContext, alignedFacesData); - - await syncFileFaceMakeRelativeDetections(fileContext); + const imageBitmap = await fetchOrCreateImageBitmap(enteFile, localFile); + let mlFile: MlFileData; + try { + mlFile = await indexFaces_(enteFile, imageBitmap); + } finally { + imageBitmap.close(); } - log.debug( - () => - `Face detection for file ${fileContext.enteFile.id} took ${Math.round(Date.now() - startTime)} ms`, - ); -}; -const syncFileFaceDetections = async (fileContext: MLSyncFileContext) => { - const { newMlFile } = fileContext; - newMlFile.faceDetectionMethod = { - value: "YoloFace", - version: 1, - }; - fileContext.newDetection = true; - const imageBitmap = await fetchImageBitmapForContext(fileContext); - const faceDetections = await detectFaces(imageBitmap); - // TODO: reenable faces filtering based on width - const detectedFaces = faceDetections?.map((detection) => { - return { - fileId: fileContext.enteFile.id, - detection, - } as DetectedFace; + log.debug(() => { + const nf = mlFile.faces?.length ?? 0; + const ms = Date.now() - startTime; + return `Indexed ${nf} faces in file ${enteFile.id} (${ms} ms)`; }); - newMlFile.faces = detectedFaces?.map((detectedFace) => ({ - ...detectedFace, - id: getFaceId(detectedFace, newMlFile.imageDimensions), + return mlFile; +}; + +/** + * Return a {@link ImageBitmap}, using {@link localFile} if present otherwise + * downloading the source image corresponding to {@link enteFile} from remote. + */ +const fetchOrCreateImageBitmap = async ( + enteFile: EnteFile, + localFile: File, +) => { + const fileType = enteFile.metadata.fileType; + if (localFile) { + // TODO-ML(MR): Could also be image part of live photo? + if (fileType !== FILE_TYPE.IMAGE) + throw new Error("Local file of only image type is supported"); + + return await getLocalFileImageBitmap(enteFile, localFile); + } else if ([FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes(fileType)) { + return await fetchImageBitmap(enteFile); + } else { + throw new Error(`Cannot index unsupported file type ${fileType}`); + } +}; + +const indexFaces_ = async (enteFile: EnteFile, imageBitmap: ImageBitmap) => { + const fileID = enteFile.id; + const { width, height } = imageBitmap; + const imageDimensions = { width, height }; + const mlFile: MlFileData = { + fileId: fileID, + mlVersion: defaultMLVersion, + imageDimensions, + errorCount: 0, + }; + + const faceDetections = await detectFaces(imageBitmap); + const detectedFaces = faceDetections.map((detection) => ({ + id: makeFaceID(fileID, detection, imageDimensions), + fileId: fileID, + detection, })); - // ?.filter((f) => - // f.box.width > syncContext.config.faceDetection.minFaceSize - // ); - log.info("[MLService] Detected Faces: ", newMlFile.faces?.length); -}; + mlFile.faces = detectedFaces; -const syncFileFaceCrops = async (fileContext: MLSyncFileContext) => { - const { newMlFile } = fileContext; - const imageBitmap = await fetchImageBitmapForContext(fileContext); - newMlFile.faceCropMethod = { - value: "ArcFace", - version: 1, - }; + if (detectedFaces.length > 0) { + const alignments: FaceAlignment[] = []; - for (const face of newMlFile.faces) { - await saveFaceCrop(imageBitmap, face); - } -}; + for (const face of mlFile.faces) { + const alignment = faceAlignment(face.detection); + face.alignment = alignment; + alignments.push(alignment); -const syncFileFaceAlignments = async ( - fileContext: MLSyncFileContext, -): Promise => { - const { newMlFile } = fileContext; - newMlFile.faceAlignmentMethod = { - value: "ArcFace", - version: 1, - }; - fileContext.newAlignment = true; - const imageBitmap = - fileContext.imageBitmap || - (await fetchImageBitmapForContext(fileContext)); + await saveFaceCrop(imageBitmap, face); + } - // Execute the face alignment calculations - for (const face of newMlFile.faces) { - face.alignment = faceAlignment(face.detection); - } - // Extract face images and convert to Float32Array - const faceAlignments = newMlFile.faces.map((f) => f.alignment); - const faceImages = await extractFaceImagesToFloat32( - faceAlignments, - mobileFaceNetFaceSize, - imageBitmap, - ); - const blurValues = detectBlur(faceImages, newMlFile.faces); - newMlFile.faces.forEach((f, i) => (f.blurValue = blurValues[i])); - - imageBitmap.close(); - log.info("[MLService] alignedFaces: ", newMlFile.faces?.length); - - return faceImages; -}; - -const syncFileFaceEmbeddings = async ( - fileContext: MLSyncFileContext, - alignedFacesInput: Float32Array, -) => { - const { newMlFile } = fileContext; - newMlFile.faceEmbeddingMethod = { - value: "MobileFaceNet", - version: 2, - }; - // TODO: when not storing face crops, image will be needed to extract faces - // fileContext.imageBitmap || - // (await this.getImageBitmap(fileContext)); - - const embeddings = await faceEmbeddings(alignedFacesInput); - newMlFile.faces.forEach((f, i) => (f.embedding = embeddings[i])); - - log.info("[MLService] facesWithEmbeddings: ", newMlFile.faces.length); -}; - -const syncFileFaceMakeRelativeDetections = async ( - fileContext: MLSyncFileContext, -) => { - const { newMlFile } = fileContext; - for (let i = 0; i < newMlFile.faces.length; i++) { - const face = newMlFile.faces[i]; - if (face.detection.box.x + face.detection.box.width < 2) continue; // Skip if somehow already relative - face.detection = getRelativeDetection( - face.detection, - newMlFile.imageDimensions, + const alignedFacesData = convertToMobileFaceNetInput( + imageBitmap, + alignments, ); - } -}; -export const saveFaceCrop = async (imageBitmap: ImageBitmap, face: Face) => { - const faceCrop = getFaceCrop(imageBitmap, face.detection); + const blurValues = detectBlur(alignedFacesData, mlFile.faces); + mlFile.faces.forEach((f, i) => (f.blurValue = blurValues[i])); - const blob = await imageBitmapToBlob(faceCrop.image); + const embeddings = await computeEmbeddings(alignedFacesData); + mlFile.faces.forEach((f, i) => (f.embedding = embeddings[i])); - const cache = await openCache("face-crops"); - await cache.put(face.id, blob); - - faceCrop.image.close(); - - return blob; -}; - -export const regenerateFaceCrop = async (faceID: string) => { - const fileID = Number(faceID.split("-")[0]); - const personFace = await mlIDbStorage.getFace(fileID, faceID); - if (!personFace) { - throw Error("Face not found"); + mlFile.faces.forEach((face) => { + face.detection = relativeDetection(face.detection, imageDimensions); + }); } - const file = await getLocalFile(personFace.fileId); - const imageBitmap = await fetchImageBitmap(file); - return await saveFaceCrop(imageBitmap, personFace); + return mlFile; }; -async function extractFaceImagesToFloat32( - faceAlignments: Array, +/** + * Detect faces in the given {@link imageBitmap}. + * + * The model used is YOLO, running in an ONNX runtime. + */ +const detectFaces = async ( + imageBitmap: ImageBitmap, +): Promise => { + const rect = ({ width, height }: Dimensions) => + new Box({ x: 0, y: 0, width, height }); + + const { yoloInput, yoloSize } = + convertToYOLOInputFloat32ChannelsFirst(imageBitmap); + const yoloOutput = await workerBridge.detectFaces(yoloInput); + const faces = faceDetectionsFromYOLOOutput(yoloOutput); + const faceDetections = transformFaceDetections( + faces, + rect(yoloSize), + rect(imageBitmap), + ); + + const maxFaceDistancePercent = Math.sqrt(2) / 100; + const maxFaceDistance = imageBitmap.width * maxFaceDistancePercent; + return removeDuplicateDetections(faceDetections, maxFaceDistance); +}; + +/** + * Convert {@link imageBitmap} into the format that the YOLO face detection + * model expects. + */ +const convertToYOLOInputFloat32ChannelsFirst = (imageBitmap: ImageBitmap) => { + const requiredWidth = 640; + const requiredHeight = 640; + + const { width, height } = imageBitmap; + + // Create an OffscreenCanvas and set its size. + const offscreenCanvas = new OffscreenCanvas(width, height); + const ctx = offscreenCanvas.getContext("2d"); + ctx.drawImage(imageBitmap, 0, 0, width, height); + const imageData = ctx.getImageData(0, 0, width, height); + const pixelData = imageData.data; + + // Maintain aspect ratio. + const scale = Math.min(requiredWidth / width, requiredHeight / height); + + const scaledWidth = clamp(Math.round(width * scale), 0, requiredWidth); + const scaledHeight = clamp(Math.round(height * scale), 0, requiredHeight); + + const yoloInput = new Float32Array(1 * 3 * requiredWidth * requiredHeight); + const yoloSize = { width: scaledWidth, height: scaledHeight }; + + // Populate the Float32Array with normalized pixel values. + let pi = 0; + const channelOffsetGreen = requiredHeight * requiredWidth; + const channelOffsetBlue = 2 * requiredHeight * requiredWidth; + for (let h = 0; h < requiredHeight; h++) { + for (let w = 0; w < requiredWidth; w++) { + const { r, g, b } = + w >= scaledWidth || h >= scaledHeight + ? { r: 114, g: 114, b: 114 } + : pixelRGBBilinear( + w / scale, + h / scale, + pixelData, + width, + height, + ); + yoloInput[pi] = r / 255.0; + yoloInput[pi + channelOffsetGreen] = g / 255.0; + yoloInput[pi + channelOffsetBlue] = b / 255.0; + pi++; + } + } + + return { yoloInput, yoloSize }; +}; + +/** + * Extract detected faces from the YOLO's output. + * + * Only detections that exceed a minimum score are returned. + * + * @param rows A Float32Array of shape [25200, 16], where each row + * represents a bounding box. + */ +const faceDetectionsFromYOLOOutput = (rows: Float32Array): FaceDetection[] => { + const faces: FaceDetection[] = []; + // Iterate over each row. + for (let i = 0; i < rows.length; i += 16) { + const score = rows[i + 4]; + if (score < 0.7) continue; + + const xCenter = rows[i]; + const yCenter = rows[i + 1]; + const width = rows[i + 2]; + const height = rows[i + 3]; + const xMin = xCenter - width / 2.0; // topLeft + const yMin = yCenter - height / 2.0; // topLeft + + const leftEyeX = rows[i + 5]; + const leftEyeY = rows[i + 6]; + const rightEyeX = rows[i + 7]; + const rightEyeY = rows[i + 8]; + const noseX = rows[i + 9]; + const noseY = rows[i + 10]; + const leftMouthX = rows[i + 11]; + const leftMouthY = rows[i + 12]; + const rightMouthX = rows[i + 13]; + const rightMouthY = rows[i + 14]; + + const box = new Box({ + x: xMin, + y: yMin, + width: width, + height: height, + }); + const probability = score as number; + const landmarks = [ + new Point(leftEyeX, leftEyeY), + new Point(rightEyeX, rightEyeY), + new Point(noseX, noseY), + new Point(leftMouthX, leftMouthY), + new Point(rightMouthX, rightMouthY), + ]; + faces.push({ box, landmarks, probability }); + } + return faces; +}; + +/** + * Removes duplicate face detections from an array of detections. + * + * This function sorts the detections by their probability in descending order, + * then iterates over them. + * + * For each detection, it calculates the Euclidean distance to all other + * detections. + * + * If the distance is less than or equal to the specified threshold + * (`withinDistance`), the other detection is considered a duplicate and is + * removed. + * + * @param detections - An array of face detections to remove duplicates from. + * + * @param withinDistance - The maximum Euclidean distance between two detections + * for them to be considered duplicates. + * + * @returns An array of face detections with duplicates removed. + */ +const removeDuplicateDetections = ( + detections: FaceDetection[], + withinDistance: number, +) => { + detections.sort((a, b) => b.probability - a.probability); + + const dupIndices = new Set(); + for (let i = 0; i < detections.length; i++) { + if (dupIndices.has(i)) continue; + + for (let j = i + 1; j < detections.length; j++) { + if (dupIndices.has(j)) continue; + + const centeri = faceDetectionCenter(detections[i]); + const centerj = faceDetectionCenter(detections[j]); + const dist = euclidean( + [centeri.x, centeri.y], + [centerj.x, centerj.y], + ); + + if (dist <= withinDistance) dupIndices.add(j); + } + } + + return detections.filter((_, i) => !dupIndices.has(i)); +}; + +const faceDetectionCenter = (detection: FaceDetection) => { + const center = new Point(0, 0); + // TODO-ML(LAURENS): first 4 landmarks is applicable to blazeface only this + // needs to consider eyes, nose and mouth landmarks to take center + detection.landmarks?.slice(0, 4).forEach((p) => { + center.x += p.x; + center.y += p.y; + }); + return new Point(center.x / 4, center.y / 4); +}; + +const makeFaceID = ( + fileID: number, + detection: FaceDetection, + imageDims: Dimensions, +) => { + const part = (v: number) => clamp(v, 0.0, 0.999999).toFixed(5).substring(2); + const xMin = part(detection.box.x / imageDims.width); + const yMin = part(detection.box.y / imageDims.height); + const xMax = part( + (detection.box.x + detection.box.width) / imageDims.width, + ); + const yMax = part( + (detection.box.y + detection.box.height) / imageDims.height, + ); + return [`${fileID}`, xMin, yMin, xMax, yMax].join("_"); +}; + +/** + * Compute and return an {@link FaceAlignment} for the given face detection. + * + * @param faceDetection A geometry indicating a face detected in an image. + */ +const faceAlignment = (faceDetection: FaceDetection): FaceAlignment => + faceAlignmentUsingSimilarityTransform( + faceDetection, + normalizeLandmarks(idealMobileFaceNetLandmarks, mobileFaceNetFaceSize), + ); + +/** + * The ideal location of the landmarks (eye etc) that the MobileFaceNet + * embedding model expects. + */ +const idealMobileFaceNetLandmarks: [number, number][] = [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041], +]; + +const normalizeLandmarks = ( + landmarks: [number, number][], faceSize: number, - image: ImageBitmap, -): Promise { +): [number, number][] => + landmarks.map(([x, y]) => [x / faceSize, y / faceSize]); + +const faceAlignmentUsingSimilarityTransform = ( + faceDetection: FaceDetection, + alignedLandmarks: [number, number][], +): FaceAlignment => { + const landmarksMat = new Matrix( + faceDetection.landmarks + .map((p) => [p.x, p.y]) + .slice(0, alignedLandmarks.length), + ).transpose(); + const alignedLandmarksMat = new Matrix(alignedLandmarks).transpose(); + + const simTransform = getSimilarityTransformation( + landmarksMat, + alignedLandmarksMat, + ); + + const RS = Matrix.mul(simTransform.rotation, simTransform.scale); + const TR = simTransform.translation; + + const affineMatrix = [ + [RS.get(0, 0), RS.get(0, 1), TR.get(0, 0)], + [RS.get(1, 0), RS.get(1, 1), TR.get(1, 0)], + [0, 0, 1], + ]; + + const size = 1 / simTransform.scale; + const meanTranslation = simTransform.toMean.sub(0.5).mul(size); + const centerMat = simTransform.fromMean.sub(meanTranslation); + const center = new Point(centerMat.get(0, 0), centerMat.get(1, 0)); + const rotation = -Math.atan2( + simTransform.rotation.get(0, 1), + simTransform.rotation.get(0, 0), + ); + + return { affineMatrix, center, size, rotation }; +}; + +const convertToMobileFaceNetInput = ( + imageBitmap: ImageBitmap, + faceAlignments: FaceAlignment[], +): Float32Array => { + const faceSize = mobileFaceNetFaceSize; const faceData = new Float32Array( faceAlignments.length * faceSize * faceSize * 3, ); for (let i = 0; i < faceAlignments.length; i++) { - const alignedFace = faceAlignments[i]; + const { affineMatrix } = faceAlignments[i]; const faceDataOffset = i * faceSize * faceSize * 3; warpAffineFloat32List( - image, - alignedFace, + imageBitmap, + affineMatrix, faceSize, faceData, faceDataOffset, ); } return faceData; -} +}; + +/** + * Laplacian blur detection. + * + * Return an array of detected blur values, one for each face in {@link faces}. + * The face data is taken from the slice of {@link alignedFacesData} + * corresponding to each face of {@link faces}. + */ +const detectBlur = (alignedFacesData: Float32Array, faces: Face[]): number[] => + faces.map((face, i) => { + const faceImage = grayscaleIntMatrixFromNormalized2List( + alignedFacesData, + i, + mobileFaceNetFaceSize, + mobileFaceNetFaceSize, + ); + return matrixVariance(applyLaplacian(faceImage, faceDirection(face))); + }); + +type FaceDirection = "left" | "right" | "straight"; + +const faceDirection = (face: Face): FaceDirection => { + const landmarks = face.detection.landmarks; + const leftEye = landmarks[0]; + const rightEye = landmarks[1]; + const nose = landmarks[2]; + const leftMouth = landmarks[3]; + const rightMouth = landmarks[4]; + + const eyeDistanceX = Math.abs(rightEye.x - leftEye.x); + const eyeDistanceY = Math.abs(rightEye.y - leftEye.y); + const mouthDistanceY = Math.abs(rightMouth.y - leftMouth.y); + + const faceIsUpright = + Math.max(leftEye.y, rightEye.y) + 0.5 * eyeDistanceY < nose.y && + nose.y + 0.5 * mouthDistanceY < Math.min(leftMouth.y, rightMouth.y); + + const noseStickingOutLeft = + nose.x < Math.min(leftEye.x, rightEye.x) && + nose.x < Math.min(leftMouth.x, rightMouth.x); + + const noseStickingOutRight = + nose.x > Math.max(leftEye.x, rightEye.x) && + nose.x > Math.max(leftMouth.x, rightMouth.x); + + const noseCloseToLeftEye = + Math.abs(nose.x - leftEye.x) < 0.2 * eyeDistanceX; + const noseCloseToRightEye = + Math.abs(nose.x - rightEye.x) < 0.2 * eyeDistanceX; + + if (noseStickingOutLeft || (faceIsUpright && noseCloseToLeftEye)) { + return "left"; + } else if (noseStickingOutRight || (faceIsUpright && noseCloseToRightEye)) { + return "right"; + } + + return "straight"; +}; + +/** + * Return a new image by applying a Laplacian blur kernel to each pixel. + */ +const applyLaplacian = ( + image: number[][], + direction: FaceDirection, +): number[][] => { + const paddedImage = padImage(image, direction); + const numRows = paddedImage.length - 2; + const numCols = paddedImage[0].length - 2; + + // Create an output image initialized to 0. + const outputImage: number[][] = Array.from({ length: numRows }, () => + new Array(numCols).fill(0), + ); + + // Define the Laplacian kernel. + const kernel = [ + [0, 1, 0], + [1, -4, 1], + [0, 1, 0], + ]; + + // Apply the kernel to each pixel + for (let i = 0; i < numRows; i++) { + for (let j = 0; j < numCols; j++) { + let sum = 0; + for (let ki = 0; ki < 3; ki++) { + for (let kj = 0; kj < 3; kj++) { + sum += paddedImage[i + ki][j + kj] * kernel[ki][kj]; + } + } + // Adjust the output value if necessary (e.g., clipping). + outputImage[i][j] = sum; + } + } + + return outputImage; +}; + +const padImage = (image: number[][], direction: FaceDirection): number[][] => { + const removeSideColumns = 56; /* must be even */ + + const numRows = image.length; + const numCols = image[0].length; + const paddedNumCols = numCols + 2 - removeSideColumns; + const paddedNumRows = numRows + 2; + + // Create a new matrix with extra padding. + const paddedImage: number[][] = Array.from({ length: paddedNumRows }, () => + new Array(paddedNumCols).fill(0), + ); + + if (direction === "straight") { + // Copy original image into the center of the padded image. + for (let i = 0; i < numRows; i++) { + for (let j = 0; j < paddedNumCols - 2; j++) { + paddedImage[i + 1][j + 1] = + image[i][j + Math.round(removeSideColumns / 2)]; + } + } + } else if (direction === "left") { + // If the face is facing left, we only take the right side of the face + // image. + for (let i = 0; i < numRows; i++) { + for (let j = 0; j < paddedNumCols - 2; j++) { + paddedImage[i + 1][j + 1] = image[i][j + removeSideColumns]; + } + } + } else if (direction === "right") { + // If the face is facing right, we only take the left side of the face + // image. + for (let i = 0; i < numRows; i++) { + for (let j = 0; j < paddedNumCols - 2; j++) { + paddedImage[i + 1][j + 1] = image[i][j]; + } + } + } + + // Reflect padding + // - Top and bottom rows + for (let j = 1; j <= paddedNumCols - 2; j++) { + // Top row + paddedImage[0][j] = paddedImage[2][j]; + // Bottom row + paddedImage[numRows + 1][j] = paddedImage[numRows - 1][j]; + } + // - Left and right columns + for (let i = 0; i < numRows + 2; i++) { + // Left column + paddedImage[i][0] = paddedImage[i][2]; + // Right column + paddedImage[i][paddedNumCols - 1] = paddedImage[i][paddedNumCols - 3]; + } + + return paddedImage; +}; + +const matrixVariance = (matrix: number[][]): number => { + const numRows = matrix.length; + const numCols = matrix[0].length; + const totalElements = numRows * numCols; + + // Calculate the mean. + let mean: number = 0; + matrix.forEach((row) => { + row.forEach((value) => { + mean += value; + }); + }); + mean /= totalElements; + + // Calculate the variance. + let variance: number = 0; + matrix.forEach((row) => { + row.forEach((value) => { + const diff: number = value - mean; + variance += diff * diff; + }); + }); + variance /= totalElements; + + return variance; +}; + +const mobileFaceNetFaceSize = 112; +const mobileFaceNetEmbeddingSize = 192; + +/** + * Compute embeddings for the given {@link faceData}. + * + * The model used is MobileFaceNet, running in an ONNX runtime. + */ +const computeEmbeddings = async ( + faceData: Float32Array, +): Promise => { + const outputData = await workerBridge.computeFaceEmbeddings(faceData); + + const embeddingSize = mobileFaceNetEmbeddingSize; + const embeddings = new Array( + outputData.length / embeddingSize, + ); + for (let i = 0; i < embeddings.length; i++) { + embeddings[i] = new Float32Array( + outputData.slice(i * embeddingSize, (i + 1) * embeddingSize), + ); + } + return embeddings; +}; + +/** + * Convert the coordinates to between 0-1, normalized by the image's dimensions. + */ +const relativeDetection = ( + faceDetection: FaceDetection, + { width, height }: Dimensions, +): FaceDetection => { + const oldBox: Box = faceDetection.box; + const box = new Box({ + x: oldBox.x / width, + y: oldBox.y / height, + width: oldBox.width / width, + height: oldBox.height / height, + }); + const landmarks = faceDetection.landmarks.map((l) => { + return new Point(l.x / width, l.y / height); + }); + const probability = faceDetection.probability; + return { box, landmarks, probability }; +}; + +export const saveFaceCrop = async (imageBitmap: ImageBitmap, face: Face) => { + const faceCrop = extractFaceCrop(imageBitmap, face.alignment); + const blob = await imageBitmapToBlob(faceCrop); + faceCrop.close(); + + const cache = await blobCache("face-crops"); + await cache.put(face.id, blob); + + return blob; +}; + +const imageBitmapToBlob = (imageBitmap: ImageBitmap) => { + const canvas = new OffscreenCanvas(imageBitmap.width, imageBitmap.height); + canvas.getContext("2d").drawImage(imageBitmap, 0, 0); + return canvas.convertToBlob({ type: "image/jpeg", quality: 0.8 }); +}; + +const extractFaceCrop = ( + imageBitmap: ImageBitmap, + alignment: FaceAlignment, +): ImageBitmap => { + const alignmentBox = new Box({ + x: alignment.center.x - alignment.size / 2, + y: alignment.center.y - alignment.size / 2, + width: alignment.size, + height: alignment.size, + }); + + const padding = 0.25; + const scaleForPadding = 1 + padding * 2; + const paddedBox = roundBox(enlargeBox(alignmentBox, scaleForPadding)); + + // TODO-ML(LAURENS): The rotation doesn't seem to be used? it's set to 0. + return cropWithRotation(imageBitmap, paddedBox, 0, 256); +}; + +const cropWithRotation = ( + imageBitmap: ImageBitmap, + cropBox: Box, + rotation: number, + maxDimension: number, +) => { + const box = roundBox(cropBox); + + const outputSize = { width: box.width, height: box.height }; + + const scale = Math.min(maxDimension / box.width, maxDimension / box.height); + if (scale < 1) { + outputSize.width = Math.round(scale * box.width); + outputSize.height = Math.round(scale * box.height); + } + + const offscreen = new OffscreenCanvas(outputSize.width, outputSize.height); + const offscreenCtx = offscreen.getContext("2d"); + offscreenCtx.imageSmoothingQuality = "high"; + + offscreenCtx.translate(outputSize.width / 2, outputSize.height / 2); + rotation && offscreenCtx.rotate(rotation); + + const outputBox = new Box({ + x: -outputSize.width / 2, + y: -outputSize.height / 2, + width: outputSize.width, + height: outputSize.height, + }); + + const enlargedBox = enlargeBox(box, 1.5); + const enlargedOutputBox = enlargeBox(outputBox, 1.5); + + offscreenCtx.drawImage( + imageBitmap, + enlargedBox.x, + enlargedBox.y, + enlargedBox.width, + enlargedBox.height, + enlargedOutputBox.x, + enlargedOutputBox.y, + enlargedOutputBox.width, + enlargedOutputBox.height, + ); + + return offscreen.transferToImageBitmap(); +}; diff --git a/web/apps/photos/src/services/face/face.worker.ts b/web/apps/photos/src/services/face/face.worker.ts index 8083406bf..0ba2233e7 100644 --- a/web/apps/photos/src/services/face/face.worker.ts +++ b/web/apps/photos/src/services/face/face.worker.ts @@ -12,20 +12,16 @@ export class DedicatedMLWorker { public async syncLocalFile( token: string, userID: number, + userAgent: string, enteFile: EnteFile, localFile: globalThis.File, ) { - mlService.syncLocalFile(token, userID, enteFile, localFile); + mlService.syncLocalFile(token, userID, userAgent, enteFile, localFile); } - public async sync(token: string, userID: number) { + public async sync(token: string, userID: number, userAgent: string) { await downloadManager.init(APPS.PHOTOS, { token }); - return mlService.sync(token, userID); - } - - public async regenerateFaceCrop(token: string, faceID: string) { - await downloadManager.init(APPS.PHOTOS, { token }); - return mlService.regenerateFaceCrop(faceID); + return mlService.sync(token, userID, userAgent); } } diff --git a/web/apps/photos/src/services/face/file.ts b/web/apps/photos/src/services/face/file.ts new file mode 100644 index 000000000..b482af3fb --- /dev/null +++ b/web/apps/photos/src/services/face/file.ts @@ -0,0 +1,37 @@ +import { FILE_TYPE } from "@/media/file-type"; +import { decodeLivePhoto } from "@/media/live-photo"; +import DownloadManager from "services/download"; +import { getLocalFiles } from "services/fileService"; +import { EnteFile } from "types/file"; +import { getRenderableImage } from "utils/file"; + +export async function getLocalFile(fileId: number) { + const localFiles = await getLocalFiles(); + return localFiles.find((f) => f.id === fileId); +} + +export const fetchImageBitmap = async (file: EnteFile) => + fetchRenderableBlob(file).then(createImageBitmap); + +async function fetchRenderableBlob(file: EnteFile) { + const fileStream = await DownloadManager.getFile(file); + const fileBlob = await new Response(fileStream).blob(); + if (file.metadata.fileType === FILE_TYPE.IMAGE) { + return await getRenderableImage(file.metadata.title, fileBlob); + } else { + const { imageFileName, imageData } = await decodeLivePhoto( + file.metadata.title, + fileBlob, + ); + return await getRenderableImage(imageFileName, new Blob([imageData])); + } +} + +export async function getLocalFileImageBitmap( + enteFile: EnteFile, + localFile: globalThis.File, +) { + let fileBlob = localFile as Blob; + fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob); + return createImageBitmap(fileBlob); +} diff --git a/web/apps/photos/src/services/face/geom.ts b/web/apps/photos/src/services/face/geom.ts index 556e2b309..5f6456ca6 100644 --- a/web/apps/photos/src/services/face/geom.ts +++ b/web/apps/photos/src/services/face/geom.ts @@ -13,13 +13,6 @@ export interface Dimensions { height: number; } -export interface IBoundingBox { - left: number; - top: number; - right: number; - bottom: number; -} - export interface IRect { x: number; y: number; @@ -27,24 +20,6 @@ export interface IRect { height: number; } -export function newBox(x: number, y: number, width: number, height: number) { - return new Box({ x, y, width, height }); -} - -export const boxFromBoundingBox = ({ - left, - top, - right, - bottom, -}: IBoundingBox) => { - return new Box({ - x: left, - y: top, - width: right - left, - height: bottom - top, - }); -}; - export class Box implements IRect { public x: number; public y: number; @@ -57,36 +32,26 @@ export class Box implements IRect { this.width = width; this.height = height; } - - public get topLeft(): Point { - return new Point(this.x, this.y); - } - - public get bottomRight(): Point { - return new Point(this.x + this.width, this.y + this.height); - } - - public round(): Box { - const [x, y, width, height] = [ - this.x, - this.y, - this.width, - this.height, - ].map((val) => Math.round(val)); - return new Box({ x, y, width, height }); - } } -export function enlargeBox(box: Box, factor: number = 1.5) { +/** Round all the components of the box. */ +export const roundBox = (box: Box): Box => { + const [x, y, width, height] = [box.x, box.y, box.width, box.height].map( + (val) => Math.round(val), + ); + return new Box({ x, y, width, height }); +}; + +/** Increase the size of the given {@link box} by {@link factor}. */ +export const enlargeBox = (box: Box, factor: number) => { const center = new Point(box.x + box.width / 2, box.y + box.height / 2); + const newWidth = factor * box.width; + const newHeight = factor * box.height; - const size = new Point(box.width, box.height); - const newHalfSize = new Point((factor * size.x) / 2, (factor * size.y) / 2); - - return boxFromBoundingBox({ - left: center.x - newHalfSize.x, - top: center.y - newHalfSize.y, - right: center.x + newHalfSize.x, - bottom: center.y + newHalfSize.y, + return new Box({ + x: center.x - newWidth / 2, + y: center.y - newHeight / 2, + width: newWidth, + height: newHeight, }); -} +}; diff --git a/web/apps/photos/src/services/face/image.ts b/web/apps/photos/src/services/face/image.ts index 1ddcc70f6..12f49db54 100644 --- a/web/apps/photos/src/services/face/image.ts +++ b/web/apps/photos/src/services/face/image.ts @@ -1,121 +1,295 @@ -import { FILE_TYPE } from "@/media/file-type"; -import { decodeLivePhoto } from "@/media/live-photo"; -import log from "@/next/log"; -import DownloadManager from "services/download"; -import { Dimensions } from "services/face/geom"; -import { DetectedFace, MLSyncFileContext } from "services/face/types"; -import { getLocalFiles } from "services/fileService"; -import { EnteFile } from "types/file"; -import { getRenderableImage } from "utils/file"; -import { clamp } from "utils/image"; +import { Matrix, inverse } from "ml-matrix"; -export const fetchImageBitmapForContext = async ( - fileContext: MLSyncFileContext, +/** + * Clamp {@link value} to between {@link min} and {@link max}, inclusive. + */ +export const clamp = (value: number, min: number, max: number) => + Math.min(max, Math.max(min, value)); + +/** + * Returns the pixel value (RGB) at the given coordinates ({@link fx}, + * {@link fy}) using bilinear interpolation. + */ +export function pixelRGBBilinear( + fx: number, + fy: number, + imageData: Uint8ClampedArray, + imageWidth: number, + imageHeight: number, +) { + // Clamp to image boundaries. + fx = clamp(fx, 0, imageWidth - 1); + fy = clamp(fy, 0, imageHeight - 1); + + // Get the surrounding coordinates and their weights. + const x0 = Math.floor(fx); + const x1 = Math.ceil(fx); + const y0 = Math.floor(fy); + const y1 = Math.ceil(fy); + const dx = fx - x0; + const dy = fy - y0; + const dx1 = 1.0 - dx; + const dy1 = 1.0 - dy; + + // Get the original pixels. + const pixel1 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y0); + const pixel2 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y0); + const pixel3 = pixelRGBA(imageData, imageWidth, imageHeight, x0, y1); + const pixel4 = pixelRGBA(imageData, imageWidth, imageHeight, x1, y1); + + const bilinear = (val1: number, val2: number, val3: number, val4: number) => + Math.round( + val1 * dx1 * dy1 + + val2 * dx * dy1 + + val3 * dx1 * dy + + val4 * dx * dy, + ); + + // Return interpolated pixel colors. + return { + r: bilinear(pixel1.r, pixel2.r, pixel3.r, pixel4.r), + g: bilinear(pixel1.g, pixel2.g, pixel3.g, pixel4.g), + b: bilinear(pixel1.b, pixel2.b, pixel3.b, pixel4.b), + }; +} + +const pixelRGBA = ( + imageData: Uint8ClampedArray, + width: number, + height: number, + x: number, + y: number, ) => { - if (fileContext.imageBitmap) { - return fileContext.imageBitmap; + if (x < 0 || x >= width || y < 0 || y >= height) { + return { r: 0, g: 0, b: 0, a: 0 }; } - if (fileContext.localFile) { - if (fileContext.enteFile.metadata.fileType !== FILE_TYPE.IMAGE) { - throw new Error("Local file of only image type is supported"); - } - fileContext.imageBitmap = await getLocalFileImageBitmap( - fileContext.enteFile, - fileContext.localFile, - ); - } else if ( - [FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes( - fileContext.enteFile.metadata.fileType, - ) - ) { - fileContext.imageBitmap = await fetchImageBitmap(fileContext.enteFile); - } else { - // TODO-ML(MR): We don't do it on videos, when will we ever come - // here? - fileContext.imageBitmap = await getThumbnailImageBitmap( - fileContext.enteFile, - ); - } - - fileContext.newMlFile.imageSource = "Original"; - const { width, height } = fileContext.imageBitmap; - fileContext.newMlFile.imageDimensions = { width, height }; - - return fileContext.imageBitmap; + const index = (y * width + x) * 4; + return { + r: imageData[index], + g: imageData[index + 1], + b: imageData[index + 2], + a: imageData[index + 3], + }; }; -export async function getLocalFile(fileId: number) { - const localFiles = await getLocalFiles(); - return localFiles.find((f) => f.id === fileId); -} +/** + * Returns the pixel value (RGB) at the given coordinates ({@link fx}, + * {@link fy}) using bicubic interpolation. + */ +const pixelRGBBicubic = ( + fx: number, + fy: number, + imageData: Uint8ClampedArray, + imageWidth: number, + imageHeight: number, +) => { + // Clamp to image boundaries. + fx = clamp(fx, 0, imageWidth - 1); + fy = clamp(fy, 0, imageHeight - 1); -export function getFaceId(detectedFace: DetectedFace, imageDims: Dimensions) { - const xMin = clamp( - detectedFace.detection.box.x / imageDims.width, - 0.0, - 0.999999, - ) - .toFixed(5) - .substring(2); - const yMin = clamp( - detectedFace.detection.box.y / imageDims.height, - 0.0, - 0.999999, - ) - .toFixed(5) - .substring(2); - const xMax = clamp( - (detectedFace.detection.box.x + detectedFace.detection.box.width) / - imageDims.width, - 0.0, - 0.999999, - ) - .toFixed(5) - .substring(2); - const yMax = clamp( - (detectedFace.detection.box.y + detectedFace.detection.box.height) / - imageDims.height, - 0.0, - 0.999999, - ) - .toFixed(5) - .substring(2); + const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1); + const px = x - 1; + const nx = x + 1; + const ax = x + 2; + const y = Math.trunc(fy) - (fy >= 0.0 ? 0 : 1); + const py = y - 1; + const ny = y + 1; + const ay = y + 2; + const dx = fx - x; + const dy = fy - y; - const rawFaceID = `${xMin}_${yMin}_${xMax}_${yMax}`; - const faceID = `${detectedFace.fileId}_${rawFaceID}`; + const cubic = ( + dx: number, + ipp: number, + icp: number, + inp: number, + iap: number, + ) => + icp + + 0.5 * + (dx * (-ipp + inp) + + dx * dx * (2 * ipp - 5 * icp + 4 * inp - iap) + + dx * dx * dx * (-ipp + 3 * icp - 3 * inp + iap)); - return faceID; -} + const icc = pixelRGBA(imageData, imageWidth, imageHeight, x, y); -export const fetchImageBitmap = async (file: EnteFile) => - fetchRenderableBlob(file).then(createImageBitmap); + const ipp = + px < 0 || py < 0 + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, px, py); + const icp = + px < 0 ? icc : pixelRGBA(imageData, imageWidth, imageHeight, x, py); + const inp = + py < 0 || nx >= imageWidth + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, nx, py); + const iap = + ax >= imageWidth || py < 0 + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, ax, py); -async function fetchRenderableBlob(file: EnteFile) { - const fileStream = await DownloadManager.getFile(file); - const fileBlob = await new Response(fileStream).blob(); - if (file.metadata.fileType === FILE_TYPE.IMAGE) { - return await getRenderableImage(file.metadata.title, fileBlob); - } else { - const { imageFileName, imageData } = await decodeLivePhoto( - file.metadata.title, - fileBlob, - ); - return await getRenderableImage(imageFileName, new Blob([imageData])); + const ip0 = cubic(dx, ipp.r, icp.r, inp.r, iap.r); + const ip1 = cubic(dx, ipp.g, icp.g, inp.g, iap.g); + const ip2 = cubic(dx, ipp.b, icp.b, inp.b, iap.b); + // const ip3 = cubic(dx, ipp.a, icp.a, inp.a, iap.a); + + const ipc = + px < 0 ? icc : pixelRGBA(imageData, imageWidth, imageHeight, px, y); + const inc = + nx >= imageWidth + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, nx, y); + const iac = + ax >= imageWidth + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, ax, y); + + const ic0 = cubic(dx, ipc.r, icc.r, inc.r, iac.r); + const ic1 = cubic(dx, ipc.g, icc.g, inc.g, iac.g); + const ic2 = cubic(dx, ipc.b, icc.b, inc.b, iac.b); + // const ic3 = cubic(dx, ipc.a, icc.a, inc.a, iac.a); + + const ipn = + px < 0 || ny >= imageHeight + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, px, ny); + const icn = + ny >= imageHeight + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, x, ny); + const inn = + nx >= imageWidth || ny >= imageHeight + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, nx, ny); + const ian = + ax >= imageWidth || ny >= imageHeight + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, ax, ny); + + const in0 = cubic(dx, ipn.r, icn.r, inn.r, ian.r); + const in1 = cubic(dx, ipn.g, icn.g, inn.g, ian.g); + const in2 = cubic(dx, ipn.b, icn.b, inn.b, ian.b); + // const in3 = cubic(dx, ipn.a, icn.a, inn.a, ian.a); + + const ipa = + px < 0 || ay >= imageHeight + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, px, ay); + const ica = + ay >= imageHeight + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, x, ay); + const ina = + nx >= imageWidth || ay >= imageHeight + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, nx, ay); + const iaa = + ax >= imageWidth || ay >= imageHeight + ? icc + : pixelRGBA(imageData, imageWidth, imageHeight, ax, ay); + + const ia0 = cubic(dx, ipa.r, ica.r, ina.r, iaa.r); + const ia1 = cubic(dx, ipa.g, ica.g, ina.g, iaa.g); + const ia2 = cubic(dx, ipa.b, ica.b, ina.b, iaa.b); + // const ia3 = cubic(dx, ipa.a, ica.a, ina.a, iaa.a); + + const c0 = Math.trunc(clamp(cubic(dy, ip0, ic0, in0, ia0), 0, 255)); + const c1 = Math.trunc(clamp(cubic(dy, ip1, ic1, in1, ia1), 0, 255)); + const c2 = Math.trunc(clamp(cubic(dy, ip2, ic2, in2, ia2), 0, 255)); + // const c3 = cubic(dy, ip3, ic3, in3, ia3); + + return { r: c0, g: c1, b: c2 }; +}; + +/** + * Transform {@link inputData} starting at {@link inputStartIndex}. + */ +export const warpAffineFloat32List = ( + imageBitmap: ImageBitmap, + faceAlignmentAffineMatrix: number[][], + faceSize: number, + inputData: Float32Array, + inputStartIndex: number, +): void => { + const { width, height } = imageBitmap; + + // Get the pixel data. + const offscreenCanvas = new OffscreenCanvas(width, height); + const ctx = offscreenCanvas.getContext("2d"); + ctx.drawImage(imageBitmap, 0, 0, width, height); + const imageData = ctx.getImageData(0, 0, width, height); + const pixelData = imageData.data; + + const transformationMatrix = faceAlignmentAffineMatrix.map((row) => + row.map((val) => (val != 1.0 ? val * faceSize : 1.0)), + ); // 3x3 + + const A: Matrix = new Matrix([ + [transformationMatrix[0][0], transformationMatrix[0][1]], + [transformationMatrix[1][0], transformationMatrix[1][1]], + ]); + const Ainverse = inverse(A); + + const b00 = transformationMatrix[0][2]; + const b10 = transformationMatrix[1][2]; + const a00Prime = Ainverse.get(0, 0); + const a01Prime = Ainverse.get(0, 1); + const a10Prime = Ainverse.get(1, 0); + const a11Prime = Ainverse.get(1, 1); + + for (let yTrans = 0; yTrans < faceSize; ++yTrans) { + for (let xTrans = 0; xTrans < faceSize; ++xTrans) { + // Perform inverse affine transformation. + const xOrigin = + a00Prime * (xTrans - b00) + a01Prime * (yTrans - b10); + const yOrigin = + a10Prime * (xTrans - b00) + a11Prime * (yTrans - b10); + + // Get the pixel RGB using bicubic interpolation. + const { r, g, b } = pixelRGBBicubic( + xOrigin, + yOrigin, + pixelData, + width, + height, + ); + + // Set the pixel in the input data. + const index = (yTrans * faceSize + xTrans) * 3; + inputData[inputStartIndex + index] = rgbToBipolarFloat(r); + inputData[inputStartIndex + index + 1] = rgbToBipolarFloat(g); + inputData[inputStartIndex + index + 2] = rgbToBipolarFloat(b); + } } -} +}; -export async function getThumbnailImageBitmap(file: EnteFile) { - const thumb = await DownloadManager.getThumbnail(file); - log.info("[MLService] Got thumbnail: ", file.id.toString()); +/** Convert a RGB component 0-255 to a floating point value between -1 and 1. */ +const rgbToBipolarFloat = (pixelValue: number) => pixelValue / 127.5 - 1.0; - return createImageBitmap(new Blob([thumb])); -} +/** Convert a floating point value between -1 and 1 to a RGB component 0-255. */ +const bipolarFloatToRGB = (pixelValue: number) => + clamp(Math.round((pixelValue + 1.0) * 127.5), 0, 255); -export async function getLocalFileImageBitmap( - enteFile: EnteFile, - localFile: globalThis.File, -) { - let fileBlob = localFile as Blob; - fileBlob = await getRenderableImage(enteFile.metadata.title, fileBlob); - return createImageBitmap(fileBlob); -} +export const grayscaleIntMatrixFromNormalized2List = ( + imageList: Float32Array, + faceNumber: number, + width: number, + height: number, +): number[][] => { + const startIndex = faceNumber * width * height * 3; + return Array.from({ length: height }, (_, y) => + Array.from({ length: width }, (_, x) => { + // 0.299 ∙ Red + 0.587 ∙ Green + 0.114 ∙ Blue + const pixelIndex = startIndex + 3 * (y * width + x); + return clamp( + Math.round( + 0.299 * bipolarFloatToRGB(imageList[pixelIndex]) + + 0.587 * bipolarFloatToRGB(imageList[pixelIndex + 1]) + + 0.114 * bipolarFloatToRGB(imageList[pixelIndex + 2]), + ), + 0, + 255, + ); + }), + ); +}; diff --git a/web/apps/photos/src/services/face/people.ts b/web/apps/photos/src/services/face/people.ts index 416ba9e4e..d118cb4f9 100644 --- a/web/apps/photos/src/services/face/people.ts +++ b/web/apps/photos/src/services/face/people.ts @@ -1,37 +1,54 @@ -import log from "@/next/log"; -import mlIDbStorage from "services/face/db"; -import { Face, Person } from "services/face/types"; -import { type MLSyncContext } from "services/machineLearning/machineLearningService"; -import { clusterFaces } from "./cluster"; -import { saveFaceCrop } from "./f-index"; -import { fetchImageBitmap, getLocalFile } from "./image"; +export interface Person { + id: number; + name?: string; + files: Array; + displayFaceId?: string; +} + +// TODO-ML(MR): Forced disable clustering. It doesn't currently work, +// need to finalize it before we move out of beta. +// +// > Error: Failed to execute 'transferToImageBitmap' on +// > 'OffscreenCanvas': ImageBitmap construction failed + +/* +export const syncPeopleIndex = async () => { + + if ( + syncContext.outOfSyncFiles.length <= 0 || + (syncContext.nSyncedFiles === batchSize && Math.random() < 0) + ) { + await this.syncIndex(syncContext); + } + + public async syncIndex(syncContext: MLSyncContext) { + await this.getMLLibraryData(syncContext); + + await syncPeopleIndex(syncContext); + + await this.persistMLLibraryData(syncContext); + } -export const syncPeopleIndex = async (syncContext: MLSyncContext) => { const filesVersion = await mlIDbStorage.getIndexVersion("files"); if (filesVersion <= (await mlIDbStorage.getIndexVersion("people"))) { return; } + // TODO: have faces addresable through fileId + faceId // to avoid index based addressing, which is prone to wrong results // one way could be to match nearest face within threshold in the file + const allFacesMap = syncContext.allSyncedFacesMap ?? (syncContext.allSyncedFacesMap = await mlIDbStorage.getAllFacesMap()); - const allFaces = [...allFacesMap.values()].flat(); - await runFaceClustering(syncContext, allFaces); - await syncPeopleFromClusters(syncContext, allFacesMap, allFaces); - await mlIDbStorage.setIndexVersion("people", filesVersion); -}; - -const runFaceClustering = async ( - syncContext: MLSyncContext, - allFaces: Array, -) => { // await this.init(); + const allFacesMap = await mlIDbStorage.getAllFacesMap(); + const allFaces = [...allFacesMap.values()].flat(); + if (!allFaces || allFaces.length < 50) { log.info( `Skipping clustering since number of faces (${allFaces.length}) is less than the clustering threshold (50)`, @@ -40,34 +57,15 @@ const runFaceClustering = async ( } log.info("Running clustering allFaces: ", allFaces.length); - syncContext.mlLibraryData.faceClusteringResults = await clusterFaces( + const faceClusteringResults = await clusterFaces( allFaces.map((f) => Array.from(f.embedding)), ); - syncContext.mlLibraryData.faceClusteringMethod = { - value: "Hdbscan", - version: 1, - }; log.info( "[MLService] Got face clustering results: ", - JSON.stringify(syncContext.mlLibraryData.faceClusteringResults), + JSON.stringify(faceClusteringResults), ); - // syncContext.faceClustersWithNoise = { - // clusters: syncContext.faceClusteringResults.clusters.map( - // (faces) => ({ - // faces, - // }) - // ), - // noise: syncContext.faceClusteringResults.noise, - // }; -}; - -const syncPeopleFromClusters = async ( - syncContext: MLSyncContext, - allFacesMap: Map>, - allFaces: Array, -) => { - const clusters = syncContext.mlLibraryData.faceClusteringResults?.clusters; + const clusters = faceClusteringResults?.clusters; if (!clusters || clusters.length < 1) { return; } @@ -86,17 +84,18 @@ const syncPeopleFromClusters = async ( : best, ); + if (personFace && !personFace.crop?.cacheKey) { const file = await getLocalFile(personFace.fileId); const imageBitmap = await fetchImageBitmap(file); await saveFaceCrop(imageBitmap, personFace); } + const person: Person = { id: index, files: faces.map((f) => f.fileId), displayFaceId: personFace?.id, - faceCropCacheKey: personFace?.crop?.cacheKey, }; await mlIDbStorage.putPerson(person); @@ -108,4 +107,24 @@ const syncPeopleFromClusters = async ( } await mlIDbStorage.updateFaces(allFacesMap); + + // await mlIDbStorage.setIndexVersion("people", filesVersion); }; + + public async regenerateFaceCrop(token: string, faceID: string) { + await downloadManager.init(APPS.PHOTOS, { token }); + return mlService.regenerateFaceCrop(faceID); + } + +export const regenerateFaceCrop = async (faceID: string) => { + const fileID = Number(faceID.split("-")[0]); + const personFace = await mlIDbStorage.getFace(fileID, faceID); + if (!personFace) { + throw Error("Face not found"); + } + + const file = await getLocalFile(personFace.fileId); + const imageBitmap = await fetchImageBitmap(file); + return await saveFaceCrop(imageBitmap, personFace); +}; +*/ diff --git a/web/apps/photos/src/services/face/remote.ts b/web/apps/photos/src/services/face/remote.ts new file mode 100644 index 000000000..fcd8775a9 --- /dev/null +++ b/web/apps/photos/src/services/face/remote.ts @@ -0,0 +1,158 @@ +import log from "@/next/log"; +import ComlinkCryptoWorker from "@ente/shared/crypto"; +import { putEmbedding } from "services/embeddingService"; +import type { EnteFile } from "types/file"; +import type { Point } from "./geom"; +import type { Face, FaceDetection, MlFileData } from "./types"; + +export const putFaceEmbedding = async ( + enteFile: EnteFile, + mlFileData: MlFileData, + userAgent: string, +) => { + const serverMl = LocalFileMlDataToServerFileMl(mlFileData, userAgent); + log.debug(() => ({ t: "Local ML file data", mlFileData })); + log.debug(() => ({ + t: "Uploaded ML file data", + d: JSON.stringify(serverMl), + })); + + const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance(); + const { file: encryptedEmbeddingData } = + await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key); + log.info( + `putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`, + ); + const res = await putEmbedding({ + fileID: enteFile.id, + encryptedEmbedding: encryptedEmbeddingData.encryptedData, + decryptionHeader: encryptedEmbeddingData.decryptionHeader, + model: "file-ml-clip-face", + }); + log.info("putEmbedding response: ", res); +}; + +export interface FileML extends ServerFileMl { + updatedAt: number; +} + +class ServerFileMl { + public fileID: number; + public height?: number; + public width?: number; + public faceEmbedding: ServerFaceEmbeddings; + + public constructor( + fileID: number, + faceEmbedding: ServerFaceEmbeddings, + height?: number, + width?: number, + ) { + this.fileID = fileID; + this.height = height; + this.width = width; + this.faceEmbedding = faceEmbedding; + } +} + +class ServerFaceEmbeddings { + public faces: ServerFace[]; + public version: number; + public client: string; + + public constructor(faces: ServerFace[], client: string, version: number) { + this.faces = faces; + this.client = client; + this.version = version; + } +} + +class ServerFace { + public faceID: string; + public embedding: number[]; + public detection: ServerDetection; + public score: number; + public blur: number; + + public constructor( + faceID: string, + embedding: number[], + detection: ServerDetection, + score: number, + blur: number, + ) { + this.faceID = faceID; + this.embedding = embedding; + this.detection = detection; + this.score = score; + this.blur = blur; + } +} + +class ServerDetection { + public box: ServerFaceBox; + public landmarks: Point[]; + + public constructor(box: ServerFaceBox, landmarks: Point[]) { + this.box = box; + this.landmarks = landmarks; + } +} + +class ServerFaceBox { + public xMin: number; + public yMin: number; + public width: number; + public height: number; + + public constructor( + xMin: number, + yMin: number, + width: number, + height: number, + ) { + this.xMin = xMin; + this.yMin = yMin; + this.width = width; + this.height = height; + } +} + +function LocalFileMlDataToServerFileMl( + localFileMlData: MlFileData, + userAgent: string, +): ServerFileMl { + if (localFileMlData.errorCount > 0) { + return null; + } + const imageDimensions = localFileMlData.imageDimensions; + + const faces: ServerFace[] = []; + for (let i = 0; i < localFileMlData.faces.length; i++) { + const face: Face = localFileMlData.faces[i]; + const faceID = face.id; + const embedding = face.embedding; + const score = face.detection.probability; + const blur = face.blurValue; + const detection: FaceDetection = face.detection; + const box = detection.box; + const landmarks = detection.landmarks; + const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height); + + const newFaceObject = new ServerFace( + faceID, + Array.from(embedding), + new ServerDetection(newBox, landmarks), + score, + blur, + ); + faces.push(newFaceObject); + } + const faceEmbeddings = new ServerFaceEmbeddings(faces, userAgent, 1); + return new ServerFileMl( + localFileMlData.fileId, + faceEmbeddings, + imageDimensions.height, + imageDimensions.width, + ); +} diff --git a/web/apps/photos/src/services/face/transform-box.ts b/web/apps/photos/src/services/face/transform-box.ts new file mode 100644 index 000000000..01fa2a977 --- /dev/null +++ b/web/apps/photos/src/services/face/transform-box.ts @@ -0,0 +1,57 @@ +import { Box, Point } from "services/face/geom"; +import type { FaceDetection } from "services/face/types"; +// TODO-ML(LAURENS): Do we need two separate Matrix libraries? +// +// Keeping this in a separate file so that we can audit this. If these can be +// expressed using ml-matrix, then we can move this code to f-index.ts +import { + Matrix, + applyToPoint, + compose, + scale, + translate, +} from "transformation-matrix"; + +/** + * Transform the given {@link faceDetections} from their coordinate system in + * which they were detected ({@link inBox}) back to the coordinate system of the + * original image ({@link toBox}). + */ +export const transformFaceDetections = ( + faceDetections: FaceDetection[], + inBox: Box, + toBox: Box, +): FaceDetection[] => { + const transform = boxTransformationMatrix(inBox, toBox); + return faceDetections.map((f) => ({ + box: transformBox(f.box, transform), + landmarks: f.landmarks.map((p) => transformPoint(p, transform)), + probability: f.probability, + })); +}; + +const boxTransformationMatrix = (inBox: Box, toBox: Box): Matrix => + compose( + translate(toBox.x, toBox.y), + scale(toBox.width / inBox.width, toBox.height / inBox.height), + ); + +const transformPoint = (point: Point, transform: Matrix) => { + const txdPoint = applyToPoint(transform, point); + return new Point(txdPoint.x, txdPoint.y); +}; + +const transformBox = (box: Box, transform: Matrix) => { + const topLeft = transformPoint(new Point(box.x, box.y), transform); + const bottomRight = transformPoint( + new Point(box.x + box.width, box.y + box.height), + transform, + ); + + return new Box({ + x: topLeft.x, + y: topLeft.y, + width: bottomRight.x - topLeft.x, + height: bottomRight.y - topLeft.y, + }); +}; diff --git a/web/apps/photos/src/services/face/types.ts b/web/apps/photos/src/services/face/types.ts index 99244bf61..e1fa32785 100644 --- a/web/apps/photos/src/services/face/types.ts +++ b/web/apps/photos/src/services/face/types.ts @@ -1,161 +1,39 @@ -import type { ClusterFacesResult } from "services/face/cluster"; -import { Dimensions } from "services/face/geom"; -import { EnteFile } from "types/file"; -import { Box, Point } from "./geom"; - -export interface MLSyncResult { - nOutOfSyncFiles: number; - nSyncedFiles: number; - nSyncedFaces: number; - nFaceClusters: number; - nFaceNoise: number; - error?: Error; -} - -export declare type FaceDescriptor = Float32Array; - -export declare type Cluster = Array; - -export interface FacesCluster { - faces: Cluster; - summary?: FaceDescriptor; -} - -export interface FacesClustersWithNoise { - clusters: Array; - noise: Cluster; -} - -export interface NearestCluster { - cluster: FacesCluster; - distance: number; -} - -export declare type Landmark = Point; - -export declare type ImageType = "Original" | "Preview"; - -export declare type FaceDetectionMethod = "YoloFace"; - -export declare type FaceCropMethod = "ArcFace"; - -export declare type FaceAlignmentMethod = "ArcFace"; - -export declare type FaceEmbeddingMethod = "MobileFaceNet"; - -export declare type BlurDetectionMethod = "Laplacian"; - -export declare type ClusteringMethod = "Hdbscan" | "Dbscan"; - -export class AlignedBox { - box: Box; - rotation: number; -} - -export interface Versioned { - value: T; - version: number; -} +import { Box, Dimensions, Point } from "services/face/geom"; export interface FaceDetection { // box and landmarks is relative to image dimentions stored at mlFileData box: Box; - landmarks?: Array; + landmarks?: Point[]; probability?: number; } -export interface DetectedFace { - fileId: number; - detection: FaceDetection; -} - -export interface DetectedFaceWithId extends DetectedFace { - id: string; -} - -export interface FaceCrop { - image: ImageBitmap; - // imageBox is relative to image dimentions stored at mlFileData - imageBox: Box; -} - -export interface StoredFaceCrop { - cacheKey: string; - imageBox: Box; -} - -export interface CroppedFace extends DetectedFaceWithId { - crop?: StoredFaceCrop; -} - export interface FaceAlignment { - // TODO: remove affine matrix as rotation, size and center + // TODO-ML(MR): remove affine matrix as rotation, size and center // are simple to store and use, affine matrix adds complexity while getting crop - affineMatrix: Array>; + affineMatrix: number[][]; rotation: number; // size and center is relative to image dimentions stored at mlFileData size: number; center: Point; } -export interface AlignedFace extends CroppedFace { +export interface Face { + fileId: number; + detection: FaceDetection; + id: string; + alignment?: FaceAlignment; blurValue?: number; -} -export declare type FaceEmbedding = Float32Array; + embedding?: Float32Array; -export interface FaceWithEmbedding extends AlignedFace { - embedding?: FaceEmbedding; -} - -export interface Face extends FaceWithEmbedding { personId?: number; } -export interface Person { - id: number; - name?: string; - files: Array; - displayFaceId?: string; - faceCropCacheKey?: string; -} - export interface MlFileData { fileId: number; faces?: Face[]; - imageSource?: ImageType; imageDimensions?: Dimensions; - faceDetectionMethod?: Versioned; - faceCropMethod?: Versioned; - faceAlignmentMethod?: Versioned; - faceEmbeddingMethod?: Versioned; mlVersion: number; errorCount: number; - lastErrorMessage?: string; } - -export interface MLSearchConfig { - enabled: boolean; -} - -export interface MLSyncFileContext { - enteFile: EnteFile; - localFile?: globalThis.File; - - oldMlFile?: MlFileData; - newMlFile?: MlFileData; - - imageBitmap?: ImageBitmap; - - newDetection?: boolean; - newAlignment?: boolean; -} - -export interface MLLibraryData { - faceClusteringMethod?: Versioned; - faceClusteringResults?: ClusterFacesResult; - faceClustersWithNoise?: FacesClustersWithNoise; -} - -export declare type MLIndex = "files" | "people"; diff --git a/web/apps/photos/src/services/heic-convert.ts b/web/apps/photos/src/services/heic-convert.ts index 2b37c3198..d2e05d9ec 100644 --- a/web/apps/photos/src/services/heic-convert.ts +++ b/web/apps/photos/src/services/heic-convert.ts @@ -51,9 +51,7 @@ class HEICConverter { const startTime = Date.now(); const convertedHEIC = await worker.heicToJPEG(fileBlob); - const ms = Math.round( - Date.now() - startTime, - ); + const ms = Date.now() - startTime; log.debug(() => `heic => jpeg (${ms} ms)`); clearTimeout(timeout); resolve(convertedHEIC); diff --git a/web/apps/photos/src/services/machineLearning/machineLearningService.ts b/web/apps/photos/src/services/machineLearning/machineLearningService.ts index 43e0459ce..954a88c66 100644 --- a/web/apps/photos/src/services/machineLearning/machineLearningService.ts +++ b/web/apps/photos/src/services/machineLearning/machineLearningService.ts @@ -1,41 +1,26 @@ -import { haveWindow } from "@/next/env"; import log from "@/next/log"; -import { ComlinkWorker } from "@/next/worker/comlink-worker"; -import ComlinkCryptoWorker, { - getDedicatedCryptoWorker, -} from "@ente/shared/crypto"; -import { DedicatedCryptoWorker } from "@ente/shared/crypto/internal/crypto.worker"; import { CustomError, parseUploadErrorCodes } from "@ente/shared/error"; import PQueue from "p-queue"; -import { putEmbedding } from "services/embeddingService"; -import mlIDbStorage, { ML_SEARCH_CONFIG_NAME } from "services/face/db"; -import { - Face, - FaceDetection, - Landmark, - MLLibraryData, - MLSearchConfig, - MLSyncFileContext, - MLSyncResult, - MlFileData, -} from "services/face/types"; +import mlIDbStorage, { + ML_SEARCH_CONFIG_NAME, + type MinimalPersistedFileData, +} from "services/face/db"; +import { putFaceEmbedding } from "services/face/remote"; import { getLocalFiles } from "services/fileService"; import { EnteFile } from "types/file"; import { isInternalUserForML } from "utils/user"; -import { regenerateFaceCrop, syncFileAnalyzeFaces } from "../face/f-index"; -import { fetchImageBitmapForContext } from "../face/image"; -import { syncPeopleIndex } from "../face/people"; +import { indexFaces } from "../face/f-index"; -/** - * TODO-ML(MR): What and why. - * Also, needs to be 1 (in sync with mobile) when we move out of beta. - */ -export const defaultMLVersion = 3; +export const defaultMLVersion = 1; const batchSize = 200; export const MAX_ML_SYNC_ERROR_COUNT = 1; +export interface MLSearchConfig { + enabled: boolean; +} + export const DEFAULT_ML_SEARCH_CONFIG: MLSearchConfig = { enabled: false, }; @@ -56,107 +41,54 @@ export async function updateMLSearchConfig(newConfig: MLSearchConfig) { return mlIDbStorage.putConfig(ML_SEARCH_CONFIG_NAME, newConfig); } -export interface MLSyncContext { - token: string; - userID: number; - - localFilesMap: Map; - outOfSyncFiles: EnteFile[]; - nSyncedFiles: number; - nSyncedFaces: number; - allSyncedFacesMap?: Map>; - - error?: Error; - - // oldMLLibraryData: MLLibraryData; - mlLibraryData: MLLibraryData; - - syncQueue: PQueue; - - getEnteWorker(id: number): Promise; - dispose(): Promise; -} - -export class LocalMLSyncContext implements MLSyncContext { +class MLSyncContext { public token: string; public userID: number; + public userAgent: string; public localFilesMap: Map; public outOfSyncFiles: EnteFile[]; public nSyncedFiles: number; - public nSyncedFaces: number; - public allSyncedFacesMap?: Map>; - public error?: Error; - public mlLibraryData: MLLibraryData; - public syncQueue: PQueue; - // TODO: wheather to limit concurrent downloads - // private downloadQueue: PQueue; - private concurrency: number; - private comlinkCryptoWorker: Array< - ComlinkWorker - >; - private enteWorkers: Array; - - constructor(token: string, userID: number, concurrency?: number) { + constructor(token: string, userID: number, userAgent: string) { this.token = token; this.userID = userID; + this.userAgent = userAgent; this.outOfSyncFiles = []; this.nSyncedFiles = 0; - this.nSyncedFaces = 0; - this.concurrency = concurrency ?? getConcurrency(); - - log.info("Using concurrency: ", this.concurrency); - // timeout is added on downloads - // timeout on queue will keep the operation open till worker is terminated - this.syncQueue = new PQueue({ concurrency: this.concurrency }); - logQueueStats(this.syncQueue, "sync"); - // this.downloadQueue = new PQueue({ concurrency: 1 }); - // logQueueStats(this.downloadQueue, 'download'); - - this.comlinkCryptoWorker = new Array(this.concurrency); - this.enteWorkers = new Array(this.concurrency); - } - - public async getEnteWorker(id: number): Promise { - const wid = id % this.enteWorkers.length; - console.log("getEnteWorker: ", id, wid); - if (!this.enteWorkers[wid]) { - this.comlinkCryptoWorker[wid] = getDedicatedCryptoWorker(); - this.enteWorkers[wid] = await this.comlinkCryptoWorker[wid].remote; - } - - return this.enteWorkers[wid]; + const concurrency = getConcurrency(); + this.syncQueue = new PQueue({ concurrency }); } public async dispose() { this.localFilesMap = undefined; await this.syncQueue.onIdle(); this.syncQueue.removeAllListeners(); - for (const enteComlinkWorker of this.comlinkCryptoWorker) { - enteComlinkWorker?.terminate(); - } } } -export const getConcurrency = () => - haveWindow() && Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2)); +const getConcurrency = () => + Math.max(2, Math.ceil(navigator.hardwareConcurrency / 2)); class MachineLearningService { private localSyncContext: Promise; private syncContext: Promise; - public async sync(token: string, userID: number): Promise { + public async sync( + token: string, + userID: number, + userAgent: string, + ): Promise { if (!token) { throw Error("Token needed by ml service to sync file"); } - const syncContext = await this.getSyncContext(token, userID); + const syncContext = await this.getSyncContext(token, userID, userAgent); await this.syncLocalFiles(syncContext); @@ -166,38 +98,9 @@ class MachineLearningService { await this.syncFiles(syncContext); } - // TODO-ML(MR): Forced disable clustering. It doesn't currently work, - // need to finalize it before we move out of beta. - // - // > Error: Failed to execute 'transferToImageBitmap' on - // > 'OffscreenCanvas': ImageBitmap construction failed - /* - if ( - syncContext.outOfSyncFiles.length <= 0 || - (syncContext.nSyncedFiles === batchSize && Math.random() < 0) - ) { - await this.syncIndex(syncContext); - } - */ - - const mlSyncResult: MLSyncResult = { - nOutOfSyncFiles: syncContext.outOfSyncFiles.length, - nSyncedFiles: syncContext.nSyncedFiles, - nSyncedFaces: syncContext.nSyncedFaces, - nFaceClusters: - syncContext.mlLibraryData?.faceClusteringResults?.clusters - .length, - nFaceNoise: - syncContext.mlLibraryData?.faceClusteringResults?.noise.length, - error: syncContext.error, - }; - // log.info('[MLService] sync results: ', mlSyncResult); - - return mlSyncResult; - } - - public async regenerateFaceCrop(faceID: string) { - return regenerateFaceCrop(faceID); + const error = syncContext.error; + const nOutOfSyncFiles = syncContext.outOfSyncFiles.length; + return !error && nOutOfSyncFiles > 0; } private newMlData(fileId: number) { @@ -205,7 +108,7 @@ class MachineLearningService { fileId, mlVersion: 0, errorCount: 0, - } as MlFileData; + } as MinimalPersistedFileData; } private async getLocalFilesMap(syncContext: MLSyncContext) { @@ -309,7 +212,6 @@ class MachineLearningService { syncContext.error = error; } await syncContext.syncQueue.onIdle(); - log.info("allFaces: ", syncContext.nSyncedFaces); // TODO: In case syncJob has to use multiple ml workers // do in same transaction with each file update @@ -318,13 +220,17 @@ class MachineLearningService { // await this.disposeMLModels(); } - private async getSyncContext(token: string, userID: number) { + private async getSyncContext( + token: string, + userID: number, + userAgent: string, + ) { if (!this.syncContext) { log.info("Creating syncContext"); // TODO-ML(MR): Keep as promise for now. this.syncContext = new Promise((resolve) => { - resolve(new LocalMLSyncContext(token, userID)); + resolve(new MLSyncContext(token, userID, userAgent)); }); } else { log.info("reusing existing syncContext"); @@ -332,13 +238,17 @@ class MachineLearningService { return this.syncContext; } - private async getLocalSyncContext(token: string, userID: number) { + private async getLocalSyncContext( + token: string, + userID: number, + userAgent: string, + ) { // TODO-ML(MR): This is updating the file ML version. verify. if (!this.localSyncContext) { log.info("Creating localSyncContext"); // TODO-ML(MR): this.localSyncContext = new Promise((resolve) => { - resolve(new LocalMLSyncContext(token, userID)); + resolve(new MLSyncContext(token, userID, userAgent)); }); } else { log.info("reusing existing localSyncContext"); @@ -358,10 +268,15 @@ class MachineLearningService { public async syncLocalFile( token: string, userID: number, + userAgent: string, enteFile: EnteFile, localFile?: globalThis.File, ) { - const syncContext = await this.getLocalSyncContext(token, userID); + const syncContext = await this.getLocalSyncContext( + token, + userID, + userAgent, + ); try { await this.syncFileWithErrorHandler( @@ -385,11 +300,11 @@ class MachineLearningService { localFile?: globalThis.File, ) { try { - console.log( - `Indexing ${enteFile.title ?? ""} ${enteFile.id}`, + const mlFileData = await this.syncFile( + enteFile, + localFile, + syncContext.userAgent, ); - const mlFileData = await this.syncFile(enteFile, localFile); - syncContext.nSyncedFaces += mlFileData.faces?.length || 0; syncContext.nSyncedFiles += 1; return mlFileData; } catch (e) { @@ -421,62 +336,22 @@ class MachineLearningService { } } - private async syncFile(enteFile: EnteFile, localFile?: globalThis.File) { - log.debug(() => ({ a: "Syncing file", enteFile })); - const fileContext: MLSyncFileContext = { enteFile, localFile }; - const oldMlFile = await this.getMLFileData(enteFile.id); + private async syncFile( + enteFile: EnteFile, + localFile: globalThis.File | undefined, + userAgent: string, + ) { + const oldMlFile = await mlIDbStorage.getFile(enteFile.id); if (oldMlFile && oldMlFile.mlVersion) { return oldMlFile; } - const newMlFile = (fileContext.newMlFile = this.newMlData(enteFile.id)); - newMlFile.mlVersion = defaultMLVersion; - - try { - await fetchImageBitmapForContext(fileContext); - await syncFileAnalyzeFaces(fileContext); - newMlFile.errorCount = 0; - newMlFile.lastErrorMessage = undefined; - await this.persistOnServer(newMlFile, enteFile); - await mlIDbStorage.putFile(newMlFile); - } catch (e) { - log.error("ml detection failed", e); - newMlFile.mlVersion = oldMlFile.mlVersion; - throw e; - } finally { - fileContext.imageBitmap && fileContext.imageBitmap.close(); - } - + const newMlFile = await indexFaces(enteFile, localFile); + await putFaceEmbedding(enteFile, newMlFile, userAgent); + await mlIDbStorage.putFile(newMlFile); return newMlFile; } - private async persistOnServer(mlFileData: MlFileData, enteFile: EnteFile) { - const serverMl = LocalFileMlDataToServerFileMl(mlFileData); - log.debug(() => ({ t: "Local ML file data", mlFileData })); - log.debug(() => ({ - t: "Uploaded ML file data", - d: JSON.stringify(serverMl), - })); - - const comlinkCryptoWorker = await ComlinkCryptoWorker.getInstance(); - const { file: encryptedEmbeddingData } = - await comlinkCryptoWorker.encryptMetadata(serverMl, enteFile.key); - log.info( - `putEmbedding embedding to server for file: ${enteFile.metadata.title} fileID: ${enteFile.id}`, - ); - const res = await putEmbedding({ - fileID: enteFile.id, - encryptedEmbedding: encryptedEmbeddingData.encryptedData, - decryptionHeader: encryptedEmbeddingData.decryptionHeader, - model: "file-ml-clip-face", - }); - log.info("putEmbedding response: ", res); - } - - private async getMLFileData(fileId: number) { - return mlIDbStorage.getFile(fileId); - } - private async persistMLFileSyncError(enteFile: EnteFile, e: Error) { try { await mlIDbStorage.upsertFileInTx(enteFile.id, (mlFileData) => { @@ -484,7 +359,7 @@ class MachineLearningService { mlFileData = this.newMlData(enteFile.id); } mlFileData.errorCount = (mlFileData.errorCount || 0) + 1; - mlFileData.lastErrorMessage = e.message; + console.error(`lastError for ${enteFile.id}`, e); return mlFileData; }); @@ -493,183 +368,6 @@ class MachineLearningService { console.error("Error while storing ml sync error", e); } } - - private async getMLLibraryData(syncContext: MLSyncContext) { - syncContext.mlLibraryData = await mlIDbStorage.getLibraryData(); - if (!syncContext.mlLibraryData) { - syncContext.mlLibraryData = {}; - } - } - - private async persistMLLibraryData(syncContext: MLSyncContext) { - return mlIDbStorage.putLibraryData(syncContext.mlLibraryData); - } - - public async syncIndex(syncContext: MLSyncContext) { - await this.getMLLibraryData(syncContext); - - // TODO-ML(MR): Ensure this doesn't run until fixed. - await syncPeopleIndex(syncContext); - - await this.persistMLLibraryData(syncContext); - } } export default new MachineLearningService(); - -export interface FileML extends ServerFileMl { - updatedAt: number; -} - -class ServerFileMl { - public fileID: number; - public height?: number; - public width?: number; - public faceEmbedding: ServerFaceEmbeddings; - - public constructor( - fileID: number, - faceEmbedding: ServerFaceEmbeddings, - height?: number, - width?: number, - ) { - this.fileID = fileID; - this.height = height; - this.width = width; - this.faceEmbedding = faceEmbedding; - } -} - -class ServerFaceEmbeddings { - public faces: ServerFace[]; - public version: number; - public client?: string; - public error?: boolean; - - public constructor( - faces: ServerFace[], - version: number, - client?: string, - error?: boolean, - ) { - this.faces = faces; - this.version = version; - this.client = client; - this.error = error; - } -} - -class ServerFace { - public faceID: string; - public embeddings: number[]; - public detection: ServerDetection; - public score: number; - public blur: number; - - public constructor( - faceID: string, - embeddings: number[], - detection: ServerDetection, - score: number, - blur: number, - ) { - this.faceID = faceID; - this.embeddings = embeddings; - this.detection = detection; - this.score = score; - this.blur = blur; - } -} - -class ServerDetection { - public box: ServerFaceBox; - public landmarks: Landmark[]; - - public constructor(box: ServerFaceBox, landmarks: Landmark[]) { - this.box = box; - this.landmarks = landmarks; - } -} - -class ServerFaceBox { - public xMin: number; - public yMin: number; - public width: number; - public height: number; - - public constructor( - xMin: number, - yMin: number, - width: number, - height: number, - ) { - this.xMin = xMin; - this.yMin = yMin; - this.width = width; - this.height = height; - } -} - -function LocalFileMlDataToServerFileMl( - localFileMlData: MlFileData, -): ServerFileMl { - if ( - localFileMlData.errorCount > 0 && - localFileMlData.lastErrorMessage !== undefined - ) { - return null; - } - const imageDimensions = localFileMlData.imageDimensions; - - const faces: ServerFace[] = []; - for (let i = 0; i < localFileMlData.faces.length; i++) { - const face: Face = localFileMlData.faces[i]; - const faceID = face.id; - const embedding = face.embedding; - const score = face.detection.probability; - const blur = face.blurValue; - const detection: FaceDetection = face.detection; - const box = detection.box; - const landmarks = detection.landmarks; - const newBox = new ServerFaceBox(box.x, box.y, box.width, box.height); - const newLandmarks: Landmark[] = []; - for (let j = 0; j < landmarks.length; j++) { - newLandmarks.push({ - x: landmarks[j].x, - y: landmarks[j].y, - } as Landmark); - } - - const newFaceObject = new ServerFace( - faceID, - Array.from(embedding), - new ServerDetection(newBox, newLandmarks), - score, - blur, - ); - faces.push(newFaceObject); - } - const faceEmbeddings = new ServerFaceEmbeddings( - faces, - 1, - localFileMlData.lastErrorMessage, - ); - return new ServerFileMl( - localFileMlData.fileId, - faceEmbeddings, - imageDimensions.height, - imageDimensions.width, - ); -} - -export function logQueueStats(queue: PQueue, name: string) { - queue.on("active", () => - log.info( - `queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`, - ), - ); - queue.on("idle", () => log.info(`queuestats: ${name}: Idle`)); - queue.on("error", (error) => - console.error(`queuestats: ${name}: Error, `, error), - ); -} diff --git a/web/apps/photos/src/services/machineLearning/mlWorkManager.ts b/web/apps/photos/src/services/machineLearning/mlWorkManager.ts index 1cb61af00..c1b2ef6a7 100644 --- a/web/apps/photos/src/services/machineLearning/mlWorkManager.ts +++ b/web/apps/photos/src/services/machineLearning/mlWorkManager.ts @@ -1,6 +1,8 @@ import { FILE_TYPE } from "@/media/file-type"; +import { ensureElectron } from "@/next/electron"; import log from "@/next/log"; import { ComlinkWorker } from "@/next/worker/comlink-worker"; +import { clientPackageNamePhotosDesktop } from "@ente/shared/apps/constants"; import { eventBus, Events } from "@ente/shared/events"; import { getToken, getUserID } from "@ente/shared/storage/localStorage/helpers"; import debounce from "debounce"; @@ -8,25 +10,18 @@ import PQueue from "p-queue"; import { createFaceComlinkWorker } from "services/face"; import mlIDbStorage from "services/face/db"; import type { DedicatedMLWorker } from "services/face/face.worker"; -import { MLSyncResult } from "services/face/types"; import { EnteFile } from "types/file"; -import { logQueueStats } from "./machineLearningService"; export type JobState = "Scheduled" | "Running" | "NotScheduled"; -export interface MLSyncJobResult { - shouldBackoff: boolean; - mlSyncResult: MLSyncResult; -} - export class MLSyncJob { - private runCallback: () => Promise; + private runCallback: () => Promise; private state: JobState; private stopped: boolean; private intervalSec: number; private nextTimeoutId: ReturnType; - constructor(runCallback: () => Promise) { + constructor(runCallback: () => Promise) { this.runCallback = runCallback; this.state = "NotScheduled"; this.stopped = true; @@ -65,13 +60,11 @@ export class MLSyncJob { this.state = "Running"; try { - const jobResult = await this.runCallback(); - if (jobResult && jobResult.shouldBackoff) { - this.intervalSec = Math.min(960, this.intervalSec * 2); - } else { + if (await this.runCallback()) { this.resetInterval(); + } else { + this.intervalSec = Math.min(960, this.intervalSec * 2); } - log.info("Job completed"); } catch (e) { console.error("Error while running Job: ", e); } finally { @@ -236,8 +229,15 @@ class MLWorkManager { this.stopSyncJob(); const token = getToken(); const userID = getUserID(); + const userAgent = await getUserAgent(); const mlWorker = await this.getLiveSyncWorker(); - return mlWorker.syncLocalFile(token, userID, enteFile, localFile); + return mlWorker.syncLocalFile( + token, + userID, + userAgent, + enteFile, + localFile, + ); }); } @@ -255,7 +255,14 @@ class MLWorkManager { this.syncJobWorker = undefined; } - private async runMLSyncJob(): Promise { + /** + * Returns `false` to indicate that either an error occurred, or there are + * not more files to process, or that we cannot currently process files. + * + * Which means that when it returns true, all is well and there are more + * things pending to process, so we should chug along at full speed. + */ + private async runMLSyncJob(): Promise { try { // TODO: skipping is not required if we are caching chunks through service worker // currently worker chunk itself is not loaded when network is not there @@ -263,29 +270,17 @@ class MLWorkManager { log.info( "Skipping ml-sync job run as not connected to internet.", ); - return { - shouldBackoff: true, - mlSyncResult: undefined, - }; + return false; } const token = getToken(); const userID = getUserID(); + const userAgent = await getUserAgent(); const jobWorkerProxy = await this.getSyncJobWorker(); - const mlSyncResult = await jobWorkerProxy.sync(token, userID); - + return await jobWorkerProxy.sync(token, userID, userAgent); // this.terminateSyncJobWorker(); - const jobResult: MLSyncJobResult = { - shouldBackoff: - !!mlSyncResult.error || mlSyncResult.nOutOfSyncFiles < 1, - mlSyncResult, - }; - log.info("ML Sync Job result: ", JSON.stringify(jobResult)); - // TODO: redirect/refresh to gallery in case of session_expired, stop ml sync job - - return jobResult; } catch (e) { log.error("Failed to run MLSync Job", e); } @@ -323,3 +318,22 @@ class MLWorkManager { } export default new MLWorkManager(); + +export function logQueueStats(queue: PQueue, name: string) { + queue.on("active", () => + log.info( + `queuestats: ${name}: Active, Size: ${queue.size} Pending: ${queue.pending}`, + ), + ); + queue.on("idle", () => log.info(`queuestats: ${name}: Idle`)); + queue.on("error", (error) => + console.error(`queuestats: ${name}: Error, `, error), + ); +} + +const getUserAgent = async () => { + const electron = ensureElectron(); + const name = clientPackageNamePhotosDesktop; + const version = await electron.appVersion(); + return `${name}/${version}`; +}; diff --git a/web/apps/photos/src/services/searchService.ts b/web/apps/photos/src/services/searchService.ts index d646ecd00..4bbab115c 100644 --- a/web/apps/photos/src/services/searchService.ts +++ b/web/apps/photos/src/services/searchService.ts @@ -3,7 +3,7 @@ import log from "@/next/log"; import * as chrono from "chrono-node"; import { t } from "i18next"; import mlIDbStorage from "services/face/db"; -import { Person } from "services/face/types"; +import type { Person } from "services/face/people"; import { defaultMLVersion } from "services/machineLearning/machineLearningService"; import { Collection } from "types/collection"; import { EntityType, LocationTag, LocationTagData } from "types/entity"; diff --git a/web/apps/photos/src/types/search/index.ts b/web/apps/photos/src/types/search/index.ts index aa5f12804..33f5eba9a 100644 --- a/web/apps/photos/src/types/search/index.ts +++ b/web/apps/photos/src/types/search/index.ts @@ -1,6 +1,6 @@ import { FILE_TYPE } from "@/media/file-type"; import { IndexStatus } from "services/face/db"; -import { Person } from "services/face/types"; +import type { Person } from "services/face/people"; import { City } from "services/locationSearchService"; import { LocationTagData } from "types/entity"; import { EnteFile } from "types/file"; diff --git a/web/apps/photos/src/utils/comlink/ComlinkSearchWorker.ts b/web/apps/photos/src/utils/comlink/ComlinkSearchWorker.ts index 4886bacda..0d7c52a96 100644 --- a/web/apps/photos/src/utils/comlink/ComlinkSearchWorker.ts +++ b/web/apps/photos/src/utils/comlink/ComlinkSearchWorker.ts @@ -5,11 +5,13 @@ import { type DedicatedSearchWorker } from "worker/search.worker"; class ComlinkSearchWorker { private comlinkWorkerInstance: Remote; + private comlinkWorker: ComlinkWorker; async getInstance() { if (!this.comlinkWorkerInstance) { - this.comlinkWorkerInstance = - await getDedicatedSearchWorker().remote; + if (!this.comlinkWorker) + this.comlinkWorker = getDedicatedSearchWorker(); + this.comlinkWorkerInstance = await this.comlinkWorker.remote; } return this.comlinkWorkerInstance; } diff --git a/web/apps/photos/src/utils/image/index.ts b/web/apps/photos/src/utils/image/index.ts deleted file mode 100644 index 7583f97c2..000000000 --- a/web/apps/photos/src/utils/image/index.ts +++ /dev/null @@ -1,468 +0,0 @@ -// these utils only work in env where OffscreenCanvas is available - -import { Matrix, inverse } from "ml-matrix"; -import { Box, Dimensions, enlargeBox } from "services/face/geom"; -import { FaceAlignment } from "services/face/types"; - -export function normalizePixelBetween0And1(pixelValue: number) { - return pixelValue / 255.0; -} - -export function normalizePixelBetweenMinus1And1(pixelValue: number) { - return pixelValue / 127.5 - 1.0; -} - -export function unnormalizePixelFromBetweenMinus1And1(pixelValue: number) { - return clamp(Math.round((pixelValue + 1.0) * 127.5), 0, 255); -} - -export function readPixelColor( - imageData: Uint8ClampedArray, - width: number, - height: number, - x: number, - y: number, -) { - if (x < 0 || x >= width || y < 0 || y >= height) { - return { r: 0, g: 0, b: 0, a: 0 }; - } - const index = (y * width + x) * 4; - return { - r: imageData[index], - g: imageData[index + 1], - b: imageData[index + 2], - a: imageData[index + 3], - }; -} - -export function clamp(value: number, min: number, max: number) { - return Math.min(max, Math.max(min, value)); -} - -export function getPixelBicubic( - fx: number, - fy: number, - imageData: Uint8ClampedArray, - imageWidth: number, - imageHeight: number, -) { - // Clamp to image boundaries - fx = clamp(fx, 0, imageWidth - 1); - fy = clamp(fy, 0, imageHeight - 1); - - const x = Math.trunc(fx) - (fx >= 0.0 ? 0 : 1); - const px = x - 1; - const nx = x + 1; - const ax = x + 2; - const y = Math.trunc(fy) - (fy >= 0.0 ? 0 : 1); - const py = y - 1; - const ny = y + 1; - const ay = y + 2; - const dx = fx - x; - const dy = fy - y; - - function cubic( - dx: number, - ipp: number, - icp: number, - inp: number, - iap: number, - ) { - return ( - icp + - 0.5 * - (dx * (-ipp + inp) + - dx * dx * (2 * ipp - 5 * icp + 4 * inp - iap) + - dx * dx * dx * (-ipp + 3 * icp - 3 * inp + iap)) - ); - } - - const icc = readPixelColor(imageData, imageWidth, imageHeight, x, y); - - const ipp = - px < 0 || py < 0 - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, px, py); - const icp = - px < 0 - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, x, py); - const inp = - py < 0 || nx >= imageWidth - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, nx, py); - const iap = - ax >= imageWidth || py < 0 - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, ax, py); - - const ip0 = cubic(dx, ipp.r, icp.r, inp.r, iap.r); - const ip1 = cubic(dx, ipp.g, icp.g, inp.g, iap.g); - const ip2 = cubic(dx, ipp.b, icp.b, inp.b, iap.b); - // const ip3 = cubic(dx, ipp.a, icp.a, inp.a, iap.a); - - const ipc = - px < 0 - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, px, y); - const inc = - nx >= imageWidth - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, nx, y); - const iac = - ax >= imageWidth - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, ax, y); - - const ic0 = cubic(dx, ipc.r, icc.r, inc.r, iac.r); - const ic1 = cubic(dx, ipc.g, icc.g, inc.g, iac.g); - const ic2 = cubic(dx, ipc.b, icc.b, inc.b, iac.b); - // const ic3 = cubic(dx, ipc.a, icc.a, inc.a, iac.a); - - const ipn = - px < 0 || ny >= imageHeight - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, px, ny); - const icn = - ny >= imageHeight - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, x, ny); - const inn = - nx >= imageWidth || ny >= imageHeight - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, nx, ny); - const ian = - ax >= imageWidth || ny >= imageHeight - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, ax, ny); - - const in0 = cubic(dx, ipn.r, icn.r, inn.r, ian.r); - const in1 = cubic(dx, ipn.g, icn.g, inn.g, ian.g); - const in2 = cubic(dx, ipn.b, icn.b, inn.b, ian.b); - // const in3 = cubic(dx, ipn.a, icn.a, inn.a, ian.a); - - const ipa = - px < 0 || ay >= imageHeight - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, px, ay); - const ica = - ay >= imageHeight - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, x, ay); - const ina = - nx >= imageWidth || ay >= imageHeight - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, nx, ay); - const iaa = - ax >= imageWidth || ay >= imageHeight - ? icc - : readPixelColor(imageData, imageWidth, imageHeight, ax, ay); - - const ia0 = cubic(dx, ipa.r, ica.r, ina.r, iaa.r); - const ia1 = cubic(dx, ipa.g, ica.g, ina.g, iaa.g); - const ia2 = cubic(dx, ipa.b, ica.b, ina.b, iaa.b); - // const ia3 = cubic(dx, ipa.a, ica.a, ina.a, iaa.a); - - const c0 = Math.trunc(clamp(cubic(dy, ip0, ic0, in0, ia0), 0, 255)); - const c1 = Math.trunc(clamp(cubic(dy, ip1, ic1, in1, ia1), 0, 255)); - const c2 = Math.trunc(clamp(cubic(dy, ip2, ic2, in2, ia2), 0, 255)); - // const c3 = cubic(dy, ip3, ic3, in3, ia3); - - return { r: c0, g: c1, b: c2 }; -} - -/// Returns the pixel value (RGB) at the given coordinates using bilinear interpolation. -export function getPixelBilinear( - fx: number, - fy: number, - imageData: Uint8ClampedArray, - imageWidth: number, - imageHeight: number, -) { - // Clamp to image boundaries - fx = clamp(fx, 0, imageWidth - 1); - fy = clamp(fy, 0, imageHeight - 1); - - // Get the surrounding coordinates and their weights - const x0 = Math.floor(fx); - const x1 = Math.ceil(fx); - const y0 = Math.floor(fy); - const y1 = Math.ceil(fy); - const dx = fx - x0; - const dy = fy - y0; - const dx1 = 1.0 - dx; - const dy1 = 1.0 - dy; - - // Get the original pixels - const pixel1 = readPixelColor(imageData, imageWidth, imageHeight, x0, y0); - const pixel2 = readPixelColor(imageData, imageWidth, imageHeight, x1, y0); - const pixel3 = readPixelColor(imageData, imageWidth, imageHeight, x0, y1); - const pixel4 = readPixelColor(imageData, imageWidth, imageHeight, x1, y1); - - function bilinear(val1: number, val2: number, val3: number, val4: number) { - return Math.round( - val1 * dx1 * dy1 + - val2 * dx * dy1 + - val3 * dx1 * dy + - val4 * dx * dy, - ); - } - - // Interpolate the pixel values - const red = bilinear(pixel1.r, pixel2.r, pixel3.r, pixel4.r); - const green = bilinear(pixel1.g, pixel2.g, pixel3.g, pixel4.g); - const blue = bilinear(pixel1.b, pixel2.b, pixel3.b, pixel4.b); - - return { r: red, g: green, b: blue }; -} - -export function warpAffineFloat32List( - imageBitmap: ImageBitmap, - faceAlignment: FaceAlignment, - faceSize: number, - inputData: Float32Array, - inputStartIndex: number, -): void { - // Get the pixel data - const offscreenCanvas = new OffscreenCanvas( - imageBitmap.width, - imageBitmap.height, - ); - const ctx = offscreenCanvas.getContext("2d"); - ctx.drawImage(imageBitmap, 0, 0, imageBitmap.width, imageBitmap.height); - const imageData = ctx.getImageData( - 0, - 0, - imageBitmap.width, - imageBitmap.height, - ); - const pixelData = imageData.data; - - const transformationMatrix = faceAlignment.affineMatrix.map((row) => - row.map((val) => (val != 1.0 ? val * faceSize : 1.0)), - ); // 3x3 - - const A: Matrix = new Matrix([ - [transformationMatrix[0][0], transformationMatrix[0][1]], - [transformationMatrix[1][0], transformationMatrix[1][1]], - ]); - const Ainverse = inverse(A); - - const b00 = transformationMatrix[0][2]; - const b10 = transformationMatrix[1][2]; - const a00Prime = Ainverse.get(0, 0); - const a01Prime = Ainverse.get(0, 1); - const a10Prime = Ainverse.get(1, 0); - const a11Prime = Ainverse.get(1, 1); - - for (let yTrans = 0; yTrans < faceSize; ++yTrans) { - for (let xTrans = 0; xTrans < faceSize; ++xTrans) { - // Perform inverse affine transformation - const xOrigin = - a00Prime * (xTrans - b00) + a01Prime * (yTrans - b10); - const yOrigin = - a10Prime * (xTrans - b00) + a11Prime * (yTrans - b10); - - // Get the pixel from interpolation - const pixel = getPixelBicubic( - xOrigin, - yOrigin, - pixelData, - imageBitmap.width, - imageBitmap.height, - ); - - // Set the pixel in the input data - const index = (yTrans * faceSize + xTrans) * 3; - inputData[inputStartIndex + index] = - normalizePixelBetweenMinus1And1(pixel.r); - inputData[inputStartIndex + index + 1] = - normalizePixelBetweenMinus1And1(pixel.g); - inputData[inputStartIndex + index + 2] = - normalizePixelBetweenMinus1And1(pixel.b); - } - } -} - -export function createGrayscaleIntMatrixFromNormalized2List( - imageList: Float32Array, - faceNumber: number, - width: number = 112, - height: number = 112, -): number[][] { - const startIndex = faceNumber * width * height * 3; - return Array.from({ length: height }, (_, y) => - Array.from({ length: width }, (_, x) => { - // 0.299 ∙ Red + 0.587 ∙ Green + 0.114 ∙ Blue - const pixelIndex = startIndex + 3 * (y * width + x); - return clamp( - Math.round( - 0.299 * - unnormalizePixelFromBetweenMinus1And1( - imageList[pixelIndex], - ) + - 0.587 * - unnormalizePixelFromBetweenMinus1And1( - imageList[pixelIndex + 1], - ) + - 0.114 * - unnormalizePixelFromBetweenMinus1And1( - imageList[pixelIndex + 2], - ), - ), - 0, - 255, - ); - }), - ); -} - -export function resizeToSquare(img: ImageBitmap, size: number) { - const scale = size / Math.max(img.height, img.width); - const width = scale * img.width; - const height = scale * img.height; - const offscreen = new OffscreenCanvas(size, size); - const ctx = offscreen.getContext("2d"); - ctx.imageSmoothingQuality = "high"; - ctx.drawImage(img, 0, 0, width, height); - const resizedImage = offscreen.transferToImageBitmap(); - return { image: resizedImage, width, height }; -} - -export function transform( - imageBitmap: ImageBitmap, - affineMat: number[][], - outputWidth: number, - outputHeight: number, -) { - const offscreen = new OffscreenCanvas(outputWidth, outputHeight); - const context = offscreen.getContext("2d"); - context.imageSmoothingQuality = "high"; - - context.transform( - affineMat[0][0], - affineMat[1][0], - affineMat[0][1], - affineMat[1][1], - affineMat[0][2], - affineMat[1][2], - ); - - context.drawImage(imageBitmap, 0, 0); - return offscreen.transferToImageBitmap(); -} - -export function crop(imageBitmap: ImageBitmap, cropBox: Box, size: number) { - const dimensions: Dimensions = { - width: size, - height: size, - }; - - return cropWithRotation(imageBitmap, cropBox, 0, dimensions, dimensions); -} - -export function cropWithRotation( - imageBitmap: ImageBitmap, - cropBox: Box, - rotation?: number, - maxSize?: Dimensions, - minSize?: Dimensions, -) { - const box = cropBox.round(); - - const outputSize = { width: box.width, height: box.height }; - if (maxSize) { - const minScale = Math.min( - maxSize.width / box.width, - maxSize.height / box.height, - ); - if (minScale < 1) { - outputSize.width = Math.round(minScale * box.width); - outputSize.height = Math.round(minScale * box.height); - } - } - - if (minSize) { - const maxScale = Math.max( - minSize.width / box.width, - minSize.height / box.height, - ); - if (maxScale > 1) { - outputSize.width = Math.round(maxScale * box.width); - outputSize.height = Math.round(maxScale * box.height); - } - } - - // log.info({ imageBitmap, box, outputSize }); - - const offscreen = new OffscreenCanvas(outputSize.width, outputSize.height); - const offscreenCtx = offscreen.getContext("2d"); - offscreenCtx.imageSmoothingQuality = "high"; - - offscreenCtx.translate(outputSize.width / 2, outputSize.height / 2); - rotation && offscreenCtx.rotate(rotation); - - const outputBox = new Box({ - x: -outputSize.width / 2, - y: -outputSize.height / 2, - width: outputSize.width, - height: outputSize.height, - }); - - const enlargedBox = enlargeBox(box, 1.5); - const enlargedOutputBox = enlargeBox(outputBox, 1.5); - - offscreenCtx.drawImage( - imageBitmap, - enlargedBox.x, - enlargedBox.y, - enlargedBox.width, - enlargedBox.height, - enlargedOutputBox.x, - enlargedOutputBox.y, - enlargedOutputBox.width, - enlargedOutputBox.height, - ); - - return offscreen.transferToImageBitmap(); -} - -export function addPadding(image: ImageBitmap, padding: number) { - const scale = 1 + padding * 2; - const width = scale * image.width; - const height = scale * image.height; - const offscreen = new OffscreenCanvas(width, height); - const ctx = offscreen.getContext("2d"); - ctx.imageSmoothingEnabled = false; - ctx.drawImage( - image, - width / 2 - image.width / 2, - height / 2 - image.height / 2, - image.width, - image.height, - ); - - return offscreen.transferToImageBitmap(); -} - -export interface BlobOptions { - type?: string; - quality?: number; -} - -export async function imageBitmapToBlob(imageBitmap: ImageBitmap) { - const offscreen = new OffscreenCanvas( - imageBitmap.width, - imageBitmap.height, - ); - offscreen.getContext("2d").drawImage(imageBitmap, 0, 0); - - return offscreen.convertToBlob({ - type: "image/jpeg", - quality: 0.8, - }); -} - -export async function imageBitmapFromBlob(blob: Blob) { - return createImageBitmap(blob); -} diff --git a/web/apps/photos/src/worker/ffmpeg.worker.ts b/web/apps/photos/src/worker/ffmpeg.worker.ts index d9d6c718f..06ba05be9 100644 --- a/web/apps/photos/src/worker/ffmpeg.worker.ts +++ b/web/apps/photos/src/worker/ffmpeg.worker.ts @@ -82,7 +82,7 @@ const ffmpegExec = async ( const result = ffmpeg.FS("readFile", outputPath); - const ms = Math.round(Date.now() - startTime); + const ms = Date.now() - startTime; log.debug(() => `[wasm] ffmpeg ${cmd.join(" ")} (${ms} ms)`); return result; } finally { diff --git a/web/package.json b/web/package.json index 25fea2176..ec096189a 100644 --- a/web/package.json +++ b/web/package.json @@ -22,10 +22,10 @@ "dev": "yarn dev:photos", "dev:accounts": "yarn workspace accounts next dev -p 3001", "dev:albums": "yarn workspace photos next dev -p 3002", - "dev:auth": "yarn workspace auth next dev", + "dev:auth": "yarn workspace auth next dev -p 3000", "dev:cast": "yarn workspace cast next dev -p 3001", "dev:payments": "yarn workspace payments dev", - "dev:photos": "yarn workspace photos next dev", + "dev:photos": "yarn workspace photos next dev -p 3000", "dev:staff": "yarn workspace staff dev", "lint": "concurrently --names 'prettier,eslint,tsc' \"yarn prettier --check --log-level warn .\" \"yarn workspaces run eslint --report-unused-disable-directives .\" \"yarn workspaces run tsc\"", "lint-fix": "concurrently --names 'prettier,eslint,tsc' \"yarn prettier --write --log-level warn .\" \"yarn workspaces run eslint --report-unused-disable-directives --fix .\" \"yarn workspaces run tsc\"", diff --git a/web/packages/accounts/services/logout.ts b/web/packages/accounts/services/logout.ts index 70d67b22f..1858ec7cc 100644 --- a/web/packages/accounts/services/logout.ts +++ b/web/packages/accounts/services/logout.ts @@ -1,4 +1,4 @@ -import { clearCaches } from "@/next/blob-cache"; +import { clearBlobCaches } from "@/next/blob-cache"; import log from "@/next/log"; import InMemoryStore from "@ente/shared/storage/InMemoryStore"; import localForage from "@ente/shared/storage/localForage"; @@ -43,7 +43,7 @@ export const accountLogout = async () => { log.error("Ignoring error during logout (local forage)", e); } try { - await clearCaches(); + await clearBlobCaches(); } catch (e) { log.error("Ignoring error during logout (cache)", e); } diff --git a/web/packages/next/blob-cache.ts b/web/packages/next/blob-cache.ts index e6c3734df..7223d0fdc 100644 --- a/web/packages/next/blob-cache.ts +++ b/web/packages/next/blob-cache.ts @@ -20,8 +20,8 @@ export type BlobCacheNamespace = (typeof blobCacheNames)[number]; * * This cache is suitable for storing large amounts of data (entire files). * - * To obtain a cache for a given namespace, use {@link openCache}. To clear all - * cached data (e.g. during logout), use {@link clearCaches}. + * To obtain a cache for a given namespace, use {@link openBlobCache}. To clear all + * cached data (e.g. during logout), use {@link clearBlobCaches}. * * [Note: Caching files] * @@ -69,14 +69,31 @@ export interface BlobCache { delete: (key: string) => Promise; } +const cachedCaches = new Map(); + /** * Return the {@link BlobCache} corresponding to the given {@link name}. * + * This is a wrapper over {@link openBlobCache} that caches (pun intended) the + * cache and returns the same one each time it is called with the same name. + * It'll open the cache lazily the first time it is invoked. + */ +export const blobCache = async ( + name: BlobCacheNamespace, +): Promise => { + let c = cachedCaches.get(name); + if (!c) cachedCaches.set(name, (c = await openBlobCache(name))); + return c; +}; + +/** + * Create a new {@link BlobCache} corresponding to the given {@link name}. + * * @param name One of the arbitrary but predefined namespaces of type * {@link BlobCacheNamespace} which group related data and allow us to use the * same key across namespaces. */ -export const openCache = async ( +export const openBlobCache = async ( name: BlobCacheNamespace, ): Promise => isElectron() ? openOPFSCacheWeb(name) : openWebCache(name); @@ -194,7 +211,7 @@ export const cachedOrNew = async ( key: string, get: () => Promise, ): Promise => { - const cache = await openCache(cacheName); + const cache = await openBlobCache(cacheName); const cachedBlob = await cache.get(key); if (cachedBlob) return cachedBlob; @@ -204,15 +221,17 @@ export const cachedOrNew = async ( }; /** - * Delete all cached data. + * Delete all cached data, including cached caches. * * Meant for use during logout, to reset the state of the user's account. */ -export const clearCaches = async () => - isElectron() ? clearOPFSCaches() : clearWebCaches(); +export const clearBlobCaches = async () => { + cachedCaches.clear(); + return isElectron() ? clearOPFSCaches() : clearWebCaches(); +}; const clearWebCaches = async () => { - await Promise.all(blobCacheNames.map((name) => caches.delete(name))); + await Promise.allSettled(blobCacheNames.map((name) => caches.delete(name))); }; const clearOPFSCaches = async () => { diff --git a/web/packages/next/locales/es-ES/translation.json b/web/packages/next/locales/es-ES/translation.json index ec46bb718..abd06b510 100644 --- a/web/packages/next/locales/es-ES/translation.json +++ b/web/packages/next/locales/es-ES/translation.json @@ -352,7 +352,7 @@ "ADD_COLLABORATORS": "", "ADD_NEW_EMAIL": "", "shared_with_people_zero": "", - "shared_with_people_one": "", + "shared_with_people_one": "Compartido con 1 persona", "shared_with_people_other": "", "participants_zero": "", "participants_one": "", @@ -362,8 +362,8 @@ "CHANGE_PERMISSIONS_TO_COLLABORATOR": "", "CONVERT_TO_VIEWER": "", "CONVERT_TO_COLLABORATOR": "", - "CHANGE_PERMISSION": "", - "REMOVE_PARTICIPANT": "", + "CHANGE_PERMISSION": "¿Cambiar Permiso?", + "REMOVE_PARTICIPANT": "¿Eliminar?", "CONFIRM_REMOVE": "", "MANAGE": "", "ADDED_AS": "", @@ -415,8 +415,8 @@ "albums_other": "{{count}} álbumes", "ALL_ALBUMS": "Todos los álbumes", "ALBUMS": "Álbumes", - "ALL_HIDDEN_ALBUMS": "", - "HIDDEN_ALBUMS": "", + "ALL_HIDDEN_ALBUMS": "Todos los álbumes ocultos", + "HIDDEN_ALBUMS": "Álbumes ocultos", "HIDDEN_ITEMS": "", "ENTER_TWO_FACTOR_OTP": "Ingrese el código de seis dígitos de su aplicación de autenticación a continuación.", "CREATE_ACCOUNT": "Crear cuenta", @@ -518,7 +518,7 @@ "PUBLIC_COLLECT_SUBTEXT": "Permitir a las personas con el enlace añadir fotos al álbum compartido.", "STOP_EXPORT": "Stop", "EXPORT_PROGRESS": "{{progress.success}} / {{progress.total}} archivos exportados", - "MIGRATING_EXPORT": "", + "MIGRATING_EXPORT": "Preparando...", "RENAMING_COLLECTION_FOLDERS": "", "TRASHING_DELETED_FILES": "", "TRASHING_DELETED_COLLECTIONS": "", @@ -543,7 +543,7 @@ "at": "a las", "AUTH_NEXT": "siguiente", "AUTH_DOWNLOAD_MOBILE_APP": "Descarga nuestra aplicación móvil para administrar tus secretos", - "HIDDEN": "", + "HIDDEN": "Oculto", "HIDE": "Ocultar", "UNHIDE": "Mostrar", "UNHIDE_TO_COLLECTION": "Hacer visible al álbum", @@ -571,7 +571,7 @@ "CONVERT": "", "CONFIRM_EDITOR_CLOSE_MESSAGE": "", "CONFIRM_EDITOR_CLOSE_DESCRIPTION": "", - "BRIGHTNESS": "", + "BRIGHTNESS": "Brillo", "CONTRAST": "", "SATURATION": "", "BLUR": "", @@ -620,7 +620,7 @@ "PASSKEY_LOGIN_FAILED": "", "PASSKEY_LOGIN_URL_INVALID": "", "PASSKEY_LOGIN_ERRORED": "", - "TRY_AGAIN": "", + "TRY_AGAIN": "Inténtelo de nuevo", "PASSKEY_FOLLOW_THE_STEPS_FROM_YOUR_BROWSER": "", "LOGIN_WITH_PASSKEY": "", "autogenerated_first_album_name": "", diff --git a/web/packages/next/locales/ru-RU/translation.json b/web/packages/next/locales/ru-RU/translation.json index 2d2af0293..7861d339a 100644 --- a/web/packages/next/locales/ru-RU/translation.json +++ b/web/packages/next/locales/ru-RU/translation.json @@ -168,7 +168,7 @@ "UPDATE_PAYMENT_METHOD": "Обновить платёжную информацию", "MONTHLY": "Ежемесячно", "YEARLY": "Ежегодно", - "update_subscription_title": "", + "update_subscription_title": "Подтвердить изменение плана", "UPDATE_SUBSCRIPTION_MESSAGE": "Хотите сменить текущий план?", "UPDATE_SUBSCRIPTION": "Изменить план", "CANCEL_SUBSCRIPTION": "Отменить подписку", @@ -623,6 +623,6 @@ "TRY_AGAIN": "Пробовать снова", "PASSKEY_FOLLOW_THE_STEPS_FROM_YOUR_BROWSER": "Следуйте инструкциям в вашем браузере, чтобы продолжить вход в систему.", "LOGIN_WITH_PASSKEY": "Войдите в систему с помощью пароля", - "autogenerated_first_album_name": "", - "autogenerated_default_album_name": "" + "autogenerated_first_album_name": "Мой первый альбом", + "autogenerated_default_album_name": "Новый альбом" } diff --git a/web/packages/next/types/ipc.ts b/web/packages/next/types/ipc.ts index 7d5866cdb..806a00cd5 100644 --- a/web/packages/next/types/ipc.ts +++ b/web/packages/next/types/ipc.ts @@ -297,7 +297,9 @@ export interface Electron { * * @returns A CLIP embedding. */ - clipImageEmbedding: (jpegImageData: Uint8Array) => Promise; + computeCLIPImageEmbedding: ( + jpegImageData: Uint8Array, + ) => Promise; /** * Return a CLIP embedding of the given image if we already have the model @@ -319,7 +321,7 @@ export interface Electron { * * @returns A CLIP embedding. */ - clipTextEmbeddingIfAvailable: ( + computeCLIPTextEmbeddingIfAvailable: ( text: string, ) => Promise; @@ -337,29 +339,7 @@ export interface Electron { * Both the input and output are opaque binary data whose internal structure * is specific to our implementation and the model (MobileFaceNet) we use. */ - faceEmbeddings: (input: Float32Array) => Promise; - - /** - * Return a face crop stored by a previous version of ML. - * - * [Note: Legacy face crops] - * - * Older versions of ML generated and stored face crops in a "face-crops" - * cache directory on the Electron side. For the time being, we have - * disabled the face search whilst we put finishing touches to it. However, - * it'll be nice to still show the existing faces that have been clustered - * for people who opted in to the older beta. - * - * So we retain the older "face-crops" disk cache, and use this method to - * serve faces from it when needed. - * - * @param faceID An identifier corresponding to which the face crop had been - * stored by the older version of our app. - * - * @returns the JPEG data of the face crop if a file is found for the given - * {@link faceID}, otherwise undefined. - */ - legacyFaceCrop: (faceID: string) => Promise; + computeFaceEmbeddings: (input: Float32Array) => Promise; // - Watch diff --git a/web/packages/next/worker/comlink-worker.ts b/web/packages/next/worker/comlink-worker.ts index cb90d85f8..b388cd413 100644 --- a/web/packages/next/worker/comlink-worker.ts +++ b/web/packages/next/worker/comlink-worker.ts @@ -47,8 +47,8 @@ const workerBridge = { convertToJPEG: (imageData: Uint8Array) => ensureElectron().convertToJPEG(imageData), detectFaces: (input: Float32Array) => ensureElectron().detectFaces(input), - faceEmbeddings: (input: Float32Array) => - ensureElectron().faceEmbeddings(input), + computeFaceEmbeddings: (input: Float32Array) => + ensureElectron().computeFaceEmbeddings(input), }; export type WorkerBridge = typeof workerBridge; diff --git a/web/packages/shared/apps/constants.ts b/web/packages/shared/apps/constants.ts index d35a5e8c4..b679fb912 100644 --- a/web/packages/shared/apps/constants.ts +++ b/web/packages/shared/apps/constants.ts @@ -14,6 +14,8 @@ export const CLIENT_PACKAGE_NAMES = new Map([ [APPS.ACCOUNTS, "io.ente.accounts.web"], ]); +export const clientPackageNamePhotosDesktop = "io.ente.photos.desktop"; + export const APP_TITLES = new Map([ [APPS.ALBUMS, "Ente Albums"], [APPS.PHOTOS, "Ente Photos"], diff --git a/web/packages/shared/network/HTTPService.ts b/web/packages/shared/network/HTTPService.ts index eda0709f5..7ef99e0d7 100644 --- a/web/packages/shared/network/HTTPService.ts +++ b/web/packages/shared/network/HTTPService.ts @@ -28,8 +28,8 @@ class HTTPService { const responseData = response.data; log.error( `HTTP Service Error - ${JSON.stringify({ - url: config.url, - method: config.method, + url: config?.url, + method: config?.method, xRequestId: response.headers["x-request-id"], httpStatus: response.status, errMessage: responseData.message, diff --git a/web/packages/utils/promise.ts b/web/packages/utils/promise.ts index 4cb7648fd..34f821b6d 100644 --- a/web/packages/utils/promise.ts +++ b/web/packages/utils/promise.ts @@ -10,6 +10,10 @@ export const wait = (ms: number) => /** * Await the given {@link promise} for {@link timeoutMS} milliseconds. If it * does not resolve within {@link timeoutMS}, then reject with a timeout error. + * + * Note that this does not abort {@link promise} itself - it will still get + * resolved to completion, just its result will be ignored if it gets resolved + * after we've already timed out. */ export const withTimeout = async (promise: Promise, ms: number) => { let timeoutId: ReturnType;