File size: 5,679 Bytes
f64588d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import sys

from llama_index import SimpleDirectoryReader, ServiceContext, StorageContext, VectorStoreIndex, download_loader,load_index_from_storage
from llama_index.llms import HuggingFaceLLM
from llama_index.embeddings import HuggingFaceEmbedding
from llama_index.vector_stores import ChromaVectorStore
from llama_index.storage.index_store import SimpleIndexStore
from llama_index.indices.query.schema import QueryBundle, QueryType
import chromadb
import streamlit as st
import time


st.set_page_config(page_title="Tesla Alert Analyzer", page_icon=":card_index_dividers:", initial_sidebar_state="expanded", layout="wide")

st.title(":card_index_dividers: Tesla Alert Analyzer")
st.info("""
Begin by uploading the case report in pptx format. Afterward, click on 'Process Document'. Once the document has been processed. You can enter question and click send, system will answer your question.
""")


if "process_doc" not in st.session_state:
    st.session_state.process_doc = False




def fmetadata(dummy: str): return {"file_path": ""}

def load_files(file_dir):

    PptxReader = download_loader("PptxReader")
    loader = SimpleDirectoryReader(input_dir=file_dir, file_extractor={".pptx": PptxReader(),}, file_metadata=fmetadata)
    documents  = loader.load_data()

    for doc in documents:
        doc.metadata["file_path"]=""

    return documents 

system_prompt = "You are a Q&A assistant. "
system_prompt += "Your goal is to answer questions as accurately as possible based on the instructions and context provided."
system_prompt += "Please say you do not know if you do not find answer."

# This will wrap the default prompts that are internal to llama-index
query_wrapper_prompt = "<|USER|>{query_str}<|ASSISTANT|>"



import torch
#torch.set_default_device('cuda')

@st.cache_resource
def llm_loading():
    print("before huggingfacellm")
    llm = HuggingFaceLLM(
        context_window=8000,
        max_new_tokens=500,
        generate_kwargs={"temperature": 0.1, "do_sample": True},
        system_prompt=system_prompt,
        query_wrapper_prompt=query_wrapper_prompt,
        tokenizer_name="mistralai/Mistral-7B-Instruct-v0.1",
        model_name="mistralai/Mistral-7B-Instruct-v0.1",
        device_map="auto",
        tokenizer_kwargs={"max_length": 8000},
        model_kwargs={"torch_dtype": torch.float16}
    )

    print("after huggingfacellm")
    embed_model =  HuggingFaceEmbedding(model_name="thenlper/gte-base")
    print("after embed_model")
    return llm,embed_model

llm, embed_model = llm_loading()

files_uploaded = st.sidebar.file_uploader("Upload the case report in PPT format", type="pptx", accept_multiple_files=True)

st.sidebar.info("""
Example pptx reports you can upload here:
""")

if st.sidebar.button("Process Document"):

    with st.spinner("Processing Document..."):

        data_dir = "data"
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)

        for uploaded_file in files_uploaded:
            print(f'file named {uploaded_file.name}')
            fname=f'{data_dir}/{uploaded_file.name}'
            with open(fname, 'wb') as f:
                f.write(uploaded_file.read())
        
        documents=load_files(data_dir)

        collection_name = "tesla_report"
        chroma_client = chromadb.PersistentClient()
        chroma_collection = chroma_client.get_or_create_collection(collection_name)
        vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
        storage_context = StorageContext.from_defaults(vector_store=vector_store)
        service_context = ServiceContext.from_defaults(
            chunk_size=8000,
            llm=llm,
            embed_model=embed_model
        )
        index = VectorStoreIndex.from_documents(documents, service_context=service_context, storage_context=storage_context)
        index.storage_context.persist()

        #chroma_collection.peek()
        
        #st.session_state.index = index
        st.session_state.process_doc = True

        st.toast("Document Processsed!")

        #st.session_state.process_doc = True

def clear_form():
    st.session_state.query_text = st.session_state["question"] 
    st.session_state["question"] = ""
    st.session_state["response"] = ""

@st.cache_resource
def reload_index(_llm,_embed_model, col ) :
    chroma_client = chromadb.PersistentClient()
    chroma_collection = chroma_client.get_or_create_collection(col)
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
    service_context = ServiceContext.from_defaults(llm=llm,embed_model=embed_model)
    load_index = VectorStoreIndex.from_vector_store(service_context=service_context, 
                                     vector_store=vector_store)
    return load_index

if st.session_state.process_doc:
    #alert number looks like APP_wnnn where nnn is a number. Please list out all the alerts uploaded in these files!
    search_text = st.text_input("Enter your question", key='question' )
    if st.button(label="Submit", on_click=clear_form):
        index = reload_index(llm,embed_model,"tesla_report" )
        query_engine =  index.as_query_engine()
        start_time = time.time()
        #qry = QueryBundle(search_text)
        #alert number looks like APP_wnnn where nnn is a number. Please list out all the alerts uploaded in these files!"
        st.write("Processing....")
        search_text = st.session_state.query_text
        print(search_text)
        response = query_engine.query(search_text)
        st.write(response.response)
        #st.session_state["end_time"] = "{:.2f}".format((time.time() - start_time))

        st.toast("Report Analysis Complete!")