shaojiang commited on
Commit
aa81f3d
1 Parent(s): 68d5b33

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -0
app.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import os
4
+ import argparse
5
+ from tqdm import trange
6
+ from transformers import GPT2LMHeadModel
7
+ import gradio as gr
8
+
9
+
10
+ def is_word(word):
11
+ for item in list(word):
12
+ if item not in 'qwertyuiopasdfghjklzxcvbnm':
13
+ return False
14
+ return True
15
+
16
+
17
+ def _is_chinese_char(char):
18
+ """Checks whether CP is the codepoint of a CJK character."""
19
+ # This defines a "chinese character" as anything in the CJK Unicode block:
20
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
21
+ #
22
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
23
+ # despite its name. The modern Korean Hangul alphabet is a different block,
24
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
25
+ # space-separated words, so they are not treated specially and handled
26
+ # like the all of the other languages.
27
+ cp = ord(char)
28
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
29
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
30
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
31
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
32
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
33
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
34
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
35
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
36
+ return True
37
+
38
+ return False
39
+
40
+
41
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
42
+ """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
43
+ Args:
44
+ logits: logits distribution shape (vocabulary size)
45
+ top_k > 0: keep only top k tokens with highest probability (top-k filtering).
46
+ top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
47
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
48
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
49
+ """
50
+ assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear
51
+ top_k = min(top_k, logits.size(-1)) # Safety check
52
+ if top_k > 0:
53
+ # Remove all tokens with a probability less than the last token of the top-k
54
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
55
+ logits[indices_to_remove] = filter_value
56
+
57
+ if top_p > 0.0:
58
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
59
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
60
+
61
+ # Remove tokens with cumulative probability above the threshold
62
+ sorted_indices_to_remove = cumulative_probs > top_p
63
+ # Shift the indices to the right to keep also the first token above the threshold
64
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
65
+ sorted_indices_to_remove[..., 0] = 0
66
+
67
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
68
+ logits[indices_to_remove] = filter_value
69
+ return logits
70
+
71
+
72
+ def sample_sequence(model, context, length, n_ctx, tokenizer, temperature=1.0, top_k=30, top_p=0.0, repitition_penalty=1.0,
73
+ device='cpu'):
74
+ context = torch.tensor(context, dtype=torch.long, device=device)
75
+ context = context.unsqueeze(0)
76
+ generated = context
77
+ with torch.no_grad():
78
+ for _ in trange(length):
79
+ inputs = {'input_ids': generated[0][-(n_ctx - 1):].unsqueeze(0)}
80
+ outputs = model(
81
+ **inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states)
82
+ next_token_logits = outputs[0][0, -1, :]
83
+ for id in set(generated):
84
+ next_token_logits[id] /= repitition_penalty
85
+ next_token_logits = next_token_logits / temperature
86
+ next_token_logits[tokenizer.convert_tokens_to_ids('[UNK]')] = -float('Inf')
87
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
88
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
89
+ generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1)
90
+ return generated.tolist()[0]
91
+
92
+
93
+ def fast_sample_sequence(model, context, length, temperature=1.0, top_k=30, top_p=0.0, device='cpu'):
94
+ inputs = torch.LongTensor(context).view(1, -1).to(device)
95
+ if len(context) > 1:
96
+ _, past = model(inputs[:, :-1], None)[:2]
97
+ prev = inputs[:, -1].view(1, -1)
98
+ else:
99
+ past = None
100
+ prev = inputs
101
+ generate = [] + context
102
+ with torch.no_grad():
103
+ for i in trange(length):
104
+ output = model(prev, past=past)
105
+ output, past = output[:2]
106
+ output = output[-1].squeeze(0) / temperature
107
+ filtered_logits = top_k_top_p_filtering(output, top_k=top_k, top_p=top_p)
108
+ next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=1)
109
+ generate.append(next_token.item())
110
+ prev = next_token.view(1, 1)
111
+ return generate
112
+
113
+
114
+ # 通过命令行参数--fast_pattern,指定模式
115
+ def generate(n_ctx, model, context, length, tokenizer, temperature=1, top_k=0, top_p=0.0, repitition_penalty=1.0, device='cpu',
116
+ is_fast_pattern=False):
117
+ if is_fast_pattern:
118
+ return fast_sample_sequence(model, context, length, temperature=temperature, top_k=top_k, top_p=top_p,
119
+ device=device)
120
+ else:
121
+ return sample_sequence(model, context, length, n_ctx, tokenizer=tokenizer, temperature=temperature, top_k=top_k, top_p=top_p,
122
+ repitition_penalty=repitition_penalty, device=device)
123
+
124
+ def smp_generate(pre_str):
125
+
126
+ from tokenizations import tokenization_bert
127
+
128
+ os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3' # 此处设置程序使用哪些显卡
129
+ length = 500
130
+ batch_size = 1
131
+ nsamples = 1
132
+ temperature = 1
133
+ topk = 8
134
+ topp = 0
135
+ repetition_penalty = 1.0
136
+ model_path = 'pretrained'
137
+ tokenizer_path = 'cache/vocab.txt'
138
+ save_samples = False
139
+ save_samples_path = '.'
140
+ fast_pattern = True
141
+ prefix = pre_str
142
+
143
+ device = "cuda" if torch.cuda.is_available() else "cpu"
144
+
145
+ tokenizer = tokenization_bert.BertTokenizer(vocab_file=tokenizer_path)
146
+ model = GPT2LMHeadModel.from_pretrained(model_path)
147
+ model.to(device)
148
+ model.eval()
149
+
150
+ n_ctx = model.config.n_ctx
151
+
152
+ if length == -1:
153
+ length = model.config.n_ctx
154
+
155
+ while True:
156
+ raw_text = prefix
157
+ context_tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(raw_text))
158
+ generated = 0
159
+ for _ in range(nsamples // batch_size):
160
+ out = generate(
161
+ n_ctx=n_ctx,
162
+ model=model,
163
+ context=context_tokens,
164
+ length=length,
165
+ is_fast_pattern=fast_pattern, tokenizer=tokenizer,
166
+ temperature=temperature, top_k=topk, top_p=topp, repitition_penalty=repetition_penalty, device=device
167
+ )
168
+ for i in range(batch_size):
169
+ generated += 1
170
+ text = tokenizer.convert_ids_to_tokens(out)
171
+ for i, item in enumerate(text[:-1]): # 确保英文前后有空格
172
+ if is_word(item) and is_word(text[i + 1]):
173
+ text[i] = item + ' '
174
+ for i, item in enumerate(text):
175
+ if item == '[MASK]':
176
+ text[i] = ''
177
+ elif item == '[CLS]':
178
+ text[i] = '\n\n'
179
+ elif item == '[SEP]':
180
+ text[i] = '\n'
181
+ info = "=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40 + "\n"
182
+ text = ''.join(text).replace('##', '').strip()
183
+ return text
184
+
185
+
186
+ def format_text(text):
187
+ return "<p>" + text.replace("\n", "<br>") + "</p>"
188
+
189
+ input_textbox = gr.inputs.Textbox(label="输入前缀")
190
+ output_textbox = gr.outputs.Textbox(label="生成文言文")
191
+
192
+ # 自定义HTML和CSS
193
+ html_content = """
194
+ <div style="display: flex; flex-direction: column-reverse;">
195
+ <div style="flex-grow: 1; overflow-y: auto;">
196
+ {output}
197
+ </div>
198
+ <div style="margin-top: 10px;">
199
+ {input}
200
+ </div>
201
+ </div>
202
+ """
203
+
204
+ iface = gr.Interface(fn=smp_generate, inputs=input_textbox, outputs=output_textbox,
205
+ title="文言文生成器", layout="vertical", layout_mode="size",
206
+ layout_alignments=["center", "top"], template="gradio/custom.html",
207
+ html=html_content)
208
+
209
+ iface.launch()
210
+
211
+
212
+