Use GGML on Graphene (#1717)

Use GGML on Graphene
This commit is contained in:
Vishnu Mohandas 2024-02-14 20:37:02 +05:30 committed by GitHub
commit 9144e003e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 27 additions and 9 deletions

View file

@ -666,7 +666,7 @@ class MessageLookup extends MessageLookupByLibrary {
MessageLookupByLibrary.simpleMessage("项目显示永久删除前剩余的天数"),
"itemsWillBeRemovedFromAlbum":
MessageLookupByLibrary.simpleMessage("所选项目将从此相册中移除"),
"joinDiscord": MessageLookupByLibrary.simpleMessage("Join Discord"),
"joinDiscord": MessageLookupByLibrary.simpleMessage("加入 Discord"),
"keepPhotos": MessageLookupByLibrary.simpleMessage("保留照片"),
"kiloMeterUnit": MessageLookupByLibrary.simpleMessage("公里"),
"kindlyHelpUsWithThisInformation":

View file

@ -19,6 +19,7 @@ import "package:photos/services/semantic_search/frameworks/ggml.dart";
import "package:photos/services/semantic_search/frameworks/ml_framework.dart";
import 'package:photos/services/semantic_search/frameworks/onnx/onnx.dart';
import "package:photos/utils/debouncer.dart";
import "package:photos/utils/device_info.dart";
import "package:photos/utils/local_settings.dart";
import "package:photos/utils/thumbnail_util.dart";
@ -33,7 +34,6 @@ class SemanticSearchService {
static const kEmbeddingLength = 512;
static const kScoreThreshold = 0.23;
static const kShouldPushEmbeddings = true;
static const kCurrentModel = Model.onnxClip;
static const kDebounceDuration = Duration(milliseconds: 4000);
final _logger = Logger("SemanticSearchService");
@ -42,6 +42,7 @@ class SemanticSearchService {
final _embeddingLoaderDebouncer =
Debouncer(kDebounceDuration, executionInterval: kDebounceDuration);
late Model _currentModel;
late MLFramework _mlFramework;
bool _hasInitialized = false;
bool _isComputingEmbeddings = false;
@ -76,7 +77,8 @@ class SemanticSearchService {
_hasInitialized = true;
final shouldDownloadOverMobileData =
Configuration.instance.shouldBackupOverMobileData();
_mlFramework = kCurrentModel == Model.onnxClip
_currentModel = await _getCurrentModel();
_mlFramework = _currentModel == Model.onnxClip
? ONNX(shouldDownloadOverMobileData)
: GGML(shouldDownloadOverMobileData);
await EmbeddingsDB.instance.init();
@ -122,7 +124,7 @@ class SemanticSearchService {
return;
}
_isSyncing = true;
await EmbeddingStore.instance.pullEmbeddings(kCurrentModel);
await EmbeddingStore.instance.pullEmbeddings(_currentModel);
await _backFill();
_isSyncing = false;
}
@ -171,14 +173,14 @@ class SemanticSearchService {
}
Future<void> clearIndexes() async {
await EmbeddingStore.instance.clearEmbeddings(kCurrentModel);
_logger.info("Indexes cleared for $kCurrentModel");
await EmbeddingStore.instance.clearEmbeddings(_currentModel);
_logger.info("Indexes cleared for $_currentModel");
}
Future<void> _loadEmbeddings() async {
_logger.info("Pulling cached embeddings");
final startTime = DateTime.now();
_cachedEmbeddings = await EmbeddingsDB.instance.getAll(kCurrentModel);
_cachedEmbeddings = await EmbeddingsDB.instance.getAll(_currentModel);
final endTime = DateTime.now();
_logger.info(
"Loading ${_cachedEmbeddings.length} took: ${(endTime.millisecondsSinceEpoch - startTime.millisecondsSinceEpoch)}ms",
@ -312,7 +314,7 @@ class SemanticSearchService {
final embedding = Embedding(
fileID: file.uploadedFileID!,
model: kCurrentModel,
model: _currentModel,
embedding: result,
);
await EmbeddingStore.instance.storeEmbedding(
@ -359,6 +361,14 @@ class SemanticSearchService {
);
return queryResults;
}
Future<Model> _getCurrentModel() async {
if (await isGrapheneOS()) {
return Model.ggmlClip;
} else {
return Model.onnxClip;
}
}
}
List<QueryResult> computeBulkScore(Map args) {

View file

@ -42,6 +42,14 @@ Future<bool> isLowSpecDevice() async {
return false;
}
Future<bool> isGrapheneOS() async {
if (Platform.isAndroid) {
final androidInfo = await deviceInfoPlugin.androidInfo;
return androidInfo.host.toLowerCase() == "grapheneos";
}
return false;
}
Future<bool> isAndroidSDKVersionLowerThan(int inputSDK) async {
if (Platform.isAndroid) {
final AndroidDeviceInfo androidInfo = await deviceInfoPlugin.androidInfo;

View file

@ -1382,7 +1382,7 @@ packages:
description:
path: "."
ref: HEAD
resolved-ref: "5f26aef45ed9f5e563c26f90c1e21b3339ed906d"
resolved-ref: "1318dce97f3aae5ec9bdf7491d5eff0ad6beb378"
url: "https://github.com/ente-io/onnxruntime.git"
source: git
version: "1.1.0"