diff --git a/src/components/MachineLearning/MLFileDebugView.tsx b/src/components/MachineLearning/MLFileDebugView.tsx index 3553b755b..ba809d27c 100644 --- a/src/components/MachineLearning/MLFileDebugView.tsx +++ b/src/components/MachineLearning/MLFileDebugView.tsx @@ -97,6 +97,7 @@ export default function MLFileDebugView(props: MLFileDebugViewProps) { const objectDetections = await ssdMobileNetV2Service.detectObjects( imageBitmap, + DEFAULT_ML_SYNC_CONFIG.objectDetection.maxNumBoxes, DEFAULT_ML_SYNC_CONFIG.objectDetection.minScore ); console.log('detectedObjects: ', objectDetections); diff --git a/src/constants/machineLearning/config.ts b/src/constants/machineLearning/config.ts index d6572c8d6..601af2729 100644 --- a/src/constants/machineLearning/config.ts +++ b/src/constants/machineLearning/config.ts @@ -45,6 +45,7 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = { }, objectDetection: { method: 'SSDMobileNetV2', + maxNumBoxes: 20, minScore: 0.2, }, textDetection: { diff --git a/src/services/machineLearning/objectService.ts b/src/services/machineLearning/objectService.ts index 14286d7c1..667b1acf7 100644 --- a/src/services/machineLearning/objectService.ts +++ b/src/services/machineLearning/objectService.ts @@ -42,6 +42,7 @@ class ObjectService { const objectDetections = await syncContext.objectDetectionService.detectObjects( imageBitmap, + syncContext.config.objectDetection.maxNumBoxes, syncContext.config.objectDetection.minScore ); // console.log('3 TF Memory stats: ', tf.memory()); diff --git a/src/services/machineLearning/ssdMobileNetV2Service.ts b/src/services/machineLearning/ssdMobileNetV2Service.ts index 4371256bf..c8272ecc9 100644 --- a/src/services/machineLearning/ssdMobileNetV2Service.ts +++ b/src/services/machineLearning/ssdMobileNetV2Service.ts @@ -42,13 +42,14 @@ class SSDMobileNetV2 implements ObjectDetectionService { public async detectObjects( image: ImageBitmap, - minScore?: number + maxNumberBoxes: number, + minScore: number ): Promise { const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model(); const tfImage = tf.browser.fromPixels(image); const detections = await ssdMobileNetV2Model.detect( tfImage, - undefined, + maxNumberBoxes, minScore ); return detections; diff --git a/src/types/machineLearning/index.ts b/src/types/machineLearning/index.ts index 4d1d8e801..a44e02192 100644 --- a/src/types/machineLearning/index.ts +++ b/src/types/machineLearning/index.ts @@ -246,6 +246,7 @@ export interface FaceDetectionConfig { export interface ObjectDetectionConfig { method: ObjectDetectionMethod; + maxNumBoxes: number; minScore: number; } @@ -386,7 +387,8 @@ export interface ObjectDetectionService { // init(): Promise; detectObjects( image: ImageBitmap, - minScore?: number + maxNumBoxes: number, + minScore: number ): Promise; dispose(): Promise; }