sagar007 commited on
Commit
7b75090
·
verified ·
1 Parent(s): 63d3bc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -19
app.py CHANGED
@@ -28,6 +28,7 @@ class LLaVAPhiModel:
28
  logging.error(f"Failed to load CLIP processor: {str(e)}")
29
  self.processor = None
30
 
 
31
  self.history = []
32
  self.model = None
33
  self.clip = None
@@ -36,13 +37,12 @@ class LLaVAPhiModel:
36
  def ensure_models_loaded(self):
37
  """Ensure models are loaded in GPU context"""
38
  if self.model is None:
39
- # Load main model with updated quantization config
40
  from transformers import BitsAndBytesConfig
41
  quantization_config = BitsAndBytesConfig(
42
- load_in_4bit=True,
43
- bnb_4bit_compute_dtype=torch.float16,
44
- bnb_4bit_use_double_quant=True,
45
- bnb_4bit_quant_type="nf4"
46
  )
47
 
48
  try:
@@ -116,17 +116,21 @@ class LLaVAPhiModel:
116
  message = "Note: Image processing is not available - continuing with text only.\n" + message
117
 
118
  prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
 
 
119
  context = ""
120
- for turn in self.history[-3:]:
121
  context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
122
 
123
  full_prompt = context + prompt
 
 
124
  inputs = self.tokenizer(
125
  full_prompt,
126
  return_tensors="pt",
127
  padding=True,
128
  truncation=True,
129
- max_length=512
130
  )
131
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
132
 
@@ -134,15 +138,16 @@ class LLaVAPhiModel:
134
  inputs["image_features"] = image_features
135
 
136
  with torch.no_grad():
 
137
  outputs = self.model.generate(
138
  **inputs,
139
  max_new_tokens=256,
140
  min_length=20,
141
- temperature=0.7,
142
  do_sample=True,
143
- top_p=0.9,
144
- top_k=40,
145
- repetition_penalty=1.5,
146
  no_repeat_ngram_size=3,
147
  use_cache=True,
148
  pad_token_id=self.tokenizer.pad_token_id,
@@ -150,30 +155,34 @@ class LLaVAPhiModel:
150
  )
151
  else:
152
  prompt = f"human: {message}\ngpt:"
 
153
  context = ""
154
- for turn in self.history[-3:]:
155
  context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
156
 
157
  full_prompt = context + prompt
 
 
158
  inputs = self.tokenizer(
159
  full_prompt,
160
  return_tensors="pt",
161
  padding=True,
162
  truncation=True,
163
- max_length=512
164
  )
165
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
166
 
167
  with torch.no_grad():
 
