File size: 4,876 Bytes
5212a08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Code for GP final layer adapted from this great repo:
# https://github.com/kimjeyoung/SNGP-BERT-Pytorch .
# We simplify things here a bit by removing the spectral
# normalisation as the authors of the Plex paper say that this
# isn't strictly necessary, so we just have a GP classification head on the model.

import torch
import math
import copy
from torch import nn


def RandomFeatureLinear(i_dim, o_dim, bias=True, require_grad=False):
    m = nn.Linear(i_dim, o_dim, bias)
    nn.init.normal_(m.weight, mean=0.0, std=0.05)
    m.weight.requires_grad = require_grad  # Freeze weights
    if bias:
        nn.init.uniform_(m.bias, a=0.0, b=2.0 * math.pi)  # Freeze bias
        m.bias.requires_grad = require_grad
    return m


class GPClassificationHead(nn.Module):
    def __init__(
        self,
        hidden_size=768,
        gp_kernel_scale=1.0,
        num_inducing=1024,
        gp_output_bias=0.0,
        layer_norm_eps=1e-12,
        scale_random_features=True,
        normalize_input=True,
        gp_cov_momentum=0.999,
        gp_cov_ridge_penalty=1e-3,
        epochs=40,
        num_classes=3,
        device="cpu",
    ):
        super(GPClassificationHead, self).__init__()
        self.final_epochs = epochs - 1
        self.gp_cov_ridge_penalty = gp_cov_ridge_penalty
        self.gp_cov_momentum = gp_cov_momentum

        self.pooled_output_dim = hidden_size
        self.gp_input_scale = 1.0 / math.sqrt(gp_kernel_scale)
        self.gp_feature_scale = math.sqrt(2.0 / float(num_inducing))
        self.gp_output_bias = gp_output_bias
        self.scale_random_features = scale_random_features
        self.normalize_input = normalize_input
        self.device = device

        self._gp_input_normalize_layer = torch.nn.LayerNorm(
            hidden_size, eps=layer_norm_eps
        )
        self._gp_output_layer = nn.Linear(
            num_inducing, num_classes, bias=False
        )  # gp_output_bias set to not trainable
        self._gp_output_bias = torch.tensor([self.gp_output_bias] * num_classes).to(
            device
        )
        self._random_feature = RandomFeatureLinear(self.pooled_output_dim, num_inducing)

        # Inverse covariance matrix corresponding to RFF-GP posterior
        self.initial_precision_matrix = self.gp_cov_ridge_penalty * torch.eye(
            num_inducing
        ).to(device)
        self.precision_matrix = torch.nn.Parameter(
            copy.deepcopy(self.initial_precision_matrix), requires_grad=False
        )

    def gp_layer(self, gp_inputs, update_cov=True):
        if self.normalize_input:
            gp_inputs = self._gp_input_normalize_layer(gp_inputs)

        gp_feature = self._random_feature(gp_inputs)
        gp_feature = torch.cos(gp_feature)

        if self.scale_random_features:
            gp_feature = gp_feature * self.gp_input_scale

        gp_output = self._gp_output_layer(gp_feature).to(
            self.device
        ) + self._gp_output_bias.to(self.device)

        if update_cov:
            self.update_cov(gp_feature)
        return gp_feature, gp_output

    def reset_cov(self):
        self.precision_matrix = torch.nn.Parameter(
            copy.deepcopy(self.initial_precision_matrix), requires_grad=False
        )

    def update_cov(self, gp_feature):
        # https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L346
        batch_size = gp_feature.size()[0]
        precision_matrix_minibatch = torch.matmul(gp_feature.t(), gp_feature)

        # Moving average updates to precision matrix
        precision_matrix_minibatch = precision_matrix_minibatch / batch_size
        precision_matrix_new = (
            self.gp_cov_momentum * self.precision_matrix
            + (1.0 - self.gp_cov_momentum) * precision_matrix_minibatch
        )

        self.precision_matrix = torch.nn.Parameter(
            precision_matrix_new, requires_grad=False
        )

    def compute_predictive_covariance(self, gp_feature):
        # https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L403
        # Covariance matrix of feature coefficient
        feature_cov_matrix = torch.linalg.inv(self.precision_matrix)

        # Predictive covariance matrix for the GP
        cov_feature_product = (
            torch.matmul(feature_cov_matrix, gp_feature.t()) * self.gp_cov_ridge_penalty
        )
        gp_cov_matrix = torch.matmul(gp_feature, cov_feature_product)
        return gp_cov_matrix

    def forward(
        self,
        input_features,
        return_gp_cov: bool = False,
        update_cov: bool = True,
    ):
        gp_feature, gp_output = self.gp_layer(input_features, update_cov=update_cov)
        if return_gp_cov:
            gp_cov_matrix = self.compute_predictive_covariance(gp_feature)
            return gp_output, gp_cov_matrix
        return gp_output