File size: 10,751 Bytes
974f253
 
3cc9efa
974f253
3cc9efa
974f253
 
3cc9efa
 
974f253
 
 
 
3cc9efa
 
974f253
 
 
 
 
 
77d2366
73b392d
 
3cc9efa
73b392d
 
3cc9efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203aa9d
 
3cc9efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203aa9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc9efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203aa9d
3cc9efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203aa9d
 
3cc9efa
 
203aa9d
 
3cc9efa
 
 
 
203aa9d
 
 
 
3cc9efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203aa9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import pydantic
module_file_path = pydantic.__file__

module_file_path = module_file_path.split('pydantic')[0] + 'haystack'

import os
import fileinput


def replace_string_in_files(folder_path, old_str, new_str):
    for subdir, dirs, files in os.walk(folder_path):
        for file in files:
            file_path = os.path.join(subdir, file)

            # Check if the file is a text file (you can modify this condition based on your needs)
            if file.endswith(".txt") or file.endswith(".py"):
                # Open the file in place for editing
                with fileinput.FileInput(file_path, inplace=True) as f:
                    for line in f:
                        # Replace the old string with the new string
                        print(line.replace(old_str, new_str), end='')

with open(module_file_path+'/schema.py','r') as f:
    haystack_schema_file = f.read()

if 'from pydantic.v1' not in haystack_schema_file:
    replace_string_in_files(module_file_path, 'from pydantic', 'from pydantic.v1')


from operator import index
import streamlit as st
import logging
import os

from annotated_text import annotation
from json import JSONDecodeError
from markdown import markdown
from utils.config import parser
from utils.haystack import start_document_store, query, initialize_pipeline, start_preprocessor_node, start_retriever, start_reader
from utils.ui import reset_results, set_initial_state
import pandas as pd
import haystack


# Whether the file upload should be enabled or not
DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD"))
# Define a function to handle file uploads
def upload_files():
    uploaded_files = upload_container.file_uploader(
            "upload", type=["pdf", "txt", "docx"], accept_multiple_files=True, label_visibility="collapsed"
        )
    return uploaded_files

# Define a function to process a single file

def process_file(data_file, preprocesor, document_store):
    # read file and add content
    file_contents = data_file.read().decode("utf-8")
    docs = [{
        'content': str(file_contents),
        'meta': {'name': str(data_file.name)}
    }]
    try:
        names = [item.meta.get('name') for item in document_store.get_all_documents()]
        #if args.store == 'inmemory':
        # doc = converter.convert(file_path=files, meta=None)
        if data_file.name in names:
            print(f"{data_file.name} already processed")
        else:
            print(f'preprocessing uploaded doc {data_file.name}.......')
            #print(data_file.read().decode("utf-8"))
            preprocessed_docs = preprocesor.process(docs)
            print('writing to document store.......')
            document_store.write_documents(preprocessed_docs)
            print('updating emebdding.......')
            document_store.update_embeddings(retriever)
    except Exception as e:
        print(e)

def reset_documents():
    print('Reseting documents list')
    document_store.delete_documents()

def upload_document():
    upload_status = 0
    if data_files is not None:
        for data_file in data_files:
            # Upload file
            if data_file:
                try:
                    #raw_json = upload_doc(data_file)
                    # Call the process_file function for each uploaded file
                    if args.store == 'inmemory':
                        processed_data = process_file(data_file, preprocesor, document_store)
                    upload_container.write(str(data_file.name) + "    βœ… ")
                except Exception as e:
                    upload_container.write(str(data_file.name) + "    ❌ ")
                    upload_container.write("_This file could not be parsed, see the logs for more information._")

