radames commited on
Commit
266dbe0
·
1 Parent(s): 2d6c501

add pix2pix turbo

Browse files
server/pipelines/pix2pix/__init__.py ADDED
File without changes
server/pipelines/pix2pix/model.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/GaParmar/img2img-turbo/blob/main/src/model.py
2
+ from diffusers import DDPMScheduler
3
+
4
+
5
+ def make_1step_sched():
6
+ noise_scheduler_1step = DDPMScheduler.from_pretrained(
7
+ "stabilityai/sd-turbo", subfolder="scheduler"
8
+ )
9
+ noise_scheduler_1step.set_timesteps(1, device="cuda")
10
+ noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
11
+ return noise_scheduler_1step
12
+
13
+
14
+ def my_vae_encoder_fwd(self, sample):
15
+ sample = self.conv_in(sample)
16
+ l_blocks = []
17
+ # down
18
+ for down_block in self.down_blocks:
19
+ l_blocks.append(sample)
20
+ sample = down_block(sample)
21
+ # middle
22
+ sample = self.mid_block(sample)
23
+ sample = self.conv_norm_out(sample)
24
+ sample = self.conv_act(sample)
25
+ sample = self.conv_out(sample)
26
+ self.current_down_blocks = l_blocks
27
+ return sample
28
+
29
+
30
+ def my_vae_decoder_fwd(self, sample, latent_embeds=None):
31
+ sample = self.conv_in(sample)
32
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
33
+ # middle
34
+ sample = self.mid_block(sample, latent_embeds)
35
+ sample = sample.to(upscale_dtype)
36
+ if not self.ignore_skip:
37
+ skip_convs = [
38
+ self.skip_conv_1,
39
+ self.skip_conv_2,
40
+ self.skip_conv_3,
41
+ self.skip_conv_4,
42
+ ]
43
+ # up
44
+ for idx, up_block in enumerate(self.up_blocks):
45
+ skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
46
+ # add skip
47
+ sample = sample + skip_in
48
+ sample = up_block(sample, latent_embeds)
49
+ else:
50
+ for idx, up_block in enumerate(self.up_blocks):
51
+ sample = up_block(sample, latent_embeds)
52
+ # post-process
53
+ if latent_embeds is None:
54
+ sample = self.conv_norm_out(sample)
55
+ else:
56
+ sample = self.conv_norm_out(sample, latent_embeds)
57
+ sample = self.conv_act(sample)
58
+ sample = self.conv_out(sample)
59
+ return sample
server/pipelines/pix2pix/pix2pix_turbo.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/GaParmar/img2img-turbo/blob/main/src/pix2pix_turbo.py
2
+ import os
3
+ import requests
4
+ import sys
5
+ import pdb
6
+ import copy
7
+ from tqdm import tqdm
8
+ import torch
9
+ from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel
10
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
11
+ from diffusers.utils.peft_utils import set_weights_and_activate_adapters
12
+ from peft import LoraConfig
13
+
14
+ from pipelines.pix2pix.model import (
15
+ make_1step_sched,
16
+ my_vae_encoder_fwd,
17
+ my_vae_decoder_fwd,
18
+ )
19
+
20
+
21
+ class TwinConv(torch.nn.Module):
22
+ def __init__(self, convin_pretrained, convin_curr):
23
+ super(TwinConv, self).__init__()
24
+ self.conv_in_pretrained = copy.deepcopy(convin_pretrained)
25
+ self.conv_in_curr = copy.deepcopy(convin_curr)
26
+ self.r = None
27
+
28
+ def forward(self, x):
29
+ x1 = self.conv_in_pretrained(x).detach()
30
+ x2 = self.conv_in_curr(x)
31
+ return x1 * (1 - self.r) + x2 * (self.r)
32
+
33
+
34
+ class Pix2Pix_Turbo(torch.nn.Module):
35
+ def __init__(self, name, ckpt_folder="checkpoints"):
36
+ super().__init__()
37
+ self.tokenizer = AutoTokenizer.from_pretrained(
38
+ "stabilityai/sd-turbo", subfolder="tokenizer"
39
+ )
40
+ self.text_encoder = CLIPTextModel.from_pretrained(
41
+ "stabilityai/sd-turbo", subfolder="text_encoder"
42
+ ).cuda()
43
+ self.sched = make_1step_sched()
44
+
45
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
46
+ unet = UNet2DConditionModel.from_pretrained(
47
+ "stabilityai/sd-turbo", subfolder="unet"
48
+ )
49
+
50
+ if name == "edge_to_image":
51
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl"
52
+ os.makedirs(ckpt_folder, exist_ok=True)
53
+ outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl")
54
+ if not os.path.exists(outf):
55
+ print(f"Downloading checkpoint to {outf}")
56
+ response = requests.get(url, stream=True)
57
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
58
+ block_size = 1024 # 1 Kibibyte
59
+ progress_bar = tqdm(
60
+ total=total_size_in_bytes, unit="iB", unit_scale=True
61
+ )
62
+ with open(outf, "wb") as file:
63
+ for data in response.iter_content(block_size):
64
+ progress_bar.update(len(data))
65
+ file.write(data)
66
+ progress_bar.close()
67
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
68
+ print("ERROR, something went wrong")
69
+ print(f"Downloaded successfully to {outf}")
70
+ p_ckpt = outf
71
+ sd = torch.load(p_ckpt, map_location="cpu")
72
+ unet_lora_config = LoraConfig(
73
+ r=sd["rank_unet"],
74
+ init_lora_weights="gaussian",
75
+ target_modules=sd["unet_lora_target_modules"],
76
+ )
77
+
78
+ if name == "sketch_to_image_stochastic":
79
+ # download from url
80
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl"
81
+ os.makedirs(ckpt_folder, exist_ok=True)
82
+ outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl")
83
+ if not os.path.exists(outf):
84
+ print(f"Downloading checkpoint to {outf}")
85
+ response = requests.get(url, stream=True)
86
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
87
+ block_size = 1024 # 1 Kibibyte
88
+ progress_bar = tqdm(
89
+ total=total_size_in_bytes, unit="iB", unit_scale=True
90
+ )
91
+ with open(outf, "wb") as file:
92
+ for data in response.iter_content(block_size):
93
+ progress_bar.update(len(data))
94
+ file.write(data)
95
+ progress_bar.close()
96
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
97
+ print("ERROR, something went wrong")
98
+ print(f"Downloaded successfully to {outf}")
99
+ p_ckpt = outf
100
+ sd = torch.load(p_ckpt, map_location="cpu")
101
+ unet_lora_config = LoraConfig(
102
+ r=sd["rank_unet"],
103
+ init_lora_weights="gaussian",
104
+ target_modules=sd["unet_lora_target_modules"],
105
+ )
106
+ convin_pretrained = copy.deepcopy(unet.conv_in)
107
+ unet.conv_in = TwinConv(convin_pretrained, unet.conv_in)
108
+
109
+ vae.encoder.forward = my_vae_encoder_fwd.__get__(
110
+ vae.encoder, vae.encoder.__class__
111
+ )
112
+ vae.decoder.forward = my_vae_decoder_fwd.__get__(
113
+ vae.decoder, vae.decoder.__class__
114
+ )
115
+ # add the skip connection convs
116
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(
117
+ 512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
118
+ ).cuda()
119
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(
120
+ 256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
121
+ ).cuda()
122
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(
123
+ 128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
124
+ ).cuda()
125
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(
126
+ 128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
127
+ ).cuda()
128
+ vae_lora_config = LoraConfig(
129
+ r=sd["rank_vae"],
130
+ init_lora_weights="gaussian",
131
+ target_modules=sd["vae_lora_target_modules"],
132
+ )
133
+ vae.decoder.ignore_skip = False
134
+ vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
135
+ unet.add_adapter(unet_lora_config)
136
+ _sd_unet = unet.state_dict()
137
+ for k in sd["state_dict_unet"]:
138
+ _sd_unet[k] = sd["state_dict_unet"][k]
139
+ unet.load_state_dict(_sd_unet)
140
+ unet.enable_xformers_memory_efficient_attention()
141
+ _sd_vae = vae.state_dict()
142
+ for k in sd["state_dict_vae"]:
143
+ _sd_vae[k] = sd["state_dict_vae"][k]
144
+ vae.load_state_dict(_sd_vae)
145
+ unet.to("cuda")
146
+ vae.to("cuda")
147
+ unet.eval()
148
+ vae.eval()
149
+ self.unet, self.vae = unet, vae
150
+ self.vae.decoder.gamma = 1
151
+ self.timesteps = torch.tensor([999], device="cuda").long()
152
+ self.last_prompt = ""
153
+ self.caption_enc = None
154
+ self.device = "cuda"
155
+
156
+ def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=1.0):
157
+ # encode the text prompt
158
+ if prompt != self.last_prompt:
159
+ caption_tokens = self.tokenizer(
160
+ prompt,
161
+ max_length=self.tokenizer.model_max_length,
162
+ padding="max_length",
163
+ truncation=True,
164
+ return_tensors="pt",
165
+ ).input_ids.cuda()
166
+ caption_enc = self.text_encoder(caption_tokens)[0]
167
+ self.caption_enc = caption_enc
168
+ self.last_prompt = prompt
169
+
170
+ if deterministic:
171
+ encoded_control = (
172
+ self.vae.encode(c_t).latent_dist.sample()
173
+ * self.vae.config.scaling_factor
174
+ )
175
+ model_pred = self.unet(
176
+ encoded_control,
177
+ self.timesteps,
178
+ encoder_hidden_states=self.caption_enc,
179
+ ).sample
180
+ x_denoised = self.sched.step(
181
+ model_pred, self.timesteps, encoded_control, return_dict=True
182
+ ).prev_sample
183
+ self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
184
+ output_image = (
185
+ self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample
186
+ ).clamp(-1, 1)
187
+ else:
188
+ # scale the lora weights based on the r value
189
+ self.unet.set_adapters(["default"], weights=[r])
190
+ set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r])
191
+ encoded_control = (
192
+ self.vae.encode(c_t).latent_dist.sample()
193
+ * self.vae.config.scaling_factor
194
+ )
195
+ # combine the input and noise
196
+ unet_input = encoded_control * r + noise_map * (1 - r)
197
+ self.unet.conv_in.r = r
198
+ unet_output = self.unet(
199
+ unet_input,
200
+ self.timesteps,
201
+ encoder_hidden_states=self.caption_enc,
202
+ ).sample
203
+ self.unet.conv_in.r = None
204
+ x_denoised = self.sched.step(
205
+ unet_output, self.timesteps, unet_input, return_dict=True
206
+ ).prev_sample
207
+ self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
208
+ self.vae.decoder.gamma = r
209
+ output_image = (
210
+ self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample
211
+ ).clamp(-1, 1)
212
+ return output_image
server/pipelines/pix2pixTurbo.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+
4
+ from config import Args
5
+ from pydantic import BaseModel, Field
6
+ from PIL import Image
7
+ from pipelines.pix2pix.pix2pix_turbo import Pix2Pix_Turbo
8
+ from pipelines.utils.canny_gpu import SobelOperator
9
+
10
+
11
+ default_prompt = "close-up photo of the joker"
12
+ page_content = """
13
+ <h1 class="text-3xl font-bold">Real-Time pix2pix_turbo</h1>
14
+ <h3 class="text-xl font-bold">pix2pix turbo</h3>
15
+ <p class="text-sm">
16
+ This demo showcases
17
+ <a
18
+ href="https://github.com/GaParmar/img2img-turbo"
19
+ target="_blank"
20
+ class="text-blue-500 underline hover:no-underline">One-Step Image Translation with Text-to-Image Models
21
+ </a>
22
+ </p>
23
+ """
24
+
25
+
26
+ class Pipeline:
27
+ class Info(BaseModel):
28
+ name: str = "img2img"
29
+ title: str = "Image-to-Image SDXL"
30
+ description: str = "Generates an image from a text prompt"
31
+ input_mode: str = "image"
32
+ page_content: str = page_content
33
+
34
+ class InputParams(BaseModel):
35
+ prompt: str = Field(
36
+ default_prompt,
37
+ title="Prompt",
38
+ field="textarea",
39
+ id="prompt",
40
+ )
41
+
42
+ width: int = Field(
43
+ 512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
44
+ )
45
+ height: int = Field(
46
+ 512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
47
+ )
48
+ strength: float = Field(
49
+ 1.0,
50
+ min=0.01,
51
+ max=10.0,
52
+ step=0.001,
53
+ title="Strength",
54
+ field="range",
55
+ hide=True,
56
+ id="strength",
57
+ )
58
+ deterministic: bool = Field(
59
+ True,
60
+ hide=True,
61
+ title="Deterministic",
62
+ field="checkbox",
63
+ id="deterministic",
64
+ )
65
+ canny_low_threshold: float = Field(
66
+ 0.31,
67
+ min=0,
68
+ max=1.0,
69
+ step=0.001,
70
+ title="Canny Low Threshold",
71
+ field="range",
72
+ hide=True,
73
+ id="canny_low_threshold",
74
+ )
75
+ canny_high_threshold: float = Field(
76
+ 0.125,
77
+ min=0,
78
+ max=1.0,
79
+ step=0.001,
80
+ title="Canny High Threshold",
81
+ field="range",
82
+ hide=True,
83
+ id="canny_high_threshold",
84
+ )
85
+ debug_canny: bool = Field(
86
+ False,
87
+ title="Debug Canny",
88
+ field="checkbox",
89
+ hide=True,
90
+ id="debug_canny",
91
+ )
92
+
93
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
94
+ self.model = Pix2Pix_Turbo("edge_to_image")
95
+ self.canny_torch = SobelOperator(device=device)
96
+ self.device = device
97
+
98
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
99
+ # generator = torch.manual_seed(params.seed)
100
+ # pipe = self.pipes[params.base_model_id]
101
+
102
+ canny_pil, canny_tensor = self.canny_torch(
103
+ params.image,
104
+ params.canny_low_threshold,
105
+ params.canny_high_threshold,
106
+ output_type="pil,tensor",
107
+ )
108
+
109
+ with torch.no_grad():
110
+ canny_tensor = torch.cat((canny_tensor, canny_tensor, canny_tensor), dim=1)
111
+ output_image = self.model(
112
+ canny_tensor,
113
+ params.prompt,
114
+ params.deterministic,
115
+ params.strength,
116
+ )
117
+ output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
118
+
119
+ result_image = output_pil
120
+ if params.debug_canny:
121
+ # paste control_image on top of result_image
122
+ w0, h0 = (200, 200)
123
+ control_image = canny_pil.resize((w0, h0))
124
+ w1, h1 = result_image.size
125
+ result_image.paste(control_image, (w1 - w0, h1 - h0))
126
+
127
+ return result_image
server/pipelines/utils/canny_gpu.py CHANGED
@@ -3,6 +3,7 @@ import torch.nn as nn
3
  from torchvision.transforms import ToTensor, ToPILImage
