prithivMLmods commited on
Commit
3a6718d
·
verified ·
1 Parent(s): 54f4624

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -6
app.py CHANGED
@@ -9,6 +9,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStream
9
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
10
  from transformers.image_utils import load_image
11
  import time
 
12
 
13
  DESCRIPTION = """
14
  # QwQ Edge 💬
@@ -58,12 +59,20 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained(
58
  torch_dtype=torch.float16
59
  ).to("cuda").eval()
60
 
 
 
 
61
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
62
  """Convert text to speech using Edge TTS and save as MP3"""
63
  communicate = edge_tts.Communicate(text, voice)
64
  await communicate.save(output_file)
65
  return output_file
66
 
 
 
 
 
 
67
  def clean_chat_history(chat_history):
68
  """
69
  Filter out any chat entries whose "content" is not a string.
@@ -86,8 +95,8 @@ def generate(
86
  repetition_penalty: float = 1.2,
87
  ):
88
  """
89
- Generates chatbot responses with support for multimodal input and TTS.
90
- If the query starts with an @tts command (e.g. "@tts1"), previous chat history is cleared.
91
  """
92
  text = input_dict["text"]
93
  files = input_dict.get("files", [])
@@ -100,22 +109,36 @@ def generate(
100
  else:
101
  images = []
102
 
 
103
  tts_prefix = "@tts"
 
104
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
105
- voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
106
 
107
- if is_tts and voice_index:
108
- voice = TTS_VOICES[voice_index - 1]
 
109
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
110
  # Clear any previous chat history to avoid concatenation issues
111
  conversation = [{"role": "user", "content": text}]
 
 
 
112
  else:
113
  voice = None
114
  text = text.replace(tts_prefix, "").strip()
115
  conversation = clean_chat_history(chat_history)
116
  conversation.append({"role": "user", "content": text})
117
 
118
- if images:
 
 
 
 
 
 
 
 
119
  # Multimodal branch using the OCR model
120
  messages = [{
121
  "role": "user",
@@ -183,6 +206,7 @@ demo = gr.ChatInterface(
183
  ],
184
  examples=[
185
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
 
186
  [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
187
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
188
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],
 
9
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
10
  from transformers.image_utils import load_image
11
  import time
12
+ from gradio_client import Client # For image generation API
13
 
14
  DESCRIPTION = """
15
  # QwQ Edge 💬
 
59
  torch_dtype=torch.float16
60
  ).to("cuda").eval()
61
 
62
+ # Image generation client
63
+ image_gen_client = Client("prithivMLmods/STABLE-HAMSTER")
64
+
65
  async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
66
  """Convert text to speech using Edge TTS and save as MP3"""
67
  communicate = edge_tts.Communicate(text, voice)
68
  await communicate.save(output_file)
69
  return output_file
70
 
71
+ def image_gen(prompt: str):
72
+ """Generate an image using the Stable Hamster API"""
73
+ result = image_gen_client.predict("Image Generation", None, prompt, api_name="/stable_hamster")
74
+ return result[1] # Return the generated image
75
+
76
  def clean_chat_history(chat_history):
77
  """
78
  Filter out any chat entries whose "content" is not a string.
 
95
  repetition_penalty: float = 1.2,
96
  ):
97
  """
98
+ Generates chatbot responses with support for multimodal input, TTS, and image generation.
99
+ If the query starts with an @tts or @image command, previous chat history is cleared.
100
  """
101
  text = input_dict["text"]
102
  files = input_dict.get("files", [])
 
109
  else:
110
  images = []
111
 
112
+ # Check for TTS or Image Generation commands
113
  tts_prefix = "@tts"
114
+ image_prefix = "@image"
115
  is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
116
+ is_image = text.strip().lower().startswith(image_prefix)
117
 
118
+ if is_tts:
119
+ voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
120
+ voice = TTS_VOICES[voice_index - 1] if voice_index else None
121
  text = text.replace(f"{tts_prefix}{voice_index}", "").strip()
122
  # Clear any previous chat history to avoid concatenation issues
123
  conversation = [{"role": "user", "content": text}]
124
+ elif is_image:
125
+ text = text.replace(image_prefix, "").strip()
126
+ conversation = [{"role": "user", "content": text}]
127
  else:
128
  voice = None
129
  text = text.replace(tts_prefix, "").strip()
130
  conversation = clean_chat_history(chat_history)
131
  conversation.append({"role": "user", "content": text})
132
 
133
+ if is_image:
134
+ # Image generation branch
135
+ yield "Generating image, please wait..."
136
+ try:
137
+ image = image_gen(text)
138
+ yield gr.Image(image)
139
+ except Exception as e:
140
+ yield f"Failed to generate image: {str(e)}"
141
+ elif images:
142
  # Multimodal branch using the OCR model
143
  messages = [{
144
  "role": "user",
 
206
  ],
207
  examples=[
208
  ["@tts1 Who is Nikola Tesla, and why did he die?"],
209
+ ["@image A futuristic cityscape at sunset"],
210
  [{"text": "Extract JSON from the image", "files": ["examples/document.jpg"]}],
211
  [{"text": "summarize the letter", "files": ["examples/1.png"]}],
212
  ["A train travels 60 kilometers per hour. If it travels for 5 hours, how far will it travel in total?"],