pcuenca commited on
Commit
5bf185b
2 Parent(s): 0497ad3 5afc82f

Merge pull request #39 from borisdayma/chore-cleanup2

Browse files

refactor: reorganize files
Former-commit-id: 10c15512cf9e4095046bcd8dda622b4ad8fd8b80

README.md CHANGED
@@ -14,11 +14,11 @@ _Generate images from a text prompt_
14
 
15
  TODO: add some cool example
16
 
17
- ## [Create my own images with the demo →](TODO)
18
 
19
  ## How does it work?
20
 
21
- Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).
22
 
23
  ## Development
24
 
@@ -54,6 +54,8 @@ Use [patil-suraj/vqgan-jax](https://github.com/patil-suraj/vqgan-jax).
54
 
55
  Refer to `seq2seq` folder (some parameters may have been hardcoded for convenience when training on our TPU VM).
56
 
 
 
57
  ### Inference
58
 
59
  Refer to the demo notebooks.
 
14
 
15
  TODO: add some cool example
16
 
17
+ ## Create my own images with the demo → Coming soon
18
 
19
  ## How does it work?
20
 
21
+ Refer to [our report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA?accessToken=2ua7j8ebc810fuxyv49wbipmq3fb2e78yq3rvs5dy4wew07wwm2csdo8zcuyr14e).
22
 
23
  ## Development
24
 
 
54
 
55
  Refer to `seq2seq` folder (some parameters may have been hardcoded for convenience when training on our TPU VM).
56
 
57
+ You can also adjust the [sweep configuration file](https://docs.wandb.ai/guides/sweeps) if you need to perform a hyperparameter search.
58
+
59
  ### Inference
60
 
61
  Refer to the demo notebooks.
app/app_gradio.py CHANGED
@@ -12,74 +12,28 @@ import flax.linen as nn
12
  from flax.training.common_utils import shard
13
  from flax.jax_utils import replicate, unreplicate
14
 
15
- from transformers.models.bart.modeling_flax_bart import *
16
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
17
 
18
-
19
- import requests
20
  from PIL import Image
21
  import numpy as np
22
  import matplotlib.pyplot as plt
23
 
24
 
25
  from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
 
26
 
27
  import gradio as gr
28
 
29
 
30
- # TODO: set those args in a config file
31
- OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
32
- OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
33
- BOS_TOKEN_ID = 16384
34
- BASE_MODEL = 'flax-community/dalle-mini'
35
-
36
- class CustomFlaxBartModule(FlaxBartModule):
37
- def setup(self):
38
- # we keep shared to easily load pre-trained weights
39
- self.shared = nn.Embed(
40
- self.config.vocab_size,
41
- self.config.d_model,
42
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
43
- dtype=self.dtype,
44
- )
45
- # a separate embedding is used for the decoder
46
- self.decoder_embed = nn.Embed(
47
- OUTPUT_VOCAB_SIZE,
48
- self.config.d_model,
49
- embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
50
- dtype=self.dtype,
51
- )
52
- self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
53
-
54
- # the decoder has a different config
55
- decoder_config = BartConfig(self.config.to_dict())
56
- decoder_config.max_position_embeddings = OUTPUT_LENGTH
57
- decoder_config.vocab_size = OUTPUT_VOCAB_SIZE
58
- self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
59
-
60
- class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
61
- def setup(self):
62
- self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
63
- self.lm_head = nn.Dense(
64
- OUTPUT_VOCAB_SIZE,
65
- use_bias=False,
66
- dtype=self.dtype,
67
- kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
68
- )
69
- self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, OUTPUT_VOCAB_SIZE))
70
-
71
- class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
72
- module_class = CustomFlaxBartForConditionalGenerationModule
73
-
74
- # create our model
75
- # FIXME: Save tokenizer to hub so we can load from there
76
- tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
77
- model = CustomFlaxBartForConditionalGeneration.from_pretrained(BASE_MODEL)
78
- model.config.force_bos_token_to_be_generated = False
79
- model.config.forced_bos_token_id = None
80
- model.config.forced_eos_token_id = None
81
-
82
- vqgan = VQModel.from_pretrained("flax-community/vqgan_f16_16384")
83
 
84
  def custom_to_pil(x):
85
  x = np.clip(x, 0., 1.)
 
12
  from flax.training.common_utils import shard
13
  from flax.jax_utils import replicate, unreplicate
14
 
 
15
  from transformers import BartTokenizer, FlaxBartForConditionalGeneration
16
 
 
 
17
  from PIL import Image
18
  import numpy as np
19
  import matplotlib.pyplot as plt
20
 
21
 
22
  from dalle_mini.vqgan_jax.modeling_flax_vqgan import VQModel
23
+ from dalle_mini.model import CustomFlaxBartForConditionalGeneration
24
 
25
  import gradio as gr
26
 
27
 
28
+ DALLE_REPO = 'flax-community/dalle-mini'
29
+ DALLE_COMMIT_ID = '4d34126d0df8bc4a692ae933e3b902a1fa8b6114'
30
+
31
+ VQGAN_REPO = 'flax-community/vqgan_f16_16384'
32
+ VQGAN_COMMIT_ID = '90cc46addd2dd8f5be21586a9a23e1b95aa506a9'
33
+
34
+ tokenizer = BartTokenizer.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
35
+ model = CustomFlaxBartForConditionalGeneration.from_pretrained(DALLE_REPO, revision=DALLE_COMMIT_ID)
36
+ vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  def custom_to_pil(x):
39
  x = np.clip(x, 0., 1.)
