jed351 commited on
Commit
c7a59d6
1 Parent(s): 94a799b

Add test script

Browse files
Files changed (1) hide show
  1. test.py +85 -0
test.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from datasets import ClassLabel
3
+ import random
4
+ import pandas as pd
5
+
6
+
7
+ def tokenize_function(examples):
8
+ return tokenizer(examples['text'], add_special_tokens=True)
9
+
10
+
11
+ def group_texts(examples):
12
+ # Concatenate all texts.
13
+ concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
14
+ total_length = len(concatenated_examples[list(examples.keys())[0]])
15
+ # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
16
+ # customize this part to your needs.
17
+ total_length = (total_length // block_size) * block_size
18
+ # Split by chunks of max_len.
19
+ result = {
20
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
21
+ for k, t in concatenated_examples.items()
22
+ }
23
+ result["labels"] = result["input_ids"].copy()
24
+ return result
25
+
26
+
27
+
28
+ block_size = 128
29
+
30
+ from datasets import load_dataset
31
+ datasets = load_dataset('jed351/cantonese-wikipedia')
32
+
33
+ from transformers import AutoTokenizer
34
+ model_checkpoint = "Ayaka/bart-base-cantonese"
35
+
36
+ tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
37
+ tokenized_datasets = datasets.map(tokenize_function,
38
+ batched=True, num_proc=4, remove_columns=["text"])
39
+
40
+
41
+
42
+ lm_datasets = tokenized_datasets.map(
43
+ group_texts,
44
+ batched=True,
45
+ batch_size=1000,
46
+ num_proc=4,
47
+ )
48
+
49
+
50
+
51
+ from transformers import Trainer, TrainingArguments
52
+
53
+
54
+ from transformers import DataCollatorForLanguageModeling
55
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
56
+
57
+
58
+
59
+
60
+ from transformers import AutoModelForMaskedLM
61
+ model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
62
+
63
+
64
+ training_args = TrainingArguments(
65
+ f"bart-finetuned-wikitext2",
66
+ evaluation_strategy = "epoch",
67
+ learning_rate=2e-5,
68
+ weight_decay=0.01,
69
+ push_to_hub=False,
70
+ per_device_train_batch_size=72,
71
+ fp16=True,
72
+ save_steps=5000
73
+ )
74
+
75
+
76
+ trainer = Trainer(
77
+ model=model,
78
+ args=training_args,
79
+ train_dataset=lm_datasets["train"],
80
+ eval_dataset=lm_datasets["test"],
81
+ data_collator=data_collator,
82
+ )
83
+
84
+
85
+ trainer.train()