Spaces:
Paused
Paused
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from safetensors import safe_open | |
import json | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from huggingface_hub import snapshot_download | |
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageChunk | |
from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
import spaces | |
# Download model files | |
model_path = snapshot_download(repo_id="mistral-community/pixtral-12b-240910") | |
# Load model parameters and tokenizer configuration | |
with open(f'{model_path}/params.json', 'r') as f: | |
params = json.load(f) | |
with open(f'{model_path}/tekken.json', 'r') as f: | |
tokenizer_config = json.load(f) | |
class GELU(nn.Module): | |
def __init__(self, dim_in, dim_out, approximate='none', bias=True): | |
super().__init__() | |
self.linear = nn.Linear(dim_in, dim_out, bias=bias) | |
self.approximate = approximate | |
def forward(self, x): | |
if self.approximate == 'tanh': | |
return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
else: | |
return F.gelu(self.linear(x)) | |
class Rope2D(nn.Module): | |
def __init__(self, dim, max_position_embeddings=1024, base=10000): | |
super().__init__() | |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer("inv_freq", inv_freq) | |
self.max_seq_len_cached = max_position_embeddings | |
t = torch.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | |
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) | |
def forward(self, x, seq_len=None): | |
if seq_len > self.max_seq_len_cached: | |
self.max_seq_len_cached = seq_len | |
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | |
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) | |
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) | |
return ( | |
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | |
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), | |
) | |
class VisionEncoder(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.embed = nn.Conv2d(config['num_channels'], config['hidden_size'], kernel_size=config['patch_size'], stride=config['patch_size']) | |
self.rope = Rope2D(config['hidden_size'] // config['num_attention_heads'], base=config['rope_theta']) | |
self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=config['hidden_size'], nhead=config['num_attention_heads'], dim_feedforward=config['intermediate_size']) for _ in range(config['num_hidden_layers'])]) | |
self.norm = nn.LayerNorm(config['hidden_size']) | |
self.gelu = GELU(config['hidden_size'], config['hidden_size']) | |
def forward(self, pixel_values): | |
x = self.embed(pixel_values) | |
b, c, h, w = x.shape | |
x = x.flatten(2).transpose(1, 2) | |
cos, sin = self.rope(x, seq_len=h*w) | |
for layer in self.layers: | |
x = layer(x) | |
x = self.norm(x) | |
x = self.gelu(x) | |
return x | |
class PixtralModel(nn.Module): | |
def __init__(self, params): | |
super().__init__() | |
self.vision_encoder = VisionEncoder(params['vision_encoder']) | |
# Add text generation components here | |
def forward(self, image): | |
vision_output = self.vision_encoder(image) | |
# Add text generation logic here | |
return vision_output | |
def load_model(params, model_path): | |
model = PixtralModel(params) | |
with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f: | |
for name, param in model.named_parameters(): | |
if name in f.keys(): | |
param.data = f.get_tensor(name) | |
model.eval() | |
return model | |
# Initialize the model | |
model = load_model(params, model_path) | |
# Initialize the tokenizer | |
tokenizer = MistralTokenizer.from_model("pixtral") | |
def process_image_and_text(image, prompt): | |
# Prepare the image | |
image = image.convert('RGB') | |
image = image.resize((params['vision_encoder']['image_size'], params['vision_encoder']['image_size'])) | |
image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 | |
image_tensor = image_tensor.cuda() | |
# Tokenize the input | |
tokenized = tokenizer.encode_chat_completion( | |
ChatCompletionRequest( | |
messages=[ | |
UserMessage( | |
content=[ | |
TextChunk(text=prompt), | |
ImageChunk(image=image), | |
] | |
) | |
], | |
model="pixtral", | |
) | |
) | |
tokens, text, images = tokenized.tokens, tokenized.text, tokenized.images | |
# Process the image and generate text | |
with torch.no_grad(): | |
model.cuda() # Move model to GPU only when processing | |
vision_output = model(image_tensor) | |
model.cpu() # Move model back to CPU after processing | |
# Add text generation logic here | |
generated_text = f"Generated text based on the image and prompt: {prompt}" | |
return generated_text, len(tokens), len(images) | |
# Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# Pixtral Image-to-Text Model Demo") | |
gr.Markdown("Upload an image and provide a prompt to generate text based on it.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_image = gr.Image(type="pil") | |
input_prompt = gr.Textbox(label="Prompt") | |
submit_btn = gr.Button("Generate Text") | |
with gr.Column(scale=1): | |
output_text = gr.Textbox(label="Generated Text") | |
token_count = gr.Number(label="Number of Tokens") | |
image_count = gr.Number(label="Number of Images") | |
submit_btn.click( | |
fn=process_image_and_text, | |
inputs=[input_image, input_prompt], | |
outputs=[output_text, token_count, image_count] | |
) | |
gr.Markdown("## How it works") | |
gr.Markdown("1. The image is processed by a Vision Encoder using 2D ROPE (Rotary Position Embedding).") | |
gr.Markdown("2. The encoder uses GELU activation in its layers.") | |
gr.Markdown("3. The encoded image and the prompt are used to generate descriptive text.") | |
gr.Markdown("## Model Details") | |
gr.Markdown(f"- Vision Encoder Hidden Size: {params['vision_encoder']['hidden_size']}") | |
gr.Markdown(f"- Number of Vision Encoder Layers: {params['vision_encoder']['num_hidden_layers']}") | |
gr.Markdown(f"- Number of Attention Heads: {params['vision_encoder']['num_attention_heads']}") | |
gr.Markdown(f"- Image Size: {params['vision_encoder']['image_size']}x{params['vision_encoder']['image_size']}") | |
gr.Markdown(f"- Patch Size: {params['vision_encoder']['patch_size']}x{params['vision_encoder']['patch_size']}") | |
if __name__ == "__main__": | |
demo.launch() |