Remove ssdMobileNetV2

This commit is contained in:
Manav Rathi 2024-04-11 11:22:47 +05:30
parent da3b58661a
commit 03df858dcc
No known key found for this signature in database
3 changed files with 0 additions and 85 deletions

View file

@ -22,8 +22,6 @@ import {
MLLibraryData,
MLSyncConfig,
MLSyncContext,
ObjectDetectionMethod,
ObjectDetectionService,
SceneDetectionMethod,
SceneDetectionService,
} from "types/machineLearning";
@ -35,7 +33,6 @@ import hdbscanClusteringService from "./hdbscanClusteringService";
import imageSceneService from "./imageSceneService";
import laplacianBlurDetectionService from "./laplacianBlurDetectionService";
import mobileFaceNetEmbeddingService from "./mobileFaceNetEmbeddingService";
import ssdMobileNetV2Service from "./ssdMobileNetV2Service";
import yoloFaceDetectionService from "./yoloFaceDetectionService";
export class MLFactory {
@ -49,16 +46,6 @@ export class MLFactory {
throw Error("Unknon face detection method: " + method);
}
public static getObjectDetectionService(
method: ObjectDetectionMethod,
): ObjectDetectionService {
if (method === "SSDMobileNetV2") {
return ssdMobileNetV2Service;
}
throw Error("Unknown object detection method: " + method);
}
public static getSceneDetectionService(
method: SceneDetectionMethod,
): SceneDetectionService {
@ -147,7 +134,6 @@ export class LocalMLSyncContext implements MLSyncContext {
public blurDetectionService: BlurDetectionService;
public faceEmbeddingService: FaceEmbeddingService;
public faceClusteringService: ClusteringService;
public objectDetectionService: ObjectDetectionService;
public sceneDetectionService: SceneDetectionService;
public localFilesMap: Map<number, EnteFile>;
@ -202,9 +188,6 @@ export class LocalMLSyncContext implements MLSyncContext {
this.config.faceClustering.method,
);
this.objectDetectionService = MLFactory.getObjectDetectionService(
this.config.objectDetection.method,
);
this.sceneDetectionService = MLFactory.getSceneDetectionService(
this.config.sceneDetection.method,
);

View file

@ -1,66 +0,0 @@
import log from "@/next/log";
import * as tf from "@tensorflow/tfjs-core";
import {
ObjectDetection,
ObjectDetectionMethod,
ObjectDetectionService,
Versioned,
} from "types/machineLearning";
import * as SSDMobileNet from "@tensorflow-models/coco-ssd";
import { OBJECT_DETECTION_IMAGE_SIZE } from "constants/mlConfig";
import { resizeToSquare } from "utils/image";
class SSDMobileNetV2 implements ObjectDetectionService {
private ssdMobileNetV2Model: SSDMobileNet.ObjectDetection;
public method: Versioned<ObjectDetectionMethod>;
private ready: Promise<void>;
public constructor() {
this.method = {
value: "SSDMobileNetV2",
version: 1,
};
}
private async init() {
this.ssdMobileNetV2Model = await SSDMobileNet.load({
base: "mobilenet_v2",
modelUrl: "/models/ssdmobilenet/model.json",
});
log.info("loaded ssdMobileNetV2Model", tf.getBackend());
}
private async getSSDMobileNetV2Model() {
if (!this.ready) {
this.ready = this.init();
}
await this.ready;
return this.ssdMobileNetV2Model;
}
public async detectObjects(
image: ImageBitmap,
maxNumberBoxes: number,
minScore: number,
): Promise<ObjectDetection[]> {
const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model();
const resized = resizeToSquare(image, OBJECT_DETECTION_IMAGE_SIZE);
const tfImage = tf.browser.fromPixels(resized.image);
const detections = await ssdMobileNetV2Model.detect(
tfImage,
maxNumberBoxes,
minScore,
);
tfImage.dispose();
return detections;
}
public async dispose() {
const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model();
ssdMobileNetV2Model?.dispose();
this.ssdMobileNetV2Model = null;
}
}
export default new SSDMobileNetV2();

View file

@ -265,7 +265,6 @@ export interface MLSyncContext {
faceEmbeddingService: FaceEmbeddingService;
blurDetectionService: BlurDetectionService;
faceClusteringService: ClusteringService;
objectDetectionService: ObjectDetectionService;
sceneDetectionService: SceneDetectionService;
localFilesMap: Map<number, EnteFile>;
@ -273,7 +272,6 @@ export interface MLSyncContext {
nSyncedFiles: number;
nSyncedFaces: number;
allSyncedFacesMap?: Map<number, Array<Face>>;
allSyncedObjectsMap?: Map<number, Array<RealWorldObject>>;
tsne?: any;
error?: Error;