File size: 2,803 Bytes
8e29341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
# app_combined_prompt.py
import modules.app_constants as app_constants # Ensure this is correctly referenced
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from openai import OpenAI
from modules import app_logger, common_utils, app_st_session_utils
# Use the logger from app_config
app_logger = app_logger.app_logger
# Define a function to query the language model
def query_llm(prompt, page="nav_private_ai", retriever=None, message_store=None, use_retrieval_chain=False, last_page=None, username=""):
try:
# Choose the language model client based on the use_retrieval_chain flag
if use_retrieval_chain:
app_logger.info("Using ChatOpenAI with RetrievalQAWithSourcesChain")
llm = ChatOpenAI(
model_name=app_constants.MODEL_NAME,
openai_api_key=app_constants.openai_api_key,
base_url=app_constants.local_model_uri,
streaming=True
)
qa = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type=app_constants.RAG_TECHNIQUE,
retriever=retriever,
return_source_documents=False
)
else:
app_logger.info("Using direct OpenAI API call")
llm = OpenAI(
base_url=app_constants.local_model_uri,
api_key=app_constants.openai_api_key
)
# Update page messages if there's a change in the page
if last_page != page:
app_logger.info(f"Updating messages for new page: {page}")
common_utils.get_system_role(page, message_store)
# Construct messages to send to the LLM, excluding timestamps
messages_to_send = common_utils.construct_messages_to_send(page, message_store, prompt)
app_logger.debug(messages_to_send)
# Sending the messages to the LLM and retrieving the response
response = None
if use_retrieval_chain:
response = qa.invoke(prompt)
else:
response = llm.chat.completions.create(
model=app_constants.MODEL_NAME,
messages=messages_to_send
)
# Process the response
raw_msg = response.get('answer') if use_retrieval_chain else response.choices[0].message.content
source_info = response.get('sources', '').strip() if use_retrieval_chain else ''
formatted_msg = app_st_session_utils.format_response(raw_msg + "Source: " + source_info if source_info else raw_msg)
return formatted_msg
except Exception as e:
error_message = f"An error occurred while querying the language model: {e}"
app_logger.error(error_message)
return error_message
|