Spaces:
Sleeping
Sleeping
Update lama_inpaint.py
Browse files- lama_inpaint.py +27 -15
lama_inpaint.py
CHANGED
@@ -5,7 +5,6 @@ import torch
|
|
5 |
import yaml
|
6 |
import glob
|
7 |
import argparse
|
8 |
-
from PIL import Image
|
9 |
from omegaconf import OmegaConf
|
10 |
from pathlib import Path
|
11 |
|
@@ -20,6 +19,7 @@ sys.path.insert(0, str(Path(__file__).resolve().parent / "lama"))
|
|
20 |
from saicinpainting.evaluation.utils import move_to_device
|
21 |
from saicinpainting.training.trainers import load_checkpoint
|
22 |
from saicinpainting.evaluation.data import pad_tensor_to_modulo
|
|
|
23 |
|
24 |
from utils import load_img_to_array, save_array_to_img
|
25 |
|
@@ -53,8 +53,7 @@ def inpaint_img_with_lama(
|
|
53 |
train_config, checkpoint_path, strict=False, map_location=device
|
54 |
)
|
55 |
model.freeze()
|
56 |
-
|
57 |
-
model.to(device)
|
58 |
|
59 |
batch = {}
|
60 |
batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
|
@@ -62,16 +61,30 @@ def inpaint_img_with_lama(
|
|
62 |
unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
|
63 |
batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
|
64 |
batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
|
65 |
-
batch = move_to_device(batch, device)
|
66 |
-
batch["mask"] = (batch["mask"] > 0) * 1
|
67 |
-
|
68 |
-
batch = model(batch)
|
69 |
-
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
|
70 |
-
cur_res = cur_res.detach().cpu().numpy()
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
cur_res = cur_res[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
77 |
return cur_res
|
@@ -98,8 +111,7 @@ def build_lama_model(config_p: str, ckpt_p: str, device="cuda"):
|
|
98 |
train_config, checkpoint_path, strict=False, map_location=device
|
99 |
)
|
100 |
model.freeze()
|
101 |
-
|
102 |
-
model.to(device)
|
103 |
|
104 |
return model
|
105 |
|
|
|
5 |
import yaml
|
6 |
import glob
|
7 |
import argparse
|
|
|
8 |
from omegaconf import OmegaConf
|
9 |
from pathlib import Path
|
10 |
|
|
|
19 |
from saicinpainting.evaluation.utils import move_to_device
|
20 |
from saicinpainting.training.trainers import load_checkpoint
|
21 |
from saicinpainting.evaluation.data import pad_tensor_to_modulo
|
22 |
+
from saicinpainting.evaluation.refinement import refine_predict
|
23 |
|
24 |
from utils import load_img_to_array, save_array_to_img
|
25 |
|
|
|
53 |
train_config, checkpoint_path, strict=False, map_location=device
|
54 |
)
|
55 |
model.freeze()
|
56 |
+
model.to(device)
|
|
|
57 |
|
58 |
batch = {}
|
59 |
batch["image"] = img.permute(2, 0, 1).unsqueeze(0)
|
|
|
61 |
unpad_to_size = [batch["image"].shape[2], batch["image"].shape[3]]
|
62 |
batch["image"] = pad_tensor_to_modulo(batch["image"], mod)
|
63 |
batch["mask"] = pad_tensor_to_modulo(batch["mask"], mod)
|
64 |
+
# batch = move_to_device(batch, device)
|
65 |
+
# batch["mask"] = (batch["mask"] > 0) * 1
|
66 |
+
|
67 |
+
# batch = model(batch)
|
68 |
+
# cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
|
69 |
+
# cur_res = cur_res.detach().cpu().numpy()
|
70 |
+
if predict_config.get("refine", False):
|
71 |
+
batch["unpad_to_size"] = [torch.tensor([size]) for size in unpad_to_size]
|
72 |
+
cur_res = refine_predict(batch, model, **predict_config.refiner)
|
73 |
+
cur_res = cur_res[0].permute(1, 2, 0).detach().cpu().numpy()
|
74 |
+
else:
|
75 |
+
batch = move_to_device(batch, device)
|
76 |
+
batch["mask"] = (batch["mask"] > 0) * 1
|
77 |
+
batch = model(batch)
|
78 |
+
cur_res = batch[predict_config.out_key][0].permute(1, 2, 0)
|
79 |
+
cur_res = cur_res.detach().cpu().numpy()
|
80 |
+
|
81 |
+
if unpad_to_size is not None:
|
82 |
+
orig_height, orig_width = unpad_to_size
|
83 |
+
cur_res = cur_res[:orig_height, :orig_width]
|
84 |
+
|
85 |
+
# if unpad_to_size is not None:
|
86 |
+
# orig_height, orig_width = unpad_to_size
|
87 |
+
# cur_res = cur_res[:orig_height, :orig_width]
|
88 |
|
89 |
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
90 |
return cur_res
|
|
|
111 |
train_config, checkpoint_path, strict=False, map_location=device
|
112 |
)
|
113 |
model.freeze()
|
114 |
+
model.to(device)
|
|
|
115 |
|
116 |
return model
|
117 |
|