prithivMLmods commited on
Commit
4f97d6f
·
verified ·
1 Parent(s): 7be0e24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -72
app.py CHANGED
@@ -23,6 +23,7 @@ from transformers import (
23
  from transformers.image_utils import load_image
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
 
 
26
  DESCRIPTION = """
27
  # QwQ Edge 💬
28
  """
@@ -41,6 +42,23 @@ h1 {
41
  }
42
  '''
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  MAX_MAX_NEW_TOKENS = 2048
45
  DEFAULT_MAX_NEW_TOKENS = 1024
46
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
@@ -87,22 +105,6 @@ def clean_chat_history(chat_history):
87
  cleaned.append(msg)
88
  return cleaned
89
 
90
- # A helper function to render a progress bar using HTML.
91
- def render_progress_bar(label: str, progress: int, output_text: str = "") -> str:
92
- """
93
- Returns an HTML snippet containing a label, a progress bar (red background with a green inner bar),
94
- and optionally some output text.
95
- """
96
- return f'''
97
- <div style="margin-bottom: 10px;">
98
- <div style="font-weight: bold; margin-bottom: 5px;">{label}</div>
99
- <div style="width: 100%; background-color: red; border-radius: 5px; overflow: hidden; height: 10px;">
100
- <div style="width: {progress}%; background-color: green; height: 100%; transition: width 0.3s;"></div>
101
- </div>
102
- <div style="margin-top: 10px;">{output_text}</div>
103
- </div>
104
- '''
105
-
106
  # Environment variables and parameters for Stable Diffusion XL
107
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
108
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
@@ -183,6 +185,7 @@ def generate_image_fn(
183
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
184
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
185
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
186
  if device.type == "cuda":
187
  with torch.autocast("cuda", dtype=torch.float16):
188
  outputs = sd_pipe(**batch_options)
@@ -207,51 +210,36 @@ def generate(
207
  Special commands:
208
  - "@tts1" or "@tts2": triggers text-to-speech.
209
  - "@image": triggers image generation using the SDXL pipeline.
210
-
211
- Instead of yielding a simple "Thinking..." text, an animated progress bar is shown (via an HTML snippet)
212
- that goes from red to green. When the inference is complete the progress bar is replaced by the final result.
213
  """
214
  text = input_dict["text"]
215
  files = input_dict.get("files", [])
216
 
217
- # Image generation branch
 
 
 
218
  if text.strip().lower().startswith("@image"):
 
219
  prompt = text[len("@image"):].strip()
220
- # Use a container to capture the result from the thread.
221
- result_container = []
222
- def run_image():
223
- result_container.append(generate_image_fn(
224
- prompt=prompt,
225
- negative_prompt="",
226
- use_negative_prompt=False,
227
- seed=1,
228
- width=1024,
229
- height=1024,
230
- guidance_scale=3,
231
- num_inference_steps=25,
232
- randomize_seed=True,
233
- use_resolution_binning=True,
234
- num_images=1,
235
- ))
236
- thread = Thread(target=run_image)
237
- thread.start()
238
- start_time = time.time()
239
- # Simulate progress bar updates while image generation is running.
240
- while thread.is_alive():
241
- progress = min(95, int((time.time() - start_time) / 5 * 95))
242
- yield render_progress_bar("Generating Image", progress)
243
- time.sleep(0.5)
244
- thread.join()
245
- # Final update before showing the result.
246
- yield render_progress_bar("Generating Image", 100)
247
- image_paths, used_seed = result_container[0]
248
  yield gr.Image(image_paths[0])
249
  return # Exit early
250
 
251
- tts_prefix = "@tts"
252
- is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
253
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
254
-
255
  if is_tts and voice_index:
256
  voice = TTS_VOICES[voice_index - 1]
257
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
@@ -264,7 +252,6 @@ def generate(
264
  conversation = clean_chat_history(chat_history)
265
  conversation.append({"role": "user", "content": text})
266
 
267
- # Multimodal (image + text) branch
268
  if files:
269
  if len(files) > 1:
270
  images = [load_image(image) for image in files]
@@ -287,20 +274,17 @@ def generate(
287
  thread.start()
288
 
289
  buffer = ""
290
- start_time = time.time()
291
- # Initial progress bar for multimodal inference.
292
- yield render_progress_bar("Thinking...", 0)
293
  for new_text in streamer:
294
  buffer += new_text
295
  buffer = buffer.replace("<|im_end|>", "")
296
- progress = min(95, int((time.time() - start_time) / 5 * 95))
297
- yield render_progress_bar("Thinking...", progress, output_text=buffer)
298
- # Final progress update (100%).
299
- yield render_progress_bar("Thinking...", 100, output_text=buffer)
300
- # Then yield final response (progress bar update no longer shown).
301
- yield buffer
302
  else:
303
- # Text-only generation branch.
304
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
305
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
306
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -321,20 +305,18 @@ def generate(
321
  t = Thread(target=model.generate, kwargs=generation_kwargs)
322
  t.start()
323
 
 
 
324
  outputs = []
325
- start_time = time.time()
326
- # Initial progress bar update.
327
- yield render_progress_bar("Thinking...", 0)
328
  for new_text in streamer:
329
  outputs.append(new_text)
330
  current_text = "".join(outputs)
331
- progress = min(95, int((time.time() - start_time) / 5 * 95))
332
- yield render_progress_bar("Thinking...", progress, output_text=current_text)
 
333
  final_response = "".join(outputs)
334
- # Final update (100% progress).
335
- yield render_progress_bar("Thinking...", 100, output_text=final_response)
336
- # Finally, yield the final plain response so the progress bar disappears.
337
- yield final_response
338
 
339
  # If TTS was requested, convert the final response to speech.
340
  if is_tts and voice:
 
23
  from transformers.image_utils import load_image
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
 
26
+
27
  DESCRIPTION = """
28
  # QwQ Edge 💬
29
  """
 
42
  }
43
  '''
44
 
45
+ def progress_bar_html(label: str) -> str:
46
+ """Return an HTML snippet with a label and an animated, thin light-blue progress bar."""
47
+ return f"""
48
+ <div style="display: flex; align-items: center;">
49
+ <span style="margin-right: 8px;">{label}</span>
50
+ <div style="position: relative; width: 110px; height: 5px; background: #e0e0e0; border-radius: 5px; overflow: hidden;">
51
+ <div style="width: 100%; height: 100%; background-color: lightblue; animation: progress-bar-animation 1s linear infinite;"></div>
52
+ </div>
53
+ </div>
54
+ <style>
55
+ @keyframes progress-bar-animation {{
56
+ 0% {{ transform: translateX(-100%); }}
57
+ 100% {{ transform: translateX(100%); }}
58
+ }}
59
+ </style>
60
+ """
61
+
62
  MAX_MAX_NEW_TOKENS = 2048
63
  DEFAULT_MAX_NEW_TOKENS = 1024
64
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
105
  cleaned.append(msg)
106
  return cleaned
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Environment variables and parameters for Stable Diffusion XL
109
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
110
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
 
185
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
186
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
187
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
188
+ # Wrap the pipeline call in autocast if using CUDA
189
  if device.type == "cuda":
190
  with torch.autocast("cuda", dtype=torch.float16):
191
  outputs = sd_pipe(**batch_options)
 
210
  Special commands:
211
  - "@tts1" or "@tts2": triggers text-to-speech.
212
  - "@image": triggers image generation using the SDXL pipeline.
 
 
 
213
  """
214
  text = input_dict["text"]
215
  files = input_dict.get("files", [])
216
 
217
+ tts_prefix = "@tts"
218
+ is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
219
+ voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
220
+
221
  if text.strip().lower().startswith("@image"):
222
+ # Remove the "@image" tag and use the rest as prompt
223
  prompt = text[len("@image"):].strip()
224
+ # Yield progress bar for image generation
225
+ yield progress_bar_html("Generating Image")
226
+ image_paths, used_seed = generate_image_fn(
227
+ prompt=prompt,
228
+ negative_prompt="",
229
+ use_negative_prompt=False,
230
+ seed=1,
231
+ width=1024,
232
+ height=1024,
233
+ guidance_scale=3,
234
+ num_inference_steps=25,
235
+ randomize_seed=True,
236
+ use_resolution_binning=True,
237
+ num_images=1,
238
+ )
239
+ # Yield the generated image, replacing the progress bar
 
 
 
 
 
 
 
 
 
 
 
 
240
  yield gr.Image(image_paths[0])
241
  return # Exit early
242
 
 
 
 
 
243
  if is_tts and voice_index:
244
  voice = TTS_VOICES[voice_index - 1]
245
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
252
  conversation = clean_chat_history(chat_history)
253
  conversation.append({"role": "user", "content": text})
254
 
 
255
  if files:
256
  if len(files) > 1:
257
  images = [load_image(image) for image in files]
 
274
  thread.start()
275
 
276
  buffer = ""
277
+ # Yield initial progress bar for multimodal generation
278
+ yield progress_bar_html("Thinking...")
 
279
  for new_text in streamer:
280
  buffer += new_text
281
  buffer = buffer.replace("<|im_end|>", "")
282
+ time.sleep(0.01)
283
+ # Update with partial text and progress bar
284
+ yield f"<div>{buffer}</div><div>{progress_bar_html('Thinking...')}</div>"
285
+ # Final output: remove progress bar
286
+ yield f"<div>{buffer}</div>"
 
287
  else:
 
288
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
289
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
290
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
305
  t = Thread(target=model.generate, kwargs=generation_kwargs)
306
  t.start()
307
 
308
+ # Yield initial progress bar for text generation
309
+ yield progress_bar_html("Thinking...")
310
  outputs = []
 
 
 
311
  for new_text in streamer:
312
  outputs.append(new_text)
313
  current_text = "".join(outputs)
314
+ time.sleep(0.01)
315
+ # Update message with partial text and progress bar
316
+ yield f"<div>{current_text}</div><div>{progress_bar_html('Thinking...')}</div>"
317
  final_response = "".join(outputs)
318
+ # Final output: only the final response text, progress bar removed.
319
+ yield f"<div>{final_response}</div>"
 
 
320
 
321
  # If TTS was requested, convert the final response to speech.
322
  if is_tts and voice: