Implement search
This commit is contained in:
parent
dbcb5691cd
commit
54881e0309
|
@ -1,3 +1,4 @@
|
|||
import "dart:async";
|
||||
import "dart:io";
|
||||
|
||||
import "package:clip_ggml/clip_ggml.dart";
|
||||
|
@ -18,97 +19,106 @@ class SemanticSearchService {
|
|||
static final Computer _computer = Computer.shared();
|
||||
|
||||
bool hasLoaded = false;
|
||||
bool isRunning = false;
|
||||
final _logger = Logger("SemanticSearchService");
|
||||
Future<List<EnteFile>>? _ongoingRequest;
|
||||
PendingQuery? _nextQuery;
|
||||
|
||||
Future<void> init() async {
|
||||
await _loadModel();
|
||||
final files = await FilesDB.instance.getFilesWithoutEmbeddings();
|
||||
_logger.info(files.length.toString() + " pending to be embedded");
|
||||
for (final file in files) {
|
||||
await runInference(file, "hello");
|
||||
_computeMissingEmbeddings();
|
||||
}
|
||||
|
||||
Future<List<EnteFile>> search(String query) async {
|
||||
if (_ongoingRequest == null) {
|
||||
_ongoingRequest = getMatchingFiles(query).then((result) {
|
||||
_ongoingRequest = null;
|
||||
if (_nextQuery != null) {
|
||||
final next = _nextQuery;
|
||||
_nextQuery = null;
|
||||
search(next!.query).then((nextResult) {
|
||||
next.completer.complete(nextResult);
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
});
|
||||
return _ongoingRequest!;
|
||||
} else {
|
||||
// If there's an ongoing request, create or replace the nextCompleter.
|
||||
_nextQuery?.completer.future
|
||||
.timeout(const Duration(seconds: 0)); // Cancels the previous future.
|
||||
_nextQuery = PendingQuery(query, Completer<List<EnteFile>>());
|
||||
return _nextQuery!.completer.future;
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> runInference(EnteFile file, String text) async {
|
||||
if (!hasLoaded || isRunning) {
|
||||
return;
|
||||
}
|
||||
isRunning = true;
|
||||
// _logger.info("Running clip");
|
||||
final imagePath = (await getThumbnailFile(file))!.path;
|
||||
|
||||
Future<List<EnteFile>> getMatchingFiles(String query) async {
|
||||
_logger.info("Searching for " + query);
|
||||
var startTime = DateTime.now();
|
||||
// ignore: prefer_typing_uninitialized_variables
|
||||
var imageEmbedding;
|
||||
final embeddings = await FilesDB.instance.getAllEmbeddings();
|
||||
bool hasCachedEmbedding = false;
|
||||
for (final embedding in embeddings) {
|
||||
if (embedding.id == file.generatedID) {
|
||||
imageEmbedding = embedding.embedding;
|
||||
hasCachedEmbedding = true;
|
||||
_logger.info("Found cached embedding");
|
||||
}
|
||||
}
|
||||
if (!hasCachedEmbedding) {
|
||||
imageEmbedding ??= await _computer.compute(
|
||||
createImageEmbedding,
|
||||
param: {
|
||||
"imagePath": imagePath,
|
||||
},
|
||||
taskName: "createImageEmbedding",
|
||||
);
|
||||
await FilesDB.instance
|
||||
.insertEmbedding(Embedding(file.generatedID!, imageEmbedding, -1));
|
||||
}
|
||||
final textEmbedding = await _computer.compute(
|
||||
createTextEmbedding,
|
||||
param: {
|
||||
"text": query,
|
||||
},
|
||||
taskName: "createTextEmbedding",
|
||||
);
|
||||
var endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"createImageEmbedding took: " +
|
||||
"createTextEmbedding took: " +
|
||||
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
|
||||
.toString() +
|
||||
"ms",
|
||||
);
|
||||
|
||||
// startTime = DateTime.now();
|
||||
// final textEmbedding = await _computer.compute(
|
||||
// createTextEmbedding,
|
||||
// param: {
|
||||
// "text": text,
|
||||
// },
|
||||
// taskName: "createTextEmbedding",
|
||||
// );
|
||||
// endTime = DateTime.now();
|
||||
// _logger.info(
|
||||
// "createTextEmbedding took: " +
|
||||
// (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
|
||||
// .toString() +
|
||||
// "ms",
|
||||
// );
|
||||
startTime = DateTime.now();
|
||||
final embeddings = await FilesDB.instance.getAllEmbeddings();
|
||||
endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"Fetching embeddings took: " +
|
||||
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
|
||||
.toString() +
|
||||
"ms",
|
||||
);
|
||||
|
||||
// startTime = DateTime.now();
|
||||
// final score = computeScore({
|
||||
// "imageEmbedding": imageEmbedding,
|
||||
// "textEmbedding": textEmbedding,
|
||||
// });
|
||||
// endTime = DateTime.now();
|
||||
// _logger.info(
|
||||
// "computeScore took: " +
|
||||
// (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
|
||||
// .toString() +
|
||||
// "ms",
|
||||
startTime = DateTime.now();
|
||||
final queryResults = <QueryResult>[];
|
||||
for (final embedding in embeddings) {
|
||||
final score = computeScore({
|
||||
"imageEmbedding": embedding.embedding,
|
||||
"textEmbedding": textEmbedding,
|
||||
});
|
||||
queryResults.add(QueryResult(embedding.id, score));
|
||||
}
|
||||
queryResults.sort((first, second) => second.score.compareTo(first.score));
|
||||
queryResults.removeWhere((element) => element.score < 0.27);
|
||||
endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"computingScores took: " +
|
||||
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
|
||||
.toString() +
|
||||
"ms",
|
||||
);
|
||||
|
||||
// final score = await _computer.compute(
|
||||
// computeScore,
|
||||
// param: {
|
||||
// "imageEmbedding": imageEmbedding,
|
||||
// "textEmbedding": textEmbedding,
|
||||
// },
|
||||
// taskName: "computeScore",
|
||||
// );
|
||||
// );
|
||||
startTime = DateTime.now();
|
||||
final filesMap = await FilesDB.instance
|
||||
.getFilesFromGeneratedIDs(queryResults.map((e) => e.id).toList());
|
||||
final results = <EnteFile>[];
|
||||
for (final result in queryResults) {
|
||||
if (filesMap.containsKey(result.id)) {
|
||||
results.add(filesMap[result.id]!);
|
||||
}
|
||||
}
|
||||
endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"Fetching files took: " +
|
||||
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
|
||||
.toString() +
|
||||
"ms",
|
||||
);
|
||||
|
||||
// _logger.info("Score: " + score.toString());
|
||||
isRunning = false;
|
||||
_logger.info(results.length.toString() + " results");
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
Future<void> _loadModel() async {
|
||||
|
@ -141,6 +151,53 @@ class SemanticSearchService {
|
|||
final file = await File('${tempDir.path}/$fileName').writeAsBytes(bytes);
|
||||
return file.path;
|
||||
}
|
||||
|
||||
Future<void> _computeMissingEmbeddings() async {
|
||||
final files = await FilesDB.instance.getFilesWithoutEmbeddings();
|
||||
_logger.info(files.length.toString() + " pending to be embedded");
|
||||
for (final file in files) {
|
||||
await _computeImageEmbedding(file);
|
||||
}
|
||||
}
|
||||
|
||||
Future<void> _computeImageEmbedding(EnteFile file) async {
|
||||
if (!hasLoaded) {
|
||||
return;
|
||||
}
|
||||
// _logger.info("Running clip");
|
||||
final imagePath = (await getThumbnailFile(file))!.path;
|
||||
|
||||
final startTime = DateTime.now();
|
||||
// ignore: prefer_typing_uninitialized_variables
|
||||
var imageEmbedding;
|
||||
final embeddings = await FilesDB.instance.getAllEmbeddings();
|
||||
bool hasCachedEmbedding = false;
|
||||
for (final embedding in embeddings) {
|
||||
if (embedding.id == file.generatedID) {
|
||||
imageEmbedding = embedding.embedding;
|
||||
hasCachedEmbedding = true;
|
||||
_logger.info("Found cached embedding");
|
||||
}
|
||||
}
|
||||
if (!hasCachedEmbedding) {
|
||||
imageEmbedding ??= await _computer.compute(
|
||||
createImageEmbedding,
|
||||
param: {
|
||||
"imagePath": imagePath,
|
||||
},
|
||||
taskName: "createImageEmbedding",
|
||||
);
|
||||
await FilesDB.instance
|
||||
.insertEmbedding(Embedding(file.generatedID!, imageEmbedding, -1));
|
||||
}
|
||||
final endTime = DateTime.now();
|
||||
_logger.info(
|
||||
"createImageEmbedding took: " +
|
||||
(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)
|
||||
.toString() +
|
||||
"ms",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
List<double> createImageEmbedding(Map args) {
|
||||
|
@ -157,3 +214,17 @@ double computeScore(Map args) {
|
|||
args["textEmbedding"] as List<double>,
|
||||
);
|
||||
}
|
||||
|
||||
class QueryResult {
|
||||
final int id;
|
||||
final double score;
|
||||
|
||||
QueryResult(this.id, this.score);
|
||||
}
|
||||
|
||||
class PendingQuery {
|
||||
final String query;
|
||||
final Completer<List<EnteFile>> completer;
|
||||
|
||||
PendingQuery(this.query, this.completer);
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue