tvosch commited on
Commit
ffa1281
1 Parent(s): 9223a1e

add minimal inference code

Browse files
Files changed (3) hide show
  1. app.py +79 -40
  2. estimate_train_vram.py +34 -4
  3. vram_helpers.py +25 -7
app.py CHANGED
@@ -5,11 +5,12 @@ from functools import partial
5
  import gradio as gr
6
  from transformers import AutoConfig
7
 
8
- from estimate_train_vram import vram_required
9
- from vram_helpers import ModelConfig, TrainingConfig, filter_params_for_dataclass
10
 
11
  ZERO_STAGES = [0, 1, 2, 3]
12
  BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64]
 
13
  OPTIMIZERS = ["adam", "adamw", "adamw_8bit", "sgd"]
14
  HUGGINGFACE_URL_CONFIG = "https://huggingface.co/{}/resolve/main/config.json"
15
 
@@ -31,6 +32,9 @@ def parse_args():
31
  parser.add_argument("--num_gpus", type=int, default=4, help="Number of GPUs. Necessary for estimating ZeRO stages")
32
  parser.add_argument("--cache_dir", type=str, default=None, help="HuggingFace cache directory to download config from")
33
  parser.add_argument("--qlora", action="store_false", help="Enable QLoRA in case of finetuning")
 
 
 
34
 
35
  parser.add_argument("--no-app", action="store_true", help="Launch gradio app. Otherwise, commandline output")
36
  return parser
@@ -67,80 +71,110 @@ def scrape_config_from_hub(repo_id):
67
 
68
  def build_interface(estimate_vram_fn):
69
  with gr.Blocks() as app:
70
- gr.Markdown("## Select either an existing HF model from a repository or choose your own model parameters")
71
- option = gr.Radio(["Repo ID", "Model Parameters"], label="Select Input Type")
72
  repo_id = gr.Textbox(label="Repo ID", visible=False, placeholder="mistralai/Mistral-7B-v0.1")
73
 
74
  with gr.Row(visible=False) as model_params_row:
75
  model_params = [gr.Slider(label="Model Size", minimum=0.1, maximum=400, step=0.1, value=7, info="Model size (in billion parameters)"),
76
  gr.Slider(label="Hidden size", minimum=256, maximum=8192, step=128, value=4096, info="Hidden size"),
77
- gr.Slider(label="Sequence length", minimum=256, maximum=128_000, step=256, value=8192, info="Sequence length"),
78
  gr.Slider(label="Num layers", minimum=8, maximum=64, step=1, value=32, info="Number of layers"),
79
  gr.Slider(label="Num heads", minimum=8, maximum=64, step=1, value=32, info="Number of attention heads")
80
  ]
81
 
82
 
83
- def update_visibility(selected_option):
84
- if selected_option == "Repo ID":
85
- return gr.update(visible=True), gr.update(visible=False),
86
- elif selected_option == "Model Parameters":
87
- return gr.update(visible=False), gr.update(visible=True)
88
-
89
- option.change(
90
- fn=update_visibility,
91
- inputs=[option],
92
- outputs=[repo_id, model_params_row]
93
- )
 
 
 
 
 
 
 
 
 
 
94
 
95
- gr.Markdown("## Select training parameters")
96
- with gr.Row(equal_height=True):
 
97
  training_params = [gr.Dropdown(label="Micro batch size", choices=BATCH_SIZES, value=4, info="Micro batch size (batch size per device/GPU)"),
98
  gr.Dropdown(label="ZeRO stage", choices=ZERO_STAGES, value=0, info="ZeRO optimization stage"),
99
  gr.Dropdown(label="Gradient checkpointing", choices=[True, False], value=True, info="Enable gradient checkpointing"),
100
  gr.Dropdown(label="Mixed precision", choices=[False, True], value=False, info="Enable mixed precision for model training"),
101
  gr.Dropdown(label="Optimizer", choices=OPTIMIZERS, value="adamw", info="Type of optimizer"),
102
  gr.Dropdown(label="QLoRA", choices=[False, True], value=False, info="Finetune with QLoRA enabled"),
103
- gr.Slider(label="Num GPUs", minimum=1, maximum=64, step=1, value=4, info="Number of GPUs. Necessary for estimating ZeRO stages"),
104
- gr.Textbox(label="Cache dir", value=None, placeholder=".huggingface_configs", info="HuggingFace cache directory to download config to")
105
  ]
