terry-li-hm
U
e41a956
# coding=utf-8
import gradio as gr
import numpy as np
import soundfile as sf
import spaces
import torch
import torchaudio
from gradio.themes import Base
from sv import process_audio
@spaces.GPU
def model_inference(input_wav, language):
# Simplify language selection
language = language if language else "auto"
# Handle input_wav format
if isinstance(input_wav, tuple):
fs, input_wav = input_wav
input_wav = input_wav.astype(np.float32) / np.iinfo(np.int16).max
input_wav = input_wav.mean(-1) if len(input_wav.shape) > 1 else input_wav
if fs != 16000:
resampler = torchaudio.transforms.Resample(fs, 16000)
input_wav = resampler(torch.from_numpy(input_wav).float()[None, :])[
0
].numpy()
# Process audio
with sf.SoundFile("temp.wav", "w", samplerate=16000, channels=1) as f:
f.write(input_wav)
result = process_audio("temp.wav", language=language)
return result
def launch():
# Create a custom theme
custom_css = """
.gradio-container {color: rgb(70, 70, 70);}
"""
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# Cantonese Call Transcriber")
gr.Markdown(
"""
This tool transcribes Cantonese audio calls into text.
## How to use:
1. Upload an audio file or use the example provided at the bottom of the page.
2. Click the 'Process Audio' button.
3. The transcription will appear in the output box.
"""
)
# Define components
audio_input = gr.Audio(label="Input")
text_output = gr.Textbox(lines=10, label="Output")
# Custom render function for Examples
def render_example(example):
return gr.Button("Try Example Audio")
# Update the Examples component
gr.Examples(
examples=[["example/scb.mp3"]],
inputs=[audio_input],
outputs=[text_output],
fn=lambda x: model_inference(x, "yue"),
examples_per_page=1,
)
# Main interface
with gr.Row():
with gr.Column(scale=2):
audio_input
fn_button = gr.Button("Process Audio", variant="primary")
with gr.Column(scale=3):
text_output
# Set up event handler
fn_button.click(
fn=lambda x: model_inference(x, "yue"),
inputs=[audio_input],
outputs=[text_output],
)
demo.launch()
if __name__ == "__main__":
launch()