nikunjkdtechnoland commited on
Commit
ea71739
1 Parent(s): 7a6acb0

some changes

Browse files
Files changed (3) hide show
  1. app.py +3 -3
  2. only_gradio_server.py +0 -4
  3. trainer.py +0 -155
app.py CHANGED
@@ -9,10 +9,10 @@ options_list = list(object_names.values())
9
 
10
  # Create Gradio interface
11
  iface = gr.Interface(fn=process_images,
12
- inputs=[gr.Image(type='filepath', label='Main Image where object identify', width=300, height=300),
13
- gr.Image(type='filepath', label='Object Image which placed on Main Image', image_mode="RGBA", width=300, height=300),
14
  gr.Dropdown(options_list, label='Replace Object Name (Default = chair)')],
15
- outputs=gr.Image(type='numpy', label='Final Result', width=300, height=300),
16
  title="AI Based Image Processing",
17
  description="Object to Object Replacement (Note: due to limitation of free usage on this server task will take approx 3-5 minutes for process, But the actual speed of this process on pc or dedicated server is < 10 seconds)")
18
 
 
9
 
10
  # Create Gradio interface
11
  iface = gr.Interface(fn=process_images,
12
+ inputs=[gr.Image(type='filepath', label='Main Image where object identify', width="60%", height="60%"),
13
+ gr.Image(type='filepath', label='Object Image which placed on Main Image (PNG file only RGBA Channel)', image_mode="RGBA", width="60%", height="60%"),
14
  gr.Dropdown(options_list, label='Replace Object Name (Default = chair)')],
15
+ outputs=gr.Image(type='numpy', label='Final Result', width="70%", height="70%"),
16
  title="AI Based Image Processing",
17
  description="Object to Object Replacement (Note: due to limitation of free usage on this server task will take approx 3-5 minutes for process, But the actual speed of this process on pc or dedicated server is < 10 seconds)")
18
 
only_gradio_server.py CHANGED
@@ -9,7 +9,6 @@ import numpy as np
9
  from PIL import Image
10
  from torchvision import transforms
11
  import imageio.v2 as imageio
12
- from trainer import Trainer
13
  from utils.tools import get_config
14
  import torch.nn.functional as F
15
  from iopaint.single_processing import batch_inpaint_cv2
@@ -151,9 +150,6 @@ def process_images(input_image, append_image, default_class="chair"):
151
  def repaitingAndMerge(append_image_path, model_path, config_path, width, height, xposition, yposition, input_base, mask_base):
152
  config = get_config(config_path)
153
  device = torch.device("cpu")
154
- trainer = Trainer(config)
155
- trainer.load_state_dict(load_weights(model_path, device), strict=False)
156
- trainer.eval()
157
 
158
  # lama inpainting start
159
  print("lama inpainting start")
 
9
  from PIL import Image
10
  from torchvision import transforms
11
  import imageio.v2 as imageio
 
12
  from utils.tools import get_config
13
  import torch.nn.functional as F
14
  from iopaint.single_processing import batch_inpaint_cv2
 
150
  def repaitingAndMerge(append_image_path, model_path, config_path, width, height, xposition, yposition, input_base, mask_base):
151
  config = get_config(config_path)
152
  device = torch.device("cpu")
 
 
 
153
 
154
  # lama inpainting start
155
  print("lama inpainting start")
