hubert-uk-demo / app.py
Yehor's picture
Update app.py
60e77cf verified
import sys
import time
from importlib.metadata import version
import torch
import torchaudio
import torchaudio.transforms as T
import gradio as gr
import numpy as np
from transformers import HubertForCTC, Wav2Vec2Processor
# Config
model_name = "Yehor/hubert-uk"
min_duration = 0.5
max_duration = 60
concurrency_limit = 5
use_torch_compile = False
# Torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load the model
asr_model = HubertForCTC.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device
)
processor = Wav2Vec2Processor.from_pretrained(model_name)
if use_torch_compile:
asr_model = torch.compile(asr_model)
# Elements
examples = [
"example_1.wav",
"example_2.wav",
"example_3.wav",
"example_4.wav",
"example_5.wav",
"example_6.wav",
]
examples_table = """
| File | Text |
| ------------- | ------------- |
| `example_1.wav` | тема про яку не люблять говорити офіційні джерела у генштабі і міноборони це хімічна зброя окупанти вже тривалий час використовують хімічну зброю заборонену |
| `example_2.wav` | всіма конвенціями якщо спочатку це були гранати з дронів то тепер фіксують випадки застосування |
| `example_3.wav` | хімічних снарядів причому склад отруйної речовони різний а отже й наслідки для наших військових теж різні |
| `example_4.wav` | використовує на фронті все що має і хімічна зброя не вийняток тож з чим маємо справу розбиралася марія моганисян |
| `example_5.wav` | двох тисяч випадків застосування росіянами боєприпасів споряджених небезпечними хімічними речовинами |
| `example_6.wav` | на всі писані норми марія моганисян олександр моторний спецкор марафон єдині новини |
""".strip()
# https://www.tablesgenerator.com/markdown_tables
authors_table = """
## Authors
Follow them in social networks and **contact** if you need any help or have any questions:
| <img src="https://avatars.githubusercontent.com/u/7875085?v=4" width="100"> **Yehor Smoliakov** |
|-------------------------------------------------------------------------------------------------|
| https://t.me/smlkw in Telegram |
| https://x.com/yehor_smoliakov at X |
| https://github.com/egorsmkv at GitHub |
| https://huggingface.co/Yehor at Hugging Face |
| or use egorsmkv@gmail.com |
""".strip()
description_head = f"""
# Speech-to-Text for Ukrainian using HuBERT
## Overview
This space uses https://huggingface.co/Yehor/hubert-uk model to recognize audio files.
> Due to resource limitations, audio duration **must not** exceed **{max_duration}** seconds.
""".strip()
description_foot = f"""
## Community
- **Discord**: https://discord.gg/yVAjkBgmt4
- Speech Recognition: https://t.me/speech_recognition_uk
- Speech Synthesis: https://t.me/speech_synthesis_uk
## More
Check out other ASR models: https://github.com/egorsmkv/speech-recognition-uk
{authors_table}
""".strip()
transcription_value = """
Recognized text will appear here.
Choose **an example file** below the Recognize button, upload **your audio file**, or use **the microphone** to record own voice.
""".strip()
tech_env = f"""
#### Environment
- Python: {sys.version}
- Torch device: {device}
- Torch dtype: {torch_dtype}
- Use torch.compile: {use_torch_compile}
""".strip()
tech_libraries = f"""
#### Libraries
- torch: {version('torch')}
- torchaudio: {version('torchaudio')}
- transformers: {version('transformers')}
- accelerate: {version('accelerate')}
- gradio: {version('gradio')}
""".strip()
def inference(audio_path, progress=gr.Progress()):
if not audio_path:
raise gr.Error("Please upload an audio file.")
gr.Info("Starting recognition", duration=2)
progress(0, desc="Recognizing")
meta = torchaudio.info(audio_path)
duration = meta.num_frames / meta.sample_rate
if duration < min_duration:
raise gr.Error(
f"The duration of the file is less than {min_duration} seconds, it is {round(duration, 2)} seconds."
)
if duration > max_duration:
raise gr.Error(f"The duration of the file exceeds {max_duration} seconds.")
paths = [
audio_path,
]
results = []
for path in progress.tqdm(paths, desc="Recognizing...", unit="file"):
t0 = time.time()
meta = torchaudio.info(audio_path)
audio_duration = meta.num_frames / meta.sample_rate
audio_input, sr = torchaudio.load(path)
if meta.num_channels > 1:
audio_input = torch.mean(audio_input, dim=0, keepdim=True)
if meta.sample_rate != 16_000:
resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype)
audio_input = resampler(audio_input)
audio_input = audio_input.squeeze(0).numpy()
inputs = processor(
[audio_input], sampling_rate=16_000, padding=True
).input_values
features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device)
with torch.inference_mode():
logits = asr_model(features).logits
predicted_ids = torch.argmax(logits, dim=-1)
predictions = processor.batch_decode(predicted_ids)
if not predictions:
predictions = "-"
elapsed_time = round(time.time() - t0, 2)
rtf = round(elapsed_time / audio_duration, 4)
audio_duration = round(audio_duration, 2)
results.append(
{
"path": path.split("/")[-1],
"transcription": "\n".join(predictions),
"audio_duration": audio_duration,
"rtf": rtf,
}
)
gr.Info("Finished!", duration=2)
result_texts = []
for result in results:
result_texts.append(f'**{result["path"]}**')
result_texts.append("\n\n")
result_texts.append(f'> {result["transcription"]}')
result_texts.append("\n\n")
result_texts.append(f'**Audio duration**: {result["audio_duration"]}')
result_texts.append("\n")
result_texts.append(f'**Real-Time Factor**: {result["rtf"]}')
return "\n".join(result_texts)
demo = gr.Blocks(
title="Speech-to-Text for Ukrainian",
analytics_enabled=False,
theme=gr.themes.Base(),
)
with demo:
gr.Markdown(description_head)
gr.Markdown("## Usage")
with gr.Row():
audio_file = gr.Audio(label="Audio file", type="filepath")
transcription = gr.Markdown(
label="Transcription",
value=transcription_value,
)
gr.Button("Recognize").click(
inference,
concurrency_limit=concurrency_limit,
inputs=audio_file,
outputs=transcription,
)
with gr.Row():
gr.Examples(label="Choose an example", inputs=audio_file, examples=examples)
gr.Markdown(examples_table)
gr.Markdown(description_foot)
gr.Markdown("### Gradio app uses:")
gr.Markdown(tech_env)
gr.Markdown(tech_libraries)
if __name__ == "__main__":
demo.queue()
demo.launch()