import "dart:convert"; import "dart:typed_data"; import "package:computer/computer.dart"; import "package:logging/logging.dart"; import "package:photos/core/network/network.dart"; import "package:photos/db/files_db.dart"; import "package:photos/db/object_box.dart"; import "package:photos/models/embedding.dart"; import "package:photos/models/file/file.dart"; import "package:photos/objectbox.g.dart"; import "package:photos/services/semantic_search/remote_embedding.dart"; import "package:photos/utils/crypto_util.dart"; import "package:photos/utils/file_download_util.dart"; import "package:shared_preferences/shared_preferences.dart"; class EmbeddingStore { EmbeddingStore._privateConstructor(); static final EmbeddingStore instance = EmbeddingStore._privateConstructor(); static const kEmbeddingsSyncTimeKey = "embeddings_sync_time"; final _logger = Logger("EmbeddingStore"); final _dio = NetworkClient.instance.enteDio; late SharedPreferences _preferences; bool isSyncing = false; Future init(SharedPreferences preferences) async { _preferences = preferences; } Future pullEmbeddings() async { if (isSyncing) { return; } isSyncing = true; var remoteEmbeddings = await _getRemoteEmbeddings(); await _storeRemoteEmbeddings(remoteEmbeddings.embeddings); while (remoteEmbeddings.hasMore) { remoteEmbeddings = await _getRemoteEmbeddings(); await _storeRemoteEmbeddings(remoteEmbeddings.embeddings); } isSyncing = false; } Future pushEmbeddings() async { final query = (ObjectBox.instance .getEmbeddingBox() .query(Embedding_.updationTime.isNull())) .build(); final pendingItems = query.find(); query.close(); for (final item in pendingItems) { final file = await FilesDB.instance.getAnyUploadedFile(item.fileID); await _pushEmbedding(file!, item); } } Future storeEmbedding(EnteFile file, Embedding embedding) async { ObjectBox.instance.getEmbeddingBox().put(embedding); _pushEmbedding(file, embedding); } Future _pushEmbedding(EnteFile file, Embedding embedding) async { final encryptionKey = getFileKey(file); final embeddingJSON = jsonEncode(embedding.embedding); final encryptedEmbedding = await CryptoUtil.encryptChaCha( utf8.encode(embeddingJSON) as Uint8List, encryptionKey, ); final encryptedData = CryptoUtil.bin2base64(encryptedEmbedding.encryptedData!); final header = CryptoUtil.bin2base64(encryptedEmbedding.header!); try { final response = await _dio.put( "/embeddings", data: { "fileID": embedding.fileID, "model": embedding.model, "encryptedEmbedding": encryptedData, "decryptionHeader": header, }, ); final updationTime = response.data["updationTime"]; embedding.updationTime = updationTime; ObjectBox.instance.getEmbeddingBox().put(embedding); } catch (e, s) { _logger.severe(e, s); } } Future _getRemoteEmbeddings({ int limit = 500, }) async { final remoteEmbeddings = []; try { final sinceTime = _preferences.getInt(kEmbeddingsSyncTimeKey) ?? 0; _logger.info("Fetching embeddings since $sinceTime"); final response = await _dio.get( "/embeddings/diff", queryParameters: { "sinceTime": sinceTime, "limit": limit, }, ); final diff = response.data["diff"] as List; for (var entry in diff) { final embedding = RemoteEmbedding.fromMap(entry); remoteEmbeddings.add(embedding); } } catch (e, s) { _logger.severe(e, s); } _logger.info("${remoteEmbeddings.length} embeddings fetched"); return RemoteEmbeddings( remoteEmbeddings, remoteEmbeddings.length == limit, ); } Future _storeRemoteEmbeddings( List remoteEmbeddings, ) async { if (remoteEmbeddings.isEmpty) { return; } final inputs = []; for (final embedding in remoteEmbeddings) { final file = await FilesDB.instance.getAnyUploadedFile(embedding.fileID); if (file == null) { continue; } final fileKey = getFileKey(file); final input = EmbeddingsDecoderInput(embedding, fileKey); inputs.add(input); } final embeddings = await Computer.shared().compute( decodeEmbeddings, param: { "inputs": inputs, }, ); _logger.info("${embeddings.length} embeddings decoded"); await ObjectBox.instance.getEmbeddingBox().putManyAsync(embeddings); await _preferences.setInt( kEmbeddingsSyncTimeKey, embeddings.last.updationTime!, ); _logger.info("${embeddings.length} embeddings stored"); } } Future> decodeEmbeddings(Map args) async { final embeddings = []; final inputs = args["inputs"] as List; for (final input in inputs) { ; final decryptArgs = {}; decryptArgs["source"] = CryptoUtil.base642bin(input.embedding.encryptedEmbedding); decryptArgs["key"] = input.decryptionKey; decryptArgs["header"] = CryptoUtil.base642bin(input.embedding.decryptionHeader); final embeddingData = chachaDecryptData(decryptArgs); final List decodedEmbedding = jsonDecode(utf8.decode(embeddingData)) .map((item) => item.toDouble()) .cast() .toList(); embeddings.add( Embedding( fileID: input.embedding.fileID, model: input.embedding.model, embedding: decodedEmbedding, updationTime: input.embedding.updatedAt, ), ); } return embeddings; } class EmbeddingsDecoderInput { final RemoteEmbedding embedding; final Uint8List decryptionKey; EmbeddingsDecoderInput(this.embedding, this.decryptionKey); }