suayptalha commited on
Commit
0d741ac
·
verified ·
1 Parent(s): b975b95

Update modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +30 -7
modeling_minGRULM.py CHANGED
@@ -8,6 +8,26 @@ from .configuration_minGRULM import MinGRULMConfig
8
  from minGRU_pytorch.minGRULM import minGRULM
9
 
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  class MinGRULMPreTrainedModel(PreTrainedModel):
12
  config_class = MinGRULMConfig
13
  base_model_prefix = "model"
@@ -28,24 +48,27 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
28
  def __init__(self, config: MinGRULMConfig):
29
  super().__init__(config)
30
 
31
- # Load model from minGRULM library
32
- self.model = minGRULM(
33
  num_tokens=config.vocab_size,
34
  dim=config.d_model,
35
  depth=config.n_layer,
36
  ff_mult=config.ff_mult,
37
- min_gru_expansion=config.expand,
38
  enable_conv=config.enable_conv,
39
  )
 
40
 
 
41
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
 
42
  self.post_init()
43
 
44
  def get_input_embeddings(self):
45
- return self.model.token_emb
46
 
47
  def set_input_embeddings(self, value):
48
- self.model.token_emb = value
49
 
50
  def get_output_embeddings(self):
51
  return self.lm_head
@@ -56,7 +79,7 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
56
  labels: Optional[torch.LongTensor] = None,
57
  return_dict: Optional[bool] = True,
58
  ):
59
- # Forward pass through the model
60
  logits = self.model(input_ids)
61
 
62
  loss = None
@@ -75,4 +98,4 @@ class MinGRULMForCausalLM(MinGRULMPreTrainedModel):
75
  return CausalLMOutputWithPast(
76
  loss=loss,
77
  logits=logits,
78
- )
 
8
  from minGRU_pytorch.minGRULM import minGRULM
9
 
10
 
11
+ # Wrapper class for device compatibility
12
+ class MinGRULMWrapped(nn.Module):
13
+ def __init__(self, min_gru_model):
14
+ super().__init__()
15
+ self.min_gru_model = min_gru_model
16
+ self.device = torch.device("cpu") # Default device
17
+
18
+ def forward(self, *args, **kwargs):
19
+ # Move input tensors to the correct device
20
+ args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
21
+ kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
22
+ return self.min_gru_model(*args, **kwargs)
23
+
24
+ def to(self, device):
25
+ # Update device information
26
+ self.device = device
27
+ self.min_gru_model.to(device)
28
+ return self
29
+
30
+
31
  class MinGRULMPreTrainedModel(PreTrainedModel):
32
  config_class = MinGRULMConfig
33
  base_model_prefix = "model"
 
48
  def __init__(self, config: MinGRULMConfig):
49
  super().__init__(config)
50
 
51
+ # Load model from minGRULM library and wrap it
52
+ raw_min_gru = minGRULM(
53
  num_tokens=config.vocab_size,
54
  dim=config.d_model,
55
  depth=config.n_layer,
56
  ff_mult=config.ff_mult,
57
+ min_gru_expansion=config.min_gru_expansion,
58
  enable_conv=config.enable_conv,
59
  )
60
+ self.model = MinGRULMWrapped(raw_min_gru)
61
 
62
+ # Language modeling head
63
  self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
64
+
65
  self.post_init()
66
 
67
  def get_input_embeddings(self):
68
+ return self.model.min_gru_model.token_emb
69
 
70
  def set_input_embeddings(self, value):
71
+ self.model.min_gru_model.token_emb = value
72
 
73
  def get_output_embeddings(self):
74
  return self.lm_head
 
79
  labels: Optional[torch.LongTensor] = None,
80
  return_dict: Optional[bool] = True,
81
  ):
82
+ # Forward pass through the wrapped model
83
  logits = self.model(input_ids)
84
 
85
  loss = None
 
98
  return CausalLMOutputWithPast(
99
  loss=loss,
100
  logits=logits,
101
+ )