Spaces:
Build error
Build error
import os | |
import gradio as gr | |
import torch | |
import nemo.collections.asr as nemo_asr | |
import wandb | |
from pydub.utils import mediainfo | |
MODEL_HISTORY_DAYS = 180 | |
WANDB_ENTITY = os.environ.get("WANDB_ENTITY", "tarteel") | |
WANDB_PROJECT_NAME = os.environ.get("WANDB_PROJECT_NAME", "nemo-experiments") | |
MODEL_NAME = os.environ.get("MODEL_NAME", "CfCtcLg-SpeUni1024-DI-EATLDN-CA:v0") | |
def get_device(): | |
if torch.cuda.is_available(): | |
return "cuda" | |
else: | |
return "cpu" | |
# run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT_NAME) | |
wandb_api = wandb.Api(overrides={"entity": WANDB_ENTITY}) | |
artifact = wandb_api.artifact(f"{WANDB_ENTITY}/{WANDB_PROJECT_NAME}/{MODEL_NAME}") | |
artifact_dir = artifact.download() | |
# find the model (ending with .nemo) in the artifact directory | |
model_path = [ | |
os.path.join(root, file) | |
for root, dirs, files in os.walk(artifact_dir) | |
for file in files | |
if file.endswith(".nemo") | |
][0] | |
model = nemo_asr.models.EncDecCTCModelBPE.restore_from( | |
model_path, map_location=get_device() | |
) | |
def transcribe(audio_file): | |
transcription = model.transcribe([audio_file], verbose=False)[0] | |
print(f"{audio_file}: {transcription}") | |
return transcription | |
def get_duration_ms(audio_file): | |
duration = mediainfo(audio_file)["duration"] # a string in seconds | |
duration_ms = int(float(duration) * 1000) | |
print(f"{audio_file}: {duration_ms} ms") | |
return duration_ms | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
""" | |
# ﷽ | |
""" | |
) | |
with gr.Row(): | |
audio_file = gr.Audio(source="upload", type="filepath", label="File") | |
with gr.Row(): | |
output_file = gr.TextArea(label="Audio Transcription") | |
b1 = gr.Button("Transcribe") | |
b1.click( | |
transcribe, | |
inputs=[audio_file], | |
outputs=[output_file], | |
api_name="transcribe", | |
) | |
b2 = gr.Button("Get Duration") | |
with gr.Row(): | |
duration = gr.TextArea(label="Duration") | |
b2.click( | |
get_duration_ms, | |
inputs=[audio_file], | |
outputs=[duration], | |
api_name="get_duration_ms", | |
) | |
demo.launch() | |