Retrieve OrtSession from cached address

This commit is contained in:
vishnukvmd 2023-12-13 14:55:51 +05:30
parent 1396444e3b
commit 3ea8fca2f2
2 changed files with 11 additions and 13 deletions

View file

@ -13,6 +13,7 @@ class ONNX extends MLFramework {
final _clipImage = OnnxImageEncoder(); final _clipImage = OnnxImageEncoder();
final _clipText = OnnxTextEncoder(); final _clipText = OnnxTextEncoder();
int _textEncoderAddress = 0; int _textEncoderAddress = 0;
int _imageEncoderAddress = 0;
@override @override
String getFrameworkName() { String getFrameworkName() {
@ -32,7 +33,7 @@ class ONNX extends MLFramework {
@override @override
Future<void> loadImageModel(String path) async { Future<void> loadImageModel(String path) async {
final startTime = DateTime.now(); final startTime = DateTime.now();
await _computer.compute( _imageEncoderAddress = await _computer.compute(
_clipImage.loadModel, _clipImage.loadModel,
param: { param: {
"imageModelPath": path, "imageModelPath": path,
@ -68,6 +69,7 @@ class ONNX extends MLFramework {
_clipImage.inferByImage, _clipImage.inferByImage,
param: { param: {
"imagePath": imagePath, "imagePath": imagePath,
"address": _imageEncoderAddress,
}, },
taskName: "createImageEmbedding", taskName: "createImageEmbedding",
) as List<double>; ) as List<double>;

View file

@ -7,8 +7,6 @@ import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart"; import "package:onnxruntime/onnxruntime.dart";
class OnnxImageEncoder { class OnnxImageEncoder {
OrtSessionOptions? _sessionOptions;
OrtSession? _session;
final _logger = Logger("OnnxImageEncoder"); final _logger = Logger("OnnxImageEncoder");
OnnxImageEncoder() { OnnxImageEncoder() {
@ -16,26 +14,23 @@ class OnnxImageEncoder {
} }
release() { release() {
_sessionOptions?.release();
_sessionOptions = null;
_session?.release();
_session = null;
OrtEnv.instance.release(); OrtEnv.instance.release();
} }
Future<void> loadModel(Map args) async { Future<int> loadModel(Map args) async {
_sessionOptions = OrtSessionOptions() final sessionOptions = OrtSessionOptions()
..setInterOpNumThreads(1) ..setInterOpNumThreads(1)
..setIntraOpNumThreads(1) ..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll); ..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
try { try {
final bytes = File(args["imageModelPath"]).readAsBytesSync(); final bytes = File(args["imageModelPath"]).readAsBytesSync();
_session = OrtSession.fromBuffer(bytes, _sessionOptions!); final session = OrtSession.fromBuffer(bytes, sessionOptions);
_logger.info('image model loaded'); _logger.info('image model loaded');
return session.address;
} catch (e, s) { } catch (e, s) {
_logger.severe(e, s); _logger.severe(e, s);
rethrow;
} }
return -1;
} }
List<double> inferByImage(Map args) { List<double> inferByImage(Map args) {
@ -81,8 +76,9 @@ class OnnxImageEncoder {
[1, 3, 224, 224], [1, 3, 224, 224],
); );
final inputs = {'input': inputOrt}; final inputs = {'input': inputOrt};
final outputs = _session?.run(runOptions, inputs); final session = OrtSession.fromAddress(args["address"]);
final embedding = (outputs?[0]?.value as List<List<double>>)[0]; final outputs = session.run(runOptions, inputs);
final embedding = (outputs[0]?.value as List<List<double>>)[0];
double imageNormalization = 0; double imageNormalization = 0;
for (int i = 0; i < 512; i++) { for (int i = 0; i < 512; i++) {
imageNormalization += embedding[i] * embedding[i]; imageNormalization += embedding[i] * embedding[i];