prithivMLmods commited on
Commit
f74b154
·
verified ·
1 Parent(s): f8a9b16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -19
app.py CHANGED
@@ -7,7 +7,7 @@ import torch
7
  import edge_tts
8
  import asyncio
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
11
  from transformers.image_utils import load_image
12
  import time
13
 
@@ -35,6 +35,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
35
 
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
 
38
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
39
  tokenizer = AutoTokenizer.from_pretrained(model_id)
40
  model = AutoModelForCausalLM.from_pretrained(
@@ -53,6 +54,7 @@ TTS_VOICES = [
53
  "en-US-JasonNeural", # @tts6
54
  ]
55
 
 
56
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
57
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
58
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -77,29 +79,39 @@ def generate(
77
  top_k: int = 50,
78
  repetition_penalty: float = 1.2,
79
  ):
80
- """Generates chatbot response and handles TTS requests with multimodal input support"""
 
 
 
 
81
  text = input_dict["text"]
82
  files = input_dict.get("files", [])
83
 
84
  # Check if input includes image(s)
85
- images = [load_image(image) for image in files] if files else []
 
 
 
 
 
86
 
87
- # Check if message is for TTS
88
  tts_prefix = "@tts"
89
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 7))
90
  voice_index = next((i for i in range(1, 7) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
91
-
92
  if is_tts and voice_index:
93
  voice = TTS_VOICES[voice_index - 1]
94
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
 
95
  else:
96
  voice = None
97
  text = text.replace(tts_prefix, "").strip()
 
98
 
99
- conversation = [*chat_history, {"role": "user", "content": text}]
100
-
101
  if images:
102
- # Process multimodal input
103
  messages = [
104
  {"role": "user", "content": [
105
  *[{"type": "image", "image": image} for image in images],
@@ -109,9 +121,9 @@ def generate(
109
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
110
  inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
111
 
 
112
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
113
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
114
-
115
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
116
  thread.start()
117
 
@@ -124,7 +136,7 @@ def generate(
124
  yield buffer
125
 
126
  else:
127
- # Process text-only input
128
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
129
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
130
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -147,21 +159,18 @@ def generate(
147
  t.start()
148
 
149
  outputs = []
150
- for text in streamer:
151
- outputs.append(text)
152
  yield "".join(outputs)
153
 
154
  final_response = "".join(outputs)
155
 
156
- # Yield text response first
157
- yield final_response
158
 
 
159
  if is_tts and voice:
160
- loop = asyncio.new_event_loop()
161
- asyncio.set_event_loop(loop)
162
- output_file = loop.run_until_complete(text_to_speech(final_response, voice))
163
-
164
- # Separate yield for audio output
165
  yield gr.Audio(output_file, autoplay=True)
166
 
167
  demo = gr.ChatInterface(
 
7
  import edge_tts
8
  import asyncio
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
11
  from transformers.image_utils import load_image
12
  import time
13
 
 
35
 
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
 
38
+ # Load the text-only model and tokenizer
39
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
40
  tokenizer = AutoTokenizer.from_pretrained(model_id)
41
  model = AutoModelForCausalLM.from_pretrained(
 
54
  "en-US-JasonNeural", # @tts6
55
  ]
56
 
57
+ # Load the multimodal (OCR) model and processor
58
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
59
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
60
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
79
  top_k: int = 50,
80
  repetition_penalty: float = 1.2,
81
  ):
82
+ """
83
+ Generates chatbot response and handles TTS requests with multimodal input support.
84
+ If the query starts with a TTS command (e.g. '@tts1'), the chat history is cleared
85
+ to avoid non-text responses (like Audio) interfering with template rendering.
86
+ """
87
  text = input_dict["text"]
88
  files = input_dict.get("files", [])
89
 
90
  # Check if input includes image(s)
91
+ if len(files) > 1:
92
+ images = [load_image(image) for image in files]
93
+ elif len(files) == 1:
94
+ images = [load_image(files[0])]
95
+ else:
96
+ images = []
97
 
98
+ # Check if the message is for TTS
99
  tts_prefix = "@tts"
100
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 7))
101
  voice_index = next((i for i in range(1, 7) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
102
+
103
  if is_tts and voice_index:
104
  voice = TTS_VOICES[voice_index - 1]
105
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
106
+ # Clear conversation history to avoid issues with non-text outputs.
107
+ conversation = [{"role": "user", "content": text}]
108
  else:
109
  voice = None
110
  text = text.replace(tts_prefix, "").strip()
111
+ conversation = [*chat_history, {"role": "user", "content": text}]
112
 
113
+ # If there are images, process multimodal input
 
114
  if images:
 
115
  messages = [
116
  {"role": "user", "content": [
117
  *[{"type": "image", "image": image} for image in images],
 
121
  prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
122
  inputs = processor(text=[prompt], images=images, return_tensors="pt", padding=True).to("cuda")
123
 
124
+ # Handle generation for multimodal input using model_m
125
  streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
126
  generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
 
127
  thread = Thread(target=model_m.generate, kwargs=generation_kwargs)
128
  thread.start()
129
 
 
136
  yield buffer
137
 
138
  else:
139
+ # Process text-only input using model
140
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
141
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
142
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
159
  t.start()
160
 
161
  outputs = []
162
+ for new_text in streamer:
163
+ outputs.append(new_text)
164
  yield "".join(outputs)
165
 
166
  final_response = "".join(outputs)
167
 
168
+ # Yield text response first.
169
+ yield final_response
170
 
171
+ # If TTS was requested, yield audio output separately.
172
  if is_tts and voice:
173
+ output_file = asyncio.run(text_to_speech(final_response, voice))
 
 
 
 
174
  yield gr.Audio(output_file, autoplay=True)
175
 
176
  demo = gr.ChatInterface(