H2OTest / llm_studio /src /datasets /conversation_chain_handler.py
elineve's picture
Upload 301 files
07423df
raw
history blame
9.07 kB
import logging
from typing import Dict, List
import numpy as np
from llm_studio.src.datasets.text_utils import get_texts
from llm_studio.src.utils.utils import PatchedAttribute
logger = logging.getLogger(__name__)
class ConversationChainHandler:
"""
This class partitions the dataset into chains of conversations.
Each chain is comprised of a list of conversation rounds.
Each round within a conversation is represented as a triplet:
(system, prompt, answer).
The resulting structure of the chains is conditional on
the DataFrame's structure and configuration:
- Without a 'parent_id' in the DataFrame, each conversation chain is a single round.
So, for every `i`-th row in the DataFrame, 0 <= `i` < len(df),
the chain would look like: [(system_i, prompt_i, answer_i)]
- With a 'parent_id' in the DataFrame and
if `cfg.dataset.limit_chained_samples` is set to False,
each chain encapsulates all preceding conversations
for every `i`-th row in the DataFrame,
0 <= `i` < len(df).
The resultant chain would take shape:
[(system_start_conversation_i,
prompt_start_conversation_i,
answer_start_conversation_i),
...,
(system_i, prompt_i, answer_i)]
- With a 'parent_id' in the DataFrame and
if `cfg.dataset.limit_chained_samples` is set to True,
each conversation chain incorporates only full conversations.
The chain hence condenses into:
[(system_start_conversation_i,
prompt_start_conversation_i,
answer_start_conversation_i),
...,
(system_end_conversation_i,
prompt_end_conversation_i,
answer_end_conversation_i)]
where `i` represents complete conversations only.
"""
def __init__(
self,
df,
cfg,
):
# Do not set self.cfg = cfg, as ConversationChainHandler
# will be used with PatchedAttribute context manager.
self.conversation_chain_ids = self.get_conversation_chain_ids(cfg, df)
self.prompts = get_texts(df, cfg, separator="")
self.answers = self.get_answers(df, cfg)
self.systems = self.get_systems(cfg, df)
def get_conversation_chain_ids(self, cfg, df):
"""
Gets the conversation chain IDs for the given DataFrame.
E.g. if conversation_chain_ids = [[13, 44, 8], ...],
then the first conversation chain consists of
[df.iloc[13], df.iloc[44], df.iloc[8]]
with
- df.iloc[13] denotes the first round of the conversation
- df.iloc[44] denotes the second round of the conversation
- df.iloc[8] denotes the end of the conversation
if limit_chained_samples is True, df.iloc[13] will have no parent_id,
i.e. it is the start of the conversation.
"""
if (
cfg.dataset.parent_id_column in ["None", None]
# Handle case where train Dataframe has conversation chains,
# but val Dataframe does not
or cfg.dataset.parent_id_column not in df.columns
):
# no parent id column, so each triplet (system_i, prompt_i, answer_i)
# is a conversation chain
return [[idx] for idx in range(len(df))]
assert "id" in df.columns, (
f"id column required for conversation chaining, "
f"DataFrame only has {df.columns}."
)
# sample and parent ids can have any dtype, such as str, int, float, etc.
# id column can be int, while parent_id column can be float
# (as some values are NaN) so we cast id to the same dtype
sample_ids = df["id"].astype(df[cfg.dataset.parent_id_column].dtype).tolist()
parent_ids = df[cfg.dataset.parent_id_column].tolist()
# Some datasets may include parent ids that are not in the dataset.
sample_ids_set = set(sample_ids)
parent_ids = [idx if idx in sample_ids_set else "None" for idx in parent_ids]
id2parent_id = {
idx: parent_id
for idx, parent_id in zip(sample_ids, parent_ids)
if parent_id not in [None, "None"]
and (
not isinstance(parent_id, float)
or (not np.isnan(parent_id) and not np.isinf(parent_id))
)
}
if cfg.dataset.limit_chained_samples:
# end id == id is not a parent id of another conversation id
valid_parent_ids = set(id2parent_id.values())
conversation_end_ids = [
idx for idx in sample_ids if idx not in valid_parent_ids
]
else:
conversation_end_ids = sample_ids
conversation_chain_ids = [
self.get_conversation_ids(id2parent_id, conversation_end_id)
for conversation_end_id in conversation_end_ids
]
# map from df["id"] to enumeration index
dataframeid2idx = {id: idx for idx, id in enumerate(sample_ids)}
conversation_chain_ids = [
[dataframeid2idx[conversation_id] for conversation_id in conversation_ids]
for conversation_ids in conversation_chain_ids
]
return conversation_chain_ids
def get_answers(self, df, cfg):
answer_column = cfg.dataset.answer_column
if answer_column in df.columns:
answers = df[answer_column].astype(str).tolist()
else:
answers = ["" for _ in range(len(self.prompts))]
return answers
def get_systems(self, cfg, df):
if cfg.dataset.system_column != "None":
if cfg.dataset.system_column not in df.columns:
logger.warning(
f"System column {cfg.dataset.system_column} not found."
f"Disabling functionality."
)
systems = ["" for _ in range(len(self.prompts))]
else:
systems = df[cfg.dataset.system_column].astype(str).tolist()
else:
systems = ["" for _ in range(len(self.prompts))]
return systems
@staticmethod
def get_conversation_ids(id2parent_id, end_id):
"""
Gets the conversation chain for a given starting conversation ID.
Args:
id2parent_id: A dictionary containing the mapping of IDs
to its previous parent ID.
end_id: The ID of the end of the conversation in the chain.
Returns:
A list of conversation IDs representing the conversation chain.
The chain is ordered from the first conversation id to end_id in the chain.
"""
# prevent infinite loops in case
# of circular parent chains (dataframe issue)
loop_counter = 0
conversation_chain_ids = [end_id]
parent_id = end_id
while parent_id in id2parent_id:
loop_counter += 1
parent_id = id2parent_id[parent_id]
conversation_chain_ids = [parent_id] + conversation_chain_ids
if loop_counter > 1000:
raise ValueError(
f"Parent chain of sample with idx {end_id} "
f"exceeds max loop count of 1000. "
f"Please ensure that parent chain is not circular."
)
return conversation_chain_ids
def __len__(self):
return len(self.conversation_chain_ids)
def __getitem__(self, idx):
"""
Gets a single conversation chain.
The conversation may be:
- a single (system, prompt, answer) round,
if cfg.dataset.parent_id_column == "None" or
there is no parent_id for the conversation
- a conversation potentially starting somewhere in
the middle of the conversation, if the conversation
is chained and limit_chained_samples is set to False
- always a complete conversation, if the conversation is chained
and limit_chained_samples is True
"""
prompts = [self.prompts[i] for i in self.conversation_chain_ids[idx]]
answers = [self.answers[i] for i in self.conversation_chain_ids[idx]]
systems = [self.systems[i] for i in self.conversation_chain_ids[idx]]
return {
"prompts": prompts,
"answers": answers,
"systems": systems,
}
def get_conversation_end_ids(self):
"""
Gets the end conversation IDs for each conversation chain.
"""
return [
conversation_chain[-1] for conversation_chain in self.conversation_chain_ids
]
def get_conversation_chains(
df, cfg, limit_chained_samples=True
) -> List[Dict[str, List[str]]]:
with PatchedAttribute(cfg.dataset, "limit_chained_samples", limit_chained_samples):
conversation_chain_handler = ConversationChainHandler(df, cfg)
conversations = [
conversation
for conversation in conversation_chain_handler # type: ignore[attr-defined]
]
return conversations