import onnxruntime as ort from transformers import AutoProcessor from PIL import Image import numpy as np # set current working directory to the directory of this file import os os.chdir(os.path.dirname(os.path.abspath(__file__))) # embeddings vision_encoder = ort.InferenceSession("vision_encoder.onnx", providers=['CPUExecutionProvider']) text_embed = ort.InferenceSession("embed_tokens.onnx", providers=['CPUExecutionProvider']) # encoder encoder = ort.InferenceSession("encoder_model.onnx", providers=['CPUExecutionProvider']) # decoder decoder_prefill = ort.InferenceSession("decoder_model.onnx", providers=['CPUExecutionProvider']) decoder_decode = ort.InferenceSession("decoder_model_merged_q4.onnx", providers=['CPUExecutionProvider']) # 1. prepare inputs processor = AutoProcessor.from_pretrained("/home/zt/rk3588-nn/expr/Florence-2-base-ft", trust_remote_code=True) # 2. prepare image image = Image.open("./lena.png") # resize image to 512x512 image = image.resize((512, 512)) # 3. prepare text prompt = "" inputs = processor(text=prompt, images=image, return_tensors="np", do_resize=False) for k, v in inputs.items(): print(k, v.shape) # 4. run vision encoder image_features = vision_encoder.run(None, { "pixel_values": inputs["pixel_values"] }) for output in image_features: print(output.shape) image_features = image_features[0] np.save("image_features.npy", image_features) # 5. run text embed inputs_embeds = text_embed.run(None, { "input_ids": inputs["input_ids"] }) for output in inputs_embeds: print(output.shape) inputs_embeds = inputs_embeds[0] # 6. concat image features and text embed batch_size, image_token_length = image_features.shape[:-1] image_attention_mask = np.ones((batch_size, image_token_length)) task_prefix_embeds = inputs_embeds task_prefix_attention_mask = np.ones((batch_size, task_prefix_embeds.shape[1])) if len(task_prefix_attention_mask.shape) == 3: task_prefix_attention_mask = task_prefix_attention_mask[:, 0] inputs_embeds = np.concatenate([image_features, task_prefix_embeds], axis=1) attention_mask = np.concatenate([image_attention_mask, task_prefix_attention_mask], axis=1) # 6. run encoder encoder_out = encoder.run(None, { "inputs_embeds": inputs_embeds, "attention_mask": attention_mask.astype(np.int64) }) for output in encoder_out: print(output.shape) encoder_hidden_states = encoder_out[0] # 7. run decoder prefill stage decoder_outs = decoder_prefill.run(None, { "inputs_embeds": inputs_embeds[:, -1:], "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": attention_mask.astype(np.int64) }) for output in decoder_outs: print(output.shape) encoder_kv = decoder_outs[1:]; # 8. run decoder decode stage(autoregressive) generated_tokens = [] max_new_tokens = 32 while generated_tokens.__len__() < max_new_tokens: # 获取上一步的输出 logits = decoder_outs[0] decoder_kv = decoder_outs[1:] # 选择最后一个token的logits next_token_logits = logits[:, -1, :] # 使用argmax选择下一个token (贪心算法) next_token = np.argmax(next_token_logits, axis=-1)[0] print("next_token: ", next_token) # 将新生成的token添加到结果中 generated_tokens.append(next_token) # 如果生成了结束符,则停止生成 if next_token == 2: # break # 准备下一步的输入 next_input_embeds = text_embed.run(None, { "input_ids": np.array([[next_token]], dtype=np.int64) })[0] # 运行decoder的decode阶段 decoder_outs = decoder_decode.run(None, { "use_cache_branch": np.array([True], dtype=np.bool_), "inputs_embeds": next_input_embeds, "encoder_hidden_states": encoder_hidden_states, "encoder_attention_mask": attention_mask.astype(np.int64), "past_key_values.0.decoder.key": decoder_kv[0], "past_key_values.0.decoder.value": decoder_kv[1], "past_key_values.0.encoder.key": encoder_kv[2], "past_key_values.0.encoder.value": encoder_kv[3], "past_key_values.1.decoder.key": decoder_kv[4], "past_key_values.1.decoder.value": decoder_kv[5], "past_key_values.1.encoder.key": encoder_kv[6], "past_key_values.1.encoder.value": encoder_kv[7], "past_key_values.2.decoder.key": decoder_kv[8], "past_key_values.2.decoder.value": decoder_kv[9], "past_key_values.2.encoder.key": encoder_kv[10], "past_key_values.2.encoder.value": encoder_kv[11], "past_key_values.3.decoder.key": decoder_kv[12], "past_key_values.3.decoder.value": decoder_kv[13], "past_key_values.3.encoder.key": encoder_kv[14], "past_key_values.3.encoder.value": encoder_kv[15], "past_key_values.4.decoder.key": decoder_kv[16], "past_key_values.4.decoder.value": decoder_kv[17], "past_key_values.4.encoder.key": encoder_kv[18], "past_key_values.4.encoder.value": encoder_kv[19], "past_key_values.5.decoder.key": decoder_kv[20], "past_key_values.5.decoder.value": decoder_kv[21], "past_key_values.5.encoder.key": encoder_kv[22], "past_key_values.5.encoder.value": encoder_kv[23], }) for output in decoder_outs: print(output.shape) # print("generated_token: ", processor.decode(next_token, skip_special_tokens=False)) # 删除第一个token # generated_tokens = generated_tokens[1:] # 将生成的tokens转换为文本 print("generated_tokens: ", generated_tokens) generated_text = processor.batch_decode([generated_tokens], skip_special_tokens=False)[0] print("Generated Text:", generated_text) parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height)) print("Parsed Answer:", parsed_answer)