sky24h's picture
init_commit
f3daba8
import numpy as np
import torch
import copy
import os
import numpy as np
from sklearn import svm
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
def linear_interpolate(latent_code, boundary, start_distance=-3, end_distance=3, steps=10):
"""Manipulates the given latent code with respect to a particular boundary.
Basically, this function takes a latent code and a boundary as inputs, and
outputs a collection of manipulated latent codes. For example, let `steps` to
be 10, then the input `latent_code` is with shape [1, latent_space_dim], input
`boundary` is with shape [1, latent_space_dim] and unit norm, the output is
with shape [10, latent_space_dim]. The first output latent code is
`start_distance` away from the given `boundary`, while the last output latent
code is `end_distance` away from the given `boundary`. Remaining latent codes
are linearly interpolated.
Input `latent_code` can also be with shape [1, num_layers, latent_space_dim]
to support W+ space in Style GAN. In this case, all features in W+ space will
be manipulated same as each other. Accordingly, the output will be with shape
[10, num_layers, latent_space_dim].
NOTE: Distance is sign sensitive.
Args:
latent_code: The input latent code for manipulation.
boundary: The semantic boundary as reference.
start_distance: The distance to the boundary where the manipulation starts.
(default: -3.0)
end_distance: The distance to the boundary where the manipulation ends.
(default: 3.0)
steps: Number of steps to move the latent code from start position to end
position. (default: 10)
"""
assert latent_code.shape[0] == 1 and boundary.shape[0] == 1 and len(boundary.shape) == 2 and boundary.shape[1] == latent_code.shape[-1]
linspace = np.linspace(start_distance, end_distance, steps)
if len(latent_code.shape) == 2:
linspace = linspace - latent_code.dot(boundary.T)
linspace = linspace.reshape(-1, 1).astype(np.float32)
return latent_code + linspace * boundary
if len(latent_code.shape) == 3:
linspace = linspace.reshape(-1, 1, 1).astype(np.float32)
return latent_code + linspace * boundary.reshape(1, 1, -1)
raise ValueError(
f"Input `latent_code` should be with shape "
f"[1, latent_space_dim] or [1, N, latent_space_dim] for "
f"W+ space in Style GAN!\n"
f"But {latent_code.shape} is received."
)
def get_code(domain, boundaries):
if domain == "ink":
domain = 0
elif domain == "monet":
domain = 1
elif domain == "vangogh":
domain = 2
elif domain == "water":
domain = 3
res = np.array(torch.randn(1, 256, dtype=torch.float32))
# res = linear_interpolate(res, boundaries[domain], end_distance=3, steps=3)[-1:]
res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res)
return res
def modify_code(code, boundaries, domain, range):
if domain == "ink":
domain = 0
elif domain == "monet":
domain = 1
elif domain == "vangogh":
domain = 2
elif domain == "water":
domain = 3
# print(domain, range)
if range == 0:
return code
else:
res = np.array(code.cpu().detach().numpy())
res = linear_interpolate(res, boundaries[domain], end_distance=range, steps=3)[-1:]
res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res)
return res
def load_boundries():
domains = ["ink", "monet", "vangogh", "water"]
domains.sort()
boundaries = [
np.load(os.path.join(os.path.dirname(__file__), "boundaries_amp_52/artwork_" + domain + "_boundary/boundary.npy")) for domain in domains
]
return boundaries