BAAI
/

shunxing1234 commited on
Commit
21f1ae7
·
1 Parent(s): 6980f4b

Delete chat_test_NBCE.py

Browse files
Files changed (1) hide show
  1. chat_test_NBCE.py +0 -132
chat_test_NBCE.py DELETED
@@ -1,132 +0,0 @@
1
- #! -*- coding: utf-8 -*-
2
- # Naive Bayes-based Context Extension (NBCE)
3
- # 使用朴素贝叶斯增加LLM的Context处理长度
4
- # 链接:https://kexue.fm/archives/9617
5
- # Torch 2.0 测试通过
6
-
7
- import json
8
- import torch
9
- from transformers import AutoTokenizer
10
- from transformers import AquilaForCausalLM
11
- from transformers import TopPLogitsWarper, LogitsProcessorList
12
- import pdb
13
-
14
- # 加载tokenizer
15
- tokenizer = AutoTokenizer.from_pretrained(model_path)
16
- tokenizer.padding_side = 'left'
17
- tokenizer.pad_token = tokenizer.unk_token
18
-
19
- # 加载Aquila模型
20
- model = AquilaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
21
- device = torch.device('cuda')
22
- model.to(device)
23
- # 加载示例Context
24
- from cyg_conversation import default_conversation
25
-
26
- conv = default_conversation.copy()
27
- contexts = json.load(open('code_text_2.json'))
28
-
29
- question = "请解释这段程序的功能:"
30
- batch = []
31
- conv.append_message(conv.roles[0], question)
32
- conv.append_message(conv.roles[1], None)
33
- batch.append(conv.get_prompt())
34
- # 拼接context和question
35
- for ci,context in enumerate(contexts):
36
- conv1 = default_conversation.copy()
37
- conv1.append_message(conv.roles[0], context+question)
38
- conv1.append_message(conv.roles[1], None)
39
- batch.append(conv1.get_prompt())
40
- print('Context长度分布:', [len(text) for text in batch])
41
- print('Context总长度:', sum([len(text) for text in batch]))
42
-
43
- # Top-P截断
44
- processors = LogitsProcessorList()
45
- processors.append(TopPLogitsWarper(0.95))
46
-
47
- # Copied from https://github.com/bojone/NBCE/blob/main/test.py#L51-L106
48
- @torch.inference_mode()
49
- def generate(max_tokens):
50
- """Naive Bayes-based Context Extension 演示代码
51
- """
52
- inputs = tokenizer(batch, padding='longest', return_tensors='pt').to(device)
53
- input_ids = inputs.input_ids
54
- attention_mask = inputs.attention_mask
55
-
56
- print('input_ids', input_ids.shape)
57
- past_key_values = None
58
- n = input_ids.shape[0]
59
-
60
- for i in range(max_tokens):
61
- # 模型输出
62
- outputs = model(input_ids=input_ids,
63
- attention_mask=attention_mask,
64
- return_dict=True,
65
- use_cache=True,
66
- past_key_values=past_key_values
67
- )
68
- past_key_values = outputs.past_key_values
69
-
70
- # ===== 核心代码开始 =====
71
- beta, eta = 0.25, 0.1
72
- logits = outputs.logits[:, -1]
73
- logits = logits - logits.logsumexp(dim=-1, keepdims=True)
74
- logits = processors(input_ids, logits)
75
- entropy = -(logits.exp() * logits.clip(-100, 0)).sum(dim=-1)
76
- if i > 0:
77
- entropy[k] -= eta
78
- k = entropy[1:].argmin() + 1
79
- logits_max = logits[k]
80
- logits_uncond = logits[0]
81
- logits_merged = (1 + beta) * logits_max - beta * logits_uncond
82
- logits = torch.where(logits_uncond > -100, logits_merged, logits_max)
83
- # ===== 核心代码结束 =====
84
-
85
- # 构建分布,采样
86
- # tau = 1是标准的随机采样,tau->0则是贪心搜索
87
- # 简单起见,这里没有实现topk、topp截断
88
- tau = 0.01
89
- probas = torch.nn.functional.softmax(logits[None] / tau , dim=-1)
90
- next_tokens = torch.multinomial(probas, num_samples=1).squeeze(1)
91
- if next_tokens[0] == tokenizer.eos_token_id:
92
- break
93
-
94
- ret = tokenizer.batch_decode(next_tokens)
95
- print(ret[0], flush=True, end='')
96
-
97
- # prepare for next iteration
98
- input_ids = next_tokens.unsqueeze(-1).tile(n, 1)
99
- attention_mask = torch.cat([attention_mask, torch.ones(n, 1, dtype=torch.long, device=device)], dim=-1)
100
-
101
-
102
- if __name__ == '__main__':
103
- generate(1000)
104
-
105
-
106
- """
107
- ========= 输出结果参考 =========
108
-
109
- 1.菲律宾国家电网公司,中国占股多少?
110
- 答:中国国家电网公司持有菲律宾国家电网公司40%的股份。
111
-
112
- 2.领英计划裁员多少人?
113
- 答:领英计划裁员716人。
114
-
115
- 3.吉利德收购Pharmasset的价格是多少?
116
- 答:吉利德收购Pharmasset的价格为110亿美元。
117
-
118
- 4.丙肝神药Sovaldi在哪一年上市?
119
- 答:丙肝神药Sovaldi于2013年上市。
120
-
121
- 5.中亚峰会将在哪里举行?由谁主持?
122
- 答:中亚峰会将在陕西省西安市举行,由国家主席习近平主持。
123
-
124
- 6.哪个演员由于侮辱人民军队而被立案调查?
125
- 答:李昊石因在表演中存在侮辱人民军队的言论而被立案调查。
126
-
127
- 7.哪个项目宣称“能过坦克”的水上道路?
128
- 答:湖北恩施宣称的“能过坦克”水上道路。
129
-
130
- 8.如果你是默沙东的CEO,你的首要任务是什么?
131
- 答:如果我是默沙东的CEO,我的首要任务是如何让基本盘更加坚固,并通过药物联用获得更好的增长。
132
- """