Spaces:
Running
on
Zero
Running
on
Zero
import numbers | |
from typing import Dict, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from pdb import set_trace as st | |
class RMSNorm(nn.Module): | |
def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True): | |
super().__init__() | |
self.eps = eps | |
if isinstance(dim, numbers.Integral): | |
dim = (dim,) | |
self.dim = torch.Size(dim) | |
if elementwise_affine: | |
self.weight = nn.Parameter(torch.ones(dim)) | |
else: | |
self.weight = None | |
def forward(self, hidden_states): | |
input_dtype = hidden_states.dtype | |
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | |
if self.weight is not None: | |
# convert into half-precision if necessary | |
if self.weight.dtype in [torch.float16, torch.bfloat16]: | |
hidden_states = hidden_states.to(self.weight.dtype) | |
hidden_states = hidden_states * self.weight | |
else: | |
hidden_states = hidden_states.to(input_dtype) | |
return hidden_states.to(input_dtype) | |