Tonic commited on
Commit
337337b
·
unverified ·
1 Parent(s): 2b9d6b2
Files changed (2) hide show
  1. app.py +167 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from safetensors import safe_open
5
+ import json
6
+ import gradio as gr
7
+ from PIL import Image
8
+ import numpy as np
9
+ from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageChunk
10
+ from mistral_common.protocol.instruct.request import ChatCompletionRequest
11
+ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
12
+
13
+ # Load model parameters and tokenizer configuration
14
+ with open('PARAMS.json', 'r') as f:
15
+ params = json.load(f)
16
+
17
+ with open('TEKKEN.json', 'r') as f:
18
+ tokenizer_config = json.load(f)
19
+
20
+ class GELU(nn.Module):
21
+ def __init__(self, dim_in, dim_out, approximate='none', bias=True):
22
+ super().__init__()
23
+ self.linear = nn.Linear(dim_in, dim_out, bias=bias)
24
+ self.approximate = approximate
25
+
26
+ def forward(self, x):
27
+ if self.approximate == 'tanh':
28
+ return 0.5 * x * (1 + torch.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x, 3))))
29
+ else:
30
+ return F.gelu(self.linear(x))
31
+
32
+ class Rope2D(nn.Module):
33
+ def __init__(self, dim, max_position_embeddings=1024, base=10000):
34
+ super().__init__()
35
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
36
+ self.register_buffer("inv_freq", inv_freq)
37
+ self.max_seq_len_cached = max_position_embeddings
38
+ t = torch.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype)
39
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
40
+ emb = torch.cat((freqs, freqs), dim=-1)
41
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
42
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
43
+
44
+ def forward(self, x, seq_len=None):
45
+ if seq_len > self.max_seq_len_cached:
46
+ self.max_seq_len_cached = seq_len
47
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
48
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
49
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
50
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
51
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
52
+ return (
53
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
54
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
55
+ )
56
+
57
+ class VisionEncoder(nn.Module):
58
+ def __init__(self, config):
59
+ super().__init__()
60
+ self.config = config
61
+ self.embed = nn.Conv2d(config['num_channels'], config['hidden_size'], kernel_size=config['patch_size'], stride=config['patch_size'])
62
+ self.rope = Rope2D(config['hidden_size'] // config['num_attention_heads'], base=config['rope_theta'])
63
+ self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=config['hidden_size'], nhead=config['num_attention_heads'], dim_feedforward=config['intermediate_size']) for _ in range(config['num_hidden_layers'])])
64
+ self.norm = nn.LayerNorm(config['hidden_size'])
65
+ self.gelu = GELU(config['hidden_size'], config['hidden_size'])
66
+
67
+ def forward(self, pixel_values):
68
+ x = self.embed(pixel_values)
69
+ b, c, h, w = x.shape
70
+ x = x.flatten(2).transpose(1, 2)
71
+ cos, sin = self.rope(x, seq_len=h*w)
72
+ for layer in self.layers:
73
+ x = layer(x)
74
+ x = self.norm(x)
75
+ x = self.gelu(x)
76
+ return x
77
+
78
+ class PixtralModel(nn.Module):
79
+ def __init__(self, params):
80
+ super().__init__()
81
+ self.vision_encoder = VisionEncoder(params['vision_encoder'])
82
+ # Add text generation components here
83
+
84
+ def forward(self, image):
85
+ vision_output = self.vision_encoder(image)
86
+ # Add text generation logic here
87
+ return vision_output
88
+
89
+ # Initialize the model
90
+ model = PixtralModel(params)
91
+
92
+ # Load the model weights
93
+ with safe_open('consolidated.safetensors', framework="pt", device="cpu") as f:
94
+ for name, param in model.named_parameters():
95
+ if name in f.keys():
96
+ param.data = f.get_tensor(name)
97
+
98
+ model.eval()
99
+
100
+ # Initialize the tokenizer
101
+ tokenizer = MistralTokenizer.from_model("pixtral")
102
+
103
+ def process_image_and_text(image, prompt):
104
+ # Prepare the image
105
+ image = image.convert('RGB')
106
+ image = image.resize((params['vision_encoder']['image_size'], params['vision_encoder']['image_size']))
107
+ image_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).unsqueeze(0).float() / 255.0
108
+
109
+ # Tokenize the input
110
+ tokenized = tokenizer.encode_chat_completion(
111
+ ChatCompletionRequest(
112
+ messages=[
113
+ UserMessage(
114
+ content=[
115
+ TextChunk(text=prompt),
116
+ ImageChunk(image=image),
117
+ ]
118
+ )
119
+ ],
120
+ model="pixtral",
121
+ )
122
+ )
123
+ tokens, text, images = tokenized.tokens, tokenized.text, tokenized.images
124
+
125
+ # Process the image and generate text
126
+ with torch.no_grad():
127
+ vision_output = model(image_tensor)
128
+ # Add text generation logic here
129
+ generated_text = f"Generated text based on the image and prompt: {prompt}"
130
+
131
+ return generated_text, len(tokens), len(images)
132
+
133
+ # Gradio interface
134
+ with gr.Blocks() as demo:
135
+ gr.Markdown("# Pixtral Image-to-Text Model Demo")
136
+ gr.Markdown("Upload an image and provide a prompt to generate text based on it.")
137
+
138
+ with gr.Row():
139
+ with gr.Column(scale=1):
140
+ input_image = gr.Image(type="pil")
141
+ input_prompt = gr.Textbox(label="Prompt")
142
+ submit_btn = gr.Button("Generate Text")
143
+
144
+ with gr.Column(scale=1):
145
+ output_text = gr.Textbox(label="Generated Text")
146
+ token_count = gr.Number(label="Number of Tokens")
147
+ image_count = gr.Number(label="Number of Images")
148
+
149
+ submit_btn.click(
150
+ fn=process_image_and_text,
151
+ inputs=[input_image, input_prompt],
152
+ outputs=[output_text, token_count, image_count]
153
+ )
154
+
155
+ gr.Markdown("## How it works")
156
+ gr.Markdown("1. The image is processed by a Vision Encoder using 2D ROPE (Rotary Position Embedding).")
157
+ gr.Markdown("2. The encoder uses GELU activation in its layers.")
158
+ gr.Markdown("3. The encoded image and the prompt are used to generate descriptive text.")
159
+
160
+ gr.Markdown("## Model Details")
161
+ gr.Markdown(f"- Vision Encoder Hidden Size: {params['vision_encoder']['hidden_size']}")
162
+ gr.Markdown(f"- Number of Vision Encoder Layers: {params['vision_encoder']['num_hidden_layers']}")
163
+ gr.Markdown(f"- Number of Attention Heads: {params['vision_encoder']['num_attention_heads']}")
164
+ gr.Markdown(f"- Image Size: {params['vision_encoder']['image_size']}x{params['vision_encoder']['image_size']}")
165
+ gr.Markdown(f"- Patch Size: {params['vision_encoder']['patch_size']}x{params['vision_encoder']['patch_size']}")
166
+
167
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ safetensors>=0.3.1
3
+ gradio>=3.32.0
4
+ Pillow>=9.0.0
5
+ numpy>=1.21.0
6
+ mistral_common