import sys import time from importlib.metadata import version import torch import torchaudio import torchaudio.transforms as T import gradio as gr import numpy as np from transformers import HubertForCTC, Wav2Vec2Processor # Config model_name = "Yehor/mHuBERT-147-uk" min_duration = 0.5 max_duration = 60 concurrency_limit = 5 use_torch_compile = False # Torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # Load the model asr_model = HubertForCTC.from_pretrained(model_name, torch_dtype=torch_dtype, device_map=device) processor = Wav2Vec2Processor.from_pretrained(model_name) if use_torch_compile: asr_model = torch.compile(asr_model) # Elements examples = [ "example_1.wav", "example_2.wav", "example_3.wav", "example_4.wav", "example_5.wav", "example_6.wav", ] examples_table = """ | File | Text | | ------------- | ------------- | | `example_1.wav` | тема про яку не люблять говорити офіційні джерела у генштабі і міноборони це хімічна зброя окупанти вже тривалий час використовують хімічну зброю заборонену | | `example_2.wav` | всіма конвенціями якщо спочатку це були гранати з дронів то тепер фіксують випадки застосування | | `example_3.wav` | хімічних снарядів причому склад отруйної речовони різний а отже й наслідки для наших військових теж різні | | `example_4.wav` | використовує на фронті все що має і хімічна зброя не вийняток тож з чим маємо справу розбиралася марія моганисян | | `example_5.wav` | двох тисяч випадків застосування росіянами боєприпасів споряджених небезпечними хімічними речовинами | | `example_6.wav` | на всі писані норми марія моганисян олександр моторний спецкор марафон єдині новини | """.strip() # https://www.tablesgenerator.com/markdown_tables authors_table = """ ## Authors Follow them in social networks and **contact** if you need any help or have any questions: | **Yehor Smoliakov** | |-------------------------------------------------------------------------------------------------| | https://t.me/smlkw in Telegram | | https://x.com/yehor_smoliakov at X | | https://github.com/egorsmkv at GitHub | | https://huggingface.co/Yehor at Hugging Face | | or use egorsmkv@gmail.com | """.strip() description_head = f""" # Speech-to-Text for Ukrainian using HuBERT ## Overview This space uses https://huggingface.co/Yehor/mHuBERT-147-uk model to recognize audio files. > Due to resource limitations, audio duration **must not** exceed **{max_duration}** seconds. """.strip() description_foot = f""" ## Community - **Discord**: https://discord.gg/yVAjkBgmt4 - Speech Recognition: https://t.me/speech_recognition_uk - Speech Synthesis: https://t.me/speech_synthesis_uk ## More Check out other ASR models: https://github.com/egorsmkv/speech-recognition-uk {authors_table} """.strip() transcription_value = """ Recognized text will appear here. Choose **an example file** below the Recognize button, upload **your audio file**, or use **the microphone** to record own voice. """.strip() tech_env = f""" #### Environment - Python: {sys.version} - Torch device: {device} - Torch dtype: {torch_dtype} - Use torch.compile: {use_torch_compile} """.strip() tech_libraries = f""" #### Libraries - torch: {version('torch')} - torchaudio: {version('torchaudio')} - transformers: {version('transformers')} - accelerate: {version('accelerate')} - gradio: {version('gradio')} """.strip() def inference(audio_path, progress=gr.Progress()): if not audio_path: raise gr.Error("Please upload an audio file.") gr.Info("Starting recognition", duration=2) progress(0, desc="Recognizing") meta = torchaudio.info(audio_path) duration = meta.num_frames / meta.sample_rate if duration < min_duration: raise gr.Error( f"The duration of the file is less than {min_duration} seconds, it is {round(duration, 2)} seconds." ) if duration > max_duration: raise gr.Error(f"The duration of the file exceeds {max_duration} seconds.") paths = [ audio_path, ] results = [] for path in progress.tqdm(paths, desc="Recognizing...", unit="file"): t0 = time.time() meta = torchaudio.info(audio_path) audio_duration = meta.num_frames / meta.sample_rate audio_input, sr = torchaudio.load(path) if meta.num_channels > 1: audio_input = torch.mean(audio_input, dim=0, keepdim=True) if meta.sample_rate != 16_000: resampler = T.Resample(sr, 16_000, dtype=audio_input.dtype) audio_input = resampler(audio_input) audio_input = audio_input.squeeze(0).numpy() inputs = processor([audio_input], sampling_rate=16_000, padding=True).input_values features = torch.tensor(np.array(inputs), dtype=torch_dtype).to(device) with torch.inference_mode(): logits = asr_model(features).logits predicted_ids = torch.argmax(logits, dim=-1) predictions = processor.batch_decode(predicted_ids) if not predictions: predictions = "-" elapsed_time = round(time.time() - t0, 2) rtf = round(elapsed_time / audio_duration, 4) audio_duration = round(audio_duration, 2) results.append( { "path": path.split("/")[-1], "transcription": "\n".join(predictions), "audio_duration": audio_duration, "rtf": rtf, } ) gr.Info("Finished!", duration=2) result_texts = [] for result in results: result_texts.append(f'**{result["path"]}**') result_texts.append("\n\n") result_texts.append(f'> {result["transcription"]}') result_texts.append("\n\n") result_texts.append(f'**Audio duration**: {result["audio_duration"]}') result_texts.append("\n") result_texts.append(f'**Real-Time Factor**: {result["rtf"]}') return "\n".join(result_texts) demo = gr.Blocks( title="Speech-to-Text for Ukrainian", analytics_enabled=False, theme=gr.themes.Base(), ) with demo: gr.Markdown(description_head) gr.Markdown("## Usage") with gr.Row(): audio_file = gr.Audio(label="Audio file", type="filepath") transcription = gr.Markdown( label="Transcription", value=transcription_value, ) gr.Button("Recognize").click( inference, concurrency_limit=concurrency_limit, inputs=audio_file, outputs=transcription, ) with gr.Row(): gr.Examples(label="Choose an example", inputs=audio_file, examples=examples) gr.Markdown(examples_table) gr.Markdown(description_foot) gr.Markdown("### Gradio app uses the following technologies:") gr.Markdown(tech_env) gr.Markdown(tech_libraries) if __name__ == "__main__": demo.queue() demo.launch()