Retrieve OrtSession from cached address
This commit is contained in:
parent
1396444e3b
commit
3ea8fca2f2
|
@ -13,6 +13,7 @@ class ONNX extends MLFramework {
|
|||
final _clipImage = OnnxImageEncoder();
|
||||
final _clipText = OnnxTextEncoder();
|
||||
int _textEncoderAddress = 0;
|
||||
int _imageEncoderAddress = 0;
|
||||
|
||||
@override
|
||||
String getFrameworkName() {
|
||||
|
@ -32,7 +33,7 @@ class ONNX extends MLFramework {
|
|||
@override
|
||||
Future<void> loadImageModel(String path) async {
|
||||
final startTime = DateTime.now();
|
||||
await _computer.compute(
|
||||
_imageEncoderAddress = await _computer.compute(
|
||||
_clipImage.loadModel,
|
||||
param: {
|
||||
"imageModelPath": path,
|
||||
|
@ -68,6 +69,7 @@ class ONNX extends MLFramework {
|
|||
_clipImage.inferByImage,
|
||||
param: {
|
||||
"imagePath": imagePath,
|
||||
"address": _imageEncoderAddress,
|
||||
},
|
||||
taskName: "createImageEmbedding",
|
||||
) as List<double>;
|
||||
|
|
|
@ -7,8 +7,6 @@ import "package:logging/logging.dart";
|
|||
import "package:onnxruntime/onnxruntime.dart";
|
||||
|
||||
class OnnxImageEncoder {
|
||||
OrtSessionOptions? _sessionOptions;
|
||||
OrtSession? _session;
|
||||
final _logger = Logger("OnnxImageEncoder");
|
||||
|
||||
OnnxImageEncoder() {
|
||||
|
@ -16,26 +14,23 @@ class OnnxImageEncoder {
|
|||
}
|
||||
|
||||
release() {
|
||||
_sessionOptions?.release();
|
||||
_sessionOptions = null;
|
||||
_session?.release();
|
||||
_session = null;
|
||||
OrtEnv.instance.release();
|
||||
}
|
||||
|
||||
Future<void> loadModel(Map args) async {
|
||||
_sessionOptions = OrtSessionOptions()
|
||||
Future<int> loadModel(Map args) async {
|
||||
final sessionOptions = OrtSessionOptions()
|
||||
..setInterOpNumThreads(1)
|
||||
..setIntraOpNumThreads(1)
|
||||
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
|
||||
try {
|
||||
final bytes = File(args["imageModelPath"]).readAsBytesSync();
|
||||
_session = OrtSession.fromBuffer(bytes, _sessionOptions!);
|
||||
final session = OrtSession.fromBuffer(bytes, sessionOptions);
|
||||
_logger.info('image model loaded');
|
||||
return session.address;
|
||||
} catch (e, s) {
|
||||
_logger.severe(e, s);
|
||||
rethrow;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
List<double> inferByImage(Map args) {
|
||||
|
@ -81,8 +76,9 @@ class OnnxImageEncoder {
|
|||
[1, 3, 224, 224],
|
||||
);
|
||||
final inputs = {'input': inputOrt};
|
||||
final outputs = _session?.run(runOptions, inputs);
|
||||
final embedding = (outputs?[0]?.value as List<List<double>>)[0];
|
||||
final session = OrtSession.fromAddress(args["address"]);
|
||||
final outputs = session.run(runOptions, inputs);
|
||||
final embedding = (outputs[0]?.value as List<List<double>>)[0];
|
||||
double imageNormalization = 0;
|
||||
for (int i = 0; i < 512; i++) {
|
||||
imageNormalization += embedding[i] * embedding[i];
|
||||
|
|
Loading…
Reference in a new issue