Deadmon commited on
Commit
d2fcb60
1 Parent(s): 613b409

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +166 -144
app.py CHANGED
@@ -1,160 +1,182 @@
1
- import torch
2
- import spaces
3
- from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
4
- from transformers import AutoFeatureExtractor
5
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
6
- from ip_adapter.ip_adapter_faceid import IPAdapterFaceID, IPAdapterFaceIDPlus
7
- from huggingface_hub import hf_hub_download
8
- from insightface.app import FaceAnalysis
9
- from insightface.utils import face_align
10
  import gradio as gr
11
- import cv2
12
-
 
 
 
 
 
 
 
 
 
 
13
  base_model_paths = {
14
- "RealisticVisionV4": "SG161222/Realistic_Vision_V4.0_noVAE",
15
- "RealisticVisionV6": "SG161222/Realistic_Vision_V6.0_B1_noVAE",
16
  "Deliberate": "Yntec/Deliberate",
17
- "DeliberateV2": "Yntec/Deliberate2",
18
- "Dreamshaper8": "Lykon/dreamshaper-8",
19
- "EpicRealism": "emilianJR/epiCRealism"
20
  }
21
 
22
-
23
- vae_model_path = "stabilityai/sd-vae-ft-mse"
24
- image_encoder_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
25
- ip_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid_sd15.bin", repo_type="model")
26
- ip_plus_ckpt = hf_hub_download(repo_id="h94/IP-Adapter-FaceID", filename="ip-adapter-faceid-plusv2_sd15.bin", repo_type="model")
27
-
28
- safety_model_id = "CompVis/stable-diffusion-safety-checker"
29
- safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
30
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
31
-
32
- device = "cuda"
33
-
34
- noise_scheduler = DDIMScheduler(
35
- num_train_timesteps=1000,
36
- beta_start=0.00085,
37
- beta_end=0.012,
38
- beta_schedule="scaled_linear",
39
- clip_sample=False,
40
- set_alpha_to_one=False,
41
- steps_offset=1,
42
- )
43
- vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=torch.float16)
44
-
45
- def load_model(base_model_path):
46
- pipe = StableDiffusionPipeline.from_pretrained(
47
- base_model_path,
48
- torch_dtype=torch.float16,
49
- scheduler=noise_scheduler,
50
- vae=vae,
51
- feature_extractor=safety_feature_extractor,
52
- safety_checker=None # <--- Disable safety checker
53
- ).to(device)
54
- return pipe
55
-
56
- ip_model = None
57
- ip_model_plus = None
58
-
59
- app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider'])
60
- app.prepare(ctx_id=0, det_size=(640, 640))
61
-
62
- cv2.setNumThreads(1)
63
-
64
- @spaces.GPU(enable_queue=True)
65
- def generate_image(images, prompt, negative_prompt, preserve_face_structure, face_strength, likeness_strength, nfaa_negative_prompt, base_model, num_inference_steps, guidance_scale, width, height, progress=gr.Progress(track_tqdm=True)):
66
- global ip_model, ip_model_plus
67
- base_model_path = base_model_paths[base_model]
68
- pipe = load_model(base_model_path)
69
- ip_model = IPAdapterFaceID(pipe, ip_ckpt, device)
70
- ip_model_plus = IPAdapterFaceIDPlus(pipe, image_encoder_path, ip_plus_ckpt, device)
71
-
72
- faceid_all_embeds = []
73
- first_iteration = True
74
- for image in images:
75
- face = cv2.imread(image)
76
- faces = app.get(face)
77
- faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
78
- faceid_all_embeds.append(faceid_embed)
79
- if(first_iteration and preserve_face_structure):
80
- face_image = face_align.norm_crop(face, landmark=faces[0].kps, image_size=224) # you can also segment the face
81
- first_iteration = False
82
-
83
- average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
84
 
85
- total_negative_prompt = f"{negative_prompt} {nfaa_negative_prompt}"
 
86
 
