prithivMLmods commited on
Commit
6db128e
·
verified ·
1 Parent(s): e8c0c12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -90
app.py CHANGED
@@ -19,15 +19,15 @@ from transformers import (
19
  TextIteratorStreamer,
20
  Qwen2VLForConditionalGeneration,
21
  AutoProcessor,
22
- AutoModelForImageTextToText,
23
  )
24
  from transformers.image_utils import load_image
25
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
26
 
27
- # Application description and CSS
28
  DESCRIPTION = """
29
  # QwQ Edge 💬
30
  """
 
31
  css = '''
32
  h1 {
33
  text-align: center;
@@ -48,9 +48,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
- # -------------------------
52
- # Load Text-only Model
53
- # -------------------------
54
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
55
  tokenizer = AutoTokenizer.from_pretrained(model_id)
56
  model = AutoModelForCausalLM.from_pretrained(
@@ -60,14 +58,19 @@ model = AutoModelForCausalLM.from_pretrained(
60
  )
61
  model.eval()
62
 
63
- # -------------------------
64
- # TTS Settings
65
- # -------------------------
66
  TTS_VOICES = [
67
  "en-US-JennyNeural", # @tts1
68
  "en-US-GuyNeural", # @tts2
69
  ]
70
 
 
 
 
 
 
 
 
 
71
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
72
  """Convert text to speech using Edge TTS and save as MP3"""
73
  communicate = edge_tts.Communicate(text, voice)
@@ -85,36 +88,14 @@ def clean_chat_history(chat_history):
85
  cleaned.append(msg)
86
  return cleaned
87
 
88
- # -------------------------
89
- # Load Multimodal Model (Qwen2-VL)
90
- # -------------------------
91
- MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
92
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
93
- model_m = Qwen2VLForConditionalGeneration.from_pretrained(
94
- MODEL_ID,
95
- trust_remote_code=True,
96
- torch_dtype=torch.float16
97
- ).to("cuda").eval()
98
-
99
- # -------------------------
100
- # Load Aya-Vision Model (New Feature)
101
- # -------------------------
102
- AYA_MODEL_ID = "CohereForAI/aya-vision-8b"
103
- aya_processor = AutoProcessor.from_pretrained(AYA_MODEL_ID)
104
- aya_model = AutoModelForImageTextToText.from_pretrained(
105
- AYA_MODEL_ID, device_map="auto", torch_dtype=torch.float16
106
- )
107
- aya_tokenizer = AutoTokenizer.from_pretrained(AYA_MODEL_ID)
108
-
109
- # -------------------------
110
- # Stable Diffusion XL Settings & Pipeline
111
- # -------------------------
112
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
113
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
114
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
115
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
116
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
117
 
 
118
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
119
  MODEL_ID_SD,
120
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -123,12 +104,15 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
123
  ).to(device)
124
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
125
 
 
126
  if torch.cuda.is_available():
127
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
128
 
 
129
  if USE_TORCH_COMPILE:
130
  sd_pipe.compile()
131
 
 
132
  if ENABLE_CPU_OFFLOAD:
133
  sd_pipe.enable_model_cpu_offload()
134
 
@@ -184,6 +168,7 @@ def generate_image_fn(
184
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
185
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
186
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
187
  if device.type == "cuda":
188
  with torch.autocast("cuda", dtype=torch.float16):
189
  outputs = sd_pipe(**batch_options)
@@ -208,55 +193,12 @@ def generate(
208
  Special commands:
209
  - "@tts1" or "@tts2": triggers text-to-speech.
210
  - "@image": triggers image generation using the SDXL pipeline.
211
- - "@aya-vision": triggers image-text-to-text generation using the Aya-Vision model.
212
  """
213
  text = input_dict["text"]
214
  files = input_dict.get("files", [])
215
 
216
- # -------------------------
217
- # Aya-Vision Feature
218
- # -------------------------
219
- if text.strip().lower().startswith("@aya-vision"):
220
- prompt = text[len("@aya-vision"):].strip()
221
- if files:
222
- if len(files) > 1:
223
- images = [load_image(file) for file in files]
224
- elif len(files) == 1:
225
- images = [load_image(files[0])]
226
- messages = [{
227
- "role": "user",
228
- "content": [
229
- *[{"type": "image", "image": image} for image in images],
230
- {"type": "text", "text": prompt},
231
- ]
232
- }]
233
- else:
234
- messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
235
- yield "Processing with Aya-Vision..."
236
- inputs = aya_processor.apply_chat_template(
237
- messages,
238
- padding=True,
239
- add_generation_prompt=True,
240
- tokenize=True,
241
- return_dict=True,
242
- return_tensors="pt"
243
- ).to(aya_model.device)
244
- # Remove deprecated parameter if present to avoid conflicts.
245
- inputs.pop("num_logits_to_keep", None)
246
- gen_tokens = aya_model.generate(
247
- **inputs,
248
- max_new_tokens=300,
249
- do_sample=True,
250
- temperature=0.3,
251
- )
252
- gen_text = aya_tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
253
- yield gen_text
254
- return # Exit early after processing with Aya-Vision
255
-
256
- # -------------------------
257
- # Image Generation Feature (@image)
258
- # -------------------------
259
  if text.strip().lower().startswith("@image"):
 
260
  prompt = text[len("@image"):].strip()
261
  yield "Generating image..."
262
  image_paths, used_seed = generate_image_fn(
@@ -272,12 +214,10 @@ def generate(
272
  use_resolution_binning=True,
273
  num_images=1,
274
  )
 
275
  yield gr.Image(image_paths[0])
276
  return # Exit early
277
 
278
- # -------------------------
279
- # TTS Feature (@tts1 or @tts2)
280
- # -------------------------
281
  tts_prefix = "@tts"
282
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
283
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -294,9 +234,6 @@ def generate(
294
  conversation = clean_chat_history(chat_history)
295
  conversation.append({"role": "user", "content": text})
296
 
297
- # -------------------------
298
- # Multimodal Input (with files) using Qwen2-VL
299
- # -------------------------
300
  if files:
301
  if len(files) > 1:
302
  images = [load_image(image) for image in files]
@@ -326,9 +263,7 @@ def generate(
326
  time.sleep(0.01)
327
  yield buffer
328
  else:
329
- # -------------------------
330
- # Text-only Generation
331
- # -------------------------
332
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
333
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
334
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -356,7 +291,8 @@ def generate(
356
 
357
  final_response = "".join(outputs)
358
  yield final_response
359
-
 
360
  if is_tts and voice:
361
  output_file = asyncio.run(text_to_speech(final_response, voice))
362
  yield gr.Audio(output_file, autoplay=True)
@@ -371,12 +307,13 @@ demo = gr.ChatInterface(
371
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
372
  ],
373
  examples=[
374
- [{"text": "@aya-vision Extract JSON from the image", "files": ["examples/document.jpg"]}],
375
- [{"text": "@aya-vision Summarize the letter", "files": ["examples/1.png"]}],
376
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
 
 
377
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
378
  ["Write a Python function to check if a number is prime."],
379
  ["@tts2 What causes rainbows to form?"],
 
380
  ],
381
  cache_examples=False,
382
  type="messages",
@@ -389,5 +326,4 @@ demo = gr.ChatInterface(
389
  )
390
 
391
  if __name__ == "__main__":
392
- # To create a public link, set share=True in launch().
393
- demo.queue(max_size=20).launch(share=True)
 
19
  TextIteratorStreamer,
20
  Qwen2VLForConditionalGeneration,
21
  AutoProcessor,
 
22
  )
23
  from transformers.image_utils import load_image
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
25
 
26
+
27
  DESCRIPTION = """
28
  # QwQ Edge 💬
29
  """
30
+
31
  css = '''
