pcuenq HF staff commited on
Commit
dc95170
·
1 Parent(s): f3566c2

Download from HF

Browse files
Files changed (2) hide show
  1. app.py +39 -34
  2. data/prompts_demo.tsv +101 -0
app.py CHANGED
@@ -1,5 +1,6 @@
1
  # For licensing see accompanying LICENSE file.
2
  # Copyright (C) 2024 Apple Inc. All rights reserved.
 
3
  import logging
4
  import os
5
  import shlex
@@ -14,6 +15,9 @@ from einops import rearrange, repeat
14
 
15
  import numpy as np
16
  import torch
 
 
 
17
  from torchvision.utils import make_grid
18
 
19
  from ml_mdm import helpers, reader
@@ -22,7 +26,9 @@ from ml_mdm.language_models import factory
22
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
 
25
- # Note that it is called add_arguments, not add_argument.
 
 
26
  logging.basicConfig(
27
  level=getattr(logging, "INFO", None),
28
  format="[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s",
@@ -30,6 +36,16 @@ logging.basicConfig(
30
  )
31
 
32
 
 
 
 
 
 
 
 
 
 
 
33
  def dividable(n):
34
  for i in range(int(np.sqrt(n)), 0, -1):
35
  if n % i == 0:
@@ -91,7 +107,6 @@ class GLOBAL_DATA:
91
  diffusion_model = None
92
  override_args = ""
93
  ckpt_name = ""
94
- config_file = ""
95
 
96
 
97
  global_config = GLOBAL_DATA()
@@ -110,9 +125,9 @@ def get_model_type(config_file):
110
  return d.get("model", d.get("vision_model", "unet"))
111
 
112
 
 
113
  def generate(
114
- config_file="cc12m_64x64.yaml",
115
- ckpt_name="vis_model_64x64.pth",
116
  prompt="a chair",
117
  input_template="",
118
  negative_prompt="",
@@ -140,28 +155,28 @@ def generate(
140
  negative_prompt = negative_prompt + negative_template
141
  print(f"Postive: {prompt} / Negative: {negative_prompt}")
142
 
143
- if not os.path.exists(ckpt_name):
144
- logging.info(f"Did not generate because {ckpt_name} does not exist")
145
- return None, None, f"{ckpt_name} does not exist", None, None
 
146
 
147
  if (
148
- global_config.config_file != config_file
149
- or global_config.ckpt_name != ckpt_name
150
  or global_config.override_args != override_args
151
  ):
152
  # Identify model type
153
- model_type = get_model_type(f"configs/models/{config_file}")
154
  # reload the arguments
155
  args = get_arguments(
156
  shlex.split(override_args + f" --model {model_type}"),
157
  mode="demo",
158
- additional_config_paths=[f"configs/models/{config_file}"],
159
  )
160
  helpers.print_args(args)
161
 
162
  # setup model when the parent task changed.
 
163
  tokenizer, language_model, diffusion_model = setup_models(args, device)
164
- vision_model_file = ckpt_name
165
  try:
166
  other_items = diffusion_model.model.load(vision_model_file)
167
  except Exception as e:
@@ -176,7 +191,6 @@ def generate(
176
  global_config.language_model = language_model
177
  global_config.diffusion_model = diffusion_model
178
  global_config.reader_config = args.reader_config
179
- global_config.config_file = config_file
180
  global_config.ckpt_name = ckpt_name
181
 
182
  else:
@@ -287,6 +301,8 @@ def generate(
287
 
288
 
289
  def main(args):
 
 
290
  # get the language model outputs
291
  example_texts = open("data/prompts_demo.tsv").readlines()
292
 
@@ -315,25 +331,15 @@ def main(args):
315
  pid = gr.State()
316
  with gr.Column(scale=2):
317
  with gr.Row(equal_height=False):
318
- with gr.Column(scale=1):
319
- config_file = gr.Dropdown(
320
- [
321
- "cc12m_64x64.yaml",
322
- "cc12m_256x256.yaml",
323
- "cc12m_1024x1024.yaml",
324
- ],
325
- value="cc12m_64x64.yaml",
326
- label="Select the config file",
327
- )
328
  with gr.Column(scale=1):
329
  ckpt_name = gr.Dropdown(
330
  [
331
- "vis_model_64x64.pth",
332
- "vis_model_256x256.pth",
333
- "vis_model_1024x1024.pth",
334
  ],
335
- value="vis_model_64x64.pth",
336
- label="Load checkpoint",
337
  )
338
  with gr.Row(equal_height=False):
339
  with gr.Column(scale=1):
@@ -363,7 +369,7 @@ def main(args):
363
  )
364
  with gr.Column(scale=1):
365
  batch_size = gr.Slider(
366
- value=16, minimum=1, maximum=128, step=1, label="Batch size"
367
  )
368
 
369
  with gr.Row(equal_height=False):
@@ -488,7 +494,6 @@ def main(args):
488
  run_event = run_btn.click(
489
  fn=generate,
490
  inputs=[
491
- config_file,
492
  ckpt_name,
493
  prompt_input,
494
  input_template,
@@ -526,11 +531,11 @@ def main(args):
526
  )
527
  example0 = gr.Examples(
528
  [
529
- ["cc12m_64x64.yaml", "vis_model_64x64.pth", 64, 50, 0],
530
- ["cc12m_256x256.yaml", "vis_model_256x256.pth", 16, 100, 0],
531
- ["cc12m_1024x1024.yaml", "vis_model_1024x1024.pth", 4, 250, 1],
532
  ],
533
- inputs=[config_file, ckpt_name, batch_size, num_inference_steps, eta],
534
  )
535
  example1 = gr.Examples(
536
  examples=[[t.strip()] for t in example_texts],
 
1
  # For licensing see accompanying LICENSE file.
2
  # Copyright (C) 2024 Apple Inc. All rights reserved.
3
+ import spaces
4
  import logging
5
  import os
6
  import shlex
 
15
 
16
  import numpy as np
17
  import torch
18
+ from huggingface_hub import snapshot_download
19
+ from pathlib import Path
20
+ from transformers import T5ForConditionalGeneration
21
  from torchvision.utils import make_grid
22
 
23
  from ml_mdm import helpers, reader
 
26
 
27
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
 
29
+ # Download destination
30
+ models = Path("models")
31
+
32
  logging.basicConfig(
33
  level=getattr(logging, "INFO", None),
34
  format="[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s",
 
36
  )
37
 
38
 
39
+ def download_all_models():
40
+ # Cache language model in the standard location
41
+ _ = T5ForConditionalGeneration.from_pretrained("google/flan-t5-xl")
42
+
43
+ # Download the vision models we use in the demo
44
+ snapshot_download("pcuenq/mdm-flickr-64", local_dir=models/"mdm-flickr-64")
45
+ snapshot_download("pcuenq/mdm-flickr-256", local_dir=models/"mdm-flickr-256")
46
+ snapshot_download("pcuenq/mdm-flickr-1024", local_dir=models/"mdm-flickr-1024")
47
+
48
+
49
  def dividable(n):
50
  for i in range(int(np.sqrt(n)), 0, -1):
51
  if n % i == 0:
 
107
  diffusion_model = None
108
  override_args = ""
109
  ckpt_name = ""
 
110
 
111
 
112
  global_config = GLOBAL_DATA()
 
125
  return d.get("model", d.get("vision_model", "unet"))
126
 
127
 
128
+ @spaces.GPU
129
  def generate(
130
+ ckpt_name="mdm-flickr-64",
 
131
  prompt="a chair",
132
  input_template="",
133
  negative_prompt="",
 
155
  negative_prompt = negative_prompt + negative_template
156
  print(f"Postive: {prompt} / Negative: {negative_prompt}")
157
 
158
+ vision_model_file = models/ckpt_name/"vis_model.pth"
159
+ if not os.path.exists(vision_model_file):
160
+ logging.info(f"Did not generate because {vision_model_file} does not exist")
161
+ return None, None, f"{vision_model_file} does not exist", None, None
162
 
163
  if (
164
+ global_config.ckpt_name != ckpt_name
 
165
  or global_config.override_args != override_args
166
  ):
167
  # Identify model type
168
+ model_type = get_model_type(models/ckpt_name/"config.yaml")
169
  # reload the arguments
170
  args = get_arguments(
171
  shlex.split(override_args + f" --model {model_type}"),
172
  mode="demo",
173
+ additional_config_paths=[models/ckpt_name/"config.yaml"],
174
  )
175
  helpers.print_args(args)
176
 
177
  # setup model when the parent task changed.
178
+ args.vocab_file = str(models/ckpt_name/args.vocab_file)
179
  tokenizer, language_model, diffusion_model = setup_models(args, device)
 
180
  try:
181
  other_items = diffusion_model.model.load(vision_model_file)
182
  except Exception as e:
 
191
  global_config.language_model = language_model
192
  global_config.diffusion_model = diffusion_model
193
  global_config.reader_config = args.reader_config
 
194
  global_config.ckpt_name = ckpt_name
195
 
196
  else:
 
301
 
302
 
303
  def main(args):
304
+ download_all_models()
305
+
306
  # get the language model outputs
307
  example_texts = open("data/prompts_demo.tsv").readlines()
308
 
 
331
  pid = gr.State()
332
  with gr.Column(scale=2):
333
  with gr.Row(equal_height=False):
 
 
 
 
 
 
 
 
 
 
334
  with gr.Column(scale=1):
335
  ckpt_name = gr.Dropdown(
336
  [
337
+ "mdm-flickr-64",
338
+ "mdm-flickr-256",
339
+ "mdm-flickr-1024",
340
  ],
341
+ value="mdm-flickr-64",
342
+ label="Model",
343
  )
344
  with gr.Row(equal_height=False):
345
  with gr.Column(scale=1):
 
369
  )
370
  with gr.Column(scale=1):
371
  batch_size = gr.Slider(
372
+ value=64, minimum=1, maximum=128, step=1, label="Number of images"
373
  )
374
 
375
  with gr.Row(equal_height=False):
 
494
  run_event = run_btn.click(
495
  fn=generate,
496
  inputs=[
 
497
  ckpt_name,
498
  prompt_input,
499
  input_template,
 
531
  )
532
  example0 = gr.Examples(
533
  [
534
+ ["mdm-flickr-64", 64, 50, 0],
535
+ ["mdm-flickr-256", 16, 100, 0],
536
+ ["mdm-flickr-1024", 4, 250, 1],
537
  ],
538
+ inputs=[ckpt_name, batch_size, num_inference_steps, eta],
539
  )
540
  example1 = gr.Examples(
541
  examples=[[t.strip()] for t in example_texts],
data/prompts_demo.tsv ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ a corgi dog wearing sunglasses at the beach
2
+ A traditional Chinese garden in summer by Claude Monet
3
+ painting of an old man making sushi in the style of golden light, sandalpunk, realistic and hyper-detailed renderings | precisionist, romanticized landscapes, hyper-realistic detailed character illustrations
4
+ Cinematic photo of a fluffy baby Quokka with a knitted hat eating a large cup of popcorns, close up, studio lighting, screen reflecting in its eyes. 35mm photographs, film, bokeh, professional, 4k, highly detailed
5
+ Photography closeup portrait of an adorable rusty broken ­down steampunk llama-shaped robot covered in budding vegetation, surrounded by tall grass, misty futuristic sci-­fi forest environment.
6
+ Paying for a quarter-sized pizza with a pizza-sized quarter.
7
+ An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with no umbrellas.
8
+ A grocery store refrigerator has pint cartons of milk on the top shelf, quart cartons on the middle shelf, and gallon plastic jugs on the bottom shelf.
9
+ In late afternoon in January in New England, a man stands in the shadow of a maple tree.
10
+ An elephant is behind a tree. You can see the trunk on one side and the back legs on the other.
11
+ A tomato has been put on top of a pumpkin on a kitchen stool. There is a fork sticking into the pumpkin. The scene is viewed from above.
12
+ A pear cut into seven pieces arranged in a ring.
13
+ A donkey and an octopus are playing a game. The donkey is holding a rope on one end, the octopus is holding onto the other. The donkey holds the rope in its mouth. A cat is jumping over the rope.
14
+ Supreme Court Justices play a baseball game with the FBI. The FBI is at bat, the justices are on the field.
15
+ Abraham Lincoln touches his toes while George Washington does chin-ups. Lincoln is barefoot. Washington is wearing boots.
16
+ A train on top of a surfboard.
17
+ A wine glass on top of a dog.
18
+ A bicycle on top of a boat.
19
+ An umbrella on top of a spoon.
20
+ A laptop on top of a teddy bear.
21
+ A giraffe underneath a microwave.
22
+ A donut underneath a toilet.
23
+ A hair drier underneath a sheep.
24
+ A tennis racket underneath a traffic light.
25
+ A zebra underneath a broccoli.
26
+ A banana on the left of an apple.
27
+ A couch on the left of a chair.
28
+ A car on the left of a bus.
29
+ A cat on the left of a dog.
30
+ A carrot on the left of a broccoli.
31
+ A pizza on the right of a suitcase.
32
+ A cat on the right of a tennis racket.
33
+ A stop sign on the right of a refrigerator.
34
+ A sheep to the right of a wine glass.
35
+ A zebra to the right of a fire hydrant.
36
+ Acersecomicke.
37
+ Jentacular.
38
+ Matutinal.
39
+ Peristeronic.
40
+ Artophagous.
41
+ Backlotter.
42
+ Octothorpe.
43
+ A church with stained glass windows depicting a hamburger and french fries.
44
+ Painting of the orange cat Otto von Garfield, Count of Bismarck-Schönhausen, Duke of Lauenburg, Minister-President of Prussia. Depicted wearing a Prussian Pickelhaube and eating his favorite meal - lasagna.
45
+ A baby fennec sneezing onto a strawberry, detailed, macro, studio light, droplets, backlit ears.
46
+ A photo of a confused grizzly bear in calculus class.
47
+ An ancient Egyptian painting depicting an argument over whose turn it is to take out the trash.
48
+ A fluffy baby sloth with a knitted hat trying to figure out a laptop, close up, highly detailed, studio lighting, screen reflecting in its eyes.
49
+ A tiger in a lab coat with a 1980s Miami vibe, turning a well oiled science content machine, digital art.
50
+ A 1960s yearbook photo with animals dressed as humans.
51
+ Lego Arnold Schwarzenegger.
52
+ A yellow and black bus cruising through the rainforest.
53
+ A medieval painting of the wifi not working.
54
+ An IT-guy trying to fix hardware of a PC tower is being tangled by the PC cables like Laokoon. Marble, copy after Hellenistic original from ca. 200 BC. Found in the Baths of Trajan, 1506.
55
+ 35mm macro shot a kitten licking a baby duck, studio lighting.
56
+ McDonalds Church.
57
+ Photo of an athlete cat explaining it's latest scandal at a press conference to journalists.
58
+ Greek statue of a man tripping over a cat.
59
+ An old photograph of a 1920s airship shaped like a pig, floating over a wheat field.
60
+ Photo of a cat singing in a barbershop quartet.
61
+ A painting by Grant Wood of an astronaut couple, american gothic style.
62
+ An oil painting portrait of the regal Burger King posing with a Whopper.
63
+ A keyboard made of water, the water is made of light, the light is turned off.
64
+ Painting of Mona Lisa but the view is from behind of Mona Lisa.
65
+ Hyper-realistic photo of an abandoned industrial site during a storm.
66
+ A screenshot of an iOS app for ordering different types of milk.
67
+ A real life photography of super mario, 8k Ultra HD.
68
+ Colouring page of large cats climbing the eifel tower in a cyberpunk future.
69
+ Photo of a mega Lego space station inside a kid's bedroom.
70
+ A spider with a moustache bidding an equally gentlemanly grasshopper a good day during his walk to work.
71
+ A photocopy of a photograph of a painting of a sculpture of a giraffe.
72
+ A bridge connecting Europe and North America on the Atlantic Ocean, bird's eye view.
73
+ A maglev train going vertically downward in high speed, New York Times photojournalism.
74
+ A magnifying glass over a page of a 1950s batman comic.
75
+ A car playing soccer, digital art.
76
+ Darth Vader playing with raccoon in Mars during sunset.
77
+ A 1960s poster warning against climate change.
78
+ Illustration of a mouse using a mushroom as an umbrella.
79
+ A realistic photo of a Pomeranian dressed up like a 1980s professional wrestler with neon green and neon orange face paint and bright green wrestling tights with bright orange boots.
80
+ A pyramid made of falafel with a partial solar eclipse in the background.
81
+ A storefront with 'Hello World' written on it.
82
+ A storefront with 'Diffusion' written on it.
83
+ A storefront with 'Text to Image' written on it.
84
+ A storefront with 'NeurIPS' written on it.
85
+ A storefront with 'Deep Learning' written on it.
86
+ A storefront with 'Google Brain Toronto' written on it.
87
+ A storefront with 'Google Research Pizza Cafe' written on it.
88
+ A sign that says 'Hello World'.
89
+ A sign that says 'Diffusion'.
90
+ A sign that says 'Text to Image'.
91
+ A sign that says 'NeurIPS'.
92
+ A sign that says 'Deep Learning'.
93
+ A sign that says 'Google Brain Toronto'.
94
+ A sign that says 'Google Research Pizza Cafe'.
95
+ New York Skyline with 'Hello World' written with fireworks on the sky.
96
+ New York Skyline with 'Diffusion' written with fireworks on the sky.
97
+ New York Skyline with 'Text to Image' written with fireworks on the sky.
98
+ New York Skyline with 'NeurIPS' written with fireworks on the sky.
99
+ New York Skyline with 'Deep Learning' written with fireworks on the sky.
100
+ New York Skyline with 'Google Brain Toronto' written with fireworks on the sky.
101
+ New York Skyline with 'Google Research Pizza Cafe' written with fireworks on the sky.