Add MobileNetv1

This commit is contained in:
vishnukvmd 2023-03-14 15:11:49 +05:30
parent e8d3aa4e91
commit a5646511e0
10 changed files with 1403 additions and 200 deletions

File diff suppressed because it is too large Load diff

View file

@ -4,18 +4,18 @@ import "dart:typed_data";
import "package:logging/logging.dart";
import "package:photos/services/object_detection/models/predictions.dart";
import 'package:photos/services/object_detection/models/recognition.dart';
import "package:photos/services/object_detection/tflite/classifier.dart";
import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
import "package:photos/services/object_detection/utils/isolate_utils.dart";
class ObjectDetectionService {
static const scoreThreshold = 0.6;
static const scoreThreshold = 0.5;
final _logger = Logger("ObjectDetectionService");
/// Instance of [ObjectClassifier]
late ObjectClassifier _classifier;
late CocoSSDClassifier _objectClassifier;
late MobileNetClassifier _mobileNetClassifier;
/// Instance of [IsolateUtils]
late IsolateUtils _isolateUtils;
ObjectDetectionService._privateConstructor();
@ -23,7 +23,8 @@ class ObjectDetectionService {
Future<void> init() async {
_isolateUtils = IsolateUtils();
await _isolateUtils.start();
_classifier = ObjectClassifier();
_objectClassifier = CocoSSDClassifier();
_mobileNetClassifier = MobileNetClassifier();
}
static ObjectDetectionService instance =
@ -31,18 +32,11 @@ class ObjectDetectionService {
Future<List<String>> predict(Uint8List bytes) async {
try {
final isolateData = IsolateData(
bytes,
_classifier.interpreter.address,
_classifier.labels,
);
final predictions = await _inference(isolateData);
final Set<String> results = {};
for (final Recognition result in predictions.recognitions) {
if (result.score > scoreThreshold) {
results.add(result.label);
}
}
final results = <String>{};
final objectResults = await _getObjects(bytes);
results.addAll(objectResults);
final mobileNetResults = await _getMobileNetResults(bytes);
results.addAll(mobileNetResults);
return results.toList();
} catch (e, s) {
_logger.severe(e, s);
@ -50,6 +44,40 @@ class ObjectDetectionService {
}
}
Future<List<String>> _getObjects(Uint8List bytes) async {
final isolateData = IsolateData(
bytes,
_objectClassifier.interpreter.address,
_objectClassifier.labels,
ClassifierType.cocossd,
);
final predictions = await _inference(isolateData);
final Set<String> results = {};
for (final Recognition result in predictions.recognitions) {
if (result.score > scoreThreshold) {
results.add(result.label);
}
}
return results.toList();
}
Future<List<String>> _getMobileNetResults(Uint8List bytes) async {
final isolateData = IsolateData(
bytes,
_mobileNetClassifier.interpreter.address,
_mobileNetClassifier.labels,
ClassifierType.mobilenet,
);
final predictions = await _inference(isolateData);
final Set<String> results = {};
for (final Recognition result in predictions.recognitions) {
if (result.score > scoreThreshold) {
results.add(result.label);
}
}
return results.toList();
}
/// Runs inference in another isolate
Future<Predictions> _inference(IsolateData isolateData) async {
final responsePort = ReceivePort();

View file

@ -1,179 +1,6 @@
import 'dart:math';
import 'package:image/image.dart' as imageLib;
import "package:logging/logging.dart";
import 'package:photos/services/object_detection/models/predictions.dart';
import 'package:photos/services/object_detection/models/recognition.dart';
import "package:photos/services/object_detection/models/stats.dart";
import "package:tflite_flutter/tflite_flutter.dart";
import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
import "package:photos/services/object_detection/models/predictions.dart";
/// Classifier
class ObjectClassifier {
final _logger = Logger("Classifier");
/// Instance of Interpreter
late Interpreter _interpreter;
/// Labels file loaded as list
late List<String> _labels;
/// Input size of image (height = width = 300)
static const int inputSize = 300;
/// Result score threshold
static const double threshold = 0.5;
static const String modelFileName = "detect.tflite";
static const String labelFileName = "labelmap.txt";
/// [ImageProcessor] used to pre-process the image
ImageProcessor? imageProcessor;
/// Padding the image to transform into square
late int padSize;
/// Shapes of output tensors
late List<List<int>> _outputShapes;
/// Types of output tensors
late List<TfLiteType> _outputTypes;
/// Number of results to show
static const int numResults = 10;
ObjectClassifier({
Interpreter? interpreter,
List<String>? labels,
}) {
loadModel(interpreter);
loadLabels(labels);
}
/// Loads interpreter from asset
void loadModel(Interpreter? interpreter) async {
try {
_interpreter = interpreter ??
await Interpreter.fromAsset(
"models/" + modelFileName,
options: InterpreterOptions()..threads = 4,
);
final outputTensors = _interpreter.getOutputTensors();
_outputShapes = [];
_outputTypes = [];
outputTensors.forEach((tensor) {
_outputShapes.add(tensor.shape);
_outputTypes.add(tensor.type);
});
_logger.info("Interpreter initialized");
} catch (e, s) {
_logger.severe("Error while creating interpreter", e, s);
}
}
/// Loads labels from assets
void loadLabels(List<String>? labels) async {
try {
_labels =
labels ?? await FileUtil.loadLabels("assets/models/" + labelFileName);
_logger.info("Labels initialized");
} catch (e, s) {
_logger.severe("Error while loading labels", e, s);
}
}
/// Pre-process the image
TensorImage _getProcessedImage(TensorImage inputImage) {
padSize = max(inputImage.height, inputImage.width);
imageProcessor ??= ImageProcessorBuilder()
.add(ResizeWithCropOrPadOp(padSize, padSize))
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
.build();
inputImage = imageProcessor!.process(inputImage);
return inputImage;
}
/// Runs object detection on the input image
Predictions? predict(imageLib.Image image) {
final predictStartTime = DateTime.now().millisecondsSinceEpoch;
final preProcessStart = DateTime.now().millisecondsSinceEpoch;
// Create TensorImage from image
TensorImage inputImage = TensorImage.fromImage(image);
// Pre-process TensorImage
inputImage = _getProcessedImage(inputImage);
final preProcessElapsedTime =
DateTime.now().millisecondsSinceEpoch - preProcessStart;
// TensorBuffers for output tensors
final outputLocations = TensorBufferFloat(_outputShapes[0]);
final outputClasses = TensorBufferFloat(_outputShapes[1]);
final outputScores = TensorBufferFloat(_outputShapes[2]);
final numLocations = TensorBufferFloat(_outputShapes[3]);
// Inputs object for runForMultipleInputs
// Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
final inputs = [inputImage.buffer];
// Outputs map
final outputs = {
0: outputLocations.buffer,
1: outputClasses.buffer,
2: outputScores.buffer,
3: numLocations.buffer,
};
final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
// run inference
_interpreter.runForMultipleInputs(inputs, outputs);
final inferenceTimeElapsed =
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
// Maximum number of results to show
final resultsCount = min(numResults, numLocations.getIntValue(0));
// Using labelOffset = 1 as ??? at index 0
const labelOffset = 1;
final recognitions = <Recognition>[];
for (int i = 0; i < resultsCount; i++) {
// Prediction score
final score = outputScores.getDoubleValue(i);
// Label string
final labelIndex = outputClasses.getIntValue(i) + labelOffset;
final label = _labels.elementAt(labelIndex);
if (score > threshold) {
recognitions.add(
Recognition(i, label, score),
);
}
}
final predictElapsedTime =
DateTime.now().millisecondsSinceEpoch - predictStartTime;
_logger.info(recognitions);
return Predictions(
recognitions,
Stats(
predictElapsedTime,
predictElapsedTime,
inferenceTimeElapsed,
preProcessElapsedTime,
),
);
}
/// Gets the interpreter instance
Interpreter get interpreter => _interpreter;
/// Gets the loaded labels
List<String> get labels => _labels;
abstract class Classifier {
Predictions? predict(imageLib.Image image);
}

View file

@ -0,0 +1,180 @@
import 'dart:math';
import 'package:image/image.dart' as imageLib;
import "package:logging/logging.dart";
import 'package:photos/services/object_detection/models/predictions.dart';
import 'package:photos/services/object_detection/models/recognition.dart';
import "package:photos/services/object_detection/models/stats.dart";
import "package:photos/services/object_detection/tflite/classifier.dart";
import "package:tflite_flutter/tflite_flutter.dart";
import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
/// Classifier
class CocoSSDClassifier extends Classifier {
final _logger = Logger("Classifier");
/// Instance of Interpreter
late Interpreter _interpreter;
/// Labels file loaded as list
late List<String> _labels;
/// Input size of image (height = width = 300)
static const int inputSize = 300;
/// Result score threshold
static const double threshold = 0.5;
static const String modelFileName = "model.tflite";
static const String labelFileName = "labels.txt";
/// [ImageProcessor] used to pre-process the image
ImageProcessor? imageProcessor;
/// Padding the image to transform into square
late int padSize;
/// Shapes of output tensors
late List<List<int>> _outputShapes;
/// Types of output tensors
late List<TfLiteType> _outputTypes;
/// Number of results to show
static const int numResults = 10;
CocoSSDClassifier({
Interpreter? interpreter,
List<String>? labels,
}) {
loadModel(interpreter);
loadLabels(labels);
}
/// Loads interpreter from asset
void loadModel(Interpreter? interpreter) async {
try {
_interpreter = interpreter ??
await Interpreter.fromAsset(
"models/cocossd/" + modelFileName,
options: InterpreterOptions()..threads = 4,
);
final outputTensors = _interpreter.getOutputTensors();
_outputShapes = [];
_outputTypes = [];
outputTensors.forEach((tensor) {
_outputShapes.add(tensor.shape);
_outputTypes.add(tensor.type);
});
_logger.info("Interpreter initialized");
} catch (e, s) {
_logger.severe("Error while creating interpreter", e, s);
}
}
/// Loads labels from assets
void loadLabels(List<String>? labels) async {
try {
_labels = labels ??
await FileUtil.loadLabels("assets/models/cocossd/" + labelFileName);
_logger.info("Labels initialized");
} catch (e, s) {
_logger.severe("Error while loading labels", e, s);
}
}
/// Pre-process the image
TensorImage _getProcessedImage(TensorImage inputImage) {
padSize = max(inputImage.height, inputImage.width);
imageProcessor ??= ImageProcessorBuilder()
.add(ResizeWithCropOrPadOp(padSize, padSize))
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
.build();
inputImage = imageProcessor!.process(inputImage);
return inputImage;
}
/// Runs object detection on the input image
Predictions? predict(imageLib.Image image) {
final predictStartTime = DateTime.now().millisecondsSinceEpoch;
final preProcessStart = DateTime.now().millisecondsSinceEpoch;
// Create TensorImage from image
TensorImage inputImage = TensorImage.fromImage(image);
// Pre-process TensorImage
inputImage = _getProcessedImage(inputImage);
final preProcessElapsedTime =
DateTime.now().millisecondsSinceEpoch - preProcessStart;
// TensorBuffers for output tensors
final outputLocations = TensorBufferFloat(_outputShapes[0]);
final outputClasses = TensorBufferFloat(_outputShapes[1]);
final outputScores = TensorBufferFloat(_outputShapes[2]);
final numLocations = TensorBufferFloat(_outputShapes[3]);
// Inputs object for runForMultipleInputs
// Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
final inputs = [inputImage.buffer];
// Outputs map
final outputs = {
0: outputLocations.buffer,
1: outputClasses.buffer,
2: outputScores.buffer,
3: numLocations.buffer,
};
final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
// run inference
_interpreter.runForMultipleInputs(inputs, outputs);
final inferenceTimeElapsed =
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
// Maximum number of results to show
final resultsCount = min(numResults, numLocations.getIntValue(0));
// Using labelOffset = 1 as ??? at index 0
const labelOffset = 1;
final recognitions = <Recognition>[];
for (int i = 0; i < resultsCount; i++) {
// Prediction score
final score = outputScores.getDoubleValue(i);
// Label string
final labelIndex = outputClasses.getIntValue(i) + labelOffset;
final label = _labels.elementAt(labelIndex);
if (score > threshold) {
recognitions.add(
Recognition(i, label, score),
);
}
}
final predictElapsedTime =
DateTime.now().millisecondsSinceEpoch - predictStartTime;
_logger.info(recognitions);
return Predictions(
recognitions,
Stats(
predictElapsedTime,
predictElapsedTime,
inferenceTimeElapsed,
preProcessElapsedTime,
),
);
}
/// Gets the interpreter instance
Interpreter get interpreter => _interpreter;
/// Gets the loaded labels
List<String> get labels => _labels;
}

View file

@ -0,0 +1,151 @@
import 'dart:math';
import 'package:image/image.dart' as imageLib;
import "package:logging/logging.dart";
import 'package:photos/services/object_detection/models/predictions.dart';
import 'package:photos/services/object_detection/models/recognition.dart';
import "package:photos/services/object_detection/models/stats.dart";
import "package:photos/services/object_detection/tflite/classifier.dart";
import "package:tflite_flutter/tflite_flutter.dart";
import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
class MobileNetClassifier extends Classifier {
final _logger = Logger("MobileNetClassifier");
/// Instance of Interpreter
late Interpreter _interpreter;
/// Labels file loaded as list
late List<String> _labels;
/// Input size of image (height = width = 300)
static const int inputSize = 224;
/// Result score threshold
static const double threshold = 0.5;
static const String modelFileName = "mobilenet_v1_1.0_224_quant.tflite";
static const String labelFileName = "labels_mobilenet_quant_v1_224.txt";
/// [ImageProcessor] used to pre-process the image
ImageProcessor? imageProcessor;
/// Padding the image to transform into square
late int padSize;
/// Shapes of output tensors
late List<List<int>> _outputShapes;
/// Types of output tensors
late List<TfLiteType> _outputTypes;
/// Number of results to show
static const int numResults = 10;
MobileNetClassifier({
Interpreter? interpreter,
List<String>? labels,
}) {
loadModel(interpreter);
loadLabels(labels);
}
/// Loads interpreter from asset
void loadModel(Interpreter? interpreter) async {
try {
_interpreter = interpreter ??
await Interpreter.fromAsset(
"models/mobilenet/" + modelFileName,
options: InterpreterOptions()..threads = 4,
);
final outputTensors = _interpreter.getOutputTensors();
_outputShapes = [];
_outputTypes = [];
outputTensors.forEach((tensor) {
_outputShapes.add(tensor.shape);
_outputTypes.add(tensor.type);
});
_logger.info("Interpreter initialized");
} catch (e, s) {
_logger.severe("Error while creating interpreter", e, s);
}
}
/// Loads labels from assets
void loadLabels(List<String>? labels) async {
try {
_labels = labels ??
await FileUtil.loadLabels("assets/models/mobilenet/" + labelFileName);
_logger.info("Labels initialized");
} catch (e, s) {
_logger.severe("Error while loading labels", e, s);
}
}
/// Pre-process the image
TensorImage _getProcessedImage(TensorImage inputImage) {
padSize = max(inputImage.height, inputImage.width);
imageProcessor ??= ImageProcessorBuilder()
// .add(ResizeWithCropOrPadOp(padSize, padSize))
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
.build();
inputImage = imageProcessor!.process(inputImage);
return inputImage;
}
/// Runs object detection on the input image
Predictions? predict(imageLib.Image image) {
final predictStartTime = DateTime.now().millisecondsSinceEpoch;
final preProcessStart = DateTime.now().millisecondsSinceEpoch;
// Create TensorImage from image
TensorImage inputImage = TensorImage.fromImage(image);
// Pre-process TensorImage
inputImage = _getProcessedImage(inputImage);
final preProcessElapsedTime =
DateTime.now().millisecondsSinceEpoch - preProcessStart;
// TensorBuffers for output tensors
final output = TensorBufferUint8(_outputShapes[0]);
final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
// run inference
_interpreter.run(inputImage.buffer, output.buffer);
final inferenceTimeElapsed =
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
final recognitions = <Recognition>[];
for (int i = 0; i < 1001; i++) {
final score = output.getDoubleValue(i) / 255;
if (score >= threshold) {
final label = _labels.elementAt(i);
recognitions.add(
Recognition(i, "#" + label, score),
);
}
}
final predictElapsedTime =
DateTime.now().millisecondsSinceEpoch - predictStartTime;
_logger.info(recognitions);
return Predictions(
recognitions,
Stats(
predictElapsedTime,
predictElapsedTime,
inferenceTimeElapsed,
preProcessElapsedTime,
),
);
}
/// Gets the interpreter instance
Interpreter get interpreter => _interpreter;
/// Gets the loaded labels
List<String> get labels => _labels;
}

View file

@ -2,7 +2,8 @@ import 'dart:isolate';
import "dart:typed_data";
import 'package:image/image.dart' as imgLib;
import "package:photos/services/object_detection/tflite/classifier.dart";
import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
import 'package:tflite_flutter/tflite_flutter.dart';
/// Manages separate Isolate instance for inference
@ -29,10 +30,17 @@ class IsolateUtils {
sendPort.send(port.sendPort);
await for (final IsolateData isolateData in port) {
final classifier = ObjectClassifier(
interpreter: Interpreter.fromAddress(isolateData.interpreterAddress),
labels: isolateData.labels,
);
final classifier = isolateData.type == ClassifierType.cocossd
? CocoSSDClassifier(
interpreter:
Interpreter.fromAddress(isolateData.interpreterAddress),
labels: isolateData.labels,
)
: MobileNetClassifier(
interpreter:
Interpreter.fromAddress(isolateData.interpreterAddress),
labels: isolateData.labels,
);
final image = imgLib.decodeImage(isolateData.input);
final results = classifier.predict(image!);
isolateData.responsePort.send(results);
@ -45,11 +53,18 @@ class IsolateData {
Uint8List input;
int interpreterAddress;
List<String> labels;
ClassifierType type;
late SendPort responsePort;
IsolateData(
this.input,
this.interpreterAddress,
this.labels,
this.type,
);
}
enum ClassifierType {
cocossd,
mobilenet,
}

View file

@ -165,7 +165,8 @@ flutter_native_splash:
flutter:
assets:
- assets/
- assets/models/
- assets/models/cocossd/
- assets/models/mobilenet/
fonts:
- family: Inter
fonts: