qgallouedec HF staff commited on
Commit
5ee5935
·
verified ·
1 Parent(s): ac8821a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -2,7 +2,7 @@ import gradio as gr
2
  import matplotlib.pyplot as plt
3
 
4
 
5
- def plot_forecast(num_param, precision, batch_size, seq_len):
6
  # Convert number (input as B)
7
  num_param = float(num_param) * 1e9
8
 
@@ -19,6 +19,9 @@ def plot_forecast(num_param, precision, batch_size, seq_len):
19
  K = 4.6894e-4 * num_param + 1.8494e6
20
  y3 = batch_size * seq_len * K * precision / 1e9
21
 
 
 
 
22
  # Gradients: N×precision
23
  y4 = num_param * precision / 1e9
24
 
@@ -86,16 +89,26 @@ def plot_forecast(num_param, precision, batch_size, seq_len):
86
  return fig
87
 
88
 
89
- demo = gr.Interface(
90
- plot_forecast,
91
- [
92
- gr.Number(3, label="Number of parameters (B)"),
93
- gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision"),
94
- gr.Slider(1, 128, label="Batch size", step=1, value=8),
95
- gr.Slider(1, 1000, label="Sequence Length", step=1, value=256),
96
- ],
97
- gr.Plot(label="forecast", format="png"),
98
- )
 
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == "__main__":
101
  demo.launch()
 
2
  import matplotlib.pyplot as plt
3
 
4
 
5
+ def plot_forecast(num_param, precision, grad_ckpt, batch_size, seq_len):
6
  # Convert number (input as B)
7
  num_param = float(num_param) * 1e9
8
 
 
19
  K = 4.6894e-4 * num_param + 1.8494e6
20
  y3 = batch_size * seq_len * K * precision / 1e9
21
 
22
+ if grad_ckpt:
23
+ y3 /= 5
24
+
25
  # Gradients: N×precision
26
  y4 = num_param * precision / 1e9
27
 
 
89
  return fig
90
 
91
 
92
+ with gr.Blocks() as demo:
93
+ with gr.Row():
94
+ with gr.Column():
95
+ with gr.Accordion("Model"):
96
+ num_param = gr.Number(3, label="Number of parameters (B)")
97
+ precision = gr.Radio(["float32", "float16", "bfloat16"], value="float32", label="Precision")
98
+ with gr.Accordion("Data"):
99
+ batch_size = gr.Slider(1, 128, label="Batch size", step=1, value=8)
100
+ seq_len = gr.Slider(1, 1000, label="Sequence Length", step=1, value=256)
101
+
102
+ with gr.Accordion("Advanced", open=False):
103
+ with gr.Accordion("Data"):
104
+ grad_ckpt = gr.Checkbox(False, label="Gradient Checkpointing")
105
+
106
+ submit = gr.Button("Submit")
107
+
108
+ with gr.Column():
109
+ plot = gr.Plot(label="forecast", format="png")
110
+
111
+ submit.click(plot_forecast, [num_param, precision, grad_ckpt, batch_size, seq_len], plot)
112
 
113
  if __name__ == "__main__":
114
  demo.launch()