32
  h1 {
33
  text-align: center;
 
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
+ # Load text-only model and tokenizer
 
 
52
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
53
  tokenizer = AutoTokenizer.from_pretrained(model_id)
54
  model = AutoModelForCausalLM.from_pretrained(
 
58
  )
59
  model.eval()
60
 
 
 
 
61
  TTS_VOICES = [
62
  "en-US-JennyNeural", # @tts1
63
  "en-US-GuyNeural", # @tts2
64
  ]
65
 
66
+ MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
67
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
68
+ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
69
+ MODEL_ID,
70
+ trust_remote_code=True,
71
+ torch_dtype=torch.float16
72
+ ).to("cuda").eval()
73
+
74
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
75
  """Convert text to speech using Edge TTS and save as MP3"""
76
  communicate = edge_tts.Communicate(text, voice)
 
88
  cleaned.append(msg)
89
  return cleaned
90
 
91
+ # Environment variables and parameters for Stable Diffusion XL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
93
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
94
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
95
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
96
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
97
 
98
+ # Load the SDXL pipeline
99
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
100
  MODEL_ID_SD,
101
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
104
  ).to(device)
105
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
106
 
107
+ # Ensure that the text encoder is in half-precision if using CUDA.
108
  if torch.cuda.is_available():
109
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
110
 
