Retrieve OrtSession from cached address
This commit is contained in:
parent
1396444e3b
commit
3ea8fca2f2
|
@ -13,6 +13,7 @@ class ONNX extends MLFramework {
|
||||||
final _clipImage = OnnxImageEncoder();
|
final _clipImage = OnnxImageEncoder();
|
||||||
final _clipText = OnnxTextEncoder();
|
final _clipText = OnnxTextEncoder();
|
||||||
int _textEncoderAddress = 0;
|
int _textEncoderAddress = 0;
|
||||||
|
int _imageEncoderAddress = 0;
|
||||||
|
|
||||||
@override
|
@override
|
||||||
String getFrameworkName() {
|
String getFrameworkName() {
|
||||||
|
@ -32,7 +33,7 @@ class ONNX extends MLFramework {
|
||||||
@override
|
@override
|
||||||
Future<void> loadImageModel(String path) async {
|
Future<void> loadImageModel(String path) async {
|
||||||
final startTime = DateTime.now();
|
final startTime = DateTime.now();
|
||||||
await _computer.compute(
|
_imageEncoderAddress = await _computer.compute(
|
||||||
_clipImage.loadModel,
|
_clipImage.loadModel,
|
||||||
param: {
|
param: {
|
||||||
"imageModelPath": path,
|
"imageModelPath": path,
|
||||||
|
@ -68,6 +69,7 @@ class ONNX extends MLFramework {
|
||||||
_clipImage.inferByImage,
|
_clipImage.inferByImage,
|
||||||
param: {
|
param: {
|
||||||
"imagePath": imagePath,
|
"imagePath": imagePath,
|
||||||
|
"address": _imageEncoderAddress,
|
||||||
},
|
},
|
||||||
taskName: "createImageEmbedding",
|
taskName: "createImageEmbedding",
|
||||||
) as List<double>;
|
) as List<double>;
|
||||||
|
|
|
@ -7,8 +7,6 @@ import "package:logging/logging.dart";
|
||||||
import "package:onnxruntime/onnxruntime.dart";
|
import "package:onnxruntime/onnxruntime.dart";
|
||||||
|
|
||||||
class OnnxImageEncoder {
|
class OnnxImageEncoder {
|
||||||
OrtSessionOptions? _sessionOptions;
|
|
||||||
OrtSession? _session;
|
|
||||||
final _logger = Logger("OnnxImageEncoder");
|
final _logger = Logger("OnnxImageEncoder");
|
||||||
|
|
||||||
OnnxImageEncoder() {
|
OnnxImageEncoder() {
|
||||||
|
@ -16,26 +14,23 @@ class OnnxImageEncoder {
|
||||||
}
|
}
|
||||||
|
|
||||||
release() {
|
release() {
|
||||||
_sessionOptions?.release();
|
|
||||||
_sessionOptions = null;
|
|
||||||
_session?.release();
|
|
||||||
_session = null;
|
|
||||||
OrtEnv.instance.release();
|
OrtEnv.instance.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<void> loadModel(Map args) async {
|
Future<int> loadModel(Map args) async {
|
||||||
_sessionOptions = OrtSessionOptions()
|
final sessionOptions = OrtSessionOptions()
|
||||||
..setInterOpNumThreads(1)
|
..setInterOpNumThreads(1)
|
||||||
..setIntraOpNumThreads(1)
|
..setIntraOpNumThreads(1)
|
||||||
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
|
..setSessionGraphOptimizationLevel(GraphOptimizationLevel.ortEnableAll);
|
||||||
try {
|
try {
|
||||||
final bytes = File(args["imageModelPath"]).readAsBytesSync();
|
final bytes = File(args["imageModelPath"]).readAsBytesSync();
|
||||||
_session = OrtSession.fromBuffer(bytes, _sessionOptions!);
|
final session = OrtSession.fromBuffer(bytes, sessionOptions);
|
||||||
_logger.info('image model loaded');
|
_logger.info('image model loaded');
|
||||||
|
return session.address;
|
||||||
} catch (e, s) {
|
} catch (e, s) {
|
||||||
_logger.severe(e, s);
|
_logger.severe(e, s);
|
||||||
rethrow;
|
|
||||||
}
|
}
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
||||||
List<double> inferByImage(Map args) {
|
List<double> inferByImage(Map args) {
|
||||||
|
@ -81,8 +76,9 @@ class OnnxImageEncoder {
|
||||||
[1, 3, 224, 224],
|
[1, 3, 224, 224],
|
||||||
);
|
);
|
||||||
final inputs = {'input': inputOrt};
|
final inputs = {'input': inputOrt};
|
||||||
final outputs = _session?.run(runOptions, inputs);
|
final session = OrtSession.fromAddress(args["address"]);
|
||||||
final embedding = (outputs?[0]?.value as List<List<double>>)[0];
|
final outputs = session.run(runOptions, inputs);
|
||||||
|
final embedding = (outputs[0]?.value as List<List<double>>)[0];
|
||||||
double imageNormalization = 0;
|
double imageNormalization = 0;
|
||||||
for (int i = 0; i < 512; i++) {
|
for (int i = 0; i < 512; i++) {
|
||||||
imageNormalization += embedding[i] * embedding[i];
|
imageNormalization += embedding[i] * embedding[i];
|
||||||
|
|
Loading…
Reference in a new issue