import torch import torch.nn as nn import torch.nn.functional as F from safetensors import safe_open import json import gradio as gr from PIL import Image import numpy as np from huggingface_hub import snapshot_download from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageChunk from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.tokens.tokenizers.mistral import MistralTokenizer import spaces import math from typing import List, Optional, Tuple title = "# **WIP / DEMO** 🙋🏻‍♂️Welcome to Tonic's Pixtral Model Demo" description = """ This demo showcases two capabilities of the Pixtral model: 1. Image-to-Text Generation 2. Image Similarity Comparison ### Join us : 🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 """ model_path = snapshot_download(repo_id="mistralai/Pixtral-12B-2409") with open(f'{model_path}/params.json', 'r') as f: params = json.load(f) with open(f'{model_path}/tekken.json', 'r') as f: tokenizer_config = json.load(f) class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-5): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float) -> torch.Tensor: freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim)) h = torch.arange(height) w = torch.arange(width) freqs_h = torch.outer(h, freqs[::2]).float() freqs_w = torch.outer(w, freqs[1::2]).float() freqs_2d = torch.cat([freqs_h[:, None, :].repeat(1, width, 1), freqs_w[None, :, :].repeat(height, 1, 1)], dim=-1) return torch.polar(torch.ones_like(freqs_2d), freqs_2d) def apply_rotary_emb_vit(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = freqs_cis.view(*freqs_cis.shape[:2], 1, freqs_cis.shape[-1]) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) class Attention(nn.Module): def __init__(self, args): super().__init__() self.n_heads = args['num_attention_heads'] self.head_dim = args['hidden_size'] // args['num_attention_heads'] self.wq = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False) self.wk = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False) self.wv = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False) self.wo = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False) def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: batch, patches, _ = x.shape q, k, v = self.wq(x), self.wk(x), self.wv(x) q = q.reshape(batch, patches, self.n_heads, self.head_dim) k = k.reshape(batch, patches, self.n_heads, self.head_dim) v = v.reshape(batch, patches, self.n_heads, self.head_dim) q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim) attn = F.softmax(scores, dim=-1) out = torch.matmul(attn, v) out = out.reshape(batch, patches, self.n_heads * self.head_dim) return self.wo(out) class FeedForward(nn.Module): def __init__(self, args): super().__init__() self.w1 = nn.Linear(args['hidden_size'], args['intermediate_size'], bias=False) self.w2 = nn.Linear(args['intermediate_size'], args['hidden_size'], bias=False) self.w3 = nn.Linear(args['hidden_size'], args['intermediate_size'], bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, args): super().__init__() self.attention = Attention(args) self.feed_forward = FeedForward(args) self.attention_norm = RMSNorm(args['hidden_size'], eps=1e-5) self.ffn_norm = RMSNorm(args['hidden_size'], eps=1e-5) def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: r = self.attention(self.attention_norm(x), freqs_cis=freqs_cis) h = x + r r = self.feed_forward(self.ffn_norm(h)) out = h + r return out class VisionTransformer(nn.Module): def __init__(self, args): super().__init__() self.args = args self.patch_conv = nn.Conv2d(args['num_channels'], args['hidden_size'], kernel_size=args['patch_size'], stride=args['patch_size'], bias=False) self.ln_pre = RMSNorm(args['hidden_size'], eps=1e-5) self.transformer = nn.ModuleList([TransformerBlock(args) for _ in range(args['num_hidden_layers'])]) self.max_patches_per_side = args['image_size'] // args['patch_size'] self._freqs_cis = None @property def freqs_cis(self) -> torch.Tensor: if self._freqs_cis is None: self._freqs_cis = precompute_freqs_cis_2d( dim=self.args['hidden_size'] // self.args['num_attention_heads'], height=self.max_patches_per_side, width=self.max_patches_per_side, theta=self.args['rope_theta'], ) return self._freqs_cis.to(self.patch_conv.weight.device) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_conv(x) x = x.flatten(2).transpose(1, 2) x = self.ln_pre(x) freqs_cis = self.freqs_cis for layer in self.transformer: x = layer(x, freqs_cis=freqs_cis) return x class VisionLanguageAdapter(nn.Module): def __init__(self, args, dim: int): super().__init__() self.w_in = nn.Linear(args['hidden_size'], dim, bias=True) self.gelu = nn.GELU() self.w_out = nn.Linear(dim, dim, bias=True) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) class PixtralModel(nn.Module): def __init__(self, params): super().__init__() self.vision_encoder = VisionTransformer(params['vision_encoder']) self.vision_language_adapter = VisionLanguageAdapter(params['vision_encoder'], params['dim']) self.language_model = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=params['dim'], nhead=params['n_heads'], dim_feedforward=params['hidden_dim']), num_layers=params['n_layers'] ) self.lm_head = nn.Linear(params['dim'], params['vocab_size'], bias=False) def forward(self, image, input_ids=None): vision_output = self.vision_encoder(image) vision_output = self.vision_language_adapter(vision_output) if input_ids is not None: tgt = self.lm_head.weight[input_ids].transpose(0, 1) output = self.language_model(tgt, vision_output) logits = self.lm_head(output) return logits else: return vision_output def load_model(params, model_path): model = PixtralModel(params) with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f: for name, param in model.named_parameters(): if name in f.keys(): param.data = f.get_tensor(name) model.eval() return model model = load_model(params, model_path) tokenizer = MistralTokenizer.from_model("pixtral") def preprocess_image(image): if image is None: raise ValueError("No image provided") image = image.convert('RGB') image = image.resize((params['vision_encoder']['image_size'], params['vision_encoder']['image_size'])) image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 return image_tensor @spaces.GPU(duration=120) def generate_text(image, prompt, max_tokens): try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") image_tensor = preprocess_image(image).to(device) model.to(device) tokenized = tokenizer.encode_chat_completion( ChatCompletionRequest( messages=[UserMessage(content=[TextChunk(text=prompt), ImageChunk(image=image)])], model="pixtral", ) ) input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).to(device) for _ in range(max_tokens): logits = model(image_tensor, input_ids) next_token_logits = logits[0, -1, :] next_token = torch.argmax(next_token_logits, dim=-1) input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1) if next_token.item() == tokenizer.eos_token_id: break generated_text = tokenizer.decode(input_ids[0].tolist()) # model.to("cpu") return generated_text, len(input_ids[0]), 1 except Exception as e: return f"Error: {str(e)}", 0, 0 @spaces.GPU(duration=60) def calculate_similarity(image1, image2): try: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tensor1 = preprocess_image(image1).to(device) tensor2 = preprocess_image(image2).to(device) model.to(device) embedding1 = model(tensor1).mean(dim=1) embedding2 = model(tensor2).mean(dim=1) similarity = F.cosine_similarity(embedding1, embedding2).item() # model.to("cpu") return similarity except Exception as e: return f"Error: {str(e)}" with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown("## Model Details") gr.Markdown(f"- Model Dimension: {params['dim']}") gr.Markdown(f"- Number of Layers: {params['n_layers']}") gr.Markdown(f"- Number of Attention Heads: {params['n_heads']}") gr.Markdown(f"- Vision Encoder Hidden Size: {params['vision_encoder']['hidden_size']}") gr.Markdown(f"- Number of Vision Encoder Layers: {params['vision_encoder']['num_hidden_layers']}") gr.Markdown(f"- Number of Vision Encoder Attention Heads: {params['vision_encoder']['num_attention_heads']}") gr.Markdown(f"- Image Size: {params['vision_encoder']['image_size']}x{params['vision_encoder']['image_size']}") gr.Markdown(f"- Patch Size: {params['vision_encoder']['patch_size']}x{params['vision_encoder']['patch_size']}") gr.Markdown("## How it works") gr.Markdown("1. The image is processed by a Vision Encoder using 2D ROPE (Rotary Position Embedding).") gr.Markdown("2. The encoder uses SiLU activation in its feed-forward layers.") gr.Markdown("3. The encoded image is used for text generation or similarity comparison.") gr.Markdown(description) with gr.Tabs(): with gr.TabItem("Image-to-Text Generation"): with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil", label="Input Image") input_prompt = gr.Textbox(label="Prompt") max_tokens_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max Tokens") submit_btn = gr.Button("Generate Text") with gr.Column(): output_text = gr.Textbox(label="Generated Text") token_count = gr.Number(label="Number of Tokens") image_count = gr.Number(label="Number of Images Processed") submit_btn.click( fn=generate_text, inputs=[input_image, input_prompt, max_tokens_slider], outputs=[output_text, token_count, image_count] ) with gr.TabItem("Image Similarity Comparison"): with gr.Row(): image1_input = gr.Image(type="pil", label="Image 1") image2_input = gr.Image(type="pil", label="Image 2") similarity_btn = gr.Button("📸🌬️Calculate Similarity") similarity_output = gr.Number(label="Similarity Score (0.0 to 1.0)") similarity_btn.click( fn=calculate_similarity, inputs=[image1_input, image2_input], outputs=[similarity_output] ) if __name__ == "__main__": demo.launch()