168
  outputs = self.model.generate(
169
  **inputs,
170
- max_new_tokens=150,
171
  min_length=20,
172
- temperature=0.6,
173
  do_sample=True,
174
- top_p=0.85,
175
- top_k=30,
176
- repetition_penalty=1.8,
177
  no_repeat_ngram_size=4,
178
  use_cache=True,
179
  pad_token_id=self.tokenizer.pad_token_id,
@@ -202,6 +211,15 @@ class LLaVAPhiModel:
202
  self.history = []
203
  return None
204
 
 
 
 
 
 
 
 
 
 
205
  def create_demo():
206
  try:
207
  model = LLaVAPhiModel()
@@ -209,7 +227,7 @@ def create_demo():
209
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
210
  gr.Markdown(
211
  """
212
- # LLaVA-Phi Demo (ZeroGPU)
213
  Chat with a vision-language model that can understand both text and images.
214
  """
215
  )
@@ -229,6 +247,15 @@ def create_demo():
229
 
230
  image = gr.Image(type="pil", label="Upload Image (Optional)")
231
 
 
 
 
 
 
 
 
 
 
232
  def respond(message, chat_history, image):
233
  if not message and image is None:
234
  return chat_history
@@ -241,6 +268,9 @@ def create_demo():
241
  model.clear_history()
242
  return None, None
243
 
 
 
 
244
  submit.click(
245
  respond,
246
  [msg, chatbot, image],
@@ -259,6 +289,12 @@ def create_demo():
259
  [msg, chatbot],
260
  )
261
 
 
 
 
 
 
 
262
  return demo
263
  except Exception as e:
264
  logging.error(f"Error creating demo: {str(e)}")
 
28
  logging.error(f"Failed to load CLIP processor: {str(e)}")
29
  self.processor = None
30
 
31
+ # Increase history length to retain more context
32
  self.history = []
33
  self.model = None
34
  self.clip = None
 
37
  def ensure_models_loaded(self):
38
  """Ensure models are loaded in GPU context"""
39
  if self.model is None:
40
+ # Improved quantization config for better quality
41
  from transformers import BitsAndBytesConfig
42
  quantization_config = BitsAndBytesConfig(
43
+ load_in_8bit=True, # Changed from 4-bit to 8-bit for better quality
44
+ bnb_8bit_compute_dtype=torch.float16,
45
+ bnb_8bit_use_double_quant=False
 
46
  )
47
 
48
  try:
 
116
  message = "Note: Image processing is not available - continuing with text only.\n" + message
117
 
118
  prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:"
119
+
120
+ # Include more history for better context (previous 5 turns instead of 3)
121
  context = ""
122
+ for turn in self.history[-5:]:
123
  context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
124
 
125
  full_prompt = context + prompt
126
+
127
+ # Increased context window
128
  inputs = self.tokenizer(
129
  full_prompt,
130
  return_tensors="pt",
131
  padding=True,
132
  truncation=True,
133
+ max_length=1024 # Increased from 512
134
  )
135
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
136
 
 
138
  inputs["image_features"] = image_features
139
 
140
  with torch.no_grad():
141
+ # More conservative generation settings to reduce hallucinations
142
  outputs = self.model.generate(
143
  **inputs,
144
  max_new_tokens=256,
145
  min_length=20,
146
+ temperature=0.3, # Reduced from 0.7 for more deterministic output
147
  do_sample=True,
148
+ top_p=0.92,
149
+ top_k=50,
150
+ repetition_penalty=1.2, # Adjusted for more natural responses
151
  no_repeat_ngram_size=3,
152
  use_cache=True,
153
  pad_token_id=self.tokenizer.pad_token_id,
 
155
  )
156
  else:
157
  prompt = f"human: {message}\ngpt:"
158
+ # Include more history
159
  context = ""
160
+ for turn in self.history[-5:]:
161
  context += f"human: {turn[0]}\ngpt: {turn[1]}\n"
162
 
163
  full_prompt = context + prompt
164
+
165
+ # Increased context window
166
  inputs = self.tokenizer(
167
  full_prompt,
168
  return_tensors="pt",
169
  padding=True,
170
  truncation=True,
171
+ max_length=1024 # Increased from 512
172
  )
173
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
174
 
175
  with torch.no_grad():
176
+ # More conservative generation settings
177
  outputs = self.model.generate(
178
  **inputs,
179
+ max_new_tokens=200, # Slightly increased from 150
180
  min_length=20,
181
+ temperature=0.3, # Reduced from 0.6
182
  do_sample=True,
183
+ top_p=0.92,
184
+ top_k=50,
185
+ repetition_penalty=1.2,
186
  no_repeat_ngram_size=4,
187
  use_cache=True,
188
  pad_token_id=self.tokenizer.pad_token_id,
 
211
  self.history = []
212
  return None
213
 
214
+ # Add new function to control generation parameters
215
+ def update_generation_params(self, temperature=0.3, top_p=0.92, top_k=50, repetition_penalty=1.2):
216
+ """Update generation parameters to control hallucination tendency"""
217
+ self.temperature = temperature
218
+ self.top_p = top_p
219
+ self.top_k = top_k
220
+ self.repetition_penalty = repetition_penalty
221
+ return f"Generation parameters updated: temp={temperature}, top_p={top_p}, top_k={top_k}, rep_penalty={repetition_penalty}"
222
+
223
  def create_demo():
224
  try:
225
  model = LLaVAPhiModel()
 
227
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
228
  gr.Markdown(
229
  """
230
+ # LLaVA-Phi Demo (Optimized for Accuracy)
231
  Chat with a vision-language model that can understand both text and images.
232
  """
233
  )
 
247
 
248
  image = gr.Image(type="pil", label="Upload Image (Optional)")
249
 
250
+ # Add generation parameter controls
251
+ with gr.Accordion("Advanced Settings", open=False):
252
+ gr.Markdown("Adjust these parameters to control hallucination tendency")
253
+ temp_slider = gr.Slider(0.1, 1.0, value=0.3, step=0.1, label="Temperature (lower = more factual)")
254
+ top_p_slider = gr.Slider(0.5, 1.0, value=0.92, step=0.01, label="Top-p (nucleus sampling)")
255
+ top_k_slider = gr.Slider(10, 100, value=50, step=5, label="Top-k")
256
+ rep_penalty_slider = gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
257
+ update_params = gr.Button("Update Parameters")
258
+
259
  def respond(message, chat_history, image):
260
  if not message and image is None:
261
  return chat_history
 
268
  model.clear_history()
269
  return None, None
270
 
271
+ def update_params_fn(temp, top_p, top_k, rep_penalty):
272
+ return model.update_generation_params(temp, top_p, top_k, rep_penalty)
273
+
274
  submit.click(
275
  respond,
276
  [msg, chatbot, image],
 
289
  [msg, chatbot],
290
  )
291
 
292
+ update_params.click(
293
+ update_params_fn,
294
+ [temp_slider, top_p_slider, top_k_slider, rep_penalty_slider],
295
+ None
296
+ )
297
+
298
  return demo
299
  except Exception as e:
300
  logging.error(f"Error creating demo: {str(e)}")