LiteAI-Team commited on
Commit
fd8a539
1 Parent(s): a11f2aa

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +106 -0
README.md CHANGED
@@ -26,4 +26,110 @@ datasets:
26
 
27
  ## 模型使用
28
  ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ```
 
26
 
27
  ## 模型使用
28
  ```python
29
+ import time
30
+ from transformers import GenerationConfig
31
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, StoppingCriteria, StoppingCriteriaList
32
+ import os
33
+ import json
34
+ import logging
35
+ import torch
36
+
37
+ log_path = 'your_log_path'
38
+ logging.basicConfig(filename=log_path, level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
39
+
40
+ logging.info('This is a log message.')
41
+
42
+ model_path = "/LiteAI/Hare-1.1B-Tool"
43
+ tokenizer = AutoTokenizer.from_pretrained(save_path)
44
+ model = AutoModelForCausalLM.from_pretrained(save_path)
45
+
46
+ class MyStoppingCriteria(StoppingCriteria):
47
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
48
+ keyword = tokenizer.decode(input_ids[0][-1])
49
+ return keyword in ["<api_end>"]
50
+
51
+
52
+ def chat(
53
+ messages,
54
+ model,
55
+ tokenizer,
56
+ generate_config=None,
57
+ max_length=512,
58
+ max_new_tokens=256,
59
+ ):
60
+ if generate_config is None:
61
+ generate_config = GenerationConfig(
62
+ do_sample=False,
63
+ max_length=max_length,
64
+ max_new_tokens=max_new_tokens,
65
+ eos_token_id=32001,
66
+ )
67
+
68
+ if messages[0]["role"] == "system":
69
+ system = messages[0]["content"]
70
+ messages = messages[0:]
71
+ else:
72
+ system = "You are a helpful assistant."
73
+
74
+ n_token = max_length
75
+ system = "<round_start>system\n{}<round_end>\n".format(system)
76
+ system_token = tokenizer.encode(system, add_special_tokens=False)
77
+ n_token -= len(system_token)
78
+
79
+ query = messages[-1]["content"]
80
+ query = "<round_start>user\n{}<round_end>\n<round_start>assistant\n".format(query)
81
+ query_token = tokenizer.encode(query, add_special_tokens=False)
82
+ n_token -= len(query_token)
83
+
84
+ messages = messages[:-1]
85
+ conversations = []
86
+ for ids in range(len(messages)-1, 0, -2):
87
+ user = messages[ids - 1]["content"]
88
+ assistant = messages[ids]["content"]
89
+
90
+ round = "<round_start>user\n{}<round_end>\n<round_start>assistant\n{}<round_end>\n".format(user, assistant)
91
+ round_token = tokenizer.encode(round, add_special_tokens=False)
92
+
93
+ if n_token - len(round_token) > 0:
94
+ conversations = [round] + conversations
95
+ else:
96
+ break
97
+
98
+ prompt = system + "".join(conversations) + query
99
+ prompt_token = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
100
+ prompt_token.to(model.device)
101
+
102
+
103
+ response = model.generate(
104
+ generation_config=generate_config,
105
+ **prompt_token
106
+ )
107
+
108
+ output_tokens = response[0].cpu().numpy()[prompt_token.input_ids.size()[1]:]
109
+ output_string = tokenizer.decode(output_tokens, skip_special_tokens=True).replace("<round_end>", "")
110
+ return output_string, prompt
111
+
112
+ # ======================
113
+ # main
114
+ # ======================
115
+
116
+ test_query_path = "/home/sxw/sft_exper/dataset/query_to_test"
117
+ for file in os.listdir(test_query_path):
118
+ file_pth = os.path.join(test_query_path, file)
119
+ print(file_pth)
120
+ logging.info(file_pth)
121
+ with open(file_pth, 'r') as f:
122
+ for line in f:
123
+ data = json.loads(line)
124
+ # print(data)
125
+ query = data["human"]
126
+ messages = [
127
+ {"role": "system", "content": "Below is the query from the users, you need make full sense of user's intention based on the content of the sentence, then call the correct function and generate the parameters of the calling function."},
128
+ {"role": "user", "content": query}
129
+ ]
130
+ response, input_prompt = chat(messages=messages, model=model, tokenizer=tokenizer)
131
+
132
+ logging.info(query)
133
+ logging.info(data["assistant"])
134
+ logging.info(response)
135
  ```