File size: 5,828 Bytes
95dfa6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import onnxruntime as ort
from transformers import AutoProcessor
from PIL import Image
import numpy as np
# set current working directory to the directory of this file
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))

# embeddings
vision_encoder = ort.InferenceSession("vision_encoder.onnx", providers=['CPUExecutionProvider'])
text_embed = ort.InferenceSession("embed_tokens.onnx", providers=['CPUExecutionProvider'])
# encoder
encoder = ort.InferenceSession("encoder_model.onnx", providers=['CPUExecutionProvider'])
# decoder
decoder_prefill = ort.InferenceSession("decoder_model.onnx", providers=['CPUExecutionProvider'])
decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider'])

# 1. prepare inputs
processor = AutoProcessor.from_pretrained("/home/zt/rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True)

# 2. prepare image
image = Image.open("./lena.png")

# resize image to 512x512
image = image.resize((512, 512))
# 3. prepare text
prompt = "<MORE_DETAILED_CAPTION>"
inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False)
for k, v in inputs.items():
    print(k, v.shape)

# 4. run vision encoder
image_features = vision_encoder.run(None, {
    "pixel_values": inputs["pixel_values"]
})
for output in image_features:
    print(output.shape)
image_features = image_features[0]
np.save("image_features.npy", image_features)
# 5. run text embed
inputs_embeds = text_embed.run(None, {
    "input_ids": inputs["input_ids"]
})
for output in inputs_embeds:
    print(output.shape)

inputs_embeds = inputs_embeds[0]

# 6. concat image features and text embed
batch_size, image_token_length = image_features.shape[:-1]
image_attention_mask = np.ones((batch_size, image_token_length))
task_prefix_embeds = inputs_embeds
task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1]))
if len(task_prefix_attention_mask.shape) == 3:
    task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1)
attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1)

# 6. run encoder
encoder_out = encoder.run(None, {
    "inputs_embeds": inputs_embeds,
    "attention_mask": attention_mask.astype(np.int64)
})
for output in encoder_out:
    print(output.shape)

encoder_hidden_states = encoder_out[0]

# 7. run decoder prefill stage
decoder_outs = decoder_prefill.run(None, {
    "inputs_embeds": inputs_embeds[:, -1:],
    "encoder_hidden_states": encoder_hidden_states,
    "encoder_attention_mask": attention_mask.astype(np.int64)
})
for output in decoder_outs:
    print(output.shape)

encoder_kv = decoder_outs[1:];

# 8. run decoder decode stage(autoregressive)
generated_tokens = []
max_new_tokens = 32
while generated_tokens.__len__() < max_new_tokens:
    # 获取上一步的输出
    logits = decoder_outs[0]
    decoder_kv = decoder_outs[1:]
    
    # 选择最后一个token的logits
    next_token_logits = logits[:, -1, :]
    
    # 使用argmax选择下一个token (贪心算法)
    next_token = np.argmax(next_token_logits, axis=-1)[0]
    print("next_token: ", next_token)
    # 将新生成的token添加到结果中
    generated_tokens.append(next_token)

    # 如果生成了结束符,则停止生成
    if next_token == 2: # </s>
        break
    
    # 准备下一步的输入
    next_input_embeds = text_embed.run(None, {
        "input_ids": np.array([[next_token]], dtype=np.int64)
    })[0]

    # 运行decoder的decode阶段
    decoder_outs = decoder_decode.run(None, {
        "use_cache_branch": np.array([True], dtype=np.bool_),
        "inputs_embeds": next_input_embeds,
        "encoder_hidden_states": encoder_hidden_states,
        "encoder_attention_mask": attention_mask.astype(np.int64),
        "past_key_values.0.decoder.key": decoder_kv[0],
        "past_key_values.0.decoder.value": decoder_kv[1],
        "past_key_values.0.encoder.key": encoder_kv[2],
        "past_key_values.0.encoder.value": encoder_kv[3],
        "past_key_values.1.decoder.key": decoder_kv[4],
        "past_key_values.1.decoder.value": decoder_kv[5],
        "past_key_values.1.encoder.key": encoder_kv[6],
        "past_key_values.1.encoder.value": encoder_kv[7],
        "past_key_values.2.decoder.key": decoder_kv[8],
        "past_key_values.2.decoder.value": decoder_kv[9],
        "past_key_values.2.encoder.key": encoder_kv[10],
        "past_key_values.2.encoder.value": encoder_kv[11],
        "past_key_values.3.decoder.key": decoder_kv[12],
        "past_key_values.3.decoder.value": decoder_kv[13],
        "past_key_values.3.encoder.key": encoder_kv[14],
        "past_key_values.3.encoder.value": encoder_kv[15],
        "past_key_values.4.decoder.key": decoder_kv[16],
        "past_key_values.4.decoder.value": decoder_kv[17],
        "past_key_values.4.encoder.key": encoder_kv[18],
        "past_key_values.4.encoder.value": encoder_kv[19],
        "past_key_values.5.decoder.key": decoder_kv[20],
        "past_key_values.5.decoder.value": decoder_kv[21],
        "past_key_values.5.encoder.key": encoder_kv[22],
        "past_key_values.5.encoder.value": encoder_kv[23],
    })
    for output in decoder_outs:
        print(output.shape)

    # print("generated_token: ", processor.decode(next_token, skip_special_tokens=False))

# 删除第一个token
# generated_tokens = generated_tokens[1:]
# 将生成的tokens转换为文本
print("generated_tokens: ", generated_tokens)
generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0]
print("Generated Text:", generated_text)
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
print("Parsed Answer:", parsed_answer)