Spaces:
Running
Running
import sys | |
import json | |
import torch | |
import numpy as np | |
import argparse | |
import torchvision.transforms as transforms | |
import cv2 | |
from DRL.ddpg import decode | |
from utils.util import * | |
from PIL import Image | |
from torchvision import transforms, utils | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
aug = transforms.Compose( | |
[transforms.ToPILImage(), | |
transforms.RandomHorizontalFlip(), | |
]) | |
width = 128 | |
convas_area = width * width | |
img_train = [] | |
img_test = [] | |
train_num = 0 | |
test_num = 0 | |
class Paint: | |
def __init__(self, batch_size, max_step): | |
self.batch_size = batch_size | |
self.max_step = max_step | |
self.action_space = (13) | |
self.observation_space = (self.batch_size, width, width, 7) | |
self.test = False | |
def load_data(self): | |
# CelebA | |
global train_num, test_num | |
for i in range(200000): | |
img_id = '%06d' % (i + 1) | |
try: | |
img = cv2.imread('./data/img_align_celeba/' + img_id + '.jpg', cv2.IMREAD_UNCHANGED) | |
img = cv2.resize(img, (width, width)) | |
if i > 2000: | |
train_num += 1 | |
img_train.append(img) | |
else: | |
test_num += 1 | |
img_test.append(img) | |
finally: | |
if (i + 1) % 10000 == 0: | |
print('loaded {} images'.format(i + 1)) | |
print('finish loading data, {} training images, {} testing images'.format(str(train_num), str(test_num))) | |
def pre_data(self, id, test): | |
if test: | |
img = img_test[id] | |
else: | |
img = img_train[id] | |
if not test: | |
img = aug(img) | |
img = np.asarray(img) | |
return np.transpose(img, (2, 0, 1)) | |
def reset(self, test=False, begin_num=False): | |
self.test = test | |
self.imgid = [0] * self.batch_size | |
self.gt = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device) | |
for i in range(self.batch_size): | |
if test: | |
id = (i + begin_num) % test_num | |
else: | |
id = np.random.randint(train_num) | |
self.imgid[i] = id | |
self.gt[i] = torch.tensor(self.pre_data(id, test)) | |
self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1) | |
self.stepnum = 0 | |
self.canvas = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device) | |
self.lastdis = self.ini_dis = self.cal_dis() | |
return self.observation() | |
def observation(self): | |
# canvas B * 3 * width * width | |
# gt B * 3 * width * width | |
# T B * 1 * width * width | |
ob = [] | |
T = torch.ones([self.batch_size, 1, width, width], dtype=torch.uint8) * self.stepnum | |
return torch.cat((self.canvas, self.gt, T.to(device)), 1) # canvas, img, T | |
def cal_trans(self, s, t): | |
return (s.transpose(0, 3) * t).transpose(0, 3) | |
def step(self, action): | |
self.canvas = (decode(action, self.canvas.float() / 255) * 255).byte() | |
self.stepnum += 1 | |
ob = self.observation() | |
done = (self.stepnum == self.max_step) | |
reward = self.cal_reward() # np.array([0.] * self.batch_size) | |
return ob.detach(), reward, np.array([done] * self.batch_size), None | |
def cal_dis(self): | |
return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1) | |
def cal_reward(self): | |
dis = self.cal_dis() | |
reward = (self.lastdis - dis) / (self.ini_dis + 1e-8) | |
self.lastdis = dis | |
return to_numpy(reward) | |