4
  from PIL import Image
5
 
 
6
  class SobelOperator(nn.Module):
7
  def __init__(self, device="cuda"):
8
  super(SobelOperator, self).__init__()
@@ -25,7 +26,13 @@ class SobelOperator(nn.Module):
25
  self.edge_conv_y.weight = nn.Parameter(sobel_kernel_y.view((1, 1, 3, 3)))
26
 
27
  @torch.no_grad()
28
- def forward(self, image: Image.Image, low_threshold: float, high_threshold: float):
 
 
 
 
 
 
29
  # Convert PIL image to PyTorch tensor
30
  image_gray = image.convert("L")
31
  image_tensor = ToTensor()(image_gray).unsqueeze(0).to(self.device)
@@ -41,4 +48,9 @@ class SobelOperator(nn.Module):
41
  edge[edge <= low_threshold] = 0.0
42
 
43
  # Convert the result back to a PIL image
44
- return ToPILImage()(edge.squeeze(0).cpu())
 
 
 
 
 
 
3
  from torchvision.transforms import ToTensor, ToPILImage
4
  from PIL import Image
5
 
6
+
7
  class SobelOperator(nn.Module):
8
  def __init__(self, device="cuda"):
9
  super(SobelOperator, self).__init__()
 
26
  self.edge_conv_y.weight = nn.Parameter(sobel_kernel_y.view((1, 1, 3, 3)))
