dsvilarko commited on
Commit
c4ebaf8
1 Parent(s): 7fa27b9

Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +5 -5
  2. app.py +119 -0
  3. checkpoint-1464/config.json +20 -0
  4. checkpoint-1464/optimizer.pt +3 -0
  5. checkpoint-1464/pytorch_model.bin +3 -0
  6. checkpoint-1464/rng_state.pth +3 -0
  7. checkpoint-1464/scheduler.pt +3 -0
  8. checkpoint-1464/special_tokens_map.json +1 -0
  9. checkpoint-1464/tokenizer.json +0 -0
  10. checkpoint-1464/tokenizer_config.json +1 -0
  11. checkpoint-1464/trainer_state.json +106 -0
  12. checkpoint-1464/training_args.bin +3 -0
  13. checkpoint-1464/vocab.txt +0 -0
  14. checkpoint-150/config.json +58 -0
  15. checkpoint-150/optimizer.pt +3 -0
  16. checkpoint-150/pytorch_model.bin +3 -0
  17. checkpoint-150/rng_state.pth +3 -0
  18. checkpoint-150/scheduler.pt +3 -0
  19. checkpoint-150/special_tokens_map.json +110 -0
  20. checkpoint-150/spiece.model +3 -0
  21. checkpoint-150/tokenizer.json +0 -0
  22. checkpoint-150/tokenizer_config.json +117 -0
  23. checkpoint-150/trainer_state.json +56 -0
  24. checkpoint-150/training_args.bin +3 -0
  25. fudge/LICENSE +21 -0
  26. fudge/README.md +155 -0
  27. fudge/clickbait_classifier.py +128 -0
  28. fudge/constants.py +32 -0
  29. fudge/data.py +415 -0
  30. fudge/eval_formality_metrics.py +73 -0
  31. fudge/eval_poetry_metrics.py +135 -0
  32. fudge/eval_topic_metrics.py +134 -0
  33. fudge/evaluate_clickbait.py +200 -0
  34. fudge/evaluate_formality.py +104 -0
  35. fudge/evaluate_poetry.py +115 -0
  36. fudge/evaluate_topic.py +143 -0
  37. fudge/formality_data/README.md +2 -0
  38. fudge/formality_data/fisher_test_oracle.es +0 -0
  39. fudge/formality_data/test.noid.cleaned_0 +0 -0
  40. fudge/formality_data/test.noid.cleaned_1 +0 -0
  41. fudge/main.py +192 -0
  42. fudge/model.py +182 -0
  43. fudge/poetry_data/README.md +1 -0
  44. fudge/poetry_data/couplet_ends.txt +154 -0
  45. fudge/poetry_data/couplet_prefixes.txt +154 -0
  46. fudge/poetry_util.py +83 -0
  47. fudge/predict_clickbait.py +199 -0
  48. fudge/predict_formality.py +404 -0
  49. fudge/predict_poetry.py +219 -0
  50. fudge/predict_topic.py +126 -0
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
  title: Clickbaitonator
3
- emoji: 😻
4
- colorFrom: gray
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.0.26
8
  app_file: app.py
9
  pinned: false
10
- license: gpl
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Clickbaitonator
3
+ emoji: 💩
4
+ colorFrom: purple
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.0.24
8
  app_file: app.py
9
  pinned: false
