prithivMLmods commited on
Commit
3df271a
·
verified ·
1 Parent(s): bd34d66

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -24
app.py CHANGED
@@ -4,6 +4,7 @@ import uuid
4
  import json
5
  import time
6
  import asyncio
 
7
  from threading import Thread
8
 
9
  import gradio as gr
@@ -12,6 +13,7 @@ import torch
12
  import numpy as np
13
  from PIL import Image
14
  import edge_tts
 
15
 
16
  from transformers import (
17
  AutoModelForCausalLM,
@@ -21,8 +23,75 @@ from transformers import (
21
  AutoProcessor,
22
  )
23
  from transformers.image_utils import load_image
 
24
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  DESCRIPTION = """
28
  # QwQ Edge 💬
@@ -48,7 +117,7 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
- # Load text-only model and tokenizer
52
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
53
  tokenizer = AutoTokenizer.from_pretrained(model_id)
54
  model = AutoModelForCausalLM.from_pretrained(
@@ -58,11 +127,13 @@ model = AutoModelForCausalLM.from_pretrained(
58
  )
59
  model.eval()
60
 
 
61
  TTS_VOICES = [
62
  "en-US-JennyNeural", # @tts1
63
  "en-US-GuyNeural", # @tts2
64
  ]
65
 
 
66
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
67
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
68
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
@@ -88,14 +159,12 @@ def clean_chat_history(chat_history):
88
  cleaned.append(msg)
89
  return cleaned
90
 
91
- # Environment variables and parameters for Stable Diffusion XL
92
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
93
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
94
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
95
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
96
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
97
 
98
- # Load the SDXL pipeline
99
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
100
  MODEL_ID_SD,
101
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
@@ -104,31 +173,21 @@ sd_pipe = StableDiffusionXLPipeline.from_pretrained(
104
  ).to(device)
105
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
106
 
107
- # Ensure that the text encoder is in half-precision if using CUDA.
108
  if torch.cuda.is_available():
109
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
110
 
111
- # Optional: compile the model for speedup if enabled
112
  if USE_TORCH_COMPILE:
113
  sd_pipe.compile()
114
 
115
- # Optional: offload parts of the model to CPU if needed
116
  if ENABLE_CPU_OFFLOAD:
117
  sd_pipe.enable_model_cpu_offload()
118
 
119
- MAX_SEED = np.iinfo(np.int32).max
120
-
121
  def save_image(img: Image.Image) -> str:
122
  """Save a PIL image with a unique filename and return the path."""
123
  unique_name = str(uuid.uuid4()) + ".png"
124
  img.save(unique_name)
125
  return unique_name
126
 
127
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
128
- if randomize_seed:
129
- seed = random.randint(0, MAX_SEED)
130
- return seed
131
-
132
  @spaces.GPU(duration=60, enable_queue=True)
133
  def generate_image_fn(
134
  prompt: str,
@@ -168,7 +227,6 @@ def generate_image_fn(
168
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
169
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
170
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
171
- # Wrap the pipeline call in autocast if using CUDA
172
  if device.type == "cuda":
173
  with torch.autocast("cuda", dtype=torch.float16):
174
  outputs = sd_pipe(**batch_options)
@@ -178,6 +236,23 @@ def generate_image_fn(
178
  image_paths = [save_image(img) for img in images]
179
  return image_paths, seed
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  @spaces.GPU
182
  def generate(
183
  input_dict: dict,
@@ -189,16 +264,34 @@ def generate(
189
  repetition_penalty: float = 1.2,
190
  ):
191
  """
192
- Generates chatbot responses with support for multimodal input, TTS, and image generation.
 
 
193
  Special commands:
194
  - "@tts1" or "@tts2": triggers text-to-speech.
195
  - "@image": triggers image generation using the SDXL pipeline.
 
196
  """
197
  text = input_dict["text"]
198
  files = input_dict.get("files", [])
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  if text.strip().lower().startswith("@image"):
201
- # Remove the "@image" tag and use the rest as prompt
202
  prompt = text[len("@image"):].strip()
203
  yield "Generating image..."
204
  image_paths, used_seed = generate_image_fn(
@@ -214,10 +307,10 @@ def generate(
214
  use_resolution_binning=True,
215
  num_images=1,
216
  )
217
- # Yield the generated image so that the chat interface displays it.
218
  yield gr.Image(image_paths[0])
219
- return # Exit early
220
 
 
221
  tts_prefix = "@tts"
222
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
223
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
@@ -225,11 +318,9 @@ def generate(
225
  if is_tts and voice_index:
226
  voice = TTS_VOICES[voice_index - 1]
227
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
228
- # Clear previous chat history for a fresh TTS request.
229
  conversation = [{"role": "user", "content": text}]
230
  else:
231
  voice = None
232
- # Remove any stray @tts tags and build the conversation history.
233
  text = text.replace(tts_prefix, "").strip()
234
  conversation = clean_chat_history(chat_history)
235
  conversation.append({"role": "user", "content": text})
@@ -263,7 +354,6 @@ def generate(
263
  time.sleep(0.01)
264
  yield buffer
265
  else:
266
-
267
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
268
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
269
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
@@ -292,7 +382,6 @@ def generate(
292
  final_response = "".join(outputs)
293
  yield final_response
294
 
295
- # If TTS was requested, convert the final response to speech.
296
  if is_tts and voice:
297
  output_file = asyncio.run(text_to_speech(final_response, voice))
298
  yield gr.Audio(output_file, autoplay=True)
@@ -308,12 +397,11 @@ demo = gr.ChatInterface(
308
  ],
309
  examples=[
310
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
311
- [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
312
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
313
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
314
  ["Write a Python function to check if a number is prime."],
315
  ["@tts2 What causes rainbows to form?"],
316
-
317
  ],
318
  cache_examples=False,
319
  type="messages",
 
4
  import json
5
  import time
6
  import asyncio
7
+ import tempfile
8
  from threading import Thread
9
 
10
  import gradio as gr
 
13
  import numpy as np
14
  from PIL import Image
15
  import edge_tts
16
+ import trimesh
17
 
18
  from transformers import (
19
  AutoModelForCausalLM,
 
23
  AutoProcessor,
24
  )
25
  from transformers.image_utils import load_image
26
+
27
  from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
28
+ from diffusers import ShapEImg2ImgPipeline, ShapEPipeline
29
+ from diffusers.utils import export_to_ply
30
+
31
+
32
+ MAX_SEED = np.iinfo(np.int32).max
33
+
34
+ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
35
+ if randomize_seed:
36
+ seed = random.randint(0, MAX_SEED)
37
+ return seed
38
 
39
+ class Model:
40
+ def __init__(self):
41
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
+ self.pipe = ShapEPipeline.from_pretrained("openai/shap-e", torch_dtype=torch.float16)
43
+ self.pipe.to(self.device)
44
+ # Ensure the text encoder is in half precision to avoid dtype mismatches.
45
+ if torch.cuda.is_available():
46
+ try:
47
+ self.pipe.text_encoder = self.pipe.text_encoder.half()
48
+ except AttributeError:
49
+ pass
50
+
51
+ self.pipe_img = ShapEImg2ImgPipeline.from_pretrained("openai/shap-e-img2img", torch_dtype=torch.float16)
52
+ self.pipe_img.to(self.device)
53
+ # Use getattr with a default value to avoid AttributeError if text_encoder is missing.
54
+ if torch.cuda.is_available():
55
+ text_encoder_img = getattr(self.pipe_img, "text_encoder", None)
56
+ if text_encoder_img is not None:
57
+ self.pipe_img.text_encoder = text_encoder_img.half()
58
+
59
+ def to_glb(self, ply_path: str) -> str:
60
+ mesh = trimesh.load(ply_path)
61
+ # Rotate the mesh for proper orientation
62
+ rot = trimesh.transformations.rotation_matrix(-np.pi / 2, [1, 0, 0])
63
+ mesh.apply_transform(rot)
64
+ rot = trimesh.transformations.rotation_matrix(np.pi, [0, 1, 0])
65
+ mesh.apply_transform(rot)
66
+ mesh_path = tempfile.NamedTemporaryFile(suffix=".glb", delete=False)
67
+ mesh.export(mesh_path.name, file_type="glb")
68
+ return mesh_path.name
69
+
70
+ def run_text(self, prompt: str, seed: int = 0, guidance_scale: float = 15.0, num_steps: int = 64) -> str:
71
+ generator = torch.Generator(device=self.device).manual_seed(seed)
72
+ images = self.pipe(
73
+ prompt,
74
+ generator=generator,
75
+ guidance_scale=guidance_scale,
76
+ num_inference_steps=num_steps,
77
+ output_type="mesh",
78
+ ).images
79
+ ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
80
+ export_to_ply(images[0], ply_path.name)
81
+ return self.to_glb(ply_path.name)
82
+
83
+ def run_image(self, image: Image.Image, seed: int = 0, guidance_scale: float = 3.0, num_steps: int = 64) -> str:
84
+ generator = torch.Generator(device=self.device).manual_seed(seed)
85
+ images = self.pipe_img(
86
+ image,
87
+ generator=generator,
88
+ guidance_scale=guidance_scale,
89
+ num_inference_steps=num_steps,
90
+ output_type="mesh",
91
+ ).images
92
+ ply_path = tempfile.NamedTemporaryFile(suffix=".ply", delete=False, mode="w+b")
93
+ export_to_ply(images[0], ply_path.name)
94
+ return self.to_glb(ply_path.name)
95
 
96
  DESCRIPTION = """
97
  # QwQ Edge 💬
 
117
 
118
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
119
 
120
+ # Load the text-only model and tokenizer (for pure text chat)
121
  model_id = "prithivMLmods/FastThink-0.5B-Tiny"
122
  tokenizer = AutoTokenizer.from_pretrained(model_id)
123
  model = AutoModelForCausalLM.from_pretrained(
 
127
  )
128
  model.eval()
129
 
130
+ # Voices for text-to-speech
131
  TTS_VOICES = [
132
  "en-US-JennyNeural", # @tts1
133
  "en-US-GuyNeural", # @tts2
134
  ]
135
 
136
+ # Load multimodal processor and model (e.g. for OCR and image processing)
137
  MODEL_ID = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
138
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
139
  model_m = Qwen2VLForConditionalGeneration.from_pretrained(
 
159
  cleaned.append(msg)
160
  return cleaned
161
 
 
162
  MODEL_ID_SD = os.getenv("MODEL_VAL_PATH") # SDXL Model repository path via env variable
163
  MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "4096"))
