rondaravaol
commited on
Commit
·
5a69a9a
1
Parent(s):
6a54372
Inference
Browse files- .gitattributes +1 -0
- api_inference.py +11 -0
- code/labels.py +51 -0
- code/utils.py +96 -0
- eval.py +77 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.extension filter=lfs diff=lfs merge=lfs -text
|
api_inference.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from eval import eval
|
2 |
+
|
3 |
+
def query(payload):
|
4 |
+
url = payload.get("url", "")
|
5 |
+
if not url:
|
6 |
+
return {"error": "No URL provided"}
|
7 |
+
try:
|
8 |
+
result = eval(url)
|
9 |
+
return result
|
10 |
+
except Exception as e:
|
11 |
+
return {"error": str(e)}
|
code/labels.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from collections import Counter
|
3 |
+
|
4 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
5 |
+
|
6 |
+
id2label = {0: "none", 1:"title", 2:"content", 3:"author", 4: "date", 5: "header", 6: "footer", 7: "rail", 8: "advertisement", 9: "navigation"}
|
7 |
+
label2id = {label:id for id, label in id2label.items()}
|
8 |
+
|
9 |
+
label_list = ["B-" + x for x in list(id2label.values())]
|
10 |
+
print(label_list)
|
11 |
+
|
12 |
+
def get_class_weights_tensor(dataset):
|
13 |
+
all_labels = [label for data_item in dataset for label in data_item['labels']]
|
14 |
+
|
15 |
+
# Count the frequency of each label
|
16 |
+
label_counter = Counter(all_labels)
|
17 |
+
|
18 |
+
# Calculate the class weights
|
19 |
+
total_count = sum(label_counter.values())
|
20 |
+
class_weights = {label: total_count / count for label, count in label_counter.items()}
|
21 |
+
|
22 |
+
# Normalize the weights
|
23 |
+
sum_weights = sum(class_weights.values())
|
24 |
+
normalized_class_weights = {label: weight / sum_weights for label, weight in class_weights.items()}
|
25 |
+
|
26 |
+
# Convert class weights to a tensor
|
27 |
+
class_weights_list = [normalized_class_weights[label] for label in sorted(normalized_class_weights.keys())]
|
28 |
+
class_weights_tensor = torch.tensor(class_weights_list, dtype=torch.float).to(device)
|
29 |
+
|
30 |
+
return class_weights_tensor
|
31 |
+
|
32 |
+
|
33 |
+
def get_labels(predictions, references):
|
34 |
+
# Transform predictions and references tensos to numpy arrays
|
35 |
+
if device.type == "cpu":
|
36 |
+
y_pred = predictions.detach().clone().numpy()
|
37 |
+
y_true = references.detach().clone().numpy()
|
38 |
+
else:
|
39 |
+
y_pred = predictions.detach().cpu().clone().numpy()
|
40 |
+
y_true = references.detach().cpu().clone().numpy()
|
41 |
+
|
42 |
+
# Remove ignored index (special tokens)
|
43 |
+
true_predictions = [
|
44 |
+
[label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
|
45 |
+
for pred, gold_label in zip(y_pred, y_true)
|
46 |
+
]
|
47 |
+
true_labels = [
|
48 |
+
[label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
|
49 |
+
for pred, gold_label in zip(y_pred, y_true)
|
50 |
+
]
|
51 |
+
return true_predictions, true_labels
|
code/utils.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import requests
|
3 |
+
from bs4 import BeautifulSoup
|
4 |
+
import re
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from common import custom_feature_extraction_markuplm
|
8 |
+
import glob
|
9 |
+
|
10 |
+
def get_latest_file(directory):
|
11 |
+
# Get list of all files in the directory
|
12 |
+
list_of_files = glob.glob(os.path.join(directory, '*'))
|
13 |
+
print("Files in model folder\n" + str(list_of_files))
|
14 |
+
|
15 |
+
# Sort files based on creation time
|
16 |
+
latest_file = max(list_of_files, key=os.path.getctime)
|
17 |
+
|
18 |
+
return latest_file
|
19 |
+
|
20 |
+
def split_sliding_data(items, window_size, overlap):
|
21 |
+
new_data = []
|
22 |
+
for obj in items:
|
23 |
+
#print (obj.keys(), '\n')
|
24 |
+
#print (obj, '\n')
|
25 |
+
nodes = obj['nodes']
|
26 |
+
num_elements = len(nodes)
|
27 |
+
#print(num_elements, '\n')
|
28 |
+
counter = 0
|
29 |
+
for i in range(0, num_elements, window_size - overlap):
|
30 |
+
start = i
|
31 |
+
end = min(i + window_size, num_elements)
|
32 |
+
#print (start, end)
|
33 |
+
new_obj = {
|
34 |
+
'Index': obj['Index'] if 'Index' in obj else 0,
|
35 |
+
'Index2': counter,
|
36 |
+
'Url': obj['Url'] if 'Url' in obj else None,
|
37 |
+
'Path': obj['Path'] if 'Path' in obj else None,
|
38 |
+
'nodes': obj['nodes'][start:end],
|
39 |
+
'xpaths': obj['xpaths'][start:end],
|
40 |
+
'xpaths_simple': obj['xpaths_simple'][start:end],
|
41 |
+
'labels': obj['labels'][start:end] if 'labels' in obj else None,
|
42 |
+
}
|
43 |
+
counter= counter+1
|
44 |
+
#print (new_obj, '\n')
|
45 |
+
new_data.append(new_obj)
|
46 |
+
|
47 |
+
return new_data
|
48 |
+
|
49 |
+
|
50 |
+
# Function to fetch HTML content from URL
|
51 |
+
def get_html_content(url):
|
52 |
+
try:
|
53 |
+
response = requests.get(url)
|
54 |
+
if response.status_code == 200:
|
55 |
+
return response.text
|
56 |
+
else:
|
57 |
+
return None
|
58 |
+
except Exception as e:
|
59 |
+
print("Error fetching HTML content:", e)
|
60 |
+
return None
|
61 |
+
|
62 |
+
# Function to clean HTML content
|
63 |
+
def clean_html(html):
|
64 |
+
# Remove extra whitespaces, newlines, and tabs
|
65 |
+
soup = BeautifulSoup(html, "html.parser")
|
66 |
+
|
67 |
+
for data in soup(['style', 'script',]):
|
68 |
+
# Remove tags
|
69 |
+
data.decompose()
|
70 |
+
|
71 |
+
html = str(soup)
|
72 |
+
clean_html = re.sub(r'\s+', ' ', html)
|
73 |
+
# Escape double quotes and wrap content in double quotes
|
74 |
+
#clean_html = clean_html.replace('"', '""')
|
75 |
+
#clean_html = f'"{clean_html}"'
|
76 |
+
return clean_html
|
77 |
+
|
78 |
+
# Function to extract HTML content from URL and save to new dataset
|
79 |
+
def extract_nodes_and_feautures(html_content):
|
80 |
+
if html_content:
|
81 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
82 |
+
cleaned_html = clean_html(str(soup))
|
83 |
+
|
84 |
+
feature_extractor = custom_feature_extraction_markuplm.CustomMarkupLMFeatureExtractor(None)
|
85 |
+
|
86 |
+
encoding = feature_extractor(cleaned_html)
|
87 |
+
|
88 |
+
#print(encoding.keys())
|
89 |
+
row = {}
|
90 |
+
row['nodes'] = encoding['nodes'][0]
|
91 |
+
row['xpaths'] = encoding['xpaths'][0]
|
92 |
+
row['xpaths_simple'] = encoding['xpaths_simple'][0]
|
93 |
+
row['labels'] = encoding['labels'][0]
|
94 |
+
return row
|
95 |
+
else:
|
96 |
+
return pd.Series()
|
eval.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import MarkupLMForTokenClassification
|
2 |
+
from transformers import MarkupLMProcessor
|
3 |
+
from code import utils, labels
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
|
10 |
+
|
11 |
+
def rank_titles(titles, content):
|
12 |
+
vectorizer = TfidfVectorizer()
|
13 |
+
texts = titles + [content]
|
14 |
+
|
15 |
+
tfidf_matrix = vectorizer.fit_transform(texts)
|
16 |
+
|
17 |
+
cosine_similarities = cosine_similarity(tfidf_matrix[-1:], tfidf_matrix[:-1]).flatten()
|
18 |
+
ranked_titles_indices = np.argsort(cosine_similarities)[::-1]
|
19 |
+
ranked_titles = [titles[idx] for idx in ranked_titles_indices]
|
20 |
+
return ranked_titles
|
21 |
+
|
22 |
+
|
23 |
+
def eval(url):
|
24 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
|
26 |
+
model_folder = os.path.join(current_dir, 'models') # models folder is in the repository root
|
27 |
+
model_name = 'OxMarkupLM.pt'
|
28 |
+
|
29 |
+
processor = MarkupLMProcessor.from_pretrained("microsoft/markuplm-base")
|
30 |
+
processor.parse_html = False
|
31 |
+
|
32 |
+
model_path = os.path.join(model_folder, model_name)
|
33 |
+
|
34 |
+
model = MarkupLMForTokenClassification.from_pretrained(
|
35 |
+
model_path, id2label=labels.id2label, label2id=labels.label2id
|
36 |
+
)
|
37 |
+
|
38 |
+
html = utils.clean_html(utils.get_html_content(url))
|
39 |
+
data = [utils.extract_nodes_and_feautures(html)]
|
40 |
+
example = utils.split_sliding_data(data, 10, 0)
|
41 |
+
|
42 |
+
title, author, date, content = [], [], [], []
|
43 |
+
for splited in example:
|
44 |
+
nodes, xpaths = splited['nodes'], splited['xpaths']
|
45 |
+
encoding = processor(
|
46 |
+
nodes=nodes, xpaths=xpaths, return_offsets_mapping=True,
|
47 |
+
padding="max_length", truncation=True, max_length=512, return_tensors="pt"
|
48 |
+
)
|
49 |
+
offset_mapping = encoding.pop("offset_mapping")
|
50 |
+
with torch.no_grad():
|
51 |
+
logits = model(**encoding).logits
|
52 |
+
|
53 |
+
predictions = logits.argmax(-1)
|
54 |
+
processed_words = []
|
55 |
+
|
56 |
+
for pred_id, word_id, offset in zip(predictions[0].tolist(), encoding.word_ids(0), offset_mapping[0].tolist()):
|
57 |
+
if word_id is not None and offset[0] == 0:
|
58 |
+
if pred_id == 1:
|
59 |
+
title.append(nodes[word_id])
|
60 |
+
elif pred_id == 2 and word_id not in processed_words:
|
61 |
+
processed_words.append(word_id)
|
62 |
+
content.append(nodes[word_id])
|
63 |
+
elif pred_id == 3:
|
64 |
+
author.append(nodes[word_id])
|
65 |
+
elif pred_id == 4:
|
66 |
+
date.append(nodes[word_id])
|
67 |
+
|
68 |
+
title = rank_titles(title, '\n'.join(content))
|
69 |
+
return {
|
70 |
+
"model_name": model_name,
|
71 |
+
"url": url,
|
72 |
+
"title": title,
|
73 |
+
"author": author,
|
74 |
+
"date": date,
|
75 |
+
"content": content,
|
76 |
+
}
|
77 |
+
|