|
import gradio as gr |
|
from PIL import Image |
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC, AutoModelForCausalLM, AutoTokenizer |
|
import soundfile as sf |
|
import torch |
|
|
|
model_name_or_path = "microsoft/DialoGPT-large" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", use_fast=False) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
torch_dtype=torch.float32, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
|
|
|
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
wav2vec2_model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") |
|
|
|
|
|
def handle_text(text): |
|
new_user_input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt') |
|
bot_input_ids = new_user_input_ids |
|
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id) |
|
chat_output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
|
return chat_output |
|
|
|
|
|
def handle_image(img): |
|
return "This image seems nice!" |
|
|
|
|
|
def handle_audio(audio): |
|
speech, _ = sf.read(audio) |
|
input_values = processor(speech, return_tensors="pt").input_values |
|
logits = wav2vec2_model(input_values).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcriptions = processor.decode(predicted_ids[0]) |
|
return handle_text(transcriptions) |
|
|
|
def chatbot(text, img, audio): |
|
text_output = handle_text(text) if text is not None else '' |
|
img_output = handle_image(img) if img is not None else '' |
|
audio_output = handle_audio(audio) if audio is not None else '' |
|
|
|
outputs = [o for o in [text_output, img_output, audio_output] if o] |
|
return "\n".join(outputs) |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=chatbot, |
|
inputs=[ |
|
gr.inputs.Textbox(lines=2, placeholder="Input Text here..."), |
|
gr.inputs.Image(label="Upload Image"), |
|
gr.inputs.Audio(source="microphone", label="Audio Input"), |
|
], |
|
outputs=gr.outputs.Textbox(label="Output"), |
|
title="Multimodal Chatbot", |
|
description="This chatbot can handle text, image, and audio inputs. Try it out!", |
|
) |
|
|
|
iface.launch() |
|
|