File size: 623 Bytes
6c333b4 c9dfa9e 6c333b4 c9dfa9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
from transformers import PretrainedConfig
import torch
class BiLSTMConfig(PretrainedConfig):
def __init__(self, vocab_size=23626, embed_dim=100,
num_layers=1, hidden_dim=256, dropout=0.33,
output_dim=128, predict_output=10, device="cuda:0", **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.dropout = dropout
self.output_dim = output_dim
self.predict_output = predict_output
self.device = device |