File size: 5,923 Bytes
d868172 357be93 7d0539f d868172 357be93 d868172 357be93 d868172 a945a9c bd79886 d868172 7d0539f d868172 5b87284 7d0539f 25e0e5e 357be93 5b87284 d868172 357be93 7d0539f d868172 357be93 7d0539f d868172 ddc72bb d868172 8d26403 d868172 483a49c 5b87284 357be93 7d0539f 357be93 483a49c 357be93 7d0539f 357be93 5b87284 d868172 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from transformers import Pipeline
import nltk
import requests
import torch
nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")
NEL_MODEL = "nel-mgenre-multilingual"
def get_wikipedia_page_props(input_str: str):
"""
Retrieves the QID for a given Wikipedia page name from the specified language Wikipedia.
If the request fails, it falls back to using the OpenRefine Wikidata API.
Args:
input_str (str): The input string in the format "page_name >> language".
Returns:
str: The QID or "NIL" if the QID is not found.
"""
try:
# Preprocess the input string
page_name, language = input_str.split(" >> ")
page_name = page_name.strip()
language = language.strip()
except ValueError:
return "Invalid input format. Use 'page_name >> language'."
wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
wikipedia_params = {
"action": "query",
"prop": "pageprops",
"format": "json",
"titles": page_name,
}
qid = "NIL"
try:
# Attempt to fetch from Wikipedia API
response = requests.get(wikipedia_url, params=wikipedia_params)
response.raise_for_status()
data = response.json()
if "pages" in data["query"]:
page_id = list(data["query"]["pages"].keys())[0]
if "pageprops" in data["query"]["pages"][page_id]:
page_props = data["query"]["pages"][page_id]["pageprops"]
if "wikibase_item" in page_props:
return page_props["wikibase_item"]
else:
return qid
else:
return qid
except Exception as e:
return qid
def get_wikipedia_title(qid, language="en"):
url = f"https://www.wikidata.org/w/api.php"
params = {
"action": "wbgetentities",
"format": "json",
"ids": qid,
"props": "sitelinks/urls",
"sitefilter": f"{language}wiki",
}
response = requests.get(url, params=params)
data = response.json()
try:
title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
return title, url
except KeyError:
return "NIL", "None"
class NelPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "text" in kwargs:
preprocess_kwargs["text"] = kwargs["text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, **kwargs):
# Extract the entity between [START] and [END]
start_token = "[START]"
end_token = "[END]"
if start_token in text and end_token in text:
start_idx = text.index(start_token) + len(start_token)
end_idx = text.index(end_token)
enclosed_entity = text[start_idx:end_idx].strip()
lOffset = start_idx # left offset (start of the entity)
rOffset = end_idx # right offset (end of the entity)
else:
enclosed_entity = None
lOffset = None
rOffset = None
# Generate predictions using the model
outputs = self.model.generate(
**self.tokenizer([text], return_tensors="pt").to(self.device),
num_beams=1,
num_return_sequences=1,
max_new_tokens=30,
return_dict_in_generate=True,
output_scores=True,
)
token_ids, scores = outputs.sequences, outputs.sequences_scores
# Process scores and normalize
scores_tensor = scores.clone().detach()
probabilities = torch.exp(scores_tensor)
percentages = (probabilities * 100.0).cpu().numpy().tolist()
# Decode the predictions into readable text
wikipedia_predictions = self.tokenizer.batch_decode(
outputs, skip_special_tokens=True
)
# Return the predictions along with the extracted entity, lOffset, and rOffset
return wikipedia_predictions, enclosed_entity, lOffset, rOffset, percentages
def _forward(self, inputs):
return inputs
def postprocess(self, outputs, **kwargs):
"""
Postprocess the outputs of the model
:param outputs:
:param kwargs:
:return:
"""
# {
# "surface": sentences[i].split("[START]")[1].split("[END]")[0],
# "lOffset": lOffset,
# "rOffset": rOffset,
# "type": "UNK",
# "id": f"{lOffset}:{rOffset}:{surface}:{NEL_MODEL}",
# "wkd_id": get_wikipedia_page_props(wikipedia_titles[i * 2]),
# "wkpedia_pagename": wikipedia_titles[
# i * 2
# ], # This can be improved with a real API call to get the QID
# "confidence_nel": np.round(percentages[i], 2),
# }
wikipedia_predictions, enclosed_entity, lOffset, rOffset, percentages = outputs
results = []
for idx, wikipedia_name in enumerate(wikipedia_predictions):
# Get QID
qid = get_wikipedia_page_props(wikipedia_name)
# print(f"{wikipedia_name} -- QID: {qid}")
# Get Wikipedia title and URL
wkpedia_pagename, url = get_wikipedia_title(qid)
results.append(
{
# "id": f"{lOffset}:{rOffset}:{enclosed_entity}:{NEL_MODEL}",
"surface": enclosed_entity,
"wkpedia_pagename": wkpedia_pagename,
"wkd_id": qid,
"url": url,
"type": "UNK",
"confidence_nel": percentages[idx],
"lOffset": lOffset,
"rOffset": rOffset,
}
)
return results
|