File size: 2,803 Bytes
8e29341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# app_combined_prompt.py
import modules.app_constants as app_constants  # Ensure this is correctly referenced
from langchain_openai import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from openai import OpenAI
from modules import app_logger, common_utils, app_st_session_utils

# Use the logger from app_config
app_logger = app_logger.app_logger

# Define a function to query the language model
def query_llm(prompt, page="nav_private_ai", retriever=None, message_store=None, use_retrieval_chain=False, last_page=None, username=""):
    try:
        # Choose the language model client based on the use_retrieval_chain flag
        if use_retrieval_chain:
            app_logger.info("Using ChatOpenAI with RetrievalQAWithSourcesChain")
            llm = ChatOpenAI(
                model_name=app_constants.MODEL_NAME,
                openai_api_key=app_constants.openai_api_key,
                base_url=app_constants.local_model_uri,
                streaming=True
            )
            qa = RetrievalQAWithSourcesChain.from_chain_type(
                llm=llm,
                chain_type=app_constants.RAG_TECHNIQUE,
                retriever=retriever,
                return_source_documents=False
            )
        else:
            app_logger.info("Using direct OpenAI API call")
            llm = OpenAI(
                base_url=app_constants.local_model_uri,
                api_key=app_constants.openai_api_key
            )

        # Update page messages if there's a change in the page
        if last_page != page:
            app_logger.info(f"Updating messages for new page: {page}")
            common_utils.get_system_role(page, message_store)

        # Construct messages to send to the LLM, excluding timestamps
        messages_to_send = common_utils.construct_messages_to_send(page, message_store, prompt)
        app_logger.debug(messages_to_send)
        # Sending the messages to the LLM and retrieving the response
        response = None
        if use_retrieval_chain:
            response = qa.invoke(prompt)
        else:
            response = llm.chat.completions.create(
                model=app_constants.MODEL_NAME,
                messages=messages_to_send
            )

        # Process the response
        raw_msg = response.get('answer') if use_retrieval_chain else response.choices[0].message.content
        source_info = response.get('sources', '').strip() if use_retrieval_chain else ''
        formatted_msg = app_st_session_utils.format_response(raw_msg + "Source: " + source_info if source_info else raw_msg)

        return formatted_msg

    except Exception as e:
        error_message = f"An error occurred while querying the language model: {e}"
        app_logger.error(error_message)
        return error_message