File size: 5,923 Bytes
d868172
 
357be93
7d0539f
d868172
 
 
357be93
 
d868172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357be93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868172
a945a9c
bd79886
 
d868172
7d0539f
 
d868172
5b87284
7d0539f
 
 
 
 
 
25e0e5e
357be93
5b87284
d868172
 
 
357be93
7d0539f
d868172
 
 
 
 
 
 
 
 
 
 
357be93
 
 
 
 
 
 
 
 
 
 
 
 
7d0539f
d868172
ddc72bb
d868172
 
8d26403
d868172
 
483a49c
5b87284
357be93
7d0539f
357be93
483a49c
357be93
 
 
7d0539f
357be93
 
 
5b87284
d868172
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from transformers import Pipeline
import nltk
import requests
import torch

nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")

NEL_MODEL = "nel-mgenre-multilingual"


def get_wikipedia_page_props(input_str: str):
    """
    Retrieves the QID for a given Wikipedia page name from the specified language Wikipedia.
    If the request fails, it falls back to using the OpenRefine Wikidata API.

    Args:
        input_str (str): The input string in the format "page_name >> language".

    Returns:
        str: The QID or "NIL" if the QID is not found.
    """
    try:
        # Preprocess the input string
        page_name, language = input_str.split(" >> ")
        page_name = page_name.strip()
        language = language.strip()
    except ValueError:
        return "Invalid input format. Use 'page_name >> language'."

    wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
    wikipedia_params = {
        "action": "query",
        "prop": "pageprops",
        "format": "json",
        "titles": page_name,
    }

    qid = "NIL"
    try:
        # Attempt to fetch from Wikipedia API
        response = requests.get(wikipedia_url, params=wikipedia_params)
        response.raise_for_status()
        data = response.json()

        if "pages" in data["query"]:
            page_id = list(data["query"]["pages"].keys())[0]

            if "pageprops" in data["query"]["pages"][page_id]:
                page_props = data["query"]["pages"][page_id]["pageprops"]

                if "wikibase_item" in page_props:
                    return page_props["wikibase_item"]
                else:
                    return qid
            else:
                return qid
    except Exception as e:
        return qid


def get_wikipedia_title(qid, language="en"):
    url = f"https://www.wikidata.org/w/api.php"
    params = {
        "action": "wbgetentities",
        "format": "json",
        "ids": qid,
        "props": "sitelinks/urls",
        "sitefilter": f"{language}wiki",
    }

    response = requests.get(url, params=params)
    data = response.json()

    try:
        title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
        url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
        return title, url
    except KeyError:
        return "NIL", "None"


class NelPipeline(Pipeline):

    def _sanitize_parameters(self, **kwargs):
        preprocess_kwargs = {}
        if "text" in kwargs:
            preprocess_kwargs["text"] = kwargs["text"]

        return preprocess_kwargs, {}, {}

    def preprocess(self, text, **kwargs):
        # Extract the entity between [START] and [END]
        start_token = "[START]"
        end_token = "[END]"

        if start_token in text and end_token in text:
            start_idx = text.index(start_token) + len(start_token)
            end_idx = text.index(end_token)
            enclosed_entity = text[start_idx:end_idx].strip()
            lOffset = start_idx  # left offset (start of the entity)
            rOffset = end_idx  # right offset (end of the entity)
        else:
            enclosed_entity = None
            lOffset = None
            rOffset = None

        # Generate predictions using the model
        outputs = self.model.generate(
            **self.tokenizer([text], return_tensors="pt").to(self.device),
            num_beams=1,
            num_return_sequences=1,
            max_new_tokens=30,
            return_dict_in_generate=True,
            output_scores=True,
        )

        token_ids, scores = outputs.sequences, outputs.sequences_scores

        # Process scores and normalize
        scores_tensor = scores.clone().detach()
        probabilities = torch.exp(scores_tensor)
        percentages = (probabilities * 100.0).cpu().numpy().tolist()

        # Decode the predictions into readable text
        wikipedia_predictions = self.tokenizer.batch_decode(
            outputs, skip_special_tokens=True
        )

        # Return the predictions along with the extracted entity, lOffset, and rOffset
        return wikipedia_predictions, enclosed_entity, lOffset, rOffset, percentages

    def _forward(self, inputs):
        return inputs

    def postprocess(self, outputs, **kwargs):
        """
        Postprocess the outputs of the model
        :param outputs:
        :param kwargs:
        :return:
        """

        # {
        #     "surface": sentences[i].split("[START]")[1].split("[END]")[0],
        #     "lOffset": lOffset,
        #     "rOffset": rOffset,
        #     "type": "UNK",
        #     "id": f"{lOffset}:{rOffset}:{surface}:{NEL_MODEL}",
        #     "wkd_id": get_wikipedia_page_props(wikipedia_titles[i * 2]),
        #     "wkpedia_pagename": wikipedia_titles[
        #         i * 2
        #         ],  # This can be improved with a real API call to get the QID
        #     "confidence_nel": np.round(percentages[i], 2),
        # }
        wikipedia_predictions, enclosed_entity, lOffset, rOffset, percentages = outputs
        results = []
        for idx, wikipedia_name in enumerate(wikipedia_predictions):
            # Get QID
            qid = get_wikipedia_page_props(wikipedia_name)
            # print(f"{wikipedia_name} -- QID: {qid}")

            # Get Wikipedia title and URL
            wkpedia_pagename, url = get_wikipedia_title(qid)
            results.append(
                {
                    # "id": f"{lOffset}:{rOffset}:{enclosed_entity}:{NEL_MODEL}",
                    "surface": enclosed_entity,
                    "wkpedia_pagename": wkpedia_pagename,
                    "wkd_id": qid,
                    "url": url,
                    "type": "UNK",
                    "confidence_nel": percentages[idx],
                    "lOffset": lOffset,
                    "rOffset": rOffset,
                }
            )

        return results