diff --git a/mobile/lib/services/magic_cache_service.dart b/mobile/lib/services/magic_cache_service.dart index 5fd93b688..3efe8b6dd 100644 --- a/mobile/lib/services/magic_cache_service.dart +++ b/mobile/lib/services/magic_cache_service.dart @@ -1,6 +1,7 @@ import 'dart:math'; import "package:logging/logging.dart"; +import "package:photos/services/machine_learning/semantic_search/semantic_search_service.dart"; import "package:shared_preferences/shared_preferences.dart"; const _promptsJson = { @@ -48,7 +49,8 @@ class MagicCacheService { List> getRandomPrompts() { final promptsJson = _promptsJson["prompts"]; final randomPrompts = >[]; - final randomNumbers = _generateRandomNumbers(promptsJson!.length - 1, 4); + final randomNumbers = + _generateUniqueRandomNumbers(promptsJson!.length - 1, 4); for (int i = 0; i < randomNumbers.length; i++) { randomPrompts.add(promptsJson[randomNumbers[i]]); } @@ -56,10 +58,25 @@ class MagicCacheService { return randomPrompts; } - List _generateRandomNumbers(int max, int count) { + Future> getMatchingFileIDsForPromptData( + Map promptData, + ) { + return SemanticSearchService.instance.getMatchingFileIDs( + promptData["prompt"] as String, + promptData["minimumScore"] as double, + ); + } + + ///Generates from 0 to max unique random numbers + List _generateUniqueRandomNumbers(int max, int count) { final numbers = []; - for (int i = 1; i <= count; i++) { - numbers.add(Random().nextInt(max + 1)); + for (int i = 1; i <= count;) { + final randomNumber = Random().nextInt(max + 1); + if (numbers.contains(randomNumber)) { + continue; + } + numbers.add(randomNumber); + i++; } return numbers; }