BlockDetail commited on
Commit
259a646
1 Parent(s): 4f7d543
Files changed (1) hide show
  1. app.py +115 -105
app.py CHANGED
@@ -8,112 +8,15 @@ import spaces
8
 
9
  negative_prompt = ""
10
  device = torch.device('cuda')
11
- pipe = None
12
-
13
- @spaces.GPU
14
- def load():
15
- global pipe
16
- controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
17
- pipe = CustomStableDiffusionControlNetPipeline.from_pretrained(
18
- "runwayml/stable-diffusion-v1-5",
19
- controlnet=controlnet, torch_dtype=torch.float16
20
- ).to(device)
21
- pipe.safety_checker = None
22
- pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
23
-
24
- @spaces.GPU
25
- def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
26
- global curr_num_samples
27
- global pipe
28
- generator = torch.Generator(device="cuda:0")
29
- generator.manual_seed(seed)
30
- negative_prompt = ""
31
- guidance_scale = 7
32
- controlnet_conditioning_scale = 1.0
33
- images = pipe([prompt]*curr_num_samples, [curr_sketch_image.convert("RGB").point( lambda p: 256 if p > 128 else 0)]*curr_num_samples, guidance_scale=guidance_scale, controlnet_conditioning_scale = controlnet_conditioning_scale, negative_prompt = [negative_prompt] * curr_num_samples, num_inference_steps=num_steps, generator=generator, key_image=None, neg_mask=None).images
34
- # run blended renoising if blocking strokes are provided
35
- if dilation_mask is not None:
36
- new_images = pipe.collage([prompt] * curr_num_samples, images, [dilation_mask] * curr_num_samples, num_inference_steps=50, strength=0.8)["images"]
37
- else:
38
- new_images = images
39
- return images, new_images
40
- def run_sketching(prompt, curr_sketch, sketch_states, dilation, contour_dilation=11):
41
- seed = sketch_states[k][1]
42
- if seed is None:
43
- seed = np.random.randint(1000)
44
- sketch_states[k][1] = seed
45
-
46
- curr_sketch_image = Image.fromarray(curr_sketch["composite"])
47
- curr_sketch = np.array(curr_sketch_image.resize((512, 512), resample=0))
48
- curr_sketch[:, :, 0][curr_sketch[:, :, -1] == 0] = 255
49
- curr_sketch[:, :, 2][curr_sketch[:, :, -1] == 0] = 255
50
- curr_sketch[:, :, 1][curr_sketch[:, :, -1] == 0] = 255
51
- curr_sketch_image = Image.fromarray(curr_sketch[:, :, 0]).resize((512, 512))
52
- curr_construction_image = Image.fromarray(255 - curr_sketch[:, :, 1] + curr_sketch[:, :, 0])
53
- if np.sum(255 - np.array(curr_construction_image)) == 0:
54
- curr_construction_image = None
55
- curr_detail_image = Image.fromarray(curr_sketch[:, :, 1]).resize((512, 512))
56
- if curr_construction_image is not None:
57
- dilation_mask = Image.fromarray(255 - np.array(curr_construction_image)).filter(ImageFilter.MaxFilter(dilation))
58
- dilation_mask = dilation_mask.point( lambda p: 256 if p > 0 else 25).filter(ImageFilter.GaussianBlur(radius = 5))
59
- neg_dilation_mask = Image.fromarray(255 - np.array(curr_detail_image)).filter(ImageFilter.MaxFilter(contour_dilation))
60
- neg_dilation_mask = np.array(neg_dilation_mask.point( lambda p: 256 if p > 0 else 0))
61
- dilation_mask = np.array(dilation_mask)
62
- dilation_mask[neg_dilation_mask > 0] = 25
63
- dilation_mask = Image.fromarray(dilation_mask).filter(ImageFilter.GaussianBlur(radius = 5))
64
- else:
65
- dilation_mask = None
66
-
67
- images, new_images = sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps = 40, dilation = dilation)
68
- save_sketch = np.array(Image.fromarray(curr_sketch).convert("RGBA"))
69
- save_sketch[:, :, 3][save_sketch[:, :, 0] > 128] = 0
70
- overlays = []
71
- for i in images:
72
- background = i.copy()
73
- background.putalpha(80)
74
- background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
75
- overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA"))
76
- overlays.append(overlay.convert("RGB"))
77
- new_overlays = []
78
- for i in new_images:
79
- background = i.copy()
80
- background.putalpha(80)
81
- background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
82
- overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA"))
83
- new_overlays.append(overlay.convert("RGB"))
84
-
85
- global all_gens
86
- all_gens = new_images
87
- return new_images, new_overlays, images, overlays
88
- def reset(sketch_states):
89
- for k in range(len(sketch_states)):
90
- sketch_states[k] = [None, None]
91
- return None, sketch_states
92
-
93
- # def change_color(stroke_type):
94
- # if stroke_type == "Blocking":
95
- # color = "#00FF00"
96
- # else:
97
- # color = "#000000"
98
- # return gr.Sketchpad(sources = (), width=512, brush = gr.Brush(colors=[color], default_size = 2, color_mode="fixed"), height=512)
99
-
100
- def change_background(option):
101
- global all_gens
102
- if option == "None" or len(all_gens) == 0:
103
- return None
104
- elif option == "Sample 0":
105
- image_overlay = all_gens[0].copy()
106
- elif option == "Sample 1":
107
- image_overlay = all_gens[0].copy()
108
- else:
109
- return None
110
- image_overlay.putalpha(80)
111
- return image_overlay
112
- def change_num_samples(change):
113
- global curr_num_samples
114
- curr_num_samples = change
115
- return None
116
 
 
 
 
 
 
 
 
 
