pdufour commited on
Commit
abaea80
·
verified ·
1 Parent(s): 5c17b02

Update index.js

Browse files
Files changed (1) hide show
  1. index.js +154 -67
index.js CHANGED
@@ -68,21 +68,31 @@ async function parse(img, txt) {
68
  status.textContent = output;
69
  }
70
 
71
- async function imageTextToText(imagePath, query, vision = true) {
72
- const config = await getModelJSON(BASE_MODEL, "config.json");
73
-
 
 
 
 
 
74
  const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
 
 
 
 
75
  let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]);
76
- let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]);
77
- let attention_mask = new ort.Tensor("float16", new Uint16Array([0xfbff]), [1]);
78
-
 
79
  let past_key_states = new ort.Tensor(
80
  "float16",
81
  new Uint16Array(
82
  config.num_hidden_layers *
83
- config.num_key_value_heads *
84
- MAX_SEQ_LENGTH *
85
- (config.hidden_size / config.num_attention_heads)
86
  ).fill(0),
87
  [
88
  config.num_hidden_layers,
@@ -91,8 +101,19 @@ async function imageTextToText(imagePath, query, vision = true) {
91
  config.hidden_size / config.num_attention_heads,
92
  ]
93
  );
 
94
  let past_value_states = past_key_states;
95
 
 
 
 
 
 
 
 
 
 
 
96
  const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL);
