Rename pages/03_π_Upload_Audio_File.py to pages/04_π_Upload_Audio_File.py
43445ac
import whisper | |
import streamlit as st | |
from streamlit_lottie import st_lottie | |
from utils import write_vtt, write_srt | |
import ffmpeg | |
import requests | |
from typing import Iterator | |
from io import StringIO | |
import numpy as np | |
import pathlib | |
import os | |
st.set_page_config(page_title="Auto Transcriber", page_icon="π", layout="wide") | |
# Define a function that we can use to load lottie files from a link. | |
def load_lottieurl(url: str): | |
r = requests.get(url) | |
if r.status_code != 200: | |
return None | |
return r.json() | |
APP_DIR = pathlib.Path(__file__).parent.absolute() | |
LOCAL_DIR = APP_DIR / "local_audio" | |
LOCAL_DIR.mkdir(exist_ok=True) | |
save_dir = LOCAL_DIR / "output" | |
save_dir.mkdir(exist_ok=True) | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
lottie = load_lottieurl("https://assets1.lottiefiles.com/packages/lf20_1xbk4d2v.json") | |
st_lottie(lottie) | |
with col2: | |
st.write(""" | |
## Auto Transcriber | |
##### Input an audio file and get a transcript. | |
###### β If you want to transcribe the audio in its original language, select the task as "Transcribe" | |
###### β If you want to translate the transcription to English, select the task as "Translate" | |
###### I recommend starting with the base model and then experimenting with the larger models, the small and medium models often work well. """) | |
loaded_model = whisper.load_model("base") | |
current_size = "None" | |
def change_model(current_size, size): | |
if current_size != size: | |
loaded_model = whisper.load_model(size) | |
return loaded_model | |
else: | |
raise Exception("Model size is the same as the current size.") | |
def inferecence(loaded_model, uploaded_file, task): | |
with open(f"{save_dir}/input.mp3", "wb") as f: | |
f.write(uploaded_file.read()) | |
audio = ffmpeg.input(f"{save_dir}/input.mp3") | |
audio = ffmpeg.output(audio, f"{save_dir}/output.wav", acodec="pcm_s16le", ac=1, ar="16k") | |
ffmpeg.run(audio, overwrite_output=True) | |
if task == "Transcribe": | |
options = dict(task="transcribe", best_of=5) | |
results = loaded_model.transcribe(f"{save_dir}/output.wav", **options) | |
vtt = getSubs(results["segments"], "vtt", 80) | |
srt = getSubs(results["segments"], "srt", 80) | |
lang = results["language"] | |
return results["text"], vtt, srt, lang | |
elif task == "Translate": | |
options = dict(task="translate", best_of=5) | |
results = loaded_model.transcribe(f"{save_dir}/output.wav", **options) | |
vtt = getSubs(results["segments"], "vtt", 80) | |
srt = getSubs(results["segments"], "srt", 80) | |
lang = results["language"] | |
return results["text"], vtt, srt, lang | |
else: | |
raise ValueError("Task not supported") | |
def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int) -> str: | |
segmentStream = StringIO() | |
if format == 'vtt': | |
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth) | |
elif format == 'srt': | |
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth) | |
else: | |
raise Exception("Unknown format " + format) | |
segmentStream.seek(0) | |
return segmentStream.read() | |
def main(): | |
size = st.selectbox("Select Model Size (The larger the model, the more accurate the transcription will be, but it will take longer)", ["tiny", "base", "small", "medium", "large"], index=1) | |
loaded_model = change_model(current_size, size) | |
st.write(f"Model is {'multilingual' if loaded_model.is_multilingual else 'English-only'} " | |
f"and has {sum(np.prod(p.shape) for p in loaded_model.parameters()):,} parameters.") | |
input_file = st.file_uploader("Upload an audio file", type=["mp3", "wav", "m4a"]) | |
if input_file is not None: | |
filename = input_file.name[:-4] | |
else: | |
filename = None | |
task = st.selectbox("Select Task", ["Transcribe", "Translate"], index=0) | |
if task == "Transcribe": | |
if st.button("Transcribe"): | |
results = inferecence(loaded_model, input_file, task) | |
col3, col4 = st.columns(2) | |
col5, col6, col7 = st.columns(3) | |
col9, col10 = st.columns(2) | |
with col3: | |
st.audio(input_file) | |
with open("transcript.txt", "w+", encoding='utf8') as f: | |
f.writelines(results[0]) | |
f.close() | |
with open(os.path.join(os.getcwd(), "transcript.txt"), "rb") as f: | |
datatxt = f.read() | |
with open("transcript.vtt", "w+",encoding='utf8') as f: | |
f.writelines(results[1]) | |
f.close() | |
with open(os.path.join(os.getcwd(), "transcript.vtt"), "rb") as f: | |
datavtt = f.read() | |
with open("transcript.srt", "w+",encoding='utf8') as f: | |
f.writelines(results[2]) | |
f.close() | |
with open(os.path.join(os.getcwd(), "transcript.srt"), "rb") as f: | |
datasrt = f.read() | |
with col5: | |
st.download_button(label="Download Transcript (.txt)", | |
data=datatxt, | |
file_name="transcript.txt") | |
with col6: | |
st.download_button(label="Download Transcript (.vtt)", | |
data=datavtt, | |
file_name="transcript.vtt") | |
with col7: | |
st.download_button(label="Download Transcript (.srt)", | |
data=datasrt, | |
file_name="transcript.srt") | |
with col9: | |
st.success("You can download the transcript in .srt format, edit it (if you need to) and upload it to YouTube to create subtitles for your video.") | |
with col10: | |
st.info("Streamlit refreshes after the download button is clicked. The data is cached so you can download the transcript again without having to transcribe the video again.") | |
elif task == "Translate": | |
if st.button("Translate to English"): | |
results = inferecence(loaded_model, input_file, task) | |
col3, col4 = st.columns(2) | |
col5, col6, col7 = st.columns(3) | |
col9, col10 = st.columns(2) | |
with col3: | |
st.audio(input_file) | |
with open("transcript.txt", "w+", encoding='utf8') as f: | |
f.writelines(results[0]) | |
f.close() | |
with open(os.path.join(os.getcwd(), "transcript.txt"), "rb") as f: | |
datatxt = f.read() | |
with open("transcript.vtt", "w+",encoding='utf8') as f: | |
f.writelines(results[1]) | |
f.close() | |
with open(os.path.join(os.getcwd(), "transcript.vtt"), "rb") as f: | |
datavtt = f.read() | |
with open("transcript.srt", "w+",encoding='utf8') as f: | |
f.writelines(results[2]) | |
f.close() | |
with open(os.path.join(os.getcwd(), "transcript.srt"), "rb") as f: | |
datasrt = f.read() | |
with col5: | |
st.download_button(label="Download Transcript (.txt)", | |
data=datatxt, | |
file_name="transcript.txt") | |
with col6: | |
st.download_button(label="Download Transcript (.vtt)", | |
data=datavtt, | |
file_name="transcript.vtt") | |
with col7: | |
st.download_button(label="Download Transcript (.srt)", | |
data=datasrt, | |
file_name="transcript.srt") | |
with col9: | |
st.success("You can download the transcript in .srt format, edit it (if you need to) and upload it to YouTube to create subtitles for your video.") | |
with col10: | |
st.info("Streamlit refreshes after the download button is clicked. The data is cached so you can download the transcript again without having to transcribe the video again.") | |
else: | |
st.error("Please select a task.") | |
if __name__ == "__main__": | |
main() | |
st.markdown("###### Made with :heart: by [@BatuhanYΔ±lmaz](https://twitter.com/batuhan3326) [![this is an image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/batuhanylmz)") |