pdufour's picture
Update index.js
fde6d8b verified
raw
history blame
6.15 kB
import { env, AutoTokenizer, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers';
import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js";
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs";
// Since we will download the model from the Hugging Face Hub, we can skip the local model check
env.allowLocalModels = false;
// Reference the elements that we will need
const status = document.getElementById('status');
const fileUpload = document.getElementById('upload');
const imageContainer = document.getElementById('container');
const example = document.getElementById('example');
const EXAMPLE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
const INPUT_IMAGE_SIZE = [960, 960];
const HEIGHT_FACTOR = 10;
const WIDTH_FACTOR = 10;
const IMAGE_EMBED_SIZE = WIDTH_FACTOR * HEIGHT_FACTOR;
const MAX_SEQ_LENGTH = 1024;
const ONNX_URL = "http://localhost:3004/onnx";
const BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct";
const ONNX_MODEL = "pdufour/Qwen2-VL-2B-Instruct-ONNX-Q4-F16";
const QUANT = "q4f16";
const MAX_SINGLE_CHAT_LENGTH = 10;
status.textContent = 'Loading model...';
status.textContent = 'Ready';
example.addEventListener('click', (e) => {
e.preventDefault();
parse(EXAMPLE_URL, 'Describe this image.');
});
fileUpload.addEventListener('change', function (e) {
const file = e.target.files[0];
if (!file) {
return;
}
const reader = new FileReader();
// Set up a callback when the file is loaded
reader.onload = e2 => parse(e2.target.result, '');
reader.readAsDataURL(file);
});
async function parse(img, txt) {
imageContainer.innerHTML = '';
imageContainer.style.backgroundImage = `url(${img})`;
status.textContent = 'Analysing...';
const output = await imageTextToText(img, txt);
status.textContent = '';
output.forEach(renderBox);
}
export async function imageTextToText(
imagePath,
query,
vision = true
) {
const config = (await getModelJSON(BASE_MODEL, "config.json"))
const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
let position_ids;
let num_decode = 0;
let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]);
let past_key_states = new ort.Tensor(
"float16",
new Uint16Array(
config.num_hidden_layers *
config.num_key_value_heads *
MAX_SEQ_LENGTH *
(config.hidden_size / config.num_attention_heads)
).fill(0),
[
config.num_hidden_layers,
config.num_key_value_heads,
MAX_SEQ_LENGTH,
config.hidden_size / config.num_attention_heads,
]
);
let past_value_states = past_key_states;
let attention_mask = new ort.Tensor(
"float16",
new Uint16Array([0xfbff]), // -65504.0 in float16
[1]
);
let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]);
const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL);
const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`;
const token = await tokenizer(prompt, {
return_tensors: "pt",
add_generation_prompt: false,
tokenize: true,
}).input_ids;
const seq_length = token.dims[1];
let ids_len = new Tensor("int64", new BigInt64Array([BigInt(seq_length)]), [
1,
]);
let input_ids = new ort.Tensor(
"int32",
new Int32Array(MAX_SEQ_LENGTH).fill(0),
[MAX_SEQ_LENGTH]
);
input_ids.data.set(Array.from(token.data.slice(0, seq_length), Number));
if (vision) {
let image = await RawImage.fromURL(imagePath);
image = await image.resize(INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1]);
image = image.rgb().toTensor("CHW").to("float32").div_(255.0);
const pixel_values = image.unsqueeze(0);
const ortSessionA = await ort.InferenceSession.create(
await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`),
{ executionProviders: ["webgpu"] }
);
const { image_embed } = await ortSessionA.run({ pixel_values });
ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE));
const ortSessionD = await ort.InferenceSession.create(
`${BASE_URL}/QwenVL_D${suffix}.onnx`,
{ executionProviders: ["webgpu"] }
);
({ hidden_states: past_key_states, position_ids } =
await ortSessionD.run({
"hidden_states.1": past_key_states,
image_embed,
ids_len,
"ids_len_minus": new Tensor(
"int32",
new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]),
[1]
),
"split_factor": new Tensor(
"int32",
new Int32Array([
MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE,
]),
[1]
),
}));
}
const ortSessionB = await ort.InferenceSession.create(
`${BASE_URL}/QwenVL_B${suffix}.onnx`,
{ executionProviders: ["webgpu"] }
);
while (
num_decode < MAX_SINGLE_CHAT_LENGTH &&
Number(history_len.data[0]) < MAX_SEQ_LENGTH
) {
const ortSessionE = await ort.InferenceSession.create(
`${BASE_URL}/QwenVL_E_q4f16.onnx`,
{ executionProviders: ["wasm"] }
);
const result = await ortSessionE.run({
hidden_states: past_key_states,
attention_mask,
"past_key_states.1": past_key_states,
"past_value_states.1": past_value_states,
history_len,
ids_len,
position_ids,
pos_factor,
});
const token_id = result.max_logit_ids;
if (token_id === 151643 || token_id === 151645) break;
num_decode++;
history_len = history_len.add(BigInt(1));
pos_factor = new Tensor(
"float16",
new Uint16Array([Number(pos_factor.data[0]) + 1]),
[1]
);
past_key_states = result.past_key_states;
past_value_states = result.past_value_states;
input_ids.data[0] = Number(token_id.data[0]);
const { hidden_states } = await ortSessionB.run({
input_ids,
ids_len,
});
past_key_states = hidden_states;
}
}