Sara Han
Add RAG generation (#19)
0d14ea5 unverified
import math
import random
from distilabel.models import ClientvLLM, InferenceEndpointsLLM, OllamaLLM, OpenAILLM
from distilabel.steps.tasks import TextGeneration
from synthetic_dataset_generator.constants import (
API_KEYS,
DEFAULT_BATCH_SIZE,
HUGGINGFACE_BASE_URL,
MODEL,
OLLAMA_BASE_URL,
OPENAI_BASE_URL,
TOKENIZER_ID,
VLLM_BASE_URL,
)
TOKEN_INDEX = 0
def _get_next_api_key():
global TOKEN_INDEX
api_key = API_KEYS[TOKEN_INDEX % len(API_KEYS)]
TOKEN_INDEX += 1
return api_key
def _get_prompt_rewriter():
generation_kwargs = {
"temperature": 1,
}
system_prompt = "You are a prompt rewriter. You are given a prompt and you need to rewrite it keeping the same structure but highlighting different aspects of the original without adding anything new."
prompt_rewriter = TextGeneration(
llm=_get_llm(generation_kwargs=generation_kwargs),
system_prompt=system_prompt,
use_system_prompt=True,
)
prompt_rewriter.load()
return prompt_rewriter
def get_rewritten_prompts(prompt: str, num_rows: int):
prompt_rewriter = _get_prompt_rewriter()
# create prompt rewrites
inputs = [
{"instruction": f"Original prompt: {prompt} \nRewritten prompt: "}
for i in range(math.floor(num_rows / 100))
]
n_processed = 0
prompt_rewrites = [prompt]
while n_processed < num_rows:
batch = list(
prompt_rewriter.process(
inputs=inputs[n_processed : n_processed + DEFAULT_BATCH_SIZE]
)
)
prompt_rewrites += [entry["generation"] for entry in batch[0]]
n_processed += DEFAULT_BATCH_SIZE
random.seed(a=random.randint(0, 2**32 - 1))
return prompt_rewrites
def _get_llm_class() -> str:
if OPENAI_BASE_URL:
return "OpenAILLM"
elif OLLAMA_BASE_URL:
return "OllamaLLM"
elif HUGGINGFACE_BASE_URL:
return "InferenceEndpointsLLM"
elif VLLM_BASE_URL:
return "ClientvLLM"
else:
return "InferenceEndpointsLLM"
def _get_llm(use_magpie_template=False, **kwargs):
if OPENAI_BASE_URL:
llm = OpenAILLM(
model=MODEL,
base_url=OPENAI_BASE_URL,
api_key=_get_next_api_key(),
**kwargs,
)
if "generation_kwargs" in kwargs:
if "stop_sequences" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
"stop_sequences"
]
del kwargs["generation_kwargs"]["stop_sequences"]
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
elif OLLAMA_BASE_URL:
if "generation_kwargs" in kwargs:
if "max_new_tokens" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["num_predict"] = kwargs[
"generation_kwargs"
]["max_new_tokens"]
del kwargs["generation_kwargs"]["max_new_tokens"]
if "stop_sequences" in kwargs["generation_kwargs"]:
kwargs["generation_kwargs"]["stop"] = kwargs["generation_kwargs"][
"stop_sequences"
]
del kwargs["generation_kwargs"]["stop_sequences"]
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
options = kwargs["generation_kwargs"]
del kwargs["generation_kwargs"]
kwargs["generation_kwargs"] = {}
kwargs["generation_kwargs"]["options"] = options
llm = OllamaLLM(
model=MODEL,
host=OLLAMA_BASE_URL,
tokenizer_id=TOKENIZER_ID or MODEL,
use_magpie_template=use_magpie_template,
**kwargs,
)
elif HUGGINGFACE_BASE_URL:
kwargs["generation_kwargs"]["do_sample"] = True
llm = InferenceEndpointsLLM(
api_key=_get_next_api_key(),
base_url=HUGGINGFACE_BASE_URL,
tokenizer_id=TOKENIZER_ID or MODEL,
use_magpie_template=use_magpie_template,
**kwargs,
)
elif VLLM_BASE_URL:
if "generation_kwargs" in kwargs:
if "do_sample" in kwargs["generation_kwargs"]:
del kwargs["generation_kwargs"]["do_sample"]
llm = ClientvLLM(
base_url=VLLM_BASE_URL,
model=MODEL,
tokenizer=TOKENIZER_ID or MODEL,
api_key=_get_next_api_key(),
use_magpie_template=use_magpie_template,
**kwargs,
)
else:
llm = InferenceEndpointsLLM(
api_key=_get_next_api_key(),
tokenizer_id=TOKENIZER_ID or MODEL,
model_id=MODEL,
use_magpie_template=use_magpie_template,
**kwargs,
)
return llm
try:
llm = _get_llm()
llm.load()
llm.generate([[{"content": "Hello, world!", "role": "user"}]])
except Exception as e:
raise Exception(f"Error loading {llm.__class__.__name__}: {e}")