File size: 10,551 Bytes
5915b56
4fd1faf
5915b56
61675e4
c549c79
 
 
4fd1faf
 
 
5915b56
a84fd08
5915b56
2447e1e
 
4fd1faf
 
 
 
 
 
 
e94a65a
 
 
4fd1faf
 
e94a65a
 
 
 
4fd1faf
e94a65a
 
4fd1faf
e94a65a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed9f086
7d319f7
e94a65a
 
 
 
 
 
 
 
4fd1faf
 
 
 
 
 
 
 
 
 
b7624fb
4fd1faf
 
 
 
 
 
 
 
 
08dfe16
a84fd08
 
 
08dfe16
 
b7624fb
 
08dfe16
 
a84fd08
 
 
 
 
08dfe16
 
4fd1faf
 
 
c549c79
 
 
34f241e
c549c79
 
 
3d63834
4fd1faf
 
 
4bd2480
4fd1faf
 
 
 
 
 
 
 
 
 
 
 
18246b0
4fd1faf
720e26b
4fd1faf
 
720e26b
 
 
 
 
 
 
 
 
 
4fd1faf
 
 
 
a84fd08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fa92dc
 
 
 
 
 
 
 
 
 
4fd1faf
2709af2
 
 
 
 
 
4fd1faf
 
2709af2
4fd1faf
2709af2
4fd1faf
51adc94
5915b56
18246b0
 
 
 
 
 
5915b56
 
18246b0
 
 
 
 
 
 
8c3be27
18246b0
 
 
c549c79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
006e0ea
471ce47
53b6bfb
18246b0
 
 
 
53b6bfb
18246b0
 
 
 
 
6ff0d20
 
18246b0
5915b56
50980d2
18246b0
 
50980d2
 
18246b0
 
 
 
 
 
c549c79
5915b56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
from transformers import Pipeline
import numpy as np
import torch
import nltk

nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")
from nltk.chunk import conlltags2tree
from nltk import pos_tag
from nltk.tree import Tree
import string
import torch.nn.functional as F
import re, string


def tokenize(text):
    # print(text)
    for punctuation in string.punctuation:
        text = text.replace(punctuation, " " + punctuation + " ")
    return text.split()


def normalize_text(text):
    # Remove spaces and tabs for the search but keep newline characters
    return re.sub(r"[ \t]+", "", text)


def find_entity_indices(article_text, search_text):
    # Normalize texts by removing spaces and tabs
    normalized_article = normalize_text(article_text)
    normalized_search = normalize_text(search_text)

    # Initialize a list to hold all start and end indices
    indices = []

    # Find all occurrences of the search text in the normalized article text
    start_index = 0
    while True:
        start_index = normalized_article.find(normalized_search, start_index)
        if start_index == -1:
            break

        # Calculate the actual start and end indices in the original article text
        original_chars = 0
        original_start_index = 0
        for i in range(start_index):
            while article_text[original_start_index] in (" ", "\t"):
                original_start_index += 1
            if article_text[original_start_index] not in (" ", "\t", "\n"):
                original_chars += 1
            original_start_index += 1

        original_end_index = original_start_index
        search_chars = 0
        while search_chars < len(normalized_search):
            if article_text[original_end_index] not in (" ", "\t", "\n"):
                search_chars += 1
            original_end_index += 1  # Increment to include the last character

        # Append the found indices to the list
        if article_text[original_start_index] == " ":
            original_start_index += 1
        indices.append((original_start_index, original_end_index))

        # Move start_index to the next position to continue searching
        start_index += 1

    return indices


def get_entities(tokens, tags, confidences, text):

    tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
    pos_tags = [pos for token, pos in pos_tag(tokens)]

    conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
    ne_tree = conlltags2tree(conlltags)

    entities = []
    idx: int = 0
    already_done = []
    for subtree in ne_tree:
        # skipping 'O' tags
        if isinstance(subtree, Tree):
            original_label = subtree.label()
            original_string = " ".join([token for token, pos in subtree.leaves()])

            for indices in find_entity_indices(text, original_string):
                entity_start_position = indices[0]
                entity_end_position = indices[1]
                if (
                    "_".join(
                        [original_label, original_string, str(entity_start_position)]
                    )
                    in already_done
                ):
                    continue
                else:
                    already_done.append(
                        "_".join(
                            [
                                original_label,
                                original_string,
                                str(entity_start_position),
                            ]
                        )
                    )
                entities.append(
                    {
                        "entity": original_label,
                        "score": round(
                            np.average(confidences[idx : idx + len(subtree)]) * 100, 2
                        ),
                        "index": (idx, idx + len(subtree)),
                        "word": text[
                            entity_start_position:entity_end_position
                        ],  # original_string,
                        "start": entity_start_position,
                        "end": entity_end_position,
                    }
                )

            idx += len(subtree)

            # Update the current character position
            # We add the length of the original string + 1 (for the space)
        else:
            token, pos = subtree
            # If it's not a named entity, we still need to update the character
            # position
            idx += 1

    return entities


