Echo12 / MultiRotationLayer.py
Sin2pi's picture
Create MultiRotationLayer.py
4241e0e verified
### A mix of:
### Givens Rotation
### Householder Rotation
### Orthogonal Rotation
class CombinedRotaryEmbedding(nn.Module):
def __init__(self, base, dims, head, rotation_type='givens', theta_learnable=True,
rot_learnable=True, matrix_learnable=False, freq_learnable=True):
super(CombinedRotaryEmbedding, self).__init__()
self.base = base
self.dims = dims
self.head = head
self.rotation_type = rotation_type
self.h_dim = self.dims // self.head
self.rot = (self.dims // self.head) // 2
self.thetas = nn.Parameter(torch.zeros(self.rot))
self.r_pairs = nn.Parameter(data=torch.rand(self.rot, 2) * self.h_dim)
self.theta_scale = nn.Parameter(torch.ones(1), requires_grad=theta_learnable)
self.rot_scale = nn.Parameter(torch.ones(1), requires_grad=rot_learnable)
self.r_matrix = nn.Parameter(torch.eye(n=self.h_dim), requires_grad=matrix_learnable)
freq_data = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))
self.inv_freq = nn.Parameter(freq_data, requires_grad=freq_learnable)
self.orthogonal_reg_weight = 0.01
if self.rotation_type == 'givens':
self.rotation_function = self.givens_rotation
elif self.rotation_type == 'householder':
self.rotation_function = self.householder_rotation
elif self.rotation_type == 'orthogonal':
self.rotation_function = self.orthogonal_rotation
else:
raise ValueError('Invalid rotation type')
def givens_r_matrix(self, dims, i, j, theta):
G = torch.eye(dims).to(theta.device)
G[i, i] = torch.cos(theta)
G[i, j] = -torch.sin(theta)
G[j, i] = torch.sin(theta)
G[j, j] = torch.cos(theta)
return G
def givens_rotation(self, x):
adjusted_rot = int(torch.round(self.rot_scale * self.rot))
for k in range(adjusted_rot):
i, j = self.r_pairs[k].long()
theta = self.thetas[k] * self.theta_scale
G = self.givens_r_matrix(dims=self.h_dim, i=i, j=j, theta=theta)
x = torch.matmul(input=x, other=G)
return x
def householder_rotation(self, x):
adjusted_rot = int(torch.round(self.rot_scale * self.rot))
for k in range(adjusted_rot):
i, j = self.r_pairs[k].long()
theta = self.thetas[k] * self.theta_scale
v = torch.zeros(self.h_dim).to(theta.device)
v[i] = torch.cos(theta)
v[j] = torch.sin(theta)
H = torch.eye(self.h_dim).to(theta.device) - 2 * torch.outer(v, v) / torch.dot(v, v)
x = torch.matmul(input=x, other=H)
return x
def orthogonal_rotation(self, x):
adjusted_rot = int(torch.round(self.rot_scale * self.rot))
for k in range(adjusted_rot):
i, j = self.r_pairs[k].long()
theta = self.thetas[k] * self.theta_scale
R = torch.eye(self.h_dim).to(theta.device)
R[i, i] = torch.cos(theta)
R[i, j] = -torch.sin(theta)
R[j, i] = torch.sin(theta)
R[j, j] = torch.cos(theta)
x = torch.matmul(input=x, other=R)
return x
def update_base(self, new_base):
if new_base is not None and new_base!= self.base:
self.base = new_base
inv_freq = 1.0 / (self.base ** (torch.arange(start=0, end=self.h_dim, step=2).float() / self.h_dim))
self.inv_freq.data.copy_(inv_freq)
self.update_pairs()
def reset_parameters(self):
nn.init.orthogonal_(self.r_matrix)
nn.init.zeros_(self.thetas)
def orthogonal_regularization_term(self):
loss = torch.tensor(0.0, device=self.r_matrix.device)
if self.r_matrix.requires_grad:
product = torch.matmul(self.r_matrix, self.r_matrix.t())
identity = torch.eye(self.r_matrix.size(0)).to(self.r_matrix.device)
loss = ((product - identity) ** 2).sum()
return self.orthogonal_reg_weight * loss
def update_pairs(self):
pairs = []
while len(pairs) < self.rot:
i, j = torch.randint(0, self.h_dim - 1, (2,))
if i!= j and (i, j) not in pairs and (j, i) not in pairs:
pairs.append((i, j))
self.r_pairs.data.copy_(torch.tensor(pairs, dtype=torch.float32))
def forward(self, x, global_step=None):
if x.dim() not in [3, 4]:
raise ValueError(f"Expected input tensor to be 3D or 4D, but got {x.dim()}D")
batch_size, seq_len, *rest = x.size()
if x.dim() == 3:
dims = rest[0]
if dims!= self.head * self.h_dim:
raise ValueError(
f"Expected dims ({dims}) to be compatible with head ({self.head}) * h_dim ({self.h_dim}={self.head * self.h_dim})")
else:
head, h_dim = rest
if head!= self.head or h_dim!= self.h_dim:
raise ValueError(
f"For 4D input, expected head {self.head} and h_dim {self.h_dim}, but got head {head} and h_dim {h_dim}")
x = x.view(batch_size, seq_len, self.head, self.h_dim)
x = x.reshape(-1, self.h_dim)
x = self.rotation_function(x)
x = torch.matmul(input=x, other=self.r_matrix)
x = x.view(batch_size, seq_len, self.head, self.h_dim)
sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(end=seq_len, device=x.device),
self.inv_freq.to(device=x.device))
sin = sinusoid_inp.sin()[None, :, None, :]
cos = sinusoid_inp.cos()[None, :, None, :]
x1, x2 = x[..., ::2], x[..., 1::2]
x = torch.cat(tensors=[x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)
x = x.view(batch_size, seq_len, self.dims)
return x
class MultiRotationLayer(nn.Module):
def __init__(self, input_dim, output_dim, num_scales=3):
super(MultiRotationLayer, self).__init__()
self.num_scales = num_scales
self.rotations = nn.ModuleList([self._create_rotation(input_dim // (2**i)) for i in range(num_scales)])
self.scale = nn.Parameter(torch.ones(num_scales))
def _create_rotation(self, input_dim):
return nn.Sequential(
GivensRotation(input_dim // 3),
HouseholderRotation(input_dim // 3),
OrthogonalRotation(input_dim // 3)
)
def forward(self, x):
outputs = []
for i, rotation in enumerate(self.rotations):
output = rotation(x[:, ::(2**i)])
outputs.append(output)
x = torch.cat(outputs, dim=1)
return x
def update(self, loss):
# Calculate the scale of the rotations based on the loss
self.scale.data = 1 / (1 + torch.exp(-loss))
# Update the rotations based on the scale
for i, rotation in enumerate(self.rotations):
rotation.scale = self.scale[i]
def get_scale(self):
return self.scale