10
+ license: afl-3.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+
3
+ # os.chdir('naacl-2021-fudge-controlled-generation/')
4
+
5
+ import gradio as gr
6
+ from fudge.predict_clickbait import generate_clickbait, tokenizer, classifier_tokenizer
7
+ from datasets import load_dataset,DatasetDict,Dataset
8
+ # from datasets import
9
+ from transformers import AutoTokenizer,AutoModelForSeq2SeqLM
10
+ import numpy as np
11
+ from sklearn.model_selection import train_test_split
12
+ import pandas as pd
13
+ from sklearn.utils.class_weight import compute_class_weight
14
+ import torch
15
+ import pandas as pd
16
+ from fudge.model import Model
17
+ import os
18
+ from argparse import ArgumentParser
19
+ from collections import namedtuple
20
+ import mock
21
+
22
+ from tqdm import tqdm
23
+ import numpy as np
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from fudge.data import Dataset
27
+ from fudge.util import save_checkpoint, ProgressMeter, AverageMeter, num_params
28
+ from fudge.constants import *
29
+
30
+
31
+ device = 'cpu'
32
+ # imp.reload(model)
33
+ pretrained_model = "checkpoint-150/"
34
+ generation_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model, return_dict=True).to(device)
35
+
36
+
37
+ pad_id = 0
38
+
39
+ generation_model.eval()
40
+
41
+ model_args = mock.Mock()
42
+ model_args.task = 'clickbait'
43
+ model_args.device = device
44
+ model_args.checkpoint = 'checkpoint-1464/'
45
+
46
+ # conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
47
+ conditioning_model = Model(model_args, pad_id, vocab_size=None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
48
+ conditioning_model = conditioning_model.to(device)
49
+ conditioning_model.eval()
50
+
51
+ condition_lambda = 5.0
52
+ length_cutoff = 50
53
+ precondition_topk = 200
54
+
55
+
56
+ conditioning_model.classifier
57
+
58
+ model_args.checkpoint
59
+
60
+ classifier_tokenizer = AutoTokenizer.from_pretrained(model_args.checkpoint, load_best_model_at_end=True)
61
+
62
+
63
+ def rate_title(input_text, model, tokenizer, device='cuda'):
64
+ # input_text = {
65
+ # "postText": input_text['postText'],
66
+ # "truthClass" : input_text['truthClass']
67
+ # }
68
+ tokenized_input = preprocess_function_title_only_classification(input_text,tokenizer=tokenizer)
69
+ # print(tokenized_input.items())
70
+ dict_tokenized_input = {k : torch.tensor([v]).to(device) for k,v in tokenized_input.items() if k != 'labels'}
71
+ predicted_class = float(model(**dict_tokenized_input).logits)
72
+ actual_class = input_text['truthClass']
73
+
74
+ # print(predicted_class, actual_class)
75
+ return {'predicted_class' : predicted_class}
76
+
77
+ def preprocess_function_title_only_classification(examples,tokenizer=None):
78
+ model_inputs = tokenizer(examples['postText'], padding="longest", truncation=True, max_length=25)
79
+
80
+ model_inputs['labels'] = examples['truthClass']
81
+
82
+ return model_inputs
83
+
84
+
85
+
86
+ def clickbait_generator(article_content, condition_lambda=5.0):
87
+ # result = "Hi {}! 😎. The Mulitple of {} is {}".format(name, number, round(number**2, 2))
88
+ results = generate_clickbait(model=generation_model,
89
+ tokenizer=tokenizer,
90
+ conditioning_model=conditioning_model,
91
+ input_text=[None],
92
+ dataset_info=None,
93
+ precondition_topk=precondition_topk,
94
+ length_cutoff=length_cutoff,
95
+ condition_lambda=condition_lambda,
96
+ article_content=article_content,
97
+ device=device)
98
+
99
+ return results[0].replace('</s>', '').replace('<pad>', '')
100
+
101
+ title = "Clickbaitinator - Controllable Clickbait generator"
102
+ description = """
103
+ Use the [Fudge](https://github.com/yangkevin2/naacl-2021-fudge-controlled-generation) implementation fine-tuned for our purposes to try and create news headline you are looking for! Use condition_lambda to steer your clickbaitiness higher (by increasing the slider value) or lower (by decreasing the slider value). <br/>
104
+ Note that this is using two Transformers and is executed with CPU-only, so it will take a minute or two to finish generating a title.
105
+ """
106
+
107
+ article = "Check out [the codebase for our model](https://github.com/dsvilarkovic/naacl-2021-fudge-controlled-generation) that this demo is based of. You need collaborator access, which you have been probably invited for."
108
+
109
+
110
+ app = gr.Interface(
111
+ title = title,
112
+ description = description,
113
+ label = 'Article content or paragraph',
114
+ fn = clickbait_generator,
115
+ inputs=["text", gr.Slider(0, 15, step=0.1, value=5.0)],
116
+ outputs="text",
117
+ article=article,
118
+ )
119
+ app.launch()
checkpoint-1464/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertClickbaitClassifier"
4
+ ],
5
+ "dropout": 0.2,
6
+ "freeze_bert": false,
7
+ "id2label": {
8
+ "0": "LABEL_0"
9
+ },
10
+ "inner_dim1": 256,
11
+ "inner_dim2": 32,
12
+ "label2id": {
13
+ "LABEL_0": 0
14
+ },
15
+ "load_pretrained": true,
16
+ "max_length": 25,
17
+ "pretrained_model": "sentence-transformers/all-mpnet-base-v2",
18
+ "torch_dtype": "float32",
19
+ "transformers_version": "4.19.2"
20
+ }
checkpoint-1464/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e86af40c41b9c7bd860a34891976e7ccb5cb03e35f540bac3901e768a5e90947
3
+ size 872925589
checkpoint-1464/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:881e07cb4fc93116a9f5bee91fc6048ce366725537b34edf7fdf7f243d2ba240
3
+ size 438838053
checkpoint-1464/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af047df93f2e21bc6b802f06e57d980fb915986b57e8908ad2e2e43065125260
3
+ size 14503
checkpoint-1464/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c454d2b62b7deada05505a5f5f4607c60de6caa71a7d7b0a6c0c1821f97993c7
3
+ size 623
checkpoint-1464/special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "[UNK]", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
checkpoint-1464/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-1464/tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"do_lower_case": true, "bos_token": "<s>", "eos_token": "</s>", "sep_token": "</s>", "cls_token": "<s>", "unk_token": "[UNK]", "pad_token": "<pad>", "mask_token": "<mask>", "tokenize_chinese_chars": true, "strip_accents": null, "model_max_length": 512, "special_tokens_map_file": null, "name_or_path": "sentence-transformers/all-mpnet-base-v2", "tokenizer_class": "MPNetTokenizer"}
checkpoint-1464/trainer_state.json ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 1.0750963687896729,
3
+ "best_model_checkpoint": "drive/MyDrive/nlp_lss_data/mpnet_clickbait_classification_maxlen25/checkpoint-488",
4
+ "epoch": 6.0,
5
+ "global_step": 1464,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 1.0,
12
+ "eval_accuracy": 0.8498845265588915,
13
+ "eval_balanced_accuracy": 0.7757283651550713,
14
+ "eval_f1": 0.670793472144063,
15
+ "eval_loss": 1.1807881593704224,
16
+ "eval_precision": 0.7146282973621103,
17
+ "eval_recall": 0.6320254506892895,
18
+ "eval_runtime": 34.3115,
19
+ "eval_samples_per_second": 113.577,
20
+ "eval_steps_per_second": 113.577,
21
+ "step": 244
22
+ },
23
+ {
24
+ "epoch": 2.0,
25
+ "eval_accuracy": 0.8545034642032333,
26
+ "eval_balanced_accuracy": 0.7932135085090511,
27
+ "eval_f1": 0.6916802610114193,
28
+ "eval_loss": 1.0750963687896729,
29
+ "eval_precision": 0.7098214285714286,
30
+ "eval_recall": 0.6744432661717922,
31
+ "eval_runtime": 33.949,
32
+ "eval_samples_per_second": 114.79,
33
+ "eval_steps_per_second": 114.79,
34
+ "step": 488
35
+ },
36
+ {
37
+ "epoch": 2.05,
38
+ "learning_rate": 3.533724340175953e-05,
39
+ "loss": 1.3113,
40
+ "step": 500
41
+ },
42
+ {
43
+ "epoch": 3.0,
44
+ "eval_accuracy": 0.8601488324352066,
45
+ "eval_balanced_accuracy": 0.7835817278869854,
46
+ "eval_f1": 0.6873207114170969,
47
+ "eval_loss": 1.1083189249038696,
48
+ "eval_precision": 0.74875,
49
+ "eval_recall": 0.6352067868504772,
50
+ "eval_runtime": 33.7352,
51
+ "eval_samples_per_second": 115.517,
52
+ "eval_steps_per_second": 115.517,
53
+ "step": 732
54
+ },
55
+ {
56
+ "epoch": 4.0,
57
+ "eval_accuracy": 0.8534770336156018,
58
+ "eval_balanced_accuracy": 0.7993947132812708,
59
+ "eval_f1": 0.6964380648591175,
60
+ "eval_loss": 1.1579805612564087,
61
+ "eval_precision": 0.6982942430703625,
62
+ "eval_recall": 0.694591728525981,
63
+ "eval_runtime": 33.9178,
64
+ "eval_samples_per_second": 114.895,
65
+ "eval_steps_per_second": 114.895,
66
+ "step": 976
67
+ },
68
+ {
69
+ "epoch": 4.1,
70
+ "learning_rate": 1.7008797653958943e-05,
71
+ "loss": 0.7869,
72
+ "step": 1000
73
+ },
74
+ {
75
+ "epoch": 5.0,
76
+ "eval_accuracy": 0.8552732871439569,
77
+ "eval_balanced_accuracy": 0.8009405080804215,
78
+ "eval_f1": 0.6993603411513859,
79
+ "eval_loss": 1.2740588188171387,
80
+ "eval_precision": 0.7031082529474812,
81
+ "eval_recall": 0.6956521739130435,
82
+ "eval_runtime": 34.1758,
83
+ "eval_samples_per_second": 114.028,
84
+ "eval_steps_per_second": 114.028,
85
+ "step": 1220
86
+ },
87
+ {
88
+ "epoch": 6.0,
89
+ "eval_accuracy": 0.8555298947908647,
90
+ "eval_balanced_accuracy": 0.793168635227608,
91
+ "eval_f1": 0.6925177498634627,
92
+ "eval_loss": 1.3905503749847412,
93
+ "eval_precision": 0.713963963963964,
94
+ "eval_recall": 0.672322375397667,
95
+ "eval_runtime": 33.4993,
96
+ "eval_samples_per_second": 116.331,
97
+ "eval_steps_per_second": 116.331,
98
+ "step": 1464
99
+ }
100
+ ],
101
+ "max_steps": 1464,
102
+ "num_train_epochs": 6,
103
+ "total_flos": 1204353585477900.0,
104
+ "trial_name": null,
105
+ "trial_params": null
106
+ }
checkpoint-1464/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6cf148502136d195161841d56eda30be54006d58e628e1439c9393fb6404ef4a
3
+ size 3311
checkpoint-1464/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-150/config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "google/pegasus-xsum",
3
+ "activation_dropout": 0.1,
4
+ "activation_function": "relu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": true,
7
+ "architectures": [
8
+ "PegasusForConditionalGeneration"
9
+ ],
10
+ "attention_dropout": 0.1,
11
+ "bos_token_id": 0,
12
+ "classif_dropout": 0.0,
13
+ "classifier_dropout": 0.0,
14
+ "d_model": 1024,
15
+ "decoder_attention_heads": 16,
16
+ "decoder_ffn_dim": 4096,
17
+ "decoder_layerdrop": 0.0,
18
+ "decoder_layers": 16,
19
+ "decoder_start_token_id": 0,
20
+ "do_blenderbot_90_layernorm": false,
21
+ "dropout": 0.1,
22
+ "encoder_attention_heads": 16,
23
+ "encoder_ffn_dim": 4096,
24
+ "encoder_layerdrop": 0.0,
25
+ "encoder_layers": 16,
26
+ "eos_token_id": 1,
27
+ "extra_pos_embeddings": 0,
28
+ "force_bos_token_to_be_generated": false,
29
+ "forced_eos_token_id": 1,
30
+ "gradient_checkpointing": false,
31
+ "id2label": {
32
+ "0": "LABEL_0",
33
+ "1": "LABEL_1",
34
+ "2": "LABEL_2"
35
+ },
36
+ "init_std": 0.02,
37
+ "is_encoder_decoder": true,
38
+ "label2id": {
39
+ "LABEL_0": 0,
40
+ "LABEL_1": 1,
41
+ "LABEL_2": 2
42
+ },
43
+ "length_penalty": 0.6,
44
+ "max_length": 64,
45
+ "max_position_embeddings": 512,
46
+ "model_type": "pegasus",
47
+ "normalize_before": true,
48
+ "normalize_embedding": false,
49
+ "num_beams": 8,
50
+ "num_hidden_layers": 16,
51
+ "pad_token_id": 0,
52
+ "scale_embedding": true,
53
+ "static_position_embeddings": true,
54
+ "torch_dtype": "float32",
55
+ "transformers_version": "4.20.1",
56
+ "use_cache": true,
57
+ "vocab_size": 96103
58
+ }
checkpoint-150/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b31a813daf8949431f72c9672f50293c37c937f8239655269de86409df0a04ad
3
+ size 5839694
checkpoint-150/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6247ec114255a90ed5b84b8a94e1f9c20e3ff778c1cb853fb3758706d58deb78
3
+ size 2279605745
checkpoint-150/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:61da4aab34859d84193f6d36e1ca6db2cab8dd2449b8ba79a7f1af61aa8c44a5
3
+ size 14503
checkpoint-150/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e43ff59eb3184ee5df5457b7f99569ad47a950434bbef23184de1d9025687c8e
3
+ size 623
checkpoint-150/special_tokens_map.json ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<mask_1>",
4
+ "<unk_2>",
5
+ "<unk_3>",
6
+ "<unk_4>",
7
+ "<unk_5>",
8
+ "<unk_6>",
9
+ "<unk_7>",
10
+ "<unk_8>",
11
+ "<unk_9>",
12
+ "<unk_10>",
13
+ "<unk_11>",
14
+ "<unk_12>",
15
+ "<unk_13>",
16
+ "<unk_14>",
17
+ "<unk_15>",
18
+ "<unk_16>",
19
+ "<unk_17>",
20
+ "<unk_18>",
21
+ "<unk_19>",
22
+ "<unk_20>",
23
+ "<unk_21>",
24
+ "<unk_22>",
25
+ "<unk_23>",
26
+ "<unk_24>",
27
+ "<unk_25>",
28
+ "<unk_26>",
29
+ "<unk_27>",
30
+ "<unk_28>",
31
+ "<unk_29>",
32
+ "<unk_30>",
33
+ "<unk_31>",
34
+ "<unk_32>",
35
+ "<unk_33>",
36
+ "<unk_34>",
37
+ "<unk_35>",
38
+ "<unk_36>",
39
+ "<unk_37>",
40
+ "<unk_38>",
41
+ "<unk_39>",
42
+ "<unk_40>",
43
+ "<unk_41>",
44
+ "<unk_42>",
45
+ "<unk_43>",
46
+ "<unk_44>",
47
+ "<unk_45>",
48
+ "<unk_46>",
49
+ "<unk_47>",
50
+ "<unk_48>",
51
+ "<unk_49>",
52
+ "<unk_50>",
53
+ "<unk_51>",
54
+ "<unk_52>",
55
+ "<unk_53>",
56
+ "<unk_54>",
57
+ "<unk_55>",
58
+ "<unk_56>",
59
+ "<unk_57>",
60
+ "<unk_58>",
61
+ "<unk_59>",
62
+ "<unk_60>",
63
+ "<unk_61>",
64
+ "<unk_62>",
65
+ "<unk_63>",
66
+ "<unk_64>",
67
+ "<unk_65>",
68
+ "<unk_66>",
69
+ "<unk_67>",
70
+ "<unk_68>",
71
+ "<unk_69>",
72
+ "<unk_70>",
73
+ "<unk_71>",
74
+ "<unk_72>",
75
+ "<unk_73>",
76
+ "<unk_74>",
77
+ "<unk_75>",
78
+ "<unk_76>",
79
+ "<unk_77>",
80
+ "<unk_78>",
81
+ "<unk_79>",
82
+ "<unk_80>",
83
+ "<unk_81>",
84
+ "<unk_82>",
85
+ "<unk_83>",
86
+ "<unk_84>",
87
+ "<unk_85>",
88
+ "<unk_86>",
89
+ "<unk_87>",
90
+ "<unk_88>",
91
+ "<unk_89>",
92
+ "<unk_90>",
93
+ "<unk_91>",
94
+ "<unk_92>",
95
+ "<unk_93>",
96
+ "<unk_94>",
97
+ "<unk_95>",
98
+ "<unk_96>",
99
+ "<unk_97>",
100
+ "<unk_98>",
101
+ "<unk_99>",
102
+ "<unk_100>",
103
+ "<unk_101>",
104
+ "<unk_102>"
105
+ ],
106
+ "eos_token": "</s>",
107
+ "mask_token": "<mask_2>",
108
+ "pad_token": "<pad>",
109
+ "unk_token": "<unk>"
110
+ }
checkpoint-150/spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0015189ef36359283fec8b93cf6d9ce51bca37eb1101defc68a53b394913b96c
3
+ size 1912529
checkpoint-150/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
checkpoint-150/tokenizer_config.json ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<mask_1>",
4
+ "<unk_2>",
5
+ "<unk_3>",
6
+ "<unk_4>",
7
+ "<unk_5>",
8
+ "<unk_6>",
9
+ "<unk_7>",
10
+ "<unk_8>",
11
+ "<unk_9>",
12
+ "<unk_10>",
13
+ "<unk_11>",
14
+ "<unk_12>",
15
+ "<unk_13>",
16
+ "<unk_14>",
17
+ "<unk_15>",
18
+ "<unk_16>",
19
+ "<unk_17>",
20
+ "<unk_18>",
21
+ "<unk_19>",
22
+ "<unk_20>",
23
+ "<unk_21>",
24
+ "<unk_22>",
25
+ "<unk_23>",
26
+ "<unk_24>",
27
+ "<unk_25>",
28
+ "<unk_26>",
29
+ "<unk_27>",
30
+ "<unk_28>",
31
+ "<unk_29>",
32
+ "<unk_30>",
33
+ "<unk_31>",
34
+ "<unk_32>",
35
+ "<unk_33>",
36
+ "<unk_34>",
37
+ "<unk_35>",
38
+ "<unk_36>",
39
+ "<unk_37>",
40
+ "<unk_38>",
41
+ "<unk_39>",
42
+ "<unk_40>",
43
+ "<unk_41>",
44
+ "<unk_42>",
45
+ "<unk_43>",
46
+ "<unk_44>",
47
+ "<unk_45>",
48
+ "<unk_46>",
49
+ "<unk_47>",
50
+ "<unk_48>",
51
+ "<unk_49>",
52
+ "<unk_50>",
53
+ "<unk_51>",
54
+ "<unk_52>",
55
+ "<unk_53>",
56
+ "<unk_54>",
57
+ "<unk_55>",
58
+ "<unk_56>",
59
+ "<unk_57>",
60
+ "<unk_58>",
61
+ "<unk_59>",
62
+ "<unk_60>",
63
+ "<unk_61>",
64
+ "<unk_62>",
65
+ "<unk_63>",
66
+ "<unk_64>",
67
+ "<unk_65>",
68
+ "<unk_66>",
69
+ "<unk_67>",
70
+ "<unk_68>",
71
+ "<unk_69>",
72
+ "<unk_70>",
73
+ "<unk_71>",
74
+ "<unk_72>",
75
+ "<unk_73>",
76
+ "<unk_74>",
77
+ "<unk_75>",
78
+ "<unk_76>",
79
+ "<unk_77>",
80
+ "<unk_78>",
81
+ "<unk_79>",
82
+ "<unk_80>",
83
+ "<unk_81>",
84
+ "<unk_82>",
85
+ "<unk_83>",
86
+ "<unk_84>",
87
+ "<unk_85>",
88
+ "<unk_86>",
89
+ "<unk_87>",
90
+ "<unk_88>",
91
+ "<unk_89>",
92
+ "<unk_90>",
93
+ "<unk_91>",
94
+ "<unk_92>",
95
+ "<unk_93>",
96
+ "<unk_94>",
97
+ "<unk_95>",
98
+ "<unk_96>",
99
+ "<unk_97>",
100
+ "<unk_98>",
101
+ "<unk_99>",
102
+ "<unk_100>",
103
+ "<unk_101>",
104
+ "<unk_102>"
105
+ ],
106
+ "eos_token": "</s>",
107
+ "full_tokenizer_file": null,
108
+ "mask_token": "<mask_2>",
109
+ "mask_token_sent": "<mask_1>",
110
+ "model_max_length": 512,
111
+ "name_or_path": "google/pegasus-xsum",
112
+ "offset": 103,
113
+ "pad_token": "<pad>",
114
+ "special_tokens_map_file": null,
115
+ "tokenizer_class": "PegasusTokenizer",
116
+ "unk_token": "<unk>"
117
+ }
checkpoint-150/trainer_state.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 4.982725527831094,
5
+ "global_step": 150,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 0.98,
12
+ "eval_loss": 2.3803367614746094,
13
+ "eval_runtime": 213.2043,
14
+ "eval_samples_per_second": 18.33,
15
+ "eval_steps_per_second": 18.33,
16
+ "step": 30
17
+ },
18
+ {
19
+ "epoch": 1.98,
20
+ "eval_loss": 2.2591161727905273,
21
+ "eval_runtime": 212.2667,
22
+ "eval_samples_per_second": 18.411,
23
+ "eval_steps_per_second": 18.411,
24
+ "step": 60
25
+ },
26
+ {
27
+ "epoch": 2.98,
28
+ "eval_loss": 2.203186511993408,
29
+ "eval_runtime": 212.608,
30
+ "eval_samples_per_second": 18.381,
31
+ "eval_steps_per_second": 18.381,
32
+ "step": 90
33
+ },
34
+ {
35
+ "epoch": 3.98,
36
+ "eval_loss": 2.1706554889678955,
37
+ "eval_runtime": 212.4689,
38
+ "eval_samples_per_second": 18.393,
39
+ "eval_steps_per_second": 18.393,
40
+ "step": 120
41
+ },
42
+ {
43
+ "epoch": 4.98,
44
+ "eval_loss": 2.1453042030334473,
45
+ "eval_runtime": 213.096,
46
+ "eval_samples_per_second": 18.339,
47
+ "eval_steps_per_second": 18.339,
48
+ "step": 150
49
+ }
50
+ ],
51
+ "max_steps": 900,
52
+ "num_train_epochs": 30,
53
+ "total_flos": 1.076296204604375e+17,
54
+ "trial_name": null,
55
+ "trial_params": null
56
+ }
checkpoint-150/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50f1c331773948f1aafffbec50322b3f07edf4e08988fbdb44798e3e9b3db9fd
3
+ size 3375
fudge/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Kevin Yang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
fudge/README.md ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FUDGE: Controlled Text Generation With Future Discriminators
2
+
3
+ This repo contains code corresponding to the paper FUDGE: Controlled Text Generation With Future Discriminators (https://arxiv.org/abs/2104.05218) by Kevin Yang and Dan Klein, published at NAACL 2021.
4
+
5
+ You can also find a video presentation at http://somup.com/crhlVPFKN7 and the corresponding slides in `slides.pptx`.
6
+
7
+ ## Setup/Installation
8
+
9
+ We tested on Python 3.8.5 but earlier versions of Python 3 are almost certainly fine. To get the required packages (other versions likely to work too):
10
+
11
+ ```
12
+ pip install -r requirements.txt
13
+ ```
14
+
15
+ Additionally, to get our pre-trained predictor checkpoints and training data, run:
16
+
17
+ ```
18
+ wget https://naacl2021-fudge-files.s3.amazonaws.com/large_files.zip
19
+ ```
20
+
21
+ and extract the zip to the top-level `lm-prediction/` folder. (There should be three folders, `ckpt/`, `train_data/`, and `topic_human_evals/`. The zip is 7GB.) Note: the zip seems to not work for some people actually, if this is the case you can get the files directly from https://drive.google.com/drive/folders/1GZfOGqpQxDmIfD2RvuhUQla9eX2OHUXU?usp=sharing (13GB).
22
+
23
+ `ckpt/` contains predictor checkpoints for each task if you are just interested in running inference. (Note that for the paper results, we used predictors trained with an older version of the code, but the new checkpoints get similar results, so you are OK to use the new predictors provided here if e.g. you just want to use FUDGE as a baseline. You can just run the evaluation commands provided below; it should take maybe 5-60 minutes depending on the task and your compute, assuming you have a GPU.)
24
+
25
+ `train_data/` contains our GPT2-generated training data for the poetry and topic tasks' predictors. See https://github.com/raosudha89/GYAFC-corpus for instructions on gaining access to the GYAFC data used for the machine translation formality task; replace our dummy folders with the corresponding folders/files if you want to train our formality predictor.
26
+
27
+ ## Clickbait
28
+ To generate outputs, run:
29
+
30
+ ```
31
+ python -u evaluate_clickbait.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --in_file topic_data/topic_prefixes.txt --condition_lambda 4.0 --verbose --precondition_topk 200 --length_cutoff 80 --device cpu
32
+
33
+ python -u evaluate_clickbait.py --ckpt ckpt/formality/predictor_gyafc_entertainment_music/model.pth.tar --dataset_info ckpt/formality/predictor_gyafc_entertainment_music/dataset_info --in_file formality_data/fisher_test_oracle.es
34
+
35
+ python -u evaluate_clickbait.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --in_file topic_data/topic_prefixes.txt --condition_lambda 4.0 --verbose --precondition_topk 200 --sample_size 3 --max_sample_batch 1 --length_cutoff 80 --log_file clickbait_preds.log
36
+ ```
37
+
38
+ Then evaluate metrics using:
39
+
40
+ ```
41
+ python eval_topic_metrics.py --log_file topic_preds.log --tw_dir topic_data/test_wordlists
42
+ ```
43
+
44
+
45
+ ## Poetry Couplet Completion
46
+
47
+ ### Evaluation
48
+
49
+ To generate outputs, run:
50
+
51
+ ```
52
+ python -u evaluate_poetry.py --iambic_ckpt ckpt/poetry/iambic_predictor/model.pth.tar --rhyme_ckpt ckpt/poetry/rhyme_predictor/model.pth.tar --newline_ckpt ckpt/poetry/newline_predictor/model.pth.tar --dataset_info ckpt/poetry/rhyme_predictor/dataset_info --rhyme_info ckpt/poetry/rhyme_predictor/rhyme_info --prefix_file poetry_data/couplet_prefixes.txt --precondition_topk 200 > poetry_preds.log
53
+ ```
54
+
55
+ Then evaluate metrics using:
56
+
57
+ ```
58
+ python eval_poetry_metrics.py --pred_file poetry_preds.log --prefix_file poetry_data/couplet_prefixes.txt
59
+ ```
60
+
61
+ ### Training your own predictors
62
+
63
+ Example commands for all three predictors used in the poetry task below. (You actually probably don't need so many epochs for iambic and rhyme; in any case the commands will save intermediate ckpts so you can just stop them early if needed by inspecting the log.)
64
+
65
+ Iambic predictor:
66
+
67
+ ```
68
+ python -u main.py --task iambic --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/iambic_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 1500 > iambic_retrain_predictor.log
69
+ ```
70
+
71
+ Rhyme predictor:
72
+
73
+ ```
74
+ python -u main.py --task rhyme --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/rhyme_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 1500 > rhyme_retrain_predictor.log
75
+ ```
76
+
77
+ End of sentence predictor (referred to as "newline" in the code; 50 epochs is more than enough for this one):
78
+
79
+ ```
80
+ python -u main.py --task newline --data_dir train_data/gpt2_generations --save_dir ckpt/poetry/newline_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 50 > newline_retrain_predictor.log
81
+ ```
82
+
83
+ The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
84
+
85
+ ## Topic Control
86
+
87
+ ### Evaluation
88
+
89
+ To generate outputs, run:
90
+
91
+ ```
92
+ python -u evaluate_topic.py --ckpt ckpt/topic/future_word_predictor/model.pth.tar --dataset_info ckpt/topic/future_word_predictor/dataset_info --prefix_file topic_data/topic_prefixes.txt --wordlist_dir topic_data/wordlists --condition_lambda 4.0 --verbose --precondition_topk 200 --topk 10 --sample_size 3 --max_sample_batch 1 --length_cutoff 80 --log_file topic_preds.log
93
+ ```
94
+
95
+ Then evaluate metrics using:
96
+
97
+ ```
98
+ python eval_topic_metrics.py --log_file topic_preds.log --tw_dir topic_data/test_wordlists
99
+ ```
100
+
101
+ You can also find our original generations and baselines in `topic_human_evals/`.
102
+
103
+ ### Training your own predictors
104
+
105
+ Example command below.
106
+
107
+ ```
108
+ python -u main.py --task topic --data_dir train_data/gpt2_generations --save_dir ckpt/topic/future_word_retrain_predictor --num_workers 20 --batch_size 128 --epoch_max_len 100000 --validation_freq 10 --lr 2e-4 --epochs 500 --glove_file train_data/glove.840B.300d.txt > future_word_retrain_predictor.log
109
+ ```
110
+
111
+ The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
112
+
113
+ ## Machine Translation Formality
114
+
115
+ ### Evaluation
116
+
117
+ To generate outputs, run:
118
+
119
+ ```
120
+ python -u evaluate_formality.py --ckpt ckpt/formality/predictor_gyafc_entertainment_music/model.pth.tar --dataset_info ckpt/formality/predictor_gyafc_entertainment_music/dataset_info --in_file formality_data/fisher_test_oracle.es --model_path ckpt/formality/marian_finetune_fisher > formality_preds.log
121
+ ```
122
+
123
+ The above command generates predictions using the Marian model finetuned on the Fisher dataset; remove the `--model_path` argument to get predictions with the un-finetuned Marian model from HuggingFace (referred to as 0-shot in the paper)
124
+
125
+ Then evaluate metrics using:
126
+
127
+ ```
128
+ python eval_formality_metrics.py --pred formality_preds.log --ref formality_data/test.noid.cleaned_0 formality_data/test.noid.cleaned_1 --ckpt ckpt/formality/test_evaluator_gyafc_family_relationships/model.pth.tar --dataset_info ckpt/formality/test_evaluator_gyafc_family_relationships/dataset_info
129
+ ```
130
+
131
+ ### Training your own predictors
132
+
133
+ Example command below. (Reminder: you need to go get the GYAFC dataset following the instructions in https://github.com/raosudha89/GYAFC-corpus.)
134
+
135
+ ```
136
+ python -u main.py --task formality --data_dir train_data/GYAFC_Corpus/Entertainment_Music --save_dir ckpt/formality/formality_retrain_predictor --num_workers 20 --batch_size 32 --epoch_max_len 1000000 --validation_freq 1 --lr 2e-5 --epochs 20 > formality_retrain_predictor.log
137
+ ```
138
+
139
+ (The test-time formality evaluator is trained in the same way, just using the Family/Relationships half of the GYAFC dataset.)
140
+
141
+ The same evaluation commands as before will work; just modify the paths in the command to point to `model_best.pth.tar`, `dataset_info`, and `rhyme_info` from your newly trained ckpt folders.
142
+
143
+ ## Running FUDGE on your own data
144
+
145
+ The code has been refactored so that the iambic (poetry), rhyme (poetry), newline (poetry), future word (topic), and formality (machine translation) are controlled by the `--task` flag to `main.py`. You should add your task as another option here, then modify the data processing in `data.py` and the model in `model.py` as needed for your task. (In `data.py` you probably won't need all the entries of the tuple that is expected of the loader; you can just put dummy entries in the ones you don't need.) You might also need to modify the loss computation in the `train` and `validate` functions in `main.py`. You'll probably want to write new evaluation scripts, though the existing poetry/topic/formality ones are hopefully helpful as references.
146
+
147
+ Alternatively, the general FUDGE framework is pretty simple, so you could always try reimplementing things yourself. A few additional details based on questions I've received:
148
+
149
+ (1) The formality task setup is likely closest to what you want if you're just trying to run the simplest form of FUDGE (take a language model, and use a classifier to optimize toward a single attribute) although you may need to swap out the Marian translation model/tokenizer we use.
150
+
151
+ (2) When you construct your training data, if you have an example in your data e.g. "This movie is great!" for positive sentiment, you want to learn on all the pairs (This, +), (This movie, +), (This movie is, +), etc., as that's one of the main points of our approach.
152
+
153
+ (3) For computational efficiency, we first filter the base model's next token probabilities down to the top 200 (Sec. 3.1 in the paper), before adding the classifier logits. This way you only need to evaluate your classifier on 200 continuations. Then afterward, you filter down again to whatever top-k/greedy/nucleus sampling you're using for evaluation (we use top-k with k=10 for poetry and topic, greedy for formality).
154
+
155
+ (4) You can use a pretrained LM backbone instead of a simple LSTM backbone for the predictor as well. This should work better when your dataset is smaller.
fudge/clickbait_classifier.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import BertModel, BertConfig, PretrainedConfig, PreTrainedModel, AutoModel, AutoConfig
3
+ from typing import List, Optional, Tuple, Union
4
+ from transformers.modeling_outputs import TokenClassifierOutput,SequenceClassifierOutput
5
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss, BCELoss
6
+ import torch.nn as nn
7
+ # from modeling_mpnet import MPNetModel, MPnetConfig
8
+
9
+ class ClickbaitConfig(PretrainedConfig):
10
+ def __init__(
11
+ self,
12
+ model_type: str = "bert",
13
+ pretrained_model: str = "bert-base-uncased",
14
+ num_labels: int = 1,
15
+ dropout: float = 0.1,
16
+ inner_dim1: int = 256,
17
+ inner_dim2: int = 32,
18
+ max_length: int = 512,
19
+ load_pretrained: bool = True,
20
+ freeze_bert: bool = True,
21
+ **kwargs
22
+ ):
23
+ super(ClickbaitConfig, self).__init__(num_labels=num_labels, **kwargs)
24
+ self.model_type = model_type
25
+ self.pretrained_model = pretrained_model
26
+ self.dropout = dropout
27
+ self.inner_dim1 = inner_dim1
28
+ self.inner_dim2 = inner_dim2
29
+ self.max_length = max_length
30
+ self.load_pretrained = load_pretrained
31
+ self.freeze_bert = freeze_bert
32
+
33
+
34
+ class BertClickbaitClassifier(PreTrainedModel):
35
+ """
36
+ Taken and extended from BertforSequenceClassification : https://github.com/huggingface/transformers/blob/v4.19.2/src/transformers/models/bert/modeling_bert.py#L1508
37
+ """
38
+ config_class = ClickbaitConfig
39
+ def __init__(self, config: ClickbaitConfig):
40
+ super(BertClickbaitClassifier, self).__init__(config)
41
+ self.num_labels = config.num_labels
42
+ self.config = config
43
+ # self.bert_config = BertConfig.from_pretrained(config.pretrained_model)
44
+ self.bert_config = AutoConfig.from_pretrained(config.pretrained_model)
45
+
46
+ # self.bert = BertModel(self.bert_config)
47
+ self.bert = AutoModel.from_pretrained(config.pretrained_model, config=self.bert_config)
48
+ # self.bert = SentenceTransformer(config.pretrained_model, config=self.bert_config)
49
+ # self.bert = MPNetModel(config.pretrained_model, config=self.bert_config)
50
+ if config.load_pretrained:
51
+ print("Load pretrained weights from {}".format(config.pretrained_model))
52
+ self.bert = self.bert.from_pretrained(config.pretrained_model)
53
+ if config.freeze_bert:
54
+ print("Freeze weights in the BERT model. Just the classifier will be trained")
55
+ for param in self.bert.parameters():
56
+ param.requires_grad = False
57
+
58
+ self.linear_1 = nn.Linear(self.bert.config.hidden_size, config.inner_dim1)
59
+ self.dropout_1 = nn.Dropout(config.dropout)
60
+ self.relu_1 = nn.ReLU()
61
+ self.dropout_2 = nn.Dropout(config.dropout)
62
+ self.linear_2 = nn.Linear(config.inner_dim1, config.inner_dim2)
63
+ self.relu_2 = nn.ReLU()
64
+ self.dropout_3 = nn.Dropout(config.dropout)
65
+ self.classifier = nn.Linear(config.inner_dim2, config.num_labels)
66
+ self.sigmoid = nn.Sigmoid()
67
+
68
+
69
+ def forward(
70
+ self,
71
+ input_ids: Optional[torch.Tensor] = None,
72
+ attention_mask: Optional[torch.Tensor] = None,
73
+ token_type_ids: Optional[torch.Tensor] = None,
74
+ position_ids: Optional[torch.Tensor] = None,
75
+ head_mask: Optional[torch.Tensor] = None,
76
+ inputs_embeds: Optional[torch.Tensor] = None,
77
+ labels: Optional[torch.Tensor] = None,
78
+ output_attentions: Optional[bool] = None,
79
+ output_hidden_states: Optional[bool] = None,
80
+ return_dict: Optional[bool] = None,
81
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
82
+ r"""
83
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
84
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
85
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
86
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
87
+ """
88
+
89
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
90
+
91
+ outputs = self.bert(
92
+ input_ids,
93
+ attention_mask=attention_mask,
94
+ token_type_ids=token_type_ids,
95
+ position_ids=position_ids,
96
+ head_mask=head_mask,
97
+ inputs_embeds=inputs_embeds,
98
+ output_attentions=output_attentions,
99
+ output_hidden_states=output_hidden_states,
100
+ return_dict=return_dict,
101
+ )
102
+
103
+ output = outputs[0][:,0,:]
104
+
105
+ x = self.dropout_1(output)
106
+ x = self.linear_1(x)
107
+ x = self.relu_1(x)
108
+ x = self.dropout_2(x)
109
+ x = self.linear_2(x)
110
+ x = self.relu_2(x)
111
+ x = self.dropout_3(x)
112
+
113
+ logits = self.classifier(x)
114
+ logits = self.sigmoid(logits)
115
+
116
+ loss = None
117
+ if labels is not None:
118
+ loss_fct = BCELoss(weight=WEIGHT)
119
+ labels = 1.0*labels
120
+ loss = loss_fct(logits.view(-1), labels.view(-1))
121
+ if not return_dict:
122
+ output = (logits,) + outputs[2:]
123
+ return ((loss,) + output) if loss is not None else output
124
+
125
+ return SequenceClassifierOutput(
126
+ loss=loss,
127
+ logits=logits
128
+ )
fudge/constants.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PAD_TOKEN = '[PAD]'
2
+ EOT_TOKEN = '<|endoftext|>'
3
+ SEP = 50256 # just use the weird eot token
4
+
5
+ TOPIC_MODEL_STRING = 'gpt2-medium'
6
+ FORMALITY_MODEL_STRING = 'Helsinki-NLP/opus-mt-es-en'
7
+
8
+ DIR_END_SPLIT_POSITIONS = 32
9
+
10
+ TOPIC_VAL_SIZE = 100000
11
+ FORMALITY_VAL_SIZE = 2000
12
+ VOCAB_SIZE = 50000
13
+
14
+ FORMALITY_MAX_LEN = 200
15
+
16
+ GLOVE_PRINT_PROGRESS_FREQ = 1000000
17
+ GLOVE_DIM = 300
18
+ HIDDEN_DIM = 300
19
+ RNN_DIM = 150
20
+
21
+ MIN_SENTENCE_LENGTH = 3
22
+
23
+ POETRY_LINE_SYLLABLES = 10
24
+ MAX_SYLLABLES_PER_WORD = 10 # no way anything is more
25
+ MAX_COUNT_SYLLABLE_DIST = 10
26
+ MAX_COUNT_SYLLABLE_INPUT_LENGTH = 25 # for just a couplet, shouldn't need more
27
+ COUNT_SYLLABLE_DIM = 100
28
+ UNKNOWN_RHYME_GROUP = 'UNKNOWN_RHYME_GROUP'
29
+ PHRASE_ENDS = '.?!'
30
+
31
+ POETRY_BANNED_TOKENS = [198, 50256, 628, 220] # newlines and eos and such
32
+
fudge/data.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import math
3
+ import os
4
+ import pickle
5
+ from collections import defaultdict, namedtuple
6
+ import string
7
+
8
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false' # turn off since we're using multiple threads for loading anyway
9
+
10
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ import torch
14
+
15
+ from fudge.util import suppress_stdout
16
+ from fudge.poetry_util import is_iambic, count_syllables, get_rhymes, get_rhyme_group
17
+ from fudge.constants import *
18
+
19
+ DatasetInfo = namedtuple('DatasetInfo',
20
+ ['index2word', 'word2index', 'total_words', 'vocab', 'glove_embeddings'])
21
+ RhymeInfo = namedtuple('RhymeInfo',
22
+ ['word2rhyme_group', 'rhyme_group_counts', 'rhyme_groups', 'index2rhyme_group', 'rhyme_group2index', 'total_rhyme_groups'])
23
+
24
+ def collate(batch):
25
+ pad_id = batch[0][4]
26
+ inputs = [b[0] for b in batch]
27
+ lengths = torch.LongTensor([b[1] for b in batch])
28
+ max_length = lengths.max()
29
+ for i in range(len(inputs)):
30
+ if len(inputs[i]) < max_length:
31
+ inputs[i] = torch.cat([inputs[i], torch.zeros(max_length - len(inputs[i])).long()], dim=0) # actually 0 is fine as pad since it's masked out
32
+ inputs = torch.stack(inputs, dim=0)
33
+ future_words = torch.LongTensor([b[2] for b in batch]).unsqueeze(0).expand(len(batch), -1).clone() # batch x N=batch
34
+ labels = torch.zeros_like(future_words).long()
35
+ labels = labels.scatter(1, torch.arange(len(batch)).unsqueeze(1), torch.ones(len(batch)).long().unsqueeze(1)).clone()
36
+ log_probs = torch.Tensor([b[3] for b in batch])
37
+ classification_labels = [b[5] for b in batch] # batch
38
+ if type(classification_labels[0]) == list:
39
+ for i in range(len(classification_labels)):
40
+ assert len(classification_labels[i]) == lengths[i]
41
+ if len(classification_labels[i]) < max_length:
42
+ classification_labels[i] = torch.cat([torch.LongTensor(classification_labels[i]), -1 + torch.zeros(max_length - len(classification_labels[i])).long()], dim=0)
43
+ else:
44
+ classification_labels[i] = torch.LongTensor(classification_labels[i])
45
+ classification_labels = torch.stack(classification_labels, dim=0) # batch x seq
46
+ else:
47
+ assert type(classification_labels[0]) == int
48
+ classification_labels = torch.LongTensor(classification_labels) # they're just int labels
49
+ syllables_to_go = torch.LongTensor([b[6] for b in batch])
50
+ future_word_num_syllables = torch.LongTensor([b[7] for b in batch])
51
+ rhyme_group_index = torch.LongTensor([b[8] for b in batch])
52
+ return (inputs, lengths, future_words, log_probs, labels, classification_labels, syllables_to_go, future_word_num_syllables, rhyme_group_index)
53
+
54
+
55
+ def load_rhyme_info(index2word, vocab):
56
+ word2rhyme_group = defaultdict(lambda: UNKNOWN_RHYME_GROUP)
57
+ rhyme_group_counts = defaultdict(lambda: 0)
58
+ rhyme_groups = set()
59
+ for word in index2word:
60
+ try:
61
+ rhyme_group = get_rhyme_group(word)
62
+ word2rhyme_group[word] = rhyme_group
63
+ rhyme_group_counts[rhyme_group] += (vocab[word] if word in vocab else 1) # for rare words not in vocab, just use 1
64
+ rhyme_groups.add(rhyme_group)
65
+ except:
66
+ rhyme_group_counts[UNKNOWN_RHYME_GROUP] += (vocab[word] if word in vocab else 1)
67
+ index2rhyme_group = [UNKNOWN_RHYME_GROUP] + sorted(list(rhyme_groups))
68
+ rhyme_group2index = {s: i for i, s in enumerate(index2rhyme_group)}
69
+ total_rhyme_groups = sum(rhyme_group_counts.values())
70
+
71
+ return RhymeInfo(word2rhyme_group=dict(word2rhyme_group),
72
+ rhyme_group_counts=dict(rhyme_group_counts),
73
+ rhyme_groups=rhyme_groups,
74
+ index2rhyme_group=index2rhyme_group,
75
+ rhyme_group2index=rhyme_group2index,
76
+ total_rhyme_groups=total_rhyme_groups)
77
+
78
+
79
+ class Dataset:
80
+ def __init__(self, args):
81
+ print('loading data')
82
+ random.seed(args.seed)
83
+ self.batch_size = args.batch_size
84
+ self.data_dir = args.data_dir
85
+ self.topic = args.task == 'topic'
86
+ self.formality = args.task == 'formality'
87
+ self.iambic = args.task == 'iambic'
88
+ self.rhyme = args.task == 'rhyme'
89
+ self.newline = args.task == 'newline'
90
+
91
+ self.tokenizer = AutoTokenizer.from_pretrained(FORMALITY_MODEL_STRING if self.formality else TOPIC_MODEL_STRING)
92
+ self.tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
93
+ self.gpt_pad_id = self.tokenizer.encode(PAD_TOKEN)[0] # actually just the vocab size
94
+ sentences = []
95
+ self.vocab = defaultdict(lambda: 0)
96
+ if self.formality:
97
+ self.vocab['placeholder'] = 1 # anything so we don't crash
98
+ train, val, test = [], [], []
99
+ for category, label in [('formal', 1), ('informal', 0)]:
100
+ with open(os.path.join(args.data_dir, 'train', category), 'r') as rf:
101
+ for i, line in enumerate(rf):
102
+ if len(line) > FORMALITY_MAX_LEN:
103
+ line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len; chosen so only ~20 examples affected in dataset
104
+ if i < FORMALITY_VAL_SIZE // 2:
105
+ val.append((line.strip(), label))
106
+ else:
107
+ train.append((line.strip(), label))
108
+ with open(os.path.join(args.data_dir, 'test', category), 'r') as rf:
109
+ for line in rf:
110
+ if len(line) > FORMALITY_MAX_LEN:
111
+ line = ' '.join(line.strip()[:FORMALITY_MAX_LEN].split()[:-1]) # cutoff words until below max len
112
+ test.append((line.strip(), label))
113
+ self.splits = {}
114
+ self.splits['train'], self.splits['val'], self.splits['test'] = train, val, test
115
+ else: # topic / poetry
116
+ for root, _, filenames in os.walk(args.data_dir):
117
+ for fname in filenames:
118
+ with open(os.path.join(root, fname), 'r') as rf:
119
+ for line in rf:
120
+ sentences.append(line.strip())
121
+ for word in line.strip().split(' '):
122
+ self.vocab[word] += 1
123
+ random.shuffle(sentences)
124
+ self.splits = {}
125
+ if args.debug:
126
+ self.splits['val'] = sentences
127
+ self.splits['test'] = sentences
128
+ self.splits['train'] = sentences
129
+ else:
130
+ self.splits['val'] = sentences[:TOPIC_VAL_SIZE]
131
+ self.splits['test'] = sentences[TOPIC_VAL_SIZE:2*TOPIC_VAL_SIZE]
132
+ self.splits['train'] = sentences[2*TOPIC_VAL_SIZE:]
133
+
134
+ if args.dataset_info is not None:
135
+ print('loading dataset info from file')
136
+ with open(args.dataset_info, 'rb') as rf:
137
+ dataset_info = pickle.load(rf)
138
+ self.vocab, self.total_words, self.index2word, self.word2index, self.glove_embeddings = \
139
+ dataset_info.vocab, dataset_info.total_words, dataset_info.index2word, dataset_info.word2index, dataset_info.glove_embeddings
140
+ self.dataset_info = dataset_info
141
+ else:
142
+ print('generating dataset info from scratch')
143
+ words_values = list(self.vocab.items())
144
+ words_values = sorted(words_values, key=lambda x: x[1], reverse=True)
145
+ if args.glove_file is None:
146
+ print('no glove embeddings given')
147
+ for word, _ in words_values[VOCAB_SIZE:]: # only use somewhat common tokens
148
+ del self.vocab[word]
149
+ glove_embeddings = None
150
+ else:
151
+ print('loading glove embeddings')
152
+ glove_embeddings = {}
153
+ with open(args.glove_file, 'r') as rf:
154
+ for i, line in enumerate(rf):
155
+ if i % GLOVE_PRINT_PROGRESS_FREQ == 0:
156
+ print(i)
157
+ line = line.strip().split()
158
+ if len(line) != GLOVE_DIM + 1:
159
+ continue # skip multi-word embeddings which are rare anyway
160
+ glove_embeddings[line[0]] = [float(x) for x in line[1:]]
161
+ for word, _ in words_values:
162
+ if word not in glove_embeddings:
163
+ del self.vocab[word]
164
+ self.total_words = sum(self.vocab.values())
165
+ self.index2word = [PAD_TOKEN] + sorted(list(self.vocab.keys()))
166
+ self.word2index = {s: i for i, s in enumerate(self.index2word)}
167
+ self.vocab = dict(self.vocab) # so we can pickle later
168
+ if glove_embeddings is None:
169
+ self.glove_embeddings = None
170
+ else:
171
+ self.glove_embeddings = torch.stack([torch.zeros(GLOVE_DIM)] + [torch.Tensor(glove_embeddings[word]) for word in self.index2word[1:]], dim=0)
172
+
173
+ self.dataset_info = DatasetInfo(index2word=self.index2word,
174
+ word2index=self.word2index,
175
+ total_words=self.total_words,
176
+ vocab=self.vocab,
177
+ glove_embeddings=self.glove_embeddings)
178
+
179
+ if self.rhyme:
180
+ if args.rhyme_info is not None:
181
+ print('loading rhyme info from file')
182
+ with open(args.rhyme_info, 'rb') as rf:
183
+ self.rhyme_info = pickle.load(rf)
184
+ else:
185
+ self.rhyme_info = load_rhyme_info(self.index2word, self.vocab)
186
+ self.word2rhyme_group, self.rhyme_group_counts, self.rhyme_groups, self.index2rhyme_group, self.rhyme_group2index, self.total_rhyme_groups = \
187
+ defaultdict(lambda: UNKNOWN_RHYME_GROUP, self.rhyme_info.word2rhyme_group), self.rhyme_info.rhyme_group_counts, self.rhyme_info.rhyme_groups, self.rhyme_info.index2rhyme_group, self.rhyme_info.rhyme_group2index, self.rhyme_info.total_rhyme_groups
188
+
189
+ print('done loading data')
190
+ print('split sizes:')
191
+ for key in ['train', 'val', 'test']:
192
+ print(key, len(self.splits[key]))
193
+ if not self.formality:
194
+ print('total words', self.total_words)
195
+ print('vocab size', len(self.index2word))
196
+
197
+
198
+ def shuffle(self, split, seed=None):
199
+ assert split in ['train', 'val', 'test']
200
+ if seed is not None:
201
+ random.seed(seed)
202
+ random.shuffle(self.splits[split])
203
+
204
+
205
+ def loader(self, split, num_workers=20, indices=None):
206
+ assert split in ['train', 'val', 'test']
207
+ data = self.splits[split] if indices is None else [self.splits[split][i] for i in indices]
208
+ return torch.utils.data.DataLoader(SplitLoader(data, self), batch_size=self.batch_size, pin_memory=True, collate_fn=collate, num_workers=num_workers)
209
+
210
+
211
+ class SplitLoader(torch.utils.data.IterableDataset):
212
+ def __init__(self, data, parent):
213
+ super(SplitLoader).__init__()
214
+ self.data = data
215
+ self.pos = 0
216
+ self.parent = parent
217
+
218
+
219
+ def __len__(self):
220
+ return len(self.data)
221
+
222
+
223
+ def __iter__(self):
224
+ return self
225
+
226
+
227
+ def __next__(self):
228
+ increment = 1
229
+ worker_info = torch.utils.data.get_worker_info()
230
+ if worker_info is not None: # # in a worker process
231
+ increment = worker_info.num_workers
232
+ worker_id = worker_info.id
233
+ if self.pos == 0:
234
+ self.pos = worker_id
235
+ valid = False
236
+ while not valid:
237
+ if self.pos >= len(self):
238
+ raise StopIteration
239
+ if self.parent.topic:
240
+ failed = False
241
+ future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
242
+ raw_sentence, classification_label = self.data[self.pos], -1
243
+ original_sentence = raw_sentence.split()
244
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
245
+ length = len(sentence)
246
+ min_sentence_length = MIN_SENTENCE_LENGTH
247
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
248
+ pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
249
+ inp = sentence[:pos_to_split]
250
+ length = len(inp)
251
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
252
+ if not failed and num_words_in_input < len(original_sentence):
253
+ future_word_position_max = len(original_sentence) - 1
254
+ future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
255
+ future_word = original_sentence[future_word_position]
256
+ unstripped_future_word = future_word
257
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
258
+ if not failed and future_word in self.parent.word2index.keys():
259
+ word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model
260
+ future_word = self.parent.word2index[future_word]
261
+ pad_id = self.parent.gpt_pad_id
262
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
263
+ valid = not failed
264
+ elif self.parent.formality:
265
+ future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
266
+ raw_sentence, classification_label = self.data[self.pos]
267
+ original_sentence = raw_sentence.split()
268
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
269
+ length = len(sentence)
270
+ min_sentence_length = MIN_SENTENCE_LENGTH
271
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
272
+ pos_to_split = length # no need to split; we're going to train on all possible prefixes simultaneously for efficiency
273
+ inp = sentence[:pos_to_split]
274
+ length = len(inp)
275
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
276
+ # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
277
+ future_word_position_max = len(original_sentence) - 1
278
+ future_word_position = 0
279
+ future_word = 'placeholder'
280
+ unstripped_future_word = future_word
281
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
282
+ word_log_prob, future_word = 0, 0
283
+ pad_id = self.parent.gpt_pad_id
284
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
285
+ valid = True
286
+ elif self.parent.iambic:
287
+ failed = False
288
+ future_word_num_syllables, rhyme_group_index, syllables_to_go = -1, -1, -1
289
+ raw_sentence, classification_label = self.data[self.pos], -1
290
+ original_sentence = raw_sentence.split()
291
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
292
+ length = len(sentence)
293
+ min_sentence_length = MIN_SENTENCE_LENGTH
294
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
295
+ pos_to_split = random.randint(0, length - 1)
296
+ # try to get a subseq of exactly 10 syllables
297
+ inp = sentence[pos_to_split:]
298
+ num_syllables = 0
299
+ checked = False
300
+ for i in range(1, len(inp)):
301
+ decoded = self.parent.tokenizer.decode(inp[:i])
302
+ num_syllables = count_syllables(decoded)
303
+ if num_syllables > POETRY_LINE_SYLLABLES:
304
+ inp = inp[:i-1] # might get a few data points where the split is in the middle of a word, but it should be ok for learning.
305
+ last_line_length = i-1
306
+ decoded = self.parent.tokenizer.decode(inp)
307
+ num_syllables = count_syllables(decoded)
308
+ checked = True
309
+ break
310
+ if not checked or num_syllables != POETRY_LINE_SYLLABLES:
311
+ failed = True
312
+ length = len(inp)
313
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
314
+ classification_label = [is_iambic(self.parent.tokenizer.decode(inp)) for _ in range(length)] # predict for whole seq including future
315
+ # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
316
+ future_word_position_max = len(original_sentence) - 1
317
+ future_word_position = 0
318
+ future_word = 'placeholder'
319
+ unstripped_future_word = future_word
320
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
321
+ if not failed:
322
+ word_log_prob, future_word = 0, 0
323
+ pad_id = self.parent.gpt_pad_id
324
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
325
+ valid = not failed
326
+ elif self.parent.rhyme:
327
+ failed = False
328
+ future_word_num_syllables, rhyme_group_index = -1, -1
329
+ raw_sentence, classification_label = self.data[self.pos], -1
330
+ original_sentence = raw_sentence.split()
331
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
332
+ length = len(sentence)
333
+ min_sentence_length = MIN_SENTENCE_LENGTH
334
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
335
+ pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
336
+ inp = sentence[:pos_to_split]
337
+ length = len(inp)
338
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
339
+ if not failed and num_words_in_input < len(original_sentence):
340
+ # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
341
+ future_word_position_max = min(len(original_sentence) - 1, num_words_in_input + MAX_COUNT_SYLLABLE_DIST)
342
+ future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
343
+ future_word = original_sentence[future_word_position]
344
+ unstripped_future_word = future_word
345
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
346
+
347
+ words_in_between = original_sentence[num_words_in_input-1:future_word_position+1]
348
+ syllables_to_go = count_syllables(' '.join(words_in_between))
349
+ if syllables_to_go > MAX_COUNT_SYLLABLE_DIST:
350
+ failed = True
351
+ future_word_num_syllables = count_syllables(future_word)
352
+ rhyme_group = self.parent.word2rhyme_group[future_word]
353
+ rhyme_group_index = self.parent.rhyme_group2index[rhyme_group]
354
+ # truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose.
355
+ desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH)
356
+ inp = inp[-desired_length:]
357
+ length = len(inp)
358
+
359
+ if not failed and future_word in self.parent.word2index.keys():
360
+ word_log_prob = math.log(self.parent.rhyme_group_counts[rhyme_group] / self.parent.total_rhyme_groups)
361
+ future_word = rhyme_group_index # future conditioning is just the rhyme group in this case
362
+ pad_id = self.parent.gpt_pad_id
363
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
364
+ valid = not failed
365
+ elif self.parent.newline:
366
+ failed = False
367
+ future_word_num_syllables, rhyme_group_index = -1, -1
368
+ raw_sentence, classification_label = self.data[self.pos], -1
369
+ original_sentence = raw_sentence.split()
370
+ sentence = self.parent.tokenizer.encode(raw_sentence, return_tensors='pt')[0]
371
+ length = len(sentence)
372
+ min_sentence_length = MIN_SENTENCE_LENGTH
373
+ if len(sentence) > min_sentence_length: # set to 3. well, everything in data is > 3 for the bag of words task
374
+ pos_to_split = random.randint(1, length - 1) # for lm, learn all positions at once
375
+ inp = sentence[:pos_to_split]
376
+ while pos_to_split < len(sentence):
377
+ if len(self.parent.tokenizer.decode(inp).split()) == len(self.parent.tokenizer.decode(sentence[:pos_to_split + 1]).split()):
378
+ pos_to_split += 1
379
+ inp = sentence[:pos_to_split]
380
+ else:
381
+ break
382
+ length = len(inp)
383
+ num_words_in_input = len(self.parent.tokenizer.decode(inp).split())
384
+ if not failed and num_words_in_input < len(original_sentence):
385
+ # only look up to 10 words ahead if we're doing count syllables, since we'll filter out anything more than 10 syllables ahead anyway
386
+ future_word_position_max = len(original_sentence) - 1
387
+ future_word_position = random.randint(num_words_in_input-1, future_word_position_max) # allow the last possibly partial word though
388
+ future_word = original_sentence[future_word_position]
389
+ unstripped_future_word = future_word
390
+ future_word = future_word.strip().strip(string.punctuation) # NOTE: we didn't strip punctuation for the topic bag of words paper experiments for our method. it doesn't make much difference, though.
391
+
392
+ # future_word = original_sentence[-1] # useful for debugging
393
+ words_in_between = original_sentence[num_words_in_input-1:future_word_position+1]
394
+ syllables_to_go = count_syllables(' '.join(words_in_between))
395
+ if syllables_to_go > MAX_COUNT_SYLLABLE_DIST:
396
+ failed = True
397
+ # truncate context a bit since we're just doing couplets. random length from 1 to max desired length for this purpose.
398
+ desired_length = random.randint(1, MAX_COUNT_SYLLABLE_INPUT_LENGTH)
399
+ # desired_length = 10 # useful for debugging
400
+ inp = inp[-desired_length:]
401
+ length = len(inp)
402
+ true_label = 1 if unstripped_future_word.strip()[-1] in PHRASE_ENDS else 0 # common ways to end a phrase
403
+ classification_label = [-1 for _ in range(length)]
404
+ classification_label[-1] = true_label # only learn at the last position
405
+ if not failed and future_word in self.parent.word2index.keys():
406
+ word_log_prob = math.log(self.parent.vocab[future_word] / self.parent.total_words) # roughly baseline prob of word under noise model
407
+ future_word = self.parent.word2index[future_word]
408
+ pad_id = self.parent.gpt_pad_id
409
+ example = (inp, length, future_word, word_log_prob, pad_id, classification_label, syllables_to_go, future_word_num_syllables, rhyme_group_index)
410
+ valid = not failed
411
+ else:
412
+ raise NotImplementedError
413
+
414
+ self.pos += increment
415
+ return example
fudge/eval_formality_metrics.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import pickle
3
+ import os
4
+ import math
5
+
6
+ import sacrebleu
7
+ import numpy as np
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel
10
+
11
+ from constants import *
12
+ from model import Model
13
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
14
+
15
+ def avg_formality(preds, model, tokenizer, device='cuda'):
16
+ probs = []
17
+ for sent in preds:
18
+ encoded_input = tokenizer.encode(sent, return_tensors='pt').to(device)
19
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
20
+ scores = model(encoded_input, lengths=lengths) # batch x seq
21
+ score = scores.flatten()[-1].item()
22
+ probs.append(math.exp(score) / (1 + math.exp(score))) # sigmoided score = prob
23
+ return np.mean(probs)
24
+
25
+ if __name__=='__main__':
26
+ parser = ArgumentParser()
27
+ parser.add_argument('--pred', type=str)
28
+ parser.add_argument('--ref', type=str, nargs='*', help='bleu refs')
29
+ parser.add_argument('--ckpt', type=str, help='formality classifier')
30
+ parser.add_argument('--dataset_info', type=str)
31
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
32
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
33
+
34
+ args = parser.parse_args()
35
+
36
+ # refs = [['The dog bit the man.', 'It was not unexpected.', 'The man bit him first.'],
37
+ # ['The dog had bit the man.', 'No one was surprised.', 'The man had bitten the dog.']]
38
+ # sys = ['The dog bit the man.', "It wasn't surprising.", 'The man had just bitten him.']
39
+ print('num ref files', len(args.ref))
40
+ pred = []
41
+ with open(args.pred, 'r') as rf:
42
+ for line in rf:
43
+ pred.append(line.strip())
44
+ refs = []
45
+ for ref_file in args.ref:
46
+ ref = []
47
+ with open(ref_file, 'r') as rf:
48
+ for line in rf:
49
+ ref.append(line.strip())
50
+ assert len(ref) == len(pred)
51
+ refs.append(ref)
52
+ bleu = sacrebleu.corpus_bleu(pred, refs)
53
+ print('BLEU score:', bleu.score)
54
+
55
+ with open(args.dataset_info, 'rb') as rf:
56
+ dataset_info = pickle.load(rf)
57
+
58
+ tokenizer = MarianTokenizer.from_pretrained(args.model_string)
59
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
60
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
61
+
62
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
63
+ model_args = checkpoint['args']
64
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
65
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
66
+ conditioning_model = conditioning_model.to(args.device)
67
+ conditioning_model.eval()
68
+ print("=> loaded checkpoint '{}' (epoch {})"
69
+ .format(args.ckpt, checkpoint['epoch']))
70
+ print('num params', num_params(conditioning_model))
71
+
72
+ print('avg formality prob according to model', avg_formality(pred, conditioning_model, tokenizer, device=args.device))
73
+
fudge/eval_poetry_metrics.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import math
3
+ import string
4
+
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification
10
+
11
+ from poetry_util import is_iambic, perfect_rhyme_end, count_syllables
12
+ from constants import *
13
+
14
+
15
+ def conditional_perplexity(prefix, pred, tokenizer, model, device='cuda', sep_losses=False):
16
+ # calculate perplexity on pred only, conditioned on prefix
17
+ sentence = prefix + pred
18
+ sos_token = tokenizer.decode([0])
19
+ prefix_tensor_input = tokenizer.encode(sos_token + prefix.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
20
+ full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
21
+ if sep_losses:
22
+ prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0].sum()
23
+ full_loss = model(full_tensor_input, labels=full_tensor_input)[0].sum()
24
+ else:
25
+ prefix_loss = model(prefix_tensor_input, labels=prefix_tensor_input)[0] * (prefix_tensor_input.shape[1]-1) # neg log prob of prefix
26
+ full_loss = model(full_tensor_input, labels=full_tensor_input)[0] * (full_tensor_input.shape[1]-1) # neg log prob of full seq
27
+ pred_loss = full_loss - prefix_loss # neg log prob of preds given prefix
28
+ avg_pred_loss = pred_loss / (full_tensor_input.shape[1] - prefix_tensor_input.shape[1])
29
+ return math.exp(avg_pred_loss.item())
30
+
31
+
32
+ def grammaticality(sentences, tokenizer, model, device='cuda'):
33
+ with torch.no_grad():
34
+ total_good = 0
35
+ for sent in tqdm(sentences, total=len(sentences)):
36
+ good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1]
37
+ total_good += good_prob
38
+ return total_good / len(sentences) # avg probability of grammaticality according to model
39
+
40
+
41
+ def distinctness(sentences):
42
+ d1 = set()
43
+ d2 = set()
44
+ d3 = set()
45
+ total_words = 0
46
+ for sentence in sentences:
47
+ o = sentence.split(' ')
48
+ total_words += len(o)
49
+ d1.update(o)
50
+ for i in range(len(o) - 1):
51
+ d2.add(o[i] + '_' + o[i+1])
52
+ for i in range(len(o) - 2):
53
+ d3.add(o[i] + '_' + o[i+1] + '_' + o[i+2])
54
+ return len(d1) / total_words, len(d2) / total_words, len(d3) / total_words
55
+
56
+
57
+ if __name__=='__main__':
58
+ parser = ArgumentParser()
59
+ parser.add_argument('--pred_file', type=str)
60
+ parser.add_argument('--prefix_file', type=str)
61
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
62
+ args = parser.parse_args()
63
+
64
+ preds = []
65
+ with open(args.pred_file, 'r') as rf:
66
+ for line in rf:
67
+ preds.append(line[:-1]) # drop \n but not beginning spaces if any
68
+ prefixes = []
69
+ with open(args.prefix_file, 'r') as rf:
70
+ for line in rf:
71
+ prefixes.append(line.strip())
72
+ assert len(prefixes) == len(preds)
73
+ rhymes = 0
74
+ iambic = 0
75
+ ten_syllables = 0
76
+ end = 0
77
+ diff_rhymes = 0
78
+ all_success = 0
79
+ total = len(prefixes)
80
+ for prefix, pred in zip(prefixes, preds):
81
+ if is_iambic(pred):
82
+ iambic += 1
83
+ if perfect_rhyme_end(prefix, pred):
84
+ rhymes += 1
85
+ if prefix.split()[-1].strip(string.punctuation) != pred.split()[-1].strip(string.punctuation):
86
+ diff_rhymes += 1
87
+ if count_syllables(pred) == 10:
88
+ ten_syllables += 1
89
+ if pred.strip()[-1] in PHRASE_ENDS:
90
+ end += 1
91
+ if is_iambic(pred) and perfect_rhyme_end(prefix, pred) and count_syllables(pred) == 10 and pred.strip()[-1] in PHRASE_ENDS:
92
+ all_success += 1
93
+ print('iambic', iambic, 'out of', total, ', frac', iambic / total)
94
+ print('rhymes', rhymes, 'out of', total, ', frac', rhymes / total)
95
+ print('end sentence', end, 'out of', total, ', frac', end / total)
96
+ print('10 syllables', ten_syllables, 'out of', total, ', frac', ten_syllables / total)
97
+ print('all success', all_success, 'out of', total, ', frac', all_success / total)
98
+ print('rhymes with diff word', diff_rhymes, 'out of', total, ', frac', diff_rhymes / total)
99
+
100
+ print('distinctness', distinctness(preds))
101
+
102
+ grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA')
103
+ grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device)
104
+ grammar_model.eval()
105
+ print('grammaticality', grammaticality(preds, grammar_tokenizer, grammar_model, device=args.device))
106
+
107
+ perplexities = []
108
+ eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103')
109
+ eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device)
110
+ eval_model.eval()
111
+ for prefix, pred in zip(prefixes, preds):
112
+ perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device, sep_losses=True))
113
+ print('transformer xl perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
114
+
115
+ perplexities = []
116
+ eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
117
+ eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
118
+ eval_model.eval()
119
+ for prefix, pred in zip(prefixes, preds):
120
+ perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device))
121
+ print('gpt perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
122
+
123
+ # NOTE: uncomment this section with the path to the Shakespeare-finetuned GPT to evaluate this metric. it's in ckpt/poetry/gpt_finetune_shakespeare.pth.tar.
124
+ # eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
125
+ # eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
126
+ # checkpoint = torch.load('***PATH_TO_SHAKESPEARE_FINETUNED_GPT***', map_location=args.device)
127
+ # mod_dict = {}
128
+ # for key in checkpoint['state_dict']:
129
+ # mod_dict[key.replace('classifier.', '')] = checkpoint['state_dict'][key]
130
+ # eval_model.load_state_dict(mod_dict)
131
+ # eval_model.eval()
132
+ # perplexities = []
133
+ # for prefix, pred in zip(prefixes, preds):
134
+ # perplexities.append(conditional_perplexity(prefix, pred, eval_tokenizer, eval_model, device=args.device))
135
+ # print('shakespeare finetuned perplexity', np.mean(perplexities), '+/-', np.std(perplexities))
fudge/eval_topic_metrics.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ from collections import defaultdict
8
+ import string
9
+ import csv
10
+
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification
17
+
18
+ from data import Dataset
19
+ from model import Model
20
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
21
+ from predict import predict
22
+ from constants import *
23
+
24
+ def tw_topic_eval(sentences, category, tw_dir, cap=None):
25
+ # num matches of distinct words
26
+ words = []
27
+ with open(os.path.join(tw_dir, category + '.txt'), 'r') as rf:
28
+ for line in rf:
29
+ words.append(line.strip().lower())
30
+ num_match = 0
31
+ for sent in sentences:
32
+ sent_match = 0
33
+ sent = sent.strip().lower().split()
34
+ sent = [tok.strip(string.punctuation) for tok in sent]
35
+ for word in words:
36
+ if word in sent:
37
+ sent_match += 1
38
+ if cap is None:
39
+ num_match += sent_match
40
+ else:
41
+ num_match += min(cap, sent_match)
42
+ return num_match
43
+
44
+
45
+ def perplexity(sentences, tokenizer, model, device='cuda'):
46
+ # calculate perplexity
47
+ with torch.no_grad():
48
+ ppl = []
49
+ sos_token = tokenizer.decode([0])
50
+ for sentence in tqdm(sentences, total=len(sentences)):
51
+ full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device)
52
+ full_loss = model(full_tensor_input, labels=full_tensor_input)[0].mean()
53
+ ppl.append(torch.exp(full_loss).flatten().cpu().item())
54
+ return np.mean(ppl), np.std(ppl)
55
+
56
+
57
+ def grammaticality(sentences, tokenizer, model, device='cuda'):
58
+ with torch.no_grad():
59
+ total_good = 0
60
+ for sent in tqdm(sentences, total=len(sentences)):
61
+ good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1]
62
+ total_good += good_prob
63
+ return total_good / len(sentences) # avg probability of grammaticality according to model
64
+
65
+
66
+ def distinctness(results):
67
+ d1, d2, d3 = defaultdict(lambda: set()), defaultdict(lambda: set()), defaultdict(lambda: set())
68
+ total_words = defaultdict(lambda: 0)
69
+ for cw, outputs in results.items():
70
+ for o in outputs:
71
+ o = o.replace(EOT_TOKEN, ' ').strip().split(' ')
72
+ o = [str(x) for x in o]
73
+ total_words[cw] += len(o)
74
+ d1[cw].update(o)
75
+ for i in range(len(o) - 1):
76
+ d2[cw].add(o[i] + ' ' + o[i+1])
77
+ for i in range(len(o) - 2):
78
+ d3[cw].add(o[i] + ' ' + o[i+1] + ' ' + o[i+2])
79
+ return_info = []
80
+ avg_d1, avg_d2, avg_d3 = 0, 0, 0
81
+ for cw in total_words.keys():
82
+ return_info.append((cw, 'DISTINCTNESS', len(d1[cw]) / total_words[cw], len(d2[cw]) / total_words[cw], len(d3[cw]) / total_words[cw]))
83
+ avg_d1 += len(d1[cw]) / total_words[cw]
84
+ avg_d2 += len(d2[cw]) / total_words[cw]
85
+ avg_d3 += len(d3[cw]) / total_words[cw]
86
+ avg_d1, avg_d2, avg_d3 = avg_d1 / len(total_words.keys()), avg_d2 / len(total_words.keys()), avg_d3 / len(total_words.keys())
87
+ return return_info, (avg_d1, avg_d2, avg_d3)
88
+
89
+
90
+ if __name__=='__main__':
91
+ parser = ArgumentParser()
92
+ parser.add_argument('--log_file', type=str, required=True, help='where to load results from')
93
+ parser.add_argument('--tw_dir', type=str, default='test_wordlists', help='test wordlists')
94
+ parser.add_argument('--batch_size', type=int, default=8, help='max samples at a time')
95
+ parser.add_argument('--cap_per_example', type=int, default=None, help='max matches to count per sentence')
96
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
97
+ args = parser.parse_args()
98
+
99
+ tw_topic_match_c_total = 0
100
+ category_totals_c = defaultdict(lambda:0)
101
+ results = defaultdict(lambda: [])
102
+ with open(args.log_file, 'r') as rf:
103
+ data = list(csv.DictReader(rf))
104
+ for line in data:
105
+ results[line['category']].append(line['generation'])
106
+
107
+ all_c_sents = []
108
+ for category, condition_results in results.items():
109
+ tw_topic_match_c = tw_topic_eval(condition_results, category, args.tw_dir, cap=args.cap_per_example)
110
+ tw_topic_match_c_total += tw_topic_match_c
111
+ category_totals_c[category] += tw_topic_match_c
112
+ all_c_sents += condition_results
113
+
114
+ print('Test wordlist matches (divide by num outputs to get the Success metric):', tw_topic_match_c_total)
115
+ print('per category:', category_totals_c)
116
+
117
+ dist_info_by_category, dist_overall = distinctness(results)
118
+ print('Overall avg distinctness:', dist_overall)
119
+ print('per category:', dist_info_by_category)
120
+
121
+ grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA')
122
+ grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device)
123
+ grammar_model.eval()
124
+ print('grammaticality:', grammaticality(all_c_sents, grammar_tokenizer, grammar_model, device=args.device))
125
+
126
+ eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt')
127
+ eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device)
128
+ eval_model.eval()
129
+ print('GPT perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model))
130
+
131
+ eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103')
132
+ eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device)
133
+ eval_model.eval()
134
+ print('TFXL perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model))
fudge/evaluate_clickbait.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from typing import Iterable, List, Optional, Tuple
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead
16
+ from torch import Tensor
17
+
18
+ from fudge.data import Dataset
19
+ from fudge.model import Model
20
+ from fudge.util import num_params
21
+ from fudge.constants import *
22
+
23
+
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained('google/pegasus-xsum')
26
+ classifier_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
27
+
28
+
29
+ def main(args):
30
+ with open(args.dataset_info, 'rb') as rf:
31
+ dataset_info = pickle.load(rf)
32
+
33
+ article_content = """Australian actor Guy Pearce will return for the iconic soap Neighbours finale on August 1 to reprise his role as Mike Young.
34
+ Guy, 54, played the troubled Mike from 1986 to 1989, and is now set to make a comeback on the show after 33 years, Metro.co.uk reports.
35
+ The star's character arcs explored the implications of domestic abuse, student-teacher relationships and dealing with loss of loved ones.
36
+ Speaking to Metro.co.uk, Guy said: 'It is very exciting and surreal at the same time being back on set again, however it feels like coming home.
37
+ 'It's where it all started for me professionally. I've been asked to come back on occasions over the years and wondered if it was the right thing
38
+ to do, but once I knew the show was finishing, I knew I had to do it.'He added that there is 'nothing like being here all together again'
39
+ , even though he's had a chance to catch-up with other cast members."""
40
+
41
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
42
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
43
+
44
+ #For loading Clickbait summarizer
45
+ model = AutoModelWithLMHead.from_pretrained(args.model_string, return_dict=True).to(args.device)
46
+
47
+ model.eval()
48
+
49
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
50
+ model_args = checkpoint['args']
51
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
52
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
53
+ conditioning_model = conditioning_model.to(args.device)
54
+ conditioning_model.eval()
55
+ print("=> loaded checkpoint '{}' (epoch {})"
56
+ .format(args.ckpt, checkpoint['epoch']))
57
+ print('num params', num_params(conditioning_model))
58
+
59
+ while True:
60
+ results = generate_clickbait(model,
61
+ tokenizer,
62
+ conditioning_model,
63
+ [args.input_text],
64
+ dataset_info,
65
+ precondition_topk=args.precondition_topk,
66
+ do_sample=args.do_sample,
67
+ length_cutoff=args.length_cutoff,
68
+ condition_lambda=args.condition_lambda,
69
+ article_content=article_content,
70
+ device=args.device)
71
+ # print(results)
72
+ import pdb; pdb.set_trace()
73
+
74
+
75
+ def generate_clickbait(model,
76
+ tokenizer,
77
+ conditioning_model,
78
+ input_text,
79
+ dataset_info,
80
+ precondition_topk,
81
+ length_cutoff,
82
+ condition_lambda=1.0,
83
+ article_content=None,
84
+ device='cuda'):
85
+ with torch.no_grad():
86
+ batch_size = len(input_text)
87
+ # encoded_input_article = [tokenizer.encode(article_content, return_tensors='pt',add_special_tokens=False).to(device)] # batch x seq
88
+ encoded_input_article = tokenizer(article_content, return_tensors='pt',add_special_tokens=False, max_length=512).to(device) # batch x seq
89
+ # encoded_input_article = torch.cat(encoded_input_article, dim=0)
90
+ # attention_mask = encoded_input_article.new_ones(encoded_input_article.shape).to(device)
91
+
92
+ # CHANGE=ko
93
+ encoded_input = tokenizer('<pad>', return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
94
+ # encoded_input = tokenizer('<pad>'+ input_text[0], return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
95
+ # encoded_input = torch.cat(encoded_input, dim=0)
96
+ encoded_input = encoded_input['input_ids']
97
+
98
+
99
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
100
+ # lengths = 1
101
+
102
+ past = None
103
+ use_cache = True
104
+
105
+ # CHANGE
106
+ # model_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input_article, attention_mask=attention_mask)}
107
+ # print(encoded_input_article)
108
+ # print(encoded_input_article['input_ids'].shape, encoded_input_article['attention_mask'].shape)
109
+ model_kwargs = {'encoder_outputs': model.get_encoder()(input_ids=encoded_input_article['input_ids'],
110
+ attention_mask=encoded_input_article['attention_mask'],
111
+ return_dict=True,
112
+ output_attentions=False,
113
+ output_hidden_states=False),
114
+ }
115
+
116
+ while lengths.max() < length_cutoff:
117
+ model_inputs = model.prepare_inputs_for_generation(
118
+ input_ids = encoded_input_article['input_ids'],
119
+ decoder_input_ids=encoded_input,
120
+ # past=past,
121
+ attention_mask=encoded_input_article['attention_mask'],
122
+ use_cache=use_cache,
123
+ **model_kwargs
124
+ )
125
+
126
+ outputs = model(**model_inputs, return_dict=True)
127
+ logits = outputs.logits[:, -1, :]
128
+
129
+ if "past_key_values" in outputs:
130
+ model_kwargs["past"] = outputs.past_key_values
131
+
132
+ # logits = model(encoded_input)[0][:, -1, :] # batch x vocab
133
+ top_logits, top_indices = logits.topk(precondition_topk, dim=1) # batch x topk
134
+ new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
135
+ expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
136
+
137
+ if condition_lambda == 0:
138
+ condition_logits = torch.zeros_like(top_logits).float()
139
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
140
+ else:
141
+ decoded_outputs = tokenizer.batch_decode(new_input_candidates.view(-1, new_input_candidates.size(-1)), clean_up_tokenization_spaces=False)
142
+ resulting_tokenization = classifier_tokenizer(decoded_outputs, add_special_tokens=False, padding='longest')
143
+ encoded_with_classifier = resulting_tokenization['input_ids']
144
+ attention_mask = torch.tensor(resulting_tokenization['attention_mask']).to(model.device)
145
+ tplus1_candidates_classifier = torch.tensor(encoded_with_classifier).view(batch_size, precondition_topk, -1).to(model.device)
146
+
147
+ condition_logits = conditioning_model(tplus1_candidates_classifier.flatten(0, 1), # batch*topk x seq+1
148
+ expanded_lengths.flatten(0, 1), # batch*topk
149
+ None,
150
+ None,
151
+ None,
152
+ attention_mask=attention_mask
153
+ )
154
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
155
+ condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
156
+
157
+ condition_logits = torch.mean(condition_logits, dim=2)
158
+ full_logits = top_logits + condition_logits * condition_lambda # batch x topk
159
+ post_logits, post_indices = full_logits.topk(precondition_topk, dim=1)
160
+ post_probs = F.softmax(post_logits, dim=1)
161
+ # index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
162
+ index_into_top_indices = post_indices[:, torch.multinomial(post_probs, 1).flatten()] # batch
163
+
164
+ # next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
165
+ next_indices = top_indices[:, index_into_top_indices] # batch
166
+
167
+ # encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
168
+ encoded_input = torch.cat([encoded_input, next_indices.squeeze(1)], dim=1)
169
+ lengths = lengths + 1 # batch
170
+
171
+ # print(tokenizer.decode(encoded_input[0], add_special_tokens=False))
172
+ return [tokenizer.decode(s) for s in encoded_input]
173
+
174
+
175
+ if __name__=='__main__':
176
+ parser = ArgumentParser()
177
+
178
+ # DATA
179
+ parser.add_argument('--ckpt', type=str, required=True)
180
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
181
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
182
+
183
+ parser.add_argument('--in_file', type=str, default=None, required=True, help='text to run pred on')
184
+
185
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from text generation at each step before conditioning and re-pruning')
186
+ parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
187
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
188
+ parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
189
+
190
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
191
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
192
+ parser.add_argument('--debug', action='store_true', default=False)
193
+
194
+ args = parser.parse_args()
195
+
196
+ random.seed(args.seed)
197
+ np.random.seed(args.seed)
198
+ torch.manual_seed(args.seed)
199
+
200
+ main(args)
fudge/evaluate_formality.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ from collections import namedtuple
8
+
9
+ from tqdm import tqdm
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel
15
+
16
+ from data import Dataset
17
+ from model import Model
18
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
19
+ from constants import *
20
+ from predict_formality import predict_formality
21
+
22
+ def main(args):
23
+ with open(args.dataset_info, 'rb') as rf:
24
+ dataset_info = pickle.load(rf)
25
+ tokenizer = MarianTokenizer.from_pretrained(args.model_string)
26
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
27
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
28
+ model = MarianMTModel.from_pretrained(args.model_string, return_dict=True).to(args.device)
29
+ if args.model_path is not None:
30
+ if os.path.isdir(args.model_path):
31
+ for _, _, files in os.walk(args.model_path):
32
+ for fname in files:
33
+ if fname.endswith('.ckpt'):
34
+ args.model_path = os.path.join(args.model_path, fname)
35
+ break
36
+ ckpt = torch.load(args.model_path, map_location=torch.device(args.device))
37
+ try:
38
+ model.load_state_dict(ckpt['state_dict'], strict=False)
39
+ except:
40
+ state_dict = {}
41
+ for key in ckpt['state_dict'].keys():
42
+ assert key.startswith('model.')
43
+ state_dict[key[6:]] = ckpt['state_dict'][key]
44
+ model.load_state_dict(state_dict)
45
+ model.eval()
46
+
47
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
48
+ model_args = checkpoint['args']
49
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
50
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
51
+ conditioning_model = conditioning_model.to(args.device)
52
+ conditioning_model.eval()
53
+ if args.verbose:
54
+ print("=> loaded checkpoint '{}' (epoch {})"
55
+ .format(args.ckpt, checkpoint['epoch']))
56
+ print('num params', num_params(conditioning_model))
57
+
58
+ inputs = []
59
+ with open(args.in_file, 'r') as rf:
60
+ for line in rf:
61
+ inputs.append(line.strip())
62
+
63
+ for inp in tqdm(inputs, total=len(inputs)):
64
+ results = predict_formality(model,
65
+ tokenizer,
66
+ conditioning_model,
67
+ [inp],
68
+ dataset_info,
69
+ precondition_topk=args.precondition_topk,
70
+ do_sample=args.do_sample,
71
+ length_cutoff=args.length_cutoff,
72
+ condition_lambda=args.condition_lambda,
73
+ device=args.device)
74
+ print(results[0])
75
+
76
+
77
+ if __name__=='__main__':
78
+ parser = ArgumentParser()
79
+
80
+ # DATA
81
+ parser.add_argument('--ckpt', type=str, required=True)
82
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
83
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
84
+ parser.add_argument('--model_path', type=str, default=None)
85
+
86
+ parser.add_argument('--in_file', type=str, default=None, required=True, help='file containing text to run pred on')
87
+
88
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
89
+ parser.add_argument('--do_sample', action='store_true', default=False, help='sample or greedy; only greedy implemented')
90
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
91
+ parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
92
+
93
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
94
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
95
+ parser.add_argument('--debug', action='store_true', default=False)
96
+ parser.add_argument('--verbose', action='store_true', default=False)
97
+
98
+ args = parser.parse_args()
99
+
100
+ random.seed(args.seed)
101
+ np.random.seed(args.seed)
102
+ torch.manual_seed(args.seed)
103
+
104
+ main(args)
fudge/evaluate_poetry.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ import string
8
+ from collections import defaultdict
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
16
+
17
+ from data import Dataset, load_rhyme_info
18
+ from model import Model
19
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
20
+ from constants import *
21
+ from poetry_util import get_rhymes, count_syllables
22
+ from predict_poetry import predict_couplet
23
+
24
+ def main(args):
25
+ with open(args.dataset_info, 'rb') as rf:
26
+ dataset_info = pickle.load(rf)
27
+ gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
28
+ gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
29
+ gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
30
+ gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
31
+ gpt_model.eval()
32
+
33
+ checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
34
+ model_args = checkpoint['args']
35
+ iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
36
+ iambic_model.load_state_dict(checkpoint['state_dict'])
37
+ iambic_model = iambic_model.to(args.device)
38
+ iambic_model.eval()
39
+ if args.verbose:
40
+ print("=> loaded checkpoint '{}' (epoch {})"
41
+ .format(args.iambic_ckpt, checkpoint['epoch']))
42
+ print('iambic model num params', num_params(iambic_model))
43
+
44
+ with open(args.rhyme_info, 'rb') as rf:
45
+ rhyme_info = pickle.load(rf)
46
+ checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
47
+ model_args = checkpoint['args']
48
+ rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group), verbose=args.verbose) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
49
+ rhyme_model.load_state_dict(checkpoint['state_dict'])
50
+ rhyme_model = rhyme_model.to(args.device)
51
+ rhyme_model.eval()
52
+ if args.verbose:
53
+ print("=> loaded checkpoint '{}' (epoch {})"
54
+ .format(args.rhyme_ckpt, checkpoint['epoch']))
55
+ print('rhyme model num params', num_params(rhyme_model))
56
+
57
+ checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
58
+ model_args = checkpoint['args']
59
+ newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
60
+ newline_model.load_state_dict(checkpoint['state_dict'])
61
+ newline_model = newline_model.to(args.device)
62
+ newline_model.eval()
63
+ if args.verbose:
64
+ print("=> loaded checkpoint '{}' (epoch {})"
65
+ .format(args.newline_ckpt, checkpoint['epoch']))
66
+ print('iambic model num params', num_params(newline_model))
67
+
68
+ with open(args.prefix_file, 'r') as rf:
69
+ lines = rf.readlines()
70
+ for line in tqdm(lines, total=len(lines)):
71
+ couplet = predict_couplet(gpt_model,
72
+ gpt_tokenizer,
73
+ iambic_model,
74
+ rhyme_model,
75
+ newline_model,
76
+ [line],
77
+ dataset_info,
78
+ rhyme_info,
79
+ args.precondition_topk,
80
+ args.topk,
81
+ condition_lambda=args.condition_lambda,
82
+ device=args.device)
83
+ assert len(couplet) == 2
84
+ print(couplet[1].strip().replace('\n', ''))
85
+
86
+
87
+ if __name__=='__main__':
88
+ parser = ArgumentParser()
89
+
90
+ # DATA
91
+ parser.add_argument('--iambic_ckpt', type=str, required=True)
92
+ parser.add_argument('--rhyme_ckpt', type=str, required=True)
93
+ parser.add_argument('--newline_ckpt', type=str, required=True)
94
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
95
+ parser.add_argument('--rhyme_info', type=str, required=True, help='saved rhyme info')
96
+ parser.add_argument('--model_string', type=str, default='gpt2-medium')
97
+
98
+ parser.add_argument('--prefix_file', type=str, default=None, required=True, help='file of prefix lines for couplets')
99
+
100
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
101
+ parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
102
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
103
+
104
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
105
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
106
+ parser.add_argument('--debug', action='store_true', default=False)
107
+ parser.add_argument('--verbose', action='store_true', default=False)
108
+
109
+ args = parser.parse_args()
110
+
111
+ random.seed(args.seed)
112
+ np.random.seed(args.seed)
113
+ torch.manual_seed(args.seed)
114
+
115
+ main(args)
fudge/evaluate_topic.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ from collections import defaultdict
8
+ import string
9
+ import csv
10
+
11
+ from tqdm import tqdm
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
17
+
18
+ from data import Dataset
19
+ from model import Model
20
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
21
+ from predict_topic import predict
22
+ from constants import *
23
+
24
+
25
+ def main(args):
26
+ with open(args.dataset_info, 'rb') as rf:
27
+ dataset_info = pickle.load(rf)
28
+ gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
29
+ gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
30
+ gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
31
+ gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
32
+ gpt_model.eval()
33
+
34
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
35
+ model_args = checkpoint['args']
36
+ conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
37
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
38
+ conditioning_model = conditioning_model.to(args.device)
39
+ conditioning_model.eval()
40
+ if args.verbose:
41
+ print("=> loaded checkpoint '{}' (epoch {})"
42
+ .format(args.ckpt, checkpoint['epoch']))
43
+ print('num params', num_params(conditioning_model))
44
+
45
+ input_texts, conditions, categories = [], [], []
46
+
47
+ if args.condition_file is not None:
48
+ with open(args.condition_file, 'r') as rf:
49
+ for line in rf:
50
+ input_texts.append(line.strip().split('\t')[0])
51
+ conditions.append(line.strip().split('\t')[1])
52
+ categories.append(None)
53
+ for cw in conditions[-1].split():
54
+ assert cw in dataset_info.word2index
55
+ else:
56
+ prefixes = []
57
+ with open(args.prefix_file, 'r') as rf:
58
+ for line in rf:
59
+ prefixes.append(line.strip())
60
+ condition_wordlists = []
61
+ for root, _, files in os.walk(args.wordlist_dir):
62
+ for fname in files:
63
+ words = []
64
+ with open(os.path.join(root, fname), 'r') as rf:
65
+ for line in rf:
66
+ word = line.strip()
67
+ if word in dataset_info.word2index:
68
+ words.append(word)
69
+ else:
70
+ if args.verbose:
71
+ print('word not found:', word)
72
+ condition_wordlists.append((' '.join(words), fname.split('.')[0]))
73
+ for p in prefixes:
74
+ for c, category in condition_wordlists:
75
+ input_texts.append(p)
76
+ conditions.append(c)
77
+ categories.append(category)
78
+
79
+ all_cr = []
80
+ pair_num = 0
81
+ for input_text, condition_words, category in tqdm(zip(input_texts, conditions, categories), total=len(conditions)):
82
+ predict_function = predict
83
+ condition_results = []
84
+ for i in range(0, args.sample_size, args.max_sample_batch):
85
+ num_samples = min(args.max_sample_batch, args.sample_size - i)
86
+ condition_results += predict_function(gpt_model,
87
+ gpt_tokenizer,
88
+ conditioning_model,
89
+ [input_text for _ in range(num_samples)],
90
+ condition_words,
91
+ dataset_info,
92
+ args.precondition_topk,
93
+ args.topk,
94
+ args.length_cutoff,
95
+ condition_lambda=args.condition_lambda,
96
+ device=args.device)
97
+ all_cr.append((input_text, category, condition_results))
98
+ pair_num += 1
99
+ if args.max_pairs > 0 and pair_num >= args.max_pairs:
100
+ break
101
+ with open(args.log_file, 'w') as wf:
102
+ writer = csv.DictWriter(wf, fieldnames=['category', 'input_text', 'generation'])
103
+ writer.writeheader()
104
+ for cr_group in all_cr:
105
+ for cr in cr_group[2]:
106
+ writer.writerow({'category': cr_group[1], 'input_text': cr_group[0], 'generation': cr})
107
+
108
+
109
+ if __name__=='__main__':
110
+ parser = ArgumentParser()
111
+
112
+ # DATA
113
+ parser.add_argument('--ckpt', type=str, required=True)
114
+ parser.add_argument('--log_file', type=str, required=True, help='file to write outputs to (csv format)')
115
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
116
+ parser.add_argument('--model_string', type=str, default='gpt2-medium')
117
+
118
+ parser.add_argument('--condition_file', type=str, default=None, help='file of inputs and conditions')
119
+ parser.add_argument('--prefix_file', type=str, default=None, help='prefix set')
120
+ parser.add_argument('--wordlist_dir', type=str, default=None, help='dir of bow wordlists for categories')
121
+ parser.add_argument('--sample_size', type=int, default=3, help='samples per input text-condition pair')
122
+ parser.add_argument('--max_sample_batch', type=int, default=3, help='max samples at a time')
123
+ parser.add_argument('--max_pairs', type=int, default=-1, help='max input-condition pairs, for debugging quickly')
124
+
125
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
126
+ parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
127
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
128
+ parser.add_argument('--length_cutoff', type=int, default=80, help='max length')
129
+
130
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
131
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
132
+ parser.add_argument('--debug', action='store_true', default=False)
133
+ parser.add_argument('--verbose', action='store_true', default=False)
134
+
135
+ args = parser.parse_args()
136
+
137
+ assert (args.condition_file is not None) != (args.prefix_file is not None and args.wordlist_dir is not None) # one of two interfaces for specifying
138
+
139
+ random.seed(args.seed)
140
+ np.random.seed(args.seed)
141
+ torch.manual_seed(args.seed)
142
+
143
+ main(args)
fudge/formality_data/README.md ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ `fisher_test_oracle.es` is the source-side Spanish test set.
2
+ `test_noid.cleaned_0` and `test_noid.cleaned_1` are Salesky 2019's fluent English test-time references.
fudge/formality_data/fisher_test_oracle.es ADDED
The diff for this file is too large to render. See raw diff
 
