k4tel commited on
Commit
6b5e3dd
·
1 Parent(s): 3d6f284
Files changed (1) hide show
  1. pipe.py +33 -0
pipe.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from transformers import Pipeline
6
+
7
+ class RegressionPipeline(Pipeline):
8
+ def _sanitize_parameters(self, **kwargs):
9
+ preprocess_kwargs = {}
10
+ if "maybe_arg" in kwargs:
11
+ preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
12
+ return preprocess_kwargs, {}, {}
13
+
14
+ def preprocess(self, inputs, maybe_arg=2):
15
+ print(inputs)
16
+ encoded_corpus = self.tokenizer(text=inputs,
17
+ add_special_tokens=True,
18
+ padding='max_length',
19
+ truncation='longest_first',
20
+ max_length=300,
21
+ return_attention_mask=True)
22
+ return {"model_input": encoded_corpus}
23
+
24
+ def _forward(self, model_inputs):
25
+ print(model_inputs)
26
+ # model_inputs == {"model_input": model_input}
27
+ outputs = self.model(torch.tensor(model_inputs['model_input']['input_ids']).reshape(1, -1).to(torch.int64),
28
+ torch.tensor(model_inputs['model_input']['attention_mask']).reshape(1, -1).to(torch.int64))
29
+ return outputs
30
+
31
+ def postprocess(self, model_outputs):
32
+ print(model_outputs)
33
+ return model_outputs.numpy()