MyFirstModel / test.py
VinceItsMe's picture
Create test.py
104ec88
raw
history blame
1.34 kB
from datasets import load_dataset,concatenate_datasets
from setfit import SetFitModel, SetFitTrainer
from sentence_transformers.losses import CosineSimilarityLoss
# Load the dataset
dataset = load_dataset("ag_news")
# create train dataset
seed=20
labels = 4
samples_per_label = 8
sampled_datasets = []
# find the number of samples per label
for i in range(labels):
sampled_datasets.append(dataset["train"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
# concatenate the sampled datasets
train_dataset = concatenate_datasets(sampled_datasets)
# create test dataset
test_dataset = dataset["test"]
# Load a SetFit model from Hub
model_id = "sentence-transformers/all-mpnet-base-v2"
model = SetFitModel.from_pretrained(model_id)
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=64,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for constrastive learning
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
print(f"model used: {model_id}")
print(f"train dataset: {len(train_dataset)} samples")
print(f"accuracy: {metrics['accuracy']}")