emanuelaboros's picture
Update handler.py
c4e0123 verified
raw
history blame
3.82 kB
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from typing import List, Dict, Any
import requests
import nltk
from transformers import pipeline
# Download required NLTK models
nltk.download("averaged_perceptron_tagger")
nltk.download("averaged_perceptron_tagger_eng")
# Define your model name
NEL_MODEL = "nel-mgenre-multilingual"
def get_wikipedia_page_props(input_str: str):
if ">>" not in input_str:
page_name = input_str
language = "en"
else:
try:
page_name, language = input_str.split(">>")
page_name = page_name.strip()
language = language.strip()
except:
page_name = input_str
language = "en"
wikipedia_url = f"https://{language}.wikipedia.org/w/api.php"
wikipedia_params = {
"action": "query",
"prop": "pageprops",
"format": "json",
"titles": page_name,
}
qid = "NIL"
try:
response = requests.get(wikipedia_url, params=wikipedia_params)
response.raise_for_status()
data = response.json()
if "pages" in data["query"]:
page_id = list(data["query"]["pages"].keys())[0]
if "pageprops" in data["query"]["pages"][page_id]:
page_props = data["query"]["pages"][page_id]["pageprops"]
if "wikibase_item" in page_props:
return page_props["wikibase_item"], language
else:
return qid, language
else:
return qid, language
else:
return qid, language
except Exception as e:
return qid, language
def get_wikipedia_title(qid, language="en"):
url = f"https://www.wikidata.org/w/api.php"
params = {
"action": "wbgetentities",
"format": "json",
"ids": qid,
"props": "sitelinks/urls",
"sitefilter": f"{language}wiki",
}
response = requests.get(url, params=params)
try:
response.raise_for_status()
data = response.json()
except requests.exceptions.RequestException as e:
return "NIL", "None"
except ValueError as e:
return "NIL", "None"
try:
title = data["entities"][qid]["sitelinks"][f"{language}wiki"]["title"]
url = data["entities"][qid]["sitelinks"][f"{language}wiki"]["url"]
return title, url
except KeyError:
return "NIL", "None"
class NelPipeline:
def __init__(self, model_dir: str = "."):
self.model_name = NEL_MODEL
print(f"Loading {model_dir}")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = pipeline("generic-nel", model="impresso-project/nel-mgenre-multilingual",
tokenizer=self.tokenizer,
trust_remote_code=True,
device=self.device)
def preprocess(self, text: str):
linked_entity = self.model(text)
return linked_entity
def postprocess(self, outputs):
linked_entity = outputs
return linked_entity
class EndpointHandler:
def __init__(self, path: str = None):
# Initialize the NelPipeline with the specified model
self.pipeline = NelPipeline("impresso-project/nel-mgenre-multilingual")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Process incoming data
inputs = data.get("inputs", "")
if not isinstance(inputs, str):
raise ValueError("Input must be a string.")
# Preprocess, forward, and postprocess
preprocessed = self.pipeline.preprocess(inputs)
results = self.pipeline.postprocess(preprocessed)
return results