vincenttruum
commited on
Commit
•
cbecf0e
1
Parent(s):
ee0c2e6
test
Browse files- newtest.py +6 -4
newtest.py
CHANGED
@@ -40,7 +40,7 @@ trainer = SetFitTrainer(
|
|
40 |
loss_class=CosineSimilarityLoss,
|
41 |
metric="accuracy",
|
42 |
batch_size=64,
|
43 |
-
num_iterations=
|
44 |
num_epochs=1, # The number of epochs to use for constrastive learning
|
45 |
)
|
46 |
|
@@ -53,10 +53,12 @@ print(f"train dataset: {len(train_dataset)} samples")
|
|
53 |
print(f"accuracy: {metrics['accuracy']}")
|
54 |
|
55 |
# Push model to the Hub
|
56 |
-
trainer.
|
57 |
|
58 |
# Download from Hub and run inference
|
59 |
-
model = SetFitModel.from_pretrained("
|
60 |
# Run inference
|
61 |
-
preds = model(["i loved
|
|
|
|
|
62 |
q = 1
|
|
|
40 |
loss_class=CosineSimilarityLoss,
|
41 |
metric="accuracy",
|
42 |
batch_size=64,
|
43 |
+
num_iterations= 20, # The number of text pairs to generate for contrastive learning
|
44 |
num_epochs=1, # The number of epochs to use for constrastive learning
|
45 |
)
|
46 |
|
|
|
53 |
print(f"accuracy: {metrics['accuracy']}")
|
54 |
|
55 |
# Push model to the Hub
|
56 |
+
trainer.model.save_pretrained("my_first_test")
|
57 |
|
58 |
# Download from Hub and run inference
|
59 |
+
model = SetFitModel.from_pretrained("my_first_test")
|
60 |
# Run inference
|
61 |
+
preds = model(["i loved France!", "pineapple on pizza is the worst when watching football"])
|
62 |
+
label = {'0': 'World','1': 'Sports', '2': 'Business', '3': 'Sci/Tech'}
|
63 |
+
output = [label[str(tt.item())] for tt in preds]
|
64 |
q = 1
|