Cyril666 commited on
Commit
3bb6945
Β·
1 Parent(s): db445c8

First model version

Browse files
app.py CHANGED
@@ -10,7 +10,7 @@ from demo import get_model, preprocess, postprocess, load
10
  from utils import Config, Logger, CharsetMapper
11
 
12
  def process_image(image):
13
- config = Config('configs/train_abinet.yaml')
14
  config.model_vision_checkpoint = None
15
  model = get_model(config)
16
  model = load(model, 'workdir/train-abinet/best-train-abinet.pth')
 
10
  from utils import Config, Logger, CharsetMapper
11
 
12
  def process_image(image):
13
+ config = Config('configs/rec/train_abinet.yaml')
14
  config.model_vision_checkpoint = None
15
  model = get_model(config)
16
  model = load(model, 'workdir/train-abinet/best-train-abinet.pth')
configs/{template.yaml β†’ rec/template.yaml} RENAMED
File without changes
configs/{train_abinet.yaml β†’ rec/train_abinet.yaml} RENAMED
File without changes
modules/__init__.py DELETED
File without changes
modules/attention.py DELETED
@@ -1,97 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from .transformer import PositionalEncoding
4
-
5
- class Attention(nn.Module):
6
- def __init__(self, in_channels=512, max_length=25, n_feature=256):
7
- super().__init__()
8
- self.max_length = max_length
9
-
10
- self.f0_embedding = nn.Embedding(max_length, in_channels)
11
- self.w0 = nn.Linear(max_length, n_feature)
12
- self.wv = nn.Linear(in_channels, in_channels)
13
- self.we = nn.Linear(in_channels, max_length)
14
-
15
- self.active = nn.Tanh()
16
- self.softmax = nn.Softmax(dim=2)
17
-
18
- def forward(self, enc_output):
19
- enc_output = enc_output.permute(0, 2, 3, 1).flatten(1, 2)
20
- reading_order = torch.arange(self.max_length, dtype=torch.long, device=enc_output.device)
21
- reading_order = reading_order.unsqueeze(0).expand(enc_output.size(0), -1) # (S,) -> (B, S)
22
- reading_order_embed = self.f0_embedding(reading_order) # b,25,512
23
-
24
- t = self.w0(reading_order_embed.permute(0, 2, 1)) # b,512,256
25
- t = self.active(t.permute(0, 2, 1) + self.wv(enc_output)) # b,256,512
26
-
27
- attn = self.we(t) # b,256,25
28
- attn = self.softmax(attn.permute(0, 2, 1)) # b,25,256
29
- g_output = torch.bmm(attn, enc_output) # b,25,512
30
- return g_output, attn.view(*attn.shape[:2], 8, 32)
31
-
32
-
33
- def encoder_layer(in_c, out_c, k=3, s=2, p=1):
34
- return nn.Sequential(nn.Conv2d(in_c, out_c, k, s, p),
35
- nn.BatchNorm2d(out_c),
36
- nn.ReLU(True))
37
-
38
- def decoder_layer(in_c, out_c, k=3, s=1, p=1, mode='nearest', scale_factor=None, size=None):
39
- align_corners = None if mode=='nearest' else True
40
- return nn.Sequential(nn.Upsample(size=size, scale_factor=scale_factor,
41
- mode=mode, align_corners=align_corners),
42
- nn.Conv2d(in_c, out_c, k, s, p),
43
- nn.BatchNorm2d(out_c),
44
- nn.ReLU(True))
45
-
46
-
47
- class PositionAttention(nn.Module):
48
- def __init__(self, max_length, in_channels=512, num_channels=64,
49
- h=8, w=32, mode='nearest', **kwargs):
50
- super().__init__()
51
- self.max_length = max_length
52
- self.k_encoder = nn.Sequential(
53
- encoder_layer(in_channels, num_channels, s=(1, 2)),
54
- encoder_layer(num_channels, num_channels, s=(2, 2)),
55
- encoder_layer(num_channels, num_channels, s=(2, 2)),
56
- encoder_layer(num_channels, num_channels, s=(2, 2))
57
- )
58
- self.k_decoder = nn.Sequential(
59
- decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
60
- decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
61
- decoder_layer(num_channels, num_channels, scale_factor=2, mode=mode),
62
- decoder_layer(num_channels, in_channels, size=(h, w), mode=mode)
63
- )
64
-
65
- self.pos_encoder = PositionalEncoding(in_channels, dropout=0, max_len=max_length)
66
- self.project = nn.Linear(in_channels, in_channels)
67
-
68
- def forward(self, x):
69
- N, E, H, W = x.size()
70
- k, v = x, x # (N, E, H, W)
71
-
72
- # calculate key vector
73
- features = []
74
- for i in range(0, len(self.k_encoder)):
75
- k = self.k_encoder[i](k)
76
- features.append(k)
77
- for i in range(0, len(self.k_decoder) - 1):
78
- k = self.k_decoder[i](k)
79
- k = k + features[len(self.k_decoder) - 2 - i]
80
- k = self.k_decoder[-1](k)
81
-
82
- # calculate query vector
83
- # TODO q=f(q,k)
84
- zeros = x.new_zeros((self.max_length, N, E)) # (T, N, E)
85
- q = self.pos_encoder(zeros) # (T, N, E)
86
- q = q.permute(1, 0, 2) # (N, T, E)
87
- q = self.project(q) # (N, T, E)
88
-
89
- # calculate attention
90
- attn_scores = torch.bmm(q, k.flatten(2, 3)) # (N, T, (H*W))
91
- attn_scores = attn_scores / (E ** 0.5)
92
- attn_scores = torch.softmax(attn_scores, dim=-1)
93
-
94
- v = v.permute(0, 2, 3, 1).view(N, -1, E) # (N, (H*W), E)
95
- attn_vecs = torch.bmm(attn_scores, v) # (N, T, E)
96
-
97
- return attn_vecs, attn_scores.view(N, -1, H, W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/backbone.py DELETED
@@ -1,36 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from fastai.vision import *
4
-
5
- from modules.model import _default_tfmer_cfg
6
- from modules.resnet import resnet45
7
- from modules.transformer import (PositionalEncoding,
8
- TransformerEncoder,
9
- TransformerEncoderLayer)
10
-
11
-
12
- class ResTranformer(nn.Module):
13
- def __init__(self, config):
14
- super().__init__()
15
- self.resnet = resnet45()
16
-
17
- self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model'])
18
- nhead = ifnone(config.model_vision_nhead, _default_tfmer_cfg['nhead'])
19
- d_inner = ifnone(config.model_vision_d_inner, _default_tfmer_cfg['d_inner'])
20
- dropout = ifnone(config.model_vision_dropout, _default_tfmer_cfg['dropout'])
21
- activation = ifnone(config.model_vision_activation, _default_tfmer_cfg['activation'])
22
- num_layers = ifnone(config.model_vision_backbone_ln, 2)
23
-
24
- self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32)
25
- encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead,
26
- dim_feedforward=d_inner, dropout=dropout, activation=activation)
27
- self.transformer = TransformerEncoder(encoder_layer, num_layers)
28
-
29
- def forward(self, images):
30
- feature = self.resnet(images)
31
- n, c, h, w = feature.shape
32
- feature = feature.view(n, c, -1).permute(2, 0, 1)
33
- feature = self.pos_encoder(feature)
34
- feature = self.transformer(feature)
35
- feature = feature.permute(1, 2, 0).view(n, c, h, w)
36
- return feature
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model.py DELETED
@@ -1,50 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
- from utils import CharsetMapper
5
-
6
-
7
- _default_tfmer_cfg = dict(d_model=512, nhead=8, d_inner=2048, # 1024
8
- dropout=0.1, activation='relu')
9
-
10
- class Model(nn.Module):
11
-
12
- def __init__(self, config):
13
- super().__init__()
14
- self.max_length = config.dataset_max_length + 1
15
- self.charset = CharsetMapper(config.dataset_charset_path, max_length=self.max_length)
16
-
17
- def load(self, source, device=None, strict=True):
18
- state = torch.load(source, map_location=device)
19
- self.load_state_dict(state['model'], strict=strict)
20
-
21
- def _get_length(self, logit, dim=-1):
22
- """ Greed decoder to obtain length from logit"""
23
- out = (logit.argmax(dim=-1) == self.charset.null_label)
24
- abn = out.any(dim)
25
- out = ((out.cumsum(dim) == 1) & out).max(dim)[1]
26
- out = out + 1 # additional end token
27
- out = torch.where(abn, out, out.new_tensor(logit.shape[1]))
28
- return out
29
-
30
- @staticmethod
31
- def _get_padding_mask(length, max_length):
32
- length = length.unsqueeze(-1)
33
- grid = torch.arange(0, max_length, device=length.device).unsqueeze(0)
34
- return grid >= length
35
-
36
- @staticmethod
37
- def _get_square_subsequent_mask(sz, device, diagonal=0, fw=True):
38
- r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
39
- Unmasked positions are filled with float(0.0).
40
- """
41
- mask = (torch.triu(torch.ones(sz, sz, device=device), diagonal=diagonal) == 1)
42
- if fw: mask = mask.transpose(0, 1)
43
- mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
44
- return mask
45
-
46
- @staticmethod
47
- def _get_location_mask(sz, device=None):
48
- mask = torch.eye(sz, device=device)
49
- mask = mask.float().masked_fill(mask == 1, float('-inf'))
50
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model_abinet.py DELETED
@@ -1,30 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from fastai.vision import *
4
-
5
- from .model_vision import BaseVision
6
- from .model_language import BCNLanguage
7
- from .model_alignment import BaseAlignment
8
-
9
-
10
- class ABINetModel(nn.Module):
11
- def __init__(self, config):
12
- super().__init__()
13
- self.use_alignment = ifnone(config.model_use_alignment, True)
14
- self.max_length = config.dataset_max_length + 1 # additional stop token
15
- self.vision = BaseVision(config)
16
- self.language = BCNLanguage(config)
17
- if self.use_alignment: self.alignment = BaseAlignment(config)
18
-
19
- def forward(self, images, *args):
20
- v_res = self.vision(images)
21
- v_tokens = torch.softmax(v_res['logits'], dim=-1)
22
- v_lengths = v_res['pt_lengths'].clamp_(2, self.max_length) # TODO:move to langauge model
23
-
24
- l_res = self.language(v_tokens, v_lengths)
25
- if not self.use_alignment:
26
- return l_res, v_res
27
- l_feature, v_feature = l_res['feature'], v_res['feature']
28
-
29
- a_res = self.alignment(l_feature, v_feature)
30
- return a_res, l_res, v_res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model_abinet_iter.py DELETED
@@ -1,34 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from fastai.vision import *
4
-
5
- from .model_vision import BaseVision
6
- from .model_language import BCNLanguage
7
- from .model_alignment import BaseAlignment
8
-
9
-
10
- class ABINetIterModel(nn.Module):
11
- def __init__(self, config):
12
- super().__init__()
13
- self.iter_size = ifnone(config.model_iter_size, 1)
14
- self.max_length = config.dataset_max_length + 1 # additional stop token
15
- self.vision = BaseVision(config)
16
- self.language = BCNLanguage(config)
17
- self.alignment = BaseAlignment(config)
18
-
19
- def forward(self, images, *args):
20
- v_res = self.vision(images)
21
- a_res = v_res
22
- all_l_res, all_a_res = [], []
23
- for _ in range(self.iter_size):
24
- tokens = torch.softmax(a_res['logits'], dim=-1)
25
- lengths = a_res['pt_lengths']
26
- lengths.clamp_(2, self.max_length) # TODO:move to langauge model
27
- l_res = self.language(tokens, lengths)
28
- all_l_res.append(l_res)
29
- a_res = self.alignment(l_res['feature'], v_res['feature'])
30
- all_a_res.append(a_res)
31
- if self.training:
32
- return all_a_res, all_l_res, v_res
33
- else:
34
- return a_res, all_l_res[-1], v_res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model_alignment.py DELETED
@@ -1,34 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from fastai.vision import *
4
-
5
- from modules.model import Model, _default_tfmer_cfg
6
-
7
-
8
- class BaseAlignment(Model):
9
- def __init__(self, config):
10
- super().__init__(config)
11
- d_model = ifnone(config.model_alignment_d_model, _default_tfmer_cfg['d_model'])
12
-
13
- self.loss_weight = ifnone(config.model_alignment_loss_weight, 1.0)
14
- self.max_length = config.dataset_max_length + 1 # additional stop token
15
- self.w_att = nn.Linear(2 * d_model, d_model)
16
- self.cls = nn.Linear(d_model, self.charset.num_classes)
17
-
18
- def forward(self, l_feature, v_feature):
19
- """
20
- Args:
21
- l_feature: (N, T, E) where T is length, N is batch size and d is dim of model
22
- v_feature: (N, T, E) shape the same as l_feature
23
- l_lengths: (N,)
24
- v_lengths: (N,)
25
- """
26
- f = torch.cat((l_feature, v_feature), dim=2)
27
- f_att = torch.sigmoid(self.w_att(f))
28
- output = f_att * v_feature + (1 - f_att) * l_feature
29
-
30
- logits = self.cls(output) # (N, T, C)
31
- pt_lengths = self._get_length(logits)
32
-
33
- return {'logits': logits, 'pt_lengths': pt_lengths, 'loss_weight':self.loss_weight,
34
- 'name': 'alignment'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model_language.py DELETED
@@ -1,67 +0,0 @@
1
- import logging
2
- import torch.nn as nn
3
- from fastai.vision import *
4
-
5
- from modules.model import _default_tfmer_cfg
6
- from modules.model import Model
7
- from modules.transformer import (PositionalEncoding,
8
- TransformerDecoder,
9
- TransformerDecoderLayer)
10
-
11
-
12
- class BCNLanguage(Model):
13
- def __init__(self, config):
14
- super().__init__(config)
15
- d_model = ifnone(config.model_language_d_model, _default_tfmer_cfg['d_model'])
16
- nhead = ifnone(config.model_language_nhead, _default_tfmer_cfg['nhead'])
17
- d_inner = ifnone(config.model_language_d_inner, _default_tfmer_cfg['d_inner'])
18
- dropout = ifnone(config.model_language_dropout, _default_tfmer_cfg['dropout'])
19
- activation = ifnone(config.model_language_activation, _default_tfmer_cfg['activation'])
20
- num_layers = ifnone(config.model_language_num_layers, 4)
21
- self.d_model = d_model
22
- self.detach = ifnone(config.model_language_detach, True)
23
- self.use_self_attn = ifnone(config.model_language_use_self_attn, False)
24
- self.loss_weight = ifnone(config.model_language_loss_weight, 1.0)
25
- self.max_length = config.dataset_max_length + 1 # additional stop token
26
- self.debug = ifnone(config.global_debug, False)
27
-
28
- self.proj = nn.Linear(self.charset.num_classes, d_model, False)
29
- self.token_encoder = PositionalEncoding(d_model, max_len=self.max_length)
30
- self.pos_encoder = PositionalEncoding(d_model, dropout=0, max_len=self.max_length)
31
- decoder_layer = TransformerDecoderLayer(d_model, nhead, d_inner, dropout,
32
- activation, self_attn=self.use_self_attn, debug=self.debug)
33
- self.model = TransformerDecoder(decoder_layer, num_layers)
34
-
35
- self.cls = nn.Linear(d_model, self.charset.num_classes)
36
-
37
- if config.model_language_checkpoint is not None:
38
- logging.info(f'Read language model from {config.model_language_checkpoint}.')
39
- self.load(config.model_language_checkpoint)
40
-
41
- def forward(self, tokens, lengths):
42
- """
43
- Args:
44
- tokens: (N, T, C) where T is length, N is batch size and C is classes number
45
- lengths: (N,)
46
- """
47
- if self.detach: tokens = tokens.detach()
48
- embed = self.proj(tokens) # (N, T, E)
49
- embed = embed.permute(1, 0, 2) # (T, N, E)
50
- embed = self.token_encoder(embed) # (T, N, E)
51
- padding_mask = self._get_padding_mask(lengths, self.max_length)
52
-
53
- zeros = embed.new_zeros(*embed.shape)
54
- qeury = self.pos_encoder(zeros)
55
- location_mask = self._get_location_mask(self.max_length, tokens.device)
56
- output = self.model(qeury, embed,
57
- tgt_key_padding_mask=padding_mask,
58
- memory_mask=location_mask,
59
- memory_key_padding_mask=padding_mask) # (T, N, E)
60
- output = output.permute(1, 0, 2) # (N, T, E)
61
-
62
- logits = self.cls(output) # (N, T, C)
63
- pt_lengths = self._get_length(logits)
64
-
65
- res = {'feature': output, 'logits': logits, 'pt_lengths': pt_lengths,
66
- 'loss_weight':self.loss_weight, 'name': 'language'}
67
- return res
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model_vision.py DELETED
@@ -1,47 +0,0 @@
1
- import logging
2
- import torch.nn as nn
3
- from fastai.vision import *
4
-
5
- from modules.attention import *
6
- from modules.backbone import ResTranformer
7
- from modules.model import Model
8
- from modules.resnet import resnet45
9
-
10
-
11
- class BaseVision(Model):
12
- def __init__(self, config):
13
- super().__init__(config)
14
- self.loss_weight = ifnone(config.model_vision_loss_weight, 1.0)
15
- self.out_channels = ifnone(config.model_vision_d_model, 512)
16
-
17
- if config.model_vision_backbone == 'transformer':
18
- self.backbone = ResTranformer(config)
19
- else: self.backbone = resnet45()
20
-
21
- if config.model_vision_attention == 'position':
22
- mode = ifnone(config.model_vision_attention_mode, 'nearest')
23
- self.attention = PositionAttention(
24
- max_length=config.dataset_max_length + 1, # additional stop token
25
- mode=mode,
26
- )
27
- elif config.model_vision_attention == 'attention':
28
- self.attention = Attention(
29
- max_length=config.dataset_max_length + 1, # additional stop token
30
- n_feature=8*32,
31
- )
32
- else:
33
- raise Exception(f'{config.model_vision_attention} is not valid.')
34
- self.cls = nn.Linear(self.out_channels, self.charset.num_classes)
35
-
36
- if config.model_vision_checkpoint is not None:
37
- logging.info(f'Read vision model from {config.model_vision_checkpoint}.')
38
- self.load(config.model_vision_checkpoint)
39
-
40
- def forward(self, images, *args):
41
- features = self.backbone(images) # (N, E, H, W)
42
- attn_vecs, attn_scores = self.attention(features) # (N, T, E), (N, T, H, W)
43
- logits = self.cls(attn_vecs) # (N, T, C)
44
- pt_lengths = self._get_length(logits)
45
-
46
- return {'feature': attn_vecs, 'logits': logits, 'pt_lengths': pt_lengths,
47
- 'attn_scores': attn_scores, 'loss_weight':self.loss_weight, 'name': 'vision'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/resnet.py DELETED
@@ -1,104 +0,0 @@
1
- import math
2
-
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torch.utils.model_zoo as model_zoo
6
-
7
-
8
- def conv1x1(in_planes, out_planes, stride=1):
9
- return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
10
-
11
-
12
- def conv3x3(in_planes, out_planes, stride=1):
13
- "3x3 convolution with padding"
14
- return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
15
- padding=1, bias=False)
16
-
17
-
18
- class BasicBlock(nn.Module):
19
- expansion = 1
20
-
21
- def __init__(self, inplanes, planes, stride=1, downsample=None):
22
- super(BasicBlock, self).__init__()
23
- self.conv1 = conv1x1(inplanes, planes)
24
- self.bn1 = nn.BatchNorm2d(planes)
25
- self.relu = nn.ReLU(inplace=True)
26
- self.conv2 = conv3x3(planes, planes, stride)
27
- self.bn2 = nn.BatchNorm2d(planes)
28
- self.downsample = downsample
29
- self.stride = stride
30
-
31
- def forward(self, x):
32
- residual = x
33
-
34
- out = self.conv1(x)
35
- out = self.bn1(out)
36
- out = self.relu(out)
37
-
38
- out = self.conv2(out)
39
- out = self.bn2(out)
40
-
41
- if self.downsample is not None:
42
- residual = self.downsample(x)
43
-
44
- out += residual
45
- out = self.relu(out)
46
-
47
- return out
48
-
49
-
50
- class ResNet(nn.Module):
51
-
52
- def __init__(self, block, layers):
53
- self.inplanes = 32
54
- super(ResNet, self).__init__()
55
- self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1,
56
- bias=False)
57
- self.bn1 = nn.BatchNorm2d(32)
58
- self.relu = nn.ReLU(inplace=True)
59
-
60
- self.layer1 = self._make_layer(block, 32, layers[0], stride=2)
61
- self.layer2 = self._make_layer(block, 64, layers[1], stride=1)
62
- self.layer3 = self._make_layer(block, 128, layers[2], stride=2)
63
- self.layer4 = self._make_layer(block, 256, layers[3], stride=1)
64
- self.layer5 = self._make_layer(block, 512, layers[4], stride=1)
65
-
66
- for m in self.modules():
67
- if isinstance(m, nn.Conv2d):
68
- n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
69
- m.weight.data.normal_(0, math.sqrt(2. / n))
70
- elif isinstance(m, nn.BatchNorm2d):
71
- m.weight.data.fill_(1)
72
- m.bias.data.zero_()
73
-
74
- def _make_layer(self, block, planes, blocks, stride=1):
75
- downsample = None
76
- if stride != 1 or self.inplanes != planes * block.expansion:
77
- downsample = nn.Sequential(
78
- nn.Conv2d(self.inplanes, planes * block.expansion,
79
- kernel_size=1, stride=stride, bias=False),
80
- nn.BatchNorm2d(planes * block.expansion),
81
- )
82
-
83
- layers = []
84
- layers.append(block(self.inplanes, planes, stride, downsample))
85
- self.inplanes = planes * block.expansion
86
- for i in range(1, blocks):
87
- layers.append(block(self.inplanes, planes))
88
-
89
- return nn.Sequential(*layers)
90
-
91
- def forward(self, x):
92
- x = self.conv1(x)
93
- x = self.bn1(x)
94
- x = self.relu(x)
95
- x = self.layer1(x)
96
- x = self.layer2(x)
97
- x = self.layer3(x)
98
- x = self.layer4(x)
99
- x = self.layer5(x)
100
- return x
101
-
102
-
103
- def resnet45():
104
- return ResNet(BasicBlock, [3, 4, 6, 6, 3])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/transformer.py DELETED
@@ -1,901 +0,0 @@
1
- # pytorch 1.5.0
2
- import copy
3
- import math
4
- import warnings
5
- from typing import Optional
6
-
7
- import torch
8
- import torch.nn as nn
9
- from torch import Tensor
10
- from torch.nn import Dropout, LayerNorm, Linear, Module, ModuleList, Parameter
11
- from torch.nn import functional as F
12
- from torch.nn.init import constant_, xavier_uniform_
13
-
14
-
15
- def multi_head_attention_forward(query, # type: Tensor
16
- key, # type: Tensor
17
- value, # type: Tensor
18
- embed_dim_to_check, # type: int
19
- num_heads, # type: int
20
- in_proj_weight, # type: Tensor
21
- in_proj_bias, # type: Tensor
22
- bias_k, # type: Optional[Tensor]
23
- bias_v, # type: Optional[Tensor]
24
- add_zero_attn, # type: bool
25
- dropout_p, # type: float
26
- out_proj_weight, # type: Tensor
27
- out_proj_bias, # type: Tensor
28
- training=True, # type: bool
29
- key_padding_mask=None, # type: Optional[Tensor]
30
- need_weights=True, # type: bool
31
- attn_mask=None, # type: Optional[Tensor]
32
- use_separate_proj_weight=False, # type: bool
33
- q_proj_weight=None, # type: Optional[Tensor]
34
- k_proj_weight=None, # type: Optional[Tensor]
35
- v_proj_weight=None, # type: Optional[Tensor]
36
- static_k=None, # type: Optional[Tensor]
37
- static_v=None # type: Optional[Tensor]
38
- ):
39
- # type: (...) -> Tuple[Tensor, Optional[Tensor]]
40
- r"""
41
- Args:
42
- query, key, value: map a query and a set of key-value pairs to an output.
43
- See "Attention Is All You Need" for more details.
44
- embed_dim_to_check: total dimension of the model.
45
- num_heads: parallel attention heads.
46
- in_proj_weight, in_proj_bias: input projection weight and bias.
47
- bias_k, bias_v: bias of the key and value sequences to be added at dim=0.
48
- add_zero_attn: add a new batch of zeros to the key and
49
- value sequences at dim=1.
50
- dropout_p: probability of an element to be zeroed.
51
- out_proj_weight, out_proj_bias: the output projection weight and bias.
52
- training: apply dropout if is ``True``.
53
- key_padding_mask: if provided, specified padding elements in the key will
54
- be ignored by the attention. This is an binary mask. When the value is True,
55
- the corresponding value on the attention layer will be filled with -inf.
56
- need_weights: output attn_output_weights.
57
- attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
58
- the batches while a 3D mask allows to specify a different mask for the entries of each batch.
59
- use_separate_proj_weight: the function accept the proj. weights for query, key,
60
- and value in different forms. If false, in_proj_weight will be used, which is
61
- a combination of q_proj_weight, k_proj_weight, v_proj_weight.
62
- q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias.
63
- static_k, static_v: static key and value used for attention operators.
64
- Shape:
65
- Inputs:
66
- - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
67
- the embedding dimension.
68
- - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
69
- the embedding dimension.
70
- - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
71
- the embedding dimension.
72
- - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
73
- If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
74
- will be unchanged. If a BoolTensor is provided, the positions with the
75
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
76
- - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
77
- 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
78
- S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
79
- positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
80
- while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
81
- are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
82
- is provided, it will be added to the attention weight.
83
- - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
84
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
85
- - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length,
86
- N is the batch size, E is the embedding dimension. E/num_heads is the head dimension.
87
- Outputs:
88
- - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
89
- E is the embedding dimension.
90
- - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
91
- L is the target sequence length, S is the source sequence length.
92
- """
93
- # if not torch.jit.is_scripting():
94
- # tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v,
95
- # out_proj_weight, out_proj_bias)
96
- # if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops):
97
- # return handle_torch_function(
98
- # multi_head_attention_forward, tens_ops, query, key, value,
99
- # embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias,
100
- # bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight,
101
- # out_proj_bias, training=training, key_padding_mask=key_padding_mask,
102
- # need_weights=need_weights, attn_mask=attn_mask,
103
- # use_separate_proj_weight=use_separate_proj_weight,
104
- # q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight,
105
- # v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v)
106
- tgt_len, bsz, embed_dim = query.size()
107
- assert embed_dim == embed_dim_to_check
108
- assert key.size() == value.size()
109
-
110
- head_dim = embed_dim // num_heads
111
- assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
112
- scaling = float(head_dim) ** -0.5
113
-
114
- if not use_separate_proj_weight:
115
- if torch.equal(query, key) and torch.equal(key, value):
116
- # self-attention
117
- q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1)
118
-
119
- elif torch.equal(key, value):
120
- # encoder-decoder attention
121
- # This is inline in_proj function with in_proj_weight and in_proj_bias
122
- _b = in_proj_bias
123
- _start = 0
124
- _end = embed_dim
125
- _w = in_proj_weight[_start:_end, :]
126
- if _b is not None:
127
- _b = _b[_start:_end]
128
- q = F.linear(query, _w, _b)
129
-
130
- if key is None:
131
- assert value is None
132
- k = None
133
- v = None
134
- else:
135
-
136
- # This is inline in_proj function with in_proj_weight and in_proj_bias
137
- _b = in_proj_bias
138
- _start = embed_dim
139
- _end = None
140
- _w = in_proj_weight[_start:, :]
141
- if _b is not None:
142
- _b = _b[_start:]
143
- k, v = F.linear(key, _w, _b).chunk(2, dim=-1)
144
-
145
- else:
146
- # This is inline in_proj function with in_proj_weight and in_proj_bias
147
- _b = in_proj_bias
148
- _start = 0
149
- _end = embed_dim
150
- _w = in_proj_weight[_start:_end, :]
151
- if _b is not None:
152
- _b = _b[_start:_end]
153
- q = F.linear(query, _w, _b)
154
-
155
- # This is inline in_proj function with in_proj_weight and in_proj_bias
156
- _b = in_proj_bias
157
- _start = embed_dim
158
- _end = embed_dim * 2
159
- _w = in_proj_weight[_start:_end, :]
160
- if _b is not None:
161
- _b = _b[_start:_end]
162
- k = F.linear(key, _w, _b)
163
-
164
- # This is inline in_proj function with in_proj_weight and in_proj_bias
165
- _b = in_proj_bias
166
- _start = embed_dim * 2
167
- _end = None
168
- _w = in_proj_weight[_start:, :]
169
- if _b is not None:
170
- _b = _b[_start:]
171
- v = F.linear(value, _w, _b)
172
- else:
173
- q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight)
174
- len1, len2 = q_proj_weight_non_opt.size()
175
- assert len1 == embed_dim and len2 == query.size(-1)
176
-
177
- k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight)
178
- len1, len2 = k_proj_weight_non_opt.size()
179
- assert len1 == embed_dim and len2 == key.size(-1)
180
-
181
- v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight)
182
- len1, len2 = v_proj_weight_non_opt.size()
183
- assert len1 == embed_dim and len2 == value.size(-1)
184
-
185
- if in_proj_bias is not None:
186
- q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim])
187
- k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)])
188
- v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):])
189
- else:
190
- q = F.linear(query, q_proj_weight_non_opt, in_proj_bias)
191
- k = F.linear(key, k_proj_weight_non_opt, in_proj_bias)
192
- v = F.linear(value, v_proj_weight_non_opt, in_proj_bias)
193
- q = q * scaling
194
-
195
- if attn_mask is not None:
196
- assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
197
- attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
198
- 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
199
- if attn_mask.dtype == torch.uint8:
200
- warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
201
- attn_mask = attn_mask.to(torch.bool)
202
-
203
- if attn_mask.dim() == 2:
204
- attn_mask = attn_mask.unsqueeze(0)
205
- if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
206
- raise RuntimeError('The size of the 2D attn_mask is not correct.')
207
- elif attn_mask.dim() == 3:
208
- if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]:
209
- raise RuntimeError('The size of the 3D attn_mask is not correct.')
210
- else:
211
- raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
212
- # attn_mask's dim is 3 now.
213
-
214
- # # convert ByteTensor key_padding_mask to bool
215
- # if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
216
- # warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
217
- # key_padding_mask = key_padding_mask.to(torch.bool)
218
-
219
- if bias_k is not None and bias_v is not None:
220
- if static_k is None and static_v is None:
221
- k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
222
- v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
223
- if attn_mask is not None:
224
- attn_mask = pad(attn_mask, (0, 1))
225
- if key_padding_mask is not None:
226
- key_padding_mask = pad(key_padding_mask, (0, 1))
227
- else:
228
- assert static_k is None, "bias cannot be added to static key."
229
- assert static_v is None, "bias cannot be added to static value."
230
- else:
231
- assert bias_k is None
232
- assert bias_v is None
233
-
234
- q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
235
- if k is not None:
236
- k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
237
- if v is not None:
238
- v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
239
-
240
- if static_k is not None:
241
- assert static_k.size(0) == bsz * num_heads
242
- assert static_k.size(2) == head_dim
243
- k = static_k
244
-
245
- if static_v is not None:
246
- assert static_v.size(0) == bsz * num_heads
247
- assert static_v.size(2) == head_dim
248
- v = static_v
249
-
250
- src_len = k.size(1)
251
-
252
- if key_padding_mask is not None:
253
- assert key_padding_mask.size(0) == bsz
254
- assert key_padding_mask.size(1) == src_len
255
-
256
- if add_zero_attn:
257
- src_len += 1
258
- k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1)
259
- v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1)
260
- if attn_mask is not None:
261
- attn_mask = pad(attn_mask, (0, 1))
262
- if key_padding_mask is not None:
263
- key_padding_mask = pad(key_padding_mask, (0, 1))
264
-
265
- attn_output_weights = torch.bmm(q, k.transpose(1, 2))
266
- assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len]
267
-
268
- if attn_mask is not None:
269
- if attn_mask.dtype == torch.bool:
270
- attn_output_weights.masked_fill_(attn_mask, float('-inf'))
271
- else:
272
- attn_output_weights += attn_mask
273
-
274
-
275
- if key_padding_mask is not None:
276
- attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
277
- attn_output_weights = attn_output_weights.masked_fill(
278
- key_padding_mask.unsqueeze(1).unsqueeze(2),
279
- float('-inf'),
280
- )
281
- attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len)
282
-
283
- attn_output_weights = F.softmax(
284
- attn_output_weights, dim=-1)
285
- attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training)
286
-
287
- attn_output = torch.bmm(attn_output_weights, v)
288
- assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
289
- attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
290
- attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias)
291
-
292
- if need_weights:
293
- # average attention weights over heads
294
- attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len)
295
- return attn_output, attn_output_weights.sum(dim=1) / num_heads
296
- else:
297
- return attn_output, None
298
-
299
- class MultiheadAttention(Module):
300
- r"""Allows the model to jointly attend to information
301
- from different representation subspaces.
302
- See reference: Attention Is All You Need
303
- .. math::
304
- \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
305
- \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
306
- Args:
307
- embed_dim: total dimension of the model.
308
- num_heads: parallel attention heads.
309
- dropout: a Dropout layer on attn_output_weights. Default: 0.0.
310
- bias: add bias as module parameter. Default: True.
311
- add_bias_kv: add bias to the key and value sequences at dim=0.
312
- add_zero_attn: add a new batch of zeros to the key and
313
- value sequences at dim=1.
314
- kdim: total number of features in key. Default: None.
315
- vdim: total number of features in value. Default: None.
316
- Note: if kdim and vdim are None, they will be set to embed_dim such that
317
- query, key, and value have the same number of features.
318
- Examples::
319
- >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
320
- >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
321
- """
322
- # __annotations__ = {
323
- # 'bias_k': torch._jit_internal.Optional[torch.Tensor],
324
- # 'bias_v': torch._jit_internal.Optional[torch.Tensor],
325
- # }
326
- __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']
327
-
328
- def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
329
- super(MultiheadAttention, self).__init__()
330
- self.embed_dim = embed_dim
331
- self.kdim = kdim if kdim is not None else embed_dim
332
- self.vdim = vdim if vdim is not None else embed_dim
333
- self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
334
-
335
- self.num_heads = num_heads
336
- self.dropout = dropout
337
- self.head_dim = embed_dim // num_heads
338
- assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
339
-
340
- if self._qkv_same_embed_dim is False:
341
- self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim))
342
- self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim))
343
- self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim))
344
- self.register_parameter('in_proj_weight', None)
345
- else:
346
- self.in_proj_weight = Parameter(torch.empty(3 * embed_dim, embed_dim))
347
- self.register_parameter('q_proj_weight', None)
348
- self.register_parameter('k_proj_weight', None)
349
- self.register_parameter('v_proj_weight', None)
350
-
351
- if bias:
352
- self.in_proj_bias = Parameter(torch.empty(3 * embed_dim))
353
- else:
354
- self.register_parameter('in_proj_bias', None)
355
- self.out_proj = Linear(embed_dim, embed_dim, bias=bias)
356
-
357
- if add_bias_kv:
358
- self.bias_k = Parameter(torch.empty(1, 1, embed_dim))
359
- self.bias_v = Parameter(torch.empty(1, 1, embed_dim))
360
- else:
361
- self.bias_k = self.bias_v = None
362
-
363
- self.add_zero_attn = add_zero_attn
364
-
365
- self._reset_parameters()
366
-
367
- def _reset_parameters(self):
368
- if self._qkv_same_embed_dim:
369
- xavier_uniform_(self.in_proj_weight)
370
- else:
371
- xavier_uniform_(self.q_proj_weight)
372
- xavier_uniform_(self.k_proj_weight)
373
- xavier_uniform_(self.v_proj_weight)
374
-
375
- if self.in_proj_bias is not None:
376
- constant_(self.in_proj_bias, 0.)
377
- constant_(self.out_proj.bias, 0.)
378
- if self.bias_k is not None:
379
- xavier_normal_(self.bias_k)
380
- if self.bias_v is not None:
381
- xavier_normal_(self.bias_v)
382
-
383
- def __setstate__(self, state):
384
- # Support loading old MultiheadAttention checkpoints generated by v1.1.0
385
- if '_qkv_same_embed_dim' not in state:
386
- state['_qkv_same_embed_dim'] = True
387
-
388
- super(MultiheadAttention, self).__setstate__(state)
389
-
390
- def forward(self, query, key, value, key_padding_mask=None,
391
- need_weights=True, attn_mask=None):
392
- # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]
393
- r"""
394
- Args:
395
- query, key, value: map a query and a set of key-value pairs to an output.
396
- See "Attention Is All You Need" for more details.
397
- key_padding_mask: if provided, specified padding elements in the key will
398
- be ignored by the attention. This is an binary mask. When the value is True,
399
- the corresponding value on the attention layer will be filled with -inf.
400
- need_weights: output attn_output_weights.
401
- attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
402
- the batches while a 3D mask allows to specify a different mask for the entries of each batch.
403
- Shape:
404
- - Inputs:
405
- - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
406
- the embedding dimension.
407
- - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
408
- the embedding dimension.
409
- - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
410
- the embedding dimension.
411
- - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
412
- If a ByteTensor is provided, the non-zero positions will be ignored while the position
413
- with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
414
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
415
- - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
416
- 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
417
- S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
418
- positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
419
- while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
420
- is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
421
- is provided, it will be added to the attention weight.
422
- - Outputs:
423
- - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
424
- E is the embedding dimension.
425
- - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
426
- L is the target sequence length, S is the source sequence length.
427
- """
428
- if not self._qkv_same_embed_dim:
429
- return multi_head_attention_forward(
430
- query, key, value, self.embed_dim, self.num_heads,
431
- self.in_proj_weight, self.in_proj_bias,
432
- self.bias_k, self.bias_v, self.add_zero_attn,
433
- self.dropout, self.out_proj.weight, self.out_proj.bias,
434
- training=self.training,
435
- key_padding_mask=key_padding_mask, need_weights=need_weights,
436
- attn_mask=attn_mask, use_separate_proj_weight=True,
437
- q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
438
- v_proj_weight=self.v_proj_weight)
439
- else:
440
- return multi_head_attention_forward(
441
- query, key, value, self.embed_dim, self.num_heads,
442
- self.in_proj_weight, self.in_proj_bias,
443
- self.bias_k, self.bias_v, self.add_zero_attn,
444
- self.dropout, self.out_proj.weight, self.out_proj.bias,
445
- training=self.training,
446
- key_padding_mask=key_padding_mask, need_weights=need_weights,
447
- attn_mask=attn_mask)
448
-
449
-
450
- class Transformer(Module):
451
- r"""A transformer model. User is able to modify the attributes as needed. The architecture
452
- is based on the paper "Attention Is All You Need". Ashish Vaswani, Noam Shazeer,
453
- Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, Lukasz Kaiser, and
454
- Illia Polosukhin. 2017. Attention is all you need. In Advances in Neural Information
455
- Processing Systems, pages 6000-6010. Users can build the BERT(https://arxiv.org/abs/1810.04805)
456
- model with corresponding parameters.
457
-
458
- Args:
459
- d_model: the number of expected features in the encoder/decoder inputs (default=512).
460
- nhead: the number of heads in the multiheadattention models (default=8).
461
- num_encoder_layers: the number of sub-encoder-layers in the encoder (default=6).
462
- num_decoder_layers: the number of sub-decoder-layers in the decoder (default=6).
463
- dim_feedforward: the dimension of the feedforward network model (default=2048).
464
- dropout: the dropout value (default=0.1).
465
- activation: the activation function of encoder/decoder intermediate layer, relu or gelu (default=relu).
466
- custom_encoder: custom encoder (default=None).
467
- custom_decoder: custom decoder (default=None).
468
-
469
- Examples::
470
- >>> transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)
471
- >>> src = torch.rand((10, 32, 512))
472
- >>> tgt = torch.rand((20, 32, 512))
473
- >>> out = transformer_model(src, tgt)
474
-
475
- Note: A full example to apply nn.Transformer module for the word language model is available in
476
- https://github.com/pytorch/examples/tree/master/word_language_model
477
- """
478
-
479
- def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
480
- num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
481
- activation="relu", custom_encoder=None, custom_decoder=None):
482
- super(Transformer, self).__init__()
483
-
484
- if custom_encoder is not None:
485
- self.encoder = custom_encoder
486
- else:
487
- encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
488
- encoder_norm = LayerNorm(d_model)
489
- self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
490
-
491
- if custom_decoder is not None:
492
- self.decoder = custom_decoder
493
- else:
494
- decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation)
495
- decoder_norm = LayerNorm(d_model)
496
- self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm)
497
-
498
- self._reset_parameters()
499
-
500
- self.d_model = d_model
501
- self.nhead = nhead
502
-
503
- def forward(self, src, tgt, src_mask=None, tgt_mask=None,
504
- memory_mask=None, src_key_padding_mask=None,
505
- tgt_key_padding_mask=None, memory_key_padding_mask=None):
506
- # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor # noqa
507
- r"""Take in and process masked source/target sequences.
508
-
509
- Args:
510
- src: the sequence to the encoder (required).
511
- tgt: the sequence to the decoder (required).
512
- src_mask: the additive mask for the src sequence (optional).
513
- tgt_mask: the additive mask for the tgt sequence (optional).
514
- memory_mask: the additive mask for the encoder output (optional).
515
- src_key_padding_mask: the ByteTensor mask for src keys per batch (optional).
516
- tgt_key_padding_mask: the ByteTensor mask for tgt keys per batch (optional).
517
- memory_key_padding_mask: the ByteTensor mask for memory keys per batch (optional).
518
-
519
- Shape:
520
- - src: :math:`(S, N, E)`.
521
- - tgt: :math:`(T, N, E)`.
522
- - src_mask: :math:`(S, S)`.
523
- - tgt_mask: :math:`(T, T)`.
524
- - memory_mask: :math:`(T, S)`.
525
- - src_key_padding_mask: :math:`(N, S)`.
526
- - tgt_key_padding_mask: :math:`(N, T)`.
527
- - memory_key_padding_mask: :math:`(N, S)`.
528
-
529
- Note: [src/tgt/memory]_mask ensures that position i is allowed to attend the unmasked
530
- positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
531
- while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
532
- are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
533
- is provided, it will be added to the attention weight.
534
- [src/tgt/memory]_key_padding_mask provides specified elements in the key to be ignored by
535
- the attention. If a ByteTensor is provided, the non-zero positions will be ignored while the zero
536
- positions will be unchanged. If a BoolTensor is provided, the positions with the
537
- value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
538
-
539
- - output: :math:`(T, N, E)`.
540
-
541
- Note: Due to the multi-head attention architecture in the transformer model,
542
- the output sequence length of a transformer is same as the input sequence
543
- (i.e. target) length of the decode.
544
-
545
- where S is the source sequence length, T is the target sequence length, N is the
546
- batch size, E is the feature number
547
-
548
- Examples:
549
- >>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
550
- """
551
-
552
- if src.size(1) != tgt.size(1):
553
- raise RuntimeError("the batch number of src and tgt must be equal")
554
-
555
- if src.size(2) != self.d_model or tgt.size(2) != self.d_model:
556
- raise RuntimeError("the feature number of src and tgt must be equal to d_model")
557
-
558
- memory = self.encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
559
- output = self.decoder(tgt, memory, tgt_mask=tgt_mask, memory_mask=memory_mask,
560
- tgt_key_padding_mask=tgt_key_padding_mask,
561
- memory_key_padding_mask=memory_key_padding_mask)
562
- return output
563
-
564
- def generate_square_subsequent_mask(self, sz):
565
- r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
566
- Unmasked positions are filled with float(0.0).
567
- """
568
- mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
569
- mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
570
- return mask
571
-
572
- def _reset_parameters(self):
573
- r"""Initiate parameters in the transformer model."""
574
-
575
- for p in self.parameters():
576
- if p.dim() > 1:
577
- xavier_uniform_(p)
578
-
579
-
580
- class TransformerEncoder(Module):
581
- r"""TransformerEncoder is a stack of N encoder layers
582
-
583
- Args:
584
- encoder_layer: an instance of the TransformerEncoderLayer() class (required).
585
- num_layers: the number of sub-encoder-layers in the encoder (required).
586
- norm: the layer normalization component (optional).
587
-
588
- Examples::
589
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
590
- >>> transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
591
- >>> src = torch.rand(10, 32, 512)
592
- >>> out = transformer_encoder(src)
593
- """
594
- __constants__ = ['norm']
595
-
596
- def __init__(self, encoder_layer, num_layers, norm=None):
597
- super(TransformerEncoder, self).__init__()
598
- self.layers = _get_clones(encoder_layer, num_layers)
599
- self.num_layers = num_layers
600
- self.norm = norm
601
-
602
- def forward(self, src, mask=None, src_key_padding_mask=None):
603
- # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
604
- r"""Pass the input through the encoder layers in turn.
605
-
606
- Args:
607
- src: the sequence to the encoder (required).
608
- mask: the mask for the src sequence (optional).
609
- src_key_padding_mask: the mask for the src keys per batch (optional).
610
-
611
- Shape:
612
- see the docs in Transformer class.
613
- """
614
- output = src
615
-
616
- for i, mod in enumerate(self.layers):
617
- output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
618
-
619
- if self.norm is not None:
620
- output = self.norm(output)
621
-
622
- return output
623
-
624
-
625
- class TransformerDecoder(Module):
626
- r"""TransformerDecoder is a stack of N decoder layers
627
-
628
- Args:
629
- decoder_layer: an instance of the TransformerDecoderLayer() class (required).
630
- num_layers: the number of sub-decoder-layers in the decoder (required).
631
- norm: the layer normalization component (optional).
632
-
633
- Examples::
634
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
635
- >>> transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=6)
636
- >>> memory = torch.rand(10, 32, 512)
637
- >>> tgt = torch.rand(20, 32, 512)
638
- >>> out = transformer_decoder(tgt, memory)
639
- """
640
- __constants__ = ['norm']
641
-
642
- def __init__(self, decoder_layer, num_layers, norm=None):
643
- super(TransformerDecoder, self).__init__()
644
- self.layers = _get_clones(decoder_layer, num_layers)
645
- self.num_layers = num_layers
646
- self.norm = norm
647
-
648
- def forward(self, tgt, memory, memory2=None, tgt_mask=None,
649
- memory_mask=None, memory_mask2=None, tgt_key_padding_mask=None,
650
- memory_key_padding_mask=None, memory_key_padding_mask2=None):
651
- # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
652
- r"""Pass the inputs (and mask) through the decoder layer in turn.
653
-
654
- Args:
655
- tgt: the sequence to the decoder (required).
656
- memory: the sequence from the last layer of the encoder (required).
657
- tgt_mask: the mask for the tgt sequence (optional).
658
- memory_mask: the mask for the memory sequence (optional).
659
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
660
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
661
-
662
- Shape:
663
- see the docs in Transformer class.
664
- """
665
- output = tgt
666
-
667
- for mod in self.layers:
668
- output = mod(output, memory, memory2=memory2, tgt_mask=tgt_mask,
669
- memory_mask=memory_mask, memory_mask2=memory_mask2,
670
- tgt_key_padding_mask=tgt_key_padding_mask,
671
- memory_key_padding_mask=memory_key_padding_mask,
672
- memory_key_padding_mask2=memory_key_padding_mask2)
673
-
674
- if self.norm is not None:
675
- output = self.norm(output)
676
-
677
- return output
678
-
679
- class TransformerEncoderLayer(Module):
680
- r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
681
- This standard encoder layer is based on the paper "Attention Is All You Need".
682
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
683
- Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
684
- Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
685
- in a different way during application.
686
-
687
- Args:
688
- d_model: the number of expected features in the input (required).
689
- nhead: the number of heads in the multiheadattention models (required).
690
- dim_feedforward: the dimension of the feedforward network model (default=2048).
691
- dropout: the dropout value (default=0.1).
692
- activation: the activation function of intermediate layer, relu or gelu (default=relu).
693
-
694
- Examples::
695
- >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
696
- >>> src = torch.rand(10, 32, 512)
697
- >>> out = encoder_layer(src)
698
- """
699
-
700
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
701
- activation="relu", debug=False):
702
- super(TransformerEncoderLayer, self).__init__()
703
- self.debug = debug
704
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
705
- # Implementation of Feedforward model
706
- self.linear1 = Linear(d_model, dim_feedforward)
707
- self.dropout = Dropout(dropout)
708
- self.linear2 = Linear(dim_feedforward, d_model)
709
-
710
- self.norm1 = LayerNorm(d_model)
711
- self.norm2 = LayerNorm(d_model)
712
- self.dropout1 = Dropout(dropout)
713
- self.dropout2 = Dropout(dropout)
714
-
715
- self.activation = _get_activation_fn(activation)
716
-
717
- def __setstate__(self, state):
718
- if 'activation' not in state:
719
- state['activation'] = F.relu
720
- super(TransformerEncoderLayer, self).__setstate__(state)
721
-
722
- def forward(self, src, src_mask=None, src_key_padding_mask=None):
723
- # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
724
- r"""Pass the input through the encoder layer.
725
-
726
- Args:
727
- src: the sequence to the encoder layer (required).
728
- src_mask: the mask for the src sequence (optional).
729
- src_key_padding_mask: the mask for the src keys per batch (optional).
730
-
731
- Shape:
732
- see the docs in Transformer class.
733
- """
734
- src2, attn = self.self_attn(src, src, src, attn_mask=src_mask,
735
- key_padding_mask=src_key_padding_mask)
736
- if self.debug: self.attn = attn
737
- src = src + self.dropout1(src2)
738
- src = self.norm1(src)
739
- src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
740
- src = src + self.dropout2(src2)
741
- src = self.norm2(src)
742
-
743
- return src
744
-
745
-
746
- class TransformerDecoderLayer(Module):
747
- r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network.
748
- This standard decoder layer is based on the paper "Attention Is All You Need".
749
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
750
- Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
751
- Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
752
- in a different way during application.
753
-
754
- Args:
755
- d_model: the number of expected features in the input (required).
756
- nhead: the number of heads in the multiheadattention models (required).
757
- dim_feedforward: the dimension of the feedforward network model (default=2048).
758
- dropout: the dropout value (default=0.1).
759
- activation: the activation function of intermediate layer, relu or gelu (default=relu).
760
-
761
- Examples::
762
- >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
763
- >>> memory = torch.rand(10, 32, 512)
764
- >>> tgt = torch.rand(20, 32, 512)
765
- >>> out = decoder_layer(tgt, memory)
766
- """
767
-
768
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
769
- activation="relu", self_attn=True, siamese=False, debug=False):
770
- super(TransformerDecoderLayer, self).__init__()
771
- self.has_self_attn, self.siamese = self_attn, siamese
772
- self.debug = debug
773
- if self.has_self_attn:
774
- self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
775
- self.norm1 = LayerNorm(d_model)
776
- self.dropout1 = Dropout(dropout)
777
- self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
778
- # Implementation of Feedforward model
779
- self.linear1 = Linear(d_model, dim_feedforward)
780
- self.dropout = Dropout(dropout)
781
- self.linear2 = Linear(dim_feedforward, d_model)
782
-
783
- self.norm2 = LayerNorm(d_model)
784
- self.norm3 = LayerNorm(d_model)
785
- self.dropout2 = Dropout(dropout)
786
- self.dropout3 = Dropout(dropout)
787
- if self.siamese:
788
- self.multihead_attn2 = MultiheadAttention(d_model, nhead, dropout=dropout)
789
-
790
- self.activation = _get_activation_fn(activation)
791
-
792
- def __setstate__(self, state):
793
- if 'activation' not in state:
794
- state['activation'] = F.relu
795
- super(TransformerDecoderLayer, self).__setstate__(state)
796
-
797
- def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
798
- tgt_key_padding_mask=None, memory_key_padding_mask=None,
799
- memory2=None, memory_mask2=None, memory_key_padding_mask2=None):
800
- # type: (Tensor, Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]) -> Tensor
801
- r"""Pass the inputs (and mask) through the decoder layer.
802
-
803
- Args:
804
- tgt: the sequence to the decoder layer (required).
805
- memory: the sequence from the last layer of the encoder (required).
806
- tgt_mask: the mask for the tgt sequence (optional).
807
- memory_mask: the mask for the memory sequence (optional).
808
- tgt_key_padding_mask: the mask for the tgt keys per batch (optional).
809
- memory_key_padding_mask: the mask for the memory keys per batch (optional).
810
-
811
- Shape:
812
- see the docs in Transformer class.
813
- """
814
- if self.has_self_attn:
815
- tgt2, attn = self.self_attn(tgt, tgt, tgt, attn_mask=tgt_mask,
816
- key_padding_mask=tgt_key_padding_mask)
817
- tgt = tgt + self.dropout1(tgt2)
818
- tgt = self.norm1(tgt)
819
- if self.debug: self.attn = attn
820
- tgt2, attn2 = self.multihead_attn(tgt, memory, memory, attn_mask=memory_mask,
821
- key_padding_mask=memory_key_padding_mask)
822
- if self.debug: self.attn2 = attn2
823
-
824
- if self.siamese:
825
- tgt3, attn3 = self.multihead_attn2(tgt, memory2, memory2, attn_mask=memory_mask2,
826
- key_padding_mask=memory_key_padding_mask2)
827
- tgt = tgt + self.dropout2(tgt3)
828
- if self.debug: self.attn3 = attn3
829
-
830
- tgt = tgt + self.dropout2(tgt2)
831
- tgt = self.norm2(tgt)
832
- tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
833
- tgt = tgt + self.dropout3(tgt2)
834
- tgt = self.norm3(tgt)
835
-
836
- return tgt
837
-
838
-
839
- def _get_clones(module, N):
840
- return ModuleList([copy.deepcopy(module) for i in range(N)])
841
-
842
-
843
- def _get_activation_fn(activation):
844
- if activation == "relu":
845
- return F.relu
846
- elif activation == "gelu":
847
- return F.gelu
848
-
849
- raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
850
-
851
-
852
- class PositionalEncoding(nn.Module):
853
- r"""Inject some information about the relative or absolute position of the tokens
854
- in the sequence. The positional encodings have the same dimension as
855
- the embeddings, so that the two can be summed. Here, we use sine and cosine
856
- functions of different frequencies.
857
- .. math::
858
- \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
859
- \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
860
- \text{where pos is the word position and i is the embed idx)
861
- Args:
862
- d_model: the embed dim (required).
863
- dropout: the dropout value (default=0.1).
864
- max_len: the max. length of the incoming sequence (default=5000).
865
- Examples:
866
- >>> pos_encoder = PositionalEncoding(d_model)
867
- """
868
-
869
- def __init__(self, d_model, dropout=0.1, max_len=5000):
870
- super(PositionalEncoding, self).__init__()
871
- self.dropout = nn.Dropout(p=dropout)
872
-
873
- pe = torch.zeros(max_len, d_model)
874
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
875
- div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
876
- pe[:, 0::2] = torch.sin(position * div_term)
877
- pe[:, 1::2] = torch.cos(position * div_term)
878
- pe = pe.unsqueeze(0).transpose(0, 1)
879
- self.register_buffer('pe', pe)
880
-
881
- def forward(self, x):
882
- r"""Inputs of forward function
883
- Args:
884
- x: the sequence fed to the positional encoder model (required).
885
- Shape:
886
- x: [sequence length, batch size, embed dim]
887
- output: [sequence length, batch size, embed dim]
888
- Examples:
889
- >>> output = pos_encoder(x)
890
- """
891
-
892
- x = x + self.pe[:x.size(0), :]
893
- return self.dropout(x)
894
-
895
-
896
- if __name__ == '__main__':
897
- transformer_model = Transformer(nhead=16, num_encoder_layers=12)
898
- src = torch.rand((10, 32, 512))
899
- tgt = torch.rand((20, 32, 512))
900
- out = transformer_model(src, tgt)
901
- print(out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py CHANGED
@@ -202,7 +202,7 @@ class Config(object):
202
  assert os.path.exists(config_path), '%s does not exists!' % config_path
203
  with open(config_path) as file:
204
  config_dict = yaml.load(file, Loader=yaml.FullLoader)
205
- with open('configs/template.yaml') as file:
206
  default_config_dict = yaml.load(file, Loader=yaml.FullLoader)
207
  __dict2attr(default_config_dict)
208
  __dict2attr(config_dict)
 
202
  assert os.path.exists(config_path), '%s does not exists!' % config_path
203
  with open(config_path) as file:
204
  config_dict = yaml.load(file, Loader=yaml.FullLoader)
205
+ with open('configs/rec/template.yaml') as file:
206
  default_config_dict = yaml.load(file, Loader=yaml.FullLoader)
207
  __dict2attr(default_config_dict)
208
  __dict2attr(config_dict)