gokaygokay commited on
Commit
509539f
1 Parent(s): 5328b7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -2
app.py CHANGED
@@ -50,10 +50,12 @@ download_file("https://huggingface.co/ai-forever/Real-ESRGAN/resolve/main/RealES
50
  # Download the model files
51
  ckpt_dir_pony = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
52
  ckpt_dir_cyber = snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl")
 
53
 
54
  # Load the models
55
  vae_pony = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_pony, "vae"), torch_dtype=torch.float16)
56
  vae_cyber = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_cyber, "vae"), torch_dtype=torch.float16)
 
57
 
58
  pipe_pony = StableDiffusionXLPipeline.from_pretrained(
59
  ckpt_dir_pony,
@@ -69,12 +71,21 @@ pipe_cyber = StableDiffusionXLPipeline.from_pretrained(
69
  use_safetensors=True,
70
  variant="fp16"
71
  )
 
 
 
 
 
 
 
72
 
73
  pipe_pony = pipe_pony.to("cuda")
74
  pipe_cyber = pipe_cyber.to("cuda")
 
75
 
76
  pipe_pony.unet.set_attn_processor(AttnProcessor2_0())
77
  pipe_cyber.unet.set_attn_processor(AttnProcessor2_0())
 
78
 
79
  # Define samplers
80
  samplers = {
@@ -181,7 +192,12 @@ def generate_image(model_choice, additional_positive_prompt, additional_negative
181
  input_image=None, progress=gr.Progress(track_tqdm=True)):
182
 
183
  # Select the appropriate pipe based on the model choice
184
- pipe = pipe_pony if model_choice == "Pony Realism v21" else pipe_cyber
 
 
 
 
 
185
 
186
  if use_random_seed:
187
  seed = random.randint(0, 2**32 - 1)
@@ -286,7 +302,7 @@ with gr.Blocks(theme='bethecloud/storj_theme') as demo:
286
  with gr.Accordion("Advanced settings", open=False):
287
  height = gr.Slider(512, 2048, 1024, step=64, label="Height")
288
  width = gr.Slider(512, 2048, 1024, step=64, label="Width")
289
- num_inference_steps = gr.Slider(20, 50, 30, step=1, label="Number of Inference Steps")
290
  guidance_scale = gr.Slider(1, 20, 6, step=0.1, label="Guidance Scale")
291
  num_images_per_prompt = gr.Slider(1, 4, 1, step=1, label="Number of images per prompt")
292
  use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
 
50
  # Download the model files
51
  ckpt_dir_pony = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
52
  ckpt_dir_cyber = snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl")
53
+ ckpt_dir_stallion = snapshot_download(repo_id="John6666/stallion-dreams-pony-realistic-v1-sdxl")
54
 
55
  # Load the models
56
  vae_pony = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_pony, "vae"), torch_dtype=torch.float16)
57
  vae_cyber = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_cyber, "vae"), torch_dtype=torch.float16)
58
+ vae_stallion = AutoencoderKL.from_pretrained(os.path.join(ckpt_dir_stallion, "vae"), torch_dtype=torch.float16)
59
 
60
  pipe_pony = StableDiffusionXLPipeline.from_pretrained(
61
  ckpt_dir_pony,
 
71
  use_safetensors=True,
72
  variant="fp16"
73
  )
74
+ pipe_stallion = StableDiffusionXLPipeline.from_pretrained(
75
+ ckpt_dir_stallion,
76
+ vae=vae_stallion,
77
+ torch_dtype=torch.float16,
78
+ use_safetensors=True,
79
+ variant="fp16"
80
+ )
81
 
82
  pipe_pony = pipe_pony.to("cuda")
83
  pipe_cyber = pipe_cyber.to("cuda")
84
+ pipe_stallion = pipe_stallion.to("cuda")
85
 
86
  pipe_pony.unet.set_attn_processor(AttnProcessor2_0())
87
  pipe_cyber.unet.set_attn_processor(AttnProcessor2_0())
88
+ pipe_stallion.unet.set_attn_processor(AttnProcessor2_0())
89
 
90
  # Define samplers
91
  samplers = {
 
192
  input_image=None, progress=gr.Progress(track_tqdm=True)):
193
 
194
  # Select the appropriate pipe based on the model choice
195
+ if model_choice == "Pony Realism v21":
196
+ pipe = pipe_pony
197
+ elif model_choice == "Cyber Realistic Pony v61":
198
+ pipe = pipe_cyber
199
+ else: # "Stallion Dreams Pony Realistic v1"
200
+ pipe = pipe_stallion
201
 
202
  if use_random_seed:
203
  seed = random.randint(0, 2**32 - 1)
 
302
  with gr.Accordion("Advanced settings", open=False):
303
  height = gr.Slider(512, 2048, 1024, step=64, label="Height")
304
  width = gr.Slider(512, 2048, 1024, step=64, label="Width")
305
+ num_inference_steps = gr.Slider(20, 100, 30, step=1, label="Number of Inference Steps")
306
  guidance_scale = gr.Slider(1, 20, 6, step=0.1, label="Guidance Scale")
307
  num_images_per_prompt = gr.Slider(1, 4, 1, step=1, label="Number of images per prompt")
308
  use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)