serbog commited on
Commit
55990e0
1 Parent(s): dafd68e

Upload handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -6
handler.py CHANGED
@@ -30,16 +30,20 @@ class EndpointHandler:
30
  inputs = data.pop("inputs", data)
31
 
32
  encodings = self.tokenizer(
33
- inputs, padding=False, truncation=False, max_length=514
34
  )
 
35
  truncated_input_ids = middle_truncate(
36
- encodings["input_ids"], 514, self.tokenizer
37
  )
38
- truncated_input_ids_array = np.array(truncated_input_ids)
39
- attention_masks = (truncated_input_ids_array != 1).astype(int)
 
 
 
40
  truncated_encodings = {
41
- "input_ids": truncated_input_ids,
42
- "attention_mask": attention_masks,
43
  }
44
 
45
  outputs = self.model(**truncated_encodings)
 
30
  inputs = data.pop("inputs", data)
31
 
32
  encodings = self.tokenizer(
33
+ inputs, padding=False, truncation=False, return_tensors="pt"
34
  )
35
+
36
  truncated_input_ids = middle_truncate(
37
+ encodings["input_ids"][0].tolist(), self.MAX_LENGTH, self.tokenizer
38
  )
39
+
40
+ attention_masks = [
41
+ int(token_id != self.tokenizer.pad_token_id)
42
+ for token_id in truncated_input_ids
43
+ ]
44
  truncated_encodings = {
45
+ "input_ids": torch.tensor([truncated_input_ids]),
46
+ "attention_mask": torch.tensor([attention_masks]),
47
  }
48
 
49
  outputs = self.model(**truncated_encodings)