vincenttruum commited on
Commit
78223e6
1 Parent(s): 104ec88
Files changed (1) hide show
  1. test_revised.py +47 -0
test_revised.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset,concatenate_datasets
2
+ from setfit import SetFitModel, SetFitTrainer
3
+ from sentence_transformers.losses import CosineSimilarityLoss
4
+
5
+
6
+ # Load the dataset
7
+ dataset = load_dataset("ag_news")
8
+
9
+ # create train dataset
10
+ seed=20
11
+ labels = 4
12
+ samples_per_label = 8
13
+ sampled_datasets = []
14
+ # find the number of samples per label
15
+ for i in range(labels):
16
+ sampled_datasets.append(dataset["train"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
17
+
18
+ # concatenate the sampled datasets
19
+ train_dataset = concatenate_datasets(sampled_datasets)
20
+
21
+ # create test dataset
22
+ test_dataset = dataset["test"]
23
+
24
+ # Load a SetFit model from Hub
25
+ model_id = "sentence-transformers/all-mpnet-base-v2"
26
+ model = SetFitModel.from_pretrained(model_id)
27
+
28
+ # Create trainer
29
+ trainer = SetFitTrainer(
30
+ model=model,
31
+ train_dataset=train_dataset,
32
+ eval_dataset=test_dataset,
33
+ loss_class=CosineSimilarityLoss,
34
+ metric="accuracy",
35
+ batch_size=64,
36
+ num_iterations=20, # The number of text pairs to generate for contrastive learning
37
+ num_epochs=1, # The number of epochs to use for constrastive learning
38
+ )
39
+
40
+ # Train and evaluate
41
+ trainer.train()
42
+ metrics = trainer.evaluate()
43
+
44
+ print(f"model used: {model_id}")
45
+ print(f"train dataset: {len(train_dataset)} samples")
46
+ print(f"accuracy: {metrics['accuracy']}")
47
+ print("Test")