import torch import torch.nn as nn from model import ( SwitchTransformer, SwitchTransformerLayer, MultiHeadAttention, SwitchFeedForward, FeedForward, ) from transformers import AutoTokenizer device = 'cpu' ff = FeedForward(768, 768*4) attn = MultiHeadAttention(8, 768, 0.2) st_ff = SwitchFeedForward( capacity_factor=1.25, drop_tokens=False, n_experts=4, expert=ff, d_model=768, is_scale_prob=True, ) st_layer = SwitchTransformerLayer( d_model=768, attn=attn, feed_forward=st_ff, dropout_prob=0.2 ) model = SwitchTransformer( layer=st_layer, n_layers=4, n_experts=4, device=device, load_balancing_loss_ceof=0.05, ).to(device) model.load_state_dict(torch.load("switch_transformer.pt")) tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz")