H2OTest / llm_studio /src /datasets /text_causal_classification_ds.py
elineve's picture
Upload 301 files
07423df
raw
history blame
3.38 kB
import logging
from typing import Any, Dict
import numpy as np
import pandas as pd
import torch
from llm_studio.src.datasets.text_causal_language_modeling_ds import (
CustomDataset as TextCausalLanguageModelingCustomDataset,
)
from llm_studio.src.utils.exceptions import LLMDataException
logger = logging.getLogger(__name__)
class CustomDataset(TextCausalLanguageModelingCustomDataset):
def __init__(self, df: pd.DataFrame, cfg: Any, mode: str = "train"):
super().__init__(df=df, cfg=cfg, mode=mode)
check_for_non_int_answers(cfg, df)
self.answers_int = df[cfg.dataset.answer_column].astype(int).values.tolist()
if 1 < cfg.dataset.num_classes <= max(self.answers_int):
raise LLMDataException(
"Number of classes is smaller than max label "
f"{max(self.answers_int)}. Please increase the setting accordingly."
)
elif cfg.dataset.num_classes == 1 and max(self.answers_int) > 1:
raise LLMDataException(
"For binary classification, max label should be 1 but is "
f"{max(self.answers_int)}."
)
if min(self.answers_int) < 0:
raise LLMDataException(
"Labels should be non-negative but min label is "
f"{min(self.answers_int)}."
)
if (
min(self.answers_int) != 0
or max(self.answers_int) != len(set(self.answers_int)) - 1
):
logger.warning(
"Labels should start at 0 and be continuous but are "
f"{sorted(set(self.answers_int))}."
)
if cfg.dataset.parent_id_column != "None":
raise LLMDataException(
"Parent ID column is not supported for classification datasets."
)
def __getitem__(self, idx: int) -> Dict:
sample = super().__getitem__(idx)
sample["class_label"] = self.answers_int[idx]
return sample
def postprocess_output(self, cfg, df: pd.DataFrame, output: Dict) -> Dict:
output["logits"] = output["logits"].float()
if cfg.dataset.num_classes == 1:
preds = output["logits"]
preds = np.array((preds > 0.0)).astype(int).astype(str).reshape(-1)
else:
preds = output["logits"]
preds = (
np.array(torch.argmax(preds, dim=1)) # type: ignore[arg-type]
.astype(str)
.reshape(-1)
)
output["predicted_text"] = preds
return super().postprocess_output(cfg, df, output)
def clean_output(self, output, cfg):
return output
@classmethod
def sanity_check(cls, df: pd.DataFrame, cfg: Any, mode: str = "train"):
# TODO: Dataset import in UI is currently using text_causal_language_modeling_ds
check_for_non_int_answers(cfg, df)
def check_for_non_int_answers(cfg, df):
answers_non_int = [
x for x in df[cfg.dataset.answer_column].values if not is_castable_to_int(x)
]
if len(answers_non_int) > 0:
raise LLMDataException(
f"Column {cfg.dataset.answer_column} contains non int items. "
f"Sample values: {answers_non_int[:5]}."
)
def is_castable_to_int(s):
try:
int(s)
return True
except ValueError:
return False