87
- if(not preserve_face_structure):
88
- print("Generating normal")
89
- image = ip_model.generate(
90
- prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
91
- scale=likeness_strength, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
92
- )
93
- else:
94
- print("Generating plus")
95
- image = ip_model_plus.generate(
96
- prompt=prompt, negative_prompt=total_negative_prompt, faceid_embeds=average_embedding,
97
- scale=likeness_strength, face_image=face_image, shortcut=True, s_scale=face_strength, width=width, height=height, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale
98
- )
99
- print(image)
100
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def change_style(style):
 
 
 
103
  if style == "Photorealistic":
104
- return(gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0))
105
  else:
106
- return(gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
- def swap_to_gallery(images):
109
- return gr.update(value=images, visible=True), gr.update(visible=True), gr.update(visible=False)
110
 
111
- def remove_back_to_files():
112
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
113
 
114
- css = '''
115
- h1{margin-bottom: 0 !important}
116
- footer{display:none !important}
117
- '''
118
 
119
- with gr.Blocks(css=css) as demo:
120
- gr.Markdown("")
121
- gr.Markdown("")
122
- with gr.Row():
123
- with gr.Column():
124
- files = gr.Files(
125
- label="Drag 1 or more photos of your face",
126
- file_types=["image"]
127
- )
128
- uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=125)
129
- with gr.Column(visible=False) as clear_button:
130
- remove_and_reupload = gr.ClearButton(value="Remove and upload new ones", components=files, size="sm")
131
- prompt = gr.Textbox(label="Prompt",
132
- info="Try something like 'a photo of a man/woman/person'",
133
- placeholder="A photo of a [man/woman/person]...")
134
- negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality")
135
- style = gr.Radio(label="Generation type", info="For stylized try prompts like 'a watercolor painting of a woman'", choices=["Photorealistic", "Stylized"], value="Photorealistic")
136
- base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="Realistic_Vision_V4.0_noVAE")
137
- submit = gr.Button("Submit")
138
- with gr.Accordion(open=False, label="Advanced Options"):
139
- preserve = gr.Checkbox(label="Preserve Face Structure", info="Higher quality, less versatility (the face structure of your first photo will be preserved). Unchecking this will use the v1 model.", value=True)
140
- face_strength = gr.Slider(label="Face Structure strength", info="Only applied if preserve face structure is checked", value=1.3, step=0.1, minimum=0, maximum=3)
141
- likeness_strength = gr.Slider(label="Face Embed strength", value=1.0, step=0.1, minimum=0, maximum=5)
142
- nfaa_negative_prompts = gr.Textbox(label="Appended Negative Prompts", info="Negative prompts to steer generations towards safe for all audiences outputs", value="naked, bikini, skimpy, scanty, bare skin, lingerie, swimsuit, exposed, see-through")
143
- num_inference_steps = gr.Slider(label="Number of Inference Steps", value=30, step=1, minimum=10, maximum=100)
144
- guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.1, minimum=1, maximum=20)
145
- width = gr.Slider(label="Width", value=512, step=64, minimum=256, maximum=1024)
146
- height = gr.Slider(label="Height", value=512, step=64, minimum=256, maximum=1024)
147
  with gr.Column():
