Upload ONNX weights

#1
by Xenova HF staff - opened

Conversion script:

import os
import torch
import torch.nn as nn
from transformers import (
    AutoProcessor,
    PaliGemmaForConditionalGeneration,
    DynamicCache,
)

model_id="optimum-internal-testing/tiny-random-PaliGemmaForConditionalGeneration"

def new_len(self: torch.Tensor):
    return self.shape[0]

torch.Tensor.__len__ = new_len


class VisionEncoder(nn.Module):
  def __init__(self, paligemma_model):
    super().__init__()
    self.config = paligemma_model.config
    self.vision_tower = paligemma_model.vision_tower
    self.multi_modal_projector = paligemma_model.multi_modal_projector

  def forward(self, pixel_values: torch.FloatTensor):
      """
      Obtains image last hidden states from the vision tower and apply multimodal projection.

      Args:
          pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
              The tensors corresponding to the input images.
      Returns:
          image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
      """
      image_outputs = self.vision_tower(pixel_values)
      selected_image_feature = image_outputs.last_hidden_state
      image_features = self.multi_modal_projector(selected_image_feature)
      image_features = image_features / (self.config.text_config.hidden_size**0.5)
      return image_features


class PatchedPaliGemmaForConditionalGeneration(PaliGemmaForConditionalGeneration):
    def forward(self, *args):
        inputs_embeds, position_ids, *past_key_values_args = args
        config = model.config.text_config
        
        # Convert past_key_values list to DynamicCache
        if len(past_key_values_args) == 0:
            past_key_values = None
        else:
            past_key_values = DynamicCache(config.num_hidden_layers)
            for i in range(config.num_hidden_layers):
                key = past_key_values_args.pop(0)
                value = past_key_values_args.pop(0)
                past_key_values.update(key_states=key, value_states=value, layer_idx=i)


        batch_size = inputs_embeds.shape[0]

        o = self.language_model.forward(
            inputs_embeds=inputs_embeds,
            # Create a 4D attention mask of all zeros (attend to everything)
            attention_mask=torch.zeros(
                batch_size,
                1, # num_attention_heads (1 -> expand to num_attention_heads)
                1, # sequence_length (1 -> expand to sequence_length)
                1, # total_sequence_length (1 -> expand to total_sequence_length)
                dtype=torch.float32,
            ),
            position_ids=position_ids,
            past_key_values=past_key_values,
        )

        flattened_past_key_values_outputs = {
            "logits": o.logits,
        }
        output_past_key_values: DynamicCache = o.past_key_values
        for i, (key, value) in enumerate(
            zip(output_past_key_values.key_cache, output_past_key_values.value_cache)
        ):
            flattened_past_key_values_outputs[f"present.{i}.key"] = key
            flattened_past_key_values_outputs[f"present.{i}.value"] = value

        return flattened_past_key_values_outputs


# Constants
OUTPUT_FOLDER = os.path.join("output", model_id)
TEXT_MODEL_NAME = "decoder_model_merged.onnx"
VISION_MODEL_NAME = "vision_encoder.onnx"
EMBED_MODEL_NAME = "embed_tokens.onnx"
TEMP_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "temp")
FINAL_MODEL_OUTPUT_FOLDER = os.path.join(OUTPUT_FOLDER, "onnx")


# Load model and processor
model = PatchedPaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
).eval()
vision_model = VisionEncoder(model)
embed_layer = model.language_model.model.embed_tokens

processor = AutoProcessor.from_pretrained(model_id)

# Save model configs and processor
model.config.save_pretrained(OUTPUT_FOLDER)
model.generation_config.save_pretrained(OUTPUT_FOLDER)
processor.save_pretrained(OUTPUT_FOLDER)
os.makedirs(TEMP_MODEL_OUTPUT_FOLDER, exist_ok=True)


# Configuration values
## Text model
text_config = model.config.text_config
num_attention_heads = text_config.num_attention_heads
num_key_value_heads = text_config.num_key_value_heads
head_dim = text_config.head_dim
num_layers = text_config.num_hidden_layers
hidden_size = text_config.hidden_size

# Dummy input sizes
batch_size = 2
sequence_length = 32
past_sequence_length = 8

## Text inputs
dummy_past_key_values_kwargs = {
    f"past_key_values.{i}.{key}": torch.zeros(
        batch_size,
        num_key_value_heads,
        past_sequence_length,
        head_dim,
        dtype=torch.float32,
    )
    for i in range(num_layers)
    for key in ["key", "value"]
}
inputs_embeds = torch.randn(
    (batch_size, sequence_length, hidden_size),
)

total_sequence_length = sequence_length + past_sequence_length
position_ids = torch.arange(1, sequence_length + 1, dtype=torch.int64).expand(batch_size, sequence_length)

