File size: 4,535 Bytes
a905447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from dataclasses import dataclass, fields
from typing import Optional


PRECISION_TO_BYTES = {"float32": 4,
                      "fp32": 4,
                      "float16": 2,
                      "fp16": 2,
                      "bfloat16": 2,
                      "bf16": 2,
                      "int8": 1,
                      "int4": 0.5}


@dataclass
class ModelConfig:
    model_size: float
    hidden_size: int
    sequence_length: int
    num_layers: int
    num_heads: int
    mixed_precision: bool = False
    precision: str = "bf16"
    repo_id: Optional[str] = None

    def overwrite_with_hf_config(self, config: dict):
        self.model_size = round(get_model_size_from_config(config) / 10**9, 2)
        self.hidden_size = config["hidden_size"]
        self.sequence_length = config["max_position_embeddings"]
        self.num_layers = config["num_hidden_layers"]
        self.num_heads = config["num_attention_heads"]

@dataclass
class TrainingConfig:
    micro_batch_size: int
    num_gpus: int
    optimizer: str
    zero_stage: int
    qlora: bool = False
    gradient_checkpointing: bool = False

# Utility function to filter params based on dataclass fields
def filter_params_for_dataclass(dataclass_type, params):
    return {field.name: params[field.name] for field in fields(dataclass_type) if field.name in params}

def get_model_size_from_config(config: dict):
    # Embedding parameters:
    embedding_params = config["vocab_size"] * config["hidden_size"]

    # Transformer layer parameters
    def transformer_layer_params(hidden_size, intermediate_size, num_key_value_heads):
            input_layernorm_params = hidden_size
            mlp_down_proj_params = hidden_size * intermediate_size
            mlp_gate_proj_params = intermediate_size * hidden_size
            mlp_up_proj_params = intermediate_size * hidden_size
            post_attention_layernorm_params = hidden_size
            self_attn_k_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size
            self_attn_o_proj_params = hidden_size * hidden_size
            self_attn_q_proj_params = hidden_size * hidden_size
            self_attn_v_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size
            
            total_layer_params = (
                input_layernorm_params + mlp_down_proj_params + mlp_gate_proj_params + mlp_up_proj_params +
                post_attention_layernorm_params + self_attn_k_proj_params + self_attn_o_proj_params +
                self_attn_q_proj_params + self_attn_v_proj_params
            )

            return total_layer_params
    
    # Total parameters for all transformer layers
    single_layer_params = transformer_layer_params(config["hidden_size"], config["intermediate_size"], config["num_key_value_heads"])
    total_transformer_params = config["num_hidden_layers"] * single_layer_params
    
    # Output layer parameters
    output_params = config["vocab_size"] * config["hidden_size"]
    
    # Total parameters
    total_params = embedding_params + total_transformer_params + output_params
    return total_params

def model_memory(parameters, precision = "bf16", mixed_precision = False):
    if mixed_precision:
        return parameters * (PRECISION_TO_BYTES["fp32"] + PRECISION_TO_BYTES["fp16"])
    return parameters * PRECISION_TO_BYTES[precision]
    

def gradients_memory(parameters, precision = "fp32"):
    return parameters * PRECISION_TO_BYTES[precision]

def optimizer_memory(parameters, optimizer= "adamw", precision = "fp32"):
    optimizer_choices = {"adam": 3,    # Adam: stores precision copies of the optimizer parameters, momentum, and variance -> 4 + 4 + 4 = 12 bytes per model parameter
                         "adamw": 3,   # AdamW: Same for Adam
                         "sgd": 2,      # For SGD: optimier parameters and gradients -> 4 + 4 = 8 bytes per model parameter
                         "adam-8bit": 1.5, # Adam 8-bit: same for Adam-> 2 + 2 + 2 = 6 bytes per model parameter
                         }
    return optimizer_choices[optimizer] * parameters * PRECISION_TO_BYTES[precision]

def activations_memory(num_layers, sequence_length, micro_batch_size, hidden_size, num_heads):
    # Reference: https://arxiv.org/pdf/2205.05198
    # Activations assumed to be in 16-bit floating precision
    bytes_per_layer = sequence_length * micro_batch_size * hidden_size * (34 + 5 * (num_heads * sequence_length / hidden_size))
    bytes_model = bytes_per_layer * num_layers
    return bytes_model / 10**9