Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,209 Bytes
ab907f9 1ffe71f ab907f9 1ffe71f ab907f9 1ffe71f ab907f9 1ffe71f ab907f9 6b6d9ba ab907f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModel
import scipy.io.wavfile as wavfile
import spaces
# Processor
def load_model():
processor = AutoProcessor.from_pretrained("suno/bark-small")
model = AutoModel.from_pretrained("suno/bark-small")
model.eval() # Set the model to evaluation mode
return processor, model
# Load models on startup
print("Loading models...")
processor, model = load_model()
print("Models loaded successfully!")
@spaces.GPU # Decorate the function to enable GPU usage
def text_to_speech(text):
try:
# Check if a GPU is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move model to GPU
model.to(device)
inputs = processor(
text=[text],
return_tensors="pt",
).to(device) # Move inputs to GPU
# Generate speech values on the GPU
with torch.no_grad(): # Disable gradient calculation for inference
speech_values = model.generate(**inputs, do_sample=True)
# Move generated audio data back to CPU for saving
audio_data = speech_values.cpu().numpy().squeeze()
sampling_rate = model.generation_config.sample_rate
temp_path = "temp_audio.wav"
wavfile.write(temp_path, sampling_rate, audio_data)
return temp_path
except Exception as e:
return f"Error generating speech: {str(e)}"
# Define Gradio interface
demo = gr.Interface(
fn=text_to_speech,
inputs=[
gr.Textbox(
label="Enter text",
placeholder="दिल्ली मेट्रो में आपका स्वागत है"
)
],
outputs=gr.Audio(label="Generated Speech"),
title="Bark TTS Test App",
description="This app generates speech from text using the Bark TTS model.",
examples=[
["दिल्ली मेट्रो में आपका स्वागत है"],
["अगला स्टेशन राजीव चौक है"]
],
theme="default"
)
if __name__ == "__main__":
demo.launch() |