27
 
28
  @torch.no_grad()
29
+ def forward(
30
+ self,
31
+ image: Image.Image,
32
+ low_threshold: float,
33
+ high_threshold: float,
34
+ output_type="pil",
35
+ ) -> Image.Image | torch.Tensor | tuple[Image.Image, torch.Tensor]:
36
  # Convert PIL image to PyTorch tensor
37
  image_gray = image.convert("L")
38
  image_tensor = ToTensor()(image_gray).unsqueeze(0).to(self.device)
 
48
  edge[edge <= low_threshold] = 0.0
49
 
50
  # Convert the result back to a PIL image
51
+ if output_type == "pil":
52
+ return ToPILImage()(edge.squeeze(0).cpu())
53
+ elif output_type == "tensor":
54
+ return edge
55
+ elif output_type == "pil,tensor":
56
+ return ToPILImage()(edge.squeeze(0).cpu()), edge
server/requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- diffusers==0.27.1
2
  transformers==4.36.2
3
  --extra-index-url https://download.pytorch.org/whl/cu121;
4
  torch==2.2.0
@@ -8,10 +8,11 @@ Pillow==10.2.0
8
  accelerate==0.25.0
9
  compel==2.0.2
10
  controlnet-aux==0.0.7
11
- peft==0.6.0
12
  xformers; sys_platform != 'darwin' or platform_machine != 'arm64'
