whisper-demo-french / run_demo_openai_layout.py
bofenghuang's picture
Use ct2 for inference
2eface4
raw
history blame
11.4 kB
#! /usr/bin/env python
# coding=utf-8
# Copyright 2022 Bofeng Huang
import datetime
import logging
import os
import re
import warnings
import gradio as gr
import pandas as pd
import psutil
import pytube as pt
import torch
import whisper
from huggingface_hub import hf_hub_download, model_info
from transformers.utils.logging import disable_progress_bar
import nltk
nltk.download("punkt")
from nltk.tokenize import sent_tokenize
warnings.filterwarnings("ignore")
disable_progress_bar()
DEFAULT_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-french"
CHECKPOINT_FILENAME = "checkpoint_openai.pt"
GEN_KWARGS = {
"task": "transcribe",
"language": "fr",
# "without_timestamps": True,
# decode options
# "beam_size": 5,
# "patience": 2,
# disable fallback
# "compression_ratio_threshold": None,
# "logprob_threshold": None,
# vad threshold
# "no_speech_threshold": None,
}
logging.basicConfig(
format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
datefmt="%Y-%m-%dT%H:%M:%SZ",
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# device = 0 if torch.cuda.is_available() else "cpu"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
logger.info(f"Model will be loaded on device `{device}`")
cached_models = {}
def format_timestamp(seconds):
return str(datetime.timedelta(seconds=round(seconds)))
def _return_yt_html_embed(yt_url):
video_id = yt_url.split("?v=")[-1]
HTML_str = (
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>' " </center>"
)
return HTML_str
def download_audio_from_youtube(yt_url, downloaded_filename="audio.wav"):
yt = pt.YouTube(yt_url)
stream = yt.streams.filter(only_audio=True)[0]
# stream.download(filename="audio.mp3")
stream.download(filename=downloaded_filename)
return downloaded_filename
def download_video_from_youtube(yt_url, downloaded_filename="video.mp4"):
yt = pt.YouTube(yt_url)
stream = yt.streams.filter(progressive=True, file_extension="mp4").order_by("resolution").desc().first()
stream.download(filename=downloaded_filename)
logger.info(f"Download YouTube video from {yt_url}")
return downloaded_filename
def _print_memory_info():
memory = psutil.virtual_memory()
logger.info(
f"Memory info - Free: {memory.available / (1024 ** 3):.2f} Gb, used: {memory.percent}%, total: {memory.total / (1024 ** 3):.2f} Gb"
)
def _print_cuda_memory_info():
used_mem, tot_mem = torch.cuda.mem_get_info()
logger.info(
f"CUDA memory info - Free: {used_mem / 1024 ** 3:.2f} Gb, used: {(tot_mem - used_mem) / 1024 ** 3:.2f} Gb, total: {tot_mem / 1024 ** 3:.2f} Gb"
)
def print_memory_info():
_print_memory_info()
_print_cuda_memory_info()
def maybe_load_cached_pipeline(model_name):
model = cached_models.get(model_name)
if model is None:
downloaded_model_path = hf_hub_download(repo_id=model_name, filename=CHECKPOINT_FILENAME)
model = whisper.load_model(downloaded_model_path, device=device)
logger.info(f"`{model_name}` has been loaded on device `{device}`")
print_memory_info()
cached_models[model_name] = model
return model
def infer(model, filename, with_timestamps, return_df=False):
if with_timestamps:
model_outputs = model.transcribe(filename, **GEN_KWARGS)
if return_df:
model_outputs_df = pd.DataFrame(model_outputs["segments"])
# print(model_outputs)
# print(model_outputs_df)
# print(model_outputs_df.info(verbose=True))
model_outputs_df = model_outputs_df[["start", "end", "text"]]
model_outputs_df["start"] = model_outputs_df["start"].map(format_timestamp)
model_outputs_df["end"] = model_outputs_df["end"].map(format_timestamp)
model_outputs_df["text"] = model_outputs_df["text"].str.strip()
return model_outputs_df
else:
return "\n\n".join(
[
f'Segment {segment["id"]+1} from {segment["start"]:.2f}s to {segment["end"]:.2f}s:\n{segment["text"].strip()}'
for segment in model_outputs["segments"]
]
)
else:
text = model.transcribe(filename, without_timestamps=True, **GEN_KWARGS)["text"]
if return_df:
return pd.DataFrame({"text": sent_tokenize(text)})
else:
return text
def transcribe(microphone, file_upload, with_timestamps, model_name=DEFAULT_MODEL_NAME):
warn_output = ""
if (microphone is not None) and (file_upload is not None):
warn_output = (
"WARNING: You've uploaded an audio file and used the microphone. "
"The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
)
elif (microphone is None) and (file_upload is None):
return "ERROR: You have to either use the microphone or upload an audio file"
file = microphone if microphone is not None else file_upload
model = maybe_load_cached_pipeline(model_name)
# text = model.transcribe(file, **GEN_KWARGS)["text"]
# text = infer(model, file, with_timestamps)
text = infer(model, file, with_timestamps, return_df=True)
logger.info(f'Transcription by `{model_name}`:\n{text.to_json(orient="index", force_ascii=False, indent=2)}\n')
# return warn_output + text
return text
def yt_transcribe(yt_url, with_timestamps, model_name=DEFAULT_MODEL_NAME):
# html_embed_str = _return_yt_html_embed(yt_url)
audio_file_path = download_audio_from_youtube(yt_url)
model = maybe_load_cached_pipeline(model_name)
# text = model.transcribe("audio.mp3", **GEN_KWARGS)["text"]
# text = infer(model, audio_file_path, with_timestamps)
text = infer(model, audio_file_path, with_timestamps, return_df=True)
logger.info(f'Transcription by `{model_name}` of "{yt_url}":\n{text.to_json(orient="index", force_ascii=False, indent=2)}\n')
# return html_embed_str, text
return text
def video_transcribe(video_file_path, with_timestamps, model_name=DEFAULT_MODEL_NAME):
if video_file_path is None:
raise ValueError("Failed to transcribe video as no video_file_path has been defined")
audio_file_path = re.sub(r"\.mp4$", ".wav", video_file_path)
os.system(f'ffmpeg -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{audio_file_path}"')
model = maybe_load_cached_pipeline(model_name)
# text = model.transcribe("audio.mp3", **GEN_KWARGS)["text"]
text = infer(model, audio_file_path, with_timestamps, return_df=True)
logger.info(f'Transcription by `{model_name}`:\n{text.to_json(orient="index", force_ascii=False, indent=2)}\n')
return text
# load default model
maybe_load_cached_pipeline(DEFAULT_MODEL_NAME)
# default_text_output_df = pd.DataFrame(columns=["start", "end", "text"])
default_text_output_df = pd.DataFrame(columns=["text"])
with gr.Blocks() as demo:
with gr.Tab("Transcribe Audio"):
gr.Markdown(
f"""
<div>
<h1 style='text-align: center'>Whisper French Demo: Transcribe Audio</h1>
</div>
Transcribe long-form microphone or audio inputs!
Demo uses the fine-tuned checkpoint: <a href='https://huggingface.co/{DEFAULT_MODEL_NAME}' target='_blank'><b>{DEFAULT_MODEL_NAME}</b></a> to transcribe audio files of arbitrary length.
"""
)
microphone_input = gr.inputs.Audio(source="microphone", type="filepath", label="Record", optional=True)
upload_input = gr.inputs.Audio(source="upload", type="filepath", label="Upload File", optional=True)
with_timestamps_input = gr.Checkbox(label="With timestamps?")
microphone_transcribe_btn = gr.Button("Transcribe Audio")
# gr.Markdown('''
# Here you will get generated transcrit.
# ''')
# microphone_text_output = gr.outputs.Textbox(label="Transcription")
text_output_df2 = gr.DataFrame(
value=default_text_output_df,
label="Transcription",
row_count=(0, "dynamic"),
max_rows=10,
wrap=True,
overflow_row_behaviour="paginate",
)
microphone_transcribe_btn.click(
transcribe, inputs=[microphone_input, upload_input, with_timestamps_input], outputs=text_output_df2
)
# with gr.Tab("Transcribe YouTube"):
# gr.Markdown(
# f"""
# <div>
# <h1 style='text-align: center'>Whisper French Demo: Transcribe YouTube</h1>
# </div>
# Transcribe long-form YouTube videos!
# Demo uses the fine-tuned checkpoint: <a href='https://huggingface.co/{DEFAULT_MODEL_NAME}' target='_blank'><b>{DEFAULT_MODEL_NAME}</b></a> to transcribe video files of arbitrary length.
# """
# )
# yt_link_input2 = gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")
# with_timestamps_input2 = gr.Checkbox(label="With timestamps?", value=True)
# yt_transcribe_btn = gr.Button("Transcribe YouTube")
# # yt_text_output = gr.outputs.Textbox(label="Transcription")
# text_output_df3 = gr.DataFrame(
# value=default_text_output_df,
# label="Transcription",
# row_count=(0, "dynamic"),
# max_rows=10,
# wrap=True,
# overflow_row_behaviour="paginate",
# )
# # yt_html_output = gr.outputs.HTML(label="YouTube Page")
# yt_transcribe_btn.click(yt_transcribe, inputs=[yt_link_input2, with_timestamps_input2], outputs=[text_output_df3])
with gr.Tab("Transcribe Video"):
gr.Markdown(
f"""
<div>
<h1 style='text-align: center'>Whisper French Demo: Transcribe Video</h1>
</div>
Transcribe long-form YouTube videos or uploaded video inputs!
Demo uses the fine-tuned checkpoint: <a href='https://huggingface.co/{DEFAULT_MODEL_NAME}' target='_blank'><b>{DEFAULT_MODEL_NAME}</b></a> to transcribe video files of arbitrary length.
"""
)
yt_link_input = gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")
download_youtube_btn = gr.Button("Download Youtube video")
downloaded_video_output = gr.Video(label="Video file", mirror_webcam=False)
download_youtube_btn.click(download_video_from_youtube, inputs=[yt_link_input], outputs=[downloaded_video_output])
with_timestamps_input3 = gr.Checkbox(label="With timestamps?", value=True)
video_transcribe_btn = gr.Button("Transcribe video")
text_output_df = gr.DataFrame(
value=default_text_output_df,
label="Transcription",
row_count=(0, "dynamic"),
max_rows=10,
wrap=True,
overflow_row_behaviour="paginate",
)
video_transcribe_btn.click(video_transcribe, inputs=[downloaded_video_output, with_timestamps_input3], outputs=[text_output_df])
# demo.launch(server_name="0.0.0.0", debug=True)
# demo.launch(server_name="0.0.0.0", debug=True, share=True)
demo.launch(enable_queue=True)