doevent commited on
Commit
d3ab762
1 Parent(s): ee506ee

Upload predict.py

Browse files
Files changed (1) hide show
  1. predict.py +104 -0
predict.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://github.com/replicate/cog/blob/main/docs/python.md
3
+
4
+ from cog import BasePredictor, Input, Path
5
+ import tempfile
6
+ import os, glob
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from models import model, basic
14
+ from utils import util
15
+
16
+ class Predictor(BasePredictor):
17
+ def setup(self):
18
+ seed = 130
19
+ np.random.seed(seed)
20
+ torch.manual_seed(seed)
21
+ torch.cuda.manual_seed(seed)
22
+ #print('--------------', torch.cuda.is_available())
23
+ """Load the model into memory to make running multiple predictions efficient"""
24
+ self.colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True)
25
+ self.colorizer = self.colorizer.cuda()
26
+ checkpt_path = "./checkpoints/disco-beta.pth.rar"
27
+ assert os.path.exists(checkpt_path)
28
+ data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
29
+ self.colorizer.load_state_dict(data_dict['state_dict'])
30
+ self.colorizer.eval()
31
+ self.color_class = basic.ColorLabel(lambda_=0.5, device='cuda')
32
+
33
+ def resize_ab2l(self, gray_img, lab_imgs):
34
+ H, W = gray_img.shape[:2]
35
+ reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
36
+ return np.concatenate((gray_img, reszied_ab), axis=2)
37
+
38
+ def predict(
39
+ self,
40
+ image: Path = Input(description="input image. Output will be one or multiple colorized images."),
41
+ n_anchors: int = Input(
42
+ description="number of color anchors", ge=3, le=14, default=8
43
+ ),
44
+ multi_result: bool = Input(
45
+ description="to generate diverse results", default=False
46
+ ),
47
+ vis_anchors: bool = Input(
48
+ description="to visualize the anchor locations", default=False
49
+ )
50
+ ) -> Path:
51
+ """Run a single prediction on the model"""
52
+ bgr_img = cv2.imread(str(image), cv2.IMREAD_COLOR)
53
+ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
54
+ rgb_img = np.array(rgb_img / 255., np.float32)
55
+ lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
56
+ org_grays = (lab_img[:,:,[0]]-50.) / 50.
57
+ lab_img = cv2.resize(lab_img, (256,256), interpolation=cv2.INTER_LINEAR)
58
+
59
+ lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
60
+ gray_img = (lab_img[0:1,:,:]-50.) / 50.
61
+ ab_chans = lab_img[1:3,:,:] / 110.
62
+ input_grays = gray_img.unsqueeze(0)
63
+ input_colors = ab_chans.unsqueeze(0)
64
+ input_grays = input_grays.cuda(non_blocking=True)
65
+ input_colors = input_colors.cuda(non_blocking=True)
66
+
67
+ sampled_T = 2 if multi_result else 0
68
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = self.colorizer(input_grays, \
69
+ input_colors, n_anchors, True, sampled_T)
70
+ pred_probs = pal_logit
71
+ guided_colors = self.color_class.decode_ind2ab(ref_logit, T=0)
72
+ sp_size = 16
73
+ guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
74
+ res_list = []
75
+ if multi_result:
76
+ for no in range(3):
77
+ pred_labs = torch.cat((input_grays,enhanced_ab[no:no+1,:,:,:]), dim=1)
78
+ lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
79
+ lab_imgs = self.resize_ab2l(org_grays, lab_imgs)
80
+ #util.save_normLabs_from_batch(lab_imgs, save_dir, [file_name], -1, suffix='c%d'%no)
81
+ res_list.append(lab_imgs)
82
+ else:
83
+ pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
84
+ lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
85
+ lab_imgs = self.resize_ab2l(org_grays, lab_imgs)
86
+ #util.save_normLabs_from_batch(lab_imgs, save_dir, [file_name], -1)#, suffix='enhanced')
87
+ res_list.append(lab_imgs)
88
+
89
+ if vis_anchors:
90
+ ## visualize anchor locations
91
+ anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
92
+ marked_labs = basic.mark_color_hints(input_grays, enhanced_ab, anchor_masks, base_ABs=enhanced_ab)
93
+ hint_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
94
+ hint_imgs = self.resize_ab2l(org_grays, hint_imgs)
95
+ #util.save_normLabs_from_batch(hint_imgs, save_dir, [file_name], -1, suffix='anchors')
96
+ res_list.append(hint_imgs)
97
+
98
+ output = cv2.vconcat(res_list)
99
+ output[:,:,0] = output[:,:,0] * 50.0 + 50.0
100
+ output[:,:,1:3] = output[:,:,1:3] * 110.0
101
+ rgb_output = cv2.cvtColor(output[:,:,:], cv2.COLOR_LAB2BGR)
102
+ out_path = Path(tempfile.mkdtemp()) / "out.png"
103
+ cv2.imwrite(str(out_path), (rgb_output*255.0).astype(np.uint8))
104
+ return out_path