pdufour commited on
Commit
b2ca260
1 Parent(s): 4c4928f

Update index.js

Browse files
Files changed (1) hide show
  1. index.js +36 -2
index.js CHANGED
@@ -2,6 +2,7 @@ import { env, AutoTokenizer, RawImage, Tensor } from 'https://cdn.jsdelivr.net/n
2
  import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js";
3
  import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs";
4
 
 
5
  const INPUT_IMAGE_SIZE = [960, 960];
6
  const HEIGHT_FACTOR = 10;
7
  const WIDTH_FACTOR = 10;
@@ -12,9 +13,17 @@ const ONNX_MODEL = "pdufour/Qwen2-VL-2B-Instruct-ONNX-Q4-F16";
12
  const QUANT = "q4f16";
13
  const MAX_SINGLE_CHAT_LENGTH = 10;
14
 
 
 
 
 
 
 
15
  let ortSessionA, ortSessionB, ortSessionC;
16
 
17
  async function initializeSessions() {
 
 
18
  ortSessionA = await ort.InferenceSession.create(
19
  await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`),
20
  { executionProviders: ["webgpu"] }
@@ -29,9 +38,35 @@ async function initializeSessions() {
29
  await getModelFile(ONNX_MODEL, `onnx/QwenVL_C_${QUANT}.onnx`),
30
  { executionProviders: ["webgpu"] }
31
  );
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  }
33
 
34
- export async function imageTextToText(imagePath, query, vision = true) {
35
  const config = await getModelJSON(BASE_MODEL, "config.json");
36
 
37
  const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
@@ -72,7 +107,6 @@ export async function imageTextToText(imagePath, query, vision = true) {
72
  );
73
  input_ids.data.set(Array.from(token.data.slice(0, token.dims[1]), Number));
74
 
75
- // Get position IDs from Session C
76
  const dummy = new ort.Tensor("int32", new Int32Array([0]), []);
77
  let { position_ids } = await ortSessionC.run({ dummy });
78
 
 
2
  import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js";
3
  import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs";
4
 
5
+ const EXAMPLE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
6
  const INPUT_IMAGE_SIZE = [960, 960];
7
  const HEIGHT_FACTOR = 10;
8
  const WIDTH_FACTOR = 10;
 
13
  const QUANT = "q4f16";
14
  const MAX_SINGLE_CHAT_LENGTH = 10;
15
 
16
+ // UI Elements
17
+ const status = document.getElementById('status');
18
+ const fileUpload = document.getElementById('upload');
19
+ const imageContainer = document.getElementById('container');
20
+ const example = document.getElementById('example');
21
+
22
  let ortSessionA, ortSessionB, ortSessionC;
23
 
24
  async function initializeSessions() {
25
+ status.textContent = 'Loading model...';
26
+
27
  ortSessionA = await ort.InferenceSession.create(
28
  await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`),
29
  { executionProviders: ["webgpu"] }
 
38
  await getModelFile(ONNX_MODEL, `onnx/QwenVL_C_${QUANT}.onnx`),
39
  { executionProviders: ["webgpu"] }
40
  );
41
+
42
+ status.textContent = 'Ready';
43
+ }
44
+
45
+ // UI Event Handlers
46
+ example.addEventListener('click', (e) => {
47
+ e.preventDefault();
48
+ parse(EXAMPLE_URL, 'Describe this image.');
49
+ });
50
+
51
+ fileUpload.addEventListener('change', function(e) {
52
+ const file = e.target.files[0];
53
+ if (!file) return;
54
+
55
+ const reader = new FileReader();
56
+ reader.onload = e2 => parse(e2.target.result, '');
57
+ reader.readAsDataURL(file);
58
+ });
59
+
60
+ async function parse(img, txt) {
61
+ imageContainer.innerHTML = '';
62
+ imageContainer.style.backgroundImage = `url(${img})`;
63
+ status.textContent = 'Analysing...';
64
+ const output = await imageTextToText(img, txt);
65
+ status.textContent = '';
66
+ imageContainer.textContent = output;
67
  }
68
 
69
+ async function imageTextToText(imagePath, query, vision = true) {
70
  const config = await getModelJSON(BASE_MODEL, "config.json");
71
 
72
  const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
 
107
  );
108
  input_ids.data.set(Array.from(token.data.slice(0, token.dims[1]), Number));
109
 
 
110
  const dummy = new ort.Tensor("int32", new Int32Array([0]), []);
111
  let { position_ids } = await ortSessionC.run({ dummy });
112