shyamgupta196 commited on
Commit
ed21add
1 Parent(s): 098663c

added pipeline

Browse files
Files changed (1) hide show
  1. app.py +33 -3
app.py CHANGED
@@ -2,15 +2,45 @@ import requests
2
  import gradio as gr
3
  import torch
4
  from timm.data import resolve_data_config
5
- from timm.data.transforms_factory import create_transform
 
 
6
 
7
  LABELS = {0:'Cat', 1:'Dog'}
8
- model = torch.load('CatVsDogsModel.pth',map_location='cpu')
9
 
10
- transform = create_transform(**resolve_data_config({},model=model))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def predict(img):
 
14
  img = img.convert('RGB')
15
  img = transform(img).unsqueeze(0)
16
  with torch.no_grad():
 
2
  import gradio as gr
3
  import torch
4
  from timm.data import resolve_data_config
5
+ from torchvision.models import alexnet
6
+ import torch.nn as nn
7
+ from torchvision import transforms
8
 
9
  LABELS = {0:'Cat', 1:'Dog'}
 
10
 
11
+
12
+ model = alexnet(pretrained=True)
13
+ for param in model.parameters():
14
+ param.requires_grad = False
15
+
16
+ # Add a avgpool here
17
+ avgpool = nn.AdaptiveAvgPool2d((7, 7))
18
+
19
+ # Replace the classifier layer
20
+ # to customise it according to our output
21
+ model.classifier = nn.Sequential(
22
+ nn.Linear(256 * 7 * 7, 1024),
23
+ nn.Linear(1024, 256),
24
+ nn.Linear(256, 2),
25
+ )
26
+
27
+ checkpoint = torch.load(
28
+ "CatVsDogsModel.pth", map_location=torch.device("cpu")
29
+ )
30
+
31
+ model.load_state_dict(checkpoint["state_dict"])
32
+ model = model.to('cpu')
33
+
34
+
35
+
36
+ transform = transforms.Compose(
37
+ [transforms.Resize((128, 128)), transforms.ToTensor()]
38
+ )
39
+
40
 
41
 
42
  def predict(img):
43
+ img = transform(img).to('cpu')
44
  img = img.convert('RGB')
45
  img = transform(img).unsqueeze(0)
46
  with torch.no_grad():