create and use thingClasses index

This commit is contained in:
Abhinav 2022-03-14 15:49:43 +05:30
parent 1db047e3d3
commit 3b072b7d71
6 changed files with 83 additions and 19 deletions

View file

@ -1,17 +1,17 @@
import { Chip } from 'components/pages/gallery/Collections'; import { Chip } from 'components/pages/gallery/Collections';
import React, { useState, useEffect } from 'react'; import React, { useState, useEffect } from 'react';
import objectService from 'services/machineLearning/objectService';
import { EnteFile } from 'types/file'; import { EnteFile } from 'types/file';
import mlIDbStorage from 'utils/storage/mlIDbStorage';
export function ObjectLabelList(props: { file: EnteFile }) { export function ObjectLabelList(props: { file: EnteFile }) {
const [objects, setObjects] = useState<Array<string>>([]); const [objects, setObjects] = useState<Array<string>>([]);
useEffect(() => { useEffect(() => {
let didCancel = false; let didCancel = false;
const main = async () => { const main = async () => {
const objects = await objectService.getAllSyncedThingsMap(); const things = await mlIDbStorage.getAllThingsMap();
const uniqueObjectNames = [ const uniqueObjectNames = [
...new Set( ...new Set(
objects things
.get(props.file.id) .get(props.file.id)
.map((object) => object.detection.class) .map((object) => object.detection.class)
), ),

View file

@ -28,7 +28,7 @@ import { CustomError, parseServerError } from 'utils/error';
import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config'; import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config';
import FaceService from './faceService'; import FaceService from './faceService';
import PeopleService from './peopleService'; import PeopleService from './peopleService';
import objectService from './objectService'; import ObjectService from './objectService';
class MachineLearningService { class MachineLearningService {
private initialized = false; private initialized = false;
// private faceDetectionService: FaceDetectionService; // private faceDetectionService: FaceDetectionService;
@ -406,7 +406,7 @@ class MachineLearningService {
await FaceService.syncFileFaceEmbeddings(syncContext, fileContext); await FaceService.syncFileFaceEmbeddings(syncContext, fileContext);
} }
await objectService.syncFileObjectDetections(syncContext, fileContext); await ObjectService.syncFileObjectDetections(syncContext, fileContext);
fileContext.tfImage && fileContext.tfImage.dispose(); fileContext.tfImage && fileContext.tfImage.dispose();
fileContext.imageBitmap && fileContext.imageBitmap.close(); fileContext.imageBitmap && fileContext.imageBitmap.close();
@ -507,6 +507,8 @@ class MachineLearningService {
// await this.init(); // await this.init();
await PeopleService.syncPeopleIndex(syncContext); await PeopleService.syncPeopleIndex(syncContext);
await ObjectService.syncThingClassesIndex(syncContext);
await this.persistMLLibraryData(syncContext); await this.persistMLLibraryData(syncContext);
} }

View file

@ -25,12 +25,7 @@ class ObjectService {
) && ) &&
oldMlFile?.imageSource === syncContext.config.imageSource oldMlFile?.imageSource === syncContext.config.imageSource
) { ) {
newMlFile.things = oldMlFile?.things?.map((existingObject) => ({ newMlFile.things = oldMlFile?.things;
id: existingObject.id,
fileID: existingObject.fileID,
detection: existingObject.detection,
}));
newMlFile.imageSource = oldMlFile.imageSource; newMlFile.imageSource = oldMlFile.imageSource;
newMlFile.imageDimensions = oldMlFile.imageDimensions; newMlFile.imageDimensions = oldMlFile.imageDimensions;
newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod; newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod;
@ -54,9 +49,10 @@ class ObjectService {
detection, detection,
} as DetectedObject; } as DetectedObject;
}); });
newMlFile.things = detectedObjects?.map((detectedObjects) => ({ newMlFile.things = detectedObjects?.map((detectedObject) => ({
...detectedObjects, ...detectedObject,
id: getObjectId(detectedObjects, newMlFile.imageDimensions), id: getObjectId(detectedObject, newMlFile.imageDimensions),
className: detectedObject.detection.class,
})); }));
// ?.filter((f) => // ?.filter((f) =>
// f.box.width > syncContext.config.faceDetection.minFaceSize // f.box.width > syncContext.config.faceDetection.minFaceSize
@ -64,12 +60,19 @@ class ObjectService {
console.log('[MLService] Detected Objects: ', newMlFile.things?.length); console.log('[MLService] Detected Objects: ', newMlFile.things?.length);
} }
public async getAllSyncedThingsMap() { async getAllSyncedThingsMap(syncContext: MLSyncContext) {
return await mlIDbStorage.getAllThingsMap(); if (syncContext.allSyncedThingsMap) {
return syncContext.allSyncedThingsMap;
} }
public async getThingClasses(): Promise<ThingClass[]> { syncContext.allSyncedThingsMap = await mlIDbStorage.getAllThingsMap();
const allObjectsMap = await this.getAllSyncedThingsMap(); return syncContext.allSyncedThingsMap;
}
public async clusterThingClasses(
syncContext: MLSyncContext
): Promise<ThingClass[]> {
const allObjectsMap = await this.getAllSyncedThingsMap(syncContext);
const allObjects = getAllThingsFromMap(allObjectsMap); const allObjects = getAllThingsFromMap(allObjectsMap);
const objectClusters = new Map<string, number[]>(); const objectClusters = new Map<string, number[]>();
allObjects.map((object) => { allObjects.map((object) => {
@ -85,6 +88,40 @@ class ObjectService {
files, files,
})); }));
} }
async syncThingClassesIndex(syncContext: MLSyncContext) {
const filesVersion = await mlIDbStorage.getIndexVersion('files');
console.log(
'thingClasses',
await mlIDbStorage.getIndexVersion('thingClasses')
);
if (
filesVersion <= (await mlIDbStorage.getIndexVersion('thingClasses'))
) {
console.log(
'[MLService] Skipping people index as already synced to latest version'
);
return;
}
const thingClasses = await this.clusterThingClasses(syncContext);
if (!thingClasses || thingClasses.length < 1) {
return;
}
await mlIDbStorage.clearAllThingClasses();
for (const thingClass of thingClasses) {
await mlIDbStorage.putThingClass(thingClass);
}
await mlIDbStorage.setIndexVersion('thingClasses', filesVersion);
}
async getAllThingClasses() {
return await mlIDbStorage.getAllThingClasses();
}
} }
export default new ObjectService(); export default new ObjectService();

