ObjectDetection: Return score with label

This commit is contained in:
Neeraj Gupta 2023-05-11 14:14:34 +05:30
parent 2a4d266bcc
commit 76d4f7ae29
No known key found for this signature in database
GPG key ID: 3C5A1684DC1729E1
3 changed files with 53 additions and 26 deletions

View file

@ -1,4 +1,5 @@
import "dart:isolate"; import "dart:isolate";
import "dart:math";
import "dart:typed_data"; import "dart:typed_data";
import "package:logging/logging.dart"; import "package:logging/logging.dart";
@ -47,23 +48,32 @@ class ObjectDetectionService {
static ObjectDetectionService instance = static ObjectDetectionService instance =
ObjectDetectionService._privateConstructor(); ObjectDetectionService._privateConstructor();
Future<List<String>> predict(Uint8List bytes) async { Future<Map<String, double>> predict(Uint8List bytes) async {
try { try {
if (!inInitiated) { if (!inInitiated) {
return Future.error("ObjectDetectionService init is not completed"); return Future.error("ObjectDetectionService init is not completed");
} }
final results = <String>{}; final results = <String, double>{};
results.addAll(await _getObjects(bytes)); final methods = [_getObjects, _getMobileNetResults, _getSceneResults];
results.addAll(await _getMobileNetResults(bytes));
results.addAll(await _getSceneResults(bytes)); for (var method in methods) {
return results.toList(); final methodResults = await method(bytes);
methodResults.forEach((key, value) {
results.update(
key,
(existingValue) => max(existingValue, value),
ifAbsent: () => value,
);
});
}
return results;
} catch (e, s) { } catch (e, s) {
_logger.severe(e, s); _logger.severe(e, s);
rethrow; rethrow;
} }
} }
Future<List<String>> _getObjects(Uint8List bytes) async { Future<Map<String, double>> _getObjects(Uint8List bytes) async {
try { try {
final isolateData = IsolateData( final isolateData = IsolateData(
bytes, bytes,
@ -75,10 +85,10 @@ class ObjectDetectionService {
} catch (e, s) { } catch (e, s) {
_logger.severe("Could not run cocossd", e, s); _logger.severe("Could not run cocossd", e, s);
} }
return []; return {};
} }
Future<List<String>> _getMobileNetResults(Uint8List bytes) async { Future<Map<String, double>> _getMobileNetResults(Uint8List bytes) async {
try { try {
final isolateData = IsolateData( final isolateData = IsolateData(
bytes, bytes,
@ -90,10 +100,10 @@ class ObjectDetectionService {
} catch (e, s) { } catch (e, s) {
_logger.severe("Could not run mobilenet", e, s); _logger.severe("Could not run mobilenet", e, s);
} }
return []; return {};
} }
Future<List<String>> _getSceneResults(Uint8List bytes) async { Future<Map<String, double>> _getSceneResults(Uint8List bytes) async {
try { try {
final isolateData = IsolateData( final isolateData = IsolateData(
bytes, bytes,
@ -105,32 +115,35 @@ class ObjectDetectionService {
} catch (e, s) { } catch (e, s) {
_logger.severe("Could not run scene detection", e, s); _logger.severe("Could not run scene detection", e, s);
} }
return []; return {};
} }
Future<List<String>> _getPredictions(IsolateData isolateData) async { Future<Map<String, double>> _getPredictions(IsolateData isolateData) async {
final predictions = await _inference(isolateData); final predictions = await _inference(isolateData);
final Set<String> results = {}; final Map<String, double> results = {};
if (predictions.error == null) { if (predictions.error == null) {
for (final Recognition result in predictions.recognitions!) { for (final Recognition result in predictions.recognitions!) {
if (result.score > scoreThreshold) { if (result.score > scoreThreshold) {
results.add(result.label); // Update the result score only if it's higher than the current score
if (!results.containsKey(result.label) ||
results[result.label]! < result.score) {
results[result.label] = result.score;
}
} }
} }
_logger.info( _logger.info(
"Time taken for " + "Time taken for ${isolateData.type}: ${predictions.stats!.totalElapsedTime}ms",
isolateData.type.toString() +
": " +
predictions.stats!.totalElapsedTime.toString() +
"ms",
); );
} else { } else {
_logger.severe( _logger.severe(
"Error while fetching predictions for " + isolateData.type.toString(), "Error while fetching predictions for ${isolateData.type}",
predictions.error, predictions.error,
); );
} }
return results.toList();
return results;
} }
/// Runs inference in another isolate /// Runs inference in another isolate

View file

@ -1,6 +1,5 @@
import 'dart:math'; import 'dart:math';
import "package:flutter/foundation.dart";
import 'package:image/image.dart' as image_lib; import 'package:image/image.dart' as image_lib;
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import 'package:photos/services/object_detection/models/predictions.dart'; import 'package:photos/services/object_detection/models/predictions.dart';

View file

@ -1,3 +1,4 @@
import "package:flutter/foundation.dart";
import "package:flutter/material.dart"; import "package:flutter/material.dart";
import "package:logging/logging.dart"; import "package:logging/logging.dart";
import "package:photos/generated/l10n.dart"; import "package:photos/generated/l10n.dart";
@ -27,7 +28,7 @@ class ObjectsItemWidget extends StatelessWidget {
) async { ) async {
try { try {
final chipButtons = <ChipButtonWidget>[]; final chipButtons = <ChipButtonWidget>[];
var objectTags = <String>[]; var objectTags = <String, double>{};
final thumbnail = await getThumbnail(file); final thumbnail = await getThumbnail(file);
if (thumbnail != null) { if (thumbnail != null) {
objectTags = await ObjectDetectionService.instance.predict(thumbnail); objectTags = await ObjectDetectionService.instance.predict(thumbnail);
@ -40,9 +41,23 @@ class ObjectsItemWidget extends StatelessWidget {
) )
]; ];
} }
for (String objectTag in objectTags) { // sort by values
chipButtons.add(ChipButtonWidget(objectTag)); objectTags = Map.fromEntries(
objectTags.entries.toList()
..sort((e1, e2) => e2.value.compareTo(e1.value)),
);
for (MapEntry<String, double> entry in objectTags.entries) {
chipButtons.add(
ChipButtonWidget(
entry.key +
(kDebugMode
? "-" + (entry.value * 100).round().toString()
: ""),
),
);
} }
return chipButtons; return chipButtons;
} catch (e, s) { } catch (e, s) {
Logger("ObjctsItemWidget").info(e, s); Logger("ObjctsItemWidget").info(e, s);