oleksandrfluxon commited on
Commit
44092ff
1 Parent(s): e7b6e5f

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +81 -0
pipeline.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import cuda
2
+ import transformers
3
+ from accelerate import dispatch_model, infer_auto_device_map
4
+ from accelerate.utils import get_balanced_memory
5
+ from transformers import BitsAndBytesConfig, StoppingCriteria, StoppingCriteriaList
6
+ from typing import Dict, List, Any
7
+
8
+ # define custom stopping criteria object
9
+ class StopOnTokens(StoppingCriteria):
10
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
11
+ for stop_ids in stop_token_ids:
12
+ if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
13
+ return True
14
+ return False
15
+
16
+ class PreTrainedPipeline():
17
+ def __init__(self, path=""):
18
+ path = "oleksandrfluxon/mpt-7b-instruct-evaluate"
19
+ print("===> path", path)
20
+
21
+ device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
22
+ print("===> device", device)
23
+
24
+ model = transformers.AutoModelForCausalLM.from_pretrained(
25
+ 'oleksandrfluxon/mpt-7b-instruct-evaluate',
26
+ trust_remote_code=True,
27
+ load_in_8bit=True, # this requires the `bitsandbytes` library
28
+ max_seq_len=8192,
29
+ init_device=device
30
+ )
31
+ model.eval()
32
+ #model.to(device)
33
+ print(f"===> Model loaded on {device}")
34
+
35
+ tokenizer = transformers.AutoTokenizer.from_pretrained("mosaicml/mpt-7b")
36
+
37
+ # we create a list of stopping criteria
38
+ stop_token_ids = [
39
+ tokenizer.convert_tokens_to_ids(x) for x in [
40
+ ['Human', ':'], ['AI', ':']
41
+ ]
42
+ ]
43
+ stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
44
+ print("===> stop_token_ids", stop_token_ids)
45
+
46
+ stopping_criteria = StoppingCriteriaList([StopOnTokens()])
47
+
48
+ self.pipeline = transformers.pipeline(
49
+ model=model, tokenizer=tokenizer,
50
+ return_full_text=True, # langchain expects the full text
51
+ task='text-generation',
52
+ # we pass model parameters here too
53
+ stopping_criteria=stopping_criteria, # without this model rambles during chat
54
+ temperature=0.1, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
55
+ top_p=0.15, # select from top tokens whose probability add up to 15%
56
+ top_k=0, # select from top 0 tokens (because zero, relies on top_p)
57
+ max_new_tokens=128, # mex number of tokens to generate in the output
58
+ repetition_penalty=1.1 # without this output begins repeating
59
+ )
60
+
61
+ print("===> init finished")
62
+
63
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
64
+ """
65
+ data args:
66
+ inputs (:obj: `str`)
67
+ parameters (:obj: `str`)
68
+ Return:
69
+ A :obj:`str`: todo
70
+ """
71
+ # get inputs
72
+ inputs = data.pop("inputs",data)
73
+ parameters = data.pop("parameters", {})
74
+ date = data.pop("date", None)
75
+ print("===> inputs", inputs)
76
+ print("===> parameters", parameters)
77
+
78
+ result = self.pipeline(inputs, **parameters)
79
+ print("===> result", result)
80
+
81
+ return result