Deadmon commited on
Commit
a60d15c
·
verified ·
1 Parent(s): 8b4565f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -34
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import spaces
3
- from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL, StableDiffusionXLPipeline
4
  from transformers import AutoFeatureExtractor
5
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
  from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDPlus
@@ -9,7 +9,6 @@ from insightface.app import FaceAnalysis
9
  from insightface.utils import face_align
10
  import gradio as gr
11
  import cv2
12
- import os
13
 
14
  base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
15
  vae_model_path = "stabilityai/sd-vae-ft-mse"
@@ -42,6 +41,9 @@ pipe = StableDiffusionPipeline.from_pretrained(
42
  safety_checker=None # <--- Disable safety checker
43
  ).to(device)
44
 
 
 
 
45
  ip_model = IPAdapterFaceID(pipe, ip_ckpt, device)
46
  ip_model_plus = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_plus_ckpt, device)
47
 
@@ -50,22 +52,8 @@ app.prepare(ctx_id=0, det_size=(640, 640))
50
 
51
  cv2.setNumThreads(1)
52
 
53
- # Download the SDXL model files
54
- ckpt_dir_pony = snapshot_download(repo_id="John6666/pony-realism-v21main-sdxl")
55
- ckpt_dir_cyber = snapshot_download(repo_id="John6666/cyberrealistic-pony-v61-sdxl")
56
- ckpt_dir_stallion = snapshot_download(repo_id="John6666/stallion-dreams-pony-realistic-v1-sdxl")
57
-
58
- # Load the SDXL models
59
- pipe_pony = StableDiffusionXLPipeline.from_pretrained(ckpt_dir_pony, torch_dtype=torch.float16)
60
- pipe_cyber = StableDiffusionXLPipeline.from_pretrained(ckpt_dir_cyber, torch_dtype=torch.float16)
61
- pipe_stallion = StableDiffusionXLPipeline.from_pretrained(ckpt_dir_stallion, torch_dtype=torch.float16)
62
-
63
- pipe_pony.to(device)
64
- pipe_cyber.to(device)
65
- pipe_stallion.to(device)
66
-
67
  @spaces.GPU(enable_queue=True)
68
- def generate_image(images, prompt, negative_prompt, preserve_face_structure, face_strength, likeness_strength, nfaa_negative_prompt, model_choice, progress=gr.Progress(track_tqdm=True)):
69
  faceid_all_embeds = []
70
  first_iteration = True
71
  for image in images:
@@ -81,24 +69,17 @@ def generate_image(images, prompt, negative_prompt, preserve_face_structure, fac
81
 
82
  total_negative_prompt = f"{negative_prompt} {nfaa_negative_prompt}"
83
 
84
- if model_choice == "Pony Realism v21":
85
- pipe = pipe_pony
86
- elif model_choice == "Cyber Realistic Pony v61":
87
- pipe = pipe_cyber
88
- else: # "Stallion Dreams Pony Realistic v1"
89
- pipe = pipe_stallion
90
-
91
  if(not preserve_face_structure):
92
  print("Generating normal")
93
  image = ip_model.generate(
94
  prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
95
- scale=likeness_strength, width=512, height=512, num_inference_steps=30, pipe=pipe
96
  )
97
  else:
98
  print("Generating plus")
99
  image = ip_model_plus.generate(
100
  prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
101
- scale=likeness_strength, face_image=face_image, shortcut=True, s_scale=face_strength, width=512, height=512, num_inference_steps=30, pipe=pipe
102
  )
103
  print(image)
104
  return image
@@ -114,12 +95,10 @@ def swap_to_gallery(images):
114
 
115
  def remove_back_to_files():
116
  return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
117
-
118
  css = '''
119
  h1{margin-bottom: 0 !important}
120
  footer{display:none !important}
121
  '''
122
-
123
  with gr.Blocks(css=css) as demo:
124
  gr.Markdown("")
125
  gr.Markdown("")
@@ -137,11 +116,6 @@ with gr.Blocks(css=css) as demo:
137
  placeholder="A photo of a [man/woman/person]...")
