VRAM-estimator / vram_helpers.py
tvosch's picture
add minimal inference code
ffa1281
from dataclasses import dataclass, fields
from typing import Optional
PRECISION_TO_BYTES = {"fp32": 4,
"fp16": 2,
"bf16": 2,
"int8": 1,
"int4": 0.5}
@dataclass
class ModelConfig:
model_size: float
hidden_size: int
sequence_length: int
total_sequence_length: int # for inference = prompt + output tokens
num_layers: int
num_heads: int
mixed_precision: bool = False
precision: str = "bf16"
repo_id: Optional[str] = None
def overwrite_with_hf_config(self, config: dict):
self.model_size = round(get_model_size_from_config(config) / 10**9, 2)
self.hidden_size = config["hidden_size"]
self.sequence_length = config["max_position_embeddings"]
if self.total_sequence_length == 0:
self.total_sequence_length = self.sequence_length
self.num_layers = config["num_hidden_layers"]
self.num_heads = config["num_attention_heads"]
@dataclass
class TrainingConfig:
micro_batch_size: int
num_gpus: int
optimizer: str
zero_stage: int
qlora: bool = False
gradient_checkpointing: bool = False
train: bool = True # False for inference
# Utility function to filter params based on dataclass fields
def filter_params_for_dataclass(dataclass_type, params):
return {field.name: params[field.name] for field in fields(dataclass_type) if field.name in params}
def get_model_size_from_config(config: dict):
# Embedding parameters:
embedding_params = config["vocab_size"] * config["hidden_size"]
# Transformer layer parameters
def transformer_layer_params(hidden_size, intermediate_size, num_key_value_heads):
input_layernorm_params = hidden_size
mlp_down_proj_params = hidden_size * intermediate_size
mlp_gate_proj_params = intermediate_size * hidden_size
mlp_up_proj_params = intermediate_size * hidden_size
post_attention_layernorm_params = hidden_size
self_attn_k_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size
self_attn_o_proj_params = hidden_size * hidden_size
self_attn_q_proj_params = hidden_size * hidden_size
self_attn_v_proj_params = (hidden_size // (num_key_value_heads // 2)) * hidden_size
total_layer_params = (
input_layernorm_params + mlp_down_proj_params + mlp_gate_proj_params + mlp_up_proj_params +
post_attention_layernorm_params + self_attn_k_proj_params + self_attn_o_proj_params +
self_attn_q_proj_params + self_attn_v_proj_params
)
return total_layer_params
# Total parameters for all transformer layers
single_layer_params = transformer_layer_params(config["hidden_size"], config["intermediate_size"], config["num_key_value_heads"])
total_transformer_params = config["num_hidden_layers"] * single_layer_params
# Output layer parameters
output_params = config["vocab_size"] * config["hidden_size"]
# Total parameters
total_params = embedding_params + total_transformer_params + output_params
return total_params
def model_memory(parameters, precision = "bf16", mixed_precision = False):
if mixed_precision:
return parameters * (PRECISION_TO_BYTES["fp32"] + PRECISION_TO_BYTES["fp16"])
return parameters * PRECISION_TO_BYTES[precision]
def gradients_memory(parameters, precision = "fp32"):
return parameters * PRECISION_TO_BYTES[precision]
def optimizer_memory(parameters, optimizer= "adamw", precision = "fp32"):
optimizer_choices = {
"adam": 3, # Adam: stores precision copies of the optimizer parameters, momentum, and variance -> 4 + 4 + 4 = 12 bytes per model parameter
"adamw": 3, # AdamW: Same for Adam
"sgd": 2, # For SGD: optimier parameters and gradients -> 4 + 4 = 8 bytes per model parameter
"adamw_8bit": 1.5, # Adam 8-bit: same for Adam-> 2 + 2 + 2 = 6 bytes per model parameter
}
return optimizer_choices[optimizer] * parameters * PRECISION_TO_BYTES[precision]
# def activations_memory_per_layer(sequence_length, micro_batch_size, hidden_size, num_heads):
# bytes_per_layer = sequence_length * micro_batch_size * hidden_size * (34 + 5 * (num_heads * sequence_length / hidden_size))
# return bytes_per_layer / 10**9
def activations_memory_per_layer(sequence_length, micro_batch_size, hidden_size, num_heads):
precision = "fp32"
"Returns amount of GPU VRAM (in GB) required to store intermediate activations for traditional Transformer Encoder block"
mem_bytes = PRECISION_TO_BYTES[precision] * sequence_length * micro_batch_size * hidden_size * (
16 + 2/PRECISION_TO_BYTES[precision] + 2*num_heads*sequence_length/hidden_size + num_heads*sequence_length/(PRECISION_TO_BYTES[precision]*hidden_size))
return round(mem_bytes / 10**9, 2)
def activations_memory(num_layers, sequence_length, micro_batch_size, hidden_size, num_heads):
# Reference: https://arxiv.org/pdf/2205.05198
# Activations assumed to be in 16-bit floating precision
bytes_per_layer = activations_memory_per_layer(sequence_length, micro_batch_size, hidden_size, num_heads)
bytes_model = bytes_per_layer * num_layers
return bytes_model
def kv_cache_memory(batch_size, total_sequence_length, num_layers, num_heads, hidden_size, precision):
# Total sequence length means input prompt length + completion so we assume the context size of the model as upper bound
kv_cache_memory = 2 * batch_size * total_sequence_length * num_layers * num_heads * hidden_size * PRECISION_TO_BYTES[precision]
return kv_cache_memory / 10**9