llm-app / app.py
donb-hf's picture
update app.py
db459ed
raw
history blame
No virus
2.32 kB
import os
import chainlit as cl
from dotenv import load_dotenv
from operator import itemgetter
from langchain import hub
from langchain_groq import ChatGroq
from langchain_openai import OpenAIEmbeddings
from langchain_qdrant import Qdrant
from langchain_core.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough
from langchain.schema.runnable.config import RunnableConfig
from starters import set_starters
load_dotenv()
LANGCHAIN_PROJECT = os.environ["LANGCHAIN_PROJECT"]
LANGCHAIN_ENDPOINT = os.environ["LANGCHAIN_ENDPOINT"]
LANGCHAIN_API_KEY = os.environ["LANGCHAIN_API_KEY"]
LANGCHAIN_TRACING_V2 = os.environ["LANGCHAIN_TRACING_V2"]
LANGCHAIN_HUB_PROMPT = os.environ["LANGCHAIN_HUB_PROMPT"]
GROQ_API_KEY = os.environ["GROQ_API_KEY"]
llm = ChatGroq(model="llama3-70b-8192", temperature=0.3)
prompt = hub.pull(LANGCHAIN_HUB_PROMPT)
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
embedding = OpenAIEmbeddings(model="text-embedding-3-small")
QDRANT_API_KEY = os.environ["QDRANT_API_KEY"]
QDRANT_API_URL = os.environ["QDRANT_API_URL"]
QDRANT_COLLECTION = os.environ["QDRANT_COLLECTION"]
collection = QDRANT_COLLECTION
qdrant = Qdrant.from_existing_collection(
embedding=embedding,
collection_name=collection,
url=QDRANT_API_URL,
api_key=QDRANT_API_KEY,
prefer_grpc=True,
)
retriever = qdrant.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": 0.5, "k": 5}
)
@cl.on_chat_start
async def start_chat():
rag_chain = (
{"context": itemgetter("question") | retriever, "question": itemgetter("question")}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": prompt | llm, "context": itemgetter("context")}
)
cl.user_session.set("rag_chain", rag_chain)
@cl.on_message
async def main(message: cl.Message):
rag_chain = cl.user_session.get("rag_chain")
msg = cl.Message(content="")
response = await rag_chain.ainvoke(
{"question": message.content},
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]),
)
context = response["context"]
response_content = response["response"].content
await msg.stream_token(response_content)
await msg.send()