Spaces:
Paused
Paused
Carlosito16
commited on
Commit
β’
986ac67
1
Parent(s):
4d045a8
Upload 3 files
Browse files- pages/1_data.py +150 -0
- pages/2_model.py +51 -0
- pages/3_chat.py +28 -0
pages/1_data.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import copy
|
4 |
+
from googletrans import Translator
|
5 |
+
from langchain.vectorstores import FAISS
|
6 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
7 |
+
from streamlit_extras.row import row
|
8 |
+
import requests
|
9 |
+
from bs4 import BeautifulSoup
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
from collections import Counter
|
12 |
+
import torch
|
13 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
14 |
+
from langchain.vectorstores import FAISS
|
15 |
+
from langchain import HuggingFacePipeline
|
16 |
+
from langchain.chains import RetrievalQA
|
17 |
+
from collections import Counter
|
18 |
+
|
19 |
+
if 'faiss_db' not in st.session_state:
|
20 |
+
st.session_state['faiss_db'] = 0
|
21 |
+
|
22 |
+
if 'chunked_count_list' not in st.session_state:
|
23 |
+
st.session_state['chunked_count_list'] = 0
|
24 |
+
|
25 |
+
if 'chunked_df' not in st.session_state:
|
26 |
+
st.session_state['chunked_df'] = 0
|
27 |
+
|
28 |
+
|
29 |
+
def make_clickable(link):
|
30 |
+
text = link.split()[0]
|
31 |
+
return f'<a target="_blank" href="{link}">{text}</a>'
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
user_agent = 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36'
|
36 |
+
headers = {'User-Agent': user_agent}
|
37 |
+
|
38 |
+
def scrape_url(url_list):
|
39 |
+
all_whole_text = []
|
40 |
+
for url in url_list:
|
41 |
+
main_url = url
|
42 |
+
html_doc = requests.get(main_url, headers =headers )
|
43 |
+
soup = BeautifulSoup(html_doc.text, 'html.parser')
|
44 |
+
whole_text = ""
|
45 |
+
|
46 |
+
for paragraph in soup.find_all():
|
47 |
+
if paragraph.name in ["p", "ol"]:
|
48 |
+
whole_text += paragraph.text.replace('\xa0', '').replace('\n', '').strip()
|
49 |
+
|
50 |
+
all_whole_text.append(whole_text)
|
51 |
+
|
52 |
+
return all_whole_text
|
53 |
+
|
54 |
+
def create_count_list(chuked_text):
|
55 |
+
original_count_list = []
|
56 |
+
|
57 |
+
for item in range(len(chuked_text)):
|
58 |
+
original_count_list.append(chuked_text[item].metadata['document'])
|
59 |
+
item_counts = Counter(original_count_list)
|
60 |
+
count_list = list(item_counts.values())
|
61 |
+
return count_list
|
62 |
+
|
63 |
+
def thai_to_eng(text):
|
64 |
+
translated = translator.translate(text, src='th', dest ='en')
|
65 |
+
return translated
|
66 |
+
|
67 |
+
def eng_to_thai(text):
|
68 |
+
translated = translator.translate(text, src='en', dest ='th')
|
69 |
+
return translated
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
# st.set_page_config(page_title=None, page_icon=None, layout="wide")
|
74 |
+
|
75 |
+
url_list = ["https://www.mindphp.com/ΰΈΰΈΉΰΉΰΈ‘ΰΈ·ΰΈ/openerp-manual.html#google_vignette",
|
76 |
+
"https://www.mindphp.com/ΰΈΰΈΉΰΉΰΈ‘ΰΈ·ΰΈ/openerp-manual/7874-refund.html",
|
77 |
+
"https://www.mindphp.com/ΰΈΰΈΉΰΉΰΈ‘ΰΈ·ΰΈ/openerp-manual/8842-50-percent-discount-on-erp.html",
|
78 |
+
"https://www.mindphp.com/ΰΈΰΈΉΰΉΰΈ‘ΰΈ·ΰΈ/openerp-manual/7873-hr-payroll-account.html",
|
79 |
+
"https://www.mindphp.com/ΰΈΰΈΉΰΉΰΈ‘ΰΈ·ΰΈ/openerp-manual/4255-supplier-payments.html"]#or whatever default
|
80 |
+
|
81 |
+
metadatas = [{"document": i, "url" : j} for i, j in enumerate(url_list)]
|
82 |
+
|
83 |
+
scrape_list = scrape_url(url_list)
|
84 |
+
translator = Translator()
|
85 |
+
|
86 |
+
splitter_row = row([2, 2, 1], vertical_align="bottom")
|
87 |
+
var1 = splitter_row.number_input("Chunk Size", value = 1200)
|
88 |
+
var2 = splitter_row.number_input("Chunk Overlap Size", value = 100)
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
split_button = splitter_row.button("Split the data")
|
93 |
+
|
94 |
+
if split_button:
|
95 |
+
|
96 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
97 |
+
# Set a really small chunk size, just to show.
|
98 |
+
chunk_size = var1,
|
99 |
+
chunk_overlap = var2,
|
100 |
+
length_function = len
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
chuked_text = text_splitter.create_documents([doc for doc in scrape_list], metadatas = metadatas)
|
105 |
+
chunked_count_list = create_count_list(chuked_text)
|
106 |
+
|
107 |
+
print(len(url_list), len(chunked_count_list))
|
108 |
+
|
109 |
+
url_dataframe = pd.DataFrame({'link': url_list, 'number_of_chunks': chunked_count_list})
|
110 |
+
url_dataframe['link'] = url_dataframe['link'].apply(make_clickable)
|
111 |
+
url_dataframe = url_dataframe.to_html(escape=False)
|
112 |
+
|
113 |
+
st.session_state['chunked_df'] = url_dataframe
|
114 |
+
|
115 |
+
|
116 |
+
st.write(url_dataframe, unsafe_allow_html=True)
|
117 |
+
# st.dataframe(url_dataframe)
|
118 |
+
|
119 |
+
with st.expander("chunked items"):
|
120 |
+
st.json(chuked_text)
|
121 |
+
|
122 |
+
|
123 |
+
|
124 |
+
translated_chunk_text = copy.deepcopy(chuked_text)
|
125 |
+
for chunk in range(len(translated_chunk_text)):
|
126 |
+
translated_chunk_text[chunk].page_content = thai_to_eng(translated_chunk_text[chunk].page_content).text
|
127 |
+
|
128 |
+
# st.json(translated_chunk_text)
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
embedding_model = HuggingFaceInstructEmbeddings(model_name='hkunlp/instructor-base',
|
133 |
+
model_kwargs = {'device': torch.device('cuda' if torch.cuda.is_available() else 'cpu')})
|
134 |
+
|
135 |
+
faiss_db = FAISS.from_documents(translated_chunk_text, embedding_model)
|
136 |
+
|
137 |
+
|
138 |
+
st.session_state['faiss_db'] = faiss_db
|
139 |
+
|
140 |
+
|
141 |
+
st.session_state['chunked_count_list'] = chunked_count_list
|
142 |
+
|
143 |
+
|
144 |
+
st.write('successfully preprocessed data β
')
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
|
pages/2_model.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
3 |
+
from langchain import HuggingFacePipeline
|
4 |
+
from langchain.chains import RetrievalQA
|
5 |
+
from streamlit_extras.row import row
|
6 |
+
|
7 |
+
if 'model' not in st.session_state:
|
8 |
+
st.session_state['model'] = 0
|
9 |
+
if 'max_length' not in st.session_state:
|
10 |
+
st.session_state['max_length'] = 0
|
11 |
+
if 'temperature' not in st.session_state:
|
12 |
+
st.session_state['temperature'] = 0
|
13 |
+
if 'repetition_penalty' not in st.session_state:
|
14 |
+
st.session_state['repetition_penalty'] = 0
|
15 |
+
|
16 |
+
def load_llm_model(max_length, temperature, repetition_penalty):
|
17 |
+
# llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
|
18 |
+
# task= 'text2text-generation',
|
19 |
+
# model_kwargs={ "device_map": "auto",
|
20 |
+
# "load_in_8bit": True,"max_length": 256, "temperature": 0,
|
21 |
+
# "repetition_penalty": 1.5})
|
22 |
+
|
23 |
+
|
24 |
+
llm = HuggingFacePipeline.from_model_id(model_id= 'lmsys/fastchat-t5-3b-v1.0',
|
25 |
+
task= 'text2text-generation',
|
26 |
+
|
27 |
+
model_kwargs={ "max_length": max_length, "temperature": temperature,
|
28 |
+
"torch_dtype":torch.float32,
|
29 |
+
"repetition_penalty": repetition_penalty})
|
30 |
+
return llm
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
model_row = row([2, 2, 2], vertical_align="bottom")
|
35 |
+
max_length = model_row.number_input("max_length", value = 256)
|
36 |
+
temperature = model_row.number_input("temperature", value = 0)
|
37 |
+
repetition_penalty = model_row.number_input("repetition_penalty", value = 1.3)
|
38 |
+
load_model_button = st.button("load model")
|
39 |
+
|
40 |
+
|
41 |
+
if load_model_button:
|
42 |
+
st.session_state['max_length'] = max_length
|
43 |
+
st.session_state['temperature'] = temperature
|
44 |
+
st.session_state['repetition_penalty'] = repetition_penalty
|
45 |
+
|
46 |
+
st.session_state['model'] = load_llm_model(max_length, temperature, repetition_penalty)
|
47 |
+
|
48 |
+
st.write('successfully model loaded β
')
|
49 |
+
st.markdown(st.session_state['max_length'])
|
50 |
+
st.markdown(st.session_state['temperature'])
|
51 |
+
st.markdown(st.session_state['repetition_penalty'])
|
pages/3_chat.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_extras.stateful_chat import chat, add_message
|
3 |
+
from langchain.chains import RetrievalQA
|
4 |
+
|
5 |
+
|
6 |
+
with st.expander("key information"):
|
7 |
+
st.write( st.session_state['chunked_df'], unsafe_allow_html=True)
|
8 |
+
st.markdown(st.session_state['max_length'])
|
9 |
+
st.markdown(st.session_state['temperature'])
|
10 |
+
st.markdown(st.session_state['repetition_penalty'])
|
11 |
+
|
12 |
+
|
13 |
+
# qa_retriever = RetrievalQA.from_chain_type(llm=st.session_state['llm_model'] , chain_type="stuff",
|
14 |
+
# retriever=st.session_state['faiss_db'].as_retriever())
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
# with chat(key="my_chat"):
|
19 |
+
# if prompt := st.chat_input():
|
20 |
+
# add_message("user", prompt, avatar="π§βπ»")
|
21 |
+
# # def stream_echo():
|
22 |
+
# # for word in prompt.split():
|
23 |
+
# # yield word + " "
|
24 |
+
# # time.sleep(0.15)
|
25 |
+
# add_message("assistant", "Echo: ", qa_retriever.run(prompt), avatar="π¦")
|
26 |
+
|
27 |
+
# query = "How to process docuemnts about HR"
|
28 |
+
# docs = st.session_state['faiss_db'].similarity_search(query)
|