Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
538 Bytes
import os
import pandas as pd
from torchvision import datasets
def get_score(submission_folder = "../env"):
submission_path = os.path.join(submission_folder, "submission.csv")
submission = pd.read_csv(submission_path, index_col=0)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True)
acc = 0
for idx, (x, y) in enumerate(test_dataset):
pred = submission.loc[idx].argmax()
acc += int(pred == y)
return acc/len(test_dataset)
if __name__ == "__main__":
print(get_score())