Sanshruth commited on
Commit
369659e
1 Parent(s): 258af09

initial_commit

Browse files
Files changed (1) hide show
  1. app.py +298 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import zipfile
2
+ def unzip_content():
3
+ try:
4
+ # First try using Python's zipfile
5
+ print("Attempting to unzip content using Python...")
6
+ with zipfile.ZipFile('./content.zip', 'r') as zip_ref:
7
+ zip_ref.extractall('.')
8
+ except Exception as e:
9
+ print(f"Python unzip failed: {str(e)}")
10
+ try:
11
+ # Fallback to system unzip command
12
+ print("Attempting to unzip content using system command...")
13
+ subprocess.run(['unzip', '-o', './content.zip'], check=True)
14
+ except Exception as e:
15
+ print(f"System unzip failed: {str(e)}")
16
+ raise Exception("Failed to unzip content using both methods")
17
+ print("Content successfully unzipped!")
18
+
19
+ # Try to unzip content at startup
20
+ try:
21
+ unzip_content()
22
+ except Exception as e:
23
+ print(f"Warning: Could not unzip content: {str(e)}")
24
+
25
+ import gradio as gr
26
+ import numpy as np
27
+ import torch
28
+ import torchvision
29
+ import torchvision.transforms
30
+ import torchvision.transforms.functional
31
+ import PIL
32
+ import matplotlib.pyplot as plt
33
+ import yaml
34
+ from omegaconf import OmegaConf
35
+ from CLIP import clip
36
+ import os
37
+ os.chdir('./taming-transformers')
38
+ from taming.models.vqgan import VQModel
39
+ os.chdir('..')
40
+ from PIL import Image
41
+ import cv2
42
+ import imageio
43
+
44
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+
46
+ def create_video(image_folder='./generated', video_name='morphing_video.mp4'):
47
+ images = sorted([img for img in os.listdir(image_folder) if img.endswith(".png") or img.endswith(".jpg")])
48
+ if len(images) == 0:
49
+ print("No images found in the folder.")
50
+ return None
51
+
52
+ frame = cv2.imread(os.path.join(image_folder, images[0]))
53
+ height, width, layers = frame.shape
54
+ video_writer = imageio.get_writer(video_name, fps=10)
55
+
56
+ for image in images:
57
+ img_path = os.path.join(image_folder, image)
58
+ img = imageio.imread(img_path)
59
+ video_writer.append_data(img)
60
+
61
+ video_writer.close()
62
+ return video_name
63
+
64
+ def save_from_tensors(tensor, output_dir, filename):
65
+ img = tensor.clone()
66
+ img = img.mul(255).byte()
67
+ img = img.cpu().numpy().transpose((1, 2, 0))
68
+ os.makedirs(output_dir, exist_ok=True)
69
+ Image.fromarray(img).save(os.path.join(output_dir, filename))
70
+
71
+ def norm_data(data):
72
+ return (data.clip(-1, 1) + 1) / 2
73
+
74
+ def setup_clip_model():
75
+ model, _ = clip.load('ViT-B/32', jit=False)
76
+ model.eval().to(device)
77
+ return model
78
+
79
+ def setup_vqgan_model(config_path, checkpoint_path):
80
+ config = OmegaConf.load(config_path)
81
+ model = VQModel(**config.model.params)
82
+ state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
83
+ model.load_state_dict(state_dict, strict=False)
84
+ return model.eval().to(device)
85
+
86
+ def generator(x, model):
87
+ x = model.post_quant_conv(x)
88
+ x = model.decoder(x)
89
+ return x
90
+
91
+ def encode_text(text, clip_model):
92
+ t = clip.tokenize(text).to(device)
93
+ return clip_model.encode_text(t).detach().clone()
94
+
95
+ def create_encoding(include, exclude, extras, clip_model):
96
+ include_enc = [encode_text(text, clip_model) for text in include]
97
+ exclude_enc = [encode_text(text, clip_model) for text in exclude]
98
+ extras_enc = [encode_text(text, clip_model) for text in extras]
99
+ return include_enc, exclude_enc, extras_enc
100
+
101
+ def create_crops(img, num_crops=32, size1=225, noise_factor=0.05):
102
+ aug_transform = torch.nn.Sequential(
103
+ torchvision.transforms.RandomHorizontalFlip(),
104
+ torchvision.transforms.RandomAffine(30, translate=(0.1, 0.1), fill=0)
105
+ ).to(device)
106
+
107
+ p = size1 // 2
108
+ img = torch.nn.functional.pad(img, (p, p, p, p), mode='constant', value=0)
109
+ img = aug_transform(img)
110
+
111
+ crop_set = []
112
+ for _ in range(num_crops):
113
+ gap1 = int(torch.normal(1.2, .3, ()).clip(.43, 1.9) * size1)
114
+ offsetx = torch.randint(0, int(size1 * 2 - gap1), ())
115
+ offsety = torch.randint(0, int(size1 * 2 - gap1), ())
116
+ crop = img[:, :, offsetx:offsetx + gap1, offsety:offsety + gap1]
117
+ crop = torch.nn.functional.interpolate(crop, (224, 224), mode='bilinear', align_corners=True)
118
+ crop_set.append(crop)
119
+
120
+ img_crops = torch.cat(crop_set, 0)
121
+ randnormal = torch.randn_like(img_crops, requires_grad=False)
122
+ randstotal = torch.rand((img_crops.shape[0], 1, 1, 1)).to(device)
123
+ img_crops = img_crops + noise_factor * randstotal * randnormal
124
+
125
+ return img_crops
126
+
127
+ def optimize_result(params, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc):
128
+ alpha = 1
129
+ beta = 0.5
130
+ out = generator(params, vqgan_model)
131
+ out = norm_data(out)
132
+ out = create_crops(out)
133
+ out = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
134
+ (0.26862954, 0.26130258, 0.27577711))(out)
135
+
136
+ img_enc = clip_model.encode_image(out)
137
+ final_enc = w1 * prompt + w2 * extras_enc[0]
138
+ final_text_include_enc = final_enc / final_enc.norm(dim=-1, keepdim=True)
139
+ final_text_exclude_enc = exclude_enc[0]
140
+
141
+ main_loss = torch.cosine_similarity(final_text_include_enc, img_enc, dim=-1)
142
+ penalize_loss = torch.cosine_similarity(final_text_exclude_enc, img_enc, dim=-1)
143
+
144
+ return -alpha * main_loss.mean() + beta * penalize_loss.mean()
145
+
146
+ def optimize(params, optimizer, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc):
147
+ loss = optimize_result(params, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc)
148
+ optimizer.zero_grad()
149
+ loss.backward()
150
+ optimizer.step()
151
+ return loss
152
+
153
+ def training_loop(params, optimizer, include_enc, exclude_enc, extras_enc, vqgan_model, clip_model, w1, w2,
154
+ total_iter=200, show_step=1):
155
+ res_img = []
156
+ res_z = []
157
+
158
+ for prompt in include_enc:
159
+ for it in range(total_iter):
160
+ loss = optimize(params, optimizer, prompt, vqgan_model, clip_model, w1, w2, extras_enc, exclude_enc)
161
+
162
+ if it >= 0 and it % show_step == 0:
163
+ with torch.no_grad():
164
+ generated = generator(params, vqgan_model)
165
+ new_img = norm_data(generated[0].to(device))
166
+ res_img.append(new_img)
167
+ res_z.append(params.clone().detach())
168
+ print(f"loss: {loss.item():.4f}\nno. of iteration: {it}")
169
+
170
+ torch.cuda.empty_cache()
171
+ return res_img, res_z
172
+
173
+ def generate_art(include_text, exclude_text, extras_text, num_iterations):
174
+ try:
175
+ # Process the input prompts
176
+ include = [x.strip() for x in include_text.split(',')]
177
+ exclude = [x.strip() for x in exclude_text.split(',')]
178
+ extras = [x.strip() for x in extras_text.split(',')]
179
+
180
+ w1, w2 = 1.0, 0.9
181
+
182
+ # Setup models
183
+ clip_model = setup_clip_model()
184
+ vqgan_model = setup_vqgan_model("./models/vqgan_imagenet_f16_16384/configs/model.yaml",
185
+ "./models/vqgan_imagenet_f16_16384/checkpoints/last.ckpt")
186
+
187
+ # Parameters
188
+ learning_rate = 0.1
189
+ batch_size = 1
190
+ wd = 0.1
191
+ size1, size2 = 225, 400
192
+
193
+ # Initialize parameters
194
+ initial_image = PIL.Image.open('./gradient1.png')
195
+ initial_image = initial_image.resize((size2, size1))
196
+ initial_image = torchvision.transforms.ToTensor()(initial_image).unsqueeze(0).to(device)
197
+
198
+ with torch.no_grad():
199
+ z, _, _ = vqgan_model.encode(initial_image)
200
+
201
+ params = torch.nn.Parameter(z).to(device)
202
+ optimizer = torch.optim.AdamW([params], lr=learning_rate, weight_decay=wd)
203
+ params.data = params.data * 0.6 + torch.randn_like(params.data) * 0.4
204
+
205
+ # Encode prompts
206
+ include_enc, exclude_enc, extras_enc = create_encoding(include, exclude, extras, clip_model)
207
+
208
+ # Run training loop
209
+ res_img, res_z = training_loop(params, optimizer, include_enc, exclude_enc, extras_enc,
210
+ vqgan_model, clip_model, w1, w2, total_iter=num_iterations)
211
+
212
+ # Save results
213
+ output_dir = "generated"
214
+ # Create output directory if it doesn't exist
215
+ os.makedirs(output_dir, exist_ok=True)
216
+
217
+ # Clear any existing files in the output directory
218
+ for file in os.listdir(output_dir):
219
+ file_path = os.path.join(output_dir, file)
220
+ if os.path.isfile(file_path):
221
+ os.remove(file_path)
222
+
223
+ for i, img in enumerate(res_img):
224
+ save_from_tensors(img, output_dir, f"generated_image_{i:03d}.png")
225
+
226
+ # Create video
227
+ video_path = create_video()
228
+
229
+ # Delete the generated folder and its contents after creating the video
230
+ import shutil
231
+ shutil.rmtree(output_dir)
232
+
233
+ return video_path
234
+
235
+ except Exception as e:
236
+ # If there's an error, ensure the generated folder is cleaned up
237
+ if os.path.exists("generated"):
238
+ import shutil
239
+ shutil.rmtree("generated")
240
+ raise e # Re-raise the exception to be handled by the calling function
241
+ def gradio_interface(include_text, exclude_text, extras_text, num_iterations):
242
+ try:
243
+ video_path = generate_art(include_text, exclude_text, extras_text, int(num_iterations))
244
+ return video_path
245
+ except Exception as e:
246
+ return f"An error occurred: {str(e)}"
247
+
248
+ # Define and launch the Gradio app
249
+ iface = gr.Interface(
250
+ fn=gradio_interface,
251
+ inputs=[
252
+ gr.Textbox(label="Include Prompts (comma-separated)",
253
+ value="desert, heavy rain, cactus"),
254
+ gr.Textbox(label="Exclude Prompts (comma-separated)",
255
+ value="confusing, blurry"),
256
+ gr.Textbox(label="Extra Style Prompts (comma-separated)",
257
+ value="desert, clear, detailed, beautiful, good shape, detailed"),
258
+ gr.Number(label="Number of Iterations",
259
+ value=200, minimum=1, maximum=1000)
260
+ ],
261
+ outputs=gr.Video(label="Generated Morphing Video"),
262
+ title="VQGAN-CLIP Art Generator",
263
+ description="""
264
+ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1ivRYvTaX90PRghQIqAdOyEawkY0YLefa?authuser=0#scrollTo=WE7aPQ0t1hd2)
265
+ [![Clone Space](https://huggingface.co/datasets/huggingface/badges/raw/main/clone-space-lg.svg)](https://huggingface.co/spaces/your-username/your-space-name?duplicate=true)
266
+
267
+ Generate artistic videos using VQGAN-CLIP.
268
+ Enter your prompts separated by commas and adjust the number of iterations.
269
+ The model will generate a morphing video based on your inputs.
270
+
271
+ **Note:** This application requires GPU access. Please either:
272
+ 1. Use the Colab notebook (click the Colab badge above) with GPU runtime
273
+ 2. Clone this space (click Clone Space badge) and enable GPU in your personal copy""",
274
+ css="""
275
+ .gradio-container {
276
+ font-family: 'IBM Plex Sans', sans-serif;
277
+ }
278
+ .gr-button {
279
+ color: white;
280
+ border-radius: 7px;
281
+ background: linear-gradient(45deg, #7747FF, #FF3557);
282
+ border: none;
283
+ height: 46px;
284
+ }
285
+ a {
286
+ text-decoration: none;
287
+ }
288
+ .maintenance-msg {
289
+ color: #FF0000;
290
+ font-size: 14px;
291
+ margin-top: 10px;
292
+ }
293
+ """
294
+ )
295
+
296
+ if __name__ == "__main__":
297
+ print("Checking GPU availability:", "GPU AVAILABLE" if torch.cuda.is_available() else "NO GPU FOUND")
298
+ iface.launch()