File size: 4,298 Bytes
61e8157
 
1b829a0
 
eb48411
 
 
1b829a0
61e8157
 
 
 
0d95b06
1b829a0
eede3bc
 
 
 
 
 
 
1b829a0
 
 
 
 
 
 
61e8157
 
 
 
 
 
 
 
 
456a8a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7751ed
456a8a0
ea6de43
 
61e8157
456a8a0
01da3bb
b7751ed
ea6de43
456a8a0
 
 
 
 
 
 
61e8157
 
 
 
456a8a0
ea6de43
61e8157
 
 
ea6de43
61e8157
 
 
 
 
ea6de43
 
61e8157
 
01da3bb
 
 
 
ea6de43
61e8157
ea6de43
 
 
 
 
 
 
 
61e8157
 
 
ea6de43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61e8157
 
 
ea6de43
 
61e8157
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import gradio as gr
import os
import sys
import subprocess
import numpy as np
from PIL import Image
import cv2

import torch
from diffusers import StableDiffusion3Pipeline
from diffusers.models.controlnet_sd3 import ControlNetSD3Model
from diffusers.utils.torch_utils import randn_tensor
from diffusers.utils import load_image

# Clone the specific branch
subprocess.run(["git", "clone", "-b", "sd3_control", "https://github.com/instantX-research/diffusers_sd3_control.git"])

# Change directory to the cloned repository and install it
os.chdir('diffusers_sd3_control')
subprocess.run(["pip", "install", "-e", "."])

# Add the path to the examples directory
sys.path.append(os.path.abspath('./examples/community'))

# Import the required pipeline
from pipeline_stable_diffusion_3_controlnet import StableDiffusion3CommonPipeline



# load pipeline
base_model = 'stabilityai/stable-diffusion-3-medium-diffusers'
pipe = StableDiffusion3CommonPipeline.from_pretrained(
    base_model, 
    controlnet_list=['InstantX/SD3-Controlnet-Canny'],
)
pipe.to('cuda:0', torch.float16)

def resize_image(input_path, output_path, target_height):
    # Open the input image
    img = Image.open(input_path)

    # Calculate the aspect ratio of the original image
    original_width, original_height = img.size
    original_aspect_ratio = original_width / original_height

    # Calculate the new width while maintaining the aspect ratio and the target height
    new_width = int(target_height * original_aspect_ratio)

    # Resize the image while maintaining the aspect ratio and fixing the height
    img = img.resize((new_width, target_height), Image.LANCZOS)

    # Save the resized image
    img.save(output_path)

    return output_path, new_width, target_height

def infer(image_in, prompt, inference_steps, guidance_scale, control_weight):
    
    n_prompt = 'NSFW, nude, naked, porn, ugly'

    
    
    # Canny preprocessing
    image_to_canny = load_image(image_in)
    image_to_canny = np.array(image_to_canny)
    image_to_canny = cv2.Canny(image_to_canny, 100, 200)
    image_to_canny = image_to_canny[:, :, None]
    image_to_canny = np.concatenate([image_to_canny, image_to_canny, image_to_canny], axis=2)
    image_to_canny = Image.fromarray(image_to_canny)
    
    # controlnet config
    controlnet_conditioning = [
        dict(
            control_index=0,
            control_image=image_to_canny,
            control_weight=control_weight,
            control_pooled_projections='zeros'
        )
    ]
    
    # infer
    image = pipe(
        prompt=prompt,
        negative_prompt=n_prompt,
        controlnet_conditioning=controlnet_conditioning,
        num_inference_steps=inference_steps,
        guidance_scale=guidance_scale,
    ).images[0]

    image_redim, w, h = resize_image(image_in, "resized_input.jpg", 1024)
    
    image = image.resize((w, h), Image.LANCZOS)

    return image, image_to_canny

css="""
#col-container{
    margin: 0 auto;
    max-width: 1080px;
}
"""
with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown("""
        # SD3 ControlNet
        """)
        with gr.Row():
            with gr.Column():
                image_in = gr.Image(label="Image reference", sources=["upload"], type="filepath")
                prompt = gr.Textbox(label="Prompt")
                with gr.Accordion("Advanced settings", open=False):
                    with gr.Column():
                        with gr.Row():
                            inference_steps = gr.Slider(label="Inference steps", minimum=1, maximum=50, step=1, value=25)
                            guidance_scale = gr.Slider(label="Guidance scale", minimum=1.0, maximum=10.0, step=0.1, value=7.0)
                        control_weight = gr.Slider(label="Control Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.7)
                    
                submit_btn = gr.Button("Submit")
            with gr.Column():
                result = gr.Image(label="Result")
                canny_used = gr.Image(label="Preprocessed Canny")
    
    submit_btn.click(
        fn = infer,
        inputs = [image_in, prompt, inference_steps, guidance_scale, control_weight],
        outputs = [result, canny_used],
        show_api=False
    )
demo.queue().launch()