CVHT_DEMO1 / app.py
ThangNguyen27's picture
Update app.py
6db47da verified
raw
history blame
3 kB
import streamlit as st
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(
model="gemini-1.5-pro",
temperature=0,
max_tokens=None,
timeout=None,
max_retries=2,
# other params...
)
import os
fileNames = os.listdir("./data")
from langchain_text_splitters import RecursiveCharacterTextSplitter # type: ignore
#from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_chroma import Chroma # type: ignore
from langchain_community.document_loaders import UnstructuredPDFLoader # type: ignore
path_dir = "data/"
docs = []
for item in fileNames:
pdf_loader = UnstructuredPDFLoader(path_dir + item)
docs += pdf_loader.load()
chunk_size = 1000
chunk_overlap = 200
separators: list[str] = [
"\n\n",
"\n",
" ",
".",
",",
"\u200b", # Zero-width space
"\uff0c", # Fullwidth comma
"\u3001", # Ideographic comma
"\uff0e", # Fullwidth full stop
"\u3002", # Ideographic full stop
"",
]
char_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
length_function=len,
is_separator_regex=False,
separators=separators
)
text_output = char_splitter.split_documents(docs)
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings, GPT4AllEmbeddings
embedding_model = HuggingFaceEmbeddings(model_name="keepitreal/vietnamese-sbert") #try to optimize
chroma_db = Chroma.from_documents(text_output, embedding=embedding_model)
retriever = chroma_db.as_retriever(search_kwargs = {"k":10}, max_tokens_limit=1024, search_type = "similarity")
def format_docs(docs):
for doc in docs:
return "\n\n".join(doc.page_content for doc in docs).strip()
#from langchain_core.runnables import RunnablePassthrough
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
input_data = {
"context": retriever | format_docs,
"question": RunnablePassthrough()
}
from langchain import LLMChain
from langchain import PromptTemplate
template = "Bạn là một cố vấn học tập tại trường đại học Công Nghệ Thông Tin, một người nắm rõ các quy chế đào tạo, dựa vào ngữ cảnh sau để trả lời cho câu hỏi của sinh viên\n{context}\nĐây là lịch sử chat:{chat hisotry}\n### Câu hỏi:\n{question}\n### Trả lời:"
#prompt = PromptTemplate(template = template, input_variables=["context","chat_history","question"])
prompt = PromptTemplate(template = template, input_variables=["context", "question","chat history"])
rag_chain = (
input_data
| prompt
| llm
#| str_parser
)
query = st.text_area("Bạn muốn hỏi cố vấn học tập điều gì?")
if query:
result = rag_chain.invoke(query)
st.json({
"response": result.content
})