tomiwa1a commited on
Commit
4678908
1 Parent(s): 847a8ee

add sentence transformer to inference endpoint

Browse files
Files changed (1) hide show
  1. handler.py +123 -18
handler.py CHANGED
@@ -1,5 +1,7 @@
1
- from typing import Dict
2
- from transformers.pipelines.audio_utils import ffmpeg_read
 
 
3
  import whisper
4
  import torch
5
  import pytube
@@ -9,20 +11,27 @@ import time
9
  class EndpointHandler():
10
  def __init__(self, path=""):
11
  # load the model
12
- MODEL_NAME = "tiny.en"
13
-
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
  print(f'whisper will use: {device}')
16
-
17
  t0 = time.time()
18
- self.model = whisper.load_model(MODEL_NAME).to(device)
19
  t1 = time.time()
20
-
21
- total = t1-t0
22
- print(f'Finished loading model in {total} seconds')
23
 
 
 
 
 
 
 
24
 
25
- def __call__(self, data: Dict[str, bytes]) -> Dict[str, str]:
 
 
 
26
  """
27
  Args:
28
  data (:obj:):
@@ -32,13 +41,36 @@ class EndpointHandler():
32
  """
33
  # process input
34
  print('data', data)
35
- video_url = data.pop("inputs", data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  decode_options = {
37
- # Set language to None to support multilingual,
38
  # but it will take longer to process while it detects the language.
39
  # Realized this by running in verbose mode and seeing how much time
40
  # was spent on the decoding language step
41
- "language":"en",
42
  "verbose": True
43
  }
44
  yt = pytube.YouTube(video_url)
@@ -56,14 +88,87 @@ class EndpointHandler():
56
  path_to_audio = f"{yt.video_id}.mp3"
57
  stream.download(filename=path_to_audio)
58
  t0 = time.time()
59
- transcript = self.model.transcribe(path_to_audio, **decode_options)
60
  t1 = time.time()
61
  for segment in transcript['segments']:
62
- # Remove the tokens array, it makes the response too verbose
63
- segment.pop('tokens', None)
64
-
65
- total = t1-t0
66
  print(f'Finished transcription in {total} seconds')
67
 
68
  # postprocess the prediction
69
  return {"transcript": transcript, 'video': video_info}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ from sentence_transformers import SentenceTransformer
4
+ from tqdm import tqdm
5
  import whisper
6
  import torch
7
  import pytube
 
11
  class EndpointHandler():
12
  def __init__(self, path=""):
13
  # load the model
14
+ WHISPER_MODEL_NAME = "tiny.en"
15
+ SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
16
+
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  print(f'whisper will use: {device}')
19
+
20
  t0 = time.time()
21
+ self.whisper_model = whisper.load_model(WHISPER_MODEL_NAME).to(device)
22
  t1 = time.time()
 
 
 
23
 
24
+ total = t1 - t0
25
+ print(f'Finished loading whisper_model in {total} seconds')
26
+
27
+ t0 = time.time()
28
+ self.sentence_transformer_model = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_NAME)
29
+ t1 = time.time()
30
 
31
+ total = t1 - t0
32
+ print(f'Finished loading sentence_transformer_model in {total} seconds')
33
+
34
+ def __call__(self, data: Dict[str, str]) -> Dict:
35
  """
36
  Args:
37
  data (:obj:):
 
41
  """
42
  # process input
43
  print('data', data)
44
+
45
+ video_url = data.pop("video_url", None)
46
+ segments = data.pop("segments", None)
47
+ encoded_segments = {}
48
+ if video_url:
49
+ video_with_transcript = self.transcribe_video(video_url)
50
+ encode_transcript = data.pop("encode_transcript", True)
51
+ if encode_transcript:
52
+ video_with_transcript['transcript']['segments'] = self.combine_transcripts(video_with_transcript)
53
+ encoded_segments = {
54
+ "encoded_segments": self.encode_sentences(video_with_transcript['transcript']['segments'])
55
+ }
56
+ return {
57
+ **video_with_transcript,
58
+ **encoded_segments
59
+ }
60
+ elif segments:
61
+ encoded_segments = self.encode_sentences(segments)
62
+
63
+ return {
64
+ "encoded_segments": encoded_segments
65
+ }
66
+
67
+ def transcribe_video(self, video_url):
68
  decode_options = {
69
+ # Set language to None to support multilingual,
70
  # but it will take longer to process while it detects the language.
71
  # Realized this by running in verbose mode and seeing how much time
72
  # was spent on the decoding language step
73
+ "language": "en",
74
  "verbose": True
75
  }