148
- gallery = gr.Gallery(label="Generated Images")
149
- style.change(fn=change_style,
150
- inputs=style,
151
- outputs=[preserve, face_strength, likeness_strength])
152
- files.upload(fn=swap_to_gallery, inputs=files, outputs=[uploaded_files, clear_button, files])
153
- remove_and_reupload.click(fn=remove_back_to_files, outputs=[uploaded_files, clear_button, files])
154
- submit.click(fn=generate_image,
155
- inputs=[files,prompt,negative_prompt,preserve, face_strength, likeness_strength, nfaa_negative_prompts, base_model, num_inference_steps, guidance_scale, width, height],
156
- outputs=gallery)
157
-
158
- gr.Markdown("")
159
-
160
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import asyncio
3
+ import fal_client
4
+ from PIL import Image
5
+ import requests
6
+ import io
7
+ import os
8
+
9
+ # Set up your Fal API key as an environment variable
10
+ os.environ["FAL_KEY"] = "b6fa8d06-4225-4ec3-9aaf-4d01e960d899:cc6a52d0fc818c6f892b2760fd341ee4"
11
+ fal_client.api_key = os.environ["FAL_KEY"]
12
+
13
+ # Model choices (base models)
14
  base_model_paths = {
15
+ "Realistic Vision V4": "SG161222/Realistic_Vision_V4.0_noVAE",
16
+ "Realistic Vision V6": "SG161222/Realistic_Vision_V6.0_B1_noVAE",
17
  "Deliberate": "Yntec/Deliberate",
18
+ "Deliberate V2": "Yntec/Deliberate2",
19
+ "Dreamshaper 8": "Lykon/dreamshaper-8",
20
+ "Epic Realism": "emilianJR/epiCRealism"
21
  }
22
 