fudge/formality_data/test.noid.cleaned_0 ADDED
The diff for this file is too large to render. See raw diff
 
fudge/formality_data/test.noid.cleaned_1 ADDED
The diff for this file is too large to render. See raw diff
 
fudge/main.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ from data import Dataset
14
+ from model import Model
15
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask
16
+ from constants import *
17
+
18
+
19
+ def train(model, dataset, optimizer, criterion, epoch, args, data_start_index):
20
+ model.train()
21
+ if data_start_index == 0:
22
+ dataset.shuffle('train', seed=epoch + args.seed)
23
+ if args.epoch_max_len is not None:
24
+ data_end_index = min(data_start_index + args.epoch_max_len, len(dataset.splits['train']))
25
+ loader = dataset.loader('train', num_workers=args.num_workers, indices=list(range(data_start_index, data_end_index)))
26
+ data_start_index = data_end_index if data_end_index < len(dataset.splits['train']) else 0
27
+ else:
28
+ loader = dataset.loader('train', num_workers=args.num_workers)
29
+ loss_meter = AverageMeter('loss', ':6.4f')
30
+ total_length = len(loader)
31
+ progress = ProgressMeter(total_length, [loss_meter], prefix='Training: ')
32
+ for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
33
+ batch = [tensor.to(args.device) for tensor in batch]
34
+ inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
35
+ if args.task not in ['formality', 'iambic']:
36
+ if not args.debug and len(inputs) != args.batch_size: # it'll screw up the bias...?
37
+ continue
38
+ scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True)
39
+ if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq
40
+ expanded_labels = classification_targets.unsqueeze(1).expand(-1, scores.shape[1]) # batch x seq
41
+ length_mask = pad_mask(lengths).permute(1, 0) # batch x seq
42
+ loss = criterion(scores.flatten()[length_mask.flatten()==1], expanded_labels.flatten().float()[length_mask.flatten()==1])
43
+ elif args.task in ['iambic', 'newline']:
44
+ use_indices = classification_targets.flatten() != -1
45
+ loss = criterion(scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices])
46
+ else: # topic, rhyme
47
+ loss = criterion(scores.flatten(), labels.flatten().float())
48
+ optimizer.zero_grad()
49
+ loss.backward()
50
+ optimizer.step()
51
+ loss_meter.update(loss.detach(), len(labels))
52
+ if batch_num % args.train_print_freq == 0:
53
+ progress.display(batch_num)
54
+ progress.display(total_length)
55
+ return data_start_index
56
+
57
+
58
+ def validate(model, dataset, criterion, epoch, args):
59
+ model.eval()
60
+ random.seed(0)
61
+ loader = dataset.loader('val', num_workers=args.num_workers)
62
+ loss_meter = AverageMeter('loss', ':6.4f')
63
+ total_length = len(loader)
64
+ progress = ProgressMeter(total_length, [loss_meter], prefix='Validation: ')
65
+ with torch.no_grad():
66
+ for batch_num, batch in enumerate(tqdm(loader, total=len(loader))):
67
+ batch = [tensor.to(args.device) for tensor in batch]
68
+ inputs, lengths, future_words, log_probs, labels, classification_targets, syllables_to_go, future_word_num_syllables, rhyme_group_index = batch
69
+ if args.task not in ['formality', 'iambic']: # topic predictor
70
+ if not args.debug and len(inputs) != args.batch_size:
71
+ continue
72
+ scores = model(inputs, lengths, future_words, log_probs, syllables_to_go, future_word_num_syllables, rhyme_group_index, run_classifier=True)
73
+ if args.task == 'formality': # we're learning for all positions at once. scores are batch x seq
74
+ expanded_labels = classification_targets.unsqueeze(1).expand(-1, scores.shape[1]) # batch x seq
75
+ length_mask = pad_mask(lengths).permute(1, 0) # batch x seq
76
+ loss = criterion(scores.flatten()[length_mask.flatten()==1], expanded_labels.flatten().float()[length_mask.flatten()==1])
77
+ elif args.task in ['iambic', 'newline']:
78
+ use_indices = classification_targets.flatten() != -1
79
+ loss = criterion(scores.flatten()[use_indices], classification_targets.flatten().float()[use_indices])
80
+ else: # topic, rhyme
81
+ loss = criterion(scores.flatten(), labels.flatten().float())
82
+ loss_meter.update(loss.detach(), len(labels))
83
+ if batch_num % args.train_print_freq == 0:
84
+ progress.display(batch_num)
85
+ progress.display(total_length)
86
+ return loss_meter.avg
87
+
88
+
89
+ def main(args):
90
+ dataset = Dataset(args)
91
+ os.makedirs(args.save_dir, exist_ok=True)
92
+ with open(os.path.join(args.save_dir, 'dataset_info'), 'wb') as wf:
93
+ pickle.dump(dataset.dataset_info, wf)
94
+ if args.task == 'rhyme':
95
+ with open(os.path.join(args.save_dir, 'rhyme_info'), 'wb') as wf:
96
+ pickle.dump(dataset.rhyme_info, wf)
97
+ if args.ckpt:
98
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
99
+ start_epoch = checkpoint['epoch'] + 1
100
+ best_val_metric = checkpoint['best_metric']
101
+ model_args = checkpoint['args']
102
+ model = Model(model_args, dataset.gpt_pad_id, len(dataset.index2word), rhyme_group_size=len(dataset.index2rhyme_group) if args.task == 'rhyme' else None) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
103
+ model.load_state_dict(checkpoint['state_dict'])
104
+ model = model.to(args.device)
105
+ optimizer = torch.optim.Adam(model.parameters(), lr=model_args.lr)
106
+ optimizer.load_state_dict(checkpoint['optimizer'])
107
+ data_start_index = checkpoint['data_start_index']
108
+ print("=> loaded checkpoint '{}' (epoch {})"
109
+ .format(args.ckpt, checkpoint['epoch']))
110
+ # NOTE: just import pdb after loading the model here if you want to play with it, it's easy
111
+ # model.eval()
112
+ # import pdb; pdb.set_trace()
113
+ else:
114
+ model = Model(args, dataset.gpt_pad_id, len(dataset.index2word), rhyme_group_size=len(dataset.index2rhyme_group) if args.task == 'rhyme' else None, glove_embeddings=dataset.glove_embeddings)
115
+ model = model.to(args.device)
116
+ optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
117
+ best_val_metric = 1e8 # lower is better for BCE
118
+ data_start_index = 0
119
+ print('num params', num_params(model))
120
+ criterion = nn.BCEWithLogitsLoss().to(args.device)
121
+
122
+ if args.evaluate:
123
+ epoch = 0
124
+ validate(model, dataset, criterion, epoch, args)
125
+ return
126
+ for epoch in range(args.epochs):
127
+ print("TRAINING: Epoch {} at {}".format(epoch, time.ctime()))
128
+ data_start_index = train(model, dataset, optimizer, criterion, epoch, args, data_start_index)
129
+ if epoch % args.validation_freq == 0:
130
+ print("VALIDATION: Epoch {} at {}".format(epoch, time.ctime()))
131
+ metric = validate(model, dataset, criterion, epoch, args)
132
+
133
+ if not args.debug:
134
+ if metric < best_val_metric:
135
+ print('new best val metric', metric)
136
+ best_val_metric = metric
137
+ save_checkpoint({
138
+ 'epoch': epoch,
139
+ 'state_dict': model.state_dict(),
140
+ 'best_metric': best_val_metric,
141
+ 'optimizer': optimizer.state_dict(),
142
+ 'data_start_index': data_start_index,
143
+ 'args': args
144
+ }, os.path.join(args.save_dir, 'model_best.pth.tar'))
145
+ save_checkpoint({
146
+ 'epoch': epoch,
147
+ 'state_dict': model.state_dict(),
148
+ 'best_metric': metric,
149
+ 'optimizer': optimizer.state_dict(),
150
+ 'data_start_index': data_start_index,
151
+ 'args': args
152
+ }, os.path.join(args.save_dir, 'model_epoch' + str(epoch) + '.pth.tar'))
153
+
154
+
155
+ if __name__=='__main__':
156
+ parser = ArgumentParser()
157
+
158
+ # DATA
159
+ parser.add_argument('--task', type=str, required=True, choices=['iambic', 'rhyme', 'newline', 'topic', 'formality', 'clickbait'])
160
+ parser.add_argument('--data_dir', type=str, required=True)
161
+ parser.add_argument('--glove_file', type=str, help='glove embedding init, for topic task')
162
+
163
+ # SAVE/LOAD
164
+ parser.add_argument('--save_dir', type=str, required=True, help='where to save ckpts')
165
+ parser.add_argument('--ckpt', type=str, default=None, help='load ckpt from file if given')
166
+ parser.add_argument('--dataset_info', type=str, help='saved dataset info')
167
+ parser.add_argument('--rhyme_info', type=str, help='saved dataset rhyme info, for a ckpt with task==rhyme')
168
+
169
+ # TRAINING
170
+ parser.add_argument('--batch_size', type=int, default=128)
171
+ parser.add_argument('--epochs', type=int, default=100)
172
+ parser.add_argument('--epoch_max_len', type=int, default=None, help='max batches per epoch if set, for more frequent validation')
173
+ parser.add_argument('--validation_freq', type=int, default=1, help='validate every X epochs')
174
+ parser.add_argument('--lr', type=float, default=1e-3, help='Adam learning rate')
175
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
176
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
177
+ parser.add_argument('--num_workers', type=int, default=20, help='num workers for data loader')
178
+ parser.add_argument('--evaluate', action='store_true', default=False)
179
+ parser.add_argument('--debug', action='store_true', default=False)
180
+
181
+ # PRINTING
182
+ parser.add_argument('--train_print_freq', type=int, default=100, help='how often to print metrics (every X batches)')
183
+
184
+ args = parser.parse_args()
185
+
186
+ random.seed(args.seed)
187
+ np.random.seed(args.seed)
188
+ torch.manual_seed(args.seed)
189
+ if args.evaluate:
190
+ assert args.ckpt is not None
191
+
192
+ main(args)
fudge/model.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn.utils.rnn import pad_sequence, pad_packed_sequence, pack_padded_sequence
7
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, GPT2LMHeadModel, GPT2Config, GPT2ForSequenceClassification, GPT2LMHeadModel, MarianTokenizer
8
+
9
+ from fudge.constants import *
10
+ from fudge.util import pad_mask
11
+ from fudge.clickbait_classifier import BertClickbaitClassifier, ClickbaitConfig
12
+
13
+ class Model(nn.Module):
14
+ def __init__(self, args, gpt_pad_id, vocab_size, rhyme_group_size=None, glove_embeddings=None, verbose=True):
15
+ super(Model, self).__init__()
16
+
17
+ # self.topic = args.task == 'topic'
18
+ self.formality = args.task == 'formality'
19
+ self.iambic = args.task == 'iambic'
20
+ self.rhyme = args.task == 'rhyme'
21
+ self.newline = args.task == 'newline'
22
+ self.clickbait = args.task == 'clickbait'
23
+ # if self.topic:
24
+ # self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
25
+ # if glove_embeddings is None:
26
+ # if verbose:
27
+ # print('initializing word embeddings from scratch')
28
+ # self.word_embed = nn.Embedding(vocab_size, GLOVE_DIM, padding_idx=0)
29
+ # else:
30
+ # if verbose:
31
+ # print('initializing word embeddings from glove')
32
+ # self.word_embed = nn.Embedding.from_pretrained(glove_embeddings, padding_idx=0)
33
+ # self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
34
+ # self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
35
+ # large_hidden_dim = HIDDEN_DIM
36
+ # self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
37
+ # self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
38
+ # self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
39
+ # self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
40
+ # self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
41
+ # self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
42
+ # self.nonlinear = nn.ReLU()
43
+ # elif self.formality:
44
+ if self.formality:
45
+ self.marian_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=0) # 0 in marian is ''
46
+ self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0.5) # want it to be causal so we can learn all positions
47
+ self.out_linear = nn.Linear(HIDDEN_DIM, 1)
48
+ elif self.iambic:
49
+ self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id)
50
+ self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False, dropout=0) # want it to be causal so we can learn all positions
51
+ self.out_linear = nn.Linear(HIDDEN_DIM, 1)
52
+ elif self.rhyme:
53
+ self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
54
+ self.word_embed = nn.Embedding(rhyme_group_size+1, GLOVE_DIM, padding_idx=0) # this embedding for future words will actually embed the rhyme group idx
55
+ self.rnn = nn.LSTM(HIDDEN_DIM, RNN_DIM, num_layers=3, bidirectional=True)
56
+ self.attention_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
57
+ large_hidden_dim = HIDDEN_DIM + COUNT_SYLLABLE_DIM
58
+ self.embed_key_linear = nn.Linear(large_hidden_dim, HIDDEN_DIM)
59
+ self.attention_value_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
60
+ self.out_embed_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
61
+ self.out_linear = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
62
+ self.out_linear2 = nn.Linear(HIDDEN_DIM + large_hidden_dim, HIDDEN_DIM)
63
+ self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
64
+ self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
65
+ self.nonlinear = nn.ReLU()
66
+ elif self.newline:
67
+ self.gpt_embed = nn.Embedding(gpt_pad_id + 1, HIDDEN_DIM, padding_idx=gpt_pad_id) # these are subwords, not words
68
+ self.rnn = nn.LSTM(HIDDEN_DIM, HIDDEN_DIM, num_layers=3, bidirectional=False)
69
+ self.count_syllable_embed = nn.Embedding(MAX_COUNT_SYLLABLE_DIST+1, COUNT_SYLLABLE_DIM)
70
+ self.out_linear = nn.Linear(HIDDEN_DIM + COUNT_SYLLABLE_DIM, HIDDEN_DIM)
71
+ self.out_linear2 = nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
72
+ self.out_linear3 = nn.Linear(HIDDEN_DIM, 1)
73
+ self.nonlinear = nn.ReLU()
74
+ elif self.clickbait:
75
+ # mpnet_config = ClickbaitConfig(
76
+ # model_type="mpnet",
77
+ # pretrained_model="sentence-transformers/all-mpnet-base-v2",
78
+ # num_labels=1,
79
+ # dropout=0.2,
80
+ # inner_dim1=256,
81
+ # inner_dim2=32,
82
+ # max_length=25,
83
+ # load_pretrained=True,
84
+ # freeze_bert=False,
85
+ # )
86
+ #TODO add a checkpoint to Classifier
87
+ # print('add a checkpoint to Classifier')
88
+ checkpoint = args.checkpoint #'ckpt/clickbait_classifier/checkpoint-1464'
89
+ # self.classifier = BertClickbaitClassifier(config=mpnet_config).to(torch.device(args.device))
90
+ self.classifier = BertClickbaitClassifier.from_pretrained(checkpoint).to(torch.device(args.device))
91
+ else:
92
+ raise NotImplementedError # TODO honestly this can/should be refactored into different models
93
+
94
+
95
+ def forward(self, inputs, lengths=None, future_words=None, log_probs=None, syllables_to_go=None, future_word_num_syllables=None, rhyme_group_index=None, run_classifier=False, attention_mask=None):
96
+ """
97
+ inputs: token ids, batch x seq, right-padded with 0s
98
+ lengths: lengths of inputs; batch
99
+ future_words: batch x N words to check if not predict next token, else batch
100
+ log_probs: N
101
+ syllables_to_go: batch
102
+ """
103
+ # if self.topic:
104
+ # inputs = self.gpt_embed(inputs) # batch x seq x 300
105
+ # inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
106
+ # rnn_output, _ = self.rnn(inputs)
107
+ # rnn_output, _ = pad_packed_sequence(rnn_output)
108
+ # rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
109
+ # hidden = rnn_output
110
+ # attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
111
+ # embed = self.word_embed(future_words) # batch x N x 300
112
+ # embed_query = self.embed_key_linear(embed)
113
+ # attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
114
+ # attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
115
+ # attention_weights = attention_weights * attention_mask.unsqueeze(2)
116
+ # hidden = self.attention_value_linear(hidden)
117
+ # weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
118
+ # unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
119
+ # unnormalized_scores = torch.cat([unnormalized_scores, embed], dim=2)
120
+ # unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
121
+ # unnormalized_scores = self.out_linear3(unnormalized_scores)
122
+ # scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0)
123
+ # return scores # batch x N of normalized scores or batch x
124
+ # elif self.formality:
125
+ if self.formality:
126
+ inputs = self.marian_embed(inputs)
127
+ inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
128
+ rnn_output, _ = self.rnn(inputs)
129
+ rnn_output, _ = pad_packed_sequence(rnn_output)
130
+ rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
131
+ return self.out_linear(rnn_output).squeeze(2)
132
+ elif self.iambic:
133
+ inputs = self.gpt_embed(inputs)
134
+ inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
135
+ rnn_output, _ = self.rnn(inputs)
136
+ rnn_output, _ = pad_packed_sequence(rnn_output)
137
+ rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
138
+ return self.out_linear(rnn_output).squeeze(2)
139
+ elif self.rhyme:
140
+ inputs = self.gpt_embed(inputs) # batch x seq x 300
141
+ inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
142
+ rnn_output, _ = self.rnn(inputs)
143
+ rnn_output, _ = pad_packed_sequence(rnn_output)
144
+ rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
145
+ hidden = rnn_output
146
+ attention_mask = pad_mask(lengths).permute(1, 0) # batch x seq
147
+ embed = self.word_embed(future_words) # batch x N x 300
148
+ embedded_syllables_to_go = self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, embed.shape[1], -1) # batch x N x 100
149
+ auxiliary_embed = embedded_syllables_to_go
150
+ embed_query = self.embed_key_linear(torch.cat([embed, auxiliary_embed], dim=2))
151
+ attention_tensor = self.attention_linear(hidden).unsqueeze(2) * embed_query.unsqueeze(1) # batch x seq x N x 300
152
+ attention_weights = F.softmax(attention_tensor.sum(dim=3), dim=1) # batch x seq x N
153
+ attention_weights = attention_weights * attention_mask.unsqueeze(2)
154
+ hidden = self.attention_value_linear(hidden)
155
+ weighted_hidden = (hidden.unsqueeze(2) * attention_weights.unsqueeze(3)).sum(dim=1) # batch x seq x N x 768 -> batch x N x 768
156
+ unnormalized_scores = (self.out_linear(weighted_hidden) * self.out_embed_linear(embed)) # batch x N x 300
157
+ unnormalized_scores = torch.cat([unnormalized_scores, embed, auxiliary_embed], dim=2)
158
+ unnormalized_scores = self.nonlinear(self.out_linear2(self.nonlinear(unnormalized_scores)))
159
+ unnormalized_scores = self.out_linear3(unnormalized_scores)
160
+ scores = unnormalized_scores.squeeze(2) - log_probs.unsqueeze(0)
161
+ return scores # batch x N of normalized scores or batch x
162
+ elif self.newline:
163
+ inputs = self.gpt_embed(inputs) # batch x seq x 300
164
+ inputs = pack_padded_sequence(inputs.permute(1, 0, 2), lengths.cpu(), enforce_sorted=False)
165
+ rnn_output, _ = self.rnn(inputs)
166
+ rnn_output, _ = pad_packed_sequence(rnn_output)
167
+ rnn_output = rnn_output.permute(1, 0, 2) # batch x seq x 300
168
+ hidden = torch.cat([rnn_output, self.count_syllable_embed(syllables_to_go).unsqueeze(1).expand(-1, rnn_output.shape[1], -1)], dim=2)
169
+ return self.out_linear3(self.nonlinear(self.out_linear2(self.nonlinear(self.out_linear(hidden))))).squeeze(2)
170
+ elif self.clickbait:
171
+
172
+ input_ids = torch.tensor(inputs)
173
+ classifer_output = self.classifier(input_ids = input_ids, attention_mask = attention_mask).logits
174
+
175
+ classifer_output = classifer_output[None,:,:] # batch x seq x 300
176
+ # return self.out_linear(rnn_output).squeeze(2)
177
+ return classifer_output.squeeze(2)
178
+
179
+ else:
180
+ raise NotImplementedError
181
+
182
+
fudge/poetry_data/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ `couplet_prefixes.txt` contains the 13th line of each of Shakespeare's sonnets. `couplet_ends.txt` contains the 14th. (Each 14-line sonnet ends with a couplet in the last two lines). The prefixes are our test set prefixes for the couplet completion task; the ends are Shakespeare's outputs.
fudge/poetry_data/couplet_ends.txt ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ To eat the world's due, by the grave and thee.
2
+ And see thy blood warm when thou feel'st it cold.
3
+ Die single, and thine image dies with thee.
4
+ Which, used, lives th' executor to be.
5
+ Leese but their show; their substance still lives sweet.
6
+ To be death's conquest and make worms thine heir.
7
+ Unlook'd on diest, unless thou get a son.
8
+ Sings this to thee: 'thou single wilt prove none.'
9
+ That on himself such murderous shame commits.
10
+ That beauty still may live in thine or thee.
11
+ Thou shouldst print more, not let that copy die.
12
+ Save breed, to brave him when he takes thee hence.
13
+ You had a father: let your son say so.
14
+ Thy end is truth's and beauty's doom and date.
15
+ As he takes from you, I engraft you new.
16
+ And you must live, drawn by your own sweet skill.
17
+ You should live twice; in it and in my rhyme.
18
+ So long lives this and this gives life to thee.
19
+ My love shall in my verse ever live young.
20
+ Mine be thy love and thy love's use their treasure.
21
+ I will not praise that purpose not to sell.
22
+ Thou gavest me thine, not to give back again.
23
+ To hear with eyes belongs to love's fine wit.
24
+ They draw but what they see, know not the heart.
25
+ Where I may not remove nor be removed.
26
+ Till then not show my head where thou mayst prove me.
27
+ For thee and for myself no quiet find.
28
+ And night doth nightly make grief's strength seem stronger.
29
+ That then I scorn to change my state with kings.
30
+ All losses are restored and sorrows end.
31
+ And thou, all they, hast all the all of me.
32
+ Theirs for their style I'll read, his for his love.'
33
+ Suns of the world may stain when heaven's sun staineth.
34
+ And they are rich and ransom all ill deeds.
35
+ To that sweet thief which sourly robs from me.
36
+ As, thou being mine, mine is thy good report.
37
+ This wish I have; then ten times happy me!
38
+ The pain be mine, but thine shall be the praise.
39
+ By praising him here who doth hence remain!
40
+ Kill me with spites; yet we must not be foes.
41
+ Thine, by thy beauty being false to me.
42
+ Sweet flattery! then she loves but me alone.
43
+ And nights bright days when dreams do show thee me.
44
+ But heavy tears, badges of either's woe.
45
+ I send them back again and straight grow sad.
46
+ And my heart's right thy inward love of heart.
47
+ Awakes my heart to heart's and eye's delight.
48
+ For truth proves thievish for a prize so dear.
49
+ Since why to love I can allege no cause.
50
+ My grief lies onward and my joy behind.
51
+ Towards thee I'll run, and give him leave to go.
52
+ Being had, to triumph, being lack'd, to hope.
53
+ But you like none, none you, for constant heart.
54
+ When that shall fade, my verse distills your truth.
55
+ You live in this, and dwell in lover's eyes.
56
+ Makes summer's welcome thrice more wish'd, more rare.
57
+ Though you do any thing, he thinks no ill.
58
+ Not blame your pleasure, be it ill or well.
59
+ To subjects worse have given admiring praise.
60
+ Praising thy worth, despite his cruel hand.
61
+ From me far off, with others all too near.
62
+ Painting my age with beauty of thy days.
63
+ And they shall live, and he in them still green.
64
+ But weep to have that which it fears to lose.
65
+ That in black ink my love may still shine bright.
66
+ Save that, to die, I leave my love alone.
67
+ In days long since, before these last so bad.
68
+ To show false Art what beauty was of yore.
69
+ The solve is this, that thou dost common grow.
70
+ Then thou alone kingdoms of hearts shouldst owe.
71
+ And mock you with me after I am gone.
72
+ And so should you, to love things nothing worth.
73
+ To love that well which thou must leave ere long.
74
+ And that is this, and this with thee remains.
75
+ Or gluttoning on all, or all away.
76
+ So is my love still telling what is told.
77
+ Shall profit thee and much enrich thy book.
78
+ As high as learning my rude ignorance.
79
+ Since what he owes thee thou thyself dost pay.
80
+ The worst was this; my love was my decay.
81
+ Where breath most breathes, even in the mouths of men.
82
+ Where cheeks need blood; in thee it is abused.
83
+ Than both your poets can in praise devise.
84
+ Being fond on praise, which makes your praises worse.
85
+ Me for my dumb thoughts, speaking in effect.
86
+ Then lack'd I matter; that enfeebled mine.
87
+ In sleep a king, but waking no such matter.
88
+ That for thy right myself will bear all wrong.
89
+ For I must ne'er love him whom thou dost hate.
90
+ Compared with loss of thee will not seem so.
91
+ All this away and me most wretched make.
92
+ Thou mayst be false, and yet I know it not.
93
+ if thy sweet virtue answer not thy show!
94
+ Lilies that fester smell far worse than weeds.
95
+ The hardest knife ill-used doth lose his edge.
96
+ As, thou being mine, mine is thy good report.
97
+ That leaves look pale, dreading the winter's near.
98
+ As with your shadow I with these did play:
99
+ But sweet or colour it had stol'n from thee.
100
+ So thou prevent'st his scythe and crooked knife.
101
+ To make him seem long hence as he shows now.
102
+ Because I would not dull you with my song.
103
+ Your own glass shows you when you look in it.
104
+ Ere you were born was beauty's summer dead.
105
+ Which three till now never kept seat in one.
106
+ Had eyes to wonder, but lack tongues to praise.
107
+ When tyrants' crests and tombs of brass are spent.
108
+ Where time and outward form would show it dead.
109
+ Save thou, my rose; in it thou art my all.
110
+ Even to thy pure and most most loving breast.
111
+ Even that your pity is enough to cure me.
112
+ That all the world besides methinks are dead.
113
+ My most true mind thus makes mine eye untrue.
114
+ That mine eye loves it and doth first begin.
115
+ To give full growth to that which still doth grow?
116
+ I never writ, nor no man ever loved.
117
+ The constancy and virtue of your love.
118
+ Drugs poison him that so fell sick of you.
119
+ And gain by ill thrice more than I have spent.
120
+ Mine ransoms yours, and yours must ransom me.
121
+ All men are bad, and in their badness reign.
122
+ Were to import forgetfulness in me.
123
+ I will be true, despite thy scythe and thee.
124
+ Which die for goodness, who have lived for crime.
125
+ When most impeach'd stands least in thy control.
126
+ And her quietus is to render thee.
127
+ That every tongue says beauty should look so.
128
+ Give them thy fingers, me thy lips to kiss.
129
+ To shun the heaven that leads men to this hell.
130
+ As any she belied with false compare.
131
+ And thence this slander, as I think, proceeds.
132
+ And all they foul that thy complexion lack.
133
+ Perforce am thine, and all that is in me.
134
+ He pays the whole, and yet am I not free.
135
+ Think all but one, and me in that one 'Will.'
136
+ And then thou lovest me, for my name is 'Will.'
137
+ And to this false plague are they now transferr'd.
138
+ And in our faults by lies we flatter'd be.
139
+ Kill me outright with looks and rid my pain.
140
+ Bear thine eyes straight, though thy proud heart go wide.
141
+ That she that makes me sin awards me pain.
142
+ By self-example mayst thou be denied!
143
+ If thou turn back, and my loud crying still.
144
+ Till my bad angel fire my good one out.
145
+ And saved my life, saying 'not you.'
146
+ And Death once dead, there's no more dying then.
147
+ Who art as black as hell, as dark as night.
148
+ Lest eyes well-seeing thy foul faults should find.
149
+ Those that can see thou lovest, and I am blind.
150
+ More worthy I to be beloved of thee.
151
+ Her 'love' for whose dear love I rise and fall.
152
+ To swear against the truth so foul a lie!
153
+ Where Cupid got new fire--my mistress' eyes.
154
+ Love's fire heats water, water cools not love.
fudge/poetry_data/couplet_prefixes.txt ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Pity the world, or else this glutton be,
2
+ This were to be new made when thou art old,
3
+ But if thou live, remember'd not to be,
4
+ Thy unused beauty must be tomb'd with thee,
5
+ But flowers distill'd though they with winter meet,
6
+ Be not self-will'd, for thou art much too fair
7
+ So thou, thyself out-going in thy noon,
8
+ Whose speechless song, being many, seeming one,
9
+ No love toward others in that bosom sits
10
+ Make thee another self, for love of me,
11
+ She carved thee for her seal, and meant thereby
12
+ And nothing 'gainst Time's scythe can make defence
13
+ O, none but unthrifts! Dear my love, you know
14
+ Or else of thee this I prognosticate:
15
+ And all in war with Time for love of you,
16
+ To give away yourself keeps yourself still,
17
+ But were some child of yours alive that time,
18
+ So long as men can breathe or eyes can see,
19
+ Yet, do thy worst, old Time: despite thy wrong,
20
+ But since she prick'd thee out for women's pleasure,
21
+ Let them say more than like of hearsay well;
22
+ Presume not on thy heart when mine is slain;
23
+ O, learn to read what silent love hath writ:
24
+ Yet eyes this cunning want to grace their art;
25
+ Then happy I, that love and am beloved
26
+ Then may I dare to boast how I do love thee;
27
+ Lo! thus, by day my limbs, by night my mind,
28
+ But day doth daily draw my sorrows longer
29
+ For thy sweet love remember'd such wealth brings
30
+ But if the while I think on thee, dear friend,
31
+ Their images I loved I view in thee,
32
+ But since he died and poets better prove,
33
+ Yet him for this my love no whit disdaineth;
34
+ Ah! but those tears are pearl which thy love sheds,
35
+ That I an accessary needs must be
36
+ But do not so; I love thee in such sort
37
+ Look, what is best, that best I wish in thee:
38
+ If my slight Muse do please these curious days,
39
+ And that thou teachest how to make one twain,
40
+ Lascivious grace, in whom all ill well shows,
41
+ Hers by thy beauty tempting her to thee,
42
+ But here's the joy; my friend and I are one;
43
+ All days are nights to see till I see thee,
44
+ Receiving nought by elements so slow
45
+ This told, I joy; but then no longer glad,
46
+ As thus; mine eye's due is thy outward part,
47
+ Or, if they sleep, thy picture in my sight
48
+ And even thence thou wilt be stol'n, I fear,
49
+ To leave poor me thou hast the strength of laws,
50
+ For that same groan doth put this in my mind;
51
+ Since from thee going he went wilful-slow,
52
+ Blessed are you, whose worthiness gives scope,
53
+ In all external grace you have some part,
54
+ And so of you, beauteous and lovely youth,
55
+ So, till the judgment that yourself arise,
56
+ Else call it winter, which being full of care
57
+ So true a fool is love that in your will,
58
+ I am to wait, though waiting so be hell;
59
+ O, sure I am, the wits of former days
60
+ And yet to times in hope my verse shall stand,
61
+ For thee watch I whilst thou dost wake elsewhere,
62
+ 'Tis thee, myself, that for myself I praise,
63
+ His beauty shall in these black lines be seen,
64
+ This thought is as a death, which cannot choose
65
+ O, none, unless this miracle have might,
66
+ Tired with all these, from these would I be gone,
67
+ O, him she stores, to show what wealth she had
68
+ And him as for a map doth Nature store,
69
+ But why thy odour matcheth not thy show,
70
+ If some suspect of ill mask'd not thy show,
71
+ Lest the wise world should look into your moan
72
+ For I am shamed by that which I bring forth,
73
+ This thou perceivest, which makes thy love more strong,
74
+ The worth of that is that which it contains,
75
+ Thus do I pine and surfeit day by day,
76
+ For as the sun is daily new and old,
77
+ These offices, so oft as thou wilt look,
78
+ But thou art all my art and dost advance
79
+ Then thank him not for that which he doth say,
80
+ Then if he thrive and I be cast away,
81
+ You still shall live--such virtue hath my pen--
82
+ And their gross painting might be better used
83
+ There lives more life in one of your fair eyes
84
+ You to your beauteous blessings add a curse,
85
+ Then others for the breath of words respect,
86
+ But when your countenance fill'd up his line,
87
+ Thus have I had thee, as a dream doth flatter,
88
+ Such is my love, to thee I so belong,
89
+ For thee against myself I'll vow debate,
90
+ And other strains of woe, which now seem woe,
91
+ Wretched in this alone, that thou mayst take
92
+ But what's so blessed-fair that fears no blot?
93
+ How like Eve's apple doth thy beauty grow,
94
+ For sweetest things turn sourest by their deeds;
95
+ Take heed, dear heart, of this large privilege;
96
+ But do not so; I love thee in such sort
97
+ Or, if they sing, 'tis with so dull a cheer
98
+ Yet seem'd it winter still, and, you away,
99
+ More flowers I noted, yet I none could see
100
+ Give my love fame faster than Time wastes life;
101
+ Then do thy office, Muse; I teach thee how
102
+ Therefore like her I sometime hold my tongue,
103
+ And more, much more, than in my verse can sit
104
+ For fear of which, hear this, thou age unbred;
105
+ 'Fair, kind, and true,' have often lived alone,
106
+ For we, which now behold these present days,
107
+ And thou in this shalt find thy monument,
108
+ Finding the first conceit of love there bred
109
+ For nothing this wide universe I call,
110
+ Then give me welcome, next my heaven the best,
111
+ Pity me then, dear friend, and I assure ye
112
+ You are so strongly in my purpose bred
113
+ Incapable of more, replete with you,
114
+ If it be poison'd, 'tis the lesser sin
115
+ Love is a babe; then might I not say so,
116
+ If this be error and upon me proved,
117
+ Since my appeal says I did strive to prove
118
+ But thence I learn, and find the lesson true,
119
+ So I return rebuked to my content
120
+ But that your trespass now becomes a fee;
121
+ Unless this general evil they maintain,
122
+ To keep an adjunct to remember thee
123
+ This I do vow and this shall ever be;
124
+ To this I witness call the fools of time,
125
+ Hence, thou suborn'd informer! a true soul
126
+ Her audit, though delay'd, answer'd must be,
127
+ Yet so they mourn, becoming of their woe,
128
+ Since saucy jacks so happy are in this,
129
+ All this the world well knows; yet none knows well
130
+ And yet, by heaven, I think my love as rare
131
+ In nothing art thou black save in thy deeds,
132
+ Then will I swear beauty herself is black
133
+ And yet thou wilt; for I, being pent in thee,
134
+ Him have I lost; thou hast both him and me:
135
+ Let no unkind, no fair beseechers kill;
136
+ Make but my name thy love, and love that still,
137
+ In things right true my heart and eyes have erred,
138
+ Therefore I lie with her and she with me,
139
+ Yet do not so; but since I am near slain,
140
+ That I may not be so, nor thou belied,
141
+ Only my plague thus far I count my gain,
142
+ If thou dost seek to have what thou dost hide,
143
+ So will I pray that thou mayst have thy 'Will,'
144
+ Yet this shall I ne'er know, but live in doubt,
145
+ 'I hate' from hate away she threw,
146
+ So shalt thou feed on Death, that feeds on men,
147
+ For I have sworn thee fair and thought thee bright,
148
+ O cunning Love! with tears thou keep'st me blind,
149
+ But, love, hate on, for now I know thy mind;
150
+ If thy unworthiness raised love in me,
151
+ No want of conscience hold it that I call
152
+ For I have sworn thee fair; more perjured I,
153
+ But found no cure: the bath for my help lies
154
+ Came there for cure, and this by that I prove,
fudge/poetry_util.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import string
2
+
3
+ import pronouncing
4
+ from Phyme import Phyme
5
+ phyme = Phyme()
6
+
7
+ from fudge.constants import *
8
+
9
+ def is_iambic(phrase):
10
+ """
11
+ check that we satisfy iambic meter.
12
+ return 1 if so, otherwise 0.
13
+ definitely an imperfect check...
14
+ if we end up needing to check a word that's not in the CMU dictionary, just return 0.
15
+ """
16
+ meter = ''
17
+ for word in phrase.split():
18
+ word = word.strip().strip(string.punctuation).lower()
19
+ try:
20
+ phones_list = pronouncing.phones_for_word(word)
21
+ stresses = pronouncing.stresses(phones_list[0])
22
+ if len(stresses) == 1:
23
+ if stresses == '1':
24
+ stresses = '2' # allow ambiguity for 1-syllable words with stress 1
25
+ meter += stresses # just default to the first pronunciation if > 1 given
26
+ except:
27
+ return 0 # word not found
28
+ meter = [int(x) for x in meter]
29
+ even_stresses_full = [meter[i] for i in range(0, len(meter), 2)]
30
+ odd_stresses_full = [meter[i] for i in range(1, len(meter), 2)]
31
+ even_stresses = set(even_stresses_full)
32
+ odd_stresses = set(odd_stresses_full)
33
+ if 0 in odd_stresses:
34
+ return 0
35
+ if 1 in even_stresses:
36
+ return 0
37
+ return 1
38
+
39
+
40
+ def count_syllables(words):
41
+ syllables = 0
42
+ for word in words.split():
43
+ word = word.strip().strip(string.punctuation)
44
+ try:
45
+ phones_list = pronouncing.phones_for_word(word)
46
+ stresses = pronouncing.stresses(phones_list[0])
47
+ syllables += min(MAX_SYLLABLES_PER_WORD, len(stresses))
48
+ except:
49
+ # if we don't know, just do a quick approximation here; it shouldn't come up too often
50
+ syllables += min(MAX_SYLLABLES_PER_WORD, round(len(word) / 3))
51
+ return syllables
52
+
53
+
54
+ def get_rhymes(word):
55
+ # throws exception if word not in the rhyme dict (rare)
56
+ rhymes = []
57
+ rhyme_dict = phyme.get_perfect_rhymes(word)
58
+ for length_dict in rhyme_dict.values():
59
+ for word in length_dict:
60
+ if '(' in word: # sometimes you have stuff like preferred(1) where they indicate a particular pronunciation
61
+ rhymes.append(word.split('(')[0])
62
+ else:
63
+ rhymes.append(word)
64
+ return sorted(list(set(rhymes)))
65
+
66
+
67
+ def get_rhyme_group(word):
68
+ sorted_rhyme_list = get_rhymes(word)
69
+ return ' '.join(sorted_rhyme_list)
70
+
71
+
72
+ def perfect_rhyme_end(s1, s2):
73
+ ending_word1 = s1.split()[-1].strip(string.punctuation)
74
+ ending_word2 = s2.split()[-1].strip(string.punctuation)
75
+ try:
76
+ return get_rhyme_group(ending_word1) == get_rhyme_group(ending_word2)
77
+ except:
78
+ return False # unknown words
79
+
80
+ if __name__=='__main__':
81
+ result = is_iambic('Shall I compare thee to a summer day')
82
+ result2 = count_syllables('Shall I compare thee to a summer day')
83
+ import pdb; pdb.set_trace()
fudge/predict_clickbait.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from typing import Iterable, List, Optional, Tuple
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead
16
+ from torch import Tensor
17
+
18
+ from fudge.data import Dataset
19
+ from fudge.model import Model
20
+ from fudge.util import num_params
21
+ from fudge.constants import *
22
+
23
+
24
+
25
+ tokenizer = AutoTokenizer.from_pretrained('google/pegasus-xsum')
26
+ classifier_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
27
+
28
+
29
+ def main(args):
30
+ with open(args.dataset_info, 'rb') as rf:
31
+ dataset_info = pickle.load(rf)
32
+
33
+ article_content = """Australian actor Guy Pearce will return for the iconic soap Neighbours finale on August 1 to reprise his role as Mike Young.
34
+ Guy, 54, played the troubled Mike from 1986 to 1989, and is now set to make a comeback on the show after 33 years, Metro.co.uk reports.
35
+ The star's character arcs explored the implications of domestic abuse, student-teacher relationships and dealing with loss of loved ones.
36
+ Speaking to Metro.co.uk, Guy said: 'It is very exciting and surreal at the same time being back on set again, however it feels like coming home.
37
+ 'It's where it all started for me professionally. I've been asked to come back on occasions over the years and wondered if it was the right thing
38
+ to do, but once I knew the show was finishing, I knew I had to do it.'He added that there is 'nothing like being here all together again'
39
+ , even though he's had a chance to catch-up with other cast members."""
40
+
41
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
42
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
43
+
44
+ #For loading Clickbait summarizer
45
+ model = AutoModelWithLMHead.from_pretrained(args.model_string, return_dict=True).to(args.device)
46
+
47
+ model.eval()
48
+
49
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
50
+ model_args = checkpoint['args']
51
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
52
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
53
+ conditioning_model = conditioning_model.to(args.device)
54
+ conditioning_model.eval()
55
+ print("=> loaded checkpoint '{}' (epoch {})"
56
+ .format(args.ckpt, checkpoint['epoch']))
57
+ print('num params', num_params(conditioning_model))
58
+
59
+ while True:
60
+ results = generate_clickbait(model,
61
+ tokenizer,
62
+ conditioning_model,
63
+ [args.input_text],
64
+ dataset_info,
65
+ precondition_topk=args.precondition_topk,
66
+ do_sample=args.do_sample,
67
+ length_cutoff=args.length_cutoff,
68
+ condition_lambda=args.condition_lambda,
69
+ article_content=article_content,
70
+ device=args.device)
71
+ # print(results)
72
+ import pdb; pdb.set_trace()
73
+
74
+
75
+ def generate_clickbait(model,
76
+ tokenizer,
77
+ conditioning_model,
78
+ input_text,
79
+ dataset_info,
80
+ precondition_topk,
81
+ length_cutoff,
82
+ condition_lambda=1.0,
83
+ article_content=None,
84
+ device='cuda'):
85
+ with torch.no_grad():
86
+ batch_size = len(input_text)
87
+ # encoded_input_article = [tokenizer.encode(article_content, return_tensors='pt',add_special_tokens=False).to(device)] # batch x seq
88
+ max_input_length = 512
89
+ encoded_input_article = tokenizer(article_content, return_tensors='pt',add_special_tokens=False, max_length = max_input_length).to(device) # batch x seq
90
+ # encoded_input_article = torch.cat(encoded_input_article, dim=0)
91
+ # attention_mask = encoded_input_article.new_ones(encoded_input_article.shape).to(device)
92
+
93
+ # CHANGE=ko
94
+ encoded_input = tokenizer('<pad>', return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
95
+ # encoded_input = tokenizer('<pad>'+ input_text[0], return_tensors='pt',add_special_tokens=False).to(device) # batch x seq
96
+ # encoded_input = torch.cat(encoded_input, dim=0)
97
+ encoded_input = encoded_input['input_ids']
98
+
99
+
100
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
101
+ # lengths = 1
102
+
103
+ past = None
104
+ use_cache = True
105
+
106
+ # CHANGE
107
+ # model_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input_article, attention_mask=attention_mask)}
108
+ model_kwargs = {'encoder_outputs': model.get_encoder()(input_ids=encoded_input_article['input_ids'],
109
+ attention_mask=encoded_input_article['attention_mask'],
110
+ return_dict=True,
111
+ output_attentions=False,
112
+ output_hidden_states=False),
113
+ }
114
+
115
+ while lengths.max() < length_cutoff:
116
+ model_inputs = model.prepare_inputs_for_generation(
117
+ input_ids = encoded_input_article['input_ids'],
118
+ decoder_input_ids=encoded_input,
119
+ # past=past,
120
+ attention_mask=encoded_input_article['attention_mask'],
121
+ use_cache=use_cache,
122
+ **model_kwargs
123
+ )
124
+
125
+ outputs = model(**model_inputs, return_dict=True)
126
+ logits = outputs.logits[:, -1, :]
127
+
128
+ if "past_key_values" in outputs:
129
+ model_kwargs["past"] = outputs.past_key_values
130
+
131
+ # logits = model(encoded_input)[0][:, -1, :] # batch x vocab
132
+ top_logits, top_indices = logits.topk(precondition_topk, dim=1) # batch x topk
133
+ new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
134
+ expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
135
+
136
+ if condition_lambda == 0:
137
+ condition_logits = torch.zeros_like(top_logits).float()
138
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
139
+ else:
140
+ decoded_outputs = tokenizer.batch_decode(new_input_candidates.view(-1, new_input_candidates.size(-1)), clean_up_tokenization_spaces=False)
141
+ resulting_tokenization = classifier_tokenizer(decoded_outputs, add_special_tokens=False, padding='longest')
142
+ encoded_with_classifier = resulting_tokenization['input_ids']
143
+ attention_mask = torch.tensor(resulting_tokenization['attention_mask']).to(model.device)
144
+ tplus1_candidates_classifier = torch.tensor(encoded_with_classifier).view(batch_size, precondition_topk, -1).to(model.device)
145
+
146
+ condition_logits = conditioning_model(tplus1_candidates_classifier.flatten(0, 1), # batch*topk x seq+1
147
+ expanded_lengths.flatten(0, 1), # batch*topk
148
+ None,
149
+ None,
150
+ None,
151
+ attention_mask=attention_mask
152
+ )
153
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
154
+ condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
155
+
156
+ condition_logits = torch.mean(condition_logits, dim=2)
157
+ full_logits = top_logits + condition_logits * condition_lambda # batch x topk
158
+ post_logits, post_indices = full_logits.topk(precondition_topk, dim=1)
159
+ post_probs = F.softmax(post_logits, dim=1)
160
+ # index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
161
+ index_into_top_indices = post_indices[:, torch.multinomial(post_probs, 1).flatten()] # batch
162
+
163
+ # next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
164
+ next_indices = top_indices[:, index_into_top_indices] # batch
165
+
166
+ # encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
167
+ encoded_input = torch.cat([encoded_input, next_indices.squeeze(1)], dim=1)
168
+ lengths = lengths + 1 # batch
169
+
170
+ # print(tokenizer.decode(encoded_input[0], add_special_tokens=False))
171
+ return [tokenizer.decode(s) for s in encoded_input]
172
+
173
+
174
+ if __name__=='__main__':
175
+ parser = ArgumentParser()
176
+
177
+ # DATA
178
+ parser.add_argument('--ckpt', type=str, required=True)
179
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
180
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
181
+
182
+ parser.add_argument('--in_file', type=str, default=None, required=True, help='text to run pred on')
183
+
184
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from text generation at each step before conditioning and re-pruning')
185
+ parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
186
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
187
+ parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
188
+
189
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
190
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
191
+ parser.add_argument('--debug', action='store_true', default=False)
192
+
193
+ args = parser.parse_args()
194
+
195
+ random.seed(args.seed)
196
+ np.random.seed(args.seed)
197
+ torch.manual_seed(args.seed)
198
+
199
+ main(args)
fudge/predict_formality.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from typing import Iterable, List, Optional, Tuple
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model, MarianTokenizer, MarianMTModel
16
+ from torch import Tensor
17
+
18
+ from data import Dataset
19
+ from model import Model
20
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
21
+ from constants import *
22
+
23
+ def main(args):
24
+ with open(args.dataset_info, 'rb') as rf:
25
+ dataset_info = pickle.load(rf)
26
+ tokenizer = MarianTokenizer.from_pretrained(args.model_string)
27
+ tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
28
+ pad_id = tokenizer.encode(PAD_TOKEN)[0]
29
+ model = MarianMTModel.from_pretrained(args.model_string, return_dict=True).to(args.device)
30
+ model.eval()
31
+
32
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
33
+ model_args = checkpoint['args']
34
+ conditioning_model = Model(model_args, pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
35
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
36
+ conditioning_model = conditioning_model.to(args.device)
37
+ conditioning_model.eval()
38
+ print("=> loaded checkpoint '{}' (epoch {})"
39
+ .format(args.ckpt, checkpoint['epoch']))
40
+ print('num params', num_params(conditioning_model))
41
+
42
+ while True:
43
+ results = predict_formality(model,
44
+ tokenizer,
45
+ conditioning_model,
46
+ [args.input_text],
47
+ dataset_info,
48
+ precondition_topk=args.precondition_topk,
49
+ do_sample=args.do_sample,
50
+ length_cutoff=args.length_cutoff,
51
+ condition_lambda=args.condition_lambda,
52
+ device=args.device)
53
+ print(results)
54
+ import pdb; pdb.set_trace()
55
+
56
+
57
+ def predict_formality(model, tokenizer, conditioning_model, input_text, dataset_info, precondition_topk=200, do_sample=False, length_cutoff=512, condition_lambda=1.0, device='cuda'):
58
+ with torch.no_grad():
59
+ batch_size = len(input_text)
60
+
61
+ # assumes initially all same length.
62
+ # encode every x_i i \in [seq] word to respectable embedding
63
+ encoded_input = [tokenizer.encode(it, return_tensors='pt').to(device) for it in input_text] # batch x seq
64
+ encoded_input = torch.cat(encoded_input, dim=0)
65
+
66
+ input_ids = torch.LongTensor([[58100]]).to(device)
67
+ cur_len = 1
68
+ max_length = length_cutoff
69
+ min_length = 0
70
+ temperature = 1.0
71
+ top_k = 50
72
+ top_p = 1.0
73
+ repetition_penalty = 1.0
74
+ no_repeat_ngram_size = 0
75
+ bad_words_ids = [[58100]]
76
+ pad_token_id = 58100
77
+ eos_token_id = 0
78
+ effective_batch_size = batch_size
79
+ attention_mask = encoded_input.new_ones(encoded_input.shape)
80
+ use_cache = True
81
+ model_specific_kwargs = {'encoder_outputs': model.get_encoder()(encoded_input, attention_mask=attention_mask)}
82
+
83
+ output = _generate_no_beam_search(model,
84
+ conditioning_model,
85
+ condition_lambda,
86
+ precondition_topk,
87
+ input_ids,
88
+ cur_len,
89
+ max_length,
90
+ min_length,
91
+ do_sample,
92
+ temperature,
93
+ top_k,
94
+ top_p,
95
+ repetition_penalty,
96
+ no_repeat_ngram_size,
97
+ bad_words_ids,
98
+ pad_token_id,
99
+ eos_token_id,
100
+ batch_size,
101
+ attention_mask,
102
+ use_cache,
103
+ model_specific_kwargs)
104
+
105
+ return [tokenizer.decode(s[1:]) for s in output] # 1: to delete the pad token
106
+
107
+
108
+ # hack of code from transformers/generation_utils.py
109
+ # to get our conditioning
110
+ def postprocess_next_token_scores(
111
+ model,
112
+ scores,
113
+ input_ids,
114
+ no_repeat_ngram_size,
115
+ bad_words_ids,
116
+ cur_len,
117
+ min_length,
118
+ max_length,
119
+ eos_token_id,
120
+ repetition_penalty,
121
+ batch_size,
122
+ num_beams,
123
+ ):
124
+ # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
125
+ if repetition_penalty != 1.0:
126
+ model.enforce_repetition_penalty_(
127
+ scores,
128
+ batch_size,
129
+ num_beams,
130
+ input_ids,
131
+ repetition_penalty,
132
+ )
133
+
134
+ # set eos token prob to zero if min_length is not reached
135
+ if eos_token_id is not None and cur_len < min_length:
136
+ scores[:, eos_token_id] = -float("inf")
137
+
138
+ if no_repeat_ngram_size > 0:
139
+ # calculate a list of banned tokens to prevent repetitively generating the same ngrams
140
+ num_batch_hypotheses = batch_size * num_beams
141
+ # from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
142
+ banned_batch_tokens = calc_banned_ngram_tokens(
143
+ input_ids, num_batch_hypotheses, no_repeat_ngram_size, cur_len
144
+ )
145
+ for i, banned_tokens in enumerate(banned_batch_tokens):
146
+ scores[i, banned_tokens] = -float("inf")
147
+
148
+ if bad_words_ids is not None:
149
+ # Exclude EOS token (already processed)
150
+ bad_words_ids = list(filter(lambda bad_token_seq: bad_token_seq != [eos_token_id], bad_words_ids))
151
+ # calculate a list of banned tokens according to bad words
152
+ banned_tokens = calc_banned_bad_words_ids(input_ids.tolist(), bad_words_ids)
153
+ # Modify the scores in place by setting the banned tokens logits to `-inf`
154
+ set_scores_to_inf_for_banned_tokens(scores, banned_tokens)
155
+
156
+ return scores
157
+
158
+ def calc_banned_ngram_tokens(prev_input_ids: Tensor, num_hypos: int, no_repeat_ngram_size: int, cur_len: int) -> None:
159
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
160
+ if cur_len + 1 < no_repeat_ngram_size:
161
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
162
+ return [[] for _ in range(num_hypos)]
163
+ generated_ngrams = [{} for _ in range(num_hypos)]
164
+ for idx in range(num_hypos):
165
+ gen_tokens = prev_input_ids[idx].tolist()
166
+ generated_ngram = generated_ngrams[idx]
167
+ for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
168
+ prev_ngram_tuple = tuple(ngram[:-1])
169
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
170
+
171
+ def _get_generated_ngrams(hypo_idx):
172
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
173
+ start_idx = cur_len + 1 - no_repeat_ngram_size
174
+ ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].tolist())
175
+ return generated_ngrams[hypo_idx].get(ngram_idx, [])
176
+
177
+ banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
178
+ return banned_tokens
179
+
180
+
181
+ def calc_banned_bad_words_ids(prev_input_ids: Iterable[int], bad_words_ids: Iterable[int]) -> Iterable[int]:
182
+ banned_tokens = []
183
+
184
+ def _tokens_match(prev_tokens, tokens):
185
+ if len(tokens) == 0:
186
+ # if bad word tokens is just one token always ban it
187
+ return True
188
+ if len(tokens) > len(prev_tokens):
189
+ # if bad word tokens are longer than prev tokens they can't be equal
190
+ return False
191
+
192
+ if prev_tokens[-len(tokens) :] == tokens:
193
+ # if tokens match
194
+ return True
195
+ else:
196
+ return False
197
+
198
+ for prev_input_ids_slice in prev_input_ids:
199
+ banned_tokens_slice = []
200
+
201
+ for banned_token_seq in bad_words_ids:
202
+ assert len(banned_token_seq) > 0, "Banned words token sequences {} cannot have an empty list".format(
203
+ bad_words_ids
204
+ )
205
+
206
+ if _tokens_match(prev_input_ids_slice, banned_token_seq[:-1]) is False:
207
+ # if tokens do not match continue
208
+ continue
209
+
210
+ banned_tokens_slice.append(banned_token_seq[-1])
211
+
212
+ banned_tokens.append(banned_tokens_slice)
213
+
214
+ return banned_tokens
215
+
216
+ def set_scores_to_inf_for_banned_tokens(scores: torch.Tensor, banned_tokens: List[List[int]]) -> None:
217
+ """Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be
218
+ a list of list of banned tokens to ban in the format [[batch index, vocabulary position],...]
219
+ Args:
220
+ scores: logits distribution of shape (batch size, vocabulary size)
221
+ banned_tokens: list of list of tokens to ban of length (batch_size)
222
+ """
223
+ banned_mask_list = []
224
+ for idx, batch_banned_tokens in enumerate(banned_tokens):
225
+ for token in batch_banned_tokens:
226
+ banned_mask_list.append([idx, token])
227
+ if not banned_mask_list:
228
+ return
229
+ banned_mask = torch.LongTensor(banned_mask_list)
230
+ indices = torch.ones(len(banned_mask))
231
+ # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
232
+ # [ 0 1 1 ]
233
+ # [ 0 0 0 ]
234
+ # [ 1 0 0 ]
235
+
236
+ banned_mask = torch.sparse.LongTensor(banned_mask.t(), indices, scores.size()).to(scores.device).to_dense().bool()
237
+ scores.masked_fill_(banned_mask, -float("inf"))
238
+
239
+ def _generate_no_beam_search(
240
+ model,
241
+ conditioning_model,
242
+ condition_lambda,
243
+ precondition_topk,
244
+ input_ids,
245
+ cur_len,
246
+ max_length,
247
+ min_length,
248
+ do_sample,
249
+ temperature,
250
+ top_k,
251
+ top_p,
252
+ repetition_penalty,
253
+ no_repeat_ngram_size,
254
+ bad_words_ids,
255
+ pad_token_id,
256
+ eos_token_id,
257
+ batch_size,
258
+ attention_mask,
259
+ use_cache,
260
+ model_kwargs,
261
+ ):
262
+ """Generate sequences for each example without beam search (num_beams == 1).
263
+ All returned sequence are generated independantly.
264
+ """
265
+ # length of generated sentences / unfinished sentences
266
+ unfinished_sents = input_ids.new(batch_size).fill_(1)
267
+ sent_lengths = input_ids.new(batch_size).fill_(max_length)
268
+ past = None
269
+ while cur_len < max_length:
270
+ model_inputs = model.prepare_inputs_for_generation(
271
+ input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
272
+ )
273
+
274
+ outputs = model(**model_inputs, return_dict=True)
275
+ next_token_logits = outputs.logits[:, -1, :]
276
+
277
+ # scores = model.postprocess_next_token_scores(
278
+ # scores=next_token_logits,
279
+ # input_ids=input_ids,
280
+ # no_repeat_ngram_size=no_repeat_ngram_size,
281
+ # bad_words_ids=bad_words_ids,
282
+ # cur_len=cur_len,
283
+ # min_length=min_length,
284
+ # max_length=max_length,
285
+ # eos_token_id=eos_token_id,
286
+ # repetition_penalty=repetition_penalty,
287
+ # batch_size=batch_size,
288
+ # num_beams=1,
289
+ # )
290
+
291
+ scores = postprocess_next_token_scores(
292
+ model=model,
293
+ scores=next_token_logits,
294
+ input_ids=input_ids,
295
+ no_repeat_ngram_size=no_repeat_ngram_size,
296
+ bad_words_ids=bad_words_ids,
297
+ cur_len=cur_len,
298
+ min_length=min_length,
299
+ max_length=max_length,
300
+ eos_token_id=eos_token_id,
301
+ repetition_penalty=repetition_penalty,
302
+ batch_size=batch_size,
303
+ num_beams=1,
304
+ )
305
+
306
+ # if model has past, then set the past variable to speed up decoding
307
+ if "past_key_values" in outputs:
308
+ past = outputs.past_key_values
309
+ elif "mems" in outputs:
310
+ past = outputs.mems
311
+
312
+ top_logits, top_indices = scores.topk(precondition_topk, dim=1) # batch x topk
313
+ tplus1_candidates = torch.cat([input_ids.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2)[:, :, 1:] # batch x topk x seq+1, with pad dropped
314
+ expanded_lengths = torch.LongTensor([[cur_len for _ in range(precondition_topk)] for _ in range(batch_size)]).to(scores.device)
315
+ if condition_lambda == 0:
316
+ condition_logits = torch.zeros_like(top_logits).float()
317
+ else:
318
+ condition_logits = conditioning_model(tplus1_candidates.flatten(0, 1), # batch*topk x seq+1
319
+ expanded_lengths.flatten(0, 1), # batch*topk
320
+ None,
321
+ None,
322
+ None)
323
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1)[:, :, -1] # batch x topk of last formality pred
324
+ condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
325
+ # condition_logits = - torch.log(1 + torch.exp(condition_logits)) # for informal
326
+ full_logits = top_logits + condition_lambda * condition_logits
327
+ if do_sample:
328
+ raise NotImplementedError
329
+ else:
330
+ # Greedy decoding
331
+ next_token = top_indices[torch.arange(batch_size).to(top_indices.device), torch.argmax(full_logits, dim=-1)]
332
+
333
+ # if do_sample:
334
+ # # Temperature (higher temperature => more likely to sample low probability tokens)
335
+ # if temperature != 1.0:
336
+ # scores = scores / temperature
337
+ # # Top-p/top-k filtering
338
+ # next_token_logscores = top_k_top_p_filtering(scores, top_k=top_k, top_p=top_p)
339
+ # # Sample
340
+ # probs = F.softmax(next_token_logscores, dim=-1)
341
+ # next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
342
+ # else:
343
+ # # Greedy decoding
344
+ # next_token = torch.argmax(next_token_logits, dim=-1)
345
+
346
+ # update generations and finished sentences
347
+ if eos_token_id is not None:
348
+ # pad finished sentences if eos_token_id exist
349
+ tokens_to_add = next_token * unfinished_sents + (pad_token_id) * (1 - unfinished_sents)
350
+ else:
351
+ tokens_to_add = next_token
352
+
353
+ # add token and increase length by one
354
+ input_ids = torch.cat([input_ids, tokens_to_add.unsqueeze(-1)], dim=-1)
355
+ cur_len = cur_len + 1
356
+
357
+ if eos_token_id is not None:
358
+ eos_in_sents = tokens_to_add == eos_token_id
359
+ # if sentence is unfinished and the token to add is eos, sent_lengths is filled with current length
360
+ is_sents_unfinished_and_token_to_add_is_eos = unfinished_sents.mul(eos_in_sents.long()).bool()
361
+ sent_lengths.masked_fill_(is_sents_unfinished_and_token_to_add_is_eos, cur_len)
362
+ # unfinished_sents is set to zero if eos in sentence
363
+ unfinished_sents.mul_((~eos_in_sents).long())
364
+
365
+ # stop when there is a </s> in each sentence, or if we exceed the maximul length
366
+ if unfinished_sents.max() == 0:
367
+ break
368
+
369
+ # extend attention_mask for new generated input if only decoder
370
+ if model.config.is_encoder_decoder is False:
371
+ attention_mask = torch.cat(
372
+ [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
373
+ )
374
+
375
+ return input_ids
376
+
377
+ if __name__=='__main__':
378
+ parser = ArgumentParser()
379
+
380
+ # DATA
381
+ parser.add_argument('--ckpt', type=str, required=True)
382
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
383
+ parser.add_argument('--model_string', type=str, default='Helsinki-NLP/opus-mt-es-en')
384
+
385
+ parser.add_argument('--input_text', type=str, default=None, required=True, help='text to run pred on')
386
+
387
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
388
+ parser.add_argument('--do_sample', action='store_true', default=False, help='sample instead of greedy')
389
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
390
+ parser.add_argument('--length_cutoff', type=int, default=512, help='max length')
391
+
392
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
393
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
394
+ parser.add_argument('--debug', action='store_true', default=False)
395
+
396
+ args = parser.parse_args()
397
+
398
+ random.seed(args.seed)
399
+ np.random.seed(args.seed)
400
+ torch.manual_seed(args.seed)
401
+
402
+ main(args)
403
+
404
+
fudge/predict_poetry.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+ import string
8
+ from collections import defaultdict
9
+
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
16
+
17
+ from data import Dataset, load_rhyme_info
18
+ from model import Model
19
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
20
+ from constants import *
21
+ from poetry_util import get_rhymes, count_syllables
22
+
23
+ def main(args):
24
+ with open(args.dataset_info, 'rb') as rf:
25
+ dataset_info = pickle.load(rf)
26
+ gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
27
+ gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
28
+ gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
29
+ gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
30
+ gpt_model.eval()
31
+
32
+ checkpoint = torch.load(args.iambic_ckpt, map_location=args.device)
33
+ model_args = checkpoint['args']
34
+ iambic_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
35
+ iambic_model.load_state_dict(checkpoint['state_dict'])
36
+ iambic_model = iambic_model.to(args.device)
37
+ iambic_model.eval()
38
+ print("=> loaded checkpoint '{}' (epoch {})"
39
+ .format(args.iambic_ckpt, checkpoint['epoch']))
40
+ print('iambic model num params', num_params(iambic_model))
41
+
42
+ with open(args.rhyme_info, 'rb') as rf:
43
+ rhyme_info = pickle.load(rf)
44
+ checkpoint = torch.load(args.rhyme_ckpt, map_location=args.device)
45
+ model_args = checkpoint['args']
46
+ rhyme_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word), rhyme_group_size=len(rhyme_info.index2rhyme_group)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
47
+ rhyme_model.load_state_dict(checkpoint['state_dict'])
48
+ rhyme_model = rhyme_model.to(args.device)
49
+ rhyme_model.eval()
50
+ print("=> loaded checkpoint '{}' (epoch {})"
51
+ .format(args.rhyme_ckpt, checkpoint['epoch']))
52
+ print('rhyme model num params', num_params(rhyme_model))
53
+
54
+ checkpoint = torch.load(args.newline_ckpt, map_location=args.device)
55
+ model_args = checkpoint['args']
56
+ newline_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
57
+ newline_model.load_state_dict(checkpoint['state_dict'])
58
+ newline_model = newline_model.to(args.device)
59
+ newline_model.eval()
60
+ print("=> loaded checkpoint '{}' (epoch {})"
61
+ .format(args.newline_ckpt, checkpoint['epoch']))
62
+ print('iambic model num params', num_params(newline_model))
63
+
64
+ while True:
65
+ results = predict_couplet(gpt_model,
66
+ gpt_tokenizer,
67
+ iambic_model,
68
+ rhyme_model,
69
+ newline_model,
70
+ [args.input_text],
71
+ dataset_info,
72
+ rhyme_info,
73
+ args.precondition_topk,
74
+ args.topk,
75
+ condition_lambda=args.condition_lambda,
76
+ device=args.device)
77
+ for line in results:
78
+ print(line)
79
+ import pdb; pdb.set_trace()
80
+
81
+
82
+ def predict_couplet(gpt_model, gpt_tokenizer, iambic_model, rhyme_model, newline_model, input_text, dataset_info, rhyme_info, precondition_topk, postcondition_topk, condition_lambda=1.0, device='cuda'):
83
+ assert len(input_text) == 1 # only do one at a time for now
84
+ current_text = input_text[0]
85
+ current_line_text = ''
86
+ all_lines = [current_text]
87
+ ending_word = current_text.split()[-1].strip(string.punctuation)
88
+ word2rhyme_group = defaultdict(lambda: UNKNOWN_RHYME_GROUP, rhyme_info.word2rhyme_group)
89
+ rhyme_group = word2rhyme_group[ending_word]
90
+
91
+ line = predict_iambic_pentameter_line(gpt_model,
92
+ gpt_tokenizer,
93
+ iambic_model,
94
+ rhyme_model,
95
+ newline_model,
96
+ current_text,
97
+ current_line_text,
98
+ rhyme_group,
99
+ dataset_info,
100
+ rhyme_info,
101
+ precondition_topk,
102
+ postcondition_topk,
103
+ condition_lambda=condition_lambda,
104
+ device=device)
105
+ all_lines.append(line)
106
+
107
+ return all_lines
108
+
109
+
110
+ def predict_iambic_pentameter_line(gpt_model, gpt_tokenizer, iambic_model, rhyme_model, newline_model, current_text, current_line_text, rhyme_group, dataset_info, rhyme_info, precondition_topk, postcondition_topk, banned_tokens=POETRY_BANNED_TOKENS, condition_lambda=1.0, device='cuda', length_cutoff=30):
111
+ # TODO(poetry) delete banned tokens?
112
+ with torch.no_grad():
113
+ batch_size = 1
114
+
115
+ rhyme_group_index = rhyme_info.rhyme_group2index[rhyme_group]
116
+ future_words = torch.LongTensor([rhyme_group_index]).to(device) # 1
117
+ log_probs = torch.Tensor([math.log(rhyme_info.rhyme_group_counts[rhyme_group] / rhyme_info.total_rhyme_groups)]).to(device) # 1
118
+
119
+ # assumes initially all same length.
120
+ previous_encoded_text = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in [current_text]]
121
+ previous_enc_len = previous_encoded_text[0].shape[1]
122
+ encoded_input = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in [current_text + current_line_text]] # batch x seq
123
+ encoded_input = torch.cat(encoded_input, dim=0)
124
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
125
+
126
+ line_syllable_count = count_syllables(current_line_text)
127
+ assert line_syllable_count < POETRY_LINE_SYLLABLES # assume we started with less than one full line
128
+ syllables_to_go = POETRY_LINE_SYLLABLES - line_syllable_count
129
+
130
+ for _ in range(length_cutoff): # really shouldn't have a line this long anyway
131
+ gpt_logits = gpt_model(encoded_input)[0][:, -1, :] # batch x vocab
132
+ gpt_logits[:, banned_tokens] = -1e8
133
+ top_logits, top_indices = gpt_logits.topk(precondition_topk, dim=1)
134
+
135
+ new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
136
+ expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
137
+ expanded_future_words = future_words.unsqueeze(0).unsqueeze(1).expand(batch_size, precondition_topk, -1) # batch x topk x N
138
+ candidate_syllables_to_go = []
139
+ for candidate in new_input_candidates[0]:
140
+ candidate_until_last_word_text = ' '.join(gpt_tokenizer.decode(candidate[previous_enc_len:]).split()[:-1])
141
+ candidate_syllables_to_go.append(10 - count_syllables(candidate_until_last_word_text))
142
+ # usually these are all the same, but run them all for correctness. could do more efficiently but it's not too slow anyway.
143
+ expanded_syllables_to_go = torch.LongTensor(candidate_syllables_to_go).to(device).view(1, precondition_topk)
144
+
145
+ if condition_lambda == 0:
146
+ iambic_logits = torch.zeros_like(expanded_lengths).float()
147
+ else:
148
+ # truncate prefix because we trained on single lines
149
+ iambic_logits = iambic_model(new_input_candidates[:, :, previous_enc_len:].flatten(0, 1), expanded_lengths.flatten(0, 1) - previous_enc_len, None, None, None)[:, -1] # batch*topk x seq+1 -> batch*topk
150
+ iambic_logits = iambic_logits.view(batch_size, precondition_topk)
151
+ iambic_logits = iambic_logits - torch.log(1 + torch.exp(iambic_logits))
152
+ if condition_lambda == 0:
153
+ rhyme_logits = torch.zeros_like(expanded_lengths).float()
154
+ else:
155
+ rhyme_logits = rhyme_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
156
+ expanded_lengths.flatten(0, 1), # batch*topk
157
+ expanded_future_words.flatten(0, 1), # batch*topk x N
158
+ log_probs, # N
159
+ expanded_syllables_to_go.flatten(0, 1)) # batch*topk
160
+ rhyme_logits = rhyme_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
161
+ rhyme_logits = rhyme_logits - torch.log(1 + torch.exp(rhyme_logits)) # batch x topk x N
162
+ rhyme_logits = rhyme_logits.squeeze(2) # batch x topk
163
+ if condition_lambda == 0:
164
+ newline_logits = torch.zeros_like(expanded_lengths).float()
165
+ else:
166
+ newline_logits = newline_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
167
+ expanded_lengths.flatten(0, 1), # batch*topk
168
+ expanded_future_words.flatten(0, 1), # batch*topk x N
169
+ log_probs, # N
170
+ expanded_syllables_to_go.flatten(0, 1)) # batch*topk
171
+ newline_logits = newline_logits[:, -1].view(batch_size, precondition_topk, -1) # batch x topk x N
172
+ newline_logits = newline_logits - torch.log(1 + torch.exp(newline_logits)) # batch x topk x N
173
+ newline_logits = newline_logits.squeeze(2) # batch x topk
174
+
175
+ full_logits = top_logits + condition_lambda * iambic_logits + condition_lambda * rhyme_logits + condition_lambda * newline_logits
176
+ post_logits, post_indices = full_logits.topk(postcondition_topk, dim=1)
177
+ post_probs = F.softmax(post_logits, dim=1)
178
+ index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
179
+ next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
180
+ encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
181
+ lengths = lengths + 1
182
+ syllables_to_go = POETRY_LINE_SYLLABLES - count_syllables(gpt_tokenizer.decode(encoded_input[0][previous_enc_len:])) # if we get very unlucky with a partial word that the syllable counter doesn't recognize we might end early, but it's unlikely
183
+ if syllables_to_go <= 0 and [gpt_tokenizer.decode(s) for s in encoded_input][0][-1] in PHRASE_ENDS:
184
+ break
185
+ if syllables_to_go < 0:
186
+ # encoded_input = encoded_input[:, :-1]
187
+ break
188
+
189
+ return [gpt_tokenizer.decode(s) for s in encoded_input][0][len(current_text):]
190
+
191
+
192
+ if __name__=='__main__':
193
+ parser = ArgumentParser()
194
+
195
+ # DATA
196
+ parser.add_argument('--iambic_ckpt', type=str, required=True)
197
+ parser.add_argument('--rhyme_ckpt', type=str, required=True)
198
+ parser.add_argument('--newline_ckpt', type=str, required=True)
199
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
200
+ parser.add_argument('--rhyme_info', type=str, required=True, help='saved rhyme info')
201
+ parser.add_argument('--model_string', type=str, default='gpt2-medium')
202
+
203
+ parser.add_argument('--input_text', type=str, default=None, required=True, help='initial text')
204
+
205
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
206
+ parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
207
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
208
+
209
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
210
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
211
+ parser.add_argument('--debug', action='store_true', default=False)
212
+
213
+ args = parser.parse_args()
214
+
215
+ random.seed(args.seed)
216
+ np.random.seed(args.seed)
217
+ torch.manual_seed(args.seed)
218
+
219
+ main(args)
fudge/predict_topic.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import time
4
+ import pickle
5
+ import math
6
+ from argparse import ArgumentParser
7
+
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from transformers import AutoTokenizer, AutoModelWithLMHead, pipeline, set_seed, GPT2Tokenizer, GPT2Model
14
+
15
+ from data import Dataset
16
+ from model import Model
17
+ from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
18
+ from constants import *
19
+
20
+ def main(args):
21
+ with open(args.dataset_info, 'rb') as rf:
22
+ dataset_info = pickle.load(rf)
23
+ for cw in args.condition_words.split():
24
+ assert cw in dataset_info.word2index
25
+ gpt_tokenizer = AutoTokenizer.from_pretrained(args.model_string)
26
+ gpt_tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
27
+ gpt_pad_id = gpt_tokenizer.encode(PAD_TOKEN)[0]
28
+ gpt_model = AutoModelWithLMHead.from_pretrained(args.model_string).to(args.device)
29
+ gpt_model.eval()
30
+
31
+ checkpoint = torch.load(args.ckpt, map_location=args.device)
32
+ model_args = checkpoint['args']
33
+ conditioning_model = Model(model_args, gpt_pad_id, len(dataset_info.index2word)) # no need to get the glove embeddings when reloading since they're saved in model ckpt anyway
34
+ conditioning_model.load_state_dict(checkpoint['state_dict'])
35
+ conditioning_model = conditioning_model.to(args.device)
36
+ conditioning_model.eval()
37
+ print("=> loaded checkpoint '{}' (epoch {})"
38
+ .format(args.ckpt, checkpoint['epoch']))
39
+ print('num params', num_params(conditioning_model))
40
+
41
+ while True:
42
+ results = predict(gpt_model,
43
+ gpt_tokenizer,
44
+ conditioning_model,
45
+ [args.input_text],
46
+ args.condition_words,
47
+ dataset_info,
48
+ args.precondition_topk,
49
+ args.topk,
50
+ args.length_cutoff,
51
+ condition_lambda=args.condition_lambda,
52
+ device=args.device)
53
+ print(results)
54
+ import pdb; pdb.set_trace()
55
+
56
+ def predict(gpt_model, gpt_tokenizer, conditioning_model, input_text, condition_words, dataset_info, precondition_topk, postcondition_topk, length_cutoff, condition_lambda=1.0, device='cuda'):
57
+ with torch.no_grad():
58
+ batch_size = len(input_text)
59
+
60
+ condition_words = condition_words.split()
61
+ future_words = torch.LongTensor([dataset_info.word2index[cw] for cw in condition_words]).to(device) # N
62
+ log_probs = torch.Tensor([math.log(dataset_info.vocab[cw] / dataset_info.total_words) for cw in condition_words]).to(device) # N
63
+
64
+ # assumes initially all same length.
65
+ encoded_input = [gpt_tokenizer.encode(it, return_tensors='pt').to(device) for it in input_text] # batch x seq
66
+ encoded_input = torch.cat(encoded_input, dim=0)
67
+ lengths = torch.LongTensor([encoded_input.shape[1]]).to(device)
68
+
69
+ gpt_encoded_future_words = [gpt_tokenizer.encode(' ' + cw, return_tensors='pt')[0].to(device) for cw in condition_words]
70
+ while lengths.max() < length_cutoff:
71
+ tokens_left = torch.LongTensor([length_cutoff - lengths.max() for _ in range(batch_size)]).to(device)
72
+ gpt_logits = gpt_model(encoded_input)[0][:, -1, :] # batch x vocab
73
+ top_logits, top_indices = gpt_logits.topk(precondition_topk, dim=1) # batch x topk
74
+ new_input_candidates = torch.cat([encoded_input.unsqueeze(1).expand(-1, precondition_topk, -1), top_indices.unsqueeze(2)], dim=2) # batch x topk x seq+1
75
+ expanded_lengths = (lengths + 1).unsqueeze(1).expand(batch_size, precondition_topk) # batch x topk
76
+ expanded_future_words = future_words.unsqueeze(0).unsqueeze(1).expand(batch_size, precondition_topk, -1) # batch x topk x N
77
+ expanded_tokens_left = tokens_left.unsqueeze(1).expand(-1, precondition_topk) # batch x topk
78
+ if condition_lambda == 0:
79
+ condition_logits = torch.zeros_like(expanded_future_words).float()
80
+ else:
81
+ condition_logits = conditioning_model(new_input_candidates.flatten(0, 1), # batch*topk x seq+1
82
+ expanded_lengths.flatten(0, 1), # batch*topk
83
+ expanded_future_words.flatten(0, 1), # batch*topk x N
84
+ log_probs, # N
85
+ expanded_tokens_left.flatten(0, 1)) # batch*topk
86
+ condition_logits = condition_logits.view(batch_size, precondition_topk, -1) # batch x topk x N
87
+ condition_logits = condition_logits - torch.log(1 + torch.exp(condition_logits)) # get correct log probs
88
+
89
+ condition_logits = torch.mean(condition_logits, dim=2)
90
+ full_logits = top_logits + condition_logits * condition_lambda # batch x topk
91
+ post_logits, post_indices = full_logits.topk(postcondition_topk, dim=1)
92
+ post_probs = F.softmax(post_logits, dim=1)
93
+ index_into_top_indices = post_indices[torch.arange(batch_size).to(post_indices.device), torch.multinomial(post_probs, 1).flatten()] # batch
94
+ next_indices = top_indices[torch.arange(batch_size).to(top_indices.device), index_into_top_indices] # batch
95
+ encoded_input = torch.cat([encoded_input, next_indices.unsqueeze(1)], dim=1) # batch x seq+1
96
+ lengths = lengths + 1 # batch
97
+ return [gpt_tokenizer.decode(s) for s in encoded_input]
98
+
99
+
100
+ if __name__=='__main__':
101
+ parser = ArgumentParser()
102
+
103
+ # DATA
104
+ parser.add_argument('--ckpt', type=str, required=True)
105
+ parser.add_argument('--dataset_info', type=str, required=True, help='saved dataset info')
106
+ parser.add_argument('--model_string', type=str, default='gpt2-medium')
107
+
108
+ parser.add_argument('--input_text', type=str, default=None, required=True, help='initial text')
109
+ parser.add_argument('--condition_words', type=str, default=None, required=True, help='word(s) to optimize for')
110
+
111
+ parser.add_argument('--precondition_topk', type=int, default=200, help='consider top k outputs from gpt at each step before conditioning and re-pruning')
112
+ parser.add_argument('--topk', type=int, default=10, help='consider top k outputs from gpt at each step')
113
+ parser.add_argument('--condition_lambda', type=float, default=1.0, help='lambda weight on conditioning model')
114
+ parser.add_argument('--length_cutoff', type=int, default=80, help='max length')
115
+
116
+ parser.add_argument('--seed', type=int, default=1, help='random seed')
117
+ parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda'])
118
+ parser.add_argument('--debug', action='store_true', default=False)
119
+
120
+ args = parser.parse_args()
121
+
122
+ random.seed(args.seed)
123
+ np.random.seed(args.seed)
124
+ torch.manual_seed(args.seed)
125
+
126
+ main(args)