konradhugging commited on
Commit
562e37d
·
verified ·
1 Parent(s): 2dbaced

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +420 -0
app.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ # import shutil
4
+ # from pathlib import Path
5
+ # current_dir = Path(__file__).parent
6
+
7
+ header = """
8
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/twitter-bootstrap/4.1.3/css/bootstrap.min.css">
9
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.11.2/css/all.min.css">
10
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome-animation/0.2.1/font-awesome-animation.min.css">
11
+ """
12
+
13
+ css = """
14
+
15
+ body
16
+ {
17
+ background: #000e29;
18
+ }
19
+
20
+ .alert>.start-icon {
21
+ margin-right: 0;
22
+ min-width: 20px;
23
+ text-align: center;
24
+ }
25
+
26
+ .alert>.start-icon {
27
+ margin-right: 5px;
28
+ }
29
+
30
+ .greencross
31
+ {
32
+ font-size:18px;
33
+ color: #25ff0b;
34
+ text-shadow: none;
35
+ }
36
+
37
+ .alert-simple.alert-success
38
+ {
39
+ border: 1px solid rgba(36, 241, 6, 0.46);
40
+ background-color: rgba(7, 149, 66, 0.12156862745098039);
41
+ box-shadow: 0px 0px 2px #259c08;
42
+ color: #0ad406;
43
+ text-shadow: 2px 1px #00040a;
44
+ transition:0.5s;
45
+ cursor:pointer;
46
+ }
47
+ .alert-success:hover{
48
+ background-color: rgba(7, 149, 66, 0.35);
49
+ transition:0.5s;
50
+ }
51
+ .alert-simple.alert-info
52
+ {
53
+ border: 1px solid rgba(6, 44, 241, 0.46);
54
+ background-color: rgba(7, 73, 149, 0.12156862745098039);
55
+ box-shadow: 0px 0px 2px #0396ff;
56
+ color: #0396ff;
57
+ text-shadow: 2px 1px #00040a;
58
+ transition:0.5s;
59
+ cursor:pointer;
60
+ }
61
+
62
+ .alert-info:hover
63
+ {
64
+ background-color: rgba(7, 73, 149, 0.35);
65
+ transition:0.5s;
66
+ }
67
+
68
+ .blue-cross
69
+ {
70
+ font-size: 18px;
71
+ color: #0bd2ff;
72
+ text-shadow: none;
73
+ }
74
+
75
+ .alert-simple.alert-warning
76
+ {
77
+ border: 1px solid rgba(241, 142, 6, 0.81);
78
+ background-color: rgba(220, 128, 1, 0.16);
79
+ box-shadow: 0px 0px 2px #ffb103;
80
+ color: #ffb103;
81
+ text-shadow: 2px 1px #00040a;
82
+ transition:0.5s;
83
+ cursor:pointer;
84
+ }
85
+
86
+ .alert-warning:hover{
87
+ background-color: rgba(220, 128, 1, 0.33);
88
+ transition:0.5s;
89
+ }
90
+
91
+ .warning
92
+ {
93
+ font-size: 18px;
94
+ color: #ffb40b;
95
+ text-shadow: none;
96
+ }
97
+
98
+ .alert-simple.alert-danger
99
+ {
100
+ border: 1px solid rgba(241, 6, 6, 0.81);
101
+ background-color: rgba(220, 17, 1, 0.16);
102
+ box-shadow: 0px 0px 2px #ff0303;
103
+ color: #ff0303;
104
+ text-shadow: 2px 1px #00040a;
105
+ transition:0.5s;
106
+ cursor:pointer;
107
+ }
108
+
109
+ .alert-danger:hover
110
+ {
111
+ background-color: rgba(220, 17, 1, 0.33);
112
+ transition:0.5s;
113
+ }
114
+
115
+ .danger
116
+ {
117
+ font-size: 18px;
118
+ color: #ff0303;
119
+ text-shadow: none;
120
+ }
121
+
122
+ .alert-simple.alert-primary
123
+ {
124
+ border: 1px solid rgba(6, 241, 226, 0.81);
125
+ background-color: rgba(1, 204, 220, 0.16);
126
+ box-shadow: 0px 0px 2px #03fff5;
127
+ color: #03d0ff;
128
+ text-shadow: 2px 1px #00040a;
129
+ transition:0.5s;
130
+ cursor:pointer;
131
+ }
132
+
133
+ .alert-primary:hover{
134
+ background-color: rgba(1, 204, 220, 0.33);
135
+ transition:0.5s;
136
+ }
137
+
138
+ .alertprimary
139
+ {
140
+ font-size: 18px;
141
+ color: #03d0ff;
142
+ text-shadow: none;
143
+ }
144
+
145
+ .square_box {
146
+ position: absolute;
147
+ -webkit-transform: rotate(45deg);
148
+ -ms-transform: rotate(45deg);
149
+ transform: rotate(45deg);
150
+ border-top-left-radius: 45px;
151
+ opacity: 0.302;
152
+ }
153
+
154
+ .square_box.box_three {
155
+ background-image: -moz-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
156
+ background-image: -webkit-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
157
+ background-image: -ms-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
158
+ opacity: 0.059;
159
+ left: -80px;
160
+ top: -60px;
161
+ width: 500px;
162
+ height: 500px;
163
+ border-radius: 45px;
164
+ }
165
+
166
+ .square_box.box_four {
167
+ background-image: -moz-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
168
+ background-image: -webkit-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
169
+ background-image: -ms-linear-gradient(-90deg, #290a59 0%, #3d57f4 100%);
170
+ opacity: 0.059;
171
+ left: 150px;
172
+ top: -25px;
173
+ width: 550px;
174
+ height: 550px;
175
+ border-radius: 45px;
176
+ }
177
+
178
+ .alert:before {
179
+ content: '';
180
+ position: absolute;
181
+ width: 0;
182
+ height: calc(100% - 44px);
183
+ border-left: 1px solid;
184
+ border-right: 2px solid;
185
+ border-bottom-right-radius: 3px;
186
+ border-top-right-radius: 3px;
187
+ left: 0;
188
+ top: 50%;
189
+ transform: translate(0,-50%);
190
+ height: 20px;
191
+ }
192
+
193
+ .fa-times
194
+ {
195
+ -webkit-animation: blink-1 2s infinite both;
196
+ animation: blink-1 2s infinite both;
197
+ }
198
+
199
+
200
+ /**
201
+ * ----------------------------------------
202
+ * animation blink-1
203
+ * ----------------------------------------
204
+ */
205
+ @-webkit-keyframes blink-1 {
206
+ 0%,
207
+ 50%,
208
+ 100% {
209
+ opacity: 1;
210
+ }
211
+ 25%,
212
+ 75% {
213
+ opacity: 0;
214
+ }
215
+ }
216
+ @keyframes blink-1 {
217
+ 0%,
218
+ 50%,
219
+ 100% {
220
+ opacity: 1;
221
+ }
222
+ 25%,
223
+ 75% {
224
+ opacity: 0;
225
+ }
226
+ }
227
+
228
+ /**
229
+ Custom CSS for Gradio
230
+ */
231
+
232
+ """
233
+
234
+ info_alert_text = """
235
+ <div class="alert fade alert-simple alert-info alert-dismissible text-left font__family-montserrat font__size-16 font__weight-light brk-library-rendered rendered show" role="alert" data-brk-library="component__alert">
236
+ <i class="start-icon fa fa-info-circle faa-shake animated"></i>
237
+ <strong class="font__weight-semibold">Heads up!</strong>
238
+ <p class="font__weight-light">
239
+ The GPU memory usage estimation above only show how much memory the model will take on the GPU. It's not the actual memory usage needed to train the model or use it for inference.
240
+ You can find more information <a href="https://amenalahassa.github.io/amenalahassa/posts/model_gpu.html" target="_blank">here</a>.
241
+ </p>
242
+ </div>
243
+ """
244
+
245
+ error_alert_text = """
246
+ <div class="alert fade alert-simple alert-danger alert-dismissible text-left font__family-montserrat font__size-16 font__weight-light brk-library-rendered rendered show" role="alert" data-brk-library="component__alert">
247
+ <i class="start-icon far fa-times-circle faa-pulse animated"></i>
248
+ <strong class="font__weight-semibold">Warning</strong>
249
+ <p class="font__weight-light">{error}</p>
250
+ </div>
251
+ """
252
+
253
+ def get_model_size(model_size, precision):
254
+ if precision == "fp16":
255
+ model_size *= 0.5
256
+ elif precision == "int8":
257
+ model_size *= 0.125
258
+ elif precision == "int4":
259
+ model_size *= 0.0625
260
+ return model_size
261
+
262
+ def get_model_size_from_checkpoint(file, precision):
263
+ from pathlib import Path
264
+
265
+ num_params = 0
266
+ error = None
267
+ filepath = Path(file)
268
+ extension = filepath.suffix[1:]
269
+
270
+ try:
271
+
272
+ if extension in ["pth", "pt"]:
273
+ import torch
274
+ checkpoint = torch.load(file, weights_only=False)
275
+
276
+ # If the checkpoint contains only the state_dict, use it directly
277
+ if "state_dict" in checkpoint:
278
+ state_dict = checkpoint["state_dict"]
279
+ else:
280
+ state_dict = checkpoint
281
+
282
+ # Calculate the total number of parameters
283
+ # Assuming that the model is composed of multiple children modules/models
284
+ for child in state_dict.values():
285
+ # Check if the parameter is a model
286
+ if isinstance(child, torch.nn.Module):
287
+ # Calculate the number of parameters in the model
288
+ num_params += sum(p.numel() for p in child.parameters())
289
+
290
+ # Calculate the number of parameters of direct children/layers
291
+ for param in state_dict.values():
292
+ # Check if the parameter has the attribute `numel`
293
+ if hasattr(param, "numel"):
294
+ num_params += param.numel()
295
+
296
+ elif extension in ["h5", "hdf5"]:
297
+ from tensorflow.keras.models import load_model
298
+
299
+ model = load_model(file)
300
+ model.compile()
301
+ # Calculate the total number of parameters
302
+ num_params = model.count_params()
303
+
304
+ elif extension in ["onnx"]:
305
+ import onnx
306
+ from onnx import numpy_helper
307
+
308
+ model = onnx.load(file)
309
+ num_params = sum([numpy_helper.to_array(tensor).size for tensor in model.graph.initializer])
310
+
311
+ else:
312
+ error = "Unsupported file format. Please upload a PyTorch/Keras/ONNX model checkpoint."
313
+
314
+ except Exception as e:
315
+ error = str(e)
316
+
317
+ if num_params == 0 and error is None:
318
+ error = "No parameters found in the model checkpoint"
319
+
320
+ return get_model_size(num_params, precision), error
321
+
322
+ def get_model_size_from_hf(model_name, precision):
323
+ from transformers import AutoModel
324
+ num_params = 0
325
+ error = None
326
+
327
+ try:
328
+ model = AutoModel.from_pretrained(model_name)
329
+ num_params = sum(param.numel() for param in model.parameters())
330
+ except Exception as e:
331
+ error = str(e)
332
+
333
+ return get_model_size(num_params, precision), error
334
+
335
+ def compute_gpu_memory(input_model_size, model_precision):
336
+ P = input_model_size
337
+ Q = 32 if model_precision == "fp32" else 16 if model_precision == "fp16" else 8 if model_precision == "int8" else 4
338
+ memory = P * Q / 8 / 1024 / 1024 / 1024
339
+ return [f"{memory} GB", True if memory > 0 else False]
340
+
341
+ # def delete_directory(req: gr.Request):
342
+ # if not req.username:
343
+ # return
344
+ # user_dir: Path = current_dir / req.username
345
+ # shutil.rmtree(str(user_dir))
346
+
347
+ with gr.Blocks(head=header, css=css, delete_cache=(43200,43200)) as demo:
348
+ model_precision = gr.State("fp32")
349
+ model_source = gr.State("import")
350
+ uploaded_file = gr.State(None)
351
+ hf_model_name = gr.State()
352
+ msg_error = gr.State()
353
+ supported_file_types = ["pt", "pth", "h5", "hdf5", "onnx"]
354
+ has_computed_gpu_memory = gr.State(False)
355
+
356
+ gr.Markdown(
357
+ """
358
+ # Wondering how much memory your model will take?
359
+ This app helps you estimate the memory usage of a model on GPU.
360
+ """
361
+ )
362
+
363
+ checkpoint_radio = gr.Radio(
364
+ [("Import model checkpoint", "import"), ("Use model from Hugging Face", "hf")],
365
+ value="import",
366
+ label="Choose a model source"
367
+ )
368
+
369
+ checkpoint_radio.change(fn=lambda x: x, inputs=checkpoint_radio, outputs=model_source)
370
+ @gr.render(inputs=[model_source, msg_error])
371
+ def rendering(source, runtime_error):
372
+
373
+ with gr.Row():
374
+ with gr.Column():
375
+ if source == "import":
376
+ gr.Markdown("Upload a model checkpoint file. Supported formats are PyTorch, Keras, and ONNX.")
377
+ uploader = gr.File(label=f'Upload Model Checkpoint [{" | ".join(supported_file_types)}]', file_types=supported_file_types, file_count="single", type="filepath")
378
+ uploader.upload(fn=lambda x: x, inputs=uploader, outputs=uploaded_file)
379
+ else:
380
+ mode_name_textbox = gr.Textbox(label="Model Name", placeholder="e.g. facebook/bart-large")
381
+ mode_name_textbox.change(fn=lambda x: x, inputs=mode_name_textbox, outputs=hf_model_name)
382
+
383
+ precision_radio = gr.Radio(
384
+ [
385
+ ("FP32 (32-bit floating point)", "fp32"),
386
+ ("FP16 (half/BF16) (16-bit floating point)", "fp16"),
387
+ ("INT8 (8-bit integer)", "int8"),
388
+ ("INT4 (4-bit integer)", "int4"),
389
+ ],
390
+ value=model_precision.value,
391
+ label="Select the Precision or Size of the model parameters"
392
+ )
393
+ precision_radio.change(fn=lambda x: x, inputs=precision_radio, outputs=model_precision)
394
+ compute_btn = gr.Button("Compute")
395
+
396
+ with gr.Column():
397
+
398
+ num_params = gr.Number(label="Number of Parameters")
399
+ gpu_memory = gr.Textbox(label="GPU memory expressed in Gigabyte(GB)", show_copy_button=True)
400
+ num_params.change(compute_gpu_memory, inputs=[num_params, model_precision], outputs=[gpu_memory, has_computed_gpu_memory])
401
+
402
+ if runtime_error:
403
+ gr.HTML(error_alert_text.format(error=runtime_error))
404
+
405
+ info = gr.HTML(info_alert_text, visible=False)
406
+ gpu_memory.change(fn=lambda x: gr.HTML(info_alert_text, visible=True) if x != "0.0 GB" else gr.HTML(info_alert_text, visible=False), inputs=gpu_memory, outputs=info)
407
+
408
+ def compute_model_size(input_source, input_precision, input_file, input_hf_model):
409
+ if input_source == "import":
410
+ model_size, error = get_model_size_from_checkpoint(input_file, input_precision)
411
+ else:
412
+ model_size, error = get_model_size_from_hf(input_hf_model, input_precision)
413
+
414
+ return [model_size, error]
415
+
416
+ compute_btn.click(compute_model_size, inputs=[model_source, model_precision, uploaded_file, hf_model_name], outputs=[num_params, msg_error])
417
+
418
+ # demo.unload(delete_directory)
419
+
420
+ demo.launch()