23
+ async def generate_image(image_url: str, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
24
+ """
25
+ Submit the image generation process using the fal_client's submit method with the ip-adapter-face-id model.
26
+ """
27
+ try:
28
+ handler = fal_client.submit(
29
+ "fal-ai/ip-adapter-face-id",
30
+ arguments={
31
+ "model_type": model_type,
32
+ "prompt": prompt,
33
+ "face_image_url": image_url,
34
+ "negative_prompt": negative_prompt,
35
+ "seed": seed,
36
+ "guidance_scale": guidance_scale,
37
+ "num_inference_steps": num_inference_steps,
38
+ "num_samples": num_samples,
39
+ "width": width,
40
+ "height": height,
41
+ "base_1_5_model_repo": base_model_paths[base_model], # Base model selected by user
42
+ "base_sdxl_model_repo": "SG161222/RealVisXL_V3.0", # SDXL model as default
43
+ },
44
+ )
45
+ # Retrieve the result synchronously
46
+ result = handler.get()
47
+
48
+ if "image" in result and "url" in result["image"]:
49
+ return result["image"] # Return the full image information dictionary
50
+ else:
51
+ return None
52
+ except Exception as e:
53
+ print(f"Error generating image: {e}")
54
+ return None
55
+
56
+ def fetch_image_from_url(url: str) -> Image.Image:
57
+ """
58
+ Download the image from the given URL and return it as a PIL Image.
59
+ """
60
+ response = requests.get(url)
61
+ return Image.open(io.BytesIO(response.content))
62
+
63
+ async def process_inputs(image: Image.Image, prompt: str, negative_prompt: str, model_type: str, base_model: str, seed: int, guidance_scale: float, num_inference_steps: int, num_samples: int, width: int, height: int):
64
+ """
65
+ Asynchronous function to handle image upload, prompt inputs and generate the final image.
66
+ """
67
+ # Upload the image and get a valid URL
68
+ image_url = await upload_image_to_server(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ if not image_url:
71
+ return None
72
 
73
+ # Run the image generation
74
+ image_info = await generate_image(image_url, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height)
75
+
76
+ if image_info and "url" in image_info:
77
+ return fetch_image_from_url(image_info["url"]), image_info # Return both the image and the metadata
78
+
79
+ return None, None
80
+
81
+ async def upload_image_to_server(image: Image.Image) -> str:
82
+ """
83
+ Upload an image to the fal_client and return the uploaded image URL.
84
+ """
85
+ # Convert PIL image to byte stream for upload
86
+ byte_arr = io.BytesIO()
87
+ image.save(byte_arr, format='PNG')
88
+ byte_arr.seek(0)
89
+
90
+ # Convert BytesIO to a file-like object that fal_client can handle
91
+ with open("temp_image.png", "wb") as f:
92
+ f.write(byte_arr.getvalue())
93
+
94
+ # Upload the image using fal_client's asynchronous method
95
+ try:
96
+ upload_url = await fal_client.upload_file_async("temp_image.png")
97
+ return upload_url
98
+ except Exception as e:
99
+ print(f"Error uploading image: {e}")
100
+ return ""
101
 
102
  def change_style(style):
103
+ """
104
+ Changes the style for 'Photorealistic' or 'Stylized' generation type.
105
+ """
106
  if style == "Photorealistic":
107
+ return gr.update(value=True), gr.update(value=1.3), gr.update(value=1.0)
108
  else:
109
+ return gr.update(value=True), gr.update(value=0.1), gr.update(value=0.8)
110
+
111
+ def gradio_interface(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height):
112
+ """
113
+ Wrapper function to run asynchronous code in a synchronous environment like Gradio.
114
+ """
115
+ loop = asyncio.new_event_loop()
116
+ asyncio.set_event_loop(loop)
117
+
118
+ # Execute the async process_inputs function
119
+ result_image, image_info = loop.run_until_complete(process_inputs(image, prompt, negative_prompt, model_type, base_model, seed, guidance_scale, num_inference_steps, num_samples, width, height))
120
+ if result_image:
121
+ # Display both the image and metadata
122
+ metadata = f"File Name: {image_info['file_name']}\nFile Size: {image_info['file_size']} bytes\nDimensions: {image_info['width']}x{image_info['height']} px\nSeed: {image_info.get('seed', 'N/A')}"
123
+ return result_image, metadata
124
+ return None, "Error generating image"
125
+
126
+ # Gradio Interface
127
+ with gr.Blocks() as demo:
128
+ gr.Markdown("## Image Generation with Fal API and Gradio")
129
+
130
+ with gr.Row():
131
+ with gr.Column():
132
+ # Image input
133
+ image_input = gr.Image(label="Upload Image", type="pil")
134
+
135
+ # Textbox for prompt
136
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Describe the image you want to generate", lines=2)
137
+
138
+ # Textbox for negative prompt
139
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", placeholder="Describe elements to avoid", lines=2)
140
+
141
+ # Radio buttons for model type (Photorealistic or Stylized)
142
+ style = gr.Radio(label="Generation type", choices=["Photorealistic", "Stylized"], value="Photorealistic")
143
+
144
+ # Dropdown for selecting the base model
145
+ base_model = gr.Dropdown(label="Base Model", choices=list(base_model_paths.keys()), value="Realistic Vision V4")
146
 
147
+ # Seed input
148
+ seed_input = gr.Number(label="Seed", value=42, precision=0)
149
 
150
+ # Guidance scale slider
151
+ guidance_scale = gr.Slider(label="Guidance Scale", value=7.5, step=0.1, minimum=1, maximum=20)
152
 
153
+ # Inference steps slider
154
+ num_inference_steps = gr.Slider(label="Number of Inference Steps", value=50, step=1, minimum=10, maximum=100)
 
 
155
 
156
+ # Samples slider
157
+ num_samples = gr.Slider(label="Number of Samples", value=4, step=1, minimum=1, maximum=10)
158
+
159
+ # Image dimensions sliders
160
+ width = gr.Slider(label="Width", value=1024, step=64, minimum=256, maximum=1024)
161
+ height = gr.Slider(label="Height", value=1024, step=64, minimum=256, maximum=1024)
162
+
163
+ # Button to trigger image generation
164
+ generate_button = gr.Button("Generate Image")
165
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  with gr.Column():
167
+ # Display generated image and metadata
168
+ generated_image = gr.Image(label="Generated Image")
169
+ metadata_output = gr.Textbox(label="Image Metadata", interactive=False, lines=6)
170
+
171
+ # Style change functionality
172
+ style.change(fn=change_style, inputs=style, outputs=[guidance_scale, num_samples, width])
173
+
174
+ # Define the interaction between inputs and output
175
+ generate_button.click(
176
+ fn=gradio_interface,
177
+ inputs=[image_input, prompt_input, negative_prompt_input, style, base_model, seed_input, guidance_scale, num_inference_steps, num_samples, width, height],
178
+ outputs=[generated_image, metadata_output]
179
+ )
180
+
181
+ # Launch the Gradio interface
182
+ demo.launch()