EditGuard / models /IBSN.py
Ricoooo's picture
'folder'
5d21dd2
import logging
from collections import OrderedDict
import torch
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
import models.networks as networks
import models.lr_scheduler as lr_scheduler
from .base_model import BaseModel
from models.modules.loss import ReconstructionLoss, ReconstructionMsgLoss
from models.modules.Quantization import Quantization
from .modules.common import DWT,IWT
from utils.jpegtest import JpegTest
from utils.JPEG import DiffJPEG
import utils.util as util
import numpy as np
import random
import cv2
import time
logger = logging.getLogger('base')
dwt=DWT()
iwt=IWT()
from diffusers import StableDiffusionInpaintPipeline
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
from diffusers import StableDiffusionXLInpaintPipeline
from diffusers.utils import load_image
from diffusers import RePaintPipeline, RePaintScheduler
class Model_VSN(BaseModel):
def __init__(self, opt):
super(Model_VSN, self).__init__(opt)
if opt['dist']:
self.rank = torch.distributed.get_rank()
else:
self.rank = -1 # non dist training
self.gop = opt['gop']
train_opt = opt['train']
test_opt = opt['test']
self.opt = opt
self.train_opt = train_opt
self.test_opt = test_opt
self.opt_net = opt['network_G']
self.center = self.gop // 2
self.num_image = opt['num_image']
self.mode = opt["mode"]
self.idxx = 0
self.netG = networks.define_G_v2(opt).to(self.device)
if opt['dist']:
self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
else:
self.netG = DataParallel(self.netG)
# print network
self.print_network()
self.load()
self.Quantization = Quantization()
if not self.opt['hide']:
file_path = "bit_sequence.txt"
data_list = []
with open(file_path, "r") as file:
for line in file:
data = [int(bit) for bit in line.strip()]
data_list.append(data)
self.msg_list = data_list
if self.opt['sdinpaint']:
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
).to("cuda")
if self.opt['controlnetinpaint']:
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float32
).to("cuda")
self.pipe_control = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float32
).to("cuda")
if self.opt['sdxl']:
self.pipe_sdxl = StableDiffusionXLInpaintPipeline.from_pretrained(
"diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
).to("cuda")
if self.opt['repaint']:
self.scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
self.pipe_repaint = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=self.scheduler)
self.pipe_repaint = self.pipe_repaint.to("cuda")
if self.is_train:
self.netG.train()
# loss
self.Reconstruction_forw = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_forw'])
self.Reconstruction_back = ReconstructionLoss(losstype=self.train_opt['pixel_criterion_back'])
self.Reconstruction_center = ReconstructionLoss(losstype="center")
self.Reconstruction_msg = ReconstructionMsgLoss(losstype=self.opt['losstype'])
# optimizers
wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
optim_params = []
if self.mode == "image":
for k, v in self.netG.named_parameters():
if (k.startswith('module.irn') or k.startswith('module.pm')) and v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
elif self.mode == "bit":
for k, v in self.netG.named_parameters():
if (k.startswith('module.bitencoder') or k.startswith('module.bitdecoder')) and v.requires_grad:
optim_params.append(v)
else:
if self.rank <= 0:
logger.warning('Params [{:s}] will not optimize.'.format(k))
self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
weight_decay=wd_G,
betas=(train_opt['beta1'], train_opt['beta2']))
self.optimizers.append(self.optimizer_G)
# schedulers
if train_opt['lr_scheme'] == 'MultiStepLR':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
restarts=train_opt['restarts'],
weights=train_opt['restart_weights'],
gamma=train_opt['lr_gamma'],
clear_state=train_opt['clear_state']))
elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
for optimizer in self.optimizers:
self.schedulers.append(
lr_scheduler.CosineAnnealingLR_Restart(
optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
else:
raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
self.log_dict = OrderedDict()
def feed_data(self, data):
self.ref_L = data['LQ'].to(self.device)
self.real_H = data['GT'].to(self.device)
self.mes = data['MES']
def init_hidden_state(self, z):
b, c, h, w = z.shape
h_t = []
c_t = []
for _ in range(self.opt_net['block_num_rbm']):
h_t.append(torch.zeros([b, c, h, w]).cuda())
c_t.append(torch.zeros([b, c, h, w]).cuda())
memory = torch.zeros([b, c, h, w]).cuda()
return h_t, c_t, memory
def loss_forward(self, out, y):
l_forw_fit = self.train_opt['lambda_fit_forw'] * self.Reconstruction_forw(out, y)
return l_forw_fit
def loss_back_rec(self, out, x):
l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x)
return l_back_rec
def loss_back_rec_mul(self, out, x):
l_back_rec = self.train_opt['lambda_rec_back'] * self.Reconstruction_back(out, x)
return l_back_rec
def optimize_parameters(self, current_step):
self.optimizer_G.zero_grad()
b, n, t, c, h, w = self.ref_L.shape
center = t // 2
intval = self.gop // 2
message = torch.Tensor(np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length']))).to(self.device)
add_noise = self.opt['addnoise']
add_jpeg = self.opt['addjpeg']
add_possion = self.opt['addpossion']
add_sdinpaint = self.opt['sdinpaint']
degrade_shuffle = self.opt['degrade_shuffle']
self.host = self.real_H[:, center - intval:center + intval + 1]
self.secret = self.ref_L[:, :, center - intval:center + intval + 1]
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=dwt(self.secret[:,0].reshape(b, -1, h, w)), message=message)
Gt_ref = self.real_H[:, center - intval:center + intval + 1].detach()
y_forw = container
l_forw_fit = self.loss_forward(y_forw, self.host[:,0])
if degrade_shuffle:
import random
choice = random.randint(0, 2)
if choice == 0:
NL = float((np.random.randint(1, 16))/255)
noise = np.random.normal(0, NL, y_forw.shape)
torchnoise = torch.from_numpy(noise).cuda().float()
y_forw = y_forw + torchnoise
elif choice == 1:
NL = int(np.random.randint(70,95))
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda()
y_forw = self.DiffJPEG(y_forw)
elif choice == 2:
vals = 10**4
if random.random() < 0.5:
noisy_img_tensor = torch.poisson(y_forw * vals) / vals
else:
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
y_forw = torch.clamp(noisy_img_tensor, 0, 1)
else:
if add_noise:
NL = float((np.random.randint(1,16))/255)
noise = np.random.normal(0, NL, y_forw.shape)
torchnoise = torch.from_numpy(noise).cuda().float()
y_forw = y_forw + torchnoise
elif add_jpeg:
NL = int(np.random.randint(70,95))
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda()
y_forw = self.DiffJPEG(y_forw)
elif add_possion:
vals = 10**4
if random.random() < 0.5:
noisy_img_tensor = torch.poisson(y_forw * vals) / vals
else:
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
y_forw = torch.clamp(noisy_img_tensor, 0, 1)
y = self.Quantization(y_forw)
all_zero = torch.zeros(message.shape).to(self.device)
if self.mode == "image":
out_x, out_x_h, out_z, recmessage = self.netG(x=y, message=all_zero, rev=True)
out_x = iwt(out_x)
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
l_back_rec = self.loss_back_rec(out_x, self.host[:,0])
out_x_h = torch.stack(out_x_h, dim=1)
l_center_x = self.loss_back_rec(out_x_h[:, 0], self.secret[:,0].reshape(b, -1, h, w))
recmessage = torch.clamp(recmessage, -0.5, 0.5)
l_msg = self.Reconstruction_msg(message, recmessage)
loss = l_forw_fit*2 + l_back_rec + l_center_x*4
loss.backward()
if self.train_opt['lambda_center'] != 0:
self.log_dict['l_center_x'] = l_center_x.item()
# set log
self.log_dict['l_back_rec'] = l_back_rec.item()
self.log_dict['l_forw_fit'] = l_forw_fit.item()
self.log_dict['l_msg'] = l_msg.item()
self.log_dict['l_h'] = (l_center_x*10).item()
# gradient clipping
if self.train_opt['gradient_clipping']:
nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
self.optimizer_G.step()
elif self.mode == "bit":
recmessage = self.netG(x=y, message=all_zero, rev=True)
recmessage = torch.clamp(recmessage, -0.5, 0.5)
l_msg = self.Reconstruction_msg(message, recmessage)
lambda_msg = self.train_opt['lambda_msg']
loss = l_msg * lambda_msg + l_forw_fit
loss.backward()
# set log
self.log_dict['l_forw_fit'] = l_forw_fit.item()
self.log_dict['l_msg'] = l_msg.item()
# gradient clipping
if self.train_opt['gradient_clipping']:
nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping'])
self.optimizer_G.step()
def test(self, image_id):
self.netG.eval()
add_noise = self.opt['addnoise']
add_jpeg = self.opt['addjpeg']
add_possion = self.opt['addpossion']
add_sdinpaint = self.opt['sdinpaint']
add_controlnet = self.opt['controlnetinpaint']
add_sdxl = self.opt['sdxl']
add_repaint = self.opt['repaint']
degrade_shuffle = self.opt['degrade_shuffle']
with torch.no_grad():
forw_L = []
forw_L_h = []
fake_H = []
fake_H_h = []
pred_z = []
recmsglist = []
msglist = []
b, t, c, h, w = self.real_H.shape
center = t // 2
intval = self.gop // 2
b, n, t, c, h, w = self.ref_L.shape
id=0
# forward downscaling
self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
messagenp = np.random.choice([-0.5, 0.5], (self.ref_L.shape[0], self.opt['message_length']))
message = torch.Tensor(messagenp).to(self.device)
if self.opt['bitrecord']:
mymsg = message.clone()
mymsg[mymsg>0] = 1
mymsg[mymsg<0] = 0
mymsg = mymsg.squeeze(0).to(torch.int)
bit_list = mymsg.tolist()
bit_string = ''.join(map(str, bit_list))
file_name = "bit_sequence.txt"
with open(file_name, "a") as file:
file.write(bit_string + "\n")
if self.opt['hide']:
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message)
y_forw = container
else:
message = torch.tensor(self.msg_list[image_id]).unsqueeze(0).cuda()
self.output = self.host
y_forw = self.output.squeeze(1)
if add_sdinpaint:
import random
from PIL import Image
prompt = ""
b, _, _, _ = y_forw.shape
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
forw_list = []
for j in range(b):
i = image_id + 1
masksrc = "../dataset/valAGE-Set-Mask/"
mask_image = Image.open(masksrc + str(i).zfill(4) + ".png").convert("L")
mask_image = mask_image.resize((512, 512))
h, w = mask_image.size
image = image_batch[j, :, :, :]
image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB")
image_inpaint = self.pipe(prompt=prompt, image=image_init, mask_image=mask_image, height=w, width=h).images[0]
image_inpaint = np.array(image_inpaint) / 255.
mask_image = np.array(mask_image)
mask_image = np.stack([mask_image] * 3, axis=-1) / 255.
mask_image = mask_image.astype(np.uint8)
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
y_forw = torch.stack(forw_list, dim=0).float().cuda()
if add_controlnet:
from diffusers.utils import load_image
from PIL import Image
b, _, _, _ = y_forw.shape
forw_list = []
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
generator = torch.Generator(device="cuda").manual_seed(1)
for j in range(b):
i = image_id + 1
mask_path = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png"
mask_image = load_image(mask_path)
mask_image = mask_image.resize((512, 512))
image_init = image_batch[j, :, :, :]
image_init1 = Image.fromarray((image_init * 255).astype(np.uint8), mode = "RGB")
image_mask = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
assert image_init.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
image_init[image_mask > 0.5] = -1.0 # set as masked pixel
image = np.expand_dims(image_init, 0).transpose(0, 3, 1, 2)
control_image = torch.from_numpy(image)
# generate image
image_inpaint = self.pipe_control(
"",
num_inference_steps=20,
generator=generator,
eta=1.0,
image=image_init1,
mask_image=image_mask,
control_image=control_image,
).images[0]
image_inpaint = np.array(image_inpaint) / 255.
image_mask = np.stack([image_mask] * 3, axis=-1)
image_mask = image_mask.astype(np.uint8)
image_fuse = image_init * (1 - image_mask) + image_inpaint * image_mask
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
y_forw = torch.stack(forw_list, dim=0).float().cuda()
if add_sdxl:
import random
from PIL import Image
from diffusers.utils import load_image
prompt = ""
b, _, _, _ = y_forw.shape
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
forw_list = []
for j in range(b):
i = image_id + 1
masksrc = "../dataset/valAGE-Set-Mask/"
mask_image = load_image(masksrc + str(i).zfill(4) + ".png").convert("RGB")
mask_image = mask_image.resize((512, 512))
h, w = mask_image.size
image = image_batch[j, :, :, :]
image_init = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB")
image_inpaint = self.pipe_sdxl(
prompt=prompt, image=image_init, mask_image=mask_image, num_inference_steps=50, strength=0.80, target_size=(512, 512)
).images[0]
image_inpaint = image_inpaint.resize((512, 512))
image_inpaint = np.array(image_inpaint) / 255.
mask_image = np.array(mask_image) / 255.
mask_image = mask_image.astype(np.uint8)
image_fuse = image * (1 - mask_image) + image_inpaint * mask_image
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
y_forw = torch.stack(forw_list, dim=0).float().cuda()
if add_repaint:
from PIL import Image
b, _, _, _ = y_forw.shape
image_batch = y_forw.permute(0, 2, 3, 1).detach().cpu().numpy()
forw_list = []
generator = torch.Generator(device="cuda").manual_seed(0)
for j in range(b):
i = image_id + 1
masksrc = "../dataset/valAGE-Set-Mask/" + str(i).zfill(4) + ".png"
mask_image = Image.open(masksrc).convert("RGB")
mask_image = mask_image.resize((256, 256))
mask_image = Image.fromarray(255 - np.array(mask_image))
image = image_batch[j, :, :, :]
original_image = Image.fromarray((image * 255).astype(np.uint8), mode = "RGB")
original_image = original_image.resize((256, 256))
output = self.pipe_repaint(
image=original_image,
mask_image=mask_image,
num_inference_steps=150,
eta=0.0,
jump_length=10,
jump_n_sample=10,
generator=generator,
)
image_inpaint = output.images[0]
image_inpaint = image_inpaint.resize((512, 512))
image_inpaint = np.array(image_inpaint) / 255.
mask_image = mask_image.resize((512, 512))
mask_image = np.array(mask_image) / 255.
mask_image = mask_image.astype(np.uint8)
image_fuse = image * mask_image + image_inpaint * (1 - mask_image)
forw_list.append(torch.from_numpy(image_fuse).permute(2, 0, 1))
y_forw = torch.stack(forw_list, dim=0).float().cuda()
if degrade_shuffle:
import random
choice = random.randint(0, 2)
if choice == 0:
NL = float((np.random.randint(1,5))/255)
noise = np.random.normal(0, NL, y_forw.shape)
torchnoise = torch.from_numpy(noise).cuda().float()
y_forw = y_forw + torchnoise
elif choice == 1:
NL = 90
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(NL)).cuda()
y_forw = self.DiffJPEG(y_forw)
elif choice == 2:
vals = 10**4
if random.random() < 0.5:
noisy_img_tensor = torch.poisson(y_forw * vals) / vals
else:
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
y_forw = torch.clamp(noisy_img_tensor, 0, 1)
else:
if add_noise:
NL = self.opt['noisesigma'] / 255.0
noise = np.random.normal(0, NL, y_forw.shape)
torchnoise = torch.from_numpy(noise).cuda().float()
y_forw = y_forw + torchnoise
elif add_jpeg:
Q = self.opt['jpegfactor']
self.DiffJPEG = DiffJPEG(differentiable=True, quality=int(Q)).cuda()
y_forw = self.DiffJPEG(y_forw)
elif add_possion:
vals = 10**4
if random.random() < 0.5:
noisy_img_tensor = torch.poisson(y_forw * vals) / vals
else:
img_gray_tensor = torch.mean(y_forw, dim=0, keepdim=True)
noisy_gray_tensor = torch.poisson(img_gray_tensor * vals) / vals
noisy_img_tensor = y_forw + (noisy_gray_tensor - img_gray_tensor)
y_forw = torch.clamp(noisy_img_tensor, 0, 1)
# backward upscaling
if self.opt['hide']:
y = self.Quantization(y_forw)
else:
y = y_forw
if self.mode == "image":
out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True)
out_x = iwt(out_x)
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
out_x = out_x.reshape(-1, self.gop, 3, h, w)
out_x_h = torch.stack(out_x_h, dim=1)
out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w)
forw_L.append(y_forw)
fake_H.append(out_x[:, self.gop//2])
fake_H_h.append(out_x_h[:,:, self.gop//2])
recmsglist.append(recmessage)
msglist.append(message)
elif self.mode == "bit":
recmessage = self.netG(x=y, rev=True)
forw_L.append(y_forw)
recmsglist.append(recmessage)
msglist.append(message)
if self.mode == "image":
self.fake_H = torch.clamp(torch.stack(fake_H, dim=1),0,1)
self.fake_H_h = torch.clamp(torch.stack(fake_H_h, dim=2),0,1)
self.forw_L = torch.clamp(torch.stack(forw_L, dim=1),0,1)
remesg = torch.clamp(torch.stack(recmsglist, dim=0),-0.5,0.5)
if self.opt['hide']:
mesg = torch.clamp(torch.stack(msglist, dim=0),-0.5,0.5)
else:
mesg = torch.stack(msglist, dim=0)
self.recmessage = remesg.clone()
self.recmessage[remesg > 0] = 1
self.recmessage[remesg <= 0] = 0
self.message = mesg.clone()
self.message[mesg > 0] = 1
self.message[mesg <= 0] = 0
self.netG.train()
def image_hiding(self, ):
self.netG.eval()
with torch.no_grad():
b, t, c, h, w = self.real_H.shape
center = t // 2
intval = self.gop // 2
b, n, t, c, h, w = self.ref_L.shape
id=0
# forward downscaling
self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
message = torch.Tensor(self.mes).to(self.device)
self.output, container = self.netG(x=dwt(self.host.reshape(b, -1, h, w)), x_h=self.secret, message=message)
y_forw = container
result = torch.clamp(y_forw,0,1)
lr_img = util.tensor2img(result)
return lr_img
def image_recovery(self, number):
self.netG.eval()
with torch.no_grad():
b, t, c, h, w = self.real_H.shape
center = t // 2
intval = self.gop // 2
b, n, t, c, h, w = self.ref_L.shape
id=0
# forward downscaling
self.host = self.real_H[:, center - intval+id:center + intval + 1+id]
self.secret = self.ref_L[:, :, center - intval+id:center + intval + 1+id]
template = self.secret.reshape(b, -1, h, w)
self.secret = [dwt(self.secret[:,i].reshape(b, -1, h, w)) for i in range(n)]
self.output = self.host
y_forw = self.output.squeeze(1)
y = self.Quantization(y_forw)
out_x, out_x_h, out_z, recmessage = self.netG(x=y, rev=True)
out_x = iwt(out_x)
out_x_h = [iwt(out_x_h_i) for out_x_h_i in out_x_h]
out_x = out_x.reshape(-1, self.gop, 3, h, w)
out_x_h = torch.stack(out_x_h, dim=1)
out_x_h = out_x_h.reshape(-1, 1, self.gop, 3, h, w)
rec_loc = out_x_h[:,:, self.gop//2]
# from PIL import Image
# tmp = util.tensor2img(rec_loc)
# save
residual = torch.abs(template - rec_loc)
binary_residual = (residual > number).float()
residual = util.tensor2img(binary_residual)
mask = np.sum(residual, axis=2)
# print(mask)
remesg = torch.clamp(recmessage,-0.5,0.5)
remesg[remesg > 0] = 1
remesg[remesg <= 0] = 0
return mask, remesg
def get_current_log(self):
return self.log_dict
def get_current_visuals(self):
b, n, t, c, h, w = self.ref_L.shape
center = t // 2
intval = self.gop // 2
out_dict = OrderedDict()
LR_ref = self.ref_L[:, :, center - intval:center + intval + 1].detach()[0].float().cpu()
LR_ref = torch.chunk(LR_ref, self.num_image, dim=0)
out_dict['LR_ref'] = [image.squeeze(0) for image in LR_ref]
if self.mode == "image":
out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
SR_h = self.fake_H_h.detach()[0].float().cpu()
SR_h = torch.chunk(SR_h, self.num_image, dim=0)
out_dict['SR_h'] = [image.squeeze(0) for image in SR_h]
out_dict['LR'] = self.forw_L.detach()[0].float().cpu()
out_dict['GT'] = self.real_H[:, center - intval:center + intval + 1].detach()[0].float().cpu()
out_dict['message'] = self.message
out_dict['recmessage'] = self.recmessage
return out_dict
def print_network(self):
s, n = self.get_network_description(self.netG)
if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
self.netG.module.__class__.__name__)
else:
net_struc_str = '{}'.format(self.netG.__class__.__name__)
if self.rank <= 0:
logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
logger.info(s)
def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
def load_test(self,load_path_G):
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
def save(self, iter_label):
self.save_network(self.netG, 'G', iter_label)