text_inputs = dict(
    inputs_embeds=inputs_embeds,
    position_ids=position_ids,
    **dummy_past_key_values_kwargs,
)
text_inputs_positional = tuple(text_inputs.values())
text_outputs = model.forward(*text_inputs_positional)  # Test forward pass

## Vision inputs
size = processor.image_processor.size
w, h = size['width'], size['height']
pixel_values = torch.randn(2, 3, h, w, requires_grad=True)
vision_inputs = dict(pixel_values=pixel_values)
vision_inputs_positional = tuple(vision_inputs.values())
vision_outputs = vision_model.forward(*vision_inputs_positional)  # Test forward pass



# ONNX Exports
from torch.onnx._globals import GLOBALS
GLOBALS.onnx_shape_inference = False # Bug in pytorch

## Text model
TEXT_MODEL_OUTPUT_PATH=os.path.join(TEMP_MODEL_OUTPUT_FOLDER, TEXT_MODEL_NAME)
torch.onnx.export(
    model,
    args=text_inputs_positional,
    f=TEXT_MODEL_OUTPUT_PATH,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=list(text_inputs.keys()),
    output_names=["logits"]
    + [f"present.{i}.{key}" for i in range(num_layers) for key in ["key", "value"]],
    dynamic_axes={
        "inputs_embeds": {0: "batch_size", 1: "sequence_length"},
        "position_ids": {0: "batch_size", 1: "sequence_length"},
        **{
            f"past_key_values.{i}.{key}": {0: "batch_size", 2: "past_sequence_length"}
            for i in range(num_layers)
            for key in ["key", "value"]
        },
        "logits": {0: "batch_size", 1: "sequence_length"},
        **{
            f"present.{i}.{key}": {0: "batch_size", 2: "total_sequence_length"}
            for i in range(num_layers)
            for key in ["key", "value"]
        },
    },
)

## Vision model
VISION_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, VISION_MODEL_NAME)
torch.onnx.export(
    vision_model,
    args=vision_inputs_positional,
    f=VISION_MODEL_OUTPUT_PATH,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=['pixel_values'],
    output_names=['image_features'],
    dynamic_axes={
        'pixel_values': {0: 'batch_size'},
        'image_features': {0: 'batch_size'}
    },
)

input_ids = torch.randint(0, embed_layer.num_embeddings, (batch_size, sequence_length))

## Embedding model
EMBED_MODEL_OUTPUT_PATH = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, EMBED_MODEL_NAME)
torch.onnx.export(
    embed_layer,
    args=(input_ids,),
    f=EMBED_MODEL_OUTPUT_PATH,
    export_params=True,
    opset_version=14,
    do_constant_folding=True,
    input_names=['input_ids'],
    output_names=['inputs_embeds'],
    dynamic_axes={
        'input_ids': {0: 'batch_size', 1: 'sequence_length'},
        'inputs_embeds': {0: 'batch_size', 1: 'sequence_length'}
    },
)


# Post-processing
import onnx
import onnxslim
from optimum.onnx.graph_transformations import check_and_save_model

os.makedirs(FINAL_MODEL_OUTPUT_FOLDER, exist_ok=True)
for name in (TEXT_MODEL_NAME, VISION_MODEL_NAME, EMBED_MODEL_NAME):
    temp_model_path = os.path.join(TEMP_MODEL_OUTPUT_FOLDER, name)

    onnx.shape_inference.infer_shapes_path(temp_model_path, check_type=True, strict_mode=True)

    ## Attempt to optimize the model with onnxslim
    try:
        onnx_model = onnxslim.slim(temp_model_path)
    except Exception as e:
        print(f"Failed to slim {temp_model_path}: {e}")
        onnx_model = onnx.load(temp_model_path)

    ## Save model
    final_model_path = os.path.join(FINAL_MODEL_OUTPUT_FOLDER, name)
    check_and_save_model(onnx_model, final_model_path)

## Minify tokenizer.json
import json
tokenizer_path = os.path.join(OUTPUT_FOLDER, "tokenizer.json")
with open(tokenizer_path, "r") as f:
    tokenizer = json.load(f)
with open(tokenizer_path, "w") as f:
    json.dump(tokenizer, f) # No need for indenting

# Add head_dim and num_image_tokens to config.json
config_path = os.path.join(OUTPUT_FOLDER, "config.json")
with open(config_path, "r") as f:
    config = json.load(f)
config["text_config"]["head_dim"] = head_dim
config["num_image_tokens"] = config["text_config"]["num_image_tokens"]
with open(config_path, "w") as f:
    json.dump(config, f, indent=2)

## Cleanup
import shutil
shutil.rmtree(TEMP_MODEL_OUTPUT_FOLDER)
Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment