OxMarkupLM / eval.py
rondaravaol
Inference
5a69a9a
from transformers import MarkupLMForTokenClassification
from transformers import MarkupLMProcessor
from code import utils, labels
import torch
import os
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
def rank_titles(titles, content):
vectorizer = TfidfVectorizer()
texts = titles + [content]
tfidf_matrix = vectorizer.fit_transform(texts)
cosine_similarities = cosine_similarity(tfidf_matrix[-1:], tfidf_matrix[:-1]).flatten()
ranked_titles_indices = np.argsort(cosine_similarities)[::-1]
ranked_titles = [titles[idx] for idx in ranked_titles_indices]
return ranked_titles
def eval(url):
current_dir = os.path.dirname(os.path.abspath(__file__))
model_folder = os.path.join(current_dir, 'models') # models folder is in the repository root
model_name = 'OxMarkupLM.pt'
processor = MarkupLMProcessor.from_pretrained("microsoft/markuplm-base")
processor.parse_html = False
model_path = os.path.join(model_folder, model_name)
model = MarkupLMForTokenClassification.from_pretrained(
model_path, id2label=labels.id2label, label2id=labels.label2id
)
html = utils.clean_html(utils.get_html_content(url))
data = [utils.extract_nodes_and_feautures(html)]
example = utils.split_sliding_data(data, 10, 0)
title, author, date, content = [], [], [], []
for splited in example:
nodes, xpaths = splited['nodes'], splited['xpaths']
encoding = processor(
nodes=nodes, xpaths=xpaths, return_offsets_mapping=True,
padding="max_length", truncation=True, max_length=512, return_tensors="pt"
)
offset_mapping = encoding.pop("offset_mapping")
with torch.no_grad():
logits = model(**encoding).logits
predictions = logits.argmax(-1)
processed_words = []
for pred_id, word_id, offset in zip(predictions[0].tolist(), encoding.word_ids(0), offset_mapping[0].tolist()):
if word_id is not None and offset[0] == 0:
if pred_id == 1:
title.append(nodes[word_id])
elif pred_id == 2 and word_id not in processed_words:
processed_words.append(word_id)
content.append(nodes[word_id])
elif pred_id == 3:
author.append(nodes[word_id])
elif pred_id == 4:
date.append(nodes[word_id])
title = rank_titles(title, '\n'.join(content))
return {
"model_name": model_name,
"url": url,
"title": title,
"author": author,
"date": date,
"content": content,
}