diff --git a/lib/services/semantic_search_service.dart b/lib/services/semantic_search_service.dart index 9ab66ba43..65c74c969 100644 --- a/lib/services/semantic_search_service.dart +++ b/lib/services/semantic_search_service.dart @@ -1,3 +1,4 @@ +import "dart:async"; import "dart:io"; import "package:clip_ggml/clip_ggml.dart"; @@ -18,97 +19,106 @@ class SemanticSearchService { static final Computer _computer = Computer.shared(); bool hasLoaded = false; - bool isRunning = false; final _logger = Logger("SemanticSearchService"); + Future>? _ongoingRequest; + PendingQuery? _nextQuery; Future init() async { await _loadModel(); - final files = await FilesDB.instance.getFilesWithoutEmbeddings(); - _logger.info(files.length.toString() + " pending to be embedded"); - for (final file in files) { - await runInference(file, "hello"); + _computeMissingEmbeddings(); + } + + Future> search(String query) async { + if (_ongoingRequest == null) { + _ongoingRequest = getMatchingFiles(query).then((result) { + _ongoingRequest = null; + if (_nextQuery != null) { + final next = _nextQuery; + _nextQuery = null; + search(next!.query).then((nextResult) { + next.completer.complete(nextResult); + }); + } + + return result; + }); + return _ongoingRequest!; + } else { + // If there's an ongoing request, create or replace the nextCompleter. + _nextQuery?.completer.future + .timeout(const Duration(seconds: 0)); // Cancels the previous future. + _nextQuery = PendingQuery(query, Completer>()); + return _nextQuery!.completer.future; } } - Future runInference(EnteFile file, String text) async { - if (!hasLoaded || isRunning) { - return; - } - isRunning = true; - // _logger.info("Running clip"); - final imagePath = (await getThumbnailFile(file))!.path; - + Future> getMatchingFiles(String query) async { + _logger.info("Searching for " + query); var startTime = DateTime.now(); - // ignore: prefer_typing_uninitialized_variables - var imageEmbedding; - final embeddings = await FilesDB.instance.getAllEmbeddings(); - bool hasCachedEmbedding = false; - for (final embedding in embeddings) { - if (embedding.id == file.generatedID) { - imageEmbedding = embedding.embedding; - hasCachedEmbedding = true; - _logger.info("Found cached embedding"); - } - } - if (!hasCachedEmbedding) { - imageEmbedding ??= await _computer.compute( - createImageEmbedding, - param: { - "imagePath": imagePath, - }, - taskName: "createImageEmbedding", - ); - await FilesDB.instance - .insertEmbedding(Embedding(file.generatedID!, imageEmbedding, -1)); - } + final textEmbedding = await _computer.compute( + createTextEmbedding, + param: { + "text": query, + }, + taskName: "createTextEmbedding", + ); var endTime = DateTime.now(); _logger.info( - "createImageEmbedding took: " + + "createTextEmbedding took: " + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) .toString() + "ms", ); - // startTime = DateTime.now(); - // final textEmbedding = await _computer.compute( - // createTextEmbedding, - // param: { - // "text": text, - // }, - // taskName: "createTextEmbedding", - // ); - // endTime = DateTime.now(); - // _logger.info( - // "createTextEmbedding took: " + - // (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) - // .toString() + - // "ms", - // ); + startTime = DateTime.now(); + final embeddings = await FilesDB.instance.getAllEmbeddings(); + endTime = DateTime.now(); + _logger.info( + "Fetching embeddings took: " + + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) + .toString() + + "ms", + ); - // startTime = DateTime.now(); - // final score = computeScore({ - // "imageEmbedding": imageEmbedding, - // "textEmbedding": textEmbedding, - // }); - // endTime = DateTime.now(); - // _logger.info( - // "computeScore took: " + - // (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) - // .toString() + - // "ms", + startTime = DateTime.now(); + final queryResults = []; + for (final embedding in embeddings) { + final score = computeScore({ + "imageEmbedding": embedding.embedding, + "textEmbedding": textEmbedding, + }); + queryResults.add(QueryResult(embedding.id, score)); + } + queryResults.sort((first, second) => second.score.compareTo(first.score)); + queryResults.removeWhere((element) => element.score < 0.27); + endTime = DateTime.now(); + _logger.info( + "computingScores took: " + + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) + .toString() + + "ms", + ); - // final score = await _computer.compute( - // computeScore, - // param: { - // "imageEmbedding": imageEmbedding, - // "textEmbedding": textEmbedding, - // }, - // taskName: "computeScore", - // ); - // ); + startTime = DateTime.now(); + final filesMap = await FilesDB.instance + .getFilesFromGeneratedIDs(queryResults.map((e) => e.id).toList()); + final results = []; + for (final result in queryResults) { + if (filesMap.containsKey(result.id)) { + results.add(filesMap[result.id]!); + } + } + endTime = DateTime.now(); + _logger.info( + "Fetching files took: " + + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) + .toString() + + "ms", + ); - // _logger.info("Score: " + score.toString()); - isRunning = false; + _logger.info(results.length.toString() + " results"); + + return results; } Future _loadModel() async { @@ -141,6 +151,53 @@ class SemanticSearchService { final file = await File('${tempDir.path}/$fileName').writeAsBytes(bytes); return file.path; } + + Future _computeMissingEmbeddings() async { + final files = await FilesDB.instance.getFilesWithoutEmbeddings(); + _logger.info(files.length.toString() + " pending to be embedded"); + for (final file in files) { + await _computeImageEmbedding(file); + } + } + + Future _computeImageEmbedding(EnteFile file) async { + if (!hasLoaded) { + return; + } + // _logger.info("Running clip"); + final imagePath = (await getThumbnailFile(file))!.path; + + final startTime = DateTime.now(); + // ignore: prefer_typing_uninitialized_variables + var imageEmbedding; + final embeddings = await FilesDB.instance.getAllEmbeddings(); + bool hasCachedEmbedding = false; + for (final embedding in embeddings) { + if (embedding.id == file.generatedID) { + imageEmbedding = embedding.embedding; + hasCachedEmbedding = true; + _logger.info("Found cached embedding"); + } + } + if (!hasCachedEmbedding) { + imageEmbedding ??= await _computer.compute( + createImageEmbedding, + param: { + "imagePath": imagePath, + }, + taskName: "createImageEmbedding", + ); + await FilesDB.instance + .insertEmbedding(Embedding(file.generatedID!, imageEmbedding, -1)); + } + final endTime = DateTime.now(); + _logger.info( + "createImageEmbedding took: " + + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) + .toString() + + "ms", + ); + } } List createImageEmbedding(Map args) { @@ -157,3 +214,17 @@ double computeScore(Map args) { args["textEmbedding"] as List, ); } + +class QueryResult { + final int id; + final double score; + + QueryResult(this.id, this.score); +} + +class PendingQuery { + final String query; + final Completer> completer; + + PendingQuery(this.query, this.completer); +}