ente/lib/services/object_detection/object_detection_service.dart

90 lines
2.7 KiB
Dart
Raw Normal View History

2023-02-08 13:40:18 +00:00
import "dart:isolate";
import "dart:typed_data";
import "package:logging/logging.dart";
import "package:photos/services/object_detection/models/predictions.dart";
import 'package:photos/services/object_detection/models/recognition.dart';
2023-03-14 09:41:49 +00:00
import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
2023-02-08 13:40:18 +00:00
import "package:photos/services/object_detection/utils/isolate_utils.dart";
class ObjectDetectionService {
2023-03-14 09:41:49 +00:00
static const scoreThreshold = 0.5;
2023-02-12 13:59:09 +00:00
2023-02-08 13:40:18 +00:00
final _logger = Logger("ObjectDetectionService");
2023-03-14 09:41:49 +00:00
late CocoSSDClassifier _objectClassifier;
late MobileNetClassifier _mobileNetClassifier;
2023-02-08 13:40:18 +00:00
late IsolateUtils _isolateUtils;
ObjectDetectionService._privateConstructor();
Future<void> init() async {
_isolateUtils = IsolateUtils();
await _isolateUtils.start();
2023-03-14 09:41:49 +00:00
_objectClassifier = CocoSSDClassifier();
_mobileNetClassifier = MobileNetClassifier();
2023-02-08 13:40:18 +00:00
}
static ObjectDetectionService instance =
ObjectDetectionService._privateConstructor();
Future<List<String>> predict(Uint8List bytes) async {
try {
2023-03-14 09:41:49 +00:00
final results = <String>{};
final objectResults = await _getObjects(bytes);
results.addAll(objectResults);
final mobileNetResults = await _getMobileNetResults(bytes);
results.addAll(mobileNetResults);
2023-02-08 13:40:18 +00:00
return results.toList();
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
}
2023-03-14 09:41:49 +00:00
Future<List<String>> _getObjects(Uint8List bytes) async {
final isolateData = IsolateData(
bytes,
_objectClassifier.interpreter.address,
_objectClassifier.labels,
ClassifierType.cocossd,
);
final predictions = await _inference(isolateData);
final Set<String> results = {};
for (final Recognition result in predictions.recognitions) {
if (result.score > scoreThreshold) {
results.add(result.label);
}
}
return results.toList();
}
Future<List<String>> _getMobileNetResults(Uint8List bytes) async {
final isolateData = IsolateData(
bytes,
_mobileNetClassifier.interpreter.address,
_mobileNetClassifier.labels,
ClassifierType.mobilenet,
);
final predictions = await _inference(isolateData);
final Set<String> results = {};
for (final Recognition result in predictions.recognitions) {
if (result.score > scoreThreshold) {
results.add(result.label);
}
}
return results.toList();
}
2023-02-08 13:40:18 +00:00
/// Runs inference in another isolate
Future<Predictions> _inference(IsolateData isolateData) async {
final responsePort = ReceivePort();
_isolateUtils.sendPort.send(
isolateData..responsePort = responsePort.sendPort,
);
return await responsePort.first;
}
}