Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from modelscope import snapshot_download | |
from torchvision.models import squeezenet1_1 | |
MODEL_DIR = snapshot_download( | |
"ccmusic-database/pianos", | |
cache_dir="./__pycache__", | |
) | |
def Classifier(cls_num=8, output_size=512, linear_output=False): | |
q = (1.0 * output_size / cls_num) ** 0.25 | |
l1 = int(q * cls_num) | |
l2 = int(q * l1) | |
l3 = int(q * l2) | |
if linear_output: | |
return torch.nn.Sequential( | |
nn.Dropout(), | |
nn.Linear(output_size, l3), | |
nn.ReLU(inplace=True), | |
nn.Dropout(), | |
nn.Linear(l3, l2), | |
nn.ReLU(inplace=True), | |
nn.Dropout(), | |
nn.Linear(l2, l1), | |
nn.ReLU(inplace=True), | |
nn.Linear(l1, cls_num), | |
) | |
else: | |
return torch.nn.Sequential( | |
nn.Dropout(), | |
nn.Conv2d(output_size, l3, kernel_size=(1, 1), stride=(1, 1)), | |
nn.ReLU(inplace=True), | |
nn.AdaptiveAvgPool2d(output_size=(1, 1)), | |
nn.Flatten(), | |
nn.Linear(l3, l2), | |
nn.ReLU(inplace=True), | |
nn.Dropout(), | |
nn.Linear(l2, l1), | |
nn.ReLU(inplace=True), | |
nn.Linear(l1, cls_num), | |
) | |
def net(weights=f"{MODEL_DIR}/save.pt"): | |
model = squeezenet1_1(pretrained=False) | |
model.classifier = Classifier() | |
model.load_state_dict(torch.load(weights, map_location=torch.device("cpu"))) | |
model.eval() | |
return model | |