sagar007 commited on
Commit
e266b1f
·
verified ·
1 Parent(s): 7b75090

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -289
app.py CHANGED
@@ -1,309 +1,134 @@
 
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel
 
 
 
 
4
  from PIL import Image
5
- import logging
6
  import spaces
7
- import numpy
8
 
9
- # Setup logging
10
- logging.basicConfig(level=logging.INFO)
11
 
12
- class LLaVAPhiModel:
13
- def __init__(self, model_id="sagar007/Lava_phi"):
14
- self.device = "cuda"
15
- self.model_id = model_id
16
- logging.info("Initializing LLaVA-Phi model...")
 
 
 
 
 
17
 
18
- # Initialize tokenizer
19
- self.tokenizer = AutoTokenizer.from_pretrained(model_id)
20
- if self.tokenizer.pad_token is None:
21
- self.tokenizer.pad_token = self.tokenizer.eos_token
22
 
23
- try:
24
- # Use CLIPProcessor directly instead of AutoProcessor
25
- self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
26
- logging.info("Successfully loaded CLIP processor")
27
- except Exception as e:
28
- logging.error(f"Failed to load CLIP processor: {str(e)}")
29
- self.processor = None
30
 
31
- # Increase history length to retain more context
32
- self.history = []
33
- self.model = None
34
- self.clip = None
 
 
35
 
36
- @spaces.GPU
37
- def ensure_models_loaded(self):
38
- """Ensure models are loaded in GPU context"""
39
- if self.model is None:
40
- # Improved quantization config for better quality
41
- from transformers import BitsAndBytesConfig
42
- quantization_config = BitsAndBytesConfig(
43
- load_in_8bit=True, # Changed from 4-bit to 8-bit for better quality
44
- bnb_8bit_compute_dtype=torch.float16,
45
- bnb_8bit_use_double_quant=False
46
- )
47
-
48
- try:
49
- self.model = AutoModelForCausalLM.from_pretrained(
50
- self.model_id,
51
- quantization_config=quantization_config,
52
- device_map="auto",
53
- torch_dtype=torch.bfloat16,
54
- trust_remote_code=True
55
- )
56
- self.model.config.pad_token_id = self.tokenizer.eos_token_id
57
- logging.info("Successfully loaded main model")
58
- except Exception as e:
59
- logging.error(f"Failed to load main model: {str(e)}")
60
- raise
61
 
62
- if self.clip is None:
63
- try:
64
- # Use CLIPModel directly instead of AutoModel
65
- self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
66
- logging.info("Successfully loaded CLIP model")
67
- except Exception as e:
68
- logging.error(f"Failed to load CLIP model: {str(e)}")
69
- self.clip = None
70
 
71
- @spaces.GPU
72
- def process_image(self, image):
73
- """Process image through CLIP if available"""
74
- try:
75
- self.ensure_models_loaded()
76
-
77
- if self.clip is None or self.processor is None:
78
- logging.warning("CLIP model or processor not available")
79
- return None
80
-
81
- # Convert image to correct format
82
- if isinstance(image, str):
83
- image = Image.open(image)
84
- elif isinstance(image, numpy.ndarray):
85
- image = Image.fromarray(image)
86
-
87
- # Ensure image is in RGB mode
88
- if image.mode != 'RGB':
89
- image = image.convert('RGB')
90
-
91
- with torch.no_grad():
92
- try:
93
- # Process image with error handling
94
- image_inputs = self.processor(images=image, return_tensors="pt")
95
- image_features = self.clip.get_image_features(
96
- pixel_values=image_inputs.pixel_values.to(self.device)
97
- )
98
- logging.info("Successfully processed image through CLIP")
99
- return image_features
100
- except Exception as e:
101
- logging.error(f"Error during image processing: {str(e)}")
102
- return None
103
- except Exception as e:
104
- logging.error(f"Error in process_image: {str(e)}")
105
- return None
106
 
