emanuelaboros commited on
Commit
6ff6c37
1 Parent(s): 9f3ce07

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +164 -0
handler.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ from typing import List, Dict, Any
4
+ import requests
5
+ import nltk
6
+
7
+ # Download required NLTK models
8
+ nltk.download("averaged_perceptron_tagger")
9
+ nltk.download("averaged_perceptron_tagger_eng")
10
+
11
+ # Define your model name
12
+ NEL_MODEL = "nel-mgenre-multilingual"
13
+
14
+ class NelPipeline:
15
+ def __init__(self, model_name: str):
16
+ self.model_name = model_name
17
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(self.device)
20
+
21
+ def preprocess(self, text: str):
22
+ start_token = "[START]"
23
+ end_token = "[END]"
24
+
25
+ if start_token in text and end_token in text:
26
+ start_idx = text.index(start_token) + len(start_token)
27
+ end_idx = text.index(end_token)
28
+ enclosed_entity = text[start_idx:end_idx].strip()
29
+ lOffset = start_idx
30
+ rOffset = end_idx
31
+ else:
32
+ enclosed_entity = None
33
+ lOffset = None
34
+ rOffset = None
35
+
36
+ outputs = self.model.generate(
37
+ **self.tokenizer([text], return_tensors="pt").to(self.device),
38
+ num_beams=1,
39
+ num_return_sequences=1,
40
+ max_new_tokens=30,
41
+ return_dict_in_generate=True,
42
+ output_scores=True,
43
+ )
44
+ wikipedia_prediction = self.tokenizer.batch_decode(
45
+ outputs.sequences, skip_special_tokens=True
46
+ )[0]
47
+
48
+ transition_scores = self.model.compute_transition_scores(
49
+ outputs.sequences, outputs.scores, normalize_logits=True
50
+ )
51
+ log_prob_sum = sum(transition_scores[0])
52
+ sequence_confidence = torch.exp(log_prob_sum)
53
+ percentage = sequence_confidence.cpu().numpy() * 100.0
54
+
55
+ return wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage
56
+
57
+ def postprocess(self, outputs):
58
+ wikipedia_prediction, enclosed_entity, lOffset, rOffset, percentage = outputs
59
+
60
+ qid, language = get_wikipedia_page_props(wikipedia_prediction)
61
+ title, url = get_wikipedia_title(qid, language=language)
62
+
63
+ results = [
64
+ {
65
+ "surface": enclosed_entity,
66
+ "wkd_id": qid,
67
+ "wkpedia_pagename": title,
68
+ "wkpedia_url": url,
69
+ "type": "UNK",
70
+ "confidence_nel": round(percentage, 2),
71
+ "lOffset": lOffset,
72
+ "rOffset": rOffset,
73
+ }
74
+ ]
75
+ return results
76
+
77
+
78
+ def get_wikipedia_page_props(input_str: str):
79
+ if ">>" not in input_str:
80
+ page_name = input_str
81
+ language = "en"
82
+ else:
83
+ try:
84
+ page_name, language = input_str.split(">>")
85
+ page_name = page_name.strip()
86
+ language = language.strip()
87
+ except:
88
+ page_name = input_str
89
+ language = "en"
90
+ wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
91
+ wikipedia_params = {
92
+ "action": "query",
93
+ "prop": "pageprops",
94
+ "format": "json",
95
+ "titles": page_name,
96
+ }
97
+
98
+ qid = "NIL"
99
+ try:
100
+ response = requests.get(wikipedia_url, params=wikipedia_params)
101
+ response.raise_for_status()
102
+ data = response.json()
103
+
104
+ if "pages" in data["query"]:
105
+ page_id = list(data["query"]["pages"].keys())[0]
106
+
107
+ if "pageprops" in data["query"]["pages"][page_id]:
108
+ page_props = data["query"]["pages"][page_id]["pageprops"]
109
+
110
+ if "wikibase_item" in page_props:
111
+ return page_props["wikibase_item"], language
112
+ else:
113
+ return qid, language
114
+ else:
115
+ return qid, language
116
+ else:
117
+ return qid, language
118
+ except Exception as e:
119
+ return qid, language
120
+
121
+
122
+ def get_wikipedia_title(qid, language="en"):
123
+ url = f"https://www.wikidata.org/w/api.php"
124
+ params = {
125
+ "action": "wbgetentities",
126
+ "format": "json",
127
+ "ids": qid,
128
+ "props": "sitelinks/urls",
129
+ "sitefilter": f"{language}wiki",
130
+ }
131
+
132
+ response = requests.get(url, params=params)
133
+ try:
134
+ response.raise_for_status()
135
+ data = response.json()
136
+ except requests.exceptions.RequestException as e:
137
+ return "NIL", "None"
138
+ except ValueError as e:
139
+ return "NIL", "None"
140
+
141
+ try:
142
+ title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
143
+ url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
144
+ return title, url
145
+ except KeyError:
146
+ return "NIL", "None"
147
+
148
+
149
+ class EndpointHandler:
150
+ def __init__(self, path: str = None):
151
+ # Initialize the NelPipeline with the specified model
152
+ self.pipeline = NelPipeline(NEL_MODEL)
153
+
154
+ def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
155
+ # Process incoming data
156
+ inputs = data.get("inputs", "")
157
+ if not isinstance(inputs, str):
158
+ raise ValueError("Input must be a string.")
159
+
160
+ # Preprocess, forward, and postprocess
161
+ preprocessed = self.pipeline.preprocess(inputs)
162
+ results = self.pipeline.postprocess(preprocessed)
163
+
164
+ return results