2023-03-14 09:41:49 +00:00
|
|
|
import 'dart:math';
|
|
|
|
|
2023-03-15 07:33:48 +00:00
|
|
|
import 'package:image/image.dart' as image_lib;
|
2023-03-15 07:58:10 +00:00
|
|
|
import "package:logging/logging.dart";
|
2023-03-14 09:41:49 +00:00
|
|
|
import 'package:photos/services/object_detection/models/predictions.dart';
|
|
|
|
import 'package:photos/services/object_detection/models/recognition.dart';
|
|
|
|
import "package:photos/services/object_detection/models/stats.dart";
|
|
|
|
import "package:photos/services/object_detection/tflite/classifier.dart";
|
|
|
|
import "package:tflite_flutter/tflite_flutter.dart";
|
|
|
|
import "package:tflite_flutter_helper/tflite_flutter_helper.dart";
|
|
|
|
|
|
|
|
/// Classifier
|
|
|
|
class CocoSSDClassifier extends Classifier {
|
2023-03-15 07:58:10 +00:00
|
|
|
static final _logger = Logger("CocoSSDClassifier");
|
2023-05-11 07:38:02 +00:00
|
|
|
static const double threshold = 0.4;
|
2023-03-14 09:41:49 +00:00
|
|
|
|
2023-03-15 07:50:48 +00:00
|
|
|
@override
|
|
|
|
String get modelPath => "models/cocossd/model.tflite";
|
2023-03-14 09:41:49 +00:00
|
|
|
|
2023-03-15 07:50:48 +00:00
|
|
|
@override
|
|
|
|
String get labelPath => "assets/models/cocossd/labels.txt";
|
2023-03-14 09:41:49 +00:00
|
|
|
|
2023-03-15 07:58:10 +00:00
|
|
|
@override
|
|
|
|
int get inputSize => 300;
|
|
|
|
|
|
|
|
@override
|
|
|
|
Logger get logger => _logger;
|
|
|
|
|
2023-03-14 09:41:49 +00:00
|
|
|
static const int numResults = 10;
|
|
|
|
|
|
|
|
CocoSSDClassifier({
|
|
|
|
Interpreter? interpreter,
|
|
|
|
List<String>? labels,
|
|
|
|
}) {
|
|
|
|
loadModel(interpreter);
|
|
|
|
loadLabels(labels);
|
|
|
|
}
|
|
|
|
|
2023-03-15 07:32:26 +00:00
|
|
|
@override
|
2023-03-15 07:33:48 +00:00
|
|
|
Predictions? predict(image_lib.Image image) {
|
2023-03-14 09:41:49 +00:00
|
|
|
final predictStartTime = DateTime.now().millisecondsSinceEpoch;
|
|
|
|
|
|
|
|
final preProcessStart = DateTime.now().millisecondsSinceEpoch;
|
|
|
|
|
|
|
|
// Create TensorImage from image
|
|
|
|
TensorImage inputImage = TensorImage.fromImage(image);
|
|
|
|
|
|
|
|
// Pre-process TensorImage
|
2023-03-15 07:58:10 +00:00
|
|
|
inputImage = getProcessedImage(inputImage);
|
2023-03-14 09:41:49 +00:00
|
|
|
|
|
|
|
final preProcessElapsedTime =
|
|
|
|
DateTime.now().millisecondsSinceEpoch - preProcessStart;
|
|
|
|
|
|
|
|
// TensorBuffers for output tensors
|
2023-03-15 07:50:48 +00:00
|
|
|
final outputLocations = TensorBufferFloat(outputShapes[0]);
|
|
|
|
final outputClasses = TensorBufferFloat(outputShapes[1]);
|
|
|
|
final outputScores = TensorBufferFloat(outputShapes[2]);
|
|
|
|
final numLocations = TensorBufferFloat(outputShapes[3]);
|
2023-03-14 09:41:49 +00:00
|
|
|
|
|
|
|
// Inputs object for runForMultipleInputs
|
|
|
|
// Use [TensorImage.buffer] or [TensorBuffer.buffer] to pass by reference
|
|
|
|
final inputs = [inputImage.buffer];
|
|
|
|
|
|
|
|
// Outputs map
|
|
|
|
final outputs = {
|
|
|
|
0: outputLocations.buffer,
|
|
|
|
1: outputClasses.buffer,
|
|
|
|
2: outputScores.buffer,
|
|
|
|
3: numLocations.buffer,
|
|
|
|
};
|
|
|
|
|
|
|
|
final inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
|
|
|
|
|
|
|
|
// run inference
|
2023-03-15 07:50:48 +00:00
|
|
|
interpreter.runForMultipleInputs(inputs, outputs);
|
2023-03-14 09:41:49 +00:00
|
|
|
|
|
|
|
final inferenceTimeElapsed =
|
|
|
|
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
|
|
|
|
|
|
|
|
// Maximum number of results to show
|
|
|
|
final resultsCount = min(numResults, numLocations.getIntValue(0));
|
|
|
|
|
|
|
|
// Using labelOffset = 1 as ??? at index 0
|
|
|
|
const labelOffset = 1;
|
|
|
|
|
|
|
|
final recognitions = <Recognition>[];
|
|
|
|
|
|
|
|
for (int i = 0; i < resultsCount; i++) {
|
|
|
|
// Prediction score
|
|
|
|
final score = outputScores.getDoubleValue(i);
|
|
|
|
|
|
|
|
// Label string
|
|
|
|
final labelIndex = outputClasses.getIntValue(i) + labelOffset;
|
2023-03-15 07:50:48 +00:00
|
|
|
final label = labels.elementAt(labelIndex);
|
2023-03-14 09:41:49 +00:00
|
|
|
|
|
|
|
if (score > threshold) {
|
|
|
|
recognitions.add(
|
|
|
|
Recognition(i, label, score),
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
final predictElapsedTime =
|
|
|
|
DateTime.now().millisecondsSinceEpoch - predictStartTime;
|
|
|
|
return Predictions(
|
|
|
|
recognitions,
|
|
|
|
Stats(
|
|
|
|
predictElapsedTime,
|
|
|
|
predictElapsedTime,
|
|
|
|
inferenceTimeElapsed,
|
|
|
|
preProcessElapsedTime,
|
|
|
|
),
|
|
|
|
);
|
|
|
|
}
|
|
|
|
}
|