|
import gradio as gr |
|
import matplotlib |
|
matplotlib.use('Agg') |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from matplotlib.ticker import MultipleLocator |
|
|
|
HARM_INTRO = """ |
|
The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost. |
|
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size). |
|
""" |
|
|
|
|
|
A100_flops = 312e12 |
|
H100_flops = 990e12 |
|
|
|
|
|
E = 1.62 |
|
A = 406.4 |
|
B = 410.7 |
|
alpha = 0.336 |
|
beta = 0.283 |
|
|
|
Bn = 10**9 |
|
|
|
G = ((alpha*A)/(beta*B))**(1/(alpha+beta)) |
|
|
|
|
|
def to_flops(N, D): |
|
return 6 * N * D |
|
|
|
def n_opt(C): |
|
return G * ((C/6) ** (beta / (alpha+beta))) |
|
|
|
def d_opt(C): |
|
return (1/G) * ((C/6) ** (alpha / (alpha+beta))) |
|
|
|
def compute_kd(kn): |
|
frac = (A/B)*(G**(-alpha-beta)) |
|
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta)) |
|
return kd |
|
|
|
def compute_overhead(kn, kd): |
|
return kn*kd - 1 |
|
|
|
|
|
kn_min = 0.18 |
|
kn_max = 2 |
|
|
|
kns = np.linspace(kn_min, kn_max, 100) |
|
overheads = [] |
|
for kn in kns: |
|
kd = compute_kd(kn) |
|
overheads.append(compute_overhead(kn, kd)*100) |
|
|
|
def plot_curve(kn, kd): |
|
fig, ax = plt.subplots(dpi=200, figsize=(5, 3)) |
|
plt.plot(kns, overheads, color="black", zorder=1) |
|
plt.scatter([kn], [compute_overhead(kn, kd)*100], s=100, marker="o", c="red", label="You are here!", zorder=2) |
|
plt.scatter([1.0], [0.0], marker="o", s=100, c="blue", label="Chinchilla optimal", zorder=2) |
|
plt.xlabel("Fraction of Chinchilla optimal model size") |
|
plt.ylabel("Compute overhead (%)") |
|
plt.legend(loc="best") |
|
plt.grid(True, which="both") |
|
plt.grid(True, which="minor", alpha=0.5) |
|
ax.yaxis.set_minor_locator(MultipleLocator(10)) |
|
plt.tight_layout() |
|
|
|
return fig |
|
|
|
|
|
def compute(N, D, gpu_type, gpu_util, n_gpus, gpu_price): |
|
|
|
C = to_flops(N * Bn, D * Bn) |
|
N_opt = n_opt(C) |
|
D_opt = d_opt(C) |
|
|
|
kn = Bn*N/N_opt |
|
kd = compute_kd(kn) |
|
|
|
fig = plot_curve(kn, kd) |
|
|
|
|
|
gpu_util = gpu_util/100 |
|
if gpu_type=="H100": |
|
gpu_flops = H100_flops * gpu_util |
|
else: |
|
gpu_flops = A100_flops * gpu_util |
|
gpu_hours = (C / (gpu_flops * 3600)) |
|
|
|
|
|
text = f"""\ |
|
## Training summary |
|
|
|
|Training compute| Training cost | Training time | Total GPU hours | |
|
|:----|:-------|:-------|:-------| |
|
|{C:.2E} TFLOPs | ${(gpu_hours * gpu_price)/1e6:.2f}M | {gpu_hours/(24*n_gpus):.2f} days | {gpu_hours/1_000_000:.2f}M | |
|
|
|
## Chinchilla and Training/Inference Trade-off |
|
Optimal model/dataset size for training compute and how it translates to training overhead and inference savings according to Harm's law |
|
|Chinchilla optimal model | Chinchilla optimal dataset | Training overhead | Inference savings| |
|
|:----|:-------|:----|:-------| |
|
| {N_opt/Bn:.2f}B parameters | {D_opt/Bn:.2f}B tokens | {100*compute_overhead(kn, kd):.2f}%| {100 - kn*100:.2f}% | |
|
""" |
|
|
|
return text, fig |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Train LLMs") |
|
|
|
gr.Markdown("## Training configuration") |
|
with gr.Row(): |
|
|
|
N = gr.Number(value=7, label="Model size (in B parameters):") |
|
D = gr.Number(value=2000, label="Dataset size (in B tokens):") |
|
|
|
gr.Markdown("## Cluster configuration") |
|
with gr.Row(): |
|
n_gpus = gr.Number(value=1000, label="Number of GPUs") |
|
gpu_type = gr.Dropdown(choices=["A100", "H100"], value="H100", label="GPU type") |
|
gpu_util = gr.Number(value=50, label="% GPU utilization") |
|
gpu_price = gr.Number(value=3.00, label="$/GPU/Hour") |
|
button = gr.Button("Compute!") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("## Harm's law") |
|
plot = gr.Plot(value=plt) |
|
gr.Markdown(HARM_INTRO) |
|
|
|
with gr.Column(): |
|
md = gr.Markdown("") |
|
|
|
button.click(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot]) |
|
demo.load(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot]) |
|
demo.launch() |