Update index.js
Browse files
index.js
CHANGED
@@ -2,6 +2,7 @@ import { env, AutoTokenizer, RawImage, Tensor } from 'https://cdn.jsdelivr.net/n
|
|
2 |
import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js";
|
3 |
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs";
|
4 |
|
|
|
5 |
const INPUT_IMAGE_SIZE = [960, 960];
|
6 |
const HEIGHT_FACTOR = 10;
|
7 |
const WIDTH_FACTOR = 10;
|
@@ -12,9 +13,17 @@ const ONNX_MODEL = "pdufour/Qwen2-VL-2B-Instruct-ONNX-Q4-F16";
|
|
12 |
const QUANT = "q4f16";
|
13 |
const MAX_SINGLE_CHAT_LENGTH = 10;
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
let ortSessionA, ortSessionB, ortSessionC;
|
16 |
|
17 |
async function initializeSessions() {
|
|
|
|
|
18 |
ortSessionA = await ort.InferenceSession.create(
|
19 |
await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`),
|
20 |
{ executionProviders: ["webgpu"] }
|
@@ -29,9 +38,35 @@ async function initializeSessions() {
|
|
29 |
await getModelFile(ONNX_MODEL, `onnx/QwenVL_C_${QUANT}.onnx`),
|
30 |
{ executionProviders: ["webgpu"] }
|
31 |
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
}
|
33 |
|
34 |
-
|
35 |
const config = await getModelJSON(BASE_MODEL, "config.json");
|
36 |
|
37 |
const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
|
@@ -72,7 +107,6 @@ export async function imageTextToText(imagePath, query, vision = true) {
|
|
72 |
);
|
73 |
input_ids.data.set(Array.from(token.data.slice(0, token.dims[1]), Number));
|
74 |
|
75 |
-
// Get position IDs from Session C
|
76 |
const dummy = new ort.Tensor("int32", new Int32Array([0]), []);
|
77 |
let { position_ids } = await ortSessionC.run({ dummy });
|
78 |
|
|
|
2 |
import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js";
|
3 |
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs";
|
4 |
|
5 |
+
const EXAMPLE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
|
6 |
const INPUT_IMAGE_SIZE = [960, 960];
|
7 |
const HEIGHT_FACTOR = 10;
|
8 |
const WIDTH_FACTOR = 10;
|
|
|
13 |
const QUANT = "q4f16";
|
14 |
const MAX_SINGLE_CHAT_LENGTH = 10;
|
15 |
|
16 |
+
// UI Elements
|
17 |
+
const status = document.getElementById('status');
|
18 |
+
const fileUpload = document.getElementById('upload');
|
19 |
+
const imageContainer = document.getElementById('container');
|
20 |
+
const example = document.getElementById('example');
|
21 |
+
|
22 |
let ortSessionA, ortSessionB, ortSessionC;
|
23 |
|
24 |
async function initializeSessions() {
|
25 |
+
status.textContent = 'Loading model...';
|
26 |
+
|
27 |
ortSessionA = await ort.InferenceSession.create(
|
28 |
await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`),
|
29 |
{ executionProviders: ["webgpu"] }
|
|
|
38 |
await getModelFile(ONNX_MODEL, `onnx/QwenVL_C_${QUANT}.onnx`),
|
39 |
{ executionProviders: ["webgpu"] }
|
40 |
);
|
41 |
+
|
42 |
+
status.textContent = 'Ready';
|
43 |
+
}
|
44 |
+
|
45 |
+
// UI Event Handlers
|
46 |
+
example.addEventListener('click', (e) => {
|
47 |
+
e.preventDefault();
|
48 |
+
parse(EXAMPLE_URL, 'Describe this image.');
|
49 |
+
});
|
50 |
+
|
51 |
+
fileUpload.addEventListener('change', function(e) {
|
52 |
+
const file = e.target.files[0];
|
53 |
+
if (!file) return;
|
54 |
+
|
55 |
+
const reader = new FileReader();
|
56 |
+
reader.onload = e2 => parse(e2.target.result, '');
|
57 |
+
reader.readAsDataURL(file);
|
58 |
+
});
|
59 |
+
|
60 |
+
async function parse(img, txt) {
|
61 |
+
imageContainer.innerHTML = '';
|
62 |
+
imageContainer.style.backgroundImage = `url(${img})`;
|
63 |
+
status.textContent = 'Analysing...';
|
64 |
+
const output = await imageTextToText(img, txt);
|
65 |
+
status.textContent = '';
|
66 |
+
imageContainer.textContent = output;
|
67 |
}
|
68 |
|
69 |
+
async function imageTextToText(imagePath, query, vision = true) {
|
70 |
const config = await getModelJSON(BASE_MODEL, "config.json");
|
71 |
|
72 |
const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
|
|
|
107 |
);
|
108 |
input_ids.data.set(Array.from(token.data.slice(0, token.dims[1]), Number));
|
109 |
|
|
|
110 |
const dummy = new ort.Tensor("int32", new Int32Array([0]), []);
|
111 |
let { position_ids } = await ortSessionC.run({ dummy });
|
112 |
|