|
from rknnlite.api.rknn_lite import RKNNLite |
|
from transformers import AutoProcessor |
|
from PIL import Image |
|
import numpy as np |
|
import onnxruntime as ort |
|
import time |
|
|
|
import os |
|
os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
total_time = 0 |
|
|
|
|
|
rknn_vision_encoder = RKNNLite(verbose=False) |
|
rknn_encoder = RKNNLite(verbose=False) |
|
rknn_decoder_prefill = RKNNLite(verbose=False) |
|
|
|
|
|
ret = rknn_vision_encoder.load_rknn('./vision_encoder.rknn') |
|
ret = rknn_encoder.load_rknn('./encoder_model.rknn') |
|
ret = rknn_decoder_prefill.load_rknn('./decoder_model.rknn') |
|
|
|
|
|
ret = rknn_vision_encoder.init_runtime() |
|
ret = rknn_encoder.init_runtime() |
|
ret = rknn_decoder_prefill.init_runtime() |
|
|
|
text_embed = ort.InferenceSession("embed_tokens.onnx", providers=['CPUExecutionProvider']) |
|
decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider']) |
|
|
|
|
|
|
|
processor = AutoProcessor.from_pretrained("/home/firefly/mnt/zt-rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True) |
|
|
|
|
|
image = Image.open("./lena.png") |
|
|
|
|
|
image = image.resize((512, 512)) |
|
|
|
prompt = "<MORE_DETAILED_CAPTION>" |
|
inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False) |
|
for k, v in inputs.items(): |
|
print(k, v.shape) |
|
|
|
|
|
start_time = time.time() |
|
image_features = rknn_vision_encoder.inference(inputs=[inputs["pixel_values"]])[0] |
|
end_time = time.time() |
|
vision_encoder_time = (end_time - start_time) * 1000 |
|
total_time += vision_encoder_time |
|
print(f"Vision encoder time: {vision_encoder_time:.2f} ms") |
|
print(image_features.shape) |
|
np.save("image_features.npy", image_features) |
|
|
|
|
|
start_time = time.time() |
|
inputs_embeds = text_embed.run(None, { |
|
"input_ids": inputs["input_ids"] |
|
})[0] |
|
end_time = time.time() |
|
text_embed_time = (end_time - start_time) * 1000 |
|
total_time += text_embed_time |
|
print(f"Text embed time: {text_embed_time:.2f} ms") |
|
print(inputs_embeds.shape) |
|
|
|
|
|
batch_size, image_token_length = image_features.shape[:-1] |
|
image_attention_mask = np.ones((batch_size, image_token_length)) |
|
task_prefix_embeds = inputs_embeds |
|
task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1])) |
|
if len(task_prefix_attention_mask.shape) == 3: |
|
task_prefix_attention_mask = task_prefix_attention_mask[:, 0] |
|
inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1) |
|
attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1) |
|
|
|
|
|
start_time = time.time() |
|
encoder_out = rknn_encoder.inference(inputs=[attention_mask.astype(np.int64),inputs_embeds]) |
|
end_time = time.time() |
|
encoder_time = (end_time - start_time) * 1000 |
|
total_time += encoder_time |
|
print(f"Encoder time: {encoder_time:.2f} ms") |
|
encoder_hidden_states = encoder_out[0] |
|
print(encoder_hidden_states.shape) |
|
|
|
|
|
start_time = time.time() |
|
decoder_outs = rknn_decoder_prefill.inference(inputs=[attention_mask.astype(np.int64), encoder_hidden_states,inputs_embeds[:, -1:]]) |
|
end_time = time.time() |
|
decoder_prefill_time = (end_time - start_time) * 1000 |
|
total_time += decoder_prefill_time |
|
print(f"Decoder prefill time: {decoder_prefill_time:.2f} ms") |
|
|
|
|
|
|
|
encoder_kv = decoder_outs[1:] |
|
|
|
|
|
generated_tokens = [] |
|
max_new_tokens = 32 |
|
decoder_decode_total_time = 0 |
|
while generated_tokens.__len__() < max_new_tokens: |
|
|
|
logits = decoder_outs[0] |
|
decoder_kv = decoder_outs[1:] |
|
|
|
|
|
next_token_logits = logits[:, -1, :] |
|
|
|
|
|
next_token = np.argmax(next_token_logits, axis=-1)[0] |
|
|
|
|
|
generated_tokens.append(next_token) |
|
|
|
|
|
if next_token == 2: |
|
break |
|
|
|
|
|
start_time = time.time() |
|
next_input_embeds = text_embed.run(None, { |
|
"input_ids": np.array([[next_token]], dtype=np.int64) |
|
})[0] |
|
end_time = time.time() |
|
text_embed_time = (end_time - start_time) * 1000 |
|
decoder_decode_total_time += text_embed_time |
|
|
|
|
|
start_time = time.time() |
|
decoder_outs = decoder_decode.run(None, { |
|
"use_cache_branch": np.array([True], dtype=np.bool_), |
|
"inputs_embeds": next_input_embeds, |
|
"encoder_hidden_states": encoder_hidden_states, |
|
"encoder_attention_mask": attention_mask.astype(np.int64), |
|
"past_key_values.0.decoder.key": decoder_kv[0], |
|
"past_key_values.0.decoder.value": decoder_kv[1], |
|
"past_key_values.0.encoder.key": encoder_kv[2], |
|
"past_key_values.0.encoder.value": encoder_kv[3], |
|
"past_key_values.1.decoder.key": decoder_kv[4], |
|
"past_key_values.1.decoder.value": decoder_kv[5], |
|
"past_key_values.1.encoder.key": encoder_kv[6], |
|
"past_key_values.1.encoder.value": encoder_kv[7], |
|
"past_key_values.2.decoder.key": decoder_kv[8], |
|
"past_key_values.2.decoder.value": decoder_kv[9], |
|
"past_key_values.2.encoder.key": encoder_kv[10], |
|
"past_key_values.2.encoder.value": encoder_kv[11], |
|
"past_key_values.3.decoder.key": decoder_kv[12], |
|
"past_key_values.3.decoder.value": decoder_kv[13], |
|
"past_key_values.3.encoder.key": encoder_kv[14], |
|
"past_key_values.3.encoder.value": encoder_kv[15], |
|
"past_key_values.4.decoder.key": decoder_kv[16], |
|
"past_key_values.4.decoder.value": decoder_kv[17], |
|
"past_key_values.4.encoder.key": encoder_kv[18], |
|
"past_key_values.4.encoder.value": encoder_kv[19], |
|
"past_key_values.5.decoder.key": decoder_kv[20], |
|
"past_key_values.5.decoder.value": decoder_kv[21], |
|
"past_key_values.5.encoder.key": encoder_kv[22], |
|
"past_key_values.5.encoder.value": encoder_kv[23], |
|
}) |
|
end_time = time.time() |
|
decoder_decode_time = (end_time - start_time) * 1000 |
|
decoder_decode_total_time += decoder_decode_time |
|
|
|
total_time += decoder_decode_total_time |
|
print(f"Decoder decode total time: {decoder_decode_total_time:.2f} ms") |
|
|
|
|
|
print("generated_tokens: ", generated_tokens) |
|
generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0] |
|
print("Generated Text:", generated_text) |
|
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) |
|
print("Parsed Answer:", parsed_answer) |
|
|
|
print(f"Total inference time: {total_time:.2f} ms") |
|
|
|
|
|
rknn_vision_encoder.release() |
|
rknn_encoder.release() |
|
rknn_decoder_prefill.release() |