app/dalle_mini ADDED
@@ -0,0 +1 @@
 
 
1
+ ../dalle_mini/
app/dalle_mini/__init__.py DELETED
@@ -1 +0,0 @@
1
- __version__ = "0.0.1"
 
 
app/dalle_mini/dataset.py DELETED
@@ -1,122 +0,0 @@
1
- """
2
- An image-caption dataset dataloader.
3
- Luke Melas-Kyriazi, 2021
4
- """
5
- import warnings
6
- from typing import Optional, Callable
7
- from pathlib import Path
8
- import numpy as np
9
- import torch
10
- import pandas as pd
11
- from torch.utils.data import Dataset
12
- from torchvision.datasets.folder import default_loader
13
- from PIL import ImageFile
14
- from PIL.Image import DecompressionBombWarning
15
- ImageFile.LOAD_TRUNCATED_IMAGES = True
16
- warnings.filterwarnings("ignore", category=UserWarning)
17
- warnings.filterwarnings("ignore", category=DecompressionBombWarning)
18
-
19
-
20
- class CaptionDataset(Dataset):
21
- """
22
- A PyTorch Dataset class for (image, texts) tasks. Note that this dataset
23
- returns the raw text rather than tokens. This is done on purpose, because
24
- it's easy to tokenize a batch of text after loading it from this dataset.
25
- """
26
-
27
- def __init__(self, *, images_root: str, captions_path: str, text_transform: Optional[Callable] = None,
28
- image_transform: Optional[Callable] = None, image_transform_type: str = 'torchvision',
29
- include_captions: bool = True):
30
- """
31
- :param images_root: folder where images are stored
32
- :param captions_path: path to csv that maps image filenames to captions
33
- :param image_transform: image transform pipeline
34
- :param text_transform: image transform pipeline
35
- :param image_transform_type: image transform type, either `torchvision` or `albumentations`
36
- :param include_captions: Returns a dictionary with `image`, `text` if `true`; otherwise returns just the images.
37
- """
38
-
39
- # Base path for images
40
- self.images_root = Path(images_root)
41
-
42
- # Load captions as DataFrame
43
- self.captions = pd.read_csv(captions_path, delimiter='\t', header=0)
44
- self.captions['image_file'] = self.captions['image_file'].astype(str)
45
-
46
- # PyTorch transformation pipeline for the image (normalizing, etc.)
47
- self.text_transform = text_transform
48
- self.image_transform = image_transform
49
- self.image_transform_type = image_transform_type.lower()
50
- assert self.image_transform_type in ['torchvision', 'albumentations']
51
-
52
- # Total number of datapoints
53
- self.size = len(self.captions)
54
-
55
- # Return image+captions or just images
56
- self.include_captions = include_captions
57
-
58
- def verify_that_all_images_exist(self):
59
- for image_file in self.captions['image_file']:
60
- p = self.images_root / image_file
61
- if not p.is_file():
62
- print(f'file does not exist: {p}')
63
-
64
- def _get_raw_image(self, i):
65
- image_file = self.captions.iloc[i]['image_file']
66
- image_path = self.images_root / image_file
67
- image = default_loader(image_path)
68
- return image
69
-
70
- def _get_raw_text(self, i):
71
- return self.captions.iloc[i]['caption']
72
-
73
- def __getitem__(self, i):
74
- image = self._get_raw_image(i)
75
- caption = self._get_raw_text(i)
76
- if self.image_transform is not None:
77
- if self.image_transform_type == 'torchvision':
78
- image = self.image_transform(image)
79
- elif self.image_transform_type == 'albumentations':
80
- image = self.image_transform(image=np.array(image))['image']
81
- else:
82
- raise NotImplementedError(f"{self.image_transform_type=}")
83
- return {'image': image, 'text': caption} if self.include_captions else image
84
-
85
- def __len__(self):
86
- return self.size
87
-
88
-
89
- if __name__ == "__main__":
90
- import albumentations as A
91
- from albumentations.pytorch import ToTensorV2
92
- from transformers import AutoTokenizer
93
-
94
- # Paths
95
- images_root = './images'
96
- captions_path = './images-list-clean.tsv'
97
-
98
- # Create transforms
99
- tokenizer = AutoTokenizer.from_pretrained('distilroberta-base')
100
- def tokenize(text):
101
- return tokenizer(text, max_length=32, truncation=True, return_tensors='pt', padding='max_length')
102
- image_transform = A.Compose([
103
- A.Resize(256, 256), A.CenterCrop(256, 256),
104
- A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), ToTensorV2()])
105
-
106
- # Create dataset
107
- dataset = CaptionDataset(
108
- images_root=images_root,
109
- captions_path=captions_path,
110
- image_transform=image_transform,
111
- text_transform=tokenize,
112
- image_transform_type='albumentations')
113
-
114
- # Create dataloader
115
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)
116
- batch = next(iter(dataloader))
117
- print({k: (v.shape if isinstance(v, torch.Tensor) else v) for k, v in batch.items()})
118
-
119
- # # (Optional) Check that all the images exist
120
- # dataset = CaptionDataset(images_root=images_root, captions_path=captions_path)
121
- # dataset.verify_that_all_images_exist()
122
- # print('Done')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/dalle_mini/vqgan_jax/__init__.py DELETED
File without changes
app/dalle_mini/vqgan_jax/configuration_vqgan.py DELETED
@@ -1,40 +0,0 @@
1
- from typing import Tuple
2
-
3
- from transformers import PretrainedConfig
4
-
5
-
6
- class VQGANConfig(PretrainedConfig):
7
- def __init__(
8
- self,
9
- ch: int = 128,
10
- out_ch: int = 3,
11
- in_channels: int = 3,
12
- num_res_blocks: int = 2,
13
- resolution: int = 256,
14
- z_channels: int = 256,
15
- ch_mult: Tuple = (1, 1, 2, 2, 4),
16
- attn_resolutions: int = (16,),
17
- n_embed: int = 1024,
18
- embed_dim: int = 256,
19
- dropout: float = 0.0,
20
- double_z: bool = False,
21
- resamp_with_conv: bool = True,
22
- give_pre_end: bool = False,
23
- **kwargs,
24
- ):
25
- super().__init__(**kwargs)
26
- self.ch = ch
27
- self.out_ch = out_ch
28
- self.in_channels = in_channels
29
- self.num_res_blocks = num_res_blocks
30
- self.resolution = resolution
31
- self.z_channels = z_channels
32
- self.ch_mult = list(ch_mult)
33
- self.attn_resolutions = list(attn_resolutions)
34
- self.n_embed = n_embed
35
- self.embed_dim = embed_dim
36
- self.dropout = dropout
37
- self.double_z = double_z
38
- self.resamp_with_conv = resamp_with_conv
39
- self.give_pre_end = give_pre_end
40
- self.num_resolutions = len(ch_mult)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py DELETED
@@ -1,109 +0,0 @@
1
- import re
2
-
3
- import jax.numpy as jnp
4
- from flax.traverse_util import flatten_dict, unflatten_dict
5
-
6
- import torch
7
-
8
- from modeling_flax_vqgan import VQModel
9
- from configuration_vqgan import VQGANConfig
10
-
11
-
12
- regex = r"\w+[.]\d+"
13
-
14
-
15
- def rename_key(key):
16
- pats = re.findall(regex, key)
17
- for pat in pats:
18
- key = key.replace(pat, "_".join(pat.split(".")))
19
- return key
20
-
21
-
22
- # Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
23
- def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
24
- # convert pytorch tensor to numpy
25
- pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
26
-
27
- random_flax_state_dict = flatten_dict(flax_model.params)
28
- flax_state_dict = {}
29
-
30
- remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
31
- flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
32
- )
33
- add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
34
- flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
35
- )
36
-
37
- # Need to change some parameters name to match Flax names so that we don't have to fork any layer
38
- for pt_key, pt_tensor in pt_state_dict.items():
39
- pt_tuple_key = tuple(pt_key.split("."))
40
-
41
- has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
42
- require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
43
-
44
- if remove_base_model_prefix and has_base_model_prefix:
45
- pt_tuple_key = pt_tuple_key[1:]
46
- elif add_base_model_prefix and require_base_model_prefix:
47
- pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
48
-
49
- # Correctly rename weight parameters
50
- if (
51
- "norm" in pt_key
52
- and (pt_tuple_key[-1] == "bias")
53
- and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
54
- ):
55
- pt_tensor = pt_tensor[None, None, None, :]
56
- elif (
57
- "norm" in pt_key
58
- and (pt_tuple_key[-1] == "bias")
59
- and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
60
- ):
61
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
62
- pt_tensor = pt_tensor[None, None, None, :]
63
- elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
64
- pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
65
- pt_tensor = pt_tensor[None, None, None, :]
66
- if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
67
- pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
68
- elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
69
- # conv layer
70
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
71
- pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
72
- elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
73
- # linear layer
74
- pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
75
- pt_tensor = pt_tensor.T
76
- elif pt_tuple_key[-1] == "gamma":
77
- pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78
- elif pt_tuple_key[-1] == "beta":
79
- pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
80
-
81
- if pt_tuple_key in random_flax_state_dict:
82
- if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
83
- raise ValueError(
84
- f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
85
- f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
86
- )
87
-
88
- # also add unexpected weight so that warning is thrown
89
- flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
90
-
91
- return unflatten_dict(flax_state_dict)
92
-
93
-
94
- def convert_model(config_path, pt_state_dict_path, save_path):
95
- config = VQGANConfig.from_pretrained(config_path)
96
- model = VQModel(config)
97
-
98
- state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
99
- keys = list(state_dict.keys())
100
- for key in keys:
101
- if key.startswith("loss"):
102
- state_dict.pop(key)
103
- continue
104
- renamed_key = rename_key(key)
105
- state_dict[renamed_key] = state_dict.pop(key)
106
-
107
- state = convert_pytorch_state_dict_to_flax(state_dict, model)
108
- model.params = unflatten_dict(state)
109
- model.save_pretrained(save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app/dalle_mini/vqgan_jax/modeling_flax_vqgan.py DELETED
@@ -1,609 +0,0 @@
1
- # JAX implementation of VQGAN from taming-transformers https://github.com/CompVis/taming-transformers
2
-
3
- from functools import partial
4
- from typing import Tuple
5
- import math
6
-
7
- import jax
8
- import jax.numpy as jnp
9
- import numpy as np
10
- import flax.linen as nn
11
- from flax.core.frozen_dict import FrozenDict
12
-
13
- from transformers.modeling_flax_utils import FlaxPreTrainedModel
14
-
15
- from .configuration_vqgan import VQGANConfig
16
-
17
-
18
- class Upsample(nn.Module):
19
- in_channels: int
20
- with_conv: bool
21
- dtype: jnp.dtype = jnp.float32
22
-
23
- def setup(self):
24
- if self.with_conv:
25
- self.conv = nn.Conv(
26
- self.in_channels,
27
- kernel_size=(3, 3),
28
- strides=(1, 1),
29
- padding=((1, 1), (1, 1)),
30
- dtype=self.dtype,
31
- )
32
-
33
- def __call__(self, hidden_states):
34
- batch, height, width, channels = hidden_states.shape
35
- hidden_states = jax.image.resize(
36
- hidden_states,
37
- shape=(batch, height * 2, width * 2, channels),
38
- method="nearest",
39
- )
40
- if self.with_conv:
41
- hidden_states = self.conv(hidden_states)
42
- return hidden_states
43
-
44
-
45
- class Downsample(nn.Module):
46
- in_channels: int
47
- with_conv: bool
48
- dtype: jnp.dtype = jnp.float32
49
-
50
- def setup(self):
51
- if self.with_conv:
52
- self.conv = nn.Conv(
53
- self.in_channels,
54
- kernel_size=(3, 3),
55
- strides=(2, 2),
56
- padding="VALID",
57
- dtype=self.dtype,
58
- )
59
-
60
- def __call__(self, hidden_states):
61
- if self.with_conv:
62
- pad = ((0, 0), (0, 1), (0, 1), (0, 0)) # pad height and width dim
63
- hidden_states = jnp.pad(hidden_states, pad_width=pad)
64
- hidden_states = self.conv(hidden_states)
65
- else:
66
- hidden_states = nn.avg_pool(hidden_states, window_shape=(2, 2), strides=(2, 2), padding="VALID")
67
- return hidden_states
68
-
69
-
70
- class ResnetBlock(nn.Module):
71
- in_channels: int
72
- out_channels: int = None
73
- use_conv_shortcut: bool = False
74
- temb_channels: int = 512
75
- dropout_prob: float = 0.0
76
- dtype: jnp.dtype = jnp.float32
77
-
78
- def setup(self):
79
- self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels
80
-
81
- self.norm1 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
82
- self.conv1 = nn.Conv(
83
- self.out_channels_,
84
- kernel_size=(3, 3),
85
- strides=(1, 1),
86
- padding=((1, 1), (1, 1)),
87
- dtype=self.dtype,
88
- )
89
-
90
- if self.temb_channels:
91
- self.temb_proj = nn.Dense(self.out_channels_, dtype=self.dtype)
92
-
93
- self.norm2 = nn.GroupNorm(num_groups=32, epsilon=1e-6)
94
- self.dropout = nn.Dropout(self.dropout_prob)
95
- self.conv2 = nn.Conv(
96
- self.out_channels_,
97
- kernel_size=(3, 3),
98
- strides=(1, 1),
99
- padding=((1, 1), (1, 1)),
100
- dtype=self.dtype,
101
- )
102
-
103
- if self.in_channels != self.out_channels_:
104
- if self.use_conv_shortcut:
105
- self.conv_shortcut = nn.Conv(
106
- self.out_channels_,
107
- kernel_size=(3, 3),
108
- strides=(1, 1),
109
- padding=((1, 1), (1, 1)),
110
- dtype=self.dtype,
111
- )
112
- else:
113
- self.nin_shortcut = nn.Conv(
114
- self.out_channels_,
115
- kernel_size=(1, 1),
116
- strides=(1, 1),
117
- padding="VALID",
118
- dtype=self.dtype,
119
- )
120
-
121
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
122
- residual = hidden_states
123
- hidden_states = self.norm1(hidden_states)
124
- hidden_states = nn.swish(hidden_states)
125
- hidden_states = self.conv1(hidden_states)
126
-
127
- if temb is not None:
128
- hidden_states = hidden_states + self.temb_proj(nn.swish(temb))[:, :, None, None] # TODO: check shapes
129
-
130
- hidden_states = self.norm2(hidden_states)
131
- hidden_states = nn.swish(hidden_states)
132
- hidden_states = self.dropout(hidden_states, deterministic)
133
- hidden_states = self.conv2(hidden_states)
134
-
135
- if self.in_channels != self.out_channels_:
136
- if self.use_conv_shortcut:
137
- residual = self.conv_shortcut(residual)
138
- else:
139
- residual = self.nin_shortcut(residual)
140
-
141
- return hidden_states + residual
142
-
143
-
144
- class AttnBlock(nn.Module):
145
- in_channels: int
146
- dtype: jnp.dtype = jnp.float32
147
-
148
- def setup(self):
149
- conv = partial(
150
- nn.Conv, self.in_channels, kernel_size=(1, 1), strides=(1, 1), padding="VALID", dtype=self.dtype
151
- )
152
-
153
- self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-6)
154
- self.q, self.k, self.v = conv(), conv(), conv()
155
- self.proj_out = conv()
156
-
157
- def __call__(self, hidden_states):
158
- residual = hidden_states
159
- hidden_states = self.norm(hidden_states)
160
-
161
- query = self.q(hidden_states)
162
- key = self.k(hidden_states)
163
- value = self.v(hidden_states)
164
-
165
- # compute attentions
166
- batch, height, width, channels = query.shape
167
- query = query.reshape((batch, height * width, channels))
168
- key = key.reshape((batch, height * width, channels))
169
- attn_weights = jnp.einsum("...qc,...kc->...qk", query, key)
170
- attn_weights = attn_weights * (int(channels) ** -0.5)
171
- attn_weights = nn.softmax(attn_weights, axis=2)
172
-
173
- ## attend to values
174
- value = value.reshape((batch, height * width, channels))
175
- hidden_states = jnp.einsum("...kc,...qk->...qc", value, attn_weights)
176
- hidden_states = hidden_states.reshape((batch, height, width, channels))
177
-
178
- hidden_states = self.proj_out(hidden_states)
179
- hidden_states = hidden_states + residual
180
- return hidden_states
181
-
182
-
183
- class UpsamplingBlock(nn.Module):
184
- config: VQGANConfig
185
- curr_res: int
186
- block_idx: int
187
- dtype: jnp.dtype = jnp.float32
188
-
189
- def setup(self):
190
- if self.block_idx == self.config.num_resolutions - 1:
191
- block_in = self.config.ch * self.config.ch_mult[-1]
192
- else:
193
- block_in = self.config.ch * self.config.ch_mult[self.block_idx + 1]
194
-
195
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
196
- self.temb_ch = 0
197
-
198
- res_blocks = []
199
- attn_blocks = []
200
- for _ in range(self.config.num_res_blocks + 1):
201
- res_blocks.append(
202
- ResnetBlock(
203
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
204
- )
205
- )
206
- block_in = block_out
207
- if self.curr_res in self.config.attn_resolutions:
208
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
209
-
210
- self.block = res_blocks
211
- self.attn = attn_blocks
212
-
213
- self.upsample = None
214
- if self.block_idx != 0:
215
- self.upsample = Upsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
216
-
217
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
218
- for res_block in self.block:
219
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
220
- for attn_block in self.attn:
221
- hidden_states = attn_block(hidden_states)
222
-
223
- if self.upsample is not None:
224
- hidden_states = self.upsample(hidden_states)
225
-
226
- return hidden_states
227
-
228
-
229
- class DownsamplingBlock(nn.Module):
230
- config: VQGANConfig
231
- curr_res: int
232
- block_idx: int
233
- dtype: jnp.dtype = jnp.float32
234
-
235
- def setup(self):
236
- in_ch_mult = (1,) + tuple(self.config.ch_mult)
237
- block_in = self.config.ch * in_ch_mult[self.block_idx]
238
- block_out = self.config.ch * self.config.ch_mult[self.block_idx]
239
- self.temb_ch = 0
240
-
241
- res_blocks = []
242
- attn_blocks = []
243
- for _ in range(self.config.num_res_blocks):
244
- res_blocks.append(
245
- ResnetBlock(
246
- block_in, block_out, temb_channels=self.temb_ch, dropout_prob=self.config.dropout, dtype=self.dtype
247
- )
248
- )
249
- block_in = block_out
250
- if self.curr_res in self.config.attn_resolutions:
251
- attn_blocks.append(AttnBlock(block_in, dtype=self.dtype))
252
-
253
- self.block = res_blocks
254
- self.attn = attn_blocks
255
-
256
- self.downsample = None
257
- if self.block_idx != self.config.num_resolutions - 1:
258
- self.downsample = Downsample(block_in, self.config.resamp_with_conv, dtype=self.dtype)
259
-
260
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
261
- for res_block in self.block:
262
- hidden_states = res_block(hidden_states, temb, deterministic=deterministic)
263
- for attn_block in self.attn:
264
- hidden_states = attn_block(hidden_states)
265
-
266
- if self.downsample is not None:
267
- hidden_states = self.downsample(hidden_states)
268
-
269
- return hidden_states
270
-
271
-
272
- class MidBlock(nn.Module):
273
- in_channels: int
274
- temb_channels: int
275
- dropout: float
276
- dtype: jnp.dtype = jnp.float32
277
-
278
- def setup(self):
279
- self.block_1 = ResnetBlock(
280
- self.in_channels,
281
- self.in_channels,
282
- temb_channels=self.temb_channels,
283
- dropout_prob=self.dropout,
284
- dtype=self.dtype,
285
- )
286
- self.attn_1 = AttnBlock(self.in_channels, dtype=self.dtype)
287
- self.block_2 = ResnetBlock(
288
- self.in_channels,
289
- self.in_channels,
290
- temb_channels=self.temb_channels,
291
- dropout_prob=self.dropout,
292
- dtype=self.dtype,
293
- )
294
-
295
- def __call__(self, hidden_states, temb=None, deterministic: bool = True):
296
- hidden_states = self.block_1(hidden_states, temb, deterministic=deterministic)
297
- hidden_states = self.attn_1(hidden_states)
298
- hidden_states = self.block_2(hidden_states, temb, deterministic=deterministic)
299
- return hidden_states
300
-
301
-
302
- class Encoder(nn.Module):
303
- config: VQGANConfig
304
- dtype: jnp.dtype = jnp.float32
305
-
306
- def setup(self):
307
- self.temb_ch = 0
308
-
309
- # downsampling
310
- self.conv_in = nn.Conv(
311
- self.config.ch,
312
- kernel_size=(3, 3),
313
- strides=(1, 1),
314
- padding=((1, 1), (1, 1)),
315
- dtype=self.dtype,
316
- )
317
-
318
- curr_res = self.config.resolution
319
- downsample_blocks = []
320
- for i_level in range(self.config.num_resolutions):
321
- downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
322
-
323
- if i_level != self.config.num_resolutions - 1:
324
- curr_res = curr_res // 2
325
- self.down = downsample_blocks
326
-
327
- # middle
328
- mid_channels = self.config.ch * self.config.ch_mult[-1]
329
- self.mid = MidBlock(mid_channels, self.temb_ch, self.config.dropout, dtype=self.dtype)
330
-
331
- # end
332
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
333
- self.conv_out = nn.Conv(
334
- 2 * self.config.z_channels if self.config.double_z else self.config.z_channels,
335
- kernel_size=(3, 3),
336
- strides=(1, 1),
337
- padding=((1, 1), (1, 1)),
338
- dtype=self.dtype,
339
- )
340
-
341
- def __call__(self, pixel_values, deterministic: bool = True):
342
- # timestep embedding
343
- temb = None
344
-
345
- # downsampling
346
- hidden_states = self.conv_in(pixel_values)
347
- for block in self.down:
348
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
349
-
350
- # middle
351
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
352
-
353
- # end
354
- hidden_states = self.norm_out(hidden_states)
355
- hidden_states = nn.swish(hidden_states)
356
- hidden_states = self.conv_out(hidden_states)
357
-
358
- return hidden_states
359
-
360
-
361
- class Decoder(nn.Module):
362
- config: VQGANConfig
363
- dtype: jnp.dtype = jnp.float32
364
-
365
- def setup(self):
366
- self.temb_ch = 0
367
-
368
- # compute in_ch_mult, block_in and curr_res at lowest res
369
- block_in = self.config.ch * self.config.ch_mult[self.config.num_resolutions - 1]
370
- curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
371
- self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
372
- print("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
373
-
374
- # z to block_in
375
- self.conv_in = nn.Conv(
376
- block_in,
377
- kernel_size=(3, 3),
378
- strides=(1, 1),
379
- padding=((1, 1), (1, 1)),
380
- dtype=self.dtype,
381
- )
382
-
383
- # middle
384
- self.mid = MidBlock(block_in, self.temb_ch, self.config.dropout, dtype=self.dtype)
385
-
386
- # upsampling
387
- upsample_blocks = []
388
- for i_level in reversed(range(self.config.num_resolutions)):
389
- upsample_blocks.append(UpsamplingBlock(self.config, curr_res, block_idx=i_level, dtype=self.dtype))
390
- if i_level != 0:
391
- curr_res = curr_res * 2
392
- self.up = list(reversed(upsample_blocks)) # reverse to get consistent order
393
-
394
- # end
395
- self.norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-6)
396
- self.conv_out = nn.Conv(
397
- self.config.out_ch,
398
- kernel_size=(3, 3),
399
- strides=(1, 1),
400
- padding=((1, 1), (1, 1)),
401
- dtype=self.dtype,
402
- )
403
-
404
- def __call__(self, hidden_states, deterministic: bool = True):
405
- # timestep embedding
406
- temb = None
407
-
408
- # z to block_in
409
- hidden_states = self.conv_in(hidden_states)
410
-
411
- # middle
412
- hidden_states = self.mid(hidden_states, temb, deterministic=deterministic)
413
-
414
- # upsampling
415
- for block in reversed(self.up):
416
- hidden_states = block(hidden_states, temb, deterministic=deterministic)
417
-
418
- # end
419
- if self.config.give_pre_end:
420
- return hidden_states
421
-
422
- hidden_states = self.norm_out(hidden_states)
423
- hidden_states = nn.swish(hidden_states)
424
- hidden_states = self.conv_out(hidden_states)
425
-
426
- return hidden_states
427
-
428
-
429
- class VectorQuantizer(nn.Module):
430
- """
431
- see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
432
- ____________________________________________
433
- Discretization bottleneck part of the VQ-VAE.
434
- Inputs:
435
- - n_e : number of embeddings
436
- - e_dim : dimension of embedding
437
- - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
438
- _____________________________________________
439
- """
440
-
441
- config: VQGANConfig
442
- dtype: jnp.dtype = jnp.float32
443
-
444
- def setup(self):
445
- self.embedding = nn.Embed(self.config.n_embed, self.config.embed_dim, dtype=self.dtype) # TODO: init
446
-
447
- def __call__(self, hidden_states):
448
- """
449
- Inputs the output of the encoder network z and maps it to a discrete
450
- one-hot vector that is the index of the closest embedding vector e_j
451
- z (continuous) -> z_q (discrete)
452
- z.shape = (batch, channel, height, width)
453
- quantization pipeline:
454
- 1. get encoder input (B,C,H,W)
455
- 2. flatten input to (B*H*W,C)
456
- """
457
- # flatten
458
- hidden_states_flattended = hidden_states.reshape((-1, self.config.embed_dim))
459
-
460
- # dummy op to init the weights, so we can access them below
461
- self.embedding(jnp.ones((1, 1), dtype="i4"))
462
-
463
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
464
- emb_weights = self.variables["params"]["embedding"]["embedding"]
465
- distance = (
466
- jnp.sum(hidden_states_flattended ** 2, axis=1, keepdims=True)
467
- + jnp.sum(emb_weights ** 2, axis=1)
468
- - 2 * jnp.dot(hidden_states_flattended, emb_weights.T)
469
- )
470
-
471
- # get quantized latent vectors
472
- min_encoding_indices = jnp.argmin(distance, axis=1)
473
- z_q = self.embedding(min_encoding_indices).reshape(hidden_states.shape)
474
-
475
- # reshape to (batch, num_tokens)
476
- min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
477
-
478
- # compute the codebook_loss (q_loss) outside the model
479
- # here we return the embeddings and indices
480
- return z_q, min_encoding_indices
481
-
482
- def get_codebook_entry(self, indices, shape=None):
483
- # indices are expected to be of shape (batch, num_tokens)
484
- # get quantized latent vectors
485
- batch, num_tokens = indices.shape
486
- z_q = self.embedding(indices)
487
- z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1)
488
- return z_q
489
-
490
-
491
- class VQModule(nn.Module):
492
- config: VQGANConfig
493
- dtype: jnp.dtype = jnp.float32
494
-
495
- def setup(self):
496
- self.encoder = Encoder(self.config, dtype=self.dtype)
497
- self.decoder = Decoder(self.config, dtype=self.dtype)
498
- self.quantize = VectorQuantizer(self.config, dtype=self.dtype)
499
- self.quant_conv = nn.Conv(
500
- self.config.embed_dim,
501
- kernel_size=(1, 1),
502
- strides=(1, 1),
503
- padding="VALID",
504
- dtype=self.dtype,
505
- )
506
- self.post_quant_conv = nn.Conv(
507
- self.config.z_channels,
508
- kernel_size=(1, 1),
509
- strides=(1, 1),
510
- padding="VALID",
511
- dtype=self.dtype,
512
- )
513
-
514
- def encode(self, pixel_values, deterministic: bool = True):
515
- hidden_states = self.encoder(pixel_values, deterministic=deterministic)
516
- hidden_states = self.quant_conv(hidden_states)
517
- quant_states, indices = self.quantize(hidden_states)
518
- return quant_states, indices
519
-
520
- def decode(self, hidden_states, deterministic: bool = True):
521
- hidden_states = self.post_quant_conv(hidden_states)
522
- hidden_states = self.decoder(hidden_states, deterministic=deterministic)
523
- return hidden_states
524
-
525
- def decode_code(self, code_b):
526
- hidden_states = self.quantize.get_codebook_entry(code_b)
527
- hidden_states = self.decode(hidden_states)
528
- return hidden_states
529
-
530
- def __call__(self, pixel_values, deterministic: bool = True):
531
- quant_states, indices = self.encode(pixel_values, deterministic)
532
- hidden_states = self.decode(quant_states, deterministic)
533
- return hidden_states, indices
534
-
535
-
536
- class VQGANPreTrainedModel(FlaxPreTrainedModel):
537
- """
538
- An abstract class to handle weights initialization and a simple interface
539
- for downloading and loading pretrained models.
540
- """
541
-
542
- config_class = VQGANConfig
543
- base_model_prefix = "model"
544
- module_class: nn.Module = None
545
-
546
- def __init__(
547
- self,
548
- config: VQGANConfig,
549
- input_shape: Tuple = (1, 256, 256, 3),
550
- seed: int = 0,
551
- dtype: jnp.dtype = jnp.float32,
552
- **kwargs,
553
- ):
554
- module = self.module_class(config=config, dtype=dtype, **kwargs)
555
- super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)
556
-
557
- def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
558
- # init input tensors
559
- pixel_values = jnp.zeros(input_shape, dtype=jnp.float32)
560
- params_rng, dropout_rng = jax.random.split(rng)
561
- rngs = {"params": params_rng, "dropout": dropout_rng}
562
-
563
- return self.module.init(rngs, pixel_values)["params"]
564
-
565
- def encode(self, pixel_values, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
566
- # Handle any PRNG if needed
567
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
568
-
569
- return self.module.apply(
570
- {"params": params or self.params}, jnp.array(pixel_values), not train, rngs=rngs, method=self.module.encode
571
- )
572
-
573
- def decode(self, hidden_states, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train: bool = False):
574
- # Handle any PRNG if needed
575
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
576
-
577
- return self.module.apply(
578
- {"params": params or self.params},
579
- jnp.array(hidden_states),
580
- not train,
581
- rngs=rngs,
582
- method=self.module.decode,
583
- )
584
-
585
- def decode_code(self, indices, params: dict = None):
586
- return self.module.apply(
587
- {"params": params or self.params}, jnp.array(indices, dtype="i4"), method=self.module.decode_code
588
- )
589
-
590
- def __call__(
591
- self,
592
- pixel_values,
593
- params: dict = None,
594
- dropout_rng: jax.random.PRNGKey = None,
595
- train: bool = False,
596
- ):
597
- # Handle any PRNG if needed
598
- rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
599
-
600
- return self.module.apply(
601
- {"params": params or self.params},
602
- jnp.array(pixel_values),
603
- not train,
604
- rngs=rngs,
605
- )
606
-
607
-
608
- class VQModel(VQGANPreTrainedModel):
609
- module_class = VQModule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dalle_mini/model.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import jax
3
+ import flax.linen as nn
4
+
5
+ from transformers.models.bart.modeling_flax_bart import (
6
+ FlaxBartModule,
7
+ FlaxBartForConditionalGenerationModule,
8
+ FlaxBartForConditionalGeneration,
9
+ FlaxBartEncoder,
10
+ FlaxBartDecoder
11
+ )
12
+
13
+ from transformers import BartConfig
14
+
15
+
16
+ # Model hyperparameters, for convenience
17
+ OUTPUT_VOCAB_SIZE = 16384 + 1 # encoded image token space + 1 for bos
18
+ OUTPUT_LENGTH = 256 + 1 # number of encoded tokens + 1 for bos
19
+ BOS_TOKEN_ID = 16384
20
+ BASE_MODEL = 'facebook/bart-large-cnn' # we currently have issues with bart-large
21
+
22
+
23
+ class CustomFlaxBartModule(FlaxBartModule):
24
+ def setup(self):
25
+ # check config is valid, otherwise set default values
26
+ self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
27
+ self.config.max_position_embeddings_decoder = getattr(self.config, 'max_position_embeddings_decoder', OUTPUT_LENGTH)
28
+
29
+ # we keep shared to easily load pre-trained weights
30
+ self.shared = nn.Embed(
31
+ self.config.vocab_size,
32
+ self.config.d_model,
33
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
34
+ dtype=self.dtype,
35
+ )
36
+ # a separate embedding is used for the decoder
37
+ self.decoder_embed = nn.Embed(
38
+ self.config.vocab_size_output,
39
+ self.config.d_model,
40
+ embedding_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
41
+ dtype=self.dtype,
42
+ )
43
+ self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=self.shared)
44
+
45
+ # the decoder has a different config
46
+ decoder_config = BartConfig(self.config.to_dict())
47
+ decoder_config.max_position_embeddings = self.config.max_position_embeddings_decoder
48
+ decoder_config.vocab_size = self.config.vocab_size_output
49
+ self.decoder = FlaxBartDecoder(decoder_config, dtype=self.dtype, embed_tokens=self.decoder_embed)
50
+
51
+ class CustomFlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
52
+ def setup(self):
53
+ # check config is valid, otherwise set default values
54
+ self.config.vocab_size_output = getattr(self.config, 'vocab_size_output', OUTPUT_VOCAB_SIZE)
55
+
56
+ self.model = CustomFlaxBartModule(config=self.config, dtype=self.dtype)
57
+ self.lm_head = nn.Dense(
58
+ self.config.vocab_size_output,
59
+ use_bias=False,
60
+ dtype=self.dtype,
61
+ kernel_init=jax.nn.initializers.normal(self.config.init_std, self.dtype),
62
+ )
63
+ self.final_logits_bias = self.param("final_logits_bias", self.bias_init, (1, self.config.vocab_size_output))
64
+
65
+ class CustomFlaxBartForConditionalGeneration(FlaxBartForConditionalGeneration):
66
+ module_class = CustomFlaxBartForConditionalGenerationModule
demo/dalle_mini DELETED
@@ -1 +0,0 @@
1
- ../dalle_mini
 
 
{data → dev/data}/CC12M_downloader.py RENAMED
File without changes
{data → dev/data}/CC3M_downloader.py RENAMED
File without changes
dev/data/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Data
2
+
3
+ Utility scripts for downloading CC3M and CC12M.
dev/notebooks/README.md ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Notebooks
2
+
3
+ These notebooks were used during development.
4
+
5
+ TODO: This section requires some refactor and clean up.
{demo → dev/notebooks/demo}/CustomBARTv4b_model-generate.ipynb RENAMED
File without changes
{demo → dev/notebooks/demo}/demo_notebook.ipynb RENAMED
File without changes
{demo → dev/notebooks/demo}/model-sweep.py RENAMED
File without changes
{demo → dev/notebooks/demo}/tpu-demo.ipynb RENAMED
File without changes
{encoding → dev/notebooks/encoding}/vqgan-jax-encoding-with-captions.ipynb RENAMED
File without changes
{encoding → dev/notebooks/encoding}/vqgan-jax-encoding-yfcc100m.ipynb RENAMED
File without changes
{encoding → dev/notebooks/encoding}/vqgan-jax-encoding.ipynb RENAMED
File without changes
{model → dev/notebooks/model}/data-pipeline.ipynb RENAMED
File without changes
{seq2seq → dev/seq2seq}/do_big_run.sh RENAMED
File without changes
{seq2seq → dev/seq2seq}/do_small_run.sh RENAMED
File without changes
{seq2seq → dev/seq2seq}/environment.yaml RENAMED
File without changes
{seq2seq → dev/seq2seq}/requirements.txt RENAMED
File without changes
{seq2seq → dev/seq2seq}/run_seq2seq_flax.py RENAMED
File without changes
{seq2seq → dev/seq2seq}/sweep.yaml RENAMED
File without changes
{demo → dev}/wandb-examples.py RENAMED
@@ -75,8 +75,8 @@ import os
75
  os.environ["WANDB_SILENT"] = "true"
