diff --git a/mobile/lib/db/embeddings_sqlite_db.dart b/mobile/lib/db/embeddings_sqlite_db.dart new file mode 100644 index 000000000..cdbe64fb7 --- /dev/null +++ b/mobile/lib/db/embeddings_sqlite_db.dart @@ -0,0 +1,129 @@ +import "dart:io"; +import "dart:typed_data"; + +import "package:path/path.dart"; +import 'package:path_provider/path_provider.dart'; +import "package:photos/core/event_bus.dart"; +import "package:photos/events/embedding_updated_event.dart"; +import "package:photos/models/embedding.dart"; +import "package:sqlite_async/sqlite_async.dart"; + +class EmbeddingsDB { + EmbeddingsDB._privateConstructor(); + + static final EmbeddingsDB instance = EmbeddingsDB._privateConstructor(); + + static const databaseName = "ente.embeddings.db"; + static const tableName = "embeddings"; + static const columnFileID = "file_id"; + static const columnModel = "model"; + static const columnEmbedding = "embedding"; + static const columnUpdationTime = "updation_time"; + + static Future? _dbFuture; + + Future get _database async { + _dbFuture ??= _initDatabase(); + return _dbFuture!; + } + + Future _initDatabase() async { + final Directory documentsDirectory = + await getApplicationDocumentsDirectory(); + final String path = join(documentsDirectory.path, databaseName); + final migrations = SqliteMigrations() + ..add( + SqliteMigration( + 1, + (tx) async { + await tx.execute( + 'CREATE TABLE $tableName IF NOT EXISTS ($columnFileID INTEGER NOT NULL, $columnModel TEXT NOT NULL, $columnEmbedding BLOB NOT NULL, $columnUpdationTime INTEGER, UNIQUE ($columnFileID, $columnModel))', + ); + }, + ), + ); + final database = SqliteDatabase(path: path); + await migrations.migrate(database); + return database; + } + + Future clearTable() async { + final db = await _database; + await db.execute('DELETE * FROM $tableName'); + } + + Future> getAll(Model model) async { + final db = await _database; + final results = await db.getAll('SELECT * FROM $tableName'); + return _convertToEmbeddings(results); + } + + Future put(Embedding embedding) async { + final db = await _database; + await db.execute( + 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnModel, $columnEmbedding) VALUES (?, ?, ?, ?)', + _getRowFromEmbedding(embedding), + ); + Bus.instance.fire(EmbeddingUpdatedEvent()); + } + + Future putMany(List embeddings) async { + final db = await _database; + final inputs = embeddings.map((e) => _getRowFromEmbedding(e)).toList(); + await db.executeBatch( + 'INSERT OR REPLACE INTO $tableName ($columnFileID, $columnModel, $columnEmbedding) values(?, ?, ?, ?)', + inputs, + ); + Bus.instance.fire(EmbeddingUpdatedEvent()); + } + + Future> getUnsyncedEmbeddings() async { + final db = await _database; + final results = await db.getAll( + 'SELECT * FROM $tableName WHERE $columnUpdationTime IS NULL', + ); + return _convertToEmbeddings(results); + } + + Future deleteEmbeddings(List fileIDs) async { + final db = await _database; + await db.execute( + 'DELETE FROM $tableName WHERE $columnFileID IN (${fileIDs.join(", ")})', + ); + Bus.instance.fire(EmbeddingUpdatedEvent()); + } + + Future deleteAllForModel(Model model) async { + final db = await _database; + await db.execute( + 'DELETE FROM $tableName WHERE $columnModel = ?', + [serialize(model)], + ); + Bus.instance.fire(EmbeddingUpdatedEvent()); + } + + List _convertToEmbeddings(List> results) { + final List embeddings = []; + for (final result in results) { + embeddings.add(_getEmbeddingFromRow(result)); + } + return embeddings; + } + + Embedding _getEmbeddingFromRow(Map row) { + final fileID = row[columnFileID]; + final model = deserialize(row[columnModel]); + final bytes = row[columnEmbedding] as Uint8List; + final list = Float32List.view(bytes.buffer); + return Embedding(fileID: fileID, model: model, embedding: list); + } + + List _getRowFromEmbedding(Embedding embedding) { + return [ + embedding.fileID, + serialize(embedding.model), + Float32List.fromList(embedding.embedding).buffer.asUint8List(), + embedding.updationTime, + ]; + } +}