add sentence transformer to inference endpoint
Browse files- handler.py +123 -18
handler.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
-
from typing import
|
2 |
-
|
|
|
|
|
3 |
import whisper
|
4 |
import torch
|
5 |
import pytube
|
@@ -9,20 +11,27 @@ import time
|
|
9 |
class EndpointHandler():
|
10 |
def __init__(self, path=""):
|
11 |
# load the model
|
12 |
-
|
13 |
-
|
|
|
14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
print(f'whisper will use: {device}')
|
16 |
-
|
17 |
t0 = time.time()
|
18 |
-
self.
|
19 |
t1 = time.time()
|
20 |
-
|
21 |
-
total = t1-t0
|
22 |
-
print(f'Finished loading model in {total} seconds')
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
-
|
|
|
|
|
|
|
26 |
"""
|
27 |
Args:
|
28 |
data (:obj:):
|
@@ -32,13 +41,36 @@ class EndpointHandler():
|
|
32 |
"""
|
33 |
# process input
|
34 |
print('data', data)
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
decode_options = {
|
37 |
-
# Set language to None to support multilingual,
|
38 |
# but it will take longer to process while it detects the language.
|
39 |
# Realized this by running in verbose mode and seeing how much time
|
40 |
# was spent on the decoding language step
|
41 |
-
"language":"en",
|
42 |
"verbose": True
|
43 |
}
|
44 |
yt = pytube.YouTube(video_url)
|
@@ -56,14 +88,87 @@ class EndpointHandler():
|
|
56 |
path_to_audio = f"{yt.video_id}.mp3"
|
57 |
stream.download(filename=path_to_audio)
|
58 |
t0 = time.time()
|
59 |
-
transcript = self.
|
60 |
t1 = time.time()
|
61 |
for segment in transcript['segments']:
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
total = t1-t0
|
66 |
print(f'Finished transcription in {total} seconds')
|
67 |
|
68 |
# postprocess the prediction
|
69 |
return {"transcript": transcript, 'video': video_info}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
from sentence_transformers import SentenceTransformer
|
4 |
+
from tqdm import tqdm
|
5 |
import whisper
|
6 |
import torch
|
7 |
import pytube
|
|
|
11 |
class EndpointHandler():
|
12 |
def __init__(self, path=""):
|
13 |
# load the model
|
14 |
+
WHISPER_MODEL_NAME = "tiny.en"
|
15 |
+
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
16 |
+
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
print(f'whisper will use: {device}')
|
19 |
+
|
20 |
t0 = time.time()
|
21 |
+
self.whisper_model = whisper.load_model(WHISPER_MODEL_NAME).to(device)
|
22 |
t1 = time.time()
|
|
|
|
|
|
|
23 |
|
24 |
+
total = t1 - t0
|
25 |
+
print(f'Finished loading whisper_model in {total} seconds')
|
26 |
+
|
27 |
+
t0 = time.time()
|
28 |
+
self.sentence_transformer_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_NAME)
|
29 |
+
t1 = time.time()
|
30 |
|
31 |
+
total = t1 - t0
|
32 |
+
print(f'Finished loading sentence_transformer_model in {total} seconds')
|
33 |
+
|
34 |
+
def __call__(self, data: Dict[str, str]) -> Dict:
|
35 |
"""
|
36 |
Args:
|
37 |
data (:obj:):
|
|
|
41 |
"""
|
42 |
# process input
|
43 |
print('data', data)
|
44 |
+
|
45 |
+
video_url = data.pop("video_url", None)
|
46 |
+
segments = data.pop("segments", None)
|
47 |
+
encoded_segments = {}
|
48 |
+
if video_url:
|
49 |
+
video_with_transcript = self.transcribe_video(video_url)
|
50 |
+
encode_transcript = data.pop("encode_transcript", True)
|
51 |
+
if encode_transcript:
|
52 |
+
video_with_transcript['transcript']['segments'] = self.combine_transcripts(video_with_transcript)
|
53 |
+
encoded_segments = {
|
54 |
+
"encoded_segments": self.encode_sentences(video_with_transcript['transcript']['segments'])
|
55 |
+
}
|
56 |
+
return {
|
57 |
+
**video_with_transcript,
|
58 |
+
**encoded_segments
|
59 |
+
}
|
60 |
+
elif segments:
|
61 |
+
encoded_segments = self.encode_sentences(segments)
|
62 |
+
|
63 |
+
return {
|
64 |
+
"encoded_segments": encoded_segments
|
65 |
+
}
|
66 |
+
|
67 |
+
def transcribe_video(self, video_url):
|
68 |
decode_options = {
|
69 |
+
# Set language to None to support multilingual,
|
70 |
# but it will take longer to process while it detects the language.
|
71 |
# Realized this by running in verbose mode and seeing how much time
|
72 |
# was spent on the decoding language step
|
73 |
+
"language": "en",
|
74 |
"verbose": True
|
75 |
}
|
76 |
yt = pytube.YouTube(video_url)
|
|
|
88 |
path_to_audio = f"{yt.video_id}.mp3"
|
89 |
stream.download(filename=path_to_audio)
|
90 |
t0 = time.time()
|
91 |
+
transcript = self.whisper_model.transcribe(path_to_audio, **decode_options)
|
92 |
t1 = time.time()
|
93 |
for segment in transcript['segments']:
|
94 |
+
# Remove the tokens array, it makes the response too verbose
|
95 |
+
segment.pop('tokens', None)
|
96 |
+
|
97 |
+
total = t1 - t0
|
98 |
print(f'Finished transcription in {total} seconds')
|
99 |
|
100 |
# postprocess the prediction
|
101 |
return {"transcript": transcript, 'video': video_info}
|
102 |
+
|
103 |
+
def encode_sentences(self, transcripts, batch_size=64):
|
104 |
+
"""
|
105 |
+
Encoding all of our segments at once or storing them locally would require too much compute or memory.
|
106 |
+
So we do it in batches of 64
|
107 |
+
:param transcripts:
|
108 |
+
:param batch_size:
|
109 |
+
:return:
|
110 |
+
"""
|
111 |
+
# loop through in batches of 64
|
112 |
+
all_batches = []
|
113 |
+
for i in tqdm(range(0, len(transcripts), batch_size)):
|
114 |
+
# find end position of batch (for when we hit end of data)
|
115 |
+
i_end = min(len(transcripts) - 1, i + batch_size)
|
116 |
+
# extract the metadata like text, start/end positions, etc
|
117 |
+
batch_meta = [{
|
118 |
+
**transcripts[x]
|
119 |
+
} for x in range(i, i_end)]
|
120 |
+
# extract only text to be encoded by embedding model
|
121 |
+
batch_text = [
|
122 |
+
row['text'] for row in transcripts[i:i_end]
|
123 |
+
]
|
124 |
+
# extract IDs to be attached to each embedding and metadata
|
125 |
+
batch_ids = [
|
126 |
+
row['id'] for row in transcripts[i:i_end]
|
127 |
+
]
|
128 |
+
# create the embedding vectors
|
129 |
+
batch_vectors = self.sentence_transformer_model.encode(batch_text).tolist()
|
130 |
+
|
131 |
+
batch_details = [
|
132 |
+
{
|
133 |
+
**batch_meta[x],
|
134 |
+
'vectors':batch_vectors[x]
|
135 |
+
} for x in range(0, len(batch_meta))
|
136 |
+
]
|
137 |
+
all_batches.extend(batch_details)
|
138 |
+
|
139 |
+
return all_batches
|
140 |
+
|
141 |
+
@staticmethod
|
142 |
+
def combine_transcripts(video, window=6, stride=3):
|
143 |
+
"""
|
144 |
+
|
145 |
+
:param video:
|
146 |
+
:param window: number of sentences to combine
|
147 |
+
:param stride: number of sentences to 'stride' over, used to create overlap
|
148 |
+
:return:
|
149 |
+
"""
|
150 |
+
new_transcript_segments = []
|
151 |
+
|
152 |
+
video_info = video['video']
|
153 |
+
transcript_segments = video['transcript']['segments']
|
154 |
+
for i in tqdm(range(0, len(transcript_segments), stride)):
|
155 |
+
i_end = min(len(transcript_segments) - 1, i + window)
|
156 |
+
text = ' '.join(transcript['text']
|
157 |
+
for transcript in
|
158 |
+
transcript_segments[i:i_end])
|
159 |
+
# TODO: Should int (float to seconds) conversion happen at the API level?
|
160 |
+
start = int(transcript_segments[i]['start'])
|
161 |
+
end = int(transcript_segments[i]['end'])
|
162 |
+
new_transcript_segments.append({
|
163 |
+
**video_info,
|
164 |
+
**{
|
165 |
+
'start': start,
|
166 |
+
'end': end,
|
167 |
+
'title': video_info['title'],
|
168 |
+
'text': text,
|
169 |
+
'id': f"{video_info['id']}-t{start}",
|
170 |
+
'url': f"https://youtu.be/{video_info['id']}?t={start}",
|
171 |
+
'video_id': video_info['id'],
|
172 |
+
}
|
173 |
+
})
|
174 |
+
return new_transcript_segments
|