dappyx commited on
Commit
cfd9f7e
1 Parent(s): 5cea6c9

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +41 -41
main.py CHANGED
@@ -1,41 +1,41 @@
1
- import torch
2
- import torch.nn as nn
3
- from model import (
4
- SwitchTransformer,
5
- SwitchTransformerLayer,
6
- MultiHeadAttention,
7
- SwitchFeedForward,
8
- FeedForward,
9
- )
10
- from transformers import AutoTokenizer
11
-
12
- device = 'cpu'
13
-
14
- ff = FeedForward(768, 768*4)
15
- attn = MultiHeadAttention(8, 768, 0.2)
16
- st_ff = SwitchFeedForward(
17
- capacity_factor=1.25,
18
- drop_tokens=False,
19
- n_experts=4,
20
- expert=ff,
21
- d_model=768,
22
- is_scale_prob=True,
23
- )
24
- st_layer = SwitchTransformerLayer(
25
- d_model=768,
26
- attn=attn,
27
- feed_forward=st_ff,
28
- dropout_prob=0.2
29
- )
30
- model = SwitchTransformer(
31
- layer=st_layer,
32
- n_layers=4,
33
- n_experts=4,
34
- device=device,
35
- load_balancing_loss_ceof=0.05,
36
- ).to(device)
37
-
38
- model.load_state_dict(torch.load("switch_transformer.pt"))
39
- tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz")
40
-
41
-
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from model import (
4
+ SwitchTransformer,
5
+ SwitchTransformerLayer,
6
+ MultiHeadAttention,
7
+ SwitchFeedForward,
8
+ FeedForward,
9
+ )
10
+ from transformers import AutoTokenizer
11
+
12
+ device = 'cpu'
13
+
14
+ ff = FeedForward(768, 768*4)
15
+ attn = MultiHeadAttention(8, 768, 0.2)
16
+ st_ff = SwitchFeedForward(
17
+ capacity_factor=1.25,
18
+ drop_tokens=False,
19
+ n_experts=4,
20
+ expert=ff,
21
+ d_model=768,
22
+ is_scale_prob=True,
23
+ )
24
+ st_layer = SwitchTransformerLayer(
25
+ d_model=768,
26
+ attn=attn,
27
+ feed_forward=st_ff,
28
+ dropout_prob=0.2
29
+ )
30
+ model = SwitchTransformer(
31
+ layer=st_layer,
32
+ n_layers=4,
33
+ n_experts=4,
34
+ device=device,
35
+ load_balancing_loss_ceof=0.05,
36
+ ).to(device)
37
+
38
+ model.load_state_dict(torch.load("switch_transformer.pt", map_location=torch.device('cpu')))
39
+ tokenizer = AutoTokenizer.from_pretrained("Kyrmasch/kaz-roberta-squad2-kaz")
40
+
41
+