Spaces:
Running
on
A10G
Running
on
A10G
Andranik Sargsyan
commited on
Commit
β’
da1e12f
1
Parent(s):
bfd34e9
enable fp16, move SR to cuda:1
Browse files- app.py +2 -2
- lib/methods/rasg.py +9 -6
- lib/methods/sd.py +8 -6
- lib/methods/sr.py +11 -9
- lib/models/ds_inp.py +7 -2
- lib/models/sam.py +3 -4
- lib/models/sd15_inp.py +6 -1
- lib/models/sd2_inp.py +8 -1
- lib/models/sd2_sr.py +11 -8
- lib/smplfusion/ddim.py +5 -3
- lib/smplfusion/models/unet.py +8 -1
- lib/utils/iimage.py +4 -0
app.py
CHANGED
@@ -64,8 +64,8 @@ inpainting_models = OrderedDict([
|
|
64 |
("Stable-Inpainting 2.0", models.sd2_inp.load_model()),
|
65 |
("Stable-Inpainting 1.5", models.sd15_inp.load_model())
|
66 |
])
|
67 |
-
sr_model = models.sd2_sr.load_model()
|
68 |
-
sam_predictor = models.sam.load_model()
|
69 |
|
70 |
inp_model = None
|
71 |
cached_inp_model_name = ''
|
|
|
64 |
("Stable-Inpainting 2.0", models.sd2_inp.load_model()),
|
65 |
("Stable-Inpainting 1.5", models.sd15_inp.load_model())
|
66 |
])
|
67 |
+
sr_model = models.sd2_sr.load_model(device='cuda:1')
|
68 |
+
sam_predictor = models.sam.load_model(device='cuda:0')
|
69 |
|
70 |
inp_model = None
|
71 |
cached_inp_model_name = ''
|
lib/methods/rasg.py
CHANGED
@@ -38,9 +38,11 @@ def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, p
|
|
38 |
unet_condition = ddim.get_inpainting_condition(image, mask)
|
39 |
share.set_mask(mask)
|
40 |
|
|
|
|
|
41 |
# Starting latent
|
42 |
seed_everything(seed)
|
43 |
-
zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda()
|
44 |
|
45 |
# Setup unet for guidance
|
46 |
ddim.unet.requires_grad_(True)
|
@@ -58,11 +60,12 @@ def run(ddim, method, prompt, image, mask, seed, eta, prefix, negative_prompt, p
|
|
58 |
|
59 |
# Run the model
|
60 |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
# Unconditional guidance
|
68 |
eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
|
|
|
38 |
unet_condition = ddim.get_inpainting_condition(image, mask)
|
39 |
share.set_mask(mask)
|
40 |
|
41 |
+
dtype = unet_condition.dtype
|
42 |
+
|
43 |
# Starting latent
|
44 |
seed_everything(seed)
|
45 |
+
zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda().to(dtype)
|
46 |
|
47 |
# Setup unet for guidance
|
48 |
ddim.unet.requires_grad_(True)
|
|
|
60 |
|
61 |
# Run the model
|
62 |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
|
63 |
+
with torch.autocast('cuda'):
|
64 |
+
eps_uncond, eps = ddim.unet(
|
65 |
+
torch.cat([_zt, _zt]).to(dtype),
|
66 |
+
timesteps = torch.tensor([timestep, timestep]).cuda(),
|
67 |
+
context = context
|
68 |
+
).detach().chunk(2)
|
69 |
|
70 |
# Unconditional guidance
|
71 |
eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
|
lib/methods/sd.py
CHANGED
@@ -43,11 +43,12 @@ def run(
|
|
43 |
|
44 |
# Image condition
|
45 |
unet_condition = ddim.get_inpainting_condition(image, mask)
|
|
|
46 |
share.set_mask(mask)
|
47 |
|
48 |
# Starting latent
|
49 |
seed_everything(seed)
|
50 |
-
zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda()
|
51 |
|
52 |
# Turn off gradients
|
53 |
ddim.unet.requires_grad_(False)
|
@@ -58,11 +59,12 @@ def run(
|
|
58 |
if share.timestep <= 500: router.reset()
|
59 |
|
60 |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
|
67 |
eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
|
68 |
z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep]
|
|
|
43 |
|
44 |
# Image condition
|
45 |
unet_condition = ddim.get_inpainting_condition(image, mask)
|
46 |
+
dtype = unet_condition.dtype
|
47 |
share.set_mask(mask)
|
48 |
|
49 |
# Starting latent
|
50 |
seed_everything(seed)
|
51 |
+
zt = torch.randn((1,4) + unet_condition.shape[2:]).cuda().to(dtype)
|
52 |
|
53 |
# Turn off gradients
|
54 |
ddim.unet.requires_grad_(False)
|
|
|
59 |
if share.timestep <= 500: router.reset()
|
60 |
|
61 |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
|
62 |
+
with torch.autocast('cuda'):
|
63 |
+
eps_uncond, eps = ddim.unet(
|
64 |
+
torch.cat([_zt, _zt]).to(dtype),
|
65 |
+
timesteps = torch.tensor([timestep, timestep]).cuda(),
|
66 |
+
context = context
|
67 |
+
).chunk(2)
|
68 |
|
69 |
eps = (eps_uncond + guidance_scale * (eps - eps_uncond))
|
70 |
z0 = (zt - share.schedule.sqrt_one_minus_alphas[timestep] * eps) / share.schedule.sqrt_alphas[timestep]
|
lib/methods/sr.py
CHANGED
@@ -59,8 +59,10 @@ def refine_mask(hr_image, hr_mask, lr_image, sam_predictor):
|
|
59 |
|
60 |
def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
|
61 |
blend_output = True, blend_trick = True, no_superres = False,
|
62 |
-
dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False
|
63 |
torch.manual_seed(seed)
|
|
|
|
|
64 |
|
65 |
router.attention_forward = attentionpatch.default.forward_xformers
|
66 |
router.basic_transformer_forward = transformerpatch.default.forward
|
@@ -74,7 +76,7 @@ dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = Fa
|
|
74 |
hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
|
75 |
hr_mask_orig = hr_mask
|
76 |
lr_image = lr_image.padx(64, padding_mode='reflect')
|
77 |
-
lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).
|
78 |
lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
|
79 |
|
80 |
if no_superres:
|
@@ -89,18 +91,18 @@ dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = Fa
|
|
89 |
|
90 |
# encode hr image
|
91 |
with torch.no_grad():
|
92 |
-
hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype)).mean * ddim.config.scale_factor
|
93 |
|
94 |
assert hr_z0.shape[2] == lr_image.torch().shape[2]
|
95 |
assert hr_z0.shape[3] == lr_image.torch().shape[3]
|
96 |
|
97 |
-
unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype)
|
98 |
-
zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype)
|
99 |
|
100 |
with torch.no_grad():
|
101 |
context = ddim.encoder.encode([negative_prompt, prompt])
|
102 |
|
103 |
-
noise_level = torch.Tensor(1 * [noise_level]).to(
|
104 |
unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
|
105 |
|
106 |
with torch.autocast('cuda'), torch.no_grad():
|
@@ -110,13 +112,13 @@ dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = Fa
|
|
110 |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
|
111 |
|
112 |
eps_uncond, eps = ddim.unet(
|
113 |
-
torch.cat([_zt, _zt]).to(dtype),
|
114 |
-
timesteps = torch.tensor([t, t]).
|
115 |
context = context,
|
116 |
y=torch.cat([noise_level]*2)
|
117 |
).chunk(2)
|
118 |
|
119 |
-
ts = torch.full((zt.shape[0],), t, device=
|
120 |
model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
|
121 |
eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
|
122 |
z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
|
|
|
59 |
|
60 |
def run(ddim, sam_predictor, lr_image, hr_image, hr_mask, prompt = 'high resolution professional photo', noise_level=20,
|
61 |
blend_output = True, blend_trick = True, no_superres = False,
|
62 |
+
dt = 20, seed = 1, guidance_scale = 7.5, negative_prompt = '', use_sam_mask = False):
|
63 |
torch.manual_seed(seed)
|
64 |
+
dtype = ddim.vae.encoder.conv_in.weight.dtype
|
65 |
+
device = ddim.vae.encoder.conv_in.weight.device
|
66 |
|
67 |
router.attention_forward = attentionpatch.default.forward_xformers
|
68 |
router.basic_transformer_forward = transformerpatch.default.forward
|
|
|
76 |
hr_mask = hr_mask.padx(256, padding_mode='reflect').dilate(19)
|
77 |
hr_mask_orig = hr_mask
|
78 |
lr_image = lr_image.padx(64, padding_mode='reflect')
|
79 |
+
lr_mask = hr_mask.resize((lr_image.torch().shape[2], lr_image.torch().shape[3]), resample = Image.BICUBIC).alpha().torch(vmin=0).to(device)
|
80 |
lr_mask = TvF.gaussian_blur(lr_mask, kernel_size=19)
|
81 |
|
82 |
if no_superres:
|
|
|
91 |
|
92 |
# encode hr image
|
93 |
with torch.no_grad():
|
94 |
+
hr_z0 = ddim.vae.encode(hr_image.torch().cuda().to(dtype=dtype, device=device)).mean * ddim.config.scale_factor
|
95 |
|
96 |
assert hr_z0.shape[2] == lr_image.torch().shape[2]
|
97 |
assert hr_z0.shape[3] == lr_image.torch().shape[3]
|
98 |
|
99 |
+
unet_condition = lr_image.cuda().torch().to(memory_format=torch.contiguous_format).to(dtype=dtype, device=device)
|
100 |
+
zT = torch.randn((1,4,unet_condition.shape[2], unet_condition.shape[3])).cuda().to(dtype=dtype, device=device)
|
101 |
|
102 |
with torch.no_grad():
|
103 |
context = ddim.encoder.encode([negative_prompt, prompt])
|
104 |
|
105 |
+
noise_level = torch.Tensor(1 * [noise_level]).to(device=device).long()
|
106 |
unet_condition, noise_level = ddim.low_scale_model(unet_condition, noise_level=noise_level)
|
107 |
|
108 |
with torch.autocast('cuda'), torch.no_grad():
|
|
|
112 |
_zt = zt if unet_condition is None else torch.cat([zt, unet_condition], 1)
|
113 |
|
114 |
eps_uncond, eps = ddim.unet(
|
115 |
+
torch.cat([_zt, _zt]).to(dtype=dtype, device=device),
|
116 |
+
timesteps = torch.tensor([t, t]).to(device=device),
|
117 |
context = context,
|
118 |
y=torch.cat([noise_level]*2)
|
119 |
).chunk(2)
|
120 |
|
121 |
+
ts = torch.full((zt.shape[0],), t, device=device, dtype=torch.long)
|
122 |
model_output = (eps_uncond + guidance_scale * (eps - eps_uncond))
|
123 |
eps = predict_eps_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
|
124 |
z0 = predict_start_from_z_and_v(ddim.schedule, zt, ts, model_output).to(dtype)
|
lib/models/ds_inp.py
CHANGED
@@ -15,7 +15,7 @@ DOWNLOAD_URL = 'https://civitai.com/api/download/models/131004'
|
|
15 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
16 |
|
17 |
|
18 |
-
def load_model():
|
19 |
print ("Loading model: Dreamshaper Inpainting V8")
|
20 |
|
21 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
@@ -36,10 +36,15 @@ def load_model():
|
|
36 |
encoder.load_state_dict(encoder_state)
|
37 |
vae.load_state_dict(vae_state)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
39 |
unet = unet.requires_grad_(False)
|
40 |
encoder = encoder.requires_grad_(False)
|
41 |
vae = vae.requires_grad_(False)
|
42 |
-
|
43 |
ddim = DDIM(config, vae, encoder, unet)
|
44 |
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
|
45 |
|
|
|
15 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
16 |
|
17 |
|
18 |
+
def load_model(dtype=torch.float16):
|
19 |
print ("Loading model: Dreamshaper Inpainting V8")
|
20 |
|
21 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
|
|
36 |
encoder.load_state_dict(encoder_state)
|
37 |
vae.load_state_dict(vae_state)
|
38 |
|
39 |
+
if dtype == torch.float16:
|
40 |
+
unet.convert_to_fp16()
|
41 |
+
vae.to(dtype)
|
42 |
+
encoder.to(dtype)
|
43 |
+
|
44 |
unet = unet.requires_grad_(False)
|
45 |
encoder = encoder.requires_grad_(False)
|
46 |
vae = vae.requires_grad_(False)
|
47 |
+
|
48 |
ddim = DDIM(config, vae, encoder, unet)
|
49 |
share.schedule = scheduler.linear(config.timesteps, config.linear_start, config.linear_end)
|
50 |
|
lib/models/sam.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from segment_anything import sam_model_registry, SamPredictor
|
2 |
from .common import *
|
3 |
|
@@ -8,12 +9,10 @@ DOWNLOAD_URL = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939
|
|
8 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
9 |
|
10 |
|
11 |
-
def load_model():
|
12 |
print ("Loading model: SAM")
|
13 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
14 |
-
|
15 |
-
device = "cuda"
|
16 |
-
sam = sam_model_registry[model_type](checkpoint=MODEL_PATH)
|
17 |
sam.to(device=device)
|
18 |
sam_predictor = SamPredictor(sam)
|
19 |
print ("SAM loaded")
|
|
|
1 |
+
import torch
|
2 |
from segment_anything import sam_model_registry, SamPredictor
|
3 |
from .common import *
|
4 |
|
|
|
9 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
10 |
|
11 |
|
12 |
+
def load_model(device='cuda:0'):
|
13 |
print ("Loading model: SAM")
|
14 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
15 |
+
sam = sam_model_registry["vit_h"](checkpoint=MODEL_PATH)
|
|
|
|
|
16 |
sam.to(device=device)
|
17 |
sam_predictor = SamPredictor(sam)
|
18 |
print ("SAM loaded")
|
lib/models/sd15_inp.py
CHANGED
@@ -12,7 +12,7 @@ MODEL_PATH = f'{MODEL_FOLDER}/sd-1-5-inpainting/sd-v1-5-inpainting.ckpt'
|
|
12 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
13 |
|
14 |
|
15 |
-
def load_model():
|
16 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
17 |
|
18 |
state_dict = torch.load(MODEL_PATH)['state_dict']
|
@@ -34,6 +34,11 @@ def load_model():
|
|
34 |
encoder.load_state_dict(encoder_state)
|
35 |
vae.load_state_dict(vae_state)
|
36 |
|
|
|
|
|
|
|
|
|
|
|
37 |
unet = unet.requires_grad_(False)
|
38 |
encoder = encoder.requires_grad_(False)
|
39 |
vae = vae.requires_grad_(False)
|
|
|
12 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
13 |
|
14 |
|
15 |
+
def load_model(dtype=torch.float16):
|
16 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
17 |
|
18 |
state_dict = torch.load(MODEL_PATH)['state_dict']
|
|
|
34 |
encoder.load_state_dict(encoder_state)
|
35 |
vae.load_state_dict(vae_state)
|
36 |
|
37 |
+
if dtype == torch.float16:
|
38 |
+
unet.convert_to_fp16()
|
39 |
+
vae.to(dtype)
|
40 |
+
encoder.to(dtype)
|
41 |
+
|
42 |
unet = unet.requires_grad_(False)
|
43 |
encoder = encoder.requires_grad_(False)
|
44 |
vae = vae.requires_grad_(False)
|
lib/models/sd2_inp.py
CHANGED
@@ -13,7 +13,7 @@ DOWNLOAD_URL = 'https://huggingface.co/stabilityai/stable-diffusion-2-inpainting
|
|
13 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
14 |
|
15 |
|
16 |
-
def load_model():
|
17 |
print ("Loading model: Stable-Inpainting 2.0")
|
18 |
|
19 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
@@ -36,6 +36,13 @@ def load_model():
|
|
36 |
encoder.load_state_dict(encoder_state)
|
37 |
vae.load_state_dict(vae_state)
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
unet = unet.requires_grad_(False)
|
40 |
encoder = encoder.requires_grad_(False)
|
41 |
vae = vae.requires_grad_(False)
|
|
|
13 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
14 |
|
15 |
|
16 |
+
def load_model(dtype=torch.float16, device='cuda:0'):
|
17 |
print ("Loading model: Stable-Inpainting 2.0")
|
18 |
|
19 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
|
|
36 |
encoder.load_state_dict(encoder_state)
|
37 |
vae.load_state_dict(vae_state)
|
38 |
|
39 |
+
if dtype == torch.float16:
|
40 |
+
unet.convert_to_fp16()
|
41 |
+
unet.to(device=device)
|
42 |
+
vae.to(dtype=dtype, device=device)
|
43 |
+
encoder.to(dtype=dtype, device=device)
|
44 |
+
encoder.device = device
|
45 |
+
|
46 |
unet = unet.requires_grad_(False)
|
47 |
encoder = encoder.requires_grad_(False)
|
48 |
vae = vae.requires_grad_(False)
|
lib/models/sd2_sr.py
CHANGED
@@ -39,15 +39,15 @@ def extract_into_tensor(a, t, x_shape):
|
|
39 |
|
40 |
def predict_eps_from_z_and_v(schedule, x_t, t, v):
|
41 |
return (
|
42 |
-
extract_into_tensor(schedule.sqrt_alphas.
|
43 |
-
extract_into_tensor(schedule.sqrt_one_minus_alphas.
|
44 |
)
|
45 |
|
46 |
|
47 |
def predict_start_from_z_and_v(schedule, x_t, t, v):
|
48 |
return (
|
49 |
-
extract_into_tensor(schedule.sqrt_alphas.
|
50 |
-
extract_into_tensor(schedule.sqrt_one_minus_alphas.
|
51 |
)
|
52 |
|
53 |
|
@@ -153,7 +153,7 @@ def load_obj(path):
|
|
153 |
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
|
154 |
|
155 |
|
156 |
-
def load_model(dtype=torch.bfloat16):
|
157 |
print ("Loading model: SD2 superresolution...")
|
158 |
|
159 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
@@ -180,9 +180,10 @@ def load_model(dtype=torch.bfloat16):
|
|
180 |
encoder = encoder.requires_grad_(False)
|
181 |
vae = vae.requires_grad_(False)
|
182 |
|
183 |
-
unet.to(dtype)
|
184 |
-
vae.to(dtype)
|
185 |
-
encoder.to(dtype)
|
|
|
186 |
|
187 |
ddim = DDIM(config, vae, encoder, unet)
|
188 |
|
@@ -199,6 +200,8 @@ def load_model(dtype=torch.bfloat16):
|
|
199 |
for param in low_scale_model.parameters():
|
200 |
param.requires_grad = False
|
201 |
|
|
|
|
|
202 |
ddim.low_scale_model = low_scale_model
|
203 |
print('SD2 superresolution loaded')
|
204 |
return ddim
|
|
|
39 |
|
40 |
def predict_eps_from_z_and_v(schedule, x_t, t, v):
|
41 |
return (
|
42 |
+
extract_into_tensor(schedule.sqrt_alphas.to(x_t.device), t, x_t.shape) * v +
|
43 |
+
extract_into_tensor(schedule.sqrt_one_minus_alphas.to(x_t.device), t, x_t.shape) * x_t
|
44 |
)
|
45 |
|
46 |
|
47 |
def predict_start_from_z_and_v(schedule, x_t, t, v):
|
48 |
return (
|
49 |
+
extract_into_tensor(schedule.sqrt_alphas.to(x_t.device), t, x_t.shape) * x_t -
|
50 |
+
extract_into_tensor(schedule.sqrt_one_minus_alphas.to(x_t.device), t, x_t.shape) * v
|
51 |
)
|
52 |
|
53 |
|
|
|
153 |
return get_obj_from_str(objyaml['__class__'])(**objyaml.get("__init__", {}))
|
154 |
|
155 |
|
156 |
+
def load_model(dtype=torch.bfloat16, device='cuda:0'):
|
157 |
print ("Loading model: SD2 superresolution...")
|
158 |
|
159 |
download_file(DOWNLOAD_URL, MODEL_PATH)
|
|
|
180 |
encoder = encoder.requires_grad_(False)
|
181 |
vae = vae.requires_grad_(False)
|
182 |
|
183 |
+
unet.to(dtype=dtype, device=device)
|
184 |
+
vae.to(dtype=dtype, device=device)
|
185 |
+
encoder.to(dtype=dtype, device=device)
|
186 |
+
encoder.device = device
|
187 |
|
188 |
ddim = DDIM(config, vae, encoder, unet)
|
189 |
|
|
|
200 |
for param in low_scale_model.parameters():
|
201 |
param.requires_grad = False
|
202 |
|
203 |
+
low_scale_model = low_scale_model.to(dtype=dtype, device=device)
|
204 |
+
|
205 |
ddim.low_scale_model = low_scale_model
|
206 |
print('SD2 superresolution loaded')
|
207 |
return ddim
|
lib/smplfusion/ddim.py
CHANGED
@@ -43,11 +43,13 @@ class DDIM:
|
|
43 |
|
44 |
def get_inpainting_condition(self, image, mask):
|
45 |
latent_size = [x//8 for x in image.size]
|
|
|
46 |
with torch.no_grad():
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
|
|
|
51 |
condition_x0 += 0.01 * condition_mask * torch.randn_like(condition_mask)
|
52 |
return torch.cat([condition_mask, condition_x0], 1)
|
53 |
|
|
|
43 |
|
44 |
def get_inpainting_condition(self, image, mask):
|
45 |
latent_size = [x//8 for x in image.size]
|
46 |
+
dtype = self.vae.encoder.conv_in.weight.dtype
|
47 |
with torch.no_grad():
|
48 |
+
masked_image = image.torch().cuda() * ~mask.torch(0).bool().cuda()
|
49 |
+
masked_image = masked_image.to(dtype)
|
50 |
+
condition_x0 = self.vae.encode(masked_image).mean * self.config.scale_factor
|
51 |
|
52 |
+
condition_mask = mask.resize(latent_size[::-1]).cuda().torch(0).bool().to(dtype)
|
53 |
condition_x0 += 0.01 * condition_mask * torch.randn_like(condition_mask)
|
54 |
return torch.cat([condition_mask, condition_x0], 1)
|
55 |
|
lib/smplfusion/models/unet.py
CHANGED
@@ -14,7 +14,14 @@ from ..modules.attention.spatial_transformer import SpatialTransformer
|
|
14 |
|
15 |
|
16 |
# dummy replace
|
17 |
-
def convert_module_to_f16(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def convert_module_to_f32(x): pass
|
19 |
|
20 |
|
|
|
14 |
|
15 |
|
16 |
# dummy replace
|
17 |
+
def convert_module_to_f16(param):
|
18 |
+
"""
|
19 |
+
Convert primitive modules to float16.
|
20 |
+
"""
|
21 |
+
if isinstance(param, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
22 |
+
param.weight.data = param.weight.data.half()
|
23 |
+
if param.bias is not None:
|
24 |
+
param.bias.data = param.bias.data.half()
|
25 |
def convert_module_to_f32(x): pass
|
26 |
|
27 |
|
lib/utils/iimage.py
CHANGED
@@ -59,6 +59,10 @@ class IImage:
|
|
59 |
data = self.data.transpose(0, 3, 1, 2) / 255.
|
60 |
return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin)
|
61 |
|
|
|
|
|
|
|
|
|
62 |
def cuda(self):
|
63 |
self.device = 'cuda'
|
64 |
return self
|
|
|
59 |
data = self.data.transpose(0, 3, 1, 2) / 255.
|
60 |
return vmin + torch.from_numpy(data).float().to(self.device) * (vmax - vmin)
|
61 |
|
62 |
+
def to(self, device):
|
63 |
+
self.device = device
|
64 |
+
return self
|
65 |
+
|
66 |
def cuda(self):
|
67 |
self.device = 'cuda'
|
68 |
return self
|