File size: 5,045 Bytes
06a8327 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import torch
from allennlp.modules.feedforward import FeedForward
from allennlp.modules.seq2vec_encoders import PytorchSeq2VecWrapper
from higher.patch import monkeypatch as make_functional
class ConditionedParameter(torch.nn.Module):
def __init__(self, parameter, condition_dim=1024, hidden_dim=128, max_scale=1):
super().__init__()
self.parameter_shape = parameter.shape
if len(self.parameter_shape) == 2: # condition_dim是从lstm中得到的tensor,然后用linear学习返回到768作为更新的parm_dict
self.conditioners = torch.nn.Sequential(
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
torch.nn.Tanh(),
torch.nn.utils.weight_norm(
torch.nn.Linear(
hidden_dim, 2 * (parameter.shape[0] + parameter.shape[1]) + 1
)
),
)
elif len(self.parameter_shape) == 1:
self.conditioners = torch.nn.Sequential(
torch.nn.utils.weight_norm(torch.nn.Linear(condition_dim, hidden_dim)),
torch.nn.Tanh(),
torch.nn.utils.weight_norm(
torch.nn.Linear(hidden_dim, 2 * parameter.shape[0] + 1)
),
)
else:
raise RuntimeError()
self.max_scale = max_scale
def forward(self, inputs, grad):
if len(self.parameter_shape) == 2:
(
conditioner_cola,
conditioner_rowa,
conditioner_colb,
conditioner_rowb,
conditioner_norm,
) = self.conditioners(inputs).split(
[
self.parameter_shape[1],
self.parameter_shape[0],
self.parameter_shape[1],
self.parameter_shape[0],
1,
],
dim=-1,
)
a = conditioner_rowa.softmax(-1).T @ conditioner_cola
b = conditioner_rowb.softmax(-1).T @ conditioner_colb
elif len(self.parameter_shape) == 1:
a, b, conditioner_norm = self.conditioners(inputs).split(
[self.parameter_shape[0], self.parameter_shape[0], 1], dim=-1
)
else:
raise RuntimeError()
return (
self.max_scale
* torch.mean(conditioner_norm.sigmoid(), dim=0).squeeze() # 多条我们直接取mean
* (grad * a.squeeze() + b.squeeze())
)
class LSTMConditioner(torch.nn.Module):
def __init__(
self,
vocab_dim=30522,
embedding_dim=768,
hidden_dim=256,
output_dim=1024,
embedding_init=None,
):
super().__init__()
self.embedding = torch.nn.Embedding(
num_embeddings=vocab_dim,
embedding_dim=embedding_dim,
padding_idx=0,
_weight=embedding_init,
)
self.lstm = PytorchSeq2VecWrapper(
torch.nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=True,
batch_first=True,
)
)
self.linear = FeedForward(
input_dim=hidden_dim * 2,
num_layers=1,
hidden_dims=[output_dim],
activations=[torch.nn.Tanh()],
)
def forward(self, inputs, masks):
return self.linear(self.lstm(self.embedding(inputs), masks)) # 1, 64
class OneShotLearner(torch.nn.Module):
def __init__(
self,
model,
vocab_dim=30522,
embedding_dim=768,
hidden_dim=128,
condition_dim=1024,
include_set={},
max_scale=1e-3,
embedding_init=None,
):
super().__init__()
self.param2conditioner_map = {
n: "{}_conditioner".format(n).replace(".", "_")
for n, p in model.named_parameters()
if n in include_set
}
self.conditioners = torch.nn.ModuleDict(
{
self.param2conditioner_map[n]: ConditionedParameter(
p,
condition_dim,
hidden_dim,
max_scale=max_scale,
)
for n, p in model.named_parameters()
if n in include_set
}
)
self.condition = LSTMConditioner(
vocab_dim,
embedding_dim,
hidden_dim,
condition_dim,
embedding_init=embedding_init,
)
def forward(self, inputs, masks, grads=None):
condition = self.condition(inputs, masks) # LSTM输出condition
return {
p: self.conditioners[self.param2conditioner_map[p]](
condition,
grad=grads[p] if grads else None,
)
for p, c in self.param2conditioner_map.items()
}
|