This commit is contained in:
Manav Rathi 2024-05-16 10:41:05 +05:30
parent 01108141c2
commit dd38232836
No known key found for this signature in database
3 changed files with 20 additions and 24 deletions

View file

@ -15,6 +15,7 @@ import ReaderService, {
getFaceId,
getLocalFile,
} from "./readerService";
import { DEFAULT_ML_SYNC_CONFIG } from "./machineLearningService";
class FaceService {
async syncFileFaceDetections(
@ -27,7 +28,7 @@ class FaceService {
oldMlFile?.faceDetectionMethod,
syncContext.faceDetectionService.method,
) &&
oldMlFile?.imageSource === syncContext.config.imageSource
oldMlFile?.imageSource === "Original"
) {
newMlFile.faces = oldMlFile?.faces?.map((existingFace) => ({
id: existingFace.id,
@ -223,10 +224,10 @@ class FaceService {
const faceCrop = await syncContext.faceCropService.getFaceCrop(
imageBitmap,
face.detection,
syncContext.config.faceCrop,
DEFAULT_ML_SYNC_CONFIG.faceCrop,
);
const blobOptions = syncContext.config.faceCrop.blobOptions;
const blobOptions = DEFAULT_ML_SYNC_CONFIG.faceCrop.blobOptions;
const blob = await imageBitmapToBlob(faceCrop.image, blobOptions);
const cache = await openCache("face-crops");
@ -252,7 +253,7 @@ class FaceService {
) {
// await this.init();
const clusteringConfig = syncContext.config.faceClustering;
const clusteringConfig = DEFAULT_ML_SYNC_CONFIG.faceClustering;
if (!allFaces || allFaces.length < clusteringConfig.minInputSize) {
log.info(
@ -266,7 +267,7 @@ class FaceService {
syncContext.mlLibraryData.faceClusteringResults =
await syncContext.faceClusteringService.cluster(
allFaces.map((f) => Array.from(f.embedding)),
syncContext.config.faceClustering,
DEFAULT_ML_SYNC_CONFIG.faceClustering,
);
syncContext.mlLibraryData.faceClusteringMethod =
syncContext.faceClusteringService.method;

View file

@ -55,7 +55,9 @@ import yoloFaceDetectionService from "./yoloFaceDetectionService";
*/
export const defaultMLVersion = 3;
const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
const batchSize = 200;
export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
batchSize: 200,
imageSource: "Original",
faceDetection: {
@ -186,19 +188,13 @@ export class MLFactory {
config: MLSyncConfig,
shouldUpdateMLVersion: boolean = true,
) {
return new LocalMLSyncContext(
token,
userID,
config,
shouldUpdateMLVersion,
);
return new LocalMLSyncContext(token, userID, shouldUpdateMLVersion);
}
}
export class LocalMLSyncContext implements MLSyncContext {
public token: string;
public userID: number;
public config: MLSyncConfig;
public shouldUpdateMLVersion: boolean;
public faceDetectionService: FaceDetectionService;
@ -231,13 +227,11 @@ export class LocalMLSyncContext implements MLSyncContext {
constructor(
token: string,
userID: number,
config: MLSyncConfig,
shouldUpdateMLVersion: boolean = true,
concurrency?: number,
) {
this.token = token;
this.userID = userID;
this.config = config;
this.shouldUpdateMLVersion = shouldUpdateMLVersion;
this.faceDetectionService =
@ -318,8 +312,7 @@ class MachineLearningService {
// may be need to just take synced files on latest ml version for indexing
if (
syncContext.outOfSyncFiles.length <= 0 ||
(syncContext.nSyncedFiles === syncContext.config.batchSize &&
Math.random() < 0.2)
(syncContext.nSyncedFiles === batchSize && Math.random() < 0.2)
) {
await this.syncIndex(syncContext);
}
@ -425,8 +418,8 @@ class MachineLearningService {
private async getOutOfSyncFiles(syncContext: MLSyncContext) {
const startTime = Date.now();
const fileIds = await mlIDbStorage.getFileIds(
syncContext.config.batchSize,
syncContext.config.mlVersion,
batchSize,
defaultMLVersion,
MAX_ML_SYNC_ERROR_COUNT,
);
@ -535,7 +528,7 @@ class MachineLearningService {
localFile,
);
if (syncContext.nSyncedFiles >= syncContext.config.batchSize) {
if (syncContext.nSyncedFiles >= batchSize) {
await this.closeLocalSyncContext();
}
// await syncContext.dispose();
@ -603,7 +596,7 @@ class MachineLearningService {
(fileContext.oldMlFile = await this.getMLFileData(enteFile.id)) ??
this.newMlData(enteFile.id);
if (
fileContext.oldMlFile?.mlVersion === syncContext.config.mlVersion
fileContext.oldMlFile?.mlVersion === defaultMLVersion
// TODO: reset mlversion of all files when user changes image source
) {
return fileContext.oldMlFile;
@ -611,7 +604,7 @@ class MachineLearningService {
const newMlFile = (fileContext.newMlFile = this.newMlData(enteFile.id));
if (syncContext.shouldUpdateMLVersion) {
newMlFile.mlVersion = syncContext.config.mlVersion;
newMlFile.mlVersion = defaultMLVersion;
} else if (fileContext.oldMlFile?.mlVersion) {
newMlFile.mlVersion = fileContext.oldMlFile.mlVersion;
}

View file

@ -12,6 +12,7 @@ import {
import { EnteFile } from "types/file";
import { getRenderableImage } from "utils/file";
import { clamp } from "utils/image";
import { DEFAULT_ML_SYNC_CONFIG } from "./machineLearningService";
class ReaderService {
async getImageBitmap(
@ -35,7 +36,7 @@ class ReaderService {
fileContext.localFile,
);
} else if (
syncContext.config.imageSource === "Original" &&
DEFAULT_ML_SYNC_CONFIG.imageSource === "Original" &&
[FILE_TYPE.IMAGE, FILE_TYPE.LIVE_PHOTO].includes(
fileContext.enteFile.metadata.fileType,
)
@ -49,7 +50,8 @@ class ReaderService {
);
}
fileContext.newMlFile.imageSource = syncContext.config.imageSource;
fileContext.newMlFile.imageSource =
DEFAULT_ML_SYNC_CONFIG.imageSource;
const { width, height } = fileContext.imageBitmap;
fileContext.newMlFile.imageDimensions = { width, height };