pdufour commited on
Commit
1477d49
1 Parent(s): a589fe4

Update index.js

Browse files
Files changed (1) hide show
  1. index.js +60 -108
index.js CHANGED
@@ -1,81 +1,51 @@
1
- import { env, AutoTokenizer, RawImage, Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers';
2
- import { getModelJSON, getModelFile } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.0.2/src/utils/hub.js";
3
- import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.20.0/dist/ort.webgpu.mjs";
4
 
5
- // Since we will download the model from the Hugging Face Hub, we can skip the local model check
6
- env.allowLocalModels = false;
7
-
8
- // Reference the elements that we will need
9
- const status = document.getElementById('status');
10
- const fileUpload = document.getElementById('upload');
11
- const imageContainer = document.getElementById('container');
12
- const example = document.getElementById('example');
13
-
14
- const EXAMPLE_URL = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg";
15
  const INPUT_IMAGE_SIZE = [960, 960];
16
  const HEIGHT_FACTOR = 10;
17
  const WIDTH_FACTOR = 10;
18
  const IMAGE_EMBED_SIZE = WIDTH_FACTOR * HEIGHT_FACTOR;
19
  const MAX_SEQ_LENGTH = 1024;
20
- const ONNX_URL = "http://localhost:3004/onnx";
21
  const BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct";
22
  const ONNX_MODEL = "pdufour/Qwen2-VL-2B-Instruct-ONNX-Q4-F16";
23
  const QUANT = "q4f16";
24
  const MAX_SINGLE_CHAT_LENGTH = 10;
25
 
26
- status.textContent = 'Loading model...';
27
- status.textContent = 'Ready';
28
-
29
- example.addEventListener('click', (e) => {
30
- e.preventDefault();
31
- parse(EXAMPLE_URL, 'Describe this image.');
32
- });
33
-
34
- fileUpload.addEventListener('change', function (e) {
35
- const file = e.target.files[0];
36
- if (!file) {
37
- return;
38
- }
39
-
40
- const reader = new FileReader();
41
 
42
- // Set up a callback when the file is loaded
43
- reader.onload = e2 => parse(e2.target.result, '');
44
-
45
- reader.readAsDataURL(file);
46
- });
47
 
48
- async function parse(img, txt) {
49
- imageContainer.innerHTML = '';
50
- imageContainer.style.backgroundImage = `url(${img})`;
 
51
 
52
- status.textContent = 'Analysing...';
53
- const output = await imageTextToText(img, txt);
54
- status.textContent = '';
55
- output.forEach(renderBox);
56
  }
57
 
58
-
59
- export async function imageTextToText(
60
- imagePath,
61
- query,
62
- vision = true
63
- ) {
64
- const config = (await getModelJSON(BASE_MODEL, "config.json"))
65
-
66
  const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
67
-
68
- let position_ids;
69
- let num_decode = 0;
70
  let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]);
71
-
 
 
72
  let past_key_states = new ort.Tensor(
73
  "float16",
74
  new Uint16Array(
75
  config.num_hidden_layers *
76
- config.num_key_value_heads *
77
- MAX_SEQ_LENGTH *
78
- (config.hidden_size / config.num_attention_heads)
79
  ).fill(0),
80
  [
81
  config.num_hidden_layers,
@@ -84,17 +54,8 @@ export async function imageTextToText(
84
  config.hidden_size / config.num_attention_heads,
85
  ]
86
  );
87
-
88
  let past_value_states = past_key_states;
89
 
90
- let attention_mask = new ort.Tensor(
91
- "float16",
92
- new Uint16Array([0xfbff]), // -65504.0 in float16
93
- [1]
94
- );
95
-
96
- let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]);
97
-
98
  const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL);
