ObjectDetection: Return score with label
This commit is contained in:
parent
2a4d266bcc
commit
76d4f7ae29
|
@ -1,4 +1,5 @@
|
|||
import "dart:isolate";
|
||||
import "dart:math";
|
||||
import "dart:typed_data";
|
||||
|
||||
import "package:logging/logging.dart";
|
||||
|
@ -47,23 +48,32 @@ class ObjectDetectionService {
|
|||
static ObjectDetectionService instance =
|
||||
ObjectDetectionService._privateConstructor();
|
||||
|
||||
Future<List<String>> predict(Uint8List bytes) async {
|
||||
Future<Map<String, double>> predict(Uint8List bytes) async {
|
||||
try {
|
||||
if (!inInitiated) {
|
||||
return Future.error("ObjectDetectionService init is not completed");
|
||||
}
|
||||
final results = <String>{};
|
||||
results.addAll(await _getObjects(bytes));
|
||||
results.addAll(await _getMobileNetResults(bytes));
|
||||
results.addAll(await _getSceneResults(bytes));
|
||||
return results.toList();
|
||||
final results = <String, double>{};
|
||||
final methods = [_getObjects, _getMobileNetResults, _getSceneResults];
|
||||
|
||||
for (var method in methods) {
|
||||
final methodResults = await method(bytes);
|
||||
methodResults.forEach((key, value) {
|
||||
results.update(
|
||||
key,
|
||||
(existingValue) => max(existingValue, value),
|
||||
ifAbsent: () => value,
|
||||
);
|
||||
});
|
||||
}
|
||||
return results;
|
||||
} catch (e, s) {
|
||||
_logger.severe(e, s);
|
||||
rethrow;
|
||||
}
|
||||
}
|
||||
|
||||
Future<List<String>> _getObjects(Uint8List bytes) async {
|
||||
Future<Map<String, double>> _getObjects(Uint8List bytes) async {
|
||||
try {
|
||||
final isolateData = IsolateData(
|
||||
bytes,
|
||||
|
@ -75,10 +85,10 @@ class ObjectDetectionService {
|
|||
} catch (e, s) {
|
||||
_logger.severe("Could not run cocossd", e, s);
|
||||
}
|
||||
return [];
|
||||
return {};
|
||||
}
|
||||
|
||||
Future<List<String>> _getMobileNetResults(Uint8List bytes) async {
|
||||
Future<Map<String, double>> _getMobileNetResults(Uint8List bytes) async {
|
||||
try {
|
||||
final isolateData = IsolateData(
|
||||
bytes,
|
||||
|
@ -90,10 +100,10 @@ class ObjectDetectionService {
|
|||
} catch (e, s) {
|
||||
_logger.severe("Could not run mobilenet", e, s);
|
||||
}
|
||||
return [];
|
||||
return {};
|
||||
}
|
||||
|
||||
Future<List<String>> _getSceneResults(Uint8List bytes) async {
|
||||
Future<Map<String, double>> _getSceneResults(Uint8List bytes) async {
|
||||
try {
|
||||
final isolateData = IsolateData(
|
||||
bytes,
|
||||
|
@ -105,32 +115,35 @@ class ObjectDetectionService {
|
|||
} catch (e, s) {
|
||||
_logger.severe("Could not run scene detection", e, s);
|
||||
}
|
||||
return [];
|
||||
return {};
|
||||
}
|
||||
|
||||
Future<List<String>> _getPredictions(IsolateData isolateData) async {
|
||||
Future<Map<String, double>> _getPredictions(IsolateData isolateData) async {
|
||||
final predictions = await _inference(isolateData);
|
||||
final Set<String> results = {};
|
||||
final Map<String, double> results = {};
|
||||
|
||||
if (predictions.error == null) {
|
||||
for (final Recognition result in predictions.recognitions!) {
|
||||
if (result.score > scoreThreshold) {
|
||||
results.add(result.label);
|
||||
// Update the result score only if it's higher than the current score
|
||||
if (!results.containsKey(result.label) ||
|
||||
results[result.label]! < result.score) {
|
||||
results[result.label] = result.score;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_logger.info(
|
||||
"Time taken for " +
|
||||
isolateData.type.toString() +
|
||||
": " +
|
||||
predictions.stats!.totalElapsedTime.toString() +
|
||||
"ms",
|
||||
"Time taken for ${isolateData.type}: ${predictions.stats!.totalElapsedTime}ms",
|
||||
);
|
||||
} else {
|
||||
_logger.severe(
|
||||
"Error while fetching predictions for " + isolateData.type.toString(),
|
||||
"Error while fetching predictions for ${isolateData.type}",
|
||||
predictions.error,
|
||||
);
|
||||
}
|
||||
return results.toList();
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/// Runs inference in another isolate
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import 'dart:math';
|
||||
|
||||
import "package:flutter/foundation.dart";
|
||||
import 'package:image/image.dart' as image_lib;
|
||||
import "package:logging/logging.dart";
|
||||
import 'package:photos/services/object_detection/models/predictions.dart';
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import "package:flutter/foundation.dart";
|
||||
import "package:flutter/material.dart";
|
||||
import "package:logging/logging.dart";
|
||||
import "package:photos/generated/l10n.dart";
|
||||
|
@ -27,7 +28,7 @@ class ObjectsItemWidget extends StatelessWidget {
|
|||
) async {
|
||||
try {
|
||||
final chipButtons = <ChipButtonWidget>[];
|
||||
var objectTags = <String>[];
|
||||
var objectTags = <String, double>{};
|
||||
final thumbnail = await getThumbnail(file);
|
||||
if (thumbnail != null) {
|
||||
objectTags = await ObjectDetectionService.instance.predict(thumbnail);
|
||||
|
@ -40,9 +41,23 @@ class ObjectsItemWidget extends StatelessWidget {
|
|||
)
|
||||
];
|
||||
}
|
||||
for (String objectTag in objectTags) {
|
||||
chipButtons.add(ChipButtonWidget(objectTag));
|
||||
// sort by values
|
||||
objectTags = Map.fromEntries(
|
||||
objectTags.entries.toList()
|
||||
..sort((e1, e2) => e2.value.compareTo(e1.value)),
|
||||
);
|
||||
|
||||
for (MapEntry<String, double> entry in objectTags.entries) {
|
||||
chipButtons.add(
|
||||
ChipButtonWidget(
|
||||
entry.key +
|
||||
(kDebugMode
|
||||
? "-" + (entry.value * 100).round().toString()
|
||||
: ""),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
return chipButtons;
|
||||
} catch (e, s) {
|
||||
Logger("ObjctsItemWidget").info(e, s);
|
||||
|
|
Loading…
Reference in a new issue