File size: 6,610 Bytes
8b4c6c7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from detectron2.engine import AMPTrainer
import torch
import time
import logging
logger = logging.getLogger("detectron2")
import typing
from collections import defaultdict
import tabulate
from torch import nn
def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]:
"""
Count parameters of a model and its submodules.
Args:
model: a torch module
Returns:
dict (str-> int): the key is either a parameter name or a module name.
The value is the number of elements in the parameter, or in all
parameters of the module. The key "" corresponds to the total
number of parameters of the model.
"""
r = defaultdict(int)
for name, prm in model.named_parameters():
if trainable_only:
if not prm.requires_grad:
continue
size = prm.numel()
name = name.split(".")
for k in range(0, len(name) + 1):
prefix = ".".join(name[:k])
r[prefix] += size
return r
def parameter_count_table(
model: nn.Module, max_depth: int = 3, trainable_only: bool = False
) -> str:
"""
Format the parameter count of the model (and its submodules or parameters)
in a nice table. It looks like this:
::
| name | #elements or shape |
|:--------------------------------|:---------------------|
| model | 37.9M |
| backbone | 31.5M |
| backbone.fpn_lateral3 | 0.1M |
| backbone.fpn_lateral3.weight | (256, 512, 1, 1) |
| backbone.fpn_lateral3.bias | (256,) |
| backbone.fpn_output3 | 0.6M |
| backbone.fpn_output3.weight | (256, 256, 3, 3) |
| backbone.fpn_output3.bias | (256,) |
| backbone.fpn_lateral4 | 0.3M |
| backbone.fpn_lateral4.weight | (256, 1024, 1, 1) |
| backbone.fpn_lateral4.bias | (256,) |
| backbone.fpn_output4 | 0.6M |
| backbone.fpn_output4.weight | (256, 256, 3, 3) |
| backbone.fpn_output4.bias | (256,) |
| backbone.fpn_lateral5 | 0.5M |
| backbone.fpn_lateral5.weight | (256, 2048, 1, 1) |
| backbone.fpn_lateral5.bias | (256,) |
| backbone.fpn_output5 | 0.6M |
| backbone.fpn_output5.weight | (256, 256, 3, 3) |
| backbone.fpn_output5.bias | (256,) |
| backbone.top_block | 5.3M |
| backbone.top_block.p6 | 4.7M |
| backbone.top_block.p7 | 0.6M |
| backbone.bottom_up | 23.5M |
| backbone.bottom_up.stem | 9.4K |
| backbone.bottom_up.res2 | 0.2M |
| backbone.bottom_up.res3 | 1.2M |
| backbone.bottom_up.res4 | 7.1M |
| backbone.bottom_up.res5 | 14.9M |
| ...... | ..... |
Args:
model: a torch module
max_depth (int): maximum depth to recursively print submodules or
parameters
Returns:
str: the table to be printed
"""
count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only)
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
param_shape: typing.Dict[str, typing.Tuple] = {
k: tuple(v.shape) for k, v in model.named_parameters()
}
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
table: typing.List[typing.Tuple] = []
def format_size(x: int) -> str:
if x > 1e8:
return "{:.1f}G".format(x / 1e9)
if x > 1e5:
return "{:.1f}M".format(x / 1e6)
if x > 1e2:
return "{:.1f}K".format(x / 1e3)
return str(x)
def fill(lvl: int, prefix: str) -> None:
if lvl >= max_depth:
return
for name, v in count.items():
if name.count(".") == lvl and name.startswith(prefix):
indent = " " * (lvl + 1)
if name in param_shape:
table.append((indent + name, indent + str(param_shape[name])))
else:
table.append((indent + name, indent + format_size(v)))
fill(lvl + 1, name + ".")
table.append(("model", format_size(count.pop(""))))
fill(0, "")
old_ws = tabulate.PRESERVE_WHITESPACE
tabulate.PRESERVE_WHITESPACE = True
tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe")
tabulate.PRESERVE_WHITESPACE = old_ws
return tab
def cycle(iterable):
while True:
for x in iterable:
yield x
class MattingTrainer(AMPTrainer):
def __init__(self, model, data_loader, optimizer, grad_scaler=None):
super().__init__(model, data_loader, optimizer, grad_scaler=None)
self.data_loader_iter = iter(cycle(self.data_loader))
# print model parameters
logger.info("All parameters: \n" + parameter_count_table(model))
logger.info("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=8))
def run_step(self):
"""
Implement the AMP training logic.
"""
assert self.model.training, "[AMPTrainer] model was changed to eval mode!"
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!"
from torch.cuda.amp import autocast
#matting pass
start = time.perf_counter()
data = next(self.data_loader_iter)
data_time = time.perf_counter() - start
with autocast():
loss_dict = self.model(data)
if isinstance(loss_dict, torch.Tensor):
losses = loss_dict
loss_dict = {"total_loss": loss_dict}
else:
losses = sum(loss_dict.values())
self.optimizer.zero_grad()
self.grad_scaler.scale(losses).backward()
self._write_metrics(loss_dict, data_time)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update() |