set max bounding boxes of objects to be detected

This commit is contained in:
Abhinav 2022-04-12 16:22:41 +05:30
parent 952d00df45
commit 2dce1407b0
5 changed files with 9 additions and 3 deletions

View file

@ -97,6 +97,7 @@ export default function MLFileDebugView(props: MLFileDebugViewProps) {
const objectDetections = await ssdMobileNetV2Service.detectObjects( const objectDetections = await ssdMobileNetV2Service.detectObjects(
imageBitmap, imageBitmap,
DEFAULT_ML_SYNC_CONFIG.objectDetection.maxNumBoxes,
DEFAULT_ML_SYNC_CONFIG.objectDetection.minScore DEFAULT_ML_SYNC_CONFIG.objectDetection.minScore
); );
console.log('detectedObjects: ', objectDetections); console.log('detectedObjects: ', objectDetections);

View file

@ -45,6 +45,7 @@ export const DEFAULT_ML_SYNC_CONFIG: MLSyncConfig = {
}, },
objectDetection: { objectDetection: {
method: 'SSDMobileNetV2', method: 'SSDMobileNetV2',
maxNumBoxes: 20,
minScore: 0.2, minScore: 0.2,
}, },
textDetection: { textDetection: {

View file

@ -42,6 +42,7 @@ class ObjectService {
const objectDetections = const objectDetections =
await syncContext.objectDetectionService.detectObjects( await syncContext.objectDetectionService.detectObjects(
imageBitmap, imageBitmap,
syncContext.config.objectDetection.maxNumBoxes,
syncContext.config.objectDetection.minScore syncContext.config.objectDetection.minScore
); );
// console.log('3 TF Memory stats: ', tf.memory()); // console.log('3 TF Memory stats: ', tf.memory());

View file

@ -42,13 +42,14 @@ class SSDMobileNetV2 implements ObjectDetectionService {
public async detectObjects( public async detectObjects(
image: ImageBitmap, image: ImageBitmap,
minScore?: number maxNumberBoxes: number,
minScore: number
): Promise<ObjectDetection[]> { ): Promise<ObjectDetection[]> {
const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model(); const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model();
const tfImage = tf.browser.fromPixels(image); const tfImage = tf.browser.fromPixels(image);
const detections = await ssdMobileNetV2Model.detect( const detections = await ssdMobileNetV2Model.detect(
tfImage, tfImage,
undefined, maxNumberBoxes,
minScore minScore
); );
return detections; return detections;

View file

@ -246,6 +246,7 @@ export interface FaceDetectionConfig {
export interface ObjectDetectionConfig { export interface ObjectDetectionConfig {
method: ObjectDetectionMethod; method: ObjectDetectionMethod;
maxNumBoxes: number;
minScore: number; minScore: number;
} }
@ -386,7 +387,8 @@ export interface ObjectDetectionService {
// init(): Promise<void>; // init(): Promise<void>;
detectObjects( detectObjects(
image: ImageBitmap, image: ImageBitmap,
minScore?: number maxNumBoxes: number,
minScore: number
): Promise<ObjectDetection[]>; ): Promise<ObjectDetection[]>;
dispose(): Promise<void>; dispose(): Promise<void>;
} }