xu-song commited on
Commit
74a60bc
0 Parent(s):

Duplicate from eson/bert-perplexity-debug

Browse files
Files changed (6) hide show
  1. .gitattributes +34 -0
  2. .gitignore +16 -0
  3. README.md +13 -0
  4. app.py +58 -0
  5. perplexity.py +57 -0
  6. requirements.txt +2 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+
5
+ # C extensions
6
+ *.so
7
+
8
+ # Distribution / packaging
9
+ flagged/
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ eggs/
15
+ .eggs/
16
+ .idea/
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Bert Perplexity
3
+ emoji: 💩
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.18.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: eson/bert-perplexity-debug
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # author: xusong
3
+ # time: 2022/8/23 16:06
4
+
5
+ from perplexity import PerplexityPipeline
6
+ from transformers import BertTokenizer, BertForMaskedLM
7
+ import gradio as gr
8
+ import time
9
+
10
+ en_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
11
+ en_model = BertForMaskedLM.from_pretrained("bert-base-uncased")
12
+ en_pipeline = PerplexityPipeline(model=en_model, tokenizer=en_tokenizer)
13
+
14
+ zh_tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
15
+ zh_model = BertForMaskedLM.from_pretrained("bert-base-chinese")
16
+ zh_pipeline = PerplexityPipeline(model=zh_model, tokenizer=zh_tokenizer)
17
+
18
+
19
+ def ppl(model_version, text):
20
+ print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), model_version, text)
21
+ if model_version == "bert-base-uncased":
22
+ result = en_pipeline(text)
23
+ else:
24
+ result = zh_pipeline(text)
25
+ return result["ppl"], result
26
+
27
+
28
+ examples = [
29
+ ["bert-base-uncased", "New York City is located in the northeastern United States."],
30
+ ["bert-base-uncased", "New York City is located in the western United States."],
31
+ ["bert-base-chinese", "少先队员因该为老人让坐"],
32
+ ]
33
+
34
+ css = "#json-container {height:: 400px; overflow: auto !important}"
35
+
36
+ corr_iface = gr.Interface(
37
+ fn=ppl,
38
+ inputs=[
39
+ # gr.Dropdown(["bert-base-uncased", "bert-base-chinese"], value="bert-base-uncased"), # TODO 调整大小和位置
40
+ gr.Radio(
41
+ ["bert-base-uncased", "bert-base-chinese"],
42
+ value="bert-base-uncased"
43
+ ),
44
+ gr.Textbox(
45
+ value="New York City is located in the northeastern United States.",
46
+ label="input text"
47
+ )],
48
+ outputs=[
49
+ gr.Textbox(label="Perplexity"),
50
+ gr.JSON(label="Tokens", elem_id="json-container")],
51
+ examples=examples,
52
+ title="BERT as Language Model",
53
+ description='',
54
+ css=css
55
+ )
56
+
57
+ if __name__ == "__main__":
58
+ corr_iface.launch()
perplexity.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # author: xusong
3
+ # time: 2022/8/22 12:06
4
+
5
+ import numpy as np
6
+ import torch
7
+ from transformers import FillMaskPipeline
8
+
9
+
10
+ class PerplexityPipeline(FillMaskPipeline):
11
+
12
+ def create_sequential_mask(self, input_data, mask_count=1):
13
+ _, seq_length = input_data["input_ids"].shape
14
+ mask_count = seq_length - 2
15
+
16
+ input_ids = input_data["input_ids"]
17
+
18
+ new_input_ids = torch.repeat_interleave(input_data["input_ids"], repeats=mask_count, dim=0)
19
+ token_type_ids = torch.repeat_interleave(input_data["token_type_ids"], repeats=mask_count, dim=0)
20
+ attention_mask = torch.repeat_interleave(input_data["attention_mask"], repeats=mask_count, dim=0)
21
+ masked_lm_labels = []
22
+ masked_lm_positions = list(range(1, mask_count + 1))
23
+ for i in masked_lm_positions:
24
+ new_input_ids[i - 1][i] = self.tokenizer.mask_token_id
25
+ masked_lm_labels.append(input_ids[0][i].item())
26
+ new_data = {"input_ids": new_input_ids, "token_type_ids": token_type_ids, "attention_mask": attention_mask}
27
+ return new_data, masked_lm_positions, masked_lm_labels
28
+
29
+ def __call__(self, input_text, *args, **kwargs):
30
+ """
31
+ Compute perplexity for given sentence.
32
+ """
33
+ if not isinstance(input_text, str):
34
+ return None
35
+ # 1. create sequential mask
36
+ model_inputs = self.tokenizer(input_text, return_tensors='pt')
37
+ new_data, masked_lm_positions, masked_lm_labels = self.create_sequential_mask(model_inputs.data)
38
+ model_inputs.data = new_data
39
+ labels = torch.tensor(masked_lm_labels)
40
+
41
+ # 2. predict
42
+ model_outputs = self.model(**model_inputs)
43
+
44
+ # 3. compute perplexity
45
+ sentence = {}
46
+ tokens = []
47
+ for i in range(len(labels)):
48
+ model_outputs_i = {}
49
+ model_outputs_i["input_ids"] = model_inputs["input_ids"][i:i + 1]
50
+ model_outputs_i["logits"] = model_outputs["logits"][i:i + 1]
51
+ outputs = self.postprocess(model_outputs_i, target_ids=labels[i:i + 1])
52
+ print(outputs)
53
+ tokens.append({"token": outputs[0]["token_str"],
54
+ "prob": outputs[0]["score"]})
55
+ sentence["tokens"] = tokens
56
+ sentence["ppl"] = float(np.exp(- sum(np.log(token["prob"]) for token in tokens) / len(tokens)))
57
+ return sentence
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ transformers>=4.21.1
2
+ torch