Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
# import shutil
|
4 |
+
# from pathlib import Path
|
5 |
+
# current_dir = Path(__file__).parent
|
6 |
+
|
7 |
+
header = """
|
8 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/4.1.3/css/bootstrap.min.css">
|
9 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/all.min.css">
|
10 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome-animation/0.2.1/font-awesome-animation.min.css">
|
11 |
+
"""
|
12 |
+
|
13 |
+
css = """
|
14 |
+
|
15 |
+
body
|
16 |
+
{
|
17 |
+
background: #000e29;
|
18 |
+
}
|
19 |
+
|
20 |
+
.alert>.start-icon {
|
21 |
+
margin-right: 0;
|
22 |
+
min-width: 20px;
|
23 |
+
text-align: center;
|
24 |
+
}
|
25 |
+
|
26 |
+
.alert>.start-icon {
|
27 |
+
margin-right: 5px;
|
28 |
+
}
|
29 |
+
|
30 |
+
.greencross
|
31 |
+
{
|
32 |
+
font-size:18px;
|
33 |
+
color: #25ff0b;
|
34 |
+
text-shadow: none;
|
35 |
+
}
|
36 |
+
|
37 |
+
.alert-simple.alert-success
|
38 |
+
{
|
39 |
+
border: 1px solid rgba(36, 241, 6, 0.46);
|
40 |
+
background-color: rgba(7, 149, 66, 0.12156862745098039);
|
41 |
+
box-shadow: 0px 0px 2px #259c08;
|
42 |
+
color: #0ad406;
|
43 |
+
text-shadow: 2px 1px #00040a;
|
44 |
+
transition:0.5s;
|
45 |
+
cursor:pointer;
|
46 |
+
}
|
47 |
+
.alert-success:hover{
|
48 |
+
background-color: rgba(7, 149, 66, 0.35);
|
49 |
+
transition:0.5s;
|
50 |
+
}
|
51 |
+
.alert-simple.alert-info
|
52 |
+
{
|
53 |
+
border: 1px solid rgba(6, 44, 241, 0.46);
|
54 |
+
background-color: rgba(7, 73, 149, 0.12156862745098039);
|
55 |
+
box-shadow: 0px 0px 2px #0396ff;
|
56 |
+
color: #0396ff;
|
57 |
+
text-shadow: 2px 1px #00040a;
|
58 |
+
transition:0.5s;
|
59 |
+
cursor:pointer;
|
60 |
+
}
|
61 |
+
|
62 |
+
.alert-info:hover
|
63 |
+
{
|
64 |
+
background-color: rgba(7, 73, 149, 0.35);
|
65 |
+
transition:0.5s;
|
66 |
+
}
|
67 |
+
|
68 |
+
.blue-cross
|
69 |
+
{
|
70 |
+
font-size: 18px;
|
71 |
+
color: #0bd2ff;
|
72 |
+
text-shadow: none;
|
73 |
+
}
|
74 |
+
|
75 |
+
.alert-simple.alert-warning
|
76 |
+
{
|
77 |
+
border: 1px solid rgba(241, 142, 6, 0.81);
|
78 |
+
background-color: rgba(220, 128, 1, 0.16);
|
79 |
+
box-shadow: 0px 0px 2px #ffb103;
|
80 |
+
color: #ffb103;
|
81 |
+
text-shadow: 2px 1px #00040a;
|
82 |
+
transition:0.5s;
|
83 |
+
cursor:pointer;
|
84 |
+
}
|
85 |
+
|
86 |
+
.alert-warning:hover{
|
87 |
+
background-color: rgba(220, 128, 1, 0.33);
|
88 |
+
transition:0.5s;
|
89 |
+
}
|
90 |
+
|
91 |
+
.warning
|
92 |
+
{
|
93 |
+
font-size: 18px;
|
94 |
+
color: #ffb40b;
|
95 |
+
text-shadow: none;
|
96 |
+
}
|
97 |
+
|
98 |
+
.alert-simple.alert-danger
|
99 |
+
{
|
100 |
+
border: 1px solid rgba(241, 6, 6, 0.81);
|
101 |
+
background-color: rgba(220, 17, 1, 0.16);
|
102 |
+
box-shadow: 0px 0px 2px #ff0303;
|
103 |
+
color: #ff0303;
|
104 |
+
text-shadow: 2px 1px #00040a;
|
105 |
+
transition:0.5s;
|
106 |
+
cursor:pointer;
|
107 |
+
}
|
108 |
+
|
109 |
+
.alert-danger:hover
|
110 |
+
{
|
111 |
+
background-color: rgba(220, 17, 1, 0.33);
|
112 |
+
transition:0.5s;
|
113 |
+
}
|
114 |
+
|
115 |
+
.danger
|
116 |
+
{
|
117 |
+
font-size: 18px;
|
118 |
+
color: #ff0303;
|
119 |
+
text-shadow: none;
|
120 |
+
}
|
121 |
+
|
122 |
+
.alert-simple.alert-primary
|
123 |
+
{
|
124 |
+
border: 1px solid rgba(6, 241, 226, 0.81);
|
125 |
+
background-color: rgba(1, 204, 220, 0.16);
|
126 |
+
box-shadow: 0px 0px 2px #03fff5;
|
127 |
+
color: #03d0ff;
|
128 |
+
text-shadow: 2px 1px #00040a;
|
129 |
+
transition:0.5s;
|
130 |
+
cursor:pointer;
|
131 |
+
}
|
132 |
+
|
133 |
+
.alert-primary:hover{
|
134 |
+
background-color: rgba(1, 204, 220, 0.33);
|
135 |
+
transition:0.5s;
|
136 |
+
}
|
137 |
+
|
138 |
+
.alertprimary
|
139 |
+
{
|
140 |
+
font-size: 18px;
|
141 |
+
color: #03d0ff;
|
142 |
+
text-shadow: none;
|
143 |
+
}
|
144 |
+
|
145 |
+
.square_box {
|
146 |
+
position: absolute;
|
147 |
+
-webkit-transform: rotate(45deg);
|
148 |
+
-ms-transform: rotate(45deg);
|
149 |
+
transform: rotate(45deg);
|
150 |
+
border-top-left-radius: 45px;
|
151 |
+
opacity: 0.302;
|
152 |
+
}
|
153 |
+
|
154 |
+
.square_box.box_three {
|
155 |
+
background-image: -moz-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
|
156 |
+
background-image: -webkit-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
|
157 |
+
background-image: -ms-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
|
158 |
+
opacity: 0.059;
|
159 |
+
left: -80px;
|
160 |
+
top: -60px;
|
161 |
+
width: 500px;
|
162 |
+
height: 500px;
|
163 |
+
border-radius: 45px;
|
164 |
+
}
|
165 |
+
|
166 |
+
.square_box.box_four {
|
167 |
+
background-image: -moz-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
|
168 |
+
background-image: -webkit-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
|
169 |
+
background-image: -ms-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
|
170 |
+
opacity: 0.059;
|
171 |
+
left: 150px;
|
172 |
+
top: -25px;
|
173 |
+
width: 550px;
|
174 |
+
height: 550px;
|
175 |
+
border-radius: 45px;
|
176 |
+
}
|
177 |
+
|
178 |
+
.alert:before {
|
179 |
+
content: '';
|
180 |
+
position: absolute;
|
181 |
+
width: 0;
|
182 |
+
height: calc(100% - 44px);
|
183 |
+
border-left: 1px solid;
|
184 |
+
border-right: 2px solid;
|
185 |
+
border-bottom-right-radius: 3px;
|
186 |
+
border-top-right-radius: 3px;
|
187 |
+
left: 0;
|
188 |
+
top: 50%;
|
189 |
+
transform: translate(0,-50%);
|
190 |
+
height: 20px;
|
191 |
+
}
|
192 |
+
|
193 |
+
.fa-times
|
194 |
+
{
|
195 |
+
-webkit-animation: blink-1 2s infinite both;
|
196 |
+
animation: blink-1 2s infinite both;
|
197 |
+
}
|
198 |
+
|
199 |
+
|
200 |
+
/**
|
201 |
+
* ----------------------------------------
|
202 |
+
* animation blink-1
|
203 |
+
* ----------------------------------------
|
204 |
+
*/
|
205 |
+
@-webkit-keyframes blink-1 {
|
206 |
+
0%,
|
207 |
+
50%,
|
208 |
+
100% {
|
209 |
+
opacity: 1;
|
210 |
+
}
|
211 |
+
25%,
|
212 |
+
75% {
|
213 |
+
opacity: 0;
|
214 |
+
}
|
215 |
+
}
|
216 |
+
@keyframes blink-1 {
|
217 |
+
0%,
|
218 |
+
50%,
|
219 |
+
100% {
|
220 |
+
opacity: 1;
|
221 |
+
}
|
222 |
+
25%,
|
223 |
+
75% {
|
224 |
+
opacity: 0;
|
225 |
+
}
|
226 |
+
}
|
227 |
+
|
228 |
+
/**
|
229 |
+
Custom CSS for Gradio
|
230 |
+
*/
|
231 |
+
|
232 |
+
"""
|
233 |
+
|
234 |
+
info_alert_text = """
|
235 |
+
<div class="alert fade alert-simple alert-info alert-dismissible text-left font__family-montserrat font__size-16 font__weight-light brk-library-rendered rendered show" role="alert" data-brk-library="component__alert">
|
236 |
+
<i class="start-icon fa fa-info-circle faa-shake animated"></i>
|
237 |
+
<strong class="font__weight-semibold">Heads up!</strong>
|
238 |
+
<p class="font__weight-light">
|
239 |
+
The GPU memory usage estimation above only show how much memory the model will take on the GPU. It's not the actual memory usage needed to train the model or use it for inference.
|
240 |
+
You can find more information <a href="https://amenalahassa.github.io/amenalahassa/posts/model_gpu.html" target="_blank">here</a>.
|
241 |
+
</p>
|
242 |
+
</div>
|
243 |
+
"""
|
244 |
+
|
245 |
+
error_alert_text = """
|
246 |
+
<div class="alert fade alert-simple alert-danger alert-dismissible text-left font__family-montserrat font__size-16 font__weight-light brk-library-rendered rendered show" role="alert" data-brk-library="component__alert">
|
247 |
+
<i class="start-icon far fa-times-circle faa-pulse animated"></i>
|
248 |
+
<strong class="font__weight-semibold">Warning</strong>
|
249 |
+
<p class="font__weight-light">{error}</p>
|
250 |
+
</div>
|
251 |
+
"""
|
252 |
+
|
253 |
+
def get_model_size(model_size, precision):
|
254 |
+
if precision == "fp16":
|
255 |
+
model_size *= 0.5
|
256 |
+
elif precision == "int8":
|
257 |
+
model_size *= 0.125
|
258 |
+
elif precision == "int4":
|
259 |
+
model_size *= 0.0625
|
260 |
+
return model_size
|
261 |
+
|
262 |
+
def get_model_size_from_checkpoint(file, precision):
|
263 |
+
from pathlib import Path
|
264 |
+
|
265 |
+
num_params = 0
|
266 |
+
error = None
|
267 |
+
filepath = Path(file)
|
268 |
+
extension = filepath.suffix[1:]
|
269 |
+
|
270 |
+
try:
|
271 |
+
|
272 |
+
if extension in ["pth", "pt"]:
|
273 |
+
import torch
|
274 |
+
checkpoint = torch.load(file, weights_only=False)
|
275 |
+
|
276 |
+
# If the checkpoint contains only the state_dict, use it directly
|
277 |
+
if "state_dict" in checkpoint:
|
278 |
+
state_dict = checkpoint["state_dict"]
|
279 |
+
else:
|
280 |
+
state_dict = checkpoint
|
281 |
+
|
282 |
+
# Calculate the total number of parameters
|
283 |
+
# Assuming that the model is composed of multiple children modules/models
|
284 |
+
for child in state_dict.values():
|
285 |
+
# Check if the parameter is a model
|
286 |
+
if isinstance(child, torch.nn.Module):
|
287 |
+
# Calculate the number of parameters in the model
|
288 |
+
num_params += sum(p.numel() for p in child.parameters())
|
289 |
+
|
290 |
+
# Calculate the number of parameters of direct children/layers
|
291 |
+
for param in state_dict.values():
|
292 |
+
# Check if the parameter has the attribute `numel`
|
293 |
+
if hasattr(param, "numel"):
|
294 |
+
num_params += param.numel()
|
295 |
+
|
296 |
+
elif extension in ["h5", "hdf5"]:
|
297 |
+
from tensorflow.keras.models import load_model
|
298 |
+
|
299 |
+
model = load_model(file)
|
300 |
+
model.compile()
|
301 |
+
# Calculate the total number of parameters
|
302 |
+
num_params = model.count_params()
|
303 |
+
|
304 |
+
elif extension in ["onnx"]:
|
305 |
+
import onnx
|
306 |
+
from onnx import numpy_helper
|
307 |
+
|
308 |
+
model = onnx.load(file)
|
309 |
+
num_params = sum([numpy_helper.to_array(tensor).size for tensor in model.graph.initializer])
|
310 |
+
|
311 |
+
else:
|
312 |
+
error = "Unsupported file format. Please upload a PyTorch/Keras/ONNX model checkpoint."
|
313 |
+
|
314 |
+
except Exception as e:
|
315 |
+
error = str(e)
|
316 |
+
|
317 |
+
if num_params == 0 and error is None:
|
318 |
+
error = "No parameters found in the model checkpoint"
|
319 |
+
|
320 |
+
return get_model_size(num_params, precision), error
|
321 |
+
|
322 |
+
def get_model_size_from_hf(model_name, precision):
|
323 |
+
from transformers import AutoModel
|
324 |
+
num_params = 0
|
325 |
+
error = None
|
326 |
+
|
327 |
+
try:
|
328 |
+
model = AutoModel.from_pretrained(model_name)
|
329 |
+
num_params = sum(param.numel() for param in model.parameters())
|
330 |
+
except Exception as e:
|
331 |
+
error = str(e)
|
332 |
+
|
333 |
+
return get_model_size(num_params, precision), error
|
334 |
+
|
335 |
+
def compute_gpu_memory(input_model_size, model_precision):
|
336 |
+
P = input_model_size
|
337 |
+
Q = 32 if model_precision == "fp32" else 16 if model_precision == "fp16" else 8 if model_precision == "int8" else 4
|
338 |
+
memory = P * Q / 8 / 1024 / 1024 / 1024
|
339 |
+
return [f"{memory} GB", True if memory > 0 else False]
|
340 |
+
|
341 |
+
# def delete_directory(req: gr.Request):
|
342 |
+
# if not req.username:
|
343 |
+
# return
|
344 |
+
# user_dir: Path = current_dir / req.username
|
345 |
+
# shutil.rmtree(str(user_dir))
|
346 |
+
|
347 |
+
with gr.Blocks(head=header, css=css, delete_cache=(43200,43200)) as demo:
|
348 |
+
model_precision = gr.State("fp32")
|
349 |
+
model_source = gr.State("import")
|
350 |
+
uploaded_file = gr.State(None)
|
351 |
+
hf_model_name = gr.State()
|
352 |
+
msg_error = gr.State()
|
353 |
+
supported_file_types = ["pt", "pth", "h5", "hdf5", "onnx"]
|
354 |
+
has_computed_gpu_memory = gr.State(False)
|
355 |
+
|
356 |
+
gr.Markdown(
|
357 |
+
"""
|
358 |
+
# Wondering how much memory your model will take?
|
359 |
+
This app helps you estimate the memory usage of a model on GPU.
|
360 |
+
"""
|
361 |
+
)
|
362 |
+
|
363 |
+
checkpoint_radio = gr.Radio(
|
364 |
+
[("Import model checkpoint", "import"), ("Use model from Hugging Face", "hf")],
|
365 |
+
value="import",
|
366 |
+
label="Choose a model source"
|
367 |
+
)
|
368 |
+
|
369 |
+
checkpoint_radio.change(fn=lambda x: x, inputs=checkpoint_radio, outputs=model_source)
|
370 |
+
@gr.render(inputs=[model_source, msg_error])
|
371 |
+
def rendering(source, runtime_error):
|
372 |
+
|
373 |
+
with gr.Row():
|
374 |
+
with gr.Column():
|
375 |
+
if source == "import":
|
376 |
+
gr.Markdown("Upload a model checkpoint file. Supported formats are PyTorch, Keras, and ONNX.")
|
377 |
+
uploader = gr.File(label=f'Upload Model Checkpoint [{" | ".join(supported_file_types)}]', file_types=supported_file_types, file_count="single", type="filepath")
|
378 |
+
uploader.upload(fn=lambda x: x, inputs=uploader, outputs=uploaded_file)
|
379 |
+
else:
|
380 |
+
mode_name_textbox = gr.Textbox(label="Model Name", placeholder="e.g. facebook/bart-large")
|
381 |
+
mode_name_textbox.change(fn=lambda x: x, inputs=mode_name_textbox, outputs=hf_model_name)
|
382 |
+
|
383 |
+
precision_radio = gr.Radio(
|
384 |
+
[
|
385 |
+
("FP32 (32-bit floating point)", "fp32"),
|
386 |
+
("FP16 (half/BF16) (16-bit floating point)", "fp16"),
|
387 |
+
("INT8 (8-bit integer)", "int8"),
|
388 |
+
("INT4 (4-bit integer)", "int4"),
|
389 |
+
],
|
390 |
+
value=model_precision.value,
|
391 |
+
label="Select the Precision or Size of the model parameters"
|
392 |
+
)
|
393 |
+
precision_radio.change(fn=lambda x: x, inputs=precision_radio, outputs=model_precision)
|
394 |
+
compute_btn = gr.Button("Compute")
|
395 |
+
|
396 |
+
with gr.Column():
|
397 |
+
|
398 |
+
num_params = gr.Number(label="Number of Parameters")
|
399 |
+
gpu_memory = gr.Textbox(label="GPU memory expressed in Gigabyte(GB)", show_copy_button=True)
|
400 |
+
num_params.change(compute_gpu_memory, inputs=[num_params, model_precision], outputs=[gpu_memory, has_computed_gpu_memory])
|
401 |
+
|
402 |
+
if runtime_error:
|
403 |
+
gr.HTML(error_alert_text.format(error=runtime_error))
|
404 |
+
|
405 |
+
info = gr.HTML(info_alert_text, visible=False)
|
406 |
+
gpu_memory.change(fn=lambda x: gr.HTML(info_alert_text, visible=True) if x != "0.0 GB" else gr.HTML(info_alert_text, visible=False), inputs=gpu_memory, outputs=info)
|
407 |
+
|
408 |
+
def compute_model_size(input_source, input_precision, input_file, input_hf_model):
|
409 |
+
if input_source == "import":
|
410 |
+
model_size, error = get_model_size_from_checkpoint(input_file, input_precision)
|
411 |
+
else:
|
412 |
+
model_size, error = get_model_size_from_hf(input_hf_model, input_precision)
|
413 |
+
|
414 |
+
return [model_size, error]
|
415 |
+
|
416 |
+
compute_btn.click(compute_model_size, inputs=[model_source, model_precision, uploaded_file, hf_model_name], outputs=[num_params, msg_error])
|
417 |
+
|
418 |
+
# demo.unload(delete_directory)
|
419 |
+
|
420 |
+
demo.launch()
|