File size: 4,478 Bytes
4ad32d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

from dataclasses import dataclass, field
from typing import Optional

import os
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, pipeline

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import ORPOConfig, ORPOTrainer, set_seed
from trl.core import LengthSampler

# This code is built on top of the example code from Huggingface TRL Team

tqdm.pandas()

@dataclass
class ScriptArguments:
    model_name: Optional[str] = field(default="microsoft/phi-2", metadata={"help": "the model name"})
    optim: Optional[str] = field(default="adamw_torch", metadata={"help": "the model name"})
    data_name: Optional[str] = field(default="argilla/ultrafeedback-binarized-preferences-cleaned", metadata={"help": "the model name"})
    cache_dir: Optional[str] = field(default="", metadata={"help": "the model name"})
    log_with: Optional[str] = field(default='wandb', metadata={"help": "use 'wandb' to log with wandb"})
    output_dir: Optional[str] = field(default='', metadata={"help": "use 'wandb' to log with wandb"})
    learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
    lr_scheduler_type: Optional[str] = field(default='cosine', metadata={"help": "the learning rate scheduler"})
    per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "the batch size"})
    num_train_epochs: Optional[int] = field(default=5, metadata={"help": "the batch size"})
    beta: Optional[float] = field(default=0.25, metadata={"help": "weighting hyperparameter for L_OR"})
    gradient_accumulation_steps: Optional[int] = field(
        default=1, metadata={"help": "the number of gradient accumulation steps"}
    )


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]

config = ORPOConfig(
    output_dir=script_args.output_dir,
    max_prompt_length=1024,
    max_length=2048,
    logging_steps=100,
    save_strategy='no',
    max_completion_length=2048,
    per_device_train_batch_size=script_args.per_device_train_batch_size,
    remove_unused_columns=False,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,
    learning_rate=script_args.learning_rate,
    optim=script_args.optim,
    lr_scheduler_type=script_args.lr_scheduler_type,
    gradient_checkpointing=True, 
    gradient_checkpointing_kwargs={'use_reentrant':True},
    beta=script_args.beta,
    report_to='wandb',
    num_train_epochs=script_args.num_train_epochs,
    bf16=True,
    do_eval=False
)   

model = AutoModelForCausalLM.from_pretrained(script_args.model_name,
                                             cache_dir=script_args.cache_dir,
                                             attn_implementation='flash_attention_2',
                                             torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name,
                                          cache_dir=script_args.cache_dir)
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"

def build_dataset(tokenizer):
    ds_train = load_dataset(script_args.data_name, split="train",
                            cache_dir=script_args.cache_dir)

    def chat_template_to_text(sample):
        sample["chosen"] = [item_chosen[1]['content'] for item_chosen in sample['chosen']]
        sample["rejected"] = [item_rejected[1]['content'] for item_rejected in sample['rejected']]
        sample['prompt'] = [tokenizer.apply_chat_template([{'role': 'user', 'content': item_prompt}], tokenize=False, add_generation_prompt=True) for item_prompt in sample['prompt']]
        
        return sample
    
    ds_train = ds_train.map(chat_template_to_text, batched=True, num_proc=8)

    return ds_train

train = build_dataset(tokenizer=tokenizer)

trainer = ORPOTrainer(
                model=model,
                args=config,
                tokenizer=tokenizer,
                train_dataset=train
            )

trainer.train()