tomaarsen HF staff commited on
Commit
09f6214
·
1 Parent(s): c9384b1

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +185 -0
train.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import random
3
+ import shutil
4
+ from datasets import load_dataset, concatenate_datasets, Features, Sequence, ClassLabel, Value, DatasetDict
5
+ from transformers import TrainingArguments
6
+ from span_marker import SpanMarkerModel, Trainer
7
+ from span_marker.model_card import SpanMarkerModelCardData
8
+ from huggingface_hub import upload_folder, upload_file
9
+
10
+
11
+ """
12
+ FEATURES = Features({"tokens": Sequence(feature=Value(dtype='string')), "ner_tags": Sequence(feature=ClassLabel(names=['O', 'B-ORG', 'I-ORG']))})
13
+
14
+
15
+ def load_fewnerd():
16
+ def mapper(sample):
17
+ sample["ner_tags"] = [int(tag == 5) for tag in sample["ner_tags"]]
18
+ sample["ner_tags"] = [2 if tag == 1 and idx > 0 and sample["ner_tags"][idx - 1] == 1 else tag for idx, tag in enumerate(sample["ner_tags"])]
19
+ return sample
20
+
21
+ dataset = load_dataset("DFKI-SLT/few-nerd", "supervised")
22
+ dataset = dataset.map(mapper, remove_columns=["id", "fine_ner_tags"])
23
+ dataset = dataset.cast(FEATURES)
24
+ return dataset
25
+
26
+
27
+ def load_conll():
28
+ label_mapping = {3: 1, 4: 2}
29
+ def mapper(sample):
30
+ sample["ner_tags"] = [label_mapping.get(tag, 0) for tag in sample["ner_tags"]]
31
+ return sample
32
+
33
+ dataset = load_dataset("conll2003")
34
+ dataset = dataset.map(mapper, remove_columns=["id", "pos_tags", "chunk_tags"])
35
+ dataset = dataset.cast(FEATURES)
36
+ return dataset
37
+
38
+
39
+ def load_ontonotes():
40
+ label_mapping = {11: 1, 12: 2}
41
+ def mapper(sample):
42
+ sample["ner_tags"] = [label_mapping.get(tag, 0) for tag in sample["ner_tags"]]
43
+ return sample
44
+
45
+ dataset = load_dataset("tner/ontonotes5")
46
+ dataset = dataset.rename_column("tags", "ner_tags")
47
+ dataset = dataset.map(mapper)
48
+ dataset = dataset.cast(FEATURES)
49
+ return dataset
50
+
51
+
52
+ def load_multinerd():
53
+ label_mapping = {5: 1, 6: 2}
54
+ def mapper(sample):
55
+ sample["ner_tags"] = [label_mapping.get(tag, 0) for tag in sample["ner_tags"]]
56
+ return sample
57
+
58
+ def lang_filter(sample):
59
+ return sample["lang"] == "en"
60
+
61
+ dataset = load_dataset("Babelscape/multinerd")
62
+ dataset = dataset.filter(lang_filter)
63
+ dataset = dataset.map(mapper, remove_columns="lang")
64
+ dataset = dataset.cast(FEATURES)
65
+ return dataset
66
+
67
+
68
+ def preprocess_raw_dataset(raw_dataset):
69
+ # Set the number of sentences without an org equal to the number of sentences with an org
70
+ def has_org(sample):
71
+ return bool(sum(sample["ner_tags"]))
72
+
73
+ def has_no_org(sample):
74
+ return not has_org(sample)
75
+
76
+ dataset_org = raw_dataset.filter(has_org)
77
+ dataset_no_org = raw_dataset.filter(has_no_org)
78
+ dataset_no_org = dataset_no_org.select(random.sample(range(len(dataset_no_org)), k=len(dataset_org)))
79
+ dataset = concatenate_datasets([dataset_org, dataset_no_org])
80
+ return dataset
81
+ """
82
+
83
+
84
+ def main() -> None:
85
+ # Load the dataset, ensure "tokens" and "ner_tags" columns, and get a list of labels
86
+ labels = ["O", "B-ORG", "I-ORG"]
87
+ """
88
+ fewnerd_dataset = load_fewnerd()
89
+ conll_dataset = load_conll()
90
+ ontonotes_dataset = load_ontonotes()
91
+ multinerd_dataset = load_multinerd()
92
+
93
+ raw_train_dataset = concatenate_datasets([fewnerd_dataset["train"], conll_dataset["train"], ontonotes_dataset["train"], multinerd_dataset["train"]])
94
+ raw_eval_dataset = concatenate_datasets([fewnerd_dataset["validation"], conll_dataset["validation"], ontonotes_dataset["validation"], multinerd_dataset["validation"]])
95
+ raw_test_dataset = concatenate_datasets([fewnerd_dataset["test"], conll_dataset["test"], ontonotes_dataset["test"], multinerd_dataset["test"]])
96
+
97
+ train_dataset = preprocess_raw_dataset(raw_train_dataset)
98
+ eval_dataset = preprocess_raw_dataset(raw_eval_dataset)
99
+ test_dataset = preprocess_raw_dataset(raw_test_dataset)
100
+
101
+ dataset_dict = DatasetDict({
102
+ "train": train_dataset,
103
+ "validation": eval_dataset,
104
+ "test": test_dataset,
105
+ })
106
+ dataset_dict.push_to_hub("ner-orgs", private=True)
107
+ """
108
+ # breakpoint()
109
+ dataset = load_dataset("tomaarsen/ner-orgs")
110
+
111
+ train_dataset = dataset["train"]
112
+ eval_dataset = dataset["validation"]
113
+ eval_dataset = eval_dataset.select(random.sample(range(len(eval_dataset)), k=3000))
114
+ test_dataset = dataset["test"]
115
+
116
+ # Initialize a SpanMarker model using a pretrained BERT-style encoder
117
+ encoder_id = "bert-base-cased"
118
+ model_id = f"tomaarsen/span-marker-bert-base-orgs"
119
+ model = SpanMarkerModel.from_pretrained(
120
+ encoder_id,
121
+ labels=labels,
122
+ # SpanMarker hyperparameters:
123
+ model_max_length=256,
124
+ marker_max_length=128,
125
+ entity_max_length=8,
126
+ # Model card variables
127
+ model_card_data=SpanMarkerModelCardData(
128
+ model_id=model_id,
129
+ encoder_id=encoder_id,
130
+ dataset_name="FewNERD, CoNLL2003, OntoNotes v5, and MultiNERD",
131
+ language=["en"],
132
+ ),
133
+ )
134
+
135
+ # Prepare the 🤗 transformers training arguments
136
+ output_dir = Path("models") / model_id
137
+ args = TrainingArguments(
138
+ output_dir=output_dir,
139
+ run_name=model_id,
140
+ # Training Hyperparameters:
141
+ learning_rate=5e-5,
142
+ per_device_train_batch_size=32,
143
+ per_device_eval_batch_size=32,
144
+ num_train_epochs=3,
145
+ weight_decay=0.01,
146
+ warmup_ratio=0.1,
147
+ bf16=True, # Replace `bf16` with `fp16` if your hardware can't use bf16.
148
+ # Other Training parameters
149
+ logging_first_step=True,
150
+ logging_steps=100,
151
+ evaluation_strategy="steps",
152
+ save_strategy="steps",
153
+ eval_steps=3000,
154
+ save_total_limit=1,
155
+ dataloader_num_workers=4,
156
+ )
157
+
158
+ # Initialize the trainer using our model, training args & dataset, and train
159
+ trainer = Trainer(
160
+ model=model,
161
+ args=args,
162
+ train_dataset=train_dataset,
163
+ eval_dataset=eval_dataset,
164
+ )
165
+ trainer.train()
166
+
167
+ # Compute & save the metrics on the test set
168
+ metrics = trainer.evaluate(test_dataset, metric_key_prefix="test")
169
+ trainer.save_metrics("test", metrics)
170
+
171
+ # Save the model & training script locally
172
+ trainer.save_model(output_dir / "checkpoint-final")
173
+ shutil.copy2(__file__, output_dir / "checkpoint-final" / "train.py")
174
+
175
+ # Upload everything to the Hub
176
+ breakpoint()
177
+ model.push_to_hub(model_id, private=True)
178
+ upload_folder(folder_path=output_dir / "runs", path_in_repo="runs", repo_id=model_id)
179
+ upload_file(path_or_fileobj=__file__, path_in_repo="train.py", repo_id=model_id)
180
+ upload_file(path_or_fileobj=output_dir / "all_results.json", path_in_repo="all_results.json", repo_id=model_id)
181
+ upload_file(path_or_fileobj=output_dir / "emissions.csv", path_in_repo="emissions.csv", repo_id=model_id)
182
+
183
+
184
+ if __name__ == "__main__":
185
+ main()