rondaravaol commited on
Commit
5a69a9a
·
1 Parent(s): 6a54372
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. api_inference.py +11 -0
  3. code/labels.py +51 -0
  4. code/utils.py +96 -0
  5. eval.py +77 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.extension filter=lfs diff=lfs merge=lfs -text
api_inference.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from eval import eval
2
+
3
+ def query(payload):
4
+ url = payload.get("url", "")
5
+ if not url:
6
+ return {"error": "No URL provided"}
7
+ try:
8
+ result = eval(url)
9
+ return result
10
+ except Exception as e:
11
+ return {"error": str(e)}
code/labels.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from collections import Counter
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+
6
+ id2label = {0: "none", 1:"title", 2:"content", 3:"author", 4: "date", 5: "header", 6: "footer", 7: "rail", 8: "advertisement", 9: "navigation"}
7
+ label2id = {label:id for id, label in id2label.items()}
8
+
9
+ label_list = ["B-" + x for x in list(id2label.values())]
10
+ print(label_list)
11
+
12
+ def get_class_weights_tensor(dataset):
13
+ all_labels = [label for data_item in dataset for label in data_item['labels']]
14
+
15
+ # Count the frequency of each label
16
+ label_counter = Counter(all_labels)
17
+
18
+ # Calculate the class weights
19
+ total_count = sum(label_counter.values())
20
+ class_weights = {label: total_count / count for label, count in label_counter.items()}
21
+
22
+ # Normalize the weights
23
+ sum_weights = sum(class_weights.values())
24
+ normalized_class_weights = {label: weight / sum_weights for label, weight in class_weights.items()}
25
+
26
+ # Convert class weights to a tensor
27
+ class_weights_list = [normalized_class_weights[label] for label in sorted(normalized_class_weights.keys())]
28
+ class_weights_tensor = torch.tensor(class_weights_list, dtype=torch.float).to(device)
29
+
30
+ return class_weights_tensor
31
+
32
+
33
+ def get_labels(predictions, references):
34
+ # Transform predictions and references tensos to numpy arrays
35
+ if device.type == "cpu":
36
+ y_pred = predictions.detach().clone().numpy()
37
+ y_true = references.detach().clone().numpy()
38
+ else:
39
+ y_pred = predictions.detach().cpu().clone().numpy()
40
+ y_true = references.detach().cpu().clone().numpy()
41
+
42
+ # Remove ignored index (special tokens)
43
+ true_predictions = [
44
+ [label_list[p] for (p, l) in zip(pred, gold_label) if l != -100]
45
+ for pred, gold_label in zip(y_pred, y_true)
46
+ ]
47
+ true_labels = [
48
+ [label_list[l] for (p, l) in zip(pred, gold_label) if l != -100]
49
+ for pred, gold_label in zip(y_pred, y_true)
50
+ ]
51
+ return true_predictions, true_labels
code/utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import requests
3
+ from bs4 import BeautifulSoup
4
+ import re
5
+ import json
6
+ import os
7
+ from common import custom_feature_extraction_markuplm
8
+ import glob
9
+
10
+ def get_latest_file(directory):
11
+ # Get list of all files in the directory
12
+ list_of_files = glob.glob(os.path.join(directory, '*'))
13
+ print("Files in model folder\n" + str(list_of_files))
14
+
15
+ # Sort files based on creation time
16
+ latest_file = max(list_of_files, key=os.path.getctime)
17
+
18
+ return latest_file
19
+
20
+ def split_sliding_data(items, window_size, overlap):
21
+ new_data = []
22
+ for obj in items:
23
+ #print (obj.keys(), '\n')
24
+ #print (obj, '\n')
25
+ nodes = obj['nodes']
26
+ num_elements = len(nodes)
27
+ #print(num_elements, '\n')
28
+ counter = 0
29
+ for i in range(0, num_elements, window_size - overlap):
30
+ start = i
31
+ end = min(i + window_size, num_elements)
32
+ #print (start, end)
33
+ new_obj = {
34
+ 'Index': obj['Index'] if 'Index' in obj else 0,
35
+ 'Index2': counter,
36
+ 'Url': obj['Url'] if 'Url' in obj else None,
37
+ 'Path': obj['Path'] if 'Path' in obj else None,
38
+ 'nodes': obj['nodes'][start:end],
39
+ 'xpaths': obj['xpaths'][start:end],
40
+ 'xpaths_simple': obj['xpaths_simple'][start:end],
41
+ 'labels': obj['labels'][start:end] if 'labels' in obj else None,
42
+ }
43
+ counter= counter+1
44
+ #print (new_obj, '\n')
45
+ new_data.append(new_obj)
46
+
47
+ return new_data
48
+
49
+
50
+ # Function to fetch HTML content from URL
51
+ def get_html_content(url):
52
+ try:
53
+ response = requests.get(url)
54
+ if response.status_code == 200:
55
+ return response.text
56
+ else:
57
+ return None
58
+ except Exception as e:
59
+ print("Error fetching HTML content:", e)
60
+ return None
61
+
62
+ # Function to clean HTML content
63
+ def clean_html(html):
64
+ # Remove extra whitespaces, newlines, and tabs
65
+ soup = BeautifulSoup(html, "html.parser")
66
+
67
+ for data in soup(['style', 'script',]):
68
+ # Remove tags
69
+ data.decompose()
70
+
71
+ html = str(soup)
72
+ clean_html = re.sub(r'\s+', ' ', html)
73
+ # Escape double quotes and wrap content in double quotes
74
+ #clean_html = clean_html.replace('"', '""')
75
+ #clean_html = f'"{clean_html}"'
76
+ return clean_html
77
+
78
+ # Function to extract HTML content from URL and save to new dataset
79
+ def extract_nodes_and_feautures(html_content):
80
+ if html_content:
81
+ soup = BeautifulSoup(html_content, 'html.parser')
82
+ cleaned_html = clean_html(str(soup))
83
+
84
+ feature_extractor = custom_feature_extraction_markuplm.CustomMarkupLMFeatureExtractor(None)
85
+
86
+ encoding = feature_extractor(cleaned_html)
87
+
88
+ #print(encoding.keys())
89
+ row = {}
90
+ row['nodes'] = encoding['nodes'][0]
91
+ row['xpaths'] = encoding['xpaths'][0]
92
+ row['xpaths_simple'] = encoding['xpaths_simple'][0]
93
+ row['labels'] = encoding['labels'][0]
94
+ return row
95
+ else:
96
+ return pd.Series()
eval.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import MarkupLMForTokenClassification
2
+ from transformers import MarkupLMProcessor
3
+ from code import utils, labels
4
+ import torch
5
+ import os
6
+ import numpy as np
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+
10
+
11
+ def rank_titles(titles, content):
12
+ vectorizer = TfidfVectorizer()
13
+ texts = titles + [content]
14
+
15
+ tfidf_matrix = vectorizer.fit_transform(texts)
16
+
17
+ cosine_similarities = cosine_similarity(tfidf_matrix[-1:], tfidf_matrix[:-1]).flatten()
18
+ ranked_titles_indices = np.argsort(cosine_similarities)[::-1]
19
+ ranked_titles = [titles[idx] for idx in ranked_titles_indices]
20
+ return ranked_titles
21
+
22
+
23
+ def eval(url):
24
+ current_dir = os.path.dirname(os.path.abspath(__file__))
25
+
26
+ model_folder = os.path.join(current_dir, 'models') # models folder is in the repository root
27
+ model_name = 'OxMarkupLM.pt'
28
+
29
+ processor = MarkupLMProcessor.from_pretrained("microsoft/markuplm-base")
30
+ processor.parse_html = False
31
+
32
+ model_path = os.path.join(model_folder, model_name)
33
+
34
+ model = MarkupLMForTokenClassification.from_pretrained(
35
+ model_path, id2label=labels.id2label, label2id=labels.label2id
36
+ )
37
+
38
+ html = utils.clean_html(utils.get_html_content(url))
39
+ data = [utils.extract_nodes_and_feautures(html)]
40
+ example = utils.split_sliding_data(data, 10, 0)
41
+
42
+ title, author, date, content = [], [], [], []
43
+ for splited in example:
44
+ nodes, xpaths = splited['nodes'], splited['xpaths']
45
+ encoding = processor(
46
+ nodes=nodes, xpaths=xpaths, return_offsets_mapping=True,
47
+ padding="max_length", truncation=True, max_length=512, return_tensors="pt"
48
+ )
49
+ offset_mapping = encoding.pop("offset_mapping")
50
+ with torch.no_grad():
51
+ logits = model(**encoding).logits
52
+
53
+ predictions = logits.argmax(-1)
54
+ processed_words = []
55
+
56
+ for pred_id, word_id, offset in zip(predictions[0].tolist(), encoding.word_ids(0), offset_mapping[0].tolist()):
57
+ if word_id is not None and offset[0] == 0:
58
+ if pred_id == 1:
59
+ title.append(nodes[word_id])
60
+ elif pred_id == 2 and word_id not in processed_words:
61
+ processed_words.append(word_id)
62
+ content.append(nodes[word_id])
63
+ elif pred_id == 3:
64
+ author.append(nodes[word_id])
65
+ elif pred_id == 4:
66
+ date.append(nodes[word_id])
67
+
68
+ title = rank_titles(title, '\n'.join(content))
69
+ return {
70
+ "model_name": model_name,
71
+ "url": url,
72
+ "title": title,
73
+ "author": author,
74
+ "date": date,
75
+ "content": content,
76
+ }
77
+