English
cloudseg-models / models /unetmobv2.py
XavierJiezou's picture
Add files using upload-large-folder tool
63a9590 verified
raw
history blame contribute delete
773 Bytes
# -*- coding: utf-8 -*-
# @Time : 2024/8/6 下午3:44
# @Author : xiaoshun
# @Email : 3038523973@qq.com
# @File : unetmobv2.py
# @Software: PyCharm
import segmentation_models_pytorch as smp
import torch
from torch import nn as nn
class UNetMobV2(nn.Module):
def __init__(self,num_classes,in_channels=3):
super().__init__()
self.backbone = smp.Unet(
encoder_name='mobilenet_v2',
encoder_weights=None,
in_channels=in_channels,
classes=num_classes,
)
def forward(self, x):
x = self.backbone(x)
return x
if __name__ == '__main__':
fake_image = torch.rand(1, 3, 224, 224)
model = UNetMobV2(num_classes=2)
output = model(fake_image)
print(output.size())