clementchadebec commited on
Commit
d1eabfd
1 Parent(s): 620559b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -22
app.py CHANGED
@@ -5,25 +5,37 @@ from diffusers import StableDiffusionPipeline, LCMScheduler
5
  import torch
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
- adapter_id = "jasperai/flash-sd"
 
 
 
 
 
 
 
 
 
 
9
 
10
  if torch.cuda.is_available():
11
  torch.cuda.max_memory_allocated(device=device)
12
- pipe = StableDiffusionPipeline.from_pretrained(
13
- "runwayml/stable-diffusion-v1-5",
14
- use_safetensors=True,
 
15
  )
16
  pipe.enable_xformers_memory_efficient_attention()
17
  pipe = pipe.to(device)
18
  else:
19
- pipe = StableDiffusionPipeline.from_pretrained(
20
- "runwayml/stable-diffusion-v1-5",
21
- use_safetensors=True,
 
22
  )
23
  pipe = pipe.to(device)
24
 
25
  pipe.scheduler = LCMScheduler.from_pretrained(
26
- "runwayml/stable-diffusion-v1-5",
27
  subfolder="scheduler",
28
  timestep_spacing="trailing",
29
  )
@@ -32,7 +44,8 @@ pipe.load_lora_weights(adapter_id)
32
  pipe.fuse_lora()
33
 
34
  MAX_SEED = np.iinfo(np.int32).max
35
- MAX_IMAGE_SIZE = 512
 
36
 
37
  def infer(prompt, seed, randomize_seed, num_inference_steps):
38
 
@@ -59,7 +72,7 @@ examples = [
59
  css="""
60
  #col-container {
61
  margin: 0 auto;
62
- max-width: 520px;
63
  }
64
  """
65
 
@@ -104,17 +117,6 @@ with gr.Blocks(css=css) as demo:
104
 
105
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
106
 
107
-
108
-
109
- with gr.Row():
110
-
111
- num_inference_steps = gr.Slider(
112
- label="Number of inference steps",
113
- minimum=2,
114
- maximum=8,
115
- step=1,
116
- value=4,
117
- )
118
 
119
  gr.Examples(
120
  examples = examples,
@@ -123,7 +125,7 @@ with gr.Blocks(css=css) as demo:
123
 
124
  run_button.click(
125
  fn = infer,
126
- inputs = [prompt, seed, randomize_seed, num_inference_steps],
127
  outputs = [result]
128
  )
129
 
 
5
  import torch
6
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
+ transformer = Transformer2DModel.from_pretrained(
10
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
11
+ subfolder="transformer",
12
+ torch_dtype=torch.float16
13
+ )
14
+ transformer = PeftModel.from_pretrained(
15
+ transformer,
16
+ "jasperai/flash-pixart"
17
+ )
18
+
19
 
20
  if torch.cuda.is_available():
21
  torch.cuda.max_memory_allocated(device=device)
22
+ pipe = PixArtAlphaPipeline.from_pretrained(
23
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
24
+ transformer=transformer,
25
+ torch_dtype=torch.float16
26
  )
27
  pipe.enable_xformers_memory_efficient_attention()
28
  pipe = pipe.to(device)
29
  else:
30
+ pipe = PixArtAlphaPipeline.from_pretrained(
31
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
32
+ transformer=transformer,
33
+ torch_dtype=torch.float16
34
  )
35
  pipe = pipe.to(device)
36
 
37
  pipe.scheduler = LCMScheduler.from_pretrained(
38
+ "PixArt-alpha/PixArt-XL-2-1024-MS",
39
  subfolder="scheduler",
40
  timestep_spacing="trailing",
41
  )
 
44
  pipe.fuse_lora()
45
 
46
  MAX_SEED = np.iinfo(np.int32).max
47
+ MAX_IMAGE_SIZE = 1024
48
+ NUM_INFERENCE_STEPS = 4
49
 
50
  def infer(prompt, seed, randomize_seed, num_inference_steps):
51
 
 
72
  css="""
73
  #col-container {
74
  margin: 0 auto;
75
+ max-width: 512px;
76
  }
77
  """
78
 
 
117
 
118
  randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
119
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  gr.Examples(
122
  examples = examples,
 
125
 
126
  run_button.click(
127
  fn = infer,
128
+ inputs = [prompt, seed, randomize_seed, NUM_INFERENCE_STEPS],
129
  outputs = [result]
130
  )
131