oleksandrfluxon commited on
Commit
723f5ca
1 Parent(s): d461c7e

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +46 -0
pipeline.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import transformers
3
+ from typing import Dict, List, Any
4
+
5
+ class PreTrainedPipeline():
6
+ def __init__(self, path=""):
7
+ path = "oleksandrfluxon/mpt-7b-instruct-2"
8
+ print("===> path", path)
9
+
10
+ config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
11
+ config.max_seq_len = 4096 # (input + output) tokens can now be up to 4096
12
+
13
+ print("===> loading model")
14
+ model = transformers.AutoModelForCausalLM.from_pretrained(
15
+ name,
16
+ config=config,
17
+ torch_dtype=torch.bfloat16, # Load model weights in bfloat16
18
+ trust_remote_code=True,
19
+ load_in_4bit=True, # Load model in the lowest 4-bit precision quantization
20
+ )
21
+ print("===> model loaded")
22
+
23
+ tokenizer = transformers.AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b', padding_side="left", device_map="auto")
24
+
25
+ self.pipeline = transformers.pipeline('text-generation', model=model, tokenizer=tokenizer)
26
+ print("===> init finished")
27
+
28
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
29
+ """
30
+ data args:
31
+ inputs (:obj: `str`)
32
+ parameters (:obj: `str`)
33
+ Return:
34
+ A :obj:`str`: todo
35
+ """
36
+ # get inputs
37
+ inputs = data.pop("inputs",data)
38
+ parameters = data.pop("parameters", {})
39
+ date = data.pop("date", None)
40
+ print("===> inputs", inputs)
41
+ print("===> parameters", parameters)
42
+
43
+ result = self.pipeline(inputs, **parameters)
44
+ print("===> result", result)
45
+
46
+ return result