import nltk import whisper from pytube import YouTube import streamlit as st from sentence_transformers import SentenceTransformer, util nltk.download('punkt') @st.experimental_singleton def init_sentence_model(embedding_model): return SentenceTransformer(embedding_model) @st.experimental_singleton def init_whisper(whisper_size): return whisper.load_model(whisper_size) @st.experimental_memo def inference(link): yt = YouTube(link) path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4") options = whisper.DecodingOptions(without_timestamps=True) results = whisper_model.transcribe(path) return results['segments'] @st.experimental_memo def get_embeddings(segments): return model.encode(segments["text"]) def format_segments(segments, window=10): new_segments = dict() new_segments['text'] = [" ".join([seg['text'] for seg in segments[i:i+5]]) for i in range(0, len(segments), window)] new_segments['start'] = [segments[i]['start'] for i in range(0, len(segments), window)] return new_segments with st.form("transcribe"): yt_link = st.text_input("Youtube link") whisper_size = st.selectbox("Whisper model size", ("small", "base", "large")) embedding_model = st.text_input("Embedding model name", value='all-mpnet-base-v2') top_k = st.number_input("Number of query results", value=5) window = st.number_input("Number of segments per result", value=10) transcribe_submit = st.form_submit_button("Submit") if transcribe_submit and 'start_search' not in st.session_state: st.session_state.start_search = True if 'start_search' in st.session_state: model = init_sentence_model(embedding_model) whisper_model = init_whisper(whisper_size) segments = inference(yt_link) segments = format_segments(segments, window) embeddings = get_embeddings(segments) query = st.text_input('Enter a query') if query: query_embedding = model.encode(query) results = util.semantic_search(query_embedding, embeddings, top_k=top_k) st.markdown("\n\n".join([segments['text'][result['corpus_id']]+f"... [Watch at timestamp]({yt_link}&t={segments['start'][result['corpus_id']]}s)" for result in results[0]]), unsafe_allow_html=True)