Spaces:
Sleeping
Sleeping
import time | |
import torch | |
import spaces | |
import numpy as np | |
import gradio as gr | |
from gtts import gTTS | |
from transformers import pipeline | |
from huggingface_hub import InferenceClient | |
# Model names | |
ASR_MODEL_NAME = "openai/whisper-small" | |
LLM_MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2" | |
# Initial system prompt | |
system_prompt = """"<s>[INST] You are Friday, a helpful and conversational AI assistant, and you respond with one to two sentences. [/INST] Hello there! I'm Friday, how can I help you?</s>""" | |
# Global variables for history | |
instruct_history = system_prompt | |
formatted_history = "" | |
# Create inference client for text generation | |
client = InferenceClient(LLM_MODEL_NAME) | |
# Set device for ASR pipeline | |
device = 0 if torch.cuda.is_available() else "cpu" | |
# ASR pipeline | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=ASR_MODEL_NAME, | |
device=device, | |
) | |
def generate(instruct_history, temperature=0.1, max_new_tokens=128, top_p=0.95, repetition_penalty=1.0): | |
temperature = float(temperature) | |
if temperature < 1e-2: | |
temperature = 1e-2 | |
top_p = float(top_p) | |
generate_kwargs = dict( | |
temperature=temperature, | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
do_sample=True, | |
seed=42, | |
) | |
output = client.text_generation( | |
instruct_history, **generate_kwargs, stream=False, details=False, return_full_text=False) | |
return output | |
def transcribe(audio, past_history): | |
global instruct_history, formatted_history | |
time.sleep(1) | |
sr, y = audio | |
y = y.astype(np.float32) | |
y /= np.max(np.abs(y)) | |
transcribed_user_audio = pipe({"sampling_rate": sr, "raw": y})["text"] | |
formatted_history += past_history | |
formatted_history += f"π Human: {transcribed_user_audio}\n\n" | |
instruct_history += f"<s>[INST] {transcribed_user_audio} [/INST] " | |
# Generate LLM response | |
llm_response = generate(instruct_history) | |
instruct_history += f" {llm_response}</s>" | |
formatted_history += f"π€ Friday: {llm_response}\n\n" | |
# Convert AI response to audio | |
audio_response = gTTS(llm_response) | |
audio_response.save("response.mp3") | |
print("Formatted History: ", formatted_history) | |
# Return the full conversation history | |
return "response.mp3", formatted_history | |
def clear_history(formatted_history): | |
instruct_history = "" | |
instruct_history += system_prompt | |
formatted_history = "" | |
return formatted_history | |
with gr.Blocks() as demo: | |
gr.HTML("<center><h1>Friday: AI Virtual Assistant π€</h1><center>") | |
with gr.Row(): | |
audio_input = gr.Audio(label="Human", sources="microphone") | |
output_audio = gr.Audio(label="Friday", type="filepath", interactive=False, autoplay=True, elem_classes="audio") | |
with gr.Row(): | |
send_btn = gr.Button("π Send") | |
clear_btn = gr.Button("ποΈ Clear") | |
# Textbox to display the full conversation history | |
transcription_box = gr.Textbox(label="Transcription", lines=10, placeholder="Conversation History...") | |
send_btn.click(fn=transcribe, inputs=[audio_input, transcription_box], outputs=[output_audio, transcription_box]) | |
clear_btn.click(fn=clear_history, inputs=[transcription_box], outputs=[transcription_box]) | |
if __name__ == "__main__": | |
demo.queue() | |
demo.launch() |