Sanket commited on
Commit
98989c5
1 Parent(s): ed18734

increased input features

Browse files
Files changed (1) hide show
  1. app.py +64 -35
app.py CHANGED
@@ -6,18 +6,20 @@ from PIL import Image
6
 
7
  norm_layer = nn.InstanceNorm2d
8
 
 
9
  class ResidualBlock(nn.Module):
10
  def __init__(self, in_features):
11
  super(ResidualBlock, self).__init__()
12
 
13
- conv_block = [ nn.ReflectionPad2d(1),
14
- nn.Conv2d(in_features, in_features, 3),
15
- norm_layer(in_features),
16
- nn.ReLU(inplace=True),
17
- nn.ReflectionPad2d(1),
18
- nn.Conv2d(in_features, in_features, 3),
19
- norm_layer(in_features)
20
- ]
 
21
 
22
  self.conv_block = nn.Sequential(*conv_block)
23
 
@@ -30,22 +32,26 @@ class Generator(nn.Module):
30
  super(Generator, self).__init__()
31
 
32
  # Initial convolution block
33
- model0 = [ nn.ReflectionPad2d(3),
34
- nn.Conv2d(input_nc, 64, 7),
35
- norm_layer(64),
36
- nn.ReLU(inplace=True) ]
 
 
37
  self.model0 = nn.Sequential(*model0)
38
 
39
  # Downsampling
40
  model1 = []
41
  in_features = 256
42
- out_features = in_features*2
43
  for _ in range(2):
44
- model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
45
- norm_layer(out_features),
46
- nn.ReLU(inplace=True) ]
 
 
47
  in_features = out_features
48
- out_features = in_features*2
49
  self.model1 = nn.Sequential(*model1)
50
 
51
  model2 = []
@@ -56,18 +62,21 @@ class Generator(nn.Module):
56
 
57
  # Upsampling
58
  model3 = []
59
- out_features = in_features//2
60
  for _ in range(2):
61
- model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
62
- norm_layer(out_features),
63
- nn.ReLU(inplace=True) ]
 
 
 
 
64
  in_features = out_features
65
- out_features = in_features//2
66
  self.model3 = nn.Sequential(*model3)
67
 
68
  # Output layer
69
- model4 = [ nn.ReflectionPad2d(3),
70
- nn.Conv2d(64, output_nc, 7)]
71
  if sigmoid:
72
  model4 += [nn.Sigmoid()]
73
 
@@ -82,23 +91,27 @@ class Generator(nn.Module):
82
 
83
  return out
84
 
 
85
  model1 = Generator(3, 1, 3)
86
- model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
87
  model1.eval()
88
 
89
  model2 = Generator(3, 1, 3)
90
- model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
91
  model2.eval()
92
 
 
93
  def predict(input_img, ver):
94
  input_img = Image.open(input_img)
95
- transform = transforms.Compose([transforms.Resize(1080, Image.BICUBIC), transforms.ToTensor()])
 
 
96
  input_img = transform(input_img)
97
  input_img = torch.unsqueeze(input_img, 0)
98
 
99
  drawing = 0
100
  with torch.no_grad():
101
- if ver == 'Simple Lines':
102
  drawing = model2(input_img)[0].detach()
103
  else:
104
  drawing = model1(input_img)[0].detach()
@@ -106,14 +119,30 @@ def predict(input_img, ver):
106
  drawing = transforms.ToPILImage()(drawing)
107
  return drawing
108
 
109
- title="Image to Line Drawings - Complex and Simple Portraits and Landscapes"
110
- examples=[
111
- ['01.jpg', 'Complex Lines'], ['02.jpg', 'Complex Lines'], ['03.jpg', 'Simple Lines'],
112
- ['04.jpg', 'Simple Lines'], ['05.jpg', 'Simple Lines'],
 
 
 
 
113
  ]
114
 
115
- iface = gr.Interface(predict, [gr.inputs.Image(type='filepath'),
116
- gr.inputs.Radio(['Complex Lines','Simple Lines'], type="value", default='Simple Lines', label='version')],
117
- gr.outputs.Image(type="pil"), title=title,examples=examples)
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  iface.launch()
 
6
 
7
  norm_layer = nn.InstanceNorm2d
8
 
