H2OTest / llm_studio /src /utils /data_utils.py
elineve's picture
Upload 301 files
07423df
raw
history blame
23.1 kB
import logging
import math
import os
from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Union, no_type_check
import networkx as nx
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import torch
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from torch import distributed as dist
from torch.utils.data import DataLoader, Sampler, SequentialSampler
from llm_studio.src.datasets.conversation_chain_handler import ConversationChainHandler
from llm_studio.src.utils.exceptions import LLMDataException
from llm_studio.src.utils.gpu_utils import sync_across_processes
from llm_studio.src.utils.utils import PatchedAttribute, set_seed
logger = logging.getLogger(__name__)
def read_dataframe(
path: str,
n_rows: int = -1,
meta_only: bool = False,
non_missing_columns: Optional[List[str]] = None,
verbose: bool = False,
handling: str = "warn",
fill_columns: Optional[List[str]] = None,
fill_value: Any = "",
mode: str = "",
) -> pd.DataFrame:
"""Reading a dataframe from different file types
Args:
path: path of the dataframe
n_rows: number of rows to limit to
meta_only: return only meta information
non_missing_columns: list of columns that cannot contain missing values
verbose: if warning about dropped rows should be logged
handling: how to handle missing values
fill_columns: columns where empty value should be filled (used for empty text)
fill_value: value to fill empty columns with (used for empty text)
mode: dataset type, used only for better exception/log information
Returns:
dataframe
"""
non_missing_columns = [] if non_missing_columns is None else non_missing_columns
fill_columns = [] if fill_columns is None else fill_columns
meta_info_path = os.path.split(path)
meta_info_path = os.path.join(
*meta_info_path[:-1],
"__meta_info__" + meta_info_path[-1] + ".csv",
)
if meta_only and os.path.exists(meta_info_path):
path = meta_info_path
if path.endswith(".csv"):
df = pd.read_csv(path, lineterminator="\n").reset_index(drop=True)
elif path.endswith(".pq") or path.endswith(".parquet"):
try:
df = pd.read_parquet(path, engine="pyarrow").reset_index(drop=True)
except Exception:
df = pd.read_parquet(path, engine="fastparquet").reset_index(drop=True)
elif path.endswith(".json") or path == "":
return pd.DataFrame()
else:
raise ValueError(
f"Could not determine type of file {path}: "
f"CSV (`.csv`) and Parquet (`.pq` and `.parquet`) are supported."
)
if fill_columns:
df[fill_columns] = df[fill_columns].fillna(fill_value)
if meta_only and os.path.exists(meta_info_path):
return df
non_missing_columns = [x for x in non_missing_columns if x in df]
if len(non_missing_columns):
orig_size = df.shape[0]
non_missing_index = df[non_missing_columns].dropna().index
dropped_index = [idx for idx in df.index if idx not in non_missing_index]
df = df.loc[non_missing_index].reset_index(drop=True)
new_size = df.shape[0]
if new_size < orig_size and verbose:
logger.warning(
f"Dropped {orig_size - new_size} rows when reading dataframe '{path}' "
f"due to missing values encountered in one of the following columns:"
f" {non_missing_columns} in the following rows: {dropped_index}"
)
if handling == "error":
dropped_str = dropped_index
if len(dropped_str) > 10:
dropped_str = dropped_str[:5] + ["..."] + dropped_str[-5:]
dropped_str = ", ".join([str(x) for x in dropped_str])
prefix = f"{mode} " if mode else ""
error = (
f"{prefix}dataset contains {len(dropped_index)} rows with missing "
f"values in one of the following columns: {non_missing_columns} in "
f"the following rows: {dropped_str}"
)
raise ValueError(error.capitalize())
if n_rows > -1:
df = df.iloc[sample_indices(len(df), n_indices=n_rows)]
# create meta information dataframe if it does not exist
if not os.path.exists(meta_info_path):
df_meta = pd.DataFrame(columns=df.columns)
df_meta.to_csv(meta_info_path, index=False)
return df
def get_fill_columns(cfg: Any) -> List[str]:
if hasattr(cfg.dataset, "prompt_column"):
if isinstance(cfg.dataset.prompt_column, (list, tuple)):
return list(cfg.dataset.prompt_column)
return [cfg.dataset.prompt_column]
return []
def read_dataframe_drop_missing_labels(path: str, cfg: Any) -> pd.DataFrame:
if isinstance(cfg.dataset.prompt_column, tuple):
input_cols = list(cfg.dataset.prompt_column)
else:
input_cols = [cfg.dataset.prompt_column]
verbose = cfg.environment._local_rank == 0
fill_columns = get_fill_columns(cfg)
df = read_dataframe(
path,
non_missing_columns=input_cols,
verbose=verbose,
fill_columns=fill_columns,
)
df[input_cols] = df[input_cols].fillna("").astype(str)
if (
hasattr(cfg.dataset, "answer_column")
and cfg.dataset.answer_column in df.columns
):
df[cfg.dataset.answer_column] = (
df[cfg.dataset.answer_column].fillna("").astype(str)
)
return df
def is_valid_data_frame(path: str, csv_rows: int = 100) -> bool:
"""Checking data frame format
Args:
path: path of the dataframe
csv_rows: number of rows to limit to when checking csv files
Returns:
bool
"""
try:
if path.endswith(".csv"):
pd.read_csv(path, nrows=csv_rows, lineterminator="\n")
elif path.endswith(".pq") or path.endswith(".parquet"):
pq.ParquetFile(path)
else:
raise ValueError(
f"Could not determine type of file {path}: "
f"CSV (`.csv`) and Parquet (`.pq` and `.parquet`) are supported."
)
except Exception as e:
logger.error(str(e))
return False
return True
def sample_data(cfg: Any, df: pd.DataFrame) -> pd.DataFrame:
"""Sample data from the dataframe"""
if cfg.dataset.parent_id_column != "None" and "id" in df.columns:
parent_mapping = df.set_index("id")["parent_id"].to_dict()
# A recursive function to get the root id for each node
def get_root(node):
parent = parent_mapping.get(node)
if parent is None or pd.isna(parent):
return node
return get_root(parent)
# Apply the function to assign each row the root id
df["root_id"] = df["id"].apply(get_root)
# Sample root_ids without replacement
root_ids = df["root_id"].unique()
n_sampled_root_ids = int(len(root_ids) * cfg.dataset.data_sample)
np.random.seed(7331)
sampled_root_ids = np.random.choice(
root_ids, size=n_sampled_root_ids, replace=False
)
# Filter the dataframe to only include rows with sampled root_ids
df = df[df["root_id"].isin(sampled_root_ids)].reset_index(drop=True)
del df["root_id"]
else:
# at least 10 observations
n = max(10, int(len(df) * cfg.dataset.data_sample))
df = df.sample(n=min(n, len(df)), random_state=7331, replace=False)
return df
def load_mt_bench_data(cfg: Any) -> pd.DataFrame:
"""Loads MT-BENCH data.
Args:
cfg: input config
Returns:
MT-BENCH DataFrame
"""
prompt_column = cfg.dataset.prompt_column[0]
answer_column = cfg.dataset.answer_column
df = df = pd.read_json("prompts/mt-bench/question.jsonl", lines=True)
df = df.rename(columns={"turns": prompt_column, "reference": answer_column})
df[prompt_column] = df[prompt_column].apply(lambda x: x[0])
df[answer_column] = (
df[answer_column].fillna("").apply(lambda x: x[0] if x != "" else x)
)
return df
def get_data(cfg: Any) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""Prepares train and validation DataFrames.
Args:
cfg: input config
Returns:
Train and validation DataFrames
"""
train_df, val_df = load_train_valid_data(cfg)
if (
hasattr(cfg.prediction, "metric_gpt_template")
and cfg.prediction.metric_gpt_template == "mt-bench"
):
if cfg.environment._local_rank == 0:
logger.info(
"Overwriting validation data with MT-BENCH data. Please note that "
"respective metric is an approximation and might not fully match "
"the original implementation."
)
val_df = load_mt_bench_data(cfg)
if cfg.dataset.data_sample < 1.0:
if "Train" in cfg.dataset.data_sample_choice:
train_df = sample_data(cfg, train_df)
if "Validation" in cfg.dataset.data_sample_choice:
val_df = sample_data(cfg, val_df)
if cfg.training.train_validation_data:
train_df = pd.concat([train_df, val_df], axis=0)
train_df = cfg.dataset.dataset_class.preprocess_dataframe(
train_df, cfg, mode="train"
)
val_df = cfg.dataset.dataset_class.preprocess_dataframe(
val_df, cfg, mode="validation"
)
return train_df.reset_index(drop=True), val_df.reset_index(drop=True)
def merge_on_common_items(lst):
G = nx.Graph()
for sublst in lst:
for item in sublst:
G.add_edge(sublst[0], item)
return [list(c) for c in nx.connected_components(G)]
def load_train_valid_data(cfg) -> Tuple[pd.DataFrame, pd.DataFrame]:
if cfg.dataset.validation_strategy == "custom":
if cfg.dataset.validation_dataframe == "None":
raise LLMDataException(
"No validation dataframe provided. "
"Please provide a validation dataframe or "
"choose a different validation strategy."
)
train_df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)
val_df = read_dataframe_drop_missing_labels(
cfg.dataset.validation_dataframe, cfg
)
elif cfg.dataset.validation_strategy == "automatic":
if cfg.environment._local_rank == 0:
logger.info("Setting up automatic validation split...")
df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)
if cfg.dataset.parent_id_column != "None" and "id" in df.columns:
# split based on conversation_chain_ids
# this ensures that all samples from the
# same conversation are in the same fold
with PatchedAttribute(cfg.dataset, "limit_chained_samples", True):
conversation_chain_ids = ConversationChainHandler(
df=df, cfg=cfg
).conversation_chain_ids
# Some conversations may have the same parent id, e.g. for OASST
# 6aa548c6-65ad-4531-9411-76173ae060a3 and
# 2a164c2a-4f0e-45aa-8990-e7dd3b51c06b
# have the same parent a8df94e3-cfc7-4736-9587-0ec943d0fec3
# We need to merge those into a single group
conversation_chain_ids = merge_on_common_items(conversation_chain_ids)
conversation_chain_labels = [
i
for i, conversation_chain_id in enumerate(conversation_chain_ids)
for _ in conversation_chain_id
]
group_shuffle_split = GroupShuffleSplit(
test_size=cfg.dataset.validation_size, n_splits=1, random_state=1337
)
train_idx, val_idx = next(
group_shuffle_split.split(df, groups=conversation_chain_labels)
)
# flatten conversation_chain_ids
flattened_conversation_chain_ids = np.array(
[
idx
for conversation_chain_id in conversation_chain_ids
for idx in conversation_chain_id
]
)
train_df = df.iloc[flattened_conversation_chain_ids[train_idx]].reset_index(
drop=True
)
val_df = df.iloc[flattened_conversation_chain_ids[val_idx]].reset_index(
drop=True
)
else:
train_df, val_df = train_test_split(
df, test_size=cfg.dataset.validation_size, random_state=1337
)
else:
raise LLMDataException("No valid validation strategy provided.")
return train_df, val_df
def worker_init_fn(worker_id: int) -> None:
"""Sets the random seed for each worker.
Args:
worker_id: ID of the corresponding worker
"""
if "PYTHONHASHSEED" in os.environ:
seed = int(os.environ["PYTHONHASHSEED"]) + worker_id
else:
seed = np.random.get_state()[1][0] + worker_id # type: ignore
set_seed(seed)
def get_train_dataset(train_df: pd.DataFrame, cfg: Any, verbose=True):
"""Prepares train Dataset.
Args:
train_df: train DataFrame
cfg: input config
verbose: whether to print the logs
Returns:
Train Dataset
"""
if cfg.environment._local_rank == 0 and verbose:
logger.info("Loading train dataset...")
train_dataset = cfg.dataset.dataset_class(df=train_df, cfg=cfg, mode="train")
return train_dataset
def get_train_dataloader(train_ds: Any, cfg: Any, verbose=True):
"""Prepares train DataLoader.
Args:
train_ds: train Dataset
cfg: input config
verbose: whether to print the logs
Returns:
Train Dataloader
"""
sampler: Sampler
if cfg.environment._distributed:
sampler = torch.utils.data.distributed.DistributedSampler(
train_ds,
num_replicas=cfg.environment._world_size,
rank=cfg.environment._local_rank,
shuffle=True,
seed=cfg.environment._seed,
drop_last=True,
)
sampler_length = len(sampler)
else:
sampler = None
sampler_length = len(train_ds)
if sampler_length < cfg.training.batch_size and cfg.training.drop_last_batch:
logger.warning(
"Training data too small when dropping last batch. Number of rows "
"should be at least batch size multiplied by number of gpus. "
"Forcing to keep last batch."
)
cfg.training.drop_last_batch = False
if sampler_length <= 1:
raise LLMDataException("Data too small to train model.")
train_dataloader = DataLoader(
train_ds,
sampler=sampler,
shuffle=(sampler is None),
batch_size=cfg.training.batch_size,
num_workers=cfg.environment.number_of_workers,
pin_memory=True,
collate_fn=train_ds.get_train_collate_fn(),
drop_last=cfg.training.drop_last_batch,
worker_init_fn=worker_init_fn,
)
if cfg.environment._local_rank == 0 and verbose:
logger.info(f"Number of observations in train dataset: {len(train_ds)}")
return train_dataloader
def get_val_dataset(val_df: pd.DataFrame, cfg: Any, verbose: bool = True):
"""Prepares validation Dataset.
Args:
val_df: validation DataFrame
cfg: input config
verbose: verbose
Returns:
Validation Dataset
"""
if verbose and cfg.environment._local_rank == 0:
logger.info("Loading validation dataset...")
val_dataset = cfg.dataset.dataset_class(df=val_df, cfg=cfg, mode="validation")
return val_dataset
def get_val_dataloader(val_ds: Any, cfg: Any, verbose: bool = True):
"""Prepares validation DataLoader.
Args:
val_ds: validation Dataset
cfg: input config
verbose: verbose
Returns:
Validation Dataloader
"""
sampler: Sampler
if cfg.environment._distributed and cfg.environment._distributed_inference:
sampler = OrderedDistributedSampler(
val_ds,
num_replicas=cfg.environment._world_size,
rank=cfg.environment._local_rank,
)
else:
sampler = SequentialSampler(val_ds)
batch_size = get_inference_batch_size(cfg)
val_dataloader = DataLoader(
val_ds,
sampler=sampler,
batch_size=batch_size,
num_workers=cfg.environment.number_of_workers,
pin_memory=True,
collate_fn=val_ds.get_validation_collate_fn(),
worker_init_fn=worker_init_fn,
)
if verbose and cfg.environment._local_rank == 0:
logger.info(f"Number of observations in validation dataset: {len(val_ds)}")
return val_dataloader
@no_type_check
def cat_batches(
data: DefaultDict[str, Union[torch.Tensor, np.ndarray]]
) -> DefaultDict[str, Union[torch.Tensor, np.ndarray]]:
"""Concatenates output data from several batches
Args:
data: dict with keys and list of batch outputs
Returns:
Concatenated dict
"""
for key, value in data.items():
if len(value[0].shape) == 0:
if isinstance(value[0], torch.Tensor):
data[key] = torch.stack(value)
else:
data[key] = np.stack(value)
else:
if isinstance(value[0], torch.Tensor):
data[key] = torch.cat(value, dim=0)
else:
data[key] = np.concatenate(value, axis=0)
return data
class OrderedDistributedSampler(Sampler):
"""
Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
Source:
https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/distributed_sampler.py
"""
def __init__(
self,
dataset: Any,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
):
"""
Args:
dataset: Dataset used for sampling
num_replicas: Number of processes participating in distributed training
rank: Rank of the current process within num_replicas
"""
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
def __iter__(self):
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += [0] * (self.total_size - len(indices))
assert len(indices) == self.total_size
# subsample
indices = indices[
self.rank * self.num_samples : self.rank * self.num_samples
+ self.num_samples
]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self):
return self.num_samples
def sample_indices(length: int, n_indices: int = 10, seed: int = 1337) -> np.ndarray:
"""Samples random indices
Args:
length: length to sample from
n_indices: number of indices to sample
seed: seed for sampling
Returns:
sampled indices
"""
state = np.random.get_state()
np.random.seed(seed)
idx = np.random.choice(
np.arange(length), size=min(length, n_indices), replace=False
)
np.random.set_state(state)
return idx
def get_inference_batch_size(cfg: Any) -> int:
"""Calculates inference batch size
Args:
cfg: config with all the hyperparameters
Returns:
Inference batch size
"""
if cfg.prediction.batch_size_inference != 0:
return cfg.prediction.batch_size_inference
else:
return cfg.training.batch_size
def sanity_check(cfg):
"""
Perform sanity check on the data
"""
df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)
cfg.dataset.dataset_class.sanity_check(df=df, cfg=cfg, mode="train")
valid_filename = cfg.dataset.validation_dataframe
if isinstance(valid_filename, str) and os.path.exists(valid_filename):
df = read_dataframe_drop_missing_labels(valid_filename, cfg)
cfg.dataset.dataset_class.sanity_check(df=df, cfg=cfg, mode="validation")
def batch_padding(
cfg: Any,
batch: Dict,
training: bool = True,
mask_key: str = "attention_mask",
pad_keys: List[str] = ["input_ids", "attention_mask", "special_tokens_mask"],
padding_side: str = "left",
) -> Dict:
"""Pads a batch according to set quantile, or cuts it at maximum length"""
if cfg.environment.compile_model:
# logger.warning("Batch padding not functional with torch compile.")
return batch
elif batch[mask_key].sum() == 0:
# continued pretraining
return batch
elif cfg.tokenizer.padding_quantile == 0:
return batch
elif training and cfg.tokenizer.padding_quantile < 1.0:
if padding_side == "left":
lengths = torch.stack(
[
torch.where(batch[mask_key][i] == 1)[0].min()
for i in range(batch[mask_key].size(0))
]
).float()
quantile = 1 - cfg.tokenizer.padding_quantile
else:
lengths = torch.stack(
[
torch.where(batch[mask_key][i] == 1)[0].max()
for i in range(batch[mask_key].size(0))
]
).float()
quantile = cfg.tokenizer.padding_quantile
if cfg.environment._distributed:
lengths = sync_across_processes(
lengths, cfg.environment._world_size
) # type: ignore
idx = int(torch.floor(torch.quantile(lengths, quantile)))
else:
if padding_side == "left":
idx = int(torch.where(batch[mask_key] == 1)[1].min())
else:
idx = int(torch.where(batch[mask_key] == 1)[1].max())
if padding_side == "left":
for key in pad_keys:
if key in batch:
batch[key] = batch[key][:, idx:].contiguous()
else:
idx += 1
for key in pad_keys:
if key in batch:
batch[key] = batch[key][:, :idx].contiguous()
return batch