Chat / app.py
ahricat's picture
Update app.py
eeba817 verified
raw
history blame
No virus
1.77 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, WhisperProcessor, WhisperForConditionalGeneration
from gtts import gTTS
import os
class InteractiveChat:
def __init__(self):
self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large")
self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
self.zephyr_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
self.zephyr_model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta", device_map="auto")
def generate_response(self, input_data):
input_features = self.whisper_processor(input_data)
predicted_ids = self.whisper_model.generate(input_features)
transcription = self.whisper_processor.batch_decode(predicted_ids)
response = self.get_zephyr_response(transcription)
self.speak(response)
return response
def get_zephyr_response(self, transcription):
zephyr_pipeline = pipeline("text-generation")
response = zephyr_pipeline(transcription)[0]["generated_text"]
return response
def speak(self, text):
tts = gTTS(text=text, lang='en')
tts.save("output.mp3")
os.system("mpg321 output.mp3")
# Create an instance of the InteractiveChat class
chat = InteractiveChat()
# Define a function that wraps the generate_response method
def generate_response_fn(input_data):
return chat.generate_response(input_data)
# Use the function in gr.Interface
interface = gr.Interface(
gr.Audio(type="filepath"), # Accept audio files
gr.Textbox(),
generate_response_fn # Pass the function here
)
interface.launch()