Run ML only when device is healthy (#1728)

This commit is contained in:
Vishnu Mohandas 2024-02-18 17:36:03 +05:30 committed by GitHub
commit 8f781915e1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
18 changed files with 160 additions and 66 deletions

View file

@ -13,7 +13,7 @@ import 'package:photos/ente_theme_data.dart';
import "package:photos/generated/l10n.dart";
import "package:photos/l10n/l10n.dart";
import 'package:photos/services/app_lifecycle_service.dart';
import "package:photos/services/semantic_search/semantic_search_service.dart";
import "package:photos/services/machine_learning/machine_learning_controller.dart";
import 'package:photos/services/sync_service.dart';
import 'package:photos/ui/tabs/home_widget.dart';
import "package:photos/ui/viewer/actions/file_viewer.dart";
@ -43,12 +43,8 @@ class EnteApp extends StatefulWidget {
}
class _EnteAppState extends State<EnteApp> with WidgetsBindingObserver {
static const initialInteractionTimeout = Duration(seconds: 10);
static const defaultInteractionTimeout = Duration(seconds: 5);
final _logger = Logger("EnteAppState");
late Locale locale;
late Timer _userInteractionTimer;
@override
void initState() {
@ -57,7 +53,6 @@ class _EnteAppState extends State<EnteApp> with WidgetsBindingObserver {
locale = widget.locale;
setupIntentAction();
WidgetsBinding.instance.addObserver(this);
_setupInteractionTimer(timeout: initialInteractionTimeout);
}
setLocale(Locale newLocale) {
@ -76,30 +71,12 @@ class _EnteAppState extends State<EnteApp> with WidgetsBindingObserver {
}
}
void _resetTimer() {
_userInteractionTimer.cancel();
_setupInteractionTimer();
}
void _setupInteractionTimer({Duration timeout = defaultInteractionTimeout}) {
if (Platform.isAndroid || kDebugMode) {
_userInteractionTimer = Timer(timeout, () {
debugPrint("user is not interacting with the app");
SemanticSearchService.instance.startIndexing();
});
} else {
SemanticSearchService.instance.startIndexing();
}
}
@override
Widget build(BuildContext context) {
if (Platform.isAndroid || kDebugMode) {
return Listener(
onPointerDown: (event) {
SemanticSearchService.instance.pauseIndexing();
debugPrint("user is interacting with the app");
_resetTimer();
MachineLearningController.instance.onUserInteraction();
},
child: AdaptiveTheme(
light: lightThemeData,
@ -149,7 +126,6 @@ class _EnteAppState extends State<EnteApp> with WidgetsBindingObserver {
@override
void dispose() {
WidgetsBinding.instance.removeObserver(this);
_userInteractionTimer.cancel();
super.dispose();
}

View file

@ -25,9 +25,9 @@ import 'package:photos/services/billing_service.dart';
import 'package:photos/services/collections_service.dart';
import 'package:photos/services/favorites_service.dart';
import 'package:photos/services/ignored_files_service.dart';
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import 'package:photos/services/memories_service.dart';
import 'package:photos/services/search_service.dart';
import "package:photos/services/semantic_search/semantic_search_service.dart";
import 'package:photos/services/sync_service.dart';
import 'package:photos/utils/crypto_util.dart';
import 'package:photos/utils/file_uploader.dart';

View file

@ -0,0 +1,7 @@
import "package:photos/events/event.dart";
class MachineLearningControlEvent extends Event {
final bool shouldRun;
MachineLearningControlEvent(this.shouldRun);
}

View file

@ -30,11 +30,12 @@ import 'package:photos/services/feature_flag_service.dart';
import 'package:photos/services/local_file_update_service.dart';
import 'package:photos/services/local_sync_service.dart';
import "package:photos/services/location_service.dart";
import "package:photos/services/machine_learning/machine_learning_controller.dart";
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import 'package:photos/services/memories_service.dart';
import 'package:photos/services/push_service.dart';
import 'package:photos/services/remote_sync_service.dart';
import 'package:photos/services/search_service.dart';
import 'package:photos/services/semantic_search/semantic_search_service.dart';
import "package:photos/services/storage_bonus_service.dart";
import 'package:photos/services/sync_service.dart';
import 'package:photos/services/trash_sync_service.dart';
@ -193,7 +194,8 @@ Future<void> _init(bool isBackground, {String via = ''}) async {
});
}
unawaited(FeatureFlagService.instance.init());
unawaited(SemanticSearchService.instance.init(isInBackground: isBackground));
unawaited(SemanticSearchService.instance.init());
MachineLearningController.instance.init();
// Can not including existing tf/ml binaries as they are not being built
// from source.
// See https://gitlab.com/fdroid/fdroiddata/-/merge_requests/12671#note_1294346819

View file

@ -0,0 +1,97 @@
import "dart:async";
import "dart:io";
import "package:battery_info/battery_info_plugin.dart";
import "package:battery_info/model/android_battery_info.dart";
import "package:logging/logging.dart";
import "package:photos/core/event_bus.dart";
import "package:photos/events/machine_learning_control_event.dart";
class MachineLearningController {
MachineLearningController._privateConstructor();
static final MachineLearningController instance =
MachineLearningController._privateConstructor();
final _logger = Logger("MachineLearningController");
static const kMaximumTemperature = 36; // 36 degree celsius
static const kMinimumBatteryLevel = 20; // 20%
static const kInitialInteractionTimeout = Duration(seconds: 10);
static const kDefaultInteractionTimeout = Duration(seconds: 5);
static const kUnhealthyStates = ["over_heat", "over_voltage", "dead"];
bool _isDeviceHealthy = true;
bool _isUserInteracting = true;
bool _isRunningML = false;
late Timer _userInteractionTimer;
void init() {
if (Platform.isAndroid) {
_startInteractionTimer(timeout: kInitialInteractionTimeout);
BatteryInfoPlugin()
.androidBatteryInfoStream
.listen((AndroidBatteryInfo? batteryInfo) {
_onBatteryStateUpdate(batteryInfo);
});
} else {
// Always run Machine Learning on iOS
Bus.instance.fire(MachineLearningControlEvent(true));
}
}
void onUserInteraction() {
_logger.info("User is interacting with the app");
_isUserInteracting = true;
_fireControlEvent();
_resetTimer();
}
void _fireControlEvent() {
final shouldRunML = _isDeviceHealthy && !_isUserInteracting;
if (shouldRunML != _isRunningML) {
_isRunningML = shouldRunML;
_logger.info(
"Firing event with device health: $_isDeviceHealthy and user interaction: $_isUserInteracting",
);
Bus.instance.fire(MachineLearningControlEvent(shouldRunML));
}
}
void _startInteractionTimer({Duration timeout = kDefaultInteractionTimeout}) {
_userInteractionTimer = Timer(timeout, () {
_logger.info("User is not interacting with the app");
_isUserInteracting = false;
_fireControlEvent();
});
}
void _resetTimer() {
_userInteractionTimer.cancel();
_startInteractionTimer();
}
void _onBatteryStateUpdate(AndroidBatteryInfo? batteryInfo) {
_logger.info("Battery info: ${batteryInfo!.toJson()}");
_isDeviceHealthy = _computeIsDeviceHealthy(batteryInfo);
_fireControlEvent();
}
bool _computeIsDeviceHealthy(AndroidBatteryInfo info) {
return _hasSufficientBattery(info.batteryLevel ?? kMinimumBatteryLevel) &&
_isAcceptableTemperature(info.temperature ?? kMaximumTemperature) &&
_isBatteryHealthy(info.health ?? "");
}
bool _hasSufficientBattery(int batteryLevel) {
return batteryLevel >= kMinimumBatteryLevel;
}
bool _isAcceptableTemperature(int temperature) {
return temperature <= kMaximumTemperature;
}
bool _isBatteryHealthy(String health) {
return !kUnhealthyStates.contains(health);
}
}

View file

@ -9,7 +9,7 @@ import "package:photos/db/embeddings_db.dart";
import "package:photos/db/files_db.dart";
import "package:photos/models/embedding.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/semantic_search/remote_embedding.dart";
import 'package:photos/services/machine_learning/semantic_search/remote_embedding.dart';
import "package:photos/utils/crypto_util.dart";
import "package:photos/utils/file_download_util.dart";
import "package:shared_preferences/shared_preferences.dart";

View file

@ -1,7 +1,7 @@
import "package:clip_ggml/clip_ggml.dart";
import "package:computer/computer.dart";
import "package:logging/logging.dart";
import 'package:photos/services/semantic_search/frameworks/ml_framework.dart';
import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
class GGML extends MLFramework {
static const kModelBucketEndpoint = "https://models.ente.io/";

View file

@ -1,9 +1,9 @@
import "package:computer/computer.dart";
import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/semantic_search/frameworks/ml_framework.dart";
import "package:photos/services/semantic_search/frameworks/onnx/onnx_image_encoder.dart";
import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_encoder.dart";
import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_image_encoder.dart';
import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_encoder.dart';
class ONNX extends MLFramework {
static const kModelBucketEndpoint = "https://models.ente.io/";

View file

@ -5,7 +5,7 @@ import "dart:typed_data";
import "package:flutter/services.dart";
import "package:logging/logging.dart";
import "package:onnxruntime/onnxruntime.dart";
import "package:photos/services/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart";
import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx_text_tokenizer.dart';
class OnnxTextEncoder {
static const kVocabFilePath = "assets/models/clip/bpe_simple_vocab_16e6.txt";

View file

@ -11,13 +11,14 @@ import "package:photos/db/files_db.dart";
import "package:photos/events/diff_sync_complete_event.dart";
import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/events/file_uploaded_event.dart";
import "package:photos/events/machine_learning_control_event.dart";
import "package:photos/models/embedding.dart";
import "package:photos/models/file/file.dart";
import "package:photos/services/collections_service.dart";
import "package:photos/services/semantic_search/embedding_store.dart";
import "package:photos/services/semantic_search/frameworks/ggml.dart";
import "package:photos/services/semantic_search/frameworks/ml_framework.dart";
import 'package:photos/services/semantic_search/frameworks/onnx/onnx.dart';
import 'package:photos/services/machine_learning/semantic_search/embedding_store.dart';
import 'package:photos/services/machine_learning/semantic_search/frameworks/ggml.dart';
import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
import 'package:photos/services/machine_learning/semantic_search/frameworks/onnx/onnx.dart';
import "package:photos/utils/debouncer.dart";
import "package:photos/utils/device_info.dart";
import "package:photos/utils/local_settings.dart";
@ -50,26 +51,11 @@ class SemanticSearchService {
Future<List<EnteFile>>? _ongoingRequest;
List<Embedding> _cachedEmbeddings = <Embedding>[];
PendingQuery? _nextQuery;
Completer<void> _userInteraction = Completer<void>();
Completer<void> _mlController = Completer<void>();
get hasInitialized => _hasInitialized;
void startIndexing() {
_logger.info("Start indexing");
_userInteraction.complete();
}
void pauseIndexing() {
if (_userInteraction.isCompleted) {
_logger.info("Pausing indexing");
_userInteraction = Completer<void>();
}
}
Future<void> init({
bool shouldSyncImmediately = false,
bool isInBackground = false,
}) async {
Future<void> init({bool shouldSyncImmediately = false}) async {
if (!LocalSettings.instance.hasEnabledMagicSearch()) {
return;
}
@ -114,10 +100,13 @@ class SemanticSearchService {
if (shouldSyncImmediately) {
unawaited(sync());
}
if (isInBackground) {
// Do not block on user interactions
startIndexing();
}
Bus.instance.on<MachineLearningControlEvent>().listen((event) {
if (event.shouldRun) {
_startIndexing();
} else {
_pauseIndexing();
}
});
}
Future<void> release() async {
@ -301,9 +290,9 @@ class SemanticSearchService {
if (!_frameworkInitialization.isCompleted) {
return;
}
if (!_userInteraction.isCompleted) {
_logger.info("Waiting for user interactions to stop...");
await _userInteraction.future;
if (!_mlController.isCompleted) {
_logger.info("Waiting for a green signal from controller...");
await _mlController.future;
}
try {
final thumbnail = await getThumbnailForUploadedFile(file);
@ -376,6 +365,20 @@ class SemanticSearchService {
return Model.onnxClip;
}
}
void _startIndexing() {
_logger.info("Start indexing");
if (!_mlController.isCompleted) {
_mlController.complete();
}
}
void _pauseIndexing() {
if (_mlController.isCompleted) {
_logger.info("Pausing indexing");
_mlController = Completer<void>();
}
}
}
List<QueryResult> computeBulkScore(Map args) {

View file

@ -25,7 +25,7 @@ import 'package:photos/models/search/generic_search_result.dart';
import "package:photos/models/search/search_types.dart";
import 'package:photos/services/collections_service.dart';
import "package:photos/services/location_service.dart";
import 'package:photos/services/semantic_search/semantic_search_service.dart';
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import "package:photos/states/location_screen_state.dart";
import "package:photos/ui/viewer/location/add_location_sheet.dart";
import "package:photos/ui/viewer/location/location_screen.dart";

View file

@ -6,8 +6,8 @@ import "package:photos/core/event_bus.dart";
import 'package:photos/events/embedding_updated_event.dart';
import "package:photos/generated/l10n.dart";
import "package:photos/services/feature_flag_service.dart";
import "package:photos/services/semantic_search/frameworks/ml_framework.dart";
import "package:photos/services/semantic_search/semantic_search_service.dart";
import 'package:photos/services/machine_learning/semantic_search/frameworks/ml_framework.dart';
import 'package:photos/services/machine_learning/semantic_search/semantic_search_service.dart';
import "package:photos/theme/ente_theme.dart";
import "package:photos/ui/common/loading_widget.dart";
import "package:photos/ui/components/buttons/icon_button_widget.dart";

View file

@ -89,6 +89,14 @@ packages:
url: "https://pub.dev"
source: hosted
version: "1.2.1"
battery_info:
dependency: "direct main"
description:
name: battery_info
sha256: "5d5249c87a600a0a20d6b2f5ffdf90d711bccb1bfd3a58e5a6228f270031c680"
url: "https://pub.dev"
source: hosted
version: "1.1.1"
bip39:
dependency: "direct main"
description:

View file

@ -23,6 +23,7 @@ dependencies:
animated_list_plus: ^0.4.5
archive: ^3.1.2
background_fetch: ^1.2.1
battery_info: ^1.1.1
bip39: ^1.0.6
cached_network_image: ^3.0.0
chewie: