lpn / inference.py
clement-bonnet's picture
fix: pil image generation
3e506b8
raw
history blame
3.31 kB
import os
import sys
sys.path.append("..")
from PIL import Image
import matplotlib.pyplot as plt
import hydra
import omegaconf
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from flax.serialization import from_bytes
from huggingface_hub import snapshot_download
# lpn imports
from src.models.lpn import LPN
from src.models.transformer import EncoderTransformer, DecoderTransformer
from src.visualization import display_grid
from utils import patch_target, ax_to_pil
checkpoint_name = "quiet-thunder-789--checkpoint:v0"
BLUE_LOCATION_INPUTS = {0: 13, 1: 9}
local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*")
with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f:
cfg = omegaconf.OmegaConf.load(f)
patch_target(cfg)
encoder = EncoderTransformer(hydra.utils.instantiate(cfg.encoder_transformer))
decoder = DecoderTransformer(hydra.utils.instantiate(cfg.decoder_transformer))
lpn = LPN(encoder=encoder, decoder=decoder)
key = jax.random.PRNGKey(0)
grids = jax.random.randint(
key,
(1, 3, decoder.config.max_rows, decoder.config.max_cols, 2),
minval=0,
maxval=decoder.config.vocab_size,
)
shapes = jax.random.randint(
key,
(1, 3, 2, 2),
minval=1,
maxval=min(decoder.config.max_rows, decoder.config.max_cols) + 1,
)
variables = lpn.init(
key, grids, shapes, dropout_eval=False, prior_kl_coeff=0.0, pairwise_kl_coeff=0.0, mode="mean"
)
learning_rate, linear_warmup_steps = 0, 0
linear_warmup_scheduler = optax.warmup_exponential_decay_schedule(
init_value=learning_rate / (linear_warmup_steps + 1),
peak_value=learning_rate,
warmup_steps=linear_warmup_steps,
transition_steps=1,
end_value=learning_rate,
decay_rate=1.0,
)
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(linear_warmup_scheduler))
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)
train_state = TrainState.create(apply_fn=lpn.apply, tx=optimizer, params=variables["params"])
with open(os.path.join(local_dir, checkpoint_name, "state.msgpack"), "rb") as data_file:
byte_data = data_file.read()
loaded_state = from_bytes(train_state, byte_data)
generate_output_from_context = jax.jit(
lambda context, input, input_grid_shape: lpn.apply(
{"params": loaded_state.params},
context=context,
input=input,
input_grid_shape=input_grid_shape,
dropout_eval=True,
method=lpn._generate_output_from_context,
)
)
def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image:
# Create the input image
input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4)
# Ensure x and y are in [eps, 1 - eps]
x = min(1 - eps, max(eps, x))
y = min(1 - eps, max(eps, y))
# Convert x and y to context in R^2
context = jax.scipy.stats.norm.ppf(jnp.array([x, y]))
output_grids, _ = generate_output_from_context(
context=context[None], input=input[None], input_grid_shape=jnp.array([4, 4])[None]
)
output_grid = output_grids[0]
_, ax = plt.subplots(1, 1, figsize=(4, 4))
display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4]))
pil_image = ax_to_pil(ax)
return pil_image