pdufour's picture
Update index.js
41c8086 verified
raw
history blame
6.64 kB
import { env, AutoTokenizer, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers';
import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js";
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs";
const EXAMPLE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
const INPUT_IMAGE_SIZE = [960, 960];
const HEIGHT_FACTOR = 10;
const WIDTH_FACTOR = 10;
const IMAGE_EMBED_SIZE = WIDTH_FACTOR * HEIGHT_FACTOR;
const MAX_SEQ_LENGTH = 1024;
const BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct";
const ONNX_MODEL = "pdufour/Qwen2-VL-2B-Instruct-ONNX-Q4-F16";
const QUANT = "q4f16";
const MAX_SINGLE_CHAT_LENGTH = 10;
// UI Elements
const status = document.getElementById('status');
const fileUpload = document.getElementById('upload');
const imageContainer = document.getElementById('container');
const example = document.getElementById('example');
let ortSessionA, ortSessionB, ortSessionC;
async function initializeSessions() {
status.textContent = 'Loading model...';
ortSessionA = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`),
{ executionProviders: ["webgpu"] }
);
console.log({ortSessionA});
ortSessionB = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_B_${QUANT}.onnx`),
{ executionProviders: ["webgpu"] }
);
console.log({ortSessionB});
ortSessionC = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_C_${QUANT}.onnx`),
{ executionProviders: ["webgpu"] }
);
console.log({ortSessionC});
status.textContent = 'Ready';
}
// UI Event Handlers
example.addEventListener('click', (e) => {
e.preventDefault();
parse(EXAMPLE_URL, 'Describe this image.');
});
fileUpload.addEventListener('change', function(e) {
const file = e.target.files[0];
if (!file) return;
const reader = new FileReader();
reader.onload = e2 => parse(e2.target.result, '');
reader.readAsDataURL(file);
});
async function parse(img, txt) {
imageContainer.innerHTML = '';
imageContainer.style.backgroundImage = `url(${img})`;
status.textContent = 'Analysing...';
const output = await imageTextToText(img, txt);
status.textContent = output;
}
async function imageTextToText(imagePath, query, vision = true) {
const config = await getModelJSON(BASE_MODEL, "config.json");
const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]);
let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]);
let attention_mask = new ort.Tensor("float16", new Uint16Array([0xfbff]), [1]);
let past_key_states = new ort.Tensor(
"float16",
new Uint16Array(
config.num_hidden_layers *
config.num_key_value_heads *
MAX_SEQ_LENGTH *
(config.hidden_size / config.num_attention_heads)
).fill(0),
[
config.num_hidden_layers,
config.num_key_value_heads,
MAX_SEQ_LENGTH,
config.hidden_size / config.num_attention_heads,
]
);
let past_value_states = past_key_states;
const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL);
const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`;
const token = await tokenizer(prompt, {
return_tensors: "pt",
add_generation_prompt: false,
tokenize: true,
}).input_ids;
let ids_len = new Tensor("int64", new BigInt64Array([BigInt(token.dims[1])]), [1]);
let input_ids = new ort.Tensor(
"int32",
new Int32Array(MAX_SEQ_LENGTH).fill(0),
[MAX_SEQ_LENGTH]
);
input_ids.data.set(Array.from(token.data.slice(0, token.dims[1]), Number));
let { hidden_states } = await ortSessionB.run({
input_ids: input_ids,
ids_len: ids_len,
});
const dummy = new ort.Tensor("int32", new Int32Array([0]), []);
let { position_ids } = await ortSessionC.run({ dummy });
if (vision) {
let image = await RawImage.fromURL(imagePath);
image = await image.resize(INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1]);
image = image.rgb().toTensor("CHW").to("float32").div_(255.0);
const pixel_values = image.unsqueeze(0);
console.log('run session a');
const { image_embed } = await ortSessionA.run({ pixel_values });
console.log('finished session a');
ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE));
const ortSessionD = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_D_${QUANT}.onnx`),
{ executionProviders: ["webgpu"] }
);
console.log('run session d');
const result = await ortSessionD.run({
"hidden_states.1": past_key_states,
image_embed,
ids_len,
"ids_len_minus": new Tensor(
"int32",
new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]),
[1]
),
"split_factor": new Tensor(
"int32",
new Int32Array([MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE]),
[1]
),
});
console.log('finished session d');
past_key_states = result.hidden_states;
position_ids = result.position_ids;
}
let num_decode = 0;
let output = '';
while (num_decode < MAX_SINGLE_CHAT_LENGTH && Number(history_len.data[0]) < MAX_SEQ_LENGTH) {
const ortSessionE = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_E_${QUANT}.onnx`),
{ executionProviders: ["wasm"] }
);
const result = await ortSessionE.run({
hidden_states: past_key_states,
attention_mask,
"past_key_states.1": past_key_states,
"past_value_states.1": past_value_states,
history_len,
ids_len,
position_ids,
pos_factor,
});
console.log('finished session e');
const token_id = result.max_logit_ids;
if (token_id === 151643 || token_id === 151645) break;
output += tokenizer.decode([...token_id.data]);
num_decode++;
history_len = history_len.add(BigInt(1));
pos_factor = new Tensor(
"float16",
new Uint16Array([Number(pos_factor.data[0]) + 1]),
[1]
);
past_key_states = result.past_key_states;
past_value_states = result.past_value_states;
input_ids.data[0] = Number(token_id.data[0]);
const { hidden_states } = await ortSessionB.run({
input_ids,
ids_len,
});
past_key_states = hidden_states;
}
return output;
}
await initializeSessions();