Commit
•
fd8a539
1
Parent(s):
a11f2aa
Update README.md
Browse files
README.md
CHANGED
@@ -26,4 +26,110 @@ datasets:
|
|
26 |
|
27 |
## 模型使用
|
28 |
```python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
```
|
|
|
26 |
|
27 |
## 模型使用
|
28 |
```python
|
29 |
+
import time
|
30 |
+
from transformers import GenerationConfig
|
31 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, StoppingCriteria, StoppingCriteriaList
|
32 |
+
import os
|
33 |
+
import json
|
34 |
+
import logging
|
35 |
+
import torch
|
36 |
+
|
37 |
+
log_path = 'your_log_path'
|
38 |
+
logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
39 |
+
|
40 |
+
logging.info('This is a log message.')
|
41 |
+
|
42 |
+
model_path = "/LiteAI/Hare-1.1B-Tool"
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(save_path)
|
44 |
+
model = AutoModelForCausalLM.from_pretrained(save_path)
|
45 |
+
|
46 |
+
class MyStoppingCriteria(StoppingCriteria):
|
47 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
48 |
+
keyword = tokenizer.decode(input_ids[0][-1])
|
49 |
+
return keyword in ["<api_end>"]
|
50 |
+
|
51 |
+
|
52 |
+
def chat(
|
53 |
+
messages,
|
54 |
+
model,
|
55 |
+
tokenizer,
|
56 |
+
generate_config=None,
|
57 |
+
max_length=512,
|
58 |
+
max_new_tokens=256,
|
59 |
+
):
|
60 |
+
if generate_config is None:
|
61 |
+
generate_config = GenerationConfig(
|
62 |
+
do_sample=False,
|
63 |
+
max_length=max_length,
|
64 |
+
max_new_tokens=max_new_tokens,
|
65 |
+
eos_token_id=32001,
|
66 |
+
)
|
67 |
+
|
68 |
+
if messages[0]["role"] == "system":
|
69 |
+
system = messages[0]["content"]
|
70 |
+
messages = messages[0:]
|
71 |
+
else:
|
72 |
+
system = "You are a helpful assistant."
|
73 |
+
|
74 |
+
n_token = max_length
|
75 |
+
system = "<round_start>system\n{}<round_end>\n".format(system)
|
76 |
+
system_token = tokenizer.encode(system, add_special_tokens=False)
|
77 |
+
n_token -= len(system_token)
|
78 |
+
|
79 |
+
query = messages[-1]["content"]
|
80 |
+
query = "<round_start>user\n{}<round_end>\n<round_start>assistant\n".format(query)
|
81 |
+
query_token = tokenizer.encode(query, add_special_tokens=False)
|
82 |
+
n_token -= len(query_token)
|
83 |
+
|
84 |
+
messages = messages[:-1]
|
85 |
+
conversations = []
|
86 |
+
for ids in range(len(messages)-1, 0, -2):
|
87 |
+
user = messages[ids - 1]["content"]
|
88 |
+
assistant = messages[ids]["content"]
|
89 |
+
|
90 |
+
round = "<round_start>user\n{}<round_end>\n<round_start>assistant\n{}<round_end>\n".format(user, assistant)
|
91 |
+
round_token = tokenizer.encode(round, add_special_tokens=False)
|
92 |
+
|
93 |
+
if n_token - len(round_token) > 0:
|
94 |
+
conversations = [round] + conversations
|
95 |
+
else:
|
96 |
+
break
|
97 |
+
|
98 |
+
prompt = system + "".join(conversations) + query
|
99 |
+
prompt_token = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
|
100 |
+
prompt_token.to(model.device)
|
101 |
+
|
102 |
+
|
103 |
+
response = model.generate(
|
104 |
+
generation_config=generate_config,
|
105 |
+
**prompt_token
|
106 |
+
)
|
107 |
+
|
108 |
+
output_tokens = response[0].cpu().numpy()[prompt_token.input_ids.size()[1]:]
|
109 |
+
output_string = tokenizer.decode(output_tokens, skip_special_tokens=True).replace("<round_end>", "")
|
110 |
+
return output_string, prompt
|
111 |
+
|
112 |
+
# ======================
|
113 |
+
# main
|
114 |
+
# ======================
|
115 |
+
|
116 |
+
test_query_path = "/home/sxw/sft_exper/dataset/query_to_test"
|
117 |
+
for file in os.listdir(test_query_path):
|
118 |
+
file_pth = os.path.join(test_query_path, file)
|
119 |
+
print(file_pth)
|
120 |
+
logging.info(file_pth)
|
121 |
+
with open(file_pth, 'r') as f:
|
122 |
+
for line in f:
|
123 |
+
data = json.loads(line)
|
124 |
+
# print(data)
|
125 |
+
query = data["human"]
|
126 |
+
messages = [
|
127 |
+
{"role": "system", "content": "Below is the query from the users, you need make full sense of user's intention based on the content of the sentence, then call the correct function and generate the parameters of the calling function."},
|
128 |
+
{"role": "user", "content": query}
|
129 |
+
]
|
130 |
+
response, input_prompt = chat(messages=messages, model=model, tokenizer=tokenizer)
|
131 |
+
|
132 |
+
logging.info(query)
|
133 |
+
logging.info(data["assistant"])
|
134 |
+
logging.info(response)
|
135 |
```
|