dappyx's picture
Upload 4 files
75c80a0 verified
raw
history blame
1.04 kB
import torch
import torch.nn as nn
from model import (
SwitchTransformer,
SwitchTransformerLayer,
MultiHeadAttention,
SwitchFeedForward,
FeedForward,
)
from transformers import AutoTokenizer
device = 'cpu'
ff = FeedForward(768, 768*4)
attn = MultiHeadAttention(8, 768, 0.2)
st_ff = SwitchFeedForward(
capacity_factor=1.25,
drop_tokens=False,
n_experts=4,
expert=ff,
d_model=768,
is_scale_prob=True,
)
st_layer = SwitchTransformerLayer(
d_model=768,
attn=attn,
feed_forward=st_ff,
dropout_prob=0.2
)
model = SwitchTransformer(
layer=st_layer,
n_layers=4,
n_experts=4,
device=device,
load_balancing_loss_ceof=0.05,
).to(device)
model.load_state_dict(torch.load("switch_transformer.pt"))
tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz")