File size: 2,600 Bytes
5112867 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from typing import List
from torch import nn
import torch
from pathlib import Path
import json
from .gpt_model import Model, HParams
class GPTModel(nn.Module):
def __init__(self, path, n_layer=-1, freeze=True, use_lstm=False):
super().__init__()
root = Path(path)
params = json.loads((root / "params.json").read_text())
hparams = params["hparams"]
hparams.setdefault("n_hidden", hparams["n_embed"])
self.model = Model(HParams(**hparams))
state = torch.load(root / "model.pt", map_location="cpu")
state_dict = self.fixed_state_dict(state["state_dict"])
self.model.load_state_dict(state_dict)
self.activation = {}
self.freeze = freeze
self.n_layer = n_layer
if self.freeze:
for param in self.model.parameters():
param.requires_grad = False
self.activation = {}
self.use_lstm = use_lstm
self.set_hook(self.n_layer)
self.in_fc_layer = 512 if self.use_lstm else 768
self.lstm1 = nn.LSTM(
768,
256,
bidirectional=True,
batch_first=True,
)
self.lstm2 = nn.LSTM(
512,
256,
bidirectional=True,
batch_first=True,
)
self.lstm3 = nn.LSTM(
512,
256,
bidirectional=True,
batch_first=True,
)
self.fc = nn.Linear(self.in_fc_layer, 17)
def get_activation(self, name):
def hook(model, input, output):
self.activation[name] = output[0].detach()
return hook
def set_hook(self, n_layer=0):
self.model.blocks[n_layer].register_forward_hook(self.get_activation("feats"))
def fixed_state_dict(self, state_dict):
if all(k.startswith("module.") for k in state_dict):
# legacy multi-GPU format
state_dict = {k[len("module.") :]: v for k, v in state_dict.items()}
return state_dict
def forward(self, src: torch.Tensor, lengths: torch.Tensor, target=None):
# logits shape [batch_size, 256, 500]
logits = self.model(src)["logits"]
logits = self.activation["feats"]
if self.use_lstm:
x, (h, cn) = self.lstm1(logits)
x, (h, cn) = self.lstm2(x)
x, (h, cn) = self.lstm3(x)
else:
x = logits
predictions = self.fc(x)
output = {"diacritics": predictions}
return output
|