|
from typing import Callable |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
class ModulateDiT(nn.Module): |
|
"""Modulation layer for DiT.""" |
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
factor: int, |
|
act_layer: Callable, |
|
dtype=None, |
|
device=None, |
|
): |
|
factory_kwargs = {"dtype": dtype, "device": device} |
|
super().__init__() |
|
self.act = act_layer() |
|
self.linear = nn.Linear( |
|
hidden_size, factor * hidden_size, bias=True, **factory_kwargs |
|
) |
|
|
|
nn.init.zeros_(self.linear.weight) |
|
nn.init.zeros_(self.linear.bias) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
return self.linear(self.act(x)) |
|
|
|
|
|
def modulate(x, shift=None, scale=None): |
|
"""modulate by shift and scale |
|
|
|
Args: |
|
x (torch.Tensor): input tensor. |
|
shift (torch.Tensor, optional): shift tensor. Defaults to None. |
|
scale (torch.Tensor, optional): scale tensor. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: the output tensor after modulate. |
|
""" |
|
if scale is None and shift is None: |
|
return x |
|
elif shift is None: |
|
return x * (1 + scale.unsqueeze(1)) |
|
elif scale is None: |
|
return x + shift.unsqueeze(1) |
|
else: |
|
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
|
|
|
def apply_gate(x, gate=None, tanh=False): |
|
"""AI is creating summary for apply_gate |
|
|
|
Args: |
|
x (torch.Tensor): input tensor. |
|
gate (torch.Tensor, optional): gate tensor. Defaults to None. |
|
tanh (bool, optional): whether to use tanh function. Defaults to False. |
|
|
|
Returns: |
|
torch.Tensor: the output tensor after apply gate. |
|
""" |
|
if gate is None: |
|
return x |
|
if tanh: |
|
return x * gate.unsqueeze(1).tanh() |
|
else: |
|
return x * gate.unsqueeze(1) |
|
|
|
|
|
def ckpt_wrapper(module): |
|
def ckpt_forward(*inputs): |
|
outputs = module(*inputs) |
|
return outputs |
|
|
|
return ckpt_forward |
|
|