arxiv_chatbot / chat /model_manage.py
Artteiv's picture
Fix model retrieval and improve answering by adding summary context (#5)
9b87900 verified
raw
history blame
11.6 kB
# my_app/model_manager.py
import google.generativeai as genai
import chat.arxiv_bot.arxiv_bot_utils as utils
import json
model = None
model_retrieval = None
model_answer = None
RETRIEVAL_INSTRUCT = """You are an auto chatbot that response with only one action below based on user question.
1. If the guest question is asking about a science topic, you need to respond the information in JSON schema below:
{
"keywords": [a list of string keywords about the topic],
"description": "a paragraph describing the topic in about 50 to 100 words"
}
2. If the guest is not asking for any informations or documents, you need to respond in JSON schema below:
{
"answer": "your answer to the user question"
}"""
ANSWER_INSTRUCT = """You are a library assistant that help answering customer question based on the information given.
You always answer in a conversational form naturally and politely.
You must introduce all the records given, each must contain title, authors and the link to the pdf file."""
def create_model():
with open("apikey.txt","r") as apikey:
key = apikey.readline()
genai.configure(api_key=key)
for m in genai.list_models():
if 'generateContent' in m.supported_generation_methods:
print(m.name)
print("He was there")
config = genai.GenerationConfig(max_output_tokens=2048,
temperature=1.0)
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE",
},
]
global model, model_retrieval, model_answer
model = genai.GenerativeModel("gemini-1.5-pro-latest",
generation_config=config,
safety_settings=safety_settings)
model_retrieval = genai.GenerativeModel("gemini-1.5-pro-latest",
generation_config=config,
safety_settings=safety_settings,
system_instruction=RETRIEVAL_INSTRUCT)
model_answer = genai.GenerativeModel("gemini-1.5-pro-latest",
generation_config=config,
safety_settings=safety_settings,
system_instruction=ANSWER_INSTRUCT)
return model, model_answer, model_retrieval
def get_model():
global model, model_answer, model_retrieval
if model is None:
# Khởi tạo model ở đây
model, model_answer, model_retrieval = create_model() # Giả sử create_model là hàm tạo model của bạn
return model, model_answer, model_retrieval
def extract_keyword_prompt(query):
"""A prompt that return a JSON block as arguments for querying database"""
prompt = """[INST] SYSTEM: You are an auto chatbot that response with only one action below based on user question.
1. If the guest question is asking about a science topic, you need to respond the information in JSON schema below:
{
"keywords": [a list of string keywords about the topic],
"description": "a paragraph describing the topic in about 50 to 100 words"
}
2. If the guest is not asking for any informations or documents, you need to respond in JSON schema below:
{
"answer": "your answer to the user question"
}
QUESTION: """ + query + """[/INST]
ANSWER: """
return prompt
def make_answer_prompt(input, contexts):
"""A prompt that return the final answer, based on the queried context"""
prompt = (
"""[INST] You are a library assistant that help answering customer QUESTION based on the INFORMATION given.
You always answer in a conversational form naturally and politely.
You must introduce all the records given, each must contain title, authors and the link to the pdf file.
QUESTION: {input}
INFORMATION: '{contexts}'
[/INST]
ANSWER:
"""
).format(input=input, contexts=contexts)
return prompt
def retrieval_chat_template(question):
return {
"role":"user",
"parts":[f"QUESTION: {question} \n ANSWER:"]
}
def answer_chat_template(question, contexts):
return {
"role":"user",
"parts":[f"QUESTION: {question} \n INFORMATION: {contexts} \n ANSWER:"]
}
def response(args, db_instance):
"""Create response context, based on input arguments"""
keys = list(dict.keys(args))
if "answer" in keys:
return args['answer'], None # trả lời trực tiếp
if "keywords" in keys:
# perform query
query_texts = args["description"]
keywords = args["keywords"]
results = utils.db.query_relevant(keywords=keywords, query_texts=query_texts)
# print(results)
ids = results['metadatas'][0]
if len(ids) == 0:
# go crawl some
new_records = utils.crawl_arxiv(keyword_list=keywords, max_results=10)
print("Got new records: ",len(new_records))
if type(new_records) == str:
return "Error occured, information not found", new_records
utils.db.add(new_records)
db_instance.add(new_records)
results = utils.db.query_relevant(keywords=keywords, query_texts=query_texts)
ids = results['metadatas'][0]
print("Re-queried on chromadb, results: ",ids)
paper_id = [id['paper_id'] for id in ids]
paper_info = db_instance.query_id(paper_id)
print(paper_info)
records = [] # get title (2), author (3), link (6)
result_string = ""
if paper_info:
for i in range(len(paper_info)):
result_string += "Record no.{} - Title: {}, Author: {}, Link: {}, ".format(i+1,paper_info[i][2],paper_info[i][3],paper_info[i][6])
id = paper_info[i][0]
selected_document = utils.db.query_exact(id)["documents"]
doc_str = "Summary:"
for doc in selected_document:
doc_str+= doc + " "
result_string += doc_str
records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
return result_string, records
else:
return "Information not found", "Information not found"
# invoke llm and return result
# if "title" in keys:
# title = args['title']
# authors = utils.authors_str_to_list(args['author'])
# paper_info = db_instance.query(title = title,author = authors)
# # if query not found then go crawl brh
# # print(paper_info)
# if len(paper_info) == 0:
# new_records = utils.crawl_exact_paper(title=title,author=authors)
# print("Got new records: ",len(new_records))
# if type(new_records) == str:
# # print(new_records)
# return "Error occured, information not found", "Information not found"
# utils.db.add(new_records)
# db_instance.add(new_records)
# paper_info = db_instance.query(title = title,author = authors)
# print("Re-queried on chromadb, results: ",paper_info)
# # -------------------------------------
# records = [] # get title (2), author (3), link (6)
# result_string = ""
# for i in range(len(paper_info)):
# result_string += "Title: {}, Author: {}, Link: {}".format(paper_info[i][2],paper_info[i][3],paper_info[i][6])
# records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
# # process results:
# if len(result_string) == 0:
# return "Information not found", "Information not found"
# return result_string, records
# invoke llm and return result
def full_chain_single_question(input_prompt, db_instance):
try:
first_prompt = extract_keyword_prompt(input_prompt)
temp_answer = model.generate_content(first_prompt).text
args = json.loads(utils.trimming(temp_answer))
contexts, results = response(args, db_instance)
if not results:
# print(contexts)
return "Random question, direct return", contexts
else:
output_prompt = make_answer_prompt(input_prompt,contexts)
answer = model.generate_content(output_prompt).text
return temp_answer, answer
except Exception as e:
# print(e)
return temp_answer, "Error occured: " + str(e)
def format_chat_history_from_web(chat_history: list):
temp_chat = []
for message in chat_history:
temp_chat.append(
{
"role": message["role"],
"parts": [message["content"]]
}
)
return temp_chat
# def full_chain_history_question(chat_history: list, db_instance):
# try:
# temp_chat = format_chat_history_from_web(chat_history)
# print('Extracted temp chat: ',temp_chat)
# first_prompt = extract_keyword_prompt(temp_chat[-1]["parts"][0])
# temp_answer = model.generate_content(first_prompt).text
# args = json.loads(utils.trimming(temp_answer))
# contexts, results = response(args, db_instance)
# print('Context extracted: ',contexts)
# if not results:
# return "Random question, direct return", contexts
# else:
# QA_Prompt = make_answer_prompt(temp_chat[-1]["parts"][0], contexts)
# temp_chat[-1]["parts"] = QA_Prompt
# print(temp_chat)
# answer = model.generate_content(temp_chat).text
# return temp_answer, answer
# except Exception as e:
# # print(e)
# return temp_answer, "Error occured: " + str(e)
def full_chain_history_question(chat_history: list, db_instance):
try:
temp_chat = format_chat_history_from_web(chat_history)
question = temp_chat[-1]['parts'][0]
first_answer = model_retrieval.generate_content(temp_chat).text
print(first_answer)
args = json.loads(utils.trimming(first_answer))
contexts, results = response(args, db_instance)
if not results:
return "Random question, direct return", contexts
else:
print('Context to answers: ',contexts)
answer_chat = answer_chat_template(question, contexts)
temp_chat[-1] = answer_chat
answer = model_answer.generate_content(temp_chat).text
return first_answer, answer
except Exception as e:
if first_answer:
return first_answer, "Error occured: " + str(e)
else:
return "No answer", "Error occured: " + str(e)