Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.nn.functional import softmax
|
5 |
+
from transformers import XLMRobertaTokenizerFast, AutoModelForTokenClassification
|
6 |
+
import pandas as pd
|
7 |
+
import trafilatura
|
8 |
+
|
9 |
+
# Load model and tokenizer
|
10 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
11 |
+
tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-large")
|
12 |
+
model = AutoModelForTokenClassification.from_pretrained("dejanseo/LinkBERT-XL").to(device)
|
13 |
+
model.eval()
|
14 |
+
|
15 |
+
# Functions
|
16 |
+
|
17 |
+
def tokenize_with_indices(text: str):
|
18 |
+
encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=True)
|
19 |
+
return encoded['input_ids'], encoded['offset_mapping']
|
20 |
+
|
21 |
+
def fetch_and_extract_content(url: str):
|
22 |
+
downloaded = trafilatura.fetch_url(url)
|
23 |
+
if downloaded:
|
24 |
+
content = trafilatura.extract(downloaded, include_comments=False, include_tables=False)
|
25 |
+
return content
|
26 |
+
return None
|
27 |
+
|
28 |
+
def process_text(inputs: str, confidence_threshold: float):
|
29 |
+
max_chunk_length = 512 - 2
|
30 |
+
words = inputs.split()
|
31 |
+
chunk_texts = []
|
32 |
+
current_chunk = []
|
33 |
+
current_length = 0
|
34 |
+
for word in words:
|
35 |
+
if len(tokenizer.tokenize(word)) + current_length > max_chunk_length:
|
36 |
+
chunk_texts.append(" ".join(current_chunk))
|
37 |
+
current_chunk = [word]
|
38 |
+
current_length = len(tokenizer.tokenize(word))
|
39 |
+
else:
|
40 |
+
current_chunk.append(word)
|
41 |
+
current_length += len(tokenizer.tokenize(word))
|
42 |
+
chunk_texts.append(" ".join(current_chunk))
|
43 |
+
|
44 |
+
df_data = {
|
45 |
+
'Word': [],
|
46 |
+
'Prediction': [],
|
47 |
+
'Confidence': [],
|
48 |
+
'Start': [],
|
49 |
+
'End': []
|
50 |
+
}
|
51 |
+
reconstructed_text = ""
|
52 |
+
original_position_offset = 0
|
53 |
+
|
54 |
+
for chunk in chunk_texts:
|
55 |
+
input_ids, token_offsets = tokenize_with_indices(chunk)
|
56 |
+
predictions = []
|
57 |
+
|
58 |
+
input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
|
59 |
+
with torch.no_grad():
|
60 |
+
outputs = model(input_ids_tensor)
|
61 |
+
logits = outputs.logits
|
62 |
+
predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
|
63 |
+
softmax_scores = F.softmax(logits, dim=-1).squeeze().tolist()
|
64 |
+
|
65 |
+
word_info = {}
|
66 |
+
|
67 |
+
for idx, (start, end) in enumerate(token_offsets):
|
68 |
+
if idx == 0 or idx == len(token_offsets) - 1:
|
69 |
+
continue
|
70 |
+
|
71 |
+
word_start = start
|
72 |
+
while word_start > 0 and chunk[word_start-1] != ' ':
|
73 |
+
word_start -= 1
|
74 |
+
|
75 |
+
if word_start not in word_info:
|
76 |
+
word_info[word_start] = {'prediction': 0, 'confidence': 0.0, 'subtokens': []}
|
77 |
+
|
78 |
+
confidence_percentage = softmax_scores[idx][predictions[idx]] * 100
|
79 |
+
|
80 |
+
if predictions[idx] == 1 and confidence_percentage >= confidence_threshold:
|
81 |
+
word_info[word_start]['prediction'] = 1
|
82 |
+
|
83 |
+
word_info[word_start]['confidence'] = max(word_info[word_start]['confidence'], confidence_percentage)
|
84 |
+
word_info[word_start]['subtokens'].append((start, end, chunk[start:end]))
|
85 |
+
|
86 |
+
last_end = 0
|
87 |
+
for word_start in sorted(word_info.keys()):
|
88 |
+
word_data = word_info[word_start]
|
89 |
+
for subtoken_start, subtoken_end, subtoken_text in word_data['subtokens']:
|
90 |
+
if last_end < subtoken_start:
|
91 |
+
reconstructed_text += chunk[last_end:subtoken_start]
|
92 |
+
if word_data['prediction'] == 1:
|
93 |
+
reconstructed_text += f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{subtoken_text}</span>"
|
94 |
+
else:
|
95 |
+
reconstructed_text += subtoken_text
|
96 |
+
last_end = subtoken_end
|
97 |
+
|
98 |
+
df_data['Word'].append(subtoken_text)
|
99 |
+
df_data['Prediction'].append(word_data['prediction'])
|
100 |
+
df_data['Confidence'].append(word_data['confidence'])
|
101 |
+
df_data['Start'].append(subtoken_start + original_position_offset)
|
102 |
+
df_data['End'].append(subtoken_end + original_position_offset)
|
103 |
+
|
104 |
+
original_position_offset += len(chunk) + 1
|
105 |
+
|
106 |
+
reconstructed_text += chunk[last_end:]
|
107 |
+
|
108 |
+
df_tokens = pd.DataFrame(df_data)
|
109 |
+
return reconstructed_text, df_tokens
|
110 |
+
|
111 |
+
# Streamlit Interface
|
112 |
+
|
113 |
+
st.set_page_config(layout="wide")
|
114 |
+
st.title('Text Processing with XLM-Roberta and LinkBERT-XL')
|
115 |
+
|
116 |
+
confidence_threshold = st.slider('Confidence Threshold', 50, 100, 75)
|
117 |
+
|
118 |
+
tab1, tab2 = st.tabs(["Text Input", "URL Input"])
|
119 |
+
|
120 |
+
with tab1:
|
121 |
+
user_input = st.text_area("Enter text to process:", "Type here...")
|
122 |
+
if st.button('Process Text'):
|
123 |
+
highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
|
124 |
+
st.markdown(highlighted_text, unsafe_allow_html=True)
|
125 |
+
st.dataframe(df_tokens)
|
126 |
+
|
127 |
+
with tab2:
|
128 |
+
url_input = st.text_input("Enter URL to process:")
|
129 |
+
if st.button('Fetch and Process'):
|
130 |
+
content = fetch_and_extract_content(url_input)
|
131 |
+
if content:
|
132 |
+
highlighted_text, df_tokens = process_text(content, confidence_threshold)
|
133 |
+
st.markdown(highlighted_text, unsafe_allow_html=True)
|
134 |
+
st.dataframe(df_tokens)
|
135 |
+
else:
|
136 |
+
st.error("Could not fetch content from the URL. Please check the URL and try again.")
|