106
 
107
- submit_btn = gr.Button("Estimate!")
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- output = gr.Textbox(label="Total estimated VRAM per device/GPU (in GB)")
110
 
111
- def create_combined_params_dict(repo_id, *values):
112
- all_params = model_params + training_params
 
 
113
  combined_dict = {param.label.lower().replace(" ", "_"): value for param, value in zip(all_params, values)}
114
  combined_dict["repo_id"] = repo_id
 
115
  return combined_dict
116
 
117
  submit_btn.click(
118
- fn=lambda repo_id, *values: estimate_vram_fn(create_combined_params_dict(repo_id, *values)),
119
- inputs=[repo_id] + model_params + training_params,
120
- outputs=[output]
121
  )
122
  return app
123
 
124
 
125
  def estimate_vram(gradio_params):
 
126
  model_config = ModelConfig(**filter_params_for_dataclass(ModelConfig, gradio_params))
127
  training_config = TrainingConfig(**filter_params_for_dataclass(TrainingConfig, gradio_params))
 
 
128
  # Update model config
129
  if not gradio_params["repo_id"]:
130
  return "No model selected!"
131
  # If cache directory set, then download config
132
- if gradio_params["cache_dir"]:
133
- config = scrape_config_from_hub(gradio_params["repo_id"])
134
- model_config.overwrite_with_hf_config(config)
 
135
  # By default, scrape config.json from hub
136
- else:
137
- config = download_config_from_hub(gradio_params["repo_id"], gradio_params["cache_dir"])
138
- model_config.overwrite_with_hf_config(config.to_dict())
139
-
140
- if gradio_params["qlora"]:
141
- model_config.precision = "int4"
142
- total_vram_dict = vram_required(model_config, training_config)
143
- output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['gradients']}GB (gradients) + {total_vram_dict['optimizer']}GB (optimizer) + {total_vram_dict['activations']}GB activations"
 
 
144
  return output_str
145
 
146
  if __name__ == "__main__":
@@ -166,5 +200,10 @@ if __name__ == "__main__":
166
  config = scrape_config_from_hub(args.repo_id)
167
  model_config.overwrite_with_hf_config(config)
168
 
169
- total_vram_dict = vram_required(model_config, training_config)
170
- print(f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['gradients']}GB (gradients) + {total_vram_dict['optimizer']}GB (optimizer) + {total_vram_dict['activations']}GB (activations)")
 
 
 
 
 
 
5
  import gradio as gr
6
  from transformers import AutoConfig
7
 
8
+ from estimate_train_vram import training_vram_required, inference_vram_required
9
+ from vram_helpers import ModelConfig, TrainingConfig, filter_params_for_dataclass, PRECISION_TO_BYTES
10
 
11
  ZERO_STAGES = [0, 1, 2, 3]
12
  BATCH_SIZES = [1, 2, 4, 8, 16, 32, 64]
13
+ QUANTIZATION = PRECISION_TO_BYTES.keys()
14
  OPTIMIZERS = ["adam", "adamw", "adamw_8bit", "sgd"]
15
  HUGGINGFACE_URL_CONFIG = "https://huggingface.co/{}/resolve/main/config.json"
16
 
 
32
  parser.add_argument("--num_gpus", type=int, default=4, help="Number of GPUs. Necessary for estimating ZeRO stages")
33
  parser.add_argument("--cache_dir", type=str, default=None, help="HuggingFace cache directory to download config from")
34
  parser.add_argument("--qlora", action="store_false", help="Enable QLoRA in case of finetuning")
35
+ parser.add_argument("--quantization", type=str, choices=QUANTIZATION, help="Type of quantization. Default is fp16/bf16")
36
+ parser.add_argument("--train", action="store_false", help="Flag to turn off train and run inference")
37
+ parser.add_argument("--total_sequence_length", type=int, default=0, help="Total sequence length (prompt + output) for inference")
38
 
39
  parser.add_argument("--no-app", action="store_true", help="Launch gradio app. Otherwise, commandline output")
40
  return parser
 
71
 
72
  def build_interface(estimate_vram_fn):
73
  with gr.Blocks() as app:
74
+ gr.Markdown("## 1. Select HuggingFace model from a repository or choose your own model parameters")
75
+ model_option = gr.Radio(["Repo ID", "Model Parameters"], label="Select Input Type")
76
  repo_id = gr.Textbox(label="Repo ID", visible=False, placeholder="mistralai/Mistral-7B-v0.1")