try:
    args = parser.parse_args()
    preprocesor = start_preprocessor_node()
    document_store = start_document_store(type=args.store)
    retriever = start_retriever(document_store)
    reader = start_reader()
    st.set_page_config(
        page_title="MLReplySearch",
        layout="centered",
        page_icon=":shark:",
        menu_items={
            'Get Help': 'https://www.extremelycoolapp.com/help',
            'Report a bug': "https://www.extremelycoolapp.com/bug",
            'About': "# This is a header. This is an *extremely* cool app!"
        }
    )
    st.sidebar.image("ml_logo.png", use_column_width=True)
        

    # Sidebar for Task Selection
    st.sidebar.header('Options:')

    # OpenAI Key Input
    openai_key = st.sidebar.text_input("Enter OpenAI Key:", type="password")

    if openai_key:
        task_options = ['Extractive', 'Generative']
    else:
        task_options = ['Extractive']

    task_selection = st.sidebar.radio('Select the task:', task_options)

    # Check the task and initialize pipeline accordingly
    if task_selection == 'Extractive':
        pipeline_extractive = initialize_pipeline("extractive", document_store, retriever, reader)
    elif task_selection == 'Generative' and openai_key:  # Check for openai_key to ensure user has entered it
        pipeline_rag = initialize_pipeline("rag", document_store, retriever, reader, openai_key=openai_key)


    set_initial_state()

    st.write('# ' + args.name)
    
    
    # File upload block
    if not DISABLE_FILE_UPLOAD:
        upload_container = st.sidebar.container()
        upload_container.write("## File Upload:")
        #data_files = st.sidebar.file_uploader(
        #    "upload", type=["pdf", "txt", "docx"], accept_multiple_files=True, label_visibility="hidden"
        #)
        data_files = upload_files()

        upload_container.button('Upload Files', on_click=upload_document, args=())

    st.sidebar.button("Reset documents", on_click=reset_documents, args=())

    if "question" not in st.session_state:
        st.session_state.question = ""
    # Search bar
    question = st.text_input("", value=st.session_state.question, max_chars=100, on_change=reset_results)

    run_pressed = st.button("Run")

    run_query = (
        run_pressed or question != st.session_state.question #or task_selection != st.session_state.task
    )

    # Get results for query
    if run_query and question:
        if task_selection == 'Extractive':
            reset_results()
            st.session_state.question = question
            with st.spinner("πŸ”Ž    Running your pipeline"):
                try:
                    st.session_state.results_extractive = query(pipeline_extractive, question)
                    st.session_state.task = task_selection
                except JSONDecodeError as je:
                    st.error(
                        "πŸ‘“    An error occurred reading the results. Is the document store working?"
                    )
                except Exception as e:
                    logging.exception(e)
                    st.error("🐞    An error occurred during the request.")

        elif task_selection == 'Generative':
            reset_results()
            st.session_state.question = question
            with st.spinner("πŸ”Ž    Running your pipeline"):
                try:
                    st.session_state.results_generative = query(pipeline_rag, question)
                    st.session_state.task = task_selection
                except JSONDecodeError as je:
                    st.error(
                        "πŸ‘“    An error occurred reading the results. Is the document store working?"
                    )
                except Exception as e:
                    if "API key is invalid" in str(e):
                        logging.exception(e)
                        st.error("🐞    incorrect API key provided. You can find your API key at https://platform.openai.com/account/api-keys.")
                    else:
                        logging.exception(e)
                        st.error("🐞    An error occurred during the request.")
    # Display results
    if (st.session_state.results_extractive or st.session_state.results_generative) and run_query:

        # Handle Extractive Answers
        if task_selection == 'Extractive':
            results = st.session_state.results_extractive

            st.subheader("Extracted Answers:")

            if 'answers' in results:
                answers = results['answers']
                treshold = 0.2
                higher_then_treshold = any(ans.score > treshold for ans in answers)
                if not higher_then_treshold:
                    st.markdown(f"<span style='color:red'>Please note none of the answers achieved a score higher then {int(treshold) * 100}%. Which probably means that the desired answer is not in the searched documents.</span>", unsafe_allow_html=True)
                for count, answer in enumerate(answers):
                    if answer.answer:
                        text, context = answer.answer, answer.context
                        start_idx = context.find(text)
                        end_idx = start_idx + len(text)
                        score = round(answer.score, 3)
                        st.markdown(f"**Answer {count + 1}:**")
                        st.markdown(
                            context[:start_idx] + str(annotation(body=text, label=f'SCORE {score}', background='#964448', color='#ffffff')) + context[end_idx:],
                            unsafe_allow_html=True,
                        )
                    else:
                        st.info(
                            "πŸ€” &nbsp;&nbsp; Haystack is unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
                        )

        # Handle Generative Answers
        elif task_selection == 'Generative':
            results = st.session_state.results_generative
            st.subheader("Generated Answer:")
            if 'results' in results:
                st.markdown("**Answer:**")
                st.write(results['results'][0])

        # Handle Retrieved Documents
        if 'documents' in results:
            retrieved_documents = results['documents']
            st.subheader("Retriever Results:")

            data = []
            for i, document in enumerate(retrieved_documents):
                # Truncate the content
                truncated_content = (document.content[:150] + '...') if len(document.content) > 150 else document.content
                data.append([i + 1, document.meta['name'], truncated_content])

            # Convert data to DataFrame and display using Streamlit
            df = pd.DataFrame(data, columns=['Ranked Context', 'Document Name', 'Content'])
            st.table(df)

except SystemExit as e:
    os._exit(e.code)