ObjectDetection: Return score with label

This commit is contained in:
Neeraj Gupta 2023-05-11 14:14:34 +05:30
parent 2a4d266bcc
commit 76d4f7ae29
No known key found for this signature in database
GPG key ID: 3C5A1684DC1729E1
3 changed files with 53 additions and 26 deletions

View file

@ -1,4 +1,5 @@
import "dart:isolate";
import "dart:math";
import "dart:typed_data";
import "package:logging/logging.dart";
@ -47,23 +48,32 @@ class ObjectDetectionService {
static ObjectDetectionService instance =
ObjectDetectionService._privateConstructor();
Future<List<String>> predict(Uint8List bytes) async {
Future<Map<String, double>> predict(Uint8List bytes) async {
try {
if (!inInitiated) {
return Future.error("ObjectDetectionService init is not completed");
}
final results = <String>{};
results.addAll(await _getObjects(bytes));
results.addAll(await _getMobileNetResults(bytes));
results.addAll(await _getSceneResults(bytes));
return results.toList();
final results = <String, double>{};
final methods = [_getObjects, _getMobileNetResults, _getSceneResults];
for (var method in methods) {
final methodResults = await method(bytes);
methodResults.forEach((key, value) {
results.update(
key,
(existingValue) => max(existingValue, value),
ifAbsent: () => value,
);
});
}
return results;
} catch (e, s) {
_logger.severe(e, s);
rethrow;
}
}
Future<List<String>> _getObjects(Uint8List bytes) async {
Future<Map<String, double>> _getObjects(Uint8List bytes) async {
try {
final isolateData = IsolateData(
bytes,
@ -75,10 +85,10 @@ class ObjectDetectionService {
} catch (e, s) {
_logger.severe("Could not run cocossd", e, s);
}
return [];
return {};
}
Future<List<String>> _getMobileNetResults(Uint8List bytes) async {
Future<Map<String, double>> _getMobileNetResults(Uint8List bytes) async {
try {
final isolateData = IsolateData(
bytes,
@ -90,10 +100,10 @@ class ObjectDetectionService {
} catch (e, s) {
_logger.severe("Could not run mobilenet", e, s);
}
return [];
return {};
}
Future<List<String>> _getSceneResults(Uint8List bytes) async {
Future<Map<String, double>> _getSceneResults(Uint8List bytes) async {
try {
final isolateData = IsolateData(
bytes,
@ -105,32 +115,35 @@ class ObjectDetectionService {
} catch (e, s) {
_logger.severe("Could not run scene detection", e, s);
}
return [];
return {};
}
Future<List<String>> _getPredictions(IsolateData isolateData) async {
Future<Map<String, double>> _getPredictions(IsolateData isolateData) async {
final predictions = await _inference(isolateData);
final Set<String> results = {};
final Map<String, double> results = {};
if (predictions.error == null) {
for (final Recognition result in predictions.recognitions!) {
if (result.score > scoreThreshold) {
results.add(result.label);
// Update the result score only if it's higher than the current score
if (!results.containsKey(result.label) ||
results[result.label]! < result.score) {
results[result.label] = result.score;
}
}
}
_logger.info(
"Time taken for " +
isolateData.type.toString() +
": " +
predictions.stats!.totalElapsedTime.toString() +
"ms",
"Time taken for ${isolateData.type}: ${predictions.stats!.totalElapsedTime}ms",
);
} else {
_logger.severe(
"Error while fetching predictions for " + isolateData.type.toString(),
"Error while fetching predictions for ${isolateData.type}",
predictions.error,
);
}
return results.toList();
return results;
}
/// Runs inference in another isolate

View file

@ -1,6 +1,5 @@
import 'dart:math';
import "package:flutter/foundation.dart";
import 'package:image/image.dart' as image_lib;
import "package:logging/logging.dart";
import 'package:photos/services/object_detection/models/predictions.dart';

View file

@ -1,3 +1,4 @@
import "package:flutter/foundation.dart";
import "package:flutter/material.dart";
import "package:logging/logging.dart";
import "package:photos/generated/l10n.dart";
@ -27,7 +28,7 @@ class ObjectsItemWidget extends StatelessWidget {
) async {
try {
final chipButtons = <ChipButtonWidget>[];
var objectTags = <String>[];
var objectTags = <String, double>{};
final thumbnail = await getThumbnail(file);
if (thumbnail != null) {
objectTags = await ObjectDetectionService.instance.predict(thumbnail);
@ -40,9 +41,23 @@ class ObjectsItemWidget extends StatelessWidget {
)
];
}
for (String objectTag in objectTags) {
chipButtons.add(ChipButtonWidget(objectTag));
// sort by values
objectTags = Map.fromEntries(
objectTags.entries.toList()
..sort((e1, e2) => e2.value.compareTo(e1.value)),
);
for (MapEntry<String, double> entry in objectTags.entries) {
chipButtons.add(
ChipButtonWidget(
entry.key +
(kDebugMode
? "-" + (entry.value * 100).round().toString()
: ""),
),
);
}
return chipButtons;
} catch (e, s) {
Logger("ObjctsItemWidget").info(e, s);