jerryhai
Track binary files with Git LFS
90f7c1e
raw
history blame
8.66 kB
# Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
import math
import torch
from einops import rearrange
from .base import BaseModule
class Mish(BaseModule):
def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))
class Upsample(BaseModule):
def __init__(self, dim):
super(Upsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def forward(self, x):
return self.conv(x)
class Downsample(BaseModule):
def __init__(self, dim):
super(Downsample, self).__init__()
self.conv = torch.nn.Conv2d(dim, dim, 3, 2, 1)
def forward(self, x):
return self.conv(x)
class Rezero(BaseModule):
def __init__(self, fn):
super(Rezero, self).__init__()
self.fn = fn
self.g = torch.nn.Parameter(torch.zeros(1))
def forward(self, x):
return self.fn(x) * self.g
class Block(BaseModule):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.block = torch.nn.Sequential(torch.nn.Conv2d(dim, dim_out, 3,
padding=1), torch.nn.GroupNorm(
groups, dim_out), Mish())
def forward(self, x):
output = self.block(x)
return output
class ResnetBlock(BaseModule):
def __init__(self, dim, dim_out, time_emb_dim, groups=8):
super(ResnetBlock, self).__init__()
self.mlp = torch.nn.Sequential(Mish(), torch.nn.Linear(time_emb_dim,
dim_out))
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
if dim != dim_out:
self.res_conv = torch.nn.Conv2d(dim, dim_out, 1)
else:
self.res_conv = torch.nn.Identity()
def forward(self, x, time_emb):
h = self.block1(x)
h += self.mlp(time_emb).unsqueeze(-1).unsqueeze(-1)
h = self.block2(h)
output = h + self.res_conv(x)
return output
class LinearAttention(BaseModule):
def __init__(self, dim, heads=4, dim_head=32, q_norm=True):
super(LinearAttention, self).__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = torch.nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = torch.nn.Conv2d(hidden_dim, dim, 1)
self.q_norm = q_norm
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)',
heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
if self.q_norm:
q = q.softmax(dim=-2)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w',
heads=self.heads, h=h, w=w)
return self.to_out(out)
class Residual(BaseModule):
def __init__(self, fn):
super(Residual, self).__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
output = self.fn(x, *args, **kwargs) + x
return output
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(BaseModule):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class PitchPosEmb(BaseModule):
def __init__(self, dim, flip_sin_to_cos=False, downscale_freq_shift=0):
super(PitchPosEmb, self).__init__()
self.dim = dim
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, x):
# B * L
b, l = x.shape
x = rearrange(x, 'b l -> (b l)')
emb = get_timestep_embedding(
x,
self.dim,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
emb = rearrange(emb, '(b l) d -> b d l', b=b, l=l)
return emb
class TimbreBlock(BaseModule):
def __init__(self, out_dim):
super(TimbreBlock, self).__init__()
base_dim = out_dim // 4
self.block11 = torch.nn.Sequential(torch.nn.Conv2d(1, 2 * base_dim,
3, 1, 1),
torch.nn.InstanceNorm2d(2 * base_dim, affine=True),
torch.nn.GLU(dim=1))
self.block12 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 2 * base_dim,
3, 1, 1),
torch.nn.InstanceNorm2d(2 * base_dim, affine=True),
torch.nn.GLU(dim=1))
self.block21 = torch.nn.Sequential(torch.nn.Conv2d(base_dim, 4 * base_dim,
3, 1, 1),
torch.nn.InstanceNorm2d(4 * base_dim, affine=True),
torch.nn.GLU(dim=1))
self.block22 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 4 * base_dim,
3, 1, 1),
torch.nn.InstanceNorm2d(4 * base_dim, affine=True),
torch.nn.GLU(dim=1))
self.block31 = torch.nn.Sequential(torch.nn.Conv2d(2 * base_dim, 8 * base_dim,
3, 1, 1),
torch.nn.InstanceNorm2d(8 * base_dim, affine=True),
torch.nn.GLU(dim=1))
self.block32 = torch.nn.Sequential(torch.nn.Conv2d(4 * base_dim, 8 * base_dim,
3, 1, 1),
torch.nn.InstanceNorm2d(8 * base_dim, affine=True),
torch.nn.GLU(dim=1))
self.final_conv = torch.nn.Conv2d(4 * base_dim, out_dim, 1)
def forward(self, x):
y = self.block11(x)
y = self.block12(y)
y = self.block21(y)
y = self.block22(y)
y = self.block31(y)
y = self.block32(y)
y = self.final_conv(y)
return y.sum((2, 3)) / (y.shape[2] * y.shape[3])