Spaces:
Sleeping
Sleeping
File size: 3,112 Bytes
d38ce92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import re
from typing import Optional, List
import vllm
from fire import Fire
from pydantic import BaseModel
from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM
class ZeroShotChatTemplate:
# This is the default template used in llama-factory for training
texts: List[str] = []
@staticmethod
def make_prompt(prompt: str) -> str:
return f"Human: {prompt}\nAssistant: "
@staticmethod
def get_stopping_words() -> List[str]:
return ["Human:"]
@staticmethod
def extract_answer(text: str) -> str:
filtered = "".join([char for char in text if char.isdigit() or char == " "])
if not filtered.strip():
return text
return re.findall(pattern=r"\d+", string=filtered)[-1]
class VLLMModel(BaseModel, arbitrary_types_allowed=True):
path_model: str
model: vllm.LLM = None
tokenizer: Optional[PreTrainedTokenizer] = None
max_input_length: int = 512
max_output_length: int = 512
stopping_words: Optional[List[str]] = None
def load(self):
if self.model is None:
self.model = vllm.LLM(model=self.path_model, trust_remote_code=True)
if self.tokenizer is None:
self.tokenizer = AutoTokenizer.from_pretrained(self.path_model)
def format_prompt(self, prompt: str) -> str:
self.load()
prompt = prompt.rstrip(" ") # Llama is sensitive (eg "Answer:" vs "Answer: ")
return prompt
def make_kwargs(self, do_sample: bool, **kwargs) -> dict:
if self.stopping_words:
kwargs.update(stop=self.stopping_words)
params = vllm.SamplingParams(
temperature=0.5 if do_sample else 0.0,
max_tokens=self.max_output_length,
**kwargs,
)
outputs = dict(sampling_params=params, use_tqdm=False)
return outputs
def run(self, prompt: str) -> str:
prompt = self.format_prompt(prompt)
outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False))
pred = outputs[0].outputs[0].text
pred = pred.split("<|endoftext|>")[0]
return pred
def upload_to_hub(path: str, repo_id: str):
tokenizer = AutoTokenizer.from_pretrained(path)
model = AutoModelForCausalLM.from_pretrained(path)
model.push_to_hub(repo_id)
tokenizer.push_to_hub(repo_id)
def main(
question: str = "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?",
**kwargs,
):
model = VLLMModel(**kwargs)
demo = ZeroShotChatTemplate()
model.stopping_words = demo.get_stopping_words()
prompt = demo.make_prompt(question)
raw_outputs = model.run(prompt)
pred = demo.extract_answer(raw_outputs)
print(dict(question=question, prompt=prompt, raw_outputs=raw_outputs, pred=pred))
"""
p run_demo.py upload_to_hub outputs_paths/gsm8k_paths_llama3_8b_beta_03_rank_128/final chiayewken/llama3-8b-gsm8k-rpo
p run_demo.py main --path_model chiayewken/llama3-8b-gsm8k-rpo
"""
if __name__ == "__main__":
Fire()
|