Spaces:
Sleeping
Sleeping
theekshana
commited on
Commit
•
38be0ae
1
Parent(s):
fce2427
Fix error LLMchain validation
Browse files- reggpt/agent/agent.py +0 -33
- reggpt/agents/agent.py +0 -0
- reggpt/api/router.py +2 -4
- reggpt/app.py +0 -2
- reggpt/chains/llmChain.py +9 -9
- reggpt/configs/api.py +8 -8
- reggpt/configs/{config.py → model.py} +0 -0
- reggpt/controller/agent.py +4 -19
- reggpt/controller/router.py +8 -25
- reggpt/logs/__init__.py +0 -0
- reggpt/logs/log +0 -0
- reggpt/prompts/general.py +5 -3
- reggpt/prompts/router.py +2 -2
- reggpt/routers/controller.py +2 -2
- reggpt/routers/general.py +5 -20
- reggpt/routers/out_of_domain.py +1 -28
- reggpt/routers/qa.py +4 -12
- reggpt/routers/qaPipeline.py +1 -1
reggpt/agent/agent.py
DELETED
@@ -1,33 +0,0 @@
|
|
1 |
-
import logging
|
2 |
-
logger = logging.getLogger(__name__)
|
3 |
-
from fastapi import HTTPException
|
4 |
-
import time
|
5 |
-
from routers.qaPipeline import run_router_chain, chain_selector
|
6 |
-
|
7 |
-
def run_agent(query):
|
8 |
-
try:
|
9 |
-
logger.info(f"run_agent : Question: {query}")
|
10 |
-
print(f"---------------- run_agent : Question: {query} ----------------")
|
11 |
-
# Get the answer from the chain
|
12 |
-
start = time.time()
|
13 |
-
chain_type = run_router_chain(query)
|
14 |
-
res = chain_selector(chain_type,query)
|
15 |
-
end = time.time()
|
16 |
-
|
17 |
-
# log the result
|
18 |
-
logger.error(f"---------------- Answer (took {round(end - start, 2)} s.) \n: {res}")
|
19 |
-
print(f" \n ---------------- Answer (took {round(end - start, 2)} s.): -------------- \n")
|
20 |
-
|
21 |
-
return res
|
22 |
-
|
23 |
-
except HTTPException as e:
|
24 |
-
print('HTTPException')
|
25 |
-
print(e)
|
26 |
-
logger.exception(e)
|
27 |
-
raise e
|
28 |
-
|
29 |
-
except Exception as e:
|
30 |
-
print('Exception')
|
31 |
-
print(e)
|
32 |
-
logger.exception(e)
|
33 |
-
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reggpt/agents/agent.py
ADDED
File without changes
|
reggpt/api/router.py
CHANGED
@@ -1,18 +1,15 @@
|
|
1 |
import time
|
2 |
-
|
3 |
|
4 |
from fastapi import APIRouter, HTTPException, status
|
5 |
from fastapi import HTTPException, status
|
6 |
|
7 |
from schemas.schema import UserQuery, LoginRequest, UserModel
|
8 |
from routers.controller import get_QA_Answers, get_avaliable_models
|
9 |
-
|
10 |
from configs.api import API_ENDPOINT_LOGIN,API_ENDPOINT_CHAT, API_ENDPOINT_HEALTH, API_ENDPOINT_MODEL
|
11 |
-
import logging
|
12 |
|
13 |
logger = logging.getLogger(__name__)
|
14 |
|
15 |
-
|
16 |
class ChatAPI:
|
17 |
|
18 |
def __init__(self):
|
@@ -25,6 +22,7 @@ class ChatAPI:
|
|
25 |
self.router.add_api_route(API_ENDPOINT_CHAT, self.chat, methods=["POST"])
|
26 |
|
27 |
async def hello(self):
|
|
|
28 |
return "Hello there!"
|
29 |
|
30 |
async def avaliable_models(self):
|
|
|
1 |
import time
|
2 |
+
import logging
|
3 |
|
4 |
from fastapi import APIRouter, HTTPException, status
|
5 |
from fastapi import HTTPException, status
|
6 |
|
7 |
from schemas.schema import UserQuery, LoginRequest, UserModel
|
8 |
from routers.controller import get_QA_Answers, get_avaliable_models
|
|
|
9 |
from configs.api import API_ENDPOINT_LOGIN,API_ENDPOINT_CHAT, API_ENDPOINT_HEALTH, API_ENDPOINT_MODEL
|
|
|
10 |
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
|
|
13 |
class ChatAPI:
|
14 |
|
15 |
def __init__(self):
|
|
|
22 |
self.router.add_api_route(API_ENDPOINT_CHAT, self.chat, methods=["POST"])
|
23 |
|
24 |
async def hello(self):
|
25 |
+
print(API_ENDPOINT_HEALTH)
|
26 |
return "Hello there!"
|
27 |
|
28 |
async def avaliable_models(self):
|
reggpt/app.py
CHANGED
@@ -23,7 +23,6 @@ from fastapi import FastAPI
|
|
23 |
from fastapi.middleware.cors import CORSMiddleware
|
24 |
|
25 |
from configs.api import API_TITLE, API_VERSION, API_DESCRIPTION
|
26 |
-
|
27 |
from api.router import ChatAPI
|
28 |
|
29 |
def filer():
|
@@ -65,7 +64,6 @@ app.add_middleware(
|
|
65 |
)
|
66 |
|
67 |
if __name__ == "__main__":
|
68 |
-
|
69 |
# config = uvicorn.Config("server:app",host=host, port=port, log_config= logging.basicConfig())
|
70 |
config = uvicorn.Config("app:app", host=host, port=port)
|
71 |
server = uvicorn.Server(config)
|
|
|
23 |
from fastapi.middleware.cors import CORSMiddleware
|
24 |
|
25 |
from configs.api import API_TITLE, API_VERSION, API_DESCRIPTION
|
|
|
26 |
from api.router import ChatAPI
|
27 |
|
28 |
def filer():
|
|
|
64 |
)
|
65 |
|
66 |
if __name__ == "__main__":
|
|
|
67 |
# config = uvicorn.Config("server:app",host=host, port=port, log_config= logging.basicConfig())
|
68 |
config = uvicorn.Config("app:app", host=host, port=port)
|
69 |
server = uvicorn.Server(config)
|
reggpt/chains/llmChain.py
CHANGED
@@ -28,11 +28,11 @@ from langchain.chains import ConversationalRetrievalChain
|
|
28 |
# from conversationBufferWindowMemory import ConversationBufferWindowMemory
|
29 |
|
30 |
# from langchain.prompts import PromptTemplate
|
31 |
-
|
32 |
-
from prompts import
|
33 |
-
from prompts import
|
34 |
-
from prompts import
|
35 |
-
from prompts import
|
36 |
|
37 |
|
38 |
def get_qa_chain(model_type,retriever):
|
@@ -70,8 +70,8 @@ def get_general_qa_chain(model_type):
|
|
70 |
|
71 |
try:
|
72 |
general_qa_llm = get_model(model_type)
|
73 |
-
|
74 |
-
general_qa_chain = general_qa_chain_prompt | general_qa_llm
|
75 |
|
76 |
logger.info("general_qa_chain created")
|
77 |
return general_qa_chain
|
@@ -87,8 +87,8 @@ def get_router_chain(model_type):
|
|
87 |
|
88 |
try:
|
89 |
router_llm = get_model(model_type)
|
90 |
-
|
91 |
-
router_chain = router_prompt | router_llm
|
92 |
|
93 |
logger.info("router_chain created")
|
94 |
return router_chain
|
|
|
28 |
# from conversationBufferWindowMemory import ConversationBufferWindowMemory
|
29 |
|
30 |
# from langchain.prompts import PromptTemplate
|
31 |
+
from langchain.chains import LLMChain
|
32 |
+
from prompts.document_combine import document_combine_prompt
|
33 |
+
from prompts.retrieval import retrieval_qa_chain_prompt
|
34 |
+
from prompts.general import general_qa_chain_prompt
|
35 |
+
from prompts.router import router_prompt
|
36 |
|
37 |
|
38 |
def get_qa_chain(model_type,retriever):
|
|
|
70 |
|
71 |
try:
|
72 |
general_qa_llm = get_model(model_type)
|
73 |
+
general_qa_chain = LLMChain(llm=general_qa_llm, prompt=general_qa_chain_prompt)
|
74 |
+
# general_qa_chain = general_qa_chain_prompt | general_qa_llm
|
75 |
|
76 |
logger.info("general_qa_chain created")
|
77 |
return general_qa_chain
|
|
|
87 |
|
88 |
try:
|
89 |
router_llm = get_model(model_type)
|
90 |
+
router_chain = LLMChain(llm=router_llm, prompt=router_prompt)
|
91 |
+
# router_chain = router_prompt | router_llm
|
92 |
|
93 |
logger.info("router_chain created")
|
94 |
return router_chain
|
reggpt/configs/api.py
CHANGED
@@ -17,12 +17,12 @@ API_TITLE = "RegGPT Back End v1"
|
|
17 |
API_VERSION = "0.1.0"
|
18 |
API_DESCRIPTION = "API_DESC"
|
19 |
|
20 |
-
API_ENDPOINT_PREFIX = "/api/v1"
|
21 |
-
API_DOCS_URL = f"{API_ENDPOINT_PREFIX}
|
22 |
-
API_REDOC_URL = f"{API_ENDPOINT_PREFIX}
|
23 |
-
API_OPENAPI_URL = f"{API_ENDPOINT_PREFIX}
|
24 |
|
25 |
-
API_ENDPOINT_HEALTH = f"{API_ENDPOINT_PREFIX}
|
26 |
-
API_ENDPOINT_CHAT = f"{API_ENDPOINT_PREFIX}
|
27 |
-
API_ENDPOINT_MODEL = f"{API_ENDPOINT_PREFIX}
|
28 |
-
API_ENDPOINT_LOGIN = f"{API_ENDPOINT_PREFIX}
|
|
|
17 |
API_VERSION = "0.1.0"
|
18 |
API_DESCRIPTION = "API_DESC"
|
19 |
|
20 |
+
API_ENDPOINT_PREFIX = "/api/v1/"
|
21 |
+
API_DOCS_URL = f"{API_ENDPOINT_PREFIX}docs"
|
22 |
+
API_REDOC_URL = f"{API_ENDPOINT_PREFIX}redoc"
|
23 |
+
API_OPENAPI_URL = f"{API_ENDPOINT_PREFIX}openapi.json"
|
24 |
|
25 |
+
API_ENDPOINT_HEALTH = f"{API_ENDPOINT_PREFIX}health"
|
26 |
+
API_ENDPOINT_CHAT = f"{API_ENDPOINT_PREFIX}chat"
|
27 |
+
API_ENDPOINT_MODEL = f"{API_ENDPOINT_PREFIX}models"
|
28 |
+
API_ENDPOINT_LOGIN = f"{API_ENDPOINT_PREFIX}login"
|
reggpt/configs/{config.py → model.py}
RENAMED
File without changes
|
reggpt/controller/agent.py
CHANGED
@@ -4,30 +4,15 @@ import logging
|
|
4 |
logger = logging.getLogger(__name__)
|
5 |
from dotenv import load_dotenv
|
6 |
from fastapi import HTTPException
|
7 |
-
from chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
8 |
-
from output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
9 |
|
10 |
-
from configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
11 |
-
from utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
12 |
load_dotenv()
|
13 |
|
14 |
verbose = os.environ.get('VERBOSE')
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
# model_type="tiiuae/falcon-7b-instruct"
|
21 |
-
|
22 |
-
# retriever=load_faiss_retriever()
|
23 |
-
retriever=load_ensemble_retriever()
|
24 |
-
# retriever=load_multi_query_retriever(multi_query_model_type)
|
25 |
-
logger.info("retriever loaded:")
|
26 |
-
|
27 |
-
qa_chain= get_qa_chain(qa_model_type,retriever)
|
28 |
-
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
29 |
-
router_chain= get_router_chain(router_model_type)
|
30 |
-
|
31 |
|
32 |
def chain_selector(chain_type, query):
|
33 |
chain_type = chain_type.lower().strip()
|
|
|
4 |
logger = logging.getLogger(__name__)
|
5 |
from dotenv import load_dotenv
|
6 |
from fastapi import HTTPException
|
|
|
|
|
7 |
|
|
|
|
|
8 |
load_dotenv()
|
9 |
|
10 |
verbose = os.environ.get('VERBOSE')
|
11 |
|
12 |
+
from controller.router import run_router_chain
|
13 |
+
from routers.out_of_domain import run_out_of_domain_chain
|
14 |
+
from routers.general import run_general_qa_chain
|
15 |
+
from routers.qa import run_qa_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
def chain_selector(chain_type, query):
|
18 |
chain_type = chain_type.lower().strip()
|
reggpt/controller/router.py
CHANGED
@@ -14,34 +14,17 @@
|
|
14 |
*************************************************************************/
|
15 |
"""
|
16 |
|
17 |
-
import os
|
18 |
import time
|
19 |
import logging
|
20 |
logger = logging.getLogger(__name__)
|
21 |
-
|
22 |
-
from
|
23 |
-
|
24 |
-
from
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
verbose = os.environ.get('VERBOSE')
|
31 |
-
|
32 |
-
qa_model_type=QA_MODEL_TYPE
|
33 |
-
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
34 |
-
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
35 |
-
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
36 |
-
# model_type="tiiuae/falcon-7b-instruct"
|
37 |
-
|
38 |
-
# retriever=load_faiss_retriever()
|
39 |
-
retriever=load_ensemble_retriever()
|
40 |
-
# retriever=load_multi_query_retriever(multi_query_model_type)
|
41 |
-
logger.info("retriever loaded:")
|
42 |
-
|
43 |
-
qa_chain= get_qa_chain(qa_model_type,retriever)
|
44 |
-
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
45 |
router_chain= get_router_chain(router_model_type)
|
46 |
|
47 |
def run_router_chain(query):
|
|
|
14 |
*************************************************************************/
|
15 |
"""
|
16 |
|
|
|
17 |
import time
|
18 |
import logging
|
19 |
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
from chains.llmChain import get_router_chain
|
22 |
+
|
23 |
+
from configs.model import ROUTER_MODEL_TYPE
|
24 |
+
|
25 |
+
|
26 |
+
router_model_type=ROUTER_MODEL_TYPE
|
27 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
router_chain= get_router_chain(router_model_type)
|
29 |
|
30 |
def run_router_chain(query):
|
reggpt/logs/__init__.py
ADDED
File without changes
|
reggpt/logs/log
ADDED
File without changes
|
reggpt/prompts/general.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from langchain.prompts import PromptTemplate
|
2 |
|
3 |
-
general_qa_template_Mixtral_V0= """
|
4 |
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
5 |
you can answer Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations related question .
|
6 |
|
@@ -21,5 +21,7 @@ Additionally, it's important to note that this AI assistant has access to an int
|
|
21 |
Question: {question}
|
22 |
"""
|
23 |
|
24 |
-
general_qa_chain_prompt = PromptTemplate(
|
25 |
-
|
|
|
|
|
|
1 |
from langchain.prompts import PromptTemplate
|
2 |
|
3 |
+
general_qa_template_Mixtral_V0 = """
|
4 |
You are the AI assistant of company 'boardpac' which provide services to company board members related to banking and financial sector.
|
5 |
you can answer Banking and Financial Services Sector like Banking & Financial regulations, legal framework, governance framework, compliance requirements as per Central Bank regulations related question .
|
6 |
|
|
|
21 |
Question: {question}
|
22 |
"""
|
23 |
|
24 |
+
# general_qa_chain_prompt = PromptTemplate(
|
25 |
+
# input_variables=["question"], template=general_qa_template_Mixtral_V0
|
26 |
+
# )
|
27 |
+
general_qa_chain_prompt = PromptTemplate.from_template(general_qa_template_Mixtral_V0)
|
reggpt/prompts/router.py
CHANGED
@@ -13,5 +13,5 @@ Give the correct name of question type. If you are not sure return "Not Sure" in
|
|
13 |
|
14 |
Question : {question}
|
15 |
"""
|
16 |
-
router_prompt=PromptTemplate(input_variables=["question"],template=router_template_Mixtral_V0)
|
17 |
-
|
|
|
13 |
|
14 |
Question : {question}
|
15 |
"""
|
16 |
+
# router_prompt=PromptTemplate(input_variables=["question"],template=router_template_Mixtral_V0)
|
17 |
+
router_prompt=PromptTemplate.from_template(router_template_Mixtral_V0)
|
reggpt/routers/controller.py
CHANGED
@@ -16,13 +16,13 @@
|
|
16 |
|
17 |
import logging
|
18 |
logger = logging.getLogger(__name__)
|
19 |
-
from configs.
|
20 |
|
21 |
# from qaPipeline import QAPipeline
|
22 |
# from qaPipeline_retriever_only import QAPipeline
|
23 |
# qaPipeline = QAPipeline()
|
24 |
|
25 |
-
from
|
26 |
|
27 |
def get_QA_Answers(userQuery):
|
28 |
# model=userQuery.model
|
|
|
16 |
|
17 |
import logging
|
18 |
logger = logging.getLogger(__name__)
|
19 |
+
from configs.model import AVALIABLE_MODELS , MEMORY_WINDOW_K
|
20 |
|
21 |
# from qaPipeline import QAPipeline
|
22 |
# from qaPipeline_retriever_only import QAPipeline
|
23 |
# qaPipeline = QAPipeline()
|
24 |
|
25 |
+
from controller.agent import run_agent
|
26 |
|
27 |
def get_QA_Answers(userQuery):
|
28 |
# model=userQuery.model
|
reggpt/routers/general.py
CHANGED
@@ -1,33 +1,18 @@
|
|
1 |
|
2 |
-
import os
|
3 |
import time
|
4 |
import logging
|
5 |
logger = logging.getLogger(__name__)
|
6 |
-
from dotenv import load_dotenv
|
7 |
-
from fastapi import HTTPException
|
8 |
-
from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
9 |
-
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
10 |
|
11 |
-
from
|
12 |
-
from
|
13 |
-
|
|
|
14 |
|
15 |
-
verbose = os.environ.get('VERBOSE')
|
16 |
|
17 |
-
qa_model_type=QA_MODEL_TYPE
|
18 |
-
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
19 |
-
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
20 |
-
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
21 |
-
# model_type="tiiuae/falcon-7b-instruct"
|
22 |
|
23 |
-
|
24 |
-
retriever=load_ensemble_retriever()
|
25 |
-
# retriever=load_multi_query_retriever(multi_query_model_type)
|
26 |
-
logger.info("retriever loaded:")
|
27 |
|
28 |
-
qa_chain= get_qa_chain(qa_model_type,retriever)
|
29 |
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
30 |
-
router_chain= get_router_chain(router_model_type)
|
31 |
|
32 |
def run_general_qa_chain(query):
|
33 |
try:
|
|
|
1 |
|
|
|
2 |
import time
|
3 |
import logging
|
4 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
5 |
|
6 |
+
from chains.llmChain import get_general_qa_chain
|
7 |
+
from output_parsers.output_parser import general_qa_chain_output_parser
|
8 |
+
|
9 |
+
from configs.model import GENERAL_QA_MODEL_TYPE
|
10 |
|
|
|
11 |
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
|
|
|
|
|
|
14 |
|
|
|
15 |
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
|
|
16 |
|
17 |
def run_general_qa_chain(query):
|
18 |
try:
|
reggpt/routers/out_of_domain.py
CHANGED
@@ -1,31 +1,4 @@
|
|
1 |
-
import
|
2 |
-
import time
|
3 |
-
import logging
|
4 |
-
logger = logging.getLogger(__name__)
|
5 |
-
from dotenv import load_dotenv
|
6 |
-
from fastapi import HTTPException
|
7 |
-
from reggpt.chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
8 |
-
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
9 |
|
10 |
-
from reggpt.configs.config import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
11 |
-
from reggpt.utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
12 |
-
load_dotenv()
|
13 |
-
|
14 |
-
verbose = os.environ.get('VERBOSE')
|
15 |
-
|
16 |
-
qa_model_type=QA_MODEL_TYPE
|
17 |
-
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
18 |
-
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
19 |
-
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
20 |
-
# model_type="tiiuae/falcon-7b-instruct"
|
21 |
-
|
22 |
-
# retriever=load_faiss_retriever()
|
23 |
-
retriever=load_ensemble_retriever()
|
24 |
-
# retriever=load_multi_query_retriever(multi_query_model_type)
|
25 |
-
logger.info("retriever loaded:")
|
26 |
-
|
27 |
-
qa_chain= get_qa_chain(qa_model_type,retriever)
|
28 |
-
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
29 |
-
router_chain= get_router_chain(router_model_type)
|
30 |
def run_out_of_domain_chain(query):
|
31 |
return out_of_domain_chain_parser(query)
|
|
|
1 |
+
from output_parsers.output_parser import out_of_domain_chain_parser
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
def run_out_of_domain_chain(query):
|
4 |
return out_of_domain_chain_parser(query)
|
reggpt/routers/qa.py
CHANGED
@@ -19,30 +19,22 @@ import time
|
|
19 |
import logging
|
20 |
logger = logging.getLogger(__name__)
|
21 |
from dotenv import load_dotenv
|
22 |
-
from
|
23 |
-
from
|
24 |
-
from reggpt.output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
25 |
|
26 |
-
from
|
27 |
-
from
|
28 |
load_dotenv()
|
29 |
|
30 |
verbose = os.environ.get('VERBOSE')
|
31 |
|
32 |
qa_model_type=QA_MODEL_TYPE
|
33 |
-
general_qa_model_type=GENERAL_QA_MODEL_TYPE
|
34 |
-
router_model_type=ROUTER_MODEL_TYPE #"google/flan-t5-xxl"
|
35 |
-
multi_query_model_type=Multi_Query_MODEL_TYPE #"google/flan-t5-xxl"
|
36 |
-
# model_type="tiiuae/falcon-7b-instruct"
|
37 |
-
|
38 |
# retriever=load_faiss_retriever()
|
39 |
retriever=load_ensemble_retriever()
|
40 |
# retriever=load_multi_query_retriever(multi_query_model_type)
|
41 |
logger.info("retriever loaded:")
|
42 |
|
43 |
qa_chain= get_qa_chain(qa_model_type,retriever)
|
44 |
-
general_qa_chain= get_general_qa_chain(general_qa_model_type)
|
45 |
-
router_chain= get_router_chain(router_model_type)
|
46 |
|
47 |
def run_qa_chain(query):
|
48 |
try:
|
|
|
19 |
import logging
|
20 |
logger = logging.getLogger(__name__)
|
21 |
from dotenv import load_dotenv
|
22 |
+
from chains.llmChain import get_qa_chain
|
23 |
+
from output_parsers.output_parser import qa_chain_output_parser
|
|
|
24 |
|
25 |
+
from configs.model import QA_MODEL_TYPE
|
26 |
+
from utils.retriever import load_ensemble_retriever
|
27 |
load_dotenv()
|
28 |
|
29 |
verbose = os.environ.get('VERBOSE')
|
30 |
|
31 |
qa_model_type=QA_MODEL_TYPE
|
|
|
|
|
|
|
|
|
|
|
32 |
# retriever=load_faiss_retriever()
|
33 |
retriever=load_ensemble_retriever()
|
34 |
# retriever=load_multi_query_retriever(multi_query_model_type)
|
35 |
logger.info("retriever loaded:")
|
36 |
|
37 |
qa_chain= get_qa_chain(qa_model_type,retriever)
|
|
|
|
|
38 |
|
39 |
def run_qa_chain(query):
|
40 |
try:
|
reggpt/routers/qaPipeline.py
CHANGED
@@ -23,7 +23,7 @@ from fastapi import HTTPException
|
|
23 |
from chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
24 |
from output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
25 |
|
26 |
-
from configs.
|
27 |
from utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
28 |
load_dotenv()
|
29 |
|
|
|
23 |
from chains.llmChain import get_qa_chain, get_general_qa_chain, get_router_chain
|
24 |
from output_parsers.output_parser import general_qa_chain_output_parser, qa_chain_output_parser, out_of_domain_chain_parser
|
25 |
|
26 |
+
from configs.model import QA_MODEL_TYPE, GENERAL_QA_MODEL_TYPE, ROUTER_MODEL_TYPE, Multi_Query_MODEL_TYPE
|
27 |
from utils.retriever import load_faiss_retriever, load_ensemble_retriever, load_multi_query_retriever
|
28 |
load_dotenv()
|
29 |
|