ente/lib/services/semantic_search/embedding_store.dart
2023-11-14 22:48:03 +05:30

198 lines
6 KiB
Dart

import "dart:async";
import "dart:convert";
import "dart:typed_data";
import "package:computer/computer.dart";
import "package:logging/logging.dart";
import "package:photos/core/network/network.dart";
import "package:photos/db/files_db.dart";
import "package:photos/db/object_box.dart";
import "package:photos/models/embedding.dart";
import "package:photos/models/file/file.dart";
import "package:photos/objectbox.g.dart";
import "package:photos/services/semantic_search/remote_embedding.dart";
import "package:photos/utils/crypto_util.dart";
import "package:photos/utils/file_download_util.dart";
import "package:shared_preferences/shared_preferences.dart";
class EmbeddingStore {
EmbeddingStore._privateConstructor();
static final EmbeddingStore instance = EmbeddingStore._privateConstructor();
static const kEmbeddingsSyncTimeKey = "sync_time_embeddings";
final _logger = Logger("EmbeddingStore");
final _dio = NetworkClient.instance.enteDio;
final _computer = Computer.shared();
late SharedPreferences _preferences;
Completer<void>? _syncStatus;
Future<void> init(SharedPreferences preferences) async {
_preferences = preferences;
}
Future<void> pullEmbeddings() async {
if (_syncStatus != null) {
return _syncStatus!.future;
}
_syncStatus = Completer();
var remoteEmbeddings = await _getRemoteEmbeddings();
await _storeRemoteEmbeddings(remoteEmbeddings.embeddings);
while (remoteEmbeddings.hasMore) {
remoteEmbeddings = await _getRemoteEmbeddings();
await _storeRemoteEmbeddings(remoteEmbeddings.embeddings);
}
_syncStatus!.complete();
_syncStatus = null;
}
Future<void> pushEmbeddings() async {
final query = (ObjectBox.instance
.getEmbeddingBox()
.query(Embedding_.updationTime.isNull()))
.build();
final pendingItems = query.find();
query.close();
for (final item in pendingItems) {
final file = await FilesDB.instance.getAnyUploadedFile(item.fileID);
await _pushEmbedding(file!, item);
}
}
Future<void> storeEmbedding(EnteFile file, Embedding embedding) async {
ObjectBox.instance.getEmbeddingBox().put(embedding);
_pushEmbedding(file, embedding);
}
Future<void> _pushEmbedding(EnteFile file, Embedding embedding) async {
final encryptionKey = getFileKey(file);
final embeddingJSON = jsonEncode(embedding.embedding);
final encryptedEmbedding = await CryptoUtil.encryptChaCha(
utf8.encode(embeddingJSON) as Uint8List,
encryptionKey,
);
final encryptedData =
CryptoUtil.bin2base64(encryptedEmbedding.encryptedData!);
final header = CryptoUtil.bin2base64(encryptedEmbedding.header!);
try {
final response = await _dio.put(
"/embeddings",
data: {
"fileID": embedding.fileID,
"model": embedding.model,
"encryptedEmbedding": encryptedData,
"decryptionHeader": header,
},
);
final updationTime = response.data["updationTime"];
embedding.updationTime = updationTime;
ObjectBox.instance.getEmbeddingBox().put(embedding);
} catch (e, s) {
_logger.severe(e, s);
}
}
Future<RemoteEmbeddings> _getRemoteEmbeddings({
int limit = 500,
}) async {
final remoteEmbeddings = <RemoteEmbedding>[];
try {
final sinceTime = _preferences.getInt(kEmbeddingsSyncTimeKey) ?? 0;
_logger.info("Fetching embeddings since $sinceTime");
final response = await _dio.get(
"/embeddings/diff",
queryParameters: {
"sinceTime": sinceTime,
"limit": limit,
},
);
final diff = response.data["diff"] as List;
for (var entry in diff) {
final embedding = RemoteEmbedding.fromMap(entry);
remoteEmbeddings.add(embedding);
}
} catch (e, s) {
_logger.severe(e, s);
}
_logger.info("${remoteEmbeddings.length} embeddings fetched");
return RemoteEmbeddings(
remoteEmbeddings,
remoteEmbeddings.length == limit,
);
}
Future<void> _storeRemoteEmbeddings(
List<RemoteEmbedding> remoteEmbeddings,
) async {
if (remoteEmbeddings.isEmpty) {
return;
}
final inputs = <EmbeddingsDecoderInput>[];
for (final embedding in remoteEmbeddings) {
final file = await FilesDB.instance.getAnyUploadedFile(embedding.fileID);
if (file == null) {
continue;
}
final fileKey = getFileKey(file);
final input = EmbeddingsDecoderInput(embedding, fileKey);
inputs.add(input);
}
final embeddings = await _computer.compute(
decodeEmbeddings,
param: {
"inputs": inputs,
},
);
_logger.info("${embeddings.length} embeddings decoded");
await ObjectBox.instance.getEmbeddingBox().putManyAsync(embeddings);
await _preferences.setInt(
kEmbeddingsSyncTimeKey,
embeddings.last.updationTime!,
);
_logger.info("${embeddings.length} embeddings stored");
}
}
Future<List<Embedding>> decodeEmbeddings(Map<String, dynamic> args) async {
final embeddings = <Embedding>[];
final inputs = args["inputs"] as List<EmbeddingsDecoderInput>;
for (final input in inputs) {
final decryptArgs = <String, dynamic>{};
decryptArgs["source"] =
CryptoUtil.base642bin(input.embedding.encryptedEmbedding);
decryptArgs["key"] = input.decryptionKey;
decryptArgs["header"] =
CryptoUtil.base642bin(input.embedding.decryptionHeader);
final embeddingData = chachaDecryptData(decryptArgs);
final List<double> decodedEmbedding = jsonDecode(utf8.decode(embeddingData))
.map((item) => item.toDouble())
.cast<double>()
.toList();
embeddings.add(
Embedding(
fileID: input.embedding.fileID,
model: input.embedding.model,
embedding: decodedEmbedding,
updationTime: input.embedding.updatedAt,
),
);
}
return embeddings;
}
class EmbeddingsDecoderInput {
final RemoteEmbedding embedding;
final Uint8List decryptionKey;
EmbeddingsDecoderInput(this.embedding, this.decryptionKey);
}