create and use thingClasses index

This commit is contained in:
Abhinav 2022-03-14 15:49:43 +05:30
parent 1db047e3d3
commit 3b072b7d71
6 changed files with 83 additions and 19 deletions

View file

@ -1,17 +1,17 @@
import { Chip } from 'components/pages/gallery/Collections';
import React, { useState, useEffect } from 'react';
import objectService from 'services/machineLearning/objectService';
import { EnteFile } from 'types/file';
import mlIDbStorage from 'utils/storage/mlIDbStorage';
export function ObjectLabelList(props: { file: EnteFile }) {
const [objects, setObjects] = useState<Array<string>>([]);
useEffect(() => {
let didCancel = false;
const main = async () => {
const objects = await objectService.getAllSyncedThingsMap();
const things = await mlIDbStorage.getAllThingsMap();
const uniqueObjectNames = [
...new Set(
objects
things
.get(props.file.id)
.map((object) => object.detection.class)
),

View file

@ -28,7 +28,7 @@ import { CustomError, parseServerError } from 'utils/error';
import { MAX_ML_SYNC_ERROR_COUNT } from 'constants/machineLearning/config';
import FaceService from './faceService';
import PeopleService from './peopleService';
import objectService from './objectService';
import ObjectService from './objectService';
class MachineLearningService {
private initialized = false;
// private faceDetectionService: FaceDetectionService;
@ -406,7 +406,7 @@ class MachineLearningService {
await FaceService.syncFileFaceEmbeddings(syncContext, fileContext);
}
await objectService.syncFileObjectDetections(syncContext, fileContext);
await ObjectService.syncFileObjectDetections(syncContext, fileContext);
fileContext.tfImage && fileContext.tfImage.dispose();
fileContext.imageBitmap && fileContext.imageBitmap.close();
@ -507,6 +507,8 @@ class MachineLearningService {
// await this.init();
await PeopleService.syncPeopleIndex(syncContext);
await ObjectService.syncThingClassesIndex(syncContext);
await this.persistMLLibraryData(syncContext);
}

View file

@ -25,12 +25,7 @@ class ObjectService {
) &&
oldMlFile?.imageSource === syncContext.config.imageSource
) {
newMlFile.things = oldMlFile?.things?.map((existingObject) => ({
id: existingObject.id,
fileID: existingObject.fileID,
detection: existingObject.detection,
}));
newMlFile.things = oldMlFile?.things;
newMlFile.imageSource = oldMlFile.imageSource;
newMlFile.imageDimensions = oldMlFile.imageDimensions;
newMlFile.objectDetectionMethod = oldMlFile.objectDetectionMethod;
@ -54,9 +49,10 @@ class ObjectService {
detection,
} as DetectedObject;
});
newMlFile.things = detectedObjects?.map((detectedObjects) => ({
...detectedObjects,
id: getObjectId(detectedObjects, newMlFile.imageDimensions),
newMlFile.things = detectedObjects?.map((detectedObject) => ({
...detectedObject,
id: getObjectId(detectedObject, newMlFile.imageDimensions),
className: detectedObject.detection.class,
}));
// ?.filter((f) =>
// f.box.width > syncContext.config.faceDetection.minFaceSize
@ -64,12 +60,19 @@ class ObjectService {
console.log('[MLService] Detected Objects: ', newMlFile.things?.length);
}
public async getAllSyncedThingsMap() {
return await mlIDbStorage.getAllThingsMap();
async getAllSyncedThingsMap(syncContext: MLSyncContext) {
if (syncContext.allSyncedThingsMap) {
return syncContext.allSyncedThingsMap;
}
syncContext.allSyncedThingsMap = await mlIDbStorage.getAllThingsMap();
return syncContext.allSyncedThingsMap;
}
public async getThingClasses(): Promise<ThingClass[]> {
const allObjectsMap = await this.getAllSyncedThingsMap();
public async clusterThingClasses(
syncContext: MLSyncContext
): Promise<ThingClass[]> {
const allObjectsMap = await this.getAllSyncedThingsMap(syncContext);
const allObjects = getAllThingsFromMap(allObjectsMap);
const objectClusters = new Map<string, number[]>();
allObjects.map((object) => {
@ -85,6 +88,40 @@ class ObjectService {
files,
}));
}
async syncThingClassesIndex(syncContext: MLSyncContext) {
const filesVersion = await mlIDbStorage.getIndexVersion('files');
console.log(
'thingClasses',
await mlIDbStorage.getIndexVersion('thingClasses')
);
if (
filesVersion <= (await mlIDbStorage.getIndexVersion('thingClasses'))
) {
console.log(
'[MLService] Skipping people index as already synced to latest version'
);
return;
}
const thingClasses = await this.clusterThingClasses(syncContext);
if (!thingClasses || thingClasses.length < 1) {
return;
}
await mlIDbStorage.clearAllThingClasses();
for (const thingClass of thingClasses) {
await mlIDbStorage.putThingClass(thingClass);
}
await mlIDbStorage.setIndexVersion('thingClasses', filesVersion);
}
async getAllThingClasses() {
return await mlIDbStorage.getAllThingClasses();
}
}
export default new ObjectService();

View file

@ -174,7 +174,7 @@ export function searchFiles(searchPhrase: string, files: EnteFile[]) {
}
export async function searchThing(searchPhrase: string) {
const thingClasses = await ObjectService.getThingClasses();
const thingClasses = await ObjectService.getAllThingClasses();
return thingClasses
.filter((thingClass) =>
thingClass.className.toLocaleLowerCase().includes(searchPhrase)

View file

@ -194,6 +194,7 @@ export interface DetectedObject {
export interface Thing extends DetectedObject {
id: string;
className: string;
}
export interface ThingClass {
@ -296,6 +297,7 @@ export interface MLSyncContext {
nSyncedFiles: number;
nSyncedFaces: number;
allSyncedFacesMap?: Map<number, Array<Face>>;
allSyncedThingsMap?: Map<number, Array<Thing>>;
tsne?: any;
error?: Error;

View file

@ -19,6 +19,7 @@ import {
MLLibraryData,
Person,
Thing,
ThingClass,
} from 'types/machineLearning';
import { IndexStatus } from 'types/machineLearning/ui';
import { runningInBrowser } from 'utils/common';
@ -39,6 +40,10 @@ interface MLDb extends DBSchema {
key: number;
value: Person;
};
thingClasses: {
key: number;
value: ThingClass;
};
versions: {
key: string;
value: number;
@ -98,6 +103,10 @@ class MLIDbStorage {
keyPath: 'id',
});
db.createObjectStore('thingClasses', {
keyPath: 'id',
});
db.createObjectStore('versions');
db.createObjectStore('library');
@ -321,6 +330,20 @@ class MLIDbStorage {
return db.clear('people');
}
public async getAllThingClasses() {
const db = await this.db;
return db.getAll('thingClasses');
}
public async putThingClass(thingClass: ThingClass) {
const db = await this.db;
return db.put('thingClasses', thingClass);
}
public async clearAllThingClasses() {
const db = await this.db;
return db.clear('thingClasses');
}
public async getIndexVersion(index: string) {
const db = await this.db;
return db.get('versions', index);