tomiwa1a commited on
Commit
b80a146
1 Parent(s): efe5d70

add summarizer to handler

Browse files
Files changed (1) hide show
  1. handler.py +28 -1
handler.py CHANGED
@@ -1,9 +1,12 @@
 
 
 
1
  from typing import Dict
2
 
3
  from sentence_transformers import SentenceTransformer
4
  from tqdm import tqdm
5
  import whisper
6
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  import torch
8
  import pytube
9
  import time
@@ -14,7 +17,9 @@ class EndpointHandler():
14
  WHISPER_MODEL_NAME = "tiny.en"
15
  SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
16
  QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
 
17
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
18
 
19
  def __init__(self, path=""):
20
 
@@ -34,6 +39,13 @@ class EndpointHandler():
34
 
35
  total = t1 - t0
36
  print(f'Finished loading sentence_transformer_model in {total} seconds')
 
 
 
 
 
 
 
37
 
38
  self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
39
  t0 = time.time()
@@ -59,6 +71,7 @@ class EndpointHandler():
59
  video_url = data.pop("video_url", None)
60
  query = data.pop("query", None)
61
  long_form_answer = data.pop("long_form_answer", None)
 
62
  encoded_segments = {}
63
  if video_url:
64
  video_with_transcript = self.transcribe_video(video_url)
@@ -73,6 +86,9 @@ class EndpointHandler():
73
  **video_with_transcript,
74
  **encoded_segments
75
  }
 
 
 
76
  elif query:
77
  if long_form_answer:
78
  context = data.pop("context", None)
@@ -167,6 +183,17 @@ class EndpointHandler():
167
 
168
  return all_batches
169
 
 
 
 
 
 
 
 
 
 
 
 
170
  def generate_answer(self, query, documents):
171
 
172
  # concatenate question and support documents into BART input
 
1
+ """
2
+ https://huggingface.co/tomiwa1a/video-search
3
+ """
4
  from typing import Dict
5
 
6
  from sentence_transformers import SentenceTransformer
7
  from tqdm import tqdm
8
  import whisper
9
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
10
  import torch
11
  import pytube
12
  import time
 
17
  WHISPER_MODEL_NAME = "tiny.en"
18
  SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
19
  QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
20
+ SUMMARIZER_MODEL_NAME = "philschmid/bart-large-cnn-samsum"
21
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+ device_number = 0 if torch.cuda.is_available() else -1
23
 
24
  def __init__(self, path=""):
25
 
 
39
 
40
  total = t1 - t0
41
  print(f'Finished loading sentence_transformer_model in {total} seconds')
42
+
43
+ t0 = time.time()
44
+ self.summarizer = pipeline("summarization", model=self.SUMMARIZER_MODEL_NAME, device=device)
45
+ t1 = time.time()
46
+
47
+ total = t1 - t0
48
+ print(f'Finished loading summarizer in {total} seconds')
49
 
50
  self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
51
  t0 = time.time()
 
71
  video_url = data.pop("video_url", None)
72
  query = data.pop("query", None)
73
  long_form_answer = data.pop("long_form_answer", None)
74
+ summarize = data.pop("summarize", False)
75
  encoded_segments = {}
76
  if video_url:
77
  video_with_transcript = self.transcribe_video(video_url)
 
86
  **video_with_transcript,
87
  **encoded_segments
88
  }
89
+ elif summarize:
90
+ summary = self.summarize_video(data["segments"])
91
+ return {"summary": summary}
92
  elif query:
93
  if long_form_answer:
94
  context = data.pop("context", None)
 
183
 
184
  return all_batches
185
 
186
+ def summarize_video(self, segments):
187
+ for index, segment in enumerate(segments):
188
+ segment['summary'] = self.summarizer(segment['text'])
189
+ segment['summary'] = segment['summary'][0]['summary_text']
190
+ print('index', index)
191
+ print('length', segment['length'])
192
+ print('text', segment['text'])
193
+ print('summary', segment['summary'])
194
+
195
+ return segments
196
+
197
  def generate_answer(self, query, documents):
198
 
199
  # concatenate question and support documents into BART input