File size: 6,058 Bytes
7c7bb51
 
 
 
 
efac88a
7c7bb51
8dde699
7c7bb51
 
 
 
 
 
8dde699
e7578b0
7c7bb51
e7578b0
7c7bb51
 
 
 
 
 
 
 
 
 
 
 
 
8dde699
 
 
 
 
 
 
120532c
 
8dde699
 
 
 
 
efac88a
 
 
839f7b3
efac88a
 
 
8dde699
 
 
 
 
 
 
 
efac88a
 
8dde699
120532c
8dde699
 
 
 
e7578b0
683aeb8
8dde699
e7578b0
8dde699
e7578b0
8dde699
 
e7578b0
8dde699
 
 
 
 
 
 
 
 
 
 
 
 
 
839f7b3
 
 
 
 
 
 
 
 
7c7bb51
839f7b3
 
 
 
 
 
7c7bb51
 
 
 
 
 
839f7b3
 
 
 
 
7c7bb51
839f7b3
 
 
 
 
7c7bb51
839f7b3
 
7c7bb51
839f7b3
 
 
 
 
 
 
 
 
7c7bb51
8dde699
839f7b3
 
7c7bb51
839f7b3
7c7bb51
839f7b3
7c7bb51
 
8dde699
 
 
839f7b3
7c7bb51
 
21b463c
 
839f7b3
 
7c7bb51
 
 
 
839f7b3
7c7bb51
839f7b3
8dde699
7c7bb51
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import logging
import warnings

import gradio as gr
import pytube as pt
import psutil
import torch
import whisper
from huggingface_hub import hf_hub_download, model_info
from transformers.utils.logging import disable_progress_bar

warnings.filterwarnings("ignore")
disable_progress_bar()

DEFAULT_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-french"
CHECKPOINT_FILENAME = "checkpoint_openai.pt"

GEN_KWARGS = {
    "task": "transcribe",
    "language": "fr",
    # "without_timestamps": True,
    # decode options
    # "beam_size": 5,
    # "patience": 2,
    # disable fallback
    # "compression_ratio_threshold": None,
    # "logprob_threshold": None,
    # vad threshold
    # "no_speech_threshold": None,
}

logging.basicConfig(
    format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
    datefmt="%Y-%m-%dT%H:%M:%SZ",
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

# device = 0 if torch.cuda.is_available() else "cpu"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Model will be loaded on device `{device}`")

cached_models = {}


def _print_memory_info():
    memory = psutil.virtual_memory()
    logger.info(
        f"Memory info - Free: {memory.available / (1024 ** 3):.2f} Gb, used: {memory.percent}%, total: {memory.total / (1024 ** 3):.2f} Gb"
    )


def print_cuda_memory_info():
    used_mem, tot_mem = torch.cuda.mem_get_info()
    logger.info(
        f"CUDA memory info - Free: {used_mem / 1024 ** 3:.2f} Gb, used: {(tot_mem - used_mem) / 1024 ** 3:.2f} Gb, total: {tot_mem / 1024 ** 3:.2f} Gb"
    )


def print_memory_info():
    _print_memory_info()
    print_cuda_memory_info()


def maybe_load_cached_pipeline(model_name):
    model = cached_models.get(model_name)
    if model is None:
        downloaded_model_path = hf_hub_download(repo_id=model_name, filename=CHECKPOINT_FILENAME)

        model = whisper.load_model(downloaded_model_path, device=device)
        logger.info(f"`{model_name}` has been loaded on device `{device}`")

        print_memory_info()

        cached_models[model_name] = model
    return model


def infer(model, filename, with_timestamps):
    if with_timestamps:
        model_outputs = model.transcribe(filename, **GEN_KWARGS)
        return "\n\n".join(
            [
                f'Segment {segment["id"]+1} from {segment["start"]:.2f}s to {segment["end"]:.2f}s:\n{segment["text"].strip()}'
                for segment in model_outputs["segments"]
            ]
        )
    else:
        return model.transcribe(filename, without_timestamps=True, **GEN_KWARGS)["text"]


def download_from_youtube(yt_url, downloaded_filename="audio.wav"):
    yt = pt.YouTube(yt_url)
    stream = yt.streams.filter(only_audio=True)[0]
    # stream.download(filename="audio.mp3")
    stream.download(filename=downloaded_filename)
    return downloaded_filename


def transcribe(microphone, file_upload, yt_url, with_timestamps, model_name=DEFAULT_MODEL_NAME):
    warn_output = ""
    if (microphone is not None) and (file_upload is not None) and yt_url:
        warn_output = (
            "WARNING: You've uploaded an audio file, used the microphone, and pasted a YouTube URL. "
            "The recorded file from the microphone will be used, the uploaded audio and the YouTube URL will be discarded.\n"
        )

    if (microphone is not None) and (file_upload is not None):
        warn_output = (
            "WARNING: You've uploaded an audio file and used the microphone. "
            "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
        )

    if (microphone is not None) and yt_url:
        warn_output = (
            "WARNING: You've used the microphone and pasted a YouTube URL. "
            "The recorded file from the microphone will be used and the YouTube URL will be discarded.\n"
        )

    if (file_upload is not None) and yt_url:
        warn_output = (
            "WARNING: You've uploaded an audio file and pasted a YouTube URL. "
            "The uploaded audio will be used and the YouTube URL will be discarded.\n"
        )

    elif (microphone is None) and (file_upload is None) and (not yt_url):
        return "ERROR: You have to either use the microphone, upload an audio file or paste a YouTube URL"

    if microphone is not None:
        file = microphone
        logging_prefix = f"Transcription by `{model_name}` of microphone:"
    elif file_upload is not None:
        file = file_upload
        logging_prefix = f"Transcription by `{model_name}` of uploaded file:"
    else:
        file = download_from_youtube(yt_url)
        logging_prefix = f'Transcription by `{model_name}` of "{yt_url}":'

    model = maybe_load_cached_pipeline(model_name)
    # text = model.transcribe(file, **GEN_KWARGS)["text"]
    text = infer(model, file, with_timestamps)

    logger.info(logging_prefix + "\n" + text + "\n")

    return warn_output + text


# load default model
maybe_load_cached_pipeline(DEFAULT_MODEL_NAME)

demo = gr.Interface(
    fn=transcribe,
    inputs=[
        gr.inputs.Audio(source="microphone", type="filepath", label="Record", optional=True),
        gr.inputs.Audio(source="upload", type="filepath", label="Upload File", optional=True),
        gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL", optional=True),
        gr.Checkbox(label="With timestamps?"),
    ],
    outputs=gr.outputs.Textbox(label="Transcription"),
    layout="horizontal",
    theme="huggingface",
    title="Whisper French Demo 🇫🇷",
    description=(
        "**Transcribe long-form microphone, audio inputs or YouTube videos with the click of a button!** \n\nDemo uses the the fine-tuned"
        f" checkpoint [{DEFAULT_MODEL_NAME}](https://huggingface.co/{DEFAULT_MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
        " of arbitrary length."
    ),
    allow_flagging="never",
)


# demo.launch(server_name="0.0.0.0", debug=True, share=True)
demo.launch(enable_queue=True)