Spaces:
Sleeping
Sleeping
import streamlit as st | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-1.5-pro", | |
temperature=0, | |
max_tokens=None, | |
timeout=None, | |
max_retries=2, | |
# other params... | |
) | |
import os | |
fileNames = os.listdir("./data") | |
from langchain_text_splitters import RecursiveCharacterTextSplitter # type: ignore | |
#from langchain_community.document_loaders import PyPDFDirectoryLoader | |
from langchain_chroma import Chroma # type: ignore | |
from langchain_community.document_loaders import UnstructuredPDFLoader # type: ignore | |
path_dir = "data/" | |
docs = [] | |
for item in fileNames: | |
pdf_loader = UnstructuredPDFLoader(path_dir + item) | |
docs += pdf_loader.load() | |
chunk_size = 1000 | |
chunk_overlap = 200 | |
separators: list[str] = [ | |
"\n\n", | |
"\n", | |
" ", | |
".", | |
",", | |
"\u200b", # Zero-width space | |
"\uff0c", # Fullwidth comma | |
"\u3001", # Ideographic comma | |
"\uff0e", # Fullwidth full stop | |
"\u3002", # Ideographic full stop | |
"", | |
] | |
char_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=chunk_size, | |
chunk_overlap=chunk_overlap, | |
length_function=len, | |
is_separator_regex=False, | |
separators=separators | |
) | |
text_output = char_splitter.split_documents(docs) | |
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings, HuggingFaceEmbeddings | |
from langchain.embeddings import OpenAIEmbeddings, HuggingFaceInstructEmbeddings, GPT4AllEmbeddings | |
embedding_model = HuggingFaceEmbeddings(model_name="keepitreal/vietnamese-sbert") #try to optimize | |
chroma_db = Chroma.from_documents(text_output, embedding=embedding_model) | |
retriever = chroma_db.as_retriever(search_kwargs = {"k":10}, max_tokens_limit=1024, search_type = "similarity") | |
def format_docs(docs): | |
for doc in docs: | |
return "\n\n".join(doc.page_content for doc in docs).strip() | |
#from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.runnables import RunnablePassthrough, RunnableLambda | |
input_data = { | |
"context": retriever | format_docs, | |
"question": RunnablePassthrough() | |
} | |
from langchain import LLMChain | |
from langchain import PromptTemplate | |
template = "Bạn là một cố vấn học tập tại trường đại học Công Nghệ Thông Tin, một người nắm rõ các quy chế đào tạo, dựa vào ngữ cảnh sau để trả lời cho câu hỏi của sinh viên\n{context}\n### Câu hỏi:\n{question}\n### Trả lời:" | |
#prompt = PromptTemplate(template = template, input_variables=["context","chat_history","question"]) | |
prompt = PromptTemplate(template = template, input_variables=["context", "question"]) | |
rag_chain = ( | |
input_data | |
| prompt | |
| llm | |
#| str_parser | |
) | |
query = st.text_area("Bạn muốn hỏi cố vấn học tập điều gì?") | |
if query: | |
result = rag_chain.invoke(query) | |
st.json({ | |
"response": result.content | |
}) | |