File size: 3,305 Bytes
999b913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e506b8
999b913
3e506b8
999b913
 
 
3e506b8
999b913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e506b8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import sys

sys.path.append("..")

from PIL import Image
import matplotlib.pyplot as plt
import hydra
import omegaconf
import jax
import jax.numpy as jnp
import optax
from flax.training.train_state import TrainState
from flax.serialization import from_bytes
from huggingface_hub import snapshot_download

# lpn imports
from src.models.lpn import LPN
from src.models.transformer import EncoderTransformer, DecoderTransformer
from src.visualization import display_grid

from utils import patch_target, ax_to_pil


checkpoint_name = "quiet-thunder-789--checkpoint:v0"
BLUE_LOCATION_INPUTS = {0: 13, 1: 9}

local_dir = snapshot_download(repo_id="clement-bonnet/lpn-2d", allow_patterns=f"{checkpoint_name}/*")
with open(f"{local_dir}/{checkpoint_name}/config.yaml", "r") as f:
    cfg = omegaconf.OmegaConf.load(f)
patch_target(cfg)

encoder = EncoderTransformer(hydra.utils.instantiate(cfg.encoder_transformer))
decoder = DecoderTransformer(hydra.utils.instantiate(cfg.decoder_transformer))
lpn = LPN(encoder=encoder, decoder=decoder)

key = jax.random.PRNGKey(0)
grids = jax.random.randint(
    key,
    (1, 3, decoder.config.max_rows, decoder.config.max_cols, 2),
    minval=0,
    maxval=decoder.config.vocab_size,
)
shapes = jax.random.randint(
    key,
    (1, 3, 2, 2),
    minval=1,
    maxval=min(decoder.config.max_rows, decoder.config.max_cols) + 1,
)
variables = lpn.init(
    key, grids, shapes, dropout_eval=False, prior_kl_coeff=0.0, pairwise_kl_coeff=0.0, mode="mean"
)
learning_rate, linear_warmup_steps = 0, 0
linear_warmup_scheduler = optax.warmup_exponential_decay_schedule(
    init_value=learning_rate / (linear_warmup_steps + 1),
    peak_value=learning_rate,
    warmup_steps=linear_warmup_steps,
    transition_steps=1,
    end_value=learning_rate,
    decay_rate=1.0,
)
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(linear_warmup_scheduler))
optimizer = optax.MultiSteps(optimizer, every_k_schedule=1)
train_state = TrainState.create(apply_fn=lpn.apply, tx=optimizer, params=variables["params"])

with open(os.path.join(local_dir, checkpoint_name, "state.msgpack"), "rb") as data_file:
    byte_data = data_file.read()
loaded_state = from_bytes(train_state, byte_data)

generate_output_from_context = jax.jit(
    lambda context, input, input_grid_shape: lpn.apply(
        {"params": loaded_state.params},
        context=context,
        input=input,
        input_grid_shape=input_grid_shape,
        dropout_eval=True,
        method=lpn._generate_output_from_context,
    )
)


def generate_image(image_idx: int, x: float, y: float, eps: float = 1e-4) -> Image.Image:
    # Create the input image
    input = jnp.zeros(16, int).at[BLUE_LOCATION_INPUTS[image_idx]].set(1).reshape(4, 4)
    # Ensure x and y are in [eps, 1 - eps]
    x = min(1 - eps, max(eps, x))
    y = min(1 - eps, max(eps, y))
    # Convert x and y to context in R^2
    context = jax.scipy.stats.norm.ppf(jnp.array([x, y]))
    output_grids, _ = generate_output_from_context(
        context=context[None], input=input[None], input_grid_shape=jnp.array([4, 4])[None]
    )
    output_grid = output_grids[0]
    _, ax = plt.subplots(1, 1, figsize=(4, 4))
    display_grid(ax=ax, grid=output_grid, grid_shape=jnp.array([4, 4]))
    pil_image = ax_to_pil(ax)
    return pil_image