sagar007 commited on
Commit
78e7cbb
·
verified ·
1 Parent(s): 9f22f0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -75
app.py CHANGED
@@ -41,12 +41,21 @@ text_model = AutoModelForCausalLM.from_pretrained(
41
  quantization_config=quantization_config
42
  )
43
 
44
- vision_model = AutoModelForCausalLM.from_pretrained(
45
- VISION_MODEL_ID,
46
- trust_remote_code=True,
47
- torch_dtype="auto",
48
- attn_implementation="flash_attention_2"
49
- ).to(device).eval()
 
 
 
 
 
 
 
 
 
50
 
51
  vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
52
 
@@ -55,80 +64,84 @@ tts_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler
55
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
56
 
57
  # Helper functions
58
- # Helper functions
59
- @spaces.GPU
60
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
61
- conversation = [{"role": "system", "content": system_prompt}]
62
- for prompt, answer in history:
63
- conversation.extend([
64
- {"role": "user", "content": prompt},
65
- {"role": "assistant", "content": answer},
66
- ])
67
- conversation.append({"role": "user", "content": message})
68
-
69
- input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
70
- attention_mask = torch.ones_like(input_ids) # Create attention mask
71
- streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
72
-
73
- generate_kwargs = dict(
74
- input_ids=input_ids,
75
- attention_mask=attention_mask, # Pass attention mask
76
- max_new_tokens=max_new_tokens,
77
- do_sample=temperature > 0,
78
- top_p=top_p,
79
- top_k=top_k,
80
- temperature=temperature,
81
- eos_token_id=[128001, 128008, 128009],
82
- streamer=streamer,
83
- )
84
-
85
- with torch.no_grad():
86
- thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
87
- thread.start()
88
-
89
- buffer = ""
90
- audio_buffer = np.array([])
91
- for new_text in streamer:
92
- buffer += new_text
93
-
94
- # Generate speech for the new text
95
- tts_input_ids = tts_tokenizer(new_text, return_tensors="pt").input_ids.to(device)
96
- tts_description = "A clear and natural voice reads the text with moderate speed and expression."
97
- tts_description_ids = tts_tokenizer(tts_description, return_tensors="pt").input_ids.to(device)
98
-
99
  with torch.no_grad():
100
- audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- new_audio = audio_generation.cpu().numpy().squeeze()
103
- audio_buffer = np.concatenate((audio_buffer, new_audio))
 
 
 
104
 
105
- yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
106
 
107
- @spaces.GPU
108
- def process_vision_query(image, text_input):
109
- prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
110
-
111
- # Ensure the image is in the correct format
112
- if isinstance(image, np.ndarray):
113
- # Convert numpy array to PIL Image
114
- image = Image.fromarray(image).convert("RGB")
115
- elif not isinstance(image, Image.Image):
116
- raise ValueError("Invalid image type. Expected PIL.Image.Image or numpy.ndarray")
117
-
118
- # Now process the image
119
- inputs = vision_processor(prompt, images=image, return_tensors="pt").to(device)
120
-
121
- with torch.no_grad():
122
- generate_ids = vision_model.generate(
123
- **inputs,
124
- max_new_tokens=1000,
125
- eos_token_id=vision_processor.tokenizer.eos_token_id
126
- )
127
-
128
- generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
129
- response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
130
- return response
131
-
132
 
133
  # Custom CSS
