Last commit not found
import torch | |
from torch.utils.data import Dataset, DataLoader | |
import torchvision | |
from torchvision import transforms | |
from torchvision.transforms.functional import to_pil_image, to_tensor | |
import glob | |
from PIL import Image | |
import tqdm | |
import gc | |
class TestModel(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.start = torch.nn.Conv2d(3, 16, 3, 1, 1, bias=False) | |
self.conv1 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) | |
self.conv2 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) | |
self.conv3 = torch.nn.Conv2d(16, 16, 3, 1, 1, bias=False) | |
self.final = torch.nn.Conv2d(16, 3, 3, 1, 1, bias=False) | |
self.bn1 = torch.nn.BatchNorm2d(16) | |
self.bn2 = torch.nn.BatchNorm2d(16) | |
def forward(self, x): | |
x = self.start(x) | |
x = self.bn1(x) | |
x = self.conv1(x) + x | |
x = self.conv2(x) + x | |
x = self.conv3(x) + x | |
x = self.bn2(x) | |
x = self.final(x) | |
x = torch.clamp(x, -1, 1) | |
return x | |
class DS(Dataset): | |
def __init__(self): | |
super().__init__() | |
self.g = glob.glob("./15k/*") | |
self.trans = transforms.Compose([ | |
transforms.RandomCrop((256, 256)), | |
transforms.ToTensor() | |
]) | |
def __len__(self): | |
return len(self.g) | |
def __getitem__(self, idx): | |
x = self.g[idx] | |
x = Image.open(x) | |
x = x.convert("RGB") | |
x = self.trans(x) | |
x = x / 127.5 - 1 | |
return x | |
def gettest(self): | |
x = self.g[0] | |
x = Image.open(x) | |
x = x.convert("RGB") | |
x = to_tensor(x) | |
x = x / 127.5 - 1 | |
return x | |
def main(): | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
bacth_size = 64 | |
epoch = 10 | |
model = TestModel() | |
dataset = DS() | |
datalaoder = DataLoader(dataset, batch_size=bacth_size, shuffle=True) | |
criterion = torch.nn.MSELoss() | |
kl = torch.nn.KLDivLoss(size_average=False) | |
optim = torch.optim.Adam(model.parameters(recurse=True), lr=1e-4) | |
criterion = criterion.to(device) | |
model = model.to(device) | |
model.train() | |
def log(l): | |
model.eval() | |
x = dataset.gettest().to(device) | |
x = x.unsqueeze(0) | |
out = model(x) | |
to_pil_image((out[0] + 1)/2).save("./test/" + str(l) + ".png") | |
model.train() | |
log("test") | |
for i in range(epoch): | |
for j, k in enumerate(tqdm.tqdm(datalaoder)): | |
k = k.to(device) | |
model.zero_grad() | |
out = model(k) | |
loss = criterion(out, k)# + kl(((out + 1)/2).log(), (k + 1)/2) | |
loss.backward() | |
optim.step() | |
if j % 100 == 0: | |
gc.collect() | |
torch.cuda.empty_cache() | |
print("EPOCH", i) | |
print("LAST LOSS", loss) | |
log(i) | |
if __name__ == "__main__": | |
main() |