|
import logging |
|
from typing import Dict, List |
|
|
|
import numpy as np |
|
|
|
from llm_studio.src.datasets.text_utils import get_texts |
|
from llm_studio.src.utils.utils import PatchedAttribute |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ConversationChainHandler: |
|
""" |
|
This class partitions the dataset into chains of conversations. |
|
Each chain is comprised of a list of conversation rounds. |
|
Each round within a conversation is represented as a triplet: |
|
(system, prompt, answer). |
|
|
|
The resulting structure of the chains is conditional on |
|
the DataFrame's structure and configuration: |
|
|
|
- Without a 'parent_id' in the DataFrame, each conversation chain is a single round. |
|
So, for every `i`-th row in the DataFrame, 0 <= `i` < len(df), |
|
the chain would look like: [(system_i, prompt_i, answer_i)] |
|
|
|
- With a 'parent_id' in the DataFrame and |
|
if `cfg.dataset.limit_chained_samples` is set to False, |
|
each chain encapsulates all preceding conversations |
|
for every `i`-th row in the DataFrame, |
|
0 <= `i` < len(df). |
|
The resultant chain would take shape: |
|
[(system_start_conversation_i, |
|
prompt_start_conversation_i, |
|
answer_start_conversation_i), |
|
..., |
|
(system_i, prompt_i, answer_i)] |
|
|
|
- With a 'parent_id' in the DataFrame and |
|
if `cfg.dataset.limit_chained_samples` is set to True, |
|
each conversation chain incorporates only full conversations. |
|
The chain hence condenses into: |
|
[(system_start_conversation_i, |
|
prompt_start_conversation_i, |
|
answer_start_conversation_i), |
|
..., |
|
(system_end_conversation_i, |
|
prompt_end_conversation_i, |
|
answer_end_conversation_i)] |
|
where `i` represents complete conversations only. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
df, |
|
cfg, |
|
): |
|
|
|
|
|
self.conversation_chain_ids = self.get_conversation_chain_ids(cfg, df) |
|
self.prompts = get_texts(df, cfg, separator="") |
|
self.answers = self.get_answers(df, cfg) |
|
self.systems = self.get_systems(cfg, df) |
|
|
|
def get_conversation_chain_ids(self, cfg, df): |
|
""" |
|
Gets the conversation chain IDs for the given DataFrame. |
|
E.g. if conversation_chain_ids = [[13, 44, 8], ...], |
|
then the first conversation chain consists of |
|
[df.iloc[13], df.iloc[44], df.iloc[8]] |
|
with |
|
- df.iloc[13] denotes the first round of the conversation |
|
- df.iloc[44] denotes the second round of the conversation |
|
- df.iloc[8] denotes the end of the conversation |
|
if limit_chained_samples is True, df.iloc[13] will have no parent_id, |
|
i.e. it is the start of the conversation. |
|
""" |
|
if ( |
|
cfg.dataset.parent_id_column in ["None", None] |
|
|
|
|
|
or cfg.dataset.parent_id_column not in df.columns |
|
): |
|
|
|
|
|
return [[idx] for idx in range(len(df))] |
|
|
|
assert "id" in df.columns, ( |
|
f"id column required for conversation chaining, " |
|
f"DataFrame only has {df.columns}." |
|
) |
|
|
|
|
|
|
|
sample_ids = df["id"].astype(df[cfg.dataset.parent_id_column].dtype).tolist() |
|
parent_ids = df[cfg.dataset.parent_id_column].tolist() |
|
|
|
sample_ids_set = set(sample_ids) |
|
parent_ids = [idx if idx in sample_ids_set else "None" for idx in parent_ids] |
|
|
|
id2parent_id = { |
|
idx: parent_id |
|
for idx, parent_id in zip(sample_ids, parent_ids) |
|
if parent_id not in [None, "None"] |
|
and ( |
|
not isinstance(parent_id, float) |
|
or (not np.isnan(parent_id) and not np.isinf(parent_id)) |
|
) |
|
} |
|
if cfg.dataset.limit_chained_samples: |
|
|
|
valid_parent_ids = set(id2parent_id.values()) |
|
conversation_end_ids = [ |
|
idx for idx in sample_ids if idx not in valid_parent_ids |
|
] |
|
else: |
|
conversation_end_ids = sample_ids |
|
conversation_chain_ids = [ |
|
self.get_conversation_ids(id2parent_id, conversation_end_id) |
|
for conversation_end_id in conversation_end_ids |
|
] |
|
|
|
dataframeid2idx = {id: idx for idx, id in enumerate(sample_ids)} |
|
conversation_chain_ids = [ |
|
[dataframeid2idx[conversation_id] for conversation_id in conversation_ids] |
|
for conversation_ids in conversation_chain_ids |
|
] |
|
return conversation_chain_ids |
|
|
|
def get_answers(self, df, cfg): |
|
answer_column = cfg.dataset.answer_column |
|
if answer_column in df.columns: |
|
answers = df[answer_column].astype(str).tolist() |
|
else: |
|
answers = ["" for _ in range(len(self.prompts))] |
|
return answers |
|
|
|
def get_systems(self, cfg, df): |
|
if cfg.dataset.system_column != "None": |
|
if cfg.dataset.system_column not in df.columns: |
|
logger.warning( |
|
f"System column {cfg.dataset.system_column} not found." |
|
f"Disabling functionality." |
|
) |
|
systems = ["" for _ in range(len(self.prompts))] |
|
else: |
|
systems = df[cfg.dataset.system_column].astype(str).tolist() |
|
else: |
|
systems = ["" for _ in range(len(self.prompts))] |
|
return systems |
|
|
|
@staticmethod |
|
def get_conversation_ids(id2parent_id, end_id): |
|
""" |
|
Gets the conversation chain for a given starting conversation ID. |
|
Args: |
|
id2parent_id: A dictionary containing the mapping of IDs |
|
to its previous parent ID. |
|
end_id: The ID of the end of the conversation in the chain. |
|
Returns: |
|
A list of conversation IDs representing the conversation chain. |
|
The chain is ordered from the first conversation id to end_id in the chain. |
|
""" |
|
|
|
|
|
loop_counter = 0 |
|
|
|
conversation_chain_ids = [end_id] |
|
parent_id = end_id |
|
while parent_id in id2parent_id: |
|
loop_counter += 1 |
|
|
|
parent_id = id2parent_id[parent_id] |
|
conversation_chain_ids = [parent_id] + conversation_chain_ids |
|
if loop_counter > 1000: |
|
raise ValueError( |
|
f"Parent chain of sample with idx {end_id} " |
|
f"exceeds max loop count of 1000. " |
|
f"Please ensure that parent chain is not circular." |
|
) |
|
return conversation_chain_ids |
|
|
|
def __len__(self): |
|
return len(self.conversation_chain_ids) |
|
|
|
def __getitem__(self, idx): |
|
""" |
|
Gets a single conversation chain. |
|
The conversation may be: |
|
- a single (system, prompt, answer) round, |
|
if cfg.dataset.parent_id_column == "None" or |
|
there is no parent_id for the conversation |
|
- a conversation potentially starting somewhere in |
|
the middle of the conversation, if the conversation |
|
is chained and limit_chained_samples is set to False |
|
- always a complete conversation, if the conversation is chained |
|
and limit_chained_samples is True |
|
|
|
""" |
|
prompts = [self.prompts[i] for i in self.conversation_chain_ids[idx]] |
|
answers = [self.answers[i] for i in self.conversation_chain_ids[idx]] |
|
systems = [self.systems[i] for i in self.conversation_chain_ids[idx]] |
|
return { |
|
"prompts": prompts, |
|
"answers": answers, |
|
"systems": systems, |
|
} |
|
|
|
def get_conversation_end_ids(self): |
|
""" |
|
Gets the end conversation IDs for each conversation chain. |
|
""" |
|
return [ |
|
conversation_chain[-1] for conversation_chain in self.conversation_chain_ids |
|
] |
|
|
|
|
|
def get_conversation_chains( |
|
df, cfg, limit_chained_samples=True |
|
) -> List[Dict[str, List[str]]]: |
|
with PatchedAttribute(cfg.dataset, "limit_chained_samples", limit_chained_samples): |
|
conversation_chain_handler = ConversationChainHandler(df, cfg) |
|
conversations = [ |
|
conversation |
|
for conversation in conversation_chain_handler |
|
] |
|
return conversations |
|
|