File size: 5,378 Bytes
e1234b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266743a
e1234b6
 
 
266743a
e1234b6
266743a
e1234b6
 
266743a
e1234b6
d598c67
e1234b6
 
 
266743a
e1234b6
 
d598c67
e1234b6
 
 
 
 
 
 
86e91e4
e1234b6
7efe121
e1234b6
 
 
 
86e91e4
e1234b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import streamlit as st
import torch
import torch.nn.functional as F
from torch.nn.functional import softmax
from transformers import XLMRobertaTokenizerFast, AutoModelForTokenClassification
import pandas as pd
import trafilatura

# Load model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-large")
model = AutoModelForTokenClassification.from_pretrained("dejanseo/LinkBERT-XL").to(device)
model.eval()

# Functions

def tokenize_with_indices(text: str):
    encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=True)
    return encoded['input_ids'], encoded['offset_mapping']

def fetch_and_extract_content(url: str):
    downloaded = trafilatura.fetch_url(url)
    if downloaded:
        content = trafilatura.extract(downloaded, include_comments=False, include_tables=False)
        return content
    return None

def process_text(inputs: str, confidence_threshold: float):
    max_chunk_length = 512 - 2
    words = inputs.split()
    chunk_texts = []
    current_chunk = []
    current_length = 0
    for word in words:
        if len(tokenizer.tokenize(word)) + current_length > max_chunk_length:
            chunk_texts.append(" ".join(current_chunk))
            current_chunk = [word]
            current_length = len(tokenizer.tokenize(word))
        else:
            current_chunk.append(word)
            current_length += len(tokenizer.tokenize(word))
    chunk_texts.append(" ".join(current_chunk))

    df_data = {
        'Word': [],
        'Prediction': [],
        'Confidence': [],
        'Start': [],
        'End': []
    }
    reconstructed_text = ""
    original_position_offset = 0

    for chunk in chunk_texts:
        input_ids, token_offsets = tokenize_with_indices(chunk)
        predictions = []

        input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(input_ids_tensor)
            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
            softmax_scores = F.softmax(logits, dim=-1).squeeze().tolist()

        word_info = {}

        for idx, (start, end) in enumerate(token_offsets):
            if idx == 0 or idx == len(token_offsets) - 1:
                continue

            word_start = start
            while word_start > 0 and chunk[word_start-1] != ' ':
                word_start -= 1

            if word_start not in word_info:
                word_info[word_start] = {'prediction': 0, 'confidence': 0.0, 'subtokens': []}

            confidence_percentage = softmax_scores[idx][predictions[idx]] * 100

            if predictions[idx] == 1 and confidence_percentage >= confidence_threshold:
                word_info[word_start]['prediction'] = 1

            word_info[word_start]['confidence'] = max(word_info[word_start]['confidence'], confidence_percentage)
            word_info[word_start]['subtokens'].append((start, end, chunk[start:end]))

        last_end = 0
        for word_start in sorted(word_info.keys()):
            word_data = word_info[word_start]
            for subtoken_start, subtoken_end, subtoken_text in word_data['subtokens']:
                escaped_subtoken_text = subtoken_text.replace('$', '\\$')  # Perform replacement outside f-string
                if last_end < subtoken_start:
                    reconstructed_text += chunk[last_end:subtoken_start]
                if word_data['prediction'] == 1:
                    reconstructed_text += f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{escaped_subtoken_text}</span>"
                else:
                    reconstructed_text += escaped_subtoken_text
                last_end = subtoken_end

                df_data['Word'].append(escaped_subtoken_text)
                df_data['Prediction'].append(word_data['prediction'])
                df_data['Confidence'].append(word_info[word_start]['confidence'])
                df_data['Start'].append(subtoken_start + original_position_offset)
                df_data['End'].append(subtoken_end + original_position_offset)


            original_position_offset += len(chunk) + 1

        reconstructed_text += chunk[last_end:].replace('$', '\\$')

    df_tokens = pd.DataFrame(df_data)
    return reconstructed_text, df_tokens

# Streamlit Interface

st.set_page_config(layout="wide")
st.title('SEO by DEJAN: LinkBERT')

confidence_threshold = st.slider('Confidence Threshold', 50, 100, 50)

tab1, tab2 = st.tabs(["Text Input", "URL Input"])

with tab1:
    user_input = st.text_area("Enter text to process:")
    if st.button('Process Text'):
        highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
        st.markdown(highlighted_text, unsafe_allow_html=True)
        st.dataframe(df_tokens)

with tab2:
    url_input = st.text_input("Enter URL to process:")
    if st.button('Fetch and Process'):
        content = fetch_and_extract_content(url_input)
        if content:
            highlighted_text, df_tokens = process_text(content, confidence_threshold)
            st.markdown(highlighted_text, unsafe_allow_html=True)
            st.dataframe(df_tokens)
        else:
            st.error("Could not fetch content from the URL. Please check the URL and try again.")