ente/lib/services/semantic_search/semantic_search_service.dart

316 lines
9.2 KiB
Dart
Raw Normal View History

2023-09-22 18:16:03 +00:00
import "dart:async";
2023-10-03 19:07:42 +00:00
import "dart:collection";
2023-10-27 06:45:38 +00:00
import "dart:io";
2023-09-22 07:26:51 +00:00
import "package:clip_ggml/clip_ggml.dart";
2023-09-22 12:08:18 +00:00
import "package:computer/computer.dart";
2023-09-22 07:26:51 +00:00
import "package:logging/logging.dart";
2023-10-24 08:42:22 +00:00
import "package:photos/core/configuration.dart";
2023-10-03 17:08:18 +00:00
import "package:photos/core/event_bus.dart";
2023-09-22 15:47:33 +00:00
import "package:photos/db/files_db.dart";
import "package:photos/db/object_box.dart";
import "package:photos/events/diff_sync_complete_event.dart";
import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/events/file_uploaded_event.dart";
2023-09-22 15:47:33 +00:00
import "package:photos/models/embedding.dart";
import "package:photos/models/file/file.dart";
2023-10-03 17:08:18 +00:00
import "package:photos/services/semantic_search/embedding_store.dart";
2023-10-25 14:04:41 +00:00
import "package:photos/services/semantic_search/model_loader.dart";
2023-10-13 14:53:59 +00:00
import "package:photos/utils/local_settings.dart";
2023-09-22 15:47:33 +00:00
import "package:photos/utils/thumbnail_util.dart";
2023-10-03 17:08:18 +00:00
import "package:shared_preferences/shared_preferences.dart";
2023-09-22 07:26:51 +00:00
class SemanticSearchService {
SemanticSearchService._privateConstructor();
static final SemanticSearchService instance =
SemanticSearchService._privateConstructor();
2023-09-22 12:08:18 +00:00
static final Computer _computer = Computer.shared();
2023-09-22 07:26:51 +00:00
2023-10-24 08:38:45 +00:00
static const kModelName = "ggml-clip";
static const kEmbeddingLength = 512;
2023-10-27 07:11:31 +00:00
static const kScoreThreshold = 0.23;
2023-09-23 18:28:19 +00:00
2023-09-22 07:26:51 +00:00
final _logger = Logger("SemanticSearchService");
2023-10-03 19:07:42 +00:00
final _queue = Queue<EnteFile>();
bool hasLoaded = false;
bool isComputingEmbeddings = false;
2023-09-22 18:16:03 +00:00
Future<List<EnteFile>>? _ongoingRequest;
PendingQuery? _nextQuery;
2023-10-28 12:17:17 +00:00
final _cachedEmbeddings = <Embedding>[];
2023-09-22 07:26:51 +00:00
2023-10-03 17:08:18 +00:00
Future<void> init(SharedPreferences preferences) async {
2023-10-27 06:45:38 +00:00
if (Platform.isIOS) {
return;
}
2023-10-03 17:08:18 +00:00
await EmbeddingStore.instance.init(preferences);
2023-10-25 14:04:41 +00:00
await ModelLoader.instance.init(_computer);
_setupCachedEmbeddings();
Bus.instance.on<DiffSyncCompleteEvent>().listen((event) async {
// Diff sync is complete, we can now pull embeddings from remote
sync();
2023-10-03 17:08:18 +00:00
});
if (Configuration.instance.hasConfiguredAccount()) {
EmbeddingStore.instance.pushEmbeddings();
}
2023-10-28 12:17:17 +00:00
_loadModels().then((v) {
2023-11-14 08:04:28 +00:00
_getTextEmbedding("warm up text encoder");
});
Bus.instance.on<FileUploadedEvent>().listen((event) async {
2023-11-14 09:04:42 +00:00
_addToQueue(event.file);
});
2023-09-22 07:26:51 +00:00
}
2023-11-14 15:20:46 +00:00
Future<void> sync() async {
await EmbeddingStore.instance.pullEmbeddings();
_backFill();
}
2023-09-22 18:16:03 +00:00
Future<List<EnteFile>> search(String query) async {
if (!LocalSettings.instance.hasEnabledMagicSearch() || !hasLoaded) {
2023-10-27 06:45:38 +00:00
return [];
}
2023-09-22 18:16:03 +00:00
if (_ongoingRequest == null) {
2023-11-14 09:03:16 +00:00
_ongoingRequest = _getMatchingFiles(query).then((result) {
2023-09-22 18:16:03 +00:00
_ongoingRequest = null;
if (_nextQuery != null) {
final next = _nextQuery;
_nextQuery = null;
search(next!.query).then((nextResult) {
next.completer.complete(nextResult);
});
}
return result;
});
return _ongoingRequest!;
} else {
// If there's an ongoing request, create or replace the nextCompleter.
_nextQuery?.completer.future
.timeout(const Duration(seconds: 0)); // Cancels the previous future.
_nextQuery = PendingQuery(query, Completer<List<EnteFile>>());
return _nextQuery!.completer.future;
2023-09-22 12:08:18 +00:00
}
2023-09-22 18:16:03 +00:00
}
2023-09-22 12:08:18 +00:00
2023-11-14 09:04:42 +00:00
Future<IndexStatus> getIndexStatus() async {
2023-11-14 20:47:53 +00:00
return IndexStatus(
_cachedEmbeddings.length,
(await _getFileIDsToBeIndexed()).length,
);
}
void _setupCachedEmbeddings() {
ObjectBox.instance
.getEmbeddingBox()
.query()
.watch(triggerImmediately: true)
.map((query) => query.find())
.listen((embeddings) {
_logger.info("Updated embeddings: " + embeddings.length.toString());
_cachedEmbeddings.clear();
_cachedEmbeddings.addAll(embeddings);
2023-11-14 20:38:26 +00:00
Bus.instance.fire(EmbeddingUpdatedEvent());
});
2023-11-14 09:04:42 +00:00
}
2023-11-14 15:20:46 +00:00
Future<void> _backFill() async {
2023-11-14 09:04:42 +00:00
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
return;
}
_logger.info("Attempting backfill");
2023-11-14 20:47:53 +00:00
final fileIDs = await _getFileIDsToBeIndexed();
final files = await FilesDB.instance.getUploadedFiles(fileIDs);
_logger.info(files.length.toString() + " to be embedded");
_queue.addAll(files);
_pollQueue();
}
Future<List<int>> _getFileIDsToBeIndexed() async {
2023-11-14 09:04:42 +00:00
final uploadedFileIDs = await FilesDB.instance
.getOwnedFileIDs(Configuration.instance.getUserID()!);
final embeddedFileIDs = _cachedEmbeddings.map((e) => e.fileID).toSet();
final queuedFileIDs = _queue.map((e) => e.uploadedFileID).toSet();
uploadedFileIDs.removeWhere(
(id) => embeddedFileIDs.contains(id) || queuedFileIDs.contains(id),
);
2023-11-14 20:47:53 +00:00
return uploadedFileIDs;
2023-11-14 09:04:42 +00:00
}
Future<void> clearQueue() async {
_queue.clear();
}
2023-11-14 09:03:16 +00:00
Future<List<EnteFile>> _getMatchingFiles(String query) async {
2023-11-14 08:04:28 +00:00
final textEmbedding = await _getTextEmbedding(query);
2023-11-14 08:11:32 +00:00
final queryResults = await _getScores(textEmbedding);
2023-09-22 18:16:03 +00:00
final filesMap = await FilesDB.instance
2023-10-13 19:44:47 +00:00
.getFilesFromIDs(queryResults.map((e) => e.id).toList());
2023-09-22 18:16:03 +00:00
final results = <EnteFile>[];
for (final result in queryResults) {
if (filesMap.containsKey(result.id)) {
results.add(filesMap[result.id]!);
}
2023-09-22 15:47:33 +00:00
}
2023-09-22 12:08:18 +00:00
2023-09-22 18:16:03 +00:00
_logger.info(results.length.toString() + " results");
return results;
2023-09-22 11:31:31 +00:00
}
2023-11-14 09:04:42 +00:00
void _addToQueue(EnteFile file) {
2023-10-13 14:53:59 +00:00
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
return;
}
_logger.info("Adding " + file.toString() + " to the queue");
_queue.add(file);
_pollQueue();
}
2023-10-24 08:38:45 +00:00
Future<void> _loadModels() async {
2023-10-25 14:04:41 +00:00
await ModelLoader.instance.loadImageModel();
await ModelLoader.instance.loadTextModel();
2023-09-22 11:31:31 +00:00
hasLoaded = true;
2023-09-22 07:26:51 +00:00
}
Future<void> _pollQueue() async {
if (isComputingEmbeddings) {
return;
}
isComputingEmbeddings = true;
2023-10-03 19:07:42 +00:00
while (_queue.isNotEmpty) {
2023-10-28 10:45:42 +00:00
await _computeImageEmbedding(_queue.removeLast());
2023-09-22 18:16:03 +00:00
}
isComputingEmbeddings = false;
2023-09-22 18:16:03 +00:00
}
2023-10-25 17:49:50 +00:00
Future<void> _computeImageEmbedding(EnteFile file) async {
if (!hasLoaded) {
2023-09-22 18:16:03 +00:00
return;
}
2023-10-13 19:44:47 +00:00
try {
2023-10-28 10:53:48 +00:00
final filePath = (await getThumbnailForUploadedFile(file))!.path;
2023-10-25 17:49:50 +00:00
_logger.info("Running clip over $file");
2023-10-13 19:44:47 +00:00
final startTime = DateTime.now();
2023-10-25 17:49:50 +00:00
final result = await _computer.compute(
createImageEmbedding,
param: {
"imagePath": filePath,
},
taskName: "createImageEmbedding",
) as List<double>;
2023-10-13 19:44:47 +00:00
final endTime = DateTime.now();
_logger.info(
2023-10-25 17:49:50 +00:00
"createImageEmbedding took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
);
if (result.length != kEmbeddingLength) {
_logger.severe("Discovered incorrect embedding for $file - $result");
return;
}
2023-10-28 12:17:17 +00:00
final embedding = Embedding(
fileID: file.uploadedFileID!,
model: kModelName,
embedding: result,
2023-10-28 12:17:17 +00:00
);
2023-10-25 17:49:50 +00:00
await EmbeddingStore.instance.storeEmbedding(
file,
2023-10-28 12:17:17 +00:00
embedding,
2023-10-03 18:24:09 +00:00
);
2023-10-13 19:44:47 +00:00
} catch (e, s) {
_logger.severe(e, s);
2023-10-03 18:24:09 +00:00
}
2023-09-22 18:16:03 +00:00
}
2023-10-28 12:17:17 +00:00
2023-11-14 08:04:28 +00:00
Future<List<double>> _getTextEmbedding(String query) async {
_logger.info("Searching for " + query);
final startTime = DateTime.now();
final embedding = await _computer.compute(
createTextEmbedding,
param: {
"text": query,
},
taskName: "createTextEmbedding",
);
final endTime = DateTime.now();
_logger.info(
"createTextEmbedding took: " +
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
"ms",
);
return embedding;
}
2023-11-14 08:11:32 +00:00
Future<List<QueryResult>> _getScores(List<double> textEmbedding) async {
final startTime = DateTime.now();
final List<QueryResult> queryResults = await _computer.compute(
computeBulkScore,
param: {
"imageEmbeddings": _cachedEmbeddings,
"textEmbedding": textEmbedding,
},
taskName: "computeBulkScore",
);
final endTime = DateTime.now();
_logger.info(
"computingScores took: " +
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
"ms",
);
return queryResults;
}
2023-09-22 07:26:51 +00:00
}
2023-09-22 12:08:18 +00:00
List<double> createImageEmbedding(Map args) {
return CLIP.createImageEmbedding(args["imagePath"]);
}
List<double> createTextEmbedding(Map args) {
return CLIP.createTextEmbedding(args["text"]);
}
List<QueryResult> computeBulkScore(Map args) {
final queryResults = <QueryResult>[];
final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
final textEmbedding = args["textEmbedding"] as List<double>;
for (final imageEmbedding in imageEmbeddings) {
final score = CLIP.computeScore(
imageEmbedding.embedding,
textEmbedding,
);
2023-11-14 08:11:32 +00:00
if (score >= SemanticSearchService.kScoreThreshold) {
queryResults.add(QueryResult(imageEmbedding.fileID, score));
}
}
queryResults.sort((first, second) => second.score.compareTo(first.score));
return queryResults;
2023-09-22 12:08:18 +00:00
}
2023-09-22 18:16:03 +00:00
class QueryResult {
final int id;
final double score;
QueryResult(this.id, this.score);
}
class PendingQuery {
final String query;
final Completer<List<EnteFile>> completer;
PendingQuery(this.query, this.completer);
}
2023-10-13 14:53:59 +00:00
class IndexStatus {
final int indexedItems, pendingItems;
IndexStatus(this.indexedItems, this.pendingItems);
}