|
import numpy as np |
|
import torch |
|
import copy |
|
import os |
|
import numpy as np |
|
from sklearn import svm |
|
|
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
|
|
|
|
|
def linear_interpolate(latent_code, boundary, start_distance=-3, end_distance=3, steps=10): |
|
"""Manipulates the given latent code with respect to a particular boundary. |
|
|
|
Basically, this function takes a latent code and a boundary as inputs, and |
|
outputs a collection of manipulated latent codes. For example, let `steps` to |
|
be 10, then the input `latent_code` is with shape [1, latent_space_dim], input |
|
`boundary` is with shape [1, latent_space_dim] and unit norm, the output is |
|
with shape [10, latent_space_dim]. The first output latent code is |
|
`start_distance` away from the given `boundary`, while the last output latent |
|
code is `end_distance` away from the given `boundary`. Remaining latent codes |
|
are linearly interpolated. |
|
|
|
Input `latent_code` can also be with shape [1, num_layers, latent_space_dim] |
|
to support W+ space in Style GAN. In this case, all features in W+ space will |
|
be manipulated same as each other. Accordingly, the output will be with shape |
|
[10, num_layers, latent_space_dim]. |
|
|
|
NOTE: Distance is sign sensitive. |
|
|
|
Args: |
|
latent_code: The input latent code for manipulation. |
|
boundary: The semantic boundary as reference. |
|
start_distance: The distance to the boundary where the manipulation starts. |
|
(default: -3.0) |
|
end_distance: The distance to the boundary where the manipulation ends. |
|
(default: 3.0) |
|
steps: Number of steps to move the latent code from start position to end |
|
position. (default: 10) |
|
""" |
|
assert latent_code.shape[0] == 1 and boundary.shape[0] == 1 and len(boundary.shape) == 2 and boundary.shape[1] == latent_code.shape[-1] |
|
|
|
linspace = np.linspace(start_distance, end_distance, steps) |
|
if len(latent_code.shape) == 2: |
|
linspace = linspace - latent_code.dot(boundary.T) |
|
linspace = linspace.reshape(-1, 1).astype(np.float32) |
|
return latent_code + linspace * boundary |
|
if len(latent_code.shape) == 3: |
|
linspace = linspace.reshape(-1, 1, 1).astype(np.float32) |
|
return latent_code + linspace * boundary.reshape(1, 1, -1) |
|
raise ValueError( |
|
f"Input `latent_code` should be with shape " |
|
f"[1, latent_space_dim] or [1, N, latent_space_dim] for " |
|
f"W+ space in Style GAN!\n" |
|
f"But {latent_code.shape} is received." |
|
) |
|
|
|
|
|
def get_code(domain, boundaries): |
|
if domain == "ink": |
|
domain = 0 |
|
elif domain == "monet": |
|
domain = 1 |
|
elif domain == "vangogh": |
|
domain = 2 |
|
elif domain == "water": |
|
domain = 3 |
|
|
|
res = np.array(torch.randn(1, 256, dtype=torch.float32)) |
|
|
|
res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res) |
|
return res |
|
|
|
|
|
def modify_code(code, boundaries, domain, range): |
|
if domain == "ink": |
|
domain = 0 |
|
elif domain == "monet": |
|
domain = 1 |
|
elif domain == "vangogh": |
|
domain = 2 |
|
elif domain == "water": |
|
domain = 3 |
|
|
|
if range == 0: |
|
return code |
|
else: |
|
res = np.array(code.cpu().detach().numpy()) |
|
res = linear_interpolate(res, boundaries[domain], end_distance=range, steps=3)[-1:] |
|
res = torch.Tensor(res).cuda() if torch.cuda.is_available() else torch.Tensor(res) |
|
return res |
|
|
|
|
|
def load_boundries(): |
|
domains = ["ink", "monet", "vangogh", "water"] |
|
domains.sort() |
|
boundaries = [ |
|
np.load(os.path.join(os.path.dirname(__file__), "boundaries_amp_52/artwork_" + domain + "_boundary/boundary.npy")) for domain in domains |
|
] |
|
return boundaries |
|
|