qgallouedec HF staff commited on
Commit
addbb37
·
verified ·
1 Parent(s): 008f6fd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+
4
+
5
+ def plot_forecast(num_param, batch_size, precision, seq_len):
6
+ # Convert number (input as B)
7
+ num_param = float(num_param) * 1e9
8
+
9
+ # Convert precision to bytes
10
+ precision = {"float32": 4, "float16": 2, "bfloat16": 2}[precision]
11
+
12
+ # Model Parameters: N×precision
13
+ y1 = num_param * precision / (1024**3)
14
+
15
+ # Optimizer States: 2×N×precision
16
+ y2 = 2 * num_param * precision / (1024**3)
17
+
18
+ # Activations: B×Sequence Length×K×precision
19
+ K = 4.6894e-04 * num_param + 1.8494e06
20
+ y3 = batch_size * seq_len * K * precision / (1024**3)
21
+
22
+ # Gradients: N×precision
23
+ y4 = num_param * precision / (1024**3)
24
+
25
+ fig = plt.figure(figsize=(4, 4))
26
+ ax = fig.add_subplot(111)
27
+
28
+ # Create stacked bars
29
+ ax.bar(0, y1, color="r")
30
+ ax.bar(0, y2, bottom=y1, color="b")
31
+ ax.bar(0, y3, bottom=y1 + y2, color="g")
32
+ ax.bar(0, y4, bottom=y1 + y2 + y3, color="y")
33
+
34
+ # Add text labels inside the bars
35
+ ax.text(0, y1 / 2, "Model Parameters", ha="center", va="center", color="white", fontweight="bold")
36
+ ax.text(0, y1 + y2 / 2, "Optimizer States", ha="center", va="center", color="white", fontweight="bold")
37
+ ax.text(0, y1 + y2 + y3 / 2, "Activations", ha="center", va="center", color="white", fontweight="bold")
38
+ ax.text(0, y1 + y2 + y3 + y4 / 2, "Gradients", ha="center", va="center", color="white", fontweight="bold")
39
+
40
+ # remove x axis
41
+ ax.xaxis.set_visible(False)
42
+
43
+ # Set GB as the unit for the y-axis
44
+ ax.set_ylabel("Memory (GB)")
45
+ fig.tight_layout()
46
+ return fig
47
+
48
+
49
+ demo = gr.Interface(
50
+ plot_forecast,
51
+ [
52
+ gr.Number(7, label="Number of parameters (B)"),
53
+ gr.Radio([1, 2, 4, 8, 16, 32, 64, 128], value=8, label="Batch size"),
54
+ gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision"),
55
+ gr.Slider(1, 1024, label="Sequence Length", step=1, value=128),
56
+ ],
57
+ gr.Plot(label="forecast", format="png"),
58
+ )
59
+
60
+ if __name__ == "__main__":
61
+ demo.launch()