Spaces:
Running
on
Zero
Running
on
Zero
import re | |
from typing import Optional, List | |
import vllm | |
from fire import Fire | |
from pydantic import BaseModel | |
from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM | |
class ZeroShotChatTemplate: | |
# This is the default template used in llama-factory for training | |
texts: List[str] = [] | |
def make_prompt(prompt: str) -> str: | |
return f"Human: {prompt}\nAssistant: " | |
def get_stopping_words() -> List[str]: | |
return ["Human:"] | |
def extract_answer(text: str) -> str: | |
filtered = "".join([char for char in text if char.isdigit() or char == " "]) | |
if not filtered.strip(): | |
return text | |
return re.findall(pattern=r"\d+", string=filtered)[-1] | |
class VLLMModel(BaseModel, arbitrary_types_allowed=True): | |
path_model: str | |
model: vllm.LLM = None | |
tokenizer: Optional[PreTrainedTokenizer] = None | |
max_input_length: int = 512 | |
max_output_length: int = 512 | |
stopping_words: Optional[List[str]] = None | |
def load(self): | |
if self.model is None: | |
self.model = vllm.LLM(model=self.path_model, trust_remote_code=True) | |
if self.tokenizer is None: | |
self.tokenizer = AutoTokenizer.from_pretrained(self.path_model) | |
def format_prompt(self, prompt: str) -> str: | |
self.load() | |
prompt = prompt.rstrip(" ") # Llama is sensitive (eg "Answer:" vs "Answer: ") | |
return prompt | |
def make_kwargs(self, do_sample: bool, **kwargs) -> dict: | |
if self.stopping_words: | |
kwargs.update(stop=self.stopping_words) | |
params = vllm.SamplingParams( | |
temperature=0.5 if do_sample else 0.0, | |
max_tokens=self.max_output_length, | |
**kwargs, | |
) | |
outputs = dict(sampling_params=params, use_tqdm=False) | |
return outputs | |
def run(self, prompt: str) -> str: | |
prompt = self.format_prompt(prompt) | |
outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False)) | |
pred = outputs[0].outputs[0].text | |
pred = pred.split("<|endoftext|>")[0] | |
return pred | |
def upload_to_hub(path: str, repo_id: str): | |
tokenizer = AutoTokenizer.from_pretrained(path) | |
model = AutoModelForCausalLM.from_pretrained(path) | |
model.push_to_hub(repo_id) | |
tokenizer.push_to_hub(repo_id) | |
def main( | |
question: str = "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?", | |
**kwargs, | |
): | |
model = VLLMModel(**kwargs) | |
demo = ZeroShotChatTemplate() | |
model.stopping_words = demo.get_stopping_words() | |
prompt = demo.make_prompt(question) | |
raw_outputs = model.run(prompt) | |
pred = demo.extract_answer(raw_outputs) | |
print(dict(question=question, prompt=prompt, raw_outputs=raw_outputs, pred=pred)) | |
""" | |
p run_demo.py upload_to_hub outputs_paths/gsm8k_paths_llama3_8b_beta_03_rank_128/final chiayewken/llama3-8b-gsm8k-rpo | |
p run_demo.py main --path_model chiayewken/llama3-8b-gsm8k-rpo | |
""" | |
if __name__ == "__main__": | |
Fire() | |