Create predict.py
Browse files- predict.py +148 -0
predict.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
5 |
+
|
6 |
+
from vllm import LLM, SamplingParams
|
7 |
+
import torch
|
8 |
+
from cog import BasePredictor, Input, ConcatenateIterator
|
9 |
+
import typing as t
|
10 |
+
|
11 |
+
|
12 |
+
MODEL_ID = "TheBloke/Mistral-7B-OpenOrca-AWQ"
|
13 |
+
PROMPT_TEMPLATE = """\
|
14 |
+
<|im_start|>system
|
15 |
+
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!
|
16 |
+
<|im_end|>
|
17 |
+
<|im_start|>user
|
18 |
+
{prompt}<|im_end|>
|
19 |
+
<|im_start|>assistant
|
20 |
+
"""
|
21 |
+
|
22 |
+
DEFAULT_MAX_NEW_TOKENS = 512
|
23 |
+
DEFAULT_TEMPERATURE = 0.8
|
24 |
+
DEFAULT_TOP_P = 0.95
|
25 |
+
DEFAULT_TOP_K = 50
|
26 |
+
DEFAULT_PRESENCE_PENALTY = 0.0 # 1.15
|
27 |
+
DEFAULT_FREQUENCY_PENALTY = 0.0 # 0.2
|
28 |
+
|
29 |
+
|
30 |
+
def vllm_generate_iterator(
|
31 |
+
self, prompt: str, /, *, echo: bool = False, stop: str = None, stop_token_ids: t.List[int] = None, sampling_params=None, **attrs: t.Any
|
32 |
+
) -> t.Iterator[t.Dict[str, t.Any]]:
|
33 |
+
request_id: str = attrs.pop('request_id', None)
|
34 |
+
if request_id is None: raise ValueError('request_id must not be None.')
|
35 |
+
if stop_token_ids is None: stop_token_ids = []
|
36 |
+
stop_token_ids.append(self.tokenizer.eos_token_id)
|
37 |
+
stop_ = set()
|
38 |
+
if isinstance(stop, str) and stop != '': stop_.add(stop)
|
39 |
+
elif isinstance(stop, list) and stop != []: stop_.update(stop)
|
40 |
+
for tid in stop_token_ids:
|
41 |
+
if tid: stop_.add(self.tokenizer.decode(tid))
|
42 |
+
|
43 |
+
# if self.config['temperature'] <= 1e-5: top_p = 1.0
|
44 |
+
# else: top_p = self.config['top_p']
|
45 |
+
# config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
|
46 |
+
self.add_request(request_id=request_id, prompt=prompt, sampling_params=sampling_params)
|
47 |
+
|
48 |
+
token_cache = []
|
49 |
+
print_len = 0
|
50 |
+
|
51 |
+
while self.has_unfinished_requests():
|
52 |
+
for request_output in self.step():
|
53 |
+
# Add the new tokens to the cache
|
54 |
+
for output in request_output.outputs:
|
55 |
+
text = output.text
|
56 |
+
yield {'text': text, 'error_code': 0, 'num_tokens': len(output.token_ids)}
|
57 |
+
|
58 |
+
if request_output.finished: break
|
59 |
+
|
60 |
+
|
61 |
+
class Predictor(BasePredictor):
|
62 |
+
|
63 |
+
def setup(self):
|
64 |
+
self.llm = LLM(
|
65 |
+
model=MODEL_ID,
|
66 |
+
quantization="awq",
|
67 |
+
dtype="float16"
|
68 |
+
)
|
69 |
+
|
70 |
+
def predict(
|
71 |
+
self,
|
72 |
+
prompt: str,
|
73 |
+
max_new_tokens: int = Input(
|
74 |
+
description="The maximum number of tokens the model should generate as output.",
|
75 |
+
default=DEFAULT_MAX_NEW_TOKENS,
|
76 |
+
),
|
77 |
+
temperature: float = Input(
|
78 |
+
description="The value used to modulate the next token probabilities.", default=DEFAULT_TEMPERATURE
|
79 |
+
),
|
80 |
+
top_p: float = Input(
|
81 |
+
description="A probability threshold for generating the output. If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751).",
|
82 |
+
default=DEFAULT_TOP_P,
|
83 |
+
),
|
84 |
+
top_k: int = Input(
|
85 |
+
description="The number of highest probability tokens to consider for generating the output. If > 0, only keep the top k tokens with highest probability (top-k filtering).",
|
86 |
+
default=DEFAULT_TOP_K,
|
87 |
+
),
|
88 |
+
presence_penalty: float = Input(
|
89 |
+
description="Presence penalty",
|
90 |
+
default=DEFAULT_PRESENCE_PENALTY,
|
91 |
+
),
|
92 |
+
frequency_penalty: float = Input(
|
93 |
+
description="Frequency penalty",
|
94 |
+
default=DEFAULT_FREQUENCY_PENALTY,
|
95 |
+
),
|
96 |
+
prompt_template: str = Input(
|
97 |
+
description="The template used to format the prompt. The input prompt is inserted into the template using the `{prompt}` placeholder.",
|
98 |
+
default=PROMPT_TEMPLATE,
|
99 |
+
)
|
100 |
+
) -> ConcatenateIterator:
|
101 |
+
prompts = [
|
102 |
+
(
|
103 |
+
prompt_template.format(prompt=prompt),
|
104 |
+
SamplingParams(
|
105 |
+
max_tokens=max_new_tokens,
|
106 |
+
temperature=temperature,
|
107 |
+
top_k=top_k,
|
108 |
+
top_p=top_p,
|
109 |
+
presence_penalty=presence_penalty,
|
110 |
+
frequency_penalty=frequency_penalty
|
111 |
+
)
|
112 |
+
)
|
113 |
+
]
|
114 |
+
start = time.time()
|
115 |
+
while True:
|
116 |
+
if prompts:
|
117 |
+
prompt, sampling_params = prompts.pop(0)
|
118 |
+
gen = vllm_generate_iterator(self.llm.llm_engine, prompt, echo=False, stop=None, stop_token_ids=None, sampling_params=sampling_params, request_id=0)
|
119 |
+
last = ""
|
120 |
+
for _, x in enumerate(gen):
|
121 |
+
if x['text'] == "":
|
122 |
+
continue
|
123 |
+
yield x['text'][len(last):]
|
124 |
+
last = x["text"]
|
125 |
+
num_tokens = x["num_tokens"]
|
126 |
+
print(f"\nGenerated {num_tokens} tokens in {time.time() - start} seconds.")
|
127 |
+
|
128 |
+
if not (self.llm.llm_engine.has_unfinished_requests() or prompts):
|
129 |
+
break
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
import sys
|
134 |
+
p = Predictor()
|
135 |
+
p.setup()
|
136 |
+
gen = p.predict(
|
137 |
+
"Write me an itinerary for my dog's birthday party.",
|
138 |
+
512,
|
139 |
+
0.8,
|
140 |
+
0.95,
|
141 |
+
50,
|
142 |
+
1.0,
|
143 |
+
0.2,
|
144 |
+
PROMPT_TEMPLATE,
|
145 |
+
)
|
146 |
+
for out in gen:
|
147 |
+
print(out, end="")
|
148 |
+
sys.stdout.flush()
|