doevent commited on
Commit
ee506ee
1 Parent(s): 495fea9

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +105 -0
inference.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, glob, sys, logging
2
+ import argparse, datetime, time
3
+ import numpy as np
4
+ import cv2
5
+ from PIL import Image
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from models import model, basic
10
+ from utils import util
11
+
12
+
13
+ def setup_model(checkpt_path, device="cuda"):
14
+ #print('--------------', torch.cuda.is_available())
15
+ """Load the model into memory to make running multiple predictions efficient"""
16
+ colorLabeler = basic.ColorLabel(device=device)
17
+ colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True, colorLabeler=colorLabeler)
18
+ colorizer = colorizer.to(device)
19
+ #checkpt_path = "./checkpoints/disco-beta.pth.rar"
20
+ assert os.path.exists(checkpt_path), "No checkpoint found!"
21
+ data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
22
+ colorizer.load_state_dict(data_dict['state_dict'])
23
+ colorizer.eval()
24
+ return colorizer, colorLabeler
25
+
26
+
27
+ def resize_ab2l(gray_img, lab_imgs, vis=False):
28
+ H, W = gray_img.shape[:2]
29
+ reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
30
+ if vis:
31
+ gray_img = cv2.resize(lab_imgs[:,:,:1], (W,H), interpolation=cv2.INTER_LINEAR)
32
+ return np.concatenate((gray_img[:,:,np.newaxis], reszied_ab), axis=2)
33
+ else:
34
+ return np.concatenate((gray_img, reszied_ab), axis=2)
35
+
36
+ def prepare_data(rgb_img, target_res):
37
+ rgb_img = np.array(rgb_img / 255., np.float32)
38
+ lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
39
+ org_grays = (lab_img[:,:,[0]]-50.) / 50.
40
+ lab_img = cv2.resize(lab_img, target_res, interpolation=cv2.INTER_LINEAR)
41
+
42
+ lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
43
+ gray_img = (lab_img[0:1,:,:]-50.) / 50.
44
+ ab_chans = lab_img[1:3,:,:] / 110.
45
+ input_grays = gray_img.unsqueeze(0)
46
+ input_colors = ab_chans.unsqueeze(0)
47
+ return input_grays, input_colors, org_grays
48
+
49
+
50
+ def colorize_grayscale(colorizer, color_class, rgb_img, hint_img, n_anchors, is_high_res, is_editable, device="cuda"):
51
+ n_anchors = int(n_anchors)
52
+ n_anchors = max(n_anchors, 3)
53
+ n_anchors = min(n_anchors, 14)
54
+ target_res = (512,512) if is_high_res else (256,256)
55
+ input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
56
+ input_grays = input_grays.to(device)
57
+ input_colors = input_colors.to(device)
58
+
59
+ if is_editable:
60
+ print('>>>:editable mode')
61
+ sampled_T = -1
62
+ _, input_colors, _ = prepare_data(hint_img, target_res)
63
+ input_colors = input_colors.to(device)
64
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
65
+ input_colors, n_anchors, sampled_T)
66
+ else:
67
+ print('>>>:automatic mode')
68
+ sampled_T = 0
69
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
70
+ input_colors, n_anchors, sampled_T)
71
+
72
+ pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
73
+ lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
74
+ lab_imgs = resize_ab2l(org_grays, lab_imgs)
75
+
76
+ lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
77
+ lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
78
+ rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
79
+ return (rgb_output*255.0).astype(np.uint8)
80
+
81
+
82
+ def predict_anchors(colorizer, color_class, rgb_img, n_anchors, is_high_res, is_editable, device="cuda"):
83
+ n_anchors = int(n_anchors)
84
+ n_anchors = max(n_anchors, 3)
85
+ n_anchors = min(n_anchors, 14)
86
+ target_res = (512,512) if is_high_res else (256,256)
87
+ input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
88
+ input_grays = input_grays.to(device)
89
+ input_colors = input_colors.to(device)
90
+
91
+ sampled_T, sp_size = 0, 16
92
+ pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
93
+ input_colors, n_anchors, sampled_T)
94
+ pred_probs = pal_logit
95
+ guided_colors = color_class.decode_ind2ab(ref_logit, T=0)
96
+ guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
97
+ anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
98
+ marked_labs = basic.mark_color_hints(input_grays, guided_colors, anchor_masks, base_ABs=None)
99
+ lab_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
100
+ lab_imgs = resize_ab2l(org_grays, lab_imgs, vis=True)
101
+
102
+ lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
103
+ lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
104
+ rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
105
+ return (rgb_output*255.0).astype(np.uint8)