Update config/dtat_config.py
Browse files- config/dtat_config.py +8 -4
config/dtat_config.py
CHANGED
@@ -8,7 +8,7 @@ class DTATConfig:
|
|
8 |
self.n_layer = 12
|
9 |
self.n_head = 8 # Reduced from 12
|
10 |
self.n_embd = 512 # Reduced from 768
|
11 |
-
self.dropout = 0.
|
12 |
self.bias = True
|
13 |
|
14 |
# Sequence parameters
|
@@ -20,7 +20,7 @@ class DTATConfig:
|
|
20 |
self.min_lr = 1e-5 # Lower minimum to allow fine-tuning
|
21 |
self.warmup_iters = 367 # 5% of 14,667 iterations
|
22 |
self.max_iters = 7334 # Exactly 4 epochs with batch_size=24
|
23 |
-
self.weight_decay = 0.
|
24 |
self.beta1 = 0.9
|
25 |
self.beta2 = 0.95
|
26 |
self.grad_clip = 1.0
|
@@ -40,7 +40,7 @@ class DTATConfig:
|
|
40 |
|
41 |
# Sparse attention parameters
|
42 |
self.sparse_topk = 32 # Number of tokens to attend to
|
43 |
-
self.importance_dropout = 0.
|
44 |
|
45 |
# Mixed precision training
|
46 |
self.mixed_precision = True
|
@@ -48,12 +48,16 @@ class DTATConfig:
|
|
48 |
|
49 |
# Memory optimization
|
50 |
self.gradient_checkpointing = True
|
51 |
-
self.batch_size =
|
52 |
|
53 |
# System
|
54 |
self.device = 'cuda'
|
55 |
self.compile = True
|
56 |
|
|
|
|
|
|
|
|
|
57 |
# Git config for model versioning
|
58 |
self.git_name = "Your Name"
|
59 |
self.git_email = "your.email@example.com"
|
|
|
8 |
self.n_layer = 12
|
9 |
self.n_head = 8 # Reduced from 12
|
10 |
self.n_embd = 512 # Reduced from 768
|
11 |
+
self.dropout = 0.1 # Reduced for more stability
|
12 |
self.bias = True
|
13 |
|
14 |
# Sequence parameters
|
|
|
20 |
self.min_lr = 1e-5 # Lower minimum to allow fine-tuning
|
21 |
self.warmup_iters = 367 # 5% of 14,667 iterations
|
22 |
self.max_iters = 7334 # Exactly 4 epochs with batch_size=24
|
23 |
+
self.weight_decay = 0.1 # Reduced for more stability
|
24 |
self.beta1 = 0.9
|
25 |
self.beta2 = 0.95
|
26 |
self.grad_clip = 1.0
|
|
|
40 |
|
41 |
# Sparse attention parameters
|
42 |
self.sparse_topk = 32 # Number of tokens to attend to
|
43 |
+
self.importance_dropout = 0.1 # Reduced for more stability
|
44 |
|
45 |
# Mixed precision training
|
46 |
self.mixed_precision = True
|
|
|
48 |
|
49 |
# Memory optimization
|
50 |
self.gradient_checkpointing = True
|
51 |
+
self.batch_size = 32 # Increased for more stable gradients
|
52 |
|
53 |
# System
|
54 |
self.device = 'cuda'
|
55 |
self.compile = True
|
56 |
|
57 |
+
# Performance optimization
|
58 |
+
self.compile_model = True # Enable torch.compile
|
59 |
+
self.cudnn_benchmark = True # Enable cuDNN benchmarking
|
60 |
+
|
61 |
# Git config for model versioning
|
62 |
self.git_name = "Your Name"
|
63 |
self.git_email = "your.email@example.com"
|