issam9 commited on
Commit
358c829
1 Parent(s): 331a526

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ import whisper
3
+ from pytube import YouTube
4
+ import streamlit as st
5
+ from sentence_transformers import SentenceTransformer, util
6
+
7
+ nltk.download('punkt')
8
+
9
+
10
+ @st.experimental_singleton
11
+ def init_sentence_model(embedding_model):
12
+ return SentenceTransformer(embedding_model)
13
+
14
+ @st.experimental_singleton
15
+ def init_whisper(whisper_size):
16
+ return whisper.load_model(whisper_size)
17
+
18
+ @st.experimental_memo
19
+ def inference(link):
20
+ yt = YouTube(link)
21
+ path = yt.streams.filter(only_audio=True)[0].download(filename="audio.mp4")
22
+ options = whisper.DecodingOptions(without_timestamps=True)
23
+ results = whisper_model.transcribe(path)
24
+ return results['segments']
25
+
26
+ @st.experimental_memo
27
+ def get_embeddings(segments):
28
+ return model.encode(segments["text"])
29
+
30
+ def format_segments(segments, window=10):
31
+ new_segments = dict()
32
+ new_segments['text'] = [" ".join([seg['text'] for seg in segments[i:i+5]]) for i in range(0, len(segments), window)]
33
+ new_segments['start'] = [segments[i]['start'] for i in range(0, len(segments), window)]
34
+
35
+ return new_segments
36
+
37
+ with st.form("transcribe"):
38
+ yt_link = st.text_input("Youtube link")
39
+ whisper_size = st.selectbox("Whisper model size", ("small", "base", "large"))
40
+ embedding_model = st.text_input("Embedding model name", value='all-mpnet-base-v2')
41
+ top_k = st.number_input("Number of query results", value=5)
42
+ window = st.number_input("Number of segments per result", value=10)
43
+
44
+ transcribe_submit = st.form_submit_button("Submit")
45
+
46
+ if transcribe_submit and 'start_search' not in st.session_state:
47
+ st.session_state.start_search = True
48
+
49
+ if 'start_search' in st.session_state:
50
+ model = init_sentence_model(embedding_model)
51
+
52
+ whisper_model = init_whisper(whisper_size)
53
+
54
+ segments = inference(yt_link)
55
+
56
+ segments = format_segments(segments, window)
57
+
58
+ embeddings = get_embeddings(segments)
59
+
60
+ query = st.text_input('Enter a query')
61
+
62
+ if query:
63
+ query_embedding = model.encode(query)
64
+ results = util.semantic_search(query_embedding, embeddings, top_k=top_k)
65
+ st.markdown("\n\n".join([segments['text'][result['corpus_id']]+f"... [Watch at timestamp]({yt_link}&t={segments['start'][result['corpus_id']]}s)" for result in results[0]]), unsafe_allow_html=True)