import re from typing import Optional, List import vllm from fire import Fire from pydantic import BaseModel from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM class ZeroShotChatTemplate: # This is the default template used in llama-factory for training texts: List[str] = [] @staticmethod def make_prompt(prompt: str) -> str: return f"Human: {prompt}\nAssistant: " @staticmethod def get_stopping_words() -> List[str]: return ["Human:"] @staticmethod def extract_answer(text: str) -> str: filtered = "".join([char for char in text if char.isdigit() or char == " "]) if not filtered.strip(): return text return re.findall(pattern=r"\d+", string=filtered)[-1] class VLLMModel(BaseModel, arbitrary_types_allowed=True): path_model: str model: vllm.LLM = None tokenizer: Optional[PreTrainedTokenizer] = None max_input_length: int = 512 max_output_length: int = 512 stopping_words: Optional[List[str]] = None def load(self): if self.model is None: self.model = vllm.LLM(model=self.path_model, trust_remote_code=True) if self.tokenizer is None: self.tokenizer = AutoTokenizer.from_pretrained(self.path_model) def format_prompt(self, prompt: str) -> str: self.load() prompt = prompt.rstrip(" ") # Llama is sensitive (eg "Answer:" vs "Answer: ") return prompt def make_kwargs(self, do_sample: bool, **kwargs) -> dict: if self.stopping_words: kwargs.update(stop=self.stopping_words) params = vllm.SamplingParams( temperature=0.5 if do_sample else 0.0, max_tokens=self.max_output_length, **kwargs, ) outputs = dict(sampling_params=params, use_tqdm=False) return outputs def run(self, prompt: str) -> str: prompt = self.format_prompt(prompt) outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False)) pred = outputs[0].outputs[0].text pred = pred.split("<|endoftext|>")[0] return pred def upload_to_hub(path: str, repo_id: str): tokenizer = AutoTokenizer.from_pretrained(path) model = AutoModelForCausalLM.from_pretrained(path) model.push_to_hub(repo_id) tokenizer.push_to_hub(repo_id) def main( question: str = "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?", **kwargs, ): model = VLLMModel(**kwargs) demo = ZeroShotChatTemplate() model.stopping_words = demo.get_stopping_words() prompt = demo.make_prompt(question) raw_outputs = model.run(prompt) pred = demo.extract_answer(raw_outputs) print(dict(question=question, prompt=prompt, raw_outputs=raw_outputs, pred=pred)) """ p run_demo.py upload_to_hub outputs_paths/gsm8k_paths_llama3_8b_beta_03_rank_128/final chiayewken/llama3-8b-gsm8k-rpo p run_demo.py main --path_model chiayewken/llama3-8b-gsm8k-rpo """ if __name__ == "__main__": Fire()