|
import math |
|
import random |
|
|
|
from distilabel.models import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM |
|
from distilabel.steps.tasks import TextGeneration |
|
|
|
from synthetic_dataset_generator.constants import ( |
|
API_KEYS, |
|
DEFAULT_BATCH_SIZE, |
|
HUGGINGFACE_BASE_URL, |
|
MODEL, |
|
OLLAMA_BASE_URL, |
|
OPENAI_BASE_URL, |
|
TOKENIZER_ID, |
|
VLLM_BASE_URL, |
|
) |
|
|
|
TOKEN_INDEX = 0 |
|
|
|
|
|
def _get_next_api_key(): |
|
global TOKEN_INDEX |
|
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)] |
|
TOKEN_INDEX += 1 |
|
return api_key |
|
|
|
|
|
def _get_prompt_rewriter(): |
|
generation_kwargs = { |
|
"temperature": 1, |
|
} |
|
system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new." |
|
prompt_rewriter = TextGeneration( |
|
llm=_get_llm(generation_kwargs=generation_kwargs), |
|
system_prompt=system_prompt, |
|
use_system_prompt=True, |
|
) |
|
prompt_rewriter.load() |
|
return prompt_rewriter |
|
|
|
|
|
def get_rewritten_prompts(prompt: str, num_rows: int): |
|
prompt_rewriter = _get_prompt_rewriter() |
|
|
|
inputs = [ |
|
{"instruction": f"Original prompt: {prompt} \nRewritten prompt: "} |
|
for i in range(math.floor(num_rows / 100)) |
|
] |
|
n_processed = 0 |
|
prompt_rewrites = [prompt] |
|
while n_processed < num_rows: |
|
batch = list( |
|
prompt_rewriter.process( |
|
inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE] |
|
) |
|
) |
|
prompt_rewrites += [entry["generation"] for entry in batch[0]] |
|
n_processed += DEFAULT_BATCH_SIZE |
|
random.seed(a=random.randint(0, 2**32 - 1)) |
|
return prompt_rewrites |
|
|
|
|
|
def _get_llm_class() -> str: |
|
if OPENAI_BASE_URL: |
|
return "OpenAILLM" |
|
elif OLLAMA_BASE_URL: |
|
return "OllamaLLM" |
|
elif HUGGINGFACE_BASE_URL: |
|
return "InferenceEndpointsLLM" |
|
elif VLLM_BASE_URL: |
|
return "ClientvLLM" |
|
else: |
|
return "InferenceEndpointsLLM" |
|
|
|
|
|
def _get_llm(use_magpie_template=False, **kwargs): |
|
if OPENAI_BASE_URL: |
|
llm = OpenAILLM( |
|
model=MODEL, |
|
base_url=OPENAI_BASE_URL, |
|
api_key=_get_next_api_key(), |
|
**kwargs, |
|
) |
|
if "generation_kwargs" in kwargs: |
|
if "stop_sequences" in kwargs["generation_kwargs"]: |
|
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][ |
|
"stop_sequences" |
|
] |
|
del kwargs["generation_kwargs"]["stop_sequences"] |
|
if "do_sample" in kwargs["generation_kwargs"]: |
|
del kwargs["generation_kwargs"]["do_sample"] |
|
elif OLLAMA_BASE_URL: |
|
if "generation_kwargs" in kwargs: |
|
if "max_new_tokens" in kwargs["generation_kwargs"]: |
|
kwargs["generation_kwargs"]["num_predict"] = kwargs[ |
|
"generation_kwargs" |
|
]["max_new_tokens"] |
|
del kwargs["generation_kwargs"]["max_new_tokens"] |
|
if "stop_sequences" in kwargs["generation_kwargs"]: |
|
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][ |
|
"stop_sequences" |
|
] |
|
del kwargs["generation_kwargs"]["stop_sequences"] |
|
if "do_sample" in kwargs["generation_kwargs"]: |
|
del kwargs["generation_kwargs"]["do_sample"] |
|
options = kwargs["generation_kwargs"] |
|
del kwargs["generation_kwargs"] |
|
kwargs["generation_kwargs"] = {} |
|
kwargs["generation_kwargs"]["options"] = options |
|
llm = OllamaLLM( |
|
model=MODEL, |
|
host=OLLAMA_BASE_URL, |
|
tokenizer_id=TOKENIZER_ID or MODEL, |
|
use_magpie_template=use_magpie_template, |
|
**kwargs, |
|
) |
|
elif HUGGINGFACE_BASE_URL: |
|
kwargs["generation_kwargs"]["do_sample"] = True |
|
llm = InferenceEndpointsLLM( |
|
api_key=_get_next_api_key(), |
|
base_url=HUGGINGFACE_BASE_URL, |
|
tokenizer_id=TOKENIZER_ID or MODEL, |
|
use_magpie_template=use_magpie_template, |
|
**kwargs, |
|
) |
|
elif VLLM_BASE_URL: |
|
if "generation_kwargs" in kwargs: |
|
if "do_sample" in kwargs["generation_kwargs"]: |
|
del kwargs["generation_kwargs"]["do_sample"] |
|
llm = ClientvLLM( |
|
base_url=VLLM_BASE_URL, |
|
model=MODEL, |
|
tokenizer=TOKENIZER_ID or MODEL, |
|
api_key=_get_next_api_key(), |
|
use_magpie_template=use_magpie_template, |
|
**kwargs, |
|
) |
|
else: |
|
llm = InferenceEndpointsLLM( |
|
api_key=_get_next_api_key(), |
|
tokenizer_id=TOKENIZER_ID or MODEL, |
|
model_id=MODEL, |
|
use_magpie_template=use_magpie_template, |
|
**kwargs, |
|
) |
|
|
|
return llm |
|
|
|
|
|
try: |
|
llm = _get_llm() |
|
llm.load() |
|
llm.generate([[{"content": "Hello, world!", "role": "user"}]]) |
|
except Exception as e: |
|
raise Exception(f"Error loading {llm.__class__.__name__}: {e}") |
|
|