VinceItsMe
commited on
Commit
•
104ec88
1
Parent(s):
0d9b461
Create test.py
Browse files
test.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset,concatenate_datasets
|
2 |
+
from setfit import SetFitModel, SetFitTrainer
|
3 |
+
from sentence_transformers.losses import CosineSimilarityLoss
|
4 |
+
|
5 |
+
|
6 |
+
# Load the dataset
|
7 |
+
dataset = load_dataset("ag_news")
|
8 |
+
|
9 |
+
# create train dataset
|
10 |
+
seed=20
|
11 |
+
labels = 4
|
12 |
+
samples_per_label = 8
|
13 |
+
sampled_datasets = []
|
14 |
+
# find the number of samples per label
|
15 |
+
for i in range(labels):
|
16 |
+
sampled_datasets.append(dataset["train"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
|
17 |
+
|
18 |
+
# concatenate the sampled datasets
|
19 |
+
train_dataset = concatenate_datasets(sampled_datasets)
|
20 |
+
|
21 |
+
# create test dataset
|
22 |
+
test_dataset = dataset["test"]
|
23 |
+
|
24 |
+
# Load a SetFit model from Hub
|
25 |
+
model_id = "sentence-transformers/all-mpnet-base-v2"
|
26 |
+
model = SetFitModel.from_pretrained(model_id)
|
27 |
+
|
28 |
+
# Create trainer
|
29 |
+
trainer = SetFitTrainer(
|
30 |
+
model=model,
|
31 |
+
train_dataset=train_dataset,
|
32 |
+
eval_dataset=test_dataset,
|
33 |
+
loss_class=CosineSimilarityLoss,
|
34 |
+
metric="accuracy",
|
35 |
+
batch_size=64,
|
36 |
+
num_iterations=20, # The number of text pairs to generate for contrastive learning
|
37 |
+
num_epochs=1, # The number of epochs to use for constrastive learning
|
38 |
+
)
|
39 |
+
|
40 |
+
# Train and evaluate
|
41 |
+
trainer.train()
|
42 |
+
metrics = trainer.evaluate()
|
43 |
+
|
44 |
+
print(f"model used: {model_id}")
|
45 |
+
print(f"train dataset: {len(train_dataset)} samples")
|
46 |
+
print(f"accuracy: {metrics['accuracy']}")
|