dht-tb16p
commited on
Commit
•
e60c070
1
Parent(s):
7a0d66d
Commit 1st version
Browse files- __init__.py +0 -0
- app.py +38 -0
- chat.py +171 -0
- config.py +1 -0
- create_db.py +149 -0
- database/__pycache__/create_db.cpython-311.pyc +0 -0
- database/word_database.db +0 -0
- embedding/__init__.py +1 -0
- embedding/__pycache__/__init__.cpython-310.pyc +0 -0
- embedding/__pycache__/__init__.cpython-311.pyc +0 -0
- embedding/__pycache__/__init__.cpython-39.pyc +0 -0
- embedding/__pycache__/call_embedding.cpython-310.pyc +0 -0
- embedding/__pycache__/call_embedding.cpython-311.pyc +0 -0
- embedding/__pycache__/call_embedding.cpython-39.pyc +0 -0
- embedding/__pycache__/zhipuai_embedding.cpython-310.pyc +0 -0
- embedding/__pycache__/zhipuai_embedding.cpython-311.pyc +0 -0
- embedding/__pycache__/zhipuai_embedding.cpython-39.pyc +0 -0
- embedding/call_embedding.py +20 -0
- embedding/zhipuai_embedding.py +112 -0
- llm/__pycache__/call_llm.cpython-310.pyc +0 -0
- llm/__pycache__/call_llm.cpython-311.pyc +0 -0
- llm/__pycache__/call_llm.cpython-39.pyc +0 -0
- llm/__pycache__/self_llm.cpython-310.pyc +0 -0
- llm/__pycache__/self_llm.cpython-311.pyc +0 -0
- llm/__pycache__/self_llm.cpython-39.pyc +0 -0
- llm/__pycache__/spark_llm.cpython-310.pyc +0 -0
- llm/__pycache__/spark_llm.cpython-311.pyc +0 -0
- llm/__pycache__/spark_llm.cpython-39.pyc +0 -0
- llm/__pycache__/wenxin_llm.cpython-310.pyc +0 -0
- llm/__pycache__/wenxin_llm.cpython-311.pyc +0 -0
- llm/__pycache__/wenxin_llm.cpython-39.pyc +0 -0
- llm/__pycache__/wenxin_llm_.cpython-310.pyc +0 -0
- llm/__pycache__/zhipuai_llm.cpython-310.pyc +0 -0
- llm/__pycache__/zhipuai_llm.cpython-311.pyc +0 -0
- llm/__pycache__/zhipuai_llm.cpython-39.pyc +0 -0
- llm/call_llm.py +330 -0
- llm/self_llm.py +47 -0
- llm/spark_llm.py +227 -0
- llm/test.ipynb +312 -0
- llm/wenxin_llm.py +90 -0
- llm/zhipuai_llm.py +217 -0
- prompts.py +84 -0
- requirements.txt +157 -0
- test/__pycache__/test_create_db.cpython-311-pytest-8.2.0.pyc +0 -0
- test/test_create_db.py +44 -0
- words_db.py +67 -0
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import openai
|
3 |
+
|
4 |
+
from create_db import create_db
|
5 |
+
from chat import respond
|
6 |
+
|
7 |
+
openai.api_base = "https://api.v36.cm/v1" # global setting is needed
|
8 |
+
|
9 |
+
def launch_app():
|
10 |
+
|
11 |
+
with gr.Blocks() as demo:
|
12 |
+
with gr.Row(equal_height=True):
|
13 |
+
gr.Markdown("## 英语单词学习工具")
|
14 |
+
|
15 |
+
with gr.Row():
|
16 |
+
with gr.Column(scale=4):
|
17 |
+
chatbot = gr.Chatbot(height=400)
|
18 |
+
msg = gr.Textbox(label="在此输入指令(以:起始)或对话")
|
19 |
+
btn = gr.Button("Submit")
|
20 |
+
gr.ClearButton(components=[msg, chatbot], value="清除对话")
|
21 |
+
btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
|
22 |
+
msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
|
23 |
+
with gr.Column(scale=1):
|
24 |
+
file = gr.File(label='请导入制作词库的文件', file_count='single', file_types=['.md', '.pdf'])
|
25 |
+
with gr.Row():
|
26 |
+
init_vocab_by_file = gr.Button("使用文件生成个人词库")
|
27 |
+
text = gr.Textbox(label="在此粘贴制作词库的文本", lines=8)
|
28 |
+
with gr.Row():
|
29 |
+
init_vocab_by_text = gr.Button("使用文本生成个人词库")
|
30 |
+
init_vocab_by_file.click(create_db, inputs=[file, chatbot], outputs=[chatbot])
|
31 |
+
init_vocab_by_text.click(create_db, inputs=[text, chatbot], outputs=[chatbot])
|
32 |
+
|
33 |
+
gr.close_all()
|
34 |
+
demo.launch()
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
launch_app()
|
38 |
+
|
chat.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from typing import List
|
3 |
+
from loguru import logger
|
4 |
+
|
5 |
+
import config
|
6 |
+
from llm.call_llm import get_completion, get_completion_from_messages
|
7 |
+
from words_db import words_db
|
8 |
+
from create_db import get_similar_k_words
|
9 |
+
from prompts import trans_prompt, query_prompt, learn_prompt
|
10 |
+
from prompts import system_message_mapper
|
11 |
+
|
12 |
+
|
13 |
+
def format_common_prompt(raw_prompt, variable):
|
14 |
+
"""get format prompt by repalce variable in raw_prompt
|
15 |
+
"""
|
16 |
+
return raw_prompt.format(variable)
|
17 |
+
|
18 |
+
def format_chat_prompt(message, chat_history) -> str:
|
19 |
+
"""get format prompt
|
20 |
+
"""
|
21 |
+
prompt = ""
|
22 |
+
for turn in chat_history: # add history info
|
23 |
+
user_message, bot_message = turn
|
24 |
+
prompt = f"{prompt}\nUser: {user_message}\nAssistant: {bot_message}"
|
25 |
+
prompt = f"{prompt}\nUser: {message}\nAssistant:"
|
26 |
+
return prompt
|
27 |
+
|
28 |
+
|
29 |
+
def respond(message, chat_history,
|
30 |
+
llm="gpt-3.5-turbo", history_len=3, temperature=0.1, max_tokens=2048):
|
31 |
+
"""get respond from LLM
|
32 |
+
"""
|
33 |
+
|
34 |
+
# deal with commands
|
35 |
+
respond_message = command_parser(message)
|
36 |
+
if respond_message:
|
37 |
+
chat_history.append((message, respond_message))
|
38 |
+
respond_message = ""
|
39 |
+
return respond_message, chat_history
|
40 |
+
|
41 |
+
# map natural language to command
|
42 |
+
respond_message = command_mapper(message)
|
43 |
+
if respond_message:
|
44 |
+
chat_history.append((message, respond_message))
|
45 |
+
respond_message = ""
|
46 |
+
return respond_message, chat_history
|
47 |
+
|
48 |
+
# no commands return, so chat with LLM
|
49 |
+
if message is None or len(message) < 1:
|
50 |
+
return "", chat_history
|
51 |
+
try:
|
52 |
+
chat_history = chat_history[-history_len:] if history_len > 0 else [] # constrain history length
|
53 |
+
formatted_prompt = format_chat_prompt(message, chat_history) # format prompt
|
54 |
+
bot_message = get_completion(
|
55 |
+
formatted_prompt,
|
56 |
+
llm,
|
57 |
+
api_key=config.api_key,
|
58 |
+
temperature=temperature, max_tokens=max_tokens)
|
59 |
+
bot_message = re.sub(r"\\n", '<br/>', bot_message) # replace \n with <br/>
|
60 |
+
chat_history.append((message, bot_message))
|
61 |
+
return "", chat_history
|
62 |
+
except Exception as e:
|
63 |
+
return e, chat_history
|
64 |
+
|
65 |
+
def command_parser(input: str) -> str:
|
66 |
+
"""parse 4 type commands
|
67 |
+
1. :add
|
68 |
+
2. :remove
|
69 |
+
3. :learn
|
70 |
+
4. :query
|
71 |
+
return info of action to user
|
72 |
+
"""
|
73 |
+
if input.startswith(":add"):
|
74 |
+
words = input.split(" ")[1:]
|
75 |
+
info = add_words(words)
|
76 |
+
return info
|
77 |
+
if input.startswith(":remove"):
|
78 |
+
words = input.split(" ")[1:]
|
79 |
+
info = remove_words(words)
|
80 |
+
return info
|
81 |
+
if input.startswith(":learn"):
|
82 |
+
if len(input.split(" ")) != 2:
|
83 |
+
return "学习模式将基于词库进行,请指定一个query单词"
|
84 |
+
query = input.split(" ")[1]
|
85 |
+
info = learn_words(query)
|
86 |
+
return f"Based on your query word: {query} and dictionary, learning sentence is:\n{info}"
|
87 |
+
if input.startswith(":query"):
|
88 |
+
if len(input.split(" ")) > 2:
|
89 |
+
return "查询模式仅支持单个单词,请使用:query <word>进行查询"
|
90 |
+
word = input.split(" ")[1]
|
91 |
+
info = query_word(word)
|
92 |
+
return f"{word}\n{info}"
|
93 |
+
if input.startswith(":show"):
|
94 |
+
info = show_all_words()
|
95 |
+
return info
|
96 |
+
if input.startswith(":help"):
|
97 |
+
return "目前支持的指令有:\n:add <word1> <word2> ...\n:remove <word1> <word2> ...\n:learn <query_word>\n:query <word>"
|
98 |
+
|
99 |
+
return ""
|
100 |
+
|
101 |
+
|
102 |
+
def show_all_words() -> str:
|
103 |
+
"""show all words in db
|
104 |
+
"""
|
105 |
+
try:
|
106 |
+
all_words = words_db.query_word()
|
107 |
+
return f"目前词库中的所有单词:\n{all_words}"
|
108 |
+
except Exception as e:
|
109 |
+
logger.error(str(e))
|
110 |
+
return "查询失败"
|
111 |
+
|
112 |
+
def add_words(input: List[str]):
|
113 |
+
word_tuple_list = [
|
114 |
+
(word, get_completion(
|
115 |
+
prompt=format_common_prompt(trans_prompt, word),
|
116 |
+
api_key=config.api_key))
|
117 |
+
for word in input
|
118 |
+
]
|
119 |
+
try:
|
120 |
+
for word_tuple in word_tuple_list:
|
121 |
+
word, definition = word_tuple
|
122 |
+
words_db.add_word(word, definition)
|
123 |
+
logger.info(f"已经添加单词: {word} 和其释义: {definition}")
|
124 |
+
except Exception as e:
|
125 |
+
logger.error(str(e))
|
126 |
+
return f"添加单词失败: {input}"
|
127 |
+
return f"已添加单词: {input}"
|
128 |
+
|
129 |
+
def remove_words(input: List[str]):
|
130 |
+
try:
|
131 |
+
for word in input:
|
132 |
+
words_db.delete_word(word)
|
133 |
+
logger.info(f"已经删除单词: {word} 和其释义")
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(str(e))
|
136 |
+
return f"删除单词失败: {input}"
|
137 |
+
return f"已删除单词: {input}"
|
138 |
+
|
139 |
+
def learn_words(query_word) -> str:
|
140 |
+
# get top 3 words from vec db and generate material
|
141 |
+
words = get_similar_k_words(query_word)
|
142 |
+
respond = get_completion(
|
143 |
+
prompt=format_common_prompt(learn_prompt, words),
|
144 |
+
api_key=config.api_key)
|
145 |
+
logger.info(f"进入学习模式,学习下列单词: {words}")
|
146 |
+
return respond
|
147 |
+
|
148 |
+
def query_word(input: str) -> str:
|
149 |
+
# just query infomation about a word from gpt
|
150 |
+
respond = get_completion(
|
151 |
+
prompt=format_common_prompt(query_prompt, input),
|
152 |
+
api_key=config.api_key)
|
153 |
+
logger.info(f"查询单词: {input}")
|
154 |
+
return respond
|
155 |
+
|
156 |
+
def command_mapper(input: str) -> str:
|
157 |
+
"""map natural language to command, return command function
|
158 |
+
"""
|
159 |
+
user_message = input
|
160 |
+
messages = [
|
161 |
+
{'role':'system',
|
162 |
+
'content': system_message_mapper},
|
163 |
+
{'role':'user',
|
164 |
+
'content': f"{user_message}"},
|
165 |
+
]
|
166 |
+
respond = get_completion_from_messages(messages, api_key=config)
|
167 |
+
mapped_command = command_parser(respond)
|
168 |
+
logger.info(f"用户输入: {user_message}\n指令解析器输出: {mapped_command}")
|
169 |
+
|
170 |
+
return mapped_command
|
171 |
+
|
config.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
api_key = "sk-wmeFmoPXQ9wlJKiP48B0F028E6534359A6980b9585Ba5bAc"
|
create_db.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
4 |
+
import tempfile
|
5 |
+
import config
|
6 |
+
import nltk
|
7 |
+
|
8 |
+
from typing import List
|
9 |
+
from nltk.corpus import words
|
10 |
+
from loguru import logger
|
11 |
+
from llm.call_llm import get_completion_from_messages
|
12 |
+
from embedding.call_embedding import get_embedding
|
13 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
14 |
+
from langchain.document_loaders import PyMuPDFLoader
|
15 |
+
from langchain.vectorstores import Chroma
|
16 |
+
|
17 |
+
from prompts import system_message_select
|
18 |
+
|
19 |
+
WORDS_DB_PATH = "../words_db"
|
20 |
+
VECTOR_DB_PATH = "./vector_db/chroma"
|
21 |
+
|
22 |
+
def parse_file(file_path):
|
23 |
+
docs = []
|
24 |
+
# check file type
|
25 |
+
file_type = file_path.split('.')[-1]
|
26 |
+
if file_type == 'pdf':
|
27 |
+
loader = PyMuPDFLoader(file_path)
|
28 |
+
content = loader.load()
|
29 |
+
docs.extend(content)
|
30 |
+
else:
|
31 |
+
return "File type not supported"
|
32 |
+
if len(docs) > 5:
|
33 |
+
return "File too large, please select a pdf file with less than 5 pages"
|
34 |
+
|
35 |
+
slices = split_text(docs) # split content into slices
|
36 |
+
words = extract_words(slices) # extract words from slices
|
37 |
+
try:
|
38 |
+
vectorize_words(words) # store words into vector database
|
39 |
+
except Exception as e:
|
40 |
+
logger.error(e)
|
41 |
+
|
42 |
+
return ""
|
43 |
+
|
44 |
+
def parse_text(input: str):
|
45 |
+
content = input
|
46 |
+
return content
|
47 |
+
|
48 |
+
def split_text(docs: List[object]):
|
49 |
+
"""Split text into slices"""
|
50 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
51 |
+
chunk_size = 1500,
|
52 |
+
chunk_overlap = 150
|
53 |
+
)
|
54 |
+
splits = text_splitter.split_documents(docs)
|
55 |
+
logger.info(f"Split {len(docs)} pages document into {len(splits)} slices")
|
56 |
+
return splits
|
57 |
+
|
58 |
+
def extract_words(splits: List[object]):
|
59 |
+
"""Extract words from slices"""
|
60 |
+
all_words = []
|
61 |
+
for slice in splits:
|
62 |
+
tmp_content = slice.page_content
|
63 |
+
messages = [
|
64 |
+
{'role':'system',
|
65 |
+
'content': system_message_select},
|
66 |
+
{'role':'user',
|
67 |
+
'content': f"{tmp_content}"},
|
68 |
+
]
|
69 |
+
respond = get_completion_from_messages(messages, api_key=config.api_key)
|
70 |
+
words_list = respond.split(", ")
|
71 |
+
if len(words_list) == 0:
|
72 |
+
continue
|
73 |
+
else:
|
74 |
+
all_words.extend(words_list)
|
75 |
+
all_words = wash_words(all_words)
|
76 |
+
logger.info(f"Extract {len(all_words)} words from slices")
|
77 |
+
return all_words
|
78 |
+
|
79 |
+
def wash_words(input_words: list[str]):
|
80 |
+
"""Wash words into a list of correct english words"""
|
81 |
+
words_list = [word for word in input_words
|
82 |
+
if len(word) >= 3 and len(word) <= 30]
|
83 |
+
nltk.download('words')
|
84 |
+
english_words = set(words.words())
|
85 |
+
filtered_words = [word.lower() for word in words_list if word.lower() in english_words]
|
86 |
+
filtered_words = list(set(filtered_words))
|
87 |
+
logger.info(f"Wash {len(filtered_words)} words into a list of correct english words")
|
88 |
+
return filtered_words
|
89 |
+
|
90 |
+
def get_words_from_text(input: str):
|
91 |
+
words = input.split(' ')
|
92 |
+
return words
|
93 |
+
|
94 |
+
def store_words(input: str, db_path=WORDS_DB_PATH):
|
95 |
+
"""Store words into database"""
|
96 |
+
pass
|
97 |
+
|
98 |
+
def vectorize_words(input: list[str], embedding=None):
|
99 |
+
"""Vectorize words into vectors"""
|
100 |
+
model = get_embedding("openai", embedding_key=config.api_key)
|
101 |
+
persist_path = VECTOR_DB_PATH
|
102 |
+
vectordb = Chroma.from_texts(
|
103 |
+
texts=input,
|
104 |
+
embedding=model,
|
105 |
+
persist_directory=persist_path
|
106 |
+
)
|
107 |
+
vectordb.persist()
|
108 |
+
logger.info(f"Vectorized {len(input)} words into vectors")
|
109 |
+
return vectordb
|
110 |
+
|
111 |
+
def get_similar_k_words(query_word, k=3) -> List[str]:
|
112 |
+
# get 3 simlilar words from DB
|
113 |
+
model = get_embedding("openai", embedding_key=config.api_key)
|
114 |
+
vectordb = Chroma(persist_directory=VECTOR_DB_PATH, embedding_function=model)
|
115 |
+
similar_words = vectordb.max_marginal_relevance_search(query_word, k=k)
|
116 |
+
similar_words = [word.page_content for word in similar_words]
|
117 |
+
logger.info(f"Get {k} similar words {similar_words} from DB")
|
118 |
+
return similar_words
|
119 |
+
|
120 |
+
def create_db(input, chat_history):
|
121 |
+
"""The input is file or text"""
|
122 |
+
action_msg = "" # the description of user action: put file or text into database
|
123 |
+
# 1. for file upload
|
124 |
+
if isinstance(input, tempfile._TemporaryFileWrapper):
|
125 |
+
tmp_file_path = input.name
|
126 |
+
file_name = tmp_file_path.split('/')[-1]
|
127 |
+
action_msg = f"Add words from my file: {file_name} to database"
|
128 |
+
try:
|
129 |
+
parse_file(tmp_file_path) #TODO
|
130 |
+
output = f"Words from your file: {file_name} has been added to database"
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(e)
|
133 |
+
output = f"Error: failed to use your file: {file_name} generate dictionary"
|
134 |
+
# 2. for text input
|
135 |
+
elif isinstance(input, str):
|
136 |
+
action_msg = f"Add words from my text: {input} to database"
|
137 |
+
try:
|
138 |
+
parse_text(input) #TODO
|
139 |
+
output = f"Words from your text: {input} has been added to database"
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(e)
|
142 |
+
output = f"Error: failed to use your text: {input} generate dictionary"
|
143 |
+
chat_history.append((action_msg, output))
|
144 |
+
|
145 |
+
return chat_history
|
146 |
+
|
147 |
+
|
148 |
+
if __name__ == "__main__":
|
149 |
+
create_db(embeddings="m3e")
|
database/__pycache__/create_db.cpython-311.pyc
ADDED
Binary file (3.33 kB). View file
|
|
database/word_database.db
ADDED
Binary file (8.19 kB). View file
|
|
embedding/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# from .zhipuai_embedding import ZhipuAIEmbeddings
|
embedding/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (163 Bytes). View file
|
|
embedding/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (158 Bytes). View file
|
|
embedding/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (152 Bytes). View file
|
|
embedding/__pycache__/call_embedding.cpython-310.pyc
ADDED
Binary file (956 Bytes). View file
|
|
embedding/__pycache__/call_embedding.cpython-311.pyc
ADDED
Binary file (1.48 kB). View file
|
|
embedding/__pycache__/call_embedding.cpython-39.pyc
ADDED
Binary file (937 Bytes). View file
|
|
embedding/__pycache__/zhipuai_embedding.cpython-310.pyc
ADDED
Binary file (4.23 kB). View file
|
|
embedding/__pycache__/zhipuai_embedding.cpython-311.pyc
ADDED
Binary file (5.45 kB). View file
|
|
embedding/__pycache__/zhipuai_embedding.cpython-39.pyc
ADDED
Binary file (4.2 kB). View file
|
|
embedding/call_embedding.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
4 |
+
|
5 |
+
from embedding.zhipuai_embedding import ZhipuAIEmbeddings
|
6 |
+
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
7 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
8 |
+
from llm.call_llm import parse_llm_api_key
|
9 |
+
|
10 |
+
def get_embedding(embedding: str, embedding_key: str=None, env_file: str=None):
|
11 |
+
if embedding == 'm3e':
|
12 |
+
return HuggingFaceEmbeddings(model_name="moka-ai/m3e-base")
|
13 |
+
if embedding_key is None:
|
14 |
+
embedding_key = parse_llm_api_key(embedding)
|
15 |
+
if embedding == "openai":
|
16 |
+
return OpenAIEmbeddings(openai_api_key=embedding_key)
|
17 |
+
elif embedding == "zhipuai":
|
18 |
+
return ZhipuAIEmbeddings(zhipuai_api_key=embedding_key)
|
19 |
+
else:
|
20 |
+
raise ValueError(f"embedding {embedding} not support ")
|
embedding/zhipuai_embedding.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import logging
|
4 |
+
from typing import Any, Dict, List, Optional
|
5 |
+
|
6 |
+
from langchain.embeddings.base import Embeddings
|
7 |
+
from langchain.pydantic_v1 import BaseModel, root_validator
|
8 |
+
from langchain.utils import get_from_dict_or_env
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
|
13 |
+
class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
14 |
+
"""`Zhipuai Embeddings` embedding models."""
|
15 |
+
|
16 |
+
zhipuai_api_key: Optional[str] = None
|
17 |
+
"""Zhipuai application apikey"""
|
18 |
+
|
19 |
+
@root_validator()
|
20 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
21 |
+
"""
|
22 |
+
Validate whether zhipuai_api_key in the environment variables or
|
23 |
+
configuration file are available or not.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
|
27 |
+
values: a dictionary containing configuration information, must include the
|
28 |
+
fields of zhipuai_api_key
|
29 |
+
Returns:
|
30 |
+
|
31 |
+
a dictionary containing configuration information. If zhipuai_api_key
|
32 |
+
are not provided in the environment variables or configuration
|
33 |
+
file, the original values will be returned; otherwise, values containing
|
34 |
+
zhipuai_api_key will be returned.
|
35 |
+
Raises:
|
36 |
+
|
37 |
+
ValueError: zhipuai package not found, please install it with `pip install
|
38 |
+
zhipuai`
|
39 |
+
"""
|
40 |
+
values["zhipuai_api_key"] = get_from_dict_or_env(
|
41 |
+
values,
|
42 |
+
"zhipuai_api_key",
|
43 |
+
"ZHIPUAI_API_KEY",
|
44 |
+
)
|
45 |
+
|
46 |
+
try:
|
47 |
+
import zhipuai
|
48 |
+
zhipuai.api_key = values["zhipuai_api_key"]
|
49 |
+
values["client"] = zhipuai.model_api
|
50 |
+
|
51 |
+
except ImportError:
|
52 |
+
raise ValueError(
|
53 |
+
"Zhipuai package not found, please install it with "
|
54 |
+
"`pip install zhipuai`"
|
55 |
+
)
|
56 |
+
return values
|
57 |
+
|
58 |
+
def _embed(self, texts: str) -> List[float]:
|
59 |
+
# send request
|
60 |
+
try:
|
61 |
+
resp = self.client.invoke(
|
62 |
+
model="text_embedding",
|
63 |
+
prompt=texts
|
64 |
+
)
|
65 |
+
except Exception as e:
|
66 |
+
raise ValueError(f"Error raised by inference endpoint: {e}")
|
67 |
+
|
68 |
+
if resp["code"] != 200:
|
69 |
+
raise ValueError(
|
70 |
+
"Error raised by inference API HTTP code: %s, %s"
|
71 |
+
% (resp["code"], resp["msg"])
|
72 |
+
)
|
73 |
+
embeddings = resp["data"]["embedding"]
|
74 |
+
return embeddings
|
75 |
+
|
76 |
+
def embed_query(self, text: str) -> List[float]:
|
77 |
+
"""
|
78 |
+
Embedding a text.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
|
82 |
+
Text (str): A text to be embedded.
|
83 |
+
|
84 |
+
Return:
|
85 |
+
|
86 |
+
List [float]: An embedding list of input text, which is a list of floating-point values.
|
87 |
+
"""
|
88 |
+
resp = self.embed_documents([text])
|
89 |
+
return resp[0]
|
90 |
+
|
91 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
92 |
+
"""
|
93 |
+
Embeds a list of text documents.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
texts (List[str]): A list of text documents to embed.
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
List[List[float]]: A list of embeddings for each document in the input list.
|
100 |
+
Each embedding is represented as a list of float values.
|
101 |
+
"""
|
102 |
+
return [self._embed(text) for text in texts]
|
103 |
+
|
104 |
+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
105 |
+
"""Asynchronous Embed search docs."""
|
106 |
+
raise NotImplementedError(
|
107 |
+
"Please use `embed_documents`. Official does not support asynchronous requests")
|
108 |
+
|
109 |
+
async def aembed_query(self, text: str) -> List[float]:
|
110 |
+
"""Asynchronous Embed query text."""
|
111 |
+
raise NotImplementedError(
|
112 |
+
"Please use `aembed_query`. Official does not support asynchronous requests")
|
llm/__pycache__/call_llm.cpython-310.pyc
ADDED
Binary file (8.17 kB). View file
|
|
llm/__pycache__/call_llm.cpython-311.pyc
ADDED
Binary file (14.1 kB). View file
|
|
llm/__pycache__/call_llm.cpython-39.pyc
ADDED
Binary file (8.17 kB). View file
|
|
llm/__pycache__/self_llm.cpython-310.pyc
ADDED
Binary file (1.57 kB). View file
|
|
llm/__pycache__/self_llm.cpython-311.pyc
ADDED
Binary file (2.11 kB). View file
|
|
llm/__pycache__/self_llm.cpython-39.pyc
ADDED
Binary file (1.56 kB). View file
|
|
llm/__pycache__/spark_llm.cpython-310.pyc
ADDED
Binary file (6.17 kB). View file
|
|
llm/__pycache__/spark_llm.cpython-311.pyc
ADDED
Binary file (10.5 kB). View file
|
|
llm/__pycache__/spark_llm.cpython-39.pyc
ADDED
Binary file (6.12 kB). View file
|
|
llm/__pycache__/wenxin_llm.cpython-310.pyc
ADDED
Binary file (2.93 kB). View file
|
|
llm/__pycache__/wenxin_llm.cpython-311.pyc
ADDED
Binary file (4.4 kB). View file
|
|
llm/__pycache__/wenxin_llm.cpython-39.pyc
ADDED
Binary file (2.99 kB). View file
|
|
llm/__pycache__/wenxin_llm_.cpython-310.pyc
ADDED
Binary file (3.76 kB). View file
|
|
llm/__pycache__/zhipuai_llm.cpython-310.pyc
ADDED
Binary file (5.86 kB). View file
|
|
llm/__pycache__/zhipuai_llm.cpython-311.pyc
ADDED
Binary file (8.4 kB). View file
|
|
llm/__pycache__/zhipuai_llm.cpython-39.pyc
ADDED
Binary file (5.6 kB). View file
|
|
llm/call_llm.py
ADDED
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
@File : call_llm.py
|
5 |
+
@Time : 2023/10/18 10:45:00
|
6 |
+
@Author : Logan Zou
|
7 |
+
@Version : 1.0
|
8 |
+
@Contact : loganzou0421@163.com
|
9 |
+
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
|
10 |
+
@Desc : 将各个大模型的原生接口封装在一个接口
|
11 |
+
'''
|
12 |
+
|
13 |
+
import openai
|
14 |
+
import json
|
15 |
+
import requests
|
16 |
+
import _thread as thread
|
17 |
+
import base64
|
18 |
+
# import datetime
|
19 |
+
from dotenv import load_dotenv, find_dotenv
|
20 |
+
import hashlib
|
21 |
+
import hmac
|
22 |
+
import os
|
23 |
+
import queue
|
24 |
+
from urllib.parse import urlparse
|
25 |
+
import ssl
|
26 |
+
from datetime import datetime
|
27 |
+
from time import mktime
|
28 |
+
from urllib.parse import urlencode
|
29 |
+
from wsgiref.handlers import format_date_time
|
30 |
+
import zhipuai
|
31 |
+
from langchain.utils import get_from_dict_or_env
|
32 |
+
|
33 |
+
import websocket # 使用websocket_client
|
34 |
+
|
35 |
+
def get_completion(prompt :str, model="gpt-3.5-turbo", temperature=0.1,api_key=None, secret_key=None, access_token=None, appid=None, api_secret=None, max_tokens=2048):
|
36 |
+
# 调用大模型获取回复,支持上述三种模型+gpt
|
37 |
+
# arguments:
|
38 |
+
# prompt: 输入提示
|
39 |
+
# model:模型名
|
40 |
+
# temperature: 温度系数
|
41 |
+
# api_key:如名
|
42 |
+
# secret_key, access_token:调用文心系列模型需要
|
43 |
+
# appid, api_secret: 调用星火系列模型需要
|
44 |
+
# max_tokens : 返回最长序列
|
45 |
+
# return: 模型返回,字符串
|
46 |
+
# 调用 GPT
|
47 |
+
if model in ["gpt-3.5-turbo", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-32k"]:
|
48 |
+
return get_completion_gpt(prompt, model, temperature, api_key, max_tokens)
|
49 |
+
elif model in ["ERNIE-Bot", "ERNIE-Bot-4", "ERNIE-Bot-turbo"]:
|
50 |
+
return get_completion_wenxin(prompt, model, temperature, api_key, secret_key)
|
51 |
+
elif model in ["Spark-1.5", "Spark-2.0"]:
|
52 |
+
return get_completion_spark(prompt, model, temperature, api_key, appid, api_secret, max_tokens)
|
53 |
+
elif model in ["chatglm_pro", "chatglm_std", "chatglm_lite"]:
|
54 |
+
return get_completion_glm(prompt, model, temperature, api_key, max_tokens)
|
55 |
+
else:
|
56 |
+
return "不正确的模型"
|
57 |
+
|
58 |
+
def get_completion_gpt(prompt : str, model : str, temperature : float, api_key:str, max_tokens:int):
|
59 |
+
# 封装 OpenAI 原生接口
|
60 |
+
if api_key is None:
|
61 |
+
api_key = parse_llm_api_key("openai")
|
62 |
+
openai.api_key = api_key
|
63 |
+
# 具体调用
|
64 |
+
messages = [{"role": "user", "content": prompt}]
|
65 |
+
response = openai.ChatCompletion.create(
|
66 |
+
model=model,
|
67 |
+
messages=messages,
|
68 |
+
temperature=temperature, # 模型输出的温度系数,控制输出的随机程度
|
69 |
+
max_tokens = max_tokens, # 回复最大长度
|
70 |
+
)
|
71 |
+
# 调用 OpenAI 的 ChatCompletion 接口
|
72 |
+
return response.choices[0].message["content"]
|
73 |
+
|
74 |
+
def get_completion_from_messages(messages, api_key, temperature=0):
|
75 |
+
# 封装 OpenAI 原生接口
|
76 |
+
if api_key is None:
|
77 |
+
api_key = parse_llm_api_key("openai")
|
78 |
+
openai.api_key = api_key
|
79 |
+
response = openai.ChatCompletion.create(
|
80 |
+
model="gpt-3.5-turbo",
|
81 |
+
messages=messages,
|
82 |
+
temperature=temperature, # 控制模型输出的随机程度
|
83 |
+
)
|
84 |
+
return response.choices[0].message["content"]
|
85 |
+
|
86 |
+
def get_access_token(api_key, secret_key):
|
87 |
+
"""
|
88 |
+
使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key
|
89 |
+
"""
|
90 |
+
# 指定网址
|
91 |
+
url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}"
|
92 |
+
# 设置 POST 访问
|
93 |
+
payload = json.dumps("")
|
94 |
+
headers = {
|
95 |
+
'Content-Type': 'application/json',
|
96 |
+
'Accept': 'application/json'
|
97 |
+
}
|
98 |
+
# 通过 POST 访问获取账户对应的 access_token
|
99 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
100 |
+
return response.json().get("access_token")
|
101 |
+
|
102 |
+
def get_completion_wenxin(prompt : str, model : str, temperature : float, api_key:str, secret_key : str):
|
103 |
+
# 封装百度文心原生接口
|
104 |
+
if api_key is None or secret_key is None:
|
105 |
+
api_key, secret_key = parse_llm_api_key("wenxin")
|
106 |
+
# 获取access_token
|
107 |
+
access_token = get_access_token(api_key, secret_key)
|
108 |
+
# 调用接口
|
109 |
+
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={access_token}"
|
110 |
+
# 配置 POST 参数
|
111 |
+
payload = json.dumps({
|
112 |
+
"messages": [
|
113 |
+
{
|
114 |
+
"role": "user",# user prompt
|
115 |
+
"content": "{}".format(prompt)# 输入的 prompt
|
116 |
+
}
|
117 |
+
]
|
118 |
+
})
|
119 |
+
headers = {
|
120 |
+
'Content-Type': 'application/json'
|
121 |
+
}
|
122 |
+
# 发起请求
|
123 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
124 |
+
# 返回的是一个 Json 字符串
|
125 |
+
js = json.loads(response.text)
|
126 |
+
return js["result"]
|
127 |
+
|
128 |
+
def get_completion_spark(prompt : str, model : str, temperature : float, api_key:str, appid : str, api_secret : str, max_tokens : int):
|
129 |
+
if api_key is None or appid is None and api_secret is None:
|
130 |
+
api_key, appid, api_secret = parse_llm_api_key("spark")
|
131 |
+
|
132 |
+
# 配置 1.5 和 2 的不同环境
|
133 |
+
if model == "Spark-1.5":
|
134 |
+
domain = "general"
|
135 |
+
Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
|
136 |
+
else:
|
137 |
+
domain = "generalv2" # v2.0版本
|
138 |
+
Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
|
139 |
+
|
140 |
+
question = [{"role":"user", "content":prompt}]
|
141 |
+
response = spark_main(appid,api_key,api_secret,Spark_url,domain,question,temperature,max_tokens)
|
142 |
+
return response
|
143 |
+
|
144 |
+
def get_completion_glm(prompt : str, model : str, temperature : float, api_key:str, max_tokens : int):
|
145 |
+
# 获取GLM回答
|
146 |
+
if api_key is None:
|
147 |
+
api_key = parse_llm_api_key("zhipuai")
|
148 |
+
zhipuai.api_key = api_key
|
149 |
+
|
150 |
+
response = zhipuai.model_api.invoke(
|
151 |
+
model=model,
|
152 |
+
prompt=[{"role":"user", "content":prompt}],
|
153 |
+
temperature = temperature,
|
154 |
+
max_tokens=max_tokens
|
155 |
+
)
|
156 |
+
return response["data"]["choices"][0]["content"].strip('"').strip(" ")
|
157 |
+
|
158 |
+
# def getText(role, content, text = []):
|
159 |
+
# # role 是指定角色,content 是 prompt 内容
|
160 |
+
# jsoncon = {}
|
161 |
+
# jsoncon["role"] = role
|
162 |
+
# jsoncon["content"] = content
|
163 |
+
# text.append(jsoncon)
|
164 |
+
# return text
|
165 |
+
|
166 |
+
# 星火 API 调用使用
|
167 |
+
answer = ""
|
168 |
+
|
169 |
+
class Ws_Param(object):
|
170 |
+
# 初始化
|
171 |
+
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
172 |
+
self.APPID = APPID
|
173 |
+
self.APIKey = APIKey
|
174 |
+
self.APISecret = APISecret
|
175 |
+
self.host = urlparse(Spark_url).netloc
|
176 |
+
self.path = urlparse(Spark_url).path
|
177 |
+
self.Spark_url = Spark_url
|
178 |
+
# 自定义
|
179 |
+
self.temperature = 0
|
180 |
+
self.max_tokens = 2048
|
181 |
+
|
182 |
+
# 生成url
|
183 |
+
def create_url(self):
|
184 |
+
# 生成RFC1123格式的时间戳
|
185 |
+
now = datetime.now()
|
186 |
+
date = format_date_time(mktime(now.timetuple()))
|
187 |
+
|
188 |
+
# 拼接字符串
|
189 |
+
signature_origin = "host: " + self.host + "\n"
|
190 |
+
signature_origin += "date: " + date + "\n"
|
191 |
+
signature_origin += "GET " + self.path + " HTTP/1.1"
|
192 |
+
|
193 |
+
# 进行hmac-sha256进行加密
|
194 |
+
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
195 |
+
digestmod=hashlib.sha256).digest()
|
196 |
+
|
197 |
+
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
198 |
+
|
199 |
+
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
200 |
+
|
201 |
+
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
202 |
+
|
203 |
+
# 将请求的鉴权参数组合为字典
|
204 |
+
v = {
|
205 |
+
"authorization": authorization,
|
206 |
+
"date": date,
|
207 |
+
"host": self.host
|
208 |
+
}
|
209 |
+
# 拼接鉴权参数,生成url
|
210 |
+
url = self.Spark_url + '?' + urlencode(v)
|
211 |
+
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
212 |
+
return url
|
213 |
+
|
214 |
+
|
215 |
+
# 收到websocket错误的处理
|
216 |
+
def on_error(ws, error):
|
217 |
+
print("### error:", error)
|
218 |
+
|
219 |
+
|
220 |
+
# 收到websocket关闭的处理
|
221 |
+
def on_close(ws,one,two):
|
222 |
+
print(" ")
|
223 |
+
|
224 |
+
|
225 |
+
# 收到websocket连接建立的处理
|
226 |
+
def on_open(ws):
|
227 |
+
thread.start_new_thread(run, (ws,))
|
228 |
+
|
229 |
+
|
230 |
+
def run(ws, *args):
|
231 |
+
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature = ws.temperature, max_tokens = ws.max_tokens))
|
232 |
+
ws.send(data)
|
233 |
+
|
234 |
+
|
235 |
+
# 收到websocket消息的处理
|
236 |
+
def on_message(ws, message):
|
237 |
+
# print(message)
|
238 |
+
data = json.loads(message)
|
239 |
+
code = data['header']['code']
|
240 |
+
if code != 0:
|
241 |
+
print(f'请求错误: {code}, {data}')
|
242 |
+
ws.close()
|
243 |
+
else:
|
244 |
+
choices = data["payload"]["choices"]
|
245 |
+
status = choices["status"]
|
246 |
+
content = choices["text"][0]["content"]
|
247 |
+
print(content,end ="")
|
248 |
+
global answer
|
249 |
+
answer += content
|
250 |
+
# print(1)
|
251 |
+
if status == 2:
|
252 |
+
ws.close()
|
253 |
+
|
254 |
+
|
255 |
+
def gen_params(appid, domain,question, temperature, max_tokens):
|
256 |
+
"""
|
257 |
+
通过appid和用户的提问来生成请参数
|
258 |
+
"""
|
259 |
+
data = {
|
260 |
+
"header": {
|
261 |
+
"app_id": appid,
|
262 |
+
"uid": "1234"
|
263 |
+
},
|
264 |
+
"parameter": {
|
265 |
+
"chat": {
|
266 |
+
"domain": domain,
|
267 |
+
"random_threshold": 0.5,
|
268 |
+
"max_tokens": max_tokens,
|
269 |
+
"temperature" : temperature,
|
270 |
+
"auditing": "default"
|
271 |
+
}
|
272 |
+
},
|
273 |
+
"payload": {
|
274 |
+
"message": {
|
275 |
+
"text": question
|
276 |
+
}
|
277 |
+
}
|
278 |
+
}
|
279 |
+
return data
|
280 |
+
|
281 |
+
|
282 |
+
def spark_main(appid, api_key, api_secret, Spark_url,domain, question, temperature, max_tokens):
|
283 |
+
# print("星火:")
|
284 |
+
output_queue = queue.Queue()
|
285 |
+
def on_message(ws, message):
|
286 |
+
data = json.loads(message)
|
287 |
+
code = data['header']['code']
|
288 |
+
if code != 0:
|
289 |
+
print(f'请求错误: {code}, {data}')
|
290 |
+
ws.close()
|
291 |
+
else:
|
292 |
+
choices = data["payload"]["choices"]
|
293 |
+
status = choices["status"]
|
294 |
+
content = choices["text"][0]["content"]
|
295 |
+
# print(content, end='')
|
296 |
+
# 将输出值放入队列
|
297 |
+
output_queue.put(content)
|
298 |
+
if status == 2:
|
299 |
+
ws.close()
|
300 |
+
|
301 |
+
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
|
302 |
+
websocket.enableTrace(False)
|
303 |
+
wsUrl = wsParam.create_url()
|
304 |
+
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
305 |
+
ws.appid = appid
|
306 |
+
ws.question = question
|
307 |
+
ws.domain = domain
|
308 |
+
ws.temperature = temperature
|
309 |
+
ws.max_tokens = max_tokens
|
310 |
+
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
311 |
+
return ''.join([output_queue.get() for _ in range(output_queue.qsize())])
|
312 |
+
|
313 |
+
def parse_llm_api_key(model:str, env_file:dict()=None):
|
314 |
+
"""
|
315 |
+
通过 model 和 env_file 的来解析平台参数
|
316 |
+
"""
|
317 |
+
if env_file is None:
|
318 |
+
_ = load_dotenv(find_dotenv())
|
319 |
+
env_file = os.environ
|
320 |
+
if model == "openai":
|
321 |
+
return env_file["OPENAI_API_KEY"]
|
322 |
+
elif model == "wenxin":
|
323 |
+
return env_file["wenxin_api_key"], env_file["wenxin_secret_key"]
|
324 |
+
elif model == "spark":
|
325 |
+
return env_file["spark_api_key"], env_file["spark_appid"], env_file["spark_api_secret"]
|
326 |
+
elif model == "zhipuai":
|
327 |
+
return get_from_dict_or_env(env_file, "zhipuai_api_key", "ZHIPUAI_API_KEY")
|
328 |
+
# return env_file["ZHIPUAI_API_KEY"]
|
329 |
+
else:
|
330 |
+
raise ValueError(f"model{model} not support!!!")
|
llm/self_llm.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
@File : self_llm.py
|
5 |
+
@Time : 2023/10/16 18:48:08
|
6 |
+
@Author : Logan Zou
|
7 |
+
@Version : 1.0
|
8 |
+
@Contact : loganzou0421@163.com
|
9 |
+
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
|
10 |
+
@Desc : 在 LangChain LLM 基础上封装的项目类,统一了 GPT、文心、讯飞、智谱多种 API 调用
|
11 |
+
'''
|
12 |
+
|
13 |
+
from langchain.llms.base import LLM
|
14 |
+
from typing import Dict, Any, Mapping
|
15 |
+
from pydantic import Field
|
16 |
+
|
17 |
+
class Self_LLM(LLM):
|
18 |
+
# 自定义 LLM
|
19 |
+
# 继承自 langchain.llms.base.LLM
|
20 |
+
# 原生接口地址
|
21 |
+
url : str = None
|
22 |
+
# 默认选用 GPT-3.5 模型,即目前一般所说的GPT
|
23 |
+
model_name: str = "gpt-3.5-turbo"
|
24 |
+
# 访问时延上限
|
25 |
+
request_timeout: float = None
|
26 |
+
# 温度系数
|
27 |
+
temperature: float = 0.1
|
28 |
+
# API_Key
|
29 |
+
api_key: str = None
|
30 |
+
# 必备的可选参数
|
31 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
32 |
+
|
33 |
+
# 定义一个返回默认参数的方法
|
34 |
+
@property
|
35 |
+
def _default_params(self) -> Dict[str, Any]:
|
36 |
+
"""获取调用默认参数。"""
|
37 |
+
normal_params = {
|
38 |
+
"temperature": self.temperature,
|
39 |
+
"request_timeout": self.request_timeout,
|
40 |
+
}
|
41 |
+
# print(type(self.model_kwargs))
|
42 |
+
return {**normal_params}
|
43 |
+
|
44 |
+
@property
|
45 |
+
def _identifying_params(self) -> Mapping[str, Any]:
|
46 |
+
"""Get the identifying parameters."""
|
47 |
+
return {**{"model_name": self.model_name}, **self._default_params}
|
llm/spark_llm.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
@File : wenxin_llm.py
|
5 |
+
@Time : 2023/10/16 18:53:26
|
6 |
+
@Author : Logan Zou
|
7 |
+
@Version : 1.0
|
8 |
+
@Contact : loganzou0421@163.com
|
9 |
+
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
|
10 |
+
@Desc : 基于讯飞星火大模型自定义 LLM 类
|
11 |
+
'''
|
12 |
+
|
13 |
+
from langchain.llms.base import LLM
|
14 |
+
from typing import Any, List, Mapping, Optional, Dict, Union, Tuple
|
15 |
+
from pydantic import Field
|
16 |
+
from llm.self_llm import Self_LLM
|
17 |
+
import json
|
18 |
+
import requests
|
19 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
20 |
+
import _thread as thread
|
21 |
+
import base64
|
22 |
+
import datetime
|
23 |
+
import hashlib
|
24 |
+
import hmac
|
25 |
+
import json
|
26 |
+
from urllib.parse import urlparse
|
27 |
+
import ssl
|
28 |
+
from datetime import datetime
|
29 |
+
from time import mktime
|
30 |
+
from urllib.parse import urlencode
|
31 |
+
from wsgiref.handlers import format_date_time
|
32 |
+
import websocket # 使用websocket_client
|
33 |
+
import queue
|
34 |
+
|
35 |
+
class Spark_LLM(Self_LLM):
|
36 |
+
# 讯飞星火大模型的自定义 LLM
|
37 |
+
# URL
|
38 |
+
url : str = "ws://spark-api.xf-yun.com/v1.1/chat"
|
39 |
+
# APPID
|
40 |
+
appid : str = None
|
41 |
+
# APISecret
|
42 |
+
api_secret : str = None
|
43 |
+
# Domain
|
44 |
+
domain :str = "general"
|
45 |
+
# max_token
|
46 |
+
max_tokens : int = 4096
|
47 |
+
|
48 |
+
def getText(self, role, content, text = []):
|
49 |
+
# role 是指定角色,content 是 prompt 内容
|
50 |
+
jsoncon = {}
|
51 |
+
jsoncon["role"] = role
|
52 |
+
jsoncon["content"] = content
|
53 |
+
text.append(jsoncon)
|
54 |
+
return text
|
55 |
+
|
56 |
+
def _call(self, prompt : str, stop: Optional[List[str]] = None,
|
57 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
58 |
+
**kwargs: Any):
|
59 |
+
if self.api_key == None or self.appid == None or self.api_secret == None:
|
60 |
+
# 三个 Key 均存在才可以正常调用
|
61 |
+
print("请填入 Key")
|
62 |
+
raise ValueError("Key 不存在")
|
63 |
+
# 将 Prompt 填充到星火格式
|
64 |
+
question = self.getText("user", prompt)
|
65 |
+
# 发起请求
|
66 |
+
try:
|
67 |
+
response = spark_main(self.appid,self.api_key,self.api_secret,self.url,self.domain,question, self.temperature, self.max_tokens)
|
68 |
+
return response
|
69 |
+
except Exception as e:
|
70 |
+
print(e)
|
71 |
+
print("请求失败")
|
72 |
+
return "请求失败"
|
73 |
+
|
74 |
+
@property
|
75 |
+
def _llm_type(self) -> str:
|
76 |
+
return "Spark"
|
77 |
+
|
78 |
+
answer = ""
|
79 |
+
|
80 |
+
class Ws_Param(object):
|
81 |
+
# 初始化
|
82 |
+
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
83 |
+
self.APPID = APPID
|
84 |
+
self.APIKey = APIKey
|
85 |
+
self.APISecret = APISecret
|
86 |
+
self.host = urlparse(Spark_url).netloc
|
87 |
+
self.path = urlparse(Spark_url).path
|
88 |
+
self.Spark_url = Spark_url
|
89 |
+
# 自定义
|
90 |
+
self.temperature = 0
|
91 |
+
self.max_tokens = 2048
|
92 |
+
|
93 |
+
# 生成url
|
94 |
+
def create_url(self):
|
95 |
+
# 生成RFC1123格式的时间戳
|
96 |
+
now = datetime.now()
|
97 |
+
date = format_date_time(mktime(now.timetuple()))
|
98 |
+
|
99 |
+
# 拼接字符串
|
100 |
+
signature_origin = "host: " + self.host + "\n"
|
101 |
+
signature_origin += "date: " + date + "\n"
|
102 |
+
signature_origin += "GET " + self.path + " HTTP/1.1"
|
103 |
+
|
104 |
+
# 进行hmac-sha256进行加密
|
105 |
+
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
106 |
+
digestmod=hashlib.sha256).digest()
|
107 |
+
|
108 |
+
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
109 |
+
|
110 |
+
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
111 |
+
|
112 |
+
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
113 |
+
|
114 |
+
# 将请求的鉴权参数组合为字典
|
115 |
+
v = {
|
116 |
+
"authorization": authorization,
|
117 |
+
"date": date,
|
118 |
+
"host": self.host
|
119 |
+
}
|
120 |
+
# 拼接鉴权参数,生成url
|
121 |
+
url = self.Spark_url + '?' + urlencode(v)
|
122 |
+
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
123 |
+
return url
|
124 |
+
|
125 |
+
|
126 |
+
# 收到websocket错误的处理
|
127 |
+
def on_error(ws, error):
|
128 |
+
print("### error:", error)
|
129 |
+
|
130 |
+
|
131 |
+
# 收到websocket关闭的处理
|
132 |
+
def on_close(ws,one,two):
|
133 |
+
print(" ")
|
134 |
+
|
135 |
+
|
136 |
+
# 收到websocket连接建立的处理
|
137 |
+
def on_open(ws):
|
138 |
+
thread.start_new_thread(run, (ws,))
|
139 |
+
|
140 |
+
|
141 |
+
def run(ws, *args):
|
142 |
+
data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature = ws.temperature, max_tokens = ws.max_tokens))
|
143 |
+
ws.send(data)
|
144 |
+
|
145 |
+
|
146 |
+
# 收到websocket消息的处理
|
147 |
+
def on_message(ws, message):
|
148 |
+
# print(message)
|
149 |
+
data = json.loads(message)
|
150 |
+
code = data['header']['code']
|
151 |
+
if code != 0:
|
152 |
+
print(f'请求错误: {code}, {data}')
|
153 |
+
ws.close()
|
154 |
+
else:
|
155 |
+
choices = data["payload"]["choices"]
|
156 |
+
status = choices["status"]
|
157 |
+
content = choices["text"][0]["content"]
|
158 |
+
print(content,end ="")
|
159 |
+
global answer
|
160 |
+
answer += content
|
161 |
+
# print(1)
|
162 |
+
if status == 2:
|
163 |
+
ws.close()
|
164 |
+
|
165 |
+
|
166 |
+
def gen_params(appid, domain,question, temperature, max_tokens):
|
167 |
+
"""
|
168 |
+
通过appid和用户的提问来生成请参数
|
169 |
+
"""
|
170 |
+
data = {
|
171 |
+
"header": {
|
172 |
+
"app_id": appid,
|
173 |
+
"uid": "1234"
|
174 |
+
},
|
175 |
+
"parameter": {
|
176 |
+
"chat": {
|
177 |
+
"domain": domain,
|
178 |
+
"random_threshold": 0.5,
|
179 |
+
"max_tokens": max_tokens,
|
180 |
+
"temperature" : temperature,
|
181 |
+
"auditing": "default"
|
182 |
+
}
|
183 |
+
},
|
184 |
+
"payload": {
|
185 |
+
"message": {
|
186 |
+
"text": question
|
187 |
+
}
|
188 |
+
}
|
189 |
+
}
|
190 |
+
return data
|
191 |
+
|
192 |
+
|
193 |
+
def spark_main(appid, api_key, api_secret, Spark_url,domain, question, temperature, max_tokens):
|
194 |
+
# print("星火:")
|
195 |
+
output_queue = queue.Queue()
|
196 |
+
def on_message(ws, message):
|
197 |
+
data = json.loads(message)
|
198 |
+
code = data['header']['code']
|
199 |
+
if code != 0:
|
200 |
+
print(f'请求错误: {code}, {data}')
|
201 |
+
ws.close()
|
202 |
+
else:
|
203 |
+
choices = data["payload"]["choices"]
|
204 |
+
status = choices["status"]
|
205 |
+
content = choices["text"][0]["content"]
|
206 |
+
# print(content, end='')
|
207 |
+
# 将输出值放入队列
|
208 |
+
output_queue.put(content)
|
209 |
+
if status == 2:
|
210 |
+
ws.close()
|
211 |
+
|
212 |
+
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
|
213 |
+
websocket.enableTrace(False)
|
214 |
+
wsUrl = wsParam.create_url()
|
215 |
+
ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
|
216 |
+
ws.appid = appid
|
217 |
+
ws.question = question
|
218 |
+
ws.domain = domain
|
219 |
+
ws.temperature = temperature
|
220 |
+
ws.max_tokens = max_tokens
|
221 |
+
ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
|
222 |
+
return ''.join([output_queue.get() for _ in range(output_queue.qsize())])
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
|
llm/test.ipynb
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from wenxin_llm import Wenxin_LLM"
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 3,
|
15 |
+
"metadata": {},
|
16 |
+
"outputs": [],
|
17 |
+
"source": [
|
18 |
+
"from dotenv import find_dotenv, load_dotenv\n",
|
19 |
+
"import os\n",
|
20 |
+
"\n",
|
21 |
+
"# 读取本地/项目的环境变量。\n",
|
22 |
+
"\n",
|
23 |
+
"# find_dotenv()寻找并定位.env文件的路径\n",
|
24 |
+
"# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中\n",
|
25 |
+
"# 如果你设置的是全局的环境变量,这行代码则没有任何作用。\n",
|
26 |
+
"_ = load_dotenv(find_dotenv())\n",
|
27 |
+
"\n",
|
28 |
+
"# 获取环境变量 OPENAI_API_KEY\n",
|
29 |
+
"wenxin_api_key = os.environ[\"wenxin_api_key\"]\n",
|
30 |
+
"wenxin_secret_key = os.environ[\"wenxin_secret_key\"]"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"cell_type": "code",
|
35 |
+
"execution_count": 4,
|
36 |
+
"metadata": {},
|
37 |
+
"outputs": [],
|
38 |
+
"source": [
|
39 |
+
"llm = Wenxin_LLM(model = \"ERNIE-Bot-turbo\", api_key=wenxin_api_key, secret_key=wenxin_secret_key)"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "code",
|
44 |
+
"execution_count": 5,
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [
|
47 |
+
{
|
48 |
+
"data": {
|
49 |
+
"text/plain": [
|
50 |
+
"'您好,我是百度研发的知识增强大语言模型,中文名是文心一言,英文名是ERNIE Bot。我能够与人对话互动,回答问题,协助创作,高效便捷地帮助人们获取信息、知识和灵感。\\n\\n如果您有任何问题,请随时告诉我。'"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
"execution_count": 5,
|
54 |
+
"metadata": {},
|
55 |
+
"output_type": "execute_result"
|
56 |
+
}
|
57 |
+
],
|
58 |
+
"source": [
|
59 |
+
"llm(\"你是谁\")"
|
60 |
+
]
|
61 |
+
},
|
62 |
+
{
|
63 |
+
"cell_type": "code",
|
64 |
+
"execution_count": 6,
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [],
|
67 |
+
"source": [
|
68 |
+
"from spark_llm import Spark_LLM"
|
69 |
+
]
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"execution_count": 7,
|
74 |
+
"metadata": {},
|
75 |
+
"outputs": [],
|
76 |
+
"source": [
|
77 |
+
"from dotenv import find_dotenv, load_dotenv\n",
|
78 |
+
"import os\n",
|
79 |
+
"\n",
|
80 |
+
"# 读取本地/项目的环境变量。\n",
|
81 |
+
"\n",
|
82 |
+
"# find_dotenv()寻找并定位.env文件的路径\n",
|
83 |
+
"# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中\n",
|
84 |
+
"# 如果你设置的是全局的环境变量,这行代码则没有任何作用。\n",
|
85 |
+
"_ = load_dotenv(find_dotenv())\n",
|
86 |
+
"#填写控制台中获取的 APPID 信息\n",
|
87 |
+
"appid = os.environ[\"spark_appid\"]\n",
|
88 |
+
"#填写控制台中获取的 APISecret 信息\n",
|
89 |
+
"api_secret = os.environ[\"spark_api_secret\"]\n",
|
90 |
+
"#填写控制台中获取的 APIKey 信息\n",
|
91 |
+
"api_key = os.environ[\"spark_api_key\"]"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": 8,
|
97 |
+
"metadata": {},
|
98 |
+
"outputs": [],
|
99 |
+
"source": [
|
100 |
+
"llm = Spark_LLM(model = \"spark\", appid=appid, api_secret=api_secret, api_key=api_key)"
|
101 |
+
]
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"execution_count": 9,
|
106 |
+
"metadata": {},
|
107 |
+
"outputs": [
|
108 |
+
{
|
109 |
+
"data": {
|
110 |
+
"text/plain": [
|
111 |
+
"'\\n\\n我是一个AI语言模型,可以回答你的问题和提供帮助。'"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
"execution_count": 9,
|
115 |
+
"metadata": {},
|
116 |
+
"output_type": "execute_result"
|
117 |
+
}
|
118 |
+
],
|
119 |
+
"source": [
|
120 |
+
"llm(\"你是谁\")"
|
121 |
+
]
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"cell_type": "code",
|
125 |
+
"execution_count": 10,
|
126 |
+
"metadata": {},
|
127 |
+
"outputs": [],
|
128 |
+
"source": [
|
129 |
+
"from zhipuai_llm import ZhipuAILLM\n",
|
130 |
+
"\n",
|
131 |
+
"from dotenv import find_dotenv, load_dotenv\n",
|
132 |
+
"import os\n",
|
133 |
+
"\n",
|
134 |
+
"# 读取本地/项目的环境变量。\n",
|
135 |
+
"\n",
|
136 |
+
"# find_dotenv()寻找并定位.env文件的路径\n",
|
137 |
+
"# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中\n",
|
138 |
+
"# 如果你设置的是全局的环境变量,这行代码则没有任何作用。\n",
|
139 |
+
"_ = load_dotenv(find_dotenv())\n",
|
140 |
+
"\n",
|
141 |
+
"api_key = os.environ[\"ZHIPUAI_API_KEY\"] #填写控制台中获取的 APIKey 信息"
|
142 |
+
]
|
143 |
+
},
|
144 |
+
{
|
145 |
+
"cell_type": "code",
|
146 |
+
"execution_count": 11,
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [
|
149 |
+
{
|
150 |
+
"data": {
|
151 |
+
"text/plain": [
|
152 |
+
"'我是一个名为 ChatGLM 的人工智能助手,由智谱 AI 公司于2023年训练的语言模型开发而成。我的任务是针对用户的问题和要求提供适当的答复和支持。'"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
"execution_count": 11,
|
156 |
+
"metadata": {},
|
157 |
+
"output_type": "execute_result"
|
158 |
+
}
|
159 |
+
],
|
160 |
+
"source": [
|
161 |
+
"llm = ZhipuAILLM(model=\"chatglm_pro\", zhipuai_api_key=api_key, temperature=0.1)\n",
|
162 |
+
"llm(\"你是谁\") "
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "markdown",
|
167 |
+
"metadata": {},
|
168 |
+
"source": [
|
169 |
+
"测试原生接口"
|
170 |
+
]
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "code",
|
174 |
+
"execution_count": 12,
|
175 |
+
"metadata": {},
|
176 |
+
"outputs": [],
|
177 |
+
"source": [
|
178 |
+
"from call_llm import get_completion"
|
179 |
+
]
|
180 |
+
},
|
181 |
+
{
|
182 |
+
"cell_type": "code",
|
183 |
+
"execution_count": 13,
|
184 |
+
"metadata": {},
|
185 |
+
"outputs": [],
|
186 |
+
"source": [
|
187 |
+
"from dotenv import find_dotenv, load_dotenv\n",
|
188 |
+
"import os\n",
|
189 |
+
"\n",
|
190 |
+
"# 读取本地/项目的环境变量。\n",
|
191 |
+
"\n",
|
192 |
+
"# find_dotenv()寻找并定位.env文件的路径\n",
|
193 |
+
"# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中\n",
|
194 |
+
"# 如果你设置的是全局的环境变量,这行代码则没有任何作用。\n",
|
195 |
+
"_ = load_dotenv(find_dotenv())\n",
|
196 |
+
"\n",
|
197 |
+
"# 获取环境变量 OPENAI_API_KEY\n",
|
198 |
+
"openai_api_key = os.environ[\"OPENAI_API_KEY\"]\n",
|
199 |
+
"wenxin_api_key = os.environ[\"wenxin_api_key\"]\n",
|
200 |
+
"wenxin_secret_key = os.environ[\"wenxin_secret_key\"]\n",
|
201 |
+
"spark_appid = os.environ[\"spark_appid\"]\n",
|
202 |
+
"spark_api_secret = os.environ[\"spark_api_secret\"]\n",
|
203 |
+
"spark_api_key = os.environ[\"spark_api_key\"]\n",
|
204 |
+
"zhipu_api_key = os.environ[\"ZHIPUAI_API_KEY\"]\n",
|
205 |
+
"\n",
|
206 |
+
"# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'\n",
|
207 |
+
"# os.environ[\"HTTP_PROXY\"] = 'http://127.0.0.1:7890'"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "code",
|
212 |
+
"execution_count": 14,
|
213 |
+
"metadata": {},
|
214 |
+
"outputs": [
|
215 |
+
{
|
216 |
+
"data": {
|
217 |
+
"text/plain": [
|
218 |
+
"'我是一个人工智能助手,可以回答你的问题并提供帮助。有什么可以帮到你的吗?'"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
"execution_count": 14,
|
222 |
+
"metadata": {},
|
223 |
+
"output_type": "execute_result"
|
224 |
+
}
|
225 |
+
],
|
226 |
+
"source": [
|
227 |
+
"get_completion(\"你是谁\",model=\"gpt-3.5-turbo\", api_key=openai_api_key)"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": 15,
|
233 |
+
"metadata": {},
|
234 |
+
"outputs": [
|
235 |
+
{
|
236 |
+
"data": {
|
237 |
+
"text/plain": [
|
238 |
+
"'您好,我是百度研发的知识增强大语言模型,中文名是文心一言,英文名是ERNIE Bot。我能够与人对话互动,回答问题,协助创作,高效便捷地帮助人们获取信息、知识和灵感。'"
|
239 |
+
]
|
240 |
+
},
|
241 |
+
"execution_count": 15,
|
242 |
+
"metadata": {},
|
243 |
+
"output_type": "execute_result"
|
244 |
+
}
|
245 |
+
],
|
246 |
+
"source": [
|
247 |
+
"get_completion(\"你是谁\",model=\"ERNIE-Bot-turbo\", api_key=wenxin_api_key, secret_key=wenxin_secret_key)"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"cell_type": "code",
|
252 |
+
"execution_count": 16,
|
253 |
+
"metadata": {},
|
254 |
+
"outputs": [
|
255 |
+
{
|
256 |
+
"data": {
|
257 |
+
"text/plain": [
|
258 |
+
"'\\n\\n我是一个AI语言模型,可以回答你的问题和提供帮助。'"
|
259 |
+
]
|
260 |
+
},
|
261 |
+
"execution_count": 16,
|
262 |
+
"metadata": {},
|
263 |
+
"output_type": "execute_result"
|
264 |
+
}
|
265 |
+
],
|
266 |
+
"source": [
|
267 |
+
"get_completion(\"你是谁\",model=\"Spark-1.5\", appid=spark_appid, api_key=spark_api_key, api_secret=spark_api_secret)"
|
268 |
+
]
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"cell_type": "code",
|
272 |
+
"execution_count": 17,
|
273 |
+
"metadata": {},
|
274 |
+
"outputs": [
|
275 |
+
{
|
276 |
+
"data": {
|
277 |
+
"text/plain": [
|
278 |
+
"'我是一个名为 ChatGLM 的人工智能助手,由智谱 AI 公司于2023年训练的语言模型开发而成。我的任务是针对用户的问题和要求提供适当的答复和支持。'"
|
279 |
+
]
|
280 |
+
},
|
281 |
+
"execution_count": 17,
|
282 |
+
"metadata": {},
|
283 |
+
"output_type": "execute_result"
|
284 |
+
}
|
285 |
+
],
|
286 |
+
"source": [
|
287 |
+
"get_completion(\"你是谁\",model=\"chatglm_std\", api_key=zhipu_api_key)"
|
288 |
+
]
|
289 |
+
}
|
290 |
+
],
|
291 |
+
"metadata": {
|
292 |
+
"kernelspec": {
|
293 |
+
"display_name": "langchain",
|
294 |
+
"language": "python",
|
295 |
+
"name": "python3"
|
296 |
+
},
|
297 |
+
"language_info": {
|
298 |
+
"codemirror_mode": {
|
299 |
+
"name": "ipython",
|
300 |
+
"version": 3
|
301 |
+
},
|
302 |
+
"file_extension": ".py",
|
303 |
+
"mimetype": "text/x-python",
|
304 |
+
"name": "python",
|
305 |
+
"nbconvert_exporter": "python",
|
306 |
+
"pygments_lexer": "ipython3",
|
307 |
+
"version": "3.10.0"
|
308 |
+
}
|
309 |
+
},
|
310 |
+
"nbformat": 4,
|
311 |
+
"nbformat_minor": 2
|
312 |
+
}
|
llm/wenxin_llm.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
@File : wenxin_llm.py
|
5 |
+
@Time : 2023/10/16 18:53:26
|
6 |
+
@Author : Logan Zou
|
7 |
+
@Version : 1.0
|
8 |
+
@Contact : loganzou0421@163.com
|
9 |
+
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
|
10 |
+
@Desc : 基于百度文心大模型自定义 LLM 类
|
11 |
+
'''
|
12 |
+
|
13 |
+
from langchain.llms.base import LLM
|
14 |
+
from typing import Any, List, Mapping, Optional, Dict, Union, Tuple
|
15 |
+
from pydantic import Field
|
16 |
+
from llm.self_llm import Self_LLM
|
17 |
+
import json
|
18 |
+
import requests
|
19 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
20 |
+
# 调用文心 API 的工具函数
|
21 |
+
def get_access_token(api_key : str, secret_key : str):
|
22 |
+
"""
|
23 |
+
使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key
|
24 |
+
"""
|
25 |
+
# 指定网址
|
26 |
+
url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}"
|
27 |
+
# 设置 POST 访问
|
28 |
+
payload = json.dumps("")
|
29 |
+
headers = {
|
30 |
+
'Content-Type': 'application/json',
|
31 |
+
'Accept': 'application/json'
|
32 |
+
}
|
33 |
+
# 通过 POST 访问获取账户对应的 access_token
|
34 |
+
response = requests.request("POST", url, headers=headers, data=payload)
|
35 |
+
return response.json().get("access_token")
|
36 |
+
|
37 |
+
class Wenxin_LLM(Self_LLM):
|
38 |
+
# 文心大模型的自定义 LLM
|
39 |
+
# URL
|
40 |
+
url : str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={}"
|
41 |
+
# Secret_Key
|
42 |
+
secret_key : str = None
|
43 |
+
# access_token
|
44 |
+
access_token: str = None
|
45 |
+
|
46 |
+
def init_access_token(self):
|
47 |
+
if self.api_key != None and self.secret_key != None:
|
48 |
+
# 两个 Key 均非空才可以获取 access_token
|
49 |
+
try:
|
50 |
+
self.access_token = get_access_token(self.api_key, self.secret_key)
|
51 |
+
except Exception as e:
|
52 |
+
print(e)
|
53 |
+
print("获取 access_token 失败,请检查 Key")
|
54 |
+
else:
|
55 |
+
print("API_Key 或 Secret_Key 为空,请检查 Key")
|
56 |
+
|
57 |
+
def _call(self, prompt : str, stop: Optional[List[str]] = None,
|
58 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
59 |
+
**kwargs: Any):
|
60 |
+
# 如果 access_token 为空,初始化 access_token
|
61 |
+
if self.access_token == None:
|
62 |
+
self.init_access_token()
|
63 |
+
# API 调用 url
|
64 |
+
url = self.url.format(self.access_token)
|
65 |
+
# 配置 POST 参数
|
66 |
+
payload = json.dumps({
|
67 |
+
"messages": [
|
68 |
+
{
|
69 |
+
"role": "user",# user prompt
|
70 |
+
"content": "{}".format(prompt)# 输入的 prompt
|
71 |
+
}
|
72 |
+
],
|
73 |
+
'temperature' : self.temperature
|
74 |
+
})
|
75 |
+
headers = {
|
76 |
+
'Content-Type': 'application/json'
|
77 |
+
}
|
78 |
+
# 发起请求
|
79 |
+
response = requests.request("POST", url, headers=headers, data=payload, timeout=self.request_timeout)
|
80 |
+
if response.status_code == 200:
|
81 |
+
# 返回的是一个 Json 字符串
|
82 |
+
js = json.loads(response.text)
|
83 |
+
# print(js)
|
84 |
+
return js["result"]
|
85 |
+
else:
|
86 |
+
return "请求失败"
|
87 |
+
|
88 |
+
@property
|
89 |
+
def _llm_type(self) -> str:
|
90 |
+
return "Wenxin"
|
llm/zhipuai_llm.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
@File : zhipuai_llm.py
|
5 |
+
@Time : 2023/10/16 22:06:26
|
6 |
+
@Author : 0-yy-0
|
7 |
+
@Version : 1.0
|
8 |
+
@Contact : 310484121@qq.com
|
9 |
+
@License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
|
10 |
+
@Desc : 基于智谱 AI 大模型自定义 LLM 类
|
11 |
+
'''
|
12 |
+
|
13 |
+
from __future__ import annotations
|
14 |
+
|
15 |
+
import logging
|
16 |
+
from typing import (
|
17 |
+
Any,
|
18 |
+
AsyncIterator,
|
19 |
+
Dict,
|
20 |
+
Iterator,
|
21 |
+
List,
|
22 |
+
Optional,
|
23 |
+
)
|
24 |
+
|
25 |
+
from langchain.callbacks.manager import (
|
26 |
+
AsyncCallbackManagerForLLMRun,
|
27 |
+
CallbackManagerForLLMRun,
|
28 |
+
)
|
29 |
+
from langchain.llms.base import LLM
|
30 |
+
from langchain.pydantic_v1 import Field, root_validator
|
31 |
+
from langchain.schema.output import GenerationChunk
|
32 |
+
from langchain.utils import get_from_dict_or_env
|
33 |
+
from llm.self_llm import Self_LLM
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
class ZhipuAILLM(Self_LLM):
|
39 |
+
"""Zhipuai hosted open source or customized models.
|
40 |
+
|
41 |
+
To use, you should have the ``zhipuai`` python package installed, and
|
42 |
+
the environment variable ``zhipuai_api_key`` set with
|
43 |
+
your API key and Secret Key.
|
44 |
+
|
45 |
+
zhipuai_api_key are required parameters which you could get from
|
46 |
+
https://open.bigmodel.cn/usercenter/apikeys
|
47 |
+
|
48 |
+
Example:
|
49 |
+
.. code-block:: python
|
50 |
+
|
51 |
+
from langchain.llms import ZhipuAILLM
|
52 |
+
zhipuai_model = ZhipuAILLM(model="chatglm_std", temperature=temperature)
|
53 |
+
|
54 |
+
"""
|
55 |
+
|
56 |
+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
57 |
+
|
58 |
+
client: Any
|
59 |
+
|
60 |
+
model: str = "chatglm_std"
|
61 |
+
"""Model name in chatglm_pro, chatglm_std, chatglm_lite. """
|
62 |
+
|
63 |
+
zhipuai_api_key: Optional[str] = None
|
64 |
+
|
65 |
+
incremental: Optional[bool] = True
|
66 |
+
"""Whether to incremental the results or not."""
|
67 |
+
|
68 |
+
streaming: Optional[bool] = False
|
69 |
+
"""Whether to streaming the results or not."""
|
70 |
+
# streaming = -incremental
|
71 |
+
|
72 |
+
request_timeout: Optional[int] = 60
|
73 |
+
"""request timeout for chat http requests"""
|
74 |
+
|
75 |
+
top_p: Optional[float] = 0.8
|
76 |
+
temperature: Optional[float] = 0.95
|
77 |
+
request_id: Optional[float] = None
|
78 |
+
|
79 |
+
@root_validator()
|
80 |
+
def validate_enviroment(cls, values: Dict) -> Dict:
|
81 |
+
|
82 |
+
values["zhipuai_api_key"] = get_from_dict_or_env(
|
83 |
+
values,
|
84 |
+
"zhipuai_api_key",
|
85 |
+
"ZHIPUAI_API_KEY",
|
86 |
+
)
|
87 |
+
|
88 |
+
params = {
|
89 |
+
"zhipuai_api_key": values["zhipuai_api_key"],
|
90 |
+
"model": values["model"],
|
91 |
+
}
|
92 |
+
try:
|
93 |
+
import zhipuai
|
94 |
+
|
95 |
+
zhipuai.api_key = values["zhipuai_api_key"]
|
96 |
+
values["client"] = zhipuai.model_api
|
97 |
+
except ImportError:
|
98 |
+
raise ValueError(
|
99 |
+
"zhipuai package not found, please install it with "
|
100 |
+
"`pip install zhipuai`"
|
101 |
+
)
|
102 |
+
return values
|
103 |
+
|
104 |
+
@property
|
105 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
106 |
+
return {
|
107 |
+
**{"model": self.model},
|
108 |
+
**super()._identifying_params,
|
109 |
+
}
|
110 |
+
|
111 |
+
@property
|
112 |
+
def _llm_type(self) -> str:
|
113 |
+
"""Return type of llm."""
|
114 |
+
return "zhipuai"
|
115 |
+
|
116 |
+
@property
|
117 |
+
def _default_params(self) -> Dict[str, Any]:
|
118 |
+
"""Get the default parameters for calling OpenAI API."""
|
119 |
+
normal_params = {
|
120 |
+
"streaming": self.streaming,
|
121 |
+
"top_p": self.top_p,
|
122 |
+
"temperature": self.temperature,
|
123 |
+
"request_id": self.request_id,
|
124 |
+
}
|
125 |
+
|
126 |
+
return {**normal_params, **self.model_kwargs}
|
127 |
+
|
128 |
+
def _convert_prompt_msg_params(
|
129 |
+
self,
|
130 |
+
prompt: str,
|
131 |
+
**kwargs: Any,
|
132 |
+
) -> dict:
|
133 |
+
return {
|
134 |
+
**{"prompt": prompt, "model": self.model},
|
135 |
+
**self._default_params,
|
136 |
+
**kwargs,
|
137 |
+
}
|
138 |
+
|
139 |
+
def _call(
|
140 |
+
self,
|
141 |
+
prompt: str,
|
142 |
+
stop: Optional[List[str]] = None,
|
143 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
144 |
+
**kwargs: Any,
|
145 |
+
) -> str:
|
146 |
+
"""Call out to an zhipuai models endpoint for each generation with a prompt.
|
147 |
+
Args:
|
148 |
+
prompt: The prompt to pass into the model.
|
149 |
+
Returns:
|
150 |
+
The string generated by the model.
|
151 |
+
|
152 |
+
Example:
|
153 |
+
.. code-block:: python
|
154 |
+
response = zhipuai_model("Tell me a joke.")
|
155 |
+
"""
|
156 |
+
if self.streaming:
|
157 |
+
completion = ""
|
158 |
+
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
159 |
+
completion += chunk.text
|
160 |
+
return completion
|
161 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
162 |
+
|
163 |
+
response_payload = self.client.invoke(**params)
|
164 |
+
return response_payload["data"]["choices"][-1]["content"].strip('"').strip(" ")
|
165 |
+
|
166 |
+
async def _acall(
|
167 |
+
self,
|
168 |
+
prompt: str,
|
169 |
+
stop: Optional[List[str]] = None,
|
170 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
171 |
+
**kwargs: Any,
|
172 |
+
) -> str:
|
173 |
+
if self.streaming:
|
174 |
+
completion = ""
|
175 |
+
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
|
176 |
+
completion += chunk.text
|
177 |
+
return completion
|
178 |
+
|
179 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
180 |
+
|
181 |
+
response = await self.client.async_invoke(**params)
|
182 |
+
|
183 |
+
return response_payload
|
184 |
+
|
185 |
+
def _stream(
|
186 |
+
self,
|
187 |
+
prompt: str,
|
188 |
+
stop: Optional[List[str]] = None,
|
189 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
190 |
+
**kwargs: Any,
|
191 |
+
) -> Iterator[GenerationChunk]:
|
192 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
193 |
+
|
194 |
+
for res in self.client.invoke(**params):
|
195 |
+
if res:
|
196 |
+
chunk = GenerationChunk(text=res)
|
197 |
+
yield chunk
|
198 |
+
if run_manager:
|
199 |
+
run_manager.on_llm_new_token(chunk.text)
|
200 |
+
|
201 |
+
async def _astream(
|
202 |
+
|
203 |
+
self,
|
204 |
+
prompt: str,
|
205 |
+
stop: Optional[List[str]] = None,
|
206 |
+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
207 |
+
**kwargs: Any,
|
208 |
+
) -> AsyncIterator[GenerationChunk]:
|
209 |
+
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
210 |
+
|
211 |
+
async for res in await self.client.ado(**params):
|
212 |
+
if res:
|
213 |
+
chunk = GenerationChunk(text=res["data"]["choices"]["content"])
|
214 |
+
|
215 |
+
yield chunk
|
216 |
+
if run_manager:
|
217 |
+
await run_manager.on_llm_new_token(chunk.text)
|
prompts.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""define some usefull prompts"""
|
2 |
+
|
3 |
+
trans_prompt = """
|
4 |
+
Please translate the word into chinese: {} """
|
5 |
+
|
6 |
+
query_prompt = """
|
7 |
+
Please translate the word into chinese \
|
8 |
+
and give english example sentence and sentence meaning in Chinese:
|
9 |
+
{} """
|
10 |
+
|
11 |
+
learn_prompt = """
|
12 |
+
Please give an english sentence contains these words: {} \
|
13 |
+
and give sentence meaning in Chinese.
|
14 |
+
"""
|
15 |
+
|
16 |
+
wash_prompt = """
|
17 |
+
Please help me to remove the duplicate words and meaningless words from
|
18 |
+
the words list: {} and return the result.
|
19 |
+
|
20 |
+
For example:
|
21 |
+
User: apple, red, green, yellow, red, blll><, :*&
|
22 |
+
Your return: apple, red, green, yellow
|
23 |
+
"""
|
24 |
+
#*******************************************************
|
25 |
+
#************** user prompts ***************************
|
26 |
+
#*******************************************************
|
27 |
+
user_message_mapper = """
|
28 |
+
Please classify the following categories and return the name of the category. If your think user's requirement does not belong to the following categories, return the number 0.
|
29 |
+
Here are four examples:
|
30 |
+
User: I want to know the meaning of the word "apple"
|
31 |
+
"""
|
32 |
+
|
33 |
+
#*******************************************************
|
34 |
+
#************** system prompts *************************
|
35 |
+
#*******************************************************
|
36 |
+
|
37 |
+
|
38 |
+
# Using few shots learning to map the user's requirement to a command
|
39 |
+
system_message_mapper = """
|
40 |
+
You will receive the user's requirement. \
|
41 |
+
Please classify the user's requirement into the following categories \
|
42 |
+
and return the name of the category.
|
43 |
+
If your think user's requirement does not belong to the \
|
44 |
+
following categories, return the number 0.
|
45 |
+
|
46 |
+
Here are four examples:
|
47 |
+
|
48 |
+
User: I want to know the meaning of the word "apple"
|
49 |
+
Your return: :query apple
|
50 |
+
|
51 |
+
User: I want to add the word "apple" to my dictionary
|
52 |
+
Your return: :add apple
|
53 |
+
|
54 |
+
User: I want to remove the word "apple" from my dictionary
|
55 |
+
Your return: :remove apple
|
56 |
+
|
57 |
+
User: I want to learn the word "apple" based words in my dictionary now
|
58 |
+
Your return: :learn
|
59 |
+
|
60 |
+
User: I already learnt the meaning of the word "apple"
|
61 |
+
Your return: :remove apple
|
62 |
+
|
63 |
+
User: I want to know the rest words in my dictionary
|
64 |
+
Your return: :show
|
65 |
+
|
66 |
+
Here are the meaning of four categories for you to reference:
|
67 |
+
:query word means user want to study the meaning of word
|
68 |
+
:add word means user want to add word into his dictionary
|
69 |
+
:remove word means user want to remove word from his dictionary
|
70 |
+
:learn word means user want to learn words in his dictionary related to the word
|
71 |
+
:show means user want to show all words in his dictionary
|
72 |
+
|
73 |
+
"""
|
74 |
+
|
75 |
+
# using few shots learning to select the important words in a sentence
|
76 |
+
system_message_select = """
|
77 |
+
You will receive the english sentence from user, please select the words in the sentence that you think
|
78 |
+
is important for the user to understand the meaning of the sentence.
|
79 |
+
If the sentence does not contain any english words, please just return number 0.
|
80 |
+
|
81 |
+
For example:
|
82 |
+
User: The apple is red, and it's green.
|
83 |
+
Your return: apple, red
|
84 |
+
"""
|
requirements.txt
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==23.2.1
|
2 |
+
aiohttp==3.8.6
|
3 |
+
aiosignal==1.3.1
|
4 |
+
altair==5.1.2
|
5 |
+
annotated-types==0.6.0
|
6 |
+
anyio==4.0.0
|
7 |
+
asttokens==2.2.1
|
8 |
+
async-timeout==4.0.3
|
9 |
+
attrs==23.1.0
|
10 |
+
backcall==0.2.0
|
11 |
+
backoff==2.2.1
|
12 |
+
cachetools==5.3.2
|
13 |
+
certifi==2023.7.22
|
14 |
+
chardet==5.2.0
|
15 |
+
charset-normalizer==3.3.2
|
16 |
+
chromadb==0.3.29
|
17 |
+
click==8.1.7
|
18 |
+
clickhouse-connect==0.6.20
|
19 |
+
coloredlogs==15.0.1
|
20 |
+
comm==0.1.3
|
21 |
+
contourpy==1.2.0
|
22 |
+
cycler==0.12.1
|
23 |
+
dataclasses==0.6
|
24 |
+
dataclasses-json==0.5.14
|
25 |
+
debugpy==1.6.6
|
26 |
+
decorator==5.1.1
|
27 |
+
duckdb==0.9.1
|
28 |
+
entrypoints==0.4
|
29 |
+
exceptiongroup==1.0.4
|
30 |
+
executing==1.2.0
|
31 |
+
fastapi==0.85.1
|
32 |
+
ffmpy==0.3.1
|
33 |
+
filelock==3.13.1
|
34 |
+
filetype==1.2.0
|
35 |
+
flatbuffers==23.5.26
|
36 |
+
fonttools==4.44.0
|
37 |
+
frozenlist==1.4.0
|
38 |
+
fsspec==2023.10.0
|
39 |
+
gradio==3.40.1
|
40 |
+
gradio_client==0.7.0
|
41 |
+
greenlet==3.0.1
|
42 |
+
h11==0.14.0
|
43 |
+
hnswlib==0.7.0
|
44 |
+
httpcore==1.0.2
|
45 |
+
httptools==0.6.1
|
46 |
+
httpx==0.25.1
|
47 |
+
huggingface-hub==0.17.3
|
48 |
+
humanfriendly==10.0
|
49 |
+
idna==3.4
|
50 |
+
importlib-metadata==6.0.0
|
51 |
+
importlib-resources==6.1.1
|
52 |
+
ipykernel==6.23.1
|
53 |
+
ipython==8.10.0
|
54 |
+
jedi==0.18.2
|
55 |
+
jieba==0.42.1
|
56 |
+
Jinja2==3.1.2
|
57 |
+
joblib==1.3.2
|
58 |
+
jsonschema==4.19.2
|
59 |
+
jsonschema-specifications==2023.7.1
|
60 |
+
jupyter_client==8.6.0
|
61 |
+
jupyter_core==5.3.0
|
62 |
+
kiwisolver==1.4.5
|
63 |
+
langchain==0.0.292
|
64 |
+
langsmith==0.0.63
|
65 |
+
linkify-it-py==2.0.2
|
66 |
+
lxml==4.9.3
|
67 |
+
lz4==4.3.2
|
68 |
+
Markdown==3.4.3
|
69 |
+
markdown-it-py==2.2.0
|
70 |
+
MarkupSafe==2.1.3
|
71 |
+
marshmallow==3.20.1
|
72 |
+
matplotlib==3.8.1
|
73 |
+
matplotlib-inline==0.1.3
|
74 |
+
mdit-py-plugins==0.3.3
|
75 |
+
mdurl==0.1.2
|
76 |
+
monotonic==1.6
|
77 |
+
mpmath==1.3.0
|
78 |
+
multidict==6.0.4
|
79 |
+
mypy-extensions==1.0.0
|
80 |
+
nest-asyncio==1.5.4
|
81 |
+
nltk==3.8.1
|
82 |
+
numexpr==2.8.7
|
83 |
+
numpy==1.26.2
|
84 |
+
onnxruntime==1.16.2
|
85 |
+
openai==0.27.6
|
86 |
+
orjson==3.9.10
|
87 |
+
overrides==7.4.0
|
88 |
+
packaging==23.1
|
89 |
+
pandas==2.1.3
|
90 |
+
parso==0.8.3
|
91 |
+
pexpect==4.8.0
|
92 |
+
pickleshare==0.7.5
|
93 |
+
Pillow==10.1.0
|
94 |
+
pip==23.3
|
95 |
+
platformdirs==3.0.0
|
96 |
+
posthog==3.0.2
|
97 |
+
prompt-toolkit==3.0.38
|
98 |
+
protobuf==4.25.0
|
99 |
+
psutil==5.9.4
|
100 |
+
ptyprocess==0.7.0
|
101 |
+
pulsar-client==3.3.0
|
102 |
+
pure-eval==0.2.2
|
103 |
+
pydantic==1.10.10
|
104 |
+
pydantic_core==2.10.1
|
105 |
+
pydub==0.25.1
|
106 |
+
Pygments==2.14.0
|
107 |
+
PyJWT==2.8.0
|
108 |
+
PyMuPDF==1.23.6
|
109 |
+
PyMuPDFb==1.23.6
|
110 |
+
pyparsing==3.1.1
|
111 |
+
python-dateutil==2.8.2
|
112 |
+
python-dotenv==1.0.0
|
113 |
+
python-magic==0.4.27
|
114 |
+
python-multipart==0.0.6
|
115 |
+
pytz==2023.3.post1
|
116 |
+
PyYAML==6.0.1
|
117 |
+
pyzmq==25.1.0
|
118 |
+
referencing==0.30.2
|
119 |
+
regex==2023.5.5
|
120 |
+
requests==2.31.0
|
121 |
+
rouge-chinese==1.0.3
|
122 |
+
rpds-py==0.12.0
|
123 |
+
scikit-learn==1.2.2
|
124 |
+
scipy==1.11.2
|
125 |
+
semantic-version==2.10.0
|
126 |
+
setuptools==68.0.0
|
127 |
+
six==1.16.0
|
128 |
+
sniffio==1.3.0
|
129 |
+
SQLAlchemy==2.0.23
|
130 |
+
stack-data==0.6.2
|
131 |
+
starlette==0.20.4
|
132 |
+
sympy==1.12
|
133 |
+
tabulate==0.9.0
|
134 |
+
tenacity==8.2.3
|
135 |
+
threadpoolctl==3.1.0
|
136 |
+
tokenizers==0.14.1
|
137 |
+
toolz==0.12.0
|
138 |
+
tornado==6.3.3
|
139 |
+
tqdm==4.66.1
|
140 |
+
traitlets==5.9.0
|
141 |
+
typing_extensions==4.7.1
|
142 |
+
typing-inspect==0.9.0
|
143 |
+
tzdata==2023.3
|
144 |
+
uc-micro-py==1.0.2
|
145 |
+
unstructured==0.9.0
|
146 |
+
urllib3==2.0.7
|
147 |
+
uvicorn==0.24.0.post1
|
148 |
+
uvloop==0.19.0
|
149 |
+
watchfiles==0.21.0
|
150 |
+
wcwidth==0.2.5
|
151 |
+
websocket-client==1.5.2
|
152 |
+
websockets==11.0.3
|
153 |
+
wheel==0.41.2
|
154 |
+
yarl==1.9.2
|
155 |
+
zhipuai==1.0.7
|
156 |
+
zipp==3.11.0
|
157 |
+
zstandard==0.22.0
|
test/__pycache__/test_create_db.cpython-311-pytest-8.2.0.pyc
ADDED
Binary file (7.87 kB). View file
|
|
test/test_create_db.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from pytest import fixture
|
3 |
+
from create_db import split_text
|
4 |
+
|
5 |
+
|
6 |
+
@fixture
|
7 |
+
def sample_text():
|
8 |
+
return [
|
9 |
+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
|
10 |
+
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. "
|
11 |
+
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. "
|
12 |
+
"Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. "
|
13 |
+
"Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.",
|
14 |
+
"Another long text string to demonstrate the splitting functionality. This text should also be split into multiple chunks."
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
def test_split_text(sample_text):
|
19 |
+
# Split the sample text into chunks
|
20 |
+
chunks = split_text(sample_text)
|
21 |
+
|
22 |
+
# Assert that the chunks are lists of strings
|
23 |
+
assert all(
|
24 |
+
isinstance(chunk, list) and all(
|
25 |
+
isinstance(text, str) for text in chunk) for chunk in chunks)
|
26 |
+
|
27 |
+
# Assert that the chunks are not empty
|
28 |
+
assert all(chunk for chunk in chunks)
|
29 |
+
|
30 |
+
# Assert that the chunks have the expected length (approx. 1500 characters with 150 overlap)
|
31 |
+
expected_length = 1500 - 150 # Subtracting the overlap size
|
32 |
+
assert all(expected_length <= len(''.join(chunk)) < 1500
|
33 |
+
for chunk in chunks)
|
34 |
+
|
35 |
+
# Assert that the chunks contain the original text
|
36 |
+
original_text = ' '.join(sample_text)
|
37 |
+
assert all(text in original_text for chunk in chunks for text in chunk)
|
38 |
+
|
39 |
+
# Assert that the chunks do not overlap (except for the overlap size)
|
40 |
+
for i in range(len(chunks) - 1):
|
41 |
+
previous_chunk = chunks[i]
|
42 |
+
next_chunk = chunks[i + 1]
|
43 |
+
overlap = ''.join(set(previous_chunk[-150:]) & set(next_chunk[:150]))
|
44 |
+
assert len(overlap) == 150 or not overlap
|
words_db.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlite3
|
2 |
+
from loguru import logger
|
3 |
+
|
4 |
+
DB_PATH = './database/word_database.db'
|
5 |
+
|
6 |
+
class WordsDB(object):
|
7 |
+
def __init__(self):
|
8 |
+
logger.info('Initialized words database.')
|
9 |
+
|
10 |
+
def _connect_db(self):
|
11 |
+
conn = sqlite3.connect(
|
12 |
+
DB_PATH,
|
13 |
+
timeout=10,
|
14 |
+
check_same_thread=False)
|
15 |
+
cursor = conn.cursor()
|
16 |
+
cursor.execute('''
|
17 |
+
CREATE TABLE IF NOT EXISTS words (
|
18 |
+
id INTEGER PRIMARY KEY,
|
19 |
+
word TEXT NOT NULL,
|
20 |
+
definition TEXT NOT NULL
|
21 |
+
)
|
22 |
+
''')
|
23 |
+
return conn, conn.cursor()
|
24 |
+
|
25 |
+
|
26 |
+
def add_word(self, word, definition):
|
27 |
+
self.conn, self.cursor = self._connect_db()
|
28 |
+
self.cursor.execute('INSERT INTO words (word, definition) VALUES (?, ?)', (word, definition))
|
29 |
+
self.conn.commit()
|
30 |
+
self.cursor.close()
|
31 |
+
self.conn.close()
|
32 |
+
|
33 |
+
|
34 |
+
def delete_word(self, word):
|
35 |
+
self.conn, self.cursor = self._connect_db()
|
36 |
+
self.cursor.execute('DELETE FROM words WHERE word = ?', (word,))
|
37 |
+
self.conn.commit()
|
38 |
+
self.cursor.close()
|
39 |
+
self.conn.close()
|
40 |
+
|
41 |
+
def update_word(self, word, new_definition):
|
42 |
+
self.conn, self.cursor = self._connect_db()
|
43 |
+
self.cursor.execute('UPDATE words SET definition = ? WHERE word = ?', (new_definition, word))
|
44 |
+
self.conn.commit()
|
45 |
+
self.cursor.close()
|
46 |
+
self.conn.close()
|
47 |
+
|
48 |
+
def query_word(self):
|
49 |
+
self.conn, self.cursor = self._connect_db()
|
50 |
+
self.cursor.execute('SELECT * FROM words')
|
51 |
+
res = self.cursor.fetchall()
|
52 |
+
self.cursor.close()
|
53 |
+
self.conn.close()
|
54 |
+
return res
|
55 |
+
|
56 |
+
|
57 |
+
words_db = WordsDB()
|
58 |
+
|
59 |
+
if __name__ == '__main__':
|
60 |
+
|
61 |
+
words_db.add_word('apple', '苹果')
|
62 |
+
words_db.add_word('banana', '香蕉')
|
63 |
+
words_db.update_word('banana', '新的香蕉')
|
64 |
+
result = words_db.query_word('banana')
|
65 |
+
print(result)
|
66 |
+
|
67 |
+
|