# Ultralytics YOLO 🚀, AGPL-3.0 license """ Block modules """ import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import DropPath from .conv import Conv, DWConv, GhostConv, LightConv, RepConv # from .transformer import TransformerBlock __all__ = ('DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3') class DFL(nn.Module): """ Integral module of Distribution Focal Loss (DFL). Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391 """ def __init__(self, c1=16): """Initialize a convolutional layer with a given number of input channels.""" super().__init__() self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False) x = torch.arange(c1, dtype=torch.float) self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1)) self.c1 = c1 def forward(self, x): """Applies a transformer layer on input tensor 'x' and returns a tensor.""" b, c, a = x.shape # batch, channels, anchors return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a) # return self.conv(x.view(b, self.c1, 4, a).softmax(1)).view(b, 4, a) class Proto(nn.Module): """YOLOv8 mask Proto module for segmentation models.""" def __init__(self, c1, c_=256, c2=32): # ch_in, number of protos, number of masks super().__init__() self.cv1 = Conv(c1, c_, k=3) self.upsample = nn.ConvTranspose2d(c_, c_, 2, 2, 0, bias=True) # nn.Upsample(scale_factor=2, mode='nearest') self.cv2 = Conv(c_, c_, k=3) self.cv3 = Conv(c_, c2) def forward(self, x): """Performs a forward pass through layers using an upsampled input image.""" return self.cv3(self.cv2(self.upsample(self.cv1(x)))) class HGStem(nn.Module): """StemBlock of PPHGNetV2 with 5 convolutions and one maxpool2d. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py """ def __init__(self, c1, cm, c2): super().__init__() self.stem1 = Conv(c1, cm, 3, 2, act=nn.ReLU()) self.stem2a = Conv(cm, cm // 2, 2, 1, 0, act=nn.ReLU()) self.stem2b = Conv(cm // 2, cm, 2, 1, 0, act=nn.ReLU()) self.stem3 = Conv(cm * 2, cm, 3, 2, act=nn.ReLU()) self.stem4 = Conv(cm, c2, 1, 1, act=nn.ReLU()) self.pool = nn.MaxPool2d(kernel_size=2, stride=1, padding=0, ceil_mode=True) def forward(self, x): """Forward pass of a PPHGNetV2 backbone layer.""" x = self.stem1(x) x = F.pad(x, [0, 1, 0, 1]) x2 = self.stem2a(x) x2 = F.pad(x2, [0, 1, 0, 1]) x2 = self.stem2b(x2) x1 = self.pool(x) x = torch.cat([x1, x2], dim=1) x = self.stem3(x) x = self.stem4(x) return x class HGBlock(nn.Module): """HG_Block of PPHGNetV2 with 2 convolutions and LightConv. https://github.com/PaddlePaddle/PaddleDetection/blob/develop/ppdet/modeling/backbones/hgnet_v2.py """ def __init__(self, c1, cm, c2, k=3, n=6, lightconv=False, shortcut=False, act=nn.ReLU()): super().__init__() block = LightConv if lightconv else Conv self.m = nn.ModuleList(block(c1 if i == 0 else cm, cm, k=k, act=act) for i in range(n)) self.sc = Conv(c1 + n * cm, c2 // 2, 1, 1, act=act) # squeeze conv self.ec = Conv(c2 // 2, c2, 1, 1, act=act) # excitation conv self.add = shortcut and c1 == c2 def forward(self, x): """Forward pass of a PPHGNetV2 backbone layer.""" y = [x] y.extend(m(y[-1]) for m in self.m) y = self.ec(self.sc(torch.cat(y, 1))) return y + x if self.add else y class SPP(nn.Module): """Spatial Pyramid Pooling (SPP) layer https://arxiv.org/abs/1406.4729.""" def __init__(self, c1, c2, k=(5, 9, 13)): """Initialize the SPP layer with input/output channels and pooling kernel sizes.""" super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1) self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) def forward(self, x): """Forward pass of the SPP layer, performing spatial pyramid pooling.""" x = self.cv1(x) return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) class SPPF(nn.Module): """Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher.""" def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_ * 4, c2, 1, 1) self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) def forward(self, x): """Forward pass through Ghost Convolution block.""" x = self.cv1(x) y1 = self.m(x) y2 = self.m(y1) return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1)) class C1(nn.Module): """CSP Bottleneck with 1 convolution.""" def __init__(self, c1, c2, n=1): # ch_in, ch_out, number super().__init__() self.cv1 = Conv(c1, c2, 1, 1) self.m = nn.Sequential(*(Conv(c2, c2, 3) for _ in range(n))) def forward(self, x): """Applies cross-convolutions to input in the C3 module.""" y = self.cv1(x) return self.m(y) + y class C2(nn.Module): """CSP Bottleneck with 2 convolutions.""" def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() self.c = int(c2 * e) # hidden channels self.cv1 = Conv(c1, 2 * self.c, 1, 1) self.cv2 = Conv(2 * self.c, c2, 1) # optional act=FReLU(c2) # self.attention = ChannelAttention(2 * self.c) # or SpatialAttention() self.m = nn.Sequential(*(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0) for _ in range(n))) def forward(self, x): """Forward pass through the CSP bottleneck with 2 convolutions.""" a, b = self.cv1(x).chunk(2, 1) return self.cv2(torch.cat((self.m(a), b), 1)) class C2f(nn.Module): """Faster Implementation of CSP Bottleneck with 2 convolutions.""" def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5, drop_path=None): # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() if drop_path is None: drop_path = [0.0] * n self.c = int(c2 * e) # hidden channels self.cv1 = Conv(c1, 2 * self.c, 1, 1) self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2) self.m = nn.ModuleList(Bottleneck(self.c, self.c, shortcut, g, k=((3, 3), (3, 3)), e=1.0, drop_path=drop_path[i]) for i in range(n)) def forward(self, x): """Forward pass through C2f layer.""" y = list(self.cv1(x).chunk(2, 1)) y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) def forward_split(self, x): """Forward pass using split() instead of chunk().""" y = list(self.cv1(x).split((self.c, self.c), 1)) y.extend(m(y[-1]) for m in self.m) return self.cv2(torch.cat(y, 1)) class C3(nn.Module): """CSP Bottleneck with 3 convolutions.""" def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1) self.cv3 = Conv(2 * c_, c2, 1) # optional act=FReLU(c2) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, k=((1, 1), (3, 3)), e=1.0) for _ in range(n))) def forward(self, x): """Forward pass through the CSP bottleneck with 2 convolutions.""" return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) class C3x(C3): """C3 module with cross-convolutions.""" def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): """Initialize C3TR instance and set default parameters.""" super().__init__(c1, c2, n, shortcut, g, e) self.c_ = int(c2 * e) self.m = nn.Sequential(*(Bottleneck(self.c_, self.c_, shortcut, g, k=((1, 3), (3, 1)), e=1) for _ in range(n))) class RepC3(nn.Module): """Rep C3.""" def __init__(self, c1, c2, n=3, e=1.0): super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c2, 1, 1) self.cv2 = Conv(c1, c2, 1, 1) self.m = nn.Sequential(*[RepConv(c_, c_) for _ in range(n)]) self.cv3 = Conv(c_, c2, 1, 1) if c_ != c2 else nn.Identity() def forward(self, x): """Forward pass of RT-DETR neck layer.""" return self.cv3(self.m(self.cv1(x)) + self.cv2(x)) class C3TR(C3): """C3 module with TransformerBlock().""" def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): """Initialize C3Ghost module with GhostBottleneck().""" super().__init__(c1, c2, n, shortcut, g, e) c_ = int(c2 * e) self.m = TransformerBlock(c_, c_, 4, n) class C3Ghost(C3): """C3 module with GhostBottleneck().""" def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): """Initialize 'SPP' module with various pooling sizes for spatial pyramid pooling.""" super().__init__(c1, c2, n, shortcut, g, e) c_ = int(c2 * e) # hidden channels self.m = nn.Sequential(*(GhostBottleneck(c_, c_) for _ in range(n))) class GhostBottleneck(nn.Module): """Ghost Bottleneck https://github.com/huawei-noah/ghostnet.""" def __init__(self, c1, c2, k=3, s=1): # ch_in, ch_out, kernel, stride super().__init__() c_ = c2 // 2 self.conv = nn.Sequential( GhostConv(c1, c_, 1, 1), # pw DWConv(c_, c_, k, s, act=False) if s == 2 else nn.Identity(), # dw GhostConv(c_, c2, 1, 1, act=False)) # pw-linear self.shortcut = nn.Sequential(DWConv(c1, c1, k, s, act=False), Conv(c1, c2, 1, 1, act=False)) if s == 2 else nn.Identity() def forward(self, x): """Applies skip connection and concatenation to input tensor.""" return self.conv(x) + self.shortcut(x) class Bottleneck(nn.Module): """Standard bottleneck.""" def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5, drop_path=0.0): # ch_in, ch_out, shortcut, groups, kernels, expand super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, k[0], 1) self.cv2 = Conv(c_, c2, k[1], 1, g=g) self.add = shortcut and c1 == c2 self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): """'forward()' applies the YOLOv5 FPN to input data.""" return x + self.drop_path1(self.cv2(self.cv1(x))) if self.add else self.cv2(self.cv1(x)) class BottleneckCSP(nn.Module): """CSP Bottleneck https://github.com/WongKinYiu/CrossStagePartialNetworks.""" def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion super().__init__() c_ = int(c2 * e) # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = nn.Conv2d(c1, c_, 1, 1, bias=False) self.cv3 = nn.Conv2d(c_, c_, 1, 1, bias=False) self.cv4 = Conv(2 * c_, c2, 1, 1) self.bn = nn.BatchNorm2d(2 * c_) # applied to cat(cv2, cv3) self.act = nn.SiLU() self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) def forward(self, x): """Applies a CSP bottleneck with 3 convolutions.""" y1 = self.cv3(self.m(self.cv1(x))) y2 = self.cv2(x) return self.cv4(self.act(self.bn(torch.cat((y1, y2), 1))))