99
  const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`;
100
  const token = await tokenizer(prompt, {
@@ -103,18 +64,17 @@ export async function imageTextToText(
103
  tokenize: true,
104
  }).input_ids;
105
 
106
- const seq_length = token.dims[1];
107
- let ids_len = new Tensor("int64", new BigInt64Array([BigInt(seq_length)]), [
108
- 1,
109
- ]);
110
-
111
  let input_ids = new ort.Tensor(
112
  "int32",
113
  new Int32Array(MAX_SEQ_LENGTH).fill(0),
114
  [MAX_SEQ_LENGTH]
115
  );
 
116
 
117
- input_ids.data.set(Array.from(token.data.slice(0, seq_length), Number));
 
 
118
 
119
  if (vision) {
120
  let image = await RawImage.fromURL(imagePath);
@@ -122,51 +82,40 @@ export async function imageTextToText(
122
  image = image.rgb().toTensor("CHW").to("float32").div_(255.0);
123
  const pixel_values = image.unsqueeze(0);
124
 
125
- const ortSessionA = await ort.InferenceSession.create(
126
- await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`),
127
- { executionProviders: ["webgpu"] }
128
- );
129
-
130
  const { image_embed } = await ortSessionA.run({ pixel_values });
131
-
132
  ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE));
133
 
134
  const ortSessionD = await ort.InferenceSession.create(
135
- `${BASE_URL}/QwenVL_D${suffix}.onnx`,
136
  { executionProviders: ["webgpu"] }
137
  );
138
 
139
- ({ hidden_states: past_key_states, position_ids } =
140
- await ortSessionD.run({
141
- "hidden_states.1": past_key_states,
142
- image_embed,
143
- ids_len,
144
- "ids_len_minus": new Tensor(
145
- "int32",
146
- new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]),
147
- [1]
148
- ),
149
- "split_factor": new Tensor(
150
- "int32",
151
- new Int32Array([
152
- MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE,
153
- ]),
154
- [1]
155
- ),
156
- }));
157
- }
158
 
159
- const ortSessionB = await ort.InferenceSession.create(
160
- `${BASE_URL}/QwenVL_B${suffix}.onnx`,
161
- { executionProviders: ["webgpu"] }
162
- );
163
 
164
- while (
165
- num_decode < MAX_SINGLE_CHAT_LENGTH &&
166
- Number(history_len.data[0]) < MAX_SEQ_LENGTH
167
- ) {
168
  const ortSessionE = await ort.InferenceSession.create(
169
- `${BASE_URL}/QwenVL_E_q4f16.onnx`,
170
  { executionProviders: ["wasm"] }
171
  );
172
 
@@ -184,8 +133,9 @@ export async function imageTextToText(
184
  const token_id = result.max_logit_ids;
185
  if (token_id === 151643 || token_id === 151645) break;
186
 
 
 
187
  num_decode++;
188
-
189
  history_len = history_len.add(BigInt(1));
190
  pos_factor = new Tensor(
191
  "float16",
@@ -204,6 +154,8 @@ export async function imageTextToText(
204
 
205
  past_key_states = hidden_states;
206
  }
207
- }
208
 
 
 
209
 
 
 
1
+ import { env, AutoTokenizer, RawImage, Tensor } from '@huggingface/transformers';
2
+ import { getModelJSON, getModelFile } from "@huggingface/transformers/utils/hub.js";
3
+ import * as ort from "onnxruntime-web/webgpu";
4
 
 
 
 
 
 
 
 
 
 
 
5
  const INPUT_IMAGE_SIZE = [960, 960];
6
  const HEIGHT_FACTOR = 10;
7
  const WIDTH_FACTOR = 10;
8
  const IMAGE_EMBED_SIZE = WIDTH_FACTOR * HEIGHT_FACTOR;
9
  const MAX_SEQ_LENGTH = 1024;
 
10
  const BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct";
11
  const ONNX_MODEL = "pdufour/Qwen2-VL-2B-Instruct-ONNX-Q4-F16";
12
  const QUANT = "q4f16";
13
  const MAX_SINGLE_CHAT_LENGTH = 10;
14
 
15
+ let ortSessionA, ortSessionB, ortSessionC;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ async function initializeSessions() {
18
+ ortSessionA = await ort.InferenceSession.create(
19
+ await getModelFile(ONNX_MODEL, `onnx/QwenVL_A_${QUANT}.onnx`),
20
+ { executionProviders: ["webgpu"] }
21
+ );
22
 
23
+ ortSessionB = await ort.InferenceSession.create(
24
+ await getModelFile(ONNX_MODEL, `onnx/QwenVL_B_${QUANT}.onnx`),
25
+ { executionProviders: ["webgpu"] }
26
+ );
27
 
28
+ ortSessionC = await ort.InferenceSession.create(
29
+ await getModelFile(ONNX_MODEL, `onnx/QwenVL_C_${QUANT}.onnx`),
30
+ { executionProviders: ["webgpu"] }
31
+ );
32
  }
33
 
34
+ export async function imageTextToText(imagePath, query, vision = true) {
35
+ const config = await getModelJSON(BASE_MODEL, "config.json");
36
+
 
 
 
 
 
37
  const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
 
 
 
38
  let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]);
39
+ let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]);
40
+ let attention_mask = new ort.Tensor("float16", new Uint16Array([0xfbff]), [1]);
41
+
42
  let past_key_states = new ort.Tensor(
43
  "float16",
44
  new Uint16Array(
45
  config.num_hidden_layers *
46
+ config.num_key_value_heads *
47
+ MAX_SEQ_LENGTH *
48
+ (config.hidden_size / config.num_attention_heads)
49
  ).fill(0),
50
  [
51
  config.num_hidden_layers,
 
54
  config.hidden_size / config.num_attention_heads,
55
  ]
56
  );
 
