Spaces:
Running
on
T4
Running
on
T4
File size: 1,730 Bytes
565faca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import torch.nn as nn
from fam.llm.layers.attn import SelfAttention
from fam.llm.layers.layers import MLP, LayerNorm, RMSNorm
class Block(nn.Module):
"""
Block class represents a single block in the model.
Args:
config (object): Configuration object containing parameters for the block.
Attributes:
ln_1 (object): Layer normalization for the attention layer.
ln_2 (object): Layer normalization for the feed-forward layer.
attn (object): Self-attention layer.
mlp (object): Multi-layer perceptron layer.
Methods:
forward(x): Performs forward pass through the block.
"""
def __init__(self, config):
super().__init__()
if config.norm_type == "rmsnorm":
if config.rmsnorm_eps is None:
raise Exception("RMSNorm requires rmsnorm_eps to be set")
self.ln_1 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # attn norm
self.ln_2 = RMSNorm(config.n_embd, eps=config.rmsnorm_eps) # ffn norm
elif config.norm_type == "layernorm":
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) # attn norm
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) # ffn norm
else:
raise Exception(f"Unknown norm type: {config.norm_type}")
self.attn = SelfAttention(config)
self.mlp = MLP(config)
def forward(self, x):
"""
Performs forward pass through the block.
Args:
x (tensor): Input tensor.
Returns:
tensor: Output tensor after passing through the block.
"""
x = x + self.attn(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
|