dht-tb16p commited on
Commit
e60c070
1 Parent(s): 7a0d66d

Commit 1st version

Browse files
Files changed (46) hide show
  1. __init__.py +0 -0
  2. app.py +38 -0
  3. chat.py +171 -0
  4. config.py +1 -0
  5. create_db.py +149 -0
  6. database/__pycache__/create_db.cpython-311.pyc +0 -0
  7. database/word_database.db +0 -0
  8. embedding/__init__.py +1 -0
  9. embedding/__pycache__/__init__.cpython-310.pyc +0 -0
  10. embedding/__pycache__/__init__.cpython-311.pyc +0 -0
  11. embedding/__pycache__/__init__.cpython-39.pyc +0 -0
  12. embedding/__pycache__/call_embedding.cpython-310.pyc +0 -0
  13. embedding/__pycache__/call_embedding.cpython-311.pyc +0 -0
  14. embedding/__pycache__/call_embedding.cpython-39.pyc +0 -0
  15. embedding/__pycache__/zhipuai_embedding.cpython-310.pyc +0 -0
  16. embedding/__pycache__/zhipuai_embedding.cpython-311.pyc +0 -0
  17. embedding/__pycache__/zhipuai_embedding.cpython-39.pyc +0 -0
  18. embedding/call_embedding.py +20 -0
  19. embedding/zhipuai_embedding.py +112 -0
  20. llm/__pycache__/call_llm.cpython-310.pyc +0 -0
  21. llm/__pycache__/call_llm.cpython-311.pyc +0 -0
  22. llm/__pycache__/call_llm.cpython-39.pyc +0 -0
  23. llm/__pycache__/self_llm.cpython-310.pyc +0 -0
  24. llm/__pycache__/self_llm.cpython-311.pyc +0 -0
  25. llm/__pycache__/self_llm.cpython-39.pyc +0 -0
  26. llm/__pycache__/spark_llm.cpython-310.pyc +0 -0
  27. llm/__pycache__/spark_llm.cpython-311.pyc +0 -0
  28. llm/__pycache__/spark_llm.cpython-39.pyc +0 -0
  29. llm/__pycache__/wenxin_llm.cpython-310.pyc +0 -0
  30. llm/__pycache__/wenxin_llm.cpython-311.pyc +0 -0
  31. llm/__pycache__/wenxin_llm.cpython-39.pyc +0 -0
  32. llm/__pycache__/wenxin_llm_.cpython-310.pyc +0 -0
  33. llm/__pycache__/zhipuai_llm.cpython-310.pyc +0 -0
  34. llm/__pycache__/zhipuai_llm.cpython-311.pyc +0 -0
  35. llm/__pycache__/zhipuai_llm.cpython-39.pyc +0 -0
  36. llm/call_llm.py +330 -0
  37. llm/self_llm.py +47 -0
  38. llm/spark_llm.py +227 -0
  39. llm/test.ipynb +312 -0
  40. llm/wenxin_llm.py +90 -0
  41. llm/zhipuai_llm.py +217 -0
  42. prompts.py +84 -0
  43. requirements.txt +157 -0
  44. test/__pycache__/test_create_db.cpython-311-pytest-8.2.0.pyc +0 -0
  45. test/test_create_db.py +44 -0
  46. words_db.py +67 -0
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import openai
3
+
4
+ from create_db import create_db
5
+ from chat import respond
6
+
7
+ openai.api_base = "https://api.v36.cm/v1" # global setting is needed
8
+
9
+ def launch_app():
10
+
11
+ with gr.Blocks() as demo:
12
+ with gr.Row(equal_height=True):
13
+ gr.Markdown("## 英语单词学习工具")
14
+
15
+ with gr.Row():
16
+ with gr.Column(scale=4):
17
+ chatbot = gr.Chatbot(height=400)
18
+ msg = gr.Textbox(label="在此输入指令(以:起始)或对话")
19
+ btn = gr.Button("Submit")
20
+ gr.ClearButton(components=[msg, chatbot], value="清除对话")
21
+ btn.click(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
22
+ msg.submit(respond, inputs=[msg, chatbot], outputs=[msg, chatbot])
23
+ with gr.Column(scale=1):
24
+ file = gr.File(label='请导入制作词库的文件', file_count='single', file_types=['.md', '.pdf'])
25
+ with gr.Row():
26
+ init_vocab_by_file = gr.Button("使用文件生成个人词库")
27
+ text = gr.Textbox(label="在此粘贴制作词库的文本", lines=8)
28
+ with gr.Row():
29
+ init_vocab_by_text = gr.Button("使用文本生成个人词库")
30
+ init_vocab_by_file.click(create_db, inputs=[file, chatbot], outputs=[chatbot])
31
+ init_vocab_by_text.click(create_db, inputs=[text, chatbot], outputs=[chatbot])
32
+
33
+ gr.close_all()
34
+ demo.launch()
35
+
36
+ if __name__ == "__main__":
37
+ launch_app()
38
+
chat.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import List
3
+ from loguru import logger
4
+
5
+ import config
6
+ from llm.call_llm import get_completion, get_completion_from_messages
7
+ from words_db import words_db
8
+ from create_db import get_similar_k_words
9
+ from prompts import trans_prompt, query_prompt, learn_prompt
10
+ from prompts import system_message_mapper
11
+
12
+
13
+ def format_common_prompt(raw_prompt, variable):
14
+ """get format prompt by repalce variable in raw_prompt
15
+ """
16
+ return raw_prompt.format(variable)
17
+
18
+ def format_chat_prompt(message, chat_history) -> str:
19
+ """get format prompt
20
+ """
21
+ prompt = ""
22
+ for turn in chat_history: # add history info
23
+ user_message, bot_message = turn
24
+ prompt = f"{prompt}\nUser: {user_message}\nAssistant: {bot_message}"
25
+ prompt = f"{prompt}\nUser: {message}\nAssistant:"
26
+ return prompt
27
+
28
+
29
+ def respond(message, chat_history,
30
+ llm="gpt-3.5-turbo", history_len=3, temperature=0.1, max_tokens=2048):
31
+ """get respond from LLM
32
+ """
33
+
34
+ # deal with commands
35
+ respond_message = command_parser(message)
36
+ if respond_message:
37
+ chat_history.append((message, respond_message))
38
+ respond_message = ""
39
+ return respond_message, chat_history
40
+
41
+ # map natural language to command
42
+ respond_message = command_mapper(message)
43
+ if respond_message:
44
+ chat_history.append((message, respond_message))
45
+ respond_message = ""
46
+ return respond_message, chat_history
47
+
48
+ # no commands return, so chat with LLM
49
+ if message is None or len(message) < 1:
50
+ return "", chat_history
51
+ try:
52
+ chat_history = chat_history[-history_len:] if history_len > 0 else [] # constrain history length
53
+ formatted_prompt = format_chat_prompt(message, chat_history) # format prompt
54
+ bot_message = get_completion(
55
+ formatted_prompt,
56
+ llm,
57
+ api_key=config.api_key,
58
+ temperature=temperature, max_tokens=max_tokens)
59
+ bot_message = re.sub(r"\\n", '<br/>', bot_message) # replace \n with <br/>
60
+ chat_history.append((message, bot_message))
61
+ return "", chat_history
62
+ except Exception as e:
63
+ return e, chat_history
64
+
65
+ def command_parser(input: str) -> str:
66
+ """parse 4 type commands
67
+ 1. :add
68
+ 2. :remove
69
+ 3. :learn
70
+ 4. :query
71
+ return info of action to user
72
+ """
73
+ if input.startswith(":add"):
74
+ words = input.split(" ")[1:]
75
+ info = add_words(words)
76
+ return info
77
+ if input.startswith(":remove"):
78
+ words = input.split(" ")[1:]
79
+ info = remove_words(words)
80
+ return info
81
+ if input.startswith(":learn"):
82
+ if len(input.split(" ")) != 2:
83
+ return "学习模式将基于词库进行,请指定一个query单词"
84
+ query = input.split(" ")[1]
85
+ info = learn_words(query)
86
+ return f"Based on your query word: {query} and dictionary, learning sentence is:\n{info}"
87
+ if input.startswith(":query"):
88
+ if len(input.split(" ")) > 2:
89
+ return "查询模式仅支持单个单词,请使用:query <word>进行查询"
90
+ word = input.split(" ")[1]
91
+ info = query_word(word)
92
+ return f"{word}\n{info}"
93
+ if input.startswith(":show"):
94
+ info = show_all_words()
95
+ return info
96
+ if input.startswith(":help"):
97
+ return "目前支持的指令有:\n:add <word1> <word2> ...\n:remove <word1> <word2> ...\n:learn <query_word>\n:query <word>"
98
+
99
+ return ""
100
+
101
+
102
+ def show_all_words() -> str:
103
+ """show all words in db
104
+ """
105
+ try:
106
+ all_words = words_db.query_word()
107
+ return f"目前词库中的所有单词:\n{all_words}"
108
+ except Exception as e:
109
+ logger.error(str(e))
110
+ return "查询失败"
111
+
112
+ def add_words(input: List[str]):
113
+ word_tuple_list = [
114
+ (word, get_completion(
115
+ prompt=format_common_prompt(trans_prompt, word),
116
+ api_key=config.api_key))
117
+ for word in input
118
+ ]
119
+ try:
120
+ for word_tuple in word_tuple_list:
121
+ word, definition = word_tuple
122
+ words_db.add_word(word, definition)
123
+ logger.info(f"已经添加单词: {word} 和其释义: {definition}")
124
+ except Exception as e:
125
+ logger.error(str(e))
126
+ return f"添加单词失败: {input}"
127
+ return f"已添加单词: {input}"
128
+
129
+ def remove_words(input: List[str]):
130
+ try:
131
+ for word in input:
132
+ words_db.delete_word(word)
133
+ logger.info(f"已经删除单词: {word} 和其释义")
134
+ except Exception as e:
135
+ logger.error(str(e))
136
+ return f"删除单词失败: {input}"
137
+ return f"已删除单词: {input}"
138
+
139
+ def learn_words(query_word) -> str:
140
+ # get top 3 words from vec db and generate material
141
+ words = get_similar_k_words(query_word)
142
+ respond = get_completion(
143
+ prompt=format_common_prompt(learn_prompt, words),
144
+ api_key=config.api_key)
145
+ logger.info(f"进入学习模式,学习下列单词: {words}")
146
+ return respond
147
+
148
+ def query_word(input: str) -> str:
149
+ # just query infomation about a word from gpt
150
+ respond = get_completion(
151
+ prompt=format_common_prompt(query_prompt, input),
152
+ api_key=config.api_key)
153
+ logger.info(f"查询单词: {input}")
154
+ return respond
155
+
156
+ def command_mapper(input: str) -> str:
157
+ """map natural language to command, return command function
158
+ """
159
+ user_message = input
160
+ messages = [
161
+ {'role':'system',
162
+ 'content': system_message_mapper},
163
+ {'role':'user',
164
+ 'content': f"{user_message}"},
165
+ ]
166
+ respond = get_completion_from_messages(messages, api_key=config)
167
+ mapped_command = command_parser(respond)
168
+ logger.info(f"用户输入: {user_message}\n指令解析器输出: {mapped_command}")
169
+
170
+ return mapped_command
171
+
config.py ADDED
@@ -0,0 +1 @@
 
 
1
+ api_key = "sk-wmeFmoPXQ9wlJKiP48B0F028E6534359A6980b9585Ba5bAc"
create_db.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
4
+ import tempfile
5
+ import config
6
+ import nltk
7
+
8
+ from typing import List
9
+ from nltk.corpus import words
10
+ from loguru import logger
11
+ from llm.call_llm import get_completion_from_messages
12
+ from embedding.call_embedding import get_embedding
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from langchain.document_loaders import PyMuPDFLoader
15
+ from langchain.vectorstores import Chroma
16
+
17
+ from prompts import system_message_select
18
+
19
+ WORDS_DB_PATH = "../words_db"
20
+ VECTOR_DB_PATH = "./vector_db/chroma"
21
+
22
+ def parse_file(file_path):
23
+ docs = []
24
+ # check file type
25
+ file_type = file_path.split('.')[-1]
26
+ if file_type == 'pdf':
27
+ loader = PyMuPDFLoader(file_path)
28
+ content = loader.load()
29
+ docs.extend(content)
30
+ else:
31
+ return "File type not supported"
32
+ if len(docs) > 5:
33
+ return "File too large, please select a pdf file with less than 5 pages"
34
+
35
+ slices = split_text(docs) # split content into slices
36
+ words = extract_words(slices) # extract words from slices
37
+ try:
38
+ vectorize_words(words) # store words into vector database
39
+ except Exception as e:
40
+ logger.error(e)
41
+
42
+ return ""
43
+
44
+ def parse_text(input: str):
45
+ content = input
46
+ return content
47
+
48
+ def split_text(docs: List[object]):
49
+ """Split text into slices"""
50
+ text_splitter = RecursiveCharacterTextSplitter(
51
+ chunk_size = 1500,
52
+ chunk_overlap = 150
53
+ )
54
+ splits = text_splitter.split_documents(docs)
55
+ logger.info(f"Split {len(docs)} pages document into {len(splits)} slices")
56
+ return splits
57
+
58
+ def extract_words(splits: List[object]):
59
+ """Extract words from slices"""
60
+ all_words = []
61
+ for slice in splits:
62
+ tmp_content = slice.page_content
63
+ messages = [
64
+ {'role':'system',
65
+ 'content': system_message_select},
66
+ {'role':'user',
67
+ 'content': f"{tmp_content}"},
68
+ ]
69
+ respond = get_completion_from_messages(messages, api_key=config.api_key)
70
+ words_list = respond.split(", ")
71
+ if len(words_list) == 0:
72
+ continue
73
+ else:
74
+ all_words.extend(words_list)
75
+ all_words = wash_words(all_words)
76
+ logger.info(f"Extract {len(all_words)} words from slices")
77
+ return all_words
78
+
79
+ def wash_words(input_words: list[str]):
80
+ """Wash words into a list of correct english words"""
81
+ words_list = [word for word in input_words
82
+ if len(word) >= 3 and len(word) <= 30]
83
+ nltk.download('words')
84
+ english_words = set(words.words())
85
+ filtered_words = [word.lower() for word in words_list if word.lower() in english_words]
86
+ filtered_words = list(set(filtered_words))
87
+ logger.info(f"Wash {len(filtered_words)} words into a list of correct english words")
88
+ return filtered_words
89
+
90
+ def get_words_from_text(input: str):
91
+ words = input.split(' ')
92
+ return words
93
+
94
+ def store_words(input: str, db_path=WORDS_DB_PATH):
95
+ """Store words into database"""
96
+ pass
97
+
98
+ def vectorize_words(input: list[str], embedding=None):
99
+ """Vectorize words into vectors"""
100
+ model = get_embedding("openai", embedding_key=config.api_key)
101
+ persist_path = VECTOR_DB_PATH
102
+ vectordb = Chroma.from_texts(
103
+ texts=input,
104
+ embedding=model,
105
+ persist_directory=persist_path
106
+ )
107
+ vectordb.persist()
108
+ logger.info(f"Vectorized {len(input)} words into vectors")
109
+ return vectordb
110
+
111
+ def get_similar_k_words(query_word, k=3) -> List[str]:
112
+ # get 3 simlilar words from DB
113
+ model = get_embedding("openai", embedding_key=config.api_key)
114
+ vectordb = Chroma(persist_directory=VECTOR_DB_PATH, embedding_function=model)
115
+ similar_words = vectordb.max_marginal_relevance_search(query_word, k=k)
116
+ similar_words = [word.page_content for word in similar_words]
117
+ logger.info(f"Get {k} similar words {similar_words} from DB")
118
+ return similar_words
119
+
120
+ def create_db(input, chat_history):
121
+ """The input is file or text"""
122
+ action_msg = "" # the description of user action: put file or text into database
123
+ # 1. for file upload
124
+ if isinstance(input, tempfile._TemporaryFileWrapper):
125
+ tmp_file_path = input.name
126
+ file_name = tmp_file_path.split('/')[-1]
127
+ action_msg = f"Add words from my file: {file_name} to database"
128
+ try:
129
+ parse_file(tmp_file_path) #TODO
130
+ output = f"Words from your file: {file_name} has been added to database"
131
+ except Exception as e:
132
+ logger.error(e)
133
+ output = f"Error: failed to use your file: {file_name} generate dictionary"
134
+ # 2. for text input
135
+ elif isinstance(input, str):
136
+ action_msg = f"Add words from my text: {input} to database"
137
+ try:
138
+ parse_text(input) #TODO
139
+ output = f"Words from your text: {input} has been added to database"
140
+ except Exception as e:
141
+ logger.error(e)
142
+ output = f"Error: failed to use your text: {input} generate dictionary"
143
+ chat_history.append((action_msg, output))
144
+
145
+ return chat_history
146
+
147
+
148
+ if __name__ == "__main__":
149
+ create_db(embeddings="m3e")
database/__pycache__/create_db.cpython-311.pyc ADDED
Binary file (3.33 kB). View file
 
database/word_database.db ADDED
Binary file (8.19 kB). View file
 
embedding/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from .zhipuai_embedding import ZhipuAIEmbeddings
embedding/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (163 Bytes). View file
 
embedding/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (158 Bytes). View file
 
embedding/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (152 Bytes). View file
 
embedding/__pycache__/call_embedding.cpython-310.pyc ADDED
Binary file (956 Bytes). View file
 
embedding/__pycache__/call_embedding.cpython-311.pyc ADDED
Binary file (1.48 kB). View file
 
embedding/__pycache__/call_embedding.cpython-39.pyc ADDED
Binary file (937 Bytes). View file
 
embedding/__pycache__/zhipuai_embedding.cpython-310.pyc ADDED
Binary file (4.23 kB). View file
 
embedding/__pycache__/zhipuai_embedding.cpython-311.pyc ADDED
Binary file (5.45 kB). View file
 
embedding/__pycache__/zhipuai_embedding.cpython-39.pyc ADDED
Binary file (4.2 kB). View file
 
embedding/call_embedding.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
4
+
5
+ from embedding.zhipuai_embedding import ZhipuAIEmbeddings
6
+ from langchain.embeddings.huggingface import HuggingFaceEmbeddings
7
+ from langchain.embeddings.openai import OpenAIEmbeddings
8
+ from llm.call_llm import parse_llm_api_key
9
+
10
+ def get_embedding(embedding: str, embedding_key: str=None, env_file: str=None):
11
+ if embedding == 'm3e':
12
+ return HuggingFaceEmbeddings(model_name="moka-ai/m3e-base")
13
+ if embedding_key is None:
14
+ embedding_key = parse_llm_api_key(embedding)
15
+ if embedding == "openai":
16
+ return OpenAIEmbeddings(openai_api_key=embedding_key)
17
+ elif embedding == "zhipuai":
18
+ return ZhipuAIEmbeddings(zhipuai_api_key=embedding_key)
19
+ else:
20
+ raise ValueError(f"embedding {embedding} not support ")
embedding/zhipuai_embedding.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from langchain.embeddings.base import Embeddings
7
+ from langchain.pydantic_v1 import BaseModel, root_validator
8
+ from langchain.utils import get_from_dict_or_env
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class ZhipuAIEmbeddings(BaseModel, Embeddings):
14
+ """`Zhipuai Embeddings` embedding models."""
15
+
16
+ zhipuai_api_key: Optional[str] = None
17
+ """Zhipuai application apikey"""
18
+
19
+ @root_validator()
20
+ def validate_environment(cls, values: Dict) -> Dict:
21
+ """
22
+ Validate whether zhipuai_api_key in the environment variables or
23
+ configuration file are available or not.
24
+
25
+ Args:
26
+
27
+ values: a dictionary containing configuration information, must include the
28
+ fields of zhipuai_api_key
29
+ Returns:
30
+
31
+ a dictionary containing configuration information. If zhipuai_api_key
32
+ are not provided in the environment variables or configuration
33
+ file, the original values will be returned; otherwise, values containing
34
+ zhipuai_api_key will be returned.
35
+ Raises:
36
+
37
+ ValueError: zhipuai package not found, please install it with `pip install
38
+ zhipuai`
39
+ """
40
+ values["zhipuai_api_key"] = get_from_dict_or_env(
41
+ values,
42
+ "zhipuai_api_key",
43
+ "ZHIPUAI_API_KEY",
44
+ )
45
+
46
+ try:
47
+ import zhipuai
48
+ zhipuai.api_key = values["zhipuai_api_key"]
49
+ values["client"] = zhipuai.model_api
50
+
51
+ except ImportError:
52
+ raise ValueError(
53
+ "Zhipuai package not found, please install it with "
54
+ "`pip install zhipuai`"
55
+ )
56
+ return values
57
+
58
+ def _embed(self, texts: str) -> List[float]:
59
+ # send request
60
+ try:
61
+ resp = self.client.invoke(
62
+ model="text_embedding",
63
+ prompt=texts
64
+ )
65
+ except Exception as e:
66
+ raise ValueError(f"Error raised by inference endpoint: {e}")
67
+
68
+ if resp["code"] != 200:
69
+ raise ValueError(
70
+ "Error raised by inference API HTTP code: %s, %s"
71
+ % (resp["code"], resp["msg"])
72
+ )
73
+ embeddings = resp["data"]["embedding"]
74
+ return embeddings
75
+
76
+ def embed_query(self, text: str) -> List[float]:
77
+ """
78
+ Embedding a text.
79
+
80
+ Args:
81
+
82
+ Text (str): A text to be embedded.
83
+
84
+ Return:
85
+
86
+ List [float]: An embedding list of input text, which is a list of floating-point values.
87
+ """
88
+ resp = self.embed_documents([text])
89
+ return resp[0]
90
+
91
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
92
+ """
93
+ Embeds a list of text documents.
94
+
95
+ Args:
96
+ texts (List[str]): A list of text documents to embed.
97
+
98
+ Returns:
99
+ List[List[float]]: A list of embeddings for each document in the input list.
100
+ Each embedding is represented as a list of float values.
101
+ """
102
+ return [self._embed(text) for text in texts]
103
+
104
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
105
+ """Asynchronous Embed search docs."""
106
+ raise NotImplementedError(
107
+ "Please use `embed_documents`. Official does not support asynchronous requests")
108
+
109
+ async def aembed_query(self, text: str) -> List[float]:
110
+ """Asynchronous Embed query text."""
111
+ raise NotImplementedError(
112
+ "Please use `aembed_query`. Official does not support asynchronous requests")
llm/__pycache__/call_llm.cpython-310.pyc ADDED
Binary file (8.17 kB). View file
 
llm/__pycache__/call_llm.cpython-311.pyc ADDED
Binary file (14.1 kB). View file
 
llm/__pycache__/call_llm.cpython-39.pyc ADDED
Binary file (8.17 kB). View file
 
llm/__pycache__/self_llm.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
llm/__pycache__/self_llm.cpython-311.pyc ADDED
Binary file (2.11 kB). View file
 
llm/__pycache__/self_llm.cpython-39.pyc ADDED
Binary file (1.56 kB). View file
 
llm/__pycache__/spark_llm.cpython-310.pyc ADDED
Binary file (6.17 kB). View file
 
llm/__pycache__/spark_llm.cpython-311.pyc ADDED
Binary file (10.5 kB). View file
 
llm/__pycache__/spark_llm.cpython-39.pyc ADDED
Binary file (6.12 kB). View file
 
llm/__pycache__/wenxin_llm.cpython-310.pyc ADDED
Binary file (2.93 kB). View file
 
llm/__pycache__/wenxin_llm.cpython-311.pyc ADDED
Binary file (4.4 kB). View file
 
llm/__pycache__/wenxin_llm.cpython-39.pyc ADDED
Binary file (2.99 kB). View file
 
llm/__pycache__/wenxin_llm_.cpython-310.pyc ADDED
Binary file (3.76 kB). View file
 
llm/__pycache__/zhipuai_llm.cpython-310.pyc ADDED
Binary file (5.86 kB). View file
 
llm/__pycache__/zhipuai_llm.cpython-311.pyc ADDED
Binary file (8.4 kB). View file
 
llm/__pycache__/zhipuai_llm.cpython-39.pyc ADDED
Binary file (5.6 kB). View file
 
llm/call_llm.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ '''
4
+ @File : call_llm.py
5
+ @Time : 2023/10/18 10:45:00
6
+ @Author : Logan Zou
7
+ @Version : 1.0
8
+ @Contact : loganzou0421@163.com
9
+ @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
10
+ @Desc : 将各个大模型的原生接口封装在一个接口
11
+ '''
12
+
13
+ import openai
14
+ import json
15
+ import requests
16
+ import _thread as thread
17
+ import base64
18
+ # import datetime
19
+ from dotenv import load_dotenv, find_dotenv
20
+ import hashlib
21
+ import hmac
22
+ import os
23
+ import queue
24
+ from urllib.parse import urlparse
25
+ import ssl
26
+ from datetime import datetime
27
+ from time import mktime
28
+ from urllib.parse import urlencode
29
+ from wsgiref.handlers import format_date_time
30
+ import zhipuai
31
+ from langchain.utils import get_from_dict_or_env
32
+
33
+ import websocket # 使用websocket_client
34
+
35
+ def get_completion(prompt :str, model="gpt-3.5-turbo", temperature=0.1,api_key=None, secret_key=None, access_token=None, appid=None, api_secret=None, max_tokens=2048):
36
+ # 调用大模型获取回复,支持上述三种模型+gpt
37
+ # arguments:
38
+ # prompt: 输入提示
39
+ # model:模型名
40
+ # temperature: 温度系数
41
+ # api_key:如名
42
+ # secret_key, access_token:调用文心系列模型需要
43
+ # appid, api_secret: 调用星火系列模型需要
44
+ # max_tokens : 返回最长序列
45
+ # return: 模型返回,字符串
46
+ # 调用 GPT
47
+ if model in ["gpt-3.5-turbo", "gpt-3.5-turbo-16k-0613", "gpt-3.5-turbo-0613", "gpt-4", "gpt-4-32k"]:
48
+ return get_completion_gpt(prompt, model, temperature, api_key, max_tokens)
49
+ elif model in ["ERNIE-Bot", "ERNIE-Bot-4", "ERNIE-Bot-turbo"]:
50
+ return get_completion_wenxin(prompt, model, temperature, api_key, secret_key)
51
+ elif model in ["Spark-1.5", "Spark-2.0"]:
52
+ return get_completion_spark(prompt, model, temperature, api_key, appid, api_secret, max_tokens)
53
+ elif model in ["chatglm_pro", "chatglm_std", "chatglm_lite"]:
54
+ return get_completion_glm(prompt, model, temperature, api_key, max_tokens)
55
+ else:
56
+ return "不正确的模型"
57
+
58
+ def get_completion_gpt(prompt : str, model : str, temperature : float, api_key:str, max_tokens:int):
59
+ # 封装 OpenAI 原生接口
60
+ if api_key is None:
61
+ api_key = parse_llm_api_key("openai")
62
+ openai.api_key = api_key
63
+ # 具体调用
64
+ messages = [{"role": "user", "content": prompt}]
65
+ response = openai.ChatCompletion.create(
66
+ model=model,
67
+ messages=messages,
68
+ temperature=temperature, # 模型输出的温度系数,控制输出的随机程度
69
+ max_tokens = max_tokens, # 回复最大长度
70
+ )
71
+ # 调用 OpenAI 的 ChatCompletion 接口
72
+ return response.choices[0].message["content"]
73
+
74
+ def get_completion_from_messages(messages, api_key, temperature=0):
75
+ # 封装 OpenAI 原生接口
76
+ if api_key is None:
77
+ api_key = parse_llm_api_key("openai")
78
+ openai.api_key = api_key
79
+ response = openai.ChatCompletion.create(
80
+ model="gpt-3.5-turbo",
81
+ messages=messages,
82
+ temperature=temperature, # 控制模型输出的随机程度
83
+ )
84
+ return response.choices[0].message["content"]
85
+
86
+ def get_access_token(api_key, secret_key):
87
+ """
88
+ 使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key
89
+ """
90
+ # 指定网址
91
+ url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}"
92
+ # 设置 POST 访问
93
+ payload = json.dumps("")
94
+ headers = {
95
+ 'Content-Type': 'application/json',
96
+ 'Accept': 'application/json'
97
+ }
98
+ # 通过 POST 访问获取账户对应的 access_token
99
+ response = requests.request("POST", url, headers=headers, data=payload)
100
+ return response.json().get("access_token")
101
+
102
+ def get_completion_wenxin(prompt : str, model : str, temperature : float, api_key:str, secret_key : str):
103
+ # 封装百度文心原生接口
104
+ if api_key is None or secret_key is None:
105
+ api_key, secret_key = parse_llm_api_key("wenxin")
106
+ # 获取access_token
107
+ access_token = get_access_token(api_key, secret_key)
108
+ # 调用接口
109
+ url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={access_token}"
110
+ # 配置 POST 参数
111
+ payload = json.dumps({
112
+ "messages": [
113
+ {
114
+ "role": "user",# user prompt
115
+ "content": "{}".format(prompt)# 输入的 prompt
116
+ }
117
+ ]
118
+ })
119
+ headers = {
120
+ 'Content-Type': 'application/json'
121
+ }
122
+ # 发起请求
123
+ response = requests.request("POST", url, headers=headers, data=payload)
124
+ # 返回的是一个 Json 字符串
125
+ js = json.loads(response.text)
126
+ return js["result"]
127
+
128
+ def get_completion_spark(prompt : str, model : str, temperature : float, api_key:str, appid : str, api_secret : str, max_tokens : int):
129
+ if api_key is None or appid is None and api_secret is None:
130
+ api_key, appid, api_secret = parse_llm_api_key("spark")
131
+
132
+ # 配置 1.5 和 2 的不同环境
133
+ if model == "Spark-1.5":
134
+ domain = "general"
135
+ Spark_url = "ws://spark-api.xf-yun.com/v1.1/chat" # v1.5环境的地址
136
+ else:
137
+ domain = "generalv2" # v2.0版本
138
+ Spark_url = "ws://spark-api.xf-yun.com/v2.1/chat" # v2.0环境的地址
139
+
140
+ question = [{"role":"user", "content":prompt}]
141
+ response = spark_main(appid,api_key,api_secret,Spark_url,domain,question,temperature,max_tokens)
142
+ return response
143
+
144
+ def get_completion_glm(prompt : str, model : str, temperature : float, api_key:str, max_tokens : int):
145
+ # 获取GLM回答
146
+ if api_key is None:
147
+ api_key = parse_llm_api_key("zhipuai")
148
+ zhipuai.api_key = api_key
149
+
150
+ response = zhipuai.model_api.invoke(
151
+ model=model,
152
+ prompt=[{"role":"user", "content":prompt}],
153
+ temperature = temperature,
154
+ max_tokens=max_tokens
155
+ )
156
+ return response["data"]["choices"][0]["content"].strip('"').strip(" ")
157
+
158
+ # def getText(role, content, text = []):
159
+ # # role 是指定角色,content 是 prompt 内容
160
+ # jsoncon = {}
161
+ # jsoncon["role"] = role
162
+ # jsoncon["content"] = content
163
+ # text.append(jsoncon)
164
+ # return text
165
+
166
+ # 星火 API 调用使用
167
+ answer = ""
168
+
169
+ class Ws_Param(object):
170
+ # 初始化
171
+ def __init__(self, APPID, APIKey, APISecret, Spark_url):
172
+ self.APPID = APPID
173
+ self.APIKey = APIKey
174
+ self.APISecret = APISecret
175
+ self.host = urlparse(Spark_url).netloc
176
+ self.path = urlparse(Spark_url).path
177
+ self.Spark_url = Spark_url
178
+ # 自定义
179
+ self.temperature = 0
180
+ self.max_tokens = 2048
181
+
182
+ # 生成url
183
+ def create_url(self):
184
+ # 生成RFC1123格式的时间戳
185
+ now = datetime.now()
186
+ date = format_date_time(mktime(now.timetuple()))
187
+
188
+ # 拼接字符串
189
+ signature_origin = "host: " + self.host + "\n"
190
+ signature_origin += "date: " + date + "\n"
191
+ signature_origin += "GET " + self.path + " HTTP/1.1"
192
+
193
+ # 进行hmac-sha256进行加密
194
+ signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
195
+ digestmod=hashlib.sha256).digest()
196
+
197
+ signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
198
+
199
+ authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
200
+
201
+ authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
202
+
203
+ # 将请求的鉴权参数组合为字典
204
+ v = {
205
+ "authorization": authorization,
206
+ "date": date,
207
+ "host": self.host
208
+ }
209
+ # 拼接鉴权参数,生成url
210
+ url = self.Spark_url + '?' + urlencode(v)
211
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
212
+ return url
213
+
214
+
215
+ # 收到websocket错误的处理
216
+ def on_error(ws, error):
217
+ print("### error:", error)
218
+
219
+
220
+ # 收到websocket关闭的处理
221
+ def on_close(ws,one,two):
222
+ print(" ")
223
+
224
+
225
+ # 收到websocket连接建立的处理
226
+ def on_open(ws):
227
+ thread.start_new_thread(run, (ws,))
228
+
229
+
230
+ def run(ws, *args):
231
+ data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature = ws.temperature, max_tokens = ws.max_tokens))
232
+ ws.send(data)
233
+
234
+
235
+ # 收到websocket消息的处理
236
+ def on_message(ws, message):
237
+ # print(message)
238
+ data = json.loads(message)
239
+ code = data['header']['code']
240
+ if code != 0:
241
+ print(f'请求错误: {code}, {data}')
242
+ ws.close()
243
+ else:
244
+ choices = data["payload"]["choices"]
245
+ status = choices["status"]
246
+ content = choices["text"][0]["content"]
247
+ print(content,end ="")
248
+ global answer
249
+ answer += content
250
+ # print(1)
251
+ if status == 2:
252
+ ws.close()
253
+
254
+
255
+ def gen_params(appid, domain,question, temperature, max_tokens):
256
+ """
257
+ 通过appid和用户的提问来生成请参数
258
+ """
259
+ data = {
260
+ "header": {
261
+ "app_id": appid,
262
+ "uid": "1234"
263
+ },
264
+ "parameter": {
265
+ "chat": {
266
+ "domain": domain,
267
+ "random_threshold": 0.5,
268
+ "max_tokens": max_tokens,
269
+ "temperature" : temperature,
270
+ "auditing": "default"
271
+ }
272
+ },
273
+ "payload": {
274
+ "message": {
275
+ "text": question
276
+ }
277
+ }
278
+ }
279
+ return data
280
+
281
+
282
+ def spark_main(appid, api_key, api_secret, Spark_url,domain, question, temperature, max_tokens):
283
+ # print("星火:")
284
+ output_queue = queue.Queue()
285
+ def on_message(ws, message):
286
+ data = json.loads(message)
287
+ code = data['header']['code']
288
+ if code != 0:
289
+ print(f'请求错误: {code}, {data}')
290
+ ws.close()
291
+ else:
292
+ choices = data["payload"]["choices"]
293
+ status = choices["status"]
294
+ content = choices["text"][0]["content"]
295
+ # print(content, end='')
296
+ # 将输出值放入队列
297
+ output_queue.put(content)
298
+ if status == 2:
299
+ ws.close()
300
+
301
+ wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
302
+ websocket.enableTrace(False)
303
+ wsUrl = wsParam.create_url()
304
+ ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
305
+ ws.appid = appid
306
+ ws.question = question
307
+ ws.domain = domain
308
+ ws.temperature = temperature
309
+ ws.max_tokens = max_tokens
310
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
311
+ return ''.join([output_queue.get() for _ in range(output_queue.qsize())])
312
+
313
+ def parse_llm_api_key(model:str, env_file:dict()=None):
314
+ """
315
+ 通过 model 和 env_file 的来解析平台参数
316
+ """
317
+ if env_file is None:
318
+ _ = load_dotenv(find_dotenv())
319
+ env_file = os.environ
320
+ if model == "openai":
321
+ return env_file["OPENAI_API_KEY"]
322
+ elif model == "wenxin":
323
+ return env_file["wenxin_api_key"], env_file["wenxin_secret_key"]
324
+ elif model == "spark":
325
+ return env_file["spark_api_key"], env_file["spark_appid"], env_file["spark_api_secret"]
326
+ elif model == "zhipuai":
327
+ return get_from_dict_or_env(env_file, "zhipuai_api_key", "ZHIPUAI_API_KEY")
328
+ # return env_file["ZHIPUAI_API_KEY"]
329
+ else:
330
+ raise ValueError(f"model{model} not support!!!")
llm/self_llm.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ '''
4
+ @File : self_llm.py
5
+ @Time : 2023/10/16 18:48:08
6
+ @Author : Logan Zou
7
+ @Version : 1.0
8
+ @Contact : loganzou0421@163.com
9
+ @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
10
+ @Desc : 在 LangChain LLM 基础上封装的项目类,统一了 GPT、文心、讯飞、智谱多种 API 调用
11
+ '''
12
+
13
+ from langchain.llms.base import LLM
14
+ from typing import Dict, Any, Mapping
15
+ from pydantic import Field
16
+
17
+ class Self_LLM(LLM):
18
+ # 自定义 LLM
19
+ # 继承自 langchain.llms.base.LLM
20
+ # 原生接口地址
21
+ url : str = None
22
+ # 默认选用 GPT-3.5 模型,即目前一般所说的GPT
23
+ model_name: str = "gpt-3.5-turbo"
24
+ # 访问时延上限
25
+ request_timeout: float = None
26
+ # 温度系数
27
+ temperature: float = 0.1
28
+ # API_Key
29
+ api_key: str = None
30
+ # 必备的可选参数
31
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
32
+
33
+ # 定义一个返回默认参数的方法
34
+ @property
35
+ def _default_params(self) -> Dict[str, Any]:
36
+ """获取调用默认参数。"""
37
+ normal_params = {
38
+ "temperature": self.temperature,
39
+ "request_timeout": self.request_timeout,
40
+ }
41
+ # print(type(self.model_kwargs))
42
+ return {**normal_params}
43
+
44
+ @property
45
+ def _identifying_params(self) -> Mapping[str, Any]:
46
+ """Get the identifying parameters."""
47
+ return {**{"model_name": self.model_name}, **self._default_params}
llm/spark_llm.py ADDED
@@ -0,0 +1,227 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ '''
4
+ @File : wenxin_llm.py
5
+ @Time : 2023/10/16 18:53:26
6
+ @Author : Logan Zou
7
+ @Version : 1.0
8
+ @Contact : loganzou0421@163.com
9
+ @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
10
+ @Desc : 基于讯飞星火大模型自定义 LLM 类
11
+ '''
12
+
13
+ from langchain.llms.base import LLM
14
+ from typing import Any, List, Mapping, Optional, Dict, Union, Tuple
15
+ from pydantic import Field
16
+ from llm.self_llm import Self_LLM
17
+ import json
18
+ import requests
19
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
20
+ import _thread as thread
21
+ import base64
22
+ import datetime
23
+ import hashlib
24
+ import hmac
25
+ import json
26
+ from urllib.parse import urlparse
27
+ import ssl
28
+ from datetime import datetime
29
+ from time import mktime
30
+ from urllib.parse import urlencode
31
+ from wsgiref.handlers import format_date_time
32
+ import websocket # 使用websocket_client
33
+ import queue
34
+
35
+ class Spark_LLM(Self_LLM):
36
+ # 讯飞星火大模型的自定义 LLM
37
+ # URL
38
+ url : str = "ws://spark-api.xf-yun.com/v1.1/chat"
39
+ # APPID
40
+ appid : str = None
41
+ # APISecret
42
+ api_secret : str = None
43
+ # Domain
44
+ domain :str = "general"
45
+ # max_token
46
+ max_tokens : int = 4096
47
+
48
+ def getText(self, role, content, text = []):
49
+ # role 是指定角色,content 是 prompt 内容
50
+ jsoncon = {}
51
+ jsoncon["role"] = role
52
+ jsoncon["content"] = content
53
+ text.append(jsoncon)
54
+ return text
55
+
56
+ def _call(self, prompt : str, stop: Optional[List[str]] = None,
57
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
58
+ **kwargs: Any):
59
+ if self.api_key == None or self.appid == None or self.api_secret == None:
60
+ # 三个 Key 均存在才可以正常调用
61
+ print("请填入 Key")
62
+ raise ValueError("Key 不存在")
63
+ # 将 Prompt 填充到星火格式
64
+ question = self.getText("user", prompt)
65
+ # 发起请求
66
+ try:
67
+ response = spark_main(self.appid,self.api_key,self.api_secret,self.url,self.domain,question, self.temperature, self.max_tokens)
68
+ return response
69
+ except Exception as e:
70
+ print(e)
71
+ print("请求失败")
72
+ return "请求失败"
73
+
74
+ @property
75
+ def _llm_type(self) -> str:
76
+ return "Spark"
77
+
78
+ answer = ""
79
+
80
+ class Ws_Param(object):
81
+ # 初始化
82
+ def __init__(self, APPID, APIKey, APISecret, Spark_url):
83
+ self.APPID = APPID
84
+ self.APIKey = APIKey
85
+ self.APISecret = APISecret
86
+ self.host = urlparse(Spark_url).netloc
87
+ self.path = urlparse(Spark_url).path
88
+ self.Spark_url = Spark_url
89
+ # 自定义
90
+ self.temperature = 0
91
+ self.max_tokens = 2048
92
+
93
+ # 生成url
94
+ def create_url(self):
95
+ # 生成RFC1123格式的时间戳
96
+ now = datetime.now()
97
+ date = format_date_time(mktime(now.timetuple()))
98
+
99
+ # 拼接字符串
100
+ signature_origin = "host: " + self.host + "\n"
101
+ signature_origin += "date: " + date + "\n"
102
+ signature_origin += "GET " + self.path + " HTTP/1.1"
103
+
104
+ # 进行hmac-sha256进行加密
105
+ signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
106
+ digestmod=hashlib.sha256).digest()
107
+
108
+ signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
109
+
110
+ authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
111
+
112
+ authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
113
+
114
+ # 将请求的鉴权参数组合为字典
115
+ v = {
116
+ "authorization": authorization,
117
+ "date": date,
118
+ "host": self.host
119
+ }
120
+ # 拼接鉴权参数,生成url
121
+ url = self.Spark_url + '?' + urlencode(v)
122
+ # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
123
+ return url
124
+
125
+
126
+ # 收到websocket错误的处理
127
+ def on_error(ws, error):
128
+ print("### error:", error)
129
+
130
+
131
+ # 收到websocket关闭的处理
132
+ def on_close(ws,one,two):
133
+ print(" ")
134
+
135
+
136
+ # 收到websocket连接建立的处理
137
+ def on_open(ws):
138
+ thread.start_new_thread(run, (ws,))
139
+
140
+
141
+ def run(ws, *args):
142
+ data = json.dumps(gen_params(appid=ws.appid, domain= ws.domain,question=ws.question, temperature = ws.temperature, max_tokens = ws.max_tokens))
143
+ ws.send(data)
144
+
145
+
146
+ # 收到websocket消息的处理
147
+ def on_message(ws, message):
148
+ # print(message)
149
+ data = json.loads(message)
150
+ code = data['header']['code']
151
+ if code != 0:
152
+ print(f'请求错误: {code}, {data}')
153
+ ws.close()
154
+ else:
155
+ choices = data["payload"]["choices"]
156
+ status = choices["status"]
157
+ content = choices["text"][0]["content"]
158
+ print(content,end ="")
159
+ global answer
160
+ answer += content
161
+ # print(1)
162
+ if status == 2:
163
+ ws.close()
164
+
165
+
166
+ def gen_params(appid, domain,question, temperature, max_tokens):
167
+ """
168
+ 通过appid和用户的提问来生成请参数
169
+ """
170
+ data = {
171
+ "header": {
172
+ "app_id": appid,
173
+ "uid": "1234"
174
+ },
175
+ "parameter": {
176
+ "chat": {
177
+ "domain": domain,
178
+ "random_threshold": 0.5,
179
+ "max_tokens": max_tokens,
180
+ "temperature" : temperature,
181
+ "auditing": "default"
182
+ }
183
+ },
184
+ "payload": {
185
+ "message": {
186
+ "text": question
187
+ }
188
+ }
189
+ }
190
+ return data
191
+
192
+
193
+ def spark_main(appid, api_key, api_secret, Spark_url,domain, question, temperature, max_tokens):
194
+ # print("星火:")
195
+ output_queue = queue.Queue()
196
+ def on_message(ws, message):
197
+ data = json.loads(message)
198
+ code = data['header']['code']
199
+ if code != 0:
200
+ print(f'请求错误: {code}, {data}')
201
+ ws.close()
202
+ else:
203
+ choices = data["payload"]["choices"]
204
+ status = choices["status"]
205
+ content = choices["text"][0]["content"]
206
+ # print(content, end='')
207
+ # 将输出值放入队列
208
+ output_queue.put(content)
209
+ if status == 2:
210
+ ws.close()
211
+
212
+ wsParam = Ws_Param(appid, api_key, api_secret, Spark_url)
213
+ websocket.enableTrace(False)
214
+ wsUrl = wsParam.create_url()
215
+ ws = websocket.WebSocketApp(wsUrl, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open)
216
+ ws.appid = appid
217
+ ws.question = question
218
+ ws.domain = domain
219
+ ws.temperature = temperature
220
+ ws.max_tokens = max_tokens
221
+ ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
222
+ return ''.join([output_queue.get() for _ in range(output_queue.qsize())])
223
+
224
+
225
+
226
+
227
+
llm/test.ipynb ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from wenxin_llm import Wenxin_LLM"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 3,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "from dotenv import find_dotenv, load_dotenv\n",
19
+ "import os\n",
20
+ "\n",
21
+ "# 读取本地/项目的环境变量。\n",
22
+ "\n",
23
+ "# find_dotenv()寻找并定位.env文件的路径\n",
24
+ "# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中\n",
25
+ "# 如果你设置的是全局的环境变量,这行代码则没有任何作用。\n",
26
+ "_ = load_dotenv(find_dotenv())\n",
27
+ "\n",
28
+ "# 获取环境变量 OPENAI_API_KEY\n",
29
+ "wenxin_api_key = os.environ[\"wenxin_api_key\"]\n",
30
+ "wenxin_secret_key = os.environ[\"wenxin_secret_key\"]"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 4,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "llm = Wenxin_LLM(model = \"ERNIE-Bot-turbo\", api_key=wenxin_api_key, secret_key=wenxin_secret_key)"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": 5,
45
+ "metadata": {},
46
+ "outputs": [
47
+ {
48
+ "data": {
49
+ "text/plain": [
50
+ "'您好,我是百度研发的知识增强大语言模型,中文名是文心一言,英文名是ERNIE Bot。我能够与人对话互动,回答问题,协助创作,高效便捷地帮助人们获取信息、知识和灵感。\\n\\n如果您有任何问题,请随时告诉我。'"
51
+ ]
52
+ },
53
+ "execution_count": 5,
54
+ "metadata": {},
55
+ "output_type": "execute_result"
56
+ }
57
+ ],
58
+ "source": [
59
+ "llm(\"你是谁\")"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": 6,
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "from spark_llm import Spark_LLM"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": 7,
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "from dotenv import find_dotenv, load_dotenv\n",
78
+ "import os\n",
79
+ "\n",
80
+ "# 读取本地/项目的环境变量。\n",
81
+ "\n",
82
+ "# find_dotenv()寻找并定位.env文件的路径\n",
83
+ "# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中\n",
84
+ "# 如果你设置的是全局的环境变量,这行代码则没有任何作用。\n",
85
+ "_ = load_dotenv(find_dotenv())\n",
86
+ "#填写控制台中获取的 APPID 信息\n",
87
+ "appid = os.environ[\"spark_appid\"]\n",
88
+ "#填写控制台中获取的 APISecret 信息\n",
89
+ "api_secret = os.environ[\"spark_api_secret\"]\n",
90
+ "#填写控制台中获取的 APIKey 信息\n",
91
+ "api_key = os.environ[\"spark_api_key\"]"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 8,
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "llm = Spark_LLM(model = \"spark\", appid=appid, api_secret=api_secret, api_key=api_key)"
101
+ ]
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": 9,
106
+ "metadata": {},
107
+ "outputs": [
108
+ {
109
+ "data": {
110
+ "text/plain": [
111
+ "'\\n\\n我是一个AI语言模型,可以回答你的问题和提供帮助。'"
112
+ ]
113
+ },
114
+ "execution_count": 9,
115
+ "metadata": {},
116
+ "output_type": "execute_result"
117
+ }
118
+ ],
119
+ "source": [
120
+ "llm(\"你是谁\")"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": 10,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "from zhipuai_llm import ZhipuAILLM\n",
130
+ "\n",
131
+ "from dotenv import find_dotenv, load_dotenv\n",
132
+ "import os\n",
133
+ "\n",
134
+ "# 读取本地/项目的环境变量。\n",
135
+ "\n",
136
+ "# find_dotenv()寻找并定位.env文件的路径\n",
137
+ "# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中\n",
138
+ "# 如果你设置的是全局的环境变量,这行代码则没有任何作用。\n",
139
+ "_ = load_dotenv(find_dotenv())\n",
140
+ "\n",
141
+ "api_key = os.environ[\"ZHIPUAI_API_KEY\"] #填写控制台中获取的 APIKey 信息"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": 11,
147
+ "metadata": {},
148
+ "outputs": [
149
+ {
150
+ "data": {
151
+ "text/plain": [
152
+ "'我是一个名为 ChatGLM 的人工智能助手,由智谱 AI 公司于2023年训练的语言模型开发而成。我的任务是针对用户的问题和要求提供适当的答复和支持。'"
153
+ ]
154
+ },
155
+ "execution_count": 11,
156
+ "metadata": {},
157
+ "output_type": "execute_result"
158
+ }
159
+ ],
160
+ "source": [
161
+ "llm = ZhipuAILLM(model=\"chatglm_pro\", zhipuai_api_key=api_key, temperature=0.1)\n",
162
+ "llm(\"你是谁\") "
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "metadata": {},
168
+ "source": [
169
+ "测试原生接口"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": 12,
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "from call_llm import get_completion"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": 13,
184
+ "metadata": {},
185
+ "outputs": [],
186
+ "source": [
187
+ "from dotenv import find_dotenv, load_dotenv\n",
188
+ "import os\n",
189
+ "\n",
190
+ "# 读取本地/项目的环境变量。\n",
191
+ "\n",
192
+ "# find_dotenv()寻找并定位.env文件的路径\n",
193
+ "# load_dotenv()读取该.env文件,并将其中的环境变量加载到当前的运行环境中\n",
194
+ "# 如果你设置的是全局的环境变量,这行代码则没有任何作用。\n",
195
+ "_ = load_dotenv(find_dotenv())\n",
196
+ "\n",
197
+ "# 获取环境变量 OPENAI_API_KEY\n",
198
+ "openai_api_key = os.environ[\"OPENAI_API_KEY\"]\n",
199
+ "wenxin_api_key = os.environ[\"wenxin_api_key\"]\n",
200
+ "wenxin_secret_key = os.environ[\"wenxin_secret_key\"]\n",
201
+ "spark_appid = os.environ[\"spark_appid\"]\n",
202
+ "spark_api_secret = os.environ[\"spark_api_secret\"]\n",
203
+ "spark_api_key = os.environ[\"spark_api_key\"]\n",
204
+ "zhipu_api_key = os.environ[\"ZHIPUAI_API_KEY\"]\n",
205
+ "\n",
206
+ "# os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'\n",
207
+ "# os.environ[\"HTTP_PROXY\"] = 'http://127.0.0.1:7890'"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "code",
212
+ "execution_count": 14,
213
+ "metadata": {},
214
+ "outputs": [
215
+ {
216
+ "data": {
217
+ "text/plain": [
218
+ "'我是一个人工智能助手,可以回答你的问题并提供帮助。有什么可以帮到你的吗?'"
219
+ ]
220
+ },
221
+ "execution_count": 14,
222
+ "metadata": {},
223
+ "output_type": "execute_result"
224
+ }
225
+ ],
226
+ "source": [
227
+ "get_completion(\"你是谁\",model=\"gpt-3.5-turbo\", api_key=openai_api_key)"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": 15,
233
+ "metadata": {},
234
+ "outputs": [
235
+ {
236
+ "data": {
237
+ "text/plain": [
238
+ "'您好,我是百度研发的知识增强大语言模型,中文名是文心一言,英文名是ERNIE Bot。我能够与人对话互动,回答问题,协助创作,高效便捷地帮助人们获取信息、知识和灵感。'"
239
+ ]
240
+ },
241
+ "execution_count": 15,
242
+ "metadata": {},
243
+ "output_type": "execute_result"
244
+ }
245
+ ],
246
+ "source": [
247
+ "get_completion(\"你是谁\",model=\"ERNIE-Bot-turbo\", api_key=wenxin_api_key, secret_key=wenxin_secret_key)"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": 16,
253
+ "metadata": {},
254
+ "outputs": [
255
+ {
256
+ "data": {
257
+ "text/plain": [
258
+ "'\\n\\n我是一个AI语言模型,可以回答你的问题和提供帮助。'"
259
+ ]
260
+ },
261
+ "execution_count": 16,
262
+ "metadata": {},
263
+ "output_type": "execute_result"
264
+ }
265
+ ],
266
+ "source": [
267
+ "get_completion(\"你是谁\",model=\"Spark-1.5\", appid=spark_appid, api_key=spark_api_key, api_secret=spark_api_secret)"
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "execution_count": 17,
273
+ "metadata": {},
274
+ "outputs": [
275
+ {
276
+ "data": {
277
+ "text/plain": [
278
+ "'我是一个名为 ChatGLM 的人工智能助手,由智谱 AI 公司于2023年训练的语言模型开发而成。我的任务是针对用户的问题和要求提供适当的答复和支持。'"
279
+ ]
280
+ },
281
+ "execution_count": 17,
282
+ "metadata": {},
283
+ "output_type": "execute_result"
284
+ }
285
+ ],
286
+ "source": [
287
+ "get_completion(\"你是谁\",model=\"chatglm_std\", api_key=zhipu_api_key)"
288
+ ]
289
+ }
290
+ ],
291
+ "metadata": {
292
+ "kernelspec": {
293
+ "display_name": "langchain",
294
+ "language": "python",
295
+ "name": "python3"
296
+ },
297
+ "language_info": {
298
+ "codemirror_mode": {
299
+ "name": "ipython",
300
+ "version": 3
301
+ },
302
+ "file_extension": ".py",
303
+ "mimetype": "text/x-python",
304
+ "name": "python",
305
+ "nbconvert_exporter": "python",
306
+ "pygments_lexer": "ipython3",
307
+ "version": "3.10.0"
308
+ }
309
+ },
310
+ "nbformat": 4,
311
+ "nbformat_minor": 2
312
+ }
llm/wenxin_llm.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ '''
4
+ @File : wenxin_llm.py
5
+ @Time : 2023/10/16 18:53:26
6
+ @Author : Logan Zou
7
+ @Version : 1.0
8
+ @Contact : loganzou0421@163.com
9
+ @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
10
+ @Desc : 基于百度文心大模型自定义 LLM 类
11
+ '''
12
+
13
+ from langchain.llms.base import LLM
14
+ from typing import Any, List, Mapping, Optional, Dict, Union, Tuple
15
+ from pydantic import Field
16
+ from llm.self_llm import Self_LLM
17
+ import json
18
+ import requests
19
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
20
+ # 调用文心 API 的工具函数
21
+ def get_access_token(api_key : str, secret_key : str):
22
+ """
23
+ 使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key
24
+ """
25
+ # 指定网址
26
+ url = f"https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={api_key}&client_secret={secret_key}"
27
+ # 设置 POST 访问
28
+ payload = json.dumps("")
29
+ headers = {
30
+ 'Content-Type': 'application/json',
31
+ 'Accept': 'application/json'
32
+ }
33
+ # 通过 POST 访问获取账户对应的 access_token
34
+ response = requests.request("POST", url, headers=headers, data=payload)
35
+ return response.json().get("access_token")
36
+
37
+ class Wenxin_LLM(Self_LLM):
38
+ # 文心大模型的自定义 LLM
39
+ # URL
40
+ url : str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token={}"
41
+ # Secret_Key
42
+ secret_key : str = None
43
+ # access_token
44
+ access_token: str = None
45
+
46
+ def init_access_token(self):
47
+ if self.api_key != None and self.secret_key != None:
48
+ # 两个 Key 均非空才可以获取 access_token
49
+ try:
50
+ self.access_token = get_access_token(self.api_key, self.secret_key)
51
+ except Exception as e:
52
+ print(e)
53
+ print("获取 access_token 失败,请检查 Key")
54
+ else:
55
+ print("API_Key 或 Secret_Key 为空,请检查 Key")
56
+
57
+ def _call(self, prompt : str, stop: Optional[List[str]] = None,
58
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
59
+ **kwargs: Any):
60
+ # 如果 access_token 为空,初始化 access_token
61
+ if self.access_token == None:
62
+ self.init_access_token()
63
+ # API 调用 url
64
+ url = self.url.format(self.access_token)
65
+ # 配置 POST 参数
66
+ payload = json.dumps({
67
+ "messages": [
68
+ {
69
+ "role": "user",# user prompt
70
+ "content": "{}".format(prompt)# 输入的 prompt
71
+ }
72
+ ],
73
+ 'temperature' : self.temperature
74
+ })
75
+ headers = {
76
+ 'Content-Type': 'application/json'
77
+ }
78
+ # 发起请求
79
+ response = requests.request("POST", url, headers=headers, data=payload, timeout=self.request_timeout)
80
+ if response.status_code == 200:
81
+ # 返回的是一个 Json 字符串
82
+ js = json.loads(response.text)
83
+ # print(js)
84
+ return js["result"]
85
+ else:
86
+ return "请求失败"
87
+
88
+ @property
89
+ def _llm_type(self) -> str:
90
+ return "Wenxin"
llm/zhipuai_llm.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- encoding: utf-8 -*-
3
+ '''
4
+ @File : zhipuai_llm.py
5
+ @Time : 2023/10/16 22:06:26
6
+ @Author : 0-yy-0
7
+ @Version : 1.0
8
+ @Contact : 310484121@qq.com
9
+ @License : (C)Copyright 2017-2018, Liugroup-NLPR-CASIA
10
+ @Desc : 基于智谱 AI 大模型自定义 LLM 类
11
+ '''
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ from typing import (
17
+ Any,
18
+ AsyncIterator,
19
+ Dict,
20
+ Iterator,
21
+ List,
22
+ Optional,
23
+ )
24
+
25
+ from langchain.callbacks.manager import (
26
+ AsyncCallbackManagerForLLMRun,
27
+ CallbackManagerForLLMRun,
28
+ )
29
+ from langchain.llms.base import LLM
30
+ from langchain.pydantic_v1 import Field, root_validator
31
+ from langchain.schema.output import GenerationChunk
32
+ from langchain.utils import get_from_dict_or_env
33
+ from llm.self_llm import Self_LLM
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class ZhipuAILLM(Self_LLM):
39
+ """Zhipuai hosted open source or customized models.
40
+
41
+ To use, you should have the ``zhipuai`` python package installed, and
42
+ the environment variable ``zhipuai_api_key`` set with
43
+ your API key and Secret Key.
44
+
45
+ zhipuai_api_key are required parameters which you could get from
46
+ https://open.bigmodel.cn/usercenter/apikeys
47
+
48
+ Example:
49
+ .. code-block:: python
50
+
51
+ from langchain.llms import ZhipuAILLM
52
+ zhipuai_model = ZhipuAILLM(model="chatglm_std", temperature=temperature)
53
+
54
+ """
55
+
56
+ model_kwargs: Dict[str, Any] = Field(default_factory=dict)
57
+
58
+ client: Any
59
+
60
+ model: str = "chatglm_std"
61
+ """Model name in chatglm_pro, chatglm_std, chatglm_lite. """
62
+
63
+ zhipuai_api_key: Optional[str] = None
64
+
65
+ incremental: Optional[bool] = True
66
+ """Whether to incremental the results or not."""
67
+
68
+ streaming: Optional[bool] = False
69
+ """Whether to streaming the results or not."""
70
+ # streaming = -incremental
71
+
72
+ request_timeout: Optional[int] = 60
73
+ """request timeout for chat http requests"""
74
+
75
+ top_p: Optional[float] = 0.8
76
+ temperature: Optional[float] = 0.95
77
+ request_id: Optional[float] = None
78
+
79
+ @root_validator()
80
+ def validate_enviroment(cls, values: Dict) -> Dict:
81
+
82
+ values["zhipuai_api_key"] = get_from_dict_or_env(
83
+ values,
84
+ "zhipuai_api_key",
85
+ "ZHIPUAI_API_KEY",
86
+ )
87
+
88
+ params = {
89
+ "zhipuai_api_key": values["zhipuai_api_key"],
90
+ "model": values["model"],
91
+ }
92
+ try:
93
+ import zhipuai
94
+
95
+ zhipuai.api_key = values["zhipuai_api_key"]
96
+ values["client"] = zhipuai.model_api
97
+ except ImportError:
98
+ raise ValueError(
99
+ "zhipuai package not found, please install it with "
100
+ "`pip install zhipuai`"
101
+ )
102
+ return values
103
+
104
+ @property
105
+ def _identifying_params(self) -> Dict[str, Any]:
106
+ return {
107
+ **{"model": self.model},
108
+ **super()._identifying_params,
109
+ }
110
+
111
+ @property
112
+ def _llm_type(self) -> str:
113
+ """Return type of llm."""
114
+ return "zhipuai"
115
+
116
+ @property
117
+ def _default_params(self) -> Dict[str, Any]:
118
+ """Get the default parameters for calling OpenAI API."""
119
+ normal_params = {
120
+ "streaming": self.streaming,
121
+ "top_p": self.top_p,
122
+ "temperature": self.temperature,
123
+ "request_id": self.request_id,
124
+ }
125
+
126
+ return {**normal_params, **self.model_kwargs}
127
+
128
+ def _convert_prompt_msg_params(
129
+ self,
130
+ prompt: str,
131
+ **kwargs: Any,
132
+ ) -> dict:
133
+ return {
134
+ **{"prompt": prompt, "model": self.model},
135
+ **self._default_params,
136
+ **kwargs,
137
+ }
138
+
139
+ def _call(
140
+ self,
141
+ prompt: str,
142
+ stop: Optional[List[str]] = None,
143
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
144
+ **kwargs: Any,
145
+ ) -> str:
146
+ """Call out to an zhipuai models endpoint for each generation with a prompt.
147
+ Args:
148
+ prompt: The prompt to pass into the model.
149
+ Returns:
150
+ The string generated by the model.
151
+
152
+ Example:
153
+ .. code-block:: python
154
+ response = zhipuai_model("Tell me a joke.")
155
+ """
156
+ if self.streaming:
157
+ completion = ""
158
+ for chunk in self._stream(prompt, stop, run_manager, **kwargs):
159
+ completion += chunk.text
160
+ return completion
161
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
162
+
163
+ response_payload = self.client.invoke(**params)
164
+ return response_payload["data"]["choices"][-1]["content"].strip('"').strip(" ")
165
+
166
+ async def _acall(
167
+ self,
168
+ prompt: str,
169
+ stop: Optional[List[str]] = None,
170
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
171
+ **kwargs: Any,
172
+ ) -> str:
173
+ if self.streaming:
174
+ completion = ""
175
+ async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
176
+ completion += chunk.text
177
+ return completion
178
+
179
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
180
+
181
+ response = await self.client.async_invoke(**params)
182
+
183
+ return response_payload
184
+
185
+ def _stream(
186
+ self,
187
+ prompt: str,
188
+ stop: Optional[List[str]] = None,
189
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
190
+ **kwargs: Any,
191
+ ) -> Iterator[GenerationChunk]:
192
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
193
+
194
+ for res in self.client.invoke(**params):
195
+ if res:
196
+ chunk = GenerationChunk(text=res)
197
+ yield chunk
198
+ if run_manager:
199
+ run_manager.on_llm_new_token(chunk.text)
200
+
201
+ async def _astream(
202
+
203
+ self,
204
+ prompt: str,
205
+ stop: Optional[List[str]] = None,
206
+ run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
207
+ **kwargs: Any,
208
+ ) -> AsyncIterator[GenerationChunk]:
209
+ params = self._convert_prompt_msg_params(prompt, **kwargs)
210
+
211
+ async for res in await self.client.ado(**params):
212
+ if res:
213
+ chunk = GenerationChunk(text=res["data"]["choices"]["content"])
214
+
215
+ yield chunk
216
+ if run_manager:
217
+ await run_manager.on_llm_new_token(chunk.text)
prompts.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """define some usefull prompts"""
2
+
3
+ trans_prompt = """
4
+ Please translate the word into chinese: {} """
5
+
6
+ query_prompt = """
7
+ Please translate the word into chinese \
8
+ and give english example sentence and sentence meaning in Chinese:
9
+ {} """
10
+
11
+ learn_prompt = """
12
+ Please give an english sentence contains these words: {} \
13
+ and give sentence meaning in Chinese.
14
+ """
15
+
16
+ wash_prompt = """
17
+ Please help me to remove the duplicate words and meaningless words from
18
+ the words list: {} and return the result.
19
+
20
+ For example:
21
+ User: apple, red, green, yellow, red, blll><, :*&
22
+ Your return: apple, red, green, yellow
23
+ """
24
+ #*******************************************************
25
+ #************** user prompts ***************************
26
+ #*******************************************************
27
+ user_message_mapper = """
28
+ Please classify the following categories and return the name of the category. If your think user's requirement does not belong to the following categories, return the number 0.
29
+ Here are four examples:
30
+ User: I want to know the meaning of the word "apple"
31
+ """
32
+
33
+ #*******************************************************
34
+ #************** system prompts *************************
35
+ #*******************************************************
36
+
37
+
38
+ # Using few shots learning to map the user's requirement to a command
39
+ system_message_mapper = """
40
+ You will receive the user's requirement. \
41
+ Please classify the user's requirement into the following categories \
42
+ and return the name of the category.
43
+ If your think user's requirement does not belong to the \
44
+ following categories, return the number 0.
45
+
46
+ Here are four examples:
47
+
48
+ User: I want to know the meaning of the word "apple"
49
+ Your return: :query apple
50
+
51
+ User: I want to add the word "apple" to my dictionary
52
+ Your return: :add apple
53
+
54
+ User: I want to remove the word "apple" from my dictionary
55
+ Your return: :remove apple
56
+
57
+ User: I want to learn the word "apple" based words in my dictionary now
58
+ Your return: :learn
59
+
60
+ User: I already learnt the meaning of the word "apple"
61
+ Your return: :remove apple
62
+
63
+ User: I want to know the rest words in my dictionary
64
+ Your return: :show
65
+
66
+ Here are the meaning of four categories for you to reference:
67
+ :query word means user want to study the meaning of word
68
+ :add word means user want to add word into his dictionary
69
+ :remove word means user want to remove word from his dictionary
70
+ :learn word means user want to learn words in his dictionary related to the word
71
+ :show means user want to show all words in his dictionary
72
+
73
+ """
74
+
75
+ # using few shots learning to select the important words in a sentence
76
+ system_message_select = """
77
+ You will receive the english sentence from user, please select the words in the sentence that you think
78
+ is important for the user to understand the meaning of the sentence.
79
+ If the sentence does not contain any english words, please just return number 0.
80
+
81
+ For example:
82
+ User: The apple is red, and it's green.
83
+ Your return: apple, red
84
+ """
requirements.txt ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohttp==3.8.6
3
+ aiosignal==1.3.1
4
+ altair==5.1.2
5
+ annotated-types==0.6.0
6
+ anyio==4.0.0
7
+ asttokens==2.2.1
8
+ async-timeout==4.0.3
9
+ attrs==23.1.0
10
+ backcall==0.2.0
11
+ backoff==2.2.1
12
+ cachetools==5.3.2
13
+ certifi==2023.7.22
14
+ chardet==5.2.0
15
+ charset-normalizer==3.3.2
16
+ chromadb==0.3.29
17
+ click==8.1.7
18
+ clickhouse-connect==0.6.20
19
+ coloredlogs==15.0.1
20
+ comm==0.1.3
21
+ contourpy==1.2.0
22
+ cycler==0.12.1
23
+ dataclasses==0.6
24
+ dataclasses-json==0.5.14
25
+ debugpy==1.6.6
26
+ decorator==5.1.1
27
+ duckdb==0.9.1
28
+ entrypoints==0.4
29
+ exceptiongroup==1.0.4
30
+ executing==1.2.0
31
+ fastapi==0.85.1
32
+ ffmpy==0.3.1
33
+ filelock==3.13.1
34
+ filetype==1.2.0
35
+ flatbuffers==23.5.26
36
+ fonttools==4.44.0
37
+ frozenlist==1.4.0
38
+ fsspec==2023.10.0
39
+ gradio==3.40.1
40
+ gradio_client==0.7.0
41
+ greenlet==3.0.1
42
+ h11==0.14.0
43
+ hnswlib==0.7.0
44
+ httpcore==1.0.2
45
+ httptools==0.6.1
46
+ httpx==0.25.1
47
+ huggingface-hub==0.17.3
48
+ humanfriendly==10.0
49
+ idna==3.4
50
+ importlib-metadata==6.0.0
51
+ importlib-resources==6.1.1
52
+ ipykernel==6.23.1
53
+ ipython==8.10.0
54
+ jedi==0.18.2
55
+ jieba==0.42.1
56
+ Jinja2==3.1.2
57
+ joblib==1.3.2
58
+ jsonschema==4.19.2
59
+ jsonschema-specifications==2023.7.1
60
+ jupyter_client==8.6.0
61
+ jupyter_core==5.3.0
62
+ kiwisolver==1.4.5
63
+ langchain==0.0.292
64
+ langsmith==0.0.63
65
+ linkify-it-py==2.0.2
66
+ lxml==4.9.3
67
+ lz4==4.3.2
68
+ Markdown==3.4.3
69
+ markdown-it-py==2.2.0
70
+ MarkupSafe==2.1.3
71
+ marshmallow==3.20.1
72
+ matplotlib==3.8.1
73
+ matplotlib-inline==0.1.3
74
+ mdit-py-plugins==0.3.3
75
+ mdurl==0.1.2
76
+ monotonic==1.6
77
+ mpmath==1.3.0
78
+ multidict==6.0.4
79
+ mypy-extensions==1.0.0
80
+ nest-asyncio==1.5.4
81
+ nltk==3.8.1
82
+ numexpr==2.8.7
83
+ numpy==1.26.2
84
+ onnxruntime==1.16.2
85
+ openai==0.27.6
86
+ orjson==3.9.10
87
+ overrides==7.4.0
88
+ packaging==23.1
89
+ pandas==2.1.3
90
+ parso==0.8.3
91
+ pexpect==4.8.0
92
+ pickleshare==0.7.5
93
+ Pillow==10.1.0
94
+ pip==23.3
95
+ platformdirs==3.0.0
96
+ posthog==3.0.2
97
+ prompt-toolkit==3.0.38
98
+ protobuf==4.25.0
99
+ psutil==5.9.4
100
+ ptyprocess==0.7.0
101
+ pulsar-client==3.3.0
102
+ pure-eval==0.2.2
103
+ pydantic==1.10.10
104
+ pydantic_core==2.10.1
105
+ pydub==0.25.1
106
+ Pygments==2.14.0
107
+ PyJWT==2.8.0
108
+ PyMuPDF==1.23.6
109
+ PyMuPDFb==1.23.6
110
+ pyparsing==3.1.1
111
+ python-dateutil==2.8.2
112
+ python-dotenv==1.0.0
113
+ python-magic==0.4.27
114
+ python-multipart==0.0.6
115
+ pytz==2023.3.post1
116
+ PyYAML==6.0.1
117
+ pyzmq==25.1.0
118
+ referencing==0.30.2
119
+ regex==2023.5.5
120
+ requests==2.31.0
121
+ rouge-chinese==1.0.3
122
+ rpds-py==0.12.0
123
+ scikit-learn==1.2.2
124
+ scipy==1.11.2
125
+ semantic-version==2.10.0
126
+ setuptools==68.0.0
127
+ six==1.16.0
128
+ sniffio==1.3.0
129
+ SQLAlchemy==2.0.23
130
+ stack-data==0.6.2
131
+ starlette==0.20.4
132
+ sympy==1.12
133
+ tabulate==0.9.0
134
+ tenacity==8.2.3
135
+ threadpoolctl==3.1.0
136
+ tokenizers==0.14.1
137
+ toolz==0.12.0
138
+ tornado==6.3.3
139
+ tqdm==4.66.1
140
+ traitlets==5.9.0
141
+ typing_extensions==4.7.1
142
+ typing-inspect==0.9.0
143
+ tzdata==2023.3
144
+ uc-micro-py==1.0.2
145
+ unstructured==0.9.0
146
+ urllib3==2.0.7
147
+ uvicorn==0.24.0.post1
148
+ uvloop==0.19.0
149
+ watchfiles==0.21.0
150
+ wcwidth==0.2.5
151
+ websocket-client==1.5.2
152
+ websockets==11.0.3
153
+ wheel==0.41.2
154
+ yarl==1.9.2
155
+ zhipuai==1.0.7
156
+ zipp==3.11.0
157
+ zstandard==0.22.0
test/__pycache__/test_create_db.cpython-311-pytest-8.2.0.pyc ADDED
Binary file (7.87 kB). View file
 
test/test_create_db.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from pytest import fixture
3
+ from create_db import split_text
4
+
5
+
6
+ @fixture
7
+ def sample_text():
8
+ return [
9
+ "Lorem ipsum dolor sit amet, consectetur adipiscing elit. "
10
+ "Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. "
11
+ "Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. "
12
+ "Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. "
13
+ "Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.",
14
+ "Another long text string to demonstrate the splitting functionality. This text should also be split into multiple chunks."
15
+ ]
16
+
17
+
18
+ def test_split_text(sample_text):
19
+ # Split the sample text into chunks
20
+ chunks = split_text(sample_text)
21
+
22
+ # Assert that the chunks are lists of strings
23
+ assert all(
24
+ isinstance(chunk, list) and all(
25
+ isinstance(text, str) for text in chunk) for chunk in chunks)
26
+
27
+ # Assert that the chunks are not empty
28
+ assert all(chunk for chunk in chunks)
29
+
30
+ # Assert that the chunks have the expected length (approx. 1500 characters with 150 overlap)
31
+ expected_length = 1500 - 150 # Subtracting the overlap size
32
+ assert all(expected_length <= len(''.join(chunk)) < 1500
33
+ for chunk in chunks)
34
+
35
+ # Assert that the chunks contain the original text
36
+ original_text = ' '.join(sample_text)
37
+ assert all(text in original_text for chunk in chunks for text in chunk)
38
+
39
+ # Assert that the chunks do not overlap (except for the overlap size)
40
+ for i in range(len(chunks) - 1):
41
+ previous_chunk = chunks[i]
42
+ next_chunk = chunks[i + 1]
43
+ overlap = ''.join(set(previous_chunk[-150:]) & set(next_chunk[:150]))
44
+ assert len(overlap) == 150 or not overlap
words_db.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from loguru import logger
3
+
4
+ DB_PATH = './database/word_database.db'
5
+
6
+ class WordsDB(object):
7
+ def __init__(self):
8
+ logger.info('Initialized words database.')
9
+
10
+ def _connect_db(self):
11
+ conn = sqlite3.connect(
12
+ DB_PATH,
13
+ timeout=10,
14
+ check_same_thread=False)
15
+ cursor = conn.cursor()
16
+ cursor.execute('''
17
+ CREATE TABLE IF NOT EXISTS words (
18
+ id INTEGER PRIMARY KEY,
19
+ word TEXT NOT NULL,
20
+ definition TEXT NOT NULL
21
+ )
22
+ ''')
23
+ return conn, conn.cursor()
24
+
25
+
26
+ def add_word(self, word, definition):
27
+ self.conn, self.cursor = self._connect_db()
28
+ self.cursor.execute('INSERT INTO words (word, definition) VALUES (?, ?)', (word, definition))
29
+ self.conn.commit()
30
+ self.cursor.close()
31
+ self.conn.close()
32
+
33
+
34
+ def delete_word(self, word):
35
+ self.conn, self.cursor = self._connect_db()
36
+ self.cursor.execute('DELETE FROM words WHERE word = ?', (word,))
37
+ self.conn.commit()
38
+ self.cursor.close()
39
+ self.conn.close()
40
+
41
+ def update_word(self, word, new_definition):
42
+ self.conn, self.cursor = self._connect_db()
43
+ self.cursor.execute('UPDATE words SET definition = ? WHERE word = ?', (new_definition, word))
44
+ self.conn.commit()
45
+ self.cursor.close()
46
+ self.conn.close()
47
+
48
+ def query_word(self):
49
+ self.conn, self.cursor = self._connect_db()
50
+ self.cursor.execute('SELECT * FROM words')
51
+ res = self.cursor.fetchall()
52
+ self.cursor.close()
53
+ self.conn.close()
54
+ return res
55
+
56
+
57
+ words_db = WordsDB()
58
+
59
+ if __name__ == '__main__':
60
+
61
+ words_db.add_word('apple', '苹果')
62
+ words_db.add_word('banana', '香蕉')
63
+ words_db.update_word('banana', '新的香蕉')
64
+ result = words_db.query_word('banana')
65
+ print(result)
66
+
67
+