alexhotti commited on
Commit
0c4bdb6
·
verified ·
1 Parent(s): c382351

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -2,7 +2,19 @@
2
 
3
  This is a fine-tuned version of Qwen2.5-7B-Instruct optimized for agent tasks.
4
 
 
 
 
 
5
  ## Model Performance
 
 
 
 
 
 
 
 
 
6
 
7
- - Test Accuracy: 0.7983
8
- - Train Accuracy: 0.8371
 
2
 
3
  This is a fine-tuned version of Qwen2.5-7B-Instruct optimized for agent tasks.
4
 
5
+ ## Dataset Information
6
+ - Train Dataset Size: 387 examples
7
+ - Test Dataset Size: 96 examples
8
+
9
  ## Model Performance
10
+ - Test Accuracy: 0.0000
11
+ - Train Accuracy: 0.0000
12
+
13
+ ## Training Configuration
14
+ - Base Model: Qwen/Qwen2.5-VL-7B-Instruct
15
+ - Checkpoint: checkpoint-1261
16
+ - Dataset: AgentEvalDatapointDataset
17
+ - Training Script: [train_transformer.py](train_transformer.py)
18
+ - DeepSpeed Config: [deepspeed_config.json](deepspeed_config.json)
19
 
20
+ The training configuration files are included in this model repository for reproducibility.
 
deepspeed_config.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "zero_optimization": {
3
+ "stage": 2,
4
+ "offload_optimizer": {
5
+ "device": "cpu",
6
+ "pin_memory": true
7
+ },
8
+ "allgather_partitions": true,
9
+ "allgather_bucket_size": 2e8,
10
+ "overlap_comm": true,
11
+ "reduce_scatter": true,
12
+ "reduce_bucket_size": 2e8,
13
+ "contiguous_gradients": true
14
+ },
15
+ "activation_checkpointing": {
16
+ "partition_activations": true,
17
+ "contiguous_memory_optimization": true,
18
+ "cpu_checkpointing": true,
19
+ "number_checkpoints": 2
20
+ },
21
+ "bf16": {
22
+ "enabled": true
23
+ },
24
+ "optimizer": {
25
+ "type": "AdamW",
26
+ "params": {
27
+ "lr": "auto",
28
+ "betas": [0.9, 0.999],
29
+ "eps": 1e-8,
30
+ "weight_decay": "auto"
31
+ }
32
+ },
33
+ "gradient_clipping": 1.0,
34
+ "train_micro_batch_size_per_gpu": 1,
35
+ "train_batch_size": 4,
36
+ "steps_per_print": 1,
37
+ "wall_clock_breakdown": false,
38
+ "zero_allow_untested_optimizer": true,
39
+ "zero_force_ds_cpu_optimizer": false,
40
+ "dump_state": true,
41
+ "verbose": true,
42
+ "gradient_accumulation_steps": 2
43
+ }
model-00001-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:17f77bfc008ad2c2867f6e03cbffc4d59569a665f29dca986239639bda31298b
3
  size 4968243304
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c404e6b3946f71290e21ef9f1b8be781ee0ecf99c2a71811db3aca53f48c86af
3
  size 4968243304
model-00002-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ee9c83e1c5851045ae4f58c6cb8b181256b4663f524f9ef077b33981d57b45c
3
  size 4991495816
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bab02b40814094336b89d2fe093082da009e0ee55a69d83be7227c2c181302a2
3
  size 4991495816
model-00003-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:289c24e6e2ea8eca57ee2a5b687348ee590875d9974eac89409767c5a39e902f
3
  size 4932751040
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:023252833ff6762a552366ded4c613f582ec57e5049c32cf9aba234d00e5435c
3
  size 4932751040
model-00004-of-00004.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:91881873609fc477db79a60fdb55376a0302c06f85a598875f449ca000184d84
3
  size 1691924384
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f862573f529ee441622e8368c9c1af35b4a824a7ba6a515c840afdccec0163a
3
  size 1691924384