def realign(
    text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
):
    preds_list, words_list, confidence_list = [], [], []
    word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
    for idx, word in enumerate(text_sentence):
        beginning_index = word_ids.index(idx)
        try:
            preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
            confidence_list.append(max(softmax_scores[beginning_index]))
        except Exception as ex:  # the sentence was longer then max_length
            preds_list.append("O")
            confidence_list.append(0.0)
        words_list.append(word)

    return words_list, preds_list, confidence_list


def segment_and_trim_sentences(article, language, max_length):

    try:
        segmenter = pysbd.Segmenter(language=language, clean=False)
    except:
        segmenter = pysbd.Segmenter(language="en", clean=False)

    sentences = segmenter.segment(article)

    trimmed_sentences = []
    for sentence in sentences:
        while len(sentence) > max_length:
            # Find the last space within max_length
            cut_index = sentence.rfind(" ", 0, max_length)
            if cut_index == -1:
                # If no space found, forcibly cut at max_length
                cut_index = max_length

            # Cut the sentence and add the first part to trimmed sentences
            trimmed_sentences.append(sentence[:cut_index])

            # Update the sentence to be the remaining part
            sentence = sentence[cut_index:].lstrip()

        # Add the remaining part of the sentence if it's not empty
        if sentence:
            trimmed_sentences.append(sentence)

    return trimmed_sentences


# List of additional "strange" punctuation marks
additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"


def add_spaces_around_punctuation(text):
    # Add a space before and after all punctuation
    all_punctuation = string.punctuation + additional_punctuation
    return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text)


class MultitaskTokenClassificationPipeline(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "text" in kwargs:
            preprocess_kwargs["text"] = kwargs["text"]
        self.label_map = self.model.config.label_map
        self.id2label = {
            task: {id_: label for label, id_ in labels.items()}
            for task, labels in self.label_map.items()
        }
        return preprocess_kwargs, {}, {}

    def preprocess(self, text, **kwargs):

        tokenized_inputs = self.tokenizer(
            text, padding="max_length", truncation=True, max_length=512
        )

        text_sentence = tokenize(add_spaces_around_punctuation(text))
        return tokenized_inputs, text_sentence, text

    def _forward(self, inputs):
        inputs, text_sentences, text = inputs
        input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
            self.model.device
        )
        attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
            self.model.device
        )
        with torch.no_grad():
            outputs = self.model(input_ids, attention_mask)
        return outputs, text_sentences, text

    def is_within(self, entity1, entity2):
        """Check if entity1 is fully within the bounds of entity2."""
        return entity1["start"] >= entity2["start"] and entity1["end"] <= entity2["end"]

    def postprocess_entities(self, ner_results):
        # Collect all entities in one list for processing
        all_entities = []
        for key in ner_results:
            all_entities.extend(ner_results[key])

        # Sort entities by start position, then by end position (to handle nested structures)
        all_entities.sort(key=lambda x: (x["start"], -x["end"]))

        # Create a new list for final processed entities
        final_entities = []

        # Process each entity and check for nesting
        for i, entity in enumerate(all_entities):
            nested = False

            # Compare the current entity with already processed entities
            for parent_entity in final_entities:
                if self.is_within(entity, parent_entity):
                    # If the current entity is nested, add it as a field in the parent entity
                    field_name = entity["entity"].split(".")[
                        -1
                    ]  # Last part of the label as the field
                    if field_name not in parent_entity:
                        parent_entity[field_name] = []
                    parent_entity[field_name].append(entity)
                    nested = True
                    break

            if not nested:
                # If not nested, add the entity as a new outermost entity
                final_entities.append(entity)

        return final_entities

    def postprocess(self, outputs, **kwargs):
        """
        Postprocess the outputs of the model
        :param outputs:
        :param kwargs:
        :return:
        """
        tokens_result, text_sentence, text = outputs

        predictions = {}
        confidence_scores = {}
        for task, logits in tokens_result.logits.items():
            predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
            confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]

        entities = {}
        for task in predictions.keys():
            words_list, preds_list, confidence_list = realign(
                text_sentence,
                predictions[task],
                confidence_scores[task],
                self.tokenizer,
                self.id2label[task],
            )

            entities[task] = get_entities(words_list, preds_list, confidence_list, text)

        print(self.postprocess_entities(entities))
        return entities