57
  let past_value_states = past_key_states;
58
 
 
 
 
 
 
 
 
 
59
  const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL);
60
  const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`;
61
  const token = await tokenizer(prompt, {
 
64
  tokenize: true,
65
  }).input_ids;
66
 
67
+ let ids_len = new Tensor("int64", new BigInt64Array([BigInt(token.dims[1])]), [1]);
 
 
 
 
68
  let input_ids = new ort.Tensor(
69
  "int32",
70
  new Int32Array(MAX_SEQ_LENGTH).fill(0),
71
  [MAX_SEQ_LENGTH]
72
  );
73
+ input_ids.data.set(Array.from(token.data.slice(0, token.dims[1]), Number));
74
 
75
+ // Get position IDs from Session C
76
+ const dummy = new ort.Tensor("int32", new Int32Array([0]), []);
77
+ let { position_ids } = await ortSessionC.run({ dummy });
78
 
79
  if (vision) {
80
  let image = await RawImage.fromURL(imagePath);
 
82
  image = image.rgb().toTensor("CHW").to("float32").div_(255.0);
83
  const pixel_values = image.unsqueeze(0);
84
 
 
 
 
 
 
85
  const { image_embed } = await ortSessionA.run({ pixel_values });
 
86
  ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE));
87
 
88
  const ortSessionD = await ort.InferenceSession.create(
89
+ await getModelFile(ONNX_MODEL, `onnx/QwenVL_D_${QUANT}.onnx`),
90
  { executionProviders: ["webgpu"] }
91
  );
92
 
93
+ const result = await ortSessionD.run({
94
+ "hidden_states.1": past_key_states,
95
+ image_embed,
96
+ ids_len,
97
+ "ids_len_minus": new Tensor(
98
+ "int32",
99
+ new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]),
100
+ [1]
101
+ ),
102
+ "split_factor": new Tensor(
103
+ "int32",
104
+ new Int32Array([MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE]),
105
+ [1]
106
+ ),
107
+ });
 
 
 
 
108
 
109
+ past_key_states = result.hidden_states;
110
+ position_ids = result.position_ids;
111
+ }
 
112
 
113
+ let num_decode = 0;
114
+ let output = '';
115
+
116
+ while (num_decode < MAX_SINGLE_CHAT_LENGTH && Number(history_len.data[0]) < MAX_SEQ_LENGTH) {
117
  const ortSessionE = await ort.InferenceSession.create(
118
+ await getModelFile(ONNX_MODEL, `onnx/QwenVL_E_${QUANT}.onnx`),
119
  { executionProviders: ["wasm"] }
120
  );
121
 
 
133
  const token_id = result.max_logit_ids;
134
  if (token_id === 151643 || token_id === 151645) break;
135
 
136
+ output += tokenizer.decode([...token_id.data]);
137
+
138
  num_decode++;
 
139
  history_len = history_len.add(BigInt(1));
140
  pos_factor = new Tensor(
141
  "float16",
 
154
 
155
  past_key_states = hidden_states;
156
  }
 
157
 
158
+ return output;
159
+ }
160
 
161
+ await initializeSessions();