diff --git a/src/components/MachineLearning/ObjectList.tsx b/src/components/MachineLearning/ObjectList.tsx index 599311b16..b3369fd04 100644 --- a/src/components/MachineLearning/ObjectList.tsx +++ b/src/components/MachineLearning/ObjectList.tsx @@ -1,17 +1,17 @@ import { Chip } from 'components/pages/gallery/Collections'; import React, { useState, useEffect } from 'react'; -import objectService from 'services/machineLearning/objectService'; import { EnteFile } from 'types/file'; +import mlIDbStorage from 'utils/storage/mlIDbStorage'; export function ObjectLabelList(props: { file: EnteFile }) { const [objects, setObjects] = useState>([]); useEffect(() => { let didCancel = false; const main = async () => { - const objects = await objectService.getAllSyncedThingsMap(); + const things = await mlIDbStorage.getAllThingsMap(); const uniqueObjectNames = [ ...new Set( - objects + things .get(props.file.id) .map((object) => object.detection.class) ), diff --git a/src/services/machineLearning/machineLearningService.ts b/src/services/machineLearning/machineLearningService.ts index e0b841db1..17624a18d 100644 --- a/src/services/machineLearning/machineLearningService.ts +++ b/src/services/machineLearning/machineLearningService.ts @@ -28,7 +28,7 @@ import { CustomError, parseServerError } from 'utils/error'; import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config'; import FaceService from './faceService'; import PeopleService from './peopleService'; -import objectService from './objectService'; +import ObjectService from './objectService'; class MachineLearningService { private initialized = false; // private faceDetectionService: FaceDetectionService; @@ -406,7 +406,7 @@ class MachineLearningService { await FaceService.syncFileFaceEmbeddings(syncContext, fileContext); } - await objectService.syncFileObjectDetections(syncContext, fileContext); + await ObjectService.syncFileObjectDetections(syncContext, fileContext); fileContext.tfImage && fileContext.tfImage.dispose(); fileContext.imageBitmap && fileContext.imageBitmap.close(); @@ -507,6 +507,8 @@ class MachineLearningService { // await this.init(); await PeopleService.syncPeopleIndex(syncContext); + await ObjectService.syncThingClassesIndex(syncContext); + await this.persistMLLibraryData(syncContext); } diff --git a/src/services/machineLearning/objectService.ts b/src/services/machineLearning/objectService.ts index 693411f22..c4a1bd7d4 100644 --- a/src/services/machineLearning/objectService.ts +++ b/src/services/machineLearning/objectService.ts @@ -25,12 +25,7 @@ class ObjectService { ) && oldMlFile?.imageSource === syncContext.config.imageSource ) { - newMlFile.things = oldMlFile?.things?.map((existingObject) => ({ - id: existingObject.id, - fileID: existingObject.fileID, - detection: existingObject.detection, - })); - + newMlFile.things = oldMlFile?.things; newMlFile.imageSource = oldMlFile.imageSource; newMlFile.imageDimensions = oldMlFile.imageDimensions; newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod; @@ -54,9 +49,10 @@ class ObjectService { detection, } as DetectedObject; }); - newMlFile.things = detectedObjects?.map((detectedObjects) => ({ - ...detectedObjects, - id: getObjectId(detectedObjects, newMlFile.imageDimensions), + newMlFile.things = detectedObjects?.map((detectedObject) => ({ + ...detectedObject, + id: getObjectId(detectedObject, newMlFile.imageDimensions), + className: detectedObject.detection.class, })); // ?.filter((f) => // f.box.width > syncContext.config.faceDetection.minFaceSize @@ -64,12 +60,19 @@ class ObjectService { console.log('[MLService] Detected Objects: ', newMlFile.things?.length); } - public async getAllSyncedThingsMap() { - return await mlIDbStorage.getAllThingsMap(); + async getAllSyncedThingsMap(syncContext: MLSyncContext) { + if (syncContext.allSyncedThingsMap) { + return syncContext.allSyncedThingsMap; + } + + syncContext.allSyncedThingsMap = await mlIDbStorage.getAllThingsMap(); + return syncContext.allSyncedThingsMap; } - public async getThingClasses(): Promise { - const allObjectsMap = await this.getAllSyncedThingsMap(); + public async clusterThingClasses( + syncContext: MLSyncContext + ): Promise { + const allObjectsMap = await this.getAllSyncedThingsMap(syncContext); const allObjects = getAllThingsFromMap(allObjectsMap); const objectClusters = new Map(); allObjects.map((object) => { @@ -85,6 +88,40 @@ class ObjectService { files, })); } + + async syncThingClassesIndex(syncContext: MLSyncContext) { + const filesVersion = await mlIDbStorage.getIndexVersion('files'); + console.log( + 'thingClasses', + await mlIDbStorage.getIndexVersion('thingClasses') + ); + if ( + filesVersion <= (await mlIDbStorage.getIndexVersion('thingClasses')) + ) { + console.log( + '[MLService] Skipping people index as already synced to latest version' + ); + return; + } + + const thingClasses = await this.clusterThingClasses(syncContext); + + if (!thingClasses || thingClasses.length < 1) { + return; + } + + await mlIDbStorage.clearAllThingClasses(); + + for (const thingClass of thingClasses) { + await mlIDbStorage.putThingClass(thingClass); + } + + await mlIDbStorage.setIndexVersion('thingClasses', filesVersion); + } + + async getAllThingClasses() { + return await mlIDbStorage.getAllThingClasses(); + } } export default new ObjectService(); diff --git a/src/services/searchService.ts b/src/services/searchService.ts index 768362222..d08c26076 100644 --- a/src/services/searchService.ts +++ b/src/services/searchService.ts @@ -174,7 +174,7 @@ export function searchFiles(searchPhrase: string, files: EnteFile[]) { } export async function searchThing(searchPhrase: string) { - const thingClasses = await ObjectService.getThingClasses(); + const thingClasses = await ObjectService.getAllThingClasses(); return thingClasses .filter((thingClass) => thingClass.className.toLocaleLowerCase().includes(searchPhrase) diff --git a/src/types/machineLearning/index.ts b/src/types/machineLearning/index.ts index 3348cbe54..f5505f9ae 100644 --- a/src/types/machineLearning/index.ts +++ b/src/types/machineLearning/index.ts @@ -194,6 +194,7 @@ export interface DetectedObject { export interface Thing extends DetectedObject { id: string; + className: string; } export interface ThingClass { @@ -296,6 +297,7 @@ export interface MLSyncContext { nSyncedFiles: number; nSyncedFaces: number; allSyncedFacesMap?: Map>; + allSyncedThingsMap?: Map>; tsne?: any; error?: Error; diff --git a/src/utils/storage/mlIDbStorage.ts b/src/utils/storage/mlIDbStorage.ts index 641b7cd43..f01b9c4f2 100644 --- a/src/utils/storage/mlIDbStorage.ts +++ b/src/utils/storage/mlIDbStorage.ts @@ -19,6 +19,7 @@ import { MLLibraryData, Person, Thing, + ThingClass, } from 'types/machineLearning'; import { IndexStatus } from 'types/machineLearning/ui'; import { runningInBrowser } from 'utils/common'; @@ -39,6 +40,10 @@ interface MLDb extends DBSchema { key: number; value: Person; }; + thingClasses: { + key: number; + value: ThingClass; + }; versions: { key: string; value: number; @@ -98,6 +103,10 @@ class MLIDbStorage { keyPath: 'id', }); + db.createObjectStore('thingClasses', { + keyPath: 'id', + }); + db.createObjectStore('versions'); db.createObjectStore('library'); @@ -321,6 +330,20 @@ class MLIDbStorage { return db.clear('people'); } + public async getAllThingClasses() { + const db = await this.db; + return db.getAll('thingClasses'); + } + public async putThingClass(thingClass: ThingClass) { + const db = await this.db; + return db.put('thingClasses', thingClass); + } + + public async clearAllThingClasses() { + const db = await this.db; + return db.clear('thingClasses'); + } + public async getIndexVersion(index: string) { const db = await this.db; return db.get('versions', index);