Spaces:
Build error
Build error
shyamgupta196
commited on
Commit
·
c27b177
1
Parent(s):
bd0fc79
map to cpu
Browse files- CatVsDogTrain.py +3 -3
- app.py +1 -1
CatVsDogTrain.py
CHANGED
@@ -224,7 +224,7 @@ def train_and_validate(model, loss_criterion, optimizer, epochs=25):
|
|
224 |
)
|
225 |
|
226 |
# Save if the model has best accuracy till now
|
227 |
-
torch.save(model, "
|
228 |
|
229 |
return model, history
|
230 |
|
@@ -258,7 +258,7 @@ history = []
|
|
258 |
def test():
|
259 |
test = datasets.ImageFolder(root="PetTest/", transform=convert)
|
260 |
testLoader = DataLoader(test, batch_size=16, shuffle=False)
|
261 |
-
checkpoint = torch.load("
|
262 |
alexnet.load_state_dict(checkpoint["state_dict"])
|
263 |
optimizer.load_state_dict(checkpoint["optimizer"])
|
264 |
for params in alexnet.parameters():
|
@@ -388,7 +388,7 @@ def predict(model, test_image_name):
|
|
388 |
|
389 |
if PREDICT:
|
390 |
checkpoint = torch.load(
|
391 |
-
"
|
392 |
)
|
393 |
alexnet.load_state_dict(checkpoint["state_dict"])
|
394 |
alexnet = alexnet.to(device)
|
|
|
224 |
)
|
225 |
|
226 |
# Save if the model has best accuracy till now
|
227 |
+
torch.save(model, "CatVsDogsModel.pth")
|
228 |
|
229 |
return model, history
|
230 |
|
|
|
258 |
def test():
|
259 |
test = datasets.ImageFolder(root="PetTest/", transform=convert)
|
260 |
testLoader = DataLoader(test, batch_size=16, shuffle=False)
|
261 |
+
checkpoint = torch.load("CatVsDogsModel.pth")
|
262 |
alexnet.load_state_dict(checkpoint["state_dict"])
|
263 |
optimizer.load_state_dict(checkpoint["optimizer"])
|
264 |
for params in alexnet.parameters():
|
|
|
388 |
|
389 |
if PREDICT:
|
390 |
checkpoint = torch.load(
|
391 |
+
"CatVsDogsModel.pth", map_location=torch.device("cpu")
|
392 |
)
|
393 |
alexnet.load_state_dict(checkpoint["state_dict"])
|
394 |
alexnet = alexnet.to(device)
|
app.py
CHANGED
@@ -5,7 +5,7 @@ from timm.data import resolve_data_config
|
|
5 |
from timm.data.transforms_factory import create_transform
|
6 |
|
7 |
LABELS = {0:'Cat', 1:'Dog'}
|
8 |
-
model = torch.load('CatVsDogsModel.pth')
|
9 |
model.eval()
|
10 |
transform = create_transform(**resolve_data_config({},model=model))
|
11 |
|
|
|
5 |
from timm.data.transforms_factory import create_transform
|
6 |
|
7 |
LABELS = {0:'Cat', 1:'Dog'}
|
8 |
+
model = torch.load('CatVsDogsModel.pth',map_location='cpu')
|
9 |
model.eval()
|
10 |
transform = create_transform(**resolve_data_config({},model=model))
|
11 |
|