2023-12-05 17:06:02 +00:00
|
|
|
import "dart:io";
|
|
|
|
|
2023-12-13 07:24:03 +00:00
|
|
|
import "package:flutter/services.dart";
|
2023-12-05 17:06:02 +00:00
|
|
|
import "package:logging/logging.dart";
|
|
|
|
import "package:path/path.dart";
|
|
|
|
import "package:path_provider/path_provider.dart";
|
2024-01-11 08:25:42 +00:00
|
|
|
import "package:photos/core/event_bus.dart";
|
2023-12-05 17:06:02 +00:00
|
|
|
import "package:photos/core/network/network.dart";
|
2024-01-11 08:25:42 +00:00
|
|
|
import "package:photos/events/event.dart";
|
2023-12-05 17:06:02 +00:00
|
|
|
|
|
|
|
abstract class MLFramework {
|
2024-01-03 18:58:14 +00:00
|
|
|
static const kImageEncoderEnabled = true;
|
2024-01-05 10:29:58 +00:00
|
|
|
static const kMaximumRetrials = 3;
|
2023-12-18 16:51:01 +00:00
|
|
|
|
2024-01-11 08:25:42 +00:00
|
|
|
InitializationState _state = InitializationState.notInitialized;
|
|
|
|
|
2023-12-05 17:06:02 +00:00
|
|
|
final _logger = Logger("MLFramework");
|
|
|
|
|
2024-01-11 08:25:42 +00:00
|
|
|
InitializationState get initializationState => _state;
|
|
|
|
|
|
|
|
set _initState(InitializationState state) {
|
|
|
|
Bus.instance.fire(MLFrameworkInitializationEvent(state));
|
|
|
|
_logger.info("Init state is $state");
|
|
|
|
_state = state;
|
|
|
|
}
|
|
|
|
|
2023-12-05 17:06:02 +00:00
|
|
|
/// Returns the path of the Image Model hosted remotely
|
|
|
|
String getImageModelRemotePath();
|
|
|
|
|
|
|
|
/// Returns the path of the Text Model hosted remotely
|
|
|
|
String getTextModelRemotePath();
|
|
|
|
|
|
|
|
/// Loads the Image Model stored at [path] into the framework
|
|
|
|
Future<void> loadImageModel(String path);
|
|
|
|
|
|
|
|
/// Loads the Text Model stored at [path] into the framework
|
|
|
|
Future<void> loadTextModel(String path);
|
|
|
|
|
|
|
|
/// Returns the Image Embedding for a file stored at [imagePath]
|
|
|
|
Future<List<double>> getImageEmbedding(String imagePath);
|
|
|
|
|
|
|
|
/// Returns the Text Embedding for [text]
|
|
|
|
Future<List<double>> getTextEmbedding(String text);
|
|
|
|
|
|
|
|
/// Downloads the models from remote, caches them and loads them into the
|
|
|
|
/// framework. Override this method if you would like to control the
|
|
|
|
/// initialization. For eg. if you wish to load the model from `/assets`
|
|
|
|
/// instead of a CDN.
|
|
|
|
Future<void> init() async {
|
2024-01-11 08:29:17 +00:00
|
|
|
await Future.wait([_initImageModel(), _initTextModel()]);
|
2024-01-11 08:25:42 +00:00
|
|
|
_initState = InitializationState.initialized;
|
2023-12-05 17:06:02 +00:00
|
|
|
}
|
|
|
|
|
2023-12-18 19:05:45 +00:00
|
|
|
// Releases any resources held by the framework
|
|
|
|
Future<void> release() async {}
|
|
|
|
|
2023-12-05 17:06:02 +00:00
|
|
|
/// Returns the cosine similarity between [imageEmbedding] and [textEmbedding]
|
|
|
|
double computeScore(List<double> imageEmbedding, List<double> textEmbedding) {
|
|
|
|
assert(
|
|
|
|
imageEmbedding.length == textEmbedding.length,
|
|
|
|
"The two embeddings should have the same length",
|
|
|
|
);
|
|
|
|
double score = 0;
|
|
|
|
for (int index = 0; index < imageEmbedding.length; index++) {
|
|
|
|
score += imageEmbedding[index] * textEmbedding[index];
|
|
|
|
}
|
|
|
|
return score;
|
|
|
|
}
|
|
|
|
|
|
|
|
// ---
|
|
|
|
// Private methods
|
|
|
|
// ---
|
|
|
|
|
|
|
|
Future<void> _initImageModel() async {
|
2023-12-18 16:51:01 +00:00
|
|
|
if (!kImageEncoderEnabled) {
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
final path = await _getLocalImageModelPath();
|
|
|
|
if (File(path).existsSync()) {
|
2024-01-11 08:25:42 +00:00
|
|
|
_initState = InitializationState.initializingImageModel;
|
2023-12-18 16:51:01 +00:00
|
|
|
await loadImageModel(path);
|
2023-12-14 06:01:01 +00:00
|
|
|
} else {
|
2024-01-11 08:25:42 +00:00
|
|
|
_initState = InitializationState.downloadingImageModel;
|
2023-12-18 16:51:01 +00:00
|
|
|
final tempFile = File(path + ".temp");
|
|
|
|
await _downloadFile(getImageModelRemotePath(), tempFile.path);
|
|
|
|
await tempFile.rename(path);
|
|
|
|
await loadImageModel(path);
|
2023-12-14 06:01:01 +00:00
|
|
|
}
|
2023-12-05 17:06:02 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
Future<void> _initTextModel() async {
|
2023-12-16 11:20:07 +00:00
|
|
|
final path = await _getLocalTextModelPath();
|
|
|
|
if (File(path).existsSync()) {
|
2024-01-11 08:25:42 +00:00
|
|
|
_initState = InitializationState.initializingTextModel;
|
2023-12-16 11:20:07 +00:00
|
|
|
await loadTextModel(path);
|
2023-12-14 06:01:01 +00:00
|
|
|
} else {
|
2024-01-11 08:25:42 +00:00
|
|
|
_initState = InitializationState.downloadingTextModel;
|
2023-12-16 11:20:07 +00:00
|
|
|
final tempFile = File(path + ".temp");
|
|
|
|
await _downloadFile(getTextModelRemotePath(), tempFile.path);
|
|
|
|
await tempFile.rename(path);
|
|
|
|
await loadTextModel(path);
|
2023-12-14 06:01:01 +00:00
|
|
|
}
|
2023-12-05 17:06:02 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
Future<String> _getLocalImageModelPath() async {
|
|
|
|
return (await getTemporaryDirectory()).path +
|
|
|
|
"/models/" +
|
|
|
|
basename(getImageModelRemotePath());
|
|
|
|
}
|
|
|
|
|
|
|
|
Future<String> _getLocalTextModelPath() async {
|
|
|
|
return (await getTemporaryDirectory()).path +
|
|
|
|
"/models/" +
|
|
|
|
basename(getTextModelRemotePath());
|
|
|
|
}
|
|
|
|
|
2024-01-05 10:29:58 +00:00
|
|
|
Future<void> _downloadFile(
|
|
|
|
String url,
|
|
|
|
String savePath, {
|
|
|
|
int trialCount = 1,
|
|
|
|
}) async {
|
2023-12-05 17:06:02 +00:00
|
|
|
_logger.info("Downloading " + url);
|
|
|
|
final existingFile = File(savePath);
|
|
|
|
if (await existingFile.exists()) {
|
|
|
|
await existingFile.delete();
|
|
|
|
}
|
2024-01-05 10:29:58 +00:00
|
|
|
try {
|
|
|
|
await NetworkClient.instance.getDio().download(url, savePath);
|
|
|
|
} catch (e, s) {
|
|
|
|
_logger.severe(e, s);
|
|
|
|
if (trialCount < kMaximumRetrials) {
|
|
|
|
return _downloadFile(url, savePath, trialCount: trialCount + 1);
|
|
|
|
} else {
|
|
|
|
rethrow;
|
|
|
|
}
|
|
|
|
}
|
2023-12-05 17:06:02 +00:00
|
|
|
}
|
2023-12-13 07:24:03 +00:00
|
|
|
|
|
|
|
Future<String> getAccessiblePathForAsset(
|
|
|
|
String assetPath,
|
|
|
|
String tempName,
|
|
|
|
) async {
|
|
|
|
final byteData = await rootBundle.load(assetPath);
|
|
|
|
final tempDir = await getTemporaryDirectory();
|
|
|
|
final file = await File('${tempDir.path}/$tempName')
|
|
|
|
.writeAsBytes(byteData.buffer.asUint8List());
|
|
|
|
return file.path;
|
|
|
|
}
|
2023-12-05 17:06:02 +00:00
|
|
|
}
|
2024-01-11 08:25:42 +00:00
|
|
|
|
|
|
|
class MLFrameworkInitializationEvent extends Event {
|
|
|
|
final InitializationState state;
|
|
|
|
|
|
|
|
MLFrameworkInitializationEvent(this.state);
|
|
|
|
}
|
|
|
|
|
|
|
|
enum InitializationState {
|
|
|
|
notInitialized,
|
|
|
|
waitingForNetwork,
|
|
|
|
downloadingImageModel,
|
|
|
|
downloadingTextModel,
|
|
|
|
initializingImageModel,
|
|
|
|
initializingTextModel,
|
|
|
|
initialized,
|
|
|
|
}
|