9
+
10
  class ResidualBlock(nn.Module):
11
  def __init__(self, in_features):
12
  super(ResidualBlock, self).__init__()
13
 
14
+ conv_block = [
15
+ nn.ReflectionPad2d(1),
16
+ nn.Conv2d(in_features, in_features, 3),
17
+ norm_layer(in_features),
18
+ nn.ReLU(inplace=True),
19
+ nn.ReflectionPad2d(1),
20
+ nn.Conv2d(in_features, in_features, 3),
21
+ norm_layer(in_features),
22
+ ]
23
 
24
  self.conv_block = nn.Sequential(*conv_block)
25
 
 
32
  super(Generator, self).__init__()
33
 
34
  # Initial convolution block
35
+ model0 = [
36
+ nn.ReflectionPad2d(3),
37
+ nn.Conv2d(input_nc, 256, 7),
38
+ norm_layer(256),
39
+ nn.ReLU(inplace=True),
40
+ ]
41
  self.model0 = nn.Sequential(*model0)
42
 
43
  # Downsampling
44
  model1 = []
45
  in_features = 256
46
+ out_features = in_features * 2
47
  for _ in range(2):
48
+ model1 += [
49
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
50
+ norm_layer(out_features),
51
+ nn.ReLU(inplace=True),
52
+ ]
53
  in_features = out_features
54
+ out_features = in_features * 2
55
  self.model1 = nn.Sequential(*model1)
56
 
57
  model2 = []
 
62
 
63
  # Upsampling
64
  model3 = []
65
+ out_features = in_features // 2
66
  for _ in range(2):
67
+ model3 += [
68
+ nn.ConvTranspose2d(
69
+ in_features, out_features, 3, stride=2, padding=1, output_padding=1
70
+ ),
71
+ norm_layer(out_features),
72
+ nn.ReLU(inplace=True),
73
+ ]
74
  in_features = out_features
75
+ out_features = in_features // 2
76
  self.model3 = nn.Sequential(*model3)
77
 
78
  # Output layer
79
+ model4 = [nn.ReflectionPad2d(3), nn.Conv2d(256, output_nc, 7)]
 
80
  if sigmoid:
81
  model4 += [nn.Sigmoid()]
82
 
 
91
 
92
  return out
93
 
94
+
95
  model1 = Generator(3, 1, 3)
96
+ model1.load_state_dict(torch.load("model.pth", map_location=torch.device("cpu")))
97
  model1.eval()
98
 
99
  model2 = Generator(3, 1, 3)
100
+ model2.load_state_dict(torch.load("model2.pth", map_location=torch.device("cpu")))
101
  model2.eval()
102
 
103
+
104
  def predict(input_img, ver):
105
  input_img = Image.open(input_img)
106
+ transform = transforms.Compose(
107
+ [transforms.Resize(1080, Image.BICUBIC), transforms.ToTensor()]
108
+ )
109
  input_img = transform(input_img)
110
  input_img = torch.unsqueeze(input_img, 0)
111
 
112
  drawing = 0
113
  with torch.no_grad():
114
+ if ver == "Simple Lines":
115
  drawing = model2(input_img)[0].detach()
116
  else:
117
  drawing = model1(input_img)[0].detach()
 
119
  drawing = transforms.ToPILImage()(drawing)
120
  return drawing
121
 
122
+
123
+ title = "Image to Line Drawings - Complex and Simple Portraits and Landscapes"
124
+ examples = [
125
+ ["01.jpg", "Complex Lines"],
126
+ ["02.jpg", "Simple Lines"],
127
+ ["03.jpg", "Simple Lines"],
128
+ ["04.jpg", "Simple Lines"],
129
+ ["05.jpg", "Simple Lines"],
130
  ]
131
 
132
+ iface = gr.Interface(
133
+ predict,
134
+ [
135
+ gr.inputs.Image(type="filepath"),
136
+ gr.inputs.Radio(
137
+ ["Complex Lines", "Simple Lines"],
138
+ type="value",
139
+ default="Simple Lines",
140
+ label="version",
141
+ ),
142
+ ],
143
+ gr.outputs.Image(type="pil"),
144
+ title=title,
145
+ examples=examples,
146
+ )
147
 
148
  iface.launch()