train-llm / app.py
leandro
fix util
4f09294
import gradio as gr
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
HARM_INTRO = """
The Chinchilla scaling laws focus on optimally scaling training compute but often we also care about inference cost.
This tool follows [Harm de Vries' blog post](https://www.harmdevries.com/post/model-size-vs-compute-overhead/) and visualizes the tradeoff between training comput and inference cost (i.e. model size).
"""
### GPU specs:
A100_flops = 312e12
H100_flops = 990e12
### CHINCHILLA PARAMS:
E = 1.62
A = 406.4
B = 410.7
alpha = 0.336
beta = 0.283
Bn = 10**9
G = ((alpha*A)/(beta*B))**(1/(alpha+beta))
### FUNCTIONS
def to_flops(N, D):
return 6 * N * D
def n_opt(C):
return G * ((C/6) ** (beta / (alpha+beta)))
def d_opt(C):
return (1/G) * ((C/6) ** (alpha / (alpha+beta)))
def compute_kd(kn):
frac = (A/B)*(G**(-alpha-beta))
kd = (1-((kn**-alpha -1)*frac))**(1/(-beta))
return kd
def compute_overhead(kn, kd):
return kn*kd - 1
### PRECOMPUTE CURVE:
kn_min = 0.18
kn_max = 2
kns = np.linspace(kn_min, kn_max, 100)
overheads = []
for kn in kns:
kd = compute_kd(kn)
overheads.append(compute_overhead(kn, kd)*100)
def plot_curve(kn, kd):
fig, ax = plt.subplots(dpi=200, figsize=(5, 3))
plt.plot(kns, overheads, color="black", zorder=1)
plt.scatter([kn], [compute_overhead(kn, kd)*100], s=100, marker="o", c="red", label="You are here!", zorder=2)
plt.scatter([1.0], [0.0], marker="o", s=100, c="blue", label="Chinchilla optimal", zorder=2)
plt.xlabel("Fraction of Chinchilla optimal model size")
plt.ylabel("Compute overhead (%)")
plt.legend(loc="best")
plt.grid(True, which="both")
plt.grid(True, which="minor", alpha=0.5)
ax.yaxis.set_minor_locator(MultipleLocator(10))
plt.tight_layout()
return fig
def compute(N, D, gpu_type, gpu_util, n_gpus, gpu_price):
C = to_flops(N * Bn, D * Bn)
N_opt = n_opt(C)
D_opt = d_opt(C)
kn = Bn*N/N_opt
kd = compute_kd(kn)
fig = plot_curve(kn, kd)
gpu_util = gpu_util/100
if gpu_type=="H100":
gpu_flops = H100_flops * gpu_util
else:
gpu_flops = A100_flops * gpu_util
gpu_hours = (C / (gpu_flops * 3600))
text = f"""\
## Training summary
|Training compute| Training cost | Training time | Total GPU hours |
|:----|:-------|:-------|:-------|
|{C:.2E} TFLOPs | ${(gpu_hours * gpu_price)/1e6:.2f}M | {gpu_hours/(24*n_gpus):.2f} days | {gpu_hours/1_000_000:.2f}M |
## Chinchilla and Training/Inference Trade-off
Optimal model/dataset size for training compute and how it translates to training overhead and inference savings according to Harm's law
|Chinchilla optimal model | Chinchilla optimal dataset | Training overhead | Inference savings|
|:----|:-------|:----|:-------|
| {N_opt/Bn:.2f}B parameters | {D_opt/Bn:.2f}B tokens | {100*compute_overhead(kn, kd):.2f}%| {100 - kn*100:.2f}% |
"""
return text, fig
with gr.Blocks() as demo:
gr.Markdown("# Train LLMs")
gr.Markdown("## Training configuration")
with gr.Row():
N = gr.Number(value=7, label="Model size (in B parameters):")
D = gr.Number(value=2000, label="Dataset size (in B tokens):")
gr.Markdown("## Cluster configuration")
with gr.Row():
n_gpus = gr.Number(value=1000, label="Number of GPUs")
gpu_type = gr.Dropdown(choices=["A100", "H100"], value="H100", label="GPU type")
gpu_util = gr.Number(value=50, label="% GPU utilization")
gpu_price = gr.Number(value=3.00, label="$/GPU/Hour")
button = gr.Button("Compute!")
with gr.Row():
with gr.Column():
gr.Markdown("## Harm's law")
plot = gr.Plot(value=plt)
gr.Markdown(HARM_INTRO)
with gr.Column():
md = gr.Markdown("")
button.click(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot])
demo.load(fn=compute, inputs=[N, D, gpu_type, gpu_util, n_gpus, gpu_price], outputs=[md, plot])
demo.launch()