add summarizer to handler
Browse files- handler.py +28 -1
handler.py
CHANGED
@@ -1,9 +1,12 @@
|
|
|
|
|
|
|
|
1 |
from typing import Dict
|
2 |
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
from tqdm import tqdm
|
5 |
import whisper
|
6 |
-
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
7 |
import torch
|
8 |
import pytube
|
9 |
import time
|
@@ -14,7 +17,9 @@ class EndpointHandler():
|
|
14 |
WHISPER_MODEL_NAME = "tiny.en"
|
15 |
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
16 |
QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
|
|
|
17 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
18 |
|
19 |
def __init__(self, path=""):
|
20 |
|
@@ -34,6 +39,13 @@ class EndpointHandler():
|
|
34 |
|
35 |
total = t1 - t0
|
36 |
print(f'Finished loading sentence_transformer_model in {total} seconds')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
|
39 |
t0 = time.time()
|
@@ -59,6 +71,7 @@ class EndpointHandler():
|
|
59 |
video_url = data.pop("video_url", None)
|
60 |
query = data.pop("query", None)
|
61 |
long_form_answer = data.pop("long_form_answer", None)
|
|
|
62 |
encoded_segments = {}
|
63 |
if video_url:
|
64 |
video_with_transcript = self.transcribe_video(video_url)
|
@@ -73,6 +86,9 @@ class EndpointHandler():
|
|
73 |
**video_with_transcript,
|
74 |
**encoded_segments
|
75 |
}
|
|
|
|
|
|
|
76 |
elif query:
|
77 |
if long_form_answer:
|
78 |
context = data.pop("context", None)
|
@@ -167,6 +183,17 @@ class EndpointHandler():
|
|
167 |
|
168 |
return all_batches
|
169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
def generate_answer(self, query, documents):
|
171 |
|
172 |
# concatenate question and support documents into BART input
|
|
|
1 |
+
"""
|
2 |
+
https://huggingface.co/tomiwa1a/video-search
|
3 |
+
"""
|
4 |
from typing import Dict
|
5 |
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
from tqdm import tqdm
|
8 |
import whisper
|
9 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
10 |
import torch
|
11 |
import pytube
|
12 |
import time
|
|
|
17 |
WHISPER_MODEL_NAME = "tiny.en"
|
18 |
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
19 |
QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
|
20 |
+
SUMMARIZER_MODEL_NAME = "philschmid/bart-large-cnn-samsum"
|
21 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 |
+
device_number = 0 if torch.cuda.is_available() else -1
|
23 |
|
24 |
def __init__(self, path=""):
|
25 |
|
|
|
39 |
|
40 |
total = t1 - t0
|
41 |
print(f'Finished loading sentence_transformer_model in {total} seconds')
|
42 |
+
|
43 |
+
t0 = time.time()
|
44 |
+
self.summarizer = pipeline("summarization", model=self.SUMMARIZER_MODEL_NAME, device=device)
|
45 |
+
t1 = time.time()
|
46 |
+
|
47 |
+
total = t1 - t0
|
48 |
+
print(f'Finished loading summarizer in {total} seconds')
|
49 |
|
50 |
self.question_answer_tokenizer = AutoTokenizer.from_pretrained(self.QUESTION_ANSWER_MODEL_NAME)
|
51 |
t0 = time.time()
|
|
|
71 |
video_url = data.pop("video_url", None)
|
72 |
query = data.pop("query", None)
|
73 |
long_form_answer = data.pop("long_form_answer", None)
|
74 |
+
summarize = data.pop("summarize", False)
|
75 |
encoded_segments = {}
|
76 |
if video_url:
|
77 |
video_with_transcript = self.transcribe_video(video_url)
|
|
|
86 |
**video_with_transcript,
|
87 |
**encoded_segments
|
88 |
}
|
89 |
+
elif summarize:
|
90 |
+
summary = self.summarize_video(data["segments"])
|
91 |
+
return {"summary": summary}
|
92 |
elif query:
|
93 |
if long_form_answer:
|
94 |
context = data.pop("context", None)
|
|
|
183 |
|
184 |
return all_batches
|
185 |
|
186 |
+
def summarize_video(self, segments):
|
187 |
+
for index, segment in enumerate(segments):
|
188 |
+
segment['summary'] = self.summarizer(segment['text'])
|
189 |
+
segment['summary'] = segment['summary'][0]['summary_text']
|
190 |
+
print('index', index)
|
191 |
+
print('length', segment['length'])
|
192 |
+
print('text', segment['text'])
|
193 |
+
print('summary', segment['summary'])
|
194 |
+
|
195 |
+
return segments
|
196 |
+
|
197 |
def generate_answer(self, query, documents):
|
198 |
|
199 |
# concatenate question and support documents into BART input
|