yuta0306
first commit
565faca
raw
history blame
No virus
1.73 kB
import torch.nn as nn
from fam.llm.layers.attn import SelfAttention
from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm
class Block(nn.Module):
"""
Block class represents a single block in the model.
Args:
config (object): Configuration object containing parameters for the block.
Attributes:
ln_1 (object): Layer normalization for the attention layer.
ln_2 (object): Layer normalization for the feed-forward layer.
attn (object): Self-attention layer.
mlp (object): Multi-layer perceptron layer.
Methods:
forward(x): Performs forward pass through the block.
"""
def __init__(self, config):
super().__init__()
if config.norm_type == "rmsnorm":
if config.rmsnorm_eps is None:
raise Exception("RMSNorm requires rmsnorm_eps to be set")
self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm
self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm
elif config.norm_type == "layernorm":
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm
else:
raise Exception(f"Unknown norm type: {config.norm_type}")
self.attn = SelfAttention(config)
self.mlp = MLP(config)
def forward(self, x):
"""
Performs forward pass through the block.
Args:
x (tensor): Input tensor.
Returns:
tensor: Output tensor after passing through the block.
"""
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x