File size: 623 Bytes
6c333b4
 
 
 
 
 
 
c9dfa9e
6c333b4
 
 
 
 
 
 
 
c9dfa9e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import PretrainedConfig

import torch

class BiLSTMConfig(PretrainedConfig):
    def __init__(self, vocab_size=23626, embed_dim=100,
                 num_layers=1, hidden_dim=256, dropout=0.33,
                 output_dim=128, predict_output=10, device="cuda:0", **kwargs):
        
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.dropout = dropout
        self.output_dim = output_dim
        self.predict_output = predict_output
        self.device = device