File size: 5,828 Bytes
95dfa6c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import onnxruntime as ort
from transformers import AutoProcessor
from PIL import Image
import numpy as np
# set current working directory to the directory of this file
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))
# embeddings
vision_encoder = ort.InferenceSession("vision_encoder.onnx", providers=['CPUExecutionProvider'])
text_embed = ort.InferenceSession("embed_tokens.onnx", providers=['CPUExecutionProvider'])
# encoder
encoder = ort.InferenceSession("encoder_model.onnx", providers=['CPUExecutionProvider'])
# decoder
decoder_prefill = ort.InferenceSession("decoder_model.onnx", providers=['CPUExecutionProvider'])
decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider'])
# 1. prepare inputs
processor = AutoProcessor.from_pretrained("/home/zt/rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True)
# 2. prepare image
image = Image.open("./lena.png")
# resize image to 512x512
image = image.resize((512, 512))
# 3. prepare text
prompt = "<MORE_DETAILED_CAPTION>"
inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False)
for k, v in inputs.items():
print(k, v.shape)
# 4. run vision encoder
image_features = vision_encoder.run(None, {
"pixel_values": inputs["pixel_values"]
})
for output in image_features:
print(output.shape)
image_features = image_features[0]
np.save("image_features.npy", image_features)
# 5. run text embed
inputs_embeds = text_embed.run(None, {
"input_ids": inputs["input_ids"]
})
for output in inputs_embeds:
print(output.shape)
inputs_embeds = inputs_embeds[0]
# 6. concat image features and text embed
batch_size, image_token_length = image_features.shape[:-1]
image_attention_mask = np.ones((batch_size, image_token_length))
task_prefix_embeds = inputs_embeds
task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
if len(task_prefix_attention_mask.shape) == 3:
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1)
# 6. run encoder
encoder_out = encoder.run(None, {
"inputs_embeds": inputs_embeds,
"attention_mask": attention_mask.astype(np.int64)
})
for output in encoder_out:
print(output.shape)
encoder_hidden_states = encoder_out[0]
# 7. run decoder prefill stage
decoder_outs = decoder_prefill.run(None, {
"inputs_embeds": inputs_embeds[:, -1:],
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": attention_mask.astype(np.int64)
})
for output in decoder_outs:
print(output.shape)
encoder_kv = decoder_outs[1:];
# 8. run decoder decode stage(autoregressive)
generated_tokens = []
max_new_tokens = 32
while generated_tokens.__len__() < max_new_tokens:
# 获取上一步的输出
logits = decoder_outs[0]
decoder_kv = decoder_outs[1:]
# 选择最后一个token的logits
next_token_logits = logits[:, -1, :]
# 使用argmax选择下一个token (贪心算法)
next_token = np.argmax(next_token_logits, axis=-1)[0]
print("next_token: ", next_token)
# 将新生成的token添加到结果中
generated_tokens.append(next_token)
# 如果生成了结束符,则停止生成
if next_token == 2: # </s>
break
# 准备下一步的输入
next_input_embeds = text_embed.run(None, {
"input_ids": np.array([[next_token]], dtype=np.int64)
})[0]
# 运行decoder的decode阶段
decoder_outs = decoder_decode.run(None, {
"use_cache_branch": np.array([True], dtype=np.bool_),
"inputs_embeds": next_input_embeds,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": attention_mask.astype(np.int64),
"past_key_values.0.decoder.key": decoder_kv[0],
"past_key_values.0.decoder.value": decoder_kv[1],
"past_key_values.0.encoder.key": encoder_kv[2],
"past_key_values.0.encoder.value": encoder_kv[3],
"past_key_values.1.decoder.key": decoder_kv[4],
"past_key_values.1.decoder.value": decoder_kv[5],
"past_key_values.1.encoder.key": encoder_kv[6],
"past_key_values.1.encoder.value": encoder_kv[7],
"past_key_values.2.decoder.key": decoder_kv[8],
"past_key_values.2.decoder.value": decoder_kv[9],
"past_key_values.2.encoder.key": encoder_kv[10],
"past_key_values.2.encoder.value": encoder_kv[11],
"past_key_values.3.decoder.key": decoder_kv[12],
"past_key_values.3.decoder.value": decoder_kv[13],
"past_key_values.3.encoder.key": encoder_kv[14],
"past_key_values.3.encoder.value": encoder_kv[15],
"past_key_values.4.decoder.key": decoder_kv[16],
"past_key_values.4.decoder.value": decoder_kv[17],
"past_key_values.4.encoder.key": encoder_kv[18],
"past_key_values.4.encoder.value": encoder_kv[19],
"past_key_values.5.decoder.key": decoder_kv[20],
"past_key_values.5.decoder.value": decoder_kv[21],
"past_key_values.5.encoder.key": encoder_kv[22],
"past_key_values.5.encoder.value": encoder_kv[23],
})
for output in decoder_outs:
print(output.shape)
# print("generated_token: ", processor.decode(next_token, skip_special_tokens=False))
# 删除第一个token
# generated_tokens = generated_tokens[1:]
# 将生成的tokens转换为文本
print("generated_tokens: ", generated_tokens)
generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0]
print("Generated Text:", generated_text)
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
print("Parsed Answer:", parsed_answer) |