WordsApp / chat.py
dht-tb16p
Commit 1st version
e60c070
raw
history blame
No virus
5.79 kB
import re
from typing import List
from loguru import logger
import config
from llm.call_llm import get_completion, get_completion_from_messages
from words_db import words_db
from create_db import get_similar_k_words
from prompts import trans_prompt, query_prompt, learn_prompt
from prompts import system_message_mapper
def format_common_prompt(raw_prompt, variable):
"""get format prompt by repalce variable in raw_prompt
"""
return raw_prompt.format(variable)
def format_chat_prompt(message, chat_history) -> str:
"""get format prompt
"""
prompt = ""
for turn in chat_history: # add history info
user_message, bot_message = turn
prompt = f"{prompt}\nUser: {user_message}\nAssistant: {bot_message}"
prompt = f"{prompt}\nUser: {message}\nAssistant:"
return prompt
def respond(message, chat_history,
llm="gpt-3.5-turbo", history_len=3, temperature=0.1, max_tokens=2048):
"""get respond from LLM
"""
# deal with commands
respond_message = command_parser(message)
if respond_message:
chat_history.append((message, respond_message))
respond_message = ""
return respond_message, chat_history
# map natural language to command
respond_message = command_mapper(message)
if respond_message:
chat_history.append((message, respond_message))
respond_message = ""
return respond_message, chat_history
# no commands return, so chat with LLM
if message is None or len(message) < 1:
return "", chat_history
try:
chat_history = chat_history[-history_len:] if history_len > 0 else [] # constrain history length
formatted_prompt = format_chat_prompt(message, chat_history) # format prompt
bot_message = get_completion(
formatted_prompt,
llm,
api_key=config.api_key,
temperature=temperature, max_tokens=max_tokens)
bot_message = re.sub(r"\\n", '<br/>', bot_message) # replace \n with <br/>
chat_history.append((message, bot_message))
return "", chat_history
except Exception as e:
return e, chat_history
def command_parser(input: str) -> str:
"""parse 4 type commands
1. :add
2. :remove
3. :learn
4. :query
return info of action to user
"""
if input.startswith(":add"):
words = input.split(" ")[1:]
info = add_words(words)
return info
if input.startswith(":remove"):
words = input.split(" ")[1:]
info = remove_words(words)
return info
if input.startswith(":learn"):
if len(input.split(" ")) != 2:
return "学习模式将基于词库进行,请指定一个query单词"
query = input.split(" ")[1]
info = learn_words(query)
return f"Based on your query word: {query} and dictionary, learning sentence is:\n{info}"
if input.startswith(":query"):
if len(input.split(" ")) > 2:
return "查询模式仅支持单个单词,请使用:query <word>进行查询"
word = input.split(" ")[1]
info = query_word(word)
return f"{word}\n{info}"
if input.startswith(":show"):
info = show_all_words()
return info
if input.startswith(":help"):
return "目前支持的指令有:\n:add <word1> <word2> ...\n:remove <word1> <word2> ...\n:learn <query_word>\n:query <word>"
return ""
def show_all_words() -> str:
"""show all words in db
"""
try:
all_words = words_db.query_word()
return f"目前词库中的所有单词:\n{all_words}"
except Exception as e:
logger.error(str(e))
return "查询失败"
def add_words(input: List[str]):
word_tuple_list = [
(word, get_completion(
prompt=format_common_prompt(trans_prompt, word),
api_key=config.api_key))
for word in input
]
try:
for word_tuple in word_tuple_list:
word, definition = word_tuple
words_db.add_word(word, definition)
logger.info(f"已经添加单词: {word} 和其释义: {definition}")
except Exception as e:
logger.error(str(e))
return f"添加单词失败: {input}"
return f"已添加单词: {input}"
def remove_words(input: List[str]):
try:
for word in input:
words_db.delete_word(word)
logger.info(f"已经删除单词: {word} 和其释义")
except Exception as e:
logger.error(str(e))
return f"删除单词失败: {input}"
return f"已删除单词: {input}"
def learn_words(query_word) -> str:
# get top 3 words from vec db and generate material
words = get_similar_k_words(query_word)
respond = get_completion(
prompt=format_common_prompt(learn_prompt, words),
api_key=config.api_key)
logger.info(f"进入学习模式,学习下列单词: {words}")
return respond
def query_word(input: str) -> str:
# just query infomation about a word from gpt
respond = get_completion(
prompt=format_common_prompt(query_prompt, input),
api_key=config.api_key)
logger.info(f"查询单词: {input}")
return respond
def command_mapper(input: str) -> str:
"""map natural language to command, return command function
"""
user_message = input
messages = [
{'role':'system',
'content': system_message_mapper},
{'role':'user',
'content': f"{user_message}"},
]
respond = get_completion_from_messages(messages, api_key=config)
mapped_command = command_parser(respond)
logger.info(f"用户输入: {user_message}\n指令解析器输出: {mapped_command}")
return mapped_command