|
import torch |
|
from allennlp.modules.feedforward import FeedForward |
|
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper |
|
from higher.patch import monkeypatch as make_functional |
|
|
|
|
|
class ConditionedParameter(torch.nn.Module): |
|
def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1): |
|
super().__init__() |
|
self.parameter_shape = parameter.shape |
|
|
|
if len(self.parameter_shape) == 2: |
|
self.conditioners = torch.nn.Sequential( |
|
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), |
|
torch.nn.Tanh(), |
|
torch.nn.utils.weight_norm( |
|
torch.nn.Linear( |
|
hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1 |
|
) |
|
), |
|
) |
|
elif len(self.parameter_shape) == 1: |
|
self.conditioners = torch.nn.Sequential( |
|
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)), |
|
torch.nn.Tanh(), |
|
torch.nn.utils.weight_norm( |
|
torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1) |
|
), |
|
) |
|
else: |
|
raise RuntimeError() |
|
|
|
self.max_scale = max_scale |
|
|
|
def forward(self, inputs, grad): |
|
|
|
if len(self.parameter_shape) == 2: |
|
( |
|
conditioner_cola, |
|
conditioner_rowa, |
|
conditioner_colb, |
|
conditioner_rowb, |
|
conditioner_norm, |
|
) = self.conditioners(inputs).split( |
|
[ |
|
self.parameter_shape[1], |
|
self.parameter_shape[0], |
|
self.parameter_shape[1], |
|
self.parameter_shape[0], |
|
1, |
|
], |
|
dim=-1, |
|
) |
|
|
|
a = conditioner_rowa.softmax(-1).T @ conditioner_cola |
|
b = conditioner_rowb.softmax(-1).T @ conditioner_colb |
|
|
|
elif len(self.parameter_shape) == 1: |
|
a, b, conditioner_norm = self.conditioners(inputs).split( |
|
[self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1 |
|
) |
|
else: |
|
raise RuntimeError() |
|
|
|
return ( |
|
self.max_scale |
|
* torch.mean(conditioner_norm.sigmoid(), dim=0).squeeze() |
|
* (grad * a.squeeze() + b.squeeze()) |
|
) |
|
|
|
|
|
class LSTMConditioner(torch.nn.Module): |
|
def __init__( |
|
self, |
|
vocab_dim=30522, |
|
embedding_dim=768, |
|
hidden_dim=256, |
|
output_dim=1024, |
|
embedding_init=None, |
|
): |
|
super().__init__() |
|
self.embedding = torch.nn.Embedding( |
|
num_embeddings=vocab_dim, |
|
embedding_dim=embedding_dim, |
|
padding_idx=0, |
|
_weight=embedding_init, |
|
) |
|
self.lstm = PytorchSeq2VecWrapper( |
|
torch.nn.LSTM( |
|
input_size=embedding_dim, |
|
hidden_size=hidden_dim, |
|
num_layers=1, |
|
bidirectional=True, |
|
batch_first=True, |
|
) |
|
) |
|
self.linear = FeedForward( |
|
input_dim=hidden_dim * 2, |
|
num_layers=1, |
|
hidden_dims=[output_dim], |
|
activations=[torch.nn.Tanh()], |
|
) |
|
|
|
def forward(self, inputs, masks): |
|
return self.linear(self.lstm(self.embedding(inputs), masks)) |
|
|
|
|
|
class OneShotLearner(torch.nn.Module): |
|
def __init__( |
|
self, |
|
model, |
|
vocab_dim=30522, |
|
embedding_dim=768, |
|
hidden_dim=128, |
|
condition_dim=1024, |
|
include_set={}, |
|
max_scale=1e-3, |
|
embedding_init=None, |
|
): |
|
super().__init__() |
|
|
|
self.param2conditioner_map = { |
|
n: "{}_conditioner".format(n).replace(".", "_") |
|
for n, p in model.named_parameters() |
|
if n in include_set |
|
} |
|
|
|
self.conditioners = torch.nn.ModuleDict( |
|
{ |
|
self.param2conditioner_map[n]: ConditionedParameter( |
|
p, |
|
condition_dim, |
|
hidden_dim, |
|
max_scale=max_scale, |
|
) |
|
for n, p in model.named_parameters() |
|
if n in include_set |
|
} |
|
) |
|
|
|
self.condition = LSTMConditioner( |
|
vocab_dim, |
|
embedding_dim, |
|
hidden_dim, |
|
condition_dim, |
|
embedding_init=embedding_init, |
|
) |
|
|
|
def forward(self, inputs, masks, grads=None): |
|
condition = self.condition(inputs, masks) |
|
return { |
|
p: self.conditioners[self.param2conditioner_map[p]]( |
|
condition, |
|
grad=grads[p] if grads else None, |
|
) |
|
for p, c in self.param2conditioner_map.items() |
|
} |
|
|