|
import torch |
|
|
|
def MyUNet2DConditionModel_SD_forward(self, x): |
|
global skip |
|
x = self.conv_in(x) |
|
skip = [x] |
|
x = self.body(x) |
|
return x |
|
|
|
def MyCrossAttnDownBlock2D_SD_forward(self, x): |
|
for i in range(2): |
|
x = self.resnets[i](x) |
|
x = self.attentions[i](x) |
|
skip.append(x) |
|
if self.downsamplers is not None: |
|
x = self.downsamplers[0](x) |
|
skip.append(x) |
|
return x |
|
|
|
def MyCrossAttnUpBlock2D_SD_forward(self, x): |
|
for i in range(3): |
|
x = self.resnets[i](torch.cat([x, skip.pop()], dim=1)) |
|
x = self.attentions[i](x) |
|
if self.upsamplers is not None: |
|
x = self.upsamplers[0](x) |
|
return x |
|
|
|
def MyDownBlock2D_SD_forward(self, x): |
|
for i in range(2): |
|
x = self.resnets[i](x) |
|
skip.append(x) |
|
return x |
|
|
|
def MyUNetMidBlock2DCrossAttn_SD_forward(self, x): |
|
x = self.resnets[0](x) |
|
x = self.attentions[0](x) |
|
x = self.resnets[1](x) |
|
return x |
|
|
|
def MyUpBlock2D_SD_forward(self, x): |
|
for i in range(3): |
|
x = self.resnets[i](torch.cat([x, skip.pop()], dim=1)) |
|
x = self.upsamplers[0](x) |
|
return x |
|
|
|
def MyResnetBlock2D_SD_forward(self, x_in): |
|
x = self.norm1(x_in) |
|
x = self.nonlinearity(x) |
|
x = self.conv1(x) |
|
x = self.norm2(x) |
|
x = self.nonlinearity(x) |
|
x = self.conv2(x) |
|
if self.in_channels == self.out_channels: |
|
return x + x_in |
|
return x + self.conv_shortcut(x_in) |
|
|
|
def MyTransformer2DModel_SD_forward(self, x_in): |
|
b, c, h, w = x_in.shape |
|
x = self.norm(x_in) |
|
x = x.permute(0, 2, 3, 1).reshape(b, h * w, c).contiguous() |
|
x = self.proj_in(x) |
|
for block in self.transformer_blocks: |
|
x = x + block.attn1(block.norm1(x)) |
|
x = x + block.ff(block.norm3(x)) |
|
x = self.proj_out(x) |
|
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2).contiguous() |
|
return x + x_in |