ente/lib/services/object_detection/utils/isolate_utils.dart
2023-03-31 18:03:43 +05:30

89 lines
2.4 KiB
Dart

import 'dart:isolate';
import "dart:typed_data";
import 'package:image/image.dart' as imgLib;
import "package:photos/services/object_detection/models/predictions.dart";
import "package:photos/services/object_detection/tflite/classifier.dart";
import 'package:photos/services/object_detection/tflite/cocossd_classifier.dart';
import "package:photos/services/object_detection/tflite/mobilenet_classifier.dart";
import "package:photos/services/object_detection/tflite/scene_classifier.dart";
import 'package:tflite_flutter/tflite_flutter.dart';
/// Manages separate Isolate instance for inference
class IsolateUtils {
static const String debugName = "InferenceIsolate";
late SendPort _sendPort;
final _receivePort = ReceivePort();
SendPort get sendPort => _sendPort;
Future<void> start() async {
await Isolate.spawn<SendPort>(
entryPoint,
_receivePort.sendPort,
debugName: debugName,
);
_sendPort = await _receivePort.first;
}
static void entryPoint(SendPort sendPort) async {
final port = ReceivePort();
sendPort.send(port.sendPort);
await for (final IsolateData isolateData in port) {
final classifier = _getClassifier(isolateData);
final image = imgLib.decodeImage(isolateData.input);
try {
final results = classifier.predict(image!);
isolateData.responsePort.send(results);
} catch (e) {
isolateData.responsePort.send(Predictions(null, null, error: e));
}
}
}
static Classifier _getClassifier(IsolateData isolateData) {
final interpreter = Interpreter.fromAddress(isolateData.interpreterAddress);
if (isolateData.type == ClassifierType.cocossd) {
return CocoSSDClassifier(
interpreter: interpreter,
labels: isolateData.labels,
);
} else if (isolateData.type == ClassifierType.mobilenet) {
return MobileNetClassifier(
interpreter: interpreter,
labels: isolateData.labels,
);
} else {
return SceneClassifier(
interpreter: interpreter,
labels: isolateData.labels,
);
}
}
}
/// Bundles data to pass between Isolate
class IsolateData {
Uint8List input;
int interpreterAddress;
List<String> labels;
ClassifierType type;
late SendPort responsePort;
IsolateData(
this.input,
this.interpreterAddress,
this.labels,
this.type,
);
}
enum ClassifierType {
cocossd,
mobilenet,
scenes,
}