|
import logging |
|
import random |
|
import subprocess |
|
from datetime import datetime |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.nn.parallel.distributed import _find_tensors |
|
import torch.optim |
|
import torch.utils.data |
|
from packaging import version |
|
from omegaconf import OmegaConf |
|
|
|
|
|
def set_random_seed(seed): |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def is_logging_process(): |
|
return not dist.is_initialized() or dist.get_rank() == 0 |
|
|
|
|
|
def get_logger(cfg, name=None): |
|
|
|
if is_logging_process(): |
|
logging.config.dictConfig( |
|
OmegaConf.to_container(cfg.job_logging_config, resolve=True) |
|
) |
|
return logging.getLogger(name) |
|
|
|
|
|
|
|
class SyncFunction(torch.autograd.Function): |
|
@staticmethod |
|
|
|
def forward(ctx, tensor): |
|
ctx.batch_size = tensor.shape[0] |
|
|
|
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] |
|
|
|
torch.distributed.all_gather(gathered_tensor, tensor) |
|
gathered_tensor = torch.cat(gathered_tensor, 0) |
|
|
|
return gathered_tensor |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
grad_input = grad_output.clone() |
|
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False) |
|
|
|
idx_from = torch.distributed.get_rank() * ctx.batch_size |
|
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size |
|
return grad_input[idx_from:idx_to] |
|
|
|
|
|
def get_timestamp(): |
|
return datetime.now().strftime("%y%m%d-%H%M%S") |
|
|
|
|
|
def get_commit_hash(): |
|
message = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) |
|
return message.strip().decode("utf-8") |
|
|
|
|
|
class DDP(DistributedDataParallel): |
|
""" |
|
Override the forward call in lightning so it goes to training and validation step respectively |
|
""" |
|
|
|
def forward(self, *inputs, **kwargs): |
|
if version.parse(torch.__version__[:6]) < version.parse("1.11"): |
|
self._sync_params() |
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
|
assert len(self.device_ids) == 1 |
|
if self.module.training: |
|
output = self.module.training_step(*inputs[0], **kwargs[0]) |
|
elif self.module.testing: |
|
output = self.module.test_step(*inputs[0], **kwargs[0]) |
|
else: |
|
output = self.module.validation_step(*inputs[0], **kwargs[0]) |
|
if torch.is_grad_enabled(): |
|
|
|
|
|
|
|
|
|
|
|
if self.find_unused_parameters: |
|
self.reducer.prepare_for_backward(list(_find_tensors(output))) |
|
else: |
|
self.reducer.prepare_for_backward([]) |
|
else: |
|
from torch.nn.parallel.distributed import \ |
|
logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref |
|
with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): |
|
if torch.is_grad_enabled() and self.require_backward_grad_sync: |
|
self.logger.set_runtime_stats_and_log() |
|
self.num_iterations += 1 |
|
self.reducer.prepare_for_forward() |
|
|
|
|
|
|
|
work = Join.notify_join_context(self) |
|
if work: |
|
self.reducer._set_forward_pass_work_handle( |
|
work, self._divide_by_initial_world_size |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): |
|
logging.info("Reducer buckets have been rebuilt in this iteration.") |
|
self._has_rebuilt_buckets = True |
|
|
|
|
|
|
|
buffer_hook_registered = hasattr(self, 'buffer_hook') |
|
if self._check_sync_bufs_pre_fwd(): |
|
self._sync_buffers() |
|
|
|
if self._join_config.enable: |
|
|
|
self._check_global_requires_backward_grad_sync(is_joined_rank=False) |
|
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
|
if self.module.training: |
|
output = self.module.training_step(*inputs[0], **kwargs[0]) |
|
elif self.module.testing: |
|
output = self.module.test_step(*inputs[0], **kwargs[0]) |
|
else: |
|
output = self.module.validation_step(*inputs[0], **kwargs[0]) |
|
|
|
|
|
|
|
if self._check_sync_bufs_post_fwd(): |
|
self._sync_buffers() |
|
|
|
if torch.is_grad_enabled() and self.require_backward_grad_sync: |
|
self.require_forward_param_sync = True |
|
|
|
|
|
|
|
|
|
|
|
if self.find_unused_parameters and not self.static_graph: |
|
|
|
self.reducer.prepare_for_backward(list(_find_tensors(output))) |
|
else: |
|
self.reducer.prepare_for_backward([]) |
|
else: |
|
self.require_forward_param_sync = False |
|
|
|
|
|
|
|
if (self.find_unused_parameters and not self.static_graph) or ( |
|
self.static_graph and self.num_iterations == 1 |
|
): |
|
state_dict = { |
|
'static_graph': self.static_graph, |
|
'num_iterations': self.num_iterations, |
|
} |
|
|
|
output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref( |
|
output |
|
) |
|
output_placeholders = [None for _ in range(len(output_tensor_list))] |
|
|
|
|
|
for i, output in enumerate(output_tensor_list): |
|
if torch.is_tensor(output) and output.grad_fn is None: |
|
output_placeholders[i] = output |
|
|
|
|
|
|
|
|
|
|
|
|
|
passthrough_tensor_list = _DDPSink.apply( |
|
self.reducer, |
|
state_dict, |
|
*output_tensor_list, |
|
) |
|
for i in range(len(output_placeholders)): |
|
if output_placeholders[i] is None: |
|
output_placeholders[i] = passthrough_tensor_list[i] |
|
|
|
|
|
output = _tree_unflatten_with_rref( |
|
output_placeholders, treespec, output_is_rref |
|
) |
|
return output |
|
|
|
|