fix error in device index for summarizer
Browse files- handler.py +2 -3
handler.py
CHANGED
@@ -18,12 +18,11 @@ class EndpointHandler():
|
|
18 |
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
19 |
QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
|
20 |
SUMMARIZER_MODEL_NAME = "philschmid/bart-large-cnn-samsum"
|
21 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
22 |
-
device_number = 0 if torch.cuda.is_available() else -1
|
23 |
|
24 |
def __init__(self, path=""):
|
25 |
|
26 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
27 |
print(f'whisper and question_answer_model will use: {device}')
|
28 |
|
29 |
t0 = time.time()
|
@@ -41,7 +40,7 @@ class EndpointHandler():
|
|
41 |
print(f'Finished loading sentence_transformer_model in {total} seconds')
|
42 |
|
43 |
t0 = time.time()
|
44 |
-
self.summarizer = pipeline("summarization", model=self.SUMMARIZER_MODEL_NAME, device=
|
45 |
t1 = time.time()
|
46 |
|
47 |
total = t1 - t0
|
|
|
18 |
SENTENCE_TRANSFORMER_MODEL_NAME = "multi-qa-mpnet-base-dot-v1"
|
19 |
QUESTION_ANSWER_MODEL_NAME = "vblagoje/bart_lfqa"
|
20 |
SUMMARIZER_MODEL_NAME = "philschmid/bart-large-cnn-samsum"
|
|
|
|
|
21 |
|
22 |
def __init__(self, path=""):
|
23 |
|
24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
25 |
+
device_number = 0 if torch.cuda.is_available() else -1
|
26 |
print(f'whisper and question_answer_model will use: {device}')
|
27 |
|
28 |
t0 = time.time()
|
|
|
40 |
print(f'Finished loading sentence_transformer_model in {total} seconds')
|
41 |
|
42 |
t0 = time.time()
|
43 |
+
self.summarizer = pipeline("summarization", model=self.SUMMARIZER_MODEL_NAME, device=device_number)
|
44 |
t1 = time.time()
|
45 |
|
46 |
total = t1 - t0
|