|
from transformers import Tool |
|
from huggingface_hub import list_models |
|
|
|
|
|
class HFModelDownloadsTool(Tool): |
|
name = "model_download_counter" |
|
description = ( |
|
"This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub. " |
|
"It takes the name of the category (such as text-classification, depth-estimation, etc), and " |
|
"returns the name of the checkpoint." |
|
) |
|
|
|
inputs = ["text"] |
|
outputs = ["text"] |
|
|
|
def __call__(self, task: str): |
|
model = next(iter(list_models(filter=task, sort="downloads", direction=-1))) |
|
return model.id |
|
|