trainer.py DELETED
@@ -1,155 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn as nn
4
- from torch import autograd
5
- from model.networks import Generator, LocalDis, GlobalDis
6
-
7
-
8
- from utils.tools import get_model_list, local_patch, spatial_discounting_mask
9
- from utils.logger import get_logger
10
-
11
- logger = get_logger()
12
-
13
-
14
- class Trainer(nn.Module):
15
- def __init__(self, config):
16
- super(Trainer, self).__init__()
17
- self.config = config
18
- self.use_cuda = self.config['cuda']
19
- self.device_ids = self.config['gpu_ids']
20
-
21
- self.netG = Generator(self.config['netG'], self.use_cuda, self.device_ids)
22
- self.localD = LocalDis(self.config['netD'], self.use_cuda, self.device_ids)
23
- self.globalD = GlobalDis(self.config['netD'], self.use_cuda, self.device_ids)
24
-
25
- self.optimizer_g = torch.optim.Adam(self.netG.parameters(), lr=self.config['lr'],
26
- betas=(self.config['beta1'], self.config['beta2']))
27
- d_params = list(self.localD.parameters()) + list(self.globalD.parameters())
28
- self.optimizer_d = torch.optim.Adam(d_params, lr=config['lr'],
29
- betas=(self.config['beta1'], self.config['beta2']))
30
- if self.use_cuda:
31
- self.netG.to(self.device_ids[0])
32
- self.localD.to(self.device_ids[0])
33
- self.globalD.to(self.device_ids[0])
34
-
35
- def forward(self, x, bboxes, masks, ground_truth, compute_loss_g=False):
36
- self.train()
37
- l1_loss = nn.L1Loss()
38
- losses = {}
39
-
40
- x1, x2, offset_flow = self.netG(x, masks)
41
- local_patch_gt = local_patch(ground_truth, bboxes)
42
- x1_inpaint = x1 * masks + x * (1. - masks)
43
- x2_inpaint = x2 * masks + x * (1. - masks)
44
- local_patch_x1_inpaint = local_patch(x1_inpaint, bboxes)
45
- local_patch_x2_inpaint = local_patch(x2_inpaint, bboxes)
46
-
47
- # D part
48
- # wgan d loss
49
- local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
50
- self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
51
- global_real_pred, global_fake_pred = self.dis_forward(
52
- self.globalD, ground_truth, x2_inpaint.detach())
53
- losses['wgan_d'] = torch.mean(local_patch_fake_pred - local_patch_real_pred) + \
54
- torch.mean(global_fake_pred - global_real_pred) * self.config['global_wgan_loss_alpha']
55
- # gradients penalty loss
56
- local_penalty = self.calc_gradient_penalty(
57
- self.localD, local_patch_gt, local_patch_x2_inpaint.detach())
58
- global_penalty = self.calc_gradient_penalty(self.globalD, ground_truth, x2_inpaint.detach())
59
- losses['wgan_gp'] = local_penalty + global_penalty
60
-
61
- # G part
62
- if compute_loss_g:
63
- sd_mask = spatial_discounting_mask(self.config)
64
- losses['l1'] = l1_loss(local_patch_x1_inpaint * sd_mask, local_patch_gt * sd_mask) * \
65
- self.config['coarse_l1_alpha'] + \
66
- l1_loss(local_patch_x2_inpaint * sd_mask, local_patch_gt * sd_mask)
67
- losses['ae'] = l1_loss(x1 * (1. - masks), ground_truth * (1. - masks)) * \
68
- self.config['coarse_l1_alpha'] + \
69
- l1_loss(x2 * (1. - masks), ground_truth * (1. - masks))
70
-
71
- # wgan g loss
72
- local_patch_real_pred, local_patch_fake_pred = self.dis_forward(
73
- self.localD, local_patch_gt, local_patch_x2_inpaint)
74
- global_real_pred, global_fake_pred = self.dis_forward(
75
- self.globalD, ground_truth, x2_inpaint)
76
- losses['wgan_g'] = - torch.mean(local_patch_fake_pred) - \
77
- torch.mean(global_fake_pred) * self.config['global_wgan_loss_alpha']
78
-
79
- return losses, x2_inpaint, offset_flow
80
-
81
- def dis_forward(self, netD, ground_truth, x_inpaint):
82
- assert ground_truth.size() == x_inpaint.size()
83
- batch_size = ground_truth.size(0)
84
- batch_data = torch.cat([ground_truth, x_inpaint], dim=0)
85
- batch_output = netD(batch_data)
86
- real_pred, fake_pred = torch.split(batch_output, batch_size, dim=0)
87
-
88
- return real_pred, fake_pred
89
-
90
- # Calculate gradient penalty
91
- def calc_gradient_penalty(self, netD, real_data, fake_data):
92
- batch_size = real_data.size(0)
93
- alpha = torch.rand(batch_size, 1, 1, 1)
94
- alpha = alpha.expand_as(real_data)
95
- if self.use_cuda:
96
- alpha = alpha.cuda()
97
-
98
- interpolates = alpha * real_data + (1 - alpha) * fake_data
99
- interpolates = interpolates.requires_grad_().clone()
100
-
101
- disc_interpolates = netD(interpolates)
102
- grad_outputs = torch.ones(disc_interpolates.size())
103
-
104
- if self.use_cuda:
105
- grad_outputs = grad_outputs.cuda()
106
-
107
- gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
108
- grad_outputs=grad_outputs, create_graph=True,
109
- retain_graph=True, only_inputs=True)[0]
110
-
111
- gradients = gradients.view(batch_size, -1)
112
- gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
113
-
114
- return gradient_penalty
115
-
116
- def inference(self, x, masks):
117
- self.eval()
118
- x1, x2, offset_flow = self.netG(x, masks)
119
- # x1_inpaint = x1 * masks + x * (1. - masks)
120
- x2_inpaint = x2 * masks + x * (1. - masks)
121
-
122
- return x2_inpaint, offset_flow
123
-
124
- def save_model(self, checkpoint_dir, iteration):
125
- # Save generators, discriminators, and optimizers
126
- gen_name = os.path.join(checkpoint_dir, 'gen_%08d.pt' % iteration)
127
- dis_name = os.path.join(checkpoint_dir, 'dis_%08d.pt' % iteration)
128
- opt_name = os.path.join(checkpoint_dir, 'optimizer.pt')
129
- torch.save(self.netG.state_dict(), gen_name)
130
- torch.save({'localD': self.localD.state_dict(),
131
- 'globalD': self.globalD.state_dict()}, dis_name)
132
- torch.save({'gen': self.optimizer_g.state_dict(),
133
- 'dis': self.optimizer_d.state_dict()}, opt_name)
134
-
135
- def resume(self, checkpoint_dir, iteration=0, test=False):
136
- # Load generators
137
- last_model_name = get_model_list(checkpoint_dir, "gen", iteration=iteration)
138
- self.netG.load_state_dict(torch.load(last_model_name))
139
- iteration = int(last_model_name[-11:-3])
140
-
141
- if not test:
142
- # Load discriminators
143
- last_model_name = get_model_list(checkpoint_dir, "dis", iteration=iteration)
144
- state_dict = torch.load(last_model_name)
145
- self.localD.load_state_dict(state_dict['localD'])
146
- self.globalD.load_state_dict(state_dict['globalD'])
147
- # Load optimizers
148
- state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
149
- self.optimizer_d.load_state_dict(state_dict['dis'])
150
- self.optimizer_g.load_state_dict(state_dict['gen'])
151
-
152
- print("Resume from {} at iteration {}".format(checkpoint_dir, iteration))
153
- logger.info("Resume from {} at iteration {}".format(checkpoint_dir, iteration))
154
-
155
- return iteration