Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,218 Bytes
0f079b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from craftsman.utils.typing import *
from craftsman.utils.checkpoint import checkpoint
from .utils import init_linear
from .attention import ResidualAttentionBlock
class Perceiver(nn.Module):
def __init__(
self,
*,
n_ctx: int,
width: int,
layers: int,
heads: int,
init_scale: float = 0.25,
qkv_bias: bool = True,
use_flash: bool = False,
use_checkpoint: bool = False
):
super().__init__()
self.n_ctx = n_ctx
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList(
[
ResidualAttentionBlock(
n_ctx=n_ctx,
width=width,
heads=heads,
init_scale=init_scale,
qkv_bias=qkv_bias,
use_flash=use_flash,
use_checkpoint=use_checkpoint
)
for _ in range(layers)
]
)
def forward(self, x: torch.Tensor):
for block in self.resblocks:
x = block(x)
return x |