Spaces:
Paused
Paused
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from safetensors import safe_open | |
import json | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
from huggingface_hub import snapshot_download | |
from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageChunk | |
from mistral_common.protocol.instruct.request import ChatCompletionRequest | |
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer | |
import spaces | |
import math | |
from typing import List, Optional, Tuple | |
title = "# **WIP / DEMO** 🙋🏻♂️Welcome to Tonic's Pixtral Model Demo" | |
description = """ | |
This demo showcases two capabilities of the Pixtral model: | |
1. Image-to-Text Generation | |
2. Image Similarity Comparison | |
### Join us : | |
🌟TeamTonic🌟 is always making cool demos! Join our active builder's 🛠️community 👻 [![Join us on Discord](https://img.shields.io/discord/1109943800132010065?label=Discord&logo=discord&style=flat-square)](https://discord.gg/qdfnvSPcqP) On 🤗Huggingface:[MultiTransformer](https://huggingface.co/MultiTransformer) On 🌐Github: [Tonic-AI](https://github.com/tonic-ai) & contribute to🌟 [Build Tonic](https://git.tonic-ai.com/contribute)🤗Big thanks to Yuvi Sharma and all the folks at huggingface for the community grant 🤗 | |
""" | |
model_path = snapshot_download(repo_id="mistralai/Pixtral-12B-2409") | |
with open(f'{model_path}/params.json', 'r') as f: | |
params = json.load(f) | |
with open(f'{model_path}/tekken.json', 'r') as f: | |
tokenizer_config = json.load(f) | |
class RMSNorm(nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-5): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.ones(dim)) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight | |
def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float) -> torch.Tensor: | |
freqs = 1.0 / (theta**(torch.arange(0, dim, 2).float() / dim)) | |
h = torch.arange(height) | |
w = torch.arange(width) | |
freqs_h = torch.outer(h, freqs[::2]).float() | |
freqs_w = torch.outer(w, freqs[1::2]).float() | |
freqs_2d = torch.cat([freqs_h[:, None, :].repeat(1, width, 1), freqs_w[None, :, :].repeat(height, 1, 1)], dim=-1) | |
return torch.polar(torch.ones_like(freqs_2d), freqs_2d) | |
def apply_rotary_emb_vit(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
freqs_cis = freqs_cis.view(*freqs_cis.shape[:2], 1, freqs_cis.shape[-1]) | |
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
return xq_out.type_as(xq), xk_out.type_as(xk) | |
class Attention(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.n_heads = args['num_attention_heads'] | |
self.head_dim = args['hidden_size'] // args['num_attention_heads'] | |
self.wq = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False) | |
self.wk = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False) | |
self.wv = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False) | |
self.wo = nn.Linear(args['hidden_size'], args['hidden_size'], bias=False) | |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
batch, patches, _ = x.shape | |
q, k, v = self.wq(x), self.wk(x), self.wv(x) | |
q = q.reshape(batch, patches, self.n_heads, self.head_dim) | |
k = k.reshape(batch, patches, self.n_heads, self.head_dim) | |
v = v.reshape(batch, patches, self.n_heads, self.head_dim) | |
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis) | |
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_dim) | |
attn = F.softmax(scores, dim=-1) | |
out = torch.matmul(attn, v) | |
out = out.reshape(batch, patches, self.n_heads * self.head_dim) | |
return self.wo(out) | |
class FeedForward(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.w1 = nn.Linear(args['hidden_size'], args['intermediate_size'], bias=False) | |
self.w2 = nn.Linear(args['intermediate_size'], args['hidden_size'], bias=False) | |
self.w3 = nn.Linear(args['hidden_size'], args['intermediate_size'], bias=False) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.w2(F.silu(self.w1(x)) * self.w3(x)) | |
class TransformerBlock(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.attention = Attention(args) | |
self.feed_forward = FeedForward(args) | |
self.attention_norm = RMSNorm(args['hidden_size'], eps=1e-5) | |
self.ffn_norm = RMSNorm(args['hidden_size'], eps=1e-5) | |
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
r = self.attention(self.attention_norm(x), freqs_cis=freqs_cis) | |
h = x + r | |
r = self.feed_forward(self.ffn_norm(h)) | |
out = h + r | |
return out | |
class VisionTransformer(nn.Module): | |
def __init__(self, args): | |
super().__init__() | |
self.args = args | |
self.patch_conv = nn.Conv2d(args['num_channels'], args['hidden_size'], kernel_size=args['patch_size'], stride=args['patch_size'], bias=False) | |
self.ln_pre = RMSNorm(args['hidden_size'], eps=1e-5) | |
self.transformer = nn.ModuleList([TransformerBlock(args) for _ in range(args['num_hidden_layers'])]) | |
self.max_patches_per_side = args['image_size'] // args['patch_size'] | |
self._freqs_cis = None | |
def freqs_cis(self) -> torch.Tensor: | |
if self._freqs_cis is None: | |
self._freqs_cis = precompute_freqs_cis_2d( | |
dim=self.args['hidden_size'] // self.args['num_attention_heads'], | |
height=self.max_patches_per_side, | |
width=self.max_patches_per_side, | |
theta=self.args['rope_theta'], | |
) | |
return self._freqs_cis.to(self.patch_conv.weight.device) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = self.patch_conv(x) | |
x = x.flatten(2).transpose(1, 2) | |
x = self.ln_pre(x) | |
freqs_cis = self.freqs_cis | |
for layer in self.transformer: | |
x = layer(x, freqs_cis=freqs_cis) | |
return x | |
class VisionLanguageAdapter(nn.Module): | |
def __init__(self, args, dim: int): | |
super().__init__() | |
self.w_in = nn.Linear(args['hidden_size'], dim, bias=True) | |
self.gelu = nn.GELU() | |
self.w_out = nn.Linear(dim, dim, bias=True) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.w_out(self.gelu(self.w_in(x))) | |
class PixtralModel(nn.Module): | |
def __init__(self, params): | |
super().__init__() | |
self.vision_encoder = VisionTransformer(params['vision_encoder']) | |
self.vision_language_adapter = VisionLanguageAdapter(params['vision_encoder'], params['dim']) | |
self.language_model = nn.TransformerDecoder( | |
nn.TransformerDecoderLayer(d_model=params['dim'], nhead=params['n_heads'], dim_feedforward=params['hidden_dim']), | |
num_layers=params['n_layers'] | |
) | |
self.lm_head = nn.Linear(params['dim'], params['vocab_size'], bias=False) | |
def forward(self, image, input_ids=None): | |
vision_output = self.vision_encoder(image) | |
vision_output = self.vision_language_adapter(vision_output) | |
if input_ids is not None: | |
tgt = self.lm_head.weight[input_ids].transpose(0, 1) | |
output = self.language_model(tgt, vision_output) | |
logits = self.lm_head(output) | |
return logits | |
else: | |
return vision_output | |
def load_model(params, model_path): | |
model = PixtralModel(params) | |
with safe_open(f'{model_path}/consolidated.safetensors', framework="pt", device="cpu") as f: | |
for name, param in model.named_parameters(): | |
if name in f.keys(): | |
param.data = f.get_tensor(name) | |
model.eval() | |
return model | |
model = load_model(params, model_path) | |
tokenizer = MistralTokenizer.from_model("pixtral") | |
def preprocess_image(image): | |
if image is None: | |
raise ValueError("No image provided") | |
image = image.convert('RGB') | |
image = image.resize((params['vision_encoder']['image_size'], params['vision_encoder']['image_size'])) | |
image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 | |
return image_tensor | |
def generate_text(image, prompt, max_tokens): | |
try: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
image_tensor = preprocess_image(image).to(device) | |
model.to(device) | |
tokenized = tokenizer.encode_chat_completion( | |
ChatCompletionRequest( | |
messages=[UserMessage(content=[TextChunk(text=prompt), ImageChunk(image=image)])], | |
model="pixtral", | |
) | |
) | |
input_ids = torch.tensor(tokenized.tokens).unsqueeze(0).to(device) | |
for _ in range(max_tokens): | |
logits = model(image_tensor, input_ids) | |
next_token_logits = logits[0, -1, :] | |
next_token = torch.argmax(next_token_logits, dim=-1) | |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0).unsqueeze(0)], dim=-1) | |
if next_token.item() == tokenizer.eos_token_id: | |
break | |
generated_text = tokenizer.decode(input_ids[0].tolist()) | |
# model.to("cpu") | |
return generated_text, len(input_ids[0]), 1 | |
except Exception as e: | |
return f"Error: {str(e)}", 0, 0 | |
def calculate_similarity(image1, image2): | |
try: | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tensor1 = preprocess_image(image1).to(device) | |
tensor2 = preprocess_image(image2).to(device) | |
model.to(device) | |
embedding1 = model(tensor1).mean(dim=1) | |
embedding2 = model(tensor2).mean(dim=1) | |
similarity = F.cosine_similarity(embedding1, embedding2).item() | |
# model.to("cpu") | |
return similarity | |
except Exception as e: | |
return f"Error: {str(e)}" | |
with gr.Blocks() as demo: | |
gr.Markdown(title) | |
gr.Markdown("## Model Details") | |
gr.Markdown(f"- Model Dimension: {params['dim']}") | |
gr.Markdown(f"- Number of Layers: {params['n_layers']}") | |
gr.Markdown(f"- Number of Attention Heads: {params['n_heads']}") | |
gr.Markdown(f"- Vision Encoder Hidden Size: {params['vision_encoder']['hidden_size']}") | |
gr.Markdown(f"- Number of Vision Encoder Layers: {params['vision_encoder']['num_hidden_layers']}") | |
gr.Markdown(f"- Number of Vision Encoder Attention Heads: {params['vision_encoder']['num_attention_heads']}") | |
gr.Markdown(f"- Image Size: {params['vision_encoder']['image_size']}x{params['vision_encoder']['image_size']}") | |
gr.Markdown(f"- Patch Size: {params['vision_encoder']['patch_size']}x{params['vision_encoder']['patch_size']}") | |
gr.Markdown("## How it works") | |
gr.Markdown("1. The image is processed by a Vision Encoder using 2D ROPE (Rotary Position Embedding).") | |
gr.Markdown("2. The encoder uses SiLU activation in its feed-forward layers.") | |
gr.Markdown("3. The encoded image is used for text generation or similarity comparison.") | |
gr.Markdown(description) | |
with gr.Tabs(): | |
with gr.TabItem("Image-to-Text Generation"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
input_prompt = gr.Textbox(label="Prompt") | |
max_tokens_slider = gr.Slider(minimum=10, maximum=500, value=100, step=10, label="Max Tokens") | |
submit_btn = gr.Button("Generate Text") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Generated Text") | |
token_count = gr.Number(label="Number of Tokens") | |
image_count = gr.Number(label="Number of Images Processed") | |
submit_btn.click( | |
fn=generate_text, | |
inputs=[input_image, input_prompt, max_tokens_slider], | |
outputs=[output_text, token_count, image_count] | |
) | |
with gr.TabItem("Image Similarity Comparison"): | |
with gr.Row(): | |
image1_input = gr.Image(type="pil", label="Image 1") | |
image2_input = gr.Image(type="pil", label="Image 2") | |
similarity_btn = gr.Button("📸🌬️Calculate Similarity") | |
similarity_output = gr.Number(label="Similarity Score (0.0 to 1.0)") | |
similarity_btn.click( | |
fn=calculate_similarity, | |
inputs=[image1_input, image2_input], | |
outputs=[similarity_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() |