prithivMLmods commited on
Commit
7be0e24
·
verified ·
1 Parent(s): 9604e47

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -62
app.py CHANGED
@@ -47,29 +47,6 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
47
 
48
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
 
50
- # Define a helper function that returns HTML for a progress bar with a label.
51
- def progress_bar_html_with_label(label="Thinking..."):
52
- return f"""
53
- <div style="font-weight: bold; margin-bottom: 5px;">{label}</div>
54
- <div id="progress-container" style="width: 100%; background-color: #eee; border-radius: 4px; overflow: hidden;">
55
- <div id="progress-bar" style="width: 0%; height: 10px; background-color: limegreen; transition: width 0.1s;"></div>
56
- </div>
57
- <script>
58
- (function() {{
59
- let progressBar = document.getElementById("progress-bar");
60
- let width = 0;
61
- let interval = setInterval(function(){{
62
- if(width < 100) {{
63
- width += 1;
64
- progressBar.style.width = width + "%";
65
- }} else {{
66
- clearInterval(interval);
67
- }}
68
- }}, 100);
69
- }})();
70
- </script>
71
- """
72
-
73
  # Load text-only model and tokenizer
74
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
75
  tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -110,6 +87,22 @@ def clean_chat_history(chat_history):
110
  cleaned.append(msg)
