BAAI
/

shunxing1234 commited on
Commit
9b97fea
1 Parent(s): 8c50400

Update README_zh.md

Browse files
Files changed (1) hide show
  1. README_zh.md +0 -103
README_zh.md CHANGED
@@ -63,109 +63,6 @@ with torch.no_grad():
63
  print(out)
64
  ```
65
 
66
- 利用[NBCE](https://github.com/bojone/NBCE/tree/main)进行推理
67
-
68
- ```python
69
- import json
70
- import torch
71
- from transformers import AutoTokenizer
72
- from transformers import AutoModelForCausalLM
73
- from transformers import TopPLogitsWarper, LogitsProcessorList
74
- import pdb
75
-
76
- # 加载tokenizer
77
- tokenizer = AutoTokenizer.from_pretrained(model_path)
78
- tokenizer.padding_side = 'left'
79
- tokenizer.pad_token = tokenizer.unk_token
80
-
81
- # 加载Aquila模型
82
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
83
- device = torch.device('cuda')
84
- model.to(device)
85
- # 加载示例Context
86
- from cyg_conversation import default_conversation
87
-
88
- conv = default_conversation.copy()
89
- contexts = json.load(open('code_text_2.json'))
90
-
91
- question = "请解释这段程序的功能:"
92
- batch = []
93
- conv.append_message(conv.roles[0], question)
94
- conv.append_message(conv.roles[1], None)
95
- batch.append(conv.get_prompt())
96
- # 拼接context和question
97
- for ci,context in enumerate(contexts):
98
- conv1 = default_conversation.copy()
99
- conv1.append_message(conv.roles[0], context+question)
100
- conv1.append_message(conv.roles[1], None)
101
- batch.append(conv1.get_prompt())
102
- print('Context长度分布:', [len(text) for text in batch])
103
- print('Context总长度:', sum([len(text) for text in batch]))
104
-
105
- # Top-P截断
106
- processors = LogitsProcessorList()
107
- processors.append(TopPLogitsWarper(0.95))
108
-
109
- # Copied from https://github.com/bojone/NBCE/blob/main/test.py#L51-L106
110
- @torch.inference_mode()
111
- def generate(max_tokens):
112
- """Naive Bayes-based Context Extension 演示代码
113
- """
114
- inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
115
- input_ids = inputs.input_ids
116
- attention_mask = inputs.attention_mask
117
-
118
- print('input_ids', input_ids.shape)
119
- past_key_values = None
120
- n = input_ids.shape[0]
121
-
122
- for i in range(max_tokens):
123
- # 模型输出
124
- outputs = model(input_ids=input_ids,
125
- attention_mask=attention_mask,
126
- return_dict=True,
127
- use_cache=True,
128
- past_key_values=past_key_values
129
- )
130
- past_key_values = outputs.past_key_values
131
-
132
- # ===== 核心代码开始 =====
133
- beta, eta = 0.25, 0.1
134
- logits = outputs.logits[:, -1]
135
- logits = logits - logits.logsumexp(dim=-1, keepdims=True)
136
- logits = processors(input_ids, logits)
137
- entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1)
138
- if i > 0:
139
- entropy[k] -= eta
140
- k = entropy[1:].argmin() + 1
141
- logits_max = logits[k]
142
- logits_uncond = logits[0]
143
- logits_merged = (1 + beta) * logits_max - beta * logits_uncond
144
- logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
145
- # ===== 核心代码结束 =====
146
-
147
- # 构建分布,采样
148
- # tau = 1是标准的随机采样,tau->0则是贪心搜索
149
- # 简单起见,这里没有实现topk、topp截断
150
- tau = 0.01
151
- probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
152
- next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)
153
- if next_tokens[0] == tokenizer.eos_token_id:
154
- break
155
-
156
- ret = tokenizer.batch_decode(next_tokens)
157
- print(ret[0], flush=True, end='')
158
-
159
- # prepare for next iteration
160
- input_ids = next_tokens.unsqueeze(-1).tile(n, 1)
161
- attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=device)], dim=-1)
162
-
163
-
164
- if __name__ == '__main__':
165
- generate(1000)
166
-
167
- ```
168
-
169
 
170
  ## 证书/License
171
 
 
63
  print(out)
64
  ```
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  ## 证书/License
68