107
- @spaces.GPU(duration=120)
108
- def generate_response(self, message, image=None):
109
- try:
110
- self.ensure_models_loaded()
111
-
112
- if image is not None:
113
- image_features = self.process_image(image)
114
- has_image = image_features is not None
115
- if not has_image:
116
- message = "Note: Image processing is not available - continuing with text only.\n" + message
117
-
118
- prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
119
-
120
- # Include more history for better context (previous 5 turns instead of 3)
121
- context = ""
122
- for turn in self.history[-5:]:
123
- context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
124
-
125
- full_prompt = context + prompt
126
-
127
- # Increased context window
128
- inputs = self.tokenizer(
129
- full_prompt,
130
- return_tensors="pt",
131
- padding=True,
132
- truncation=True,
133
- max_length=1024 # Increased from 512
134
- )
135
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
136
-
137
- if has_image:
138
- inputs["image_features"] = image_features
139
-
140
- with torch.no_grad():
141
- # More conservative generation settings to reduce hallucinations
142
- outputs = self.model.generate(
143
- **inputs,
144
- max_new_tokens=256,
145
- min_length=20,
146
- temperature=0.3, # Reduced from 0.7 for more deterministic output
147
- do_sample=True,
148
- top_p=0.92,
149
- top_k=50,
150
- repetition_penalty=1.2, # Adjusted for more natural responses
151
- no_repeat_ngram_size=3,
152
- use_cache=True,
153
- pad_token_id=self.tokenizer.pad_token_id,
154
- eos_token_id=self.tokenizer.eos_token_id
155
- )
156
- else:
157
- prompt = f"human: {message}\ngpt:"
158
- # Include more history
159
- context = ""
160
- for turn in self.history[-5:]:
161
- context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
162
-
163
- full_prompt = context + prompt
164
-
165
- # Increased context window
166
- inputs = self.tokenizer(
167
- full_prompt,
168
- return_tensors="pt",
169
- padding=True,
170
- truncation=True,
171
- max_length=1024 # Increased from 512
172
- )
173
- inputs = {k: v.to(self.device) for k, v in inputs.items()}
174
-
175
- with torch.no_grad():
176
- # More conservative generation settings
177
- outputs = self.model.generate(
178
- **inputs,
179
- max_new_tokens=200, # Slightly increased from 150
180
- min_length=20,
181
- temperature=0.3, # Reduced from 0.6
182
- do_sample=True,
183
- top_p=0.92,
184
- top_k=50,
185
- repetition_penalty=1.2,
186
- no_repeat_ngram_size=4,
187
- use_cache=True,
188
- pad_token_id=self.tokenizer.pad_token_id,
189
- eos_token_id=self.tokenizer.eos_token_id
190
- )
191
-
192
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
193
-
194
- # Clean up response
195
- if "gpt:" in response:
196
- response = response.split("gpt:")[-1].strip()
197
- if "human:" in response:
198
- response = response.split("human:")[0].strip()
199
- if "<image>" in response:
200
- response = response.replace("<image>", "").strip()
201
-
202
- self.history.append((message, response))
203
- return response
204
-
205
- except Exception as e:
206
- logging.error(f"Error generating response: {str(e)}")
207
- logging.error(f"Full traceback:", exc_info=True)
208
- return f"Error: {str(e)}"
209
 
210
- def clear_history(self):
211
- self.history = []
212
- return None
213
 
214
- # Add new function to control generation parameters
215
- def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2):
216
- """Update generation parameters to control hallucination tendency"""
217
- self.temperature = temperature
218
- self.top_p = top_p
219
- self.top_k = top_k
220
- self.repetition_penalty = repetition_penalty
221
- return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}"
222
 