train_transformer.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+
4
+ torch.cuda.empty_cache()
5
+ import torch.distributed
6
+ from dataset import AgentDatapointDataset
7
+ import os
8
+ import wandb
9
+ from lightning.pytorch.loggers import WandbLogger
10
+ from peft import get_peft_model, LoraConfig
11
+ from transformers import TrainerCallback
12
+
13
+ from transformers import BitsAndBytesConfig
14
+
15
+ # from unsloth import is_bf16_supported
16
+
17
+ # This version of qwen requires more vram
18
+ from transformers import Qwen2_5_VLProcessor, Qwen2_5_VLForConditionalGeneration
19
+ from trl import SFTTrainer, SFTConfig
20
+
21
+ # This version of qwen requires less vram since is uses compiled componentsand also a fused cross entropy loss
22
+ # from model import Qwen2_5_VLForConditionalGeneration
23
+ from transformers import logging as transformers_logging
24
+ import logging
25
+
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+ transformers_logging.set_verbosity_error()
29
+ import argparse
30
+ from torch.optim import AdamW
31
+ from qwen_vl_utils import process_vision_info
32
+
33
+ torch.set_float32_matmul_precision("medium")
34
+
35
+ import json
36
+
37
+
38
+ from evaluate import evaluate_model
39
+
40
+ from dataset import AgentEvalDatapointDataset, AgentDatapointDataset
41
+
42
+ # Perhaps want to add back these later
43
+ # from unsloth.models._utils import prepare_model_for_kbit_training
44
+ # from gradient_checkpointing import patch_unsloth_smart_gradient_checkpointing
45
+
46
+
47
+ def train_collate_fn(examples, processor):
48
+
49
+ texts = [
50
+ processor.apply_chat_template(example["messages"], tokenize=False)
51
+ for example in examples
52
+ ]
53
+
54
+ image_inputs = [process_vision_info(example["messages"])[0] for example in examples]
55
+
56
+ model_inputs = processor(
57
+ text=texts, images=image_inputs, return_tensors="pt", padding=True
58
+ )
59
+
60
+ labels = model_inputs["input_ids"].clone()
61
+
62
+ # mask padding tokens in labels
63
+ labels[labels == processor.tokenizer.pad_token_id] = -100
64
+
65
+ if isinstance(processor, Qwen2_5_VLProcessor):
66
+ image_tokens = [151652, 151653, 151655]
67
+ else:
68
+ image_tokens = [
69
+ processor.tokenizer.convert_tokens_to_ids(processor.image_token)
70
+ ]
71
+
72
+ # mask image token IDs in the labels
73
+ for image_token_id in image_tokens:
74
+ labels[labels == image_token_id] = -100
75
+
76
+ # Return a dictionary instead of a tuple
77
+ return {
78
+ "input_ids": model_inputs["input_ids"],
79
+ "attention_mask": model_inputs["attention_mask"],
80
+ "pixel_values": model_inputs["pixel_values"],
81
+ "image_grid_thw": model_inputs["image_grid_thw"],
82
+ "labels": labels,
83
+ }
84
+
85
+
86
+ def _wrap_fast_inference(generate, device_type, dtype, model):
87
+ # Wraps inference with bfloat16 / float16
88
+ @torch.inference_mode
89
+ def _fast_generate(*args, **kwargs):
90
+ # For num_logits_to_keep
91
+ # kwargs["num_logits_to_keep"] = 1
92
+
93
+ # Remove token_type_ids
94
+ kwargs.pop("token_type_ids", None)
95
+
96
+ # Check pad_token
97
+ model_eos_token_id = getattr(model.config, "eos_token_id", None)
98
+ if model_eos_token_id is not None and hasattr(model_eos_token_id, "__iter__"):
99
+ model_eos_token_id = model_eos_token_id[0]
100
+
101
+ kwargs["pad_token_id"] = kwargs.pop("pad_token_id", model_eos_token_id)
102
+
103
+ try:
104
+ kwargs["pixel_values"] = kwargs["pixel_values"].to(model.dtype)
105
+ except:
106
+ pass
107
+
108
+ # Autocasted
109
+ with torch.autocast(device_type=device_type, dtype=dtype):
110
+ output = generate(*args, **kwargs)
111
+ pass
112
+ return output
113
+
114
+ pass
115
+ return _fast_generate
116
+
117
+
118
+ pass
119
+
120
+
121
+ def for_inference(model):
122
+ model.gradient_checkpointing = False
123
+ model.training = False
124
+
125
+ for name, module in model.named_modules():
126
+ if hasattr(module, "gradient_checkpointing"):
127
+ module.gradient_checkpointing = False
128
+ if hasattr(module, "training"):
129
+ module.training = False
130
+ pass
131
+
132
+ dtype = model.config.torch_dtype
133
+ if type(dtype) is str:
134
+ if dtype == "float16":
135
+ dtype = torch.float16
136
+ elif dtype == "bfloat16":
137
+ dtype = torch.bfloat16
138
+ pass
139
+ device_type = model.device.type
140
+
141
+ # Wrap model.generate
142
+ if model.generate.__name__ != "_fast_generate":
143
+ model._unwrapped_old_generate = model.generate
144
+ model.generate = _wrap_fast_inference(model.generate, device_type, dtype, model)
145
+ pass
146
+
147
+ # Patch tokenizer to pad to the left
148
+ internal_model = model
149
+ while hasattr(internal_model, "model"):
150
+ if hasattr(internal_model, "_saved_temp_tokenizer"):
151
+
152
+ internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left"
153
+ pass
154
+ internal_model = internal_model.model
155
+ pass
156
+ if hasattr(internal_model, "_saved_temp_tokenizer"):
157
+ internal_model._saved_temp_tokenizer.tokenizer.padding_side = "left"
158
+ pass
159
+
160
+ # Also disable training for embeddings for NEFTune
161
+ if hasattr(model, "get_input_embeddings"):
162
+ embeddings = model.get_input_embeddings()
163
+ if hasattr(embeddings, "training"):
164
+ embeddings.training = False
165
+ pass
166
+ if hasattr(model, "get_output_embeddings"):
167
+ embeddings = model.get_output_embeddings()
168
+ if hasattr(embeddings, "training"):
169
+ embeddings.training = False
170
+ pass
171
+
172
+ return model
173
+
174
+
175
+ def for_training(model, use_gradient_checkpointing=True):
176
+ model.train()
177
+ model.gradient_checkpointing = use_gradient_checkpointing
178
+ model.training = True
179
+
180
+ for name, module in model.named_modules():
181
+ if hasattr(module, "gradient_checkpointing"):
182
+ module.gradient_checkpointing = use_gradient_checkpointing
183
+ if hasattr(module, "training"):
184
+ module.training = True
185
+ pass
186
+
187
+ # Also revert model.generate
188
+ if hasattr(model, "_unwrapped_old_generate"):
189
+ model.generate = model._unwrapped_old_generate
190
+ del model._unwrapped_old_generate
191
+ pass
192
+
193
+ # Patch tokenizer to pad to the right
194
+ internal_model = model
195
+ while hasattr(internal_model, "model"):
196
+ if hasattr(internal_model, "_saved_temp_tokenizer"):
197
+ internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right"
198
+ pass
199
+ internal_model = internal_model.model
200
+ pass
201
+ if hasattr(internal_model, "_saved_temp_tokenizer"):
202
+ internal_model._saved_temp_tokenizer.tokenizer.padding_side = "right"
203
+ pass
204
+
205
+ # Also re-enable training for embeddings for NEFTune
206
+ if hasattr(model, "get_input_embeddings"):
207
+ embeddings = model.get_input_embeddings()
208
+ if hasattr(embeddings, "training"):
209
+ embeddings.training = True
210
+ pass
211
+ if hasattr(model, "get_output_embeddings"):
212
+ embeddings = model.get_output_embeddings()
213
+ if hasattr(embeddings, "training"):
214
+ embeddings.training = True
215
+ pass
216
+
217
+ return model
218
+
219
+
220
+ class CustomTrainingCallback(TrainerCallback):
221
+ def __init__(self, trainer, eval_epoch_interval=2):
222
+ self.trainer = trainer
223
+ self.eval_epoch_interval = eval_epoch_interval
224
+
225
+ def on_log(self, args, state, control, logs=None, **kwargs):
226
+ """Log metrics at each logging step"""
227
+ if logs is not None:
228
+ # Ensure wandb is initialized
229
+ import wandb
230
+
231
+ if not wandb.run:
232
+ wandb.init(
233
+ project="qwen-vl-trainer",
234
+ reinit=True,
235
+ name=f"{os.environ.get('RANK', '0')}-training",
236
+ group=os.environ.get("WANDB_RUN_GROUP", None),
237
+ )
238
+
239
+ # Log all metrics from the logs dictionary
240
+ step = state.global_step if hasattr(state, "global_step") else 0
241
+
242
+ # Extract and log training metrics
243
+ log_data = {}
244
+ for key, value in logs.items():
245
+ # Prefix training metrics to differentiate from eval metrics
246
+ if key not in ["eval_loss", "epoch", "learning_rate"]:
247
+ log_data[f"train/{key}"] = value
248
+ else:
249
+ log_data[key] = value
250
+
251
+ wandb.log(log_data, step=step)
252
+
253
+ def on_epoch_end(self, args, state, control, **kwargs):
254
+ print(f"Epoch {state.epoch + 1} ended")
255
+ was_training = self.trainer.model.training
256
+ for_inference(self.trainer.model)
257
+ self.trainer.model.eval()
258
+
259
+ if (state.epoch + 1) % self.eval_epoch_interval == 0 and state.epoch > 4:
260
+ self.trainer.evaluate_step(dataset=self.trainer.eval_dataset, split="test")
261
+ self.trainer.evaluate_step(
262
+ dataset=self.trainer.train_dataset_eval, split="train"
263
+ )
264
+
265
+ if was_training:
266
+ for_training(self.trainer.model)
267
+ self.trainer.model.train()
268
+
269
+
270
+ class CustomSFTTrainer(SFTTrainer):
271
+ def __init__(
272
+ self,
273
+ model,
274
+ tokenizer,
275
+ processor,
276
+ data_collator,
277
+ train_dataset=None,
278
+ train_dataset_eval=None,
279
+ eval_dataset=None,
280
+ eval_epoch_interval=2,
281
+ args=None,
282
+ ):
283
+ # train_dataset_eval=train_dataset_eval,
284
+ # train_dataset=train_dataset,
285
+ # eval_dataset=test_dataset,
286
+ self.custom_callback = CustomTrainingCallback(
287
+ self, eval_epoch_interval=eval_epoch_interval
288
+ )
289
+ callbacks = [self.custom_callback]
290
+
291
+ super().__init__(
292
+ model=model,
293
+ tokenizer=tokenizer,
294
+ data_collator=data_collator,
295
+ train_dataset=train_dataset,
296
+ eval_dataset=eval_dataset,
297
+ callbacks=callbacks,
298
+ args=args,
299
+ )
300
+ self.eval_dataset = eval_dataset
301
+ self.train_dataset_eval = train_dataset_eval
302
+ self.state = type("State", (), {"global_step": 0})()
303
+ self.processor = processor
304
+
305
+ def evaluate_step(self, dataset, split):
306
+ print(f"Evaluating {split} dataset")
307
+ try:
308
+ device = self.model.device
309
+
310
+ # The correct signature is: evaluate_model(model, processor, dataset, split, verbose=False)
311
+ accuracy = evaluate_model(self.model, self.processor, dataset, split)
312
+
313
+ # Initialize wandb if not already initialized
314
+ import wandb
315
+
316
+ if not wandb.run:
317
+ wandb.init(
318
+ project="qwen-vl-trainer",
319
+ reinit=True,
320
+ name=f"{os.environ.get('RANK', '0')}-evaluation",
321
+ group=os.environ.get("WANDB_RUN_GROUP", None),
322
+ )
323
+
324
+ wandb.log(
325
+ {
326
+ f"{split}/accuracy": accuracy,
327
+ }
328
+ )
329
+
330
+ # Don't finish wandb here to avoid conflicts with the training process
331
+
332
+ except Exception as e:
333
+ logger.error(f"Error evaluating: {e}")
334
+ raise
335
+
336
+ def cleanup(self):
337
+ """Cleanup method to ensure wandb runs are properly closed"""
338
+ import wandb
339
+
340
+ if wandb.run:
341
+ wandb.finish()
342
+
343
+
344
+ def load_model(MODEL_ID: str, USE_QLORA: bool, training_args):
345
+
346
+ # patch_unsloth_smart_gradient_checkpointing()
347
+ # Configure more aggressive quantization
348
+ bnb_config = BitsAndBytesConfig(
349
+ load_in_4bit=True,
350
+ bnb_4bit_use_double_quant=True,
351
+ bnb_4bit_quant_type="nf4",
352
+ bnb_4bit_compute_dtype=torch.bfloat16,
353
+ )
354
+
355
+ # More aggressive LoRA config
356
+ lora_config = LoraConfig(
357
+ r=200, # Increase rank for more expressiveness
358
+ lora_alpha=50, # Higher scaling factor
359
+ lora_dropout=0.001, # Moderate dropout
360
+ bias="lora_only",
361
+ target_modules=[
362
+ "qkv_proj",
363
+ "o_proj",
364
+ "gate_up_proj",
365
+ "down_proj",
366
+ "gate_proj",
367
+ "up_proj",
368
+ "down_proj",
369
+ "fc1",
370
+ "fc2",
371
+ "mlp.0",
372
+ "mlp.2",
373
+ ],
374
+ task_type="CAUSAL_LM",
375
+ inference_mode=False,
376
+ modules_to_save=None,
377
+ )
378
+
379
+ # Clear memory before model load
380
+ torch.cuda.empty_cache()
381
+ gc.collect()
382
+
383
+ # Load DeepSpeed config
384
+ with open(training_args.deepspeed, "r") as f:
385
+ ds_config = json.load(f)
386
+
387
+ # Set is_deepspeed_zero3_enabled flag for ZeRO-3
388
+ is_deepspeed_zero3_enabled = (
389
+ ds_config.get("zero_optimization", {}).get("stage", 0) == 3
390
+ )
391
+
392
+ # Pass DeepSpeed configuration to from_pretrained
393
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
394
+ MODEL_ID,
395
+ # quantization_config=bnb_config if USE_QLORA else None, # Use the config
396
+ torch_dtype=torch.bfloat16,
397
+ # device_map=None, # Let DeepSpeed handle device mapping
398
+ use_cache=False,
399
+ attn_implementation="flash_attention_2",
400
+ )
401
+
402
+ # Reset generation config to avoid warnings
403
+ from transformers import GenerationConfig
404
+
405
+ model.generation_config = GenerationConfig.from_model_config(model.config)
406
+ # Ensure no conflicting generation parameters
407
+ model.generation_config.temperature = None
408
+ model.generation_config.top_p = None
409
+ model.generation_config.top_k = None
410
+ model.generation_config.early_stopping = False
411
+
412
+ processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID)
413
+
414
+ model.enable_input_require_grads() # unsloth added this prior to loading peft
415
+ model = get_peft_model(model, lora_config)
416
+ model.gradient_checkpointing_enable()
417
+
418
+ model.config.use_cache = False
419
+ model.config.pretraining_tp = 1
420
+
421
+ # More aggressive gradient checkpointing
422
+ model.config.gradient_checkpointing = True
423
+ model.config.use_reentrant = False
424
+ model.config.gradient_checkpointing_kwargs = {
425
+ "use_reentrant": False,
426
+ "checkpoint_every_n_layers": 1,
427
+ "offload_to_cpu": True,
428
+ }
429
+
430
+ return model, processor
431
+
432
+
433
+ def main(args):
434
+ # Set CUDA device explicitly based on local_rank
435
+ if args.local_rank != -1:
436
+ torch.cuda.set_device(args.local_rank)
437
+
438
+ # Initialize process group with the correct device
439
+ if not torch.distributed.is_initialized():
440
+ # Get world size from environment if available
441
+ world_size = int(os.environ.get("WORLD_SIZE", torch.cuda.device_count()))
442
+ rank = int(os.environ.get("RANK", args.local_rank))
443
+ print(
444
+ f"Initializing process group with rank={rank}, world_size={world_size}"
445
+ )
446
+
447
+ try:
448
+ torch.distributed.init_process_group(
449
+ backend="nccl",
450
+ init_method="env://",
451
+ world_size=world_size,
452
+ rank=rank,
453
+ )
454
+ print(f"Successfully initialized process group for rank {rank}")
455
+ except Exception as e:
456
+ print(f"Could not initialize process group: {e}")
457
+
458
+ # Remove memory management env vars that might interfere with DeepSpeed
459
+ os.environ.pop("PYTORCH_CUDA_ALLOC_CONF", None)
460
+ os.environ.pop("MAX_JOBS", None)
461
+ os.environ.pop("CUDA_LAUNCH_BLOCKING", None)
462
+
463
+ # Set up DeepSpeed config path first
464
+ ds_config_path = "deepspeed_config.json"
465
+
466
+ # Set up wandb configuration
467
+ os.environ["WANDB_MODE"] = "online"
468
+
469
+ # Create a unique timestamp for this training run
470
+ import datetime
471
+
472
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
473
+ run_id = timestamp
474
+ os.environ["WANDB_RUN_GROUP"] = f"qwen_training_{run_id}"
475
+
476
+ # Create a timestamped output directory
477
+ timestamped_output_dir = os.path.join(args.output_dir, f"run_{timestamp}")
478
+ os.makedirs(timestamped_output_dir, exist_ok=True)
479
+ print(f"Model checkpoints will be saved to: {timestamped_output_dir}")
480
+
481
+ # Configure wandb properly for Trainer
482
+ os.environ["WANDB_PROJECT"] = "qwen-vl-trainer"
483
+ os.environ["WANDB_LOG_MODEL"] = "end" # Changed from "true" to "end"
484
+ os.environ["WANDB_WATCH"] = "all" # Monitor all gradients and parameters
485
+ os.environ["WANDB_NAME"] = f"run_{timestamp}_rank{os.environ.get('RANK', '0')}"
486
+
487
+ # Initialize wandb only once at the beginning for the main process
488
+ if args.local_rank <= 0: # Only initialize on rank 0 or single GPU
489
+ import wandb
490
+
491
+ wandb.init(
492
+ project="qwen-vl-trainer",
493
+ name=f"transformer_training_{timestamp}",
494
+ group=os.environ.get("WANDB_RUN_GROUP"),
495
+ # Important: we're logging the model as an artifact
496
+ settings=wandb.Settings(_disable_stats=True, _disable_meta=True),
497
+ )
498
+ # Log config information
499
+ wandb.config.update(
500
+ {
501
+ "model_id": args.model_id,
502
+ "use_qlora": args.use_qlora,
503
+ "output_dir": timestamped_output_dir,
504
+ }
505
+ )
506
+ print(f"Initialized wandb with run ID: {wandb.run.id}")
507
+
508
+ # Create SFTConfig with DeepSpeed config before loading the model
509
+ training_args = SFTConfig(
510
+ per_device_train_batch_size=1, # Equivalent to train_micro_batch_size_per_gpu
511
+ gradient_accumulation_steps=2,
512
+ logging_steps=1, # Log every step
513
+ logging_strategy="steps", # Log based on steps
514
+ log_level="info",
515
+ num_train_epochs=2000, # Set to desired number of epochs
516
+ # eval_steps=100,
517
+ bf16=True,
518
+ optim="adamw_8bit",
519
+ lr_scheduler_type="linear",
520
+ seed=3407,
521
+ output_dir=timestamped_output_dir, # Use timestamped directory
522
+ overwrite_output_dir=True,
523
+ report_to="wandb", # Explicitly report to wandb
524
+ remove_unused_columns=False,
525
+ dataset_text_field="",
526
+ dataset_kwargs={"skip_prepare_dataset": True},
527
+ dataset_num_proc=4,
528
+ max_seq_length=800000,
529
+ save_strategy="epoch",
530
+ evaluation_strategy="no",
531
+ save_total_limit=2000,
532
+ deepspeed=ds_config_path, # Pass the DeepSpeed config
533
+ )
534
+
535
+ # Dynamically set devices based on availability
536
+ num_gpus = torch.cuda.device_count()
537
+ devices = list(range(num_gpus)) if num_gpus > 0 else None
538
+
539
+ # Pass training args to load_model function
540
+ model, processor = load_model(args.model_id, args.use_qlora, training_args)
541
+ # Train dataset
542
+ train_dataset = AgentDatapointDataset(split="train")
543
+ # Eval datasets
544
+ test_dataset = AgentEvalDatapointDataset(split="test")
545
+ train_dataset_eval = AgentEvalDatapointDataset(split="train")
546
+ for_training(model)
547
+
548
+ trainer = CustomSFTTrainer(
549
+ model=model,
550
+ processor=processor,
551
+ tokenizer=processor.tokenizer,
552
+ data_collator=lambda examples: train_collate_fn(examples, processor),
553
+ train_dataset_eval=train_dataset_eval,
554
+ train_dataset=train_dataset,
555
+ eval_dataset=test_dataset,
556
+ args=training_args,
557
+ )
558
+
559
+ training_stats = trainer.train()
560
+ logger.info("Training completed.")
561
+ print(f"Training Statistics: {training_stats}")
562
+
563
+ # Save the final model explicitly with timestamp
564
+ final_model_path = os.path.join(timestamped_output_dir, "final_model")
565
+ if args.local_rank <= 0: # Only save on rank 0 or single GPU
566
+ print(f"Saving final model to {final_model_path}")
567
+ trainer.save_model(final_model_path)
568
+ print(f"Final model saved to {final_model_path}")
569
+ # Also save the processor
570
+ processor.save_pretrained(final_model_path)
571
+
572
+ # Log the final model to wandb
573
+ # import wandb
574
+ # if wandb.run:
575
+ # model_artifact = wandb.Artifact(
576
+ # name=f"model_{timestamp}",
577
+ # type="model",
578
+ # description=f"Final trained model from run {timestamp}"
579
+ # )
580
+ # model_artifact.add_dir(final_model_path)
581
+ # wandb.log_artifact(model_artifact)
582
+ # print(f"Final model logged to wandb as artifact: model_{timestamp}")
583
+ #
584
+ # print(f"Final model saved to {final_model_path}")
585
+
586
+ # Ensure proper cleanup of wandb
587
+ trainer.cleanup()
588
+
589
+ # Final cleanup for the main process
590
+ if args.local_rank <= 0: # Only finalize on rank 0 or single GPU
591
+ import wandb
592
+
593
+ if wandb.run:
594
+ print("Finalizing main wandb run...")
595
+ wandb.finish()
596
+
597
+ print("Training process completed successfully.")
598
+
599
+
600
+ if __name__ == "__main__":
601
+ parser = argparse.ArgumentParser(description="Training configuration")
602
+ parser.add_argument(
603
+ "--model_id",
604
+ type=str,
605
+ default="Qwen/Qwen2.5-VL-7B-Instruct",
606
+ help="Model ID to use",
607
+ )
608
+ parser.add_argument(
609
+ "--use_qlora", type=bool, default=True, help="Whether to use QLoRA"
610
+ )
611
+ parser.add_argument(
612
+ "--output_dir", type=str, default="checkpoints_27feb", help="Output directory"
613
+ )
614
+ # Add local_rank argument for DeepSpeed
615
+ parser.add_argument(
616
+ "--local_rank", type=int, default=-1, help="Local rank for distributed training"
617
+ )
618
+ args = parser.parse_args()
619
+ main(args)