138
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality")
139
  style = gr.Radio(label="Generation type", info="For stylized try prompts like 'a watercolor painting of a woman'", choices=["Photorealistic", "Stylized"], value="Photorealistic")
140
- model_choice = gr.Dropdown(
141
- ["Pony Realism v21", "Cyber Realistic Pony v61", "Stallion Dreams Pony Realistic v1"],
142
- label="Model Choice",
143
- value="Pony Realism v21"
144
- )
145
  submit = gr.Button("Submit")
146
  with gr.Accordion(open=False, label="Advanced Options"):
147
  preserve = gr.Checkbox(label="Preserve Face Structure", info="Higher quality, less versatility (the face structure of your first photo will be preserved). Unchecking this will use the v1 model.", value=True)
@@ -156,7 +130,7 @@ with gr.Blocks(css=css) as demo:
156
  files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
157
  remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
158
  submit.click(fn=generate_image,
159
- inputs=[files,prompt,negative_prompt,preserve, face_strength, likeness_strength, nfaa_negative_prompts, model_choice],
160
  outputs=gallery)
161
 
162
  gr.Markdown("")
 
1
  import torch
2
  import spaces
3
+ from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
4
  from transformers import AutoFeatureExtractor
5
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
  from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDPlus
 
9
  from insightface.utils import face_align
10
  import gradio as gr
11
  import cv2
 
12
 
13
  base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
14
  vae_model_path = "stabilityai/sd-vae-ft-mse"
 
41
  safety_checker=None # <--- Disable safety checker
42
  ).to(device)
43
 
44
+ #pipe.load_lora_weights("h94/IP-Adapter-FaceID", weight_name="ip-adapter-faceid-plusv2_sd15_lora.safetensors")
45
+ #pipe.fuse_lora()
46
+
47
  ip_model = IPAdapterFaceID(pipe, ip_ckpt, device)
48
  ip_model_plus = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_plus_ckpt, device)
49
 
 
52
 
53
  cv2.setNumThreads(1)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  @spaces.GPU(enable_queue=True)
56
+ def generate_image(images, prompt, negative_prompt, preserve_face_structure, face_strength, likeness_strength, nfaa_negative_prompt, progress=gr.Progress(track_tqdm=True)):
57
  faceid_all_embeds = []
58
  first_iteration = True
59
  for image in images:
 
69
 
70
  total_negative_prompt = f"{negative_prompt} {nfaa_negative_prompt}"
71
 
 
 
 
 
 
 
 
72
  if(not preserve_face_structure):
73
  print("Generating normal")
74
  image = ip_model.generate(
75
  prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
76
+ scale=likeness_strength, width=512, height=512, num_inference_steps=30
77
  )
78
  else:
79
  print("Generating plus")
80
  image = ip_model_plus.generate(
81
  prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
82
+ scale=likeness_strength, face_image=face_image, shortcut=True, s_scale=face_strength, width=512, height=512, num_inference_steps=30
83
  )
84
  print(image)
85
  return image
 
95
 
96
  def remove_back_to_files():
97
  return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
 
98
  css = '''
99
  h1{margin-bottom: 0 !important}
100
  footer{display:none !important}
101
  '''
 
102
  with gr.Blocks(css=css) as demo:
103
  gr.Markdown("")
104
  gr.Markdown("")
 
116
  placeholder="A photo of a [man/woman/person]...")
117
  negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality")
118
  style = gr.Radio(label="Generation type", info="For stylized try prompts like 'a watercolor painting of a woman'", choices=["Photorealistic", "Stylized"], value="Photorealistic")
 
 
 
 
 
119
  submit = gr.Button("Submit")
120
  with gr.Accordion(open=False, label="Advanced Options"):
121
  preserve = gr.Checkbox(label="Preserve Face Structure", info="Higher quality, less versatility (the face structure of your first photo will be preserved). Unchecking this will use the v1 model.", value=True)
 
130
  files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
131
  remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
132
  submit.click(fn=generate_image,
133
+ inputs=[files,prompt,negative_prompt,preserve, face_strength, likeness_strength, nfaa_negative_prompts],
134
  outputs=gallery)
135
 
136
  gr.Markdown("")