import os import sys sys.path.append("..") from PIL import Image import matplotlib.pyplot as plt import hydra import omegaconf import jax import jax.numpy as jnp import optax from flax.training.train_state import TrainState from flax.serialization import from_bytes from huggingface_hub import snapshot_download # lpn imports from src.models.lpn import LPN from src.models.transformer import EncoderTransformer, DecoderTransformer from src.visualization import display_grid from utils import patch_target, ax_to_pil checkpoint_name = "quiet-thunder-789--checkpoint:v0" BLUE_LOCATION_INPUTS = {0: 1, 1: 5, 2: 5, 3: 10} local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*") with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f: cfg = omegaconf.OmegaConf.load(f) patch_target(cfg) encoder = EncoderTransformer(hydra.utils.instantiate(cfg.encoder_transformer)) decoder = DecoderTransformer(hydra.utils.instantiate(cfg.decoder_transformer)) lpn = LPN(encoder=encoder, decoder=decoder) key = jax.random.PRNGKey(0) grids = jax.random.randint( key, (1, 3, decoder.config.max_rows, decoder.config.max_cols, 2), minval=0, maxval=decoder.config.vocab_size, ) shapes = jax.random.randint( key, (1, 3, 2, 2), minval=1, maxval=min(decoder.config.max_rows, decoder.config.max_cols) + 1, ) variables = lpn.init( key, grids, shapes, dropout_eval=False, prior_kl_coeff=0.0, pairwise_kl_coeff=0.0, mode="mean" ) learning_rate, linear_warmup_steps = 0, 0 linear_warmup_scheduler = optax.warmup_exponential_decay_schedule( init_value=learning_rate / (linear_warmup_steps + 1), peak_value=learning_rate, warmup_steps=linear_warmup_steps, transition_steps=1, end_value=learning_rate, decay_rate=1.0, ) optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(linear_warmup_scheduler)) optimizer = optax.MultiSteps(optimizer, every_k_schedule=1) train_state = TrainState.create(apply_fn=lpn.apply, tx=optimizer, params=variables["params"]) with open(os.path.join(local_dir, checkpoint_name, "state.msgpack"), "rb") as data_file: byte_data = data_file.read() loaded_state = from_bytes(train_state, byte_data) generate_output_from_context = jax.jit( lambda context, input, input_grid_shape: lpn.apply( {"params": loaded_state.params}, context=context, input=input, input_grid_shape=input_grid_shape, dropout_eval=True, method=lpn._generate_output_from_context, ) ) def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image: # Create the input image input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4) # Inverse the y coordinate y = 1 - y # Ensure x and y are in [eps, 1 - eps] x = min(1 - eps, max(eps, x)) y = min(1 - eps, max(eps, y)) # Convert x and y to context in R^2 context = jax.scipy.stats.norm.ppf(jnp.array([x, y])) output_grids, _ = generate_output_from_context( context=context[None], input=input[None], input_grid_shape=jnp.array([4, 4])[None] ) output_grid = output_grids[0] _, ax = plt.subplots(1, 1, figsize=(4, 4)) display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4])) pil_image = ax_to_pil(ax) return pil_image