Spaces:
Running
Running
Upload inference.py
Browse files- inference.py +105 -0
inference.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, glob, sys, logging
|
2 |
+
import argparse, datetime, time
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from models import model, basic
|
10 |
+
from utils import util
|
11 |
+
|
12 |
+
|
13 |
+
def setup_model(checkpt_path, device="cuda"):
|
14 |
+
#print('--------------', torch.cuda.is_available())
|
15 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
16 |
+
colorLabeler = basic.ColorLabel(device=device)
|
17 |
+
colorizer = model.AnchorColorProb(inChannel=1, outChannel=313, enhanced=True, colorLabeler=colorLabeler)
|
18 |
+
colorizer = colorizer.to(device)
|
19 |
+
#checkpt_path = "./checkpoints/disco-beta.pth.rar"
|
20 |
+
assert os.path.exists(checkpt_path), "No checkpoint found!"
|
21 |
+
data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
|
22 |
+
colorizer.load_state_dict(data_dict['state_dict'])
|
23 |
+
colorizer.eval()
|
24 |
+
return colorizer, colorLabeler
|
25 |
+
|
26 |
+
|
27 |
+
def resize_ab2l(gray_img, lab_imgs, vis=False):
|
28 |
+
H, W = gray_img.shape[:2]
|
29 |
+
reszied_ab = cv2.resize(lab_imgs[:,:,1:], (W,H), interpolation=cv2.INTER_LINEAR)
|
30 |
+
if vis:
|
31 |
+
gray_img = cv2.resize(lab_imgs[:,:,:1], (W,H), interpolation=cv2.INTER_LINEAR)
|
32 |
+
return np.concatenate((gray_img[:,:,np.newaxis], reszied_ab), axis=2)
|
33 |
+
else:
|
34 |
+
return np.concatenate((gray_img, reszied_ab), axis=2)
|
35 |
+
|
36 |
+
def prepare_data(rgb_img, target_res):
|
37 |
+
rgb_img = np.array(rgb_img / 255., np.float32)
|
38 |
+
lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
|
39 |
+
org_grays = (lab_img[:,:,[0]]-50.) / 50.
|
40 |
+
lab_img = cv2.resize(lab_img, target_res, interpolation=cv2.INTER_LINEAR)
|
41 |
+
|
42 |
+
lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
|
43 |
+
gray_img = (lab_img[0:1,:,:]-50.) / 50.
|
44 |
+
ab_chans = lab_img[1:3,:,:] / 110.
|
45 |
+
input_grays = gray_img.unsqueeze(0)
|
46 |
+
input_colors = ab_chans.unsqueeze(0)
|
47 |
+
return input_grays, input_colors, org_grays
|
48 |
+
|
49 |
+
|
50 |
+
def colorize_grayscale(colorizer, color_class, rgb_img, hint_img, n_anchors, is_high_res, is_editable, device="cuda"):
|
51 |
+
n_anchors = int(n_anchors)
|
52 |
+
n_anchors = max(n_anchors, 3)
|
53 |
+
n_anchors = min(n_anchors, 14)
|
54 |
+
target_res = (512,512) if is_high_res else (256,256)
|
55 |
+
input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
|
56 |
+
input_grays = input_grays.to(device)
|
57 |
+
input_colors = input_colors.to(device)
|
58 |
+
|
59 |
+
if is_editable:
|
60 |
+
print('>>>:editable mode')
|
61 |
+
sampled_T = -1
|
62 |
+
_, input_colors, _ = prepare_data(hint_img, target_res)
|
63 |
+
input_colors = input_colors.to(device)
|
64 |
+
pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
|
65 |
+
input_colors, n_anchors, sampled_T)
|
66 |
+
else:
|
67 |
+
print('>>>:automatic mode')
|
68 |
+
sampled_T = 0
|
69 |
+
pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
|
70 |
+
input_colors, n_anchors, sampled_T)
|
71 |
+
|
72 |
+
pred_labs = torch.cat((input_grays,enhanced_ab), dim=1)
|
73 |
+
lab_imgs = basic.tensor2array(pred_labs).squeeze(axis=0)
|
74 |
+
lab_imgs = resize_ab2l(org_grays, lab_imgs)
|
75 |
+
|
76 |
+
lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
|
77 |
+
lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
|
78 |
+
rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
|
79 |
+
return (rgb_output*255.0).astype(np.uint8)
|
80 |
+
|
81 |
+
|
82 |
+
def predict_anchors(colorizer, color_class, rgb_img, n_anchors, is_high_res, is_editable, device="cuda"):
|
83 |
+
n_anchors = int(n_anchors)
|
84 |
+
n_anchors = max(n_anchors, 3)
|
85 |
+
n_anchors = min(n_anchors, 14)
|
86 |
+
target_res = (512,512) if is_high_res else (256,256)
|
87 |
+
input_grays, input_colors, org_grays = prepare_data(rgb_img, target_res)
|
88 |
+
input_grays = input_grays.to(device)
|
89 |
+
input_colors = input_colors.to(device)
|
90 |
+
|
91 |
+
sampled_T, sp_size = 0, 16
|
92 |
+
pal_logit, ref_logit, enhanced_ab, affinity_map, spix_colors, hint_mask = colorizer(input_grays, \
|
93 |
+
input_colors, n_anchors, sampled_T)
|
94 |
+
pred_probs = pal_logit
|
95 |
+
guided_colors = color_class.decode_ind2ab(ref_logit, T=0)
|
96 |
+
guided_colors = basic.upfeat(guided_colors, affinity_map, sp_size, sp_size)
|
97 |
+
anchor_masks = basic.upfeat(hint_mask, affinity_map, sp_size, sp_size)
|
98 |
+
marked_labs = basic.mark_color_hints(input_grays, guided_colors, anchor_masks, base_ABs=None)
|
99 |
+
lab_imgs = basic.tensor2array(marked_labs).squeeze(axis=0)
|
100 |
+
lab_imgs = resize_ab2l(org_grays, lab_imgs, vis=True)
|
101 |
+
|
102 |
+
lab_imgs[:,:,0] = lab_imgs[:,:,0] * 50.0 + 50.0
|
103 |
+
lab_imgs[:,:,1:3] = lab_imgs[:,:,1:3] * 110.0
|
104 |
+
rgb_output = cv2.cvtColor(lab_imgs[:,:,:], cv2.COLOR_LAB2RGB)
|
105 |
+
return (rgb_output*255.0).astype(np.uint8)
|