|
import numpy as np |
|
import pandas as pd |
|
import pytest |
|
import torch |
|
from tqdm import tqdm |
|
|
|
from llm_studio.python_configs.text_causal_language_modeling_config import ( |
|
ConfigNLPCausalLMTokenizer, |
|
) |
|
from llm_studio.python_configs.text_dpo_modeling_config import ( |
|
ConfigDPODataset, |
|
ConfigProblemBase, |
|
) |
|
from llm_studio.src.datasets.text_dpo_modeling_ds import CustomDataset |
|
|
|
|
|
@pytest.fixture |
|
def df(): |
|
return pd.DataFrame( |
|
{ |
|
"prompt_column": [f"prompt {i}" for i in range(200)], |
|
"answer_column": [f"chosen_response {i}" for i in range(200)], |
|
"rejected_answer_column": [f"rejected_response {i}" for i in range(200)], |
|
} |
|
) |
|
|
|
|
|
@pytest.fixture |
|
def df_with_conversation_chain_ids(): |
|
""" |
|
Create a dataframe with conversation chain ids, e.g.: |
|
prompt_column answer_column rejected_answer_column parent_id_column id |
|
0 prompt 1 response 1 response 1 None 1 |
|
1 prompt 2 response 2 response 2 1 2 |
|
2 prompt 3 response 3 response 3 2 3 |
|
3 prompt 4 response 4 response 4 3 4 |
|
4 prompt 5 chosen_response 5 rejected_response 5 4 5 |
|
5 prompt 6 response 6 response 6 None 6 |
|
""" |
|
ids = [str(i + 1) for i in range(200)] |
|
|
|
parent_ids = np.array(ids, dtype=object).reshape(-1, 5) |
|
parent_ids[:, -1] = "None" |
|
parent_ids = np.roll(parent_ids, 1, 1).reshape(-1) |
|
|
|
|
|
|
|
|
|
chosen_responses = [ |
|
f"chosen_response {idx}" if int(idx) % 5 == 0 else f"response {idx}" |
|
for idx in ids |
|
] |
|
rejected_responses = [ |
|
f"rejected_response {idx}" if int(idx) % 5 == 0 else f"response {idx}" |
|
for idx in ids |
|
] |
|
return pd.DataFrame( |
|
{ |
|
"prompt_column": [f"prompt {idx}" for idx in ids], |
|
"answer_column": chosen_responses, |
|
"rejected_answer_column": rejected_responses, |
|
"parent_id_column": parent_ids, |
|
"id": ids, |
|
} |
|
) |
|
|
|
|
|
def test_dataset_conversation_chain_is_correct(df_with_conversation_chain_ids): |
|
cfg = ConfigProblemBase( |
|
dataset=ConfigDPODataset( |
|
prompt_column=("prompt_column",), |
|
answer_column="answer_column", |
|
rejected_answer_column="rejected_answer_column", |
|
parent_id_column="parent_id_column", |
|
) |
|
) |
|
dataset = CustomDataset(df_with_conversation_chain_ids, cfg, mode="train") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for idx in range(200 // 5): |
|
for name, conversation_chain_handler in zip( |
|
["chosen", "rejected"], |
|
[ |
|
dataset.conversation_chain_handler, |
|
dataset.conversation_chain_handler_rejected, |
|
], |
|
): |
|
input_text_dict = conversation_chain_handler[idx] |
|
expected = { |
|
"prompts": [f"prompt {i + 1}" for i in range(idx * 5, (idx + 1) * 5)], |
|
"answers": [ |
|
f"response {i + 1}" for i in range(idx * 5, (idx + 1) * 5 - 1) |
|
] |
|
+ [f"{name}_response {idx * 5 + 5}"], |
|
"systems": [""] * 5, |
|
} |
|
|
|
for key in expected: |
|
assert input_text_dict[key] == expected[key], ( |
|
input_text_dict[key], |
|
expected[key], |
|
name, |
|
) |
|
|
|
|
|
def test_dataset_label_is_correct(df_with_conversation_chain_ids): |
|
cfg = ConfigProblemBase( |
|
dataset=ConfigDPODataset( |
|
prompt_column=("prompt_column",), |
|
answer_column="answer_column", |
|
rejected_answer_column="rejected_answer_column", |
|
parent_id_column="parent_id_column", |
|
) |
|
) |
|
dataset = CustomDataset(df_with_conversation_chain_ids, cfg, mode="train") |
|
|
|
for idx, item in enumerate(dataset): |
|
sample = dataset[idx] |
|
chosen_response = dataset.tokenizer.decode( |
|
sample["chosen_labels"][sample["chosen_labels"] != -100], |
|
skip_special_tokens=True, |
|
) |
|
rejected_response = dataset.tokenizer.decode( |
|
sample["rejected_labels"][sample["rejected_labels"] != -100], |
|
skip_special_tokens=True, |
|
) |
|
prompt = dataset.tokenizer.decode( |
|
sample["prompt_input_ids"][sample["prompt_input_ids"] != 0], |
|
skip_special_tokens=True, |
|
) |
|
|
|
assert ( |
|
prompt == f"<|prompt|>prompt {idx * 5 + 1} " |
|
f"<|answer|> response {idx * 5 + 1} " |
|
f"<|prompt|>prompt {idx * 5 + 2} " |
|
f"<|answer|> response {idx * 5 + 2} " |
|
f"<|prompt|>prompt {idx * 5 + 3} " |
|
f"<|answer|> response {idx * 5 + 3} " |
|
f"<|prompt|>prompt {idx * 5 + 4} " |
|
f"<|answer|> response {idx * 5 + 4} " |
|
f"<|prompt|>prompt {idx * 5 + 5} " |
|
"<|answer|>" |
|
) |
|
assert chosen_response == f"chosen_response {idx * 5 + 5}" |
|
assert rejected_response == f"rejected_response {idx * 5 + 5}" |
|
|
|
|
|
def test_dataloader_has_correct_keys(df): |
|
cfg = ConfigProblemBase( |
|
dataset=ConfigDPODataset( |
|
prompt_column=("prompt_column",), |
|
answer_column="answer_column", |
|
rejected_answer_column="rejected_answer_column", |
|
parent_id_column="None", |
|
) |
|
) |
|
|
|
dataset = CustomDataset(df, cfg, mode="train") |
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True) |
|
|
|
for idx, batch in tqdm(enumerate(dataloader), total=len(dataloader)): |
|
for key in batch: |
|
if idx != len(dataloader) - 1: |
|
assert batch[key].size(0) == 16, ( |
|
key, |
|
batch[key].shape, |
|
) |
|
|
|
keys = [ |
|
"chosen_input_ids", |
|
"chosen_attention_mask", |
|
"chosen_labels", |
|
"rejected_input_ids", |
|
"rejected_attention_mask", |
|
"rejected_labels", |
|
"prompt_input_ids", |
|
"prompt_attention_mask", |
|
] |
|
assert set(batch.keys()) - set(keys) == set() |
|
assert set(keys) - set(batch.keys()) == set() |
|
|
|
|
|
def test_empy_answer_dataset_throws_no_error(df): |
|
cfg = ConfigProblemBase( |
|
dataset=ConfigDPODataset( |
|
prompt_column=("prompt_column",), |
|
answer_column="answer_column", |
|
rejected_answer_column="rejected_answer_column", |
|
add_eos_token_to_answer=False, |
|
add_eos_token_to_prompt=False, |
|
add_eos_token_to_system=False, |
|
), |
|
) |
|
for column in ["prompt_column", "answer_column", "rejected_answer_column"]: |
|
values = df[column].values |
|
df[column] = "" |
|
dataset = CustomDataset(df, cfg, mode="train") |
|
[dataset[i] for i in range(len(dataset))] |
|
df[column] = values |
|
|
|
|
|
@pytest.fixture |
|
def df_single_prompt(): |
|
prompt = """when ordering your sandstones, you select which colour scale you would want. |
|
it could be e.g. a 100% from grey/sand mix, or 80% fra beige/yellow mixed with 20% from black/brown. |
|
This is all lower case. Can you fix that?""" |
|
system = """You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. |
|
While performing the task think step-by-step and justify your steps.""" |
|
answer = """When ordering your sandstones, you select which color scale you would want. It could be, for example, a 100% from grey/sand mix, or 80% from beige/yellow mixed with 20% from black/brown. |
|
Step 1: Capitalize the first letter of the sentence. |
|
Step 2: Correct the spelling of "color" (assuming American English usage). |
|
Step 3: Replace ", e.g." with "for example" to clarify the sentence. |
|
Step 4: Capitalize "a" in "100% from a grey/sand mix" |
|
Step 5: Ensure the proper usage of words and punctuation throughout the revised sentence.""" |
|
return pd.DataFrame( |
|
{ |
|
"prompt": [prompt], |
|
"system": [system], |
|
"answer": [answer], |
|
"rejected_answer": ["I cannot do that."], |
|
} |
|
) |
|
|
|
|
|
def generate_causal_lm_model_input_ids(df): |
|
from llm_studio.python_configs.text_causal_language_modeling_config import ( |
|
ConfigNLPCausalLMDataset, |
|
) |
|
from llm_studio.python_configs.text_causal_language_modeling_config import ( |
|
ConfigProblemBase as ConfigCausalLMProblemBase, |
|
) |
|
from llm_studio.src.datasets.text_causal_language_modeling_ds import ( |
|
CustomDataset as CausalLMCustomDataset, |
|
) |
|
|
|
cfg = ConfigCausalLMProblemBase( |
|
llm_backbone="h2oai/h2ogpt-4096-llama2-7b", |
|
dataset=ConfigNLPCausalLMDataset( |
|
system_column="system", |
|
prompt_column=("prompt",), |
|
answer_column="answer", |
|
), |
|
tokenizer=ConfigNLPCausalLMTokenizer( |
|
max_length_prompt=256, max_length_answer=256, max_length=512 |
|
), |
|
) |
|
dataset = CausalLMCustomDataset(df, cfg, mode="train") |
|
return dataset[0] |
|
|
|
|
|
def test_dataset_prompt_ids_are_the_same_as_for_causal_language_modeling( |
|
df_single_prompt, |
|
): |
|
""" |
|
DPO model should generate the same prompts as causal language modeling |
|
""" |
|
generated_text_causal_lm = generate_causal_lm_model_input_ids(df_single_prompt) |
|
|
|
cfg = ConfigProblemBase( |
|
llm_backbone="h2oai/h2ogpt-4096-llama2-7b", |
|
dataset=ConfigDPODataset( |
|
system_column="system", |
|
prompt_column=("prompt",), |
|
answer_column="answer", |
|
rejected_answer_column="rejected_answer", |
|
), |
|
tokenizer=ConfigNLPCausalLMTokenizer( |
|
max_length_prompt=256, max_length_answer=256, max_length=512 |
|
), |
|
) |
|
dataset = CustomDataset(df_single_prompt, cfg, mode="train") |
|
generated_text = dataset[0] |
|
|
|
for key in ["prompt_input_ids", "prompt_attention_mask"]: |
|
assert torch.all( |
|
generated_text_causal_lm[key] == generated_text[key] |
|
), f"{key} is not the same" |
|
|