Retrieve OrtSession from cached address

This commit is contained in:
vishnukvmd 2023-12-13 14:55:51 +05:30
parent 1396444e3b
commit 3ea8fca2f2
2 changed files with 11 additions and 13 deletions

View file

@ -13,6 +13,7 @@ class ONNX extends MLFramework {
final _clipImage = OnnxImageEncoder();
final _clipText = OnnxTextEncoder();
int _textEncoderAddress = 0;
int _imageEncoderAddress = 0;
@override
String getFrameworkName() {
@ -32,7 +33,7 @@ class ONNX extends MLFramework {
@override
Future<void> loadImageModel(String path) async {
final startTime = DateTime.now();
await _computer.compute(
_imageEncoderAddress = await _computer.compute(
_clipImage.loadModel,
param: {
"imageModelPath": path,
@ -68,6 +69,7 @@ class ONNX extends MLFramework {
_clipImage.inferByImage,
param: {
"imagePath": imagePath,
"address": _imageEncoderAddress,
},
taskName: "createImageEmbedding",
) as List<double>;

View file

@ -7,8 +7,6 @@ import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
class OnnxImageEncoder {
OrtSessionOptions? _sessionOptions;
OrtSession? _session;
final _logger = Logger("OnnxImageEncoder");
OnnxImageEncoder() {
@ -16,26 +14,23 @@ class OnnxImageEncoder {
}
release() {
_sessionOptions?.release();
_sessionOptions = null;
_session?.release();
_session = null;
OrtEnv.instance.release();
}
Future<void> loadModel(Map args) async {
_sessionOptions = OrtSessionOptions()
Future<int> loadModel(Map args) async {
final sessionOptions = OrtSessionOptions()
..setInterOpNumThreads(1)
..setIntraOpNumThreads(1)
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
try {
final bytes = File(args["imageModelPath"]).readAsBytesSync();
_session = OrtSession.fromBuffer(bytes, _sessionOptions!);
final session = OrtSession.fromBuffer(bytes, sessionOptions);
_logger.info('image model loaded');
return session.address;
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
return -1;
}
List<double> inferByImage(Map args) {
@ -81,8 +76,9 @@ class OnnxImageEncoder {
[1, 3, 224, 224],
);
final inputs = {'input': inputOrt};
final outputs = _session?.run(runOptions, inputs);
final embedding = (outputs?[0]?.value as List<List<double>>)[0];
final session = OrtSession.fromAddress(args["address"]);
final outputs = session.run(runOptions, inputs);
final embedding = (outputs[0]?.value as List<List<double>>)[0];
double imageNormalization = 0;
for (int i = 0; i < 512; i++) {
imageNormalization += embedding[i] * embedding[i];