ente/lib/services/semantic_search/semantic_search_service.dart

314 lines
9.1 KiB
Dart
Raw Normal View History

2023-09-22 18:16:03 +00:00
import "dart:async";
2023-10-03 19:07:42 +00:00
import "dart:collection";
2023-09-22 07:26:51 +00:00
import "dart:io";
import "package:clip_ggml/clip_ggml.dart";
2023-09-22 12:08:18 +00:00
import "package:computer/computer.dart";
2023-09-22 07:26:51 +00:00
import "package:flutter/services.dart";
import "package:logging/logging.dart";
import "package:path_provider/path_provider.dart";
2023-10-03 17:08:18 +00:00
import "package:photos/core/event_bus.dart";
2023-09-22 15:47:33 +00:00
import "package:photos/db/files_db.dart";
2023-10-13 14:53:59 +00:00
import "package:photos/events/file_indexed_event.dart";
import "package:photos/events/file_uploaded_event.dart";
2023-10-03 17:08:18 +00:00
import "package:photos/events/sync_status_update_event.dart";
2023-09-22 15:47:33 +00:00
import "package:photos/models/embedding.dart";
import "package:photos/models/file/file.dart";
2023-10-03 17:08:18 +00:00
import "package:photos/services/semantic_search/embedding_store.dart";
2023-10-13 14:53:59 +00:00
import "package:photos/utils/local_settings.dart";
2023-09-22 15:47:33 +00:00
import "package:photos/utils/thumbnail_util.dart";
2023-10-03 17:08:18 +00:00
import "package:shared_preferences/shared_preferences.dart";
2023-09-22 07:26:51 +00:00
class SemanticSearchService {
SemanticSearchService._privateConstructor();
static final SemanticSearchService instance =
SemanticSearchService._privateConstructor();
2023-09-22 12:08:18 +00:00
static final Computer _computer = Computer.shared();
2023-09-22 07:26:51 +00:00
2023-10-02 17:36:17 +00:00
static const int batchSize = 1;
2023-10-03 19:02:37 +00:00
static const kModelPath =
"assets/models/clip/openai_clip-vit-base-patch32.ggmlv0.f16.bin";
2023-09-23 18:28:19 +00:00
2023-09-22 07:26:51 +00:00
final _logger = Logger("SemanticSearchService");
2023-10-03 19:07:42 +00:00
final _queue = Queue<EnteFile>();
bool hasLoaded = false;
bool isComputingEmbeddings = false;
2023-09-22 18:16:03 +00:00
Future<List<EnteFile>>? _ongoingRequest;
PendingQuery? _nextQuery;
2023-09-22 07:26:51 +00:00
2023-10-03 17:08:18 +00:00
Future<void> init(SharedPreferences preferences) async {
2023-09-22 07:26:51 +00:00
await _loadModel();
2023-10-03 17:08:18 +00:00
await EmbeddingStore.instance.init(preferences);
2023-10-13 14:53:59 +00:00
startBackFill();
2023-10-13 12:25:27 +00:00
await EmbeddingStore.instance.pushEmbeddings();
2023-10-03 17:08:18 +00:00
Bus.instance.on<SyncStatusUpdate>().listen((event) async {
if (event.status == SyncStatus.diffSynced) {
await EmbeddingStore.instance.pullEmbeddings();
2023-10-03 17:08:18 +00:00
}
});
Bus.instance.on<FileUploadedEvent>().listen((event) async {
2023-10-13 12:25:27 +00:00
addToQueue(event.file);
});
2023-09-22 07:26:51 +00:00
}
2023-09-22 18:16:03 +00:00
Future<List<EnteFile>> search(String query) async {
if (_ongoingRequest == null) {
_ongoingRequest = getMatchingFiles(query).then((result) {
_ongoingRequest = null;
if (_nextQuery != null) {
final next = _nextQuery;
_nextQuery = null;
search(next!.query).then((nextResult) {
next.completer.complete(nextResult);
});
}
return result;
});
return _ongoingRequest!;
} else {
// If there's an ongoing request, create or replace the nextCompleter.
_nextQuery?.completer.future
.timeout(const Duration(seconds: 0)); // Cancels the previous future.
_nextQuery = PendingQuery(query, Completer<List<EnteFile>>());
return _nextQuery!.completer.future;
2023-09-22 12:08:18 +00:00
}
2023-09-22 18:16:03 +00:00
}
2023-09-22 12:08:18 +00:00
2023-09-22 18:16:03 +00:00
Future<List<EnteFile>> getMatchingFiles(String query) async {
_logger.info("Searching for " + query);
2023-09-22 12:08:18 +00:00
var startTime = DateTime.now();
2023-09-22 18:16:03 +00:00
final textEmbedding = await _computer.compute(
createTextEmbedding,
param: {
"text": query,
},
taskName: "createTextEmbedding",
);
var endTime = DateTime.now();
_logger.info(
"createTextEmbedding took: " +
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
"ms",
);
startTime = DateTime.now();
2023-09-22 15:47:33 +00:00
final embeddings = await FilesDB.instance.getAllEmbeddings();
2023-09-22 18:16:03 +00:00
endTime = DateTime.now();
_logger.info(
"Fetching embeddings took: " +
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
"ms",
);
startTime = DateTime.now();
final queryResults = <QueryResult>[];
2023-09-22 15:47:33 +00:00
for (final embedding in embeddings) {
2023-09-22 18:16:03 +00:00
final score = computeScore({
"imageEmbedding": embedding.embedding,
"textEmbedding": textEmbedding,
});
2023-10-03 18:13:08 +00:00
queryResults.add(QueryResult(embedding.fileID, score));
2023-09-22 15:47:33 +00:00
}
2023-09-22 18:16:03 +00:00
queryResults.sort((first, second) => second.score.compareTo(first.score));
2023-09-23 17:29:16 +00:00
queryResults.removeWhere((element) => element.score < 0.25);
2023-09-22 18:16:03 +00:00
endTime = DateTime.now();
_logger.info(
"computingScores took: " +
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
"ms",
);
startTime = DateTime.now();
final filesMap = await FilesDB.instance
.getFilesFromGeneratedIDs(queryResults.map((e) => e.id).toList());
final results = <EnteFile>[];
for (final result in queryResults) {
if (filesMap.containsKey(result.id)) {
results.add(filesMap[result.id]!);
}
2023-09-22 15:47:33 +00:00
}
2023-09-22 18:16:03 +00:00
endTime = DateTime.now();
2023-09-22 12:08:18 +00:00
_logger.info(
2023-09-22 18:16:03 +00:00
"Fetching files took: " +
2023-09-22 12:08:18 +00:00
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
"ms",
);
2023-09-22 18:16:03 +00:00
_logger.info(results.length.toString() + " results");
return results;
2023-09-22 11:31:31 +00:00
}
2023-10-13 12:25:27 +00:00
void addToQueue(EnteFile file) {
2023-10-13 14:53:59 +00:00
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
return;
}
_logger.info("Adding " + file.toString() + " to the queue");
_queue.add(file);
_pollQueue();
}
2023-10-13 14:53:59 +00:00
Future<IndexStatus> getIndexStatus() async {
final embeddings = await FilesDB.instance.getAllEmbeddings();
return IndexStatus(embeddings.length, _queue.length);
}
2023-09-22 07:26:51 +00:00
Future<void> _loadModel() async {
2023-10-03 19:02:37 +00:00
final path = await _getAccessiblePathForAsset(kModelPath, "model.bin");
2023-09-22 07:26:51 +00:00
final startTime = DateTime.now();
2023-09-22 11:31:31 +00:00
CLIP.loadModel(path);
2023-09-22 07:26:51 +00:00
final endTime = DateTime.now();
_logger.info(
"Loading model took: " +
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
"ms",
);
2023-09-22 11:31:31 +00:00
hasLoaded = true;
2023-09-22 07:26:51 +00:00
}
Future<String> _getAccessiblePathForAsset(
String assetPath,
String tempName,
) async {
final byteData = await rootBundle.load(assetPath);
2023-09-22 11:31:31 +00:00
return _writeToFile(byteData.buffer.asUint8List(), tempName);
}
Future<String> _writeToFile(Uint8List bytes, String fileName) async {
2023-09-22 07:26:51 +00:00
final tempDir = await getTemporaryDirectory();
2023-09-22 12:08:18 +00:00
final file = await File('${tempDir.path}/$fileName').writeAsBytes(bytes);
2023-09-22 07:26:51 +00:00
return file.path;
}
2023-09-22 18:16:03 +00:00
2023-10-13 14:53:59 +00:00
Future<void> startBackFill() async {
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
return;
}
2023-09-22 18:16:03 +00:00
final files = await FilesDB.instance.getFilesWithoutEmbeddings();
_logger.info(files.length.toString() + " pending to be embedded");
2023-10-03 19:07:42 +00:00
_queue.addAll(files);
_pollQueue();
}
Future<void> _pollQueue() async {
if (isComputingEmbeddings) {
return;
}
isComputingEmbeddings = true;
2023-10-03 19:07:42 +00:00
2023-09-23 18:28:19 +00:00
final List<EnteFile> batch = [];
2023-10-03 19:07:42 +00:00
while (_queue.isNotEmpty) {
if (batch.length < batchSize) {
batch.add(_queue.removeFirst());
2023-09-23 18:28:19 +00:00
} else {
await _computeImageEmbeddings(batch);
batch.clear();
}
2023-09-22 18:16:03 +00:00
}
2023-10-03 19:07:42 +00:00
await _computeImageEmbeddings(batch);
isComputingEmbeddings = false;
2023-09-22 18:16:03 +00:00
}
2023-09-23 18:28:19 +00:00
Future<void> _computeImageEmbeddings(List<EnteFile> files) async {
if (!hasLoaded || files.isEmpty) {
2023-09-22 18:16:03 +00:00
return;
}
2023-09-23 18:28:19 +00:00
final List<String> filePaths = [];
2023-09-22 18:16:03 +00:00
2023-09-23 18:28:19 +00:00
for (final file in files) {
2023-10-02 17:36:17 +00:00
filePaths.add((await getThumbnailFile(file))!.path);
}
2023-10-02 19:39:49 +00:00
_logger.info("Running clip over " + files.length.toString() + " items");
final startTime = DateTime.now();
2023-10-02 17:36:17 +00:00
final List<List<double>> imageEmbeddings = [];
if (filePaths.length == 1) {
final result = await _computer.compute(
createImageEmbedding,
param: {
"imagePath": filePaths.first,
},
taskName: "createImageEmbedding",
) as List<double>;
imageEmbeddings.add(result);
} else {
final result = await _computer.compute(
createImageEmbeddings,
param: {
"imagePaths": filePaths,
},
taskName: "createImageEmbeddings",
) as List<List<double>>;
imageEmbeddings.addAll(result);
2023-09-22 18:16:03 +00:00
}
final endTime = DateTime.now();
_logger.info(
2023-09-23 18:28:19 +00:00
"createImageEmbeddings took: " +
2023-09-22 18:16:03 +00:00
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
.toString() +
2023-09-23 18:28:19 +00:00
"ms for " +
imageEmbeddings.length.toString() +
" items",
2023-09-22 18:16:03 +00:00
);
2023-10-03 18:24:09 +00:00
for (int i = 0; i < imageEmbeddings.length; i++) {
2023-10-13 14:53:59 +00:00
await EmbeddingStore.instance.storeEmbedding(
2023-10-03 18:24:09 +00:00
files[i],
Embedding(
files[i].uploadedFileID!,
"c_uq",
imageEmbeddings[i],
),
);
}
2023-10-13 14:53:59 +00:00
Bus.instance.fire(FileIndexedEvent());
2023-09-22 18:16:03 +00:00
}
2023-09-22 07:26:51 +00:00
}
2023-09-22 12:08:18 +00:00
2023-09-23 18:28:19 +00:00
List<List<double>> createImageEmbeddings(Map args) {
return CLIP.createBatchImageEmbedding(args["imagePaths"]);
}
2023-09-22 12:08:18 +00:00
List<double> createImageEmbedding(Map args) {
return CLIP.createImageEmbedding(args["imagePath"]);
}
List<double> createTextEmbedding(Map args) {
return CLIP.createTextEmbedding(args["text"]);
}
double computeScore(Map args) {
return CLIP.computeScore(
args["imageEmbedding"] as List<double>,
args["textEmbedding"] as List<double>,
);
}
2023-09-22 18:16:03 +00:00
class QueryResult {
final int id;
final double score;
QueryResult(this.id, this.score);
}
class PendingQuery {
final String query;
final Completer<List<EnteFile>> completer;
PendingQuery(this.query, this.completer);
}
2023-10-13 14:53:59 +00:00
class IndexStatus {
final int indexedItems, pendingItems;
IndexStatus(this.indexedItems, this.pendingItems);
}