From 53fa9c6830fa771bf491aff03c972beebedae2b9 Mon Sep 17 00:00:00 2001 From: vishnukvmd Date: Tue, 24 Oct 2023 14:08:45 +0530 Subject: [PATCH] Load models separately --- .../semantic_search_service.dart | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/lib/services/semantic_search/semantic_search_service.dart b/lib/services/semantic_search/semantic_search_service.dart index ad8b14c70..6279ca470 100644 --- a/lib/services/semantic_search/semantic_search_service.dart +++ b/lib/services/semantic_search/semantic_search_service.dart @@ -27,8 +27,11 @@ class SemanticSearchService { static final Computer _computer = Computer.shared(); static const int batchSize = 1; - static const kModelPath = - "assets/models/clip/openai_clip-vit-base-patch32.ggmlv0.f16.bin"; + static const kModelName = "ggml-clip"; + static const kImageModelPath = + "assets/models/clip/clip-vit-base-patch32_ggml-vision-model-f16.gguf"; + static const kTextModelPath = + "assets/models/clip/clip-vit-base-patch32_ggml-text-model-f16.gguf"; final _logger = Logger("SemanticSearchService"); final _queue = Queue(); @@ -39,7 +42,7 @@ class SemanticSearchService { PendingQuery? _nextQuery; Future init(SharedPreferences preferences) async { - await _loadModel(); + await _loadModels(); await EmbeddingStore.instance.init(preferences); startBackFill(); @@ -161,10 +164,14 @@ class SemanticSearchService { return IndexStatus(embeddings.length, _queue.length); } - Future _loadModel() async { - final path = await _getAccessiblePathForAsset(kModelPath, "model.bin"); + Future _loadModels() async { final startTime = DateTime.now(); - CLIP.loadModel(path); + final imageModelPath = + await _getAccessiblePathForAsset(kImageModelPath, "image_model.bin"); + CLIP.loadImageModel(imageModelPath); + final textModelPath = + await _getAccessiblePathForAsset(kTextModelPath, "text_model.bin"); + CLIP.loadTextModel(textModelPath); final endTime = DateTime.now(); _logger.info( "Loading model took: " + @@ -264,7 +271,7 @@ class SemanticSearchService { files[i], Embedding( files[i].uploadedFileID!, - "c_uq", + kModelName, imageEmbeddings[i], ), );