Add cleaner, fix data preparation
Browse files- src/data_utils.py +71 -0
- src/preparaing_recipe_nlg_dataset.py +62 -16
- src/run.sh +1 -1
- src/run_ed_recipe_nlg.py +0 -1
src/data_utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from nltk.tokenize import wordpunct_tokenize as word_tokenize
|
2 |
+
from nltk.tokenize import sent_tokenize
|
3 |
+
|
4 |
+
import re
|
5 |
+
import six
|
6 |
+
import textwrap
|
7 |
+
|
8 |
+
_whitelist = r"[0-9a-z\,\.\/\<\>]+"
|
9 |
+
_regex = "0-9a-z\,\.\/\<\>"
|
10 |
+
|
11 |
+
|
12 |
+
def filter_by_lang_regex(text, ratio=0.7, regex="0-9a-z\,\.\/\<\>"):
|
13 |
+
candidate_text = re.sub(r"[^" + regex + "]+", " ", six.ensure_str(text), flags=re.IGNORECASE).replace(" ", "")
|
14 |
+
text = text.replace(" ", "")
|
15 |
+
|
16 |
+
return (len(candidate_text) / len(text)) > ratio
|
17 |
+
|
18 |
+
|
19 |
+
def filter_by_num_tokens(text, gt=64):
|
20 |
+
return len(word_tokenize(text)) > gt
|
21 |
+
|
22 |
+
|
23 |
+
def filter_by_num_sents(text, gt=2):
|
24 |
+
return len(sent_tokenize(text)) > gt
|
25 |
+
|
26 |
+
|
27 |
+
def filter_by_steps(text):
|
28 |
+
return re.search('(step|mix all)', text, re.IGNORECASE) is not None
|
29 |
+
|
30 |
+
|
31 |
+
def filter_by_length(text, gt=40):
|
32 |
+
return len(text) > gt
|
33 |
+
|
34 |
+
|
35 |
+
def filter_by_item(item_list, gt=4):
|
36 |
+
return len(item_list) > gt
|
37 |
+
|
38 |
+
|
39 |
+
def chars_to_preserve(sentence, whitelist):
|
40 |
+
try:
|
41 |
+
tokenized = re.findall(whitelist, sentence, re.IGNORECASE)
|
42 |
+
return " ".join(tokenized)
|
43 |
+
except Exception as error:
|
44 |
+
print(
|
45 |
+
textwrap.dedent(
|
46 |
+
f"""
|
47 |
+
Bad characters range {whitelist},
|
48 |
+
{error}
|
49 |
+
"""
|
50 |
+
)
|
51 |
+
)
|
52 |
+
raise
|
53 |
+
|
54 |
+
|
55 |
+
def normalizer(text, whitelist=r"[0-9a-z\,\.\/\<\>]+", do_lowercase=False):
|
56 |
+
if do_lowercase:
|
57 |
+
text = text.lower()
|
58 |
+
|
59 |
+
text = chars_to_preserve(text, whitelist=whitelist)
|
60 |
+
text = " ".join([word.strip() for word in text.split() if word.strip()])
|
61 |
+
text = text.strip()
|
62 |
+
|
63 |
+
return text
|
64 |
+
|
65 |
+
# _text = "Crust, Peanut Butter}Melt <sep> 1/2Butter, 2 c. Eggs, Filling, Semi- Sweet Chocolate Chips, Milk, Butter, " \
|
66 |
+
# "Frosting"
|
67 |
+
# out = normalizer(_text)
|
68 |
+
# print(out)
|
69 |
+
#
|
70 |
+
# _text = "step ... "
|
71 |
+
# print(re.search('(step|mix all)', _text, re.IGNORECASE) != None)
|
src/preparaing_recipe_nlg_dataset.py
CHANGED
@@ -5,6 +5,7 @@ import sys
|
|
5 |
from dataclasses import dataclass, field
|
6 |
|
7 |
import pandas as pd
|
|
|
8 |
from tqdm import tqdm
|
9 |
from typing import Dict, List, Optional, Tuple
|
10 |
|
@@ -13,6 +14,16 @@ from transformers import (
|
|
13 |
HfArgumentParser,
|
14 |
)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
logger = logging.getLogger(__name__)
|
17 |
|
18 |
|
@@ -72,40 +83,75 @@ def main():
|
|
72 |
|
73 |
def cleaning(text, item_type="ner"):
|
74 |
# NOTE: DO THE CLEANING LATER
|
|
|
75 |
return text
|
76 |
|
77 |
def recipe_preparation(item_dict):
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
return None
|
85 |
|
86 |
-
ner =
|
87 |
-
ingredients =
|
88 |
-
steps =
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
return {
|
91 |
"inputs": ner,
|
92 |
-
"targets": f"{ingredients}<
|
93 |
}
|
94 |
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
data_dict = []
|
|
|
97 |
for item in tqdm(dataset[subset], position=0, total=len(dataset[subset])):
|
98 |
item = recipe_preparation(item)
|
99 |
if item:
|
100 |
data_dict.append(item)
|
101 |
|
102 |
data_df = pd.DataFrame(data_dict)
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
logger.info(f"Data saved here {
|
109 |
|
110 |
|
111 |
if __name__ == '__main__':
|
|
|
5 |
from dataclasses import dataclass, field
|
6 |
|
7 |
import pandas as pd
|
8 |
+
from sklearn.model_selection import train_test_split
|
9 |
from tqdm import tqdm
|
10 |
from typing import Dict, List, Optional, Tuple
|
11 |
|
|
|
14 |
HfArgumentParser,
|
15 |
)
|
16 |
|
17 |
+
from data_utils import (
|
18 |
+
filter_by_lang_regex,
|
19 |
+
filter_by_steps,
|
20 |
+
filter_by_length,
|
21 |
+
filter_by_item,
|
22 |
+
filter_by_num_sents,
|
23 |
+
filter_by_num_tokens,
|
24 |
+
normalizer
|
25 |
+
)
|
26 |
+
|
27 |
logger = logging.getLogger(__name__)
|
28 |
|
29 |
|
|
|
83 |
|
84 |
def cleaning(text, item_type="ner"):
|
85 |
# NOTE: DO THE CLEANING LATER
|
86 |
+
text = normalizer(text, do_lowercase=True)
|
87 |
return text
|
88 |
|
89 |
def recipe_preparation(item_dict):
|
90 |
+
ner = item_dict["ner"]
|
91 |
+
title = item_dict["title"]
|
92 |
+
ingredients = item_dict["ingredients"]
|
93 |
+
steps = item_dict["directions"]
|
94 |
+
|
95 |
+
condition_1 = filter_by_item(ner, 4)
|
96 |
+
condition_2 = filter_by_length(title, 10)
|
97 |
+
condition_3 = filter_by_item(ingredients, 4)
|
98 |
+
condition_4 = filter_by_item(steps, 2)
|
99 |
+
condition_5 = filter_by_steps(" ".join(steps))
|
100 |
+
|
101 |
+
if not all([condition_1, condition_2, condition_3, condition_4, condition_5]):
|
102 |
return None
|
103 |
|
104 |
+
ner = ", ".join(ner)
|
105 |
+
ingredients = " <sep> ".join(ingredients)
|
106 |
+
steps = " <sep> ".join(steps)
|
107 |
+
|
108 |
+
# Cleaning
|
109 |
+
ner = cleaning(ner, "ner")
|
110 |
+
title = cleaning(title, "title")
|
111 |
+
ingredients = cleaning(ingredients, "ingredients")
|
112 |
+
steps = cleaning(steps, "steps")
|
113 |
|
114 |
return {
|
115 |
"inputs": ner,
|
116 |
+
"targets": f"title: {title} <section> ingredients: {ingredients} <section> directions: {steps}"
|
117 |
}
|
118 |
|
119 |
+
if len(dataset.keys()) > 1:
|
120 |
+
for subset in dataset.keys():
|
121 |
+
data_dict = []
|
122 |
+
for item in tqdm(dataset[subset], position=0, total=len(dataset[subset])):
|
123 |
+
item = recipe_preparation(item)
|
124 |
+
if item:
|
125 |
+
data_dict.append(item)
|
126 |
+
|
127 |
+
data_df = pd.DataFrame(data_dict)
|
128 |
+
logger.info(f"Preparation of [{subset}] set consists of {len(data_df)} records!")
|
129 |
+
|
130 |
+
output_path = os.path.join(data_args.output_dir, f"{subset}.csv")
|
131 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
132 |
+
data_df.to_csv(output_path, sep="\t", encoding="utf-8", index=False)
|
133 |
+
logger.info(f"Data saved here {output_path}")
|
134 |
+
else:
|
135 |
data_dict = []
|
136 |
+
subset = list(dataset.keys())[0]
|
137 |
for item in tqdm(dataset[subset], position=0, total=len(dataset[subset])):
|
138 |
item = recipe_preparation(item)
|
139 |
if item:
|
140 |
data_dict.append(item)
|
141 |
|
142 |
data_df = pd.DataFrame(data_dict)
|
143 |
+
train, test = train_test_split(data_df, test_size=0.05, random_state=101)
|
144 |
+
|
145 |
+
train = train.reset_index(drop=True)
|
146 |
+
test = test.reset_index(drop=True)
|
147 |
+
|
148 |
+
logger.info(f"Preparation of [train] set consists of {len(train)} records!")
|
149 |
+
logger.info(f"Preparation of [test] set consists of {len(test)} records!")
|
150 |
|
151 |
+
os.makedirs(data_args.output_dir, exist_ok=True)
|
152 |
+
train.to_csv(os.path.join(data_args.output_dir, "train.csv"), sep="\t", encoding="utf-8", index=False)
|
153 |
+
test.to_csv(os.path.join(data_args.output_dir, "test.csv"), sep="\t", encoding="utf-8", index=False)
|
154 |
+
logger.info(f"Data saved here {data_args.output_dir}")
|
155 |
|
156 |
|
157 |
if __name__ == '__main__':
|
src/run.sh
CHANGED
@@ -35,7 +35,7 @@ python run_ed_recipe_nlg.py \
|
|
35 |
--max_target_length="$MAX_TARGET_LENGTH" \
|
36 |
--model_name_or_path="$MODEL_NAME_OR_PATH" \
|
37 |
--extra_tokens="" \
|
38 |
-
--special_tokens="<sep>,<
|
39 |
--per_device_train_batch_size=$PER_DEVICE_TRAIN_BATCH_SIZE \
|
40 |
--per_device_eval_batch_size=$PER_DEVICE_EVAL_BATCH_SIZE \
|
41 |
--gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \
|
|
|
35 |
--max_target_length="$MAX_TARGET_LENGTH" \
|
36 |
--model_name_or_path="$MODEL_NAME_OR_PATH" \
|
37 |
--extra_tokens="" \
|
38 |
+
--special_tokens="<sep>,<section>" \
|
39 |
--per_device_train_batch_size=$PER_DEVICE_TRAIN_BATCH_SIZE \
|
40 |
--per_device_eval_batch_size=$PER_DEVICE_EVAL_BATCH_SIZE \
|
41 |
--gradient_accumulation_steps=$GRADIENT_ACCUMULATION_STEPS \
|
src/run_ed_recipe_nlg.py
CHANGED
@@ -409,7 +409,6 @@ def main():
|
|
409 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
410 |
)
|
411 |
|
412 |
-
model.resize_token_embeddings(len(tokenizer))
|
413 |
if model.config.decoder_start_token_id is None:
|
414 |
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
415 |
|
|
|
409 |
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
410 |
)
|
411 |
|
|
|
412 |
if model.config.decoder_start_token_id is None:
|
413 |
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
414 |
|