Update README.md
Browse files
README.md
CHANGED
@@ -16,8 +16,8 @@ I used the same dataset and reformatted it to apply the ChatML template.
|
|
16 |
The code to train this model is available on Google Colab and GitHub.
|
17 |
Fine-tuning took about an hour on Google Colab A-1000 GPU with 40GB VRAM.
|
18 |
|
19 |
-
|
20 |
-
|
21 |
peft_config = LoraConfig(
|
22 |
r=16,
|
23 |
lora_alpha=16,
|
@@ -27,7 +27,7 @@ peft_config = LoraConfig(
|
|
27 |
target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
|
28 |
)
|
29 |
|
30 |
-
|
31 |
model = AutoModelForCausalLM.from_pretrained(
|
32 |
model_name,
|
33 |
torch_dtype=torch.float16,
|
@@ -35,14 +35,14 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
35 |
)
|
36 |
model.config.use_cache = False
|
37 |
|
38 |
-
|
39 |
ref_model = AutoModelForCausalLM.from_pretrained(
|
40 |
model_name,
|
41 |
torch_dtype=torch.float16,
|
42 |
load_in_4bit=True
|
43 |
)
|
44 |
|
45 |
-
|
46 |
training_args = TrainingArguments(
|
47 |
per_device_train_batch_size=4,
|
48 |
gradient_accumulation_steps=4,
|
@@ -59,7 +59,7 @@ training_args = TrainingArguments(
|
|
59 |
report_to="wandb",
|
60 |
)
|
61 |
|
62 |
-
|
63 |
dpo_trainer = DPOTrainer(
|
64 |
model,
|
65 |
ref_model,
|
|
|
16 |
The code to train this model is available on Google Colab and GitHub.
|
17 |
Fine-tuning took about an hour on Google Colab A-1000 GPU with 40GB VRAM.
|
18 |
|
19 |
+
# TRAINING SPECIFICATIONS
|
20 |
+
> LoRA configuration
|
21 |
peft_config = LoraConfig(
|
22 |
r=16,
|
23 |
lora_alpha=16,
|
|
|
27 |
target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
|
28 |
)
|
29 |
|
30 |
+
> Model to fine-tune
|
31 |
model = AutoModelForCausalLM.from_pretrained(
|
32 |
model_name,
|
33 |
torch_dtype=torch.float16,
|
|
|
35 |
)
|
36 |
model.config.use_cache = False
|
37 |
|
38 |
+
> Reference model
|
39 |
ref_model = AutoModelForCausalLM.from_pretrained(
|
40 |
model_name,
|
41 |
torch_dtype=torch.float16,
|
42 |
load_in_4bit=True
|
43 |
)
|
44 |
|
45 |
+
> Training arguments
|
46 |
training_args = TrainingArguments(
|
47 |
per_device_train_batch_size=4,
|
48 |
gradient_accumulation_steps=4,
|
|
|
59 |
report_to="wandb",
|
60 |
)
|
61 |
|
62 |
+
> Create DPO trainer
|
63 |
dpo_trainer = DPOTrainer(
|
64 |
model,
|
65 |
ref_model,
|