sasha / app.py
PolarO3O's picture
Update app.py
e1f30ae verified
raw
history blame contribute delete
No virus
4.73 kB
import gradio as gr
from transformers import pipeline
import requests
import json
import edge_tts
import asyncio
import tempfile
import os
import inflect
from huggingface_hub import InferenceClient
import re
import time
from streaming_stt_nemo import Model
default_lang = "en"
engines = { default_lang: Model(default_lang) }
def transcribe(audio):
lang = "en"
model = engines[lang]
text = model.stt_file(audio)[0]
return text
client1 = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
system_instructions1 = "<s> Sasha is a friendly and enthusiastic person, born in Florida, United States. She enjoys casual and fun conversations. She responds in a brief and natural way, like a real person would. Feel free to ask questions, tell jokes, or just chat with her. [USER]"
def model(text):
generate_kwargs = dict(
temperature=1.0,
max_new_tokens=128, #def 512 very long sometimes, 256 semi large
top_p=0.95,
repetition_penalty=0.9,
do_sample=True,
seed=42,
)
formatted_prompt = system_instructions1 + text + "[SASHA]"
stream = client1.text_generation(
formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
output = ""
for response in stream:
if not response.token.text == "</s>":
output += response.token.text
return output
async def respond(audio):
user = transcribe(audio)
reply = model(user)
communicate = edge_tts.Communicate(reply)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_path = tmp_file.name
await communicate.save(tmp_path)
yield tmp_path
DESCRIPTION = """ # <center><b>SASHA ✨</b></center>
### <center>Your AI Assistant, Sasha Bingham</center>
### <center>Let's chat!</center>
"""
MORE = """ ## TRY Other Models
### Instant Video: Create Amazing Videos in 5 Second -> https://huggingface.co/spaces/KingNish/Instant-Video
### Instant Image: 4k images in 5 Second -> https://huggingface.co/spaces/KingNish/Instant-Image
"""
BETA = """ ### Voice Chat (BETA)"""
FAST = """## Fastest Model"""
Complex = """## Best in Complex Question"""
Detail = """## Best for Detailed Generation or Long Answers"""
base_loaded = "mistralai/Mixtral-8x7B-Instruct-v0.1"
client1 = InferenceClient(base_loaded)
system_instructions1 = "[SYSTEM] Sasha is a friendly and enthusiastic person, born in Florida, United States. She enjoys casual and fun conversations. She responds in a brief and natural way, like a real person would. Feel free to ask questions, tell jokes, or just chat with her. [USER]"
async def generate1(prompt):
generate_kwargs = dict(
temperature=1.0,
max_new_tokens=128, #def 512 very long sometimes, 256 semi large
top_p=0.95,
repetition_penalty=0.9,
do_sample=True,
seed=42,
)
formatted_prompt = system_instructions1 + prompt + "[SASHA]"
stream = client1.text_generation(
formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=True)
output = ""
for response in stream:
if not response.token.text == "</s>":
output += response.token.text
communicate = edge_tts.Communicate(output)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_path = tmp_file.name
await communicate.save(tmp_path)
yield tmp_path
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
input = gr.Audio(label="Voice Chat (BETA)", sources="microphone", type="filepath", waveform_options=False)
output = gr.Audio(label="SASHA", type="filepath",
interactive=False,
autoplay=True,
elem_classes="audio")
gr.Interface(
fn=respond,
inputs=[input],
outputs=[output], live=True)
gr.Markdown(FAST)
with gr.Row():
user_input = gr.Textbox(label="Prompt", value="What is Wikipedia")
input_text = gr.Textbox(label="Input Text", elem_id="important")
output_audio = gr.Audio(label="SASHA", type="filepath",
interactive=False,
autoplay=True,
elem_classes="audio")
with gr.Row():
translate_btn = gr.Button("Response")
translate_btn.click(fn=generate1, inputs=user_input,
outputs=output_audio, api_name="translate")
gr.Markdown(MORE)
if __name__ == "__main__":
demo.queue(max_size=200).launch()