wanderJoy / app.py
HakimHa's picture
Update app.py
9c0c186
raw
history blame
2.36 kB
import gradio as gr
from PIL import Image
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForCausalLM, AutoTokenizer
import soundfile as sf
import torch
model_name_or_path = "microsoft/DialoGPT-large"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.float32,
device_map="auto",
trust_remote_code=True,
)
# Load pre-trained model and processor for Wav2Vec2
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
# Function to handle text input
def handle_text(text):
new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt')
bot_input_ids = new_user_input_ids
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
chat_output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
return chat_output
# Function to handle image input
def handle_image(img):
return "This image seems nice!"
# Function to handle audio input
def handle_audio(audio):
speech, _ = sf.read(audio)
input_values = processor(speech, return_tensors="pt").input_values
logits = wav2vec2_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcriptions = processor.decode(predicted_ids[0])
return handle_text(transcriptions)
def chatbot(text, img, audio):
text_output = handle_text(text) if text is not None else ''
img_output = handle_image(img) if img is not None else ''
audio_output = handle_audio(audio) if audio is not None else ''
outputs = [o for o in [text_output, img_output, audio_output] if o]
return "\n".join(outputs)
iface = gr.Interface(
fn=chatbot,
inputs=[
gr.inputs.Textbox(lines=2, placeholder="Input Text here..."),
gr.inputs.Image(label="Upload Image"),
gr.inputs.Audio(source="microphone", label="Audio Input"),
],
outputs=gr.outputs.Textbox(label="Output"),
title="Multimodal Chatbot",
description="This chatbot can handle text, image, and audio inputs. Try it out!",
)
iface.launch()