Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from model import ( | |
SwitchTransformer, | |
SwitchTransformerLayer, | |
MultiHeadAttention, | |
SwitchFeedForward, | |
FeedForward, | |
) | |
from transformers import AutoTokenizer | |
device = 'cpu' | |
ff = FeedForward(768, 768*4) | |
attn = MultiHeadAttention(8, 768, 0.2) | |
st_ff = SwitchFeedForward( | |
capacity_factor=1.25, | |
drop_tokens=False, | |
n_experts=4, | |
expert=ff, | |
d_model=768, | |
is_scale_prob=True, | |
) | |
st_layer = SwitchTransformerLayer( | |
d_model=768, | |
attn=attn, | |
feed_forward=st_ff, | |
dropout_prob=0.2 | |
) | |
model = SwitchTransformer( | |
layer=st_layer, | |
n_layers=4, | |
n_experts=4, | |
device=device, | |
load_balancing_loss_ceof=0.05, | |
).to(device) | |
model.load_state_dict(torch.load("switch_transformer.pt")) | |
tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz") | |