sparswan commited on
Commit
c40b98a
1 Parent(s): 9ed6514

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("git clone --recursive https://github.com/JD-P/cloob-latent-diffusion")
4
+ os.system("cd cloob-latent-diffusion;pip install omegaconf pillow pytorch-lightning einops wandb ftfy regex ./CLIP")
5
+
6
+ import argparse
7
+ from functools import partial
8
+ from pathlib import Path
9
+ import sys
10
+ sys.path.append('./cloob-latent-diffusion')
11
+ sys.path.append('./cloob-latent-diffusion/cloob-training')
12
+ sys.path.append('./cloob-latent-diffusion/latent-diffusion')
13
+ sys.path.append('./cloob-latent-diffusion/taming-transformers')
14
+ sys.path.append('./cloob-latent-diffusion/v-diffusion-pytorch')
15
+ from omegaconf import OmegaConf
16
+ from PIL import Image
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+ from torchvision import transforms
21
+ from torchvision.transforms import functional as TF
22
+ from tqdm import trange
23
+ from CLIP import clip
24
+ from cloob_training import model_pt, pretrained
25
+ import ldm.models.autoencoder
26
+ from diffusion import sampling, utils
27
+ import train_latent_diffusion as train
28
+ from huggingface_hub import hf_hub_url, cached_download
29
+ import random
30
+
31
+ # Download the model files
32
+ checkpoint = cached_download(hf_hub_url("huggan/distill-ccld-wa", filename="model_student.ckpt"))
33
+ ae_model_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.ckpt"))
34
+ ae_config_path = cached_download(hf_hub_url("huggan/ccld_wa", filename="ae_model.yaml"))
35
+
36
+ # Define a few utility functions
37
+
38
+
39
+ def parse_prompt(prompt, default_weight=3.):
40
+ if prompt.startswith('http://') or prompt.startswith('https://'):
41
+ vals = prompt.rsplit(':', 2)
42
+ vals = [vals[0] + ':' + vals[1], *vals[2:]]
43
+ else:
44
+ vals = prompt.rsplit(':', 1)
45
+ vals = vals + ['', default_weight][len(vals):]
46
+ return vals[0], float(vals[1])
47
+
48
+
49
+ def resize_and_center_crop(image, size):
50
+ fac = max(size[0] / image.size[0], size[1] / image.size[1])
51
+ image = image.resize((int(fac * image.size[0]), int(fac * image.size[1])), Image.LANCZOS)
52
+ return TF.center_crop(image, size[::-1])
53
+
54
+
55
+ # Load the models
56
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
57
+ print('Using device:', device)
58
+ print('loading models')
59
+
60
+ # autoencoder
61
+ ae_config = OmegaConf.load(ae_config_path)
62
+ ae_model = ldm.models.autoencoder.AutoencoderKL(**ae_config.model.params)
63
+ ae_model.eval().requires_grad_(False).to(device)
64
+ ae_model.load_state_dict(torch.load(ae_model_path))
65
+ n_ch, side_y, side_x = 4, 32, 32
66
+
67
+ # diffusion model
68
+ model = train.DiffusionModel(192, [1,1,2,2], autoencoder_scale=torch.tensor(4.3084))
69
+ model.load_state_dict(torch.load(checkpoint, map_location='cpu'))
70
+ model = model.to(device).eval().requires_grad_(False)
71
+
72
+ # CLOOB
73
+ cloob_config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
74
+ cloob = model_pt.get_pt_model(cloob_config)
75
+ checkpoint = pretrained.download_checkpoint(cloob_config)
76
+ cloob.load_state_dict(model_pt.get_pt_params(cloob_config, checkpoint))
77
+ cloob.eval().requires_grad_(False).to(device)
78
+
79
+
80
+ # The key function: returns a list of n PIL images
81
+ def generate(n=1, prompts=['a red circle'], images=[], seed=42, steps=15,
82
+ method='plms', eta=None):
83
+ zero_embed = torch.zeros([1, cloob.config['d_embed']], device=device)
84
+ target_embeds, weights = [zero_embed], []
85
+
86
+ for prompt in prompts:
87
+ txt, weight = parse_prompt(prompt)
88
+ target_embeds.append(cloob.text_encoder(cloob.tokenize(txt).to(device)).float())
89
+ weights.append(weight)
90
+
91
+ for prompt in images:
92
+ path, weight = parse_prompt(prompt)
93
+ img = Image.open(utils.fetch(path)).convert('RGB')
94
+ clip_size = cloob.config['image_encoder']['image_size']
95
+ img = resize_and_center_crop(img, (clip_size, clip_size))
96
+ batch = TF.to_tensor(img)[None].to(device)
97
+ embed = F.normalize(cloob.image_encoder(cloob.normalize(batch)).float(), dim=-1)
98
+ target_embeds.append(embed)
99
+ weights.append(weight)
100
+
101
+ weights = torch.tensor([1 - sum(weights), *weights], device=device)
102
+
103
+ torch.manual_seed(seed)
104
+
105
+ def cfg_model_fn(x, t):
106
+ n = x.shape[0]
107
+ n_conds = len(target_embeds)
108
+ x_in = x.repeat([n_conds, 1, 1, 1])
109
+ t_in = t.repeat([n_conds])
110
+ clip_embed_in = torch.cat([*target_embeds]).repeat_interleave(n, 0)
111
+ vs = model(x_in, t_in, clip_embed_in).view([n_conds, n, *x.shape[1:]])
112
+ v = vs.mul(weights[:, None, None, None, None]).sum(0)
113
+ return v
114
+
115
+ def run(x, steps):
116
+ if method == 'ddpm':
117
+ return sampling.sample(cfg_model_fn, x, steps, 1., {})
118
+ if method == 'ddim':
119
+ return sampling.sample(cfg_model_fn, x, steps, eta, {})
120
+ if method == 'prk':
121
+ return sampling.prk_sample(cfg_model_fn, x, steps, {})
122
+ if method == 'plms':
123
+ return sampling.plms_sample(cfg_model_fn, x, steps, {})
124
+ if method == 'pie':
125
+ return sampling.pie_sample(cfg_model_fn, x, steps, {})
126
+ if method == 'plms2':
127
+ return sampling.plms2_sample(cfg_model_fn, x, steps, {})
128
+ assert False
129
+
130
+ batch_size = n
131
+ x = torch.randn([n, n_ch, side_y, side_x], device=device)
132
+ t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
133
+ steps = utils.get_spliced_ddpm_cosine_schedule(t)
134
+ pil_ims = []
135
+ for i in trange(0, n, batch_size):
136
+ cur_batch_size = min(n - i, batch_size)
137
+ out_latents = run(x[i:i+cur_batch_size], steps)
138
+ outs = ae_model.decode(out_latents * torch.tensor(2.55).to(device))
139
+ for j, out in enumerate(outs):
140
+ pil_ims.append(utils.to_pil_image(out))
141
+
142
+ return pil_ims
143
+
144
+
145
+ import gradio as gr
146
+
147
+ def gen_ims(prompt, im_prompt=None, seed=None, n_steps=10, method='plms'):
148
+ if seed == None :
149
+ seed = random.randint(0, 10000)
150
+ print( prompt, im_prompt, seed, n_steps)
151
+ prompts = [prompt]
152
+ im_prompts = []
153
+ if im_prompt != None:
154
+ im_prompts = [im_prompt]
155
+ pil_ims = generate(n=1, prompts=prompts, images=im_prompts, seed=seed, steps=n_steps, method=method)
156
+ return pil_ims[0]
157
+
158
+ iface = gr.Interface(fn=gen_ims,
159
+ inputs=[#gr.inputs.Slider(minimum=1, maximum=1, step=1, default=1,label="Number of images"),
160
+ #gr.inputs.Slider(minimum=0, maximum=200, step=1, label='Random seed', default=0),
161
+ gr.inputs.Textbox(label="Text prompt"),
162
+ gr.inputs.Image(optional=True, label="Image prompt", type='filepath'),
163
+ #gr.inputs.Slider(minimum=10, maximum=35, step=1, default=15,label="Number of steps")
164
+ ],
165
+ outputs=[gr.outputs.Image(type="pil", label="Generated Image")],
166
+ examples=[
167
+ ["Virgin and Child, in the style of Jacopo Bellini"],
168
+ ["Katsushika Hokusai, The Dragon of Smoke Escaping from Mount Fuji"],
169
+ ["Moon Light Sonata by Basuki Abdullah"],
170
+ ["Twon Tree by M.C. Escher"],
171
+ ["Futurism, in the style of Wassily Kandinsky"],
172
+ ["Art Nouveau, in the style of John Singer Sargent"],
173
+ ["Surrealism, in the style of Edgar Degas"],
174
+ ["Expressionism, in the style of Wassily Kandinsky"],
175
+ ["Futurism, in the style of Egon Schiele"],
176
+ ["Neoclassicism, in the style of Gustav Klimt"],
177
+ ["Cubism, in the style of Gustav Klimt"],
178
+ ["Op Art, in the style of Marc Chagall"],
179
+ ["Romanticism, in the style of M.C. Escher"],
180
+ ["Futurism, in the style of M.C. Escher"],
181
+ ["Abstract Art, in the style of M.C. Escher"],
182
+ ["Mannerism, in the style of Paul Klee"],
183
+ ["Romanesque Art, in the style of Leonardo da Vinci"],
184
+ ["High Renaissance, in the style of Rembrandt"],
185
+ ["Magic Realism, in the style of Gustave Dore"],
186
+ ["Realism, in the style of Jean-Michel Basquiat"],
187
+ ["Art Nouveau, in the style of Paul Gauguin"],
188
+ ["Avant-garde, in the style of Pierre-Auguste Renoir"],
189
+ ["Baroque, in the style of Edward Hopper"],
190
+ ["Post-Impressionism, in the style of Wassily Kandinsky"],
191
+ ["Naturalism, in the style of Rene Magritte"],
192
+ ["Constructivism, in the style of Paul Cezanne"],
193
+ ["Abstract Expressionism, in the style of Henri Matisse"],
194
+ ["Pop Art, in the style of Vincent van Gogh"],
195
+ ["Futurism, in the style of Wassily Kandinsky"],
196
+ ["Futurism, in the style of Zdzislaw Beksinski"],
197
+ ['Surrealism, in the style of Salvador Dali'],
198
+ ["Aaron Wacker, oil on canvas"],
199
+ ["abstract"],
200
+ ["landscape"],
201
+ ["portrait"],
202
+ ["sculpture"],
203
+ ["genre painting"],
204
+ ["installation"],
205
+ ["photo"],
206
+ ["figurative"],
207
+ ["illustration"],
208
+ ["still life"],
209
+ ["history painting"],
210
+ ["cityscape"],
211
+ ["marina"],
212
+ ["animal painting"],
213
+ ["design"],
214
+ ["calligraphy"],
215
+ ["symbolic painting"],
216
+ ["graffiti"],
217
+ ["performance"],
218
+ ["mythological painting"],
219
+ ["battle painting"],
220
+ ["self-portrait"],
221
+ ["Impressionism, oil on canvas"]
222
+ ],
223
+ title='Art Generator and Style Mixer from 🧠 Cloob and 🎨 WikiArt - Visual Art Encyclopedia:',
224
+ description="Trained on images from the [WikiArt](https://www.wikiart.org/) dataset, comprised of visual arts",
225
+ article = 'Model used is: [model card](https://huggingface.co/huggan/distill-ccld-wa)..'
226
+
227
+ )
228
+ iface.launch(enable_queue=True) # , debug=True for colab debugging