223
- def create_demo():
224
- try:
225
- model = LLaVAPhiModel()
226
-
227
- with gr.Blocks(css="footer {visibility: hidden}") as demo:
228
- gr.Markdown(
229
- """
230
- # LLaVA-Phi Demo (Optimized for Accuracy)
231
- Chat with a vision-language model that can understand both text and images.
232
- """
233
- )
234
-
235
- chatbot = gr.Chatbot(height=400)
236
- with gr.Row():
237
- with gr.Column(scale=0.7):
238
- msg = gr.Textbox(
239
- show_label=False,
240
- placeholder="Enter text and/or upload an image",
241
- container=False
242
- )
243
- with gr.Column(scale=0.15, min_width=0):
244
- clear = gr.Button("Clear")
245
- with gr.Column(scale=0.15, min_width=0):
246
- submit = gr.Button("Submit", variant="primary")
247
-
248
- image = gr.Image(type="pil", label="Upload Image (Optional)")
249
-
250
- # Add generation parameter controls
251
- with gr.Accordion("Advanced Settings", open=False):
252
- gr.Markdown("Adjust these parameters to control hallucination tendency")
253
- temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
254
- top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
255
- top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
256
- rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
257
- update_params = gr.Button("Update Parameters")
258
-
259
- def respond(message, chat_history, image):
260
- if not message and image is None:
261
- return chat_history
262
-
263
- response = model.generate_response(message, image)
264
- chat_history.append((message, response))
265
- return "", chat_history
266
-
267
- def clear_chat():
268
- model.clear_history()
269
- return None, None
270
-
271
- def update_params_fn(temp, top_p, top_k, rep_penalty):
272
- return model.update_generation_params(temp, top_p, top_k, rep_penalty)
273
-
274
- submit.click(
275
- respond,
276
- [msg, chatbot, image],
277
- [msg, chatbot],
278
- )
279
-
280
- clear.click(
281
- clear_chat,
282
- None,
283
- [chatbot, image],
284
- )
285
-
286
- msg.submit(
287
- respond,
288
- [msg, chatbot, image],
289
- [msg, chatbot],
290
- )
291
-
292
- update_params.click(
293
- update_params_fn,
294
- [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
295
- None
296
- )
297
-
298
- return demo
299
  except Exception as e:
300
- logging.error(f"Error creating demo: {str(e)}")
301
  raise
302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  if __name__ == "__main__":
304
- demo = create_demo()
305
- demo.launch(
306
- server_name="0.0.0.0",
307
- server_port=7860,
308
- share=True
309
- )
 
1
+ import os
2
  import gradio as gr
3
  import torch
4
+ from peft import LoraConfig, get_peft_model
5
+ import torch.nn as nn
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
+ from peft import PeftModel, PeftConfig
8
+ import whisper
9
  from PIL import Image
10
+ import clip
11
  import spaces
 
12
 
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
14
 
15
+ class MultimodalPhi(nn.Module):
16
+ def __init__(self, phi_model):
17
+ super().__init__()
18
+ self.phi_model = phi_model
19
+ self.embedding_projection = nn.Linear(512, phi_model.config.hidden_size)
20
+
21
+ def forward(self, image_embeddings, input_ids, attention_mask):
22
+ projected_embeddings = self.embedding_projection(image_embeddings).unsqueeze(1)
23
+ inputs_embeds = self.phi_model.get_input_embeddings()(input_ids)
24
+ combined_embeds = torch.cat([projected_embeddings, inputs_embeds], dim=1)
25
 
26
+ extended_attention_mask = torch.cat([torch.ones(attention_mask.shape[0], 1).to(attention_mask.device), attention_mask], dim=1)
 
 
 
27
 
28
+ outputs = self.phi_model(inputs_embeds=combined_embeds, attention_mask=extended_attention_mask)
29
+ return outputs.logits[:, 1:, :] # Exclude the image token from output
30
+
31
+ def load_models():
32
+ try:
33
+ print("Loading models...")
34
+ peft_model_name = "sagar007/phi-1_5-finetuned"
35
 
36
+ # Manually load and create LoraConfig, ignoring unknown arguments
37
+ config_dict = LoraConfig.from_pretrained(peft_model_name).to_dict()
38
+ # Remove 'layer_replication' if present
39
+ config_dict.pop('layer_replication', None)
40
+ lora_config = LoraConfig(**config_dict)
41
+ print("PEFT config loaded")
42
 
43
+ base_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1_5", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
44
+ print("Base model loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ phi_model = get_peft_model(base_model, lora_config)
47
+ phi_model.load_state_dict(torch.load(peft_model_name + '/adapter_model.bin', map_location=device), strict=False)
48
+ print("PEFT model loaded")
 
 
 
 
 
49
 
50
+ multimodal_model = MultimodalPhi(phi_model)
51
+ multimodal_model.load_state_dict(torch.load('multimodal_phi_small_gpu.pth', map_location=device))
52
+ multimodal_model.to(device)
53
+ multimodal_model.eval()
54
+ print("Multimodal model loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1_5")
57
+ tokenizer.pad_token = tokenizer.eos_token
58
+ print("Tokenizer loaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ audio_model = whisper.load_model("base").to(device)
61
+ print("Audio model loaded")
 
62
 
63
+ clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
64
+ print("CLIP model loaded")
 
 
 
 
 
 
65
 
66
+ return multimodal_model, tokenizer, audio_model, clip_model, clip_preprocess
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  except Exception as e:
68
+ print(f"Error in load_models: {str(e)}")
69
  raise
70
 
71
+ model, tokenizer, audio_model, clip_model, clip_preprocess = load_models()
72
+
73
+ @spaces.GPU
74
+ def get_clip_embedding(image):
75
+ image = clip_preprocess(Image.open(image)).unsqueeze(0).to(device)
76
+ with torch.no_grad():
77
+ image_features = clip_model.encode_image(image)
78
+ return image_features.squeeze(0)
79
+
80
+ @spaces.GPU
81
+ def process_text(text):
82
+ try:
83
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device)
84
+ dummy_image_embedding = torch.zeros(512).to(device) # Dummy image embedding for text-only input
85
+ with torch.no_grad():
86
+ outputs = model(dummy_image_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask)
87
+ return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True)
88
+ except Exception as e:
89
+ return f"Error in process_text: {str(e)}"
90
+
91
+ @spaces.GPU
92
+ def process_image(image):
93
+ try:
94
+ clip_embedding = get_clip_embedding(image)
95
+ prompt = "Describe this image:"
96
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128, padding='max_length').to(device)
97
+ with torch.no_grad():
98
+ outputs = model(clip_embedding.unsqueeze(0), inputs.input_ids, inputs.attention_mask)
99
+ return tokenizer.decode(outputs[0].argmax(dim=-1), skip_special_tokens=True)
100
+ except Exception as e:
101
+ return f"Error in process_image: {str(e)}"
102
+
103
+ @spaces.GPU
104
+ def process_audio(audio):
105
+ try:
106
+ result = audio_model.transcribe(audio)
107
+ transcription = result["text"]
108
+ return process_text(f"Transcription: {transcription}\nPlease respond to this:")
109
+ except Exception as e:
110
+ return f"Error in process_audio: {str(e)}"
111
+
112
+ def chat(message, image, audio):
113
+ if audio is not None:
114
+ return process_audio(audio)
115
+ elif image is not None:
116
+ return process_image(image)
117
+ else:
118
+ return process_text(message)
119
+
120
+ iface = gr.Interface(
121
+ fn=chat,
122
+ inputs=[
123
+ gr.Textbox(placeholder="Enter text here..."),
124
+ gr.Image(type="pil"),
125
+ gr.Audio(type="filepath")
126
+ ],
127
+ outputs="text",
128
+ title="Multi-Modal Assistant",
129
+ description="Chat with an AI using text, images, or audio!"
130
+ )
131
+
132
  if __name__ == "__main__":
133
+ print("Starting Gradio interface...")
134
+ iface.launch(share=True)