|
from datasets import load_dataset, concatenate_datasets |
|
from sentence_transformers.losses import CosineSimilarityLoss |
|
from setfit import SetFitModel, SetFitTrainer |
|
|
|
|
|
dataset = load_dataset("ag_news") |
|
|
|
|
|
seed = 20 |
|
labels = 4 |
|
samples_per_label = 8 |
|
sampled_datasets = [] |
|
|
|
for i in range(labels): |
|
sampled_datasets.append( |
|
dataset["train"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label))) |
|
|
|
|
|
train_dataset = concatenate_datasets(sampled_datasets) |
|
|
|
|
|
labels = 4 |
|
samples_per_label = 8 |
|
sampled_datasets = [] |
|
|
|
for i in range(labels): |
|
sampled_datasets.append( |
|
dataset["test"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label))) |
|
test_dataset = concatenate_datasets(sampled_datasets) |
|
|
|
|
|
model_id = "sentence-transformers/all-mpnet-base-v2" |
|
model = SetFitModel.from_pretrained(model_id) |
|
|
|
|
|
trainer = SetFitTrainer( |
|
model=model, |
|
train_dataset=train_dataset, |
|
eval_dataset=test_dataset, |
|
loss_class=CosineSimilarityLoss, |
|
metric="accuracy", |
|
batch_size=64, |
|
num_iterations= 20, |
|
num_epochs=1, |
|
) |
|
|
|
|
|
trainer.train() |
|
metrics = trainer.evaluate() |
|
|
|
print(f"model used: {model_id}") |
|
print(f"train dataset: {len(train_dataset)} samples") |
|
print(f"accuracy: {metrics['accuracy']}") |
|
|
|
|
|
trainer.model.save_pretrained("my_first_test") |
|
|
|
|
|
model = SetFitModel.from_pretrained("my_first_test") |
|
|
|
preds = model(["i loved France!", "pineapple on pizza is the worst when watching football"]) |
|
label = {'0': 'World','1': 'Sports', '2': 'Business', '3': 'Sci/Tech'} |
|
output = [label[str(tt.item())] for tt in preds] |
|
q = 1 |
|
|