77
 
78
  with gr.Row(visible=False) as model_params_row:
79
  model_params = [gr.Slider(label="Model Size", minimum=0.1, maximum=400, step=0.1, value=7, info="Model size (in billion parameters)"),
80
  gr.Slider(label="Hidden size", minimum=256, maximum=8192, step=128, value=4096, info="Hidden size"),
81
+ gr.Slider(label="Sequence length", minimum=128, maximum=128_000, step=256, value=8192, info="Sequence length"),
82
  gr.Slider(label="Num layers", minimum=8, maximum=64, step=1, value=32, info="Number of layers"),
83
  gr.Slider(label="Num heads", minimum=8, maximum=64, step=1, value=32, info="Number of attention heads")
84
  ]
85
 
86
 
87
+
88
+
89
+ def update_visibility_model_type(selected_option, choices):
90
+ """
91
+ Dynamically update the visibility of components based on the selected option.
92
+
93
+ :param selected_option: The currently selected option
94
+ :param choices: Variable number of tuples, each containing (option_value, component)
95
+ :return: List of gr.update() calls corresponding to each choice
96
+ """
97
+ updates = []
98
+ for option_value, _ in choices:
99
+ updates.append(gr.update(visible=(selected_option == option_value)))
100
+ return updates
101
+
102
+ model_option_choices = [("Repo ID", repo_id), ("Model Parameters", model_params_row)]
103
+ model_option.change(
104
+ fn=partial(update_visibility_model_type, choices=model_option_choices),
105
+ inputs=[model_option],
106
+ outputs=[repo_id, model_params_row],
107
+ )
108
 
109
+ gr.Markdown("## 2. Select training or inference parameters")
110
+ training_option = gr.Radio(["Training", "Inference"], label="Select Input Type")
111
+ with gr.Row(equal_height=True, visible=False) as training_params_row:
112
  training_params = [gr.Dropdown(label="Micro batch size", choices=BATCH_SIZES, value=4, info="Micro batch size (batch size per device/GPU)"),
113
  gr.Dropdown(label="ZeRO stage", choices=ZERO_STAGES, value=0, info="ZeRO optimization stage"),
114
  gr.Dropdown(label="Gradient checkpointing", choices=[True, False], value=True, info="Enable gradient checkpointing"),
115
  gr.Dropdown(label="Mixed precision", choices=[False, True], value=False, info="Enable mixed precision for model training"),
116
  gr.Dropdown(label="Optimizer", choices=OPTIMIZERS, value="adamw", info="Type of optimizer"),
117
  gr.Dropdown(label="QLoRA", choices=[False, True], value=False, info="Finetune with QLoRA enabled"),
118
+ gr.Slider(label="Num GPUs", minimum=1, maximum=256, step=1, value=4, info="Number of GPUs. Necessary for estimating ZeRO stages"),
 
119
  ]
120
 
121
+ with gr.Row(equal_height=True, visible=False) as inference_params_row:
122
+ inference_params = [gr.Dropdown(label="Quantization", choices=QUANTIZATION, value="fp16", info="Quantization of model"),
123
+ gr.Slider(label="Num GPUs", minimum=1, maximum=256, step=1, value=1, info="Number of GPUs"),
124
+ gr.Dropdown(label="Micro batch size", choices=BATCH_SIZES, value=1, info="Micro batch size (batch size per device/GPU)"),
125
+ gr.Slider(label="Total sequence length", minimum=128, maximum=128_000, value=0, info="Total sequence length to run (necessary for KV cache calculation")
126
+ ]
127
+
128
+ training_option_choices = [("Training", inference_params_row), ("Inference", training_params_row)]
129
+ training_option.change(
130
+ fn=partial(update_visibility_model_type, choices=training_option_choices),
131
+ inputs=[training_option],
132
+ outputs=[training_params_row, inference_params_row],
133
+ )
134
 
 
135
 
136
+ submit_btn = gr.Button("Estimate!")
137
+ output = gr.Textbox(label="Total estimated VRAM per device/GPU (in GB)")
138
+ def create_combined_params_dict(repo_id, training_option, *values):
139
+ all_params = model_params + training_params + inference_params
140
  combined_dict = {param.label.lower().replace(" ", "_"): value for param, value in zip(all_params, values)}