97
  const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`;
98
  const token = await tokenizer(prompt, {
@@ -101,72 +122,112 @@ async function imageTextToText(imagePath, query, vision = true) {
101
  tokenize: true,
102
  }).input_ids;
103
 
104
- let ids_len = new Tensor("int64", new BigInt64Array([BigInt(token.dims[1])]), [1]);
 
 
 
 
105
  let input_ids = new ort.Tensor(
106
  "int32",
107
  new Int32Array(MAX_SEQ_LENGTH).fill(0),
108
  [MAX_SEQ_LENGTH]
109
  );
110
- input_ids.data.set(Array.from(token.data.slice(0, token.dims[1]), Number));
111
 
112
-
 
 
 
 
 
113
  let { hidden_states } = await ortSessionB.run({
114
  input_ids: input_ids,
115
  ids_len: ids_len,
116
  });
117
 
118
- const dummy = new ort.Tensor("int32", new Int32Array([0]), []);
119
- let { position_ids } = await ortSessionC.run({ dummy });
 
120
 
 
121
  if (vision) {
122
  let image = await RawImage.fromURL(imagePath);
 
123
  image = await image.resize(INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1]);
124
- image = image.rgb().toTensor("CHW").to("float32").div_(255.0);
 
 
 
 
 
125
  const pixel_values = image.unsqueeze(0);
126
 
127
- console.log('run session a');
128
- const { image_embed } = await ortSessionA.run({ pixel_values });
129
- console.log('finished session a');
 
130
  ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE));
131
 
132
- const ortSessionD = await ort.InferenceSession.create(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  await getModelFile(ONNX_MODEL, `onnx/QwenVL_D_${QUANT}.onnx`),
134
- { executionProviders: ["webgpu"] }
 
 
135
  );
136
 
137
- console.log('run session d');
138
- const result = await ortSessionD.run({
139
- "hidden_states.1": hidden_states,
140
  image_embed,
141
  ids_len,
142
- "ids_len_minus": new Tensor(
143
- "int32",
144
- new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]),
145
- [1]
146
- ),
147
- "split_factor": new Tensor(
148
- "int32",
149
- new Int32Array([MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE]),
150
- [1]
151
- ),
152
- });
153
- console.log('finished session d');
154
 
155
- past_key_states = result.hidden_states;
156
- position_ids = result.position_ids;
157
  }
158
 
159
- let num_decode = 0;
160
  let output = '';
161
-
162
- while (num_decode < MAX_SINGLE_CHAT_LENGTH && Number(history_len.data[0]) < MAX_SEQ_LENGTH) {
163
- const ortSessionE = await ort.InferenceSession.create(
164
- await getModelFile(ONNX_MODEL, `onnx/QwenVL_E_${QUANT}.onnx`),
165
- { executionProviders: ["wasm"] }
166
- );
167
 
168
- const result = await ortSessionE.run({
169
- hidden_states: past_key_states,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  attention_mask,
171
  "past_key_states.1": past_key_states,
172
  "past_value_states.1": past_value_states,
@@ -174,35 +235,61 @@ async function imageTextToText(imagePath, query, vision = true) {
174
  ids_len,
175
  position_ids,
176
  pos_factor,
177
- });
178
- console.log('finished session e');
179
 
180
- const token_id = result.max_logit_ids;
181
- if (token_id === 151643 || token_id === 151645) break;
 
182
 
183
- output += tokenizer.decode([...token_id.data]);
184
-
185
  num_decode++;
186
- history_len = history_len.add(BigInt(1));
187
- pos_factor = new Tensor(
188
- "float16",
189
- new Uint16Array([Number(pos_factor.data[0]) + 1]),
190
- [1]
191
- );
192
 
193
- past_key_states = result.past_key_states;
194
- past_value_states = result.past_value_states;
195
 
196
- input_ids.data[0] = Number(token_id.data[0]);
197
- const { hidden_states } = await ortSessionB.run({
198
- input_ids,
199
- ids_len,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  });
 
201
 
202
- past_key_states = hidden_states;
 
 
 
 
 
 
 
 
203
  }
204
-
205
- return output;
206
  }
207
 
208
  await initializeSessions();
 
68
  status.textContent = output;
69
  }
70
 
71
+
72
+ export async function imageTextToText(
73
+ imagePath,
74
+ query,
75
+ vision = true
76
+ ) {
77
+ let ortSessionA, ortSessionB, ortSessionC, ortSessionD, ortSessionE;
78
+
79
  const prompt_head_len = new Tensor("int64", new BigInt64Array([5n]), [1]);
80
+ logger.tensor("prompt_head_len", prompt_head_len);
81
+
82
+ let position_ids;
83
+ let num_decode = 0;
84
  let history_len = new Tensor("int64", new BigInt64Array([0n]), [1]);
85
+ logger.tensor("history_len", history_len);
86
+
87
+ var pos_factor_v = BigInt(1 - IMAGE_EMBED_SIZE + WIDTH_FACTOR);
88
+
89
  let past_key_states = new ort.Tensor(
90
  "float16",
91
  new Uint16Array(
92
  config.num_hidden_layers *
93
+ config.num_key_value_heads *
94
+ MAX_SEQ_LENGTH *
95
+ (config.hidden_size / config.num_attention_heads)
96
  ).fill(0),
97
  [
98
  config.num_hidden_layers,
 
101
  config.hidden_size / config.num_attention_heads,
102
  ]
103
  );
104
+
105
  let past_value_states = past_key_states;
106
 
107
+ let attention_mask = new ort.Tensor(
108
+ "float16",
109
+ new Uint16Array([0xfbff]),
110
+ [1]
111
+ );
112
+
113
+ let pos_factor = new Tensor("float16", new Uint16Array([0]), [1]);
114
+ logger.tensor("pos_factor", pos_factor);
115
+
116
+ logger.groupCollapsed("[TOKENIZATION] Processing prompt...");
117
  const tokenizer = await AutoTokenizer.from_pretrained(BASE_MODEL);
118
  const prompt = `\n<|im_start|>user\n<|vision_start|><|vision_end|>${query}<|im_end|>\n<|im_start|>assistant\n`;
119
  const token = await tokenizer(prompt, {
 
122
  tokenize: true,
123
  }).input_ids;
124
 
125
+ const seq_length = token.dims[1];
126
+ let ids_len = new Tensor("int64", new BigInt64Array([BigInt(seq_length)]), [
127
+ 1,
128
+ ]);
129
+
130
  let input_ids = new ort.Tensor(
131
  "int32",
132
  new Int32Array(MAX_SEQ_LENGTH).fill(0),
133
  [MAX_SEQ_LENGTH]
134
  );
 
135
 
136
+ input_ids.data.set(Array.from(token.data.slice(0, seq_length), Number));
137
+
138
+ const dummy = new ort.Tensor("int32", new Int32Array([0]), []);
139
+
140
+ if (!ortSessionB) {
141
+ }
142
  let { hidden_states } = await ortSessionB.run({
143
  input_ids: input_ids,
144
  ids_len: ids_len,
145
  });
146
 
147
+ ({ position_ids } = await ortSessionC.run({
148
+ dummy: dummy,
149
+ }));
150
 
151
+ // Process image
152
  if (vision) {
153
  let image = await RawImage.fromURL(imagePath);
154
+
155
  image = await image.resize(INPUT_IMAGE_SIZE[0], INPUT_IMAGE_SIZE[1]);
156
+
157
+ image = image.rgb();
158
+
159
+ image = image.toTensor("CHW");
160
+ image = image.to("float32");
161
+ image = image.div_(255.0);
162
  const pixel_values = image.unsqueeze(0);
163
 
164
+ const { image_embed } = await ortSessionA.run({
165
+ pixel_values: pixel_values,
166
+ });
167
+
168
  ids_len = ids_len.add(BigInt(IMAGE_EMBED_SIZE));
169
 
170
+ const split_factor = new Tensor(
171
+ "int32",
172
+ new Int32Array([
173
+ MAX_SEQ_LENGTH - Number(ids_len.item()) - IMAGE_EMBED_SIZE,
174
+ ]),
175
+ [1]
176
+ );
177
+
178
+ const ids_len_minus = new Tensor(
179
+ "int32",
180
+ new Int32Array([Number(ids_len.item()) - Number(prompt_head_len.item())]),
181
+ [1]
182
+ );
183
+
184
+ await ortSessionA.release();
185
+ ortSessionA = null;
186
+
187
+ logger.log("session d create");
188
+ ortSessionD = await ort.InferenceSession.create(
189
  await getModelFile(ONNX_MODEL, `onnx/QwenVL_D_${QUANT}.onnx`),
190
+ {
191
+ executionProviders: ["webgpu"],
192
+ }
193
  );
194
 
195
+ ({ hidden_states, position_ids } = await ortSessionD.run({
196
+ "hidden_states.1": hidden_states,
 
197
  image_embed,
198
  ids_len,
199
+ ids_len_minus,
200
+ split_factor,
201
+ }));
 
 
 
 
 
 
 
 
 
202
 
203
+ await ortSessionD.release();
204
+ ortSessionD = null;
205
  }
206
 
 
207
  let output = '';
 
 
 
 
 
 
208
 
209
+ while (
210
+ num_decode < MAX_SINGLE_CHAT_LENGTH &&
211
+ Number(history_len.data[0]) < MAX_SEQ_LENGTH
212
+ ) {
213
+ let token_id;
214
+
215
+ if (!ortSessionE) {
216
+ console.log("Create ortSessionE");
217
+ ortSessionE = await ort.InferenceSession.create(
218
+ await getModelFile(ONNX_MODEL, `onnx/QwenVL_E_${QUANT}.onnx`),
219
+ {
220
+ executionProviders: ["wasm"],
221
+ },
222
+ );
223
+ }
224
+
225
+ ({
226
+ max_logit_ids: token_id,
227
+ past_key_states: past_key_states,
228
+ past_value_states: past_value_states,
229
+ } = await ortSessionE.run({
230
+ hidden_states,
231
  attention_mask,
232
  "past_key_states.1": past_key_states,
233
  "past_value_states.1": past_value_states,
 
235
  ids_len,
236
  position_ids,
237
  pos_factor,
238
+ }));
 
239
 
240
+ if (token_id === 151643 || token_id === 151645) {
241
+ break;
242
+ }
243
 
 
 
244
  num_decode++;
245
+ if (num_decode < 2) {
246
+ history_len = history_len.add(BigInt(ids_len.data[0]));
 
 
 
 
247
 
248
+ ids_len = new ort.Tensor("int64", new BigInt64Array([1n]), [1]);
 
249
 
250
+ attention_mask = new ort.Tensor("float16", new Uint16Array([0]), [1]);
251
+
252
+ if (vision) {
253
+ pos_factor = new Tensor(
254
+ "float16",
255
+ new Uint16Array([int64ToFloat16(pos_factor_v + ids_len.data[0])]),
256
+ [1]
257
+ );
258
+ } else {
259
+ pos_factor = new Tensor(
260
+ "float16",
261
+ new Uint16Array([int64ToFloat16(history_len.data[0] + BigInt(1))]),
262
+ [1]
263
+ );
264
+ }
265
+
266
+ } else {
267
+ history_len = history_len.add(BigInt(1));
268
+ pos_factor = pos_factor.map((v) =>
269
+ int64ToFloat16(float16ToInt64(v) + BigInt(1))
270
+ );
271
+ logger.tensor("Updated history_len", history_len);
272
+ logger.tensor("Updated pos_factor", pos_factor);
273
+ logger.groupEnd();
274
+ }
275
+ (input_ids.data)[0] = Number(token_id.data[0]);
276
+
277
+ const result_B = await ortSessionB.run({
278
+ input_ids: input_ids,
279
+ ids_len: ids_len,
280
  });
281
+ hidden_states = result_B.hidden_states;
282
 
283
+ if (
284
+ !Number.isInteger(token_id.data[0]) &&
285
+ !["bigint", "number"].includes(typeof token_id.data[0])
286
+ ) {
287
+ throw new Error(`Token ID is not an integer`);
288
+ } else {
289
+ const decoded = tokenizer.decode([...token_id.data])
290
+ output += decoded;
291
+ }
292
  }
 
 
293
  }
294
 
295
  await initializeSessions();