shyamgupta196 commited on
Commit
c27b177
·
1 Parent(s): bd0fc79

map to cpu

Browse files
Files changed (2) hide show
  1. CatVsDogTrain.py +3 -3
  2. app.py +1 -1
CatVsDogTrain.py CHANGED
@@ -224,7 +224,7 @@ def train_and_validate(model, loss_criterion, optimizer, epochs=25):
224
  )
225
 
226
  # Save if the model has best accuracy till now
227
- torch.save(model, "TrainLoopImproveCatsDogs.pth")
228
 
229
  return model, history
230
 
@@ -258,7 +258,7 @@ history = []
258
  def test():
259
  test = datasets.ImageFolder(root="PetTest/", transform=convert)
260
  testLoader = DataLoader(test, batch_size=16, shuffle=False)
261
- checkpoint = torch.load("catsvdogs.pth")
262
  alexnet.load_state_dict(checkpoint["state_dict"])
263
  optimizer.load_state_dict(checkpoint["optimizer"])
264
  for params in alexnet.parameters():
@@ -388,7 +388,7 @@ def predict(model, test_image_name):
388
 
389
  if PREDICT:
390
  checkpoint = torch.load(
391
- "ImprovedCatVsDogsModel.pth", map_location=torch.device("cpu")
392
  )
393
  alexnet.load_state_dict(checkpoint["state_dict"])
394
  alexnet = alexnet.to(device)
 
224
  )
225
 
226
  # Save if the model has best accuracy till now
227
+ torch.save(model, "CatVsDogsModel.pth")
228
 
229
  return model, history
230
 
 
258
  def test():
259
  test = datasets.ImageFolder(root="PetTest/", transform=convert)
260
  testLoader = DataLoader(test, batch_size=16, shuffle=False)
261
+ checkpoint = torch.load("CatVsDogsModel.pth")
262
  alexnet.load_state_dict(checkpoint["state_dict"])
263
  optimizer.load_state_dict(checkpoint["optimizer"])
264
  for params in alexnet.parameters():
 
388
 
389
  if PREDICT:
390
  checkpoint = torch.load(
391
+ "CatVsDogsModel.pth", map_location=torch.device("cpu")
392
  )
393
  alexnet.load_state_dict(checkpoint["state_dict"])
394
  alexnet = alexnet.to(device)
app.py CHANGED
@@ -5,7 +5,7 @@ from timm.data import resolve_data_config
5
  from timm.data.transforms_factory import create_transform
6
 
7
  LABELS = {0:'Cat', 1:'Dog'}
8
- model = torch.load('CatVsDogsModel.pth')
9
  model.eval()
10
  transform = create_transform(**resolve_data_config({},model=model))
11
 
 
5
  from timm.data.transforms_factory import create_transform
6
 
7
  LABELS = {0:'Cat', 1:'Dog'}
8
+ model = torch.load('CatVsDogsModel.pth',map_location='cpu')
9
  model.eval()
10
  transform = create_transform(**resolve_data_config({},model=model))
11