File size: 2,208 Bytes
358c829
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import nltk
import whisper
from pytube import YouTube
import streamlit as st
from sentence_transformers import SentenceTransformer, util

nltk.download('punkt')


@st.experimental_singleton
def init_sentence_model(embedding_model):
    return SentenceTransformer(embedding_model)

@st.experimental_singleton
def init_whisper(whisper_size):
  return whisper.load_model(whisper_size)

@st.experimental_memo
def inference(link):
  yt = YouTube(link)
  path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4")
  options = whisper.DecodingOptions(without_timestamps=True)
  results = whisper_model.transcribe(path)
  return results['segments']

@st.experimental_memo
def get_embeddings(segments):
  return model.encode(segments["text"])

def format_segments(segments, window=10):
  new_segments = dict()
  new_segments['text'] = [" ".join([seg['text'] for seg in segments[i:i+5]]) for i in range(0, len(segments), window)]
  new_segments['start'] = [segments[i]['start'] for i in range(0, len(segments), window)]

  return new_segments

with st.form("transcribe"):
  yt_link = st.text_input("Youtube link")
  whisper_size = st.selectbox("Whisper model size", ("small", "base", "large"))
  embedding_model = st.text_input("Embedding model name", value='all-mpnet-base-v2')
  top_k = st.number_input("Number of query results", value=5)
  window = st.number_input("Number of segments per result", value=10)

  transcribe_submit = st.form_submit_button("Submit")

if transcribe_submit and 'start_search' not in st.session_state:
  st.session_state.start_search = True

if 'start_search' in st.session_state:
  model = init_sentence_model(embedding_model)

  whisper_model = init_whisper(whisper_size)

  segments = inference(yt_link)

  segments = format_segments(segments, window)

  embeddings = get_embeddings(segments)

  query = st.text_input('Enter a query')

  if query:
    query_embedding = model.encode(query)
    results = util.semantic_search(query_embedding, embeddings, top_k=top_k)
    st.markdown("\n\n".join([segments['text'][result['corpus_id']]+f"... [Watch at timestamp]({yt_link}&t={segments['start'][result['corpus_id']]}s)" for result in results[0]]), unsafe_allow_html=True)