diff --git a/lib/services/semantic_search/frameworks/onnx/onnx.dart b/lib/services/semantic_search/frameworks/onnx/onnx.dart index 9d8efc0d1..cf2464eba 100644 --- a/lib/services/semantic_search/frameworks/onnx/onnx.dart +++ b/lib/services/semantic_search/frameworks/onnx/onnx.dart @@ -13,6 +13,7 @@ class ONNX extends MLFramework { final _clipImage = OnnxImageEncoder(); final _clipText = OnnxTextEncoder(); int _textEncoderAddress = 0; + int _imageEncoderAddress = 0; @override String getFrameworkName() { @@ -32,7 +33,7 @@ class ONNX extends MLFramework { @override Future loadImageModel(String path) async { final startTime = DateTime.now(); - await _computer.compute( + _imageEncoderAddress = await _computer.compute( _clipImage.loadModel, param: { "imageModelPath": path, @@ -68,6 +69,7 @@ class ONNX extends MLFramework { _clipImage.inferByImage, param: { "imagePath": imagePath, + "address": _imageEncoderAddress, }, taskName: "createImageEmbedding", ) as List; diff --git a/lib/services/semantic_search/frameworks/onnx/onnx_image_encoder.dart b/lib/services/semantic_search/frameworks/onnx/onnx_image_encoder.dart index dbc6ba7d8..834121d2f 100644 --- a/lib/services/semantic_search/frameworks/onnx/onnx_image_encoder.dart +++ b/lib/services/semantic_search/frameworks/onnx/onnx_image_encoder.dart @@ -7,8 +7,6 @@ import "package:logging/logging.dart"; import "package:onnxruntime/onnxruntime.dart"; class OnnxImageEncoder { - OrtSessionOptions? _sessionOptions; - OrtSession? _session; final _logger = Logger("OnnxImageEncoder"); OnnxImageEncoder() { @@ -16,26 +14,23 @@ class OnnxImageEncoder { } release() { - _sessionOptions?.release(); - _sessionOptions = null; - _session?.release(); - _session = null; OrtEnv.instance.release(); } - Future loadModel(Map args) async { - _sessionOptions = OrtSessionOptions() + Future loadModel(Map args) async { + final sessionOptions = OrtSessionOptions() ..setInterOpNumThreads(1) ..setIntraOpNumThreads(1) ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); try { final bytes = File(args["imageModelPath"]).readAsBytesSync(); - _session = OrtSession.fromBuffer(bytes, _sessionOptions!); + final session = OrtSession.fromBuffer(bytes, sessionOptions); _logger.info('image model loaded'); + return session.address; } catch (e, s) { _logger.severe(e, s); - rethrow; } + return -1; } List inferByImage(Map args) { @@ -81,8 +76,9 @@ class OnnxImageEncoder { [1, 3, 224, 224], ); final inputs = {'input': inputOrt}; - final outputs = _session?.run(runOptions, inputs); - final embedding = (outputs?[0]?.value as List>)[0]; + final session = OrtSession.fromAddress(args["address"]); + final outputs = session.run(runOptions, inputs); + final embedding = (outputs[0]?.value as List>)[0]; double imageNormalization = 0; for (int i = 0; i < 512; i++) { imageNormalization += embedding[i] * embedding[i];