Olivier CARON commited on
Commit
6cd26a2
1 Parent(s): a6b90bc

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -0
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os # Add this import to use os.path.splitext
2
+ import csv
3
+ import streamlit as st
4
+ import polars as pl
5
+ from io import BytesIO, StringIO
6
+ from gliner import GLiNER
7
+ from gliner_file import run_ner
8
+ import time
9
+
10
+ st.set_page_config(page_title="GliNER", page_icon="🧊", layout="wide", initial_sidebar_state="expanded")
11
+
12
+ # Modified function to load data from either an Excel or CSV file
13
+ @st.cache_data
14
+ def load_data(file):
15
+ _, file_ext = os.path.splitext(file.name)
16
+ if file_ext.lower() in ['.xls', '.xlsx']:
17
+ return pl.read_excel(file)
18
+ elif file_ext.lower() == '.csv':
19
+ file.seek(0) # Retour au début du fichier
20
+ try:
21
+ sample = file.read(4096).decode('utf-8') # Essayer de décoder l'échantillon en UTF-8
22
+ encoding = 'utf-8'
23
+ except UnicodeDecodeError:
24
+ encoding = 'latin1' # Basculer sur 'latin1' si UTF-8 échoue
25
+ file.seek(0)
26
+ sample = file.read(4096).decode(encoding)
27
+
28
+ file.seek(0)
29
+ dialect = csv.Sniffer().sniff(sample) # Détecter le dialecte/délimiteur
30
+
31
+ # Convertir le fichier en StringIO pour simuler un fichier texte, si nécessaire
32
+ file.seek(0)
33
+ if encoding != 'utf-8':
34
+ file_content = file.read().decode(encoding)
35
+ file = StringIO(file_content)
36
+ else:
37
+ file_content = file.read().decode('utf-8')
38
+ file = StringIO(file_content)
39
+
40
+ return pl.read_csv(file, separator=dialect.delimiter, truncate_ragged_lines=True, ignore_errors=True)
41
+ else:
42
+ raise ValueError("The uploaded file must be a CSV or Excel file.")
43
+
44
+
45
+ # Function to perform NER and update the UI
46
+ def perform_ner(filtered_df, selected_column, labels_list):
47
+ ner_results_dict = {label: [] for label in labels_list}
48
+
49
+ progress_bar = st.progress(0)
50
+ progress_text = st.empty()
51
+
52
+ start_time = time.time() # Enregistrer le temps de début pour le temps d'exécution total
53
+
54
+ for index, row in enumerate(filtered_df.to_pandas().itertuples(), 1):
55
+ iteration_start_time = time.time() # Temps de début pour cette itération
56
+
57
+ if st.session_state.stop_processing:
58
+ progress_text.text("Process stopped by the user.")
59
+ break
60
+
61
+ text_to_analyze = getattr(row, selected_column)
62
+ ner_results = run_ner(st.session_state.gliner_model, text_to_analyze, labels_list)
63
+
64
+ for label in labels_list:
65
+ texts = ner_results.get(label, [])
66
+ concatenated_texts = ', '.join(texts)
67
+ ner_results_dict[label].append(concatenated_texts)
68
+
69
+ progress = index / filtered_df.height
70
+ progress_bar.progress(progress)
71
+
72
+ iteration_time = time.time() - iteration_start_time # Calculer le temps d'exécution pour cette itération
73
+ total_time = time.time() - start_time # Calculer le temps total écoulé jusqu'à présent
74
+
75
+ progress_text.text(f"Progress: {index}/{filtered_df.height} - {progress * 100:.0f}% (Iteration: {iteration_time:.2f}s, Total: {total_time:.2f}s)")
76
+
77
+ end_time = time.time() # Enregistrer le temps de fin
78
+ total_execution_time = end_time - start_time # Calculer le temps d'exécution total
79
+
80
+ progress_text.text(f"Processing complete! Total execution time: {total_execution_time:.2f}s")
81
+
82
+ for label, texts in ner_results_dict.items():
83
+ filtered_df = filtered_df.with_columns(pl.Series(name=label, values=texts))
84
+
85
+ return filtered_df
86
+
87
+ def main():
88
+ st.title("Online NER with GliNER")
89
+ st.markdown("Prototype v0.1")
90
+
91
+ # Ensure the stop_processing flag is initialized
92
+ if 'stop_processing' not in st.session_state:
93
+ st.session_state.stop_processing = False
94
+
95
+ uploaded_file = st.sidebar.file_uploader("Choose a file")
96
+ if uploaded_file is None:
97
+ st.warning("Please upload a file.")
98
+ return
99
+
100
+ try:
101
+ df = load_data(uploaded_file)
102
+ except ValueError as e:
103
+ st.error(str(e))
104
+ return
105
+
106
+ selected_column = st.selectbox("Select the column for NER:", df.columns, index=0)
107
+ filter_text = st.text_input("Filter column by input text", "")
108
+ ner_labels = st.text_input("Enter all your different labels, separated by a comma", "")
109
+
110
+ filtered_df = df.filter(pl.col(selected_column).str.contains(f"(?i).*{filter_text}.*")) if filter_text else df
111
+ st.dataframe(filtered_df)
112
+
113
+ if st.button("Start NER"):
114
+ if not ner_labels:
115
+ st.warning("Please enter some labels for NER.")
116
+ else:
117
+ # Load GLiNER model if not already loaded
118
+ if 'gliner_model' not in st.session_state:
119
+ with st.spinner('Loading GLiNER model... Please wait.'):
120
+ st.session_state.gliner_model = GLiNER.from_pretrained("urchade/gliner_largev2")
121
+ st.session_state.gliner_model.eval()
122
+
123
+ labels_list = ner_labels.split(",")
124
+ updated_df = perform_ner(filtered_df, selected_column, labels_list)
125
+ st.dataframe(updated_df)
126
+
127
+ def to_excel(df):
128
+ output = BytesIO()
129
+ df.to_pandas().to_excel(output, index=False, engine='openpyxl')
130
+ return output.getvalue()
131
+
132
+ df_excel = to_excel(updated_df)
133
+ st.download_button(label="📥 Download Excel",
134
+ data=df_excel,
135
+ file_name="ner_results.xlsx",
136
+ mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet")
137
+
138
+ st.button("Stop Processing", on_click=lambda: setattr(st.session_state, 'stop_processing', True))
139
+
140
+ if __name__ == "__main__":
141
+ main()