diff --git a/src/services/machineLearning/ssdMobileNetV2Service.ts b/src/services/machineLearning/ssdMobileNetV2Service.ts index 0e769c2f6..f51b6e21f 100644 --- a/src/services/machineLearning/ssdMobileNetV2Service.ts +++ b/src/services/machineLearning/ssdMobileNetV2Service.ts @@ -12,7 +12,7 @@ import { // SSDMobileNetV2Model, // } from './modelWrapper/SSDMobileNetV2'; -import * as SSDMobileNet from 'ssd-mobilenet'; +import * as SSDMobileNet from '@tensorflow-models/coco-ssd'; class SSDMobileNetV2 implements ObjectDetectionService { private ssdMobileNetV2Model: SSDMobileNet.ObjectDetection; @@ -27,9 +27,9 @@ class SSDMobileNetV2 implements ObjectDetectionService { private async init() { this.ssdMobileNetV2Model = await SSDMobileNet.load({ - base: 'mobilenet_v2', - dataset: 'open_images', - modelUrl: '/models/open-images-ssd-mobilenet-v2/model.json', + // base: 'mobilenet_v2', + // dataset: 'open_images', + // modelUrl: '/models/open-images-ssd-mobilenet-v2/model.json', }); console.log( 'loaded ssdMobileNetV2Model', @@ -78,7 +78,8 @@ class SSDMobileNetV2 implements ObjectDetectionService { public async detectObjectUsingModel(imageBitmap: ImageBitmap) { const ssdMobileNetV2Model = await this.getSSDMobileNetV2Model(); - const predictions = await ssdMobileNetV2Model.detect(imageBitmap); + const tfImage = tf.browser.fromPixels(imageBitmap); + const predictions = await ssdMobileNetV2Model.detect(tfImage); return predictions; }