76
  os.environ["WANDB_CONSOLE"] = "off"
77
 
78
- id = '1i5e8rlj' #'1coradc5'
79
- #run = wandb.init(id=id, project="dalle-mini-demo", resume="allow")
80
  run = wandb.init(id=id,
81
  entity='wandb',
82
  project="hf-flax-dalle-mini",
 
75
  os.environ["WANDB_SILENT"] = "true"
76
  os.environ["WANDB_CONSOLE"] = "off"
77
 
78
+ # set id to None so our latest images don't get overwritten
79
+ id = None
80
  run = wandb.init(id=id,
81
  entity='wandb',
82
  project="hf-flax-dalle-mini",
requirements.txt CHANGED
@@ -2,12 +2,5 @@
2
  -f https://storage.googleapis.com/jax-releases/jax_releases.html
3
  jax[cuda111]
4
  flax
5
- requests
6
- -e git+https://github.com/huggingface/transformers.git@master#egg=transformers
7
- -e git+https://github.com/huggingface/datasets.git@master#egg=datasets
8
- flax
9
- jupyter
10
- wandb
11
- ftfy
12
- streamlit
13
  gradio
 
2
  -f https://storage.googleapis.com/jax-releases/jax_releases.html
3
  jax[cuda111]
4
  flax
5
+ transformers
 
 
 
 
 
 
 
6
  gradio