111
+ # Optional: compile the model for speedup if enabled
112
  if USE_TORCH_COMPILE:
113
  sd_pipe.compile()
114
 
115
+ # Optional: offload parts of the model to CPU if needed
116
  if ENABLE_CPU_OFFLOAD:
117
  sd_pipe.enable_model_cpu_offload()
118
 
 
168
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
169
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
170
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
171
+ # Wrap the pipeline call in autocast if using CUDA
172
  if device.type == "cuda":
173
  with torch.autocast("cuda", dtype=torch.float16):
174
  outputs = sd_pipe(**batch_options)
 
193
  Special commands:
194
  - "@tts1" or "@tts2": triggers text-to-speech.
195
  - "@image": triggers image generation using the SDXL pipeline.
 
196
  """
197
  text = input_dict["text"]
198
  files = input_dict.get("files", [])
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if text.strip().lower().startswith("@image"):
201
+ # Remove the "@image" tag and use the rest as prompt
202
  prompt = text[len("@image"):].strip()
203
  yield "Generating image..."
204
  image_paths, used_seed = generate_image_fn(
 
214
  use_resolution_binning=True,
215
  num_images=1,
216
  )
217
+ # Yield the generated image so that the chat interface displays it.
218
  yield gr.Image(image_paths[0])
219
  return # Exit early
220
 
 
 
 
221
  tts_prefix = "@tts"
222
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
223
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
234
  conversation = clean_chat_history(chat_history)
235
  conversation.append({"role": "user", "content": text})
236
 
 
 
 
237
  if files:
238
  if len(files) > 1:
239
  images = [load_image(image) for image in files]
 
263
  time.sleep(0.01)
264
  yield buffer
265
  else:
266
+
 
 
267
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
268
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
269
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
291
 
292
  final_response = "".join(outputs)
293
  yield final_response
294
+
295
+ # If TTS was requested, convert the final response to speech.
296
  if is_tts and voice:
297
  output_file = asyncio.run(text_to_speech(final_response, voice))
298
  yield gr.Audio(output_file, autoplay=True)
 
307
  gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2),
308
  ],
309
  examples=[
 
 
310
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
311
+ [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
312
+ [{"text": "summarize the letter", "files": ["examples/1.png"]}],
313
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
314
  ["Write a Python function to check if a number is prime."],
315
  ["@tts2 What causes rainbows to form?"],
316
+
317
  ],
318
  cache_examples=False,
319
  type="messages",
 
326
  )
327
 
328
  if __name__ == "__main__":
329
+ demo.queue(max_size=20).launch(share=True)