File size: 3,112 Bytes
d38ce92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import re
from typing import Optional, List

import vllm
from fire import Fire
from pydantic import BaseModel
from transformers import PreTrainedTokenizer, AutoTokenizer, AutoModelForCausalLM


class ZeroShotChatTemplate:
    # This is the default template used in llama-factory for training
    texts: List[str] = []

    @staticmethod
    def make_prompt(prompt: str) -> str:
        return f"Human: {prompt}\nAssistant: "

    @staticmethod
    def get_stopping_words() -> List[str]:
        return ["Human:"]

    @staticmethod
    def extract_answer(text: str) -> str:
        filtered = "".join([char for char in text if char.isdigit() or char == " "])
        if not filtered.strip():
            return text
        return re.findall(pattern=r"\d+", string=filtered)[-1]


class VLLMModel(BaseModel, arbitrary_types_allowed=True):
    path_model: str
    model: vllm.LLM = None
    tokenizer: Optional[PreTrainedTokenizer] = None
    max_input_length: int = 512
    max_output_length: int = 512
    stopping_words: Optional[List[str]] = None

    def load(self):
        if self.model is None:
            self.model = vllm.LLM(model=self.path_model, trust_remote_code=True)
        if self.tokenizer is None:
            self.tokenizer = AutoTokenizer.from_pretrained(self.path_model)

    def format_prompt(self, prompt: str) -> str:
        self.load()
        prompt = prompt.rstrip(" ")  # Llama is sensitive (eg "Answer:" vs "Answer: ")
        return prompt

    def make_kwargs(self, do_sample: bool, **kwargs) -> dict:
        if self.stopping_words:
            kwargs.update(stop=self.stopping_words)
        params = vllm.SamplingParams(
            temperature=0.5 if do_sample else 0.0,
            max_tokens=self.max_output_length,
            **kwargs,
        )

        outputs = dict(sampling_params=params, use_tqdm=False)
        return outputs

    def run(self, prompt: str) -> str:
        prompt = self.format_prompt(prompt)
        outputs = self.model.generate([prompt], **self.make_kwargs(do_sample=False))
        pred = outputs[0].outputs[0].text
        pred = pred.split("<|endoftext|>")[0]
        return pred


def upload_to_hub(path: str, repo_id: str):
    tokenizer = AutoTokenizer.from_pretrained(path)
    model = AutoModelForCausalLM.from_pretrained(path)
    model.push_to_hub(repo_id)
    tokenizer.push_to_hub(repo_id)


def main(
    question: str = "Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?",
    **kwargs,
):
    model = VLLMModel(**kwargs)
    demo = ZeroShotChatTemplate()
    model.stopping_words = demo.get_stopping_words()

    prompt = demo.make_prompt(question)
    raw_outputs = model.run(prompt)
    pred = demo.extract_answer(raw_outputs)
    print(dict(question=question, prompt=prompt, raw_outputs=raw_outputs, pred=pred))


"""
p run_demo.py upload_to_hub outputs_paths/gsm8k_paths_llama3_8b_beta_03_rank_128/final chiayewken/llama3-8b-gsm8k-rpo
p run_demo.py main --path_model chiayewken/llama3-8b-gsm8k-rpo
"""


if __name__ == "__main__":
    Fire()