Carlosito16 commited on
Commit
986ac67
β€’
1 Parent(s): 4d045a8

Upload 3 files

Browse files
Files changed (3) hide show
  1. pages/1_data.py +150 -0
  2. pages/2_model.py +51 -0
  3. pages/3_chat.py +28 -0
pages/1_data.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import copy
4
+ from googletrans import Translator
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
7
+ from streamlit_extras.row import row
8
+ import requests
9
+ from bs4 import BeautifulSoup
10
+ from urllib.parse import urlparse
11
+ from collections import Counter
12
+ import torch
13
+ from langchain.embeddings import HuggingFaceInstructEmbeddings
14
+ from langchain.vectorstores import FAISS
15
+ from langchain import HuggingFacePipeline
16
+ from langchain.chains import RetrievalQA
17
+ from collections import Counter
18
+
19
+ if 'faiss_db' not in st.session_state:
20
+ st.session_state['faiss_db'] = 0
21
+
22
+ if 'chunked_count_list' not in st.session_state:
23
+ st.session_state['chunked_count_list'] = 0
24
+
25
+ if 'chunked_df' not in st.session_state:
26
+ st.session_state['chunked_df'] = 0
27
+
28
+
29
+ def make_clickable(link):
30
+ text = link.split()[0]
31
+ return f'<a target="_blank" href="{link}">{text}</a>'
32
+
33
+
34
+
35
+ user_agent = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36'
36
+ headers = {'User-Agent': user_agent}
37
+
38
+ def scrape_url(url_list):
39
+ all_whole_text = []
40
+ for url in url_list:
41
+ main_url = url
42
+ html_doc = requests.get(main_url, headers =headers )
43
+ soup = BeautifulSoup(html_doc.text, 'html.parser')
44
+ whole_text = ""
45
+
46
+ for paragraph in soup.find_all():
47
+ if paragraph.name in ["p", "ol"]:
48
+ whole_text += paragraph.text.replace('\xa0', '').replace('\n', '').strip()
49
+
50
+ all_whole_text.append(whole_text)
51
+
52
+ return all_whole_text
53
+
54
+ def create_count_list(chuked_text):
55
+ original_count_list = []
56
+
57
+ for item in range(len(chuked_text)):
58
+ original_count_list.append(chuked_text[item].metadata['document'])
59
+ item_counts = Counter(original_count_list)
60
+ count_list = list(item_counts.values())
61
+ return count_list
62
+
63
+ def thai_to_eng(text):
64
+ translated = translator.translate(text, src='th', dest ='en')
65
+ return translated
66
+
67
+ def eng_to_thai(text):
68
+ translated = translator.translate(text, src='en', dest ='th')
69
+ return translated
70
+
71
+
72
+
73
+ # st.set_page_config(page_title=None, page_icon=None, layout="wide")
74
+
75
+ url_list = ["https://www.mindphp.com/ΰΈ„ΰΈΉΰΉˆΰΈ‘ΰΈ·ΰΈ­/openerp-manual.html#google_vignette",
76
+ "https://www.mindphp.com/ΰΈ„ΰΈΉΰΉˆΰΈ‘ΰΈ·ΰΈ­/openerp-manual/7874-refund.html",
77
+ "https://www.mindphp.com/ΰΈ„ΰΈΉΰΉˆΰΈ‘ΰΈ·ΰΈ­/openerp-manual/8842-50-percent-discount-on-erp.html",
78
+ "https://www.mindphp.com/ΰΈ„ΰΈΉΰΉˆΰΈ‘ΰΈ·ΰΈ­/openerp-manual/7873-hr-payroll-account.html",
79
+ "https://www.mindphp.com/ΰΈ„ΰΈΉΰΉˆΰΈ‘ΰΈ·ΰΈ­/openerp-manual/4255-supplier-payments.html"]#or whatever default
80
+
81
+ metadatas = [{"document": i, "url" : j} for i, j in enumerate(url_list)]
82
+
83
+ scrape_list = scrape_url(url_list)
84
+ translator = Translator()
85
+
86
+ splitter_row = row([2, 2, 1], vertical_align="bottom")
87
+ var1 = splitter_row.number_input("Chunk Size", value = 1200)
88
+ var2 = splitter_row.number_input("Chunk Overlap Size", value = 100)
89
+
90
+
91
+
92
+ split_button = splitter_row.button("Split the data")
93
+
94
+ if split_button:
95
+
96
+ text_splitter = RecursiveCharacterTextSplitter(
97
+ # Set a really small chunk size, just to show.
98
+ chunk_size = var1,
99
+ chunk_overlap = var2,
100
+ length_function = len
101
+ )
102
+
103
+
104
+ chuked_text = text_splitter.create_documents([doc for doc in scrape_list], metadatas = metadatas)
105
+ chunked_count_list = create_count_list(chuked_text)
106
+
107
+ print(len(url_list), len(chunked_count_list))
108
+
109
+ url_dataframe = pd.DataFrame({'link': url_list, 'number_of_chunks': chunked_count_list})
110
+ url_dataframe['link'] = url_dataframe['link'].apply(make_clickable)
111
+ url_dataframe = url_dataframe.to_html(escape=False)
112
+
113
+ st.session_state['chunked_df'] = url_dataframe
114
+
115
+
116
+ st.write(url_dataframe, unsafe_allow_html=True)
117
+ # st.dataframe(url_dataframe)
118
+
119
+ with st.expander("chunked items"):
120
+ st.json(chuked_text)
121
+
122
+
123
+
124
+ translated_chunk_text = copy.deepcopy(chuked_text)
125
+ for chunk in range(len(translated_chunk_text)):
126
+ translated_chunk_text[chunk].page_content = thai_to_eng(translated_chunk_text[chunk].page_content).text
127
+
128
+ # st.json(translated_chunk_text)
129
+
130
+
131
+
132
+ embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
133
+ model_kwargs = {'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')})
134
+
135
+ faiss_db = FAISS.from_documents(translated_chunk_text, embedding_model)
136
+
137
+
138
+ st.session_state['faiss_db'] = faiss_db
139
+
140
+
141
+ st.session_state['chunked_count_list'] = chunked_count_list
142
+
143
+
144
+ st.write('successfully preprocessed data βœ…')
145
+
146
+
147
+
148
+
149
+
150
+
pages/2_model.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from langchain import HuggingFacePipeline
4
+ from langchain.chains import RetrievalQA
5
+ from streamlit_extras.row import row
6
+
7
+ if 'model' not in st.session_state:
8
+ st.session_state['model'] = 0
9
+ if 'max_length' not in st.session_state:
10
+ st.session_state['max_length'] = 0
11
+ if 'temperature' not in st.session_state:
12
+ st.session_state['temperature'] = 0
13
+ if 'repetition_penalty' not in st.session_state:
14
+ st.session_state['repetition_penalty'] = 0
15
+
16
+ def load_llm_model(max_length, temperature, repetition_penalty):
17
+ # llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
18
+ # task= 'text2text-generation',
19
+ # model_kwargs={ "device_map": "auto",
20
+ # "load_in_8bit": True,"max_length": 256, "temperature": 0,
21
+ # "repetition_penalty": 1.5})
22
+
23
+
24
+ llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
25
+ task= 'text2text-generation',
26
+
27
+ model_kwargs={ "max_length": max_length, "temperature": temperature,
28
+ "torch_dtype":torch.float32,
29
+ "repetition_penalty": repetition_penalty})
30
+ return llm
31
+
32
+
33
+
34
+ model_row = row([2, 2, 2], vertical_align="bottom")
35
+ max_length = model_row.number_input("max_length", value = 256)
36
+ temperature = model_row.number_input("temperature", value = 0)
37
+ repetition_penalty = model_row.number_input("repetition_penalty", value = 1.3)
38
+ load_model_button = st.button("load model")
39
+
40
+
41
+ if load_model_button:
42
+ st.session_state['max_length'] = max_length
43
+ st.session_state['temperature'] = temperature
44
+ st.session_state['repetition_penalty'] = repetition_penalty
45
+
46
+ st.session_state['model'] = load_llm_model(max_length, temperature, repetition_penalty)
47
+
48
+ st.write('successfully model loaded βœ…')
49
+ st.markdown(st.session_state['max_length'])
50
+ st.markdown(st.session_state['temperature'])
51
+ st.markdown(st.session_state['repetition_penalty'])
pages/3_chat.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from streamlit_extras.stateful_chat import chat, add_message
3
+ from langchain.chains import RetrievalQA
4
+
5
+
6
+ with st.expander("key information"):
7
+ st.write( st.session_state['chunked_df'], unsafe_allow_html=True)
8
+ st.markdown(st.session_state['max_length'])
9
+ st.markdown(st.session_state['temperature'])
10
+ st.markdown(st.session_state['repetition_penalty'])
11
+
12
+
13
+ # qa_retriever = RetrievalQA.from_chain_type(llm=st.session_state['llm_model'] , chain_type="stuff",
14
+ # retriever=st.session_state['faiss_db'].as_retriever())
15
+
16
+
17
+
18
+ # with chat(key="my_chat"):
19
+ # if prompt := st.chat_input():
20
+ # add_message("user", prompt, avatar="πŸ§‘β€πŸ’»")
21
+ # # def stream_echo():
22
+ # # for word in prompt.split():
23
+ # # yield word + " "
24
+ # # time.sleep(0.15)
25
+ # add_message("assistant", "Echo: ", qa_retriever.run(prompt), avatar="🦜")
26
+
27
+ # query = "How to process docuemnts about HR"
28
+ # docs = st.session_state['faiss_db'].similarity_search(query)