dejanseo commited on
Commit
e1234b6
1 Parent(s): 3acf4ee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -0
app.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.nn.functional import softmax
5
+ from transformers import XLMRobertaTokenizerFast, AutoModelForTokenClassification
6
+ import pandas as pd
7
+ import trafilatura
8
+
9
+ # Load model and tokenizer
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ tokenizer = XLMRobertaTokenizerFast.from_pretrained("xlm-roberta-large")
12
+ model = AutoModelForTokenClassification.from_pretrained("dejanseo/LinkBERT-XL").to(device)
13
+ model.eval()
14
+
15
+ # Functions
16
+
17
+ def tokenize_with_indices(text: str):
18
+ encoded = tokenizer.encode_plus(text, return_offsets_mapping=True, add_special_tokens=True)
19
+ return encoded['input_ids'], encoded['offset_mapping']
20
+
21
+ def fetch_and_extract_content(url: str):
22
+ downloaded = trafilatura.fetch_url(url)
23
+ if downloaded:
24
+ content = trafilatura.extract(downloaded, include_comments=False, include_tables=False)
25
+ return content
26
+ return None
27
+
28
+ def process_text(inputs: str, confidence_threshold: float):
29
+ max_chunk_length = 512 - 2
30
+ words = inputs.split()
31
+ chunk_texts = []
32
+ current_chunk = []
33
+ current_length = 0
34
+ for word in words:
35
+ if len(tokenizer.tokenize(word)) + current_length > max_chunk_length:
36
+ chunk_texts.append(" ".join(current_chunk))
37
+ current_chunk = [word]
38
+ current_length = len(tokenizer.tokenize(word))
39
+ else:
40
+ current_chunk.append(word)
41
+ current_length += len(tokenizer.tokenize(word))
42
+ chunk_texts.append(" ".join(current_chunk))
43
+
44
+ df_data = {
45
+ 'Word': [],
46
+ 'Prediction': [],
47
+ 'Confidence': [],
48
+ 'Start': [],
49
+ 'End': []
50
+ }
51
+ reconstructed_text = ""
52
+ original_position_offset = 0
53
+
54
+ for chunk in chunk_texts:
55
+ input_ids, token_offsets = tokenize_with_indices(chunk)
56
+ predictions = []
57
+
58
+ input_ids_tensor = torch.tensor(input_ids).unsqueeze(0).to(device)
59
+ with torch.no_grad():
60
+ outputs = model(input_ids_tensor)
61
+ logits = outputs.logits
62
+ predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
63
+ softmax_scores = F.softmax(logits, dim=-1).squeeze().tolist()
64
+
65
+ word_info = {}
66
+
67
+ for idx, (start, end) in enumerate(token_offsets):
68
+ if idx == 0 or idx == len(token_offsets) - 1:
69
+ continue
70
+
71
+ word_start = start
72
+ while word_start > 0 and chunk[word_start-1] != ' ':
73
+ word_start -= 1
74
+
75
+ if word_start not in word_info:
76
+ word_info[word_start] = {'prediction': 0, 'confidence': 0.0, 'subtokens': []}
77
+
78
+ confidence_percentage = softmax_scores[idx][predictions[idx]] * 100
79
+
80
+ if predictions[idx] == 1 and confidence_percentage >= confidence_threshold:
81
+ word_info[word_start]['prediction'] = 1
82
+
83
+ word_info[word_start]['confidence'] = max(word_info[word_start]['confidence'], confidence_percentage)
84
+ word_info[word_start]['subtokens'].append((start, end, chunk[start:end]))
85
+
86
+ last_end = 0
87
+ for word_start in sorted(word_info.keys()):
88
+ word_data = word_info[word_start]
89
+ for subtoken_start, subtoken_end, subtoken_text in word_data['subtokens']:
90
+ if last_end < subtoken_start:
91
+ reconstructed_text += chunk[last_end:subtoken_start]
92
+ if word_data['prediction'] == 1:
93
+ reconstructed_text += f"<span style='background-color: rgba(0, 255, 0); display: inline;'>{subtoken_text}</span>"
94
+ else:
95
+ reconstructed_text += subtoken_text
96
+ last_end = subtoken_end
97
+
98
+ df_data['Word'].append(subtoken_text)
99
+ df_data['Prediction'].append(word_data['prediction'])
100
+ df_data['Confidence'].append(word_data['confidence'])
101
+ df_data['Start'].append(subtoken_start + original_position_offset)
102
+ df_data['End'].append(subtoken_end + original_position_offset)
103
+
104
+ original_position_offset += len(chunk) + 1
105
+
106
+ reconstructed_text += chunk[last_end:]
107
+
108
+ df_tokens = pd.DataFrame(df_data)
109
+ return reconstructed_text, df_tokens
110
+
111
+ # Streamlit Interface
112
+
113
+ st.set_page_config(layout="wide")
114
+ st.title('Text Processing with XLM-Roberta and LinkBERT-XL')
115
+
116
+ confidence_threshold = st.slider('Confidence Threshold', 50, 100, 75)
117
+
118
+ tab1, tab2 = st.tabs(["Text Input", "URL Input"])
119
+
120
+ with tab1:
121
+ user_input = st.text_area("Enter text to process:", "Type here...")
122
+ if st.button('Process Text'):
123
+ highlighted_text, df_tokens = process_text(user_input, confidence_threshold)
124
+ st.markdown(highlighted_text, unsafe_allow_html=True)
125
+ st.dataframe(df_tokens)
126
+
127
+ with tab2:
128
+ url_input = st.text_input("Enter URL to process:")
129
+ if st.button('Fetch and Process'):
130
+ content = fetch_and_extract_content(url_input)
131
+ if content:
132
+ highlighted_text, df_tokens = process_text(content, confidence_threshold)
133
+ st.markdown(highlighted_text, unsafe_allow_html=True)
134
+ st.dataframe(df_tokens)
135
+ else:
136
+ st.error("Could not fetch content from the URL. Please check the URL and try again.")