164
  USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
165
  ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
166
  BATCH_SIZE = int(os.getenv("BATCH_SIZE", "1")) # For batched image generation
167
 
 
168
  sd_pipe = StableDiffusionXLPipeline.from_pretrained(
169
  MODEL_ID_SD,
170
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
173
  ).to(device)
174
  sd_pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(sd_pipe.scheduler.config)
175
 
 
176
  if torch.cuda.is_available():
177
  sd_pipe.text_encoder = sd_pipe.text_encoder.half()
178
 
 
179
  if USE_TORCH_COMPILE:
180
  sd_pipe.compile()
181
 
 
182
  if ENABLE_CPU_OFFLOAD:
183
  sd_pipe.enable_model_cpu_offload()
184
 
 
 
185
  def save_image(img: Image.Image) -> str:
186
  """Save a PIL image with a unique filename and return the path."""
187
  unique_name = str(uuid.uuid4()) + ".png"
188
  img.save(unique_name)
189
  return unique_name
190
 
 
 
 
 
 
191
  @spaces.GPU(duration=60, enable_queue=True)
192
  def generate_image_fn(
193
  prompt: str,
 
227
  batch_options["prompt"] = options["prompt"][i:i+BATCH_SIZE]
228
  if "negative_prompt" in batch_options and batch_options["negative_prompt"] is not None:
229
  batch_options["negative_prompt"] = options["negative_prompt"][i:i+BATCH_SIZE]
 
230
  if device.type == "cuda":
231
  with torch.autocast("cuda", dtype=torch.float16):
232
  outputs = sd_pipe(**batch_options)
 
236
  image_paths = [save_image(img) for img in images]
237
  return image_paths, seed
238
 
239
+ @spaces.GPU(duration=120, enable_queue=True)
240
+ def generate_3d_fn(
241
+ prompt: str,
242
+ seed: int = 1,
243
+ guidance_scale: float = 15.0,
244
+ num_steps: int = 64,
245
+ randomize_seed: bool = False,
246
+ ):
247
+ """
248
+ Generate a 3D model from text using the ShapE pipeline.
249
+ Returns a tuple of (glb_file_path, used_seed).
250
+ """
251
+ seed = int(randomize_seed_fn(seed, randomize_seed))
252
+ model3d = Model()
253
+ glb_path = model3d.run_text(prompt, seed=seed, guidance_scale=guidance_scale, num_steps=num_steps)
254
+ return glb_path, seed
255
+
256
  @spaces.GPU
257
  def generate(
258
  input_dict: dict,
 
264
  repetition_penalty: float = 1.2,
265
  ):
266
  """
267
+ Generates chatbot responses with support for multimodal input, TTS, image generation,
268
+ and 3D model generation.
269
+
270
  Special commands:
271
  - "@tts1" or "@tts2": triggers text-to-speech.
272
  - "@image": triggers image generation using the SDXL pipeline.
273
+ - "@3d": triggers 3D model generation using the ShapE pipeline.
274
  """
275
  text = input_dict["text"]
276
  files = input_dict.get("files", [])
277
 
278
+ # --- 3D Generation branch ---
279
+ if text.strip().lower().startswith("@3d"):
280
+ prompt = text[len("@3d"):].strip()
281
+ yield "Generating 3D model..."
282
+ glb_path, used_seed = generate_3d_fn(
283
+ prompt=prompt,
284
+ seed=1,
285
+ guidance_scale=15.0,
286
+ num_steps=64,
287
+ randomize_seed=True,
288
+ )
289
+ # Instead of returning as a file, yield a 3D model component so it displays inline.
290
+ yield gr.Model3D(value=glb_path, label="3D Model")
291
+ return
292
+
293
+ # --- Image Generation branch ---
294
  if text.strip().lower().startswith("@image"):
 
295
  prompt = text[len("@image"):].strip()
296
  yield "Generating image..."
297
  image_paths, used_seed = generate_image_fn(
 
307
  use_resolution_binning=True,
308
  num_images=1,
309
  )
 
310
  yield gr.Image(image_paths[0])
311
+ return
312
 
313
+ # --- Text and TTS branch ---
314
  tts_prefix = "@tts"
315
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
316
  voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
 
318
  if is_tts and voice_index:
319
  voice = TTS_VOICES[voice_index - 1]
320
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
 
321
  conversation = [{"role": "user", "content": text}]
322
  else:
323
  voice = None
 
324
  text = text.replace(tts_prefix, "").strip()
325
  conversation = clean_chat_history(chat_history)
326
  conversation.append({"role": "user", "content": text})
 
354
  time.sleep(0.01)
355
  yield buffer
356
  else:
 
357
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
358
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
359
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
 
382
  final_response = "".join(outputs)
383
  yield final_response
384
 
 
385
  if is_tts and voice:
386
  output_file = asyncio.run(text_to_speech(final_response, voice))
387
  yield gr.Audio(output_file, autoplay=True)
 
397
  ],
398
  examples=[
399
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
400
+ ["@3d A birthday cupcake with cherry"],
401
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
402
  ["@image Chocolate dripping from a donut against a yellow background, in the style of brocore, hyper-realistic"],
403
  ["Write a Python function to check if a number is prime."],
404
  ["@tts2 What causes rainbows to form?"],
 
405
  ],
406
  cache_examples=False,
407
  type="messages",