File size: 7,826 Bytes
19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 da5b2c1 19ec9f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# from https://huggingface.co/spaces/hysts/StyleGAN3/blob/main/model.py
import pathlib
import pickle
import sys
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
import torch
import torchvision.utils as vutils
import matplotlib.pyplot as plt
from io import BytesIO
from PIL import Image
current_dir = pathlib.Path(__file__).parent
submodule_dir = current_dir / "stylegan3"
sys.path.insert(0, submodule_dir.as_posix())
user = "ellemac"
dcgan_z_dim = 100
dcgan_gen_feats = 64
ngf = 64
dcgan_img_size = 64
nc = 3
# class Generator(nn.Module):
# def __init__(self, ngpu, nz):
# super(Generator, self).__init__()
# self.ngpu = ngpu
# self.main = nn.Sequential(
# # input is Z, going into a convolution
# nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
# nn.BatchNorm2d(ngf * 8),
# nn.LeakyReLU(0.2, inplace=True),
# # state size. (ngf*8) x 4 x 4
# nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
# nn.BatchNorm2d(ngf * 4),
# nn.LeakyReLU(0.2, inplace=True),
# # state size. (ngf*4) x 8 x 8
# nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
# nn.BatchNorm2d(ngf * 2),
# nn.LeakyReLU(0.2, inplace=True),
# # state size. (ngf*2) x 16 x 16
# nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
# nn.BatchNorm2d(ngf),
# nn.LeakyReLU(0.2, inplace=True),
# # state size. (ngf) x 32 x 32
# nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
# nn.Tanh()
# # state size. (nc) x 64 x 64
# )
# def forward(self, input):
# return self.main(input)
class Generator(nn.Module):
def __init__(self, n_gen_feats, n_gpu, z_dim, n_channels):
super(Generator, self).__init__()
self.n_gpu = n_gpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(z_dim, n_gen_feats * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(n_gen_feats * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_gen_feats*8) x 4 x 4
nn.ConvTranspose2d(n_gen_feats * 8, n_gen_feats * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_gen_feats * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_gen_feats*4) x 8 x 8
nn.ConvTranspose2d(n_gen_feats * 4, n_gen_feats * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_gen_feats * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_gen_feats*2) x 16 x 16
nn.ConvTranspose2d(n_gen_feats * 2, n_gen_feats, 4, 2, 1, bias=False),
nn.BatchNorm2d(n_gen_feats),
nn.LeakyReLU(0.2, inplace=True),
# state size. (n_gen_feats) x 32 x 32
nn.ConvTranspose2d(n_gen_feats, n_channels, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (n_channels) x 64 x 64
)
def forward(self, input):
return self.main(input)
class Model:
MODEL_DICT = {
"stylegan3-abstract": {"name": "abstract-560eps.pkl", "repo": "avantStyleGAN3"},
"stylegan3-high-fidelity": {"name": "high-fidelity-1120eps.pkl", "repo": "avantStyleGAN3"},
"ada-dcgan": {"name": "gen_6kepoch.pt", "repo": "avantGAN"},
}
def __init__(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._download_all_models()
self.model_name = "ada-dcgan" #stylegan3-abstract"
self.model = self._load_model(self.model_name)
def _load_model(self, model_name: str) -> nn.Module:
file_name = self.MODEL_DICT[model_name]["name"]
repo = self.MODEL_DICT[model_name]["repo"]
path = hf_hub_download(f"{user}/{repo}", file_name) # model repo-type
if "stylegan" in model_name:
with open(path, "rb") as f:
model = pickle.load(f)["G_ema"]
else:
# todo (elle): don't hardcode the config
model = Generator(dcgan_gen_feats, 1, dcgan_z_dim, 3)
# model = Generator(0, 100)
model.load_state_dict(torch.load(path, map_location=self.device))
model.eval()
model.to(self.device)
return model
def set_model(self, model_name: str) -> None:
if model_name == self.model_name:
return
self.model_name = model_name
self.model = self._load_model(model_name)
def _download_all_models(self):
for name in self.MODEL_DICT.keys():
self._load_model(name)
@staticmethod
def make_transform(translate: tuple[float, float] = (0,0), angle: float = 0) -> np.ndarray:
mat = np.eye(3)
sin = np.sin(angle / 360 * np.pi * 2)
cos = np.cos(angle / 360 * np.pi * 2)
mat[0][0] = cos
mat[0][1] = sin
mat[0][2] = translate[0]
mat[1][0] = -sin
mat[1][1] = cos
mat[1][2] = translate[1]
return mat
def generate_z(self, seed: int) -> torch.Tensor:
seed = int(np.clip(seed, 0, np.iinfo(np.uint32).max))
z = np.random.RandomState(seed).randn(1, self.model.z_dim)
return torch.from_numpy(z).float().to(self.device)
def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
tensor = (tensor.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
return tensor.cpu().numpy()
def set_transform(self, tx: float = 0, ty: float = 0, angle: float = 0) -> None:
mat = self.make_transform((tx, ty), angle)
mat = np.linalg.inv(mat)
self.model.synthesis.input.transform.copy_(torch.from_numpy(mat))
@torch.inference_mode()
def generate(self, z: torch.Tensor, label: torch.Tensor, truncation_psi: float) -> torch.Tensor:
return self.model(z, label, truncation_psi=truncation_psi)
def generate_image(self, seed: int, truncation_psi: float = 0, tx: float = 0, ty: float = 0, angle: float = 0) -> np.ndarray:
self.set_transform(tx, ty, angle)
z = self.generate_z(seed)
label = torch.zeros([1, self.model.c_dim], device=self.device)
out = self.generate(z, label, truncation_psi)
out = self.postprocess(out)
return out[0]
def dcgan_generate_image(self, seed: int) -> np.ndarray:
torch.manual_seed(seed)
if self.device == 'cuda':
torch.cuda.manual_seed(seed)
with torch.no_grad():
n_images = 1
z = torch.randn(n_images, dcgan_z_dim, 1, 1, device=self.device)
fake_images = self.model(z.to(self.device)).cpu()
fake_images = fake_images.view(fake_images.size(0), 3, dcgan_img_size, dcgan_img_size)
# Create a grid of images
grid = vutils.make_grid(fake_images, normalize=True)
# Plot the grid and save it to a buffer
fig, ax = plt.subplots()
ax.imshow(grid.permute(1, 2, 0)) # Convert from CHW to HWC for imshow
plt.axis('off')
# Save the plot to a buffer
buf = BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
# Load the buffer into a PIL Image
img = Image.open(buf)
return img
def set_model_and_generate_image(
self, model_name: str, seed: int, truncation_psi: float = 0, tx: float = 0, ty: float = 0, angle: float = 0
) -> np.ndarray:
self.set_model(model_name)
if "stylegan3" in model_name:
return self.generate_image(seed, truncation_psi, tx, ty, angle)
else:
return self.dcgan_generate_image(seed)
|