|
""" |
|
Deprecated network.py module. This file only exists to support backwards-compatibility |
|
with old pickle files. See lib/__init__.py for more information. |
|
""" |
|
|
|
from __future__ import print_function |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
from torch.nn.parameter import Parameter |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def choose(matrix, idxs): |
|
if isinstance(idxs, Variable): |
|
idxs = idxs.data |
|
assert(matrix.ndimension() == 2) |
|
unrolled_idxs = idxs + \ |
|
torch.arange(0, matrix.size(0)).type_as(idxs) * matrix.size(1) |
|
return matrix.view(matrix.nelement())[unrolled_idxs] |
|
|
|
|
|
class Network(nn.Module): |
|
""" |
|
Todo: |
|
- Beam search |
|
- check if this is right? attend during P->FC rather than during softmax->P? |
|
- allow length 0 inputs/targets |
|
- give n_examples as input to FC |
|
- Initialise new weights randomly, rather than as zeroes |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_vocabulary, |
|
target_vocabulary, |
|
hidden_size=512, |
|
embedding_size=128, |
|
cell_type="LSTM"): |
|
""" |
|
:param list input_vocabulary: list of possible inputs |
|
:param list target_vocabulary: list of possible targets |
|
""" |
|
super(Network, self).__init__() |
|
self.h_input_encoder_size = hidden_size |
|
self.h_output_encoder_size = hidden_size |
|
self.h_decoder_size = hidden_size |
|
self.embedding_size = embedding_size |
|
self.input_vocabulary = input_vocabulary |
|
self.target_vocabulary = target_vocabulary |
|
|
|
self.v_input = len(input_vocabulary) |
|
|
|
self.v_target = len(target_vocabulary) |
|
|
|
self.cell_type = cell_type |
|
if cell_type == 'GRU': |
|
self.input_encoder_cell = nn.GRUCell( |
|
input_size=self.v_input + 1, |
|
hidden_size=self.h_input_encoder_size, |
|
bias=True) |
|
self.input_encoder_init = Parameter( |
|
torch.rand(1, self.h_input_encoder_size)) |
|
self.output_encoder_cell = nn.GRUCell( |
|
input_size=self.v_input + |
|
1 + |
|
self.h_input_encoder_size, |
|
hidden_size=self.h_output_encoder_size, |
|
bias=True) |
|
self.decoder_cell = nn.GRUCell( |
|
input_size=self.v_target + 1, |
|
hidden_size=self.h_decoder_size, |
|
bias=True) |
|
if cell_type == 'LSTM': |
|
self.input_encoder_cell = nn.LSTMCell( |
|
input_size=self.v_input + 1, |
|
hidden_size=self.h_input_encoder_size, |
|
bias=True) |
|
self.input_encoder_init = nn.ParameterList([Parameter(torch.rand( |
|
1, self.h_input_encoder_size)), Parameter(torch.rand(1, self.h_input_encoder_size))]) |
|
self.output_encoder_cell = nn.LSTMCell( |
|
input_size=self.v_input + |
|
1 + |
|
self.h_input_encoder_size, |
|
hidden_size=self.h_output_encoder_size, |
|
bias=True) |
|
self.output_encoder_init_c = Parameter( |
|
torch.rand(1, self.h_output_encoder_size)) |
|
self.decoder_cell = nn.LSTMCell( |
|
input_size=self.v_target + 1, |
|
hidden_size=self.h_decoder_size, |
|
bias=True) |
|
self.decoder_init_c = Parameter(torch.rand(1, self.h_decoder_size)) |
|
|
|
self.W = nn.Linear( |
|
self.h_output_encoder_size + |
|
self.h_decoder_size, |
|
self.embedding_size) |
|
self.V = nn.Linear(self.embedding_size, self.v_target + 1) |
|
self.input_A = nn.Bilinear( |
|
self.h_input_encoder_size, |
|
self.h_output_encoder_size, |
|
1, |
|
bias=False) |
|
self.output_A = nn.Bilinear( |
|
self.h_output_encoder_size, |
|
self.h_decoder_size, |
|
1, |
|
bias=False) |
|
self.input_EOS = torch.zeros(1, self.v_input + 1) |
|
self.input_EOS[:, -1] = 1 |
|
self.input_EOS = Parameter(self.input_EOS) |
|
self.output_EOS = torch.zeros(1, self.v_input + 1) |
|
self.output_EOS[:, -1] = 1 |
|
self.output_EOS = Parameter(self.output_EOS) |
|
self.target_EOS = torch.zeros(1, self.v_target + 1) |
|
self.target_EOS[:, -1] = 1 |
|
self.target_EOS = Parameter(self.target_EOS) |
|
|
|
def __getstate__(self): |
|
if hasattr(self, 'opt'): |
|
return dict([(k, v) for k, v in self.__dict__.items( |
|
) if k is not 'opt'] + [('optstate', self.opt.state_dict())]) |
|
|
|
|
|
else: |
|
return self.__dict__ |
|
|
|
def __setstate__(self, state): |
|
self.__dict__.update(state) |
|
|
|
if isinstance(self.input_encoder_init, tuple): |
|
self.input_encoder_init = nn.ParameterList( |
|
list(self.input_encoder_init)) |
|
|
|
def clear_optimiser(self): |
|
if hasattr(self, 'opt'): |
|
del self.opt |
|
if hasattr(self, 'optstate'): |
|
del self.optstate |
|
|
|
def get_optimiser(self): |
|
self.opt = torch.optim.Adam(self.parameters(), lr=0.001) |
|
if hasattr(self, 'optstate'): |
|
self.opt.load_state_dict(self.optstate) |
|
|
|
def optimiser_step(self, inputs, outputs, target): |
|
if not hasattr(self, 'opt'): |
|
self.get_optimiser() |
|
score = self.score(inputs, outputs, target, autograd=True).mean() |
|
(-score).backward() |
|
self.opt.step() |
|
self.opt.zero_grad() |
|
return score.data[0] |
|
|
|
def set_target_vocabulary(self, target_vocabulary): |
|
if target_vocabulary == self.target_vocabulary: |
|
return |
|
|
|
V_weight = [] |
|
V_bias = [] |
|
decoder_ih = [] |
|
|
|
for i in range(len(target_vocabulary)): |
|
if target_vocabulary[i] in self.target_vocabulary: |
|
j = self.target_vocabulary.index(target_vocabulary[i]) |
|
V_weight.append(self.V.weight.data[j:j + 1]) |
|
V_bias.append(self.V.bias.data[j:j + 1]) |
|
decoder_ih.append(self.decoder_cell.weight_ih.data[:, j:j + 1]) |
|
else: |
|
V_weight.append(torch.zeros(1, self.V.weight.size(1))) |
|
V_bias.append(torch.ones(1) * -10) |
|
decoder_ih.append( |
|
torch.zeros( |
|
self.decoder_cell.weight_ih.data.size(0), 1)) |
|
|
|
V_weight.append(self.V.weight.data[-1:]) |
|
V_bias.append(self.V.bias.data[-1:]) |
|
decoder_ih.append(self.decoder_cell.weight_ih.data[:, -1:]) |
|
|
|
self.target_vocabulary = target_vocabulary |
|
self.v_target = len(target_vocabulary) |
|
self.target_EOS.data = torch.zeros(1, self.v_target + 1) |
|
self.target_EOS.data[:, -1] = 1 |
|
|
|
self.V.weight.data = torch.cat(V_weight, dim=0) |
|
self.V.bias.data = torch.cat(V_bias, dim=0) |
|
self.V.out_features = self.V.bias.data.size(0) |
|
|
|
self.decoder_cell.weight_ih.data = torch.cat(decoder_ih, dim=1) |
|
self.decoder_cell.input_size = self.decoder_cell.weight_ih.data.size(1) |
|
|
|
self.clear_optimiser() |
|
|
|
def input_encoder_get_init(self, batch_size): |
|
if self.cell_type == "GRU": |
|
return self.input_encoder_init.repeat(batch_size, 1) |
|
if self.cell_type == "LSTM": |
|
return tuple(x.repeat(batch_size, 1) |
|
for x in self.input_encoder_init) |
|
|
|
def output_encoder_get_init(self, input_encoder_h): |
|
if self.cell_type == "GRU": |
|
return input_encoder_h |
|
if self.cell_type == "LSTM": |
|
return ( |
|
input_encoder_h, |
|
self.output_encoder_init_c.repeat( |
|
input_encoder_h.size(0), |
|
1)) |
|
|
|
def decoder_get_init(self, output_encoder_h): |
|
if self.cell_type == "GRU": |
|
return output_encoder_h |
|
if self.cell_type == "LSTM": |
|
return ( |
|
output_encoder_h, |
|
self.decoder_init_c.repeat( |
|
output_encoder_h.size(0), |
|
1)) |
|
|
|
def cell_get_h(self, cell_state): |
|
if self.cell_type == "GRU": |
|
return cell_state |
|
if self.cell_type == "LSTM": |
|
return cell_state[0] |
|
|
|
def score(self, inputs, outputs, target, autograd=False): |
|
inputs = self.inputsToTensors(inputs) |
|
outputs = self.inputsToTensors(outputs) |
|
target = self.targetToTensor(target) |
|
target, score = self.run(inputs, outputs, target=target, mode="score") |
|
|
|
if autograd: |
|
return score |
|
else: |
|
return score.data |
|
|
|
def sample(self, inputs, outputs): |
|
inputs = self.inputsToTensors(inputs) |
|
outputs = self.inputsToTensors(outputs) |
|
target, score = self.run(inputs, outputs, mode="sample") |
|
target = self.tensorToOutput(target) |
|
return target |
|
|
|
def sampleAndScore(self, inputs, outputs, nRepeats=None): |
|
inputs = self.inputsToTensors(inputs) |
|
outputs = self.inputsToTensors(outputs) |
|
if nRepeats is None: |
|
target, score = self.run(inputs, outputs, mode="sample") |
|
target = self.tensorToOutput(target) |
|
return target, score.data |
|
else: |
|
target = [] |
|
score = [] |
|
for i in range(nRepeats): |
|
|
|
t, s = self.run(inputs, outputs, mode="sample") |
|
t = self.tensorToOutput(t) |
|
target.extend(t) |
|
score.extend(list(s.data)) |
|
return target, score |
|
|
|
def run(self, inputs, outputs, target=None, mode="sample"): |
|
""" |
|
:param mode: "score" returns log p(target|input), "sample" returns target ~ p(-|input) |
|
:param List[LongTensor] inputs: n_examples * (max_length_input * batch_size) |
|
:param List[LongTensor] target: max_length_target * batch_size |
|
""" |
|
assert((mode == "score" and target is not None) or mode == "sample") |
|
|
|
n_examples = len(inputs) |
|
max_length_input = [inputs[j].size(0) for j in range(n_examples)] |
|
max_length_output = [outputs[j].size(0) for j in range(n_examples)] |
|
max_length_target = target.size(0) if target is not None else 10 |
|
batch_size = inputs[0].size(1) |
|
|
|
score = Variable(torch.zeros(batch_size)) |
|
inputs_scatter = [Variable(torch.zeros(max_length_input[j], batch_size, self.v_input + 1).scatter_( |
|
2, inputs[j][:, :, None], 1)) for j in range(n_examples)] |
|
outputs_scatter = [Variable(torch.zeros(max_length_output[j], batch_size, self.v_input + 1).scatter_( |
|
2, outputs[j][:, :, None], 1)) for j in range(n_examples)] |
|
if target is not None: |
|
target_scatter = Variable(torch.zeros(max_length_target, |
|
batch_size, |
|
self.v_target + 1).scatter_(2, |
|
target[:, |
|
:, |
|
None], |
|
1)) |
|
|
|
|
|
|
|
|
|
input_H = [] |
|
input_embeddings = [] |
|
|
|
input_attention_mask = [] |
|
for j in range(n_examples): |
|
active = torch.Tensor(max_length_input[j], batch_size).byte() |
|
active[0, :] = 1 |
|
state = self.input_encoder_get_init(batch_size) |
|
hs = [] |
|
for i in range(max_length_input[j]): |
|
state = self.input_encoder_cell( |
|
inputs_scatter[j][i, :, :], state) |
|
if i + 1 < max_length_input[j]: |
|
active[i + 1, :] = active[i, :] * \ |
|
(inputs[j][i, :] != self.v_input) |
|
h = self.cell_get_h(state) |
|
hs.append(h[None, :, :]) |
|
input_H.append(torch.cat(hs, 0)) |
|
embedding_idx = active.sum(0).long() - 1 |
|
embedding = input_H[j].gather(0, Variable( |
|
embedding_idx[None, :, None].repeat(1, 1, self.h_input_encoder_size)))[0] |
|
input_embeddings.append(embedding) |
|
input_attention_mask.append(Variable(active.float().log())) |
|
|
|
|
|
|
|
def input_attend(j, h_out): |
|
""" |
|
'general' attention from https://arxiv.org/pdf/1508.04025.pdf |
|
:param j: Index of example |
|
:param h_out: batch_size * h_output_encoder_size |
|
""" |
|
scores = self.input_A( |
|
input_H[j].view( |
|
max_length_input[j] * batch_size, |
|
self.h_input_encoder_size), |
|
h_out.view( |
|
batch_size, |
|
self.h_output_encoder_size).repeat( |
|
max_length_input[j], |
|
1)).view( |
|
max_length_input[j], |
|
batch_size) + input_attention_mask[j] |
|
c = (F.softmax(scores[:, :, None], dim=0) * input_H[j]).sum(0) |
|
return c |
|
|
|
|
|
output_H = [] |
|
output_embeddings = [] |
|
|
|
output_attention_mask = [] |
|
for j in range(n_examples): |
|
active = torch.Tensor(max_length_output[j], batch_size).byte() |
|
active[0, :] = 1 |
|
state = self.output_encoder_get_init(input_embeddings[j]) |
|
hs = [] |
|
h = self.cell_get_h(state) |
|
for i in range(max_length_output[j]): |
|
state = self.output_encoder_cell(torch.cat( |
|
[outputs_scatter[j][i, :, :], input_attend(j, h)], 1), state) |
|
if i + 1 < max_length_output[j]: |
|
active[i + 1, :] = active[i, :] * \ |
|
(outputs[j][i, :] != self.v_input) |
|
h = self.cell_get_h(state) |
|
hs.append(h[None, :, :]) |
|
output_H.append(torch.cat(hs, 0)) |
|
embedding_idx = active.sum(0).long() - 1 |
|
embedding = output_H[j].gather(0, Variable( |
|
embedding_idx[None, :, None].repeat(1, 1, self.h_output_encoder_size)))[0] |
|
output_embeddings.append(embedding) |
|
output_attention_mask.append(Variable(active.float().log())) |
|
|
|
|
|
|
|
def output_attend(j, h_dec): |
|
""" |
|
'general' attention from https://arxiv.org/pdf/1508.04025.pdf |
|
:param j: Index of example |
|
:param h_dec: batch_size * h_decoder_size |
|
""" |
|
scores = self.output_A( |
|
output_H[j].view( |
|
max_length_output[j] * batch_size, |
|
self.h_output_encoder_size), |
|
h_dec.view( |
|
batch_size, |
|
self.h_decoder_size).repeat( |
|
max_length_output[j], |
|
1)).view( |
|
max_length_output[j], |
|
batch_size) + output_attention_mask[j] |
|
c = (F.softmax(scores[:, :, None], dim=0) * output_H[j]).sum(0) |
|
return c |
|
|
|
|
|
target = target if mode == "score" else torch.zeros( |
|
max_length_target, batch_size).long() |
|
decoder_states = [ |
|
self.decoder_get_init( |
|
output_embeddings[j]) for j in range(n_examples)] |
|
active = torch.ones(batch_size).byte() |
|
for i in range(max_length_target): |
|
FC = [] |
|
for j in range(n_examples): |
|
h = self.cell_get_h(decoder_states[j]) |
|
p_aug = torch.cat([h, output_attend(j, h)], 1) |
|
FC.append(F.tanh(self.W(p_aug)[None, :, :])) |
|
|
|
m = torch.max(torch.cat(FC, 0), 0)[0] |
|
logsoftmax = F.log_softmax(self.V(m), dim=1) |
|
if mode == "sample": |
|
target[i, :] = torch.multinomial( |
|
logsoftmax.data.exp(), 1)[:, 0] |
|
score = score + \ |
|
choose(logsoftmax, target[i, :]) * Variable(active.float()) |
|
active *= (target[i, :] != self.v_target) |
|
for j in range(n_examples): |
|
if mode == "score": |
|
target_char_scatter = target_scatter[i, :, :] |
|
elif mode == "sample": |
|
target_char_scatter = Variable(torch.zeros( |
|
batch_size, self.v_target + 1).scatter_(1, target[i, :, None], 1)) |
|
decoder_states[j] = self.decoder_cell( |
|
target_char_scatter, decoder_states[j]) |
|
return target, score |
|
|
|
def inputsToTensors(self, inputss): |
|
""" |
|
:param inputss: size = nBatch * nExamples |
|
""" |
|
tensors = [] |
|
for j in range(len(inputss[0])): |
|
inputs = [x[j] for x in inputss] |
|
maxlen = max(len(s) for s in inputs) |
|
t = torch.ones( |
|
1 if maxlen == 0 else maxlen + 1, |
|
len(inputs)).long() * self.v_input |
|
for i in range(len(inputs)): |
|
s = inputs[i] |
|
if len(s) > 0: |
|
t[:len(s), i] = torch.LongTensor( |
|
[self.input_vocabulary.index(x) for x in s]) |
|
tensors.append(t) |
|
return tensors |
|
|
|
def targetToTensor(self, targets): |
|
""" |
|
:param targets: |
|
""" |
|
maxlen = max(len(s) for s in targets) |
|
t = torch.ones( |
|
1 if maxlen == 0 else maxlen + 1, |
|
len(targets)).long() * self.v_target |
|
for i in range(len(targets)): |
|
s = targets[i] |
|
if len(s) > 0: |
|
t[:len(s), i] = torch.LongTensor( |
|
[self.target_vocabulary.index(x) for x in s]) |
|
return t |
|
|
|
def tensorToOutput(self, tensor): |
|
""" |
|
:param tensor: max_length * batch_size |
|
""" |
|
out = [] |
|
for i in range(tensor.size(1)): |
|
l = tensor[:, i].tolist() |
|
if l[0] == self.v_target: |
|
out.append([]) |
|
elif self.v_target in l: |
|
final = tensor[:, i].tolist().index(self.v_target) |
|
out.append([self.target_vocabulary[x] |
|
for x in tensor[:final, i]]) |
|
else: |
|
out.append([self.target_vocabulary[x] for x in tensor[:, i]]) |
|
return out |
|
|