111
  return cleaned
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  # Environment variables and parameters for Stable Diffusion XL
114
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
115
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
@@ -214,29 +207,44 @@ def generate(
214
  Special commands:
215
  - "@tts1" or "@tts2": triggers text-to-speech.
216
  - "@image": triggers image generation using the SDXL pipeline.
 
 
 
217
  """
218
  text = input_dict["text"]
219
  files = input_dict.get("files", [])
220
 
221
- # If the command is for image generation
222
  if text.strip().lower().startswith("@image"):
223
  prompt = text[len("@image"):].strip()
224
- # Show animated progress bar with "Generating Image" label
225
- yield gr.HTML(progress_bar_html_with_label("Generating Image"))
226
- image_paths, used_seed = generate_image_fn(
227
- prompt=prompt,
228
- negative_prompt="",
229
- use_negative_prompt=False,
230
- seed=1,
231
- width=1024,
232
- height=1024,
233
- guidance_scale=3,
234
- num_inference_steps=25,
235
- randomize_seed=True,
236
- use_resolution_binning=True,
237
- num_images=1,
238
- )
239
- # After generation, yield only the image (progress bar no longer shown)
 
 
 
 
 
 
 
 
 
 
 
 
240
  yield gr.Image(image_paths[0])
241
  return # Exit early
242
 
@@ -247,14 +255,16 @@ def generate(
247
  if is_tts and voice_index:
248
  voice = TTS_VOICES[voice_index - 1]
249
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
250
  conversation = [{"role": "user", "content": text}]
251
  else:
252
  voice = None
 
253
  text = text.replace(tts_prefix, "").strip()
254
  conversation = clean_chat_history(chat_history)
255
  conversation.append({"role": "user", "content": text})
256
 
257
- # Multimodal generation (with file inputs)
258
  if files:
259
  if len(files) > 1:
260
  images = [load_image(image) for image in files]
@@ -277,21 +287,20 @@ def generate(
277
  thread.start()
278
 
279
  buffer = ""
280
- # Show initial progress bar with label "Thinking..."
281
- yield gr.HTML(progress_bar_html_with_label("Thinking..."))
 
282
  for new_text in streamer:
283
  buffer += new_text
284
  buffer = buffer.replace("<|im_end|>", "")
285
- # Update the message to show both the progress bar and current text output.
286
- html = f"""
287
- {progress_bar_html_with_label("Thinking...")}
288
- <div style="margin-top: 10px;">{buffer}</div>
289
- """
290
- yield gr.HTML(html)
291
- # Final output: only the generated text without the progress bar.
292
  yield buffer
293
  else:
294
- # Text-only generation
295
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
296
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
297
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -313,19 +322,18 @@ def generate(
313
  t.start()
314
 
315
  outputs = []
316
- buffer = ""
317
- # Show initial progress bar with label "Thinking..."
318
- yield gr.HTML(progress_bar_html_with_label("Thinking..."))
319
  for new_text in streamer:
320
  outputs.append(new_text)
321
- buffer = "".join(outputs)
322
- html = f"""
323
- {progress_bar_html_with_label("Thinking...")}
324
- <div style="margin-top: 10px;">{buffer}</div>
325
- """
326
- yield gr.HTML(html)
327
- final_response = buffer
328
- # Final output: just the final text.
329
  yield final_response
330
 
331
  # If TTS was requested, convert the final response to speech.
 
47
 
48
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  # Load text-only model and tokenizer
51
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
52
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
87
  cleaned.append(msg)
88
  return cleaned
89
 
90
+ # A helper function to render a progress bar using HTML.
91
+ def render_progress_bar(label: str, progress: int, output_text: str = "") -> str:
92
+ """
93
+ Returns an HTML snippet containing a label, a progress bar (red background with a green inner bar),
94
+ and optionally some output text.
95
+ """
96
+ return f'''
97
+ <div style="margin-bottom: 10px;">
98
+ <div style="font-weight: bold; margin-bottom: 5px;">{label}</div>
99
+ <div style="width: 100%; background-color: red; border-radius: 5px; overflow: hidden; height: 10px;">
100
+ <div style="width: {progress}%; background-color: green; height: 100%; transition: width 0.3s;"></div>
101
+ </div>
102
+ <div style="margin-top: 10px;">{output_text}</div>
103
+ </div>
104
+ '''
105
+
106
  # Environment variables and parameters for Stable Diffusion XL
107
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
108
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
 
207
  Special commands:
208
  - "@tts1" or "@tts2": triggers text-to-speech.
209
  - "@image": triggers image generation using the SDXL pipeline.
210
+
211
+ Instead of yielding a simple "Thinking..." text, an animated progress bar is shown (via an HTML snippet)
212
+ that goes from red to green. When the inference is complete the progress bar is replaced by the final result.
213
  """
214
  text = input_dict["text"]
215
  files = input_dict.get("files", [])
216
 
217
+ # Image generation branch
218
  if text.strip().lower().startswith("@image"):
219
  prompt = text[len("@image"):].strip()
220
+ # Use a container to capture the result from the thread.
221
+ result_container = []
222
+ def run_image():
223
+ result_container.append(generate_image_fn(
224
+ prompt=prompt,
225
+ negative_prompt="",
226
+ use_negative_prompt=False,
227
+ seed=1,
228
+ width=1024,
229
+ height=1024,
230
+ guidance_scale=3,
231
+ num_inference_steps=25,
232
+ randomize_seed=True,
233
+ use_resolution_binning=True,
234
+ num_images=1,
235
+ ))
236
+ thread = Thread(target=run_image)
237
+ thread.start()
238
+ start_time = time.time()
239
+ # Simulate progress bar updates while image generation is running.
240
+ while thread.is_alive():
241
+ progress = min(95, int((time.time() - start_time) / 5 * 95))
242
+ yield render_progress_bar("Generating Image", progress)
243
+ time.sleep(0.5)
244
+ thread.join()
245
+ # Final update before showing the result.
246
+ yield render_progress_bar("Generating Image", 100)
247
+ image_paths, used_seed = result_container[0]
248
  yield gr.Image(image_paths[0])
249
  return # Exit early
250
 
 
255
  if is_tts and voice_index:
256
  voice = TTS_VOICES[voice_index - 1]
257
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
258
+ # Clear previous chat history for a fresh TTS request.
259
  conversation = [{"role": "user", "content": text}]
260
  else:
261
  voice = None
262
+ # Remove any stray @tts tags and build the conversation history.
263
  text = text.replace(tts_prefix, "").strip()
264
  conversation = clean_chat_history(chat_history)
265
  conversation.append({"role": "user", "content": text})
266
 
267
+ # Multimodal (image + text) branch
268
  if files:
269
  if len(files) > 1:
270
  images = [load_image(image) for image in files]
 
287
  thread.start()
288
 
289
  buffer = ""
290
+ start_time = time.time()
291
+ # Initial progress bar for multimodal inference.
292
+ yield render_progress_bar("Thinking...", 0)
293
  for new_text in streamer:
294
  buffer += new_text
295
  buffer = buffer.replace("<|im_end|>", "")
296
+ progress = min(95, int((time.time() - start_time) / 5 * 95))
297
+ yield render_progress_bar("Thinking...", progress, output_text=buffer)
298
+ # Final progress update (100%).
299
+ yield render_progress_bar("Thinking...", 100, output_text=buffer)
300
+ # Then yield final response (progress bar update no longer shown).
 
 
301
  yield buffer
302
  else:
303
+ # Text-only generation branch.
304
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
305
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
306
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
322
  t.start()
323
 
324
  outputs = []
325
+ start_time = time.time()
326
+ # Initial progress bar update.
327
+ yield render_progress_bar("Thinking...", 0)
328
  for new_text in streamer:
329
  outputs.append(new_text)
330
+ current_text = "".join(outputs)
331
+ progress = min(95, int((time.time() - start_time) / 5 * 95))
332
+ yield render_progress_bar("Thinking...", progress, output_text=current_text)
333
+ final_response = "".join(outputs)
334
+ # Final update (100% progress).
335
+ yield render_progress_bar("Thinking...", 100, output_text=final_response)
336
+ # Finally, yield the final plain response so the progress bar disappears.
 
337
  yield final_response
338
 
339
  # If TTS was requested, convert the final response to speech.