Compare commits
7 commits
main
...
rediscover
Author | SHA1 | Date | |
---|---|---|---|
c4f5ce8298 | |||
e8e26ee523 | |||
71f1494444 | |||
9134e75bf0 | |||
8573649cf7 | |||
6bd34014fb | |||
d132353291 |
|
@ -1371,10 +1371,9 @@ class FilesDB {
|
||||||
inParam += "'" + id.toString() + "',";
|
inParam += "'" + id.toString() + "',";
|
||||||
}
|
}
|
||||||
inParam = inParam.substring(0, inParam.length - 1);
|
inParam = inParam.substring(0, inParam.length - 1);
|
||||||
final db = await instance.database;
|
final db = await instance.sqliteAsyncDB;
|
||||||
final results = await db.query(
|
final results = await db.getAll(
|
||||||
filesTable,
|
'SELECT * FROM $filesTable WHERE $columnUploadedFileID IN ($inParam)',
|
||||||
where: '$columnUploadedFileID IN ($inParam)',
|
|
||||||
);
|
);
|
||||||
final files = convertToFiles(results);
|
final files = convertToFiles(results);
|
||||||
for (final file in files) {
|
for (final file in files) {
|
||||||
|
@ -1393,10 +1392,9 @@ class FilesDB {
|
||||||
inParam += "'" + id.toString() + "',";
|
inParam += "'" + id.toString() + "',";
|
||||||
}
|
}
|
||||||
inParam = inParam.substring(0, inParam.length - 1);
|
inParam = inParam.substring(0, inParam.length - 1);
|
||||||
final db = await instance.database;
|
final db = await instance.sqliteAsyncDB;
|
||||||
final results = await db.query(
|
final results = await db.getAll(
|
||||||
filesTable,
|
'SELECT * FROM $filesTable WHERE $columnGeneratedID IN ($inParam)',
|
||||||
where: '$columnGeneratedID IN ($inParam)',
|
|
||||||
);
|
);
|
||||||
final files = convertToFiles(results);
|
final files = convertToFiles(results);
|
||||||
for (final file in files) {
|
for (final file in files) {
|
||||||
|
|
|
@ -13,6 +13,7 @@ import "package:photos/models/search/generic_search_result.dart";
|
||||||
import "package:photos/models/search/search_result.dart";
|
import "package:photos/models/search/search_result.dart";
|
||||||
import "package:photos/models/typedefs.dart";
|
import "package:photos/models/typedefs.dart";
|
||||||
import "package:photos/services/collections_service.dart";
|
import "package:photos/services/collections_service.dart";
|
||||||
|
import "package:photos/services/magic_cache_service.dart";
|
||||||
import "package:photos/services/search_service.dart";
|
import "package:photos/services/search_service.dart";
|
||||||
import "package:photos/ui/viewer/gallery/collection_page.dart";
|
import "package:photos/ui/viewer/gallery/collection_page.dart";
|
||||||
import "package:photos/ui/viewer/location/add_location_sheet.dart";
|
import "package:photos/ui/viewer/location/add_location_sheet.dart";
|
||||||
|
@ -40,7 +41,7 @@ enum SectionType {
|
||||||
face,
|
face,
|
||||||
location,
|
location,
|
||||||
// Grouping based on ML or manual tagging
|
// Grouping based on ML or manual tagging
|
||||||
content,
|
magic,
|
||||||
// includes year, month , day, event ResultType
|
// includes year, month , day, event ResultType
|
||||||
moment,
|
moment,
|
||||||
album,
|
album,
|
||||||
|
@ -56,7 +57,7 @@ extension SectionTypeExtensions on SectionType {
|
||||||
switch (this) {
|
switch (this) {
|
||||||
case SectionType.face:
|
case SectionType.face:
|
||||||
return S.of(context).faces;
|
return S.of(context).faces;
|
||||||
case SectionType.content:
|
case SectionType.magic:
|
||||||
return S.of(context).contents;
|
return S.of(context).contents;
|
||||||
case SectionType.moment:
|
case SectionType.moment:
|
||||||
return S.of(context).moments;
|
return S.of(context).moments;
|
||||||
|
@ -77,7 +78,7 @@ extension SectionTypeExtensions on SectionType {
|
||||||
switch (this) {
|
switch (this) {
|
||||||
case SectionType.face:
|
case SectionType.face:
|
||||||
return S.of(context).searchFaceEmptySection;
|
return S.of(context).searchFaceEmptySection;
|
||||||
case SectionType.content:
|
case SectionType.magic:
|
||||||
return "Contents";
|
return "Contents";
|
||||||
case SectionType.moment:
|
case SectionType.moment:
|
||||||
return S.of(context).searchDatesEmptySection;
|
return S.of(context).searchDatesEmptySection;
|
||||||
|
@ -100,7 +101,7 @@ extension SectionTypeExtensions on SectionType {
|
||||||
switch (this) {
|
switch (this) {
|
||||||
case SectionType.face:
|
case SectionType.face:
|
||||||
return false;
|
return false;
|
||||||
case SectionType.content:
|
case SectionType.magic:
|
||||||
return false;
|
return false;
|
||||||
case SectionType.moment:
|
case SectionType.moment:
|
||||||
return false;
|
return false;
|
||||||
|
@ -121,7 +122,7 @@ extension SectionTypeExtensions on SectionType {
|
||||||
switch (this) {
|
switch (this) {
|
||||||
case SectionType.face:
|
case SectionType.face:
|
||||||
return true;
|
return true;
|
||||||
case SectionType.content:
|
case SectionType.magic:
|
||||||
return false;
|
return false;
|
||||||
case SectionType.moment:
|
case SectionType.moment:
|
||||||
return false;
|
return false;
|
||||||
|
@ -143,7 +144,7 @@ extension SectionTypeExtensions on SectionType {
|
||||||
case SectionType.face:
|
case SectionType.face:
|
||||||
// todo: later
|
// todo: later
|
||||||
return "Setup";
|
return "Setup";
|
||||||
case SectionType.content:
|
case SectionType.magic:
|
||||||
// todo: later
|
// todo: later
|
||||||
return "Add tags";
|
return "Add tags";
|
||||||
case SectionType.moment:
|
case SectionType.moment:
|
||||||
|
@ -165,7 +166,7 @@ extension SectionTypeExtensions on SectionType {
|
||||||
switch (this) {
|
switch (this) {
|
||||||
case SectionType.face:
|
case SectionType.face:
|
||||||
return Icons.adaptive.arrow_forward_outlined;
|
return Icons.adaptive.arrow_forward_outlined;
|
||||||
case SectionType.content:
|
case SectionType.magic:
|
||||||
return null;
|
return null;
|
||||||
case SectionType.moment:
|
case SectionType.moment:
|
||||||
return null;
|
return null;
|
||||||
|
@ -247,8 +248,8 @@ extension SectionTypeExtensions on SectionType {
|
||||||
case SectionType.face:
|
case SectionType.face:
|
||||||
return Future.value(List<GenericSearchResult>.empty());
|
return Future.value(List<GenericSearchResult>.empty());
|
||||||
|
|
||||||
case SectionType.content:
|
case SectionType.magic:
|
||||||
return Future.value(List<GenericSearchResult>.empty());
|
return MagicCacheService.instance.getMagicGenericSearchResult();
|
||||||
|
|
||||||
case SectionType.moment:
|
case SectionType.moment:
|
||||||
return SearchService.instance.getRandomMomentsSearchResults(context);
|
return SearchService.instance.getRandomMomentsSearchResults(context);
|
||||||
|
|
|
@ -267,6 +267,48 @@ class SemanticSearchService {
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Future<List<int>> getMatchingFileIDs(String query, double minScore) async {
|
||||||
|
final textEmbedding = await _getTextEmbedding(query);
|
||||||
|
|
||||||
|
final queryResults =
|
||||||
|
await _getScores(textEmbedding, scoreThreshold: minScore);
|
||||||
|
|
||||||
|
final filesMap = await FilesDB.instance.getFilesFromIDs(
|
||||||
|
queryResults
|
||||||
|
.map(
|
||||||
|
(e) => e.id,
|
||||||
|
)
|
||||||
|
.toList(),
|
||||||
|
);
|
||||||
|
final results = <EnteFile>[];
|
||||||
|
|
||||||
|
final ignoredCollections =
|
||||||
|
CollectionsService.instance.getHiddenCollectionIds();
|
||||||
|
final deletedEntries = <int>[];
|
||||||
|
for (final result in queryResults) {
|
||||||
|
final file = filesMap[result.id];
|
||||||
|
if (file != null && !ignoredCollections.contains(file.collectionID)) {
|
||||||
|
results.add(file);
|
||||||
|
}
|
||||||
|
if (file == null) {
|
||||||
|
deletedEntries.add(result.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_logger.info(results.length.toString() + " results");
|
||||||
|
|
||||||
|
if (deletedEntries.isNotEmpty) {
|
||||||
|
unawaited(EmbeddingsDB.instance.deleteEmbeddings(deletedEntries));
|
||||||
|
}
|
||||||
|
|
||||||
|
final matchingFileIDs = <int>[];
|
||||||
|
for (EnteFile file in results) {
|
||||||
|
matchingFileIDs.add(file.uploadedFileID!);
|
||||||
|
}
|
||||||
|
|
||||||
|
return matchingFileIDs;
|
||||||
|
}
|
||||||
|
|
||||||
void _addToQueue(EnteFile file) {
|
void _addToQueue(EnteFile file) {
|
||||||
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
|
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
|
||||||
return;
|
return;
|
||||||
|
@ -355,13 +397,17 @@ class SemanticSearchService {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<List<QueryResult>> _getScores(List<double> textEmbedding) async {
|
Future<List<QueryResult>> _getScores(
|
||||||
|
List<double> textEmbedding, {
|
||||||
|
double? scoreThreshold,
|
||||||
|
}) async {
|
||||||
final startTime = DateTime.now();
|
final startTime = DateTime.now();
|
||||||
final List<QueryResult> queryResults = await _computer.compute(
|
final List<QueryResult> queryResults = await _computer.compute(
|
||||||
computeBulkScore,
|
computeBulkScore,
|
||||||
param: {
|
param: {
|
||||||
"imageEmbeddings": _cachedEmbeddings,
|
"imageEmbeddings": _cachedEmbeddings,
|
||||||
"textEmbedding": textEmbedding,
|
"textEmbedding": textEmbedding,
|
||||||
|
"scoreThreshold": scoreThreshold,
|
||||||
},
|
},
|
||||||
taskName: "computeBulkScore",
|
taskName: "computeBulkScore",
|
||||||
);
|
);
|
||||||
|
@ -402,12 +448,14 @@ List<QueryResult> computeBulkScore(Map args) {
|
||||||
final queryResults = <QueryResult>[];
|
final queryResults = <QueryResult>[];
|
||||||
final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
|
final imageEmbeddings = args["imageEmbeddings"] as List<Embedding>;
|
||||||
final textEmbedding = args["textEmbedding"] as List<double>;
|
final textEmbedding = args["textEmbedding"] as List<double>;
|
||||||
|
final scoreThreshold = args["scoreThreshold"] as double? ??
|
||||||
|
SemanticSearchService.kScoreThreshold;
|
||||||
for (final imageEmbedding in imageEmbeddings) {
|
for (final imageEmbedding in imageEmbeddings) {
|
||||||
final score = computeScore(
|
final score = computeScore(
|
||||||
imageEmbedding.embedding,
|
imageEmbedding.embedding,
|
||||||
textEmbedding,
|
textEmbedding,
|
||||||
);
|
);
|
||||||
if (score >= SemanticSearchService.kScoreThreshold) {
|
if (score >= scoreThreshold) {
|
||||||
queryResults.add(QueryResult(imageEmbedding.fileID, score));
|
queryResults.add(QueryResult(imageEmbedding.fileID, score));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
192
mobile/lib/services/magic_cache_service.dart
Normal file
192
mobile/lib/services/magic_cache_service.dart
Normal file
|
@ -0,0 +1,192 @@
|
||||||
|
import "dart:convert";
|
||||||
|
import 'dart:math';
|
||||||
|
|
||||||
|
import "package:logging/logging.dart";
|
||||||
|
import "package:photos/models/file/file.dart";
|
||||||
|
import "package:photos/models/search/generic_search_result.dart";
|
||||||
|
import "package:photos/models/search/search_types.dart";
|
||||||
|
import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart";
|
||||||
|
import "package:photos/services/search_service.dart";
|
||||||
|
import "package:shared_preferences/shared_preferences.dart";
|
||||||
|
|
||||||
|
const _promptsJson = {
|
||||||
|
"prompts": [
|
||||||
|
{
|
||||||
|
"prompt": "identity document",
|
||||||
|
"title": "Identity Document",
|
||||||
|
"minimumScore": 0.269,
|
||||||
|
"minimumSize": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "sunset at the beach",
|
||||||
|
"title": "Sunset",
|
||||||
|
"minimumScore": 0.25,
|
||||||
|
"minimumSize": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "roadtrip",
|
||||||
|
"title": "Roadtrip",
|
||||||
|
"minimumScore": 0.26,
|
||||||
|
"minimumSize": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"prompt": "pizza pasta burger",
|
||||||
|
"title": "Food",
|
||||||
|
"minimumScore": 0.27,
|
||||||
|
"minimumSize": 0.0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
class MagicCache {
|
||||||
|
final String title;
|
||||||
|
final List<int> fileUploadedIDs;
|
||||||
|
MagicCache(this.title, this.fileUploadedIDs);
|
||||||
|
|
||||||
|
factory MagicCache.fromJson(Map<String, dynamic> json) {
|
||||||
|
return MagicCache(
|
||||||
|
json['title'],
|
||||||
|
List<int>.from(json['fileUploadedIDs']),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Map<String, dynamic> toJson() {
|
||||||
|
return {
|
||||||
|
'title': title,
|
||||||
|
'fileUploadedIDs': fileUploadedIDs,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
static String encodeListToJson(List<MagicCache> magicCaches) {
|
||||||
|
final jsonList = magicCaches.map((cache) => cache.toJson()).toList();
|
||||||
|
return jsonEncode(jsonList);
|
||||||
|
}
|
||||||
|
|
||||||
|
static List<MagicCache> decodeJsonToList(String jsonString) {
|
||||||
|
final jsonList = jsonDecode(jsonString) as List;
|
||||||
|
return jsonList.map((json) => MagicCache.fromJson(json)).toList();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
extension MagicCacheServiceExtension on MagicCache {
|
||||||
|
Future<GenericSearchResult> toGenericSearchResult() async {
|
||||||
|
final allEnteFiles = await SearchService.instance.getAllFiles();
|
||||||
|
final enteFilesInMagicCache = <EnteFile>[];
|
||||||
|
for (EnteFile file in allEnteFiles) {
|
||||||
|
if (file.uploadedFileID != null &&
|
||||||
|
fileUploadedIDs.contains(file.uploadedFileID as int)) {
|
||||||
|
enteFilesInMagicCache.add(file);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return GenericSearchResult(
|
||||||
|
ResultType.magic,
|
||||||
|
title,
|
||||||
|
enteFilesInMagicCache,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class MagicCacheService {
|
||||||
|
static const _key = "magic_cache";
|
||||||
|
late SharedPreferences prefs;
|
||||||
|
final Logger _logger = Logger((MagicCacheService).toString());
|
||||||
|
MagicCacheService._privateConstructor();
|
||||||
|
|
||||||
|
static final MagicCacheService instance =
|
||||||
|
MagicCacheService._privateConstructor();
|
||||||
|
|
||||||
|
void init(SharedPreferences preferences) {
|
||||||
|
prefs = preferences;
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Map<String, Object>> getRandomPrompts() {
|
||||||
|
final promptsJson = _promptsJson["prompts"];
|
||||||
|
final randomPrompts = <Map<String, Object>>[];
|
||||||
|
final randomNumbers =
|
||||||
|
_generateUniqueRandomNumbers(promptsJson!.length - 1, 4);
|
||||||
|
for (int i = 0; i < randomNumbers.length; i++) {
|
||||||
|
randomPrompts.add(promptsJson[randomNumbers[i]]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return randomPrompts;
|
||||||
|
}
|
||||||
|
|
||||||
|
Future<Map<String, List<int>>> getMatchingFileIDsForPromptData(
|
||||||
|
Map<String, Object> promptData,
|
||||||
|
) async {
|
||||||
|
final result = await SemanticSearchService.instance.getMatchingFileIDs(
|
||||||
|
promptData["prompt"] as String,
|
||||||
|
promptData["minimumScore"] as double,
|
||||||
|
);
|
||||||
|
|
||||||
|
return {promptData["title"] as String: result};
|
||||||
|
}
|
||||||
|
|
||||||
|
Future<void> updateMagicCache(List<MagicCache> magicCaches) async {
|
||||||
|
await prefs.setString(
|
||||||
|
_key,
|
||||||
|
MagicCache.encodeListToJson(magicCaches),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Future<List<MagicCache>?> getMagicCache() async {
|
||||||
|
final jsonString = prefs.getString(_key);
|
||||||
|
if (jsonString == null) {
|
||||||
|
_logger.info("No $_key in shared preferences");
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return MagicCache.decodeJsonToList(jsonString);
|
||||||
|
}
|
||||||
|
|
||||||
|
Future<void> clearMagicCache() async {
|
||||||
|
await prefs.remove(_key);
|
||||||
|
}
|
||||||
|
|
||||||
|
Future<List<GenericSearchResult>> getMagicGenericSearchResult() async {
|
||||||
|
final magicCaches = await getMagicCache();
|
||||||
|
if (magicCaches == null) {
|
||||||
|
_logger.info("No magic cache found");
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
final List<GenericSearchResult> genericSearchResults = [];
|
||||||
|
for (MagicCache magicCache in magicCaches) {
|
||||||
|
final genericSearchResult = await magicCache.toGenericSearchResult();
|
||||||
|
genericSearchResults.add(genericSearchResult);
|
||||||
|
}
|
||||||
|
return genericSearchResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
Future<void> reloadMagicCaches() async {
|
||||||
|
_logger.info("Reloading magic caches");
|
||||||
|
final randomPromptsData = MagicCacheService.instance.getRandomPrompts();
|
||||||
|
final promptResults = <Map<String, List<int>>>[];
|
||||||
|
final magicCaches = <MagicCache>[];
|
||||||
|
|
||||||
|
for (var randomPromptData in randomPromptsData) {
|
||||||
|
promptResults.add(
|
||||||
|
await MagicCacheService.instance
|
||||||
|
.getMatchingFileIDsForPromptData(randomPromptData),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for (var promptResult in promptResults) {
|
||||||
|
magicCaches
|
||||||
|
.add(MagicCache(promptResult.keys.first, promptResult.values.first));
|
||||||
|
}
|
||||||
|
|
||||||
|
await MagicCacheService.instance.updateMagicCache(magicCaches);
|
||||||
|
}
|
||||||
|
|
||||||
|
///Generates from 0 to max unique random numbers
|
||||||
|
List<int> _generateUniqueRandomNumbers(int max, int count) {
|
||||||
|
final numbers = <int>[];
|
||||||
|
for (int i = 1; i <= count;) {
|
||||||
|
final randomNumber = Random().nextInt(max + 1);
|
||||||
|
if (numbers.contains(randomNumber)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
numbers.add(randomNumber);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
return numbers;
|
||||||
|
}
|
||||||
|
}
|
|
@ -79,8 +79,7 @@ class _AllSectionsExamplesProviderState
|
||||||
_logger.info("'_debounceTimer: reloading all sections in search tab");
|
_logger.info("'_debounceTimer: reloading all sections in search tab");
|
||||||
final allSectionsExamples = <Future<List<SearchResult>>>[];
|
final allSectionsExamples = <Future<List<SearchResult>>>[];
|
||||||
for (SectionType sectionType in SectionType.values) {
|
for (SectionType sectionType in SectionType.values) {
|
||||||
if (sectionType == SectionType.face ||
|
if (sectionType == SectionType.face) {
|
||||||
sectionType == SectionType.content) {
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
allSectionsExamples.add(
|
allSectionsExamples.add(
|
||||||
|
|
|
@ -22,7 +22,7 @@ class _NoResultWidgetState extends State<NoResultWidget> {
|
||||||
searchTypes = SectionType.values.toList(growable: true);
|
searchTypes = SectionType.values.toList(growable: true);
|
||||||
// remove face and content sectionType
|
// remove face and content sectionType
|
||||||
searchTypes.remove(SectionType.face);
|
searchTypes.remove(SectionType.face);
|
||||||
searchTypes.remove(SectionType.content);
|
searchTypes.remove(SectionType.magic);
|
||||||
}
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
|
|
|
@ -78,7 +78,7 @@ class _AllSearchSectionsState extends State<AllSearchSections> {
|
||||||
final searchTypes = SectionType.values.toList(growable: true);
|
final searchTypes = SectionType.values.toList(growable: true);
|
||||||
// remove face and content sectionType
|
// remove face and content sectionType
|
||||||
searchTypes.remove(SectionType.face);
|
searchTypes.remove(SectionType.face);
|
||||||
searchTypes.remove(SectionType.content);
|
// searchTypes.remove(SectionType.magic);
|
||||||
return Padding(
|
return Padding(
|
||||||
padding: const EdgeInsets.only(top: 8),
|
padding: const EdgeInsets.only(top: 8),
|
||||||
child: Stack(
|
child: Stack(
|
||||||
|
@ -131,6 +131,11 @@ class _AllSearchSectionsState extends State<AllSearchSections> {
|
||||||
snapshot.data!.elementAt(index)
|
snapshot.data!.elementAt(index)
|
||||||
as List<GenericSearchResult>,
|
as List<GenericSearchResult>,
|
||||||
);
|
);
|
||||||
|
case SectionType.magic:
|
||||||
|
return MomentsSection(
|
||||||
|
snapshot.data!.elementAt(index)
|
||||||
|
as List<GenericSearchResult>,
|
||||||
|
);
|
||||||
default:
|
default:
|
||||||
const SizedBox.shrink();
|
const SizedBox.shrink();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue