File size: 2,600 Bytes
5112867
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
from typing import List
from torch import nn
import torch
from pathlib import Path
import json
from .gpt_model import Model, HParams


class GPTModel(nn.Module):
    def __init__(self, path, n_layer=-1, freeze=True, use_lstm=False):
        super().__init__()
        root = Path(path)

        params = json.loads((root / "params.json").read_text())
        hparams = params["hparams"]
        hparams.setdefault("n_hidden", hparams["n_embed"])
        self.model = Model(HParams(**hparams))
        state = torch.load(root / "model.pt", map_location="cpu")
        state_dict = self.fixed_state_dict(state["state_dict"])
        self.model.load_state_dict(state_dict)
        self.activation = {}
        self.freeze = freeze
        self.n_layer = n_layer
        if self.freeze:
            for param in self.model.parameters():
                param.requires_grad = False

        self.activation = {}
        self.use_lstm = use_lstm
        self.set_hook(self.n_layer)
        self.in_fc_layer = 512 if self.use_lstm else 768
        self.lstm1 = nn.LSTM(
            768,
            256,
            bidirectional=True,
            batch_first=True,
        )
        self.lstm2 = nn.LSTM(
            512,
            256,
            bidirectional=True,
            batch_first=True,
        )
        self.lstm3 = nn.LSTM(
            512,
            256,
            bidirectional=True,
            batch_first=True,
        )
        self.fc = nn.Linear(self.in_fc_layer, 17)

    def get_activation(self, name):
        def hook(model, input, output):
            self.activation[name] = output[0].detach()

        return hook

    def set_hook(self, n_layer=0):
        self.model.blocks[n_layer].register_forward_hook(self.get_activation("feats"))

    def fixed_state_dict(self, state_dict):
        if all(k.startswith("module.") for k in state_dict):
            # legacy multi-GPU format
            state_dict = {k[len("module.") :]: v for k, v in state_dict.items()}
        return state_dict

    def forward(self, src: torch.Tensor, lengths: torch.Tensor, target=None):

        # logits shape [batch_size, 256, 500]
        logits = self.model(src)["logits"]
        logits = self.activation["feats"]

        if self.use_lstm:
            x, (h, cn) = self.lstm1(logits)
            x, (h, cn) = self.lstm2(x)
            x, (h, cn) = self.lstm3(x)
        else:
            x = logits
        predictions = self.fc(x)

        output = {"diacritics": predictions}

        return output