update getRGBData logic
This commit is contained in:
parent
ebe8966ca9
commit
74a5274249
|
@ -12,6 +12,7 @@ import fetch from 'node-fetch';
|
||||||
import { writeNodeStream } from './fs';
|
import { writeNodeStream } from './fs';
|
||||||
import { getPlatform } from '../utils/common/platform';
|
import { getPlatform } from '../utils/common/platform';
|
||||||
import { CustomErrors } from '../constants/errors';
|
import { CustomErrors } from '../constants/errors';
|
||||||
|
const jpeg = require('jpeg-js');
|
||||||
|
|
||||||
const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
|
const CLIP_MODEL_PATH_PLACEHOLDER = 'CLIP_MODEL';
|
||||||
const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
|
const GGMLCLIP_PATH_PLACEHOLDER = 'GGML_PATH';
|
||||||
|
@ -34,8 +35,7 @@ const TEXT_EMBEDDING_EXTRACT_CMD: string[] = [
|
||||||
];
|
];
|
||||||
const ort = require('onnxruntime-node');
|
const ort = require('onnxruntime-node');
|
||||||
import Tokenizer from '../utils/clip-bpe-ts/mod';
|
import Tokenizer from '../utils/clip-bpe-ts/mod';
|
||||||
|
import { readFile } from 'promise-fs';
|
||||||
const { createCanvas, Image } = require('canvas');
|
|
||||||
|
|
||||||
const TEXT_MODEL_DOWNLOAD_URL = {
|
const TEXT_MODEL_DOWNLOAD_URL = {
|
||||||
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
|
ggml: 'https://models.ente.io/clip-vit-base-patch32_ggml-text-model-f16.gguf',
|
||||||
|
@ -261,7 +261,7 @@ export async function computeONNXImageEmbedding(
|
||||||
): Promise<Float32Array> {
|
): Promise<Float32Array> {
|
||||||
try {
|
try {
|
||||||
const imageSession = await getOnnxImageSession();
|
const imageSession = await getOnnxImageSession();
|
||||||
const rgbData = await getRgbData(inputFilePath);
|
const rgbData = await getRGBData(inputFilePath);
|
||||||
const feeds = {
|
const feeds = {
|
||||||
input: new ort.Tensor('float32', rgbData, [1, 3, 224, 224]),
|
input: new ort.Tensor('float32', rgbData, [1, 3, 224, 224]),
|
||||||
};
|
};
|
||||||
|
@ -350,53 +350,72 @@ export async function computeONNXTextEmbedding(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async function getRgbData(inputFilePath: string) {
|
async function getRGBData(inputFilePath: string) {
|
||||||
const width = 224;
|
const jpegData = await readFile(inputFilePath);
|
||||||
const height = 224;
|
const rawImageData = jpeg.decode(jpegData, {
|
||||||
// let blob = await fetch(imgUrl, {referrer:""}).then(r => r.blob());
|
useTArray: true,
|
||||||
|
formatAsRGBA: false,
|
||||||
|
});
|
||||||
|
|
||||||
const img = new Image();
|
const nx: number = rawImageData.width;
|
||||||
img.src = inputFilePath;
|
const ny: number = rawImageData.height;
|
||||||
|
const inputImage: Uint8Array = rawImageData.data;
|
||||||
|
|
||||||
const canvas = createCanvas(width, height);
|
const nx2: number = 224;
|
||||||
const ctx = canvas.getContext('2d');
|
const ny2: number = 224;
|
||||||
|
const totalSize: number = 3 * nx2 * ny2;
|
||||||
|
|
||||||
// scale img to fit the shorter side to the canvas size
|
const result: number[] = Array(totalSize).fill(0);
|
||||||
const scale = Math.max(
|
const scale: number = Math.max(nx, ny) / 224;
|
||||||
canvas.width / img.width,
|
|
||||||
canvas.height / img.height
|
|
||||||
);
|
|
||||||
|
|
||||||
// compute new image dimensions that would maintain the original aspect ratio
|
const nx3: number = Math.round(nx / scale);
|
||||||
const scaledW = img.width * scale;
|
const ny3: number = Math.round(ny / scale);
|
||||||
const scaledH = img.height * scale;
|
|
||||||
|
|
||||||
// compute position to center the image
|
const mean: number[] = [0.48145466, 0.4578275, 0.40821073];
|
||||||
const posX = (canvas.width - scaledW) / 2;
|
const std: number[] = [0.26862954, 0.26130258, 0.27577711];
|
||||||
const posY = (canvas.height - scaledH) / 2;
|
|
||||||
|
|
||||||
// draw the image centered and scaled on the canvas
|
for (let y = 0; y < ny3; y++) {
|
||||||
ctx.drawImage(img, posX, posY, scaledW, scaledH);
|
for (let x = 0; x < nx3; x++) {
|
||||||
|
for (let c = 0; c < 3; c++) {
|
||||||
|
// linear interpolation
|
||||||
|
const sx: number = (x + 0.5) * scale - 0.5;
|
||||||
|
const sy: number = (y + 0.5) * scale - 0.5;
|
||||||
|
|
||||||
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
|
const x0: number = Math.max(0, Math.floor(sx));
|
||||||
const rgbData: [number[][], number[][], number[][]] = [[], [], []]; // [r, g, b]
|
const y0: number = Math.max(0, Math.floor(sy));
|
||||||
// remove alpha and put into correct shape:
|
|
||||||
const d = imageData.data;
|
const x1: number = Math.min(x0 + 1, nx - 1);
|
||||||
for (let i = 0; i < d.length; i += 4) {
|
const y1: number = Math.min(y0 + 1, ny - 1);
|
||||||
const x = (i / 4) % width;
|
|
||||||
const y = Math.floor(i / 4 / width);
|
const dx: number = sx - x0;
|
||||||
if (!rgbData[0][y]) rgbData[0][y] = [];
|
const dy: number = sy - y0;
|
||||||
if (!rgbData[1][y]) rgbData[1][y] = [];
|
|
||||||
if (!rgbData[2][y]) rgbData[2][y] = [];
|
const j00: number = 3 * (y0 * nx + x0) + c;
|
||||||
rgbData[0][y][x] = d[i + 0] / 255;
|
const j01: number = 3 * (y0 * nx + x1) + c;
|
||||||
rgbData[1][y][x] = d[i + 1] / 255;
|
const j10: number = 3 * (y1 * nx + x0) + c;
|
||||||
rgbData[2][y][x] = d[i + 2] / 255;
|
const j11: number = 3 * (y1 * nx + x1) + c;
|
||||||
// From CLIP repo: Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
|
||||||
rgbData[0][y][x] = (rgbData[0][y][x] - 0.48145466) / 0.26862954;
|
const v00: number = inputImage[j00];
|
||||||
rgbData[1][y][x] = (rgbData[1][y][x] - 0.4578275) / 0.26130258;
|
const v01: number = inputImage[j01];
|
||||||
rgbData[2][y][x] = (rgbData[2][y][x] - 0.40821073) / 0.27577711;
|
const v10: number = inputImage[j10];
|
||||||
|
const v11: number = inputImage[j11];
|
||||||
|
|
||||||
|
const v0: number = v00 * (1 - dx) + v01 * dx;
|
||||||
|
const v1: number = v10 * (1 - dx) + v11 * dx;
|
||||||
|
|
||||||
|
const v: number = v0 * (1 - dy) + v1 * dy;
|
||||||
|
|
||||||
|
const v2: number = Math.min(Math.max(Math.round(v), 0), 255);
|
||||||
|
|
||||||
|
// createTensorWithDataList is dump compared to reshape and hence has to be given with one channel after another
|
||||||
|
const i: number = y * nx3 + x + (c % 3) * 224 * 224;
|
||||||
|
|
||||||
|
result[i] = (v2 / 255 - mean[c]) / std[c];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return Float32Array.from(rgbData.flat().flat());
|
|
||||||
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
export const computeClipMatchScore = async (
|
export const computeClipMatchScore = async (
|
||||||
|
|
Loading…
Reference in a new issue