Spaces:
Build error
Build error
File size: 2,788 Bytes
d3c7726 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from torch import nn
from yolov6.layers.common import RepVGGBlock, RepBlock, SimSPPF
class EfficientRep(nn.Module):
'''EfficientRep Backbone
EfficientRep is handcrafted by hardware-aware neural network design.
With rep-style struct, EfficientRep is friendly to high-computation hardware(e.g. GPU).
'''
def __init__(
self,
in_channels=3,
channels_list=None,
num_repeats=None,
):
super().__init__()
assert channels_list is not None
assert num_repeats is not None
self.stem = RepVGGBlock(
in_channels=in_channels,
out_channels=channels_list[0],
kernel_size=3,
stride=2
)
self.ERBlock_2 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[0],
out_channels=channels_list[1],
kernel_size=3,
stride=2
),
RepBlock(
in_channels=channels_list[1],
out_channels=channels_list[1],
n=num_repeats[1]
)
)
self.ERBlock_3 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[1],
out_channels=channels_list[2],
kernel_size=3,
stride=2
),
RepBlock(
in_channels=channels_list[2],
out_channels=channels_list[2],
n=num_repeats[2]
)
)
self.ERBlock_4 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[2],
out_channels=channels_list[3],
kernel_size=3,
stride=2
),
RepBlock(
in_channels=channels_list[3],
out_channels=channels_list[3],
n=num_repeats[3]
)
)
self.ERBlock_5 = nn.Sequential(
RepVGGBlock(
in_channels=channels_list[3],
out_channels=channels_list[4],
kernel_size=3,
stride=2,
),
RepBlock(
in_channels=channels_list[4],
out_channels=channels_list[4],
n=num_repeats[4]
),
SimSPPF(
in_channels=channels_list[4],
out_channels=channels_list[4],
kernel_size=5
)
)
def forward(self, x):
outputs = []
x = self.stem(x)
x = self.ERBlock_2(x)
x = self.ERBlock_3(x)
outputs.append(x)
x = self.ERBlock_4(x)
outputs.append(x)
x = self.ERBlock_5(x)
outputs.append(x)
return tuple(outputs)
|