diff --git a/lib/main.dart b/lib/main.dart index ef1870783..10531e961 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -34,6 +34,7 @@ import 'package:photos/services/memories_service.dart'; import 'package:photos/services/push_service.dart'; import 'package:photos/services/remote_sync_service.dart'; import 'package:photos/services/search_service.dart'; +import "package:photos/services/semantic_search_service.dart"; import "package:photos/services/storage_bonus_service.dart"; import 'package:photos/services/sync_service.dart'; import 'package:photos/services/trash_sync_service.dart'; @@ -189,6 +190,7 @@ Future _init(bool isBackground, {String via = ''}) async { }); } FeatureFlagService.instance.init(); + SemanticSearchService.instance.init(); // Can not including existing tf/ml binaries as they are not being built // from source. diff --git a/lib/services/semantic_search_service.dart b/lib/services/semantic_search_service.dart new file mode 100644 index 000000000..40fcd508d --- /dev/null +++ b/lib/services/semantic_search_service.dart @@ -0,0 +1,71 @@ +import "dart:convert"; +import "dart:io"; + +import "package:clip_ggml/clip_ggml.dart"; +import "package:flutter/services.dart"; +import "package:logging/logging.dart"; +import "package:path_provider/path_provider.dart"; + +class SemanticSearchService { + SemanticSearchService._privateConstructor(); + + static final SemanticSearchService instance = + SemanticSearchService._privateConstructor(); + + late CLIP _clip; + final _logger = Logger("SemanticSearchService"); + + Future init() async { + _clip = CLIP(); + await _loadModel(); + _testJson(); + } + + Future _loadModel() async { + final clip = CLIP(); + const modelPath = + "assets/models/clip/openai_clip-vit-base-patch32.ggmlv0.f16.bin"; + + final path = await _getAccessiblePathForAsset(modelPath, "model.bin"); + final startTime = DateTime.now(); + clip.loadModel(path); + final endTime = DateTime.now(); + _logger.info( + "Loading model took: " + + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) + .toString() + + "ms", + ); + + _testJson(); + } + + Future _testJson() async { + final startTime = DateTime.now(); + final input = { + "embedding": [1.1, 2.2], + }; + _logger.info(jsonEncode(input)); + final result = _clip.testJSON(jsonEncode(input)); + final endTime = DateTime.now(); + _logger.info( + "Output: " + + result + + " (" + + (endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch) + .toString() + + "ms)", + ); + } + + Future _getAccessiblePathForAsset( + String assetPath, + String tempName, + ) async { + final byteData = await rootBundle.load(assetPath); + final tempDir = await getTemporaryDirectory(); + final file = await File('${tempDir.path}/$tempName') + .writeAsBytes(byteData.buffer.asUint8List()); + return file.path; + } +}