File size: 5,793 Bytes
e60c070
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import re
from typing import List
from loguru import logger

import config
from llm.call_llm import get_completion, get_completion_from_messages
from words_db import words_db
from create_db import get_similar_k_words
from prompts import trans_prompt, query_prompt, learn_prompt
from prompts import system_message_mapper


def format_common_prompt(raw_prompt, variable):
    """get format prompt by repalce variable in raw_prompt
    """
    return raw_prompt.format(variable)

def format_chat_prompt(message, chat_history) -> str:
    """get format prompt
    """
    prompt = ""
    for turn in chat_history:  # add history info
        user_message, bot_message = turn
        prompt = f"{prompt}\nUser: {user_message}\nAssistant: {bot_message}"
    prompt = f"{prompt}\nUser: {message}\nAssistant:"
    return prompt


def respond(message, chat_history, 
            llm="gpt-3.5-turbo", history_len=3, temperature=0.1, max_tokens=2048):
    """get respond from LLM
    """
    
    # deal with commands
    respond_message = command_parser(message)
    if respond_message:
        chat_history.append((message, respond_message))
        respond_message = ""
        return respond_message, chat_history 
    
    # map natural language to command
    respond_message = command_mapper(message)
    if respond_message:
        chat_history.append((message, respond_message))
        respond_message = ""
        return respond_message, chat_history 
    
    # no commands return, so chat with LLM
    if message is None or len(message) < 1:
            return "", chat_history
    try:
        chat_history = chat_history[-history_len:] if history_len > 0 else []  # constrain history length
        formatted_prompt = format_chat_prompt(message, chat_history)  # format prompt
        bot_message = get_completion(
            formatted_prompt, 
            llm, 
            api_key=config.api_key,
            temperature=temperature, max_tokens=max_tokens)
        bot_message = re.sub(r"\\n", '<br/>', bot_message) # replace \n with <br/>
        chat_history.append((message, bot_message))
        return "", chat_history 
    except Exception as e:
        return e, chat_history

def command_parser(input: str) -> str:
    """parse 4 type commands
    1. :add
    2. :remove
    3. :learn
    4. :query
    return info of action to user
    """
    if input.startswith(":add"):
        words = input.split(" ")[1:]
        info = add_words(words)
        return info
    if input.startswith(":remove"):
        words = input.split(" ")[1:]
        info = remove_words(words)
        return info
    if input.startswith(":learn"):
        if len(input.split(" ")) != 2:
            return "学习模式将基于词库进行,请指定一个query单词"
        query = input.split(" ")[1]
        info = learn_words(query)
        return f"Based on your query word: {query} and dictionary, learning sentence is:\n{info}"
    if input.startswith(":query"):
        if len(input.split(" ")) > 2:
            return "查询模式仅支持单个单词,请使用:query <word>进行查询"
        word = input.split(" ")[1]
        info = query_word(word)
        return f"{word}\n{info}"
    if input.startswith(":show"):
        info = show_all_words()
        return info
    if input.startswith(":help"):
        return "目前支持的指令有:\n:add <word1> <word2> ...\n:remove <word1> <word2> ...\n:learn <query_word>\n:query <word>"
    
    return ""


def show_all_words() -> str:
    """show all words in db
    """
    try:
        all_words = words_db.query_word()
        return f"目前词库中的所有单词:\n{all_words}"
    except Exception as e:
        logger.error(str(e))
        return "查询失败"

def add_words(input: List[str]):
    word_tuple_list = [
        (word, get_completion(
            prompt=format_common_prompt(trans_prompt, word),
            api_key=config.api_key))
        for word in input
    ]
    try:
        for word_tuple in word_tuple_list:
            word, definition = word_tuple
            words_db.add_word(word, definition)
            logger.info(f"已经添加单词: {word} 和其释义: {definition}")
    except Exception as e:
        logger.error(str(e))
        return f"添加单词失败: {input}"
    return f"已添加单词: {input}"

def remove_words(input: List[str]):
    try:
        for word in input:
            words_db.delete_word(word)
            logger.info(f"已经删除单词: {word} 和其释义")
    except Exception as e:
        logger.error(str(e))
        return f"删除单词失败: {input}"
    return f"已删除单词: {input}"

def learn_words(query_word) -> str:
    # get top 3 words from vec db and generate material
    words = get_similar_k_words(query_word)
    respond = get_completion(
        prompt=format_common_prompt(learn_prompt, words),
        api_key=config.api_key)
    logger.info(f"进入学习模式,学习下列单词: {words}")
    return respond

def query_word(input: str) -> str:
    # just query infomation about a word from gpt
    respond = get_completion(
        prompt=format_common_prompt(query_prompt, input),
        api_key=config.api_key)
    logger.info(f"查询单词: {input}")
    return respond

def command_mapper(input: str) -> str:
    """map natural language to command, return command function
    """
    user_message = input
    messages =  [  
        {'role':'system', 
        'content': system_message_mapper},    
        {'role':'user', 
        'content': f"{user_message}"},  
    ]
    respond = get_completion_from_messages(messages, api_key=config)
    mapped_command = command_parser(respond)
    logger.info(f"用户输入: {user_message}\n指令解析器输出: {mapped_command}")
    
    return mapped_command