Randolphzeng
commited on
Commit
•
cf79171
1
Parent(s):
6188ee5
Update README.md
Browse files
README.md
CHANGED
@@ -26,70 +26,50 @@ A deep VAE model pretrained on Wudao dataset. Both encoder and decoder are based
|
|
26 |
|
27 |
## 模型信息 Model Information
|
28 |
|
29 |
-
|
30 |
-
|
|
|
31 |
|
32 |
|
33 |
## 使用 Usage
|
34 |
|
35 |
```python
|
36 |
# Checkout the latest Fengshenbang-LM directory and run following script under Fengshenbang-LM root directory
|
37 |
-
|
38 |
import torch
|
39 |
-
import argparse
|
40 |
from torch.nn.utils.rnn import pad_sequence
|
41 |
-
from fengshen.models.deepVAE.
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
max_length = 256
|
75 |
-
top_p = 0.5
|
76 |
-
top_k = 0
|
77 |
-
temperature = .7
|
78 |
-
repetition_penalty = 1.0
|
79 |
-
sample = False
|
80 |
-
device = 0
|
81 |
-
model = model.eval()
|
82 |
-
model = model.to(device)
|
83 |
-
|
84 |
-
outputs = model.inference(inputs.to(device), top_p=top_p, top_k=top_k, max_length=max_length, sample=sample,
|
85 |
-
temperature=temperature, repetition_penalty=repetition_penalty)
|
86 |
-
|
87 |
-
for gen_sent, orig_sent in zip(outputs, inputs):
|
88 |
-
print('orig_sent:', tokenizer.decode(orig_sent).replace(' ', ''))
|
89 |
-
print('gen_sent:', tokenizer.decode(gen_sent).replace(' ', ''))
|
90 |
-
print("-"*20)
|
91 |
-
|
92 |
-
|
93 |
|
94 |
|
95 |
```
|
|
|
26 |
|
27 |
## 模型信息 Model Information
|
28 |
|
29 |
+
参考论文 Reference Paper:[Fuse It More Deeply! A Variational Transformer with Layer-Wise Latent Variable Inference for Text Generation](https://arxiv.org/abs/2207.06130)
|
30 |
+
本模型使用了Della论文里的循环潜在向量架构,但对于解码器生成并未采用原论文的low-rank-tensor-product来进行信息融合,而是使用了简单的线性变换后逐位逐词添加的方式。该方式对于开放域数据集的预训练稳定性有较大正向作用。
|
31 |
+
Note that although we adopted the layer-wise recurrent latent variables structure as the paper, we did not use the low-rank-tensor-product to fuse the latent vectors to the decoder hidden states. Instead we applied a simple linear transformation on the latent vectors and then add them to the hidden states independently.
|
32 |
|
33 |
|
34 |
## 使用 Usage
|
35 |
|
36 |
```python
|
37 |
# Checkout the latest Fengshenbang-LM directory and run following script under Fengshenbang-LM root directory
|
38 |
+
|
39 |
import torch
|
|
|
40 |
from torch.nn.utils.rnn import pad_sequence
|
41 |
+
from fengshen.models.deepVAE.deep_vae import Della
|
42 |
+
from transformers.models.bert.tokenization_bert import BertTokenizer
|
43 |
+
|
44 |
+
tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Randeng-DELLA-226M-Chinese")
|
45 |
+
vae_model = Della.from_pretrained("IDEA-CCNL/Randeng-DELLA-226M-Chinese")
|
46 |
+
|
47 |
+
special_tokens_dict = {'bos_token': '<BOS>', 'eos_token': '<EOS>'}
|
48 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
49 |
+
sentence = "本模型是在通用数据集下预训练的VAE模型,如要获得最佳效果请在特定领域微调后使用。"
|
50 |
+
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentence))
|
51 |
+
decoder_target = [tokenizer.bos_token_id] + tokenized_text + [tokenizer.eos_token_id]
|
52 |
+
inputs = []
|
53 |
+
inputs.append(torch.tensor(decoder_target, dtype=torch.long))
|
54 |
+
inputs = pad_sequence(inputs, batch_first=True, padding_value=0)
|
55 |
+
|
56 |
+
max_length = 256
|
57 |
+
top_p = 0.5
|
58 |
+
top_k = 0
|
59 |
+
temperature = .7
|
60 |
+
repetition_penalty = 1.0
|
61 |
+
sample = False
|
62 |
+
device = 0
|
63 |
+
model = vae_model.eval()
|
64 |
+
model = model.to(device)
|
65 |
+
|
66 |
+
outputs = model.model.inference(inputs.to(device), top_p=top_p, top_k=top_k, max_length=max_length, sample=sample,
|
67 |
+
temperature=temperature, repetition_penalty=repetition_penalty)
|
68 |
+
|
69 |
+
for gen_sent, orig_sent in zip(outputs, inputs):
|
70 |
+
print('orig_sent:', tokenizer.decode(orig_sent).replace(' ', ''))
|
71 |
+
print('gen_sent:', tokenizer.decode(gen_sent).replace(' ', ''))
|
72 |
+
print("-"*20)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
|
75 |
```
|