Siki-77 commited on
Commit
e16d99c
·
verified ·
1 Parent(s): 2a88abf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +70 -1
README.md CHANGED
@@ -22,7 +22,76 @@ It achieves the following results on the evaluation set:
22
 
23
  ## Model description
24
 
25
- More information needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  ## Intended uses & limitations
28
 
 
22
 
23
  ## Model description
24
 
25
+ Train and Test Code
26
+ ```python
27
+ from datasets import load_dataset
28
+ imdb = load_dataset("imdb")
29
+
30
+ import numpy as np
31
+ from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
32
+ import torch
33
+ from transformers import AutoTokenizer
34
+ from transformers import DataCollatorWithPadding
35
+ from transformers import EarlyStoppingCallback
36
+ import evaluate
37
+
38
+
39
+ # model_name = 'xlnet-large-cased'
40
+ model_name = 'roberta-large'
41
+
42
+ id2label = {0: "NEGATIVE", 1: "POSITIVE"}
43
+ label2id = {"NEGATIVE": 0, "POSITIVE": 1}
44
+ def compute_metrics(eval_pred):
45
+ predictions, labels = eval_pred
46
+ predictions = np.argmax(predictions, axis=1)
47
+ return accuracy.compute(predictions=predictions, references=labels)
48
+
49
+
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+ def preprocess_function(examples):
52
+ return tokenizer(examples["text"], truncation=True)
53
+ tokenized_imdb = imdb.map(preprocess_function, batched=True)
54
+
55
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
56
+ accuracy = evaluate.load("accuracy")
57
+
58
+ model = AutoModelForSequenceClassification.from_pretrained(
59
+ model_name, num_labels=2, id2label=id2label, label2id=label2id
60
+ )
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ model = model.to(device)
63
+
64
+
65
+ bts = 8
66
+ accumulated_step = 2
67
+ training_args = TrainingArguments(
68
+ output_dir=f"5imdb_{model_name.replace('-','_')}",
69
+ learning_rate=2e-5,
70
+ per_device_train_batch_size=bts,
71
+ per_device_eval_batch_size=bts,
72
+ num_train_epochs=2,
73
+ weight_decay=0.01,
74
+ evaluation_strategy="epoch",
75
+ save_strategy="epoch",
76
+ load_best_model_at_end=True,
77
+ push_to_hub=True,
78
+ gradient_accumulation_steps=accumulated_step,
79
+ )
80
+ # 创建 EarlyStoppingCallback 回调
81
+ early_stopping = EarlyStoppingCallback(early_stopping_patience=3)
82
+ trainer = Trainer(
83
+ model=model,
84
+ args=training_args,
85
+ train_dataset=tokenized_imdb["train"],
86
+ eval_dataset=tokenized_imdb["test"],
87
+ tokenizer=tokenizer,
88
+ data_collator=data_collator,
89
+ compute_metrics=compute_metrics,
90
+ callbacks=[early_stopping],
91
+ )
92
+
93
+ trainer.train()
94
+ ```
95
 
96
  ## Intended uses & limitations
97