bling-1.4b-0.1 / generation_test_hf_script.py
doberst's picture
Upload 2 files
5000a4b
raw
history blame
2.38 kB
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
def load_rag_benchmark_tester_ds():
# pull 200 question rag benchmark test dataset from LLMWare HuggingFace repo
from datasets import load_dataset
ds_name = "llmware/rag_instruct_benchmark_tester"
dataset = load_dataset(ds_name)
print("update: loading test dataset - ", dataset)
test_set = []
for i, samples in enumerate(dataset["train"]):
test_set.append(samples)
# to view test set samples
# print("rag benchmark dataset test samples: ", i, samples)
return test_set
def run_test(model_name, test_ds):
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
for i, entries in enumerate(test_ds):
# prepare prompt packaging used in fine-tuning process
new_prompt = "<human>: " + entries["context"] + "\n" + entries["query"] + "\n" + "<bot>:"
inputs = tokenizer(new_prompt, return_tensors="pt")
start_of_output = len(inputs.input_ids[0])
# temperature: set at 0.3 for consistency of output
# max_new_tokens: set at 100 - may prematurely stop a few of the summaries
outputs = model.generate(
inputs.input_ids.to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.3,
max_new_tokens=100,
)
output_only = tokenizer.decode(outputs[0][start_of_output:],skip_special_tokens=True)
# quick/optional post-processing clean-up of potential fine-tuning artifacts
eot = output_only.find("<|endoftext|>")
if eot > -1:
output_only = output_only[:eot]
bot = output_only.find("<bot>:")
if bot > -1:
output_only = output_only[bot+len("<bot>:"):]
# end - post-processing
print("\n")
print(i, "llm_response - ", output_only)
print(i, "gold_answer - ", entries["answer"])
return 0
if __name__ == "__main__":
test_ds = load_rag_benchmark_tester_ds()
model_name = "llmware/bling-1.4b-0.1"
output = run_test(model_name,test_ds)