license: apache-2.0 | |
datasets: | |
- dataautogpt3/Dalle3 | |
- scrapegraphai/AQL-v1-QA | |
language: | |
- en | |
metrics: | |
- accuracy | |
base_model: | |
- microsoft/resnet-50 | |
new_version: microsoft/resnet-50 | |
pipeline_tag: image-classification | |
```python | |
def load_model(model_path, num_classes): | |
model = create_model(num_classes) | |
model.load_state_dict(torch.load(model_path)) | |
model.eval() | |
return model |