johnpaulbin's picture
yep
22d4f29
raw
history blame
12.4 kB
# -*- coding: utf-8 -*-
""" Implement a pyTorch LSTM with hard sigmoid reccurent activation functions.
Adapted from the non-cuda variant of pyTorch LSTM at
https://github.com/pytorch/pytorch/blob/master/torch/nn/_functions/rnn.py
"""
from __future__ import print_function, division
import math
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.nn.utils.rnn import PackedSequence
import torch.nn.functional as F
class LSTMHardSigmoid(Module):
def __init__(self, input_size, hidden_size,
num_layers=1, bias=True, batch_first=False,
dropout=0, bidirectional=False):
super(LSTMHardSigmoid, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = dropout
self.dropout_state = {}
self.bidirectional = bidirectional
num_directions = 2 if bidirectional else 1
gate_size = 4 * hidden_size
self._all_weights = []
for layer in range(num_layers):
for direction in range(num_directions):
layer_input_size = input_size if layer == 0 else hidden_size * num_directions
w_ih = Parameter(torch.Tensor(gate_size, layer_input_size))
w_hh = Parameter(torch.Tensor(gate_size, hidden_size))
b_ih = Parameter(torch.Tensor(gate_size))
b_hh = Parameter(torch.Tensor(gate_size))
layer_params = (w_ih, w_hh, b_ih, b_hh)
suffix = '_reverse' if direction == 1 else ''
param_names = ['weight_ih_l{}{}', 'weight_hh_l{}{}']
if bias:
param_names += ['bias_ih_l{}{}', 'bias_hh_l{}{}']
param_names = [x.format(layer, suffix) for x in param_names]
for name, param in zip(param_names, layer_params):
setattr(self, name, param)
self._all_weights.append(param_names)
self.flatten_parameters()
self.reset_parameters()
def flatten_parameters(self):
"""Resets parameter data pointer so that they can use faster code paths.
Right now, this is a no-op wince we don't use CUDA acceleration.
"""
self._data_ptrs = []
def _apply(self, fn):
ret = super(LSTMHardSigmoid, self)._apply(fn)
self.flatten_parameters()
return ret
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx=None):
is_packed = isinstance(input, PackedSequence)
if is_packed:
input, batch_sizes ,_ ,_ = input
max_batch_size = batch_sizes[0]
else:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.autograd.Variable(input.data.new(self.num_layers *
num_directions,
max_batch_size,
self.hidden_size).zero_(), requires_grad=False)
hx = (hx, hx)
has_flat_weights = list(p.data.data_ptr() for p in self.parameters()) == self._data_ptrs
if has_flat_weights:
first_data = next(self.parameters()).data
assert first_data.storage().size() == self._param_buf_size
flat_weight = first_data.new().set_(first_data.storage(), 0, torch.Size([self._param_buf_size]))
else:
flat_weight = None
func = AutogradRNN(
self.input_size,
self.hidden_size,
num_layers=self.num_layers,
batch_first=self.batch_first,
dropout=self.dropout,
train=self.training,
bidirectional=self.bidirectional,
batch_sizes=batch_sizes,
dropout_state=self.dropout_state,
flat_weight=flat_weight
)
output, hidden = func(input, self.all_weights, hx)
if is_packed:
output = PackedSequence(output, batch_sizes)
return output, hidden
def __repr__(self):
s = '{name}({input_size}, {hidden_size}'
if self.num_layers != 1:
s += ', num_layers={num_layers}'
if self.bias is not True:
s += ', bias={bias}'
if self.batch_first is not False:
s += ', batch_first={batch_first}'
if self.dropout != 0:
s += ', dropout={dropout}'
if self.bidirectional is not False:
s += ', bidirectional={bidirectional}'
s += ')'
return s.format(name=self.__class__.__name__, **self.__dict__)
def __setstate__(self, d):
super(LSTMHardSigmoid, self).__setstate__(d)
self.__dict__.setdefault('_data_ptrs', [])
if 'all_weights' in d:
self._all_weights = d['all_weights']
if isinstance(self._all_weights[0][0], str):
return
num_layers = self.num_layers
num_directions = 2 if self.bidirectional else 1
self._all_weights = []
for layer in range(num_layers):
for direction in range(num_directions):
suffix = '_reverse' if direction == 1 else ''
weights = ['weight_ih_l{}{}', 'weight_hh_l{}{}', 'bias_ih_l{}{}', 'bias_hh_l{}{}']
weights = [x.format(layer, suffix) for x in weights]
if self.bias:
self._all_weights += [weights]
else:
self._all_weights += [weights[:2]]
@property
def all_weights(self):
return [[getattr(self, weight) for weight in weights] for weights in self._all_weights]
def AutogradRNN(input_size, hidden_size, num_layers=1, batch_first=False,
dropout=0, train=True, bidirectional=False, batch_sizes=None,
dropout_state=None, flat_weight=None):
cell = LSTMCell
if batch_sizes is None:
rec_factory = Recurrent
else:
rec_factory = variable_recurrent_factory(batch_sizes)
if bidirectional:
layer = (rec_factory(cell), rec_factory(cell, reverse=True))
else:
layer = (rec_factory(cell),)
func = StackedRNN(layer,
num_layers,
True,
dropout=dropout,
train=train)
def forward(input, weight, hidden):
if batch_first and batch_sizes is None:
input = input.transpose(0, 1)
nexth, output = func(input, hidden, weight)
if batch_first and batch_sizes is None:
output = output.transpose(0, 1)
return output, nexth
return forward
def Recurrent(inner, reverse=False):
def forward(input, hidden, weight):
output = []
steps = range(input.size(0) - 1, -1, -1) if reverse else range(input.size(0))
for i in steps:
hidden = inner(input[i], hidden, *weight)
# hack to handle LSTM
output.append(hidden[0] if isinstance(hidden, tuple) else hidden)
if reverse:
output.reverse()
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
return hidden, output
return forward
def variable_recurrent_factory(batch_sizes):
def fac(inner, reverse=False):
if reverse:
return VariableRecurrentReverse(batch_sizes, inner)
else:
return VariableRecurrent(batch_sizes, inner)
return fac
def VariableRecurrent(batch_sizes, inner):
def forward(input, hidden, weight):
output = []
input_offset = 0
last_batch_size = batch_sizes[0]
hiddens = []
flat_hidden = not isinstance(hidden, tuple)
if flat_hidden:
hidden = (hidden,)
for batch_size in batch_sizes:
step_input = input[input_offset:input_offset + batch_size]
input_offset += batch_size
dec = last_batch_size - batch_size
if dec > 0:
hiddens.append(tuple(h[-dec:] for h in hidden))
hidden = tuple(h[:-dec] for h in hidden)
last_batch_size = batch_size
if flat_hidden:
hidden = (inner(step_input, hidden[0], *weight),)
else:
hidden = inner(step_input, hidden, *weight)
output.append(hidden[0])
hiddens.append(hidden)
hiddens.reverse()
hidden = tuple(torch.cat(h, 0) for h in zip(*hiddens))
assert hidden[0].size(0) == batch_sizes[0]
if flat_hidden:
hidden = hidden[0]
output = torch.cat(output, 0)
return hidden, output
return forward
def VariableRecurrentReverse(batch_sizes, inner):
def forward(input, hidden, weight):
output = []
input_offset = input.size(0)
last_batch_size = batch_sizes[-1]
initial_hidden = hidden
flat_hidden = not isinstance(hidden, tuple)
if flat_hidden:
hidden = (hidden,)
initial_hidden = (initial_hidden,)
hidden = tuple(h[:batch_sizes[-1]] for h in hidden)
for batch_size in reversed(batch_sizes):
inc = batch_size - last_batch_size
if inc > 0:
hidden = tuple(torch.cat((h, ih[last_batch_size:batch_size]), 0)
for h, ih in zip(hidden, initial_hidden))
last_batch_size = batch_size
step_input = input[input_offset - batch_size:input_offset]
input_offset -= batch_size
if flat_hidden:
hidden = (inner(step_input, hidden[0], *weight),)
else:
hidden = inner(step_input, hidden, *weight)
output.append(hidden[0])
output.reverse()
output = torch.cat(output, 0)
if flat_hidden:
hidden = hidden[0]
return hidden, output
return forward
def StackedRNN(inners, num_layers, lstm=False, dropout=0, train=True):
num_directions = len(inners)
total_layers = num_layers * num_directions
def forward(input, hidden, weight):
assert(len(weight) == total_layers)
next_hidden = []
if lstm:
hidden = list(zip(*hidden))
for i in range(num_layers):
all_output = []
for j, inner in enumerate(inners):
l = i * num_directions + j
hy, output = inner(input, hidden[l], weight[l])
next_hidden.append(hy)
all_output.append(output)
input = torch.cat(all_output, input.dim() - 1)
if dropout != 0 and i < num_layers - 1:
input = F.dropout(input, p=dropout, training=train, inplace=False)
if lstm:
next_h, next_c = zip(*next_hidden)
next_hidden = (
torch.cat(next_h, 0).view(total_layers, *next_h[0].size()),
torch.cat(next_c, 0).view(total_layers, *next_c[0].size())
)
else:
next_hidden = torch.cat(next_hidden, 0).view(
total_layers, *next_hidden[0].size())
return next_hidden, input
return forward
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
"""
A modified LSTM cell with hard sigmoid activation on the input, forget and output gates.
"""
hx, cx = hidden
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = hard_sigmoid(ingate)
forgetgate = hard_sigmoid(forgetgate)
cellgate = F.tanh(cellgate)
outgate = hard_sigmoid(outgate)
cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * F.tanh(cy)
return hy, cy
def hard_sigmoid(x):
"""
Computes element-wise hard sigmoid of x.
See e.g. https://github.com/Theano/Theano/blob/master/theano/tensor/nnet/sigm.py#L279
"""
x = (0.2 * x) + 0.5
x = F.threshold(-x, -1, -1)
x = F.threshold(-x, 0, 0)
return x