issam9's picture
Update app.py
5524160
import nltk
import whisper
from pytube import YouTube
import streamlit as st
from sentence_transformers import SentenceTransformer, util
nltk.download('punkt')
@st.experimental_singleton
def init_sentence_model(embedding_model):
return SentenceTransformer(embedding_model)
@st.experimental_singleton
def init_whisper(whisper_size):
return whisper.load_model(whisper_size)
@st.experimental_memo
def inference(link):
yt = YouTube(link)
path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4")
options = whisper.DecodingOptions(without_timestamps=True)
results = whisper_model.transcribe(path)
return results['segments']
@st.experimental_memo
def get_embeddings(segments):
return model.encode(segments["text"])
def format_segments(segments, window=10):
new_segments = dict()
new_segments['text'] = [" ".join([seg['text'] for seg in segments[i:i+5]]) for i in range(0, len(segments), window)]
new_segments['start'] = [segments[i]['start'] for i in range(0, len(segments), window)]
return new_segments
st.markdown("""
# Youtube video transcription and search
You can run it on colab GPU for faster performance: [Link](https://colab.research.google.com/drive/1-6Lmvlfwxd5JAnKOBKtdR1YiooIm-rJf?usp=sharing)
""")
with st.form("transcribe"):
yt_link = st.text_input("Youtube link")
whisper_size = st.selectbox("Whisper model size", ("small", "base", "large"))
embedding_model = st.text_input("Embedding model name", value='all-mpnet-base-v2')
top_k = st.number_input("Number of query results", value=5)
window = st.number_input("Number of segments per result", value=10)
transcribe_submit = st.form_submit_button("Submit")
if transcribe_submit and 'start_search' not in st.session_state:
st.session_state.start_search = True
if 'start_search' in st.session_state:
model = init_sentence_model(embedding_model)
whisper_model = init_whisper(whisper_size)
segments = inference(yt_link)
segments = format_segments(segments, window)
embeddings = get_embeddings(segments)
query = st.text_input('Enter a query')
if query:
query_embedding = model.encode(query)
results = util.semantic_search(query_embedding, embeddings, top_k=top_k)
st.markdown("\n\n".join([segments['text'][result['corpus_id']]+f"... [Watch at timestamp]({yt_link}&t={segments['start'][result['corpus_id']]}s)" for result in results[0]]), unsafe_allow_html=True)