ronniet's picture
Update app.py
ce16067
raw
history blame
No virus
3.06 kB
import gradio as gr
from transformers import pipeline
import librosa
import numpy as np
import torch
from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan
from transformers import AutoProcessor, AutoModelForCausalLM
checkpoint = "microsoft/speecht5_tts"
tts_processor = SpeechT5Processor.from_pretrained(checkpoint)
tts_model = SpeechT5ForTextToSpeech.from_pretrained(checkpoint)
vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan")
vqa_processor = AutoProcessor.from_pretrained("microsoft/git-base-textvqa")
vqa_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-textvqa")
def tts(text):
if len(text.strip()) == 0:
return (16000, np.zeros(0).astype(np.int16))
inputs = tts_processor(text=text, return_tensors="pt")
# limit input length
input_ids = inputs["input_ids"]
input_ids = input_ids[..., :model.config.max_text_positions]
# if speaker == "Surprise Me!":
# # load one of the provided speaker embeddings at random
# idx = np.random.randint(len(speaker_embeddings))
# key = list(speaker_embeddings.keys())[idx]
# speaker_embedding = np.load(speaker_embeddings[key])
# # randomly shuffle the elements
# np.random.shuffle(speaker_embedding)
# # randomly flip half the values
# x = (np.random.rand(512) >= 0.5) * 1.0
# x[x == 0] = -1.0
# speaker_embedding *= x
#speaker_embedding = np.random.rand(512).astype(np.float32) * 0.3 - 0.15
# else:
speaker_embedding = np.load("cmu_us_bdl_arctic-wav-arctic_a0009.npy")
speaker_embedding = torch.tensor(speaker_embedding).unsqueeze(0)
speech = tts_model.generate_speech(input_ids, speaker_embedding, vocoder=vocoder)
speech = (speech.numpy() * 32767).astype(np.int16)
return (16000, speech)
# captioner = pipeline(model="microsoft/git-base")
# tts = TTS(model_name="tts_models/multilingual/multi-dataset/your_tts", progress_bar=False, gpu=False)
def predict(image):
# text = captioner(image)[0]["generated_text"]
# audio_output = "output.wav"
# tts.tts_to_file(text, speaker=tts.speakers[0], language="en", file_path=audio_output)
pixel_values = vqa_processor(images=image, return_tensors="pt").pixel_values
prompt = "what is in the scene?"
prompt_ids = vqa_processor(text=prompt, add_special_tokens=False).input_ids
prompt_ids = [vqa_processor.tokenizer.cls_token_id] + prompt_ids
prompt_ids = torch.tensor(prompt_ids).unsqueeze(0)
text_ids = vqa_model.generate(pixel_values=pixel_values, input_ids=prompt_ids, max_length=50)
text = vqa_processor.batch_decode(text_ids, skip_special_tokens=True)
audio = tts(text)
return text, audio
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil",label="Environment"),
outputs=[gr.Textbox(label="Caption"), gr.Audio(type="numpy",label="Audio Feedback")],
css=".gradio-container {background-color: #002A5B}",
theme=gr.themes.Soft()
)
demo.launch()