|
import os |
|
import sys |
|
import time |
|
import torch |
|
import numpy as np |
|
import requests |
|
import onnxruntime as ort |
|
from PIL import Image |
|
from io import BytesIO |
|
from transformers import Qwen2VLConfig, AutoTokenizer |
|
|
|
|
|
model_path = sys.argv[1] |
|
onnx_path = sys.argv[2] |
|
|
|
|
|
model_config = Qwen2VLConfig.from_pretrained(model_path) |
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
max_length = 1024 |
|
num_attention_heads = model_config.num_attention_heads |
|
num_key_value_heads = model_config.num_key_value_heads |
|
head_dim = model_config.hidden_size // num_attention_heads |
|
num_layers = model_config.num_hidden_layers |
|
|
|
|
|
session_options = ort.SessionOptions() |
|
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
|
|
|
|
|
models = ['A', 'B', 'C', 'D', 'E'] |
|
model_paths = {m: os.path.join(onnx_path, f'QwenVL_{m}_q4f16.onnx') for m in models} |
|
sessions = {m: ort.InferenceSession(path, sess_options=session_options) for m, path in model_paths.items()} |
|
|
|
|
|
inputs = { |
|
'A': sessions['A'].get_inputs()[0].name, |
|
'B': [sessions['B'].get_inputs()[i].name for i in range(2)], |
|
'C': sessions['C'].get_inputs()[0].name, |
|
'D': [inp.name for inp in sessions['D'].get_inputs()], |
|
'E': [inp.name for inp in sessions['E'].get_inputs()] |
|
} |
|
|
|
outputs = { |
|
'A': sessions['A'].get_outputs()[0].name, |
|
'B': sessions['B'].get_outputs()[0].name, |
|
'C': sessions['C'].get_outputs()[0].name, |
|
'D': [out.name for out in sessions['D'].get_outputs()], |
|
'E': [out.name for out in sessions['E'].get_outputs()] |
|
} |
|
|
|
|
|
image_url = 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg' |
|
image = Image.open(BytesIO(requests.get(image_url).content)).resize((960, 960)).convert('RGB') |
|
image_array = np.expand_dims(np.transpose(np.array(image).astype(np.float32), (2, 0, 1)), axis=0) / 255. |
|
|
|
|
|
prompt = "Describe this image." |
|
formatted_prompt = f"\n<|im_start|>user\n<|vision_start|><|vision_end|>{prompt}<|im_end|>\n<|im_start|>assistant\n" |
|
input_ids = tokenizer(formatted_prompt, return_tensors='pt')['input_ids'] |
|
input_lengths = np.array([input_ids.shape[1]], dtype=np.int64) |
|
tokens = np.zeros(max_length, dtype=np.int32) |
|
tokens[:input_ids.shape[1]] = input_ids[0, :] |
|
position = np.zeros(1, dtype=np.int64) |
|
|
|
|
|
key_cache = np.zeros((num_layers, num_key_value_heads, max_length, head_dim), dtype=np.float16) |
|
value_cache = key_cache.copy() |
|
|
|
|
|
hidden_states = sessions['B'].run( |
|
[outputs['B']], |
|
{inputs['B'][0]: tokens, inputs['B'][1]: input_lengths} |
|
)[0] |
|
|
|
batch_size = np.array(0, dtype=np.int32) |
|
batch_size, = sessions['C'].run([outputs['C']], {inputs['C']: batch_size}) |
|
|
|
|
|
image_features = sessions['A'].run([outputs['A']], {inputs['A']: image_array})[0] |
|
total_ids = 100 |
|
input_lengths += total_ids |
|
remaining_tokens = np.array(max_length - input_lengths[0] - total_ids, dtype=np.int32) |
|
tokens_to_stop = np.array(input_lengths[0] - 5, dtype=np.int32) |
|
|
|
hidden_states, batch_size = sessions['D'].run( |
|
outputs['D'], |
|
dict(zip(inputs['D'], |
|
[hidden_states, image_features, input_lengths, tokens_to_stop, remaining_tokens])) |
|
) |
|
|
|
|
|
start_time = time.time() |
|
for i in range(12): |
|
token, key_cache, value_cache = sessions['E'].run( |
|
outputs['E'], |
|
dict(zip(inputs['E'], |
|
[hidden_states, np.array([-65504. if i==0 else 0.], dtype=np.float16), |
|
key_cache, value_cache, position, input_lengths, batch_size, |
|
np.array([1-total_ids+10 if i==0 else position[0]+1], dtype=np.float16)])) |
|
) |
|
|
|
if token in [151643, 151645]: |
|
break |
|
|
|
if i < 1: |
|
position += input_lengths[0] |
|
input_lengths[0] = 1 |
|
else: |
|
position += 1 |
|
|
|
tokens[0] = token |
|
hidden_states = sessions['B'].run( |
|
[outputs['B']], |
|
{inputs['B'][0]: tokens, inputs['B'][1]: input_lengths} |
|
)[0] |
|
print(tokenizer.decode(token), end='', flush=True) |
|
|
|
print(f"\nTotal time: {time.time() - start_time:.2f}s") |
|
|