134
  custom_css = """
@@ -206,6 +219,7 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Base().set(
206
 
207
  submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot, audio_output])
208
  clear_btn.click(lambda: None, None, chatbot, queue=False)
 
209
  with gr.Tab("Vision Model (Phi-3.5-vision)"):
210
  with gr.Row():
211
  with gr.Column(scale=1):
 
41
  quantization_config=quantization_config
42
  )
43
 
44
+ try:
45
+ vision_model = AutoModelForCausalLM.from_pretrained(
46
+ VISION_MODEL_ID,
47
+ trust_remote_code=True,
48
+ torch_dtype="auto",
49
+ attn_implementation="flash_attention_2"
50
+ ).to(device).eval()
51
+ except Exception as e:
52
+ print(f"Error loading model with flash attention: {e}")
53
+ print("Falling back to default attention implementation")
54
+ vision_model = AutoModelForCausalLM.from_pretrained(
55
+ VISION_MODEL_ID,
56
+ trust_remote_code=True,
57
+ torch_dtype="auto"
58
+ ).to(device).eval()
59
 
60
  vision_processor = AutoProcessor.from_pretrained(VISION_MODEL_ID, trust_remote_code=True)
61
 
 
64
  tts_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
65
 
66
  # Helper functions
67
+ @spaces.GPU(timeout=300) # Increase timeout to 5 minutes
 
68
  def stream_text_chat(message, history, system_prompt, temperature=0.8, max_new_tokens=1024, top_p=1.0, top_k=20):
69
+ try:
70
+ conversation = [{"role": "system", "content": system_prompt}]
71
+ for prompt, answer in history:
72
+ conversation.extend([
73
+ {"role": "user", "content": prompt},
74
+ {"role": "assistant", "content": answer},
75
+ ])
76
+ conversation.append({"role": "user", "content": message})
77
+
78
+ input_ids = text_tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(text_model.device)
79
+ attention_mask = torch.ones_like(input_ids)
80
+ streamer = TextIteratorStreamer(text_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
81
+
82
+ generate_kwargs = dict(
83
+ input_ids=input_ids,
84
+ attention_mask=attention_mask,
85
+ max_new_tokens=max_new_tokens,
86
+ do_sample=temperature > 0,
87
+ top_p=top_p,
88
+ top_k=top_k,
89
+ temperature=temperature,
90
+ eos_token_id=[128001, 128008, 128009],
91
+ streamer=streamer,
92
+ )
93
+
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  with torch.no_grad():
95
+ thread = Thread(target=text_model.generate, kwargs=generate_kwargs)
96
+ thread.start()
97
+
98
+ buffer = ""
99
+ audio_buffer = np.array([])
100
+ for new_text in streamer:
101
+ buffer += new_text
102
+
103
+ # Generate speech for the new text
104
+ tts_input_ids = tts_tokenizer(new_text, return_tensors="pt").input_ids.to(device)
105
+ tts_description = "A clear and natural voice reads the text with moderate speed and expression."
106
+ tts_description_ids = tts_tokenizer(tts_description, return_tensors="pt").input_ids.to(device)
107
+
108
+ with torch.no_grad():
109
+ audio_generation = tts_model.generate(input_ids=tts_description_ids, prompt_input_ids=tts_input_ids)
110
+
111
+ new_audio = audio_generation.cpu().numpy().squeeze()
112
+ audio_buffer = np.concatenate((audio_buffer, new_audio))
113
+
114
+ yield history + [[message, buffer]], (tts_model.config.sampling_rate, audio_buffer)
115
+ except Exception as e:
116
+ print(f"An error occurred: {str(e)}")
117
+ yield history + [[message, f"An error occurred: {str(e)}"]], None
118
+
119
+ @spaces.GPU(timeout=300) # Increase timeout to 5 minutes
120
+ def process_vision_query(image, text_input):
121
+ try:
122
+ prompt = f"<|user|>\n<|image_1|>\n{text_input}<|end|>\n<|assistant|>\n"
123
 
124
+ # Ensure the image is in the correct format
125
+ if isinstance(image, np.ndarray):
126
+ image = Image.fromarray(image).convert("RGB")
127
+ elif not isinstance(image, Image.Image):
128
+ raise ValueError("Invalid image type. Expected PIL.Image.Image or numpy.ndarray")
129
 
130
+ inputs = vision_processor(prompt, images=image, return_tensors="pt").to(device)
131
 
132
+ with torch.no_grad():
133
+ generate_ids = vision_model.generate(
134
+ **inputs,
135
+ max_new_tokens=1000,
136
+ eos_token_id=vision_processor.tokenizer.eos_token_id
137
+ )
138
+
139
+ generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
140
+ response = vision_processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
141
+ return response
142
+ except Exception as e:
143
+ print(f"An error occurred: {str(e)}")
144
+ return f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  # Custom CSS
147
  custom_css = """
 
219
 
220
  submit_btn.click(stream_text_chat, [msg, chatbot, system_prompt, temperature, max_new_tokens, top_p, top_k], [chatbot, audio_output])
221
  clear_btn.click(lambda: None, None, chatbot, queue=False)
222
+
223
  with gr.Tab("Vision Model (Phi-3.5-vision)"):
224
  with gr.Row():
225
  with gr.Column(scale=1):