File size: 5,140 Bytes
2841b26
 
 
1fc08db
2841b26
f5ab4cb
 
 
2841b26
f5ab4cb
 
 
 
 
8dfc799
f5ab4cb
3c2fc33
 
 
 
 
 
cd47483
3c2fc33
 
f5ab4cb
 
2841b26
 
 
 
 
 
 
 
 
 
 
 
 
 
0d14ea5
2841b26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371c76b
 
 
 
 
 
 
 
 
 
 
 
 
f5ab4cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371c76b
f5ab4cb
 
 
 
 
 
 
 
371c76b
f5ab4cb
 
8dfc799
 
 
 
 
 
 
 
 
371c76b
8dfc799
 
f5ab4cb
 
 
 
 
371c76b
f5ab4cb
 
 
 
 
 
 
 
 
 
 
371c76b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import math
import random

from distilabel.models import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
from distilabel.steps.tasks import TextGeneration

from synthetic_dataset_generator.constants import (
    API_KEYS,
    DEFAULT_BATCH_SIZE,
    HUGGINGFACE_BASE_URL,
    MODEL,
    OLLAMA_BASE_URL,
    OPENAI_BASE_URL,
    TOKENIZER_ID,
    VLLM_BASE_URL,
)

TOKEN_INDEX = 0


def _get_next_api_key():
    global TOKEN_INDEX
    api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
    TOKEN_INDEX += 1
    return api_key


def _get_prompt_rewriter():
    generation_kwargs = {
        "temperature": 1,
    }
    system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new."
    prompt_rewriter = TextGeneration(
        llm=_get_llm(generation_kwargs=generation_kwargs),
        system_prompt=system_prompt,
        use_system_prompt=True,
    )
    prompt_rewriter.load()
    return prompt_rewriter


def get_rewritten_prompts(prompt: str, num_rows: int):
    prompt_rewriter = _get_prompt_rewriter()
    # create prompt rewrites
    inputs = [
        {"instruction": f"Original prompt: {prompt} \nRewritten prompt: "}
        for i in range(math.floor(num_rows / 100))
    ]
    n_processed = 0
    prompt_rewrites = [prompt]
    while n_processed < num_rows:
        batch = list(
            prompt_rewriter.process(
                inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE]
            )
        )
        prompt_rewrites += [entry["generation"] for entry in batch[0]]
        n_processed += DEFAULT_BATCH_SIZE
        random.seed(a=random.randint(0, 2**32 - 1))
    return prompt_rewrites


def _get_llm_class() -> str:
    if OPENAI_BASE_URL:
        return "OpenAILLM"
    elif OLLAMA_BASE_URL:
        return "OllamaLLM"
    elif HUGGINGFACE_BASE_URL:
        return "InferenceEndpointsLLM"
    elif VLLM_BASE_URL:
        return "ClientvLLM"
    else:
        return "InferenceEndpointsLLM"


def _get_llm(use_magpie_template=False, **kwargs):
    if OPENAI_BASE_URL:
        llm = OpenAILLM(
            model=MODEL,
            base_url=OPENAI_BASE_URL,
            api_key=_get_next_api_key(),
            **kwargs,
        )
        if "generation_kwargs" in kwargs:
            if "stop_sequences" in kwargs["generation_kwargs"]:
                kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
                    "stop_sequences"
                ]
                del kwargs["generation_kwargs"]["stop_sequences"]
            if "do_sample" in kwargs["generation_kwargs"]:
                del kwargs["generation_kwargs"]["do_sample"]
    elif OLLAMA_BASE_URL:
        if "generation_kwargs" in kwargs:
            if "max_new_tokens" in kwargs["generation_kwargs"]:
                kwargs["generation_kwargs"]["num_predict"] = kwargs[
                    "generation_kwargs"
                ]["max_new_tokens"]
                del kwargs["generation_kwargs"]["max_new_tokens"]
            if "stop_sequences" in kwargs["generation_kwargs"]:
                kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
                    "stop_sequences"
                ]
                del kwargs["generation_kwargs"]["stop_sequences"]
            if "do_sample" in kwargs["generation_kwargs"]:
                del kwargs["generation_kwargs"]["do_sample"]
            options = kwargs["generation_kwargs"]
            del kwargs["generation_kwargs"]
            kwargs["generation_kwargs"] = {}
            kwargs["generation_kwargs"]["options"] = options
        llm = OllamaLLM(
            model=MODEL,
            host=OLLAMA_BASE_URL,
            tokenizer_id=TOKENIZER_ID or MODEL,
            use_magpie_template=use_magpie_template,
            **kwargs,
        )
    elif HUGGINGFACE_BASE_URL:
        kwargs["generation_kwargs"]["do_sample"] = True
        llm = InferenceEndpointsLLM(
            api_key=_get_next_api_key(),
            base_url=HUGGINGFACE_BASE_URL,
            tokenizer_id=TOKENIZER_ID or MODEL,
            use_magpie_template=use_magpie_template,
            **kwargs,
        )
    elif VLLM_BASE_URL:
        if "generation_kwargs" in kwargs:
            if "do_sample" in kwargs["generation_kwargs"]:
                del kwargs["generation_kwargs"]["do_sample"]
        llm = ClientvLLM(
            base_url=VLLM_BASE_URL,
            model=MODEL,
            tokenizer=TOKENIZER_ID or MODEL,
            api_key=_get_next_api_key(),
            use_magpie_template=use_magpie_template,
            **kwargs,
        )
    else:
        llm = InferenceEndpointsLLM(
            api_key=_get_next_api_key(),
            tokenizer_id=TOKENIZER_ID or MODEL,
            model_id=MODEL,
            use_magpie_template=use_magpie_template,
            **kwargs,
        )

    return llm


try:
    llm = _get_llm()
    llm.load()
    llm.generate([[{"content": "Hello, world!", "role": "user"}]])
except Exception as e:
    raise Exception(f"Error loading {llm.__class__.__name__}: {e}")