bofenghuang commited on
Commit
8dde699
·
1 Parent(s): 120532c

add timestamp option

Browse files
Files changed (2) hide show
  1. run_demo_openai.py +71 -22
  2. run_demo_openai_merged.py +169 -0
run_demo_openai.py CHANGED
@@ -4,22 +4,14 @@ import warnings
4
  import gradio as gr
5
  import pytube as pt
6
  import torch
 
7
  from huggingface_hub import hf_hub_download, model_info
8
  from transformers.utils.logging import disable_progress_bar
9
- import whisper
10
-
11
 
12
  warnings.filterwarnings("ignore")
13
  disable_progress_bar()
14
 
15
- logging.basicConfig(
16
- format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
17
- datefmt="%Y-%m-%dT%H:%M:%SZ",
18
- )
19
- logger = logging.getLogger(__name__)
20
- logger.setLevel(logging.DEBUG)
21
-
22
- MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-french"
23
  CHECKPOINT_FILENAME = "checkpoint_openai.pt"
24
 
25
  GEN_KWARGS = {
@@ -36,17 +28,63 @@ GEN_KWARGS = {
36
  # "no_speech_threshold": None,
37
  }
38
 
 
 
 
 
 
 
 
39
  # device = 0 if torch.cuda.is_available() else "cpu"
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
42
 
43
- downloaded_model_path = hf_hub_download(repo_id=MODEL_NAME, filename=CHECKPOINT_FILENAME)
 
44
 
45
- model = whisper.load_model(downloaded_model_path, device=device)
46
- logger.info(f"Model has been loaded on device `{device}`")
47
 
 
 
48
 
49
- def transcribe(microphone, file_upload):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  warn_output = ""
51
  if (microphone is not None) and (file_upload is not None):
52
  warn_output = (
@@ -59,9 +97,11 @@ def transcribe(microphone, file_upload):
59
 
60
  file = microphone if microphone is not None else file_upload
61
 
62
- text = model.transcribe(file, **GEN_KWARGS)["text"]
 
 
63
 
64
- logger.info(f"Transcription: {text}")
65
 
66
  return warn_output + text
67
 
@@ -74,19 +114,24 @@ def _return_yt_html_embed(yt_url):
74
  return HTML_str
75
 
76
 
77
- def yt_transcribe(yt_url):
78
  yt = pt.YouTube(yt_url)
79
  html_embed_str = _return_yt_html_embed(yt_url)
80
  stream = yt.streams.filter(only_audio=True)[0]
81
  stream.download(filename="audio.mp3")
82
 
83
- text = model.transcribe("audio.mp3", **GEN_KWARGS)["text"]
 
 
84
 
85
- logger.info(f'Transcription of "{yt_url}": {text}')
86
 
87
  return html_embed_str, text
88
 
89
 
 
 
 
90
  demo = gr.Blocks()
91
 
92
  mf_transcribe = gr.Interface(
@@ -94,6 +139,7 @@ mf_transcribe = gr.Interface(
94
  inputs=[
95
  gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record"),
96
  gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload File"),
 
97
  ],
98
  # outputs="text",
99
  outputs=gr.outputs.Textbox(label="Transcription"),
@@ -102,7 +148,7 @@ mf_transcribe = gr.Interface(
102
  title="Whisper French Demo 🇫🇷 : Transcribe Audio",
103
  description=(
104
  "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the the fine-tuned"
105
- f" checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
106
  " of arbitrary length."
107
  ),
108
  allow_flagging="never",
@@ -110,7 +156,10 @@ mf_transcribe = gr.Interface(
110
 
111
  yt_transcribe = gr.Interface(
112
  fn=yt_transcribe,
113
- inputs=[gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")],
 
 
 
114
  # outputs=["html", "text"],
115
  outputs=[
116
  gr.outputs.HTML(label="YouTube Page"),
@@ -121,7 +170,7 @@ yt_transcribe = gr.Interface(
121
  title="Whisper French Demo 🇫🇷 : Transcribe YouTube",
122
  description=(
123
  "Transcribe long-form YouTube videos with the click of a button! Demo uses the the fine-tuned checkpoint:"
124
- f" [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to transcribe audio files of"
125
  " arbitrary length."
126
  ),
127
  allow_flagging="never",
 
4
  import gradio as gr
5
  import pytube as pt
6
  import torch
7
+ import whisper
8
  from huggingface_hub import hf_hub_download, model_info
9
  from transformers.utils.logging import disable_progress_bar
 
 
10
 
11
  warnings.filterwarnings("ignore")
12
  disable_progress_bar()
13
 
14
+ DEFAULT_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-french"
 
 
 
 
 
 
 
15
  CHECKPOINT_FILENAME = "checkpoint_openai.pt"
16
 
17
  GEN_KWARGS = {
 
28
  # "no_speech_threshold": None,
29
  }
30
 
31
+ logging.basicConfig(
32
+ format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
33
+ datefmt="%Y-%m-%dT%H:%M:%SZ",
34
+ )
35
+ logger = logging.getLogger(__name__)
36
+ logger.setLevel(logging.DEBUG)
37
+
38
  # device = 0 if torch.cuda.is_available() else "cpu"
39
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
+ logger.info(f"Model will be loaded on device `{device}`")
41
+
42
+ cached_models = {}
43
+
44
+
45
+ def print_cuda_memory_info():
46
+ used_mem, tot_mem = torch.cuda.mem_get_info()
47
+ logger.info(
48
+ f"CUDA memory info - Free: {used_mem / 1024 ** 3:.2f} Gb, used: {(tot_mem - used_mem) / 1024 ** 3:.2f} Gb, total: {tot_mem / 1024 ** 3:.2f} Gb"
49
+ )
50
+
51
+
52
+ def print_memory_info():
53
+ # todo
54
+ if device == "cpu":
55
+ pass
56
+ else:
57
+ print_cuda_memory_info()
58
+
59
 
60
+ def maybe_load_cached_pipeline(model_name):
61
+ model = cached_models.get(model_name)
62
+ if model is None:
63
+ downloaded_model_path = hf_hub_download(repo_id=model_name, filename=CHECKPOINT_FILENAME)
64
 
65
+ model = whisper.load_model(downloaded_model_path, device=device)
66
+ logger.info(f"`{model_name}` has been loaded on device `{device}`")
67
 
68
+ print_memory_info()
 
69
 
70
+ cached_models[model_name] = model
71
+ return model
72
 
73
+
74
+ def infer(model, filename, with_timestamps):
75
+ if with_timestamps:
76
+ model_outputs = model.transcribe(filename, **GEN_KWARGS)
77
+ return "\n\n".join(
78
+ [
79
+ f'Segment {segment["id"]+1} from {segment["start"]:.2f}s to {segment["end"]:.2f}s:\n{segment["text"].strip()}'
80
+ for segment in model_outputs["segments"]
81
+ ]
82
+ )
83
+ else:
84
+ return model.transcribe(filename, without_timestamps=True, **GEN_KWARGS)["text"]
85
+
86
+
87
+ def transcribe(microphone, file_upload, with_timestamps, model_name=DEFAULT_MODEL_NAME):
88
  warn_output = ""
89
  if (microphone is not None) and (file_upload is not None):
90
  warn_output = (
 
97
 
98
  file = microphone if microphone is not None else file_upload
99
 
100
+ model = maybe_load_cached_pipeline(model_name)
101
+ # text = model.transcribe(file, **GEN_KWARGS)["text"]
102
+ text = infer(model, file, with_timestamps)
103
 
104
+ logger.info(f"Transcription by `{model_name}`: {text}")
105
 
106
  return warn_output + text
107
 
 
114
  return HTML_str
115
 
116
 
117
+ def yt_transcribe(yt_url, with_timestamps, model_name=DEFAULT_MODEL_NAME):
118
  yt = pt.YouTube(yt_url)
119
  html_embed_str = _return_yt_html_embed(yt_url)
120
  stream = yt.streams.filter(only_audio=True)[0]
121
  stream.download(filename="audio.mp3")
122
 
123
+ model = maybe_load_cached_pipeline(model_name)
124
+ # text = model.transcribe("audio.mp3", **GEN_KWARGS)["text"]
125
+ text = infer(model, "audio.mp3", with_timestamps)
126
 
127
+ logger.info(f'Transcription by `{model_name}` of "{yt_url}": {text}')
128
 
129
  return html_embed_str, text
130
 
131
 
132
+ # load default model
133
+ maybe_load_cached_pipeline(DEFAULT_MODEL_NAME)
134
+
135
  demo = gr.Blocks()
136
 
137
  mf_transcribe = gr.Interface(
 
139
  inputs=[
140
  gr.inputs.Audio(source="microphone", type="filepath", optional=True, label="Record"),
141
  gr.inputs.Audio(source="upload", type="filepath", optional=True, label="Upload File"),
142
+ gr.Checkbox(label="With timestamps?", value=True),
143
  ],
144
  # outputs="text",
145
  outputs=gr.outputs.Textbox(label="Transcription"),
 
148
  title="Whisper French Demo 🇫🇷 : Transcribe Audio",
149
  description=(
150
  "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the the fine-tuned"
151
+ f" checkpoint [{DEFAULT_MODEL_NAME}](https://huggingface.co/{DEFAULT_MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
152
  " of arbitrary length."
153
  ),
154
  allow_flagging="never",
 
156
 
157
  yt_transcribe = gr.Interface(
158
  fn=yt_transcribe,
159
+ inputs=[
160
+ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
161
+ gr.Checkbox(label="With timestamps?", value=True),
162
+ ],
163
  # outputs=["html", "text"],
164
  outputs=[
165
  gr.outputs.HTML(label="YouTube Page"),
 
170
  title="Whisper French Demo 🇫🇷 : Transcribe YouTube",
171
  description=(
172
  "Transcribe long-form YouTube videos with the click of a button! Demo uses the the fine-tuned checkpoint:"
173
+ f" [{DEFAULT_MODEL_NAME}](https://huggingface.co/{DEFAULT_MODEL_NAME}) and 🤗 Transformers to transcribe audio files of"
174
  " arbitrary length."
175
  ),
176
  allow_flagging="never",
run_demo_openai_merged.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+
4
+ import gradio as gr
5
+ import pytube as pt
6
+ import torch
7
+ import whisper
8
+ from huggingface_hub import hf_hub_download, model_info
9
+ from transformers.utils.logging import disable_progress_bar
10
+
11
+ warnings.filterwarnings("ignore")
12
+ disable_progress_bar()
13
+
14
+ DEFAULT_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-french"
15
+ CHECKPOINT_FILENAME = "checkpoint_openai.pt"
16
+
17
+ GEN_KWARGS = {
18
+ "task": "transcribe",
19
+ "language": "fr",
20
+ # "without_timestamps": True,
21
+ # decode options
22
+ # "beam_size": 5,
23
+ # "patience": 2,
24
+ # disable fallback
25
+ # "compression_ratio_threshold": None,
26
+ # "logprob_threshold": None,
27
+ # vad threshold
28
+ # "no_speech_threshold": None,
29
+ }
30
+
31
+ logging.basicConfig(
32
+ format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
33
+ datefmt="%Y-%m-%dT%H:%M:%SZ",
34
+ )
35
+ logger = logging.getLogger(__name__)
36
+ logger.setLevel(logging.DEBUG)
37
+
38
+ # device = 0 if torch.cuda.is_available() else "cpu"
39
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
40
+ logger.info(f"Model will be loaded on device `{device}`")
41
+
42
+ cached_models = {}
43
+
44
+
45
+ def print_cuda_memory_info():
46
+ used_mem, tot_mem = torch.cuda.mem_get_info()
47
+ logger.info(
48
+ f"CUDA memory info - Free: {used_mem / 1024 ** 3:.2f} Gb, used: {(tot_mem - used_mem) / 1024 ** 3:.2f} Gb, total: {tot_mem / 1024 ** 3:.2f} Gb"
49
+ )
50
+
51
+
52
+ def print_memory_info():
53
+ # todo
54
+ if device == "cpu":
55
+ pass
56
+ else:
57
+ print_cuda_memory_info()
58
+
59
+
60
+ def maybe_load_cached_pipeline(model_name):
61
+ model = cached_models.get(model_name)
62
+ if model is None:
63
+ downloaded_model_path = hf_hub_download(repo_id=model_name, filename=CHECKPOINT_FILENAME)
64
+
65
+ model = whisper.load_model(downloaded_model_path, device=device)
66
+ logger.info(f"`{model_name}` has been loaded on device `{device}`")
67
+
68
+ print_memory_info()
69
+
70
+ cached_models[model_name] = model
71
+ return model
72
+
73
+
74
+ def infer(model, filename, with_timestamps):
75
+ if with_timestamps:
76
+ model_outputs = model.transcribe(filename, **GEN_KWARGS)
77
+ return "\n\n".join(
78
+ [
79
+ f'Segment {segment["id"]+1} from {segment["start"]:.2f}s to {segment["end"]:.2f}s:\n{segment["text"].strip()}'
80
+ for segment in model_outputs["segments"]
81
+ ]
82
+ )
83
+ else:
84
+ return model.transcribe(filename, without_timestamps=True, **GEN_KWARGS)["text"]
85
+
86
+
87
+ def download_from_youtube(yt_url, downloaded_filename="audio.wav"):
88
+ yt = pt.YouTube(yt_url)
89
+ stream = yt.streams.filter(only_audio=True)[0]
90
+ # stream.download(filename="audio.mp3")
91
+ stream.download(filename=downloaded_filename)
92
+ return downloaded_filename
93
+
94
+
95
+ def transcribe(microphone, file_upload, yt_url, with_timestamps, model_name=DEFAULT_MODEL_NAME):
96
+ warn_output = ""
97
+ if (microphone is not None) and (file_upload is not None) and (yt_url is not None):
98
+ warn_output = (
99
+ "WARNING: You've uploaded an audio file, used the microphone, and pasted a YouTube URL. "
100
+ "The recorded file from the microphone will be used, the uploaded audio and the YouTube URL will be discarded.\n"
101
+ )
102
+
103
+ if (microphone is not None) and (file_upload is not None):
104
+ warn_output = (
105
+ "WARNING: You've uploaded an audio file and used the microphone. "
106
+ "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
107
+ )
108
+
109
+ if (microphone is not None) and (yt_url is not None):
110
+ warn_output = (
111
+ "WARNING: You've used the microphone and pasted a YouTube URL. "
112
+ "The recorded file from the microphone will be used and the YouTube URL will be discarded.\n"
113
+ )
114
+
115
+ if (file_upload is not None) and (yt_url is not None):
116
+ warn_output = (
117
+ "WARNING: You've uploaded an audio file and pasted a YouTube URL. "
118
+ "The uploaded audio will be used and the YouTube URL will be discarded.\n"
119
+ )
120
+
121
+ elif (microphone is None) and (file_upload is None) or (yt_url is None):
122
+ return "ERROR: You have to either use the microphone, upload an audio file or paste a YouTube URL"
123
+
124
+ if microphone is not None:
125
+ file = microphone
126
+ logging_prefix = f"Transcription by `{model_name}` of microphone:"
127
+ elif file_upload is not None:
128
+ file = file_upload
129
+ logging_prefix = f"Transcription by `{model_name}` of uploaded file:"
130
+ else:
131
+ file = download_from_youtube(yt_url)
132
+ logging_prefix = f'Transcription by `{model_name}` of "{yt_url}":'
133
+
134
+ model = maybe_load_cached_pipeline(model_name)
135
+ # text = model.transcribe(file, **GEN_KWARGS)["text"]
136
+ text = infer(model, file, with_timestamps)
137
+
138
+ logger.info(logging_prefix + "\n" + text)
139
+
140
+ return warn_output + text
141
+
142
+
143
+ # load default model
144
+ maybe_load_cached_pipeline(DEFAULT_MODEL_NAME)
145
+
146
+ demo = gr.Interface(
147
+ fn=transcribe,
148
+ inputs=[
149
+ gr.inputs.Audio(source="microphone", type="filepath", label="Record", optional=True),
150
+ gr.inputs.Audio(source="upload", type="filepath", label="Upload File", optional=True),
151
+ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL", optional=True),
152
+ gr.Checkbox(label="With timestamps?", value=True),
153
+ ],
154
+ # outputs="text",
155
+ outputs=gr.outputs.Textbox(label="Transcription"),
156
+ layout="horizontal",
157
+ theme="huggingface",
158
+ title="Whisper French Demo 🇫🇷 : Transcribe Audio",
159
+ description=(
160
+ "Transcribe long-form microphone or audio inputs with the click of a button! Demo uses the the fine-tuned"
161
+ f" checkpoint [{DEFAULT_MODEL_NAME}](https://huggingface.co/{DEFAULT_MODEL_NAME}) and 🤗 Transformers to transcribe audio files"
162
+ " of arbitrary length."
163
+ ),
164
+ allow_flagging="never",
165
+ )
166
+
167
+
168
+ # demo.launch(server_name="0.0.0.0", debug=True, share=True)
169
+ demo.launch(enable_queue=True)