import "dart:io"; import "package:logging/logging.dart"; import "package:path/path.dart"; import "package:path_provider/path_provider.dart"; import "package:photos/core/network/network.dart"; abstract class MLFramework { final _logger = Logger("MLFramework"); /// Returns the path of the Image Model hosted remotely String getImageModelRemotePath(); /// Returns the path of the Text Model hosted remotely String getTextModelRemotePath(); /// Loads the Image Model stored at [path] into the framework Future loadImageModel(String path); /// Loads the Text Model stored at [path] into the framework Future loadTextModel(String path); /// Returns the Image Embedding for a file stored at [imagePath] Future> getImageEmbedding(String imagePath); /// Returns the Text Embedding for [text] Future> getTextEmbedding(String text); /// Downloads the models from remote, caches them and loads them into the /// framework. Override this method if you would like to control the /// initialization. For eg. if you wish to load the model from `/assets` /// instead of a CDN. Future init() async { await _initImageModel(); await _initTextModel(); } /// Returns the cosine similarity between [imageEmbedding] and [textEmbedding] double computeScore(List imageEmbedding, List textEmbedding) { assert( imageEmbedding.length == textEmbedding.length, "The two embeddings should have the same length", ); double score = 0; for (int index = 0; index < imageEmbedding.length; index++) { score += imageEmbedding[index] * textEmbedding[index]; } return score; } // --- // Private methods // --- Future _initImageModel() async { //final path = await _getLocalImageModelPath(); const path = "assets/models/clip/clip-image-vit-32-float32.onnx"; await loadImageModel(path); // if (File(path).existsSync()) { // await loadImageModel(path); // } else { // final tempFile = File(path + ".temp"); // await _downloadFile(getImageModelRemotePath(), tempFile.path); // await tempFile.rename(path); // await loadImageModel(path); // } } Future _initTextModel() async { //final path = await _getLocalTextModelPath(); const path = "assets/models/clip/clip-text-vit-32-float32.onnx"; await loadTextModel(path); // if (File(path).existsSync()) { // await loadTextModel(path); // } else { // final tempFile = File(path + ".temp"); // await _downloadFile(getTextModelRemotePath(), tempFile.path); // await tempFile.rename(path); // await loadTextModel(path); // } } Future _getLocalImageModelPath() async { return (await getTemporaryDirectory()).path + "/models/" + basename(getImageModelRemotePath()); } Future _getLocalTextModelPath() async { return (await getTemporaryDirectory()).path + "/models/" + basename(getTextModelRemotePath()); } Future _downloadFile(String url, String savePath) async { _logger.info("Downloading " + url); final existingFile = File(savePath); if (await existingFile.exists()) { await existingFile.delete(); } await NetworkClient.instance.getDio().download(url, savePath); } }