PoTaTo721's picture
Update to V1.5
b2eb230
raw
history blame
3.56 kB
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vector_quantize_pytorch import GroupedResidualFSQ
from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
@dataclass
class FSQResult:
z: torch.Tensor
codes: torch.Tensor
latents: torch.Tensor
class DownsampleFiniteScalarQuantize(nn.Module):
def __init__(
self,
input_dim: int = 512,
n_codebooks: int = 9,
n_groups: int = 1,
levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
downsample_factor: tuple[int] = (2, 2),
downsample_dims: tuple[int] | None = None,
):
super().__init__()
if downsample_dims is None:
downsample_dims = [input_dim for _ in range(len(downsample_factor))]
all_dims = (input_dim,) + tuple(downsample_dims)
self.residual_fsq = GroupedResidualFSQ(
dim=all_dims[-1],
levels=levels,
num_quantizers=n_codebooks,
groups=n_groups,
)
self.downsample_factor = downsample_factor
self.downsample_dims = downsample_dims
self.downsample = nn.Sequential(
*[
nn.Sequential(
FishConvNet(
all_dims[idx],
all_dims[idx + 1],
kernel_size=factor,
stride=factor,
),
ConvNeXtBlock(dim=all_dims[idx + 1]),
)
for idx, factor in enumerate(downsample_factor)
]
)
self.upsample = nn.Sequential(
*[
nn.Sequential(
FishTransConvNet(
all_dims[idx + 1],
all_dims[idx],
kernel_size=factor,
stride=factor,
),
ConvNeXtBlock(dim=all_dims[idx]),
)
for idx, factor in reversed(list(enumerate(downsample_factor)))
]
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.trunc_normal_(m.weight, std=0.02)
nn.init.constant_(m.bias, 0)
def forward(self, z) -> FSQResult:
original_shape = z.shape
z = self.downsample(z)
quantized, indices = self.residual_fsq(z.mT)
result = FSQResult(
z=quantized.mT,
codes=indices.mT,
latents=z,
)
result.z = self.upsample(result.z)
# Pad or crop z to match original shape
diff = original_shape[-1] - result.z.shape[-1]
left = diff // 2
right = diff - left
if diff > 0:
result.z = F.pad(result.z, (left, right))
elif diff < 0:
result.z = result.z[..., -left:right]
return result
def encode(self, z):
z = self.downsample(z)
_, indices = self.residual_fsq(z.mT)
indices = rearrange(indices, "g b l r -> b (g r) l")
return indices
def decode(self, indices: torch.Tensor):
indices = rearrange(indices, "b (g r) l -> g b l r", g=self.residual_fsq.groups)
z_q = self.residual_fsq.get_output_from_indices(indices)
z_q = self.upsample(z_q.mT)
return z_q