View file

@ -174,7 +174,7 @@ export function searchFiles(searchPhrase: string, files: EnteFile[]) {
} }
export async function searchThing(searchPhrase: string) { export async function searchThing(searchPhrase: string) {
const thingClasses = await ObjectService.getThingClasses(); const thingClasses = await ObjectService.getAllThingClasses();
return thingClasses return thingClasses
.filter((thingClass) => .filter((thingClass) =>
thingClass.className.toLocaleLowerCase().includes(searchPhrase) thingClass.className.toLocaleLowerCase().includes(searchPhrase)

View file

@ -194,6 +194,7 @@ export interface DetectedObject {
export interface Thing extends DetectedObject { export interface Thing extends DetectedObject {
id: string; id: string;
className: string;
} }
export interface ThingClass { export interface ThingClass {
@ -296,6 +297,7 @@ export interface MLSyncContext {
nSyncedFiles: number; nSyncedFiles: number;
nSyncedFaces: number; nSyncedFaces: number;
allSyncedFacesMap?: Map<number, Array<Face>>; allSyncedFacesMap?: Map<number, Array<Face>>;
allSyncedThingsMap?: Map<number, Array<Thing>>;
tsne?: any; tsne?: any;
error?: Error; error?: Error;

View file

@ -19,6 +19,7 @@ import {
MLLibraryData, MLLibraryData,
Person, Person,
Thing, Thing,
ThingClass,
} from 'types/machineLearning'; } from 'types/machineLearning';
import { IndexStatus } from 'types/machineLearning/ui'; import { IndexStatus } from 'types/machineLearning/ui';
import { runningInBrowser } from 'utils/common'; import { runningInBrowser } from 'utils/common';
@ -39,6 +40,10 @@ interface MLDb extends DBSchema {
key: number; key: number;
value: Person; value: Person;
}; };
thingClasses: {
key: number;
value: ThingClass;
};
versions: { versions: {
key: string; key: string;
value: number; value: number;
@ -98,6 +103,10 @@ class MLIDbStorage {
keyPath: 'id', keyPath: 'id',
}); });
db.createObjectStore('thingClasses', {
keyPath: 'id',
});
db.createObjectStore('versions'); db.createObjectStore('versions');
db.createObjectStore('library'); db.createObjectStore('library');
@ -321,6 +330,20 @@ class MLIDbStorage {
return db.clear('people'); return db.clear('people');
} }
public async getAllThingClasses() {
const db = await this.db;
return db.getAll('thingClasses');
}
public async putThingClass(thingClass: ThingClass) {
const db = await this.db;
return db.put('thingClasses', thingClass);
}
public async clearAllThingClasses() {
const db = await this.db;
return db.clear('thingClasses');
}
public async getIndexVersion(index: string) { public async getIndexVersion(index: string) {
const db = await this.db; const db = await this.db;
return db.get('versions', index); return db.get('versions', index);