Spaces:
Running
Running
import os | |
import sys | |
sys.path.append("..") | |
from PIL import Image | |
import matplotlib.pyplot as plt | |
import hydra | |
import omegaconf | |
import jax | |
import jax.numpy as jnp | |
import optax | |
from flax.training.train_state import TrainState | |
from flax.serialization import from_bytes | |
from huggingface_hub import snapshot_download | |
# lpn imports | |
from src.models.lpn import LPN | |
from src.models.transformer import EncoderTransformer, DecoderTransformer | |
from src.visualization import display_grid | |
from utils import patch_target, ax_to_pil | |
checkpoint_name = "quiet-thunder-789--checkpoint:v0" | |
BLUE_LOCATION_INPUTS = {0: 1, 1: 5, 2: 5, 3: 10} | |
local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*") | |
with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f: | |
cfg = omegaconf.OmegaConf.load(f) | |
patch_target(cfg) | |
encoder = EncoderTransformer(hydra.utils.instantiate(cfg.encoder_transformer)) | |
decoder = DecoderTransformer(hydra.utils.instantiate(cfg.decoder_transformer)) | |
lpn = LPN(encoder=encoder, decoder=decoder) | |
key = jax.random.PRNGKey(0) | |
grids = jax.random.randint( | |
key, | |
(1, 3, decoder.config.max_rows, decoder.config.max_cols, 2), | |
minval=0, | |
maxval=decoder.config.vocab_size, | |
) | |
shapes = jax.random.randint( | |
key, | |
(1, 3, 2, 2), | |
minval=1, | |
maxval=min(decoder.config.max_rows, decoder.config.max_cols) + 1, | |
) | |
variables = lpn.init( | |
key, grids, shapes, dropout_eval=False, prior_kl_coeff=0.0, pairwise_kl_coeff=0.0, mode="mean" | |
) | |
learning_rate, linear_warmup_steps = 0, 0 | |
linear_warmup_scheduler = optax.warmup_exponential_decay_schedule( | |
init_value=learning_rate / (linear_warmup_steps + 1), | |
peak_value=learning_rate, | |
warmup_steps=linear_warmup_steps, | |
transition_steps=1, | |
end_value=learning_rate, | |
decay_rate=1.0, | |
) | |
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(linear_warmup_scheduler)) | |
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1) | |
train_state = TrainState.create(apply_fn=lpn.apply, tx=optimizer, params=variables["params"]) | |
with open(os.path.join(local_dir, checkpoint_name, "state.msgpack"), "rb") as data_file: | |
byte_data = data_file.read() | |
loaded_state = from_bytes(train_state, byte_data) | |
generate_output_from_context = jax.jit( | |
lambda context, input, input_grid_shape: lpn.apply( | |
{"params": loaded_state.params}, | |
context=context, | |
input=input, | |
input_grid_shape=input_grid_shape, | |
dropout_eval=True, | |
method=lpn._generate_output_from_context, | |
) | |
) | |
def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image: | |
# Create the input image | |
input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4) | |
# Inverse the y coordinate | |
y = 1 - y | |
# Ensure x and y are in [eps, 1 - eps] | |
x = min(1 - eps, max(eps, x)) | |
y = min(1 - eps, max(eps, y)) | |
# Convert x and y to context in R^2 | |
context = jax.scipy.stats.norm.ppf(jnp.array([x, y])) | |
output_grids, _ = generate_output_from_context( | |
context=context[None], input=input[None], input_grid_shape=jnp.array([4, 4])[None] | |
) | |
output_grid = output_grids[0] | |
_, ax = plt.subplots(1, 1, figsize=(4, 4)) | |
display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4])) | |
pil_image = ax_to_pil(ax) | |
return pil_image | |