File size: 5,293 Bytes
8d73145
 
 
0e93feb
8d73145
 
0e93feb
 
f36c5fb
8d73145
 
 
 
0e93feb
 
 
 
 
8d73145
 
5dbef48
8d73145
 
0e93feb
8d73145
0e93feb
 
 
 
 
 
 
 
8d73145
 
0e93feb
 
 
5dbef48
0e93feb
5dbef48
8d73145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e93feb
8d73145
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os

from .configuration_stacked import ImpressoConfig

logger = logging.getLogger(__name__)


def get_info(label_map):
    num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
    return num_token_labels_dict


class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):

    config_class = ImpressoConfig
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)
        print("Current folder path:", os.path.dirname(os.path.abspath(__file__)))
        # Get the directory of the current script
        current_dir = os.path.dirname(os.path.abspath(__file__))
        # Construct the full path to label_map.json
        label_map_path = os.path.join(current_dir, "label_map.json")

        label_map = json.load(open(label_map_path, "r"))
        self.num_token_labels_dict = get_info(label_map)
        self.config = config

        import pdb

        pdb.set_trace()
        self.bert = AutoModel.from_pretrained(
            config.pretrained_config["_name_or_path"], config=config.pretrained_config
        )
        if "classifier_dropout" not in config.__dict__:
            classifier_dropout = 0.1
        else:
            classifier_dropout = (
                config.classifier_dropout
                if config.classifier_dropout is not None
                else config.hidden_dropout_prob
            )
        self.dropout = nn.Dropout(classifier_dropout)

        # Additional transformer layers
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=config.hidden_size, nhead=config.num_attention_heads
            ),
            num_layers=2,
        )

        # For token classification, create a classifier for each task
        self.token_classifiers = nn.ModuleDict(
            {
                task: nn.Linear(config.hidden_size, num_labels)
                for task, num_labels in self.num_token_labels_dict.items()
            }
        )

        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        token_labels: Optional[dict] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
        r"""
        token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
            Labels for computing the token classification loss. Keys should match the tasks.
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        bert_kwargs = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
            "position_ids": position_ids,
            "head_mask": head_mask,
            "inputs_embeds": inputs_embeds,
            "output_attentions": output_attentions,
            "output_hidden_states": output_hidden_states,
            "return_dict": return_dict,
        }

        if any(
            keyword in self.config.name_or_path.lower()
            for keyword in ["llama", "deberta"]
        ):
            bert_kwargs.pop("token_type_ids")
            bert_kwargs.pop("head_mask")

        outputs = self.bert(**bert_kwargs)

        # For token classification
        token_output = outputs[0]
        token_output = self.dropout(token_output)

        # Pass through additional transformer layers
        token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
            0, 1
        )

        # Collect the logits and compute the loss for each task
        task_logits = {}
        total_loss = 0
        for task, classifier in self.token_classifiers.items():
            logits = classifier(token_output)
            task_logits[task] = logits
            if token_labels and task in token_labels:
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(
                    logits.view(-1, self.num_token_labels_dict[task]),
                    token_labels[task].view(-1),
                )
                total_loss += loss

        if not return_dict:
            output = (task_logits,) + outputs[2:]
            return ((total_loss,) + output) if total_loss != 0 else output

        return TokenClassifierOutput(
            loss=total_loss,
            logits=task_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )