VRAM-estimator / vram_helpers.py
tvosch's picture
quick qlora support
a905447
raw
history blame
4.54 kB
from dataclasses import dataclass, fields
from typing import Optional
PRECISION_TO_BYTES = {"float32": 4,
"fp32": 4,
"float16": 2,
"fp16": 2,
"bfloat16": 2,
"bf16": 2,
"int8": 1,
"int4": 0.5}
@dataclass
class ModelConfig:
model_size: float
hidden_size: int
sequence_length: int
num_layers: int
num_heads: int
mixed_precision: bool = False
precision: str = "bf16"
repo_id: Optional[str] = None
def overwrite_with_hf_config(self, config: dict):
self.model_size = round(get_model_size_from_config(config) / 10**9, 2)
self.hidden_size = config["hidden_size"]
self.sequence_length = config["max_position_embeddings"]
self.num_layers = config["num_hidden_layers"]
self.num_heads = config["num_attention_heads"]
@dataclass
class TrainingConfig:
micro_batch_size: int
num_gpus: int
optimizer: str
zero_stage: int
qlora: bool = False
gradient_checkpointing: bool = False
# Utility function to filter params based on dataclass fields
def filter_params_for_dataclass(dataclass_type, params):
return {field.name: params[field.name] for field in fields(dataclass_type) if field.name in params}
def get_model_size_from_config(config: dict):
# Embedding parameters:
embedding_params = config["vocab_size"] * config["hidden_size"]
# Transformer layer parameters
def transformer_layer_params(hidden_size, intermediate_size, num_key_value_heads):
input_layernorm_params = hidden_size
mlp_down_proj_params = hidden_size * intermediate_size
mlp_gate_proj_params = intermediate_size * hidden_size
mlp_up_proj_params = intermediate_size * hidden_size
post_attention_layernorm_params = hidden_size
self_attn_k_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size
self_attn_o_proj_params = hidden_size * hidden_size
self_attn_q_proj_params = hidden_size * hidden_size
self_attn_v_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size
total_layer_params = (
input_layernorm_params + mlp_down_proj_params + mlp_gate_proj_params + mlp_up_proj_params +
post_attention_layernorm_params + self_attn_k_proj_params + self_attn_o_proj_params +
self_attn_q_proj_params + self_attn_v_proj_params
)
return total_layer_params
# Total parameters for all transformer layers
single_layer_params = transformer_layer_params(config["hidden_size"], config["intermediate_size"], config["num_key_value_heads"])
total_transformer_params = config["num_hidden_layers"] * single_layer_params
# Output layer parameters
output_params = config["vocab_size"] * config["hidden_size"]
# Total parameters
total_params = embedding_params + total_transformer_params + output_params
return total_params
def model_memory(parameters, precision = "bf16", mixed_precision = False):
if mixed_precision:
return parameters * (PRECISION_TO_BYTES["fp32"] + PRECISION_TO_BYTES["fp16"])
return parameters * PRECISION_TO_BYTES[precision]
def gradients_memory(parameters, precision = "fp32"):
return parameters * PRECISION_TO_BYTES[precision]
def optimizer_memory(parameters, optimizer= "adamw", precision = "fp32"):
optimizer_choices = {"adam": 3, # Adam: stores precision copies of the optimizer parameters, momentum, and variance -> 4 + 4 + 4 = 12 bytes per model parameter
"adamw": 3, # AdamW: Same for Adam
"sgd": 2, # For SGD: optimier parameters and gradients -> 4 + 4 = 8 bytes per model parameter
"adam-8bit": 1.5, # Adam 8-bit: same for Adam-> 2 + 2 + 2 = 6 bytes per model parameter
}
return optimizer_choices[optimizer] * parameters * PRECISION_TO_BYTES[precision]
def activations_memory(num_layers, sequence_length, micro_batch_size, hidden_size, num_heads):
# Reference: https://arxiv.org/pdf/2205.05198
# Activations assumed to be in 16-bit floating precision
bytes_per_layer = sequence_length * micro_batch_size * hidden_size * (34 + 5 * (num_heads * sequence_length / hidden_size))
bytes_model = bytes_per_layer * num_layers
return bytes_model / 10**9