Upload train.py with huggingface_hub
Browse files
train.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import random
|
3 |
+
import shutil
|
4 |
+
from datasets import load_dataset, concatenate_datasets, Features, Sequence, ClassLabel, Value, DatasetDict
|
5 |
+
from transformers import TrainingArguments
|
6 |
+
from span_marker import SpanMarkerModel, Trainer
|
7 |
+
from span_marker.model_card import SpanMarkerModelCardData
|
8 |
+
from huggingface_hub import upload_folder, upload_file
|
9 |
+
|
10 |
+
|
11 |
+
"""
|
12 |
+
FEATURES = Features({"tokens": Sequence(feature=Value(dtype='string')), "ner_tags": Sequence(feature=ClassLabel(names=['O', 'B-ORG', 'I-ORG']))})
|
13 |
+
|
14 |
+
|
15 |
+
def load_fewnerd():
|
16 |
+
def mapper(sample):
|
17 |
+
sample["ner_tags"] = [int(tag == 5) for tag in sample["ner_tags"]]
|
18 |
+
sample["ner_tags"] = [2 if tag == 1 and idx > 0 and sample["ner_tags"][idx - 1] == 1 else tag for idx, tag in enumerate(sample["ner_tags"])]
|
19 |
+
return sample
|
20 |
+
|
21 |
+
dataset = load_dataset("DFKI-SLT/few-nerd", "supervised")
|
22 |
+
dataset = dataset.map(mapper, remove_columns=["id", "fine_ner_tags"])
|
23 |
+
dataset = dataset.cast(FEATURES)
|
24 |
+
return dataset
|
25 |
+
|
26 |
+
|
27 |
+
def load_conll():
|
28 |
+
label_mapping = {3: 1, 4: 2}
|
29 |
+
def mapper(sample):
|
30 |
+
sample["ner_tags"] = [label_mapping.get(tag, 0) for tag in sample["ner_tags"]]
|
31 |
+
return sample
|
32 |
+
|
33 |
+
dataset = load_dataset("conll2003")
|
34 |
+
dataset = dataset.map(mapper, remove_columns=["id", "pos_tags", "chunk_tags"])
|
35 |
+
dataset = dataset.cast(FEATURES)
|
36 |
+
return dataset
|
37 |
+
|
38 |
+
|
39 |
+
def load_ontonotes():
|
40 |
+
label_mapping = {11: 1, 12: 2}
|
41 |
+
def mapper(sample):
|
42 |
+
sample["ner_tags"] = [label_mapping.get(tag, 0) for tag in sample["ner_tags"]]
|
43 |
+
return sample
|
44 |
+
|
45 |
+
dataset = load_dataset("tner/ontonotes5")
|
46 |
+
dataset = dataset.rename_column("tags", "ner_tags")
|
47 |
+
dataset = dataset.map(mapper)
|
48 |
+
dataset = dataset.cast(FEATURES)
|
49 |
+
return dataset
|
50 |
+
|
51 |
+
|
52 |
+
def load_multinerd():
|
53 |
+
label_mapping = {5: 1, 6: 2}
|
54 |
+
def mapper(sample):
|
55 |
+
sample["ner_tags"] = [label_mapping.get(tag, 0) for tag in sample["ner_tags"]]
|
56 |
+
return sample
|
57 |
+
|
58 |
+
def lang_filter(sample):
|
59 |
+
return sample["lang"] == "en"
|
60 |
+
|
61 |
+
dataset = load_dataset("Babelscape/multinerd")
|
62 |
+
dataset = dataset.filter(lang_filter)
|
63 |
+
dataset = dataset.map(mapper, remove_columns="lang")
|
64 |
+
dataset = dataset.cast(FEATURES)
|
65 |
+
return dataset
|
66 |
+
|
67 |
+
|
68 |
+
def preprocess_raw_dataset(raw_dataset):
|
69 |
+
# Set the number of sentences without an org equal to the number of sentences with an org
|
70 |
+
def has_org(sample):
|
71 |
+
return bool(sum(sample["ner_tags"]))
|
72 |
+
|
73 |
+
def has_no_org(sample):
|
74 |
+
return not has_org(sample)
|
75 |
+
|
76 |
+
dataset_org = raw_dataset.filter(has_org)
|
77 |
+
dataset_no_org = raw_dataset.filter(has_no_org)
|
78 |
+
dataset_no_org = dataset_no_org.select(random.sample(range(len(dataset_no_org)), k=len(dataset_org)))
|
79 |
+
dataset = concatenate_datasets([dataset_org, dataset_no_org])
|
80 |
+
return dataset
|
81 |
+
"""
|
82 |
+
|
83 |
+
|
84 |
+
def main() -> None:
|
85 |
+
# Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
|
86 |
+
labels = ["O", "B-ORG", "I-ORG"]
|
87 |
+
"""
|
88 |
+
fewnerd_dataset = load_fewnerd()
|
89 |
+
conll_dataset = load_conll()
|
90 |
+
ontonotes_dataset = load_ontonotes()
|
91 |
+
multinerd_dataset = load_multinerd()
|
92 |
+
|
93 |
+
raw_train_dataset = concatenate_datasets([fewnerd_dataset["train"], conll_dataset["train"], ontonotes_dataset["train"], multinerd_dataset["train"]])
|
94 |
+
raw_eval_dataset = concatenate_datasets([fewnerd_dataset["validation"], conll_dataset["validation"], ontonotes_dataset["validation"], multinerd_dataset["validation"]])
|
95 |
+
raw_test_dataset = concatenate_datasets([fewnerd_dataset["test"], conll_dataset["test"], ontonotes_dataset["test"], multinerd_dataset["test"]])
|
96 |
+
|
97 |
+
train_dataset = preprocess_raw_dataset(raw_train_dataset)
|
98 |
+
eval_dataset = preprocess_raw_dataset(raw_eval_dataset)
|
99 |
+
test_dataset = preprocess_raw_dataset(raw_test_dataset)
|
100 |
+
|
101 |
+
dataset_dict = DatasetDict({
|
102 |
+
"train": train_dataset,
|
103 |
+
"validation": eval_dataset,
|
104 |
+
"test": test_dataset,
|
105 |
+
})
|
106 |
+
dataset_dict.push_to_hub("ner-orgs", private=True)
|
107 |
+
"""
|
108 |
+
# breakpoint()
|
109 |
+
dataset = load_dataset("tomaarsen/ner-orgs")
|
110 |
+
|
111 |
+
train_dataset = dataset["train"]
|
112 |
+
eval_dataset = dataset["validation"]
|
113 |
+
eval_dataset = eval_dataset.select(random.sample(range(len(eval_dataset)), k=3000))
|
114 |
+
test_dataset = dataset["test"]
|
115 |
+
|
116 |
+
# Initialize a SpanMarker model using a pretrained BERT-style encoder
|
117 |
+
encoder_id = "bert-base-cased"
|
118 |
+
model_id = f"tomaarsen/span-marker-bert-base-orgs"
|
119 |
+
model = SpanMarkerModel.from_pretrained(
|
120 |
+
encoder_id,
|
121 |
+
labels=labels,
|
122 |
+
# SpanMarker hyperparameters:
|
123 |
+
model_max_length=256,
|
124 |
+
marker_max_length=128,
|
125 |
+
entity_max_length=8,
|
126 |
+
# Model card variables
|
127 |
+
model_card_data=SpanMarkerModelCardData(
|
128 |
+
model_id=model_id,
|
129 |
+
encoder_id=encoder_id,
|
130 |
+
dataset_name="FewNERD, CoNLL2003, OntoNotes v5, and MultiNERD",
|
131 |
+
language=["en"],
|
132 |
+
),
|
133 |
+
)
|
134 |
+
|
135 |
+
# Prepare the 🤗 transformers training arguments
|
136 |
+
output_dir = Path("models") / model_id
|
137 |
+
args = TrainingArguments(
|
138 |
+
output_dir=output_dir,
|
139 |
+
run_name=model_id,
|
140 |
+
# Training Hyperparameters:
|
141 |
+
learning_rate=5e-5,
|
142 |
+
per_device_train_batch_size=32,
|
143 |
+
per_device_eval_batch_size=32,
|
144 |
+
num_train_epochs=3,
|
145 |
+
weight_decay=0.01,
|
146 |
+
warmup_ratio=0.1,
|
147 |
+
bf16=True, # Replace `bf16` with `fp16` if your hardware can't use bf16.
|
148 |
+
# Other Training parameters
|
149 |
+
logging_first_step=True,
|
150 |
+
logging_steps=100,
|
151 |
+
evaluation_strategy="steps",
|
152 |
+
save_strategy="steps",
|
153 |
+
eval_steps=3000,
|
154 |
+
save_total_limit=1,
|
155 |
+
dataloader_num_workers=4,
|
156 |
+
)
|
157 |
+
|
158 |
+
# Initialize the trainer using our model, training args & dataset, and train
|
159 |
+
trainer = Trainer(
|
160 |
+
model=model,
|
161 |
+
args=args,
|
162 |
+
train_dataset=train_dataset,
|
163 |
+
eval_dataset=eval_dataset,
|
164 |
+
)
|
165 |
+
trainer.train()
|
166 |
+
|
167 |
+
# Compute & save the metrics on the test set
|
168 |
+
metrics = trainer.evaluate(test_dataset, metric_key_prefix="test")
|
169 |
+
trainer.save_metrics("test", metrics)
|
170 |
+
|
171 |
+
# Save the model & training script locally
|
172 |
+
trainer.save_model(output_dir / "checkpoint-final")
|
173 |
+
shutil.copy2(__file__, output_dir / "checkpoint-final" / "train.py")
|
174 |
+
|
175 |
+
# Upload everything to the Hub
|
176 |
+
breakpoint()
|
177 |
+
model.push_to_hub(model_id, private=True)
|
178 |
+
upload_folder(folder_path=output_dir / "runs", path_in_repo="runs", repo_id=model_id)
|
179 |
+
upload_file(path_or_fileobj=__file__, path_in_repo="train.py", repo_id=model_id)
|
180 |
+
upload_file(path_or_fileobj=output_dir / "all_results.json", path_in_repo="all_results.json", repo_id=model_id)
|
181 |
+
upload_file(path_or_fileobj=output_dir / "emissions.csv", path_in_repo="emissions.csv", repo_id=model_id)
|
182 |
+
|
183 |
+
|
184 |
+
if __name__ == "__main__":
|
185 |
+
main()
|