Pixtral / app.py
Tonic's picture
add reference code from vllm
24b8c6e unverified
raw
history blame
13.1 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from safetensors import safe_open
import json
import gradio as gr
from PIL import Image
import numpy as np
from huggingface_hub import snapshot_download
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageChunk
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
import spaces
import math
from typing import List, Optional, Tuple
title = "# **WIP / DEMO** 🙋🏻‍♂️Welcome to Tonic's Pixtral Model Demo"
description = """
This demo showcases two capabilities of the Pixtral model:
1. Image-to-Text Generation
2. Image Similarity Comparison
### Join us :
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗
"""
model_path = snapshot_download(repo_id="mistralai/Pixtral-12B-2409")
with open(f'{model_path}/params.json', 'r') as f:
params = json.load(f)
with open(f'{model_path}/tekken.json', 'r') as f:
tokenizer_config = json.load(f)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float) -> torch.Tensor:
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim))
h = torch.arange(height)
w = torch.arange(width)
freqs_h = torch.outer(h, freqs[::2]).float()
freqs_w = torch.outer(w, freqs[1::2]).float()
freqs_2d = torch.cat([freqs_h[:, None, :].repeat(1, width, 1), freqs_w[None, :, :].repeat(height, 1, 1)], dim=-1)
return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
def apply_rotary_emb_vit(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis.view(*freqs_cis.shape[:2], 1, freqs_cis.shape[-1])
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def __init__(self, args):
super().__init__()
self.n_heads = args['num_attention_heads']
self.head_dim = args['hidden_size'] // args['num_attention_heads']
self.wq = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
self.wk = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
self.wv = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
self.wo = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
batch, patches, _ = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(batch, patches, self.n_heads, self.head_dim)
k = k.reshape(batch, patches, self.n_heads, self.head_dim)
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
return self.wo(out)
class FeedForward(nn.Module):
def __init__(self, args):
super().__init__()
self.w1 = nn.Linear(args['hidden_size'], args['intermediate_size'], bias=False)
self.w2 = nn.Linear(args['intermediate_size'], args['hidden_size'], bias=False)
self.w3 = nn.Linear(args['hidden_size'], args['intermediate_size'], bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args):
super().__init__()
self.attention = Attention(args)
self.feed_forward = FeedForward(args)
self.attention_norm = RMSNorm(args['hidden_size'], eps=1e-5)
self.ffn_norm = RMSNorm(args['hidden_size'], eps=1e-5)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
r = self.attention(self.attention_norm(x), freqs_cis=freqs_cis)
h = x + r
r = self.feed_forward(self.ffn_norm(h))
out = h + r
return out
class VisionTransformer(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
self.patch_conv = nn.Conv2d(args['num_channels'], args['hidden_size'], kernel_size=args['patch_size'], stride=args['patch_size'], bias=False)
self.ln_pre = RMSNorm(args['hidden_size'], eps=1e-5)
self.transformer = nn.ModuleList([TransformerBlock(args) for _ in range(args['num_hidden_layers'])])
self.max_patches_per_side = args['image_size'] // args['patch_size']
self._freqs_cis = None
@property
def freqs_cis(self) -> torch.Tensor:
if self._freqs_cis is None:
self._freqs_cis = precompute_freqs_cis_2d(
dim=self.args['hidden_size'] // self.args['num_attention_heads'],
height=self.max_patches_per_side,
width=self.max_patches_per_side,
theta=self.args['rope_theta'],
)
return self._freqs_cis.to(self.patch_conv.weight.device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_conv(x)
x = x.flatten(2).transpose(1, 2)
x = self.ln_pre(x)
freqs_cis = self.freqs_cis
for layer in self.transformer:
x = layer(x, freqs_cis=freqs_cis)
return x
class VisionLanguageAdapter(nn.Module):
def __init__(self, args, dim: int):
super().__init__()
self.w_in = nn.Linear(args['hidden_size'], dim, bias=True)
self.gelu = nn.GELU()
self.w_out = nn.Linear(dim, dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_out(self.gelu(self.w_in(x)))
class PixtralModel(nn.Module):
def __init__(self, params):
super().__init__()
self.vision_encoder = VisionTransformer(params['vision_encoder'])
self.vision_language_adapter = VisionLanguageAdapter(params['vision_encoder'], params['dim'])
self.language_model = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=params['dim'], nhead=params['n_heads'], dim_feedforward=params['hidden_dim']),
num_layers=params['n_layers']
)
self.lm_head = nn.Linear(params['dim'], params['vocab_size'], bias=False)
def forward(self, image, input_ids=None):
vision_output = self.vision_encoder(image)
vision_output = self.vision_language_adapter(vision_output)
if input_ids is not None:
tgt = self.lm_head.weight[input_ids].transpose(0, 1)
output = self.language_model(tgt, vision_output)
logits = self.lm_head(output)
return logits
else:
return vision_output
def load_model(params, model_path):
model = PixtralModel(params)
with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f:
for name, param in model.named_parameters():
if name in f.keys():
param.data = f.get_tensor(name)
model.eval()
return model
model = load_model(params, model_path)
tokenizer = MistralTokenizer.from_model("pixtral")
def preprocess_image(image):
if image is None:
raise ValueError("No image provided")
image = image.convert('RGB')
image = image.resize((params['vision_encoder']['image_size'], params['vision_encoder']['image_size']))
image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
return image_tensor
@spaces.GPU(duration=120)
def generate_text(image, prompt, max_tokens):
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_tensor = preprocess_image(image).to(device)
model.to(device)
tokenized = tokenizer.encode_chat_completion(
ChatCompletionRequest(
messages=[UserMessage(content=[TextChunk(text=prompt), ImageChunk(image=image)])],
model="pixtral",
)
)
input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).to(device)
for _ in range(max_tokens):
logits = model(image_tensor, input_ids)
next_token_logits = logits[0, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1)
if next_token.item() == tokenizer.eos_token_id:
break
generated_text = tokenizer.decode(input_ids[0].tolist())
# model.to("cpu")
return generated_text, len(input_ids[0]), 1
except Exception as e:
return f"Error: {str(e)}", 0, 0
@spaces.GPU(duration=60)
def calculate_similarity(image1, image2):
try:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor1 = preprocess_image(image1).to(device)
tensor2 = preprocess_image(image2).to(device)
model.to(device)
embedding1 = model(tensor1).mean(dim=1)
embedding2 = model(tensor2).mean(dim=1)
similarity = F.cosine_similarity(embedding1, embedding2).item()
# model.to("cpu")
return similarity
except Exception as e:
return f"Error: {str(e)}"
with gr.Blocks() as demo:
gr.Markdown(title)
gr.Markdown("## Model Details")
gr.Markdown(f"- Model Dimension: {params['dim']}")
gr.Markdown(f"- Number of Layers: {params['n_layers']}")
gr.Markdown(f"- Number of Attention Heads: {params['n_heads']}")
gr.Markdown(f"- Vision Encoder Hidden Size: {params['vision_encoder']['hidden_size']}")
gr.Markdown(f"- Number of Vision Encoder Layers: {params['vision_encoder']['num_hidden_layers']}")
gr.Markdown(f"- Number of Vision Encoder Attention Heads: {params['vision_encoder']['num_attention_heads']}")
gr.Markdown(f"- Image Size: {params['vision_encoder']['image_size']}x{params['vision_encoder']['image_size']}")
gr.Markdown(f"- Patch Size: {params['vision_encoder']['patch_size']}x{params['vision_encoder']['patch_size']}")
gr.Markdown("## How it works")
gr.Markdown("1. The image is processed by a Vision Encoder using 2D ROPE (Rotary Position Embedding).")
gr.Markdown("2. The encoder uses SiLU activation in its feed-forward layers.")
gr.Markdown("3. The encoded image is used for text generation or similarity comparison.")
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem("Image-to-Text Generation"):
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Input Image")
input_prompt = gr.Textbox(label="Prompt")
max_tokens_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max Tokens")
submit_btn = gr.Button("Generate Text")
with gr.Column():
output_text = gr.Textbox(label="Generated Text")
token_count = gr.Number(label="Number of Tokens")
image_count = gr.Number(label="Number of Images Processed")
submit_btn.click(
fn=generate_text,
inputs=[input_image, input_prompt, max_tokens_slider],
outputs=[output_text, token_count, image_count]
)
with gr.TabItem("Image Similarity Comparison"):
with gr.Row():
image1_input = gr.Image(type="pil", label="Image 1")
image2_input = gr.Image(type="pil", label="Image 2")
similarity_btn = gr.Button("📸🌬️Calculate Similarity")
similarity_output = gr.Number(label="Similarity Score (0.0 to 1.0)")
similarity_btn.click(
fn=calculate_similarity,
inputs=[image1_input, image2_input],
outputs=[similarity_output]
)
if __name__ == "__main__":
demo.launch()