|
from gliclass import GLiClassModel, ZeroShotClassificationPipeline |
|
from transformers import AutoTokenizer |
|
|
|
model = GLiClassModel.from_pretrained("knowledgator/gliclass-large-v1.0") |
|
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-large-v1.0") |
|
|
|
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0') |
|
|
|
text = "One day I will see the world!" |
|
labels = ["travel", "dreams", "sport", "science", "politics"] |
|
results = pipeline(text, labels, threshold=0.5)[0] |
|
|
|
for result in results: |
|
print(result["label"], "=>", result["score"]) |
|
|