|
# TADNE (This Anime Does Not Exist) model |
|
|
|
The original TADNE site is https://thisanimedoesnotexist.ai/. |
|
|
|
## Original TensorFlow model |
|
|
|
The original TADNE model is provided in [this site](https://www.gwern.net/Faces#tadne-download) under CC-0 license. ([Google Drive](https://drive.google.com/file/d/1A-E_E32WAtTHRlOzjhhYhyyBDXLJN9_H)) |
|
|
|
## Model Conversion |
|
|
|
The model in the `models` directory is converted with the following repo: |
|
https://github.com/rosinality/stylegan2-pytorch |
|
|
|
### Apply patches |
|
```diff |
|
--- a/model.py |
|
+++ b/model.py |
|
@@ -395,6 +395,7 @@ class Generator(nn.Module): |
|
style_dim, |
|
n_mlp, |
|
channel_multiplier=2, |
|
+ additional_multiplier=2, |
|
blur_kernel=[1, 3, 3, 1], |
|
lr_mlp=0.01, |
|
): |
|
@@ -426,6 +427,9 @@ class Generator(nn.Module): |
|
512: 32 * channel_multiplier, |
|
1024: 16 * channel_multiplier, |
|
} |
|
+ if additional_multiplier > 1: |
|
+ for k in list(self.channels.keys()): |
|
+ self.channels[k] *= additional_multiplier |
|
|
|
self.input = ConstantInput(self.channels[4]) |
|
self.conv1 = StyledConv( |
|
@@ -518,7 +522,7 @@ class Generator(nn.Module): |
|
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers) |
|
] |
|
|
|
- if truncation < 1: |
|
+ if truncation_latent is not None: |
|
style_t = [] |
|
|
|
for style in styles: |
|
``` |
|
|
|
```diff |
|
--- a/convert_weight.py |
|
+++ b/convert_weight.py |
|
@@ -221,6 +221,7 @@ if __name__ == "__main__": |
|
default=2, |
|
help="channel multiplier factor. config-f = 2, else = 1", |
|
) |
|
+ parser.add_argument("--additional_multiplier", type=int, default=2) |
|
parser.add_argument("path", metavar="PATH", help="path to the tensorflow weights") |
|
|
|
args = parser.parse_args() |
|
@@ -243,7 +244,8 @@ if __name__ == "__main__": |
|
if layer[0].startswith('Dense'): |
|
n_mlp += 1 |
|
|
|
- g = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier) |
|
+ style_dim = 512 * args.additional_multiplier |
|
+ g = Generator(size, style_dim, n_mlp, channel_multiplier=args.channel_multiplier, additional_multiplier=args.additional_multiplier) |
|
state_dict = g.state_dict() |
|
state_dict = fill_statedict(state_dict, g_ema.vars, size, n_mlp) |
|
|
|
@@ -254,7 +256,7 @@ if __name__ == "__main__": |
|
ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} |
|
|
|
if args.gen: |
|
- g_train = Generator(size, 512, n_mlp, channel_multiplier=args.channel_multiplier) |
|
+ g_train = Generator(size, style_dim, n_mlp, channel_multiplier=args.channel_multiplier, additional_multiplier=args.additional_multiplier) |
|
g_train_state = g_train.state_dict() |
|
g_train_state = fill_statedict(g_train_state, generator.vars, size, n_mlp) |
|
ckpt["g"] = g_train_state |
|
@@ -271,9 +273,12 @@ if __name__ == "__main__": |
|
batch_size = {256: 16, 512: 9, 1024: 4} |
|
n_sample = batch_size.get(size, 25) |
|
|
|
+ if args.additional_multiplier > 1: |
|
+ n_sample = 2 |
|
+ |
|
g = g.to(device) |
|
|
|
- z = np.random.RandomState(0).randn(n_sample, 512).astype("float32") |
|
+ z = np.random.RandomState(0).randn(n_sample, style_dim).astype("float32") |
|
|
|
with torch.no_grad(): |
|
img_pt, _ = g( |
|
``` |
|
|
|
### Build Docker image |
|
|
|
```dockerfile |
|
FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu18.04 |
|
|
|
ENV DEBIAN_FRONTEND=noninteractive |
|
RUN apt-get update -y && \ |
|
apt-get install -y --no-install-recommends \ |
|
git \ |
|
ninja-build \ |
|
# pyenv dependencies \ |
|
make \ |
|
build-essential \ |
|
libssl-dev \ |
|
zlib1g-dev \ |
|
libbz2-dev \ |
|
libreadline-dev \ |
|
libsqlite3-dev \ |
|
wget \ |
|
curl \ |
|
llvm \ |
|
libncursesw5-dev \ |
|
xz-utils \ |
|
tk-dev \ |
|
libxml2-dev \ |
|
libxmlsec1-dev \ |
|
libffi-dev \ |
|
liblzma-dev && \ |
|
apt-get clean && \ |
|
rm -rf /var/lib/apt/lists/* |
|
|
|
ARG PYTHON_VERSION=3.7.12 |
|
ENV PYENV_ROOT /opt/pyenv |
|
ENV PATH ${PYENV_ROOT}/shims:${PYENV_ROOT}/bin:${PATH} |
|
RUN curl https://pyenv.run | bash |
|
RUN pyenv install ${PYTHON_VERSION} && \ |
|
pyenv global ${PYTHON_VERSION} |
|
RUN pip install --no-cache-dir -U requests tqdm opencv-python-headless |
|
RUN pip install --no-cache-dir -U tensorflow-gpu==1.15.4 |
|
RUN pip install --no-cache-dir -U torch==1.10.2+cu102 torchvision==0.11.3+cu102 -f https://download.pytorch.org/whl/torch/ -f https://download.pytorch.org/whl/torchvision/ |
|
RUN rm -rf ${HOME}/.cache/pip |
|
|
|
WORKDIR /work |
|
ENV PYTHONPATH /work/:${PYTHONPATH} |
|
``` |
|
|
|
```bash |
|
docker build . -t stylegan2_pytorch |
|
``` |
|
|
|
### Convert |
|
```bash |
|
git clone https://github.com/NVLabs/stylegan2 |
|
docker run --rm -it -u $(id -u):$(id -g) -e XDG_CACHE_HOME=/work --ipc host --gpus all -w /work -v `pwd`:/work stylegan2_pytorch python convert_weight.py --repo stylegan2 aydao-anime-danbooru2019s-512-5268480.pkl |
|
``` |
|
|
|
## Usage |
|
### Apply patch |
|
```diff |
|
--- a/generate.py |
|
+++ b/generate.py |
|
@@ -6,21 +6,25 @@ from model import Generator |
|
from tqdm import tqdm |
|
|
|
|
|
-def generate(args, g_ema, device, mean_latent): |
|
+def generate(args, g_ema, device, mean_latent, randomize_noise): |
|
|
|
with torch.no_grad(): |
|
g_ema.eval() |
|
for i in tqdm(range(args.pics)): |
|
- sample_z = torch.randn(args.sample, args.latent, device=device) |
|
+ samples = [] |
|
+ for _ in range(args.split): |
|
+ sample_z = torch.randn(args.sample // args.split, args.latent, device=device) |
|
|
|
- sample, _ = g_ema( |
|
- [sample_z], truncation=args.truncation, truncation_latent=mean_latent |
|
- ) |
|
+ sample, _ = g_ema( |
|
+ [sample_z], truncation=args.truncation, truncation_latent=mean_latent, |
|
+ randomize_noise=randomize_noise |
|
+ ) |
|
+ samples.extend(sample) |
|
|
|
utils.save_image( |
|
- sample, |
|
- f"sample/{str(i).zfill(6)}.png", |
|
- nrow=1, |
|
+ samples, |
|
+ f"{args.output_dir}/{str(i).zfill(6)}.{args.ext}", |
|
+ nrow=args.ncol, |
|
normalize=True, |
|
range=(-1, 1), |
|
) |
|
@@ -30,6 +34,8 @@ if __name__ == "__main__": |
|
device = "cuda" |
|
|
|
parser = argparse.ArgumentParser(description="Generate samples from the generator") |
|
+ parser.add_argument("--seed", type=int, default=0) |
|
+ parser.add_argument("--output-dir", '-o', type=str, required=True) |
|
|
|
parser.add_argument( |
|
"--size", type=int, default=1024, help="output image size of the generator" |
|
@@ -37,11 +43,14 @@ if __name__ == "__main__": |
|
parser.add_argument( |
|
"--sample", |
|
type=int, |
|
- default=1, |
|
+ default=100, |
|
help="number of samples to be generated for each image", |
|
) |
|
+ parser.add_argument("--ncol", type=int, default=10) |
|
+ parser.add_argument("--split", type=int, default=4) |
|
+ parser.add_argument("--ext", type=str, default='png') |
|
parser.add_argument( |
|
- "--pics", type=int, default=20, help="number of images to be generated" |
|
+ "--pics", type=int, default=1, help="number of images to be generated" |
|
) |
|
parser.add_argument("--truncation", type=float, default=1, help="truncation ratio") |
|
parser.add_argument( |
|
@@ -62,23 +71,31 @@ if __name__ == "__main__": |
|
default=2, |
|
help="channel multiplier of the generator. config-f = 2, else = 1", |
|
) |
|
+ parser.add_argument("--additional_multiplier", type=int, default=1) |
|
+ parser.add_argument("--load_latent_vec", action='store_true') |
|
+ parser.add_argument("--no-randomize-noise", dest='randomize_noise', action='store_false') |
|
+ parser.add_argument("--n_mlp", type=int, default=8) |
|
|
|
args = parser.parse_args() |
|
|
|
- args.latent = 512 |
|
- args.n_mlp = 8 |
|
+ seed = args.seed |
|
+ torch.manual_seed(seed) |
|
+ torch.cuda.manual_seed_all(seed) |
|
+ |
|
+ args.latent = 512 * args.additional_multiplier |
|
|
|
g_ema = Generator( |
|
- args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier |
|
+ args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier, |
|
+ additional_multiplier=args.additional_multiplier |
|
).to(device) |
|
checkpoint = torch.load(args.ckpt) |
|
|
|
- g_ema.load_state_dict(checkpoint["g_ema"]) |
|
+ g_ema.load_state_dict(checkpoint["g_ema"], strict=True) |
|
|
|
- if args.truncation < 1: |
|
+ if not args.load_latent_vec: |
|
with torch.no_grad(): |
|
mean_latent = g_ema.mean_latent(args.truncation_mean) |
|
else: |
|
- mean_latent = None |
|
+ mean_latent = checkpoint['latent_avg'].to(device) |
|
|
|
- generate(args, g_ema, device, mean_latent) |
|
+ generate(args, g_ema, device, mean_latent, randomize_noise=args.randomize_noise) |
|
``` |
|
|
|
### Run |
|
```bash |
|
python generate.py --ckpt aydao-anime-danbooru2019s-512-5268480.pt --size 512 --n_mlp 4 --additional_multiplier 2 --load_latent_vec --no-randomize-noise -o out_images --truncation 0.6 --seed 333 --pics 1 --sample 48 --ncol 8 --ext jpg |
|
``` |
|
|
|
|