kevinwang676's picture
Duplicate from kevinwang676/ChatGLM2-SadTalker
9074b85
raw
history blame
9.49 kB
import logging
import os
import torch
import torch.distributed as dist
from torch.nn import Module
from torch.nn.functional import normalize, linear
from torch.nn.parameter import Parameter
class PartialFC(Module):
"""
Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
Partial FC: Training 10 Million Identities on a Single Machine
See the original paper:
https://arxiv.org/abs/2010.05222
"""
@torch.no_grad()
def __init__(self, rank, local_rank, world_size, batch_size, resume,
margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
"""
rank: int
Unique process(GPU) ID from 0 to world_size - 1.
local_rank: int
Unique process(GPU) ID within the server from 0 to 7.
world_size: int
Number of GPU.
batch_size: int
Batch size on current rank(GPU).
resume: bool
Select whether to restore the weight of softmax.
margin_softmax: callable
A function of margin softmax, eg: cosface, arcface.
num_classes: int
The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
required.
sample_rate: float
The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
embedding_size: int
The feature dimension, default is 512.
prefix: str
Path for save checkpoint, default is './'.
"""
super(PartialFC, self).__init__()
#
self.num_classes: int = num_classes
self.rank: int = rank
self.local_rank: int = local_rank
self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
self.world_size: int = world_size
self.batch_size: int = batch_size
self.margin_softmax: callable = margin_softmax
self.sample_rate: float = sample_rate
self.embedding_size: int = embedding_size
self.prefix: str = prefix
self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
self.num_sample: int = int(self.sample_rate * self.num_local)
self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
if resume:
try:
self.weight: torch.Tensor = torch.load(self.weight_name)
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
raise IndexError
logging.info("softmax weight resume successfully!")
logging.info("softmax weight mom resume successfully!")
except (FileNotFoundError, KeyError, IndexError):
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init!")
logging.info("softmax weight mom init!")
else:
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight init successfully!")
logging.info("softmax weight mom init successfully!")
self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
self.index = None
if int(self.sample_rate) == 1:
self.update = lambda: 0
self.sub_weight = Parameter(self.weight)
self.sub_weight_mom = self.weight_mom
else:
self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
def save_params(self):
""" Save softmax weight for each rank on prefix
"""
torch.save(self.weight.data, self.weight_name)
torch.save(self.weight_mom, self.weight_mom_name)
@torch.no_grad()
def sample(self, total_label):
"""
Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
`num_sample`.
total_label: tensor
Label after all gather, which cross all GPUs.
"""
index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
total_label[~index_positive] = -1
total_label[index_positive] -= self.class_start
if int(self.sample_rate) != 1:
positive = torch.unique(total_label[index_positive], sorted=True)
if self.num_sample - positive.size(0) >= 0:
perm = torch.rand(size=[self.num_local], device=self.device)
perm[positive] = 2.0
index = torch.topk(perm, k=self.num_sample)[1]
index = index.sort()[0]
else:
index = positive
self.index = index
total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
self.sub_weight = Parameter(self.weight[index])
self.sub_weight_mom = self.weight_mom[index]
def forward(self, total_features, norm_weight):
""" Partial fc forward, `logits = X * sample(W)`
"""
torch.cuda.current_stream().wait_stream(self.stream)
logits = linear(total_features, norm_weight)
return logits
@torch.no_grad()
def update(self):
""" Set updated weight and weight_mom to memory bank.
"""
self.weight_mom[self.index] = self.sub_weight_mom
self.weight[self.index] = self.sub_weight
def prepare(self, label, optimizer):
"""
get sampled class centers for cal softmax.
label: tensor
Label tensor on each rank.
optimizer: opt
Optimizer for partial fc, which need to get weight mom.
"""
with torch.cuda.stream(self.stream):
total_label = torch.zeros(
size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
self.sample(total_label)
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
optimizer.param_groups[-1]['params'][0] = self.sub_weight
optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
norm_weight = normalize(self.sub_weight)
return total_label, norm_weight
def forward_backward(self, label, features, optimizer):
"""
Partial fc forward and backward with model parallel
label: tensor
Label tensor on each rank(GPU)
features: tensor
Features tensor on each rank(GPU)
optimizer: optimizer
Optimizer for partial fc
Returns:
--------
x_grad: tensor
The gradient of features.
loss_v: tensor
Loss value for cross entropy.
"""
total_label, norm_weight = self.prepare(label, optimizer)
total_features = torch.zeros(
size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
total_features.requires_grad = True
logits = self.forward(total_features, norm_weight)
logits = self.margin_softmax(logits, total_label)
with torch.no_grad():
max_fc = torch.max(logits, dim=1, keepdim=True)[0]
dist.all_reduce(max_fc, dist.ReduceOp.MAX)
# calculate exp(logits) and all-reduce
logits_exp = torch.exp(logits - max_fc)
logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
# calculate prob
logits_exp.div_(logits_sum_exp)
# get one-hot
grad = logits_exp
index = torch.where(total_label != -1)[0]
one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
one_hot.scatter_(1, total_label[index, None], 1)
# calculate loss
loss = torch.zeros(grad.size()[0], 1, device=grad.device)
loss[index] = grad[index].gather(1, total_label[index, None])
dist.all_reduce(loss, dist.ReduceOp.SUM)
loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
# calculate grad
grad[index] -= one_hot
grad.div_(self.batch_size * self.world_size)
logits.backward(grad)
if total_features.grad is not None:
total_features.grad.detach_()
x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
# feature gradient all-reduce
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
x_grad = x_grad * self.world_size
# backward backbone
return x_grad, loss_v