huseinzol05 commited on
Commit
246f8ef
1 Parent(s): e9f61b1

Upload MistralForSequenceClassification

Browse files
Files changed (2) hide show
  1. classifier.py +88 -0
  2. config.json +4 -5
classifier.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bidirectional_mistral import MistralBiModel
2
+ from transformers import MistralPreTrainedModel
3
+ import torch
4
+ import numpy as np
5
+ from typing import Optional, List
6
+ from torch import nn
7
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8
+ from transformers.modeling_outputs import SequenceClassifierOutputWithPast
9
+
10
+
11
+ class MistralForSequenceClassification(MistralPreTrainedModel):
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+ self.num_labels = config.num_labels
15
+ self.model = MistralBiModel(config)
16
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
17
+
18
+ # Initialize weights and apply final processing
19
+ self.post_init()
20
+
21
+ def forward(
22
+ self,
23
+ input_ids: torch.LongTensor = None,
24
+ attention_mask: Optional[torch.Tensor] = None,
25
+ position_ids: Optional[torch.LongTensor] = None,
26
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
27
+ inputs_embeds: Optional[torch.FloatTensor] = None,
28
+ labels: Optional[torch.LongTensor] = None,
29
+ use_cache: Optional[bool] = None,
30
+ output_attentions: Optional[bool] = None,
31
+ output_hidden_states: Optional[bool] = None,
32
+ return_dict: Optional[bool] = None,
33
+ ):
34
+ r"""
35
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
36
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
37
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
38
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
39
+ """
40
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
41
+
42
+ transformer_outputs = self.model(
43
+ input_ids,
44
+ attention_mask=attention_mask,
45
+ position_ids=position_ids,
46
+ past_key_values=past_key_values,
47
+ inputs_embeds=inputs_embeds,
48
+ use_cache=use_cache,
49
+ output_attentions=output_attentions,
50
+ output_hidden_states=output_hidden_states,
51
+ return_dict=return_dict,
52
+ )
53
+ pooled_output = transformer_outputs[0][:, 0]
54
+ logits = self.score(pooled_output)
55
+
56
+ loss = None
57
+ if labels is not None:
58
+ if self.config.problem_type is None:
59
+ if self.num_labels == 1:
60
+ self.config.problem_type = "regression"
61
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
62
+ self.config.problem_type = "single_label_classification"
63
+ else:
64
+ self.config.problem_type = "multi_label_classification"
65
+
66
+ if self.config.problem_type == "regression":
67
+ loss_fct = MSELoss()
68
+ if self.num_labels == 1:
69
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
70
+ else:
71
+ loss = loss_fct(logits, labels)
72
+ elif self.config.problem_type == "single_label_classification":
73
+ loss_fct = CrossEntropyLoss()
74
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
75
+ elif self.config.problem_type == "multi_label_classification":
76
+ loss_fct = BCEWithLogitsLoss()
77
+ loss = loss_fct(logits, labels)
78
+ if not return_dict:
79
+ output = (logits,) + transformer_outputs[2:]
80
+ return ((loss,) + output) if loss is not None else output
81
+
82
+ return SequenceClassifierOutputWithPast(
83
+ loss=loss,
84
+ logits=logits,
85
+ past_key_values=transformer_outputs.past_key_values,
86
+ hidden_states=transformer_outputs.hidden_states,
87
+ attentions=transformer_outputs.attentions,
88
+ )
config.json CHANGED
@@ -1,19 +1,18 @@
1
  {
2
- "_name_or_path": "mistral-191M-mlm/checkpoint-106000",
3
  "architectures": [
4
  "MistralForSequenceClassification"
5
  ],
6
  "attention_dropout": 0.0,
 
 
 
7
  "bos_token_id": 1,
8
  "eos_token_id": 2,
9
  "hidden_act": "silu",
10
  "hidden_size": 768,
11
  "initializer_range": 0.02,
12
  "intermediate_size": 3072,
13
- "label2id": {
14
- "contradiction": 0,
15
- "entailment": 1
16
- },
17
  "max_position_embeddings": 4096,
18
  "model_type": "mistral",
19
  "num_attention_heads": 16,
 
1
  {
2
+ "_name_or_path": "mnli-mistral-191M-MLM",
3
  "architectures": [
4
  "MistralForSequenceClassification"
5
  ],
6
  "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoModel": "classifier.MistralForSequenceClassification"
9
+ },
10
  "bos_token_id": 1,
11
  "eos_token_id": 2,
12
  "hidden_act": "silu",
13
  "hidden_size": 768,
14
  "initializer_range": 0.02,
15
  "intermediate_size": 3072,
 
 
 
 
16
  "max_position_embeddings": 4096,
17
  "model_type": "mistral",
18
  "num_attention_heads": 16,