Add MobileNetv1
This commit is contained in:
parent
e8d3aa4e91
commit
a5646511e0
1001
assets/models/mobilenet/labels_mobilenet_quant_v1_224.txt
Normal file
1001
assets/models/mobilenet/labels_mobilenet_quant_v1_224.txt
Normal file
File diff suppressed because it is too large
Load diff
BIN
assets/models/mobilenet/mobilenet_v1_1.0_224_quant.tflite
Normal file
BIN
assets/models/mobilenet/mobilenet_v1_1.0_224_quant.tflite
Normal file
Binary file not shown.
|
@ -4,18 +4,18 @@ import "dart:typed_data";
|
|||
import "package:logging/logging.dart";
|
||||
import "package:photos/services/object_detection/models/predictions.dart";
|
||||
import 'package:photos/services/object_detection/models/recognition.dart';
|
||||
import "package:photos/services/object_detection/tflite/classifier.dart";
|
||||
import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
|
||||
import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
|
||||
import "package:photos/services/object_detection/utils/isolate_utils.dart";
|
||||
|
||||
class ObjectDetectionService {
|
||||
static const scoreThreshold = 0.6;
|
||||
static const scoreThreshold = 0.5;
|
||||
|
||||
final _logger = Logger("ObjectDetectionService");
|
||||
|
||||
/// Instance of [ObjectClassifier]
|
||||
late ObjectClassifier _classifier;
|
||||
late CocoSSDClassifier _objectClassifier;
|
||||
late MobileNetClassifier _mobileNetClassifier;
|
||||
|
||||
/// Instance of [IsolateUtils]
|
||||
late IsolateUtils _isolateUtils;
|
||||
|
||||
ObjectDetectionService._privateConstructor();
|
||||
|
@ -23,7 +23,8 @@ class ObjectDetectionService {
|
|||
Future<void> init() async {
|
||||
_isolateUtils = IsolateUtils();
|
||||
await _isolateUtils.start();
|
||||
_classifier = ObjectClassifier();
|
||||
_objectClassifier = CocoSSDClassifier();
|
||||
_mobileNetClassifier = MobileNetClassifier();
|
||||
}
|
||||
|
||||
static ObjectDetectionService instance =
|
||||
|
@ -31,10 +32,24 @@ class ObjectDetectionService {
|
|||
|
||||
Future<List<String>> predict(Uint8List bytes) async {
|
||||
try {
|
||||
final results = <String>{};
|
||||
final objectResults = await _getObjects(bytes);
|
||||
results.addAll(objectResults);
|
||||
final mobileNetResults = await _getMobileNetResults(bytes);
|
||||
results.addAll(mobileNetResults);
|
||||
return results.toList();
|
||||
} catch (e, s) {
|
||||
_logger.severe(e, s);
|
||||
rethrow;
|
||||
}
|
||||
}
|
||||
|
||||
Future<List<String>> _getObjects(Uint8List bytes) async {
|
||||
final isolateData = IsolateData(
|
||||
bytes,
|
||||
_classifier.interpreter.address,
|
||||
_classifier.labels,
|
||||
_objectClassifier.interpreter.address,
|
||||
_objectClassifier.labels,
|
||||
ClassifierType.cocossd,
|
||||
);
|
||||
final predictions = await _inference(isolateData);
|
||||
final Set<String> results = {};
|
||||
|
@ -44,10 +59,23 @@ class ObjectDetectionService {
|
|||
}
|
||||
}
|
||||
return results.toList();
|
||||
} catch (e, s) {
|
||||
_logger.severe(e, s);
|
||||
rethrow;
|
||||
}
|
||||
|
||||
Future<List<String>> _getMobileNetResults(Uint8List bytes) async {
|
||||
final isolateData = IsolateData(
|
||||
bytes,
|
||||
_mobileNetClassifier.interpreter.address,
|
||||
_mobileNetClassifier.labels,
|
||||
ClassifierType.mobilenet,
|
||||
);
|
||||
final predictions = await _inference(isolateData);
|
||||
final Set<String> results = {};
|
||||
for (final Recognition result in predictions.recognitions) {
|
||||
if (result.score > scoreThreshold) {
|
||||
results.add(result.label);
|
||||
}
|
||||
}
|
||||
return results.toList();
|
||||
}
|
||||
|
||||
/// Runs inference in another isolate
|
||||
|
|
|
@ -1,179 +1,6 @@
|
|||
import 'dart:math';
|
||||
|
||||
import 'package:image/image.dart' as imageLib;
|
||||
import "package:logging/logging.dart";
|
||||
import 'package:photos/services/object_detection/models/predictions.dart';
|
||||
import 'package:photos/services/object_detection/models/recognition.dart';
|
||||
import "package:photos/services/object_detection/models/stats.dart";
|
||||
import "package:tflite_flutter/tflite_flutter.dart";
|
||||
import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
|
||||
import "package:photos/services/object_detection/models/predictions.dart";
|
||||
|
||||
/// Classifier
|
||||
class ObjectClassifier {
|
||||
final _logger = Logger("Classifier");
|
||||
|
||||
/// Instance of Interpreter
|
||||
late Interpreter _interpreter;
|
||||
|
||||
/// Labels file loaded as list
|
||||
late List<String> _labels;
|
||||
|
||||
/// Input size of image (height = width = 300)
|
||||
static const int inputSize = 300;
|
||||
|
||||
/// Result score threshold
|
||||
static const double threshold = 0.5;
|
||||
|
||||
static const String modelFileName = "detect.tflite";
|
||||
static const String labelFileName = "labelmap.txt";
|
||||
|
||||
/// [ImageProcessor] used to pre-process the image
|
||||
ImageProcessor? imageProcessor;
|
||||
|
||||
/// Padding the image to transform into square
|
||||
late int padSize;
|
||||
|
||||
/// Shapes of output tensors
|
||||
late List<List<int>> _outputShapes;
|
||||
|
||||
/// Types of output tensors
|
||||
late List<TfLiteType> _outputTypes;
|
||||
|
||||
/// Number of results to show
|
||||
static const int numResults = 10;
|
||||
|
||||
ObjectClassifier({
|
||||
Interpreter? interpreter,
|
||||
List<String>? labels,
|
||||
}) {
|
||||
loadModel(interpreter);
|
||||
loadLabels(labels);
|
||||
}
|
||||
|
||||
/// Loads interpreter from asset
|
||||
void loadModel(Interpreter? interpreter) async {
|
||||
try {
|
||||
_interpreter = interpreter ??
|
||||
await Interpreter.fromAsset(
|
||||
"models/" + modelFileName,
|
||||
options: InterpreterOptions()..threads = 4,
|
||||
);
|
||||
final outputTensors = _interpreter.getOutputTensors();
|
||||
_outputShapes = [];
|
||||
_outputTypes = [];
|
||||
outputTensors.forEach((tensor) {
|
||||
_outputShapes.add(tensor.shape);
|
||||
_outputTypes.add(tensor.type);
|
||||
});
|
||||
_logger.info("Interpreter initialized");
|
||||
} catch (e, s) {
|
||||
_logger.severe("Error while creating interpreter", e, s);
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads labels from assets
|
||||
void loadLabels(List<String>? labels) async {
|
||||
try {
|
||||
_labels =
|
||||
labels ?? await FileUtil.loadLabels("assets/models/" + labelFileName);
|
||||
_logger.info("Labels initialized");
|
||||
} catch (e, s) {
|
||||
_logger.severe("Error while loading labels", e, s);
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-process the image
|
||||
TensorImage _getProcessedImage(TensorImage inputImage) {
|
||||
padSize = max(inputImage.height, inputImage.width);
|
||||
imageProcessor ??= ImageProcessorBuilder()
|
||||
.add(ResizeWithCropOrPadOp(padSize, padSize))
|
||||
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
|
||||
.build();
|
||||
inputImage = imageProcessor!.process(inputImage);
|
||||
return inputImage;
|
||||
}
|
||||
|
||||
/// Runs object detection on the input image
|
||||
Predictions? predict(imageLib.Image image) {
|
||||
final predictStartTime = DateTime.now().millisecondsSinceEpoch;
|
||||
|
||||
final preProcessStart = DateTime.now().millisecondsSinceEpoch;
|
||||
|
||||
// Create TensorImage from image
|
||||
TensorImage inputImage = TensorImage.fromImage(image);
|
||||
|
||||
// Pre-process TensorImage
|
||||
inputImage = _getProcessedImage(inputImage);
|
||||
|
||||
final preProcessElapsedTime =
|
||||
DateTime.now().millisecondsSinceEpoch - preProcessStart;
|
||||
|
||||
// TensorBuffers for output tensors
|
||||
final outputLocations = TensorBufferFloat(_outputShapes[0]);
|
||||
final outputClasses = TensorBufferFloat(_outputShapes[1]);
|
||||
final outputScores = TensorBufferFloat(_outputShapes[2]);
|
||||
final numLocations = TensorBufferFloat(_outputShapes[3]);
|
||||
|
||||
// Inputs object for runForMultipleInputs
|
||||
// Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
|
||||
final inputs = [inputImage.buffer];
|
||||
|
||||
// Outputs map
|
||||
final outputs = {
|
||||
0: outputLocations.buffer,
|
||||
1: outputClasses.buffer,
|
||||
2: outputScores.buffer,
|
||||
3: numLocations.buffer,
|
||||
};
|
||||
|
||||
final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
|
||||
|
||||
// run inference
|
||||
_interpreter.runForMultipleInputs(inputs, outputs);
|
||||
|
||||
final inferenceTimeElapsed =
|
||||
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
|
||||
|
||||
// Maximum number of results to show
|
||||
final resultsCount = min(numResults, numLocations.getIntValue(0));
|
||||
|
||||
// Using labelOffset = 1 as ??? at index 0
|
||||
const labelOffset = 1;
|
||||
|
||||
final recognitions = <Recognition>[];
|
||||
|
||||
for (int i = 0; i < resultsCount; i++) {
|
||||
// Prediction score
|
||||
final score = outputScores.getDoubleValue(i);
|
||||
|
||||
// Label string
|
||||
final labelIndex = outputClasses.getIntValue(i) + labelOffset;
|
||||
final label = _labels.elementAt(labelIndex);
|
||||
|
||||
if (score > threshold) {
|
||||
recognitions.add(
|
||||
Recognition(i, label, score),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
final predictElapsedTime =
|
||||
DateTime.now().millisecondsSinceEpoch - predictStartTime;
|
||||
_logger.info(recognitions);
|
||||
return Predictions(
|
||||
recognitions,
|
||||
Stats(
|
||||
predictElapsedTime,
|
||||
predictElapsedTime,
|
||||
inferenceTimeElapsed,
|
||||
preProcessElapsedTime,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/// Gets the interpreter instance
|
||||
Interpreter get interpreter => _interpreter;
|
||||
|
||||
/// Gets the loaded labels
|
||||
List<String> get labels => _labels;
|
||||
abstract class Classifier {
|
||||
Predictions? predict(imageLib.Image image);
|
||||
}
|
||||
|
|
180
lib/services/object_detection/tflite/cocossd_classifier.dart
Normal file
180
lib/services/object_detection/tflite/cocossd_classifier.dart
Normal file
|
@ -0,0 +1,180 @@
|
|||
import 'dart:math';
|
||||
|
||||
import 'package:image/image.dart' as imageLib;
|
||||
import "package:logging/logging.dart";
|
||||
import 'package:photos/services/object_detection/models/predictions.dart';
|
||||
import 'package:photos/services/object_detection/models/recognition.dart';
|
||||
import "package:photos/services/object_detection/models/stats.dart";
|
||||
import "package:photos/services/object_detection/tflite/classifier.dart";
|
||||
import "package:tflite_flutter/tflite_flutter.dart";
|
||||
import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
|
||||
|
||||
/// Classifier
|
||||
class CocoSSDClassifier extends Classifier {
|
||||
final _logger = Logger("Classifier");
|
||||
|
||||
/// Instance of Interpreter
|
||||
late Interpreter _interpreter;
|
||||
|
||||
/// Labels file loaded as list
|
||||
late List<String> _labels;
|
||||
|
||||
/// Input size of image (height = width = 300)
|
||||
static const int inputSize = 300;
|
||||
|
||||
/// Result score threshold
|
||||
static const double threshold = 0.5;
|
||||
|
||||
static const String modelFileName = "model.tflite";
|
||||
static const String labelFileName = "labels.txt";
|
||||
|
||||
/// [ImageProcessor] used to pre-process the image
|
||||
ImageProcessor? imageProcessor;
|
||||
|
||||
/// Padding the image to transform into square
|
||||
late int padSize;
|
||||
|
||||
/// Shapes of output tensors
|
||||
late List<List<int>> _outputShapes;
|
||||
|
||||
/// Types of output tensors
|
||||
late List<TfLiteType> _outputTypes;
|
||||
|
||||
/// Number of results to show
|
||||
static const int numResults = 10;
|
||||
|
||||
CocoSSDClassifier({
|
||||
Interpreter? interpreter,
|
||||
List<String>? labels,
|
||||
}) {
|
||||
loadModel(interpreter);
|
||||
loadLabels(labels);
|
||||
}
|
||||
|
||||
/// Loads interpreter from asset
|
||||
void loadModel(Interpreter? interpreter) async {
|
||||
try {
|
||||
_interpreter = interpreter ??
|
||||
await Interpreter.fromAsset(
|
||||
"models/cocossd/" + modelFileName,
|
||||
options: InterpreterOptions()..threads = 4,
|
||||
);
|
||||
final outputTensors = _interpreter.getOutputTensors();
|
||||
_outputShapes = [];
|
||||
_outputTypes = [];
|
||||
outputTensors.forEach((tensor) {
|
||||
_outputShapes.add(tensor.shape);
|
||||
_outputTypes.add(tensor.type);
|
||||
});
|
||||
_logger.info("Interpreter initialized");
|
||||
} catch (e, s) {
|
||||
_logger.severe("Error while creating interpreter", e, s);
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads labels from assets
|
||||
void loadLabels(List<String>? labels) async {
|
||||
try {
|
||||
_labels = labels ??
|
||||
await FileUtil.loadLabels("assets/models/cocossd/" + labelFileName);
|
||||
_logger.info("Labels initialized");
|
||||
} catch (e, s) {
|
||||
_logger.severe("Error while loading labels", e, s);
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-process the image
|
||||
TensorImage _getProcessedImage(TensorImage inputImage) {
|
||||
padSize = max(inputImage.height, inputImage.width);
|
||||
imageProcessor ??= ImageProcessorBuilder()
|
||||
.add(ResizeWithCropOrPadOp(padSize, padSize))
|
||||
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
|
||||
.build();
|
||||
inputImage = imageProcessor!.process(inputImage);
|
||||
return inputImage;
|
||||
}
|
||||
|
||||
/// Runs object detection on the input image
|
||||
Predictions? predict(imageLib.Image image) {
|
||||
final predictStartTime = DateTime.now().millisecondsSinceEpoch;
|
||||
|
||||
final preProcessStart = DateTime.now().millisecondsSinceEpoch;
|
||||
|
||||
// Create TensorImage from image
|
||||
TensorImage inputImage = TensorImage.fromImage(image);
|
||||
|
||||
// Pre-process TensorImage
|
||||
inputImage = _getProcessedImage(inputImage);
|
||||
|
||||
final preProcessElapsedTime =
|
||||
DateTime.now().millisecondsSinceEpoch - preProcessStart;
|
||||
|
||||
// TensorBuffers for output tensors
|
||||
final outputLocations = TensorBufferFloat(_outputShapes[0]);
|
||||
final outputClasses = TensorBufferFloat(_outputShapes[1]);
|
||||
final outputScores = TensorBufferFloat(_outputShapes[2]);
|
||||
final numLocations = TensorBufferFloat(_outputShapes[3]);
|
||||
|
||||
// Inputs object for runForMultipleInputs
|
||||
// Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
|
||||
final inputs = [inputImage.buffer];
|
||||
|
||||
// Outputs map
|
||||
final outputs = {
|
||||
0: outputLocations.buffer,
|
||||
1: outputClasses.buffer,
|
||||
2: outputScores.buffer,
|
||||
3: numLocations.buffer,
|
||||
};
|
||||
|
||||
final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
|
||||
|
||||
// run inference
|
||||
_interpreter.runForMultipleInputs(inputs, outputs);
|
||||
|
||||
final inferenceTimeElapsed =
|
||||
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
|
||||
|
||||
// Maximum number of results to show
|
||||
final resultsCount = min(numResults, numLocations.getIntValue(0));
|
||||
|
||||
// Using labelOffset = 1 as ??? at index 0
|
||||
const labelOffset = 1;
|
||||
|
||||
final recognitions = <Recognition>[];
|
||||
|
||||
for (int i = 0; i < resultsCount; i++) {
|
||||
// Prediction score
|
||||
final score = outputScores.getDoubleValue(i);
|
||||
|
||||
// Label string
|
||||
final labelIndex = outputClasses.getIntValue(i) + labelOffset;
|
||||
final label = _labels.elementAt(labelIndex);
|
||||
|
||||
if (score > threshold) {
|
||||
recognitions.add(
|
||||
Recognition(i, label, score),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
final predictElapsedTime =
|
||||
DateTime.now().millisecondsSinceEpoch - predictStartTime;
|
||||
_logger.info(recognitions);
|
||||
return Predictions(
|
||||
recognitions,
|
||||
Stats(
|
||||
predictElapsedTime,
|
||||
predictElapsedTime,
|
||||
inferenceTimeElapsed,
|
||||
preProcessElapsedTime,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/// Gets the interpreter instance
|
||||
Interpreter get interpreter => _interpreter;
|
||||
|
||||
/// Gets the loaded labels
|
||||
List<String> get labels => _labels;
|
||||
}
|
151
lib/services/object_detection/tflite/mobilenet_classifier.dart
Normal file
151
lib/services/object_detection/tflite/mobilenet_classifier.dart
Normal file
|
@ -0,0 +1,151 @@
|
|||
import 'dart:math';
|
||||
|
||||
import 'package:image/image.dart' as imageLib;
|
||||
import "package:logging/logging.dart";
|
||||
import 'package:photos/services/object_detection/models/predictions.dart';
|
||||
import 'package:photos/services/object_detection/models/recognition.dart';
|
||||
import "package:photos/services/object_detection/models/stats.dart";
|
||||
import "package:photos/services/object_detection/tflite/classifier.dart";
|
||||
import "package:tflite_flutter/tflite_flutter.dart";
|
||||
import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
|
||||
|
||||
class MobileNetClassifier extends Classifier {
|
||||
final _logger = Logger("MobileNetClassifier");
|
||||
|
||||
/// Instance of Interpreter
|
||||
late Interpreter _interpreter;
|
||||
|
||||
/// Labels file loaded as list
|
||||
late List<String> _labels;
|
||||
|
||||
/// Input size of image (height = width = 300)
|
||||
static const int inputSize = 224;
|
||||
|
||||
/// Result score threshold
|
||||
static const double threshold = 0.5;
|
||||
|
||||
static const String modelFileName = "mobilenet_v1_1.0_224_quant.tflite";
|
||||
static const String labelFileName = "labels_mobilenet_quant_v1_224.txt";
|
||||
|
||||
/// [ImageProcessor] used to pre-process the image
|
||||
ImageProcessor? imageProcessor;
|
||||
|
||||
/// Padding the image to transform into square
|
||||
late int padSize;
|
||||
|
||||
/// Shapes of output tensors
|
||||
late List<List<int>> _outputShapes;
|
||||
|
||||
/// Types of output tensors
|
||||
late List<TfLiteType> _outputTypes;
|
||||
|
||||
/// Number of results to show
|
||||
static const int numResults = 10;
|
||||
|
||||
MobileNetClassifier({
|
||||
Interpreter? interpreter,
|
||||
List<String>? labels,
|
||||
}) {
|
||||
loadModel(interpreter);
|
||||
loadLabels(labels);
|
||||
}
|
||||
|
||||
/// Loads interpreter from asset
|
||||
void loadModel(Interpreter? interpreter) async {
|
||||
try {
|
||||
_interpreter = interpreter ??
|
||||
await Interpreter.fromAsset(
|
||||
"models/mobilenet/" + modelFileName,
|
||||
options: InterpreterOptions()..threads = 4,
|
||||
);
|
||||
final outputTensors = _interpreter.getOutputTensors();
|
||||
_outputShapes = [];
|
||||
_outputTypes = [];
|
||||
outputTensors.forEach((tensor) {
|
||||
_outputShapes.add(tensor.shape);
|
||||
_outputTypes.add(tensor.type);
|
||||
});
|
||||
_logger.info("Interpreter initialized");
|
||||
} catch (e, s) {
|
||||
_logger.severe("Error while creating interpreter", e, s);
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads labels from assets
|
||||
void loadLabels(List<String>? labels) async {
|
||||
try {
|
||||
_labels = labels ??
|
||||
await FileUtil.loadLabels("assets/models/mobilenet/" + labelFileName);
|
||||
_logger.info("Labels initialized");
|
||||
} catch (e, s) {
|
||||
_logger.severe("Error while loading labels", e, s);
|
||||
}
|
||||
}
|
||||
|
||||
/// Pre-process the image
|
||||
TensorImage _getProcessedImage(TensorImage inputImage) {
|
||||
padSize = max(inputImage.height, inputImage.width);
|
||||
imageProcessor ??= ImageProcessorBuilder()
|
||||
// .add(ResizeWithCropOrPadOp(padSize, padSize))
|
||||
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
|
||||
.build();
|
||||
inputImage = imageProcessor!.process(inputImage);
|
||||
return inputImage;
|
||||
}
|
||||
|
||||
/// Runs object detection on the input image
|
||||
Predictions? predict(imageLib.Image image) {
|
||||
final predictStartTime = DateTime.now().millisecondsSinceEpoch;
|
||||
|
||||
final preProcessStart = DateTime.now().millisecondsSinceEpoch;
|
||||
|
||||
// Create TensorImage from image
|
||||
TensorImage inputImage = TensorImage.fromImage(image);
|
||||
|
||||
// Pre-process TensorImage
|
||||
inputImage = _getProcessedImage(inputImage);
|
||||
|
||||
final preProcessElapsedTime =
|
||||
DateTime.now().millisecondsSinceEpoch - preProcessStart;
|
||||
|
||||
// TensorBuffers for output tensors
|
||||
final output = TensorBufferUint8(_outputShapes[0]);
|
||||
final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
|
||||
// run inference
|
||||
_interpreter.run(inputImage.buffer, output.buffer);
|
||||
|
||||
final inferenceTimeElapsed =
|
||||
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
|
||||
|
||||
final recognitions = <Recognition>[];
|
||||
for (int i = 0; i < 1001; i++) {
|
||||
final score = output.getDoubleValue(i) / 255;
|
||||
if (score >= threshold) {
|
||||
final label = _labels.elementAt(i);
|
||||
|
||||
recognitions.add(
|
||||
Recognition(i, "#" + label, score),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
final predictElapsedTime =
|
||||
DateTime.now().millisecondsSinceEpoch - predictStartTime;
|
||||
_logger.info(recognitions);
|
||||
return Predictions(
|
||||
recognitions,
|
||||
Stats(
|
||||
predictElapsedTime,
|
||||
predictElapsedTime,
|
||||
inferenceTimeElapsed,
|
||||
preProcessElapsedTime,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
/// Gets the interpreter instance
|
||||
Interpreter get interpreter => _interpreter;
|
||||
|
||||
/// Gets the loaded labels
|
||||
List<String> get labels => _labels;
|
||||
}
|
|
@ -2,7 +2,8 @@ import 'dart:isolate';
|
|||
import "dart:typed_data";
|
||||
|
||||
import 'package:image/image.dart' as imgLib;
|
||||
import "package:photos/services/object_detection/tflite/classifier.dart";
|
||||
import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
|
||||
import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
|
||||
import 'package:tflite_flutter/tflite_flutter.dart';
|
||||
|
||||
/// Manages separate Isolate instance for inference
|
||||
|
@ -29,8 +30,15 @@ class IsolateUtils {
|
|||
sendPort.send(port.sendPort);
|
||||
|
||||
await for (final IsolateData isolateData in port) {
|
||||
final classifier = ObjectClassifier(
|
||||
interpreter: Interpreter.fromAddress(isolateData.interpreterAddress),
|
||||
final classifier = isolateData.type == ClassifierType.cocossd
|
||||
? CocoSSDClassifier(
|
||||
interpreter:
|
||||
Interpreter.fromAddress(isolateData.interpreterAddress),
|
||||
labels: isolateData.labels,
|
||||
)
|
||||
: MobileNetClassifier(
|
||||
interpreter:
|
||||
Interpreter.fromAddress(isolateData.interpreterAddress),
|
||||
labels: isolateData.labels,
|
||||
);
|
||||
final image = imgLib.decodeImage(isolateData.input);
|
||||
|
@ -45,11 +53,18 @@ class IsolateData {
|
|||
Uint8List input;
|
||||
int interpreterAddress;
|
||||
List<String> labels;
|
||||
ClassifierType type;
|
||||
late SendPort responsePort;
|
||||
|
||||
IsolateData(
|
||||
this.input,
|
||||
this.interpreterAddress,
|
||||
this.labels,
|
||||
this.type,
|
||||
);
|
||||
}
|
||||
|
||||
enum ClassifierType {
|
||||
cocossd,
|
||||
mobilenet,
|
||||
}
|
||||
|
|
|
@ -165,7 +165,8 @@ flutter_native_splash:
|
|||
flutter:
|
||||
assets:
|
||||
- assets/
|
||||
- assets/models/
|
||||
- assets/models/cocossd/
|
||||
- assets/models/mobilenet/
|
||||
fonts:
|
||||
- family: Inter
|
||||
fonts:
|
||||
|
|
Loading…
Reference in a new issue