141
  combined_dict["repo_id"] = repo_id
142
+ combined_dict["train"] = True if training_option.lower() == "training" else False # False -> inference
143
  return combined_dict
144
 
145
  submit_btn.click(
146
+ fn=lambda repo_id, training_option, *values: estimate_vram_fn(create_combined_params_dict(repo_id, training_option, *values)),
147
+ inputs=[repo_id, training_option] + model_params + training_params + inference_params,
148
+ outputs=[output]
149
  )
150
  return app
151
 
152
 
153
  def estimate_vram(gradio_params):
154
+ print(gradio_params)
155
  model_config = ModelConfig(**filter_params_for_dataclass(ModelConfig, gradio_params))
156
  training_config = TrainingConfig(**filter_params_for_dataclass(TrainingConfig, gradio_params))
157
+
158
+
159
  # Update model config
160
  if not gradio_params["repo_id"]:
161
  return "No model selected!"
162
  # If cache directory set, then download config
163
+ # if gradio_params["cache_dir"]:
164
+ # config = scrape_config_from_hub(gradio_params["repo_id"])
165
+ # model_config.overwrite_with_hf_config(config)
166
+ cache_dir="cache/"
167
  # By default, scrape config.json from hub
168
+ #else:
169
+ config = download_config_from_hub(gradio_params["repo_id"], cache_dir)# gradio_params["cache_dir"])
170
+ model_config.overwrite_with_hf_config(config.to_dict())
171
+
172
+ if training_config.train:
173
+ total_vram_dict = training_vram_required(model_config, training_config)
174
+ output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['gradients']}GB (gradients) + {total_vram_dict['optimizer']}GB (optimizer) + {total_vram_dict['activations']}GB activations"
175
+ else: # inference
176
+ total_vram_dict = inference_vram_required(model_config, training_config)
177
+ output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['kv_cache']}GB (KV cache) + {total_vram_dict['activations']}GB activations"
178
  return output_str
179
 
180
  if __name__ == "__main__":
 
200
  config = scrape_config_from_hub(args.repo_id)
201
  model_config.overwrite_with_hf_config(config)
202
 
203
+ if training_config.train:
204
+ total_vram_dict = training_vram_required(model_config, training_config)
205
+ output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['gradients']}GB (gradients) + {total_vram_dict['optimizer']}GB (optimizer) + {total_vram_dict['activations']}GB activations"
206
+ else: # inference
207
+ total_vram_dict = inference_vram_required(model_config, training_config)
208
+ output_str = f"Total {total_vram_dict['total']}GB = {total_vram_dict['model']}GB (model) + {total_vram_dict['kv_cache']}GB (KV cache) + {total_vram_dict['activations']}GB activations"
209
+ print(output_str)
estimate_train_vram.py CHANGED
@@ -1,15 +1,20 @@
1
 
2
- from vram_helpers import model_memory, gradients_memory, optimizer_memory, activations_memory
 
 
 
 
 
3
 
4
 
5
- def vram_required(model_config, training_config):
6
  # Reference: https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/
7
 
8
  trainable_parameters = model_config.model_size
9
  if training_config.qlora:
10
  model_config.precision = "int4"
11
- # Generally around 4-8% of trainable parameters so take upper bound
12
- trainable_parameters = 0.1 * model_config.model_size
13
 
14
  model_vram = model_memory(parameters=trainable_parameters,
15
  precision=model_config.precision,
@@ -38,6 +43,7 @@ def vram_required(model_config, training_config):
38
  training_config.micro_batch_size,
39
  model_config.hidden_size,
40
  model_config.num_heads)
 
41
  if training_config.gradient_checkpointing:
42
  activations_vram = round(activations_vram ** 0.5, 2)
43
 
@@ -48,4 +54,28 @@ def vram_required(model_config, training_config):
48
  "gradients": gradients_vram,
49
  "optimizer": optimizer_vram,
50
  "activations": activations_vram
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  }.items()}
 
1
 
2
+ from vram_helpers import activations_memory_per_layer, \
3
+ model_memory, \
4
+ gradients_memory, \
5
+ optimizer_memory, \
6
+ activations_memory, \
7
+ kv_cache_memory
8
 
9
 
10
+ def training_vram_required(model_config, training_config):
11
  # Reference: https://www.microsoft.com/en-us/research/blog/zero-deepspeed-new-system-optimizations-enable-training-models-with-over-100-billion-parameters/
