Vokturz commited on
Commit
ae0011e
1 Parent(s): 563f0ef

improve the way results are displayed

Browse files
Files changed (1) hide show
  1. src/app.py +27 -1
src/app.py CHANGED
@@ -22,6 +22,24 @@ def get_gpu_specs():
22
  return pd.read_csv("data/gpu_specs.csv")
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def get_name(index):
27
  row = gpu_specs.iloc[index]
@@ -93,10 +111,17 @@ _memory_table = _memory_table.apply(np.ceil).astype(int).drop(columns=['Paramete
93
  _memory_table.columns = ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']
94
  _memory_table = _memory_table.stack().reset_index()
95
  _memory_table.columns = ['dtype', 'Variable', 'Number of GPUs']
96
-
97
  col1, col2 = st.columns([1,1.3])
98
  with col1:
99
  st.write(f"#### [{model_name}](https://huggingface.co/{model_name}) ({memory_table.iloc[3,0]:.1f}B)")
 
 
 
 
 
 
 
100
  st.write(memory_table.iloc[[0, 1, 2, 4]])
101
  with col2:
102
  num_colors= 4
@@ -106,3 +131,4 @@ with col2:
106
  , xaxis_tickfont_size=14, yaxis_tickfont_size=16, yaxis_dtick='1')
107
  st.plotly_chart(fig, use_container_width=True)
108
 
 
 
22
  return pd.read_csv("data/gpu_specs.csv")
23
 
24
 
25
+ def show_gpu_info(info, trainable_params=0):
26
+ for var in ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']:
27
+ _info = info.loc[var]
28
+ if _info['Number of GPUs'] >= 3:
29
+ func = st.error
30
+ icon = "⛔"
31
+ elif _info['Number of GPUs'] == 2:
32
+ func = st.warning
33
+ icon = "⚠️"
34
+ else:
35
+ func = st.success
36
+ icon = "✅"
37
+
38
+ msg = f"You require **{_info['Number of GPUs']}** GPUs for **{var}**"
39
+ if var == 'LoRa Fine-tuning':
40
+ msg += f" ({trainable_params}%)"
41
+ func(msg, icon=icon)
42
+
43
 
44
  def get_name(index):
45
  row = gpu_specs.iloc[index]
 
111
  _memory_table.columns = ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']
112
  _memory_table = _memory_table.stack().reset_index()
113
  _memory_table.columns = ['dtype', 'Variable', 'Number of GPUs']
114
+ st.write()
115
  col1, col2 = st.columns([1,1.3])
116
  with col1:
117
  st.write(f"#### [{model_name}](https://huggingface.co/{model_name}) ({memory_table.iloc[3,0]:.1f}B)")
118
+
119
+ dtypes = memory_table.columns.tolist()[::-1]
120
+ tabs = st.tabs(dtypes)
121
+ for dtype, tab in zip(dtypes, tabs):
122
+ with tab:
123
+ info = _memory_table[_memory_table['dtype'] == dtype].set_index('Variable')
124
+ show_gpu_info(info, lora_pct)
125
  st.write(memory_table.iloc[[0, 1, 2, 4]])
126
  with col2:
127
  num_colors= 4
 
131
  , xaxis_tickfont_size=14, yaxis_tickfont_size=16, yaxis_dtick='1')
132
  st.plotly_chart(fig, use_container_width=True)
133
 
134
+