|
from detectron2.engine import AMPTrainer |
|
import torch |
|
import time |
|
import logging |
|
|
|
logger = logging.getLogger("detectron2") |
|
|
|
import typing |
|
from collections import defaultdict |
|
import tabulate |
|
from torch import nn |
|
|
|
|
|
def parameter_count(model: nn.Module, trainable_only: bool = False) -> typing.DefaultDict[str, int]: |
|
""" |
|
Count parameters of a model and its submodules. |
|
|
|
Args: |
|
model: a torch module |
|
|
|
Returns: |
|
dict (str-> int): the key is either a parameter name or a module name. |
|
The value is the number of elements in the parameter, or in all |
|
parameters of the module. The key "" corresponds to the total |
|
number of parameters of the model. |
|
""" |
|
r = defaultdict(int) |
|
for name, prm in model.named_parameters(): |
|
if trainable_only: |
|
if not prm.requires_grad: |
|
continue |
|
size = prm.numel() |
|
name = name.split(".") |
|
for k in range(0, len(name) + 1): |
|
prefix = ".".join(name[:k]) |
|
r[prefix] += size |
|
return r |
|
|
|
|
|
def parameter_count_table( |
|
model: nn.Module, max_depth: int = 3, trainable_only: bool = False |
|
) -> str: |
|
""" |
|
Format the parameter count of the model (and its submodules or parameters) |
|
in a nice table. It looks like this: |
|
|
|
:: |
|
|
|
| name | #elements or shape | |
|
|:--------------------------------|:---------------------| |
|
| model | 37.9M | |
|
| backbone | 31.5M | |
|
| backbone.fpn_lateral3 | 0.1M | |
|
| backbone.fpn_lateral3.weight | (256, 512, 1, 1) | |
|
| backbone.fpn_lateral3.bias | (256,) | |
|
| backbone.fpn_output3 | 0.6M | |
|
| backbone.fpn_output3.weight | (256, 256, 3, 3) | |
|
| backbone.fpn_output3.bias | (256,) | |
|
| backbone.fpn_lateral4 | 0.3M | |
|
| backbone.fpn_lateral4.weight | (256, 1024, 1, 1) | |
|
| backbone.fpn_lateral4.bias | (256,) | |
|
| backbone.fpn_output4 | 0.6M | |
|
| backbone.fpn_output4.weight | (256, 256, 3, 3) | |
|
| backbone.fpn_output4.bias | (256,) | |
|
| backbone.fpn_lateral5 | 0.5M | |
|
| backbone.fpn_lateral5.weight | (256, 2048, 1, 1) | |
|
| backbone.fpn_lateral5.bias | (256,) | |
|
| backbone.fpn_output5 | 0.6M | |
|
| backbone.fpn_output5.weight | (256, 256, 3, 3) | |
|
| backbone.fpn_output5.bias | (256,) | |
|
| backbone.top_block | 5.3M | |
|
| backbone.top_block.p6 | 4.7M | |
|
| backbone.top_block.p7 | 0.6M | |
|
| backbone.bottom_up | 23.5M | |
|
| backbone.bottom_up.stem | 9.4K | |
|
| backbone.bottom_up.res2 | 0.2M | |
|
| backbone.bottom_up.res3 | 1.2M | |
|
| backbone.bottom_up.res4 | 7.1M | |
|
| backbone.bottom_up.res5 | 14.9M | |
|
| ...... | ..... | |
|
|
|
Args: |
|
model: a torch module |
|
max_depth (int): maximum depth to recursively print submodules or |
|
parameters |
|
|
|
Returns: |
|
str: the table to be printed |
|
""" |
|
count: typing.DefaultDict[str, int] = parameter_count(model, trainable_only) |
|
|
|
param_shape: typing.Dict[str, typing.Tuple] = { |
|
k: tuple(v.shape) for k, v in model.named_parameters() |
|
} |
|
|
|
|
|
table: typing.List[typing.Tuple] = [] |
|
|
|
def format_size(x: int) -> str: |
|
if x > 1e8: |
|
return "{:.1f}G".format(x / 1e9) |
|
if x > 1e5: |
|
return "{:.1f}M".format(x / 1e6) |
|
if x > 1e2: |
|
return "{:.1f}K".format(x / 1e3) |
|
return str(x) |
|
|
|
def fill(lvl: int, prefix: str) -> None: |
|
if lvl >= max_depth: |
|
return |
|
for name, v in count.items(): |
|
if name.count(".") == lvl and name.startswith(prefix): |
|
indent = " " * (lvl + 1) |
|
if name in param_shape: |
|
table.append((indent + name, indent + str(param_shape[name]))) |
|
else: |
|
table.append((indent + name, indent + format_size(v))) |
|
fill(lvl + 1, name + ".") |
|
|
|
table.append(("model", format_size(count.pop("")))) |
|
fill(0, "") |
|
|
|
old_ws = tabulate.PRESERVE_WHITESPACE |
|
tabulate.PRESERVE_WHITESPACE = True |
|
tab = tabulate.tabulate(table, headers=["name", "#elements or shape"], tablefmt="pipe") |
|
tabulate.PRESERVE_WHITESPACE = old_ws |
|
return tab |
|
|
|
|
|
def cycle(iterable): |
|
while True: |
|
for x in iterable: |
|
yield x |
|
|
|
class MattingTrainer(AMPTrainer): |
|
def __init__(self, model, data_loader, optimizer, grad_scaler=None): |
|
super().__init__(model, data_loader, optimizer, grad_scaler=None) |
|
self.data_loader_iter = iter(cycle(self.data_loader)) |
|
|
|
|
|
logger.info("All parameters: \n" + parameter_count_table(model)) |
|
logger.info("Trainable parameters: \n" + parameter_count_table(model, trainable_only=True, max_depth=8)) |
|
|
|
def run_step(self): |
|
""" |
|
Implement the AMP training logic. |
|
""" |
|
assert self.model.training, "[AMPTrainer] model was changed to eval mode!" |
|
assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" |
|
from torch.cuda.amp import autocast |
|
|
|
|
|
start = time.perf_counter() |
|
data = next(self.data_loader_iter) |
|
data_time = time.perf_counter() - start |
|
|
|
with autocast(): |
|
loss_dict = self.model(data) |
|
if isinstance(loss_dict, torch.Tensor): |
|
losses = loss_dict |
|
loss_dict = {"total_loss": loss_dict} |
|
else: |
|
losses = sum(loss_dict.values()) |
|
|
|
self.optimizer.zero_grad() |
|
self.grad_scaler.scale(losses).backward() |
|
|
|
self._write_metrics(loss_dict, data_time) |
|
|
|
self.grad_scaler.step(self.optimizer) |
|
self.grad_scaler.update() |