12
 
13
  trainable_parameters = model_config.model_size
14
  if training_config.qlora:
15
  model_config.precision = "int4"
16
+ # 0.2% according to LoRA paper (https://arxiv.org/pdf/2106.09685)
17
+ trainable_parameters = 0.0002 * model_config.model_size
18
 
19
  model_vram = model_memory(parameters=trainable_parameters,
20
  precision=model_config.precision,
 
43
  training_config.micro_batch_size,
44
  model_config.hidden_size,
45
  model_config.num_heads)
46
+
47
  if training_config.gradient_checkpointing:
48
  activations_vram = round(activations_vram ** 0.5, 2)
49
 
 
54
  "gradients": gradients_vram,
55
  "optimizer": optimizer_vram,
56
  "activations": activations_vram
57
+ }.items()}
58
+
59
+
60
+ def inference_vram_required(model_config, training_config):
61
+ # Total inference VRAM = model size + KV cache size + activations + additional overhead
62
+ model_vram = model_memory(parameters=model_config.model_size,
63
+ precision=model_config.precision,
64
+ mixed_precision=model_config.mixed_precision)
65
+ kv_cache_vram = kv_cache_memory(batch_size=training_config.micro_batch_size,
66
+ total_sequence_length=model_config.total_sequence_length,
67
+ num_layers=model_config.num_layers,
68
+ num_heads=model_config.num_heads,
69
+ hidden_size=model_config.hidden_size,
70
+ precision=model_config.precision)
71
+ activations_vram = activations_memory_per_layer(sequence_length=model_config.sequence_length,
72
+ micro_batch_size=training_config.micro_batch_size,
73
+ hidden_size=model_config.hidden_size,
74
+ num_heads=model_config.num_heads)
75
+ total_vram = model_vram + kv_cache_vram + activations_vram
76
+ return {k: round(v, 2) for k, v in {
77
+ "total": total_vram,
78
+ "model": model_vram,
79
+ "kv_cache": kv_cache_vram,
80
+ "activations": activations_vram
81
  }.items()}
vram_helpers.py CHANGED
@@ -2,11 +2,8 @@ from dataclasses import dataclass, fields
2
  from typing import Optional
3
 
4
 
5
- PRECISION_TO_BYTES = {"float32": 4,
6
- "fp32": 4,
7
- "float16": 2,
8
  "fp16": 2,
9
- "bfloat16": 2,
10
  "bf16": 2,
11
  "int8": 1,
12
  "int4": 0.5}
@@ -17,6 +14,7 @@ class ModelConfig:
17
  model_size: float
18
  hidden_size: int
19
  sequence_length: int
 
20
  num_layers: int
21
  num_heads: int
22
  mixed_precision: bool = False
@@ -27,6 +25,8 @@ class ModelConfig:
27
  self.model_size = round(get_model_size_from_config(config) / 10**9, 2)
28
  self.hidden_size = config["hidden_size"]
29
  self.sequence_length = config["max_position_embeddings"]
 
 
30
  self.num_layers = config["num_hidden_layers"]
31
  self.num_heads = config["num_attention_heads"]
32
 
@@ -38,6 +38,7 @@ class TrainingConfig:
38
  zero_stage: int
39
  qlora: bool = False
40
  gradient_checkpointing: bool = False
 
41
 
42
  # Utility function to filter params based on dataclass fields
43
  def filter_params_for_dataclass(dataclass_type, params):
@@ -88,16 +89,33 @@ def gradients_memory(parameters, precision = "fp32"):
88
  return parameters * PRECISION_TO_BYTES[precision]
89
 
90
  def optimizer_memory(parameters, optimizer= "adamw", precision = "fp32"):
91
- optimizer_choices = {"adam": 3, # Adam: stores precision copies of the optimizer parameters, momentum, and variance -> 4 + 4 + 4 = 12 bytes per model parameter
 
92
  "adamw": 3, # AdamW: Same for Adam
93
  "sgd": 2, # For SGD: optimier parameters and gradients -> 4 + 4 = 8 bytes per model parameter
94
  "adamw_8bit": 1.5, # Adam 8-bit: same for Adam-> 2 + 2 + 2 = 6 bytes per model parameter
95
  }
96
  return optimizer_choices[optimizer] * parameters * PRECISION_TO_BYTES[precision]
