Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
""" Implement a pyTorch LSTM with hard sigmoid reccurent activation functions. | |
Adapted from the non-cuda variant of pyTorch LSTM at | |
https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py | |
""" | |
from __future__ import print_function, division | |
import math | |
import torch | |
from torch.nn import Module | |
from torch.nn.parameter import Parameter | |
from torch.nn.utils.rnn import PackedSequence | |
import torch.nn.functional as F | |
class LSTMHardSigmoid(Module): | |
def __init__(self, input_size, hidden_size, | |
num_layers=1, bias=True, batch_first=False, | |
dropout=0, bidirectional=False): | |
super(LSTMHardSigmoid, self).__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.bias = bias | |
self.batch_first = batch_first | |
self.dropout = dropout | |
self.dropout_state = {} | |
self.bidirectional = bidirectional | |
num_directions = 2 if bidirectional else 1 | |
gate_size = 4 * hidden_size | |
self._all_weights = [] | |
for layer in range(num_layers): | |
for direction in range(num_directions): | |
layer_input_size = input_size if layer == 0 else hidden_size * num_directions | |
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size)) | |
w_hh = Parameter(torch.Tensor(gate_size, hidden_size)) | |
b_ih = Parameter(torch.Tensor(gate_size)) | |
b_hh = Parameter(torch.Tensor(gate_size)) | |
layer_params = (w_ih, w_hh, b_ih, b_hh) | |
suffix = '_reverse' if direction == 1 else '' | |
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}'] | |
if bias: | |
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}'] | |
param_names = [x.format(layer, suffix) for x in param_names] | |
for name, param in zip(param_names, layer_params): | |
setattr(self, name, param) | |
self._all_weights.append(param_names) | |
self.flatten_parameters() | |
self.reset_parameters() | |
def flatten_parameters(self): | |
"""Resets parameter data pointer so that they can use faster code paths. | |
Right now, this is a no-op wince we don't use CUDA acceleration. | |
""" | |
self._data_ptrs = [] | |
def _apply(self, fn): | |
ret = super(LSTMHardSigmoid, self)._apply(fn) | |
self.flatten_parameters() | |
return ret | |
def reset_parameters(self): | |
stdv = 1.0 / math.sqrt(self.hidden_size) | |
for weight in self.parameters(): | |
weight.data.uniform_(-stdv, stdv) | |
def forward(self, input, hx=None): | |
is_packed = isinstance(input, PackedSequence) | |
if is_packed: | |
batch_sizes = input.batch_sizes | |
input = input.data | |
max_batch_size = batch_sizes[0] | |
else: | |
batch_sizes = None | |
max_batch_size = input.size(0) if self.batch_first else input.size(1) | |
if hx is None: | |
num_directions = 2 if self.bidirectional else 1 | |
hx = torch.autograd.Variable(input.data.new(self.num_layers * | |
num_directions, | |
max_batch_size, | |
self.hidden_size).zero_(), requires_grad=False) | |
hx = (hx, hx) | |
has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs | |
if has_flat_weights: | |
first_data = next(self.parameters()).data | |
assert first_data.storage().size() == self._param_buf_size | |
flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size])) | |
else: | |
flat_weight = None | |
func = AutogradRNN( | |
self.input_size, | |
self.hidden_size, | |
num_layers=self.num_layers, | |
batch_first=self.batch_first, | |
dropout=self.dropout, | |
train=self.training, | |
bidirectional=self.bidirectional, | |
batch_sizes=batch_sizes, | |
dropout_state=self.dropout_state, | |
flat_weight=flat_weight | |
) | |
output, hidden = func(input, self.all_weights, hx) | |
if is_packed: | |
output = PackedSequence(output, batch_sizes) | |
return output, hidden | |
def __repr__(self): | |
s = '{name}({input_size}, {hidden_size}' | |
if self.num_layers != 1: | |
s += ', num_layers={num_layers}' | |
if self.bias is not True: | |
s += ', bias={bias}' | |
if self.batch_first is not False: | |
s += ', batch_first={batch_first}' | |
if self.dropout != 0: | |
s += ', dropout={dropout}' | |
if self.bidirectional is not False: | |
s += ', bidirectional={bidirectional}' | |
s += ')' | |
return s.format(name=self.__class__.__name__, **self.__dict__) | |
def __setstate__(self, d): | |
super(LSTMHardSigmoid, self).__setstate__(d) | |
self.__dict__.setdefault('_data_ptrs', []) | |
if 'all_weights' in d: | |
self._all_weights = d['all_weights'] | |
if isinstance(self._all_weights[0][0], str): | |
return | |
num_layers = self.num_layers | |
num_directions = 2 if self.bidirectional else 1 | |
self._all_weights = [] | |
for layer in range(num_layers): | |
for direction in range(num_directions): | |
suffix = '_reverse' if direction == 1 else '' | |
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}'] | |
weights = [x.format(layer, suffix) for x in weights] | |
if self.bias: | |
self._all_weights += [weights] | |
else: | |
self._all_weights += [weights[:2]] | |
def all_weights(self): | |
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights] | |
def AutogradRNN(input_size, hidden_size, num_layers=1, batch_first=False, | |
dropout=0, train=True, bidirectional=False, batch_sizes=None, | |
dropout_state=None, flat_weight=None): | |
cell = LSTMCell | |
if batch_sizes is None: | |
rec_factory = Recurrent | |
else: | |
rec_factory = variable_recurrent_factory(batch_sizes) | |
if bidirectional: | |
layer = (rec_factory(cell), rec_factory(cell, reverse=True)) | |
else: | |
layer = (rec_factory(cell),) | |
func = StackedRNN(layer, | |
num_layers, | |
True, | |
dropout=dropout, | |
train=train) | |
def forward(input, weight, hidden): | |
if batch_first and batch_sizes is None: | |
input = input.transpose(0, 1) | |
nexth, output = func(input, hidden, weight) | |
if batch_first and batch_sizes is None: | |
output = output.transpose(0, 1) | |
return output, nexth | |
return forward | |
def Recurrent(inner, reverse=False): | |
def forward(input, hidden, weight): | |
output = [] | |
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0)) | |
for i in steps: | |
hidden = inner(input[i], hidden, *weight) | |
# hack to handle LSTM | |
output.append(hidden[0] if isinstance(hidden, tuple) else hidden) | |
if reverse: | |
output.reverse() | |
output = torch.cat(output, 0).view(input.size(0), *output[0].size()) | |
return hidden, output | |
return forward | |
def variable_recurrent_factory(batch_sizes): | |
def fac(inner, reverse=False): | |
if reverse: | |
return VariableRecurrentReverse(batch_sizes, inner) | |
else: | |
return VariableRecurrent(batch_sizes, inner) | |
return fac | |
def VariableRecurrent(batch_sizes, inner): | |
def forward(input, hidden, weight): | |
output = [] | |
input_offset = 0 | |
last_batch_size = batch_sizes[0] | |
hiddens = [] | |
flat_hidden = not isinstance(hidden, tuple) | |
if flat_hidden: | |
hidden = (hidden,) | |
for batch_size in batch_sizes: | |
step_input = input[input_offset:input_offset + batch_size] | |
input_offset += batch_size | |
dec = last_batch_size - batch_size | |
if dec > 0: | |
hiddens.append(tuple(h[-dec:] for h in hidden)) | |
hidden = tuple(h[:-dec] for h in hidden) | |
last_batch_size = batch_size | |
if flat_hidden: | |
hidden = (inner(step_input, hidden[0], *weight),) | |
else: | |
hidden = inner(step_input, hidden, *weight) | |
output.append(hidden[0]) | |
hiddens.append(hidden) | |
hiddens.reverse() | |
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens)) | |
assert hidden[0].size(0) == batch_sizes[0] | |
if flat_hidden: | |
hidden = hidden[0] | |
output = torch.cat(output, 0) | |
return hidden, output | |
return forward | |
def VariableRecurrentReverse(batch_sizes, inner): | |
def forward(input, hidden, weight): | |
output = [] | |
input_offset = input.size(0) | |
last_batch_size = batch_sizes[-1] | |
initial_hidden = hidden | |
flat_hidden = not isinstance(hidden, tuple) | |
if flat_hidden: | |
hidden = (hidden,) | |
initial_hidden = (initial_hidden,) | |
hidden = tuple(h[:batch_sizes[-1]] for h in hidden) | |
for batch_size in reversed(batch_sizes): | |
inc = batch_size - last_batch_size | |
if inc > 0: | |
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0) | |
for h, ih in zip(hidden, initial_hidden)) | |
last_batch_size = batch_size | |
step_input = input[input_offset - batch_size:input_offset] | |
input_offset -= batch_size | |
if flat_hidden: | |
hidden = (inner(step_input, hidden[0], *weight),) | |
else: | |
hidden = inner(step_input, hidden, *weight) | |
output.append(hidden[0]) | |
output.reverse() | |
output = torch.cat(output, 0) | |
if flat_hidden: | |
hidden = hidden[0] | |
return hidden, output | |
return forward | |
def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True): | |
num_directions = len(inners) | |
total_layers = num_layers * num_directions | |
def forward(input, hidden, weight): | |
assert(len(weight) == total_layers) | |
next_hidden = [] | |
if lstm: | |
hidden = list(zip(*hidden)) | |
for i in range(num_layers): | |
all_output = [] | |
for j, inner in enumerate(inners): | |
l = i * num_directions + j | |
hy, output = inner(input, hidden[l], weight[l]) | |
next_hidden.append(hy) | |
all_output.append(output) | |
input = torch.cat(all_output, input.dim() - 1) | |
if dropout != 0 and i < num_layers - 1: | |
input = F.dropout(input, p=dropout, training=train, inplace=False) | |
if lstm: | |
next_h, next_c = zip(*next_hidden) | |
next_hidden = ( | |
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()), | |
torch.cat(next_c, 0).view(total_layers, *next_c[0].size()) | |
) | |
else: | |
next_hidden = torch.cat(next_hidden, 0).view( | |
total_layers, *next_hidden[0].size()) | |
return next_hidden, input | |
return forward | |
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None): | |
""" | |
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates. | |
""" | |
hx, cx = hidden | |
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh) | |
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) | |
ingate = hard_sigmoid(ingate) | |
forgetgate = hard_sigmoid(forgetgate) | |
cellgate = torch.tanh(cellgate) | |
outgate = hard_sigmoid(outgate) | |
cy = (forgetgate * cx) + (ingate * cellgate) | |
hy = outgate * torch.tanh(cy) | |
return hy, cy | |
def hard_sigmoid(x): | |
""" | |
Computes element-wise hard sigmoid of x. | |
See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279 | |
""" | |
x = (0.2 * x) + 0.5 | |
x = F.threshold(-x, -1, -1) | |
x = F.threshold(-x, 0, 0) | |
return x | |