Spaces:
Runtime error
Runtime error
File size: 4,876 Bytes
5212a08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Code for GP final layer adapted from this great repo:
# https://github.com/kimjeyoung/SNGP-BERT-Pytorch .
# We simplify things here a bit by removing the spectral
# normalisation as the authors of the Plex paper say that this
# isn't strictly necessary, so we just have a GP classification head on the model.
import torch
import math
import copy
from torch import nn
def RandomFeatureLinear(i_dim, o_dim, bias=True, require_grad=False):
m = nn.Linear(i_dim, o_dim, bias)
nn.init.normal_(m.weight, mean=0.0, std=0.05)
m.weight.requires_grad = require_grad # Freeze weights
if bias:
nn.init.uniform_(m.bias, a=0.0, b=2.0 * math.pi) # Freeze bias
m.bias.requires_grad = require_grad
return m
class GPClassificationHead(nn.Module):
def __init__(
self,
hidden_size=768,
gp_kernel_scale=1.0,
num_inducing=1024,
gp_output_bias=0.0,
layer_norm_eps=1e-12,
scale_random_features=True,
normalize_input=True,
gp_cov_momentum=0.999,
gp_cov_ridge_penalty=1e-3,
epochs=40,
num_classes=3,
device="cpu",
):
super(GPClassificationHead, self).__init__()
self.final_epochs = epochs - 1
self.gp_cov_ridge_penalty = gp_cov_ridge_penalty
self.gp_cov_momentum = gp_cov_momentum
self.pooled_output_dim = hidden_size
self.gp_input_scale = 1.0 / math.sqrt(gp_kernel_scale)
self.gp_feature_scale = math.sqrt(2.0 / float(num_inducing))
self.gp_output_bias = gp_output_bias
self.scale_random_features = scale_random_features
self.normalize_input = normalize_input
self.device = device
self._gp_input_normalize_layer = torch.nn.LayerNorm(
hidden_size, eps=layer_norm_eps
)
self._gp_output_layer = nn.Linear(
num_inducing, num_classes, bias=False
) # gp_output_bias set to not trainable
self._gp_output_bias = torch.tensor([self.gp_output_bias] * num_classes).to(
device
)
self._random_feature = RandomFeatureLinear(self.pooled_output_dim, num_inducing)
# Inverse covariance matrix corresponding to RFF-GP posterior
self.initial_precision_matrix = self.gp_cov_ridge_penalty * torch.eye(
num_inducing
).to(device)
self.precision_matrix = torch.nn.Parameter(
copy.deepcopy(self.initial_precision_matrix), requires_grad=False
)
def gp_layer(self, gp_inputs, update_cov=True):
if self.normalize_input:
gp_inputs = self._gp_input_normalize_layer(gp_inputs)
gp_feature = self._random_feature(gp_inputs)
gp_feature = torch.cos(gp_feature)
if self.scale_random_features:
gp_feature = gp_feature * self.gp_input_scale
gp_output = self._gp_output_layer(gp_feature).to(
self.device
) + self._gp_output_bias.to(self.device)
if update_cov:
self.update_cov(gp_feature)
return gp_feature, gp_output
def reset_cov(self):
self.precision_matrix = torch.nn.Parameter(
copy.deepcopy(self.initial_precision_matrix), requires_grad=False
)
def update_cov(self, gp_feature):
# https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L346
batch_size = gp_feature.size()[0]
precision_matrix_minibatch = torch.matmul(gp_feature.t(), gp_feature)
# Moving average updates to precision matrix
precision_matrix_minibatch = precision_matrix_minibatch / batch_size
precision_matrix_new = (
self.gp_cov_momentum * self.precision_matrix
+ (1.0 - self.gp_cov_momentum) * precision_matrix_minibatch
)
self.precision_matrix = torch.nn.Parameter(
precision_matrix_new, requires_grad=False
)
def compute_predictive_covariance(self, gp_feature):
# https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/random_feature.py#L403
# Covariance matrix of feature coefficient
feature_cov_matrix = torch.linalg.inv(self.precision_matrix)
# Predictive covariance matrix for the GP
cov_feature_product = (
torch.matmul(feature_cov_matrix, gp_feature.t()) * self.gp_cov_ridge_penalty
)
gp_cov_matrix = torch.matmul(gp_feature, cov_feature_product)
return gp_cov_matrix
def forward(
self,
input_features,
return_gp_cov: bool = False,
update_cov: bool = True,
):
gp_feature, gp_output = self.gp_layer(input_features, update_cov=update_cov)
if return_gp_cov:
gp_cov_matrix = self.compute_predictive_covariance(gp_feature)
return gp_output, gp_cov_matrix
return gp_output
|