SVHN-Recognition / model.py
MuGeminorum
fix cpu
cadcfd6
raw
history blame
No virus
5.88 kB
import os
import glob
import torch
import torch.jit
import torch.nn as nn
class Model(torch.jit.ScriptModule):
CHECKPOINT_FILENAME_PATTERN = 'model-{}.pth'
__constants__ = [
'_hidden1', '_hidden2', '_hidden3', '_hidden4', '_hidden5', '_hidden6',
'_hidden7', '_hidden8', '_hidden9', '_hidden10', '_features', '_classifier',
'_digit_length', '_digit1', '_digit2', '_digit3', '_digit4', '_digit5'
]
def __init__(self):
super(Model, self).__init__()
self._hidden1 = nn.Sequential(
nn.Conv2d(
in_channels=3,
out_channels=48,
kernel_size=5,
padding=2
),
nn.BatchNorm2d(num_features=48),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.Dropout(0.2)
)
self._hidden2 = nn.Sequential(
nn.Conv2d(
in_channels=48,
out_channels=64,
kernel_size=5,
padding=2
),
nn.BatchNorm2d(num_features=64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
nn.Dropout(0.2)
)
self._hidden3 = nn.Sequential(
nn.Conv2d(
in_channels=64,
out_channels=128,
kernel_size=5,
padding=2
),
nn.BatchNorm2d(num_features=128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.Dropout(0.2)
)
self._hidden4 = nn.Sequential(
nn.Conv2d(
in_channels=128,
out_channels=160,
kernel_size=5,
padding=2
),
nn.BatchNorm2d(num_features=160),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
nn.Dropout(0.2)
)
self._hidden5 = nn.Sequential(
nn.Conv2d(
in_channels=160,
out_channels=192,
kernel_size=5,
padding=2
),
nn.BatchNorm2d(num_features=192),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.Dropout(0.2)
)
self._hidden6 = nn.Sequential(
nn.Conv2d(
in_channels=192,
out_channels=192,
kernel_size=5,
padding=2
),
nn.BatchNorm2d(num_features=192),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
nn.Dropout(0.2)
)
self._hidden7 = nn.Sequential(
nn.Conv2d(
in_channels=192,
out_channels=192,
kernel_size=5,
padding=2
),
nn.BatchNorm2d(num_features=192),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
nn.Dropout(0.2)
)
self._hidden8 = nn.Sequential(
nn.Conv2d(
in_channels=192,
out_channels=192,
kernel_size=5,
padding=2
),
nn.BatchNorm2d(num_features=192),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
nn.Dropout(0.2)
)
self._hidden9 = nn.Sequential(
nn.Linear(192 * 7 * 7, 3072),
nn.ReLU()
)
self._hidden10 = nn.Sequential(
nn.Linear(3072, 3072),
nn.ReLU()
)
self._digit_length = nn.Sequential(nn.Linear(3072, 7))
self._digit1 = nn.Sequential(nn.Linear(3072, 11))
self._digit2 = nn.Sequential(nn.Linear(3072, 11))
self._digit3 = nn.Sequential(nn.Linear(3072, 11))
self._digit4 = nn.Sequential(nn.Linear(3072, 11))
self._digit5 = nn.Sequential(nn.Linear(3072, 11))
@torch.jit.script_method
def forward(self, x):
x = self._hidden1(x)
x = self._hidden2(x)
x = self._hidden3(x)
x = self._hidden4(x)
x = self._hidden5(x)
x = self._hidden6(x)
x = self._hidden7(x)
x = self._hidden8(x)
x = x.view(x.size(0), 192 * 7 * 7)
x = self._hidden9(x)
x = self._hidden10(x)
length_logits = self._digit_length(x)
digit1_logits = self._digit1(x)
digit2_logits = self._digit2(x)
digit3_logits = self._digit3(x)
digit4_logits = self._digit4(x)
digit5_logits = self._digit5(x)
return length_logits, digit1_logits, digit2_logits, digit3_logits, digit4_logits, digit5_logits
def store(self, path_to_dir, step, maximum=5):
path_to_models = glob.glob(os.path.join(
path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format('*')))
if len(path_to_models) == maximum:
min_step = min(
[int(path_to_model.split('\\')[-1][6:-4])
for path_to_model in path_to_models]
)
path_to_min_step_model = os.path.join(
path_to_dir,
Model.CHECKPOINT_FILENAME_PATTERN.format(min_step)
)
os.remove(path_to_min_step_model)
path_to_checkpoint_file = os.path.join(
path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format(step)
)
torch.save(self.state_dict(), path_to_checkpoint_file)
return path_to_checkpoint_file
def restore(self, path_to_checkpoint_file):
self.load_state_dict(
torch.load(
path_to_checkpoint_file,
map_location=torch.device('cpu')
)
)
step = int(path_to_checkpoint_file.split('\\')[-1][6:-4])
return step