diff --git a/lib/services/object_detection/object_detection_service.dart b/lib/services/object_detection/object_detection_service.dart index c5b0acbc6..f1e655467 100644 --- a/lib/services/object_detection/object_detection_service.dart +++ b/lib/services/object_detection/object_detection_service.dart @@ -1,4 +1,5 @@ import "dart:isolate"; +import "dart:math"; import "dart:typed_data"; import "package:logging/logging.dart"; @@ -47,23 +48,32 @@ class ObjectDetectionService { static ObjectDetectionService instance = ObjectDetectionService._privateConstructor(); - Future> predict(Uint8List bytes) async { + Future> predict(Uint8List bytes) async { try { if (!inInitiated) { return Future.error("ObjectDetectionService init is not completed"); } - final results = {}; - results.addAll(await _getObjects(bytes)); - results.addAll(await _getMobileNetResults(bytes)); - results.addAll(await _getSceneResults(bytes)); - return results.toList(); + final results = {}; + final methods = [_getObjects, _getMobileNetResults, _getSceneResults]; + + for (var method in methods) { + final methodResults = await method(bytes); + methodResults.forEach((key, value) { + results.update( + key, + (existingValue) => max(existingValue, value), + ifAbsent: () => value, + ); + }); + } + return results; } catch (e, s) { _logger.severe(e, s); rethrow; } } - Future> _getObjects(Uint8List bytes) async { + Future> _getObjects(Uint8List bytes) async { try { final isolateData = IsolateData( bytes, @@ -75,10 +85,10 @@ class ObjectDetectionService { } catch (e, s) { _logger.severe("Could not run cocossd", e, s); } - return []; + return {}; } - Future> _getMobileNetResults(Uint8List bytes) async { + Future> _getMobileNetResults(Uint8List bytes) async { try { final isolateData = IsolateData( bytes, @@ -90,10 +100,10 @@ class ObjectDetectionService { } catch (e, s) { _logger.severe("Could not run mobilenet", e, s); } - return []; + return {}; } - Future> _getSceneResults(Uint8List bytes) async { + Future> _getSceneResults(Uint8List bytes) async { try { final isolateData = IsolateData( bytes, @@ -105,32 +115,35 @@ class ObjectDetectionService { } catch (e, s) { _logger.severe("Could not run scene detection", e, s); } - return []; + return {}; } - Future> _getPredictions(IsolateData isolateData) async { + Future> _getPredictions(IsolateData isolateData) async { final predictions = await _inference(isolateData); - final Set results = {}; + final Map results = {}; + if (predictions.error == null) { for (final Recognition result in predictions.recognitions!) { if (result.score > scoreThreshold) { - results.add(result.label); + // Update the result score only if it's higher than the current score + if (!results.containsKey(result.label) || + results[result.label]! < result.score) { + results[result.label] = result.score; + } } } + _logger.info( - "Time taken for " + - isolateData.type.toString() + - ": " + - predictions.stats!.totalElapsedTime.toString() + - "ms", + "Time taken for ${isolateData.type}: ${predictions.stats!.totalElapsedTime}ms", ); } else { _logger.severe( - "Error while fetching predictions for " + isolateData.type.toString(), + "Error while fetching predictions for ${isolateData.type}", predictions.error, ); } - return results.toList(); + + return results; } /// Runs inference in another isolate diff --git a/lib/services/object_detection/tflite/cocossd_classifier.dart b/lib/services/object_detection/tflite/cocossd_classifier.dart index 7656c6414..919dd8383 100644 --- a/lib/services/object_detection/tflite/cocossd_classifier.dart +++ b/lib/services/object_detection/tflite/cocossd_classifier.dart @@ -1,6 +1,5 @@ import 'dart:math'; -import "package:flutter/foundation.dart"; import 'package:image/image.dart' as image_lib; import "package:logging/logging.dart"; import 'package:photos/services/object_detection/models/predictions.dart'; diff --git a/lib/ui/viewer/file_details/objects_item_widget.dart b/lib/ui/viewer/file_details/objects_item_widget.dart index d9dc805d4..d1e9c89b2 100644 --- a/lib/ui/viewer/file_details/objects_item_widget.dart +++ b/lib/ui/viewer/file_details/objects_item_widget.dart @@ -1,3 +1,4 @@ +import "package:flutter/foundation.dart"; import "package:flutter/material.dart"; import "package:logging/logging.dart"; import "package:photos/generated/l10n.dart"; @@ -27,7 +28,7 @@ class ObjectsItemWidget extends StatelessWidget { ) async { try { final chipButtons = []; - var objectTags = []; + var objectTags = {}; final thumbnail = await getThumbnail(file); if (thumbnail != null) { objectTags = await ObjectDetectionService.instance.predict(thumbnail); @@ -40,9 +41,23 @@ class ObjectsItemWidget extends StatelessWidget { ) ]; } - for (String objectTag in objectTags) { - chipButtons.add(ChipButtonWidget(objectTag)); + // sort by values + objectTags = Map.fromEntries( + objectTags.entries.toList() + ..sort((e1, e2) => e2.value.compareTo(e1.value)), + ); + + for (MapEntry entry in objectTags.entries) { + chipButtons.add( + ChipButtonWidget( + entry.key + + (kDebugMode + ? "-" + (entry.value * 100).round().toString() + : ""), + ), + ); } + return chipButtons; } catch (e, s) { Logger("ObjctsItemWidget").info(e, s);