update getRGBData logic

This commit is contained in:
Abhinav 2024-01-05 11:55:25 +05:30
parent ebe8966ca9
commit 74a5274249

View file

@ -12,6 +12,7 @@ import fetch from 'node-fetch';
import { writeNodeStream } from './fs'; import { writeNodeStream } from './fs';
import { getPlatform } from '../utils/common/platform'; import { getPlatform } from '../utils/common/platform';
import { CustomErrors } from '../constants/errors'; import { CustomErrors } from '../constants/errors';
const jpeg = require('jpeg-js');
const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL'; const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH'; const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
@ -34,8 +35,7 @@ const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
]; ];
const ort = require('onnxruntime-node'); const ort = require('onnxruntime-node');
import Tokenizer from '../utils/clip-bpe-ts/mod'; import Tokenizer from '../utils/clip-bpe-ts/mod';
import { readFile } from 'promise-fs';
const { createCanvas, Image } = require('canvas');
const TEXT_MODEL_DOWNLOAD_URL = { const TEXT_MODEL_DOWNLOAD_URL = {
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf', ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
@ -261,7 +261,7 @@ export async function computeONNXImageEmbedding(
): Promise<Float32Array> { ): Promise<Float32Array> {
try { try {
const imageSession = await getOnnxImageSession(); const imageSession = await getOnnxImageSession();
const rgbData = await getRgbData(inputFilePath); const rgbData = await getRGBData(inputFilePath);
const feeds = { const feeds = {
input: new ort.Tensor('float32', rgbData, [1, 3, 224, 224]), input: new ort.Tensor('float32', rgbData, [1, 3, 224, 224]),
}; };
@ -350,53 +350,72 @@ export async function computeONNXTextEmbedding(
} }
} }
async function getRgbData(inputFilePath: string) { async function getRGBData(inputFilePath: string) {
const width = 224; const jpegData = await readFile(inputFilePath);
const height = 224; const rawImageData = jpeg.decode(jpegData, {
// let blob = await fetch(imgUrl, {referrer:""}).then(r => r.blob()); useTArray: true,
formatAsRGBA: false,
});
const img = new Image(); const nx: number = rawImageData.width;
img.src = inputFilePath; const ny: number = rawImageData.height;
const inputImage: Uint8Array = rawImageData.data;
const canvas = createCanvas(width, height); const nx2: number = 224;
const ctx = canvas.getContext('2d'); const ny2: number = 224;
const totalSize: number = 3 * nx2 * ny2;
// scale img to fit the shorter side to the canvas size const result: number[] = Array(totalSize).fill(0);
const scale = Math.max( const scale: number = Math.max(nx, ny) / 224;
canvas.width / img.width,
canvas.height / img.height
);
// compute new image dimensions that would maintain the original aspect ratio const nx3: number = Math.round(nx / scale);
const scaledW = img.width * scale; const ny3: number = Math.round(ny / scale);
const scaledH = img.height * scale;
// compute position to center the image const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
const posX = (canvas.width - scaledW) / 2; const std: number[] = [0.26862954, 0.26130258, 0.27577711];
const posY = (canvas.height - scaledH) / 2;
// draw the image centered and scaled on the canvas for (let y = 0; y < ny3; y++) {
ctx.drawImage(img, posX, posY, scaledW, scaledH); for (let x = 0; x < nx3; x++) {
for (let c = 0; c < 3; c++) {
// linear interpolation
const sx: number = (x + 0.5) * scale - 0.5;
const sy: number = (y + 0.5) * scale - 0.5;
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height); const x0: number = Math.max(0, Math.floor(sx));
const rgbData: [number[][], number[][], number[][]] = [[], [], []]; // [r, g, b] const y0: number = Math.max(0, Math.floor(sy));
// remove alpha and put into correct shape:
const d = imageData.data; const x1: number = Math.min(x0 + 1, nx - 1);
for (let i = 0; i < d.length; i += 4) { const y1: number = Math.min(y0 + 1, ny - 1);
const x = (i / 4) % width;
const y = Math.floor(i / 4 / width); const dx: number = sx - x0;
if (!rgbData[0][y]) rgbData[0][y] = []; const dy: number = sy - y0;
if (!rgbData[1][y]) rgbData[1][y] = [];
if (!rgbData[2][y]) rgbData[2][y] = []; const j00: number = 3 * (y0 * nx + x0) + c;
rgbData[0][y][x] = d[i + 0] / 255; const j01: number = 3 * (y0 * nx + x1) + c;
rgbData[1][y][x] = d[i + 1] / 255; const j10: number = 3 * (y1 * nx + x0) + c;
rgbData[2][y][x] = d[i + 2] / 255; const j11: number = 3 * (y1 * nx + x1) + c;
// From CLIP repo: Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
rgbData[0][y][x] = (rgbData[0][y][x] - 0.48145466) / 0.26862954; const v00: number = inputImage[j00];
rgbData[1][y][x] = (rgbData[1][y][x] - 0.4578275) / 0.26130258; const v01: number = inputImage[j01];
rgbData[2][y][x] = (rgbData[2][y][x] - 0.40821073) / 0.27577711; const v10: number = inputImage[j10];
const v11: number = inputImage[j11];
const v0: number = v00 * (1 - dx) + v01 * dx;
const v1: number = v10 * (1 - dx) + v11 * dx;
const v: number = v0 * (1 - dy) + v1 * dy;
const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
// createTensorWithDataList is dump compared to reshape and hence has to be given with one channel after another
const i: number = y * nx3 + x + (c % 3) * 224 * 224;
result[i] = (v2 / 255 - mean[c]) / std[c];
}
}
} }
return Float32Array.from(rgbData.flat().flat());
return result;
} }
export const computeClipMatchScore = async ( export const computeClipMatchScore = async (