import streamlit as st import torch import torch.nn.functional as F from torch.nn.functional import softmax from transformers import XLMRobertaTokenizerFast, AutoModelForTokenClassification import pandas as pd import trafilatura # Load model and tokenizer device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-large") model = AutoModelForTokenClassification.from_pretrained("dejanseo/LinkBERT-XL").to(device) model.eval() # Functions def tokenize_with_indices(text: str): encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=True) return encoded['input_ids'], encoded['offset_mapping'] def fetch_and_extract_content(url: str): downloaded = trafilatura.fetch_url(url) if downloaded: content = trafilatura.extract(downloaded, include_comments=False, include_tables=False) return content return None def process_text(inputs: str, confidence_threshold: float): max_chunk_length = 512 - 2 words = inputs.split() chunk_texts = [] current_chunk = [] current_length = 0 for word in words: if len(tokenizer.tokenize(word)) + current_length > max_chunk_length: chunk_texts.append(" ".join(current_chunk)) current_chunk = [word] current_length = len(tokenizer.tokenize(word)) else: current_chunk.append(word) current_length += len(tokenizer.tokenize(word)) chunk_texts.append(" ".join(current_chunk)) df_data = { 'Word': [], 'Prediction': [], 'Confidence': [], 'Start': [], 'End': [] } reconstructed_text = "" original_position_offset = 0 for chunk in chunk_texts: input_ids, token_offsets = tokenize_with_indices(chunk) predictions = [] input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_ids_tensor) logits = outputs.logits predictions = torch.argmax(logits, dim=-1).squeeze().tolist() softmax_scores = F.softmax(logits, dim=-1).squeeze().tolist() word_info = {} for idx, (start, end) in enumerate(token_offsets): if idx == 0 or idx == len(token_offsets) - 1: continue word_start = start while word_start > 0 and chunk[word_start-1] != ' ': word_start -= 1 if word_start not in word_info: word_info[word_start] = {'prediction': 0, 'confidence': 0.0, 'subtokens': []} confidence_percentage = softmax_scores[idx][predictions[idx]] * 100 if predictions[idx] == 1 and confidence_percentage >= confidence_threshold: word_info[word_start]['prediction'] = 1 word_info[word_start]['confidence'] = max(word_info[word_start]['confidence'], confidence_percentage) word_info[word_start]['subtokens'].append((start, end, chunk[start:end])) last_end = 0 for word_start in sorted(word_info.keys()): word_data = word_info[word_start] for subtoken_start, subtoken_end, subtoken_text in word_data['subtokens']: if last_end < subtoken_start: reconstructed_text += chunk[last_end:subtoken_start] if word_data['prediction'] == 1: reconstructed_text += f"{subtoken_text}" else: reconstructed_text += subtoken_text last_end = subtoken_end df_data['Word'].append(subtoken_text) df_data['Prediction'].append(word_data['prediction']) df_data['Confidence'].append(word_data['confidence']) df_data['Start'].append(subtoken_start + original_position_offset) df_data['End'].append(subtoken_end + original_position_offset) original_position_offset += len(chunk) + 1 reconstructed_text += chunk[last_end:] df_tokens = pd.DataFrame(df_data) return reconstructed_text, df_tokens # Streamlit Interface st.set_page_config(layout="wide") st.title('Text Processing with XLM-Roberta and LinkBERT-XL') confidence_threshold = st.slider('Confidence Threshold', 50, 100, 75) tab1, tab2 = st.tabs(["Text Input", "URL Input"]) with tab1: user_input = st.text_area("Enter text to process:", "Type here...") if st.button('Process Text'): highlighted_text, df_tokens = process_text(user_input, confidence_threshold) st.markdown(highlighted_text, unsafe_allow_html=True) st.dataframe(df_tokens) with tab2: url_input = st.text_input("Enter URL to process:") if st.button('Fetch and Process'): content = fetch_and_extract_content(url_input) if content: highlighted_text, df_tokens = process_text(content, confidence_threshold) st.markdown(highlighted_text, unsafe_allow_html=True) st.dataframe(df_tokens) else: st.error("Could not fetch content from the URL. Please check the URL and try again.")