geetu040's picture
Initial Upload
d08668b
import torch
import torch.nn as nn
import json
def attention(Q, K, V):
d = K.shape[-1]
QK = Q @ K.transpose(-2, -1)
QK_d = QK / (d ** 0.5)
weights = torch.softmax(QK_d, axis=-1)
outputs = weights @ V
return outputs
class Attention(torch.nn.Module):
def __init__(self, emb_dim, n_heads):
super(Attention, self).__init__()
self.emb_dim = emb_dim
self.n_heads = n_heads
def forward(self, X):
batch_size, seq_len, emb_dim = X.size() # (batch_size, seq_len, emb_dim)
n_heads = self.n_heads
emb_dim_per_head = emb_dim // n_heads
assert emb_dim == self.emb_dim
assert emb_dim_per_head * n_heads == emb_dim
X = X.transpose(1, 2)
output = attention(X, X, X) # (batch_size, n_heads, seq_len, emb_dim_per_head)
output = output.transpose(1, 2) # (batch_size, seq_len, n_heads, emb_dim_per_head)
output = output.contiguous().view(batch_size, seq_len, emb_dim) # (batch_size, seq_len, emb_dim)
return output
class ClassifierAttention(nn.Module):
def __init__(self, vocab_size, emb_dim, padding_idx, hidden_size, n_layers, attention_heads, hidden_layer_units, dropout):
super(ClassifierAttention, self).__init__()
self.embedding = nn.Embedding(
num_embeddings = vocab_size,
embedding_dim = emb_dim,
padding_idx = padding_idx
)
self.rnn_1 = nn.LSTM(
emb_dim,
hidden_size,
n_layers,
bidirectional = False,
batch_first = True,
)
self.attention = Attention(hidden_size, attention_heads)
self.rnn_2 = nn.LSTM(
hidden_size,
hidden_size,
n_layers,
bidirectional = False,
batch_first = True,
)
self.dropout = nn.Dropout(dropout)
hidden_layer_units = [hidden_size, *hidden_layer_units]
self.hidden_layers = nn.ModuleList([])
for in_unit, out_unit in zip(hidden_layer_units[:-1], hidden_layer_units[1:]):
self.hidden_layers.append(nn.Linear(in_unit, out_unit))
self.hidden_layers.append(nn.ReLU())
self.hidden_layers.append(self.dropout)
self.hidden_layers.append(nn.Linear(hidden_layer_units[-1], 1))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: (batch_size, seq_len)
out = self.embedding(x) # (batch_size, seq_len, emb_dim)
out, (hidden_state, cell_state) = self.rnn_1(out)
out = self.attention(out) # (batch_size, seq_len, emb_dim)
out = self.dropout(out)
output, (hidden_state, cell_state) = self.rnn_2(out)
out = hidden_state[-1] # (batch_size, hidden_size)
out = self.dropout(out)
# (batch_size, seq_len, hidden_dim)
# (n_layers*n_direction, batch_size, hidden_size)
# (n_layers*n_direction, batch_size, hidden_size)
for layer in self.hidden_layers:
out = layer(out)
out = self.sigmoid(out) # (batch_size, 1)
out = out.squeeze(-1) # (batch_size)
return out
def get_model(model_path, params_path):
with open(params_path, 'rb') as f:
params = json.load(f)
model = ClassifierAttention(*params)
model.load_state_dict(torch.load(model_path))
model.eval()
return model