import os import sys import time import boto3 from langchain_aws import BedrockLLM from langchain_community.embeddings import BedrockEmbeddings from langchain_community.vectorstores import FAISS from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough import gradio as gr module_path = ".." sys.path.append(os.path.abspath(module_path)) bedrock_client = boto3.client('bedrock-runtime',region_name=os.environ.get("AWS_DEFAULT_REGION", "us-west-2")) modelId = 'meta.llama3-1-70b-instruct-v1:0' llm = BedrockLLM( model_id=modelId, client=bedrock_client ) br_embeddings = BedrockEmbeddings(model_id="cohere.embed-multilingual-v3", client=bedrock_client) db = FAISS.load_local('faiss_index', embeddings=br_embeddings, allow_dangerous_deserialization=True) retriever = db.as_retriever(k=5) prompt = ChatPromptTemplate.from_messages([ ('system', "Answer the questions witht the provided context. Do not include based on the context or based on the documents in your answer." "Please say you do not know if you do not know or cannot find the information needed." "\n Question: {question} \nContext: {context}"), ('user', "{question}") ]) chat_history = [] def format_docs(docs): return "\n\n".join(doc.page_content for doc in docs) rag_chain = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) response = rag_chain.invoke("Who are the board of directors in KCE company?") def chat_gen(message, history): response = rag_chain.invoke(message) partial_message = "" for token in response: partial_message = partial_message + token time.sleep(0.05) yield partial_message initial_msg = "Hello! I am KCE assistant. You can ask me anything about KCE. I am happy to assist you." chatbot = gr.Chatbot(value = [[None, initial_msg]]) demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue() try: demo.launch(debug=False, share=False, show_api=False) demo.close() except Exception as e: demo.close() print(e) raise e