File size: 3,636 Bytes
6065472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487e498
6065472
 
 
 
 
 
 
 
487e498
6065472
487e498
6065472
 
 
 
 
 
 
 
 
487e498
6065472
487e498
6065472
 
 
 
 
 
 
 
487e498
6065472
487e498
 
 
 
 
 
6065472
487e498
 
 
 
 
6065472
 
487e498
 
 
6065472
487e498
 
f729a94
 
 
 
 
 
 
 
 
487e498
 
 
 
 
 
6065472
487e498
 
 
 
 
 
 
 
 
 
 
 
 
 
6065472
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from pathlib import Path
import argparse
from functools import partial
import gradio as gr
import torch
from torchaudio.functional import resample

import utils.train_util as train_util


def load_model(cfg,
               ckpt_path,
               device):
    model = train_util.init_model_from_config(cfg["model"])
    ckpt = torch.load(ckpt_path, "cpu")
    train_util.load_pretrained_model(model, ckpt)
    model.eval()
    model = model.to(device)
    tokenizer = train_util.init_obj_from_dict(cfg["tokenizer"])
    if not tokenizer.loaded:
        tokenizer.load_state_dict(ckpt["tokenizer"])
    model.set_index(tokenizer.bos, tokenizer.eos, tokenizer.pad)
    return model, tokenizer


def infer(file, runner):
    sr, wav = file
    wav = torch.as_tensor(wav)
    if wav.dtype == torch.short:
        wav = wav / 2 ** 15
    elif wav.dtype == torch.int:
        wav = wav / 2 ** 31
    if wav.ndim > 1:
        wav = wav.mean(1)
    wav = resample(wav, sr, runner.target_sr)
    wav_len = len(wav)
    wav = wav.float().unsqueeze(0).to(runner.device)
    input_dict = {
        "mode": "inference",
        "wav": wav,
        "wav_len": [wav_len],
        "specaug": False,
        "sample_method": "beam",
        "beam_size": 3,
    }
    with torch.no_grad():
        output_dict = runner.model(input_dict)
        seq = output_dict["seq"].cpu().numpy()
        cap = runner.tokenizer.decode(seq)[0]
    return cap

# def input_toggle(input_type):
#     if input_type == "file":
#         return gr.update(visible=True), gr.update(visible=False)
#     elif input_type == "mic":
#         return gr.update(visible=False), gr.update(visible=True)

class InferRunner:

    def __init__(self, model_name):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        exp_dir = Path(f"./checkpoints/{model_name.lower()}")
        cfg = train_util.load_config(exp_dir / "config.yaml")
        self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
        self.target_sr = cfg["target_sr"]
    
    def change_model(self, model_name):
        exp_dir = Path(f"./checkpoints/{model_name.lower()}")
        cfg = train_util.load_config(exp_dir / "config.yaml")
        self.model, self.tokenizer = load_model(cfg, exp_dir / "ckpt.pth", self.device)
        self.target_sr = cfg["target_sr"]


def change_model(radio):
    global infer_runner
    infer_runner.change_model(radio)


with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown("# Lightweight EfficientNetB2-Transformer Audio Captioning")

    with gr.Row():
        gr.Markdown("""
            [![arXiv](https://img.shields.io/badge/arXiv-2407.14329-brightgreen.svg?style=flat-square)](https://arxiv.org/abs/2407.14329)
            
            [![github](https://img.shields.io/badge/GitHub-Code-blue?logo=Github&style=flat-square)](https://github.com/wsntxxn/AudioCaption?tab=readme-ov-file#lightweight-effb2-transformer-model)
        """)
    with gr.Row():
        with gr.Column():
            radio = gr.Radio(
                ["AudioCaps", "Clotho"],
                value="AudioCaps",
                label="Select model"
            )
            infer_runner = InferRunner(radio.value)
            file = gr.Audio(label="Input", visible=True)
            radio.change(fn=change_model, inputs=[radio,],)
            btn = gr.Button("Run")
        with gr.Column():
            output = gr.Textbox(label="Output")
        btn.click(
            fn=partial(infer,
                       runner=infer_runner),
            inputs=[file,],
            outputs=output
        )
    
    demo.launch()