13
  markdown2
14
  safetensors
15
  stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/v1.0.4/stable_fast-1.0.4+torch220cu121-cp310-cp310-manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
16
  oneflow @ https://github.com/siliconflow/oneflow_releases/releases/download/community_cu121/oneflow-0.9.1.dev20240316+cu121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
17
- onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
 
 
1
+ diffusers==0.25.1
2
  transformers==4.36.2
3
  --extra-index-url https://download.pytorch.org/whl/cu121;
4
  torch==2.2.0
 
8
  accelerate==0.25.0
9
  compel==2.0.2
10
  controlnet-aux==0.0.7
11
+ peft==0.9.0
12
  xformers; sys_platform != 'darwin' or platform_machine != 'arm64'
13
  markdown2
14
  safetensors
15
  stable_fast @ https://github.com/chengzeyi/stable-fast/releases/download/v1.0.4/stable_fast-1.0.4+torch220cu121-cp310-cp310-manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
16
  oneflow @ https://github.com/siliconflow/oneflow_releases/releases/download/community_cu121/oneflow-0.9.1.dev20240316+cu121-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl ; sys_platform != 'darwin' or platform_machine != 'arm64'
17
+ onediff @ git+https://github.com/siliconflow/onediff.git@main#egg=onediff ; sys_platform != 'darwin' or platform_machine != 'arm64'
18
+ setuptools