arxiv_chatbot / chat /model_manage.py
Artteiv's picture
Refactoring gemini functions (#9)
3826b3b verified
raw
history blame
12.1 kB
# # my_app/model_manager.py
# import google.generativeai as genai
# import chat.arxiv_bot.arxiv_bot_utils as utils
# import json
# model = None
# model_retrieval = None
# model_answer = None
# RETRIEVAL_INSTRUCT = """You are an auto chatbot that response with only one action below based on user question.
# 1. If the guest question is asking about a science topic, you need to respond the information in JSON schema below:
# {
# "keywords": [a list of string keywords about the topic],
# "description": "a paragraph describing the topic in about 50 to 100 words"
# }
# 2. If the guest is not asking for any informations or documents, you need to respond in JSON schema below:
# {
# "answer": "your answer to the user question"
# }"""
# ANSWER_INSTRUCT = """You are a library assistant that help answering customer question based on the information given.
# You always answer in a conversational form naturally and politely.
# You must introduce all the records given, each must contain title, authors and the link to the pdf file."""
# def create_model():
# with open("apikey.txt","r") as apikey:
# key = apikey.readline()
# genai.configure(api_key=key)
# for m in genai.list_models():
# if 'generateContent' in m.supported_generation_methods:
# print(m.name)
# print("He was there")
# config = genai.GenerationConfig(max_output_tokens=2048,
# temperature=1.0)
# safety_settings = [
# {
# "category": "HARM_CATEGORY_DANGEROUS",
# "threshold": "BLOCK_NONE",
# },
# {
# "category": "HARM_CATEGORY_HARASSMENT",
# "threshold": "BLOCK_NONE",
# },
# {
# "category": "HARM_CATEGORY_HATE_SPEECH",
# "threshold": "BLOCK_NONE",
# },
# {
# "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
# "threshold": "BLOCK_NONE",
# },
# {
# "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
# "threshold": "BLOCK_NONE",
# },
# ]
# global model, model_retrieval, model_answer
# model = genai.GenerativeModel("gemini-1.5-pro-latest",
# generation_config=config,
# safety_settings=safety_settings)
# model_retrieval = genai.GenerativeModel("gemini-1.5-pro-latest",
# generation_config=config,
# safety_settings=safety_settings,
# system_instruction=RETRIEVAL_INSTRUCT)
# model_answer = genai.GenerativeModel("gemini-1.5-pro-latest",
# generation_config=config,
# safety_settings=safety_settings,
# system_instruction=ANSWER_INSTRUCT)
# return model, model_answer, model_retrieval
# def get_model():
# global model, model_answer, model_retrieval
# if model is None:
# # Khởi tạo model ở đây
# model, model_answer, model_retrieval = create_model() # Giả sử create_model là hàm tạo model của bạn
# return model, model_answer, model_retrieval
# def extract_keyword_prompt(query):
# """A prompt that return a JSON block as arguments for querying database"""
# prompt = """[INST] SYSTEM: You are an auto chatbot that response with only one action below based on user question.
# 1. If the guest question is asking about a science topic, you need to respond the information in JSON schema below:
# {
# "keywords": [a list of string keywords about the topic],
# "description": "a paragraph describing the topic in about 50 to 100 words"
# }
# 2. If the guest is not asking for any informations or documents, you need to respond in JSON schema below:
# {
# "answer": "your answer to the user question"
# }
# QUESTION: """ + query + """[/INST]
# ANSWER: """
# return prompt
# def make_answer_prompt(input, contexts):
# """A prompt that return the final answer, based on the queried context"""
# prompt = (
# """[INST] You are a library assistant that help answering customer QUESTION based on the INFORMATION given.
# You always answer in a conversational form naturally and politely.
# You must introduce all the records given, each must contain title, authors and the link to the pdf file.
# QUESTION: {input}
# INFORMATION: '{contexts}'
# [/INST]
# ANSWER:
# """
# ).format(input=input, contexts=contexts)
# return prompt
# def retrieval_chat_template(question):
# return {
# "role":"user",
# "parts":[f"QUESTION: {question} \n ANSWER:"]
# }
# def answer_chat_template(question, contexts):
# return {
# "role":"user",
# "parts":[f"QUESTION: {question} \n INFORMATION: {contexts} \n ANSWER:"]
# }
# def response(args, db_instance):
# """Create response context, based on input arguments"""
# keys = list(dict.keys(args))
# if "answer" in keys:
# return args['answer'], None # trả lời trực tiếp
# if "keywords" in keys:
# # perform query
# query_texts = args["description"]
# keywords = args["keywords"]
# results = utils.db.query_relevant(keywords=keywords, query_texts=query_texts)
# # print(results)
# ids = results['metadatas'][0]
# if len(ids) == 0:
# # go crawl some
# new_records = utils.crawl_arxiv(keyword_list=keywords, max_results=10)
# print("Got new records: ",len(new_records))
# if type(new_records) == str:
# return "Error occured, information not found", new_records
# utils.db.add(new_records)
# db_instance.add(new_records)
# results = utils.db.query_relevant(keywords=keywords, query_texts=query_texts)
# ids = results['metadatas'][0]
# print("Re-queried on chromadb, results: ",ids)
# paper_id = [id['paper_id'] for id in ids]
# paper_info = db_instance.query_id(paper_id)
# print(paper_info)
# records = [] # get title (2), author (3), link (6)
# result_string = ""
# if paper_info:
# for i in range(len(paper_info)):
# result_string += "Record no.{} - Title: {}, Author: {}, Link: {}, ".format(i+1,paper_info[i][2],paper_info[i][3],paper_info[i][6])
# id = paper_info[i][0]
# selected_document = utils.db.query_exact(id)["documents"]
# doc_str = "Summary:"
# for doc in selected_document:
# doc_str+= doc + " "
# result_string += doc_str
# records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
# return result_string, records
# else:
# return "Information not found", "Information not found"
# # invoke llm and return result
# # if "title" in keys:
# # title = args['title']
# # authors = utils.authors_str_to_list(args['author'])
# # paper_info = db_instance.query(title = title,author = authors)
# # # if query not found then go crawl brh
# # # print(paper_info)
# # if len(paper_info) == 0:
# # new_records = utils.crawl_exact_paper(title=title,author=authors)
# # print("Got new records: ",len(new_records))
# # if type(new_records) == str:
# # # print(new_records)
# # return "Error occured, information not found", "Information not found"
# # utils.db.add(new_records)
# # db_instance.add(new_records)
# # paper_info = db_instance.query(title = title,author = authors)
# # print("Re-queried on chromadb, results: ",paper_info)
# # # -------------------------------------
# # records = [] # get title (2), author (3), link (6)
# # result_string = ""
# # for i in range(len(paper_info)):
# # result_string += "Title: {}, Author: {}, Link: {}".format(paper_info[i][2],paper_info[i][3],paper_info[i][6])
# # records.append([paper_info[i][2],paper_info[i][3],paper_info[i][6]])
# # # process results:
# # if len(result_string) == 0:
# # return "Information not found", "Information not found"
# # return result_string, records
# # invoke llm and return result
# def full_chain_single_question(input_prompt, db_instance):
# try:
# first_prompt = extract_keyword_prompt(input_prompt)
# temp_answer = model.generate_content(first_prompt).text
# args = json.loads(utils.trimming(temp_answer))
# contexts, results = response(args, db_instance)
# if not results:
# # print(contexts)
# return "Random question, direct return", contexts
# else:
# output_prompt = make_answer_prompt(input_prompt,contexts)
# answer = model.generate_content(output_prompt).text
# return temp_answer, answer
# except Exception as e:
# # print(e)
# return temp_answer, "Error occured: " + str(e)
# def format_chat_history_from_web(chat_history: list):
# temp_chat = []
# for message in chat_history:
# temp_chat.append(
# {
# "role": message["role"],
# "parts": [message["content"]]
# }
# )
# return temp_chat
# # def full_chain_history_question(chat_history: list, db_instance):
# # try:
# # temp_chat = format_chat_history_from_web(chat_history)
# # print('Extracted temp chat: ',temp_chat)
# # first_prompt = extract_keyword_prompt(temp_chat[-1]["parts"][0])
# # temp_answer = model.generate_content(first_prompt).text
# # args = json.loads(utils.trimming(temp_answer))
# # contexts, results = response(args, db_instance)
# # print('Context extracted: ',contexts)
# # if not results:
# # return "Random question, direct return", contexts
# # else:
# # QA_Prompt = make_answer_prompt(temp_chat[-1]["parts"][0], contexts)
# # temp_chat[-1]["parts"] = QA_Prompt
# # print(temp_chat)
# # answer = model.generate_content(temp_chat).text
# # return temp_answer, answer
# # except Exception as e:
# # # print(e)
# # return temp_answer, "Error occured: " + str(e)
# def full_chain_history_question(chat_history: list, db_instance):
# try:
# temp_chat = format_chat_history_from_web(chat_history)
# question = temp_chat[-1]['parts'][0]
# first_answer = model_retrieval.generate_content(temp_chat).text
# print(first_answer)
# args = json.loads(utils.trimming(first_answer))
# contexts, results = response(args, db_instance)
# if not results:
# return "Random question, direct return", contexts
# else:
# print('Context to answers: ',contexts)
# answer_chat = answer_chat_template(question, contexts)
# temp_chat[-1] = answer_chat
# answer = model_answer.generate_content(temp_chat).text
# return first_answer, answer
# except Exception as e:
# if first_answer:
# return first_answer, "Error occured: " + str(e)
# else:
# return "No answer", "Error occured: " + str(e)