File size: 5,668 Bytes
f87d000
6cd26a2
 
 
 
 
 
 
 
b25dbae
 
 
 
6cd26a2
 
 
 
 
b25dbae
6cd26a2
b25dbae
f87d000
6cd26a2
b25dbae
 
 
 
6cd26a2
b25dbae
6cd26a2
 
b25dbae
6cd26a2
f87d000
6cd26a2
 
b25dbae
6cd26a2
 
 
b25dbae
6cd26a2
b25dbae
 
 
 
 
 
 
6cd26a2
 
 
 
 
 
 
b25dbae
6cd26a2
 
b25dbae
f87d000
6cd26a2
 
f87d000
b25dbae
6cd26a2
 
 
 
 
b25dbae
 
 
6cd26a2
 
 
b25dbae
6cd26a2
 
 
 
b25dbae
 
 
 
f87d000
b25dbae
 
 
 
6cd26a2
f87d000
 
b25dbae
 
 
 
 
6cd26a2
 
 
 
 
b25dbae
6cd26a2
 
 
 
 
b25dbae
6cd26a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b25dbae
 
 
 
 
 
 
 
 
6cd26a2
 
 
 
 
 
 
b25dbae
 
 
 
 
6cd26a2
b25dbae
6cd26a2
 
 
 
 
 
b25dbae
6cd26a2
 
 
b25dbae
 
 
 
 
 
 
 
 
 
 
6cd26a2
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import os
import csv
import streamlit as st
import polars as pl
from io import BytesIO, StringIO
from gliner import GLiNER
from gliner_file import run_ner
import time

st.set_page_config(
    page_title="GliNER", page_icon="🧊", layout="wide", initial_sidebar_state="expanded"
)


# Modified function to load data from either an Excel or CSV file
@st.cache_data
def load_data(file):
    _, file_ext = os.path.splitext(file.name)
    if file_ext.lower() in [".xls", ".xlsx"]:
        return pl.read_excel(file)
    elif file_ext.lower() == ".csv":
        file.seek(0)  # Go back to the beginning of the file
        try:
            sample = file.read(4096).decode(
                "utf-8"
            )  # Try to decode the sample in UTF-8
            encoding = "utf-8"
        except UnicodeDecodeError:
            encoding = "latin1"  # Switch to 'latin1' if UTF-8 fails
            file.seek(0)
            sample = file.read(4096).decode(encoding)

        file.seek(0)
        dialect = csv.Sniffer().sniff(sample)  # Detect the delimiter

        file.seek(0)
        if encoding != "utf-8":
            file_content = file.read().decode(encoding)
            file = StringIO(file_content)
        else:
            file_content = file.read().decode("utf-8")
            file = StringIO(file_content)

        return pl.read_csv(
            file,
            separator=dialect.delimiter,
            truncate_ragged_lines=True,
            ignore_errors=True,
        )
    else:
        raise ValueError("The uploaded file must be a CSV or Excel file.")


# Function to perform NER and update the UI
def perform_ner(filtered_df, selected_column, labels_list):
    ner_results_dict = {label: [] for label in labels_list}

    progress_bar = st.progress(0)
    progress_text = st.empty()

    start_time = time.time()  # Record start time for total runtime

    for index, row in enumerate(filtered_df.to_pandas().itertuples(), 1):
        iteration_start_time = time.time()  # Start time for this iteration

        if st.session_state.stop_processing:
            progress_text.text("Process stopped by the user.")
            break

        text_to_analyze = getattr(row, selected_column)
        ner_results = run_ner(
            st.session_state.gliner_model, text_to_analyze, labels_list
        )

        for label in labels_list:
            texts = ner_results.get(label, [])
            concatenated_texts = ", ".join(texts)
            ner_results_dict[label].append(concatenated_texts)

        progress = index / filtered_df.height
        progress_bar.progress(progress)

        iteration_time = (
            time.time() - iteration_start_time
        )  # Calculate runtime for this iteration
        total_time = time.time() - start_time  # Calculate total elapsed time so far

        progress_text.text(
            f"Progress: {index}/{filtered_df.height} - {progress * 100:.0f}% (Iteration: {iteration_time:.2f}s, Total: {total_time:.2f}s)"
        )

    end_time = time.time()  # Record end time
    total_execution_time = end_time - start_time  # Calculate total runtime

    progress_text.text(
        f"Processing complete! Total execution time: {total_execution_time:.2f}s"
    )

    for label, texts in ner_results_dict.items():
        filtered_df = filtered_df.with_columns(pl.Series(name=label, values=texts))

    return filtered_df


def main():
    st.title("Online NER with GliNER")
    st.markdown("Prototype v0.1")

    # Ensure the stop_processing flag is initialized
    if "stop_processing" not in st.session_state:
        st.session_state.stop_processing = False

    uploaded_file = st.sidebar.file_uploader("Choose a file")
    if uploaded_file is None:
        st.warning("Please upload a file.")
        return

    try:
        df = load_data(uploaded_file)
    except ValueError as e:
        st.error(str(e))
        return

    selected_column = st.selectbox("Select the column for NER:", df.columns, index=0)
    filter_text = st.text_input("Filter column by input text", "")
    ner_labels = st.text_input(
        "Enter all your different labels, separated by a comma", ""
    )

    filtered_df = (
        df.filter(pl.col(selected_column).str.contains(f"(?i).*{filter_text}.*"))
        if filter_text
        else df
    )
    st.dataframe(filtered_df)

    if st.button("Start NER"):
        if not ner_labels:
            st.warning("Please enter some labels for NER.")
        else:
            # Load GLiNER model if not already loaded
            if "gliner_model" not in st.session_state:
                with st.spinner("Loading GLiNER model... Please wait."):
                    st.session_state.gliner_model = GLiNER.from_pretrained(
                        "urchade/gliner_largev2"
                    )
                    st.session_state.gliner_model.eval()

            labels_list = ner_labels.split(",")
            updated_df = perform_ner(filtered_df, selected_column, labels_list)
            st.dataframe(updated_df)

            def to_excel(df):
                output = BytesIO()
                df.to_pandas().to_excel(output, index=False, engine="openpyxl")
                return output.getvalue()

            df_excel = to_excel(updated_df)
            st.download_button(
                label="📥 Download Excel",
                data=df_excel,
                file_name="ner_results.xlsx",
                mime="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
            )

    st.button(
        "Stop Processing",
        on_click=lambda: setattr(st.session_state, "stop_processing", True),
    )


if __name__ == "__main__":
    main()