97
 
 
 
 
 
 
 
 
 
 
 
 
98
  def activations_memory(num_layers, sequence_length, micro_batch_size, hidden_size, num_heads):
99
  # Reference: https://arxiv.org/pdf/2205.05198
100
  # Activations assumed to be in 16-bit floating precision
101
- bytes_per_layer = sequence_length * micro_batch_size * hidden_size * (34 + 5 * (num_heads * sequence_length / hidden_size))
102
  bytes_model = bytes_per_layer * num_layers
103
- return bytes_model / 10**9
 
 
 
 
 
 
2
  from typing import Optional
3
 
4
 
5
+ PRECISION_TO_BYTES = {"fp32": 4,
 
 
6
  "fp16": 2,
 
7
  "bf16": 2,
8
  "int8": 1,
9
  "int4": 0.5}
 
14
  model_size: float
15
  hidden_size: int
16
  sequence_length: int
17
+ total_sequence_length: int # for inference = prompt + output tokens
18
  num_layers: int
19
  num_heads: int
20
  mixed_precision: bool = False
 
25
  self.model_size = round(get_model_size_from_config(config) / 10**9, 2)
26
  self.hidden_size = config["hidden_size"]
27
  self.sequence_length = config["max_position_embeddings"]
28
+ if self.total_sequence_length == 0:
29
+ self.total_sequence_length = self.sequence_length
30
  self.num_layers = config["num_hidden_layers"]
31
  self.num_heads = config["num_attention_heads"]
32
 
 
38
  zero_stage: int
39
  qlora: bool = False
40
  gradient_checkpointing: bool = False
41
+ train: bool = True # False for inference
42
 
43
  # Utility function to filter params based on dataclass fields
44
  def filter_params_for_dataclass(dataclass_type, params):
 
89
  return parameters * PRECISION_TO_BYTES[precision]
90
 
91
  def optimizer_memory(parameters, optimizer= "adamw", precision = "fp32"):
92
+ optimizer_choices = {
93
+ "adam": 3, # Adam: stores precision copies of the optimizer parameters, momentum, and variance -> 4 + 4 + 4 = 12 bytes per model parameter
94
  "adamw": 3, # AdamW: Same for Adam
95
  "sgd": 2, # For SGD: optimier parameters and gradients -> 4 + 4 = 8 bytes per model parameter
96
  "adamw_8bit": 1.5, # Adam 8-bit: same for Adam-> 2 + 2 + 2 = 6 bytes per model parameter
97
  }
98
  return optimizer_choices[optimizer] * parameters * PRECISION_TO_BYTES[precision]
99
 
100
+ # def activations_memory_per_layer(sequence_length, micro_batch_size, hidden_size, num_heads):
101
+ # bytes_per_layer = sequence_length * micro_batch_size * hidden_size * (34 + 5 * (num_heads * sequence_length / hidden_size))
102
+ # return bytes_per_layer / 10**9
103
+
104
+ def activations_memory_per_layer(sequence_length, micro_batch_size, hidden_size, num_heads):
105
+ precision = "fp32"
106
+ "Returns amount of GPU VRAM (in GB) required to store intermediate activations for traditional Transformer Encoder block"
107
+ mem_bytes = PRECISION_TO_BYTES[precision] * sequence_length * micro_batch_size * hidden_size * (
108
+ 16 + 2/PRECISION_TO_BYTES[precision] + 2*num_heads*sequence_length/hidden_size + num_heads*sequence_length/(PRECISION_TO_BYTES[precision]*hidden_size))
109
+ return round(mem_bytes / 10**9, 2)
110
+
111
  def activations_memory(num_layers, sequence_length, micro_batch_size, hidden_size, num_heads):
112
  # Reference: https://arxiv.org/pdf/2205.05198
113
  # Activations assumed to be in 16-bit floating precision
114
+ bytes_per_layer = activations_memory_per_layer(sequence_length, micro_batch_size, hidden_size, num_heads)
115
  bytes_model = bytes_per_layer * num_layers
116
+ return bytes_model
117
+
118
+ def kv_cache_memory(batch_size, total_sequence_length, num_layers, num_heads, hidden_size, precision):
119
+ # Total sequence length means input prompt length + completion so we assume the context size of the model as upper bound
120
+ kv_cache_memory = 2 * batch_size * total_sequence_length * num_layers * num_heads * hidden_size * PRECISION_TO_BYTES[precision]
121
+ return kv_cache_memory / 10**9