76
  yt = pytube.YouTube(video_url)
 
88
  path_to_audio = f"{yt.video_id}.mp3"
89
  stream.download(filename=path_to_audio)
90
  t0 = time.time()
91
+ transcript = self.whisper_model.transcribe(path_to_audio, **decode_options)
92
  t1 = time.time()
93
  for segment in transcript['segments']:
94
+ # Remove the tokens array, it makes the response too verbose
95
+ segment.pop('tokens', None)
96
+
97
+ total = t1 - t0
98
  print(f'Finished transcription in {total} seconds')
99
 
100
  # postprocess the prediction
101
  return {"transcript": transcript, 'video': video_info}
102
+
103
+ def encode_sentences(self, transcripts, batch_size=64):
104
+ """
105
+ Encoding all of our segments at once or storing them locally would require too much compute or memory.
106
+ So we do it in batches of 64
107
+ :param transcripts:
108
+ :param batch_size:
109
+ :return:
110
+ """
111
+ # loop through in batches of 64
112
+ all_batches = []
113
+ for i in tqdm(range(0, len(transcripts), batch_size)):
114
+ # find end position of batch (for when we hit end of data)
115
+ i_end = min(len(transcripts) - 1, i + batch_size)
116
+ # extract the metadata like text, start/end positions, etc
117
+ batch_meta = [{
118
+ **transcripts[x]
119
+ } for x in range(i, i_end)]
120
+ # extract only text to be encoded by embedding model
121
+ batch_text = [
122
+ row['text'] for row in transcripts[i:i_end]
123
+ ]
124
+ # extract IDs to be attached to each embedding and metadata
125
+ batch_ids = [
126
+ row['id'] for row in transcripts[i:i_end]
127
+ ]
128
+ # create the embedding vectors
129
+ batch_vectors = self.sentence_transformer_model.encode(batch_text).tolist()
130
+
131
+ batch_details = [
132
+ {
133
+ **batch_meta[x],
134
+ 'vectors':batch_vectors[x]
135
+ } for x in range(0, len(batch_meta))
136
+ ]
137
+ all_batches.extend(batch_details)
138
+
139
+ return all_batches
140
+
141
+ @staticmethod
142
+ def combine_transcripts(video, window=6, stride=3):
143
+ """
144
+
145
+ :param video:
146
+ :param window: number of sentences to combine
147
+ :param stride: number of sentences to 'stride' over, used to create overlap
148
+ :return:
149
+ """
150
+ new_transcript_segments = []
151
+
152
+ video_info = video['video']
153
+ transcript_segments = video['transcript']['segments']
154
+ for i in tqdm(range(0, len(transcript_segments), stride)):
155
+ i_end = min(len(transcript_segments) - 1, i + window)
156
+ text = ' '.join(transcript['text']
157
+ for transcript in
158
+ transcript_segments[i:i_end])
159
+ # TODO: Should int (float to seconds) conversion happen at the API level?
160
+ start = int(transcript_segments[i]['start'])
161
+ end = int(transcript_segments[i]['end'])
162
+ new_transcript_segments.append({
163
+ **video_info,
164
+ **{
165
+ 'start': start,
166
+ 'end': end,
167
+ 'title': video_info['title'],
168
+ 'text': text,
169
+ 'id': f"{video_info['id']}-t{start}",
170
+ 'url': f"https://youtu.be/{video_info['id']}?t={start}",
171
+ 'video_id': video_info['id'],
172
+ }
173
+ })
174
+ return new_transcript_segments