File size: 2,056 Bytes
6a06532 c2637f8 6a06532 043aa62 6a06532 043aa62 6a06532 043aa62 6a06532 043aa62 6a06532 cbecf0e 6a06532 043aa62 6a06532 043aa62 cbecf0e 043aa62 cbecf0e 043aa62 cbecf0e 6a06532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from datasets import load_dataset, concatenate_datasets
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
# create train dataset
seed = 20
labels = 4
samples_per_label = 8
sampled_datasets = []
# find the number of samples per label
for i in range(labels):
sampled_datasets.append(
dataset["train"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
# concatenate the sampled datasets
train_dataset = concatenate_datasets(sampled_datasets)
# create test dataset
labels = 4
samples_per_label = 8
sampled_datasets = []
# find the number of samples per label
for i in range(labels):
sampled_datasets.append(
dataset["test"].filter(lambda x: x["label"] == i).shuffle(seed=seed).select(range(samples_per_label)))
test_dataset = concatenate_datasets(sampled_datasets)
# Load a SetFit model from Hub
model_id = "sentence-transformers/all-mpnet-base-v2"
model = SetFitModel.from_pretrained(model_id)
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=test_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=64,
num_iterations= 20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for constrastive learning
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
print(f"model used: {model_id}")
print(f"train dataset: {len(train_dataset)} samples")
print(f"accuracy: {metrics['accuracy']}")
# Push model to the Hub
trainer.model.save_pretrained("my_first_test")
# Download from Hub and run inference
model = SetFitModel.from_pretrained("my_first_test")
# Run inference
preds = model(["i loved France!", "pineapple on pizza is the worst when watching football"])
label = {'0': 'World','1': 'Sports', '2': 'Business', '3': 'Sci/Tech'}
output = [label[str(tt.item())] for tt in preds]
q = 1
|