Chat / app.py
ahricat's picture
Update app.py (#4)
cdc3959 verified
raw
history blame contribute delete
No virus
1.77 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, WhisperProcessor, WhisperForConditionalGeneration
from gtts import gTTS
import os
class InteractiveChat:
def __init__(self):
self.whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-large")
self.whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
self.zephyr_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
self.zephyr_model = AutoModelForCausalLM.from_pretrained("HuggingFaceH4/zephyr-7b-beta", device_map="auto")
def generate_response(self, input_data):
input_features = self.whisper_processor(input_data)
predicted_ids = self.whisper_model.generate(input_features)
transcription = self.whisper_processor.batch_decode(predicted_ids)
response = self.get_zephyr_response(transcription)
self.speak(response)
return response
def get_zephyr_response(self, transcription):
zephyr_pipeline = pipeline("text-generation")
response = zephyr_pipeline(transcription)[0]["generated_text"]
return response
def speak(self, text):
tts = gTTS(text=text, lang='en')
tts.save("output.mp3")
os.system("mpg321 output.mp3")
# Create an instance of the InteractiveChat class
chat = InteractiveChat()
# Define a function that wraps the generate_response method
def generate_response_fn(input_data):
return chat.generate_response(input_data)
# Use the function in gr.Interface
interface = gr.Interface(
gr.Audio(type="filepath"), # Accept audio files
gr.Textbox(),
generate_response_fn # Pass the function here
)
interface.launch()