import torch from torch import nn from transformers import PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput from typing import Optional from .configuration_minGRU import MinGRUConfig from minGRU_pytorch.minGRU import minGRU class MinGRUWrapped(nn.Module): def __init__(self, min_gru_model): super().__init__() self.min_gru_model = min_gru_model self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def forward(self, *args, **kwargs): args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args] kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} return self.min_gru_model(*args, **kwargs) def to(self, device): self.device = device self.min_gru_model.to(device) return self class MinGRUPreTrainedModel(PreTrainedModel): config_class = MinGRUConfig base_model_prefix = "model" def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) for name, param in module.named_parameters(): if torch.isnan(param).any(): print(f"NaN detected in parameter {name}. Replacing with a safe number.") param.data = torch.nan_to_num(param.data, nan=1e-6) class MinGRUForSequenceClassification(PreTrainedModel): config_class = MinGRUConfig base_model_prefix = "model" def __init__(self, config: MinGRUConfig): super().__init__(config) self.embedding = nn.Embedding(config.vocab_size, config.d_model) raw_min_gru = minGRU( dim=config.d_model, expansion_factor=config.ff_mult ) self.model = MinGRUWrapped(raw_min_gru) # Final linear layer for classification self.classifier = nn.Linear(config.d_model, config.num_labels) self.post_init() def forward( self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = True, **kwargs ): embeddings = self.embedding(input_ids) logits = self.model(embeddings) pooled_output = logits.mean(dim=1) logits = self.classifier(pooled_output) # No need for additional layers here loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) if not return_dict: return (loss, logits) if loss is not None else (logits,) return SequenceClassifierOutput( loss=loss, logits=logits, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): """ Load model from a pretrained checkpoint. """ model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) for name, param in model.named_parameters(): if name in ['embedding.weight', 'model.min_gru_model.to_hidden_and_gate.weight', 'model.min_gru_model.to_out.weight']: if param is None or torch.isnan(param).any() or torch.isinf(param).any(): nn.init.xavier_normal_(param) # Başlatma işlemi print(f"Initialized parameter {name} manually.") return model def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True, **kwargs): """ Save the model and configuration to a directory. Args: save_directory (str): Directory to save the model. safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True. kwargs: Additional arguments like max_shard_size (ignored in this implementation). """ import os os.makedirs(save_directory, exist_ok=True) if safe_serialization: print("Saving with safe serialization.") state_dict = {} for name, param in self.model.min_gru_model.named_parameters(): state_dict[f"model.{name}"] = param for name, param in self.classifier.named_parameters(): state_dict[f"classifier.{name}"] = param state_dict['config'] = self.config.__dict__ torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin")) self.config.save_pretrained(save_directory) else: print("Saving without safe serialization.") super().save_pretrained(save_directory)