117
  threshold = 250
118
  curr_num_samples = 2
119
 
@@ -148,6 +51,113 @@ with gr.Blocks() as demo:
148
  sketch_states = gr.State(start_state)
149
  checkbox_state = gr.State(True)
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  btn.click(run_sketching, [prompt_box, canvas, sketch_states, dilation_strength[0]], [gallery0, gallery1, gallery2, gallery3])
152
  btn2.click(reset, sketch_states, [canvas, sketch_states])
153
  # stroke_type[0].change(change_color, [stroke_type[0]], canvas)
 
8
 
9
  negative_prompt = ""
10
  device = torch.device('cuda')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ global pipe
13
+ controlnet = ControlNetModel.from_pretrained("BlockDetail/PartialSketchControlNet", torch_dtype=torch.float16).to(device)
14
+ pipe = CustomStableDiffusionControlNetPipeline.from_pretrained(
15
+ "runwayml/stable-diffusion-v1-5",
16
+ controlnet=controlnet, torch_dtype=torch.float16
17
+ ).to(device)
18
+ pipe.safety_checker = None
19
+ pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
20
  threshold = 250
21
  curr_num_samples = 2
22
 
 
51
  sketch_states = gr.State(start_state)
52
  checkbox_state = gr.State(True)
53
 
54
+ @spaces.GPU
55
+ def sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps, dilation):
56
+ global curr_num_samples
57
+ global pipe
58
+ generator = torch.Generator(device="cuda:0")
59
+ generator.manual_seed(seed)
60
+
61
+ negative_prompt = ""
62
+ guidance_scale = 7
63
+ controlnet_conditioning_scale = 1.0
64
+ images = pipe([prompt]*curr_num_samples, [curr_sketch_image.convert("RGB").point( lambda p: 256 if p > 128 else 0)]*curr_num_samples, guidance_scale=guidance_scale, controlnet_conditioning_scale = controlnet_conditioning_scale, negative_prompt = [negative_prompt] * curr_num_samples, num_inference_steps=num_steps, generator=generator, key_image=None, neg_mask=None).images
65
+
66
+ # run blended renoising if blocking strokes are provided
67
+ if dilation_mask is not None:
68
+ new_images = pipe.collage([prompt] * curr_num_samples, images, [dilation_mask] * curr_num_samples, num_inference_steps=50, strength=0.8)["images"]
69
+ else:
70
+ new_images = images
71
+ return images, new_images
72
+
73
+ def run_sketching(prompt, curr_sketch, sketch_states, dilation, contour_dilation=11):
74
+ seed = sketch_states[k][1]
75
+ if seed is None:
76
+ seed = np.random.randint(1000)
77
+ sketch_states[k][1] = seed
78
+
79
+ curr_sketch_image = Image.fromarray(curr_sketch["composite"])
80
+ curr_sketch = np.array(curr_sketch_image.resize((512, 512), resample=0))
81
+ curr_sketch[:, :, 0][curr_sketch[:, :, -1] == 0] = 255
82
+ curr_sketch[:, :, 2][curr_sketch[:, :, -1] == 0] = 255
83
+ curr_sketch[:, :, 1][curr_sketch[:, :, -1] == 0] = 255
84
+
85
+ curr_sketch_image = Image.fromarray(curr_sketch[:, :, 0]).resize((512, 512))
86
+
87
+ curr_construction_image = Image.fromarray(255 - curr_sketch[:, :, 1] + curr_sketch[:, :, 0])
88
+ if np.sum(255 - np.array(curr_construction_image)) == 0:
89
+ curr_construction_image = None
90
+
91
+ curr_detail_image = Image.fromarray(curr_sketch[:, :, 1]).resize((512, 512))
92
+
93
+ if curr_construction_image is not None:
94
+ dilation_mask = Image.fromarray(255 - np.array(curr_construction_image)).filter(ImageFilter.MaxFilter(dilation))
95
+ dilation_mask = dilation_mask.point( lambda p: 256 if p > 0 else 25).filter(ImageFilter.GaussianBlur(radius = 5))
96
+
97
+ neg_dilation_mask = Image.fromarray(255 - np.array(curr_detail_image)).filter(ImageFilter.MaxFilter(contour_dilation))
98
+ neg_dilation_mask = np.array(neg_dilation_mask.point( lambda p: 256 if p > 0 else 0))
99
+ dilation_mask = np.array(dilation_mask)
100
+ dilation_mask[neg_dilation_mask > 0] = 25
101
+ dilation_mask = Image.fromarray(dilation_mask).filter(ImageFilter.GaussianBlur(radius = 5))
102
+ else:
103
+ dilation_mask = None
104
+
105
+ images, new_images = sketch(curr_sketch_image, dilation_mask, prompt, seed, num_steps = 40, dilation = dilation)
106
+
107
+ save_sketch = np.array(Image.fromarray(curr_sketch).convert("RGBA"))
108
+ save_sketch[:, :, 3][save_sketch[:, :, 0] > 128] = 0
109
+
110
+ overlays = []
111
+ for i in images:
112
+ background = i.copy()
113
+ background.putalpha(80)
114
+ background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
115
+ overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA"))
116
+ overlays.append(overlay.convert("RGB"))
117
+
118
+ new_overlays = []
119
+ for i in new_images:
120
+ background = i.copy()
121
+ background.putalpha(80)
122
+ background = Image.alpha_composite(Image.fromarray(255 * np.ones((512, 512)).astype(np.uint8)).convert("RGBA"), background)
123
+ overlay = Image.alpha_composite(background.resize((512, 512)), Image.fromarray(save_sketch).resize((512, 512)).convert("RGBA"))
124
+ new_overlays.append(overlay.convert("RGB"))
125
+
126
+ global all_gens
127
+ all_gens = new_images
128
+
129
+ return new_images, new_overlays, images, overlays
130
+
131
+ def reset(sketch_states):
132
+ for k in range(len(sketch_states)):
133
+ sketch_states[k] = [None, None]
134
+ return None, sketch_states
135
+
136
+ # def change_color(stroke_type):
137
+ # if stroke_type == "Blocking":
138
+ # color = "#00FF00"
139
+ # else:
140
+ # color = "#000000"
141
+ # return gr.Sketchpad(sources = (), width=512, brush = gr.Brush(colors=[color], default_size = 2, color_mode="fixed"), height=512)
142
+
143
+ def change_background(option):
144
+ global all_gens
145
+ if option == "None" or len(all_gens) == 0:
146
+ return None
147
+ elif option == "Sample 0":
148
+ image_overlay = all_gens[0].copy()
149
+ elif option == "Sample 1":
150
+ image_overlay = all_gens[0].copy()
151
+ else:
152
+ return None
153
+ image_overlay.putalpha(80)
154
+ return image_overlay
155
+
156
+ def change_num_samples(change):
157
+ global curr_num_samples
158
+ curr_num_samples = change
159
+ return None
160
+
161
  btn.click(run_sketching, [prompt_box, canvas, sketch_states, dilation_strength[0]], [gallery0, gallery1, gallery2, gallery3])
162
  btn2.click(reset, sketch_states, [canvas, sketch_states])
163
  # stroke_type[0].change(change_color, [stroke_type[0]], canvas)