lpn / inference.py
clement-bonnet's picture
fix: x y corrdinates
f6ee8cd
import os
import sys
sys.path.append("..")
from PIL import Image
import matplotlib.pyplot as plt
import hydra
import omegaconf
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from flax.serialization import from_bytes
from huggingface_hub import snapshot_download
# lpn imports
from src.models.lpn import LPN
from src.models.transformer import EncoderTransformer, DecoderTransformer
from src.visualization import display_grid
from utils import patch_target, ax_to_pil
checkpoint_name = "quiet-thunder-789--checkpoint:v0"
BLUE_LOCATION_INPUTS = {0: 1, 1: 5, 2: 5, 3: 10}
local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*")
with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f:
cfg = omegaconf.OmegaConf.load(f)
patch_target(cfg)
encoder = EncoderTransformer(hydra.utils.instantiate(cfg.encoder_transformer))
decoder = DecoderTransformer(hydra.utils.instantiate(cfg.decoder_transformer))
lpn = LPN(encoder=encoder, decoder=decoder)
key = jax.random.PRNGKey(0)
grids = jax.random.randint(
key,
(1, 3, decoder.config.max_rows, decoder.config.max_cols, 2),
minval=0,
maxval=decoder.config.vocab_size,
)
shapes = jax.random.randint(
key,
(1, 3, 2, 2),
minval=1,
maxval=min(decoder.config.max_rows, decoder.config.max_cols) + 1,
)
variables = lpn.init(
key, grids, shapes, dropout_eval=False, prior_kl_coeff=0.0, pairwise_kl_coeff=0.0, mode="mean"
)
learning_rate, linear_warmup_steps = 0, 0
linear_warmup_scheduler = optax.warmup_exponential_decay_schedule(
init_value=learning_rate / (linear_warmup_steps + 1),
peak_value=learning_rate,
warmup_steps=linear_warmup_steps,
transition_steps=1,
end_value=learning_rate,
decay_rate=1.0,
)
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(linear_warmup_scheduler))
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)
train_state = TrainState.create(apply_fn=lpn.apply, tx=optimizer, params=variables["params"])
with open(os.path.join(local_dir, checkpoint_name, "state.msgpack"), "rb") as data_file:
byte_data = data_file.read()
loaded_state = from_bytes(train_state, byte_data)
generate_output_from_context = jax.jit(
lambda context, input, input_grid_shape: lpn.apply(
{"params": loaded_state.params},
context=context,
input=input,
input_grid_shape=input_grid_shape,
dropout_eval=True,
method=lpn._generate_output_from_context,
)
)
def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image:
# Create the input image
input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4)
# Inverse the y coordinate
y = 1 - y
# Ensure x and y are in [eps, 1 - eps]
x = min(1 - eps, max(eps, x))
y = min(1 - eps, max(eps, y))
# Convert x and y to context in R^2
context = jax.scipy.stats.norm.ppf(jnp.array([x, y]))
output_grids, _ = generate_output_from_context(
context=context[None], input=input[None], input_grid_shape=jnp.array([4, 4])[None]
)
output_grid = output_grids[0]
_, ax = plt.subplots(1, 1, figsize=(4, 4))
display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4]))
pil_image = ax_to_pil(ax)
return pil_image