update getRGBData logic

This commit is contained in:
Abhinav 2024-01-05 11:55:25 +05:30
parent ebe8966ca9
commit 74a5274249

View file

@ -12,6 +12,7 @@ import fetch from 'node-fetch';
import { writeNodeStream } from './fs';
import { getPlatform } from '../utils/common/platform';
import { CustomErrors } from '../constants/errors';
const jpeg = require('jpeg-js');
const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
@ -34,8 +35,7 @@ const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
];
const ort = require('onnxruntime-node');
import Tokenizer from '../utils/clip-bpe-ts/mod';
const { createCanvas, Image } = require('canvas');
import { readFile } from 'promise-fs';
const TEXT_MODEL_DOWNLOAD_URL = {
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
@ -261,7 +261,7 @@ export async function computeONNXImageEmbedding(
): Promise<Float32Array> {
try {
const imageSession = await getOnnxImageSession();
const rgbData = await getRgbData(inputFilePath);
const rgbData = await getRGBData(inputFilePath);
const feeds = {
input: new ort.Tensor('float32', rgbData, [1, 3, 224, 224]),
};
@ -350,53 +350,72 @@ export async function computeONNXTextEmbedding(
}
}
async function getRgbData(inputFilePath: string) {
const width = 224;
const height = 224;
// let blob = await fetch(imgUrl, {referrer:""}).then(r => r.blob());
async function getRGBData(inputFilePath: string) {
const jpegData = await readFile(inputFilePath);
const rawImageData = jpeg.decode(jpegData, {
useTArray: true,
formatAsRGBA: false,
});
const img = new Image();
img.src = inputFilePath;
const nx: number = rawImageData.width;
const ny: number = rawImageData.height;
const inputImage: Uint8Array = rawImageData.data;
const canvas = createCanvas(width, height);
const ctx = canvas.getContext('2d');
const nx2: number = 224;
const ny2: number = 224;
const totalSize: number = 3 * nx2 * ny2;
// scale img to fit the shorter side to the canvas size
const scale = Math.max(
canvas.width / img.width,
canvas.height / img.height
);
const result: number[] = Array(totalSize).fill(0);
const scale: number = Math.max(nx, ny) / 224;
// compute new image dimensions that would maintain the original aspect ratio
const scaledW = img.width * scale;
const scaledH = img.height * scale;
const nx3: number = Math.round(nx / scale);
const ny3: number = Math.round(ny / scale);
// compute position to center the image
const posX = (canvas.width - scaledW) / 2;
const posY = (canvas.height - scaledH) / 2;
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
// draw the image centered and scaled on the canvas
ctx.drawImage(img, posX, posY, scaledW, scaledH);
for (let y = 0; y < ny3; y++) {
for (let x = 0; x < nx3; x++) {
for (let c = 0; c < 3; c++) {
// linear interpolation
const sx: number = (x + 0.5) * scale - 0.5;
const sy: number = (y + 0.5) * scale - 0.5;
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const rgbData: [number[][], number[][], number[][]] = [[], [], []]; // [r, g, b]
// remove alpha and put into correct shape:
const d = imageData.data;
for (let i = 0; i < d.length; i += 4) {
const x = (i / 4) % width;
const y = Math.floor(i / 4 / width);
if (!rgbData[0][y]) rgbData[0][y] = [];
if (!rgbData[1][y]) rgbData[1][y] = [];
if (!rgbData[2][y]) rgbData[2][y] = [];
rgbData[0][y][x] = d[i + 0] / 255;
rgbData[1][y][x] = d[i + 1] / 255;
rgbData[2][y][x] = d[i + 2] / 255;
// From CLIP repo: Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
rgbData[0][y][x] = (rgbData[0][y][x] - 0.48145466) / 0.26862954;
rgbData[1][y][x] = (rgbData[1][y][x] - 0.4578275) / 0.26130258;
rgbData[2][y][x] = (rgbData[2][y][x] - 0.40821073) / 0.27577711;
const x0: number = Math.max(0, Math.floor(sx));
const y0: number = Math.max(0, Math.floor(sy));
const x1: number = Math.min(x0 + 1, nx - 1);
const y1: number = Math.min(y0 + 1, ny - 1);
const dx: number = sx - x0;
const dy: number = sy - y0;
const j00: number = 3 * (y0 * nx + x0) + c;
const j01: number = 3 * (y0 * nx + x1) + c;
const j10: number = 3 * (y1 * nx + x0) + c;
const j11: number = 3 * (y1 * nx + x1) + c;
const v00: number = inputImage[j00];
const v01: number = inputImage[j01];
const v10: number = inputImage[j10];
const v11: number = inputImage[j11];
const v0: number = v00 * (1 - dx) + v01 * dx;
const v1: number = v10 * (1 - dx) + v11 * dx;
const v: number = v0 * (1 - dy) + v1 * dy;
const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
// createTensorWithDataList is dump compared to reshape and hence has to be given with one channel after another
const i: number = y * nx3 + x + (c % 3) * 224 * 224;
result[i] = (v2 / 255 - mean[c]) / std[c];
}
}
}
return Float32Array.from(rgbData.flat().flat());
return result;
}
export const computeClipMatchScore = async (