import "dart:convert"; import "dart:io"; import "dart:math"; import "dart:typed_data"; import "package:computer/computer.dart"; import "package:flutter/services.dart"; import "package:html_unescape/html_unescape.dart"; import 'package:image/image.dart' as img; import "package:logging/logging.dart"; import "package:onnxruntime/onnxruntime.dart"; import "package:photos/services/semantic_search/frameworks/ml_framework.dart"; import "package:tuple/tuple.dart"; class ONNX extends MLFramework { static const kModelBucketEndpoint = "https://models.ente.io/"; static const kImageModel = "clip-vit-base-patch32_ggml-vision-model-f16.gguf"; static const kTextModel = "clip-vit-base-patch32_ggml-text-model-f16.gguf"; final _computer = Computer.shared(); final _logger = Logger("ONNX"); final _clipImage = ClipImageEncoder(); final _clipText = ClipTextEncoder(); @override String getImageModelRemotePath() { return ""; } @override String getTextModelRemotePath() { return ""; } @override Future loadImageModel(String path) async { final startTime = DateTime.now(); await _computer.compute( _clipImage.loadModel, param: { "imageModelPath": path, }, ); final endTime = DateTime.now(); _logger.info( "Loading image model took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", ); } @override Future loadTextModel(String path) async { final startTime = DateTime.now(); await _computer.compute( _clipText.loadModel, param: { "textModelPath": path, }, ); final endTime = DateTime.now(); _logger.info( "Loading text model took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms", ); } @override Future> getImageEmbedding(String imagePath) async { try { final startTime = DateTime.now(); final result = await _computer.compute( _clipImage.inferByImage, param: { "imagePath": imagePath, }, taskName: "createImageEmbedding", ) as List; final endTime = DateTime.now(); _logger.info( "createImageEmbedding took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", ); return result; } catch (e, s) { _logger.severe(e, s); rethrow; } } @override Future> getTextEmbedding(String text) async { try { final startTime = DateTime.now(); final result = await _computer.compute( _clipText.infer, param: { "text": text, }, taskName: "createTextEmbedding", ) as List; final endTime = DateTime.now(); _logger.info( "createTextEmbedding took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms", ); return result; } catch (e, s) { _logger.severe(e, s); rethrow; } } } class ClipImageEncoder { OrtSessionOptions? _sessionOptions; OrtSession? _session; final _logger = Logger("CLIPImageEncoder"); ClipImageEncoder() { OrtEnv.instance.init(); } release() { _sessionOptions?.release(); _sessionOptions = null; _session?.release(); _session = null; OrtEnv.instance.release(); } loadModel(Map args) async { _sessionOptions = OrtSessionOptions() ..setInterOpNumThreads(1) ..setIntraOpNumThreads(1) ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); try { //const assetFileName = 'assets/models/clip-image-vit-32-float32.onnx'; // Check if the path exists locally final rawAssetFile = await rootBundle.load(args["imageModelPath"]); final bytes = rawAssetFile.buffer.asUint8List(); _session = OrtSession.fromBuffer(bytes, _sessionOptions!); _logger.info('image model loaded'); } catch (e, s) { _logger.severe(e, s); rethrow; } } List inferByImage(Map args) { final runOptions = OrtRunOptions(); //Check the existence of imagePath locally final rgb = img.decodeImage(File(args["imagePath"]).readAsBytesSync())!; dynamic inputImage; if (rgb.height >= rgb.width) { inputImage = img.copyResize( rgb, width: 224, interpolation: img.Interpolation.linear, ); inputImage = img.copyCrop( inputImage, x: 0, y: (inputImage.height - 224) ~/ 2, width: 224, height: 224, ); } else { inputImage = img.copyResize( rgb, height: 224, interpolation: img.Interpolation.linear, ); inputImage = img.copyCrop( inputImage, x: (inputImage.width - 224) ~/ 2, y: 0, width: 224, height: 224, ); } final mean = [0.48145466, 0.4578275, 0.40821073]; final std = [0.26862954, 0.26130258, 0.27577711]; final processedImage = imageToByteListFloat32(rgb, 224, mean, std); final inputOrt = OrtValueTensor.createTensorWithDataList( processedImage, [1, 3, 224, 224], ); final inputs = {'input': inputOrt}; final outputs = _session?.run(runOptions, inputs); final embedding = (outputs?[0]?.value as List>)[0]; double imageNormalization = 0; for (int i = 0; i < 512; i++) { imageNormalization += embedding[i] * embedding[i]; } for (int i = 0; i < 512; i++) { embedding[i] = embedding[i] / sqrt(imageNormalization); } inputOrt.release(); runOptions.release(); return embedding; } Float32List imageToByteListFloat32( img.Image image, int inputSize, List mean, List std, ) { final convertedBytes = Float32List(1 * inputSize * inputSize * 3); final buffer = Float32List.view(convertedBytes.buffer); int pixelIndex = 0; assert(mean.length == 3); assert(std.length == 3); //TODO: rewrite this part for (var i = 0; i < inputSize; i++) { for (var j = 0; j < inputSize; j++) { final pixel = image.getPixel(i, j); buffer[pixelIndex++] = ((pixel.r / 255) - mean[0]) / std[0]; buffer[pixelIndex++] = ((pixel.g / 255) - mean[1]) / std[1]; buffer[pixelIndex++] = ((pixel.b / 255) - mean[2]) / std[2]; } } return convertedBytes.buffer.asFloat32List(); } } class ClipTextEncoder { static const vocabFilePath = "assets/clip/bpe_simple_vocab_16e6.txt"; final _logger = Logger("CLIPTextEncoder"); OrtSessionOptions? _sessionOptions; OrtSession? _session; ClipTextEncoder() { OrtEnv.instance.init(); OrtEnv.instance.availableProviders().forEach((element) { print('onnx provider=$element'); }); } release() { _sessionOptions?.release(); _sessionOptions = null; _session?.release(); _session = null; OrtEnv.instance.release(); } loadModel(Map args) async { _sessionOptions = OrtSessionOptions() ..setInterOpNumThreads(1) ..setIntraOpNumThreads(1) ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); try { //const assetFileName = 'assets/models/clip-text-vit-32-float32-int32.onnx'; // Check if path exists locally final rawAssetFile = await rootBundle.load(args["textModelPath"]); final bytes = rawAssetFile.buffer.asUint8List(); _session = OrtSession.fromBuffer(bytes, _sessionOptions!); _logger.info('text model loaded'); } catch (e, s) { _logger.severe('text model not loaded', e, s); } } Future> infer(Map args) async { final text = args["text"]; final runOptions = OrtRunOptions(); final tokenizer = CLIPTokenizer(vocabFilePath); await tokenizer.init(); final data = List.filled(1, Int32List.fromList(tokenizer.tokenize(text))); final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]); final inputs = {'input': inputOrt}; final outputs = _session?.run(runOptions, inputs); final embedding = (outputs?[0]?.value as List>)[0]; double textNormalization = 0; for (int i = 0; i < 512; i++) { textNormalization += embedding[i] * embedding[i]; } for (int i = 0; i < 512; i++) { embedding[i] = embedding[i] / sqrt(textNormalization); } inputOrt.release(); runOptions.release(); _session?.release(); return (embedding); } } class CLIPTokenizer { final String bpePath; late Map byteEncoder; late Map byteDecoder; late Map decoder; late Map encoder; late Map, int> bpeRanks; Map cache = { '<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>', }; // Dart RegExpt does not support Unicode identifiers (\p{L} and \p{N}) RegExp pat = RegExp( r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]+|[^\s\p{L}\p{N}]+""", caseSensitive: false, multiLine: false, ); late int sot; late int eot; CLIPTokenizer(this.bpePath); // Async method since the loadFile returns a Future and dart constructor cannot be async Future init() async { final bpe = await loadFile(); byteEncoder = bytesToUnicode(); byteDecoder = byteEncoder.map((k, v) => MapEntry(v, k)); var split = bpe.split('\n'); split = split.sublist(1, 49152 - 256 - 2 + 1); final merges = split .map((merge) => Tuple2(merge.split(' ')[0], merge.split(' ')[1])) .toList(); final vocab = byteEncoder.values.toList(); vocab.addAll(vocab.map((v) => '$v').toList()); for (var merge = 0; merge < merges.length; merge++) { vocab.add(merges[merge].item1 + merges[merge].item2); } vocab.addAll(['<|startoftext|>', '<|endoftext|>']); // asMap returns the map as a Map decoder = vocab.asMap(); encoder = decoder.map((k, v) => MapEntry(v, k)); bpeRanks = Map.fromIterables( merges.map((merge) => merge), List.generate(merges.length, (i) => i), ); sot = encoder['<|startoftext|>']!; eot = encoder['<|endoftext|>']!; } Future loadFile() async { return await rootBundle.loadString(bpePath); } List encode(String text) { final List bpeTokens = []; text = whitespaceClean(basicClean(text)).toLowerCase(); for (Match match in pat.allMatches(text)) { String token = match[0]!; token = utf8.encode(token).map((b) => byteEncoder[b]).join(); bpe(token) .split(' ') .forEach((bpeToken) => bpeTokens.add(encoder[bpeToken]!)); } return bpeTokens; } String bpe(String token) { if (cache.containsKey(token)) { return cache[token]!; } var word = token.split('').map((char) => char).toList(); word[word.length - 1] = '${word.last}'; var pairs = getPairs(word); if (pairs.isEmpty) { return '$token'; } while (true) { Tuple2 bigram = pairs.first; for (var pair in pairs) { final rank1 = bpeRanks[pair] ?? double.infinity; final rank2 = bpeRanks[bigram] ?? double.infinity; if (rank1 < rank2) { bigram = pair; } } if (!bpeRanks.containsKey(bigram)) { break; } final first = bigram.item1; final second = bigram.item2; final newWord = []; var i = 0; while (i < word.length) { final j = word.sublist(i).indexOf(first); if (j == -1) { newWord.addAll(word.sublist(i)); break; } newWord.addAll(word.sublist(i, i + j)); i = i + j; if (word[i] == first && i < word.length - 1 && word[i + 1] == second) { newWord.add(first + second); i += 2; } else { newWord.add(word[i]); i += 1; } } word = newWord; if (word.length == 1) { break; } else { pairs = getPairs(word); } } final wordStr = word.join(' '); cache[token] = wordStr; return wordStr; } List tokenize(String text, {int nText = 76, bool pad = true}) { var tokens = encode(text); tokens = [sot] + tokens.sublist(0, min(nText - 1, tokens.length)) + [eot]; if (pad) { return tokens + List.filled(nText + 1 - tokens.length, 0); } else { return tokens; } } List pad(List x, int padLength) { return x + List.filled(padLength - x.length, 0); } Map bytesToUnicode() { final List bs = []; for (int i = '!'.codeUnitAt(0); i <= '~'.codeUnitAt(0); i++) { bs.add(i); } for (int i = '¡'.codeUnitAt(0); i <= '¬'.codeUnitAt(0); i++) { bs.add(i); } for (int i = '®'.codeUnitAt(0); i <= 'ÿ'.codeUnitAt(0); i++) { bs.add(i); } final List cs = List.from(bs); int n = 0; for (int b = 0; b < 256; b++) { if (!bs.contains(b)) { bs.add(b); cs.add(256 + n); n += 1; } } final List ds = cs.map((n) => String.fromCharCode(n)).toList(); return Map.fromIterables(bs, ds); } Set> getPairs(List word) { final Set> pairs = {}; String prevChar = word[0]; for (var i = 1; i < word.length; i++) { pairs.add(Tuple2(prevChar, word[i])); prevChar = word[i]; } return pairs; } String basicClean(String text) { final unescape = HtmlUnescape(); text = unescape.convert(unescape.convert(text)); return text.trim(); } String whitespaceClean(String text) { text = text.replaceAll(RegExp(r'\s+'), ' '); return text.trim(); } }