sooks commited on
Commit
2e6b06f
1 Parent(s): 270ff3e

Create utils.py

Browse files
Files changed (1) hide show
  1. detector/utils.py +62 -0
detector/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from functools import reduce
3
+
4
+ from torch import nn
5
+ import torch.distributed as dist
6
+
7
+
8
+ def summary(model: nn.Module, file=sys.stdout):
9
+ def repr(model):
10
+ # We treat the extra repr like the sub-module, one item per line
11
+ extra_lines = []
12
+ extra_repr = model.extra_repr()
13
+ # empty string will be split into list ['']
14
+ if extra_repr:
15
+ extra_lines = extra_repr.split('\n')
16
+ child_lines = []
17
+ total_params = 0
18
+ for key, module in model._modules.items():
19
+ mod_str, num_params = repr(module)
20
+ mod_str = nn.modules.module._addindent(mod_str, 2)
21
+ child_lines.append('(' + key + '): ' + mod_str)
22
+ total_params += num_params
23
+ lines = extra_lines + child_lines
24
+
25
+ for name, p in model._parameters.items():
26
+ if hasattr(p, 'shape'):
27
+ total_params += reduce(lambda x, y: x * y, p.shape)
28
+
29
+ main_str = model._get_name() + '('
30
+ if lines:
31
+ # simple one-liner info, which most builtin Modules will use
32
+ if len(extra_lines) == 1 and not child_lines:
33
+ main_str += extra_lines[0]
34
+ else:
35
+ main_str += '\n ' + '\n '.join(lines) + '\n'
36
+
37
+ main_str += ')'
38
+ if file is sys.stdout:
39
+ main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
40
+ else:
41
+ main_str += ', {:,} params'.format(total_params)
42
+ return main_str, total_params
43
+
44
+ string, count = repr(model)
45
+ if file is not None:
46
+ if isinstance(file, str):
47
+ file = open(file, 'w')
48
+ print(string, file=file)
49
+ file.flush()
50
+
51
+ return count
52
+
53
+
54
+ def grad_norm(model: nn.Module):
55
+ total_norm = 0
56
+ for p in model.parameters():
57
+ param_norm = p.grad.data.norm(2)
58
+ total_norm += param_norm.item() ** 2
59
+ return total_norm ** 0.5
60
+
61
+ def distributed():
62
+ return dist.is_available() and dist.is_initialized()