ente/lib/services/semantic_search/frameworks/onnx.dart
2023-12-13 12:39:08 +05:30

483 lines
14 KiB
Dart

import "dart:convert";
import "dart:io";
import "dart:math";
import "dart:typed_data";
import "package:computer/computer.dart";
import "package:flutter/services.dart";
import "package:html_unescape/html_unescape.dart";
import 'package:image/image.dart' as img;
import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/semantic_search/frameworks/ml_framework.dart";
import "package:tuple/tuple.dart";
class ONNX extends MLFramework {
static const kModelBucketEndpoint = "https://models.ente.io/";
static const kImageModel = "clip-vit-base-patch32_ggml-vision-model-f16.gguf";
static const kTextModel = "clip-vit-base-patch32_ggml-text-model-f16.gguf";
final _computer = Computer.shared();
final _logger = Logger("ONNX");
final _clipImage = ClipImageEncoder();
final _clipText = ClipTextEncoder();
@override
String getImageModelRemotePath() {
return "";
}
@override
String getTextModelRemotePath() {
return "";
}
@override
Future<void> loadImageModel(String path) async {
final startTime = DateTime.now();
await _computer.compute(
_clipImage.loadModel,
param: {
"imageModelPath": path,
},
);
final endTime = DateTime.now();
_logger.info(
"Loading image model took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
);
}
@override
Future<void> loadTextModel(String path) async {
final startTime = DateTime.now();
await _computer.compute(
_clipText.loadModel,
param: {
"textModelPath": path,
},
);
final endTime = DateTime.now();
_logger.info(
"Loading text model took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch).toString()}ms",
);
}
@override
Future<List<double>> getImageEmbedding(String imagePath) async {
try {
final startTime = DateTime.now();
final result = await _computer.compute(
_clipImage.inferByImage,
param: {
"imagePath": imagePath,
},
taskName: "createImageEmbedding",
) as List<double>;
final endTime = DateTime.now();
_logger.info(
"createImageEmbedding took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
);
return result;
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
}
@override
Future<List<double>> getTextEmbedding(String text) async {
try {
final startTime = DateTime.now();
final result = await _computer.compute(
_clipText.infer,
param: {
"text": text,
},
taskName: "createTextEmbedding",
) as List<double>;
final endTime = DateTime.now();
_logger.info(
"createTextEmbedding took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
);
return result;
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
}
}
class ClipImageEncoder {
OrtSessionOptions? _sessionOptions;
OrtSession? _session;
final _logger = Logger("CLIPImageEncoder");
ClipImageEncoder() {
OrtEnv.instance.init();
}
release() {
_sessionOptions?.release();
_sessionOptions = null;
_session?.release();
_session = null;
OrtEnv.instance.release();
}
loadModel(Map args) async {
_sessionOptions = OrtSessionOptions()
..setInterOpNumThreads(1)
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
try {
//const assetFileName = 'assets/models/clip-image-vit-32-float32.onnx';
// Check if the path exists locally
final rawAssetFile = await rootBundle.load(args["imageModelPath"]);
final bytes = rawAssetFile.buffer.asUint8List();
_session = OrtSession.fromBuffer(bytes, _sessionOptions!);
_logger.info('image model loaded');
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
}
List<double> inferByImage(Map args) {
final runOptions = OrtRunOptions();
//Check the existence of imagePath locally
final rgb = img.decodeImage(File(args["imagePath"]).readAsBytesSync())!;
dynamic inputImage;
if (rgb.height >= rgb.width) {
inputImage = img.copyResize(
rgb,
width: 224,
interpolation: img.Interpolation.linear,
);
inputImage = img.copyCrop(
inputImage,
x: 0,
y: (inputImage.height - 224) ~/ 2,
width: 224,
height: 224,
);
} else {
inputImage = img.copyResize(
rgb,
height: 224,
interpolation: img.Interpolation.linear,
);
inputImage = img.copyCrop(
inputImage,
x: (inputImage.width - 224) ~/ 2,
y: 0,
width: 224,
height: 224,
);
}
final mean = [0.48145466, 0.4578275, 0.40821073];
final std = [0.26862954, 0.26130258, 0.27577711];
final processedImage = imageToByteListFloat32(rgb, 224, mean, std);
final inputOrt = OrtValueTensor.createTensorWithDataList(
processedImage,
[1, 3, 224, 224],
);
final inputs = {'input': inputOrt};
final outputs = _session?.run(runOptions, inputs);
final embedding = (outputs?[0]?.value as List<List<double>>)[0];
double imageNormalization = 0;
for (int i = 0; i < 512; i++) {
imageNormalization += embedding[i] * embedding[i];
}
for (int i = 0; i < 512; i++) {
embedding[i] = embedding[i] / sqrt(imageNormalization);
}
inputOrt.release();
runOptions.release();
return embedding;
}
Float32List imageToByteListFloat32(
img.Image image,
int inputSize,
List<double> mean,
List<double> std,
) {
final convertedBytes = Float32List(1 * inputSize * inputSize * 3);
final buffer = Float32List.view(convertedBytes.buffer);
int pixelIndex = 0;
assert(mean.length == 3);
assert(std.length == 3);
//TODO: rewrite this part
for (var i = 0; i < inputSize; i++) {
for (var j = 0; j < inputSize; j++) {
final pixel = image.getPixel(i, j);
buffer[pixelIndex++] = ((pixel.r / 255) - mean[0]) / std[0];
buffer[pixelIndex++] = ((pixel.g / 255) - mean[1]) / std[1];
buffer[pixelIndex++] = ((pixel.b / 255) - mean[2]) / std[2];
}
}
return convertedBytes.buffer.asFloat32List();
}
}
class ClipTextEncoder {
static const vocabFilePath = "assets/clip/bpe_simple_vocab_16e6.txt";
final _logger = Logger("CLIPTextEncoder");
OrtSessionOptions? _sessionOptions;
OrtSession? _session;
ClipTextEncoder() {
OrtEnv.instance.init();
OrtEnv.instance.availableProviders().forEach((element) {
print('onnx provider=$element');
});
}
release() {
_sessionOptions?.release();
_sessionOptions = null;
_session?.release();
_session = null;
OrtEnv.instance.release();
}
loadModel(Map args) async {
_sessionOptions = OrtSessionOptions()
..setInterOpNumThreads(1)
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
try {
//const assetFileName = 'assets/models/clip-text-vit-32-float32-int32.onnx';
// Check if path exists locally
final rawAssetFile = await rootBundle.load(args["textModelPath"]);
final bytes = rawAssetFile.buffer.asUint8List();
_session = OrtSession.fromBuffer(bytes, _sessionOptions!);
_logger.info('text model loaded');
} catch (e, s) {
_logger.severe('text model not loaded', e, s);
}
}
Future<List<double>> infer(Map args) async {
final text = args["text"];
final runOptions = OrtRunOptions();
final tokenizer = CLIPTokenizer(vocabFilePath);
await tokenizer.init();
final data = List.filled(1, Int32List.fromList(tokenizer.tokenize(text)));
final inputOrt = OrtValueTensor.createTensorWithDataList(data, [1, 77]);
final inputs = {'input': inputOrt};
final outputs = _session?.run(runOptions, inputs);
final embedding = (outputs?[0]?.value as List<List<double>>)[0];
double textNormalization = 0;
for (int i = 0; i < 512; i++) {
textNormalization += embedding[i] * embedding[i];
}
for (int i = 0; i < 512; i++) {
embedding[i] = embedding[i] / sqrt(textNormalization);
}
inputOrt.release();
runOptions.release();
_session?.release();
return (embedding);
}
}
class CLIPTokenizer {
final String bpePath;
late Map<int, String> byteEncoder;
late Map<String, int> byteDecoder;
late Map<int, String> decoder;
late Map<String, int> encoder;
late Map<Tuple2<String, String>, int> bpeRanks;
Map<String, String> cache = <String, String>{
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>',
};
// Dart RegExpt does not support Unicode identifiers (\p{L} and \p{N})
RegExp pat = RegExp(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[a-zA-Z]+|[0-9]+|[^\s\p{L}\p{N}]+""",
caseSensitive: false,
multiLine: false,
);
late int sot;
late int eot;
CLIPTokenizer(this.bpePath);
// Async method since the loadFile returns a Future and dart constructor cannot be async
Future init() async {
final bpe = await loadFile();
byteEncoder = bytesToUnicode();
byteDecoder = byteEncoder.map((k, v) => MapEntry(v, k));
var split = bpe.split('\n');
split = split.sublist(1, 49152 - 256 - 2 + 1);
final merges = split
.map((merge) => Tuple2(merge.split(' ')[0], merge.split(' ')[1]))
.toList();
final vocab = byteEncoder.values.toList();
vocab.addAll(vocab.map((v) => '$v</w>').toList());
for (var merge = 0; merge < merges.length; merge++) {
vocab.add(merges[merge].item1 + merges[merge].item2);
}
vocab.addAll(['<|startoftext|>', '<|endoftext|>']);
// asMap returns the map as a Map<int, String>
decoder = vocab.asMap();
encoder = decoder.map((k, v) => MapEntry(v, k));
bpeRanks = Map.fromIterables(
merges.map((merge) => merge),
List.generate(merges.length, (i) => i),
);
sot = encoder['<|startoftext|>']!;
eot = encoder['<|endoftext|>']!;
}
Future<String> loadFile() async {
return await rootBundle.loadString(bpePath);
}
List<int> encode(String text) {
final List<int> bpeTokens = [];
text = whitespaceClean(basicClean(text)).toLowerCase();
for (Match match in pat.allMatches(text)) {
String token = match[0]!;
token = utf8.encode(token).map((b) => byteEncoder[b]).join();
bpe(token)
.split(' ')
.forEach((bpeToken) => bpeTokens.add(encoder[bpeToken]!));
}
return bpeTokens;
}
String bpe(String token) {
if (cache.containsKey(token)) {
return cache[token]!;
}
var word = token.split('').map((char) => char).toList();
word[word.length - 1] = '${word.last}</w>';
var pairs = getPairs(word);
if (pairs.isEmpty) {
return '$token</w>';
}
while (true) {
Tuple2<String, String> bigram = pairs.first;
for (var pair in pairs) {
final rank1 = bpeRanks[pair] ?? double.infinity;
final rank2 = bpeRanks[bigram] ?? double.infinity;
if (rank1 < rank2) {
bigram = pair;
}
}
if (!bpeRanks.containsKey(bigram)) {
break;
}
final first = bigram.item1;
final second = bigram.item2;
final newWord = <String>[];
var i = 0;
while (i < word.length) {
final j = word.sublist(i).indexOf(first);
if (j == -1) {
newWord.addAll(word.sublist(i));
break;
}
newWord.addAll(word.sublist(i, i + j));
i = i + j;
if (word[i] == first && i < word.length - 1 && word[i + 1] == second) {
newWord.add(first + second);
i += 2;
} else {
newWord.add(word[i]);
i += 1;
}
}
word = newWord;
if (word.length == 1) {
break;
} else {
pairs = getPairs(word);
}
}
final wordStr = word.join(' ');
cache[token] = wordStr;
return wordStr;
}
List<int> tokenize(String text, {int nText = 76, bool pad = true}) {
var tokens = encode(text);
tokens = [sot] + tokens.sublist(0, min(nText - 1, tokens.length)) + [eot];
if (pad) {
return tokens + List.filled(nText + 1 - tokens.length, 0);
} else {
return tokens;
}
}
List<int> pad(List<int> x, int padLength) {
return x + List.filled(padLength - x.length, 0);
}
Map<int, String> bytesToUnicode() {
final List<int> bs = [];
for (int i = '!'.codeUnitAt(0); i <= '~'.codeUnitAt(0); i++) {
bs.add(i);
}
for (int i = '¡'.codeUnitAt(0); i <= '¬'.codeUnitAt(0); i++) {
bs.add(i);
}
for (int i = '®'.codeUnitAt(0); i <= 'ÿ'.codeUnitAt(0); i++) {
bs.add(i);
}
final List<int> cs = List.from(bs);
int n = 0;
for (int b = 0; b < 256; b++) {
if (!bs.contains(b)) {
bs.add(b);
cs.add(256 + n);
n += 1;
}
}
final List<String> ds = cs.map((n) => String.fromCharCode(n)).toList();
return Map.fromIterables(bs, ds);
}
Set<Tuple2<String, String>> getPairs(List<String> word) {
final Set<Tuple2<String, String>> pairs = {};
String prevChar = word[0];
for (var i = 1; i < word.length; i++) {
pairs.add(Tuple2(prevChar, word[i]));
prevChar = word[i];
}
return pairs;
}
String basicClean(String text) {
final unescape = HtmlUnescape();
text = unescape.convert(unescape.convert(text));
return text.trim();
}
String whitespaceClean(String text) {
text = text.replaceAll(RegExp(r'\s+'), ' ');
return text.trim();
}
}