File size: 4,433 Bytes
1904ee8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from dataclasses import dataclass, field

import torch
from accelerate import PartialState
from datasets import load_dataset
from tqdm.rich import tqdm
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments

from trl import ModelConfig, SFTTrainer
from trl.trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config


tqdm.pandas()


def hh_combine(examples):
    if isinstance(examples["chosen"], str):
        return examples["prompt"] + examples["chosen"]
    elif isinstance(examples["chosen"], list):
        return list(map(str.__add__, examples["prompt"], examples["chosen"]))
    else:
        raise Exception(f"weird input examples of type {type(examples)}")


@dataclass
class ScriptArguments:
    task_type: str = field(default="hh")
    dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
    dataset_train_split: str = field(default="train", metadata={"help": "the name of the training set of the dataset"})
    dataset_eval_split: str = field(default="test", metadata={"help": "the name of the training set of the dataset"})
    output_model_name: str = field(default="", metadata={"help": "model name to upload"})
    max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
    packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"})
    config: str = field(default=None, metadata={"help": "Path to the optional config file"})
    gradient_checkpointing_use_reentrant: bool = field(
        default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
    )
    sanity_check: bool = field(default=False)


if __name__ == "__main__":
    parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
    args, training_args, model_config = parser.parse_args_into_dataclasses()

    ################
    # Model & Tokenizer
    ################
    torch_dtype = (
        model_config.torch_dtype
        if model_config.torch_dtype in ["auto", None]
        else getattr(torch, model_config.torch_dtype)
    )
    quantization_config = get_quantization_config(model_config)
    model_kwargs = dict(
        revision=model_config.model_revision,
        trust_remote_code=model_config.trust_remote_code,
        attn_implementation=model_config.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
    tokenizer.add_special_tokens({"pad_token": "<|padding|>"})

    ################
    # Dataset
    ################
    datasets = load_dataset(args.dataset_name)

    if args.sanity_check:
        for key in datasets:
            datasets[key] = datasets[key].select(range(100))

        training_args.push_to_hub = False

    train_dataset = datasets[args.dataset_train_split]
    eval_dataset = datasets[args.dataset_eval_split]

    # train_dataset = train_dataset.map(lambda ex: {"text": ex['prompt'] + ex['chosen']})
    # eval_dataset = eval_dataset.map(lambda ex: {"text": ex['prompt'] + ex['chosen']})

    if args.task_type == "tldr":
        formatting_func = None
        dataset_text_field = "query_reference_response"
    elif args.task_type == "hh":
        formatting_func = hh_combine
        dataset_text_field = None

    ################
    # Training
    ################
    trainer = SFTTrainer(
        model=model_config.model_name_or_path,
        model_init_kwargs=model_kwargs,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        max_seq_length=args.max_seq_length,
        tokenizer=tokenizer,
        packing=args.packing,
        formatting_func=formatting_func,
        dataset_text_field=dataset_text_field,
        peft_config=get_peft_config(model_config),
    )

    trainer.train()

    trainer.save_model(training_args.output_dir)

    if PartialState().is_main_process and model_config.use_peft:
        model = trainer.model.merge_and_unload()
        model.push_to_hub(args.output_model_name)
        tokenizer.push_to_hub(args.output_model_name)