YuwanA55 commited on
Commit
84619a2
1 Parent(s): 8a35bc6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -89
app.py CHANGED
@@ -1,38 +1,57 @@
1
- import gradio as gr
2
- import numpy as np
3
  import random
 
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
  import torch
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
 
 
 
 
 
 
11
 
12
  if torch.cuda.is_available():
13
- torch_dtype = torch.float16
 
 
 
 
 
 
 
 
14
  else:
15
- torch_dtype = torch.float32
 
 
 
 
 
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
19
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 1024
 
22
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
  if randomize_seed:
37
  seed = random.randint(0, MAX_SEED)
38
 
@@ -40,33 +59,55 @@ def infer(
40
 
41
  image = pipe(
42
  prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
  generator=generator,
49
  ).images[0]
50
 
51
- return image, seed
52
 
53
 
54
  examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
 
 
 
 
 
 
58
  ]
59
 
60
  css = """
61
  #col-container {
62
  margin: 0 auto;
63
- max-width: 640px;
64
  }
65
  """
66
 
 
 
 
 
 
67
  with gr.Blocks(css=css) as demo:
68
  with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  with gr.Row():
72
  prompt = gr.Text(
@@ -77,18 +118,11 @@ with gr.Blocks(css=css) as demo:
77
  container=False,
78
  )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
 
82
  result = gr.Image(label="Result", show_label=False)
83
 
84
  with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
  seed = gr.Slider(
93
  label="Seed",
94
  minimum=0,
@@ -99,56 +133,20 @@ with gr.Blocks(css=css) as demo:
99
 
100
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
  gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
  fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
  )
152
 
153
- if __name__ == "__main__":
154
- demo.launch()
 
 
 
1
  import random
2
+ import spaces
3
 
4
+ import gradio as gr
5
+ import numpy as np
6
  import torch
7
+ from diffusers import LCMScheduler, PixArtAlphaPipeline, Transformer2DModel
8
+ from peft import PeftModel
9
+ import os
10
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ IS_SPACE = os.environ.get("SPACE_ID", None) is not None
13
+
14
+ transformer = Transformer2DModel.from_pretrained(
15
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
16
+ subfolder="transformer",
17
+ torch_dtype=torch.float16,
18
+ )
19
+ transformer = PeftModel.from_pretrained(transformer, "jasperai/flash-pixart")
20
+
21
 
22
  if torch.cuda.is_available():
23
+ torch.cuda.max_memory_allocated(device=device)
24
+ pipe = PixArtAlphaPipeline.from_pretrained(
25
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
26
+ transformer=transformer,
27
+ torch_dtype=torch.float16,
28
+ )
29
+ if not IS_SPACE:
30
+ pipe.enable_xformers_memory_efficient_attention()
31
+ pipe = pipe.to(device)
32
  else:
33
+ pipe = PixArtAlphaPipeline.from_pretrained(
34
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
35
+ transformer=transformer,
36
+ torch_dtype=torch.float16,
37
+ )
38
+ pipe = pipe.to(device)
39
 
40
+ pipe.text_encoder.to_bettertransformer()
41
+
42
+ pipe.scheduler = LCMScheduler.from_pretrained(
43
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
44
+ subfolder="scheduler",
45
+ timestep_spacing="trailing",
46
+ )
47
 
48
  MAX_SEED = np.iinfo(np.int32).max
49
  MAX_IMAGE_SIZE = 1024
50
+ NUM_INFERENCE_STEPS = 4
51
 
52
 
53
+ @spaces.GPU
54
+ def infer(prompt, seed, randomize_seed):
 
 
 
 
 
 
 
 
 
 
55
  if randomize_seed:
56
  seed = random.randint(0, MAX_SEED)
57
 
 
59
 
60
  image = pipe(
61
  prompt=prompt,
62
+ guidance_scale=0,
63
+ num_inference_steps=NUM_INFERENCE_STEPS,
 
 
 
64
  generator=generator,
65
  ).images[0]
66
 
67
+ return image
68
 
69
 
70
  examples = [
71
+ "The image showcases a freshly baked bread, possibly focaccia, with rosemary sprigs and red pepper flakes sprinkled on top. It's sliced and placed on a wire cooling rack, with a bowl of mixed peppercorns beside it.",
72
+ "A raccoon reading a book in a lush forest.",
73
+ "A small cactus with a happy face in the Sahara desert.",
74
+ "A super-realistic close-up of a snake eye",
75
+ "A cute cheetah looking amazed and surprised",
76
+ "Pirate ship sailing on a sea with the milky way galaxy in the sky and purple glow lights",
77
+ "a cute fluffy rabbit pilot walking on a military aircraft carrier, 8k, cinematic",
78
+ "A close up of an old elderly man with green eyes looking straight at the camera",
79
+ "A beautiful sunflower in rainy day",
80
  ]
81
 
82
  css = """
83
  #col-container {
84
  margin: 0 auto;
85
+ max-width: 512px;
86
  }
87
  """
88
 
89
+ if torch.cuda.is_available():
90
+ power_device = "GPU"
91
+ else:
92
+ power_device = "CPU"
93
+
94
  with gr.Blocks(css=css) as demo:
95
  with gr.Column(elem_id="col-container"):
96
+ gr.Markdown(
97
+ f"""
98
+ # ⚡ Flash Diffusion: FlashPixart ⚡
99
+ This is an interactive demo of [Flash Diffusion](https://gojasper.github.io/flash-diffusion-project/), a diffusion distillation method proposed in [Flash Diffusion: Accelerating Any Conditional
100
+ Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347) *by Clément Chadebec, Onur Tasar, Eyal Benaroche and Benjamin Aubin* from Jasper Research.
101
+ [This model](https://huggingface.co/jasperai/flash-pixart) is a **66.5M** LoRA distilled version of [Pixart-α](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS) model that is able to generate 1024x1024 images in **4 steps**.
102
+ Currently running on {power_device}.
103
+ """
104
+ )
105
+ gr.Markdown(
106
+ "If you enjoy the space, please also promote *open-source* by giving a ⭐ to the <a href='https://github.com/gojasper/flash-diffusion' target='_blank'>Github Repo</a>."
107
+ )
108
+ gr.Markdown(
109
+ "💡 *Hint:* To better appreciate the low latency of our method, run the demo locally !"
110
+ )
111
 
112
  with gr.Row():
113
  prompt = gr.Text(
 
118
  container=False,
119
  )
120
 
121
+ run_button = gr.Button("Run", scale=0)
122
 
123
  result = gr.Image(label="Result", show_label=False)
124
 
125
  with gr.Accordion("Advanced Settings", open=False):
 
 
 
 
 
 
 
126
  seed = gr.Slider(
127
  label="Seed",
128
  minimum=0,
 
133
 
134
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
135
 
136
+ examples = gr.Examples(examples=examples, inputs=[prompt])
137
+
138
+ gr.Markdown("**Disclaimer:**")
139
+ gr.Markdown(
140
+ "This demo is only for research purpose. Jasper cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. Jasper provides the tools, but the responsibility for their use lies with the individual user."
141
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  gr.on(
143
+ [run_button.click, seed.change, randomize_seed.change, prompt.submit],
144
  fn=infer,
145
+ inputs=[prompt, seed, randomize_seed],
146
+ outputs=[result],
147
+ show_progress="minimal",
148
+ show_api=False,
149
+ trigger_mode="always_last",
 
 
 
 
 
 
150
  )
151
 
152
+ demo.queue().launch(show_api=False)