import "dart:async"; import "dart:collection"; import "dart:io"; import "package:clip_ggml/clip_ggml.dart"; import "package:computer/computer.dart"; import "package:logging/logging.dart"; import "package:photos/core/configuration.dart"; import "package:photos/core/event_bus.dart"; import "package:photos/db/files_db.dart"; import "package:photos/db/object_box.dart"; import "package:photos/events/file_indexed_event.dart"; import "package:photos/events/file_uploaded_event.dart"; import "package:photos/events/sync_status_update_event.dart"; import "package:photos/models/embedding.dart"; import "package:photos/models/file/file.dart"; import "package:photos/services/semantic_search/embedding_store.dart"; import "package:photos/services/semantic_search/model_loader.dart"; import "package:photos/utils/local_settings.dart"; import "package:photos/utils/thumbnail_util.dart"; import "package:shared_preferences/shared_preferences.dart"; class SemanticSearchService { SemanticSearchService._privateConstructor(); static final SemanticSearchService instance = SemanticSearchService._privateConstructor(); static final Computer _computer = Computer.shared(); static const kModelName = "ggml-clip"; static const kEmbeddingLength = 512; static const kScoreThreshold = 0.23; final _logger = Logger("SemanticSearchService"); final _queue = Queue(); bool hasLoaded = false; bool isComputingEmbeddings = false; Future>? _ongoingRequest; PendingQuery? _nextQuery; final _cachedEmbeddings = []; Future init(SharedPreferences preferences) async { if (Platform.isIOS) { return; } await EmbeddingStore.instance.init(preferences); await ModelLoader.instance.init(_computer); _cacheEmbeddings(); Bus.instance.on().listen((event) async { if (event.status == SyncStatus.diffSynced) { await EmbeddingStore.instance.pullEmbeddings(); _cacheEmbeddings(); } }); if (Configuration.instance.hasConfiguredAccount()) { EmbeddingStore.instance.pushEmbeddings(); } _loadModels().then((v) { startBackFill(); _getTextEmbedding("warm up text encoder"); }); Bus.instance.on().listen((event) async { addToQueue(event.file); }); } Future> search(String query) async { if (Platform.isIOS) { return []; } if (_ongoingRequest == null) { _ongoingRequest = getMatchingFiles(query).then((result) { _ongoingRequest = null; if (_nextQuery != null) { final next = _nextQuery; _nextQuery = null; search(next!.query).then((nextResult) { next.completer.complete(nextResult); }); } return result; }); return _ongoingRequest!; } else { // If there's an ongoing request, create or replace the nextCompleter. _nextQuery?.completer.future .timeout(const Duration(seconds: 0)); // Cancels the previous future. _nextQuery = PendingQuery(query, Completer>()); return _nextQuery!.completer.future; } } Future> getMatchingFiles(String query) async { final textEmbedding = await _getTextEmbedding(query); var startTime = DateTime.now(); final List queryResults = await _computer.compute( computeBulkScore, param: { "imageEmbeddings": _cachedEmbeddings, "textEmbedding": textEmbedding, }, taskName: "computeBulkScore", ); var endTime = DateTime.now(); _logger.info( "computingScores took: " + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) .toString() + "ms", ); startTime = DateTime.now(); final filesMap = await FilesDB.instance .getFilesFromIDs(queryResults.map((e) => e.id).toList()); final results = []; for (final result in queryResults) { if (filesMap.containsKey(result.id)) { results.add(filesMap[result.id]!); } } endTime = DateTime.now(); _logger.info( "Fetching files took: " + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) .toString() + "ms", ); _logger.info(results.length.toString() + " results"); return results; } void addToQueue(EnteFile file) { if (!LocalSettings.instance.hasEnabledMagicSearch()) { return; } _logger.info("Adding " + file.toString() + " to the queue"); _queue.add(file); _pollQueue(); } Future getIndexStatus() async { final embeddings = ObjectBox.instance.getEmbeddingBox().getAll(); return IndexStatus(embeddings.length, _queue.length); } Future _loadModels() async { await ModelLoader.instance.loadImageModel(); await ModelLoader.instance.loadTextModel(); hasLoaded = true; } Future startBackFill() async { if (!LocalSettings.instance.hasEnabledMagicSearch()) { return; } final uploadedFileIDs = await FilesDB.instance .getOwnedFileIDs(Configuration.instance.getUserID()!); final embeddedFileIDs = _cachedEmbeddings.map((e) => e.fileID).toSet(); uploadedFileIDs.removeWhere((id) => embeddedFileIDs.contains(id)); final files = await FilesDB.instance.getUploadedFiles(uploadedFileIDs); _logger.info(files.length.toString() + " pending to be embedded"); _queue.addAll(files); _pollQueue(); } Future clearQueue() async { _queue.clear(); } Future _pollQueue() async { if (isComputingEmbeddings) { return; } isComputingEmbeddings = true; while (_queue.isNotEmpty) { await _computeImageEmbedding(_queue.removeLast()); } isComputingEmbeddings = false; } Future _computeImageEmbedding(EnteFile file) async { if (!hasLoaded) { return; } try { final filePath = (await getThumbnailForUploadedFile(file))!.path; _logger.info("Running clip over $file"); final startTime = DateTime.now(); final result = await _computer.compute( createImageEmbedding, param: { "imagePath": filePath, }, taskName: "createImageEmbedding", ) as List; final endTime = DateTime.now(); _logger.info( "createImageEmbedding took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", ); if (result.length != kEmbeddingLength) { _logger.severe("Discovered incorrect embedding for $file - $result"); return; } final embedding = Embedding( fileID: file.uploadedFileID!, model: kModelName, embedding: result, ); await EmbeddingStore.instance.storeEmbedding( file, embedding, ); Bus.instance.fire(FileIndexedEvent()); _cachedEmbeddings.add(embedding); } catch (e, s) { _logger.severe(e, s); } } Future> _getTextEmbedding(String query) async { _logger.info("Searching for " + query); final startTime = DateTime.now(); final embedding = await _computer.compute( createTextEmbedding, param: { "text": query, }, taskName: "createTextEmbedding", ); final endTime = DateTime.now(); _logger.info( "createTextEmbedding took: " + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) .toString() + "ms", ); return embedding; } Future _cacheEmbeddings() async { final startTime = DateTime.now(); final embeddings = ObjectBox.instance.store.box().getAll(); _cachedEmbeddings.clear(); _cachedEmbeddings.addAll(embeddings); final endTime = DateTime.now(); _logger.info( "Loading ${embeddings.length} embeddings took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", ); } } List createImageEmbedding(Map args) { return CLIP.createImageEmbedding(args["imagePath"]); } List createTextEmbedding(Map args) { return CLIP.createTextEmbedding(args["text"]); } List computeBulkScore(Map args) { final queryResults = []; final imageEmbeddings = args["imageEmbeddings"] as List; final textEmbedding = args["textEmbedding"] as List; for (final imageEmbedding in imageEmbeddings) { final score = CLIP.computeScore( imageEmbedding.embedding, textEmbedding, ); if (score >= 0.23) { queryResults.add(QueryResult(imageEmbedding.fileID, score)); } } queryResults.sort((first, second) => second.score.compareTo(first.score)); return queryResults; } class QueryResult { final int id; final double score; QueryResult(this.id, this.score); } class PendingQuery { final String query; final Completer> completer; PendingQuery(this.query, this.completer); } class IndexStatus { final int indexedItems, pendingItems; IndexStatus(this.indexedItems, this.pendingItems); }