Spaces:
Sleeping
Sleeping
BatuhanYilmaz
commited on
Commit
·
8335d37
1
Parent(s):
ec4b0ac
Upload video + transcript to get a subtitled video
Browse files
components/.env
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
COGNITO_DOMAIN = "xxx"
|
2 |
+
CLIENT_ID = "xxx"
|
3 |
+
CLIENT_SECRET = "xxx"
|
4 |
+
APP_URI = "xxx"
|
components/__init__.py
ADDED
File without changes
|
components/authenticate.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import streamlit as st
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
import requests
|
5 |
+
import base64
|
6 |
+
import json
|
7 |
+
|
8 |
+
load_dotenv()
|
9 |
+
COGNITO_DOMAIN = os.environ.get("COGNITO_DOMAIN")
|
10 |
+
CLIENT_ID = os.environ.get("CLIENT_ID")
|
11 |
+
CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
|
12 |
+
APP_URI = os.environ.get("APP_URI")
|
13 |
+
|
14 |
+
|
15 |
+
def init_state():
|
16 |
+
if "auth_code" not in st.session_state:
|
17 |
+
st.session_state["auth_code"] = ""
|
18 |
+
if "authenticated" not in st.session_state:
|
19 |
+
st.session_state["authenticated"] = False
|
20 |
+
if "user_cognito_groups" not in st.session_state:
|
21 |
+
st.session_state["user_cognito_groups"] = []
|
22 |
+
|
23 |
+
# Get the authorization code after the user has logged in
|
24 |
+
def get_auth_code():
|
25 |
+
auth_query_params = st.experimental_get_query_params()
|
26 |
+
try:
|
27 |
+
auth_code = dict(auth_query_params)["code"][0]
|
28 |
+
except (KeyError, TypeError):
|
29 |
+
auth_code = ""
|
30 |
+
return auth_code
|
31 |
+
|
32 |
+
|
33 |
+
# Set the authorization code after the user has logged in
|
34 |
+
def set_auth_code():
|
35 |
+
init_state()
|
36 |
+
auth_code = get_auth_code()
|
37 |
+
st.session_state["auth_code"] = auth_code
|
38 |
+
|
39 |
+
|
40 |
+
# Get the access token from the authorization code
|
41 |
+
def get_user_tokens(auth_code):
|
42 |
+
# Variables to make a post request
|
43 |
+
token_url = f"{COGNITO_DOMAIN}/oauth2/token"
|
44 |
+
client_secret_string = f"{CLIENT_ID}:{CLIENT_SECRET}"
|
45 |
+
client_secret_encoded = str(
|
46 |
+
base64.b64encode(client_secret_string.encode("utf-8")), "utf-8"
|
47 |
+
)
|
48 |
+
headers = {
|
49 |
+
"Content-Type": "application/x-www-form-urlencoded",
|
50 |
+
"Authorization": f"Basic {client_secret_encoded}",
|
51 |
+
}
|
52 |
+
body = {
|
53 |
+
"grant_type": "authorization_code",
|
54 |
+
"client_id": CLIENT_ID,
|
55 |
+
"code": auth_code,
|
56 |
+
"redirect_uri": APP_URI,
|
57 |
+
}
|
58 |
+
|
59 |
+
token_response = requests.post(token_url, headers=headers, data=body)
|
60 |
+
try:
|
61 |
+
access_token = token_response.json()["access_token"]
|
62 |
+
id_token = token_response.json()["id_token"]
|
63 |
+
except (KeyError, TypeError):
|
64 |
+
access_token = ""
|
65 |
+
id_token = ""
|
66 |
+
|
67 |
+
return access_token, id_token
|
68 |
+
|
69 |
+
|
70 |
+
# Use access token to retrieve user info
|
71 |
+
def get_user_info(access_token):
|
72 |
+
userinfo_url = f"{COGNITO_DOMAIN}/oauth2/userInfo"
|
73 |
+
headers = {
|
74 |
+
"Content-Type": "application/json;charset=UTF-8",
|
75 |
+
"Authorization": f"Bearer {access_token}",
|
76 |
+
}
|
77 |
+
|
78 |
+
userinfo_response = requests.get(userinfo_url, headers=headers)
|
79 |
+
|
80 |
+
return userinfo_response.json()
|
81 |
+
|
82 |
+
|
83 |
+
# Decode access token to JWT to get user's cognito groups
|
84 |
+
def pad_base64(data):
|
85 |
+
missing_padding = len(data) % 4
|
86 |
+
if missing_padding != 0:
|
87 |
+
data += "=" * (4 - missing_padding)
|
88 |
+
return data
|
89 |
+
|
90 |
+
|
91 |
+
def get_user_cognito_groups(id_token):
|
92 |
+
user_cognito_groups = []
|
93 |
+
if id_token != "":
|
94 |
+
header, payload, signature = id_token.split(".")
|
95 |
+
printable_payload = base64.urlsafe_b64decode(pad_base64(payload))
|
96 |
+
payload_dict = json.loads(printable_payload)
|
97 |
+
try:
|
98 |
+
user_cognito_groups = list(dict(payload_dict)["cognito:groups"])
|
99 |
+
except (KeyError, TypeError):
|
100 |
+
pass
|
101 |
+
return user_cognito_groups
|
102 |
+
|
103 |
+
|
104 |
+
# Set streamlit state variables
|
105 |
+
def set_st_state_vars():
|
106 |
+
init_state()
|
107 |
+
auth_code = get_auth_code()
|
108 |
+
access_token, id_token = get_user_tokens(auth_code)
|
109 |
+
user_cognito_groups = get_user_cognito_groups(id_token)
|
110 |
+
|
111 |
+
if access_token != "":
|
112 |
+
st.session_state["auth_code"] = auth_code
|
113 |
+
st.session_state["authenticated"] = True
|
114 |
+
st.session_state["user_cognito_groups"] = user_cognito_groups
|
115 |
+
|
116 |
+
|
117 |
+
# Login/ Logout HTML components
|
118 |
+
login_link = f"{COGNITO_DOMAIN}/login?client_id={CLIENT_ID}&response_type=code&scope=email+openid&redirect_uri={APP_URI}"
|
119 |
+
logout_link = f"{COGNITO_DOMAIN}/logout?client_id={CLIENT_ID}&logout_uri={APP_URI}"
|
120 |
+
|
121 |
+
html_css_login = """
|
122 |
+
<style>
|
123 |
+
.button-login {
|
124 |
+
background-color: skyblue;
|
125 |
+
color: white !important;
|
126 |
+
padding: 1em 1.5em;
|
127 |
+
text-decoration: none;
|
128 |
+
text-transform: uppercase;
|
129 |
+
}
|
130 |
+
|
131 |
+
.button-login:hover {
|
132 |
+
background-color: #555;
|
133 |
+
text-decoration: none;
|
134 |
+
}
|
135 |
+
|
136 |
+
.button-login:active {
|
137 |
+
background-color: black;
|
138 |
+
}
|
139 |
+
|
140 |
+
</style>
|
141 |
+
"""
|
142 |
+
|
143 |
+
html_button_login = (
|
144 |
+
html_css_login
|
145 |
+
+ f"<a href='{login_link}' class='button-login' target='_self'>Log In</a>"
|
146 |
+
)
|
147 |
+
html_button_logout = (
|
148 |
+
html_css_login
|
149 |
+
+ f"<a href='{logout_link}' class='button-login' target='_self'>Log Out</a>"
|
150 |
+
)
|
151 |
+
|
152 |
+
|
153 |
+
def button_login():
|
154 |
+
return st.sidebar.markdown(f"{html_button_login}", unsafe_allow_html=True)
|
155 |
+
|
156 |
+
|
157 |
+
def button_logout():
|
158 |
+
return st.sidebar.markdown(f"{html_button_logout}", unsafe_allow_html=True)
|
pages/02_📼_Upload_Video_File.py
CHANGED
@@ -9,6 +9,8 @@ from io import StringIO
|
|
9 |
import numpy as np
|
10 |
import pathlib
|
11 |
import os
|
|
|
|
|
12 |
|
13 |
st.set_page_config(page_title="Auto Subtitled Video Generator", page_icon=":movie_camera:", layout="wide")
|
14 |
|
@@ -36,7 +38,7 @@ current_size = "None"
|
|
36 |
col1, col2 = st.columns([1, 3])
|
37 |
with col1:
|
38 |
lottie = load_lottieurl("https://assets1.lottiefiles.com/packages/lf20_HjK9Ol.json")
|
39 |
-
st_lottie(lottie
|
40 |
|
41 |
with col2:
|
42 |
st.write("""
|
@@ -49,8 +51,10 @@ with col2:
|
|
49 |
|
50 |
@st.cache(allow_output_mutation=True)
|
51 |
def change_model(current_size, size):
|
|
|
|
|
52 |
if current_size != size:
|
53 |
-
loaded_model = whisper.load_model(size)
|
54 |
return loaded_model
|
55 |
else:
|
56 |
raise Exception("Model size is the same as the current size.")
|
@@ -98,7 +102,7 @@ def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
|
|
98 |
def generate_subtitled_video(video, audio, transcript):
|
99 |
video_file = ffmpeg.input(video)
|
100 |
audio_file = ffmpeg.input(audio)
|
101 |
-
ffmpeg.concat(video_file.filter("subtitles", transcript), audio_file, v=1, a=1).output("final.mp4").run(quiet=True, overwrite_output=True)
|
102 |
video_with_subs = open("final.mp4", "rb")
|
103 |
return video_with_subs
|
104 |
|
@@ -108,7 +112,7 @@ def main():
|
|
108 |
loaded_model = change_model(current_size, size)
|
109 |
st.write(f"Model is {'multilingual' if loaded_model.is_multilingual else 'English-only'} "
|
110 |
f"and has {sum(np.prod(p.shape) for p in loaded_model.parameters()):,} parameters.")
|
111 |
-
input_file = st.file_uploader("File", type=["mp4", "avi", "mov", "mkv"])
|
112 |
# get the name of the input_file
|
113 |
if input_file is not None:
|
114 |
filename = input_file.name[:-4]
|
@@ -226,5 +230,10 @@ def main():
|
|
226 |
|
227 |
|
228 |
if __name__ == "__main__":
|
229 |
-
|
230 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import numpy as np
|
10 |
import pathlib
|
11 |
import os
|
12 |
+
import components.authenticate as authenticate
|
13 |
+
import torch
|
14 |
|
15 |
st.set_page_config(page_title="Auto Subtitled Video Generator", page_icon=":movie_camera:", layout="wide")
|
16 |
|
|
|
38 |
col1, col2 = st.columns([1, 3])
|
39 |
with col1:
|
40 |
lottie = load_lottieurl("https://assets1.lottiefiles.com/packages/lf20_HjK9Ol.json")
|
41 |
+
st_lottie(lottie)
|
42 |
|
43 |
with col2:
|
44 |
st.write("""
|
|
|
51 |
|
52 |
@st.cache(allow_output_mutation=True)
|
53 |
def change_model(current_size, size):
|
54 |
+
torch.cuda.is_available()
|
55 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
56 |
if current_size != size:
|
57 |
+
loaded_model = whisper.load_model(size, device=DEVICE)
|
58 |
return loaded_model
|
59 |
else:
|
60 |
raise Exception("Model size is the same as the current size.")
|
|
|
102 |
def generate_subtitled_video(video, audio, transcript):
|
103 |
video_file = ffmpeg.input(video)
|
104 |
audio_file = ffmpeg.input(audio)
|
105 |
+
ffmpeg.concat(video_file.filter("subtitles", transcript), audio_file, v=1, a=1).output("final.mp4").global_args('-report').run(quiet=True, overwrite_output=True)
|
106 |
video_with_subs = open("final.mp4", "rb")
|
107 |
return video_with_subs
|
108 |
|
|
|
112 |
loaded_model = change_model(current_size, size)
|
113 |
st.write(f"Model is {'multilingual' if loaded_model.is_multilingual else 'English-only'} "
|
114 |
f"and has {sum(np.prod(p.shape) for p in loaded_model.parameters()):,} parameters.")
|
115 |
+
input_file = st.file_uploader("Upload Video File", type=["mp4", "avi", "mov", "mkv"])
|
116 |
# get the name of the input_file
|
117 |
if input_file is not None:
|
118 |
filename = input_file.name[:-4]
|
|
|
230 |
|
231 |
|
232 |
if __name__ == "__main__":
|
233 |
+
authenticate.set_st_state_vars()
|
234 |
+
if st.session_state["authenticated"]:
|
235 |
+
main()
|
236 |
+
authenticate.button_logout()
|
237 |
+
else:
|
238 |
+
st.info("Please log in or sign up to use the app.")
|
239 |
+
authenticate.button_login()
|
pages/03_📝_Upload_Video_File_and_Transcript.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_lottie import st_lottie
|
3 |
+
from utils import write_vtt, write_srt
|
4 |
+
import ffmpeg
|
5 |
+
import requests
|
6 |
+
from typing import Iterator
|
7 |
+
from io import StringIO
|
8 |
+
import numpy as np
|
9 |
+
import pathlib
|
10 |
+
import os
|
11 |
+
import components.authenticate as authenticate
|
12 |
+
|
13 |
+
|
14 |
+
st.set_page_config(page_title="Auto Subtitled Video Generator", page_icon=":movie_camera:", layout="wide")
|
15 |
+
|
16 |
+
# Define a function that we can use to load lottie files from a link.
|
17 |
+
@st.cache(allow_output_mutation=True)
|
18 |
+
def load_lottieurl(url: str):
|
19 |
+
r = requests.get(url)
|
20 |
+
if r.status_code != 200:
|
21 |
+
return None
|
22 |
+
return r.json()
|
23 |
+
|
24 |
+
|
25 |
+
APP_DIR = pathlib.Path(__file__).parent.absolute()
|
26 |
+
|
27 |
+
LOCAL_DIR = APP_DIR / "local_transcript"
|
28 |
+
LOCAL_DIR.mkdir(exist_ok=True)
|
29 |
+
save_dir = LOCAL_DIR / "output"
|
30 |
+
save_dir.mkdir(exist_ok=True)
|
31 |
+
|
32 |
+
|
33 |
+
col1, col2 = st.columns([1, 3])
|
34 |
+
with col1:
|
35 |
+
lottie = load_lottieurl("https://assets6.lottiefiles.com/packages/lf20_cjnxwrkt.json")
|
36 |
+
st_lottie(lottie)
|
37 |
+
|
38 |
+
with col2:
|
39 |
+
st.write("""
|
40 |
+
## Auto Subtitled Video Generator
|
41 |
+
##### ➠ Upload a video file and a transcript as .srt or .vtt file and get a video with subtitles.
|
42 |
+
##### ➠ Processing time will increase as the video length increases. """)
|
43 |
+
|
44 |
+
|
45 |
+
def getSubs(segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
|
46 |
+
segmentStream = StringIO()
|
47 |
+
|
48 |
+
if format == 'vtt':
|
49 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
50 |
+
elif format == 'srt':
|
51 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
52 |
+
else:
|
53 |
+
raise Exception("Unknown format " + format)
|
54 |
+
|
55 |
+
segmentStream.seek(0)
|
56 |
+
return segmentStream.read()
|
57 |
+
|
58 |
+
|
59 |
+
def split_video_audio(uploaded_file):
|
60 |
+
with open(f"{save_dir}/input.mp4", "wb") as f:
|
61 |
+
f.write(uploaded_file.read())
|
62 |
+
audio = ffmpeg.input(f"{save_dir}/input.mp4")
|
63 |
+
audio = ffmpeg.output(audio, f"{save_dir}/output.wav", acodec="pcm_s16le", ac=1, ar="16k")
|
64 |
+
ffmpeg.run(audio, overwrite_output=True)
|
65 |
+
|
66 |
+
|
67 |
+
def main():
|
68 |
+
uploaded_video = st.file_uploader("Upload Video File", type=["mp4", "avi", "mov", "mkv"])
|
69 |
+
# get the name of the input_file
|
70 |
+
if uploaded_video is not None:
|
71 |
+
filename = uploaded_video.name[:-4]
|
72 |
+
else:
|
73 |
+
filename = None
|
74 |
+
transcript_file = st.file_uploader("Upload Transcript File", type=["srt", "vtt"])
|
75 |
+
if transcript_file is not None:
|
76 |
+
transcript_name = transcript_file.name
|
77 |
+
else:
|
78 |
+
transcript_name = None
|
79 |
+
if uploaded_video is not None and transcript_file is not None:
|
80 |
+
if transcript_name[-3:] == "vtt":
|
81 |
+
with open("uploaded_transcript.vtt", "wb") as f:
|
82 |
+
f.writelines(transcript_file)
|
83 |
+
f.close()
|
84 |
+
with open(os.path.join(os.getcwd(), "uploaded_transcript.vtt"), "rb") as f:
|
85 |
+
vtt_file = f.read()
|
86 |
+
if st.button("Generate Video with Subtitles"):
|
87 |
+
with st.spinner("Generating Subtitled Video"):
|
88 |
+
split_video_audio(uploaded_video)
|
89 |
+
video_file = ffmpeg.input(f"{save_dir}/input.mp4")
|
90 |
+
audio_file = ffmpeg.input(f"{save_dir}/output.wav")
|
91 |
+
ffmpeg.concat(video_file.filter("subtitles", "uploaded_transcript.vtt"), audio_file, v=1, a=1).output("final.mp4").global_args('-report').run(quiet=True, overwrite_output=True)
|
92 |
+
video_with_subs = open("final.mp4", "rb")
|
93 |
+
col3, col4 = st.columns(2)
|
94 |
+
with col3:
|
95 |
+
st.video(uploaded_video)
|
96 |
+
with col4:
|
97 |
+
st.video(video_with_subs)
|
98 |
+
st.download_button(label="Download Video with Subtitles",
|
99 |
+
data=video_with_subs,
|
100 |
+
file_name=f"{filename}_with_subs.mp4")
|
101 |
+
|
102 |
+
elif transcript_name[-3:] == "srt":
|
103 |
+
with open("uploaded_transcript.srt", "wb") as f:
|
104 |
+
f.writelines(transcript_file)
|
105 |
+
f.close()
|
106 |
+
with open(os.path.join(os.getcwd(), "uploaded_transcript.srt"), "rb") as f:
|
107 |
+
srt_file = f.read()
|
108 |
+
if st.button("Generate Video with Subtitles"):
|
109 |
+
with st.spinner("Generating Subtitled Video"):
|
110 |
+
split_video_audio(uploaded_video)
|
111 |
+
video_file = ffmpeg.input(f"{save_dir}/input.mp4")
|
112 |
+
audio_file = ffmpeg.input(f"{save_dir}/output.wav")
|
113 |
+
ffmpeg.concat(video_file.filter("subtitles", "uploaded_transcript.srt"), audio_file, v=1, a=1).output("final.mp4").run(quiet=True, overwrite_output=True)
|
114 |
+
video_with_subs = open("final.mp4", "rb")
|
115 |
+
col3, col4 = st.columns(2)
|
116 |
+
with col3:
|
117 |
+
st.video(uploaded_video)
|
118 |
+
with col4:
|
119 |
+
st.video(video_with_subs)
|
120 |
+
st.download_button(label="Download Video with Subtitles",
|
121 |
+
data=video_with_subs,
|
122 |
+
file_name=f"{filename}_with_subs.mp4")
|
123 |
+
else:
|
124 |
+
st.error("Please upload a .srt or .vtt file")
|
125 |
+
else:
|
126 |
+
st.info("Please upload a video file and a transcript file")
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == "__main__":
|
130 |
+
authenticate.set_st_state_vars()
|
131 |
+
if st.session_state["authenticated"]:
|
132 |
+
main()
|
133 |
+
authenticate.button_logout()
|
134 |
+
else:
|
135 |
+
st.info("Please log in or sign up to use the app.")
|
136 |
+
authenticate.button_login()
|
137 |
+
|
pages/{03_🔊_Upload_Audio_File.py → 04_🔊_Upload_Audio_File.py}
RENAMED
@@ -9,6 +9,8 @@ from io import StringIO
|
|
9 |
import numpy as np
|
10 |
import pathlib
|
11 |
import os
|
|
|
|
|
12 |
|
13 |
st.set_page_config(page_title="Auto Transcriber", page_icon="🔊", layout="wide")
|
14 |
|
@@ -32,7 +34,7 @@ save_dir.mkdir(exist_ok=True)
|
|
32 |
col1, col2 = st.columns([1, 3])
|
33 |
with col1:
|
34 |
lottie = load_lottieurl("https://assets1.lottiefiles.com/packages/lf20_1xbk4d2v.json")
|
35 |
-
st_lottie(lottie
|
36 |
|
37 |
with col2:
|
38 |
st.write("""
|
@@ -48,8 +50,10 @@ current_size = "None"
|
|
48 |
|
49 |
@st.cache(allow_output_mutation=True)
|
50 |
def change_model(current_size, size):
|
|
|
|
|
51 |
if current_size != size:
|
52 |
-
loaded_model = whisper.load_model(size)
|
53 |
return loaded_model
|
54 |
else:
|
55 |
raise Exception("Model size is the same as the current size.")
|
@@ -98,7 +102,7 @@ def main():
|
|
98 |
loaded_model = change_model(current_size, size)
|
99 |
st.write(f"Model is {'multilingual' if loaded_model.is_multilingual else 'English-only'} "
|
100 |
f"and has {sum(np.prod(p.shape) for p in loaded_model.parameters()):,} parameters.")
|
101 |
-
input_file = st.file_uploader("Upload
|
102 |
if input_file is not None:
|
103 |
filename = input_file.name[:-4]
|
104 |
else:
|
@@ -201,5 +205,10 @@ def main():
|
|
201 |
|
202 |
|
203 |
if __name__ == "__main__":
|
204 |
-
|
205 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
9 |
import numpy as np
|
10 |
import pathlib
|
11 |
import os
|
12 |
+
import components.authenticate as authenticate
|
13 |
+
import torch
|
14 |
|
15 |
st.set_page_config(page_title="Auto Transcriber", page_icon="🔊", layout="wide")
|
16 |
|
|
|
34 |
col1, col2 = st.columns([1, 3])
|
35 |
with col1:
|
36 |
lottie = load_lottieurl("https://assets1.lottiefiles.com/packages/lf20_1xbk4d2v.json")
|
37 |
+
st_lottie(lottie)
|
38 |
|
39 |
with col2:
|
40 |
st.write("""
|
|
|
50 |
|
51 |
@st.cache(allow_output_mutation=True)
|
52 |
def change_model(current_size, size):
|
53 |
+
torch.cuda.is_available()
|
54 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
if current_size != size:
|
56 |
+
loaded_model = whisper.load_model(size, device=DEVICE)
|
57 |
return loaded_model
|
58 |
else:
|
59 |
raise Exception("Model size is the same as the current size.")
|
|
|
102 |
loaded_model = change_model(current_size, size)
|
103 |
st.write(f"Model is {'multilingual' if loaded_model.is_multilingual else 'English-only'} "
|
104 |
f"and has {sum(np.prod(p.shape) for p in loaded_model.parameters()):,} parameters.")
|
105 |
+
input_file = st.file_uploader("Upload Audio File", type=["mp3", "wav", "m4a"])
|
106 |
if input_file is not None:
|
107 |
filename = input_file.name[:-4]
|
108 |
else:
|
|
|
205 |
|
206 |
|
207 |
if __name__ == "__main__":
|
208 |
+
authenticate.set_st_state_vars()
|
209 |
+
if st.session_state["authenticated"]:
|
210 |
+
main()
|
211 |
+
authenticate.button_logout()
|
212 |
+
else:
|
213 |
+
st.info("Please log in or sign up to use the app.")
|
214 |
+
authenticate.button_login()
|