Spaces:
Running
Running
from dataclasses import dataclass, fields | |
from typing import Optional | |
PRECISION_TO_BYTES = {"float32": 4, | |
"fp32": 4, | |
"float16": 2, | |
"fp16": 2, | |
"bfloat16": 2, | |
"bf16": 2, | |
"int8": 1, | |
"int4": 0.5} | |
class ModelConfig: | |
model_size: float | |
hidden_size: int | |
sequence_length: int | |
num_layers: int | |
num_heads: int | |
mixed_precision: bool = False | |
precision: str = "bf16" | |
repo_id: Optional[str] = None | |
def overwrite_with_hf_config(self, config: dict): | |
self.model_size = round(get_model_size_from_config(config) / 10**9, 2) | |
self.hidden_size = config["hidden_size"] | |
self.sequence_length = config["max_position_embeddings"] | |
self.num_layers = config["num_hidden_layers"] | |
self.num_heads = config["num_attention_heads"] | |
class TrainingConfig: | |
micro_batch_size: int | |
num_gpus: int | |
optimizer: str | |
zero_stage: int | |
qlora: bool = False | |
gradient_checkpointing: bool = False | |
# Utility function to filter params based on dataclass fields | |
def filter_params_for_dataclass(dataclass_type, params): | |
return {field.name: params[field.name] for field in fields(dataclass_type) if field.name in params} | |
def get_model_size_from_config(config: dict): | |
# Embedding parameters: | |
embedding_params = config["vocab_size"] * config["hidden_size"] | |
# Transformer layer parameters | |
def transformer_layer_params(hidden_size, intermediate_size, num_key_value_heads): | |
input_layernorm_params = hidden_size | |
mlp_down_proj_params = hidden_size * intermediate_size | |
mlp_gate_proj_params = intermediate_size * hidden_size | |
mlp_up_proj_params = intermediate_size * hidden_size | |
post_attention_layernorm_params = hidden_size | |
self_attn_k_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size | |
self_attn_o_proj_params = hidden_size * hidden_size | |
self_attn_q_proj_params = hidden_size * hidden_size | |
self_attn_v_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size | |
total_layer_params = ( | |
input_layernorm_params + mlp_down_proj_params + mlp_gate_proj_params + mlp_up_proj_params + | |
post_attention_layernorm_params + self_attn_k_proj_params + self_attn_o_proj_params + | |
self_attn_q_proj_params + self_attn_v_proj_params | |
) | |
return total_layer_params | |
# Total parameters for all transformer layers | |
single_layer_params = transformer_layer_params(config["hidden_size"], config["intermediate_size"], config["num_key_value_heads"]) | |
total_transformer_params = config["num_hidden_layers"] * single_layer_params | |
# Output layer parameters | |
output_params = config["vocab_size"] * config["hidden_size"] | |
# Total parameters | |
total_params = embedding_params + total_transformer_params + output_params | |
return total_params | |
def model_memory(parameters, precision = "bf16", mixed_precision = False): | |
if mixed_precision: | |
return parameters * (PRECISION_TO_BYTES["fp32"] + PRECISION_TO_BYTES["fp16"]) | |
return parameters * PRECISION_TO_BYTES[precision] | |
def gradients_memory(parameters, precision = "fp32"): | |
return parameters * PRECISION_TO_BYTES[precision] | |
def optimizer_memory(parameters, optimizer= "adamw", precision = "fp32"): | |
optimizer_choices = {"adam": 3, # Adam: stores precision copies of the optimizer parameters, momentum, and variance -> 4 + 4 + 4 = 12 bytes per model parameter | |
"adamw": 3, # AdamW: Same for Adam | |
"sgd": 2, # For SGD: optimier parameters and gradients -> 4 + 4 = 8 bytes per model parameter | |
"adam-8bit": 1.5, # Adam 8-bit: same for Adam-> 2 + 2 + 2 = 6 bytes per model parameter | |
} | |
return optimizer_choices[optimizer] * parameters * PRECISION_TO_BYTES[precision] | |
def activations_memory(num_layers, sequence_length, micro_batch_size, hidden_size, num_heads): | |
# Reference: https://arxiv.org/pdf/2205.05198 | |
# Activations assumed to be in 16-bit floating precision | |
bytes_per_layer = sequence_length * micro_batch_size * hidden_size * (34 + 5 * (num_heads * sequence_length / hidden_size)) | |
bytes_model = bytes_per_layer * num_layers | |
return bytes_model / 10**9 |