guu980-dev commited on
Commit
760ce89
·
unverified ·
2 Parent(s): ee63d12 0601fa5

Merge pull request #1 from sommizzu/main

Browse files

feat: Add optimized train script with fp16 precision

Files changed (1) hide show
  1. article_base_train_fp16.py +169 -0
article_base_train_fp16.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, time, math
2
+ import pandas as pd
3
+ from datasets import Dataset
4
+ from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration, BitsAndBytesConfig, TrainingArguments, Trainer
5
+ import torch
6
+ from PIL import Image
7
+ from peft import get_peft_model, LoraConfig
8
+ import argparse
9
+
10
+
11
+ def load_custom_dataset_from_csv(csv_file, image_folder):
12
+ data = pd.read_csv(csv_file)
13
+
14
+ questions = data['question'].tolist()
15
+ images = [os.path.join(image_folder, img) for img in data['image'].tolist()]
16
+ answers = data['answer'].tolist()
17
+
18
+ return Dataset.from_dict({
19
+ 'question': questions,
20
+ 'image': images,
21
+ 'answer': answers
22
+ })
23
+
24
+
25
+ def load_custom_dataset_from_parquet(parquet_file, image_folder):
26
+ data = pd.read_parquet(parquet_file)
27
+
28
+ questions = data['question'].tolist()
29
+ images = [os.path.join(image_folder, img) for img in data['image'].tolist()]
30
+ answers = data['answer'].tolist()
31
+
32
+ return Dataset.from_dict({
33
+ 'question': questions,
34
+ 'image': images,
35
+ 'answer': answers
36
+ })
37
+
38
+
39
+ def load_dataset_by_type(metadata_type, dataset_dir, image_folder):
40
+ if metadata_type == "csv":
41
+ return load_custom_dataset_from_csv(
42
+ os.path.join(dataset_dir, 'train_samples.csv'),
43
+ image_folder
44
+ )
45
+ elif metadata_type == "parquet":
46
+ return load_custom_dataset_from_parquet(
47
+ os.path.join(dataset_dir, 'train.parquet'),
48
+ image_folder
49
+ )
50
+ else:
51
+ raise ValueError("Unsupported metadata type. Use 'csv' or 'parquet'.")
52
+
53
+
54
+ def load_model_and_args(use_qlora, model_id, device, output_dir):
55
+ if use_qlora:
56
+ bnb_config = BitsAndBytesConfig(
57
+ load_in_4bit=True,
58
+ bnb_4bit_quant_type="nf4",
59
+ bnb_4bit_compute_dtype=torch.float16 # Changed from bfloat16 to float16
60
+ )
61
+ lora_config = LoraConfig(
62
+ r=8,
63
+ target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
64
+ task_type="CAUSAL_LM"
65
+ )
66
+
67
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"": 0})
68
+ model = get_peft_model(model, lora_config)
69
+ model.print_trainable_parameters()
70
+
71
+ args = TrainingArguments(
72
+ output_dir=os.path.join(output_dir, f"{math.floor(time.time())}"),
73
+ num_train_epochs=2,
74
+ remove_unused_columns=False,
75
+ per_device_train_batch_size=1,
76
+ gradient_accumulation_steps=4,
77
+ warmup_steps=2,
78
+ learning_rate=2e-5,
79
+ weight_decay=1e-6,
80
+ logging_steps=100,
81
+ optim="adamw_hf",
82
+ save_strategy="steps",
83
+ save_steps=1000,
84
+ save_total_limit=1,
85
+ fp16=True, # Changed from bf16 to fp16
86
+ report_to=["tensorboard"],
87
+ dataloader_pin_memory=False
88
+ )
89
+
90
+ return model, args
91
+ else:
92
+ model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16).to(device) # Changed from bfloat16 to float16
93
+ for param in model.vision_tower.parameters():
94
+ param.requires_grad = False
95
+
96
+ for param in model.multi_modal_projector.parameters():
97
+ param.requires_grad = True
98
+
99
+ args = TrainingArguments(
100
+ output_dir=os.path.join(output_dir, f"{math.floor(time.time())}"),
101
+ num_train_epochs=2,
102
+ remove_unused_columns=False,
103
+ per_device_train_batch_size=4,
104
+ gradient_accumulation_steps=4,
105
+ warmup_steps=2,
106
+ learning_rate=2e-5,
107
+ weight_decay=1e-6,
108
+ logging_steps=100,
109
+ optim="paged_adamw_8bit",
110
+ save_strategy="steps",
111
+ save_steps=1000,
112
+ save_total_limit=1,
113
+ fp16=True, # Changed from bf16 to fp16
114
+ report_to=["tensorboard"],
115
+ dataloader_pin_memory=False
116
+ )
117
+
118
+ return model, args
119
+
120
+
121
+ def main(args):
122
+ dataset_dir = args.dataset_dir
123
+ model_id = args.model_id
124
+ output_dir = args.output_dir
125
+ metadata_type = args.metadata_type
126
+
127
+ dataset = load_dataset_by_type(metadata_type, dataset_dir, os.path.join(dataset_dir, 'images'))
128
+ train_val_split = dataset.train_test_split(test_size=0.1)
129
+
130
+ train_ds = train_val_split['train']
131
+ val_ds = train_val_split['test']
132
+
133
+ processor = PaliGemmaProcessor.from_pretrained(model_id)
134
+ device = "cuda"
135
+
136
+ model, args = load_model_and_args(args.use_qlora, model_id, device, output_dir)
137
+
138
+ def collate_fn(examples):
139
+ texts = [example["question"] for example in examples]
140
+ labels = [example['answer'] for example in examples]
141
+ images = [Image.open(example['image']).convert("RGB") for example in examples]
142
+ tokens = processor(text=texts, images=images, suffix=labels, return_tensors="pt", padding="longest")
143
+ tokens = tokens.to(torch.float16).to(device) # Changed from bfloat16 to float16
144
+ return tokens
145
+
146
+ trainer = Trainer(
147
+ model=model,
148
+ train_dataset=train_ds,
149
+ eval_dataset=val_ds,
150
+ data_collator=collate_fn,
151
+ args=args
152
+ )
153
+
154
+ trainer.train()
155
+
156
+
157
+ def parse_args():
158
+ parser = argparse.ArgumentParser(description="Train a model with custom dataset")
159
+ parser.add_argument('--dataset_dir', type=str, default='./dataset', help='Path to the folder containing the images')
160
+ parser.add_argument('--model_id', type=str, default='google/paligemma-3b-pt-224', help='Model ID to use for training')
161
+ parser.add_argument('--output_dir', type=str, default='./output', help='Directory to save the output')
162
+ parser.add_argument('--use_qlora', type=bool, default=False, help='Use QLoRA for training')
163
+ parser.add_argument('--metadata_type', type=str, default='parquet', choices=['csv', 'parquet'], help='Metadata format (csv or parquet)')
164
+ return parser.parse_args()
165
+
166
+
167
+ if __name__ == "__main__":
168
+ args = parse_args()
169
+ main(args)