shivi commited on
Commit
ec60ae9
·
verified ·
1 Parent(s): eadfa0f

add app setup file

Browse files
Files changed (1) hide show
  1. app.py +491 -0
app.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ import io
4
+ import re
5
+ import time
6
+ import uuid
7
+ import torch
8
+ import cohere
9
+ import secrets
10
+ import requests
11
+ import fasttext
12
+ import replicate
13
+ import numpy as np
14
+ import gradio as gr
15
+ from PIL import Image
16
+ from groq import Groq
17
+ from TTS.api import TTS
18
+ from elevenlabs import save
19
+ from gradio.themes.base import Base
20
+ from elevenlabs.client import ElevenLabs
21
+ from huggingface_hub import hf_hub_download
22
+ from gradio.themes.utils import colors, fonts, sizes
23
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
24
+ from prompt_examples import TEXT_CHAT_EXAMPLES, IMG_GEN_PROMPT_EXAMPLES, AUDIO_EXAMPLES, TEXT_CHAT_EXAMPLES_LABELS, IMG_GEN_PROMPT_EXAMPLES_LABELS, AUDIO_EXAMPLES_LABELS
25
+ from preambles import CHAT_PREAMBLE, AUDIO_RESPONSE_PREAMBLE, IMG_DESCRIPTION_PREAMBLE
26
+ from constants import LID_LANGUAGES, NEETS_AI_LANGID_MAP, AYA_MODEL_NAME, BATCH_SIZE, USE_ELVENLABS, USE_REPLICATE
27
+
28
+ HF_API_TOKEN = os.getenv("HF_API_KEY")
29
+ ELEVEN_LABS_KEY = os.getenv("ELEVEN_LABS_KEY")
30
+ NEETS_AI_API_KEY = os.getenv("NEETS_AI_API_KEY")
31
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
32
+ IMG_COHERE_API_KEY = os.getenv("IMG_COHERE_API_KEY")
33
+ AUDIO_COHERE_API_KEY = os.getenv("AUDIO_COHERE_API_KEY")
34
+ CHAT_COHERE_API_KEY = os.getenv("CHAT_COHERE_API_KEY")
35
+
36
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
37
+
38
+ # Initialize cohere clients
39
+ img_prompt_client = cohere.Client(
40
+ api_key=IMG_COHERE_API_KEY,
41
+ client_name="c4ai-aya-expanse-img"
42
+ )
43
+ chat_client = cohere.Client(
44
+ api_key=CHAT_COHERE_API_KEY,
45
+ client_name="c4ai-aya-expanse-chat"
46
+ )
47
+ audio_response_client = cohere.Client(
48
+ api_key=AUDIO_COHERE_API_KEY,
49
+ client_name="c4ai-aya-expanse-audio"
50
+ )
51
+
52
+ # Initialize the Groq client
53
+ groq_client = Groq(api_key=GROQ_API_KEY)
54
+
55
+ # Initialize the ElevenLabs client
56
+ eleven_labs_client = ElevenLabs(
57
+ api_key=ELEVEN_LABS_KEY,
58
+ )
59
+
60
+ # Language identification
61
+ lid_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
62
+ LID_model = fasttext.load_model(lid_model_path)
63
+
64
+ def predict_language(text):
65
+ text = re.sub("\n", " ", text)
66
+ label, logit = LID_model.predict(text)
67
+ label = label[0][len("__label__") :]
68
+ print("predicted language:", label)
69
+ return label
70
+
71
+ # Image Generation util functions
72
+ def get_hf_inference_api_response(payload, model_id):
73
+ headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
74
+ MODEL_API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
75
+ response = requests.post(MODEL_API_URL, headers=headers, json=payload)
76
+ return response.content
77
+
78
+ def replicate_api_inference(input_prompt):
79
+ input_params={
80
+ "prompt": input_prompt,
81
+ "go_fast": True,
82
+ "megapixels": "1",
83
+ "num_outputs": 1,
84
+ "aspect_ratio": "1:1",
85
+ "output_format": "jpg",
86
+ "output_quality": 80,
87
+ "num_inference_steps": 4
88
+ }
89
+ image = replicate.run("black-forest-labs/flux-schnell",input=input_params)
90
+ image = Image.open(image[0])
91
+ return image
92
+
93
+ def generate_image(input_prompt, model_id="black-forest-labs/FLUX.1-schnell"):
94
+ if input_prompt!="":
95
+ if input_prompt=='Image generation blocked for prompts that include humans, kids, or children.':
96
+ return None
97
+ else:
98
+ if USE_REPLICATE:
99
+ print("using replicate for image generation")
100
+ image = replicate_api_inference(input_prompt)
101
+ else:
102
+ try:
103
+ print("using HF inference API for image generation")
104
+ image_bytes = get_hf_inference_api_response({ "inputs": input_prompt}, model_id)
105
+ image = np.array(Image.open(io.BytesIO(image_bytes)))
106
+ except Exception as e:
107
+ print("HF API error:", e)
108
+ # generate image with help replicate in case of error
109
+ image = replicate_api_inference(input_prompt)
110
+ return image
111
+ else:
112
+ return None
113
+
114
+ def generate_img_prompt(input_prompt):
115
+ # clean prompt before doing language detection
116
+ cleaned_prompt = clean_text(input_prompt, remove_bullets=True, remove_newline=True)
117
+ text_lang_code = predict_language(cleaned_prompt)
118
+ language = LID_LANGUAGES[text_lang_code]
119
+
120
+ gr.Info("Generating Image", duration=2)
121
+
122
+ if language!="english":
123
+ text = f"""
124
+ Translate the given input prompt to English.
125
+ Input Prompt: {input_prompt}
126
+ Once translated, use the English version of the prompt to create a detailed image description suitable for a text-to-image model.
127
+ Ensure the description is concise, limited to 2-3 lines, and integrates key elements from the translated prompt.
128
+ Add the prompt English translation to the image description, and respond with that.
129
+ """
130
+ else:
131
+ text = f"""Generate a detailed image description which can be used to generate an image using a text-to-image model based on the given input prompt:
132
+ Input Prompt: {input_prompt}
133
+ Do not use more than 3-4 lines for the description.
134
+ """
135
+
136
+ response = img_prompt_client.chat(message=text, preamble=IMG_DESCRIPTION_PREAMBLE, model=AYA_MODEL_NAME)
137
+ output = response.text
138
+
139
+ return output
140
+
141
+
142
+ # Chat with Aya util functions
143
+
144
+ def trigger_example(example):
145
+ chat, updated_history = generate_aya_chat_response(example)
146
+ return chat, updated_history
147
+
148
+ def generate_aya_chat_response(user_message, cid, token, history=None):
149
+ if not token:
150
+ raise gr.Error("Error loading.")
151
+
152
+ if history is None:
153
+ history = []
154
+ if cid == "" or None:
155
+ cid = str(uuid.uuid4())
156
+
157
+ print(f"cid: {cid} prompt:{user_message}")
158
+
159
+ history.append(user_message)
160
+
161
+ stream = chat_client.chat_stream(message=user_message, preamble=CHAT_PREAMBLE, conversation_id=cid, model=AYA_MODEL_NAME, connectors=[], temperature=0.3)
162
+ output = ""
163
+
164
+ for idx, response in enumerate(stream):
165
+ if response.event_type == "text-generation":
166
+ output += response.text
167
+ if idx == 0:
168
+ history.append(" " + output)
169
+ else:
170
+ history[-1] = output
171
+ chat = [
172
+ (history[i].strip(), history[i + 1].strip())
173
+ for i in range(0, len(history) - 1, 2)
174
+ ]
175
+ yield chat, history, cid
176
+
177
+ return chat, history, cid
178
+
179
+
180
+ def clear_chat():
181
+ return [], [], str(uuid.uuid4())
182
+
183
+ # Audio Pipeline util functions
184
+
185
+ def transcribe_and_stream(inputs, show_info="no", model_name="openai/whisper-large-v3-turbo", language="english"):
186
+ if inputs is not None and inputs!="":
187
+ if show_info=="show_info":
188
+ gr.Info("Processing Audio", duration=1)
189
+ if model_name != "groq_whisper":
190
+ print("DEVICE:", DEVICE)
191
+ pipe = pipeline(
192
+ task="automatic-speech-recognition",
193
+ model=model_name,
194
+ chunk_length_s=30,
195
+ DEVICE=DEVICE)
196
+ text = pipe(inputs, batch_size=BATCH_SIZE, return_timestamps=True)["text"]
197
+ else:
198
+ text = groq_whisper_tts(inputs)
199
+
200
+ # stream text output
201
+ for i in range(len(text)):
202
+ time.sleep(0.01)
203
+ yield text[: i + 10]
204
+ else:
205
+ return ""
206
+
207
+
208
+ def aya_speech_text_response(text):
209
+ if text is not None and text!="":
210
+ stream = audio_response_client.chat_stream(message=text,preamble=AUDIO_RESPONSE_PREAMBLE, model=AYA_MODEL_NAME)
211
+ output = ""
212
+
213
+ for event in stream:
214
+ if event:
215
+ if event.event_type == "text-generation":
216
+ output+=event.text
217
+ cleaned_output = clean_text(output)
218
+ yield cleaned_output
219
+ else:
220
+ return ""
221
+
222
+ def clean_text(text, remove_bullets=False, remove_newline=False):
223
+ # Remove bold formatting
224
+ cleaned_text = re.sub(r"\*\*", "", text)
225
+
226
+ if remove_bullets:
227
+ cleaned_text = re.sub(r"^- ", "", cleaned_text, flags=re.MULTILINE)
228
+
229
+ if remove_newline:
230
+ cleaned_text = re.sub(r"\n", " ", cleaned_text)
231
+
232
+ return cleaned_text
233
+
234
+ def convert_text_to_speech(text, language="english"):
235
+
236
+ # do language detection to determine voice of speech response
237
+ if text is not None and text!="":
238
+ # clean text before doing language detection
239
+ cleaned_text = clean_text(text, remove_bullets=True, remove_newline=True)
240
+ text_lang_code = predict_language(cleaned_text)
241
+ language = LID_LANGUAGES[text_lang_code]
242
+
243
+ if not USE_ELVENLABS:
244
+ if language!= "japanese":
245
+ audio_path = neetsai_tts(text, language)
246
+ else:
247
+ print("DEVICE:", DEVICE)
248
+ # if language is japanese then use XTTS for TTS since neets_ai doesn't support japanese voice
249
+ tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(DEVICE)
250
+ speaker_wav="samples/ja-sample.wav"
251
+ lang_code="ja"
252
+ audio_path = "./output.wav"
253
+ tts.tts_to_file(text=text, speaker_wav=speaker_wav, language=lang_code, file_path=audio_path)
254
+ else:
255
+ # use elevenlabs for TTS
256
+ audio_path = elevenlabs_generate_audio(text)
257
+
258
+ return audio_path
259
+ else:
260
+ return None
261
+
262
+ def elevenlabs_generate_audio(text):
263
+ audio = eleven_labs_client.generate(
264
+ text=text,
265
+ voice="River",
266
+ model="eleven_turbo_v2_5", #"eleven_multilingual_v2"
267
+ )
268
+ # save audio
269
+ audio_path = "./audio.mp3"
270
+ save(audio, audio_path)
271
+ return audio_path
272
+
273
+ def neetsai_tts(input_text, language):
274
+
275
+ lang_id = NEETS_AI_LANGID_MAP[language]
276
+ neets_vits_voice_id = f"vits-{lang_id}"
277
+
278
+ response = requests.request(
279
+ method="POST",
280
+ url="https://api.neets.ai/v1/tts",
281
+ headers={
282
+ "Content-Type": "application/json",
283
+ "X-API-Key": NEETS_AI_API_KEY
284
+ },
285
+ json={
286
+ "text": input_text,
287
+ "voice_id": neets_vits_voice_id,
288
+ "params": {
289
+ "model": "vits"
290
+ }
291
+ }
292
+ )
293
+ # save audio file
294
+ audio_path = "neets_demo.mp3"
295
+ with open(audio_path, "wb") as f:
296
+ f.write(response.content)
297
+ return audio_path
298
+
299
+ def groq_whisper_tts(filename):
300
+ with open(filename, "rb") as file:
301
+ transcriptions = groq_client.audio.transcriptions.create(
302
+ file=(filename, file.read()),
303
+ model="whisper-large-v3-turbo",
304
+ response_format="json",
305
+ temperature=0.0
306
+ )
307
+ print("transcribed text:", transcriptions.text)
308
+ print("********************************")
309
+ return transcriptions.text
310
+
311
+
312
+ # setup gradio app theme
313
+ theme = gr.themes.Base(
314
+ primary_hue=gr.themes.colors.teal,
315
+ secondary_hue=gr.themes.colors.blue,
316
+ neutral_hue=gr.themes.colors.gray,
317
+ text_size=gr.themes.sizes.text_lg,
318
+ ).set(
319
+ # Primary Button Color
320
+ button_primary_background_fill="#114A56",
321
+ button_primary_background_fill_hover="#114A56",
322
+ # Block Labels
323
+ block_title_text_weight="600",
324
+ block_label_text_weight="600",
325
+ block_label_text_size="*text_md",
326
+ )
327
+
328
+
329
+ demo = gr.Blocks(theme=theme, analytics_enabled=False)
330
+
331
+ with demo:
332
+ with gr.Row(variant="panel"):
333
+ with gr.Column(scale=1):
334
+ gr.Image("aya-expanse.png", elem_id="logo-img", show_label=False, show_share_button=False, show_download_button=False, show_fullscreen_button=False)
335
+ with gr.Column(scale=30):
336
+ gr.Markdown("""C4AI Aya Expanse is a state-of-art model with highly advanced capabilities to connect the world across languages.
337
+ <br/>
338
+ You can use this space to chat, speak and visualize with Aya Expanse in 23 languages.
339
+
340
+ **Developed by**: [Cohere for AI](https://cohere.com/research) and [Cohere](https://cohere.com/)
341
+ """
342
+ )
343
+ # Text Chat
344
+ with gr.TabItem("Chat with Aya") as chat_with_aya:
345
+ cid = gr.State("")
346
+ token = gr.State(value=None)
347
+
348
+ with gr.Column():
349
+ with gr.Row():
350
+ chatbot = gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, height=300)
351
+
352
+ with gr.Row():
353
+ user_message = gr.Textbox(lines=1, placeholder="Ask anything in our 23 languages ...", label="Input", show_label=False)
354
+
355
+
356
+ with gr.Row():
357
+ submit_button = gr.Button("Submit",variant="primary")
358
+ clear_button = gr.Button("Clear")
359
+
360
+
361
+ history = gr.State([])
362
+
363
+ user_message.submit(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32)
364
+ submit_button.click(fn=generate_aya_chat_response, inputs=[user_message, cid, token, history], outputs=[chatbot, history, cid], concurrency_limit=32)
365
+
366
+ clear_button.click(fn=clear_chat, inputs=None, outputs=[chatbot, history, cid], concurrency_limit=32)
367
+
368
+ user_message.submit(lambda x: gr.update(value=""), None, [user_message], queue=False)
369
+ submit_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False)
370
+ clear_button.click(lambda x: gr.update(value=""), None, [user_message], queue=False)
371
+
372
+ with gr.Row():
373
+ gr.Examples(
374
+ examples=TEXT_CHAT_EXAMPLES,
375
+ inputs=user_message,
376
+ cache_examples=False,
377
+ fn=trigger_example,
378
+ outputs=[chatbot],
379
+ examples_per_page=25,
380
+ label="Load example prompt for:",
381
+ example_labels=TEXT_CHAT_EXAMPLES_LABELS,
382
+ )
383
+
384
+ # Audio Pipeline
385
+ with gr.TabItem("Speak with Aya") as speak_with_aya:
386
+
387
+ with gr.Row():
388
+ with gr.Column():
389
+ e2e_audio_file = gr.Audio(sources="microphone", type="filepath", min_length=None)
390
+
391
+ clear_button_microphone = gr.ClearButton()
392
+ gr.Examples(
393
+ examples=AUDIO_EXAMPLES,
394
+ inputs=e2e_audio_file,
395
+ cache_examples=False,
396
+ examples_per_page=25,
397
+ label="Load example audio for:",
398
+ example_labels=AUDIO_EXAMPLES_LABELS,
399
+ )
400
+
401
+ with gr.Column():
402
+ e2e_audio_file_trans = gr.Textbox(lines=3,label="Your Input", autoscroll=False, show_copy_button=True, interactive=False)
403
+ e2e_audio_file_aya_response = gr.Textbox(lines=3,label="Aya's Response", show_copy_button=True, container=True, interactive=False)
404
+ e2e_aya_audio_response = gr.Audio(type="filepath", label="Aya's Audio Response")
405
+
406
+ show_info = gr.Textbox(value="show_info", visible=False)
407
+ stt_model = gr.Textbox(value="groq_whisper", visible=False)
408
+
409
+ with gr.Accordion("See Details", open=False):
410
+ gr.Markdown("To enable voice interaction with Aya Expanse, this space uses [Whisper large-v3-turbo](https://huggingface.co/openai/whisper-large-v3-turbo) and [Groq](https://groq.com/) for STT and [neets.ai](http://neets.ai/) for TTS.")
411
+
412
+
413
+ # Image Generation
414
+ with gr.TabItem("Visualize with Aya") as visualize_with_aya:
415
+ with gr.Row():
416
+ with gr.Column():
417
+ input_img_prompt = gr.Textbox(placeholder="Ask anything in our 23 languages ...", label="Describe an image", lines=3)
418
+ # generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False)
419
+ submit_button_img = gr.Button(value="Submit", variant="primary")
420
+ clear_button_img = gr.ClearButton()
421
+
422
+
423
+ with gr.Column():
424
+ generated_img = gr.Image(label="Generated Image", interactive=False)
425
+
426
+ with gr.Row():
427
+ gr.Examples(
428
+ examples=IMG_GEN_PROMPT_EXAMPLES,
429
+ inputs=input_img_prompt,
430
+ cache_examples=False,
431
+ examples_per_page=25,
432
+ label="Load example prompt for:",
433
+ example_labels=IMG_GEN_PROMPT_EXAMPLES_LABELS
434
+ )
435
+ generated_img_desc = gr.Textbox(label="Image Description generated by Aya", interactive=False, lines=3, visible=False)
436
+
437
+ # increase spacing between examples and Accordion components
438
+ with gr.Row():
439
+ pass
440
+ with gr.Row():
441
+ pass
442
+ with gr.Row():
443
+ pass
444
+
445
+ with gr.Row():
446
+ with gr.Accordion("See Details", open=False):
447
+ gr.Markdown("This space uses Aya Expanse for translating multilingual prompts and generating detailed image descriptions and [Flux Schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) for Image Generation.")
448
+
449
+ # Image Generation
450
+ clear_button_img.click(lambda: None, None, input_img_prompt)
451
+ clear_button_img.click(lambda: None, None, generated_img_desc)
452
+ clear_button_img.click(lambda: None, None, generated_img)
453
+
454
+ submit_button_img.click(
455
+ generate_img_prompt,
456
+ inputs=[input_img_prompt],
457
+ outputs=[generated_img_desc],
458
+ )
459
+
460
+ generated_img_desc.change(
461
+ generate_image, #run_flux,
462
+ inputs=[generated_img_desc],
463
+ outputs=[generated_img],
464
+ show_progress="hidden",
465
+ )
466
+
467
+ # Audio Pipeline
468
+ clear_button_microphone.click(lambda: None, None, e2e_audio_file)
469
+ clear_button_microphone.click(lambda: None, None, e2e_audio_file_trans)
470
+ clear_button_microphone.click(lambda: None, None, e2e_aya_audio_response)
471
+
472
+ e2e_audio_file.change(
473
+ transcribe_and_stream,
474
+ inputs=[e2e_audio_file, show_info, stt_model],
475
+ outputs=[e2e_audio_file_trans],
476
+ show_progress="hidden",
477
+ ).then(
478
+ aya_speech_text_response,
479
+ inputs=[e2e_audio_file_trans],
480
+ outputs=[e2e_audio_file_aya_response],
481
+ show_progress="minimal",
482
+ ).then(
483
+ convert_text_to_speech,
484
+ inputs=[e2e_audio_file_aya_response],
485
+ outputs=[e2e_aya_audio_response],
486
+ show_progress="minimal",
487
+ )
488
+
489
+ demo.load(lambda: secrets.token_hex(16), None, token)
490
+
491
+ demo.queue(api_open=False, max_size=40).launch(show_api=False, allowed_paths=['/home/user/app'])