File size: 5,378 Bytes
e1234b6 266743a e1234b6 266743a e1234b6 266743a e1234b6 266743a e1234b6 d598c67 e1234b6 266743a e1234b6 d598c67 e1234b6 86e91e4 e1234b6 7efe121 e1234b6 86e91e4 e1234b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import streamlit as st
import torch
import torch.nn.functional as F
from torch.nn.functional import softmax
from transformers import XLMRobertaTokenizerFast, AutoModelForTokenClassification
import pandas as pd
import trafilatura
# Load model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-large")
model = AutoModelForTokenClassification.from_pretrained("dejanseo/LinkBERT-XL").to(device)
model.eval()
# Functions
def tokenize_with_indices(text: str):
encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=True)
return encoded['input_ids'], encoded['offset_mapping']
def fetch_and_extract_content(url: str):
downloaded = trafilatura.fetch_url(url)
if downloaded:
content = trafilatura.extract(downloaded, include_comments=False, include_tables=False)
return content
return None
def process_text(inputs: str, confidence_threshold: float):
max_chunk_length = 512 - 2
words = inputs.split()
chunk_texts = []
current_chunk = []
current_length = 0
for word in words:
if len(tokenizer.tokenize(word)) + current_length > max_chunk_length:
chunk_texts.append(" ".join(current_chunk))
current_chunk = [word]
current_length = len(tokenizer.tokenize(word))
else:
current_chunk.append(word)
current_length += len(tokenizer.tokenize(word))
chunk_texts.append(" ".join(current_chunk))
df_data = {
'Word': [],
'Prediction': [],
'Confidence': [],
'Start': [],
'End': []
}
reconstructed_text = ""
original_position_offset = 0
for chunk in chunk_texts:
input_ids, token_offsets = tokenize_with_indices(chunk)
predictions = []
input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_ids_tensor)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
softmax_scores = F.softmax(logits, dim=-1).squeeze().tolist()
word_info = {}
for idx, (start, end) in enumerate(token_offsets):
if idx == 0 or idx == len(token_offsets) - 1:
continue
word_start = start
while word_start > 0 and chunk[word_start-1] != ' ':
word_start -= 1
if word_start not in word_info:
word_info[word_start] = {'prediction': 0, 'confidence': 0.0, 'subtokens': []}
confidence_percentage = softmax_scores[idx][predictions[idx]] * 100
if predictions[idx] == 1 and confidence_percentage >= confidence_threshold:
word_info[word_start]['prediction'] = 1
word_info[word_start]['confidence'] = max(word_info[word_start]['confidence'], confidence_percentage)
word_info[word_start]['subtokens'].append((start, end, chunk[start:end]))
last_end = 0
for word_start in sorted(word_info.keys()):
word_data = word_info[word_start]
for subtoken_start, subtoken_end, subtoken_text in word_data['subtokens']:
escaped_subtoken_text = subtoken_text.replace('$', '\\$') # Perform replacement outside f-string
if last_end < subtoken_start:
reconstructed_text += chunk[last_end:subtoken_start]
if word_data['prediction'] == 1:
reconstructed_text += f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped_subtoken_text}</span>"
else:
reconstructed_text += escaped_subtoken_text
last_end = subtoken_end
df_data['Word'].append(escaped_subtoken_text)
df_data['Prediction'].append(word_data['prediction'])
df_data['Confidence'].append(word_info[word_start]['confidence'])
df_data['Start'].append(subtoken_start + original_position_offset)
df_data['End'].append(subtoken_end + original_position_offset)
original_position_offset += len(chunk) + 1
reconstructed_text += chunk[last_end:].replace('$', '\\$')
df_tokens = pd.DataFrame(df_data)
return reconstructed_text, df_tokens
# Streamlit Interface
st.set_page_config(layout="wide")
st.title('SEO by DEJAN: LinkBERT')
confidence_threshold = st.slider('Confidence Threshold', 50, 100, 50)
tab1, tab2 = st.tabs(["Text Input", "URL Input"])
with tab1:
user_input = st.text_area("Enter text to process:")
if st.button('Process Text'):
highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
st.markdown(highlighted_text, unsafe_allow_html=True)
st.dataframe(df_tokens)
with tab2:
url_input = st.text_input("Enter URL to process:")
if st.button('Fetch and Process'):
content = fetch_and_extract_content(url_input)
if content:
highlighted_text, df_tokens = process_text(content, confidence_threshold)
st.markdown(highlighted_text, unsafe_allow_html=True)
st.dataframe(df_tokens)
else:
st.error("Could not fetch content from the URL. Please check the URL and try again.")
|