CHEMISTral7Bv0.3 / utils /validate_data.py
Clemspace's picture
Initial model upload
cb9e677
import argparse
import json
from pathlib import Path
from typing import Dict
from mistral_common.exceptions import (
InvalidAssistantMessageException,
InvalidFunctionCallException,
InvalidMessageStructureException,
InvalidToolSchemaException,
TokenizerException,
)
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from tqdm import tqdm
from finetune.args import TrainArgs
from finetune.data.dataset import parse_data_sources
from finetune.data.tokenize import (
ConversationFormatError,
FunctionFormatError,
MessageFormatError,
SampleType,
ToolCallFormatError,
UnrecognizedRoleError,
build_instruct_sample,
get_pretrain_sample,
tokenize,
)
NUM_GPUS = 8
# EXPECTED WPS for batch_size = 32768 per GPU on H100
EXPECTED_WPS = {
"open-mistral-7b": 5720,
"open-mixtral-8x7b": 2966,
"open-mixtral-8x22b": 1007,
}
MIN_NUM_JSONL_LINES = 10
MAX_NUM_JSONL_LINES = 10_000_000
MIN_BYTES = 1_000
MAX_BYTES = 10_000_000_000 # rougly 10 GB
def convert_seconds_to_hms(seconds: float) -> str:
hours = int(seconds // 3600)
seconds %= 3600
minutes = int(seconds // 60)
seconds %= 60
seconds = int(seconds)
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
def verify_size(jsonl_file: Path):
print(f"Verifying {jsonl_file} ...")
with jsonl_file.open() as f:
num_lines = 0
num_bytes = 0
for line in f:
num_lines += 1
num_bytes += len(line)
if num_lines < MIN_NUM_JSONL_LINES:
raise ValueError(
f"{jsonl_file} has only {num_lines} conversation which is less than the minimum amount of conversations required per dataset file: {MIN_NUM_JSONL_LINES}. Please make sure that each dataset has at least {MIN_NUM_JSONL_LINES} conversations."
)
elif num_bytes < MIN_BYTES:
raise ValueError(
f"{jsonl_file} has only {num_bytes} bytes which is less than the minimum amount of bytes required per dataset file: {MIN_BYTES}. Please make sure that each dataset has at least {MIN_BYTES} bytes."
)
elif num_lines > MAX_NUM_JSONL_LINES:
raise ValueError(
f"{jsonl_file} has {num_lines} conversation which is more than the maximum amount of allowed per dataset file: {MAX_NUM_JSONL_LINES}. Please make sure that each dataset has no more than {MAX_NUM_JSONL_LINES} conversations."
)
elif num_bytes > MAX_BYTES:
raise ValueError(
f"{jsonl_file} has {num_bytes} bytes which is more than the maximum amount of bytes allowed per dataset file: {MAX_BYTES}. Please make sure that each dataset has no more than {MAX_BYTES} bytes."
)
print(
f"Dataset {jsonl_file} is valid. Dataset has {num_lines} conversations amounting to a total of {num_bytes} bytes."
)
def get_train_stats(
num_tokens: Dict[str, int],
datasets_proportion: Dict[str, float],
train_args: TrainArgs,
return_type: str,
):
dataset_tokens = sum(num_tokens.values())
batch_size = train_args.batch_size * train_args.seq_len * NUM_GPUS
if Path(train_args.model_id_or_path).is_dir():
params_config = json.load(
(Path(train_args.model_id_or_path) / "params.json").open()
)
if params_config["dim"] == 4096 and params_config.get("moe") is None:
model_id = "open-mistral-7b"
elif params_config["dim"] == 4096 and params_config.get("moe") is not None:
model_id = "open-mixtral-8x7b"
elif params_config["dim"] == 6144:
model_id = "open-mixtral-8x22b"
else:
raise ValueError("Provided model folder seems incorrect.")
else:
model_id = train_args.model_id_or_path
wps = EXPECTED_WPS[model_id]
if return_type == "expected":
train_tokens = train_args.max_steps * batch_size
max_steps = train_args.max_steps
num_epochs = train_tokens / dataset_tokens
elif return_type == "recommended":
num_epochs = 3
max_steps = int(sum(num_tokens.values()) / batch_size * num_epochs)
train_tokens = max_steps * batch_size
else:
raise ValueError(
f"`return_type` is {return_type}, but has to be one of ['expected', 'recommended']"
)
expected_time_in_sec = train_tokens / NUM_GPUS / wps
# Add 5min buffer for loading/init/ckpt/eval
expected_time_in_sec += 300
train_tokens_per_dataset = {
k: (train_tokens * v) for k, v in datasets_proportion.items()
}
return {
"eta": convert_seconds_to_hms(expected_time_in_sec),
"data_tokens": dataset_tokens,
"train_tokens": train_tokens,
"epochs": f"{num_epochs:.2f}",
"max_steps": max_steps,
"data_tokens_per_dataset": {k: f"{v:.1f}" for k, v in num_tokens.items()},
"train_tokens_per_dataset": {
k: f"{v:.1f}" for k, v in train_tokens_per_dataset.items()
},
"epochs_per_dataset": {
k: f"{(train_tokens_per_dataset[k] / num_tokens[k]):.1f}"
for k in num_tokens.keys()
},
}
def main(args):
train_args = TrainArgs.load(args.train_yaml)
yaml_data_errors = []
conversation_format_errors = []
message_format_errors = []
tokenization_errors = []
# Check if pretrain can be loaded
# train_pretrain_data = train_args.data.data
data = [("train", train_args.data.data, train_args.data.instruct_data)]
if train_args.data.eval_instruct_data != "":
data.append(("eval", "", train_args.data.eval_instruct_data))
EXPECTED_WPS.keys()
instruct_tokenizer = MistralTokenizer.v3().instruct_tokenizer
for name, pretrain_file, instruct_file in data:
datasets, weights = parse_data_sources(pretrain_file, instruct_file)
data_types = [d.sample_type for d in datasets]
datasets = [str(d.path) for d in datasets]
datasets_proportion = dict(zip(datasets, weights))
num_tokens = {k: 0 for k in datasets_proportion}
for data_type, dataset in tqdm(zip(data_types, datasets)):
# verify_size(Path(dataset))
print(f"Validating {dataset} ...")
corrected_dataset = dataset + ".corrected"
correct_lines = []
sub_yaml_data_errors = []
sub_conversation_format_errors = []
sub_message_format_errors = []
sub_tokenization_errors = []
# Load the dataset
with open(dataset, "r", encoding="utf-8") as f:
lines = f.readlines()
for idx, line in tqdm(enumerate(lines), total=len(lines)):
try:
data = json.loads(line)
except ValueError as e:
prefix = f"The data in line {idx + 1} of dataset {dataset} is incorrectly formated."
sub_yaml_data_errors.append(prefix + str(e))
continue
if data_type == SampleType.PRETRAIN:
# TODO(Patrick) - Get good error message
sample = get_pretrain_sample(data)
else:
try:
sample = build_instruct_sample(data)
except (
ConversationFormatError,
UnrecognizedRoleError,
MessageFormatError,
ToolCallFormatError,
FunctionFormatError,
InvalidAssistantMessageException,
InvalidFunctionCallException,
InvalidMessageStructureException,
InvalidToolSchemaException,
) as e:
prefix = f"The data in line {idx + 1} of dataset {dataset} is incorrectly formated."
if isinstance(
e, (ConversationFormatError, FunctionFormatError)
):
sub_conversation_format_errors.append(prefix + str(e))
elif isinstance(
e,
(
MessageFormatError,
UnrecognizedRoleError,
ToolCallFormatError,
),
):
sub_message_format_errors.append(prefix + str(e))
if isinstance(
e,
(
InvalidFunctionCallException,
InvalidMessageStructureException,
InvalidAssistantMessageException,
InvalidToolSchemaException,
),
):
sub_conversation_format_errors.append(prefix + str(e))
continue
try:
tokens = tokenize(sample, instruct_tokenizer).tokens
except TokenizerException as e:
error_message = (
f"The data in line {idx + 1} of dataset {dataset} could not be tokenized. "
+ str(e)
)
sub_tokenization_errors.append(error_message)
correct_lines.append(line)
num_tokens[dataset] += len(tokens)
is_sub_error = (
len(
sub_yaml_data_errors
+ sub_conversation_format_errors
+ sub_message_format_errors
+ sub_tokenization_errors
)
> 0
)
if is_sub_error and args.create_corrected:
with open(corrected_dataset, "w", encoding="utf-8") as f:
for line in correct_lines:
f.write(line)
print(f"Saved {corrected_dataset}.")
elif args.create_corrected:
print(f"No error in {dataset} - no need to create a corrected version.")
yaml_data_errors.extend(sub_yaml_data_errors)
conversation_format_errors.extend(sub_conversation_format_errors)
message_format_errors.extend(sub_message_format_errors)
tokenization_errors.extend(sub_tokenization_errors)
is_error = (
len(
yaml_data_errors
+ conversation_format_errors
+ message_format_errors
+ tokenization_errors
)
> 0
)
if is_error:
all_yaml_data_errors = "\n".join(yaml_data_errors)
all_conversation_format_errors = "\n".join(conversation_format_errors)
all_message_format_errors = "\n".join(message_format_errors)
all_tokenization_errors = "\n".join(tokenization_errors)
error_report = f"""
Data error report
----------------------- \n
The passed datasets contains some errors as listed below. Please make sure to fix these errors in order to start training.
YAML data load errors: \n\n {all_yaml_data_errors} \n\n
Conversation format errors: \n\n {all_conversation_format_errors} \n\n
Message format errors: \n\n {all_message_format_errors} \n\n
Tokenization errors: \n\n {all_tokenization_errors} \n\n
"""
if args.save_reports:
with open(args.error_report_txt, "w") as f:
f.write(error_report)
print(error_report)
else:
print("No errors! Data is correctly formated!")
if name == "train" and not is_error:
expected_stats = get_train_stats(
num_tokens, datasets_proportion, train_args, return_type="expected"
)
stats = {
"expected": expected_stats,
}
filenames = (
f"{instruct_file}"
if pretrain_file == ""
else f"{instruct_file} and {pretrain_file}"
)
print(
f"Stats for {filenames} \n {20 * '-'} \n {json.dumps(stats, indent=4)}"
)
if args.save_reports:
if name == "train":
with open(args.train_stats_json, "w") as file:
json.dump(stats, file, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Validate jsonl.")
parser.add_argument(
"--train_yaml",
type=str,
help="Path to the data file",
)
parser.add_argument(
"--error_report_txt",
type=str,
default="data_errors.txt",
help="Path to the error report.",
)
parser.add_argument(
"--train_stats_json",
type=str,
default="train_stats.json",
help="Path to training statistics json file.",
)
parser.add_argument(
"--save_reports", action="store_true", help="Save reports to disk"
)
parser.add_argument(
"--create_corrected",
action="store_true",
help="Skip faulty lines and append all correct lines to `.corrected` datasets.",
)
args = parser.parse_args()
main(args)