ObjectDetection: Return score with label
This commit is contained in:
parent
2a4d266bcc
commit
76d4f7ae29
|
@ -1,4 +1,5 @@
|
||||||
import "dart:isolate";
|
import "dart:isolate";
|
||||||
|
import "dart:math";
|
||||||
import "dart:typed_data";
|
import "dart:typed_data";
|
||||||
|
|
||||||
import "package:logging/logging.dart";
|
import "package:logging/logging.dart";
|
||||||
|
@ -47,23 +48,32 @@ class ObjectDetectionService {
|
||||||
static ObjectDetectionService instance =
|
static ObjectDetectionService instance =
|
||||||
ObjectDetectionService._privateConstructor();
|
ObjectDetectionService._privateConstructor();
|
||||||
|
|
||||||
Future<List<String>> predict(Uint8List bytes) async {
|
Future<Map<String, double>> predict(Uint8List bytes) async {
|
||||||
try {
|
try {
|
||||||
if (!inInitiated) {
|
if (!inInitiated) {
|
||||||
return Future.error("ObjectDetectionService init is not completed");
|
return Future.error("ObjectDetectionService init is not completed");
|
||||||
}
|
}
|
||||||
final results = <String>{};
|
final results = <String, double>{};
|
||||||
results.addAll(await _getObjects(bytes));
|
final methods = [_getObjects, _getMobileNetResults, _getSceneResults];
|
||||||
results.addAll(await _getMobileNetResults(bytes));
|
|
||||||
results.addAll(await _getSceneResults(bytes));
|
for (var method in methods) {
|
||||||
return results.toList();
|
final methodResults = await method(bytes);
|
||||||
|
methodResults.forEach((key, value) {
|
||||||
|
results.update(
|
||||||
|
key,
|
||||||
|
(existingValue) => max(existingValue, value),
|
||||||
|
ifAbsent: () => value,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
return results;
|
||||||
} catch (e, s) {
|
} catch (e, s) {
|
||||||
_logger.severe(e, s);
|
_logger.severe(e, s);
|
||||||
rethrow;
|
rethrow;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<List<String>> _getObjects(Uint8List bytes) async {
|
Future<Map<String, double>> _getObjects(Uint8List bytes) async {
|
||||||
try {
|
try {
|
||||||
final isolateData = IsolateData(
|
final isolateData = IsolateData(
|
||||||
bytes,
|
bytes,
|
||||||
|
@ -75,10 +85,10 @@ class ObjectDetectionService {
|
||||||
} catch (e, s) {
|
} catch (e, s) {
|
||||||
_logger.severe("Could not run cocossd", e, s);
|
_logger.severe("Could not run cocossd", e, s);
|
||||||
}
|
}
|
||||||
return [];
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<List<String>> _getMobileNetResults(Uint8List bytes) async {
|
Future<Map<String, double>> _getMobileNetResults(Uint8List bytes) async {
|
||||||
try {
|
try {
|
||||||
final isolateData = IsolateData(
|
final isolateData = IsolateData(
|
||||||
bytes,
|
bytes,
|
||||||
|
@ -90,10 +100,10 @@ class ObjectDetectionService {
|
||||||
} catch (e, s) {
|
} catch (e, s) {
|
||||||
_logger.severe("Could not run mobilenet", e, s);
|
_logger.severe("Could not run mobilenet", e, s);
|
||||||
}
|
}
|
||||||
return [];
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<List<String>> _getSceneResults(Uint8List bytes) async {
|
Future<Map<String, double>> _getSceneResults(Uint8List bytes) async {
|
||||||
try {
|
try {
|
||||||
final isolateData = IsolateData(
|
final isolateData = IsolateData(
|
||||||
bytes,
|
bytes,
|
||||||
|
@ -105,32 +115,35 @@ class ObjectDetectionService {
|
||||||
} catch (e, s) {
|
} catch (e, s) {
|
||||||
_logger.severe("Could not run scene detection", e, s);
|
_logger.severe("Could not run scene detection", e, s);
|
||||||
}
|
}
|
||||||
return [];
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
Future<List<String>> _getPredictions(IsolateData isolateData) async {
|
Future<Map<String, double>> _getPredictions(IsolateData isolateData) async {
|
||||||
final predictions = await _inference(isolateData);
|
final predictions = await _inference(isolateData);
|
||||||
final Set<String> results = {};
|
final Map<String, double> results = {};
|
||||||
|
|
||||||
if (predictions.error == null) {
|
if (predictions.error == null) {
|
||||||
for (final Recognition result in predictions.recognitions!) {
|
for (final Recognition result in predictions.recognitions!) {
|
||||||
if (result.score > scoreThreshold) {
|
if (result.score > scoreThreshold) {
|
||||||
results.add(result.label);
|
// Update the result score only if it's higher than the current score
|
||||||
|
if (!results.containsKey(result.label) ||
|
||||||
|
results[result.label]! < result.score) {
|
||||||
|
results[result.label] = result.score;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_logger.info(
|
_logger.info(
|
||||||
"Time taken for " +
|
"Time taken for ${isolateData.type}: ${predictions.stats!.totalElapsedTime}ms",
|
||||||
isolateData.type.toString() +
|
|
||||||
": " +
|
|
||||||
predictions.stats!.totalElapsedTime.toString() +
|
|
||||||
"ms",
|
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
_logger.severe(
|
_logger.severe(
|
||||||
"Error while fetching predictions for " + isolateData.type.toString(),
|
"Error while fetching predictions for ${isolateData.type}",
|
||||||
predictions.error,
|
predictions.error,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
return results.toList();
|
|
||||||
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Runs inference in another isolate
|
/// Runs inference in another isolate
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import 'dart:math';
|
import 'dart:math';
|
||||||
|
|
||||||
import "package:flutter/foundation.dart";
|
|
||||||
import 'package:image/image.dart' as image_lib;
|
import 'package:image/image.dart' as image_lib;
|
||||||
import "package:logging/logging.dart";
|
import "package:logging/logging.dart";
|
||||||
import 'package:photos/services/object_detection/models/predictions.dart';
|
import 'package:photos/services/object_detection/models/predictions.dart';
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import "package:flutter/foundation.dart";
|
||||||
import "package:flutter/material.dart";
|
import "package:flutter/material.dart";
|
||||||
import "package:logging/logging.dart";
|
import "package:logging/logging.dart";
|
||||||
import "package:photos/generated/l10n.dart";
|
import "package:photos/generated/l10n.dart";
|
||||||
|
@ -27,7 +28,7 @@ class ObjectsItemWidget extends StatelessWidget {
|
||||||
) async {
|
) async {
|
||||||
try {
|
try {
|
||||||
final chipButtons = <ChipButtonWidget>[];
|
final chipButtons = <ChipButtonWidget>[];
|
||||||
var objectTags = <String>[];
|
var objectTags = <String, double>{};
|
||||||
final thumbnail = await getThumbnail(file);
|
final thumbnail = await getThumbnail(file);
|
||||||
if (thumbnail != null) {
|
if (thumbnail != null) {
|
||||||
objectTags = await ObjectDetectionService.instance.predict(thumbnail);
|
objectTags = await ObjectDetectionService.instance.predict(thumbnail);
|
||||||
|
@ -40,9 +41,23 @@ class ObjectsItemWidget extends StatelessWidget {
|
||||||
)
|
)
|
||||||
];
|
];
|
||||||
}
|
}
|
||||||
for (String objectTag in objectTags) {
|
// sort by values
|
||||||
chipButtons.add(ChipButtonWidget(objectTag));
|
objectTags = Map.fromEntries(
|
||||||
|
objectTags.entries.toList()
|
||||||
|
..sort((e1, e2) => e2.value.compareTo(e1.value)),
|
||||||
|
);
|
||||||
|
|
||||||
|
for (MapEntry<String, double> entry in objectTags.entries) {
|
||||||
|
chipButtons.add(
|
||||||
|
ChipButtonWidget(
|
||||||
|
entry.key +
|
||||||
|
(kDebugMode
|
||||||
|
? "-" + (entry.value * 100).round().toString()
|
||||||
|
: ""),
|
||||||
|
),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
return chipButtons;
|
return chipButtons;
|
||||||
} catch (e, s) {
|
} catch (e, s) {
|
||||||
Logger("ObjctsItemWidget").info(e, s);
|
Logger("ObjctsItemWidget").info(e, s);
|
||||||
|
|
Loading…
Reference in a new issue