Upload handler.py
Browse files- handler.py +10 -6
handler.py
CHANGED
@@ -30,16 +30,20 @@ class EndpointHandler:
|
|
30 |
inputs = data.pop("inputs", data)
|
31 |
|
32 |
encodings = self.tokenizer(
|
33 |
-
inputs, padding=False, truncation=False,
|
34 |
)
|
|
|
35 |
truncated_input_ids = middle_truncate(
|
36 |
-
encodings["input_ids"],
|
37 |
)
|
38 |
-
|
39 |
-
attention_masks =
|
|
|
|
|
|
|
40 |
truncated_encodings = {
|
41 |
-
"input_ids": truncated_input_ids,
|
42 |
-
"attention_mask": attention_masks,
|
43 |
}
|
44 |
|
45 |
outputs = self.model(**truncated_encodings)
|
|
|
30 |
inputs = data.pop("inputs", data)
|
31 |
|
32 |
encodings = self.tokenizer(
|
33 |
+
inputs, padding=False, truncation=False, return_tensors="pt"
|
34 |
)
|
35 |
+
|
36 |
truncated_input_ids = middle_truncate(
|
37 |
+
encodings["input_ids"][0].tolist(), self.MAX_LENGTH, self.tokenizer
|
38 |
)
|
39 |
+
|
40 |
+
attention_masks = [
|
41 |
+
int(token_id != self.tokenizer.pad_token_id)
|
42 |
+
for token_id in truncated_input_ids
|
43 |
+
]
|
44 |
truncated_encodings = {
|
45 |
+
"input_ids": torch.tensor([truncated_input_ids]),
|
46 |
+
"attention_mask": torch.tensor([attention_masks]),
|
47 |
}
|
48 |
|
49 |
outputs = self.model(**truncated_encodings)
|